#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_gstlal-ugly_1767612080/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl/bin/python
#
# Copyright (C) 2019  Duncan Meacher
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""Split SVD bank into 'odd' or 'even' banks"""


#
#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#

import itertools
from optparse import OptionParser
import os

import numpy as np

from ligo.lw import ligolw
from ligo.lw import lsctables
from ligo.lw import array as ligolw_array
from ligo.lw import param as ligolw_param

from gstlal import svd_bank
from gstlal.psd import read_psd

@ligolw_array.use_in
@ligolw_param.use_in
@lsctables.use_in
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
        pass

#
#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#

parser = OptionParser(description = __doc__)
parser.add_option("--svd-files", metavar = "filename", action = "append", help = "A LIGO light-weight xml / xml.gz file containing svd bank information (require). Can be passed multiple times." )
parser.add_option("--outdir", metavar = "directory", type = "str", default = ".", help = "Output directory for modified SVD files.")
parser.add_option("--reference-psd", metavar = "filename", help = "Load the spectrum from this LIGO light-weight XML file (required).")
parser.add_option("--in-place", action = "store_true", help = "If set, apply the SVD checkerboarding in place rather than generate new SVD files.")
parser.add_option("--even", action = "store_true", help = "Output even rows of reconstruction matrix. Default is odd.")
parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose (optional).")

options, filenames = parser.parse_args()

if options.svd_files is None:
	raise ValueError("SVD file(s) must be selected. --svd-files bank1.xml --svd-files bank2.xml --svd-files bank3.xml")

#
#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#

# Read in reference PSD
psd = read_psd(options.reference_psd, verbose=options.verbose)

# Loop over all SVD files given in command line
for bank_file in options.svd_files:
	# read in SVD bank file
	banks = svd_bank.read_banks(bank_file, contenthandler = LIGOLWContentHandler, verbose = options.verbose)
	# Loop over each bank within SVD file
	for bank in banks:
		# Loop over each bank_fragment within each bank
		for n, frag in enumerate(bank.bank_fragments):
			# Extract odd/even rows of chifacs and mix_matrix from each bank fragment
			if options.even:
				chifacs_re = bank.bank_fragments[n].chifacs[2::4]
				chifacs_im = bank.bank_fragments[n].chifacs[3::4]
				mix_mat_re = bank.bank_fragments[n].mix_matrix[:, 2::4]
				mix_mat_im = bank.bank_fragments[n].mix_matrix[:, 3::4]
			else:
				chifacs_re = bank.bank_fragments[n].chifacs[0::4]
				chifacs_im = bank.bank_fragments[n].chifacs[1::4]
				mix_mat_re = bank.bank_fragments[n].mix_matrix[:, 0::4]
				mix_mat_im = bank.bank_fragments[n].mix_matrix[:, 1::4]

			bank.bank_fragments[n].chifacs = np.array(list(itertools.chain(*zip(chifacs_re, chifacs_im))))
			mix_mat_new = np.empty((mix_mat_re.shape[0],mix_mat_re.shape[1]+mix_mat_im.shape[1]))
			mix_mat_new[:,0::2] = mix_mat_re
			mix_mat_new[:,1::2] = mix_mat_im
			bank.bank_fragments[n].mix_matrix = mix_mat_new

		# delete even/odd entries from sngl_inspiral table
		if options.even:
			del bank.sngl_inspiral_table[0::2]
		else:
			del bank.sngl_inspiral_table[1::2]

		# remove even/odd entries from sigmasq, correlation matrices/masks
		if options.even:
			bank.sigmasq = bank.sigmasq[1::2]
			bank.autocorrelation_bank = bank.autocorrelation_bank[1::2,:]
			bank.autocorrelation_mask = bank.autocorrelation_mask[1::2,:]
			bank.bank_correlation_matrix = bank.bank_correlation_matrix[1::2,1::2]
		else:
			bank.sigmasq = bank.sigmasq[0::2]
			bank.autocorrelation_bank = bank.autocorrelation_bank[0::2,:]
			bank.autocorrelation_mask = bank.autocorrelation_mask[0::2,:]
			bank.bank_correlation_matrix = bank.bank_correlation_matrix[0::2,0::2]

	# Set output filename
	if options.in_place:
		out_bank_file = bank_file
	else:
		out_path = options.outdir
		out_bank_file = os.path.join(out_path, os.path.basename(bank_file))

		# Check output path exists
		if not os.path.isdir(out_path):
			os.mkdir(out_path)

	# Write out checkerboard SVD bank
	svd_bank.write_bank(out_bank_file, banks, psd, verbose = options.verbose)
