#!/usr/bin/env python3
from __future__ import annotations

import argparse
import os
import sys
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple


MAP_SAMPLEID = "#SAMPLEID"
MAP_FIRST_READ = "FIRST"
MAP_SECOND_READ = "SECOND"
MAP_SAM = "SAM"
MAP_PROFILE = "PROFILE"
MAP_PROFILE_TRUTH = "PROFILE_TRUTH"
MAP_HEADER_TO_SPECIES = "HEADER_TO_SPECIES"
MAP_PREFIX = "PREFIX"

MAP_VAR_INPUT_DIR = "#INPUT_DIR"
MAP_VAR_OUTPUT_DIR = "#OUTPUT_DIR"
MAP_VAR_SAM_OUTPUT_DIR = "#SAM_OUTPUT_DIR"
MAP_VAR_PROFILE_OUTPUT_DIR = "#PROFILE_OUTPUT_DIR"
MAP_VAR_STRAIN_OUTPUT_DIR = "#STRAIN_OUTPUT_DIR"
MAP_VAR_MISC_OUTPUT_DIR = "#MISC_OUTPUT_DIR"

MAP_VAR_DEFAULT_SAM_OUTPUT_DIR = "alignments"
MAP_VAR_DEFAULT_PROFILE_OUTPUT_DIR = "profiles"
MAP_VAR_DEFAULT_STRAIN_OUTPUT_DIR = "strains"
MAP_VAR_DEFAULT_MISC_OUTPUT_DIR = "misc"
READ_EXTENSIONS = (".fq", ".fastq", ".fq.gz", ".fastq.gz")


@dataclass
class MapFile:
    path: str
    vars: Dict[str, str]
    var_lines: List[Tuple[str, str]]
    header_cols: List[str]
    rows: List[List[str]]


def parse_map(path: str) -> MapFile:
    vars_map: Dict[str, str] = {}
    var_lines: List[Tuple[str, str]] = []
    header_cols: List[str] = []
    rows: List[List[str]] = []

    saw_header = False
    with open(path, "r", encoding="utf-8") as handle:
        for line_num, raw_line in enumerate(handle, start=1):
            line = raw_line.rstrip("\n")
            if not line:
                continue
            tokens = line.split("\t")
            if line.startswith("#"):
                if saw_header:
                    raise ValueError(
                        f"{path}:{line_num}: did not expect header line after data rows"
                    )
                if tokens[0] == MAP_SAMPLEID:
                    header_cols = tokens
                    saw_header = True
                    continue
                if len(tokens) < 2:
                    raise ValueError(
                        f"{path}:{line_num}: expected value for key {tokens[0]}"
                    )
                vars_map[tokens[0]] = tokens[1]
                var_lines.append((tokens[0], tokens[1]))
            else:
                if not saw_header:
                    raise ValueError(
                        f"{path}:{line_num}: expected header line starting with {MAP_SAMPLEID}"
                    )
                rows.append(tokens)

    if not header_cols:
        raise ValueError(f"{path}: missing {MAP_SAMPLEID} header line")

    return MapFile(path=path, vars=vars_map, var_lines=var_lines, header_cols=header_cols, rows=rows)


def find_col(header_cols: List[str], name: str) -> Optional[int]:
    try:
        return header_cols.index(name)
    except ValueError:
        return None


def map_dir(path: str) -> str:
    return os.path.dirname(os.path.abspath(path))


def input_base(map_file: MapFile) -> str:
    return map_file.vars.get(MAP_VAR_INPUT_DIR, map_dir(map_file.path))


def output_base(map_file: MapFile) -> Optional[str]:
    return map_file.vars.get(MAP_VAR_OUTPUT_DIR)


def sam_output_base(map_file: MapFile) -> Optional[str]:
    base = map_file.vars.get(MAP_VAR_SAM_OUTPUT_DIR)
    if base:
        return base
    out_dir = output_base(map_file)
    if out_dir:
        return os.path.join(out_dir, MAP_VAR_DEFAULT_SAM_OUTPUT_DIR)
    return None


