Files
cocogoat/cache_model.py
2026-02-05 23:18:26 +01:00

152 lines
4.6 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Model Caching Script for LLM Benchmarking
This script downloads and caches the Qwen3-4B model from HuggingFace
before running benchmarks on offline compute nodes.
"""
import argparse
import os
import sys
from pathlib import Path
def cache_model(model_name: str, cache_dir: str, force: bool = False):
"""
Download and cache a HuggingFace model.
Args:
model_name: HuggingFace model identifier (e.g., "Qwen/Qwen3-4B-Instruct-2507")
cache_dir: Local directory to cache the model
force: Force re-download even if model exists
"""
try:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
except ImportError:
print("Error: transformers library not found. Please install it:")
print(" pip install transformers")
sys.exit(1)
# Create cache directory
cache_path = Path(cache_dir).resolve()
cache_path.mkdir(parents=True, exist_ok=True)
print(f"Caching model: {model_name}")
print(f"Cache directory: {cache_path}")
print("-" * 60)
# Set HuggingFace cache directory
os.environ['HF_HOME'] = str(cache_path)
# Check if model already exists
model_path = cache_path / model_name.replace("/", "--")
if model_path.exists() and not force:
print(f"Model already cached at: {model_path}")
print("Use --force to re-download")
return str(cache_path)
try:
# Download config
print("\n[1/3] Downloading model config...")
config = AutoConfig.from_pretrained(
model_name,
cache_dir=cache_path,
trust_remote_code=True
)
print(f" ✓ Config downloaded")
print(f" - Model type: {config.model_type}")
print(f" - Hidden size: {config.hidden_size}")
print(f" - Num layers: {config.num_hidden_layers}")
print(f" - Num attention heads: {config.num_attention_heads}")
# Download tokenizer
print("\n[2/3] Downloading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=cache_path,
trust_remote_code=True
)
print(f" ✓ Tokenizer downloaded")
print(f" - Vocab size: {len(tokenizer)}")
print(f" - Model max length: {tokenizer.model_max_length}")
# Download model weights
print("\n[3/3] Downloading model weights...")
print(" (This may take several minutes depending on connection speed)")
model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=cache_path,
trust_remote_code=True,
torch_dtype="auto",
low_cpu_mem_usage=True
)
print(f" ✓ Model weights downloaded")
# Calculate total parameters
total_params = sum(p.numel() for p in model.parameters())
print(f" - Total parameters: {total_params:,} ({total_params/1e9:.2f}B)")
# Clean up model from memory
del model
print("\n" + "=" * 60)
print("✓ Model successfully cached!")
print("=" * 60)
print(f"\nCache location: {cache_path}")
print(f"\nTo use in benchmarks, set:")
print(f" --model-path {cache_path}")
print(f"\nOr set environment variable:")
print(f" export HF_HOME={cache_path}")
return str(cache_path)
except Exception as e:
print(f"\n✗ Error downloading model: {e}", file=sys.stderr)
sys.exit(1)
def main():
parser = argparse.ArgumentParser(
description="Cache HuggingFace model for offline use",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Cache model to default location
python cache_model.py
# Cache model to custom directory
python cache_model.py --cache-dir /path/to/cache
# Force re-download
python cache_model.py --force
"""
)
parser.add_argument(
"--model-name",
type=str,
default="Qwen/Qwen3-4B",
help="HuggingFace model identifier (default: Qwen/Qwen3-4B)"
)
parser.add_argument(
"--cache-dir",
type=str,
default="./model_cache",
help="Directory to cache model (default: ./model_cache in current directory)"
)
parser.add_argument(
"--force",
action="store_true",
help="Force re-download even if model exists"
)
args = parser.parse_args()
cache_model(args.model_name, args.cache_dir, args.force)
if __name__ == "__main__":
main()