test_diagnostic.py 12.5 KB
Newer Older
# coding=utf-8
from earthdiagnostics.diagnostic import *
from unittest import TestCase

from earthdiagnostics.modelingrealm import ModelingRealms
Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
from mock import patch, Mock
class TestDiagnosticOption(TestCase):

    def test_good_default_value(self):
        diag = DiagnosticOption('option', 'default')
        self.assertEqual('default', diag.parse(''))

    def test_no_default_value(self):
        diag = DiagnosticOption('option')
        with self.assertRaises(DiagnosticOptionError):
            self.assertEqual('default', diag.parse(''))

    def test_parse_value(self):
        diag = DiagnosticOption('option')
        self.assertEqual('value', diag.parse('value'))


class TestDiagnosticFloatOption(TestCase):
    def test_float_default_value(self):
        diag = DiagnosticFloatOption('option', 3.0)
        self.assertEqual(3.0, diag.parse(''))

    def test_str_default_value(self):
        diag = DiagnosticFloatOption('option', '3')
        self.assertEqual(3.0, diag.parse(''))

    def test_bad_default_value(self):
        diag = DiagnosticFloatOption('option', 'default')
        with self.assertRaises(ValueError):
            self.assertEqual('default', diag.parse(''))

    def test_no_default_value(self):
        diag = DiagnosticFloatOption('option')
        with self.assertRaises(DiagnosticOptionError):
            self.assertEqual('default', diag.parse(''))

    def test_parse_value(self):
        diag = DiagnosticFloatOption('option')
        self.assertEqual(3.25, diag.parse('3.25'))


class TestDiagnosticDomainOption(TestCase):
    def test_domain_default_value(self):
        diag = DiagnosticDomainOption('option', ModelingRealms.ocean)
        self.assertEqual(ModelingRealms.ocean, diag.parse(''))

    def test_str_default_value(self):
        diag = DiagnosticDomainOption('option', 'atmos')
        self.assertEqual(ModelingRealms.atmos, diag.parse(''))

    def test_bad_default_value(self):
        diag = DiagnosticDomainOption('option', 'default')
        with self.assertRaises(ValueError):
            diag.parse('')

    def test_no_default_value(self):
        diag = DiagnosticDomainOption('option')
        with self.assertRaises(DiagnosticOptionError):
            diag.parse('')

    def test_parse_value(self):
        diag = DiagnosticDomainOption('option')
        self.assertEqual(ModelingRealms.seaIce, diag.parse('seaice'))


class TestDiagnosticIntOption(TestCase):
    def test_int_default_value(self):
        diag = DiagnosticIntOption('option', 3)
        self.assertEqual(3, diag.parse(''))

    def test_str_default_value(self):
        diag = DiagnosticIntOption('option', '3')
        self.assertEqual(3, diag.parse(''))

    def test_bad_default_value(self):
        diag = DiagnosticIntOption('option', 'default')
        with self.assertRaises(ValueError):
            diag.parse('')

    def test_no_default_value(self):
        diag = DiagnosticIntOption('option')
        with self.assertRaises(DiagnosticOptionError):
            diag.parse('')

    def test_parse_value(self):
        diag = DiagnosticIntOption('option')
        self.assertEqual(3, diag.parse('3'))

    def test_parse_bad_value(self):
        diag = DiagnosticIntOption('option')
        with self.assertRaises(ValueError):
            diag.parse('3.5')

    def test_good_low_limit(self):
        diag = DiagnosticIntOption('option', None, 0)
        self.assertEqual(1, diag.parse('1'))

    def test_bad_low_limit(self):
        diag = DiagnosticIntOption('option', None, 0)
        with self.assertRaises(DiagnosticOptionError):
            diag.parse('-1')

    def test_good_high_limit(self):
        diag = DiagnosticIntOption('option', None, None, 0)
        self.assertEqual(-1, diag.parse('-1'))

    def test_bad_high_limit(self):
        diag = DiagnosticIntOption('option', None, None, 0)
        with self.assertRaises(DiagnosticOptionError):
            diag.parse('1')


