Initial commit

This commit is contained in:
Bole Ma
2026-02-05 23:18:26 +01:00
commit 747c92ac6b
31 changed files with 4220 additions and 0 deletions

295
utils/attention.py Normal file
View File

@@ -0,0 +1,295 @@
"""
Attention Implementation Helpers for LLM Benchmarking
Provides functions for configuring different attention implementations
based on GPU type.
"""
from typing import Optional
import warnings
def get_default_attention(gpu_name: str) -> str:
"""
Get default attention implementation for GPU type.
Args:
gpu_name: GPU device name (from monitoring)
Returns:
Attention implementation string
"""
gpu_lower = gpu_name.lower()
# H100/H200: FlashAttention-3 Hopper
if 'h100' in gpu_lower or 'h200' in gpu_lower:
return "flash_attention_3_hopper"
# A100, MI300X, other: FlashAttention-2
return "flash_attention_2"
def configure_model_attention(model, attn_implementation: str, verbose: bool = True):
"""
Configure model to use specified attention implementation.
This function patches the model if needed to use the specified attention.
For standard implementations like flash_attention_2, the model should already
be loaded with the correct implementation via AutoModelForCausalLM.from_pretrained().
For FlashAttention-3 Hopper, this patches the model's attention modules.
Args:
model: The loaded model
attn_implementation: Attention implementation to use
verbose: Print configuration messages
Returns:
Configured model
"""
if verbose:
print(f"Configuring attention: {attn_implementation}")
if attn_implementation == "flash_attention_3_hopper":
# Patch model to use FlashAttention-3 Hopper
try:
import flash_attn_interface
except ImportError:
raise ImportError(
"flash_attn_interface not found. This is required for FlashAttention-3.\n"
"Install with appropriate method for your system."
)
# Patch the model's attention function
_patch_fa3_hopper(model, verbose=verbose)
elif attn_implementation == "flash_attention_2":
# Model should already be loaded with FA2
if verbose:
print(" Using FlashAttention-2 (configured during model loading)")
elif attn_implementation == "sdpa":
# PyTorch Scaled Dot Product Attention
if verbose:
print(" Using PyTorch SDPA")
elif attn_implementation == "eager":
# Standard PyTorch attention
if verbose:
print(" Using eager attention")
else:
warnings.warn(f"Unknown attention implementation: {attn_implementation}")
return model
def _patch_fa3_hopper(model, verbose: bool = True):
"""
Patch model to use FlashAttention-3 Hopper.
This replaces the attention computation in the model's attention layers
with calls to flash_attn_interface.flash_attn_func().
Args:
model: The model to patch
verbose: Print patching messages
"""
import flash_attn_interface
import torch
# Counter for patched modules
num_patched = 0
# Iterate through all modules in the model
for name, module in model.named_modules():
# Look for attention modules (this will vary by model architecture)
# Common names: "self_attn", "attn", "attention"
if any(attn_name in name.lower() for attn_name in ['self_attn', 'attention']):
# Check if module has a forward method we can patch
if hasattr(module, 'forward'):
# Save original forward
original_forward = module.forward
# Create patched forward function
def create_patched_forward(orig_forward):
def patched_forward(hidden_states, *args, **kwargs):
# Check if this is an attention computation
# For Qwen models, attention modules typically have q, k, v projections
if hasattr(module, 'q_proj') and hasattr(module, 'k_proj') and hasattr(module, 'v_proj'):
# Extract batch, seq_len, hidden_dim
batch_size, seq_len, hidden_dim = hidden_states.shape
# Compute Q, K, V
q = module.q_proj(hidden_states)
k = module.k_proj(hidden_states)
v = module.v_proj(hidden_states)
# Reshape for multi-head attention
num_heads = module.num_heads
head_dim = hidden_dim // num_heads
q = q.view(batch_size, seq_len, num_heads, head_dim)
k = k.view(batch_size, seq_len, num_heads, head_dim)
v = v.view(batch_size, seq_len, num_heads, head_dim)
# Call FlashAttention-3
# Note: flash_attn_func expects (batch, seqlen, nheads, headdim)
attn_output = flash_attn_interface.flash_attn_func(
q, k, v,
dropout_p=0.0,
softmax_scale=None, # Will use default 1/sqrt(head_dim)
causal=True, # For causal LM
)
# Reshape back
attn_output = attn_output.view(batch_size, seq_len, hidden_dim)
# Apply output projection if it exists
if hasattr(module, 'o_proj'):
attn_output = module.o_proj(attn_output)
return (attn_output,) + (None,) * (len(orig_forward(hidden_states, *args, **kwargs)) - 1)
else:
# Not an attention module we can patch, use original
return orig_forward(hidden_states, *args, **kwargs)
return patched_forward
# Apply patch
module.forward = create_patched_forward(original_forward)
num_patched += 1
if verbose:
if num_patched > 0:
print(f" ✓ Patched {num_patched} attention modules to use FlashAttention-3 Hopper")
else:
warnings.warn(" ⚠ No attention modules found to patch for FlashAttention-3")
def get_attention_info(attn_implementation: str) -> dict:
"""
Get information about an attention implementation.
Args:
attn_implementation: Attention implementation string
Returns:
Dictionary with info about the implementation
"""
info = {
"flash_attention_2": {
"name": "FlashAttention-2",
"description": "Optimized attention for A100 and other GPUs",
"gpu_support": ["A100", "MI300X", "V100", "RTX"],
"memory_efficient": True,
"requires_cuda": True,
},
"flash_attention_3_hopper": {
"name": "FlashAttention-3 Hopper",
"description": "Optimized attention for H100/H200 Hopper architecture",
"gpu_support": ["H100", "H200"],
"memory_efficient": True,
"requires_cuda": True,
},
"sdpa": {
"name": "PyTorch SDPA",
"description": "PyTorch Scaled Dot Product Attention",
"gpu_support": ["All"],
"memory_efficient": True,
"requires_cuda": False,
},
"eager": {
"name": "Eager Attention",
"description": "Standard PyTorch attention implementation",
"gpu_support": ["All"],
"memory_efficient": False,
"requires_cuda": False,
},
}
return info.get(attn_implementation, {
"name": attn_implementation,
"description": "Unknown attention implementation",
"gpu_support": ["Unknown"],
"memory_efficient": False,
"requires_cuda": False,
})
def validate_attention_for_gpu(attn_implementation: str, gpu_name: str) -> tuple[bool, Optional[str]]:
"""
Validate if attention implementation is suitable for GPU.
Args:
attn_implementation: Attention implementation
gpu_name: GPU device name
Returns:
Tuple of (is_valid, warning_message)
"""
gpu_lower = gpu_name.lower()
# FlashAttention-3 Hopper validation
if attn_implementation == "flash_attention_3_hopper":
if 'h100' not in gpu_lower and 'h200' not in gpu_lower:
return False, (
f"FlashAttention-3 Hopper is optimized for H100/H200. "
f"Current GPU: {gpu_name}. Consider using flash_attention_2 instead."
)
# FlashAttention-2 on Hopper GPUs
if attn_implementation == "flash_attention_2":
if 'h100' in gpu_lower or 'h200' in gpu_lower:
return True, (
f"FlashAttention-2 will work on {gpu_name}, but FlashAttention-3 Hopper "
f"may provide better performance."
)
return True, None
if __name__ == "__main__":
"""Test attention configuration."""
print("=" * 60)
print("Attention Implementation Test")
print("=" * 60)
# Test getting default attention for different GPUs
test_gpus = [
"NVIDIA A100 80GB",
"NVIDIA H100 80GB",
"NVIDIA H200 141GB",
"AMD Instinct MI300X",
]
print("\nDefault attention implementations:")
for gpu in test_gpus:
attn = get_default_attention(gpu)
print(f" {gpu:30s}{attn}")
# Test validation
print("\nValidation tests:")
test_cases = [
("flash_attention_3_hopper", "NVIDIA H100 80GB"),
("flash_attention_3_hopper", "NVIDIA A100 80GB"),
("flash_attention_2", "NVIDIA H100 80GB"),
("flash_attention_2", "NVIDIA A100 80GB"),
]
for attn, gpu in test_cases:
valid, warning = validate_attention_for_gpu(attn, gpu)
status = "" if valid else ""
print(f" {status} {attn:30s} on {gpu:25s}")
if warning:
print(f"{warning}")
# Test getting info
print("\nAttention implementation info:")
for attn in ["flash_attention_2", "flash_attention_3_hopper", "sdpa"]:
info = get_attention_info(attn)
print(f"\n {info['name']}:")
print(f" Description: {info['description']}")
print(f" GPU Support: {', '.join(info['gpu_support'])}")
print(f" Memory Efficient: {info['memory_efficient']}")