#!/usr/bin/env python3
"""Shared BibTeX parsing, serialization, and formatting utilities.\n
Provides a simple regex-based BibTeX parser, entry serializer,
merge/deduplicate logic, and formatters that convert entries to markdown
and HTML reference lists.  Also includes LaTeX accent cleanup for author
names.
"""

import html
import os
import re
import sys

from arxiv_lib import arxiv_id_sort_key, wrap_in_html_document

_MD_SPECIAL = re.compile(r"([\\`*_{}\[\]()#+\-.!|~>=<^])")

def escape_html_entities(text: str) -> str:
    """Escape special HTML characters in text content.\n
    Converts ``&``, ``<``, ``>``, ``"``, and ``'`` to their HTML entities.\n
    Args:
        text: Raw text to escape.\n
    Returns:
        HTML-safe string.
    """
    return html.escape(text, quote=True)

def escape_markdown_chars(text: str) -> str:
    """Escape special markdown characters in text.\n
    Backslash-escapes characters that have special meaning in markdown:
    ``\\ ` * _ { } [ ] ( ) # + - . ! | ~ > = < ^``\n
    Args:
        text: Raw text to escape.\n
    Returns:
        Markdown-safe string.
    """
    return _MD_SPECIAL.sub(r"\\\1", text)

# ---------------------------------------------------------------------------
# BibTeX parser
# ---------------------------------------------------------------------------

def strip_outer_braces(s: str) -> str:
    """Remove outermost matching braces from a BibTeX value.\n
    BibTeX values are often wrapped in ``{...}`` (possibly nested).
    This strips one layer of outer braces after trimming whitespace.\n
    Args:
        s: A BibTeX field value string.\n
    Returns:
        The value with outer braces and surrounding whitespace removed.
    """
    s = s.strip()
    while s.startswith("{") and s.endswith("}"):
        s = s[1:-1].strip()
    return s

def parse_bibtex_to_entries(bib: str) -> list[dict]:
    """Parse a BibTeX string into a list of entry dicts.\n
    Splits on ``@`` at the start of a line, then extracts the entry type,
    citation key, and all field key=value pairs from each entry.  Handles
    both ``{...}`` and ``"..."`` value delimiters and nested braces up to
    one level.\n
    Args:
        bib: Raw BibTeX text (may contain multiple entries).\n
    Returns:
        A list of dicts, one per entry.  Each dict contains:\n
        - ``_type`` (str): Entry type (e.g., ``"article"``).
        - ``_key`` (str): Citation key.
        - All other keys are lowercase field names with brace-stripped
          values.
    """
    entries: list[dict] = []
    chunks = re.split(r"(?m)^@", bib)
    for chunk in chunks:
        chunk = chunk.strip()
        if not chunk:
            continue
        m = re.match(r"(\w+)\{([^,]*),", chunk)
        if not m:
            continue
        entry_type = m.group(1).lower()
        cite_key = m.group(2).strip()
        body = chunk[m.end() :]
        body = re.sub(r"\}\s*$", "", body)
        fields: dict[str, str] = {"_type": entry_type, "_key": cite_key}
        for fm in re.finditer(
            r'(\w+)\s*=\s*(?:\{((?:[^{}]|\{[^{}]*\})*)\}|"([^"]*)")',
            body,
        ):
            key = fm.group(1).lower()
            val = fm.group(2) if fm.group(2) is not None else fm.group(3)
            fields[key] = strip_outer_braces(val)
        entries.append(fields)
    return entries

# ---------------------------------------------------------------------------
# BibTeX serializer
# ---------------------------------------------------------------------------

