from __future__ import print_function, division, absolute_import

from copy import copy
import nose.tools as nt

import numpy as np

import regreg.api as rr
from regreg.tests.decorators import set_seed_for_test

@set_seed_for_test()
def test_lasso_separable():
    """
    This test verifies that the specification of a separable
    penalty yields the same results as having two linear_atoms
    with selector matrices. The penalty here is a lasso, i.e. l1
    penalty.
    """

    X = np.random.standard_normal((100,20))
    Y = np.random.standard_normal((100,)) + np.dot(X, np.random.standard_normal(20))

    penalty1 = rr.l1norm(10, lagrange=1.2)
    penalty2 = rr.l1norm(10, lagrange=1.2)
    penalty = rr.separable((20,),
                           [penalty1, penalty2],
                           [slice(0,10), slice(10,20)],
                           test_for_overlap=True)

    # ensure code is tested

    print(penalty1.latexify())

    print(penalty.latexify())
    print(penalty.conjugate)
    print(penalty.dual)
    print(penalty.seminorm(np.ones(penalty.shape)))
    print(penalty.constraint(np.ones(penalty.shape), bound=2.))

    pencopy = copy(penalty)
    pencopy.set_quadratic(rr.identity_quadratic(1,0,0,0))
    pencopy.conjugate

    # solve using separable
    
    loss = rr.quadratic_loss.affine(X, -Y, coef=0.5)
    problem = rr.separable_problem.fromatom(penalty, loss)
    solver = rr.FISTA(problem)
    solver.fit(min_its=200, tol=1.0e-12)
    coefs = solver.composite.coefs

    # solve using the usual composite

    penalty_all = rr.l1norm(20, lagrange=1.2)
    problem_all = rr.container(loss, penalty_all)
    solver_all = rr.FISTA(problem_all)
    solver_all.fit(min_its=100, tol=1.0e-12)

    coefs_all = solver_all.composite.coefs

    # solve using the selectors

    penalty_s = [rr.linear_atom(p, rr.selector(g, (20,))) for p, g in
                 zip(penalty.atoms, penalty.groups)]
    problem_s = rr.container(loss, *penalty_s)
    solver_s = rr.FISTA(problem_s)
    solver_s.fit(min_its=500, tol=1.0e-12)
    coefs_s = solver_s.composite.coefs

    np.testing.assert_almost_equal(coefs, coefs_all)
    np.testing.assert_almost_equal(coefs, coefs_s)


    
@set_seed_for_test()
def test_group_lasso_separable():
    """
    This test verifies that the specification of a separable
    penalty yields the same results as having two linear_atoms
    with selector matrices. The penalty here is a group_lasso, i.e. l2
    penalty.
    """

    X = np.random.standard_normal((100,20))
    Y = np.random.standard_normal((100,)) + np.dot(X, np.random.standard_normal(20))

    penalty1 = rr.l2norm(10, lagrange=.2)
    penalty2 = rr.l2norm(10, lagrange=.2)
    penalty = rr.separable((20,), [penalty1, penalty2], [slice(0,10), slice(10,20)])

    # solve using separable
    
    loss = rr.quadratic_loss.affine(X, -Y, coef=0.5)
    problem = rr.separable_problem.fromatom(penalty, loss)
    solver = rr.FISTA(problem)
    solver.fit(min_its=200, tol=1.0e-12)
    coefs = solver.composite.coefs

    # solve using the selectors

    penalty_s = [rr.linear_atom(p, rr.selector(g, (20,))) for p, g in
                 zip(penalty.atoms, penalty.groups)]
    problem_s = rr.container(loss, *penalty_s)
    solver_s = rr.FISTA(problem_s)
    solver_s.fit(min_its=200, tol=1.0e-12)
    coefs_s = solver_s.composite.coefs

    np.testing.assert_almost_equal(coefs, coefs_s)


@set_seed_for_test()
def test_nonnegative_positive_part(debug=False):
    """
    This test verifies that using nonnegative constraint
    with a linear term, with some unpenalized terms yields the same result
    as using separable with constrained_positive_part and nonnegative
    """
    import numpy as np
    import regreg.api as rr
    import regreg.atoms as rra

    # N - number of data points
    # P - number of columns in design == number of betas
    N, P = 40, 30
    # an arbitrary positive offset for data and design
    offset = 2
    # data
    Y = np.random.normal(size=(N,)) + offset
    # design - with ones as last column
    X = np.ones((N,P))
    X[:,:-1] = np.random.normal(size=(N,P-1)) + offset
    # coef for loss
    coef = 0.5
    # lagrange for penalty
    lagrange = .1

    # Loss function (squared difference between fitted and actual data)
    loss = rr.quadratic_loss.affine(X, -Y, coef=coef)

    # Penalty using nonnegative, leave the last 5 unpenalized but
    # nonnegative
    weights = np.ones(P) * lagrange
    weights[-5:] = 0
    linq = rr.identity_quadratic(0,0,weights,0)
    penalty = rr.nonnegative(P, quadratic=linq)

    # Solution

    composite_form = rr.separable_problem.singleton(penalty, loss)
    solver = rr.FISTA(composite_form)
    solver.debug = debug
    solver.fit(tol=1.0e-12, min_its=200)
    coefs = solver.composite.coefs

    # using the separable penalty, only penalize the first
    # 25 coefficients with constrained_positive_part

    penalties_s = [rr.constrained_positive_part(25, lagrange=lagrange),
                   rr.nonnegative(5)]
    groups_s = [slice(0,25), slice(25,30)]
    penalty_s = rr.separable((P,), penalties_s,
                             groups_s)
    composite_form_s = rr.separable_problem.singleton(penalty_s, loss)
    solver_s = rr.FISTA(composite_form_s)
    solver_s.debug = debug
    solver_s.fit(tol=1.0e-12, min_its=200)
    coefs_s = solver_s.composite.coefs

    nt.assert_true(np.linalg.norm(coefs - coefs_s) / np.linalg.norm(coefs) < 1.0e-02)

@set_seed_for_test()
def test_different_dim():
    """
    This test checks that the reshape argument of separable
    works properly.
    """

    X = np.random.standard_normal((100,20))
    Y = (np.random.standard_normal((100,)) +
         np.dot(X, np.random.standard_normal(20)))

    penalty1 = rr.nuclear_norm((5, 2), lagrange=1.2)
    penalty2 = rr.l1norm(10, lagrange=1.2)
    penalty = rr.separable((20,),
                           [penalty1, penalty2],
                           [slice(0,10), slice(10,20)],
                           test_for_overlap=True,
                           shapes=[(5,2), None])

    # ensure code is tested

    print(penalty1.latexify())

    print(penalty.latexify())
    print(penalty.conjugate)
    print(penalty.dual)
    print(penalty.seminorm(np.ones(penalty.shape)))
    print(penalty.constraint(np.ones(penalty.shape), bound=2.))

    pencopy = copy(penalty)
    pencopy.set_quadratic(rr.identity_quadratic(1,0,0,0))
    pencopy.conjugate

    # solve using separable
    
    loss = rr.quadratic_loss.affine(X, -Y, coef=0.5)
    problem = rr.separable_problem.fromatom(penalty, loss)
    solver = rr.FISTA(problem)
    solver.fit(min_its=200, tol=1.0e-12)
    coefs = solver.composite.coefs
