salience-editor/api/salience/salience.py

127 lines
4.9 KiB
Python
Raw Normal View History

import os
# Set default cache locations BEFORE importing libraries that use them
PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if 'NLTK_DATA' not in os.environ:
nltk_data_path = os.path.join(PROJECT_DIR, 'cache-nltk')
os.makedirs(nltk_data_path, exist_ok=True)
os.environ['NLTK_DATA'] = nltk_data_path
if 'HF_HOME' not in os.environ:
os.environ['HF_HOME'] = os.path.join(PROJECT_DIR, 'cache-huggingface')
from salience.timed_import import timed_import
with timed_import("import numpy as np"):
import numpy as np
with timed_import("import torch"):
import torch
with timed_import("from sentence_transformers import SentenceTransformer"):
from sentence_transformers import SentenceTransformer
with timed_import("import nltk"):
import nltk.data
import nltk
# Download punkt_tab to the configured location
2025-10-30 16:26:48 -07:00
# Using punkt_tab (the modern tab-separated format introduced in NLTK 3.8+)
# instead of the older punkt pickle format
# The punkt_tab model version depends on the NLTK Python package version
# Check your NLTK version with: uv pip show nltk
nltk.download('punkt_tab')
2025-10-30 16:26:48 -07:00
# Available models for the demo
AVAILABLE_MODELS = {
'all-mpnet-base-v2': 'all-mpnet-base-v2', # Dec 2020
'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5', # Jan 2024
# 'qwen3-embedding-4b': 'Qwen/Qwen3-Embedding-4B', # April 2025
'mxbai-embed-large-v1': 'mixedbread-ai/mxbai-embed-large-v1',
}
# On clustering
2025-11-02 13:09:23 -08:00
# all-mpnet-base-v2: 40.03
2025-10-30 16:26:48 -07:00
# mixedbread-ai/mxbai-embed-large-v1: 46.71
# gte-large-en-v1.5: 47.95
# Qwen/Qwen3-Embedding-0.6B: 52.33
# Qwen/Qwen3-Embedding-4B: 57.15
# On STS
2025-11-02 13:09:23 -08:00
# all-mpnet-base-v2: 58.17
2025-10-30 16:26:48 -07:00
# gte-large-en-v1.5: 81.43
# Qwen/Qwen3-Embedding-0.6B: 76.17
# Qwen/Qwen3-Embedding-4B: 80.86
# mixedbread-ai/mxbai-embed-large-v1: 85.00
# Load all models into memory
print("Loading sentence transformer models...")
models = {}
models['all-mpnet-base-v2'] = SentenceTransformer('all-mpnet-base-v2')
2025-10-30 16:26:48 -07:00
print("Loading Alibaba-NLP/gte-large-en-v1.5")
models['gte-large-en-v1.5'] = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
2025-10-30 16:26:48 -07:00
#print("Loading Qwen/Qwen3-Embedding-4B")
#models['qwen3-embedding-4b'] = SentenceTransformer('Qwen/Qwen3-Embedding-4B', trust_remote_code=True)
2025-10-30 16:26:48 -07:00
print("Loading mixedbread-ai/mxbai-embed-large-v1")
models["mxbai-embed-large-v1"] = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1')
2025-10-30 16:26:48 -07:00
print("All models loaded!")
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
2025-11-02 13:09:23 -08:00
def cos_sim(a):
sims = a @ a.T
a_norm = np.linalg.norm(a, axis=-1, keepdims=True)
sims /= a_norm
sims /= a_norm.T
return sims
def degree_power(A, k):
degrees = np.power(np.array(A.sum(1)), k).ravel()
D = np.diag(degrees)
return D
def normalized_adjacency(A):
normalized_D = degree_power(A, -0.5)
return torch.from_numpy(normalized_D.dot(A).dot(normalized_D))
def get_sentences(source_text):
sentence_ranges = list(sent_detector.span_tokenize(source_text))
sentences = [source_text[start:end] for start, end in sentence_ranges]
return sentences, sentence_ranges
2025-10-30 16:26:48 -07:00
def text_rank(sentences, model_name='all-mpnet-base-v2'):
model = models[model_name]
vectors = model.encode(sentences)
2025-11-02 13:09:23 -08:00
adjacency = torch.tensor(cos_sim(vectors)).fill_diagonal_(0.)
adjacency[adjacency < 0] = 0
return normalized_adjacency(adjacency)
2025-10-30 16:26:48 -07:00
def extract(source_text, model_name='all-mpnet-base-v2'):
2025-11-01 12:08:03 -07:00
"""
Main API function that extracts sentence positions and computes normalized adjacency matrix.
Returns:
sentence_ranges: List of (start, end) tuples for each sentence's character position
adjacency: (N × N) normalized adjacency matrix where N is the number of sentences.
Each entry (i,j) represents the normalized similarity between sentences i and j.
This matrix is returned to the frontend, which raises it to a power and computes
the final salience scores via random walk simulation.
"""
sentences, sentence_ranges = get_sentences(source_text)
2025-10-30 16:26:48 -07:00
adjacency = text_rank(sentences, model_name)
return sentence_ranges, adjacency
2025-11-01 12:08:03 -07:00
# =============================================================================
# Unused/Debugging Code
# =============================================================================
def terminal_distr(adjacency, initial=None):
sample = initial if initial is not None else torch.full((adjacency.shape[0],), 1.)
scores = sample.matmul(torch.matrix_power(adjacency, 10)).numpy().tolist()
return scores
def get_results(sentences, adjacency):
scores = terminal_distr(adjacency)
for score, sentence in sorted(zip(scores, sentences), key=lambda xs: xs[0]):
if score > 1.1:
print('{:0.2f}: {}'.format(score, sentence))