Initial commit

This commit is contained in:
Bole Ma
2025-12-10 12:17:41 +01:00
commit 739563f916
12 changed files with 3428 additions and 0 deletions

62
.gitignore vendored Normal file
View File

@@ -0,0 +1,62 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual environments
venv/
ENV/
env/
.venv/
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
.DS_Store
# Jupyter
.ipynb_checkpoints/
# Model files (optional - uncomment if you don't want to track large models)
# *.joblib
# *.pkl
# *.h5
# Logs
*.log
logs/
# Environment variables
.env
.env.local
# Testing
.pytest_cache/
.coverage
htmlcov/
# Go binaries
*.exe
*.exe~
*.dll
*.dylib

42
Makefile Normal file
View File

@@ -0,0 +1,42 @@
.PHONY: install run-fastapi run-flask test clean help
# Default Python interpreter
PYTHON := python3
# Default ports
FASTAPI_PORT := 8000
FLASK_PORT := 5000
help:
@echo "XGBoost Multi-Label Classification API"
@echo ""
@echo "Usage:"
@echo " make install Install Python dependencies"
@echo " make run-fastapi Start FastAPI server (port $(FASTAPI_PORT))"
@echo " make run-flask Start Flask server (port $(FLASK_PORT))"
@echo " make test Run the inference example"
@echo " make clean Clean up cache files"
@echo ""
install:
$(PYTHON) -m pip install -r requirements.txt
run-fastapi:
$(PYTHON) xgb_fastapi.py --port $(FASTAPI_PORT)
run-fastapi-dev:
$(PYTHON) xgb_fastapi.py --port $(FASTAPI_PORT) --reload
run-flask:
$(PYTHON) xgb_rest_api.py --port $(FLASK_PORT)
run-flask-debug:
$(PYTHON) xgb_rest_api.py --port $(FLASK_PORT) --debug
test:
$(PYTHON) xgb_inference_example.py
clean:
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
find . -type f -name "*.pyc" -delete 2>/dev/null || true
find . -type f -name "*.pyo" -delete 2>/dev/null || true

105
README.md Normal file
View File

@@ -0,0 +1,105 @@
# XGBoost Multi-Label Classification API
A REST API for multi-label classification of HPC workloads using XGBoost. Classifies applications based on roofline performance metrics.
## Features
- **Multi-Label Classification**: Predict multiple labels with confidence scores
- **Top-K Predictions**: Get the most likely K predictions
- **Batch Prediction**: Process multiple samples in a single request
- **JSON Aggregation**: Aggregate raw roofline data into features automatically
- **Around 60 HPC Application Classes**: Including VASP, GROMACS, TurTLE, Chroma, QuantumESPRESSO, etc.
## Installation
```bash
# Create and activate virtual environment
python -m venv venv
source venv/bin/activate
# Install dependencies
pip install -r requirements.txt
```
## Quick Start
### Run Tests
```bash
# Python tests
pytest test_xgb_fastapi.py -v
# Curl tests (start server first)
./test_api_curl.sh
```
## API Endpoints
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/health` | GET | Health check and model status |
| `/predict` | POST | Single prediction with confidence scores |
| `/predict_top_k` | POST | Get top-K predictions |
| `/batch_predict` | POST | Batch prediction for multiple samples |
| `/model/info` | GET | Model information |
## Usage Examples
### Start the Server
```bash
python xgb_fastapi.py --port 8000
```
### Testing with curl
```bash
curl -X POST "http://localhost:8000/predict" \
-H "Content-Type: application/json" \
-d '{
"features": {
"bandwidth_raw_p10": 186.33,
"bandwidth_raw_median": 205.14,
"bandwidth_raw_p90": 210.83,
"flops_raw_p10": 162.024,
"flops_raw_median": 171.45,
"flops_raw_p90": 176.48,
"arith_intensity_median": 0.837,
"node_num": 0,
"duration": 19366
},
"threshold": 0.3
}'
```
### Testing with Python
```python
from xgb_local import XGBoostMultiLabelPredictor
predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')
result = predictor.predict(features, threshold=0.3)
print(f"Predictions: {result['predictions']}")
print(f"Confidences: {result['confidences']}")
# Top-K predictions
top_k = predictor.predict_top_k(features, k=5)
for cls, prob in top_k['top_probabilities'].items():
print(f"{cls}: {prob:.4f}")
```
See `xgb_local_example.py` for complete examples.
## Model Features (28 total)
| Category | Features |
|----------|----------|
| Bandwidth | `bandwidth_raw_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
| FLOPS | `flops_raw_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
| Arithmetic Intensity | `arith_intensity_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
| Performance | `avg_performance_gflops`, `median_performance_gflops`, `performance_gflops_mad` |
| Correlation | `bw_flops_covariance`, `bw_flops_correlation` |
| System | `avg_memory_bw_gbs`, `scalar_peak_gflops`, `simd_peak_gflops` |
| Other | `node_num`, `duration` |

600
cluster.json Executable file
View File