def entry_to_bibtex_string(entry: dict) -> str:
    """Serialize a single entry dict back to a BibTeX string.\n
    Produces a standard BibTeX entry with the type and key on the first
    line, each field on its own line indented with two spaces, and a
    closing ``}``.\n
    Args:
        entry: An entry dict with ``_type``, ``_key``, and field keys.\n
    Returns:
        A BibTeX entry string ending with a newline.
    """
    entry_type = entry["_type"]
    cite_key = entry["_key"]
    fields = {k: v for k, v in entry.items() if not k.startswith("_")}
    lines = [f"@{entry_type}{{{cite_key},"]
    for key, val in fields.items():
        lines.append(f"  {key} = {{{val}}},")
    lines.append("}")
    return "\n".join(lines)

def entries_to_bibtex_string(entries: list[dict]) -> str:
    """Serialize a list of entry dicts to a BibTeX string.\n
    Entries are separated by blank lines.  A trailing newline is included.\n
    Args:
        entries: List of entry dicts (as returned by :func:`parse_entries`).\n
    Returns:
        Complete BibTeX text with all entries.
    """
    return "\n\n".join(entry_to_bibtex_string(e) for e in entries) + "\n"

# ---------------------------------------------------------------------------
# Merge / deduplicate
# ---------------------------------------------------------------------------

def extract_arxiv_sort_key(entry: dict) -> tuple[int, int, int]:
    """Extract a date-based sort key from a BibTeX entry.\n
    Tries the ``eprint`` field first (via :func:`arxiv_id_sort_key`), then
    falls back to ``year`` and ``month`` fields.  Returns ``(0, 0, 0)`` for
    entries with no parseable date (sorted as oldest).\n
    Args:
        entry: A parsed BibTeX entry dict.\n
    Returns:
        A ``(year, month, idx)`` tuple for sorting (most recent first).
    """
    eprint = strip_outer_braces(entry.get("eprint", ""))
    if eprint:
        key = arxiv_id_sort_key(eprint)
        if key != (0, 0, 0):
            return key
    year_str = strip_outer_braces(entry.get("year", ""))
    if year_str:
        try:
            return (int(year_str), 0, 0)
        except ValueError:
            pass
    return (0, 0, 0)

def merge_bibtex_strings(*bib_texts: str) -> list[dict]:
    """Parse and merge multiple BibTeX strings, dropping duplicate keys.\n
    When the same citation key appears in multiple inputs, the first
    occurrence (from the first argument listed) is kept.  Results are
    sorted by arXiv date (most recent first).\n
    Args:
        *bib_texts: One or more raw BibTeX strings to parse and merge.\n
    Returns:
        A deduplicated, date-sorted list of entry dicts.
    """
    seen: set[str] = set()
    merged: list[dict] = []
    for bib in bib_texts:
        for entry in parse_bibtex_to_entries(bib):
            key = entry["_key"]
            if key not in seen:
                seen.add(key)
                merged.append(entry)
    merged.sort(key=extract_arxiv_sort_key, reverse=True)
    return merged

# ---------------------------------------------------------------------------
# Author formatting
# ---------------------------------------------------------------------------

def format_author_list(raw: str) -> str:
    r"""Format a BibTeX author string into a readable display string.\n
    Splits the raw author field on ``" and "``, reorders each entry from
    ``"Last, First"`` to ``"First Last"``, and joins them with Oxford
    commas (e.g. ``"Alice Bob, Carol Dave, and Eve Frank"``).\n
    Common LaTeX accents are also converted to their Unicode equivalents:\n
    - ``{\ss}`` → ``ß``
    - ``\"a``, ``\"o``, ``\"u`` → ``ä``, ``ö``, ``ü``
    - ``\"A``, ``\"O``, ``\"U`` → ``Ä``, ``Ö``, ``Ü``\n
    Remaining braces are stripped after accent resolution.\n
    Args:
        raw: Raw BibTeX author field value (may be empty).\n
    Returns:
        Formatted author string, or empty string if *raw* is empty.\n
    Examples:
        >>> format_author_list("Einstein, Albert")
        'Albert Einstein'
        >>> format_author_list('M\\"uller, Hans and Schr\\"odinger, Erwin')
        'Hans Müller, and Erwin Schrödinger'
        >>> format_author_list("")
        ''
    """
    if not raw:
        return ""
    parts = [a.strip() for a in raw.split(" and ")]
    formatted: list[str] = []
    for part in parts:
        if "," in part:
            last, first = part.split(",", 1)
            formatted.append(f"{first.strip()} {last.strip()}")
        else:
            formatted.append(part)
    if len(formatted) > 1:
        result = ", ".join(formatted[:-1]) + ", and " + formatted[-1]
    else:
        result = formatted[0]
    result = result.replace(r"{\ss}", "ß").replace(r"\ss", "ß")
    result = result.replace(r"\"a", "ä").replace(r"\"o", "ö").replace(r"\"u", "ü")
    result = result.replace(r"\"A", "Ä").replace(r"\"O", "Ö").replace(r"\"U", "Ü")
    result = result.replace("{", "").replace("}", "")
    return result

