"""A components sub-module containing classes for detecting intensity fluctuations at a
physical point in a model.

These Readout components essentially describe baseband and broadband detectors such as
DC and RF demodulated photodiodes typically used in optical experiments.
"""

import numpy as np
import types
from collections import defaultdict
import finesse

from finesse.components.general import Connector, borrows_nodes
from finesse.components.node import Node, NodeDirection, NodeType, Port
from finesse.components.workspace import ConnectorWorkspace
from finesse.parameter import float_parameter
from finesse.element import ModelElement
from finesse.detectors import pdtypes

from finesse.detectors.compute.quantum import (
    QShot0Workspace,
    QShotNWorkspace,
)

doc_readout_param = """"
Parameters
----------
name : str
    Name of readout element
optical_node : Node
    Node object which this readout element should look at
pdtype : str, dict
    A name of a pdtype defintion or a dict represeting a pdtype definition
"""


class ReadoutWorkspace(ConnectorWorkspace):
    pass


class _Readout(Connector):
    f"""Abstract class that provides basic functionality similar to all Readouts.
    Underscore because users should not be accessing it directly.

    {doc_readout_param}
    """

    def __init__(
        self, name: str, optical_node: Node, pdtype=None, output_detectors: bool = False
    ):
        super().__init__(name)
        self.pdtype = pdtype
        self.__output_detectors = output_detectors
        self._add_port("p1", NodeType.OPTICAL)

        if optical_node is not None:
            port = optical_node if isinstance(optical_node, Port) else optical_node.port
            other_node = tuple(o for o in port.nodes if o is not optical_node)[0]

            self.p1._add_node("i", None, optical_node)
            self.p1._add_node("o", None, other_node)
        else:
            self.p1._add_node("i", NodeDirection.INPUT)
            self.p1._add_node("o", NodeDirection.OUTPUT)

    def _on_add(self, model):
        if model is not self.p1._model:
            raise Exception(
                f"{repr(self)} is using a node {self.node} from a different model"
            )

    def _get_output_workspaces(self, model):
        return None

    @property
    def optical_node(self):
        if self.p1.i.component != self:
            return self.p1.i

    @property
    def has_mask(self):
        return False

    @property
    def output_detectors(self):
        return self.__output_detectors

    @output_detectors.setter
    def output_detectors(self, value: bool):
        self.__output_detectors = value


class ReadoutDetectorOutput(ModelElement):
    """A placeholder element that represents a detector output generated by a Readout
    element.

    Notes
    -----
    These should not be created directly by a user.
    It is internally created and added by a Readout component.
    """

    def __init__(self, name: str, readout: _Readout):
        super().__init__(name)
        self.__readout = readout

    @property
    def readout(self):
        return self.__readout