@@ -0,0 +1,600 @@
{
"name": "fritz",
"metricConfig": [
{
"name": "cpu_load",
"unit": {
"base": ""
},
"scope": "node",
"aggregation": "avg",
"footprint": "avg",
"timestep": 60,
"peak": 72,
"normal": 72,
"caution": 36,
"alert": 20,
"subClusters": [
{
"name": "spr1tb",
"footprint": "avg",
"peak": 104,
"normal": 104,
"caution": 52,
"alert": 20
},
{
"name": "spr2tb",
"footprint": "avg",
"peak": 104,
"normal": 104,
"caution": 52,
"alert": 20
}
]
},
{
"name": "cpu_user",
"unit": {
"base": ""
},
"scope": "hwthread",
"aggregation": "avg",
"timestep": 60,
"peak": 100,
"normal": 50,
"caution": 20,
"alert": 10
},
{
"name": "mem_used",
"unit": {
"base": "B",
"prefix": "G"
},
"scope": "node",
"aggregation": "sum",
"footprint": "max",
"timestep": 60,
"peak": 256,
"normal": 128,
"caution": 200,
"alert": 240,
"subClusters": [
{
"name": "spr1tb",
"footprint": "max",
"peak": 1024,
"normal": 512,
"caution": 900,
"alert": 1000
},
{
"name": "spr2tb",
"footprint": "max",
"peak": 2048,
"normal": 1024,
"caution": 1800,
"alert": 2000
}
]
},
{
"name": "flops_any",
"unit": {
"base": "Flops/s",
"prefix": "G"
},
"scope": "hwthread",
"aggregation": "sum",
"footprint": "avg",
"timestep": 60,
"peak": 5600,
"normal": 1000,
"caution": 200,
"alert": 50,
"subClusters": [
{
"name": "spr1tb",
"peak": 6656,
"normal": 1500,
"caution": 400,
"alert": 50,
"footprint": "avg"
},
{
"name": "spr2tb",
"peak": 6656,
"normal": 1500,
"caution": 400,
"alert": 50,
"footprint": "avg"
}
]
},
{
"name": "flops_sp",
"unit": {
"base": "Flops/s",
"prefix": "G"
},
"scope": "hwthread",
"aggregation": "sum",
"timestep": 60,
"peak": 5600,
"normal": 1000,
"caution": 200,
"alert": 50,
"subClusters": [
{
"name": "spr1tb",
"peak": 6656,
"normal": 1500,
"caution": 400,
"alert": 50
},
{
"name": "spr2tb",
"peak": 6656,
"normal": 1500,
"caution": 400,
"alert": 50
}
]
},
{
"name": "flops_dp",
"unit": {
"base": "Flops/s",
"prefix": "G"
},
"scope": "hwthread",
"aggregation": "sum",
"timestep": 60,
"peak": 2300,
"normal": 500,
"caution": 100,
"alert": 50,
"subClusters": [
{
"name": "spr1tb",
"peak": 3300,
"normal": 750,
"caution": 200,
"alert": 50
},
{
"name": "spr2tb",
"peak": 3300,
"normal": 750,
"caution": 200,
"alert": 50
}
]
},
{
"name": "mem_bw",
"unit": {
"base": "B/s",
"prefix": "G"
},
"scope": "socket",
"aggregation": "sum",
"footprint": "avg",
"timestep": 60,
"peak": 350,
"normal": 100,
"caution": 50,
"alert": 10,
"subClusters": [
{
"name": "spr1tb",
"footprint": "avg",
"peak": 549,
"normal": 200,
"caution": 100,
"alert": 20
},
{
"name": "spr2tb",
"footprint": "avg",
"peak": 520,
"normal": 200,
"caution": 100,
"alert": 20
}
]
},
{
"name": "clock",
"unit": {
"base": "Hz",
"prefix": "M"
},
"scope": "hwthread",
"aggregation": "avg",
"timestep": 60,
"peak": 3000,
"normal": 2400,
"caution": 1800,
"alert": 1200,
"subClusters": [
{
"name": "spr1tb",
"peak": 3000,
"normal": 2000,
"caution": 1600,
"alert": 1200,
"remove": false
},
{
"name": "spr2tb",
"peak": 3000,
"normal": 2000,
"caution": 1600,
"alert": 1200,
"remove": false
}
]
},
{
"name": "cpu_power",
"unit": {
"base": "W"
},
"scope": "socket",
"aggregation": "sum",
"energy": "power",
"timestep": 60,
"peak": 500,
"normal": 250,
"caution": 100,
"alert": 50,
"subClusters": [
{
"name": "spr1tb",
"peak": 700,
"energy": "power",
"normal": 350,
"caution": 150,
"alert": 50
},
{
"name": "spr2tb",
"peak": 700,
"energy": "power",
"normal": 350,
"caution": 150,
"alert": 50
}
]
},
{
"name": "mem_power",
"unit": {
"base": "W"
},
"scope": "socket",
"aggregation": "sum",
"energy": "power",
"timestep": 60,
"peak": 100,
"normal": 50,
"caution": 20,
"alert": 10,
"subClusters": [
{
"name": "spr1tb",
"peak": 400,
"energy": "power",
"normal": 200,
"caution": 80,
"alert": 40
},
{
"name": "spr2tb",
"peak": 800,
"energy": "power",
"normal": 400,
"caution": 160,
"alert": 80
}
]
},
{
"name": "ipc",
"unit": {
"base": "IPC"
},
"scope": "hwthread",
"aggregation": "avg",
"timestep": 60,
"peak": 4,
"normal": 2,
"caution": 1,
"alert": 0.5,
"subClusters": [
{
"name": "spr1tb",
"peak": 6,
"normal": 2,
"caution": 1,
"alert": 0.5
},
{
"name": "spr2tb",
"peak": 6,
"normal": 2,
"caution": 1,
"alert": 0.5
}
]
},
{
"name": "vectorization_ratio",
"unit": {
"base": ""
},
"scope": "hwthread",
"aggregation": "avg",
"timestep": 60,
"peak": 100,
"normal": 60,
"caution": 40,
"alert": 10,
"subClusters": [
{
"name": "spr1tb",
"peak": 100,
"normal": 60,
"caution": 40,
"alert": 10
},
{
"name": "spr2tb",
"peak": 100,
"normal": 60,
"caution": 40,
"alert": 10
}
]
},
{
"name": "ib_recv",
"unit": {
"base": "B/s"
},
"scope": "node",
"aggregation": "sum",
"timestep": 60,
"peak": 1250000,
"normal": 6000000,
"caution": 200,
"alert": 1
},
{
"name": "ib_xmit",
"unit": {
"base": "B/s"
},
"scope": "node",
"aggregation": "sum",
"timestep": 60,
"peak": 1250000,
"normal": 6000000,
"caution": 200,
"alert": 1
},
{
"name": "ib_recv_pkts",
"unit": {
"base": "packets/s"
},
"scope": "node",
"aggregation": "sum",
"timestep": 60,
"peak": 6,
"normal": 4,
"caution": 2,
"alert": 1
},
{
"name": "ib_xmit_pkts",
"unit": {
"base": "packets/s"
},
"scope": "node",
"aggregation": "sum",
"timestep": 60,
"peak": 6,
"normal": 4,
"caution": 2,
"alert": 1
},
{
"name": "nfs4_read",
"unit": {
"base": "IOP",
"prefix": ""
},
"scope": "node",
"aggregation": "sum",
"timestep": 60,
"peak": 1000,
"normal": 50,
"caution": 200,
"alert": 500
},
{
"name": "nfs4_total",
"unit": {
"base": "IOP",
"prefix": ""
},
"scope": "node",
"aggregation": "sum",
"timestep": 60,
"peak": 1000,
"normal": 50,
"caution": 200,
"alert": 500
}
],
"subClusters": [
{
"name": "main",
"nodes": "f[0101-0188,0201-0288,0301-0388,0401-0488,0501-0588,0601-0688,0701-0788,0801-0888,0901-0988,1001-1088,1101-1156,1201-1256]",
"processorType": "Intel Icelake",
"socketsPerNode": 2,
"coresPerSocket": 36,
"threadsPerCore": 1,
"flopRateScalar": {
"unit": {
"base": "F/s",
"prefix": "G"
},
"value": 432
},
"flopRateSimd": {
"unit": {
"base": "F/s",
"prefix": "G"
},
"value": 9216
},
"memoryBandwidth": {
"unit": {
"base": "B/s",
"prefix": "G"
},
"value": 350
},
"topology": {
"node": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71
],
"socket": [
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35 ],
[ 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71 ]
],
"memoryDomain": [
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17 ],
[ 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35 ],
[ 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53 ],
[ 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71 ]
],
"core": [
[ 0 ], [ 1 ], [ 2 ], [ 3 ], [ 4 ], [ 5 ], [ 6 ], [ 7 ], [ 8 ], [ 9 ], [ 10 ], [ 11 ], [ 12 ], [ 13 ], [ 14 ], [ 15 ], [ 16 ], [ 17 ], [ 18 ], [ 19 ], [ 20 ], [ 21 ], [ 22 ], [ 23 ], [ 24 ], [ 25 ], [ 26 ], [ 27 ], [ 28 ], [ 29 ], [ 30 ], [ 31 ], [ 32 ], [ 33 ], [ 34 ], [ 35 ], [ 36 ], [ 37 ], [ 38 ], [ 39 ], [ 40 ], [ 41 ], [ 42 ], [ 43 ], [ 44 ], [ 45 ], [ 46 ], [ 47 ], [ 48 ], [ 49 ], [ 50 ], [ 51 ], [ 52 ], [ 53 ], [ 54 ], [ 55 ], [ 56 ], [ 57 ], [ 58 ], [ 59 ], [ 60 ], [ 61 ], [ 62 ], [ 63 ], [ 64 ], [ 65 ], [ 66 ], [ 67 ], [ 68 ], [ 69 ], [ 70 ], [ 71 ]
]
}
},
{
"name": "spr1tb",
"processorType": "Intel(R) Xeon(R) Platinum 8470",
"socketsPerNode": 2,
"coresPerSocket": 52,
"threadsPerCore": 1,
"flopRateScalar": {
"unit": {
"base": "F/s",
"prefix": "G"
},
"value": 695
},
"flopRateSimd": {
"unit": {
"base": "F/s",
"prefix": "G"
},
"value": 9216
},
"memoryBandwidth": {
"unit": {
"base": "B/s",
"prefix": "G"
},
"value": 549
},
"nodes": "f[2157-2180,2257-2280]",
"topology": {
"node":
[
0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,5152,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103
],
"socket":
[
[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51],
[52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103]
],
"memoryDomain": [
[0,1,2,3,4,5,6,7,8,9,10,11,12],
[13,14,15,16,17,18,19,20,21,22,23,24,25],
[26,27,28,29,30,31,32,33,34,35,36,37,38],
[39,40,41,42,43,44,45,46,47,48,49,50,51],
[52,53,54,55,56,57,58,59,60,61,62,63,64],
[65,66,67,68,69,70,71,72,73,74,75,76,77],
[78,79,80,81,82,83,84,85,86,87,88,89,90],
[91,92,93,94,95,96,97,98,99,100,101,102,103]
],
"core": [
[0],[1],[2],[3],[4],[5],[6],[7],[8],[9],[10],[11],[12],[13],[14],[15],[16],[17],[18],[19],[20],[21],[22],[23],[24],[25],[26],[27],[28],[29],[30],[31],[32],[33],[34],[35],[36],[37],[38],[39],[40],[41],[42],[43],[44],[45],[46],[47],[48],[49],[50],[51],[52],[53],[54],[55],[56],[57],[58],[59],[60],[61],[62],[63],[64],[65],[66],[67],[68],[69],[70],[71],[72],[73],[74],[75],[76],[77],[78],[79],[80],[81],[82],[83],[84],[85],[86],[87],[88],[89],[90],[91],[92],[93],[94],[95],[96],[97],[98],[99],[100],[101],[102],[103]
]
}
},
{
"name": "spr2tb",
"processorType": "Intel(R) Xeon(R) Platinum 8470",
"socketsPerNode": 2,
"coresPerSocket": 52,
"threadsPerCore": 1,
"flopRateScalar": {
"unit": {
"base": "F/s",
"prefix": "G"
},
"value": 695
},
"flopRateSimd": {
"unit": {
"base": "F/s",
"prefix": "G"
},
"value": 9216
},
"memoryBandwidth": {
"unit": {
"base": "B/s",
"prefix": "G"
},
"value": 515
},
"nodes": "f[2181-2188,2281-2288]",
"topology": {
"node": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103
],
"socket": [
[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51
],
[
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103
]
],
"memoryDomain": [
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 ],
[ 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 ],
[ 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38 ],
[ 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51 ],
[ 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64 ],
[ 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77 ],
[ 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90 ],
[ 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103 ]
],
"core": [
[ 0 ], [ 1 ], [ 2 ], [ 3 ], [ 4 ], [ 5 ], [ 6 ], [ 7 ], [ 8 ], [ 9 ], [ 10 ], [ 11 ], [ 12 ], [ 13 ], [ 14 ], [ 15 ], [ 16 ], [ 17 ], [ 18 ], [ 19 ], [ 20 ], [ 21 ], [ 22 ], [ 23 ], [ 24 ], [ 25 ], [ 26 ], [ 27 ], [ 28 ], [ 29 ], [ 30 ], [ 31 ], [ 32 ], [ 33 ], [ 34 ], [ 35 ], [ 36 ], [ 37 ], [ 38 ], [ 39 ], [ 40 ], [ 41 ], [ 42 ], [ 43 ], [ 44 ], [ 45 ], [ 46 ], [ 47 ], [ 48 ], [ 49 ], [ 50 ], [ 51 ], [ 52 ], [ 53 ], [ 54 ], [ 55 ], [ 56 ], [ 57 ], [ 58 ], [ 59 ], [ 60 ], [ 61 ], [ 62 ], [ 63 ], [ 64 ], [ 65 ], [ 66 ], [ 67 ], [ 68 ], [ 69 ], [ 70 ], [ 71 ], [ 72 ], [ 73 ], [ 74 ], [ 75 ], [ 76 ], [ 77 ], [ 78 ], [ 79 ], [ 80 ], [ 81 ], [ 82 ], [ 83 ], [ 84 ], [ 85 ], [ 86 ], [ 87 ], [ 88 ], [ 89 ], [ 90 ], [ 91 ], [ 92 ], [ 93 ], [ 94 ], [ 95 ], [ 96 ], [ 97 ], [ 98 ], [ 99 ], [ 100 ], [ 101 ], [ 102 ], [ 103 ]
]
}
}
]
}

172
feature_aggregator.py Normal file
View File

@@ -0,0 +1,172 @@
import os
import pandas as pd
import numpy as np
import sqlite3
import h5py
from glob import glob
from scipy import stats
from pathlib import Path
def compute_mad(data):
"""Compute Median Absolute Deviation: a robust measure of dispersion"""
return np.median(np.abs(data - np.median(data)))
def process_roofline_data(base_dir='D:/roofline_dataframes',
output_file='D:/roofline_features.h5',
job_tags_db='job_tags.db'):
"""
Process roofline data to extract features for machine learning.
Args:
base_dir: Directory containing roofline dataframes
output_file: Path to save the output features
job_tags_db: Path to SQLite database with job tags
"""
# Connect to the SQLite database
conn = sqlite3.connect(job_tags_db)
cursor = conn.cursor()
# List to store all job features
all_job_features = []
# Find all job prefix folders in the base directory
job_prefixes = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
for job_id_prefix in job_prefixes:
job_prefix_path = os.path.join(base_dir, job_id_prefix)
# Find all h5 files in this job prefix folder
h5_files = glob(os.path.join(job_prefix_path, '*_dataframe.h5'))
for h5_file in h5_files:
filename = os.path.basename(h5_file)
# Extract job_id_full from the filename pattern: {job_id_full}_{timestamp}_dataframe.h5
job_id_full = filename.split('_')[0]
try:
# Read the dataframe from the h5 file
with h5py.File(h5_file, 'r') as f:
if 'dataframe' not in f:
print(f"Warning: No 'dataframe' key in {h5_file}")
continue
df = pd.read_hdf(h5_file, key='dataframe')
# Group data by node_num
grouped = df.groupby('node_num')
for node_num, group in grouped:
features = {
'job_id': job_id_full,
'node_num': node_num
}
# Compute statistics for key metrics
for axis in ['bandwidth_raw', 'flops_raw', 'arith_intensity']:
data = group[axis].values
# Compute percentiles
p10 = np.percentile(data, 10)
p50 = np.median(data)
p90 = np.percentile(data, 90)
# Compute MAD (more robust than variance)
mad = compute_mad(data)
# Store features
features[f'{axis}_p10'] = p10
features[f'{axis}_median'] = p50
features[f'{axis}_p90'] = p90
features[f'{axis}_mad'] = mad
features[f'{axis}_range'] = p90 - p10
features[f'{axis}_iqr'] = np.percentile(data, 75) - np.percentile(data, 25)
# Compute covariance and correlation between bandwidth_raw and flops_raw
if len(group) > 1: # Need at least 2 points for correlation
cov = np.cov(group['bandwidth_raw'], group['flops_raw'])[0, 1]
features['bw_flops_covariance'] = cov
corr, _ = stats.pearsonr(group['bandwidth_raw'], group['flops_raw'])
features['bw_flops_correlation'] = corr
# Additional useful features for the classifier
# Performance metrics
features['avg_performance_gflops'] = group['performance_gflops'].mean()
features['median_performance_gflops'] = group['performance_gflops'].median()
features['performance_gflops_mad'] = compute_mad(group['performance_gflops'].values)
# # Efficiency metrics
# features['avg_efficiency'] = group['efficiency'].mean()
# features['median_efficiency'] = group['efficiency'].median()
# features['efficiency_mad'] = compute_mad(group['efficiency'].values)
# features['efficiency_p10'] = np.percentile(group['efficiency'].values, 10)
# features['efficiency_p90'] = np.percentile(group['efficiency'].values, 90)
# # Distribution of roofline regions (memory-bound vs compute-bound)
# if 'roofline_region' in group.columns:
# region_counts = group['roofline_region'].value_counts(normalize=True).to_dict()
# for region, ratio in region_counts.items():
# features[f'region_{region}_ratio'] = ratio
# System characteristics
if 'memory_bw_gbs' in group.columns:
features['avg_memory_bw_gbs'] = group['memory_bw_gbs'].mean()
if 'scalar_peak_gflops' in group.columns and len(group['scalar_peak_gflops'].unique()) > 0:
features['scalar_peak_gflops'] = group['scalar_peak_gflops'].iloc[0]
if 'simd_peak_gflops' in group.columns and len(group['simd_peak_gflops'].unique()) > 0:
features['simd_peak_gflops'] = group['simd_peak_gflops'].iloc[0]
# # Subcluster information if available
# if 'subcluster_name' in group.columns and not group['subcluster_name'].isna().all():
# features['subcluster_name'] = group['subcluster_name'].iloc[0]
# Duration information
if 'duration' in group.columns:
features['duration'] = group['duration'].iloc[0]
# Get the label (application type) from the database
cursor.execute("SELECT tags FROM job_tags WHERE job_id = ?", (int(job_id_full),))
result = cursor.fetchone()
if result:
# Extract application name from tags
tags = result[0]
features['label'] = tags
else:
features['label'] = 'Unknown'
all_job_features.append(features)
except Exception as e:
print(f"Error processing file {h5_file}: {e}")
# Close database connection
conn.close()
if not all_job_features:
print("No features extracted. Check if files exist and have the correct format.")
return
# Convert to DataFrame
features_df = pd.DataFrame(all_job_features)
# Fill missing roofline region ratios with 0
region_columns = [col for col in features_df.columns if col.startswith('region_')]
for col in region_columns:
if col not in features_df.columns:
features_df[col] = 0
else:
features_df[col] = features_df[col].fillna(0)
# Save to H5 file
features_df.to_hdf(output_file, key='features', mode='w')
print(f"Processed {len(all_job_features)} job-node combinations")
print(f"Features saved to {output_file}")
print(f"Feature columns: {', '.join(features_df.columns)}")
return features_df
if __name__ == "__main__":
process_roofline_data()

