Initial commit
This commit is contained in:
105
README.md
Normal file
105
README.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# XGBoost Multi-Label Classification API
|
||||
|
||||
A REST API for multi-label classification of HPC workloads using XGBoost. Classifies applications based on roofline performance metrics.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multi-Label Classification**: Predict multiple labels with confidence scores
|
||||
- **Top-K Predictions**: Get the most likely K predictions
|
||||
- **Batch Prediction**: Process multiple samples in a single request
|
||||
- **JSON Aggregation**: Aggregate raw roofline data into features automatically
|
||||
- **Around 60 HPC Application Classes**: Including VASP, GROMACS, TurTLE, Chroma, QuantumESPRESSO, etc.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Create and activate virtual environment
|
||||
python -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Run Tests
|
||||
|
||||
```bash
|
||||
# Python tests
|
||||
pytest test_xgb_fastapi.py -v
|
||||
|
||||
# Curl tests (start server first)
|
||||
./test_api_curl.sh
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/health` | GET | Health check and model status |
|
||||
| `/predict` | POST | Single prediction with confidence scores |
|
||||
| `/predict_top_k` | POST | Get top-K predictions |
|
||||
| `/batch_predict` | POST | Batch prediction for multiple samples |
|
||||
| `/model/info` | GET | Model information |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
|
||||
### Start the Server
|
||||
|
||||
```bash
|
||||
python xgb_fastapi.py --port 8000
|
||||
```
|
||||
|
||||
### Testing with curl
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/predict" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"features": {
|
||||
"bandwidth_raw_p10": 186.33,
|
||||
"bandwidth_raw_median": 205.14,
|
||||
"bandwidth_raw_p90": 210.83,
|
||||
"flops_raw_p10": 162.024,
|
||||
"flops_raw_median": 171.45,
|
||||
"flops_raw_p90": 176.48,
|
||||
"arith_intensity_median": 0.837,
|
||||
"node_num": 0,
|
||||
"duration": 19366
|
||||
},
|
||||
"threshold": 0.3
|
||||
}'
|
||||
```
|
||||
|
||||
### Testing with Python
|
||||
|
||||
```python
|
||||
from xgb_local import XGBoostMultiLabelPredictor
|
||||
|
||||
predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')
|
||||
|
||||
result = predictor.predict(features, threshold=0.3)
|
||||
print(f"Predictions: {result['predictions']}")
|
||||
print(f"Confidences: {result['confidences']}")
|
||||
|
||||
# Top-K predictions
|
||||
top_k = predictor.predict_top_k(features, k=5)
|
||||
for cls, prob in top_k['top_probabilities'].items():
|
||||
print(f"{cls}: {prob:.4f}")
|
||||
```
|
||||
|
||||
See `xgb_local_example.py` for complete examples.
|
||||
|
||||
## Model Features (28 total)
|
||||
|
||||
| Category | Features |
|
||||
|----------|----------|
|
||||
| Bandwidth | `bandwidth_raw_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
|
||||
| FLOPS | `flops_raw_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
|
||||
| Arithmetic Intensity | `arith_intensity_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
|
||||
| Performance | `avg_performance_gflops`, `median_performance_gflops`, `performance_gflops_mad` |
|
||||
| Correlation | `bw_flops_covariance`, `bw_flops_correlation` |
|
||||
| System | `avg_memory_bw_gbs`, `scalar_peak_gflops`, `simd_peak_gflops` |
|
||||
| Other | `node_num`, `duration` |
|
||||
Reference in New Issue
Block a user