152 lines
4.6 KiB
Python
Executable File
152 lines
4.6 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Model Caching Script for LLM Benchmarking
|
|
|
|
This script downloads and caches the Qwen3-4B model from HuggingFace
|
|
before running benchmarks on offline compute nodes.
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
def cache_model(model_name: str, cache_dir: str, force: bool = False):
|
|
"""
|
|
Download and cache a HuggingFace model.
|
|
|
|
Args:
|
|
model_name: HuggingFace model identifier (e.g., "Qwen/Qwen3-4B-Instruct-2507")
|
|
cache_dir: Local directory to cache the model
|
|
force: Force re-download even if model exists
|
|
"""
|
|
try:
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
|
except ImportError:
|
|
print("Error: transformers library not found. Please install it:")
|
|
print(" pip install transformers")
|
|
sys.exit(1)
|
|
|
|
# Create cache directory
|
|
cache_path = Path(cache_dir).resolve()
|
|
cache_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
print(f"Caching model: {model_name}")
|
|
print(f"Cache directory: {cache_path}")
|
|
print("-" * 60)
|
|
|
|
# Set HuggingFace cache directory
|
|
os.environ['HF_HOME'] = str(cache_path)
|
|
|
|
# Check if model already exists
|
|
model_path = cache_path / model_name.replace("/", "--")
|
|
if model_path.exists() and not force:
|
|
print(f"Model already cached at: {model_path}")
|
|
print("Use --force to re-download")
|
|
return str(cache_path)
|
|
|
|
try:
|
|
# Download config
|
|
print("\n[1/3] Downloading model config...")
|
|
config = AutoConfig.from_pretrained(
|
|
model_name,
|
|
cache_dir=cache_path,
|
|
trust_remote_code=True
|
|
)
|
|
print(f" ✓ Config downloaded")
|
|
print(f" - Model type: {config.model_type}")
|
|
print(f" - Hidden size: {config.hidden_size}")
|
|
print(f" - Num layers: {config.num_hidden_layers}")
|
|
print(f" - Num attention heads: {config.num_attention_heads}")
|
|
|
|
# Download tokenizer
|
|
print("\n[2/3] Downloading tokenizer...")
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_name,
|
|
cache_dir=cache_path,
|
|
trust_remote_code=True
|
|
)
|
|
print(f" ✓ Tokenizer downloaded")
|
|
print(f" - Vocab size: {len(tokenizer)}")
|
|
print(f" - Model max length: {tokenizer.model_max_length}")
|
|
|
|
# Download model weights
|
|
print("\n[3/3] Downloading model weights...")
|
|
print(" (This may take several minutes depending on connection speed)")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
cache_dir=cache_path,
|
|
trust_remote_code=True,
|
|
torch_dtype="auto",
|
|
low_cpu_mem_usage=True
|
|
)
|
|
print(f" ✓ Model weights downloaded")
|
|
|
|
# Calculate total parameters
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
print(f" - Total parameters: {total_params:,} ({total_params/1e9:.2f}B)")
|
|
|
|
# Clean up model from memory
|
|
del model
|
|
|
|
print("\n" + "=" * 60)
|
|
print("✓ Model successfully cached!")
|
|
print("=" * 60)
|
|
print(f"\nCache location: {cache_path}")
|
|
print(f"\nTo use in benchmarks, set:")
|
|
print(f" --model-path {cache_path}")
|
|
print(f"\nOr set environment variable:")
|
|
print(f" export HF_HOME={cache_path}")
|
|
|
|
return str(cache_path)
|
|
|
|
except Exception as e:
|
|
print(f"\n✗ Error downloading model: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Cache HuggingFace model for offline use",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Cache model to default location
|
|
python cache_model.py
|
|
|
|
# Cache model to custom directory
|
|
python cache_model.py --cache-dir /path/to/cache
|
|
|
|
# Force re-download
|
|
python cache_model.py --force
|
|
"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--model-name",
|
|
type=str,
|
|
default="Qwen/Qwen3-4B",
|
|
help="HuggingFace model identifier (default: Qwen/Qwen3-4B)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--cache-dir",
|
|
type=str,
|
|
default="./model_cache",
|
|
help="Directory to cache model (default: ./model_cache in current directory)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--force",
|
|
action="store_true",
|
|
help="Force re-download even if model exists"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
cache_model(args.model_name, args.cache_dir, args.force)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|