class TestDiagnosticBoolOption(TestCase):
    def test_bool_default_value(self):
        diag = DiagnosticBoolOption('option', True)
        self.assertEqual(True, diag.parse(''))

    def test_str_default_value(self):
        diag = DiagnosticBoolOption('option', 'False')
        self.assertEqual(False, diag.parse(''))

    def test_no_default_value(self):
        diag = DiagnosticBoolOption('option')
        with self.assertRaises(DiagnosticOptionError):
            diag.parse('')

    def test_parse_True(self):
        diag = DiagnosticBoolOption('option')
        self.assertTrue(diag.parse('true'))

    def test_parse_true(self):
        diag = DiagnosticBoolOption('option')
        self.assertTrue(diag.parse('true'))

    def test_parse_t(self):
        diag = DiagnosticBoolOption('option')
        self.assertTrue(diag.parse('t'))

    def test_parse_yes(self):
        diag = DiagnosticBoolOption('option')
        self.assertTrue(diag.parse('YES'))

    def test_parse_bad_value(self):
        diag = DiagnosticBoolOption('option')
        self.assertFalse(diag.parse('3.5'))


class TestDiagnosticComplexStrOption(TestCase):
    def test_complex_default_value(self):
        diag = DiagnosticComplexStrOption('option', 'default&.str&;&.working')
        self.assertEqual('default str, working', diag.parse(''))

    def test_simple_default_value(self):
        diag = DiagnosticComplexStrOption('default str, working', 'default str, working')
        self.assertEqual('default str, working', diag.parse(''))

    def test_no_default_value(self):
        diag = DiagnosticComplexStrOption('option')
        with self.assertRaises(DiagnosticOptionError):
            diag.parse('')

    def test_parse_value(self):
        diag = DiagnosticComplexStrOption('option')
        self.assertEqual('complex string, for testing', diag.parse('complex&.string&;&.for&.testing'))


class TestDiagnosticListIntOption(TestCase):
    def test_tuple_default_value(self):
        diag = DiagnosticListIntOption('option', (3,))
        self.assertEqual((3,), diag.parse(''))

    def test_list_default_value(self):
        diag = DiagnosticListIntOption('option', [3])
        self.assertEqual([3], diag.parse(''))

    def test_str_default_value(self):
        diag = DiagnosticListIntOption('option', '3-4')
        self.assertEqual([3, 4], diag.parse(''))

    def test_bad_default_value(self):
        diag = DiagnosticListIntOption('option', 'default')
        with self.assertRaises(ValueError):
            diag.parse('')

    def test_no_default_value(self):
        diag = DiagnosticListIntOption('option')
        with self.assertRaises(DiagnosticOptionError):
            diag.parse('')

    def test_parse_value(self):
        diag = DiagnosticListIntOption('option')
        self.assertEqual([3, 2], diag.parse('3-2'))

    def test_parse_single_value(self):
        diag = DiagnosticListIntOption('option')
        self.assertEqual([3], diag.parse('3'))

Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
    def test_too_low(self):
        diag = DiagnosticListIntOption('option', min_limit=5)
        with self.assertRaises(DiagnosticOptionError):
            diag.parse('3')

    def test_too_high(self):
        diag = DiagnosticListIntOption('option', max_limit=5)
        with self.assertRaises(DiagnosticOptionError):
            diag.parse('8')

    def test_parse_bad_value(self):
        diag = DiagnosticListIntOption('option')
        with self.assertRaises(ValueError):
            diag.parse('3.5')


Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed
class TestDiagnosticChoiceOption(TestCase):

    def test_choice_value(self):
        diag = DiagnosticChoiceOption('option', ('a', 'b'))
        self.assertEqual('a', diag.parse('a'))

    def test_choice_default_value(self):
        diag = DiagnosticChoiceOption('option', ('a', 'b'), default_value='a')
        self.assertEqual('a', diag.parse(''))

    def test_bad_default_value(self):
        with self.assertRaises(DiagnosticOptionError):
            DiagnosticChoiceOption('option', ('a', 'b'), default_value='c')

    def test_ignore_case_value(self):
        diag = DiagnosticChoiceOption('option', ('a', 'b'))
        self.assertEqual('b', diag.parse('b'))
        self.assertEqual('b', diag.parse('B'))

        diag = DiagnosticChoiceOption('option', ('a', 'b'), ignore_case=False)
        self.assertEqual('b', diag.parse('b'))
        with self.assertRaises(DiagnosticOptionError):
            self.assertEqual('b', diag.parse('B'))


