#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_gstlal-ugly_1767612080/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl/bin/python
#
# Copyright (C) 2014 Miguel Fernandez, Chad Hanna
# Copyright (C) 2016,2017 Kipp Cannon, Miguel Fernandez, Chad Hanna, Stephen Privitera, Jonathan Wang
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import itertools
import numpy
from ligo.lw import ligolw
from ligo.lw import lsctables
from ligo.lw import table
from ligo.lw import utils
from ligo.lw import ilwd
from ligo.lw.utils import process as ligolw_process
#from glue.ligolw import ligolw
#from glue.ligolw import lsctables
#from glue.ligolw import table
#from glue.ligolw import utils
#from glue.ligolw import ilwd
#from glue.ligolw.utils import process as ligolw_process
from gstlal import metric as metric_module
from gstlal import tree
# FIXME dont do this
from gstlal.tree import *
import os,sys,argparse


# Read command line options
def parse_command_line():

	parser = argparse.ArgumentParser(description="Template generator via a binary tree decomposition.")
	parser.add_argument("-v", "--verbose", action="store_true", default=False,\
					help="Be verbose.")
	#parser.add_argument("--mcm2", action="store_true", default=False,\
	#				help="Use total mass and mass ratio coordinates.")
	parser.add_argument("-d", "--debug", action="store_true", default=False,\
					help="Extra explicit information for debugging and sanity checks.")
	parser.add_argument("-o", "--output-name", action="store", default="treebank.xml.gz",\
					help="Specify output bank filename.")

	# mass limits
	parser.add_argument("--min-mass1", action="store", type=float,\
					default=3.0, help="Minimum mass1 to generate bank.")
	parser.add_argument("--max-mass1", action="store", type=float,\
					default=10.0, help="Maximum mass1 to generate bank.")
	parser.add_argument("--min-mass2", action="store", type=float,\
					default=3.0, help="Minimum mass2 to generate bank.")
	parser.add_argument("--max-mass2", action="store", type=float,\
					default=10.0, help="Maximum mass2 to generate bank.")
	parser.add_argument("--min-mc", action="store", type=float,\
					help="Minimum chirp mass to generate bank. Must have --mcm2.")
	parser.add_argument("--max-mc", action="store", type=float,\
					help="Maximum chirp mass to generate bank.  Must have --mcm2.")

	# aligned spin limits
	parser.add_argument("--min-spin1z", action="store", type=float,\
					default=0, help="Minimum mass1 to generate bank.")
	parser.add_argument("--max-spin1z", action="store", type=float,\
					default=0, help="Maximum mass1 to generate bank.")
	parser.add_argument("--min-spin2z", action="store", type=float,\
					default=0, help="Minimum mass2 to generate bank.")
	parser.add_argument("--max-spin2z", action="store", type=float,\
					default=0, help="Maximum mass2 to generate bank.")

	# in-plane spin limits
	parser.add_argument("--min-spin1x", action="store", type=float,\
					default=0, help="Minimum mass1 to generate bank.")
	parser.add_argument("--max-spin1x", action="store", type=float,\
					default=0, help="Maximum mass1 to generate bank.")
	parser.add_argument("--min-spin2x", action="store", type=float,\
					default=0, help="Minimum mass2 to generate bank.")
	parser.add_argument("--max-spin2x", action="store", type=float,\
					default=0, help="Maximum mass2 to generate bank.")
	parser.add_argument("--min-spin1y", action="store", type=float,\
					default=0, help="Minimum mass1 to generate bank.")
	parser.add_argument("--max-spin1y", action="store", type=float,\
					default=0, help="Maximum mass1 to generate bank.")
	parser.add_argument("--min-spin2y", action="store", type=float,\
					default=0, help="Minimum mass2 to generate bank.")
	parser.add_argument("--max-spin2y", action="store", type=float,\
					default=0, help="Maximum mass2 to generate bank.")

	parser.add_argument("--min-inc", action="store", type=float,\
					default=0, help="Minimum inclination to generate bank. Relevant only for precessing / sub-dominant mode banks.")
	parser.add_argument("--max-inc", action="store", type=float,\
					default=numpy.pi/2, help="Maximum inclination to generate bank. Relevant only for precessing / sub-dominant mode banks.")
	parser.add_argument("--max-mtotal", action="store", type=float,\
					default=float('inf'), help="Maximum total mass, default infinity")
	parser.add_argument("--max-q", action="store", type=float,\
					default=float('inf'), help="Maximum q, default infinity")

	# meta-params
	parser.add_argument("--min-match", action="store", type=float,\
					default=0.95, help="Minimum match to generate bank.")
	parser.add_argument("--rng-seed", action="store", type=float,\
					default=314, help="Set seed used for random number generator. Relevant only when using stocahstic placement.")
	parser.add_argument("--flow", action="store", type=float,\
					default=30.0, help="Low frequency cutoff for overlap calculations.")
	parser.add_argument("--fhigh", action="store", type=float,\
					default=1024.0, help="High frequency cutoff for overlap calculations.")
	parser.add_argument("--approximant", action="store", type=str,\
					default="IMRPhenomD", help="Specify approximant.")
	parser.add_argument("--psd-file", action="store",\
					default=None, help="Input PSD file.")
	parser.add_argument("--noise-model", action="store",\
					default=None, help="Specify standard noise model.")

	args = parser.parse_args()

	numpy.random.seed(args.rng_seed)

	if args.noise_model and args.psd_file:
		raise ValueError("Cannot specify both --psd-file and --noise-model")

	#if (args.min_mc or args.max_mc) and not args.mcm2:
	#	raise ValueError("Must specify --mcm2 with --min-mc or --max-mc")

	if not (args.noise_model or args.psd_file):
		raise ValueError("Must specify a PSD.")

	if args.noise_model:
		raise NotImplementedError("IMPLEMENT NOISE MODELS!!")

	return args



