# The MIT License (MIT)
# Copyright (c) 2020 by the xcube development team and contributors
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Tuple, Sequence, Dict, Optional, Mapping, Union, Hashable

import numpy as np
import xarray as xr


class CubeSchema:
    """
    A schema that can be used to create new xcube datasets.
    The given *shape*, *dims*, and *chunks*, *coords* apply to all data variables.

    :param shape: A tuple of dimension sizes.
    :param coords: A dictionary of coordinate variables. Must have values for all *dims*.
    :param dims: A sequence of dimension names. Defaults to ``('time', 'lat', 'lon')``.
    :param chunks: A tuple of chunk sizes in each dimension.
    """

    def __init__(self,
                 shape: Sequence[int],
                 coords: Mapping[str, xr.DataArray],
                 x_name: str = 'lon',
                 y_name: str = 'lat',
                 time_name: str = 'time',
                 dims: Sequence[str] = None,
                 chunks: Sequence[int] = None):

        if not shape:
            raise ValueError('shape must be a sequence of integer sizes')
        if not coords:
            raise ValueError('coords must be a mapping from dimension names to label arrays')
        if not x_name:
            raise ValueError('x_name must be given')
        if not y_name:
            raise ValueError('y_name must be given')
        if not time_name:
            raise ValueError('time_name must be given')

        ndim = len(shape)
        if ndim < 3:
            raise ValueError('shape must have at least three dimensions')
        dims = tuple(dims) or (time_name, y_name, x_name)
        if dims and len(dims) != ndim:
            raise ValueError('dims must have same length as shape')
        if x_name not in coords or y_name not in coords or time_name not in coords:
            raise ValueError(f'missing variables {x_name!r}, {y_name!r}, {time_name!r} in coords')
        x_var, y_var, time_var = coords.get(x_name), coords.get(y_name), coords.get(time_name)
        if x_var.ndim != 1 or y_var.ndim != 1 or time_var.ndim != 1:
            raise ValueError(f'variables {x_name!r}, {y_name!r}, {time_name!r} in coords must be 1-D')
        x_dim, y_dim, time_dim = x_var.dims[0], y_var.dims[0], time_var.dims[0]
        if dims[0] != time_dim:
            raise ValueError(f"the first dimension in dims must be {time_dim!r}")
        if dims[-2:] != (y_dim, x_dim):
            raise ValueError(f"the last two dimensions in dims must be {y_dim!r} and {x_dim!r}")
        if chunks and len(chunks) != ndim:
            raise ValueError('chunks must have same length as shape')
        for i in range(ndim):
            dim_name = dims[i]
            dim_size = shape[i]
            if dim_name not in coords:
                raise ValueError(f'missing dimension {dim_name!r} in coords')
            dim_labels = coords[dim_name]
            if len(dim_labels.shape) != 1:
                raise ValueError(f'labels of {dim_name!r} in coords must be one-dimensional')
            if len(dim_labels) != dim_size:
                raise ValueError(f'number of labels of {dim_name!r} in coords does not match shape')

        self._shape = tuple(shape)
        self._x_name = x_name
        self._y_name = y_name
        self._time_name = time_name
        self._dims = dims
        self._chunks = tuple(chunks) if chunks else None
        self._coords = dict(coords)

    @property
    def ndim(self) -> int:
        """Number of dimensions."""
        return len(self._dims)

    @property
    def dims(self) -> Tuple[str, ...]:
        """Tuple of dimension names."""
        return self._dims

    @property
    def x_name(self) -> str:
        """Name of the spatial x coordinate variable."""
        return self._x_name

    @property
    def y_name(self) -> str:
        """Name of the spatial y coordinate variable."""
        return self._y_name

    @property
    def time_name(self) -> str:
        """Name of the time coordinate variable."""
        return self._time_name

    @property
    def x_var(self) -> xr.DataArray:
        """Spatial x coordinate variable."""
        return self._coords[self._x_name]

    @property
    def y_var(self) -> xr.DataArray:
        """Spatial y coordinate variable."""
        return self._coords[self._y_name]

    @property
    def time_var(self) -> xr.DataArray:
        """Time coordinate variable."""
        return self._coords[self._time_name]

    @property
    def x_dim(self) -> str:
        """Name of the spatial x dimension."""
        return self._dims[-1]

    @property
    def y_dim(self) -> str:
        """Name of the spatial y dimension."""
        return self._dims[-2]

    @property
    def time_dim(self) -> str:
        """Name of the time dimension."""
        return self._dims[0]

    @property
    def x_size(self) -> int:
        """Size of the spatial x dimension."""
        return self._shape[-1]

    @property
    def y_size(self) -> int:
        """Size of the spatial y dimension."""
        return self._shape[-2]

    @property
    def time_size(self) -> int:
        """Size of the time dimension."""
        return self._shape[0]

    @property
    def shape(self) -> Tuple[int, ...]:
        """Tuple of dimension sizes."""
        return self._shape

    @property
    def chunks(self) -> Optional[Tuple[int]]:
        """Tuple of dimension chunk sizes."""
        return self._chunks

    @property
    def coords(self) -> Dict[str, xr.DataArray]:
        """Dictionary of coordinate variables."""
        return self._coords

    @classmethod
    def new(cls, cube: xr.Dataset) -> 'CubeSchema':
        """Create a cube schema from given *cube*."""
        return get_cube_schema(cube)

    def _repr_html_(self):
        """Return a HTML representation for Jupyter Notebooks."""
        return (
            f'<table>'
            f'<tr><td>Shape:</td><td>{self.shape}</td></tr>'
            f'<tr><td>Chunk sizes:</td><td>{self.chunks}</td></tr>'
            f'<tr><td>Dimensions:</td><td>{self.dims}</td></tr>'
            f'</table>'
        )