# ---------------------------------------------------------------------------
# Markdown formatting
# ---------------------------------------------------------------------------

def entry_to_markdown_line(idx: int, entry: dict) -> str:
    """Format a single BibTeX entry as a numbered markdown line.\n
    Produces output like::\n
        1. **Title**, Authors, *Journal* **Volume**, Pages (Year)
           [arXiv:XXXX.XXXXX](https://arxiv.org/abs/XXXX.XXXXX)\n
    Fields that are absent are omitted.  Conference proceedings use
    ``booktitle`` instead of ``journal``.\n
    Args:
        idx: 1-based entry number.
        entry: A parsed BibTeX entry dict.\n
    Returns:
        A single markdown line (may contain multiple lines due to link).
    """
    title = strip_outer_braces(entry.get("title", "Untitled"))
    authors = format_author_list(entry.get("author", ""))
    journal = strip_outer_braces(entry.get("journal", ""))
    booktitle = strip_outer_braces(entry.get("booktitle", ""))
    volume = strip_outer_braces(entry.get("volume", ""))
    number = strip_outer_braces(entry.get("number", ""))
    pages = strip_outer_braces(entry.get("pages", "")).replace("--", "–")
    year = strip_outer_braces(entry.get("year", ""))
    venue_parts: list[str] = []
    if journal:
        venue_parts.append(f"*{escape_markdown_chars(journal)}*")
    elif booktitle:
        venue_parts.append(f"*{escape_markdown_chars(booktitle)}*")
    if volume:
        venue_parts.append(f"**{escape_markdown_chars(volume)}**")
    if number:
        venue_parts.append(f"({escape_markdown_chars(number)})")
    if pages:
        venue_parts.append(escape_markdown_chars(pages))
    venue_str = " ".join(venue_parts)
    eprint = strip_outer_braces(entry.get("eprint", ""))
    doi = strip_outer_braces(entry.get("doi", ""))
    link_parts: list[str] = []
    if eprint:
        link_parts.append(f"[arXiv:{eprint}](https://arxiv.org/abs/{eprint})")
    if doi:
        link_parts.append(f"[doi:{doi}](https://doi.org/{doi})")
    line = f"{idx}. **{escape_markdown_chars(title)}**"
    if authors:
        line += f", {escape_markdown_chars(authors)}"
    if venue_str:
        line += f", {venue_str}"
    if year:
        line += f" ({escape_markdown_chars(year)})"
    if link_parts:
        line += " " + " ".join(link_parts)
    return line

