Initial commit

This commit is contained in:
Bole Ma
2025-12-10 12:17:41 +01:00
commit 739563f916
12 changed files with 3428 additions and 0 deletions

429
xgb_fastapi.py Normal file
View File

@@ -0,0 +1,429 @@
#!/usr/bin/env python3
"""
FastAPI REST API for XGBoost Multi-Label Classification Inference
Provides HTTP endpoints for:
- POST /predict: Single prediction with confidence scores
- POST /predict_top_k: Top-K predictions
- POST /batch_predict: Batch predictions
Supports both feature vectors and JSON roofline data as input.
"""
from fastapi import FastAPI, HTTPException, Body
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, validator
from typing import Dict, List, Union, Optional, Any
import logging
import uvicorn
from contextlib import asynccontextmanager
from xgb_local import XGBoostMultiLabelPredictor
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Global predictor instance
predictor: Optional[XGBoostMultiLabelPredictor] = None
# Pydantic models for request/response validation
class PredictRequest(BaseModel):
"""Request model for single prediction."""
features: Union[Dict[str, float], str] = Field(
...,
description="Feature dictionary or JSON string of roofline data"
)
threshold: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Probability threshold for classification"
)
return_all_probabilities: bool = Field(
default=True,
description="Whether to return probabilities for all classes"
)
is_json: bool = Field(
default=False,
description="Whether features is a JSON string of roofline data"
)
job_id: Optional[str] = Field(
default=None,
description="Optional job ID for JSON aggregation"
)
class PredictTopKRequest(BaseModel):
"""Request model for top-K prediction."""
features: Union[Dict[str, float], str] = Field(
...,
description="Feature dictionary or JSON string of roofline data"
)
k: int = Field(
default=5,
ge=1,
le=100,
description="Number of top predictions to return"
)
is_json: bool = Field(
default=False,
description="Whether features is a JSON string of roofline data"
)
job_id: Optional[str] = Field(
default=None,
description="Optional job ID for JSON aggregation"
)
class BatchPredictRequest(BaseModel):
"""Request model for batch prediction."""
features_list: List[Union[Dict[str, float], str]] = Field(
...,
description="List of feature dictionaries or JSON strings"
)
threshold: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Probability threshold for classification"
)
is_json: bool = Field(
default=False,
description="Whether features are JSON strings of roofline data"
)
job_ids: Optional[List[str]] = Field(
default=None,
description="Optional list of job IDs for JSON aggregation"
)
class PredictResponse(BaseModel):
"""Response model for single prediction."""
predictions: List[str] = Field(description="List of predicted class names")
probabilities: Dict[str, float] = Field(description="Probabilities for each class")
confidences: Dict[str, float] = Field(description="Confidence scores for predicted classes")
threshold: float = Field(description="Threshold used for prediction")
class PredictTopKResponse(BaseModel):
"""Response model for top-K prediction."""
top_predictions: List[str] = Field(description="Top-K predicted class names")
top_probabilities: Dict[str, float] = Field(description="Probabilities for top-K classes")
all_probabilities: Dict[str, float] = Field(description="Probabilities for all classes")
class BatchPredictResponse(BaseModel):
"""Response model for batch prediction."""
results: List[Union[PredictResponse, Dict[str, str]]] = Field(
description="List of prediction results or errors"
)
total: int = Field(description="Total number of samples processed")
successful: int = Field(description="Number of successful predictions")
failed: int = Field(description="Number of failed predictions")
class HealthResponse(BaseModel):
"""Response model for health check."""
status: str
model_loaded: bool
n_classes: Optional[int] = None
n_features: Optional[int] = None
classes: Optional[List[str]] = None
class ErrorResponse(BaseModel):
"""Response model for errors."""
error: str
detail: Optional[str] = None
# Lifespan context manager for startup/shutdown
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan events."""
# Startup
global predictor
try:
logger.info("Loading XGBoost model...")
predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load model: {e}")
predictor = None
yield
# Shutdown
logger.info("Shutting down...")
# Initialize FastAPI app
app = FastAPI(
title="XGBoost Multi-Label Classification API",
description="REST API for multi-label classification inference using XGBoost",
version="1.0.0",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify allowed origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API Endpoints
@app.get("/", tags=["General"])
async def root():
"""Root endpoint with API information."""
return {
"name": "XGBoost Multi-Label Classification API",
"version": "1.0.0",
"endpoints": {
"health": "/health",
"predict": "/predict",
"predict_top_k": "/predict_top_k",
"batch_predict": "/batch_predict"
}
}
@app.get("/health", response_model=HealthResponse, tags=["General"])
async def health_check():
"""
Check API health and model status.
Returns model information if loaded successfully.
"""
if predictor is None:
return HealthResponse(
status="error",
model_loaded=False
)
try:
info = predictor.get_class_info()
return HealthResponse(
status="healthy",
model_loaded=True,
n_classes=info['n_classes'],
n_features=info['n_features'],
classes=info['classes']
)
except Exception as e:
logger.error(f"Health check error: {e}")
return HealthResponse(
status="degraded",
model_loaded=True
)
@app.post("/predict", response_model=PredictResponse, tags=["Inference"])
async def predict(request: PredictRequest):
"""
Perform single prediction on input features.
**Input formats:**
- Feature dictionary: `{"feature1": value1, "feature2": value2, ...}`
- JSON roofline data: Set `is_json=true` and provide JSON string
**Example (features):**
```json
{
"features": {
"bandwidth_raw_p10": 150.5,
"flops_raw_median": 2500.0,
...
},
"threshold": 0.5
}
```
**Example (JSON roofline):**
```json
{
"features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
"is_json": true,
"job_id": "test_job_123",
"threshold": 0.3
}
```
"""
if predictor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
result = predictor.predict(
features=request.features,
threshold=request.threshold,
return_all_probabilities=request.return_all_probabilities,
is_json=request.is_json,
job_id=request.job_id
)
return PredictResponse(**result)
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=400, detail=str(e))
@app.post("/predict_top_k", response_model=PredictTopKResponse, tags=["Inference"])
async def predict_top_k(request: PredictTopKRequest):
"""
Get top-K predictions with their probabilities.
**Example (features):**
```json
{
"features": {
"bandwidth_raw_p10": 150.5,
"flops_raw_median": 2500.0,
...
},
"k": 5
}
```
**Example (JSON roofline):**
```json
{
"features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
"is_json": true,
"job_id": "test_job_123",
"k": 10
}
```
"""
if predictor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
result = predictor.predict_top_k(
features=request.features,
k=request.k,
is_json=request.is_json,
job_id=request.job_id
)
return PredictTopKResponse(**result)
except Exception as e:
logger.error(f"Top-K prediction error: {e}")
raise HTTPException(status_code=400, detail=str(e))
@app.post("/batch_predict", response_model=BatchPredictResponse, tags=["Inference"])
async def batch_predict(request: BatchPredictRequest):
"""
Perform batch prediction on multiple samples.
**Example (features):**
```json
{
"features_list": [
{"bandwidth_raw_p10": 150.5, ...},
{"bandwidth_raw_p10": 160.2, ...}
],
"threshold": 0.5
}
```
**Example (JSON roofline):**
```json
{
"features_list": [
"[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
"[{\"node_num\": 2, \"bandwidth_raw\": 160.2, ...}]"
],
"is_json": true,
"job_ids": ["job1", "job2"],
"threshold": 0.3
}
```
"""
if predictor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
results = predictor.batch_predict(
features_list=request.features_list,
threshold=request.threshold,
is_json=request.is_json,
job_ids=request.job_ids
)
# Count successful and failed predictions
successful = sum(1 for r in results if 'error' not in r)
failed = len(results) - successful
return BatchPredictResponse(
results=results,
total=len(results),
successful=successful,
failed=failed
)
except Exception as e:
logger.error(f"Batch prediction error: {e}")
raise HTTPException(status_code=400, detail=str(e))
@app.get("/model/info", tags=["Model"])
async def model_info():
"""
Get detailed model information.
Returns information about classes, features, and model configuration.
"""
if predictor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
info = predictor.get_class_info()
return {
"classes": info['classes'],
"n_classes": info['n_classes'],
"features": info['feature_columns'],
"n_features": info['n_features']
}
except Exception as e:
logger.error(f"Model info error: {e}")
raise HTTPException(status_code=500, detail=str(e))
def main():
"""Run the FastAPI server."""
import argparse
parser = argparse.ArgumentParser(description="XGBoost Multi-Label Classification REST API")
parser.add_argument('--host', type=str, default='0.0.0.0',
help='Host to bind to (default: 0.0.0.0)')
parser.add_argument('--port', type=int, default=8000,
help='Port to bind to (default: 8000)')
parser.add_argument('--reload', action='store_true',
help='Enable auto-reload for development')
parser.add_argument('--workers', type=int, default=1,
help='Number of worker processes (default: 1)')
args = parser.parse_args()
logger.info(f"Starting FastAPI server on {args.host}:{args.port}")
logger.info(f"Workers: {args.workers}, Reload: {args.reload}")
uvicorn.run(
"xgb_fastapi:app",
host=args.host,
port=args.port,
reload=args.reload,
workers=args.workers if not args.reload else 1,
log_level="info"
)
if __name__ == "__main__":
main()