feat: create deployment scripts
This commit is contained in:
parent
78297efe5c
commit
8d5bce4bfb
22 changed files with 2697 additions and 74 deletions
150
api/benchmarks/visualize_benchmarks.py
Normal file
150
api/benchmarks/visualize_benchmarks.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
"""
|
||||
Visualize pytest-benchmark results with violin plots.
|
||||
|
||||
Usage: python visualize_benchmarks.py benchmark_results.json
|
||||
"""
|
||||
import sys
|
||||
import json
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import pandas as pd
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python visualize_benchmarks.py benchmark_results.json")
|
||||
sys.exit(1)
|
||||
|
||||
# Load benchmark results
|
||||
with open(sys.argv[1], 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Extract benchmark data
|
||||
benchmarks = data['benchmarks']
|
||||
|
||||
# Create a list to store all timing data
|
||||
timing_data = []
|
||||
|
||||
for bench in benchmarks:
|
||||
name = bench['name'].replace('test_bench_', '').replace('_', ' ').title()
|
||||
stats = bench['stats']
|
||||
|
||||
# Require actual timing data
|
||||
if 'data' not in stats:
|
||||
print(f"ERROR: No raw timing data found for {name}", file=sys.stderr)
|
||||
print(f"Benchmark must be run with --benchmark-save-data to store raw data", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
times = np.array(stats['data']) * 1000 # Convert to ms
|
||||
|
||||
for iteration, time in enumerate(times):
|
||||
timing_data.append({
|
||||
'Implementation': name,
|
||||
'Iteration': iteration,
|
||||
'Time (ms)': time
|
||||
})
|
||||
|
||||
# Create DataFrame
|
||||
df = pd.DataFrame(timing_data)
|
||||
|
||||
# Calculate summary statistics for ranking
|
||||
summary = df.groupby('Implementation')['Time (ms)'].agg(['mean', 'median', 'std', 'min', 'max'])
|
||||
slowest_mean = summary['mean'].max()
|
||||
summary['Speedup vs Slowest'] = slowest_mean / summary['mean']
|
||||
summary_sorted = summary.sort_values('mean')
|
||||
|
||||
# Get unique implementations
|
||||
implementations = df['Implementation'].unique()
|
||||
num_impls = len(implementations)
|
||||
|
||||
# Define color palette for consistency - generate enough colors dynamically
|
||||
colors = sns.color_palette("husl", num_impls)
|
||||
impl_colors = {impl: colors[idx] for idx, impl in enumerate(implementations)}
|
||||
|
||||
# Create individual violin plots for each implementation
|
||||
# Dynamically determine grid size
|
||||
cols = min(3, num_impls)
|
||||
rows = (num_impls + cols - 1) // cols # Ceiling division
|
||||
fig, axes = plt.subplots(rows, cols, figsize=(7*cols, 5*rows))
|
||||
if num_impls == 1:
|
||||
axes = [axes]
|
||||
else:
|
||||
axes = axes.flatten()
|
||||
|
||||
for idx, impl in enumerate(implementations):
|
||||
impl_data = df[df['Implementation'] == impl]
|
||||
sns.violinplot(data=impl_data, y='Time (ms)', ax=axes[idx], inner='box', color=impl_colors[impl])
|
||||
axes[idx].set_title(f'{impl}', fontsize=12, fontweight='bold')
|
||||
axes[idx].set_ylabel('Time (ms)', fontsize=10)
|
||||
axes[idx].grid(True, alpha=0.3, axis='y')
|
||||
|
||||
# Add mean line
|
||||
mean_val = impl_data['Time (ms)'].mean()
|
||||
axes[idx].axhline(mean_val, color='red', linestyle='--', linewidth=1, alpha=0.7, label=f'Mean: {mean_val:.4f} ms')
|
||||
axes[idx].legend(fontsize=8)
|
||||
|
||||
# Hide any extra empty subplots
|
||||
for idx in range(num_impls, len(axes)):
|
||||
axes[idx].set_visible(False)
|
||||
|
||||
plt.tight_layout()
|
||||
output_file_individual = sys.argv[1].replace('.json', '_individual.png')
|
||||
plt.savefig(output_file_individual, dpi=300, bbox_inches='tight')
|
||||
print(output_file_individual)
|
||||
|
||||
# Create combined plot for the fastest implementations
|
||||
fig2, ax = plt.subplots(1, 1, figsize=(10, 6))
|
||||
# Pick the top 3 fastest implementations (or fewer if there aren't that many)
|
||||
num_fast = min(3, num_impls)
|
||||
fast_implementations = list(summary_sorted.head(num_fast).index)
|
||||
df_fast = df[df['Implementation'].isin(fast_implementations)]
|
||||
|
||||
# Use the same colors as in individual plots
|
||||
palette = [impl_colors[impl] for impl in fast_implementations]
|
||||
sns.violinplot(data=df_fast, x='Implementation', y='Time (ms)', ax=ax, inner='box', palette=palette)
|
||||
ax.set_title(f'Cosine Similarity: Top {num_fast} Fastest Implementations', fontsize=14, fontweight='bold')
|
||||
ax.set_xlabel('Implementation', fontsize=12)
|
||||
ax.set_ylabel('Time (ms)', fontsize=12)
|
||||
ax.grid(True, alpha=0.3, axis='y')
|
||||
|
||||
# Add mean values as text
|
||||
for impl in fast_implementations:
|
||||
impl_data = df_fast[df_fast['Implementation'] == impl]
|
||||
mean_val = impl_data['Time (ms)'].mean()
|
||||
x_pos = list(fast_implementations).index(impl)
|
||||
ax.text(x_pos, mean_val, f'{mean_val:.4f} ms', ha='center', va='bottom', fontsize=10, fontweight='bold')
|
||||
|
||||
plt.tight_layout()
|
||||
output_file_combined = sys.argv[1].replace('.json', '_fast_comparison.png')
|
||||
plt.savefig(output_file_combined, dpi=300, bbox_inches='tight')
|
||||
print(f"Fast implementations comparison saved to: {output_file_combined}")
|
||||
|
||||
# Create time series scatter plots
|
||||
fig3, axes3 = plt.subplots(rows, cols, figsize=(7*cols, 5*rows))
|
||||
if num_impls == 1:
|
||||
axes3 = [axes3]
|
||||
else:
|
||||
axes3 = axes3.flatten()
|
||||
|
||||
for idx, impl in enumerate(implementations):
|
||||
impl_data = df[df['Implementation'] == impl].sort_values('Iteration')
|
||||
axes3[idx].scatter(impl_data['Iteration'], impl_data['Time (ms)'], alpha=0.5, s=10, color=impl_colors[impl])
|
||||
axes3[idx].set_title(f'{impl}', fontsize=12, fontweight='bold')
|
||||
axes3[idx].set_xlabel('Iteration', fontsize=10)
|
||||
axes3[idx].set_ylabel('Time (ms)', fontsize=10)
|
||||
axes3[idx].grid(True, alpha=0.3)
|
||||
|
||||
# Add mean line
|
||||
mean_val = impl_data['Time (ms)'].mean()
|
||||
axes3[idx].axhline(mean_val, color='red', linestyle='--', linewidth=1, alpha=0.7, label=f'Mean: {mean_val:.4f} ms')
|
||||
axes3[idx].legend(fontsize=8)
|
||||
|
||||
# Hide any extra empty subplots
|
||||
for idx in range(num_impls, len(axes3)):
|
||||
axes3[idx].set_visible(False)
|
||||
|
||||
plt.tight_layout()
|
||||
output_file_timeseries = sys.argv[1].replace('.json', '_timeseries.png')
|
||||
plt.savefig(output_file_timeseries, dpi=300, bbox_inches='tight')
|
||||
print(f"Time series scatter plots saved to: {output_file_timeseries}")
|
||||
|
||||
print("\nAll plots generated successfully!")
|
||||
Loading…
Add table
Add a link
Reference in a new issue