@borrows_nodes()
class ReadoutDC(_Readout):
    f"""A Readout component which represents a photodiode measuring the intensity of
    some incident field. Audio band intensity signals present in the incident optical field
    are converted into an electrical signal and output at the `self.DC` port, which has a
    single `self.DC.o` node.

    {doc_readout_param}
    """

    def __init__(
        self,
        name: str,
        optical_node: Node = None,
        pdtype=None,
        output_detectors: bool = False,
    ):
        super().__init__(
            name, optical_node, pdtype=pdtype, output_detectors=output_detectors
        )

        self.pdtype = pdtypes.get_pdtype(pdtype)

        self._add_port("DC", NodeType.ELECTRICAL)
        self.DC._add_node("o", NodeDirection.OUTPUT)

        self._register_node_coupling("P1i_DC", self.p1.i, self.DC.o)

        self.outputs = types.SimpleNamespace()
        self.outputs.DC = f"{self.name}_DC"

    def _on_add(self, model):
        super()._on_add(model)
        model.add(ReadoutDetectorOutput(f"{self.name}_DC", self))

    def _get_workspace(self, sim):
        if sim.signal:
            has_DC_node = self.DC.o.full_name in sim.signal.nodes

            if not has_DC_node:
                return None  # Don't do anything if no nodes included

            ws = ReadoutWorkspace(self, sim)
            ws.prev_carrier_solve_num = -1
            ws.I = np.eye(sim.model_settings.num_HOMs, dtype=np.complex128)
            ws.signal.add_fill_function(self._fill_matrix, True)
            ws.frequencies = sim.signal.signal_frequencies[self.DC.o].frequencies
            ws.is_segmented = self.pdtype is not None
            if ws.is_segmented:
                ws.K = pdtypes.construct_segment_beat_matrix(
                    sim.model.mode_index_map, self.pdtype  # , sparse_output=True
                )
            return ws
        else:
            return None

    def _get_output_workspaces(self, sim):
        from finesse.detectors import PowerDetector, QuantumShotNoiseDetector
        from finesse.detectors.workspace import OutputInformation
        from finesse.detectors.compute.power import PD0Workspace

        wss = []

        # Setup a DC output photodiode detector for
        # using for outputs
        oinfo = OutputInformation(
            self.name + "_DC",
            PowerDetector,
            (self.p1.i,),
            np.float64,
            "W",
            None,
            "W",
            True,
            False,
        )
        ws = PD0Workspace(self, sim, oinfo=oinfo, pdtype=self.pdtype)
        wss.append(ws)

        if sim.signal:
            oinfo = OutputInformation(
                self.name + "_shot_noise",
                QuantumShotNoiseDetector,
                (self.p1.i,),
                np.float64,
                "W/rtHz",
                None,
                "ASD",
                True,
                False,
            )
            wss.append(QShot0Workspace(self, sim, False, output_info=oinfo))

        return wss

    def _fill_matrix(self, ws):
        """
        Computing E.conj() * upper + E * lower.conj()
        """
        # if the previous fill was done with this carrier then there
        # is no need to refill it...
        if ws.prev_carrier_solve_num == ws.sim.carrier.num_solves:
            return

        for freq in ws.sim.signal.optical_frequencies.frequencies:
            # Get the carrier HOMs for this frequency
            cidx = freq.audio_carrier_index
            rhs_idx = ws.sim.carrier.field(self.p1.i, cidx, 0)
            Ec = np.conjugate(
                ws.sim.carrier.out_view[
                    rhs_idx : (rhs_idx + ws.sim.model_settings.num_HOMs)
                ]
            )
            for efreq in ws.frequencies:
                if ws.signal.connections.P1i_DC_idx > -1:
                    with ws.sim.signal.component_edge_fill3(
                        ws.owner_id,
                        ws.signal.connections.P1i_DC_idx,
                        freq.index,
                        efreq.index,
                    ) as mat:
                        if ws.is_segmented:
                            mat[:] = np.dot(ws.K, Ec)
                        else:
                            mat[:] = Ec
        # store what carrier solve number this fill was done with
        ws.prev_carrier_solve_num = ws.sim.carrier.num_solves


