feat: add multiple models

This commit is contained in:
nobody 2025-10-30 16:26:48 -07:00
commit fee0e643e4
Signed by: GrocerPublishAgent
GPG key ID: D460CD54A9E3AB86
5 changed files with 517 additions and 32 deletions

View file

@ -1,18 +1,30 @@
from flask import Flask
from flask import Flask, request
import numpy as np
from .salience import extract
from .salience import extract, AVAILABLE_MODELS
import json
app = Flask(__name__)
with open('./transcript.txt', 'r') as file:
source_text = file.read().strip()
sentence_ranges, adjacency = extract(source_text)
@app.route("/models")
def models_view():
return json.dumps(list(AVAILABLE_MODELS.keys()))
@app.route("/salience")
def salience_view():
model_name = request.args.get('model', 'all-mpnet-base-v2')
# Validate model name
if model_name not in AVAILABLE_MODELS:
return json.dumps({'error': f'Invalid model: {model_name}'}), 400
sentence_ranges, adjacency = extract(source_text, model_name)
return json.dumps({
'source': source_text,
'intervals': sentence_ranges,
'adjacency': np.nan_to_num(adjacency.numpy()).tolist(),
'model': model_name,
})

View file

@ -13,9 +13,45 @@ NLTK_DATA_DIR = os.path.join(PROJECT_DIR, 'nltk_data')
nltk.data.path.insert(0, NLTK_DATA_DIR)
# Download to the custom location
nltk.download('punkt', download_dir=NLTK_DATA_DIR)
# Using punkt_tab (the modern tab-separated format introduced in NLTK 3.8+)
# instead of the older punkt pickle format
# The punkt_tab model version depends on the NLTK Python package version
# Check your NLTK version with: uv pip show nltk
nltk.download('punkt_tab', download_dir=NLTK_DATA_DIR)
# Available models for the demo
AVAILABLE_MODELS = {
'all-mpnet-base-v2': 'all-mpnet-base-v2', # Dec 2020
'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5', # Jan 2024
# 'qwen3-embedding-4b': 'Qwen/Qwen3-Embedding-4B', # April 2025
'mxbai-embed-large-v1': 'mixedbread-ai/mxbai-embed-large-v1',
}
# On clustering
# mixedbread-ai/mxbai-embed-large-v1: 46.71
# gte-large-en-v1.5: 47.95
# Qwen/Qwen3-Embedding-0.6B: 52.33
# Qwen/Qwen3-Embedding-4B: 57.15
# On STS
# gte-large-en-v1.5: 81.43
# Qwen/Qwen3-Embedding-0.6B: 76.17
# Qwen/Qwen3-Embedding-4B: 80.86
# mixedbread-ai/mxbai-embed-large-v1: 85.00
# Load all models into memory
print("Loading sentence transformer models...")
models = {}
models['all-mpnet-base-v2'] = SentenceTransformer('all-mpnet-base-v2')
print("Loading Alibaba-NLP/gte-large-en-v1.5")
models['gte-large-en-v1.5'] = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
#print("Loading Qwen/Qwen3-Embedding-4B")
#models['qwen3-embedding-4b'] = SentenceTransformer('Qwen/Qwen3-Embedding-4B', trust_remote_code=True)
print("Loading mixedbread-ai/mxbai-embed-large-v1")
models["mxbai-embed-large-v1"] = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1')
print("All models loaded!")
model = SentenceTransformer('all-mpnet-base-v2')
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
def cos_sim(a, b):
@ -40,7 +76,8 @@ def get_sentences(source_text):
sentences = [source_text[start:end] for start, end in sentence_ranges]
return sentences, sentence_ranges
def text_rank(sentences):
def text_rank(sentences, model_name='all-mpnet-base-v2'):
model = models[model_name]
vectors = model.encode(sentences)
adjacency = torch.tensor(cos_sim(vectors, vectors)).fill_diagonal_(0.)
adjacency[adjacency < 0] = 0
@ -51,9 +88,9 @@ def terminal_distr(adjacency, initial=None):
scores = sample.matmul(torch.matrix_power(adjacency, 10)).numpy().tolist()
return scores
def extract(source_text):
def extract(source_text, model_name='all-mpnet-base-v2'):
sentences, sentence_ranges = get_sentences(source_text)
adjacency = text_rank(sentences)
adjacency = text_rank(sentences, model_name)
return sentence_ranges, adjacency
def get_results(sentences, adjacency):