def profile_output_base(map_file: MapFile) -> Optional[str]:
    base = map_file.vars.get(MAP_VAR_PROFILE_OUTPUT_DIR)
    if base:
        return base
    out_dir = output_base(map_file)
    if out_dir:
        return os.path.join(out_dir, MAP_VAR_DEFAULT_PROFILE_OUTPUT_DIR)
    return None


def resolve_path(base: Optional[str], value: str, fallback_base: str) -> str:
    if os.path.isabs(value):
        return value
    use_base = base if base else fallback_base
    return os.path.abspath(os.path.join(use_base, value))


def write_map(map_file: MapFile) -> None:
    for key, value in map_file.var_lines:
        sys.stdout.write(f"{key}\t{value}\n")
    sys.stdout.write("\t".join(map_file.header_cols) + "\n")
    for row in map_file.rows:
        sys.stdout.write("\t".join(row) + "\n")


def validate_map(path: str) -> int:
    map_file = parse_map(path)
    first_idx = find_col(map_file.header_cols, MAP_FIRST_READ)
    second_idx = find_col(map_file.header_cols, MAP_SECOND_READ)
    if first_idx is None or second_idx is None:
        raise ValueError(
            f"{path}: missing required columns {MAP_FIRST_READ} and {MAP_SECOND_READ}"
        )

    base_dir = input_base(map_file)
    missing: List[str] = []
    for row_idx, row in enumerate(map_file.rows, start=1):
        if first_idx >= len(row) or second_idx >= len(row):
            missing.append(f"row {row_idx}: missing FIRST/SECOND value")
            continue
        for col_idx in (first_idx, second_idx):
            value = row[col_idx]
            resolved = resolve_path(base_dir, value, map_dir(map_file.path))
            if not os.path.exists(resolved):
                missing.append(f"row {row_idx}: {resolved}")

    if missing:
        sys.stderr.write("Missing read files:\n")
        for item in missing:
            sys.stderr.write(f"- {item}\n")
        return 1
    return 0


def flatten_map(path: str, output_dir_override: Optional[str]) -> None:
    map_file = parse_map(path)
    first_idx = find_col(map_file.header_cols, MAP_FIRST_READ)
    second_idx = find_col(map_file.header_cols, MAP_SECOND_READ)
    if first_idx is None or second_idx is None:
        raise ValueError(
            f"{path}: missing required columns {MAP_FIRST_READ} and {MAP_SECOND_READ}"
        )

    base_dir = input_base(map_file)
    for row in map_file.rows:
        if first_idx < len(row):
            row[first_idx] = resolve_path(base_dir, row[first_idx], map_dir(map_file.path))
        if second_idx < len(row):
            row[second_idx] = resolve_path(base_dir, row[second_idx], map_dir(map_file.path))

    map_file.var_lines = [
        (k, v) for (k, v) in map_file.var_lines if k != MAP_VAR_INPUT_DIR
    ]
    map_file.vars.pop(MAP_VAR_INPUT_DIR, None)
    if output_dir_override:
        map_file.vars[MAP_VAR_OUTPUT_DIR] = output_dir_override
        map_file.var_lines = [
            (k, v) for (k, v) in map_file.var_lines if k != MAP_VAR_OUTPUT_DIR
        ]
        map_file.var_lines.append((MAP_VAR_OUTPUT_DIR, output_dir_override))
    write_map(map_file)


def common_path(paths: List[str]) -> Optional[str]:
    if not paths:
        return None
    return os.path.commonpath(paths)


def iter_read_files(input_dirs: List[str]) -> List[str]:
    reads: List[str] = []
    for input_dir in input_dirs:
        for root, _, files in os.walk(input_dir):
            for name in files:
                if name.endswith(READ_EXTENSIONS):
                    reads.append(os.path.join(root, name))
    return reads


def omit_extension(name: str) -> str:
    for ext in READ_EXTENSIONS:
        if name.endswith(ext):
            return name[: -len(ext)]
    return name


def hamming_distance(a: str, b: str) -> int:
    if len(a) != len(b):
        return max(len(a), len(b))
    return sum(1 for c1, c2 in zip(a, b) if c1 != c2)


