XGBoost Multi-Label Classification API
A REST API for multi-label classification of HPC workloads using XGBoost. Classifies applications based on roofline performance metrics.
Features
- Multi-Label Classification: Predict multiple labels with confidence scores
- Top-K Predictions: Get the most likely K predictions
- Batch Prediction: Process multiple samples in a single request
- JSON Aggregation: Aggregate raw roofline data into features automatically
- Around 60 HPC Application Classes: Including VASP, GROMACS, TurTLE, Chroma, QuantumESPRESSO, etc.
Installation
# Create and activate virtual environment
python -m venv venv
source venv/bin/activate
# Install dependencies
pip install -r requirements.txt
Quick Start
Run Tests
# Python tests
pytest test_xgb_fastapi.py -v
# Curl tests (start server first)
./test_api_curl.sh
API Endpoints
| Endpoint | Method | Description |
|---|---|---|
/health |
GET | Health check and model status |
/predict |
POST | Single prediction with confidence scores |
/predict_top_k |
POST | Get top-K predictions |
/batch_predict |
POST | Batch prediction for multiple samples |
/model/info |
GET | Model information |
Usage Examples
Start the Server
python xgb_fastapi.py --port 8000
Testing with curl
curl -X POST "http://localhost:8000/predict" \
-H "Content-Type: application/json" \
-d '{
"features": {
"bandwidth_raw_p10": 186.33,
"bandwidth_raw_median": 205.14,
"bandwidth_raw_p90": 210.83,
"flops_raw_p10": 162.024,
"flops_raw_median": 171.45,
"flops_raw_p90": 176.48,
"arith_intensity_median": 0.837,
"node_num": 0,
"duration": 19366
},
"threshold": 0.3
}'
Testing with Python
from xgb_local import XGBoostMultiLabelPredictor
predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')
result = predictor.predict(features, threshold=0.3)
print(f"Predictions: {result['predictions']}")
print(f"Confidences: {result['confidences']}")
# Top-K predictions
top_k = predictor.predict_top_k(features, k=5)
for cls, prob in top_k['top_probabilities'].items():
print(f"{cls}: {prob:.4f}")
See xgb_local_example.py for complete examples.
Model Features (28 total)
| Category | Features |
|---|---|
| Bandwidth | bandwidth_raw_p10, _median, _p90, _mad, _range, _iqr |
| FLOPS | flops_raw_p10, _median, _p90, _mad, _range, _iqr |
| Arithmetic Intensity | arith_intensity_p10, _median, _p90, _mad, _range, _iqr |
| Performance | avg_performance_gflops, median_performance_gflops, performance_gflops_mad |
| Correlation | bw_flops_covariance, bw_flops_correlation |
| System | avg_memory_bw_gbs, scalar_peak_gflops, simd_peak_gflops |
| Other | node_num, duration |
Description
Languages
Python
85.4%
Shell
13.4%
Makefile
1.2%