Files
2025-12-10 12:17:41 +01:00

106 lines
2.9 KiB
Markdown

# XGBoost Multi-Label Classification API
A REST API for multi-label classification of HPC workloads using XGBoost. Classifies applications based on roofline performance metrics.
## Features
- **Multi-Label Classification**: Predict multiple labels with confidence scores
- **Top-K Predictions**: Get the most likely K predictions
- **Batch Prediction**: Process multiple samples in a single request
- **JSON Aggregation**: Aggregate raw roofline data into features automatically
- **Around 60 HPC Application Classes**: Including VASP, GROMACS, TurTLE, Chroma, QuantumESPRESSO, etc.
## Installation
```bash
# Create and activate virtual environment
python -m venv venv
source venv/bin/activate
# Install dependencies
pip install -r requirements.txt
```
## Quick Start
### Run Tests
```bash
# Python tests
pytest test_xgb_fastapi.py -v
# Curl tests (start server first)
./test_api_curl.sh
```
## API Endpoints
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/health` | GET | Health check and model status |
| `/predict` | POST | Single prediction with confidence scores |
| `/predict_top_k` | POST | Get top-K predictions |
| `/batch_predict` | POST | Batch prediction for multiple samples |
| `/model/info` | GET | Model information |
## Usage Examples
### Start the Server
```bash
python xgb_fastapi.py --port 8000
```
### Testing with curl
```bash
curl -X POST "http://localhost:8000/predict" \
-H "Content-Type: application/json" \
-d '{
"features": {
"bandwidth_raw_p10": 186.33,
"bandwidth_raw_median": 205.14,
"bandwidth_raw_p90": 210.83,
"flops_raw_p10": 162.024,
"flops_raw_median": 171.45,
"flops_raw_p90": 176.48,
"arith_intensity_median": 0.837,
"node_num": 0,
"duration": 19366
},
"threshold": 0.3
}'
```
### Testing with Python
```python
from xgb_local import XGBoostMultiLabelPredictor
predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')
result = predictor.predict(features, threshold=0.3)
print(f"Predictions: {result['predictions']}")
print(f"Confidences: {result['confidences']}")
# Top-K predictions
top_k = predictor.predict_top_k(features, k=5)
for cls, prob in top_k['top_probabilities'].items():
print(f"{cls}: {prob:.4f}")
```
See `xgb_local_example.py` for complete examples.
## Model Features (28 total)
| Category | Features |
|----------|----------|
| Bandwidth | `bandwidth_raw_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
| FLOPS | `flops_raw_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
| Arithmetic Intensity | `arith_intensity_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
| Performance | `avg_performance_gflops`, `median_performance_gflops`, `performance_gflops_mad` |
| Correlation | `bw_flops_covariance`, `bw_flops_correlation` |
| System | `avg_memory_bw_gbs`, `scalar_peak_gflops`, `simd_peak_gflops` |
| Other | `node_num`, `duration` |