def is_pair(read1: str, read2: str) -> bool:
    r1 = os.path.split(read1)
    r2 = os.path.split(read2)
    if " ".join(r1[:-1]) != " ".join(r2[:-1]):
        return False
    if hamming_distance(read1, read2) != 1 or len(read1) != len(read2):
        return False
    for i in range(1, len(read1)):
        c1 = read1[i]
        c2 = read2[i]
        if (
            c1 != c2
            and not read1[i - 1].isdigit()
            and not read2[i - 1].isdigit()
            and (i == (len(read1) - 1) or (not read1[i + 1].isdigit() and not read2[i + 1].isdigit()))
        ):
            if (c1 == "1" and c2 == "2") or (c1 == "2" and c2 == "1"):
                return True
    return False


def first_diff_index(str1: str, str2: str) -> int:
    if len(str1) != len(str2):
        return -1
    for i, (c1, c2) in enumerate(zip(str1, str2)):
        if c1 != c2:
            return i
    return -1


def remove_pair(read1: str, read2: str) -> str:
    diff_idx = first_diff_index(read1, read2)
    sample = read1
    remove_set = {"_", "/", "."}
    if diff_idx == len(read1) - 1:
        i = diff_idx - 1
        while i > 0 and read1[i] in remove_set:
            i -= 1
        sample = read1[: i + 1]
    return sample


def get_relevant_indices(path_list: List[str]) -> List[int]:
    if not path_list:
        raise ValueError("No read files found in input directory")
    path_list_split = [path.split("/") for path in path_list]
    min_len = min(len(path) for path in path_list_split)
    max_len = max(len(path) for path in path_list_split)
    relevant_indices = [*range(min_len, max_len)]
    for i in range(0, min_len):
        if not all(path[i] == path_list_split[0][i] for path in path_list_split):
            relevant_indices.append(i)
    return relevant_indices


def generate_sample_name(read1: str, read2: str, indices: List[int]) -> str:
    r1 = read1.split("/")[:-1]
    r2 = read2.split("/")[:-1]
    r1_base = omit_extension(os.path.basename(read1))
    r2_base = omit_extension(os.path.basename(read2))
    prefix = "_".join([r1[i] for i in indices if i < len(r1)])
    sample_name = prefix + ("_" if prefix else "") + remove_pair(r1_base, r2_base)
    return sample_name


def generate_samples_from_reads(read_files: List[str], gzip_sam: bool) -> List[Dict[str, str]]:
    read_files = sorted(read_files)
    indices = get_relevant_indices(read_files)
    samples: List[Dict[str, str]] = []
    i = 0
    while i < len(read_files) - 1:
        first_path = read_files[i]
        potential_second_path = read_files[i + 1]
        first = omit_extension(os.path.basename(first_path))
        potential_second = omit_extension(os.path.basename(potential_second_path))
        if is_pair(first, potential_second):
            sample_name = generate_sample_name(first_path, potential_second_path, indices)
            sam_name = f"{sample_name}.sam.gz" if gzip_sam else f"{sample_name}.sam"
            samples.append(
                {
                    MAP_SAMPLEID: sample_name,
                    MAP_FIRST_READ: first_path,
                    MAP_SECOND_READ: potential_second_path,
                    MAP_SAM: sam_name,
                    MAP_PROFILE: f"{sample_name}.profile",
                    MAP_PREFIX: sample_name,
                }
            )
            i += 2
            continue
        i += 1
    return samples


