plot_forecasts.R 14.2 KB
Newer Older
Llorenç Lledó's avatar
Llorenç Lledó committed
#########################################################
# Barcelona Supercomputing Center			#
# llledo@bsc.es - September 2018			#
# Functions to plot several probabilistic forecasts	#
#########################################################
library(data.table)
library(ggplot2)
library(reshape2)
library(plyr)

#========================
# Function to plot several forecasts for an event, 
# either initialized at different moments or by different models. 
# Probabilities for extreme categories can be added as hatches.
# Ensemble members can be added as jittered points.
# The observed value can be added as a diamond over the pdf.
#========================
plot_forecasts <- function(fcst_df,tercile_limits,extreme_limits=NULL,obs=NULL,plotfile=NULL,title="Set a title",varname="Varname (units)",fcst_names=NULL,add_ensmemb=c("above","below","no"),colors=c("ggplot","dst","hydro")) {

	#------------------------
	# Color definitions
	#------------------------
	colors <- match.arg(colors)
	if(colors=="dst") {
		colorFill <- rev(c("#FF764D","#b5b5b5","#33BFD1"))
		colorHatch <- c("deepskyblue3","indianred3")
		colorMember <- c("#ffff7f")
		colorObs <- "purple"
		colorLab <- c("blue","red") 
	} else if(colors=="hydro") {
		colorFill <- rev(c("#41CBC9","#b5b5b5","#FFAB38"))
		colorHatch <- c("darkorange1","deepskyblue3")
		colorMember <- c("#ffff7f")
		colorObs <- "purple" 
		colorLab <- c("darkorange3","blue") 
	} else if(colors=="ggplot") {
		gg_color_hue <- function(n) {
			hues = seq(15, 375, length = n + 1)
			hcl(h = hues, l = 65, c = 100)[1:n]
		}
		colorFill <- rev(gg_color_hue(3))
		colorHatch <- c("deepskyblue3","indianred1")
		colorMember <- c("#ffff7f")
		colorObs <- "purple" 
		colorLab <- c("blue","red") 
	} else { stop("Unknown color set") }

	#------------------------
	# Check args
	#------------------------
	add_ensmemb <- match.arg(add_ensmemb)
	if(length(tercile_limits)!=2) stop("Provide two limits for delimiting tercile categories")
	if(tercile_limits[1]>=tercile_limits[2]) stop("The provided tercile limits are in the wrong order")
	if(!is.null(extreme_limits)) {
		if(length(extreme_limits)!=2) stop("Provide two limits for delimiting extreme categories")
		if(extreme_limits[1]>=tercile_limits[1] | extreme_limits[2]<=tercile_limits[2]) stop("The provided extreme limits are not consistent with tercile limits")
	}

	#------------------------
	# Set proper fcst names
	#------------------------
	if(!is.null(fcst_names)) {
		colnames(fcst_df) <- factor(fcst_names,levels=fcst_names)
	}

	#------------------------
	# Produce a first plot with the pdf for each init in a panel
	#------------------------
	melt_df <- melt(fcst_df,variable.name="init")
	plot <- ggplot(melt_df,aes(x=value)) + geom_density(alpha=1,na.rm=T) + coord_flip() + facet_wrap(~init,strip.position="top",nrow=1) + xlim(range(c(obs,density(melt_df$value,na.rm=T)$x)))
	ggp <- ggplot_build(plot)

	#------------------------
	# Gather the coordinates of the plots
	# together with init and corresponding terciles
	#------------------------
	tmp.df <- ggp$data[[1]][,c("x","ymin","ymax","PANEL")]
	tmp.df$init <- ggp$layout$layout$init[as.numeric(tmp.df$PANEL)]
	tmp.df$tercile <- factor(ifelse(tmp.df$x<tercile_limits[1],"Below normal",ifelse(tmp.df$x<tercile_limits[2],"Normal","Above normal")),levels=c("Below normal","Normal","Above normal"))

	#------------------------
	# Get the height and width of a panel
	#------------------------
	pan_width <- diff(range(tmp.df$x))
	pan_height <- max(tmp.df$ymax)
	magicratio <- 9*pan_height/pan_width

	#------------------------
	# Compute hatch coordinates for extremes
	#------------------------
	if(!is.null(extreme_limits)) {
		tmp.df$extremes <- factor(ifelse(tmp.df$x<extreme_limits[1],"Below P10",ifelse(tmp.df$x<extreme_limits[2],"Normal","Above P90")),levels=c("Below P10","Normal","Above P90"))
		hatch.ls <- dlply(tmp.df,.(init,extremes),function(x) { 
				tmp.df2 <- data.frame(x=c(x$x,max(x$x),min(x$x)),y=c(x$ymax,0,0)) # close the polygon
				hatches <- polygon.fullhatch(tmp.df2$x,tmp.df2$y,angle=60,density=10,width_units=pan_width,height_units=pan_height) # compute the hatches for this polygon
				end1 <- data.frame(x=x$x[1],y=x$ymax[1],xend=x$x[1],yend=0) # add bottom segment
				end2 <- data.frame(x=x$x[length(x$x)],y=x$ymax[length(x$x)],xend=x$x[length(x$x)],yend=0) # add top segment
				return(rbind(hatches,end1,end2))
		      })
		attr <- attr(hatch.ls,"split_labels")
		for (i in 1:length(hatch.ls)) { hatch.ls[[i]] <- cbind(hatch.ls[[i]],attr[i,]) }
		hatch.df <- do.call("rbind",hatch.ls)

		# Compute max y for each extreme category
		max.ls <- dlply(tmp.df,.(init,extremes),function(x) data.frame(y=min(0.85*pan_height,max(x$ymax))))
		attr <- attr(max.ls,"split_labels")
		for (i in 1:length(max.ls)) { max.ls[[i]] <- cbind(max.ls[[i]],attr[i,]) }
		max.df <- do.call("rbind",max.ls)
	}

	#------------------------
	# Compute jitter space for ensemble members
	#------------------------
	jitter_df <- melt(data.frame(dlply(melt_df,.(init),function(x) jitter_ensmemb(sort(x$value),pan_width/100)),check.names=F), value.name="yjitter",variable.name="init")
	jitter_df$x <- melt(data.frame(dlply(melt_df,.(init),function(x) sort(x$value))), value.name="x")$x

	#------------------------
	# Get y coordinates for observed x values, 
	# using a cool data.table feature: merge to nearest value
	#------------------------
	if(!is.null(obs)) {
	 	tmp.dt <- data.table(tmp.df,key=c("init","x"))
 		obs_dt <- data.table(init=factor(colnames(fcst_df),levels=colnames(fcst_df)),value=rep(obs,dim(fcst_df)[2]))
		setkey(obs_dt,init,value)
		obs_xy <- tmp.dt[obs_dt,roll="nearest"]
	}

	#------------------------
	# Fill each pdf with different colors for the terciles 
	#------------------------
	plot <- plot + 
		geom_ribbon(data=tmp.df,aes(x=x,ymin=ymin,ymax=ymax,fill=tercile),alpha=0.7) 

	#------------------------
	# Add hatches for extremes
	#------------------------
	if(!is.null(extreme_limits)) {
		if(nrow(hatch.df[hatch.df$extremes!="Normal",])==0)
		{	warning("The provided extreme categories are outside the plot bounds. The extremes will not be drawn.")
			extreme_limits <- NULL
		} else {
			plot <- plot +
				geom_segment(data=hatch.df[hatch.df$extremes!="Normal",],aes(x=x,y=y,xend=xend,yend=yend,color=extremes))
		}
	}

	#------------------------
	# Add obs line
	#------------------------
	if(!is.null(obs)) {
		plot <- plot +
			geom_vline(data=obs_dt,aes(xintercept=value),linetype="dashed",color=colorObs)
	}

	#------------------------
	# Add ensemble members
	#------------------------
	if(add_ensmemb=="below") {
		plot <- plot +
			geom_rect(aes(xmin=-Inf,xmax=Inf,ymin=-Inf,ymax=-pan_height/10),fill="gray95",color="black",width=0.2) + # this adds a grey box for ensmembers
			geom_point(data=jitter_df,color="black",fill=colorMember,alpha=1,aes(x=x,y=-pan_height/10-magicratio*yjitter,shape="Ensemble members")) # this adds the ensemble members
	} else if(add_ensmemb=="above") {
		plot <- plot +
			geom_point(data=jitter_df,color="black",fill=colorMember,alpha=1,aes(x=x,y=0.7*magicratio*yjitter,shape="Ensemble members"))
	}

	#------------------------
	# Add obs diamond
	#------------------------
	if(!is.null(obs)) {
		plot <- plot +
			geom_point(data=obs_xy,aes(x=x,y=ymax,size="Observation"),shape=23,color="black",fill=colorObs) # this adds the obs diamond
	}

	#------------------------
	# Compute probability for each tercile and identify MLT
	#------------------------
	tmp.dt <- data.table(tmp.df)
	pct <- tmp.dt[,.(pct=integrate(approxfun(x,ymax),lower=min(x),upper=max(x))$value),by=.(init,tercile)]
	tot <- pct[,.(tot=sum(pct)),by=init]
	pct <- merge(pct,tot,by="init")
	pct$pct <- round(100*pct$pct/pct$tot,0)
	pct$MLT <- pct[,.(MLT=pct==max(pct)),by=init]$MLT
	lab_pos <- c(tercile_limits[1],mean(tercile_limits),tercile_limits[2])

	#------------------------
	# Compute probability for extremes
	#------------------------
	if(!is.null(extreme_limits)) {
		pct2 <- tmp.dt[,.(pct=integrate(approxfun(x,ymax),lower=min(x),upper=max(x))$value),by=.(init,extremes)]
		tot2 <- pct2[,.(tot=sum(pct)),by=init]
		pct2 <- merge(pct2,tot2,by="init")
		pct2$pct <- round(100*pct2$pct/pct2$tot,0)
		lab_pos2 <- c(extreme_limits[1],NA,extreme_limits[2])
		pct2 <- merge(pct2,max.df,by=c("init","extremes"))
		# include potentially missing groups
		pct2 <- pct2[CJ(levels(pct2$init),factor(c("Below P10","Normal","Above P90"),levels=c("Below P10","Normal","Above P90"))),]
	}

	#------------------------
	# Add probability labels for terciles
	#------------------------
	if(add_ensmemb=="above") {
		labpos= -0.2*pan_height
		vjust=0
	} else { 
		labpos=0 
		vjust=-0.5
	}
	plot <- plot +
		geom_text(data=pct,aes(x=lab_pos[as.integer(tercile)],y=labpos,label=paste0(pct,"%"),hjust=as.integer(tercile)*1.5-2.5),vjust=vjust,angle=-90,size=3.2) +
		geom_text(data=pct[MLT==T,],aes(x=lab_pos[as.integer(tercile)],y=labpos,label="*",hjust=as.integer(tercile)*3.5-5.0),vjust=0.1,angle=-90,size=7,color="black")

	#------------------------
	# Add probability labels for extremes
	#------------------------
	if(!is.null(extreme_limits)) {
		plot <- plot +
			geom_text(data=pct2[extremes!="Normal",],aes(x=lab_pos2[as.integer(extremes)],y=0.9*y,label=paste0(pct,"%"),hjust=as.integer(extremes)*1.5-2.5),vjust=-0.5,angle=-90,size=3.2,color=rep(colorLab,dim(fcst_df)[2]))
	}

	#------------------------
	# Finish all theme and legend details
	#------------------------
	plot <- plot + 
		theme_minimal() +
		scale_fill_manual(name="Probability of\nterciles",breaks=c("Above normal","Normal","Below normal"),values=colorFill,drop=F) + 
		scale_color_manual(name="Probability of\nextremes",values=colorHatch) +
		scale_shape_manual(name="Ensemble\nmembers",values=c(21)) +
		scale_size_manual(name="Observation",values=c(3)) +
		labs(x=varname,y="Probabilty density\n(total area=1)",title=title) +
		theme(axis.text.x=element_blank(),panel.grid.minor.x = element_blank(),legend.key.size =  unit(0.3, "in"),panel.border=element_rect(fill = NA, color = "black"),strip.background = element_rect(colour="black", fill="gray80"),panel.spacing=unit(0.2, "in"),panel.grid.major.x=element_line(color="grey93")) + 
		guides(fill = guide_legend(order=1), color = guide_legend(order=2,reverse=T), shape = guide_legend(order=3,label=F), size=guide_legend(order=4,label=F))

	#------------------------
	# Save to plotfile if needed, and return plot
	#------------------------
	if(!is.null(plotfile)) {
		ggsave(plotfile,plot)
	}

	return(plot)
}

