#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_gstlal-ugly_1767612080/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl/bin/python
import sys
from gstlal import metric as metric_module
from ligo.lw import ligolw
from ligo.lw import utils as ligolw_utils
from lal.utils import CacheEntry
from ligo.lw import lsctables
import numpy
import argparse
import h5py

class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
	pass
lsctables.use_in(LIGOLWContentHandler)

parser = argparse.ArgumentParser()
parser.add_argument("--psd-xml-file", help = "provide a psd xml file")
parser.add_argument("--bank-file", help = "provide the bank file for which overlaps will be calculated")
parser.add_argument("--out-h5-file", required = True, help = "provide the output hdf5 file name")
parser.add_argument("--approximant", default="IMRPhenomD", help = "Waveform model. Default IMRPhenomD.")
parser.add_argument("--f-low", type=float, default=15.0, help = "Lowest frequency component of template. Default 15 Hz.")
parser.add_argument("--f-high", type=float, default=4096.0, help = "Highest frequency component of template Default 4096 Hz.")
parser.add_argument("--split-bank-cache", help = "Cache file containing paths to split banks. Required.")
parser.add_argument("--split-bank-index", type=int, help = "Split bank index that will be looped over.")
parser.add_argument("--overlap-threshold", type=float, default=0.25, help = "Overlap threshold. Default 0.25.")
parser.add_argument("--number-of-templates", type=int, help = "Total number of templates in bank. Required.")

args = parser.parse_args()

g_ij = metric_module.Metric(
        args.psd_xml_file,
        coord_func = metric_module.x_y_z_zn_func,
        duration = 1.0, # FIXME!!!!!
        flow = args.f_low,
        fhigh = args.f_high,
        approximant = args.approximant
)

split_banks = sorted([ce.path for ce in map(CacheEntry, open(args.split_bank_cache, 'r'))])
xmldoc = ligolw_utils.load_filename(split_banks[args.split_bank_index], verbose=True, contenthandler = LIGOLWContentHandler)
sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)

def id_x_y_z_zn_from_row(row):
	return [row.template_id,
		metric_module.x_from_m1_m2_s1_s2(row.mass1, row.mass2, row.spin1z, row.spin2z),
		metric_module.y_from_m1_m2_s1_s2(row.mass1, row.mass2, row.spin1z, row.spin2z),
		metric_module.z_from_m1_m2_s1_s2(row.mass1, row.mass2, row.spin1z, row.spin2z),
		metric_module.zn_from_m1_m2_s1_s2(row.mass1, row.mass2, row.spin1z, row.spin2z)
		]

vec1s = numpy.array([id_x_y_z_zn_from_row(row) for row in sngl_inspiral_table])

output = numpy.zeros((len(vec1s), int(args.number_of_templates/3)))
id2 = []
col_id_map, col_idx = {}, 0

for n, vec1, in enumerate(vec1s):
      g, det = g_ij(vec1[1:])
      b_idx = args.split_bank_index
      fwd_flag = 1
      
      while b_idx < len(split_banks) and b_idx >= 0:
            sbank_str = split_banks[b_idx]
            xmldoc2 = ligolw_utils.load_filename(sbank_str, verbose=False, contenthandler = LIGOLWContentHandler)
            sngl_inspiral_table2 = lsctables.SnglInspiralTable.get_table(xmldoc2)
            vec2s = numpy.array([id_x_y_z_zn_from_row(row) for row in sngl_inspiral_table2])
            i2list = []
            for t2_id in vec2s[:,0]:
                  if t2_id in col_id_map:
                        i2 = col_id_map[t2_id]
                  else:
                        i2 = col_id_map[t2_id] = col_idx
                        col_idx += 1
                        id2.append(t2_id)
                  i2list.append(i2)
            if fwd_flag:
                  b_idx += 1
            else:
                  b_idx -= 1
            def match(vec2, vec1 = vec1, g = g):
                  return (vec1[0], vec2[0], g_ij.pseudo_match(g, vec1[1:], vec2[1:]))
            thisoutput = numpy.array([[i, row[1], row[2]] for i, row in enumerate(map(match, vec2s))])
            xmldoc2.unlink()
            maxovrlp = max(thisoutput[:,2])
            output[n, i2list] = thisoutput[:,2]

            if maxovrlp >= args.overlap_threshold:
                  second_chance = 0

            #print "\t Max overlap in %s: %f" %(sbank_str.split('/')[-1],maxovrlp)

            if maxovrlp < args.overlap_threshold:
                  if second_chance < 2:
                        second_chance += 1
                  elif second_chance >= 2:
                        if fwd_flag:
                              second_chance = 0
                              fwd_flag = 0
                              b_idx = args.split_bank_index - 1
                        else:
                              print("done")
                              break
            if b_idx > len(split_banks) - 1:
                  second_chance = 0
                  fwd_flag = 0
                  b_idx = args.split_bank_index - 1

mask = (output==0).all(0)
stop = min(numpy.where(mask)[0])

output = numpy.array(output)
h5f = h5py.File(args.out_h5_file, 'w')
olapdata = h5f.create_group("%s_metric" %(args.approximant))
dset = olapdata.create_dataset("id", data = vec1s[:,0])
dset = olapdata.create_dataset("id2", data = id2)
dset = olapdata.create_dataset('overlaps', data = output[:,:stop])

h5f.close()