20
requirements.txt Normal file
View File

@@ -0,0 +1,20 @@
# Core ML dependencies
numpy>=1.21.0
pandas>=1.3.0
scipy>=1.7.0
joblib>=1.0.0
xgboost>=1.5.0
scikit-learn>=1.0.0
# FastAPI server
fastapi>=0.68.0
uvicorn>=0.15.0
pydantic>=1.8.0
# Flask server (alternative)
flask>=2.0.0
flask-cors>=3.0.0
# Testing
pytest>=7.0.0
httpx>=0.23.0

436
test_api_curl.sh Executable file
View File

@@ -0,0 +1,436 @@
#!/bin/bash
#
# Test script for XGBoost FastAPI Multi-Label Classification API
# Uses curl to test all endpoints with realistic sample data from roofline_features.h5
#
# Usage:
# 1. Start the server: ./venv/bin/python xgb_fastapi.py
# 2. Run this script: ./test_api_curl.sh
#
# Optional: Set API_URL environment variable to test a different host
# API_URL=http://192.168.1.100:8000 ./test_api_curl.sh
set -e
# Configuration
API_URL="${API_URL:-http://localhost:8000}"
VERBOSE="${VERBOSE:-false}"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Counters
PASSED=0
FAILED=0
# ============================================================================
# Helper Functions
# ============================================================================
print_header() {
echo ""
echo -e "${BLUE}============================================================${NC}"
echo -e "${BLUE}$1${NC}"
echo -e "${BLUE}============================================================${NC}"
}
print_test() {
echo -e "\n${YELLOW}▶ TEST: $1${NC}"
}
print_success() {
echo -e "${GREEN}✓ PASSED: $1${NC}"
((PASSED++))
}
print_failure() {
echo -e "${RED}✗ FAILED: $1${NC}"
((FAILED++))
}
# Make a curl request and check status code
# Usage: make_request METHOD ENDPOINT [DATA]
make_request() {
local method=$1
local endpoint=$2
local data=$3
local expected_status=${4:-200}
if [ "$method" == "GET" ]; then
response=$(curl -s -w "\n%{http_code}" "${API_URL}${endpoint}")
else
response=$(curl -s -w "\n%{http_code}" -X "$method" \
-H "Content-Type: application/json" \
-d "$data" \
"${API_URL}${endpoint}")
fi
# Extract status code (last line) and body (everything else)
http_code=$(echo "$response" | tail -n1)
body=$(echo "$response" | sed '$d')
if [ "$VERBOSE" == "true" ]; then
echo "Response: $body"
echo "Status: $http_code"
fi
if [ "$http_code" == "$expected_status" ]; then
return 0
else
echo "Expected status $expected_status, got $http_code"
echo "Response: $body"
return 1
fi
}
# ============================================================================
# Sample Data (from roofline_features.h5)
# ============================================================================
# TurTLE application - turbulence simulation workload
SAMPLE_TURTLE='{
"bandwidth_raw_p10": 186.33,
"bandwidth_raw_median": 205.14,
"bandwidth_raw_p90": 210.83,
"bandwidth_raw_mad": 3.57,
"bandwidth_raw_range": 24.5,
"bandwidth_raw_iqr": 12.075,
"flops_raw_p10": 162.024,
"flops_raw_median": 171.45,
"flops_raw_p90": 176.48,
"flops_raw_mad": 3.08,
"flops_raw_range": 14.456,
"flops_raw_iqr": 8.29,
"arith_intensity_p10": 0.7906,
"arith_intensity_median": 0.837,
"arith_intensity_p90": 0.9109,
"arith_intensity_mad": 0.02,
"arith_intensity_range": 0.12,
"arith_intensity_iqr": 0.0425,
"bw_flops_covariance": 60.86,
"bw_flops_correlation": 0.16,
"avg_performance_gflops": 168.1,
"median_performance_gflops": 171.45,
"performance_gflops_mad": 3.08,
"avg_memory_bw_gbs": 350.0,
"scalar_peak_gflops": 432.0,
"simd_peak_gflops": 9216.0,
"node_num": 0,
"duration": 19366
}'
# Chroma application - lattice QCD workload (compute-intensive)
SAMPLE_CHROMA='{
"bandwidth_raw_p10": 154.176,
"bandwidth_raw_median": 200.57,
"bandwidth_raw_p90": 259.952,
"bandwidth_raw_mad": 5.12,
"bandwidth_raw_range": 105.776,
"bandwidth_raw_iqr": 10.215,
"flops_raw_p10": 327.966,
"flops_raw_median": 519.8,
"flops_raw_p90": 654.422,
"flops_raw_mad": 16.97,
"flops_raw_range": 326.456,
"flops_raw_iqr": 34.88,
"arith_intensity_p10": 1.55,
"arith_intensity_median": 2.595,
"arith_intensity_p90": 3.445,
"arith_intensity_mad": 0.254,
"arith_intensity_range": 1.894,
"arith_intensity_iqr": 0.512,
"bw_flops_covariance": 382.76,
"bw_flops_correlation": 0.063,
"avg_performance_gflops": 503.26,
"median_performance_gflops": 519.8,
"performance_gflops_mad": 16.97,
"avg_memory_bw_gbs": 350.0,
"scalar_peak_gflops": 432.0,
"simd_peak_gflops": 9216.0,
"node_num": 3,
"duration": 31133
}'
# Raw JSON roofline data (before aggregation)
SAMPLE_JSON_ROOFLINE='[
{"node_num": 1, "bandwidth_raw": 150.5, "flops_raw": 2500.0, "arith_intensity": 16.6, "performance_gflops": 1200.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
{"node_num": 1, "bandwidth_raw": 155.2, "flops_raw": 2600.0, "arith_intensity": 16.8, "performance_gflops": 1250.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
{"node_num": 1, "bandwidth_raw": 148.0, "flops_raw": 2450.0, "arith_intensity": 16.5, "performance_gflops": 1180.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600}
]'
# ============================================================================
# Tests
# ============================================================================
print_header "XGBoost FastAPI Test Suite"
echo "Testing API at: $API_URL"
echo "Started at: $(date)"
# ----------------------------------------------------------------------------
# Health & Info Endpoints
# ----------------------------------------------------------------------------
print_header "Health & Info Endpoints"
print_test "Root Endpoint"
if make_request GET "/"; then
print_success "Root endpoint accessible"
else
print_failure "Root endpoint failed"
fi
print_test "Health Check"
if make_request GET "/health"; then
print_success "Health check passed"
else
print_failure "Health check failed"
fi
print_test "Model Info"
if make_request GET "/model/info"; then
print_success "Model info retrieved"
else
print_failure "Model info failed"
fi
# ----------------------------------------------------------------------------
# Single Prediction Endpoints
# ----------------------------------------------------------------------------
print_header "Single Prediction Endpoint (/predict)"
print_test "Predict with TurTLE sample (threshold=0.3)"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE,
"threshold": 0.3,
"return_all_probabilities": true
}
EOF
)
if make_request POST "/predict" "$REQUEST_DATA"; then
print_success "TurTLE prediction completed"
else
print_failure "TurTLE prediction failed"
fi
print_test "Predict with Chroma sample (threshold=0.5)"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_CHROMA,
"threshold": 0.5
}
EOF
)
if make_request POST "/predict" "$REQUEST_DATA"; then
print_success "Chroma prediction completed"
else
print_failure "Chroma prediction failed"
fi
print_test "Predict with JSON roofline data (is_json=true)"
# Need to escape the JSON for embedding
ESCAPED_JSON=$(echo "$SAMPLE_JSON_ROOFLINE" | tr -d '\n' | sed 's/"/\\"/g')
REQUEST_DATA=$(cat <<EOF
{
"features": "$ESCAPED_JSON",
"is_json": true,
"job_id": "test_job_curl_001",
"threshold": 0.3
}
EOF
)
if make_request POST "/predict" "$REQUEST_DATA"; then
print_success "JSON roofline prediction completed"
else
print_failure "JSON roofline prediction failed"
fi
print_test "Predict with return_all_probabilities=false"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE,
"threshold": 0.3,
"return_all_probabilities": false
}
EOF
)
if make_request POST "/predict" "$REQUEST_DATA"; then
print_success "Prediction with filtered probabilities completed"
else
print_failure "Prediction with filtered probabilities failed"
fi
# ----------------------------------------------------------------------------
# Top-K Prediction Endpoints
# ----------------------------------------------------------------------------
print_header "Top-K Prediction Endpoint (/predict_top_k)"
print_test "Top-5 predictions (default)"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE
}
EOF
)
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
print_success "Top-5 prediction completed"
else
print_failure "Top-5 prediction failed"
fi
print_test "Top-10 predictions"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_CHROMA,
"k": 10
}
EOF
)
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
print_success "Top-10 prediction completed"
else
print_failure "Top-10 prediction failed"
fi
print_test "Top-3 predictions"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE,
"k": 3
}
EOF
)
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
print_success "Top-3 prediction completed"
else
print_failure "Top-3 prediction failed"
fi
# ----------------------------------------------------------------------------
# Batch Prediction Endpoints
# ----------------------------------------------------------------------------
print_header "Batch Prediction Endpoint (/batch_predict)"
print_test "Batch predict with 2 samples"
REQUEST_DATA=$(cat <<EOF
{
"features_list": [$SAMPLE_TURTLE, $SAMPLE_CHROMA],
"threshold": 0.3
}
EOF
)
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
print_success "Batch prediction (2 samples) completed"
else
print_failure "Batch prediction (2 samples) failed"
fi
print_test "Batch predict with single sample"
REQUEST_DATA=$(cat <<EOF
{
"features_list": [$SAMPLE_TURTLE],
"threshold": 0.5
}
EOF
)
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
print_success "Batch prediction (single sample) completed"
else
print_failure "Batch prediction (single sample) failed"
fi
print_test "Batch predict with empty list"
REQUEST_DATA='{"features_list": [], "threshold": 0.5}'
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
print_success "Batch prediction (empty list) completed"
else
print_failure "Batch prediction (empty list) failed"
fi
# ----------------------------------------------------------------------------
# Error Handling Tests
# ----------------------------------------------------------------------------
print_header "Error Handling Tests"
print_test "Invalid threshold (negative) - should return 422"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE,
"threshold": -0.5
}
EOF
)
if make_request POST "/predict" "$REQUEST_DATA" 422; then
print_success "Invalid threshold correctly rejected"
else
print_failure "Invalid threshold not rejected properly"
fi
print_test "Invalid threshold (> 1.0) - should return 422"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE,
"threshold": 1.5
}
EOF
)
if make_request POST "/predict" "$REQUEST_DATA" 422; then
print_success "Invalid threshold (>1) correctly rejected"
else
print_failure "Invalid threshold (>1) not rejected properly"
fi
print_test "Missing features field - should return 422"
REQUEST_DATA='{"threshold": 0.5}'
if make_request POST "/predict" "$REQUEST_DATA" 422; then
print_success "Missing features correctly rejected"
else
print_failure "Missing features not rejected properly"
fi
print_test "Invalid endpoint - should return 404"
if make_request GET "/nonexistent" "" 404; then
print_success "Invalid endpoint correctly returns 404"
else
print_failure "Invalid endpoint not handled properly"
fi
print_test "Invalid k value (0) - should return 422"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE,
"k": 0
}
EOF
)
if make_request POST "/predict_top_k" "$REQUEST_DATA" 422; then
print_success "Invalid k value correctly rejected"
else
print_failure "Invalid k value not rejected properly"
fi
# ============================================================================
# Summary
# ============================================================================
print_header "Test Summary"
echo ""
echo -e "Passed: ${GREEN}$PASSED${NC}"
echo -e "Failed: ${RED}$FAILED${NC}"
echo ""
if [ $FAILED -eq 0 ]; then
echo -e "${GREEN}All tests passed! ✓${NC}"
exit 0
else
echo -e "${RED}Some tests failed! ✗${NC}"
exit 1
fi

