diff --git a/CHANGELOG b/CHANGELOG index 8902aa78bd745ccdb8cc683cae03ede0a4919715..e1ef77d70870c3be3adaa6adb9007d25eaebd9c3 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,5 +1,12 @@ CHANGELOG +## 0.1.1 - 2025/02/28 - Code refactoring + +DRY code + +### Features +- Drying of code, moving duplicated code to Parent class and creating Multiprocessing class + ## 0.1.0 - 2025/02/21 - First version of MAPIES Established the initial structure of MAPIES, providing a flexible and scalable framework for satellite systems data analysis. diff --git a/mapies/mapies.py b/mapies/mapies.py index d548f422b1d70249f2f2c6eee0cabead18b1f531..0bee5218521989719a23891de15de888c86cf508 100644 --- a/mapies/mapies.py +++ b/mapies/mapies.py @@ -4,24 +4,26 @@ from functools import wraps from datetime import datetime, timedelta import sys -from mapies.util.func_tools import timeit, time_domain_selection, frequency_int, error_estimation +from mapies.util.func_tools import * from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid import time import logging import numpy as np import pandas as pd import xarray as xr -#import matplotlib -#matplotlib.use("TkAgg") # Use tinker to perform plt.show() as it is GUI supported import matplotlib.pyplot as plt import cartopy import cartopy.crs as ccrs import os import yaml +import json import matplotlib.cm as cm from matplotlib.colors import ListedColormap import matplotlib.colors as mcolors +from mapies.util.variables_names import VARS_ATTRIBUTES +from numpy.typing import DTypeLike, NDArray + #logger = logging.getLogger(__name__) #logging.basicConfig(filename='../logs/mapies.log', level=logging.DEBUG) @@ -37,7 +39,6 @@ class MAPIES: """ self.np_typetime = "datetime64[s]" self.int_type = 'int64' - self.time_orig = np.datetime64("1900-01-01T00:00:00") self.start_date = start_date self.end_date = end_date self.apply_qa = True @@ -101,49 +102,432 @@ class MAPIES: self.qualityFlags = qa_dict["qa_variables"] - # TODO this can be removed I believe @timeit - def plot_2D_obs(self, **kwargs): - """ - Plotting the observations + def process_data(self, monthly_avg=False, batch_size=100, save=False, apply_qa=True, geospatial_crop: NDArray | None = None): + """ - # Regrid if specified as argument - if self.grid_repr: - #logger.info("Performing regridding of data") - self.regridding_function() - # TODO: Make these arguments - figsize = (15,10) - markersize = 2.5 + Process the data. - # Set Basemap projection - proj = ccrs.PlateCarree() + Parameters:geospatial_crop: NDArray | None = None + - monthly_avg: bool, whether to compute single monthly averages. + - batch_size: int, the number of files to process in each batch. - # Create the plot and add features, TODO make this adjustable in arguments - fig, ax = plt.subplots(subplot_kw={"projection": proj}, figsize=figsize) - ax.gridlines() - #ax.add_feature(cartopy.feature.BORDERS, linestyle=':', alpha=1) - #ax.add_feature(cartopy.feature.OCEAN,facecolor=("lightblue")) - #ax.add_feature(cartopy.feature.LAND) - ax.coastlines(resolution='10m') + Returns: + - if monthly_avg --> updated monthly dictorionary for the processed month. + - if not monthly_avg --> updated self.lon_values, self.lat_values, self.obs, self.count_obs. + """ - + #Check if the user wants qa flags + if apply_qa: + self.apply_qa = True + else: + self.apply_qa = False + if geospatial_crop is not None: + self.geospatial_crop = np.array(geospatial_crop) + if monthly_avg: + for month in self.monthly_dict.keys(): + + # Catch if they haven't selected a grid + try: + print(f"Running average on {self.grid}") + except AttributeError: + raise Exception("Please specify a grid in the initialisation of MAPIES, using grid_repr=''") + try: + # Process the data for the current month + final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( + batch_size=batch_size, month=month + ) + + self.monthly_dict[month] = { + "lon": final_lon, + "lat": final_lat, + "obs": final_obs, + "count": cumulative_count, + } + print(f'Updated monthly dictionary for month {month}') + print(f'Lon shape: {final_lon.shape}, Lat shape: {final_lat.shape}, Obs shape: {final_obs.shape}') + if np.all(final_lon == 0) or np.all(np.isnan(final_lon)): + print(f"All longitude values for month {month} are 0 or NaN.") + if np.all(final_lat == 0) or np.all(np.isnan(final_lat)): + print(f"All latitude values for month {month} are 0 or NaN.") + if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): + print(f"All observation values for month {month} are 0 or NaN.") + + except Exception as e: + print(f"Error processing data for month {month}: {e}") + continue + else: + try: + final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( + batch_size=batch_size, month=None + ) + + self.obs, self.lat, self.lon, self.count_obs = filter_observations(final_obs, final_lat, final_lon, cumulative_count) + + + if save: + try: + mid_time_array = generate_mid_time_array(self.start_date, self.end_date, self.obs) + + self.to_netCDF(self.obs, self.lat, self.lon, mid_time_array, self.start_date, self.end_date) + print(f'Saved processed data to netCDF') + except Exception as e: + print(f"Error saving data into netCDF: {e}") + return + + except Exception as e: + print(f"Error processing data for the time period from {self.start_date} to {self.end_date}: {e}") + return + + + def process_monthly_data(self, batch_size, month=None): - print("Plotting observations") + """ + Process the data with multiprocessing. + """ + flag = False + if month is None: + month_files = self.files + print(f'Processing all files from {self.start_date} to {self.end_date}') + else: + month_files = self.monthly_dict[month]["files"] + print(f"Processing {len(month_files)} files for month {month}") + + + # Split files into batches + batches = create_batches(month_files, batch_size) + + # Initialize cumulative arrays + cumulative_obs = None + cumulative_lat = None + cumulative_lon = None + cumulative_count = None + failed_files = [] + + for batch in batches: + print(f"Processing a batch of {len(batch)} files...") + + # Use multiprocessing to process files within the batch + + viirs_mp = self.mp_type( + self.grid, + self.obs_var, + self.lat_var, + self.lon_var, + self.time_var, + self.start_date, + self.end_date, + self.apply_qa, + self.qualityFlags, + batch, + self.geospatial_crop, + flag, + ) + valid_results, failed = viirs_mp.process_batch() + failed_files.extend(failed) + + if not valid_results: + print("All files in this batch failed. Skipping...") + continue + + # Aggregate results incrementally + cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count = combine_arrays(valid_results, cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) + if failed_files: + print(f"The following files failed to process: {failed_files}") - x, y = self.lon_values, self.lat_values + # Check if any valid data was processed + if cumulative_obs is None or cumulative_count is None: + print(f"No valid data processed for month {month}. Returning empty arrays.") + return None, None, None, None - im = ax.scatter(x,y,markersize,c=self.obs, transform=proj) - - fig.colorbar(im, ax=ax) + # Compute final averages + final_obs, final_lat, final_lon = compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) + + print(f"Completed processing for month {month}.") - ax.set_title(f'Observation 2D plot of {self.datatype.upper()} data from {self.start_date} to {self.end_date}') - print("Saving Figure") - plt.savefig(f"{self.dest}{self.datatype}_2D_obs.png", format="png") - #plt.close(fig) - plt.show() - #return fig + return final_lon, final_lat, final_obs, cumulative_count + + + @timeit + def yearly_average(self): + """ + Compute the yearly average for the processed data. + """ + cumulative_obs = None + cumulative_lat = None + cumulative_lon = None + cumulative_count = None + + for key, month_data in self.monthly_dict.items(): + if not all(var in month_data for var in ["lon", "lat", "obs", "count"]): + print(f"Skipping month {key}: Missing data.") + continue + if any(month_data[var] is None for var in ["lon", "lat", "obs", "count"]): + print(f"Skipping month {key}: No valid data.") + continue + + lon_values = month_data["lon"] + lat_values = month_data["lat"] + obs_values = month_data["obs"] + count_values = month_data["count"] + + results = [(obs_values, lat_values, lon_values, count_values)] + cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count = combine_arrays(results, cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) + if cumulative_obs is None or cumulative_count is None or np.all(cumulative_count == 0): + print("No valid data for the entire year. Returning empty arrays.") + self.yearly_lon = None + self.yearly_lat = None + self.yearly_obs = None + self.yearly_count = None + return + + yearly_obs, yearly_lat, yearly_lon = compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) + + self.yearly_lon = yearly_lon + self.yearly_lat = yearly_lat + self.yearly_obs = yearly_obs + self.yearly_count = cumulative_count + + print("Yearly average computation completed.") + + + @timeit + def process_lazy_data(self, apply_qa=True, geospatial_crop: NDArray | None = None, save=True): + """ + Process the data for the specified time range. It only retrieves a daily dictionary using dask lazy loading. + """ + flag = True + #Check if the user wants qa flags + if apply_qa: + self.apply_qa = True + else: + self.apply_qa = False + if geospatial_crop is not None: + self.geospatial_crop = np.array(geospatial_crop) + for day in self.daily_dict.keys(): + try: + files = self.daily_dict[day]["files"] + + # Use multiprocessing to process files within the batch + + viirs_mp = self.mp_type( + self.grid, + self.obs_var, + self.lat_var, + self.lon_var, + self.time_var, + self.start_date, + self.end_date, + self.apply_qa, + self.qualityFlags, + files, + self.geospatial_crop, + flag, + ) + valid_results, failed_files = viirs_mp.process_batch() + + final_obs, final_lat, final_lon, final_time = append_and_concatenate(valid_results, day) + + # Skip the day if no valid data was processed + if final_obs is None: + continue + + # Check if any valid data exists + if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): + print(f"All observation values for day {day} are 0 or NaN. Skipping...") + continue + + # Update day dict + self.daily_dict[day]["obs"] = final_obs + self.daily_dict[day]["lon"] = final_lon + self.daily_dict[day]["lat"] = final_lat + self.daily_dict[day]["time"] = final_time + + # Get day start and day end but only if they are within object start and end date + day_start, day_end = day_boundaries(day, self.start_date, self.end_date) + if save: + self.to_netCDF(final_obs, final_lat, final_lon, final_time, day_start, day_end) + + except Exception as e: + print(f"Error processing data for day {day}: {e}") + continue + + + def to_netCDF(self, obs, lat, lon, time, start_time, end_time): + try: + # Convert time to seconds since TIME_ORIGIN + final_time_seconds = (time - self.time_orig) / np.timedelta64(1, "s") + final_time_seconds = np.round(final_time_seconds, 0) + + grid_representation_str = json.dumps(self.grid_config, ensure_ascii=False) + qa_flags_str = json.dumps(self.qualityFlags, ensure_ascii=False) + + # Create xarray Dataset + ds = xr.Dataset( + coords={"time": ("time", final_time_seconds, { + "units": "seconds since 1900-01-01 00:00:00", + "calendar": "gregorian" + })}, + + data_vars={ + "lon": (["time"], lon, VARS_ATTRIBUTES['lon'].copy()), + "lat": (["time"], lat, VARS_ATTRIBUTES['lat'].copy()), + VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: (["time"], obs, VARS_ATTRIBUTES[self.obs_var].copy()), + }, + attrs={ + "title": f"{VARS_ATTRIBUTES['title']['description']} from {start_time} to {end_time}", + "institution": "Barcelona Supercomputing Center", + "grid": f"{self.grid_repr}: {grid_representation_str}", + "Developers": "MAPIES team", + "QA": f"Quality Assurance: {self.apply_qa}, quality assurance applied: {qa_flags_str}", + "history": f"Created {datetime.now()}", + }, + ) + + # Define encoding for compact storage + encoding = { + "time": {"dtype": "float32"}, + "lon": {"dtype": "float32"}, + "lat": {"dtype": "float32"}, + VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: {"dtype": "float32"}, + } + + # Save to NetCDF file + filename = f"{self.dest}/Data_{self.datatype}_{start_time}_{end_time}.nc" + ds.to_netcdf(filename, mode='w', encoding=encoding) + + print(f"Saved data to {filename}") + + except Exception as e: + print(f"Error saving data into netCDF: {e}") + return + + +# ============================================================================= +# Plotting +# ============================================================================= + def plot_2D_observations(self, months=[0], filename=None, outdir=None): + + if months == [0]: + try: + lon_values = self.lon + lat_values = self.lat + obs_values = self.obs + except Exception as e: + print(f"Error plotting all data: {e}") + return + title = f"Observation 2D plot from {self.start_date} to {self.end_date}" + filename = filename or f"{self.datatype}_2D_obs_from_{self.start_date}_to_{self.end_date}.png" + self.plot_2D_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + obs_values=obs_values, + outdir=outdir, + title=title, + filename=filename, + ) + return + elif months is not None: + for m in months: + try: + month_data = self.monthly_dict[m] + lon_values = month_data["lon"] + lat_values = month_data["lat"] + obs_values = month_data["obs"] + except Exception as e: + print(f"Error for month {m}: {e}") + continue + title = f"Observation 2D plot of month {m} from {self.start_date} to {self.end_date}" + filename = filename or f"{self.datatype}_2D_obs_month_{m}_year_{self.year}.png" + super().plot_2D_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + obs_values=obs_values, + outdir=outdir, + title=title, + filename=filename, + ) + else: + try: + lon_values = self.yearly_lon + lat_values = self.yearly_lat + obs_values = self.yearly_obs + except Exception as e: + print(f"Error for yearly data: {e}") + return + title = f"Observation 2D plot of yearly average of year {self.year}" + filename = filename or f"{self.datatype}_2D_yearly_obs_{self.year}.png" + self.plot_2D_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + obs_values=obs_values, + outdir=outdir, + title=title, + filename=filename, + ) + + def plot_2D_num_obs(self, months=[0], filename=None, outdir=None): + + if months == [0]: + try: + lon_values = self.lon + lat_values = self.lat + cum_count = self.count_obs + except Exception as e: + print(f"Error plotting all data: {e}") + return + title = f"Plot of the number of valid observations from {self.start_date} to {self.end_date}" + filename = filename or f"{self.datatype}_2D_obs_count_from_{self.start_date}_to_{self.end_date}.png" + self.plot_2D_num_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + num_obs=cum_count, + outdir=outdir, + title=title, + filename=filename, + ) + return + elif months is not None: + for m in months: + try: + month_data = self.monthly_dict[m] + lon_values = month_data["lon"] + lat_values = month_data["lat"] + cum_count = month_data["count"] + except Exception as e: + print(f"Error for month {m}: {e}") + continue + title = f"Plot of the number of valid observations for month {m} from {self.start_date} to {self.end_date}" + filename = filename or f"{self.datatype}_2D_obs_count_month_{m}_year_{self.year}.png" + self.plot_2D_num_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + num_obs=cum_count, + outdir=outdir, + title=title, + filename=filename, + ) + else: + try: + lon_values = self.yearly_lon + lat_values = self.yearly_lat + cum_count = self.yearly_count + except Exception as e: + print(f"Error for yearly data: {e}") + return + title = f"Plot of the number of valid observations for yearly average of {self.year}" + filename = filename or f"{outdir}/{self.datatype}_2D_obs_count_yearly_{self.year}.png" + self.plot_2D_num_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + num_obs=cum_count, + outdir=outdir, + title=title, + filename=filename, + ) + # TODO this can be moved to a plotting function python file @timeit @@ -208,92 +592,6 @@ class MAPIES: print(f"Saved plot: {filepath}") - @timeit - def plot_2D_cumulative_count(self, lon_values, lat_values, cumulative_count, outdir="/home/cgile/Documents/mapies/mapies/figures", title=None, filename=None, display_fig=False): - """ - Plot the cumulative count of observations contributing to the final averages. - - Parameters: - - lon_values: 1D array of longitude values. - - lat_values: 1D array of latitude values. - - cumulative_count: 1D array of the count of valid observations per grid cell. - - outdir: Directory to save the plot. - - title: Title for the plot. - - filename: Name of the output file. - """ - - # Set Basemap projection - figsize = (15, 10) - markersize = 2.5 - proj = ccrs.PlateCarree() - - # Set colormap and normalization - cmap = plt.cm.YlOrBr # Colormap for cumulative count - norm = mcolors.Normalize(vmin=0, vmax=np.max(cumulative_count)) # Continuous normalization - - # Create the plot and add features - fig, ax = plt.subplots(subplot_kw={"projection": proj}, figsize=figsize) - ax.gridlines() - ax.coastlines(resolution="10m") - - print("Plotting cumulative count") - - # Scatter plot for cumulative count - im = ax.scatter( - lon_values, lat_values, markersize, c=cumulative_count, cmap=cmap, norm=norm, transform=proj - ) - - # Add the color bar - cbar = fig.colorbar(im, ax=ax, orientation="vertical", pad=0.05) - cbar.set_label("Cumulative Count") - - # Add title and save the plot - ax.set_title(title) - - # Ensure the output directory exists - os.makedirs(outdir, exist_ok=True) - - output_filename = "few{}fcec".format(site_code) - print("{}".format(this_string)) - - - filepath = os.path.join(outdir, filename) - plt.savefig(filepath, format="png") - if display_fig: - plt.show() - print(f"Saved plot: {filepath}") - - - - @staticmethod - def reindex( - dependent_var_index, - independent_var_values, - ): - """" - Recutting the data whenever a selction of the data has been made along one of the dimenisons - - Based off of how it has been done in Providentia - """ - # Maybe add a checker try/except or assert - independent_var_values = independent_var_values[dependent_var_index] - - # return reindexed values - return independent_var_values - - - @staticmethod - def to_xarray(coords:dict, data_vars:dict, **kwargs): - """ - Method to convert numpy arrays to xarray opject - """ - attrs = kwargs.get("attrs") - if attrs is None: - attrs=dict() - - ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs) - return ds - # TODO this can be moved to a plotting function python file @timeit def plot_2D_num_obs_custom( @@ -355,4 +653,14 @@ class MAPIES: #return fig - + @staticmethod + def to_xarray(coords:dict, data_vars:dict, **kwargs): + """ + Method to convert numpy arrays to xarray opject + """ + attrs = kwargs.get("attrs") + if attrs is None: + attrs=dict() + + ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs) + return ds \ No newline at end of file diff --git a/mapies/tropomi.py b/mapies/tropomi.py index 2df0c03af7d1a21f7ac1f1d9461c73e5541f348d..b1c5a39c94a863cd837951c0c10cc7673e663fe2 100644 --- a/mapies/tropomi.py +++ b/mapies/tropomi.py @@ -17,7 +17,7 @@ import os from mapies.util.func_tools import * #timeit, get_file_list, time_converter, frequency_int, error_estimation, time_domain_selection, geo_to_rot from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function -from mapies.util.variables_names import VARS_ATTRIBUTES +from mapies.util.multiprocessing import TROPOMIMultiprocessing import json from numpy.typing import DTypeLike, NDArray @@ -26,52 +26,8 @@ os.environ["OMP_NUM_THREADS"] = "8" #logger = logging.getLogger(__name__) #logging.basicConfig(filename='../logs/mapies.log', level=logging.INFO) -TIME_ORIGIN = np.datetime64("1900-01-01T00:00:00") -def process_single_file( - r, - file, - obs_var, - lat_var, - lon_var, - time_var, - start_date, - end_date, - apply_qa, - qualityFlags, - flag=False, - geospatial_crop=None - ): - """ - Process a single file and return aggregated results or flattened arrays. - """ - try: - ds = xr.open_dataset(file, engine="h5netcdf", group="PRODUCT") - - obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) - - if apply_qa: - obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) - obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) - if geospatial_crop is not None: - obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) - - if not flag: - # Regrid and aggregate - if isinstance(r, RotatedGrid): - lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) - elif isinstance(r, RegularGrid): - lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) - ds.close() - return obs_agg, lat_agg, lon_agg, count_obs - else: - ds.close() - return obs, lat, lon, time_values - - except Exception as e: - print(f"Error processing file {file}: {e}") - return None class TROPOMI(MAPIES): """ @@ -84,7 +40,7 @@ class TROPOMI(MAPIES): """ super().__init__(start_date, end_date) - self.time_orig = np.datetime64("1993-01-01T00:00:00") + self.time_orig = np.datetime64("1900-01-01T00:00:00") self.dest = kwargs.get("dest") if not os.path.exists(self.dest): raise Exception("Output directory doesn't exist") @@ -94,6 +50,7 @@ class TROPOMI(MAPIES): self.monthly_dict = {} self.daily_dict = {} # Dictionary to produce daily netcdf files self.datatype = "tropomi" + self.mp_type = TROPOMIMultiprocessing # Add quality value filter number @@ -185,459 +142,3 @@ class TROPOMI(MAPIES): # Print the daily disctionary to see if files are stored in chronological order for day, data in self.daily_dict.items(): print(f"Day {day:02d}: {len(data['files'])} files") - - - @timeit - def process_data(self, monthly_avg=False, batch_size=100, apply_qa=True, geospatial_crop: NDArray | None = None): - """ - Process the data for the specified year and months. - - Parameters: - - monthly_avg: bool, whether to compute single monthly averages. - - batch_size: int, the number of files to process in each batch. - - chunk_size: int, the size of chunks for processing within each batch. - - geospatial_crop, 2D numpy array - - Returns: - - if monthly_avg --> lon_final, lat_final, obs_final: Aggregated longitude, latitude, and observation arrays. - - if not monthly_avg --> preprocess_vars() is called. - """ - #Check if the user wants qa flags - if apply_qa: - self.apply_qa = True - if geospatial_crop is not None: - self.geospatial_crop = np.array(geospatial_crop) - if monthly_avg: - for month in self.monthly_dict.keys(): - # print(f'Month: {self.monthly_dict[month]}') - # Catch if they haven't selected a grid - try: - print(f"Running average on {self.grid}") - except AttributeError: - raise Exception("Please specify a grid in the initialisation of MAPIES, using grid_repr=''") - try: - # Process the data for the current month - final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( - batch_size=batch_size, month=month - ) - - # Update the monthly dictionary - valid_obs = final_obs > 0 - self.monthly_dict[month] = { - "lon": final_lon[valid_obs], - "lat": final_lat[valid_obs], - "obs": final_obs[valid_obs], - "count": cumulative_count[valid_obs], - } - print(f'Updated monthly dictionary for month {month}') - print(f'Lon shape: {final_lon.shape}, Lat shape: {final_lat.shape}, Obs shape: {final_obs.shape}') - if np.all(final_lon == 0) or np.all(np.isnan(final_lon)): - print(f"All longitude values for month {month} are 0 or NaN.") - if np.all(final_lat == 0) or np.all(np.isnan(final_lat)): - print(f"All latitude values for month {month} are 0 or NaN.") - if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): - print(f"All observation values for month {month} are 0 or NaN.") - - except Exception as e: - print(f"Error processing data for month {month}: {e}") - continue - else: - try: - final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( - batch_size=batch_size, month=None - ) - valid_obs = final_obs > 0 - self.lon = final_lon[valid_obs] - self.lat = final_lat[valid_obs] - self.obs = final_obs[valid_obs] - self.count_obs = cumulative_count[valid_obs] - - except Exception as e: - print(f"Error processing data for the time period from {self.start_date} to {self.end_date}: {e}") - return - - def process_monthly_data(self, batch_size, month=None): - - """ - Function that returns all the needed variables for the DA - """ - # r = RotatedGrid(centre_lon=20, centre_lat=35, dlon=.1, dlat=.1, west=-51, south=-35) - flag = False - if month is None: - month_files = self.files - print(f'Processing all files from {self.start_date} to {self.end_date}') - else: - month_files = self.monthly_dict[month]["files"] - print(f"Processing {len(month_files)} files for month {month}") - - - # Split files into batches - batches = [ - month_files[i:i + batch_size] for i in range(0, len(month_files), batch_size) - ] - - # Initialize cumulative arrays - cumulative_obs = None - cumulative_lat = None - cumulative_lon = None - cumulative_count = None - failed_files = [] - - for batch in batches: - print(f"Processing a batch of {len(batch)} files...") - print(f"Batch: {batch}") - - # Use multiprocessing to process files within the batch - args = [(self.grid, file, self.obs_var, self.lat_var, self.lon_var, self.time_var, self.start_date, self.end_date, self.apply_qa, self.qualityFlags, flag, self.geospatial_crop) for file in batch] - - with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed - results = pool.starmap(process_single_file, args) - - # Filter valid results and log failed files - valid_results = [res for res in results if res is not None] - failed_files.extend([file for file, res in zip(batch, results) if res is None]) - - if not valid_results: - print("All files in this batch failed. Skipping...") - continue - - # Aggregate results incrementally - for obs_agg, lat_agg, lon_agg, count_obs in valid_results: - if cumulative_obs is None: - # Initialize cumulative arrays - cumulative_obs = obs_agg * count_obs - cumulative_lat = lat_agg * count_obs - cumulative_lon = lon_agg * count_obs - cumulative_count = count_obs - else: - cumulative_obs += obs_agg * count_obs - cumulative_lat += lat_agg * count_obs - cumulative_lon += lon_agg * count_obs - cumulative_count += count_obs - - if failed_files: - print(f"The following files failed to process: {failed_files}") - - # Check if any valid data was processed - if cumulative_obs is None or cumulative_count is None: - print(f"No valid data processed for month {month}. Returning empty arrays.") - return None, None, None, None - - # Compute final averages - valid_mask = cumulative_count > 0 - final_obs = np.zeros_like(cumulative_obs) - final_lat = np.zeros_like(cumulative_lat) - final_lon = np.zeros_like(cumulative_lon) - - final_obs[valid_mask] = cumulative_obs[valid_mask] / cumulative_count[valid_mask] - final_lat[valid_mask] = cumulative_lat[valid_mask] / cumulative_count[valid_mask] - final_lon[valid_mask] = cumulative_lon[valid_mask] / cumulative_count[valid_mask] - - print(f"Completed processing for month {month}.") - - return final_lon, final_lat, final_obs, cumulative_count - - @timeit - def yearly_average(self): - """ - Compute the yearly average for the processed data. - """ - cumulative_obs = None - cumulative_lat = None - cumulative_lon = None - cumulative_count = None - - for key, month_data in self.monthly_dict.items(): - if not all(var in month_data for var in ["lon", "lat", "obs", "count"]): - print(f"Skipping month {key}: Missing data.") - continue - if any(month_data[var] is None for var in ["lon", "lat", "obs", "count"]): - print(f"Skipping month {key}: No valid data.") - continue - - lon_values = month_data["lon"] - lat_values = month_data["lat"] - obs_values = month_data["obs"] - count_values = month_data["count"] - - if cumulative_obs is None: - cumulative_obs = obs_values * count_values - cumulative_lat = lat_values * count_values - cumulative_lon = lon_values * count_values - cumulative_count = count_values - else: - cumulative_obs += obs_values * count_values - cumulative_lat += lat_values * count_values - cumulative_lon += lon_values * count_values - cumulative_count += count_values - - if cumulative_obs is None or cumulative_count is None or np.all(cumulative_count == 0): - print("No valid data for the entire year. Returning empty arrays.") - self.yearly_lon = None - self.yearly_lat = None - self.yearly_obs = None - self.yearly_count = None - return - - valid_mask = cumulative_count > 0 - yearly_obs = np.zeros_like(cumulative_obs) - yearly_lat = np.zeros_like(cumulative_lat) - yearly_lon = np.zeros_like(cumulative_lon) - - yearly_obs[valid_mask] = cumulative_obs[valid_mask] / cumulative_count[valid_mask] - yearly_lat[valid_mask] = cumulative_lat[valid_mask] / cumulative_count[valid_mask] - yearly_lon[valid_mask] = cumulative_lon[valid_mask] / cumulative_count[valid_mask] - - self.yearly_lon = yearly_lon - self.yearly_lat = yearly_lat - self.yearly_obs = yearly_obs - self.yearly_count = cumulative_count - - print("Yearly average computation completed.") - - @timeit - def process_lazy_data(self, apply_qa=True, geospatial_crop: NDArray | None = None, save=True): - """ - Process the data for the specified time range. It only retrieves a daily dictionary using dask lazy loading. - """ - flag = True - #Check if the user wants qa flags - if apply_qa: - self.apply_qa = True - else: - self.apply_qa = False - if geospatial_crop is not None: - self.geospatial_crop = np.array(geospatial_crop) - for day in self.daily_dict.keys(): - try: - files = self.daily_dict[day]["files"] - args = [ - ( - self.grid, - file, - self.obs_var, - self.lat_var, - self.lon_var, - self.time_var, - self.start_date, - self.end_date, - self.apply_qa, - self.qualityFlags, - flag, - self.geospatial_crop - - ) - for file in files - ] - - daily_obs = [] - daily_lat = [] - daily_lon = [] - daily_time = [] - - with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed - results = pool.starmap(process_single_file, args) - - valid_results = [res for res in results if res is not None] - # Check shapes of obs lat and lon to see if they are all the same - - for obs, lat, lon, time_values in valid_results: - #print(f'Shape of obs: {obs.shape}, Shape of lat: {lat.shape}, Shape of lon: {lon.shape}, Shape of time: {time_values.shape}') - if len(obs) > 0: - daily_lon.append(lon) - daily_lat.append(lat) - daily_obs.append(obs) - daily_time.append(time_values) - - if not daily_obs: - print(f"No valid data for day {day}") - continue - - final_obs = np.concatenate(daily_obs, axis=0) - final_lon = np.concatenate(daily_lon, axis=0) - final_lat = np.concatenate(daily_lat, axis=0) - final_time = np.concatenate(daily_time, axis=0) - - # Check if any valid data exists - if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): - print(f"All observation values for day {day} are 0 or NaN. Skipping...") - continue - # Get day start and day end but only if they are within object start and end date - day_start, day_end = day_boundaries(day, self.start_date, self.end_date) - if save: - self.to_netCDF(final_obs, final_lat, final_lon, final_time, day_start, day_end) - except Exception as e: - print(f"Error processing data for day {day}: {e}") - continue - - - def to_netCDF(self, obs, lat, lon, time, start_time, end_time): - try: - # Convert time to seconds since TIME_ORIGIN - final_time_seconds = (time - TIME_ORIGIN) / np.timedelta64(1, "s") - final_time_seconds = np.round(final_time_seconds, 0) - - grid_representation_str = json.dumps(self.grid_config, ensure_ascii=False) - qa_flags_str = json.dumps(self.qualityFlags, ensure_ascii=False) - - # Create xarray Dataset - ds = xr.Dataset( - coords={"time": ("time", final_time_seconds, { - "units": "seconds since 1900-01-01 00:00:00", - "calendar": "gregorian" - })}, - - data_vars={ - "lon": (["time"], lon, VARS_ATTRIBUTES['lon'].copy()), - "lat": (["time"], lat, VARS_ATTRIBUTES['lat'].copy()), - VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: (["time"], obs, VARS_ATTRIBUTES[self.obs_var].copy()), - }, - attrs={ - "title": f"{VARS_ATTRIBUTES['title']['description']} from {start_time} to {end_time}", - "institution": "Barcelona Supercomputing Center", - "grid": f"{self.grid_repr}: {grid_representation_str}", - "Developers": "MAPIES team", - "QA": f"Quality Assurance: {self.apply_qa}, quality assurance applied: {qa_flags_str}", - "history": f"Created {datetime.now()}", - }, - ) - - # Define encoding for compact storage - encoding = { - "time": {"dtype": "float32"}, - "lon": {"dtype": "float32"}, - "lat": {"dtype": "float32"}, - VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: {"dtype": "float32"}, - } - - # Save to NetCDF file - filename = f"{self.dest}/Data_{self.datatype}_{start_time}_{end_time}.nc" - ds.to_netcdf(filename, mode='w', encoding=encoding) - - print(f"Saved data to {filename}") - - except Exception as e: - print(f"Error saving data into netCDF: {e}") - return - - - - def plot_2D_observations(self, months=[0], filename=None, outdir=None): - - if months == [0]: - try: - lon_values = self.lon - lat_values = self.lat - obs_values = self.obs - except Exception as e: - print(f"Error plotting all data: {e}") - return - title = f"Observation 2D plot from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_from_{self.start_date}_to_{self.end_date}.png" - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - ) - return - elif months is not None: - for m in months: - try: - month_data = self.monthly_dict[m] - lon_values = month_data["lon"] - lat_values = month_data["lat"] - obs_values = month_data["obs"] - except Exception as e: - print(f"Error for month {m}: {e}") - continue - title = f"Observation 2D plot of month {m} from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_month_{m}_year_{self.year}.png" - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - ) - else: - try: - lon_values = self.yearly_lon - lat_values = self.yearly_lat - obs_values = self.yearly_obs - except Exception as e: - print(f"Error for yearly data: {e}") - return - title = f"Observation 2D plot of yearly average of year {self.year}" - filename = filename or f"{self.datatype}_2D_yearly_obs_{self.year}.png" - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - ) - - def plot_2D_num_obs(self, months=[0], filename=None, outdir=None): - - if months == [0]: - try: - lon_values = self.lon - lat_values = self.lat - cum_count = self.count_obs - except Exception as e: - print(f"Error plotting all data: {e}") - return - title = f"Plot of the number of valid observations from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_count_from_{self.start_date}_to_{self.end_date}.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - ) - return - elif months is not None: - for m in months: - try: - month_data = self.monthly_dict[m] - lon_values = month_data["lon"] - lat_values = month_data["lat"] - cum_count = month_data["count"] - except Exception as e: - print(f"Error for month {m}: {e}") - continue - title = f"Plot of the number of valid observations for month {m} from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_count_month_{m}_year_{self.year}.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - ) - else: - try: - lon_values = self.yearly_lon - lat_values = self.yearly_lat - cum_count = self.yearly_count - except Exception as e: - print(f"Error for yearly data: {e}") - return - title = f"Plot of the number of valid observations for yearly average of {self.year}" - filename = filename or f"{outdir}/{self.datatype}_2D_obs_count_yearly_{self.year}.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - ) diff --git a/mapies/util/func_tools.py b/mapies/util/func_tools.py index 9a63fdc8fa44140631d9140f2bcfceda7d621dd7..4823d55c995371742084a911e57076055a298441 100644 --- a/mapies/util/func_tools.py +++ b/mapies/util/func_tools.py @@ -11,7 +11,6 @@ from typing import List from datetime import datetime import multiprocessing -from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function TIME_ORIGIN = np.datetime64("1993-01-01T00:00:00") @@ -270,30 +269,6 @@ def error_estimation(datatype:str, obs:np.typing.NDArray, unc_const:list) -> np. - -# def geo_to_rot(lons, lats, centre_lon: float, centre_lat: float): -# """ -# Rotating coordinates from cartesian lat/lon to rotated rlon/rlat -# """ -# distance_lons = np.radians(lons - centre_lon) -# lons = np.radians(lons) -# lats = np.radians(lats) -# centre_lon = np.radians(centre_lon) -# centre_lat = np.radians(centre_lat) - -# x = cos(centre_lat) * sin(lats) - sin(centre_lat) * cos(lats) * cos(distance_lons) -# y = cos(lats) * sin(distance_lons) -# z = cos(centre_lat) * cos(lats) * cos(distance_lons) + sin(centre_lat) * sin(lats) -# # Arctan2 used -# # Explanation of the difference https://geo.libretexts.org/Courses/University_of_California_Davis/GEL_056%3A_Introduction_to_Geophysics/Geophysics_is_everywhere_in_geology.../zz%3A_Back_Matter/Arctan_vs_Arctan2 -# rlon = np.arctan2(y, z) -# rlat = np.arcsin(x) -# #rlon[x < 0] += pi -# # Convert back to degrees -# rlon = np.degrees(rlon) -# rlat = np.degrees(rlat) -# return rlon, rlat - def new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date): """ * Preprocessing of the dataset @@ -381,6 +356,8 @@ def day_boundaries(day, start_date, end_date): return day_start, day_end + + def filter_files(files, start_date, end_date): """ Filter files based on the start_date and end_date @@ -397,6 +374,8 @@ def filter_files(files, start_date, end_date): return files_filtered + + def generate_mid_time_array(start_date, end_date, obs): """ Generate the mid time array for the observations @@ -410,6 +389,8 @@ def generate_mid_time_array(start_date, end_date, obs): return mid_time_array + + def create_batches(files, batch_size): """ Create batches of files to pass to MP @@ -422,65 +403,69 @@ def create_batches(files, batch_size): return batches -def process_single_file( - r, - file, - obs_var, - lat_var, - lon_var, - time_var, - start_date, - end_date, - apply_qa, - qualityFlags, - flag=False, - geospatial_crop=None - ): - """ - Process a single file and return aggregated results or flattened arrays. - """ - try: - ds = xr.open_dataset(file, engine="h5netcdf") + + +# def process_single_file( +# r, +# file, +# obs_var, +# lat_var, +# lon_var, +# time_var, +# start_date, +# end_date, +# apply_qa, +# qualityFlags, +# flag=False, +# geospatial_crop=None +# ): +# """ +# Process a single file and return aggregated results or flattened arrays. +# """ +# try: +# ds = xr.open_dataset(file, engine="h5netcdf") - obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) +# obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) - #Run a time domain selection - if apply_qa: - obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) +# #Run a time domain selection +# if apply_qa: +# obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) - obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) +# obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) - if geospatial_crop is not None: - obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) - - if not flag: - # Regrid and aggregate - if isinstance(r, RotatedGrid): - lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) - elif isinstance(r, RegularGrid): - lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) - ds.close() - return obs_agg, lat_agg, lon_agg, count_obs - else: - ds.close() - return obs, lat, lon, time_values +# if geospatial_crop is not None: +# obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) + +# if not flag: +# # Regrid and aggregate +# if isinstance(r, RotatedGrid): +# lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) +# elif isinstance(r, RegularGrid): +# lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) +# ds.close() +# return obs_agg, lat_agg, lon_agg, count_obs +# else: +# ds.close() +# return obs, lat, lon, time_values + +# except Exception as e: +# print(f"Error processing file {file}: {e}") + +# return None - except Exception as e: - print(f"Error processing file {file}: {e}") - return None -def process_batch(args, batch): +# def process_batch(args, batch): - with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed - results = pool.starmap(process_single_file, args) +# with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed +# results = pool.starmap(process_single_file, args) - # Filter valid results and log failed files - valid_results = [res for res in results if res is not None] - print(f"Valid results: {len(valid_results)}") - failed_files= [file for file, res in zip(batch, results) if res is None] +# # Filter valid results and log failed files +# valid_results = [res for res in results if res is not None] +# print(f"Valid results: {len(valid_results)}") +# failed_files= [file for file, res in zip(batch, results) if res is None] - return valid_results, failed_files +# return valid_results, failed_files def combine_arrays(valid_results, cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count): @@ -502,6 +487,8 @@ def combine_arrays(valid_results, cumulative_obs, cumulative_lat, cumulative_lon return cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count + + def compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count): """ Compute the final averages from the cumulative arrays. In other words divide the cumulative_obs by the cumulative_count. @@ -518,6 +505,8 @@ def compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumul return final_obs, final_lat, final_lon + + def append_and_concatenate(valid_results, day): """ Append and concatenate arrays @@ -546,6 +535,8 @@ def append_and_concatenate(valid_results, day): return final_obs, final_lat, final_lon, final_time + + def filter_observations(obs, lat, lon, count): """ Filter out observations with negative values @@ -557,4 +548,5 @@ def filter_observations(obs, lat, lon, count): final_obs = obs[valid_obs] final_count = count[valid_obs] - return final_obs, final_lat, final_lon, final_count \ No newline at end of file + return final_obs, final_lat, final_lon, final_count + diff --git a/mapies/util/multiprocessing.py b/mapies/util/multiprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..0557fe24a6dc11515364f770a3a943357bf2bade --- /dev/null +++ b/mapies/util/multiprocessing.py @@ -0,0 +1,279 @@ +import time +import logging +import numpy as np +import pandas as pd +import xarray as xr + +from functools import wraps +from numpy import cos, sin, arctan, pi, nan +from glob import glob +from typing import List +from datetime import datetime +from abc import ABC, abstractmethod +import multiprocessing +from dataclasses import dataclass +from numpy.typing import DTypeLike, NDArray + +from mapies.util.func_tools import * +from mapies.grids.monarch import Grid, RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function + + +def process_single_file_tropomi( + r, + file, + obs_var, + lat_var, + lon_var, + time_var, + start_date, + end_date, + apply_qa, + qualityFlags, + flag=False, + geospatial_crop=None + ): + """ + Process a single file and return aggregated results or flattened arrays. + """ + try: + ds = xr.open_dataset(file, engine="h5netcdf", group="PRODUCT") + except Exception as e: + print(f"Reading file {file}: {e}") + return None + + # Flatten arrays + try: + obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) + except Exception as e: + print(f"Error processing file {file}: {e}") + return None + + # Apply quality assurance if applicable + if apply_qa: + try: + obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) + except Exception as e: + print(f"Error applying QA Flags for file {file}: {e}") + return None + # Perform time domain selection beteen start and end date + try: + obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) + except Exception as e: + print(f"Error performing time domain selection for file {file}: {e}") + return None + + # Perform station domain selection if specified + if geospatial_crop is not None: + try: + obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) + except Exception as e: + print(f"Error performing spatial domain selection for file {file}: {e}") + return None + + if not flag: + try: + # Regrid and aggregate + if isinstance(r, RotatedGrid): + lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + elif isinstance(r, RegularGrid): + lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + ds.close() + return obs_agg, lat_agg, lon_agg, count_obs + except Exception as e: + print(f"Error performing regridding for file {file}: {e}") + return None + else: + ds.close() + return obs, lat, lon, time_values + + + +def process_single_file_viirs( + r, + file, + obs_var, + lat_var, + lon_var, + time_var, + start_date, + end_date, + apply_qa, + qualityFlags, + flag=False, + geospatial_crop=None + ): + """ + Process a single file and return aggregated results or flattened arrays. + """ + try: + ds = xr.open_dataset(file, engine="h5netcdf") + except Exception as e: + print(f"Reading file {file}: {e}") + return None + + # Flatten arrays + try: + obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) + except Exception as e: + print(f"Error processing file {file}: {e}") + return None + + # Apply quality assurance if applicable + if apply_qa: + try: + obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) + except Exception as e: + print(f"Error applying QA Flags for file {file}: {e}") + return None + # Perform time domain selection beteen start and end date + try: + obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) + except Exception as e: + print(f"Error performing time domain selection for file {file}: {e}") + return None + + # Perform station domain selection if specified + if geospatial_crop is not None: + try: + obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) + except Exception as e: + print(f"Error performing spatial domain selection for file {file}: {e}") + return None + + if not flag: + try: + # Regrid and aggregate + if isinstance(r, RotatedGrid): + lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + elif isinstance(r, RegularGrid): + lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + ds.close() + return obs_agg, lat_agg, lon_agg, count_obs + except Exception as e: + print(f"Error performing regridding for file {file}: {e}") + return None + else: + ds.close() + return obs, lat, lon, time_values + + + +class GenericMultiProcessing(ABC): + """ + Base method for multiprocessing + """ + + @abstractmethod + def process_batch(self): + pass + + +@dataclass +class VIIRSMultiprocessing(GenericMultiProcessing): + """ + Multiprocessing method for VIIRS + """ + r: Grid + obs_var: NDArray + lat_var: NDArray + lon_var: NDArray + time_var: NDArray + start_date: pd.Timestamp + end_date: pd.Timestamp + apply_qa: bool + qualityFlags: dict + batch: List[str] + geospatial_crop: NDArray = None + flag: bool = False + + + + def process_batch(self): + """ + Batch processing for VIIRS + """ + + args = [ + ( + self.r, + file, + self.obs_var, + self.lat_var, + self.lon_var, + self.time_var, + self.start_date, + self.end_date, + self.apply_qa, + self.qualityFlags, + self.flag, + self.geospatial_crop + ) + for file in self.batch + ] + + + + with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed + results = pool.starmap(process_single_file_viirs, args) + + # Filter valid results and log failed files + valid_results = [res for res in results if res is not None] + print(f"Valid results: {len(valid_results)}") + failed_files= [file for file, res in zip(self.batch, results) if res is None] + + return valid_results, failed_files + + +@dataclass +class TROPOMIMultiprocessing(GenericMultiProcessing): + """ + Multiprocessing method for TROPOMI + """ + r: Grid + obs_var: NDArray + lat_var: NDArray + lon_var: NDArray + time_var: NDArray + start_date: pd.Timestamp + end_date: pd.Timestamp + apply_qa: bool + qualityFlags: dict + batch: List[str] + geospatial_crop: NDArray = None + flag: bool = False + + + + def process_batch(self): + """ + Batch processing for TROPOMI + """ + + args = [ + ( + self.r, + file, + self.obs_var, + self.lat_var, + self.lon_var, + self.time_var, + self.start_date, + self.end_date, + self.apply_qa, + self.qualityFlags, + self.flag, + self.geospatial_crop + ) + for file in self.batch + ] + + + with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed + results = pool.starmap(process_single_file_tropomi, args) + + # Filter valid results and log failed files + valid_results = [res for res in results if res is not None] + print(f"Valid results: {len(valid_results)}") + failed_files= [file for file, res in zip(self.batch, results) if res is None] + + return valid_results, failed_files diff --git a/mapies/viirs.py b/mapies/viirs.py index 3b045bc160abd02fb98e1b62e8f01e7ad3e48901..3670d494b3763d7c24dc9985d599544ea7a0fb93 100644 --- a/mapies/viirs.py +++ b/mapies/viirs.py @@ -5,7 +5,7 @@ from functools import wraps import logging import os import time -from tqdm import tqdm +#from tqdm import tqdm import yaml import multiprocessing import gc @@ -29,65 +29,13 @@ from mapies.util.variables_names import VARS_ATTRIBUTES from pathlib import Path import json from numpy.typing import DTypeLike, NDArray +from mapies.util.multiprocessing import VIIRSMultiprocessing os.environ["OMP_NUM_THREADS"] = "8" #logger = logging.getLogger(__name__) #logging.basicConfig(filename='../logs/mapies.log', level=logging.INFO) -TIME_ORIGIN = np.datetime64("1900-01-01T00:00:00") - - - - -# def process_single_file( -# r, -# file, -# obs_var, -# lat_var, -# lon_var, -# time_var, -# start_date, -# end_date, -# apply_qa, -# qualityFlags, -# flag=False, -# geospatial_crop=None -# ): -# """ -# Process a single file and return aggregated results or flattened arrays. -# """ -# try: -# ds = xr.open_dataset(file, engine="h5netcdf") - -# obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) - -# #Run a time domain selection -# if apply_qa: -# obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) - -# obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) - -# if geospatial_crop is not None: -# obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) - -# if not flag: -# # Regrid and aggregate -# if isinstance(r, RotatedGrid): -# lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) -# elif isinstance(r, RegularGrid): -# lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) -# ds.close() -# return obs_agg, lat_agg, lon_agg, count_obs -# else: -# ds.close() -# return obs, lat, lon, time_values - -# except Exception as e: -# print(f"Error processing file {file}: {e}") - -# return None - class VIIRS(MAPIES): """ @@ -102,6 +50,7 @@ class VIIRS(MAPIES): self.time_orig = np.datetime64("1993-01-01T00:00:00") self.datatype="viirs" + self.mp_type = VIIRSMultiprocessing self.dest = kwargs.get("dest") if not os.path.exists(self.dest): raise Exception("Output directory doesn't exist") @@ -207,327 +156,6 @@ class VIIRS(MAPIES): print(f"Day {day:02d}: {len(data['files'])} files") - - @timeit - def process_data(self, monthly_avg=False, batch_size=100, save=False, apply_qa=True, geospatial_crop: NDArray | None = None): - - """ - Process the data. - - Parameters:geospatial_crop: NDArray | None = None - - monthly_avg: bool, whether to compute single monthly averages. - - batch_size: int, the number of files to process in each batch. - - Returns: - - if monthly_avg --> updated monthly dictorionary for the processed month. - - if not monthly_avg --> updated self.lon_values, self.lat_values, self.obs, self.count_obs. - """ - - #Check if the user wants qa flags - - if apply_qa: - self.apply_qa = True - else: - self.apply_qa = False - if geospatial_crop is not None: - self.geospatial_crop = np.array(geospatial_crop) - if monthly_avg: - for month in self.monthly_dict.keys(): - # print(f'Month: {self.monthly_dict[month]}') - # Catch if they haven't selected a grid - try: - print(f"Running average on {self.grid}") - except AttributeError: - raise Exception("Please specify a grid in the initialisation of MAPIES, using grid_repr=''") - try: - # Process the data for the current month - final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( - batch_size=batch_size, month=month - ) - - # Update the monthly dictionary - # valid_obs = final_obs > 0 - # self.monthly_dict[month] = { - # "lon": final_lon[valid_obs], - # "lat": final_lat[valid_obs], - # "obs": final_obs[valid_obs], - # "count": cumulative_count[valid_obs], - # } - self.monthly_dict[month] = { - "lon": final_lon, - "lat": final_lat, - "obs": final_obs, - "count": cumulative_count, - } - print(f'Updated monthly dictionary for month {month}') - print(f'Lon shape: {final_lon.shape}, Lat shape: {final_lat.shape}, Obs shape: {final_obs.shape}') - if np.all(final_lon == 0) or np.all(np.isnan(final_lon)): - print(f"All longitude values for month {month} are 0 or NaN.") - if np.all(final_lat == 0) or np.all(np.isnan(final_lat)): - print(f"All latitude values for month {month} are 0 or NaN.") - if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): - print(f"All observation values for month {month} are 0 or NaN.") - - except Exception as e: - print(f"Error processing data for month {month}: {e}") - continue - else: - try: - final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( - batch_size=batch_size, month=None - ) - - self.obs, self.lat, self.lon, self.count_obs = filter_observations(final_obs, final_lat, final_lon, cumulative_count) - # valid_obs = final_obs > 0 - # self.lon = final_lon[valid_obs] - # self.lat = final_lat[valid_obs] - # self.obs = final_obs[valid_obs] - # self.count_obs = cumulative_count[valid_obs] - - if save: - try: - # mid_time = self.start_date + (self.end_date - self.start_date) / 2 - # mid_time = np.datetime64(mid_time) - # N = len(self.obs) - # mid_time_array = np.repeat(mid_time, N) - mid_time_array = generate_mid_time_array(self.start_date, self.end_date, self.obs) - - self.to_netCDF(self.obs, self.lat, self.lon, mid_time_array, self.start_date, self.end_date) - print(f'Saved processed data to netCDF') - except Exception as e: - print(f"Error saving data into netCDF: {e}") - return - - except Exception as e: - print(f"Error processing data for the time period from {self.start_date} to {self.end_date}: {e}") - return - - - def process_monthly_data(self, batch_size, month=None): - - """ - Process the data with multiprocessing. - """ - flag = False - if month is None: - month_files = self.files - print(f'Processing all files from {self.start_date} to {self.end_date}') - else: - month_files = self.monthly_dict[month]["files"] - print(f"Processing {len(month_files)} files for month {month}") - - - # Split files into batches - batches = create_batches(month_files, batch_size) - - # Initialize cumulative arrays - cumulative_obs = None - cumulative_lat = None - cumulative_lon = None - cumulative_count = None - failed_files = [] - - for batch in batches: - print(f"Processing a batch of {len(batch)} files...") - args = [ - ( - self.grid, - file, - self.obs_var, - self.lat_var, - self.lon_var, - self.time_var, - self.start_date, - self.end_date, - self.apply_qa, - self.qualityFlags, - flag, - self.geospatial_crop - ) - for file in batch - ] - valid_results, failed = process_batch(args, batch) - failed_files.extend(failed) - - if not valid_results: - print("All files in this batch failed. Skipping...") - continue - - # Aggregate results incrementally - cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count = combine_arrays(valid_results, cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) - - if failed_files: - print(f"The following files failed to process: {failed_files}") - - # Check if any valid data was processed - if cumulative_obs is None or cumulative_count is None: - print(f"No valid data processed for month {month}. Returning empty arrays.") - return None, None, None, None - - # Compute final averages - final_obs, final_lat, final_lon = compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) - - print(f"Completed processing for month {month}.") - - return final_lon, final_lat, final_obs, cumulative_count - - - @timeit - def yearly_average(self): - """ - Compute the yearly average for the processed data. - """ - cumulative_obs = None - cumulative_lat = None - cumulative_lon = None - cumulative_count = None - - for key, month_data in self.monthly_dict.items(): - if not all(var in month_data for var in ["lon", "lat", "obs", "count"]): - print(f"Skipping month {key}: Missing data.") - continue - if any(month_data[var] is None for var in ["lon", "lat", "obs", "count"]): - print(f"Skipping month {key}: No valid data.") - continue - - lon_values = month_data["lon"] - lat_values = month_data["lat"] - obs_values = month_data["obs"] - count_values = month_data["count"] - - results = [(obs_values, lat_values, lon_values, count_values)] - cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count = combine_arrays(results, cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) - - if cumulative_obs is None or cumulative_count is None or np.all(cumulative_count == 0): - print("No valid data for the entire year. Returning empty arrays.") - self.yearly_lon = None - self.yearly_lat = None - self.yearly_obs = None - self.yearly_count = None - return - - yearly_obs, yearly_lat, yearly_lon = compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) - - self.yearly_lon = yearly_lon - self.yearly_lat = yearly_lat - self.yearly_obs = yearly_obs - self.yearly_count = cumulative_count - - print("Yearly average computation completed.") - - @timeit - def process_lazy_data(self, apply_qa=True, geospatial_crop: NDArray | None = None, save=True): - """ - Process the data for the specified time range. It only retrieves a daily dictionary using dask lazy loading. - """ - flag = True - #Check if the user wants qa flags - if apply_qa: - self.apply_qa = True - else: - self.apply_qa = False - if geospatial_crop is not None: - self.geospatial_crop = np.array(geospatial_crop) - for day in self.daily_dict.keys(): - try: - files = self.daily_dict[day]["files"] - args = [ - ( - self.grid, - file, - self.obs_var, - self.lat_var, - self.lon_var, - self.time_var, - self.start_date, - self.end_date, - self.apply_qa, - self.qualityFlags, - flag, - self.geospatial_crop - - ) - for file in files - ] - - valid_results, failed_files = process_batch(args, files) - - final_obs, final_lat, final_lon, final_time = append_and_concatenate(valid_results, day) - - # Skip the day if no valid data was processed - if final_obs is None: - continue - - # Check if any valid data exists - if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): - print(f"All observation values for day {day} are 0 or NaN. Skipping...") - continue - - # Update day dict - self.daily_dict[day]["obs"] = final_obs - self.daily_dict[day]["lon"] = final_lon - self.daily_dict[day]["lat"] = final_lat - self.daily_dict[day]["time"] = final_time - - # Get day start and day end but only if they are within object start and end date - day_start, day_end = day_boundaries(day, self.start_date, self.end_date) - if save: - self.to_netCDF(final_obs, final_lat, final_lon, final_time, day_start, day_end) - - except Exception as e: - print(f"Error processing data for day {day}: {e}") - continue - - def to_netCDF(self, obs, lat, lon, time, start_time, end_time): - try: - # Convert time to seconds since TIME_ORIGIN - final_time_seconds = (time - TIME_ORIGIN) / np.timedelta64(1, "s") - final_time_seconds = np.round(final_time_seconds, 0) - - grid_representation_str = json.dumps(self.grid_config, ensure_ascii=False) - qa_flags_str = json.dumps(self.qualityFlags, ensure_ascii=False) - - # Create xarray Dataset - ds = xr.Dataset( - coords={"time": ("time", final_time_seconds, { - "units": "seconds since 1900-01-01 00:00:00", - "calendar": "gregorian" - })}, - - data_vars={ - "lon": (["time"], lon, VARS_ATTRIBUTES['lon'].copy()), - "lat": (["time"], lat, VARS_ATTRIBUTES['lat'].copy()), - VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: (["time"], obs, VARS_ATTRIBUTES[self.obs_var].copy()), - }, - attrs={ - "title": f"{VARS_ATTRIBUTES['title']['description']} from {start_time} to {end_time}", - "institution": "Barcelona Supercomputing Center", - "grid": f"{self.grid_repr}: {grid_representation_str}", - "Developers": "MAPIES team", - "QA": f"Quality Assurance: {self.apply_qa}, quality assurance applied: {qa_flags_str}", - "history": f"Created {datetime.now()}", - }, - ) - - # Define encoding for compact storage - encoding = { - "time": {"dtype": "float32"}, - "lon": {"dtype": "float32"}, - "lat": {"dtype": "float32"}, - VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: {"dtype": "float32"}, - } - - # Save to NetCDF file - filename = f"{self.dest}/Data_{self.datatype}_{start_time}_{end_time}.nc" - ds.to_netcdf(filename, mode='w', encoding=encoding) - - print(f"Saved data to {filename}") - - except Exception as e: - print(f"Error saving data into netCDF: {e}") - return - - # ============================================================================= # Supporting functions # ============================================================================= @@ -639,141 +267,4 @@ class VIIRS(MAPIES): ds.to_netcdf(filename, encoding={}) outfiles.append(filename) - return outfiles - - - - - - - - - -# ============================================================================= -# Plotting -# ============================================================================= - def plot_2D_observations(self, months=[0], filename=None, outdir=None, display_fig=False): - - if months == [0]: - try: - lon_values = self.lon - lat_values = self.lat - obs_values = self.obs - except Exception as e: - print(f"Error plotting all data: {e}") - return - title = f"Observation 2D plot from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_from_{self.start_date}_to_{self.end_date}.png" - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - display_fig=display_fig - ) - return - elif months is not None: - for m in months: - try: - month_data = self.monthly_dict[m] - lon_values = month_data["lon"] - lat_values = month_data["lat"] - obs_values = month_data["obs"] - except Exception as e: - print(f"Error for month {m}: {e}") - continue - title = f"Observation 2D plot of month {m} from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_month_{m}_year_{self.year}.png" - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - display_fig=display_fig - ) - else: - try: - lon_values = self.yearly_lon - lat_values = self.yearly_lat - obs_values = self.yearly_obs - except Exception as e: - print(f"Error for yearly data: {e}") - return - title = f"Observation 2D plot of yearly average of year {self.year}" - filename = filename or f"{self.datatype}_2D_yearly_obs_{self.year}.png" - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - display_fig=display_fig, - ) - - def plot_2D_num_obs(self, months=[0], filename=None, outdir=None, display_fig=False): - - if months == [0]: - try: - lon_values = self.lon - lat_values = self.lat - cum_count = self.count_obs - except Exception as e: - print(f"Error plotting all data: {e}") - return - title = f"Plot of the number of valid observations from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_count_from_{self.start_date}_to_{self.end_date}.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - display_fig=display_fig - ) - return - elif months is not None: - for m in months: - try: - month_data = self.monthly_dict[m] - lon_values = month_data["lon"] - lat_values = month_data["lat"] - cum_count = month_data["count"] - except Exception as e: - print(f"Error for month {m}: {e}") - continue - title = f"Plot of the number of valid observations for month {m} from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_count_month_{m}_year_{self.year}.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - display_fig=display_fig - ) - else: - try: - lon_values = self.yearly_lon - lat_values = self.yearly_lat - cum_count = self.yearly_count - except Exception as e: - print(f"Error for yearly data: {e}") - return - title = f"Plot of the number of valid observations for yearly average of {self.year}" - filename = filename or f"{outdir}/{self.datatype}_2D_obs_count_yearly_{self.year}.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - display_fig=display_fig - ) \ No newline at end of file + return outfiles \ No newline at end of file diff --git a/run/run_tropomi.py b/run/run_tropomi.py index ad5ff7786fe3154d79995620d71e2c269c08030c..5e4d7a22c97820c772a003bde11330928cf1d812 100644 --- a/run/run_tropomi.py +++ b/run/run_tropomi.py @@ -15,12 +15,12 @@ if __name__ == "__main__": c = TROPOMI(start_date, end_date, dest=outdir, indir=indir, grid_repr="regular") c.read_nc() - c.process_data(monthly_avg=False, batch_size = 14, apply_qa=True, geospatial_crop=[[-40, 40], [0, 70]]) + c.process_data(monthly_avg=False, batch_size = 14, apply_qa=True, geospatial_crop=[[-40, 60], [0, 70]]) # c.yearly_average() c.plot_2D_observations(months=[0], outdir=outdir) c.plot_2D_num_obs(months=[0], outdir=outdir) - c.process_lazy_data(apply_qa=True, geospatial_crop=[[-40, 40], [0, 70]]) + c.process_lazy_data(apply_qa=True, geospatial_crop=[[-40, 60], [0, 70]]) end_time = time.time() elapsed_time = end_time - start_time print(f"Script executed in {elapsed_time:.2f} seconds.") diff --git a/run/run_viirs.py b/run/run_viirs.py index 41eff9979050054f8a4ae93f749ee768ffbca6fc..60959d360f0d573e9fb77d288e4f62cf6a16d5d8 100644 --- a/run/run_viirs.py +++ b/run/run_viirs.py @@ -6,16 +6,16 @@ import time if __name__ == "__main__": # Change these variables # 00:30 to 02:36 from 2024-01-01 to 2024-01-03 - start_date = "202401010000" - end_date = "202401012359" + start_date = "202401300000" + end_date = "202402012359" - # outdir="/home/cmeikle/Projects/data/" - # indir="/home/cmeikle/Projects/data/VIIRS" + outdir="/home/cmeikle/Projects/data/" + indir="/home/cmeikle/Projects/data/VIIRS" - outdir="/home/cgile/Documents/mapies/figures" - indir="/home/cgile/Documents/mapies/VIIRS" + #outdir="/home/cgile/Documents/mapies/figures" + #indir="/home/cgile/Documents/mapies/VIIRS" start_time = time.time() @@ -24,12 +24,12 @@ if __name__ == "__main__": c.read_nc() - c.process_data(monthly_avg=False, batch_size = 100, apply_qa=True, save=True) - # c.yearly_average() - # c.plot_2D_observations(months=None, outdir=outdir) - # c.plot_2D_num_obs(months=None, outdir=outdir) - # c.process_lazy_data(apply_qa=True, save=True) - # c.to_da(frequency="3H", save_figs=True) + c.process_data(monthly_avg=True, batch_size = 100, apply_qa=True, save=True) + c.yearly_average() + c.plot_2D_observations(months=None, filename="new.png", outdir=outdir) + c.plot_2D_num_obs(months=None, outdir=outdir) + #c.process_lazy_data(apply_qa=True, save=True) + #c.to_da(frequency="3H", save_figs=True) diff --git a/setup.py b/setup.py index 0916106fdd0741f4c2ab41fdb7b522244eb9f521..de6cc56e769d8c73dd27f8ad8af779cee12fa055 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ from setuptools import find_packages from setuptools import setup # Could update this using versioneer -version="0.1.0" +version="0.1.1" setup( name="mapies",