From 75bbbe152be2da9381ebedc6ec6fbdd3232ba63a Mon Sep 17 00:00:00 2001 From: Carles Tena Date: Thu, 7 Jul 2022 17:05:14 +0200 Subject: [PATCH] Horizontal interpolation working for Y axis --- nes/interpolation/horizontal_interpolation.py | 100 ++++++++++++++++-- 1 file changed, 89 insertions(+), 11 deletions(-) diff --git a/nes/interpolation/horizontal_interpolation.py b/nes/interpolation/horizontal_interpolation.py index ab6d774..d9b4eb0 100644 --- a/nes/interpolation/horizontal_interpolation.py +++ b/nes/interpolation/horizontal_interpolation.py @@ -24,7 +24,13 @@ def interpolate_horizontal(self, dst_grid, weight_matrix_path=None, kind='Neares """ # Obtain weight matrix - weights, idx = get_weights_idx(self, dst_grid, weight_matrix_path, kind, n_neighbours) + if self.parallel_method == 'T': + weights, idx = get_weights_idx_t_axis(self, dst_grid, weight_matrix_path, kind, n_neighbours) + elif self.parallel_method in ['Y', 'X']: + weights, idx = get_weights_idx_y_axis(self, dst_grid, weight_matrix_path, kind, n_neighbours) + else: + raise NotImplemented("Parallel method {0} is not implemented yet for horizontal interpolations. Use 'T'".format( + self.parallel_method)) # Apply weights final_dst = dst_grid.copy() @@ -48,19 +54,85 @@ def interpolate_horizontal(self, dst_grid, weight_matrix_path=None, kind='Neares # Creating empty data final_dst.variables[var_name]['data'] = np.empty(dst_shape) - src_data = var_info['data'].reshape((src_shape[0], src_shape[1], src_shape[2] * src_shape[3])) + # src_data = var_info['data'].reshape((src_shape[0], src_shape[1], src_shape[2] * src_shape[3])) for time in range(dst_shape[0]): for lev in range(dst_shape[1]): - src_aux = np.take(src_data[time, lev], idx) + src_aux = get_src_data(self.comm, var_info['data'][time, lev], idx, self.parallel_method) + # src_aux = np.take(src_data[time, lev], idx) final_dst.variables[var_name]['data'][time, lev] = np.sum(weights * src_aux, axis=1) return final_dst -def get_weights_idx(self, dst_grid, weight_matrix_path, kind, n_neighbours): - if self.parallel_method != 'T': - raise NotImplemented("Parallel method {0} is not implemented yet for horizontal interpolations. Use 'T'".format( - self.parallel_method)) +def get_src_data(comm, var_data, idx, parallel_method): + if parallel_method == 'T': + var_data = var_data.flatten() + else: + var_data = comm.gather(var_data, root=0) + if comm.Get_rank() == 0: + if parallel_method == 'Y': + axis = 0 + elif parallel_method == 'X': + axis = 1 + else: + raise NotImplementedError(parallel_method) + var_data = np.concatenate(var_data, axis=axis) + var_data = var_data.flatten() + + var_data = comm.bcast(var_data) + + var_data = np.take(var_data, idx) + return var_data + + +def get_weights_idx_t_axis(self, dst_grid, weight_matrix_path, kind, n_neighbours): + if weight_matrix_path is not None: + with FileLock(weight_matrix_path.replace('.nc', '.lock')): + if self.master: + if os.path.isfile(weight_matrix_path): + weight_matrix = read_weight_matrix(weight_matrix_path, comm=MPI.COMM_SELF) + if len(weight_matrix.lev['data']) != n_neighbours: + warn("The selected weight matrix does not have the same number of nearest neighbours." + + "Re-calculating again but not saving it.") + if kind in ['NearestNeighbour', 'NearestNeighbours', 'nn', 'NN']: + weight_matrix = create_nn_weight_matrix(self, dst_grid, n_neighbours=n_neighbours) + else: + raise NotImplementedError(kind) + else: + if kind in ['NearestNeighbour', 'NearestNeighbours', 'nn', 'NN']: + weight_matrix = create_nn_weight_matrix(self, dst_grid, n_neighbours=n_neighbours) + else: + raise NotImplementedError(kind) + if weight_matrix_path is not None: + weight_matrix.to_netcdf(weight_matrix_path) + else: + weight_matrix = None + else: + if kind in ['NearestNeighbour', 'NearestNeighbours', 'nn', 'NN']: + if self.master: + weight_matrix = create_nn_weight_matrix(self, dst_grid, n_neighbours=n_neighbours) + else: + weight_matrix = None + else: + raise NotImplementedError(kind) + + # Normalize to 1 + if self.master: + weights = np.array(np.array(weight_matrix.variables['inverse_dists']['data'], dtype=np.float64) / + np.array(weight_matrix.variables['inverse_dists']['data'], dtype=np.float64).sum(axis=1), + dtype=np.float64) + idx = np.array(weight_matrix.variables['idx']['data'][0], dtype=int) + else: + weights = None + idx = None + + weights = self.comm.bcast(weights, root=0) + idx = self.comm.bcast(idx, root=0) + + return weights, idx + + +def get_weights_idx_y_axis(self, dst_grid, weight_matrix_path, kind, n_neighbours): if weight_matrix_path is not None: with FileLock(weight_matrix_path.replace('.nc', '.lock')): if self.master: @@ -102,7 +174,13 @@ def get_weights_idx(self, dst_grid, weight_matrix_path, kind, n_neighbours): idx = None weights = self.comm.bcast(weights, root=0) + weights = weights[:, :, dst_grid.write_axis_limits['y_min']:dst_grid.write_axis_limits['y_max'], + dst_grid.write_axis_limits['x_min']:dst_grid.write_axis_limits['x_max']] idx = self.comm.bcast(idx, root=0) + idx = idx[:, dst_grid.write_axis_limits['y_min']:dst_grid.write_axis_limits['y_max'], + dst_grid.write_axis_limits['x_min']:dst_grid.write_axis_limits['x_max']] + + print(weights.shape, idx.shape) return weights, idx @@ -119,16 +197,16 @@ def create_nn_weight_matrix(self, dst_grid, n_neighbours=4, info=False): print("\tCreating Nearest Neighbour Weight Matrix with {0} neighbours".format(n_neighbours)) sys.stdout.flush() # Source - src_lat = np.array(self.lat['data'], dtype=np.float32) - src_lon = np.array(self.lon['data'], dtype=np.float32) + src_lat = np.array(self._lat['data'], dtype=np.float32) + src_lon = np.array(self._lon['data'], dtype=np.float32) # 1D to 2D coordinates if len(src_lon.shape) == 1: src_lon, src_lat = np.meshgrid(src_lon, src_lat) # Destination - dst_lat = np.array(dst_grid.lat['data'], dtype=np.float32) - dst_lon = np.array(dst_grid.lon['data'], dtype=np.float32) + dst_lat = np.array(dst_grid._lat['data'], dtype=np.float32) + dst_lon = np.array(dst_grid._lon['data'], dtype=np.float32) # 1D to 2D coordinates if len(dst_lon.shape) == 1: -- GitLab