811
test_xgb_fastapi.py Normal file
View File

@@ -0,0 +1,811 @@
#!/usr/bin/env python3
"""
Tests for XGBoost FastAPI Multi-Label Classification API
These tests use realistic sample data extracted from /Volumes/T7/roofline_features.h5
which was generated using feature_aggregator.py to process roofline dataframes.
Test data includes samples from different application types:
- TurTLE (turbulence simulation)
- SCALEXA (scaling benchmarks)
- Chroma (lattice QCD)
"""
import pytest
import json
from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock
import numpy as np
import os
# Import the FastAPI app and predictor
import xgb_fastapi
from xgb_fastapi import app
# ============================================================================
# Test Data: Realistic samples from roofline_features.h5
# Generated using feature_aggregator.py processing roofline dataframes
# ============================================================================
# Sample 1: TurTLE application - typical turbulence simulation workload
SAMPLE_TURTLE = {
"bandwidth_raw_p10": 186.33,
"bandwidth_raw_median": 205.14,
"bandwidth_raw_p90": 210.83,
"bandwidth_raw_mad": 3.57,
"bandwidth_raw_range": 24.5,
"bandwidth_raw_iqr": 12.075,
"flops_raw_p10": 162.024,
"flops_raw_median": 171.45,
"flops_raw_p90": 176.48,
"flops_raw_mad": 3.08,
"flops_raw_range": 14.456,
"flops_raw_iqr": 8.29,
"arith_intensity_p10": 0.7906,
"arith_intensity_median": 0.837,
"arith_intensity_p90": 0.9109,
"arith_intensity_mad": 0.02,
"arith_intensity_range": 0.12,
"arith_intensity_iqr": 0.0425,
"bw_flops_covariance": 60.86,
"bw_flops_correlation": 0.16,
"avg_performance_gflops": 168.1,
"median_performance_gflops": 171.45,
"performance_gflops_mad": 3.08,
"avg_memory_bw_gbs": 350.0,
"scalar_peak_gflops": 432.0,
"simd_peak_gflops": 9216.0,
"node_num": 0,
"duration": 19366,
}
# Sample 2: SCALEXA application - scaling benchmark workload
SAMPLE_SCALEXA = {
"bandwidth_raw_p10": 13.474,
"bandwidth_raw_median": 32.57,
"bandwidth_raw_p90": 51.466,
"bandwidth_raw_mad": 23.62,
"bandwidth_raw_range": 37.992,
"bandwidth_raw_iqr": 23.745,
"flops_raw_p10": 4.24,
"flops_raw_median": 16.16,
"flops_raw_p90": 24.584,
"flops_raw_mad": 10.53,
"flops_raw_range": 20.344,
"flops_raw_iqr": 12.715,
"arith_intensity_p10": 0.211,
"arith_intensity_median": 0.475,
"arith_intensity_p90": 0.492,
"arith_intensity_mad": 0.021,
"arith_intensity_range": 0.281,
"arith_intensity_iqr": 0.176,
"bw_flops_covariance": 302.0,
"bw_flops_correlation": 0.995,
"avg_performance_gflops": 14.7,
"median_performance_gflops": 16.16,
"performance_gflops_mad": 10.53,
"avg_memory_bw_gbs": 350.0,
"scalar_peak_gflops": 432.0,
"simd_peak_gflops": 9216.0,
"node_num": 18,
"duration": 165,
}
# Sample 3: Chroma application - lattice QCD workload (compute-intensive)
SAMPLE_CHROMA = {
"bandwidth_raw_p10": 154.176,
"bandwidth_raw_median": 200.57,
"bandwidth_raw_p90": 259.952,
"bandwidth_raw_mad": 5.12,
"bandwidth_raw_range": 105.776,
"bandwidth_raw_iqr": 10.215,
"flops_raw_p10": 327.966,
"flops_raw_median": 519.8,
"flops_raw_p90": 654.422,
"flops_raw_mad": 16.97,
"flops_raw_range": 326.456,
"flops_raw_iqr": 34.88,
"arith_intensity_p10": 1.55,
"arith_intensity_median": 2.595,
"arith_intensity_p90": 3.445,
"arith_intensity_mad": 0.254,
"arith_intensity_range": 1.894,
"arith_intensity_iqr": 0.512,
"bw_flops_covariance": 382.76,
"bw_flops_correlation": 0.063,
"avg_performance_gflops": 503.26,
"median_performance_gflops": 519.8,
"performance_gflops_mad": 16.97,
"avg_memory_bw_gbs": 350.0,
"scalar_peak_gflops": 432.0,
"simd_peak_gflops": 9216.0,
"node_num": 3,
"duration": 31133,
}
# Sample JSON roofline data (raw data before aggregation, as would be received by API)
SAMPLE_JSON_ROOFLINE = json.dumps([
{
"node_num": 1,
"bandwidth_raw": 150.5,
"flops_raw": 2500.0,
"arith_intensity": 16.6,
"performance_gflops": 1200.0,
"memory_bw_gbs": 450,
"scalar_peak_gflops": 600,
"duration": 3600
},
{
"node_num": 1,
"bandwidth_raw": 155.2,
"flops_raw": 2600.0,
"arith_intensity": 16.8,
"performance_gflops": 1250.0,
"memory_bw_gbs": 450,
"scalar_peak_gflops": 600,
"duration": 3600
},
{
"node_num": 1,
"bandwidth_raw": 148.0,
"flops_raw": 2450.0,
"arith_intensity": 16.5,
"performance_gflops": 1180.0,
"memory_bw_gbs": 450,
"scalar_peak_gflops": 600,
"duration": 3600
}
])
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture(scope="module")
def setup_predictor():
"""
Set up the predictor for tests.
Try to load the real model if available.
"""
from xgb_inference_api import XGBoostMultiLabelPredictor
model_path = os.path.join(os.path.dirname(__file__), 'xgb_model.joblib')
if os.path.exists(model_path):
try:
predictor = XGBoostMultiLabelPredictor(model_path)
xgb_fastapi.predictor = predictor
return True
except Exception as e:
print(f"Failed to load model: {e}")
return False
return False
@pytest.fixture
def client(setup_predictor):
"""Create a test client for the FastAPI app."""
return TestClient(app)
@pytest.fixture
def model_loaded(setup_predictor):
"""Check if the model is loaded."""
return setup_predictor
def skip_if_no_model(model_loaded):
"""Helper to skip tests if model is not loaded."""
if not model_loaded:
pytest.skip("Model not loaded, skipping test")
# ============================================================================
# Health and Root Endpoint Tests
# ============================================================================
class TestHealthEndpoints:
"""Tests for health check and root endpoints."""
def test_root_endpoint(self, client):
"""Test the root endpoint returns API information."""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert "name" in data
assert data["name"] == "XGBoost Multi-Label Classification API"
assert "version" in data
assert "endpoints" in data
assert all(key in data["endpoints"] for key in ["health", "predict", "predict_top_k", "batch_predict"])
def test_health_check(self, client, model_loaded):
"""Test the health check endpoint."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert "model_loaded" in data
# If model is loaded, check for additional info
if model_loaded:
assert data["model_loaded"] == True
assert data["status"] in ["healthy", "degraded"]
if data["status"] == "healthy":
assert "n_classes" in data
assert "n_features" in data
assert "classes" in data
assert data["n_classes"] > 0
assert data["n_features"] > 0
def test_health_check_model_not_loaded(self, client):
"""Test health check returns correct status when model not loaded."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
# Should have status and model_loaded fields regardless
assert "status" in data
assert "model_loaded" in data
# ============================================================================
# Single Prediction Tests
# ============================================================================
class TestPredictEndpoint:
"""Tests for the /predict endpoint."""
def test_predict_with_feature_dict(self, client, model_loaded):
"""Test prediction with a feature dictionary (TurTLE sample)."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": SAMPLE_TURTLE,
"threshold": 0.5,
"return_all_probabilities": True
}
response = client.post("/predict", json=request_data)
assert response.status_code == 200
data = response.json()
assert "predictions" in data
assert "probabilities" in data
assert "confidences" in data
assert "threshold" in data
assert isinstance(data["predictions"], list)
assert isinstance(data["probabilities"], dict)
assert data["threshold"] == 0.5
# All probabilities should be between 0 and 1
for prob in data["probabilities"].values():
assert 0.0 <= prob <= 1.0
def test_predict_with_different_thresholds(self, client, model_loaded):
"""Test that different thresholds affect predictions."""
if not model_loaded:
pytest.skip("Model not loaded")
request_low = {
"features": SAMPLE_TURTLE,
"threshold": 0.1,
"return_all_probabilities": True
}
request_high = {
"features": SAMPLE_TURTLE,
"threshold": 0.9,
"return_all_probabilities": True
}
response_low = client.post("/predict", json=request_low)
response_high = client.post("/predict", json=request_high)
assert response_low.status_code == 200
assert response_high.status_code == 200
data_low = response_low.json()
data_high = response_high.json()
# Lower threshold should generally produce more predictions
assert len(data_low["predictions"]) >= len(data_high["predictions"])
def test_predict_different_workloads(self, client, model_loaded):
"""Test predictions on different application workloads."""
if not model_loaded:
pytest.skip("Model not loaded")
samples = [
("TurTLE", SAMPLE_TURTLE),
("SCALEXA", SAMPLE_SCALEXA),
("Chroma", SAMPLE_CHROMA),
]
for name, sample in samples:
request_data = {
"features": sample,
"threshold": 0.3,
"return_all_probabilities": True
}
response = client.post("/predict", json=request_data)
assert response.status_code == 200, f"Failed for {name}"
data = response.json()
assert len(data["probabilities"]) > 0, f"No probabilities for {name}"
def test_predict_return_only_predicted_probabilities(self, client, model_loaded):
"""Test prediction with return_all_probabilities=False."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": SAMPLE_TURTLE,
"threshold": 0.3,
"return_all_probabilities": False
}
response = client.post("/predict", json=request_data)
assert response.status_code == 200
data = response.json()
# When return_all_probabilities is False, probabilities should only
# contain classes that are in predictions
if len(data["predictions"]) > 0:
assert set(data["probabilities"].keys()) == set(data["predictions"])
def test_predict_with_json_roofline_data(self, client, model_loaded):
"""Test prediction with raw JSON roofline data (requires aggregation)."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": SAMPLE_JSON_ROOFLINE,
"is_json": True,
"job_id": "test_job_123",
"threshold": 0.3
}
response = client.post("/predict", json=request_data)
assert response.status_code == 200
data = response.json()
assert "predictions" in data
assert "probabilities" in data
def test_predict_threshold_boundaries(self, client, model_loaded):
"""Test prediction with threshold at boundaries."""
if not model_loaded:
pytest.skip("Model not loaded")
for threshold in [0.0, 0.5, 1.0]:
request_data = {
"features": SAMPLE_TURTLE,
"threshold": threshold
}
response = client.post("/predict", json=request_data)
assert response.status_code == 200
def test_predict_invalid_threshold(self, client):
"""Test that invalid threshold values are rejected."""
for threshold in [-0.1, 1.5]:
request_data = {
"features": SAMPLE_TURTLE,
"threshold": threshold
}
response = client.post("/predict", json=request_data)
assert response.status_code == 422 # Validation error
def test_predict_model_not_loaded(self, client):
"""Test that prediction returns 503 when model not loaded."""
# Temporarily set predictor to None
original_predictor = xgb_fastapi.predictor
xgb_fastapi.predictor = None
try:
request_data = {
"features": SAMPLE_TURTLE,
"threshold": 0.5
}
response = client.post("/predict", json=request_data)
assert response.status_code == 503
assert "Model not loaded" in response.json().get("detail", "")
finally:
xgb_fastapi.predictor = original_predictor
# ============================================================================
# Top-K Prediction Tests
# ============================================================================
class TestPredictTopKEndpoint:
"""Tests for the /predict_top_k endpoint."""
def test_predict_top_k_default(self, client, model_loaded):
"""Test top-K prediction with default k=5."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": SAMPLE_TURTLE
}
response = client.post("/predict_top_k", json=request_data)
assert response.status_code == 200
data = response.json()
assert "top_predictions" in data
assert "top_probabilities" in data
assert "all_probabilities" in data
assert len(data["top_predictions"]) <= 5
assert len(data["top_probabilities"]) <= 5
def test_predict_top_k_custom_k(self, client, model_loaded):
"""Test top-K prediction with custom k values."""
if not model_loaded:
pytest.skip("Model not loaded")
for k in [1, 3, 10]:
request_data = {
"features": SAMPLE_TURTLE,
"k": k
}
response = client.post("/predict_top_k", json=request_data)
assert response.status_code == 200
data = response.json()
assert len(data["top_predictions"]) <= k
def test_predict_top_k_ordering(self, client, model_loaded):
"""Test that top-K predictions are ordered by probability."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": SAMPLE_CHROMA,
"k": 10
}
response = client.post("/predict_top_k", json=request_data)
assert response.status_code == 200
data = response.json()
probabilities = [data["top_probabilities"][cls] for cls in data["top_predictions"]]
# Check that probabilities are in descending order
for i in range(len(probabilities) - 1):
assert probabilities[i] >= probabilities[i + 1]
def test_predict_top_k_with_json_data(self, client, model_loaded):
"""Test top-K prediction with JSON roofline data."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": SAMPLE_JSON_ROOFLINE,
"is_json": True,
"job_id": "test_job_456",
"k": 5
}
response = client.post("/predict_top_k", json=request_data)
assert response.status_code == 200
data = response.json()
assert len(data["top_predictions"]) <= 5
def test_predict_top_k_invalid_k(self, client):
"""Test that invalid k values are rejected."""
for k in [0, -1, 101]:
request_data = {
"features": SAMPLE_TURTLE,
"k": k
}
response = client.post("/predict_top_k", json=request_data)
assert response.status_code == 422 # Validation error
# ============================================================================
# Batch Prediction Tests
# ============================================================================
class TestBatchPredictEndpoint:
"""Tests for the /batch_predict endpoint."""
def test_batch_predict_multiple_samples(self, client, model_loaded):
"""Test batch prediction with multiple samples."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features_list": [SAMPLE_TURTLE, SAMPLE_SCALEXA, SAMPLE_CHROMA],
"threshold": 0.3
}
response = client.post("/batch_predict", json=request_data)
assert response.status_code == 200
data = response.json()
assert "results" in data
assert "total" in data
assert "successful" in data
assert "failed" in data
assert data["total"] == 3
assert len(data["results"]) == 3
assert data["successful"] + data["failed"] == data["total"]
def test_batch_predict_single_sample(self, client, model_loaded):
"""Test batch prediction with a single sample."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features_list": [SAMPLE_TURTLE],
"threshold": 0.5
}
response = client.post("/batch_predict", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert len(data["results"]) == 1
def test_batch_predict_with_json_data(self, client, model_loaded):
"""Test batch prediction with JSON roofline data."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features_list": [SAMPLE_JSON_ROOFLINE, SAMPLE_JSON_ROOFLINE],
"is_json": True,
"job_ids": ["job_001", "job_002"],
"threshold": 0.3
}
response = client.post("/batch_predict", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["total"] == 2
def test_batch_predict_empty_list(self, client, model_loaded):
"""Test batch prediction with empty list."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features_list": [],
"threshold": 0.5
}
response = client.post("/batch_predict", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["total"] == 0
assert data["successful"] == 0
assert data["failed"] == 0
# ============================================================================
# Model Info Tests
# ============================================================================
class TestModelInfoEndpoint:
"""Tests for the /model/info endpoint."""
def test_model_info(self, client, model_loaded):
"""Test getting model information."""
response = client.get("/model/info")
# May return 503 if model not loaded, or 200 if loaded
if model_loaded:
assert response.status_code == 200
data = response.json()
assert "classes" in data
assert "n_classes" in data
assert "features" in data
assert "n_features" in data
assert isinstance(data["classes"], list)
assert len(data["classes"]) == data["n_classes"]
assert len(data["features"]) == data["n_features"]
else:
assert response.status_code == 503
def test_model_info_not_loaded(self, client):
"""Test that model info returns 503 when model not loaded."""
original_predictor = xgb_fastapi.predictor
xgb_fastapi.predictor = None
try:
response = client.get("/model/info")
assert response.status_code == 503
finally:
xgb_fastapi.predictor = original_predictor
# ============================================================================
# Error Handling Tests
# ============================================================================
class TestErrorHandling:
"""Tests for error handling."""
def test_predict_missing_features(self, client):
"""Test prediction without features field."""
request_data = {
"threshold": 0.5
}
response = client.post("/predict", json=request_data)
assert response.status_code == 422
def test_predict_invalid_json_format(self, client, model_loaded):
"""Test prediction with invalid JSON in is_json mode."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": "not valid json {{",
"is_json": True
}
response = client.post("/predict", json=request_data)
# Should return error (400)
assert response.status_code == 400
def test_invalid_endpoint(self, client):
"""Test accessing an invalid endpoint."""
response = client.get("/nonexistent")
assert response.status_code == 404
# ============================================================================
# Integration Tests (Full Pipeline)
# ============================================================================
class TestIntegration:
"""Integration tests for full prediction pipeline."""
def test_full_prediction_pipeline_features(self, client, model_loaded):
"""Test complete prediction pipeline with feature dict."""
if not model_loaded:
pytest.skip("Model not loaded")
# 1. Check health
health_response = client.get("/health")
assert health_response.status_code == 200
health_data = health_response.json()
assert health_data["model_loaded"] == True
# 2. Get model info
info_response = client.get("/model/info")
assert info_response.status_code == 200
# 3. Make single prediction
predict_response = client.post("/predict", json={
"features": SAMPLE_TURTLE,
"threshold": 0.3
})
assert predict_response.status_code == 200
# 4. Make top-K prediction
topk_response = client.post("/predict_top_k", json={
"features": SAMPLE_TURTLE,
"k": 5
})
assert topk_response.status_code == 200
# 5. Make batch prediction
batch_response = client.post("/batch_predict", json={
"features_list": [SAMPLE_TURTLE, SAMPLE_CHROMA],
"threshold": 0.3
})
assert batch_response.status_code == 200
def test_consistency_single_vs_batch(self, client, model_loaded):
"""Test that single prediction and batch prediction give consistent results."""
if not model_loaded:
pytest.skip("Model not loaded")
threshold = 0.3
# Single prediction
single_response = client.post("/predict", json={
"features": SAMPLE_TURTLE,
"threshold": threshold
})
# Batch prediction with same sample
batch_response = client.post("/batch_predict", json={
"features_list": [SAMPLE_TURTLE],
"threshold": threshold
})
assert single_response.status_code == 200
assert batch_response.status_code == 200
single_data = single_response.json()
batch_data = batch_response.json()
if batch_data["successful"] == 1:
batch_result = batch_data["results"][0]
# Predictions should be the same
assert set(single_data["predictions"]) == set(batch_result["predictions"])
# ============================================================================
# CORS Tests
# ============================================================================
class TestCORS:
"""Tests for CORS configuration."""
def test_cors_headers(self, client):
"""Test that CORS headers are present in responses."""
response = client.options("/predict")
# Accept either 200 or 405 (method not allowed for OPTIONS in some configs)
assert response.status_code in [200, 405]
# ============================================================================
# Performance Tests (Basic)
# ============================================================================
class TestPerformance:
"""Basic performance tests."""
def test_response_time_single_prediction(self, client, model_loaded):
"""Test that single prediction completes in reasonable time."""
if not model_loaded:
pytest.skip("Model not loaded")
import time
start = time.time()
response = client.post("/predict", json={
"features": SAMPLE_TURTLE,
"threshold": 0.5
})
elapsed = time.time() - start
# Should complete within 5 seconds (generous for CI environments)
assert elapsed < 5.0, f"Prediction took {elapsed:.2f}s, expected < 5s"
assert response.status_code == 200
def test_response_time_batch_prediction(self, client, model_loaded):
"""Test that batch prediction scales reasonably."""
if not model_loaded:
pytest.skip("Model not loaded")
import time
# Create a batch of 10 samples
features_list = [SAMPLE_TURTLE] * 10
start = time.time()
response = client.post("/batch_predict", json={
"features_list": features_list,
"threshold": 0.5
})
elapsed = time.time() - start
# Should complete within 10 seconds
assert elapsed < 10.0, f"Batch prediction took {elapsed:.2f}s, expected < 10s"
assert response.status_code == 200
# ============================================================================
# Run tests if executed directly
# ============================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])

429
xgb_fastapi.py Normal file
View File

@@ -0,0 +1,429 @@
#!/usr/bin/env python3
"""
FastAPI REST API for XGBoost Multi-Label Classification Inference
Provides HTTP endpoints for:
- POST /predict: Single prediction with confidence scores
- POST /predict_top_k: Top-K predictions
- POST /batch_predict: Batch predictions
Supports both feature vectors and JSON roofline data as input.
"""
from fastapi import FastAPI, HTTPException, Body
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, validator
from typing import Dict, List, Union, Optional, Any
import logging
import uvicorn
from contextlib import asynccontextmanager
from xgb_local import XGBoostMultiLabelPredictor
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Global predictor instance
predictor: Optional[XGBoostMultiLabelPredictor] = None
# Pydantic models for request/response validation
class PredictRequest(BaseModel):
"""Request model for single prediction."""
features: Union[Dict[str, float], str] = Field(
...,
description="Feature dictionary or JSON string of roofline data"
)
threshold: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Probability threshold for classification"
)
return_all_probabilities: bool = Field(
default=True,
description="Whether to return probabilities for all classes"
)
is_json: bool = Field(
default=False,
description="Whether features is a JSON string of roofline data"
)
job_id: Optional[str] = Field(
default=None,
description="Optional job ID for JSON aggregation"
)
class PredictTopKRequest(BaseModel):
"""Request model for top-K prediction."""
features: Union[Dict[str, float], str] = Field(
...,
description="Feature dictionary or JSON string of roofline data"
)
k: int = Field(
default=5,
ge=1,
le=100,
description="Number of top predictions to return"
)
is_json: bool = Field(
default=False,
description="Whether features is a JSON string of roofline data"
)
job_id: Optional[str] = Field(
default=None,
description="Optional job ID for JSON aggregation"
)
class BatchPredictRequest(BaseModel):
"""Request model for batch prediction."""
features_list: List[Union[Dict[str, float], str]] = Field(
...,
description="List of feature dictionaries or JSON strings"
)
threshold: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Probability threshold for classification"
)
is_json: bool = Field(
default=False,
description="Whether features are JSON strings of roofline data"
)
job_ids: Optional[List[str]] = Field(
default=None,
description="Optional list of job IDs for JSON aggregation"
)
class PredictResponse(BaseModel):
"""Response model for single prediction."""
predictions: List[str] = Field(description="List of predicted class names")
probabilities: Dict[str, float] = Field(description="Probabilities for each class")
confidences: Dict[str, float] = Field(description="Confidence scores for predicted classes")
threshold: float = Field(description="Threshold used for prediction")
class PredictTopKResponse(BaseModel):
"""Response model for top-K prediction."""
top_predictions: List[str] = Field(description="Top-K predicted class names")
top_probabilities: Dict[str, float] = Field(description="Probabilities for top-K classes")
all_probabilities: Dict[str, float] = Field(description="Probabilities for all classes")
class BatchPredictResponse(BaseModel):
"""Response model for batch prediction."""
results: List[Union[PredictResponse, Dict[str, str]]] = Field(
description="List of prediction results or errors"
)
total: int = Field(description="Total number of samples processed")
successful: int = Field(description="Number of successful predictions")
failed: int = Field(description="Number of failed predictions")
class HealthResponse(BaseModel):
"""Response model for health check."""
status: str
model_loaded: bool
n_classes: Optional[int] = None
n_features: Optional[int] = None
classes: Optional[List[str]] = None
class ErrorResponse(BaseModel):
"""Response model for errors."""
error: str
detail: Optional[str] = None
# Lifespan context manager for startup/shutdown
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan events."""
# Startup
global predictor
try:
logger.info("Loading XGBoost model...")
predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load model: {e}")
predictor = None
yield
# Shutdown
logger.info("Shutting down...")
# Initialize FastAPI app
app = FastAPI(
title="XGBoost Multi-Label Classification API",
description="REST API for multi-label classification inference using XGBoost",
version="1.0.0",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify allowed origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API Endpoints
@app.get("/", tags=["General"])
async def root():
"""Root endpoint with API information."""
return {
"name": "XGBoost Multi-Label Classification API",
"version": "1.0.0",
"endpoints": {
"health": "/health",
"predict": "/predict",
"predict_top_k": "/predict_top_k",
"batch_predict": "/batch_predict"
}
}
@app.get("/health", response_model=HealthResponse, tags=["General"])
async def health_check():
"""
Check API health and model status.
Returns model information if loaded successfully.
"""
if predictor is None:
return HealthResponse(
status="error",
model_loaded=False
)
try:
info = predictor.get_class_info()
return HealthResponse(
status="healthy",
model_loaded=True,
n_classes=info['n_classes'],
n_features=info['n_features'],
classes=info['classes']
)
except Exception as e:
logger.error(f"Health check error: {e}")
return HealthResponse(
status="degraded",
model_loaded=True
)
@app.post("/predict", response_model=PredictResponse, tags=["Inference"])
async def predict(request: PredictRequest):
"""
Perform single prediction on input features.
**Input formats:**
- Feature dictionary: `{"feature1": value1, "feature2": value2, ...}`
- JSON roofline data: Set `is_json=true` and provide JSON string
**Example (features):**
```json
{
"features": {
"bandwidth_raw_p10": 150.5,
"flops_raw_median": 2500.0,
...
},
"threshold": 0.5
}
```
**Example (JSON roofline):**
```json
{
"features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
"is_json": true,
"job_id": "test_job_123",
"threshold": 0.3
}
```
"""
if predictor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
result = predictor.predict(
features=request.features,
threshold=request.threshold,
return_all_probabilities=request.return_all_probabilities,
is_json=request.is_json,
job_id=request.job_id
)
return PredictResponse(**result)
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=400, detail=str(e))
@app.post("/predict_top_k", response_model=PredictTopKResponse, tags=["Inference"])
async def predict_top_k(request: PredictTopKRequest):
"""
Get top-K predictions with their probabilities.
**Example (features):**
```json
{
"features": {
"bandwidth_raw_p10": 150.5,
"flops_raw_median": 2500.0,
...
},
"k": 5
}
```
**Example (JSON roofline):**
```json
{
"features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
"is_json": true,
"job_id": "test_job_123",
"k": 10
}
```
"""
if predictor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
result = predictor.predict_top_k(
features=request.features,
k=request.k,
is_json=request.is_json,
job_id=request.job_id
)
return PredictTopKResponse(**result)
except Exception as e:
logger.error(f"Top-K prediction error: {e}")
raise HTTPException(status_code=400, detail=str(e))
@app.post("/batch_predict", response_model=BatchPredictResponse, tags=["Inference"])
async def batch_predict(request: BatchPredictRequest):
"""
Perform batch prediction on multiple samples.
**Example (features):**
```json
{
"features_list": [
{"bandwidth_raw_p10": 150.5, ...},
{"bandwidth_raw_p10": 160.2, ...}
],
"threshold": 0.5
}
```
**Example (JSON roofline):**
```json
{
"features_list": [
"[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
"[{\"node_num\": 2, \"bandwidth_raw\": 160.2, ...}]"
],
"is_json": true,
"job_ids": ["job1", "job2"],
"threshold": 0.3
}
```
"""
if predictor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
results = predictor.batch_predict(
features_list=request.features_list,
threshold=request.threshold,
is_json=request.is_json,
job_ids=request.job_ids
)
# Count successful and failed predictions
successful = sum(1 for r in results if 'error' not in r)
failed = len(results) - successful
return BatchPredictResponse(
results=results,
total=len(results),
successful=successful,
failed=failed
)
except Exception as e:
logger.error(f"Batch prediction error: {e}")
raise HTTPException(status_code=400, detail=str(e))
@app.get("/model/info", tags=["Model"])
async def model_info():
"""
Get detailed model information.
Returns information about classes, features, and model configuration.
"""
if predictor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
info = predictor.get_class_info()
return {
"classes": info['classes'],
"n_classes": info['n_classes'],
"features": info['feature_columns'],
"n_features": info['n_features']
}
except Exception as e:
logger.error(f"Model info error: {e}")
raise HTTPException(status_code=500, detail=str(e))
def main():
"""Run the FastAPI server."""
import argparse
parser = argparse.ArgumentParser(description="XGBoost Multi-Label Classification REST API")
parser.add_argument('--host', type=str, default='0.0.0.0',
help='Host to bind to (default: 0.0.0.0)')
parser.add_argument('--port', type=int, default=8000,
help='Port to bind to (default: 8000)')
parser.add_argument('--reload', action='store_true',
help='Enable auto-reload for development')
parser.add_argument('--workers', type=int, default=1,
help='Number of worker processes (default: 1)')
args = parser.parse_args()
logger.info(f"Starting FastAPI server on {args.host}:{args.port}")
logger.info(f"Workers: {args.workers}, Reload: {args.reload}")
uvicorn.run(
"xgb_fastapi:app",
host=args.host,
port=args.port,
reload=args.reload,
workers=args.workers if not args.reload else 1,
log_level="info"
)
if __name__ == "__main__":
main()

530
xgb_local.py Normal file
View File

@@ -0,0 +1,530 @@
import pandas as pd
import numpy as np
import joblib
import json
from typing import Dict, List, Tuple, Union, Optional
import warnings
warnings.filterwarnings('ignore')
from scipy import stats
def compute_mad(data: np.ndarray) -> float:
"""Compute Median Absolute Deviation."""
median = np.median(data)
mad = np.median(np.abs(data - median))
return mad
def df_aggregate(json_str: str, job_id_full: Optional[str] = None) -> Dict:
"""
Aggregate roofline data from JSON string into a single feature vector.
Args:
json_str: JSON string containing roofline data records
job_id_full: Optional job ID to include in the result
Returns:
Dictionary containing aggregated features
"""
# Parse JSON string to DataFrame
try:
data = json.loads(json_str)
if isinstance(data, list):
df = pd.DataFrame(data)
elif isinstance(data, dict):
df = pd.DataFrame([data])
else:
raise ValueError("JSON must contain a list of objects or a single object")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON string: {e}")
# Group data by node_num
if 'node_num' not in df.columns:
# If no node_num, treat all data as single node
df['node_num'] = 1
grouped = df.groupby('node_num')
all_features = []
for node_num, group in grouped:
features = {
'node_num': int(node_num)
}
if job_id_full is not None:
features['job_id'] = job_id_full
# Compute statistics for key metrics
for axis in ['bandwidth_raw', 'flops_raw', 'arith_intensity']:
data = group[axis].values
# Compute percentiles
p10 = np.percentile(data, 10)
p50 = np.median(data)
p90 = np.percentile(data, 90)
# Compute MAD (more robust than variance)
mad = compute_mad(data)
# Store features
features[f'{axis}_p10'] = p10
features[f'{axis}_median'] = p50
features[f'{axis}_p90'] = p90
features[f'{axis}_mad'] = mad
features[f'{axis}_range'] = p90 - p10
features[f'{axis}_iqr'] = np.percentile(data, 75) - np.percentile(data, 25)
# Compute covariance and correlation between bandwidth_raw and flops_raw
if len(group) > 1: # Need at least 2 points for correlation
cov = np.cov(group['bandwidth_raw'], group['flops_raw'])[0, 1]
features['bw_flops_covariance'] = cov
corr, _ = stats.pearsonr(group['bandwidth_raw'], group['flops_raw'])
features['bw_flops_correlation'] = corr
# Additional useful features for the classifier
# Performance metrics
features['avg_performance_gflops'] = group['performance_gflops'].mean()
features['median_performance_gflops'] = group['performance_gflops'].median()
features['performance_gflops_mad'] = compute_mad(group['performance_gflops'].values)
# # Efficiency metrics
# features['avg_efficiency'] = group['efficiency'].mean()
# features['median_efficiency'] = group['efficiency'].median()
# features['efficiency_mad'] = compute_mad(group['efficiency'].values)
# features['efficiency_p10'] = np.percentile(group['efficiency'].values, 10)
# features['efficiency_p90'] = np.percentile(group['efficiency'].values, 90)
# # Distribution of roofline regions (memory-bound vs compute-bound)
# if 'roofline_region' in group.columns:
# region_counts = group['roofline_region'].value_counts(normalize=True).to_dict()
# for region, ratio in region_counts.items():
# features[f'region_{region}_ratio'] = ratio
# System characteristics
if 'memory_bw_gbs' in group.columns:
features['avg_memory_bw_gbs'] = group['memory_bw_gbs'].mean()
if 'scalar_peak_gflops' in group.columns and len(group['scalar_peak_gflops'].unique()) > 0:
features['scalar_peak_gflops'] = group['scalar_peak_gflops'].iloc[0]
if 'simd_peak_gflops' in group.columns and len(group['simd_peak_gflops'].unique()) > 0:
features['simd_peak_gflops'] = group['simd_peak_gflops'].iloc[0]
# # Subcluster information if available
# if 'subcluster_name' in group.columns and not group['subcluster_name'].isna().all():
# features['subcluster_name'] = group['subcluster_name'].iloc[0]
# Duration information
if 'duration' in group.columns:
features['duration'] = group['duration'].iloc[0]
all_features.append(features)
# Return first node's features (or combine multiple nodes if needed)
if len(all_features) == 1:
return all_features[0]
else:
# If multiple nodes, return the first one or average across nodes
# For now, return the first node's features
return all_features[0]
class XGBoostMultiLabelPredictor:
"""
Python API for XGBoost multi-label classification inference.
Provides methods to load trained models band perform inference with
confidence scores for each class.
"""
def __init__(self, model_path: str = 'xgb_model.joblib'):
"""
Initialize the predictor by loading the trained model.
Args:
model_path: Path to the saved model file (.joblib)
"""
self.model_data = None
self.model = None
self.mlb = None
self.feature_columns = None
self.n_features = 0
self.classes = []
self.load_model(model_path)
def load_model(self, model_path: str) -> None:
"""
Load the trained XGBoost model from disk.
Args:
model_path: Path to the saved model file
"""
try:
print(f"Loading model from {model_path}...")
self.model_data = joblib.load(model_path)
self.model = self.model_data['model']
self.mlb = self.model_data['mlb']
self.feature_columns = self.model_data['feature_columns']
self.classes = list(self.mlb.classes_)
self.n_features = len(self.feature_columns)
print("Model loaded successfully!")
print(f" - {len(self.classes)} classes: {self.classes}")
print(f" - {self.n_features} features: {self.feature_columns[:5]}...")
print(f" - Model type: {type(self.model).__name__}")
except Exception as e:
raise ValueError(f"Failed to load model from {model_path}: {e}")
def predict(self, features: Union[pd.DataFrame, np.ndarray, List, Dict, str],
threshold: float = 0.5,
return_all_probabilities: bool = True,
is_json: bool = False,
job_id: Optional[str] = None) -> Dict:
"""
Perform multi-label prediction on input features.
Args:
features: Input features in various formats:
- pandas DataFrame
- numpy array (2D)
- list of lists/dicts
- single feature vector (list/dict)
- JSON string (if is_json=True): roofline data to aggregate
threshold: Probability threshold for binary classification (0.0-1.0)
return_all_probabilities: If True, return probabilities for all classes.
If False, return only classes above threshold.
is_json: If True, treat features as JSON string of roofline data
job_id: Optional job ID (used when is_json=True)
Returns:
Dictionary containing:
- 'predictions': List of predicted class names
- 'probabilities': Dict of {class_name: probability} for all classes
- 'confidences': Dict of {class_name: confidence_score} for predicted classes
- 'threshold': The threshold used
"""
# If input is JSON string, aggregate features first
if is_json:
if not isinstance(features, str):
raise ValueError("When is_json=True, features must be a JSON string")
features = df_aggregate(features, job_id_full=job_id)
# Convert input to proper format
X = self._prepare_features(features)
# Get probability predictions
probabilities = self.model.predict_proba(X)
# Convert to class probabilities
class_probabilities = {}
for i, class_name in enumerate(self.classes):
# For OneVsRest, predict_proba returns shape (n_samples, n_classes)
# Each column i contains probabilities for class i
if isinstance(probabilities, list):
# List of arrays (multiple samples)
prob_array = probabilities[i]
prob_positive = prob_array[0] if hasattr(prob_array, '__getitem__') else float(prob_array)
else:
# 2D numpy array (single sample or batch)
if len(probabilities.shape) == 2:
# Shape: (n_samples, n_classes)
prob_positive = float(probabilities[0, i])
else:
# 1D array
prob_positive = float(probabilities[i])
class_probabilities[class_name] = prob_positive
# Apply threshold for predictions
predictions = []
confidences = {}
for class_name, prob in class_probabilities.items():
if prob >= threshold:
predictions.append(class_name)
# Confidence score: distance from threshold as percentage
confidence = min(1.0, (prob - threshold) / (1.0 - threshold)) * 100
confidences[class_name] = round(confidence, 2)
# Sort predictions by probability
predictions.sort(key=lambda x: class_probabilities[x], reverse=True)
result = {
'predictions': predictions,
'probabilities': {k: round(v, 4) for k, v in class_probabilities.items()},
'confidences': confidences,
'threshold': threshold
}
if not return_all_probabilities:
result['probabilities'] = {k: v for k, v in result['probabilities'].items()
if k in predictions}
return result
def predict_top_k(self, features: Union[pd.DataFrame, np.ndarray, List, Dict, str],
k: int = 5,
is_json: bool = False,
job_id: Optional[str] = None) -> Dict:
"""
Get top-k predictions with their probabilities.
Args:
features: Input features (various formats) or JSON string if is_json=True
k: Number of top predictions to return
is_json: If True, treat features as JSON string of roofline data
job_id: Optional job ID (used when is_json=True)
Returns:
Dictionary with top-k predictions and their details
"""
# If input is JSON string, aggregate features first
if is_json:
if not isinstance(features, str):
raise ValueError("When is_json=True, features must be a JSON string")
features = df_aggregate(features, job_id_full=job_id)
# Get all probabilities
X = self._prepare_features(features)
probabilities = self.model.predict_proba(X)
class_probabilities = {}
for i, class_name in enumerate(self.classes):
# For OneVsRest, predict_proba returns shape (n_samples, n_classes)
# Each column i contains probabilities for class i
if isinstance(probabilities, list):
# List of arrays (multiple samples)
prob_array = probabilities[i]
prob_positive = prob_array[0] if hasattr(prob_array, '__getitem__') else float(prob_array)
else:
# 2D numpy array (single sample or batch)
if len(probabilities.shape) == 2:
# Shape: (n_samples, n_classes)
prob_positive = float(probabilities[0, i])
else:
# 1D array
prob_positive = float(probabilities[i])
class_probabilities[class_name] = prob_positive
# Sort by probability
sorted_classes = sorted(class_probabilities.items(),
key=lambda x: x[1], reverse=True)
top_k_classes = sorted_classes[:k]
return {
'top_predictions': [cls for cls, _ in top_k_classes],
'top_probabilities': {cls: round(prob, 4) for cls, prob in top_k_classes},
'all_probabilities': {k: round(v, 4) for k, v in class_probabilities.items()}
}
def _prepare_features(self, features: Union[pd.DataFrame, np.ndarray, List, Dict]) -> pd.DataFrame:
"""
Convert various input formats to the expected feature format.
Args:
features: Input features in various formats
Returns:
pandas DataFrame with correct columns and order
"""
if isinstance(features, pd.DataFrame):
df = features.copy()
elif isinstance(features, np.ndarray):
if features.ndim == 1:
features = features.reshape(1, -1)
df = pd.DataFrame(features, columns=self.feature_columns[:features.shape[1]])
elif isinstance(features, list):
if isinstance(features[0], dict):
# List of dictionaries
df = pd.DataFrame(features)
else:
# List of lists
df = pd.DataFrame(features, columns=self.feature_columns[:len(features[0])])
elif isinstance(features, dict):
# Single feature dictionary
df = pd.DataFrame([features])
else:
raise ValueError(f"Unsupported feature format: {type(features)}")
# Ensure correct column order and fill missing columns with 0
for col in self.feature_columns:
if col not in df.columns:
df[col] = 0.0
df = df[self.feature_columns]
# Validate feature count
if df.shape[1] != self.n_features:
raise ValueError(f"Expected {self.n_features} features, got {df.shape[1]}")
return df
def get_class_info(self) -> Dict:
"""
Get information about available classes.
Returns:
Dictionary with class information
"""
return {
'classes': self.classes,
'n_classes': len(self.classes),
'feature_columns': self.feature_columns,
'n_features': self.n_features
}
def batch_predict(self, features_list: List[Union[pd.DataFrame, np.ndarray, List, Dict, str]],
threshold: float = 0.5,
is_json: bool = False,
job_ids: Optional[List[str]] = None) -> List[Dict]:
"""
Perform batch prediction on multiple samples.
Args:
features_list: List of feature inputs (or JSON strings if is_json=True)
threshold: Probability threshold
is_json: If True, treat each item in features_list as JSON string
job_ids: Optional list of job IDs (used when is_json=True)
Returns:
List of prediction results
"""
results = []
for idx, features in enumerate(features_list):
try:
job_id = job_ids[idx] if job_ids and idx < len(job_ids) else None
result = self.predict(features, threshold=threshold, is_json=is_json, job_id=job_id)
results.append(result)
except Exception as e:
results.append({'error': str(e)})
return results
def create_sample_data(n_samples: int = 5) -> List[Dict]:
"""
Create sample feature data for testing.
Args:
n_samples: Number of sample feature vectors to create
Returns:
List of feature dictionaries
"""
np.random.seed(42)
# Load feature columns from model if available
try:
model_data = joblib.load('xgb_model.joblib')
feature_columns = model_data['feature_columns']
except:
# Fallback to some default features
feature_columns = [
'node_num', 'bandwidth_raw_p10', 'bandwidth_raw_median',
'bandwidth_raw_p90', 'bandwidth_raw_mad', 'bandwidth_raw_range',
'bandwidth_raw_iqr', 'flops_raw_p10', 'flops_raw_median',
'flops_raw_p90', 'flops_raw_mad', 'flops_raw_range'
]
samples = []
for _ in range(n_samples):
sample = {}
for col in feature_columns:
if 'bandwidth' in col:
sample[col] = np.random.uniform(50, 500)
elif 'flops' in col:
sample[col] = np.random.uniform(100, 5000)
elif 'node_num' in col:
sample[col] = np.random.randint(1, 16)
else:
sample[col] = np.random.uniform(0, 1000)
samples.append(sample)
return samples
if __name__ == "__main__":
print("XGBoost Multi-Label Inference API")
print("=" * 40)
# Initialize predictor
try:
predictor = XGBoostMultiLabelPredictor()
except Exception as e:
print(f"Error loading model: {e}")
exit(1)
# Example usage of df_aggregate with JSON string
print("\n=== Example 0: JSON Aggregation ===")
sample_json = json.dumps([
{
"node_num": 1,
"bandwidth_raw": 150.5,
"flops_raw": 2500.0,
"arith_intensity": 16.6,
"performance_gflops": 1200.0,
"memory_bw_gbs": 450,
"scalar_peak_gflops": 600,
"duration": 3600
},
{
"node_num": 2,
"bandwidth_raw": 155.2,
"flops_raw": 2600.0,
"arith_intensity": 16.8,
"performance_gflops": 1250.0,
"memory_bw_gbs": 450,
"scalar_peak_gflops": 600,
"duration": 3600
}
])
try:
aggregated_features = df_aggregate(sample_json, job_id_full="test_job_123")
print(f"Aggregated features from JSON:")
for key, value in list(aggregated_features.items())[:10]:
print(f" {key}: {value}")
print(f" ... ({len(aggregated_features)} total features)")
# Use aggregated features for prediction
result = predictor.predict(aggregated_features, threshold=0.3)
print(f"\nPredictions from aggregated data: {result['predictions'][:3]}")
except Exception as e:
print(f"Error in aggregation: {e}")
# Create sample data
print("\n=== Generating sample data for other examples ===")
sample_data = create_sample_data(3)
# Example 1: Single prediction
print("\n=== Example 1: Single Prediction ===")
result = predictor.predict(sample_data[0], threshold=0.3)
print(f"Predictions: {result['predictions']}")
print(f"Confidences: {result['confidences']}")
print(f"Top probabilities:")
for class_name, prob in sorted(result['probabilities'].items(),
key=lambda x: x[1], reverse=True)[:5]:
print(".4f")
# Example 2: Top-K predictions
print("\n=== Example 2: Top-5 Predictions ===")
top_result = predictor.predict_top_k(sample_data[1], k=5)
for i, class_name in enumerate(top_result['top_predictions'], 1):
prob = top_result['top_probabilities'][class_name]
print(f"{i}. {class_name}: {prob:.4f}")
# Example 3: Batch prediction
print("\n=== Example 3: Batch Prediction ===")
batch_results = predictor.batch_predict(sample_data, threshold=0.4)
for i, result in enumerate(batch_results, 1):
if 'error' not in result:
print(f"Sample {i}: {len(result['predictions'])} predictions")
else:
print(f"Sample {i}: Error - {result['error']}")
print("\nAPI ready for use!")
print("Usage:")
print(" predictor = XGBoostMultiLabelPredictor()")
print(" result = predictor.predict(your_features)")
print(" top_k = predictor.predict_top_k(your_features, k=5)")

221
xgb_local_example.py Normal file
View File

@@ -0,0 +1,221 @@
#!/usr/bin/env python3
"""
XGBoost Multi-Label Inference Usage Examples
===================================================
This script demonstrates how to use the XGBoostMultiLabelPredictor class
for multi-label classification with confidence scores.
Sample data is from real HPC workloads extracted from roofline_features.h5:
- TurTLE: Turbulence simulation (memory-bound, low arithmetic intensity ~0.84)
- SCALEXA: Scaling benchmarks (high bw-flops correlation ~0.995)
- Chroma: Lattice QCD (compute-intensive, high arithmetic intensity ~2.6)
"""
import json
from xgb_local import XGBoostMultiLabelPredictor
# ============================================================================
# Realistic Sample Data from roofline_features.h5
# ============================================================================
# TurTLE application - turbulence simulation workload
SAMPLE_TURTLE = {
"bandwidth_raw_p10": 186.33,
"bandwidth_raw_median": 205.14,
"bandwidth_raw_p90": 210.83,
"bandwidth_raw_mad": 3.57,
"bandwidth_raw_range": 24.5,
"bandwidth_raw_iqr": 12.075,
"flops_raw_p10": 162.024,
"flops_raw_median": 171.45,
"flops_raw_p90": 176.48,
"flops_raw_mad": 3.08,
"flops_raw_range": 14.456,
"flops_raw_iqr": 8.29,
"arith_intensity_p10": 0.7906,
"arith_intensity_median": 0.837,
"arith_intensity_p90": 0.9109,
"arith_intensity_mad": 0.02,
"arith_intensity_range": 0.12,
"arith_intensity_iqr": 0.0425,
"bw_flops_covariance": 60.86,
"bw_flops_correlation": 0.16,
"avg_performance_gflops": 168.1,
"median_performance_gflops": 171.45,
"performance_gflops_mad": 3.08,
"avg_memory_bw_gbs": 350.0,
"scalar_peak_gflops": 432.0,
"simd_peak_gflops": 9216.0,
"node_num": 0,
"duration": 19366,
}
# SCALEXA application - scaling benchmark workload
SAMPLE_SCALEXA = {
"bandwidth_raw_p10": 13.474,
"bandwidth_raw_median": 32.57,
"bandwidth_raw_p90": 51.466,
"bandwidth_raw_mad": 23.62,
"bandwidth_raw_range": 37.992,
"bandwidth_raw_iqr": 23.745,
"flops_raw_p10": 4.24,
"flops_raw_median": 16.16,
"flops_raw_p90": 24.584,
"flops_raw_mad": 10.53,
"flops_raw_range": 20.344,
"flops_raw_iqr": 12.715,
"arith_intensity_p10": 0.211,
"arith_intensity_median": 0.475,
"arith_intensity_p90": 0.492,
"arith_intensity_mad": 0.021,
"arith_intensity_range": 0.281,
"arith_intensity_iqr": 0.176,
"bw_flops_covariance": 302.0,
"bw_flops_correlation": 0.995,
"avg_performance_gflops": 14.7,
"median_performance_gflops": 16.16,
"performance_gflops_mad": 10.53,
"avg_memory_bw_gbs": 350.0,
"scalar_peak_gflops": 432.0,
"simd_peak_gflops": 9216.0,
"node_num": 18,
"duration": 165,
}
# Chroma application - lattice QCD workload (compute-intensive)
SAMPLE_CHROMA = {
"bandwidth_raw_p10": 154.176,
"bandwidth_raw_median": 200.57,
"bandwidth_raw_p90": 259.952,
"bandwidth_raw_mad": 5.12,
"bandwidth_raw_range": 105.776,
"bandwidth_raw_iqr": 10.215,
"flops_raw_p10": 327.966,
"flops_raw_median": 519.8,
"flops_raw_p90": 654.422,
"flops_raw_mad": 16.97,
"flops_raw_range": 326.456,
"flops_raw_iqr": 34.88,
"arith_intensity_p10": 1.55,
"arith_intensity_median": 2.595,
"arith_intensity_p90": 3.445,
"arith_intensity_mad": 0.254,
"arith_intensity_range": 1.894,
"arith_intensity_iqr": 0.512,
"bw_flops_covariance": 382.76,
"bw_flops_correlation": 0.063,
"avg_performance_gflops": 503.26,
"median_performance_gflops": 519.8,
"performance_gflops_mad": 16.97,
"avg_memory_bw_gbs": 350.0,
"scalar_peak_gflops": 432.0,
"simd_peak_gflops": 9216.0,
"node_num": 3,
"duration": 31133,
}
# Raw JSON roofline data (before aggregation, as produced by monitoring)
SAMPLE_JSON_ROOFLINE = json.dumps([
{"node_num": 1, "bandwidth_raw": 150.5, "flops_raw": 2500.0, "arith_intensity": 16.6,
"performance_gflops": 1200.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
{"node_num": 1, "bandwidth_raw": 155.2, "flops_raw": 2600.0, "arith_intensity": 16.8,
"performance_gflops": 1250.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
{"node_num": 1, "bandwidth_raw": 148.0, "flops_raw": 2450.0, "arith_intensity": 16.5,
"performance_gflops": 1180.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
])
def main():
print("XGBoost Multi-Label Inference Examples")
print("=" * 50)
# Initialize the predictor
predictor = XGBoostMultiLabelPredictor()
# =========================================================================
# Example 1: Single prediction with aggregated features
# =========================================================================
print("\n=== Example 1: Single Prediction (TurTLE workload) ===")
result = predictor.predict(SAMPLE_TURTLE, threshold=0.3)
print(f"Predictions: {result['predictions']}")
print(f"Confidences: {result['confidences']}")
print("\nTop 5 probabilities:")
sorted_probs = sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True)
for cls, prob in sorted_probs[:5]:
print(f" {cls}: {prob:.4f}")
# =========================================================================
# Example 2: Compare different workload types
# =========================================================================
print("\n=== Example 2: Compare Different Workloads ===")
workloads = [
("TurTLE (turbulence)", SAMPLE_TURTLE),
("SCALEXA (benchmark)", SAMPLE_SCALEXA),
("Chroma (lattice QCD)", SAMPLE_CHROMA),
]
for name, features in workloads:
result = predictor.predict(features, threshold=0.3)
top_pred = result['predictions'][0] if result['predictions'] else "None"
top_prob = max(result['probabilities'].values()) if result['probabilities'] else 0
print(f"{name:25} -> Top prediction: {top_pred:20} (prob: {top_prob:.4f})")
# =========================================================================
# Example 3: Top-K predictions
# =========================================================================
print("\n=== Example 3: Top-5 Predictions (Chroma workload) ===")
top_k_result = predictor.predict_top_k(SAMPLE_CHROMA, k=5)
for i, cls in enumerate(top_k_result['top_predictions'], 1):
prob = top_k_result['top_probabilities'][cls]
print(f" {i}. {cls}: {prob:.4f}")
# =========================================================================
# Example 4: Batch prediction
# =========================================================================
print("\n=== Example 4: Batch Prediction ===")
batch_data = [SAMPLE_TURTLE, SAMPLE_SCALEXA, SAMPLE_CHROMA]
batch_results = predictor.batch_predict(batch_data, threshold=0.3)
for i, result in enumerate(batch_results, 1):
if 'error' not in result:
preds = result['predictions'][:2] # Show top 2
print(f" Sample {i}: {preds}")
else:
print(f" Sample {i}: Error - {result['error']}")
# =========================================================================
# Example 5: Prediction from raw JSON roofline data
# =========================================================================
print("\n=== Example 5: Prediction from Raw JSON Data ===")
result = predictor.predict(
SAMPLE_JSON_ROOFLINE,
is_json=True,
job_id="example_job_001",
threshold=0.3
)
print(f"Predictions: {result['predictions'][:3]}")
print(f"(Aggregated from {len(json.loads(SAMPLE_JSON_ROOFLINE))} roofline samples)")
# =========================================================================
# Example 6: Model information
# =========================================================================
print("\n=== Example 6: Model Information ===")
info = predictor.get_class_info()
print(f"Number of classes: {info['n_classes']}")
print(f"Number of features: {info['n_features']}")
print(f"Sample classes: {info['classes'][:5]}...")
print(f"Sample features: {info['feature_columns'][:3]}...")
if __name__ == "__main__":
main()

BIN
xgb_model.joblib Normal file

Binary file not shown.