feat: add multiple models
This commit is contained in:
parent
8e2865c5ac
commit
fee0e643e4
5 changed files with 517 additions and 32 deletions
|
|
@ -13,9 +13,45 @@ NLTK_DATA_DIR = os.path.join(PROJECT_DIR, 'nltk_data')
|
|||
nltk.data.path.insert(0, NLTK_DATA_DIR)
|
||||
|
||||
# Download to the custom location
|
||||
nltk.download('punkt', download_dir=NLTK_DATA_DIR)
|
||||
# Using punkt_tab (the modern tab-separated format introduced in NLTK 3.8+)
|
||||
# instead of the older punkt pickle format
|
||||
# The punkt_tab model version depends on the NLTK Python package version
|
||||
# Check your NLTK version with: uv pip show nltk
|
||||
nltk.download('punkt_tab', download_dir=NLTK_DATA_DIR)
|
||||
|
||||
# Available models for the demo
|
||||
AVAILABLE_MODELS = {
|
||||
'all-mpnet-base-v2': 'all-mpnet-base-v2', # Dec 2020
|
||||
'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5', # Jan 2024
|
||||
# 'qwen3-embedding-4b': 'Qwen/Qwen3-Embedding-4B', # April 2025
|
||||
'mxbai-embed-large-v1': 'mixedbread-ai/mxbai-embed-large-v1',
|
||||
}
|
||||
|
||||
# On clustering
|
||||
# mixedbread-ai/mxbai-embed-large-v1: 46.71
|
||||
# gte-large-en-v1.5: 47.95
|
||||
# Qwen/Qwen3-Embedding-0.6B: 52.33
|
||||
# Qwen/Qwen3-Embedding-4B: 57.15
|
||||
|
||||
# On STS
|
||||
# gte-large-en-v1.5: 81.43
|
||||
# Qwen/Qwen3-Embedding-0.6B: 76.17
|
||||
# Qwen/Qwen3-Embedding-4B: 80.86
|
||||
# mixedbread-ai/mxbai-embed-large-v1: 85.00
|
||||
|
||||
# Load all models into memory
|
||||
print("Loading sentence transformer models...")
|
||||
models = {}
|
||||
|
||||
models['all-mpnet-base-v2'] = SentenceTransformer('all-mpnet-base-v2')
|
||||
print("Loading Alibaba-NLP/gte-large-en-v1.5")
|
||||
models['gte-large-en-v1.5'] = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
|
||||
#print("Loading Qwen/Qwen3-Embedding-4B")
|
||||
#models['qwen3-embedding-4b'] = SentenceTransformer('Qwen/Qwen3-Embedding-4B', trust_remote_code=True)
|
||||
print("Loading mixedbread-ai/mxbai-embed-large-v1")
|
||||
models["mxbai-embed-large-v1"] = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1')
|
||||
print("All models loaded!")
|
||||
|
||||
model = SentenceTransformer('all-mpnet-base-v2')
|
||||
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
|
||||
|
||||
def cos_sim(a, b):
|
||||
|
|
@ -40,7 +76,8 @@ def get_sentences(source_text):
|
|||
sentences = [source_text[start:end] for start, end in sentence_ranges]
|
||||
return sentences, sentence_ranges
|
||||
|
||||
def text_rank(sentences):
|
||||
def text_rank(sentences, model_name='all-mpnet-base-v2'):
|
||||
model = models[model_name]
|
||||
vectors = model.encode(sentences)
|
||||
adjacency = torch.tensor(cos_sim(vectors, vectors)).fill_diagonal_(0.)
|
||||
adjacency[adjacency < 0] = 0
|
||||
|
|
@ -51,9 +88,9 @@ def terminal_distr(adjacency, initial=None):
|
|||
scores = sample.matmul(torch.matrix_power(adjacency, 10)).numpy().tolist()
|
||||
return scores
|
||||
|
||||
def extract(source_text):
|
||||
def extract(source_text, model_name='all-mpnet-base-v2'):
|
||||
sentences, sentence_ranges = get_sentences(source_text)
|
||||
adjacency = text_rank(sentences)
|
||||
adjacency = text_rank(sentences, model_name)
|
||||
return sentence_ranges, adjacency
|
||||
|
||||
def get_results(sentences, adjacency):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue