salience-editor/api/salience/salience.py

115 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import torch
from sentence_transformers import SentenceTransformer
import nltk.data
import nltk
import os
# Set NLTK data path to project directory
PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
NLTK_DATA_DIR = os.path.join(PROJECT_DIR, 'nltk_data')
# Add to NLTK's search path
nltk.data.path.insert(0, NLTK_DATA_DIR)
# Download to the custom location
# Using punkt_tab (the modern tab-separated format introduced in NLTK 3.8+)
# instead of the older punkt pickle format
# The punkt_tab model version depends on the NLTK Python package version
# Check your NLTK version with: uv pip show nltk
nltk.download('punkt_tab', download_dir=NLTK_DATA_DIR)
# Available models for the demo
AVAILABLE_MODELS = {
'all-mpnet-base-v2': 'all-mpnet-base-v2', # Dec 2020
'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5', # Jan 2024
# 'qwen3-embedding-4b': 'Qwen/Qwen3-Embedding-4B', # April 2025
'mxbai-embed-large-v1': 'mixedbread-ai/mxbai-embed-large-v1',
}
# On clustering
# mixedbread-ai/mxbai-embed-large-v1: 46.71
# gte-large-en-v1.5: 47.95
# Qwen/Qwen3-Embedding-0.6B: 52.33
# Qwen/Qwen3-Embedding-4B: 57.15
# On STS
# gte-large-en-v1.5: 81.43
# Qwen/Qwen3-Embedding-0.6B: 76.17
# Qwen/Qwen3-Embedding-4B: 80.86
# mixedbread-ai/mxbai-embed-large-v1: 85.00
# Load all models into memory
print("Loading sentence transformer models...")
models = {}
models['all-mpnet-base-v2'] = SentenceTransformer('all-mpnet-base-v2')
print("Loading Alibaba-NLP/gte-large-en-v1.5")
models['gte-large-en-v1.5'] = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
#print("Loading Qwen/Qwen3-Embedding-4B")
#models['qwen3-embedding-4b'] = SentenceTransformer('Qwen/Qwen3-Embedding-4B', trust_remote_code=True)
print("Loading mixedbread-ai/mxbai-embed-large-v1")
models["mxbai-embed-large-v1"] = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1')
print("All models loaded!")
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
def cos_sim(a, b):
sims = a @ b.T
a_norm = np.linalg.norm(a, axis=-1)
b_norm = np.linalg.norm(b, axis=-1)
a_normalized = (sims.T / a_norm.T).T
sims = a_normalized / b_norm
return sims
def degree_power(A, k):
degrees = np.power(np.array(A.sum(1)), k).ravel()
D = np.diag(degrees)
return D
def normalized_adjacency(A):
normalized_D = degree_power(A, -0.5)
return torch.from_numpy(normalized_D.dot(A).dot(normalized_D))
def get_sentences(source_text):
sentence_ranges = list(sent_detector.span_tokenize(source_text))
sentences = [source_text[start:end] for start, end in sentence_ranges]
return sentences, sentence_ranges
def text_rank(sentences, model_name='all-mpnet-base-v2'):
model = models[model_name]
vectors = model.encode(sentences)
adjacency = torch.tensor(cos_sim(vectors, vectors)).fill_diagonal_(0.)
adjacency[adjacency < 0] = 0
return normalized_adjacency(adjacency)
def extract(source_text, model_name='all-mpnet-base-v2'):
"""
Main API function that extracts sentence positions and computes normalized adjacency matrix.
Returns:
sentence_ranges: List of (start, end) tuples for each sentence's character position
adjacency: (N × N) normalized adjacency matrix where N is the number of sentences.
Each entry (i,j) represents the normalized similarity between sentences i and j.
This matrix is returned to the frontend, which raises it to a power and computes
the final salience scores via random walk simulation.
"""
sentences, sentence_ranges = get_sentences(source_text)
adjacency = text_rank(sentences, model_name)
return sentence_ranges, adjacency
# =============================================================================
# Unused/Debugging Code
# =============================================================================
def terminal_distr(adjacency, initial=None):
sample = initial if initial is not None else torch.full((adjacency.shape[0],), 1.)
scores = sample.matmul(torch.matrix_power(adjacency, 10)).numpy().tolist()
return scores
def get_results(sentences, adjacency):
scores = terminal_distr(adjacency)
for score, sentence in sorted(zip(scores, sentences), key=lambda xs: xs[0]):
if score > 1.1:
print('{:0.2f}: {}'.format(score, sentence))