feat: deploy model api server to chicago-web01
This commit is contained in:
parent
515a0e6d81
commit
0cb89ddc80
6 changed files with 394 additions and 18 deletions
|
|
@ -1,12 +1,74 @@
|
|||
# Memory Sharing for ML Models
|
||||
# ============================
|
||||
# This app is designed to run with Gunicorn's --preload flag, which loads the
|
||||
# SentenceTransformer models once in the master process before forking workers.
|
||||
# On Linux, fork uses copy-on-write (COW) semantics, so workers share the
|
||||
# read-only model weights in memory rather than each loading their own copy.
|
||||
# This is critical for keeping memory usage reasonable with large transformer models.
|
||||
#
|
||||
# ResourceTracker errors on shutdown (Python 3.14):
|
||||
# When you Ctrl+C the Gunicorn process, you may see
|
||||
# "ChildProcessError: [Errno 10] No child processes"
|
||||
# from multiprocessing.resource_tracker.
|
||||
#
|
||||
# I think this is harmless. I think what happens is each forked worker gets a
|
||||
# copy of the ResourceTracker object, then each copy tries to deallocate the
|
||||
# same resources. The process still shuts down reasonbly quickly, so I'm not
|
||||
# concerned.
|
||||
|
||||
print("Starting salience __init__.py...")
|
||||
|
||||
from flask import Flask, request
|
||||
from flask_cors import CORS
|
||||
import numpy as np
|
||||
from .salience import extract, AVAILABLE_MODELS
|
||||
import json
|
||||
import time
|
||||
from collections import deque
|
||||
import threading
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app, origins=["http://localhost:5173"])
|
||||
|
||||
# Thread-safe stats tracker for this worker process
|
||||
class StatsTracker:
|
||||
def __init__(self):
|
||||
# Store (start_time, end_time, duration) for successful requests
|
||||
self.processing_spans = deque(maxlen=1000)
|
||||
# Store arrival timestamps for overflow requests
|
||||
self.overflow_arrivals = deque(maxlen=1000)
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def add_processing_span(self, start_time, end_time):
|
||||
duration = end_time - start_time
|
||||
with self.lock:
|
||||
self.processing_spans.append((start_time, end_time, duration))
|
||||
# Clean old entries (>5 min)
|
||||
cutoff = time.time() - 300
|
||||
while self.processing_spans and self.processing_spans[0][0] < cutoff:
|
||||
self.processing_spans.popleft()
|
||||
|
||||
def add_overflow_arrival(self, arrival_time):
|
||||
with self.lock:
|
||||
self.overflow_arrivals.append(arrival_time)
|
||||
# Clean old entries (>5 min)
|
||||
cutoff = time.time() - 300
|
||||
while self.overflow_arrivals and self.overflow_arrivals[0] < cutoff:
|
||||
self.overflow_arrivals.popleft()
|
||||
|
||||
def get_stats(self):
|
||||
with self.lock:
|
||||
return {
|
||||
'processing_spans': [
|
||||
{'start': start, 'end': end, 'duration': duration}
|
||||
for start, end, duration in self.processing_spans
|
||||
],
|
||||
'overflow_arrivals': list(self.overflow_arrivals),
|
||||
'window_seconds': 300 # 5 minutes
|
||||
}
|
||||
|
||||
stats_tracker = StatsTracker()
|
||||
|
||||
# Load default text from transcript.txt for GET requests
|
||||
with open('./transcript.txt', 'r') as file:
|
||||
default_source_text = file.read().strip()
|
||||
|
|
@ -15,9 +77,40 @@ with open('./transcript.txt', 'r') as file:
|
|||
def models_view():
|
||||
return json.dumps(list(AVAILABLE_MODELS.keys()))
|
||||
|
||||
@app.route("/overflow", methods=['GET', 'POST'])
|
||||
def overflow_view():
|
||||
"""
|
||||
Endpoint hit when HAProxy queue is full.
|
||||
Returns 429 with statistics about processing and overflow.
|
||||
"""
|
||||
arrival_time = time.time()
|
||||
stats_tracker.add_overflow_arrival(arrival_time)
|
||||
|
||||
stats = stats_tracker.get_stats()
|
||||
|
||||
response = {
|
||||
'error': 'Queue full',
|
||||
'status': 429,
|
||||
'stats': stats,
|
||||
'message': 'Service is at capacity. Try again or check queue statistics.'
|
||||
}
|
||||
|
||||
return json.dumps(response), 429
|
||||
|
||||
@app.route("/stats")
|
||||
def stats_view():
|
||||
"""
|
||||
Endpoint for frontend to poll current queue statistics.
|
||||
Returns processing spans and overflow arrivals from last 5 minutes.
|
||||
"""
|
||||
stats = stats_tracker.get_stats()
|
||||
return json.dumps(stats)
|
||||
|
||||
@app.route("/salience", methods=['GET'])
|
||||
def salience_view_default():
|
||||
"""GET endpoint - processes default text from transcript.txt"""
|
||||
start_time = time.time()
|
||||
|
||||
model_name = request.args.get('model', 'all-mpnet-base-v2')
|
||||
|
||||
# Validate model name
|
||||
|
|
@ -26,6 +119,9 @@ def salience_view_default():
|
|||
|
||||
sentence_ranges, adjacency = extract(default_source_text, model_name)
|
||||
|
||||
end_time = time.time()
|
||||
stats_tracker.add_processing_span(start_time, end_time)
|
||||
|
||||
return json.dumps({
|
||||
'source': default_source_text,
|
||||
'intervals': sentence_ranges,
|
||||
|
|
@ -36,6 +132,8 @@ def salience_view_default():
|
|||
@app.route("/salience", methods=['POST'])
|
||||
def salience_view_custom():
|
||||
"""POST endpoint - processes text from request body"""
|
||||
start_time = time.time()
|
||||
|
||||
model_name = request.args.get('model', 'all-mpnet-base-v2')
|
||||
|
||||
# Validate model name
|
||||
|
|
@ -50,6 +148,9 @@ def salience_view_custom():
|
|||
|
||||
sentence_ranges, adjacency = extract(source_text, model_name)
|
||||
|
||||
end_time = time.time()
|
||||
stats_tracker.add_processing_span(start_time, end_time)
|
||||
|
||||
return json.dumps({
|
||||
'source': source_text,
|
||||
'intervals': sentence_ranges,
|
||||
|
|
|
|||
|
|
@ -1,24 +1,34 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import nltk.data
|
||||
import nltk
|
||||
import os
|
||||
|
||||
# Set NLTK data path to project directory
|
||||
# Set default cache locations BEFORE importing libraries that use them
|
||||
PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
NLTK_DATA_DIR = os.path.join(PROJECT_DIR, 'nltk_data')
|
||||
TRANSFORMERS_CACHE_DIR = os.path.join(PROJECT_DIR, 'models_cache')
|
||||
|
||||
# Add to NLTK's search path
|
||||
nltk.data.path.insert(0, NLTK_DATA_DIR)
|
||||
if 'NLTK_DATA' not in os.environ:
|
||||
nltk_data_path = os.path.join(PROJECT_DIR, 'cache-nltk')
|
||||
os.makedirs(nltk_data_path, exist_ok=True)
|
||||
os.environ['NLTK_DATA'] = nltk_data_path
|
||||
|
||||
# Download to the custom location
|
||||
if 'HF_HOME' not in os.environ:
|
||||
os.environ['HF_HOME'] = os.path.join(PROJECT_DIR, 'cache-huggingface')
|
||||
|
||||
from salience.timed_import import timed_import
|
||||
|
||||
with timed_import("import numpy as np"):
|
||||
import numpy as np
|
||||
with timed_import("import torch"):
|
||||
import torch
|
||||
with timed_import("from sentence_transformers import SentenceTransformer"):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
with timed_import("import nltk"):
|
||||
import nltk.data
|
||||
import nltk
|
||||
|
||||
# Download punkt_tab to the configured location
|
||||
# Using punkt_tab (the modern tab-separated format introduced in NLTK 3.8+)
|
||||
# instead of the older punkt pickle format
|
||||
# The punkt_tab model version depends on the NLTK Python package version
|
||||
# Check your NLTK version with: uv pip show nltk
|
||||
nltk.download('punkt_tab', download_dir=NLTK_DATA_DIR)
|
||||
nltk.download('punkt_tab')
|
||||
|
||||
# Available models for the demo
|
||||
AVAILABLE_MODELS = {
|
||||
|
|
@ -46,13 +56,13 @@ AVAILABLE_MODELS = {
|
|||
print("Loading sentence transformer models...")
|
||||
models = {}
|
||||
|
||||
models['all-mpnet-base-v2'] = SentenceTransformer('all-mpnet-base-v2', cache_folder=TRANSFORMERS_CACHE_DIR)
|
||||
models['all-mpnet-base-v2'] = SentenceTransformer('all-mpnet-base-v2')
|
||||
print("Loading Alibaba-NLP/gte-large-en-v1.5")
|
||||
models['gte-large-en-v1.5'] = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True, cache_folder=TRANSFORMERS_CACHE_DIR)
|
||||
models['gte-large-en-v1.5'] = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
|
||||
#print("Loading Qwen/Qwen3-Embedding-4B")
|
||||
#models['qwen3-embedding-4b'] = SentenceTransformer('Qwen/Qwen3-Embedding-4B', trust_remote_code=True, cache_folder=TRANSFORMERS_CACHE_DIR)
|
||||
#models['qwen3-embedding-4b'] = SentenceTransformer('Qwen/Qwen3-Embedding-4B', trust_remote_code=True)
|
||||
print("Loading mixedbread-ai/mxbai-embed-large-v1")
|
||||
models["mxbai-embed-large-v1"] = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1', cache_folder=TRANSFORMERS_CACHE_DIR)
|
||||
models["mxbai-embed-large-v1"] = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1')
|
||||
print("All models loaded!")
|
||||
|
||||
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
|
||||
|
|
|
|||
20
api/salience/timed_import.py
Normal file
20
api/salience/timed_import.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
import sys
|
||||
import time
|
||||
|
||||
|
||||
class timed_import:
|
||||
"""Context manager for timing imports."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.start = None
|
||||
|
||||
def __enter__(self):
|
||||
sys.stdout.write(f"{self.name} ")
|
||||
sys.stdout.flush()
|
||||
self.start = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
elapsed = time.time() - self.start
|
||||
print(f"in {elapsed:.1f}s")
|
||||
Loading…
Add table
Add a link
Reference in a new issue