import os
import multiprocessing as multi

from .util import sample_feature
from .util import lazy_wombat

import numpy as np
import dask.array as da
import xarray as xr
import pandas as pd
import geopandas as gpd
from rasterio.features import rasterize, shapes
from rasterio.warp import aligned_target
import shapely
from shapely.geometry import Polygon
from affine import Affine
from tqdm import tqdm
from deprecated import deprecated

shapely.speedups.enable()


def _iter_func(a):
    return a


class Converters(object):

    @staticmethod
    def indices_to_coords(col_index, row_index, transform):

        """
        Converts array indices to map coordinates

        Args:
            col_index (float or 1d array): The column index.
            row_index (float or 1d array): The row index.
            transform (Affine, DataArray, or tuple): The affine transform.

        Returns:

            ``tuple``:

                (x, y)

        Example:
            >>> import geowombat as gw
            >>> from geowombat.core import indices_to_coords
            >>>
            >>> with gw.open('image.tif') as src:
            >>>     x, y = indices_to_coords(j, i, src)
        """

        if not isinstance(transform, Affine):

            if isinstance(transform, tuple):
                transform = Affine(*transform)
            elif isinstance(transform, xr.DataArray):
                transform = transform.gw.meta.affine
            else:
                logger.exception(
                    '  The transform must be an instance of affine.Affine, an xarray.DataArray, or a tuple')
                raise TypeError

        return transform * (col_index, row_index)

    @staticmethod
    def coords_to_indices(x, y, transform):

        """
        Converts map coordinates to array indices

        Args:
            x (float or 1d array): The x coordinates.
            y (float or 1d array): The y coordinates.
            transform (object): The affine transform.

        Returns:

            ``tuple``:

                (col_index, row_index)

        Example:
            >>> import geowombat as gw
            >>> from geowombat.core import coords_to_indices
            >>>
            >>> with gw.open('image.tif') as src:
            >>>     j, i = coords_to_indices(x, y, src)
        """

        if not isinstance(transform, Affine):

            if isinstance(transform, tuple):
                transform = Affine(*transform)
            elif isinstance(transform, xr.DataArray):
                transform = transform.gw.meta.affine
            else:
                logger.exception(
                    '  The transform must be an instance of affine.Affine, an xarray.DataArray, or a tuple')
                raise TypeError

        col_index, row_index = ~transform * (x, y)

        return np.int64(col_index), np.int64(row_index)

    @staticmethod
    @deprecated('Deprecated since 1.0.6. Use indices_to_coords() instead.')
    def ij_to_xy(j, i, transform):

        """
        .. deprecated:: 1.0.6
            Use :func:`geowombat.indices_to_coords()` instead.

        Converts to array indices to map coordinates

        Args:
            j (float or 1d array): The column index.
            i (float or 1d array): The row index.
            transform (object): The affine transform.

        Returns:
            x, y
        """

        return transform * (j, i)

    @staticmethod
    @deprecated('Deprecated since 1.0.6. Use coords_to_indices() instead.')
    def xy_to_ij(x, y, transform):

        """
        .. deprecated:: 1.0.6
            Use :func:`geowombat.coords_to_indices()` instead.

        Converts map coordinates to array indices

        Args:
            x (float or 1d array): The x coordinates.
            y (float or 1d array): The y coordinates.
            transform (object): The affine transform.

        Returns:
            j, i
        """

        if not isinstance(transform, Affine):
            transform = Affine(*transform)

        x, y = ~transform * (x, y)

        return np.int64(x), np.int64(y)

    @staticmethod
    def dask_to_xarray(data,
                       dask_data,
                       band_names):

        """
        Converts a Dask array to an Xarray DataArray

        Args:
            data (DataArray): The DataArray with attribute information.
            dask_data (Dask Array): The Dask array to convert.
            band_names (1d array-like): The output band names.

        Returns:
            ``xarray.DataArray``
        """

        if len(dask_data.shape) == 2:
            dask_data = dask_data.reshape(1, dask_data.shape[0], dask_data.shape[1])

        return xr.DataArray(dask_data,
                            dims=('band', 'y', 'x'),
                            coords={'band': band_names,
                                    'y': data.y,
                                    'x': data.x},
                            attrs=data.attrs)

    @staticmethod
    def ndarray_to_xarray(data,
                          numpy_data,
                          band_names):

        """
        Converts a NumPy array to an Xarray DataArray

        Args:
            data (DataArray): The DataArray with attribute information.
            numpy_data (Dask Array): The Dask array to convert.
            band_names (1d array-like): The output band names.

        Returns:
            ``xarray.DataArray``
        """

        if len(numpy_data.shape) == 2:
            numpy_data = numpy_data[np.newaxis, :, :]

        return xr.DataArray(da.from_array(numpy_data,
                                          chunks=(1, data.gw.row_chunks, data.gw.col_chunks)),
                            dims=('band', 'y', 'x'),
                            coords={'band': band_names,
                                    'y': data.y,
                                    'x': data.x},
                            attrs=data.attrs)

    @staticmethod
    def xarray_to_xdataset(data_array,
                           band_names,
                           time_names,
                           ycoords=None,
                           xcoords=None,
                           attrs=None):

        """
        Converts an Xarray DataArray to a Xarray Dataset

        Args:
            data_array (DataArray)
            band_names (list)
            time_names (list)
            ycoords (1d array-like)
            xcoords (1d array-like)
            attrs (dict)

        Returns:
            Dataset
        """

        if len(data_array.shape) == 2:
            data_array = data_array.expand_dims('band')

        if len(data_array.shape) == 4:
            n_bands = data_array.shape[1]
        else:
            n_bands = data_array.shape[0]

        if not band_names:

            if n_bands == 1:
                band_names = ['1']
            else:
                band_names = list(map(str, range(1, n_bands + 1)))

        if time_names:

            return xr.Dataset({'bands': (['date', 'band', 'y', 'x'], data_array)},
                              coords={'date': time_names,
                                      'band': band_names,
                                      'y': ('y', ycoords),
                                      'x': ('x', xcoords)},
                              attrs=attrs)

        else:

            return xr.Dataset({'bands': (['band', 'y', 'x'], data_array.data)},
                              coords={'band': band_names,
                                      'y': ('y', data_array.y),
                                      'x': ('x', data_array.x)},
                              attrs=data_array.attrs)

    def prepare_points(self,
                       data,
                       aoi,
                       frac=1.0,
                       all_touched=False,
                       id_column='id',
                       mask=None,
                       n_jobs=8,
                       verbose=0):

        if isinstance(aoi, gpd.GeoDataFrame):
            df = aoi
        else:

            if isinstance(aoi, str):

                if not os.path.isfile(aoi):
                    logger.exception('  The AOI file does not exist.')

                df = gpd.read_file(aoi)

            else:
                logger.exception('  The AOI must be a vector file or a GeoDataFrame.')

        # Re-project the data to match the image CRS
        if isinstance(df.crs, str):

            if df.crs.lower().startswith('+proj'):

                if data.crs != df.crs:
                    df = df.to_crs(data.crs)

        elif isinstance(df.crs, int):

            if data.crs != CRS.from_epsg(df.crs).to_proj4():
                df = df.to_crs(data.crs)

        else:

            if data.crs != CRS.from_dict(df.crs).to_proj4():
                df = df.to_crs(data.crs)

        if verbose > 0:
            logger.info('  Checking geometry validity ...')

        # Ensure all geometry is valid
        df = df[df['geometry'].apply(lambda x_: x_ is not None)]

        if verbose > 0:
            logger.info('  Checking geometry extent ...')

        # Remove data outside of the image bounds
        if type(df.iloc[0].geometry) == Polygon:

            df = gpd.overlay(df,
                             gpd.GeoDataFrame(data=[0],
                                              geometry=[data.gw.meta.geometry],
                                              crs=df.crs),
                             how='intersection')

        else:

            # Clip points to the image bounds
            df = df[df.geometry.intersects(data.gw.unary_union)]

        if isinstance(mask, Polygon) or isinstance(mask, gpd.GeoDataFrame):

            if isinstance(mask, gpd.GeoDataFrame):

                if CRS.from_dict(mask.crs).to_proj4() != CRS.from_dict(df.crs).to_proj4():
                    mask = mask.to_crs(df.crs)

            if verbose > 0:
                logger.info('  Clipping geometry ...')

            df = df[df.within(mask)]

            if df.empty:
                logger.exception('  No geometry intersects the user-provided mask.')

        # Subset the DataArray
        # minx, miny, maxx, maxy = df.total_bounds
        #
        # obj_subset = self._obj.gw.subset(left=float(minx)-self._obj.res[0],
        #                                  top=float(maxy)+self._obj.res[0],
        #                                  right=float(maxx)+self._obj.res[0],
        #                                  bottom=float(miny)-self._obj.res[0])

        # Convert polygons to points
        if type(df.iloc[0].geometry) == Polygon:

            if verbose > 0:
                logger.info('  Converting polygons to points ...')

            df = self.polygons_to_points(data,
                                         df,
                                         frac=frac,
                                         all_touched=all_touched,
                                         id_column=id_column,
                                         n_jobs=n_jobs)

        # Ensure a unique index
        df.index = list(range(0, df.shape[0]))

        return df

    @staticmethod
    def polygons_to_points(data,
                           df,
                           frac=1.0,
                           all_touched=False,
                           id_column='id',
                           n_jobs=1):

        """
        Converts polygons to points

        Args:
            data (DataArray or Dataset): The ``xarray.DataArray`` or ``xarray.Dataset``.
            df (GeoDataFrame): The ``geopandas.GeoDataFrame`` containing the geometry to rasterize.
            frac (Optional[float]): A fractional subset of points to extract in each feature.
            all_touched (Optional[bool]): The ``all_touched`` argument is passed to ``rasterio.features.rasterize``.
            id_column (Optional[str]): The 'id' column.
            n_jobs (Optional[int]): The number of features to rasterize in parallel.

        Returns:
            ``geopandas.GeoDataFrame``
        """

        meta = data.gw.meta

        dataframes = list()

        with multi.Pool(processes=n_jobs) as pool:

            for i in tqdm(pool.imap(_iter_func, range(0, df.shape[0])), total=df.shape[0]):

                # Get the current feature's geometry
                dfrow = df.iloc[i]

                point_df = sample_feature(dfrow[id_column],
                                          dfrow.geometry,
                                          data.crs,
                                          data.res,
                                          all_touched,
                                          meta,
                                          frac)

                if not point_df.empty:
                    dataframes.append(point_df)

        dataframes = pd.concat(dataframes, axis=0)

        # Make the points unique
        dataframes.loc[:, 'point'] = np.arange(0, dataframes.shape[0])

        return dataframes

    @staticmethod
    def array_to_polygon(data, mask=None, connectivity=4, num_workers=1):

        """
        Converts an ``xarray.DataArray` to a ``geopandas.GeoDataFrame``

        Args:
            data (DataArray): The ``xarray.DataArray`` to convert.
            mask (Optional[str, numpy ndarray, or rasterio Band object]): Must evaluate to bool (rasterio.bool_ or rasterio.uint8).
                Values of False or 0 will be excluded from feature generation. Note well that this is the inverse sense from
                Numpy's, where a mask value of True indicates invalid data in an array. If source is a Numpy masked array
                and mask is None, the source's mask will be inverted and used in place of mask. if ``mask`` is equal to
                'source', then ``data`` is used as the mask.
            connectivity (Optional[int]): Use 4 or 8 pixel connectivity for grouping pixels into features.
            num_workers (Optional[int]): The number of parallel workers to send to ``dask.compute``.

        Returns:
            ``geopandas.GeoDataFrame``

        Example:
            >>> import geowombat as gw
            >>>
            >>> with gw.open('image.tif') as src:
            >>>
            >>>     # Convert the input image to a GeoDataFrame
            >>>     df = gw.array_to_polygon(src,
            >>>                              mask='source',
            >>>                              num_workers=8)
        """

        if not hasattr(data, 'transform'):
            logger.exception("  The data should have a 'transform' object.")

        if not hasattr(data, 'crs'):
            logger.exception("  The data should have a 'crs' object.")

        if isinstance(mask, str):

            if mask == 'source':
                mask = data.astype('uint8').data.compute(num_workers=num_workers)

        poly_objects = shapes(data.data.compute(num_workers=num_workers),
                              mask=mask,
                              connectivity=connectivity,
                              transform=data.transform)

        poly_geom = [Polygon(p[0]['coordinates'][0]) for p in poly_objects]

        return gpd.GeoDataFrame(data=np.ones(len(poly_geom), dtype='uint8'),
                                geometry=poly_geom,
                                crs=data.crs)

    @staticmethod
    @deprecated('Deprecated since 1.2.0. Use array_to_polygon() instead.')
    def to_geodataframe(data, mask=None, connectivity=4, num_workers=1):

        """
        .. deprecated:: 1.2.0
            Use :func:`geowombat.array_to_polygon()` instead.

        Converts a Dask array to a GeoDataFrame

        Args:
            data (DataArray): The ``xarray.DataArray`` to convert.
            mask (Optional[str, numpy ndarray, or rasterio Band object]): Must evaluate to bool (rasterio.bool_ or rasterio.uint8).
                Values of False or 0 will be excluded from feature generation. Note well that this is the inverse sense from
                Numpy's, where a mask value of True indicates invalid data in an array. If source is a Numpy masked array
                and mask is None, the source's mask will be inverted and used in place of mask. if ``mask`` is equal to
                'source', then ``data`` is used as the mask.
            connectivity (Optional[int]): Use 4 or 8 pixel connectivity for grouping pixels into features.
            num_workers (Optional[int]): The number of parallel workers to send to ``dask.compute``.

        Returns:
            ``GeoDataFrame``

        Example:
            >>> import geowombat as gw
            >>>
            >>> with gw.open('image.tif') as src:
            >>>
            >>>     # Convert the input image to a GeoDataFrame
            >>>     df = gw.to_geodataframe(src,
            >>>                             mask='source',
            >>>                             num_workers=8)
        """

        if not hasattr(data, 'transform'):
            logger.exception("  The data should have a 'transform' object.")

        if not hasattr(data, 'crs'):
            logger.exception("  The data should have a 'crs' object.")

        if isinstance(mask, str):

            if mask == 'source':
                mask = data.astype('uint8').data.compute(num_workers=num_workers)

        poly_objects = shapes(data.data.compute(num_workers=num_workers),
                              mask=mask,
                              connectivity=connectivity,
                              transform=data.transform)

        poly_geom = [Polygon(p[0]['coordinates'][0]) for p in poly_objects]

        return gpd.GeoDataFrame(data=np.ones(len(poly_geom), dtype='uint8'),
                                geometry=poly_geom,
                                crs=data.crs)

    @lazy_wombat
    def polygon_to_array(self,
                         polygon,
                         data=None,
                         cellx=None,
                         celly=None,
                         band_name=None,
                         row_chunks=512,
                         col_chunks=512,
                         src_res=None,
                         fill=0,
                         default_value=1,
                         all_touched=True,
                         dtype='uint8',
                         sindex=None):

        """
        Converts a polygon geometry to an ``xarray.DataArray``.

        Args:
            polygon (GeoDataFrame | str): The ``geopandas.DataFrame`` or file with polygon geometry.
            data (Optional[DataArray]): An ``xarray.DataArray`` to use as a reference.
            cellx (Optional[float]): The output cell x size.
            celly (Optional[float]): The output cell y size.
            band_name (Optional[list]): The ``xarray.DataArray`` band name.
            row_chunks (Optional[int]): The ``dask`` row chunk size.
            col_chunks (Optional[int]): The ``dask`` column chunk size.
            src_res (Optional[tuple]: A source resolution to align to.
            fill (Optional[int]): The output fill value for ``rasterio.features.rasterize``.
            default_value (Optional[int]): The output default value for ``rasterio.features.rasterize``.
            all_touched (Optional[int]): The ``all_touched`` value for ``rasterio.features.rasterize``.
            dtype (Optional[int]): The output data type for ``rasterio.features.rasterize``.
            sindex (Optional[object]): An instanced of ``geopandas.GeoDataFrame.sindex``.

        Returns:
            ``xarray.DataArray``

        Example:
            >>> import geowombat as gw
            >>> import geopandas as gpd
            >>>
            >>> df = gpd.read_file('polygons.gpkg')
            >>>
            >>> # 100x100 cell size
            >>> data = gw.polygon_to_array(df, 100.0, 100.0)
            >>>
            >>> # Align to an existing image
            >>> with gw.open('image.tif') as src:
            >>>     data = gw.polygon_to_array(df, data=src)
        """

        if not band_name:
            band_name = [1]

        if isinstance(polygon, gpd.GeoDataFrame):
            dataframe = polygon
        else:

            if os.path.isfile(polygon):
                dataframe = gpd.read_file(polygon)
            else:
                logger.exception('  The polygon file does not exists.')
                raise OSError

        if isinstance(data, xr.DataArray):

            if dataframe.crs != data.crs:

                # Transform the geometry
                dataframe = dataframe.to_crs(data.crs)

            if sindex:

                # Get the R-tree spatial index
                sindex = dataframe.sindex

            # Get intersecting features
            int_idx = sorted(list(sindex.intersection(tuple(data.gw.geodataframe.bounds.values.flatten()))))

            if not int_idx:

                return self.dask_to_xarray(data, da.zeros((1, data.gw.nrows, data.gw.ncols),
                                                          chunks=(1, data.gw.row_chunks, data.gw.col_chunks),
                                                          dtype=data.dtype.name), [1])

            # Subset to the intersecting features
            dataframe = dataframe.iloc[int_idx]

            # Clip the geometry
            dataframe = gpd.overlay(dataframe, data.gw.geodataframe, how='intersection')

            if dataframe.empty:

                return self.dask_to_xarray(data, da.zeros((1, data.gw.nrows, data.gw.ncols),
                                                          chunks=(1, data.gw.row_chunks, data.gw.col_chunks),
                                                          dtype=data.dtype.name), [1])

            cellx = data.gw.cellx
            celly = data.gw.celly
            row_chunks = data.gw.row_chunks
            col_chunks = data.gw.col_chunks
            src_res = None

            left, bottom, right, top = data.gw.bounds

            dst_height = data.gw.nrows
            dst_width = data.gw.ncols

            dst_transform = data.transform

        else:

            left, bottom, right, top = dataframe.total_bounds.flatten().tolist()

            dst_height = int((top - bottom) / abs(celly))
            dst_width = int((right - left) / abs(cellx))

            dst_transform = Affine(cellx, 0.0, left, 0.0, -celly, top)

        if src_res:
            dst_transform = aligned_target(dst_transform,
                                           dst_width,
                                           dst_height,
                                           src_res)[0]

            left = dst_transform[2]
            top = dst_transform[5]

            dst_transform = Affine(cellx, 0.0, left, 0.0, -celly, top)

        varray = rasterize(dataframe.geometry.values,
                           out_shape=(dst_height, dst_width),
                           transform=dst_transform,
                           fill=fill,
                           default_value=default_value,
                           all_touched=all_touched,
                           dtype=dtype)

        cellxh = abs(cellx) / 2.0
        cellyh = abs(celly) / 2.0

        xcoords = np.arange(left + cellxh, left + cellxh + dst_width * abs(cellx), cellx)
        ycoords = np.arange(top - cellyh, top - cellyh - dst_height * abs(celly), -celly)

        attrs = {'transform': dst_transform[:6],
                 'crs': dataframe.crs,
                 'res': (cellx, celly),
                 'is_tiled': 1}

        return xr.DataArray(data=da.from_array(varray[np.newaxis, :, :],
                                               chunks=(1, row_chunks, col_chunks)),
                            coords={'band': band_name,
                                    'y': ycoords,
                                    'x': xcoords},
                            dims=('band', 'y', 'x'),
                            attrs=attrs)

    @staticmethod
    @deprecated('Deprecated since 1.2.0. Use polygon_to_array() instead.')
    def geodataframe_to_array(dataframe,
                              data=None,
                              cellx=None,
                              celly=None,
                              band_name=None,
                              row_chunks=512,
                              col_chunks=512,
                              src_res=None,
                              fill=0,
                              default_value=1,
                              all_touched=True,
                              dtype='uint8'):

        """
        .. deprecated:: 1.2.0
            Use :func:`geowombat.polygon_to_array()` instead.

        Converts a polygon ``geopandas.GeoDataFrame`` to an ``xarray.DataArray``.

        Args:
            dataframe (GeoDataFrame): The ``geopandas.DataFrame`` or file with polygon geometries.
            data (Optional[DataArray]): An ``xarray.DataArray`` to use as a reference.
            cellx (Optional[float]): The output cell x size.
            celly (Optional[float]): The output cell y size.
            band_name (Optional[list]): The ``xarray.DataArray`` band name.
            row_chunks (Optional[int]): The ``dask`` row chunk size.
            col_chunks (Optional[int]): The ``dask`` column chunk size.
            src_res (Optional[tuple]: A source resolution to align to.
            fill (Optional[int]): The output fill value for ``rasterio.features.rasterize``.
            default_value (Optional[int]): The output default value for ``rasterio.features.rasterize``.
            all_touched (Optional[int]): The ``all_touched`` value for ``rasterio.features.rasterize``.
            dtype (Optional[int]): The output data type for ``rasterio.features.rasterize``.

        Returns:
            ``xarray.DataArray``

        Example:
            >>> import geowombat as gw
            >>> import geopandas as gpd
            >>>
            >>> df = gpd.read_file('polygons.gpkg')
            >>>
            >>> # 100x100 cell size
            >>> data = gw.geodataframe_to_array(df, 100.0, 100.0)
            >>>
            >>> # Align to an existing image
            >>> with gw.open('image.tif') as src:
            >>>
            >>>     data = gw.geodataframe_to_array(df,
            >>>                                     cellx=src.gw.cellx,
            >>>                                     celly=src.gw.celly,
            >>>                                     row_chunks=src.gw.row_chunks,
            >>>                                     col_chunks=src.gw.col_chunks,
            >>>                                     src_res=src_res)
        """

        if not band_name:
            band_name = [1]

        if not isinstance(dataframe, gpd.GeoDataFrame):
            dataframe = gpd.read_file(dataframe)

        if isinstance(data, xr.DataArray):

            if dataframe.crs != data.crs:
                # Transform the geometry
                dataframe = dataframe.to_crs(data.crs)

            # Get the R-tree spatial index
            sindex = dataframe.sindex

            # Get intersecting features
            int_idx = sorted(list(sindex.intersection(tuple(data.gw.geodataframe.bounds.values.flatten()))))

            if not int_idx:

                return dask_to_xarray(data, da.zeros((1, data.gw.nrows, data.gw.ncols),
                                                     chunks=(1, data.gw.row_chunks, data.gw.col_chunks),
                                                     dtype=data.dtype.name), [1])

            # Subset to the intersecting features
            dataframe = dataframe.iloc[int_idx]

            # Clip the geometry
            dataframe = gpd.overlay(dataframe, data.gw.geodataframe, how='intersection')

            if dataframe.empty:

                return dask_to_xarray(data, da.zeros((1, data.gw.nrows, data.gw.ncols),
                                                     chunks=(1, data.gw.row_chunks, data.gw.col_chunks),
                                                     dtype=data.dtype.name), [1])

            cellx = data.gw.cellx
            celly = data.gw.celly
            row_chunks = data.gw.row_chunks
            col_chunks = data.gw.col_chunks
            src_res = None

            left, bottom, right, top = data.gw.bounds

            dst_height = data.gw.nrows
            dst_width = data.gw.ncols

            dst_transform = data.transform

        else:

            left, bottom, right, top = dataframe.total_bounds.flatten().tolist()

            dst_height = int((top - bottom) / abs(celly))
            dst_width = int((right - left) / abs(cellx))

            dst_transform = Affine(cellx, 0.0, left, 0.0, -celly, top)

        if src_res:
            dst_transform = aligned_target(dst_transform,
                                           dst_width,
                                           dst_height,
                                           src_res)[0]

            left = dst_transform[2]
            top = dst_transform[5]

            dst_transform = Affine(cellx, 0.0, left, 0.0, -celly, top)

        varray = rasterize(dataframe.geometry.values,
                           out_shape=(dst_height, dst_width),
                           transform=dst_transform,
                           fill=fill,
                           default_value=default_value,
                           all_touched=all_touched,
                           dtype=dtype)

        cellxh = abs(cellx) / 2.0
        cellyh = abs(celly) / 2.0

        xcoords = np.arange(left + cellxh, left + cellxh + dst_width * abs(cellx), cellx)
        ycoords = np.arange(top - cellyh, top - cellyh - dst_height * abs(celly), -celly)

        attrs = {'transform': dst_transform[:6],
                 'crs': dataframe.crs,
                 'res': (cellx, celly),
                 'is_tiled': 1}

        return xr.DataArray(data=da.from_array(varray[np.newaxis, :, :],
                                               chunks=(1, row_chunks, col_chunks)),
                            coords={'band': band_name,
                                    'y': ycoords,
                                    'x': xcoords},
                            dims=('band', 'y', 'x'),
                            attrs=attrs)


