#!/usr/bin/env python3
"""Public-facing TMLM-Haiku interactive CLI.

Pulls models from the CompactAI-O HuggingFace collection:
  https://huggingface.co/collections/CompactAI-O/tmlm-haiku-series
"""
from __future__ import annotations


#!/usr/bin/env python3
from __future__ import annotations

import hashlib
import json
import math
import os
import string
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Sequence, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint


HUGGINGFACE_MODELS = {
    "TMLM-Haiku-1": "CompactAI-O/TMLM-Haiku-1",
    "TMLM-Haiku-1.3": "CompactAI-O/TMLM-Haiku-1.3",
    "TMLM-Haiku-2": "CompactAI-O/TMLM-Haiku-2",
    "Glint-1": "CompactAI-O/Glint-1",
}


# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------


@dataclass
class ModelConfig:
    dim: int = 128
    n_unique_layers: int = 8
    n_logical_layers: int = 16
    n_heads: int = 4
    n_kv_heads: int = 2
    ffn_dim: int = 224
    dropout: float = 0.0
    seq_len: int = 2048
    sliding_window_size: int = 512
    mtp_horizons: Tuple[int, ...] = (2, 3, 4)
    rope_fraction: float = 0.5
    embed_scale: bool = True
    logit_soft_cap: float = -1.0
    quantization: str = "nvfp4"

    @property
    def head_dim(self) -> int:
        return self.dim // self.n_heads


model_config = ModelConfig()

MODEL_SERIES = {
    "haiku": {
        "dim": 64,
        "n_unique_layers": 12,
        "n_logical_layers": 24,
        "n_heads": 4,
        "n_kv_heads": 2,
        "ffn_dim": 384,
        "dropout": 0.0,
        "seq_len": 2048,
        "sliding_window_size": 2048,
        "mtp_horizons": (),
        "rope_fraction": 0.5,
        "engram_dim": 8,
        "engram_heads": 2,
        "engram_table_size": 64,
        "engram_max_ngram": 2,
        "mhc_expansion": 2,
        "sleep_gate_cap": 0,
        "sleep_gate_heads": 4,
        "latent_think_layers": 0,
        "prelude_layers": 0,
        "coda_layers": 0,
        "recurrent_loops": 0,
        "recurrent_act_threshold": 0.9,
        "recurrent_lora_rank": 0,
        "recurrent_loop_embed_dim": 0,
    },
    "sonnet": {
        "dim": 1024,
        "n_unique_layers": 20,
        "n_logical_layers": 40,
        "n_heads": 16,
        "n_kv_heads": 4,
        "ffn_dim": 4096,
        "dropout": 0.0,
        "seq_len": 2048,
        "mtp_horizons": (2,),
        "engram_dim": 32,
        "engram_heads": 8,
        "engram_table_size": 4096,
        "engram_max_ngram": 2,
        "mhc_expansion": 2,
        "sleep_gate_cap": 0,
        "sleep_gate_heads": 8,
        "latent_think_layers": 0,
        "prelude_layers": 0,
        "coda_layers": 0,
        "recurrent_loops": 0,
        "recurrent_act_threshold": 0.99,
        "recurrent_lora_rank": 0,
        "recurrent_loop_embed_dim": 0,
    },
    "opus": {
        "dim": 1536,
        "n_unique_layers": 18,
        "n_logical_layers": 36,
        "n_heads": 16,
        "n_kv_heads": 4,
        "ffn_dim": 5888,
        "dropout": 0.0,
        "seq_len": 2048,
        "mtp_horizons": (2,),
        "engram_dim": 64,
        "engram_heads": 8,
        "engram_table_size": 8192,
        "engram_max_ngram": 2,
        "mhc_expansion": 4,
        "sleep_gate_cap": 0,
        "sleep_gate_heads": 8,
        "latent_think_layers": 0,
        "prelude_layers": 0,
        "coda_layers": 0,
        "recurrent_loops": 0,
        "recurrent_act_threshold": 0.99,
        "recurrent_lora_rank": 0,
        "recurrent_loop_embed_dim": 0,
    },
}


# ---------------------------------------------------------------------------
# Tokenizer
# ---------------------------------------------------------------------------

FORMAT_TOKENS = [
    "<|user|>",
    "<|assistant|>",
    "<|system|>",
    "<|start_header_id|>",
    "<|end_header_id|>",
    "<|begin_of_thought|>",
    "<|end_of_thought|>",
    "<|begin_of_solution|>",
    "<|end_of_solution|>",
]


class WordTokenizer:
    def __init__(
        self, extra_chars: str = "", format_tokens: Optional[List[str]] = None
    ) -> None:
        base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r"
        fallback_chars = sorted(set(base + extra_chars))
        self.core_special = ["<PAD>", "<BOS>", "<EOS>", "<UNK>"]
        self.format_tokens = (
            list(format_tokens) if format_tokens else list(FORMAT_TOKENS)
        )
        self.special = list(self.core_special) + list(self.format_tokens)
        self.id_to_token: List[str] = (
            list(self.core_special) + self.format_tokens + fallback_chars
        )
        self.token_to_id: Dict[str, int] = {
            t: i for i, t in enumerate(self.id_to_token)
        }
        self.special_multi_tokens = sorted(
            [t for t in self.special if len(t) > 1], key=len, reverse=True
        )
        self.multi_char_tokens = self.special_multi_tokens
        self.dynamic_additions = 0

    @property
    def pad_id(self) -> int:
        return self.token_to_id["<PAD>"]

    @property
    def bos_id(self) -> int:
        return self.token_to_id["<BOS>"]

    @property
    def eos_id(self) -> int:
        return self.token_to_id["<EOS>"]

    @property
    def unk_id(self) -> int:
        return self.token_to_id["<UNK>"]

    @property
    def vocab_size(self) -> int:
        return len(self.id_to_token)

    def maybe_add_char(self, ch: str) -> bool:
        if ch in self.token_to_id:
            return False
        self.token_to_id[ch] = len(self.id_to_token)
        self.id_to_token.append(ch)
        self.dynamic_additions += 1
        return True

    def iter_lexical_tokens(self, text: str) -> Iterator[str]:
        i = 0
        n = len(text)
        while i < n:
            matched_special = False
            for token in self.special_multi_tokens:
                if text.startswith(token, i):
                    yield token
                    i += len(token)
                    matched_special = True
                    break
            if matched_special:
                continue
            yield text[i]
            i += 1

    def encode(
        self, text: str, add_bos: bool = False, add_eos: bool = False
    ) -> List[int]:
        out: List[int] = []
        if add_bos:
            out.append(self.bos_id)
        unk = self.unk_id
        t2i = self.token_to_id
        for tok in self.iter_lexical_tokens(text):
            out.append(t2i.get(tok, unk))
        if add_eos:
            out.append(self.eos_id)
        return out

    def decode(self, ids: Sequence[int], skip_special: bool = True) -> str:
        pieces: List[str] = []
        for idx in ids:
            if int(idx) < 0 or int(idx) >= len(self.id_to_token):
                continue
            tok = self.id_to_token[int(idx)]
            if skip_special and tok in self.special:
                continue
            pieces.append(tok)
        return "".join(pieces)

    @classmethod
    def load(cls, path: Path) -> WordTokenizer:
        with path.open("r", encoding="utf-8") as f:
            data = json.load(f)
        format_tokens = data.get("format_tokens", FORMAT_TOKENS)
        tokenizer = cls(extra_chars="", format_tokens=format_tokens)
        tokenizer.id_to_token = data["id_to_token"]
        tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)}
        tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens)
        tokenizer.special_multi_tokens = sorted(
            [t for t in tokenizer.special if len(t) > 1], key=len, reverse=True
        )
        tokenizer.multi_char_tokens = tokenizer.special_multi_tokens
        return tokenizer


LetterTokenizer = WordTokenizer


# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if hasattr(torch.nn.functional, "rms_norm"):
            return torch.nn.functional.rms_norm(
                x, self.weight.shape, self.weight, self.eps
            )
        x_fp = x.float()
        rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return (x_fp * rms).to(dtype=x.dtype) * self.weight


class RotaryEmbedding(nn.Module):
    def __init__(self, dim: int, base: float = 10000.0) -> None:
        super().__init__()
        inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv, persistent=False)

    def cos_sin(
        self, seq_len: int, device: torch.device, dtype: torch.dtype
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        cos = emb.cos()[None, None, :, :].to(dtype=dtype)
        sin = emb.sin()[None, None, :, :].to(dtype=dtype)
        return cos, sin


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


class CausalSelfAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        n_heads: int,
        n_kv_heads: int,
        head_dim: int,
        dropout: float,
        sliding_window: int,
        rope_fraction: float,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = head_dim
        self.n_rep = n_heads // n_kv_heads
        self.dropout = dropout
        self.sliding_window = sliding_window

        self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
        self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
        self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
        self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)

        self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2)
        self.rope = RotaryEmbedding(self.rope_dim)

        self.q_norm = RMSNorm(head_dim)
        self.k_norm = RMSNorm(head_dim)

        self.output_gate = nn.Parameter(torch.ones(n_heads))

    def forward(
        self,
        x: torch.Tensor,
        is_global: bool,
        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        B, T, _ = x.shape

        q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
        k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
        v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)

        q = self.q_norm(q)
        k = self.k_norm(k)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        past_len = past_kv[0].shape[2] if past_kv is not None else 0
        cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype)
        cos_slice = cos[:, :, past_len : past_len + T, :]
        sin_slice = sin[:, :, past_len : past_len + T, :]

        q_rope = q[..., : self.rope_dim]
        q_pass = q[..., self.rope_dim :]
        k_rope = k[..., : self.rope_dim]
        k_pass = k[..., self.rope_dim :]

        q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice)
        k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice)

        q = torch.cat([q_rope, q_pass], dim=-1)
        k = torch.cat([k_rope, k_pass], dim=-1)

        if past_kv is not None:
            k = torch.cat([past_kv[0], k], dim=2)
            v = torch.cat([past_kv[1], v], dim=2)

        new_kv = (k, v) if use_cache else None

        S = k.shape[2]
        if self.n_rep > 1:
            k = (
                k[:, :, None, :, :]
                .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
                .reshape(B, self.n_heads, S, self.head_dim)
            )
            v = (
                v[:, :, None, :, :]
                .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
                .reshape(B, self.n_heads, S, self.head_dim)
            )

        drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0

        if is_global:
            if past_kv is None and T > 1:
                out = F.scaled_dot_product_attention(
                    q, k, v, is_causal=True, dropout_p=drop_p
                )
            else:
                out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p)
        else:
            T_q = q.shape[2]
            q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1)
            k_pos = torch.arange(S, device=q.device).unsqueeze(0)
            mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window)
            out = F.scaled_dot_product_attention(
                q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p
            )

        gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1)
        out = out * gate

        out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
        out = self.wo(out)

        return out, new_kv


