#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_gstlal-ugly_1767612080/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl/bin/python

import os
import sys
import math
from gstlal.plots import set_matplotlib_cache_directory
set_matplotlib_cache_directory()
import matplotlib
matplotlib.rcParams.update({
	"font.size": 10.0,
	"axes.titlesize": 10.0,
	"axes.labelsize": 10.0,
	"xtick.labelsize": 8.0,
	"ytick.labelsize": 8.0,
	"legend.fontsize": 8.0,
	"figure.dpi": 600,
	"savefig.dpi": 600,
	"text.usetex": True
})
from matplotlib import figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import numpy

from optparse import OptionParser

from ligo.lw import utils
from ligo.lw import ligolw
from ligo.lw import table
from ligo.lw import lsctables
from ligo.lw.utils import process as ligolw_process

import astropy.cosmology as cosmo
import astropy.stats as ast
import astropy.units as u

__author__ = "Sarah Caudill <sarah.caudill@ligo.org>"
__version__ = "git id %s" % ""	# FIXME
__date__ = ""	# FIXME

class ContentHandler(ligolw.LIGOLWContentHandler):
    pass

lsctables.use_in(ContentHandler)


def create_plot(x_label = None, y_label = None, width = 165.0, aspect = None):
    if aspect is None:
        aspect = (1 + math.sqrt(5)) / 2
    fig = figure.Figure()
    FigureCanvas(fig)
    fig.set_size_inches(width / 25.4, width / 25.4 / aspect)
    axes = fig.gca()
    axes.grid(True)
    if x_label is not None:
        axes.set_xlabel(x_label)
    if y_label is not None:
        axes.set_ylabel(y_label)
    return fig, axes

def parse_command_line():
    parser = OptionParser(
        description = __doc__
    )

    parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
    options, urls = parser.parse_args()

    return options, urls

options, urls = parse_command_line()

for file_path in urls:

    tag = os.path.basename(file_path).split(".")[0]
