From 3d98e5f7397c8eee88fb81c5c21af7deb3630851 Mon Sep 17 00:00:00 2001 From: nperez Date: Wed, 10 Apr 2024 16:17:41 +0200 Subject: [PATCH 01/16] k fold in make.eval.train.dexes --- R/CST_Calibration.R | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/R/CST_Calibration.R b/R/CST_Calibration.R index d4b9170b..5668c41c 100644 --- a/R/CST_Calibration.R +++ b/R/CST_Calibration.R @@ -567,10 +567,25 @@ Calibration <- function(exp, obs, exp_cor = NULL, } } -.make.eval.train.dexes <- function(eval.method, amt.points, amt.points_cor) { +.make.eval.train.dexes <- function(eval.method, amt.points, amt.points_cor, + k = 1) { if (eval.method == "leave-one-out") { dexes.lst <- lapply(seq(1, amt.points), function(x) return(list(eval.dexes = x, train.dexes = seq(1, amt.points)[-x]))) + } else if (eval.method == "k-fold") { + # k is a odd number + dexes.lst <- lapply(seq(1, amt.points), function(x, kfold = k) { + if (x >= ((kfold-1)/2) + 1 && x + ((kfold-1)/2) <= amt.points) { + ind <- (x-((kfold-1)/2)):(x+((kfold-1)/2)) + } else if (x < ((kfold-1)/2) + 1) { + ind <- c((amt.points - ((kfold-1)/2-x)): amt.points, 1:(x+(kfold-1)/2)) + } else if ((x+((kfold-1)/2)) > amt.points) { + ind <- c((x-(kfold-1)/2):amt.points, 1:(((kfold-1)/2)-amt.points + x)) + } else { + stop("Review make.eval.train.dexes function") + } + return(list(eval.dexes = x, train.dexes = seq(1, amt.points)[-ind])) + }) } else if (eval.method == "in-sample") { dexes.lst <- list(list(eval.dexes = seq(1, amt.points), train.dexes = seq(1, amt.points))) -- GitLab From 8173c9b30fedcce7de9413e05cb28c77ec85475c Mon Sep 17 00:00:00 2001 From: nperez Date: Wed, 10 Apr 2024 17:32:43 +0200 Subject: [PATCH 02/16] retrospective indices --- R/CST_Calibration.R | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/R/CST_Calibration.R b/R/CST_Calibration.R index 5668c41c..525514ad 100644 --- a/R/CST_Calibration.R +++ b/R/CST_Calibration.R @@ -585,7 +585,17 @@ Calibration <- function(exp, obs, exp_cor = NULL, stop("Review make.eval.train.dexes function") } return(list(eval.dexes = x, train.dexes = seq(1, amt.points)[-ind])) - }) + }) + } else if (eval.metrod == "retrospective") { + # k can be any integer indicating the when to start + dexes.lst <- Filter(length, lapply(seq(1, amt.points), + function(x, mindata = k) { + if (x > k) { + eval.dexes <- x + train.dexes <- 1:(x-1) + return(list(eval.dexes = x, + train.dexes = 1:(x-1))) + }})) } else if (eval.method == "in-sample") { dexes.lst <- list(list(eval.dexes = seq(1, amt.points), train.dexes = seq(1, amt.points))) -- GitLab From 2286cebcc31e96232d3892812aad6c6bcb6fb4dc Mon Sep 17 00:00:00 2001 From: nperez Date: Wed, 10 Apr 2024 18:00:42 +0200 Subject: [PATCH 03/16] fix typo --- R/CST_Calibration.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/CST_Calibration.R b/R/CST_Calibration.R index b2e87f2b..78c4d7c9 100644 --- a/R/CST_Calibration.R +++ b/R/CST_Calibration.R @@ -586,7 +586,7 @@ Calibration <- function(exp, obs, exp_cor = NULL, } return(list(eval.dexes = x, train.dexes = seq(1, amt.points)[-ind])) }) - } else if (eval.metrod == "retrospective") { + } else if (eval.method == "retrospective") { # k can be any integer indicating the when to start dexes.lst <- Filter(length, lapply(seq(1, amt.points), function(x, mindata = k) { -- GitLab From f6f0a6c2794e8333651273052c7a670455ae7742 Mon Sep 17 00:00:00 2001 From: nperez Date: Thu, 11 Apr 2024 14:19:52 +0200 Subject: [PATCH 04/16] unit test --- R/CST_Calibration.R | 12 ++- tests/testthat/test-make-eval-train-dexes.R | 93 +++++++++++++++++++++ 2 files changed, 101 insertions(+), 4 deletions(-) create mode 100644 tests/testthat/test-make-eval-train-dexes.R diff --git a/R/CST_Calibration.R b/R/CST_Calibration.R index 78c4d7c9..c36aabb2 100644 --- a/R/CST_Calibration.R +++ b/R/CST_Calibration.R @@ -568,7 +568,11 @@ Calibration <- function(exp, obs, exp_cor = NULL, } .make.eval.train.dexes <- function(eval.method, amt.points, amt.points_cor, - k = 1) { + k = 1) { + if (k >= amt.points && !is.null(k)) { + stop("k need to be smaller than the amt.points") + } + if (eval.method == "leave-one-out") { dexes.lst <- lapply(seq(1, amt.points), function(x) return(list(eval.dexes = x, train.dexes = seq(1, amt.points)[-x]))) @@ -578,14 +582,14 @@ Calibration <- function(exp, obs, exp_cor = NULL, if (x >= ((kfold-1)/2) + 1 && x + ((kfold-1)/2) <= amt.points) { ind <- (x-((kfold-1)/2)):(x+((kfold-1)/2)) } else if (x < ((kfold-1)/2) + 1) { - ind <- c((amt.points - ((kfold-1)/2-x)): amt.points, 1:(x+(kfold-1)/2)) + ind <- c((amt.points - ((kfold-1)/2-x)):amt.points, 1:(x+(kfold-1)/2)) } else if ((x+((kfold-1)/2)) > amt.points) { - ind <- c((x-(kfold-1)/2):amt.points, 1:(((kfold-1)/2)-amt.points + x)) + ind <- c((x-(kfold-1)/2):amt.points, 1:(((kfold-1)/2)-amt.points+x)) } else { stop("Review make.eval.train.dexes function") } return(list(eval.dexes = x, train.dexes = seq(1, amt.points)[-ind])) - }) + }) } else if (eval.method == "retrospective") { # k can be any integer indicating the when to start dexes.lst <- Filter(length, lapply(seq(1, amt.points), diff --git a/tests/testthat/test-make-eval-train-dexes.R b/tests/testthat/test-make-eval-train-dexes.R new file mode 100644 index 00000000..53e5dc88 --- /dev/null +++ b/tests/testthat/test-make-eval-train-dexes.R @@ -0,0 +1,93 @@ +############################################## + +test_that("1. Input checks", { + # s2dv_cube + expect_error( + .make.eval.train.dexes(eval.method = 1, 2), + paste0("unknown sampling method: 1") + ) + expect_error( + .make.eval.train.dexes(eval.method = 'k-fold', 10, NULL, 12), + paste0("k need to be smaller than the amt.points") + ) + expect_error( + .make.eval.train.dexes(eval.method = 'retrospective', 10, NULL, 10), + paste0("k need to be smaller than the amt.points") + ) +}) + +############################################## + +test_that("2. Output checks: k-fold", { + expect_equal( + is.list(.make.eval.train.dexes(eval.method = 'k-fold', 10, NULL, 5)), + TRUE + ) + expect_equal( + is.list(.make.eval.train.dexes(eval.method = 'k-fold', 10, NULL, 5)[[1]]), + TRUE + ) + expect_equal( + length(.make.eval.train.dexes(eval.method = 'k-fold', 10, NULL, 5)), + 10 + ) + expect_equal( + names(.make.eval.train.dexes(eval.method = 'k-fold', 10, NULL, 5)[[1]]), + c("eval.dexes", "train.dexes") + ) + expect_equal( + .make.eval.train.dexes(eval.method = 'k-fold', 10, NULL, 5)[[1]], + list(eval.dexes = 1, + train.dexes = 4:8) + ) + expect_equal( + .make.eval.train.dexes(eval.method = 'k-fold', 10, NULL, 5)[[2]], + list(eval.dexes = 2, + train.dexes = 5:9) + ) + expect_equal( + .make.eval.train.dexes(eval.method = 'k-fold', 10, NULL, 5)[[3]], + list(eval.dexes = 3, + train.dexes = 6:10) + ) + expect_equal( + .make.eval.train.dexes(eval.method = 'k-fold', 10, NULL, 5)[[9]], + list(eval.dexes = 9, + train.dexes = 2:6) + ) + expect_equal( + .make.eval.train.dexes(eval.method = 'k-fold', 10, NULL, 5)[[10]], + list(eval.dexes = 10, + train.dexes = 3:7) + ) + expect_equal( + .make.eval.train.dexes(eval.method = 'k-fold', 10, NULL, 1), + .make.eval.train.dexes(eval.method = 'leave-one-out', 10, NULL) + ) + expect_equal( + .make.eval.train.dexes(eval.method = 'k-fold', 20, NULL, 1), + .make.eval.train.dexes(eval.method = 'leave-one-out', 20, NULL) + ) +}) + +############################################## + +test_that("3. Output checks: retrospective", { + expect_equal( + length(.make.eval.train.dexes(eval.method = 'retrospective', 10, NULL, 3)), + 7 + ) + expect_equal( + .make.eval.train.dexes(eval.method = 'retrospective', 10, NULL, 3)[[1]], + list(eval.dexes = 4, train.dexes = 1:3) + ) + expect_equal( + .make.eval.train.dexes(eval.method = 'retrospective', 10, NULL, 3)[[2]], + list(eval.dexes = 5, train.dexes = 1:4) + ) + expect_equal( + .make.eval.train.dexes(eval.method = 'retrospective', 10, NULL, 3)[[7]], + list(eval.dexes = 10, train.dexes = 1:9) + ) +}) + -- GitLab From 87422b119df951815a1a4e23ff676eb3bb6df0be Mon Sep 17 00:00:00 2001 From: nperez Date: Thu, 11 Apr 2024 15:48:33 +0200 Subject: [PATCH 05/16] Filter function from base package --- R/CST_Calibration.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/CST_Calibration.R b/R/CST_Calibration.R index c36aabb2..7cad2fc9 100644 --- a/R/CST_Calibration.R +++ b/R/CST_Calibration.R @@ -592,7 +592,7 @@ Calibration <- function(exp, obs, exp_cor = NULL, }) } else if (eval.method == "retrospective") { # k can be any integer indicating the when to start - dexes.lst <- Filter(length, lapply(seq(1, amt.points), + dexes.lst <- base::Filter(length, lapply(seq(1, amt.points), function(x, mindata = k) { if (x > k) { eval.dexes <- x -- GitLab From ed2bc3ba6088d78eaf93a1df104c2f086df6a976 Mon Sep 17 00:00:00 2001 From: Sara Moreno Date: Thu, 18 Jul 2024 12:59:45 +0200 Subject: [PATCH 06/16] tail out parameter for cross val --- R/CST_Calibration.R | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/R/CST_Calibration.R b/R/CST_Calibration.R index 7cad2fc9..048b7243 100644 --- a/R/CST_Calibration.R +++ b/R/CST_Calibration.R @@ -568,7 +568,17 @@ Calibration <- function(exp, obs, exp_cor = NULL, } .make.eval.train.dexes <- function(eval.method, amt.points, amt.points_cor, - k = 1) { + k = 1, tail.out = TRUE) { +# eval.method: 'leave-one-out', 'k-fold', 'in-sample', "hindcast-vs-forecast", "retrospective" +# amt.points: length of the sample +# amt.points_cor: only needed for hindcast-vs-forecast method +# k: the number of samples to leave-out +# tail_out: boolean for 'k-fold method; TRUE to remove both extremes of the sample when k is +# in the extreme keeping the same sample size for all k-folds (e.g. amt.points=50, k=3, +# eval.dexes=1, train.dexes={3,49}). FALSE to remove only the corresponding tail (e.g.. +# amt.points=50, k=3, eval.dexes=1, train.dexes={3,50}) + + if (k >= amt.points && !is.null(k)) { stop("k need to be smaller than the amt.points") } @@ -578,7 +588,8 @@ Calibration <- function(exp, obs, exp_cor = NULL, train.dexes = seq(1, amt.points)[-x]))) } else if (eval.method == "k-fold") { # k is a odd number - dexes.lst <- lapply(seq(1, amt.points), function(x, kfold = k) { + if (tail.out == TRUE){ + dexes.lst <- lapply(seq(1, amt.points), function(x, kfold = k) { if (x >= ((kfold-1)/2) + 1 && x + ((kfold-1)/2) <= amt.points) { ind <- (x-((kfold-1)/2)):(x+((kfold-1)/2)) } else if (x < ((kfold-1)/2) + 1) { @@ -590,6 +601,20 @@ Calibration <- function(exp, obs, exp_cor = NULL, } return(list(eval.dexes = x, train.dexes = seq(1, amt.points)[-ind])) }) + } else { + if (k != 1){ + k <- (k-1)/2 + } + + dexes.lst <- lapply(seq(1, amt.points), function(x) { + start_idx <- max(1, x - k) + end_idx <- min(amt.points, x + k) + eval_range <- start_idx:end_idx + train_idx <- setdiff(seq(1, amt.points), eval_range) + return(list(eval.dexes = x, train.dexes = train_idx)) + }) + } + } else if (eval.method == "retrospective") { # k can be any integer indicating the when to start dexes.lst <- base::Filter(length, lapply(seq(1, amt.points), -- GitLab From 62b81970d6e7b052b976f5ee1e6d8551883b97f7 Mon Sep 17 00:00:00 2001 From: THEERTHA KARIYATHAN Date: Tue, 12 Nov 2024 15:24:14 +0100 Subject: [PATCH 07/16] add checks for new calibration methods --- R/CST_Calibration.R | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/R/CST_Calibration.R b/R/CST_Calibration.R index 048b7243..6f20c3ff 100644 --- a/R/CST_Calibration.R +++ b/R/CST_Calibration.R @@ -165,7 +165,7 @@ CST_Calibration <- function(exp, obs, exp_cor = NULL, cal.method = "mse_min", multi.model = multi.model, na.fill = na.fill, na.rm = na.rm, apply_to = apply_to, alpha = alpha, memb_dim = memb_dim, sdate_dim = sdate_dim, - dat_dim = dat_dim, ncores = ncores) + dat_dim = dat_dim, ncores = ncores,k = k,tail.out = tail.out) if (is.null(exp_cor)) { exp$data <- Calibration @@ -305,7 +305,7 @@ Calibration <- function(exp, obs, exp_cor = NULL, multi.model = FALSE, na.fill = TRUE, na.rm = TRUE, apply_to = NULL, alpha = NULL, memb_dim = 'member', sdate_dim = 'sdate', dat_dim = NULL, - ncores = NULL) { + ncores = NULL, k = 1, tail.out = TRUE) { # Check inputs ## exp, obs @@ -478,10 +478,10 @@ Calibration <- function(exp, obs, exp_cor = NULL, } } ## eval.method - if (!any(eval.method %in% c('in-sample', 'leave-one-out', 'hindcast-vs-forecast'))) { + if (!any(eval.method %in% c('in-sample', 'leave-one-out', 'hindcast-vs-forecast', 'k-fold', 'retrospective'))) { stop(paste0("Parameter 'eval.method' must be a character string indicating ", - "the sampling method used ('in-sample', 'leave-one-out' or ", - "'hindcast-vs-forecast').")) + "the sampling method used ('in-sample', 'leave-one-out', 'hindcast-vs-forecast', 'k-fold' or ", + "'retrospective').")) } ## multi.model if (!inherits(multi.model, "logical")) { @@ -568,7 +568,7 @@ Calibration <- function(exp, obs, exp_cor = NULL, } .make.eval.train.dexes <- function(eval.method, amt.points, amt.points_cor, - k = 1, tail.out = TRUE) { + k, tail.out) { # eval.method: 'leave-one-out', 'k-fold', 'in-sample', "hindcast-vs-forecast", "retrospective" # amt.points: length of the sample # amt.points_cor: only needed for hindcast-vs-forecast method @@ -578,7 +578,6 @@ Calibration <- function(exp, obs, exp_cor = NULL, # eval.dexes=1, train.dexes={3,49}). FALSE to remove only the corresponding tail (e.g.. # amt.points=50, k=3, eval.dexes=1, train.dexes={3,50}) - if (k >= amt.points && !is.null(k)) { stop("k need to be smaller than the amt.points") } @@ -615,16 +614,16 @@ Calibration <- function(exp, obs, exp_cor = NULL, }) } - } else if (eval.method == "retrospective") { + } else if (eval.method == "retrospective") { # k can be any integer indicating the when to start dexes.lst <- base::Filter(length, lapply(seq(1, amt.points), function(x, mindata = k) { - if (x > k) { + if (x > k) { eval.dexes <- x train.dexes <- 1:(x-1) return(list(eval.dexes = x, train.dexes = 1:(x-1))) - }})) + }})) } else if (eval.method == "in-sample") { dexes.lst <- list(list(eval.dexes = seq(1, amt.points), train.dexes = seq(1, amt.points))) @@ -634,6 +633,7 @@ Calibration <- function(exp, obs, exp_cor = NULL, } else { stop(paste0("unknown sampling method: ", eval.method)) } + return(dexes.lst) } @@ -696,7 +696,8 @@ Calibration <- function(exp, obs, exp_cor = NULL, eval.train.dexeses <- .make.eval.train.dexes(eval.method = eval.method, amt.points = sdate, - amt.points_cor = sdate_cor) + amt.points_cor = sdate_cor, + k = k, tail.out = tail.out) amt.resamples <- length(eval.train.dexeses) for (i.sample in seq(1, amt.resamples)) { # defining training (tr) and evaluation (ev) subsets -- GitLab From 5e35c00cf2d9b8e61794750c2e35da199eb4a369 Mon Sep 17 00:00:00 2001 From: THEERTHA KARIYATHAN Date: Tue, 12 Nov 2024 15:27:00 +0100 Subject: [PATCH 08/16] tests for k-fold and retrospective --- tests/testthat/test-CST_Calibration.R | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test-CST_Calibration.R b/tests/testthat/test-CST_Calibration.R index 491aff29..4f5b34b3 100644 --- a/tests/testthat/test-CST_Calibration.R +++ b/tests/testthat/test-CST_Calibration.R @@ -218,8 +218,8 @@ test_that("1. Input checks", { expect_error( Calibration(exp4, obs4, eval.method = 'biass'), paste0("Parameter 'eval.method' must be a character string indicating ", - "the sampling method used ('in-sample', 'leave-one-out' or ", - "'hindcast-vs-forecast')."), + "the sampling method used ('in-sample', 'leave-one-out', 'hindcast-vs-forecast', + 'k-fold', or ", "'retrospective')."), fixed = TRUE ) # multi.model @@ -453,6 +453,18 @@ test_that("6. Output checks: dat4", { c(-0.7119142, 0.2626203, -0.9635483, 1.9607986, 0.4380930), tolerance = 0.0001 ) + expect_equal( + as.vector(Calibration(exp4, obs4, eval.method = "k-fold", k = 5, tail.out = T))[1:5], + c(-0.8149557, 0.2043942, -1.0781616, 1.9806657, 0.3879362), + tolerance = 0.0001 + ) + + expect_equal( + as.vector(Calibration(exp4, obs4, eval.method = "retrospective", k = 3)[1,1:5]), + c(NA, NA, NA, 1.6267865, -0.1658674), + tolerance = 0.0001 + ) + }) ############################################## -- GitLab From a0bd3a8c318c8376e7811bbb3361e77d7b946cfa Mon Sep 17 00:00:00 2001 From: THEERTHA KARIYATHAN Date: Tue, 12 Nov 2024 16:04:39 +0100 Subject: [PATCH 09/16] k and tail.out added to CST_Calibration.Rd/R --- R/CST_Calibration.R | 4 ++-- man/CST_Calibration.Rd | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/R/CST_Calibration.R b/R/CST_Calibration.R index 6f20c3ff..11367952 100644 --- a/R/CST_Calibration.R +++ b/R/CST_Calibration.R @@ -149,7 +149,7 @@ CST_Calibration <- function(exp, obs, exp_cor = NULL, cal.method = "mse_min", eval.method = "leave-one-out", multi.model = FALSE, na.fill = TRUE, na.rm = TRUE, apply_to = NULL, alpha = NULL, memb_dim = 'member', sdate_dim = 'sdate', - dat_dim = NULL, ncores = NULL) { + dat_dim = NULL, ncores = NULL, k = 1, tail.out = TRUE) { # Check 's2dv_cube' if (!inherits(exp, "s2dv_cube") || !inherits(obs, "s2dv_cube")) { stop("Parameter 'exp' and 'obs' must be of the class 's2dv_cube'.") @@ -165,7 +165,7 @@ CST_Calibration <- function(exp, obs, exp_cor = NULL, cal.method = "mse_min", multi.model = multi.model, na.fill = na.fill, na.rm = na.rm, apply_to = apply_to, alpha = alpha, memb_dim = memb_dim, sdate_dim = sdate_dim, - dat_dim = dat_dim, ncores = ncores,k = k,tail.out = tail.out) + dat_dim = dat_dim, ncores = ncores, k = k, tail.out = tail.out) if (is.null(exp_cor)) { exp$data <- Calibration diff --git a/man/CST_Calibration.Rd b/man/CST_Calibration.Rd index 491b7271..6843f1c9 100644 --- a/man/CST_Calibration.Rd +++ b/man/CST_Calibration.Rd @@ -18,7 +18,9 @@ CST_Calibration( memb_dim = "member", sdate_dim = "sdate", dat_dim = NULL, - ncores = NULL + ncores = NULL, + k = 1, + tail.out = TRUE ) } \arguments{ -- GitLab From d1cb46bed9f49d1bf95c6e62d65bb1ea6ace8977 Mon Sep 17 00:00:00 2001 From: THEERTHA KARIYATHAN Date: Thu, 14 Nov 2024 10:48:59 +0100 Subject: [PATCH 10/16] k and tail.iut added to man/CST_CategoricalEnsCombination.Rd/R --- R/CST_Calibration.R | 6 +++--- R/CST_CategoricalEnsCombination.R | 17 +++++++++++------ man/CST_CategoricalEnsCombination.Rd | 2 ++ 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/R/CST_Calibration.R b/R/CST_Calibration.R index 11367952..fc972d6a 100644 --- a/R/CST_Calibration.R +++ b/R/CST_Calibration.R @@ -513,7 +513,7 @@ Calibration <- function(exp, obs, exp_cor = NULL, cal.method = cal.method, eval.method = eval.method, multi.model = multi.model, na.fill = na.fill, na.rm = na.rm, apply_to = apply_to, alpha = alpha, target_dims = list(exp = target_dims_exp, obs = target_dims_obs), - ncores = ncores, fun = .cal)$output1 + ncores = ncores, k = k, tail.out = tail.out, fun = .cal)$output1 } else { calibrated <- Apply(data = list(exp = exp, obs = obs, exp_cor = exp_cor), dat_dim = dat_dim, cal.method = cal.method, eval.method = eval.method, @@ -521,7 +521,7 @@ Calibration <- function(exp, obs, exp_cor = NULL, apply_to = apply_to, alpha = alpha, target_dims = list(exp = target_dims_exp, obs = target_dims_obs, exp_cor = target_dims_cor), - ncores = ncores, fun = .cal)$output1 + ncores = ncores, k = k, tail.out = tail.out, fun = .cal)$output1 } if (!is.null(dat_dim)) { @@ -639,7 +639,7 @@ Calibration <- function(exp, obs, exp_cor = NULL, .cal <- function(exp, obs, exp_cor = NULL, dat_dim = NULL, cal.method = "mse_min", eval.method = "leave-one-out", multi.model = FALSE, na.fill = TRUE, - na.rm = TRUE, apply_to = NULL, alpha = NULL) { + na.rm = TRUE, apply_to = NULL, alpha = NULL, k = 1, tail.out = TRUE) { # exp: [memb, sdate, (dat)] # obs: [sdate (dat)] diff --git a/R/CST_CategoricalEnsCombination.R b/R/CST_CategoricalEnsCombination.R index 3c92d587..f973cde7 100644 --- a/R/CST_CategoricalEnsCombination.R +++ b/R/CST_CategoricalEnsCombination.R @@ -97,7 +97,7 @@ #'@export CST_CategoricalEnsCombination <- function(exp, obs, cat.method = "pool", eval.method = "leave-one-out", - amt.cat = 3, + amt.cat = 3, k = 1, tail.out = TRUE, ...) { # Check 's2dv_cube' if (!inherits(exp, "s2dv_cube") || !inherits(exp, "s2dv_cube")) { @@ -113,7 +113,8 @@ CST_CategoricalEnsCombination <- function(exp, obs, cat.method = "pool", exp$data <- CategoricalEnsCombination(fc = exp$data, obs = obs$data, cat.method = cat.method, eval.method = eval.method, - amt.cat = amt.cat, ...) + amt.cat = amt.cat, k = k, + tail.out = tail.out, ...) names.dim.tmp[which(names.dim.tmp == "member")] <- "category" names(dim(exp$data)) <- names.dim.tmp @@ -175,7 +176,7 @@ CST_CategoricalEnsCombination <- function(exp, obs, cat.method = "pool", #'@importFrom s2dv InsertDim #'@import abind #'@export -CategoricalEnsCombination <- function (fc, obs, cat.method, eval.method, amt.cat, ...) { +CategoricalEnsCombination <- function (fc, obs, cat.method, eval.method, amt.cat, k, tail.out, ...) { if (!all(c("member", "sdate") %in% names(dim(fc)))) { stop("Parameter 'exp' must have the dimensions 'member' and 'sdate'.") @@ -205,6 +206,8 @@ CategoricalEnsCombination <- function (fc, obs, cat.method, eval.method, amt.cat cat.method = cat.method, eval.method = eval.method, amt.cat = amt.cat, + k = k, + tail.out = tail.out, ...) return(cat_fc_out) } @@ -243,7 +246,7 @@ comb.dims <- function(arr.in, dims.to.combine){ return(arr.out) } -.apply.obs.fc <- function(obs, fc, target.dims, FUN, return.feat, cat.method, eval.method, amt.cat, ...){ +.apply.obs.fc <- function(obs, fc, target.dims, FUN, return.feat, cat.method, eval.method, amt.cat, k, tail.out, ...){ dimnames.tmp <- dimnames(fc) fc.dims.tmp <- dim(fc) dims.out.tmp <- return.feat$dim @@ -260,6 +263,8 @@ comb.dims <- function(arr.in, dims.to.combine){ cat.method = cat.method, eval.method = eval.method, amt.cat = amt.cat, + k = k, + tail.out = tail.out, ...) dims.tmp <- dim(arr.out) names.dims.tmp <- names(dim(arr.out)) @@ -280,7 +285,7 @@ comb.dims <- function(arr.in, dims.to.combine){ } -.cat_fc <- function(obs.fc, amt.cat, cat.method, eval.method) { +.cat_fc <- function(obs.fc, amt.cat, cat.method, eval.method, k, tail.out) { dims.tmp=dim(obs.fc) amt.mbr <- dims.tmp["member"][]-1 @@ -299,7 +304,7 @@ comb.dims <- function(arr.in, dims.to.combine){ amt.coeff <- amt.mdl + 1 var.cat.fc <- array(NA, c(amt.cat, amt.sdate)) - eval.train.dexeses <- .make.eval.train.dexes(eval.method = eval.method, amt.points = amt.sdate) + eval.train.dexeses <- .make.eval.train.dexes(eval.method = eval.method, amt.points = amt.sdate, k = k, tail.out = tail.out) amt.resamples <- length(eval.train.dexeses) for (i.sample in seq(1, amt.resamples)) { diff --git a/man/CST_CategoricalEnsCombination.Rd b/man/CST_CategoricalEnsCombination.Rd index 85ebb7f8..416ad7f8 100644 --- a/man/CST_CategoricalEnsCombination.Rd +++ b/man/CST_CategoricalEnsCombination.Rd @@ -11,6 +11,8 @@ CST_CategoricalEnsCombination( cat.method = "pool", eval.method = "leave-one-out", amt.cat = 3, + k = 1, + tail.out = TRUE, ... ) } -- GitLab From 3ec5fecfb4f88d0a7ba7d41ab25517246df94f6e Mon Sep 17 00:00:00 2001 From: THEERTHA KARIYATHAN Date: Thu, 14 Nov 2024 12:28:54 +0100 Subject: [PATCH 11/16] set default value for tail.out in .make.eval.train.dexes() --- R/CST_Calibration.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/CST_Calibration.R b/R/CST_Calibration.R index fc972d6a..d900a773 100644 --- a/R/CST_Calibration.R +++ b/R/CST_Calibration.R @@ -568,7 +568,7 @@ Calibration <- function(exp, obs, exp_cor = NULL, } .make.eval.train.dexes <- function(eval.method, amt.points, amt.points_cor, - k, tail.out) { + k, tail.out = TRUE) { # eval.method: 'leave-one-out', 'k-fold', 'in-sample', "hindcast-vs-forecast", "retrospective" # amt.points: length of the sample # amt.points_cor: only needed for hindcast-vs-forecast method -- GitLab From cd74879c86f842aa20161a23ca20244c987c12b5 Mon Sep 17 00:00:00 2001 From: THEERTHA KARIYATHAN Date: Thu, 14 Nov 2024 12:42:31 +0100 Subject: [PATCH 12/16] set default value for k in .make.eval.train.dexes() --- R/CST_Calibration.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/CST_Calibration.R b/R/CST_Calibration.R index d900a773..86bd7712 100644 --- a/R/CST_Calibration.R +++ b/R/CST_Calibration.R @@ -568,7 +568,7 @@ Calibration <- function(exp, obs, exp_cor = NULL, } .make.eval.train.dexes <- function(eval.method, amt.points, amt.points_cor, - k, tail.out = TRUE) { + k = 1, tail.out = TRUE) { # eval.method: 'leave-one-out', 'k-fold', 'in-sample', "hindcast-vs-forecast", "retrospective" # amt.points: length of the sample # amt.points_cor: only needed for hindcast-vs-forecast method -- GitLab From 5942810b1a09215ae2b97c7086942755d7d270ab Mon Sep 17 00:00:00 2001 From: THEERTHA KARIYATHAN Date: Thu, 14 Nov 2024 14:52:07 +0100 Subject: [PATCH 13/16] test-CST_Calibration.R update --- tests/testthat/test-CST_Calibration.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test-CST_Calibration.R b/tests/testthat/test-CST_Calibration.R index 4f5b34b3..511d845c 100644 --- a/tests/testthat/test-CST_Calibration.R +++ b/tests/testthat/test-CST_Calibration.R @@ -218,8 +218,8 @@ test_that("1. Input checks", { expect_error( Calibration(exp4, obs4, eval.method = 'biass'), paste0("Parameter 'eval.method' must be a character string indicating ", - "the sampling method used ('in-sample', 'leave-one-out', 'hindcast-vs-forecast', - 'k-fold', or ", "'retrospective')."), + "the sampling method used ('in-sample', 'leave-one-out', 'hindcast-vs-forecast', 'k-fold', or ", + "'retrospective')."), fixed = TRUE ) # multi.model -- GitLab From 8db9c93d0746079b0c1e4e29c1779d336dfe534a Mon Sep 17 00:00:00 2001 From: THEERTHA KARIYATHAN Date: Thu, 14 Nov 2024 15:26:13 +0100 Subject: [PATCH 14/16] test-CST_Calibration.R update error message --- tests/testthat/test-CST_Calibration.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test-CST_Calibration.R b/tests/testthat/test-CST_Calibration.R index 511d845c..285e81c8 100644 --- a/tests/testthat/test-CST_Calibration.R +++ b/tests/testthat/test-CST_Calibration.R @@ -218,7 +218,7 @@ test_that("1. Input checks", { expect_error( Calibration(exp4, obs4, eval.method = 'biass'), paste0("Parameter 'eval.method' must be a character string indicating ", - "the sampling method used ('in-sample', 'leave-one-out', 'hindcast-vs-forecast', 'k-fold', or ", + "the sampling method used ('in-sample', 'leave-one-out', 'hindcast-vs-forecast', 'k-fold' or ", "'retrospective')."), fixed = TRUE ) -- GitLab From 052c3f2a9caeaf36b5785ac6a38d9a06fa33be75 Mon Sep 17 00:00:00 2001 From: THEERTHA KARIYATHAN Date: Thu, 9 Jan 2025 15:26:55 +0100 Subject: [PATCH 15/16] Add .Rproj.user/ to .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 2f6c062a..14ce9c42 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,5 @@ Rplots.pdf .nfs* *.RData !data/*.RData +.Rproj.user/ + -- GitLab From 1707b16a3b8c683b59f7d26ffaa8b7d5113b750b Mon Sep 17 00:00:00 2001 From: THEERTHA KARIYATHAN Date: Thu, 9 Jan 2025 16:11:20 +0100 Subject: [PATCH 16/16] check +ve k give error msg --- R/CST_Calibration.R | 7 +++---- tests/testthat/test-make-eval-train-dexes.R | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/R/CST_Calibration.R b/R/CST_Calibration.R index 86bd7712..c4942eaa 100644 --- a/R/CST_Calibration.R +++ b/R/CST_Calibration.R @@ -577,11 +577,10 @@ Calibration <- function(exp, obs, exp_cor = NULL, # in the extreme keeping the same sample size for all k-folds (e.g. amt.points=50, k=3, # eval.dexes=1, train.dexes={3,49}). FALSE to remove only the corresponding tail (e.g.. # amt.points=50, k=3, eval.dexes=1, train.dexes={3,50}) - - if (k >= amt.points && !is.null(k)) { - stop("k need to be smaller than the amt.points") - } + if (k >= amt.points && !is.null(k) | k < 1) { + stop("k needs to be a positive integer less than the amt.points") + } if (eval.method == "leave-one-out") { dexes.lst <- lapply(seq(1, amt.points), function(x) return(list(eval.dexes = x, train.dexes = seq(1, amt.points)[-x]))) diff --git a/tests/testthat/test-make-eval-train-dexes.R b/tests/testthat/test-make-eval-train-dexes.R index 53e5dc88..a74a06c0 100644 --- a/tests/testthat/test-make-eval-train-dexes.R +++ b/tests/testthat/test-make-eval-train-dexes.R @@ -8,11 +8,11 @@ test_that("1. Input checks", { ) expect_error( .make.eval.train.dexes(eval.method = 'k-fold', 10, NULL, 12), - paste0("k need to be smaller than the amt.points") + paste0("k needs to be a positive integer less than the amt.points") ) expect_error( .make.eval.train.dexes(eval.method = 'retrospective', 10, NULL, 10), - paste0("k need to be smaller than the amt.points") + paste0("k needs to be a positive integer less than the amt.points") ) }) -- GitLab