# TODO (forman): code duplication with xcube.core.verify._check_data_variables(), line 76
def get_cube_schema(cube: xr.Dataset) -> CubeSchema:
    """
    Derive cube schema from given *cube*.

    :param cube: The data cube.
    :return: The cube schema.
    """

    xy_var_names = get_dataset_xy_var_names(cube, must_exist=True, dataset_arg_name='cube')
    time_var_name = get_dataset_time_var_name(cube, must_exist=True, dataset_arg_name='cube')

    first_dims = None
    first_shape = None
    first_chunks = None
    first_coords = None

    for var_name, var in cube.data_vars.items():

        dims = var.dims
        if first_dims is None:
            first_dims = dims
        elif first_dims != dims:
            raise ValueError(f'all variables must have same dimensions, but variable {var_name!r} '
                             f'has dimensions {dims!r}')

        shape = var.shape
        if first_shape is None:
            first_shape = shape
        elif first_shape != shape:
            raise ValueError(f'all variables must have same shape, but variable {var_name!r} '
                             f'has shape {shape!r}')

        coords = var.coords
        if first_coords is None:
            first_coords = coords

        dask_chunks = var.chunks
        if dask_chunks:
            chunks = []
            for i in range(var.ndim):
                dim_name = var.dims[i]
                dim_chunk_sizes = dask_chunks[i]
                first_size = dim_chunk_sizes[0]
                if any(size != first_size for size in dim_chunk_sizes[1:-1]):
                    raise ValueError(f'dimension {dim_name!r} of variable {var_name!r} has chunks of different sizes: '
                                     f'{dim_chunk_sizes!r}')
                chunks.append(first_size)
            chunks = tuple(chunks)
            if first_chunks is None:
                first_chunks = chunks
            elif first_chunks != chunks:
                raise ValueError(f'all variables must have same chunks, but variable {var_name!r} '
                                 f'has chunks {chunks!r}')

    if first_dims is None:
        raise ValueError('cube is empty')

    return CubeSchema(first_shape,
                      first_coords,
                      x_name=xy_var_names[0],
                      y_name=xy_var_names[1],
                      time_name=time_var_name,
                      dims=tuple(str(d) for d in first_dims),
                      chunks=first_chunks)


def get_dataset_xy_var_names(coords: Union[xr.Dataset, xr.DataArray, Mapping[Hashable, xr.DataArray]],
                             must_exist: bool = False,
                             dataset_arg_name: str = 'dataset') -> Optional[Tuple[str, str]]:
    if hasattr(coords, 'coords'):
        coords = coords.coords
    x_var_name = None
    y_var_name = None
    for var_name, var in coords.items():
        if var.attrs.get('standard_name') == 'projection_x_coordinate' \
                or var.attrs.get('long_name') == 'x coordinate of projection':
            if var.ndim == 1:
                x_var_name = var_name
        if var.attrs.get('standard_name') == 'projection_y_coordinate' \
                or var.attrs.get('long_name') == 'y coordinate of projection':
            if var.ndim == 1:
                y_var_name = var_name
        if x_var_name and y_var_name:
            return str(x_var_name), str(y_var_name)

    x_var_name = None
    y_var_name = None
    for var_name, var in coords.items():
        if var.attrs.get('long_name') == 'longitude':
            if var.ndim == 1:
                x_var_name = var_name
        if var.attrs.get('long_name') == 'latitude':
            if var.ndim == 1:
                y_var_name = var_name
        if x_var_name and y_var_name:
            return str(x_var_name), str(y_var_name)

    for x_var_name, y_var_name in (('lon', 'lat'), ('x', 'y')):
        if x_var_name in coords and y_var_name in coords:
            x_var = coords[x_var_name]
            y_var = coords[y_var_name]
            if x_var.ndim == 1 and y_var.ndim == 1:
                return x_var_name, y_var_name

    if must_exist:
        raise ValueError(f'{dataset_arg_name} has no valid spatial coordinate variables')

    return None


def get_dataset_time_var_name(dataset: Union[xr.Dataset, xr.DataArray],
                              must_exist: bool = False,
                              dataset_arg_name: str = 'dataset') -> Optional[str]:
    time_var_name = 'time'
    if time_var_name in dataset.coords:
        time_var = dataset.coords[time_var_name]
        if time_var.ndim == 1 and np.issubdtype(time_var.dtype, np.datetime64):
            return time_var_name

    if must_exist:
        raise ValueError(f'{dataset_arg_name} has no valid time coordinate variable')

    return None


def get_dataset_bounds_var_name(dataset: Union[xr.Dataset, xr.DataArray],
                                var_name: str,
                                must_exist: bool = False,
                                dataset_arg_name: str = 'dataset') -> Optional[str]:
    if var_name in dataset.coords:
        var = dataset[var_name]
        bounds_var_name = var.attrs.get('bounds', f'{var_name}_bnds')
        if bounds_var_name in dataset:
            bounds_var = dataset[bounds_var_name]
            if bounds_var.ndim == 2 \
                    and bounds_var.shape[0] == var.shape[0] and bounds_var.shape[1] == 2:
                return bounds_var_name

    if must_exist:
        raise ValueError(f'{dataset_arg_name} has no valid bounds variable for variable {var_name!r}')

    return None
