#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_gstlal-ugly_1767612080/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl/bin/python
import numpy
import h5py
import sys, os

# FIXME this is all really stupid. We can change the code that makes thise to
# an h5 file format with a reasonable schema or json or whatever I don't care
# just not custom
class AcCounts(object):
	def __init__(self):
		self.mchirp = {}
		self.chieff = {}
		self.counts = {}
	def insert(self, fname, binnum):
		try:
			with open(fname) as f:
				_, mchirp, _, chieff = f.readline().split()
				self.mchirp[binnum] = eval(mchirp)
				self.chieff[binnum] = eval(chieff)
				self.counts[binnum] = {}
				for line in f.readlines()[1:]:
					name, _, _, _, _, count = line.split(",")
					try:
						self.counts[binnum][name] += float(count)
					except:
						self.counts[binnum][name] = float(count)
		except ValueError as IOError:
			print("%s could not be processed" % fname)
			raise

	def __str__(self):
		total_counts = sorted([(sum(self.counts[b].values()), b, self.counts[b].values()) for b in self.counts])
		return "\n".join(("%e: %s %s" % (c, b, d) for (c, b, d) in total_counts))

	def normalize(self):
		# FIXME this does no error checking that category keys are consistent
		self.norm = dict((cat, 0.) for cat in list(self.counts.values())[0].keys())
		for b in self.counts:
			for cat in self.norm:
				self.norm[cat] += self.counts[b][cat]
		for b in self.counts:
			for cat in self.norm:
				if self.norm[cat] > 0.:
					self.counts[b][cat] /= self.norm[cat]

			

h5 = h5py.File("activation_counts.h5", "w")
ac_counts = AcCounts()
for fname in sys.argv[1:]:
	binnum = os.path.split(fname)[1].split("-")[0]
	ac_counts.insert(fname, binnum)
ac_counts.normalize()

for b, counts in sorted(ac_counts.counts.items()):
	grp = h5.create_group(b)
	grp.create_dataset("mchirp_max", data = ac_counts.mchirp[b][1])
	grp.create_dataset("mchirp_min", data = ac_counts.mchirp[b][0])
	for cat, count in counts.items():
		grp.create_dataset(cat, data = numpy.array(count))
h5.close()
