296 lines
11 KiB
Python
296 lines
11 KiB
Python
"""
|
|
Attention Implementation Helpers for LLM Benchmarking
|
|
|
|
Provides functions for configuring different attention implementations
|
|
based on GPU type.
|
|
"""
|
|
|
|
from typing import Optional
|
|
import warnings
|
|
|
|
|
|
def get_default_attention(gpu_name: str) -> str:
|
|
"""
|
|
Get default attention implementation for GPU type.
|
|
|
|
Args:
|
|
gpu_name: GPU device name (from monitoring)
|
|
|
|
Returns:
|
|
Attention implementation string
|
|
"""
|
|
gpu_lower = gpu_name.lower()
|
|
|
|
# H100/H200: FlashAttention-3 Hopper
|
|
if 'h100' in gpu_lower or 'h200' in gpu_lower:
|
|
return "flash_attention_3_hopper"
|
|
|
|
# A100, MI300X, other: FlashAttention-2
|
|
return "flash_attention_2"
|
|
|
|
|
|
def configure_model_attention(model, attn_implementation: str, verbose: bool = True):
|
|
"""
|
|
Configure model to use specified attention implementation.
|
|
|
|
This function patches the model if needed to use the specified attention.
|
|
For standard implementations like flash_attention_2, the model should already
|
|
be loaded with the correct implementation via AutoModelForCausalLM.from_pretrained().
|
|
|
|
For FlashAttention-3 Hopper, this patches the model's attention modules.
|
|
|
|
Args:
|
|
model: The loaded model
|
|
attn_implementation: Attention implementation to use
|
|
verbose: Print configuration messages
|
|
|
|
Returns:
|
|
Configured model
|
|
"""
|
|
if verbose:
|
|
print(f"Configuring attention: {attn_implementation}")
|
|
|
|
if attn_implementation == "flash_attention_3_hopper":
|
|
# Patch model to use FlashAttention-3 Hopper
|
|
try:
|
|
import flash_attn_interface
|
|
except ImportError:
|
|
raise ImportError(
|
|
"flash_attn_interface not found. This is required for FlashAttention-3.\n"
|
|
"Install with appropriate method for your system."
|
|
)
|
|
|
|
# Patch the model's attention function
|
|
_patch_fa3_hopper(model, verbose=verbose)
|
|
|
|
elif attn_implementation == "flash_attention_2":
|
|
# Model should already be loaded with FA2
|
|
if verbose:
|
|
print(" Using FlashAttention-2 (configured during model loading)")
|
|
|
|
elif attn_implementation == "sdpa":
|
|
# PyTorch Scaled Dot Product Attention
|
|
if verbose:
|
|
print(" Using PyTorch SDPA")
|
|
|
|
elif attn_implementation == "eager":
|
|
# Standard PyTorch attention
|
|
if verbose:
|
|
print(" Using eager attention")
|
|
|
|
else:
|
|
warnings.warn(f"Unknown attention implementation: {attn_implementation}")
|
|
|
|
return model
|
|
|
|
|
|
def _patch_fa3_hopper(model, verbose: bool = True):
|
|
"""
|
|
Patch model to use FlashAttention-3 Hopper.
|
|
|
|
This replaces the attention computation in the model's attention layers
|
|
with calls to flash_attn_interface.flash_attn_func().
|
|
|
|
Args:
|
|
model: The model to patch
|
|
verbose: Print patching messages
|
|
"""
|
|
import flash_attn_interface
|
|
import torch
|
|
|
|
# Counter for patched modules
|
|
num_patched = 0
|
|
|
|
# Iterate through all modules in the model
|
|
for name, module in model.named_modules():
|
|
# Look for attention modules (this will vary by model architecture)
|
|
# Common names: "self_attn", "attn", "attention"
|
|
if any(attn_name in name.lower() for attn_name in ['self_attn', 'attention']):
|
|
# Check if module has a forward method we can patch
|
|
if hasattr(module, 'forward'):
|
|
# Save original forward
|
|
original_forward = module.forward
|
|
|
|
# Create patched forward function
|
|
def create_patched_forward(orig_forward):
|
|
def patched_forward(hidden_states, *args, **kwargs):
|
|
# Check if this is an attention computation
|
|
# For Qwen models, attention modules typically have q, k, v projections
|
|
if hasattr(module, 'q_proj') and hasattr(module, 'k_proj') and hasattr(module, 'v_proj'):
|
|
# Extract batch, seq_len, hidden_dim
|
|
batch_size, seq_len, hidden_dim = hidden_states.shape
|
|
|
|
# Compute Q, K, V
|
|
q = module.q_proj(hidden_states)
|
|
k = module.k_proj(hidden_states)
|
|
v = module.v_proj(hidden_states)
|
|
|
|
# Reshape for multi-head attention
|
|
num_heads = module.num_heads
|
|
head_dim = hidden_dim // num_heads
|
|
|
|
q = q.view(batch_size, seq_len, num_heads, head_dim)
|
|
k = k.view(batch_size, seq_len, num_heads, head_dim)
|
|
v = v.view(batch_size, seq_len, num_heads, head_dim)
|
|
|
|
# Call FlashAttention-3
|
|
# Note: flash_attn_func expects (batch, seqlen, nheads, headdim)
|
|
attn_output = flash_attn_interface.flash_attn_func(
|
|
q, k, v,
|
|
dropout_p=0.0,
|
|
softmax_scale=None, # Will use default 1/sqrt(head_dim)
|
|
causal=True, # For causal LM
|
|
)
|
|
|
|
# Reshape back
|
|
attn_output = attn_output.view(batch_size, seq_len, hidden_dim)
|
|
|
|
# Apply output projection if it exists
|
|
if hasattr(module, 'o_proj'):
|
|
attn_output = module.o_proj(attn_output)
|
|
|
|
return (attn_output,) + (None,) * (len(orig_forward(hidden_states, *args, **kwargs)) - 1)
|
|
|
|
else:
|
|
# Not an attention module we can patch, use original
|
|
return orig_forward(hidden_states, *args, **kwargs)
|
|
|
|
return patched_forward
|
|
|
|
# Apply patch
|
|
module.forward = create_patched_forward(original_forward)
|
|
num_patched += 1
|
|
|
|
if verbose:
|
|
if num_patched > 0:
|
|
print(f" ✓ Patched {num_patched} attention modules to use FlashAttention-3 Hopper")
|
|
else:
|
|
warnings.warn(" ⚠ No attention modules found to patch for FlashAttention-3")
|
|
|
|
|
|
def get_attention_info(attn_implementation: str) -> dict:
|
|
"""
|
|
Get information about an attention implementation.
|
|
|
|
Args:
|
|
attn_implementation: Attention implementation string
|
|
|
|
Returns:
|
|
Dictionary with info about the implementation
|
|
"""
|
|
info = {
|
|
"flash_attention_2": {
|
|
"name": "FlashAttention-2",
|
|
"description": "Optimized attention for A100 and other GPUs",
|
|
"gpu_support": ["A100", "MI300X", "V100", "RTX"],
|
|
"memory_efficient": True,
|
|
"requires_cuda": True,
|
|
},
|
|
"flash_attention_3_hopper": {
|
|
"name": "FlashAttention-3 Hopper",
|
|
"description": "Optimized attention for H100/H200 Hopper architecture",
|
|
"gpu_support": ["H100", "H200"],
|
|
"memory_efficient": True,
|
|
"requires_cuda": True,
|
|
},
|
|
"sdpa": {
|
|
"name": "PyTorch SDPA",
|
|
"description": "PyTorch Scaled Dot Product Attention",
|
|
"gpu_support": ["All"],
|
|
"memory_efficient": True,
|
|
"requires_cuda": False,
|
|
},
|
|
"eager": {
|
|
"name": "Eager Attention",
|
|
"description": "Standard PyTorch attention implementation",
|
|
"gpu_support": ["All"],
|
|
"memory_efficient": False,
|
|
"requires_cuda": False,
|
|
},
|
|
}
|
|
|
|
return info.get(attn_implementation, {
|
|
"name": attn_implementation,
|
|
"description": "Unknown attention implementation",
|
|
"gpu_support": ["Unknown"],
|
|
"memory_efficient": False,
|
|
"requires_cuda": False,
|
|
})
|
|
|
|
|
|
def validate_attention_for_gpu(attn_implementation: str, gpu_name: str) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate if attention implementation is suitable for GPU.
|
|
|
|
Args:
|
|
attn_implementation: Attention implementation
|
|
gpu_name: GPU device name
|
|
|
|
Returns:
|
|
Tuple of (is_valid, warning_message)
|
|
"""
|
|
gpu_lower = gpu_name.lower()
|
|
|
|
# FlashAttention-3 Hopper validation
|
|
if attn_implementation == "flash_attention_3_hopper":
|
|
if 'h100' not in gpu_lower and 'h200' not in gpu_lower:
|
|
return False, (
|
|
f"FlashAttention-3 Hopper is optimized for H100/H200. "
|
|
f"Current GPU: {gpu_name}. Consider using flash_attention_2 instead."
|
|
)
|
|
|
|
# FlashAttention-2 on Hopper GPUs
|
|
if attn_implementation == "flash_attention_2":
|
|
if 'h100' in gpu_lower or 'h200' in gpu_lower:
|
|
return True, (
|
|
f"FlashAttention-2 will work on {gpu_name}, but FlashAttention-3 Hopper "
|
|
f"may provide better performance."
|
|
)
|
|
|
|
return True, None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
"""Test attention configuration."""
|
|
print("=" * 60)
|
|
print("Attention Implementation Test")
|
|
print("=" * 60)
|
|
|
|
# Test getting default attention for different GPUs
|
|
test_gpus = [
|
|
"NVIDIA A100 80GB",
|
|
"NVIDIA H100 80GB",
|
|
"NVIDIA H200 141GB",
|
|
"AMD Instinct MI300X",
|
|
]
|
|
|
|
print("\nDefault attention implementations:")
|
|
for gpu in test_gpus:
|
|
attn = get_default_attention(gpu)
|
|
print(f" {gpu:30s} → {attn}")
|
|
|
|
# Test validation
|
|
print("\nValidation tests:")
|
|
test_cases = [
|
|
("flash_attention_3_hopper", "NVIDIA H100 80GB"),
|
|
("flash_attention_3_hopper", "NVIDIA A100 80GB"),
|
|
("flash_attention_2", "NVIDIA H100 80GB"),
|
|
("flash_attention_2", "NVIDIA A100 80GB"),
|
|
]
|
|
|
|
for attn, gpu in test_cases:
|
|
valid, warning = validate_attention_for_gpu(attn, gpu)
|
|
status = "✓" if valid else "✗"
|
|
print(f" {status} {attn:30s} on {gpu:25s}")
|
|
if warning:
|
|
print(f" ⚠ {warning}")
|
|
|
|
# Test getting info
|
|
print("\nAttention implementation info:")
|
|
for attn in ["flash_attention_2", "flash_attention_3_hopper", "sdpa"]:
|
|
info = get_attention_info(attn)
|
|
print(f"\n {info['name']}:")
|
|
print(f" Description: {info['description']}")
|
|
print(f" GPU Support: {', '.join(info['gpu_support'])}")
|
|
print(f" Memory Efficient: {info['memory_efficient']}")
|