from numpy.testing import assert_allclose

from optimix.variables import merge_variables
from optimix.variables import Variables
from optimix import Scalar

def test_variables_set():
    a = Scalar(1.0)
    b = a
    a.value = 2.0
    assert a is b
    assert a.raw is b.raw

    v = Variables(dict(a=Scalar(1.0), b=Scalar(1.5)))
    v.set({'a': 0.5})
    assert_allclose(v.get('a'), 0.5)

def test_variables_str():
    v = Variables(dict(a=Scalar(1.0), b=Scalar(1.5)))
    assert v.__str__() == """Variables(a=Scalar(1.0),
          b=Scalar(1.5))"""

def test_variables_merge():
    a = Variables(a0=Scalar(1.0))
    b = Variables(b0=Scalar(1.0))
    c = merge_variables(dict(a=a, b=b))

    a.get('a0').value += 1.0

    assert a.get('a0').value == 2.0
    assert a.get('a0').value == c.get('a.a0').value

def test_variables_setattr():
    a = Variables(a0=Scalar(1.0))

    a['a1'] = Scalar(2.0)
    a['a1'].value += 1.0

    assert a.get('a0').value == 1.0
    assert a.get('a1').value == 3.0

if __name__ == '__main__':
    __import__('pytest').main([__file__, '-s'])
