#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_gstlal-ugly_1767612080/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl/bin/python
#
# Copyright (C) 2016  Chad Hanna
# Copyright (C) 2019  Patrick Godwin
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


from collections import deque
import logging
from optparse import OptionParser
import sys

import numpy
from scipy import signal
import yaml

from ligo.scald import utils
from ligo.scald.io import influx, kafka

from gstlal import datasource, pipeio, psd
from gstlal import pipeparts
from gstlal.stream import MessageType, Stream

#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#

def parse_command_line():
	parser = OptionParser(description = __doc__)

	# generic "source" options
	datasource.append_options(parser)

	# add our own options
	parser.add_option("--sample-rate", metavar = "Hz", default = 4096, type = "int", help = "Sample rate at which to generate the PSD, default 16384 Hz")
	parser.add_option("--psd-fft-length", metavar = "s", default = 16, type = "int", help = "FFT length, default 8s")
	parser.add_option("--scald-config", metavar = "path", help = "sets ligo-scald options based on yaml configuration.")
	parser.add_option("--output-kafka-server", metavar = "addr", help = "Set the server address and port number for output data. Optional")
	parser.add_option("--analysis-tag", metavar = "tag", default = "test", help = "Set the string to identify the analysis in which this job is part of. Used when --output-kafka-server is set. May not contain \".\" nor \"-\". Default is test.")
	parser.add_option("--injection-channel", metavar = "tag", action = "store_true", help = "if True, compute noise and range history from injection channels.")
	parser.add_option("--reference-psd", metavar = "filename", help = "Load spectrum from this LIGO light-weight XML file. The noise spectrum will be measured and tracked starting from this reference. (optional).")
	parser.add_option("--horizon-approximant", type = "string", default = "IMRPhenomD", help = "Specify a waveform approximant to use while calculating the horizon distance and range. Default is IMRPhenomD.")
	parser.add_option("--horizon-f-min", metavar = "Hz", type = "float", default = 15., help = "Set the frequency at which the waveform model is to begin for the horizon distance and range calculation. Default is 15 Hz.")
	parser.add_option("--horizon-f-max", metavar = "Hz", type = "float", default = 900., help = "Set the upper frequency cut off for the waveform model used in the horizon distance and range calculation. Default is 900 Hz.")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose (optional).")

	options, filenames = parser.parse_args()

	return options, filenames

class NoiseTracker(object):
	def __init__(self, instrument, agg_sink, client, tag, approximant, f_min, f_max, injections=False):
		self.psd = None
		self.instrument = instrument
		self.agg_sink = agg_sink
		self.client = client
		self.tag = tag
		self.injections = injections
		self.horizon_distance_func = psd.HorizonDistance(f_min, f_max, 1./16., 1.4, 1.4, approximant = approximant)

		self.routes = ("noise", "range_history")
		self.timedeq = deque(maxlen = 10000)
		self.datadeq = {route: deque(maxlen = 10000) for route in self.routes}
		self.last_reduce_time = None
		self.prevdataspan = set()

		if injections:
			self.psd_topic = f'gstlal.{self.tag}.{self.instrument}_inj_psd'
		else:
			self.psd_topic = f'gstlal.{self.tag}.{self.instrument}_psd'

	def on_spectrum_message(self, message):
		self.psd = pipeio.parse_spectrum_message(message)
		return True

	def on_buffer(self, buf):
		if buf.is_gap:
			return

		if self.last_reduce_time is None:
			self.last_reduce_time = round(int(buf.t0), -2)
		logging.debug(f"found buffer at t = {buf.t0}")

		# First noise
		ix = numpy.argmax(buf.data[0])
		self.timedeq.append(int(buf.t0))
		self.datadeq['noise'].append(buf.data[0,ix])
		if self.psd:
			# Then range
			self.datadeq['range_history'].append(self.horizon_distance_func(self.psd, 8)[0] / 2.25)

			# The PSD
			psd_freq = numpy.arange(self.psd.data.length / 4) * self.psd.deltaF * 4
			psd_data = signal.decimate(self.psd.data.data[:], 4, ftype='fir', zero_phase=False)[:-1]**.5

		# Only reduce every 100s
		if (buf.t0 - self.last_reduce_time) >= 100:
			self.last_reduce_time = round(int(buf.t0), -2)
			logging.info("reducing data and writing PSD snapshot for %d @ %d" % (buf.t0, int(utils.gps_now())))

			data = {route: {self.instrument: {'time': list(self.timedeq), 'fields': {'data': list(self.datadeq[route])}}} for route in self.routes}

			### store and reduce noise / range history
			for route in self.routes:
				if self.datadeq[route]:
					if self.injections:
						measurement = f"inj_{route}"
					else:
						measurement = route
					agg_sink.store_columns(measurement, data[route], aggregate="max")

			### output "latest" psd to kafka
			if self.client and self.psd:
				psd_output = {
						"freq": psd_freq.tolist(),
						"asd": psd_data.tolist(),
						"time": [int(buf.t0)],
						"deltaF": self.psd.deltaF * 4
				}

				self.client.write(self.psd_topic, psd_output)

			### flush buffers
			self.timedeq.clear()
			for route in self.routes:
				self.datadeq[route].clear()


#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#

if __name__ == '__main__':
	options, filenames = parse_command_line()

	log_level = logging.DEBUG if options.verbose else logging.INFO
	logging.basicConfig(level = log_level, format = "%(asctime)s | gstlal_ll_dq : %(levelname)s : %(message)s")

	# set up aggregator sink
	with open(options.scald_config, 'r') as f:
		agg_config = yaml.safe_load(f)
	agg_sink = influx.Aggregator(**agg_config["backends"]["default"])

	# register measurement schemas for aggregators
	agg_sink.load(path=options.scald_config)

	# set up kafka client
	if options.output_kafka_server:
		client = kafka.Client("kafka://{}".format(options.output_kafka_server))
	else:
		client = None

	# parse the generic "source" options, check for inconsistencies is done inside
	# the class init method
	gw_data_source_info = datasource.DataSourceInfo.from_optparse(options)

	# only support one channel
	instrument = list(gw_data_source_info.channel_dict.keys())[0]

	if options.reference_psd is not None:
		reference_psd = psd.read_psd(options.reference_psd, verbose=options.verbose)[instrument]
	else:
		reference_psd = None

	# set up noise tracker
	tracker = NoiseTracker(instrument, agg_sink, client, options.analysis_tag, options.horizon_approximant, options.horizon_f_min, options.horizon_f_max, injections=options.injection_channel)

	#
	# build pipeline
	#

	stream = Stream.from_datasource(gw_data_source_info, instrument, state_vector=True, dq_vector=True, verbose=options.verbose)
	stream.add_callback(MessageType.ELEMENT, "spectrum", tracker.on_spectrum_message)

	logging.info("building pipeline ...")
	stream.resample(quality=9) \
		.queue(max_size_buffers=8) \
		.condition(options.sample_rate, instrument, psd = reference_psd, psd_fft_length = options.psd_fft_length, track_psd = True, statevector = stream.source.state_vector[instrument], dqvector = stream.source.dq_vector[instrument]) \
		.queue() \
		.reblock() \
		.bufsink(tracker.on_buffer)

	#
	# process segment
	#

	logging.info("running pipeline ...")
	stream.start()
	logging.info("shutting down...")
	#
	# done.  online pipeline always ends with an error code so that dagman does
	# not mark the job "done" and the job will be restarted when the dag is
	# restarted.
	#
	sys.exit(1)
