#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_gstlal-ugly_1767612080/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl/bin/python
#
# Copyright (C) 2023  Rachael Huxford
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


from collections import defaultdict
import logging
from optparse import OptionParser
import os
import queue
import tempfile

import gi
gi.require_version('Gst', '1.0')
from gi.repository import GObject
from gi.repository import Gst

from lal import gpstime
from lal import LIGOTimeGPS
from lal.utils import CacheEntry

from ligo import segments

from gstlal import pipeparts
from gstlal import simplehandler
from gstlal.pipeparts import pipetools
from gstlal.dagparts import T050017_filename

#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#

def parse_command_line():
	parser = OptionParser(description = __doc__)

	# add devshmsrc/strain options
	parser.add_option("--passthrough-channel", type = "str", action = "append", help = "Channels to split from streamed frames, and write without alteration. Can be provided multiple times per ifo analyzed.")
	parser.add_option("--strain-channel", type = "str", action = "append", help = "Strain channel to split from streamed frames and add to noiseless injections before writing. Must be provided once per ifo analyzed")
	parser.add_option("--shared-memory-partition", metavar = "path", action = "append", help = "Path to shared memory partition from which frames will be read. Must be provided once per ifo analyzed as {IFO}=/path/to/partition.")
	# add frame cache/injection options
	parser.add_option("--inj-channel", metavar = "channel_name", action = "append", type = "str", help = "Full channel name of noiseless injection channel to be added to hoft. Can only be provided once per ifo.")
	parser.add_option("--inj-frame-cache", metavar = "path", help = "Path to frame cache which points at noiseless injection frames")
	parser.add_option("--inj-frame-cache-regex", metavar = "inj regex", type = "str", help = "Regex option to use with injection frame cache. Must match the frame cache descriptor exactly. Example: H (instead of H1)")

	# Add general options
	parser.add_option("--output-dir", metavar = "path", action = "append", help = "Path to directory where frames should be written. Must be given once per ifo for which channels are provided as {IFO}=/path/to/directory.")
	parser.add_option("--frame-type", metavar = "name", type="string", action = "append", help = "Frame type to be included in the name of the output frame files. Must be provided once for each ifo anlyzed as {IFO}={FRAME_TYPE}.")
	parser.add_option("--gps-end-time", metavar = "s", default = 2000000000, help = "The time at which to stop looking for data in frame cache if different from legnth of cache. Script will terminate at this time. Default is to take all times in frame cache. e.g. GPS=200000000")
	parser.add_option("--history-len", metavar = "s", type=int, default = 300, help = "Length of time (in seconds) to keep files for in output-dir. Files written more than history-len in the past, will be deleted from output-dir. Default: 300s")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose (optional).")

	options, filenames = parser.parse_args()

	return options, filenames


#
# =============================================================================
#
#                                 Handler
#
# =============================================================================
#

def framecpp_filesink_path_handler_simple(elem: pipetools.Element, pspec, filename_q: queue.Queue):
	"""Add path for file sink to element for injection streaming

	Args:
		elem:
			Element, the element to which to add a filesink path property
		pspec:
			Unknown
		outpath:
			String, a string path

	Examples:
		>>> filesinkelem.connect("notify::timestamp", framecpp_filesink_path_handler_simple, "./")

	Returns:
		Element, with the formatted outpath attached as the "path" property
	"""

	# get other metadata
	instrument = elem.get_property("instrument")
	frame_type = elem.get_property("frame-type")
	timestamp = elem.get_property("timestamp") // Gst.SECOND
	path = elem.get_property("path")

	# check if the desired end time has been reached
	if timestamp >= int(options.gps_end_time):
		logging.info('Desired gps end time reached. Shutting down...')
		# needs to die gracefully here with EOS as in gstlal online handler case
		pipeline.send_event(Gst.Event.new_eos())

	# do cleanup for older files
	filename = T050017_filename(instrument, frame_type, (timestamp, timestamp+1), '.gwf', path = path)
	if filename_q.full():
		old_filename = filename_q.get()
		logging.debug('removing old file from disk: %s', old_filename)
		os.unlink(old_filename)
		filename_q.put(filename)
		logging.debug('adding new file to queue: %s', filename)
	else:
		filename_q.put(filename)
		logging.debug('adding new file to queue: %s', filename)


#
# =============================================================================
#
#                                     Setup
#
# =============================================================================
#

# Set up gstreamer pipeline
GObject.threads_init()
Gst.init(None)
mainloop = GObject.MainLoop()
pipeline = Gst.Pipeline(name="gstlal_ll_inj_stream")
handler = simplehandler.Handler(mainloop, pipeline)


