#!/usr/bin/env bash
# pyreplab — CLI for the persistent Python REPL
#
# Usage:
#   pyreplab start [--workdir DIR] [--cwd DIR] [--venv PATH] [...]
#   pyreplab run notebook.py       Run all cells (stamps [N] into # %% markers)
#   pyreplab run notebook.py:2     Run cell 2
#   pyreplab run 'print("hello")'  Inline code
#   pyreplab run < script.py       Read from stdin
#   pyreplab cells notebook.py     List cells (stamps [N], peeks comments)
#   pyreplab wait                  Wait for a running command to finish
#   pyreplab cancel                Cancel the currently running command
#   pyreplab status                Check if REPL is running
#   pyreplab ps                    List all active sessions
#   pyreplab stop / stop-all       Stop session(s)
#   pyreplab clean                 Remove session files
#
# No sourcing needed. Each invocation is standalone.

set -euo pipefail

PYREPLAB_BASE="${PYREPLAB_BASE:-/tmp/pyreplab}"
if [ -n "${PYREPLAB_SCRIPT:-}" ]; then
    :
elif [ -f "$(cd "$(dirname "$0")" && pwd)/pyreplab.py" ]; then
    PYREPLAB_SCRIPT="$(cd "$(dirname "$0")" && pwd)/pyreplab.py"
else
    PYREPLAB_SCRIPT="$(python3 -c 'import pyreplab; print(pyreplab.__file__)')"
fi

# Resolve session dir: explicit PYREPLAB_DIR, or derived from workdir
_resolve_dir() {
    if [ -n "${PYREPLAB_DIR:-}" ]; then
        echo "$PYREPLAB_DIR"
    elif [ -n "${_PYREPLAB_WORKDIR:-}" ]; then
        # Hash the workdir into a short session name
        local hash
        hash=$(printf '%s' "$_PYREPLAB_WORKDIR" | md5sum 2>/dev/null | cut -c1-8 || printf '%s' "$_PYREPLAB_WORKDIR" | md5 -q 2>/dev/null | cut -c1-8)
        local name
        name=$(basename "$_PYREPLAB_WORKDIR")
        echo "$PYREPLAB_BASE/${name}_${hash}"
    else
        echo "$PYREPLAB_BASE/default"
    fi
}