class SwiGLU(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None:
        super().__init__()
        self.gate = nn.Linear(dim, hidden_dim, bias=False)
        self.up = nn.Linear(dim, hidden_dim, bias=False)
        self.down = nn.Linear(hidden_dim, dim, bias=False)
        self.drop = nn.Dropout(dropout)

        nn.init.normal_(self.gate.weight, std=dim**-0.5)
        nn.init.normal_(self.up.weight, std=dim**-0.5)
        nn.init.normal_(self.down.weight, std=hidden_dim**-0.5)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = F.silu(self.gate(x)) * self.up(x)
        out = self.down(h)
        if self.training and torch.is_grad_enabled():
            out = self.drop(out)
        return out


def loop_index_embedding(h: torch.Tensor, loop_t: int, loop_dim: int, theta: float = 10000.0) -> torch.Tensor:
    if loop_dim <= 0:
        return h
    loop_dim = min(loop_dim, h.shape[-1])
    if loop_dim % 2 == 1:
        loop_dim -= 1
    if loop_dim <= 0:
        return h
    inv_freq = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
    phase = torch.tensor(float(loop_t), device=h.device, dtype=h.dtype) * inv_freq
    loop_embed = torch.cat([phase.sin(), phase.cos()], dim=0).view(1, 1, loop_dim)
    out = h.clone()
    out[..., :loop_dim] = out[..., :loop_dim] + loop_embed
    return out


class DepthLoRAAdapter(nn.Module):
    def __init__(self, dim: int, rank: int, max_loops: int) -> None:
        super().__init__()
        self.rank = max(0, rank)
        if self.rank <= 0:
            self.down = None
            self.B = None
            self.scale = None
            return
        self.down = nn.Linear(dim, self.rank, bias=False)
        self.B = nn.Parameter(torch.randn(self.rank, dim) * 0.02)
        self.scale = nn.Embedding(max(1, max_loops), self.rank)
        nn.init.zeros_(self.scale.weight)

    def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor:
        if self.rank <= 0 or self.down is None or self.B is None or self.scale is None:
            return torch.zeros_like(x)
        t_idx = min(loop_t, self.scale.num_embeddings - 1)
        scale = self.scale(torch.tensor(t_idx, device=x.device))
        return (self.down(x) * scale) @ self.B


class StableRecurrentInjection(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.log_A = nn.Parameter(torch.full((dim,), -2.0))
        self.log_dt = nn.Parameter(torch.full((dim,), -2.0))
        self.input_gate = nn.Parameter(torch.zeros(dim))

    def forward(self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor) -> torch.Tensor:
        A = torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))).view(1, 1, -1)
        B = torch.sigmoid(self.input_gate).view(1, 1, -1)
        return A * h + B * e + transformer_out


