#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_gstlal-ugly_1767612080/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl/bin/python

from __future__ import division

import numpy as np
import sys
import matplotlib
matplotlib.use("agg")
from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D

from ligo.lw import ligolw
from ligo.lw import lsctables
from ligo.lw import utils
from ligo.lw.utils import process as ligolw_process

from optparse import OptionParser

#from sbank import git_version FIXME
from lalinspiral.sbank.bank import Bank
from lalinspiral.sbank.tau0tau3 import m1m2_to_tau0tau3
from lalinspiral.sbank.psds import noise_models, read_psd, get_PSD
from lalinspiral.sbank.waveforms import waveforms
from lal import PI, MTSUN_SI

import lal

class ContentHandler(ligolw.LIGOLWContentHandler):
    pass
lsctables.use_in(ContentHandler)


print("Reading in the template bank...", file=sys.stdout)
tmpdoc = utils.load_filename(sys.argv[1], contenthandler=ContentHandler)
sngl_inspiral = lsctables.SnglInspiralTable.get_table(tmpdoc)


m1s = [s.mass1 for s in sngl_inspiral]
m2s = [s.mass2 for s in sngl_inspiral]
tau0s = [s.tau0 for s in sngl_inspiral]
tau3s = [s.tau3 for s in sngl_inspiral]
mtots = [s.mass1 + s.mass2 for s in sngl_inspiral]
chis = [s.spin1z for s in sngl_inspiral]
mratios = [max(s.mass1/s.mass2, s.mass2/s.mass1) for s in sngl_inspiral]


fig = pyplot.figure()
tag = sys.argv[1].strip(".xml.gz") + ".png"

#
# plot number of templates versus total mass
#
ax = fig.gca()
pyplot.hist(mtots, bins=10, log=True)
ax.set_xlabel("$M_\mathrm{total}$")
ax.set_ylabel("Number of Templates")
ax.set_xlim([min(mtots), max(mtots)])
ax.grid()
pyplot.savefig("number_templates_vs_mtotal_" + tag)
fig.clf()

#
# plot templates vs m1 m2
#
ax = fig.gca()
pyplot.scatter(m1s, m2s)
ax.set_xlabel("$m_1$")
ax.set_ylabel("m_2")
ax.grid()
#ax.set_xlim([4, 5])
#ax.set_ylim([4, 5])
pyplot.savefig("m1m2_scatter_" + tag)
fig.clf()

#
# plot templates vs m1 m2
#
ax = fig.gca()
pyplot.scatter(mtots, mratios)
ax.set_xlabel("$M$")
ax.set_ylabel("$q$")
ax.grid()
pyplot.savefig("mtotq_scatter_" + tag)
fig.clf()


#
# plot templates vs tau0 tau3
#
ax = fig.gca()
tau0, tau3 = m1m2_to_tau0tau3(np.array(m1s), np.array(m2s), 30)
pyplot.scatter(tau0, tau3)
ax.set_xlabel("$\\tau_0$")
ax.set_ylabel("\\tau_3")
ax.grid()
pyplot.savefig("tau0tau3_scatter_" + tag)
fig.clf()


ax = fig.gca()
pyplot.hist(mratios, bins=10, log=True)
ax.set_xlabel("$m_1/m_2$")
ax.set_ylabel("Number of Templates")
ax.set_xlim([min(mratios), max(mratios)])
ax.grid()
pyplot.savefig("number_templates_vs_mratio_" + tag)
fig.clf()
