2025-10-30 14:16:04 -07:00
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
import nltk.data
|
|
|
|
|
import nltk
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
# Set NLTK data path to project directory
|
|
|
|
|
PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
|
NLTK_DATA_DIR = os.path.join(PROJECT_DIR, 'nltk_data')
|
|
|
|
|
|
|
|
|
|
# Add to NLTK's search path
|
|
|
|
|
nltk.data.path.insert(0, NLTK_DATA_DIR)
|
|
|
|
|
|
|
|
|
|
# Download to the custom location
|
2025-10-30 16:26:48 -07:00
|
|
|
# Using punkt_tab (the modern tab-separated format introduced in NLTK 3.8+)
|
|
|
|
|
# instead of the older punkt pickle format
|
|
|
|
|
# The punkt_tab model version depends on the NLTK Python package version
|
|
|
|
|
# Check your NLTK version with: uv pip show nltk
|
|
|
|
|
nltk.download('punkt_tab', download_dir=NLTK_DATA_DIR)
|
|
|
|
|
|
|
|
|
|
# Available models for the demo
|
|
|
|
|
AVAILABLE_MODELS = {
|
|
|
|
|
'all-mpnet-base-v2': 'all-mpnet-base-v2', # Dec 2020
|
|
|
|
|
'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5', # Jan 2024
|
|
|
|
|
# 'qwen3-embedding-4b': 'Qwen/Qwen3-Embedding-4B', # April 2025
|
|
|
|
|
'mxbai-embed-large-v1': 'mixedbread-ai/mxbai-embed-large-v1',
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# On clustering
|
|
|
|
|
# mixedbread-ai/mxbai-embed-large-v1: 46.71
|
|
|
|
|
# gte-large-en-v1.5: 47.95
|
|
|
|
|
# Qwen/Qwen3-Embedding-0.6B: 52.33
|
|
|
|
|
# Qwen/Qwen3-Embedding-4B: 57.15
|
|
|
|
|
|
|
|
|
|
# On STS
|
|
|
|
|
# gte-large-en-v1.5: 81.43
|
|
|
|
|
# Qwen/Qwen3-Embedding-0.6B: 76.17
|
|
|
|
|
# Qwen/Qwen3-Embedding-4B: 80.86
|
|
|
|
|
# mixedbread-ai/mxbai-embed-large-v1: 85.00
|
|
|
|
|
|
|
|
|
|
# Load all models into memory
|
|
|
|
|
print("Loading sentence transformer models...")
|
|
|
|
|
models = {}
|
|
|
|
|
|
|
|
|
|
models['all-mpnet-base-v2'] = SentenceTransformer('all-mpnet-base-v2')
|
|
|
|
|
print("Loading Alibaba-NLP/gte-large-en-v1.5")
|
|
|
|
|
models['gte-large-en-v1.5'] = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
|
|
|
|
|
#print("Loading Qwen/Qwen3-Embedding-4B")
|
|
|
|
|
#models['qwen3-embedding-4b'] = SentenceTransformer('Qwen/Qwen3-Embedding-4B', trust_remote_code=True)
|
|
|
|
|
print("Loading mixedbread-ai/mxbai-embed-large-v1")
|
|
|
|
|
models["mxbai-embed-large-v1"] = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1')
|
|
|
|
|
print("All models loaded!")
|
2025-10-30 14:16:04 -07:00
|
|
|
|
|
|
|
|
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
|
|
|
|
|
|
|
|
|
|
def cos_sim(a, b):
|
|
|
|
|
sims = a @ b.T
|
|
|
|
|
a_norm = np.linalg.norm(a, axis=-1)
|
|
|
|
|
b_norm = np.linalg.norm(b, axis=-1)
|
|
|
|
|
a_normalized = (sims.T / a_norm.T).T
|
|
|
|
|
sims = a_normalized / b_norm
|
|
|
|
|
return sims
|
|
|
|
|
|
|
|
|
|
def degree_power(A, k):
|
|
|
|
|
degrees = np.power(np.array(A.sum(1)), k).ravel()
|
|
|
|
|
D = np.diag(degrees)
|
|
|
|
|
return D
|
|
|
|
|
|
|
|
|
|
def normalized_adjacency(A):
|
|
|
|
|
normalized_D = degree_power(A, -0.5)
|
|
|
|
|
return torch.from_numpy(normalized_D.dot(A).dot(normalized_D))
|
|
|
|
|
|
|
|
|
|
def get_sentences(source_text):
|
|
|
|
|
sentence_ranges = list(sent_detector.span_tokenize(source_text))
|
|
|
|
|
sentences = [source_text[start:end] for start, end in sentence_ranges]
|
|
|
|
|
return sentences, sentence_ranges
|
|
|
|
|
|
2025-10-30 16:26:48 -07:00
|
|
|
def text_rank(sentences, model_name='all-mpnet-base-v2'):
|
|
|
|
|
model = models[model_name]
|
2025-10-30 14:16:04 -07:00
|
|
|
vectors = model.encode(sentences)
|
|
|
|
|
adjacency = torch.tensor(cos_sim(vectors, vectors)).fill_diagonal_(0.)
|
|
|
|
|
adjacency[adjacency < 0] = 0
|
|
|
|
|
return normalized_adjacency(adjacency)
|
|
|
|
|
|
|
|
|
|
def terminal_distr(adjacency, initial=None):
|
|
|
|
|
sample = initial if initial is not None else torch.full((adjacency.shape[0],), 1.)
|
|
|
|
|
scores = sample.matmul(torch.matrix_power(adjacency, 10)).numpy().tolist()
|
|
|
|
|
return scores
|
|
|
|
|
|
2025-10-30 16:26:48 -07:00
|
|
|
def extract(source_text, model_name='all-mpnet-base-v2'):
|
2025-10-30 14:16:04 -07:00
|
|
|
sentences, sentence_ranges = get_sentences(source_text)
|
2025-10-30 16:26:48 -07:00
|
|
|
adjacency = text_rank(sentences, model_name)
|
2025-10-30 14:16:04 -07:00
|
|
|
return sentence_ranges, adjacency
|
|
|
|
|
|
|
|
|
|
def get_results(sentences, adjacency):
|
|
|
|
|
scores = terminal_distr(adjacency)
|
|
|
|
|
for score, sentence in sorted(zip(scores, sentences), key=lambda xs: xs[0]):
|
|
|
|
|
if score > 1.1:
|
|
|
|
|
print('{:0.2f}: {}'.format(score, sentence))
|