# Extract --workdir from args, or default to cwd
_PYREPLAB_WORKDIR=""
_parse_workdir() {
    local args=("$@")
    for ((i=0; i<${#args[@]}; i++)); do
        if [ "${args[$i]}" = "--workdir" ] && [ $((i+1)) -lt ${#args[@]} ]; then
            _PYREPLAB_WORKDIR="$(cd "${args[$((i+1))]}" && pwd)"
            return
        fi
    done
    # Default: cwd
    _PYREPLAB_WORKDIR="$(pwd)"
}

_is_running() {
    local dir
    dir=$(_resolve_dir)
    local pidfile="$dir/pyreplab.pid"
    [ -f "$pidfile" ] && kill -0 "$(cat "$pidfile")" 2>/dev/null
}

# Resolve session with fallback: if cwd-derived session has no running daemon,
# use the last-started session. Call after _parse_workdir.
_resolve_session() {
    if [ -z "${PYREPLAB_DIR:-}" ]; then
        if ! _is_running && [ -f "$PYREPLAB_BASE/.last_session" ]; then
            local last_dir
            last_dir=$(cat "$PYREPLAB_BASE/.last_session")
            local last_pid="$last_dir/pyreplab.pid"
            if [ -f "$last_pid" ] && kill -0 "$(cat "$last_pid")" 2>/dev/null; then
                export PYREPLAB_DIR="$last_dir"
            fi
        fi
    fi
}

cmd_start() {
    _parse_workdir "$@"
    local dir
    dir=$(_resolve_dir)
    local pidfile="$dir/pyreplab.pid"

    if [ -f "$pidfile" ] && kill -0 "$(cat "$pidfile")" 2>/dev/null; then
        echo "pyreplab: already running (pid $(cat "$pidfile"), dir $dir)" >&2
        return 0
    fi
    mkdir -p "$dir"
    # Clean stale IPC files from a previous (dead) session
    rm -f "$dir/cmd.py" "$dir/output.json" "$dir/done" "$dir/pending_id" "$dir/pending_start"
    python3 "$PYREPLAB_SCRIPT" --session-dir "$dir" "$@" &
    local pid=$!
    echo "$pid" > "$pidfile"
    sleep 0.3
    if ! kill -0 "$pid" 2>/dev/null; then
        echo "pyreplab: failed to start" >&2
        rm -f "$pidfile"
        return 1
    fi
    echo "pyreplab: started (pid $pid, dir $dir)" >&2
    # Save as last-used session for cross-directory reuse
    printf '%s' "$dir" > "$PYREPLAB_BASE/.last_session"
}

cmd_stop() {
    local dir
    dir=$(_resolve_dir)
    local pidfile="$dir/pyreplab.pid"
    if ! [ -f "$pidfile" ] || ! kill -0 "$(cat "$pidfile")" 2>/dev/null; then
        echo "pyreplab: not running" >&2
        rm -f "$pidfile"
        return 0
    fi
    local pid
    pid=$(cat "$pidfile")
    kill "$pid" 2>/dev/null
    local i=0
    while kill -0 "$pid" 2>/dev/null && [ "$i" -lt 30 ]; do
        sleep 0.1
        i=$((i + 1))
    done
    rm -f "$pidfile"
    echo "pyreplab: stopped (pid $pid)" >&2
}

cmd_clean() {
    local dir
    dir=$(_resolve_dir)
    for f in cmd.py cmd.py.tmp output.json output.json.tmp done pending_id pending_start; do
        [ -f "$dir/$f" ] && rm -f "$dir/$f"
    done
    echo "pyreplab: cleaned session files in $dir" >&2
}

_count_cells() {
    # Count the number of cells in a .py file.
    local file="$1"
    python3 -c '
import sys, re
text = open(sys.argv[1]).read()
markers = list(re.finditer(r"(?m)^# ?%%[^\n]*\n", text))
if not markers:
    print(1)
elif re.match(r"# ?%%", text):
    print(len(markers))
else:
    print(len(markers) + 1)
' "$file"
}

_stamp_cells() {
    # Add or update [N] indices on cell markers in a .py file. Idempotent.
    local file="$1"
    python3 -c '
import sys, re

path = sys.argv[1]
text = open(path).read()
lines = text.split("\n")

# If file starts with a cell marker, first marker is cell 0.
# Otherwise preamble is cell 0, first marker is cell 1.
has_preamble = not re.match(r"# ?%%", text)
cell_idx = 1 if has_preamble else 0
changed = False

for i, line in enumerate(lines):
    m = re.match(r"^(# ?%%)\s*(?:\[(\d+)\]\s*)?(.*?)$", line)
    if m:
        prefix, existing_idx, label = m.groups()
        new_line = f"# %% [{cell_idx}]"
        if label.strip():
            new_line += f" {label.strip()}"
        if lines[i] != new_line:
            lines[i] = new_line
            changed = True
        cell_idx += 1

if changed:
    with open(path, "w") as f:
        f.write("\n".join(lines))
' "$file"
}

_extract_cell() {
    # Parse a .py file into #%% cells and extract one by index.
    # Usage: _extract_cell file.py [cell_index]
    # If no index, returns the whole file.
    local file="$1"
    local cell_idx="${2:-all}"

    if [ "$cell_idx" = "all" ]; then
        cat "$file"
        return
    fi

    python3 -c '
import sys, re

text = open(sys.argv[1]).read()
# Split on # %% or #%% lines (accept optional space per industry convention)
parts = re.split(r"(?m)^# ?%%[^\n]*\n", text)

# If file starts with a cell marker, first split is empty string before it
# If file does NOT start with one, first split is the preamble (cell 0)
cells = []
markers = list(re.finditer(r"(?m)^# ?%%[^\n]*\n", text))

if re.match(r"# ?%%", text):
    # Each marker corresponds to a cell
    cells = parts[1:]  # skip empty first split
else:
    # parts[0] is preamble before first marker, rest follow markers
    cells = parts

idx = int(sys.argv[2])
if idx < 0 or idx >= len(cells):
    print(f"pyreplab: cell {idx} not found (file has {len(cells)} cells)", file=sys.stderr)
    sys.exit(1)
print(cells[idx], end="")
' "$file" "$cell_idx"
}

_wait_for_result() {
    # Poll for command completion and print output.
    # Returns 0 on success, 1 on error output, 2 on timeout (still running).
    local timeout="${1:-30}"
    local dir
    dir=$(_resolve_dir)
    local done_path="$dir/done"
    local output_path="$dir/output.json"
    local pending_path="$dir/pending_id"

    # Read submission timestamp for total elapsed reporting
    local start_time=""
    [ -f "$dir/pending_start" ] && start_time=$(cat "$dir/pending_start")

    # Poll for completion
    local elapsed=0
    while [ ! -f "$done_path" ]; do
        sleep 0.1
        elapsed=$((elapsed + 1))
        if [ "$elapsed" -ge "$((timeout * 10))" ]; then
            local total="${timeout}s"
            if [ -n "$start_time" ]; then
                total="$(( $(date +%s) - start_time ))s"
            fi
            echo "pyreplab: still running (${total} elapsed). Run \`pyreplab wait\` to check again." >&2
            return 2
        fi
    done

    # Clean up handshake files
    rm -f "$done_path"

    # Read output, then clean it up
    local result
    result=$(cat "$output_path" 2>/dev/null) || true
    if [ -z "$result" ]; then
        echo "pyreplab: no output received" >&2
        rm -f "$pending_path" "$dir/pending_start"
        return 1
    fi
    rm -f "$output_path"
    rm -f "$pending_path" "$dir/pending_start"

    # Parse JSON and print (single python3 call)
    printf '%s' "$result" | python3 -c '
import json, sys
d = json.load(sys.stdin)
stdout = d.get("stdout", "")
stderr = d.get("stderr", "")
error = d.get("error")
if stdout:
    print(stdout, end="")
if stderr:
    print(stderr, end="", file=sys.stderr)
if error:
    print(error, end="", file=sys.stderr)
    sys.exit(1)
'
}

_send_code() {
    # Send code to the pyreplab server and print the response.
    local code="$1"
    local cell_label="${2:-}"
    local dir
    dir=$(_resolve_dir)
    local id
    id="cmd_$$_$(date +%s%N 2>/dev/null || echo $RANDOM)"
    local cmd_path="$dir/cmd.py"
    local done_path="$dir/done"
    local output_path="$dir/output.json"
    local pending_path="$dir/pending_id"
    local timeout="${PYREPLAB_TIMEOUT:-30}"

    # Check if server is busy with a previous command
    if [ -f "$pending_path" ]; then
        # If daemon is dead, these are stale files from a previous session — clean up
        if ! _is_running; then
            rm -f "$pending_path" "$done_path" "$output_path" "$dir/cmd.py" "$dir/pending_start"
        else
            echo "pyreplab: busy running previous command. Run \`pyreplab wait\` first." >&2
            return 1
        fi
    fi

    # Remove stale handshake files
    rm -f "$done_path" "$output_path"

    # Write cmd.py with #%% cell header (atomic: write tmp, then move)
    local header="id: $id cwd: $(pwd)"
    [ -n "$cell_label" ] && header="$header cell: $cell_label"
    printf '#%%%% %s\n%s\n' "$header" "$code" > "$cmd_path.tmp"
    mv "$cmd_path.tmp" "$cmd_path"

    # Track pending command and start time
    printf '%s' "$id" > "$pending_path"
    date +%s > "$dir/pending_start"

    # Wait for result
    _wait_for_result "$timeout"
}

_send_notebook() {
    # Send a notebook path to the daemon for server-side multi-cell execution.
    # The daemon reads the file, splits cells, and runs them all sequentially.
    local notebook_abs="$1"
    local dir
    dir=$(_resolve_dir)
    local id
    id="cmd_$$_$(date +%s%N 2>/dev/null || echo $RANDOM)"
    local cmd_path="$dir/cmd.py"
    local done_path="$dir/done"
    local output_path="$dir/output.json"
    local pending_path="$dir/pending_id"
    local timeout="${PYREPLAB_TIMEOUT:-30}"

    # Check if server is busy with a previous command
    if [ -f "$pending_path" ]; then
        if ! _is_running; then
            rm -f "$pending_path" "$done_path" "$output_path" "$dir/cmd.py" "$dir/pending_start"
        else
            echo "pyreplab: busy running previous command. Run \`pyreplab wait\` first." >&2
            return 1
        fi
    fi

    rm -f "$done_path" "$output_path"

    # Write cmd.py with notebook header (no code body — daemon reads the file)
    local header="id: $id cwd: $(pwd) notebook: $notebook_abs"
    printf '#%%%% %s\n' "$header" > "$cmd_path.tmp"
    mv "$cmd_path.tmp" "$cmd_path"

    printf '%s' "$id" > "$pending_path"
    date +%s > "$dir/pending_start"

    _wait_for_result "$timeout"
}

cmd_cells() {
    local file="${1:-}"
    if [ -z "$file" ] || ! [ -f "$file" ]; then
        echo "Usage: pyreplab cells <file.py>" >&2
        return 1
    fi
    local stamp="${PYREPLAB_STAMP:-1}"
    [ "$stamp" = "1" ] && _stamp_cells "$file"
    python3 -c '
import sys, re
text = open(sys.argv[1]).read()
lines = text.split("\n")
markers = list(re.finditer(r"(?m)^# ?%%[^\n]*$", text))
if not markers:
    # No cell markers — whole file is cell 0
    first = lines[0].strip() if lines else "(empty)"
    print(f"  0: {first}")
    sys.exit(0)
# If file starts with content before first marker, that is cell 0 (preamble)
if markers[0].start() > 0:
    pre = text[:markers[0].start()].strip().split("\n")
    first = next((l.strip() for l in pre if l.strip()), "(empty)")
    print(f"  0: {first}")
    offset = 1
else:
    offset = 0
for i, m in enumerate(markers):
    # Strip the marker prefix to get the inline label
    label = re.sub(r"^# ?%%\s*(?:\[\d+\]\s*)?", "", m.group()).strip()
    # If no inline label, peek at the next line for a comment
    if not label:
        end = m.end()
        line_start = text.find("\n", end)
        if line_start == -1:
            next_line = ""
        else:
            # The next line starts right after the marker line
            next_start = end + 1 if end < len(text) and text[end] == "\n" else end
            next_end = text.find("\n", next_start)
            next_line = text[next_start:next_end].strip() if next_end != -1 else text[next_start:].strip()
        if next_line.startswith("#") and not re.match(r"# ?%%", next_line):
            label = next_line.lstrip("# ").strip()
    label = label or "(unnamed)"
    print(f"  {i + offset}: # %% {label}")
' "$file"
}

cmd_run() {
    # Resolve session from cwd (unless PYREPLAB_DIR is set)
    [ -z "${PYREPLAB_DIR:-}" ] && _parse_workdir
    _resolve_session
    local arg="${1:-}"
    local stamp="${PYREPLAB_STAMP:-1}"

    if [ -z "$arg" ]; then
        # No args: read from stdin
        local code
        code=$(cat)
        if [ -z "$code" ]; then
            echo "pyreplab: no code to run" >&2
            return 1
        fi
        _send_code "$code" ""
    elif [ -f "$arg" ]; then
        # pyreplab run file.py — stamp cells, send notebook to daemon for server-side execution
        [ "$stamp" = "1" ] && _stamp_cells "$arg"
        local file_abs
        file_abs="$(cd "$(dirname "$arg")" && pwd)/$(basename "$arg")"
        _send_notebook "$file_abs"
    elif [[ "$arg" == *:* ]] && [ -f "${arg%%:*}" ]; then
        # pyreplab run file.py:N — stamp cells, then run cell N
        local file="${arg%%:*}"
        local cell="${arg##*:}"
        [ "$stamp" = "1" ] && _stamp_cells "$file"
        local code
        code=$(_extract_cell "$file" "$cell")
        if [ -z "$code" ]; then
            echo "pyreplab: no code to run" >&2
            return 1
        fi
        _send_code "$code" "$(basename "$file"):$cell"
    else
        # Inline code
        _send_code "$arg" ""
    fi
}

cmd_wait() {
    local dir
    dir=$(_resolve_dir)

    if [ ! -f "$dir/pending_id" ]; then
        echo "pyreplab: no command pending" >&2
        return 1
    fi

    # If daemon is dead, the pending command will never complete — clean up
    if ! _is_running; then
        rm -f "$dir/pending_id" "$dir/pending_start" "$dir/done" "$dir/output.json" "$dir/cmd.py"
        echo "pyreplab: server died while command was pending (cleaned up stale files)" >&2
        return 1
    fi

    # Short poll (2s) — return quickly so agents aren't blocked
    _wait_for_result 2
}

cmd_cancel() {
    local dir
    dir=$(_resolve_dir)
    local pidfile="$dir/pyreplab.pid"

    if ! [ -f "$pidfile" ] || ! kill -0 "$(cat "$pidfile")" 2>/dev/null; then
        echo "pyreplab: not running" >&2
        return 1
    fi

    if [ ! -f "$dir/pending_id" ]; then
        echo "pyreplab: no command running" >&2
        return 0
    fi

    local pid
    pid=$(cat "$pidfile")
    kill -USR1 "$pid" 2>/dev/null
    echo "pyreplab: cancel signal sent" >&2

    # Wait briefly for the daemon to finish processing the cancellation
    _wait_for_result 5
}

cmd_dir() {
    local dir
    dir=$(_resolve_dir)
    echo "$dir"
}

cmd_status() {
    local dir
    dir=$(_resolve_dir)
    local pidfile="$dir/pyreplab.pid"
    if [ -f "$pidfile" ] && kill -0 "$(cat "$pidfile")" 2>/dev/null; then
        local state="idle"
        if [ -f "$dir/pending_id" ]; then
            state="executing command"
            if [ -f "$dir/pending_start" ]; then
                local elapsed=$(( $(date +%s) - $(cat "$dir/pending_start") ))
                state="executing command (${elapsed}s elapsed)"
            fi
        fi
        echo "pyreplab: running (pid $(cat "$pidfile"), dir $dir), $state"
    else
        echo "pyreplab: not running"
        return 1
    fi
}

cmd_ps() {
    local found=0
    local format="%-28s %-7s %-8s %-6s %s\n"
    printf "$format" "SESSION" "PID" "UPTIME" "MEM" "DIR"
    for pidfile in "$PYREPLAB_BASE"/*/pyreplab.pid; do
        [ -f "$pidfile" ] || continue
        local pid dir name
        pid=$(cat "$pidfile")
        dir=$(dirname "$pidfile")
        name=$(basename "$dir")
        if kill -0 "$pid" 2>/dev/null; then
            # Uptime from pidfile creation time
            local uptime="?"
            if stat -f%m "$pidfile" >/dev/null 2>&1; then
                # macOS stat
                local created now elapsed
                created=$(stat -f%m "$pidfile")
                now=$(date +%s)
                elapsed=$((now - created))
            elif stat -c%Y "$pidfile" >/dev/null 2>&1; then
                # Linux stat
                local created now elapsed
                created=$(stat -c%Y "$pidfile")
                now=$(date +%s)
                elapsed=$((now - created))
            fi
            if [ -n "${elapsed:-}" ]; then
                if [ "$elapsed" -ge 3600 ]; then
                    uptime="$((elapsed / 3600))h$((elapsed % 3600 / 60))m"
                elif [ "$elapsed" -ge 60 ]; then
                    uptime="$((elapsed / 60))m$((elapsed % 60))s"
                else
                    uptime="${elapsed}s"
                fi
            fi
            # Memory (RSS in MB)
            local mem="?"
            local rss
            rss=$(ps -o rss= -p "$pid" 2>/dev/null | tr -d ' ')
            if [ -n "$rss" ]; then
                mem="$((rss / 1024))MB"
            fi
            printf "$format" "$name" "$pid" "$uptime" "$mem" "$dir"
            found=$((found + 1))
        else
            rm -f "$pidfile"
        fi
    done
    if [ "$found" -eq 0 ]; then
        echo "pyreplab: no active sessions"
    fi
}

cmd_stop_all() {
    local stopped=0
    for pidfile in "$PYREPLAB_BASE"/*/pyreplab.pid; do
        [ -f "$pidfile" ] || continue
        local pid dir name
        pid=$(cat "$pidfile")
        dir=$(dirname "$pidfile")
        name=$(basename "$dir")
        if kill -0 "$pid" 2>/dev/null; then
            kill "$pid" 2>/dev/null
            local i=0
            while kill -0 "$pid" 2>/dev/null && [ "$i" -lt 30 ]; do
                sleep 0.1
                i=$((i + 1))
            done
            echo "pyreplab: stopped $name (pid $pid)" >&2
            stopped=$((stopped + 1))
        fi
        rm -f "$pidfile"
    done
    if [ "$stopped" -eq 0 ]; then
        echo "pyreplab: no sessions to stop" >&2
    else
        echo "pyreplab: stopped $stopped session(s)" >&2
    fi
}

# --- Main dispatch ---
case "${1:-help}" in
    start)  shift; cmd_start "$@" ;;
    stop)   shift; _parse_workdir "$@"; _resolve_session; cmd_stop ;;
    stop-all) cmd_stop_all ;;
    clean)  shift; _parse_workdir "$@"; _resolve_session; cmd_clean ;;
    run)    shift; cmd_run "$@" ;;
    wait)   shift; _parse_workdir "$@"; _resolve_session; cmd_wait ;;
    cancel) shift; _parse_workdir "$@"; _resolve_session; cmd_cancel ;;
    cells)  shift; cmd_cells "$@" ;;
    dir)    shift; _parse_workdir "$@"; _resolve_session; cmd_dir ;;
    status) shift; _parse_workdir "$@"; _resolve_session; cmd_status ;;
    ps|list|ls) cmd_ps ;;
    help|--help|-h)
        echo "Usage: pyreplab <command> [args]"
        echo ""
        echo "Commands:"
        echo "  start [opts]        Start the REPL (opts: --workdir, --cwd, --venv, ...)"
        echo "  run file.py         Run all cells (stamps [N] indices into file)"
        echo "  run file.py:N       Run cell N from file (0-indexed)"
        echo "  run 'code'          Run inline code"
        echo "  run                 Read code from stdin"
        echo "  cells file.py       List cells (stamps [N] indices into file)"
        echo "  wait                Wait for a running command to finish"
        echo "  cancel              Cancel the currently running command"
        echo "  stop                Stop the current session"
        echo "  stop-all            Stop all active sessions"
        echo "  dir                 Print session directory path"
        echo "  status              Check if REPL is running"
        echo "  ps                  List all active sessions"
        echo "  clean               Remove session files"
        echo ""
        echo ""
        echo "Cell stamping adds [N] indices to # %% markers in your .py files."
        echo "Set PYREPLAB_STAMP=0 to disable. Accepts both # %% and #%% markers."
        ;;
    *)
        echo "pyreplab: unknown command '$1' (try 'pyreplab help')" >&2
        exit 1
        ;;
esac
