diff --git a/mapies/tests/test_viirs.py b/mapies/tests/test_viirs.py index 046b4f3fc658e321ef29d2be9d53951bb6234b3c..ce2e87fec0c911bd8c3996b6aede9df0264fcab8 100644 --- a/mapies/tests/test_viirs.py +++ b/mapies/tests/test_viirs.py @@ -11,7 +11,7 @@ from mapies.viirs import VIIRS def test_to_da(): start_date = "202401010000" - end_date = "202401010030" + end_date = "202401011630" outdir="./" indir="test_data" c = VIIRS(start_date, end_date, dest=outdir, indir=indir, grid_repr="rotated") diff --git a/mapies/util/func_tools.py b/mapies/util/func_tools.py index aec2f8d67ebf4159226829d4989001281356cb61..fe93b21ec3b8a4e3449b7db92dfd8fe7d2235ef6 100644 --- a/mapies/util/func_tools.py +++ b/mapies/util/func_tools.py @@ -11,7 +11,7 @@ from typing import List from datetime import datetime import multiprocessing - +TIME_ORIGIN = np.datetime64("1993-01-01T00:00:00") #Timing decorator def timeit(func): @@ -82,7 +82,7 @@ def time_converter(date:str) -> pd.Timestamp: return date -def quality_assurance(ds, obs, lat, lon, qualityFlags): +def quality_assurance(ds, obs, lat, lon, time_values, qualityFlags): """ Applying quality assurance flags to mapies data """ @@ -106,9 +106,10 @@ def quality_assurance(ds, obs, lat, lon, qualityFlags): obs = obs[mask] lat = lat[mask] lon = lon[mask] + time_values = time_values[mask] - return obs, lat, lon + return obs, lat, lon, time_values # Time Domain Selection function based off of xarray Will this work with dask too, a question for later def time_domain_selection( @@ -218,3 +219,94 @@ def geo_to_rot(lons, lats, centre_lon: float, centre_lat: float): rlon = np.degrees(rlon) rlat = np.degrees(rlat) return rlon, rlat + +def new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date): + + """ + Preprocessing of the dataset + """ + time_values, time_values_index = preprocess_vars_mapies(ds, time_var, start_date, end_date) + + + # Lon and lat values + lon_dims = ds[lon_var].dims + lon_shape = ds[lon_var].shape + lon_attrs = ds[lon_var].attrs + lon_values = ds[lon_var].values.flatten() + + # Reindex lon column values + lon_values = reindex( + time_values_index, + lon_values, + ) + + lat_dims = ds[lat_var].dims + lat_shape = ds[lat_var].shape + lat_attrs = ds[lat_var].attrs + lat_values = ds[lat_var].values.flatten() + + # Reindex aod column values + lat_values = reindex( + time_values_index, + lat_values, + ) + + + # AOD column values, default "Aerosol_Optical_Thickness_550_Land_Ocean_Best_Estimate" + aod_dims = ds[obs_var].dims + + aod_shape = ds[obs_var].shape + aod_attrs = ds[obs_var].attrs + obs = ds[obs_var].values.flatten() + + # Reindex aod column values + obs = reindex( + time_values_index, + obs, + ) + + return obs, lat_values, lon_values, time_values, time_values_index + +def preprocess_vars_mapies(ds, time_var, start_date, end_date): + """ + Preprocessing of the dataset + """ + # Get all info about time columns + time_dims = ds[time_var].dims + time_shape = ds[time_var].shape + time_attrs = ds[time_var].attrs # Useful if we need to convert to datetime + + + # Get time values flatten and convert to datetime values + time_values = ds[time_var].values + time_values = time_values.flatten() + + if time_values.dtype == "timedelta64[ns]": + logging.info("Adding time origin to time values as the time variable is in timedelta") + time_values = np.add(TIME_ORIGIN, time_values) + time_values = pd.to_datetime(time_values) + else: + time_values = pd.to_datetime(time_values) + + + # Time domain selection + # print(f'Time values before selection: {time_values}') + time_values, time_values_index = time_domain_selection(time_values, start_date, end_date) + # print(f'Time values after selection: {time_values}') + + return time_values, time_values_index + +def reindex( + dependent_var_index, + independent_var_values, + ): + """" + Recutting the data whenever a selction of the data has been made along one of the dimenisons + + Based off of how it has been done in Providentia + """ + # Maybe add a checker try/except or assert + independent_var_values = independent_var_values[dependent_var_index] + + # return reindexed values + return independent_var_values \ No newline at end of file diff --git a/mapies/viirs.py b/mapies/viirs.py index 7b066dd4f8ef59f90f64cf53d8de6fc78d7f6ac4..a5fff213ca9305ec46e6c6caf1d4bd319a6d34c7 100644 --- a/mapies/viirs.py +++ b/mapies/viirs.py @@ -38,62 +38,35 @@ TIME_ORIGIN = np.datetime64("1900-01-01T00:00:00") - -def process_single_file(r, file, obs_var, lat_var, lon_var, apply_qa, qualityFlags): +def process_single_file(r, file, obs_var, lat_var, lon_var, time_var, start_date, end_date, apply_qa, qualityFlags, flag=False): """ Process a single file and return aggregated results. """ - ds = xr.open_dataset(file, engine="h5netcdf") # Open the file - obs = ds[obs_var].values.flatten() - lat = ds[lat_var].values.flatten() - lon = ds[lon_var].values.flatten() - - if apply_qa: - obs, lat, lon = quality_assurance(ds, obs, lat, lon, qualityFlags) - - # Regrid and aggregate - if isinstance(r, RotatedGrid): - lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) - elif isinstance(r, RegularGrid): - lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) - - ds.close() # Close the file to free memory - return obs_agg, lat_agg, lon_agg, count_obs - - -def process_batch(r, batch_files, obs_var, lat_var, lon_var, chunk_size, qualityFlags): - """ - Process a batch of files to calculate cumulative sums and counts. - - Parameters: - - batch_files: List of file paths in the batch.At the beginning I read the .csv files (I always save data in .csv files) - - obs_var, lat_var, lon_var: Variable names for observations, latitude, and longitude. - - rotate_func: Function to apply rotation and aggregation. - - Returns: - - cumulative_obs, cumulative_lat, cumulative_lon, cumulative_count: Aggregated results for the batch. - """ - def preprocess(ds): - """Add a unique dimension to distinguish datasets.""" - return ds.expand_dims(dim={"index": [ds.attrs["product_name"]]}) - - ds = xr.open_mfdataset( - batch_files, - chunks={"index": chunk_size}, - parallel=True, - engine="h5netcdf", - preprocess=preprocess, - ) - - obs = ds[obs_var].values.flatten() - lat = ds[lat_var].values.flatten() - lon = ds[lon_var].values.flatten() - - lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r,lon, lat, obs) + try: + ds = xr.open_dataset(file, engine="h5netcdf") # Open the file - ds.close() + obs, lat, lon, time_values, time_values_index = new_preprocess_vars(ds, obs_var, lat_var, lon_var, time_var, start_date, end_date) + # obs = ds[obs_var].values.flatten() + # lat = ds[lat_var].values.flatten() - return obs_agg, lat_agg, lon_agg, count_obs + if apply_qa: + obs, lat, lon, time_values = quality_assurance(ds, obs, lat, lon, time_values, qualityFlags) + + if not flag: + # Regrid and aggregate + if isinstance(r, RotatedGrid): + lon_agg, lat_agg, rlon_agg, rlat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + elif isinstance(r, RegularGrid): + lon_agg, lat_agg, obs_agg, count_obs = regridding_function(r, lon, lat, obs) + ds.close() + return obs_agg, lat_agg, lon_agg, count_obs + else: + ds.close() + return obs, lat, lon, time_values + + except Exception as e: + print(f"Error processing file {file}: {e}") + return None def filter_file(args): @@ -133,6 +106,7 @@ class VIIRS(MAPIES): self.start_day = int(start_date[6:8]) self.end_day = int(end_date[6:8]) self.monthly_dict = {} + self.daily_dict = {} # Dictionary to produce daily netcdf files self.dates_slice = pd.date_range( self.start_date, @@ -258,13 +232,6 @@ class VIIRS(MAPIES): files = sorted(get_file_list(file_patterns)) print(f"Total number of files: {len(files)}") - - start_dt = datetime.strftime(self.start_date, "%Y%m%d%H%M") - end_dt = datetime.strftime(self.end_date, "%Y%m%d%H%M") - - start_dt = datetime.strptime(start_dt, "%Y%m%d%H%M") - end_dt = datetime.strptime(end_dt, "%Y%m%d%H%M") - files_filtered = [] for file in files: parts = file.split('.') @@ -272,7 +239,7 @@ class VIIRS(MAPIES): hhmm = parts[2] file_dt = datetime.strptime(date + hhmm, "%Y%j%H%M") - if start_dt <= file_dt <= end_dt: + if self.start_date <= file_dt <= self.end_date: files_filtered.append(file) @@ -295,22 +262,35 @@ class VIIRS(MAPIES): "obs": None, # Observations array "count": None, # Count of observations used to compute the average } + if date.day not in self.daily_dict: + self.daily_dict[date.day] = { + "files": [], # List to store file paths + "lon": None, # Longitudes array + "lat": None, # Latitudes array + "obs": None, # Observations array + "time": None, # Scan Start Time array + } # Append the file to the list of files for the corresponding month self.monthly_dict[date.month]["files"].append(file) + self.daily_dict[date.day]["files"].append(file) print(f"Discovered files for months: {list(self.monthly_dict.keys())}") for month, data in self.monthly_dict.items(): print(f" Month {month:02d}: {len(data['files'])} files") + print(f"Discovered files for days: {list(self.daily_dict.keys())}") + + # Print the daily disctionary to see if files are stored in chronological order + for day, data in self.daily_dict.items(): + print(f"Day {day:02d}: {len(data['files'])} files") + # for filepath in data['files']: + # print(f" {filepath}") -# ============================================================================= -# MULTIPROCESSING - POOL -# ============================================================================= + @timeit - def process_data(self, monthly_avg=True, batch_size=100): - + def process_data(self, monthly_avg=False, batch_size=100): """ Process the data for the specified year and months. @@ -334,7 +314,7 @@ class VIIRS(MAPIES): try: # Process the data for the current month final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( - month, batch_size=batch_size + batch_size=batch_size, month=month ) # final_lon, final_lat, final_obs, cumulative_count = self.process_files_in_blocks(month, block_size=300, batch_size=batch_size) @@ -358,51 +338,32 @@ class VIIRS(MAPIES): print(f"Error processing data for month {month}: {e}") continue else: - # Check if the number of files is too big - if isinstance(self.start_date, str): - start_dt = datetime.strptime(self.start_date, "%Y%m%d%H%M") - else: - start_dt = self.start_date - - if isinstance(self.end_date, str): - end_dt = datetime.strptime(self.end_date, "%Y%m%d%H%M") - else: - end_dt = self.end_date - - time_difference = (end_dt - start_dt).days - - if time_difference > 3: - raise ValueError(f"The time difference between start_date ({self.start_date}) " - f"and end_date ({self.end_date}) exceeds 3 days. Under development.") - - def preprocess(ds): - ds = ds.expand_dims(dim={"index": [ds.attrs["product_name"]]}) - return ds - try: - print(f"Processing {len(self.files)} files") - self.ds = xr.open_mfdataset(self.files, preprocess=preprocess) - if self.ds is None: - raise ValueError("Failed to initialize dataset `self.ds`.") - self.preprocess_vars() + try: + final_lon, final_lat, final_obs, cumulative_count = self.process_monthly_data( + batch_size=batch_size, month=None + ) + self.lon_values = final_lon + self.lat_values = final_lat + self.obs = final_obs + self.count_obs = cumulative_count except Exception as e: - print(f"Error processing data: {e}") + print(f"Error processing data for the time period from {self.start_date} to {self.end_date}: {e}") return + def process_monthly_data(self, batch_size, month=None): - def process_monthly_data(self, month, batch_size): """ - Process and compute the 2D monthly average for a given month/list of files using multiprocessing. - - Parameters: - - month: int, the month to process (e.g., 1 for January). - - batch_size: int, the number of files to process in each batch. - - Returns: - - lon_final, lat_final, obs_final: Aggregated longitude, latitude, and observation arrays. + Function that returns all the needed variables for the DA """ + # r = RotatedGrid(centre_lon=20, centre_lat=35, dlon=.1, dlat=.1, west=-51, south=-35) + flag = False + if month is None: + month_files = self.files + print(f'Processing all files from {self.start_date} to {self.end_date}') + else: + month_files = self.monthly_dict[month]["files"] + print(f"Processing {len(month_files)} files for month {month}") - month_files = self.monthly_dict[month]["files"] - print(f"Processing {len(month_files)} files for month {month}") # Split files into batches batches = [ @@ -418,10 +379,11 @@ class VIIRS(MAPIES): for batch in batches: print(f"Processing a batch of {len(batch)} files...") - + print(f"Batch: {batch}") # Use multiprocessing to process files within the batch - args = [(self.grid, file, self.obs_var, self.lat_var, self.lon_var, self.apply_qa, self.qualityFlags) for file in batch] + args = [(self.grid, file, self.obs_var, self.lat_var, self.lon_var, self.time_var, self.start_date, self.end_date, self.apply_qa, self.qualityFlags, flag) for file in batch] + with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed results = pool.starmap(process_single_file, args) @@ -469,7 +431,6 @@ class VIIRS(MAPIES): return final_lon, final_lat, final_obs, cumulative_count - @timeit def yearly_average(self): """ @@ -537,10 +498,86 @@ class VIIRS(MAPIES): print("Yearly average computation completed.") + def process_lazy_data(self): + """ + Process the data for the specified time range. It only retrieves a daily dictionary using dask lazy loading. + """ + flag = True + for day in self.daily_dict.keys(): + try: + files = self.daily_dict[day]["files"] + args = [ + ( + self.grid, + file, + self.obs_var, + self.lat_var, + self.lon_var, + self.time_var, + self.start_date, + self.end_date, + self.apply_qa, + self.qualityFlags, + flag + ) + for file in files + ] + + daily_obs = [] + daily_lat = [] + daily_lon = [] + daily_time = [] + + with multiprocessing.Pool(processes=8) as pool: # Adjust number of processes as needed + results = pool.starmap(process_single_file, args) + + for obs, lat, lon, time_values in results: + if len(obs) > 0: # Skip empty arrays + daily_lon.append(lon) + daily_lat.append(lat) + daily_obs.append(obs) + daily_time.append(time_values) + + # Skip saving if all arrays are empty + if not daily_obs: + print(f"No valid data for day {day}") + continue + + # Concatenate all data in one step + final_obs = np.concatenate(daily_obs, axis=0) + final_lon = np.concatenate(daily_lon, axis=0) + final_lat = np.concatenate(daily_lat, axis=0) + final_time = np.concatenate(daily_time, axis=0) + + # Check if any valid data exists + if np.all(final_obs == 0) or np.all(np.isnan(final_obs)): + print(f"All observation values for day {day} are 0 or NaN. Skipping...") + continue + # print first 100 values of final_obs + print(f"First 100 values of final_obs: {final_obs[:100]}") + print(f"First 100 values of final_lon: {final_lon[:100]}") + print(f"First 100 values of final_lat: {final_lat[:100]}") + + # Save directly to NetCDF + ds = xr.Dataset( + coords={"time": final_time}, + data_vars={ + "lon": (["time"], final_lon), + "lat": (["time"], final_lat), + "obs": (["time"], final_obs), + }, + ) + filename = f"{self.dest}/daily_{day}.nc" + ds.to_netcdf(filename, encoding={}) + + except Exception as e: + print(f"Error processing data for day {day}: {e}") + continue + + # ============================================================================= # Supporting functions # ============================================================================= - @timeit def to_da(self, frequency="D", save_nc= True, save_figs=False): """ @@ -630,24 +667,42 @@ class VIIRS(MAPIES): outfiles.append(filename) return outfiles + + + + + + + + # ============================================================================= # Plotting # ============================================================================= - @timeit - def plot_2D_observations(self, months=None, outdir=None): + def plot_2D_observations(self, months=[0], outdir=None): - # if 0 in months: if months == [0]: try: - self.plot_2D_obs() + lon_values = self.lon_values + lat_values = self.lat_values + obs_values = self.obs except Exception as e: - logger.error(f"Error plotting all data: {e}") + print(f"Error plotting all data: {e}") return + title = f"Observation 2D plot from {self.start_date} to {self.end_date}" + filename = f"{outdir}/{self.datatype}_2D_obs_from_{self.start_date}_to_{self.end_date}_dada.png" + super().plot_2D_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + obs_values=obs_values, + outdir=outdir, + title=title, + filename=filename, + ) + return elif months is not None: - # elif m in months: for m in months: try: month_data = self.monthly_dict[m] # Access the dictionary for this month @@ -656,11 +711,11 @@ class VIIRS(MAPIES): lat_values = month_data["lat"] obs_values = month_data["obs"] except Exception as e: - logger.error(f"Error for month {m}: {e}") + print(f"Error for month {m}: {e}") continue # title = f"Observation 2D plot of {self.datatype.upper()} data from {self.start_date} to {self.end_date}" title = f"Observation 2D plot of month {m} from {self.start_date} to {self.end_date}" - filename = f"{outdir}/{self.datatype}_2D_obs_month{m}_{self.year}_.png" + filename = f"{outdir}/{self.datatype}_2D_obs_month{m}_{self.year}_dada.png" super().plot_2D_obs_custom( lon_values=lon_values, @@ -676,7 +731,7 @@ class VIIRS(MAPIES): lat_values = self.yearly_lat obs_values = self.yearly_obs except Exception as e: - logger.error(f"Error for yearly data: {e}") + print(f"Error for yearly data: {e}") return title = f"Observation 2D plot of yearly average of year {self.year} from {self.start_date} to {self.end_date}" filename = f"{outdir}/{self.datatype}_2D_yearly_obs_{self.year}.png" @@ -689,11 +744,28 @@ class VIIRS(MAPIES): filename=filename, ) - def plot_2D_num_obs(self, months=None, outdir=None): + def plot_2D_num_obs(self, months=[0], outdir=None): - # if isinstance(months, int): - # months = [months] - if months is not None: + if months == [0]: + try: + lon_values = self.lon_values + lat_values = self.lat_values + cum_count = self.count_obs + except Exception as e: + print(f"Error plotting all data: {e}") + return + title = f"Plot of the number of valid observations from {self.start_date} to {self.end_date}" + filename = f"{outdir}/{self.datatype}_2D_obs_count_from_{self.start_date}_to_{self.end_date}_dada.png" + super().plot_2D_num_obs_custom( + lon_values=lon_values, + lat_values=lat_values, + num_obs=cum_count, + outdir=outdir, + title=title, + filename=filename, + ) + return + elif months is not None: for m in months: try: month_data = self.monthly_dict[m] # Access the dictionary for this month @@ -702,11 +774,11 @@ class VIIRS(MAPIES): lat_values = month_data["lat"] cum_count = month_data["count"] except Exception as e: - logger.error(f"Error for month {m}: {e}") + print(f"Error for month {m}: {e}") continue # title = f"Observation 2D plot of {self.datatype.upper()} data from {self.start_date} to {self.end_date}" title = f"Plot of the number of valid observations for month {m} from {self.start_date} to {self.end_date}" - filename = f"{outdir}/{self.datatype}_2D_obs_count_month{m}_{self.year}_trial4.png" + filename = f"{outdir}/{self.datatype}_2D_obs_count_month{m}_{self.year}_dada.png" super().plot_2D_num_obs_custom( lon_values=lon_values, lat_values=lat_values, @@ -721,7 +793,7 @@ class VIIRS(MAPIES): lat_values = self.yearly_lat cum_count = self.yearly_count except Exception as e: - logger.error(f"Error for yearly data: {e}") + print(f"Error for yearly data: {e}") return title = f"Plot of the number of valid observations for yearly average of {self.year} from {self.start_date} to {self.end_date}" filename = f"{outdir}/{self.datatype}_2D_obs_count_yearly_{self.year}.png" @@ -732,4 +804,4 @@ class VIIRS(MAPIES): outdir=outdir, title=title, filename=filename, - ) \ No newline at end of file + ) diff --git a/run/run_viirs.py b/run/run_viirs.py index 8e74881d4a999a87a2a63fe88216807908cfb64b..d6d359b728e5f27bd73894ac0ba620d7544edd2f 100644 --- a/run/run_viirs.py +++ b/run/run_viirs.py @@ -5,15 +5,40 @@ import time if __name__ == "__main__": # Change these variables + # 00:30 to 02:36 from 2024-01-01 to 2024-01-03 + # start_date = "202401010030" + # end_date = "202401030236" + # 00:30 to 03:36 of 2024-01-01 + start_date = "202401010030" + end_date = "202401011636" - start_date = "202401010000" - end_date = "202401040000" - outdir="/home/cmeikle/Projects/data/VIIRS/" - indir="/home/cmeikle/Projects/data/VIIRS" - c = VIIRS(start_date, end_date, dest=outdir, indir=indir, apply_qa=True, grid_repr="rotated") + # ONE YEAR + # start_date = "202401300000" + # end_date = "202412312359" + # end_date = "202402022359" + + # ONE MONTH - JANUARY 2024 + # start_date = "202401010000" + # end_date = "202401312359" + + # outdir="/home/cmeikle/Projects/data/" + # indir="/home/cmeikle/Projects/data/VIIRS/original_files/AERDB_L2_VIIRS_NOAA20" + outdir = "/home/cgile/Documents/mapies/figures" + # indir = "/home/cgile/bscearth000/esarchive/obs/nasa/viirs_noaa20_aerdb_l2/original_files/VIIRS" + indir = "/home/cgile/Documents/mapies/VIIRS" + + start_time = time.time() + + c = VIIRS(start_date, end_date, frequency="D", dest=outdir, indir=indir, apply_qa=False, grid_repr="rotated") c.read_nc() - c.process_data(monthly_avg=True, batch_size = 20) - #c.to_da(frequency="D", save_figs=True) - #c.plot_2D_observations(months=[1], outdir=outdir) - #c.plot_2D_num_obs(months=[1], outdir=outdir) \ No newline at end of file + c.process_data(monthly_avg=False, batch_size = 100) + # c.yearly_average() + c.plot_2D_observations(months=[0], outdir=outdir) + c.plot_2D_num_obs(months=[0], outdir=outdir) + c.process_lazy_data() + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Script executed in {elapsed_time:.2f} seconds.") + +