""" Attention Implementation Helpers for LLM Benchmarking Provides functions for configuring different attention implementations based on GPU type. """ from typing import Optional import warnings def get_default_attention(gpu_name: str) -> str: """ Get default attention implementation for GPU type. Args: gpu_name: GPU device name (from monitoring) Returns: Attention implementation string """ gpu_lower = gpu_name.lower() # H100/H200: FlashAttention-3 Hopper if 'h100' in gpu_lower or 'h200' in gpu_lower: return "flash_attention_3_hopper" # A100, MI300X, other: FlashAttention-2 return "flash_attention_2" def configure_model_attention(model, attn_implementation: str, verbose: bool = True): """ Configure model to use specified attention implementation. This function patches the model if needed to use the specified attention. For standard implementations like flash_attention_2, the model should already be loaded with the correct implementation via AutoModelForCausalLM.from_pretrained(). For FlashAttention-3 Hopper, this patches the model's attention modules. Args: model: The loaded model attn_implementation: Attention implementation to use verbose: Print configuration messages Returns: Configured model """ if verbose: print(f"Configuring attention: {attn_implementation}") if attn_implementation == "flash_attention_3_hopper": # Patch model to use FlashAttention-3 Hopper try: import flash_attn_interface except ImportError: raise ImportError( "flash_attn_interface not found. This is required for FlashAttention-3.\n" "Install with appropriate method for your system." ) # Patch the model's attention function _patch_fa3_hopper(model, verbose=verbose) elif attn_implementation == "flash_attention_2": # Model should already be loaded with FA2 if verbose: print(" Using FlashAttention-2 (configured during model loading)") elif attn_implementation == "sdpa": # PyTorch Scaled Dot Product Attention if verbose: print(" Using PyTorch SDPA") elif attn_implementation == "eager": # Standard PyTorch attention if verbose: print(" Using eager attention") else: warnings.warn(f"Unknown attention implementation: {attn_implementation}") return model def _patch_fa3_hopper(model, verbose: bool = True): """ Patch model to use FlashAttention-3 Hopper. This replaces the attention computation in the model's attention layers with calls to flash_attn_interface.flash_attn_func(). Args: model: The model to patch verbose: Print patching messages """ import flash_attn_interface import torch # Counter for patched modules num_patched = 0 # Iterate through all modules in the model for name, module in model.named_modules(): # Look for attention modules (this will vary by model architecture) # Common names: "self_attn", "attn", "attention" if any(attn_name in name.lower() for attn_name in ['self_attn', 'attention']): # Check if module has a forward method we can patch if hasattr(module, 'forward'): # Save original forward original_forward = module.forward # Create patched forward function def create_patched_forward(orig_forward): def patched_forward(hidden_states, *args, **kwargs): # Check if this is an attention computation # For Qwen models, attention modules typically have q, k, v projections if hasattr(module, 'q_proj') and hasattr(module, 'k_proj') and hasattr(module, 'v_proj'): # Extract batch, seq_len, hidden_dim batch_size, seq_len, hidden_dim = hidden_states.shape # Compute Q, K, V q = module.q_proj(hidden_states) k = module.k_proj(hidden_states) v = module.v_proj(hidden_states) # Reshape for multi-head attention num_heads = module.num_heads head_dim = hidden_dim // num_heads q = q.view(batch_size, seq_len, num_heads, head_dim) k = k.view(batch_size, seq_len, num_heads, head_dim) v = v.view(batch_size, seq_len, num_heads, head_dim) # Call FlashAttention-3 # Note: flash_attn_func expects (batch, seqlen, nheads, headdim) attn_output = flash_attn_interface.flash_attn_func( q, k, v, dropout_p=0.0, softmax_scale=None, # Will use default 1/sqrt(head_dim) causal=True, # For causal LM ) # Reshape back attn_output = attn_output.view(batch_size, seq_len, hidden_dim) # Apply output projection if it exists if hasattr(module, 'o_proj'): attn_output = module.o_proj(attn_output) return (attn_output,) + (None,) * (len(orig_forward(hidden_states, *args, **kwargs)) - 1) else: # Not an attention module we can patch, use original return orig_forward(hidden_states, *args, **kwargs) return patched_forward # Apply patch module.forward = create_patched_forward(original_forward) num_patched += 1 if verbose: if num_patched > 0: print(f" ✓ Patched {num_patched} attention modules to use FlashAttention-3 Hopper") else: warnings.warn(" ⚠ No attention modules found to patch for FlashAttention-3") def get_attention_info(attn_implementation: str) -> dict: """ Get information about an attention implementation. Args: attn_implementation: Attention implementation string Returns: Dictionary with info about the implementation """ info = { "flash_attention_2": { "name": "FlashAttention-2", "description": "Optimized attention for A100 and other GPUs", "gpu_support": ["A100", "MI300X", "V100", "RTX"], "memory_efficient": True, "requires_cuda": True, }, "flash_attention_3_hopper": { "name": "FlashAttention-3 Hopper", "description": "Optimized attention for H100/H200 Hopper architecture", "gpu_support": ["H100", "H200"], "memory_efficient": True, "requires_cuda": True, }, "sdpa": { "name": "PyTorch SDPA", "description": "PyTorch Scaled Dot Product Attention", "gpu_support": ["All"], "memory_efficient": True, "requires_cuda": False, }, "eager": { "name": "Eager Attention", "description": "Standard PyTorch attention implementation", "gpu_support": ["All"], "memory_efficient": False, "requires_cuda": False, }, } return info.get(attn_implementation, { "name": attn_implementation, "description": "Unknown attention implementation", "gpu_support": ["Unknown"], "memory_efficient": False, "requires_cuda": False, }) def validate_attention_for_gpu(attn_implementation: str, gpu_name: str) -> tuple[bool, Optional[str]]: """ Validate if attention implementation is suitable for GPU. Args: attn_implementation: Attention implementation gpu_name: GPU device name Returns: Tuple of (is_valid, warning_message) """ gpu_lower = gpu_name.lower() # FlashAttention-3 Hopper validation if attn_implementation == "flash_attention_3_hopper": if 'h100' not in gpu_lower and 'h200' not in gpu_lower: return False, ( f"FlashAttention-3 Hopper is optimized for H100/H200. " f"Current GPU: {gpu_name}. Consider using flash_attention_2 instead." ) # FlashAttention-2 on Hopper GPUs if attn_implementation == "flash_attention_2": if 'h100' in gpu_lower or 'h200' in gpu_lower: return True, ( f"FlashAttention-2 will work on {gpu_name}, but FlashAttention-3 Hopper " f"may provide better performance." ) return True, None if __name__ == "__main__": """Test attention configuration.""" print("=" * 60) print("Attention Implementation Test") print("=" * 60) # Test getting default attention for different GPUs test_gpus = [ "NVIDIA A100 80GB", "NVIDIA H100 80GB", "NVIDIA H200 141GB", "AMD Instinct MI300X", ] print("\nDefault attention implementations:") for gpu in test_gpus: attn = get_default_attention(gpu) print(f" {gpu:30s} → {attn}") # Test validation print("\nValidation tests:") test_cases = [ ("flash_attention_3_hopper", "NVIDIA H100 80GB"), ("flash_attention_3_hopper", "NVIDIA A100 80GB"), ("flash_attention_2", "NVIDIA H100 80GB"), ("flash_attention_2", "NVIDIA A100 80GB"), ] for attn, gpu in test_cases: valid, warning = validate_attention_for_gpu(attn, gpu) status = "✓" if valid else "✗" print(f" {status} {attn:30s} on {gpu:25s}") if warning: print(f" ⚠ {warning}") # Test getting info print("\nAttention implementation info:") for attn in ["flash_attention_2", "flash_attention_3_hopper", "sdpa"]: info = get_attention_info(attn) print(f"\n {info['name']}:") print(f" Description: {info['description']}") print(f" GPU Support: {', '.join(info['gpu_support'])}") print(f" Memory Efficient: {info['memory_efficient']}")