salience-editor/api/salience/salience.py

127 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
# Set default cache locations BEFORE importing libraries that use them
PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if 'NLTK_DATA' not in os.environ:
nltk_data_path = os.path.join(PROJECT_DIR, 'cache-nltk')
os.makedirs(nltk_data_path, exist_ok=True)
os.environ['NLTK_DATA'] = nltk_data_path
if 'HF_HOME' not in os.environ:
os.environ['HF_HOME'] = os.path.join(PROJECT_DIR, 'cache-huggingface')
from salience.timed_import import timed_import
with timed_import("import numpy as np"):
import numpy as np
with timed_import("import torch"):
import torch
with timed_import("from sentence_transformers import SentenceTransformer"):
from sentence_transformers import SentenceTransformer
with timed_import("import nltk"):
import nltk.data
import nltk
# Download punkt_tab to the configured location
# Using punkt_tab (the modern tab-separated format introduced in NLTK 3.8+)
# instead of the older punkt pickle format
# The punkt_tab model version depends on the NLTK Python package version
# Check your NLTK version with: uv pip show nltk
nltk.download('punkt_tab')
# Available models for the demo
AVAILABLE_MODELS = {
'all-mpnet-base-v2': 'all-mpnet-base-v2', # Dec 2020
'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5', # Jan 2024
# 'qwen3-embedding-4b': 'Qwen/Qwen3-Embedding-4B', # April 2025
'mxbai-embed-large-v1': 'mixedbread-ai/mxbai-embed-large-v1',
}
# On clustering
# all-mpnet-base-v2: 40.03
# mixedbread-ai/mxbai-embed-large-v1: 46.71
# gte-large-en-v1.5: 47.95
# Qwen/Qwen3-Embedding-0.6B: 52.33
# Qwen/Qwen3-Embedding-4B: 57.15
# On STS
# all-mpnet-base-v2: 58.17
# gte-large-en-v1.5: 81.43
# Qwen/Qwen3-Embedding-0.6B: 76.17
# Qwen/Qwen3-Embedding-4B: 80.86
# mixedbread-ai/mxbai-embed-large-v1: 85.00
# Load all models into memory
print("Loading sentence transformer models...")
models = {}
models['all-mpnet-base-v2'] = SentenceTransformer('all-mpnet-base-v2')
print("Loading Alibaba-NLP/gte-large-en-v1.5")
models['gte-large-en-v1.5'] = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
#print("Loading Qwen/Qwen3-Embedding-4B")
#models['qwen3-embedding-4b'] = SentenceTransformer('Qwen/Qwen3-Embedding-4B', trust_remote_code=True)
print("Loading mixedbread-ai/mxbai-embed-large-v1")
models["mxbai-embed-large-v1"] = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1')
print("All models loaded!")
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
def cos_sim(a):
sims = a @ a.T
a_norm = np.linalg.norm(a, axis=-1, keepdims=True)
sims /= a_norm
sims /= a_norm.T
return sims
def degree_power(A, k):
degrees = np.power(np.array(A.sum(1)), k).ravel()
D = np.diag(degrees)
return D
def normalized_adjacency(A):
normalized_D = degree_power(A, -0.5)
return torch.from_numpy(normalized_D.dot(A).dot(normalized_D))
def get_sentences(source_text):
sentence_ranges = list(sent_detector.span_tokenize(source_text))
sentences = [source_text[start:end] for start, end in sentence_ranges]
return sentences, sentence_ranges
def text_rank(sentences, model_name='all-mpnet-base-v2'):
model = models[model_name]
vectors = model.encode(sentences)
adjacency = torch.tensor(cos_sim(vectors)).fill_diagonal_(0.)
adjacency[adjacency < 0] = 0
return normalized_adjacency(adjacency)
def extract(source_text, model_name='all-mpnet-base-v2'):
"""
Main API function that extracts sentence positions and computes normalized adjacency matrix.
Returns:
sentence_ranges: List of (start, end) tuples for each sentence's character position
adjacency: (N × N) normalized adjacency matrix where N is the number of sentences.
Each entry (i,j) represents the normalized similarity between sentences i and j.
This matrix is returned to the frontend, which raises it to a power and computes
the final salience scores via random walk simulation.
"""
sentences, sentence_ranges = get_sentences(source_text)
adjacency = text_rank(sentences, model_name)
return sentence_ranges, adjacency
# =============================================================================
# Unused/Debugging Code
# =============================================================================
def terminal_distr(adjacency, initial=None):
sample = initial if initial is not None else torch.full((adjacency.shape[0],), 1.)
scores = sample.matmul(torch.matrix_power(adjacency, 10)).numpy().tolist()
return scores
def get_results(sentences, adjacency):
scores = terminal_distr(adjacency)
for score, sentence in sorted(zip(scores, sentences), key=lambda xs: xs[0]):
if score > 1.1:
print('{:0.2f}: {}'.format(score, sentence))