import pysam
import pandas as pd
import argparse
import os
import matplotlib.pyplot as plt
import seaborn as sns
import re # Import regular expression module
import sys # Import sys for exiting

# --- Helper Functions ---

def read_sam(file):
    """Reads alignments from a BAM/SAM file."""
    try:
        samfile = pysam.AlignmentFile(file, "rb")
    except ValueError as e:
        print(f"Error opening BAM/SAM file '{file}': {e}", file=sys.stderr)
        sys.exit(1)
    except FileNotFoundError:
        print(f"Error: BAM/SAM file not found at '{file}'", file=sys.stderr)
        sys.exit(1)

    records = []
    n_total = 0
    n_unmapped = 0
    try:
        for alignment in samfile.fetch(until_eof=True):
            n_total += 1
            is_unmapped = alignment.reference_name is None
            if is_unmapped:
                n_unmapped += 1

            record_data = {
                "element": alignment.reference_name if not is_unmapped else "unmapped",
            }
            records.append({"element": record_data["element"]})

    except Exception as e:
        print(f"Error reading records from '{file}': {e}", file=sys.stderr)
        samfile.close()
        sys.exit(1)

    samfile.close()
    df = pd.DataFrame(records)
    unmapped_perc = (n_unmapped / n_total * 100) if n_total > 0 else 0
    print(f"Processed {n_total} reads from BAM/SAM.")
    print(f"  {n_unmapped} unmapped ({unmapped_perc:.2f}%).")

    df["element"] = df["element"].str.replace("_Swap", "", regex=False)

    return df

def read_fastq_and_extract(file, pattern):
    """Reads sequences from a FASTQ file and extracts barcodes using a pattern."""
    if "{BC}" not in pattern:
        print(f"Error: Match pattern '{pattern}' must contain '{{BC}}' placeholder.", file=sys.stderr)
        sys.exit(1)

    flank1, flank2 = pattern.split('{BC}')
    regex_pattern = re.escape(flank1) + r'([ACGTN]+)' + re.escape(flank2)
    try:
        barcode_regex = re.compile(regex_pattern)
    except re.error as e:
        print(f"Error compiling regex pattern '{regex_pattern}': {e}", file=sys.stderr)
        sys.exit(1)

    try:
        fastqfile = pysam.FastxFile(file)
    except FileNotFoundError:
        print(f"Error: FASTQ file not found at '{file}'", file=sys.stderr)
        sys.exit(1)
    except Exception as e:
        print(f"Error opening FASTQ file '{file}': {e}", file=sys.stderr)
        sys.exit(1)

    extracted_barcodes = []
    n_total = 0
    n_matched = 0

    try:
        for record in fastqfile:
            n_total += 1
            sequence = record.sequence
            match = barcode_regex.search(sequence)
            if match:
                barcode = match.group(1) # Get the captured barcode
                extracted_barcodes.append({"element": barcode})
                n_matched += 1
            else:
                # Decide how to handle non-matching reads. Here, we label them 'unmatched'.
                extracted_barcodes.append({"element": "unmatched"})

    except Exception as e:
        print(f"Error reading records from '{file}': {e}", file=sys.stderr)
        fastqfile.close()
        sys.exit(1)

    fastqfile.close()
    df = pd.DataFrame(extracted_barcodes)

    matched_perc = (n_matched / n_total * 100) if n_total > 0 else 0
    unmatched_count = n_total - n_matched
    unmatched_perc = 100.0 - matched_perc
    print(f"Processed {n_total} reads from FASTQ.")
    print(f"  {n_matched} matched pattern ({matched_perc:.2f}%).")
    print(f"  {unmatched_count} unmatched ({unmatched_perc:.2f}%).")

    return df


def main():
    parser = argparse.ArgumentParser(
        description="Plot library representation from aligned BAM or unaligned FASTQ.",
        formatter_class=argparse.RawTextHelpFormatter # Keep help text formatting
        )

    # Input file group - must provide one
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument("--bam", help="Path to aligned BAM file.")
    input_group.add_argument("--fastq", help="Path to unaligned FASTQ file.")

    # Required if --fastq is used
    parser.add_argument("--match",
                        help="Barcode pattern for FASTQ (e.g., 'PREFIX{BC}SUFFIX').\n"
                             "Required if --fastq is used. {BC} marks the barcode.")

    # Required output name
    parser.add_argument("name", help="Base name for the output plot file (e.g., 'MyLibrary').")

    args = parser.parse_args()

    # --- Input Validation ---
    data = None # Initialize data DataFrame

    if args.fastq:
        if not args.match:
            parser.error("--match pattern is required when using --fastq.")
        if not os.path.isfile(args.fastq):
            print(f"Error: FASTQ file '{args.fastq}' does not exist.", file=sys.stderr)
            sys.exit(1)
        print(f"Reading FASTQ: {args.fastq}")
        print(f"Using pattern: {args.match}")
        data = read_fastq_and_extract(args.fastq, args.match)

    elif args.bam:
        if not os.path.isfile(args.bam):
            print(f"Error: BAM file '{args.bam}' does not exist.", file=sys.stderr)
            sys.exit(1)
        print(f"Reading BAM: {args.bam}")
        data = read_sam(args.bam)

    # --- Data Processing ---
    if data is None or data.empty:
        print("No data read from input file. Exiting.", file=sys.stderr)
        sys.exit(1)

    valid_elements = data.query("element != 'unmapped' and element != 'unmatched'")

    if valid_elements.empty:
        print("No valid elements (mapped reads or matched barcodes) found. Cannot generate plot.", file=sys.stderr)
        sys.exit(0) # Exit cleanly, just nothing to plot

    count = valid_elements.groupby("element").size().reset_index(name="n_reads")
    total_valid_reads = count["n_reads"].sum()
    num_unique_elements = len(count)
    print(f"Found {num_unique_elements} unique elements with {total_valid_reads} total reads.")


    # --- Plotting ---
    fig, axes = plt.subplots(1, 2, figsize=(8, 4)) # Adjusted size slightly
    fig.suptitle(f"{args.name} Library Representation ({num_unique_elements} elements)")

    # Histogram of reads per element
    sns.histplot(data=count, x="n_reads", bins=min(50, num_unique_elements), ax=axes[0]) # Adjusted bins
    axes[0].set_title("Reads per Element")
    axes[0].set_xlabel("Number of Reads")
    axes[0].set_ylabel("Number of Elements")
    axes[0].set_yscale('log') # Often useful for library representation
    axes[0].grid(axis='y', linestyle='--', alpha=0.7)

    # Cumulative distribution
    count["frac"] = count["n_reads"] / total_valid_reads
    count = count.sort_values("frac", ascending=False)
    count["cum_frac"] = count["frac"].cumsum()
    count["rank"] = range(1, len(count) + 1) # Simpler ranking

    sns.lineplot(data=count, x="rank", y="cum_frac", ax=axes[1])
    axes[1].set_title("Cumulative Read Fraction")
    axes[1].set_xlabel("Element Rank")
    axes[1].set_ylabel("Cumulative Fraction of Reads")
    axes[1].set_ylim(0, 1.05)
    axes[1].set_xlim(0, num_unique_elements + 1)
    axes[1].grid(True, linestyle='--', alpha=0.7)


    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap

    # --- Save Figure ---
    output_filename = f"{args.name}_representation.png"
    try:
        fig.savefig(output_filename, dpi=300, bbox_inches='tight')
        print(f"Plot saved to {output_filename}")
    except Exception as e:
        print(f"Error saving plot to '{output_filename}': {e}", file=sys.stderr)

if __name__ == "__main__":
    main()

