146 lines
4.7 KiB
Python
146 lines
4.7 KiB
Python
# Salience API
|
|
# ============
|
|
# Uses a worker thread for model inference to avoid fork() issues with Metal/MPS.
|
|
# The worker thread owns all model instances; HTTP handlers submit work via queue.
|
|
|
|
print("Starting salience __init__.py...")
|
|
|
|
from flask import Flask, request
|
|
from flask_cors import CORS
|
|
import numpy as np
|
|
from .salience import submit_work, AVAILABLE_MODELS
|
|
import json
|
|
import time
|
|
from collections import deque
|
|
import threading
|
|
|
|
app = Flask(__name__)
|
|
CORS(app, origins=["http://localhost:5173"])
|
|
|
|
# Thread-safe stats tracker for this worker process
|
|
class StatsTracker:
|
|
def __init__(self):
|
|
# Store (start_time, end_time, duration) for successful requests
|
|
self.processing_spans = deque(maxlen=1000)
|
|
# Store arrival timestamps for overflow requests
|
|
self.overflow_arrivals = deque(maxlen=1000)
|
|
self.lock = threading.Lock()
|
|
|
|
def add_processing_span(self, start_time, end_time):
|
|
duration = end_time - start_time
|
|
with self.lock:
|
|
self.processing_spans.append((start_time, end_time, duration))
|
|
# Clean old entries (>5 min)
|
|
cutoff = time.time() - 300
|
|
while self.processing_spans and self.processing_spans[0][0] < cutoff:
|
|
self.processing_spans.popleft()
|
|
|
|
def add_overflow_arrival(self, arrival_time):
|
|
with self.lock:
|
|
self.overflow_arrivals.append(arrival_time)
|
|
# Clean old entries (>5 min)
|
|
cutoff = time.time() - 300
|
|
while self.overflow_arrivals and self.overflow_arrivals[0] < cutoff:
|
|
self.overflow_arrivals.popleft()
|
|
|
|
def get_stats(self):
|
|
with self.lock:
|
|
return {
|
|
'processing_spans': [
|
|
{'start': start, 'end': end, 'duration': duration}
|
|
for start, end, duration in self.processing_spans
|
|
],
|
|
'overflow_arrivals': list(self.overflow_arrivals),
|
|
'window_seconds': 300 # 5 minutes
|
|
}
|
|
|
|
stats_tracker = StatsTracker()
|
|
|
|
# Load default text from transcript.txt for GET requests
|
|
with open('./transcript.txt', 'r') as file:
|
|
default_source_text = file.read().strip()
|
|
|
|
@app.route("/models")
|
|
def models_view():
|
|
return json.dumps(list(AVAILABLE_MODELS.keys()))
|
|
|
|
@app.route("/overflow", methods=['GET', 'POST'])
|
|
def overflow_view():
|
|
"""
|
|
Endpoint hit when HAProxy queue is full.
|
|
Returns 429 with statistics about processing and overflow.
|
|
"""
|
|
arrival_time = time.time()
|
|
stats_tracker.add_overflow_arrival(arrival_time)
|
|
|
|
stats = stats_tracker.get_stats()
|
|
|
|
response = {
|
|
'error': 'Queue full',
|
|
'status': 429,
|
|
'stats': stats,
|
|
'message': 'Service is at capacity. Try again or check queue statistics.'
|
|
}
|
|
|
|
return json.dumps(response), 429
|
|
|
|
@app.route("/stats")
|
|
def stats_view():
|
|
"""
|
|
Endpoint for frontend to poll current queue statistics.
|
|
Returns processing spans and overflow arrivals from last 5 minutes.
|
|
"""
|
|
stats = stats_tracker.get_stats()
|
|
return json.dumps(stats)
|
|
|
|
@app.route("/salience", methods=['GET'])
|
|
def salience_view_default():
|
|
"""GET endpoint - processes default text from transcript.txt"""
|
|
start_time = time.time()
|
|
|
|
model_name = request.args.get('model', 'all-mpnet-base-v2')
|
|
|
|
# Validate model name
|
|
if model_name not in AVAILABLE_MODELS:
|
|
return json.dumps({'error': f'Invalid model: {model_name}'}), 400
|
|
|
|
sentence_ranges, adjacency = submit_work(default_source_text, model_name)
|
|
|
|
end_time = time.time()
|
|
stats_tracker.add_processing_span(start_time, end_time)
|
|
|
|
return json.dumps({
|
|
'source': default_source_text,
|
|
'intervals': sentence_ranges,
|
|
'adjacency': np.nan_to_num(adjacency.numpy()).tolist(),
|
|
'model': model_name,
|
|
})
|
|
|
|
@app.route("/salience", methods=['POST'])
|
|
def salience_view_custom():
|
|
"""POST endpoint - processes text from request body"""
|
|
start_time = time.time()
|
|
|
|
model_name = request.args.get('model', 'all-mpnet-base-v2')
|
|
|
|
# Validate model name
|
|
if model_name not in AVAILABLE_MODELS:
|
|
return json.dumps({'error': f'Invalid model: {model_name}'}), 400
|
|
|
|
# Get document content from request body as plain text
|
|
source_text = request.data.decode('utf-8').strip()
|
|
|
|
if not source_text:
|
|
return json.dumps({'error': 'No text provided'}), 400
|
|
|
|
sentence_ranges, adjacency = submit_work(source_text, model_name)
|
|
|
|
end_time = time.time()
|
|
stats_tracker.add_processing_span(start_time, end_time)
|
|
|
|
return json.dumps({
|
|
'source': source_text,
|
|
'intervals': sentence_ranges,
|
|
'adjacency': np.nan_to_num(adjacency.numpy()).tolist(),
|
|
'model': model_name,
|
|
})
|