"""
Plane wave quantum noise tests.
"""

import pytest
import numpy as np
from scipy import constants
from numpy.testing import assert_allclose


@pytest.fixture
def lossy_mirror_model(model):
    model.parse(
        """
        l l1 P=1
        s s1 l1.p1 m1.p1 L=0
        m m1 R=0.5 T=0.4

        qnoised qd_laser l1.p1.o
        qnoised qd_trans m1.p2.o
        qnoised qd_refl m1.p1.o
    """
    )

    return model


@pytest.fixture
def modulator_model(model):
    model.parse(
        """
        l l1 P=1
        s s1 l1.p1 mod1.p1 L=0
        mod mod1 f=10M midx=0.8 mod_type=am

        qnoised qd1 mod1.p2.o
        qnoised1 qd2 mod1.p2.o &mod1.f 0

        fsig(1)
    """
    )

    return model


@pytest.fixture
def squeezer_model(model):
    model.parse(
        """
        sq sqz 10
        s s1 sqz.p1 mod.p1 L=5
        mod mod f=10M midx=0.5
        s s2 mod.p2 m1.p1 L=5
        m m1 R=0.9 T=0.1
        s scav m1.p2 m2.p1 L=30
        m m2 R=0.9 T=0.1

        free_mass fm m2.mech mass=1

        qnoised qd1 m1.p1.o
        qnoised1 qd2 m1.p1.o &mod.f 0
        qnoised1 qd3 m1.p1.o -&mod.f 0
        qnoised1 qd4 m1.p1.o &mod.f 0
        qnoised1 qd5 m1.p1.o -&mod.f 0

        fsig(1)
    """
    )

    return model


def test_vacuum_quantum_noise(lossy_mirror_model):
    """Test basic quantum noise generation"""
    lossy_mirror_model.parse(
        """
        xaxis(l1.f, log, 1, 1e9, 100)  # Sweep laser frequency from 1Hz to 1GHz
        """
    )

    # Trying to detect quantum noise without an fsig should raise an error
    with pytest.raises(Exception):
        lossy_mirror_model.run()
    lossy_mirror_model.parse("fsig(1)")
    out = lossy_mirror_model.run()
    xaxis = out.x[0]

    # Quantum noise sidebands generated by the laser
    qn = 0.5 * (1 + xaxis / lossy_mirror_model.f0)

    # Convert generated Qn to an ASD
    # Factors:
    #   2 (number of sidebands being added)
    # * laser_power
    # * 4 (Where does this come from?)
    # * 2 (compensation for 0.5 demod factor when demodulating at signal frequency)
    # * 0.25 (factor 0.25 per demodulation)
    # TODO: Won't the 4 and 0.25 always cancel each other out? What's
    # the justification for these?
    fac = 2 * lossy_mirror_model.l1.P.value * 4 * 2 * 0.25
    qn_asd = np.sqrt(fac * qn * constants.h * lossy_mirror_model.f0)
    qn_asd_trans = qn_asd * np.sqrt(lossy_mirror_model.m1.T.value)
    qn_asd_refl = qn_asd * np.sqrt(lossy_mirror_model.m1.R.value)

    assert_allclose(out["qd_laser"], qn_asd, rtol=1e-14, atol=0)
    assert_allclose(out["qd_trans"], qn_asd_trans, rtol=1e-14, atol=0)
    assert_allclose(out["qd_refl"], qn_asd_refl, rtol=1e-14, atol=0)


def test_modulator_quantum_noise_sweep_laser_freq(modulator_model):
    """
    Test modulator quantum noise couplings with varying laser frequency
    """
    modulator_model.parse(
        """
        xaxis(l1.f, log, 1, 1e9, 100)  # Sweep laser frequency from 1Hz to 1GHz
        """
    )

    out = modulator_model.run()
    xaxis = out.x[0]

    # Quantum noise sidebands generated by the laser
    qn = 0.5 * (1 + xaxis / modulator_model.f0)

    # Combine the quantum noise from each carrier with the correct modulated powers
    midx = modulator_model.mod1.midx.value
    qn_tot = qn * ((1 - 0.5 * midx) ** 2 + 2 * (0.25 * midx) ** 2)

    # Convert generated qn to an ASD
    # Factors:
    #   2 (number of sidebands being added)
    # * laser_power
    # * 4 (Where does this come from?)
    # * 2 (compensation for 0.5 demod factor when demodulating at signal frequency)
    # * 0.25 (factor 0.25 per demodulation)
    # TODO: Won't the 4 and 0.25 always cancel each other out? What's
    # the justification for these?
    fac = 2 * modulator_model.l1.P.value * 4 * 2 * 0.25
    qn_asd = np.sqrt(fac * qn_tot * constants.h * modulator_model.f0)

    assert_allclose(out["qd1"], qn_asd, rtol=1e-14, atol=0)

    # TODO: work out the correct analytics for this
    # assert_allclose(out["qd2"], qn_asd_upper, rtol=1e-14, atol=0)


def test_modulator_quantum_noise_sweep_modulation_index(modulator_model):
    """
    Test modulator quantum noise couplings with varying modulation index
    """
    modulator_model.parse(
        """
        xaxis(mod1.midx, lin, 0, 1, 100)
        """
    )

    out = modulator_model.run()
    midx = out.x[0]

    # Quantum noise sidebands generated by the laser
    qn = 0.5

    # Combine the quantum noise from each carrier with the correct modulated powers
    qn_tot = qn * ((1 - 0.5 * midx) ** 2 + 2 * (0.25 * midx) ** 2)

    # Convert generated qn to an ASD
    # Factors:
    #   2 (number of sidebands being added)
    # * laser_power
    # * 4 (Where does this come from?)
    # * 2 (compensation for 0.5 demod factor when demodulating at signal frequency)
    # * 0.25 (factor 0.25 per demodulation)
    # TODO: Won't the 4 and 0.25 always cancel each other out? What's
    # the justification for these?
    fac = 2 * modulator_model.l1.P.value * 4 * 2 * 0.25
    qn_asd = np.sqrt(fac * qn_tot * constants.h * modulator_model.f0)

    assert_allclose(out["qd1"], qn_asd, rtol=1e-14)

    # TODO: work out the correct analytics for this
    # assert_allclose(out["qd2"], qn_asd_upper, rtol=1e-14, atol=0)


@pytest.mark.skip(reason="Not valid until finesse 3 handles this properly")
def test_modulator_quantum_noise_squeezer(squeezer_model):
    """Test modulator quantum noise couplings with squeezed light input"""
    squeezer_model.parse(
        """
        xaxis(sqz.f, lin, -10, 10, 20)  # Sweep squeezer frequency from -10Hz to +10Hz
        """
    )

    out = squeezer_model.run()
    xaxis = out.x[0]
    qn = 0.5 * (1 + xaxis / squeezer_model.f0) * constants.h

    assert_allclose(out["qd1"], qn, rtol=1e-5, atol=1e-8)
    assert_allclose(out["qd2"], qn, rtol=1e-5, atol=1e-8)
    assert_allclose(out["qd3"], qn, rtol=1e-5, atol=1e-8)
