import uarray as ua
import pickle

import pytest  # type: ignore


@pytest.fixture(scope="function", autouse=True)
def cleanup_backends():
    with ua.reset_state():
        yield


class Backend:
    __ua_domain__ = "ua_tests"


@pytest.fixture()
def nullary_mm():
    return ua.generate_multimethod(lambda: (), lambda a, kw, d: (a, kw), "ua_tests")


def test_nestedbackend(nullary_mm):
    obj = object()
    be_outer = Backend()
    be_outer.__ua_function__ = lambda f, a, kw: obj

    def default(*a, **kw):
        return nullary_mm(*a, **kw)

    mm2 = ua.generate_multimethod(
        lambda: (), lambda a, kw, d: (a, kw), "ua_tests", default=default
    )
    be_inner = Backend()

    def be2_ua_func(f, a, kw):
        with ua.skip_backend(be_inner):
            return f(*a, **kw)

    be_inner.__ua_function__ = be2_ua_func
    with ua.set_backend(be_outer), ua.set_backend(be_inner):
        assert mm2() is obj


def _replacer(args, kwargs, dispatchables):
    return (args, kwargs)


@ua.create_multimethod(_replacer, "ua_tests")
def pickle_mm():
    return ()


def test_pickle_support():
    unpickle_mm = pickle.loads(pickle.dumps(pickle_mm))

    assert unpickle_mm is pickle_mm


def test_registration(nullary_mm):
    obj = object()
    be = Backend()
    be.__ua_function__ = lambda f, a, kw: obj

    ua.register_backend(be)
    assert nullary_mm() is obj


def test_global(nullary_mm):
    obj = object()
    be = Backend()
    be.__ua_function__ = lambda f, a, kw: obj

    ua.set_global_backend(be)
    assert nullary_mm() is obj


def ctx_before_global(nullary_mm):
    obj = object()
    obj2 = object()
    be = Backend()
    be.__ua_function__ = lambda f, a, kw: obj

    be2 = Backend()
    be2.__ua_function__ = lambda f, a, kw: obj2

    ua.set_global_backend(be)

    with ua.set_backend(be2):
        assert nullary_mm() is obj2


def test_global_before_registered(nullary_mm):
    obj = object()
    obj2 = object()
    be = Backend()
    be.__ua_function__ = lambda f, a, kw: obj

    be2 = Backend()
    be2.__ua_function__ = lambda f, a, kw: obj2

    ua.set_global_backend(be)
    ua.register_backend(be2)
    assert nullary_mm() is obj


def test_global_only(nullary_mm):
    obj = object()
    be = Backend()
    be.__ua_function__ = lambda f, a, kw: NotImplemented

    be2 = Backend()
    be2.__ua_function__ = lambda f, a, kw: obj

    ua.set_global_backend(be, only=True)
    ua.register_backend(be2)

    with pytest.raises(ua.BackendNotImplementedError):
        nullary_mm()


def test_clear_backends(nullary_mm):
    obj = object()
    obj2 = object()
    be = Backend()
    be.__ua_function__ = lambda f, a, kw: obj

    be2 = Backend()
    be2.__ua_function__ = lambda f, a, kw: obj2

    ua.set_global_backend(be)
    ua.register_backend(be2)

    ua.clear_backends(Backend.__ua_domain__, registered=True, globals=True)
    with pytest.raises(ua.BackendNotImplementedError):
        nullary_mm()


def test_get_extractor_replacer():
    def extractor():
        return ()

    def replacer(a, kw, d):
        return a, kw

    mm = ua.generate_multimethod(extractor, replacer, "ua_tests")

    assert mm.arg_extractor is extractor
    assert mm.arg_replacer is replacer


def test_raising_from_backend(nullary_mm):
    def raise_(foo):
        raise foo

    Foo = ua.BackendNotImplementedError("Foo")
    be = Backend()
    be.__ua_function__ = lambda f, a, kw: raise_(Foo)

    # BackendNotImplementedErrors are nested
    with ua.set_backend(be):
        with pytest.raises(ua.BackendNotImplementedError) as e:
            nullary_mm()

        assert (
            e.value.args[0]
            == "No selected backends had an implementation for this function."
        )
        assert type(e.value.args[1]) == tuple
        assert e.value.args[1] == (be, Foo)

    Bar = ua.BackendNotImplementedError("Bar")
    be2 = Backend()
    be2.__ua_function__ = lambda f, a, kw: raise_(Bar)
    # Errors are in the order the backends were tried
    with ua.set_backend(be), ua.set_backend(be2):
        with pytest.raises(ua.BackendNotImplementedError) as e:
            nullary_mm()

        assert e.value.args[1] == (be2, Bar)
        assert e.value.args[2] == (be, Foo)

    be3 = Backend()
    be3.__ua_function__ = lambda f, a, kw: "Success"
    # Can succeed after a backend has raised BackendNotImplementedError
    with ua.set_backend(be3), ua.set_backend(be):
        assert nullary_mm() == "Success"


def test_nested():
    be = Backend()
    be.__ua_function__ = lambda f, a, kw: None

    ctx = ua.set_backend(be)

    with ctx, ctx:
        pass


def test_invalid():
    be1 = Backend()
    be1.__ua_function__ = lambda f, a, kw: None

    be2 = Backend()
    be2.__ua_function__ = lambda f, a, kw: None

    ctx1 = ua.set_backend(be1)
    ctx2 = ua.set_backend(be2)

    with pytest.raises(RuntimeError):
        try:
            ctx1.__enter__()
            try:
                ctx2.__enter__()
            finally:
                ctx1.__exit__(None, None, None)
        finally:
            ctx2.__exit__(None, None, None)


def test_skip_comparison(nullary_mm):
    be1 = Backend()
    be1.__ua_function__ = lambda f, a, kw: None

    class Backend2(Backend):
        @staticmethod
        def __ua_function__(f, a, kw):
            pass

        def __eq__(self, other):
            return other is self or other is be1

    with pytest.raises(ua.BackendNotImplementedError):
        with ua.set_backend(be1), ua.skip_backend(Backend2()):
            nullary_mm()


def test_skip_raises(nullary_mm):
    be1 = Backend()
    be1.__ua_function__ = lambda f, a, kw: None

    foo = Exception("Foo")

    class Backend2(Backend):
        @staticmethod
        def __ua_function__(f, a, kw):
            pass

        def __eq__(self, other):
            raise foo

    with pytest.raises(Exception) as e:
        with ua.set_backend(be1), ua.skip_backend(Backend2()):
            nullary_mm()

    assert e.value is foo


def test_getset_state(cleanup_backends):
    ua.set_global_backend(Backend())
    ua.register_backend(Backend())

    with ua.set_backend(Backend()), ua.skip_backend(Backend()):
        state = ua.get_state()

    pstate = state._pickle()

    assert pstate != ua.get_state()._pickle()

    with ua.set_state(state):
        assert pstate[:2] == ua.get_state()._pickle()[:2]


class ComparableBackend(Backend):
    def __init__(self, obj):
        super().__init__()
        self.obj = obj

    def __eq__(self, other):
        return isinstance(other, ComparableBackend) and self.obj == other.obj

    def __ne__(self, other):
        return not (self == other)


def test_pickle_state():
    ua.set_global_backend(ComparableBackend("a"))
    ua.register_backend(ComparableBackend("b"))

    with ua.set_backend(ComparableBackend("c")), ua.skip_backend(
        ComparableBackend("d")
    ):
        state = ua.get_state()

    state_loaded = pickle.loads(pickle.dumps(state))

    assert state._pickle() == state_loaded._pickle()
