DataCouplerFileVar.py 10.9 KB
Newer Older
# Copyright 2022 - Barcelona Supercomputing Center
# Authors: Etienne Tourigny, Rodrigo Martín Posada
# MIT License
from netCDF4 import Dataset  # pylint: disable=no-name-in-module
import logging
import numpy as np
from datetime import date, timedelta
import pyoasis

from DataCouplerUtils import days_since_refyear, filename_from_pattern, scale_factor_from_str

    def __init__(self, conf, fix_year, start_year):
        self.id = conf['id']
        self.oasis_name = conf['oasis_name']
        self.grid_name = conf['grid_name']
        self.file_pattern = conf['file_pattern']
        self.var_name = conf['netcdf_variable']
        self.ref_year = conf['yref_min']
        self.timestep = conf['timestep']
        self.interpolate = conf['interpolate']
        self.scale_factor = scale_factor_from_str(conf['scale_factor']) if type(conf['scale_factor']) == str else conf['scale_factor']
        self.offset = conf['offset']
        self.min = conf['min'] if conf['min'] is not None else np.NINF
        self.max = conf['max'] if conf['max'] is not None else np.Inf
        self.update_timestep = conf['update']
        self.accum = conf['accum']

        self.fix_year = fix_year
        # Initialization
        self.__get_grids(start_year)
        self.field = pyoasis.asarray(np.empty((2,self.grid_nx,self.grid_ny)))
        # Field that will be sent
        self._send_field = pyoasis.asarray(np.empty((self.grid_nx,self.grid_ny)))
        # Current netCDF dataset
        self.dataset = None
        # Name of the current dataset file
        self.current_file = ''
        self.time_array = None
        # Index in time_array
        self.t_index = 0
        # Dummy values
        # Timesteps [days since refdate]
        # Step before with data
        self.t1 = -9999.
        # Step after with data
        self.t2 = -9999.
        # Step at Jan 1st
        self.t0101 = -9999.
        # Step at Dec 31st
        self.t1231 = -9999.
        self.tnewyear = -9999.

    def setup(self, yy, t_global, tnewyear_global):
        # Dummy values
        self.t1 = -9999.
        self.t2 = -9999.
        self.t_index = 0
        self.t0101 = float(days_since_refyear(yy, 1, 1, self.ref_year))
        self.t1231 = float(days_since_refyear(yy, 12, 31, self.ref_year))
        # Current timestep
        t_local = self.t0101 + t_global
        self.tnewyear = self.t0101 + tnewyear_global

        logging.debug('{}: INIT - t       : {}'.format(self.id, t_local))
        logging.debug('{}: INIT - t0101   : {}'.format(self.id, self.t0101))
        logging.debug('{}: INIT - t1231   : {}'.format(self.id, self.t1231))
        logging.debug('{}: INIT - tnewyear: {}'.format(self.id, self.tnewyear))

        curr_date = date(self.ref_year, 1, 1) + timedelta(days=t_local)
        file_name = filename_from_pattern(self.file_pattern, curr_date)
        self.__open_file(file_name)
        if t_local >= self.time_array[0]:
            self.t_index = np.searchsorted(self.time_array, t_local, side='right')-1
        self.t2 = self.time_array[self.t_index]
        self.t1 = self.t2

        logging.debug('{}: INIT - t_index: {}'.format(self.id, self.t_index))

        self.__read_file()

    def update(self, t_global):
        t_local = self.t0101 + t_global
        logging.debug('update var {} at time {}'.format(self.id, t_local))

        while t_local > self.t2:
            self.__read_next_timestep_from_files(t_local)

        var_max = self.max
        var_min = self.min
        # Monthly interpolation
        if self.timestep == 'monthly':
            if self.interpolate:
                if self.t1 < self.t2:
                    logging.debug('{}: time interpolation t1,t,t2 {} {} {}'.format(self.id, self.t1, t_local, self.t2))
                    raw_field = ((t_local-self.t1)*self.field[1] + (self.t2-t_local)*self.field[0])/(self.t2-self.t1)
                else:
                    logging.debug('{}: no time interpolation t,t2 {} {}'.format(self.id, t_local, self.t2))
                    # Hardcoded to match original Fortran amip-forcing clipping avoidance
                    if self.oasis_name == 'AMIP_sst':
                        var_min = np.NINF
                    raw_field = self.field[1]
            else:
                # Nearest neighbour
                if self.t1 < self.t2:
                    if np.floor(t_local)-np.floor(self.t1) < np.floor(self.t2)-np.floor(t_local):
                    # bugfix for rounding of last day of the month - partial fix which does not work for all months
                    #if np.floor(t_local)-np.floor(self.t1) <= np.floor(self.t2)-np.floor(t_local):
                logging.debug('{}: nearest neigbour nn={} t1={} t={} t2={}'.format(
                    self.id, nn, self.t1, t_local, self.t2))
                # Hardcoded to match Fortran amip-forcing clipping avoidance
                if self.oasis_name == 'AMIP_sst':
                    var_min = np.NINF
                raw_field = self.field[nn]
        # Daily (no interpolation)
        else:
            raw_field = self.field[1]
        # Clipping required for CMIP6 BCs
        self._send_field = np.maximum(var_min, np.minimum(var_max, raw_field))

    def send_field(self):
        return self._send_field

    def finalise(self):
        self.__close_current_file()

    def __get_grids(self, start_year):
        '''
        Sets `self.grid_lat` and `self.grid_lon` from the first input netCDF file.
        '''
        yy = self.fix_year if self.fix_year > 0 else start_year
        file_name = filename_from_pattern(self.file_pattern, date(yy, 1, 1))
        ds = Dataset(file_name, 'r')
        variable_names = ds.variables.keys()

        # Look for lat and lon variables to create a 2d grid. 
        # If not found, look for sector to greate a 1d grid (only 3 sectors for co2 supported for now).
        if 'lat' in variable_names:
            lat = ds['lat'][:]
        elif 'latitude' in variable_names:
            lat = ds['latitude'][:]
        else:
            lat  = None
        if 'lon' in variable_names:
            lon = ds['lon'][:]
        elif 'longitude' in variable_names:
            lon = ds['longitude'][:]
        else:
            lon = None
        if lat is None or lon is None:
            if 'sector' in variable_names:
                num_sectors = len(ds['sector'][:])
                if num_sectors == 3 and False:
                    # 2 hemisphere box model - not used anymore
                    self.grid_ny = 2
                    self.grid_nx = 1
                    lat=[-45,45]
                    lon=0
                    self.dy=90
                    self.dx=360
                elif num_sectors == 3:
                    # global box model
                    self.grid_ny = 1
                    self.grid_nx = 1
                    lat=0
                    lon=0
                    self.dy=0
                        'FileInputVar: {}: Could not find sector variable of length 3. Invalid input netCDF file {}.'.format(self.id, file_name))
                    'FileInputVar: {}: Could not find lon or longitude variable name. Invalid input netCDF file {}.'.format(self.id, file_name))
        else:
            self.grid_ny = lat.size
            self.grid_nx = lon.size
            self.dy = lat[1]-lat[0] if self.grid_ny>1 else 0
            self.dx = lon[1]-lon[0] if self.grid_nx>1 else 0

        # Create 2-D grid for OASIS
        self.grid_lon = np.tile(lon, (self.grid_ny, 1)).T
        self.grid_lat = np.tile(lat, (self.grid_nx, 1))

        ds.close()

    def __read_next_timestep_from_files(self, t_local):
        self.t1 = self.t2
        self.field[0] = self.field[1]

        if t_local > np.max(self.time_array):
            old_file_name = self.current_file
            self.__close_current_file()
            curr_date = date(self.ref_year, 1, 1) + timedelta(days=t_local)
            file_name = filename_from_pattern(self.file_pattern, curr_date)
            if file_name == old_file_name:
                raise Exception('FileInputVar: {}: Could not find the next file for date {}.'.format(self.id, curr_date))
            self.__open_file(file_name)
            self.t_index = -1

        self.t_index += 1

        self.t2 = self.time_array[self.t_index]
        self.__read_file()

    def __open_file(self, file_name):
        self.current_file = file_name
        logging.debug('{}: open file  {}'.format(self.id, self.current_file))
        self.dataset = Dataset(file_name, 'r')
        self.time_array = self.dataset['time'][:]

    def __read_file(self):
        if self.fix_year > 0 and self.time_array[self.t_index] >= self.tnewyear:
            logging.debug('{}:...use first BC data for timestep {}'.format(self.id, self.time_array[self.t_index]))
            # Assign BC only if values have been assigned previously
            first_t_index = np.searchsorted(self.time_array, self.t0101, side='right')-1
            raw_field = self.dataset[self.var_name][first_t_index+1]
        elif self.fix_year > 0 and self.time_array[self.t_index] < self.t0101:
            logging.debug('{}:...use last BC data for timestep {}'.format(self.id, self.time_array[self.t_index]))
            # Assign BC only if values have been assigned previously
            last_t_index = np.searchsorted(self.time_array, self.t1231, side='right')-1
            raw_field = self.dataset[self.var_name][last_t_index]
        else:
            logging.debug('{}: ...read timestep {} {:<018}'.format(
                self.id, self.t_index, self.time_array[self.t_index]))
            raw_field = self.dataset[self.var_name][self.t_index]

        # This section can be "hardcoded" by the user
        # Hardcoded CO2 emissions data.
        logging.debug('reading {}'.format(self.var_name))
        #if self.var_name == 'CO2_em_anthro' or self.var_name == 'CO2_em_AIR_anthro':
        if self.grid_name != "global" and (self.var_name == 'CO2_em_anthro' or self.var_name == 'CO2_em_AIR_anthro'):
            #Read sum of all sectors / levels
            raw_field = np.sum(raw_field,axis=0)
        elif self.var_name == 'mole_fraction_of_carbon_dioxide_in_air':
            # raw_field = raw_field[1:] # use this to read NH and SH values
            raw_field = raw_field[0] # use this to read global values
            logging.debug('reading co2 raw_field={} raw_field.shape={} field.shape={}'.format(str(raw_field),str(raw_field.shape),str(self.field[1].shape)))
        self.field[1] = pyoasis.asarray(np.transpose(raw_field)).astype('float64')*self.scale_factor+self.offset
        logging.debug('read raw_field={} updating field[1]={}'.format(raw_field,self.field[1]))

    def __close_current_file(self):
        logging.debug('{}: close file {}'.format(self.id, self.current_file))
        self.dataset.close()