regionmean.py 10.4 KB
Newer Older
# coding=utf-8
"""Diagnostic to compute regional averages"""
import iris.util
import iris.analysis
import iris.exceptions
import numpy as np

from earthdiagnostics.box import Box
from earthdiagnostics.constants import Basins
from earthdiagnostics.diagnostic import Diagnostic, DiagnosticOption, DiagnosticIntOption, DiagnosticDomainOption, \
    DiagnosticBoolOption, DiagnosticBasinOption, DiagnosticVariableOption
from earthdiagnostics.modelingrealm import ModelingRealms
from earthdiagnostics.utils import Utils, TempFile
from bscearth.utils.log import Log


class RegionMean(Diagnostic):
    """
    Computes the mean value of the field (3D, weighted).

    For 3D fields, a horizontal mean for each level is also given. If a spatial window
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
    is specified, the mean value is computed only in this window.

    :original author: Javier Vegas-Regidor <javier.vegas@bsc.es>

Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
    :created: March 2017

    :param data_manager: data management object
    :type data_manager: DataManager
    :param startdate: startdate
    :type startdate: str
    :param member: member number
    :type member: int
    :param chunk: chunk's number
    :type chunk: int
    :param variable: variable to average
    :type variable: str
    :param box: box used to restrict the vertical mean
    :type box: Box
    """

    alias = 'regmean'
    "Diagnostic alias for the configuration file"

    def __init__(self, data_manager, startdate, member, chunk, domain, variable, box, save3d, weights_file,
                 variance, basin):
        Diagnostic.__init__(self, data_manager)
        self.startdate = startdate
        self.member = member
        self.chunk = chunk
        self.domain = domain
        self.variable = variable
        self.box = box
        self.save3d = save3d
        self.weights_file = weights_file
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
        self.variance = variance
        self.declared = {}
        self.lat_name = 'lat'
        self.lon_name = 'lon'

    def __eq__(self, other):
        if type(self) is not type(other):
            return False
        return self.startdate == other.startdate and self.member == other.member and self.chunk == other.chunk and \
            self.box == other.box and self.variable == other.variable

    def __str__(self):
        return 'Region mean Startdate: {0.startdate} Member: {0.member} Chunk: {0.chunk} Variable: {0.variable} ' \
               'Box: {0.box} Save 3D: {0.save3d} Save variance: {0.variance}'.format(self)

    def __hash__(self):
        return hash(str(self))

    @classmethod
    def generate_jobs(cls, diags, options):
        """
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
        Create a job for each chunk to compute the diagnostic

        :param diags: Diagnostics manager class
        :type diags: Diags
        :param options: variable, minimum depth (level), maximum depth (level)
        :type options: list[str]
        :return:
        """
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
        options_available = (DiagnosticDomainOption(),
                             DiagnosticVariableOption(diags.data_manager.config.var_manager),
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
                             DiagnosticOption('grid_point', 'T'),
                             DiagnosticBasinOption('basin', Basins().Global),
                             DiagnosticIntOption('min_depth', 0),
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
                             DiagnosticIntOption('max_depth', 0),
                             DiagnosticBoolOption('save3D', True),
                             DiagnosticBoolOption('variance', False),
                             DiagnosticOption('grid', ''))
        options = cls.process_options(options, options_available)

        box = Box()
        box.min_depth = options['min_depth']
        box.max_depth = options['max_depth']
        weights_file = TempFile.get()
        weight_diagnostics = ComputeWeights(diags.data_manager, options['grid_point'],  options['basin'], weights_file)

        job_list = list()
        for startdate, member, chunk in diags.config.experiment.get_chunk_list():
            job = RegionMean(diags.data_manager, startdate, member, chunk,
                             options['domain'], options['variable'], box,
                             options['save3D'], weights_file, options['variance'], options['basin'])
            job.add_subjob(weight_diagnostics)
            job_list.append(job)

        return job_list

    def request_data(self):
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
        """Request data required by the diagnostic"""
        self.variable_file = self.request_chunk(self.domain, self.variable, self.startdate, self.member, self.chunk)
    def declare_data_generated(self):
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
        """Declare data to be generated by the diagnostic"""
        if self.box.min_depth == 0:
            # To cdftools, this means all levels
            box_save = None
        else:
            box_save = self.box

        self._declare_var('mean', False, box_save)
        self._declare_var('mean', True, box_save)

        if self.variance:
            self._declare_var('var', False, box_save)
            self._declare_var('var', True, box_save)
    def compute(self):
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
        """Run the diagnostic"""
        handler = Utils.open_cdf(self.variable_file.local_file)
        var = handler.variables[self.variable]
        coordinates = ''
        has_levels = False
        for dimension in handler.variables.keys():
            if dimension in ['time', 'lev', 'lat', 'latitude', 'lon', 'longitude']:
                coordinates += ' {0}'.format(dimension)
            if dimension == 'lev':
                has_levels = True
        var.coordinates = coordinates
        handler.close()

        data = iris.load_cube(self.variable_file.local_file, iris.AttributeConstraint(short_name=self.variable))
        Log.info(str(data))
        weights = iris.load_cube(self.weights_file, 'weights').data
        if has_levels:
            mean = iris.cube.CubeList()
            mean3d = iris.cube.CubeList()
            var = iris.cube.CubeList()
            var3d = iris.cube.CubeList()
            for time_slice in data.slices_over('time'):
                mean.append(time_slice.collapsed(['latitude', 'longitude', 'depth'],
                                                 iris.analysis.MEAN, weights=weights))
                mean3d.append(time_slice.collapsed(['latitude', 'longitude'],
                                                   iris.analysis.MEAN, weights=weights))
                # var.append(time_slice.collapsed(['latitude', 'longitude', 'depth'],
                #                                 iris.analysis.VARIANCE, weights=weights))
                # var3d.append(time_slice.collapsed(['latitude', 'longitude'],
                #                                   iris.analysis.VARIANCE, weights=weights))
            self._send_var('mean', True, mean3d.merge_cube())
            self._send_var('mean', False, mean.merge_cube())
            if self.variance:
                self._send_var('var', True, var3d.merge_cube())

                self._send_var('var', False, var.merge_cube())
        else:
            mean = iris.cube.CubeList()
            var = iris.cube.CubeList()
            for time_slice in data.slices_over('time'):
                mean.append(time_slice.collapsed(['latitude', 'longitude'], iris.analysis.MEAN, weights=weights))
                var.append(time_slice.collapsed(['latitude', 'longitude'], iris.analysis.VARIANCE, weights=weights))

            self._send_var('mean', False, mean.merge_cube())
            if self.variance:
                self._send_var('var', False, var.merge_cube())
    def _declare_var(self, var, threed, box_save):
        if threed:
            if not self.save3d:
                return False
            final_name = '{1}3d{0}iris'.format(var, self.variable)
        else:
            final_name = '{1}{0}iris'.format(var, self.variable)

        self.declared[final_name] = self.declare_chunk(ModelingRealms.ocean, final_name, self.startdate, self.member,
                                                       self.chunk, box=box_save, region=self.basin)

    def _send_var(self, var, threed, cube):
        if threed:
            final_name = '{1}3d{0}iris'.format(var, self.variable)
        else:
            final_name = '{1}{0}iris'.format(var, self.variable)
        cube.var_name = 'result'
        temp = TempFile.get()
        iris.save(cube, temp)
        self.declared[final_name].set_local_file(temp, diagnostic=self, rename_var='result', region=self.basin)