def entry_to_html_list_item(entry: dict) -> str:
    """Format a single BibTeX entry as an HTML list item.\n
    Produces an ``<li>`` element with title (bold), authors, venue
    (italic journal or booktitle, bold volume), year, and arXiv/DOI
    links.\n
    Args:
        entry: A parsed BibTeX entry dict.\n
    Returns:
        An HTML ``<li>...</li>`` string.
    """
    title = strip_outer_braces(entry.get("title", "Untitled"))
    authors = format_author_list(entry.get("author", ""))
    journal = strip_outer_braces(entry.get("journal", ""))
    booktitle = strip_outer_braces(entry.get("booktitle", ""))
    volume = strip_outer_braces(entry.get("volume", ""))
    number = strip_outer_braces(entry.get("number", ""))
    pages = strip_outer_braces(entry.get("pages", "")).replace("--", "–")
    year = strip_outer_braces(entry.get("year", ""))
    eprint = strip_outer_braces(entry.get("eprint", ""))
    doi = strip_outer_braces(entry.get("doi", ""))
    #
    parts: list[str] = []
    parts.append(f"<strong>{escape_html_entities(title)}</strong>")
    if authors:
        parts.append(escape_html_entities(authors))
    venue_parts: list[str] = []
    if journal:
        venue_parts.append(f"<em>{escape_html_entities(journal)}</em>")
    elif booktitle:
        venue_parts.append(f"<em>{escape_html_entities(booktitle)}</em>")
    if volume:
        venue_parts.append(f"<strong>{escape_html_entities(volume)}</strong>")
    if number:
        venue_parts.append(f"({escape_html_entities(number)})")
    if pages:
        venue_parts.append(escape_html_entities(pages))
    if venue_parts:
        parts.append(" ".join(venue_parts))
    if year:
        parts.append(f"({escape_html_entities(year)})")
    link_parts: list[str] = []
    if eprint:
        link_parts.append(
            f'<a href="https://arxiv.org/abs/{eprint}">arXiv:{escape_html_entities(eprint)}</a>'
        )
    if doi:
        link_parts.append(
            f'<a href="https://doi.org/{doi}">doi:{escape_html_entities(doi)}</a>'
        )
    if link_parts:
        parts.append(" ".join(link_parts))
    return "<li>" + ", ".join(parts) + "</li>"

def entries_to_markdown_list(entries: list[dict]) -> str:
    """Convert a list of BibTeX entry dicts to a markdown reference list.\n
    Each entry is formatted as a numbered line (via
    :func:`format_entry_markdown`), separated by blank lines.  A trailing
    newline is included.\n
    Args:
        entries: List of parsed BibTeX entry dicts.\n
    Returns:
        A markdown string with numbered references.
    """
    lines = [f"{entry_to_markdown_line(i, e)}" for i, e in enumerate(entries, 1)]
    return "\n\n".join(lines) + "\n"

def entries_to_html_document(entries: list[dict], title: str = "") -> str:
    """Convert a list of BibTeX entry dicts to a complete HTML document.\n
    Wraps the entries in an ``<ol>`` with left padding sized to the
    number of digits in the entry count, then wraps the whole thing
    in a full HTML5 document.\n
    Args:
        entries: List of parsed BibTeX entry dicts.
        title: Optional HTML document title.\n
    Returns:
        A complete HTML5 document string.
    """
    import math
    #
    items = "\n".join(entry_to_html_list_item(e) for e in entries)
    if entries:
        digits = max(2, math.ceil(math.log10(len(entries) + 1)))
        padding = f"{digits + 1.5}em"
    else:
        padding = "3.5em"
    body = f'<ol style="padding-left: {padding}">\n{items}\n</ol>'
    return wrap_in_html_document(body, title=title)

# ---------------------------------------------------------------------------
# Output
# ---------------------------------------------------------------------------

def write_bibtex_to_file(entries: list[str], output_path: str | None) -> None:
    """Write BibTeX entries to a file or stdout.\n
    Entries are separated by blank lines.  If *output_path* is ``None``,
    writes to stdout instead.\n
    Args:
        entries: List of raw BibTeX entry strings.
        output_path: Destination file path, or ``None`` for stdout.
    """
    import sys
    #
    output = "\n\n".join(entries) + "\n" if entries else ""
    if output_path:
        with open(output_path, "w") as f:
            f.write(output)
    else:
        sys.stdout.write(output)

