salience-editor/api/salience/salience.py

193 lines
7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import threading
import queue
from dataclasses import dataclass, field
from typing import Optional
# Set default cache locations BEFORE importing libraries that use them
PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if 'NLTK_DATA' not in os.environ:
nltk_data_path = os.path.join(PROJECT_DIR, 'cache-nltk')
os.makedirs(nltk_data_path, exist_ok=True)
os.environ['NLTK_DATA'] = nltk_data_path
if 'HF_HOME' not in os.environ:
os.environ['HF_HOME'] = os.path.join(PROJECT_DIR, 'cache-huggingface')
# Device configuration: set TORCH_DEVICE=cpu to force CPU, otherwise auto-detect
DEVICE = os.environ.get('TORCH_DEVICE', None) # None = auto-detect (cuda/mps/cpu)
from salience.timed_import import timed_import
with timed_import("import numpy as np"):
import numpy as np
with timed_import("import torch"):
import torch
with timed_import("from sentence_transformers import SentenceTransformer"):
from sentence_transformers import SentenceTransformer
with timed_import("import nltk"):
import nltk.data
import nltk
# Download punkt_tab to the configured location
# Using punkt_tab (the modern tab-separated format introduced in NLTK 3.8+)
# instead of the older punkt pickle format
# The punkt_tab model version depends on the NLTK Python package version
# Check your NLTK version with: uv pip show nltk
nltk.download('punkt_tab')
# Available models for the demo
# Keys are short names for the API, values are full HuggingFace repo IDs
AVAILABLE_MODELS = {
'all-mpnet-base-v2': 'sentence-transformers/all-mpnet-base-v2', # Dec 2020
'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5', # Jan 2024
# 'qwen3-embedding-4b': 'Qwen/Qwen3-Embedding-4B', # April 2025
'mxbai-embed-large-v1': 'mixedbread-ai/mxbai-embed-large-v1',
}
# On clustering
# all-mpnet-base-v2: 40.03
# mixedbread-ai/mxbai-embed-large-v1: 46.71
# gte-large-en-v1.5: 47.95
# Qwen/Qwen3-Embedding-0.6B: 52.33
# Qwen/Qwen3-Embedding-4B: 57.15
# On STS
# all-mpnet-base-v2: 58.17
# gte-large-en-v1.5: 81.43
# Qwen/Qwen3-Embedding-0.6B: 76.17
# Qwen/Qwen3-Embedding-4B: 80.86
# mixedbread-ai/mxbai-embed-large-v1: 85.00
# Models loaded on first use in worker thread
_loaded_models = {}
def _get_model(model_name):
"""Load and cache a model. Called only from worker thread."""
if model_name not in _loaded_models:
repo_id = AVAILABLE_MODELS[model_name]
print(f"Loading model {repo_id} into memory...")
trust_remote = model_name in ('gte-large-en-v1.5', 'qwen3-embedding-4b')
_loaded_models[model_name] = SentenceTransformer(repo_id, trust_remote_code=trust_remote, device=DEVICE)
return _loaded_models[model_name]
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
def cos_sim(a):
sims = a @ a.T
a_norm = np.linalg.norm(a, axis=-1, keepdims=True)
sims /= a_norm
sims /= a_norm.T
return sims
def degree_power(A, k):
degrees = np.power(np.array(A.sum(1)), k).ravel()
D = np.diag(degrees)
return D
def normalized_adjacency(A):
normalized_D = degree_power(A, -0.5)
return torch.from_numpy(normalized_D.dot(A).dot(normalized_D))
def get_sentences(source_text):
sentence_ranges = list(sent_detector.span_tokenize(source_text))
sentences = [source_text[start:end] for start, end in sentence_ranges]
return sentences, sentence_ranges
def text_rank(sentences, model_name='all-mpnet-base-v2'):
model = _get_model(model_name)
vectors = model.encode(sentences)
adjacency = torch.tensor(cos_sim(vectors)).fill_diagonal_(0.)
adjacency[adjacency < 0] = 0
return normalized_adjacency(adjacency)
def extract(source_text, model_name='all-mpnet-base-v2'):
"""
Main API function that extracts sentence positions and computes normalized adjacency matrix.
Returns:
sentence_ranges: List of (start, end) tuples for each sentence's character position
adjacency: (N × N) normalized adjacency matrix where N is the number of sentences.
Each entry (i,j) represents the normalized similarity between sentences i and j.
This matrix is returned to the frontend, which raises it to a power and computes
the final salience scores via random walk simulation.
"""
sentences, sentence_ranges = get_sentences(source_text)
adjacency = text_rank(sentences, model_name)
return sentence_ranges, adjacency
# =============================================================================
# Worker Thread for Model Inference
# =============================================================================
# All model inference runs in a dedicated worker thread. This:
# 1. Avoids fork() issues with Metal/MPS (no forking server needed)
# 2. Serializes inference requests (one at a time)
# 3. Keeps /stats and other endpoints responsive
@dataclass
class WorkItem:
source_text: str
model_name: str
event: threading.Event = field(default_factory=threading.Event)
result: Optional[tuple] = None
error: Optional[str] = None
_work_queue: queue.Queue[WorkItem] = queue.Queue()
def _model_worker():
"""Worker thread loop - processes inference requests from queue."""
while True:
item = _work_queue.get()
try:
item.result = extract(item.source_text, item.model_name)
except Exception as e:
item.error = str(e)
finally:
item.event.set()
# Start worker thread
threading.Thread(target=_model_worker, daemon=True, name="model-worker").start()
def submit_work(source_text: str, model_name: str, timeout: float = 60.0) -> tuple:
"""Submit text for salience extraction and wait for result.
Args:
source_text: Text to analyze
model_name: Name of the model to use (must be in AVAILABLE_MODELS)
timeout: Max seconds to wait for result
Returns:
(sentence_ranges, adjacency) tuple
Raises:
TimeoutError: If inference takes longer than timeout
RuntimeError: If inference fails
"""
item = WorkItem(source_text=source_text, model_name=model_name)
_work_queue.put(item)
if not item.event.wait(timeout=timeout):
raise TimeoutError("Model inference timed out")
if item.error:
raise RuntimeError(item.error)
return item.result
# =============================================================================
# Unused/Debugging Code
# =============================================================================
def terminal_distr(adjacency, initial=None):
sample = initial if initial is not None else torch.full((adjacency.shape[0],), 1.)
scores = sample.matmul(torch.matrix_power(adjacency, 10)).numpy().tolist()
return scores
def get_results(sentences, adjacency):
scores = terminal_distr(adjacency)
for score, sentence in sorted(zip(scores, sentences), key=lambda xs: xs[0]):
if score > 1.1:
print('{:0.2f}: {}'.format(score, sentence))