# Initialize coordinate function
args = parse_command_line()
#
# Set up the hypercube which bounds the entire space.
#
coord_limits = []
positions = [] # I don't really get this?

# FIXME: This loop has to ensure that the parameter limits make sense. For
#   instance, the user should not specify inclination limits when the bank is
#   intended to be aligned-spin dominant mode.
for i, v in enumerate(("mass1", "mass2", "spin1x", "spin1y", "spin1z", "spin2x", "spin2y", "spin2z")):
	xi = getattr(args, "min_%s" % v)
	xf = getattr(args, "max_%s" % v)
	dx = xf-xi
	if dx !=0:
		coord_limits.append([xi,xf])
		positions.append(i)
# Function to map from the coordinates used for template placement and the
# metric calculation and the coordinates needed to call into LALSimulation, the
# latter of which includes parameters that may have fixed values.
#   FIXME: Assumes fixed value is zero, but user could specify identical but
#	 non-zero min/max values.
#   FIXME: Supports only precessing template banks. Add
#	 support for sub-dominant mode banks.
#   FIXME: Define a number of standard coordinate functions. This one here is
#	 essentially just the identity mapping.
def coord_func(coords, positions = positions):
	out = numpy.zeros(8)
	for i, pi in enumerate(positions):
		out[pi] = coords[i]
	return out

if False:#args.mcm2:
	coord_func = metric_module.M_q_func
else:
	coord_func = coord_func

# Initialize the metric and set the coordinate function
g_ij = metric_module.Metric(
	args.psd_file,
	coord_func = coord_func,
	duration = 4.0, # FIXME!!!!!
	flow = args.flow,
	fhigh = args.fhigh,
	approximant = args.approximant)

mismatch = 1.0 - args.min_match

