diff --git a/environment.yml b/environment.yml index 0bd6897f1e950f46c9d3dafe7836c9039d120680..d02dfeff9748c1f1bdab3244d4b4447345c282a7 100644 --- a/environment.yml +++ b/environment.yml @@ -22,6 +22,7 @@ dependencies: - pyYAML - h5netcdf - loguru + - tqdm - pip variables: diff --git a/mapies/grids/base_grid.py b/mapies/grids/base_grid.py index c1ffdeec793cb25a66134353f69fadd0ee3c9b42..2e78183dcab90ac399e26438719d5ec336a83563 100644 --- a/mapies/grids/base_grid.py +++ b/mapies/grids/base_grid.py @@ -23,6 +23,7 @@ class Grid: south: float = None centre_lat: Optional[float] = None centre_lon: Optional[float] = None + min_num_obs: Optional[int] = None def calculate_grid_coords(self): @@ -64,7 +65,12 @@ class Grid: @staticmethod - def create_bounds(coordinates, inc, number_vertices=2, inverse=False): + def create_bounds( + coordinates: NDArray, + inc: float, + number_vertices: int = 2, + inverse: bool = False + ): """ Calculate the boundaries of each grid points """ diff --git a/mapies/grids/regular.py b/mapies/grids/regular.py index 0ca54c99bfaf9f1e3e566829e071d7cac5c1a4da..a03bc04509fb0b4c146beafa69c3414b1f10a1c6 100644 --- a/mapies/grids/regular.py +++ b/mapies/grids/regular.py @@ -5,6 +5,7 @@ import numpy as np from typing import Tuple, Dict, Optional from mapies.util.func_tools import * from mapies.grids.base_grid import Grid +from numpy.typing import DTypeLike, NDArray import logging @@ -21,7 +22,13 @@ class RegularGrid(Grid): def __post_init__(self): self.calculate_grid_coords() - def aggregate(self, lon, lat, obs, obserr=None): + def aggregate( + self, + lon: NDArray, + lat: NDArray, + obs: NDArray, + obserr: NDArray | None = None + ): """ Aggregate the observations to the grid cells @@ -79,6 +86,7 @@ class RegularGrid(Grid): if aggregated_obs_err is not None: aggregated_obs_err[non_zero_counts] /= count_obs[non_zero_counts] + # TODO: Add a function that takes the center or aggreagated value # Return the results lon_out = aggregated_lon.flatten() #self.center_longitudes diff --git a/mapies/grids/rotated.py b/mapies/grids/rotated.py index 5cfdaa2fbf7fcfa1b5da8519a4e99ea96db11076..c6847103759db8270188d25847241629cd9936d4 100644 --- a/mapies/grids/rotated.py +++ b/mapies/grids/rotated.py @@ -12,7 +12,12 @@ from mapies.grids.base_grid import Grid import logging -def geo_to_rot(lons, lats, centre_lon: float, centre_lat: float): +def geo_to_rot( + lons: NDArray, + lats: NDArray, + centre_lon: float, + centre_lat: float, + ): """ Rotating coordinates from cartesian lat/lon to rotated rlon/rlat """ @@ -48,7 +53,13 @@ class RotatedGrid(Grid): def __post_init__(self): self.calculate_grid_coords() - def aggregate(self, lon, lat, obs, obserr=None): + def aggregate( + self, + lon: NDArray, + lat: NDArray, + obs: NDArray, + obserr: NDArray | None = None + ): """ Aggregate the observations to the grid cells diff --git a/mapies/mapies.py b/mapies/mapies.py index f5230169ea0977366b5a4321dd1303365048e8b5..5ba6a8f50bc0eebf7bcb5a01b7945255a3726775 100644 --- a/mapies/mapies.py +++ b/mapies/mapies.py @@ -25,6 +25,7 @@ from mapies.util.variables_names import VARS_ATTRIBUTES from numpy.typing import DTypeLike, NDArray from loguru import logger +from tqdm import tqdm logger.remove() logger.add(sys.stderr, level="INFO") @@ -47,10 +48,12 @@ class MAPIES: self.end_date = end_date self.apply_qa = True self.geospatial_crop = None + self.min_num_obs = None + self.filters = {} self.grid_dict = { - "rotated":RotatedGrid, - "regular":RegularGrid, + "rotated":RotatedGrid, + "regular":RegularGrid, "ireg_rotated":IrregularRotatedGrid, "cams":CAMSGrid, "bdrc":BDRCGrid, @@ -117,8 +120,15 @@ class MAPIES: @timeit @logger.catch - def process_avg_data(self, monthly_avg=False, batch_size=100, save=False, apply_qa=True, geospatial_crop: NDArray | None = None): - + def process_avg_data( + self, + monthly_avg: bool = 0, + min_num_obs: int = 0, + batch_size: int = 100, + save: bool = False, + apply_qa: bool = True, + geospatial_crop: NDArray | None = None + ): """ Process the data. @@ -130,7 +140,8 @@ class MAPIES: - if monthly_avg --> updated monthly dictorionary for the processed month. - if not monthly_avg --> updated self.lon_values, self.lat_values, self.obs, self.count_obs. """ - + self.min_num_obs = min_num_obs + self.filters.update({"min_num_obs":min_num_obs}) #Check if the user wants qa flags if apply_qa: logger.info(f"Applying QA flags: {self.qualityFlags}") @@ -185,7 +196,11 @@ class MAPIES: - def process_data_mp(self, batch_size, month=None): + def process_data_mp( + self, + batch_size: int, + month: int | None =None + ): """ Process the data with multiprocessing. """ @@ -208,7 +223,7 @@ class MAPIES: cumulative_count = None failed_files = [] - for batch in batches: + for batch in tqdm(batches): logger.debug(f"Processing a batch of {len(batch)} files...") # Use multiprocessing to process files within the batch @@ -247,6 +262,13 @@ class MAPIES: # Compute final averages final_obs, final_lat, final_lon = compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) + # Set values less than min num of counts to nan + min_num_counts = cumulative_count > self.min_num_obs + final_obs = final_obs[min_num_counts] + final_lon = final_lon[min_num_counts] + final_lat = final_lat[min_num_counts] + cumulative_count = cumulative_count[min_num_counts] + logger.debug(f"Completed processing for month {month}.") @@ -300,7 +322,12 @@ class MAPIES: @timeit @logger.catch - def process_lazy_data(self, apply_qa=True, geospatial_crop: NDArray | None = None, save=True): + def process_lazy_data( + self, + apply_qa: bool = True, + geospatial_crop: NDArray | None = None, + save: bool = True + ): """ Process the data for the specified time range. It only retrieves a daily dictionary using dask lazy loading. """ @@ -312,7 +339,7 @@ class MAPIES: self.apply_qa = False if geospatial_crop is not None: self.geospatial_crop = np.array(geospatial_crop) - for day in self.daily_dict.keys(): + for day in tqdm(self.daily_dict.keys()): files = self.daily_dict[day]["files"] # Use multiprocessing to process files within the batch @@ -357,7 +384,18 @@ class MAPIES: @logger.catch - def to_netCDF(self, obs, lat, lon, time, start_time, end_time): + def to_netCDF( + self, + obs: NDArray | None = None, + lat: NDArray | None = None, + lon: NDArray | None = None, + time: NDArray | None = None, + start_time: pd.Timestamp | None = None, + end_time: pd.Timestamp | None = None, + ): + """ + Function to return the start MAPIES output as a netcdf + """ # Convert time to seconds since TIME_ORIGIN final_time_seconds = (time - self.time_orig) / np.timedelta64(1, "s") @@ -385,6 +423,7 @@ class MAPIES: "Developers": "MAPIES team", "QA": f"Quality Assurance: {self.apply_qa}, quality assurance applied: {qa_flags_str}", "history": f"Created {datetime.now()}", + "Extra Filters Applied": f"{json.dumps(self.filters, ensure_ascii=False)}" }, ) @@ -408,7 +447,15 @@ class MAPIES: # Plotting # ============================================================================= @logger.catch - def plot_2D_observations(self, months=[0], filename=None, outdir=None): + def plot_2D_observations( + self, + months: List[int] = [0], + filename: str = None, + outdir: str = None + ): + """ + Mapies main plotting function for 2d surface obs + """ if months == [0]: try: @@ -469,7 +516,15 @@ class MAPIES: ) - def plot_2D_num_obs(self, months=[0], filename=None, outdir=None): + def plot_2D_num_obs( + self, + months: List[int] = [0], + filename: str = None, + outdir: str = None + ): + """ + Mapies main plotting function for observation counts at the surface + """ if months == [0]: try: @@ -531,7 +586,11 @@ class MAPIES: @staticmethod - def to_xarray(coords:dict, data_vars:dict, **kwargs): + def to_xarray( + coords:dict, + data_vars:dict, + **kwargs + ): """ Method to convert numpy arrays to xarray opject """ diff --git a/mapies/util/func_tools.py b/mapies/util/func_tools.py index fb2c05b00a85ff9356c353348ec170f9ea940816..553b295131ca4446bd2273872ed3a68e1d46e147 100644 --- a/mapies/util/func_tools.py +++ b/mapies/util/func_tools.py @@ -8,7 +8,8 @@ import sys from functools import wraps from numpy import cos, sin, arctan, pi, nan from glob import glob -from typing import List +from typing import List, Tuple +from numpy.typing import DTypeLike, NDArray from datetime import datetime import multiprocessing from loguru import logger @@ -35,7 +36,7 @@ def timeit(func): return timeit_wrapper -def get_filepaths(pattern): +def get_filepaths(pattern:str): """ Get file paths from a pattern """ @@ -43,7 +44,7 @@ def get_filepaths(pattern): # Only supported in python3.10 -def get_file_list(patterns) -> List[str]: +def get_file_list(patterns: List[str]) -> List[str]: """ param patterns: one or several glob patterns (or exact file names) pointing to the location of files on disk. return: the matching list of files @@ -89,7 +90,14 @@ def time_converter(date:str) -> pd.Timestamp: -def quality_assurance(ds, obs, lat, lon, time_values, qualityFlags): +def quality_assurance( + ds: xr.Dataset, + obs: NDArray, + lat: NDArray, + lon: NDArray, + time_values: NDArray, + qualityFlags: dict, + ): """ Applying quality assurance flags to mapies data """ @@ -150,7 +158,13 @@ def quality_assurance(ds, obs, lat, lon, time_values, qualityFlags): -def spatial_domain_selection(obs, lat, lon, time, spatial_bounds): +def spatial_domain_selection( + obs: NDArray, + lat: NDArray, + lon: NDArray, + time: NDArray, + spatial_bounds: NDArray, + ): """ A function for cutting the data spatially @@ -183,13 +197,13 @@ def spatial_domain_selection(obs, lat, lon, time, spatial_bounds): # Time Domain Selection function based off of xarray Will this work with dask too, a question for later def time_domain_selection( - obs, - lat, - lon, - time_values, - start_date, - end_date, - closed = None) -> np.array: + obs: NDArray, + lat: NDArray, + lon: NDArray, + time_values: NDArray, + start_date: pd.Timestamp, + end_date: pd.Timestamp, + closed: str | None = None) -> NDArray: """ :param ds: one or several xarray datasets to cut down to correct size. :param start_date: the start date that we want to start with for our time domain selection @@ -231,7 +245,11 @@ def time_domain_selection( return obs, lon, lat, time_values -def inflate_array(var_array, shape1, shape2): +def inflate_array( + var_array: NDArray, + shape1: tuple, + shape2: tuple, + ) -> NDArray: """ A function to inflate an array by repeating it along one dimension This is useful if you have a flattened array and need to repeat it due to the original array having less @@ -245,7 +263,7 @@ def inflate_array(var_array, shape1, shape2): return np.repeat(var_array, repeats) -def frequency_int(time_interval): +def frequency_int(time_interval:str) -> int: """ Convert frequency to an integer value, used in the DA """ @@ -258,7 +276,11 @@ def frequency_int(time_interval): return frequency -def error_estimation(datatype:str, obs:np.typing.NDArray, unc_const:list) -> np.typing.NDArray: +def error_estimation( + datatype:str, + obs:NDArray, + unc_const:list + ) -> NDArray: """ Error estimation function """ @@ -272,7 +294,13 @@ def error_estimation(datatype:str, obs:np.typing.NDArray, unc_const:list) -> np. -def flatten_arrays(ds, obs_var, lat_var, lon_var, time_var): +def flatten_arrays( + ds: xr.Dataset, + obs_var: str, + lat_var: str, + lon_var: str, + time_var: str, + ): """ * Preprocessing of the dataset * Flattening arrays @@ -323,8 +351,8 @@ def flatten_arrays(ds, obs_var, lat_var, lon_var, time_var): def reindex( - dependent_var_index, - independent_var_values, + dependent_var_index: NDArray, + independent_var_values: NDArray, ): """" Recutting the data whenever a selction of the data has been made along one of the dimenisons @@ -338,7 +366,11 @@ def reindex( return independent_var_values -def day_boundaries(day, start_date, end_date): +def day_boundaries( + day, + start_date: pd.Timestamp, + end_date: pd.Timestamp, + ): """ Get daily boundaries only if they are within start and end date If not return the start and end date as the boundaries @@ -360,7 +392,11 @@ def day_boundaries(day, start_date, end_date): -def filter_files(files, start_date, end_date): +def filter_files( + files: List[str], + start_date: pd.Timestamp, + end_date: pd.Timestamp, + ): """ Filter files based on the start_date and end_date """ @@ -378,7 +414,11 @@ def filter_files(files, start_date, end_date): -def generate_mid_time_array(start_date, end_date, obs): +def generate_mid_time_array( + start_date: pd.Timestamp, + end_date: pd.Timestamp, + obs: NDArray, + ): """ Generate the mid time array for the observations (Used when saving into nectCDF averaged data) @@ -393,7 +433,10 @@ def generate_mid_time_array(start_date, end_date, obs): -def create_batches(files, batch_size): +def create_batches( + files: List[str], + batch_size: int, + ): """ Create batches of files to pass to MP """ @@ -406,7 +449,13 @@ def create_batches(files, batch_size): -def combine_arrays(valid_results, cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count): +def combine_arrays( + valid_results: Tuple[NDArray], + cumulative_obs: None, + cumulative_lat: None, + cumulative_lon: None, + cumulative_count: None, + ): """ Combine the arrays from the valid results. In other words sum the arrays and count the number of observations. """ @@ -427,7 +476,12 @@ def combine_arrays(valid_results, cumulative_obs, cumulative_lat, cumulative_lon -def compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count): +def compute_final_averages( + cumulative_obs: NDArray, + cumulative_lat: NDArray, + cumulative_lon: NDArray, + cumulative_count: NDArray, + ): """ Compute the final averages from the cumulative arrays. In other words divide the cumulative_obs by the cumulative_count. """ @@ -445,7 +499,10 @@ def compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumul -def append_and_concatenate(valid_results, day): +def append_and_concatenate( + valid_results: Tuple[NDArray], + day + ): """ Append and concatenate arrays """ @@ -475,7 +532,12 @@ def append_and_concatenate(valid_results, day): -def filter_observations(obs, lat, lon, count): +def filter_observations( + obs: NDArray, + lat: NDArray, + lon: NDArray, + count: NDArray, + ): """ Filter out observations with negative values """ diff --git a/mapies/util/multiprocessing.py b/mapies/util/multiprocessing.py index 034fa2be1518ed557ed6ac60e8a997b69d428d10..ad0d3a49d431fa87ea74814e2de9e136af04cc3e 100644 --- a/mapies/util/multiprocessing.py +++ b/mapies/util/multiprocessing.py @@ -16,18 +16,18 @@ from mapies.grids.regular import RegularGrid @logger.catch def process_single_file_tropomi( - grid, - file, - obs_var, - lat_var, - lon_var, - time_var, - start_date, - end_date, - apply_qa, - qualityFlags, - flag=False, - geospatial_crop=None + grid: Grid, + file: str, + obs_var: NDArray, + lat_var: NDArray, + lon_var: NDArray, + time_var: NDArray, + start_date: pd.Timestamp, + end_date: pd.Timestamp, + apply_qa: bool, + qualityFlags: dict, + flag: bool = False, + geospatial_crop: NDArray | None = None ): """ Process a single file and return aggregated results or flattened arrays. @@ -63,18 +63,18 @@ def process_single_file_tropomi( @logger.catch def process_single_file_viirs( - grid, - file, - obs_var, - lat_var, - lon_var, - time_var, - start_date, - end_date, - apply_qa, - qualityFlags, - flag=False, - geospatial_crop=None + grid: Grid, + file: str, + obs_var: NDArray, + lat_var: NDArray, + lon_var: NDArray, + time_var: NDArray, + start_date: pd.Timestamp, + end_date: pd.Timestamp, + apply_qa: bool, + qualityFlags: dict, + flag: bool = False, + geospatial_crop: NDArray | None = None ): """ Process a single file and return aggregated results or flattened arrays. diff --git a/mapies/util/plotting.py b/mapies/util/plotting.py index 2791ed452f932b30b2c002223c4bf1f88567b1b5..38a0417d2c6e2c2f7d135ce812343c2f0f7a70e2 100644 --- a/mapies/util/plotting.py +++ b/mapies/util/plotting.py @@ -16,7 +16,16 @@ class MapiesPlotter: self.datatype=datatype # @timeit - def plot_2D_obs_custom(self, lon_values=None, lat_values=None, obs_values=None, outdir=None, title=None, filename=None, display_fig=False): + def plot_2D_obs_custom( + self, + lon_values=None, + lat_values=None, + obs_values=None, + outdir=None, + title=None, + filename=None, + display_fig=False + ): """ General method for plotting observations with CB-RdYlBu colormap and arrow-like ends for out-of-bound values. """ @@ -74,18 +83,19 @@ class MapiesPlotter: plt.savefig(filepath, format="png") if display_fig: plt.show() + plt.close(fig) logger.info(f"Saved plot: {filepath}") def plot_2D_num_obs_custom( - self, - lon_values=None, - lat_values=None, - num_obs=None, - outdir="./", - title=None, - filename=None, - display_fig=False + self, + lon_values=None, + lat_values=None, + num_obs=None, + outdir="./", + title=None, + filename=None, + display_fig=False ): """ General method for plotting the number of observations contributing to the final averages. @@ -132,6 +142,7 @@ class MapiesPlotter: plt.savefig(filepath, format="png") if display_fig: plt.show() + plt.close(fig) logger.info(f"Saved plot: {filepath}") # @property diff --git a/mapies/viirs.py b/mapies/viirs.py index ee864eabd4ceff09b081f78fa8eecebb18b09c2a..5887af70d2a06b73751f5ddfc198a9ec47ba0106 100644 --- a/mapies/viirs.py +++ b/mapies/viirs.py @@ -82,6 +82,8 @@ class VIIRS(MAPIES): + + @timeit @logger.catch def gather_nc_files(self): @@ -149,7 +151,13 @@ class VIIRS(MAPIES): # ============================================================================= @timeit @logger.catch - def to_da(self, frequency="D", save_nc= True, save_figs=False): + def to_da( + self, + frequency:str = "D", + min_num_obs: int = 0, + save_nc: bool = True, + save_figs: bool = False + ): """ Function that returns all the needed variables for the DA """ @@ -207,7 +215,17 @@ class VIIRS(MAPIES): lon, lat, obs, obserr, count_obs = self.grid.aggregate(lon, lat, obs, obserr) except ValueError: raise Exception("Have you specified a grid_repr") - + + # TODO: Apply min num of obs here too + min_num_counts = count_obs > min_num_obs + obs = obs[min_num_counts] + lon = lon[min_num_counts] + lat = lat[min_num_counts] + obserr = obserr[min_num_counts] + count_obs = count_obs[min_num_counts] + if isinstance(self.grid, RotatedGrid): + rlon = rlon[min_num_counts] + rlat = rlat[min_num_counts] # Create new arrays with same length as obs obsid = np.full(shape=obs.shape, fill_value=self.obsid_da, dtype=self.int_type) @@ -244,9 +262,8 @@ class VIIRS(MAPIES): if save_figs: self.plotter.plot_2D_obs_custom(lon, lat, obs, outdir=self.dest, filename=f"da_obs_{date}.png") - ds = self.to_xarray(coords=coords, data_vars=data_vars) + ds = self.to_xarray(coords=coords, data_vars=data_vars, attrs={"filters":f"Min_num_obs = {min_num_obs}"}) if save_nc: - #logger.info(f"Outputting da data with {filename}") logger.debug(f"Outputting da data with {filename}") ds.to_netcdf(filename, encoding={}) outfiles.append(filename) diff --git a/run/run_viirs.py b/run/run_viirs.py index 646d251445af2001bf1800b471c64286e5d9d5e9..af39004ec0d631c5c10c0503a49c00a3c116d1b4 100644 --- a/run/run_viirs.py +++ b/run/run_viirs.py @@ -7,25 +7,24 @@ if __name__ == "__main__": # Change these variables # 00:30 to 02:36 from 2024-01-01 to 2024-01-03 start_date = "202401010000" - end_date = "202401012359" + end_date = "202401312359" - # outdir="/home/cmeikle/Projects/data/" - # indir="/home/cmeikle/Projects/data/VIIRS" + outdir="/home/cmeikle/Projects/data/" + indir="/home/cmeikle/Projects/data/VIIRS" start_time = time.time() c = VIIRS(start_date, end_date, dest=outdir, indir=indir, grid_repr="bdrc") c.gather_nc_files() - c.process_avg_data(monthly_avg=False, batch_size = 100, apply_qa=True, save=True) - #c.yearly_average() - c.plot_2D_observations(months=[0], filename="new.png", outdir=outdir) + c.process_avg_data(monthly_avg=False, min_num_obs = 20, save=True) + c.plot_2D_observations(months=[0], filename="month_without_min_obs_and_qa_applied.png", outdir=outdir) c.plot_2D_num_obs(months=[0], outdir=outdir) - c.process_lazy_data(apply_qa=True, save=False) + #c.process_lazy_data(apply_qa=True, save=False) - c.to_da(frequency="3H", save_figs=True) + #c.to_da(frequency="D", min_num_obs=2, save_figs=True)