class ComputeWeights(Diagnostic):
    """
    Diagnostic used to compute regional mean and sum weights

    Parameters
    ----------
    data_manager: DataManager
    grid_point: str
    basin: int
    weights_file: str
    """

    alias = 'computeinterpcdoweights'
    "Diagnostic alias for the configuration file"

    @classmethod
    def generate_jobs(cls, diags, options):
        pass

    def __init__(self, data_manager, grid_point, basin, weights_file):
        Diagnostic.__init__(self, data_manager)
        self.weights_file = weights_file
        self.basin = basin
        self.grid_point = grid_point.lower()

    def __eq__(self, other):
        if type(self) is not type(other):
            return False
        return self.weights_file == other.weights_file and self.basin == other.basin and \
               self.grid_point == other.grid_point

    def __str__(self):
        return 'Computing weights for region averaging: Point {0.grid_point} Basin: {0.basin}'.format(self)

    def __hash__(self):
        return hash(str(self))

    def compute(self):
        """Compute weights"""
        iris.FUTURE.netcdf_promote = True
        iris.FUTURE.netcdf_no_unlimited = True


        mask = np.squeeze(Utils.get_mask(self.basin))
        e1 = self.try_load_cube(1)
        e2 = self.try_load_cube(2)
        e3 = self.try_load_cube(3)
        mask = e1 * e2 * mask
        iris.util.demote_dim_coord_to_aux_coord(e3, 'lev')
        for coord in e3.coords():
            e3.remove_coord(coord)
        for coord in mask.coords():
            mask.remove_coord(coord)
        weights = e3 * mask
        weights.var_name = 'weights'
        iris.save(weights, self.weights_file)

    def try_load_cube(self, number):
        try:
            cube = iris.load_cube('mesh_hgr.nc', 'e{0}{1}'.format(number, self.grid_point))
        except iris.exceptions.ConstraintMismatchError:
            cube = iris.load_cube('mesh_hgr.nc', 'e{0}{1}_0'.format(number, self.grid_point))
        return iris.util.squeeze(cube)

    def request_data(self):
        """Request data required by the diagnostic"""

    def declare_data_generated(self):
        """Declare data to be generated by the diagnostic"""
        pass