#!/opt/conda/bin/python
# Read GLAM2 output: write an HTML version of it 
import fileinput, os, tempfile, sys

embed_seqs = 0
version = ''
commandline = ''
alphabet = '?'
alignments = []
text_alignments = []
freq_matrices = []
state = 0

usage = """
USAGE: %s

\tConvert glam2 output to HTML.

\tReads standard input.
\tWrites standard output.
""" % (sys.argv[0])


# parse command line
i=1
while i < len(sys.argv):
  arg = sys.argv[i]
  sys.stderr.write("Unknown command line argument: " + arg + "\n"); sys.exit(1)
  i += 1

for line in sys.stdin.readlines():
    fields = line.split()

    if state == 0:
        if line.startswith('Version'):
            version = line
        elif line.find('glam2') != -1:
            commandline = line
        elif line.startswith('Residue counts'):
            if len(fields) == 4 + 3:
                alphabet = 'n'
                alen = 4
            elif len(fields) == 20 + 3:
                alphabet = 'p'
                alen = 20
            state += 1

    elif state == 1:
        if len(fields) > 1 and fields[0] == 'Score:':
            score = fields[1]
            columns = int(fields[3])
            sequences = int(fields[5])
            state += 1

    elif state == 2:
        if len(fields) == 1:
            keypos = fields[0]
            aln = []
            alignments.append([score, keypos, aln])
            text_aln = line
            text_alignments.append(text_aln)
            state += 1

    elif state == 3:
        if len(fields) == 6:
            assert len(fields[2]) == len(keypos)
            aln.append(fields)
            text_alignments[len(text_alignments)-1] += line
        else:
            state += 1

    elif state == 4:
        if len(fields) == alen + 3 and fields[alen] == 'Del':
            ipos = 0
            pspm = []
            freq_matrices.append([pspm, columns, sequences]) 
            state += 1

    elif state == 5:
      if len(fields) == alen + 2:
         pspm.append(fields[:alen])
         ipos += 1
         if ipos == columns:
           state = 1
    

assert len(alignments) > 0

# print the HTML header:
print('<html>')
print('<head>')
print('<title>GLAM2 Results</title>')
print('<style type="text/css">')
print('body {background: #D5F0FF}')
print('th {text-align: left}')
print('td.kk {font-family: monospace; font-weight: bold}')
print('td.jj {font-family: monospace; color: gray}')
if alphabet == 'n':  # nucleotide colors, copied from MEME:
    print('td.a {color: red}')
    print('td.c {color: blue}')
    print('td.g {color: orange}')
    print('td.t {color: green}')
elif alphabet == 'p':  # amino-acid colors, copied from MEME:
    print('td.N {color: green}')
    print('td.Q {color: green}')
    print('td.S {color: green}')
    print('td.T {color: green}')
    print('td.D {color: magenta}')
    print('td.E {color: magenta}')
    print('td.K {color: red}')
    print('td.R {color: red}')
    print('td.H {color: pink}')
    print('td.G {color: orange}')
    print('td.P {color: yellow}')
    print('td.Y {color: turquoise}')
    print('td.A {color: blue}')
    print('td.C {color: blue}')
    print('td.F {color: blue}')
    print('td.I {color: blue}')
    print('td.L {color: blue}')
    print('td.M {color: blue}')
    print('td.V {color: blue}')
    print('td.W {color: blue}')
print('</style>')
print('</head>')
print('<body>')
print('<h1>GLAM2: Gapped Local Alignment of Motifs</h1>')
print('<p style="font-family: monospace">', version, '<br>', commandline, '</p>')
print('<p>If you use this program in your research, please cite:<b> \
MC Frith, NFW Saunders, B Kobe, TL Bailey, &quot;Discovering sequence motifs with arbitrary insertions and deletions&quot;, PLoS Computational Biology, <b>4</b>(5):e1000071, 2008.\
<a href="http://www.ploscompbiol.org/article/info:doi/10.1371/journal.pcbi.1000071">\
[full text]</a></b></p>')
print('<form enctype="application/x-www-form-urlencoded" method="POST" target="_new" action="https://meme-suite.org/meme/utilities/process_request">')
print('<ul>')
print('</ul>')

# print the command line parameters as hidden fields
fields = commandline.split()
i = 1
while i < len(fields)-2:
    # simple flags
    if (fields[i] == '-2' or fields[i] == '-Q'):
      print('<input type="hidden" name="' + fields[i] + '" value="1">')
      i = i+1
    elif (fields[i] == '-M'):    # print(the input sequences)
      embed_seqs = 1
      i = i+1
    elif (fields[i] == '-A'):     # print(the email address)
        i = i+1
        address = fields[i].replace('\'', '');
        print('<input type="hidden" name="address" value="' + address + '">')
        print('<input type="hidden" name="address_verify" value="' + address + '">')
        i = i+1
    elif (fields[i] == '-X'):    # print(description field; replace '&' with ' ')
        i = i+1
        description = fields[i].replace('&', ' ');
        description = description.replace('\'', '');
        print('<input type="hidden" name="description" value="' + description + '">')
        i = i+1
    else:   # other switches with values
        print('<input type="hidden" name="' + fields[i] + '" value="' + fields[i+1] + '">')
        i = i+2

print('<input type="hidden" name="alphabet" value="' + fields[i] + '">')
i = i+1
print('<input type="hidden" name="seq_file" value="' + fields[i] + '">')

