Initial commit
This commit is contained in:
295
utils/attention.py
Normal file
295
utils/attention.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Attention Implementation Helpers for LLM Benchmarking
|
||||
|
||||
Provides functions for configuring different attention implementations
|
||||
based on GPU type.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import warnings
|
||||
|
||||
|
||||
def get_default_attention(gpu_name: str) -> str:
|
||||
"""
|
||||
Get default attention implementation for GPU type.
|
||||
|
||||
Args:
|
||||
gpu_name: GPU device name (from monitoring)
|
||||
|
||||
Returns:
|
||||
Attention implementation string
|
||||
"""
|
||||
gpu_lower = gpu_name.lower()
|
||||
|
||||
# H100/H200: FlashAttention-3 Hopper
|
||||
if 'h100' in gpu_lower or 'h200' in gpu_lower:
|
||||
return "flash_attention_3_hopper"
|
||||
|
||||
# A100, MI300X, other: FlashAttention-2
|
||||
return "flash_attention_2"
|
||||
|
||||
|
||||
def configure_model_attention(model, attn_implementation: str, verbose: bool = True):
|
||||
"""
|
||||
Configure model to use specified attention implementation.
|
||||
|
||||
This function patches the model if needed to use the specified attention.
|
||||
For standard implementations like flash_attention_2, the model should already
|
||||
be loaded with the correct implementation via AutoModelForCausalLM.from_pretrained().
|
||||
|
||||
For FlashAttention-3 Hopper, this patches the model's attention modules.
|
||||
|
||||
Args:
|
||||
model: The loaded model
|
||||
attn_implementation: Attention implementation to use
|
||||
verbose: Print configuration messages
|
||||
|
||||
Returns:
|
||||
Configured model
|
||||
"""
|
||||
if verbose:
|
||||
print(f"Configuring attention: {attn_implementation}")
|
||||
|
||||
if attn_implementation == "flash_attention_3_hopper":
|
||||
# Patch model to use FlashAttention-3 Hopper
|
||||
try:
|
||||
import flash_attn_interface
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"flash_attn_interface not found. This is required for FlashAttention-3.\n"
|
||||
"Install with appropriate method for your system."
|
||||
)
|
||||
|
||||
# Patch the model's attention function
|
||||
_patch_fa3_hopper(model, verbose=verbose)
|
||||
|
||||
elif attn_implementation == "flash_attention_2":
|
||||
# Model should already be loaded with FA2
|
||||
if verbose:
|
||||
print(" Using FlashAttention-2 (configured during model loading)")
|
||||
|
||||
elif attn_implementation == "sdpa":
|
||||
# PyTorch Scaled Dot Product Attention
|
||||
if verbose:
|
||||
print(" Using PyTorch SDPA")
|
||||
|
||||
elif attn_implementation == "eager":
|
||||
# Standard PyTorch attention
|
||||
if verbose:
|
||||
print(" Using eager attention")
|
||||
|
||||
else:
|
||||
warnings.warn(f"Unknown attention implementation: {attn_implementation}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _patch_fa3_hopper(model, verbose: bool = True):
|
||||
"""
|
||||
Patch model to use FlashAttention-3 Hopper.
|
||||
|
||||
This replaces the attention computation in the model's attention layers
|
||||
with calls to flash_attn_interface.flash_attn_func().
|
||||
|
||||
Args:
|
||||
model: The model to patch
|
||||
verbose: Print patching messages
|
||||
"""
|
||||
import flash_attn_interface
|
||||
import torch
|
||||
|
||||
# Counter for patched modules
|
||||
num_patched = 0
|
||||
|
||||
# Iterate through all modules in the model
|
||||
for name, module in model.named_modules():
|
||||
# Look for attention modules (this will vary by model architecture)
|
||||
# Common names: "self_attn", "attn", "attention"
|
||||
if any(attn_name in name.lower() for attn_name in ['self_attn', 'attention']):
|
||||
# Check if module has a forward method we can patch
|
||||
if hasattr(module, 'forward'):
|
||||
# Save original forward
|
||||
original_forward = module.forward
|
||||
|
||||
# Create patched forward function
|
||||
def create_patched_forward(orig_forward):
|
||||
def patched_forward(hidden_states, *args, **kwargs):
|
||||
# Check if this is an attention computation
|
||||
# For Qwen models, attention modules typically have q, k, v projections
|
||||
if hasattr(module, 'q_proj') and hasattr(module, 'k_proj') and hasattr(module, 'v_proj'):
|
||||
# Extract batch, seq_len, hidden_dim
|
||||
batch_size, seq_len, hidden_dim = hidden_states.shape
|
||||
|
||||
# Compute Q, K, V
|
||||
q = module.q_proj(hidden_states)
|
||||
k = module.k_proj(hidden_states)
|
||||
v = module.v_proj(hidden_states)
|
||||
|
||||
# Reshape for multi-head attention
|
||||
num_heads = module.num_heads
|
||||
head_dim = hidden_dim // num_heads
|
||||
|
||||
q = q.view(batch_size, seq_len, num_heads, head_dim)
|
||||
k = k.view(batch_size, seq_len, num_heads, head_dim)
|
||||
v = v.view(batch_size, seq_len, num_heads, head_dim)
|
||||
|
||||
# Call FlashAttention-3
|
||||
# Note: flash_attn_func expects (batch, seqlen, nheads, headdim)
|
||||
attn_output = flash_attn_interface.flash_attn_func(
|
||||
q, k, v,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None, # Will use default 1/sqrt(head_dim)
|
||||
causal=True, # For causal LM
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
attn_output = attn_output.view(batch_size, seq_len, hidden_dim)
|
||||
|
||||
# Apply output projection if it exists
|
||||
if hasattr(module, 'o_proj'):
|
||||
attn_output = module.o_proj(attn_output)
|
||||
|
||||
return (attn_output,) + (None,) * (len(orig_forward(hidden_states, *args, **kwargs)) - 1)
|
||||
|
||||
else:
|
||||
# Not an attention module we can patch, use original
|
||||
return orig_forward(hidden_states, *args, **kwargs)
|
||||
|
||||
return patched_forward
|
||||
|
||||
# Apply patch
|
||||
module.forward = create_patched_forward(original_forward)
|
||||
num_patched += 1
|
||||
|
||||
if verbose:
|
||||
if num_patched > 0:
|
||||
print(f" ✓ Patched {num_patched} attention modules to use FlashAttention-3 Hopper")
|
||||
else:
|
||||
warnings.warn(" ⚠ No attention modules found to patch for FlashAttention-3")
|
||||
|
||||
|
||||
def get_attention_info(attn_implementation: str) -> dict:
|
||||
"""
|
||||
Get information about an attention implementation.
|
||||
|
||||
Args:
|
||||
attn_implementation: Attention implementation string
|
||||
|
||||
Returns:
|
||||
Dictionary with info about the implementation
|
||||
"""
|
||||
info = {
|
||||
"flash_attention_2": {
|
||||
"name": "FlashAttention-2",
|
||||
"description": "Optimized attention for A100 and other GPUs",
|
||||
"gpu_support": ["A100", "MI300X", "V100", "RTX"],
|
||||
"memory_efficient": True,
|
||||
"requires_cuda": True,
|
||||
},
|
||||
"flash_attention_3_hopper": {
|
||||
"name": "FlashAttention-3 Hopper",
|
||||
"description": "Optimized attention for H100/H200 Hopper architecture",
|
||||
"gpu_support": ["H100", "H200"],
|
||||
"memory_efficient": True,
|
||||
"requires_cuda": True,
|
||||
},
|
||||
"sdpa": {
|
||||
"name": "PyTorch SDPA",
|
||||
"description": "PyTorch Scaled Dot Product Attention",
|
||||
"gpu_support": ["All"],
|
||||
"memory_efficient": True,
|
||||
"requires_cuda": False,
|
||||
},
|
||||
"eager": {
|
||||
"name": "Eager Attention",
|
||||
"description": "Standard PyTorch attention implementation",
|
||||
"gpu_support": ["All"],
|
||||
"memory_efficient": False,
|
||||
"requires_cuda": False,
|
||||
},
|
||||
}
|
||||
|
||||
return info.get(attn_implementation, {
|
||||
"name": attn_implementation,
|
||||
"description": "Unknown attention implementation",
|
||||
"gpu_support": ["Unknown"],
|
||||
"memory_efficient": False,
|
||||
"requires_cuda": False,
|
||||
})
|
||||
|
||||
|
||||
def validate_attention_for_gpu(attn_implementation: str, gpu_name: str) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate if attention implementation is suitable for GPU.
|
||||
|
||||
Args:
|
||||
attn_implementation: Attention implementation
|
||||
gpu_name: GPU device name
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, warning_message)
|
||||
"""
|
||||
gpu_lower = gpu_name.lower()
|
||||
|
||||
# FlashAttention-3 Hopper validation
|
||||
if attn_implementation == "flash_attention_3_hopper":
|
||||
if 'h100' not in gpu_lower and 'h200' not in gpu_lower:
|
||||
return False, (
|
||||
f"FlashAttention-3 Hopper is optimized for H100/H200. "
|
||||
f"Current GPU: {gpu_name}. Consider using flash_attention_2 instead."
|
||||
)
|
||||
|
||||
# FlashAttention-2 on Hopper GPUs
|
||||
if attn_implementation == "flash_attention_2":
|
||||
if 'h100' in gpu_lower or 'h200' in gpu_lower:
|
||||
return True, (
|
||||
f"FlashAttention-2 will work on {gpu_name}, but FlashAttention-3 Hopper "
|
||||
f"may provide better performance."
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Test attention configuration."""
|
||||
print("=" * 60)
|
||||
print("Attention Implementation Test")
|
||||
print("=" * 60)
|
||||
|
||||
# Test getting default attention for different GPUs
|
||||
test_gpus = [
|
||||
"NVIDIA A100 80GB",
|
||||
"NVIDIA H100 80GB",
|
||||
"NVIDIA H200 141GB",
|
||||
"AMD Instinct MI300X",
|
||||
]
|
||||
|
||||
print("\nDefault attention implementations:")
|
||||
for gpu in test_gpus:
|
||||
attn = get_default_attention(gpu)
|
||||
print(f" {gpu:30s} → {attn}")
|
||||
|
||||
# Test validation
|
||||
print("\nValidation tests:")
|
||||
test_cases = [
|
||||
("flash_attention_3_hopper", "NVIDIA H100 80GB"),
|
||||
("flash_attention_3_hopper", "NVIDIA A100 80GB"),
|
||||
("flash_attention_2", "NVIDIA H100 80GB"),
|
||||
("flash_attention_2", "NVIDIA A100 80GB"),
|
||||
]
|
||||
|
||||
for attn, gpu in test_cases:
|
||||
valid, warning = validate_attention_for_gpu(attn, gpu)
|
||||
status = "✓" if valid else "✗"
|
||||
print(f" {status} {attn:30s} on {gpu:25s}")
|
||||
if warning:
|
||||
print(f" ⚠ {warning}")
|
||||
|
||||
# Test getting info
|
||||
print("\nAttention implementation info:")
|
||||
for attn in ["flash_attention_2", "flash_attention_3_hopper", "sdpa"]:
|
||||
info = get_attention_info(attn)
|
||||
print(f"\n {info['name']}:")
|
||||
print(f" Description: {info['description']}")
|
||||
print(f" GPU Support: {', '.join(info['gpu_support'])}")
|
||||
print(f" Memory Efficient: {info['memory_efficient']}")
|
||||
Reference in New Issue
Block a user