refactor: rename ML model python backend folder
This commit is contained in:
parent
fee0e643e4
commit
76c28bafab
9 changed files with 0 additions and 0 deletions
30
api/salience/__init__.py
Normal file
30
api/salience/__init__.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from flask import Flask, request
|
||||
import numpy as np
|
||||
from .salience import extract, AVAILABLE_MODELS
|
||||
import json
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
with open('./transcript.txt', 'r') as file:
|
||||
source_text = file.read().strip()
|
||||
|
||||
@app.route("/models")
|
||||
def models_view():
|
||||
return json.dumps(list(AVAILABLE_MODELS.keys()))
|
||||
|
||||
@app.route("/salience")
|
||||
def salience_view():
|
||||
model_name = request.args.get('model', 'all-mpnet-base-v2')
|
||||
|
||||
# Validate model name
|
||||
if model_name not in AVAILABLE_MODELS:
|
||||
return json.dumps({'error': f'Invalid model: {model_name}'}), 400
|
||||
|
||||
sentence_ranges, adjacency = extract(source_text, model_name)
|
||||
|
||||
return json.dumps({
|
||||
'source': source_text,
|
||||
'intervals': sentence_ranges,
|
||||
'adjacency': np.nan_to_num(adjacency.numpy()).tolist(),
|
||||
'model': model_name,
|
||||
})
|
||||
100
api/salience/salience.py
Normal file
100
api/salience/salience.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import nltk.data
|
||||
import nltk
|
||||
import os
|
||||
|
||||
# Set NLTK data path to project directory
|
||||
PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
NLTK_DATA_DIR = os.path.join(PROJECT_DIR, 'nltk_data')
|
||||
|
||||
# Add to NLTK's search path
|
||||
nltk.data.path.insert(0, NLTK_DATA_DIR)
|
||||
|
||||
# Download to the custom location
|
||||
# Using punkt_tab (the modern tab-separated format introduced in NLTK 3.8+)
|
||||
# instead of the older punkt pickle format
|
||||
# The punkt_tab model version depends on the NLTK Python package version
|
||||
# Check your NLTK version with: uv pip show nltk
|
||||
nltk.download('punkt_tab', download_dir=NLTK_DATA_DIR)
|
||||
|
||||
# Available models for the demo
|
||||
AVAILABLE_MODELS = {
|
||||
'all-mpnet-base-v2': 'all-mpnet-base-v2', # Dec 2020
|
||||
'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5', # Jan 2024
|
||||
# 'qwen3-embedding-4b': 'Qwen/Qwen3-Embedding-4B', # April 2025
|
||||
'mxbai-embed-large-v1': 'mixedbread-ai/mxbai-embed-large-v1',
|
||||
}
|
||||
|
||||
# On clustering
|
||||
# mixedbread-ai/mxbai-embed-large-v1: 46.71
|
||||
# gte-large-en-v1.5: 47.95
|
||||
# Qwen/Qwen3-Embedding-0.6B: 52.33
|
||||
# Qwen/Qwen3-Embedding-4B: 57.15
|
||||
|
||||
# On STS
|
||||
# gte-large-en-v1.5: 81.43
|
||||
# Qwen/Qwen3-Embedding-0.6B: 76.17
|
||||
# Qwen/Qwen3-Embedding-4B: 80.86
|
||||
# mixedbread-ai/mxbai-embed-large-v1: 85.00
|
||||
|
||||
# Load all models into memory
|
||||
print("Loading sentence transformer models...")
|
||||
models = {}
|
||||
|
||||
models['all-mpnet-base-v2'] = SentenceTransformer('all-mpnet-base-v2')
|
||||
print("Loading Alibaba-NLP/gte-large-en-v1.5")
|
||||
models['gte-large-en-v1.5'] = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
|
||||
#print("Loading Qwen/Qwen3-Embedding-4B")
|
||||
#models['qwen3-embedding-4b'] = SentenceTransformer('Qwen/Qwen3-Embedding-4B', trust_remote_code=True)
|
||||
print("Loading mixedbread-ai/mxbai-embed-large-v1")
|
||||
models["mxbai-embed-large-v1"] = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1')
|
||||
print("All models loaded!")
|
||||
|
||||
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
|
||||
|
||||
def cos_sim(a, b):
|
||||
sims = a @ b.T
|
||||
a_norm = np.linalg.norm(a, axis=-1)
|
||||
b_norm = np.linalg.norm(b, axis=-1)
|
||||
a_normalized = (sims.T / a_norm.T).T
|
||||
sims = a_normalized / b_norm
|
||||
return sims
|
||||
|
||||
def degree_power(A, k):
|
||||
degrees = np.power(np.array(A.sum(1)), k).ravel()
|
||||
D = np.diag(degrees)
|
||||
return D
|
||||
|
||||
def normalized_adjacency(A):
|
||||
normalized_D = degree_power(A, -0.5)
|
||||
return torch.from_numpy(normalized_D.dot(A).dot(normalized_D))
|
||||
|
||||
def get_sentences(source_text):
|
||||
sentence_ranges = list(sent_detector.span_tokenize(source_text))
|
||||
sentences = [source_text[start:end] for start, end in sentence_ranges]
|
||||
return sentences, sentence_ranges
|
||||
|
||||
def text_rank(sentences, model_name='all-mpnet-base-v2'):
|
||||
model = models[model_name]
|
||||
vectors = model.encode(sentences)
|
||||
adjacency = torch.tensor(cos_sim(vectors, vectors)).fill_diagonal_(0.)
|
||||
adjacency[adjacency < 0] = 0
|
||||
return normalized_adjacency(adjacency)
|
||||
|
||||
def terminal_distr(adjacency, initial=None):
|
||||
sample = initial if initial is not None else torch.full((adjacency.shape[0],), 1.)
|
||||
scores = sample.matmul(torch.matrix_power(adjacency, 10)).numpy().tolist()
|
||||
return scores
|
||||
|
||||
def extract(source_text, model_name='all-mpnet-base-v2'):
|
||||
sentences, sentence_ranges = get_sentences(source_text)
|
||||
adjacency = text_rank(sentences, model_name)
|
||||
return sentence_ranges, adjacency
|
||||
|
||||
def get_results(sentences, adjacency):
|
||||
scores = terminal_distr(adjacency)
|
||||
for score, sentence in sorted(zip(scores, sentences), key=lambda xs: xs[0]):
|
||||
if score > 1.1:
|
||||
print('{:0.2f}: {}'.format(score, sentence))
|
||||
179
api/salience/static/index.html
Normal file
179
api/salience/static/index.html
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
<!DOCTYPE HTML>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf8" />
|
||||
<title>Salience</title>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjs/11.8.0/math.js" integrity="sha512-VW8/i4IZkHxdD8OlqNdF7fGn3ba0+lYqag+Uy4cG6BtJ/LIr8t23s/vls70pQ41UasHH0tL57GQfKDApqc9izA==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
|
||||
<style>
|
||||
body {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
}
|
||||
p {
|
||||
width: 700px;
|
||||
margin: 1em auto;
|
||||
color: #4d4d4d;
|
||||
font-family: sans-serif;
|
||||
font-size: 15px;
|
||||
line-height: 1.33em;
|
||||
flex: 1;
|
||||
overflow-y: scroll;
|
||||
}
|
||||
h1 {
|
||||
width: 700px;
|
||||
text-align: left;
|
||||
margin: 15px auto;
|
||||
margin-bottom: 0;
|
||||
color: #000;
|
||||
font-family: sans-serif;
|
||||
font-size: 24px;
|
||||
}
|
||||
h1 span {
|
||||
display: block;
|
||||
font-size: 0.7em;
|
||||
font-weight: normal;
|
||||
color: #a0a0a0;
|
||||
}
|
||||
.controls {
|
||||
width: 700px;
|
||||
margin: 15px auto;
|
||||
font-family: sans-serif;
|
||||
}
|
||||
.controls label {
|
||||
margin-right: 10px;
|
||||
color: #4d4d4d;
|
||||
}
|
||||
.controls select {
|
||||
padding: 5px 10px;
|
||||
font-size: 14px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
background-color: white;
|
||||
cursor: pointer;
|
||||
}
|
||||
span.sentence {
|
||||
--salience: 1;
|
||||
background-color: rgba(249, 239, 104, var(--salience));
|
||||
}
|
||||
span.highlight {
|
||||
background-color: rgb(185, 225, 244);
|
||||
}
|
||||
::selection {
|
||||
background: transparent;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>
|
||||
Salience
|
||||
<span>sentence highlights based on their significance to the document</span>
|
||||
</h1>
|
||||
<div class="controls">
|
||||
<label for="model-select">Model:</label>
|
||||
<select id="model-select">
|
||||
<option value="">Loading...</option>
|
||||
</select>
|
||||
</div>
|
||||
<p id="content"></p>
|
||||
<script type="text/javascript">
|
||||
const content = document.querySelector('#content')
|
||||
const modelSelect = document.querySelector('#model-select')
|
||||
let adjacency = null
|
||||
let currentModel = 'all-mpnet-base-v2'
|
||||
|
||||
function scale(score) {
|
||||
return Math.max(0, Math.min(1, score ** 3 - 0.95))
|
||||
}
|
||||
|
||||
let exponent = 5
|
||||
|
||||
const redraw = () => {
|
||||
if (!adjacency) return
|
||||
const sentences = document.querySelectorAll('span.sentence')
|
||||
if (!window.getSelection().isCollapsed) {
|
||||
const sel = window.getSelection()
|
||||
const fromNode = sel.anchorNode.parentNode
|
||||
const toNode = sel.extentNode.parentNode
|
||||
const fromIdx = Array.from(sentences).indexOf(fromNode)
|
||||
const toIdx = Array.from(sentences).indexOf(toNode)
|
||||
const range = [fromIdx, toIdx]
|
||||
console.log('range', range)
|
||||
range.sort((a, b) => a - b)
|
||||
const vec = adjacency.map((x, i) => (i >= range[0] && i <= range[1]) ? 1 : 0)
|
||||
const vec_sum = vec.reduce((a, x) => a + x, 0)
|
||||
const scores = math.multiply(vec, adjacency).map(x => x * adjacency.length / vec_sum)
|
||||
Array.from(sentences).forEach((node, i) => {
|
||||
node.style.setProperty('--salience', scale(scores[i]))
|
||||
if (i >= range[0] && i <= range[1]) node.classList.add('highlight')
|
||||
else node.classList.remove('highlight')
|
||||
})
|
||||
} else {
|
||||
const initial = adjacency.map(() => 1)
|
||||
const scores = math.multiply(initial, math.pow(adjacency, exponent))
|
||||
Array.from(sentences).forEach((node, i) => {
|
||||
node.style.setProperty('--salience', scale(scores[i]))
|
||||
node.classList.remove('highlight')
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function loadSalience(model) {
|
||||
// Clear existing content
|
||||
content.innerHTML = ''
|
||||
adjacency = null
|
||||
|
||||
fetch(`/salience?model=${encodeURIComponent(model)}`).then(async res => {
|
||||
const data = await res.json()
|
||||
console.log(data)
|
||||
const source = data.source
|
||||
const intervals = data.intervals
|
||||
const tokens = intervals.map(([start, end]) => source.substr(start, end - start))
|
||||
adjacency = data.adjacency
|
||||
tokens.forEach((t, i) => {
|
||||
const token = document.createElement('span')
|
||||
token.innerText = t
|
||||
token.classList.add('sentence')
|
||||
content.appendChild(token)
|
||||
if (tokens[i+1] && intervals[i+1][0] > intervals[i][1]) {
|
||||
const intervening = document.createElement('span')
|
||||
const start = intervals[i][1]
|
||||
intervening.innerText = source.substr(start, intervals[i+1][0] - start)
|
||||
content.appendChild(intervening)
|
||||
}
|
||||
})
|
||||
redraw()
|
||||
})
|
||||
}
|
||||
|
||||
// Load available models and populate dropdown
|
||||
fetch('/models').then(async res => {
|
||||
const models = await res.json()
|
||||
modelSelect.innerHTML = ''
|
||||
models.forEach(model => {
|
||||
const option = document.createElement('option')
|
||||
option.value = model
|
||||
option.textContent = model
|
||||
if (model === currentModel) {
|
||||
option.selected = true
|
||||
}
|
||||
modelSelect.appendChild(option)
|
||||
})
|
||||
})
|
||||
|
||||
// Handle model selection change
|
||||
modelSelect.addEventListener('change', (e) => {
|
||||
currentModel = e.target.value
|
||||
loadSalience(currentModel)
|
||||
})
|
||||
|
||||
// Disabled functionality to center highlights on a selected fragment
|
||||
// document.addEventListener('mousemove', redraw)
|
||||
// document.addEventListener('mouseup', redraw)
|
||||
|
||||
// Load initial salience data
|
||||
loadSalience(currentModel)
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
Loading…
Add table
Add a link
Reference in a new issue