def convert_entries_to_markdown_list(bib_path: str, md_path: str) -> None:
    """Read a .bib file and write a numbered markdown reference list.\n
    Each entry becomes a single line of the form::\n
        1. **Title**, Authors, *Journal* **vol** (num) pages (year)
           [arXiv:ID](url) [doi:ID](url)\n
    Entries are separated by blank lines.  Fields that are absent are
    omitted.  If the .bib file contains no entries, writes an empty file.\n
    Args:
        bib_path: Path to the input .bib file.
        md_path: Path to the output .md file.
    """
    import os
    #
    with open(bib_path) as f:
        bib = f.read()
    #
    entries = parse_bibtex_to_entries(bib)
    if not entries:
        print(
            f"  No entries in {os.path.basename(bib_path)}, writing empty markdown",
            file=sys.stderr,
        )
        with open(md_path, "w") as f:
            f.write("")
        return
    #
    md = entries_to_markdown_list(entries)
    with open(md_path, "w") as f:
        f.write(md)
    print(
        f"  Wrote {len(entries)} entries to {os.path.basename(md_path)}",
        file=sys.stderr,
    )

def convert_entries_to_html_document(bib_path: str, html_path: str) -> None:
    """Read a .bib file and write an HTML ordered reference list.\n
    Each entry becomes an ``<li>`` with title, authors, venue, year,
    and arXiv/DOI links, wrapped in a complete HTML5 document.  If the
    .bib file contains no entries, writes an empty file.\n
    Args:
        bib_path: Path to the input .bib file.
        html_path: Path to the output .html file.
    """
    import os
    #
    with open(bib_path) as f:
        bib = f.read()
    #
    entries = parse_bibtex_to_entries(bib)
    if not entries:
        print(
            f"  No entries in {os.path.basename(bib_path)}, writing empty HTML",
            file=sys.stderr,
        )
        with open(html_path, "w") as f:
            f.write("")
        return
    #
    html = entries_to_html_document(entries, title=os.path.basename(html_path))
    with open(html_path, "w") as f:
        f.write(html)
    print(
        f"  Wrote {len(entries)} entries to {os.path.basename(html_path)}",
        file=sys.stderr,
    )

def merge_bib_files_and_write(files: list[str], bib_output: str, label: str) -> None:
    """Read BibTeX files, merge and deduplicate, write .bib/.md/.html outputs.\n
    Reads all *files*, merges entries (first occurrence wins on duplicate
    keys), then writes three output files derived from *bib_output*:\n
    - ``bib_output`` (.bib): Merged BibTeX.
    - ``stem.md``: Numbered markdown reference list.
    - ``stem.html``: HTML ordered reference list.\n
    Args:
        files: List of .bib file paths to read and merge.
        bib_output: Path for the output .bib file.  The .md and .html
            outputs are derived by replacing the extension.
        label: Human-readable label for log messages (e.g.,
            ``"reference"``, ``"citation"``, ``"paper"``).
    """
    bib_texts = []
    for path in files:
        with open(path) as f:
            bib_texts.append(f.read())
    total_input = sum(len(parse_bibtex_to_entries(bib_text)) for bib_text in bib_texts)
    merged = merge_bibtex_strings(*bib_texts)
    bib_result = entries_to_bibtex_string(merged)
    md_result = entries_to_markdown_list(merged)
    duplicates = total_input - len(merged)
    #
    stem = bib_output.removesuffix(".bib")
    md_output = stem + ".md"
    html_output = stem + ".html"
    html_result = entries_to_html_document(merged, title=os.path.basename(html_output))
    #
    with open(bib_output, "w") as f:
        f.write(bib_result)
    with open(md_output, "w") as f:
        f.write(md_result)
    with open(html_output, "w") as f:
        f.write(html_result)
    #
    print(
        f"Merged {len(merged)} {label} entries from {len(files)} file(s) → {bib_output}, {md_output}, {html_output} ({duplicates} duplicates removed)"
    )
