From 3b32735b292a216470e04cc10ad66ecbc8fa59f0 Mon Sep 17 00:00:00 2001 From: aho Date: Fri, 29 Sep 2023 16:21:13 +0200 Subject: [PATCH 1/2] Allow obs to have no memb_dim --- R/Ano_CrossValid.R | 85 ++++++++++++++++++---------- tests/testthat/test-Ano_CrossValid.R | 43 +++++++++++++- 2 files changed, 95 insertions(+), 33 deletions(-) diff --git a/R/Ano_CrossValid.R b/R/Ano_CrossValid.R index d1996b9..7117f41 100644 --- a/R/Ano_CrossValid.R +++ b/R/Ano_CrossValid.R @@ -72,10 +72,6 @@ Ano_CrossValid <- function(exp, obs, time_dim = 'sdate', dat_dim = c('dataset', any(is.null(names(dim(obs))))| any(nchar(names(dim(obs))) == 0)) { stop("Parameter 'exp' and 'obs' must have dimension names.") } - if(!all(names(dim(exp)) %in% names(dim(obs))) | - !all(names(dim(obs)) %in% names(dim(exp)))) { - stop("Parameter 'exp' and 'obs' must have same dimension name.") - } ## time_dim if (!is.character(time_dim) | length(time_dim) > 1) { stop("Parameter 'time_dim' must be a character string.") @@ -83,13 +79,38 @@ Ano_CrossValid <- function(exp, obs, time_dim = 'sdate', dat_dim = c('dataset', if (!time_dim %in% names(dim(exp)) | !time_dim %in% names(dim(obs))) { stop("Parameter 'time_dim' is not found in 'exp' or 'obs' dimension.") } + ## memb + if (!is.logical(memb) | length(memb) > 1) { + stop("Parameter 'memb' must be one logical value.") + } + ## memb_dim + if (!memb) { + if (!is.character(memb_dim) | length(memb_dim) > 1) { + stop("Parameter 'memb_dim' must be a character string.") + } + if (!memb_dim %in% names(dim(exp)) & !memb_dim %in% names(dim(obs))) { + stop("Parameter 'memb_dim' is not found in 'exp' nor 'obs' dimension. ", + "Set it as NULL if there is no member dimension.") + } +# # Add [member = 1] +# if (memb_dim %in% names(dim(exp)) & !memb_dim %in% names(dim(obs))) { +# dim(obs) <- c(dim(obs), 1) +# names(dim(obs))[length(dim(obs))] <- memb_dim +# } +# if (!memb_dim %in% names(dim(exp)) & memb_dim %in% names(dim(obs))) { +# dim(exp) <- c(dim(exp), 1) +# names(dim(exp))[length(dim(exp))] <- memb_dim +# } + } + ## dat_dim + reset_obs_dim <- reset_exp_dim <- FALSE if (!is.null(dat_dim)) { if (!is.character(dat_dim)) { stop("Parameter 'dat_dim' must be a character vector.") } - if (!all(dat_dim %in% names(dim(exp))) | !all(dat_dim %in% names(dim(obs)))) { - stop("Parameter 'dat_dim' is not found in 'exp' or 'obs' dimension.", + if (!any(dat_dim %in% names(dim(exp))) & !any(dat_dim %in% names(dim(obs)))) { + stop("Parameter 'dat_dim' is not found in 'exp' nor 'obs' dimension.", " Set it as NULL if there is no dataset dimension.") } # If dat_dim is not in obs, add it in @@ -98,28 +119,22 @@ Ano_CrossValid <- function(exp, obs, time_dim = 'sdate', dat_dim = c('dataset', ori_obs_dim <- dim(obs) dim(obs) <- c(dim(obs), rep(1, length(dat_dim[which(!dat_dim %in% names(dim(obs)))]))) names(dim(obs)) <- c(names(ori_obs_dim), dat_dim[which(!dat_dim %in% names(dim(obs)))]) - } else { - reset_obs_dim <- FALSE } - } else { - reset_obs_dim <- FALSE - } - ## memb - if (!is.logical(memb) | length(memb) > 1) { - stop("Parameter 'memb' must be one logical value.") + # If dat_dim is not in obs, add it in + if (any(!dat_dim %in% names(dim(exp)))) { + reset_exp_dim <- TRUE + ori_exp_dim <- dim(exp) + dim(exp) <- c(dim(exp), rep(1, length(dat_dim[which(!dat_dim %in% names(dim(exp)))]))) + names(dim(exp)) <- c(names(ori_exp_dim), dat_dim[which(!dat_dim %in% names(dim(exp)))]) + } } - ## memb_dim + # memb_dim and dat_dim if (!memb) { - if (!is.character(memb_dim) | length(memb_dim) > 1) { - stop("Parameter 'memb_dim' must be a character string.") - } - if (!memb_dim %in% names(dim(exp)) | !memb_dim %in% names(dim(obs))) { - stop("Parameter 'memb_dim' is not found in 'exp' or 'obs' dimension.") - } if (!memb_dim %in% dat_dim) { stop("Parameter 'memb_dim' must be one element in parameter 'dat_dim'.") - } + } } + ## ncores if (!is.null(ncores)) { if (!is.numeric(ncores) | ncores %% 1 != 0 | ncores <= 0 | @@ -184,17 +199,25 @@ Ano_CrossValid <- function(exp, obs, time_dim = 'sdate', dat_dim = c('dataset', # Remove dat_dim in obs if obs doesn't have at first place if (reset_obs_dim) { - res_obs_dim <- ori_obs_dim[-which(names(ori_obs_dim) == time_dim)] - if (!memb & memb_dim %in% names(res_obs_dim)) { - res_obs_dim <- res_obs_dim[-which(names(res_obs_dim) == memb_dim)] - } - if (is.integer(res_obs_dim) & length(res_obs_dim) == 0) { - res$obs <- as.vector(res$obs) - } else { - res$obs <- array(res$obs, dim = res_obs_dim) - } + tmp <- match(names(dim(res$obs)), names(ori_obs_dim)) + dim(res$obs) <- ori_obs_dim[tmp[which(!is.na(tmp))]] + } + if (reset_exp_dim) { + tmp <- match(names(dim(res$exp)), names(ori_exp_dim)) + dim(res$exp) <- ori_exp_dim[tmp[which(!is.na(tmp))]] } +# res_obs_dim <- ori_obs_dim[-which(names(ori_obs_dim) == time_dim)] +# if (!memb & memb_dim %in% names(res_obs_dim)) { +# res_obs_dim <- res_obs_dim[-which(names(res_obs_dim) == memb_dim)] +# } +# if (is.integer(res_obs_dim) & length(res_obs_dim) == 0) { +# res$obs <- as.vector(res$obs) +# } else { +# res$obs <- array(res$obs, dim = res_obs_dim) +# } +# } + return(res) } diff --git a/tests/testthat/test-Ano_CrossValid.R b/tests/testthat/test-Ano_CrossValid.R index c5eea59..d450ff0 100644 --- a/tests/testthat/test-Ano_CrossValid.R +++ b/tests/testthat/test-Ano_CrossValid.R @@ -5,6 +5,10 @@ exp1 <- array(rnorm(60), dim = c(dataset = 2, member = 3, sdate = 5, ftime = 2)) set.seed(2) obs1 <- array(rnorm(20), dim = c(dataset = 1, member = 2, sdate = 5, ftime = 2)) +obs1_2 <- obs1 +dim(obs1_2) <- c(member = 2, sdate = 5, ftime = 2) +obs1_3 <- obs1[1,1,,] +exp1_2 <- exp1[,1,,] # dat2 set.seed(1) exp2 <- array(rnorm(30), dim = c(member = 3, ftime = 2, sdate = 5)) @@ -55,7 +59,7 @@ test_that("1. Input checks", { ) expect_error( Ano_CrossValid(exp1, obs1, dat_dim = 'dat'), - "Parameter 'dat_dim' is not found in 'exp' or 'obs' dimension. Set it as NULL if there is no dataset dimension." + "Parameter 'dat_dim' is not found in 'exp' nor 'obs' dimension. Set it as NULL if there is no dataset dimension." ) # memb expect_error( @@ -69,7 +73,7 @@ test_that("1. Input checks", { ) expect_error( Ano_CrossValid(exp1, obs1, memb = FALSE, memb_dim = 'memb'), - "Parameter 'memb_dim' is not found in 'exp' or 'obs' dimension." + "Parameter 'memb_dim' is not found in 'exp' nor 'obs' dimension. Set it as NULL if there is no member dimension." ) expect_error( Ano_CrossValid(exp1, obs1, memb = FALSE, memb_dim = 'ftime'), @@ -115,6 +119,41 @@ test_that("2. dat1", { tolerance = 0.0001 ) + expect_equal( + dim(Ano_CrossValid(exp1, obs1_2)$obs), + c(sdate = 5, member = 2, ftime = 2) + ) + expect_equal( + Ano_CrossValid(exp1, obs1)$exp, + Ano_CrossValid(exp1, obs1_2)$exp + ) + expect_equal( + c(Ano_CrossValid(exp1, obs1)$obs), + c(Ano_CrossValid(exp1, obs1_2)$obs) + ) + + expect_equal( + Ano_CrossValid(exp1, obs1)$exp, + Ano_CrossValid(exp1, obs1_3)$exp + ) + expect_equal( + dim(Ano_CrossValid(exp1, obs1_3)$obs), + c(sdate = 5, ftime = 2) + ) + expect_equal( + c(Ano_CrossValid(exp1, obs1_3)$obs), + c(Ano_CrossValid(exp1, obs1)$obs[, 1, 1, ]) + ) + + expect_equal( + dim(Ano_CrossValid(exp1_2, obs1)$exp), + c(sdate = 5, dataset = 2, ftime = 2) + ) + expect_equal( + c(Ano_CrossValid(exp1_2, obs1)$exp), + c(Ano_CrossValid(exp1, obs1)$exp[,,1,]) + ) + }) ############################################## -- GitLab From 8568137787c33f16dc2cd7538cd52718146e932e Mon Sep 17 00:00:00 2001 From: aho Date: Mon, 2 Oct 2023 12:39:11 +0200 Subject: [PATCH 2/2] Improve checks; Consider different dimension orders of dat and member --- R/ACC.R | 2 +- R/Ano_CrossValid.R | 10 +++++----- R/Corr.R | 2 +- R/RMS.R | 2 +- R/RatioSDRMS.R | 2 +- tests/testthat/test-Ano_CrossValid.R | 30 +++++++++++++++++++++++++++- 6 files changed, 38 insertions(+), 10 deletions(-) diff --git a/R/ACC.R b/R/ACC.R index d921ce8..71544b9 100644 --- a/R/ACC.R +++ b/R/ACC.R @@ -281,7 +281,7 @@ ACC <- function(exp, obs, dat_dim = NULL, lat_dim = 'lat', lon_dim = 'lon', name_exp <- name_exp[-which(name_exp == memb_dim)] name_obs <- name_obs[-which(name_obs == memb_dim)] } - if (!all(dim(exp)[name_exp] == dim(obs)[name_obs])) { + if (!identical(dim(exp)[name_exp], dim(obs)[name_obs])) { stop(paste0("Parameter 'exp' and 'obs' must have same length of ", "all the dimensions except 'dat_dim' and 'memb_dim'.")) } diff --git a/R/Ano_CrossValid.R b/R/Ano_CrossValid.R index 7117f41..13f7e97 100644 --- a/R/Ano_CrossValid.R +++ b/R/Ano_CrossValid.R @@ -151,7 +151,7 @@ Ano_CrossValid <- function(exp, obs, time_dim = 'sdate', dat_dim = c('dataset', name_obs <- name_obs[-which(name_obs == dat_dim[i])] } } - if(!all(dim(exp)[name_exp] == dim(obs)[name_obs])) { + if (!identical(dim(exp)[name_exp], dim(obs)[name_obs])) { stop(paste0("Parameter 'exp' and 'obs' must have the same length of ", "all dimensions except 'dat_dim'.")) } @@ -175,10 +175,10 @@ Ano_CrossValid <- function(exp, obs, time_dim = 'sdate', dat_dim = c('dataset', outrows_exp <- MeanDims(exp, pos, na.rm = FALSE) + MeanDims(obs, pos, na.rm = FALSE) outrows_obs <- outrows_exp - - for (i in 1:length(pos)) { - outrows_exp <- InsertDim(outrows_exp, pos[i], dim(exp)[pos[i]]) - outrows_obs <- InsertDim(outrows_obs, pos[i], dim(obs)[pos[i]]) +#browser() + for (i_pos in sort(pos)) { + outrows_exp <- InsertDim(outrows_exp, i_pos, dim(exp)[i_pos]) + outrows_obs <- InsertDim(outrows_obs, i_pos, dim(obs)[i_pos]) } exp_for_clim <- exp obs_for_clim <- obs diff --git a/R/Corr.R b/R/Corr.R index fe03041..c11fcf6 100644 --- a/R/Corr.R +++ b/R/Corr.R @@ -222,7 +222,7 @@ Corr <- function(exp, obs, time_dim = 'sdate', dat_dim = NULL, name_exp <- name_exp[-which(name_exp == memb_dim)] name_obs <- name_obs[-which(name_obs == memb_dim)] } - if(!all(dim(exp)[name_exp] == dim(obs)[name_obs])) { + if (!identical(dim(exp)[name_exp], dim(obs)[name_obs])) { stop(paste0("Parameter 'exp' and 'obs' must have same length of ", "all dimension except 'dat_dim' and 'memb_dim'.")) } diff --git a/R/RMS.R b/R/RMS.R index b603c37..8f7e58b 100644 --- a/R/RMS.R +++ b/R/RMS.R @@ -185,7 +185,7 @@ RMS <- function(exp, obs, time_dim = 'sdate', memb_dim = NULL, dat_dim = NULL, if (!all(name_exp == name_obs)) { stop("Parameter 'exp' and 'obs' must have the same dimension names.") } - if (!all(dim(exp)[name_exp] == dim(obs)[name_obs])) { + if (!identical(dim(exp)[name_exp], dim(obs)[name_obs])) { stop(paste0("Parameter 'exp' and 'obs' must have same length of ", "all dimensions except 'dat_dim' and 'memb_dim'.")) } diff --git a/R/RatioSDRMS.R b/R/RatioSDRMS.R index b38d5e2..6040410 100644 --- a/R/RatioSDRMS.R +++ b/R/RatioSDRMS.R @@ -111,7 +111,7 @@ RatioSDRMS <- function(exp, obs, dat_dim = NULL, memb_dim = 'member', } name_exp <- name_exp[-which(name_exp == memb_dim)] name_obs <- name_obs[-which(name_obs == memb_dim)] - if(!all(dim(exp)[name_exp] == dim(obs)[name_obs])) { + if (!identical(dim(exp)[name_exp], dim(obs)[name_obs])) { stop(paste0("Parameter 'exp' and 'obs' must have same length of ", "all the dimensions except 'dat_dim' and 'memb_dim'.")) } diff --git a/tests/testthat/test-Ano_CrossValid.R b/tests/testthat/test-Ano_CrossValid.R index d450ff0..0e2b442 100644 --- a/tests/testthat/test-Ano_CrossValid.R +++ b/tests/testthat/test-Ano_CrossValid.R @@ -1,14 +1,22 @@ ############################################## - # dat1 +# dat1 set.seed(1) exp1 <- array(rnorm(60), dim = c(dataset = 2, member = 3, sdate = 5, ftime = 2)) set.seed(2) obs1 <- array(rnorm(20), dim = c(dataset = 1, member = 2, sdate = 5, ftime = 2)) +## different member and dat dim obs1_2 <- obs1 dim(obs1_2) <- c(member = 2, sdate = 5, ftime = 2) obs1_3 <- obs1[1,1,,] +obs1_4 <- obs1[, 1, , ]; dim(obs1_4) <- c(dataset = 1, dim(obs1_4)) + exp1_2 <- exp1[,1,,] + +## not usual dimension order +exp1_5 <- aperm(exp1, 4:1) +obs1_5 <- aperm(obs1, c(3, 4, 2, 1)) + # dat2 set.seed(1) exp2 <- array(rnorm(30), dim = c(member = 3, ftime = 2, sdate = 5)) @@ -21,6 +29,8 @@ exp3 <- array(rnorm(30), dim = c(ftime = 2, sdate = 5)) set.seed(2) obs3 <- array(rnorm(20), dim = c(ftime = 2, sdate = 5)) +# dat4: not usual dimension order + ############################################## test_that("1. Input checks", { @@ -145,6 +155,19 @@ test_that("2. dat1", { c(Ano_CrossValid(exp1, obs1)$obs[, 1, 1, ]) ) + expect_equal( + dim(Ano_CrossValid(exp1, obs1_4)$obs), + c(sdate = 5, dataset = 1, ftime = 2) + ) + expect_equal( + Ano_CrossValid(exp1, obs1_4)$exp, + Ano_CrossValid(exp1, obs1)$exp + ) + expect_equal( + c(Ano_CrossValid(exp1, obs1_4)$obs), + c(Ano_CrossValid(exp1, obs1)$obs[, , 1, ]) + ) + expect_equal( dim(Ano_CrossValid(exp1_2, obs1)$exp), c(sdate = 5, dataset = 2, ftime = 2) @@ -154,6 +177,11 @@ test_that("2. dat1", { c(Ano_CrossValid(exp1, obs1)$exp[,,1,]) ) + expect_equal( + Ano_CrossValid(exp1, obs1), + Ano_CrossValid(exp1_5, obs1_5) + ) + }) ############################################## -- GitLab