View file

@ -1,4 +1,4 @@
CTYPE HTML>
<!DOCTYPE HTML>
<html>
<head>
<meta charset="utf8" />
@ -36,6 +36,23 @@ CTYPE HTML>
font-weight: normal;
color: #a0a0a0;
}
.controls {
width: 700px;
margin: 15px auto;
font-family: sans-serif;
}
.controls label {
margin-right: 10px;
color: #4d4d4d;
}
.controls select {
padding: 5px 10px;
font-size: 14px;
border: 1px solid #ccc;
border-radius: 4px;
background-color: white;
cursor: pointer;
}
span.sentence {
--salience: 1;
background-color: rgba(249, 239, 104, var(--salience));
@ -51,16 +68,27 @@ CTYPE HTML>
<body>
<h1>
Salience
<span>automatic sentence highlights based on their significance to the document</span>
<span>sentence highlights based on their significance to the document</span>
</h1>
<div class="controls">
<label for="model-select">Model:</label>
<select id="model-select">
<option value="">Loading...</option>
</select>
</div>
<p id="content"></p>
<script type="text/javascript">
const content = document.querySelector('#content')
const modelSelect = document.querySelector('#model-select')
let adjacency = null
let currentModel = 'all-mpnet-base-v2'
function scale(score) {
return Math.max(0, Math.min(1, score ** 3 - 0.95))
}
let exponent = 5
const redraw = () => {
if (!adjacency) return
const sentences = document.querySelectorAll('span.sentence')
@ -90,31 +118,62 @@ CTYPE HTML>
})
}
}
function loadSalience(model) {
// Clear existing content
content.innerHTML = ''
adjacency = null
fetch(`/salience?model=${encodeURIComponent(model)}`).then(async res => {
const data = await res.json()
console.log(data)
const source = data.source
const intervals = data.intervals
const tokens = intervals.map(([start, end]) => source.substr(start, end - start))
adjacency = data.adjacency
tokens.forEach((t, i) => {
const token = document.createElement('span')
token.innerText = t
token.classList.add('sentence')
content.appendChild(token)
if (tokens[i+1] && intervals[i+1][0] > intervals[i][1]) {
const intervening = document.createElement('span')
const start = intervals[i][1]
intervening.innerText = source.substr(start, intervals[i+1][0] - start)
content.appendChild(intervening)
}
})
redraw()
})
}
// Load available models and populate dropdown
fetch('/models').then(async res => {
const models = await res.json()
modelSelect.innerHTML = ''
models.forEach(model => {
const option = document.createElement('option')
option.value = model
option.textContent = model
if (model === currentModel) {
option.selected = true
}
modelSelect.appendChild(option)
})
})
// Handle model selection change
modelSelect.addEventListener('change', (e) => {
currentModel = e.target.value
loadSalience(currentModel)
})
// Disabled functionality to center highlights on a selected fragment
// document.addEventListener('mousemove', redraw)
// document.addEventListener('mouseup', redraw)
fetch('/salience').then(async res => {
const data = await res.json()
console.log(data)
const source = data.source
const intervals = data.intervals
const tokens = intervals.map(([start, end]) => source.substr(start, end - start))
adjacency = data.adjacency
tokens.forEach((t, i) => {
const token = document.createElement('span')
token.innerText = t
token.classList.add('sentence')
content.appendChild(token)
if (tokens[i+1] && intervals[i+1][0] > intervals[i][1]) {
const intervening = document.createElement('span')
const start = intervals[i][1]
intervening.innerText = source.substr(start, intervals[i+1][0] - start)
content.appendChild(intervening)
}
})
redraw()
})
// Load initial salience data
loadSalience(currentModel)
</script>
</body>
</html>