#!/usr/bin/env python3
#
# Usage:
#
#   maint/update-md-links CHANGELOG.md
#
# Updates a markdown file, merging/inserting links from `gen_md_links`
#
# New links are added to the next to the first place in the file
# where links are defined.
#
# When we don't know what a link target should be, we emit four Xs.


# ALGORITHM
#
# Look for annotations in the file telling us how to wirk.
#
# Split the file up into sections (divided by the specified heading level).
#
# In each section, find existing link def lines
# (ie, lines giving the target `t` for a link anchor text `a`).
# This gives us
#   - for each section
#     - for each anchor, locations of relevant defs in this section
#   - targets for some anchors
#
# Feed each section separately to gen_md_links
# This gives us
#   - for each section, needed anchor set
#   - targets, for some anchors
#
# Reconcile definitions, to obtain precisely one target for each anchor.
#
# for each section.
#   For each needed anchor
#     If there are def(s) in this section, before the def collection, OK
#     Otherwise write a a definition line to the new collection
#   Collect defs from existing final link def collection (if any)
#   Sort the collection
#   Replace the relevant part of the file with the normalised collection
#   (possibly *adding* the collection)

from __future__ import annotations
import argparse
import collections
import filecmp
import os
import re
import subprocess
import sys
import tempfile

from typing import Any, Optional, Tuple, Union

# ---------- "constant" definitions ----------

# regexps

link_def_re = re.compile(r"\[([^][]+)\]\:\s?(.*)\n?")
heading_re = re.compile(r"(\#+)\s")
instruction_re = re.compile(r"\<\!\-\-\@\@\s+update-md-links\s*(.*\S)\s*\-\-\>\s*$")
instruction_val_re = re.compile(r"\s*([-0-9a-z]+)\s+(.*\S)\s*")

# ---------- set up globals: parse command line and read the input file ----------

# "instructions" we understand in comments like this:
# <!--@@ update-md-links INSTRUCTION VALUE -->
instructions = {
    "split-heading-level": 0,
    "section-blank-lines": 1,
}

parser = argparse.ArgumentParser(
    prog="update-md-links",
    description="update links in a markdown document",
)
parser.add_argument("filename")
parser.add_argument(
    "--check",
    action="store_true",
    help="Check that everything is up to date; make no changes",
)

args = parser.parse_args()
md = list(open(args.filename, "r"))

# ---------- types used in our data structures ----------

Anchor = str
Target = str
Source = str

# One definition, of `t`, with human-readable source `source`, at md line `md_i`
#
# The link definition might be from maint/gen_md_links, in which case `md_i` is `None`.
Def = collections.namedtuple("Def", ["t", "source", "md_i"])

Defs = dict[Anchor, list[Def]]
Resolved = dict[Anchor, Target]

# One section (if we have `split-heading-level` of other than 0).
#
# Comprises lines [`start`, `end`).
# `defs_t2i` maps link target to a list of line numbers it's defined on
# `a_needed` will be set to True when we find it actually contains any link
Section = collections.namedtuple("Section", ["start", "end", "defs_t2i", "a_needed"])

# ---------- utility functions ----------

troubles = 0


def trouble(m: str) -> None:
    """
    Record a "trouble" - a nonfatal problem.

    Prints the message `m` to stderr, and increments `troubles`, so we exit nonzero, later.
    """
    global troubles
    print("trouble: " + m, file=sys.stderr)
    troubles += 1


CheckableSubprocess = Union[subprocess.CompletedProcess[Any], subprocess.Popen[str]]


def check_returncode(process: CheckableSubprocess) -> None:
    """
    Check `process.returncode`; if it's not zero, print an error message and exit nonzero.
    """
    r = process.returncode
    if r != 0:
        print("subprocess failed with nonzero returncode %s" % r, file=sys.stderr)
        sys.exit(12)


def is_link_def(line: str) -> Optional[Tuple[str, str]]:
    """
    Is `line` in the syntax of a link definition?

    If so returns `(a, t)` where `a` is the anchor and `t` the target.
    Otherwise returns None.
    """
    g = link_def_re.fullmatch(line)
    if g is None:
        return None
    # mypy can't see that our regexp has precisely these captures
    r: Tuple[str, str] = g.groups()  # type: ignore
    return r


# ---------- search for instructions ----------


def process_instructions() -> None:
    """
    Looks for instructions and updates the global `instructions`.
    """

    for i, l in enumerate(md):
        source = "%s:%d" % (args.filename, i + 1)

        g = instruction_re.fullmatch(l)
        if not g:
            continue

        g = instruction_val_re.fullmatch(g.group(1))
        if g:
            kv: Tuple[str, str] = g.groups()  # type: ignore
            k, v = kv
            if k not in instructions:
                trouble("%s: unknown value instruction %s" % (source, k))
            instructions[k] = int(v)
            continue

        trouble("%s: unknown instruction" % source)


# ---------- break input into sections ----------


def split_input() -> list[Section]:
    """
    Parse `md` into sections.
    """

    sections = []

    section_start = 0

    # namedtuple has defaults= but it gives every fresh tuple an aliased copy of the
    # same value!  So we provide this constructor.
    def new_section(start: int, end: int) -> None:
        sections.append(Section(start, end, {}, {}))

    for i, l in enumerate(md):
        g = heading_re.match(l)
        if g and len(g.group(1)) == int(instructions["split-heading-level"]):
            if i != section_start:
                new_section(section_start, i)
            section_start = i

    new_section(section_start, len(md))

    return sections