#    fp = open(file_path, 'r')
#    xmldoc, digest = utils.load_fileobj(fp, contenthandler=ContentHandler)
#    sim_table = table.get_table(xmldoc, 'sim_inspiral')
    xmldoc = utils.load_filename(sys.argv[1], False, contenthandler=ContentHandler)
    sim_table = table.get_table(xmldoc, lsctables.SimInspiralTable.tableName)

    mass1 = []
    mass2 = []
    spin1x = []
    spin1y = []
    spin1z = []
    spin2x = []
    spin2y = []
    spin2z = []
    distance = []
    time = []
    alpha3 = []
    for sim in sim_table:
        mass1.append(sim.mass1)
        mass2.append(sim.mass2)
        spin1x.append(sim.spin1x)
        spin1y.append(sim.spin1y)
        spin1z.append(sim.spin1z)
        spin2x.append(sim.spin2x)
        spin2y.append(sim.spin2y)
        spin2z.append(sim.spin2z)
        distance.append(sim.distance)
        time.append(sim.geocent_end_time + 10**-9*sim.geocent_end_time_ns)
        alpha3.append(sim.alpha3)

    fig, axes = create_plot('det mass1','Count')
    axes.set_title(r"Plot")
    axes.hist(mass1, 100)
    fig.savefig('det_mass1_hist_%s.png' % str(tag))

    fig, axes = create_plot('det mass2','Count')
    axes.set_title(r"Plot")
    axes.hist(mass2, 100)
    fig.savefig('det_mass2_hist_%s.png' % str(tag))

    fig, axes = create_plot('det mtotal','Count')
    axes.set_title(r"Plot")
    axes.hist(numpy.array(mass1) + numpy.array(mass2), 100)
    fig.savefig('det_mtotal_hist_%s.png' % str(tag))

    fig, axes = create_plot('src mass1','Count')
    axes.set_title(r"Plot")
    axes.hist(numpy.array(mass1)/(1.0 + numpy.array(alpha3)), 100, log=True)
    fig.savefig('src_mass1_loghist_%s.png' % str(tag))

    fig, axes = create_plot('src mass2','Count')
    axes.set_title(r"Plot")
    axes.hist(numpy.array(mass2)/(1.0 + numpy.array(alpha3)), 100, log=True)
    fig.savefig('src_mass2_loghist_%s.png' % str(tag))

    fig, axes = create_plot('src mtotal','Count')
    axes.set_title(r"Plot")
    axes.hist((numpy.array(mass1) + numpy.array(mass2))/(1.0 + numpy.array(alpha3)), 100, log=True)
    fig.savefig('src_mtotal_loghist_%s.png' % str(tag))

    fig, axes = create_plot('det mass1','Count')
    axes.set_title(r"Plot")
    axes.hist(mass1, 100, log=True)
    fig.savefig('det_mass1_loghist_%s.png' % str(tag))

    fig, axes = create_plot('det mass2','Count')
    axes.set_title(r"Plot")
    axes.hist(mass2, 100, log=True)
    fig.savefig('det_mass2_loghist_%s.png' % str(tag))

    fig, axes = create_plot('det mtotal','Count')
    axes.set_title(r"Plot")
    axes.hist(numpy.array(mass1) + numpy.array(mass2), 100, log=True)
    fig.savefig('det_mtotal_loghist_%s.png' % str(tag))

    fig, axes = create_plot('src mass1','Count')
    axes.set_title(r"Plot")
    axes.hist(numpy.array(mass1)/(1.0 + numpy.array(alpha3)), 100)
    fig.savefig('src_mass1_hist_%s.png' % str(tag))

    fig, axes = create_plot('src mass2','Count')
    axes.set_title(r"Plot")
    axes.hist(numpy.array(mass2)/(1.0 + numpy.array(alpha3)), 100)
    fig.savefig('src_mass2_hist_%s.png' % str(tag))

    fig, axes = create_plot('src mtotal','Count')
    axes.set_title(r"Plot")
    axes.hist((numpy.array(mass1) + numpy.array(mass2))/(1.0 + numpy.array(alpha3)), 100)
    fig.savefig('src_mtotal_hist_%s.png' % str(tag))

    fig, axes = create_plot('spin1x','Count')
    axes.set_title(r"Plot")
    axes.hist(spin1x, 100)
    fig.savefig('spin1x_hist_%s.png' % str(tag))

    fig, axes = create_plot('spin1y','Count')
    axes.set_title(r"Plot")
    axes.hist(spin1y, 100)
    fig.savefig('spin1y_hist_%s.png' % str(tag))

    fig, axes = create_plot('spin1z','Count')
    axes.set_title(r"Plot")
    axes.hist(spin1z, 100)
    fig.savefig('spin1z_hist_%s.png' % str(tag))

    fig, axes = create_plot('spin2x','Count')
    axes.set_title(r"Plot")
    axes.hist(spin2x, 100)
    fig.savefig('spin2x_hist_%s.png' % str(tag))

    fig, axes = create_plot('spin2y','Count')
    axes.set_title(r"Plot")
    axes.hist(spin2y, 100)
    fig.savefig('spin2y_hist_%s.png' % str(tag))

    fig, axes = create_plot('spin2z','Count')
    axes.set_title(r"Plot")
    axes.hist(spin2z, 100)
    fig.savefig('spin2z_hist_%s.png' % str(tag))

    fig, axes = create_plot('distance','Count')
    axes.set_title(r"Plot")
    axes.hist(distance, 100)
    fig.savefig('distance_hist_%s.png' % str(tag))

    fig, axes = create_plot('Time','Count')
    axes.set_title(r"Plot")
    axes.hist(time, 100)
    fig.savefig('time_hist_%s.png' % str(tag))

    fig, axes = create_plot('Redshift','Count')
    axes.hist(alpha3, 100, normed=True)
    zs = numpy.linspace(0, max(alpha3), 1000)
    dvcdzo1pzs = cosmo.Planck13.differential_comoving_volume(zs).to(u.Gpc**3/u.sr).value/(1+zs)
    propernorm = numpy.trapz(dvcdzo1pzs, zs)
    axes.plot(zs, dvcdzo1pzs/propernorm, label='Cosmological Distribution')
    fig.savefig('cosmological_distribution_%s.png' % str(tag))
