Initial commit

This commit is contained in:
Bole Ma
2026-02-05 23:18:26 +01:00
commit 747c92ac6b
31 changed files with 4220 additions and 0 deletions

3
utils/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""Utility package for LLM benchmarking."""
__version__ = "1.0.0"

295
utils/attention.py Normal file
View File

@@ -0,0 +1,295 @@
"""
Attention Implementation Helpers for LLM Benchmarking
Provides functions for configuring different attention implementations
based on GPU type.
"""
from typing import Optional
import warnings
def get_default_attention(gpu_name: str) -> str:
"""
Get default attention implementation for GPU type.
Args:
gpu_name: GPU device name (from monitoring)
Returns:
Attention implementation string
"""
gpu_lower = gpu_name.lower()
# H100/H200: FlashAttention-3 Hopper
if 'h100' in gpu_lower or 'h200' in gpu_lower:
return "flash_attention_3_hopper"
# A100, MI300X, other: FlashAttention-2
return "flash_attention_2"
def configure_model_attention(model, attn_implementation: str, verbose: bool = True):
"""
Configure model to use specified attention implementation.
This function patches the model if needed to use the specified attention.
For standard implementations like flash_attention_2, the model should already
be loaded with the correct implementation via AutoModelForCausalLM.from_pretrained().
For FlashAttention-3 Hopper, this patches the model's attention modules.
Args:
model: The loaded model
attn_implementation: Attention implementation to use
verbose: Print configuration messages
Returns:
Configured model
"""
if verbose:
print(f"Configuring attention: {attn_implementation}")
if attn_implementation == "flash_attention_3_hopper":
# Patch model to use FlashAttention-3 Hopper
try:
import flash_attn_interface
except ImportError:
raise ImportError(
"flash_attn_interface not found. This is required for FlashAttention-3.\n"
"Install with appropriate method for your system."
)
# Patch the model's attention function
_patch_fa3_hopper(model, verbose=verbose)
elif attn_implementation == "flash_attention_2":
# Model should already be loaded with FA2
if verbose:
print(" Using FlashAttention-2 (configured during model loading)")
elif attn_implementation == "sdpa":
# PyTorch Scaled Dot Product Attention
if verbose:
print(" Using PyTorch SDPA")
elif attn_implementation == "eager":
# Standard PyTorch attention
if verbose:
print(" Using eager attention")
else:
warnings.warn(f"Unknown attention implementation: {attn_implementation}")
return model
def _patch_fa3_hopper(model, verbose: bool = True):
"""
Patch model to use FlashAttention-3 Hopper.
This replaces the attention computation in the model's attention layers
with calls to flash_attn_interface.flash_attn_func().
Args:
model: The model to patch
verbose: Print patching messages
"""
import flash_attn_interface
import torch
# Counter for patched modules
num_patched = 0
# Iterate through all modules in the model
for name, module in model.named_modules():
# Look for attention modules (this will vary by model architecture)
# Common names: "self_attn", "attn", "attention"
if any(attn_name in name.lower() for attn_name in ['self_attn', 'attention']):
# Check if module has a forward method we can patch
if hasattr(module, 'forward'):
# Save original forward
original_forward = module.forward
# Create patched forward function
def create_patched_forward(orig_forward):
def patched_forward(hidden_states, *args, **kwargs):
# Check if this is an attention computation
# For Qwen models, attention modules typically have q, k, v projections
if hasattr(module, 'q_proj') and hasattr(module, 'k_proj') and hasattr(module, 'v_proj'):
# Extract batch, seq_len, hidden_dim
batch_size, seq_len, hidden_dim = hidden_states.shape
# Compute Q, K, V
q = module.q_proj(hidden_states)
k = module.k_proj(hidden_states)
v = module.v_proj(hidden_states)
# Reshape for multi-head attention
num_heads = module.num_heads
head_dim = hidden_dim // num_heads
q = q.view(batch_size, seq_len, num_heads, head_dim)
k = k.view(batch_size, seq_len, num_heads, head_dim)
v = v.view(batch_size, seq_len, num_heads, head_dim)
# Call FlashAttention-3
# Note: flash_attn_func expects (batch, seqlen, nheads, headdim)
attn_output = flash_attn_interface.flash_attn_func(
q, k, v,
dropout_p=0.0,
softmax_scale=None, # Will use default 1/sqrt(head_dim)
causal=True, # For causal LM
)
# Reshape back
attn_output = attn_output.view(batch_size, seq_len, hidden_dim)
# Apply output projection if it exists
if hasattr(module, 'o_proj'):
attn_output = module.o_proj(attn_output)
return (attn_output,) + (None,) * (len(orig_forward(hidden_states, *args, **kwargs)) - 1)
else:
# Not an attention module we can patch, use original
return orig_forward(hidden_states, *args, **kwargs)
return patched_forward
# Apply patch
module.forward = create_patched_forward(original_forward)
num_patched += 1
if verbose:
if num_patched > 0:
print(f" ✓ Patched {num_patched} attention modules to use FlashAttention-3 Hopper")
else:
warnings.warn(" ⚠ No attention modules found to patch for FlashAttention-3")
def get_attention_info(attn_implementation: str) -> dict:
"""
Get information about an attention implementation.
Args:
attn_implementation: Attention implementation string
Returns:
Dictionary with info about the implementation
"""
info = {
"flash_attention_2": {
"name": "FlashAttention-2",
"description": "Optimized attention for A100 and other GPUs",
"gpu_support": ["A100", "MI300X", "V100", "RTX"],
"memory_efficient": True,
"requires_cuda": True,
},
"flash_attention_3_hopper": {
"name": "FlashAttention-3 Hopper",
"description": "Optimized attention for H100/H200 Hopper architecture",
"gpu_support": ["H100", "H200"],
"memory_efficient": True,
"requires_cuda": True,
},
"sdpa": {
"name": "PyTorch SDPA",
"description": "PyTorch Scaled Dot Product Attention",
"gpu_support": ["All"],
"memory_efficient": True,
"requires_cuda": False,
},
"eager": {
"name": "Eager Attention",
"description": "Standard PyTorch attention implementation",
"gpu_support": ["All"],
"memory_efficient": False,
"requires_cuda": False,
},
}
return info.get(attn_implementation, {
"name": attn_implementation,
"description": "Unknown attention implementation",
"gpu_support": ["Unknown"],
"memory_efficient": False,
"requires_cuda": False,
})
def validate_attention_for_gpu(attn_implementation: str, gpu_name: str) -> tuple[bool, Optional[str]]:
"""
Validate if attention implementation is suitable for GPU.
Args:
attn_implementation: Attention implementation
gpu_name: GPU device name
Returns:
Tuple of (is_valid, warning_message)
"""
gpu_lower = gpu_name.lower()
# FlashAttention-3 Hopper validation
if attn_implementation == "flash_attention_3_hopper":
if 'h100' not in gpu_lower and 'h200' not in gpu_lower:
return False, (
f"FlashAttention-3 Hopper is optimized for H100/H200. "
f"Current GPU: {gpu_name}. Consider using flash_attention_2 instead."
)
# FlashAttention-2 on Hopper GPUs
if attn_implementation == "flash_attention_2":
if 'h100' in gpu_lower or 'h200' in gpu_lower:
return True, (
f"FlashAttention-2 will work on {gpu_name}, but FlashAttention-3 Hopper "
f"may provide better performance."
)
return True, None
if __name__ == "__main__":
"""Test attention configuration."""
print("=" * 60)
print("Attention Implementation Test")
print("=" * 60)
# Test getting default attention for different GPUs
test_gpus = [
"NVIDIA A100 80GB",
"NVIDIA H100 80GB",
"NVIDIA H200 141GB",
"AMD Instinct MI300X",
]
print("\nDefault attention implementations:")
for gpu in test_gpus:
attn = get_default_attention(gpu)
print(f" {gpu:30s}{attn}")
# Test validation
print("\nValidation tests:")
test_cases = [
("flash_attention_3_hopper", "NVIDIA H100 80GB"),
("flash_attention_3_hopper", "NVIDIA A100 80GB"),
("flash_attention_2", "NVIDIA H100 80GB"),
("flash_attention_2", "NVIDIA A100 80GB"),
]
for attn, gpu in test_cases:
valid, warning = validate_attention_for_gpu(attn, gpu)
status = "" if valid else ""
print(f" {status} {attn:30s} on {gpu:25s}")
if warning:
print(f"{warning}")
# Test getting info
print("\nAttention implementation info:")
for attn in ["flash_attention_2", "flash_attention_3_hopper", "sdpa"]:
info = get_attention_info(attn)
print(f"\n {info['name']}:")
print(f" Description: {info['description']}")
print(f" GPU Support: {', '.join(info['gpu_support'])}")
print(f" Memory Efficient: {info['memory_efficient']}")

562
utils/gpu_monitor.py Normal file
View File

@@ -0,0 +1,562 @@
"""
GPU Monitoring Infrastructure for LLM Benchmarking
Provides unified interface for monitoring both NVIDIA and AMD GPUs.
"""
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, List
import warnings
@dataclass
class GPUMetrics:
"""Container for GPU metrics."""
timestamp: float
power_watts: float
gpu_utilization_percent: float
memory_used_gb: float
memory_total_gb: float
temperature_celsius: Optional[float] = None
energy_joules: Optional[float] = None # Cumulative energy
class GPUMonitor(ABC):
"""Abstract base class for GPU monitoring."""
def __init__(self, device_id: int = 0):
"""
Initialize GPU monitor.
Args:
device_id: GPU device ID to monitor
"""
self.device_id = device_id
self.start_time = None
self.start_energy = None
self.last_metrics = None
@abstractmethod
def get_metrics(self) -> GPUMetrics:
"""Get current GPU metrics."""
pass
@abstractmethod
def get_device_name(self) -> str:
"""Get GPU device name."""
pass
@abstractmethod
def cleanup(self):
"""Cleanup resources."""
pass
def start_monitoring(self):
"""Start energy monitoring session."""
self.start_time = time.time()
metrics = self.get_metrics()
self.start_energy = metrics.energy_joules if metrics.energy_joules is not None else 0.0
self.last_metrics = metrics
def get_energy_consumed(self) -> float:
"""
Get energy consumed since start_monitoring() was called.
Returns:
Energy in Joules
"""
if self.start_time is None:
raise RuntimeError("Must call start_monitoring() first")
current_metrics = self.get_metrics()
if current_metrics.energy_joules is not None:
# If GPU provides cumulative energy, use it
return current_metrics.energy_joules - self.start_energy
else:
# Otherwise, integrate power over time
elapsed_time = time.time() - self.start_time
# Use average of start and current power
avg_power = (self.last_metrics.power_watts + current_metrics.power_watts) / 2.0
return avg_power * elapsed_time
def get_average_power(self) -> float:
"""
Get average power consumption since start_monitoring().
Returns:
Average power in Watts
"""
if self.start_time is None:
raise RuntimeError("Must call start_monitoring() first")
elapsed_time = time.time() - self.start_time
if elapsed_time == 0:
return 0.0
energy = self.get_energy_consumed()
return energy / elapsed_time
class NVIDIAMonitor(GPUMonitor):
"""NVIDIA GPU monitor using pynvml."""
def __init__(self, device_id: int = 0):
"""Initialize NVIDIA monitor."""
try:
import pynvml
self.pynvml = pynvml
except ImportError:
raise ImportError(
"pynvml not found. Install with: pip install pynvml"
)
try:
self.pynvml.nvmlInit()
self.handle = self.pynvml.nvmlDeviceGetHandleByIndex(device_id)
except Exception as e:
raise RuntimeError(f"Failed to initialize NVIDIA GPU {device_id}: {e}")
super().__init__(device_id)
def get_metrics(self) -> GPUMetrics:
"""Get current NVIDIA GPU metrics."""
try:
# Power (in milliwatts)
power_mw = self.pynvml.nvmlDeviceGetPowerUsage(self.handle)
power_watts = power_mw / 1000.0
# Utilization
util = self.pynvml.nvmlDeviceGetUtilizationRates(self.handle)
gpu_util = util.gpu
# Memory
mem_info = self.pynvml.nvmlDeviceGetMemoryInfo(self.handle)
memory_used_gb = mem_info.used / (1024**3)
memory_total_gb = mem_info.total / (1024**3)
# Temperature
try:
temp = self.pynvml.nvmlDeviceGetTemperature(
self.handle,
self.pynvml.NVML_TEMPERATURE_GPU
)
except:
temp = None
# Try to get cumulative energy (newer GPUs)
energy_joules = None
try:
energy_mj = self.pynvml.nvmlDeviceGetTotalEnergyConsumption(self.handle)
energy_joules = energy_mj / 1000.0
except:
# Not supported on this GPU, will use power integration
pass
return GPUMetrics(
timestamp=time.time(),
power_watts=power_watts,
gpu_utilization_percent=gpu_util,
memory_used_gb=memory_used_gb,
memory_total_gb=memory_total_gb,
temperature_celsius=temp,
energy_joules=energy_joules
)
except Exception as e:
raise RuntimeError(f"Failed to get NVIDIA GPU metrics: {e}")
def get_device_name(self) -> str:
"""Get NVIDIA GPU device name."""
try:
name = self.pynvml.nvmlDeviceGetName(self.handle)
if isinstance(name, bytes):
name = name.decode('utf-8')
return name
except:
return f"NVIDIA GPU {self.device_id}"
def cleanup(self):
"""Cleanup NVIDIA resources."""
try:
self.pynvml.nvmlShutdown()
except:
pass
class AMDMonitor(GPUMonitor):
"""AMD GPU monitor using rocm-smi command line tool."""
def __init__(self, device_id: int = 0):
"""Initialize AMD monitor."""
import subprocess
import shutil
# Check if rocm-smi is available
if shutil.which('rocm-smi') is None:
raise RuntimeError("rocm-smi command not found. Make sure ROCm is installed and in PATH.")
self.device_id = device_id
# Verify device exists
try:
result = subprocess.run(
['rocm-smi', '--showid'],
capture_output=True,
text=True,
timeout=5
)
if result.returncode != 0:
raise RuntimeError(f"rocm-smi failed: {result.stderr}")
except subprocess.TimeoutExpired:
raise RuntimeError("rocm-smi command timed out")
except Exception as e:
raise RuntimeError(f"Failed to initialize AMD GPU {device_id}: {e}")
super().__init__(device_id)
def _parse_detailed_output(self, output: str) -> dict:
"""Parse rocm-smi detailed output format."""
lines = output.strip().split('\n')
# Parse detailed format: GPU[X] : Metric : Value
metrics = {
'temperature': None,
'power': None,
'vram_percent': None,
'gpu_percent': None,
}
device_prefix = f"GPU[{self.device_id}]"
for line in lines:
if not line.strip() or not line.startswith(device_prefix):
continue
# Split by colon
parts = line.split(':')
if len(parts) < 3:
continue
metric_name = parts[1].strip().lower()
value_str = parts[2].strip()
try:
# Temperature (Sensor junction)
if 'temperature' in metric_name and 'junction' in metric_name:
metrics['temperature'] = float(value_str)
# Power consumption
elif 'power' in metric_name and 'package' in metric_name:
metrics['power'] = float(value_str)
# GPU utilization
elif 'gpu use' in metric_name:
metrics['gpu_percent'] = float(value_str)
# VRAM usage percentage
elif 'memory allocated' in metric_name and 'vram%' in metric_name:
metrics['vram_percent'] = float(value_str)
except (ValueError, IndexError):
continue
# Validate we got the required metrics
if metrics['temperature'] is None:
raise ValueError(f"Could not find temperature for GPU[{self.device_id}]")
if metrics['power'] is None:
raise ValueError(f"Could not find power for GPU[{self.device_id}]")
if metrics['gpu_percent'] is None:
metrics['gpu_percent'] = 0.0
if metrics['vram_percent'] is None:
metrics['vram_percent'] = 0.0
return metrics
def _get_memory_info(self) -> tuple:
"""Get memory usage in GB using rocm-smi --showmeminfo."""
import subprocess
try:
result = subprocess.run(
['rocm-smi', '--showmeminfo', 'vram', '-d', str(self.device_id)],
capture_output=True,
text=True,
timeout=5
)
if result.returncode != 0:
return 0.0, 0.0
# Parse output for memory info
# Looking for lines like "GPU memory used: X MiB" and "GPU memory total: Y MiB"
used_gb = 0.0
total_gb = 0.0
for line in result.stdout.split('\n'):
if 'Used' in line or 'used' in line:
# Extract number
parts = line.split()
for i, part in enumerate(parts):
if part.replace('.', '').isdigit():
used_bytes = float(part)
# Check if next part indicates unit
if i + 1 < len(parts):
unit = parts[i + 1].lower()
if 'mb' in unit or 'mib' in unit:
used_gb = used_bytes / 1024
elif 'gb' in unit or 'gib' in unit:
used_gb = used_bytes
elif 'kb' in unit or 'kib' in unit:
used_gb = used_bytes / (1024 * 1024)
break
if 'Total' in line or 'total' in line:
parts = line.split()
for i, part in enumerate(parts):
if part.replace('.', '').isdigit():
total_bytes = float(part)
if i + 1 < len(parts):
unit = parts[i + 1].lower()
if 'mb' in unit or 'mib' in unit:
total_gb = total_bytes / 1024
elif 'gb' in unit or 'gib' in unit:
total_gb = total_bytes
elif 'kb' in unit or 'kib' in unit:
total_gb = total_bytes / (1024 * 1024)
break
return used_gb, total_gb
except Exception:
return 0.0, 0.0
def get_metrics(self) -> GPUMetrics:
"""Get current AMD GPU metrics."""
import subprocess
try:
# Get main metrics from concise output
result = subprocess.run(
['rocm-smi', '--showid', '--showtemp', '--showpower', '--showuse', '--showmemuse'],
capture_output=True,
text=True,
timeout=5
)
if result.returncode != 0:
raise RuntimeError(f"rocm-smi failed: {result.stderr}")
metrics = self._parse_detailed_output(result.stdout)
# Get detailed memory info
memory_used_gb, memory_total_gb = self._get_memory_info()
# If we couldn't get absolute memory, estimate from percentage
if memory_total_gb == 0.0:
# MI300X has ~192GB, MI250X has ~128GB - use a reasonable default
memory_total_gb = 192.0 # Assume MI300X
memory_used_gb = memory_total_gb * (metrics['vram_percent'] / 100.0)
return GPUMetrics(
timestamp=time.time(),
power_watts=metrics['power'],
gpu_utilization_percent=metrics['gpu_percent'],
memory_used_gb=memory_used_gb,
memory_total_gb=memory_total_gb,
temperature_celsius=metrics['temperature'],
energy_joules=None # Will use power integration
)
except subprocess.TimeoutExpired:
raise RuntimeError("rocm-smi command timed out")
except Exception as e:
raise RuntimeError(f"Failed to get AMD GPU metrics: {e}")
def get_device_name(self) -> str:
"""Get AMD GPU device name."""
import subprocess
try:
result = subprocess.run(
['rocm-smi', '--showproductname', '-d', str(self.device_id)],
capture_output=True,
text=True,
timeout=5
)
if result.returncode == 0:
# Parse output to find device name
for line in result.stdout.split('\n'):
if 'Card series' in line or 'Card model' in line or 'name' in line.lower():
parts = line.split(':')
if len(parts) > 1:
return parts[1].strip()
except Exception:
pass
return f"AMD GPU {self.device_id}"
def cleanup(self):
"""Cleanup AMD resources."""
# No cleanup needed for command-line tool
pass
def get_gpu_monitor(device_id: int = 0) -> GPUMonitor:
"""
Factory function to automatically detect and create appropriate GPU monitor.
Args:
device_id: GPU device ID to monitor
Returns:
GPUMonitor instance (NVIDIAMonitor or AMDMonitor)
Raises:
RuntimeError: If no supported GPU is found
"""
# Try AMD first (rocm-smi based) as it's more commonly available
try:
return AMDMonitor(device_id)
except:
pass
# Try NVIDIA if AMD fails
try:
return NVIDIAMonitor(device_id)
except:
pass
# Try to import torch to detect GPU type as last resort
try:
import torch
if torch.cuda.is_available():
# Check if it's NVIDIA or AMD
device_name = torch.cuda.get_device_name(device_id).lower()
if 'nvidia' in device_name or 'tesla' in device_name or 'geforce' in device_name:
return NVIDIAMonitor(device_id)
elif 'amd' in device_name or 'radeon' in device_name or 'mi300' in device_name or 'mi200' in device_name:
return AMDMonitor(device_id)
except:
pass
raise RuntimeError(
"No supported GPU found. Make sure either ROCm (rocm-smi) or NVIDIA (pynvml) drivers are installed."
)
def list_available_gpus() -> List[str]:
"""
List all available GPUs.
Returns:
List of GPU names
"""
gpus = []
# Try NVIDIA
try:
import pynvml
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
name = pynvml.nvmlDeviceGetName(handle)
if isinstance(name, bytes):
name = name.decode('utf-8')
gpus.append(f"GPU {i}: {name} (NVIDIA)")
pynvml.nvmlShutdown()
except:
pass
# Try AMD with rocm-smi
try:
import subprocess
import shutil
if shutil.which('rocm-smi'):
result = subprocess.run(
['rocm-smi', '--showid'],
capture_output=True,
text=True,
timeout=5
)
if result.returncode == 0:
# Parse device IDs from output
for line in result.stdout.split('\n'):
if not line.strip() or line.startswith('=') or 'Device' in line or 'ROCm' in line:
continue
parts = line.split()
if parts and parts[0].isdigit():
device_id = int(parts[0])
# Try to get device name
name_result = subprocess.run(
['rocm-smi', '--showproductname', '-d', str(device_id)],
capture_output=True,
text=True,
timeout=5
)
name = f"AMD GPU"
if name_result.returncode == 0:
for name_line in name_result.stdout.split('\n'):
if 'Card' in name_line or 'name' in name_line.lower():
parts_name = name_line.split(':')
if len(parts_name) > 1:
name = parts_name[1].strip()
break
gpus.append(f"GPU {device_id}: {name} (AMD)")
except:
pass
return gpus
if __name__ == "__main__":
"""Test GPU monitoring."""
print("=" * 60)
print("GPU Monitoring Test")
print("=" * 60)
# List available GPUs
print("\nAvailable GPUs:")
gpus = list_available_gpus()
if not gpus:
print(" No GPUs found!")
exit(1)
for gpu in gpus:
print(f" {gpu}")
# Test monitoring
print("\nTesting GPU 0 monitoring...")
try:
monitor = get_gpu_monitor(0)
print(f" Device: {monitor.get_device_name()}")
# Get metrics
metrics = monitor.get_metrics()
print(f"\nCurrent Metrics:")
print(f" Power: {metrics.power_watts:.2f} W")
print(f" GPU Utilization: {metrics.gpu_utilization_percent:.1f}%")
print(f" Memory: {metrics.memory_used_gb:.2f} / {metrics.memory_total_gb:.2f} GB")
if metrics.temperature_celsius:
print(f" Temperature: {metrics.temperature_celsius:.1f}°C")
# Test energy monitoring
print("\nTesting energy monitoring (5 seconds)...")
monitor.start_monitoring()
time.sleep(5)
energy = monitor.get_energy_consumed()
avg_power = monitor.get_average_power()
print(f" Energy consumed: {energy:.2f} J")
print(f" Average power: {avg_power:.2f} W")
monitor.cleanup()
print("\n✓ Monitoring test successful!")
except Exception as e:
print(f"\n✗ Error: {e}")
exit(1)

473
utils/metrics.py Normal file
View File

@@ -0,0 +1,473 @@
"""
Metrics Collection and Reporting for LLM Benchmarking
Provides centralized metrics collection, aggregation, and reporting.
"""
import json
import csv
from dataclasses import dataclass, asdict, field
from typing import Dict, List, Optional, Any
from pathlib import Path
import time
@dataclass
class StageMetrics:
"""Metrics for a specific stage (e.g., forward pass, prefill, etc.)."""
stage_name: str
duration_ms: float
tokens_processed: int
tokens_per_second: float
energy_joules: float
energy_per_token: float
avg_power_watts: float
peak_memory_gb: float
avg_gpu_util_percent: float
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return asdict(self)
@dataclass
class PretrainMetrics:
"""Metrics for pretraining benchmark."""
model_name: str
gpu_name: str
attention_implementation: str
batch_size: int
sequence_length: int
num_steps: int
# Stage-specific metrics
forward: StageMetrics
backward: StageMetrics
optimizer: StageMetrics
# Overall metrics
total_duration_ms: float
total_tokens: int
total_tokens_per_second: float
total_energy_joules: float
total_energy_per_token: float
timestamp: float = field(default_factory=time.time)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"model_name": self.model_name,
"gpu_name": self.gpu_name,
"attention_implementation": self.attention_implementation,
"batch_size": self.batch_size,
"sequence_length": self.sequence_length,
"num_steps": self.num_steps,
"forward": self.forward.to_dict(),
"backward": self.backward.to_dict(),
"optimizer": self.optimizer.to_dict(),
"total_duration_ms": self.total_duration_ms,
"total_tokens": self.total_tokens,
"total_tokens_per_second": self.total_tokens_per_second,
"total_energy_joules": self.total_energy_joules,
"total_energy_per_token": self.total_energy_per_token,
"timestamp": self.timestamp,
}
@dataclass
class InferenceMetrics:
"""Metrics for inference benchmark."""
model_name: str
gpu_name: str
attention_implementation: str
num_requests: int
prompt_length: int
generation_length: int
# Stage-specific metrics
prefill: StageMetrics # Time to First Token
decode: StageMetrics # Inter-Token Latency
# End-to-end metrics
e2e_latency_ms: float
e2e_tokens_per_second: float
e2e_energy_joules: float
e2e_energy_per_token: float
# Additional metrics
ttft_ms: float # Time to First Token (same as prefill duration)
itl_ms: float # Inter-Token Latency (decode duration / num_tokens)
timestamp: float = field(default_factory=time.time)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"model_name": self.model_name,
"gpu_name": self.gpu_name,
"attention_implementation": self.attention_implementation,
"num_requests": self.num_requests,
"prompt_length": self.prompt_length,
"generation_length": self.generation_length,
"prefill": self.prefill.to_dict(),
"decode": self.decode.to_dict(),
"e2e_latency_ms": self.e2e_latency_ms,
"e2e_tokens_per_second": self.e2e_tokens_per_second,
"e2e_energy_joules": self.e2e_energy_joules,
"e2e_energy_per_token": self.e2e_energy_per_token,
"ttft_ms": self.ttft_ms,
"itl_ms": self.itl_ms,
"timestamp": self.timestamp,
}
class MetricsCollector:
"""Collects metrics during benchmark runs."""
def __init__(self):
"""Initialize metrics collector."""
self.metrics_history: List[Dict[str, Any]] = []
def add_pretrain_metrics(self, metrics: PretrainMetrics):
"""Add pretraining metrics."""
self.metrics_history.append({
"type": "pretrain",
"metrics": metrics.to_dict()
})
def add_inference_metrics(self, metrics: InferenceMetrics):
"""Add inference metrics."""
self.metrics_history.append({
"type": "inference",
"metrics": metrics.to_dict()
})
def get_all_metrics(self) -> List[Dict[str, Any]]:
"""Get all collected metrics."""
return self.metrics_history
def clear(self):
"""Clear all metrics."""
self.metrics_history.clear()
class MetricsReporter:
"""Formats and outputs benchmark results."""
@staticmethod
def print_pretrain_metrics(metrics: PretrainMetrics, verbose: bool = True):
"""Print pretraining metrics to console."""
print("\n" + "=" * 80)
print("PRETRAINING BENCHMARK RESULTS")
print("=" * 80)
print(f"\nModel: {metrics.model_name}")
print(f"GPU: {metrics.gpu_name}")
print(f"Attention: {metrics.attention_implementation}")
print(f"Batch Size: {metrics.batch_size}")
print(f"Sequence Length: {metrics.sequence_length}")
print(f"Training Steps: {metrics.num_steps}")
print("\n" + "-" * 80)
print("STAGE BREAKDOWN")
print("-" * 80)
# Forward pass
print(f"\n[1] FORWARD PASS")
MetricsReporter._print_stage_metrics(metrics.forward, verbose)
# Backward pass
print(f"\n[2] BACKWARD PASS")
MetricsReporter._print_stage_metrics(metrics.backward, verbose)
# Optimizer step
print(f"\n[3] OPTIMIZER STEP")
MetricsReporter._print_stage_metrics(metrics.optimizer, verbose)
# Overall
print("\n" + "-" * 80)
print("OVERALL METRICS")
print("-" * 80)
print(f" Total Duration: {metrics.total_duration_ms:>10.2f} ms")
print(f" Total Tokens: {metrics.total_tokens:>10,}")
print(f" Throughput: {metrics.total_tokens_per_second:>10.2f} tokens/s")
print(f" Total Energy: {metrics.total_energy_joules:>10.2f} J")
print(f" Energy per Token: {metrics.total_energy_per_token*1000:>10.4f} mJ/token")
print("=" * 80 + "\n")
@staticmethod
def print_inference_metrics(metrics: InferenceMetrics, verbose: bool = True):
"""Print inference metrics to console."""
print("\n" + "=" * 80)
print("INFERENCE BENCHMARK RESULTS")
print("=" * 80)
print(f"\nModel: {metrics.model_name}")
print(f"GPU: {metrics.gpu_name}")
print(f"Attention: {metrics.attention_implementation}")
print(f"Requests: {metrics.num_requests}")
print(f"Prompt Length: {metrics.prompt_length}")
print(f"Generation Length: {metrics.generation_length}")
print("\n" + "-" * 80)
print("STAGE BREAKDOWN")
print("-" * 80)
# Prefill
print(f"\n[1] PREFILL (Time to First Token)")
MetricsReporter._print_stage_metrics(metrics.prefill, verbose)
print(f" TTFT: {metrics.ttft_ms:>10.2f} ms")
# Decode
print(f"\n[2] DECODE (Inter-Token Latency)")
MetricsReporter._print_stage_metrics(metrics.decode, verbose)
print(f" ITL: {metrics.itl_ms:>10.2f} ms/token")
# End-to-end
print("\n" + "-" * 80)
print("END-TO-END METRICS")
print("-" * 80)
print(f" Request Latency: {metrics.e2e_latency_ms:>10.2f} ms")
print(f" Throughput: {metrics.e2e_tokens_per_second:>10.2f} tokens/s")
print(f" Total Energy: {metrics.e2e_energy_joules:>10.2f} J")
print(f" Energy per Token: {metrics.e2e_energy_per_token*1000:>10.4f} mJ/token")
print("=" * 80 + "\n")
@staticmethod
def _print_stage_metrics(stage: StageMetrics, verbose: bool = True):
"""Print metrics for a single stage."""
print(f" Duration: {stage.duration_ms:>10.2f} ms")
print(f" Tokens: {stage.tokens_processed:>10,}")
print(f" Throughput: {stage.tokens_per_second:>10.2f} tokens/s")
print(f" Energy: {stage.energy_joules:>10.2f} J")
print(f" Energy per Token: {stage.energy_per_token*1000:>10.4f} mJ/token")
if verbose:
print(f" Avg Power: {stage.avg_power_watts:>10.2f} W")
print(f" Peak Memory: {stage.peak_memory_gb:>10.2f} GB")
print(f" Avg GPU Utilization: {stage.avg_gpu_util_percent:>10.1f} %")
@staticmethod
def save_json(metrics: Any, output_path: Path):
"""
Save metrics to JSON file.
Args:
metrics: PretrainMetrics or InferenceMetrics object
output_path: Path to output JSON file
"""
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
json.dump(metrics.to_dict(), f, indent=2)
print(f"Metrics saved to: {output_path}")
@staticmethod
def save_csv(metrics_list: List[Any], output_path: Path, benchmark_type: str = "pretrain"):
"""
Save multiple metrics to CSV file for comparison.
Args:
metrics_list: List of PretrainMetrics or InferenceMetrics objects
output_path: Path to output CSV file
benchmark_type: "pretrain" or "inference"
"""
if not metrics_list:
print("No metrics to save")
return
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', newline='') as f:
if benchmark_type == "pretrain":
MetricsReporter._save_pretrain_csv(metrics_list, f)
else:
MetricsReporter._save_inference_csv(metrics_list, f)
print(f"CSV saved to: {output_path}")
@staticmethod
def _save_pretrain_csv(metrics_list: List[PretrainMetrics], file):
"""Save pretraining metrics to CSV."""
fieldnames = [
'gpu_name', 'attention_implementation', 'batch_size', 'sequence_length', 'num_steps',
'forward_duration_ms', 'forward_tokens_per_sec', 'forward_energy_j', 'forward_energy_per_token_mj',
'backward_duration_ms', 'backward_tokens_per_sec', 'backward_energy_j', 'backward_energy_per_token_mj',
'optimizer_duration_ms', 'optimizer_tokens_per_sec', 'optimizer_energy_j', 'optimizer_energy_per_token_mj',
'total_duration_ms', 'total_tokens_per_sec', 'total_energy_j', 'total_energy_per_token_mj',
'timestamp'
]
writer = csv.DictWriter(file, fieldnames=fieldnames)
writer.writeheader()
for m in metrics_list:
writer.writerow({
'gpu_name': m.gpu_name,
'attention_implementation': m.attention_implementation,
'batch_size': m.batch_size,
'sequence_length': m.sequence_length,
'num_steps': m.num_steps,
'forward_duration_ms': m.forward.duration_ms,
'forward_tokens_per_sec': m.forward.tokens_per_second,
'forward_energy_j': m.forward.energy_joules,
'forward_energy_per_token_mj': m.forward.energy_per_token * 1000,
'backward_duration_ms': m.backward.duration_ms,
'backward_tokens_per_sec': m.backward.tokens_per_second,
'backward_energy_j': m.backward.energy_joules,
'backward_energy_per_token_mj': m.backward.energy_per_token * 1000,
'optimizer_duration_ms': m.optimizer.duration_ms,
'optimizer_tokens_per_sec': m.optimizer.tokens_per_second,
'optimizer_energy_j': m.optimizer.energy_joules,
'optimizer_energy_per_token_mj': m.optimizer.energy_per_token * 1000,
'total_duration_ms': m.total_duration_ms,
'total_tokens_per_sec': m.total_tokens_per_second,
'total_energy_j': m.total_energy_joules,
'total_energy_per_token_mj': m.total_energy_per_token * 1000,
'timestamp': m.timestamp,
})
@staticmethod
def _save_inference_csv(metrics_list: List[InferenceMetrics], file):
"""Save inference metrics to CSV."""
fieldnames = [
'gpu_name', 'attention_implementation', 'num_requests', 'prompt_length', 'generation_length',
'prefill_duration_ms', 'prefill_tokens_per_sec', 'prefill_energy_j', 'prefill_energy_per_token_mj',
'ttft_ms',
'decode_duration_ms', 'decode_tokens_per_sec', 'decode_energy_j', 'decode_energy_per_token_mj',
'itl_ms',
'e2e_latency_ms', 'e2e_tokens_per_sec', 'e2e_energy_j', 'e2e_energy_per_token_mj',
'timestamp'
]
writer = csv.DictWriter(file, fieldnames=fieldnames)
writer.writeheader()
for m in metrics_list:
writer.writerow({
'gpu_name': m.gpu_name,
'attention_implementation': m.attention_implementation,
'num_requests': m.num_requests,
'prompt_length': m.prompt_length,
'generation_length': m.generation_length,
'prefill_duration_ms': m.prefill.duration_ms,
'prefill_tokens_per_sec': m.prefill.tokens_per_second,
'prefill_energy_j': m.prefill.energy_joules,
'prefill_energy_per_token_mj': m.prefill.energy_per_token * 1000,
'ttft_ms': m.ttft_ms,
'decode_duration_ms': m.decode.duration_ms,
'decode_tokens_per_sec': m.decode.tokens_per_second,
'decode_energy_j': m.decode.energy_joules,
'decode_energy_per_token_mj': m.decode.energy_per_token * 1000,
'itl_ms': m.itl_ms,
'e2e_latency_ms': m.e2e_latency_ms,
'e2e_tokens_per_sec': m.e2e_tokens_per_second,
'e2e_energy_j': m.e2e_energy_joules,
'e2e_energy_per_token_mj': m.e2e_energy_per_token * 1000,
'timestamp': m.timestamp,
})
if __name__ == "__main__":
"""Test metrics reporting."""
# Create sample pretraining metrics
forward = StageMetrics(
stage_name="forward",
duration_ms=100.5,
tokens_processed=1024,
tokens_per_second=10189.3,
energy_joules=25.3,
energy_per_token=0.0247,
avg_power_watts=251.7,
peak_memory_gb=45.2,
avg_gpu_util_percent=95.3
)
backward = StageMetrics(
stage_name="backward",
duration_ms=205.2,
tokens_processed=1024,
tokens_per_second=4991.2,
energy_joules=51.6,
energy_per_token=0.0504,
avg_power_watts=251.5,
peak_memory_gb=48.6,
avg_gpu_util_percent=97.1
)
optimizer = StageMetrics(
stage_name="optimizer",
duration_ms=15.3,
tokens_processed=1024,
tokens_per_second=66928.1,
energy_joules=3.8,
energy_per_token=0.0037,
avg_power_watts=248.4,
peak_memory_gb=48.6,
avg_gpu_util_percent=42.1
)
pretrain_metrics = PretrainMetrics(
model_name="Qwen/Qwen2.5-3B-Instruct",
gpu_name="NVIDIA A100 80GB",
attention_implementation="flash_attention_2",
batch_size=8,
sequence_length=2048,
num_steps=10,
forward=forward,
backward=backward,
optimizer=optimizer,
total_duration_ms=321.0,
total_tokens=10240,
total_tokens_per_second=31900.3,
total_energy_joules=80.7,
total_energy_per_token=0.00788
)
# Print pretrain metrics
MetricsReporter.print_pretrain_metrics(pretrain_metrics)
# Create sample inference metrics
prefill = StageMetrics(
stage_name="prefill",
duration_ms=45.2,
tokens_processed=512,
tokens_per_second=11327.4,
energy_joules=11.3,
energy_per_token=0.0221,
avg_power_watts=250.0,
peak_memory_gb=42.1,
avg_gpu_util_percent=89.2
)
decode = StageMetrics(
stage_name="decode",
duration_ms=223.5,
tokens_processed=100,
tokens_per_second=447.4,
energy_joules=55.9,
energy_per_token=0.559,
avg_power_watts=250.1,
peak_memory_gb=42.1,
avg_gpu_util_percent=62.3
)
inference_metrics = InferenceMetrics(
model_name="Qwen/Qwen2.5-3B-Instruct",
gpu_name="NVIDIA A100 80GB",
attention_implementation="flash_attention_2",
num_requests=10,
prompt_length=512,
generation_length=100,
prefill=prefill,
decode=decode,
e2e_latency_ms=268.7,
e2e_tokens_per_second=2277.9,
e2e_energy_joules=67.2,
e2e_energy_per_token=0.110,
ttft_ms=45.2,
itl_ms=2.235
)
# Print inference metrics
MetricsReporter.print_inference_metrics(inference_metrics)