@deprecated('Deprecated since 1.2.0. Use geowombat.core.dask_to_xarray() instead.')
def dask_to_xarray(data, dask_data, band_names):

    """
    .. deprecated:: 1.2.0
        Use :func:`geowombat.core.dask_to_xarray()` instead.

    Converts a Dask array to an Xarray DataArray

    Args:
        data (DataArray): The DataArray with attribute information.
        dask_data (Dask Array): The Dask array to convert.
        band_names (1d array-like): The output band names.

    Returns:
        ``xarray.DataArray``
    """

    if len(dask_data.shape) == 2:
        dask_data = dask_data.reshape(1, dask_data.shape[0], dask_data.shape[1])

    return xr.DataArray(dask_data,
                        dims=('band', 'y', 'x'),
                        coords={'band': band_names,
                                'y': data.y,
                                'x': data.x},
                        attrs=data.attrs)


@deprecated('Deprecated since 1.2.0. Use geowombat.core.ndarray_to_xarray() instead.')
def ndarray_to_xarray(data, numpy_data, band_names):

    """
    .. deprecated:: 1.2.0
        Use :func:`geowombat.core.ndarray_to_xarray()` instead.

    Converts a NumPy array to an Xarray DataArray

    Args:
        data (DataArray): The DataArray with attribute information.
        numpy_data (Dask Array): The Dask array to convert.
        band_names (1d array-like): The output band names.

    Returns:
        ``xarray.DataArray``
    """

    if len(numpy_data.shape) == 2:
        numpy_data = numpy_data[np.newaxis, :, :]

    return xr.DataArray(da.from_array(numpy_data,
                                      chunks=(1, data.gw.row_chunks, data.gw.col_chunks)),
                        dims=('band', 'y', 'x'),
                        coords={'band': band_names,
                                'y': data.y,
                                'x': data.x},
                        attrs=data.attrs)


