discretize.py 9.73 KB
Newer Older
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
# coding=utf-8
from bscearth.utils.date import parse_date, add_months
from bscearth.utils.log import Log

from earthdiagnostics.diagnostic import Diagnostic, DiagnosticVariableOption, DiagnosticDomainOption, \
    DiagnosticIntOption, DiagnosticFloatOption
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
from earthdiagnostics.utils import Utils, TempFile
from earthdiagnostics.variable_type import VariableType
import numpy as np
import iris
from iris.cube import Cube
import iris.coord_categorisation
from iris.time import PartialDateTime
import iris.exceptions
import iris.coords
import math
import psutil
import six
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed


class Discretize(Diagnostic):
    """
    Discretizes a variable

    :param data_manager: data management object
    :type data_manager: DataManager
    :param variable: variable to average
    :type variable: str
    """

    alias = 'discretize'
    "Diagnostic alias for the configuration file"

    Percentiles = np.array([0.1, 0.25, 0.33, 0.5, 0.66, 0.75, 0.9])

    def __init__(self, data_manager, startdate, domain, variable, num_bins, min_value, max_value):
        Diagnostic.__init__(self, data_manager)

        self.startdate = startdate
        self.variable = variable
        self.domain = domain

        self.realizations = None
        self.num_bins = num_bins
        self._bins = None
        self.cmor_var = data_manager.variable_list.get_variable(variable, silent=True)

        if not math.isnan(min_value):
            self.min_value = min_value
            self.check_min_value = False
        elif self.cmor_var and self.cmor_var.valid_min:
            self.min_value = float(self.cmor_var.valid_min)
            self.check_min_value = False
        else:
            self.min_value = None
            self.check_min_value = True

        if not math.isnan(max_value):
            self.max_value = max_value
            self.check_max_value = False
        elif self.cmor_var and self.cmor_var.valid_min:
            self.max_value = float(self.cmor_var.valid_max)
            self.check_max_value = False
        else:
            self.max_value = None
            self.check_max_value = True

        self.process = psutil.Process()

    def print_memory_used(self):
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
        Log.debug('Memory: {0:.2f} GB'.format(self.process.memory_info().rss / 1024.0**3))
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
    @property
    def bins(self):
        if self._bins is None:
            return self.num_bins
        return self._bins

    @bins.setter
    def bins(self, value):
        self._bins = value

    def __eq__(self, other):
        return self.domain == other.domain and self.variable == other.variable and self.num_bins == other.num_bins and \
               self.min_value == other.min_value and self.max_value == other.max_value and \
               self.startdate == other.startdate

    def __str__(self):
        return 'Discretizing variable: {0.domain}:{0.variable} Startdate: {0.startdate} ' \
               'Bins: {0.num_bins} Range: [{0.min_value}, {0.max_value}]'.format(self)

    @classmethod
    def generate_jobs(cls, diags, options):
        """
        Creates a job for each chunk to compute the diagnostic

        :param diags: Diagnostics manager class
        :type diags: Diags
        :param options: domain, variable, percentil number, maximum depth (level)
        :type options: list[str]
        :return:
        """
        options_available = (DiagnosticDomainOption(),
                             DiagnosticVariableOption(),
                             DiagnosticIntOption('bins', 2000),
                             DiagnosticFloatOption('min_value', float('nan')),
                             DiagnosticFloatOption('max_value', float('nan')),
                             )
        options = cls.process_options(options, options_available)

        job_list = list()
        for startdate in diags.config.experiment.startdates:
            job_list.append(Discretize(diags.data_manager, startdate, options['domain'], options['variable'],
                                       options['bins'], options['min_value'], options['max_value']))
        return job_list

    def request_data(self):
        self.original_data = self.request_chunk(self.domain, self.variable, self.startdate, None, None)

    def declare_data_generated(self):
        var_name = '{0.variable}_dis'.format(self)
        self.discretized_data = self.declare_chunk(self.domain, var_name, self.startdate, None, None,
                                                   vartype=VariableType.STATISTIC)

    def compute(self):
        """
        Runs the diagnostic
        """
        self.print_memory_used()
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
        iris.FUTURE.netcdf_promote = True
        self._load_cube()
        self.print_memory_used()
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
        self._get_value_interval()
        self.print_memory_used()
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
        Log.info('Range: [{0}, {1}]', self.min_value, self.max_value)
        self._get_distribution()
        self.print_memory_used()
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
        self._save_results()
        self.print_memory_used()
        del self.distribution
        del self.data_cube
        self.print_memory_used()
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed

    def _load_cube(self):

        handler = Utils.openCdf(self.original_data.local_file)
        if 'realization' in handler.variables:
            handler.variables[self.variable].coordinates = 'realization'
        handler.close()
        data_cube = iris.load_cube(self.original_data.local_file)

        date = parse_date(self.startdate)
        lead_date = add_months(date, 1, self.data_manager.config.experiment.calendar)
        leadtimes = {1: PartialDateTime(lead_date.year, lead_date.month, lead_date.day)}

        def assign_leadtime(coord, x):
            leadtime_month = 1
            partial_date = leadtimes[leadtime_month]
            while coord.units.num2date(x) >= partial_date:
                leadtime_month += 1
                try:
                    partial_date = leadtimes[leadtime_month]
                except KeyError:
                    new_date = add_months(date, leadtime_month, self.data_manager.config.experiment.calendar)
                    partial_date = PartialDateTime(new_date.year, new_date.month, new_date.day)
                    leadtimes[leadtime_month] = partial_date
            return leadtime_month

        iris.coord_categorisation.add_categorised_coord(data_cube, 'leadtime', 'time', assign_leadtime)
        self.data_cube = data_cube

    def _save_results(self):
        Log.debug('Saving results...')

        bins = np.zeros(self.num_bins)
        bins_bounds = np.zeros((self.num_bins, 2))

        for x in range(self.num_bins):
            bins[x] = (self.bins[x+1] - self.bins[x]) / 2 + self.bins[x]
            bins_bounds[x, 0] = self.bins[x]
            bins_bounds[x, 1] = self.bins[x+1]

        bins_coord = iris.coords.DimCoord(bins, var_name='bin', units=self.data_cube.units, bounds=bins_bounds)

        cubes = iris.cube.CubeList()

        for leadtime, distribution in six.iteritems(self.distribution):
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
            leadtime_cube = Cube(distribution.astype(np.uint32), var_name=self.data_cube.var_name,
                                 standard_name=self.data_cube.standard_name, units='1')
            leadtime_cube.add_dim_coord(bins_coord, 0)
            leadtime_cube.add_dim_coord(self.data_cube.coord('latitude'), 1)
            leadtime_cube.add_dim_coord(self.data_cube.coord('longitude'), 2)
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
            leadtime_cube.add_aux_coord(iris.coords.AuxCoord(np.array((leadtime,), np.int8), var_name='leadtime',
                                                             units='months'))
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
            cubes.append(leadtime_cube)
        temp = TempFile.get()
        iris.FUTURE.netcdf_no_unlimited = True
        iris.save(cubes.merge_cube(), temp, zlib=True)
        self.discretized_data.set_local_file(temp, rename_var=self.data_cube.var_name)

    def _get_distribution(self):
        self.distribution = {}
        Log.debug('Discretizing...')

        for leadtime in set(self.data_cube.coord('leadtime').points):
            Log.debug('Discretizing leadtime {0}', leadtime)
            leadtime_cube = self.data_cube.extract(iris.Constraint(leadtime=leadtime))
            for realization in self.data_cube.coord('realization').points:
                Log.debug('Discretizing realization {0}', realization)
                self.print_memory_used()
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
                try:
                    realization_cube = leadtime_cube.extract(iris.Constraint(realization=realization))
                except iris.exceptions.CoordinateNotFoundError:
                    realization_cube = leadtime_cube
                if realization_cube is None and realization == 0:
                    realization_cube = leadtime_cube
                if leadtime not in self.distribution:
                    self.distribution[leadtime] = self._calculate_distribution(realization_cube)
                else:
                    self.distribution[leadtime] += self._calculate_distribution(realization_cube)

Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
    # noinspection PyTypeChecker
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
    def _get_value_interval(self):
        if self.check_min_value or self.check_max_value:
            Log.debug('Calculating max and min values...')
            for time_slice in self.data_cube.slices_over('time'):
                if self.check_min_value:
                    file_min = np.amin(time_slice.data)
                    if self.min_value is None:
                        self.min_value = file_min
                    self.min_value = min(self.min_value, file_min)

                if self.check_max_value:
                    file_max = np.amax(time_slice.data)
                    self.max_value = max(self.max_value, file_max)

    def _calculate_distribution(self, data_cube):
        def calculate_histogram(time_series):
            histogram, self.bins = np.histogram(time_series, bins=self.bins,
                                                range=(self.min_value, self.max_value))
            return histogram

        return np.apply_along_axis(calculate_histogram, 0, data_cube.data)