def generate_map(input_dirs: List[str], output_dir: str, gzip_sam: bool) -> None:
    input_dirs_abs = [os.path.abspath(d) for d in input_dirs]
    input_dir = common_path(input_dirs_abs)
    if not input_dir:
        raise ValueError("No input directories provided")

    all_reads = iter_read_files(input_dirs_abs)
    if not all_reads:
        raise ValueError("No read files found in input directories")

    reads_rel = [
        strip_prefix(os.path.normpath(path), input_dir)
        for path in all_reads
        if os.path.normpath(path).startswith(input_dir)
    ]
    samples = generate_samples_from_reads(reads_rel, gzip_sam)

    seen_sample_ids: Dict[str, str] = {}
    seen_files: Dict[str, str] = {}
    for idx, sample in enumerate(samples, start=1):
        sample_id = sample[MAP_SAMPLEID]
        if sample_id in seen_sample_ids:
            raise ValueError(
                f"Duplicate sample id '{sample_id}' in generate (rows {seen_sample_ids[sample_id]} and {idx})"
            )
        seen_sample_ids[sample_id] = str(idx)
        for col in (MAP_FIRST_READ, MAP_SECOND_READ, MAP_SAM, MAP_PROFILE, MAP_PREFIX):
            value = sample.get(col, "")
            if not value:
                continue
            if value in seen_files:
                raise ValueError(
                    f"Duplicate {col} entry '{value}' in generate (rows {seen_files[value]} and {idx})"
                )
            seen_files[value] = str(idx)

    header_cols = [
        MAP_SAMPLEID,
        MAP_FIRST_READ,
        MAP_SECOND_READ,
        MAP_SAM,
        MAP_PROFILE,
        MAP_PREFIX,
    ]
    rows = [[sample.get(col, "") for col in header_cols] for sample in samples]

    var_lines = [
        (MAP_VAR_INPUT_DIR, input_dir),
        (MAP_VAR_OUTPUT_DIR, os.path.abspath(output_dir)),
        (MAP_VAR_SAM_OUTPUT_DIR, MAP_VAR_DEFAULT_SAM_OUTPUT_DIR),
        (MAP_VAR_PROFILE_OUTPUT_DIR, MAP_VAR_DEFAULT_PROFILE_OUTPUT_DIR),
        (MAP_VAR_STRAIN_OUTPUT_DIR, MAP_VAR_DEFAULT_STRAIN_OUTPUT_DIR),
        (MAP_VAR_MISC_OUTPUT_DIR, MAP_VAR_DEFAULT_MISC_OUTPUT_DIR),
    ]
    merged = MapFile(
        path="",
        vars={k: v for k, v in var_lines},
        var_lines=var_lines,
        header_cols=header_cols,
        rows=rows,
    )
    write_map(merged)


def strip_prefix(path: str, prefix: str) -> str:
    prefix = os.path.abspath(prefix)
    path = os.path.abspath(path)
    if path == prefix:
        return "."
    if path.startswith(prefix + os.sep):
        return path[len(prefix) + 1 :]
    return path


def normalize_prefix(value: str, output_dir: Optional[str]) -> str:
    if not output_dir or not os.path.isabs(value):
        if value.startswith(MAP_VAR_DEFAULT_MISC_OUTPUT_DIR + os.sep):
            return value[len(MAP_VAR_DEFAULT_MISC_OUTPUT_DIR) + 1 :]
        if value.startswith(MAP_VAR_DEFAULT_STRAIN_OUTPUT_DIR + os.sep):
            return value[len(MAP_VAR_DEFAULT_STRAIN_OUTPUT_DIR) + 1 :]
        return value
    if value.startswith(os.path.abspath(output_dir) + os.sep):
        return strip_prefix(value, output_dir)
    for subdir in (MAP_VAR_DEFAULT_MISC_OUTPUT_DIR, MAP_VAR_DEFAULT_STRAIN_OUTPUT_DIR):
        marker = os.sep + subdir + os.sep
        if marker in value:
            return value.split(marker, 1)[1]
    return os.path.basename(value)


