Files
2025-12-10 12:17:41 +01:00

2.9 KiB

XGBoost Multi-Label Classification API

A REST API for multi-label classification of HPC workloads using XGBoost. Classifies applications based on roofline performance metrics.

Features

  • Multi-Label Classification: Predict multiple labels with confidence scores
  • Top-K Predictions: Get the most likely K predictions
  • Batch Prediction: Process multiple samples in a single request
  • JSON Aggregation: Aggregate raw roofline data into features automatically
  • Around 60 HPC Application Classes: Including VASP, GROMACS, TurTLE, Chroma, QuantumESPRESSO, etc.

Installation

# Create and activate virtual environment
python -m venv venv
source venv/bin/activate

# Install dependencies
pip install -r requirements.txt

Quick Start

Run Tests

# Python tests
pytest test_xgb_fastapi.py -v

# Curl tests (start server first)
./test_api_curl.sh

API Endpoints

Endpoint Method Description
/health GET Health check and model status
/predict POST Single prediction with confidence scores
/predict_top_k POST Get top-K predictions
/batch_predict POST Batch prediction for multiple samples
/model/info GET Model information

Usage Examples

Start the Server

python xgb_fastapi.py --port 8000

Testing with curl

curl -X POST "http://localhost:8000/predict" \
  -H "Content-Type: application/json" \
  -d '{
    "features": {
      "bandwidth_raw_p10": 186.33,
      "bandwidth_raw_median": 205.14,
      "bandwidth_raw_p90": 210.83,
      "flops_raw_p10": 162.024,
      "flops_raw_median": 171.45,
      "flops_raw_p90": 176.48,
      "arith_intensity_median": 0.837,
      "node_num": 0,
      "duration": 19366
    },
    "threshold": 0.3
  }'

Testing with Python

from xgb_local import XGBoostMultiLabelPredictor

predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')

result = predictor.predict(features, threshold=0.3)
print(f"Predictions: {result['predictions']}")
print(f"Confidences: {result['confidences']}")

# Top-K predictions
top_k = predictor.predict_top_k(features, k=5)
for cls, prob in top_k['top_probabilities'].items():
    print(f"{cls}: {prob:.4f}")

See xgb_local_example.py for complete examples.

Model Features (28 total)

Category Features
Bandwidth bandwidth_raw_p10, _median, _p90, _mad, _range, _iqr
FLOPS flops_raw_p10, _median, _p90, _mad, _range, _iqr
Arithmetic Intensity arith_intensity_p10, _median, _p90, _mad, _range, _iqr
Performance avg_performance_gflops, median_performance_gflops, performance_gflops_mad
Correlation bw_flops_covariance, bw_flops_correlation
System avg_memory_bw_gbs, scalar_peak_gflops, simd_peak_gflops
Other node_num, duration