Initial commit
This commit is contained in:
62
.gitignore
vendored
Normal file
62
.gitignore
vendored
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env/
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Jupyter
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
|
||||||
|
# Model files (optional - uncomment if you don't want to track large models)
|
||||||
|
# *.joblib
|
||||||
|
# *.pkl
|
||||||
|
# *.h5
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
logs/
|
||||||
|
|
||||||
|
# Environment variables
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
# Go binaries
|
||||||
|
*.exe
|
||||||
|
*.exe~
|
||||||
|
*.dll
|
||||||
|
*.dylib
|
||||||
42
Makefile
Normal file
42
Makefile
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
.PHONY: install run-fastapi run-flask test clean help
|
||||||
|
|
||||||
|
# Default Python interpreter
|
||||||
|
PYTHON := python3
|
||||||
|
|
||||||
|
# Default ports
|
||||||
|
FASTAPI_PORT := 8000
|
||||||
|
FLASK_PORT := 5000
|
||||||
|
|
||||||
|
help:
|
||||||
|
@echo "XGBoost Multi-Label Classification API"
|
||||||
|
@echo ""
|
||||||
|
@echo "Usage:"
|
||||||
|
@echo " make install Install Python dependencies"
|
||||||
|
@echo " make run-fastapi Start FastAPI server (port $(FASTAPI_PORT))"
|
||||||
|
@echo " make run-flask Start Flask server (port $(FLASK_PORT))"
|
||||||
|
@echo " make test Run the inference example"
|
||||||
|
@echo " make clean Clean up cache files"
|
||||||
|
@echo ""
|
||||||
|
|
||||||
|
install:
|
||||||
|
$(PYTHON) -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
run-fastapi:
|
||||||
|
$(PYTHON) xgb_fastapi.py --port $(FASTAPI_PORT)
|
||||||
|
|
||||||
|
run-fastapi-dev:
|
||||||
|
$(PYTHON) xgb_fastapi.py --port $(FASTAPI_PORT) --reload
|
||||||
|
|
||||||
|
run-flask:
|
||||||
|
$(PYTHON) xgb_rest_api.py --port $(FLASK_PORT)
|
||||||
|
|
||||||
|
run-flask-debug:
|
||||||
|
$(PYTHON) xgb_rest_api.py --port $(FLASK_PORT) --debug
|
||||||
|
|
||||||
|
test:
|
||||||
|
$(PYTHON) xgb_inference_example.py
|
||||||
|
|
||||||
|
clean:
|
||||||
|
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||||
|
find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||||
|
find . -type f -name "*.pyo" -delete 2>/dev/null || true
|
||||||
105
README.md
Normal file
105
README.md
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# XGBoost Multi-Label Classification API
|
||||||
|
|
||||||
|
A REST API for multi-label classification of HPC workloads using XGBoost. Classifies applications based on roofline performance metrics.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Multi-Label Classification**: Predict multiple labels with confidence scores
|
||||||
|
- **Top-K Predictions**: Get the most likely K predictions
|
||||||
|
- **Batch Prediction**: Process multiple samples in a single request
|
||||||
|
- **JSON Aggregation**: Aggregate raw roofline data into features automatically
|
||||||
|
- **Around 60 HPC Application Classes**: Including VASP, GROMACS, TurTLE, Chroma, QuantumESPRESSO, etc.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Create and activate virtual environment
|
||||||
|
python -m venv venv
|
||||||
|
source venv/bin/activate
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Run Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Python tests
|
||||||
|
pytest test_xgb_fastapi.py -v
|
||||||
|
|
||||||
|
# Curl tests (start server first)
|
||||||
|
./test_api_curl.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
| Endpoint | Method | Description |
|
||||||
|
|----------|--------|-------------|
|
||||||
|
| `/health` | GET | Health check and model status |
|
||||||
|
| `/predict` | POST | Single prediction with confidence scores |
|
||||||
|
| `/predict_top_k` | POST | Get top-K predictions |
|
||||||
|
| `/batch_predict` | POST | Batch prediction for multiple samples |
|
||||||
|
| `/model/info` | GET | Model information |
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
|
||||||
|
### Start the Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python xgb_fastapi.py --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing with curl
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://localhost:8000/predict" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"features": {
|
||||||
|
"bandwidth_raw_p10": 186.33,
|
||||||
|
"bandwidth_raw_median": 205.14,
|
||||||
|
"bandwidth_raw_p90": 210.83,
|
||||||
|
"flops_raw_p10": 162.024,
|
||||||
|
"flops_raw_median": 171.45,
|
||||||
|
"flops_raw_p90": 176.48,
|
||||||
|
"arith_intensity_median": 0.837,
|
||||||
|
"node_num": 0,
|
||||||
|
"duration": 19366
|
||||||
|
},
|
||||||
|
"threshold": 0.3
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing with Python
|
||||||
|
|
||||||
|
```python
|
||||||
|
from xgb_local import XGBoostMultiLabelPredictor
|
||||||
|
|
||||||
|
predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')
|
||||||
|
|
||||||
|
result = predictor.predict(features, threshold=0.3)
|
||||||
|
print(f"Predictions: {result['predictions']}")
|
||||||
|
print(f"Confidences: {result['confidences']}")
|
||||||
|
|
||||||
|
# Top-K predictions
|
||||||
|
top_k = predictor.predict_top_k(features, k=5)
|
||||||
|
for cls, prob in top_k['top_probabilities'].items():
|
||||||
|
print(f"{cls}: {prob:.4f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
See `xgb_local_example.py` for complete examples.
|
||||||
|
|
||||||
|
## Model Features (28 total)
|
||||||
|
|
||||||
|
| Category | Features |
|
||||||
|
|----------|----------|
|
||||||
|
| Bandwidth | `bandwidth_raw_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
|
||||||
|
| FLOPS | `flops_raw_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
|
||||||
|
| Arithmetic Intensity | `arith_intensity_p10`, `_median`, `_p90`, `_mad`, `_range`, `_iqr` |
|
||||||
|
| Performance | `avg_performance_gflops`, `median_performance_gflops`, `performance_gflops_mad` |
|
||||||
|
| Correlation | `bw_flops_covariance`, `bw_flops_correlation` |
|
||||||
|
| System | `avg_memory_bw_gbs`, `scalar_peak_gflops`, `simd_peak_gflops` |
|
||||||
|
| Other | `node_num`, `duration` |
|
||||||
600
cluster.json
Executable file
600
cluster.json
Executable file
@@ -0,0 +1,600 @@
|
|||||||
|
{
|
||||||
|
"name": "fritz",
|
||||||
|
"metricConfig": [
|
||||||
|
{
|
||||||
|
"name": "cpu_load",
|
||||||
|
"unit": {
|
||||||
|
"base": ""
|
||||||
|
},
|
||||||
|
"scope": "node",
|
||||||
|
"aggregation": "avg",
|
||||||
|
"footprint": "avg",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 72,
|
||||||
|
"normal": 72,
|
||||||
|
"caution": 36,
|
||||||
|
"alert": 20,
|
||||||
|
"subClusters": [
|
||||||
|
{
|
||||||
|
"name": "spr1tb",
|
||||||
|
"footprint": "avg",
|
||||||
|
"peak": 104,
|
||||||
|
"normal": 104,
|
||||||
|
"caution": 52,
|
||||||
|
"alert": 20
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "spr2tb",
|
||||||
|
"footprint": "avg",
|
||||||
|
"peak": 104,
|
||||||
|
"normal": 104,
|
||||||
|
"caution": 52,
|
||||||
|
"alert": 20
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "cpu_user",
|
||||||
|
"unit": {
|
||||||
|
"base": ""
|
||||||
|
},
|
||||||
|
"scope": "hwthread",
|
||||||
|
"aggregation": "avg",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 100,
|
||||||
|
"normal": 50,
|
||||||
|
"caution": 20,
|
||||||
|
"alert": 10
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mem_used",
|
||||||
|
"unit": {
|
||||||
|
"base": "B",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"scope": "node",
|
||||||
|
"aggregation": "sum",
|
||||||
|
"footprint": "max",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 256,
|
||||||
|
"normal": 128,
|
||||||
|
"caution": 200,
|
||||||
|
"alert": 240,
|
||||||
|
"subClusters": [
|
||||||
|
{
|
||||||
|
"name": "spr1tb",
|
||||||
|
"footprint": "max",
|
||||||
|
"peak": 1024,
|
||||||
|
"normal": 512,
|
||||||
|
"caution": 900,
|
||||||
|
"alert": 1000
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "spr2tb",
|
||||||
|
"footprint": "max",
|
||||||
|
"peak": 2048,
|
||||||
|
"normal": 1024,
|
||||||
|
"caution": 1800,
|
||||||
|
"alert": 2000
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "flops_any",
|
||||||
|
"unit": {
|
||||||
|
"base": "Flops/s",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"scope": "hwthread",
|
||||||
|
"aggregation": "sum",
|
||||||
|
"footprint": "avg",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 5600,
|
||||||
|
"normal": 1000,
|
||||||
|
"caution": 200,
|
||||||
|
"alert": 50,
|
||||||
|
"subClusters": [
|
||||||
|
{
|
||||||
|
"name": "spr1tb",
|
||||||
|
"peak": 6656,
|
||||||
|
"normal": 1500,
|
||||||
|
"caution": 400,
|
||||||
|
"alert": 50,
|
||||||
|
"footprint": "avg"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "spr2tb",
|
||||||
|
"peak": 6656,
|
||||||
|
"normal": 1500,
|
||||||
|
"caution": 400,
|
||||||
|
"alert": 50,
|
||||||
|
"footprint": "avg"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "flops_sp",
|
||||||
|
"unit": {
|
||||||
|
"base": "Flops/s",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"scope": "hwthread",
|
||||||
|
"aggregation": "sum",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 5600,
|
||||||
|
"normal": 1000,
|
||||||
|
"caution": 200,
|
||||||
|
"alert": 50,
|
||||||
|
"subClusters": [
|
||||||
|
{
|
||||||
|
"name": "spr1tb",
|
||||||
|
"peak": 6656,
|
||||||
|
"normal": 1500,
|
||||||
|
"caution": 400,
|
||||||
|
"alert": 50
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "spr2tb",
|
||||||
|
"peak": 6656,
|
||||||
|
"normal": 1500,
|
||||||
|
"caution": 400,
|
||||||
|
"alert": 50
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "flops_dp",
|
||||||
|
"unit": {
|
||||||
|
"base": "Flops/s",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"scope": "hwthread",
|
||||||
|
"aggregation": "sum",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 2300,
|
||||||
|
"normal": 500,
|
||||||
|
"caution": 100,
|
||||||
|
"alert": 50,
|
||||||
|
"subClusters": [
|
||||||
|
{
|
||||||
|
"name": "spr1tb",
|
||||||
|
"peak": 3300,
|
||||||
|
"normal": 750,
|
||||||
|
"caution": 200,
|
||||||
|
"alert": 50
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "spr2tb",
|
||||||
|
"peak": 3300,
|
||||||
|
"normal": 750,
|
||||||
|
"caution": 200,
|
||||||
|
"alert": 50
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mem_bw",
|
||||||
|
"unit": {
|
||||||
|
"base": "B/s",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"scope": "socket",
|
||||||
|
"aggregation": "sum",
|
||||||
|
"footprint": "avg",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 350,
|
||||||
|
"normal": 100,
|
||||||
|
"caution": 50,
|
||||||
|
"alert": 10,
|
||||||
|
"subClusters": [
|
||||||
|
{
|
||||||
|
"name": "spr1tb",
|
||||||
|
"footprint": "avg",
|
||||||
|
"peak": 549,
|
||||||
|
"normal": 200,
|
||||||
|
"caution": 100,
|
||||||
|
"alert": 20
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "spr2tb",
|
||||||
|
"footprint": "avg",
|
||||||
|
"peak": 520,
|
||||||
|
"normal": 200,
|
||||||
|
"caution": 100,
|
||||||
|
"alert": 20
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "clock",
|
||||||
|
"unit": {
|
||||||
|
"base": "Hz",
|
||||||
|
"prefix": "M"
|
||||||
|
},
|
||||||
|
"scope": "hwthread",
|
||||||
|
"aggregation": "avg",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 3000,
|
||||||
|
"normal": 2400,
|
||||||
|
"caution": 1800,
|
||||||
|
"alert": 1200,
|
||||||
|
"subClusters": [
|
||||||
|
{
|
||||||
|
"name": "spr1tb",
|
||||||
|
"peak": 3000,
|
||||||
|
"normal": 2000,
|
||||||
|
"caution": 1600,
|
||||||
|
"alert": 1200,
|
||||||
|
"remove": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "spr2tb",
|
||||||
|
"peak": 3000,
|
||||||
|
"normal": 2000,
|
||||||
|
"caution": 1600,
|
||||||
|
"alert": 1200,
|
||||||
|
"remove": false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "cpu_power",
|
||||||
|
"unit": {
|
||||||
|
"base": "W"
|
||||||
|
},
|
||||||
|
"scope": "socket",
|
||||||
|
"aggregation": "sum",
|
||||||
|
"energy": "power",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 500,
|
||||||
|
"normal": 250,
|
||||||
|
"caution": 100,
|
||||||
|
"alert": 50,
|
||||||
|
"subClusters": [
|
||||||
|
{
|
||||||
|
"name": "spr1tb",
|
||||||
|
"peak": 700,
|
||||||
|
"energy": "power",
|
||||||
|
"normal": 350,
|
||||||
|
"caution": 150,
|
||||||
|
"alert": 50
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "spr2tb",
|
||||||
|
"peak": 700,
|
||||||
|
"energy": "power",
|
||||||
|
"normal": 350,
|
||||||
|
"caution": 150,
|
||||||
|
"alert": 50
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mem_power",
|
||||||
|
"unit": {
|
||||||
|
"base": "W"
|
||||||
|
},
|
||||||
|
"scope": "socket",
|
||||||
|
"aggregation": "sum",
|
||||||
|
"energy": "power",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 100,
|
||||||
|
"normal": 50,
|
||||||
|
"caution": 20,
|
||||||
|
"alert": 10,
|
||||||
|
"subClusters": [
|
||||||
|
{
|
||||||
|
"name": "spr1tb",
|
||||||
|
"peak": 400,
|
||||||
|
"energy": "power",
|
||||||
|
"normal": 200,
|
||||||
|
"caution": 80,
|
||||||
|
"alert": 40
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "spr2tb",
|
||||||
|
"peak": 800,
|
||||||
|
"energy": "power",
|
||||||
|
"normal": 400,
|
||||||
|
"caution": 160,
|
||||||
|
"alert": 80
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "ipc",
|
||||||
|
"unit": {
|
||||||
|
"base": "IPC"
|
||||||
|
},
|
||||||
|
"scope": "hwthread",
|
||||||
|
"aggregation": "avg",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 4,
|
||||||
|
"normal": 2,
|
||||||
|
"caution": 1,
|
||||||
|
"alert": 0.5,
|
||||||
|
"subClusters": [
|
||||||
|
{
|
||||||
|
"name": "spr1tb",
|
||||||
|
"peak": 6,
|
||||||
|
"normal": 2,
|
||||||
|
"caution": 1,
|
||||||
|
"alert": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "spr2tb",
|
||||||
|
"peak": 6,
|
||||||
|
"normal": 2,
|
||||||
|
"caution": 1,
|
||||||
|
"alert": 0.5
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "vectorization_ratio",
|
||||||
|
"unit": {
|
||||||
|
"base": ""
|
||||||
|
},
|
||||||
|
"scope": "hwthread",
|
||||||
|
"aggregation": "avg",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 100,
|
||||||
|
"normal": 60,
|
||||||
|
"caution": 40,
|
||||||
|
"alert": 10,
|
||||||
|
"subClusters": [
|
||||||
|
{
|
||||||
|
"name": "spr1tb",
|
||||||
|
"peak": 100,
|
||||||
|
"normal": 60,
|
||||||
|
"caution": 40,
|
||||||
|
"alert": 10
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "spr2tb",
|
||||||
|
"peak": 100,
|
||||||
|
"normal": 60,
|
||||||
|
"caution": 40,
|
||||||
|
"alert": 10
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "ib_recv",
|
||||||
|
"unit": {
|
||||||
|
"base": "B/s"
|
||||||
|
},
|
||||||
|
"scope": "node",
|
||||||
|
"aggregation": "sum",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 1250000,
|
||||||
|
"normal": 6000000,
|
||||||
|
"caution": 200,
|
||||||
|
"alert": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "ib_xmit",
|
||||||
|
"unit": {
|
||||||
|
"base": "B/s"
|
||||||
|
},
|
||||||
|
"scope": "node",
|
||||||
|
"aggregation": "sum",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 1250000,
|
||||||
|
"normal": 6000000,
|
||||||
|
"caution": 200,
|
||||||
|
"alert": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "ib_recv_pkts",
|
||||||
|
"unit": {
|
||||||
|
"base": "packets/s"
|
||||||
|
},
|
||||||
|
"scope": "node",
|
||||||
|
"aggregation": "sum",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 6,
|
||||||
|
"normal": 4,
|
||||||
|
"caution": 2,
|
||||||
|
"alert": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "ib_xmit_pkts",
|
||||||
|
"unit": {
|
||||||
|
"base": "packets/s"
|
||||||
|
},
|
||||||
|
"scope": "node",
|
||||||
|
"aggregation": "sum",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 6,
|
||||||
|
"normal": 4,
|
||||||
|
"caution": 2,
|
||||||
|
"alert": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "nfs4_read",
|
||||||
|
"unit": {
|
||||||
|
"base": "IOP",
|
||||||
|
"prefix": ""
|
||||||
|
},
|
||||||
|
"scope": "node",
|
||||||
|
"aggregation": "sum",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 1000,
|
||||||
|
"normal": 50,
|
||||||
|
"caution": 200,
|
||||||
|
"alert": 500
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "nfs4_total",
|
||||||
|
"unit": {
|
||||||
|
"base": "IOP",
|
||||||
|
"prefix": ""
|
||||||
|
},
|
||||||
|
"scope": "node",
|
||||||
|
"aggregation": "sum",
|
||||||
|
"timestep": 60,
|
||||||
|
"peak": 1000,
|
||||||
|
"normal": 50,
|
||||||
|
"caution": 200,
|
||||||
|
"alert": 500
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"subClusters": [
|
||||||
|
{
|
||||||
|
"name": "main",
|
||||||
|
"nodes": "f[0101-0188,0201-0288,0301-0388,0401-0488,0501-0588,0601-0688,0701-0788,0801-0888,0901-0988,1001-1088,1101-1156,1201-1256]",
|
||||||
|
"processorType": "Intel Icelake",
|
||||||
|
"socketsPerNode": 2,
|
||||||
|
"coresPerSocket": 36,
|
||||||
|
"threadsPerCore": 1,
|
||||||
|
"flopRateScalar": {
|
||||||
|
"unit": {
|
||||||
|
"base": "F/s",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"value": 432
|
||||||
|
},
|
||||||
|
"flopRateSimd": {
|
||||||
|
"unit": {
|
||||||
|
"base": "F/s",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"value": 9216
|
||||||
|
},
|
||||||
|
"memoryBandwidth": {
|
||||||
|
"unit": {
|
||||||
|
"base": "B/s",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"value": 350
|
||||||
|
},
|
||||||
|
"topology": {
|
||||||
|
"node": [
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71
|
||||||
|
],
|
||||||
|
"socket": [
|
||||||
|
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35 ],
|
||||||
|
[ 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71 ]
|
||||||
|
],
|
||||||
|
"memoryDomain": [
|
||||||
|
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17 ],
|
||||||
|
[ 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35 ],
|
||||||
|
[ 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53 ],
|
||||||
|
[ 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71 ]
|
||||||
|
],
|
||||||
|
"core": [
|
||||||
|
[ 0 ], [ 1 ], [ 2 ], [ 3 ], [ 4 ], [ 5 ], [ 6 ], [ 7 ], [ 8 ], [ 9 ], [ 10 ], [ 11 ], [ 12 ], [ 13 ], [ 14 ], [ 15 ], [ 16 ], [ 17 ], [ 18 ], [ 19 ], [ 20 ], [ 21 ], [ 22 ], [ 23 ], [ 24 ], [ 25 ], [ 26 ], [ 27 ], [ 28 ], [ 29 ], [ 30 ], [ 31 ], [ 32 ], [ 33 ], [ 34 ], [ 35 ], [ 36 ], [ 37 ], [ 38 ], [ 39 ], [ 40 ], [ 41 ], [ 42 ], [ 43 ], [ 44 ], [ 45 ], [ 46 ], [ 47 ], [ 48 ], [ 49 ], [ 50 ], [ 51 ], [ 52 ], [ 53 ], [ 54 ], [ 55 ], [ 56 ], [ 57 ], [ 58 ], [ 59 ], [ 60 ], [ 61 ], [ 62 ], [ 63 ], [ 64 ], [ 65 ], [ 66 ], [ 67 ], [ 68 ], [ 69 ], [ 70 ], [ 71 ]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "spr1tb",
|
||||||
|
"processorType": "Intel(R) Xeon(R) Platinum 8470",
|
||||||
|
"socketsPerNode": 2,
|
||||||
|
"coresPerSocket": 52,
|
||||||
|
"threadsPerCore": 1,
|
||||||
|
"flopRateScalar": {
|
||||||
|
"unit": {
|
||||||
|
"base": "F/s",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"value": 695
|
||||||
|
},
|
||||||
|
"flopRateSimd": {
|
||||||
|
"unit": {
|
||||||
|
"base": "F/s",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"value": 9216
|
||||||
|
},
|
||||||
|
"memoryBandwidth": {
|
||||||
|
"unit": {
|
||||||
|
"base": "B/s",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"value": 549
|
||||||
|
},
|
||||||
|
"nodes": "f[2157-2180,2257-2280]",
|
||||||
|
"topology": {
|
||||||
|
"node":
|
||||||
|
[
|
||||||
|
0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,5152,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103
|
||||||
|
],
|
||||||
|
"socket":
|
||||||
|
[
|
||||||
|
[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51],
|
||||||
|
[52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103]
|
||||||
|
],
|
||||||
|
"memoryDomain": [
|
||||||
|
[0,1,2,3,4,5,6,7,8,9,10,11,12],
|
||||||
|
[13,14,15,16,17,18,19,20,21,22,23,24,25],
|
||||||
|
[26,27,28,29,30,31,32,33,34,35,36,37,38],
|
||||||
|
[39,40,41,42,43,44,45,46,47,48,49,50,51],
|
||||||
|
[52,53,54,55,56,57,58,59,60,61,62,63,64],
|
||||||
|
[65,66,67,68,69,70,71,72,73,74,75,76,77],
|
||||||
|
[78,79,80,81,82,83,84,85,86,87,88,89,90],
|
||||||
|
[91,92,93,94,95,96,97,98,99,100,101,102,103]
|
||||||
|
],
|
||||||
|
"core": [
|
||||||
|
[0],[1],[2],[3],[4],[5],[6],[7],[8],[9],[10],[11],[12],[13],[14],[15],[16],[17],[18],[19],[20],[21],[22],[23],[24],[25],[26],[27],[28],[29],[30],[31],[32],[33],[34],[35],[36],[37],[38],[39],[40],[41],[42],[43],[44],[45],[46],[47],[48],[49],[50],[51],[52],[53],[54],[55],[56],[57],[58],[59],[60],[61],[62],[63],[64],[65],[66],[67],[68],[69],[70],[71],[72],[73],[74],[75],[76],[77],[78],[79],[80],[81],[82],[83],[84],[85],[86],[87],[88],[89],[90],[91],[92],[93],[94],[95],[96],[97],[98],[99],[100],[101],[102],[103]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "spr2tb",
|
||||||
|
"processorType": "Intel(R) Xeon(R) Platinum 8470",
|
||||||
|
"socketsPerNode": 2,
|
||||||
|
"coresPerSocket": 52,
|
||||||
|
"threadsPerCore": 1,
|
||||||
|
"flopRateScalar": {
|
||||||
|
"unit": {
|
||||||
|
"base": "F/s",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"value": 695
|
||||||
|
},
|
||||||
|
"flopRateSimd": {
|
||||||
|
"unit": {
|
||||||
|
"base": "F/s",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"value": 9216
|
||||||
|
},
|
||||||
|
"memoryBandwidth": {
|
||||||
|
"unit": {
|
||||||
|
"base": "B/s",
|
||||||
|
"prefix": "G"
|
||||||
|
},
|
||||||
|
"value": 515
|
||||||
|
},
|
||||||
|
"nodes": "f[2181-2188,2281-2288]",
|
||||||
|
"topology": {
|
||||||
|
"node": [
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103
|
||||||
|
],
|
||||||
|
"socket": [
|
||||||
|
[
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51
|
||||||
|
],
|
||||||
|
[
|
||||||
|
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"memoryDomain": [
|
||||||
|
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 ],
|
||||||
|
[ 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 ],
|
||||||
|
[ 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38 ],
|
||||||
|
[ 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51 ],
|
||||||
|
[ 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64 ],
|
||||||
|
[ 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77 ],
|
||||||
|
[ 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90 ],
|
||||||
|
[ 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103 ]
|
||||||
|
],
|
||||||
|
"core": [
|
||||||
|
[ 0 ], [ 1 ], [ 2 ], [ 3 ], [ 4 ], [ 5 ], [ 6 ], [ 7 ], [ 8 ], [ 9 ], [ 10 ], [ 11 ], [ 12 ], [ 13 ], [ 14 ], [ 15 ], [ 16 ], [ 17 ], [ 18 ], [ 19 ], [ 20 ], [ 21 ], [ 22 ], [ 23 ], [ 24 ], [ 25 ], [ 26 ], [ 27 ], [ 28 ], [ 29 ], [ 30 ], [ 31 ], [ 32 ], [ 33 ], [ 34 ], [ 35 ], [ 36 ], [ 37 ], [ 38 ], [ 39 ], [ 40 ], [ 41 ], [ 42 ], [ 43 ], [ 44 ], [ 45 ], [ 46 ], [ 47 ], [ 48 ], [ 49 ], [ 50 ], [ 51 ], [ 52 ], [ 53 ], [ 54 ], [ 55 ], [ 56 ], [ 57 ], [ 58 ], [ 59 ], [ 60 ], [ 61 ], [ 62 ], [ 63 ], [ 64 ], [ 65 ], [ 66 ], [ 67 ], [ 68 ], [ 69 ], [ 70 ], [ 71 ], [ 72 ], [ 73 ], [ 74 ], [ 75 ], [ 76 ], [ 77 ], [ 78 ], [ 79 ], [ 80 ], [ 81 ], [ 82 ], [ 83 ], [ 84 ], [ 85 ], [ 86 ], [ 87 ], [ 88 ], [ 89 ], [ 90 ], [ 91 ], [ 92 ], [ 93 ], [ 94 ], [ 95 ], [ 96 ], [ 97 ], [ 98 ], [ 99 ], [ 100 ], [ 101 ], [ 102 ], [ 103 ]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
172
feature_aggregator.py
Normal file
172
feature_aggregator.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import sqlite3
|
||||||
|
import h5py
|
||||||
|
from glob import glob
|
||||||
|
from scipy import stats
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def compute_mad(data):
|
||||||
|
"""Compute Median Absolute Deviation: a robust measure of dispersion"""
|
||||||
|
return np.median(np.abs(data - np.median(data)))
|
||||||
|
|
||||||
|
def process_roofline_data(base_dir='D:/roofline_dataframes',
|
||||||
|
output_file='D:/roofline_features.h5',
|
||||||
|
job_tags_db='job_tags.db'):
|
||||||
|
"""
|
||||||
|
Process roofline data to extract features for machine learning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_dir: Directory containing roofline dataframes
|
||||||
|
output_file: Path to save the output features
|
||||||
|
job_tags_db: Path to SQLite database with job tags
|
||||||
|
"""
|
||||||
|
# Connect to the SQLite database
|
||||||
|
conn = sqlite3.connect(job_tags_db)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# List to store all job features
|
||||||
|
all_job_features = []
|
||||||
|
|
||||||
|
# Find all job prefix folders in the base directory
|
||||||
|
job_prefixes = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
|
||||||
|
|
||||||
|
for job_id_prefix in job_prefixes:
|
||||||
|
job_prefix_path = os.path.join(base_dir, job_id_prefix)
|
||||||
|
|
||||||
|
# Find all h5 files in this job prefix folder
|
||||||
|
h5_files = glob(os.path.join(job_prefix_path, '*_dataframe.h5'))
|
||||||
|
|
||||||
|
for h5_file in h5_files:
|
||||||
|
filename = os.path.basename(h5_file)
|
||||||
|
# Extract job_id_full from the filename pattern: {job_id_full}_{timestamp}_dataframe.h5
|
||||||
|
job_id_full = filename.split('_')[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read the dataframe from the h5 file
|
||||||
|
with h5py.File(h5_file, 'r') as f:
|
||||||
|
if 'dataframe' not in f:
|
||||||
|
print(f"Warning: No 'dataframe' key in {h5_file}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
df = pd.read_hdf(h5_file, key='dataframe')
|
||||||
|
|
||||||
|
# Group data by node_num
|
||||||
|
grouped = df.groupby('node_num')
|
||||||
|
|
||||||
|
for node_num, group in grouped:
|
||||||
|
features = {
|
||||||
|
'job_id': job_id_full,
|
||||||
|
'node_num': node_num
|
||||||
|
}
|
||||||
|
|
||||||
|
# Compute statistics for key metrics
|
||||||
|
for axis in ['bandwidth_raw', 'flops_raw', 'arith_intensity']:
|
||||||
|
data = group[axis].values
|
||||||
|
|
||||||
|
# Compute percentiles
|
||||||
|
p10 = np.percentile(data, 10)
|
||||||
|
p50 = np.median(data)
|
||||||
|
p90 = np.percentile(data, 90)
|
||||||
|
|
||||||
|
# Compute MAD (more robust than variance)
|
||||||
|
mad = compute_mad(data)
|
||||||
|
|
||||||
|
# Store features
|
||||||
|
features[f'{axis}_p10'] = p10
|
||||||
|
features[f'{axis}_median'] = p50
|
||||||
|
features[f'{axis}_p90'] = p90
|
||||||
|
features[f'{axis}_mad'] = mad
|
||||||
|
features[f'{axis}_range'] = p90 - p10
|
||||||
|
features[f'{axis}_iqr'] = np.percentile(data, 75) - np.percentile(data, 25)
|
||||||
|
|
||||||
|
# Compute covariance and correlation between bandwidth_raw and flops_raw
|
||||||
|
if len(group) > 1: # Need at least 2 points for correlation
|
||||||
|
cov = np.cov(group['bandwidth_raw'], group['flops_raw'])[0, 1]
|
||||||
|
features['bw_flops_covariance'] = cov
|
||||||
|
|
||||||
|
corr, _ = stats.pearsonr(group['bandwidth_raw'], group['flops_raw'])
|
||||||
|
features['bw_flops_correlation'] = corr
|
||||||
|
|
||||||
|
# Additional useful features for the classifier
|
||||||
|
|
||||||
|
# Performance metrics
|
||||||
|
features['avg_performance_gflops'] = group['performance_gflops'].mean()
|
||||||
|
features['median_performance_gflops'] = group['performance_gflops'].median()
|
||||||
|
features['performance_gflops_mad'] = compute_mad(group['performance_gflops'].values)
|
||||||
|
|
||||||
|
# # Efficiency metrics
|
||||||
|
# features['avg_efficiency'] = group['efficiency'].mean()
|
||||||
|
# features['median_efficiency'] = group['efficiency'].median()
|
||||||
|
# features['efficiency_mad'] = compute_mad(group['efficiency'].values)
|
||||||
|
# features['efficiency_p10'] = np.percentile(group['efficiency'].values, 10)
|
||||||
|
# features['efficiency_p90'] = np.percentile(group['efficiency'].values, 90)
|
||||||
|
|
||||||
|
# # Distribution of roofline regions (memory-bound vs compute-bound)
|
||||||
|
# if 'roofline_region' in group.columns:
|
||||||
|
# region_counts = group['roofline_region'].value_counts(normalize=True).to_dict()
|
||||||
|
# for region, ratio in region_counts.items():
|
||||||
|
# features[f'region_{region}_ratio'] = ratio
|
||||||
|
|
||||||
|
# System characteristics
|
||||||
|
if 'memory_bw_gbs' in group.columns:
|
||||||
|
features['avg_memory_bw_gbs'] = group['memory_bw_gbs'].mean()
|
||||||
|
if 'scalar_peak_gflops' in group.columns and len(group['scalar_peak_gflops'].unique()) > 0:
|
||||||
|
features['scalar_peak_gflops'] = group['scalar_peak_gflops'].iloc[0]
|
||||||
|
if 'simd_peak_gflops' in group.columns and len(group['simd_peak_gflops'].unique()) > 0:
|
||||||
|
features['simd_peak_gflops'] = group['simd_peak_gflops'].iloc[0]
|
||||||
|
|
||||||
|
# # Subcluster information if available
|
||||||
|
# if 'subcluster_name' in group.columns and not group['subcluster_name'].isna().all():
|
||||||
|
# features['subcluster_name'] = group['subcluster_name'].iloc[0]
|
||||||
|
|
||||||
|
# Duration information
|
||||||
|
if 'duration' in group.columns:
|
||||||
|
features['duration'] = group['duration'].iloc[0]
|
||||||
|
|
||||||
|
# Get the label (application type) from the database
|
||||||
|
cursor.execute("SELECT tags FROM job_tags WHERE job_id = ?", (int(job_id_full),))
|
||||||
|
result = cursor.fetchone()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
# Extract application name from tags
|
||||||
|
tags = result[0]
|
||||||
|
features['label'] = tags
|
||||||
|
else:
|
||||||
|
features['label'] = 'Unknown'
|
||||||
|
|
||||||
|
all_job_features.append(features)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing file {h5_file}: {e}")
|
||||||
|
|
||||||
|
# Close database connection
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if not all_job_features:
|
||||||
|
print("No features extracted. Check if files exist and have the correct format.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert to DataFrame
|
||||||
|
features_df = pd.DataFrame(all_job_features)
|
||||||
|
|
||||||
|
# Fill missing roofline region ratios with 0
|
||||||
|
region_columns = [col for col in features_df.columns if col.startswith('region_')]
|
||||||
|
for col in region_columns:
|
||||||
|
if col not in features_df.columns:
|
||||||
|
features_df[col] = 0
|
||||||
|
else:
|
||||||
|
features_df[col] = features_df[col].fillna(0)
|
||||||
|
|
||||||
|
# Save to H5 file
|
||||||
|
features_df.to_hdf(output_file, key='features', mode='w')
|
||||||
|
|
||||||
|
print(f"Processed {len(all_job_features)} job-node combinations")
|
||||||
|
print(f"Features saved to {output_file}")
|
||||||
|
print(f"Feature columns: {', '.join(features_df.columns)}")
|
||||||
|
|
||||||
|
return features_df
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
process_roofline_data()
|
||||||
20
requirements.txt
Normal file
20
requirements.txt
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# Core ML dependencies
|
||||||
|
numpy>=1.21.0
|
||||||
|
pandas>=1.3.0
|
||||||
|
scipy>=1.7.0
|
||||||
|
joblib>=1.0.0
|
||||||
|
xgboost>=1.5.0
|
||||||
|
scikit-learn>=1.0.0
|
||||||
|
|
||||||
|
# FastAPI server
|
||||||
|
fastapi>=0.68.0
|
||||||
|
uvicorn>=0.15.0
|
||||||
|
pydantic>=1.8.0
|
||||||
|
|
||||||
|
# Flask server (alternative)
|
||||||
|
flask>=2.0.0
|
||||||
|
flask-cors>=3.0.0
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
pytest>=7.0.0
|
||||||
|
httpx>=0.23.0
|
||||||
436
test_api_curl.sh
Executable file
436
test_api_curl.sh
Executable file
@@ -0,0 +1,436 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Test script for XGBoost FastAPI Multi-Label Classification API
|
||||||
|
# Uses curl to test all endpoints with realistic sample data from roofline_features.h5
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# 1. Start the server: ./venv/bin/python xgb_fastapi.py
|
||||||
|
# 2. Run this script: ./test_api_curl.sh
|
||||||
|
#
|
||||||
|
# Optional: Set API_URL environment variable to test a different host
|
||||||
|
# API_URL=http://192.168.1.100:8000 ./test_api_curl.sh
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
API_URL="${API_URL:-http://localhost:8000}"
|
||||||
|
VERBOSE="${VERBOSE:-false}"
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
BLUE='\033[0;34m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# Counters
|
||||||
|
PASSED=0
|
||||||
|
FAILED=0
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Helper Functions
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
print_header() {
|
||||||
|
echo ""
|
||||||
|
echo -e "${BLUE}============================================================${NC}"
|
||||||
|
echo -e "${BLUE}$1${NC}"
|
||||||
|
echo -e "${BLUE}============================================================${NC}"
|
||||||
|
}
|
||||||
|
|
||||||
|
print_test() {
|
||||||
|
echo -e "\n${YELLOW}▶ TEST: $1${NC}"
|
||||||
|
}
|
||||||
|
|
||||||
|
print_success() {
|
||||||
|
echo -e "${GREEN}✓ PASSED: $1${NC}"
|
||||||
|
((PASSED++))
|
||||||
|
}
|
||||||
|
|
||||||
|
print_failure() {
|
||||||
|
echo -e "${RED}✗ FAILED: $1${NC}"
|
||||||
|
((FAILED++))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make a curl request and check status code
|
||||||
|
# Usage: make_request METHOD ENDPOINT [DATA]
|
||||||
|
make_request() {
|
||||||
|
local method=$1
|
||||||
|
local endpoint=$2
|
||||||
|
local data=$3
|
||||||
|
local expected_status=${4:-200}
|
||||||
|
|
||||||
|
if [ "$method" == "GET" ]; then
|
||||||
|
response=$(curl -s -w "\n%{http_code}" "${API_URL}${endpoint}")
|
||||||
|
else
|
||||||
|
response=$(curl -s -w "\n%{http_code}" -X "$method" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d "$data" \
|
||||||
|
"${API_URL}${endpoint}")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract status code (last line) and body (everything else)
|
||||||
|
http_code=$(echo "$response" | tail -n1)
|
||||||
|
body=$(echo "$response" | sed '$d')
|
||||||
|
|
||||||
|
if [ "$VERBOSE" == "true" ]; then
|
||||||
|
echo "Response: $body"
|
||||||
|
echo "Status: $http_code"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$http_code" == "$expected_status" ]; then
|
||||||
|
return 0
|
||||||
|
else
|
||||||
|
echo "Expected status $expected_status, got $http_code"
|
||||||
|
echo "Response: $body"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Sample Data (from roofline_features.h5)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# TurTLE application - turbulence simulation workload
|
||||||
|
SAMPLE_TURTLE='{
|
||||||
|
"bandwidth_raw_p10": 186.33,
|
||||||
|
"bandwidth_raw_median": 205.14,
|
||||||
|
"bandwidth_raw_p90": 210.83,
|
||||||
|
"bandwidth_raw_mad": 3.57,
|
||||||
|
"bandwidth_raw_range": 24.5,
|
||||||
|
"bandwidth_raw_iqr": 12.075,
|
||||||
|
"flops_raw_p10": 162.024,
|
||||||
|
"flops_raw_median": 171.45,
|
||||||
|
"flops_raw_p90": 176.48,
|
||||||
|
"flops_raw_mad": 3.08,
|
||||||
|
"flops_raw_range": 14.456,
|
||||||
|
"flops_raw_iqr": 8.29,
|
||||||
|
"arith_intensity_p10": 0.7906,
|
||||||
|
"arith_intensity_median": 0.837,
|
||||||
|
"arith_intensity_p90": 0.9109,
|
||||||
|
"arith_intensity_mad": 0.02,
|
||||||
|
"arith_intensity_range": 0.12,
|
||||||
|
"arith_intensity_iqr": 0.0425,
|
||||||
|
"bw_flops_covariance": 60.86,
|
||||||
|
"bw_flops_correlation": 0.16,
|
||||||
|
"avg_performance_gflops": 168.1,
|
||||||
|
"median_performance_gflops": 171.45,
|
||||||
|
"performance_gflops_mad": 3.08,
|
||||||
|
"avg_memory_bw_gbs": 350.0,
|
||||||
|
"scalar_peak_gflops": 432.0,
|
||||||
|
"simd_peak_gflops": 9216.0,
|
||||||
|
"node_num": 0,
|
||||||
|
"duration": 19366
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Chroma application - lattice QCD workload (compute-intensive)
|
||||||
|
SAMPLE_CHROMA='{
|
||||||
|
"bandwidth_raw_p10": 154.176,
|
||||||
|
"bandwidth_raw_median": 200.57,
|
||||||
|
"bandwidth_raw_p90": 259.952,
|
||||||
|
"bandwidth_raw_mad": 5.12,
|
||||||
|
"bandwidth_raw_range": 105.776,
|
||||||
|
"bandwidth_raw_iqr": 10.215,
|
||||||
|
"flops_raw_p10": 327.966,
|
||||||
|
"flops_raw_median": 519.8,
|
||||||
|
"flops_raw_p90": 654.422,
|
||||||
|
"flops_raw_mad": 16.97,
|
||||||
|
"flops_raw_range": 326.456,
|
||||||
|
"flops_raw_iqr": 34.88,
|
||||||
|
"arith_intensity_p10": 1.55,
|
||||||
|
"arith_intensity_median": 2.595,
|
||||||
|
"arith_intensity_p90": 3.445,
|
||||||
|
"arith_intensity_mad": 0.254,
|
||||||
|
"arith_intensity_range": 1.894,
|
||||||
|
"arith_intensity_iqr": 0.512,
|
||||||
|
"bw_flops_covariance": 382.76,
|
||||||
|
"bw_flops_correlation": 0.063,
|
||||||
|
"avg_performance_gflops": 503.26,
|
||||||
|
"median_performance_gflops": 519.8,
|
||||||
|
"performance_gflops_mad": 16.97,
|
||||||
|
"avg_memory_bw_gbs": 350.0,
|
||||||
|
"scalar_peak_gflops": 432.0,
|
||||||
|
"simd_peak_gflops": 9216.0,
|
||||||
|
"node_num": 3,
|
||||||
|
"duration": 31133
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Raw JSON roofline data (before aggregation)
|
||||||
|
SAMPLE_JSON_ROOFLINE='[
|
||||||
|
{"node_num": 1, "bandwidth_raw": 150.5, "flops_raw": 2500.0, "arith_intensity": 16.6, "performance_gflops": 1200.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
|
||||||
|
{"node_num": 1, "bandwidth_raw": 155.2, "flops_raw": 2600.0, "arith_intensity": 16.8, "performance_gflops": 1250.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
|
||||||
|
{"node_num": 1, "bandwidth_raw": 148.0, "flops_raw": 2450.0, "arith_intensity": 16.5, "performance_gflops": 1180.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600}
|
||||||
|
]'
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
print_header "XGBoost FastAPI Test Suite"
|
||||||
|
echo "Testing API at: $API_URL"
|
||||||
|
echo "Started at: $(date)"
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# Health & Info Endpoints
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
print_header "Health & Info Endpoints"
|
||||||
|
|
||||||
|
print_test "Root Endpoint"
|
||||||
|
if make_request GET "/"; then
|
||||||
|
print_success "Root endpoint accessible"
|
||||||
|
else
|
||||||
|
print_failure "Root endpoint failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_test "Health Check"
|
||||||
|
if make_request GET "/health"; then
|
||||||
|
print_success "Health check passed"
|
||||||
|
else
|
||||||
|
print_failure "Health check failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_test "Model Info"
|
||||||
|
if make_request GET "/model/info"; then
|
||||||
|
print_success "Model info retrieved"
|
||||||
|
else
|
||||||
|
print_failure "Model info failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# Single Prediction Endpoints
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
print_header "Single Prediction Endpoint (/predict)"
|
||||||
|
|
||||||
|
print_test "Predict with TurTLE sample (threshold=0.3)"
|
||||||
|
REQUEST_DATA=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"features": $SAMPLE_TURTLE,
|
||||||
|
"threshold": 0.3,
|
||||||
|
"return_all_probabilities": true
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
if make_request POST "/predict" "$REQUEST_DATA"; then
|
||||||
|
print_success "TurTLE prediction completed"
|
||||||
|
else
|
||||||
|
print_failure "TurTLE prediction failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_test "Predict with Chroma sample (threshold=0.5)"
|
||||||
|
REQUEST_DATA=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"features": $SAMPLE_CHROMA,
|
||||||
|
"threshold": 0.5
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
if make_request POST "/predict" "$REQUEST_DATA"; then
|
||||||
|
print_success "Chroma prediction completed"
|
||||||
|
else
|
||||||
|
print_failure "Chroma prediction failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_test "Predict with JSON roofline data (is_json=true)"
|
||||||
|
# Need to escape the JSON for embedding
|
||||||
|
ESCAPED_JSON=$(echo "$SAMPLE_JSON_ROOFLINE" | tr -d '\n' | sed 's/"/\\"/g')
|
||||||
|
REQUEST_DATA=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"features": "$ESCAPED_JSON",
|
||||||
|
"is_json": true,
|
||||||
|
"job_id": "test_job_curl_001",
|
||||||
|
"threshold": 0.3
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
if make_request POST "/predict" "$REQUEST_DATA"; then
|
||||||
|
print_success "JSON roofline prediction completed"
|
||||||
|
else
|
||||||
|
print_failure "JSON roofline prediction failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_test "Predict with return_all_probabilities=false"
|
||||||
|
REQUEST_DATA=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"features": $SAMPLE_TURTLE,
|
||||||
|
"threshold": 0.3,
|
||||||
|
"return_all_probabilities": false
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
if make_request POST "/predict" "$REQUEST_DATA"; then
|
||||||
|
print_success "Prediction with filtered probabilities completed"
|
||||||
|
else
|
||||||
|
print_failure "Prediction with filtered probabilities failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# Top-K Prediction Endpoints
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
print_header "Top-K Prediction Endpoint (/predict_top_k)"
|
||||||
|
|
||||||
|
print_test "Top-5 predictions (default)"
|
||||||
|
REQUEST_DATA=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"features": $SAMPLE_TURTLE
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
|
||||||
|
print_success "Top-5 prediction completed"
|
||||||
|
else
|
||||||
|
print_failure "Top-5 prediction failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_test "Top-10 predictions"
|
||||||
|
REQUEST_DATA=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"features": $SAMPLE_CHROMA,
|
||||||
|
"k": 10
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
|
||||||
|
print_success "Top-10 prediction completed"
|
||||||
|
else
|
||||||
|
print_failure "Top-10 prediction failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_test "Top-3 predictions"
|
||||||
|
REQUEST_DATA=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"features": $SAMPLE_TURTLE,
|
||||||
|
"k": 3
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
|
||||||
|
print_success "Top-3 prediction completed"
|
||||||
|
else
|
||||||
|
print_failure "Top-3 prediction failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# Batch Prediction Endpoints
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
print_header "Batch Prediction Endpoint (/batch_predict)"
|
||||||
|
|
||||||
|
print_test "Batch predict with 2 samples"
|
||||||
|
REQUEST_DATA=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"features_list": [$SAMPLE_TURTLE, $SAMPLE_CHROMA],
|
||||||
|
"threshold": 0.3
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
|
||||||
|
print_success "Batch prediction (2 samples) completed"
|
||||||
|
else
|
||||||
|
print_failure "Batch prediction (2 samples) failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_test "Batch predict with single sample"
|
||||||
|
REQUEST_DATA=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"features_list": [$SAMPLE_TURTLE],
|
||||||
|
"threshold": 0.5
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
|
||||||
|
print_success "Batch prediction (single sample) completed"
|
||||||
|
else
|
||||||
|
print_failure "Batch prediction (single sample) failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_test "Batch predict with empty list"
|
||||||
|
REQUEST_DATA='{"features_list": [], "threshold": 0.5}'
|
||||||
|
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
|
||||||
|
print_success "Batch prediction (empty list) completed"
|
||||||
|
else
|
||||||
|
print_failure "Batch prediction (empty list) failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# Error Handling Tests
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
print_header "Error Handling Tests"
|
||||||
|
|
||||||
|
print_test "Invalid threshold (negative) - should return 422"
|
||||||
|
REQUEST_DATA=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"features": $SAMPLE_TURTLE,
|
||||||
|
"threshold": -0.5
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
if make_request POST "/predict" "$REQUEST_DATA" 422; then
|
||||||
|
print_success "Invalid threshold correctly rejected"
|
||||||
|
else
|
||||||
|
print_failure "Invalid threshold not rejected properly"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_test "Invalid threshold (> 1.0) - should return 422"
|
||||||
|
REQUEST_DATA=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"features": $SAMPLE_TURTLE,
|
||||||
|
"threshold": 1.5
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
if make_request POST "/predict" "$REQUEST_DATA" 422; then
|
||||||
|
print_success "Invalid threshold (>1) correctly rejected"
|
||||||
|
else
|
||||||
|
print_failure "Invalid threshold (>1) not rejected properly"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_test "Missing features field - should return 422"
|
||||||
|
REQUEST_DATA='{"threshold": 0.5}'
|
||||||
|
if make_request POST "/predict" "$REQUEST_DATA" 422; then
|
||||||
|
print_success "Missing features correctly rejected"
|
||||||
|
else
|
||||||
|
print_failure "Missing features not rejected properly"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_test "Invalid endpoint - should return 404"
|
||||||
|
if make_request GET "/nonexistent" "" 404; then
|
||||||
|
print_success "Invalid endpoint correctly returns 404"
|
||||||
|
else
|
||||||
|
print_failure "Invalid endpoint not handled properly"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_test "Invalid k value (0) - should return 422"
|
||||||
|
REQUEST_DATA=$(cat <<EOF
|
||||||
|
{
|
||||||
|
"features": $SAMPLE_TURTLE,
|
||||||
|
"k": 0
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
)
|
||||||
|
if make_request POST "/predict_top_k" "$REQUEST_DATA" 422; then
|
||||||
|
print_success "Invalid k value correctly rejected"
|
||||||
|
else
|
||||||
|
print_failure "Invalid k value not rejected properly"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Summary
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
print_header "Test Summary"
|
||||||
|
echo ""
|
||||||
|
echo -e "Passed: ${GREEN}$PASSED${NC}"
|
||||||
|
echo -e "Failed: ${RED}$FAILED${NC}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
if [ $FAILED -eq 0 ]; then
|
||||||
|
echo -e "${GREEN}All tests passed! ✓${NC}"
|
||||||
|
exit 0
|
||||||
|
else
|
||||||
|
echo -e "${RED}Some tests failed! ✗${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
811
test_xgb_fastapi.py
Normal file
811
test_xgb_fastapi.py
Normal file
@@ -0,0 +1,811 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Tests for XGBoost FastAPI Multi-Label Classification API
|
||||||
|
|
||||||
|
These tests use realistic sample data extracted from /Volumes/T7/roofline_features.h5
|
||||||
|
which was generated using feature_aggregator.py to process roofline dataframes.
|
||||||
|
|
||||||
|
Test data includes samples from different application types:
|
||||||
|
- TurTLE (turbulence simulation)
|
||||||
|
- SCALEXA (scaling benchmarks)
|
||||||
|
- Chroma (lattice QCD)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Import the FastAPI app and predictor
|
||||||
|
import xgb_fastapi
|
||||||
|
from xgb_fastapi import app
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Test Data: Realistic samples from roofline_features.h5
|
||||||
|
# Generated using feature_aggregator.py processing roofline dataframes
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# Sample 1: TurTLE application - typical turbulence simulation workload
|
||||||
|
SAMPLE_TURTLE = {
|
||||||
|
"bandwidth_raw_p10": 186.33,
|
||||||
|
"bandwidth_raw_median": 205.14,
|
||||||
|
"bandwidth_raw_p90": 210.83,
|
||||||
|
"bandwidth_raw_mad": 3.57,
|
||||||
|
"bandwidth_raw_range": 24.5,
|
||||||
|
"bandwidth_raw_iqr": 12.075,
|
||||||
|
"flops_raw_p10": 162.024,
|
||||||
|
"flops_raw_median": 171.45,
|
||||||
|
"flops_raw_p90": 176.48,
|
||||||
|
"flops_raw_mad": 3.08,
|
||||||
|
"flops_raw_range": 14.456,
|
||||||
|
"flops_raw_iqr": 8.29,
|
||||||
|
"arith_intensity_p10": 0.7906,
|
||||||
|
"arith_intensity_median": 0.837,
|
||||||
|
"arith_intensity_p90": 0.9109,
|
||||||
|
"arith_intensity_mad": 0.02,
|
||||||
|
"arith_intensity_range": 0.12,
|
||||||
|
"arith_intensity_iqr": 0.0425,
|
||||||
|
"bw_flops_covariance": 60.86,
|
||||||
|
"bw_flops_correlation": 0.16,
|
||||||
|
"avg_performance_gflops": 168.1,
|
||||||
|
"median_performance_gflops": 171.45,
|
||||||
|
"performance_gflops_mad": 3.08,
|
||||||
|
"avg_memory_bw_gbs": 350.0,
|
||||||
|
"scalar_peak_gflops": 432.0,
|
||||||
|
"simd_peak_gflops": 9216.0,
|
||||||
|
"node_num": 0,
|
||||||
|
"duration": 19366,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sample 2: SCALEXA application - scaling benchmark workload
|
||||||
|
SAMPLE_SCALEXA = {
|
||||||
|
"bandwidth_raw_p10": 13.474,
|
||||||
|
"bandwidth_raw_median": 32.57,
|
||||||
|
"bandwidth_raw_p90": 51.466,
|
||||||
|
"bandwidth_raw_mad": 23.62,
|
||||||
|
"bandwidth_raw_range": 37.992,
|
||||||
|
"bandwidth_raw_iqr": 23.745,
|
||||||
|
"flops_raw_p10": 4.24,
|
||||||
|
"flops_raw_median": 16.16,
|
||||||
|
"flops_raw_p90": 24.584,
|
||||||
|
"flops_raw_mad": 10.53,
|
||||||
|
"flops_raw_range": 20.344,
|
||||||
|
"flops_raw_iqr": 12.715,
|
||||||
|
"arith_intensity_p10": 0.211,
|
||||||
|
"arith_intensity_median": 0.475,
|
||||||
|
"arith_intensity_p90": 0.492,
|
||||||
|
"arith_intensity_mad": 0.021,
|
||||||
|
"arith_intensity_range": 0.281,
|
||||||
|
"arith_intensity_iqr": 0.176,
|
||||||
|
"bw_flops_covariance": 302.0,
|
||||||
|
"bw_flops_correlation": 0.995,
|
||||||
|
"avg_performance_gflops": 14.7,
|
||||||
|
"median_performance_gflops": 16.16,
|
||||||
|
"performance_gflops_mad": 10.53,
|
||||||
|
"avg_memory_bw_gbs": 350.0,
|
||||||
|
"scalar_peak_gflops": 432.0,
|
||||||
|
"simd_peak_gflops": 9216.0,
|
||||||
|
"node_num": 18,
|
||||||
|
"duration": 165,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sample 3: Chroma application - lattice QCD workload (compute-intensive)
|
||||||
|
SAMPLE_CHROMA = {
|
||||||
|
"bandwidth_raw_p10": 154.176,
|
||||||
|
"bandwidth_raw_median": 200.57,
|
||||||
|
"bandwidth_raw_p90": 259.952,
|
||||||
|
"bandwidth_raw_mad": 5.12,
|
||||||
|
"bandwidth_raw_range": 105.776,
|
||||||
|
"bandwidth_raw_iqr": 10.215,
|
||||||
|
"flops_raw_p10": 327.966,
|
||||||
|
"flops_raw_median": 519.8,
|
||||||
|
"flops_raw_p90": 654.422,
|
||||||
|
"flops_raw_mad": 16.97,
|
||||||
|
"flops_raw_range": 326.456,
|
||||||
|
"flops_raw_iqr": 34.88,
|
||||||
|
"arith_intensity_p10": 1.55,
|
||||||
|
"arith_intensity_median": 2.595,
|
||||||
|
"arith_intensity_p90": 3.445,
|
||||||
|
"arith_intensity_mad": 0.254,
|
||||||
|
"arith_intensity_range": 1.894,
|
||||||
|
"arith_intensity_iqr": 0.512,
|
||||||
|
"bw_flops_covariance": 382.76,
|
||||||
|
"bw_flops_correlation": 0.063,
|
||||||
|
"avg_performance_gflops": 503.26,
|
||||||
|
"median_performance_gflops": 519.8,
|
||||||
|
"performance_gflops_mad": 16.97,
|
||||||
|
"avg_memory_bw_gbs": 350.0,
|
||||||
|
"scalar_peak_gflops": 432.0,
|
||||||
|
"simd_peak_gflops": 9216.0,
|
||||||
|
"node_num": 3,
|
||||||
|
"duration": 31133,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sample JSON roofline data (raw data before aggregation, as would be received by API)
|
||||||
|
SAMPLE_JSON_ROOFLINE = json.dumps([
|
||||||
|
{
|
||||||
|
"node_num": 1,
|
||||||
|
"bandwidth_raw": 150.5,
|
||||||
|
"flops_raw": 2500.0,
|
||||||
|
"arith_intensity": 16.6,
|
||||||
|
"performance_gflops": 1200.0,
|
||||||
|
"memory_bw_gbs": 450,
|
||||||
|
"scalar_peak_gflops": 600,
|
||||||
|
"duration": 3600
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"node_num": 1,
|
||||||
|
"bandwidth_raw": 155.2,
|
||||||
|
"flops_raw": 2600.0,
|
||||||
|
"arith_intensity": 16.8,
|
||||||
|
"performance_gflops": 1250.0,
|
||||||
|
"memory_bw_gbs": 450,
|
||||||
|
"scalar_peak_gflops": 600,
|
||||||
|
"duration": 3600
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"node_num": 1,
|
||||||
|
"bandwidth_raw": 148.0,
|
||||||
|
"flops_raw": 2450.0,
|
||||||
|
"arith_intensity": 16.5,
|
||||||
|
"performance_gflops": 1180.0,
|
||||||
|
"memory_bw_gbs": 450,
|
||||||
|
"scalar_peak_gflops": 600,
|
||||||
|
"duration": 3600
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Fixtures
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def setup_predictor():
|
||||||
|
"""
|
||||||
|
Set up the predictor for tests.
|
||||||
|
Try to load the real model if available.
|
||||||
|
"""
|
||||||
|
from xgb_inference_api import XGBoostMultiLabelPredictor
|
||||||
|
|
||||||
|
model_path = os.path.join(os.path.dirname(__file__), 'xgb_model.joblib')
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
try:
|
||||||
|
predictor = XGBoostMultiLabelPredictor(model_path)
|
||||||
|
xgb_fastapi.predictor = predictor
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load model: {e}")
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(setup_predictor):
|
||||||
|
"""Create a test client for the FastAPI app."""
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_loaded(setup_predictor):
|
||||||
|
"""Check if the model is loaded."""
|
||||||
|
return setup_predictor
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_no_model(model_loaded):
|
||||||
|
"""Helper to skip tests if model is not loaded."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded, skipping test")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Health and Root Endpoint Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestHealthEndpoints:
|
||||||
|
"""Tests for health check and root endpoints."""
|
||||||
|
|
||||||
|
def test_root_endpoint(self, client):
|
||||||
|
"""Test the root endpoint returns API information."""
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "name" in data
|
||||||
|
assert data["name"] == "XGBoost Multi-Label Classification API"
|
||||||
|
assert "version" in data
|
||||||
|
assert "endpoints" in data
|
||||||
|
assert all(key in data["endpoints"] for key in ["health", "predict", "predict_top_k", "batch_predict"])
|
||||||
|
|
||||||
|
def test_health_check(self, client, model_loaded):
|
||||||
|
"""Test the health check endpoint."""
|
||||||
|
response = client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "status" in data
|
||||||
|
assert "model_loaded" in data
|
||||||
|
|
||||||
|
# If model is loaded, check for additional info
|
||||||
|
if model_loaded:
|
||||||
|
assert data["model_loaded"] == True
|
||||||
|
assert data["status"] in ["healthy", "degraded"]
|
||||||
|
if data["status"] == "healthy":
|
||||||
|
assert "n_classes" in data
|
||||||
|
assert "n_features" in data
|
||||||
|
assert "classes" in data
|
||||||
|
assert data["n_classes"] > 0
|
||||||
|
assert data["n_features"] > 0
|
||||||
|
|
||||||
|
def test_health_check_model_not_loaded(self, client):
|
||||||
|
"""Test health check returns correct status when model not loaded."""
|
||||||
|
response = client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
# Should have status and model_loaded fields regardless
|
||||||
|
assert "status" in data
|
||||||
|
assert "model_loaded" in data
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Single Prediction Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestPredictEndpoint:
|
||||||
|
"""Tests for the /predict endpoint."""
|
||||||
|
|
||||||
|
def test_predict_with_feature_dict(self, client, model_loaded):
|
||||||
|
"""Test prediction with a feature dictionary (TurTLE sample)."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"features": SAMPLE_TURTLE,
|
||||||
|
"threshold": 0.5,
|
||||||
|
"return_all_probabilities": True
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/predict", json=request_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "predictions" in data
|
||||||
|
assert "probabilities" in data
|
||||||
|
assert "confidences" in data
|
||||||
|
assert "threshold" in data
|
||||||
|
|
||||||
|
assert isinstance(data["predictions"], list)
|
||||||
|
assert isinstance(data["probabilities"], dict)
|
||||||
|
assert data["threshold"] == 0.5
|
||||||
|
|
||||||
|
# All probabilities should be between 0 and 1
|
||||||
|
for prob in data["probabilities"].values():
|
||||||
|
assert 0.0 <= prob <= 1.0
|
||||||
|
|
||||||
|
def test_predict_with_different_thresholds(self, client, model_loaded):
|
||||||
|
"""Test that different thresholds affect predictions."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
request_low = {
|
||||||
|
"features": SAMPLE_TURTLE,
|
||||||
|
"threshold": 0.1,
|
||||||
|
"return_all_probabilities": True
|
||||||
|
}
|
||||||
|
request_high = {
|
||||||
|
"features": SAMPLE_TURTLE,
|
||||||
|
"threshold": 0.9,
|
||||||
|
"return_all_probabilities": True
|
||||||
|
}
|
||||||
|
|
||||||
|
response_low = client.post("/predict", json=request_low)
|
||||||
|
response_high = client.post("/predict", json=request_high)
|
||||||
|
|
||||||
|
assert response_low.status_code == 200
|
||||||
|
assert response_high.status_code == 200
|
||||||
|
|
||||||
|
data_low = response_low.json()
|
||||||
|
data_high = response_high.json()
|
||||||
|
|
||||||
|
# Lower threshold should generally produce more predictions
|
||||||
|
assert len(data_low["predictions"]) >= len(data_high["predictions"])
|
||||||
|
|
||||||
|
def test_predict_different_workloads(self, client, model_loaded):
|
||||||
|
"""Test predictions on different application workloads."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
samples = [
|
||||||
|
("TurTLE", SAMPLE_TURTLE),
|
||||||
|
("SCALEXA", SAMPLE_SCALEXA),
|
||||||
|
("Chroma", SAMPLE_CHROMA),
|
||||||
|
]
|
||||||
|
|
||||||
|
for name, sample in samples:
|
||||||
|
request_data = {
|
||||||
|
"features": sample,
|
||||||
|
"threshold": 0.3,
|
||||||
|
"return_all_probabilities": True
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/predict", json=request_data)
|
||||||
|
assert response.status_code == 200, f"Failed for {name}"
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["probabilities"]) > 0, f"No probabilities for {name}"
|
||||||
|
|
||||||
|
def test_predict_return_only_predicted_probabilities(self, client, model_loaded):
|
||||||
|
"""Test prediction with return_all_probabilities=False."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"features": SAMPLE_TURTLE,
|
||||||
|
"threshold": 0.3,
|
||||||
|
"return_all_probabilities": False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/predict", json=request_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
# When return_all_probabilities is False, probabilities should only
|
||||||
|
# contain classes that are in predictions
|
||||||
|
if len(data["predictions"]) > 0:
|
||||||
|
assert set(data["probabilities"].keys()) == set(data["predictions"])
|
||||||
|
|
||||||
|
def test_predict_with_json_roofline_data(self, client, model_loaded):
|
||||||
|
"""Test prediction with raw JSON roofline data (requires aggregation)."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"features": SAMPLE_JSON_ROOFLINE,
|
||||||
|
"is_json": True,
|
||||||
|
"job_id": "test_job_123",
|
||||||
|
"threshold": 0.3
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/predict", json=request_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "predictions" in data
|
||||||
|
assert "probabilities" in data
|
||||||
|
|
||||||
|
def test_predict_threshold_boundaries(self, client, model_loaded):
|
||||||
|
"""Test prediction with threshold at boundaries."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
for threshold in [0.0, 0.5, 1.0]:
|
||||||
|
request_data = {
|
||||||
|
"features": SAMPLE_TURTLE,
|
||||||
|
"threshold": threshold
|
||||||
|
}
|
||||||
|
response = client.post("/predict", json=request_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_predict_invalid_threshold(self, client):
|
||||||
|
"""Test that invalid threshold values are rejected."""
|
||||||
|
for threshold in [-0.1, 1.5]:
|
||||||
|
request_data = {
|
||||||
|
"features": SAMPLE_TURTLE,
|
||||||
|
"threshold": threshold
|
||||||
|
}
|
||||||
|
response = client.post("/predict", json=request_data)
|
||||||
|
assert response.status_code == 422 # Validation error
|
||||||
|
|
||||||
|
def test_predict_model_not_loaded(self, client):
|
||||||
|
"""Test that prediction returns 503 when model not loaded."""
|
||||||
|
# Temporarily set predictor to None
|
||||||
|
original_predictor = xgb_fastapi.predictor
|
||||||
|
xgb_fastapi.predictor = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
request_data = {
|
||||||
|
"features": SAMPLE_TURTLE,
|
||||||
|
"threshold": 0.5
|
||||||
|
}
|
||||||
|
response = client.post("/predict", json=request_data)
|
||||||
|
assert response.status_code == 503
|
||||||
|
assert "Model not loaded" in response.json().get("detail", "")
|
||||||
|
finally:
|
||||||
|
xgb_fastapi.predictor = original_predictor
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Top-K Prediction Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestPredictTopKEndpoint:
|
||||||
|
"""Tests for the /predict_top_k endpoint."""
|
||||||
|
|
||||||
|
def test_predict_top_k_default(self, client, model_loaded):
|
||||||
|
"""Test top-K prediction with default k=5."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"features": SAMPLE_TURTLE
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/predict_top_k", json=request_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "top_predictions" in data
|
||||||
|
assert "top_probabilities" in data
|
||||||
|
assert "all_probabilities" in data
|
||||||
|
|
||||||
|
assert len(data["top_predictions"]) <= 5
|
||||||
|
assert len(data["top_probabilities"]) <= 5
|
||||||
|
|
||||||
|
def test_predict_top_k_custom_k(self, client, model_loaded):
|
||||||
|
"""Test top-K prediction with custom k values."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
for k in [1, 3, 10]:
|
||||||
|
request_data = {
|
||||||
|
"features": SAMPLE_TURTLE,
|
||||||
|
"k": k
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/predict_top_k", json=request_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["top_predictions"]) <= k
|
||||||
|
|
||||||
|
def test_predict_top_k_ordering(self, client, model_loaded):
|
||||||
|
"""Test that top-K predictions are ordered by probability."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"features": SAMPLE_CHROMA,
|
||||||
|
"k": 10
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/predict_top_k", json=request_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
probabilities = [data["top_probabilities"][cls] for cls in data["top_predictions"]]
|
||||||
|
|
||||||
|
# Check that probabilities are in descending order
|
||||||
|
for i in range(len(probabilities) - 1):
|
||||||
|
assert probabilities[i] >= probabilities[i + 1]
|
||||||
|
|
||||||
|
def test_predict_top_k_with_json_data(self, client, model_loaded):
|
||||||
|
"""Test top-K prediction with JSON roofline data."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"features": SAMPLE_JSON_ROOFLINE,
|
||||||
|
"is_json": True,
|
||||||
|
"job_id": "test_job_456",
|
||||||
|
"k": 5
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/predict_top_k", json=request_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["top_predictions"]) <= 5
|
||||||
|
|
||||||
|
def test_predict_top_k_invalid_k(self, client):
|
||||||
|
"""Test that invalid k values are rejected."""
|
||||||
|
for k in [0, -1, 101]:
|
||||||
|
request_data = {
|
||||||
|
"features": SAMPLE_TURTLE,
|
||||||
|
"k": k
|
||||||
|
}
|
||||||
|
response = client.post("/predict_top_k", json=request_data)
|
||||||
|
assert response.status_code == 422 # Validation error
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Batch Prediction Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestBatchPredictEndpoint:
|
||||||
|
"""Tests for the /batch_predict endpoint."""
|
||||||
|
|
||||||
|
def test_batch_predict_multiple_samples(self, client, model_loaded):
|
||||||
|
"""Test batch prediction with multiple samples."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"features_list": [SAMPLE_TURTLE, SAMPLE_SCALEXA, SAMPLE_CHROMA],
|
||||||
|
"threshold": 0.3
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/batch_predict", json=request_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert "results" in data
|
||||||
|
assert "total" in data
|
||||||
|
assert "successful" in data
|
||||||
|
assert "failed" in data
|
||||||
|
|
||||||
|
assert data["total"] == 3
|
||||||
|
assert len(data["results"]) == 3
|
||||||
|
assert data["successful"] + data["failed"] == data["total"]
|
||||||
|
|
||||||
|
def test_batch_predict_single_sample(self, client, model_loaded):
|
||||||
|
"""Test batch prediction with a single sample."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"features_list": [SAMPLE_TURTLE],
|
||||||
|
"threshold": 0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/batch_predict", json=request_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 1
|
||||||
|
assert len(data["results"]) == 1
|
||||||
|
|
||||||
|
def test_batch_predict_with_json_data(self, client, model_loaded):
|
||||||
|
"""Test batch prediction with JSON roofline data."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"features_list": [SAMPLE_JSON_ROOFLINE, SAMPLE_JSON_ROOFLINE],
|
||||||
|
"is_json": True,
|
||||||
|
"job_ids": ["job_001", "job_002"],
|
||||||
|
"threshold": 0.3
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/batch_predict", json=request_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 2
|
||||||
|
|
||||||
|
def test_batch_predict_empty_list(self, client, model_loaded):
|
||||||
|
"""Test batch prediction with empty list."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"features_list": [],
|
||||||
|
"threshold": 0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/batch_predict", json=request_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 0
|
||||||
|
assert data["successful"] == 0
|
||||||
|
assert data["failed"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Model Info Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestModelInfoEndpoint:
|
||||||
|
"""Tests for the /model/info endpoint."""
|
||||||
|
|
||||||
|
def test_model_info(self, client, model_loaded):
|
||||||
|
"""Test getting model information."""
|
||||||
|
response = client.get("/model/info")
|
||||||
|
|
||||||
|
# May return 503 if model not loaded, or 200 if loaded
|
||||||
|
if model_loaded:
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "classes" in data
|
||||||
|
assert "n_classes" in data
|
||||||
|
assert "features" in data
|
||||||
|
assert "n_features" in data
|
||||||
|
|
||||||
|
assert isinstance(data["classes"], list)
|
||||||
|
assert len(data["classes"]) == data["n_classes"]
|
||||||
|
assert len(data["features"]) == data["n_features"]
|
||||||
|
else:
|
||||||
|
assert response.status_code == 503
|
||||||
|
|
||||||
|
def test_model_info_not_loaded(self, client):
|
||||||
|
"""Test that model info returns 503 when model not loaded."""
|
||||||
|
original_predictor = xgb_fastapi.predictor
|
||||||
|
xgb_fastapi.predictor = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.get("/model/info")
|
||||||
|
assert response.status_code == 503
|
||||||
|
finally:
|
||||||
|
xgb_fastapi.predictor = original_predictor
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Error Handling Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestErrorHandling:
|
||||||
|
"""Tests for error handling."""
|
||||||
|
|
||||||
|
def test_predict_missing_features(self, client):
|
||||||
|
"""Test prediction without features field."""
|
||||||
|
request_data = {
|
||||||
|
"threshold": 0.5
|
||||||
|
}
|
||||||
|
response = client.post("/predict", json=request_data)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
def test_predict_invalid_json_format(self, client, model_loaded):
|
||||||
|
"""Test prediction with invalid JSON in is_json mode."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"features": "not valid json {{",
|
||||||
|
"is_json": True
|
||||||
|
}
|
||||||
|
response = client.post("/predict", json=request_data)
|
||||||
|
# Should return error (400)
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
def test_invalid_endpoint(self, client):
|
||||||
|
"""Test accessing an invalid endpoint."""
|
||||||
|
response = client.get("/nonexistent")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Integration Tests (Full Pipeline)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestIntegration:
|
||||||
|
"""Integration tests for full prediction pipeline."""
|
||||||
|
|
||||||
|
def test_full_prediction_pipeline_features(self, client, model_loaded):
|
||||||
|
"""Test complete prediction pipeline with feature dict."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
# 1. Check health
|
||||||
|
health_response = client.get("/health")
|
||||||
|
assert health_response.status_code == 200
|
||||||
|
|
||||||
|
health_data = health_response.json()
|
||||||
|
assert health_data["model_loaded"] == True
|
||||||
|
|
||||||
|
# 2. Get model info
|
||||||
|
info_response = client.get("/model/info")
|
||||||
|
assert info_response.status_code == 200
|
||||||
|
|
||||||
|
# 3. Make single prediction
|
||||||
|
predict_response = client.post("/predict", json={
|
||||||
|
"features": SAMPLE_TURTLE,
|
||||||
|
"threshold": 0.3
|
||||||
|
})
|
||||||
|
assert predict_response.status_code == 200
|
||||||
|
|
||||||
|
# 4. Make top-K prediction
|
||||||
|
topk_response = client.post("/predict_top_k", json={
|
||||||
|
"features": SAMPLE_TURTLE,
|
||||||
|
"k": 5
|
||||||
|
})
|
||||||
|
assert topk_response.status_code == 200
|
||||||
|
|
||||||
|
# 5. Make batch prediction
|
||||||
|
batch_response = client.post("/batch_predict", json={
|
||||||
|
"features_list": [SAMPLE_TURTLE, SAMPLE_CHROMA],
|
||||||
|
"threshold": 0.3
|
||||||
|
})
|
||||||
|
assert batch_response.status_code == 200
|
||||||
|
|
||||||
|
def test_consistency_single_vs_batch(self, client, model_loaded):
|
||||||
|
"""Test that single prediction and batch prediction give consistent results."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
threshold = 0.3
|
||||||
|
|
||||||
|
# Single prediction
|
||||||
|
single_response = client.post("/predict", json={
|
||||||
|
"features": SAMPLE_TURTLE,
|
||||||
|
"threshold": threshold
|
||||||
|
})
|
||||||
|
|
||||||
|
# Batch prediction with same sample
|
||||||
|
batch_response = client.post("/batch_predict", json={
|
||||||
|
"features_list": [SAMPLE_TURTLE],
|
||||||
|
"threshold": threshold
|
||||||
|
})
|
||||||
|
|
||||||
|
assert single_response.status_code == 200
|
||||||
|
assert batch_response.status_code == 200
|
||||||
|
|
||||||
|
single_data = single_response.json()
|
||||||
|
batch_data = batch_response.json()
|
||||||
|
|
||||||
|
if batch_data["successful"] == 1:
|
||||||
|
batch_result = batch_data["results"][0]
|
||||||
|
# Predictions should be the same
|
||||||
|
assert set(single_data["predictions"]) == set(batch_result["predictions"])
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# CORS Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestCORS:
|
||||||
|
"""Tests for CORS configuration."""
|
||||||
|
|
||||||
|
def test_cors_headers(self, client):
|
||||||
|
"""Test that CORS headers are present in responses."""
|
||||||
|
response = client.options("/predict")
|
||||||
|
# Accept either 200 or 405 (method not allowed for OPTIONS in some configs)
|
||||||
|
assert response.status_code in [200, 405]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Performance Tests (Basic)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestPerformance:
|
||||||
|
"""Basic performance tests."""
|
||||||
|
|
||||||
|
def test_response_time_single_prediction(self, client, model_loaded):
|
||||||
|
"""Test that single prediction completes in reasonable time."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
response = client.post("/predict", json={
|
||||||
|
"features": SAMPLE_TURTLE,
|
||||||
|
"threshold": 0.5
|
||||||
|
})
|
||||||
|
elapsed = time.time() - start
|
||||||
|
|
||||||
|
# Should complete within 5 seconds (generous for CI environments)
|
||||||
|
assert elapsed < 5.0, f"Prediction took {elapsed:.2f}s, expected < 5s"
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_response_time_batch_prediction(self, client, model_loaded):
|
||||||
|
"""Test that batch prediction scales reasonably."""
|
||||||
|
if not model_loaded:
|
||||||
|
pytest.skip("Model not loaded")
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Create a batch of 10 samples
|
||||||
|
features_list = [SAMPLE_TURTLE] * 10
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
response = client.post("/batch_predict", json={
|
||||||
|
"features_list": features_list,
|
||||||
|
"threshold": 0.5
|
||||||
|
})
|
||||||
|
elapsed = time.time() - start
|
||||||
|
|
||||||
|
# Should complete within 10 seconds
|
||||||
|
assert elapsed < 10.0, f"Batch prediction took {elapsed:.2f}s, expected < 10s"
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Run tests if executed directly
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v", "--tb=short"])
|
||||||
429
xgb_fastapi.py
Normal file
429
xgb_fastapi.py
Normal file
@@ -0,0 +1,429 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
FastAPI REST API for XGBoost Multi-Label Classification Inference
|
||||||
|
|
||||||
|
Provides HTTP endpoints for:
|
||||||
|
- POST /predict: Single prediction with confidence scores
|
||||||
|
- POST /predict_top_k: Top-K predictions
|
||||||
|
- POST /batch_predict: Batch predictions
|
||||||
|
|
||||||
|
Supports both feature vectors and JSON roofline data as input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, Body
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from typing import Dict, List, Union, Optional, Any
|
||||||
|
import logging
|
||||||
|
import uvicorn
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from xgb_local import XGBoostMultiLabelPredictor
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global predictor instance
|
||||||
|
predictor: Optional[XGBoostMultiLabelPredictor] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Pydantic models for request/response validation
|
||||||
|
class PredictRequest(BaseModel):
|
||||||
|
"""Request model for single prediction."""
|
||||||
|
features: Union[Dict[str, float], str] = Field(
|
||||||
|
...,
|
||||||
|
description="Feature dictionary or JSON string of roofline data"
|
||||||
|
)
|
||||||
|
threshold: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Probability threshold for classification"
|
||||||
|
)
|
||||||
|
return_all_probabilities: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether to return probabilities for all classes"
|
||||||
|
)
|
||||||
|
is_json: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether features is a JSON string of roofline data"
|
||||||
|
)
|
||||||
|
job_id: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional job ID for JSON aggregation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PredictTopKRequest(BaseModel):
|
||||||
|
"""Request model for top-K prediction."""
|
||||||
|
features: Union[Dict[str, float], str] = Field(
|
||||||
|
...,
|
||||||
|
description="Feature dictionary or JSON string of roofline data"
|
||||||
|
)
|
||||||
|
k: int = Field(
|
||||||
|
default=5,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
description="Number of top predictions to return"
|
||||||
|
)
|
||||||
|
is_json: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether features is a JSON string of roofline data"
|
||||||
|
)
|
||||||
|
job_id: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional job ID for JSON aggregation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchPredictRequest(BaseModel):
|
||||||
|
"""Request model for batch prediction."""
|
||||||
|
features_list: List[Union[Dict[str, float], str]] = Field(
|
||||||
|
...,
|
||||||
|
description="List of feature dictionaries or JSON strings"
|
||||||
|
)
|
||||||
|
threshold: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Probability threshold for classification"
|
||||||
|
)
|
||||||
|
is_json: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether features are JSON strings of roofline data"
|
||||||
|
)
|
||||||
|
job_ids: Optional[List[str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional list of job IDs for JSON aggregation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PredictResponse(BaseModel):
|
||||||
|
"""Response model for single prediction."""
|
||||||
|
predictions: List[str] = Field(description="List of predicted class names")
|
||||||
|
probabilities: Dict[str, float] = Field(description="Probabilities for each class")
|
||||||
|
confidences: Dict[str, float] = Field(description="Confidence scores for predicted classes")
|
||||||
|
threshold: float = Field(description="Threshold used for prediction")
|
||||||
|
|
||||||
|
|
||||||
|
class PredictTopKResponse(BaseModel):
|
||||||
|
"""Response model for top-K prediction."""
|
||||||
|
top_predictions: List[str] = Field(description="Top-K predicted class names")
|
||||||
|
top_probabilities: Dict[str, float] = Field(description="Probabilities for top-K classes")
|
||||||
|
all_probabilities: Dict[str, float] = Field(description="Probabilities for all classes")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchPredictResponse(BaseModel):
|
||||||
|
"""Response model for batch prediction."""
|
||||||
|
results: List[Union[PredictResponse, Dict[str, str]]] = Field(
|
||||||
|
description="List of prediction results or errors"
|
||||||
|
)
|
||||||
|
total: int = Field(description="Total number of samples processed")
|
||||||
|
successful: int = Field(description="Number of successful predictions")
|
||||||
|
failed: int = Field(description="Number of failed predictions")
|
||||||
|
|
||||||
|
|
||||||
|
class HealthResponse(BaseModel):
|
||||||
|
"""Response model for health check."""
|
||||||
|
status: str
|
||||||
|
model_loaded: bool
|
||||||
|
n_classes: Optional[int] = None
|
||||||
|
n_features: Optional[int] = None
|
||||||
|
classes: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
"""Response model for errors."""
|
||||||
|
error: str
|
||||||
|
detail: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Lifespan context manager for startup/shutdown
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Manage application lifespan events."""
|
||||||
|
# Startup
|
||||||
|
global predictor
|
||||||
|
try:
|
||||||
|
logger.info("Loading XGBoost model...")
|
||||||
|
predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')
|
||||||
|
logger.info("Model loaded successfully!")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model: {e}")
|
||||||
|
predictor = None
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
logger.info("Shutting down...")
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize FastAPI app
|
||||||
|
app = FastAPI(
|
||||||
|
title="XGBoost Multi-Label Classification API",
|
||||||
|
description="REST API for multi-label classification inference using XGBoost",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # In production, specify allowed origins
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# API Endpoints
|
||||||
|
|
||||||
|
@app.get("/", tags=["General"])
|
||||||
|
async def root():
|
||||||
|
"""Root endpoint with API information."""
|
||||||
|
return {
|
||||||
|
"name": "XGBoost Multi-Label Classification API",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"endpoints": {
|
||||||
|
"health": "/health",
|
||||||
|
"predict": "/predict",
|
||||||
|
"predict_top_k": "/predict_top_k",
|
||||||
|
"batch_predict": "/batch_predict"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health", response_model=HealthResponse, tags=["General"])
|
||||||
|
async def health_check():
|
||||||
|
"""
|
||||||
|
Check API health and model status.
|
||||||
|
|
||||||
|
Returns model information if loaded successfully.
|
||||||
|
"""
|
||||||
|
if predictor is None:
|
||||||
|
return HealthResponse(
|
||||||
|
status="error",
|
||||||
|
model_loaded=False
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
info = predictor.get_class_info()
|
||||||
|
return HealthResponse(
|
||||||
|
status="healthy",
|
||||||
|
model_loaded=True,
|
||||||
|
n_classes=info['n_classes'],
|
||||||
|
n_features=info['n_features'],
|
||||||
|
classes=info['classes']
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Health check error: {e}")
|
||||||
|
return HealthResponse(
|
||||||
|
status="degraded",
|
||||||
|
model_loaded=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/predict", response_model=PredictResponse, tags=["Inference"])
|
||||||
|
async def predict(request: PredictRequest):
|
||||||
|
"""
|
||||||
|
Perform single prediction on input features.
|
||||||
|
|
||||||
|
**Input formats:**
|
||||||
|
- Feature dictionary: `{"feature1": value1, "feature2": value2, ...}`
|
||||||
|
- JSON roofline data: Set `is_json=true` and provide JSON string
|
||||||
|
|
||||||
|
**Example (features):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"features": {
|
||||||
|
"bandwidth_raw_p10": 150.5,
|
||||||
|
"flops_raw_median": 2500.0,
|
||||||
|
...
|
||||||
|
},
|
||||||
|
"threshold": 0.5
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example (JSON roofline):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
|
||||||
|
"is_json": true,
|
||||||
|
"job_id": "test_job_123",
|
||||||
|
"threshold": 0.3
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if predictor is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = predictor.predict(
|
||||||
|
features=request.features,
|
||||||
|
threshold=request.threshold,
|
||||||
|
return_all_probabilities=request.return_all_probabilities,
|
||||||
|
is_json=request.is_json,
|
||||||
|
job_id=request.job_id
|
||||||
|
)
|
||||||
|
return PredictResponse(**result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Prediction error: {e}")
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/predict_top_k", response_model=PredictTopKResponse, tags=["Inference"])
|
||||||
|
async def predict_top_k(request: PredictTopKRequest):
|
||||||
|
"""
|
||||||
|
Get top-K predictions with their probabilities.
|
||||||
|
|
||||||
|
**Example (features):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"features": {
|
||||||
|
"bandwidth_raw_p10": 150.5,
|
||||||
|
"flops_raw_median": 2500.0,
|
||||||
|
...
|
||||||
|
},
|
||||||
|
"k": 5
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example (JSON roofline):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
|
||||||
|
"is_json": true,
|
||||||
|
"job_id": "test_job_123",
|
||||||
|
"k": 10
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if predictor is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = predictor.predict_top_k(
|
||||||
|
features=request.features,
|
||||||
|
k=request.k,
|
||||||
|
is_json=request.is_json,
|
||||||
|
job_id=request.job_id
|
||||||
|
)
|
||||||
|
return PredictTopKResponse(**result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Top-K prediction error: {e}")
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/batch_predict", response_model=BatchPredictResponse, tags=["Inference"])
|
||||||
|
async def batch_predict(request: BatchPredictRequest):
|
||||||
|
"""
|
||||||
|
Perform batch prediction on multiple samples.
|
||||||
|
|
||||||
|
**Example (features):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"features_list": [
|
||||||
|
{"bandwidth_raw_p10": 150.5, ...},
|
||||||
|
{"bandwidth_raw_p10": 160.2, ...}
|
||||||
|
],
|
||||||
|
"threshold": 0.5
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example (JSON roofline):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"features_list": [
|
||||||
|
"[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
|
||||||
|
"[{\"node_num\": 2, \"bandwidth_raw\": 160.2, ...}]"
|
||||||
|
],
|
||||||
|
"is_json": true,
|
||||||
|
"job_ids": ["job1", "job2"],
|
||||||
|
"threshold": 0.3
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if predictor is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = predictor.batch_predict(
|
||||||
|
features_list=request.features_list,
|
||||||
|
threshold=request.threshold,
|
||||||
|
is_json=request.is_json,
|
||||||
|
job_ids=request.job_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count successful and failed predictions
|
||||||
|
successful = sum(1 for r in results if 'error' not in r)
|
||||||
|
failed = len(results) - successful
|
||||||
|
|
||||||
|
return BatchPredictResponse(
|
||||||
|
results=results,
|
||||||
|
total=len(results),
|
||||||
|
successful=successful,
|
||||||
|
failed=failed
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Batch prediction error: {e}")
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/model/info", tags=["Model"])
|
||||||
|
async def model_info():
|
||||||
|
"""
|
||||||
|
Get detailed model information.
|
||||||
|
|
||||||
|
Returns information about classes, features, and model configuration.
|
||||||
|
"""
|
||||||
|
if predictor is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
info = predictor.get_class_info()
|
||||||
|
return {
|
||||||
|
"classes": info['classes'],
|
||||||
|
"n_classes": info['n_classes'],
|
||||||
|
"features": info['feature_columns'],
|
||||||
|
"n_features": info['n_features']
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Model info error: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run the FastAPI server."""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="XGBoost Multi-Label Classification REST API")
|
||||||
|
parser.add_argument('--host', type=str, default='0.0.0.0',
|
||||||
|
help='Host to bind to (default: 0.0.0.0)')
|
||||||
|
parser.add_argument('--port', type=int, default=8000,
|
||||||
|
help='Port to bind to (default: 8000)')
|
||||||
|
parser.add_argument('--reload', action='store_true',
|
||||||
|
help='Enable auto-reload for development')
|
||||||
|
parser.add_argument('--workers', type=int, default=1,
|
||||||
|
help='Number of worker processes (default: 1)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logger.info(f"Starting FastAPI server on {args.host}:{args.port}")
|
||||||
|
logger.info(f"Workers: {args.workers}, Reload: {args.reload}")
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"xgb_fastapi:app",
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
reload=args.reload,
|
||||||
|
workers=args.workers if not args.reload else 1,
|
||||||
|
log_level="info"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
530
xgb_local.py
Normal file
530
xgb_local.py
Normal file
@@ -0,0 +1,530 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import joblib
|
||||||
|
import json
|
||||||
|
from typing import Dict, List, Tuple, Union, Optional
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
from scipy import stats
|
||||||
|
|
||||||
|
def compute_mad(data: np.ndarray) -> float:
|
||||||
|
"""Compute Median Absolute Deviation."""
|
||||||
|
median = np.median(data)
|
||||||
|
mad = np.median(np.abs(data - median))
|
||||||
|
return mad
|
||||||
|
|
||||||
|
|
||||||
|
def df_aggregate(json_str: str, job_id_full: Optional[str] = None) -> Dict:
|
||||||
|
"""
|
||||||
|
Aggregate roofline data from JSON string into a single feature vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_str: JSON string containing roofline data records
|
||||||
|
job_id_full: Optional job ID to include in the result
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing aggregated features
|
||||||
|
"""
|
||||||
|
# Parse JSON string to DataFrame
|
||||||
|
try:
|
||||||
|
data = json.loads(json_str)
|
||||||
|
if isinstance(data, list):
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
df = pd.DataFrame([data])
|
||||||
|
else:
|
||||||
|
raise ValueError("JSON must contain a list of objects or a single object")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Invalid JSON string: {e}")
|
||||||
|
|
||||||
|
# Group data by node_num
|
||||||
|
if 'node_num' not in df.columns:
|
||||||
|
# If no node_num, treat all data as single node
|
||||||
|
df['node_num'] = 1
|
||||||
|
|
||||||
|
grouped = df.groupby('node_num')
|
||||||
|
|
||||||
|
all_features = []
|
||||||
|
for node_num, group in grouped:
|
||||||
|
features = {
|
||||||
|
'node_num': int(node_num)
|
||||||
|
}
|
||||||
|
|
||||||
|
if job_id_full is not None:
|
||||||
|
features['job_id'] = job_id_full
|
||||||
|
|
||||||
|
# Compute statistics for key metrics
|
||||||
|
for axis in ['bandwidth_raw', 'flops_raw', 'arith_intensity']:
|
||||||
|
data = group[axis].values
|
||||||
|
|
||||||
|
# Compute percentiles
|
||||||
|
p10 = np.percentile(data, 10)
|
||||||
|
p50 = np.median(data)
|
||||||
|
p90 = np.percentile(data, 90)
|
||||||
|
|
||||||
|
# Compute MAD (more robust than variance)
|
||||||
|
mad = compute_mad(data)
|
||||||
|
|
||||||
|
# Store features
|
||||||
|
features[f'{axis}_p10'] = p10
|
||||||
|
features[f'{axis}_median'] = p50
|
||||||
|
features[f'{axis}_p90'] = p90
|
||||||
|
features[f'{axis}_mad'] = mad
|
||||||
|
features[f'{axis}_range'] = p90 - p10
|
||||||
|
features[f'{axis}_iqr'] = np.percentile(data, 75) - np.percentile(data, 25)
|
||||||
|
|
||||||
|
# Compute covariance and correlation between bandwidth_raw and flops_raw
|
||||||
|
if len(group) > 1: # Need at least 2 points for correlation
|
||||||
|
cov = np.cov(group['bandwidth_raw'], group['flops_raw'])[0, 1]
|
||||||
|
features['bw_flops_covariance'] = cov
|
||||||
|
|
||||||
|
corr, _ = stats.pearsonr(group['bandwidth_raw'], group['flops_raw'])
|
||||||
|
features['bw_flops_correlation'] = corr
|
||||||
|
|
||||||
|
# Additional useful features for the classifier
|
||||||
|
|
||||||
|
# Performance metrics
|
||||||
|
features['avg_performance_gflops'] = group['performance_gflops'].mean()
|
||||||
|
features['median_performance_gflops'] = group['performance_gflops'].median()
|
||||||
|
features['performance_gflops_mad'] = compute_mad(group['performance_gflops'].values)
|
||||||
|
|
||||||
|
# # Efficiency metrics
|
||||||
|
# features['avg_efficiency'] = group['efficiency'].mean()
|
||||||
|
# features['median_efficiency'] = group['efficiency'].median()
|
||||||
|
# features['efficiency_mad'] = compute_mad(group['efficiency'].values)
|
||||||
|
# features['efficiency_p10'] = np.percentile(group['efficiency'].values, 10)
|
||||||
|
# features['efficiency_p90'] = np.percentile(group['efficiency'].values, 90)
|
||||||
|
|
||||||
|
# # Distribution of roofline regions (memory-bound vs compute-bound)
|
||||||
|
# if 'roofline_region' in group.columns:
|
||||||
|
# region_counts = group['roofline_region'].value_counts(normalize=True).to_dict()
|
||||||
|
# for region, ratio in region_counts.items():
|
||||||
|
# features[f'region_{region}_ratio'] = ratio
|
||||||
|
|
||||||
|
# System characteristics
|
||||||
|
if 'memory_bw_gbs' in group.columns:
|
||||||
|
features['avg_memory_bw_gbs'] = group['memory_bw_gbs'].mean()
|
||||||
|
if 'scalar_peak_gflops' in group.columns and len(group['scalar_peak_gflops'].unique()) > 0:
|
||||||
|
features['scalar_peak_gflops'] = group['scalar_peak_gflops'].iloc[0]
|
||||||
|
if 'simd_peak_gflops' in group.columns and len(group['simd_peak_gflops'].unique()) > 0:
|
||||||
|
features['simd_peak_gflops'] = group['simd_peak_gflops'].iloc[0]
|
||||||
|
|
||||||
|
# # Subcluster information if available
|
||||||
|
# if 'subcluster_name' in group.columns and not group['subcluster_name'].isna().all():
|
||||||
|
# features['subcluster_name'] = group['subcluster_name'].iloc[0]
|
||||||
|
|
||||||
|
# Duration information
|
||||||
|
if 'duration' in group.columns:
|
||||||
|
features['duration'] = group['duration'].iloc[0]
|
||||||
|
|
||||||
|
all_features.append(features)
|
||||||
|
|
||||||
|
# Return first node's features (or combine multiple nodes if needed)
|
||||||
|
if len(all_features) == 1:
|
||||||
|
return all_features[0]
|
||||||
|
else:
|
||||||
|
# If multiple nodes, return the first one or average across nodes
|
||||||
|
# For now, return the first node's features
|
||||||
|
return all_features[0]
|
||||||
|
|
||||||
|
class XGBoostMultiLabelPredictor:
|
||||||
|
"""
|
||||||
|
Python API for XGBoost multi-label classification inference.
|
||||||
|
|
||||||
|
Provides methods to load trained models band perform inference with
|
||||||
|
confidence scores for each class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str = 'xgb_model.joblib'):
|
||||||
|
"""
|
||||||
|
Initialize the predictor by loading the trained model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the saved model file (.joblib)
|
||||||
|
"""
|
||||||
|
self.model_data = None
|
||||||
|
self.model = None
|
||||||
|
self.mlb = None
|
||||||
|
self.feature_columns = None
|
||||||
|
self.n_features = 0
|
||||||
|
self.classes = []
|
||||||
|
|
||||||
|
self.load_model(model_path)
|
||||||
|
|
||||||
|
def load_model(self, model_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Load the trained XGBoost model from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the saved model file
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
print(f"Loading model from {model_path}...")
|
||||||
|
self.model_data = joblib.load(model_path)
|
||||||
|
self.model = self.model_data['model']
|
||||||
|
self.mlb = self.model_data['mlb']
|
||||||
|
self.feature_columns = self.model_data['feature_columns']
|
||||||
|
|
||||||
|
self.classes = list(self.mlb.classes_)
|
||||||
|
self.n_features = len(self.feature_columns)
|
||||||
|
|
||||||
|
print("Model loaded successfully!")
|
||||||
|
print(f" - {len(self.classes)} classes: {self.classes}")
|
||||||
|
print(f" - {self.n_features} features: {self.feature_columns[:5]}...")
|
||||||
|
print(f" - Model type: {type(self.model).__name__}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to load model from {model_path}: {e}")
|
||||||
|
|
||||||
|
def predict(self, features: Union[pd.DataFrame, np.ndarray, List, Dict, str],
|
||||||
|
threshold: float = 0.5,
|
||||||
|
return_all_probabilities: bool = True,
|
||||||
|
is_json: bool = False,
|
||||||
|
job_id: Optional[str] = None) -> Dict:
|
||||||
|
"""
|
||||||
|
Perform multi-label prediction on input features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: Input features in various formats:
|
||||||
|
- pandas DataFrame
|
||||||
|
- numpy array (2D)
|
||||||
|
- list of lists/dicts
|
||||||
|
- single feature vector (list/dict)
|
||||||
|
- JSON string (if is_json=True): roofline data to aggregate
|
||||||
|
threshold: Probability threshold for binary classification (0.0-1.0)
|
||||||
|
return_all_probabilities: If True, return probabilities for all classes.
|
||||||
|
If False, return only classes above threshold.
|
||||||
|
is_json: If True, treat features as JSON string of roofline data
|
||||||
|
job_id: Optional job ID (used when is_json=True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing:
|
||||||
|
- 'predictions': List of predicted class names
|
||||||
|
- 'probabilities': Dict of {class_name: probability} for all classes
|
||||||
|
- 'confidences': Dict of {class_name: confidence_score} for predicted classes
|
||||||
|
- 'threshold': The threshold used
|
||||||
|
"""
|
||||||
|
# If input is JSON string, aggregate features first
|
||||||
|
if is_json:
|
||||||
|
if not isinstance(features, str):
|
||||||
|
raise ValueError("When is_json=True, features must be a JSON string")
|
||||||
|
features = df_aggregate(features, job_id_full=job_id)
|
||||||
|
|
||||||
|
# Convert input to proper format
|
||||||
|
X = self._prepare_features(features)
|
||||||
|
|
||||||
|
# Get probability predictions
|
||||||
|
probabilities = self.model.predict_proba(X)
|
||||||
|
|
||||||
|
# Convert to class probabilities
|
||||||
|
class_probabilities = {}
|
||||||
|
for i, class_name in enumerate(self.classes):
|
||||||
|
# For OneVsRest, predict_proba returns shape (n_samples, n_classes)
|
||||||
|
# Each column i contains probabilities for class i
|
||||||
|
if isinstance(probabilities, list):
|
||||||
|
# List of arrays (multiple samples)
|
||||||
|
prob_array = probabilities[i]
|
||||||
|
prob_positive = prob_array[0] if hasattr(prob_array, '__getitem__') else float(prob_array)
|
||||||
|
else:
|
||||||
|
# 2D numpy array (single sample or batch)
|
||||||
|
if len(probabilities.shape) == 2:
|
||||||
|
# Shape: (n_samples, n_classes)
|
||||||
|
prob_positive = float(probabilities[0, i])
|
||||||
|
else:
|
||||||
|
# 1D array
|
||||||
|
prob_positive = float(probabilities[i])
|
||||||
|
class_probabilities[class_name] = prob_positive
|
||||||
|
|
||||||
|
# Apply threshold for predictions
|
||||||
|
predictions = []
|
||||||
|
confidences = {}
|
||||||
|
|
||||||
|
for class_name, prob in class_probabilities.items():
|
||||||
|
if prob >= threshold:
|
||||||
|
predictions.append(class_name)
|
||||||
|
# Confidence score: distance from threshold as percentage
|
||||||
|
confidence = min(1.0, (prob - threshold) / (1.0 - threshold)) * 100
|
||||||
|
confidences[class_name] = round(confidence, 2)
|
||||||
|
|
||||||
|
# Sort predictions by probability
|
||||||
|
predictions.sort(key=lambda x: class_probabilities[x], reverse=True)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
'predictions': predictions,
|
||||||
|
'probabilities': {k: round(v, 4) for k, v in class_probabilities.items()},
|
||||||
|
'confidences': confidences,
|
||||||
|
'threshold': threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
if not return_all_probabilities:
|
||||||
|
result['probabilities'] = {k: v for k, v in result['probabilities'].items()
|
||||||
|
if k in predictions}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def predict_top_k(self, features: Union[pd.DataFrame, np.ndarray, List, Dict, str],
|
||||||
|
k: int = 5,
|
||||||
|
is_json: bool = False,
|
||||||
|
job_id: Optional[str] = None) -> Dict:
|
||||||
|
"""
|
||||||
|
Get top-k predictions with their probabilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: Input features (various formats) or JSON string if is_json=True
|
||||||
|
k: Number of top predictions to return
|
||||||
|
is_json: If True, treat features as JSON string of roofline data
|
||||||
|
job_id: Optional job ID (used when is_json=True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with top-k predictions and their details
|
||||||
|
"""
|
||||||
|
# If input is JSON string, aggregate features first
|
||||||
|
if is_json:
|
||||||
|
if not isinstance(features, str):
|
||||||
|
raise ValueError("When is_json=True, features must be a JSON string")
|
||||||
|
features = df_aggregate(features, job_id_full=job_id)
|
||||||
|
|
||||||
|
# Get all probabilities
|
||||||
|
X = self._prepare_features(features)
|
||||||
|
probabilities = self.model.predict_proba(X)
|
||||||
|
|
||||||
|
class_probabilities = {}
|
||||||
|
for i, class_name in enumerate(self.classes):
|
||||||
|
# For OneVsRest, predict_proba returns shape (n_samples, n_classes)
|
||||||
|
# Each column i contains probabilities for class i
|
||||||
|
if isinstance(probabilities, list):
|
||||||
|
# List of arrays (multiple samples)
|
||||||
|
prob_array = probabilities[i]
|
||||||
|
prob_positive = prob_array[0] if hasattr(prob_array, '__getitem__') else float(prob_array)
|
||||||
|
else:
|
||||||
|
# 2D numpy array (single sample or batch)
|
||||||
|
if len(probabilities.shape) == 2:
|
||||||
|
# Shape: (n_samples, n_classes)
|
||||||
|
prob_positive = float(probabilities[0, i])
|
||||||
|
else:
|
||||||
|
# 1D array
|
||||||
|
prob_positive = float(probabilities[i])
|
||||||
|
class_probabilities[class_name] = prob_positive
|
||||||
|
|
||||||
|
# Sort by probability
|
||||||
|
sorted_classes = sorted(class_probabilities.items(),
|
||||||
|
key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
top_k_classes = sorted_classes[:k]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'top_predictions': [cls for cls, _ in top_k_classes],
|
||||||
|
'top_probabilities': {cls: round(prob, 4) for cls, prob in top_k_classes},
|
||||||
|
'all_probabilities': {k: round(v, 4) for k, v in class_probabilities.items()}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _prepare_features(self, features: Union[pd.DataFrame, np.ndarray, List, Dict]) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Convert various input formats to the expected feature format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: Input features in various formats
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pandas DataFrame with correct columns and order
|
||||||
|
"""
|
||||||
|
if isinstance(features, pd.DataFrame):
|
||||||
|
df = features.copy()
|
||||||
|
elif isinstance(features, np.ndarray):
|
||||||
|
if features.ndim == 1:
|
||||||
|
features = features.reshape(1, -1)
|
||||||
|
df = pd.DataFrame(features, columns=self.feature_columns[:features.shape[1]])
|
||||||
|
elif isinstance(features, list):
|
||||||
|
if isinstance(features[0], dict):
|
||||||
|
# List of dictionaries
|
||||||
|
df = pd.DataFrame(features)
|
||||||
|
else:
|
||||||
|
# List of lists
|
||||||
|
df = pd.DataFrame(features, columns=self.feature_columns[:len(features[0])])
|
||||||
|
elif isinstance(features, dict):
|
||||||
|
# Single feature dictionary
|
||||||
|
df = pd.DataFrame([features])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported feature format: {type(features)}")
|
||||||
|
|
||||||
|
# Ensure correct column order and fill missing columns with 0
|
||||||
|
for col in self.feature_columns:
|
||||||
|
if col not in df.columns:
|
||||||
|
df[col] = 0.0
|
||||||
|
|
||||||
|
df = df[self.feature_columns]
|
||||||
|
|
||||||
|
# Validate feature count
|
||||||
|
if df.shape[1] != self.n_features:
|
||||||
|
raise ValueError(f"Expected {self.n_features} features, got {df.shape[1]}")
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def get_class_info(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Get information about available classes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with class information
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'classes': self.classes,
|
||||||
|
'n_classes': len(self.classes),
|
||||||
|
'feature_columns': self.feature_columns,
|
||||||
|
'n_features': self.n_features
|
||||||
|
}
|
||||||
|
|
||||||
|
def batch_predict(self, features_list: List[Union[pd.DataFrame, np.ndarray, List, Dict, str]],
|
||||||
|
threshold: float = 0.5,
|
||||||
|
is_json: bool = False,
|
||||||
|
job_ids: Optional[List[str]] = None) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Perform batch prediction on multiple samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features_list: List of feature inputs (or JSON strings if is_json=True)
|
||||||
|
threshold: Probability threshold
|
||||||
|
is_json: If True, treat each item in features_list as JSON string
|
||||||
|
job_ids: Optional list of job IDs (used when is_json=True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of prediction results
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for idx, features in enumerate(features_list):
|
||||||
|
try:
|
||||||
|
job_id = job_ids[idx] if job_ids and idx < len(job_ids) else None
|
||||||
|
result = self.predict(features, threshold=threshold, is_json=is_json, job_id=job_id)
|
||||||
|
results.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
results.append({'error': str(e)})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def create_sample_data(n_samples: int = 5) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Create sample feature data for testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_samples: Number of sample feature vectors to create
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of feature dictionaries
|
||||||
|
"""
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
# Load feature columns from model if available
|
||||||
|
try:
|
||||||
|
model_data = joblib.load('xgb_model.joblib')
|
||||||
|
feature_columns = model_data['feature_columns']
|
||||||
|
except:
|
||||||
|
# Fallback to some default features
|
||||||
|
feature_columns = [
|
||||||
|
'node_num', 'bandwidth_raw_p10', 'bandwidth_raw_median',
|
||||||
|
'bandwidth_raw_p90', 'bandwidth_raw_mad', 'bandwidth_raw_range',
|
||||||
|
'bandwidth_raw_iqr', 'flops_raw_p10', 'flops_raw_median',
|
||||||
|
'flops_raw_p90', 'flops_raw_mad', 'flops_raw_range'
|
||||||
|
]
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
for _ in range(n_samples):
|
||||||
|
sample = {}
|
||||||
|
for col in feature_columns:
|
||||||
|
if 'bandwidth' in col:
|
||||||
|
sample[col] = np.random.uniform(50, 500)
|
||||||
|
elif 'flops' in col:
|
||||||
|
sample[col] = np.random.uniform(100, 5000)
|
||||||
|
elif 'node_num' in col:
|
||||||
|
sample[col] = np.random.randint(1, 16)
|
||||||
|
else:
|
||||||
|
sample[col] = np.random.uniform(0, 1000)
|
||||||
|
samples.append(sample)
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("XGBoost Multi-Label Inference API")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
# Initialize predictor
|
||||||
|
try:
|
||||||
|
predictor = XGBoostMultiLabelPredictor()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading model: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Example usage of df_aggregate with JSON string
|
||||||
|
print("\n=== Example 0: JSON Aggregation ===")
|
||||||
|
sample_json = json.dumps([
|
||||||
|
{
|
||||||
|
"node_num": 1,
|
||||||
|
"bandwidth_raw": 150.5,
|
||||||
|
"flops_raw": 2500.0,
|
||||||
|
"arith_intensity": 16.6,
|
||||||
|
"performance_gflops": 1200.0,
|
||||||
|
"memory_bw_gbs": 450,
|
||||||
|
"scalar_peak_gflops": 600,
|
||||||
|
"duration": 3600
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"node_num": 2,
|
||||||
|
"bandwidth_raw": 155.2,
|
||||||
|
"flops_raw": 2600.0,
|
||||||
|
"arith_intensity": 16.8,
|
||||||
|
"performance_gflops": 1250.0,
|
||||||
|
"memory_bw_gbs": 450,
|
||||||
|
"scalar_peak_gflops": 600,
|
||||||
|
"duration": 3600
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
|
try:
|
||||||
|
aggregated_features = df_aggregate(sample_json, job_id_full="test_job_123")
|
||||||
|
print(f"Aggregated features from JSON:")
|
||||||
|
for key, value in list(aggregated_features.items())[:10]:
|
||||||
|
print(f" {key}: {value}")
|
||||||
|
print(f" ... ({len(aggregated_features)} total features)")
|
||||||
|
|
||||||
|
# Use aggregated features for prediction
|
||||||
|
result = predictor.predict(aggregated_features, threshold=0.3)
|
||||||
|
print(f"\nPredictions from aggregated data: {result['predictions'][:3]}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in aggregation: {e}")
|
||||||
|
|
||||||
|
# Create sample data
|
||||||
|
print("\n=== Generating sample data for other examples ===")
|
||||||
|
sample_data = create_sample_data(3)
|
||||||
|
|
||||||
|
# Example 1: Single prediction
|
||||||
|
print("\n=== Example 1: Single Prediction ===")
|
||||||
|
result = predictor.predict(sample_data[0], threshold=0.3)
|
||||||
|
print(f"Predictions: {result['predictions']}")
|
||||||
|
print(f"Confidences: {result['confidences']}")
|
||||||
|
print(f"Top probabilities:")
|
||||||
|
for class_name, prob in sorted(result['probabilities'].items(),
|
||||||
|
key=lambda x: x[1], reverse=True)[:5]:
|
||||||
|
print(".4f")
|
||||||
|
|
||||||
|
# Example 2: Top-K predictions
|
||||||
|
print("\n=== Example 2: Top-5 Predictions ===")
|
||||||
|
top_result = predictor.predict_top_k(sample_data[1], k=5)
|
||||||
|
for i, class_name in enumerate(top_result['top_predictions'], 1):
|
||||||
|
prob = top_result['top_probabilities'][class_name]
|
||||||
|
print(f"{i}. {class_name}: {prob:.4f}")
|
||||||
|
|
||||||
|
# Example 3: Batch prediction
|
||||||
|
print("\n=== Example 3: Batch Prediction ===")
|
||||||
|
batch_results = predictor.batch_predict(sample_data, threshold=0.4)
|
||||||
|
for i, result in enumerate(batch_results, 1):
|
||||||
|
if 'error' not in result:
|
||||||
|
print(f"Sample {i}: {len(result['predictions'])} predictions")
|
||||||
|
else:
|
||||||
|
print(f"Sample {i}: Error - {result['error']}")
|
||||||
|
|
||||||
|
print("\nAPI ready for use!")
|
||||||
|
print("Usage:")
|
||||||
|
print(" predictor = XGBoostMultiLabelPredictor()")
|
||||||
|
print(" result = predictor.predict(your_features)")
|
||||||
|
print(" top_k = predictor.predict_top_k(your_features, k=5)")
|
||||||
221
xgb_local_example.py
Normal file
221
xgb_local_example.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
XGBoost Multi-Label Inference Usage Examples
|
||||||
|
===================================================
|
||||||
|
|
||||||
|
This script demonstrates how to use the XGBoostMultiLabelPredictor class
|
||||||
|
for multi-label classification with confidence scores.
|
||||||
|
|
||||||
|
Sample data is from real HPC workloads extracted from roofline_features.h5:
|
||||||
|
- TurTLE: Turbulence simulation (memory-bound, low arithmetic intensity ~0.84)
|
||||||
|
- SCALEXA: Scaling benchmarks (high bw-flops correlation ~0.995)
|
||||||
|
- Chroma: Lattice QCD (compute-intensive, high arithmetic intensity ~2.6)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from xgb_local import XGBoostMultiLabelPredictor
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Realistic Sample Data from roofline_features.h5
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# TurTLE application - turbulence simulation workload
|
||||||
|
SAMPLE_TURTLE = {
|
||||||
|
"bandwidth_raw_p10": 186.33,
|
||||||
|
"bandwidth_raw_median": 205.14,
|
||||||
|
"bandwidth_raw_p90": 210.83,
|
||||||
|
"bandwidth_raw_mad": 3.57,
|
||||||
|
"bandwidth_raw_range": 24.5,
|
||||||
|
"bandwidth_raw_iqr": 12.075,
|
||||||
|
"flops_raw_p10": 162.024,
|
||||||
|
"flops_raw_median": 171.45,
|
||||||
|
"flops_raw_p90": 176.48,
|
||||||
|
"flops_raw_mad": 3.08,
|
||||||
|
"flops_raw_range": 14.456,
|
||||||
|
"flops_raw_iqr": 8.29,
|
||||||
|
"arith_intensity_p10": 0.7906,
|
||||||
|
"arith_intensity_median": 0.837,
|
||||||
|
"arith_intensity_p90": 0.9109,
|
||||||
|
"arith_intensity_mad": 0.02,
|
||||||
|
"arith_intensity_range": 0.12,
|
||||||
|
"arith_intensity_iqr": 0.0425,
|
||||||
|
"bw_flops_covariance": 60.86,
|
||||||
|
"bw_flops_correlation": 0.16,
|
||||||
|
"avg_performance_gflops": 168.1,
|
||||||
|
"median_performance_gflops": 171.45,
|
||||||
|
"performance_gflops_mad": 3.08,
|
||||||
|
"avg_memory_bw_gbs": 350.0,
|
||||||
|
"scalar_peak_gflops": 432.0,
|
||||||
|
"simd_peak_gflops": 9216.0,
|
||||||
|
"node_num": 0,
|
||||||
|
"duration": 19366,
|
||||||
|
}
|
||||||
|
|
||||||
|
# SCALEXA application - scaling benchmark workload
|
||||||
|
SAMPLE_SCALEXA = {
|
||||||
|
"bandwidth_raw_p10": 13.474,
|
||||||
|
"bandwidth_raw_median": 32.57,
|
||||||
|
"bandwidth_raw_p90": 51.466,
|
||||||
|
"bandwidth_raw_mad": 23.62,
|
||||||
|
"bandwidth_raw_range": 37.992,
|
||||||
|
"bandwidth_raw_iqr": 23.745,
|
||||||
|
"flops_raw_p10": 4.24,
|
||||||
|
"flops_raw_median": 16.16,
|
||||||
|
"flops_raw_p90": 24.584,
|
||||||
|
"flops_raw_mad": 10.53,
|
||||||
|
"flops_raw_range": 20.344,
|
||||||
|
"flops_raw_iqr": 12.715,
|
||||||
|
"arith_intensity_p10": 0.211,
|
||||||
|
"arith_intensity_median": 0.475,
|
||||||
|
"arith_intensity_p90": 0.492,
|
||||||
|
"arith_intensity_mad": 0.021,
|
||||||
|
"arith_intensity_range": 0.281,
|
||||||
|
"arith_intensity_iqr": 0.176,
|
||||||
|
"bw_flops_covariance": 302.0,
|
||||||
|
"bw_flops_correlation": 0.995,
|
||||||
|
"avg_performance_gflops": 14.7,
|
||||||
|
"median_performance_gflops": 16.16,
|
||||||
|
"performance_gflops_mad": 10.53,
|
||||||
|
"avg_memory_bw_gbs": 350.0,
|
||||||
|
"scalar_peak_gflops": 432.0,
|
||||||
|
"simd_peak_gflops": 9216.0,
|
||||||
|
"node_num": 18,
|
||||||
|
"duration": 165,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Chroma application - lattice QCD workload (compute-intensive)
|
||||||
|
SAMPLE_CHROMA = {
|
||||||
|
"bandwidth_raw_p10": 154.176,
|
||||||
|
"bandwidth_raw_median": 200.57,
|
||||||
|
"bandwidth_raw_p90": 259.952,
|
||||||
|
"bandwidth_raw_mad": 5.12,
|
||||||
|
"bandwidth_raw_range": 105.776,
|
||||||
|
"bandwidth_raw_iqr": 10.215,
|
||||||
|
"flops_raw_p10": 327.966,
|
||||||
|
"flops_raw_median": 519.8,
|
||||||
|
"flops_raw_p90": 654.422,
|
||||||
|
"flops_raw_mad": 16.97,
|
||||||
|
"flops_raw_range": 326.456,
|
||||||
|
"flops_raw_iqr": 34.88,
|
||||||
|
"arith_intensity_p10": 1.55,
|
||||||
|
"arith_intensity_median": 2.595,
|
||||||
|
"arith_intensity_p90": 3.445,
|
||||||
|
"arith_intensity_mad": 0.254,
|
||||||
|
"arith_intensity_range": 1.894,
|
||||||
|
"arith_intensity_iqr": 0.512,
|
||||||
|
"bw_flops_covariance": 382.76,
|
||||||
|
"bw_flops_correlation": 0.063,
|
||||||
|
"avg_performance_gflops": 503.26,
|
||||||
|
"median_performance_gflops": 519.8,
|
||||||
|
"performance_gflops_mad": 16.97,
|
||||||
|
"avg_memory_bw_gbs": 350.0,
|
||||||
|
"scalar_peak_gflops": 432.0,
|
||||||
|
"simd_peak_gflops": 9216.0,
|
||||||
|
"node_num": 3,
|
||||||
|
"duration": 31133,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Raw JSON roofline data (before aggregation, as produced by monitoring)
|
||||||
|
SAMPLE_JSON_ROOFLINE = json.dumps([
|
||||||
|
{"node_num": 1, "bandwidth_raw": 150.5, "flops_raw": 2500.0, "arith_intensity": 16.6,
|
||||||
|
"performance_gflops": 1200.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
|
||||||
|
{"node_num": 1, "bandwidth_raw": 155.2, "flops_raw": 2600.0, "arith_intensity": 16.8,
|
||||||
|
"performance_gflops": 1250.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
|
||||||
|
{"node_num": 1, "bandwidth_raw": 148.0, "flops_raw": 2450.0, "arith_intensity": 16.5,
|
||||||
|
"performance_gflops": 1180.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("XGBoost Multi-Label Inference Examples")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Initialize the predictor
|
||||||
|
predictor = XGBoostMultiLabelPredictor()
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Example 1: Single prediction with aggregated features
|
||||||
|
# =========================================================================
|
||||||
|
print("\n=== Example 1: Single Prediction (TurTLE workload) ===")
|
||||||
|
|
||||||
|
result = predictor.predict(SAMPLE_TURTLE, threshold=0.3)
|
||||||
|
|
||||||
|
print(f"Predictions: {result['predictions']}")
|
||||||
|
print(f"Confidences: {result['confidences']}")
|
||||||
|
print("\nTop 5 probabilities:")
|
||||||
|
sorted_probs = sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True)
|
||||||
|
for cls, prob in sorted_probs[:5]:
|
||||||
|
print(f" {cls}: {prob:.4f}")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Example 2: Compare different workload types
|
||||||
|
# =========================================================================
|
||||||
|
print("\n=== Example 2: Compare Different Workloads ===")
|
||||||
|
|
||||||
|
workloads = [
|
||||||
|
("TurTLE (turbulence)", SAMPLE_TURTLE),
|
||||||
|
("SCALEXA (benchmark)", SAMPLE_SCALEXA),
|
||||||
|
("Chroma (lattice QCD)", SAMPLE_CHROMA),
|
||||||
|
]
|
||||||
|
|
||||||
|
for name, features in workloads:
|
||||||
|
result = predictor.predict(features, threshold=0.3)
|
||||||
|
top_pred = result['predictions'][0] if result['predictions'] else "None"
|
||||||
|
top_prob = max(result['probabilities'].values()) if result['probabilities'] else 0
|
||||||
|
print(f"{name:25} -> Top prediction: {top_pred:20} (prob: {top_prob:.4f})")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Example 3: Top-K predictions
|
||||||
|
# =========================================================================
|
||||||
|
print("\n=== Example 3: Top-5 Predictions (Chroma workload) ===")
|
||||||
|
|
||||||
|
top_k_result = predictor.predict_top_k(SAMPLE_CHROMA, k=5)
|
||||||
|
|
||||||
|
for i, cls in enumerate(top_k_result['top_predictions'], 1):
|
||||||
|
prob = top_k_result['top_probabilities'][cls]
|
||||||
|
print(f" {i}. {cls}: {prob:.4f}")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Example 4: Batch prediction
|
||||||
|
# =========================================================================
|
||||||
|
print("\n=== Example 4: Batch Prediction ===")
|
||||||
|
|
||||||
|
batch_data = [SAMPLE_TURTLE, SAMPLE_SCALEXA, SAMPLE_CHROMA]
|
||||||
|
batch_results = predictor.batch_predict(batch_data, threshold=0.3)
|
||||||
|
|
||||||
|
for i, result in enumerate(batch_results, 1):
|
||||||
|
if 'error' not in result:
|
||||||
|
preds = result['predictions'][:2] # Show top 2
|
||||||
|
print(f" Sample {i}: {preds}")
|
||||||
|
else:
|
||||||
|
print(f" Sample {i}: Error - {result['error']}")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Example 5: Prediction from raw JSON roofline data
|
||||||
|
# =========================================================================
|
||||||
|
print("\n=== Example 5: Prediction from Raw JSON Data ===")
|
||||||
|
|
||||||
|
result = predictor.predict(
|
||||||
|
SAMPLE_JSON_ROOFLINE,
|
||||||
|
is_json=True,
|
||||||
|
job_id="example_job_001",
|
||||||
|
threshold=0.3
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Predictions: {result['predictions'][:3]}")
|
||||||
|
print(f"(Aggregated from {len(json.loads(SAMPLE_JSON_ROOFLINE))} roofline samples)")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Example 6: Model information
|
||||||
|
# =========================================================================
|
||||||
|
print("\n=== Example 6: Model Information ===")
|
||||||
|
|
||||||
|
info = predictor.get_class_info()
|
||||||
|
print(f"Number of classes: {info['n_classes']}")
|
||||||
|
print(f"Number of features: {info['n_features']}")
|
||||||
|
print(f"Sample classes: {info['classes'][:5]}...")
|
||||||
|
print(f"Sample features: {info['feature_columns'][:3]}...")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
xgb_model.joblib
Normal file
BIN
xgb_model.joblib
Normal file
Binary file not shown.
Reference in New Issue
Block a user