106 lines
2.9 KiB
Markdown
106 lines
2.9 KiB
Markdown
# XGBoost Multi-Label Classification API
|
|
|
|
A REST API for multi-label classification of HPC workloads using XGBoost. Classifies applications based on roofline performance metrics.
|
|
|
|
## Features
|
|
|
|
- **Multi-Label Classification**: Predict multiple labels with confidence scores
|
|
- **Top-K Predictions**: Get the most likely K predictions
|
|
- **Batch Prediction**: Process multiple samples in a single request
|
|
- **JSON Aggregation**: Aggregate raw roofline data into features automatically
|
|
- **Around 60 HPC Application Classes**: Including VASP, GROMACS, TurTLE, Chroma, QuantumESPRESSO, etc.
|
|
|
|
## Installation
|
|
|
|
```bash
|
|
# Create and activate virtual environment
|
|
python -m venv venv
|
|
source venv/bin/activate
|
|
|
|
# Install dependencies
|
|
pip install -r requirements.txt
|
|
```
|
|
|
|
## Quick Start
|
|
|
|
### Run Tests
|
|
|
|
```bash
|
|
# Python tests
|
|
pytest test_xgb_fastapi.py -v
|
|
|
|
# Curl tests (start server first)
|
|
./test_api_curl.sh
|
|
```
|
|
|
|
## API Endpoints
|
|
|
|
| Endpoint | Method | Description |
|
|
|----------|--------|-------------|
|
|
| `/health` | GET | Health check and model status |
|
|
| `/predict` | POST | Single prediction with confidence scores |
|
|
| `/predict_top_k` | POST | Get top-K predictions |
|
|
| `/batch_predict` | POST | Batch prediction for multiple samples |
|
|
| `/model/info` | GET | Model information |
|
|
|
|
## Usage Examples
|
|
|
|
|
|
### Start the Server
|
|
|
|
```bash
|
|
python xgb_fastapi.py --port 8000
|
|
```
|
|
|
|
### Testing with curl
|
|
|
|
```bash
|
|
curl -X POST "http://localhost:8000/predict" \
|
|
-H "Content-Type: application/json" \
|
|
-d '{
|
|
"features": {
|
|
"bandwidth_raw_p10": 186.33,
|
|
"bandwidth_raw_median": 205.14,
|
|
"bandwidth_raw_p90": 210.83,
|
|
"flops_raw_p10": 162.024,
|
|
"flops_raw_median": 171.45,
|
|
"flops_raw_p90": 176.48,
|
|
"arith_intensity_median": 0.837,
|
|
"node_num": 0,
|
|
"duration": 19366
|
|
},
|
|
"threshold": 0.3
|
|
}'
|
|
```
|
|
|
|
### Testing with Python
|
|
|
|
```python
|
|
from xgb_local import XGBoostMultiLabelPredictor
|
|
|
|
predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')
|
|
|
|
result = predictor.predict(features, threshold=0.3)
|
|
print(f"Predictions: {result['predictions']}")
|
|
print(f"Confidences: {result['confidences']}")
|
|
|
|
# Top-K predictions
|
|
top_k = predictor.predict_top_k(features, k=5)
|
|
for cls, prob in top_k['top_probabilities'].items():
|
|
print(f"{cls}: {prob:.4f}")
|
|
```
|
|
|
|
See `xgb_local_example.py` for complete examples.
|
|
|
|
## Model Features (28 total)
|
|
|
|
| Category | Features |
|
|
|----------|----------|
|
|
| Bandwidth | `bandwidth_raw_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
|
|
| FLOPS | `flops_raw_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
|
|
| Arithmetic Intensity | `arith_intensity_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
|
|
| Performance | `avg_performance_gflops`, `median_performance_gflops`, `performance_gflops_mad` |
|
|
| Correlation | `bw_flops_covariance`, `bw_flops_correlation` |
|
|
| System | `avg_memory_bw_gbs`, `scalar_peak_gflops`, `simd_peak_gflops` |
|
|
| Other | `node_num`, `duration` |
|