# ---------- scan input sections' contents ----------


def scan_sections(sections: list[Section]) -> Defs:
    """
    Scans each section in `sections`
    """
    link_defs: Defs = {}

    def record_link_def(
        a: Anchor, t: Target, source: Source, md_i: Optional[int]
    ) -> None:
        """
        Record that anchor `a` is defined to have target url `t`.

        `source` and `md_i` are as for `Def`.

        `t` may be the empty string (and for output from `gen_md_links`, often is).
        """
        link_defs.setdefault(a, []).append(Def(t, source, md_i))

    for s in sections:
        # ---------- for each section, find existing link def lines ----------

        for i in range(s.start, s.end):
            lno = i + 1
            line = md[i]
            at = is_link_def(line)
            if at:
                a, t = at
                record_link_def(a, t, "%s:%d" % (args.filename, lno), i)
                s.defs_t2i.setdefault(t, []).append(i)

        # ---------- for each section, run gen_md_links ----------

        text_file = tempfile.TemporaryFile(mode="w+", buffering=True)
        for i in range(s.start, s.end):
            print(md[i], file=text_file)
        text_file.flush()
        text_file.seek(0, 0)
        gen_links_output = subprocess.Popen(
            ["maint/gen_md_links", "--", "-"],
            stdin=text_file,
            stdout=subprocess.PIPE,
            encoding="utf-8",
        )
        assert gen_links_output.stdout
        for line in gen_links_output.stdout:
            line = line.strip()
            if line == "":
                continue
            at = is_link_def(line)
            if at is None:
                print(
                    "gen_md_links produced bad output line %s (for %s:%d..%d)"
                    % (repr(line), args.filename, s.start + 1, s.end),
                    file=sys.stderr,
                )
                sys.exit(12)
            a, t = at
            record_link_def(a, t, "gen_md_links", None)
            s.a_needed[a] = True

        gen_links_output.wait()
        check_returncode(gen_links_output)

    return link_defs


# ---------- reconcile definitions ----------


def resolve_definitions(link_defs: Defs) -> Resolved:
    """
    Resolve link definitions.
    """
    link_def = {}

    for a, defs in link_defs.items():
        candidates: dict[Target, list[Source]] = {}
        for d in defs:
            if d.t.strip() != "":
                candidates.setdefault(d.t, []).append(d.source)

        ts = list(candidates.keys())
        if len(ts) > 1:
            trouble("conflicting definitions for [%s]" % a)
            done: dict[Target, bool] = {}
            for d in defs:
                t = d.t
                if done.get(t):
                    continue
                done[t] = True
                print("  candidate %s" % t, file=sys.stderr)
                for d in defs:
                    if d.t != t:
                        continue
                    print("    defined %s" % d.source, file=sys.stderr)
        if len(ts) == 0:
            ts.append("XX" + "XX")

        link_def[a] = ts[0]

    return link_def


# ---------- collate outputs ----------


def collate_insert_outputs(
    sections: list[Section], link_defs: Defs, link_def: Resolved
) -> None:
    """
    Collate link definitions into each section

    Updates `md` in place.
    """
    for s in sections:
        linkcoll_start = s.end
        while True:
            if linkcoll_start <= s.start:
                break
            prev = linkcoll_start - 1
            prev_l = md[prev]
            if prev_l.strip() != "" and not is_link_def(prev_l):
                break
            linkcoll_start = prev

        if linkcoll_start <= s.start:
            continue  # section contains only links, ignore it

        if not s.a_needed:
            continue  # section contains no link anchors, ignore it

        # Now linkcoll_start is the start of the link collection for this section.
        # (Including blank lines either siude of the link collection.)

        new_collection = []

        for a in s.a_needed:
            found = False
            for d in link_defs[a]:
                i = d.md_i
                if i is None:
                    continue
                if i < s.start or i >= linkcoll_start:
                    continue
                found = True
                break
            if not found:
                new_collection.append("[%s]: %s\n" % (a, link_def[a]))

        # delete old collection
        for i in range(linkcoll_start, s.end):
            md[i] = ""

        o = ""
        if len(new_collection) != 0:
            new_collection.sort()
            o += "\n" + "".join(new_collection)

        if s.end != len(md):
            for i in range(0, int(instructions["section-blank-lines"])):
                o += "\n"

        md[linkcoll_start - 1] += o


# ---------- write output ----------


def write_output() -> None:
    """
    Writes the output file

    Writes to a `.tmp`, and then runs diff, or installs it, as appropriate.
    """
    new_filename = "%s.tmp" % args.filename
    output = open(new_filename, "w", buffering=True)
    for line in md:
        print(line, file=output, end="")
    output.close()

    if troubles != 0:
        print("trouble, not installing %s" % new_filename, file=sys.stderr)
        sys.exit(12)

    if args.check:
        r = subprocess.run(["diff", "-u", "--", args.filename, new_filename])
        if r.returncode == 1:
            print("%s links not up to date." % args.filename, file=sys.stderr)
            sys.exit(1)
        check_returncode(r)
        os.remove(new_filename)
    else:
        if filecmp.cmp(args.filename, new_filename):
            print("%s unchanged" % args.filename)
        else:
            print("%s *updated*!" % args.filename)
        os.rename(new_filename, args.filename)


# ---------- main program ----------

process_instructions()
sections = split_input()
link_defs = scan_sections(sections)
link_def = resolve_definitions(link_defs)
collate_insert_outputs(sections, link_defs, link_def)
write_output()
