150 lines
5.4 KiB
Python
150 lines
5.4 KiB
Python
"""
|
|
Visualize pytest-benchmark results with violin plots.
|
|
|
|
Usage: python visualize_benchmarks.py benchmark_results.json
|
|
"""
|
|
import sys
|
|
import json
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
import pandas as pd
|
|
|
|
if len(sys.argv) < 2:
|
|
print("Usage: python visualize_benchmarks.py benchmark_results.json")
|
|
sys.exit(1)
|
|
|
|
# Load benchmark results
|
|
with open(sys.argv[1], 'r') as f:
|
|
data = json.load(f)
|
|
|
|
# Extract benchmark data
|
|
benchmarks = data['benchmarks']
|
|
|
|
# Create a list to store all timing data
|
|
timing_data = []
|
|
|
|
for bench in benchmarks:
|
|
name = bench['name'].replace('test_bench_', '').replace('_', ' ').title()
|
|
stats = bench['stats']
|
|
|
|
# Require actual timing data
|
|
if 'data' not in stats:
|
|
print(f"ERROR: No raw timing data found for {name}", file=sys.stderr)
|
|
print(f"Benchmark must be run with --benchmark-save-data to store raw data", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
times = np.array(stats['data']) * 1000 # Convert to ms
|
|
|
|
for iteration, time in enumerate(times):
|
|
timing_data.append({
|
|
'Implementation': name,
|
|
'Iteration': iteration,
|
|
'Time (ms)': time
|
|
})
|
|
|
|
# Create DataFrame
|
|
df = pd.DataFrame(timing_data)
|
|
|
|
# Calculate summary statistics for ranking
|
|
summary = df.groupby('Implementation')['Time (ms)'].agg(['mean', 'median', 'std', 'min', 'max'])
|
|
slowest_mean = summary['mean'].max()
|
|
summary['Speedup vs Slowest'] = slowest_mean / summary['mean']
|
|
summary_sorted = summary.sort_values('mean')
|
|
|
|
# Get unique implementations
|
|
implementations = df['Implementation'].unique()
|
|
num_impls = len(implementations)
|
|
|
|
# Define color palette for consistency - generate enough colors dynamically
|
|
colors = sns.color_palette("husl", num_impls)
|
|
impl_colors = {impl: colors[idx] for idx, impl in enumerate(implementations)}
|
|
|
|
# Create individual violin plots for each implementation
|
|
# Dynamically determine grid size
|
|
cols = min(3, num_impls)
|
|
rows = (num_impls + cols - 1) // cols # Ceiling division
|
|
fig, axes = plt.subplots(rows, cols, figsize=(7*cols, 5*rows))
|
|
if num_impls == 1:
|
|
axes = [axes]
|
|
else:
|
|
axes = axes.flatten()
|
|
|
|
for idx, impl in enumerate(implementations):
|
|
impl_data = df[df['Implementation'] == impl]
|
|
sns.violinplot(data=impl_data, y='Time (ms)', ax=axes[idx], inner='box', color=impl_colors[impl])
|
|
axes[idx].set_title(f'{impl}', fontsize=12, fontweight='bold')
|
|
axes[idx].set_ylabel('Time (ms)', fontsize=10)
|
|
axes[idx].grid(True, alpha=0.3, axis='y')
|
|
|
|
# Add mean line
|
|
mean_val = impl_data['Time (ms)'].mean()
|
|
axes[idx].axhline(mean_val, color='red', linestyle='--', linewidth=1, alpha=0.7, label=f'Mean: {mean_val:.4f} ms')
|
|
axes[idx].legend(fontsize=8)
|
|
|
|
# Hide any extra empty subplots
|
|
for idx in range(num_impls, len(axes)):
|
|
axes[idx].set_visible(False)
|
|
|
|
plt.tight_layout()
|
|
output_file_individual = sys.argv[1].replace('.json', '_individual.png')
|
|
plt.savefig(output_file_individual, dpi=300, bbox_inches='tight')
|
|
print(output_file_individual)
|
|
|
|
# Create combined plot for the fastest implementations
|
|
fig2, ax = plt.subplots(1, 1, figsize=(10, 6))
|
|
# Pick the top 3 fastest implementations (or fewer if there aren't that many)
|
|
num_fast = min(3, num_impls)
|
|
fast_implementations = list(summary_sorted.head(num_fast).index)
|
|
df_fast = df[df['Implementation'].isin(fast_implementations)]
|
|
|
|
# Use the same colors as in individual plots
|
|
palette = [impl_colors[impl] for impl in fast_implementations]
|
|
sns.violinplot(data=df_fast, x='Implementation', y='Time (ms)', ax=ax, inner='box', palette=palette)
|
|
ax.set_title(f'Cosine Similarity: Top {num_fast} Fastest Implementations', fontsize=14, fontweight='bold')
|
|
ax.set_xlabel('Implementation', fontsize=12)
|
|
ax.set_ylabel('Time (ms)', fontsize=12)
|
|
ax.grid(True, alpha=0.3, axis='y')
|
|
|
|
# Add mean values as text
|
|
for impl in fast_implementations:
|
|
impl_data = df_fast[df_fast['Implementation'] == impl]
|
|
mean_val = impl_data['Time (ms)'].mean()
|
|
x_pos = list(fast_implementations).index(impl)
|
|
ax.text(x_pos, mean_val, f'{mean_val:.4f} ms', ha='center', va='bottom', fontsize=10, fontweight='bold')
|
|
|
|
plt.tight_layout()
|
|
output_file_combined = sys.argv[1].replace('.json', '_fast_comparison.png')
|
|
plt.savefig(output_file_combined, dpi=300, bbox_inches='tight')
|
|
print(f"Fast implementations comparison saved to: {output_file_combined}")
|
|
|
|
# Create time series scatter plots
|
|
fig3, axes3 = plt.subplots(rows, cols, figsize=(7*cols, 5*rows))
|
|
if num_impls == 1:
|
|
axes3 = [axes3]
|
|
else:
|
|
axes3 = axes3.flatten()
|
|
|
|
for idx, impl in enumerate(implementations):
|
|
impl_data = df[df['Implementation'] == impl].sort_values('Iteration')
|
|
axes3[idx].scatter(impl_data['Iteration'], impl_data['Time (ms)'], alpha=0.5, s=10, color=impl_colors[impl])
|
|
axes3[idx].set_title(f'{impl}', fontsize=12, fontweight='bold')
|
|
axes3[idx].set_xlabel('Iteration', fontsize=10)
|
|
axes3[idx].set_ylabel('Time (ms)', fontsize=10)
|
|
axes3[idx].grid(True, alpha=0.3)
|
|
|
|
# Add mean line
|
|
mean_val = impl_data['Time (ms)'].mean()
|
|
axes3[idx].axhline(mean_val, color='red', linestyle='--', linewidth=1, alpha=0.7, label=f'Mean: {mean_val:.4f} ms')
|
|
axes3[idx].legend(fontsize=8)
|
|
|
|
# Hide any extra empty subplots
|
|
for idx in range(num_impls, len(axes3)):
|
|
axes3[idx].set_visible(False)
|
|
|
|
plt.tight_layout()
|
|
output_file_timeseries = sys.argv[1].replace('.json', '_timeseries.png')
|
|
plt.savefig(output_file_timeseries, dpi=300, bbox_inches='tight')
|
|
print(f"Time series scatter plots saved to: {output_file_timeseries}")
|
|
|
|
print("\nAll plots generated successfully!")
|