# Initialize the tree root and then split with a given splitsize
def trans_M_q(coord_limits):
	Mlimits = [0.9 * metric_module.mc_from_m1_m2(coord_limits[0][0], coord_limits[1][0]), metric_module.mc_from_m1_m2(coord_limits[0][1], coord_limits[1][1]) * 1.11]
	m2limis = [coord_limits[1][0] * 0.9, coord_limits[1][1] * 1.256789]
	return [Mlimits] +  [m2limis] + coord_limits[2:]

if False:#args.mcm2:
	coord_limits_2 = trans_M_q(coord_limits)
	if args.min_mc:
		coord_limits_2[0][0] = max(coord_limits_2[0][0], args.min_mc)
	if args.max_mc:
		coord_limits_2[0][1] = min(coord_limits_2[0][1], args.max_mc)
	print(coord_limits_2)
	bank = Node(HyperCube(numpy.array(coord_limits_2), mismatch, constraint_func = lambda x: tree.mass_sym_constraint_mc(x, args.max_q, args.max_mtotal, coord_limits[0], coord_limits[1], coord_limits_2[0]), metric = g_ij))
else:
	coord_limits_2 = coord_limits
	bank = Node(HyperCube(numpy.array(coord_limits_2), mismatch, constraint_func = lambda x: tree.mass_sym_constraint(x, args.max_q, args.max_mtotal), metric = g_ij))


if args.verbose:
	print("The bank size according to the center metric: ", bank.cube.size)
	print("The bank boundaries are: ")
	for row in bank.cube.boundaries:
		print("\t", row)

mismatch = 1.0 - args.min_match
bank.split(tree.packing_density(len(coord_limits)), mismatch = mismatch, verbose = args.verbose)

# prepare a new XML document for writing template bank
xmldoc = ligolw.Document()
xmldoc.appendChild(ligolw.LIGO_LW())
sngl_inspiral_columns = ("process_id", "mass1", "mass2", "spin1x", "spin1y", "spin1z", "spin2x", "spin2y", "spin2z")
tbl = lsctables.New(lsctables.SnglInspiralTable, columns = sngl_inspiral_columns)
xmldoc.childNodes[-1].appendChild(tbl)
# FIXME make a real process table
process = ligolw_process.register_to_xmldoc(xmldoc, sys.argv[0], {})
process.set_end_time_now()

previous_tiles = []
nodes = bank.leafnodes()
patches = []
expected = 0
bad_metrics = 0

if args.verbose:
	print("Writing output document", file=sys.stderr)

for n, c in enumerate(nodes):

	tiles = c.tile(mismatch = mismatch)
	expected += c.num_templates(mismatch)
	for t in tiles:
		row = lsctables.SnglInspiralTable.RowType()
		row.mass1, row.mass2, row.spin1x, row.spin1y, row.spin1z, row.spin2x, row.spin2y, row.spin2z = coord_func(t)
		row.event_id = ilwd.ilwdchar('sngl_inspiral:event_id:%d' % len(tbl))
		row.ifo = "H1" # FIXME
		if row.mass1 < row.mass2:
			continue
			mass1 = row.mass1
			row.mass1 = row.mass2
			row.mass2 = mass1
		row.process_id = process.process_id
		if coord_limits[0][0] * 0.985 <= row.mass1 <= coord_limits[0][1] * 1.05 and coord_limits[1][0] * 0.985 <= row.mass2 <= coord_limits[1][1] * 1.05 and (row.mass1+row.mass2 < args.max_mtotal) and (row.mass1 / row.mass2 < args.max_q):
		#if (row.mass1+row.mass2 < args.max_mtotal) and (row.mass1 / row.mass2 < args.max_q):
			tbl.append(row)
			previous_tiles.append(t)

	patches.append(c.boundaries)

#print "bad metrics ", bad_metrics

numpy.save("patches.npy", numpy.array(patches))

utils.write_filename(xmldoc, args.output_name)

if args.verbose:
	print("Number of leaf nodes: ", n+1, file=sys.stderr)
	print("Number of templates (tiles): ", len(tbl), file=sys.stderr)
	print("Expected number ", expected, file=sys.stderr)
