193 lines
7 KiB
Python
193 lines
7 KiB
Python
import os
|
||
import threading
|
||
import queue
|
||
from dataclasses import dataclass, field
|
||
from typing import Optional
|
||
|
||
# Set default cache locations BEFORE importing libraries that use them
|
||
PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
if 'NLTK_DATA' not in os.environ:
|
||
nltk_data_path = os.path.join(PROJECT_DIR, 'cache-nltk')
|
||
os.makedirs(nltk_data_path, exist_ok=True)
|
||
os.environ['NLTK_DATA'] = nltk_data_path
|
||
|
||
if 'HF_HOME' not in os.environ:
|
||
os.environ['HF_HOME'] = os.path.join(PROJECT_DIR, 'cache-huggingface')
|
||
|
||
# Device configuration: set TORCH_DEVICE=cpu to force CPU, otherwise auto-detect
|
||
DEVICE = os.environ.get('TORCH_DEVICE', None) # None = auto-detect (cuda/mps/cpu)
|
||
|
||
from salience.timed_import import timed_import
|
||
|
||
with timed_import("import numpy as np"):
|
||
import numpy as np
|
||
with timed_import("import torch"):
|
||
import torch
|
||
with timed_import("from sentence_transformers import SentenceTransformer"):
|
||
from sentence_transformers import SentenceTransformer
|
||
with timed_import("import nltk"):
|
||
import nltk.data
|
||
import nltk
|
||
|
||
# Download punkt_tab to the configured location
|
||
# Using punkt_tab (the modern tab-separated format introduced in NLTK 3.8+)
|
||
# instead of the older punkt pickle format
|
||
# The punkt_tab model version depends on the NLTK Python package version
|
||
# Check your NLTK version with: uv pip show nltk
|
||
nltk.download('punkt_tab')
|
||
|
||
# Available models for the demo
|
||
# Keys are short names for the API, values are full HuggingFace repo IDs
|
||
AVAILABLE_MODELS = {
|
||
'all-mpnet-base-v2': 'sentence-transformers/all-mpnet-base-v2', # Dec 2020
|
||
'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5', # Jan 2024
|
||
# 'qwen3-embedding-4b': 'Qwen/Qwen3-Embedding-4B', # April 2025
|
||
'mxbai-embed-large-v1': 'mixedbread-ai/mxbai-embed-large-v1',
|
||
}
|
||
|
||
# On clustering
|
||
# all-mpnet-base-v2: 40.03
|
||
# mixedbread-ai/mxbai-embed-large-v1: 46.71
|
||
# gte-large-en-v1.5: 47.95
|
||
# Qwen/Qwen3-Embedding-0.6B: 52.33
|
||
# Qwen/Qwen3-Embedding-4B: 57.15
|
||
|
||
# On STS
|
||
# all-mpnet-base-v2: 58.17
|
||
# gte-large-en-v1.5: 81.43
|
||
# Qwen/Qwen3-Embedding-0.6B: 76.17
|
||
# Qwen/Qwen3-Embedding-4B: 80.86
|
||
# mixedbread-ai/mxbai-embed-large-v1: 85.00
|
||
|
||
# Models loaded on first use in worker thread
|
||
_loaded_models = {}
|
||
|
||
def _get_model(model_name):
|
||
"""Load and cache a model. Called only from worker thread."""
|
||
if model_name not in _loaded_models:
|
||
repo_id = AVAILABLE_MODELS[model_name]
|
||
print(f"Loading model {repo_id} into memory...")
|
||
trust_remote = model_name in ('gte-large-en-v1.5', 'qwen3-embedding-4b')
|
||
_loaded_models[model_name] = SentenceTransformer(repo_id, trust_remote_code=trust_remote, device=DEVICE)
|
||
return _loaded_models[model_name]
|
||
|
||
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
|
||
|
||
def cos_sim(a):
|
||
sims = a @ a.T
|
||
a_norm = np.linalg.norm(a, axis=-1, keepdims=True)
|
||
sims /= a_norm
|
||
sims /= a_norm.T
|
||
return sims
|
||
|
||
def degree_power(A, k):
|
||
degrees = np.power(np.array(A.sum(1)), k).ravel()
|
||
D = np.diag(degrees)
|
||
return D
|
||
|
||
def normalized_adjacency(A):
|
||
normalized_D = degree_power(A, -0.5)
|
||
return torch.from_numpy(normalized_D.dot(A).dot(normalized_D))
|
||
|
||
def get_sentences(source_text):
|
||
sentence_ranges = list(sent_detector.span_tokenize(source_text))
|
||
sentences = [source_text[start:end] for start, end in sentence_ranges]
|
||
return sentences, sentence_ranges
|
||
|
||
def text_rank(sentences, model_name='all-mpnet-base-v2'):
|
||
model = _get_model(model_name)
|
||
vectors = model.encode(sentences)
|
||
adjacency = torch.tensor(cos_sim(vectors)).fill_diagonal_(0.)
|
||
adjacency[adjacency < 0] = 0
|
||
return normalized_adjacency(adjacency)
|
||
|
||
def extract(source_text, model_name='all-mpnet-base-v2'):
|
||
"""
|
||
Main API function that extracts sentence positions and computes normalized adjacency matrix.
|
||
|
||
Returns:
|
||
sentence_ranges: List of (start, end) tuples for each sentence's character position
|
||
adjacency: (N × N) normalized adjacency matrix where N is the number of sentences.
|
||
Each entry (i,j) represents the normalized similarity between sentences i and j.
|
||
This matrix is returned to the frontend, which raises it to a power and computes
|
||
the final salience scores via random walk simulation.
|
||
"""
|
||
sentences, sentence_ranges = get_sentences(source_text)
|
||
adjacency = text_rank(sentences, model_name)
|
||
return sentence_ranges, adjacency
|
||
|
||
|
||
# =============================================================================
|
||
# Worker Thread for Model Inference
|
||
# =============================================================================
|
||
# All model inference runs in a dedicated worker thread. This:
|
||
# 1. Avoids fork() issues with Metal/MPS (no forking server needed)
|
||
# 2. Serializes inference requests (one at a time)
|
||
# 3. Keeps /stats and other endpoints responsive
|
||
|
||
@dataclass
|
||
class WorkItem:
|
||
source_text: str
|
||
model_name: str
|
||
event: threading.Event = field(default_factory=threading.Event)
|
||
result: Optional[tuple] = None
|
||
error: Optional[str] = None
|
||
|
||
_work_queue: queue.Queue[WorkItem] = queue.Queue()
|
||
|
||
def _model_worker():
|
||
"""Worker thread loop - processes inference requests from queue."""
|
||
while True:
|
||
item = _work_queue.get()
|
||
try:
|
||
item.result = extract(item.source_text, item.model_name)
|
||
except Exception as e:
|
||
item.error = str(e)
|
||
finally:
|
||
item.event.set()
|
||
|
||
# Start worker thread
|
||
threading.Thread(target=_model_worker, daemon=True, name="model-worker").start()
|
||
|
||
def submit_work(source_text: str, model_name: str, timeout: float = 60.0) -> tuple:
|
||
"""Submit text for salience extraction and wait for result.
|
||
|
||
Args:
|
||
source_text: Text to analyze
|
||
model_name: Name of the model to use (must be in AVAILABLE_MODELS)
|
||
timeout: Max seconds to wait for result
|
||
|
||
Returns:
|
||
(sentence_ranges, adjacency) tuple
|
||
|
||
Raises:
|
||
TimeoutError: If inference takes longer than timeout
|
||
RuntimeError: If inference fails
|
||
"""
|
||
item = WorkItem(source_text=source_text, model_name=model_name)
|
||
_work_queue.put(item)
|
||
|
||
if not item.event.wait(timeout=timeout):
|
||
raise TimeoutError("Model inference timed out")
|
||
|
||
if item.error:
|
||
raise RuntimeError(item.error)
|
||
|
||
return item.result
|
||
|
||
|
||
# =============================================================================
|
||
# Unused/Debugging Code
|
||
# =============================================================================
|
||
|
||
def terminal_distr(adjacency, initial=None):
|
||
sample = initial if initial is not None else torch.full((adjacency.shape[0],), 1.)
|
||
scores = sample.matmul(torch.matrix_power(adjacency, 10)).numpy().tolist()
|
||
return scores
|
||
|
||
def get_results(sentences, adjacency):
|
||
scores = terminal_distr(adjacency)
|
||
for score, sentence in sorted(zip(scores, sentences), key=lambda xs: xs[0]):
|
||
if score > 1.1:
|
||
print('{:0.2f}: {}'.format(score, sentence))
|