Files
cocogoat/utils/attention.py
2026-02-05 23:18:26 +01:00

296 lines
11 KiB
Python

"""
Attention Implementation Helpers for LLM Benchmarking
Provides functions for configuring different attention implementations
based on GPU type.
"""
from typing import Optional
import warnings
def get_default_attention(gpu_name: str) -> str:
"""
Get default attention implementation for GPU type.
Args:
gpu_name: GPU device name (from monitoring)
Returns:
Attention implementation string
"""
gpu_lower = gpu_name.lower()
# H100/H200: FlashAttention-3 Hopper
if 'h100' in gpu_lower or 'h200' in gpu_lower:
return "flash_attention_3_hopper"
# A100, MI300X, other: FlashAttention-2
return "flash_attention_2"
def configure_model_attention(model, attn_implementation: str, verbose: bool = True):
"""
Configure model to use specified attention implementation.
This function patches the model if needed to use the specified attention.
For standard implementations like flash_attention_2, the model should already
be loaded with the correct implementation via AutoModelForCausalLM.from_pretrained().
For FlashAttention-3 Hopper, this patches the model's attention modules.
Args:
model: The loaded model
attn_implementation: Attention implementation to use
verbose: Print configuration messages
Returns:
Configured model
"""
if verbose:
print(f"Configuring attention: {attn_implementation}")
if attn_implementation == "flash_attention_3_hopper":
# Patch model to use FlashAttention-3 Hopper
try:
import flash_attn_interface
except ImportError:
raise ImportError(
"flash_attn_interface not found. This is required for FlashAttention-3.\n"
"Install with appropriate method for your system."
)
# Patch the model's attention function
_patch_fa3_hopper(model, verbose=verbose)
elif attn_implementation == "flash_attention_2":
# Model should already be loaded with FA2
if verbose:
print(" Using FlashAttention-2 (configured during model loading)")
elif attn_implementation == "sdpa":
# PyTorch Scaled Dot Product Attention
if verbose:
print(" Using PyTorch SDPA")
elif attn_implementation == "eager":
# Standard PyTorch attention
if verbose:
print(" Using eager attention")
else:
warnings.warn(f"Unknown attention implementation: {attn_implementation}")
return model
def _patch_fa3_hopper(model, verbose: bool = True):
"""
Patch model to use FlashAttention-3 Hopper.
This replaces the attention computation in the model's attention layers
with calls to flash_attn_interface.flash_attn_func().
Args:
model: The model to patch
verbose: Print patching messages
"""
import flash_attn_interface
import torch
# Counter for patched modules
num_patched = 0
# Iterate through all modules in the model
for name, module in model.named_modules():
# Look for attention modules (this will vary by model architecture)
# Common names: "self_attn", "attn", "attention"
if any(attn_name in name.lower() for attn_name in ['self_attn', 'attention']):
# Check if module has a forward method we can patch
if hasattr(module, 'forward'):
# Save original forward
original_forward = module.forward
# Create patched forward function
def create_patched_forward(orig_forward):
def patched_forward(hidden_states, *args, **kwargs):
# Check if this is an attention computation
# For Qwen models, attention modules typically have q, k, v projections
if hasattr(module, 'q_proj') and hasattr(module, 'k_proj') and hasattr(module, 'v_proj'):
# Extract batch, seq_len, hidden_dim
batch_size, seq_len, hidden_dim = hidden_states.shape
# Compute Q, K, V
q = module.q_proj(hidden_states)
k = module.k_proj(hidden_states)
v = module.v_proj(hidden_states)
# Reshape for multi-head attention
num_heads = module.num_heads
head_dim = hidden_dim // num_heads
q = q.view(batch_size, seq_len, num_heads, head_dim)
k = k.view(batch_size, seq_len, num_heads, head_dim)
v = v.view(batch_size, seq_len, num_heads, head_dim)
# Call FlashAttention-3
# Note: flash_attn_func expects (batch, seqlen, nheads, headdim)
attn_output = flash_attn_interface.flash_attn_func(
q, k, v,
dropout_p=0.0,
softmax_scale=None, # Will use default 1/sqrt(head_dim)
causal=True, # For causal LM
)
# Reshape back
attn_output = attn_output.view(batch_size, seq_len, hidden_dim)
# Apply output projection if it exists
if hasattr(module, 'o_proj'):
attn_output = module.o_proj(attn_output)
return (attn_output,) + (None,) * (len(orig_forward(hidden_states, *args, **kwargs)) - 1)
else:
# Not an attention module we can patch, use original
return orig_forward(hidden_states, *args, **kwargs)
return patched_forward
# Apply patch
module.forward = create_patched_forward(original_forward)
num_patched += 1
if verbose:
if num_patched > 0:
print(f" ✓ Patched {num_patched} attention modules to use FlashAttention-3 Hopper")
else:
warnings.warn(" ⚠ No attention modules found to patch for FlashAttention-3")
def get_attention_info(attn_implementation: str) -> dict:
"""
Get information about an attention implementation.
Args:
attn_implementation: Attention implementation string
Returns:
Dictionary with info about the implementation
"""
info = {
"flash_attention_2": {
"name": "FlashAttention-2",
"description": "Optimized attention for A100 and other GPUs",
"gpu_support": ["A100", "MI300X", "V100", "RTX"],
"memory_efficient": True,
"requires_cuda": True,
},
"flash_attention_3_hopper": {
"name": "FlashAttention-3 Hopper",
"description": "Optimized attention for H100/H200 Hopper architecture",
"gpu_support": ["H100", "H200"],
"memory_efficient": True,
"requires_cuda": True,
},
"sdpa": {
"name": "PyTorch SDPA",
"description": "PyTorch Scaled Dot Product Attention",
"gpu_support": ["All"],
"memory_efficient": True,
"requires_cuda": False,
},
"eager": {
"name": "Eager Attention",
"description": "Standard PyTorch attention implementation",
"gpu_support": ["All"],
"memory_efficient": False,
"requires_cuda": False,
},
}
return info.get(attn_implementation, {
"name": attn_implementation,
"description": "Unknown attention implementation",
"gpu_support": ["Unknown"],
"memory_efficient": False,
"requires_cuda": False,
})
def validate_attention_for_gpu(attn_implementation: str, gpu_name: str) -> tuple[bool, Optional[str]]:
"""
Validate if attention implementation is suitable for GPU.
Args:
attn_implementation: Attention implementation
gpu_name: GPU device name
Returns:
Tuple of (is_valid, warning_message)
"""
gpu_lower = gpu_name.lower()
# FlashAttention-3 Hopper validation
if attn_implementation == "flash_attention_3_hopper":
if 'h100' not in gpu_lower and 'h200' not in gpu_lower:
return False, (
f"FlashAttention-3 Hopper is optimized for H100/H200. "
f"Current GPU: {gpu_name}. Consider using flash_attention_2 instead."
)
# FlashAttention-2 on Hopper GPUs
if attn_implementation == "flash_attention_2":
if 'h100' in gpu_lower or 'h200' in gpu_lower:
return True, (
f"FlashAttention-2 will work on {gpu_name}, but FlashAttention-3 Hopper "
f"may provide better performance."
)
return True, None
if __name__ == "__main__":
"""Test attention configuration."""
print("=" * 60)
print("Attention Implementation Test")
print("=" * 60)
# Test getting default attention for different GPUs
test_gpus = [
"NVIDIA A100 80GB",
"NVIDIA H100 80GB",
"NVIDIA H200 141GB",
"AMD Instinct MI300X",
]
print("\nDefault attention implementations:")
for gpu in test_gpus:
attn = get_default_attention(gpu)
print(f" {gpu:30s}{attn}")
# Test validation
print("\nValidation tests:")
test_cases = [
("flash_attention_3_hopper", "NVIDIA H100 80GB"),
("flash_attention_3_hopper", "NVIDIA A100 80GB"),
("flash_attention_2", "NVIDIA H100 80GB"),
("flash_attention_2", "NVIDIA A100 80GB"),
]
for attn, gpu in test_cases:
valid, warning = validate_attention_for_gpu(attn, gpu)
status = "" if valid else ""
print(f" {status} {attn:30s} on {gpu:25s}")
if warning:
print(f"{warning}")
# Test getting info
print("\nAttention implementation info:")
for attn in ["flash_attention_2", "flash_attention_3_hopper", "sdpa"]:
info = get_attention_info(attn)
print(f"\n {info['name']}:")
print(f" Description: {info['description']}")
print(f" GPU Support: {', '.join(info['gpu_support'])}")
print(f" Memory Efficient: {info['memory_efficient']}")