Initial commit
This commit is contained in:
62
.gitignore
vendored
Normal file
62
.gitignore
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
.venv/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.DS_Store
|
||||
|
||||
# Jupyter
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# Model files (optional - uncomment if you don't want to track large models)
|
||||
# *.joblib
|
||||
# *.pkl
|
||||
# *.h5
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Testing
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# Go binaries
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.dylib
|
||||
42
Makefile
Normal file
42
Makefile
Normal file
@@ -0,0 +1,42 @@
|
||||
.PHONY: install run-fastapi run-flask test clean help
|
||||
|
||||
# Default Python interpreter
|
||||
PYTHON := python3
|
||||
|
||||
# Default ports
|
||||
FASTAPI_PORT := 8000
|
||||
FLASK_PORT := 5000
|
||||
|
||||
help:
|
||||
@echo "XGBoost Multi-Label Classification API"
|
||||
@echo ""
|
||||
@echo "Usage:"
|
||||
@echo " make install Install Python dependencies"
|
||||
@echo " make run-fastapi Start FastAPI server (port $(FASTAPI_PORT))"
|
||||
@echo " make run-flask Start Flask server (port $(FLASK_PORT))"
|
||||
@echo " make test Run the inference example"
|
||||
@echo " make clean Clean up cache files"
|
||||
@echo ""
|
||||
|
||||
install:
|
||||
$(PYTHON) -m pip install -r requirements.txt
|
||||
|
||||
run-fastapi:
|
||||
$(PYTHON) xgb_fastapi.py --port $(FASTAPI_PORT)
|
||||
|
||||
run-fastapi-dev:
|
||||
$(PYTHON) xgb_fastapi.py --port $(FASTAPI_PORT) --reload
|
||||
|
||||
run-flask:
|
||||
$(PYTHON) xgb_rest_api.py --port $(FLASK_PORT)
|
||||
|
||||
run-flask-debug:
|
||||
$(PYTHON) xgb_rest_api.py --port $(FLASK_PORT) --debug
|
||||
|
||||
test:
|
||||
$(PYTHON) xgb_inference_example.py
|
||||
|
||||
clean:
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
find . -type f -name "*.pyo" -delete 2>/dev/null || true
|
||||
105
README.md
Normal file
105
README.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# XGBoost Multi-Label Classification API
|
||||
|
||||
A REST API for multi-label classification of HPC workloads using XGBoost. Classifies applications based on roofline performance metrics.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multi-Label Classification**: Predict multiple labels with confidence scores
|
||||
- **Top-K Predictions**: Get the most likely K predictions
|
||||
- **Batch Prediction**: Process multiple samples in a single request
|
||||
- **JSON Aggregation**: Aggregate raw roofline data into features automatically
|
||||
- **Around 60 HPC Application Classes**: Including VASP, GROMACS, TurTLE, Chroma, QuantumESPRESSO, etc.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Create and activate virtual environment
|
||||
python -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Run Tests
|
||||
|
||||
```bash
|
||||
# Python tests
|
||||
pytest test_xgb_fastapi.py -v
|
||||
|
||||
# Curl tests (start server first)
|
||||
./test_api_curl.sh
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/health` | GET | Health check and model status |
|
||||
| `/predict` | POST | Single prediction with confidence scores |
|
||||
| `/predict_top_k` | POST | Get top-K predictions |
|
||||
| `/batch_predict` | POST | Batch prediction for multiple samples |
|
||||
| `/model/info` | GET | Model information |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
|
||||
### Start the Server
|
||||
|
||||
```bash
|
||||
python xgb_fastapi.py --port 8000
|
||||
```
|
||||
|
||||
### Testing with curl
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/predict" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"features": {
|
||||
"bandwidth_raw_p10": 186.33,
|
||||
"bandwidth_raw_median": 205.14,
|
||||
"bandwidth_raw_p90": 210.83,
|
||||
"flops_raw_p10": 162.024,
|
||||
"flops_raw_median": 171.45,
|
||||
"flops_raw_p90": 176.48,
|
||||
"arith_intensity_median": 0.837,
|
||||
"node_num": 0,
|
||||
"duration": 19366
|
||||
},
|
||||
"threshold": 0.3
|
||||
}'
|
||||
```
|
||||
|
||||
### Testing with Python
|
||||
|
||||
```python
|
||||
from xgb_local import XGBoostMultiLabelPredictor
|
||||
|
||||
predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')
|
||||
|
||||
result = predictor.predict(features, threshold=0.3)
|
||||
print(f"Predictions: {result['predictions']}")
|
||||
print(f"Confidences: {result['confidences']}")
|
||||
|
||||
# Top-K predictions
|
||||
top_k = predictor.predict_top_k(features, k=5)
|
||||
for cls, prob in top_k['top_probabilities'].items():
|
||||
print(f"{cls}: {prob:.4f}")
|
||||
```
|
||||
|
||||
See `xgb_local_example.py` for complete examples.
|
||||
|
||||
## Model Features (28 total)
|
||||
|
||||
| Category | Features |
|
||||
|----------|----------|
|
||||
| Bandwidth | `bandwidth_raw_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
|
||||
| FLOPS | `flops_raw_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
|
||||
| Arithmetic Intensity | `arith_intensity_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
|
||||
| Performance | `avg_performance_gflops`, `median_performance_gflops`, `performance_gflops_mad` |
|
||||
| Correlation | `bw_flops_covariance`, `bw_flops_correlation` |
|
||||
| System | `avg_memory_bw_gbs`, `scalar_peak_gflops`, `simd_peak_gflops` |
|
||||
| Other | `node_num`, `duration` |
|
||||
600
cluster.json
Executable file
600
cluster.json
Executable file
@@ -0,0 +1,600 @@
|
||||
{
|
||||
"name": "fritz",
|
||||
"metricConfig": [
|
||||
{
|
||||
"name": "cpu_load",
|
||||
"unit": {
|
||||
"base": ""
|
||||
},
|
||||
"scope": "node",
|
||||
"aggregation": "avg",
|
||||
"footprint": "avg",
|
||||
"timestep": 60,
|
||||
"peak": 72,
|
||||
"normal": 72,
|
||||
"caution": 36,
|
||||
"alert": 20,
|
||||
"subClusters": [
|
||||
{
|
||||
"name": "spr1tb",
|
||||
"footprint": "avg",
|
||||
"peak": 104,
|
||||
"normal": 104,
|
||||
"caution": 52,
|
||||
"alert": 20
|
||||
},
|
||||
{
|
||||
"name": "spr2tb",
|
||||
"footprint": "avg",
|
||||
"peak": 104,
|
||||
"normal": 104,
|
||||
"caution": 52,
|
||||
"alert": 20
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "cpu_user",
|
||||
"unit": {
|
||||
"base": ""
|
||||
},
|
||||
"scope": "hwthread",
|
||||
"aggregation": "avg",
|
||||
"timestep": 60,
|
||||
"peak": 100,
|
||||
"normal": 50,
|
||||
"caution": 20,
|
||||
"alert": 10
|
||||
},
|
||||
{
|
||||
"name": "mem_used",
|
||||
"unit": {
|
||||
"base": "B",
|
||||
"prefix": "G"
|
||||
},
|
||||
"scope": "node",
|
||||
"aggregation": "sum",
|
||||
"footprint": "max",
|
||||
"timestep": 60,
|
||||
"peak": 256,
|
||||
"normal": 128,
|
||||
"caution": 200,
|
||||
"alert": 240,
|
||||
"subClusters": [
|
||||
{
|
||||
"name": "spr1tb",
|
||||
"footprint": "max",
|
||||
"peak": 1024,
|
||||
"normal": 512,
|
||||
"caution": 900,
|
||||
"alert": 1000
|
||||
},
|
||||
{
|
||||
"name": "spr2tb",
|
||||
"footprint": "max",
|
||||
"peak": 2048,
|
||||
"normal": 1024,
|
||||
"caution": 1800,
|
||||
"alert": 2000
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "flops_any",
|
||||
"unit": {
|
||||
"base": "Flops/s",
|
||||
"prefix": "G"
|
||||
},
|
||||
"scope": "hwthread",
|
||||
"aggregation": "sum",
|
||||
"footprint": "avg",
|
||||
"timestep": 60,
|
||||
"peak": 5600,
|
||||
"normal": 1000,
|
||||
"caution": 200,
|
||||
"alert": 50,
|
||||
"subClusters": [
|
||||
{
|
||||
"name": "spr1tb",
|
||||
"peak": 6656,
|
||||
"normal": 1500,
|
||||
"caution": 400,
|
||||
"alert": 50,
|
||||
"footprint": "avg"
|
||||
},
|
||||
{
|
||||
"name": "spr2tb",
|
||||
"peak": 6656,
|
||||
"normal": 1500,
|
||||
"caution": 400,
|
||||
"alert": 50,
|
||||
"footprint": "avg"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "flops_sp",
|
||||
"unit": {
|
||||
"base": "Flops/s",
|
||||
"prefix": "G"
|
||||
},
|
||||
"scope": "hwthread",
|
||||
"aggregation": "sum",
|
||||
"timestep": 60,
|
||||
"peak": 5600,
|
||||
"normal": 1000,
|
||||
"caution": 200,
|
||||
"alert": 50,
|
||||
"subClusters": [
|
||||
{
|
||||
"name": "spr1tb",
|
||||
"peak": 6656,
|
||||
"normal": 1500,
|
||||
"caution": 400,
|
||||
"alert": 50
|
||||
},
|
||||
{
|
||||
"name": "spr2tb",
|
||||
"peak": 6656,
|
||||
"normal": 1500,
|
||||
"caution": 400,
|
||||
"alert": 50
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "flops_dp",
|
||||
"unit": {
|
||||
"base": "Flops/s",
|
||||
"prefix": "G"
|
||||
},
|
||||
"scope": "hwthread",
|
||||
"aggregation": "sum",
|
||||
"timestep": 60,
|
||||
"peak": 2300,
|
||||
"normal": 500,
|
||||
"caution": 100,
|
||||
"alert": 50,
|
||||
"subClusters": [
|
||||
{
|
||||
"name": "spr1tb",
|
||||
"peak": 3300,
|
||||
"normal": 750,
|
||||
"caution": 200,
|
||||
"alert": 50
|
||||
},
|
||||
{
|
||||
"name": "spr2tb",
|
||||
"peak": 3300,
|
||||
"normal": 750,
|
||||
"caution": 200,
|
||||
"alert": 50
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "mem_bw",
|
||||
"unit": {
|
||||
"base": "B/s",
|
||||
"prefix": "G"
|
||||
},
|
||||
"scope": "socket",
|
||||
"aggregation": "sum",
|
||||
"footprint": "avg",
|
||||
"timestep": 60,
|
||||
"peak": 350,
|
||||
"normal": 100,
|
||||
"caution": 50,
|
||||
"alert": 10,
|
||||
"subClusters": [
|
||||
{
|
||||
"name": "spr1tb",
|
||||
"footprint": "avg",
|
||||
"peak": 549,
|
||||
"normal": 200,
|
||||
"caution": 100,
|
||||
"alert": 20
|
||||
},
|
||||
{
|
||||
"name": "spr2tb",
|
||||
"footprint": "avg",
|
||||
"peak": 520,
|
||||
"normal": 200,
|
||||
"caution": 100,
|
||||
"alert": 20
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "clock",
|
||||
"unit": {
|
||||
"base": "Hz",
|
||||
"prefix": "M"
|
||||
},
|
||||
"scope": "hwthread",
|
||||
"aggregation": "avg",
|
||||
"timestep": 60,
|
||||
"peak": 3000,
|
||||
"normal": 2400,
|
||||
"caution": 1800,
|
||||
"alert": 1200,
|
||||
"subClusters": [
|
||||
{
|
||||
"name": "spr1tb",
|
||||
"peak": 3000,
|
||||
"normal": 2000,
|
||||
"caution": 1600,
|
||||
"alert": 1200,
|
||||
"remove": false
|
||||
},
|
||||
{
|
||||
"name": "spr2tb",
|
||||
"peak": 3000,
|
||||
"normal": 2000,
|
||||
"caution": 1600,
|
||||
"alert": 1200,
|
||||
"remove": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "cpu_power",
|
||||
"unit": {
|
||||
"base": "W"
|
||||
},
|
||||
"scope": "socket",
|
||||
"aggregation": "sum",
|
||||
"energy": "power",
|
||||
"timestep": 60,
|
||||
"peak": 500,
|
||||
"normal": 250,
|
||||
"caution": 100,
|
||||
"alert": 50,
|
||||
"subClusters": [
|
||||
{
|
||||
"name": "spr1tb",
|
||||
"peak": 700,
|
||||
"energy": "power",
|
||||
"normal": 350,
|
||||
"caution": 150,
|
||||
"alert": 50
|
||||
},
|
||||
{
|
||||
"name": "spr2tb",
|
||||
"peak": 700,
|
||||
"energy": "power",
|
||||
"normal": 350,
|
||||
"caution": 150,
|
||||
"alert": 50
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "mem_power",
|
||||
"unit": {
|
||||
"base": "W"
|
||||
},
|
||||
"scope": "socket",
|
||||
"aggregation": "sum",
|
||||
"energy": "power",
|
||||
"timestep": 60,
|
||||
"peak": 100,
|
||||
"normal": 50,
|
||||
"caution": 20,
|
||||
"alert": 10,
|
||||
"subClusters": [
|
||||
{
|
||||
"name": "spr1tb",
|
||||
"peak": 400,
|
||||
"energy": "power",
|
||||
"normal": 200,
|
||||
"caution": 80,
|
||||
"alert": 40
|
||||
},
|
||||
{
|
||||
"name": "spr2tb",
|
||||
"peak": 800,
|
||||
"energy": "power",
|
||||
"normal": 400,
|
||||
"caution": 160,
|
||||
"alert": 80
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "ipc",
|
||||
"unit": {
|
||||
"base": "IPC"
|
||||
},
|
||||
"scope": "hwthread",
|
||||
"aggregation": "avg",
|
||||
"timestep": 60,
|
||||
"peak": 4,
|
||||
"normal": 2,
|
||||
"caution": 1,
|
||||
"alert": 0.5,
|
||||
"subClusters": [
|
||||
{
|
||||
"name": "spr1tb",
|
||||
"peak": 6,
|
||||
"normal": 2,
|
||||
"caution": 1,
|
||||
"alert": 0.5
|
||||
},
|
||||
{
|
||||
"name": "spr2tb",
|
||||
"peak": 6,
|
||||
"normal": 2,
|
||||
"caution": 1,
|
||||
"alert": 0.5
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "vectorization_ratio",
|
||||
"unit": {
|
||||
"base": ""
|
||||
},
|
||||
"scope": "hwthread",
|
||||
"aggregation": "avg",
|
||||
"timestep": 60,
|
||||
"peak": 100,
|
||||
"normal": 60,
|
||||
"caution": 40,
|
||||
"alert": 10,
|
||||
"subClusters": [
|
||||
{
|
||||
"name": "spr1tb",
|
||||
"peak": 100,
|
||||
"normal": 60,
|
||||
"caution": 40,
|
||||
"alert": 10
|
||||
},
|
||||
{
|
||||
"name": "spr2tb",
|
||||
"peak": 100,
|
||||
"normal": 60,
|
||||
"caution": 40,
|
||||
"alert": 10
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "ib_recv",
|
||||
"unit": {
|
||||
"base": "B/s"
|
||||
},
|
||||
"scope": "node",
|
||||
"aggregation": "sum",
|
||||
"timestep": 60,
|
||||
"peak": 1250000,
|
||||
"normal": 6000000,
|
||||
"caution": 200,
|
||||
"alert": 1
|
||||
},
|
||||
{
|
||||
"name": "ib_xmit",
|
||||
"unit": {
|
||||
"base": "B/s"
|
||||
},
|
||||
"scope": "node",
|
||||
"aggregation": "sum",
|
||||
"timestep": 60,
|
||||
"peak": 1250000,
|
||||
"normal": 6000000,
|
||||
"caution": 200,
|
||||
"alert": 1
|
||||
},
|
||||
{
|
||||
"name": "ib_recv_pkts",
|
||||
"unit": {
|
||||
"base": "packets/s"
|
||||
},
|
||||
"scope": "node",
|
||||
"aggregation": "sum",
|
||||
"timestep": 60,
|
||||
"peak": 6,
|
||||
"normal": 4,
|
||||
"caution": 2,
|
||||
"alert": 1
|
||||
},
|
||||
{
|
||||
"name": "ib_xmit_pkts",
|
||||
"unit": {
|
||||
"base": "packets/s"
|
||||
},
|
||||
"scope": "node",
|
||||
"aggregation": "sum",
|
||||
"timestep": 60,
|
||||
"peak": 6,
|
||||
"normal": 4,
|
||||
"caution": 2,
|
||||
"alert": 1
|
||||
},
|
||||
{
|
||||
"name": "nfs4_read",
|
||||
"unit": {
|
||||
"base": "IOP",
|
||||
"prefix": ""
|
||||
},
|
||||
"scope": "node",
|
||||
"aggregation": "sum",
|
||||
"timestep": 60,
|
||||
"peak": 1000,
|
||||
"normal": 50,
|
||||
"caution": 200,
|
||||
"alert": 500
|
||||
},
|
||||
{
|
||||
"name": "nfs4_total",
|
||||
"unit": {
|
||||
"base": "IOP",
|
||||
"prefix": ""
|
||||
},
|
||||
"scope": "node",
|
||||
"aggregation": "sum",
|
||||
"timestep": 60,
|
||||
"peak": 1000,
|
||||
"normal": 50,
|
||||
"caution": 200,
|
||||
"alert": 500
|
||||
}
|
||||
],
|
||||
"subClusters": [
|
||||
{
|
||||
"name": "main",
|
||||
"nodes": "f[0101-0188,0201-0288,0301-0388,0401-0488,0501-0588,0601-0688,0701-0788,0801-0888,0901-0988,1001-1088,1101-1156,1201-1256]",
|
||||
"processorType": "Intel Icelake",
|
||||
"socketsPerNode": 2,
|
||||
"coresPerSocket": 36,
|
||||
"threadsPerCore": 1,
|
||||
"flopRateScalar": {
|
||||
"unit": {
|
||||
"base": "F/s",
|
||||
"prefix": "G"
|
||||
},
|
||||
"value": 432
|
||||
},
|
||||
"flopRateSimd": {
|
||||
"unit": {
|
||||
"base": "F/s",
|
||||
"prefix": "G"
|
||||
},
|
||||
"value": 9216
|
||||
},
|
||||
"memoryBandwidth": {
|
||||
"unit": {
|
||||
"base": "B/s",
|
||||
"prefix": "G"
|
||||
},
|
||||
"value": 350
|
||||
},
|
||||
"topology": {
|
||||
"node": [
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71
|
||||
],
|
||||
"socket": [
|
||||
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35 ],
|
||||
[ 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71 ]
|
||||
],
|
||||
"memoryDomain": [
|
||||
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17 ],
|
||||
[ 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35 ],
|
||||
[ 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53 ],
|
||||
[ 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71 ]
|
||||
],
|
||||
"core": [
|
||||
[ 0 ], [ 1 ], [ 2 ], [ 3 ], [ 4 ], [ 5 ], [ 6 ], [ 7 ], [ 8 ], [ 9 ], [ 10 ], [ 11 ], [ 12 ], [ 13 ], [ 14 ], [ 15 ], [ 16 ], [ 17 ], [ 18 ], [ 19 ], [ 20 ], [ 21 ], [ 22 ], [ 23 ], [ 24 ], [ 25 ], [ 26 ], [ 27 ], [ 28 ], [ 29 ], [ 30 ], [ 31 ], [ 32 ], [ 33 ], [ 34 ], [ 35 ], [ 36 ], [ 37 ], [ 38 ], [ 39 ], [ 40 ], [ 41 ], [ 42 ], [ 43 ], [ 44 ], [ 45 ], [ 46 ], [ 47 ], [ 48 ], [ 49 ], [ 50 ], [ 51 ], [ 52 ], [ 53 ], [ 54 ], [ 55 ], [ 56 ], [ 57 ], [ 58 ], [ 59 ], [ 60 ], [ 61 ], [ 62 ], [ 63 ], [ 64 ], [ 65 ], [ 66 ], [ 67 ], [ 68 ], [ 69 ], [ 70 ], [ 71 ]
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "spr1tb",
|
||||
"processorType": "Intel(R) Xeon(R) Platinum 8470",
|
||||
"socketsPerNode": 2,
|
||||
"coresPerSocket": 52,
|
||||
"threadsPerCore": 1,
|
||||
"flopRateScalar": {
|
||||
"unit": {
|
||||
"base": "F/s",
|
||||
"prefix": "G"
|
||||
},
|
||||
"value": 695
|
||||
},
|
||||
"flopRateSimd": {
|
||||
"unit": {
|
||||
"base": "F/s",
|
||||
"prefix": "G"
|
||||
},
|
||||
"value": 9216
|
||||
},
|
||||
"memoryBandwidth": {
|
||||
"unit": {
|
||||
"base": "B/s",
|
||||
"prefix": "G"
|
||||
},
|
||||
"value": 549
|
||||
},
|
||||
"nodes": "f[2157-2180,2257-2280]",
|
||||
"topology": {
|
||||
"node":
|
||||
[
|
||||
0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,5152,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103
|
||||
],
|
||||
"socket":
|
||||
[
|
||||
[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51],
|
||||
[52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103]
|
||||
],
|
||||
"memoryDomain": [
|
||||
[0,1,2,3,4,5,6,7,8,9,10,11,12],
|
||||
[13,14,15,16,17,18,19,20,21,22,23,24,25],
|
||||
[26,27,28,29,30,31,32,33,34,35,36,37,38],
|
||||
[39,40,41,42,43,44,45,46,47,48,49,50,51],
|
||||
[52,53,54,55,56,57,58,59,60,61,62,63,64],
|
||||
[65,66,67,68,69,70,71,72,73,74,75,76,77],
|
||||
[78,79,80,81,82,83,84,85,86,87,88,89,90],
|
||||
[91,92,93,94,95,96,97,98,99,100,101,102,103]
|
||||
],
|
||||
"core": [
|
||||
[0],[1],[2],[3],[4],[5],[6],[7],[8],[9],[10],[11],[12],[13],[14],[15],[16],[17],[18],[19],[20],[21],[22],[23],[24],[25],[26],[27],[28],[29],[30],[31],[32],[33],[34],[35],[36],[37],[38],[39],[40],[41],[42],[43],[44],[45],[46],[47],[48],[49],[50],[51],[52],[53],[54],[55],[56],[57],[58],[59],[60],[61],[62],[63],[64],[65],[66],[67],[68],[69],[70],[71],[72],[73],[74],[75],[76],[77],[78],[79],[80],[81],[82],[83],[84],[85],[86],[87],[88],[89],[90],[91],[92],[93],[94],[95],[96],[97],[98],[99],[100],[101],[102],[103]
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "spr2tb",
|
||||
"processorType": "Intel(R) Xeon(R) Platinum 8470",
|
||||
"socketsPerNode": 2,
|
||||
"coresPerSocket": 52,
|
||||
"threadsPerCore": 1,
|
||||
"flopRateScalar": {
|
||||
"unit": {
|
||||
"base": "F/s",
|
||||
"prefix": "G"
|
||||
},
|
||||
"value": 695
|
||||
},
|
||||
"flopRateSimd": {
|
||||
"unit": {
|
||||
"base": "F/s",
|
||||
"prefix": "G"
|
||||
},
|
||||
"value": 9216
|
||||
},
|
||||
"memoryBandwidth": {
|
||||
"unit": {
|
||||
"base": "B/s",
|
||||
"prefix": "G"
|
||||
},
|
||||
"value": 515
|
||||
},
|
||||
"nodes": "f[2181-2188,2281-2288]",
|
||||
"topology": {
|
||||
"node": [
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103
|
||||
],
|
||||
"socket": [
|
||||
[
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51
|
||||
],
|
||||
[
|
||||
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103
|
||||
]
|
||||
],
|
||||
"memoryDomain": [
|
||||
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 ],
|
||||
[ 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 ],
|
||||
[ 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38 ],
|
||||
[ 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51 ],
|
||||
[ 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64 ],
|
||||
[ 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77 ],
|
||||
[ 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90 ],
|
||||
[ 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103 ]
|
||||
],
|
||||
"core": [
|
||||
[ 0 ], [ 1 ], [ 2 ], [ 3 ], [ 4 ], [ 5 ], [ 6 ], [ 7 ], [ 8 ], [ 9 ], [ 10 ], [ 11 ], [ 12 ], [ 13 ], [ 14 ], [ 15 ], [ 16 ], [ 17 ], [ 18 ], [ 19 ], [ 20 ], [ 21 ], [ 22 ], [ 23 ], [ 24 ], [ 25 ], [ 26 ], [ 27 ], [ 28 ], [ 29 ], [ 30 ], [ 31 ], [ 32 ], [ 33 ], [ 34 ], [ 35 ], [ 36 ], [ 37 ], [ 38 ], [ 39 ], [ 40 ], [ 41 ], [ 42 ], [ 43 ], [ 44 ], [ 45 ], [ 46 ], [ 47 ], [ 48 ], [ 49 ], [ 50 ], [ 51 ], [ 52 ], [ 53 ], [ 54 ], [ 55 ], [ 56 ], [ 57 ], [ 58 ], [ 59 ], [ 60 ], [ 61 ], [ 62 ], [ 63 ], [ 64 ], [ 65 ], [ 66 ], [ 67 ], [ 68 ], [ 69 ], [ 70 ], [ 71 ], [ 72 ], [ 73 ], [ 74 ], [ 75 ], [ 76 ], [ 77 ], [ 78 ], [ 79 ], [ 80 ], [ 81 ], [ 82 ], [ 83 ], [ 84 ], [ 85 ], [ 86 ], [ 87 ], [ 88 ], [ 89 ], [ 90 ], [ 91 ], [ 92 ], [ 93 ], [ 94 ], [ 95 ], [ 96 ], [ 97 ], [ 98 ], [ 99 ], [ 100 ], [ 101 ], [ 102 ], [ 103 ]
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
172
feature_aggregator.py
Normal file
172
feature_aggregator.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import sqlite3
|
||||
import h5py
|
||||
from glob import glob
|
||||
from scipy import stats
|
||||
from pathlib import Path
|
||||
|
||||
def compute_mad(data):
|
||||
"""Compute Median Absolute Deviation: a robust measure of dispersion"""
|
||||
return np.median(np.abs(data - np.median(data)))
|
||||
|
||||
def process_roofline_data(base_dir='D:/roofline_dataframes',
|
||||
output_file='D:/roofline_features.h5',
|
||||
job_tags_db='job_tags.db'):
|
||||
"""
|
||||
Process roofline data to extract features for machine learning.
|
||||
|
||||
Args:
|
||||
base_dir: Directory containing roofline dataframes
|
||||
output_file: Path to save the output features
|
||||
job_tags_db: Path to SQLite database with job tags
|
||||
"""
|
||||
# Connect to the SQLite database
|
||||
conn = sqlite3.connect(job_tags_db)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# List to store all job features
|
||||
all_job_features = []
|
||||
|
||||
# Find all job prefix folders in the base directory
|
||||
job_prefixes = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
|
||||
|
||||
for job_id_prefix in job_prefixes:
|
||||
job_prefix_path = os.path.join(base_dir, job_id_prefix)
|
||||
|
||||
# Find all h5 files in this job prefix folder
|
||||
h5_files = glob(os.path.join(job_prefix_path, '*_dataframe.h5'))
|
||||
|
||||
for h5_file in h5_files:
|
||||
filename = os.path.basename(h5_file)
|
||||
# Extract job_id_full from the filename pattern: {job_id_full}_{timestamp}_dataframe.h5
|
||||
job_id_full = filename.split('_')[0]
|
||||
|
||||
try:
|
||||
# Read the dataframe from the h5 file
|
||||
with h5py.File(h5_file, 'r') as f:
|
||||
if 'dataframe' not in f:
|
||||
print(f"Warning: No 'dataframe' key in {h5_file}")
|
||||
continue
|
||||
|
||||
df = pd.read_hdf(h5_file, key='dataframe')
|
||||
|
||||
# Group data by node_num
|
||||
grouped = df.groupby('node_num')
|
||||
|
||||
for node_num, group in grouped:
|
||||
features = {
|
||||
'job_id': job_id_full,
|
||||
'node_num': node_num
|
||||
}
|
||||
|
||||
# Compute statistics for key metrics
|
||||
for axis in ['bandwidth_raw', 'flops_raw', 'arith_intensity']:
|
||||
data = group[axis].values
|
||||
|
||||
# Compute percentiles
|
||||
p10 = np.percentile(data, 10)
|
||||
p50 = np.median(data)
|
||||
p90 = np.percentile(data, 90)
|
||||
|
||||
# Compute MAD (more robust than variance)
|
||||
mad = compute_mad(data)
|
||||
|
||||
# Store features
|
||||
features[f'{axis}_p10'] = p10
|
||||
features[f'{axis}_median'] = p50
|
||||
features[f'{axis}_p90'] = p90
|
||||
features[f'{axis}_mad'] = mad
|
||||
features[f'{axis}_range'] = p90 - p10
|
||||
features[f'{axis}_iqr'] = np.percentile(data, 75) - np.percentile(data, 25)
|
||||
|
||||
# Compute covariance and correlation between bandwidth_raw and flops_raw
|
||||
if len(group) > 1: # Need at least 2 points for correlation
|
||||
cov = np.cov(group['bandwidth_raw'], group['flops_raw'])[0, 1]
|
||||
features['bw_flops_covariance'] = cov
|
||||
|
||||
corr, _ = stats.pearsonr(group['bandwidth_raw'], group['flops_raw'])
|
||||
features['bw_flops_correlation'] = corr
|
||||
|
||||
# Additional useful features for the classifier
|
||||
|
||||
# Performance metrics
|
||||
features['avg_performance_gflops'] = group['performance_gflops'].mean()
|
||||
features['median_performance_gflops'] = group['performance_gflops'].median()
|
||||
features['performance_gflops_mad'] = compute_mad(group['performance_gflops'].values)
|
||||
|
||||
# # Efficiency metrics
|
||||
# features['avg_efficiency'] = group['efficiency'].mean()
|
||||
# features['median_efficiency'] = group['efficiency'].median()
|
||||
# features['efficiency_mad'] = compute_mad(group['efficiency'].values)
|
||||
# features['efficiency_p10'] = np.percentile(group['efficiency'].values, 10)
|
||||
# features['efficiency_p90'] = np.percentile(group['efficiency'].values, 90)
|
||||
|
||||
# # Distribution of roofline regions (memory-bound vs compute-bound)
|
||||
# if 'roofline_region' in group.columns:
|
||||
# region_counts = group['roofline_region'].value_counts(normalize=True).to_dict()
|
||||
# for region, ratio in region_counts.items():
|
||||
# features[f'region_{region}_ratio'] = ratio
|
||||
|
||||
# System characteristics
|
||||
if 'memory_bw_gbs' in group.columns:
|
||||
features['avg_memory_bw_gbs'] = group['memory_bw_gbs'].mean()
|
||||
if 'scalar_peak_gflops' in group.columns and len(group['scalar_peak_gflops'].unique()) > 0:
|
||||
features['scalar_peak_gflops'] = group['scalar_peak_gflops'].iloc[0]
|
||||
if 'simd_peak_gflops' in group.columns and len(group['simd_peak_gflops'].unique()) > 0:
|
||||
features['simd_peak_gflops'] = group['simd_peak_gflops'].iloc[0]
|
||||
|
||||
# # Subcluster information if available
|
||||
# if 'subcluster_name' in group.columns and not group['subcluster_name'].isna().all():
|
||||
# features['subcluster_name'] = group['subcluster_name'].iloc[0]
|
||||
|
||||
# Duration information
|
||||
if 'duration' in group.columns:
|
||||
features['duration'] = group['duration'].iloc[0]
|
||||
|
||||
# Get the label (application type) from the database
|
||||
cursor.execute("SELECT tags FROM job_tags WHERE job_id = ?", (int(job_id_full),))
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
# Extract application name from tags
|
||||
tags = result[0]
|
||||
features['label'] = tags
|
||||
else:
|
||||
features['label'] = 'Unknown'
|
||||
|
||||
all_job_features.append(features)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing file {h5_file}: {e}")
|
||||
|
||||
# Close database connection
|
||||
conn.close()
|
||||
|
||||
if not all_job_features:
|
||||
print("No features extracted. Check if files exist and have the correct format.")
|
||||
return
|
||||
|
||||
# Convert to DataFrame
|
||||
features_df = pd.DataFrame(all_job_features)
|
||||
|
||||
# Fill missing roofline region ratios with 0
|
||||
region_columns = [col for col in features_df.columns if col.startswith('region_')]
|
||||
for col in region_columns:
|
||||
if col not in features_df.columns:
|
||||
features_df[col] = 0
|
||||
else:
|
||||
features_df[col] = features_df[col].fillna(0)
|
||||
|
||||
# Save to H5 file
|
||||
features_df.to_hdf(output_file, key='features', mode='w')
|
||||
|
||||
print(f"Processed {len(all_job_features)} job-node combinations")
|
||||
print(f"Features saved to {output_file}")
|
||||
print(f"Feature columns: {', '.join(features_df.columns)}")
|
||||
|
||||
return features_df
|
||||
|
||||
if __name__ == "__main__":
|
||||
process_roofline_data()
|
||||
20
requirements.txt
Normal file
20
requirements.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
# Core ML dependencies
|
||||
numpy>=1.21.0
|
||||
pandas>=1.3.0
|
||||
scipy>=1.7.0
|
||||
joblib>=1.0.0
|
||||
xgboost>=1.5.0
|
||||
scikit-learn>=1.0.0
|
||||
|
||||
# FastAPI server
|
||||
fastapi>=0.68.0
|
||||
uvicorn>=0.15.0
|
||||
pydantic>=1.8.0
|
||||
|
||||
# Flask server (alternative)
|
||||
flask>=2.0.0
|
||||
flask-cors>=3.0.0
|
||||
|
||||
# Testing
|
||||
pytest>=7.0.0
|
||||
httpx>=0.23.0
|
||||
436
test_api_curl.sh
Executable file
436
test_api_curl.sh
Executable file
@@ -0,0 +1,436 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Test script for XGBoost FastAPI Multi-Label Classification API
|
||||
# Uses curl to test all endpoints with realistic sample data from roofline_features.h5
|
||||
#
|
||||
# Usage:
|
||||
# 1. Start the server: ./venv/bin/python xgb_fastapi.py
|
||||
# 2. Run this script: ./test_api_curl.sh
|
||||
#
|
||||
# Optional: Set API_URL environment variable to test a different host
|
||||
# API_URL=http://192.168.1.100:8000 ./test_api_curl.sh
|
||||
|
||||
set -e
|
||||
|
||||
# Configuration
|
||||
API_URL="${API_URL:-http://localhost:8000}"
|
||||
VERBOSE="${VERBOSE:-false}"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Counters
|
||||
PASSED=0
|
||||
FAILED=0
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
print_header() {
|
||||
echo ""
|
||||
echo -e "${BLUE}============================================================${NC}"
|
||||
echo -e "${BLUE}$1${NC}"
|
||||
echo -e "${BLUE}============================================================${NC}"
|
||||
}
|
||||
|
||||
print_test() {
|
||||
echo -e "\n${YELLOW}▶ TEST: $1${NC}"
|
||||
}
|
||||
|
||||
print_success() {
|
||||
echo -e "${GREEN}✓ PASSED: $1${NC}"
|
||||
((PASSED++))
|
||||
}
|
||||
|
||||
print_failure() {
|
||||
echo -e "${RED}✗ FAILED: $1${NC}"
|
||||
((FAILED++))
|
||||
}
|
||||
|
||||
# Make a curl request and check status code
|
||||
# Usage: make_request METHOD ENDPOINT [DATA]
|
||||
make_request() {
|
||||
local method=$1
|
||||
local endpoint=$2
|
||||
local data=$3
|
||||
local expected_status=${4:-200}
|
||||
|
||||
if [ "$method" == "GET" ]; then
|
||||
response=$(curl -s -w "\n%{http_code}" "${API_URL}${endpoint}")
|
||||
else
|
||||
response=$(curl -s -w "\n%{http_code}" -X "$method" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$data" \
|
||||
"${API_URL}${endpoint}")
|
||||
fi
|
||||
|
||||
# Extract status code (last line) and body (everything else)
|
||||
http_code=$(echo "$response" | tail -n1)
|
||||
body=$(echo "$response" | sed '$d')
|
||||
|
||||
if [ "$VERBOSE" == "true" ]; then
|
||||
echo "Response: $body"
|
||||
echo "Status: $http_code"
|
||||
fi
|
||||
|
||||
if [ "$http_code" == "$expected_status" ]; then
|
||||
return 0
|
||||
else
|
||||
echo "Expected status $expected_status, got $http_code"
|
||||
echo "Response: $body"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Sample Data (from roofline_features.h5)
|
||||
# ============================================================================
|
||||
|
||||
# TurTLE application - turbulence simulation workload
|
||||
SAMPLE_TURTLE='{
|
||||
"bandwidth_raw_p10": 186.33,
|
||||
"bandwidth_raw_median": 205.14,
|
||||
"bandwidth_raw_p90": 210.83,
|
||||
"bandwidth_raw_mad": 3.57,
|
||||
"bandwidth_raw_range": 24.5,
|
||||
"bandwidth_raw_iqr": 12.075,
|
||||
"flops_raw_p10": 162.024,
|
||||
"flops_raw_median": 171.45,
|
||||
"flops_raw_p90": 176.48,
|
||||
"flops_raw_mad": 3.08,
|
||||
"flops_raw_range": 14.456,
|
||||
"flops_raw_iqr": 8.29,
|
||||
"arith_intensity_p10": 0.7906,
|
||||
"arith_intensity_median": 0.837,
|
||||
"arith_intensity_p90": 0.9109,
|
||||
"arith_intensity_mad": 0.02,
|
||||
"arith_intensity_range": 0.12,
|
||||
"arith_intensity_iqr": 0.0425,
|
||||
"bw_flops_covariance": 60.86,
|
||||
"bw_flops_correlation": 0.16,
|
||||
"avg_performance_gflops": 168.1,
|
||||
"median_performance_gflops": 171.45,
|
||||
"performance_gflops_mad": 3.08,
|
||||
"avg_memory_bw_gbs": 350.0,
|
||||
"scalar_peak_gflops": 432.0,
|
||||
"simd_peak_gflops": 9216.0,
|
||||
"node_num": 0,
|
||||
"duration": 19366
|
||||
}'
|
||||
|
||||
# Chroma application - lattice QCD workload (compute-intensive)
|
||||
SAMPLE_CHROMA='{
|
||||
"bandwidth_raw_p10": 154.176,
|
||||
"bandwidth_raw_median": 200.57,
|
||||
"bandwidth_raw_p90": 259.952,
|
||||
"bandwidth_raw_mad": 5.12,
|
||||
"bandwidth_raw_range": 105.776,
|
||||
"bandwidth_raw_iqr": 10.215,
|
||||
"flops_raw_p10": 327.966,
|
||||
"flops_raw_median": 519.8,
|
||||
"flops_raw_p90": 654.422,
|
||||
"flops_raw_mad": 16.97,
|
||||
"flops_raw_range": 326.456,
|
||||
"flops_raw_iqr": 34.88,
|
||||
"arith_intensity_p10": 1.55,
|
||||
"arith_intensity_median": 2.595,
|
||||
"arith_intensity_p90": 3.445,
|
||||
"arith_intensity_mad": 0.254,
|
||||
"arith_intensity_range": 1.894,
|
||||
"arith_intensity_iqr": 0.512,
|
||||
"bw_flops_covariance": 382.76,
|
||||
"bw_flops_correlation": 0.063,
|
||||
"avg_performance_gflops": 503.26,
|
||||
"median_performance_gflops": 519.8,
|
||||
"performance_gflops_mad": 16.97,
|
||||
"avg_memory_bw_gbs": 350.0,
|
||||
"scalar_peak_gflops": 432.0,
|
||||
"simd_peak_gflops": 9216.0,
|
||||
"node_num": 3,
|
||||
"duration": 31133
|
||||
}'
|
||||
|
||||
# Raw JSON roofline data (before aggregation)
|
||||
SAMPLE_JSON_ROOFLINE='[
|
||||
{"node_num": 1, "bandwidth_raw": 150.5, "flops_raw": 2500.0, "arith_intensity": 16.6, "performance_gflops": 1200.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
|
||||
{"node_num": 1, "bandwidth_raw": 155.2, "flops_raw": 2600.0, "arith_intensity": 16.8, "performance_gflops": 1250.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
|
||||
{"node_num": 1, "bandwidth_raw": 148.0, "flops_raw": 2450.0, "arith_intensity": 16.5, "performance_gflops": 1180.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600}
|
||||
]'
|
||||
|
||||
# ============================================================================
|
||||
# Tests
|
||||
# ============================================================================
|
||||
|
||||
print_header "XGBoost FastAPI Test Suite"
|
||||
echo "Testing API at: $API_URL"
|
||||
echo "Started at: $(date)"
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Health & Info Endpoints
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
print_header "Health & Info Endpoints"
|
||||
|
||||
print_test "Root Endpoint"
|
||||
if make_request GET "/"; then
|
||||
print_success "Root endpoint accessible"
|
||||
else
|
||||
print_failure "Root endpoint failed"
|
||||
fi
|
||||
|
||||
print_test "Health Check"
|
||||
if make_request GET "/health"; then
|
||||
print_success "Health check passed"
|
||||
else
|
||||
print_failure "Health check failed"
|
||||
fi
|
||||
|
||||
print_test "Model Info"
|
||||
if make_request GET "/model/info"; then
|
||||
print_success "Model info retrieved"
|
||||
else
|
||||
print_failure "Model info failed"
|
||||
fi
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Single Prediction Endpoints
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
print_header "Single Prediction Endpoint (/predict)"
|
||||
|
||||
print_test "Predict with TurTLE sample (threshold=0.3)"
|
||||
REQUEST_DATA=$(cat <<EOF
|
||||
{
|
||||
"features": $SAMPLE_TURTLE,
|
||||
"threshold": 0.3,
|
||||
"return_all_probabilities": true
|
||||
}
|
||||
EOF
|
||||
)
|
||||
if make_request POST "/predict" "$REQUEST_DATA"; then
|
||||
print_success "TurTLE prediction completed"
|
||||
else
|
||||
print_failure "TurTLE prediction failed"
|
||||
fi
|
||||
|
||||
print_test "Predict with Chroma sample (threshold=0.5)"
|
||||
REQUEST_DATA=$(cat <<EOF
|
||||
{
|
||||
"features": $SAMPLE_CHROMA,
|
||||
"threshold": 0.5
|
||||
}
|
||||
EOF
|
||||
)
|
||||
if make_request POST "/predict" "$REQUEST_DATA"; then
|
||||
print_success "Chroma prediction completed"
|
||||
else
|
||||
print_failure "Chroma prediction failed"
|
||||
fi
|
||||
|
||||
print_test "Predict with JSON roofline data (is_json=true)"
|
||||
# Need to escape the JSON for embedding
|
||||
ESCAPED_JSON=$(echo "$SAMPLE_JSON_ROOFLINE" | tr -d '\n' | sed 's/"/\\"/g')
|
||||
REQUEST_DATA=$(cat <<EOF
|
||||
{
|
||||
"features": "$ESCAPED_JSON",
|
||||
"is_json": true,
|
||||
"job_id": "test_job_curl_001",
|
||||
"threshold": 0.3
|
||||
}
|
||||
EOF
|
||||
)
|
||||
if make_request POST "/predict" "$REQUEST_DATA"; then
|
||||
print_success "JSON roofline prediction completed"
|
||||
else
|
||||
print_failure "JSON roofline prediction failed"
|
||||
fi
|
||||
|
||||
print_test "Predict with return_all_probabilities=false"
|
||||
REQUEST_DATA=$(cat <<EOF
|
||||
{
|
||||
"features": $SAMPLE_TURTLE,
|
||||
"threshold": 0.3,
|
||||
"return_all_probabilities": false
|
||||
}
|
||||
EOF
|
||||
)
|
||||
if make_request POST "/predict" "$REQUEST_DATA"; then
|
||||
print_success "Prediction with filtered probabilities completed"
|
||||
else
|
||||
print_failure "Prediction with filtered probabilities failed"
|
||||
fi
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Top-K Prediction Endpoints
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
print_header "Top-K Prediction Endpoint (/predict_top_k)"
|
||||
|
||||
print_test "Top-5 predictions (default)"
|
||||
REQUEST_DATA=$(cat <<EOF
|
||||
{
|
||||
"features": $SAMPLE_TURTLE
|
||||
}
|
||||
EOF
|
||||
)
|
||||
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
|
||||
print_success "Top-5 prediction completed"
|
||||
else
|
||||
print_failure "Top-5 prediction failed"
|
||||
fi
|
||||
|
||||
print_test "Top-10 predictions"
|
||||
REQUEST_DATA=$(cat <<EOF
|
||||
{
|
||||
"features": $SAMPLE_CHROMA,
|
||||
"k": 10
|
||||
}
|
||||
EOF
|
||||
)
|
||||
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
|
||||
print_success "Top-10 prediction completed"
|
||||
else
|
||||
print_failure "Top-10 prediction failed"
|
||||
fi
|
||||
|
||||
print_test "Top-3 predictions"
|
||||
REQUEST_DATA=$(cat <<EOF
|
||||
{
|
||||
"features": $SAMPLE_TURTLE,
|
||||
"k": 3
|
||||
}
|
||||
EOF
|
||||
)
|
||||
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
|
||||
print_success "Top-3 prediction completed"
|
||||
else
|
||||
print_failure "Top-3 prediction failed"
|
||||
fi
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Batch Prediction Endpoints
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
print_header "Batch Prediction Endpoint (/batch_predict)"
|
||||
|
||||
print_test "Batch predict with 2 samples"
|
||||
REQUEST_DATA=$(cat <<EOF
|
||||
{
|
||||
"features_list": [$SAMPLE_TURTLE, $SAMPLE_CHROMA],
|
||||
"threshold": 0.3
|
||||
}
|
||||
EOF
|
||||
)
|
||||
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
|
||||
print_success "Batch prediction (2 samples) completed"
|
||||
else
|
||||
print_failure "Batch prediction (2 samples) failed"
|
||||
fi
|
||||
|
||||
print_test "Batch predict with single sample"
|
||||
REQUEST_DATA=$(cat <<EOF
|
||||
{
|
||||
"features_list": [$SAMPLE_TURTLE],
|
||||
"threshold": 0.5
|
||||
}
|
||||
EOF
|
||||
)
|
||||
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
|
||||
print_success "Batch prediction (single sample) completed"
|
||||
else
|
||||
print_failure "Batch prediction (single sample) failed"
|
||||
fi
|
||||
|
||||
print_test "Batch predict with empty list"
|
||||
REQUEST_DATA='{"features_list": [], "threshold": 0.5}'
|
||||
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
|
||||
print_success "Batch prediction (empty list) completed"
|
||||
else
|
||||
print_failure "Batch prediction (empty list) failed"
|
||||
fi
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Error Handling Tests
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
print_header "Error Handling Tests"
|
||||
|
||||
print_test "Invalid threshold (negative) - should return 422"
|
||||
REQUEST_DATA=$(cat <<EOF
|
||||
{
|
||||
"features": $SAMPLE_TURTLE,
|
||||
"threshold": -0.5
|
||||
}
|
||||
EOF
|
||||
)
|
||||
if make_request POST "/predict" "$REQUEST_DATA" 422; then
|
||||
print_success "Invalid threshold correctly rejected"
|
||||
else
|
||||
print_failure "Invalid threshold not rejected properly"
|
||||
fi
|
||||
|
||||
print_test "Invalid threshold (> 1.0) - should return 422"
|
||||
REQUEST_DATA=$(cat <<EOF
|
||||
{
|
||||
"features": $SAMPLE_TURTLE,
|
||||
"threshold": 1.5
|
||||
}
|
||||
EOF
|
||||
)
|
||||
if make_request POST "/predict" "$REQUEST_DATA" 422; then
|
||||
print_success "Invalid threshold (>1) correctly rejected"
|
||||
else
|
||||
print_failure "Invalid threshold (>1) not rejected properly"
|
||||
fi
|
||||
|
||||
print_test "Missing features field - should return 422"
|
||||
REQUEST_DATA='{"threshold": 0.5}'
|
||||
if make_request POST "/predict" "$REQUEST_DATA" 422; then
|
||||
print_success "Missing features correctly rejected"
|
||||
else
|
||||
print_failure "Missing features not rejected properly"
|
||||
fi
|
||||
|
||||
print_test "Invalid endpoint - should return 404"
|
||||
if make_request GET "/nonexistent" "" 404; then
|
||||
print_success "Invalid endpoint correctly returns 404"
|
||||
else
|
||||
print_failure "Invalid endpoint not handled properly"
|
||||
fi
|
||||
|
||||
print_test "Invalid k value (0) - should return 422"
|
||||
REQUEST_DATA=$(cat <<EOF
|
||||
{
|
||||
"features": $SAMPLE_TURTLE,
|
||||
"k": 0
|
||||
}
|
||||
EOF
|
||||
)
|
||||
if make_request POST "/predict_top_k" "$REQUEST_DATA" 422; then
|
||||
print_success "Invalid k value correctly rejected"
|
||||
else
|
||||
print_failure "Invalid k value not rejected properly"
|
||||
fi
|
||||
|
||||
# ============================================================================
|
||||
# Summary
|
||||
# ============================================================================
|
||||
|
||||
print_header "Test Summary"
|
||||
echo ""
|
||||
echo -e "Passed: ${GREEN}$PASSED${NC}"
|
||||
echo -e "Failed: ${RED}$FAILED${NC}"
|
||||
echo ""
|
||||
|
||||
if [ $FAILED -eq 0 ]; then
|
||||
echo -e "${GREEN}All tests passed! ✓${NC}"
|
||||
exit 0
|
||||
else
|
||||
echo -e "${RED}Some tests failed! ✗${NC}"
|
||||
exit 1
|
||||
fi
|
||||
811
test_xgb_fastapi.py
Normal file
811
test_xgb_fastapi.py
Normal file
@@ -0,0 +1,811 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for XGBoost FastAPI Multi-Label Classification API
|
||||
|
||||
These tests use realistic sample data extracted from /Volumes/T7/roofline_features.h5
|
||||
which was generated using feature_aggregator.py to process roofline dataframes.
|
||||
|
||||
Test data includes samples from different application types:
|
||||
- TurTLE (turbulence simulation)
|
||||
- SCALEXA (scaling benchmarks)
|
||||
- Chroma (lattice QCD)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch, MagicMock
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# Import the FastAPI app and predictor
|
||||
import xgb_fastapi
|
||||
from xgb_fastapi import app
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Data: Realistic samples from roofline_features.h5
|
||||
# Generated using feature_aggregator.py processing roofline dataframes
|
||||
# ============================================================================
|
||||
|
||||
# Sample 1: TurTLE application - typical turbulence simulation workload
|
||||
SAMPLE_TURTLE = {
|
||||
"bandwidth_raw_p10": 186.33,
|
||||
"bandwidth_raw_median": 205.14,
|
||||
"bandwidth_raw_p90": 210.83,
|
||||
"bandwidth_raw_mad": 3.57,
|
||||
"bandwidth_raw_range": 24.5,
|
||||
"bandwidth_raw_iqr": 12.075,
|
||||
"flops_raw_p10": 162.024,
|
||||
"flops_raw_median": 171.45,
|
||||
"flops_raw_p90": 176.48,
|
||||
"flops_raw_mad": 3.08,
|
||||
"flops_raw_range": 14.456,
|
||||
"flops_raw_iqr": 8.29,
|
||||
"arith_intensity_p10": 0.7906,
|
||||
"arith_intensity_median": 0.837,
|
||||
"arith_intensity_p90": 0.9109,
|
||||
"arith_intensity_mad": 0.02,
|
||||
"arith_intensity_range": 0.12,
|
||||
"arith_intensity_iqr": 0.0425,
|
||||
"bw_flops_covariance": 60.86,
|
||||
"bw_flops_correlation": 0.16,
|
||||
"avg_performance_gflops": 168.1,
|
||||
"median_performance_gflops": 171.45,
|
||||
"performance_gflops_mad": 3.08,
|
||||
"avg_memory_bw_gbs": 350.0,
|
||||
"scalar_peak_gflops": 432.0,
|
||||
"simd_peak_gflops": 9216.0,
|
||||
"node_num": 0,
|
||||
"duration": 19366,
|
||||
}
|
||||
|
||||
# Sample 2: SCALEXA application - scaling benchmark workload
|
||||
SAMPLE_SCALEXA = {
|
||||
"bandwidth_raw_p10": 13.474,
|
||||
"bandwidth_raw_median": 32.57,
|
||||
"bandwidth_raw_p90": 51.466,
|
||||
"bandwidth_raw_mad": 23.62,
|
||||
"bandwidth_raw_range": 37.992,
|
||||
"bandwidth_raw_iqr": 23.745,
|
||||
"flops_raw_p10": 4.24,
|
||||
"flops_raw_median": 16.16,
|
||||
"flops_raw_p90": 24.584,
|
||||
"flops_raw_mad": 10.53,
|
||||
"flops_raw_range": 20.344,
|
||||
"flops_raw_iqr": 12.715,
|
||||
"arith_intensity_p10": 0.211,
|
||||
"arith_intensity_median": 0.475,
|
||||
"arith_intensity_p90": 0.492,
|
||||
"arith_intensity_mad": 0.021,
|
||||
"arith_intensity_range": 0.281,
|
||||
"arith_intensity_iqr": 0.176,
|
||||
"bw_flops_covariance": 302.0,
|
||||
"bw_flops_correlation": 0.995,
|
||||
"avg_performance_gflops": 14.7,
|
||||
"median_performance_gflops": 16.16,
|
||||
"performance_gflops_mad": 10.53,
|
||||
"avg_memory_bw_gbs": 350.0,
|
||||
"scalar_peak_gflops": 432.0,
|
||||
"simd_peak_gflops": 9216.0,
|
||||
"node_num": 18,
|
||||
"duration": 165,
|
||||
}
|
||||
|
||||
# Sample 3: Chroma application - lattice QCD workload (compute-intensive)
|
||||
SAMPLE_CHROMA = {
|
||||
"bandwidth_raw_p10": 154.176,
|
||||
"bandwidth_raw_median": 200.57,
|
||||
"bandwidth_raw_p90": 259.952,
|
||||
"bandwidth_raw_mad": 5.12,
|
||||
"bandwidth_raw_range": 105.776,
|
||||
"bandwidth_raw_iqr": 10.215,
|
||||
"flops_raw_p10": 327.966,
|
||||
"flops_raw_median": 519.8,
|
||||
"flops_raw_p90": 654.422,
|
||||
"flops_raw_mad": 16.97,
|
||||
"flops_raw_range": 326.456,
|
||||
"flops_raw_iqr": 34.88,
|
||||
"arith_intensity_p10": 1.55,
|
||||
"arith_intensity_median": 2.595,
|
||||
"arith_intensity_p90": 3.445,
|
||||
"arith_intensity_mad": 0.254,
|
||||
"arith_intensity_range": 1.894,
|
||||
"arith_intensity_iqr": 0.512,
|
||||
"bw_flops_covariance": 382.76,
|
||||
"bw_flops_correlation": 0.063,
|
||||
"avg_performance_gflops": 503.26,
|
||||
"median_performance_gflops": 519.8,
|
||||
"performance_gflops_mad": 16.97,
|
||||
"avg_memory_bw_gbs": 350.0,
|
||||
"scalar_peak_gflops": 432.0,
|
||||
"simd_peak_gflops": 9216.0,
|
||||
"node_num": 3,
|
||||
"duration": 31133,
|
||||
}
|
||||
|
||||
# Sample JSON roofline data (raw data before aggregation, as would be received by API)
|
||||
SAMPLE_JSON_ROOFLINE = json.dumps([
|
||||
{
|
||||
"node_num": 1,
|
||||
"bandwidth_raw": 150.5,
|
||||
"flops_raw": 2500.0,
|
||||
"arith_intensity": 16.6,
|
||||
"performance_gflops": 1200.0,
|
||||
"memory_bw_gbs": 450,
|
||||
"scalar_peak_gflops": 600,
|
||||
"duration": 3600
|
||||
},
|
||||
{
|
||||
"node_num": 1,
|
||||
"bandwidth_raw": 155.2,
|
||||
"flops_raw": 2600.0,
|
||||
"arith_intensity": 16.8,
|
||||
"performance_gflops": 1250.0,
|
||||
"memory_bw_gbs": 450,
|
||||
"scalar_peak_gflops": 600,
|
||||
"duration": 3600
|
||||
},
|
||||
{
|
||||
"node_num": 1,
|
||||
"bandwidth_raw": 148.0,
|
||||
"flops_raw": 2450.0,
|
||||
"arith_intensity": 16.5,
|
||||
"performance_gflops": 1180.0,
|
||||
"memory_bw_gbs": 450,
|
||||
"scalar_peak_gflops": 600,
|
||||
"duration": 3600
|
||||
}
|
||||
])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def setup_predictor():
|
||||
"""
|
||||
Set up the predictor for tests.
|
||||
Try to load the real model if available.
|
||||
"""
|
||||
from xgb_inference_api import XGBoostMultiLabelPredictor
|
||||
|
||||
model_path = os.path.join(os.path.dirname(__file__), 'xgb_model.joblib')
|
||||
if os.path.exists(model_path):
|
||||
try:
|
||||
predictor = XGBoostMultiLabelPredictor(model_path)
|
||||
xgb_fastapi.predictor = predictor
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to load model: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(setup_predictor):
|
||||
"""Create a test client for the FastAPI app."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_loaded(setup_predictor):
|
||||
"""Check if the model is loaded."""
|
||||
return setup_predictor
|
||||
|
||||
|
||||
def skip_if_no_model(model_loaded):
|
||||
"""Helper to skip tests if model is not loaded."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded, skipping test")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Health and Root Endpoint Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestHealthEndpoints:
|
||||
"""Tests for health check and root endpoints."""
|
||||
|
||||
def test_root_endpoint(self, client):
|
||||
"""Test the root endpoint returns API information."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "name" in data
|
||||
assert data["name"] == "XGBoost Multi-Label Classification API"
|
||||
assert "version" in data
|
||||
assert "endpoints" in data
|
||||
assert all(key in data["endpoints"] for key in ["health", "predict", "predict_top_k", "batch_predict"])
|
||||
|
||||
def test_health_check(self, client, model_loaded):
|
||||
"""Test the health check endpoint."""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
assert "model_loaded" in data
|
||||
|
||||
# If model is loaded, check for additional info
|
||||
if model_loaded:
|
||||
assert data["model_loaded"] == True
|
||||
assert data["status"] in ["healthy", "degraded"]
|
||||
if data["status"] == "healthy":
|
||||
assert "n_classes" in data
|
||||
assert "n_features" in data
|
||||
assert "classes" in data
|
||||
assert data["n_classes"] > 0
|
||||
assert data["n_features"] > 0
|
||||
|
||||
def test_health_check_model_not_loaded(self, client):
|
||||
"""Test health check returns correct status when model not loaded."""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
# Should have status and model_loaded fields regardless
|
||||
assert "status" in data
|
||||
assert "model_loaded" in data
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Single Prediction Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestPredictEndpoint:
|
||||
"""Tests for the /predict endpoint."""
|
||||
|
||||
def test_predict_with_feature_dict(self, client, model_loaded):
|
||||
"""Test prediction with a feature dictionary (TurTLE sample)."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.5,
|
||||
"return_all_probabilities": True
|
||||
}
|
||||
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "predictions" in data
|
||||
assert "probabilities" in data
|
||||
assert "confidences" in data
|
||||
assert "threshold" in data
|
||||
|
||||
assert isinstance(data["predictions"], list)
|
||||
assert isinstance(data["probabilities"], dict)
|
||||
assert data["threshold"] == 0.5
|
||||
|
||||
# All probabilities should be between 0 and 1
|
||||
for prob in data["probabilities"].values():
|
||||
assert 0.0 <= prob <= 1.0
|
||||
|
||||
def test_predict_with_different_thresholds(self, client, model_loaded):
|
||||
"""Test that different thresholds affect predictions."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_low = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.1,
|
||||
"return_all_probabilities": True
|
||||
}
|
||||
request_high = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.9,
|
||||
"return_all_probabilities": True
|
||||
}
|
||||
|
||||
response_low = client.post("/predict", json=request_low)
|
||||
response_high = client.post("/predict", json=request_high)
|
||||
|
||||
assert response_low.status_code == 200
|
||||
assert response_high.status_code == 200
|
||||
|
||||
data_low = response_low.json()
|
||||
data_high = response_high.json()
|
||||
|
||||
# Lower threshold should generally produce more predictions
|
||||
assert len(data_low["predictions"]) >= len(data_high["predictions"])
|
||||
|
||||
def test_predict_different_workloads(self, client, model_loaded):
|
||||
"""Test predictions on different application workloads."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
samples = [
|
||||
("TurTLE", SAMPLE_TURTLE),
|
||||
("SCALEXA", SAMPLE_SCALEXA),
|
||||
("Chroma", SAMPLE_CHROMA),
|
||||
]
|
||||
|
||||
for name, sample in samples:
|
||||
request_data = {
|
||||
"features": sample,
|
||||
"threshold": 0.3,
|
||||
"return_all_probabilities": True
|
||||
}
|
||||
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 200, f"Failed for {name}"
|
||||
|
||||
data = response.json()
|
||||
assert len(data["probabilities"]) > 0, f"No probabilities for {name}"
|
||||
|
||||
def test_predict_return_only_predicted_probabilities(self, client, model_loaded):
|
||||
"""Test prediction with return_all_probabilities=False."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.3,
|
||||
"return_all_probabilities": False
|
||||
}
|
||||
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
# When return_all_probabilities is False, probabilities should only
|
||||
# contain classes that are in predictions
|
||||
if len(data["predictions"]) > 0:
|
||||
assert set(data["probabilities"].keys()) == set(data["predictions"])
|
||||
|
||||
def test_predict_with_json_roofline_data(self, client, model_loaded):
|
||||
"""Test prediction with raw JSON roofline data (requires aggregation)."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": SAMPLE_JSON_ROOFLINE,
|
||||
"is_json": True,
|
||||
"job_id": "test_job_123",
|
||||
"threshold": 0.3
|
||||
}
|
||||
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "predictions" in data
|
||||
assert "probabilities" in data
|
||||
|
||||
def test_predict_threshold_boundaries(self, client, model_loaded):
|
||||
"""Test prediction with threshold at boundaries."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
for threshold in [0.0, 0.5, 1.0]:
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": threshold
|
||||
}
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_predict_invalid_threshold(self, client):
|
||||
"""Test that invalid threshold values are rejected."""
|
||||
for threshold in [-0.1, 1.5]:
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": threshold
|
||||
}
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
def test_predict_model_not_loaded(self, client):
|
||||
"""Test that prediction returns 503 when model not loaded."""
|
||||
# Temporarily set predictor to None
|
||||
original_predictor = xgb_fastapi.predictor
|
||||
xgb_fastapi.predictor = None
|
||||
|
||||
try:
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.5
|
||||
}
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 503
|
||||
assert "Model not loaded" in response.json().get("detail", "")
|
||||
finally:
|
||||
xgb_fastapi.predictor = original_predictor
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Top-K Prediction Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestPredictTopKEndpoint:
|
||||
"""Tests for the /predict_top_k endpoint."""
|
||||
|
||||
def test_predict_top_k_default(self, client, model_loaded):
|
||||
"""Test top-K prediction with default k=5."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE
|
||||
}
|
||||
|
||||
response = client.post("/predict_top_k", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "top_predictions" in data
|
||||
assert "top_probabilities" in data
|
||||
assert "all_probabilities" in data
|
||||
|
||||
assert len(data["top_predictions"]) <= 5
|
||||
assert len(data["top_probabilities"]) <= 5
|
||||
|
||||
def test_predict_top_k_custom_k(self, client, model_loaded):
|
||||
"""Test top-K prediction with custom k values."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
for k in [1, 3, 10]:
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"k": k
|
||||
}
|
||||
|
||||
response = client.post("/predict_top_k", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data["top_predictions"]) <= k
|
||||
|
||||
def test_predict_top_k_ordering(self, client, model_loaded):
|
||||
"""Test that top-K predictions are ordered by probability."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": SAMPLE_CHROMA,
|
||||
"k": 10
|
||||
}
|
||||
|
||||
response = client.post("/predict_top_k", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
probabilities = [data["top_probabilities"][cls] for cls in data["top_predictions"]]
|
||||
|
||||
# Check that probabilities are in descending order
|
||||
for i in range(len(probabilities) - 1):
|
||||
assert probabilities[i] >= probabilities[i + 1]
|
||||
|
||||
def test_predict_top_k_with_json_data(self, client, model_loaded):
|
||||
"""Test top-K prediction with JSON roofline data."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": SAMPLE_JSON_ROOFLINE,
|
||||
"is_json": True,
|
||||
"job_id": "test_job_456",
|
||||
"k": 5
|
||||
}
|
||||
|
||||
response = client.post("/predict_top_k", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data["top_predictions"]) <= 5
|
||||
|
||||
def test_predict_top_k_invalid_k(self, client):
|
||||
"""Test that invalid k values are rejected."""
|
||||
for k in [0, -1, 101]:
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"k": k
|
||||
}
|
||||
response = client.post("/predict_top_k", json=request_data)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Batch Prediction Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestBatchPredictEndpoint:
|
||||
"""Tests for the /batch_predict endpoint."""
|
||||
|
||||
def test_batch_predict_multiple_samples(self, client, model_loaded):
|
||||
"""Test batch prediction with multiple samples."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features_list": [SAMPLE_TURTLE, SAMPLE_SCALEXA, SAMPLE_CHROMA],
|
||||
"threshold": 0.3
|
||||
}
|
||||
|
||||
response = client.post("/batch_predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "results" in data
|
||||
assert "total" in data
|
||||
assert "successful" in data
|
||||
assert "failed" in data
|
||||
|
||||
assert data["total"] == 3
|
||||
assert len(data["results"]) == 3
|
||||
assert data["successful"] + data["failed"] == data["total"]
|
||||
|
||||
def test_batch_predict_single_sample(self, client, model_loaded):
|
||||
"""Test batch prediction with a single sample."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features_list": [SAMPLE_TURTLE],
|
||||
"threshold": 0.5
|
||||
}
|
||||
|
||||
response = client.post("/batch_predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert len(data["results"]) == 1
|
||||
|
||||
def test_batch_predict_with_json_data(self, client, model_loaded):
|
||||
"""Test batch prediction with JSON roofline data."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features_list": [SAMPLE_JSON_ROOFLINE, SAMPLE_JSON_ROOFLINE],
|
||||
"is_json": True,
|
||||
"job_ids": ["job_001", "job_002"],
|
||||
"threshold": 0.3
|
||||
}
|
||||
|
||||
response = client.post("/batch_predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
|
||||
def test_batch_predict_empty_list(self, client, model_loaded):
|
||||
"""Test batch prediction with empty list."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features_list": [],
|
||||
"threshold": 0.5
|
||||
}
|
||||
|
||||
response = client.post("/batch_predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 0
|
||||
assert data["successful"] == 0
|
||||
assert data["failed"] == 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Model Info Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestModelInfoEndpoint:
|
||||
"""Tests for the /model/info endpoint."""
|
||||
|
||||
def test_model_info(self, client, model_loaded):
|
||||
"""Test getting model information."""
|
||||
response = client.get("/model/info")
|
||||
|
||||
# May return 503 if model not loaded, or 200 if loaded
|
||||
if model_loaded:
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "classes" in data
|
||||
assert "n_classes" in data
|
||||
assert "features" in data
|
||||
assert "n_features" in data
|
||||
|
||||
assert isinstance(data["classes"], list)
|
||||
assert len(data["classes"]) == data["n_classes"]
|
||||
assert len(data["features"]) == data["n_features"]
|
||||
else:
|
||||
assert response.status_code == 503
|
||||
|
||||
def test_model_info_not_loaded(self, client):
|
||||
"""Test that model info returns 503 when model not loaded."""
|
||||
original_predictor = xgb_fastapi.predictor
|
||||
xgb_fastapi.predictor = None
|
||||
|
||||
try:
|
||||
response = client.get("/model/info")
|
||||
assert response.status_code == 503
|
||||
finally:
|
||||
xgb_fastapi.predictor = original_predictor
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Error Handling Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling."""
|
||||
|
||||
def test_predict_missing_features(self, client):
|
||||
"""Test prediction without features field."""
|
||||
request_data = {
|
||||
"threshold": 0.5
|
||||
}
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_predict_invalid_json_format(self, client, model_loaded):
|
||||
"""Test prediction with invalid JSON in is_json mode."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": "not valid json {{",
|
||||
"is_json": True
|
||||
}
|
||||
response = client.post("/predict", json=request_data)
|
||||
# Should return error (400)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_invalid_endpoint(self, client):
|
||||
"""Test accessing an invalid endpoint."""
|
||||
response = client.get("/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests (Full Pipeline)
|
||||
# ============================================================================
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for full prediction pipeline."""
|
||||
|
||||
def test_full_prediction_pipeline_features(self, client, model_loaded):
|
||||
"""Test complete prediction pipeline with feature dict."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
# 1. Check health
|
||||
health_response = client.get("/health")
|
||||
assert health_response.status_code == 200
|
||||
|
||||
health_data = health_response.json()
|
||||
assert health_data["model_loaded"] == True
|
||||
|
||||
# 2. Get model info
|
||||
info_response = client.get("/model/info")
|
||||
assert info_response.status_code == 200
|
||||
|
||||
# 3. Make single prediction
|
||||
predict_response = client.post("/predict", json={
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.3
|
||||
})
|
||||
assert predict_response.status_code == 200
|
||||
|
||||
# 4. Make top-K prediction
|
||||
topk_response = client.post("/predict_top_k", json={
|
||||
"features": SAMPLE_TURTLE,
|
||||
"k": 5
|
||||
})
|
||||
assert topk_response.status_code == 200
|
||||
|
||||
# 5. Make batch prediction
|
||||
batch_response = client.post("/batch_predict", json={
|
||||
"features_list": [SAMPLE_TURTLE, SAMPLE_CHROMA],
|
||||
"threshold": 0.3
|
||||
})
|
||||
assert batch_response.status_code == 200
|
||||
|
||||
def test_consistency_single_vs_batch(self, client, model_loaded):
|
||||
"""Test that single prediction and batch prediction give consistent results."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
threshold = 0.3
|
||||
|
||||
# Single prediction
|
||||
single_response = client.post("/predict", json={
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": threshold
|
||||
})
|
||||
|
||||
# Batch prediction with same sample
|
||||
batch_response = client.post("/batch_predict", json={
|
||||
"features_list": [SAMPLE_TURTLE],
|
||||
"threshold": threshold
|
||||
})
|
||||
|
||||
assert single_response.status_code == 200
|
||||
assert batch_response.status_code == 200
|
||||
|
||||
single_data = single_response.json()
|
||||
batch_data = batch_response.json()
|
||||
|
||||
if batch_data["successful"] == 1:
|
||||
batch_result = batch_data["results"][0]
|
||||
# Predictions should be the same
|
||||
assert set(single_data["predictions"]) == set(batch_result["predictions"])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CORS Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestCORS:
|
||||
"""Tests for CORS configuration."""
|
||||
|
||||
def test_cors_headers(self, client):
|
||||
"""Test that CORS headers are present in responses."""
|
||||
response = client.options("/predict")
|
||||
# Accept either 200 or 405 (method not allowed for OPTIONS in some configs)
|
||||
assert response.status_code in [200, 405]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Performance Tests (Basic)
|
||||
# ============================================================================
|
||||
|
||||
class TestPerformance:
|
||||
"""Basic performance tests."""
|
||||
|
||||
def test_response_time_single_prediction(self, client, model_loaded):
|
||||
"""Test that single prediction completes in reasonable time."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
response = client.post("/predict", json={
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.5
|
||||
})
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Should complete within 5 seconds (generous for CI environments)
|
||||
assert elapsed < 5.0, f"Prediction took {elapsed:.2f}s, expected < 5s"
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_response_time_batch_prediction(self, client, model_loaded):
|
||||
"""Test that batch prediction scales reasonably."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
import time
|
||||
|
||||
# Create a batch of 10 samples
|
||||
features_list = [SAMPLE_TURTLE] * 10
|
||||
|
||||
start = time.time()
|
||||
response = client.post("/batch_predict", json={
|
||||
"features_list": features_list,
|
||||
"threshold": 0.5
|
||||
})
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Should complete within 10 seconds
|
||||
assert elapsed < 10.0, f"Batch prediction took {elapsed:.2f}s, expected < 10s"
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Run tests if executed directly
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
429
xgb_fastapi.py
Normal file
429
xgb_fastapi.py
Normal file
@@ -0,0 +1,429 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FastAPI REST API for XGBoost Multi-Label Classification Inference
|
||||
|
||||
Provides HTTP endpoints for:
|
||||
- POST /predict: Single prediction with confidence scores
|
||||
- POST /predict_top_k: Top-K predictions
|
||||
- POST /batch_predict: Batch predictions
|
||||
|
||||
Supports both feature vectors and JSON roofline data as input.
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import Dict, List, Union, Optional, Any
|
||||
import logging
|
||||
import uvicorn
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from xgb_local import XGBoostMultiLabelPredictor
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global predictor instance
|
||||
predictor: Optional[XGBoostMultiLabelPredictor] = None
|
||||
|
||||
|
||||
# Pydantic models for request/response validation
|
||||
class PredictRequest(BaseModel):
|
||||
"""Request model for single prediction."""
|
||||
features: Union[Dict[str, float], str] = Field(
|
||||
...,
|
||||
description="Feature dictionary or JSON string of roofline data"
|
||||
)
|
||||
threshold: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Probability threshold for classification"
|
||||
)
|
||||
return_all_probabilities: bool = Field(
|
||||
default=True,
|
||||
description="Whether to return probabilities for all classes"
|
||||
)
|
||||
is_json: bool = Field(
|
||||
default=False,
|
||||
description="Whether features is a JSON string of roofline data"
|
||||
)
|
||||
job_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional job ID for JSON aggregation"
|
||||
)
|
||||
|
||||
|
||||
class PredictTopKRequest(BaseModel):
|
||||
"""Request model for top-K prediction."""
|
||||
features: Union[Dict[str, float], str] = Field(
|
||||
...,
|
||||
description="Feature dictionary or JSON string of roofline data"
|
||||
)
|
||||
k: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of top predictions to return"
|
||||
)
|
||||
is_json: bool = Field(
|
||||
default=False,
|
||||
description="Whether features is a JSON string of roofline data"
|
||||
)
|
||||
job_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional job ID for JSON aggregation"
|
||||
)
|
||||
|
||||
|
||||
class BatchPredictRequest(BaseModel):
|
||||
"""Request model for batch prediction."""
|
||||
features_list: List[Union[Dict[str, float], str]] = Field(
|
||||
...,
|
||||
description="List of feature dictionaries or JSON strings"
|
||||
)
|
||||
threshold: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Probability threshold for classification"
|
||||
)
|
||||
is_json: bool = Field(
|
||||
default=False,
|
||||
description="Whether features are JSON strings of roofline data"
|
||||
)
|
||||
job_ids: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Optional list of job IDs for JSON aggregation"
|
||||
)
|
||||
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
"""Response model for single prediction."""
|
||||
predictions: List[str] = Field(description="List of predicted class names")
|
||||
probabilities: Dict[str, float] = Field(description="Probabilities for each class")
|
||||
confidences: Dict[str, float] = Field(description="Confidence scores for predicted classes")
|
||||
threshold: float = Field(description="Threshold used for prediction")
|
||||
|
||||
|
||||
class PredictTopKResponse(BaseModel):
|
||||
"""Response model for top-K prediction."""
|
||||
top_predictions: List[str] = Field(description="Top-K predicted class names")
|
||||
top_probabilities: Dict[str, float] = Field(description="Probabilities for top-K classes")
|
||||
all_probabilities: Dict[str, float] = Field(description="Probabilities for all classes")
|
||||
|
||||
|
||||
class BatchPredictResponse(BaseModel):
|
||||
"""Response model for batch prediction."""
|
||||
results: List[Union[PredictResponse, Dict[str, str]]] = Field(
|
||||
description="List of prediction results or errors"
|
||||
)
|
||||
total: int = Field(description="Total number of samples processed")
|
||||
successful: int = Field(description="Number of successful predictions")
|
||||
failed: int = Field(description="Number of failed predictions")
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Response model for health check."""
|
||||
status: str
|
||||
model_loaded: bool
|
||||
n_classes: Optional[int] = None
|
||||
n_features: Optional[int] = None
|
||||
classes: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Response model for errors."""
|
||||
error: str
|
||||
detail: Optional[str] = None
|
||||
|
||||
|
||||
# Lifespan context manager for startup/shutdown
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage application lifespan events."""
|
||||
# Startup
|
||||
global predictor
|
||||
try:
|
||||
logger.info("Loading XGBoost model...")
|
||||
predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')
|
||||
logger.info("Model loaded successfully!")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
predictor = None
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down...")
|
||||
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title="XGBoost Multi-Label Classification API",
|
||||
description="REST API for multi-label classification inference using XGBoost",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, specify allowed origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# API Endpoints
|
||||
|
||||
@app.get("/", tags=["General"])
|
||||
async def root():
|
||||
"""Root endpoint with API information."""
|
||||
return {
|
||||
"name": "XGBoost Multi-Label Classification API",
|
||||
"version": "1.0.0",
|
||||
"endpoints": {
|
||||
"health": "/health",
|
||||
"predict": "/predict",
|
||||
"predict_top_k": "/predict_top_k",
|
||||
"batch_predict": "/batch_predict"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse, tags=["General"])
|
||||
async def health_check():
|
||||
"""
|
||||
Check API health and model status.
|
||||
|
||||
Returns model information if loaded successfully.
|
||||
"""
|
||||
if predictor is None:
|
||||
return HealthResponse(
|
||||
status="error",
|
||||
model_loaded=False
|
||||
)
|
||||
|
||||
try:
|
||||
info = predictor.get_class_info()
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
model_loaded=True,
|
||||
n_classes=info['n_classes'],
|
||||
n_features=info['n_features'],
|
||||
classes=info['classes']
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Health check error: {e}")
|
||||
return HealthResponse(
|
||||
status="degraded",
|
||||
model_loaded=True
|
||||
)
|
||||
|
||||
|
||||
@app.post("/predict", response_model=PredictResponse, tags=["Inference"])
|
||||
async def predict(request: PredictRequest):
|
||||
"""
|
||||
Perform single prediction on input features.
|
||||
|
||||
**Input formats:**
|
||||
- Feature dictionary: `{"feature1": value1, "feature2": value2, ...}`
|
||||
- JSON roofline data: Set `is_json=true` and provide JSON string
|
||||
|
||||
**Example (features):**
|
||||
```json
|
||||
{
|
||||
"features": {
|
||||
"bandwidth_raw_p10": 150.5,
|
||||
"flops_raw_median": 2500.0,
|
||||
...
|
||||
},
|
||||
"threshold": 0.5
|
||||
}
|
||||
```
|
||||
|
||||
**Example (JSON roofline):**
|
||||
```json
|
||||
{
|
||||
"features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
|
||||
"is_json": true,
|
||||
"job_id": "test_job_123",
|
||||
"threshold": 0.3
|
||||
}
|
||||
```
|
||||
"""
|
||||
if predictor is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
result = predictor.predict(
|
||||
features=request.features,
|
||||
threshold=request.threshold,
|
||||
return_all_probabilities=request.return_all_probabilities,
|
||||
is_json=request.is_json,
|
||||
job_id=request.job_id
|
||||
)
|
||||
return PredictResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"Prediction error: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/predict_top_k", response_model=PredictTopKResponse, tags=["Inference"])
|
||||
async def predict_top_k(request: PredictTopKRequest):
|
||||
"""
|
||||
Get top-K predictions with their probabilities.
|
||||
|
||||
**Example (features):**
|
||||
```json
|
||||
{
|
||||
"features": {
|
||||
"bandwidth_raw_p10": 150.5,
|
||||
"flops_raw_median": 2500.0,
|
||||
...
|
||||
},
|
||||
"k": 5
|
||||
}
|
||||
```
|
||||
|
||||
**Example (JSON roofline):**
|
||||
```json
|
||||
{
|
||||
"features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
|
||||
"is_json": true,
|
||||
"job_id": "test_job_123",
|
||||
"k": 10
|
||||
}
|
||||
```
|
||||
"""
|
||||
if predictor is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
result = predictor.predict_top_k(
|
||||
features=request.features,
|
||||
k=request.k,
|
||||
is_json=request.is_json,
|
||||
job_id=request.job_id
|
||||
)
|
||||
return PredictTopKResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"Top-K prediction error: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/batch_predict", response_model=BatchPredictResponse, tags=["Inference"])
|
||||
async def batch_predict(request: BatchPredictRequest):
|
||||
"""
|
||||
Perform batch prediction on multiple samples.
|
||||
|
||||
**Example (features):**
|
||||
```json
|
||||
{
|
||||
"features_list": [
|
||||
{"bandwidth_raw_p10": 150.5, ...},
|
||||
{"bandwidth_raw_p10": 160.2, ...}
|
||||
],
|
||||
"threshold": 0.5
|
||||
}
|
||||
```
|
||||
|
||||
**Example (JSON roofline):**
|
||||
```json
|
||||
{
|
||||
"features_list": [
|
||||
"[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
|
||||
"[{\"node_num\": 2, \"bandwidth_raw\": 160.2, ...}]"
|
||||
],
|
||||
"is_json": true,
|
||||
"job_ids": ["job1", "job2"],
|
||||
"threshold": 0.3
|
||||
}
|
||||
```
|
||||
"""
|
||||
if predictor is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
results = predictor.batch_predict(
|
||||
features_list=request.features_list,
|
||||
threshold=request.threshold,
|
||||
is_json=request.is_json,
|
||||
job_ids=request.job_ids
|
||||
)
|
||||
|
||||
# Count successful and failed predictions
|
||||
successful = sum(1 for r in results if 'error' not in r)
|
||||
failed = len(results) - successful
|
||||
|
||||
return BatchPredictResponse(
|
||||
results=results,
|
||||
total=len(results),
|
||||
successful=successful,
|
||||
failed=failed
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Batch prediction error: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/model/info", tags=["Model"])
|
||||
async def model_info():
|
||||
"""
|
||||
Get detailed model information.
|
||||
|
||||
Returns information about classes, features, and model configuration.
|
||||
"""
|
||||
if predictor is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
info = predictor.get_class_info()
|
||||
return {
|
||||
"classes": info['classes'],
|
||||
"n_classes": info['n_classes'],
|
||||
"features": info['feature_columns'],
|
||||
"n_features": info['n_features']
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Model info error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the FastAPI server."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="XGBoost Multi-Label Classification REST API")
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0',
|
||||
help='Host to bind to (default: 0.0.0.0)')
|
||||
parser.add_argument('--port', type=int, default=8000,
|
||||
help='Port to bind to (default: 8000)')
|
||||
parser.add_argument('--reload', action='store_true',
|
||||
help='Enable auto-reload for development')
|
||||
parser.add_argument('--workers', type=int, default=1,
|
||||
help='Number of worker processes (default: 1)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"Starting FastAPI server on {args.host}:{args.port}")
|
||||
logger.info(f"Workers: {args.workers}, Reload: {args.reload}")
|
||||
|
||||
uvicorn.run(
|
||||
"xgb_fastapi:app",
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
reload=args.reload,
|
||||
workers=args.workers if not args.reload else 1,
|
||||
log_level="info"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
530
xgb_local.py
Normal file
530
xgb_local.py
Normal file
@@ -0,0 +1,530 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import joblib
|
||||
import json
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
from scipy import stats
|
||||
|
||||
def compute_mad(data: np.ndarray) -> float:
|
||||
"""Compute Median Absolute Deviation."""
|
||||
median = np.median(data)
|
||||
mad = np.median(np.abs(data - median))
|
||||
return mad
|
||||
|
||||
|
||||
def df_aggregate(json_str: str, job_id_full: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Aggregate roofline data from JSON string into a single feature vector.
|
||||
|
||||
Args:
|
||||
json_str: JSON string containing roofline data records
|
||||
job_id_full: Optional job ID to include in the result
|
||||
|
||||
Returns:
|
||||
Dictionary containing aggregated features
|
||||
"""
|
||||
# Parse JSON string to DataFrame
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
if isinstance(data, list):
|
||||
df = pd.DataFrame(data)
|
||||
elif isinstance(data, dict):
|
||||
df = pd.DataFrame([data])
|
||||
else:
|
||||
raise ValueError("JSON must contain a list of objects or a single object")
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON string: {e}")
|
||||
|
||||
# Group data by node_num
|
||||
if 'node_num' not in df.columns:
|
||||
# If no node_num, treat all data as single node
|
||||
df['node_num'] = 1
|
||||
|
||||
grouped = df.groupby('node_num')
|
||||
|
||||
all_features = []
|
||||
for node_num, group in grouped:
|
||||
features = {
|
||||
'node_num': int(node_num)
|
||||
}
|
||||
|
||||
if job_id_full is not None:
|
||||
features['job_id'] = job_id_full
|
||||
|
||||
# Compute statistics for key metrics
|
||||
for axis in ['bandwidth_raw', 'flops_raw', 'arith_intensity']:
|
||||
data = group[axis].values
|
||||
|
||||
# Compute percentiles
|
||||
p10 = np.percentile(data, 10)
|
||||
p50 = np.median(data)
|
||||
p90 = np.percentile(data, 90)
|
||||
|
||||
# Compute MAD (more robust than variance)
|
||||
mad = compute_mad(data)
|
||||
|
||||
# Store features
|
||||
features[f'{axis}_p10'] = p10
|
||||
features[f'{axis}_median'] = p50
|
||||
features[f'{axis}_p90'] = p90
|
||||
features[f'{axis}_mad'] = mad
|
||||
features[f'{axis}_range'] = p90 - p10
|
||||
features[f'{axis}_iqr'] = np.percentile(data, 75) - np.percentile(data, 25)
|
||||
|
||||
# Compute covariance and correlation between bandwidth_raw and flops_raw
|
||||
if len(group) > 1: # Need at least 2 points for correlation
|
||||
cov = np.cov(group['bandwidth_raw'], group['flops_raw'])[0, 1]
|
||||
features['bw_flops_covariance'] = cov
|
||||
|
||||
corr, _ = stats.pearsonr(group['bandwidth_raw'], group['flops_raw'])
|
||||
features['bw_flops_correlation'] = corr
|
||||
|
||||
# Additional useful features for the classifier
|
||||
|
||||
# Performance metrics
|
||||
features['avg_performance_gflops'] = group['performance_gflops'].mean()
|
||||
features['median_performance_gflops'] = group['performance_gflops'].median()
|
||||
features['performance_gflops_mad'] = compute_mad(group['performance_gflops'].values)
|
||||
|
||||
# # Efficiency metrics
|
||||
# features['avg_efficiency'] = group['efficiency'].mean()
|
||||
# features['median_efficiency'] = group['efficiency'].median()
|
||||
# features['efficiency_mad'] = compute_mad(group['efficiency'].values)
|
||||
# features['efficiency_p10'] = np.percentile(group['efficiency'].values, 10)
|
||||
# features['efficiency_p90'] = np.percentile(group['efficiency'].values, 90)
|
||||
|
||||
# # Distribution of roofline regions (memory-bound vs compute-bound)
|
||||
# if 'roofline_region' in group.columns:
|
||||
# region_counts = group['roofline_region'].value_counts(normalize=True).to_dict()
|
||||
# for region, ratio in region_counts.items():
|
||||
# features[f'region_{region}_ratio'] = ratio
|
||||
|
||||
# System characteristics
|
||||
if 'memory_bw_gbs' in group.columns:
|
||||
features['avg_memory_bw_gbs'] = group['memory_bw_gbs'].mean()
|
||||
if 'scalar_peak_gflops' in group.columns and len(group['scalar_peak_gflops'].unique()) > 0:
|
||||
features['scalar_peak_gflops'] = group['scalar_peak_gflops'].iloc[0]
|
||||
if 'simd_peak_gflops' in group.columns and len(group['simd_peak_gflops'].unique()) > 0:
|
||||
features['simd_peak_gflops'] = group['simd_peak_gflops'].iloc[0]
|
||||
|
||||
# # Subcluster information if available
|
||||
# if 'subcluster_name' in group.columns and not group['subcluster_name'].isna().all():
|
||||
# features['subcluster_name'] = group['subcluster_name'].iloc[0]
|
||||
|
||||
# Duration information
|
||||
if 'duration' in group.columns:
|
||||
features['duration'] = group['duration'].iloc[0]
|
||||
|
||||
all_features.append(features)
|
||||
|
||||
# Return first node's features (or combine multiple nodes if needed)
|
||||
if len(all_features) == 1:
|
||||
return all_features[0]
|
||||
else:
|
||||
# If multiple nodes, return the first one or average across nodes
|
||||
# For now, return the first node's features
|
||||
return all_features[0]
|
||||
|
||||
class XGBoostMultiLabelPredictor:
|
||||
"""
|
||||
Python API for XGBoost multi-label classification inference.
|
||||
|
||||
Provides methods to load trained models band perform inference with
|
||||
confidence scores for each class.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str = 'xgb_model.joblib'):
|
||||
"""
|
||||
Initialize the predictor by loading the trained model.
|
||||
|
||||
Args:
|
||||
model_path: Path to the saved model file (.joblib)
|
||||
"""
|
||||
self.model_data = None
|
||||
self.model = None
|
||||
self.mlb = None
|
||||
self.feature_columns = None
|
||||
self.n_features = 0
|
||||
self.classes = []
|
||||
|
||||
self.load_model(model_path)
|
||||
|
||||
def load_model(self, model_path: str) -> None:
|
||||
"""
|
||||
Load the trained XGBoost model from disk.
|
||||
|
||||
Args:
|
||||
model_path: Path to the saved model file
|
||||
"""
|
||||
try:
|
||||
print(f"Loading model from {model_path}...")
|
||||
self.model_data = joblib.load(model_path)
|
||||
self.model = self.model_data['model']
|
||||
self.mlb = self.model_data['mlb']
|
||||
self.feature_columns = self.model_data['feature_columns']
|
||||
|
||||
self.classes = list(self.mlb.classes_)
|
||||
self.n_features = len(self.feature_columns)
|
||||
|
||||
print("Model loaded successfully!")
|
||||
print(f" - {len(self.classes)} classes: {self.classes}")
|
||||
print(f" - {self.n_features} features: {self.feature_columns[:5]}...")
|
||||
print(f" - Model type: {type(self.model).__name__}")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load model from {model_path}: {e}")
|
||||
|
||||
def predict(self, features: Union[pd.DataFrame, np.ndarray, List, Dict, str],
|
||||
threshold: float = 0.5,
|
||||
return_all_probabilities: bool = True,
|
||||
is_json: bool = False,
|
||||
job_id: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Perform multi-label prediction on input features.
|
||||
|
||||
Args:
|
||||
features: Input features in various formats:
|
||||
- pandas DataFrame
|
||||
- numpy array (2D)
|
||||
- list of lists/dicts
|
||||
- single feature vector (list/dict)
|
||||
- JSON string (if is_json=True): roofline data to aggregate
|
||||
threshold: Probability threshold for binary classification (0.0-1.0)
|
||||
return_all_probabilities: If True, return probabilities for all classes.
|
||||
If False, return only classes above threshold.
|
||||
is_json: If True, treat features as JSON string of roofline data
|
||||
job_id: Optional job ID (used when is_json=True)
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- 'predictions': List of predicted class names
|
||||
- 'probabilities': Dict of {class_name: probability} for all classes
|
||||
- 'confidences': Dict of {class_name: confidence_score} for predicted classes
|
||||
- 'threshold': The threshold used
|
||||
"""
|
||||
# If input is JSON string, aggregate features first
|
||||
if is_json:
|
||||
if not isinstance(features, str):
|
||||
raise ValueError("When is_json=True, features must be a JSON string")
|
||||
features = df_aggregate(features, job_id_full=job_id)
|
||||
|
||||
# Convert input to proper format
|
||||
X = self._prepare_features(features)
|
||||
|
||||
# Get probability predictions
|
||||
probabilities = self.model.predict_proba(X)
|
||||
|
||||
# Convert to class probabilities
|
||||
class_probabilities = {}
|
||||
for i, class_name in enumerate(self.classes):
|
||||
# For OneVsRest, predict_proba returns shape (n_samples, n_classes)
|
||||
# Each column i contains probabilities for class i
|
||||
if isinstance(probabilities, list):
|
||||
# List of arrays (multiple samples)
|
||||
prob_array = probabilities[i]
|
||||
prob_positive = prob_array[0] if hasattr(prob_array, '__getitem__') else float(prob_array)
|
||||
else:
|
||||
# 2D numpy array (single sample or batch)
|
||||
if len(probabilities.shape) == 2:
|
||||
# Shape: (n_samples, n_classes)
|
||||
prob_positive = float(probabilities[0, i])
|
||||
else:
|
||||
# 1D array
|
||||
prob_positive = float(probabilities[i])
|
||||
class_probabilities[class_name] = prob_positive
|
||||
|
||||
# Apply threshold for predictions
|
||||
predictions = []
|
||||
confidences = {}
|
||||
|
||||
for class_name, prob in class_probabilities.items():
|
||||
if prob >= threshold:
|
||||
predictions.append(class_name)
|
||||
# Confidence score: distance from threshold as percentage
|
||||
confidence = min(1.0, (prob - threshold) / (1.0 - threshold)) * 100
|
||||
confidences[class_name] = round(confidence, 2)
|
||||
|
||||
# Sort predictions by probability
|
||||
predictions.sort(key=lambda x: class_probabilities[x], reverse=True)
|
||||
|
||||
result = {
|
||||
'predictions': predictions,
|
||||
'probabilities': {k: round(v, 4) for k, v in class_probabilities.items()},
|
||||
'confidences': confidences,
|
||||
'threshold': threshold
|
||||
}
|
||||
|
||||
if not return_all_probabilities:
|
||||
result['probabilities'] = {k: v for k, v in result['probabilities'].items()
|
||||
if k in predictions}
|
||||
|
||||
return result
|
||||
|
||||
def predict_top_k(self, features: Union[pd.DataFrame, np.ndarray, List, Dict, str],
|
||||
k: int = 5,
|
||||
is_json: bool = False,
|
||||
job_id: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Get top-k predictions with their probabilities.
|
||||
|
||||
Args:
|
||||
features: Input features (various formats) or JSON string if is_json=True
|
||||
k: Number of top predictions to return
|
||||
is_json: If True, treat features as JSON string of roofline data
|
||||
job_id: Optional job ID (used when is_json=True)
|
||||
|
||||
Returns:
|
||||
Dictionary with top-k predictions and their details
|
||||
"""
|
||||
# If input is JSON string, aggregate features first
|
||||
if is_json:
|
||||
if not isinstance(features, str):
|
||||
raise ValueError("When is_json=True, features must be a JSON string")
|
||||
features = df_aggregate(features, job_id_full=job_id)
|
||||
|
||||
# Get all probabilities
|
||||
X = self._prepare_features(features)
|
||||
probabilities = self.model.predict_proba(X)
|
||||
|
||||
class_probabilities = {}
|
||||
for i, class_name in enumerate(self.classes):
|
||||
# For OneVsRest, predict_proba returns shape (n_samples, n_classes)
|
||||
# Each column i contains probabilities for class i
|
||||
if isinstance(probabilities, list):
|
||||
# List of arrays (multiple samples)
|
||||
prob_array = probabilities[i]
|
||||
prob_positive = prob_array[0] if hasattr(prob_array, '__getitem__') else float(prob_array)
|
||||
else:
|
||||
# 2D numpy array (single sample or batch)
|
||||
if len(probabilities.shape) == 2:
|
||||
# Shape: (n_samples, n_classes)
|
||||
prob_positive = float(probabilities[0, i])
|
||||
else:
|
||||
# 1D array
|
||||
prob_positive = float(probabilities[i])
|
||||
class_probabilities[class_name] = prob_positive
|
||||
|
||||
# Sort by probability
|
||||
sorted_classes = sorted(class_probabilities.items(),
|
||||
key=lambda x: x[1], reverse=True)
|
||||
|
||||
top_k_classes = sorted_classes[:k]
|
||||
|
||||
return {
|
||||
'top_predictions': [cls for cls, _ in top_k_classes],
|
||||
'top_probabilities': {cls: round(prob, 4) for cls, prob in top_k_classes},
|
||||
'all_probabilities': {k: round(v, 4) for k, v in class_probabilities.items()}
|
||||
}
|
||||
|
||||
def _prepare_features(self, features: Union[pd.DataFrame, np.ndarray, List, Dict]) -> pd.DataFrame:
|
||||
"""
|
||||
Convert various input formats to the expected feature format.
|
||||
|
||||
Args:
|
||||
features: Input features in various formats
|
||||
|
||||
Returns:
|
||||
pandas DataFrame with correct columns and order
|
||||
"""
|
||||
if isinstance(features, pd.DataFrame):
|
||||
df = features.copy()
|
||||
elif isinstance(features, np.ndarray):
|
||||
if features.ndim == 1:
|
||||
features = features.reshape(1, -1)
|
||||
df = pd.DataFrame(features, columns=self.feature_columns[:features.shape[1]])
|
||||
elif isinstance(features, list):
|
||||
if isinstance(features[0], dict):
|
||||
# List of dictionaries
|
||||
df = pd.DataFrame(features)
|
||||
else:
|
||||
# List of lists
|
||||
df = pd.DataFrame(features, columns=self.feature_columns[:len(features[0])])
|
||||
elif isinstance(features, dict):
|
||||
# Single feature dictionary
|
||||
df = pd.DataFrame([features])
|
||||
else:
|
||||
raise ValueError(f"Unsupported feature format: {type(features)}")
|
||||
|
||||
# Ensure correct column order and fill missing columns with 0
|
||||
for col in self.feature_columns:
|
||||
if col not in df.columns:
|
||||
df[col] = 0.0
|
||||
|
||||
df = df[self.feature_columns]
|
||||
|
||||
# Validate feature count
|
||||
if df.shape[1] != self.n_features:
|
||||
raise ValueError(f"Expected {self.n_features} features, got {df.shape[1]}")
|
||||
|
||||
return df
|
||||
|
||||
def get_class_info(self) -> Dict:
|
||||
"""
|
||||
Get information about available classes.
|
||||
|
||||
Returns:
|
||||
Dictionary with class information
|
||||
"""
|
||||
return {
|
||||
'classes': self.classes,
|
||||
'n_classes': len(self.classes),
|
||||
'feature_columns': self.feature_columns,
|
||||
'n_features': self.n_features
|
||||
}
|
||||
|
||||
def batch_predict(self, features_list: List[Union[pd.DataFrame, np.ndarray, List, Dict, str]],
|
||||
threshold: float = 0.5,
|
||||
is_json: bool = False,
|
||||
job_ids: Optional[List[str]] = None) -> List[Dict]:
|
||||
"""
|
||||
Perform batch prediction on multiple samples.
|
||||
|
||||
Args:
|
||||
features_list: List of feature inputs (or JSON strings if is_json=True)
|
||||
threshold: Probability threshold
|
||||
is_json: If True, treat each item in features_list as JSON string
|
||||
job_ids: Optional list of job IDs (used when is_json=True)
|
||||
|
||||
Returns:
|
||||
List of prediction results
|
||||
"""
|
||||
results = []
|
||||
for idx, features in enumerate(features_list):
|
||||
try:
|
||||
job_id = job_ids[idx] if job_ids and idx < len(job_ids) else None
|
||||
result = self.predict(features, threshold=threshold, is_json=is_json, job_id=job_id)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
results.append({'error': str(e)})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def create_sample_data(n_samples: int = 5) -> List[Dict]:
|
||||
"""
|
||||
Create sample feature data for testing.
|
||||
|
||||
Args:
|
||||
n_samples: Number of sample feature vectors to create
|
||||
|
||||
Returns:
|
||||
List of feature dictionaries
|
||||
"""
|
||||
np.random.seed(42)
|
||||
|
||||
# Load feature columns from model if available
|
||||
try:
|
||||
model_data = joblib.load('xgb_model.joblib')
|
||||
feature_columns = model_data['feature_columns']
|
||||
except:
|
||||
# Fallback to some default features
|
||||
feature_columns = [
|
||||
'node_num', 'bandwidth_raw_p10', 'bandwidth_raw_median',
|
||||
'bandwidth_raw_p90', 'bandwidth_raw_mad', 'bandwidth_raw_range',
|
||||
'bandwidth_raw_iqr', 'flops_raw_p10', 'flops_raw_median',
|
||||
'flops_raw_p90', 'flops_raw_mad', 'flops_raw_range'
|
||||
]
|
||||
|
||||
samples = []
|
||||
for _ in range(n_samples):
|
||||
sample = {}
|
||||
for col in feature_columns:
|
||||
if 'bandwidth' in col:
|
||||
sample[col] = np.random.uniform(50, 500)
|
||||
elif 'flops' in col:
|
||||
sample[col] = np.random.uniform(100, 5000)
|
||||
elif 'node_num' in col:
|
||||
sample[col] = np.random.randint(1, 16)
|
||||
else:
|
||||
sample[col] = np.random.uniform(0, 1000)
|
||||
samples.append(sample)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("XGBoost Multi-Label Inference API")
|
||||
print("=" * 40)
|
||||
|
||||
# Initialize predictor
|
||||
try:
|
||||
predictor = XGBoostMultiLabelPredictor()
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
exit(1)
|
||||
|
||||
# Example usage of df_aggregate with JSON string
|
||||
print("\n=== Example 0: JSON Aggregation ===")
|
||||
sample_json = json.dumps([
|
||||
{
|
||||
"node_num": 1,
|
||||
"bandwidth_raw": 150.5,
|
||||
"flops_raw": 2500.0,
|
||||
"arith_intensity": 16.6,
|
||||
"performance_gflops": 1200.0,
|
||||
"memory_bw_gbs": 450,
|
||||
"scalar_peak_gflops": 600,
|
||||
"duration": 3600
|
||||
},
|
||||
{
|
||||
"node_num": 2,
|
||||
"bandwidth_raw": 155.2,
|
||||
"flops_raw": 2600.0,
|
||||
"arith_intensity": 16.8,
|
||||
"performance_gflops": 1250.0,
|
||||
"memory_bw_gbs": 450,
|
||||
"scalar_peak_gflops": 600,
|
||||
"duration": 3600
|
||||
}
|
||||
])
|
||||
|
||||
try:
|
||||
aggregated_features = df_aggregate(sample_json, job_id_full="test_job_123")
|
||||
print(f"Aggregated features from JSON:")
|
||||
for key, value in list(aggregated_features.items())[:10]:
|
||||
print(f" {key}: {value}")
|
||||
print(f" ... ({len(aggregated_features)} total features)")
|
||||
|
||||
# Use aggregated features for prediction
|
||||
result = predictor.predict(aggregated_features, threshold=0.3)
|
||||
print(f"\nPredictions from aggregated data: {result['predictions'][:3]}")
|
||||
except Exception as e:
|
||||
print(f"Error in aggregation: {e}")
|
||||
|
||||
# Create sample data
|
||||
print("\n=== Generating sample data for other examples ===")
|
||||
sample_data = create_sample_data(3)
|
||||
|
||||
# Example 1: Single prediction
|
||||
print("\n=== Example 1: Single Prediction ===")
|
||||
result = predictor.predict(sample_data[0], threshold=0.3)
|
||||
print(f"Predictions: {result['predictions']}")
|
||||
print(f"Confidences: {result['confidences']}")
|
||||
print(f"Top probabilities:")
|
||||
for class_name, prob in sorted(result['probabilities'].items(),
|
||||
key=lambda x: x[1], reverse=True)[:5]:
|
||||
print(".4f")
|
||||
|
||||
# Example 2: Top-K predictions
|
||||
print("\n=== Example 2: Top-5 Predictions ===")
|
||||
top_result = predictor.predict_top_k(sample_data[1], k=5)
|
||||
for i, class_name in enumerate(top_result['top_predictions'], 1):
|
||||
prob = top_result['top_probabilities'][class_name]
|
||||
print(f"{i}. {class_name}: {prob:.4f}")
|
||||
|
||||
# Example 3: Batch prediction
|
||||
print("\n=== Example 3: Batch Prediction ===")
|
||||
batch_results = predictor.batch_predict(sample_data, threshold=0.4)
|
||||
for i, result in enumerate(batch_results, 1):
|
||||
if 'error' not in result:
|
||||
print(f"Sample {i}: {len(result['predictions'])} predictions")
|
||||
else:
|
||||
print(f"Sample {i}: Error - {result['error']}")
|
||||
|
||||
print("\nAPI ready for use!")
|
||||
print("Usage:")
|
||||
print(" predictor = XGBoostMultiLabelPredictor()")
|
||||
print(" result = predictor.predict(your_features)")
|
||||
print(" top_k = predictor.predict_top_k(your_features, k=5)")
|
||||
221
xgb_local_example.py
Normal file
221
xgb_local_example.py
Normal file
@@ -0,0 +1,221 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
XGBoost Multi-Label Inference Usage Examples
|
||||
===================================================
|
||||
|
||||
This script demonstrates how to use the XGBoostMultiLabelPredictor class
|
||||
for multi-label classification with confidence scores.
|
||||
|
||||
Sample data is from real HPC workloads extracted from roofline_features.h5:
|
||||
- TurTLE: Turbulence simulation (memory-bound, low arithmetic intensity ~0.84)
|
||||
- SCALEXA: Scaling benchmarks (high bw-flops correlation ~0.995)
|
||||
- Chroma: Lattice QCD (compute-intensive, high arithmetic intensity ~2.6)
|
||||
"""
|
||||
|
||||
import json
|
||||
from xgb_local import XGBoostMultiLabelPredictor
|
||||
|
||||
# ============================================================================
|
||||
# Realistic Sample Data from roofline_features.h5
|
||||
# ============================================================================
|
||||
|
||||
# TurTLE application - turbulence simulation workload
|
||||
SAMPLE_TURTLE = {
|
||||
"bandwidth_raw_p10": 186.33,
|
||||
"bandwidth_raw_median": 205.14,
|
||||
"bandwidth_raw_p90": 210.83,
|
||||
"bandwidth_raw_mad": 3.57,
|
||||
"bandwidth_raw_range": 24.5,
|
||||
"bandwidth_raw_iqr": 12.075,
|
||||
"flops_raw_p10": 162.024,
|
||||
"flops_raw_median": 171.45,
|
||||
"flops_raw_p90": 176.48,
|
||||
"flops_raw_mad": 3.08,
|
||||
"flops_raw_range": 14.456,
|
||||
"flops_raw_iqr": 8.29,
|
||||
"arith_intensity_p10": 0.7906,
|
||||
"arith_intensity_median": 0.837,
|
||||
"arith_intensity_p90": 0.9109,
|
||||
"arith_intensity_mad": 0.02,
|
||||
"arith_intensity_range": 0.12,
|
||||
"arith_intensity_iqr": 0.0425,
|
||||
"bw_flops_covariance": 60.86,
|
||||
"bw_flops_correlation": 0.16,
|
||||
"avg_performance_gflops": 168.1,
|
||||
"median_performance_gflops": 171.45,
|
||||
"performance_gflops_mad": 3.08,
|
||||
"avg_memory_bw_gbs": 350.0,
|
||||
"scalar_peak_gflops": 432.0,
|
||||
"simd_peak_gflops": 9216.0,
|
||||
"node_num": 0,
|
||||
"duration": 19366,
|
||||
}
|
||||
|
||||
# SCALEXA application - scaling benchmark workload
|
||||
SAMPLE_SCALEXA = {
|
||||
"bandwidth_raw_p10": 13.474,
|
||||
"bandwidth_raw_median": 32.57,
|
||||
"bandwidth_raw_p90": 51.466,
|
||||
"bandwidth_raw_mad": 23.62,
|
||||
"bandwidth_raw_range": 37.992,
|
||||
"bandwidth_raw_iqr": 23.745,
|
||||
"flops_raw_p10": 4.24,
|
||||
"flops_raw_median": 16.16,
|
||||
"flops_raw_p90": 24.584,
|
||||
"flops_raw_mad": 10.53,
|
||||
"flops_raw_range": 20.344,
|
||||
"flops_raw_iqr": 12.715,
|
||||
"arith_intensity_p10": 0.211,
|
||||
"arith_intensity_median": 0.475,
|
||||
"arith_intensity_p90": 0.492,
|
||||
"arith_intensity_mad": 0.021,
|
||||
"arith_intensity_range": 0.281,
|
||||
"arith_intensity_iqr": 0.176,
|
||||
"bw_flops_covariance": 302.0,
|
||||
"bw_flops_correlation": 0.995,
|
||||
"avg_performance_gflops": 14.7,
|
||||
"median_performance_gflops": 16.16,
|
||||
"performance_gflops_mad": 10.53,
|
||||
"avg_memory_bw_gbs": 350.0,
|
||||
"scalar_peak_gflops": 432.0,
|
||||
"simd_peak_gflops": 9216.0,
|
||||
"node_num": 18,
|
||||
"duration": 165,
|
||||
}
|
||||
|
||||
# Chroma application - lattice QCD workload (compute-intensive)
|
||||
SAMPLE_CHROMA = {
|
||||
"bandwidth_raw_p10": 154.176,
|
||||
"bandwidth_raw_median": 200.57,
|
||||
"bandwidth_raw_p90": 259.952,
|
||||
"bandwidth_raw_mad": 5.12,
|
||||
"bandwidth_raw_range": 105.776,
|
||||
"bandwidth_raw_iqr": 10.215,
|
||||
"flops_raw_p10": 327.966,
|
||||
"flops_raw_median": 519.8,
|
||||
"flops_raw_p90": 654.422,
|
||||
"flops_raw_mad": 16.97,
|
||||
"flops_raw_range": 326.456,
|
||||
"flops_raw_iqr": 34.88,
|
||||
"arith_intensity_p10": 1.55,
|
||||
"arith_intensity_median": 2.595,
|
||||
"arith_intensity_p90": 3.445,
|
||||
"arith_intensity_mad": 0.254,
|
||||
"arith_intensity_range": 1.894,
|
||||
"arith_intensity_iqr": 0.512,
|
||||
"bw_flops_covariance": 382.76,
|
||||
"bw_flops_correlation": 0.063,
|
||||
"avg_performance_gflops": 503.26,
|
||||
"median_performance_gflops": 519.8,
|
||||
"performance_gflops_mad": 16.97,
|
||||
"avg_memory_bw_gbs": 350.0,
|
||||
"scalar_peak_gflops": 432.0,
|
||||
"simd_peak_gflops": 9216.0,
|
||||
"node_num": 3,
|
||||
"duration": 31133,
|
||||
}
|
||||
|
||||
# Raw JSON roofline data (before aggregation, as produced by monitoring)
|
||||
SAMPLE_JSON_ROOFLINE = json.dumps([
|
||||
{"node_num": 1, "bandwidth_raw": 150.5, "flops_raw": 2500.0, "arith_intensity": 16.6,
|
||||
"performance_gflops": 1200.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
|
||||
{"node_num": 1, "bandwidth_raw": 155.2, "flops_raw": 2600.0, "arith_intensity": 16.8,
|
||||
"performance_gflops": 1250.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
|
||||
{"node_num": 1, "bandwidth_raw": 148.0, "flops_raw": 2450.0, "arith_intensity": 16.5,
|
||||
"performance_gflops": 1180.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
|
||||
])
|
||||
|
||||
|
||||
def main():
|
||||
print("XGBoost Multi-Label Inference Examples")
|
||||
print("=" * 50)
|
||||
|
||||
# Initialize the predictor
|
||||
predictor = XGBoostMultiLabelPredictor()
|
||||
|
||||
# =========================================================================
|
||||
# Example 1: Single prediction with aggregated features
|
||||
# =========================================================================
|
||||
print("\n=== Example 1: Single Prediction (TurTLE workload) ===")
|
||||
|
||||
result = predictor.predict(SAMPLE_TURTLE, threshold=0.3)
|
||||
|
||||
print(f"Predictions: {result['predictions']}")
|
||||
print(f"Confidences: {result['confidences']}")
|
||||
print("\nTop 5 probabilities:")
|
||||
sorted_probs = sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True)
|
||||
for cls, prob in sorted_probs[:5]:
|
||||
print(f" {cls}: {prob:.4f}")
|
||||
|
||||
# =========================================================================
|
||||
# Example 2: Compare different workload types
|
||||
# =========================================================================
|
||||
print("\n=== Example 2: Compare Different Workloads ===")
|
||||
|
||||
workloads = [
|
||||
("TurTLE (turbulence)", SAMPLE_TURTLE),
|
||||
("SCALEXA (benchmark)", SAMPLE_SCALEXA),
|
||||
("Chroma (lattice QCD)", SAMPLE_CHROMA),
|
||||
]
|
||||
|
||||
for name, features in workloads:
|
||||
result = predictor.predict(features, threshold=0.3)
|
||||
top_pred = result['predictions'][0] if result['predictions'] else "None"
|
||||
top_prob = max(result['probabilities'].values()) if result['probabilities'] else 0
|
||||
print(f"{name:25} -> Top prediction: {top_pred:20} (prob: {top_prob:.4f})")
|
||||
|
||||
# =========================================================================
|
||||
# Example 3: Top-K predictions
|
||||
# =========================================================================
|
||||
print("\n=== Example 3: Top-5 Predictions (Chroma workload) ===")
|
||||
|
||||
top_k_result = predictor.predict_top_k(SAMPLE_CHROMA, k=5)
|
||||
|
||||
for i, cls in enumerate(top_k_result['top_predictions'], 1):
|
||||
prob = top_k_result['top_probabilities'][cls]
|
||||
print(f" {i}. {cls}: {prob:.4f}")
|
||||
|
||||
# =========================================================================
|
||||
# Example 4: Batch prediction
|
||||
# =========================================================================
|
||||
print("\n=== Example 4: Batch Prediction ===")
|
||||
|
||||
batch_data = [SAMPLE_TURTLE, SAMPLE_SCALEXA, SAMPLE_CHROMA]
|
||||
batch_results = predictor.batch_predict(batch_data, threshold=0.3)
|
||||
|
||||
for i, result in enumerate(batch_results, 1):
|
||||
if 'error' not in result:
|
||||
preds = result['predictions'][:2] # Show top 2
|
||||
print(f" Sample {i}: {preds}")
|
||||
else:
|
||||
print(f" Sample {i}: Error - {result['error']}")
|
||||
|
||||
# =========================================================================
|
||||
# Example 5: Prediction from raw JSON roofline data
|
||||
# =========================================================================
|
||||
print("\n=== Example 5: Prediction from Raw JSON Data ===")
|
||||
|
||||
result = predictor.predict(
|
||||
SAMPLE_JSON_ROOFLINE,
|
||||
is_json=True,
|
||||
job_id="example_job_001",
|
||||
threshold=0.3
|
||||
)
|
||||
|
||||
print(f"Predictions: {result['predictions'][:3]}")
|
||||
print(f"(Aggregated from {len(json.loads(SAMPLE_JSON_ROOFLINE))} roofline samples)")
|
||||
|
||||
# =========================================================================
|
||||
# Example 6: Model information
|
||||
# =========================================================================
|
||||
print("\n=== Example 6: Model Information ===")
|
||||
|
||||
info = predictor.get_class_info()
|
||||
print(f"Number of classes: {info['n_classes']}")
|
||||
print(f"Number of features: {info['n_features']}")
|
||||
print(f"Sample classes: {info['classes'][:5]}...")
|
||||
print(f"Sample features: {info['feature_columns'][:3]}...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
xgb_model.joblib
Normal file
BIN
xgb_model.joblib
Normal file
Binary file not shown.
Reference in New Issue
Block a user