From baa4510933f525070f8bd7c61485e9e16c77c2e2 Mon Sep 17 00:00:00 2001 From: Calum Meikle Date: Fri, 28 Feb 2025 16:15:21 +0100 Subject: [PATCH 1/2] Name changes and updating runs and tests --- mapies/mapies.py | 8 ++++---- mapies/tests/test_viirs.py | 2 +- mapies/tropomi.py | 2 +- mapies/viirs.py | 2 +- run/run_tropomi.py | 4 ++-- run/run_viirs.py | 4 ++-- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mapies/mapies.py b/mapies/mapies.py index 0bee5218..551df9fc 100644 --- a/mapies/mapies.py +++ b/mapies/mapies.py @@ -103,7 +103,7 @@ class MAPIES: @timeit - def process_data(self, monthly_avg=False, batch_size=100, save=False, apply_qa=True, geospatial_crop: NDArray | None = None): + def process_avg_data(self, monthly_avg=False, batch_size=100, save=False, apply_qa=True, geospatial_crop: NDArray | None = None): """ Process the data. @@ -134,7 +134,7 @@ class MAPIES: raise Exception("Please specify a grid in the initialisation of MAPIES, using grid_repr=''") try: # Process the data for the current month - final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( + final_lon, final_lat, final_obs, cumulative_count = self.process_data_mp( batch_size=batch_size, month=month ) @@ -158,7 +158,7 @@ class MAPIES: continue else: try: - final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( + final_lon, final_lat, final_obs, cumulative_count = self.process_data_mp( batch_size=batch_size, month=None ) @@ -180,7 +180,7 @@ class MAPIES: return - def process_monthly_data(self, batch_size, month=None): + def process_data_mp(self, batch_size, month=None): """ Process the data with multiprocessing. diff --git a/mapies/tests/test_viirs.py b/mapies/tests/test_viirs.py index 0383225c..d204e5eb 100644 --- a/mapies/tests/test_viirs.py +++ b/mapies/tests/test_viirs.py @@ -16,7 +16,7 @@ def test_to_da(): indir="test_data" c = VIIRS(start_date, end_date, dest=outdir, indir=indir, grid_repr="rotated") print(c.dest) - c.read_nc() + c.gather_nc_files() c.process_lazy_data(apply_qa=False, save=False) outfiles = c.to_da(frequency="3H", save_nc=False, save_figs=False) assert outfiles == [Path("obs202401010000.nc")] \ No newline at end of file diff --git a/mapies/tropomi.py b/mapies/tropomi.py index b1c5a39c..a9feb6d4 100644 --- a/mapies/tropomi.py +++ b/mapies/tropomi.py @@ -72,7 +72,7 @@ class TROPOMI(MAPIES): @timeit - def read_nc(self): + def gather_nc_files(self): """ Returns list of files needed to be processed by the application """ diff --git a/mapies/viirs.py b/mapies/viirs.py index 3670d494..1ab4f01b 100644 --- a/mapies/viirs.py +++ b/mapies/viirs.py @@ -96,7 +96,7 @@ class VIIRS(MAPIES): @timeit - def read_nc(self): + def gather_nc_files(self): """ Returns list of files needed to be processed by the application """ diff --git a/run/run_tropomi.py b/run/run_tropomi.py index 5e4d7a22..2b438d8b 100644 --- a/run/run_tropomi.py +++ b/run/run_tropomi.py @@ -14,8 +14,8 @@ if __name__ == "__main__": start_time = time.time() c = TROPOMI(start_date, end_date, dest=outdir, indir=indir, grid_repr="regular") - c.read_nc() - c.process_data(monthly_avg=False, batch_size = 14, apply_qa=True, geospatial_crop=[[-40, 60], [0, 70]]) + c.gather_nc_files() + c.process_avg_data(monthly_avg=False, batch_size = 14, apply_qa=True, geospatial_crop=[[-40, 60], [0, 70]]) # c.yearly_average() c.plot_2D_observations(months=[0], outdir=outdir) diff --git a/run/run_viirs.py b/run/run_viirs.py index 60959d36..9cdb89eb 100644 --- a/run/run_viirs.py +++ b/run/run_viirs.py @@ -21,10 +21,10 @@ if __name__ == "__main__": start_time = time.time() c = VIIRS(start_date, end_date, dest=outdir, indir=indir, grid_repr="rotated") - c.read_nc() + c.gather_nc_files() - c.process_data(monthly_avg=True, batch_size = 100, apply_qa=True, save=True) + c.process_avg_data(monthly_avg=True, batch_size = 100, apply_qa=True, save=True) c.yearly_average() c.plot_2D_observations(months=None, filename="new.png", outdir=outdir) c.plot_2D_num_obs(months=None, outdir=outdir) -- GitLab From 237d0a49b5e96a1406d0f26837273b42e1027f3d Mon Sep 17 00:00:00 2001 From: Calum Meikle Date: Mon, 3 Mar 2025 13:10:25 +0100 Subject: [PATCH 2/2] Cleaning and adding units to no2 --- CHANGELOG | 3 +- mapies/grids/monarch.py | 53 +++++++------- mapies/mapies.py | 101 +++++++++++++------------- mapies/tests/test_viirs.py | 3 - mapies/tropomi.py | 42 ++++------- mapies/util/func_tools.py | 129 ++++++++------------------------- mapies/util/multiprocessing.py | 127 +++++++++++++++----------------- mapies/util/variables_names.py | 2 +- mapies/viirs.py | 65 ++++++----------- run/run_tropomi.py | 4 +- run/run_viirs.py | 4 +- 11 files changed, 210 insertions(+), 323 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index e1ef77d7..41756c63 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,4 +1,5 @@ -CHANGELOG +CHANGELOG + ## 0.1.1 - 2025/02/28 - Code refactoring diff --git a/mapies/grids/monarch.py b/mapies/grids/monarch.py index c00bf208..2134dde1 100644 --- a/mapies/grids/monarch.py +++ b/mapies/grids/monarch.py @@ -1,12 +1,9 @@ #!/usr/bin/env python from dataclasses import dataclass -from pandas import DataFrame -import xarray as xr import pandas as pd import numpy as np from typing import Tuple, Dict, Optional from numpy import cos, sin, arctan, pi, nan -from functools import partial, wraps import geopandas as gpd from geopandas import GeoDataFrame from shapely.geometry import Polygon, Point @@ -53,16 +50,16 @@ def geo_to_rot(lons, lats, centre_lon: float, centre_lat: float): x = cos(centre_lat) * sin(lats) - sin(centre_lat) * cos(lats) * cos(distance_lons) y = cos(lats) * sin(distance_lons) z = cos(centre_lat) * cos(lats) * cos(distance_lons) + sin(centre_lat) * sin(lats) - # Arctan2 used + # Arctan2 used # Explanation of the difference https://geo.libretexts.org/Courses/University_of_California_Davis/GEL_056%3A_Introduction_to_Geophysics/Geophysics_is_everywhere_in_geology.../zz%3A_Back_Matter/Arctan_vs_Arctan2 rlon = np.arctan2(y, z) rlat = np.arcsin(x) - #rlon[x < 0] += pi # Convert back to degrees rlon = np.degrees(rlon) rlat = np.degrees(rlat) return rlon, rlat + @dataclass class Grid: dlon: float = None @@ -98,7 +95,6 @@ class Grid: self.center_latitudes = np.linspace(self.cenlat, self.cenlat + self.dlat * (self.nlat) - self.dlat, self.nlat, dtype=float) self.center_longitudes = np.linspace(self.cenlon, self.cenlon + self.dlon * (self.nlon) - self.dlat, self.nlon, dtype=float) - counter = 0 self.coords = np.empty((self.center_latitudes.shape[0] * self.center_longitudes.shape[0], 2)) for i in range(self.center_longitudes.shape[0]): @@ -154,7 +150,7 @@ class Grid: # Re-shape aux_b_lats = self.b_lats.squeeze() aux_b_lons = self.b_lons.squeeze() - + geometry = [] # Create one dataframe with 8 columns, 4 points with two coordinates each one @@ -167,14 +163,15 @@ class Grid: self.gdf = GeoDataFrame(index=range(aux_b_lons.shape[0]), geometry=geometry, crs='epsg:4326') self.gdf["grid_cell"] = ["grid_cell_{}".format(i+1) for i in range(len(self.gdf))] - + # Append the center lat and lon values to this dataframe self.gdf["lon_cen"] = self.coords[:, 0] self.gdf["lat_cen"] = self.coords[:, 1] - + return True + @dataclass class RegularGrid(Grid): """ @@ -186,7 +183,7 @@ class RegularGrid(Grid): def __post_init__(self): self.calculate_grid_coords() - + def aggregate(self, lon, lat, obs, obserr=None): """ Aggregate the observations to the grid cells @@ -211,47 +208,47 @@ class RegularGrid(Grid): obs = obs[mask] if obserr is not None: obserr = obserr[mask] - + # Vectorized calculation of grid indices lon_idx = np.searchsorted(self.center_longitudes, lon) - 1 # Index of longitude grid lat_idx = np.searchsorted(self.center_latitudes, lat) - 1 # Index of latitude grid - + # Ensure indices are within bounds lon_idx = np.clip(lon_idx, 0, self.nlon - 1) lat_idx = np.clip(lat_idx, 0, self.nlat - 1) - + # Initialize arrays for aggregation aggregated_obs = np.zeros((self.nlat, self.nlon)) aggregated_lon = np.zeros((self.nlat, self.nlon)) aggregated_lat = np.zeros((self.nlat, self.nlon)) aggregated_obs_err = np.zeros((self.nlat, self.nlon)) if obserr is not None else None count_obs = np.zeros((self.nlat, self.nlon), dtype=int) - + # Efficient accumulation using np.add.at np.add.at(aggregated_obs, (lat_idx, lon_idx), obs) # Accumulate observations np.add.at(aggregated_lon, (lat_idx, lon_idx), lon) # Accumulate lon values np.add.at(aggregated_lat, (lat_idx, lon_idx), lat) # Accumulate lat values np.add.at(count_obs, (lat_idx, lon_idx), 1) # Count observations per grid cell - + if obserr is not None: np.add.at(aggregated_obs_err, (lat_idx, lon_idx), obserr) # Accumulate errors - + # Compute averages by dividing by counts non_zero_counts = count_obs > 0 aggregated_obs[non_zero_counts] /= count_obs[non_zero_counts] aggregated_lon[non_zero_counts] /= count_obs[non_zero_counts] aggregated_lat[non_zero_counts] /= count_obs[non_zero_counts] - + if aggregated_obs_err is not None: aggregated_obs_err[non_zero_counts] /= count_obs[non_zero_counts] - + # TODO: Add a function that takes the center or aggreagated value # Return the results lon_out = aggregated_lon.flatten() #self.center_longitudes lat_out = aggregated_lat.flatten() #self.center_latitudes obs_out = aggregated_obs.flatten() count_obs = count_obs.flatten() - + if aggregated_obs_err is not None: obs_err_out = aggregated_obs_err.flatten() return lon_out, lat_out, obs_out, obs_err_out, count_obs @@ -271,7 +268,7 @@ class RotatedGrid(Grid): def __post_init__(self): self.calculate_grid_coords() - + def aggregate(self, lon, lat, obs, obserr=None): """ Aggregate the observations to the grid cells @@ -306,12 +303,12 @@ class RotatedGrid(Grid): rlon_idx = np.searchsorted(self.center_longitudes, rlon) - 1 # Index of longitude grid rlat_idx = np.searchsorted(self.center_latitudes, rlat) - 1 # Index of latitude grid - + # Ensure indices are within bounds rlon_idx = np.clip(rlon_idx, 0, self.nlon - 1) rlat_idx = np.clip(rlat_idx, 0, self.nlat - 1) - + # Initialize arrays for aggregation aggregated_obs = np.zeros((self.nlat, self.nlon)) aggregated_lon = np.zeros((self.nlat, self.nlon)) @@ -330,12 +327,12 @@ class RotatedGrid(Grid): np.add.at(aggregated_rlon, (rlat_idx, rlon_idx), rlon) # Accumulate lon values np.add.at(aggregated_rlat, (rlat_idx, rlon_idx), rlat) # Accumulate lat values np.add.at(count_obs, (rlat_idx, rlon_idx), 1) # Count observations per grid cell - + if obserr is not None: np.add.at(aggregated_obs_err, (rlat_idx, rlon_idx), obserr) # Accumulate errors - + # Compute averages by dividing by counts non_zero_counts = count_obs > 0 aggregated_obs[non_zero_counts] /= count_obs[non_zero_counts] @@ -343,7 +340,7 @@ class RotatedGrid(Grid): aggregated_lat[non_zero_counts] /= count_obs[non_zero_counts] aggregated_rlon[non_zero_counts] /= count_obs[non_zero_counts] aggregated_rlat[non_zero_counts] /= count_obs[non_zero_counts] - + if aggregated_obs_err is not None: aggregated_obs_err[non_zero_counts] /= count_obs[non_zero_counts] @@ -367,12 +364,12 @@ class RotatedGrid(Grid): class IrregularRotatedGrid(Grid): projection: str = "rotated" - + def __post_init__(self): self.calculate_grid_coords() self.building_grid() - - + + def aggregate(self, lon, lat, obs, obserr=None): """ Aggregate diff --git a/mapies/mapies.py b/mapies/mapies.py index 551df9fc..5b840ed2 100644 --- a/mapies/mapies.py +++ b/mapies/mapies.py @@ -1,9 +1,8 @@ #!/usr/bin/env python # MAPIES base -from functools import wraps -from datetime import datetime, timedelta -import sys +from datetime import datetime + from mapies.util.func_tools import * from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid import time @@ -12,7 +11,6 @@ import numpy as np import pandas as pd import xarray as xr import matplotlib.pyplot as plt -import cartopy import cartopy.crs as ccrs import os import yaml @@ -137,7 +135,7 @@ class MAPIES: final_lon, final_lat, final_obs, cumulative_count = self.process_data_mp( batch_size=batch_size, month=month ) - + self.monthly_dict[month] = { "lon": final_lon, "lat": final_lat, @@ -152,7 +150,7 @@ class MAPIES: print(f"All latitude values for month {month} are 0 or NaN.") if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): print(f"All observation values for month {month} are 0 or NaN.") - + except Exception as e: print(f"Error processing data for month {month}: {e}") continue @@ -164,24 +162,23 @@ class MAPIES: self.obs, self.lat, self.lon, self.count_obs = filter_observations(final_obs, final_lat, final_lon, cumulative_count) - - if save: - try: + + if save: + try: mid_time_array = generate_mid_time_array(self.start_date, self.end_date, self.obs) - + self.to_netCDF(self.obs, self.lat, self.lon, mid_time_array, self.start_date, self.end_date) print(f'Saved processed data to netCDF') except Exception as e: print(f"Error saving data into netCDF: {e}") return - + except Exception as e: print(f"Error processing data for the time period from {self.start_date} to {self.end_date}: {e}") return def process_data_mp(self, batch_size, month=None): - """ Process the data with multiprocessing. """ @@ -196,7 +193,7 @@ class MAPIES: # Split files into batches batches = create_batches(month_files, batch_size) - + # Initialize cumulative arrays cumulative_obs = None cumulative_lat = None @@ -208,7 +205,7 @@ class MAPIES: print(f"Processing a batch of {len(batch)} files...") # Use multiprocessing to process files within the batch - + viirs_mp = self.mp_type( self.grid, self.obs_var, @@ -232,7 +229,7 @@ class MAPIES: # Aggregate results incrementally cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count = combine_arrays(valid_results, cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) - + if failed_files: print(f"The following files failed to process: {failed_files}") @@ -245,7 +242,7 @@ class MAPIES: final_obs, final_lat, final_lon = compute_final_averages(cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count) print(f"Completed processing for month {month}.") - + return final_lon, final_lat, final_obs, cumulative_count @@ -311,7 +308,7 @@ class MAPIES: files = self.daily_dict[day]["files"] # Use multiprocessing to process files within the batch - + viirs_mp = self.mp_type( self.grid, self.obs_var, @@ -332,19 +329,19 @@ class MAPIES: # Skip the day if no valid data was processed if final_obs is None: - continue + continue # Check if any valid data exists if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): print(f"All observation values for day {day} are 0 or NaN. Skipping...") continue - + # Update day dict - self.daily_dict[day]["obs"] = final_obs - self.daily_dict[day]["lon"] = final_lon + self.daily_dict[day]["obs"] = final_obs + self.daily_dict[day]["lon"] = final_lon self.daily_dict[day]["lat"] = final_lat self.daily_dict[day]["time"] = final_time - + # Get day start and day end but only if they are within object start and end date day_start, day_end = day_boundaries(day, self.start_date, self.end_date) if save: @@ -363,14 +360,14 @@ class MAPIES: grid_representation_str = json.dumps(self.grid_config, ensure_ascii=False) qa_flags_str = json.dumps(self.qualityFlags, ensure_ascii=False) - + # Create xarray Dataset ds = xr.Dataset( coords={"time": ("time", final_time_seconds, { "units": "seconds since 1900-01-01 00:00:00", "calendar": "gregorian" })}, - + data_vars={ "lon": (["time"], lon, VARS_ATTRIBUTES['lon'].copy()), "lat": (["time"], lat, VARS_ATTRIBUTES['lat'].copy()), @@ -385,13 +382,13 @@ class MAPIES: "history": f"Created {datetime.now()}", }, ) - + # Define encoding for compact storage encoding = { - "time": {"dtype": "float32"}, - "lon": {"dtype": "float32"}, - "lat": {"dtype": "float32"}, - VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: {"dtype": "float32"}, + "time": {"dtype": "float32"}, + "lon": {"dtype": "float32"}, + "lat": {"dtype": "float32"}, + VARS_ATTRIBUTES[self.obs_var]["mapies_variable"]: {"dtype": "float32"}, } # Save to NetCDF file @@ -431,8 +428,8 @@ class MAPIES: return elif months is not None: for m in months: - try: - month_data = self.monthly_dict[m] + try: + month_data = self.monthly_dict[m] lon_values = month_data["lon"] lat_values = month_data["lat"] obs_values = month_data["obs"] @@ -440,7 +437,7 @@ class MAPIES: print(f"Error for month {m}: {e}") continue title = f"Observation 2D plot of month {m} from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_month_{m}_year_{self.year}.png" + filename = filename or f"{self.datatype}_2D_obs_month_{m}_year_{self.year}.png" super().plot_2D_obs_custom( lon_values=lon_values, lat_values=lat_values, @@ -449,10 +446,10 @@ class MAPIES: title=title, filename=filename, ) - else: - try: + else: + try: lon_values = self.yearly_lon - lat_values = self.yearly_lat + lat_values = self.yearly_lat obs_values = self.yearly_obs except Exception as e: print(f"Error for yearly data: {e}") @@ -467,7 +464,8 @@ class MAPIES: title=title, filename=filename, ) - + + def plot_2D_num_obs(self, months=[0], filename=None, outdir=None): if months == [0]: @@ -491,8 +489,8 @@ class MAPIES: return elif months is not None: for m in months: - try: - month_data = self.monthly_dict[m] + try: + month_data = self.monthly_dict[m] lon_values = month_data["lon"] lat_values = month_data["lat"] cum_count = month_data["count"] @@ -500,7 +498,7 @@ class MAPIES: print(f"Error for month {m}: {e}") continue title = f"Plot of the number of valid observations for month {m} from {self.start_date} to {self.end_date}" - filename = filename or f"{self.datatype}_2D_obs_count_month_{m}_year_{self.year}.png" + filename = filename or f"{self.datatype}_2D_obs_count_month_{m}_year_{self.year}.png" self.plot_2D_num_obs_custom( lon_values=lon_values, lat_values=lat_values, @@ -509,10 +507,10 @@ class MAPIES: title=title, filename=filename, ) - else: - try: + else: + try: lon_values = self.yearly_lon - lat_values = self.yearly_lat + lat_values = self.yearly_lat cum_count = self.yearly_count except Exception as e: print(f"Error for yearly data: {e}") @@ -527,7 +525,7 @@ class MAPIES: title=title, filename=filename, ) - + # TODO this can be moved to a plotting function python file @timeit @@ -553,8 +551,8 @@ class MAPIES: level_format = [f"{level:.5f}" for level in levels] # Define the custom color list colors = [ - '#290AD8', '#2645FA', '#3A91FF', '#66CBFF', '#99EEFF', - '#CDFCFF', '#F2FFD9', '#FFEEAA', '#FFCA88', '#FC966A', + '#290AD8', '#2645FA', '#3A91FF', '#66CBFF', '#99EEFF', + '#CDFCFF', '#F2FFD9', '#FFEEAA', '#FFCA88', '#FC966A', '#ED5750', '#CB1C2D', '#9A001E', '#68001D', '#690067' ] @@ -590,16 +588,16 @@ class MAPIES: if display_fig: plt.show() print(f"Saved plot: {filepath}") - + # TODO this can be moved to a plotting function python file - @timeit + @timeit def plot_2D_num_obs_custom( - self, lon_values=None, - lat_values=None, - num_obs=None, + self, lon_values=None, + lat_values=None, + num_obs=None, outdir="./", - title=None, + title=None, filename=None, display_fig=False ): @@ -650,7 +648,6 @@ class MAPIES: plt.show() print(f"Saved plot: {filepath}") - #return fig @staticmethod diff --git a/mapies/tests/test_viirs.py b/mapies/tests/test_viirs.py index d204e5eb..96dd71e1 100644 --- a/mapies/tests/test_viirs.py +++ b/mapies/tests/test_viirs.py @@ -1,8 +1,5 @@ #!/usr/bin/env python -import pytest -import pandas as pd -import numpy as np from pathlib import Path from mapies.viirs import VIIRS diff --git a/mapies/tropomi.py b/mapies/tropomi.py index a9feb6d4..180649c8 100644 --- a/mapies/tropomi.py +++ b/mapies/tropomi.py @@ -1,24 +1,15 @@ #!/usr/bin/env python -from dataclasses import dataclass, field -from typing import List, Tuple -from datetime import datetime, timedelta - -from functools import wraps +from typing import List, Tuple +from datetime import datetime from mapies.mapies import MAPIES -from mapies.util.func_tools import timeit, get_file_list, time_converter, inflate_array import time -import multiprocessing import numpy as np import pandas as pd import xarray as xr import os - from mapies.util.func_tools import * -#timeit, get_file_list, time_converter, frequency_int, error_estimation, time_domain_selection, geo_to_rot -from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function from mapies.util.multiprocessing import TROPOMIMultiprocessing -import json from numpy.typing import DTypeLike, NDArray os.environ["OMP_NUM_THREADS"] = "8" @@ -33,29 +24,26 @@ class TROPOMI(MAPIES): """ Class for VIIRS specific data """ - + def __init__(self, start_date, end_date,**kwargs): """ Inherited init class with new variables """ super().__init__(start_date, end_date) - + self.time_orig = np.datetime64("1900-01-01T00:00:00") self.dest = kwargs.get("dest") if not os.path.exists(self.dest): raise Exception("Output directory doesn't exist") self.indir = kwargs.get("indir") - self.year = int(start_date[:4]) + self.year = int(start_date[:4]) #self.quality_flag_limit self.monthly_dict = {} - self.daily_dict = {} # Dictionary to produce daily netcdf files + self.daily_dict = {} # Dictionary to produce daily netcdf files self.datatype = "tropomi" self.mp_type = TROPOMIMultiprocessing - # Add quality value filter number - - if isinstance(self.start_date, str): self.start_date = time_converter(self.start_date) @@ -82,13 +70,13 @@ class TROPOMI(MAPIES): # Convert to Julian day date = datetime.strptime(date, '%Y%m%d%H%M').strftime('%Y%m%d') - #TODO add something that takes in the + #TODO add something that takes in the multiple different types of tropomi filepaths = f'{self.indir}/**/S5P_OFFL_L2__NO2____{date}*' file_patterns.append(filepaths) # Add pattern to the set files = sorted(get_file_list(file_patterns)) print(f"Total number of files: {len(files)}") - + # I'm not sure we need this part files_filtered = [] @@ -100,13 +88,13 @@ class TROPOMI(MAPIES): if self.start_date <= file_dt <= self.end_date: files_filtered.append(file) - + first_idx = max(0, files.index(files_filtered[0]) - 2) last_idx = min(len(files) - 1, files.index(files_filtered[-1]) + 2) self.files = files[first_idx : last_idx + 1] - + print(f"Total number of filtered files: {len(self.files)}") - + for file in self.files: parts = file.split('_') julian_day = parts[-6] @@ -119,7 +107,7 @@ class TROPOMI(MAPIES): "lat": None, # Latitudes array "obs": None, # Observations array "count": None, # Count of observations used to compute the average - } + } if date.day not in self.daily_dict: self.daily_dict[date.day] = { "files": [], # List to store file paths @@ -132,13 +120,13 @@ class TROPOMI(MAPIES): # Append the file to the list of files for the corresponding month self.monthly_dict[date.month]["files"].append(file) self.daily_dict[date.day]["files"].append(file) - + print(f"Discovered files for months: {list(self.monthly_dict.keys())}") for month, data in self.monthly_dict.items(): print(f" Month {month:02d}: {len(data['files'])} files") - + print(f"Discovered files for days: {list(self.daily_dict.keys())}") - + # Print the daily disctionary to see if files are stored in chronological order for day, data in self.daily_dict.items(): print(f"Day {day:02d}: {len(data['files'])} files") diff --git a/mapies/util/func_tools.py b/mapies/util/func_tools.py index 4823d55c..60fd8709 100644 --- a/mapies/util/func_tools.py +++ b/mapies/util/func_tools.py @@ -14,6 +14,11 @@ import multiprocessing TIME_ORIGIN = np.datetime64("1993-01-01T00:00:00") +def exception_factory(exception, message): + return exception(message) + + + #Timing decorator def timeit(func): @wraps(func) @@ -33,7 +38,7 @@ def get_filepaths(pattern): """ return glob(pattern, recursive=True) -# Function from Guillaume's script could be useful + # Only supported in python3.10 def get_file_list(patterns) -> List[str]: """ @@ -53,9 +58,6 @@ def get_file_list(patterns) -> List[str]: return sorted(files_final) -def exception_factory(exception, message): - return exception(message) - def time_converter(date:str) -> pd.Timestamp: """ @@ -66,7 +68,7 @@ def time_converter(date:str) -> pd.Timestamp: if ch in date: date = date.replace(ch,"") logging.info(f"Special character {ch} was replaced in the date") - + if len(date) == 14: # Probable str format %Y%m%d%H%M%S date = pd.to_datetime(date, format="%Y%m%d%H%M%S") elif len(date) == 12: # Probable str format %Y%m%d%H%M @@ -91,15 +93,13 @@ def quality_assurance(ds, obs, lat, lon, time_values, qualityFlags): mask = np.zeros_like(obs, dtype=bool) # Iterate through each qa flag for i, (flag, info) in enumerate(qualityFlags.items()): - # TODO add logger of flag being applied with condition and value - + flag_mask = np.zeros_like(obs, dtype=bool) arr = ds[info["qa_name"]].values.flatten() # Get qa conditions try: condition = info["qa_condition"] except KeyError: - # print("No QA condition passed, putting equal") condition = "equal" # Get qa values try: @@ -119,9 +119,9 @@ def quality_assurance(ds, obs, lat, lon, time_values, qualityFlags): flag_mask |= (arr >= value) elif condition == "less_than_or_equal": flag_mask |= (arr <= value) - elif condition == "equal": + elif condition == "equal": flag_mask |= (arr == value) - + if i == 0: mask = flag_mask else: @@ -131,7 +131,7 @@ def quality_assurance(ds, obs, lat, lon, time_values, qualityFlags): except KeyError: # print("No combination of and or or passed for the condition, presuming or(False)") combination = False - + if combination: mask &= flag_mask else: @@ -156,7 +156,7 @@ def spatial_domain_selection(obs, lat, lon, time, spatial_bounds): # Check that the spatial bounds are 2 dimensional if len(spatial_bounds) != 2: raise Exception("Spatial bounds passed are not 2 dimensional, make sure to pass [[left_lon, right_lon], [left_lat, right_lat]]") - + # Get the bounds lon_l_bound, lon_r_bound = spatial_bounds[0] lat_l_bound, lat_r_bound = spatial_bounds[1] @@ -164,7 +164,7 @@ def spatial_domain_selection(obs, lat, lon, time, spatial_bounds): # Crop the spatial variabes lon_mask = np.logical_and(lon>=lon_l_bound, lon<=lon_r_bound) lat_mask = np.logical_and(lat>=lat_l_bound, lat<=lat_r_bound) - + # Masking the data lon = lon[lon_mask & lat_mask] lat = lat[lon_mask & lat_mask] @@ -174,18 +174,18 @@ def spatial_domain_selection(obs, lat, lon, time, spatial_bounds): assert len(lon) == len(lat) assert len(obs) == len(time) assert len(lon) == len(obs) - + return obs, lat, lon, time - + # Time Domain Selection function based off of xarray Will this work with dask too, a question for later def time_domain_selection( - obs, - lat, - lon, - time_values, - start_date, - end_date, + obs, + lat, + lon, + time_values, + start_date, + end_date, closed = None) -> np.array: """ :param ds: one or several xarray datasets to cut down to correct size. @@ -204,19 +204,19 @@ def time_domain_selection( if isinstance(end_date, str): end_date = pd.to_datetime(end_date) - + # For the time column we can either centralise it in the seperate Instrument objects or pass it here, # I will start by passing it here def is_between_times(array): - + if closed == "left": return (array > start_date) & (array <= end_date) elif closed == "right": return (array >= start_date) & (array < end_date) else: return (array >= start_date) & (array <= end_date) - + time_mask = np.where(is_between_times(time_values))[0] @@ -249,7 +249,7 @@ def frequency_int(time_interval): if time_interval == "H": frequency = 1 elif time_interval == "D": - frequency = 24 + frequency = 24 else: frequency = int(time_interval[0]) return frequency @@ -269,7 +269,7 @@ def error_estimation(datatype:str, obs:np.typing.NDArray, unc_const:list) -> np. -def new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date): +def flatten_arrays(ds, obs_var, lat_var, lon_var, time_var): """ * Preprocessing of the dataset * Flattening arrays @@ -279,34 +279,33 @@ def new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end time_dims = ds[time_var].dims time_shape = ds[time_var].shape time_attrs = ds[time_var].attrs # Useful if we need to convert to datetime - + # Get time values flatten and convert to datetime values time_values = ds[time_var].values time_values = time_values.flatten() - + if time_values.dtype == "timedelta64[ns]": logging.info("Adding time origin to time values as the time variable is in timedelta") time_values = np.add(TIME_ORIGIN, time_values) time_values = pd.to_datetime(time_values) else: time_values = pd.to_datetime(time_values) - - + + obs_shape = ds[obs_var].shape obs_attrs = ds[obs_var].attrs - + # Duplicate values in the time array to get the same shape as the flattened variable array if time_shape != obs_shape: time_values = inflate_array(time_values, time_shape, obs_shape) - + # Lon and lat values lon_dims = ds[lon_var].dims lon_shape = ds[lon_var].shape lon_attrs = ds[lon_var].attrs lon_values = ds[lon_var].values.flatten() - lat_dims = ds[lat_var].dims lat_shape = ds[lat_var].shape lat_attrs = ds[lat_var].attrs @@ -404,70 +403,6 @@ def create_batches(files, batch_size): - -# def process_single_file( -# r, -# file, -# obs_var, -# lat_var, -# lon_var, -# time_var, -# start_date, -# end_date, -# apply_qa, -# qualityFlags, -# flag=False, -# geospatial_crop=None -# ): -# """ -# Process a single file and return aggregated results or flattened arrays. -# """ -# try: -# ds = xr.open_dataset(file, engine="h5netcdf") - -# obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) - -# #Run a time domain selection -# if apply_qa: -# obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) - -# obs, lon, lat, time_values = time_domain_selection(obs, lat, lon, time_values, start_date, end_date) - -# if geospatial_crop is not None: -# obs, lat, lon, time_values = spatial_domain_selection(obs, lat, lon, time_values, geospatial_crop) - -# if not flag: -# # Regrid and aggregate -# if isinstance(r, RotatedGrid): -# lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) -# elif isinstance(r, RegularGrid): -# lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) -# ds.close() -# return obs_agg, lat_agg, lon_agg, count_obs -# else: -# ds.close() -# return obs, lat, lon, time_values - -# except Exception as e: -# print(f"Error processing file {file}: {e}") - -# return None - - - -# def process_batch(args, batch): - -# with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed -# results = pool.starmap(process_single_file, args) - -# # Filter valid results and log failed files -# valid_results = [res for res in results if res is not None] -# print(f"Valid results: {len(valid_results)}") -# failed_files= [file for file, res in zip(batch, results) if res is None] - -# return valid_results, failed_files - - def combine_arrays(valid_results, cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count): """ Combine the arrays from the valid results. In other words sum the arrays and count the number of observations. diff --git a/mapies/util/multiprocessing.py b/mapies/util/multiprocessing.py index 0557fe24..3ce5bed9 100644 --- a/mapies/util/multiprocessing.py +++ b/mapies/util/multiprocessing.py @@ -1,35 +1,28 @@ -import time -import logging + import numpy as np import pandas as pd import xarray as xr - -from functools import wraps -from numpy import cos, sin, arctan, pi, nan -from glob import glob from typing import List -from datetime import datetime from abc import ABC, abstractmethod import multiprocessing from dataclasses import dataclass from numpy.typing import DTypeLike, NDArray - from mapies.util.func_tools import * -from mapies.grids.monarch import Grid, RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function +from mapies.grids.monarch import Grid, RotatedGrid, RegularGrid, regridding_function def process_single_file_tropomi( - r, - file, - obs_var, - lat_var, - lon_var, - time_var, - start_date, - end_date, - apply_qa, - qualityFlags, - flag=False, + r, + file, + obs_var, + lat_var, + lon_var, + time_var, + start_date, + end_date, + apply_qa, + qualityFlags, + flag=False, geospatial_crop=None ): """ @@ -43,7 +36,7 @@ def process_single_file_tropomi( # Flatten arrays try: - obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) + obs, lat, lon, time_values = flatten_arrays(ds, obs_var, lat_var, lon_var, time_var) except Exception as e: print(f"Error processing file {file}: {e}") return None @@ -89,17 +82,17 @@ def process_single_file_tropomi( def process_single_file_viirs( - r, - file, - obs_var, - lat_var, - lon_var, - time_var, - start_date, - end_date, - apply_qa, - qualityFlags, - flag=False, + r, + file, + obs_var, + lat_var, + lon_var, + time_var, + start_date, + end_date, + apply_qa, + qualityFlags, + flag=False, geospatial_crop=None ): """ @@ -113,7 +106,7 @@ def process_single_file_viirs( # Flatten arrays try: - obs, lat, lon, time_values = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) + obs, lat, lon, time_values = flatten_arrays(ds, obs_var, lat_var, lon_var, time_var) except Exception as e: print(f"Error processing file {file}: {e}") return None @@ -155,7 +148,7 @@ def process_single_file_viirs( else: ds.close() return obs, lat, lon, time_values - + class GenericMultiProcessing(ABC): @@ -174,20 +167,20 @@ class VIIRSMultiprocessing(GenericMultiProcessing): Multiprocessing method for VIIRS """ r: Grid - obs_var: NDArray - lat_var: NDArray + obs_var: NDArray + lat_var: NDArray lon_var: NDArray time_var: NDArray start_date: pd.Timestamp - end_date: pd.Timestamp + end_date: pd.Timestamp apply_qa: bool qualityFlags: dict batch: List[str] geospatial_crop: NDArray = None flag: bool = False - - + + def process_batch(self): """ Batch processing for VIIRS @@ -195,19 +188,19 @@ class VIIRSMultiprocessing(GenericMultiProcessing): args = [ ( - self.r, - file, - self.obs_var, - self.lat_var, - self.lon_var, - self.time_var, - self.start_date, - self.end_date, - self.apply_qa, - self.qualityFlags, - self.flag, + self.r, + file, + self.obs_var, + self.lat_var, + self.lon_var, + self.time_var, + self.start_date, + self.end_date, + self.apply_qa, + self.qualityFlags, + self.flag, self.geospatial_crop - ) + ) for file in self.batch ] @@ -230,20 +223,20 @@ class TROPOMIMultiprocessing(GenericMultiProcessing): Multiprocessing method for TROPOMI """ r: Grid - obs_var: NDArray - lat_var: NDArray + obs_var: NDArray + lat_var: NDArray lon_var: NDArray time_var: NDArray start_date: pd.Timestamp - end_date: pd.Timestamp + end_date: pd.Timestamp apply_qa: bool qualityFlags: dict batch: List[str] geospatial_crop: NDArray = None flag: bool = False - - + + def process_batch(self): """ Batch processing for TROPOMI @@ -251,19 +244,19 @@ class TROPOMIMultiprocessing(GenericMultiProcessing): args = [ ( - self.r, - file, - self.obs_var, - self.lat_var, - self.lon_var, - self.time_var, - self.start_date, - self.end_date, - self.apply_qa, - self.qualityFlags, + self.r, + file, + self.obs_var, + self.lat_var, + self.lon_var, + self.time_var, + self.start_date, + self.end_date, + self.apply_qa, + self.qualityFlags, self.flag, self.geospatial_crop - ) + ) for file in self.batch ] diff --git a/mapies/util/variables_names.py b/mapies/util/variables_names.py index a17b08cb..fd1c4fbd 100644 --- a/mapies/util/variables_names.py +++ b/mapies/util/variables_names.py @@ -2,6 +2,6 @@ VARS_ATTRIBUTES={ "lon": {"description": "Longitude of observations", "units": "degrees_east"}, "lat": {"description": "Latitude of observations", "units": "degrees_north"}, "Aerosol_Optical_Thickness_550_Land_Ocean_Best_Estimate": {"description": "Aerosol_Optical_Thickness_550_Land_Ocean_Best_Estimate values", "units": "unitless", "mapies_variable":"aod550"}, - "nitrogendioxide_tropospheric_column": {"description": "nitrogendioxide_tropospheric_column", "units": "unitless", "mapies_variable":"no2"}, + "nitrogendioxide_tropospheric_column": {"description": "nitrogendioxide_tropospheric_column", "units": "molm-2", "mapies_variable":"no2"}, "title": {"description": "VIIRS data"}, } \ No newline at end of file diff --git a/mapies/viirs.py b/mapies/viirs.py index 1ab4f01b..8629f835 100644 --- a/mapies/viirs.py +++ b/mapies/viirs.py @@ -1,34 +1,20 @@ #!/usr/bin/env python from datetime import datetime, timedelta -from functools import wraps import logging import os import time -#from tqdm import tqdm -import yaml -import multiprocessing -import gc import dask.array as da -from glob import glob as glob_module -import sys import numpy as np import pandas as pd -import xarray as xr -import matplotlib.pyplot as plt -import cartopy import cartopy.crs as ccrs -#from netCDF4 import Dataset - from .mapies import MAPIES from mapies.util.func_tools import * -#timeit, get_file_list, time_converter, frequency_int, error_estimation, time_domain_selection, geo_to_rot -from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function +from mapies.grids.monarch import regridding_function from mapies.util.variables_names import VARS_ATTRIBUTES from pathlib import Path -import json -from numpy.typing import DTypeLike, NDArray + from mapies.util.multiprocessing import VIIRSMultiprocessing os.environ["OMP_NUM_THREADS"] = "8" @@ -47,7 +33,7 @@ class VIIRS(MAPIES): Inherited init class with new variables """ super().__init__(start_date, end_date) - + self.time_orig = np.datetime64("1993-01-01T00:00:00") self.datatype="viirs" self.mp_type = VIIRSMultiprocessing @@ -57,14 +43,14 @@ class VIIRS(MAPIES): self.indir = kwargs.get("indir") self.year = int(start_date[:4]) - self.start_day = int(start_date[6:8]) - self.end_day = int(end_date[6:8]) + self.start_day = int(start_date[6:8]) + self.end_day = int(end_date[6:8]) self.monthly_dict = {} self.daily_dict = {} # Dictionary to produce daily netcdf files self.dates_slice = pd.date_range( - self.start_date, - self.end_date, + self.start_date, + self.end_date, ).strftime('%Y%m%d%H%M') if isinstance(self.start_date, str): @@ -92,7 +78,7 @@ class VIIRS(MAPIES): self.unc_const = unc_const else: self.unc_const = da_dict["uncertainty_constants"] #float, int, list, dict - + @timeit @@ -107,21 +93,21 @@ class VIIRS(MAPIES): date = datetime.strptime(date, '%Y%m%d%H%M').strftime('%Y%j') filepaths = f'{self.indir}/**/AERDB_L2_VIIRS_NOAA20.A{date}*' - file_patterns.append(filepaths) + file_patterns.append(filepaths) files = sorted(get_file_list(file_patterns)) print(f"Total number of files: {len(files)}") else: raise Exception("Please specify an input directory") - + files_filtered = filter_files(files, self.start_date, self.end_date) - + first_idx = max(0, files.index(files_filtered[0]) - 2) last_idx = min(len(files) - 1, files.index(files_filtered[-1]) + 2) self.files = files[first_idx : last_idx + 1] - + print(f"Total number of filtered files: {len(self.files)}") - + for file in self.files: parts = file.split('.') julian_day = parts[1][1:] @@ -133,7 +119,7 @@ class VIIRS(MAPIES): "lat": None, # Latitudes array "obs": None, # Observations array "count": None, # Count of observations used to compute the average - } + } if date.day not in self.daily_dict: self.daily_dict[date.day] = { "files": [], # List to store file paths @@ -145,13 +131,13 @@ class VIIRS(MAPIES): self.monthly_dict[date.month]["files"].append(file) self.daily_dict[date.day]["files"].append(file) - + print(f"Discovered files for months: {list(self.monthly_dict.keys())}") for month, data in self.monthly_dict.items(): print(f" Month {month:02d}: {len(data['files'])} files") - + print(f"Discovered files for days: {list(self.daily_dict.keys())}") - + for day, data in self.daily_dict.items(): print(f"Day {day:02d}: {len(data['files'])} files") @@ -164,15 +150,11 @@ class VIIRS(MAPIES): """ Function that returns all the needed variables for the DA """ - - # Calculate error - - outfiles = [] - + self.frequency = frequency_int(frequency) da_dates_slice = pd.date_range( - self.start_date, + self.start_date, self.end_date, freq=frequency ).strftime('%Y%m%d%H%M') @@ -182,7 +164,7 @@ class VIIRS(MAPIES): r_border = (datetime.strptime(date, "%Y%m%d%H%M") + timedelta(hours=self.frequency/2)).strftime('%Y%m%d%H%M') filename = Path(self.dest).joinpath(f'obs{date}.nc') - + # concat all days that we need obs = np.array([]) lon = np.array([]) @@ -223,7 +205,6 @@ class VIIRS(MAPIES): except ValueError: raise Exception("Have you specified a grid_repr") - # Create new arrays with same length as obs obsid = np.full(shape=obs.shape, fill_value=self.obsid_da, dtype=self.int_type) @@ -239,7 +220,7 @@ class VIIRS(MAPIES): coords = dict(index=("index", np.indices(obs.shape)[0, :])) data_vars = dict( time=(["index",], time_values), - lon=(["index",], lon), + lon=(["index",], lon), lat=(["index",], lat), obs=(["index",], obs), obserr=(["index", ], obserr), @@ -250,12 +231,12 @@ class VIIRS(MAPIES): ) if self.grid_repr == "rotated": rotated_dict = dict( - rlon=(["index",], rlon), + rlon=(["index",], rlon), rlat=(["index",], rlat), #Also return rlon/rlat values if available ) data_vars.update(rotated_dict) - + #save the plots if specified for each time frequency if save_figs: self.plot_2D_obs_custom(lon, lat, obs, outdir=self.dest, filename=f"da_obs_{date}.png") diff --git a/run/run_tropomi.py b/run/run_tropomi.py index 2b438d8b..f8188943 100644 --- a/run/run_tropomi.py +++ b/run/run_tropomi.py @@ -1,7 +1,6 @@ from mapies.tropomi import TROPOMI -import numpy as np import time -# from dask.distributed import Client + if __name__ == "__main__": @@ -17,7 +16,6 @@ if __name__ == "__main__": c.gather_nc_files() c.process_avg_data(monthly_avg=False, batch_size = 14, apply_qa=True, geospatial_crop=[[-40, 60], [0, 70]]) - # c.yearly_average() c.plot_2D_observations(months=[0], outdir=outdir) c.plot_2D_num_obs(months=[0], outdir=outdir) c.process_lazy_data(apply_qa=True, geospatial_crop=[[-40, 60], [0, 70]]) diff --git a/run/run_viirs.py b/run/run_viirs.py index 9cdb89eb..47053401 100644 --- a/run/run_viirs.py +++ b/run/run_viirs.py @@ -28,8 +28,8 @@ if __name__ == "__main__": c.yearly_average() c.plot_2D_observations(months=None, filename="new.png", outdir=outdir) c.plot_2D_num_obs(months=None, outdir=outdir) - #c.process_lazy_data(apply_qa=True, save=True) - #c.to_da(frequency="3H", save_figs=True) + c.process_lazy_data(apply_qa=True, save=False) + c.to_da(frequency="3H", save_figs=True) -- GitLab