diff --git a/environment.yml b/environment.yml index a1943e8ce1013cdf57b04669f6dd57b40cf68a53..21ca5cd3f432cf9216656580d2f0e38d5496d298 100755 --- a/environment.yml +++ b/environment.yml @@ -28,6 +28,8 @@ dependencies: - jupyter - ipykernel - pip + - gdal + - numba - pip: - argcomplete diff --git a/nes/cli/__init__.py b/nes/cli/__init__.py index 110a2569338c398d302830e7492361a74b61defc..bc4935a7db88949a033890ac1571e384ba168738 100644 --- a/nes/cli/__init__.py +++ b/nes/cli/__init__.py @@ -1,8 +1,12 @@ from .checker import run_checks from .reorder_longitudes import reorder_longitudes from .interpolate import interpolate -from .geostructure import nc2geostructure +from .geostructure import nc2geostructure, nc2mbtiles __all__ = [ - 'run_checks', 'reorder_longitudes', 'interpolate', 'nc2geostructure' -] \ No newline at end of file + "run_checks", + "reorder_longitudes", + "interpolate", + "nc2geostructure", + "nc2mbtiles", +] diff --git a/nes/cli/cli.py b/nes/cli/cli.py index 124af48e298561009376020cc43547d95fa71caa..4d478c216fb2e54357a2055e8aca19da49ed2a96 100644 --- a/nes/cli/cli.py +++ b/nes/cli/cli.py @@ -4,6 +4,48 @@ from configargparse import ArgParser import argcomplete +def _add_nc2mbtiles_subparser(subparsers): + """ + Add the 'mbtiles' subcommand to the NES CLI. + + This command converts a monthly NetCDF into daily mbtiles. + + Parameters + ---------- + subparsers : argparse._SubParsersAction + The subparsers object returned by `add_subparsers()` on the main parser. + """ + from nes.cli import nc2mbtiles + + mbtiles_parser = subparsers.add_parser( + "nc2mbtiles", help="Convert a monthly NetCDF into daily mbtiles" + ) + mbtiles_parser.add_argument( + "-i", "--input_file", required=True, help="Path to input NetCDF file" + ) + mbtiles_parser.add_argument( + "-o", + "--output_path", + required=True, + help="Path to the path where the mbtiles will be saved", + ) + mbtiles_parser.add_argument( + "-f", + "--mbtiles_file", + required=True, + help="Name for the mbtile files, will be used as a prefix for the daily mbtiles", + ) + mbtiles_parser.add_argument( + "-gs", + "--geostructure", + required=False, + default=None, + help="Optional path to geostructure file, allowing to skip the manual computation.", + ) + + mbtiles_parser.set_defaults(func=nc2mbtiles) + + def _add_nc2geostructure_subparser(subparsers): """ Add the 'geostructure' subcommand to the NES CLI. @@ -18,15 +60,33 @@ def _add_nc2geostructure_subparser(subparsers): from nes.cli import nc2geostructure # TODO: TEST - geo_parser = subparsers.add_parser("nc2geostructure", help="Convert NetCDF to geospatial structure (GeoJSON, shapefile) (TESTING PHASE)") - geo_parser.add_argument("-i", "--input", required=True, help="Path to input NetCDF file") - geo_parser.add_argument("-o", "--output", required=True, help="Path to output geostructure") + geo_parser = subparsers.add_parser( + "nc2geostructure", + help="Convert NetCDF to geospatial structure (GeoJSON, shapefile) (TESTING PHASE)", + ) geo_parser.add_argument( - "--var-list", nargs="+", - help="List of variable names to include in the geostructure. If omitted, all variables will be included." + "-i", "--input", required=True, help="Path to input NetCDF file" + ) + geo_parser.add_argument( + "-o", "--output", required=True, help="Path to output geostructure" + ) + geo_parser.add_argument( + "--var-list", + nargs="+", + help="List of variable names to include in the geostructure. If omitted, all variables will be included.", + ) + geo_parser.add_argument( + "--time-step", + type=int, + default=0, + help="Time step index to extract (default: 0)", + ) + geo_parser.add_argument( + "--level", + type=int, + default=0, + help="Vertical level index to extract (default: 0)", ) - geo_parser.add_argument("--time-step", type=int, default=0, help="Time step index to extract (default: 0)") - geo_parser.add_argument("--level", type=int, default=0, help="Vertical level index to extract (default: 0)") geo_parser.set_defaults(func=nc2geostructure) @@ -45,70 +105,148 @@ def _add_interpolate_subparser(subparsers): from nes.cli import interpolate # TODO: TEST - interp_parser = subparsers.add_parser("interpolate", help="Interpolate data onto a different grid (TESTING PHASE)") + interp_parser = subparsers.add_parser( + "interpolate", help="Interpolate data onto a different grid (TESTING PHASE)" + ) # Main input/output general = interp_parser.add_argument_group("General options") - general.add_argument("-i", "--input", required=True, help="Path to source NetCDF file") + general.add_argument( + "-i", "--input", required=True, help="Path to source NetCDF file" + ) general.add_argument("-o", "--output", help="Path to output NetCDF file") - general.add_argument("--axis", choices=["horizontal", "vertical"], default="horizontal", - help="Interpolation axis (default: horizontal)") + general.add_argument( + "--axis", + choices=["horizontal", "vertical"], + default="horizontal", + help="Interpolation axis (default: horizontal)", + ) dst_group = general.add_mutually_exclusive_group(required=True) - dst_group.add_argument("-d", "--destination", help="Path to destination grid NetCDF file") - dst_group.add_argument("--projection", help="Projection type to generate destination grid (e.g. regular, rotated, lcc)") + dst_group.add_argument( + "-d", "--destination", help="Path to destination grid NetCDF file" + ) + dst_group.add_argument( + "--projection", + help="Projection type to generate destination grid (e.g. regular, rotated, lcc)", + ) # Horizontal interpolation options horizontal = interp_parser.add_argument_group("Horizontal interpolation options") - horizontal.add_argument("--kind", choices=["NearestNeighbour", "Conservative"], - help="Interpolation method for horizontal axis") - horizontal.add_argument("--n-neighbours", type=int, help="Number of neighbors (NearestNeighbour)") - horizontal.add_argument("--flux", action="store_true", help="Treat variables as fluxes (Conservative only)") - horizontal.add_argument("--keep-nan", action="store_true", help="Keep NaN values after interpolation") - horizontal.add_argument("--fix-border", action="store_true", help="Fix border effects (NearestNeighbour only)") + horizontal.add_argument( + "--kind", + choices=["NearestNeighbour", "Conservative"], + help="Interpolation method for horizontal axis", + ) + horizontal.add_argument( + "--n-neighbours", type=int, help="Number of neighbors (NearestNeighbour)" + ) + horizontal.add_argument( + "--flux", + action="store_true", + help="Treat variables as fluxes (Conservative only)", + ) + horizontal.add_argument( + "--keep-nan", action="store_true", help="Keep NaN values after interpolation" + ) + horizontal.add_argument( + "--fix-border", + action="store_true", + help="Fix border effects (NearestNeighbour only)", + ) horizontal.add_argument("--weight-matrix-path", help="Path to weight matrix file") - horizontal.add_argument("--only-create-wm", action="store_true", help="Only generate weight matrix") - horizontal.add_argument("--to-providentia", action="store_true", help="Format output for Providentia") + horizontal.add_argument( + "--only-create-wm", action="store_true", help="Only generate weight matrix" + ) + horizontal.add_argument( + "--to-providentia", action="store_true", help="Format output for Providentia" + ) # Vertical interpolation options vertical = interp_parser.add_argument_group("Vertical interpolation options") - vertical.add_argument("--method", help="Interpolation method for vertical axis (e.g. linear)") - vertical.add_argument("--extrapolate", action="store_true", help="Allow extrapolation in vertical interpolation") + vertical.add_argument( + "--method", help="Interpolation method for vertical axis (e.g. linear)" + ) + vertical.add_argument( + "--extrapolate", + action="store_true", + help="Allow extrapolation in vertical interpolation", + ) # Grid creation arguments grid = interp_parser.add_argument_group("Grid creation options (for --projection)") grid.add_argument("--lat_orig", type=float, help="Latitude origin (regular/global)") - grid.add_argument("--lon_orig", type=float, help="Longitude origin (regular/global)") - grid.add_argument("--inc_lat", type=float, help="Latitude increment (regular/global)") - grid.add_argument("--inc_lon", type=float, help="Longitude increment (regular/global)") - grid.add_argument("--n_lat", type=int, help="Number of latitude points (regular/global)") - grid.add_argument("--n_lon", type=int, help="Number of longitude points (regular/global)") - - grid.add_argument("--centre_lat", type=float, help="Rotated pole latitude (rotated)") - grid.add_argument("--centre_lon", type=float, help="Rotated pole longitude (rotated)") + grid.add_argument( + "--lon_orig", type=float, help="Longitude origin (regular/global)" + ) + grid.add_argument( + "--inc_lat", type=float, help="Latitude increment (regular/global)" + ) + grid.add_argument( + "--inc_lon", type=float, help="Longitude increment (regular/global)" + ) + grid.add_argument( + "--n_lat", type=int, help="Number of latitude points (regular/global)" + ) + grid.add_argument( + "--n_lon", type=int, help="Number of longitude points (regular/global)" + ) + + grid.add_argument( + "--centre_lat", type=float, help="Rotated pole latitude (rotated)" + ) + grid.add_argument( + "--centre_lon", type=float, help="Rotated pole longitude (rotated)" + ) grid.add_argument("--west_boundary", type=float, help="Western boundary (rotated)") - grid.add_argument("--south_boundary", type=float, help="Southern boundary (rotated)") + grid.add_argument( + "--south_boundary", type=float, help="Southern boundary (rotated)" + ) grid.add_argument("--inc_rlat", type=float, help="Latitude increment (rotated)") grid.add_argument("--inc_rlon", type=float, help="Longitude increment (rotated)") - grid.add_argument("--parent_grid_path", help="Path to parent grid NetCDF (rotated_nested)") + grid.add_argument( + "--parent_grid_path", help="Path to parent grid NetCDF (rotated_nested)" + ) grid.add_argument("--parent_ratio", type=int, help="Parent ratio (rotated_nested)") - grid.add_argument("--i_parent_start", type=int, help="Parent grid i index start (rotated_nested)") - grid.add_argument("--j_parent_start", type=int, help="Parent grid j index start (rotated_nested)") - grid.add_argument("--n_rlat", type=int, help="Number of lat points (rotated_nested)") - grid.add_argument("--n_rlon", type=int, help="Number of lon points (rotated_nested)") + grid.add_argument( + "--i_parent_start", type=int, help="Parent grid i index start (rotated_nested)" + ) + grid.add_argument( + "--j_parent_start", type=int, help="Parent grid j index start (rotated_nested)" + ) + grid.add_argument( + "--n_rlat", type=int, help="Number of lat points (rotated_nested)" + ) + grid.add_argument( + "--n_rlon", type=int, help="Number of lon points (rotated_nested)" + ) - grid.add_argument("--lat_1", type=float, help="First standard parallel (LCC projection)") - grid.add_argument("--lat_2", type=float, help="Second standard parallel (LCC projection)") + grid.add_argument( + "--lat_1", type=float, help="First standard parallel (LCC projection)" + ) + grid.add_argument( + "--lat_2", type=float, help="Second standard parallel (LCC projection)" + ) grid.add_argument("--lon_0", type=float, help="Central meridian (LCC projection)") grid.add_argument("--x_0", type=float, help="False easting (LCC projection)") grid.add_argument("--y_0", type=float, help="False northing (LCC projection)") - grid.add_argument("--dx", type=float, help="Grid spacing in x direction (LCC projection)") - grid.add_argument("--dy", type=float, help="Grid spacing in y direction (LCC projection)") - grid.add_argument("--nx", type=int, help="Number of grid points in x (LCC projection)") - grid.add_argument("--ny", type=int, help="Number of grid points in y (LCC projection)") + grid.add_argument( + "--dx", type=float, help="Grid spacing in x direction (LCC projection)" + ) + grid.add_argument( + "--dy", type=float, help="Grid spacing in y direction (LCC projection)" + ) + grid.add_argument( + "--nx", type=int, help="Number of grid points in x (LCC projection)" + ) + grid.add_argument( + "--ny", type=int, help="Number of grid points in y (LCC projection)" + ) - grid.add_argument("--lat_ts", type=float, help="Latitude of true scale (Mercator projection)") + grid.add_argument( + "--lat_ts", type=float, help="Latitude of true scale (Mercator projection)" + ) interp_parser.set_defaults(func=interpolate) @@ -127,11 +265,27 @@ def _add_check_subparser(subparsers): from nes.cli import run_checks check_parser = subparsers.add_parser("check", help="Run checks on a NetCDF file") - check_parser.add_argument("-i", "--input", required=True, help="Input NetCDF file path") - check_parser.add_argument("--nan", dest="check_nan", action="store_true", help="Check for NaN values") - check_parser.add_argument("--no-nan", dest="check_nan", action="store_false", help="Do not check NaN values") - check_parser.add_argument("--inf", dest="check_inf", action="store_true", help="Check for Inf values") - check_parser.add_argument("--no-inf", dest="check_inf", action="store_false", help="Do not check Inf values") + check_parser.add_argument( + "-i", "--input", required=True, help="Input NetCDF file path" + ) + check_parser.add_argument( + "--nan", dest="check_nan", action="store_true", help="Check for NaN values" + ) + check_parser.add_argument( + "--no-nan", + dest="check_nan", + action="store_false", + help="Do not check NaN values", + ) + check_parser.add_argument( + "--inf", dest="check_inf", action="store_true", help="Check for Inf values" + ) + check_parser.add_argument( + "--no-inf", + dest="check_inf", + action="store_false", + help="Do not check Inf values", + ) check_parser.set_defaults(check_nan=True, check_inf=True) check_parser.set_defaults(func=run_checks) @@ -150,8 +304,12 @@ def _add_reorder_subparser(subparsers): from nes.cli import reorder_longitudes # TODO: Add support for parallel version - reorder_parser = subparsers.add_parser("reorder", help="Reorder longitudes in a NetCDF file (ONLY SERIAL)") - reorder_parser.add_argument("-i", "--input", required=True, help="Input NetCDF file path") + reorder_parser = subparsers.add_parser( + "reorder", help="Reorder longitudes in a NetCDF file (ONLY SERIAL)" + ) + reorder_parser.add_argument( + "-i", "--input", required=True, help="Input NetCDF file path" + ) reorder_parser.add_argument("-o", "--output", help="Output NetCDF file path") reorder_parser.set_defaults(func=reorder_longitudes) @@ -173,6 +331,7 @@ def _filter_args(func, args_namespace): A dictionary containing only the arguments accepted by the function. """ import inspect + sig = inspect.signature(func) arg_keys = set(sig.parameters.keys()) args_dict = vars(args_namespace) @@ -189,11 +348,12 @@ def main(): """ parser = ArgParser( description="NES - NetCDF for Earth Science utilities", - default_config_files=['~/.nes_config'] + default_config_files=["~/.nes_config"], ) subparsers = parser.add_subparsers(dest="command", required=True) # Add subcommands + _add_nc2mbtiles_subparser(subparsers) _add_nc2geostructure_subparser(subparsers) _add_check_subparser(subparsers) _add_reorder_subparser(subparsers) @@ -208,9 +368,15 @@ def main(): filtered_args = _filter_args(args.func, args) args.func(**filtered_args) except Exception as e: - print(f"Process {MPI.COMM_WORLD.Get_rank()}: NES critical error detected {e}, aborting MPI.", flush=True) - print(f"Process {MPI.COMM_WORLD.Get_rank()}: Traceback:\n{traceback.format_exc()}", flush=True) + print( + f"Process {MPI.COMM_WORLD.Get_rank()}: NES critical error detected {e}, aborting MPI.", + flush=True, + ) + print( + f"Process {MPI.COMM_WORLD.Get_rank()}: Traceback:\n{traceback.format_exc()}", + flush=True, + ) MPI.COMM_WORLD.Abort(1) - return \ No newline at end of file + return diff --git a/nes/cli/geostructure.py b/nes/cli/geostructure.py index 40e49a1ac9f5ff92fcc997e2ab5eea57dadf4743..282da1cf2b8d69051f68e51b63c424d1ae25d1e4 100644 --- a/nes/cli/geostructure.py +++ b/nes/cli/geostructure.py @@ -10,7 +10,14 @@ Intended to be used as part of a CLI interface (e.g. via `nes geostructure`). from nes import open_netcdf -def nc2geostructure(input_file: str, output_file: str, var_list: list=None, time_step:int=0, level:int=0): + +def nc2geostructure( + input_file: str, + output_file: str, + var_list: list = None, + time_step: int = 0, + level: int = 0, +): """ Extracts geospatial data from a NetCDF file and writes it as a shapefile. @@ -41,9 +48,11 @@ def nc2geostructure(input_file: str, output_file: str, var_list: list=None, time # Select the desired time step and vertical level for extraction nessy.sel( time_min=nessy.time[time_step], # Minimum time index to extract - time_max=nessy.time[time_step], # Maximum time index to extract (same as min for single step) - level_min=level, # Minimum vertical level to extract - level_max=level # Maximum vertical level to extract (same as min for single level) + time_max=nessy.time[ + time_step + ], # Maximum time index to extract (same as min for single step) + level_min=level, # Minimum vertical level to extract + level_max=level, # Maximum vertical level to extract (same as min for single level) ) # Filter to keep only the specified variables and load the data into memory @@ -58,3 +67,24 @@ def nc2geostructure(input_file: str, output_file: str, var_list: list=None, time nessy.write_shapefile(output_file) return None + + +def nc2mbtiles( + input_file: str, output_path: str, mbtiles_file: str, geostructure: str = None +): + """Converts a monthly NetCDF into daily mbtiles + + Parameters + ---------- + input_file : str + Input NetCDF + output_path : str + Path where the mbtiles should be saved + mbtiles_file: str + Name of the mbtiles to be saved, this will be used as the prefix for the date portion and file extension + geostructure: str + Optional geostructure file to be provided, allowing to skip the manual computation, by default None + """ + nessy = open_netcdf(input_file, parallel_method="T") + nessy.load() + nessy.to_mbtile(output_path, mbtiles_file, geostructure) diff --git a/nes/nc_projections/default_nes.py b/nes/nc_projections/default_nes.py index cf1a1e22068002185c7a75651211e53fed801b33..e294f4e73a60931bcfc49ad19357b4b963c2ed9a 100644 --- a/nes/nc_projections/default_nes.py +++ b/nes/nc_projections/default_nes.py @@ -1,25 +1,81 @@ #!/usr/bin/env python - +import os import sys +import tempfile +from copy import copy, deepcopy +from datetime import datetime, timedelta from gc import collect -from warnings import warn from math import isclose -from numpy import (array, ndarray, abs, mean, diff, dstack, append, tile, empty, unique, stack, vstack, full, isnan, - flipud, nan, float32, float64, ma, generic, character, issubdtype, arange, newaxis, concatenate, - split, cumsum, zeros, column_stack, hstack, argsort, take) -from pandas import Index, concat -from geopandas import GeoDataFrame -from datetime import timedelta, datetime -from netCDF4 import Dataset, num2date, date2num, stringtochar -from mpi4py import MPI -from shapely.geometry import Polygon, Point -from copy import deepcopy, copy +from typing import Any, Dict, List, Union +from warnings import warn + from dateutil.relativedelta import relativedelta -from typing import Union, List, Dict, Any +from geopandas import GeoDataFrame, read_file +from mpi4py import MPI +from netCDF4 import Dataset, date2num, num2date, stringtochar +from numba import boolean, njit, prange +from numpy import ( + abs, + append, + arange, + argsort, + array, + character, + column_stack, + concatenate, + cumsum, + diff, + dstack, + empty, + flipud, + float32, + float64, + full, + generic, + hstack, + isnan, + issubdtype, + ma, + mean, + nan, + ndarray, + newaxis, + split, + stack, + take, + tile, + unique, + vstack, + zeros, +) +from osgeo import gdal +from pandas import Index, concat from pyproj import Proj, Transformer -from ..methods import vertical_interpolation, horizontal_interpolation, cell_measures, spatial_join -from ..nes_formats import to_netcdf_cams_ra, to_netcdf_monarch, to_monarch_units, to_netcdf_cmaq, to_cmaq_units, \ - to_netcdf_wrf_chem, to_wrf_chem_units, to_netcdf_mocage, to_mocage_units +from shapely.geometry import Point, Polygon + +from ..methods import ( + cell_measures, + horizontal_interpolation, + spatial_join, + vertical_interpolation, +) +from ..nes_formats import ( + to_cmaq_units, + to_mocage_units, + to_monarch_units, + to_netcdf_cams_ra, + to_netcdf_cmaq, + to_netcdf_mocage, + to_netcdf_monarch, + to_netcdf_wrf_chem, + to_wrf_chem_units, +) + +# GDAL config +gdal.SetConfigOption("CPL_LOG_ERRORS", "ON") +gdal.SetConfigOption("GDAL_NUM_THREADS", "1") +gdal.UseExceptions() +os.environ["GDAL_NUM_THREADS"] = "1" class Nes(object): @@ -175,11 +231,23 @@ class Nes(object): ... } """ - def __init__(self, comm: Union[MPI.Comm, None] = None, path: Union[str, None] = None, info: bool = False, - dataset: Union[Dataset, None] = None, parallel_method: str = "Y", avoid_first_hours: int = 0, - avoid_last_hours: int = 0, first_level: int = 0, last_level: Union[int, None] = None, - create_nes: bool = False, balanced: bool = False, times: Union[List[datetime], None] = None, - **kwargs) -> None: + + def __init__( + self, + comm: Union[MPI.Comm, None] = None, + path: Union[str, None] = None, + info: bool = False, + dataset: Union[Dataset, None] = None, + parallel_method: str = "Y", + avoid_first_hours: int = 0, + avoid_last_hours: int = 0, + first_level: int = 0, + last_level: Union[int, None] = None, + create_nes: bool = False, + balanced: bool = False, + times: Union[List[datetime], None] = None, + **kwargs, + ) -> None: """ Initialize the Nes class @@ -252,7 +320,6 @@ class Nes(object): # NetCDF object if create_nes: - self.dataset = None # Set string length @@ -270,7 +337,9 @@ class Nes(object): self._full_time = times self._full_time_bnds = self.__get_time_bnds(create_nes) - self._full_lat_bnds, self._full_lon_bnds = self.__get_coordinates_bnds(create_nes) + self._full_lat_bnds, self._full_lon_bnds = self.__get_coordinates_bnds( + create_nes + ) self._full_lev = {"data": array([0]), "units": "", "positive": "up"} self._full_lat, self._full_lon = self._create_centre_coordinates(**kwargs) @@ -279,7 +348,9 @@ class Nes(object): self.write_axis_limits = self._get_write_axis_limits() # Dimensions screening - self.time = self.get_full_times()[self.read_axis_limits["t_min"]:self.read_axis_limits["t_max"]] + self.time = self.get_full_times()[ + self.read_axis_limits["t_min"] : self.read_axis_limits["t_max"] + ] self.time_bnds = self.get_full_time_bnds() self.lev = self.get_full_levels() self.lat_bnds = self.get_full_latitudes_boundaries() @@ -307,9 +378,15 @@ class Nes(object): # Complete dimensions self._full_time = self.__get_time() self._full_time_bnds = self.__get_time_bnds() - self._full_lev = self._get_coordinate_dimension(["lev", "level", "lm", "plev"]) - self._full_lat = self._get_coordinate_dimension(["lat", "latitude", "latitudes"]) - self._full_lon = self._get_coordinate_dimension(["lon", "longitude", "longitudes"]) + self._full_lev = self._get_coordinate_dimension( + ["lev", "level", "lm", "plev"] + ) + self._full_lat = self._get_coordinate_dimension( + ["lat", "latitude", "latitudes"] + ) + self._full_lon = self._get_coordinate_dimension( + ["lon", "longitude", "longitudes"] + ) self._full_lat_bnds, self._full_lon_bnds = self.__get_coordinates_bnds() # Complete cell measures @@ -321,13 +398,19 @@ class Nes(object): self.write_axis_limits = self._get_write_axis_limits() # Dimensions screening - self.time = self.get_full_times()[self.read_axis_limits["t_min"]:self.read_axis_limits["t_max"]] + self.time = self.get_full_times()[ + self.read_axis_limits["t_min"] : self.read_axis_limits["t_max"] + ] self.time_bnds = self.get_full_time_bnds() self.lev = self._get_coordinate_values(self.get_full_levels(), "Z") self.lat = self._get_coordinate_values(self.get_full_latitudes(), "Y") self.lon = self._get_coordinate_values(self.get_full_longitudes(), "X") - self.lat_bnds = self._get_coordinate_values(self.get_full_latitudes_boundaries(), "Y", bounds=True) - self.lon_bnds = self._get_coordinate_values(self.get_full_longitudes_boundaries(), "X", bounds=True) + self.lat_bnds = self._get_coordinate_values( + self.get_full_latitudes_boundaries(), "Y", bounds=True + ) + self.lon_bnds = self._get_coordinate_values( + self.get_full_longitudes_boundaries(), "X", bounds=True + ) # Cell measures screening self.cell_measures = self._get_cell_measures_values(self._cell_measures) @@ -353,8 +436,10 @@ class Nes(object): # Filtering (portion of the filter coordinates function) idx = self._get_idx_intervals() if self.master: - self.set_full_times(self._full_time[idx["idx_t_min"]:idx["idx_t_max"]]) - self._full_lev["data"] = self._full_lev["data"][idx["idx_z_min"]:idx["idx_z_max"]] + self.set_full_times(self._full_time[idx["idx_t_min"] : idx["idx_t_max"]]) + self._full_lev["data"] = self._full_lev["data"][ + idx["idx_z_min"] : idx["idx_z_max"] + ] self.hours_start = 0 self.hours_end = 0 @@ -362,9 +447,21 @@ class Nes(object): self.first_level = None @staticmethod - def new(comm=None, path=None, info=False, dataset=None, parallel_method="Y", - avoid_first_hours=0, avoid_last_hours=0, first_level=0, last_level=None, create_nes=False, - balanced=False, times=None, **kwargs): + def new( + comm=None, + path=None, + info=False, + dataset=None, + parallel_method="Y", + avoid_first_hours=0, + avoid_last_hours=0, + first_level=0, + last_level=None, + create_nes=False, + balanced=False, + times=None, + **kwargs, + ): """ Initialize the Nes class. @@ -398,9 +495,21 @@ class Nes(object): List of times to substitute the current ones while creation. """ - new = Nes(comm=comm, path=path, info=info, dataset=dataset, parallel_method=parallel_method, - avoid_first_hours=avoid_first_hours, avoid_last_hours=avoid_last_hours, first_level=first_level, - last_level=last_level, create_nes=create_nes, balanced=balanced, times=times, **kwargs) + new = Nes( + comm=comm, + path=path, + info=info, + dataset=dataset, + parallel_method=parallel_method, + avoid_first_hours=avoid_first_hours, + avoid_last_hours=avoid_last_hours, + first_level=first_level, + last_level=last_level, + create_nes=create_nes, + balanced=balanced, + times=times, + **kwargs, + ) return new @@ -485,7 +594,11 @@ class Nes(object): """ d = self.__dict__ - state = {k: d[k] for k in d if k not in ["comm", "variables", "dataset", "cell_measures"]} + state = { + k: d[k] + for k in d + if k not in ["comm", "variables", "dataset", "cell_measures"] + } return state @@ -909,8 +1022,12 @@ class Nes(object): 2D array with the FID data. """ if self.master: - fids = arange(self._full_lat["data"].shape[0] * self._full_lon["data"].shape[-1]) - fids = fids.reshape((self._full_lat["data"].shape[0], self._full_lon["data"].shape[-1])) + fids = arange( + self._full_lat["data"].shape[0] * self._full_lon["data"].shape[-1] + ) + fids = fids.reshape( + (self._full_lat["data"].shape[0], self._full_lon["data"].shape[-1]) + ) if self.size == 1: return fids else: @@ -919,12 +1036,16 @@ class Nes(object): fids = self.comm.bcast(fids, root=0) if use_read: - fids = fids[self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:self.read_axis_limits["x_max"]] + fids = fids[ + self.read_axis_limits["y_min"] : self.read_axis_limits["y_max"], + self.read_axis_limits["x_min"] : self.read_axis_limits["x_max"], + ] else: try: - fids = fids[self.write_axis_limits["y_min"]:self.write_axis_limits["y_max"], - self.write_axis_limits["x_min"]:self.write_axis_limits["x_max"]] + fids = fids[ + self.write_axis_limits["y_min"] : self.write_axis_limits["y_max"], + self.write_axis_limits["x_min"] : self.write_axis_limits["x_max"], + ] except TypeError as e: print(self.rank, fids, self.write_axis_limits) sys.stdout.flush() @@ -970,7 +1091,9 @@ class Nes(object): If `new_direction` is not "up" or "down". """ if new_direction not in ["up", "down"]: - raise ValueError(f"Level direction mus be up or down. '{new_direction}' is not a valid option") + raise ValueError( + f"Level direction mus be up or down. '{new_direction}' is not a valid option" + ) if self.master: self._full_lev["positive"] = new_direction self.lev["positive"] = new_direction @@ -1115,7 +1238,9 @@ class Nes(object): else: msg = "WARNING!!! " msg += "The given time bounds list has a different length than the time array. " - msg += "(time:{0}, bnds:{1}). Time bounds will not be set.".format(len(self.time), len(time_bnds)) + msg += "(time:{0}, bnds:{1}). Time bounds will not be set.".format( + len(self.time), len(time_bnds) + ) warn(msg) sys.stderr.flush() else: @@ -1147,12 +1272,23 @@ class Nes(object): ValueError If `new_resolution` is not one of the accepted values. """ - accepted_resolutions = ["second", "seconds", "minute", "minutes", "hour", "hours", "day", "days"] + accepted_resolutions = [ + "second", + "seconds", + "minute", + "minutes", + "hour", + "hours", + "day", + "days", + ] if new_resolution in accepted_resolutions: self._time_resolution = new_resolution else: - raise ValueError(f"Time resolution '{new_resolution}' is not accepted. " + - f"Use one of this: {accepted_resolutions}") + raise ValueError( + f"Time resolution '{new_resolution}' is not accepted. " + + f"Use one of this: {accepted_resolutions}" + ) return True @staticmethod @@ -1205,18 +1341,30 @@ class Nes(object): # Latitudes full_lat = self.get_full_latitudes() inc_lat = abs(mean(diff(full_lat["data"]))) - lat_bnds = self._create_single_spatial_bounds(full_lat["data"], inc_lat, spatial_nv=2) + lat_bnds = self._create_single_spatial_bounds( + full_lat["data"], inc_lat, spatial_nv=2 + ) self.set_full_latitudes_boundaries({"data": deepcopy(lat_bnds)}) - self.lat_bnds = {"data": lat_bnds[self.write_axis_limits["y_min"]:self.write_axis_limits["y_max"], :]} + self.lat_bnds = { + "data": lat_bnds[ + self.write_axis_limits["y_min"] : self.write_axis_limits["y_max"], : + ] + } # Longitudes full_lon = self.get_full_longitudes() inc_lon = abs(mean(diff(full_lon["data"]))) - lon_bnds = self._create_single_spatial_bounds(full_lon["data"], inc_lon, spatial_nv=2) + lon_bnds = self._create_single_spatial_bounds( + full_lon["data"], inc_lon, spatial_nv=2 + ) self.set_full_longitudes_boundaries({"data": deepcopy(lon_bnds)}) - self.lon_bnds = {"data": lon_bnds[self.write_axis_limits["x_min"]:self.write_axis_limits["x_max"], :]} + self.lon_bnds = { + "data": lon_bnds[ + self.write_axis_limits["x_min"] : self.write_axis_limits["x_max"], : + ] + } return None @@ -1234,7 +1382,9 @@ class Nes(object): Latitude boundaries in the mesh format """ if self.size > 1: - raise RuntimeError("NES.get_spatial_bounds_mesh_format() function only works in serial mode.") + raise RuntimeError( + "NES.get_spatial_bounds_mesh_format() function only works in serial mode." + ) if self.lat_bnds is None: self.create_spatial_bounds() @@ -1248,19 +1398,27 @@ class Nes(object): elif self.lat_bnds["data"].shape[-1] == 4: # Irregular quadrilateral polygon cell definition - lat_bnds_mesh = empty((self.lat["data"].shape[0] + 1, self.lat["data"].shape[1] + 1)) + lat_bnds_mesh = empty( + (self.lat["data"].shape[0] + 1, self.lat["data"].shape[1] + 1) + ) lat_bnds_mesh[:-1, :-1] = self.lat_bnds["data"][:, :, 0] lat_bnds_mesh[:-1, 1:] = self.lat_bnds["data"][:, :, 1] lat_bnds_mesh[1:, 1:] = self.lat_bnds["data"][:, :, 2] lat_bnds_mesh[1:, :-1] = self.lat_bnds["data"][:, :, 3] - lon_bnds_mesh = empty((self.lat["data"].shape[0] + 1, self.lat["data"].shape[1] + 1)) + lon_bnds_mesh = empty( + (self.lat["data"].shape[0] + 1, self.lat["data"].shape[1] + 1) + ) lon_bnds_mesh[:-1, :-1] = self.lon_bnds["data"][:, :, 0] lon_bnds_mesh[:-1, 1:] = self.lon_bnds["data"][:, :, 1] lon_bnds_mesh[1:, 1:] = self.lon_bnds["data"][:, :, 2] lon_bnds_mesh[1:, :-1] = self.lon_bnds["data"][:, :, 3] else: - raise RuntimeError("Invalid number of vertices: {0}".format(self.lat_bnds["data"].shape[-1])) + raise RuntimeError( + "Invalid number of vertices: {0}".format( + self.lat_bnds["data"].shape[-1] + ) + ) return lon_bnds_mesh, lat_bnds_mesh @@ -1365,12 +1523,25 @@ class Nes(object): else: aux_nessy.variables[var_name][att_name] = att_value else: - aux_nessy.variables[var_name]["data"] = aux_nessy.variables[var_name]["data"][[idx_time]] + aux_nessy.variables[var_name]["data"] = aux_nessy.variables[var_name][ + "data" + ][[idx_time]] return aux_nessy - def sel(self, hours_start=None, time_min=None, hours_end=None, time_max=None, lev_min=None, lev_max=None, - lat_min=None, lat_max=None, lon_min=None, lon_max=None): + def sel( + self, + hours_start=None, + time_min=None, + hours_end=None, + time_max=None, + lev_min=None, + lev_max=None, + lat_min=None, + lat_max=None, + lon_min=None, + lon_max=None, + ): """ Select a slice of time, vertical level, latitude, or longitude given minimum and maximum limits. @@ -1417,23 +1588,31 @@ class Nes(object): if var_info["data"] is not None: loaded_vars = True if loaded_vars: - raise ValueError("Some variables have been loaded. Use select function before load.") + raise ValueError( + "Some variables have been loaded. Use select function before load." + ) # First time filter if hours_start is not None: if time_min is not None: - raise ValueError("Choose to select by hours_start or time_min but not both") + raise ValueError( + "Choose to select by hours_start or time_min but not both" + ) self.hours_start = hours_start elif time_min is not None: if time_min <= full_time[0]: self.hours_start = 0 else: - self.hours_start = int((time_min - full_time[0]).total_seconds() // 3600) + self.hours_start = int( + (time_min - full_time[0]).total_seconds() // 3600 + ) # Last time filter if hours_end is not None: if time_max is not None: - raise ValueError("Choose to select by hours_end or time_max but not both") + raise ValueError( + "Choose to select by hours_end or time_max but not both" + ) self.hours_end = hours_end elif time_max is not None: if time_max >= full_time[-1]: @@ -1455,14 +1634,20 @@ class Nes(object): self.read_axis_limits = self._get_read_axis_limits() # Dimensions screening - self.time = self.get_full_times()[self.read_axis_limits["t_min"]:self.read_axis_limits["t_max"]] + self.time = self.get_full_times()[ + self.read_axis_limits["t_min"] : self.read_axis_limits["t_max"] + ] self.time_bnds = self.get_full_time_bnds() self.lev = self._get_coordinate_values(self.get_full_levels(), "Z") self.lat = self._get_coordinate_values(self.get_full_latitudes(), "Y") self.lon = self._get_coordinate_values(self.get_full_longitudes(), "X") - self.lat_bnds = self._get_coordinate_values(self.get_full_latitudes_boundaries(), "Y", bounds=True) - self.lon_bnds = self._get_coordinate_values(self.get_full_longitudes_boundaries(), "X", bounds=True) + self.lat_bnds = self._get_coordinate_values( + self.get_full_latitudes_boundaries(), "Y", bounds=True + ) + self.lon_bnds = self._get_coordinate_values( + self.get_full_longitudes_boundaries(), "X", bounds=True + ) # Filter dimensions self._filter_coordinates_selection() @@ -1491,7 +1676,9 @@ class Nes(object): self._expand_coordinates(n_cells, is_first, is_last) self._expand_variables(n_cells, is_first, is_last) self._expand_full_coordinates(n_cells) - self._update_axis_limits(expand=True, n_cells=n_cells, is_first=is_first, is_last=is_last) + self._update_axis_limits( + expand=True, n_cells=n_cells, is_first=is_first, is_last=is_last + ) self._post_expand_processing() def contract(self, n_cells: int = 1) -> None: @@ -1513,7 +1700,9 @@ class Nes(object): self._contract_coordinates(n_cells, is_first, is_last) self._contract_variables(n_cells, is_first, is_last) self._contract_full_coordinates(n_cells) - self._update_axis_limits(expand=False, n_cells=n_cells, is_first=is_first, is_last=is_last) + self._update_axis_limits( + expand=False, n_cells=n_cells, is_first=is_first, is_last=is_last + ) self._post_expand_processing() def _expand_coordinates(self, n_cells: int, is_first: bool, is_last: bool) -> None: @@ -1535,7 +1724,9 @@ class Nes(object): """ raise NotImplementedError("It is not possible to expand Default coordinates") - def _contract_coordinates(self, n_cells: int, is_first: bool, is_last: bool) -> None: + def _contract_coordinates( + self, n_cells: int, is_first: bool, is_last: bool + ) -> None: """ Contracts latitude and longitude coordinates by removing `n_cells` from the borders. @@ -1567,7 +1758,9 @@ class Nes(object): ------- None """ - raise NotImplementedError("It is not possible to expand Default full coordinates") + raise NotImplementedError( + "It is not possible to expand Default full coordinates" + ) def _contract_full_coordinates(self, n_cells: int) -> None: """ @@ -1582,7 +1775,9 @@ class Nes(object): ------- None """ - raise NotImplementedError("It is not possible to contract Default full coordinates") + raise NotImplementedError( + "It is not possible to contract Default full coordinates" + ) def _expand_variables(self, n_cells: int, is_first: bool, is_last: bool) -> None: """ @@ -1601,20 +1796,21 @@ class Nes(object): ------- None """ - expand_x = (self.parallel_method in ["X", None]) - expand_y = (self.parallel_method in ["Y", None]) + expand_x = self.parallel_method in ["X", None] + expand_y = self.parallel_method in ["Y", None] for var_name, var_data in self.variables.items(): - if isinstance(self.variables[var_name]['data'], ndarray): - self.variables[var_name]['data'] = self._expand_variables_data( - var_data['data'], n_cells=n_cells, + if isinstance(self.variables[var_name]["data"], ndarray): + self.variables[var_name]["data"] = self._expand_variables_data( + var_data["data"], + n_cells=n_cells, left=is_first if expand_x else True, right=is_last if expand_x else True, top=is_last if expand_y else True, - bottom=is_first if expand_y else True + bottom=is_first if expand_y else True, ) else: - self.variables[var_name]['data'] = var_data['data'] + self.variables[var_name]["data"] = var_data["data"] return None def _contract_variables(self, n_cells: int, is_first: bool, is_last: bool) -> None: @@ -1634,24 +1830,27 @@ class Nes(object): ------- None """ - contract_x = (self.parallel_method in ["X", None]) - contract_y = (self.parallel_method in ["Y", None]) + contract_x = self.parallel_method in ["X", None] + contract_y = self.parallel_method in ["Y", None] for var_name, var_data in self.variables.items(): - if isinstance(self.variables[var_name]['data'], ndarray): - self.variables[var_name]['data'] = self._contract_variables_data( - var_data['data'], n_cells=n_cells, + if isinstance(self.variables[var_name]["data"], ndarray): + self.variables[var_name]["data"] = self._contract_variables_data( + var_data["data"], + n_cells=n_cells, left=is_first if contract_x else True, right=is_last if contract_x else True, top=is_last if contract_y else True, - bottom=is_first if contract_y else True + bottom=is_first if contract_y else True, ) else: - self.variables[var_name]['data'] = var_data['data'] + self.variables[var_name]["data"] = var_data["data"] return None @staticmethod - def _expand_1d(coordinate: ndarray, n_cells: int, left: bool = True, right: bool = True) -> ndarray or None: + def _expand_1d( + coordinate: ndarray, n_cells: int, left: bool = True, right: bool = True + ) -> ndarray or None: """ Expands a 1D or 2D coordinate array by adding `n_cells` values on each side. @@ -1716,19 +1915,44 @@ class Nes(object): # Expansion for 1D coordinates step = coordinate[1] - coordinate[0] - expanded_left = arange(coordinate[0] - n_cells * step, coordinate[0], step) if left else empty((0,)) - expanded_right = arange(coordinate[-1] + step, coordinate[-1] + (n_cells + 1) * step, - step) if right else empty((0,)) + expanded_left = ( + arange(coordinate[0] - n_cells * step, coordinate[0], step) + if left + else empty((0,)) + ) + expanded_right = ( + arange( + coordinate[-1] + step, coordinate[-1] + (n_cells + 1) * step, step + ) + if right + else empty((0,)) + ) result = concatenate([expanded_left, coordinate, expanded_right]) elif coordinate.ndim == 2: # Expansion for 2D coordinates (bounds) step = coordinate[1, 0] - coordinate[0, 0] # Assuming uniform step size - expanded_left = array([[coordinate[0, 0] - i * step, coordinate[0, 1] - i * step] for i in - range(n_cells, 0, -1)]) if left else empty((0, 2)) - expanded_right = array([[coordinate[-1, 0] + i * step, coordinate[-1, 1] + i * step] for i in - range(1, n_cells + 1)]) if right else empty((0, 2)) + expanded_left = ( + array( + [ + [coordinate[0, 0] - i * step, coordinate[0, 1] - i * step] + for i in range(n_cells, 0, -1) + ] + ) + if left + else empty((0, 2)) + ) + expanded_right = ( + array( + [ + [coordinate[-1, 0] + i * step, coordinate[-1, 1] + i * step] + for i in range(1, n_cells + 1) + ] + ) + if right + else empty((0, 2)) + ) result = concatenate([expanded_left, coordinate, expanded_right]) @@ -1738,8 +1962,9 @@ class Nes(object): return result @staticmethod - def _contract_1d(coordinate: ndarray, n_cells: int, left: bool = True, - right: bool = True) -> ndarray or None: + def _contract_1d( + coordinate: ndarray, n_cells: int, left: bool = True, right: bool = True + ) -> ndarray or None: """ Contracts a 1D or 2D coordinate array by removing `n_cells` from each specified side. @@ -1774,13 +1999,17 @@ class Nes(object): if coordinate.ndim == 1: if n_cells * (left + right) >= coordinate.shape[0]: - raise IndexError("n_cells is too large, would remove all elements in the 1D array.") + raise IndexError( + "n_cells is too large, would remove all elements in the 1D array." + ) result = coordinate[n_cells:] if left else coordinate result = result[:-n_cells] if right else result elif coordinate.ndim == 2: if n_cells * (left + right) >= coordinate.shape[0]: - raise IndexError("n_cells is too large, would remove all rows in the 2D array.") + raise IndexError( + "n_cells is too large, would remove all rows in the 2D array." + ) result = coordinate[n_cells:] if left else coordinate result = result[:-n_cells] if right else result @@ -1790,9 +2019,14 @@ class Nes(object): return result @staticmethod - def _expand_2d(coordinate: ndarray, n_cells: int, - left: bool = True, right: bool = True, - top: bool = True, bottom: bool = True) -> ndarray or None: + def _expand_2d( + coordinate: ndarray, + n_cells: int, + left: bool = True, + right: bool = True, + top: bool = True, + bottom: bool = True, + ) -> ndarray or None: """ Expands a 2D coordinate grid by adding `n_cells` on each specified side. @@ -1827,39 +2061,60 @@ class Nes(object): # Expand left and right if left: left_extension = tile(coordinate[:, [0]], (1, n_cells)) - left_extension -= arange(n_cells, 0, -1).reshape(1, -1) * (coordinate[:, [1]] - coordinate[:, [0]]) + left_extension -= arange(n_cells, 0, -1).reshape(1, -1) * ( + coordinate[:, [1]] - coordinate[:, [0]] + ) else: left_extension = empty((ny, 0)) if right: right_extension = tile(coordinate[:, [-1]], (1, n_cells)) - right_extension += arange(1, n_cells + 1).reshape(1, -1) * (coordinate[:, [-1]] - coordinate[:, [-2]]) + right_extension += arange(1, n_cells + 1).reshape(1, -1) * ( + coordinate[:, [-1]] - coordinate[:, [-2]] + ) else: right_extension = empty((ny, 0)) - expanded = hstack([left_extension, coordinate, right_extension]) if left or right else coordinate + expanded = ( + hstack([left_extension, coordinate, right_extension]) + if left or right + else coordinate + ) # Expand top and bottom if top: top_extension = tile(expanded[[0], :], (n_cells, 1)) - top_extension -= arange(n_cells, 0, -1).reshape(-1, 1) * (expanded[[1], :] - expanded[[0], :]) + top_extension -= arange(n_cells, 0, -1).reshape(-1, 1) * ( + expanded[[1], :] - expanded[[0], :] + ) else: top_extension = empty((0, expanded.shape[1])) if bottom: bottom_extension = tile(expanded[[-1], :], (n_cells, 1)) - bottom_extension += arange(1, n_cells + 1).reshape(-1, 1) * (expanded[[-1], :] - expanded[[-2], :]) + bottom_extension += arange(1, n_cells + 1).reshape(-1, 1) * ( + expanded[[-1], :] - expanded[[-2], :] + ) else: bottom_extension = empty((0, expanded.shape[1])) - expanded = vstack([top_extension, expanded, bottom_extension]) if top or bottom else expanded + expanded = ( + vstack([top_extension, expanded, bottom_extension]) + if top or bottom + else expanded + ) return expanded @staticmethod - def _expand_2d_bounds(coordinate: ndarray, n_cells: int, - left: bool = True, right: bool = True, - top: bool = True, bottom: bool = True) -> ndarray or None: + def _expand_2d_bounds( + coordinate: ndarray, + n_cells: int, + left: bool = True, + right: bool = True, + top: bool = True, + bottom: bool = True, + ) -> ndarray or None: """ Expands a 3D coordinate grid (bounds format) by adding `n_cells` on each specified side. @@ -1887,7 +2142,9 @@ class Nes(object): return None if coordinate.ndim != 3 or coordinate.shape[2] != 4: - raise ValueError("Input coordinate array must be 3D with the last dimension of size 4 (bounds format).") + raise ValueError( + "Input coordinate array must be 3D with the last dimension of size 4 (bounds format)." + ) if n_cells <= 0 or not (left or right or top or bottom): return coordinate # No expansion needed @@ -1897,40 +2154,60 @@ class Nes(object): # Expand left and right if left: left_extension = tile(coordinate[:, [0], :], (1, n_cells, 1)) - left_extension -= arange(n_cells, 0, -1).reshape(1, -1, 1) * (coordinate[:, [1], :] - coordinate[:, [0], :]) + left_extension -= arange(n_cells, 0, -1).reshape(1, -1, 1) * ( + coordinate[:, [1], :] - coordinate[:, [0], :] + ) else: left_extension = empty((ny, 0, nbounds)) if right: right_extension = tile(coordinate[:, [-1], :], (1, n_cells, 1)) right_extension += arange(1, n_cells + 1).reshape(1, -1, 1) * ( - coordinate[:, [-1], :] - coordinate[:, [-2], :]) + coordinate[:, [-1], :] - coordinate[:, [-2], :] + ) else: right_extension = empty((ny, 0, nbounds)) - expanded = hstack([left_extension, coordinate, right_extension]) if left or right else coordinate + expanded = ( + hstack([left_extension, coordinate, right_extension]) + if left or right + else coordinate + ) # Expand top and bottom if top: top_extension = tile(expanded[[0], :, :], (n_cells, 1, 1)) - top_extension -= arange(n_cells, 0, -1).reshape(-1, 1, 1) * (expanded[[1], :, :] - expanded[[0], :, :]) + top_extension -= arange(n_cells, 0, -1).reshape(-1, 1, 1) * ( + expanded[[1], :, :] - expanded[[0], :, :] + ) else: top_extension = empty((0, expanded.shape[1], nbounds)) if bottom: bottom_extension = tile(expanded[[-1], :, :], (n_cells, 1, 1)) - bottom_extension += arange(1, n_cells + 1).reshape(-1, 1, 1) * (expanded[[-1], :, :] - expanded[[-2], :, :]) + bottom_extension += arange(1, n_cells + 1).reshape(-1, 1, 1) * ( + expanded[[-1], :, :] - expanded[[-2], :, :] + ) else: bottom_extension = empty((0, expanded.shape[1], nbounds)) - expanded = vstack([top_extension, expanded, bottom_extension]) if top or bottom else expanded + expanded = ( + vstack([top_extension, expanded, bottom_extension]) + if top or bottom + else expanded + ) return expanded @staticmethod - def _contract_2d(coordinate: ndarray, n_cells: int, - left: bool = True, right: bool = True, - top: bool = True, bottom: bool = True) -> ndarray or None: + def _contract_2d( + coordinate: ndarray, + n_cells: int, + left: bool = True, + right: bool = True, + top: bool = True, + bottom: bool = True, + ) -> ndarray or None: """ Contracts a 2D coordinate grid by removing `n_cells` from each specified side. @@ -1974,9 +2251,14 @@ class Nes(object): return coordinate @staticmethod - def _contract_2d_bounds(coordinate: ndarray, n_cells: int, - left: bool = True, right: bool = True, - top: bool = True, bottom: bool = True) -> ndarray or None: + def _contract_2d_bounds( + coordinate: ndarray, + n_cells: int, + left: bool = True, + right: bool = True, + top: bool = True, + bottom: bool = True, + ) -> ndarray or None: """ Contracts a 3D coordinate grid (bounds format) by removing `n_cells` from each specified side. @@ -2004,7 +2286,9 @@ class Nes(object): return None if coordinate.ndim != 3 or coordinate.shape[2] != 4: - raise ValueError("Input coordinate array must be 3D with the last dimension of size 4 (bounds format).") + raise ValueError( + "Input coordinate array must be 3D with the last dimension of size 4 (bounds format)." + ) if n_cells <= 0 or not (left or right or top or bottom): return coordinate # No contraction needed @@ -2023,8 +2307,14 @@ class Nes(object): return coordinate @staticmethod - def _expand_variables_data(data: ndarray, n_cells: int, left: bool = True, right: bool = True, top: bool = True, - bottom: bool = True) -> ndarray or None: + def _expand_variables_data( + data: ndarray, + n_cells: int, + left: bool = True, + right: bool = True, + top: bool = True, + bottom: bool = True, + ) -> ndarray or None: """ Expands a 4D data array by adding zero-filled cells along the spatial dimensions (Y and X). @@ -2100,9 +2390,14 @@ class Nes(object): return data @staticmethod - def _contract_variables_data(data: ndarray, n_cells: int, - left: bool = True, right: bool = True, - top: bool = True, bottom: bool = True) -> ndarray or None: + def _contract_variables_data( + data: ndarray, + n_cells: int, + left: bool = True, + right: bool = True, + top: bool = True, + bottom: bool = True, + ) -> ndarray or None: """ Contracts a 4D data array by removing `n_cells` along the spatial dimensions (Y and X). @@ -2153,9 +2448,13 @@ class Nes(object): # Validate that n_cells does not exceed the available dimensions if n_cells >= y_dim and (top or bottom): - raise IndexError("n_cells is too large, would remove all rows in Y dimension.") + raise IndexError( + "n_cells is too large, would remove all rows in Y dimension." + ) if n_cells >= x_dim and (left or right): - raise IndexError("n_cells is too large, would remove all columns in X dimension.") + raise IndexError( + "n_cells is too large, would remove all columns in X dimension." + ) # Contract Y dimension (top and bottom) if top: @@ -2181,10 +2480,12 @@ class Nes(object): """ if self.shapefile is not None: self.create_shapefile(overwrite=True) - if 'cell_area' in self.cell_measures: + if "cell_area" in self.cell_measures: self.calculate_grid_area(overwrite=True) - def _update_axis_limits(self, expand: bool = True, n_cells: int = 1, is_first=True, is_last=True) -> None: + def _update_axis_limits( + self, expand: bool = True, n_cells: int = 1, is_first=True, is_last=True + ) -> None: """ Update the axis limits by expanding or contracting the read and write boundaries. @@ -2206,10 +2507,16 @@ class Nes(object): # self.read_axis_limits = self._adjust_axis_limits(axis_limits=self.read_axis_limits, n_cells=n_cells, expand=expand, is_first=is_first, is_last=is_last) # self.write_axis_limits = self._adjust_axis_limits(axis_limits=self.write_axis_limits, n_cells=n_cells, expand=expand, is_first=is_first, is_last=is_last) - self._adjust_read_axis_limits(n_cells=n_cells, expand=expand, is_first=is_first, is_last=is_last) - self._adjust_write_axis_limits(n_cells=n_cells, expand=expand, is_first=is_first, is_last=is_last) + self._adjust_read_axis_limits( + n_cells=n_cells, expand=expand, is_first=is_first, is_last=is_last + ) + self._adjust_write_axis_limits( + n_cells=n_cells, expand=expand, is_first=is_first, is_last=is_last + ) - def _adjust_axis_limits(self, axis_limits, n_cells: int, expand: bool, is_first=True, is_last=True) -> dict: + def _adjust_axis_limits( + self, axis_limits, n_cells: int, expand: bool, is_first=True, is_last=True + ) -> dict: """ Adjust read axis limits by expanding or contracting them based on parallelization method. @@ -2233,41 +2540,43 @@ class Nes(object): if self.parallel_method == "Y": if is_first: - axis_limits['y_min'] -= sign * n_cells - axis_limits['x_min'] -= sign * n_cells + axis_limits["y_min"] -= sign * n_cells + axis_limits["x_min"] -= sign * n_cells if is_last: - if axis_limits['y_max'] is not None: - axis_limits['y_max'] += sign * n_cells - if axis_limits['x_max'] is not None: - axis_limits['x_max'] += sign * n_cells + if axis_limits["y_max"] is not None: + axis_limits["y_max"] += sign * n_cells + if axis_limits["x_max"] is not None: + axis_limits["x_max"] += sign * n_cells elif self.parallel_method == "X": if is_first: - axis_limits['x_min'] -= sign * n_cells - axis_limits['y_min'] -= sign * n_cells + axis_limits["x_min"] -= sign * n_cells + axis_limits["y_min"] -= sign * n_cells if is_last: - if axis_limits['x_max'] is not None: - axis_limits['x_max'] += sign * n_cells - if axis_limits['y_max'] is not None: - axis_limits['y_max'] += sign * n_cells + if axis_limits["x_max"] is not None: + axis_limits["x_max"] += sign * n_cells + if axis_limits["y_max"] is not None: + axis_limits["y_max"] += sign * n_cells else: - axis_limits['y_min'] -= sign * n_cells - axis_limits['x_min'] -= sign * n_cells - if axis_limits['y_max'] is not None: - axis_limits['y_max'] += sign * n_cells - if axis_limits['x_max'] is not None: - axis_limits['x_max'] += sign * n_cells - - if axis_limits['x_min'] < 0: - if axis_limits['x_max'] is not None: - axis_limits['x_max'] += abs(axis_limits['x_min']) - axis_limits['x_min'] = 0 - if axis_limits['y_min'] < 0: - if axis_limits['y_max'] is not None: - axis_limits['y_max'] += abs(axis_limits['y_min']) - axis_limits['y_min'] = 0 + axis_limits["y_min"] -= sign * n_cells + axis_limits["x_min"] -= sign * n_cells + if axis_limits["y_max"] is not None: + axis_limits["y_max"] += sign * n_cells + if axis_limits["x_max"] is not None: + axis_limits["x_max"] += sign * n_cells + + if axis_limits["x_min"] < 0: + if axis_limits["x_max"] is not None: + axis_limits["x_max"] += abs(axis_limits["x_min"]) + axis_limits["x_min"] = 0 + if axis_limits["y_min"] < 0: + if axis_limits["y_max"] is not None: + axis_limits["y_max"] += abs(axis_limits["y_min"]) + axis_limits["y_min"] = 0 return axis_limits - def _adjust_read_axis_limits(self, n_cells: int, expand: bool, is_first=True, is_last=True) -> None: + def _adjust_read_axis_limits( + self, n_cells: int, expand: bool, is_first=True, is_last=True + ) -> None: """ Adjust read axis limits by expanding or contracting them based on parallelization method. @@ -2290,26 +2599,28 @@ class Nes(object): if self.parallel_method == "Y": if is_first: - self.read_axis_limits['y_min'] -= sign * n_cells - self.read_axis_limits['x_min'] -= sign * n_cells + self.read_axis_limits["y_min"] -= sign * n_cells + self.read_axis_limits["x_min"] -= sign * n_cells if is_last: - self.read_axis_limits['y_max'] += sign * n_cells - self.read_axis_limits['x_max'] += sign * n_cells + self.read_axis_limits["y_max"] += sign * n_cells + self.read_axis_limits["x_max"] += sign * n_cells elif self.parallel_method == "X": if is_first: - self.read_axis_limits['x_min'] -= sign * n_cells - self.read_axis_limits['y_min'] -= sign * n_cells + self.read_axis_limits["x_min"] -= sign * n_cells + self.read_axis_limits["y_min"] -= sign * n_cells if is_last: - self.read_axis_limits['x_max'] += sign * n_cells - self.read_axis_limits['y_max'] += sign * n_cells + self.read_axis_limits["x_max"] += sign * n_cells + self.read_axis_limits["y_max"] += sign * n_cells else: - self.read_axis_limits['y_min'] -= sign * n_cells - self.read_axis_limits['x_min'] -= sign * n_cells - self.read_axis_limits['y_max'] += sign * n_cells - self.read_axis_limits['x_max'] += sign * n_cells + self.read_axis_limits["y_min"] -= sign * n_cells + self.read_axis_limits["x_min"] -= sign * n_cells + self.read_axis_limits["y_max"] += sign * n_cells + self.read_axis_limits["x_max"] += sign * n_cells return None - def _adjust_write_axis_limits(self, n_cells: int, expand: bool, is_first=True, is_last=True) -> None: + def _adjust_write_axis_limits( + self, n_cells: int, expand: bool, is_first=True, is_last=True + ) -> None: """ Adjust write axis limits by expanding or contracting them. @@ -2331,19 +2642,19 @@ class Nes(object): """ sign = 1 if expand else -1 if self.parallel_method == "X": - max_axis = ['x_max'] - min_axis = ['x_min'] + max_axis = ["x_max"] + min_axis = ["x_min"] elif self.parallel_method == "Y": - max_axis = ['y_max'] - min_axis = ['y_min'] + max_axis = ["y_max"] + min_axis = ["y_min"] else: - max_axis = ['y_max', 'x_max'] - min_axis = ['y_min', 'x_min'] + max_axis = ["y_max", "x_max"] + min_axis = ["y_min", "x_min"] for axis in max_axis: if self.write_axis_limits[axis] is not None: if is_last: - self.write_axis_limits[axis] += sign * n_cells*2 + self.write_axis_limits[axis] += sign * n_cells * 2 else: self.write_axis_limits[axis] += sign * n_cells for axis in min_axis: @@ -2409,20 +2720,24 @@ class Nes(object): my_limits = deepcopy(limits) for key in my_limits: if my_limits[key] is None: - my_limits[key] = float('inf') if 'min' in key else float('-inf') + my_limits[key] = float("inf") if "min" in key else float("-inf") # Perform MPI reduction to find global min and max values global_limits = {} for key in my_limits: - if 'min' in key: - global_limits[key] = self.comm.reduce(my_limits[key], op=MPI.MIN, root=0) - elif 'max' in key: - global_limits[key] = self.comm.reduce(my_limits[key], op=MPI.MAX, root=0) + if "min" in key: + global_limits[key] = self.comm.reduce( + my_limits[key], op=MPI.MIN, root=0 + ) + elif "max" in key: + global_limits[key] = self.comm.reduce( + my_limits[key], op=MPI.MAX, root=0 + ) # Restore None for dimensions that were originally None if self.rank == 0: for key in global_limits: - if global_limits[key] in [float('inf'), float('-inf')]: + if global_limits[key] in [float("inf"), float("-inf")]: global_limits[key] = None return global_limits @@ -2482,25 +2797,28 @@ class Nes(object): my_limits = deepcopy(limits) for key in my_limits: if my_limits[key] is None: - my_limits[key] = float('inf') if 'min' in key else float('-inf') + my_limits[key] = float("inf") if "min" in key else float("-inf") # Perform MPI reduction to find global min and max values global_limits = {} for key in my_limits: - if 'min' in key: - global_limits[key] = self.comm.reduce(my_limits[key], op=MPI.MIN, root=0) - elif 'max' in key: - global_limits[key] = self.comm.reduce(my_limits[key], op=MPI.MAX, root=0) + if "min" in key: + global_limits[key] = self.comm.reduce( + my_limits[key], op=MPI.MIN, root=0 + ) + elif "max" in key: + global_limits[key] = self.comm.reduce( + my_limits[key], op=MPI.MAX, root=0 + ) # Restore None for dimensions that were originally None if self.rank == 0: for key in global_limits: - if global_limits[key] in [float('inf'), float('-inf')]: + if global_limits[key] in [float("inf"), float("-inf")]: global_limits[key] = None return global_limits - def _filter_coordinates_selection(self): """ Use the selection limits to filter time, lev, lat, lon, lon_bnds and lat_bnds. @@ -2508,75 +2826,145 @@ class Nes(object): global_limits = self.get_global_limits(self.read_axis_limits) if self.master: - self._full_time = self._full_time[global_limits["t_min"]:global_limits["t_max"]] - self._full_lev["data"] = self._full_lev["data"][global_limits["z_min"]:global_limits["z_max"]] + self._full_time = self._full_time[ + global_limits["t_min"] : global_limits["t_max"] + ] + self._full_lev["data"] = self._full_lev["data"][ + global_limits["z_min"] : global_limits["z_max"] + ] if len(self._full_lat["data"].shape) == 1: # Regular projection - self._full_lat["data"] = self._full_lat["data"][global_limits["y_min"]:global_limits["y_max"]] + self._full_lat["data"] = self._full_lat["data"][ + global_limits["y_min"] : global_limits["y_max"] + ] if global_limits["x_min"] < 0: - self._full_lon["data"] = concatenate((self._full_lon["data"][global_limits["x_min"]:], - self._full_lon["data"][:global_limits["x_max"]])) + self._full_lon["data"] = concatenate( + ( + self._full_lon["data"][global_limits["x_min"] :], + self._full_lon["data"][: global_limits["x_max"]], + ) + ) self._full_lon["data"][self._full_lon["data"] > 180] -= 360 else: - self._full_lon["data"] = self._full_lon["data"][global_limits["x_min"]:global_limits["x_max"]] + self._full_lon["data"] = self._full_lon["data"][ + global_limits["x_min"] : global_limits["x_max"] + ] if self._full_lat_bnds is not None: - self._full_lat_bnds["data"] = self._full_lat_bnds["data"][global_limits["y_min"]:global_limits["y_max"], :] + self._full_lat_bnds["data"] = self._full_lat_bnds["data"][ + global_limits["y_min"] : global_limits["y_max"], : + ] if self._full_lon_bnds is not None: if global_limits["x_min"] < 0: - self._full_lon_bnds["data"] = concatenate((self._full_lon_bnds["data"][global_limits["x_min"]:, :], - self._full_lon_bnds["data"][:global_limits["x_max"], :])) - - self._full_lon_bnds["data"][self._full_lon_bnds["data"] > 180] -= 360 + self._full_lon_bnds["data"] = concatenate( + ( + self._full_lon_bnds["data"][ + global_limits["x_min"] :, : + ], + self._full_lon_bnds["data"][ + : global_limits["x_max"], : + ], + ) + ) + + self._full_lon_bnds["data"][ + self._full_lon_bnds["data"] > 180 + ] -= 360 else: - self._full_lon_bnds["data"] = self._full_lon_bnds["data"][global_limits["x_min"]:global_limits["x_max"], :] + self._full_lon_bnds["data"] = self._full_lon_bnds["data"][ + global_limits["x_min"] : global_limits["x_max"], : + ] else: # Irregular projections if global_limits["x_min"] < 0: - self._full_lat["data"] = concatenate(( - self._full_lat["data"][global_limits["y_min"]:global_limits["y_max"], - global_limits["x_min"]:], - self._full_lat["data"][global_limits["y_min"]:global_limits["y_max"], - :global_limits["x_max"]]), - axis=1) - - self._full_lon["data"] = concatenate(( - self._full_lon["data"][global_limits["y_min"]:global_limits["y_max"], - global_limits["x_min"]:], - self._full_lon["data"][global_limits["y_min"]:global_limits["y_max"], - :global_limits["x_max"]]), - axis=1) + self._full_lat["data"] = concatenate( + ( + self._full_lat["data"][ + global_limits["y_min"] : global_limits["y_max"], + global_limits["x_min"] :, + ], + self._full_lat["data"][ + global_limits["y_min"] : global_limits["y_max"], + : global_limits["x_max"], + ], + ), + axis=1, + ) + + self._full_lon["data"] = concatenate( + ( + self._full_lon["data"][ + global_limits["y_min"] : global_limits["y_max"], + global_limits["x_min"] :, + ], + self._full_lon["data"][ + global_limits["y_min"] : global_limits["y_max"], + : global_limits["x_max"], + ], + ), + axis=1, + ) self._full_lon["data"][self._full_lon["data"] > 180] -= 360 if self._full_lat_bnds is not None: - self._full_lat_bnds["data"] = concatenate(( - self._full_lat_bnds["data"][global_limits["y_min"]:global_limits["y_max"], - global_limits["x_min"]:, :], - self._full_lat_bnds["data"][global_limits["y_min"]:global_limits["y_max"], - :global_limits["x_max"], :]), - axis=1) + self._full_lat_bnds["data"] = concatenate( + ( + self._full_lat_bnds["data"][ + global_limits["y_min"] : global_limits["y_max"], + global_limits["x_min"] :, + :, + ], + self._full_lat_bnds["data"][ + global_limits["y_min"] : global_limits["y_max"], + : global_limits["x_max"], + :, + ], + ), + axis=1, + ) if self._full_lon_bnds is not None: - self._full_lon_bnds["data"] = concatenate(( - self._full_lon_bnds["data"][global_limits["y_min"]:global_limits["y_max"], - global_limits["x_min"]:, :], - self._full_lon_bnds["data"][global_limits["y_min"]:global_limits["y_max"], - :global_limits["x_max"], :]), - axis=1) - self._full_lon_bnds["data"][self._full_lon_bnds["data"] > 180] -= 360 + self._full_lon_bnds["data"] = concatenate( + ( + self._full_lon_bnds["data"][ + global_limits["y_min"] : global_limits["y_max"], + global_limits["x_min"] :, + :, + ], + self._full_lon_bnds["data"][ + global_limits["y_min"] : global_limits["y_max"], + : global_limits["x_max"], + :, + ], + ), + axis=1, + ) + self._full_lon_bnds["data"][ + self._full_lon_bnds["data"] > 180 + ] -= 360 else: - self._full_lat["data"] = self._full_lat["data"][global_limits["y_min"]:global_limits["y_max"], - global_limits["x_min"]:global_limits["x_max"]] + self._full_lat["data"] = self._full_lat["data"][ + global_limits["y_min"] : global_limits["y_max"], + global_limits["x_min"] : global_limits["x_max"], + ] - self._full_lon["data"] = self._full_lon["data"][global_limits["y_min"]:global_limits["y_max"], - global_limits["x_min"]:global_limits["x_max"]] + self._full_lon["data"] = self._full_lon["data"][ + global_limits["y_min"] : global_limits["y_max"], + global_limits["x_min"] : global_limits["x_max"], + ] if self._full_lat_bnds is not None: - self._full_lat_bnds["data"] = self._full_lat_bnds["data"][global_limits["y_min"]:global_limits["y_max"], - global_limits["x_min"]:global_limits["x_max"], :] + self._full_lat_bnds["data"] = self._full_lat_bnds["data"][ + global_limits["y_min"] : global_limits["y_max"], + global_limits["x_min"] : global_limits["x_max"], + :, + ] if self._full_lon_bnds is not None: - self._full_lon_bnds["data"] = self._full_lon_bnds["data"][global_limits["y_min"]:global_limits["y_max"], - global_limits["x_min"]:global_limits["x_max"], :] + self._full_lon_bnds["data"] = self._full_lon_bnds["data"][ + global_limits["y_min"] : global_limits["y_max"], + global_limits["x_min"] : global_limits["x_max"], + :, + ] self.hours_start = 0 self.hours_end = 0 @@ -2623,20 +3011,26 @@ class Nes(object): """ full_lat = self.get_full_latitudes() full_lon = self.get_full_longitudes() - idx = {"idx_t_min": self._get_time_id(self.hours_start, first=True), - "idx_t_max": self._get_time_id(self.hours_end, first=False), - "idx_z_min": self.first_level, - "idx_z_max": self.last_level} + idx = { + "idx_t_min": self._get_time_id(self.hours_start, first=True), + "idx_t_max": self._get_time_id(self.hours_end, first=False), + "idx_z_min": self.first_level, + "idx_z_max": self.last_level, + } # Axis Y if self.lat_min is None: idx["idx_y_min"] = 0 else: - idx["idx_y_min"] = self._get_coordinate_id(full_lat["data"], self.lat_min, axis=0) + idx["idx_y_min"] = self._get_coordinate_id( + full_lat["data"], self.lat_min, axis=0 + ) if self.lat_max is None: idx["idx_y_max"] = full_lat["data"].shape[0] else: - idx["idx_y_max"] = self._get_coordinate_id(full_lat["data"], self.lat_max, axis=0) + 1 + idx["idx_y_max"] = ( + self._get_coordinate_id(full_lat["data"], self.lat_max, axis=0) + 1 + ) if idx["idx_y_min"] > idx["idx_y_max"]: idx_aux = copy(idx["idx_y_min"]) @@ -2652,10 +3046,14 @@ class Nes(object): else: axis = 1 if self.lon_min < 0 and full_lon["data"].max() > 180: - aux_xmin = self._get_coordinate_id(full_lon["data"], self.lon_min + 360, axis=axis) - idx["idx_x_min"] = - (self.lon["data"].shape[-1] - aux_xmin) + aux_xmin = self._get_coordinate_id( + full_lon["data"], self.lon_min + 360, axis=axis + ) + idx["idx_x_min"] = -(self.lon["data"].shape[-1] - aux_xmin) else: - idx["idx_x_min"] = self._get_coordinate_id(full_lon["data"], self.lon_min, axis=axis) + idx["idx_x_min"] = self._get_coordinate_id( + full_lon["data"], self.lon_min, axis=axis + ) if self.lon_max is None: idx["idx_x_max"] = full_lon["data"].shape[-1] else: @@ -2663,7 +3061,9 @@ class Nes(object): axis = 0 else: axis = 1 - idx["idx_x_max"] = self._get_coordinate_id(full_lon["data"], self.lon_max, axis=axis) + 1 + idx["idx_x_max"] = ( + self._get_coordinate_id(full_lon["data"], self.lon_max, axis=axis) + 1 + ) if idx["idx_x_min"] > idx["idx_x_max"]: idx_aux = copy(idx["idx_x_min"]) @@ -2681,8 +3081,12 @@ class Nes(object): """ if self.parallel_method == "T": - raise NotImplementedError("Statistics are not implemented on time axis parallelization method.") - aux_time = self.get_full_times()[0].replace(hour=0, minute=0, second=0, microsecond=0) + raise NotImplementedError( + "Statistics are not implemented on time axis parallelization method." + ) + aux_time = self.get_full_times()[0].replace( + hour=0, minute=0, second=0, microsecond=0 + ) self.set_full_times([aux_time]) self.time = [aux_time] @@ -2691,7 +3095,9 @@ class Nes(object): self.load(var_name) aux_data = var_info["data"][-1, :] if len(aux_data.shape) == 3: - aux_data = aux_data.reshape((1, aux_data.shape[0], aux_data.shape[1], aux_data.shape[2])) + aux_data = aux_data.reshape( + (1, aux_data.shape[0], aux_data.shape[1], aux_data.shape[2]) + ) self.variables[var_name]["data"] = aux_data self.hours_start = 0 self.hours_end = 0 @@ -2715,7 +3121,9 @@ class Nes(object): """ if self.parallel_method == "T": - raise NotImplementedError("Statistics are not implemented on time axis parallel method.") + raise NotImplementedError( + "Statistics are not implemented on time axis parallel method." + ) time_interval = self.get_time_interval if type_op == "calendar": aux_time_bounds = [] @@ -2726,39 +3134,69 @@ class Nes(object): self.load(var_name) stat_data = None for day in unique(day_list): - idx_first = next(i for i, val in enumerate(day_list, 0) if val == day) - idx_last = len(day_list) - next(i for i, val in enumerate(reversed(day_list), 1) if val == day) + idx_first = next( + i for i, val in enumerate(day_list, 0) if val == day + ) + idx_last = len(day_list) - next( + i for i, val in enumerate(reversed(day_list), 1) if val == day + ) if idx_first != idx_last: # To avoid single time step statistic if idx_last != len(day_list): if op == "mean": - data_aux = var_info["data"][idx_first:idx_last + 1, :, :, :].mean(axis=0) + data_aux = var_info["data"][ + idx_first : idx_last + 1, :, :, : + ].mean(axis=0) elif op == "max": - data_aux = var_info["data"][idx_first:idx_last + 1, :, :, :].max(axis=0) + data_aux = var_info["data"][ + idx_first : idx_last + 1, :, :, : + ].max(axis=0) elif op == "min": - data_aux = var_info["data"][idx_first:idx_last + 1, :, :, :].min(axis=0) + data_aux = var_info["data"][ + idx_first : idx_last + 1, :, :, : + ].min(axis=0) else: - raise NotImplementedError(f"Statistic operation '{op}' is not implemented.") - aux_time_bounds.append([self.time[idx_first], self.time[idx_last]]) + raise NotImplementedError( + f"Statistic operation '{op}' is not implemented." + ) + aux_time_bounds.append( + [self.time[idx_first], self.time[idx_last]] + ) else: if op == "mean": - data_aux = var_info["data"][idx_first:, :, :, :].mean(axis=0) + data_aux = var_info["data"][idx_first:, :, :, :].mean( + axis=0 + ) elif op == "max": - data_aux = var_info["data"][idx_first:, :, :, :].max(axis=0) + data_aux = var_info["data"][idx_first:, :, :, :].max( + axis=0 + ) elif op == "min": - data_aux = var_info["data"][idx_first:, :, :, :].min(axis=0) + data_aux = var_info["data"][idx_first:, :, :, :].min( + axis=0 + ) else: - raise NotImplementedError(f"Statistic operation '{op}' is not implemented.") - aux_time_bounds.append([self.time[idx_first], self.time[-1]]) - - data_aux = data_aux.reshape((1, data_aux.shape[0], data_aux.shape[1], data_aux.shape[2])) - aux_time.append(self.time[idx_first].replace(hour=0, minute=0, second=0)) + raise NotImplementedError( + f"Statistic operation '{op}' is not implemented." + ) + aux_time_bounds.append( + [self.time[idx_first], self.time[-1]] + ) + + data_aux = data_aux.reshape( + (1, data_aux.shape[0], data_aux.shape[1], data_aux.shape[2]) + ) + aux_time.append( + self.time[idx_first].replace(hour=0, minute=0, second=0) + ) # Append over time dimension if stat_data is None: stat_data = data_aux.copy() else: stat_data = vstack([stat_data, data_aux]) self.variables[var_name]["data"] = stat_data - self.variables[var_name]["cell_methods"] = "time: {0} (interval: {1}hr)".format(op, time_interval) + self.variables[var_name]["cell_methods"] = ( + "time: {0} (interval: {1}hr)".format(op, time_interval) + ) self.time = aux_time self.set_full_times(self.time) @@ -2775,11 +3213,17 @@ class Nes(object): elif op == "min": aux_data = var_info["data"].min(axis=0) else: - raise NotImplementedError(f"Statistic operation '{op}' is not implemented.") + raise NotImplementedError( + f"Statistic operation '{op}' is not implemented." + ) if len(aux_data.shape) == 3: - aux_data = aux_data.reshape((1, aux_data.shape[0], aux_data.shape[1], aux_data.shape[2])) + aux_data = aux_data.reshape( + (1, aux_data.shape[0], aux_data.shape[1], aux_data.shape[2]) + ) self.variables[var_name]["data"] = aux_data - self.variables[var_name]["cell_methods"] = "time: {0} (interval: {1}hr)".format(op, time_interval) + self.variables[var_name]["cell_methods"] = ( + "time: {0} (interval: {1}hr)".format(op, time_interval) + ) aux_time = self.time[0].replace(hour=0, minute=0, second=0, microsecond=0) aux_time_bounds = [[self.time[0], self.time[-1]]] @@ -2799,11 +3243,17 @@ class Nes(object): elif op == "min": aux_data = var_info["data"][1:, :].min(axis=0) else: - raise NotImplementedError(f"Statistic operation '{op}' is not implemented.") + raise NotImplementedError( + f"Statistic operation '{op}' is not implemented." + ) if len(aux_data.shape) == 3: - aux_data = aux_data.reshape((1, aux_data.shape[0], aux_data.shape[1], aux_data.shape[2])) + aux_data = aux_data.reshape( + (1, aux_data.shape[0], aux_data.shape[1], aux_data.shape[2]) + ) self.variables[var_name]["data"] = aux_data - self.variables[var_name]["cell_methods"] = "time: {0} (interval: {1}hr)".format(op, time_interval) + self.variables[var_name]["cell_methods"] = ( + "time: {0} (interval: {1}hr)".format(op, time_interval) + ) full_time = self.get_full_times() aux_time = full_time[1].replace(hour=0, minute=0, second=0, microsecond=0) aux_time_bounds = [[full_time[1], full_time[-1]]] @@ -2812,7 +3262,9 @@ class Nes(object): self.set_time_bnds(aux_time_bounds) else: - raise NotImplementedError(f"Statistic operation type '{type_op}' is not implemented.") + raise NotImplementedError( + f"Statistic operation type '{type_op}' is not implemented." + ) self.hours_start = 0 self.hours_end = 0 @@ -2820,7 +3272,6 @@ class Nes(object): @staticmethod def _get_axis_index_(axis): - if axis == "T": value = 0 elif axis == "Z": @@ -2835,18 +3286,22 @@ class Nes(object): return value def sum_axis(self, axis="Z"): - if self.parallel_method == axis: raise NotImplementedError( - f"It is not possible to sum the axis with it is parallelized '{self.parallel_method}'") + f"It is not possible to sum the axis with it is parallelized '{self.parallel_method}'" + ) for var_name, var_info in self.variables.items(): if var_info["data"] is not None: self.variables[var_name]["data"] = self.variables[var_name]["data"].sum( - axis=self._get_axis_index_(axis), keepdims=True) + axis=self._get_axis_index_(axis), keepdims=True + ) if axis == "T": - self.variables[var_name]["cell_methods"] = "time: sum (interval: {0}hr)".format( - (self.time[-1] - self.time[0]).total_seconds() // 3600) + self.variables[var_name]["cell_methods"] = ( + "time: sum (interval: {0}hr)".format( + (self.time[-1] - self.time[0]).total_seconds() // 3600 + ) + ) if axis == "T": self.set_time_bnds([self.time[0], self.time[-1]]) @@ -2894,7 +3349,9 @@ class Nes(object): """ if self.parallel_method == "T": - raise NotImplementedError("The rolling mean cannot be calculated using the time axis parallel method.") + raise NotImplementedError( + "The rolling mean cannot be calculated using the time axis parallel method." + ) aux_nessy = self.copy(copy_vars=False) aux_nessy.set_communicator(self.comm) @@ -2915,11 +3372,13 @@ class Nes(object): # Initialise array aux_nessy.variables[var_name] = {} aux_nessy.variables[var_name]["data"] = empty(shape=nessy_shape) - aux_nessy.variables[var_name]["dimensions"] = deepcopy(self.variables[var_name]["dimensions"]) + aux_nessy.variables[var_name]["dimensions"] = deepcopy( + self.variables[var_name]["dimensions"] + ) for curr_time in self.time: # Get previous time given a set of hours - prev_time = curr_time - timedelta(hours=(hours-1)) + prev_time = curr_time - timedelta(hours=(hours - 1)) # Get time indices curr_time_id = self.find_time_id(curr_time) @@ -2929,8 +3388,11 @@ class Nes(object): if prev_time_id is not None: if self.info: print(f"Calculating mean between {prev_time} and {curr_time}.") - aux_nessy.variables[var_name]["data"][curr_time_id, :, :, :] = self.variables[var_name]["data"][ - prev_time_id:curr_time_id, :, :, :].mean(axis=0, keepdims=True) + aux_nessy.variables[var_name]["data"][curr_time_id, :, :, :] = ( + self.variables[var_name]["data"][ + prev_time_id:curr_time_id, :, :, : + ].mean(axis=0, keepdims=True) + ) # Fill with nan if previous time is not available else: if self.info: @@ -2938,7 +3400,9 @@ class Nes(object): msg += f"because data for {prev_time} is not available." print(msg) aux_nessy.variables[var_name]["data"][curr_time_id, :, :, :] = full( - shape=(1, nessy_shape[1], nessy_shape[2], nessy_shape[3]), fill_value=nan) + shape=(1, nessy_shape[1], nessy_shape[2], nessy_shape[3]), + fill_value=nan, + ) return aux_nessy @@ -2973,20 +3437,31 @@ class Nes(object): t_min, t_max, z_min, z_max, y_min, y_max, x_min and x_max. """ - axis_limits = {"x_min": None, "x_max": None, - "y_min": None, "y_max": None, - "z_min": None, "z_max": None, - "t_min": None, "t_max": None} + axis_limits = { + "x_min": None, + "x_max": None, + "y_min": None, + "y_max": None, + "z_min": None, + "z_max": None, + "t_min": None, + "t_max": None, + } idx = self._get_idx_intervals() if self.parallel_method == "Y": y_len = idx["idx_y_max"] - idx["idx_y_min"] if y_len < self.size: - raise IndexError("More processors (size={0}) selected than Y elements (size={1})".format( - self.size, y_len)) + raise IndexError( + "More processors (size={0}) selected than Y elements (size={1})".format( + self.size, y_len + ) + ) axis_limits["y_min"] = ((y_len // self.size) * self.rank) + idx["idx_y_min"] if self.rank + 1 < self.size: - axis_limits["y_max"] = ((y_len // self.size) * (self.rank + 1)) + idx["idx_y_min"] + axis_limits["y_max"] = ((y_len // self.size) * (self.rank + 1)) + idx[ + "idx_y_min" + ] else: axis_limits["y_max"] = idx["idx_y_max"] @@ -3000,11 +3475,16 @@ class Nes(object): elif self.parallel_method == "X": x_len = idx["idx_x_max"] - idx["idx_x_min"] if x_len < self.size: - raise IndexError("More processors (size={0}) selected than X elements (size={1})".format( - self.size, x_len)) + raise IndexError( + "More processors (size={0}) selected than X elements (size={1})".format( + self.size, x_len + ) + ) axis_limits["x_min"] = ((x_len // self.size) * self.rank) + idx["idx_x_min"] if self.rank + 1 < self.size: - axis_limits["x_max"] = ((x_len // self.size) * (self.rank + 1)) + idx["idx_x_min"] + axis_limits["x_max"] = ((x_len // self.size) * (self.rank + 1)) + idx[ + "idx_x_min" + ] else: axis_limits["x_max"] = idx["idx_x_max"] @@ -3018,11 +3498,16 @@ class Nes(object): elif self.parallel_method == "T": t_len = idx["idx_t_max"] - idx["idx_t_min"] if t_len < self.size: - raise IndexError("More processors (size={0}) selected than T elements (size={1})".format( - self.size, t_len)) + raise IndexError( + "More processors (size={0}) selected than T elements (size={1})".format( + self.size, t_len + ) + ) axis_limits["t_min"] = ((t_len // self.size) * self.rank) + idx["idx_t_min"] if self.rank + 1 < self.size: - axis_limits["t_max"] = ((t_len // self.size) * (self.rank + 1)) + idx["idx_t_min"] + axis_limits["t_max"] = ((t_len // self.size) * (self.rank + 1)) + idx[ + "idx_t_min" + ] # Non parallel filters axis_limits["y_min"] = idx["idx_y_min"] @@ -3032,8 +3517,11 @@ class Nes(object): axis_limits["x_max"] = idx["idx_x_max"] else: - raise NotImplementedError("Parallel method '{meth}' is not implemented. Use one of these: {accept}".format( - meth=self.parallel_method, accept=["X", "Y", "T"])) + raise NotImplementedError( + "Parallel method '{meth}' is not implemented. Use one of these: {accept}".format( + meth=self.parallel_method, accept=["X", "Y", "T"] + ) + ) # Vertical levels selection: axis_limits["z_min"] = self.first_level @@ -3063,8 +3551,11 @@ class Nes(object): if self.parallel_method == "Y": len_to_split = idx["idx_y_max"] - idx["idx_y_min"] if len_to_split < self.size: - raise IndexError("More processors (size={0}) selected than Y elements (size={1})".format( - self.size, len_to_split)) + raise IndexError( + "More processors (size={0}) selected than Y elements (size={1})".format( + self.size, len_to_split + ) + ) min_axis = "y_min" max_axis = "y_max" to_add = idx["idx_y_min"] @@ -3072,31 +3563,45 @@ class Nes(object): elif self.parallel_method == "X": len_to_split = idx["idx_x_max"] - idx["idx_x_min"] if len_to_split < self.size: - raise IndexError("More processors (size={0}) selected than X elements (size={1})".format( - self.size, len_to_split)) + raise IndexError( + "More processors (size={0}) selected than X elements (size={1})".format( + self.size, len_to_split + ) + ) min_axis = "x_min" max_axis = "x_max" to_add = idx["idx_x_min"] elif self.parallel_method == "T": len_to_split = idx["idx_t_max"] - idx["idx_t_min"] if len_to_split < self.size: - raise IndexError(f"More processors (size={self.size}) selected than T elements (size={len_to_split})") + raise IndexError( + f"More processors (size={self.size}) selected than T elements (size={len_to_split})" + ) min_axis = "t_min" max_axis = "t_max" to_add = idx["idx_t_min"] else: - raise NotImplementedError("Parallel method '{meth}' is not implemented. Use one of these: {accept}".format( - meth=self.parallel_method, accept=["X", "Y", "T"])) + raise NotImplementedError( + "Parallel method '{meth}' is not implemented. Use one of these: {accept}".format( + meth=self.parallel_method, accept=["X", "Y", "T"] + ) + ) procs_len = len_to_split // self.size procs_rows_extended = len_to_split - (procs_len * self.size) rows_sum = 0 for proc in range(self.size): - fid_dist[proc] = {"x_min": 0, "x_max": None, - "y_min": 0, "y_max": None, - "z_min": 0, "z_max": None, - "t_min": 0, "t_max": None} + fid_dist[proc] = { + "x_min": 0, + "x_max": None, + "y_min": 0, + "y_max": None, + "z_min": 0, + "z_max": None, + "t_min": 0, + "t_max": None, + } if proc < procs_rows_extended: aux_rows = procs_len + 1 else: @@ -3221,10 +3726,18 @@ class Nes(object): """ if self.size == 1: - netcdf = Dataset(self.__ini_path, format="NETCDF4", mode=mode, parallel=False) + netcdf = Dataset( + self.__ini_path, format="NETCDF4", mode=mode, parallel=False + ) else: - netcdf = Dataset(self.__ini_path, format="NETCDF4", mode=mode, parallel=True, comm=self.comm, - info=MPI.Info()) + netcdf = Dataset( + self.__ini_path, + format="NETCDF4", + mode=mode, + parallel=True, + comm=self.comm, + info=MPI.Info(), + ) self.dataset = netcdf return netcdf @@ -3263,7 +3776,11 @@ class Nes(object): """ start_date_str = time.units.split("since")[1].lstrip() - start_date = datetime(int(start_date_str[0:4]), int(start_date_str[5:7]), int(start_date_str[8:10])) + start_date = datetime( + int(start_date_str[0:4]), + int(start_date_str[5:7]), + int(start_date_str[8:10]), + ) new_time_deltas = [] @@ -3330,7 +3847,13 @@ class Nes(object): if "h @" in t_units: t_units = "hours since {0}-{1}-{2} {3}:{4}:{5} UTC".format( - t_units[4:8], t_units[8:10], t_units[10:12], t_units[13:15], t_units[15:17], t_units[17:-4]) + t_units[4:8], + t_units[8:10], + t_units[10:12], + t_units[13:15], + t_units[15:17], + t_units[17:-4], + ) return t_units @@ -3380,7 +3903,9 @@ class Nes(object): # Accepted name options for the time variable accepted_time_names = ["TIME", "valid_time"] # Get name of the time variable of the dataset - time_var_name = list(set(nc_variable_names).intersection(set(accepted_time_names)))[0] + time_var_name = list( + set(nc_variable_names).intersection(set(accepted_time_names)) + )[0] nc_var = self.dataset.variables[time_var_name] time_data, units, calendar = self.__parse_time(nc_var) @@ -3391,7 +3916,16 @@ class Nes(object): self._climatology = True self._climatology_var_name = nc_var.climatology time = num2date(time_data, units, calendar=calendar) - time = [datetime(year=dt.year, month=dt.month, day=dt.day, hour=dt.hour, minute=dt.minute) for dt in time] + time = [ + datetime( + year=dt.year, + month=dt.month, + day=dt.day, + hour=dt.hour, + minute=dt.minute, + ) + for dt in time + ] else: time = None # Free the time variable @@ -3422,8 +3956,11 @@ class Nes(object): nc_var = self.dataset.variables[self._climatology_var_name] else: nc_var = self.dataset.variables["time_bnds"] - time_bnds = num2date(nc_var[:], self.__parse_time_unit(time.units), - calendar=time.calendar).tolist() + time_bnds = num2date( + nc_var[:], + self.__parse_time_unit(time.units), + calendar=time.calendar, + ).tolist() # Iterate over each inner list for inner_list in time_bnds: # Create a new list to store datetime objects @@ -3431,7 +3968,13 @@ class Nes(object): # Iterate over datetime objects within each inner list for dt in inner_list: # Access year, month, day, hour, and minute attributes of datetime objects - new_dt = datetime(year=dt.year, month=dt.month, day=dt.day, hour=dt.hour, minute=dt.minute) + new_dt = datetime( + year=dt.year, + month=dt.month, + day=dt.day, + hour=dt.hour, + minute=dt.minute, + ) # Append the new datetime object to the new inner list new_inner_list.append(new_dt) # Replace the old inner list with the new one @@ -3467,12 +4010,20 @@ class Nes(object): if not create_nes: if self.master: if "lat_bnds" in self.dataset.variables.keys(): - lat_bnds = {"data": self._unmask_array(self.dataset.variables["lat_bnds"][:])} + lat_bnds = { + "data": self._unmask_array( + self.dataset.variables["lat_bnds"][:] + ) + } else: lat_bnds = None if "lon_bnds" in self.dataset.variables.keys(): - lon_bnds = {"data": self._unmask_array(self.dataset.variables["lon_bnds"][:])} + lon_bnds = { + "data": self._unmask_array( + self.dataset.variables["lon_bnds"][:] + ) + } else: lon_bnds = None else: @@ -3506,7 +4057,9 @@ class Nes(object): if not create_nes: if "cell_area" in self.dataset.variables.keys(): c_measures["cell_area"] = {} - c_measures["cell_area"]["data"] = self._unmask_array(self.dataset.variables["cell_area"][:]) + c_measures["cell_area"]["data"] = self._unmask_array( + self.dataset.variables["cell_area"][:] + ) c_measures = self.comm.bcast(c_measures, root=0) self.free_vars(["cell_area"]) @@ -3534,7 +4087,9 @@ class Nes(object): possible_names = [possible_names] try: - dimension_name = set(possible_names).intersection(set(self.variables.keys())).pop() + dimension_name = ( + set(possible_names).intersection(set(self.variables.keys())).pop() + ) if self.master: nc_var = self.variables[dimension_name].copy() nc_var["data"] = self.dataset.variables[dimension_name][:] @@ -3546,8 +4101,7 @@ class Nes(object): self.free_vars(dimension_name) except KeyError: if self.master: - nc_var = {"data": array([0]), - "units": ""} + nc_var = {"data": array([0]), "units": ""} else: nc_var = None @@ -3585,43 +4139,78 @@ class Nes(object): if coordinate_axis == "Y": if coordinate_len == 1: - values["data"] = values["data"][self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"]] + values["data"] = values["data"][ + self.read_axis_limits["y_min"] : self.read_axis_limits["y_max"] + ] elif coordinate_len == 2: - values["data"] = values["data"][self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:self.read_axis_limits["x_max"]] + values["data"] = values["data"][ + self.read_axis_limits["y_min"] : self.read_axis_limits["y_max"], + self.read_axis_limits["x_min"] : self.read_axis_limits["x_max"], + ] else: - raise NotImplementedError("The coordinate has wrong dimensions: {dim}".format( - dim=values["data"].shape)) + raise NotImplementedError( + "The coordinate has wrong dimensions: {dim}".format( + dim=values["data"].shape + ) + ) elif coordinate_axis == "X": if coordinate_len == 1: if self.read_axis_limits["x_min"] < 0: # Negative longitudes - values["data"] = concatenate((values["data"][self.read_axis_limits["x_min"]:], - values["data"][:self.read_axis_limits["x_max"]])) + values["data"] = concatenate( + ( + values["data"][self.read_axis_limits["x_min"] :], + values["data"][: self.read_axis_limits["x_max"]], + ) + ) values["data"][values["data"] > 180] -= 360 else: - values["data"] = values["data"][self.read_axis_limits["x_min"]:self.read_axis_limits["x_max"]] + values["data"] = values["data"][ + self.read_axis_limits["x_min"] : self.read_axis_limits["x_max"] + ] elif coordinate_len == 2: if self.read_axis_limits["x_min"] < 0: # Negative longitudes values["data"] = concatenate( - (values["data"][self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:], - values["data"][self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - :self.read_axis_limits["x_max"]]), axis=1) + ( + values["data"][ + self.read_axis_limits["y_min"] : self.read_axis_limits[ + "y_max" + ], + self.read_axis_limits["x_min"] :, + ], + values["data"][ + self.read_axis_limits["y_min"] : self.read_axis_limits[ + "y_max" + ], + : self.read_axis_limits["x_max"], + ], + ), + axis=1, + ) values["data"][values["data"] > 180] -= 360 else: - values["data"] = values["data"][self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:self.read_axis_limits["x_max"]] + values["data"] = values["data"][ + self.read_axis_limits["y_min"] : self.read_axis_limits["y_max"], + self.read_axis_limits["x_min"] : self.read_axis_limits["x_max"], + ] else: - raise NotImplementedError("The coordinate has wrong dimensions: {dim}".format( - dim=values["data"].shape)) + raise NotImplementedError( + "The coordinate has wrong dimensions: {dim}".format( + dim=values["data"].shape + ) + ) elif coordinate_axis == "Z": if coordinate_len == 1: - values["data"] = values["data"][self.read_axis_limits["z_min"]:self.read_axis_limits["z_max"]] + values["data"] = values["data"][ + self.read_axis_limits["z_min"] : self.read_axis_limits["z_max"] + ] else: - raise NotImplementedError("The coordinate has wrong dimensions: {dim}".format( - dim=values["data"].shape)) + raise NotImplementedError( + "The coordinate has wrong dimensions: {dim}".format( + dim=values["data"].shape + ) + ) return values @@ -3646,18 +4235,24 @@ class Nes(object): cell_measures_values = {} for cell_measures_var in cell_measures_info.keys(): - values = deepcopy(cell_measures_info[cell_measures_var]) coordinate_len = len(values["data"].shape) if coordinate_len == 1: - values["data"] = values["data"][self.read_axis_limits["x_min"]:self.read_axis_limits["x_max"]] + values["data"] = values["data"][ + self.read_axis_limits["x_min"] : self.read_axis_limits["x_max"] + ] elif coordinate_len == 2: - values["data"] = values["data"][self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:self.read_axis_limits["x_max"]] + values["data"] = values["data"][ + self.read_axis_limits["y_min"] : self.read_axis_limits["y_max"], + self.read_axis_limits["x_min"] : self.read_axis_limits["x_max"], + ] else: - raise NotImplementedError("The coordinate has wrong dimensions: {dim}".format( - dim=values["data"].shape)) + raise NotImplementedError( + "The coordinate has wrong dimensions: {dim}".format( + dim=values["data"].shape + ) + ) cell_measures_values[cell_measures_var] = values @@ -3694,7 +4289,12 @@ class Nes(object): # Avoid some attributes for attrname in var_info.ncattrs(): - if attrname not in ["missing_value", "_FillValue", "add_offset", "scale_factor"]: + if attrname not in [ + "missing_value", + "_FillValue", + "add_offset", + "scale_factor", + ]: value = getattr(var_info, attrname) if str(value) in ["unitless", "-"]: value = "" @@ -3728,118 +4328,264 @@ class Nes(object): if len(var_dims) < 2: data = nc_var[:] elif len(var_dims) == 2: - data = concatenate(( - nc_var[self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:], - nc_var[self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - :self.read_axis_limits["x_max"]]), axis=1) + data = concatenate( + ( + nc_var[ + self.read_axis_limits["y_min"] : self.read_axis_limits[ + "y_max" + ], + self.read_axis_limits["x_min"] :, + ], + nc_var[ + self.read_axis_limits["y_min"] : self.read_axis_limits[ + "y_max" + ], + : self.read_axis_limits["x_max"], + ], + ), + axis=1, + ) data = data.reshape(1, 1, data.shape[-2], data.shape[-1]) elif len(var_dims) == 3: if "strlen" in var_dims: - data = concatenate(( - nc_var[self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:, :], - nc_var[self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - :self.read_axis_limits["x_max"], :]), axis=1) + data = concatenate( + ( + nc_var[ + self.read_axis_limits["y_min"] : self.read_axis_limits[ + "y_max" + ], + self.read_axis_limits["x_min"] :, + :, + ], + nc_var[ + self.read_axis_limits["y_min"] : self.read_axis_limits[ + "y_max" + ], + : self.read_axis_limits["x_max"], + :, + ], + ), + axis=1, + ) data_aux = empty(shape=(data.shape[0], data.shape[1]), dtype=object) for lat_n in range(data.shape[0]): for lon_n in range(data.shape[1]): data_aux[lat_n, lon_n] = "".join( - data[lat_n, lon_n].tobytes().decode("ascii").replace("\x00", "")) - data = data_aux.reshape((1, 1, data_aux.shape[-2], data_aux.shape[-1])) + data[lat_n, lon_n] + .tobytes() + .decode("ascii") + .replace("\x00", "") + ) + data = data_aux.reshape( + (1, 1, data_aux.shape[-2], data_aux.shape[-1]) + ) else: - data = concatenate(( - nc_var[self.read_axis_limits["t_min"]:self.read_axis_limits["t_max"], - self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:], - nc_var[self.read_axis_limits["t_min"]:self.read_axis_limits["t_max"], - self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - :self.read_axis_limits["x_max"]]), axis=2) - data = data.reshape(data.shape[-3], 1, data.shape[-2], data.shape[-1]) + data = concatenate( + ( + nc_var[ + self.read_axis_limits["t_min"] : self.read_axis_limits[ + "t_max" + ], + self.read_axis_limits["y_min"] : self.read_axis_limits[ + "y_max" + ], + self.read_axis_limits["x_min"] :, + ], + nc_var[ + self.read_axis_limits["t_min"] : self.read_axis_limits[ + "t_max" + ], + self.read_axis_limits["y_min"] : self.read_axis_limits[ + "y_max" + ], + : self.read_axis_limits["x_max"], + ], + ), + axis=2, + ) + data = data.reshape( + data.shape[-3], 1, data.shape[-2], data.shape[-1] + ) elif len(var_dims) == 4: - data = concatenate(( - nc_var[self.read_axis_limits["t_min"]:self.read_axis_limits["t_max"], - self.read_axis_limits["z_min"]:self.read_axis_limits["z_max"], - self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:], - nc_var[self.read_axis_limits["t_min"]:self.read_axis_limits["t_max"], - self.read_axis_limits["z_min"]:self.read_axis_limits["z_max"], - self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - :self.read_axis_limits["x_max"]]), axis=3) + data = concatenate( + ( + nc_var[ + self.read_axis_limits["t_min"] : self.read_axis_limits[ + "t_max" + ], + self.read_axis_limits["z_min"] : self.read_axis_limits[ + "z_max" + ], + self.read_axis_limits["y_min"] : self.read_axis_limits[ + "y_max" + ], + self.read_axis_limits["x_min"] :, + ], + nc_var[ + self.read_axis_limits["t_min"] : self.read_axis_limits[ + "t_max" + ], + self.read_axis_limits["z_min"] : self.read_axis_limits[ + "z_max" + ], + self.read_axis_limits["y_min"] : self.read_axis_limits[ + "y_max" + ], + : self.read_axis_limits["x_max"], + ], + ), + axis=3, + ) elif len(var_dims) == 5: if "strlen" in var_dims: - data = concatenate(( - nc_var[self.read_axis_limits["t_min"]:self.read_axis_limits["t_max"], - self.read_axis_limits["z_min"]:self.read_axis_limits["z_max"], - self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:, :], - nc_var[self.read_axis_limits["t_min"]:self.read_axis_limits["t_max"], - self.read_axis_limits["z_min"]:self.read_axis_limits["z_max"], - self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - :self.read_axis_limits["x_max"], :]), axis=3) - data_aux = empty(shape=(data.shape[0], data.shape[1], data.shape[2], data.shape[3]), dtype=object) + data = concatenate( + ( + nc_var[ + self.read_axis_limits["t_min"] : self.read_axis_limits[ + "t_max" + ], + self.read_axis_limits["z_min"] : self.read_axis_limits[ + "z_max" + ], + self.read_axis_limits["y_min"] : self.read_axis_limits[ + "y_max" + ], + self.read_axis_limits["x_min"] :, + :, + ], + nc_var[ + self.read_axis_limits["t_min"] : self.read_axis_limits[ + "t_max" + ], + self.read_axis_limits["z_min"] : self.read_axis_limits[ + "z_max" + ], + self.read_axis_limits["y_min"] : self.read_axis_limits[ + "y_max" + ], + : self.read_axis_limits["x_max"], + :, + ], + ), + axis=3, + ) + data_aux = empty( + shape=( + data.shape[0], + data.shape[1], + data.shape[2], + data.shape[3], + ), + dtype=object, + ) for time_n in range(data.shape[0]): for lev_n in range(data.shape[1]): for lat_n in range(data.shape[2]): for lon_n in range(data.shape[3]): data_aux[time_n, lev_n, lat_n, lon_n] = "".join( - data[time_n, lev_n, lat_n, lon_n].tobytes().decode("ascii").replace("\x00", "")) + data[time_n, lev_n, lat_n, lon_n] + .tobytes() + .decode("ascii") + .replace("\x00", "") + ) data = data_aux else: - raise NotImplementedError("Error with {0}. Only can be read netCDF with 4 dimensions or less".format( - var_name)) + raise NotImplementedError( + "Error with {0}. Only can be read netCDF with 4 dimensions or less".format( + var_name + ) + ) else: - raise NotImplementedError("Error with {0}. Only can be read netCDF with 4 dimensions or less".format( - var_name)) + raise NotImplementedError( + "Error with {0}. Only can be read netCDF with 4 dimensions or less".format( + var_name + ) + ) else: if len(var_dims) < 2: data = nc_var[:] elif len(var_dims) == 2: - data = nc_var[self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:self.read_axis_limits["x_max"]] + data = nc_var[ + self.read_axis_limits["y_min"] : self.read_axis_limits["y_max"], + self.read_axis_limits["x_min"] : self.read_axis_limits["x_max"], + ] data = data.reshape(1, 1, data.shape[-2], data.shape[-1]) elif len(var_dims) == 3: if "strlen" in var_dims: - data = nc_var[self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:self.read_axis_limits["x_max"], - :] + data = nc_var[ + self.read_axis_limits["y_min"] : self.read_axis_limits["y_max"], + self.read_axis_limits["x_min"] : self.read_axis_limits["x_max"], + :, + ] data_aux = empty(shape=(data.shape[0], data.shape[1]), dtype=object) for lat_n in range(data.shape[0]): for lon_n in range(data.shape[1]): data_aux[lat_n, lon_n] = "".join( - data[lat_n, lon_n].tobytes().decode("ascii").replace("\x00", "")) - data = data_aux.reshape((1, 1, data_aux.shape[-2], data_aux.shape[-1])) + data[lat_n, lon_n] + .tobytes() + .decode("ascii") + .replace("\x00", "") + ) + data = data_aux.reshape( + (1, 1, data_aux.shape[-2], data_aux.shape[-1]) + ) else: - data = nc_var[self.read_axis_limits["t_min"]:self.read_axis_limits["t_max"], - self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:self.read_axis_limits["x_max"]] - data = data.reshape(data.shape[-3], 1, data.shape[-2], data.shape[-1]) + data = nc_var[ + self.read_axis_limits["t_min"] : self.read_axis_limits["t_max"], + self.read_axis_limits["y_min"] : self.read_axis_limits["y_max"], + self.read_axis_limits["x_min"] : self.read_axis_limits["x_max"], + ] + data = data.reshape( + data.shape[-3], 1, data.shape[-2], data.shape[-1] + ) elif len(var_dims) == 4: - data = nc_var[self.read_axis_limits["t_min"]:self.read_axis_limits["t_max"], - self.read_axis_limits["z_min"]:self.read_axis_limits["z_max"], - self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:self.read_axis_limits["x_max"]] + data = nc_var[ + self.read_axis_limits["t_min"] : self.read_axis_limits["t_max"], + self.read_axis_limits["z_min"] : self.read_axis_limits["z_max"], + self.read_axis_limits["y_min"] : self.read_axis_limits["y_max"], + self.read_axis_limits["x_min"] : self.read_axis_limits["x_max"], + ] elif len(var_dims) == 5: if "strlen" in var_dims: - data = nc_var[self.read_axis_limits["t_min"]:self.read_axis_limits["t_max"], - self.read_axis_limits["z_min"]:self.read_axis_limits["z_max"], - self.read_axis_limits["y_min"]:self.read_axis_limits["y_max"], - self.read_axis_limits["x_min"]:self.read_axis_limits["x_max"], - :] - data_aux = empty(shape=(data.shape[0], data.shape[1], data.shape[2], data.shape[3]), dtype=object) + data = nc_var[ + self.read_axis_limits["t_min"] : self.read_axis_limits["t_max"], + self.read_axis_limits["z_min"] : self.read_axis_limits["z_max"], + self.read_axis_limits["y_min"] : self.read_axis_limits["y_max"], + self.read_axis_limits["x_min"] : self.read_axis_limits["x_max"], + :, + ] + data_aux = empty( + shape=( + data.shape[0], + data.shape[1], + data.shape[2], + data.shape[3], + ), + dtype=object, + ) for time_n in range(data.shape[0]): for lev_n in range(data.shape[1]): for lat_n in range(data.shape[2]): for lon_n in range(data.shape[3]): data_aux[time_n, lev_n, lat_n, lon_n] = "".join( - data[time_n, lev_n, lat_n, lon_n].tobytes().decode("ascii").replace("\x00", "")) + data[time_n, lev_n, lat_n, lon_n] + .tobytes() + .decode("ascii") + .replace("\x00", "") + ) data = data_aux else: - raise NotImplementedError("Error with {0}. Only can be read netCDF with 4 dimensions or less".format( - var_name)) + raise NotImplementedError( + "Error with {0}. Only can be read netCDF with 4 dimensions or less".format( + var_name + ) + ) else: - raise NotImplementedError("Error with {0}. Only can be read netCDF with 4 dimensions or less".format( - var_name)) + raise NotImplementedError( + "Error with {0}. Only can be read netCDF with 4 dimensions or less".format( + var_name + ) + ) # Unmask array data = self._unmask_array(data) @@ -3874,7 +4620,11 @@ class Nes(object): for i, var_name in enumerate(var_list): if self.info: - print("Rank {0:03d}: Loading {1} var ({2}/{3})".format(self.rank, var_name, i + 1, len(var_list))) + print( + "Rank {0:03d}: Loading {1} var ({2}/{3})".format( + self.rank, var_name, i + 1, len(var_list) + ) + ) if self.variables[var_name]["data"] is None: self.variables[var_name]["data"] = self._read_variable(var_name) # Data type changes when joining characters in read_variable (S1 to S+strlen) @@ -3882,14 +4632,26 @@ class Nes(object): if self.strlen is None: self.set_strlen() self.variables[var_name]["dtype"] = str - self.variables[var_name]["dimensions"] = tuple([x for x in self.variables[var_name]["dimensions"] - if x != "strlen"]) + self.variables[var_name]["dimensions"] = tuple( + [ + x + for x in self.variables[var_name]["dimensions"] + if x != "strlen" + ] + ) else: if self.master: - print("Data for {0} was previously loaded. Skipping variable.".format(var_name)) + print( + "Data for {0} was previously loaded. Skipping variable.".format( + var_name + ) + ) if self.info: - print("Rank {0:03d}: Loaded {1} var ({2})".format( - self.rank, var_name, self.variables[var_name]["data"].shape)) + print( + "Rank {0:03d}: Loaded {1} var ({2})".format( + self.rank, var_name, self.variables[var_name]["data"].shape + ) + ) if close: self.close() @@ -3925,7 +4687,7 @@ class Nes(object): return data def to_dtype(self, data_type="float32"): - """ Cast variables data into selected data type. + """Cast variables data into selected data type. Parameters ---------- @@ -3935,7 +4697,9 @@ class Nes(object): for var_name, var_info in self.variables.items(): if isinstance(var_info["data"], ndarray): - self.variables[var_name]["data"] = self.variables[var_name]["data"].astype(data_type) + self.variables[var_name]["data"] = self.variables[var_name][ + "data" + ].astype(data_type) self.variables[var_name]["dtype"] = data_type return None @@ -3956,9 +4720,15 @@ class Nes(object): """ if isinstance(aux_nessy, str): - aux_nessy = self.new(path=aux_nessy, comm=self.comm, parallel_method=self.parallel_method, - avoid_first_hours=self.hours_start, avoid_last_hours=self.hours_end, - first_level=self.first_level, last_level=self.last_level) + aux_nessy = self.new( + path=aux_nessy, + comm=self.comm, + parallel_method=self.parallel_method, + avoid_first_hours=self.hours_start, + avoid_last_hours=self.hours_end, + first_level=self.first_level, + last_level=self.last_level, + ) new = True else: new = False @@ -4032,10 +4802,16 @@ class Nes(object): t_min, t_max, z_min, z_max, y_min, y_max, x_min and x_max. """ - axis_limits = {"x_min": None, "x_max": None, - "y_min": None, "y_max": None, - "z_min": None, "z_max": None, - "t_min": None, "t_max": None} + axis_limits = { + "x_min": None, + "x_max": None, + "y_min": None, + "y_max": None, + "z_min": None, + "z_max": None, + "t_min": None, + "t_max": None, + } my_shape = self.get_full_shape() if self.parallel_method == "Y": y_len = my_shape[0] @@ -4049,12 +4825,15 @@ class Nes(object): axis_limits["x_max"] = (x_len // self.size) * (self.rank + 1) elif self.parallel_method == "T": t_len = len(self.get_full_times()) - axis_limits["t_min"] = ((t_len // self.size) * self.rank) + axis_limits["t_min"] = (t_len // self.size) * self.rank if self.rank + 1 < self.size: axis_limits["t_max"] = (t_len // self.size) * (self.rank + 1) else: - raise NotImplementedError("Parallel method '{meth}' is not implemented. Use one of these: {accept}".format( - meth=self.parallel_method, accept=["X", "Y", "T"])) + raise NotImplementedError( + "Parallel method '{meth}' is not implemented. Use one of these: {accept}".format( + meth=self.parallel_method, accept=["X", "Y", "T"] + ) + ) return axis_limits @@ -4083,18 +4862,27 @@ class Nes(object): min_axis = "t_min" max_axis = "t_max" else: - raise NotImplementedError("Parallel method '{meth}' is not implemented. Use one of these: {accept}".format( - meth=self.parallel_method, accept=["X", "Y", "T"])) + raise NotImplementedError( + "Parallel method '{meth}' is not implemented. Use one of these: {accept}".format( + meth=self.parallel_method, accept=["X", "Y", "T"] + ) + ) procs_len = len_to_split // self.size procs_rows_extended = len_to_split - (procs_len * self.size) rows_sum = 0 for proc in range(self.size): - fid_dist[proc] = {"x_min": 0, "x_max": None, - "y_min": 0, "y_max": None, - "z_min": 0, "z_max": None, - "t_min": 0, "t_max": None} + fid_dist[proc] = { + "x_min": 0, + "x_max": None, + "y_min": 0, + "y_max": None, + "z_min": 0, + "z_max": None, + "t_min": 0, + "t_max": None, + } if proc < procs_rows_extended: aux_rows = procs_len + 1 else: @@ -4173,8 +4961,12 @@ class Nes(object): # TIMES full_time = self.get_full_times() full_time_bnds = self.get_full_time_bnds() - time_var = netcdf.createVariable("time", float32, ("time",), zlib=self.zip_lvl > 0, complevel=self.zip_lvl) - time_var.units = "{0} since {1}".format(self._time_resolution, full_time[0].strftime("%Y-%m-%d %H:%M:%S")) + time_var = netcdf.createVariable( + "time", float32, ("time",), zlib=self.zip_lvl > 0, complevel=self.zip_lvl + ) + time_var.units = "{0} since {1}".format( + self._time_resolution, full_time[0].strftime("%Y-%m-%d %H:%M:%S") + ) time_var.standard_name = "time" time_var.calendar = "standard" time_var.long_name = "time" @@ -4190,19 +4982,38 @@ class Nes(object): # TIME BOUNDS if full_time_bnds is not None: if self._climatology: - time_bnds_var = netcdf.createVariable(self._climatology_var_name, float64, ("time", "time_nv",), - zlib=self.zip_lvl, complevel=self.zip_lvl) + time_bnds_var = netcdf.createVariable( + self._climatology_var_name, + float64, + ( + "time", + "time_nv", + ), + zlib=self.zip_lvl, + complevel=self.zip_lvl, + ) else: - time_bnds_var = netcdf.createVariable("time_bnds", float64, ("time", "time_nv",), - zlib=self.zip_lvl, complevel=self.zip_lvl) + time_bnds_var = netcdf.createVariable( + "time_bnds", + float64, + ( + "time", + "time_nv", + ), + zlib=self.zip_lvl, + complevel=self.zip_lvl, + ) if self.size > 1: time_bnds_var.set_collective(True) - time_bnds_var[:] = date2num(full_time_bnds, time_var.units, calendar="standard") + time_bnds_var[:] = date2num( + full_time_bnds, time_var.units, calendar="standard" + ) # LEVELS full_lev = self.get_full_levels() - lev = netcdf.createVariable("lev", float32, ("lev",), - zlib=self.zip_lvl > 0, complevel=self.zip_lvl) + lev = netcdf.createVariable( + "lev", float32, ("lev",), zlib=self.zip_lvl > 0, complevel=self.zip_lvl + ) if "units" in full_lev.keys(): lev.units = full_lev["units"] else: @@ -4217,8 +5028,9 @@ class Nes(object): # LATITUDES full_lat = self.get_full_latitudes() full_lat_bnds = self.get_full_latitudes_boundaries() - lat = netcdf.createVariable("lat", float32, self._lat_dim, - zlib=self.zip_lvl > 0, complevel=self.zip_lvl) + lat = netcdf.createVariable( + "lat", float32, self._lat_dim, zlib=self.zip_lvl > 0, complevel=self.zip_lvl + ) lat.units = "degrees_north" lat.axis = "Y" lat.long_name = "latitude coordinate" @@ -4231,9 +5043,13 @@ class Nes(object): # LATITUDES BOUNDS if full_lat_bnds is not None: - lat_bnds_var = netcdf.createVariable("lat_bnds", float32, - self._lat_dim + ("spatial_nv",), - zlib=self.zip_lvl > 0, complevel=self.zip_lvl) + lat_bnds_var = netcdf.createVariable( + "lat_bnds", + float32, + self._lat_dim + ("spatial_nv",), + zlib=self.zip_lvl > 0, + complevel=self.zip_lvl, + ) if self.size > 1: lat_bnds_var.set_collective(True) lat_bnds_var[:] = array(full_lat_bnds["data"], dtype=float32) @@ -4241,8 +5057,9 @@ class Nes(object): # LONGITUDES full_lon = self.get_full_longitudes() full_lon_bnds = self.get_full_longitudes_boundaries() - lon = netcdf.createVariable("lon", float32, self._lon_dim, - zlib=self.zip_lvl > 0, complevel=self.zip_lvl) + lon = netcdf.createVariable( + "lon", float32, self._lon_dim, zlib=self.zip_lvl > 0, complevel=self.zip_lvl + ) lon.units = "degrees_east" lon.axis = "X" lon.long_name = "longitude coordinate" @@ -4255,9 +5072,13 @@ class Nes(object): # LONGITUDES BOUNDS if full_lon_bnds is not None: - lon_bnds_var = netcdf.createVariable("lon_bnds", float32, - self._lon_dim + ("spatial_nv",), - zlib=self.zip_lvl > 0, complevel=self.zip_lvl) + lon_bnds_var = netcdf.createVariable( + "lon_bnds", + float32, + self._lon_dim + ("spatial_nv",), + zlib=self.zip_lvl > 0, + complevel=self.zip_lvl, + ) if self.size > 1: lon_bnds_var.set_collective(True) lon_bnds_var[:] = array(full_lon_bnds["data"], dtype=float32) @@ -4277,8 +5098,12 @@ class Nes(object): # TIMES full_time = self.get_full_times() full_time_bnds = self.get_full_time_bnds() - time_var = netcdf.createVariable("time", float64, ("time",), zlib=self.zip_lvl > 0, complevel=self.zip_lvl) - time_var.units = "{0} since {1}".format(self._time_resolution, full_time[0].strftime("%Y-%m-%d %H:%M:%S")) + time_var = netcdf.createVariable( + "time", float64, ("time",), zlib=self.zip_lvl > 0, complevel=self.zip_lvl + ) + time_var.units = "{0} since {1}".format( + self._time_resolution, full_time[0].strftime("%Y-%m-%d %H:%M:%S") + ) time_var.standard_name = "time" time_var.calendar = "standard" time_var.long_name = "time" @@ -4294,19 +5119,42 @@ class Nes(object): # TIME BOUNDS if full_time_bnds is not None: if self._climatology: - time_bnds_var = netcdf.createVariable(self._climatology_var_name, float64, ("time", "time_nv",), - zlib=self.zip_lvl, complevel=self.zip_lvl) + time_bnds_var = netcdf.createVariable( + self._climatology_var_name, + float64, + ( + "time", + "time_nv", + ), + zlib=self.zip_lvl, + complevel=self.zip_lvl, + ) else: - time_bnds_var = netcdf.createVariable("time_bnds", float64, ("time", "time_nv",), - zlib=self.zip_lvl, complevel=self.zip_lvl) + time_bnds_var = netcdf.createVariable( + "time_bnds", + float64, + ( + "time", + "time_nv", + ), + zlib=self.zip_lvl, + complevel=self.zip_lvl, + ) if self.size > 1: time_bnds_var.set_collective(True) - time_bnds_var[:] = date2num(full_time_bnds, time_var.units, calendar="standard") + time_bnds_var[:] = date2num( + full_time_bnds, time_var.units, calendar="standard" + ) # LEVELS full_lev = self.get_full_levels() - lev = netcdf.createVariable("lev", full_lev["data"].dtype, ("lev",), - zlib=self.zip_lvl > 0, complevel=self.zip_lvl) + lev = netcdf.createVariable( + "lev", + full_lev["data"].dtype, + ("lev",), + zlib=self.zip_lvl > 0, + complevel=self.zip_lvl, + ) if "units" in full_lev.keys(): lev.units = full_lev["units"] else: @@ -4321,8 +5169,13 @@ class Nes(object): # LATITUDES full_lat = self.get_full_latitudes() full_lat_bnds = self.get_full_latitudes_boundaries() - lat = netcdf.createVariable("lat", full_lat["data"].dtype, self._lat_dim, - zlib=self.zip_lvl > 0, complevel=self.zip_lvl) + lat = netcdf.createVariable( + "lat", + full_lat["data"].dtype, + self._lat_dim, + zlib=self.zip_lvl > 0, + complevel=self.zip_lvl, + ) lat.units = "degrees_north" lat.axis = "Y" lat.long_name = "latitude coordinate" @@ -4335,9 +5188,13 @@ class Nes(object): # LATITUDES BOUNDS if full_lat_bnds is not None: - lat_bnds_var = netcdf.createVariable("lat_bnds", full_lat_bnds["data"].dtype, - self._lat_dim + ("spatial_nv",), - zlib=self.zip_lvl > 0, complevel=self.zip_lvl) + lat_bnds_var = netcdf.createVariable( + "lat_bnds", + full_lat_bnds["data"].dtype, + self._lat_dim + ("spatial_nv",), + zlib=self.zip_lvl > 0, + complevel=self.zip_lvl, + ) if self.size > 1: lat_bnds_var.set_collective(True) lat_bnds_var[:] = full_lat_bnds["data"] @@ -4345,8 +5202,13 @@ class Nes(object): # LONGITUDES full_lon = self.get_full_longitudes() full_lon_bnds = self.get_full_longitudes_boundaries() - lon = netcdf.createVariable("lon", full_lon["data"].dtype, self._lon_dim, - zlib=self.zip_lvl > 0, complevel=self.zip_lvl) + lon = netcdf.createVariable( + "lon", + full_lon["data"].dtype, + self._lon_dim, + zlib=self.zip_lvl > 0, + complevel=self.zip_lvl, + ) lon.units = "degrees_east" lon.axis = "X" lon.long_name = "longitude coordinate" @@ -4359,9 +5221,13 @@ class Nes(object): # LONGITUDES BOUNDS if full_lon_bnds is not None: - lon_bnds_var = netcdf.createVariable("lon_bnds", full_lon_bnds["data"].dtype, - self._lon_dim + ("spatial_nv",), - zlib=self.zip_lvl > 0, complevel=self.zip_lvl) + lon_bnds_var = netcdf.createVariable( + "lon_bnds", + full_lon_bnds["data"].dtype, + self._lon_dim + ("spatial_nv",), + zlib=self.zip_lvl > 0, + complevel=self.zip_lvl, + ) if self.size > 1: lon_bnds_var.set_collective(True) lon_bnds_var[:] = full_lon_bnds["data"] @@ -4369,16 +5235,21 @@ class Nes(object): return None def _create_cell_measures(self, netcdf): - # CELL AREA if "cell_area" in self.cell_measures.keys(): - cell_area = netcdf.createVariable("cell_area", self.cell_measures["cell_area"]["data"].dtype, self._var_dim, - zlib=self.zip_lvl > 0, complevel=self.zip_lvl) + cell_area = netcdf.createVariable( + "cell_area", + self.cell_measures["cell_area"]["data"].dtype, + self._var_dim, + zlib=self.zip_lvl > 0, + complevel=self.zip_lvl, + ) if self.size > 1: cell_area.set_collective(True) - cell_area[self.write_axis_limits["y_min"]:self.write_axis_limits["y_max"], - self.write_axis_limits["x_min"]:self.write_axis_limits["x_max"]] = \ - self.cell_measures["cell_area"]["data"] + cell_area[ + self.write_axis_limits["y_min"] : self.write_axis_limits["y_max"], + self.write_axis_limits["x_min"] : self.write_axis_limits["x_max"], + ] = self.cell_measures["cell_area"]["data"] cell_area.long_name = "area of grid cell" cell_area.standard_name = "cell_area" @@ -4392,20 +5263,23 @@ class Nes(object): return None def _str2char(self, data): - if self.strlen is None: msg = "String data could not be converted into chars while writing." msg += " Please, set the maximum string length (set_strlen) before writing." raise RuntimeError(msg) # Get final shape by adding strlen at the end - data_new_shape = data.shape + (self.strlen, ) + data_new_shape = data.shape + (self.strlen,) # nD (2D, 3D, 4D) data as 1D string array data = data.flatten() # Split strings into chars (S1) - data_aux = stringtochar(array([v.encode("ascii", "ignore") for v in data]).astype("S" + str(self.strlen))) + data_aux = stringtochar( + array([v.encode("ascii", "ignore") for v in data]).astype( + "S" + str(self.strlen) + ) + ) data_aux = data_aux.reshape(data_new_shape) return data_aux @@ -4424,42 +5298,61 @@ class Nes(object): for i, (var_name, var_dict) in enumerate(self.variables.items()): if isinstance(var_dict["data"], int) and var_dict["data"] == 0: - var_dims = ("time", "lev",) + self._var_dim + var_dims = ( + "time", + "lev", + ) + self._var_dim var_dtype = float32 else: # Get dimensions if (var_dict["data"] is None) or (len(var_dict["data"].shape) == 4): - var_dims = ("time", "lev",) + self._var_dim + var_dims = ( + "time", + "lev", + ) + self._var_dim else: var_dims = self._var_dim # Get data type if "dtype" in var_dict.keys(): var_dtype = var_dict["dtype"] - if (var_dict["data"] is not None) and (var_dtype != var_dict["data"].dtype): + if (var_dict["data"] is not None) and ( + var_dtype != var_dict["data"].dtype + ): msg = "WARNING!!! " - msg += "Different data types for variable {0}. ".format(var_name) - msg += "Input dtype={0}. Data dtype={1}.".format(var_dtype, var_dict["data"].dtype) + msg += "Different data types for variable {0}. ".format( + var_name + ) + msg += "Input dtype={0}. Data dtype={1}.".format( + var_dtype, var_dict["data"].dtype + ) warn(msg) sys.stderr.flush() try: var_dict["data"] = var_dict["data"].astype(var_dtype) except Exception as e: # TODO: Detect exception print(e) - raise TypeError("It was not possible to cast the data to the input dtype.") + raise TypeError( + "It was not possible to cast the data to the input dtype." + ) else: var_dtype = var_dict["data"].dtype if var_dtype is object: - raise TypeError("Data dtype is object. Define dtype explicitly as dictionary key 'dtype'") + raise TypeError( + "Data dtype is object. Define dtype explicitly as dictionary key 'dtype'" + ) if var_dict["data"] is not None: - # Ensure data is of type numpy array (to create NES) if not isinstance(var_dict["data"], (ndarray, generic)): try: var_dict["data"] = array(var_dict["data"]) except AttributeError: - raise AttributeError("Data for variable {0} must be a numpy array.".format(var_name)) + raise AttributeError( + "Data for variable {0} must be a numpy array.".format( + var_name + ) + ) # Convert list of strings to chars for parallelization if issubdtype(var_dtype, character): @@ -4468,12 +5361,20 @@ class Nes(object): var_dtype = "S1" if self.info: - print("Rank {0:03d}: Writing {1} var ({2}/{3})".format( - self.rank, var_name, i + 1, len(self.variables))) + print( + "Rank {0:03d}: Writing {1} var ({2}/{3})".format( + self.rank, var_name, i + 1, len(self.variables) + ) + ) if not chunking: - var = netcdf.createVariable(var_name, var_dtype, var_dims, - zlib=self.zip_lvl > 0, complevel=self.zip_lvl) + var = netcdf.createVariable( + var_name, + var_dtype, + var_dims, + zlib=self.zip_lvl > 0, + complevel=self.zip_lvl, + ) else: if self.balanced: raise NotImplementedError("A balanced data cannot be chunked.") @@ -4482,60 +5383,122 @@ class Nes(object): else: chunk_size = None chunk_size = self.comm.bcast(chunk_size, root=0) - var = netcdf.createVariable(var_name, var_dtype, var_dims, - zlib=self.zip_lvl > 0, complevel=self.zip_lvl, - chunksizes=chunk_size) + var = netcdf.createVariable( + var_name, + var_dtype, + var_dims, + zlib=self.zip_lvl > 0, + complevel=self.zip_lvl, + chunksizes=chunk_size, + ) if self.info: - print("Rank {0:03d}: Var {1} created ({2}/{3})".format( - self.rank, var_name, i + 1, len(self.variables))) + print( + "Rank {0:03d}: Var {1} created ({2}/{3})".format( + self.rank, var_name, i + 1, len(self.variables) + ) + ) if self.size > 1: var.set_collective(True) if self.info: - print("Rank {0:03d}: Var {1} collective ({2}/{3})".format( - self.rank, var_name, i + 1, len(self.variables))) + print( + "Rank {0:03d}: Var {1} collective ({2}/{3})".format( + self.rank, var_name, i + 1, len(self.variables) + ) + ) for att_name, att_value in var_dict.items(): if att_name == "data": if att_value is not None: if self.info: - print("Rank {0:03d}: Filling {1}".format(self.rank, var_name)) + print( + "Rank {0:03d}: Filling {1}".format(self.rank, var_name) + ) if "data_aux" in var_dict.keys(): att_value = var_dict["data_aux"] if isinstance(att_value, int) and att_value == 0: - var[self.write_axis_limits["t_min"]:self.write_axis_limits["t_max"], - self.write_axis_limits["z_min"]:self.write_axis_limits["z_max"], - self.write_axis_limits["y_min"]:self.write_axis_limits["y_max"], - self.write_axis_limits["x_min"]:self.write_axis_limits["x_max"]] = 0 + var[ + self.write_axis_limits[ + "t_min" + ] : self.write_axis_limits["t_max"], + self.write_axis_limits[ + "z_min" + ] : self.write_axis_limits["z_max"], + self.write_axis_limits[ + "y_min" + ] : self.write_axis_limits["y_max"], + self.write_axis_limits[ + "x_min" + ] : self.write_axis_limits["x_max"], + ] = 0 elif len(att_value.shape) == 5: if "strlen" in var_dims: - var[self.write_axis_limits["t_min"]:self.write_axis_limits["t_max"], - self.write_axis_limits["z_min"]:self.write_axis_limits["z_max"], - self.write_axis_limits["y_min"]:self.write_axis_limits["y_max"], - self.write_axis_limits["x_min"]:self.write_axis_limits["x_max"], - :] = att_value + var[ + self.write_axis_limits[ + "t_min" + ] : self.write_axis_limits["t_max"], + self.write_axis_limits[ + "z_min" + ] : self.write_axis_limits["z_max"], + self.write_axis_limits[ + "y_min" + ] : self.write_axis_limits["y_max"], + self.write_axis_limits[ + "x_min" + ] : self.write_axis_limits["x_max"], + :, + ] = att_value else: - raise NotImplementedError("It is not possible to write 5D variables.") + raise NotImplementedError( + "It is not possible to write 5D variables." + ) elif len(att_value.shape) == 4: - var[self.write_axis_limits["t_min"]:self.write_axis_limits["t_max"], - self.write_axis_limits["z_min"]:self.write_axis_limits["z_max"], - self.write_axis_limits["y_min"]:self.write_axis_limits["y_max"], - self.write_axis_limits["x_min"]:self.write_axis_limits["x_max"]] = att_value + var[ + self.write_axis_limits[ + "t_min" + ] : self.write_axis_limits["t_max"], + self.write_axis_limits[ + "z_min" + ] : self.write_axis_limits["z_max"], + self.write_axis_limits[ + "y_min" + ] : self.write_axis_limits["y_max"], + self.write_axis_limits[ + "x_min" + ] : self.write_axis_limits["x_max"], + ] = att_value elif len(att_value.shape) == 3: if "strlen" in var_dims: - var[self.write_axis_limits["y_min"]:self.write_axis_limits["y_max"], - self.write_axis_limits["x_min"]:self.write_axis_limits["x_max"], - :] = att_value + var[ + self.write_axis_limits[ + "y_min" + ] : self.write_axis_limits["y_max"], + self.write_axis_limits[ + "x_min" + ] : self.write_axis_limits["x_max"], + :, + ] = att_value else: - raise NotImplementedError("It is not possible to write 3D variables.") + raise NotImplementedError( + "It is not possible to write 3D variables." + ) if self.info: - print("Rank {0:03d}: Var {1} data ({2}/{3})".format( - self.rank, var_name, i + 1, len(self.variables))) - - elif att_name not in ["chunk_size", "var_dims", "dimensions", "dtype", "data_aux"]: + print( + "Rank {0:03d}: Var {1} data ({2}/{3})".format( + self.rank, var_name, i + 1, len(self.variables) + ) + ) + + elif att_name not in [ + "chunk_size", + "var_dims", + "dimensions", + "dtype", + "data_aux", + ]: var.setncattr(att_name, att_value) if "data_aux" in var_dict.keys(): @@ -4543,8 +5506,11 @@ class Nes(object): self._set_var_crs(var) if self.info: - print("Rank {0:03d}: Var {1} completed ({2}/{3})".format( - self.rank, var_name, i + 1, len(self.variables))) + print( + "Rank {0:03d}: Var {1} completed ({2}/{3})".format( + self.rank, var_name, i + 1, len(self.variables) + ) + ) return None @@ -4581,37 +5547,71 @@ class Nes(object): for i, (var_name, var_dict) in enumerate(self.variables.items()): for att_name, att_value in var_dict.items(): if att_name == "data": - if att_value is not None: if self.info: - print("Rank {0:03d}: Filling {1}".format(self.rank, var_name)) + print( + "Rank {0:03d}: Filling {1}".format( + self.rank, var_name + ) + ) var = self.dataset.variables[var_name] if isinstance(att_value, int) and att_value == 0: - var[i_time, - self.write_axis_limits["z_min"]:self.write_axis_limits["z_max"], - self.write_axis_limits["y_min"]:self.write_axis_limits["y_max"], - self.write_axis_limits["x_min"]:self.write_axis_limits["x_max"]] = 0 + var[ + i_time, + self.write_axis_limits[ + "z_min" + ] : self.write_axis_limits["z_max"], + self.write_axis_limits[ + "y_min" + ] : self.write_axis_limits["y_max"], + self.write_axis_limits[ + "x_min" + ] : self.write_axis_limits["x_max"], + ] = 0 elif len(att_value.shape) == 4: if len(var.shape) == 3: # No level info - var[i_time, - self.write_axis_limits["y_min"]:self.write_axis_limits["y_max"], - self.write_axis_limits["x_min"]:self.write_axis_limits["x_max"]] = att_value + var[ + i_time, + self.write_axis_limits[ + "y_min" + ] : self.write_axis_limits["y_max"], + self.write_axis_limits[ + "x_min" + ] : self.write_axis_limits["x_max"], + ] = att_value else: - var[i_time, - self.write_axis_limits["z_min"]:self.write_axis_limits["z_max"], - self.write_axis_limits["y_min"]:self.write_axis_limits["y_max"], - self.write_axis_limits["x_min"]:self.write_axis_limits["x_max"]] = att_value + var[ + i_time, + self.write_axis_limits[ + "z_min" + ] : self.write_axis_limits["z_max"], + self.write_axis_limits[ + "y_min" + ] : self.write_axis_limits["y_max"], + self.write_axis_limits[ + "x_min" + ] : self.write_axis_limits["x_max"], + ] = att_value elif len(att_value.shape) == 3: - raise NotImplementedError("It is not possible to write 3D variables.") + raise NotImplementedError( + "It is not possible to write 3D variables." + ) else: - raise NotImplementedError("SHAPE APPEND ERROR: {0}".format(att_value.shape)) + raise NotImplementedError( + "SHAPE APPEND ERROR: {0}".format(att_value.shape) + ) if self.info: - print("Rank {0:03d}: Var {1} data ({2}/{3})".format( - self.rank, var_name, i + 1, len(self.variables))) + print( + "Rank {0:03d}: Var {1} data ({2}/{3})".format( + self.rank, var_name, i + 1, len(self.variables) + ) + ) else: - raise ValueError("Cannot append None Data for {0}".format(var_name)) + raise ValueError( + "Cannot append None Data for {0}".format(var_name) + ) else: # Metadata already writen pass @@ -4671,7 +5671,14 @@ class Nes(object): if self.info: print("Rank {0:03d}: Creating {1}".format(self.rank, path)) if self.size > 1: - netcdf = Dataset(path, format="NETCDF4", mode="w", parallel=True, comm=self.comm, info=MPI.Info()) + netcdf = Dataset( + path, + format="NETCDF4", + mode="w", + parallel=True, + comm=self.comm, + info=MPI.Info(), + ) else: netcdf = Dataset(path, format="NETCDF4", mode="w", parallel=False) if self.info: @@ -4710,8 +5717,16 @@ class Nes(object): def __to_netcdf_cams_ra(self, path): return to_netcdf_cams_ra(self, path) - def to_netcdf(self, path, compression_level=0, serial=False, info=False, chunking=False, nc_type="NES", - keep_open=False): + def to_netcdf( + self, + path, + compression_level=0, + serial=False, + info=False, + chunking=False, + nc_type="NES", + keep_open=False, + ): """ Write the netCDF output file. @@ -4762,7 +5777,9 @@ class Nes(object): elif nc_type == "CAMS_RA": new_nc.__to_netcdf_cams_ra(path) elif nc_type == "MONARCH": - to_netcdf_monarch(new_nc, path, chunking=chunking, keep_open=keep_open) + to_netcdf_monarch( + new_nc, path, chunking=chunking, keep_open=keep_open + ) elif nc_type == "CMAQ": to_netcdf_cmaq(new_nc, path, keep_open=keep_open) elif nc_type == "WRF_CHEM": @@ -4791,14 +5808,18 @@ class Nes(object): to_netcdf_mocage(self, path, keep_open=keep_open) else: msg = f"Unknown NetCDF type '{nc_type}''. " - msg += "Use CAMS_RA, MONARCH, CMAQ, WRF_CHEM, MOCAGE or NES (or DEFAULT)" + msg += ( + "Use CAMS_RA, MONARCH, CMAQ, WRF_CHEM, MOCAGE or NES (or DEFAULT)" + ) raise ValueError(msg) self.info = old_info return None - def __to_grib2(self, path, grib_keys, grib_template_path, lat_flip=True, info=False): + def __to_grib2( + self, path, grib_keys, grib_template_path, lat_flip=True, info=False + ): """ Private method to write output file with grib2 format. @@ -4814,17 +5835,19 @@ class Nes(object): Indicates if you want to print extra information during the process. """ - from eccodes import codes_grib_new_from_file - from eccodes import codes_keys_iterator_new - from eccodes import codes_keys_iterator_next - from eccodes import codes_keys_iterator_get_name - from eccodes import codes_get_string - from eccodes import codes_keys_iterator_delete - from eccodes import codes_clone - from eccodes import codes_set - from eccodes import codes_set_values - from eccodes import codes_write - from eccodes import codes_release + from eccodes import ( + codes_clone, + codes_get_string, + codes_grib_new_from_file, + codes_keys_iterator_delete, + codes_keys_iterator_get_name, + codes_keys_iterator_new, + codes_keys_iterator_next, + codes_release, + codes_set, + codes_set_values, + codes_write, + ) fout = open(path, "wb") @@ -4854,29 +5877,49 @@ class Nes(object): try: codes_set(clone_id, key, value) except Exception as e: - print(f"Something went wrong while writing the Grib key '{key}': {value}") + print( + f"Something went wrong while writing the Grib key '{key}': {value}" + ) raise e # Time dependent keys - if "dataTime" in grib_keys.keys() and grib_keys["dataTime"] in ["", "None", None, nan]: + if "dataTime" in grib_keys.keys() and grib_keys["dataTime"] in [ + "", + "None", + None, + nan, + ]: codes_set(clone_id, "dataTime", int(i_time * 100)) - if "stepRange" in grib_keys.keys() and grib_keys["stepRange"] in ["", "None", None, nan]: + if "stepRange" in grib_keys.keys() and grib_keys["stepRange"] in [ + "", + "None", + None, + nan, + ]: n_secs = (time - self.get_full_times()[0]).total_seconds() codes_set(clone_id, "stepRange", int(n_secs // 3600)) - if "forecastTime" in grib_keys.keys() and grib_keys["forecastTime"] in ["", "None", None, nan]: + if "forecastTime" in grib_keys.keys() and grib_keys[ + "forecastTime" + ] in ["", "None", None, nan]: n_secs = (time - self.get_full_times()[0]).total_seconds() codes_set(clone_id, "forecastTime", int(n_secs)) # Level dependent keys - if "typeOfFirstFixedSurface" in grib_keys.keys() and \ - grib_keys["typeOfFirstFixedSurface"] in ["", "None", None, nan]: + if "typeOfFirstFixedSurface" in grib_keys.keys() and grib_keys[ + "typeOfFirstFixedSurface" + ] in ["", "None", None, nan]: if float(lev) == 0: codes_set(clone_id, "typeOfFirstFixedSurface", 1) # grib_keys["typeOfFirstFixedSurface"] = 1 else: codes_set(clone_id, "typeOfFirstFixedSurface", 103) # grib_keys["typeOfFirstFixedSurface"] = 103 - if "level" in grib_keys.keys() and grib_keys["level"] in ["", "None", None, nan]: + if "level" in grib_keys.keys() and grib_keys["level"] in [ + "", + "None", + None, + nan, + ]: codes_set(clone_id, "level", float(lev)) newval = var_info["data"][i_time, i_lev, :, :] @@ -4884,7 +5927,7 @@ class Nes(object): newval = flipud(newval) # TODO Check default NaN Value - newval[isnan(newval)] = 0. + newval[isnan(newval)] = 0.0 codes_set_values(clone_id, array(newval.ravel(), dtype="float64")) codes_write(clone_id, fout) @@ -4927,9 +5970,13 @@ class Nes(object): new_nc = self.copy(copy_vars=False, new_comm=MPI.COMM_SELF) new_nc.variables = data new_nc.cell_measures = c_measures - new_nc.__to_grib2(path, grib_keys, grib_template_path, lat_flip=lat_flip, info=info) + new_nc.__to_grib2( + path, grib_keys, grib_template_path, lat_flip=lat_flip, info=info + ) else: - self.__to_grib2(path, grib_keys, grib_template_path, lat_flip=lat_flip, info=info) + self.__to_grib2( + path, grib_keys, grib_template_path, lat_flip=lat_flip, info=info + ) return None @@ -4944,12 +5991,15 @@ class Nes(object): """ if self.shapefile is None or overwrite: - if self.lat_bnds is None or self.lon_bnds is None: self.create_spatial_bounds() # Reshape arrays to create geometry - aux_shape = (self.lat_bnds["data"].shape[0], self.lon_bnds["data"].shape[0], 4) + aux_shape = ( + self.lat_bnds["data"].shape[0], + self.lon_bnds["data"].shape[0], + 4, + ) lon_bnds_aux = empty(aux_shape) lon_bnds_aux[:, :, 0] = self.lon_bnds["data"][newaxis, :, 0] lon_bnds_aux[:, :, 1] = self.lon_bnds["data"][newaxis, :, 1] @@ -4968,20 +6018,34 @@ class Nes(object): lat_bnds = lat_bnds_aux del lat_bnds_aux - aux_b_lats = lat_bnds.reshape((lat_bnds.shape[0] * lat_bnds.shape[1], lat_bnds.shape[2])) - aux_b_lons = lon_bnds.reshape((lon_bnds.shape[0] * lon_bnds.shape[1], lon_bnds.shape[2])) + aux_b_lats = lat_bnds.reshape( + (lat_bnds.shape[0] * lat_bnds.shape[1], lat_bnds.shape[2]) + ) + aux_b_lons = lon_bnds.reshape( + (lon_bnds.shape[0] * lon_bnds.shape[1], lon_bnds.shape[2]) + ) # Create dataframe cointaining all polygons geometry = [] for i in range(aux_b_lons.shape[0]): - geometry.append(Polygon([(aux_b_lons[i, 0], aux_b_lats[i, 0]), - (aux_b_lons[i, 1], aux_b_lats[i, 1]), - (aux_b_lons[i, 2], aux_b_lats[i, 2]), - (aux_b_lons[i, 3], aux_b_lats[i, 3]), - (aux_b_lons[i, 0], aux_b_lats[i, 0])])) + geometry.append( + Polygon( + [ + (aux_b_lons[i, 0], aux_b_lats[i, 0]), + (aux_b_lons[i, 1], aux_b_lats[i, 1]), + (aux_b_lons[i, 2], aux_b_lats[i, 2]), + (aux_b_lons[i, 3], aux_b_lats[i, 3]), + (aux_b_lons[i, 0], aux_b_lats[i, 0]), + ] + ) + ) fids = self.get_fids() - gdf = GeoDataFrame(index=Index(name="FID", data=fids.ravel()), geometry=geometry, crs="EPSG:4326") + gdf = GeoDataFrame( + index=Index(name="FID", data=fids.ravel()), + geometry=geometry, + crs="EPSG:4326", + ) self.shapefile = gdf else: @@ -5002,15 +6066,19 @@ class Nes(object): if self.shapefile is None: raise ValueError("Shapefile was not created.") - if self.size == 1: - # In serial, avoid gather + if self.parallel_method in ["X", "Y"]: + if self.size == 1: + # In serial, avoid gather + self.shapefile.to_file(path) + else: + # In parallel + data = self.comm.gather(self.shapefile, root=0) + if self.master: + data = concat(data) + data.to_file(path) + + elif self.parallel_method == "T": self.shapefile.to_file(path) - else: - # In parallel - data = self.comm.gather(self.shapefile, root=0) - if self.master: - data = concat(data) - data.to_file(path) return None @@ -5049,8 +6117,11 @@ class Nes(object): if self.variables[var_name]["data"] is None: unloaded_vars.append(var_name) if len(unloaded_vars) > 0: - raise ValueError("The variables {0} need to be loaded/created before using to_shapefile.".format( - unloaded_vars)) + raise ValueError( + "The variables {0} need to be loaded/created before using to_shapefile.".format( + unloaded_vars + ) + ) # Select first vertical level (if needed) if lev is None: @@ -5061,7 +6132,11 @@ class Nes(object): idx_lev = 0 else: if lev not in self.lev["data"]: - raise ValueError("Level {} is not available. Choose from {}".format(lev, self.lev["data"])) + raise ValueError( + "Level {} is not available. Choose from {}".format( + lev, self.lev["data"] + ) + ) idx_lev = lev # Select first time (if needed) @@ -5073,7 +6148,9 @@ class Nes(object): idx_time = 0 else: if time not in self.time: - raise ValueError("Time {} is not available. Choose from {}".format(time, self.time)) + raise ValueError( + "Time {} is not available. Choose from {}".format(time, self.time) + ) idx_time = self.time.index(time) # Create shapefile @@ -5099,8 +6176,13 @@ class Nes(object): Index of time for which the data will be saved in the shapefile. """ + mask = getattr(self, "_keep_mask", None) + for var_name in var_list: - self.shapefile[var_name] = self.variables[var_name]["data"][idx_time, idx_lev, :].ravel() + flat_data = self.variables[var_name]["data"][idx_time, idx_lev, :].ravel() + if mask is not None: + flat_data = flat_data[mask] + self.shapefile[var_name] = flat_data return None @@ -5118,12 +6200,17 @@ class Nes(object): centroids = [] for lat_ind in range(0, len(self.lat["data"])): for lon_ind in range(0, len(self.lon["data"])): - centroids.append(Point(self.lon["data"][lon_ind], - self.lat["data"][lat_ind])) + centroids.append( + Point(self.lon["data"][lon_ind], self.lat["data"][lat_ind]) + ) # Create dataframe containing all points fids = self.get_fids() - centroids_gdf = GeoDataFrame(index=Index(name="FID", data=fids.ravel()), geometry=centroids, crs="EPSG:4326") + centroids_gdf = GeoDataFrame( + index=Index(name="FID", data=fids.ravel()), + geometry=centroids, + crs="EPSG:4326", + ) return centroids_gdf @@ -5179,7 +6266,9 @@ class Nes(object): else: raise NotImplementedError( "Parallel method '{meth}' is not implemented. Use one of these: {accept}".format( - meth=self.parallel_method, accept=["X", "Y", "T"])) + meth=self.parallel_method, accept=["X", "Y", "T"] + ) + ) if add_dimension: data_list[var_name]["data"] = stack(data_aux) else: @@ -5211,19 +6300,26 @@ class Nes(object): print("Gathering {0}".format(var_name)) if data_list[var_name]["data"] is None: data_list[var_name]["data"] = None - elif isinstance(data_list[var_name]["data"], int) and data_list[var_name]["data"] == 0: + elif ( + isinstance(data_list[var_name]["data"], int) + and data_list[var_name]["data"] == 0 + ): data_list[var_name]["data"] = 0 else: shp_len = len(data_list[var_name]["data"].shape) # Collect local array sizes using the gather communication pattern - rank_shapes = array(self.comm.gather(data_list[var_name]["data"].shape, root=0)) + rank_shapes = array( + self.comm.gather(data_list[var_name]["data"].shape, root=0) + ) sendbuf = data_list[var_name]["data"].flatten() sendcounts = array(self.comm.gather(len(sendbuf), root=0)) if self.master: recvbuf = empty(sum(sendcounts), dtype=type(sendbuf.max())) else: recvbuf = None - self.comm.Gatherv(sendbuf=sendbuf, recvbuf=(recvbuf, sendcounts), root=0) + self.comm.Gatherv( + sendbuf=sendbuf, recvbuf=(recvbuf, sendcounts), root=0 + ) if self.master: recvbuf = split(recvbuf, cumsum(sendcounts)) # TODO ask @@ -5267,7 +6363,9 @@ class Nes(object): else: raise NotImplementedError( "Parallel method '{meth}' is not implemented. Use one of these: {accept}".format( - meth=self.parallel_method, accept=["X", "Y", "T"])) + meth=self.parallel_method, accept=["X", "Y", "T"] + ) + ) if add_dimension: data_list[var_name]["data"] = stack(recvbuf) else: @@ -5275,9 +6373,190 @@ class Nes(object): return data_list - # ================================================================================================================== - # Extra Methods - # ================================================================================================================== + def to_mbtile( + self, + mbtiles_path: str, + mbtiles_file: str, + geostructure: str = None, + ) -> None: + """Converts a monthly NES object into daily MBTiles + + Parameters + ---------- + mbtiles_path : str + Path to the folder where the tiles will be saved in + mbtiles_file: str + Name of the mbtiles to be saved, this will be used as the prefix for the date portion and file extension + geostructure: str + Optional geostructure file to be provided, allowing to skip the manual computation, by default None + """ + + def _geojson_to_mbtile( + mbtiles_file: str, + geojson_file: str, + min_zoom: int = 5, + max_zoom: int = 12, + max_size: int = 10000000000, + max_features: int = 10000000000, + ): + """Converts a geojson file to an mbtile using gdal and the mbtile driver + + Parameters + ---------- + mbtiles_file : str + File path to the mbtile file to be created + geojson_file : str + File path to the geojson file converted from the netcdf + min_zoom : int, optional + Minimum zoom level at which tiles are generated, by default 5 + max_zoom : int, optional + Maximum zoom level at which tiles are generated, by default 12 + max_size : int, optional + Maximum size of a tile in bytes (after compression). If a tile is greater than this threshold, + features will be written with reduced precision, or discarded, by default 10000000000 + max_features : int, optional + Maximum number of features per tile, by default 10000000000 + """ + + # GDAL options + opts = gdal.VectorTranslateOptions( + format="MBTiles", + datasetCreationOptions=[ + f"MINZOOM={min_zoom}", + f"MAXZOOM={max_zoom}", + f"MAX_SIZE={max_size}", + f"MAX_FEATURES={max_features}", + ], + layerCreationOptions=[ + "NAME=layer", + ], + ) + + # Execute conversion + ds = gdal.VectorTranslate( + destNameOrDestDS=mbtiles_file, + srcDS=geojson_file, + options=opts, + ) + # Recover resources held by dataset, remove file locks, equivalent to GDALClose() + ds = None + + @njit(fastmath=True, cache=True, parallel=True) + def _filter_by_index_sum(data_arrays: list[ndarray], threshold: float = 1e-12): + """ + Filters a list of 1D arrays (same length) by summing values at each index across all arrays and applying a threshold. + Essentially serves to filter out cells that have no data in them. + + Parameters + ---------- + data_arrays : list[ndarray] + List of 1D NumPy arrays representing different data sectors. All arrays must have the same length. + threshold : float, optional + Minimum absolute sum required at each index to retain the data across all arrays. Default is 1e-12. + + Returns + ------- + filtered : list[ndarray] + List of filtered arrays with values retained only at indices where: + - no array has a NaN at that index, and + - the absolute sum of all values at that index exceeds the threshold. + keep : ndarray[bool] + Boolean mask of the original length indicating which indices were kept. + """ + + n = data_arrays[0].shape[0] + num_sectors = len(data_arrays) + keep = zeros(n, dtype=boolean) + + # Parallelization with multithreading + for i in prange(n): + total = 0.0 + valid = True + for j in range(num_sectors): + val = data_arrays[j][i] + if isnan(val): + valid = False + break + total += val + if valid and abs(total) > threshold: + keep[i] = True + + # Count kept values + count = 0 + for i in range(n): + if keep[i]: + count += 1 + + # Filter each sector + filtered = [] + for j in range(num_sectors): + arr = data_arrays[j] + out = empty(count, dtype=arr.dtype) + idx = 0 + for i in range(n): + if keep[i]: + out[idx] = arr[i] + idx += 1 + filtered.append(out) + + return filtered, keep + + if self.parallel_method != "T" or self.size > 1: + raise ( + f"Warning, {self.parallel_method} incompatible with 'to_mbtiles', we can only parallelize across the time domain." + ) + + # Extract sectors available + sectors = self.variables.keys() + + if not os.path.exists(mbtiles_path): + # Create output directory if it doesn't exist + os.makedirs(mbtiles_path, exist_ok=True) + + if geostructure is not None: + self.shapefile = read_file(geostructure) + else: + self.create_shapefile() + + for date in self.time: + # Select time slice + sliced_data = self.sel_time(date, inplace=False) + # Flatten sector data for each sector into a single array + sector_arrays = [ + sliced_data.variables[sector]["data"].flatten() for sector in sectors + ] + # Filter 0s and near 0 values + filtered_arrays, keep_mask = _filter_by_index_sum(sector_arrays) + # Filter the shapefile DataFrame itself + sliced_data.shapefile = sliced_data.shapefile[keep_mask].reset_index( + drop=True + ) + # Boolean mask of which indices to keep + sliced_data._keep_mask = keep_mask + for i, sector in enumerate(sectors): + sliced_data.shapefile[sector] = filtered_arrays[i] + # Format date portion of file name + date_part = ( + str(date.year) + str(date.month).zfill(2) + str(date.day).zfill(2) + ) + print( + f"\tProcessing daily file, date: {date_part}", + flush=True, + ) + # Create temporary GeoJSON file + geojson_file = os.path.join( + tempfile.gettempdir(), f"temp_geojson_{self.rank}_{date_part}.geojson" + ) + # Convert to shapefile + sliced_data.to_shapefile(geojson_file, time=date, lev=0) + + # Assemble path to mbtiles file + mbtiles = os.path.join( + mbtiles_path, f"daily_{mbtiles_file}_{date_part}.mbtiles" + ) + # Convert geojson to mbtile + _geojson_to_mbtile(mbtiles, geojson_file) + @staticmethod def lon_lat_to_cartesian_ecef(lon, lat): """ @@ -5320,8 +6599,15 @@ class Nes(object): return vertical_interpolation.add_4d_vertical_info(self, info_to_add) - def interpolate_vertical(self, new_levels, new_src_vertical=None, kind="linear", extrapolate=None, info=None, - overwrite=False): + def interpolate_vertical( + self, + new_levels, + new_src_vertical=None, + kind="linear", + extrapolate=None, + info=None, + overwrite=False, + ): """ Vertical interpolation function. @@ -5364,11 +6650,29 @@ class Nes(object): """ return vertical_interpolation.interpolate_vertical( - self, new_levels, new_src_vertical=new_src_vertical, kind=kind, extrapolate_options=extrapolate, info=info, - overwrite=overwrite) - - def interpolate_horizontal(self, dst_grid, weight_matrix_path=None, kind="NearestNeighbour", n_neighbours=4, - info=False, to_providentia=False, only_create_wm=False, wm=None, flux=False, keep_nan=False, fix_border=False): + self, + new_levels, + new_src_vertical=new_src_vertical, + kind=kind, + extrapolate_options=extrapolate, + info=info, + overwrite=overwrite, + ) + + def interpolate_horizontal( + self, + dst_grid, + weight_matrix_path=None, + kind="NearestNeighbour", + n_neighbours=4, + info=False, + to_providentia=False, + only_create_wm=False, + wm=None, + flux=False, + keep_nan=False, + fix_border=False, + ): """ Horizontal methods from the current grid to another one. @@ -5400,11 +6704,23 @@ class Nes(object): """ return horizontal_interpolation.interpolate_horizontal( - self, dst_grid, weight_matrix_path=weight_matrix_path, kind=kind, n_neighbours=n_neighbours, info=info, - to_providentia=to_providentia, only_create_wm=only_create_wm, wm=wm, flux=flux, keep_nan=keep_nan, - fix_border=fix_border) - - def spatial_join(self, ext_shp, method=None, var_list=None, info=False, apply_bbox=True): + self, + dst_grid, + weight_matrix_path=weight_matrix_path, + kind=kind, + n_neighbours=n_neighbours, + info=info, + to_providentia=to_providentia, + only_create_wm=only_create_wm, + wm=wm, + flux=flux, + keep_nan=keep_nan, + fix_border=fix_border, + ) + + def spatial_join( + self, ext_shp, method=None, var_list=None, info=False, apply_bbox=True + ): """ Compute overlay intersection of two GeoPandasDataFrames. @@ -5422,8 +6738,14 @@ class Nes(object): Indicates if you want to reduce the shapefile to a bbox. """ - return spatial_join(self, ext_shp=ext_shp, method=method, var_list=var_list, info=info, - apply_bbox=apply_bbox) + return spatial_join( + self, + ext_shp=ext_shp, + method=method, + var_list=var_list, + info=info, + apply_bbox=apply_bbox, + ) def calculate_grid_area(self, overwrite=True): """ @@ -5439,7 +6761,9 @@ class Nes(object): if ("cell_area" not in self.cell_measures.keys()) or overwrite: grid_area = cell_measures.calculate_grid_area(self) - grid_area = grid_area.reshape([self.lat["data"].shape[0], self.lon["data"].shape[-1]]) + grid_area = grid_area.reshape( + [self.lat["data"].shape[0], self.lon["data"].shape[-1]] + ) self.cell_measures["cell_area"] = {"data": grid_area} else: grid_area = self.cell_measures["cell_area"]["data"] @@ -5447,8 +6771,11 @@ class Nes(object): return grid_area @staticmethod - def calculate_geometry_area(geometry_list, earth_radius_minor_axis=6356752.3142, - earth_radius_major_axis=6378137.0): + def calculate_geometry_area( + geometry_list, + earth_radius_minor_axis=6356752.3142, + earth_radius_major_axis=6378137.0, + ): """ Get coordinate bounds and call function to calculate the area of each cell of a set of geometries. @@ -5462,8 +6789,11 @@ class Nes(object): Radius of the major axis of the Earth. """ - return cell_measures.calculate_geometry_area(geometry_list, earth_radius_minor_axis=earth_radius_minor_axis, - earth_radius_major_axis=earth_radius_major_axis) + return cell_measures.calculate_geometry_area( + geometry_list, + earth_radius_minor_axis=earth_radius_minor_axis, + earth_radius_major_axis=earth_radius_major_axis, + ) @staticmethod def get_earth_radius(ellps): @@ -5493,7 +6823,9 @@ class Nes(object): Dictionary with data of centre coordinates for longitude in 2D (latitude, longitude). """ - raise NotImplementedError("create_providentia_exp_centre_coordinates function is not implemented by default") + raise NotImplementedError( + "create_providentia_exp_centre_coordinates function is not implemented by default" + ) # noinspection DuplicatedCode def create_providentia_exp_grid_edge_coordinates(self): @@ -5507,7 +6839,9 @@ class Nes(object): grid_edge_lon : dict Dictionary with data of grid edge longitudes. """ - raise NotImplementedError("create_providentia_exp_grid_edge_coordinates function is not implemented by default") + raise NotImplementedError( + "create_providentia_exp_grid_edge_coordinates function is not implemented by default" + ) def _detect_longitude_format(self): """ @@ -5536,7 +6870,9 @@ class Nes(object): """ for variable in self.variables.keys(): if self.variables[variable]["data"] is None: - raise ValueError(f"All variables data must be loaded before using this function. Data for {variable} is not loaded.") + raise ValueError( + f"All variables data must be loaded before using this function. Data for {variable} is not loaded." + ) def convert_longitudes(self): """ @@ -5556,7 +6892,9 @@ class Nes(object): raise ValueError("This method is currently only available in serial.") if not self._detect_longitude_format(): - raise ValueError("Longitudes are already in [-180, 180] format or another unrecognised format.") + raise ValueError( + "Longitudes are already in [-180, 180] format or another unrecognised format." + ) self._check_if_data_is_loaded() @@ -5575,9 +6913,13 @@ class Nes(object): lon_bnds_data[lon_bnds_data > 180] -= 360 lon_bnds_sorted = lon_bnds_data[sorted_indices] - if (lon_bnds_sorted[0][0] > lon_bnds_sorted[0][1]) and (isclose(lon_bnds_sorted[0][0], 180)): + if (lon_bnds_sorted[0][0] > lon_bnds_sorted[0][1]) and ( + isclose(lon_bnds_sorted[0][0], 180) + ): lon_bnds_sorted[0][0] = -180 - elif (lon_bnds_sorted[-1][0] > lon_bnds_sorted[-1][1]) and (isclose(lon_bnds_sorted[-1][1], -180)): + elif (lon_bnds_sorted[-1][0] > lon_bnds_sorted[-1][1]) and ( + isclose(lon_bnds_sorted[-1][1], -180) + ): lon_bnds_sorted[-1][1] = 180 self.lon_bnds["data"] = lon_bnds_sorted diff --git a/requirements.txt b/requirements.txt index d85263846dbaa211a162b1082fda58f28936f5ad..7b7a595f2bcf9a5a789372259c0da5d6ea442e1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,6 @@ toml importlib-metadata packaging pytest -argcomplete \ No newline at end of file +argcomplete +numba +gdal \ No newline at end of file