CST_Calibration.R 13 KB
Newer Older
#"Forecast Calibration based on the ensemble inflation
#"
#"@author Verónica Torralba, \email{veronica.torralba@bsc.es}
#"@description This function applies a variance inflation technique described in Doblas-Reyes et al. (2005) in leave-one-out cross-validation. This bias adjustment method produces calibrated forecasts with equivalent mean and variance to that of the reference dataset, but at the same time preserve reliability. 
#"
#"@references Doblas-Reyes F.J, Hagedorn R, Palmer T.N. The rationale behind the success of multi-model ensembles in seasonal forecasting-II calibration and combination. Tellus A. 2005;57:234-252. doi:10.1111/j.1600-0870.2005.00104.x
#"
#"@param exp an object of class \code{s2dv_cube} as returned by \code{CST_Load} function, containing the seasonal forecast experiment data in the element named \code{$data}.
#"@param obs an object of class \code{s2dv_cube} as returned by \code{CST_Load} function, containing the observed data in the element named \code{$data}.
#"@param cal.method is the calibration method used, can be either "bias", "cal" or "mbm_cal". Default value is "bias".
#"@param eval.method is the sampling method used, can be either "in-sample" or "take-one-out". Default value is "take-one-out".
#"@return an object of class \code{s2dv_cube} containing the calibrated forecasts in the element \code{$data} with the same dimensions of the experimental data.
#"
#"@import s2dverification
#"@import multiApply
#"
#"@seealso \code{\link{CST_Load}}
#"
#"@examples
#"# Example
#"# Load data using CST_Load or use the sample data provided:
#"library(zeallot)
#"c(exp, obs) %<-% areave_data
#"exp_calibrated <- CST_Calibration(exp = exp, obs = obs)
#"str(exp_calibrated)
#"@export


