feat: make version deployable
This commit is contained in:
parent
4aa8759514
commit
49bd94cda2
22 changed files with 7785 additions and 10962 deletions
|
|
@ -1,27 +1,14 @@
|
|||
# Memory Sharing for ML Models
|
||||
# ============================
|
||||
# This app is designed to run with Gunicorn's --preload flag, which loads the
|
||||
# SentenceTransformer models once in the master process before forking workers.
|
||||
# On Linux, fork uses copy-on-write (COW) semantics, so workers share the
|
||||
# read-only model weights in memory rather than each loading their own copy.
|
||||
# This is critical for keeping memory usage reasonable with large transformer models.
|
||||
#
|
||||
# ResourceTracker errors on shutdown (Python 3.14):
|
||||
# When you Ctrl+C the Gunicorn process, you may see
|
||||
# "ChildProcessError: [Errno 10] No child processes"
|
||||
# from multiprocessing.resource_tracker.
|
||||
#
|
||||
# I think this is harmless. I think what happens is each forked worker gets a
|
||||
# copy of the ResourceTracker object, then each copy tries to deallocate the
|
||||
# same resources. The process still shuts down reasonbly quickly, so I'm not
|
||||
# concerned.
|
||||
# Salience API
|
||||
# ============
|
||||
# Uses a worker thread for model inference to avoid fork() issues with Metal/MPS.
|
||||
# The worker thread owns all model instances; HTTP handlers submit work via queue.
|
||||
|
||||
print("Starting salience __init__.py...")
|
||||
|
||||
from flask import Flask, request
|
||||
from flask_cors import CORS
|
||||
import numpy as np
|
||||
from .salience import extract, AVAILABLE_MODELS
|
||||
from .salience import submit_work, AVAILABLE_MODELS
|
||||
import json
|
||||
import time
|
||||
from collections import deque
|
||||
|
|
@ -117,7 +104,7 @@ def salience_view_default():
|
|||
if model_name not in AVAILABLE_MODELS:
|
||||
return json.dumps({'error': f'Invalid model: {model_name}'}), 400
|
||||
|
||||
sentence_ranges, adjacency = extract(default_source_text, model_name)
|
||||
sentence_ranges, adjacency = submit_work(default_source_text, model_name)
|
||||
|
||||
end_time = time.time()
|
||||
stats_tracker.add_processing_span(start_time, end_time)
|
||||
|
|
@ -146,7 +133,7 @@ def salience_view_custom():
|
|||
if not source_text:
|
||||
return json.dumps({'error': 'No text provided'}), 400
|
||||
|
||||
sentence_ranges, adjacency = extract(source_text, model_name)
|
||||
sentence_ranges, adjacency = submit_work(source_text, model_name)
|
||||
|
||||
end_time = time.time()
|
||||
stats_tracker.add_processing_span(start_time, end_time)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
import os
|
||||
import threading
|
||||
import queue
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
# Set default cache locations BEFORE importing libraries that use them
|
||||
PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
|
@ -11,6 +15,9 @@ if 'NLTK_DATA' not in os.environ:
|
|||
if 'HF_HOME' not in os.environ:
|
||||
os.environ['HF_HOME'] = os.path.join(PROJECT_DIR, 'cache-huggingface')
|
||||
|
||||
# Device configuration: set TORCH_DEVICE=cpu to force CPU, otherwise auto-detect
|
||||
DEVICE = os.environ.get('TORCH_DEVICE', None) # None = auto-detect (cuda/mps/cpu)
|
||||
|
||||
from salience.timed_import import timed_import
|
||||
|
||||
with timed_import("import numpy as np"):
|
||||
|
|
@ -31,8 +38,9 @@ with timed_import("import nltk"):
|
|||
nltk.download('punkt_tab')
|
||||
|
||||
# Available models for the demo
|
||||
# Keys are short names for the API, values are full HuggingFace repo IDs
|
||||
AVAILABLE_MODELS = {
|
||||
'all-mpnet-base-v2': 'all-mpnet-base-v2', # Dec 2020
|
||||
'all-mpnet-base-v2': 'sentence-transformers/all-mpnet-base-v2', # Dec 2020
|
||||
'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5', # Jan 2024
|
||||
# 'qwen3-embedding-4b': 'Qwen/Qwen3-Embedding-4B', # April 2025
|
||||
'mxbai-embed-large-v1': 'mixedbread-ai/mxbai-embed-large-v1',
|
||||
|
|
@ -52,18 +60,17 @@ AVAILABLE_MODELS = {
|
|||
# Qwen/Qwen3-Embedding-4B: 80.86
|
||||
# mixedbread-ai/mxbai-embed-large-v1: 85.00
|
||||
|
||||
# Load all models into memory
|
||||
print("Loading sentence transformer models...")
|
||||
models = {}
|
||||
# Models loaded on first use in worker thread
|
||||
_loaded_models = {}
|
||||
|
||||
models['all-mpnet-base-v2'] = SentenceTransformer('all-mpnet-base-v2')
|
||||
print("Loading Alibaba-NLP/gte-large-en-v1.5")
|
||||
models['gte-large-en-v1.5'] = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
|
||||
#print("Loading Qwen/Qwen3-Embedding-4B")
|
||||
#models['qwen3-embedding-4b'] = SentenceTransformer('Qwen/Qwen3-Embedding-4B', trust_remote_code=True)
|
||||
print("Loading mixedbread-ai/mxbai-embed-large-v1")
|
||||
models["mxbai-embed-large-v1"] = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1')
|
||||
print("All models loaded!")
|
||||
def _get_model(model_name):
|
||||
"""Load and cache a model. Called only from worker thread."""
|
||||
if model_name not in _loaded_models:
|
||||
repo_id = AVAILABLE_MODELS[model_name]
|
||||
print(f"Loading model {repo_id} into memory...")
|
||||
trust_remote = model_name in ('gte-large-en-v1.5', 'qwen3-embedding-4b')
|
||||
_loaded_models[model_name] = SentenceTransformer(repo_id, trust_remote_code=trust_remote, device=DEVICE)
|
||||
return _loaded_models[model_name]
|
||||
|
||||
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
|
||||
|
||||
|
|
@ -89,7 +96,7 @@ def get_sentences(source_text):
|
|||
return sentences, sentence_ranges
|
||||
|
||||
def text_rank(sentences, model_name='all-mpnet-base-v2'):
|
||||
model = models[model_name]
|
||||
model = _get_model(model_name)
|
||||
vectors = model.encode(sentences)
|
||||
adjacency = torch.tensor(cos_sim(vectors)).fill_diagonal_(0.)
|
||||
adjacency[adjacency < 0] = 0
|
||||
|
|
@ -111,6 +118,65 @@ def extract(source_text, model_name='all-mpnet-base-v2'):
|
|||
return sentence_ranges, adjacency
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Worker Thread for Model Inference
|
||||
# =============================================================================
|
||||
# All model inference runs in a dedicated worker thread. This:
|
||||
# 1. Avoids fork() issues with Metal/MPS (no forking server needed)
|
||||
# 2. Serializes inference requests (one at a time)
|
||||
# 3. Keeps /stats and other endpoints responsive
|
||||
|
||||
@dataclass
|
||||
class WorkItem:
|
||||
source_text: str
|
||||
model_name: str
|
||||
event: threading.Event = field(default_factory=threading.Event)
|
||||
result: Optional[tuple] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
_work_queue: queue.Queue[WorkItem] = queue.Queue()
|
||||
|
||||
def _model_worker():
|
||||
"""Worker thread loop - processes inference requests from queue."""
|
||||
while True:
|
||||
item = _work_queue.get()
|
||||
try:
|
||||
item.result = extract(item.source_text, item.model_name)
|
||||
except Exception as e:
|
||||
item.error = str(e)
|
||||
finally:
|
||||
item.event.set()
|
||||
|
||||
# Start worker thread
|
||||
threading.Thread(target=_model_worker, daemon=True, name="model-worker").start()
|
||||
|
||||
def submit_work(source_text: str, model_name: str, timeout: float = 60.0) -> tuple:
|
||||
"""Submit text for salience extraction and wait for result.
|
||||
|
||||
Args:
|
||||
source_text: Text to analyze
|
||||
model_name: Name of the model to use (must be in AVAILABLE_MODELS)
|
||||
timeout: Max seconds to wait for result
|
||||
|
||||
Returns:
|
||||
(sentence_ranges, adjacency) tuple
|
||||
|
||||
Raises:
|
||||
TimeoutError: If inference takes longer than timeout
|
||||
RuntimeError: If inference fails
|
||||
"""
|
||||
item = WorkItem(source_text=source_text, model_name=model_name)
|
||||
_work_queue.put(item)
|
||||
|
||||
if not item.event.wait(timeout=timeout):
|
||||
raise TimeoutError("Model inference timed out")
|
||||
|
||||
if item.error:
|
||||
raise RuntimeError(item.error)
|
||||
|
||||
return item.result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Unused/Debugging Code
|
||||
# =============================================================================
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue