#!/usr/bin/env python3

import argparse
import re
import copy

import btllib

"""
Convert ACTG sequences less than k to lower case or "N"
"""


def get_cli_args():
    parser = argparse.ArgumentParser(
        description="Hard or soft-mask ACTG sequences less than supplied k value."
    )
    parser.add_argument(
        "-n", help="Hard-mask regions less than k.", action="store_true"
    )
    parser.add_argument(
        "-s", help="Soft-mask regions less than k.", action="store_true"
    )
    parser.add_argument("-k", help="k-mer size", required=True, type=int)
    parser.add_argument(
        "seqspath", help="Input file for masking (or - to read from stdin)."
    )

    args = parser.parse_args()

    if not args.n and not args.s:
        parser.error("Either -n or -s must be set")
    if args.n and args.s:
        parser.error(
            "Both -n and -s cannot be set -- choose one to hard mask OR soft mask."
        )

    if args.seqspath == "-":
        args.seqspath = "/dev/stdin"

    return args


def extend_gaps(reader, writer, k, soft_mask, hard_mask):
    "Extend the gap lengths of the reads if flanks are less than k."
    for record in reader:
        seq = record.seq

        if len(seq) < 2 * k:
            seq = seq.upper()
        else:
            seq = seq[: args.k].upper() + seq[k:-k] + seq[-k:].upper()

        new_seq = ""
        groups_seq = re.findall(r"([ACTG]+|[Nn]+|[actgUNMRWSYKVHDBunmrwsykvhdb]+)", seq)
        for tmp_seq in groups_seq:
            if tmp_seq[0] == "N":
                new_seq += tmp_seq
            else:
                if len(tmp_seq) < k:
                    if hard_mask:
                        new_seq += "N" * len(tmp_seq)
                    else:
                        new_seq += tmp_seq.lower()
                else:
                    new_seq += tmp_seq
        new_seq = new_seq.strip("Nn")
        if new_seq == "":
            new_seq = "N"

        record.seq = new_seq
        writer.write(record.id, record.comment, record.seq, record.qual)


if __name__ == "__main__":
    args = get_cli_args()

    with btllib.SeqReader(
        args.seqspath, btllib.SeqReaderFlag.LONG_MODE
    ) as reader, btllib.SeqWriter("-") as writer:
        extend_gaps(reader, writer, args.k, args.s, args.n)
