diff --git a/earthdiagnostics/ocean/psi.py b/earthdiagnostics/ocean/psi.py index 24d8b322f566277e3c2f433a43381e728eaff2c5..e1624b6a88a35e65d4fc8facc352f50c083b06b5 100644 --- a/earthdiagnostics/ocean/psi.py +++ b/earthdiagnostics/ocean/psi.py @@ -1,15 +1,13 @@ # coding=utf-8 import numpy as np -import netCDF4 import iris import iris.analysis import iris.coords import iris.util -from bscearth.utils.log import Log -from earthdiagnostics.diagnostic import Diagnostic, DiagnosticBasinListOption +from earthdiagnostics.diagnostic import Diagnostic from earthdiagnostics.modelingrealm import ModelingRealms from earthdiagnostics.utils import Utils, TempFile @@ -60,7 +58,7 @@ class Psi(Diagnostic): ) def __str__(self): - return "PSI Startdate: {0} Member: {1} Chunk: {2} Basins: {3}".format( + return "PSI Startdate: {0} Member: {1} Chunk: {2}".format( self.startdate, self.member, self.chunk, @@ -81,17 +79,11 @@ class Psi(Diagnostic): :type options: list[str] :return: """ - options_available = (DiagnosticBasinListOption("basins", "global"),) + options_available = () options = cls.process_options(options, options_available) - basins = options["basins"] - if not basins: - Log.error("Basins not recognized") - return () - masks = {} - basins.sort() - for basin in basins: + for basin in ('global', ): masks[basin] = Utils.get_mask(basin) job_list = list() @@ -158,14 +150,11 @@ class Psi(Diagnostic): Utils.copy_variable(handler_source, handler_temp, lon_name, True, True) handler_temp.createDimension("region", len(result)) handler_temp.createDimension("region_length", 50) - var_region = handler_temp.createVariable( - "region", "S1", ("region", "region_length") - ) var = handler_temp.createVariable( - "vsftbarot", float, ("time", "j", "i", "region") + "vsftbarot", float, ("time", "j", "i",) ) var.units = "m3/s" - var.coordinates = " ".join((lat_name, lon_name)) + var.coordinates = " ".join((lat_name, lon_name,)) var.missing_value = 1e20 var.fill_value = 1e20 var.valid_min = -300e6 @@ -173,9 +162,8 @@ class Psi(Diagnostic): var.long_name = "Barotropic_Stream_Function" for i, basin in enumerate(result): - var_region[i, ...] = netCDF4.stringtoarr(str(basin), 50) result[basin].mask = self.masks[basin] < 1 - var[..., i] = result[basin] + var[...] = result[basin] handler_temp.close() self.psi.set_local_file(temp, diagnostic=self) diff --git a/test/unit/ocean/test_psi.py b/test/unit/ocean/test_psi.py index a4b5efdb2d8b82063d1d802eafa8235f284db847..c94ff3eaeba7de5aaeb7628a06084566f93703fd 100644 --- a/test/unit/ocean/test_psi.py +++ b/test/unit/ocean/test_psi.py @@ -33,27 +33,12 @@ class TestPsi(TestCase): Psi(self.data_manager, "20010101", 0, 1, {Basins().Global: None}), ) - jobs = Psi.generate_jobs(self.diags, ["diagnostic", "atl"]) - self.assertEqual(len(jobs), 2) - self.assertEqual( - jobs[0], - Psi( - self.data_manager, "20010101", 0, 0, {Basins().Atlantic: None} - ), - ) - self.assertEqual( - jobs[1], - Psi( - self.data_manager, "20010101", 0, 1, {Basins().Atlantic: None} - ), - ) - with self.assertRaises(Exception): - Psi.generate_jobs(self.diags, ["diagnostic", "atl", "badoption"]) + Psi.generate_jobs(self.diags, ["diagnostic", "extra"]) def test_str(self): psi = Psi(self.data_manager, "20000101", 1, 1, {Basins().Global: None}) self.assertEqual( str(psi), - "PSI Startdate: 20000101 Member: 1 Chunk: 1 Basins: Global", + "PSI Startdate: 20000101 Member: 1 Chunk: 1", )