class AdaptiveHalting(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.halt = nn.Linear(dim, 1, bias=True)
        nn.init.zeros_(self.halt.weight)
        nn.init.constant_(self.halt.bias, -2.0)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        return torch.sigmoid(self.halt(h)).squeeze(-1)


class EngramBlock(nn.Module):
    """DeepSeek Engram: conditional memory via O(1) hashed N-gram lookup.

    Stores common token-pair/triplet patterns in an embedding table and
    retrieves them with multi-head hashing.  A context-aware gate (using the
    current hidden state as query) decides how much of the retrieved memory
    to inject into the residual stream.

    Reference: DeepSeek-AI, "Conditional Memory via Scalable Lookup" (2025).
    """

    def __init__(
        self,
        dim: int,
        engram_dim: int,
        n_heads: int = 4,
        table_size: int = 8192,
        max_ngram: int = 3,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.engram_dim = engram_dim
        self.n_heads = n_heads
        self.table_size = table_size
        self.max_ngram = max_ngram

        # One embedding table per (ngram_order, hash_head)
        self.embeddings = nn.ParameterDict()
        for n in range(2, max_ngram + 1):
            for k in range(n_heads):
                self.embeddings[f"{n}_{k}"] = nn.Parameter(
                    torch.randn(table_size, engram_dim) * (engram_dim**-0.5)
                )

        # Fixed hash parameters (non-learnable, deterministic)
        for n in range(2, max_ngram + 1):
            for k in range(n_heads):
                seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16)
                rng = torch.Generator().manual_seed(seed)
                a = torch.randint(1, 2**31, (1,), generator=rng).item()
                b = torch.randint(0, 2**31, (1,), generator=rng).item()
                self.register_buffer(
                    f"hash_a_{n}_{k}", torch.tensor(a), persistent=False
                )
                self.register_buffer(
                    f"hash_b_{n}_{k}", torch.tensor(b), persistent=False
                )

        # Causal convolution over N-gram branch outputs (kernel=4, dilation=max_ngram)
        total_branch_dim = engram_dim * n_heads * (max_ngram - 1)
        self.branch_conv = nn.Conv1d(
            total_branch_dim,
            total_branch_dim,
            kernel_size=4,
            dilation=max_ngram,
            padding=0,
            groups=total_branch_dim,
            bias=True,
        )
        nn.init.zeros_(self.branch_conv.weight)
        nn.init.zeros_(self.branch_conv.bias)

        # Context-aware gating: hidden state as query, memory as key/value
        self.gate_query = nn.Linear(dim, engram_dim, bias=False)
        self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False)
        self.gate_value = nn.Linear(total_branch_dim, dim, bias=False)
        self.gate_scale = engram_dim**-0.5

    def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor:
        """Hash n-gram token sequences into table indices.

        Args:
            token_ids: (B, T) token IDs
            n: n-gram order (2 = bigram, 3 = trigram)
            k: hash head index
        Returns:
            indices: (B, T) integer indices into embedding table
        """
        a = getattr(self, f"hash_a_{n}_{k}")
        b = getattr(self, f"hash_b_{n}_{k}")
        B, T = token_ids.shape

        # Pad left with zeros so every position has a valid n-gram
        padded = F.pad(token_ids, (n - 1, 0), value=0)  # (B, T+n-1)

        # Polynomial rolling hash
        combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device)
        for i in range(n):
            combined = combined * 31 + padded[:, i : i + T].long()

        indices = ((a * combined) ^ b) % self.table_size
        return indices

    def forward(
        self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Forward pass.

        Args:
            hidden: (B, T, dim) current hidden state
            token_ids: (B, T) input token IDs for n-gram hashing.
                       If None, uses argmax of hidden projections as proxy.
        Returns:
            output: (B, T, dim) memory injection for residual stream
        """
        B, T, _ = hidden.shape

        if token_ids is None:
            # Fallback: derive pseudo-token-ids from hidden state
            token_ids = hidden.mean(dim=-1).long() % self.table_size

        # Retrieve and concatenate across n-gram orders and hash heads
        branch_outputs = []
        for n in range(2, self.max_ngram + 1):
            for k in range(self.n_heads):
                indices = self._hash_ngram(token_ids, n, k)  # (B, T)
                table = self.embeddings[f"{n}_{k}"]  # (table_size, engram_dim)
                retrieved = table[indices]  # (B, T, engram_dim)
                branch_outputs.append(retrieved)

        # (B, T, engram_dim * n_heads * (max_ngram - 1))
        memory = torch.cat(branch_outputs, dim=-1)

        # Causal convolution over sequence dimension
        # Pad left for causality (kernel_size - 1 = 3)
        conv_in = memory.transpose(1, 2)  # (B, C, T)
        conv_in = F.pad(
            conv_in,
            ((self.branch_conv.kernel_size[0] - 1) * self.branch_conv.dilation[0], 0),
        )
        conv_out = self.branch_conv(conv_in)  # (B, C, T)
        memory = conv_out.transpose(1, 2)  # (B, T, C)

        # Context-aware gating
        query = self.gate_query(hidden)  # (B, T, engram_dim)
        key = self.gate_key(memory)  # (B, T, engram_dim)
        gate = torch.sigmoid(
            (query * key).sum(dim=-1, keepdim=True) * self.gate_scale
        )  # (B, T, 1)
        value = self.gate_value(memory)  # (B, T, dim)

        return gate * value


class SleepGate(nn.Module):
    """Persistent memory + periodic consolidation gate."""

    def __init__(
        self,
        dim: int,
        cap: int = 128,
        n_heads: int = 4,
        retention_enabled: bool = True,
        retention_hidden: int = 0,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.cap = cap
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** -0.5
        self.retention_enabled = retention_enabled

        self.register_buffer("mem_emb", torch.zeros(cap, dim, dtype=torch.bfloat16))
        self.register_buffer("mem_age", torch.zeros(cap, dtype=torch.long))
        self.register_buffer("mem_beta", torch.ones(cap, dtype=torch.float32))
        self.register_buffer("mem_count", torch.zeros((), dtype=torch.long))
        self.register_buffer("mem_head", torch.zeros((), dtype=torch.long))
        self.register_buffer("global_step", torch.zeros((), dtype=torch.long))

        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.o_proj = nn.Linear(dim, dim, bias=False)
        nn.init.zeros_(self.o_proj.weight)
        self.gate_scale = nn.Parameter(torch.zeros(()))

        if retention_enabled:
            if retention_hidden > 0:
                self.retention_gate: Optional[nn.Module] = nn.Sequential(
                    nn.Linear(dim, retention_hidden, bias=False),
                    nn.GELU(),
                    nn.Linear(retention_hidden, 1, bias=True),
                )
                nn.init.constant_(self.retention_gate[-1].bias, 2.2)
            else:
                self.retention_gate = nn.Linear(dim, 1, bias=True)
                nn.init.constant_(self.retention_gate.bias, 2.2)
        else:
            self.retention_gate = None

        self._last_beta: Optional[torch.Tensor] = None

    def write(self, hidden: torch.Tensor) -> None:
        B, T, _ = hidden.shape
        tail_full = hidden[:, max(0, T - 16):, :].float().mean(dim=1)
        if self.retention_gate is not None:
            beta_live = torch.sigmoid(self.retention_gate(tail_full).squeeze(-1))
            self._last_beta = beta_live if self.training else None
            beta_store = beta_live.detach().float()
        else:
            self._last_beta = None
            beta_store = torch.ones(B, device=hidden.device, dtype=torch.float32)
        tail = tail_full.to(self.mem_emb.dtype).detach()
        with torch.no_grad():
            head = int(self.mem_head.item())
            count = int(self.mem_count.item())
            step = int(self.global_step.item())
            for b in range(B):
                self.mem_emb[head] = tail[b]
                self.mem_age[head] = step
                self.mem_beta[head] = beta_store[b]
                head = (head + 1) % self.cap
                if count < self.cap:
                    count += 1
            self.mem_head.fill_(head)
            self.mem_count.fill_(count)

    def read(self, x: torch.Tensor) -> torch.Tensor:
        count = int(self.mem_count.item())
        if count == 0:
            return torch.zeros_like(x)
        B, T, D = x.shape
        mem = self.mem_emb[:count].clone().to(x.dtype)
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1)
        v = self.v_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1)
        attn = torch.einsum("bhtd,hmd->bhtm", q, k) * self.scale
        attn = F.softmax(attn, dim=-1)
        if self.retention_enabled:
            step = int(self.global_step.item())
            ages = self.mem_age[:count].to(x.device)
            delta = (step - ages).clamp(min=0).to(x.dtype)
            betas = self.mem_beta[:count].to(x.dtype).clamp(min=1e-6, max=1.0)
            weights = betas.pow(delta)
            attn = attn * weights.view(1, 1, 1, count)
            attn = attn / attn.sum(dim=-1, keepdim=True).clamp_min(1e-9)
        out = torch.einsum("bhtm,hmd->bhtd", attn, v)
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        out = self.o_proj(out)
        return torch.sigmoid(self.gate_scale) * out

    @torch.no_grad()
    def reset(self) -> None:
        self.mem_emb.zero_()
        self.mem_age.zero_()
        self.mem_beta.fill_(1.0)
        self.mem_count.zero_()
        self.mem_head.zero_()
        self.global_step.zero_()
        self._last_beta = None


def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor:
    M = torch.exp(logits.clamp(-10, 10))
    for _ in range(n_iters):
        M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10)
        M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10)
    return M


class ManifoldHyperConnection(nn.Module):
    def __init__(self, dim: int, expansion: int = 2) -> None:
        super().__init__()
        self.dim = dim
        self.expansion = expansion
        n = expansion

        self.expand_fn = "duplicate"
        self.collapse_fn = "mean"

        self.bias_pre = nn.Parameter(torch.zeros(1, n))
        self.bias_post = nn.Parameter(torch.zeros(1, n))
        self.bias_res = nn.Parameter(torch.zeros(n, n))

        self.theta_pre = nn.Linear(n * dim, n, bias=False)
        self.theta_post = nn.Linear(n * dim, n, bias=False)
        self.theta_res = nn.Linear(n * dim, n * n, bias=False)

        self.alpha_pre = nn.Parameter(torch.tensor(0.0))
        self.alpha_post = nn.Parameter(torch.tensor(0.0))
        self.alpha_res = nn.Parameter(torch.tensor(0.0))

    def _compute_mappings(
        self, x_expanded: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        B, T, _ = x_expanded.shape
        n = self.expansion

        x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]])

        d_pre = torch.tanh(self.theta_pre(x_norm))
        d_post = torch.tanh(self.theta_post(x_norm))
        d_res = self.theta_res(x_norm)

        H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre)
        H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post)
        H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape(
            B, T, n, n
        )

        H_res = _sinkhorn_knopp(H_res_raw)

        return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res

    def expand_stream(self, x: torch.Tensor) -> torch.Tensor:
        return x.repeat(1, 1, self.expansion)

    def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor:
        B, T, _ = x_expanded.shape
        n = self.expansion
        C = self.dim
        return x_expanded.view(B, T, n, C).mean(dim=-2)

    def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor:
        B, T, _ = x_expanded.shape
        n = self.expansion
        x_streams = x_expanded.view(B, T, n, self.dim)
        return (H_pre @ x_streams).squeeze(-2)

    def post_res_mix(
        self,
        layer_output: torch.Tensor,
        x_expanded: torch.Tensor,
        H_post: torch.Tensor,
        H_res: torch.Tensor,
    ) -> torch.Tensor:
        B, T, _ = x_expanded.shape
        n = self.expansion
        C = self.dim

        x_streams = x_expanded.view(B, T, n, C)
        mixed = torch.matmul(H_res, x_streams)
        post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2))

        result = mixed + post_out
        return result.reshape(B, T, n * C)


class TransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        n_heads: int,
        n_kv_heads: int,
        head_dim: int,
        ffn_dim: int,
        dropout: float,
        sliding_window: int,
        rope_fraction: float,
        engram_dim: int = 0,
        engram_heads: int = 4,
        engram_table_size: int = 8192,
        engram_max_ngram: int = 3,
        mhc_expansion: int = 1,
    ) -> None:
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = CausalSelfAttention(
            dim=dim,
            n_heads=n_heads,
            n_kv_heads=n_kv_heads,
            head_dim=head_dim,
            dropout=dropout,
            sliding_window=sliding_window,
            rope_fraction=rope_fraction,
        )
        self.norm2 = RMSNorm(dim)
        self.ffn = SwiGLU(dim, ffn_dim, dropout)
        self.use_engram = engram_dim > 0
        if self.use_engram:
            self.engram = EngramBlock(
                dim=dim,
                engram_dim=engram_dim,
                n_heads=engram_heads,
                table_size=engram_table_size,
                max_ngram=engram_max_ngram,
            )
            self.engram_norm = RMSNorm(dim)
        self.use_mhc = mhc_expansion > 1
        if self.use_mhc:
            self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
            self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion)

    def forward(
        self,
        x: torch.Tensor,
        is_global: bool,
        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        use_cache: bool = False,
        token_ids: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        if self.use_mhc:
            x_exp = self.mhc_attn.expand_stream(x)
            H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp)
            attn_in = self.mhc_attn.pre_mix(x_exp, H_pre)
            attn_out, new_kv = self.attn(
                self.norm1(attn_in), is_global, past_kv, use_cache
            )
            x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res)
            if self.use_engram:
                collapsed = self.mhc_attn.collapse_stream(x_exp)
                collapsed = collapsed + self.engram(
                    self.engram_norm(collapsed), token_ids=token_ids
                )
                x_exp = self.mhc_attn.expand_stream(collapsed)
            H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp)
            ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2)
            ffn_out = self.ffn(self.norm2(ffn_in))
            x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2)
            x = self.mhc_attn.collapse_stream(x_exp)
        else:
            attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache)
            x = x + attn_out
            if self.use_engram:
                x = x + self.engram(self.engram_norm(x), token_ids=token_ids)
            x = x + self.ffn(self.norm2(x))
        return x, new_kv


class RecurrentDepthBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        n_heads: int,
        n_kv_heads: int,
        head_dim: int,
        ffn_dim: int,
        dropout: float,
        sliding_window: int,
        rope_fraction: float,
        n_loops: int,
        act_threshold: float,
        lora_rank: int,
        loop_embed_dim: int,
    ) -> None:
        super().__init__()
        self.n_loops = max(1, n_loops)
        self.act_threshold = act_threshold
        self.loop_embed_dim = max(0, loop_embed_dim)
        self.norm = RMSNorm(dim)
        self.block = TransformerBlock(
            dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
            ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
            rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1,
        )
        self.injection = StableRecurrentInjection(dim)
        self.act = AdaptiveHalting(dim)
        self.lora = DepthLoRAAdapter(dim, lora_rank, self.n_loops)

    def forward(
        self,
        h: torch.Tensor,
        e: torch.Tensor,
        token_ids: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
        use_cache: bool = False,
        n_loops: Optional[int] = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
        loops = max(1, n_loops or self.n_loops)
        B, T, _ = h.shape
        halted = torch.zeros(B, T, device=h.device, dtype=torch.bool)
        cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype)
        output = torch.zeros_like(h)
        new_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
        current = h
        final_halt = None

        for t in range(loops):
            h_loop = loop_index_embedding(current, t, self.loop_embed_dim)
            combined = self.norm(h_loop + e)
            past_kv = None
            if past_key_values is not None and t < len(past_key_values):
                past_kv = past_key_values[t]
            trans_out, layer_kv = self.block(combined, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=token_ids)
            trans_out = trans_out + self.lora(trans_out, t)
            next_h = self.injection(current, e, trans_out)
            p = self.act(next_h)
            p = p * (~halted).to(p.dtype)
            final_halt = p
            should_halt = (~halted) & ((cumulative_p + p) >= self.act_threshold)
            update_weight = torch.where(should_halt, (1.0 - cumulative_p).clamp(min=0.0), p)
            output = output + next_h * update_weight.unsqueeze(-1)
            cumulative_p = cumulative_p + update_weight
            current = torch.where(halted.unsqueeze(-1), current, next_h)
            halted = halted | should_halt
            if new_past is not None:
                new_past.append(layer_kv)
            if not use_cache and bool(halted.all()):
                break

        remainder = (1.0 - cumulative_p).clamp(min=0.0)
        output = output + current * remainder.unsqueeze(-1)
        aux: Dict[str, torch.Tensor] = {}
        if final_halt is not None:
            aux["recurrent_halt_mean"] = final_halt.mean()
        return output, aux, new_past


class TinyMemoryLM(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        dim: int,
        n_unique_layers: int,
        n_logical_layers: int,
        n_heads: int,
        n_kv_heads: int,
        ffn_dim: int,
        dropout: float,
        mtp_horizons: Sequence[int],
        grad_checkpoint: bool,
        sliding_window: int = 512,
        rope_fraction: float = 0.5,
        embed_scale: bool = True,
        engram_dim: int = 0,
        engram_heads: int = 4,
        engram_table_size: int = 8192,
        engram_max_ngram: int = 3,
        mhc_expansion: int = 1,
        sleep_gate_cap: int = 0,
        sleep_gate_heads: int = 4,
        sleep_retention_enabled: bool = True,
        sleep_retention_hidden: int = 0,
        latent_think_layers: int = 0,
        prelude_layers: int = 0,
        coda_layers: int = 0,
        recurrent_loops: int = 0,
        recurrent_act_threshold: float = 0.99,
        recurrent_lora_rank: int = 0,
        recurrent_loop_embed_dim: int = 0,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.n_unique_layers = n_unique_layers
        self.n_logical_layers = n_logical_layers
        self.grad_checkpoint = grad_checkpoint
        self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0
        head_dim = dim // n_heads

        self.embed_tokens = nn.Embedding(vocab_size, dim)
        self.head = nn.Linear(dim, vocab_size, bias=False)
        self.head.weight = self.embed_tokens.weight
        self.output_bias = nn.Parameter(torch.zeros(vocab_size))

        self.use_recurrent_depth = recurrent_loops > 0
        self.prelude_layers = max(0, prelude_layers)
        self.coda_layers = max(0, coda_layers)
        self.recurrent_loops = max(0, recurrent_loops)

        self.blocks: Optional[nn.ModuleList] = None
        self.prelude: Optional[nn.ModuleList] = None
        self.recurrent: Optional[RecurrentDepthBlock] = None
        self.coda: Optional[nn.ModuleList] = None

        def _make_blocks(n: int) -> nn.ModuleList:
            return nn.ModuleList([
                TransformerBlock(
                    dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
                    ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
                    rope_fraction=rope_fraction, engram_dim=engram_dim,
                    engram_heads=engram_heads, engram_table_size=engram_table_size,
                    engram_max_ngram=engram_max_ngram, mhc_expansion=mhc_expansion,
                )
                for _ in range(n)
            ])

        if self.use_recurrent_depth:
            if self.prelude_layers > 0:
                self.prelude = _make_blocks(self.prelude_layers)
            self.recurrent = RecurrentDepthBlock(
                dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
                ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
                rope_fraction=rope_fraction, n_loops=self.recurrent_loops,
                act_threshold=recurrent_act_threshold, lora_rank=recurrent_lora_rank,
                loop_embed_dim=recurrent_loop_embed_dim or max(2, dim // 8),
            )
            if self.coda_layers > 0:
                self.coda = _make_blocks(self.coda_layers)
        else:
            self.blocks = _make_blocks(max(1, n_unique_layers))

        self.norm = RMSNorm(dim)

        self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1})
        self.mtp_adapters = nn.ModuleDict(
            {str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons}
        )
        self.mtp_norms = nn.ModuleDict(
            {str(h): RMSNorm(dim) for h in self.mtp_horizons}
        )

        res_scale = (2 * max(1, n_logical_layers)) ** -0.5
        for group in (self.blocks, self.prelude, self.coda):
            if group is None:
                continue
            for block in group:
                block.attn.wo.weight.data.mul_(res_scale)
                block.ffn.down.weight.data.mul_(res_scale)
        if self.recurrent is not None:
            self.recurrent.block.attn.wo.weight.data.mul_(res_scale)
            self.recurrent.block.ffn.down.weight.data.mul_(res_scale)

        self.sleep_gate: Optional[SleepGate] = None
        if sleep_gate_cap > 0:
            self.sleep_gate = SleepGate(
                dim=dim, cap=sleep_gate_cap, n_heads=sleep_gate_heads,
                retention_enabled=sleep_retention_enabled,
                retention_hidden=sleep_retention_hidden,
            )

        self.think_blocks: Optional[nn.ModuleList] = None
        self.think_norm: Optional[RMSNorm] = None
        if latent_think_layers > 0:
            self.think_blocks = nn.ModuleList([
                TransformerBlock(
                    dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
                    ffn_dim=ffn_dim, dropout=0.0, sliding_window=2048,
                    rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1,
                )
                for _ in range(latent_think_layers)
            ])
            self.think_norm = RMSNorm(dim)

    def resize_token_embeddings(self, new_vocab_size: int) -> None:
        old_vocab_size = self.embed_tokens.num_embeddings
        if new_vocab_size == old_vocab_size:
            return
        device = self.embed_tokens.weight.device
        old_embed_weight = self.embed_tokens.weight.data.clone()
        self.embed_tokens = nn.Embedding(new_vocab_size, self.embed_tokens.embedding_dim).to(device)
        self.head = nn.Linear(self.embed_tokens.embedding_dim, new_vocab_size, bias=False).to(device)
        self.head.weight = self.embed_tokens.weight
        old_bias = self.output_bias.data.clone()
        self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device))
        copy_size = min(old_vocab_size, new_vocab_size)
        self.output_bias.data[:copy_size] = old_bias[:copy_size]
        self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size]

    def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]:
        if self.blocks is None:
            return []
        blocks_list = list(self.blocks)
        full_sequence = blocks_list + blocks_list
        return [(block, i) for i, block in enumerate(full_sequence[: self.n_logical_layers])]

    def forward(
        self,
        ids: torch.Tensor,
        use_cache: bool = False,
        past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
        return_hidden: bool = False,
    ) -> Tuple[torch.Tensor, Dict[int, torch.Tensor], Dict[str, torch.Tensor], Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
        B, T = ids.shape
        x = self.embed_tokens(ids) * self.embed_scale_factor
        new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
        aux: Dict[str, torch.Tensor] = {}

        if self.use_recurrent_depth:
            offset = 0
            if self.prelude is not None:
                for block in self.prelude:
                    past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None
                    x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids)
                    if new_past_key_values is not None:
                        new_past_key_values.append(layer_kv)
                    offset += 1
            encoded = x
            recurrent_past = past_key_values[offset: offset + self.recurrent_loops] if past_key_values is not None else None
            x, recurrent_aux, recurrent_kv = self.recurrent(
                x, encoded, token_ids=ids, past_key_values=recurrent_past, use_cache=use_cache,
            )
            aux.update(recurrent_aux)
            if new_past_key_values is not None and recurrent_kv is not None:
                new_past_key_values.extend(recurrent_kv)
            offset += self.recurrent_loops
            if self.coda is not None:
                for block in self.coda:
                    past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None
                    x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids)
                    if new_past_key_values is not None:
                        new_past_key_values.append(layer_kv)
                    offset += 1
        else:
            logical_layers = self._build_logical_layers()
            last_logical_idx = len(logical_layers) - 1
            for layer_idx, (block, logical_idx) in enumerate(logical_layers):
                is_global = logical_idx % 2 == 0 or layer_idx == last_logical_idx
                past_kv = past_key_values[layer_idx] if past_key_values is not None and layer_idx < len(past_key_values) else None
                if self.grad_checkpoint and self.training and not use_cache:
                    x, layer_kv = checkpoint(block, x, is_global, past_kv, use_cache, ids, use_reentrant=True)
                else:
                    x, layer_kv = block(x, is_global, past_kv, use_cache, ids)
                if new_past_key_values is not None:
                    new_past_key_values.append(layer_kv)

        x = self.norm(x)

        if self.sleep_gate is not None:
            x = x + self.sleep_gate.read(x)
            if self.training:
                self.sleep_gate.write(x)

        if self.think_blocks is not None:
            for think_block in self.think_blocks:
                x, _ = think_block(x, is_global=True)
            x = self.think_norm(x)

        h_out = x if return_hidden else None
        logits = self.head(x)
        if self.embed_scale_factor != 1.0:
            logits = logits / self.embed_scale_factor
        logits = logits + self.output_bias

        mtp: Dict[int, torch.Tensor] = {}
        if self.mtp_horizons and self.training:
            for horizon in self.mtp_horizons:
                if horizon > 1 and horizon <= T - 1:
                    shifted_h = x[:, :-horizon, :]
                    adapted_h = self.mtp_adapters[str(horizon)](shifted_h)
                    adapted_h = self.mtp_norms[str(horizon)](adapted_h)
                    mtp_logits = self.head(adapted_h)
                    if self.embed_scale_factor != 1.0:
                        mtp_logits = mtp_logits / self.embed_scale_factor
                    mtp_logits = mtp_logits + self.output_bias
                    mtp[horizon] = mtp_logits

        return logits, mtp, aux, h_out, new_past_key_values


# ---------------------------------------------------------------------------
# Generation
# ---------------------------------------------------------------------------


def build_stop_token_ids(tokenizer: WordTokenizer) -> set:
    stop_tokens = {tokenizer.eos_id}
    for tok in ("<|user|>", "<|system|>", "<|assistant|>"):
        tid = tokenizer.token_to_id.get(tok)
        if tid is not None:
            stop_tokens.add(int(tid))
    return stop_tokens


def apply_no_repeat_ngram(
    logits: torch.Tensor,
    token_history: Sequence[int],
    ngram_size: int,
) -> torch.Tensor:
    if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1):
        return logits
    prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple()
    banned: set = set()
    for i in range(len(token_history) - ngram_size + 1):
        if tuple(token_history[i : i + ngram_size - 1]) == prefix:
            banned.add(int(token_history[i + ngram_size - 1]))
    if not banned:
        return logits
    out = logits.clone()
    banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long)
    out[banned_ids] = float("-inf")
    return out


def apply_loop_penalty(
    logits: torch.Tensor,
    tokenizer: WordTokenizer,
    generated_text: str,
    penalty: float = 5.0,
) -> torch.Tensor:
    """Detect repeated substring loops and penalise continuation tokens."""
    if len(generated_text) < 16:
        return logits
    out = logits.clone()
    for span_len in [24, 16, 12, 8]:
        if len(generated_text) < span_len * 2:
            continue
        suffix = generated_text[-span_len:]
        prev = generated_text[:-span_len].rfind(suffix)
        if prev == -1:
            continue
        next_pos = prev + span_len
        if next_pos < len(generated_text):
            next_char = generated_text[next_pos]
            tid = tokenizer.token_to_id.get(next_char)
            if tid is not None:
                out[tid] -= penalty
        break
    return out


def apply_min_p(logits: torch.Tensor, min_p: float) -> torch.Tensor:
    """Filter tokens below min_p fraction of the top token probability."""
    if min_p <= 0.0:
        return logits
    probs = torch.softmax(logits, dim=-1)
    threshold = probs.max() * min_p
    out = logits.clone()
    out[probs < threshold] = float("-inf")
    return out


def generate(
    model: TinyMemoryLM,
    tokenizer: WordTokenizer,
    prompt: str,
    max_new_tokens: int = 256,
    temperature: float = 0.8,
    top_k: int = 16,
    top_p: float = 0.95,
    repetition_penalty: float = 1.0,
    device: str = "cuda",
    sft_mode: bool = True,
    stream: bool = True,
    no_repeat_ngram_size: int = 0,
    context_window: int = 2048,
    logit_soft_cap: float = 15.0,
    min_p: float = 0.05,
    loop_penalty: float = 5.0,
) -> str:
    if sft_mode:
        full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
    else:
        full_prompt = prompt
    input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False)
    input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device)
    visible_tokens: List[str] = []
    stop_token_ids = build_stop_token_ids(tokenizer)
    generated_text = ""

    generated_ids: List[int] = []
    # Full history (prompt + generated) for ngram blocking — prevents echoing prompt
    full_ids_history: List[int] = list(input_ids)

    with torch.no_grad():
        for _ in range(max_new_tokens):
            ctx_ids = (
                input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t
            )
            logits, *_ = model(ctx_ids)
            next_logits = logits[0, -1, :].clone()

            # Logit soft-capping (Gemma-style) — prevents overconfident collapse
            if logit_soft_cap > 0:
                next_logits = logit_soft_cap * torch.tanh(next_logits / logit_soft_cap)

            raw_next_logits = next_logits.clone()

            # Repetition penalty on previously generated tokens
            if repetition_penalty != 1.0 and generated_ids:
                for tok_id in set(generated_ids):
                    if next_logits[tok_id] > 0:
                        next_logits[tok_id] /= repetition_penalty
                    else:
                        next_logits[tok_id] *= repetition_penalty

            # No-repeat n-gram blocking on generated tokens only
            if no_repeat_ngram_size > 0 and generated_ids:
                next_logits = apply_no_repeat_ngram(next_logits, generated_ids, no_repeat_ngram_size)

            # Substring loop detection
            next_logits = apply_loop_penalty(next_logits, tokenizer, generated_text, penalty=loop_penalty)

            # Temperature scaling
            if temperature != 1.0:
                next_logits = next_logits / max(temperature, 1e-6)

            # Min-p filtering — remove tokens below min_p * max_prob
            if min_p > 0:
                next_logits = apply_min_p(next_logits, min_p)

            # Top-k filtering
            if top_k > 0:
                v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0)))
                next_logits[next_logits < v[-1]] = float("-inf")

            # Top-p (nucleus) filtering
            if 0.0 < top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
                sorted_probs = torch.softmax(sorted_logits, dim=-1)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                remove_mask = cumulative_probs > top_p
                remove_mask[0] = False
                indices_to_remove = sorted_indices[remove_mask]
                next_logits[indices_to_remove] = float("-inf")

            # Fallback if all tokens masked
            if not torch.isfinite(next_logits).any():
                next_logits = raw_next_logits
                if temperature != 1.0:
                    next_logits = next_logits / max(temperature, 1e-6)

            if temperature == 0:
                next_id = torch.argmax(next_logits).item()
            else:
                probs = torch.softmax(next_logits, dim=-1)
                next_id = torch.multinomial(probs, num_samples=1).item()
            if next_id in stop_token_ids:
                break
            token_str = (
                tokenizer.id_to_token[next_id]
                if next_id < len(tokenizer.id_to_token)
                else ""
            )
            generated_ids.append(next_id)
            full_ids_history.append(next_id)
            if token_str not in tokenizer.special:
                visible_tokens.append(token_str)
                generated_text += token_str
                if stream:
                    print(token_str, end="", flush=True)
            input_ids_t = torch.cat(
                [input_ids_t, torch.tensor([[next_id]], device=device)], dim=1
            )
    if stream:
        print()
    return "".join(visible_tokens)


# ---------------------------------------------------------------------------
# Local model loading
# ---------------------------------------------------------------------------


def series_from_name(name: str) -> str | None:
    lower = (name or "").lower()
    if "haiku" in lower:
        return "Haiku"
    if "sonnet" in lower:
        return "Sonnet"
    if "opus" in lower:
        return "Opus"
    return None


def series_config(series: str) -> dict[str, object]:
    return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"])


def discover_models(runs_dir: Path) -> List[dict]:
    models = []
    if not runs_dir.is_dir():
        return models
    for child in sorted(runs_dir.iterdir()):
        if not child.is_dir():
            continue
        tokenizer_path = child / "tokenizer.json"
        if not tokenizer_path.exists():
            continue
        name = child.name
        series = None
        for ckpt_name in ("model.pt", "pretrain.pt"):
            ckpt_path = child / ckpt_name
            if ckpt_path.exists():
                series = _fast_series_from_checkpoint(ckpt_path)
                break
        if series is None:
            series = series_from_name(name) or "Sonnet"
        found = False
        for ckpt_name in ("model.pt", "model_rep.pt", "pretrain.pt"):
            ckpt_path = child / ckpt_name
            if ckpt_path.exists():
                models.append(
                    {
                        "name": name,
                        "checkpoint": ckpt_name,
                        "series": series,
                        "model_path": ckpt_path,
                        "tokenizer_path": tokenizer_path,
                    }
                )
                found = True
        if not found:
            step_ckpts = sorted(
                child.glob("checkpoint_step_*.pt"),
                key=lambda p: int(p.stem.rsplit("_", 1)[-1]),
            )
            if step_ckpts:
                ckpt_path = step_ckpts[-1]
                models.append(
                    {
                        "name": name,
                        "checkpoint": ckpt_path.name,
                        "series": series,
                        "model_path": ckpt_path,
                        "tokenizer_path": tokenizer_path,
                    }
                )
    return models


def _detect_engram(state_dict):
    for key in state_dict:
        if ".engram." in key:
            if ".embeddings." in key:
                return state_dict[key].shape[-1]
    return 0


def _detect_mhc(state_dict):
    for key, val in state_dict.items():
        if ".mhc_attn.bias_pre" in key and val.dim() == 2:
            return val.shape[-1]  # (1, expansion)
    return 1


def _detect_sleep_gate(state_dict) -> Tuple[int, int]:
    for key, val in state_dict.items():
        if key == "sleep_gate.mem_emb" and val.dim() == 2:
            cap = val.shape[0]
            return cap, 4
    return 0, 4


def _detect_latent_think(state_dict) -> int:
    indices = {
        int(k.split(".")[1])
        for k in state_dict
        if k.startswith("think_blocks.") and k.split(".")[1].isdigit()
    }
    return max(indices) + 1 if indices else 0


def _detect_prelude_layers(state_dict) -> int:
    indices = {
        int(k.split(".")[1])
        for k in state_dict
        if k.startswith("prelude.") and k.split(".")[1].isdigit()
    }
    return max(indices) + 1 if indices else 0


def _detect_coda_layers(state_dict) -> int:
    indices = {
        int(k.split(".")[1])
        for k in state_dict
        if k.startswith("coda.") and k.split(".")[1].isdigit()
    }
    return max(indices) + 1 if indices else 0


def _detect_recurrent_loops(state_dict) -> int:
    if "recurrent.norm.weight" in state_dict or "recurrent.block.attn.wq.weight" in state_dict:
        if "recurrent.lora.scale.weight" in state_dict:
            return state_dict["recurrent.lora.scale.weight"].shape[0]
        return 1
    return 0


def _detect_recurrent_lora_rank(state_dict) -> int:
    for key in ("recurrent.lora.B", "recurrent.lora.down.weight"):
        if key in state_dict:
            shape = state_dict[key].shape
            if len(shape) == 2:
                return int(shape[0])
    return 0


def _infer_series_from_lora_rank(rank: int) -> str | None:
    if rank == 0:
        return None
    if rank <= 8:
        return "haiku"
    if rank <= 16:
        return "sonnet"
    return "opus"


def _fast_series_from_checkpoint(ckpt_path: Path) -> str | None:
    try:
        cp = torch.load(ckpt_path, map_location="cpu", weights_only=False)
        sd = cp.get("model_state", cp.get("state_dict", {}))
        rank = 0
        for key in ("recurrent.lora.B", "recurrent.lora.down.weight"):
            if key in sd:
                rank = int(sd[key].shape[0])
                break
        if rank == 0:
            return None
        if rank <= 8:
            return "Haiku"
        if rank <= 16:
            return "Sonnet"
        return "Opus"
    except Exception:
        pass
    return None


def _infer_arch_from_state_dict(state_dict, cfg):
    """Infer architecture hyper-parameters directly from checkpoint weights,
    falling back to *cfg* (series config) when a key is not found."""
    overrides = {}

    has_prelude = any(k.startswith("prelude.") for k in state_dict)
    has_blocks = any(k.startswith("blocks.") for k in state_dict)
    has_recurrent = any(k.startswith("recurrent.") for k in state_dict)
    uses_recurrent_arch = has_prelude and has_recurrent and not has_blocks

    # dim from embed_tokens.weight [vocab, dim]
    if "embed_tokens.weight" in state_dict:
        overrides["dim"] = state_dict["embed_tokens.weight"].shape[1]

    if uses_recurrent_arch:
        if "prelude.0.ffn.gate.weight" in state_dict:
            overrides["ffn_dim"] = state_dict["prelude.0.ffn.gate.weight"].shape[0]
        overrides["n_unique_layers"] = 0
        src = "prelude.0"
    else:
        if "blocks.0.ffn.gate.weight" in state_dict:
            overrides["ffn_dim"] = state_dict["blocks.0.ffn.gate.weight"].shape[0]
        block_ids = {
            int(k.split(".")[1])
            for k in state_dict
            if k.startswith("blocks.") and k.split(".")[1].isdigit()
        }
        if block_ids:
            overrides["n_unique_layers"] = max(block_ids) + 1
        src = "blocks.0"

    dim = overrides.get("dim", int(cfg.get("dim", model_config.dim)))
    if f"{src}.attn.wq.weight" in state_dict:
        wq_rows = state_dict[f"{src}.attn.wq.weight"].shape[0]
        if f"{src}.attn.q_norm.weight" in state_dict:
            head_dim = state_dict[f"{src}.attn.q_norm.weight"].shape[0]
            overrides["n_heads"] = wq_rows // head_dim
    if f"{src}.attn.wk.weight" in state_dict:
        wk_rows = state_dict[f"{src}.attn.wk.weight"].shape[0]
        if f"{src}.attn.k_norm.weight" in state_dict:
            head_dim = state_dict[f"{src}.attn.k_norm.weight"].shape[0]
            overrides["n_kv_heads"] = wk_rows // head_dim

    # engram params
    for key, val in state_dict.items():
        if ".engram.embeddings." in key and key.endswith("_0") and val.dim() == 2:
            overrides["engram_table_size"] = val.shape[0]
            overrides["engram_dim"] = val.shape[1]
            break
    engram_dim = overrides.get("engram_dim", int(cfg.get("engram_dim", 0)))
    engram_max_ngram = int(cfg.get("engram_max_ngram", 2))
    if engram_dim > 0:
        for key, val in state_dict.items():
            if ".engram.branch_conv.weight" in key and val.dim() == 3:
                total_branch_dim = val.shape[0]
                denom = engram_dim * (engram_max_ngram - 1)
                if denom > 0 and total_branch_dim % denom == 0:
                    overrides["engram_heads"] = total_branch_dim // denom
                break

    merged = dict(cfg)
    merged.update(overrides)
    return merged


def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dict:
    tokenizer = WordTokenizer.load(tokenizer_path)
    ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
    cfg = series_config(series)
    vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size))

    state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt

    cfg = _infer_arch_from_state_dict(state_dict, cfg)

    engram_dim = int(cfg.get("engram_dim", 0))
    if _detect_engram(state_dict) == 0:
        engram_dim = 0

    mhc_expansion = _detect_mhc(state_dict)
    if mhc_expansion == 1:
        mhc_expansion = int(cfg.get("mhc_expansion", 1))

    ckpt_sleep_cap, ckpt_sleep_heads = _detect_sleep_gate(state_dict)
    sleep_gate_cap = ckpt_sleep_cap if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_cap", 0))
    sleep_gate_heads = ckpt_sleep_heads if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_heads", 4))
    sleep_retention_enabled = bool(cfg.get("sleep_retention_enabled", True))
    sleep_retention_hidden = int(cfg.get("sleep_retention_hidden", 0))

    latent_think_layers = _detect_latent_think(state_dict)
    if latent_think_layers == 0:
        latent_think_layers = int(cfg.get("latent_think_layers", 0))

    prelude_layers = _detect_prelude_layers(state_dict)
    coda_layers = _detect_coda_layers(state_dict)
    recurrent_loops = _detect_recurrent_loops(state_dict)

    ckpt_lora_rank = _detect_recurrent_lora_rank(state_dict)
    if ckpt_lora_rank > 0:
        inferred_series = _infer_series_from_lora_rank(ckpt_lora_rank)
        if inferred_series and inferred_series != series.lower():
            series = inferred_series.capitalize()
            cfg = series_config(series)
        recurrent_lora_rank = ckpt_lora_rank
    else:
        recurrent_lora_rank = int(cfg.get("recurrent_lora_rank", 0))

    recurrent_act_threshold = float(cfg.get("recurrent_act_threshold", 0.99))
    recurrent_loop_embed_dim = int(cfg.get("recurrent_loop_embed_dim", 0))

    n_unique = int(cfg.get("n_unique_layers", model_config.n_unique_layers))

    model = TinyMemoryLM(
        vocab_size=vocab_size,
        dim=int(cfg.get("dim", model_config.dim)),
        n_unique_layers=n_unique,
        n_logical_layers=int(cfg.get("n_logical_layers", model_config.n_logical_layers)),
        n_heads=int(cfg.get("n_heads", model_config.n_heads)),
        n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)),
        ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)),
        dropout=float(cfg.get("dropout", model_config.dropout)),
        mtp_horizons=tuple(int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)),
        grad_checkpoint=False,
        sliding_window=int(cfg.get("sliding_window_size", getattr(model_config, "sliding_window_size", 512))),
        rope_fraction=float(cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))),
        embed_scale=bool(cfg.get("embed_scale", getattr(model_config, "embed_scale", True))),
        engram_dim=engram_dim,
        engram_heads=int(cfg.get("engram_heads", 4)),
        engram_table_size=int(cfg.get("engram_table_size", 8192)),
        engram_max_ngram=int(cfg.get("engram_max_ngram", 3)),
        mhc_expansion=mhc_expansion,
        sleep_gate_cap=sleep_gate_cap,
        sleep_gate_heads=sleep_gate_heads,
        sleep_retention_enabled=sleep_retention_enabled,
        sleep_retention_hidden=sleep_retention_hidden,
        latent_think_layers=latent_think_layers,
        prelude_layers=prelude_layers,
        coda_layers=coda_layers,
        recurrent_loops=recurrent_loops,
        recurrent_act_threshold=recurrent_act_threshold,
        recurrent_lora_rank=recurrent_lora_rank,
        recurrent_loop_embed_dim=recurrent_loop_embed_dim,
    )
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    if tokenizer.vocab_size > vocab_size:
        model.resize_token_embeddings(tokenizer.vocab_size)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    return {
        "model": model,
        "tokenizer": tokenizer,
        "device": device,
        "series": series,
        "sft_mode": ckpt.get("sft_mode", None),
        "phase": ckpt.get("phase", None),
    }


# ---------------------------------------------------------------------------
# HuggingFace Model Download & Loading
# ---------------------------------------------------------------------------

def download_huggingface_model(hf_id: str, cache_dir: Path) -> dict:
    try:
        from huggingface_hub import snapshot_download
    except ImportError:
        print("huggingface_hub not installed. Install with: pip install huggingface_hub")
        sys.exit(1)

    print(f"Downloading {hf_id}...")
    try:
        local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir)))
    except Exception as e:
        print(f"Failed to download {hf_id}: {e}")
        return None

    print(f"Using cached {hf_id} from {local_dir}")

    # Check common subdirectory names: "models/", "model/"
    if (local_dir / "models").exists():
        model_dir = local_dir / "models"
    elif (local_dir / "model").exists():
        model_dir = local_dir / "model"
    else:
        model_dir = local_dir
    model_path = model_dir / "model.pt"
    pretrain_path = model_dir / "pretrain.pt"
    tokenizer_path = model_dir / "tokenizer.json"

    ckpt_path = None
    for p in [model_path, pretrain_path]:
        if p.exists():
            ckpt_path = p
            break

    if ckpt_path is None or not tokenizer_path.exists():
        print(f"Missing model files in {model_dir}")
        print(f"  model.pt exists: {model_path.exists()}")
        print(f"  pretrain.pt exists: {pretrain_path.exists()}")
        print(f"  tokenizer.json exists: {tokenizer_path.exists()}")
        return None

    return {
        "model_path": ckpt_path,
        "tokenizer_path": tokenizer_path,
        "model_name": ckpt_path.stem,
    }


def load_huggingface_model(hf_id: str, cache_dir: Path) -> dict:
    files = download_huggingface_model(hf_id, cache_dir)
    if files is None:
        return None

    return load_local_model(files["model_path"], files["tokenizer_path"], "Haiku")


# ---------------------------------------------------------------------------
# Compare All Models
# ---------------------------------------------------------------------------

_hf_model_cache: Dict[str, dict] = {}


def prefetch_huggingface_models() -> None:
    root = Path(__file__).resolve().parent
    cache_dir = root / "cache" / "huggingface"
    cache_dir.mkdir(parents=True, exist_ok=True)

    print("Downloading/preparing HuggingFace models...")
    for name, hf_id in HUGGINGFACE_MODELS.items():
        print(f"  {name}...")
        bundle = load_huggingface_model(hf_id, cache_dir)
        if bundle:
            _hf_model_cache[name] = bundle
    print(f"Prepared {len(_hf_model_cache)} HuggingFace models")


def compare_all_models(prompt: str, cfg: dict) -> None:
    root = Path(__file__).resolve().parent
    runs_dir = root / "runs"
    all_models = discover_models(runs_dir)

    is_pretrain = not cfg.get("sft_mode", True)
    local_models = [
        m for m in all_models
        if ("pretrain" in m["checkpoint"]) == is_pretrain
    ]

    if not local_models:
        print("No models found matching mode.")
        return

    results: List[dict] = []

    for m in local_models:
        print(f"\n{'='*60}")
        print(f"Loading local {m['name']}/{m['checkpoint']}...")
        try:
            bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
        except Exception as e:
            print(f"Failed to load {m['name']}: {e}")
            continue

        model = bundle["model"]
        tokenizer = bundle["tokenizer"]
        device = bundle["device"]

        print(f"Generating on '{prompt}'...")
        output = generate(
            model=model,
            tokenizer=tokenizer,
            prompt=prompt,
            max_new_tokens=cfg["max_new_tokens"],
            temperature=cfg["temperature"],
            top_k=cfg["top_k"],
            top_p=cfg["top_p"],
            min_p=cfg["min_p"],
            no_repeat_ngram_size=cfg["no_repeat_ngram_size"],
            repetition_penalty=cfg["repetition_penalty"],
            logit_soft_cap=cfg["logit_soft_cap"],
            loop_penalty=cfg["loop_penalty"],
            device=str(device),
            sft_mode=cfg["sft_mode"],
            stream=True,
            context_window=cfg["context_window"],
        )

        results.append({
            "name": f"[LOCAL] {m['name']}/{m['checkpoint']}",
            "output": output,
            "device": device,
        })

    for name, bundle in _hf_model_cache.items():
        print(f"\n{'='*60}")
        print(f"Loading {name} (cached)...")

        model = bundle["model"]
        tokenizer = bundle["tokenizer"]
        device = bundle["device"]

        print(f"Generating on '{prompt}'...")
        output = generate(
            model=model,
            tokenizer=tokenizer,
            prompt=prompt,
            max_new_tokens=cfg["max_new_tokens"],
            temperature=cfg["temperature"],
            top_k=cfg["top_k"],
            top_p=cfg["top_p"],
            min_p=cfg["min_p"],
            no_repeat_ngram_size=cfg["no_repeat_ngram_size"],
            repetition_penalty=cfg["repetition_penalty"],
            logit_soft_cap=cfg["logit_soft_cap"],
            loop_penalty=cfg["loop_penalty"],
            device=str(device),
            sft_mode=cfg["sft_mode"],
            stream=True,
            context_window=cfg["context_window"],
        )

        results.append({
            "name": name,
            "output": output,
            "device": device,
        })

    print(f"\n{'='*60}")
    print("=" * 60)
    print("SIDE-BY-SIDE COMPARISON")
    print("=" * 60)
    for r in results:
        print(f"\n--- {r['name']} ---")
        print(r["output"])
    print(f"\n{'='*60}")


# ---------------------------------------------------------------------------
# Benchmark
# ---------------------------------------------------------------------------

BENCHMARKS = {
    "blimp": {
        "label": "BLiMP",
        "desc": "Grammaticality minimal pairs (67 paradigms). Accuracy = % grammatical < ungrammatical perplexity.",
        "hf_dataset": ("nyu-mll/blimp", None),
        "metric": "accuracy",
    },
    "wikitext2": {
        "label": "WikiText-2",
        "desc": "LM perplexity on Wikipedia test split. Lower is better.",
        "hf_dataset": ("Salesforce/wikitext", "wikitext-2-raw-v1"),
        "metric": "perplexity",
    },
    "arc_easy": {
        "label": "ARC-Easy",
        "desc": "Multiple-choice science QA (~2.4K). Perplexity-based answer selection.",
        "hf_dataset": ("allenai/ai2_arc", "ARC-Easy"),
        "metric": "accuracy",
    },
}


def _score_text(model: TinyMemoryLM, tokenizer: WordTokenizer, text: str, device: str) -> float:
    ids = tokenizer.encode(text, add_bos=True, add_eos=False)
    if len(ids) < 2:
        return float("nan")
    ids_t = torch.tensor([ids], dtype=torch.long, device=device)
    with torch.no_grad():
        logits, *_ = model(ids_t)
    log_probs = F.log_softmax(logits[0], dim=-1)
    targets = ids_t[0, 1:]
    nll = -log_probs[range(len(targets)), targets].mean().item()
    return nll


def _score_completion(model: TinyMemoryLM, tokenizer: WordTokenizer, context: str, completion: str, device: str) -> float:
    full_ids = tokenizer.encode(context + completion, add_bos=True, add_eos=False)
    ctx_ids = tokenizer.encode(context, add_bos=True, add_eos=False)
    n_ctx = len(ctx_ids)
    n_ref = len(full_ids) - n_ctx
    if n_ref <= 0:
        return float("nan")
    ids_t = torch.tensor([full_ids], dtype=torch.long, device=device)
    with torch.no_grad():
        logits, *_ = model(ids_t)
    log_probs = F.log_softmax(logits[0], dim=-1)
    targets = ids_t[0, 1:]
    ref_start = n_ctx - 1
    ref_end = min(ref_start + n_ref, log_probs.shape[0])
    if ref_start >= ref_end:
        return float("nan")
    nll = -log_probs[ref_start:ref_end][range(ref_end - ref_start), targets[ref_start:ref_end]].mean().item()
    return nll


BLIMP_PARADIGMS = [
    "adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement",
    "animate_subject_passive", "animate_subject_trans", "causative",
    "complex_NP_island", "coordinate_structure_constraint_complex_left_branch",
    "coordinate_structure_constraint_object_extraction",
    "determiner_noun_agreement_1", "determiner_noun_agreement_2",
    "determiner_noun_agreement_irregular_1", "determiner_noun_agreement_irregular_2",
    "determiner_noun_agreement_with_adj_2", "determiner_noun_agreement_with_adj_irregular_1",
    "determiner_noun_agreement_with_adj_irregular_2", "determiner_noun_agreement_with_adjective_1",
    "distractor_agreement_relational_noun", "distractor_agreement_relative_clause",
    "drop_argument", "ellipsis_n_bar_1", "ellipsis_n_bar_2",
    "existential_there_object_raising", "existential_there_quantifiers_1",
    "existential_there_quantifiers_2", "existential_there_subject_raising",
    "expletive_it_object_raising", "inchoative", "intransitive",
    "irregular_past_participle_adjectives", "irregular_past_participle_verbs",
    "irregular_plural_subject_verb_agreement_1", "irregular_plural_subject_verb_agreement_2",
    "left_branch_island_echo_question", "left_branch_island_simple_question",
    "matrix_question_npi_licensor_present", "npi_present_1", "npi_present_2",
    "only_npi_licensor_present", "only_npi_scope", "passive_1", "passive_2",
    "principle_A_c_command", "principle_A_case_1", "principle_A_case_2",
    "principle_A_domain_1", "principle_A_domain_2", "principle_A_domain_3",
    "principle_A_reconstruction", "regular_plural_subject_verb_agreement_1",
    "regular_plural_subject_verb_agreement_2", "sentential_negation_npi_licensor_present",
    "sentential_negation_npi_scope", "sentential_subject_island",
    "superlative_quantifiers_1", "superlative_quantifiers_2",
    "tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island",
    "wh_questions_object_gap", "wh_questions_subject_gap",
    "wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap",
    "wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap",
    "wh_vs_that_with_gap_long_distance",
]


def _run_blimp(model, tokenizer, device, n_samples: int = 200) -> Tuple[List[str], List[float]]:
    from datasets import load_dataset  # type: ignore
    accuracies: List[float] = []
    for paradigm in BLIMP_PARADIGMS:
        try:
            ds = load_dataset("nyu-mll/blimp", paradigm, split="train")
        except Exception as e:
            print(f"  {paradigm}: skip ({e})")
            accuracies.append(float("nan"))
            continue
        items = list(ds)[:n_samples]
        correct = 0
        for ex in items:
            good_nll = _score_text(model, tokenizer, ex["sentence_good"], device)
            bad_nll = _score_text(model, tokenizer, ex["sentence_bad"], device)
            if math.isnan(good_nll) or math.isnan(bad_nll):
                continue
            if good_nll < bad_nll:
                correct += 1
        acc = correct / len(items) if items else float("nan")
        accuracies.append(acc)
        print(f"  {paradigm:50s}  acc={acc:.3f}")
    return BLIMP_PARADIGMS, accuracies


def _run_wikitext2(model, tokenizer, device, chunk_chars: int = 512, max_chunks: int = 100) -> Tuple[List[str], List[float]]:
    from datasets import load_dataset  # type: ignore
    ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")
    full_text = "\n".join(ex["text"] for ex in ds if ex["text"].strip())
    chunks = [full_text[i:i + chunk_chars] for i in range(0, len(full_text), chunk_chars)]
    chunks = [c for c in chunks if len(c) > 20][:max_chunks]
    labels: List[str] = []
    ppls: List[float] = []
    for i, chunk in enumerate(chunks):
        nll = _score_text(model, tokenizer, chunk, device)
        ppl = math.exp(nll) if not math.isnan(nll) else float("nan")
        labels.append(f"chunk {i + 1}")
        ppls.append(ppl)
        if (i + 1) % 10 == 0:
            valid = [v for v in ppls if not math.isnan(v)]
            mean = sum(valid) / len(valid) if valid else float("nan")
            print(f"  chunk {i + 1}/{len(chunks)}  running mean ppl={mean:.2f}")
    return labels, ppls


def _run_arc_easy(model, tokenizer, device, max_samples: int = 200) -> Tuple[List[str], List[float]]:
    from datasets import load_dataset  # type: ignore
    ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test")
    items = list(ds)[:max_samples]
    labels: List[str] = []
    scores: List[float] = []
    for i, ex in enumerate(items):
        question = ex["question"]
        choices = ex["choices"]["text"]
        choice_labels = ex["choices"]["label"]
        answer_key = ex["answerKey"]
        context = f"Question: {question}\nAnswer:"
        nlls = [_score_completion(model, tokenizer, context, f" {c}", device) for c in choices]
        if all(math.isnan(v) for v in nlls):
            scores.append(float("nan"))
        else:
            best_idx = min(range(len(nlls)), key=lambda j: nlls[j] if not math.isnan(nlls[j]) else float("inf"))
            predicted = choice_labels[best_idx]
            scores.append(1.0 if predicted == answer_key else 0.0)
        labels.append(f"Q{i + 1}")
    n_valid = sum(1 for s in scores if not math.isnan(s))
    acc = sum(s for s in scores if not math.isnan(s)) / n_valid if n_valid else float("nan")
    print(f"  {n_valid} questions evaluated, accuracy={acc:.3f}")
    return labels, scores


def run_benchmark_mode() -> None:
    try:
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
    except ImportError:
        print("matplotlib not installed. pip install matplotlib")
        return

    bench_keys = list(BENCHMARKS.keys())
    print("\nBenchmarks:")
    for i, k in enumerate(bench_keys):
        b = BENCHMARKS[k]
        print(f"  [{i + 1}] {b['label']} — {b['desc']}")
    print("Select benchmark [1]:", end=" ", flush=True)
    try:
        b_choice = input().strip() or "1"
    except (EOFError, KeyboardInterrupt):
        print()
        return
    if not (b_choice.isdigit() and 1 <= int(b_choice) <= len(bench_keys)):
        print("Invalid selection.")
        return
    bench_key = bench_keys[int(b_choice) - 1]
    bench = BENCHMARKS[bench_key]
    print(f"Benchmark: {bench['label']}")

    root = Path(__file__).resolve().parent
    runs_dir = root / "runs"
    all_models = discover_models(runs_dir)

    model_entries: List[dict] = []
    for m in all_models:
        model_entries.append({"label": f"[LOCAL] {m['name']}/{m['checkpoint']}", "type": "local", "meta": m})
    for hf_name, hf_id in HUGGINGFACE_MODELS.items():
        model_entries.append({"label": f"[HF] {hf_name}", "type": "hf", "hf_id": hf_id, "hf_name": hf_name})

    if not model_entries:
        print("No models found.")
        return

    print("\nAvailable models:")
    for i, e in enumerate(model_entries):
        print(f"  [{i + 1}] {e['label']}")
    print("  [a] All models")
    print("Select models (comma-separated or 'a'):", end=" ", flush=True)
    try:
        raw = input().strip()
    except (EOFError, KeyboardInterrupt):
        print()
        return

    if raw.lower() == "a":
        selected = list(range(len(model_entries)))
    else:
        selected = []
        for tok in raw.split(","):
            tok = tok.strip()
            if tok.isdigit() and 1 <= int(tok) <= len(model_entries):
                selected.append(int(tok) - 1)
    if not selected:
        print("No valid selection.")
        return

    all_results: List[dict] = []
    shared_x_labels: Optional[List[str]] = None

    for idx in selected:
        entry = model_entries[idx]
        print(f"\n{'='*60}\nLoading {entry['label']}...")
        try:
            if entry["type"] == "local":
                m = entry["meta"]
                bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
            else:
                bundle = load_huggingface_model(entry["hf_id"], root / ".hf_cache")
        except Exception as e:
            print(f"  Failed: {e}")
            continue

        model = bundle["model"]
        tokenizer = bundle["tokenizer"]
        device = str(bundle["device"])
        model.eval()

        if bench_key == "blimp":
            x_labels, y_vals = _run_blimp(model, tokenizer, device)
        elif bench_key == "wikitext2":
            x_labels, y_vals = _run_wikitext2(model, tokenizer, device)
        else:
            x_labels, y_vals = _run_arc_easy(model, tokenizer, device)

        if shared_x_labels is None:
            shared_x_labels = x_labels

        valid = [v for v in y_vals if not math.isnan(v)]
        summary = sum(valid) / len(valid) if valid else float("nan")
        all_results.append({"label": entry["label"], "y": y_vals, "summary": summary})

    if not all_results or shared_x_labels is None:
        print("No results to plot.")
        return

    metric = bench["metric"]
    paired = sorted(zip([r["summary"] for r in all_results], [r["label"] for r in all_results]),
                    reverse=(metric != "perplexity"))
    summaries, model_labels = zip(*paired) if paired else ([], [])
    n = len(summaries)
    colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)]

    fig, ax = plt.subplots(figsize=(max(6, n * 1.4), 6))
    bars = ax.bar(range(n), summaries, color=colors, edgecolor="black")
    for bar, val in zip(bars, summaries):
        ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
                f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold")

    ylabel = "Mean Perplexity (↓ better)" if metric == "perplexity" else "Mean Accuracy (↑ better)"
    ax.set_ylabel(ylabel)
    ax.set_title(f"{bench['label']} Benchmark — Model Comparison")
    ax.set_xticks(range(n))
    ax.set_xticklabels(model_labels, rotation=20, ha="right", fontsize=9)
    if metric == "accuracy":
        ax.set_ylim(0, 1.05)
    ax.grid(True, axis="y", alpha=0.3)
    plt.tight_layout()

    out_path = root / f"benchmark_{bench_key}.png"
    plt.savefig(str(out_path), dpi=150)
    print(f"\nChart saved to {out_path}")
    try:
        import subprocess
        subprocess.Popen(["xdg-open", str(out_path)])
    except Exception:
        pass


# ---------------------------------------------------------------------------
# Interactive CLI
# ---------------------------------------------------------------------------


def _pick_series(detected: str) -> str:
    series_list = list(MODEL_SERIES.keys())
    detected_lower = detected.lower()
    default_idx = next(
        (i + 1 for i, s in enumerate(series_list) if s == detected_lower), 1
    )

    # Skip selection if only one series available
    if len(series_list) == 1:
        return series_list[0].capitalize()

    print("Series:")
    for i, s in enumerate(series_list):
        marker = " (detected)" if s == detected_lower else ""
        print(f"  [{i + 1}] {s.capitalize()}{marker}")
    while True:
        try:
            choice = input(f"Select series [{default_idx}]: ").strip()
        except (EOFError, KeyboardInterrupt):
            print()
            sys.exit(0)
        if not choice:
            choice = str(default_idx)
        if choice.isdigit() and 1 <= int(choice) <= len(series_list):
            return series_list[int(choice) - 1].capitalize()
        print(f"Enter a number 1-{len(series_list)}")


def pick_model(runs_dir: Path) -> tuple[dict, str]:
    models = discover_models(runs_dir)
    if not models:
        print(f"No models found in {runs_dir}")
        print("Expected layout: runs/<name>/model.pt (or pretrain.pt) + tokenizer.json")
        sys.exit(1)

    if len(models) == 1:
        m = models[0]
        print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...")
        bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
        return bundle, m["checkpoint"]

    print("Available models:")
    for i, m in enumerate(models):
        print(f"  [{i + 1}] {m['name']}/{m['checkpoint']} ({m['series']})")
    while True:
        try:
            choice = input("Select model [1]: ").strip()
        except (EOFError, KeyboardInterrupt):
            print()
            sys.exit(0)
        if not choice:
            choice = "1"
        if choice.isdigit() and 1 <= int(choice) <= len(models):
            m = models[int(choice) - 1]
            print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...")
            bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
            return bundle, m["checkpoint"]
        print(f"Enter a number 1-{len(models)}")


# ---------------------------------------------------------------------------
# Generation mode configs
# ---------------------------------------------------------------------------

MODES = {
    "chat-coherent": {
        "label": "Chat — Coherent",
        "desc": "structured, consistent, strong repetition control",
        "sft_mode": "chat",
        "temperature": 0.35,
        "top_k": 20,
        "top_p": 0.88,
        "min_p": 0.10,
        "no_repeat_ngram_size": 4,
        "repetition_penalty": 1.22,
        "logit_soft_cap": 20.0,
        "loop_penalty": 20.0,
        "max_new_tokens": 4096,
        "context_window": 2048,
    },
    "chat-variants": {
        "label": "Chat — Variants",
        "desc": "creative, diverse, more surprising outputs",
        "sft_mode": "chat",
        "temperature": 0.65,
        "top_k": 60,
        "top_p": 0.92,
        "min_p": 0.05,
        "no_repeat_ngram_size": 4,
        "repetition_penalty": 1.12,
        "logit_soft_cap": 20.0,
        "loop_penalty": 14.0,
        "max_new_tokens": 4096,
        "context_window": 2048,
    },
    "pretrain-coherent": {
        "label": "Pretrain — Coherent",
        "desc": "grounded continuation, low temperature, tight sampling",
        "sft_mode": False,
        "temperature": 0.3,
        "top_k": 20,
        "top_p": 0.85,
        "min_p": 0.10,
        "no_repeat_ngram_size": 4,
        "repetition_penalty": 1.2,
        "logit_soft_cap": 20.0,
        "loop_penalty": 20.0,
        "max_new_tokens": 4096,
        "context_window": 2048,
    },
    "pretrain-variants": {
        "label": "Pretrain — Variants",
        "desc": "free-form continuation, higher temperature, more exploration",
        "sft_mode": False,
        "temperature": 0.7,
        "top_k": 60,
        "top_p": 0.93,
        "min_p": 0.04,
        "no_repeat_ngram_size": 4,
        "repetition_penalty": 1.12,
        "logit_soft_cap": 20.0,
        "loop_penalty": 12.0,
        "max_new_tokens": 4096,
        "context_window": 2048,
    },
}

_MODE_LIST = list(MODES.keys())


def pick_mode(is_pretrain: bool) -> dict:
    """Prompt the user to choose a generation mode. Returns a config dict."""
    # Filter to relevant modes based on checkpoint type
    candidates = [k for k in _MODE_LIST if ("pretrain" in k) == is_pretrain]
    print("\nGeneration mode:")
    for i, key in enumerate(candidates):
        cfg = MODES[key]
        print(f"  [{i + 1}] {cfg['label']}  — {cfg['desc']}")
    while True:
        try:
            choice = input("Select mode [1]: ").strip()
        except (EOFError, KeyboardInterrupt):
            print()
            sys.exit(0)
        if not choice:
            choice = "1"
        if choice.isdigit() and 1 <= int(choice) <= len(candidates):
            key = candidates[int(choice) - 1]
            cfg = MODES[key]
            print(f"Mode: {cfg['label']}")
            return cfg
        print(f"Enter a number 1-{len(candidates)}")


def _run_loop(bundle: dict, cfg: dict) -> None:
    model = bundle["model"]
    tokenizer = bundle["tokenizer"]
    device = bundle["device"]
    sft = cfg["sft_mode"]
    prompt_label = "You" if sft else "Prompt"
    print(f"\nModel ready on {device}. Type your message, or /quit to exit.")
    print(f"  temp={cfg['temperature']}  top_k={cfg['top_k']}  top_p={cfg['top_p']}")
    print(f"  min_p={cfg['min_p']}  ng={cfg['no_repeat_ngram_size']}  rp={cfg['repetition_penalty']}")
    print(f"  cap={cfg['logit_soft_cap']}  loop_penalty={cfg['loop_penalty']}\n")
    while True:
        try:
            prompt = input(f"{prompt_label}: ").strip()
        except (EOFError, KeyboardInterrupt):
            print()
            break
        if not prompt:
            continue
        if prompt in ("/quit", "/exit", "/q"):
            break
        if prompt == "/help":
            print("Commands: /quit  /exit  /q  /help  /mode")
            if sft:
                print("Anything else is sent as a chat prompt.")
            else:
                print("Anything else is sent as a raw continuation prompt.")
            continue
        if prompt == "/mode":
            print(f"Current: {cfg['label']} — {cfg['desc']}")
            continue
        print("AI: ", end="", flush=True)
        generate(
            model=model,
            tokenizer=tokenizer,
            prompt=prompt,
            max_new_tokens=cfg["max_new_tokens"],
            temperature=cfg["temperature"],
            top_k=cfg["top_k"],
            top_p=cfg["top_p"],
            min_p=cfg["min_p"],
            no_repeat_ngram_size=cfg["no_repeat_ngram_size"],
            repetition_penalty=cfg["repetition_penalty"],
            logit_soft_cap=cfg["logit_soft_cap"],
            loop_penalty=cfg["loop_penalty"],
            device=str(device),
            sft_mode=cfg["sft_mode"],
            stream=True,
            context_window=cfg["context_window"],
        )




# ---------------------------------------------------------------------------
# Dynamic collection discovery
# ---------------------------------------------------------------------------

_COLLECTION_SLUG = "CompactAI-O/tmlm-haiku-series"
_AUTHOR = "CompactAI-O"
_SEARCH = "TMLM-Haiku"

_FALLBACK_COLLECTION = [
    {"version": "TMLM-Haiku-2.3", "hf_id": "CompactAI-O/TMLM-Haiku-2.3"},
    {"version": "TMLM-Haiku-2",   "hf_id": "CompactAI-O/TMLM-Haiku-2"},
    {"version": "TMLM-Haiku-1.3", "hf_id": "CompactAI-O/TMLM-Haiku-1.3"},
    {"version": "TMLM-Haiku-1",   "hf_id": "CompactAI-O/TMLM-Haiku-1"},
    {"version": "Glint-1",         "hf_id": "CompactAI-O/Glint-1"},
]

_EXTRA_REPOS = ["CompactAI-O/Glint-1"]


def _probe_repo(hf_id: str) -> dict | None:
    """Return entry dict for one repo, or None if no usable checkpoints found."""
    from huggingface_hub import list_repo_files

    try:
        files = set(list_repo_files(hf_id))
    except Exception:
        return None

    # Detect which subdirectory holds the checkpoints
    subdir: str | None = None
    for candidate in ("models", "model"):
        if any(f.startswith(f"{candidate}/") for f in files):
            subdir = candidate
            break

    prefix = f"{subdir}/" if subdir else ""

    # Collect all .pt files in the checkpoint directory
    pt_files = sorted(
        f[len(prefix):] for f in files
        if f.startswith(prefix) and f.endswith(".pt")
    )

    _LABELS = {
        "model.pt": ("Chat (SFT)", False),
        "model_rep.pt": ("Chat (anti-repetition)", False),
        "pretrain.pt": ("Pretrain (base)", True),
    }

    checkpoints = []
    for fname in pt_files:
        label, is_pretrain = _LABELS.get(fname, (fname.removesuffix(".pt"), "pretrain" in fname))
        checkpoints.append((label, fname, is_pretrain))

    if not checkpoints:
        return None

    return {
        "version": hf_id.split("/")[-1],
        "hf_id": hf_id,
        "subdir": subdir,
        "checkpoints": checkpoints,
        "desc": "",
    }


def fetch_collection() -> list[dict]:
    """Query HF for all CompactAI-O TMLM-Haiku models, newest first."""
    from huggingface_hub import HfApi

    print("Checking HuggingFace collection for available models...")
    try:
        api = HfApi()
        infos = list(
            api.list_models(
                author=_AUTHOR,
                search=_SEARCH,
                sort="lastModified",
            )
        )
        infos.sort(key=lambda m: getattr(m, "lastModified", ""), reverse=True)
    except Exception as exc:
        print(f"  Could not reach HuggingFace ({exc}); using fallback list.")
        infos = [type("M", (), {"id": e["hf_id"]})() for e in _FALLBACK_COLLECTION]

    entries = []
    seen_ids: set = set()
    for info in infos:
        repo_id = info.id
        if _SEARCH.lower() not in repo_id.lower():
            continue
        entry = _probe_repo(repo_id)
        if entry:
            entries.append(entry)
            seen_ids.add(repo_id)

    # Always include extra repos (e.g. Glint-1) not caught by TMLM-Haiku search
    for repo_id in _EXTRA_REPOS:
        if repo_id not in seen_ids:
            entry = _probe_repo(repo_id)
            if entry:
                entries.append(entry)
                seen_ids.add(repo_id)

    if not entries:
        print("  No models found; using fallback list.")
        for fb in _FALLBACK_COLLECTION:
            e = _probe_repo(fb["hf_id"])
            if e:
                entries.append(e)

    return entries


# ---------------------------------------------------------------------------
# Download helper
# ---------------------------------------------------------------------------


def _download_version(entry: dict, cache_dir: Path) -> Path:
    """Download full repo snapshot; return the directory containing model files."""
    try:
        from huggingface_hub import snapshot_download
    except ImportError:
        print("huggingface_hub not installed. Run: pip install huggingface_hub")
        sys.exit(1)

    hf_id = entry["hf_id"]
    print(f"Fetching {hf_id} ...")
    try:
        local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir)))
    except Exception as exc:
        print(f"Download failed: {exc}")
        sys.exit(1)

    subdir = entry.get("subdir")
    model_dir = (local_dir / subdir) if subdir else local_dir
    if not model_dir.exists():
        # Fallback to root
        model_dir = local_dir
    return model_dir


# ---------------------------------------------------------------------------
# Selection prompts
# ---------------------------------------------------------------------------


def _prompt_int(prompt: str, lo: int, hi: int, default: int = 1) -> int:
    while True:
        try:
            raw = input(f"{prompt} [{default}]: ").strip()
        except (EOFError, KeyboardInterrupt):
            print()
            sys.exit(0)
        if not raw:
            return default
        if raw.isdigit() and lo <= int(raw) <= hi:
            return int(raw)
        print(f"  Enter a number {lo}–{hi}.")


def pick_version(collection: list[dict]) -> dict:
    print("\nTMLM-Haiku series  (CompactAI-O)\n")
    for i, entry in enumerate(collection):
        desc = f"  — {entry['desc']}" if entry["desc"] else ""
        print(f"  [{i + 1}] {entry['version']}{desc}")
    idx = _prompt_int("Select version", 1, len(collection))
    return collection[idx - 1]


def pick_checkpoint(entry: dict) -> tuple[str, bool]:
    """Return (filename, is_pretrain)."""
    ckpts = entry["checkpoints"]
    if len(ckpts) == 1:
        label, fname, is_pretrain = ckpts[0]
        print(f"  Using: {label} ({fname})")
        return fname, is_pretrain

    print(f"\nCheckpoints for {entry['version']}:")
    for i, (label, fname, _) in enumerate(ckpts):
        print(f"  [{i + 1}] {label}  ({fname})")
    idx = _prompt_int("Select checkpoint", 1, len(ckpts))
    label, fname, is_pretrain = ckpts[idx - 1]
    return fname, is_pretrain


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------


def main() -> None:
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--compare", "-c", action="store_true")
    parser.add_argument("--prompt", "-p", type=str, default="Hello")
    mode_group = parser.add_mutually_exclusive_group()
    mode_group.add_argument("--pretrain", action="store_true")
    mode_group.add_argument("--sft", action="store_true")
    args, _ = parser.parse_known_args()

    print("=" * 56)
    print("  CompactAI-O Interactive Chat")
    print("  Models: huggingface.co/CompactAI-O")
    print("=" * 56)

    if args.compare:
        prefetch_huggingface_models()
        cfg = pick_mode(is_pretrain=args.pretrain)
        prompt_label = "You" if cfg["sft_mode"] else "Prompt"
        while True:
            print(f"{prompt_label}:", end=" ", flush=True)
            prompt = sys.stdin.readline().strip()
            if not prompt or prompt in ("/quit", "/exit", "/q"):
                break
            compare_all_models(prompt, cfg)
        return

    collection = fetch_collection()
    if not collection:
        print("No models found. Check your internet connection.")
        sys.exit(1)

    entry = pick_version(collection)
    fname, is_pretrain = pick_checkpoint(entry)

    if args.pretrain:
        is_pretrain = True
    elif args.sft:
        is_pretrain = False

    root = Path(__file__).resolve().parent
    cache_dir = root / "cache" / "huggingface"
    cache_dir.mkdir(parents=True, exist_ok=True)

    model_dir = _download_version(entry, cache_dir)

    model_path = model_dir / fname
    tokenizer_path = model_dir / "tokenizer.json"

    if not model_path.exists():
        print(f"File not found: {model_path}")
        sys.exit(1)
    if not tokenizer_path.exists():
        print(f"Tokenizer not found: {tokenizer_path}")
        sys.exit(1)

    print(f"Loading {entry['version']} / {fname} ...")
    bundle = load_local_model(model_path, tokenizer_path, "Haiku")

    # Use checkpoint-embedded sft_mode/phase if available
    sft_mode_flag = bundle.get("sft_mode")
    phase_flag = bundle.get("phase")
    if sft_mode_flag is not None and not args.pretrain and not args.sft:
        is_pretrain = not sft_mode_flag
    elif phase_flag is not None and not args.pretrain and not args.sft:
        is_pretrain = phase_flag == "pretrain"

    print("\nChoose action:")
    print("  [1] Chat with this model")
    print("  [2] Compare ALL models (local + HuggingFace)")
    print("  [3] Run Benchmark (BLiMP / WikiText-2 / ARC-Easy)")
    print("Select [1]:", end=" ", flush=True)
    choice = sys.stdin.readline().strip() or "1"

    if choice == "1":
        cfg = pick_mode(is_pretrain)
        _run_loop(bundle, cfg)
    elif choice == "2":
        print("\nDownloading/preparing HuggingFace models...")
        prefetch_huggingface_models()
        cfg = pick_mode(is_pretrain)
        prompt_label = "You" if cfg["sft_mode"] else "Prompt"
        while True:
            print(f"{prompt_label}:", end=" ", flush=True)
            prompt = sys.stdin.readline().strip()
            if not prompt or prompt in ("/quit", "/exit", "/q"):
                break
            compare_all_models(prompt, cfg)
    elif choice == "3":
        run_benchmark_mode()
    else:
        print("Enter 1, 2, or 3")


if __name__ == "__main__":
    main()