#
# Parse Input Options
#

options, filenames = parse_command_line()

# Noiseless Injection / Frame Cache Options
hoft_inj_channel_dict = {}
for inj_channel in options.inj_channel:
	hoft_inj_channel_dict[inj_channel.split(":")[0]] = inj_channel.split(":")[1]
inj_frame_cache = options.inj_frame_cache

# Strain / Devshmsrc Options
hoft_channel_dict = defaultdict(list)
hoft_passthrough_channel_dict = defaultdict(list)
for strain_channel in options.strain_channel:
	hoft_channel_dict[strain_channel.split(":")[0]].append(strain_channel.split(":")[1])

if options.passthrough_channel:
	for channel in options.passthrough_channel:
		hoft_passthrough_channel_dict[channel.split(":")[0]].append(channel.split(":")[1])
		hoft_channel_dict[channel.split(":")[0]].append(channel.split(":")[1])

shared_memory_dict = {}
for smd in options.shared_memory_partition:
	shared_memory_dict[smd.split('=')[0]] = smd.split('=')[1]

# Output Frame Options
output_dir_dict = {}
for out_dir_by_ifo in options.output_dir:
	outdir = out_dir_by_ifo.split('=')[1]
	if not os.path.exists(outdir):
		os.mkdir(outdir)
	output_dir_dict[out_dir_by_ifo.split('=')[0]] = outdir

frame_type_dict = {}
for ftype in options.frame_type:
	frame_type_dict[ftype.split('=')[0]] = ftype.split('=')[1]


#
# Logging
#

log_level = logging.DEBUG if options.verbose else logging.INFO
logging.basicConfig(level = log_level, format = "%(asctime)s | gstlal_ll_inj_stream : %(levelname)s : %(message)s")


#
# Set up frame cache
#
# We want it to start just before the script is launched
# and go to the end of available data.
# This way, the adder will not be waiting long to synchronize
# the injections with live data.
#

cache_entries = [CacheEntry(x) for x in open(options.inj_frame_cache)]
now = gpstime.gps_time_now()
end = LIGOTimeGPS(int(options.gps_end_time), 0)

# Check that end < now.
# Otherwise will take entire cache.
if end < now:
	raise ValueError("--gps-end-time must be greater than the current gps time. Now: %s, Supplied gps-end-time:%s"%(now, end))

#create new cache
now_to_end_seg = segments.segment((now, end))
new_cache = [c for c in cache_entries if now_to_end_seg.intersects(c.segment)]

#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#


logging.info("building pipeline ...")

#
# Support Functions
#

def gap_dropper(pad, info):
	buf = info.get_buffer()
	startts = LIGOTimeGPS(0, buf.pts)
	if bool(buf.mini_object.flags & Gst.BufferFlags.GAP):
		logging.debug("Dropping b/c flagged as gap: %s"% startts)
		return Gst.PadProbeReturn.DROP
	return Gst.PadProbeReturn.OK

#
# Mainloop
#

