import numpy as np

from ocgis import RequestDataset
from ocgis.interface.base.dimension.base import VectorDimension
from ocgis.interface.base.dimension.spatial import SpatialGridDimension
from ocgis.interface.nc.spatial import NcSpatialGridDimension
from ocgis.test.base import TestBase


class TestNcSpatialGridDimension(TestBase):
    def get(self, **kwargs):
        kwargs['row'] = kwargs.pop('row', VectorDimension(value=[4, 5]))
        kwargs['col'] = kwargs.pop('col', VectorDimension(value=[6, 7, 8]))
        return NcSpatialGridDimension(**kwargs)

    def test_init(self):
        self.assertEqual(NcSpatialGridDimension.__bases__, (SpatialGridDimension,))
        ngd = self.get()
        self.assertIsInstance(ngd, NcSpatialGridDimension)

    def test_getitem(self):
        src_idx = {'row': np.array([5, 6, 7, 8]), 'col': np.array([9, 10, 11])}
        grid = NcSpatialGridDimension(src_idx=src_idx, request_dataset='foo')
        self.assertIsNone(grid._uid)
        sub = grid[1:3, 1]
        self.assertNumpyAll(sub._src_idx['col'], np.array([10]))
        self.assertNumpyAll(sub._src_idx['row'], np.array([6, 7]))
        for k, v in src_idx.iteritems():
            self.assertNumpyAll(grid._src_idx[k], v)

    def test_get_uid(self):
        src_idx = {'row': np.array([5, 6, 7, 8]), 'col': np.array([9, 10, 11])}
        grid = NcSpatialGridDimension(src_idx=src_idx, request_dataset='foo')
        uid1 = grid._get_uid_()
        self.assertEqual(uid1.shape, (4, 3))

        value = np.ma.array(np.zeros((2, 4, 3)))
        grid = NcSpatialGridDimension(value=value)
        uid2 = grid._get_uid_()
        self.assertEqual(uid2.shape, (4, 3))

        self.assertNumpyAll(uid1, uid2)

    def test_get_value_from_source(self):
        path = self.get_netcdf_path_no_row_column()
        rd = RequestDataset(path)

        src_idx = {'row': np.array([0, 1]), 'col': np.array([0])}
        grid = NcSpatialGridDimension(request_dataset=rd, src_idx=src_idx, name_row='yc', name_col='xc')
        self.assertEqual(grid.value.shape, (2, 2, 1))
        with self.nc_scope(path) as ds:
            var_row = ds.variables[grid.name_row]
            var_col = ds.variables[grid.name_col]
            self.assertNumpyAll(var_row[:, 0].reshape(2, 1), grid.value[0].data)
            self.assertNumpyAll(var_col[:, 0].reshape(2, 1), grid.value[1].data)

        src_idx = {'row': np.array([0]), 'col': np.array([1])}
        grid = NcSpatialGridDimension(request_dataset=rd, src_idx=src_idx, name_row='yc', name_col='xc')
        self.assertIsNone(grid._value)
        self.assertIsNone(grid._corners)
        self.assertEqual(grid.value.shape, (2, 1, 1))
        self.assertEqual(grid.corners.shape, (2, 1, 1, 4))
        self.assertEqual(grid.corners_esmf.shape, (2, 2, 2))
        actual = np.ma.array([[[[3.5, 3.5, 4.5, 4.5]]], [[[45.0, 55.0, 55.0, 45.0]]]])
        self.assertNumpyAll(actual, grid.corners)

    def test_shape(self):
        src_idx = {'row': np.array([5, 6, 7, 8]), 'col': np.array([9, 10, 11])}
        grid = NcSpatialGridDimension(src_idx=src_idx, request_dataset='foo')
        self.assertEqual(grid.shape, (4, 3))
        self.assertIsNone(grid._value)

        row = VectorDimension(value=[4, 5])
        col = VectorDimension(value=[6, 7, 8])
        grid = NcSpatialGridDimension(row=row, col=col)
        self.assertEqual(grid.shape, (2, 3))

    def test_src_idx(self):
        src_idx = {'row': np.array([5]), 'col': np.array([6])}
        ngd = self.get(src_idx=src_idx)
        self.assertEqual(ngd._src_idx, src_idx)

        # Test assertions.
        with self.assertRaises(AssertionError):
            self.get(src_idx=[])

    def test_validate(self):
        with self.assertRaises(ValueError):
            NcSpatialGridDimension()
        NcSpatialGridDimension(request_dataset='foo')

    def test_value(self):
        row = VectorDimension(value=[4, 5])
        col = VectorDimension(value=[6, 7, 8])
        grid = NcSpatialGridDimension(row=row, col=col)
        self.assertEqual(grid.shape, (2, 3))

        value = grid.value.copy()
        grid = NcSpatialGridDimension(value=value)
        self.assertNumpyAll(grid.value, value)