def regexp(aln, keypos):
    '''Write a regexp representation of a motif, similarly to MEME.'''
    alignedSequences = [row[2] for row in aln]
    columns = list(zip(*alignedSequences))
    regexpString = ''
    index = 0
    while index != -1:  # loop over aligned columns
        # get the regular expression component for this aligned column:
        residueCounts = {}
        for residue in columns[index]:
            if residue.isalpha():
                residueCounts[residue] = residueCounts.get(residue, 0) + 1
        matchCount = sum(residueCounts.values())
        residues = [k for k, v in residueCounts.items()
                    if v*5 >= matchCount and v > 0]
        if   len(residues) == 0:  regexpString += '.'
        elif len(residues) == 1:  regexpString += residues[0]
        else:                     regexpString += '[' + ''.join(residues) + ']'
        if '.' in columns[index]: regexpString += '?'
        # bump the loop variable:
        start = index + 1
        index = keypos.find('*', start)  # find the next aligned column
        # get the regular expression component for the insertion, if any:
        if index > start:
            maxInsert = index - start
            gapSizes = [s[start:index].count('.') for s in alignedSequences]
            minInsert = maxInsert - max(gapSizes)
            regexpString += '.{' + str(minInsert) + ',' + str(maxInsert) + '}'
    print('<b>', regexpString, '</b>')

def alntable(aln, keypos):
    print('<table>')
    print('<tr>')
    print('<th style="padding-right: 1em">NAME</th>')
    print('<th style="padding-right: 1em">START</th>')
    print('<th style="text-align: center" colspan="', len(keypos), '">SITES</th>')
    print('<th style="padding-left: 1em">END</th>')
    print('<th style="padding-left: 1em">STRAND</th>')
    print('<th style="padding-left: 1em">MARGINAL SCORE</th>')
    print('</tr>')
    print('<tbody>')
    for row in aln:
        print('<tr>')
        print('<td style="padding-right: 1em">', row[0], '</td>')
        print('<td style="padding-right: 1em;text-align: right">', row[1], '</td>')
        for i in range(len(keypos)):
            if keypos[i] == '*':
                print('<td class="kk', row[2][i], '">', row[2][i], '</td>')
            else:
                print('<td class="jj">', row[2][i], '</td>')
        print('<td style="padding-left: 1em;text-align: right">', row[3], '</td>')
        print('<td style="padding-left: 1em;text-align: center">', row[4], '</td>')
        print('<td style="padding-left: 1em">', row[5], '</td>')
        print('</tr>')
    print('</tbody>')
    print('</table>')

def print_aln(imotif, text_aln):
    print('<input type="submit" name="action" value="Scan alignment %d"><b> against sequence databases using <a href="https://meme-suite.org/meme/doc/glam2scan.html">GLAM2SCAN.</a></b>' % (imotif))
    # The "dots" in the value field are needed so Firefox doesn't remove whitespace!
    print('<input type="hidden" name="aln' + str(imotif) + '" value=".\n' + text_aln + '.">')
    print('<br><input type="submit" name="action" value="View alignment %d">' % (imotif))

def print_pspm(imotif, pspm, columns, sequences, alphabet):
    print('<input type="hidden" name="pspm' + str(imotif) + '" value="')
    print('letter-probability matrix: alength= ' + str(alen),)
    print(' w= ' + str(columns) + ' nsites= ' + str(sequences) + ' E= 1')
    uniform = 1.0/alen
    for col in pspm:
        match_count = 0
        for entry in col:
            match_count += float(entry)
        for entry in col:
            freq = uniform if match_count == 0 else float(entry)/match_count
            print(str(freq) + ' ')
        print()
    print('">')
    print('<input type="submit" name="action" value="View PSPM %d">' % (imotif))
    if alphabet == 'n':
        print('<br><input type="submit" name="action" value="Compare PSPM %d"><b> to known motifs in motif databases using <a href="https://meme-suite.org/meme/doc/tomtom.html">Tomtom</a>.</b>' % (imotif))

def seqlogo(imotif, score):
    print('<table> <tr> <td> Score: <b>', score, '</b></td>')
    print('<td rowspan=2><img src="logo' + str(imotif) + '.png"></td></tr>')
    print('<tr><th><a href="logo_ssc' + str(imotif) + '.png">logo with ssc</a></td></tr></table>')

for imotif in range(len(alignments)):
    a = alignments[imotif]
    t = text_alignments[imotif]
    f = freq_matrices[imotif]
    if imotif == 0:
        print('<h2>Best Motif Found:</h2>')
        alntable(a[2], a[1])
    elif imotif == 1:
        print('<h2>Replicates:</h2>')
        print('<p>Check that at least some of these are similar to the best motif found.<br>If not, there may well be an even better motif.<BR>Click here to ')
    if embed_seqs:
        print('<input type="submit" name="action" value="re-run GLAM2">' )
    else:
      print('re-run GLAM2')
    print('with double the "maximum number of iterations without improvement" to find it.</p>')
    print('<p>')
    seqlogo(imotif+1, a[0])
    print_aln(imotif+1, t)
    print_pspm(imotif+1, f[0], f[1], f[2], alphabet)
    print('<br><i>Regular Expression for Motif: </i>')
    regexp(a[2], a[1])
    print('<br>')

if len(alignments) < 2:
    print('<strong>No replicates were performed!</strong>')
print('</p>')

# print(the embedded sequences)
if (embed_seqs): 
    f = open(fields[i], 'r')
    print('<input type="hidden" name="data" value="\n' + f.read() + '">')
    f.close()

# close the HTML:
print('</form></body>')
print('</html>')