#set up temp file for the new cache
with tempfile.NamedTemporaryFile(mode='w',suffix='.cache') as f:

	# write out the new cache to the temp filename
	f.write("\n".join(["%s" % c for c in new_cache]))
	# fetch the duration of the frame files to use in queue lengths
	frame_duration = int(abs(new_cache[0].segment))

	for instrument in hoft_inj_channel_dict:


		#
		# fetch inj channel from disk
		#

		inj_src = pipeparts.mklalcachesrc(pipeline, location=f.name, cache_src_regex=options.inj_frame_cache_regex)
		inj_demux = pipeparts.mkframecppchanneldemux(pipeline, inj_src, do_file_checksum=False, channel_list=list(map("%s:%s".__mod__, hoft_inj_channel_dict.items())))

		# allow frame reading and decoding to occur in a diffrent
		# thread
		inj_strain = pipeparts.mkqueue(pipeline, None, max_size_buffers=0, max_size_bytes=0, max_size_time=frame_duration * Gst.SECOND)
		pipeparts.src_deferred_link(inj_demux, "%s:%s" % (instrument, hoft_inj_channel_dict[instrument]), inj_strain.get_static_pad("sink"))


		#
		# fetch hoft frames from devshm
		#

		hoft_src = pipeparts.mkdevshmsrc(pipeline, shm_dirname=shared_memory_dict[instrument], wait_time=60, watch_suffix='.gwf')
		hoft_src = pipeparts.mkqueue(pipeline, hoft_src, max_size_buffers=0, max_size_bytes=0, max_size_time=Gst.SECOND * frame_duration)
		hoft_demux = pipeparts.mkframecppchanneldemux(pipeline, hoft_src, do_file_checksum=False, skip_bad_files=True)

		# extract strain with 10 buffers of buffering
		# Extract all the channels we want to mux in the end and src defer link each of them.
		channel_dict = {}
		strain_dict = {}
		for channel_name in hoft_channel_dict[instrument]:
			channel_src = pipeparts.mkqueue(pipeline, None, max_size_buffers=0, max_size_bytes=0, max_size_time=Gst.SECOND * frame_duration)
			pipeparts.src_deferred_link(hoft_demux, "%s:%s" % (instrument, channel_name), channel_src.get_static_pad("sink"))
			# fill in and drop samples as necessary
			channel_src = pipeparts.mkaudiorate(pipeline, channel_src, skip_to_first=True, silent=False, name="%s_%s_audiorate" %(instrument, channel_name))
			channel_src = pipeparts.mkqueue(pipeline, channel_src, max_size_time = 10)
			# audioconvert necessary for incorporation of Virgo data
			channel_src = pipeparts.mkaudioconvert(pipeline,channel_src)

			# Only add items to strain_dict which we want to become a new hoft + inj stream
			if channel_name not in hoft_passthrough_channel_dict[instrument]:
				# tee here, so we can use a tee later for the adder
				strain_dict[instrument+':'+channel_name] = pipeparts.mktee(pipeline, channel_src)
			else:
				channel_dict[instrument+':'+channel_name] = pipeparts.mkqueue(pipeline, channel_src, max_size_time = Gst.SECOND * 8)
				channel_dict[instrument+':'+channel_name].get_static_pad("src").add_probe(Gst.PadProbeType.BUFFER, gap_dropper)


		#
		# Gate inj stream when hoft is not available
		# this prevents inj-only frames from being written to disk.
		#
		hoft_key = list(strain_dict.keys())[0] # hoft strains arrive together
		gate_hoft_strain = pipeparts.mkqueue(pipeline, strain_dict[hoft_key])
		inj_strain = pipeparts.mkgate(pipeline, inj_strain, control = gate_hoft_strain, threshold = 0, leaky = True)

		#
		# Combine strains and re-mux
		#

		# tee the the inj strain here, so it can be used multiple times	
		inj_strain = pipeparts.mktee(pipeline, inj_strain)
		inj_strain_dict = {}
		for channel, hoft_strain in strain_dict.items():
			# Add a tee to this dict, so we can mux o.g. hoft later
			# Drop gaps b/c no injection buffers for gaps.
			hoft_strain = pipeparts.mkqueue(pipeline, hoft_strain, max_size_time = 10 * Gst.SECOND)
			hoft_strain.get_static_pad("src").add_probe(Gst.PadProbeType.BUFFER, gap_dropper)

			# large queues are necessary for timestamp synchronization
			this_inj_strain = pipeparts.mkqueue(pipeline, inj_strain, max_size_time = 1024 * Gst.SECOND)

			# add injections to hoft
			# note that the adder sinks the streams
			inj_strain_dict[channel+'_INJ'] = pipeparts.mkadder(pipeline, [hoft_strain, this_inj_strain], sync=True, mix_mode="sum")


		# want to mux hoft, inj + hoft, and passthrough channels
		# cannot mux tee, so add queue. Drop gaps b/c no inj for gaps.
		for channel in strain_dict:
			strain_dict[channel] = pipeparts.mkqueue(pipeline, strain_dict[channel])
			strain_dict[channel].get_static_pad("src").add_probe(Gst.PadProbeType.BUFFER, gap_dropper)
		channel_dict.update(strain_dict)
		channel_dict.update(inj_strain_dict)
		output_stream = pipeparts.mkframecppchannelmux(pipeline, channel_src_map=channel_dict, frame_duration=1, frames_per_file=1)


		#
		# Write frames to disk
		#

		# make a queue for filename clean-up
		filename_queue = queue.Queue(options.history_len)
		framesink = pipeparts.mkframecppfilesink(pipeline, output_stream, frame_type=frame_type_dict[instrument], path=output_dir_dict[instrument])
		# note that how the output is organized is controlled
		# by the function frameccp_filesink_path_handler_simple
		framesink.connect("notify::timestamp", framecpp_filesink_path_handler_simple, filename_queue)


	#
	# Start pipeline
	#

	logging.info("running pipeline ...")

	if pipeline.set_state(Gst.State.PLAYING) == Gst.StateChangeReturn.FAILURE:
		raise RuntimeError("pipeline failed to enter PLAYING state")
	mainloop.run()

	logging.info("shutting down...")
