From 38d988194e8f843419d012f950bcbcd353f44b1e Mon Sep 17 00:00:00 2001 From: Calum Meikle Date: Thu, 27 Feb 2025 16:58:29 +0100 Subject: [PATCH 1/4] Removing non needed methods and adding multiprocessing classes --- mapies/mapies.py | 141 ++--------------- mapies/tropomi.py | 70 +++------ mapies/util/func_tools.py | 142 ++++++++--------- mapies/util/multiprocessing.py | 280 +++++++++++++++++++++++++++++++++ mapies/viirs.py | 98 +++++++----- run/run_viirs.py | 8 +- 6 files changed, 437 insertions(+), 302 deletions(-) create mode 100644 mapies/util/multiprocessing.py diff --git a/mapies/mapies.py b/mapies/mapies.py index d548f422..baa8c75b 100644 --- a/mapies/mapies.py +++ b/mapies/mapies.py @@ -101,49 +101,6 @@ class MAPIES: self.qualityFlags = qa_dict["qa_variables"] - # TODO this can be removed I believe - @timeit - def plot_2D_obs(self, **kwargs): - """ - Plotting the observations - """ - # Regrid if specified as argument - if self.grid_repr: - #logger.info("Performing regridding of data") - self.regridding_function() - # TODO: Make these arguments - figsize = (15,10) - markersize = 2.5 - - # Set Basemap projection - proj = ccrs.PlateCarree() - - # Create the plot and add features, TODO make this adjustable in arguments - fig, ax = plt.subplots(subplot_kw={"projection": proj}, figsize=figsize) - ax.gridlines() - #ax.add_feature(cartopy.feature.BORDERS, linestyle=':', alpha=1) - #ax.add_feature(cartopy.feature.OCEAN,facecolor=("lightblue")) - #ax.add_feature(cartopy.feature.LAND) - ax.coastlines(resolution='10m') - - - - print("Plotting observations") - - - x, y = self.lon_values, self.lat_values - - im = ax.scatter(x,y,markersize,c=self.obs, transform=proj) - - fig.colorbar(im, ax=ax) - - ax.set_title(f'Observation 2D plot of {self.datatype.upper()} data from {self.start_date} to {self.end_date}') - print("Saving Figure") - plt.savefig(f"{self.dest}{self.datatype}_2D_obs.png", format="png") - #plt.close(fig) - plt.show() - #return fig - # TODO this can be moved to a plotting function python file @timeit @@ -208,92 +165,6 @@ class MAPIES: print(f"Saved plot: {filepath}") - @timeit - def plot_2D_cumulative_count(self, lon_values, lat_values, cumulative_count, outdir="/home/cgile/Documents/mapies/mapies/figures", title=None, filename=None, display_fig=False): - """ - Plot the cumulative count of observations contributing to the final averages. - - Parameters: - - lon_values: 1D array of longitude values. - - lat_values: 1D array of latitude values. - - cumulative_count: 1D array of the count of valid observations per grid cell. - - outdir: Directory to save the plot. - - title: Title for the plot. - - filename: Name of the output file. - """ - - # Set Basemap projection - figsize = (15, 10) - markersize = 2.5 - proj = ccrs.PlateCarree() - - # Set colormap and normalization - cmap = plt.cm.YlOrBr # Colormap for cumulative count - norm = mcolors.Normalize(vmin=0, vmax=np.max(cumulative_count)) # Continuous normalization - - # Create the plot and add features - fig, ax = plt.subplots(subplot_kw={"projection": proj}, figsize=figsize) - ax.gridlines() - ax.coastlines(resolution="10m") - - print("Plotting cumulative count") - - # Scatter plot for cumulative count - im = ax.scatter( - lon_values, lat_values, markersize, c=cumulative_count, cmap=cmap, norm=norm, transform=proj - ) - - # Add the color bar - cbar = fig.colorbar(im, ax=ax, orientation="vertical", pad=0.05) - cbar.set_label("Cumulative Count") - - # Add title and save the plot - ax.set_title(title) - - # Ensure the output directory exists - os.makedirs(outdir, exist_ok=True) - - output_filename = "few{}fcec".format(site_code) - print("{}".format(this_string)) - - - filepath = os.path.join(outdir, filename) - plt.savefig(filepath, format="png") - if display_fig: - plt.show() - print(f"Saved plot: {filepath}") - - - - @staticmethod - def reindex( - dependent_var_index, - independent_var_values, - ): - """" - Recutting the data whenever a selction of the data has been made along one of the dimenisons - - Based off of how it has been done in Providentia - """ - # Maybe add a checker try/except or assert - independent_var_values = independent_var_values[dependent_var_index] - - # return reindexed values - return independent_var_values - - - @staticmethod - def to_xarray(coords:dict, data_vars:dict, **kwargs): - """ - Method to convert numpy arrays to xarray opject - """ - attrs = kwargs.get("attrs") - if attrs is None: - attrs=dict() - - ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs) - return ds - # TODO this can be moved to a plotting function python file @timeit def plot_2D_num_obs_custom( @@ -355,4 +226,14 @@ class MAPIES: #return fig - + @staticmethod + def to_xarray(coords:dict, data_vars:dict, **kwargs): + """ + Method to convert numpy arrays to xarray opject + """ + attrs = kwargs.get("attrs") + if attrs is None: + attrs=dict() + + ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs) + return ds diff --git a/mapies/tropomi.py b/mapies/tropomi.py index 2df0c03a..af9d9e8b 100644 --- a/mapies/tropomi.py +++ b/mapies/tropomi.py @@ -18,6 +18,7 @@ from mapies.util.func_tools import * #timeit, get_file_list, time_converter, frequency_int, error_estimation, time_domain_selection, geo_to_rot from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function from mapies.util.variables_names import VARS_ATTRIBUTES +from mapies.util.multiprocessing import TROPOMIMultiprocessing import json from numpy.typing import DTypeLike, NDArray @@ -29,49 +30,7 @@ os.environ["OMP_NUM_THREADS"] = "8" TIME_ORIGIN = np.datetime64("1900-01-01T00:00:00") -def process_single_file( - r, - file, - obs_var, - lat_var, - lon_var, - time_var, - start_date, - end_date, - apply_qa, - qualityFlags, - flag=False, - geospatial_crop=None - ): - """ - Process a single file and return aggregated results or flattened arrays. - """ - try: - ds = xr.open_dataset(file, engine="h5netcdf", group="PRODUCT") - obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) - - if apply_qa: - obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) - obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) - if geospatial_crop is not None: - obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) - - if not flag: - # Regrid and aggregate - if isinstance(r, RotatedGrid): - lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) - elif isinstance(r, RegularGrid): - lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) - ds.close() - return obs_agg, lat_agg, lon_agg, count_obs - else: - ds.close() - return obs, lat, lon, time_values - - except Exception as e: - print(f"Error processing file {file}: {e}") - return None class TROPOMI(MAPIES): """ @@ -288,15 +247,24 @@ class TROPOMI(MAPIES): print(f"Batch: {batch}") # Use multiprocessing to process files within the batch - args = [(self.grid, file, self.obs_var, self.lat_var, self.lon_var, self.time_var, self.start_date, self.end_date, self.apply_qa, self.qualityFlags, flag, self.geospatial_crop) for file in batch] - - with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed - results = pool.starmap(process_single_file, args) - - # Filter valid results and log failed files - valid_results = [res for res in results if res is not None] - failed_files.extend([file for file, res in zip(batch, results) if res is None]) - + + tropomi_mp = TROPOMIMultiprocessing( + self.grid, + self.obs_var, + self.lat_var, + self.lon_var, + self.time_var, + self.start_date, + self.end_date, + self.apply_qa, + self.qualityFlags, + batch, + self.geospatial_crop, + flag, + ) + valid_results, failed = tropomi_mp.process_batch() + failed_files.extend(failed) + if not valid_results: print("All files in this batch failed. Skipping...") continue diff --git a/mapies/util/func_tools.py b/mapies/util/func_tools.py index 9a63fdc8..4823d55c 100644 --- a/mapies/util/func_tools.py +++ b/mapies/util/func_tools.py @@ -11,7 +11,6 @@ from typing import List from datetime import datetime import multiprocessing -from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function TIME_ORIGIN = np.datetime64("1993-01-01T00:00:00") @@ -270,30 +269,6 @@ def error_estimation(datatype:str, obs:np.typing.NDArray, unc_const:list) -> np. - -# def geo_to_rot(lons, lats, centre_lon: float, centre_lat: float): -# """ -# Rotating coordinates from cartesian lat/lon to rotated rlon/rlat -# """ -# distance_lons = np.radians(lons - centre_lon) -# lons = np.radians(lons) -# lats = np.radians(lats) -# centre_lon = np.radians(centre_lon) -# centre_lat = np.radians(centre_lat) - -# x = cos(centre_lat) * sin(lats) - sin(centre_lat) * cos(lats) * cos(distance_lons) -# y = cos(lats) * sin(distance_lons) -# z = cos(centre_lat) * cos(lats) * cos(distance_lons) + sin(centre_lat) * sin(lats) -# # Arctan2 used -# # Explanation of the difference https://geo.libretexts.org/Courses/University_of_California_Davis/GEL_056%3A_Introduction_to_Geophysics/Geophysics_is_everywhere_in_geology.../zz%3A_Back_Matter/Arctan_vs_Arctan2 -# rlon = np.arctan2(y, z) -# rlat = np.arcsin(x) -# #rlon[x < 0] += pi -# # Convert back to degrees -# rlon = np.degrees(rlon) -# rlat = np.degrees(rlat) -# return rlon, rlat - def new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date): """ * Preprocessing of the dataset @@ -381,6 +356,8 @@ def day_boundaries(day, start_date, end_date): return day_start, day_end + + def filter_files(files, start_date, end_date): """ Filter files based on the start_date and end_date @@ -397,6 +374,8 @@ def filter_files(files, start_date, end_date): return files_filtered + + def generate_mid_time_array(start_date, end_date, obs): """ Generate the mid time array for the observations @@ -410,6 +389,8 @@ def generate_mid_time_array(start_date, end_date, obs): return mid_time_array + + def create_batches(files, batch_size): """ Create batches of files to pass to MP @@ -422,65 +403,69 @@ def create_batches(files, batch_size): return batches -def process_single_file( - r, - file, - obs_var, - lat_var, - lon_var, - time_var, - start_date, - end_date, - apply_qa, - qualityFlags, - flag=False, - geospatial_crop=None - ): - """ - Process a single file and return aggregated results or flattened arrays. - """ - try: - ds = xr.open_dataset(file, engine="h5netcdf") + + +# def process_single_file( +# r, +# file, +# obs_var, +# lat_var, +# lon_var, +# time_var, +# start_date, +# end_date, +# apply_qa, +# qualityFlags, +# flag=False, +# geospatial_crop=None +# ): +# """ +# Process a single file and return aggregated results or flattened arrays. +# """ +# try: +# ds = xr.open_dataset(file, engine="h5netcdf") - obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) +# obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) - #Run a time domain selection - if apply_qa: - obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) +# #Run a time domain selection +# if apply_qa: +# obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) - obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) +# obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) - if geospatial_crop is not None: - obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) - - if not flag: - # Regrid and aggregate - if isinstance(r, RotatedGrid): - lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) - elif isinstance(r, RegularGrid): - lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) - ds.close() - return obs_agg, lat_agg, lon_agg, count_obs - else: - ds.close() - return obs, lat, lon, time_values +# if geospatial_crop is not None: +# obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) + +# if not flag: +# # Regrid and aggregate +# if isinstance(r, RotatedGrid): +# lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) +# elif isinstance(r, RegularGrid): +# lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) +# ds.close() +# return obs_agg, lat_agg, lon_agg, count_obs +# else: +# ds.close() +# return obs, lat, lon, time_values + +# except Exception as e: +# print(f"Error processing file {file}: {e}") + +# return None - except Exception as e: - print(f"Error processing file {file}: {e}") - return None -def process_batch(args, batch): +# def process_batch(args, batch): - with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed - results = pool.starmap(process_single_file, args) +# with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed +# results = pool.starmap(process_single_file, args) - # Filter valid results and log failed files - valid_results = [res for res in results if res is not None] - print(f"Valid results: {len(valid_results)}") - failed_files= [file for file, res in zip(batch, results) if res is None] +# # Filter valid results and log failed files +# valid_results = [res for res in results if res is not None] +# print(f"Valid results: {len(valid_results)}") +# failed_files= [file for file, res in zip(batch, results) if res is None] - return valid_results, failed_files +# return valid_results, failed_files def combine_arrays(valid_results, cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count): @@ -502,6 +487,8 @@ def combine_arrays(valid_results, cumulative_obs, cumulative_lat, cumulative_lon return cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count + + def compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count): """ Compute the final averages from the cumulative arrays. In other words divide the cumulative_obs by the cumulative_count. @@ -518,6 +505,8 @@ def compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumul return final_obs, final_lat, final_lon + + def append_and_concatenate(valid_results, day): """ Append and concatenate arrays @@ -546,6 +535,8 @@ def append_and_concatenate(valid_results, day): return final_obs, final_lat, final_lon, final_time + + def filter_observations(obs, lat, lon, count): """ Filter out observations with negative values @@ -557,4 +548,5 @@ def filter_observations(obs, lat, lon, count): final_obs = obs[valid_obs] final_count = count[valid_obs] - return final_obs, final_lat, final_lon, final_count \ No newline at end of file + return final_obs, final_lat, final_lon, final_count + diff --git a/mapies/util/multiprocessing.py b/mapies/util/multiprocessing.py new file mode 100644 index 00000000..6bbb60f4 --- /dev/null +++ b/mapies/util/multiprocessing.py @@ -0,0 +1,280 @@ +import time +import logging +import numpy as np +import pandas as pd +import xarray as xr + +from functools import wraps +from numpy import cos, sin, arctan, pi, nan +from glob import glob +from typing import List +from datetime import datetime +from abc import ABC, abstractmethod +import multiprocessing +from dataclasses import dataclass +from numpy.typing import DTypeLike, NDArray + +from mapies.util.func_tools import * +#timeit, get_file_list, time_converter, frequency_int, error_estimation, time_domain_selection, geo_to_rot +from mapies.grids.monarch import Grid, RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function + + +def process_single_file_tropomi( + r, + file, + obs_var, + lat_var, + lon_var, + time_var, + start_date, + end_date, + apply_qa, + qualityFlags, + flag=False, + geospatial_crop=None + ): + """ + Process a single file and return aggregated results or flattened arrays. + """ + try: + ds = xr.open_dataset(file, engine="h5netcdf", group="PRODUCT") + except Exception as e: + print(f"Reading file {file}: {e}") + return None + + # Flatten arrays + try: + obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) + except Exception as e: + print(f"Error processing file {file}: {e}") + return None + + # Apply quality assurance if applicable + if apply_qa: + try: + obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) + except Exception as e: + print(f"Error applying QA Flags for file {file}: {e}") + return None + # Perform time domain selection beteen start and end date + try: + obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) + except Exception as e: + print(f"Error performing time domain selection for file {file}: {e}") + return None + + # Perform station domain selection if specified + if geospatial_crop is not None: + try: + obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) + except Exception as e: + print(f"Error performing spatial domain selection for file {file}: {e}") + return None + + if not flag: + try: + # Regrid and aggregate + if isinstance(r, RotatedGrid): + lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + elif isinstance(r, RegularGrid): + lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + ds.close() + return obs_agg, lat_agg, lon_agg, count_obs + except Exception as e: + print(f"Error performing regridding for file {file}: {e}") + return None + else: + ds.close() + return obs, lat, lon, time_values + + + +def process_single_file_viirs( + r, + file, + obs_var, + lat_var, + lon_var, + time_var, + start_date, + end_date, + apply_qa, + qualityFlags, + flag=False, + geospatial_crop=None + ): + """ + Process a single file and return aggregated results or flattened arrays. + """ + try: + ds = xr.open_dataset(file, engine="h5netcdf") + except Exception as e: + print(f"Reading file {file}: {e}") + return None + + # Flatten arrays + try: + obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) + except Exception as e: + print(f"Error processing file {file}: {e}") + return None + + # Apply quality assurance if applicable + if apply_qa: + try: + obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) + except Exception as e: + print(f"Error applying QA Flags for file {file}: {e}") + return None + # Perform time domain selection beteen start and end date + try: + obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) + except Exception as e: + print(f"Error performing time domain selection for file {file}: {e}") + return None + + # Perform station domain selection if specified + if geospatial_crop is not None: + try: + obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) + except Exception as e: + print(f"Error performing spatial domain selection for file {file}: {e}") + return None + + if not flag: + try: + # Regrid and aggregate + if isinstance(r, RotatedGrid): + lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + elif isinstance(r, RegularGrid): + lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + ds.close() + return obs_agg, lat_agg, lon_agg, count_obs + except Exception as e: + print(f"Error performing regridding for file {file}: {e}") + return None + else: + ds.close() + return obs, lat, lon, time_values + + + +class GenericMultiProcessing(ABC): + """ + Base method for multiprocessing + """ + + @abstractmethod + def process_batch(self): + pass + + +class VIIRSMultiprocessing(GenericMultiProcessing): + """ + Multiprocessing method for VIIRS + """ + + r: Grid + obs_var: NDArray + lat_var: NDArray + lon_var: NDArray + time_var: NDArray + start_date: pd.Timestamp + end_date: pd.Timestamp + apply_qa: bool + qualityFlags: dict + batch: List[str] + geospatial_crop: NDArray = None + flag: bool = False + + + + def process_batch(self): + """ + Batch processing for VIIRS + """ + + args = [ + ( + self.r, + file, + self.obs_var, + self.lat_var, + self.lon_var, + self.time_var, + self.start_date, + self.end_date, + self.apply_qa, + self.qualityFlags, + self.flag, + self.geospatial_crop + ) + for file in self.batch + ] + + + + with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed + results = pool.starmap(process_single_file_viirs, args) + + # Filter valid results and log failed files + valid_results = [res for res in results if res is not None] + print(f"Valid results: {len(valid_results)}") + failed_files= [file for file, res in zip(self.batch, results) if res is None] + + return valid_results, failed_files + + +@dataclass +class TROPOMIMultiprocessing(GenericMultiProcessing): + """ + Multiprocessing method for TROPOMI + """ + r: Grid + obs_var: NDArray + lat_var: NDArray + lon_var: NDArray + time_var: NDArray + start_date: pd.Timestamp + end_date: pd.Timestamp + apply_qa: bool + qualityFlags: dict + batch: List[str] + geospatial_crop: NDArray = None + flag: bool = False + + + + def process_batch(self): + """ + Batch processing for TROPOMI + """ + + args = [ + ( + self.r, + file, + self.obs_var, + self.lat_var, + self.lon_var, + self.time_var, + self.start_date, + self.end_date, + self.apply_qa, + self.qualityFlags, + self.flag, + self.geospatial_crop + ) + for file in self.batch + ] + + + with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed + results = pool.starmap(process_single_file_tropomi, args) + + # Filter valid results and log failed files + valid_results = [res for res in results if res is not None] + print(f"Valid results: {len(valid_results)}") + failed_files= [file for file, res in zip(self.batch, results) if res is None] + + return valid_results, failed_files diff --git a/mapies/viirs.py b/mapies/viirs.py index 3b045bc1..27bba79f 100644 --- a/mapies/viirs.py +++ b/mapies/viirs.py @@ -39,54 +39,68 @@ TIME_ORIGIN = np.datetime64("1900-01-01T00:00:00") - -# def process_single_file( -# r, -# file, -# obs_var, -# lat_var, -# lon_var, -# time_var, -# start_date, -# end_date, -# apply_qa, -# qualityFlags, -# flag=False, -# geospatial_crop=None -# ): -# """ -# Process a single file and return aggregated results or flattened arrays. -# """ -# try: -# ds = xr.open_dataset(file, engine="h5netcdf") +def process_single_file( + r, + file, + obs_var, + lat_var, + lon_var, + time_var, + start_date, + end_date, + apply_qa, + qualityFlags, + flag=False, + geospatial_crop=None + ): + """ + Process a single file and return aggregated results or flattened arrays. + """ + try: + ds = xr.open_dataset(file, engine="h5netcdf") -# obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) + obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) -# #Run a time domain selection -# if apply_qa: -# obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) + #Run a time domain selection + if apply_qa: + obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) -# obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) + obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) -# if geospatial_crop is not None: -# obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) - -# if not flag: -# # Regrid and aggregate -# if isinstance(r, RotatedGrid): -# lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) -# elif isinstance(r, RegularGrid): -# lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) -# ds.close() -# return obs_agg, lat_agg, lon_agg, count_obs -# else: -# ds.close() -# return obs, lat, lon, time_values + if geospatial_crop is not None: + obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) + + if not flag: + # Regrid and aggregate + if isinstance(r, RotatedGrid): + lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + elif isinstance(r, RegularGrid): + lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + ds.close() + return obs_agg, lat_agg, lon_agg, count_obs + else: + ds.close() + return obs, lat, lon, time_values -# except Exception as e: -# print(f"Error processing file {file}: {e}") + except Exception as e: + print(f"Error processing file {file}: {e}") + + return None + + + +def process_batch(args, batch): + + with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed + results = pool.starmap(process_single_file, args) + + # Filter valid results and log failed files + valid_results = [res for res in results if res is not None] + print(f"Valid results: {len(valid_results)}") + failed_files= [file for file, res in zip(batch, results) if res is None] + + return valid_results, failed_files -# return None class VIIRS(MAPIES): diff --git a/run/run_viirs.py b/run/run_viirs.py index 41eff997..989b1695 100644 --- a/run/run_viirs.py +++ b/run/run_viirs.py @@ -11,11 +11,11 @@ if __name__ == "__main__": - # outdir="/home/cmeikle/Projects/data/" - # indir="/home/cmeikle/Projects/data/VIIRS" + outdir="/home/cmeikle/Projects/data/" + indir="/home/cmeikle/Projects/data/VIIRS" - outdir="/home/cgile/Documents/mapies/figures" - indir="/home/cgile/Documents/mapies/VIIRS" + #outdir="/home/cgile/Documents/mapies/figures" + #indir="/home/cgile/Documents/mapies/VIIRS" start_time = time.time() -- GitLab From f75ab094e9bae6f921ddf976164f188376852640 Mon Sep 17 00:00:00 2001 From: Calum Meikle Date: Fri, 28 Feb 2025 13:14:29 +0100 Subject: [PATCH 2/4] Moved a lot of duplicate code to mapies.py --- mapies/mapies.py | 434 ++++++++++++++++++++++++++- mapies/tropomi.py | 471 +---------------------------- mapies/util/multiprocessing.py | 2 +- mapies/viirs.py | 526 +-------------------------------- run/run_tropomi.py | 4 +- run/run_viirs.py | 14 +- 6 files changed, 445 insertions(+), 1006 deletions(-) diff --git a/mapies/mapies.py b/mapies/mapies.py index baa8c75b..82a1e294 100644 --- a/mapies/mapies.py +++ b/mapies/mapies.py @@ -4,7 +4,7 @@ from functools import wraps from datetime import datetime, timedelta import sys -from mapies.util.func_tools import timeit, time_domain_selection, frequency_int, error_estimation +from mapies.util.func_tools import * from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid import time import logging @@ -18,10 +18,14 @@ import cartopy import cartopy.crs as ccrs import os import yaml +import json import matplotlib.cm as cm from matplotlib.colors import ListedColormap import matplotlib.colors as mcolors +from mapies.util.variables_names import VARS_ATTRIBUTES +from numpy.typing import DTypeLike, NDArray + #logger = logging.getLogger(__name__) #logging.basicConfig(filename='../logs/mapies.log', level=logging.DEBUG) @@ -37,7 +41,6 @@ class MAPIES: """ self.np_typetime = "datetime64[s]" self.int_type = 'int64' - self.time_orig = np.datetime64("1900-01-01T00:00:00") self.start_date = start_date self.end_date = end_date self.apply_qa = True @@ -100,7 +103,432 @@ class MAPIES: self.qualityFlags = qa_dict["qa_variables"] + @timeit + def process_data(self, monthly_avg=False, batch_size=100, save=False, apply_qa=True, geospatial_crop: NDArray | None = None): + + """ + Process the data. + + Parameters:geospatial_crop: NDArray | None = None + - monthly_avg: bool, whether to compute single monthly averages. + - batch_size: int, the number of files to process in each batch. + + Returns: + - if monthly_avg --> updated monthly dictorionary for the processed month. + - if not monthly_avg --> updated self.lon_values, self.lat_values, self.obs, self.count_obs. + """ + + #Check if the user wants qa flags + if apply_qa: + self.apply_qa = True + else: + self.apply_qa = False + if geospatial_crop is not None: + self.geospatial_crop = np.array(geospatial_crop) + if monthly_avg: + for month in self.monthly_dict.keys(): + + # Catch if they haven't selected a grid + try: + print(f"Running average on {self.grid}") + except AttributeError: + raise Exception("Please specify a grid in the initialisation of MAPIES, using grid_repr=''") + try: + # Process the data for the current month + final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( + batch_size=batch_size, month=month + ) + + self.monthly_dict[month] = { + "lon": final_lon, + "lat": final_lat, + "obs": final_obs, + "count": cumulative_count, + } + print(f'Updated monthly dictionary for month {month}') + print(f'Lon shape: {final_lon.shape}, Lat shape: {final_lat.shape}, Obs shape: {final_obs.shape}') + if np.all(final_lon == 0) or np.all(np.isnan(final_lon)): + print(f"All longitude values for month {month} are 0 or NaN.") + if np.all(final_lat == 0) or np.all(np.isnan(final_lat)): + print(f"All latitude values for month {month} are 0 or NaN.") + if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): + print(f"All observation values for month {month} are 0 or NaN.") + + except Exception as e: + print(f"Error processing data for month {month}: {e}") + continue + else: + try: + final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( + batch_size=batch_size, month=None + ) + + self.obs, self.lat, self.lon, self.count_obs = filter_observations(final_obs, final_lat, final_lon, cumulative_count) + + + if save: + try: + mid_time_array = generate_mid_time_array(self.start_date, self.end_date, self.obs) + + self.to_netCDF(self.obs, self.lat, self.lon, mid_time_array, self.start_date, self.end_date) + print(f'Saved processed data to netCDF') + except Exception as e: + print(f"Error saving data into netCDF: {e}") + return + + except Exception as e: + print(f"Error processing data for the time period from {self.start_date} to {self.end_date}: {e}") + return + + + def process_monthly_data(self, batch_size, month=None): + + """ + Process the data with multiprocessing. + """ + flag = False + if month is None: + month_files = self.files + print(f'Processing all files from {self.start_date} to {self.end_date}') + else: + month_files = self.monthly_dict[month]["files"] + print(f"Processing {len(month_files)} files for month {month}") + + + # Split files into batches + batches = create_batches(month_files, batch_size) + + # Initialize cumulative arrays + cumulative_obs = None + cumulative_lat = None + cumulative_lon = None + cumulative_count = None + failed_files = [] + + for batch in batches: + print(f"Processing a batch of {len(batch)} files...") + + # Use multiprocessing to process files within the batch + + viirs_mp = self.mp_type( + self.grid, + self.obs_var, + self.lat_var, + self.lon_var, + self.time_var, + self.start_date, + self.end_date, + self.apply_qa, + self.qualityFlags, + batch, + self.geospatial_crop, + flag, + ) + valid_results, failed = viirs_mp.process_batch() + failed_files.extend(failed) + + if not valid_results: + print("All files in this batch failed. Skipping...") + continue + + # Aggregate results incrementally + cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count = combine_arrays(valid_results, cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) + + if failed_files: + print(f"The following files failed to process: {failed_files}") + + # Check if any valid data was processed + if cumulative_obs is None or cumulative_count is None: + print(f"No valid data processed for month {month}. Returning empty arrays.") + return None, None, None, None + # Compute final averages + final_obs, final_lat, final_lon = compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) + + print(f"Completed processing for month {month}.") + + return final_lon, final_lat, final_obs, cumulative_count + + + @timeit + def yearly_average(self): + """ + Compute the yearly average for the processed data. + """ + cumulative_obs = None + cumulative_lat = None + cumulative_lon = None + cumulative_count = None + + for key, month_data in self.monthly_dict.items(): + if not all(var in month_data for var in ["lon", "lat", "obs", "count"]): + print(f"Skipping month {key}: Missing data.") + continue + if any(month_data[var] is None for var in ["lon", "lat", "obs", "count"]): + print(f"Skipping month {key}: No valid data.") + continue + + lon_values = month_data["lon"] + lat_values = month_data["lat"] + obs_values = month_data["obs"] + count_values = month_data["count"] + + results = [(obs_values, lat_values, lon_values, count_values)] + cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count = combine_arrays(results, cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) + + if cumulative_obs is None or cumulative_count is None or np.all(cumulative_count == 0): + print("No valid data for the entire year. Returning empty arrays.") + self.yearly_lon = None + self.yearly_lat = None + self.yearly_obs = None + self.yearly_count = None + return + + yearly_obs, yearly_lat, yearly_lon = compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) + + self.yearly_lon = yearly_lon + self.yearly_lat = yearly_lat + self.yearly_obs = yearly_obs + self.yearly_count = cumulative_count + + print("Yearly average computation completed.") + + + @timeit + def process_lazy_data(self, apply_qa=True, geospatial_crop: NDArray | None = None, save=True): + """ + Process the data for the specified time range. It only retrieves a daily dictionary using dask lazy loading. + """ + flag = True + #Check if the user wants qa flags + if apply_qa: + self.apply_qa = True + else: + self.apply_qa = False + if geospatial_crop is not None: + self.geospatial_crop = np.array(geospatial_crop) + for day in self.daily_dict.keys(): + try: + files = self.daily_dict[day]["files"] + + # Use multiprocessing to process files within the batch + + viirs_mp = self.mp_type( + self.grid, + self.obs_var, + self.lat_var, + self.lon_var, + self.time_var, + self.start_date, + self.end_date, + self.apply_qa, + self.qualityFlags, + files, + self.geospatial_crop, + flag, + ) + valid_results, failed_files = viirs_mp.process_batch() + + final_obs, final_lat, final_lon, final_time = append_and_concatenate(valid_results, day) + + # Skip the day if no valid data was processed + if final_obs is None: + continue + + # Check if any valid data exists + if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): + print(f"All observation values for day {day} are 0 or NaN. Skipping...") + continue + + # Update day dict + self.daily_dict[day]["obs"] = final_obs + self.daily_dict[day]["lon"] = final_lon + self.daily_dict[day]["lat"] = final_lat + self.daily_dict[day]["time"] = final_time + + # Get day start and day end but only if they are within object start and end date + day_start, day_end = day_boundaries(day, self.start_date, self.end_date) + if save: + self.to_netCDF(final_obs, final_lat, final_lon, final_time, day_start, day_end) + + except Exception as e: + print(f"Error processing data for day {day}: {e}") + continue + + + def to_netCDF(self, obs, lat, lon, time, start_time, end_time): + try: + # Convert time to seconds since TIME_ORIGIN + final_time_seconds = (time - self.time_orig) / np.timedelta64(1, "s") + final_time_seconds = np.round(final_time_seconds, 0) + + grid_representation_str = json.dumps(self.grid_config, ensure_ascii=False) + qa_flags_str = json.dumps(self.qualityFlags, ensure_ascii=False) + + # Create xarray Dataset + ds = xr.Dataset( + coords={"time": ("time", final_time_seconds, { + "units": "seconds since 1900-01-01 00:00:00", + "calendar": "gregorian" + })}, + + data_vars={ + "lon": (["time"], lon, VARS_ATTRIBUTES['lon'].copy()), + "lat": (["time"], lat, VARS_ATTRIBUTES['lat'].copy()), + VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: (["time"], obs, VARS_ATTRIBUTES[self.obs_var].copy()), + }, + attrs={ + "title": f"{VARS_ATTRIBUTES['title']['description']} from {start_time} to {end_time}", + "institution": "Barcelona Supercomputing Center", + "grid": f"{self.grid_repr}: {grid_representation_str}", + "Developers": "MAPIES team", + "QA": f"Quality Assurance: {self.apply_qa}, quality assurance applied: {qa_flags_str}", + "history": f"Created {datetime.now()}", + }, + ) + + # Define encoding for compact storage + encoding = { + "time": {"dtype": "float32"}, + "lon": {"dtype": "float32"}, + "lat": {"dtype": "float32"}, + VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: {"dtype": "float32"}, + } + + # Save to NetCDF file + filename = f"{self.dest}/Data_{self.datatype}_{start_time}_{end_time}.nc" + ds.to_netcdf(filename, mode='w', encoding=encoding) + + print(f"Saved data to {filename}") + + except Exception as e: + print(f"Error saving data into netCDF: {e}") + return + + +# ============================================================================= +# Plotting +# ============================================================================= + def plot_2D_observations(self, months=[0], filename=None, outdir=None): + + if months == [0]: + try: + lon_values = self.lon + lat_values = self.lat + obs_values = self.obs + except Exception as e: + print(f"Error plotting all data: {e}") + return + title = f"Observation 2D plot from {self.start_date} to {self.end_date}" + filename = filename or f"{self.datatype}_2D_obs_from_{self.start_date}_to_{self.end_date}.png" + self.plot_2D_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + obs_values=obs_values, + outdir=outdir, + title=title, + filename=filename, + ) + return + elif months is not None: + for m in months: + try: + month_data = self.monthly_dict[m] + lon_values = month_data["lon"] + lat_values = month_data["lat"] + obs_values = month_data["obs"] + except Exception as e: + print(f"Error for month {m}: {e}") + continue + title = f"Observation 2D plot of month {m} from {self.start_date} to {self.end_date}" + filename = filename or f"{self.datatype}_2D_obs_month_{m}_year_{self.year}.png" + super().plot_2D_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + obs_values=obs_values, + outdir=outdir, + title=title, + filename=filename, + ) + else: + try: + lon_values = self.yearly_lon + lat_values = self.yearly_lat + obs_values = self.yearly_obs + except Exception as e: + print(f"Error for yearly data: {e}") + return + title = f"Observation 2D plot of yearly average of year {self.year}" + filename = filename or f"{self.datatype}_2D_yearly_obs_{self.year}.png" + self.plot_2D_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + obs_values=obs_values, + outdir=outdir, + title=title, + filename=filename, + ) + + def plot_2D_num_obs(self, months=[0], filename=None, outdir=None): + + if months == [0]: + try: + lon_values = self.lon + lat_values = self.lat + cum_count = self.count_obs + except Exception as e: + print(f"Error plotting all data: {e}") + return + title = f"Plot of the number of valid observations from {self.start_date} to {self.end_date}" + filename = filename or f"{self.datatype}_2D_obs_count_from_{self.start_date}_to_{self.end_date}.png" + self.plot_2D_num_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + num_obs=cum_count, + outdir=outdir, + title=title, + filename=filename, + ) + return + elif months is not None: + for m in months: + try: + month_data = self.monthly_dict[m] + lon_values = month_data["lon"] + lat_values = month_data["lat"] + cum_count = month_data["count"] + except Exception as e: + print(f"Error for month {m}: {e}") + continue + title = f"Plot of the number of valid observations for month {m} from {self.start_date} to {self.end_date}" + filename = filename or f"{self.datatype}_2D_obs_count_month_{m}_year_{self.year}.png" + self.plot_2D_num_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + num_obs=cum_count, + outdir=outdir, + title=title, + filename=filename, + ) + else: + try: + lon_values = self.yearly_lon + lat_values = self.yearly_lat + cum_count = self.yearly_count + except Exception as e: + print(f"Error for yearly data: {e}") + return + title = f"Plot of the number of valid observations for yearly average of {self.year}" + filename = filename or f"{outdir}/{self.datatype}_2D_obs_count_yearly_{self.year}.png" + self.plot_2D_num_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + num_obs=cum_count, + outdir=outdir, + title=title, + filename=filename, + ) + # TODO this can be moved to a plotting function python file @timeit @@ -236,4 +664,4 @@ class MAPIES: attrs=dict() ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs) - return ds + return ds \ No newline at end of file diff --git a/mapies/tropomi.py b/mapies/tropomi.py index af9d9e8b..b1c5a39c 100644 --- a/mapies/tropomi.py +++ b/mapies/tropomi.py @@ -17,7 +17,6 @@ import os from mapies.util.func_tools import * #timeit, get_file_list, time_converter, frequency_int, error_estimation, time_domain_selection, geo_to_rot from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function -from mapies.util.variables_names import VARS_ATTRIBUTES from mapies.util.multiprocessing import TROPOMIMultiprocessing import json from numpy.typing import DTypeLike, NDArray @@ -27,8 +26,6 @@ os.environ["OMP_NUM_THREADS"] = "8" #logger = logging.getLogger(__name__) #logging.basicConfig(filename='../logs/mapies.log', level=logging.INFO) -TIME_ORIGIN = np.datetime64("1900-01-01T00:00:00") - @@ -43,7 +40,7 @@ class TROPOMI(MAPIES): """ super().__init__(start_date, end_date) - self.time_orig = np.datetime64("1993-01-01T00:00:00") + self.time_orig = np.datetime64("1900-01-01T00:00:00") self.dest = kwargs.get("dest") if not os.path.exists(self.dest): raise Exception("Output directory doesn't exist") @@ -53,6 +50,7 @@ class TROPOMI(MAPIES): self.monthly_dict = {} self.daily_dict = {} # Dictionary to produce daily netcdf files self.datatype = "tropomi" + self.mp_type = TROPOMIMultiprocessing # Add quality value filter number @@ -144,468 +142,3 @@ class TROPOMI(MAPIES): # Print the daily disctionary to see if files are stored in chronological order for day, data in self.daily_dict.items(): print(f"Day {day:02d}: {len(data['files'])} files") - - - @timeit - def process_data(self, monthly_avg=False, batch_size=100, apply_qa=True, geospatial_crop: NDArray | None = None): - """ - Process the data for the specified year and months. - - Parameters: - - monthly_avg: bool, whether to compute single monthly averages. - - batch_size: int, the number of files to process in each batch. - - chunk_size: int, the size of chunks for processing within each batch. - - geospatial_crop, 2D numpy array - - Returns: - - if monthly_avg --> lon_final, lat_final, obs_final: Aggregated longitude, latitude, and observation arrays. - - if not monthly_avg --> preprocess_vars() is called. - """ - #Check if the user wants qa flags - if apply_qa: - self.apply_qa = True - if geospatial_crop is not None: - self.geospatial_crop = np.array(geospatial_crop) - if monthly_avg: - for month in self.monthly_dict.keys(): - # print(f'Month: {self.monthly_dict[month]}') - # Catch if they haven't selected a grid - try: - print(f"Running average on {self.grid}") - except AttributeError: - raise Exception("Please specify a grid in the initialisation of MAPIES, using grid_repr=''") - try: - # Process the data for the current month - final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( - batch_size=batch_size, month=month - ) - - # Update the monthly dictionary - valid_obs = final_obs > 0 - self.monthly_dict[month] = { - "lon": final_lon[valid_obs], - "lat": final_lat[valid_obs], - "obs": final_obs[valid_obs], - "count": cumulative_count[valid_obs], - } - print(f'Updated monthly dictionary for month {month}') - print(f'Lon shape: {final_lon.shape}, Lat shape: {final_lat.shape}, Obs shape: {final_obs.shape}') - if np.all(final_lon == 0) or np.all(np.isnan(final_lon)): - print(f"All longitude values for month {month} are 0 or NaN.") - if np.all(final_lat == 0) or np.all(np.isnan(final_lat)): - print(f"All latitude values for month {month} are 0 or NaN.") - if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): - print(f"All observation values for month {month} are 0 or NaN.") - - except Exception as e: - print(f"Error processing data for month {month}: {e}") - continue - else: - try: - final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( - batch_size=batch_size, month=None - ) - valid_obs = final_obs > 0 - self.lon = final_lon[valid_obs] - self.lat = final_lat[valid_obs] - self.obs = final_obs[valid_obs] - self.count_obs = cumulative_count[valid_obs] - - except Exception as e: - print(f"Error processing data for the time period from {self.start_date} to {self.end_date}: {e}") - return - - def process_monthly_data(self, batch_size, month=None): - - """ - Function that returns all the needed variables for the DA - """ - # r = RotatedGrid(centre_lon=20, centre_lat=35, dlon=.1, dlat=.1, west=-51, south=-35) - flag = False - if month is None: - month_files = self.files - print(f'Processing all files from {self.start_date} to {self.end_date}') - else: - month_files = self.monthly_dict[month]["files"] - print(f"Processing {len(month_files)} files for month {month}") - - - # Split files into batches - batches = [ - month_files[i:i + batch_size] for i in range(0, len(month_files), batch_size) - ] - - # Initialize cumulative arrays - cumulative_obs = None - cumulative_lat = None - cumulative_lon = None - cumulative_count = None - failed_files = [] - - for batch in batches: - print(f"Processing a batch of {len(batch)} files...") - print(f"Batch: {batch}") - - # Use multiprocessing to process files within the batch - - tropomi_mp = TROPOMIMultiprocessing( - self.grid, - self.obs_var, - self.lat_var, - self.lon_var, - self.time_var, - self.start_date, - self.end_date, - self.apply_qa, - self.qualityFlags, - batch, - self.geospatial_crop, - flag, - ) - valid_results, failed = tropomi_mp.process_batch() - failed_files.extend(failed) - - if not valid_results: - print("All files in this batch failed. Skipping...") - continue - - # Aggregate results incrementally - for obs_agg, lat_agg, lon_agg, count_obs in valid_results: - if cumulative_obs is None: - # Initialize cumulative arrays - cumulative_obs = obs_agg * count_obs - cumulative_lat = lat_agg * count_obs - cumulative_lon = lon_agg * count_obs - cumulative_count = count_obs - else: - cumulative_obs += obs_agg * count_obs - cumulative_lat += lat_agg * count_obs - cumulative_lon += lon_agg * count_obs - cumulative_count += count_obs - - if failed_files: - print(f"The following files failed to process: {failed_files}") - - # Check if any valid data was processed - if cumulative_obs is None or cumulative_count is None: - print(f"No valid data processed for month {month}. Returning empty arrays.") - return None, None, None, None - - # Compute final averages - valid_mask = cumulative_count > 0 - final_obs = np.zeros_like(cumulative_obs) - final_lat = np.zeros_like(cumulative_lat) - final_lon = np.zeros_like(cumulative_lon) - - final_obs[valid_mask] = cumulative_obs[valid_mask] / cumulative_count[valid_mask] - final_lat[valid_mask] = cumulative_lat[valid_mask] / cumulative_count[valid_mask] - final_lon[valid_mask] = cumulative_lon[valid_mask] / cumulative_count[valid_mask] - - print(f"Completed processing for month {month}.") - - return final_lon, final_lat, final_obs, cumulative_count - - @timeit - def yearly_average(self): - """ - Compute the yearly average for the processed data. - """ - cumulative_obs = None - cumulative_lat = None - cumulative_lon = None - cumulative_count = None - - for key, month_data in self.monthly_dict.items(): - if not all(var in month_data for var in ["lon", "lat", "obs", "count"]): - print(f"Skipping month {key}: Missing data.") - continue - if any(month_data[var] is None for var in ["lon", "lat", "obs", "count"]): - print(f"Skipping month {key}: No valid data.") - continue - - lon_values = month_data["lon"] - lat_values = month_data["lat"] - obs_values = month_data["obs"] - count_values = month_data["count"] - - if cumulative_obs is None: - cumulative_obs = obs_values * count_values - cumulative_lat = lat_values * count_values - cumulative_lon = lon_values * count_values - cumulative_count = count_values - else: - cumulative_obs += obs_values * count_values - cumulative_lat += lat_values * count_values - cumulative_lon += lon_values * count_values - cumulative_count += count_values - - if cumulative_obs is None or cumulative_count is None or np.all(cumulative_count == 0): - print("No valid data for the entire year. Returning empty arrays.") - self.yearly_lon = None - self.yearly_lat = None - self.yearly_obs = None - self.yearly_count = None - return - - valid_mask = cumulative_count > 0 - yearly_obs = np.zeros_like(cumulative_obs) - yearly_lat = np.zeros_like(cumulative_lat) - yearly_lon = np.zeros_like(cumulative_lon) - - yearly_obs[valid_mask] = cumulative_obs[valid_mask] / cumulative_count[valid_mask] - yearly_lat[valid_mask] = cumulative_lat[valid_mask] / cumulative_count[valid_mask] - yearly_lon[valid_mask] = cumulative_lon[valid_mask] / cumulative_count[valid_mask] - - self.yearly_lon = yearly_lon - self.yearly_lat = yearly_lat - self.yearly_obs = yearly_obs - self.yearly_count = cumulative_count - - print("Yearly average computation completed.") - - @timeit - def process_lazy_data(self, apply_qa=True, geospatial_crop: NDArray | None = None, save=True): - """ - Process the data for the specified time range. It only retrieves a daily dictionary using dask lazy loading. - """ - flag = True - #Check if the user wants qa flags - if apply_qa: - self.apply_qa = True - else: - self.apply_qa = False - if geospatial_crop is not None: - self.geospatial_crop = np.array(geospatial_crop) - for day in self.daily_dict.keys(): - try: - files = self.daily_dict[day]["files"] - args = [ - ( - self.grid, - file, - self.obs_var, - self.lat_var, - self.lon_var, - self.time_var, - self.start_date, - self.end_date, - self.apply_qa, - self.qualityFlags, - flag, - self.geospatial_crop - - ) - for file in files - ] - - daily_obs = [] - daily_lat = [] - daily_lon = [] - daily_time = [] - - with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed - results = pool.starmap(process_single_file, args) - - valid_results = [res for res in results if res is not None] - # Check shapes of obs lat and lon to see if they are all the same - - for obs, lat, lon, time_values in valid_results: - #print(f'Shape of obs: {obs.shape}, Shape of lat: {lat.shape}, Shape of lon: {lon.shape}, Shape of time: {time_values.shape}') - if len(obs) > 0: - daily_lon.append(lon) - daily_lat.append(lat) - daily_obs.append(obs) - daily_time.append(time_values) - - if not daily_obs: - print(f"No valid data for day {day}") - continue - - final_obs = np.concatenate(daily_obs, axis=0) - final_lon = np.concatenate(daily_lon, axis=0) - final_lat = np.concatenate(daily_lat, axis=0) - final_time = np.concatenate(daily_time, axis=0) - - # Check if any valid data exists - if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): - print(f"All observation values for day {day} are 0 or NaN. Skipping...") - continue - # Get day start and day end but only if they are within object start and end date - day_start, day_end = day_boundaries(day, self.start_date, self.end_date) - if save: - self.to_netCDF(final_obs, final_lat, final_lon, final_time, day_start, day_end) - except Exception as e: - print(f"Error processing data for day {day}: {e}") - continue - - - def to_netCDF(self, obs, lat, lon, time, start_time, end_time): - try: - # Convert time to seconds since TIME_ORIGIN - final_time_seconds = (time - TIME_ORIGIN) / np.timedelta64(1, "s") - final_time_seconds = np.round(final_time_seconds, 0) - - grid_representation_str = json.dumps(self.grid_config, ensure_ascii=False) - qa_flags_str = json.dumps(self.qualityFlags, ensure_ascii=False) - - # Create xarray Dataset - ds = xr.Dataset( - coords={"time": ("time", final_time_seconds, { - "units": "seconds since 1900-01-01 00:00:00", - "calendar": "gregorian" - })}, - - data_vars={ - "lon": (["time"], lon, VARS_ATTRIBUTES['lon'].copy()), - "lat": (["time"], lat, VARS_ATTRIBUTES['lat'].copy()), - VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: (["time"], obs, VARS_ATTRIBUTES[self.obs_var].copy()), - }, - attrs={ - "title": f"{VARS_ATTRIBUTES['title']['description']} from {start_time} to {end_time}", - "institution": "Barcelona Supercomputing Center", - "grid": f"{self.grid_repr}: {grid_representation_str}", - "Developers": "MAPIES team", - "QA": f"Quality Assurance: {self.apply_qa}, quality assurance applied: {qa_flags_str}", - "history": f"Created {datetime.now()}", - }, - ) - - # Define encoding for compact storage - encoding = { - "time": {"dtype": "float32"}, - "lon": {"dtype": "float32"}, - "lat": {"dtype": "float32"}, - VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: {"dtype": "float32"}, - } - - # Save to NetCDF file - filename = f"{self.dest}/Data_{self.datatype}_{start_time}_{end_time}.nc" - ds.to_netcdf(filename, mode='w', encoding=encoding) - - print(f"Saved data to {filename}") - - except Exception as e: - print(f"Error saving data into netCDF: {e}") - return - - - - def plot_2D_observations(self, months=[0], filename=None, outdir=None): - - if months == [0]: - try: - lon_values = self.lon - lat_values = self.lat - obs_values = self.obs - except Exception as e: - print(f"Error plotting all data: {e}") - return - title = f"Observation 2D plot from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_from_{self.start_date}_to_{self.end_date}.png" - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - ) - return - elif months is not None: - for m in months: - try: - month_data = self.monthly_dict[m] - lon_values = month_data["lon"] - lat_values = month_data["lat"] - obs_values = month_data["obs"] - except Exception as e: - print(f"Error for month {m}: {e}") - continue - title = f"Observation 2D plot of month {m} from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_month_{m}_year_{self.year}.png" - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - ) - else: - try: - lon_values = self.yearly_lon - lat_values = self.yearly_lat - obs_values = self.yearly_obs - except Exception as e: - print(f"Error for yearly data: {e}") - return - title = f"Observation 2D plot of yearly average of year {self.year}" - filename = filename or f"{self.datatype}_2D_yearly_obs_{self.year}.png" - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - ) - - def plot_2D_num_obs(self, months=[0], filename=None, outdir=None): - - if months == [0]: - try: - lon_values = self.lon - lat_values = self.lat - cum_count = self.count_obs - except Exception as e: - print(f"Error plotting all data: {e}") - return - title = f"Plot of the number of valid observations from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_count_from_{self.start_date}_to_{self.end_date}.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - ) - return - elif months is not None: - for m in months: - try: - month_data = self.monthly_dict[m] - lon_values = month_data["lon"] - lat_values = month_data["lat"] - cum_count = month_data["count"] - except Exception as e: - print(f"Error for month {m}: {e}") - continue - title = f"Plot of the number of valid observations for month {m} from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_count_month_{m}_year_{self.year}.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - ) - else: - try: - lon_values = self.yearly_lon - lat_values = self.yearly_lat - cum_count = self.yearly_count - except Exception as e: - print(f"Error for yearly data: {e}") - return - title = f"Plot of the number of valid observations for yearly average of {self.year}" - filename = filename or f"{outdir}/{self.datatype}_2D_obs_count_yearly_{self.year}.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - ) diff --git a/mapies/util/multiprocessing.py b/mapies/util/multiprocessing.py index 6bbb60f4..a844e172 100644 --- a/mapies/util/multiprocessing.py +++ b/mapies/util/multiprocessing.py @@ -169,11 +169,11 @@ class GenericMultiProcessing(ABC): pass +@dataclass class VIIRSMultiprocessing(GenericMultiProcessing): """ Multiprocessing method for VIIRS """ - r: Grid obs_var: NDArray lat_var: NDArray diff --git a/mapies/viirs.py b/mapies/viirs.py index 27bba79f..0d95768e 100644 --- a/mapies/viirs.py +++ b/mapies/viirs.py @@ -29,79 +29,13 @@ from mapies.util.variables_names import VARS_ATTRIBUTES from pathlib import Path import json from numpy.typing import DTypeLike, NDArray +from mapies.util.multiprocessing import VIIRSMultiprocessing os.environ["OMP_NUM_THREADS"] = "8" #logger = logging.getLogger(__name__) #logging.basicConfig(filename='../logs/mapies.log', level=logging.INFO) -TIME_ORIGIN = np.datetime64("1900-01-01T00:00:00") - - - -def process_single_file( - r, - file, - obs_var, - lat_var, - lon_var, - time_var, - start_date, - end_date, - apply_qa, - qualityFlags, - flag=False, - geospatial_crop=None - ): - """ - Process a single file and return aggregated results or flattened arrays. - """ - try: - ds = xr.open_dataset(file, engine="h5netcdf") - - obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) - - #Run a time domain selection - if apply_qa: - obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) - - obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) - - if geospatial_crop is not None: - obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) - - if not flag: - # Regrid and aggregate - if isinstance(r, RotatedGrid): - lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) - elif isinstance(r, RegularGrid): - lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) - ds.close() - return obs_agg, lat_agg, lon_agg, count_obs - else: - ds.close() - return obs, lat, lon, time_values - - except Exception as e: - print(f"Error processing file {file}: {e}") - - return None - - - -def process_batch(args, batch): - - with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed - results = pool.starmap(process_single_file, args) - - # Filter valid results and log failed files - valid_results = [res for res in results if res is not None] - print(f"Valid results: {len(valid_results)}") - failed_files= [file for file, res in zip(batch, results) if res is None] - - return valid_results, failed_files - - class VIIRS(MAPIES): """ @@ -116,6 +50,7 @@ class VIIRS(MAPIES): self.time_orig = np.datetime64("1993-01-01T00:00:00") self.datatype="viirs" + self.mp_type = VIIRSMultiprocessing self.dest = kwargs.get("dest") if not os.path.exists(self.dest): raise Exception("Output directory doesn't exist") @@ -221,327 +156,6 @@ class VIIRS(MAPIES): print(f"Day {day:02d}: {len(data['files'])} files") - - @timeit - def process_data(self, monthly_avg=False, batch_size=100, save=False, apply_qa=True, geospatial_crop: NDArray | None = None): - - """ - Process the data. - - Parameters:geospatial_crop: NDArray | None = None - - monthly_avg: bool, whether to compute single monthly averages. - - batch_size: int, the number of files to process in each batch. - - Returns: - - if monthly_avg --> updated monthly dictorionary for the processed month. - - if not monthly_avg --> updated self.lon_values, self.lat_values, self.obs, self.count_obs. - """ - - #Check if the user wants qa flags - - if apply_qa: - self.apply_qa = True - else: - self.apply_qa = False - if geospatial_crop is not None: - self.geospatial_crop = np.array(geospatial_crop) - if monthly_avg: - for month in self.monthly_dict.keys(): - # print(f'Month: {self.monthly_dict[month]}') - # Catch if they haven't selected a grid - try: - print(f"Running average on {self.grid}") - except AttributeError: - raise Exception("Please specify a grid in the initialisation of MAPIES, using grid_repr=''") - try: - # Process the data for the current month - final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( - batch_size=batch_size, month=month - ) - - # Update the monthly dictionary - # valid_obs = final_obs > 0 - # self.monthly_dict[month] = { - # "lon": final_lon[valid_obs], - # "lat": final_lat[valid_obs], - # "obs": final_obs[valid_obs], - # "count": cumulative_count[valid_obs], - # } - self.monthly_dict[month] = { - "lon": final_lon, - "lat": final_lat, - "obs": final_obs, - "count": cumulative_count, - } - print(f'Updated monthly dictionary for month {month}') - print(f'Lon shape: {final_lon.shape}, Lat shape: {final_lat.shape}, Obs shape: {final_obs.shape}') - if np.all(final_lon == 0) or np.all(np.isnan(final_lon)): - print(f"All longitude values for month {month} are 0 or NaN.") - if np.all(final_lat == 0) or np.all(np.isnan(final_lat)): - print(f"All latitude values for month {month} are 0 or NaN.") - if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): - print(f"All observation values for month {month} are 0 or NaN.") - - except Exception as e: - print(f"Error processing data for month {month}: {e}") - continue - else: - try: - final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( - batch_size=batch_size, month=None - ) - - self.obs, self.lat, self.lon, self.count_obs = filter_observations(final_obs, final_lat, final_lon, cumulative_count) - # valid_obs = final_obs > 0 - # self.lon = final_lon[valid_obs] - # self.lat = final_lat[valid_obs] - # self.obs = final_obs[valid_obs] - # self.count_obs = cumulative_count[valid_obs] - - if save: - try: - # mid_time = self.start_date + (self.end_date - self.start_date) / 2 - # mid_time = np.datetime64(mid_time) - # N = len(self.obs) - # mid_time_array = np.repeat(mid_time, N) - mid_time_array = generate_mid_time_array(self.start_date, self.end_date, self.obs) - - self.to_netCDF(self.obs, self.lat, self.lon, mid_time_array, self.start_date, self.end_date) - print(f'Saved processed data to netCDF') - except Exception as e: - print(f"Error saving data into netCDF: {e}") - return - - except Exception as e: - print(f"Error processing data for the time period from {self.start_date} to {self.end_date}: {e}") - return - - - def process_monthly_data(self, batch_size, month=None): - - """ - Process the data with multiprocessing. - """ - flag = False - if month is None: - month_files = self.files - print(f'Processing all files from {self.start_date} to {self.end_date}') - else: - month_files = self.monthly_dict[month]["files"] - print(f"Processing {len(month_files)} files for month {month}") - - - # Split files into batches - batches = create_batches(month_files, batch_size) - - # Initialize cumulative arrays - cumulative_obs = None - cumulative_lat = None - cumulative_lon = None - cumulative_count = None - failed_files = [] - - for batch in batches: - print(f"Processing a batch of {len(batch)} files...") - args = [ - ( - self.grid, - file, - self.obs_var, - self.lat_var, - self.lon_var, - self.time_var, - self.start_date, - self.end_date, - self.apply_qa, - self.qualityFlags, - flag, - self.geospatial_crop - ) - for file in batch - ] - valid_results, failed = process_batch(args, batch) - failed_files.extend(failed) - - if not valid_results: - print("All files in this batch failed. Skipping...") - continue - - # Aggregate results incrementally - cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count = combine_arrays(valid_results, cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) - - if failed_files: - print(f"The following files failed to process: {failed_files}") - - # Check if any valid data was processed - if cumulative_obs is None or cumulative_count is None: - print(f"No valid data processed for month {month}. Returning empty arrays.") - return None, None, None, None - - # Compute final averages - final_obs, final_lat, final_lon = compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) - - print(f"Completed processing for month {month}.") - - return final_lon, final_lat, final_obs, cumulative_count - - - @timeit - def yearly_average(self): - """ - Compute the yearly average for the processed data. - """ - cumulative_obs = None - cumulative_lat = None - cumulative_lon = None - cumulative_count = None - - for key, month_data in self.monthly_dict.items(): - if not all(var in month_data for var in ["lon", "lat", "obs", "count"]): - print(f"Skipping month {key}: Missing data.") - continue - if any(month_data[var] is None for var in ["lon", "lat", "obs", "count"]): - print(f"Skipping month {key}: No valid data.") - continue - - lon_values = month_data["lon"] - lat_values = month_data["lat"] - obs_values = month_data["obs"] - count_values = month_data["count"] - - results = [(obs_values, lat_values, lon_values, count_values)] - cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count = combine_arrays(results, cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) - - if cumulative_obs is None or cumulative_count is None or np.all(cumulative_count == 0): - print("No valid data for the entire year. Returning empty arrays.") - self.yearly_lon = None - self.yearly_lat = None - self.yearly_obs = None - self.yearly_count = None - return - - yearly_obs, yearly_lat, yearly_lon = compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) - - self.yearly_lon = yearly_lon - self.yearly_lat = yearly_lat - self.yearly_obs = yearly_obs - self.yearly_count = cumulative_count - - print("Yearly average computation completed.") - - @timeit - def process_lazy_data(self, apply_qa=True, geospatial_crop: NDArray | None = None, save=True): - """ - Process the data for the specified time range. It only retrieves a daily dictionary using dask lazy loading. - """ - flag = True - #Check if the user wants qa flags - if apply_qa: - self.apply_qa = True - else: - self.apply_qa = False - if geospatial_crop is not None: - self.geospatial_crop = np.array(geospatial_crop) - for day in self.daily_dict.keys(): - try: - files = self.daily_dict[day]["files"] - args = [ - ( - self.grid, - file, - self.obs_var, - self.lat_var, - self.lon_var, - self.time_var, - self.start_date, - self.end_date, - self.apply_qa, - self.qualityFlags, - flag, - self.geospatial_crop - - ) - for file in files - ] - - valid_results, failed_files = process_batch(args, files) - - final_obs, final_lat, final_lon, final_time = append_and_concatenate(valid_results, day) - - # Skip the day if no valid data was processed - if final_obs is None: - continue - - # Check if any valid data exists - if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): - print(f"All observation values for day {day} are 0 or NaN. Skipping...") - continue - - # Update day dict - self.daily_dict[day]["obs"] = final_obs - self.daily_dict[day]["lon"] = final_lon - self.daily_dict[day]["lat"] = final_lat - self.daily_dict[day]["time"] = final_time - - # Get day start and day end but only if they are within object start and end date - day_start, day_end = day_boundaries(day, self.start_date, self.end_date) - if save: - self.to_netCDF(final_obs, final_lat, final_lon, final_time, day_start, day_end) - - except Exception as e: - print(f"Error processing data for day {day}: {e}") - continue - - def to_netCDF(self, obs, lat, lon, time, start_time, end_time): - try: - # Convert time to seconds since TIME_ORIGIN - final_time_seconds = (time - TIME_ORIGIN) / np.timedelta64(1, "s") - final_time_seconds = np.round(final_time_seconds, 0) - - grid_representation_str = json.dumps(self.grid_config, ensure_ascii=False) - qa_flags_str = json.dumps(self.qualityFlags, ensure_ascii=False) - - # Create xarray Dataset - ds = xr.Dataset( - coords={"time": ("time", final_time_seconds, { - "units": "seconds since 1900-01-01 00:00:00", - "calendar": "gregorian" - })}, - - data_vars={ - "lon": (["time"], lon, VARS_ATTRIBUTES['lon'].copy()), - "lat": (["time"], lat, VARS_ATTRIBUTES['lat'].copy()), - VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: (["time"], obs, VARS_ATTRIBUTES[self.obs_var].copy()), - }, - attrs={ - "title": f"{VARS_ATTRIBUTES['title']['description']} from {start_time} to {end_time}", - "institution": "Barcelona Supercomputing Center", - "grid": f"{self.grid_repr}: {grid_representation_str}", - "Developers": "MAPIES team", - "QA": f"Quality Assurance: {self.apply_qa}, quality assurance applied: {qa_flags_str}", - "history": f"Created {datetime.now()}", - }, - ) - - # Define encoding for compact storage - encoding = { - "time": {"dtype": "float32"}, - "lon": {"dtype": "float32"}, - "lat": {"dtype": "float32"}, - VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: {"dtype": "float32"}, - } - - # Save to NetCDF file - filename = f"{self.dest}/Data_{self.datatype}_{start_time}_{end_time}.nc" - ds.to_netcdf(filename, mode='w', encoding=encoding) - - print(f"Saved data to {filename}") - - except Exception as e: - print(f"Error saving data into netCDF: {e}") - return - - # ============================================================================= # Supporting functions # ============================================================================= @@ -655,139 +269,3 @@ class VIIRS(MAPIES): return outfiles - - - - - - - - -# ============================================================================= -# Plotting -# ============================================================================= - def plot_2D_observations(self, months=[0], filename=None, outdir=None, display_fig=False): - - if months == [0]: - try: - lon_values = self.lon - lat_values = self.lat - obs_values = self.obs - except Exception as e: - print(f"Error plotting all data: {e}") - return - title = f"Observation 2D plot from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_from_{self.start_date}_to_{self.end_date}.png" - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - display_fig=display_fig - ) - return - elif months is not None: - for m in months: - try: - month_data = self.monthly_dict[m] - lon_values = month_data["lon"] - lat_values = month_data["lat"] - obs_values = month_data["obs"] - except Exception as e: - print(f"Error for month {m}: {e}") - continue - title = f"Observation 2D plot of month {m} from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_month_{m}_year_{self.year}.png" - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - display_fig=display_fig - ) - else: - try: - lon_values = self.yearly_lon - lat_values = self.yearly_lat - obs_values = self.yearly_obs - except Exception as e: - print(f"Error for yearly data: {e}") - return - title = f"Observation 2D plot of yearly average of year {self.year}" - filename = filename or f"{self.datatype}_2D_yearly_obs_{self.year}.png" - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - display_fig=display_fig, - ) - - def plot_2D_num_obs(self, months=[0], filename=None, outdir=None, display_fig=False): - - if months == [0]: - try: - lon_values = self.lon - lat_values = self.lat - cum_count = self.count_obs - except Exception as e: - print(f"Error plotting all data: {e}") - return - title = f"Plot of the number of valid observations from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_count_from_{self.start_date}_to_{self.end_date}.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - display_fig=display_fig - ) - return - elif months is not None: - for m in months: - try: - month_data = self.monthly_dict[m] - lon_values = month_data["lon"] - lat_values = month_data["lat"] - cum_count = month_data["count"] - except Exception as e: - print(f"Error for month {m}: {e}") - continue - title = f"Plot of the number of valid observations for month {m} from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_count_month_{m}_year_{self.year}.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - display_fig=display_fig - ) - else: - try: - lon_values = self.yearly_lon - lat_values = self.yearly_lat - cum_count = self.yearly_count - except Exception as e: - print(f"Error for yearly data: {e}") - return - title = f"Plot of the number of valid observations for yearly average of {self.year}" - filename = filename or f"{outdir}/{self.datatype}_2D_obs_count_yearly_{self.year}.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - display_fig=display_fig - ) \ No newline at end of file diff --git a/run/run_tropomi.py b/run/run_tropomi.py index ad5ff778..5e4d7a22 100644 --- a/run/run_tropomi.py +++ b/run/run_tropomi.py @@ -15,12 +15,12 @@ if __name__ == "__main__": c = TROPOMI(start_date, end_date, dest=outdir, indir=indir, grid_repr="regular") c.read_nc() - c.process_data(monthly_avg=False, batch_size = 14, apply_qa=True, geospatial_crop=[[-40, 40], [0, 70]]) + c.process_data(monthly_avg=False, batch_size = 14, apply_qa=True, geospatial_crop=[[-40, 60], [0, 70]]) # c.yearly_average() c.plot_2D_observations(months=[0], outdir=outdir) c.plot_2D_num_obs(months=[0], outdir=outdir) - c.process_lazy_data(apply_qa=True, geospatial_crop=[[-40, 40], [0, 70]]) + c.process_lazy_data(apply_qa=True, geospatial_crop=[[-40, 60], [0, 70]]) end_time = time.time() elapsed_time = end_time - start_time print(f"Script executed in {elapsed_time:.2f} seconds.") diff --git a/run/run_viirs.py b/run/run_viirs.py index 989b1695..16966236 100644 --- a/run/run_viirs.py +++ b/run/run_viirs.py @@ -6,8 +6,8 @@ import time if __name__ == "__main__": # Change these variables # 00:30 to 02:36 from 2024-01-01 to 2024-01-03 - start_date = "202401010000" - end_date = "202401012359" + start_date = "202401300000" + end_date = "202402012359" @@ -25,11 +25,11 @@ if __name__ == "__main__": c.process_data(monthly_avg=False, batch_size = 100, apply_qa=True, save=True) - # c.yearly_average() - # c.plot_2D_observations(months=None, outdir=outdir) - # c.plot_2D_num_obs(months=None, outdir=outdir) - # c.process_lazy_data(apply_qa=True, save=True) - # c.to_da(frequency="3H", save_figs=True) + c.yearly_average() + c.plot_2D_observations(months=[0], filename="new.png", outdir=outdir) + c.plot_2D_num_obs(months=[0], outdir=outdir) + c.process_lazy_data(apply_qa=True, save=True) + c.to_da(frequency="3H", save_figs=True) -- GitLab From 727f8c128c21e5c33c0a3cb538f18fa766540120 Mon Sep 17 00:00:00 2001 From: Calum Meikle Date: Fri, 28 Feb 2025 14:58:07 +0100 Subject: [PATCH 3/4] Remove commented code --- mapies/mapies.py | 3 +-- mapies/util/multiprocessing.py | 1 - mapies/viirs.py | 5 ++--- run/run_viirs.py | 10 +++++----- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/mapies/mapies.py b/mapies/mapies.py index 82a1e294..0bee5218 100644 --- a/mapies/mapies.py +++ b/mapies/mapies.py @@ -11,8 +11,6 @@ import logging import numpy as np import pandas as pd import xarray as xr -#import matplotlib -#matplotlib.use("TkAgg") # Use tinker to perform plt.show() as it is GUI supported import matplotlib.pyplot as plt import cartopy import cartopy.crs as ccrs @@ -103,6 +101,7 @@ class MAPIES: self.qualityFlags = qa_dict["qa_variables"] + @timeit def process_data(self, monthly_avg=False, batch_size=100, save=False, apply_qa=True, geospatial_crop: NDArray | None = None): diff --git a/mapies/util/multiprocessing.py b/mapies/util/multiprocessing.py index a844e172..0557fe24 100644 --- a/mapies/util/multiprocessing.py +++ b/mapies/util/multiprocessing.py @@ -15,7 +15,6 @@ from dataclasses import dataclass from numpy.typing import DTypeLike, NDArray from mapies.util.func_tools import * -#timeit, get_file_list, time_converter, frequency_int, error_estimation, time_domain_selection, geo_to_rot from mapies.grids.monarch import Grid, RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function diff --git a/mapies/viirs.py b/mapies/viirs.py index 0d95768e..3670d494 100644 --- a/mapies/viirs.py +++ b/mapies/viirs.py @@ -5,7 +5,7 @@ from functools import wraps import logging import os import time -from tqdm import tqdm +#from tqdm import tqdm import yaml import multiprocessing import gc @@ -267,5 +267,4 @@ class VIIRS(MAPIES): ds.to_netcdf(filename, encoding={}) outfiles.append(filename) - return outfiles - + return outfiles \ No newline at end of file diff --git a/run/run_viirs.py b/run/run_viirs.py index 16966236..60959d36 100644 --- a/run/run_viirs.py +++ b/run/run_viirs.py @@ -24,12 +24,12 @@ if __name__ == "__main__": c.read_nc() - c.process_data(monthly_avg=False, batch_size = 100, apply_qa=True, save=True) + c.process_data(monthly_avg=True, batch_size = 100, apply_qa=True, save=True) c.yearly_average() - c.plot_2D_observations(months=[0], filename="new.png", outdir=outdir) - c.plot_2D_num_obs(months=[0], outdir=outdir) - c.process_lazy_data(apply_qa=True, save=True) - c.to_da(frequency="3H", save_figs=True) + c.plot_2D_observations(months=None, filename="new.png", outdir=outdir) + c.plot_2D_num_obs(months=None, outdir=outdir) + #c.process_lazy_data(apply_qa=True, save=True) + #c.to_da(frequency="3H", save_figs=True) -- GitLab From 8e4fdade21df364dea5965059807f60592a20dd6 Mon Sep 17 00:00:00 2001 From: Calum Meikle Date: Fri, 28 Feb 2025 15:01:12 +0100 Subject: [PATCH 4/4] Updated Changelog and version --- CHANGELOG | 7 +++++++ setup.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG b/CHANGELOG index 8902aa78..e1ef77d7 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,5 +1,12 @@ CHANGELOG +## 0.1.1 - 2025/02/28 - Code refactoring + +DRY code + +### Features +- Drying of code, moving duplicated code to Parent class and creating Multiprocessing class + ## 0.1.0 - 2025/02/21 - First version of MAPIES Established the initial structure of MAPIES, providing a flexible and scalable framework for satellite systems data analysis. diff --git a/setup.py b/setup.py index 0916106f..de6cc56e 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ from setuptools import find_packages from setuptools import setup # Could update this using versioneer -version="0.1.0" +version="0.1.1" setup( name="mapies", -- GitLab