// This file is part of Eigenx, an extension of Eigen.
//
// Copyright (C) 2018-2019 Allan Leal
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_EIGENXFUNCTIONS_H
#define EIGEN_EIGENXFUNCTIONS_H

#if EIGEN_HAS_CXX11

namespace Eigen {

namespace internal {

template<typename T>
using VectorX = Matrix<T, Dynamic, 1>; ///< Auxiliary alias for dynamic vectors

template<typename T>
using RowVectorX = Matrix<T, 1, Dynamic, Eigen::RowMajor, 1, Dynamic>; ///< Auxiliary alias for dynamic row vectors

template<typename T>
using MatrixX = Matrix<T, Dynamic, Dynamic>; ///< Auxiliary alias for dynamic matrices

} // namespace internal

/// Return an expression of a zero vector.
/// @param rows The number of rows
template<typename T = double>
auto zeros(Index rows) -> decltype(internal::VectorX<T>::Zero(rows))
{
    return internal::VectorX<T>::Zero(rows);
}

/// Return an expression of a vector with entries equal to one.
/// @param rows The number of rows
template<typename T = double>
auto ones(Index rows) -> decltype(internal::VectorX<T>::Ones(rows))
{
    return internal::VectorX<T>::Ones(rows);
}

/// Return an expression of a vector with entries equal to a given constant.
/// @param rows The number of rows
/// @param val The constant value
template<typename T = double, typename U>
auto constants(Index rows, const U& val) -> decltype(internal::VectorX<T>::Constant(rows, val))
{
    return internal::VectorX<T>::Constant(rows, val);
}

/// Return an expression of a vector with random entries.
/// @param rows The number of rows
template<typename T = double>
auto random(Index rows) -> decltype(internal::VectorX<T>::Random(rows))
{
    return internal::VectorX<T>::Random(rows);
}

/// Return a linearly spaced vector.
/// @param rows The number of rows
/// @param start The start of the sequence
/// @param stop The stop of the sequence
template<typename T = double, typename U, typename V>
auto linspace(Index rows, const U& start, const V& stop) -> decltype(internal::VectorX<T>::LinSpaced(rows, start, stop))
{
    return internal::VectorX<T>::LinSpaced(rows, start, stop);
}

/// Return a linearly spaced vector from zero to `rows` - 1.
/// @param rows The number of rows
template<typename T = double>
auto linspace(Index rows) -> decltype(internal::VectorX<T>::LinSpaced(rows, 0, rows - 1))
{
    return internal::VectorX<T>::LinSpaced(rows, 0, rows - 1);
}

/// Return an expression of a unit vector.
/// @param rows The number of rows
/// @param i The row index at which the vector entry is one
template<typename T = double>
auto unit(Index rows, Index i) -> decltype(internal::VectorX<T>::Unit(rows, i))
{
    return internal::VectorX<T>::Unit(rows, i);
}

/// Return an expression of a unit row vector.
/// @param cols The number of columns
/// @param i The column index at which the row vector entry is one
template<typename T = double>
auto unitrow(Index cols, Index i) -> decltype(internal::RowVectorX<T>::Unit(cols, i))
{
    return internal::RowVectorX<T>::Unit(cols, i);
}

/// Return an expression of a matrix with entries equal to zero.
/// @param rows The number of rows
/// @param cols The number of columns
template<typename T = double>
auto zeros(Index rows, Index cols) -> decltype(internal::MatrixX<T>::Zero(rows, cols))
{
    return internal::MatrixX<T>::Zero(rows, cols);
}

/// Return an expression of a matrix with entries equal to one.
/// @param rows The number of rows
/// @param cols The number of columns
template<typename T = double>
auto ones(Index rows, Index cols) -> decltype(internal::MatrixX<T>::Ones(rows, cols))
{
    return internal::MatrixX<T>::Ones(rows, cols);
}

/// Return an expression of a vector with entries equal to a given constant.
/// @param rows The number of rows
/// @param val The constant value
template<typename T = double, typename U>
auto constants(Index rows, Index cols, const U& val) -> decltype(internal::MatrixX<T>::Constant(rows, cols, val))
{
    return internal::MatrixX<T>::Constant(rows, cols, val);
}

/// Return an expression of a matrix with random entries.
/// @param rows The number of rows
/// @param cols The number of columns
template<typename T = double>
auto random(Index rows, Index cols) -> decltype(internal::MatrixX<T>::Random(rows, cols))
{
    return internal::MatrixX<T>::Random(rows, cols);
}

/// Return an expression of an identity matrix.
/// @param rows The number of rows
/// @param cols The number of columns
template<typename T = double>
auto identity(Index rows, Index cols) -> decltype(internal::MatrixX<T>::Identity(rows, cols))
{
    return internal::MatrixX<T>::Identity(rows, cols);
}

/// Return a transpose expression of a matrix.
template<typename Derived>
auto tr(MatrixBase<Derived>& mat) -> decltype(mat.transpose())
{
    return mat.transpose();
}

/// Return a const transpose expression of a matrix.
template<typename Derived>
auto tr(const MatrixBase<Derived>& mat) -> decltype(mat.transpose())
{
    return mat.transpose();
}

/// Return an inverse expression of a matrix.
template<typename Derived>
auto inv(const MatrixBase<Derived>& mat) -> decltype(mat.cwiseInverse())
{
    return mat.cwiseInverse();
}

/// Return a diagonal matrix expression of a vector.
template<typename Derived>
auto diag(const MatrixBase<Derived>& vec) -> decltype(vec.asDiagonal())
{
    return vec.asDiagonal();
}

/// Return a vector expression of a matrix diagonal.
template<typename Derived>
auto diagonal(MatrixBase<Derived>& mat) -> decltype(mat.diagonal())
{
    return mat.diagonal();
}

/// Return a const vector expression of a matrix diagonal.
template<typename Derived>
auto diagonal(const MatrixBase<Derived>& mat) -> decltype(mat.diagonal())
{
    return mat.diagonal();
}

/// Return the Lp norm of a matrix.
template<int p, typename Derived>
auto norm(const MatrixBase<Derived>& mat) -> decltype(mat.template lpNorm<p>())
{
    return mat.template lpNorm<p>();
}

/// Return the L2 norm of a matrix.
template<typename Derived>
auto norm(const MatrixBase<Derived>& mat) -> decltype(mat.norm())
{
    return mat.norm();
}

/// Return the L-inf norm of a matrix.
template<typename Derived>
auto norminf(const MatrixBase<Derived>& mat) -> decltype(mat.template lpNorm<Infinity>())
{
    return mat.template lpNorm<Infinity>();
}

/// Return the sum expression of the entries of a matrix.
template<typename Derived>
auto sum(const DenseBase<Derived>& mat) -> decltype(mat.sum())
{
    return mat.sum();
}

/// Return the dot product expression of two matrices.
template<typename DerivedLHS, typename DerivedRHS>
auto dot(const MatrixBase<DerivedLHS>& lhs, const MatrixBase<DerivedRHS>& rhs) -> decltype(lhs.dot(rhs))
{
    return lhs.dot(rhs);
}

/// Return the minimum entry of a matrix.
template<typename Derived>
auto min(const MatrixBase<Derived>& mat) -> decltype(mat.minCoeff())
{
    return mat.minCoeff();
}

/// Return the component-wise minimum of two matrices.
template<typename DerivedLHS, typename DerivedRHS>
auto min(const MatrixBase<DerivedLHS>& lhs, const MatrixBase<DerivedRHS>& rhs) -> decltype(lhs.cwiseMin(rhs))
{
    return lhs.cwiseMin(rhs);
}

/// Return the maximum entry of a matrix.
template<typename Derived>
auto max(const MatrixBase<Derived>& mat) -> decltype(mat.maxCoeff())
{
    return mat.maxCoeff();
}

/// Return the component-wise maximum of two matrices.
template<typename DerivedLHS, typename DerivedRHS>
auto max(const MatrixBase<DerivedLHS>& lhs, const MatrixBase<DerivedRHS>& rhs) -> decltype(lhs.cwiseMax(rhs))
{
    return lhs.cwiseMax(rhs);
}

/// Return the component-wise absolute entries of a matrix.
template<typename Derived>
auto abs(const MatrixBase<Derived>& mat) -> decltype(mat.cwiseAbs())
{
    return mat.cwiseAbs();
}

/// Return the component-wise square root of a matrix.
template<typename Derived>
auto sqrt(const MatrixBase<Derived>& mat) -> decltype(mat.cwiseSqrt())
{
    return mat.cwiseSqrt();
}

/// Return the component-wise exponential of a matrix.
template<typename Derived>
auto pow(const MatrixBase<Derived>& mat, double power) -> decltype(mat.array().pow(power).matrix())
{
    return mat.array().pow(power).matrix();
}

/// Return the component-wise natural exponential of a matrix.
template<typename Derived>
auto exp(const MatrixBase<Derived>& mat) -> decltype(mat.array().exp().matrix())
{
    return mat.array().exp().matrix();
}

/// Return the component-wise natural log of a matrix.
template<typename Derived>
auto log(const MatrixBase<Derived>& mat) -> decltype(mat.array().log().matrix())
{
    return mat.array().log().matrix();
}

/// Return the component-wise log10 of a matrix.
template<typename Derived>
auto log10(const MatrixBase<Derived>& mat) -> decltype(mat.array().log10().matrix())
{
    return mat.array().log10().matrix();
}

/// Return a view of a sequence of rows of a matrix
/// @param start The row index of the start of the sequence
/// @param num The number of rows in the sequencetemplate<typename Derived>
template<typename Derived>
auto rows(MatrixBase<Derived>& mat, Index start, Index num) -> decltype(mat.middleRows(start, num))
{
   return mat.middleRows(start, num);
}

/// Return a view of a sequence of rows of a matrix
/// @param start The row index of the start of the sequence
/// @param num The number of rows in the sequencetemplate<typename Derived>
template<typename Derived>
auto rows(const MatrixBase<Derived>& mat, Index start, Index num) -> decltype(mat.middleRows(start, num))
{
   return mat.middleRows(start, num);
}

/// Return a view of some rows of a matrix
/// @param mat The matrix for which the view is created
/// @param irows The indices of the rows of the matrixtemplate<typename Derived, typename Indices>
template<typename Derived, typename Indices>
auto rows(MatrixBase<Derived>& mat, const Indices& irows) -> decltype(mat(irows, all))
{
   return mat(irows, all);
}

/// Return a const view of some rows of a matrix
/// @param mat The matrix for which the view is created
/// @param irows The indices of the rows of the matrixtemplate<typename Derived, typename Indices>
template<typename Derived, typename Indices>
auto rows(const MatrixBase<Derived>& mat, const Indices& irows) -> decltype(mat(irows, all))
{
   return mat(irows, all);
}

/// Return a view of a sequence of columns of a matrix
/// @param start The column index of the start of the sequence
/// @param num The number of columns in the sequencetemplate<typename Derived>
template<typename Derived>
auto cols(MatrixBase<Derived>& mat, Index start, Index num) -> decltype(mat.middleCols(start, num))
{
   return mat.middleCols(start, num);
}

/// Return a view of a sequence of columns of a matrix
/// @param start The column index of the start of the sequence
/// @param num The number of columns in the sequencetemplate<typename Derived>
template<typename Derived>
auto cols(const MatrixBase<Derived>& mat, Index start, Index num) -> decltype(mat.middleCols(start, num))
{
   return mat.middleCols(start, num);
}

/// Return a view of some columns of a matrix
/// @param mat The matrix for which the view is created
/// @param icols The indices of the columns of the matrixtemplate<typename Derived, typename Indices>
template<typename Derived, typename Indices>
auto cols(MatrixBase<Derived>& mat, const Indices& icols) -> decltype(mat(all, icols))
{
   return mat(all, icols);
}

/// Return a const view of some columns of a matrix
/// @param mat The matrix for which the view is created
/// @param icols The indices of the columns of the matrixtemplate<typename Derived, typename Indices>
template<typename Derived, typename Indices>
auto cols(const MatrixBase<Derived>& mat, const Indices& icols) -> decltype(mat(all, icols))
{
   return mat(all, icols);
}

/// Return a view of some rows and columns of a matrix
/// @param mat The matrix for which the view is created
/// @param irows The indices of the rows of the matrix
/// @param icols The indices of the columns of the matrixtemplate<typename Derived>
template<typename Derived>
auto segment(MatrixBase<Derived>& vec, Index irow, Index nrows) -> decltype(vec.segment(irow, nrows))
{
   return vec.segment(irow, nrows);
}

/// Return a view of some rows and columns of a matrix
/// @param mat The matrix for which the view is created
/// @param irows The indices of the rows of the matrix
/// @param icols The indices of the columns of the matrixtemplate<typename Derived>
template<typename Derived>
auto segment(const MatrixBase<Derived>& vec, Index irow, Index nrows) -> decltype(vec.segment(irow, nrows))
{
   return vec.segment(irow, nrows);
}

/// Return a view of some rows and columns of a matrix
/// @param mat The matrix for which the view is created
/// @param irows The indices of the rows of the matrix
/// @param icols The indices of the columns of the matrixtemplate<typename Derived>
template<typename Derived>
auto block(MatrixBase<Derived>& mat, Index irow, Index icol, Index nrows, Index ncols) -> decltype(mat.block(irow, icol, nrows, ncols))
{
   return mat.block(irow, icol, nrows, ncols);
}

/// Return a view of some rows and columns of a matrix
/// @param mat The matrix for which the view is created
/// @param irows The indices of the rows of the matrix
/// @param icols The indices of the columns of the matrixtemplate<typename Derived>
template<typename Derived>
auto block(const MatrixBase<Derived>& mat, Index irow, Index icol, Index nrows, Index ncols) -> decltype(mat.block(irow, icol, nrows, ncols))
{
   return mat.block(irow, icol, nrows, ncols);
}

/// Return a view of some rows and columns of a matrix
/// @param mat The matrix for which the view is created
/// @param irows The indices of the rows of the matrix
/// @param icols The indices of the columns of the matrixtemplate<typename Derived, typename Indices>
template<typename Derived, typename Indices>
auto submatrix(MatrixBase<Derived>& mat, const Indices& irows, const Indices& icols) -> decltype(mat(irows, icols))
{
   return mat(irows, icols);
}

/// Return a const view of some rows and columns of a matrix
/// @param mat The matrix for which the view is created
/// @param irows The indices of the rows of the matrix
/// @param icols The indices of the columns of the matrixtemplate<typename Derived, typename Indices>
template<typename Derived, typename Indices>
auto submatrix(const MatrixBase<Derived>& mat, const Indices& irows, const Indices& icols) -> decltype(mat(irows, icols))
{
   return mat(irows, icols);
}

/// Return a block mapped view of a matrix.
/// @param mat The matrix from which the mapped view is created.
/// @param row The index of the row at which the view starts.
/// @param col The index of the column at which the view starts.
/// @param nrows The number of rows of the mapped view.
/// @param ncols The number of columns of the mapped view.
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
auto blockmap(Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>& mat, Index row, Index col, Index nrows, Index ncols) -> Map<Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>, Unaligned, Stride<Rows,Cols>>
{
    Stride<Rows,Cols> stride(mat.outerStride(), mat.innerStride());
    return Map<Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>, Unaligned, Stride<MaxRows,MaxCols>>(
        mat.block(row, col, nrows, ncols).data(), nrows, ncols, stride);
}

/// Return a const block mapped view of a matrix.
/// @param mat The matrix from which the mapped view is created.
/// @param row The index of the row at which the view starts.
/// @param col The index of the column at which the view starts.
/// @param nrows The number of rows of the mapped view.
/// @param ncols The number of columns of the mapped view.
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
auto blockmap(const Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>& mat, Index row, Index col, Index nrows, Index ncols) -> Map<const Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>, Unaligned, Stride<Rows,Cols>>
{
    Stride<Rows,Cols> stride(mat.outerStride(), mat.innerStride());
    return Map<const Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>, Unaligned, Stride<MaxRows,MaxCols>>(
        mat.block(row, col, nrows, ncols).data(), nrows, ncols, stride);
}

/// Return a mapped view of a sequence of rows of a matrix.
/// @param mat The matrix from which the mapped view is created.
/// @param row The index of the row at which the view starts.
/// @param nrows The number of rows of the mapped view.
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
auto rowsmap(Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>& mat, Index row, Index nrows) -> Map<Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>, Unaligned, Stride<Rows,Cols>>
{
    return blockmap(mat, row, 0, nrows, mat.cols());
}

/// Return a const mapped view of a sequence of rows of a matrix.
/// @param mat The matrix from which the mapped view is created.
/// @param row The index of the row at which the view starts.
/// @param nrows The number of rows of the mapped view.
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
auto rowsmap(const Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>& mat, Index row, Index nrows) -> Map<const Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>, Unaligned, Stride<Rows,Cols>>
{
    return blockmap(mat, row, 0, nrows, mat.cols());
}

/// Return a mapped view of a sequence of columns of a matrix.
/// @param mat The matrix from which the mapped view is created.
/// @param row The index of the row at which the view starts.
/// @param col The index of the column at which the view starts.
/// @param nrows The number of rows of the mapped view.
/// @param ncols The number of columns of the mapped view.
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
auto colsmap(Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>& mat, Index col, Index ncols) -> Map<Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>, Unaligned, Stride<Rows,Cols>>
{
    return blockmap(mat, 0, col, mat.rows(), ncols);
}

/// Return a const mapped view of a sequence of columns of a matrix.
/// @param mat The matrix from which the mapped view is created.
/// @param row The index of the row at which the view starts.
/// @param col The index of the column at which the view starts.
/// @param nrows The number of rows of the mapped view.
/// @param ncols The number of columns of the mapped view.
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
auto colsmap(const Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>& mat, Index col, Index ncols) -> Map<const Matrix<Scalar,Rows,Cols,Options,MaxRows,MaxCols>, Unaligned, Stride<Rows,Cols>>
{
    return blockmap(mat, 0, col, mat.rows(), ncols);
}

/// Return the component-wise multiplication of two matrices.
template<typename DerivedLHS, typename DerivedRHS>
auto operator%(const MatrixBase<DerivedLHS>& lhs, const MatrixBase<DerivedRHS>& rhs) -> decltype(lhs.cwiseProduct(rhs))
{
    return lhs.cwiseProduct(rhs);
}

/// Return the component-wise division of two matrices.
template<typename DerivedLHS, typename DerivedRHS>
auto operator/(const MatrixBase<DerivedLHS>& lhs, const MatrixBase<DerivedRHS>& rhs) -> decltype(lhs.cwiseQuotient(rhs))
{
    return lhs.cwiseQuotient(rhs);
}

/// Return the component-wise division of a scalar by a matrix.
template<typename Derived>
auto operator/(const typename Derived::Scalar& scalar, const MatrixBase<Derived>& mat) -> decltype(scalar*mat.cwiseInverse())
{
    return scalar*mat.cwiseInverse();
}

/// Return the component-wise addition of a scalar and a matrix.
template<typename Derived>
auto operator+(const typename Derived::Scalar& scalar, const MatrixBase<Derived>& mat) -> decltype((scalar + mat.array()).matrix())
{
    return (scalar + mat.array()).matrix();
}

/// Return the component-wise addition of a matrix and a scalar.
template<typename Derived>
auto operator+(const MatrixBase<Derived>& mat, const typename Derived::Scalar& scalar) -> decltype((scalar + mat.array()).matrix())
{
    return (scalar + mat.array()).matrix();
}

/// Return the component-wise subtraction of a scalar and a matrix.
template<typename Derived>
auto operator-(const typename Derived::Scalar& scalar, const MatrixBase<Derived>& mat) -> decltype((scalar - mat.array()).matrix())
{
    return (scalar - mat.array()).matrix();
}

/// Return the component-wise subtraction of a matrix and a scalar.
template<typename Derived>
auto operator-(const MatrixBase<Derived>& mat, const typename Derived::Scalar& scalar) -> decltype((mat.array() - scalar).matrix())
{
    return (mat.array() - scalar).matrix();
}

template<typename Derived>
auto begin(MatrixBase<Derived>& vec) -> decltype(vec.derived().data())
{
    EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
    return vec.derived().data();
}

template<typename Derived>
auto begin(const MatrixBase<Derived>& vec) -> decltype(vec.derived().data())
{
    EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
    return vec.derived().data();
}

template<typename Derived>
auto end(MatrixBase<Derived>& vec) -> decltype(begin(vec) + vec.size())
{
    EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
    return begin(vec) + vec.size();
}

template<typename Derived>
auto end(const MatrixBase<Derived>& vec) -> decltype(begin(vec) + vec.size())
{
    EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
    return begin(vec) + vec.size();
}

} // namespace Eigen

#endif // EIGEN_HAS_CXX11

#endif // EIGEN_EIGENXFUNCTIONS_H