class TestDiagnosticVariableOption(TestCase):

    def get_var_mock(self, name):
        mock = Mock()
        mock.short_name = name
        return mock

    @patch('earthdiagnostics.variable.VariableManager.get_variable')
    def test_parse(self, get_variable_mock):
        get_variable_mock.return_value = self.get_var_mock('var1')

        diag = DiagnosticVariableOption()
        self.assertEqual('var1', diag.parse('var1'))

    @patch('earthdiagnostics.variable.VariableManager.get_variable')
    def test_parse(self, get_variable_mock):
        get_variable_mock.return_value = self.get_var_mock('var1')

        diag = DiagnosticVariableOption()
        self.assertEqual('var1', diag.parse('var1'))

    @patch('earthdiagnostics.variable.VariableManager.get_variable')
    def test_not_recognized(self, get_variable_mock):
        get_variable_mock.return_value = None

        diag = DiagnosticVariableOption()
        self.assertEqual('var1', diag.parse('var1'))


class TestDiagnosticVariableListOption(TestCase):

    @patch('earthdiagnostics.variable.VariableManager.get_variable')
    def test_parse_multiple(self, get_variable_mock):
        get_variable_mock.side_effect = (self.get_var_mock('var1'), self.get_var_mock('var2'))
        diag = DiagnosticVariableListOption('variables')
        self.assertEqual(['var1', 'var2'], diag.parse('var1-var2'))

    @patch('earthdiagnostics.variable.VariableManager.get_variable')
    def test_parse_one(self, get_variable_mock):
        get_variable_mock.return_value = self.get_var_mock('var1')
        diag = DiagnosticVariableListOption('variables')
        self.assertEqual(['var1'], diag.parse('var1'))

    @patch('earthdiagnostics.variable.VariableManager.get_variable')
    def test_not_recognized(self, get_variable_mock):
        get_variable_mock.return_value = None
        diag = DiagnosticVariableListOption('variables')
        self.assertEqual(['var1'], diag.parse('var1'))

    def get_var_mock(self, name):
        mock = Mock()
        mock.short_name = name
        return mock

Javier Vegas-Regidor's avatar
Javier Vegas-Regidor committed

class TestDiagnostic(TestCase):

    def setUp(cls):
        class MockDiag(Diagnostic):
            pass
        TestDiagnostic.MockDiag = MockDiag

    def test_str(self):
        self.assertEqual(str(Diagnostic(None)), 'Developer must override base class __str__ method')

    def test_compute_is_virtual(self):
        with self.assertRaises(NotImplementedError):
            Diagnostic(None).compute()

    def test_declare_data_generated_is_virtual(self):
        with self.assertRaises(NotImplementedError):
            Diagnostic(None).declare_data_generated()

    def test_request_data_is_virtual(self):
        with self.assertRaises(NotImplementedError):
            Diagnostic(None).request_data()

    @patch.object(Diagnostic, 'dispatch')
    def test_set_status_call_dispatch(self, dispatch_mock):


        diag = Diagnostic(None)
        diag.status = DiagnosticStatus.FAILED
        dispatch_mock.assert_called_once_with(diag)

    @patch.object(Diagnostic, 'dispatch')
    def test_set_status_call_dispatch(self, dispatch_mock):
        class MockDiag(Diagnostic):
            pass

        diag = Diagnostic(None)
        diag.status = diag.status
        assert not dispatch_mock.called, 'Dispatch should not have been called'

    def test_register(self):
        with self.assertRaises(ValueError):
            Diagnostic.register(TestDiagnostic)

        with self.assertRaises(ValueError):
            Diagnostic.register(TestDiagnostic.MockDiag)

        TestDiagnostic.MockDiag.alias = 'mock'
        Diagnostic.register(TestDiagnostic.MockDiag)


    def test_get_diagnostic(self):
        self.assertIsNone(Diagnostic.get_diagnostic('none'))
        TestDiagnostic.MockDiag.alias = 'mock'
        Diagnostic.register(TestDiagnostic.MockDiag)
        self.assertIs(TestDiagnostic.MockDiag, Diagnostic.get_diagnostic('mock'))

    def test_generate_jobs(self):
        with self.assertRaises(NotImplementedError):
            Diagnostic.generate_jobs(None, [''])

    def test_compute(self):
        with self.assertRaises(NotImplementedError):
            Diagnostic(None).compute()

    def test_repr(self):
        self.assertEquals(Diagnostic(None).__repr__(), str(Diagnostic(None)))

    def test_empty_process_options(self):
        self.assertEqual(len(Diagnostic.process_options(('diag_name',), tuple())), 0)