@borrows_nodes()
@float_parameter("f", "Frequency")
@float_parameter("phase", "Phase")
class ReadoutRF(_Readout):
    def __init__(
        self,
        name,
        optical_node=None,
        *,
        f=None,
        phase=0,
        output_detectors=False,
        pdtype=None,
    ):
        super().__init__(
            name, optical_node, pdtype=pdtype, output_detectors=output_detectors
        )

        self.f = f
        self.phase = phase

        self._add_port("I", NodeType.ELECTRICAL)
        self.I._add_node("o", NodeDirection.OUTPUT)
        self._add_port("Q", NodeType.ELECTRICAL)
        self.Q._add_node("o", NodeDirection.OUTPUT)

        self._register_node_coupling("P1i_I", self.p1.i, self.I.o)
        self._register_node_coupling("P1i_Q", self.p1.i, self.Q.o)

        self.outputs = types.SimpleNamespace()
        self.outputs.I = f"{self.name}_I"
        self.outputs.Q = f"{self.name}_Q"

    @property
    def optical_node(self):
        if self.p1.i.component != self:
            return self.p1.i

    def _on_add(self, model):
        super()._on_add(model)
        model.add(ReadoutDetectorOutput(self.name + "_I", self))
        model.add(ReadoutDetectorOutput(self.name + "_Q", self))

    def _get_workspace(self, sim):
        if sim.signal:
            has_I_node = self.I.o.full_name in sim.signal.nodes
            has_Q_node = self.Q.o.full_name in sim.signal.nodes

            if not (has_I_node or has_Q_node):
                return None  # Don't do anything if no nodes included

            ws = ReadoutWorkspace(self, sim)
            ws.prev_carrier_solve_num = -1
            ws.signal.add_fill_function(self._fill_matrix, True)
            ws.frequencies = sim.signal.signal_frequencies[
                self.I.o if has_I_node else self.Q.o
            ].frequencies
            ws.dc_node_id = sim.carrier.node_id(self.p1.i)
            ws.is_segmented = self.pdtype is not None
            if ws.is_segmented:
                ws.K = pdtypes.construct_segment_beat_matrix(
                    sim.model.mode_index_map, self.pdtype  # , sparse_output=True
                )
            return ws
        else:
            return None

    def _get_output_workspaces(self, sim):
        from finesse.detectors import (
            PowerDetectorDemod1,
            QuantumShotNoiseDetectorDemod1,
        )
        from finesse.detectors.workspace import OutputInformation
        from finesse.detectors.compute.power import PD1Workspace

        wss = []
        for quadrature in ("I", "Q"):
            # Setup a single demodulation photodiode detector for
            # using for outputs
            oinfo = OutputInformation(
                self.name + "_" + quadrature,
                PowerDetectorDemod1,
                (self.p1.i,),
                np.float64,
                "W",
                None,
                "W",
                True,
                False,
            )
            poff = 90 if quadrature == "Q" else 0
            ws = PD1Workspace(
                self,
                sim,
                self.f,
                self.phase,
                phase_offset=poff,
                oinfo=oinfo,
                pdtype=self.pdtype,
            )
            wss.append(ws)

        if sim.signal:
            oinfo = OutputInformation(
                self.name + "_shot_noise",
                QuantumShotNoiseDetectorDemod1,
                (self.p1.i,),
                np.float64,
                "W/rtHz",
                None,
                "ASD",
                True,
                False,
            )
            wss.append(QShotNWorkspace(self, sim, 1, False, output_info=oinfo))
        return wss

    def _fill_matrix(self, ws):
        if ws.prev_carrier_solve_num == ws.sim.carrier.num_solves:
            return

        # extra factor of two we do not apply here as we work
        # directly with amplitudes from the matrix solution
        # need one half gain from demod. Other factor of two from
        # signal scaling and 0.5 from second demod cancel out
        factorI = (
            0.5
            * ws.sim.model_settings.EPSILON0_C
            * np.exp(-1j * ws.values.phase * finesse.constants.DEG2RAD)
        )
        factorQ = (
            0.5
            * ws.sim.model_settings.EPSILON0_C
            * np.exp(-1j * (ws.values.phase + 90) * finesse.constants.DEG2RAD)
        )
        terms = defaultdict(list)

        for f1 in ws.sim.carrier.optical_frequencies.frequencies:
            for f2 in ws.sim.carrier.optical_frequencies.frequencies:
                df = f1.f - f2.f
                # Get the carrier HOMs for this frequency
                rhs_idx = ws.sim.carrier.field(self.p1.i, f1.index, 0)
                E1 = ws.sim.carrier.out_view[
                    rhs_idx : (rhs_idx + ws.sim.model_settings.num_HOMs)
                ]
                E1c = np.conjugate(E1)
                rhs_idx = ws.sim.carrier.field(self.p1.i, f2.index, 0)
                E2 = ws.sim.carrier.out_view[
                    rhs_idx : (rhs_idx + ws.sim.model_settings.num_HOMs)
                ]
                E2c = np.conjugate(E2)

                if df == -ws.values.f:
                    if ws.signal.connections.P1i_I_idx >= 0:
                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_I_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f2.index)[
                                "audio_lower_index"
                            ],
                        )
                        terms[key].append(factorI * E1c)

                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_I_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f1.index)[
                                "audio_upper_index"
                            ],
                        )
                        terms[key].append(factorI.conjugate() * E2c)

                    if ws.signal.connections.P1i_Q_idx >= 0:
                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_Q_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f2.index)[
                                "audio_lower_index"
                            ],
                        )
                        terms[key].append(factorQ * E1c)

                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_Q_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f1.index)[
                                "audio_upper_index"
                            ],
                        )
                        terms[key].append(factorQ.conjugate() * E2c)

                if df == ws.values.f:
                    if ws.signal.connections.P1i_I_idx >= 0:
                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_I_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f2.index)[
                                "audio_lower_index"
                            ],
                        )
                        terms[key].append(factorI.conjugate() * E1c)

                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_I_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f1.index)[
                                "audio_upper_index"
                            ],
                        )

                        terms[key].append(factorI * E2c)

                    if ws.signal.connections.P1i_Q_idx >= 0:
                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_Q_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f2.index)[
                                "audio_lower_index"
                            ],
                        )
                        terms[key].append(factorQ.conjugate() * E1c)

                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_Q_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f1.index)[
                                "audio_upper_index"
                            ],
                        )

                        terms[key].append(factorQ * E2c)

        for key, values in terms.items():
            total = sum(values)
            if ws.is_segmented:
                total = np.dot(ws.K, total)

            with ws.sim.signal.component_edge_fill3(*key, 0) as mat:
                mat[:] = total
        # store previous carrier solve number this fill was done with
        # so we don't have to repeat it
        ws.prev_carrier_solve_num = ws.sim.carrier.num_solves