#==================
# A function to distribute ensemble members so not to overlap
# Requires a sorted array of values.
#==================
jitter_ensmemb <- function(x,thr=0.1) {
	# Idea: start with first level. Loop all points, 
	# and if distance to last point in the level is more than a threshold,
	# include the point to the level.
	# Otherwise keep the point for another round.
	# Do one round in each direction to avoid uggly patterns.
	if(is.unsorted(x)) { stop("Provide a sorted array!") }

        lev <- x*0
        level <- 1
        while(any(lev==0)) {
                last <- -1/0
                for(i in 1:length(x)) {
                        if (lev[i]!=0) { next }
                        if (x[i]-last>thr) {
                                lev[i] <- level
                                last <- x[i]
                        }
                }
                level <- level+1
                last <- 1/0
                for(i in seq(length(x),1,-1)) {
                        if (lev[i]!=0) { next }
                        if (last-x[i]>thr) {
                                lev[i] <- level
                                last <- x[i]
                        }
                }
                level <- level+1
        }
        return(lev*thr*sqrt(3)/2)
}


#==================
# Hatch one polygon
# Based on base polygon function
#==================
polygon.onehatch <- function(x, y, x0, y0, xd, yd, fillOddEven=F) {
    halfplane <- as.integer(xd * (y - y0) - yd * (x - x0) <= 0)
    cross <- halfplane[-1L] - halfplane[-length(halfplane)]
    does.cross <- cross != 0
    if (!any(does.cross))
        return()
    x1 <- x[-length(x)][does.cross]
    y1 <- y[-length(y)][does.cross]
    x2 <- x[-1L][does.cross]
    y2 <- y[-1L][does.cross]
    t <- (((x1 - x0) * (y2 - y1) - (y1 - y0) * (x2 - x1))/(xd * (y2 - y1) - yd * (x2 - x1)))
    o <- order(t)
    tsort <- t[o]
    crossings <- cumsum(cross[does.cross][o])
    if (fillOddEven)
        crossings <- crossings%%2
    drawline <- crossings != 0
    lx <- x0 + xd * tsort
    ly <- y0 + yd * tsort
    lx1 <- lx[-length(lx)][drawline]
    ly1 <- ly[-length(ly)][drawline]
    lx2 <- lx[-1L][drawline]
    ly2 <- ly[-1L][drawline]
    return(data.frame(x=lx1,y=ly1,xend=lx2,yend=ly2))
}