CST_Calibration <- function(exp, obs, ...) {
  if (!inherits(exp, "s2dv_cube") || !inherits(exp, "s2dv_cube")) {
    stop("Parameter 'exp' and 'obs' must be of the class 's2dv_cube', ",
         "as output by CSTools::CST_Load.")
  if (dim(obs$data)["member"] != 1) {
    stop("The length of the dimension 'member' in the component 'data' ",
         "of the parameter 'obs' must be equal to 1.")

  exp$data <- .calibration(exp = exp$data, obs = obs$data, ...)
  exp$Datasets <- c(exp$Datasets, obs$Datasets)
  exp$source_files <- c(exp$source_files, obs$source_files)
  return(exp)
make.eval.train.dexeses <- function(eval.method, amt.points){
  if(amt.points < 10 & eval.method != "in-sample"){
	cat("Too few points, so sample method will necessarily be in-sample")
      eval.method <- "in-sample"
  }
  if(eval.method == "take-one-out"){
    dexes.lst <- lapply(seq(1, amt.points), function(x) return(list(eval.dexes = x, train.dexes = seq(1, amt.points)[-x])))
  } else if (eval.method == "in-sample"){
    dexes.lst <- list(list(eval.dexes = seq(1, amt.points), train.dexes = seq(1, amt.points)))
  } else {
    stop(paste0("unknown sampling method: ",eval.method))
  }
  return(dexes.lst)
}

.cal <- function(obs.fc, cal.method = "bias", eval.method = "take-one-out", ...) {
  dims.tmp=dim(obs.fc)
  amt.mbr <- dims.tmp["member"][] - 1
  amt.sdate <- dims.tmp["sdate"][]
  pos <- match(c("member","sdate"), names(dims.tmp))
  obs.fc <- aperm(obs.fc, pos)
  var.obs <- asub(obs.fc, list(1),1)
  var.fc <- asub(obs.fc, list(1+seq(1, amt.mbr)),1)
  dims.fc <- dim(var.fc)
  var.cor.fc <- NA * var.fc
  
  eval.train.dexeses <- make.eval.train.dexeses(eval.method, amt.points = amt.sdate)
  amt.resamples <- length(eval.train.dexeses)
  for (i.sample in seq(1, amt.resamples)) {
    # defining training (tr) and evaluation (ev) subsets 
    eval.dexes <- eval.train.dexeses[[i.sample]]$eval.dexes
    train.dexes <- eval.train.dexeses[[i.sample]]$train.dexes
    
    fc.ev <- var.fc[ , eval.dexes, drop = FALSE]
    fc.tr <- var.fc[ , train.dexes]
    obs.tr <- var.obs[train.dexes , drop = FALSE] 
    
    #calculate ensemble and observational characteristics
    if(cal.method == "bias"){
	  var.cor.fc[ , eval.dexes] <- fc.ev + mean(obs.tr, na.rm = TRUE) - mean(fc.tr, na.rm = TRUE)
	} else if (cal.method == "cal"){
	  quant.obs.fc.tr <- calc.obs.fc.quant(obs = obs.tr, fc = fc.tr)
      #calculate value for regression parameters
      init.par <- c(calc.cal.par(quant.obs.fc.tr), 0.)
	  #correct evaluation subset
      var.cor.fc[ , eval.dexes] <- correct.cal.fc(fc.ev , init.par)
    } else if (cal.method == "mbm_cal"){
	  quant.obs.fc.tr <- calc.obs.fc.quant.ext(obs = obs.tr, fc = fc.tr)
      #calculate initial value for regression parameters
      init.par <- c(calc.cal.par(quant.obs.fc.tr), 0.001)
      init.par[3] <- sqrt(init.par[3])
      #calculate regression parameters on training dataset
      optim.tmp <- optim(par = init.par, fn = calc.crps.opt, 
        gr = calc.crps.grad.opt,
        quant.obs.fc = quant.obs.fc.tr) #,  method = "CG"
      mbm.par <- optim.tmp$par
	  #correct evaluation subset
      var.cor.fc[ , eval.dexes] <- correct.mbm.fc(fc.ev , mbm.par)
    } else {
	  stop("unknown calibration method: ",cal.method)
    }
  } 
  names(dim(var.cor.fc)) <- c("member", "sdate")
  return(var.cor.fc)
}



.calibration <- function(exp, obs, ...) {
  target.dims <- c("member", "sdate")
  if (!all(target.dims %in% names(dim(exp)))) {
    stop("Parameter 'exp' must have the dimensions 'member' and 'sdate'.")
  }

  if (!all(c("sdate") %in% names(dim(obs)))) {
    stop("Parameter 'obs' must have the dimension 'sdate'.")
  }

    warning("Parameter 'exp' contains NA values.")
    warning("Parameter 'obs' contains NA values.")
  target_dims_obs <- "sdate"
  if ("member" %in% names(dim(obs))) {
    target_dims_obs <- c("member", target_dims_obs)
  
  amt.member=dim(exp)["member"]
  amt.sdate=dim(exp)["sdate"]
  target.dims <- c("member", "sdate")
  return.feat <- list(dim = c(amt.member, amt.sdate))
  return.feat$name <- c("member", "sdate")
  return.feat$dim.name <- list(dimnames(exp)[["member"]],dimnames(exp)[["sdate"]])
  
  ptm <- proc.time()
  calibrated <- .apply.obs.fc(obs = obs,
                      fc = exp,
                      target.dims = target.dims,
                      return.feat = return.feat,
                      FUN = .cal, ...)
                      
  return(calibrated)

.apply.obs.fc <- function(obs, fc, target.dims, FUN, return.feat, ...){
  dimnames.tmp <- dimnames(fc)
  fc.dims.tmp <- dim(fc)
  dims.out.tmp <- return.feat$dim
  
  obs.fc <- .combine.obs.fc(obs, fc)
  names.dim <- names(dim(obs.fc))
  amt.dims <- length(names.dim)
  margin.all <- seq(1, amt.dims)
  matched.dims <- match(target.dims, names.dim)
  margin.to.use <- margin.all[-matched.dims]
  arr.out <- apply(X = obs.fc,
    MARGIN = margin.to.use,
    FUN = FUN,
    ...)
  dims.tmp <- dim(arr.out)
  names.dims.tmp <- names(dim(arr.out))
  if(prod(return.feat$dim) != dims.tmp[1]){
    stop("apply.obs.fc: returned dimensions not as expected: ", prod(return.feat$dim), " and ", dims.tmp[1])
  dim(arr.out) <- c(dims.out.tmp, dims.tmp[-1])
  names(dim(arr.out)) <- c(return.feat$name, names.dims.tmp[c(-1)])
  names.dim[matched.dims] <- return.feat$name
  pos <- match(names.dim, names(dim(arr.out)))
  pos_inv <- match(names(dim(arr.out)), names.dim)
  arr.out <- aperm(arr.out, pos)
  for (i.item in seq(1,length(return.feat$name))){
    dimnames.tmp[[pos_inv[i.item]]] <- return.feat$dim.name[[i.item]]
  dimnames(arr.out) <- dimnames.tmp
  return(arr.out)
}



.combine.obs.fc <- function(obs,fc){
  names.dim.tmp <- names(dim(obs))
  members.dim <- which(names.dim.tmp == "member")
  arr.out <- abind(obs, fc, along = members.dim)
  dimnames.tmp <- dimnames(arr.out)
  names(dim(arr.out)) <- names.dim.tmp
  dimnames(arr.out) <- dimnames.tmp
  names(dimnames(arr.out)) <- names.dim.tmp
  return(arr.out)
}



calc.obs.fc.quant <- function(obs, fc){
  amt.mbr <- dim(fc)[1]
  obs.per.ens <- .spr(obs, amt.mbr)
  fc.ens.av <- apply(fc, c(2), mean, na.rm = TRUE)
  cor.obs.fc <- cor(fc.ens.av, obs, use = "complete.obs")
  obs.av <- mean(obs, na.rm = TRUE)
  obs.sd <- sd(obs, na.rm = TRUE)
  return(
    append(
      calc.fc.quant(fc = fc),
      list(
        obs.per.ens = obs.per.ens,
        cor.obs.fc = cor.obs.fc,
        obs.av = obs.av,
        obs.sd = obs.sd
      )
    )
  )
}

calc.obs.fc.quant.ext <- function(obs, fc){
  amt.mbr <- dim(fc)[1]
  obs.per.ens <- .spr(obs, amt.mbr)
  fc.ens.av <- apply(fc, c(2), mean, na.rm = TRUE)
  cor.obs.fc <- cor(fc.ens.av, obs, use = "complete.obs")
  obs.av <- mean(obs, na.rm = TRUE)
  obs.sd <- sd(obs, na.rm = TRUE)
  return(
    append(
      calc.fc.quant.ext(fc = fc),
      list(
        obs.per.ens = obs.per.ens,
        cor.obs.fc = cor.obs.fc,
        obs.av = obs.av,
        obs.sd = obs.sd
      )
    )
  )
}


calc.fc.quant <- function(fc){
  amt.mbr <- dim(fc)[1]
  fc.ens.av <- apply(fc, c(2), mean, na.rm = TRUE)
  fc.ens.av.av <- mean(fc.ens.av, na.rm = TRUE)
  fc.ens.av.sd <- sd(fc.ens.av, na.rm = TRUE)
  fc.ens.av.per.ens <- .spr(fc.ens.av, amt.mbr)
  fc.ens.sd <- apply(fc, c(2), sd, na.rm = TRUE)
  fc.ens.sd.av <- sqrt(mean(fc.ens.sd^2,na.rm = TRUE))
  fc.dev <- fc - fc.ens.av.per.ens
  fc.av <- mean(fc, na.rm = TRUE)
  fc.sd <- sd(fc, na.rm = TRUE)
  return(
    list(
      fc.ens.av = fc.ens.av,
      fc.ens.av.av = fc.ens.av.av,
      fc.ens.av.sd = fc.ens.av.sd,
      fc.ens.av.per.ens = fc.ens.av.per.ens,
      fc.ens.sd = fc.ens.sd,
      fc.ens.sd.av = fc.ens.sd.av,
      fc.dev = fc.dev,
      fc.av = fc.av,
      fc.sd = fc.sd
    )
  )
}

calc.fc.quant.ext <- function(fc){
  amt.mbr <- dim(fc)[1]
  fc.ens.av <- apply(fc, c(2), mean, na.rm = TRUE)
  fc.ens.av.av <- mean(fc.ens.av, na.rm = TRUE)
  fc.ens.av.sd <- sd(fc.ens.av, na.rm = TRUE)
  fc.ens.av.per.ens <- .spr(fc.ens.av, amt.mbr)
  fc.ens.sd <- apply(fc, c(2), sd, na.rm = TRUE)
  fc.ens.sd.av <- sqrt(mean(fc.ens.sd^2, na.rm = TRUE))
  fc.dev <- fc - fc.ens.av.per.ens
  repmat1.tmp <- .spr(fc, amt.mbr)
  repmat2.tmp <- aperm(repmat1.tmp, c(2, 1, 3))
  spr.abs <- apply(abs(repmat1.tmp - repmat2.tmp), c(3), mean, na.rm = TRUE)
  spr.abs.per.ens <- .spr(spr.abs, amt.mbr)
  fc.av <- mean(fc, na.rm = TRUE)
  fc.sd <- sd(fc, na.rm = TRUE)
  return(
    list(
      fc.ens.av = fc.ens.av,
      fc.ens.av.av = fc.ens.av.av,
      fc.ens.av.sd = fc.ens.av.sd,
      fc.ens.av.per.ens = fc.ens.av.per.ens,
      fc.ens.sd = fc.ens.sd,
      fc.ens.sd.av = fc.ens.sd.av,
      fc.dev = fc.dev,
      spr.abs = spr.abs,
      spr.abs.per.ens = spr.abs.per.ens,
      fc.av = fc.av,
      fc.sd = fc.sd
    )
  )
}



calc.cal.par <- function(quant.obs.fc){
  par.out <- rep(NA, 3)
  par.out[3] <- with(quant.obs.fc, obs.sd * sqrt(1 - cor.obs.fc^2) / fc.ens.sd.av)
  par.out[2] <- with(quant.obs.fc, cor.obs.fc * obs.sd / fc.ens.av.sd)
  par.out[1] <- with(quant.obs.fc, obs.av - par.out[2] * fc.ens.av.av, na.rm = TRUE)
  return(par.out)
}



correct.mbm.fc <- function(fc, par){
  quant.fc.mp <- calc.fc.quant.ext(fc = fc)
  return(with(quant.fc.mp, par[1] + par[2] * fc.ens.av.per.ens + fc.dev * abs((par[3])^2 + par[4] / spr.abs)))
}

correct.cal.fc <- function(fc, par){
  quant.fc.mp <- calc.fc.quant(fc = fc)
  return(with(quant.fc.mp, par[1] + par[2] * fc.ens.av.per.ens + fc.dev * par[3]))
}


calc.crps <- function(obs, fc){
  quant.obs.fc <- calc.obs.fc.quant.ext(obs = obs, fc = fc)
  return(with(quant.obs.fc,
    mean(apply(abs(obs.per.ens - fc), c(2), mean, na.rm = TRUE) - spr.abs / 2., na.rm = TRUE)))
}


calc.crps.opt <- function(par, quant.obs.fc){
  return( 
    with(quant.obs.fc, 
      mean(abs(obs.per.ens - (par[1] + par[2] * fc.ens.av.per.ens +
	    ((par[3])^2 + par[4] / spr.abs.per.ens) * fc.dev)), na.rm = TRUE) -
        mean(abs((par[3])^2 * spr.abs + par[4]) / 2., na.rm = TRUE)
    )
  )
}


calc.crps.grad.opt <- function(par, quant.obs.fc){
  attach(quant.obs.fc)
  sgn1 <- sign(obs.per.ens - (par[1] + par[2] * fc.ens.av.per.ens +
	    ((par[3])^2 + par[4] / spr.abs.per.ens) * fc.dev))
  sgn2 <- sign((par[3])^2 + par[4] / spr.abs.per.ens)
  sgn3 <- sign((par[3])^2 * spr.abs + par[4])
  deriv.par1 <- mean(sgn1, na.rm = TRUE)
  deriv.par2 <- mean(sgn1 * fc.dev, na.rm = TRUE)
  deriv.par3 <- mean(2* par[3] * sgn1 * sgn2 * fc.ens.av.per.ens, na.rm = TRUE) -
    mean(spr.abs * sgn3, na.rm = TRUE) / 2.
  deriv.par4 <- mean(sgn1 * sgn2 * fc.ens.av.per.ens / spr.abs.per.ens, na.rm = TRUE) -
    mean(sgn3, na.rm = TRUE) / 2.
  return(c(deriv.par1, deriv.par2, deriv.par3, deriv.par4))
}



.spr <- function(x, amt.spr, dim = 1) {
  if(is.vector(x)){
    amt.dims <- 1
    if(dim == 2){
		arr.out <- array(rep(x, amt.spr), c(length(x), amt.spr))
    } else if(dim == 1){
		arr.out <- t(array(rep(x, amt.spr), c(length(x), amt.spr)))
	} else {
		stop(paste0("error in .spr: amt.dims = ",amt.dims," while dim = ",dim))
	}
  } else if(is.array(x)) {
    amt.dims <- length(dim(x))
    if(dim > amt.dims + 1){ 
		stop(paste0("error in .spr: amt.dims = ",amt.dims," while dim = ",dim))
	}
    arr.out <- array(rep(as.vector(x), amt.spr), c(dim(x), amt.spr))
    if(dim != amt.dims + 1){ 
      amt.dims.out <- amt.dims + 1
	  dims.tmp <- seq(1, amt.dims.out)
	  dims.tmp[seq(dim, amt.dims.out)] <- c(amt.dims.out, seq(dim,amt.dims.out-1))
	  arr.out <- aperm(arr.out, dims.tmp)
	}
  } else {
    stop("x is not array nor vector but is ", class(x))
  return(arr.out)