#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_gstlal-ugly_1767612080/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl/bin/python
import sys
import itertools
import time
from gstlal import metric as metric_module
from ligo.lw import ligolw
from ligo.lw import utils as ligolw_utils
from ligo.lw import lsctables, param, array
import numpy
import argparse
import h5py
from gstlal import svd_bank

@array.use_in
@param.use_in
@lsctables.use_in
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
	pass

def mchirp(m1, m2):
	return (m1*m2)**.6 / (m1+m2)**.2

parser = argparse.ArgumentParser()
parser.add_argument("--output", help = "provide the output file")
parser.add_argument("--psd-xml-file", help = "provide a psd xml file")
parser.add_argument("--svd-file", help = "provide the bank file for which overlaps will be calculated")
parser.add_argument("--name", action = "append", help = "provide the name of the source category, e.g., NSBH")
parser.add_argument("--m1", action = "append", help = "provide the num:start:stop parameters for mass 1, e.g., 100:1:3")
parser.add_argument("--m2", action = "append", help = "provide the num:start:stop parameters for mass 2, e.g., 100:1:3")
parser.add_argument("--chi-eff", action = "append", help = "provide the num:start:stop parameters for chi effective, e.g., 11:-1:1")

args = parser.parse_args()

outfile = open(args.output, "w")

# Initialize a numeric metric object that will let us compute approximate
# matches
g_ij = metric_module.Metric(
	args.psd_xml_file,
	coord_func = metric_module.x_y_z_zn_func,
	duration = 1.0, # FIXME!!!!!
	flow = 10,
	fhigh = 1024,
	approximant = "IMRPhenomD"
)

# Read all of the templates from the svd bin
sngl_inspiral_table = []
for n, bank in enumerate(svd_bank.read_banks(args.svd_file, contenthandler = LIGOLWContentHandler, verbose = True)):
	sngl_inspiral_table.extend([row for row in bank.sngl_inspiral_table])

print("number of templates", len(sngl_inspiral_table))

# get the ranges of mchirp and chi effective from the templates
min_mchirp = min(row.mchirp for row in sngl_inspiral_table)
max_mchirp = max(row.mchirp for row in sngl_inspiral_table)
chieff = [(row.mass1 * row.spin1z + row.mass2 * row.spin2z) / (row.mass1 + row.mass2) for row in sngl_inspiral_table]
min_chi = min(chieff)
max_chi = max(chieff)
print("mchirp (%f,%f) chieff (%f,%f)" % (min_mchirp, max_mchirp, min_chi, max_chi))

# define the coordinate transformation for the metric.
def x_y_z_zn_from_row(row):
	return [metric_module.x_from_m1_m2_s1_s2(row.mass1, row.mass2, row.spin1z, row.spin2z),
		metric_module.y_from_m1_m2_s1_s2(row.mass1, row.mass2, row.spin1z, row.spin2z),
		metric_module.z_from_m1_m2_s1_s2(row.mass1, row.mass2, row.spin1z, row.spin2z),
		metric_module.zn_from_m1_m2_s1_s2(row.mass1, row.mass2, row.spin1z, row.spin2z)
		]

print("mchirp (%f,%f) chieff (%f,%f)" % (min_mchirp, max_mchirp, min_chi, max_chi), file=outfile)
print("name, m1 start, m1 stop, m2 start, m2 stop, activation counts", file=outfile)

# compute template match functions. Each template has a metric computed
print("computing metrics")
start = time.time()
metric_tensors = []
templates = []

for tmp in (x_y_z_zn_from_row(row) for row in sngl_inspiral_table):
	g = g_ij(tmp)[0]
	templates.append(tmp)
	metric_tensors.append(g)

templates = numpy.array(templates)
metric_tensors = numpy.array(metric_tensors)

print("done computing metrics", time.time() - start)

# setup an array of SNRs to loop over covering a plausible signal range.
snrs = numpy.linspace(6., 20., 15)
snrnorm = sum(snrs**-4)

# iterate over the lists of masses and spins that define the different target
# regions, e.g., bns, nsbh, bbh etc.
for M1, M2, S1, name in zip(args.m1, args.m2, args.chi_eff, args.name):

	print(M1, M2)

	m1n,m1s,m1e = [float(x) for x in M1.split(":")]
	m2n,m2s,m2e = [float(x) for x in M2.split(":")]
	chin,chis,chie = [float(x) for x in S1.split(":")]
	num_injections = m1n * m2n * chin

	# Setup some "injections"
	# uniform in log population as requested by Jolien for mass and uniform in spin
	signals = []
	print(m1n, m1s, m1e, m2n,m2s,m2e, chin,chis,chie)
	for m1,m2,chi in itertools.product(numpy.logspace(numpy.log10(m1s), numpy.log10(m1e), m1n), numpy.logspace(numpy.log10(m2s), numpy.log10(m2e), m2n), numpy.linspace(chis, chie, chin)):

		signals.append(
			numpy.array(
			[metric_module.x_from_m1_m2_s1_s2(m1, m2, chi, chi),
			metric_module.y_from_m1_m2_s1_s2(m1, m2, chi, chi),
			metric_module.z_from_m1_m2_s1_s2(m1, m2, chi, chi),
			metric_module.zn_from_m1_m2_s1_s2(m1, m2, chi, chi)]
			)
		)

	signals = numpy.array(signals)

	count = 0
	print("computing counts for %d signals" % len(signals))

	if len(signals) > 0:
		for n, template in enumerate(templates):

			delta = signals - template
			X = numpy.dot(delta, metric_tensors[n])
			Y = delta
			d2 = numpy.sum((X * Y), axis = 1)
			# flatten out the parabolic match function
			matches = 1. - (numpy.arctan(d2**.5 * numpy.pi / 2) / numpy.pi * 2)**2
			matches[numpy.isnan(matches)] = 0.
			matches = matches[matches > 0]
			#print n, len(matches)
			if len(matches) > 0:
				for rho in snrs:
					count += rho**-4 / snrnorm * (numpy.exp(-rho**2*(1.0 - matches)**2/2.)).sum() / num_injections
	print(count)
	print("%s, %f, %f, %f, %f, %e" % (name, m1s, m1e, m2s, m2e, count), file=outfile)
