salience-editor/api/benchmarks/visualize_benchmarks.py

150 lines
5.4 KiB
Python
Raw Normal View History

2025-11-02 13:09:23 -08:00
"""
Visualize pytest-benchmark results with violin plots.
Usage: python visualize_benchmarks.py benchmark_results.json
"""
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
if len(sys.argv) < 2:
print("Usage: python visualize_benchmarks.py benchmark_results.json")
sys.exit(1)
# Load benchmark results
with open(sys.argv[1], 'r') as f:
data = json.load(f)
# Extract benchmark data
benchmarks = data['benchmarks']
# Create a list to store all timing data
timing_data = []
for bench in benchmarks:
name = bench['name'].replace('test_bench_', '').replace('_', ' ').title()
stats = bench['stats']
# Require actual timing data
if 'data' not in stats:
print(f"ERROR: No raw timing data found for {name}", file=sys.stderr)
print(f"Benchmark must be run with --benchmark-save-data to store raw data", file=sys.stderr)
sys.exit(1)
times = np.array(stats['data']) * 1000 # Convert to ms
for iteration, time in enumerate(times):
timing_data.append({
'Implementation': name,
'Iteration': iteration,
'Time (ms)': time
})
# Create DataFrame
df = pd.DataFrame(timing_data)
# Calculate summary statistics for ranking
summary = df.groupby('Implementation')['Time (ms)'].agg(['mean', 'median', 'std', 'min', 'max'])
slowest_mean = summary['mean'].max()
summary['Speedup vs Slowest'] = slowest_mean / summary['mean']
summary_sorted = summary.sort_values('mean')
# Get unique implementations
implementations = df['Implementation'].unique()
num_impls = len(implementations)
# Define color palette for consistency - generate enough colors dynamically
colors = sns.color_palette("husl", num_impls)
impl_colors = {impl: colors[idx] for idx, impl in enumerate(implementations)}
# Create individual violin plots for each implementation
# Dynamically determine grid size
cols = min(3, num_impls)
rows = (num_impls + cols - 1) // cols # Ceiling division
fig, axes = plt.subplots(rows, cols, figsize=(7*cols, 5*rows))
if num_impls == 1:
axes = [axes]
else:
axes = axes.flatten()
for idx, impl in enumerate(implementations):
impl_data = df[df['Implementation'] == impl]
sns.violinplot(data=impl_data, y='Time (ms)', ax=axes[idx], inner='box', color=impl_colors[impl])
axes[idx].set_title(f'{impl}', fontsize=12, fontweight='bold')
axes[idx].set_ylabel('Time (ms)', fontsize=10)
axes[idx].grid(True, alpha=0.3, axis='y')
# Add mean line
mean_val = impl_data['Time (ms)'].mean()
axes[idx].axhline(mean_val, color='red', linestyle='--', linewidth=1, alpha=0.7, label=f'Mean: {mean_val:.4f} ms')
axes[idx].legend(fontsize=8)
# Hide any extra empty subplots
for idx in range(num_impls, len(axes)):
axes[idx].set_visible(False)
plt.tight_layout()
output_file_individual = sys.argv[1].replace('.json', '_individual.png')
plt.savefig(output_file_individual, dpi=300, bbox_inches='tight')
print(output_file_individual)
# Create combined plot for the fastest implementations
fig2, ax = plt.subplots(1, 1, figsize=(10, 6))
# Pick the top 3 fastest implementations (or fewer if there aren't that many)
num_fast = min(3, num_impls)
fast_implementations = list(summary_sorted.head(num_fast).index)
df_fast = df[df['Implementation'].isin(fast_implementations)]
# Use the same colors as in individual plots
palette = [impl_colors[impl] for impl in fast_implementations]
sns.violinplot(data=df_fast, x='Implementation', y='Time (ms)', ax=ax, inner='box', palette=palette)
ax.set_title(f'Cosine Similarity: Top {num_fast} Fastest Implementations', fontsize=14, fontweight='bold')
ax.set_xlabel('Implementation', fontsize=12)
ax.set_ylabel('Time (ms)', fontsize=12)
ax.grid(True, alpha=0.3, axis='y')
# Add mean values as text
for impl in fast_implementations:
impl_data = df_fast[df_fast['Implementation'] == impl]
mean_val = impl_data['Time (ms)'].mean()
x_pos = list(fast_implementations).index(impl)
ax.text(x_pos, mean_val, f'{mean_val:.4f} ms', ha='center', va='bottom', fontsize=10, fontweight='bold')
plt.tight_layout()
output_file_combined = sys.argv[1].replace('.json', '_fast_comparison.png')
plt.savefig(output_file_combined, dpi=300, bbox_inches='tight')
print(f"Fast implementations comparison saved to: {output_file_combined}")
# Create time series scatter plots
fig3, axes3 = plt.subplots(rows, cols, figsize=(7*cols, 5*rows))
if num_impls == 1:
axes3 = [axes3]
else:
axes3 = axes3.flatten()
for idx, impl in enumerate(implementations):
impl_data = df[df['Implementation'] == impl].sort_values('Iteration')
axes3[idx].scatter(impl_data['Iteration'], impl_data['Time (ms)'], alpha=0.5, s=10, color=impl_colors[impl])
axes3[idx].set_title(f'{impl}', fontsize=12, fontweight='bold')
axes3[idx].set_xlabel('Iteration', fontsize=10)
axes3[idx].set_ylabel('Time (ms)', fontsize=10)
axes3[idx].grid(True, alpha=0.3)
# Add mean line
mean_val = impl_data['Time (ms)'].mean()
axes3[idx].axhline(mean_val, color='red', linestyle='--', linewidth=1, alpha=0.7, label=f'Mean: {mean_val:.4f} ms')
axes3[idx].legend(fontsize=8)
# Hide any extra empty subplots
for idx in range(num_impls, len(axes3)):
axes3[idx].set_visible(False)
plt.tight_layout()
output_file_timeseries = sys.argv[1].replace('.json', '_timeseries.png')
plt.savefig(output_file_timeseries, dpi=300, bbox_inches='tight')
print(f"Time series scatter plots saved to: {output_file_timeseries}")
print("\nAll plots generated successfully!")