Use port 15000 for the default development port. If you ever cloned the repo on Mac, ran the demo, and saw the models list would never load, or saw 403 errors in browser console. Check the Server headers. Good chances are the request went to AirPlay service which is also listening on port 5000.
146 lines
4.7 KiB
Python
146 lines
4.7 KiB
Python
# Salience API
|
|
# ============
|
|
# Uses a worker thread for model inference to avoid fork() issues with Metal/MPS.
|
|
# The worker thread owns all model instances; HTTP handlers submit work via queue.
|
|
|
|
print("Starting salience __init__.py...")
|
|
|
|
from flask import Flask, request
|
|
from flask_cors import CORS
|
|
import numpy as np
|
|
from .salience import submit_work, AVAILABLE_MODELS
|
|
import json
|
|
import time
|
|
from collections import deque
|
|
import threading
|
|
|
|
app = Flask(__name__)
|
|
CORS(app, origins=["http://localhost:5173", "http://127.0.0.1:5173"])
|
|
|
|
# Thread-safe stats tracker for this worker process
|
|
class StatsTracker:
|
|
def __init__(self):
|
|
# Store (start_time, end_time, duration) for successful requests
|
|
self.processing_spans = deque(maxlen=1000)
|
|
# Store arrival timestamps for overflow requests
|
|
self.overflow_arrivals = deque(maxlen=1000)
|
|
self.lock = threading.Lock()
|
|
|
|
def add_processing_span(self, start_time, end_time):
|
|
duration = end_time - start_time
|
|
with self.lock:
|
|
self.processing_spans.append((start_time, end_time, duration))
|
|
# Clean old entries (>5 min)
|
|
cutoff = time.time() - 300
|
|
while self.processing_spans and self.processing_spans[0][0] < cutoff:
|
|
self.processing_spans.popleft()
|
|
|
|
def add_overflow_arrival(self, arrival_time):
|
|
with self.lock:
|
|
self.overflow_arrivals.append(arrival_time)
|
|
# Clean old entries (>5 min)
|
|
cutoff = time.time() - 300
|
|
while self.overflow_arrivals and self.overflow_arrivals[0] < cutoff:
|
|
self.overflow_arrivals.popleft()
|
|
|
|
def get_stats(self):
|
|
with self.lock:
|
|
return {
|
|
'processing_spans': [
|
|
{'start': start, 'end': end, 'duration': duration}
|
|
for start, end, duration in self.processing_spans
|
|
],
|
|
'overflow_arrivals': list(self.overflow_arrivals),
|
|
'window_seconds': 300 # 5 minutes
|
|
}
|
|
|
|
stats_tracker = StatsTracker()
|
|
|
|
# Load default text from transcript.txt for GET requests
|
|
with open('./transcript.txt', 'r') as file:
|
|
default_source_text = file.read().strip()
|
|
|
|
@app.route("/models")
|
|
def models_view():
|
|
return json.dumps(list(AVAILABLE_MODELS.keys()))
|
|
|
|
@app.route("/overflow", methods=['GET', 'POST'])
|
|
def overflow_view():
|
|
"""
|
|
Endpoint hit when HAProxy queue is full.
|
|
Returns 429 with statistics about processing and overflow.
|
|
"""
|
|
arrival_time = time.time()
|
|
stats_tracker.add_overflow_arrival(arrival_time)
|
|
|
|
stats = stats_tracker.get_stats()
|
|
|
|
response = {
|
|
'error': 'Queue full',
|
|
'status': 429,
|
|
'stats': stats,
|
|
'message': 'Service is at capacity. Try again or check queue statistics.'
|
|
}
|
|
|
|
return json.dumps(response), 429
|
|
|
|
@app.route("/stats")
|
|
def stats_view():
|
|
"""
|
|
Endpoint for frontend to poll current queue statistics.
|
|
Returns processing spans and overflow arrivals from last 5 minutes.
|
|
"""
|
|
stats = stats_tracker.get_stats()
|
|
return json.dumps(stats)
|
|
|
|
@app.route("/salience", methods=['GET'])
|
|
def salience_view_default():
|
|
"""GET endpoint - processes default text from transcript.txt"""
|
|
start_time = time.time()
|
|
|
|
model_name = request.args.get('model', 'all-mpnet-base-v2')
|
|
|
|
# Validate model name
|
|
if model_name not in AVAILABLE_MODELS:
|
|
return json.dumps({'error': f'Invalid model: {model_name}'}), 400
|
|
|
|
sentence_ranges, adjacency = submit_work(default_source_text, model_name)
|
|
|
|
end_time = time.time()
|
|
stats_tracker.add_processing_span(start_time, end_time)
|
|
|
|
return json.dumps({
|
|
'source': default_source_text,
|
|
'intervals': sentence_ranges,
|
|
'adjacency': np.nan_to_num(adjacency.numpy()).tolist(),
|
|
'model': model_name,
|
|
})
|
|
|
|
@app.route("/salience", methods=['POST'])
|
|
def salience_view_custom():
|
|
"""POST endpoint - processes text from request body"""
|
|
start_time = time.time()
|
|
|
|
model_name = request.args.get('model', 'all-mpnet-base-v2')
|
|
|
|
# Validate model name
|
|
if model_name not in AVAILABLE_MODELS:
|
|
return json.dumps({'error': f'Invalid model: {model_name}'}), 400
|
|
|
|
# Get document content from request body as plain text
|
|
source_text = request.data.decode('utf-8').strip()
|
|
|
|
if not source_text:
|
|
return json.dumps({'error': 'No text provided'}), 400
|
|
|
|
sentence_ranges, adjacency = submit_work(source_text, model_name)
|
|
|
|
end_time = time.time()
|
|
stats_tracker.add_processing_span(start_time, end_time)
|
|
|
|
return json.dumps({
|
|
'source': source_text,
|
|
'intervals': sentence_ranges,
|
|
'adjacency': np.nan_to_num(adjacency.numpy()).tolist(),
|
|
'model': model_name,
|
|
})
|