diff --git a/mapies/config/satellite_config.yaml b/mapies/config/satellite_config.yaml index 847a47d71a4792fb8f4467cc2c95294d6269738b..af4837f296bb8916bc8f4cf7c1bc42df89e8f8b7 100644 --- a/mapies/config/satellite_config.yaml +++ b/mapies/config/satellite_config.yaml @@ -47,11 +47,27 @@ tropomi: lon_variable: "longitude" lat_variable: "latitude" obs_variable: "nitrogendioxide_tropospheric_column" + grids: + grid_repr: "" # default grid repr + rotated: + centre_lat: 35 + centre_lon: 20 + west: -51 + south: -35 + dlon: 0.1 + dlat: 0.1 + regular: + west: -180 + south: -90 + dlon: 0.1 + dlat: 0.1 + #nlon: 3600 + #nlat: 1800 qa: qa_variables: qa_value: qa_name: "qa_value" - qa_values: [0.5] + qa_values: [1] #[0.62, 0.65999997, 0.66999996, 0.69, 0.72999996, 0.74, 0.78999996, 0.84, 0.88, 0.9, 0.93, 1.] qa_combined: True diff --git a/mapies/mapies.py b/mapies/mapies.py index 206969bc5f97480d10b73534dfe06eba4b8e4842..adca3b403ea2aa7e53943394c7a6943ef91c8ad4 100644 --- a/mapies/mapies.py +++ b/mapies/mapies.py @@ -40,6 +40,7 @@ class MAPIES: self.time_orig = np.datetime64("1900-01-01T00:00:00") self.start_date = start_date self.end_date = end_date + self.grid_dict = { @@ -189,8 +190,13 @@ class MAPIES: proj = ccrs.PlateCarree() # Define discrete levels for the color bar - levels = [0.02, 0.04, 0.06, 0.08, 0.10, 0.15, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 1.00] - + if self.datatype == "viirs": + label = "AOD" + levels = [0.02, 0.04, 0.06, 0.08, 0.10, 0.15, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 1.00] + elif self.datatype == "tropomi": + #TODO remove hard coding + label = "NO2" + levels = [0.00001, 0.00002, 0.00003, 0.00004, 0.00005, 0.00006, 0.00007, 0.00008, 0.00009, 0.00010, 0.00015, 0.00020, 0.00030, 0.00040] # Define the custom color list colors = [ '#290AD8', '#2645FA', '#3A91FF', '#66CBFF', '#99EEFF', @@ -216,7 +222,7 @@ class MAPIES: # Add the color bar cbar = fig.colorbar(im, ax=ax, orientation="vertical", pad=0.05, ticks=levels, extend="both") - cbar.set_label("AOD") + cbar.set_label(label) cbar.ax.set_yticklabels([f"{level:.2f}" for level in levels]) # Format tick labels # Add title and save the plot @@ -277,6 +283,10 @@ class MAPIES: # Ensure the output directory exists os.makedirs(outdir, exist_ok=True) + output_filename = "few{}fcec".format(site_code) + print("{}".format(this_string)) + + filepath = os.path.join(outdir, filename) plt.savefig(filepath, format="png") if display_fig: diff --git a/mapies/tropomi.py b/mapies/tropomi.py index aaf5cdc6f20a01102f05d5525dce958bf056111e..26e3bef76e99eb77876e60230ba84f786749a3fb 100644 --- a/mapies/tropomi.py +++ b/mapies/tropomi.py @@ -1,7 +1,6 @@ #!/usr/bin/env python from dataclasses import dataclass, field from typing import List, Tuple -from netCDF4 import Dataset, num2date, chartostring from datetime import datetime, timedelta from functools import wraps @@ -9,10 +8,62 @@ from functools import wraps from mapies.mapies import MAPIES from mapies.util.func_tools import timeit, get_file_list, time_converter, inflate_array import time +import multiprocessing import numpy as np import pandas as pd import xarray as xr - +import os + +from mapies.util.func_tools import * +#timeit, get_file_list, time_converter, frequency_int, error_estimation, time_domain_selection, geo_to_rot +from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function +from numpy.typing import DTypeLike, NDArray + + +def process_single_file( + r, + file, + obs_var, + lat_var, + lon_var, + time_var, + start_date, + end_date, + apply_qa, + qualityFlags, + flag=False, + geospatial_crop=None + ): + """ + Process a single file and return aggregated results or flattened arrays. + """ + try: + ds = xr.open_dataset(file, engine="h5netcdf", group="PRODUCT") + + obs, lat, lon, time_values, time_values_index = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) + # obs = ds[obs_var].values.flatten() + # lat = ds[lat_var].values.flatten() + + if apply_qa: + obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) + if geospatial_crop is not None: + obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) + + if not flag: + # Regrid and aggregate + if isinstance(r, RotatedGrid): + lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + elif isinstance(r, RegularGrid): + lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + ds.close() + return obs_agg, lat_agg, lon_agg, count_obs + else: + ds.close() + return obs, lat, lon, time_values + + except Exception as e: + print(f"Error processing file {file}: {e}") + return None class TROPOMI(MAPIES): """ @@ -26,9 +77,17 @@ class TROPOMI(MAPIES): super().__init__(start_date, end_date) self.time_orig = np.datetime64("1993-01-01T00:00:00") - self.lat=lat - self.lon=lon + self.dest = kwargs.get("dest") + if not os.path.exists(self.dest): + raise Exception("Output directory doesn't exist") + self.indir = kwargs.get("indir") + self.year = int(start_date[:4]) #self.quality_flag_limit + self.monthly_dict = {} + self.daily_dict = {} # Dictionary to produce daily netcdf files + self.datatype = "tropomi" + + self.geospatial_crop = None # Add quality value filter number @@ -39,118 +98,352 @@ class TROPOMI(MAPIES): if isinstance(self.end_date, str): self.end_date = time_converter(self.end_date) + self.dates_slice = pd.date_range( + self.start_date, + self.end_date, + ).strftime('%Y%m%d%H%M') - self.read_config(**kwargs) - - - @timeit - def read_config(self, **kwargs): - """ - Read yaml config file - """ - module_dir = os.path.dirname(__file__) - if config_file: - try: - config = yaml.safe_load(open(config_file)) - except FileNotFoundError: - logging.error("This is not a config file") - else: - config = yaml.safe_load(open(os.path.join(module_dir, "config/satellite_config.yaml"))) - - variable_dict = config[self.datatype]["variables"] - self.time_var = variable_dict["time_variable"] - self.lon_var = variable_dict["lon_variable"] - self.lat_var = variable_dict["lat_variable"] - # If obs_var passed then run data analysis on that variable if not pull default - obs_var = kwargs.get("obs_var") - if obs_var: - self.obs_var = obs_var - else: - self.obs_var = variable_dict["obs_variable"] + self.read_config(**kwargs) @timeit - def preprocess_vars(self): + def read_nc(self): """ - Preprocessing of the dataset + Returns list of files needed to be processed by the application """ - super().preprocess_vars() - - - # no2 column values - obs_dims = self.ds[self.obs_var].dims - obs_shape = self.ds[self.obs_var].shape - obs_attrs = self.ds[self.obs_var].attrs - - # Duplicate values in the time array to get the same shape as the flattened variable array - self.time_values = inflate_array(self.time_values, self.time_shape, no2_shape) - self.time_values_index = inflate_array(self.time_values_index, self.time_shape, obs_shape) - # Lon and lat values - lon_dims = self.ds[self.lon_var].dims - lon_shape = self.ds[self.lon_var].shape - lon_attrs = self.ds[self.lon_var].attrs - self.lon_values = self.ds[self.lon_var].values.flatten() - - # Reindex lon column values - self.lon_values = self.reindex( - self.time_values_index, - self.lon_values, - ) - + file_patterns = [] + print(f' Date slice = {self.dates_slice}') + for date in self.dates_slice: + # Convert to Julian day + date = datetime.strptime(date, '%Y%m%d%H%M').strftime('%Y%m%d') + + #TODO add something that takes in the + filepaths = f'{self.indir}/**/S5P_OFFL_L2__NO2____{date}*' + file_patterns.append(filepaths) # Add pattern to the set + + files = sorted(get_file_list(file_patterns)) + print(f"Total number of files: {len(files)}") + - lat_dims = self.ds[self.lat_var].dims - lat_shape = self.ds[self.lat_var].shape - lat_attrs = self.ds[self.lat_var].attrs - self.lat_values = self.ds[self.lat_var].values.flatten() + # I'm not sure we need this part + files_filtered = [] + for file in files: + parts = file.split('_') + date = parts[-6] - # Reindex aod column values - self.lat_values = self.reindex( - self.time_values_index, - self.lat_values, - ) + file_dt = datetime.strptime(date, "%Y%m%dT%H%M%S") + if self.start_date <= file_dt <= self.end_date: + files_filtered.append(file) - # TODO: Check if this works well, if the inflated values are along the right dimensions - self.obs = self.ds[self.obs_var].values.flatten() - - # Reindex no2 column values - self.obs = self.reindex( - self.time_values_index, - self.obs, - ) + + first_idx = max(0, files.index(files_filtered[0]) - 2) + last_idx = min(len(files) - 1, files.index(files_filtered[-1]) + 2) + self.files = files[first_idx : last_idx + 1] + + print(f"Total number of filtered files: {len(self.files)}") + + for file in self.files: + parts = file.split('_') + julian_day = parts[-6] + + date = datetime.strptime(julian_day, "%Y%m%dT%H%M%S") + if date.month not in self.monthly_dict: + self.monthly_dict[date.month] = { + "files": [], # List to store file paths + "lon": None, # Longitudes array + "lat": None, # Latitudes array + "obs": None, # Observations array + "count": None, # Count of observations used to compute the average + } + if date.day not in self.daily_dict: + self.daily_dict[date.day] = { + "files": [], # List to store file paths + "lon": None, # Longitudes array + "lat": None, # Latitudes array + "obs": None, # Observations array + "time": None, # Time array + } + + # Append the file to the list of files for the corresponding month + self.monthly_dict[date.month]["files"].append(file) + self.daily_dict[date.day]["files"].append(file) + + print(f"Discovered files for months: {list(self.monthly_dict.keys())}") + for month, data in self.monthly_dict.items(): + print(f" Month {month:02d}: {len(data['files'])} files") + + print(f"Discovered files for days: {list(self.daily_dict.keys())}") + + # Print the daily disctionary to see if files are stored in chronological order + for day, data in self.daily_dict.items(): + print(f"Day {day:02d}: {len(data['files'])} files") @timeit - def read_nc(self): - """ - Read netcdf files with xarray + def process_data(self, monthly_avg=False, batch_size=100, geospatial_crop: NDArray | None = None): """ + Process the data for the specified year and months. - files = get_file_list("/esarchive/obs/sentinel/tropomi/original_files/tropomi_OFFL_NO2/S5P_OFFL_L2__NO2____20230101T220249_20230101*.nc") - print(files) + Parameters: + - monthly_avg: bool, whether to compute single monthly averages. + - batch_size: int, the number of files to process in each batch. + - chunk_size: int, the size of chunks for processing within each batch. + - geospatial_crop, 2D numpy array + Returns: + - if monthly_avg --> lon_final, lat_final, obs_final: Aggregated longitude, latitude, and observation arrays. + - if not monthly_avg --> preprocess_vars() is called. + """ + if geospatial_crop is not None: + self.geospatial_crop = np.array(geospatial_crop) + if monthly_avg: + for month in self.monthly_dict.keys(): + # print(f'Month: {self.monthly_dict[month]}') + # Catch if they haven't selected a grid + try: + print(f"Running average on {self.grid}") + except AttributeError: + raise Exception("Please specify a grid in the initialisation of MAPIES, using grid_repr=''") + try: + # Process the data for the current month + final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( + batch_size=batch_size, month=month + ) + # final_lon, final_lat, final_obs, cumulative_count = self.process_files_in_blocks(month, block_size=300, batch_size=batch_size) + + # Update the monthly dictionary + self.monthly_dict[month] = { + "lon": final_lon, + "lat": final_lat, + "obs": final_obs, + "count": cumulative_count, + } + print(f'Updated monthly dictionary for month {month}') + print(f'Lon shape: {final_lon.shape}, Lat shape: {final_lat.shape}, Obs shape: {final_obs.shape}') + if np.all(final_lon == 0) or np.all(np.isnan(final_lon)): + print(f"All longitude values for month {month} are 0 or NaN.") + if np.all(final_lat == 0) or np.all(np.isnan(final_lat)): + print(f"All latitude values for month {month} are 0 or NaN.") + if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): + print(f"All observation values for month {month} are 0 or NaN.") + + except Exception as e: + print(f"Error processing data for month {month}: {e}") + continue + else: + try: + final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( + batch_size=batch_size, month=None + ) + self.lon_values = final_lon + self.lat_values = final_lat + self.obs = final_obs + self.count_obs = cumulative_count - # Open dataset with xarray and dask - # Tropomi requires you choose the group - # If we want to read anything other than a single tropomi dataset we will need to do it in parallel - self.ds = xr.open_mfdataset(files, group="PRODUCT") + except Exception as e: + print(f"Error processing data for the time period from {self.start_date} to {self.end_date}: {e}") + return + def process_monthly_data(self, batch_size, month=None): - @timeit - def to_plumes(self): """ - Restructing the data to be passed to calculate the plumes + Function that returns all the needed variables for the DA """ + # r = RotatedGrid(centre_lon=20, centre_lat=35, dlon=.1, dlat=.1, west=-51, south=-35) + flag = False + if month is None: + month_files = self.files + print(f'Processing all files from {self.start_date} to {self.end_date}') + else: + month_files = self.monthly_dict[month]["files"] + print(f"Processing {len(month_files)} files for month {month}") + + + # Split files into batches + batches = [ + month_files[i:i + batch_size] for i in range(0, len(month_files), batch_size) + ] + + # Initialize cumulative arrays + cumulative_obs = None + cumulative_lat = None + cumulative_lon = None + cumulative_count = None + failed_files = [] + + for batch in batches: + print(f"Processing a batch of {len(batch)} files...") + print(f"Batch: {batch}") + + # Use multiprocessing to process files within the batch + args = [(self.grid, file, self.obs_var, self.lat_var, self.lon_var, self.time_var, self.start_date, self.end_date, self.apply_qa, self.qualityFlags, flag, self.geospatial_crop) for file in batch] + + with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed + results = pool.starmap(process_single_file, args) + + # Filter valid results and log failed files + valid_results = [res for res in results if res is not None] + failed_files.extend([file for file, res in zip(batch, results) if res is None]) + + if not valid_results: + print("All files in this batch failed. Skipping...") + continue + + # Aggregate results incrementally + for obs_agg, lat_agg, lon_agg, count_obs in valid_results: + if cumulative_obs is None: + # Initialize cumulative arrays + cumulative_obs = obs_agg * count_obs + cumulative_lat = lat_agg * count_obs + cumulative_lon = lon_agg * count_obs + cumulative_count = count_obs + else: + cumulative_obs += obs_agg * count_obs + cumulative_lat += lat_agg * count_obs + cumulative_lon += lon_agg * count_obs + cumulative_count += count_obs + + if failed_files: + print(f"The following files failed to process: {failed_files}") + + # Check if any valid data was processed + if cumulative_obs is None or cumulative_count is None: + print(f"No valid data processed for month {month}. Returning empty arrays.") + return None, None, None, None + + # Compute final averages + valid_mask = cumulative_count > 0 + final_obs = np.zeros_like(cumulative_obs) + final_lat = np.zeros_like(cumulative_lat) + final_lon = np.zeros_like(cumulative_lon) + + final_obs[valid_mask] = cumulative_obs[valid_mask] / cumulative_count[valid_mask] + final_lat[valid_mask] = cumulative_lat[valid_mask] / cumulative_count[valid_mask] + final_lon[valid_mask] = cumulative_lon[valid_mask] / cumulative_count[valid_mask] + + print(f"Completed processing for month {month}.") + + return final_lon, final_lat, final_obs, cumulative_count - coords = coords = dict(time=time_values,) - ds = self.to_xarray(coords=coords, data_vars=data_vars) + def plot_2D_observations(self, months=[0], filename=None, outdir=None): + if months == [0]: + try: + lon_values = self.lon_values + lat_values = self.lat_values + obs_values = self.obs + except Exception as e: + print(f"Error plotting all data: {e}") + return + title = f"Observation 2D plot from {self.start_date} to {self.end_date}" + filename = filename or f"{self.datatype}_2D_obs_from_{self.start_date}_to_{self.end_date}.png" + super().plot_2D_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + obs_values=obs_values, + outdir=outdir, + title=title, + filename=filename, + ) + return + elif months is not None: + for m in months: + try: + month_data = self.monthly_dict[m] + lon_values = month_data["lon"] + lat_values = month_data["lat"] + obs_values = month_data["obs"] + except Exception as e: + print(f"Error for month {m}: {e}") + continue + title = f"Observation 2D plot of month {m} from {self.start_date} to {self.end_date}" + filename = filename or f"{self.datatype}_2D_obs_month_{m}_year_{self.year}.png" + super().plot_2D_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + obs_values=obs_values, + outdir=outdir, + title=title, + filename=filename, + ) + else: + try: + lon_values = self.yearly_lon + lat_values = self.yearly_lat + obs_values = self.yearly_obs + except Exception as e: + print(f"Error for yearly data: {e}") + return + title = f"Observation 2D plot of yearly average of year {self.year}" + filename = filename or f"{self.datatype}_2D_yearly_obs_{self.year}.png" + super().plot_2D_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + obs_values=obs_values, + outdir=outdir, + title=title, + filename=filename, + ) + + def plot_2D_num_obs(self, months=[0], filename=None, outdir=None): -if __name__ == "__main__": - start_date = "202301011430" - end_date = "202401011531" - c = TROPOMI(start_date, end_date) - c.read_nc() - c.preprocess_vars() \ No newline at end of file + if months == [0]: + try: + lon_values = self.lon_values + lat_values = self.lat_values + cum_count = self.count_obs + except Exception as e: + print(f"Error plotting all data: {e}") + return + title = f"Plot of the number of valid observations from {self.start_date} to {self.end_date}" + filename = filename or f"{self.datatype}_2D_obs_count_from_{self.start_date}_to_{self.end_date}.png" + super().plot_2D_num_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + num_obs=cum_count, + outdir=outdir, + title=title, + filename=filename, + ) + return + elif months is not None: + for m in months: + try: + month_data = self.monthly_dict[m] + lon_values = month_data["lon"] + lat_values = month_data["lat"] + cum_count = month_data["count"] + except Exception as e: + print(f"Error for month {m}: {e}") + continue + title = f"Plot of the number of valid observations for month {m} from {self.start_date} to {self.end_date}" + filename = filename or f"{self.datatype}_2D_obs_count_month_{m}_year_{self.year}.png" + super().plot_2D_num_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + num_obs=cum_count, + outdir=outdir, + title=title, + filename=filename, + ) + else: + try: + lon_values = self.yearly_lon + lat_values = self.yearly_lat + cum_count = self.yearly_count + except Exception as e: + print(f"Error for yearly data: {e}") + return + title = f"Plot of the number of valid observations for yearly average of {self.year}" + filename = filename or f"{outdir}/{self.datatype}_2D_obs_count_yearly_{self.year}.png" + super().plot_2D_num_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + num_obs=cum_count, + outdir=outdir, + title=title, + filename=filename, + ) diff --git a/mapies/util/func_tools.py b/mapies/util/func_tools.py index b0b634c336fd47d69e9fe6fb6efffb66871eca33..300f2ec6ce7350f3a8f9db0eb2e6a17a93bbccbf 100644 --- a/mapies/util/func_tools.py +++ b/mapies/util/func_tools.py @@ -82,33 +82,64 @@ def time_converter(date:str) -> pd.Timestamp: return date def quality_assurance(ds, obs, lat, lon, time_values, qualityFlags): - """ - Applying quality assurance flags to mapies data - """ - mask = np.zeros_like(obs, dtype=bool) - - for i, (flag, info) in enumerate(qualityFlags.items()): - - flag_mask = np.zeros_like(obs, dtype=bool) - arr = ds[info["qa_name"]].values.flatten() - for value in info["qa_values"]: - flag_mask |= (arr == value) - - if i == 0: - mask = flag_mask + """ + Applying quality assurance flags to mapies data + """ + mask = np.zeros_like(obs, dtype=bool) + + for i, (flag, info) in enumerate(qualityFlags.items()): + + flag_mask = np.zeros_like(obs, dtype=bool) + arr = ds[info["qa_name"]].values.flatten() + for value in info["qa_values"]: + flag_mask |= (arr == value) + + if i == 0: + mask = flag_mask + else: + if info["qa_combined"]: + mask &= flag_mask else: - if info["qa_combined"]: - mask &= flag_mask - else: - mask |= flag_mask + mask |= flag_mask + + obs = obs[mask] + lat = lat[mask] + lon = lon[mask] + time_values = time_values[mask] - obs = obs[mask] - lat = lat[mask] - lon = lon[mask] - time_values = time_values[mask] + return obs, lat, lon, time_values - return obs, lat, lon, time_values +def spatial_domain_selection(obs, lat, lon, time, spatial_bounds): + """ + A function for cutting the data spatially + + :param spatial_bounds: - the boundaries that you would like to crop te data by [[lon], [lat]] e.g. [[0, 20], [10, 12]] + """ + # Check that the spatial bounds are 2 dimensional + if len(spatial_bounds) != 2: + raise Exception("Spatial bounds passed are not 2 dimensional, make sure to pass [[left_lon, right_lon], [left_lat, right_lat]]") + + # Get the bounds + lon_l_bound, lon_r_bound = spatial_bounds[0] + lat_l_bound, lat_r_bound = spatial_bounds[1] + + # Crop the spatial variabes + lon_mask = np.logical_and(lon>=lon_l_bound, lon<=lon_r_bound) + lat_mask = np.logical_and(lat>=lat_l_bound, lat<=lat_r_bound) + + # Masking the data + lon = lon[lon_mask & lat_mask] + lat = lat[lon_mask & lat_mask] + time = time[lon_mask & lat_mask] + obs = obs[lon_mask & lat_mask] + + assert len(lon) == len(lat) + assert len(obs) == len(time) + assert len(lon) == len(obs) + + return obs, lat, lon, time + # Time Domain Selection function based off of xarray Will this work with dask too, a question for later def time_domain_selection( @@ -125,7 +156,6 @@ def time_domain_selection( :return: a cut down numpy array """ - # But the format of the time column may change aaaaaaaghh # For these if it is a string we need to know what the format is (maybe pandas figures that out) if isinstance(start_date, str): @@ -220,12 +250,43 @@ def geo_to_rot(lons, lats, centre_lon: float, centre_lat: float): return rlon, rlat def new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date): - """ - Preprocessing of the dataset + * Preprocessing of the dataset + * Flattening arrays + * Applying time domain selection """ - time_values, time_values_index = preprocess_vars_mapies(ds, time_var, start_date, end_date) + # Get all info about time columns + time_dims = ds[time_var].dims + time_shape = ds[time_var].shape + time_attrs = ds[time_var].attrs # Useful if we need to convert to datetime + + + # Get time values flatten and convert to datetime values + time_values = ds[time_var].values + time_values = time_values.flatten() + + if time_values.dtype == "timedelta64[ns]": + logging.info("Adding time origin to time values as the time variable is in timedelta") + time_values = np.add(TIME_ORIGIN, time_values) + time_values = pd.to_datetime(time_values) + else: + time_values = pd.to_datetime(time_values) + + + # Time domain selection + # print(f'Time values before selection: {time_values}') + time_values, time_values_index = time_domain_selection(time_values, start_date, end_date) + # print(f'Time values after selection: {time_values}')obs_dims = ds[obs_var].dims + obs_shape = ds[obs_var].shape + obs_attrs = ds[obs_var].attrs + + # Duplicate values in the time array to get the same shape as the flattened variable array + if time_shape != obs_shape: + time_values = inflate_array(time_values, time_shape, obs_shape) + time_values_index = inflate_array(time_values_index, time_shape, obs_shape) + time_values, time_values_index = time_domain_selection(time_values, start_date, end_date) + # Lon and lat values lon_dims = ds[lon_var].dims @@ -252,10 +313,7 @@ def new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end # AOD column values, default "Aerosol_Optical_Thickness_550_Land_Ocean_Best_Estimate" - aod_dims = ds[obs_var].dims - - aod_shape = ds[obs_var].shape - aod_attrs = ds[obs_var].attrs + obs = ds[obs_var].values.flatten() # Reindex aod column values @@ -266,34 +324,6 @@ def new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end return obs, lat_values, lon_values, time_values, time_values_index -def preprocess_vars_mapies(ds, time_var, start_date, end_date): - """ - Preprocessing of the dataset - """ - # Get all info about time columns - time_dims = ds[time_var].dims - time_shape = ds[time_var].shape - time_attrs = ds[time_var].attrs # Useful if we need to convert to datetime - - - # Get time values flatten and convert to datetime values - time_values = ds[time_var].values - time_values = time_values.flatten() - - if time_values.dtype == "timedelta64[ns]": - logging.info("Adding time origin to time values as the time variable is in timedelta") - time_values = np.add(TIME_ORIGIN, time_values) - time_values = pd.to_datetime(time_values) - else: - time_values = pd.to_datetime(time_values) - - - # Time domain selection - # print(f'Time values before selection: {time_values}') - time_values, time_values_index = time_domain_selection(time_values, start_date, end_date) - # print(f'Time values after selection: {time_values}') - - return time_values, time_values_index def reindex( dependent_var_index, diff --git a/mapies/viirs.py b/mapies/viirs.py index 50b3c180780480fc3fb949d95249daaf34fcda9f..e8d3aa83d5c09e7714a7470591d4fd18cf9e949a 100644 --- a/mapies/viirs.py +++ b/mapies/viirs.py @@ -26,6 +26,7 @@ from mapies.util.func_tools import * from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function from pathlib import Path from glob import glob as glob_module +from numpy.typing import DTypeLike, NDArray os.environ["OMP_NUM_THREADS"] = "8" @@ -37,12 +38,25 @@ TIME_ORIGIN = np.datetime64("1900-01-01T00:00:00") -def process_single_file(r, file, obs_var, lat_var, lon_var, time_var, start_date, end_date, apply_qa, qualityFlags, flag=False): +def process_single_file( + r, + file, + obs_var, + lat_var, + lon_var, + time_var, + start_date, + end_date, + apply_qa, + qualityFlags, + flag=False, + geospatial_crop=None + ): """ Process a single file and return aggregated results or flattened arrays. """ try: - ds = xr.open_dataset(file, engine="h5netcdf") + ds = xr.open_dataset(file, engine="h5netcdf") obs, lat, lon, time_values, time_values_index = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) # obs = ds[obs_var].values.flatten() @@ -50,7 +64,8 @@ def process_single_file(r, file, obs_var, lat_var, lon_var, time_var, start_date if apply_qa: obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) - + if geospatial_crop is not None: + obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) if not flag: # Regrid and aggregate if isinstance(r, RotatedGrid): @@ -95,13 +110,12 @@ class VIIRS(MAPIES): self.time_orig = np.datetime64("1993-01-01T00:00:00") self.datatype="viirs" - frequency = kwargs.get("frequency") self.dest = kwargs.get("dest") if not os.path.exists(self.dest): raise Exception("Output directory doesn't exist") self.indir = kwargs.get("indir") - self.year = int(start_date[:4]) + self.year = int(start_date[:4]) self.start_day = int(start_date[6:8]) self.end_day = int(end_date[6:8]) self.monthly_dict = {} @@ -139,80 +153,6 @@ class VIIRS(MAPIES): self.unc_const = da_dict["uncertainty_constants"] #float, int, list, dict - @timeit - def preprocess_vars(self): - """ - Preprocessing of the dataset - """ - print(f'Inside preprocess_vars') - super().preprocess_vars() - - - # Lon and lat values - lon_dims = self.ds[self.lon_var].dims - lon_shape = self.ds[self.lon_var].shape - lon_attrs = self.ds[self.lon_var].attrs - self.lon_values = self.ds[self.lon_var].values.flatten() - - # Reindex lon column values - self.lon_values = self.reindex( - self.time_values_index, - self.lon_values, - ) - - lat_dims = self.ds[self.lat_var].dims - lat_shape = self.ds[self.lat_var].shape - lat_attrs = self.ds[self.lat_var].attrs - self.lat_values = self.ds[self.lat_var].values.flatten() - - # Reindex aod column values - self.lat_values = self.reindex( - self.time_values_index, - self.lat_values, - ) - - - # AOD column values, default "Aerosol_Optical_Thickness_550_Land_Ocean_Best_Estimate" - aod_dims = self.ds[self.obs_var].dims - - aod_shape = self.ds[self.obs_var].shape - aod_attrs = self.ds[self.obs_var].attrs - self.obs = self.ds[self.obs_var].values.flatten() - - # Reindex aod column values - self.obs = self.reindex( - self.time_values_index, - self.obs, - ) - - # Apply QA - if self.apply_qa: - #.info("Applying quality assurance filters") - self.quality_assurance() - - - - # @timeit - # def regridding_function(self): - # """ - # Perform Rotation of Grid representation - # """ - # # Calculate the grid representation - # if isinstance(self.grid, RotatedGrid): - # # Aggregate the the observations to the grid representation - # lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = self.grid.aggregate(self.lon_values, self.lat_values, self.obs) - - # elif isinstance(self.grid, RegularGrid): - # # Aggregate the the observations to the grid representation - # lon_agg, lat_agg, obs_agg, count_obs = self.grid.aggregate(self.lon_values, self.lat_values, self.obs) - - # else: - # raise ValueError("Invalid grid representation") - - # self.lon_values = lon_agg - # self.lat_values = lat_agg - # self.obs = obs_agg - # self.count_obs = count_obs @timeit def read_nc(self): @@ -289,11 +229,11 @@ class VIIRS(MAPIES): @timeit - def process_data(self, monthly_avg=False, batch_size=100): + def process_data(self, monthly_avg=False, batch_size=100, geospatial_crop: NDArray | None = None): """ Process the data. - Parameters: + Parameters:geospatial_crop: NDArray | None = None - monthly_avg: bool, whether to compute single monthly averages. - batch_size: int, the number of files to process in each batch. @@ -301,6 +241,8 @@ class VIIRS(MAPIES): - if monthly_avg --> updated monthly dictorionary for the processed month. - if not monthly_avg --> updated self.lon_values, self.lat_values, self.obs, self.count_obs. """ + if geospatial_crop is not None: + self.geospatial_crop = np.array(geospatial_crop) if monthly_avg: for month in self.monthly_dict.keys(): # print(f'Month: {self.monthly_dict[month]}') @@ -378,17 +320,18 @@ class VIIRS(MAPIES): args = [ ( - self.grid, - file, - self.obs_var, - self.lat_var, - self.lon_var, - self.time_var, - self.start_date, - self.end_date, - self.apply_qa, - self.qualityFlags, - flag + self.grid, + file, + self.obs_var, + self.lat_var, + self.lon_var, + self.time_var, + self.start_date, + self.end_date, + self.apply_qa, + self.qualityFlags, + flag, + self.geospatial_crop ) for file in batch ] diff --git a/run/run_tropomi.py b/run/run_tropomi.py new file mode 100644 index 0000000000000000000000000000000000000000..4ee763a5f97e52762a3b923d91c54714f9c263bf --- /dev/null +++ b/run/run_tropomi.py @@ -0,0 +1,26 @@ +from mapies.tropomi import TROPOMI +import numpy as np +import time +# from dask.distributed import Client + + +if __name__ == "__main__": + start_date = "202212010000" + end_date = "202212202359" + + outdir = "/esarchive/scratch/cmeikle/" + indir = "/esarchive/obs/sentinel/tropomi/original_files/tropomi_OFFL_NO2" + + start_time = time.time() + + c = TROPOMI(start_date, end_date, dest=outdir, indir=indir, apply_qa=True, grid_repr="rotated") + c.read_nc() + c.process_data(monthly_avg=True, batch_size = 13) + + # c.yearly_average() + c.plot_2D_observations(months=[12], outdir=outdir) + c.plot_2D_num_obs(months=[12], outdir=outdir) + #c.process_lazy_data() + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Script executed in {elapsed_time:.2f} seconds.") diff --git a/run/run_viirs.py b/run/run_viirs.py index d158cf87bb5802aa3c14f47ebe9c36c184cf2921..a3adc1536d6a25a55f1017b6d01b16292d2dce94 100644 --- a/run/run_viirs.py +++ b/run/run_viirs.py @@ -22,26 +22,24 @@ if __name__ == "__main__": # start_date = "202401010000" # end_date = "202401312359" - # outdir="/home/cmeikle/Projects/data/" - # indir="/home/cmeikle/Projects/data/VIIRS/original_files/AERDB_L2_VIIRS_NOAA20" - outdir = "/home/cgile/Documents/mapies/figures" + outdir="/home/cmeikle/Projects/data/" + indir="/home/cmeikle/Projects/data/VIIRS" + + #outdir = "/home/cgile/Documents/mapies/figures" # indir = "/home/cgile/bscearth000/esarchive/obs/nasa/viirs_noaa20_aerdb_l2/original_files/VIIRS" # indir = "/home/cgile/Documents/mapies/VIIRS" - indir = '' + filename_obs = 'carlotta' filename_num_obs = 'carlotta' - start_time = time.time() c = VIIRS(start_date, end_date, dest=outdir, indir=indir, apply_qa=False, grid_repr="rotated") c.read_nc() - c.process_data(monthly_avg=False, batch_size = 50) + c.process_data(monthly_avg=False, batch_size = 50, geospatial_crop=[[-30, 50], [0, 50]]) # c.yearly_average() - c.plot_2D_observations(months=[0], filename=filename_obs, outdir=outdir) - c.plot_2D_num_obs(months=[0], filename=filename_num_obs, outdir=outdir) - # c.process_lazy_data() + c.plot_2D_observations(months=[0], filename=filename_obs, outdir=outdir, display_fig=True) end_time = time.time() elapsed_time = end_time - start_time print(f"Script executed in {elapsed_time:.2f} seconds.")