#==================
# Hatch one polygon
# Based on base polygon function
#==================
polygon.fullhatch <- function(x, y, density, angle, width_units, height_units, inches=c(5,1)) {
    x <- c(x, x[1L])
    y <- c(y, y[1L])
    angle <- angle%%180
    #usr <- par("usr")
    #inches <- par("pin")
    upi <- c(width_units, height_units)/inches
    if (upi[1L] < 0)
        angle <- 180 - angle
    if (upi[2L] < 0)
        angle <- 180 - angle
    upi <- abs(upi)
    xd <- cos(angle/180 * pi) * upi[1L]
    yd <- sin(angle/180 * pi) * upi[2L]
    hatch_ls <- list()
    i <- 1
    if (angle < 45 || angle > 135) {
        if (angle < 45) {
            first.x <- max(x)
            last.x <- min(x)
        }
        else {
            first.x <- min(x)
            last.x <- max(x)
        }
        y.shift <- upi[2L]/density/abs(cos(angle/180 * pi))
        x0 <- 0
        y0 <- floor((min(y) - first.x * yd/xd)/y.shift) * y.shift
        y.end <- max(y) - last.x * yd/xd
        while (y0 < y.end) {
            hatch_ls[[i]] <- polygon.onehatch(x, y, x0, y0, xd, yd)
            i <- i+1
            y0 <- y0 + y.shift
        }
    }
    else {
        if (angle < 90) {
            first.y <- max(y)
            last.y <- min(y)
        }
        else {
            first.y <- min(y)
            last.y <- max(y)
        }
        x.shift <- upi[1L]/density/abs(sin(angle/180 * pi))
        x0 <- floor((min(x) - first.y * xd/yd)/x.shift) * x.shift
        y0 <- 0
        x.end <- max(x) - last.y * xd/yd
        while (x0 < x.end) {
            hatch_ls[[i]] <- polygon.onehatch(x, y, x0, y0, xd, yd)
            i <- i+1
            x0 <- x0 + x.shift
        }
    }
    return(do.call("rbind",hatch_ls))
}