@deprecated('Deprecated since 1.2.0.')
def xarray_to_xdataset(data_array, band_names, time_names, ycoords=None, xcoords=None, attrs=None):

    """
    .. deprecated:: 1.2.0

    Converts an Xarray DataArray to a Xarray Dataset

    Args:
        data_array (DataArray)
        band_names (list)
        time_names (list)
        ycoords (1d array-like)
        xcoords (1d array-like)
        attrs (dict)

    Returns:
        Dataset
    """

    if len(data_array.shape) == 2:
        data_array = data_array.expand_dims('band')

    if len(data_array.shape) == 4:
        n_bands = data_array.shape[1]
    else:
        n_bands = data_array.shape[0]

    if not band_names:

        if n_bands == 1:
            band_names = ['1']
        else:
            band_names = list(map(str, range(1, n_bands+1)))

    if time_names:

        return xr.Dataset({'bands': (['date', 'band', 'y', 'x'], data_array)},
                             coords={'date': time_names,
                                     'band': band_names,
                                     'y': ('y', ycoords),
                                     'x': ('x', xcoords)},
                             attrs=attrs)

    else:

        return xr.Dataset({'bands': (['band', 'y', 'x'], data_array.data)},
                          coords={'band': band_names,
                                  'y': ('y', data_array.y),
                                  'x': ('x', data_array.x)},
                          attrs=data_array.attrs)