def merge_maps(
    paths: List[str],
    output_dir_override: Optional[str],
    use_sampleid: bool,
    gzip_sam: bool,
) -> None:
    maps = [parse_map(p) for p in paths]
    header = maps[0].header_cols
    header_set = set(header)
    for mp in maps[1:]:
        if set(mp.header_cols) != header_set:
            raise ValueError("All map files must use the same header columns (order may differ)")

    out_first_idx = find_col(header, MAP_FIRST_READ)
    out_second_idx = find_col(header, MAP_SECOND_READ)
    out_prefix_idx = find_col(header, MAP_PREFIX)
    out_sam_idx = find_col(header, MAP_SAM)
    out_profile_idx = find_col(header, MAP_PROFILE)
    out_sampleid_idx = find_col(header, MAP_SAMPLEID)

    if out_first_idx is None or out_second_idx is None or out_prefix_idx is None:
        raise ValueError(
            f"Missing required columns {MAP_FIRST_READ}, {MAP_SECOND_READ}, {MAP_PREFIX}"
        )
    if use_sampleid and out_sampleid_idx is None:
        raise ValueError(f"Missing required column {MAP_SAMPLEID} for --use-sampleid")

    read_paths: List[str] = []
    prefix_paths: List[str] = []
    sam_paths: List[str] = []
    profile_paths: List[str] = []

    seen_sample_ids: Dict[str, str] = {}
    seen_files: Dict[str, str] = {}

    merged_rows: List[List[str]] = []

    for mp in maps:
        base_dir = input_base(mp)
        out_dir = output_base(mp)
        sam_base = sam_output_base(mp)
        profile_base = profile_output_base(mp)
        fallback_base = map_dir(mp.path)

        col_index = {name: idx for idx, name in enumerate(mp.header_cols)}
        for row in mp.rows:
            out_row = [""] * len(header)
            for out_idx, col_name in enumerate(header):
                src_idx = col_index.get(col_name)
                if src_idx is not None and src_idx < len(row):
                    out_row[out_idx] = row[src_idx]

            first_abs = resolve_path(base_dir, out_row[out_first_idx], fallback_base)
            second_abs = resolve_path(base_dir, out_row[out_second_idx], fallback_base)
            read_paths.extend([first_abs, second_abs])

            prefix_abs = resolve_path(out_dir, out_row[out_prefix_idx], fallback_base)
            prefix_paths.append(prefix_abs)

            if out_sam_idx is not None and out_sam_idx < len(out_row):
                sam_abs = resolve_path(sam_base, out_row[out_sam_idx], fallback_base)
                sam_paths.append(sam_abs)
            if out_profile_idx is not None and out_profile_idx < len(out_row):
                profile_abs = resolve_path(profile_base, out_row[out_profile_idx], fallback_base)
                profile_paths.append(profile_abs)

            out_row[out_first_idx] = first_abs
            out_row[out_second_idx] = second_abs
            out_row[out_prefix_idx] = prefix_abs
            if out_sam_idx is not None and out_sam_idx < len(out_row):
                out_row[out_sam_idx] = sam_paths[-1] if sam_paths else out_row[out_sam_idx]
            if out_profile_idx is not None and out_profile_idx < len(out_row):
                out_row[out_profile_idx] = profile_paths[-1] if profile_paths else out_row[out_profile_idx]

            merged_rows.append(out_row)

    input_dir = common_path(read_paths)
    output_dir = output_dir_override or common_path(prefix_paths + sam_paths + profile_paths)
    if output_dir:
        for subdir in (
            MAP_VAR_DEFAULT_SAM_OUTPUT_DIR,
            MAP_VAR_DEFAULT_PROFILE_OUTPUT_DIR,
            MAP_VAR_DEFAULT_STRAIN_OUTPUT_DIR,
            MAP_VAR_DEFAULT_MISC_OUTPUT_DIR,
        ):
            if os.path.basename(output_dir) == subdir:
                output_dir = os.path.dirname(output_dir)
                break
    if input_dir:
        for row in merged_rows:
            row[out_first_idx] = strip_prefix(row[out_first_idx], input_dir)
            row[out_second_idx] = strip_prefix(row[out_second_idx], input_dir)
    if output_dir:
        for row in merged_rows:
            row[out_prefix_idx] = normalize_prefix(row[out_prefix_idx], output_dir)

    var_lines: List[Tuple[str, str]] = []
    if input_dir:
        var_lines.append((MAP_VAR_INPUT_DIR, input_dir))
    if output_dir:
        var_lines.append((MAP_VAR_OUTPUT_DIR, output_dir))
        var_lines.append((MAP_VAR_SAM_OUTPUT_DIR, MAP_VAR_DEFAULT_SAM_OUTPUT_DIR))
        var_lines.append((MAP_VAR_PROFILE_OUTPUT_DIR, MAP_VAR_DEFAULT_PROFILE_OUTPUT_DIR))
        var_lines.append((MAP_VAR_STRAIN_OUTPUT_DIR, MAP_VAR_DEFAULT_STRAIN_OUTPUT_DIR))
        var_lines.append((MAP_VAR_MISC_OUTPUT_DIR, MAP_VAR_DEFAULT_MISC_OUTPUT_DIR))

    if output_dir:
        if out_sam_idx is not None:
            for row in merged_rows:
                if out_sam_idx < len(row):
                    row[out_sam_idx] = os.path.basename(row[out_sam_idx])
        if out_profile_idx is not None:
            for row in merged_rows:
                if out_profile_idx < len(row):
                    row[out_profile_idx] = os.path.basename(row[out_profile_idx])

    if use_sampleid:
        for row in merged_rows:
            sample_id = row[out_sampleid_idx]
            row[out_prefix_idx] = sample_id
            if out_sam_idx is not None and out_sam_idx < len(row):
                row[out_sam_idx] = f"{sample_id}.sam.gz" if gzip_sam else f"{sample_id}.sam"
            if out_profile_idx is not None and out_profile_idx < len(row):
                row[out_profile_idx] = f"{sample_id}.profile"

    for idx, row in enumerate(merged_rows, start=1):
        sample_id = row[out_sampleid_idx] if out_sampleid_idx is not None else ""
        if sample_id:
            if sample_id in seen_sample_ids:
                raise ValueError(
                    f"Duplicate sample id '{sample_id}' in merge (rows {seen_sample_ids[sample_id]} and {idx})"
                )
            seen_sample_ids[sample_id] = str(idx)
        for col_name, col_idx in (
            (MAP_FIRST_READ, out_first_idx),
            (MAP_SECOND_READ, out_second_idx),
            (MAP_SAM, out_sam_idx),
            (MAP_PROFILE, out_profile_idx),
            (MAP_PREFIX, out_prefix_idx),
        ):
            if col_idx is None or col_idx >= len(row):
                continue
            value = row[col_idx]
            if not value:
                continue
            if value in seen_files:
                raise ValueError(
                    f"Duplicate {col_name} entry '{value}' in merge (rows {seen_files[value]} and {idx})"
                )
            seen_files[value] = str(idx)
    merged = MapFile(
        path="",
        vars={k: v for k, v in var_lines},
        var_lines=var_lines,
        header_cols=header,
        rows=merged_rows,
    )
    write_map(merged)


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(prog="protal_map_utils")
    sub = parser.add_subparsers(dest="command", required=True)

    validate = sub.add_parser("validate")
    validate.add_argument("--map", required=True)

    flatten = sub.add_parser("flatten")
    flatten.add_argument("--map", required=True)
    flatten.add_argument("--out")

    generate = sub.add_parser("generate")
    generate.add_argument("--input", required=True, nargs="+")
    generate.add_argument("--out", required=True)
    generate.add_argument("--nogzip", action="store_true")

    merge = sub.add_parser("merge")
    merge.add_argument("--map", required=True, action="append", nargs="+")
    merge.add_argument("--out")
    merge.add_argument("--use-sampleid", action="store_true")
    merge.add_argument("--nogzip", action="store_true")

    return parser


def main() -> int:
    parser = build_parser()
    args = parser.parse_args()

    try:
        if args.command == "validate":
            return validate_map(args.map)
        if args.command == "flatten":
            flatten_map(args.map, args.out)
            return 0
        if args.command == "generate":
            generate_map(args.input, args.out, not args.nogzip)
            return 0
        if args.command == "merge":
            maps = [item for group in args.map for item in group]
            if len(maps) < 2:
                raise ValueError("merge requires at least two --map arguments")
            merge_maps(maps, args.out, args.use_sampleid, not args.nogzip)
            return 0
    except Exception as exc:
        sys.stderr.write(f"Error: {exc}\n")
        return 2

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
