#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_gstlal-ugly_1767612080/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl/bin/python
#
# Copyright (C) 2012  Drew Keppel, 2018  Heather Fong
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import sys
import numpy
import scipy
from optparse import OptionParser

import lal

__author__ = "Drew Keppel <drew.keppel@ligo.org>"

parser = OptionParser(description = __doc__)
parser.add_option("--instrument-flow", action = "append", help="instrument and lower frequency cutoff pairs, e.g. \"H1:10.\"")
parser.add_option("--output", metavar = "filename", help = "Set the name of the LIGO light-weight XML file to output")
parser.add_option("--df", metavar = "float", type = "float", default = 0.25, help = "set the frequency resolution to interpolate to, default = 0.25")
parser.add_option("--verbose", action = "store_true", help = "Be verbose.")
options, filenames = parser.parse_args()

# delay this import so the above options are displayed with the help message
from gstlal.psd import interpolate_psd, read_psd, write_psd

# save --df as deltaF
deltaF = options.df

# construct a dictionary for instrument-lower frequency cutoff pairs
flowers = {}
for ifoflow in options.instrument_flow:
	ifo = ifoflow.split(":")[0]
	flow = ifoflow.split(":")[1]
	flowers[ifo] = float(flow)

# construct a containers for the psds
psds = {}

# store the maximimum frequency found
fmaxs = {}

# loop over input files
for fname in filenames:
	# read in psds from file
	new_psds = read_psd(fname, verbose=options.verbose)

	for ifo,psd_freqseries in new_psds.items():
		# check that new instruments are not in psds
		if ifo in psds.keys():
			raise ValueError(
				f"ERROR: PSD already loaded for instrument {ifo}."
				"Please make sure the input files contain unique detector psds"
			)
		# check we have a flower for this instrument
		if ifo not in flowers.keys():
			raise ValueError(
				f"ERROR: No lower frequency cutoff for instrument {ifo}."
				"Please specify instrument and lower frequency cutoff pairs with --instrument-flow IFO:F_LOW"
			)
		# check flower is at or below f0 for this instrument
		if flowers[ifo] < psd_freqseries.f0:
			raise ValueError(
				f"ERROR: Lower frequency cutoff {flowers[ifo]:f} below f0 {psd_freqseries.f0:f} of PSD for instrument {ifo}."
				"Please choose a larger lower frequency cutoff or obtain a PSD with information down to that point."
			)
		# check whether f0 is a multiple of df
		if psd_freqseries.f0 % deltaF:
			raise ValueError(
				f"ERROR: f0 {psd_freqseries.f0:f} of PSD for instrument {ifo} not a multiple of --df {deltaF:f}."
				"Please choose a more suitable --df or obtain a PSD with f0 that is a multiple of --df."
			)
		# get fmax for this instrument
		fmaxs[ifo] = psd_freqseries.f0 + psd_freqseries.deltaF * (len(psd_freqseries.data.data) - 1)

	# add psds to dictionary
	psds.update(new_psds)

# create vectors to store f and harmonic mean psd
f = scipy.arange(max(fmaxs.values())/deltaF+1)*deltaF
invpsd_data = scipy.zeros(len(f))

# loop over read in psds
for ifo,psd_freqseries in psds.items():
	psdinterp = interpolate_psd(psd_freqseries, deltaF)
	sampleunits = psdinterp.sampleUnits
	windowed_invpsd_freqseries = psdinterp.data.data**-1.
	start = int(psdinterp.f0/deltaF)
	end = start + len(windowed_invpsd_freqseries)
	windowed_invpsd_freqseries[f[start:end] < flowers[ifo]] = scipy.zeros(sum(f[start:end] < flowers[ifo]))
	invpsd_data[start:end] += windowed_invpsd_freqseries

# create harmonic mean psd frequency series
harmonic_mean_psd = lal.CreateREAL8FrequencySeries(
	name = 'PSD',
	epoch = lal.LIGOTimeGPS(0.),
	f0 = f[0],
	deltaF = deltaF,
	sampleUnits = sampleunits,
	length = len(invpsd_data)
)
harmonic_mean_psd.data.data = (invpsd_data/len(psds))**-1.
psds[''.join(sorted(psds.keys()))] = harmonic_mean_psd

# write harmonic mean psd to file
write_psd(options.output, psds, options.verbose)
