diff --git a/.gitignore b/.gitignore index 503b47d8c1742777c940c6a79f01b5b0a087e5e6..8529173ceb8faa6d8a96aeffc5d0afa3979f682a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.orig venv out-logs +logs figures mapies-env outputs diff --git a/mapies/grids/monarch.py b/mapies/grids/monarch.py index 6929811a93ca06676f1e744f04d58f2752d52e7c..1b11ec9c14b24a17cf705af62a42ed21b208f2be 100644 --- a/mapies/grids/monarch.py +++ b/mapies/grids/monarch.py @@ -13,6 +13,32 @@ from shapely.geometry import Polygon, Point from mapies.util.func_tools import * import logging +logger = logging.getLogger(__name__) +logging.basicConfig(filename='../logs/mapies.log', level=logging.INFO) + + +def regridding_function(r, lon, lat, obs, obserr=None): + + # If obserr is selected return that + if obserr is not None: + if isinstance(r, RotatedGrid): + lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, obserr_agg, count_obs = r.aggregate(lon, lat, obs, obserr) + return lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, obserr_agg, count_obs + elif isinstance(r, RegularGrid): + lon_agg, lat_agg, obs_agg, obserr_agg, count_obs = r.aggregate(lon, lat, obs, obserr) + return lon_agg, lat_agg, obs_agg, obserr_agg, count_obs + else: + raise ValueError("Invalid grid representation") + + else: + if isinstance(r, RotatedGrid): + lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = r.aggregate(lon, lat, obs) + return lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs + elif isinstance(r, RegularGrid): + lon_agg, lat_agg, obs_agg, count_obs = r.aggregate(lon, lat, obs) + return lon_agg, lat_agg, obs_agg, count_obs + else: + raise ValueError("Invalid grid representation") @dataclass @@ -35,6 +61,7 @@ class Grid: This can then be used in many grid representations, for example in a rotated grid used in the ihouse Monarch model """ + logger.info("Creating Grid") if (self.nlon == None) and (self.nlat == None): self.nlat = int((abs(self.south) / self.dlat) * 2) self.nlon = int((abs(self.west) / self.dlon) * 2) diff --git a/mapies/mapies.py b/mapies/mapies.py index 67b93688d137d8177f5f222e8e3a4dbd826d33c7..e8da0cae2c4e97bd0089d0d93f08f088380115d8 100644 --- a/mapies/mapies.py +++ b/mapies/mapies.py @@ -22,7 +22,8 @@ import matplotlib.cm as cm from matplotlib.colors import ListedColormap import matplotlib.colors as mcolors - +logger = logging.getLogger(__name__) +logging.basicConfig(filename='../logs/mapies.log', level=logging.DEBUG) class MAPIES: """ @@ -40,6 +41,7 @@ class MAPIES: self.start_date = start_date self.end_date = end_date + self.grid_dict = { "rotated":RotatedGrid, "regular":RegularGrid, @@ -60,7 +62,8 @@ class MAPIES: try: self.config = yaml.safe_load(open(config_file)) except FileNotFoundError: - logging.error("This is not a config file") + logger.error("Config file passed does not exist: Using default config instead") + self.config = yaml.safe_load(open(os.path.join(module_dir, "config/satellite_config.yaml"))) else: self.config = yaml.safe_load(open(os.path.join(module_dir, "config/satellite_config.yaml"))) @@ -85,10 +88,11 @@ class MAPIES: else: self.grid_repr = grid_dict["grid_repr"] #str - # Cread grid object if grid repr is selected in the input - self.grid_config = grid_dict[self.grid_repr] - self.grid = self.grid_dict[self.grid_repr](**self.grid_config) - print(self.grid) + if self.grid_repr: + # Cread grid object if grid repr is selected in the input + self.grid_config = grid_dict[self.grid_repr] + self.grid = self.grid_dict[self.grid_repr](**self.grid_config) + print(self.grid) qa_dict = self.config[self.datatype]["qa"] @@ -118,7 +122,7 @@ class MAPIES: self.time_values = self.time_values.flatten() if self.time_values.dtype == "timedelta64[ns]": - logging.info("Adding time origin to time values as the time variable is in timedelta") + logger.info("Adding time origin to time values as the time variable is in timedelta") self.time_values = np.add(self.time_orig, self.time_values) self.time_values = pd.to_datetime(self.time_values) else: @@ -135,6 +139,10 @@ class MAPIES: """ Plotting the observations """ + # Regrid if specified as argument + if self.grid_repr: + logger.info("Performing regridding of data") + self.regridding_function() # TODO: Make these arguments figsize = (15,10) markersize = 2.5 @@ -321,6 +329,7 @@ class MAPIES: self.obs = self.obs[mask] self.lon_values = self.lon_values[mask] self.lat_values = self.lat_values[mask] + self.time_values = self.time_values[mask] @staticmethod diff --git a/mapies/tests/test_data/VIIRS/AERDB_L2_VIIRS_NOAA20.A2024001.0024.002.2024001191538.nc b/mapies/tests/test_data/VIIRS/AERDB_L2_VIIRS_NOAA20.A2024001.0024.002.2024001191538.nc new file mode 100644 index 0000000000000000000000000000000000000000..b199fd32d19fec99e93a38c667e9399f244cdd63 Binary files /dev/null and b/mapies/tests/test_data/VIIRS/AERDB_L2_VIIRS_NOAA20.A2024001.0024.002.2024001191538.nc differ diff --git a/mapies/tests/test_monarch_grid.py b/mapies/tests/test_monarch_grid.py index 8555e5074677a0fdb712ae96226181aa83cfd4ef..527c973e5824df62c10211323df6e422f4e6b0ea 100644 --- a/mapies/tests/test_monarch_grid.py +++ b/mapies/tests/test_monarch_grid.py @@ -10,7 +10,7 @@ from mapies.grids.monarch import RegularGrid, RotatedGrid, IrregularRotatedGrid "test_input,expected", [ ((np.array([-10, -10, -10]), np.array([-10, -10, -10]), np.array([1, 2, 3])), - (np.array([-10.]), np.array([-10.]), np.array([2.]), np.array([[3]]))), + (np.array([-10., 0., 0., 0.]), np.array([-10., 0., 0., 0.]), np.array([2., 0., 0., 0.]), np.array([[3., 0., 0., 0.]]))), ] ) def test_regular_grid_aggregation(test_input, expected): diff --git a/mapies/tests/test_viirs.py b/mapies/tests/test_viirs.py new file mode 100644 index 0000000000000000000000000000000000000000..046b4f3fc658e321ef29d2be9d53951bb6234b3c --- /dev/null +++ b/mapies/tests/test_viirs.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python + +import pytest +import pandas as pd +import numpy as np +from pathlib import Path +from mapies.viirs import VIIRS + +# test data loaded in new directory mapies/tests/test_data this way we can read it in and always test the function + + +def test_to_da(): + start_date = "202401010000" + end_date = "202401010030" + outdir="./" + indir="test_data" + c = VIIRS(start_date, end_date, dest=outdir, indir=indir, grid_repr="rotated") + print(c.dest) + c.read_nc() + c.process_data(monthly_avg=False) + outfiles = c.to_da(frequency="3H", save_nc= True, save_figs=False) + assert outfiles == [Path("obs202401010000.nc")] \ No newline at end of file diff --git a/mapies/util/func_tools.py b/mapies/util/func_tools.py index 0722ec2e5d6fe67df05ff0a5108e4e06a431d12b..aec2f8d67ebf4159226829d4989001281356cb61 100644 --- a/mapies/util/func_tools.py +++ b/mapies/util/func_tools.py @@ -50,6 +50,7 @@ def get_file_list(patterns) -> List[str]: for file_list in pool.imap_unordered(get_filepaths, args): files_final.extend(file_list) # Flatten the lists into files_final + return sorted(files_final) def exception_factory(exception, message): diff --git a/mapies/viirs.py b/mapies/viirs.py index 3ede174d97a16f8b276d7c5d496f35c6800ff66b..7b066dd4f8ef59f90f64cf53d8de6fc78d7f6ac4 100644 --- a/mapies/viirs.py +++ b/mapies/viirs.py @@ -23,35 +23,19 @@ from netCDF4 import Dataset from .mapies import MAPIES from mapies.util.func_tools import * #timeit, get_file_list, time_converter, frequency_int, error_estimation, time_domain_selection, geo_to_rot -from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid +from mapies.grids.monarch import RotatedGrid, IrregularRotatedGrid, RegularGrid, regridding_function from pathlib import Path from glob import glob as glob_module from memory_profiler import profile os.environ["OMP_NUM_THREADS"] = "8" -# TODO: still need to do a load of filtering huuuhhh -# TODO: Create Yaml config file for time, lat, lon, uncertainty const etc. +logger = logging.getLogger(__name__) +logging.basicConfig(filename='../logs/mapies.log', level=logging.INFO) +TIME_ORIGIN = np.datetime64("1900-01-01T00:00:00") + -# Uncertainty constants should be able to be passed as number, array or dict - -def rotate_function(r, lon, lat, obs): - - if isinstance(r, RotatedGrid): - - lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = r.aggregate(lon, lat, obs) - - return lon_agg, lat_agg, obs_agg, count_obs - - elif isinstance(r, RegularGrid): - - lon_agg, lat_agg, obs_agg, count_obs = r.aggregate(lon, lat, obs) - - return lon_agg, lat_agg, obs_agg, count_obs - - else: - raise ValueError("Invalid grid representation") @@ -67,9 +51,11 @@ def process_single_file(r, file, obs_var, lat_var, lon_var, apply_qa, qualityFla if apply_qa: obs, lat, lon = quality_assurance(ds, obs, lat, lon, qualityFlags) - # Rotate and aggregate - lon_agg, lat_agg, obs_agg, count_obs = rotate_function(r, lon, lat, obs) - + # Regrid and aggregate + if isinstance(r, RotatedGrid): + lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + elif isinstance(r, RegularGrid): + lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) ds.close() # Close the file to free memory return obs_agg, lat_agg, lon_agg, count_obs @@ -103,7 +89,7 @@ def process_batch(r, batch_files, obs_var, lat_var, lon_var, chunk_size, quality lat = ds[lat_var].values.flatten() lon = ds[lon_var].values.flatten() - lon_agg, lat_agg, obs_agg, count_obs = rotate_function(r,lon, lat, obs) + lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r,lon, lat, obs) ds.close() @@ -140,7 +126,7 @@ class VIIRS(MAPIES): frequency = kwargs.get("frequency") self.dest = kwargs.get("dest") if not os.path.exists(self.dest): - raise Exception("Doesn't exist") + raise Exception("Output directory doesn't exist") self.indir = kwargs.get("indir") self.year = int(start_date[:4]) @@ -151,10 +137,7 @@ class VIIRS(MAPIES): self.dates_slice = pd.date_range( self.start_date, self.end_date, - freq=frequency ).strftime('%Y%m%d%H%M') - if frequency: - self.frequency = frequency_int(frequency) if isinstance(self.start_date, str): self.start_date = time_converter(self.start_date) @@ -171,6 +154,7 @@ class VIIRS(MAPIES): Read yaml config file """ super().read_config(**kwargs) + logger.info("Reading VIIRS specific config variables") da_dict = self.config[self.datatype]["da"] self.obsid_da = da_dict["obsid"] #int @@ -229,11 +213,14 @@ class VIIRS(MAPIES): ) # Apply QA - self.quality_assurance() + if self.apply_qa: + logger.info("Applying quality assurance filters") + self.quality_assurance() + @timeit - def rotate(self): + def regridding_function(self): """ Perform Rotation of Grid representation """ @@ -257,7 +244,7 @@ class VIIRS(MAPIES): @timeit def read_nc(self): """ - Read netcdf files with xarray + Returns list of files needed to be processed by the application """ file_patterns = [] print(f' Date slice = {self.dates_slice}') @@ -316,185 +303,13 @@ class VIIRS(MAPIES): for month, data in self.monthly_dict.items(): print(f" Month {month:02d}: {len(data['files'])} files") - @timeit - def to_da(self): - """ - Function that returns all the needed variables for the DA - """ - - # Calculate error - self.obserr = error_estimation(self.datatype, self.obs, self.unc_const) - - outfiles = [] - - for date in self.dates_slice: - l_border = (datetime.strptime(date, "%Y%m%d%H%M") - timedelta(hours=self.frequency/2)).strftime('%Y%m%d%H%M') - r_border = (datetime.strptime(date, "%Y%m%d%H%M") + timedelta(hours=self.frequency/2)).strftime('%Y%m%d%H%M') - # Included by Guillaume in his DA function so I believe he needs this to be true - filename = Path(self.dest).joinpath(f'obs{date}.nc') - # Cut down the time interval again then reindex obs, obserr, lat and lon - # Returning also arrays for type, lev, id and n with the same index as the time var - time_values, frequency_index = time_domain_selection( - self.time_values, - l_border, - r_border, - ) - - - # Reindex - obs = self.reindex(frequency_index, self.obs) - lon = self.reindex(frequency_index, self.lon_values) - lat = self.reindex(frequency_index, self.lat_values) - obserr = self.reindex(frequency_index, self.obserr) - # Run aggregation with grid representation - lon, lat, obs, obserr = self.grid.aggregate(lon, lat, obs, obserr) - rlon, rlat = geo_to_rot(lon, lat, centre_lon=20, centre_lat=35) - - - # Create new arrays with same length as obs - obsid = np.full(shape=obs.shape, fill_value=self.obsid_da, dtype=self.int_type) - obstype = np.full(shape=obs.shape, fill_value=30, dtype=self.int_type) - obslev = np.full(shape=obs.shape, fill_value=-99999, dtype=self.int_type) - obsn = np.full(shape=obs.shape, fill_value=1, dtype=self.int_type) - time_values = np.full(shape=obs.shape, fill_value=date, dtype=str) - - # Coords equals index - - assert obs.shape == obserr.shape - assert lon.shape == time_values.shape - coords = dict(index=("index", np.indices(obs.shape)[0, :])) - data_vars = dict( - time=(["index",], time_values), - lon=(["index",], lon), - lat=(["index",], lat), - rlon=(["index",], rlon), - rlat=(["index",], rlat), #Also return rlon/rlat values if available - obs=(["index",], obs), - obserr=(["index", ], obserr), - obsid=(["index", ], obsid), - obstype=(["index", ], obstype), - obslev=(["index", ], obslev), - obsn=(["index", ], obsn), - ) - - ds = self.to_xarray(coords=coords, data_vars=data_vars) - - ds.to_netcdf(filename, encoding={}) - outfiles.append(filename) - - return outfiles - - # Plots - def plot_2D_obs(self, outdir="./", **kwargs): - """ - Plotting the observations specific to VIIRS - """ - if self.grid_repr != "": - self.rotate() - super().plot_2D_obs_custom() - - - - @timeit - def plot_2D_observations(self, months=None, outdir=None): - - # if 0 in months: - if months == [0]: - try: - self.plot_2D_obs() - except Exception as e: - print(f"Error plotting all data: {e}") - return - elif months is not None: - # elif m in months: - for m in months: - try: - month_data = self.monthly_dict[m] # Access the dictionary for this month - # files = month_data["files"] # List of files for the month (not used for plotting) - lon_values = month_data["lon"] - lat_values = month_data["lat"] - obs_values = month_data["obs"] - except Exception as e: - print(f"Error for month {m}: {e}") - continue - # title = f"Observation 2D plot of {self.datatype.upper()} data from {self.start_date} to {self.end_date}" - title = f"Observation 2D plot of month {m} from {self.start_date} to {self.end_date}" - filename = f"{outdir}/{self.datatype}_2D_obs_month{m}_{self.year}_.png" - - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - ) - else: - try: - lon_values = self.yearly_lon - lat_values = self.yearly_lat - obs_values = self.yearly_obs - except Exception as e: - print(f"Error for yearly data: {e}") - return - title = f"Observation 2D plot of yearly average of year {self.year} from {self.start_date} to {self.end_date}" - filename = f"{outdir}/{self.datatype}_2D_yearly_obs_{self.year}.png" - super().plot_2D_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - obs_values=obs_values, - outdir=outdir, - title=title, - filename=filename, - ) - - def plot_2D_num_obs(self, months=None, outdir=None): - # if isinstance(months, int): - # months = [months] - if months is not None: - for m in months: - try: - month_data = self.monthly_dict[m] # Access the dictionary for this month - # files = month_data["files"] # List of files for the month (not used for plotting) - lon_values = month_data["lon"] - lat_values = month_data["lat"] - cum_count = month_data["count"] - except Exception as e: - print(f"Error for month {m}: {e}") - continue - # title = f"Observation 2D plot of {self.datatype.upper()} data from {self.start_date} to {self.end_date}" - title = f"Plot of the number of valid observations for month {m} from {self.start_date} to {self.end_date}" - filename = f"{outdir}/{self.datatype}_2D_obs_count_month{m}_{self.year}_trial4.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - ) - else: - try: - lon_values = self.yearly_lon - lat_values = self.yearly_lat - cum_count = self.yearly_count - except Exception as e: - print(f"Error for yearly data: {e}") - return - title = f"Plot of the number of valid observations for yearly average of {self.year} from {self.start_date} to {self.end_date}" - filename = f"{outdir}/{self.datatype}_2D_obs_count_yearly_{self.year}.png" - super().plot_2D_num_obs_custom( - lon_values=lon_values, - lat_values=lat_values, - num_obs=cum_count, - outdir=outdir, - title=title, - filename=filename, - ) +# ============================================================================= +# MULTIPROCESSING - POOL +# ============================================================================= @timeit - def process_data(self, monthly_avg=True, batch_size=100, regrid=""): + def process_data(self, monthly_avg=True, batch_size=100): """ Process the data for the specified year and months. @@ -511,6 +326,11 @@ class VIIRS(MAPIES): if monthly_avg: for month in self.monthly_dict.keys(): # print(f'Month: {self.monthly_dict[month]}') + # Catch if they haven't selected a grid + try: + print(f"Running average on {self.grid}") + except AttributeError: + raise Exception("Please specify a grid in the initialisation of MAPIES, using grid_repr=''") try: # Process the data for the current month final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( @@ -568,6 +388,88 @@ class VIIRS(MAPIES): print(f"Error processing data: {e}") return + + def process_monthly_data(self, month, batch_size): + """ + Process and compute the 2D monthly average for a given month/list of files using multiprocessing. + + Parameters: + - month: int, the month to process (e.g., 1 for January). + - batch_size: int, the number of files to process in each batch. + + Returns: + - lon_final, lat_final, obs_final: Aggregated longitude, latitude, and observation arrays. + """ + + month_files = self.monthly_dict[month]["files"] + print(f"Processing {len(month_files)} files for month {month}") + + # Split files into batches + batches = [ + month_files[i:i + batch_size] for i in range(0, len(month_files), batch_size) + ] + + # Initialize cumulative arrays + cumulative_obs = None + cumulative_lat = None + cumulative_lon = None + cumulative_count = None + failed_files = [] + + for batch in batches: + print(f"Processing a batch of {len(batch)} files...") + + + # Use multiprocessing to process files within the batch + args = [(self.grid, file, self.obs_var, self.lat_var, self.lon_var, self.apply_qa, self.qualityFlags) for file in batch] + with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed + results = pool.starmap(process_single_file, args) + + # Filter valid results and log failed files + valid_results = [res for res in results if res is not None] + failed_files.extend([file for file, res in zip(batch, results) if res is None]) + + if not valid_results: + print("All files in this batch failed. Skipping...") + continue + + # Aggregate results incrementally + for obs_agg, lat_agg, lon_agg, count_obs in valid_results: + if cumulative_obs is None: + # Initialize cumulative arrays + cumulative_obs = obs_agg * count_obs + cumulative_lat = lat_agg * count_obs + cumulative_lon = lon_agg * count_obs + cumulative_count = count_obs + else: + cumulative_obs += obs_agg * count_obs + cumulative_lat += lat_agg * count_obs + cumulative_lon += lon_agg * count_obs + cumulative_count += count_obs + + if failed_files: + print(f"The following files failed to process: {failed_files}") + + # Check if any valid data was processed + if cumulative_obs is None or cumulative_count is None: + print(f"No valid data processed for month {month}. Returning empty arrays.") + return None, None, None, None + + # Compute final averages + valid_mask = cumulative_count > 0 + final_obs = np.zeros_like(cumulative_obs) + final_lat = np.zeros_like(cumulative_lat) + final_lon = np.zeros_like(cumulative_lon) + + final_obs[valid_mask] = cumulative_obs[valid_mask] / cumulative_count[valid_mask] + final_lat[valid_mask] = cumulative_lat[valid_mask] / cumulative_count[valid_mask] + final_lon[valid_mask] = cumulative_lon[valid_mask] / cumulative_count[valid_mask] + + print(f"Completed processing for month {month}.") + + return final_lon, final_lat, final_obs, cumulative_count + + @timeit def yearly_average(self): """ @@ -636,84 +538,198 @@ class VIIRS(MAPIES): print("Yearly average computation completed.") # ============================================================================= -# MULTIPROCESSING - POOL +# Supporting functions # ============================================================================= - def process_monthly_data(self, month, batch_size): + @timeit + def to_da(self, frequency="D", save_nc= True, save_figs=False): + """ + Function that returns all the needed variables for the DA """ - Process and compute the 2D monthly average for a given month/list of files using multiprocessing. - Parameters: - - month: int, the month to process (e.g., 1 for January). - - batch_size: int, the number of files to process in each batch. + # Calculate error + self.obserr = error_estimation(self.datatype, self.obs, self.unc_const) - Returns: - - lon_final, lat_final, obs_final: Aggregated longitude, latitude, and observation arrays. - """ + outfiles = [] + + self.frequency = frequency_int(frequency) + da_dates_slice = pd.date_range( + self.start_date, + self.end_date, + freq=frequency + ).strftime('%Y%m%d%H%M') - month_files = self.monthly_dict[month]["files"] - print(f"Processing {len(month_files)} files for month {month}") + for date in da_dates_slice: + l_border = (datetime.strptime(date, "%Y%m%d%H%M") - timedelta(hours=self.frequency/2)).strftime('%Y%m%d%H%M') + r_border = (datetime.strptime(date, "%Y%m%d%H%M") + timedelta(hours=self.frequency/2)).strftime('%Y%m%d%H%M') + # Included by Guillaume in his DA function so I believe he needs this to be true + filename = Path(self.dest).joinpath(f'obs{date}.nc') + # Cut down the time interval again then reindex obs, obserr, lat and lon + # Returning also arrays for type, lev, id and n with the same index as the time var + time_values, frequency_index = time_domain_selection( + self.time_values, + l_border, + r_border, + ) - # Split files into batches - batches = [ - month_files[i:i + batch_size] for i in range(0, len(month_files), batch_size) - ] + # Reindex + obs = self.reindex(frequency_index, self.obs) + lon = self.reindex(frequency_index, self.lon_values) + lat = self.reindex(frequency_index, self.lat_values) + obserr = self.reindex(frequency_index, self.obserr) + # Run aggregation with grid representation + try: + if self.grid_repr == "rotated": + lon, lat, rlon, rlat, obs, obserr, count_obs = regridding_function(self.grid, lon, lat, obs, obserr) + elif self.grid_repr == "regular": + lon, lat, obs, obserr, count_obs = regridding_function(self.grid, lon, lat, obs, obserr) + except ValueError: + raise Exception("Have you specified a grid_repr") - # Initialize cumulative arrays - cumulative_obs = None - cumulative_lat = None - cumulative_lon = None - cumulative_count = None - failed_files = [] + - for batch in batches: - print(f"Processing a batch of {len(batch)} files...") + # Create new arrays with same length as obs + obsid = np.full(shape=obs.shape, fill_value=self.obsid_da, dtype=self.int_type) + obstype = np.full(shape=obs.shape, fill_value=30, dtype=self.int_type) + obslev = np.full(shape=obs.shape, fill_value=-99999, dtype=self.int_type) + obsn = np.full(shape=obs.shape, fill_value=count_obs, dtype=self.int_type) + time_values = np.full(shape=obs.shape, fill_value=date, dtype=str) - # Use multiprocessing to process files within the batch - args = [(self.grid, file, self.obs_var, self.lat_var, self.lon_var, self.apply_qa, self.qualityFlags) for file in batch] - with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed - results = pool.starmap(process_single_file, args) + # Coords equals index - # Filter valid results and log failed files - valid_results = [res for res in results if res is not None] - failed_files.extend([file for file, res in zip(batch, results) if res is None]) + assert obs.shape == obserr.shape + assert lon.shape == time_values.shape + coords = dict(index=("index", np.indices(obs.shape)[0, :])) + data_vars = dict( + time=(["index",], time_values), + lon=(["index",], lon), + lat=(["index",], lat), + obs=(["index",], obs), + obserr=(["index", ], obserr), + obsid=(["index", ], obsid), + obstype=(["index", ], obstype), + obslev=(["index", ], obslev), + obsn=(["index", ], obsn), + ) + if self.grid_repr == "rotated": + rotated_dict = dict( + rlon=(["index",], rlon), + rlat=(["index",], rlat), #Also return rlon/rlat values if available + ) + data_vars.update(rotated_dict) - if not valid_results: - print("All files in this batch failed. Skipping...") - continue + + #save the plots if specified for each time frequency + if save_figs: + self.plot_2D_obs_custom(lon, lat, obs, outdir=self.dest, filename=f"da_obs_{date}.png") - # Aggregate results incrementally - for obs_agg, lat_agg, lon_agg, count_obs in valid_results: - if cumulative_obs is None: - # Initialize cumulative arrays - cumulative_obs = obs_agg * count_obs - cumulative_lat = lat_agg * count_obs - cumulative_lon = lon_agg * count_obs - cumulative_count = count_obs - else: - cumulative_obs += obs_agg * count_obs - cumulative_lat += lat_agg * count_obs - cumulative_lon += lon_agg * count_obs - cumulative_count += count_obs + ds = self.to_xarray(coords=coords, data_vars=data_vars) + if save_nc: + logger.info(f"Outputting da data with {filename}") + ds.to_netcdf(filename, encoding={}) + outfiles.append(filename) - if failed_files: - print(f"The following files failed to process: {failed_files}") + return outfiles - # Check if any valid data was processed - if cumulative_obs is None or cumulative_count is None: - print(f"No valid data processed for month {month}. Returning empty arrays.") - return None, None, None, None +# ============================================================================= +# Plotting +# ============================================================================= - # Compute final averages - valid_mask = cumulative_count > 0 - final_obs = np.zeros_like(cumulative_obs) - final_lat = np.zeros_like(cumulative_lat) - final_lon = np.zeros_like(cumulative_lon) - final_obs[valid_mask] = cumulative_obs[valid_mask] / cumulative_count[valid_mask] - final_lat[valid_mask] = cumulative_lat[valid_mask] / cumulative_count[valid_mask] - final_lon[valid_mask] = cumulative_lon[valid_mask] / cumulative_count[valid_mask] + @timeit + def plot_2D_observations(self, months=None, outdir=None): - print(f"Completed processing for month {month}.") - - return final_lon, final_lat, final_obs, cumulative_count + # if 0 in months: + if months == [0]: + try: + self.plot_2D_obs() + except Exception as e: + logger.error(f"Error plotting all data: {e}") + return + elif months is not None: + # elif m in months: + for m in months: + try: + month_data = self.monthly_dict[m] # Access the dictionary for this month + # files = month_data["files"] # List of files for the month (not used for plotting) + lon_values = month_data["lon"] + lat_values = month_data["lat"] + obs_values = month_data["obs"] + except Exception as e: + logger.error(f"Error for month {m}: {e}") + continue + # title = f"Observation 2D plot of {self.datatype.upper()} data from {self.start_date} to {self.end_date}" + title = f"Observation 2D plot of month {m} from {self.start_date} to {self.end_date}" + filename = f"{outdir}/{self.datatype}_2D_obs_month{m}_{self.year}_.png" + + super().plot_2D_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + obs_values=obs_values, + outdir=outdir, + title=title, + filename=filename, + ) + else: + try: + lon_values = self.yearly_lon + lat_values = self.yearly_lat + obs_values = self.yearly_obs + except Exception as e: + logger.error(f"Error for yearly data: {e}") + return + title = f"Observation 2D plot of yearly average of year {self.year} from {self.start_date} to {self.end_date}" + filename = f"{outdir}/{self.datatype}_2D_yearly_obs_{self.year}.png" + super().plot_2D_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + obs_values=obs_values, + outdir=outdir, + title=title, + filename=filename, + ) + + def plot_2D_num_obs(self, months=None, outdir=None): + + # if isinstance(months, int): + # months = [months] + if months is not None: + for m in months: + try: + month_data = self.monthly_dict[m] # Access the dictionary for this month + # files = month_data["files"] # List of files for the month (not used for plotting) + lon_values = month_data["lon"] + lat_values = month_data["lat"] + cum_count = month_data["count"] + except Exception as e: + logger.error(f"Error for month {m}: {e}") + continue + # title = f"Observation 2D plot of {self.datatype.upper()} data from {self.start_date} to {self.end_date}" + title = f"Plot of the number of valid observations for month {m} from {self.start_date} to {self.end_date}" + filename = f"{outdir}/{self.datatype}_2D_obs_count_month{m}_{self.year}_trial4.png" + super().plot_2D_num_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + num_obs=cum_count, + outdir=outdir, + title=title, + filename=filename, + ) + else: + try: + lon_values = self.yearly_lon + lat_values = self.yearly_lat + cum_count = self.yearly_count + except Exception as e: + logger.error(f"Error for yearly data: {e}") + return + title = f"Plot of the number of valid observations for yearly average of {self.year} from {self.start_date} to {self.end_date}" + filename = f"{outdir}/{self.datatype}_2D_obs_count_yearly_{self.year}.png" + super().plot_2D_num_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + num_obs=cum_count, + outdir=outdir, + title=title, + filename=filename, + ) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index aadd1aa32fba3f5bc38e4b0d5aff0105840edd48..f46a0d4621ff18dc7cc7eab641f775cfec0830f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ + attrs==24.2.0 Cartopy==0.22.0 certifi==2024.8.30 @@ -59,4 +60,4 @@ tzdata==2024.2 urllib3==2.2.3 xarray==2023.1.0 zict==3.0.0 -zipp==3.20.2 +zipp==3.20.2 \ No newline at end of file diff --git a/run/run_viirs.py b/run/run_viirs.py index b622c73be1bcf7dca9ea6d008eaf3af36936a1be..8e74881d4a999a87a2a63fe88216807908cfb64b 100644 --- a/run/run_viirs.py +++ b/run/run_viirs.py @@ -8,14 +8,12 @@ if __name__ == "__main__": start_date = "202401010000" - end_date = "202401020000" - outdir="/esarchive/scratch/cmeikle/" - indir="/esarchive/obs/nasa/viirs_noaa20_aerdb_l2/original_files/VIIRS" - c = VIIRS(start_date, end_date, frequency="D", dest=outdir, indir=indir, apply_qa=True, grid_repr="regular") + end_date = "202401040000" + outdir="/home/cmeikle/Projects/data/VIIRS/" + indir="/home/cmeikle/Projects/data/VIIRS" + c = VIIRS(start_date, end_date, dest=outdir, indir=indir, apply_qa=True, grid_repr="rotated") c.read_nc() - c.process_data(monthly_avg=False, batch_size = 10) - #c.plot_2D_obs() - c.plot_2D_observations(months=[0], outdir=outdir) - c.plot_2D_num_obs(months=[0], outdir=outdir) - - + c.process_data(monthly_avg=True, batch_size = 20) + #c.to_da(frequency="D", save_figs=True) + #c.plot_2D_observations(months=[1], outdir=outdir) + #c.plot_2D_num_obs(months=[1], outdir=outdir) \ No newline at end of file diff --git a/setup.py b/setup.py index b147a08a39fef5086f5f11e6d4a98dc1949fd394..e954ab7820750e018b348e6fa924715c6924d8ff 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ from setuptools import find_packages from setuptools import setup # Could update this using versioneer -version="0.0.8" +version="0.0.9" setup( name="mapies",