From 82be6a6b1ea539d3d4b21407a3783e21ba3719bf Mon Sep 17 00:00:00 2001 From: nperez Date: Wed, 16 Jul 2025 16:35:52 +0200 Subject: [PATCH] unit test for sub-seasonal crossval calibration --- modules/Crossval/Crossval_calibration.R | 4 +- modules/Crossval/Crossval_metrics.R | 10 +- .../test-subseasonal_weekly_crossval.R | 116 ++++++++++++++++++ 3 files changed, 126 insertions(+), 4 deletions(-) create mode 100644 tests/testthat/test-subseasonal_weekly_crossval.R diff --git a/modules/Crossval/Crossval_calibration.R b/modules/Crossval/Crossval_calibration.R index 75765e75..48f250c9 100644 --- a/modules/Crossval/Crossval_calibration.R +++ b/modules/Crossval/Crossval_calibration.R @@ -56,10 +56,10 @@ Crossval_calibration <- function(recipe, data, correct_negative = FALSE) { if (horizon == 'subseasonal') { central_day <- (dim(exp)['sday'] + 1)/2 hcst_tr <- MergeDims(hcst_tr, merge_dims = c('sday', 'syear'), - rename_dim = 'syear', na.rm = na.rm) + rename_dim = 'syear', na.rm = FALSE) obs_tr <- MergeDims(obs_tr, merge_dims = c('sday', 'syear'), - rename_dim = 'syear', na.rm = na.rm) + rename_dim = 'syear', na.rm = FALSE) # 'sday' dim to select the central day hcst_ev <- Subset(hcst_ev, along = 'sday', indices = central_day, drop = 'selected') diff --git a/modules/Crossval/Crossval_metrics.R b/modules/Crossval/Crossval_metrics.R index 1532fdc2..0026dc0c 100644 --- a/modules/Crossval/Crossval_metrics.R +++ b/modules/Crossval/Crossval_metrics.R @@ -92,7 +92,7 @@ Crossval_metrics <- function(recipe, extra_info = list(Fair = Fair), ncores = ncores)$output1 skill_metrics$crps <- crps - if (is.null(datos$ref_obs)) { + if (is.null(data$ref_obs)) { # Build the reference forecast: ref_clim <- Apply(list(datos$obs, tmp), target_dims = list(c('ensemble', 'syear'), NULL), @@ -101,8 +101,14 @@ Crossval_metrics <- function(recipe, indices = cross[[y]]$train.dexes, drop = T)}, ncores = ncores, output_dims = 'ensemble')$output1 } else { - ref_clim <- datos$ref_obs + ref_clim <- data$ref_obs } + if (horizon == 'subseasonal') { + # The evaluation of all metrics are done with extra sample + ref_clim <- MergeDims(ref_clim, merge_dims = c('sweek', 'syear'), + rename_dim = 'syear', na.rm = FALSE) + } + ## Do we want to do this...? # CRPS Climatology datos <- append(list(ref = ref_clim), datos) diff --git a/tests/testthat/test-subseasonal_weekly_crossval.R b/tests/testthat/test-subseasonal_weekly_crossval.R new file mode 100644 index 00000000..f7d9e812 --- /dev/null +++ b/tests/testthat/test-subseasonal_weekly_crossval.R @@ -0,0 +1,116 @@ +context("Subsseasonal weekly data") + +source("modules/Loading/Loading.R") +source("modules/Visualization/Visualization.R") + +recipe_file <- "tests/recipes/recipe-subseasonal_weekly.yml" +recipe <- prepare_outputs(recipe_file, disable_checks = F) + +# Load datasets +suppressWarnings({invisible(capture.output( +data <- Loading(recipe) +))}) + +# Calibrate data +source("modules/Crossval/Crossval_calibration.R") +calibrated <- Crossval_calibration(recipe = recipe, data = data) + +# Compute skill metrics +source("modules/Crossval/Crossval_metrics.R") +suppressWarnings({invisible(capture.output( +skill_metrics <- Crossval_metrics(recipe, calibrated) +))}) + +# Plotting +suppressWarnings({invisible(capture.output( +Visualization(recipe = recipe, data = calibrated, + skill_metrics = skill_metrics, probabilities = calibrated$probs, + significance = T) +))}) +outdir <- get_dir(recipe = recipe, variable = data$hcst$attrs$Variable$varName) + +# ------- TESTS -------- + +test_that("2. Crossval_calibration", { + +expect_equal(is.list(calibrated), TRUE) + +expect_equal(names(calibrated), + c("hcst", "obs", "fcst", "hcst.full_val", "obs.full_val", + "cat_lims", "probs", "ref_obs")) + +expect_equal(class(calibrated$hcst), "s2dv_cube") + +expect_equal(class(calibrated$fcst), "s2dv_cube") + +expect_equal(dim(calibrated$hcst$data), + c(sday = 1, dat = 1, var = 1, sweek = 5, time = 4, latitude = 10, + longitude = 21, ensemble = 12, syear = 5)) + +expect_equal(dim(calibrated$fcst$data), + c(dat = 1, var = 1, sday = 1, syear = 1, sweek = 1, time = 4, + latitude = 10, longitude = 21, ensemble = 48)) + +expect_equal(mean(calibrated$fcst$data), + 299.9444, tolerance = 0.0001) + +expect_equal(mean(calibrated$hcst$data, na.rm = TRUE), + 299.4878, tolerance = 0.0001) + +expect_equal(as.vector(drop(calibrated$hcst$data)[2,, 1, 2, 3, 4]), + c(298.6695, 298.6082, 299.1191, 299.1713), tolerance = 0.0001) + +expect_equal(range(calibrated$fcst$data), + c(297.8898, 302.3096), tolerance = 0.0001) + +}) + + +#====================================== +test_that("3. Crossval_metrics", { + +expect_equal(is.list(skill_metrics), TRUE) + +expect_equal(names(skill_metrics), + c("crps", "crps_clim", "crpss", + "crpss_significance", + "enssprerr", "enssprerr_significance", + "enscorr", "enscorr_significance", + "rpss-set1", "rpss-set1_significance", + "mean_bias", "mean_bias_significance", + "rpss-set2", "rpss-set2_significance", + "rpss-set3", "rpss-set3_significance")) + +expect_equal(class(skill_metrics[['rpss-set1']]), "array") + +expect_equal(dim(skill_metrics[['rpss-set1']]), + c(var = 1, time = 4, latitude = 10, longitude = 21)) + +expect_equal(dim(skill_metrics[['rpss-set1_significance']]), + dim(skill_metrics[['rpss-set1']])) + +expect_equal(as.vector(skill_metrics[['rpss-set1']][, , 2, 3]), + c(0.22625000, 0.41974432, -0.07348901, -0.06593407), + tolerance = 0.0001) + +expect_equal(as.vector(skill_metrics[['rpss-set1_significance']][, , 2, 3]), + c(0,0,0,0)) + +}) + + +test_that("5. Visualization", { +plots <- paste0(recipe$Run$output_dir, "/plots/") +expect_equal(all(basename(list.files(plots, recursive = T)) %in% + c("forecast_ensemble_median-20240104_ft01.png", + "forecast_ensemble_median-20240104_ft02.png", + "forecast_ensemble_median-20240104_ft03.png", + "forecast_ensemble_median-20240104_ft04.png")), + TRUE) + +expect_equal(length(list.files(plots, recursive = T)), 4) + +}) + +# Delete files +unlink(recipe$Run$output_dir, recursive = T) -- GitLab