import os import threading import queue from dataclasses import dataclass, field from typing import Optional # Set default cache locations BEFORE importing libraries that use them PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if 'NLTK_DATA' not in os.environ: nltk_data_path = os.path.join(PROJECT_DIR, 'cache-nltk') os.makedirs(nltk_data_path, exist_ok=True) os.environ['NLTK_DATA'] = nltk_data_path if 'HF_HOME' not in os.environ: os.environ['HF_HOME'] = os.path.join(PROJECT_DIR, 'cache-huggingface') # Device configuration: set TORCH_DEVICE=cpu to force CPU, otherwise auto-detect DEVICE = os.environ.get('TORCH_DEVICE', None) # None = auto-detect (cuda/mps/cpu) from salience.timed_import import timed_import with timed_import("import numpy as np"): import numpy as np with timed_import("import torch"): import torch with timed_import("from sentence_transformers import SentenceTransformer"): from sentence_transformers import SentenceTransformer with timed_import("import nltk"): import nltk.data import nltk # Download punkt_tab to the configured location # Using punkt_tab (the modern tab-separated format introduced in NLTK 3.8+) # instead of the older punkt pickle format # The punkt_tab model version depends on the NLTK Python package version # Check your NLTK version with: uv pip show nltk nltk.download('punkt_tab') # Available models for the demo # Keys are short names for the API, values are full HuggingFace repo IDs AVAILABLE_MODELS = { 'all-mpnet-base-v2': 'sentence-transformers/all-mpnet-base-v2', # Dec 2020 'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5', # Jan 2024 # 'qwen3-embedding-4b': 'Qwen/Qwen3-Embedding-4B', # April 2025 'mxbai-embed-large-v1': 'mixedbread-ai/mxbai-embed-large-v1', } # On clustering # all-mpnet-base-v2: 40.03 # mixedbread-ai/mxbai-embed-large-v1: 46.71 # gte-large-en-v1.5: 47.95 # Qwen/Qwen3-Embedding-0.6B: 52.33 # Qwen/Qwen3-Embedding-4B: 57.15 # On STS # all-mpnet-base-v2: 58.17 # gte-large-en-v1.5: 81.43 # Qwen/Qwen3-Embedding-0.6B: 76.17 # Qwen/Qwen3-Embedding-4B: 80.86 # mixedbread-ai/mxbai-embed-large-v1: 85.00 # Models loaded on first use in worker thread _loaded_models = {} def _get_model(model_name): """Load and cache a model. Called only from worker thread.""" if model_name not in _loaded_models: repo_id = AVAILABLE_MODELS[model_name] print(f"Loading model {repo_id} into memory...") trust_remote = model_name in ('gte-large-en-v1.5', 'qwen3-embedding-4b') _loaded_models[model_name] = SentenceTransformer(repo_id, trust_remote_code=trust_remote, device=DEVICE) return _loaded_models[model_name] sent_detector = nltk.data.load('tokenizers/punkt/english.pickle') def cos_sim(a): sims = a @ a.T a_norm = np.linalg.norm(a, axis=-1, keepdims=True) sims /= a_norm sims /= a_norm.T return sims def degree_power(A, k): degrees = np.power(np.array(A.sum(1)), k).ravel() D = np.diag(degrees) return D def normalized_adjacency(A): normalized_D = degree_power(A, -0.5) return torch.from_numpy(normalized_D.dot(A).dot(normalized_D)) def get_sentences(source_text): sentence_ranges = list(sent_detector.span_tokenize(source_text)) sentences = [source_text[start:end] for start, end in sentence_ranges] return sentences, sentence_ranges def text_rank(sentences, model_name='all-mpnet-base-v2'): model = _get_model(model_name) vectors = model.encode(sentences) adjacency = torch.tensor(cos_sim(vectors)).fill_diagonal_(0.) adjacency[adjacency < 0] = 0 return normalized_adjacency(adjacency) def extract(source_text, model_name='all-mpnet-base-v2'): """ Main API function that extracts sentence positions and computes normalized adjacency matrix. Returns: sentence_ranges: List of (start, end) tuples for each sentence's character position adjacency: (N × N) normalized adjacency matrix where N is the number of sentences. Each entry (i,j) represents the normalized similarity between sentences i and j. This matrix is returned to the frontend, which raises it to a power and computes the final salience scores via random walk simulation. """ sentences, sentence_ranges = get_sentences(source_text) adjacency = text_rank(sentences, model_name) return sentence_ranges, adjacency # ============================================================================= # Worker Thread for Model Inference # ============================================================================= # All model inference runs in a dedicated worker thread. This: # 1. Avoids fork() issues with Metal/MPS (no forking server needed) # 2. Serializes inference requests (one at a time) # 3. Keeps /stats and other endpoints responsive @dataclass class WorkItem: source_text: str model_name: str event: threading.Event = field(default_factory=threading.Event) result: Optional[tuple] = None error: Optional[str] = None _work_queue: queue.Queue[WorkItem] = queue.Queue() def _model_worker(): """Worker thread loop - processes inference requests from queue.""" while True: item = _work_queue.get() try: item.result = extract(item.source_text, item.model_name) except Exception as e: item.error = str(e) finally: item.event.set() # Start worker thread threading.Thread(target=_model_worker, daemon=True, name="model-worker").start() def submit_work(source_text: str, model_name: str, timeout: float = 60.0) -> tuple: """Submit text for salience extraction and wait for result. Args: source_text: Text to analyze model_name: Name of the model to use (must be in AVAILABLE_MODELS) timeout: Max seconds to wait for result Returns: (sentence_ranges, adjacency) tuple Raises: TimeoutError: If inference takes longer than timeout RuntimeError: If inference fails """ item = WorkItem(source_text=source_text, model_name=model_name) _work_queue.put(item) if not item.event.wait(timeout=timeout): raise TimeoutError("Model inference timed out") if item.error: raise RuntimeError(item.error) return item.result # ============================================================================= # Unused/Debugging Code # ============================================================================= def terminal_distr(adjacency, initial=None): sample = initial if initial is not None else torch.full((adjacency.shape[0],), 1.) scores = sample.matmul(torch.matrix_power(adjacency, 10)).numpy().tolist() return scores def get_results(sentences, adjacency): scores = terminal_distr(adjacency) for score, sentence in sorted(zip(scores, sentences), key=lambda xs: xs[0]): if score > 1.1: print('{:0.2f}: {}'.format(score, sentence))