Cross Validations
TanayLabUtilities.CrossValidations
—
Module
Cross validation functions.
TanayLabUtilities.CrossValidations.CrossValidationIndices
—
Type
@kwdef struct CrossValidationIndices
n_parts::Integer
train_indices_per_part::AbstractVector{<:AbstractVector{<:Integer}}
test_indices_per_part::AbstractVector{<:AbstractVector{<:Integer}}
end
A set of cross-validation indices;
n_parts
, where each part is a combination of distinct test and training indices, such that the training indices for each part are the union of the test indices of all the other parts.
TanayLabUtilities.CrossValidations.pick_cross_validation_indices
—
Function
function pick_cross_validation_indices(;
full_indices::AbstractVector{<:Integer},
cross_validation_parts::Integer,
rng::AbstractRNG,
)::CrossValidationIndices
Given a vector of
full_indices
, split them into
cross_validation_parts
where each part has distinct training and testing indices.
using Test
using Random
cross_validation_indices = pick_cross_validation_indices(;
full_indices = collect(1:12),
cross_validation_parts = 3,
rng = Random.default_rng()
)
@test all([
length(test_indices)
for test_indices
in cross_validation_indices.test_indices_per_part
] .== 4)
@test all([
length(train_indices)
for train_indices
in cross_validation_indices.train_indices_per_part
] .== 8)
@test all([
length(union(Set(test_indices), Set(train_indices)))
for (test_indices, train_indices)
in zip(cross_validation_indices.test_indices_per_part, cross_validation_indices.train_indices_per_part)
] .== 12)
println("OK")
# output
OK