#!/usr/bin/env python3
from __future__ import annotations

import argparse
import glob
import os
import re
import sys
from dataclasses import dataclass
from decimal import Decimal, InvalidOperation
from typing import Dict, Iterable, List, Tuple


@dataclass
class Profile:
    path: str
    sample: str
    abundances: Dict[str, Decimal]


def expand_inputs(inputs: Iterable[str]) -> List[str]:
    paths: List[str] = []
    for item in inputs:
        expanded = glob.glob(item)
        if expanded:
            paths.extend(expanded)
        else:
            paths.append(item)
    return sorted(set(paths))


def sample_name(path: str) -> str:
    base = os.path.basename(path)
    for ext in (".profile", ".tsv", ".txt", ".profile.gz", ".tsv.gz", ".txt.gz"):
        if base.endswith(ext):
            return base[: -len(ext)]
    return os.path.splitext(base)[0]


def read_profile(path: str, sample: str | None = None) -> Profile:
    abundances: Dict[str, Decimal] = {}
    with open(path, "r", encoding="utf-8") as handle:
        for line_num, raw in enumerate(handle, start=1):
            line = raw.rstrip("\n")
            if not line:
                continue
            parts = line.split("\t")
            if len(parts) < 3:
                raise ValueError(f"{path}:{line_num}: expected 3 tab-delimited columns")
            taxonomy = parts[1].strip()
            abundance = parts[2].strip()
            if not taxonomy:
                raise ValueError(f"{path}:{line_num}: empty taxonomy lineage")
            try:
                abundance_value = Decimal(abundance)
            except InvalidOperation as exc:
                raise ValueError(f"{path}:{line_num}: invalid abundance '{abundance}'") from exc
            abundances[taxonomy] = abundances.get(taxonomy, Decimal(0)) + abundance_value
    return Profile(path=path, sample=sample if sample is not None else sample_name(path), abundances=abundances)


def find_duplicates(values: List[str]) -> Dict[str, List[int]]:
    duplicates: Dict[str, List[int]] = {}
    for idx, value in enumerate(values):
        duplicates.setdefault(value, []).append(idx)
    return {value: idxs for value, idxs in duplicates.items() if len(idxs) > 1}


def shorten_stems(stems: List[str]) -> List[str]:
    tokens_list = [re.split(r"[._-]+", stem) for stem in stems]
    common_prefix: List[str] = []
    if tokens_list:
        shortest = min(len(tokens) for tokens in tokens_list)
        for idx in range(shortest):
            token = tokens_list[0][idx]
            if token and all(tokens[idx] == token for tokens in tokens_list):
                common_prefix.append(token)
            else:
                break
    if common_prefix:
        trimmed = []
        for tokens in tokens_list:
            remaining = tokens[len(common_prefix) :]
            trimmed.append(remaining if remaining else tokens)
        tokens_list = trimmed
    lengths = [1 for _ in stems]

    def candidate(idx: int) -> str:
        tokens = tokens_list[idx]
        if not tokens or tokens == [""]:
            return stems[idx]
        length = min(lengths[idx], len(tokens))
        return "_".join(tokens[:length])

    while True:
        candidates = [candidate(idx) for idx in range(len(stems))]
        duplicates = find_duplicates(candidates)
        if not duplicates:
            return candidates
        progressed = False
        for idxs in duplicates.values():
            for idx in idxs:
                if lengths[idx] < len(tokens_list[idx]):
                    lengths[idx] += 1
                    progressed = True
        if not progressed:
            return candidates


def resolve_sample_names(paths: List[str]) -> List[str]:
    stems = [sample_name(path) for path in paths]
    short_stems = shorten_stems(stems)
    dir_parts = [os.path.abspath(path).split(os.sep)[:-1] for path in paths]
    depths = [0 for _ in paths]

    def build_candidates(current_depths: List[int]) -> List[str]:
        candidates: List[str] = []
        for idx, stem in enumerate(short_stems):
            depth = current_depths[idx]
            if depth == 0:
                candidates.append(stem)
            else:
                parts = dir_parts[idx][-depth:] + [stem]
                candidates.append("_".join(parts))
        return candidates

    def disambiguate(start_depths: List[int]) -> Tuple[List[str], List[int]]:
        current_depths = list(start_depths)
        while True:
            candidates = build_candidates(current_depths)
            duplicates = find_duplicates(candidates)
            if not duplicates:
                return candidates, current_depths
            progressed = False
            for idxs in duplicates.values():
                for idx in idxs:
                    if current_depths[idx] < len(dir_parts[idx]):
                        current_depths[idx] += 1
                        progressed = True
            if not progressed:
                return candidates, current_depths

    candidates, depths = disambiguate(depths)
    duplicates = find_duplicates(candidates)
    if not duplicates:
        if any(depth > 0 for depth in depths):
            min_depths = [min(len(dir_parts[idx]), max(depths[idx], 1)) for idx in range(len(paths))]
            candidates, depths = disambiguate(min_depths)
            duplicates = find_duplicates(candidates)
            if not duplicates:
                return candidates
        else:
            return candidates

    detail = "; ".join(
        f"{name}: {', '.join(paths[idx] for idx in idxs)}" for name, idxs in duplicates.items()
    )
    raise ValueError(
        "Unable to derive unique sample names from paths. "
        f"Duplicates: {detail}"
    )


def ensure_unique_samples(samples: List[str], paths: List[str], hint: str | None = None) -> None:
    duplicates = find_duplicates(samples)
    if not duplicates:
        return
    detail = "; ".join(
        f"{name}: {', '.join(paths[idx] for idx in idxs)}" for name, idxs in duplicates.items()
    )
    message = f"Duplicate sample names detected: {detail}"
    if hint:
        message = f"{message}. {hint}"
    raise ValueError(message)


def merge_profiles(inputs: Iterable[str], resolve_samples: bool = False) -> None:
    paths = expand_inputs(inputs)
    if not paths:
        raise ValueError("No input profiles provided")

    if resolve_samples:
        samples = resolve_sample_names(paths)
    else:
        samples = [sample_name(path) for path in paths]
    ensure_unique_samples(samples, paths, hint="Use --resolve-samples to disambiguate.")

    profiles = [read_profile(path, sample) for path, sample in zip(paths, samples)]
    taxa: List[str] = sorted({taxon for profile in profiles for taxon in profile.abundances})

    header = ["taxon"] + [profile.sample for profile in profiles]
    sys.stdout.write("\t".join(header) + "\n")

    for taxon in taxa:
        row = [taxon]
        for profile in profiles:
            row.append(str(profile.abundances.get(taxon, Decimal(0))))
        sys.stdout.write("\t".join(row) + "\n")


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(prog="protal_profile_utils")
    sub = parser.add_subparsers(dest="command", required=True)

    merge = sub.add_parser("merge")
    merge.add_argument("--input", required=True, nargs="+")
    merge.add_argument(
        "--resolve-samples",
        action="store_true",
        help="Try to disambiguate sample names by shortening filenames and adding path segments.",
    )

    return parser


def main() -> int:
    parser = build_parser()
    args = parser.parse_args()

    try:
        if args.command == "merge":
            merge_profiles(args.input, resolve_samples=args.resolve_samples)
            return 0
    except Exception as exc:
        sys.stderr.write(f"Error: {exc}\n")
        return 2
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
