430 lines
12 KiB
Python
430 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
FastAPI REST API for XGBoost Multi-Label Classification Inference
|
|
|
|
Provides HTTP endpoints for:
|
|
- POST /predict: Single prediction with confidence scores
|
|
- POST /predict_top_k: Top-K predictions
|
|
- POST /batch_predict: Batch predictions
|
|
|
|
Supports both feature vectors and JSON roofline data as input.
|
|
"""
|
|
|
|
from fastapi import FastAPI, HTTPException, Body
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field, validator
|
|
from typing import Dict, List, Union, Optional, Any
|
|
import logging
|
|
import uvicorn
|
|
from contextlib import asynccontextmanager
|
|
|
|
from xgb_local import XGBoostMultiLabelPredictor
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global predictor instance
|
|
predictor: Optional[XGBoostMultiLabelPredictor] = None
|
|
|
|
|
|
# Pydantic models for request/response validation
|
|
class PredictRequest(BaseModel):
|
|
"""Request model for single prediction."""
|
|
features: Union[Dict[str, float], str] = Field(
|
|
...,
|
|
description="Feature dictionary or JSON string of roofline data"
|
|
)
|
|
threshold: float = Field(
|
|
default=0.5,
|
|
ge=0.0,
|
|
le=1.0,
|
|
description="Probability threshold for classification"
|
|
)
|
|
return_all_probabilities: bool = Field(
|
|
default=True,
|
|
description="Whether to return probabilities for all classes"
|
|
)
|
|
is_json: bool = Field(
|
|
default=False,
|
|
description="Whether features is a JSON string of roofline data"
|
|
)
|
|
job_id: Optional[str] = Field(
|
|
default=None,
|
|
description="Optional job ID for JSON aggregation"
|
|
)
|
|
|
|
|
|
class PredictTopKRequest(BaseModel):
|
|
"""Request model for top-K prediction."""
|
|
features: Union[Dict[str, float], str] = Field(
|
|
...,
|
|
description="Feature dictionary or JSON string of roofline data"
|
|
)
|
|
k: int = Field(
|
|
default=5,
|
|
ge=1,
|
|
le=100,
|
|
description="Number of top predictions to return"
|
|
)
|
|
is_json: bool = Field(
|
|
default=False,
|
|
description="Whether features is a JSON string of roofline data"
|
|
)
|
|
job_id: Optional[str] = Field(
|
|
default=None,
|
|
description="Optional job ID for JSON aggregation"
|
|
)
|
|
|
|
|
|
class BatchPredictRequest(BaseModel):
|
|
"""Request model for batch prediction."""
|
|
features_list: List[Union[Dict[str, float], str]] = Field(
|
|
...,
|
|
description="List of feature dictionaries or JSON strings"
|
|
)
|
|
threshold: float = Field(
|
|
default=0.5,
|
|
ge=0.0,
|
|
le=1.0,
|
|
description="Probability threshold for classification"
|
|
)
|
|
is_json: bool = Field(
|
|
default=False,
|
|
description="Whether features are JSON strings of roofline data"
|
|
)
|
|
job_ids: Optional[List[str]] = Field(
|
|
default=None,
|
|
description="Optional list of job IDs for JSON aggregation"
|
|
)
|
|
|
|
|
|
class PredictResponse(BaseModel):
|
|
"""Response model for single prediction."""
|
|
predictions: List[str] = Field(description="List of predicted class names")
|
|
probabilities: Dict[str, float] = Field(description="Probabilities for each class")
|
|
confidences: Dict[str, float] = Field(description="Confidence scores for predicted classes")
|
|
threshold: float = Field(description="Threshold used for prediction")
|
|
|
|
|
|
class PredictTopKResponse(BaseModel):
|
|
"""Response model for top-K prediction."""
|
|
top_predictions: List[str] = Field(description="Top-K predicted class names")
|
|
top_probabilities: Dict[str, float] = Field(description="Probabilities for top-K classes")
|
|
all_probabilities: Dict[str, float] = Field(description="Probabilities for all classes")
|
|
|
|
|
|
class BatchPredictResponse(BaseModel):
|
|
"""Response model for batch prediction."""
|
|
results: List[Union[PredictResponse, Dict[str, str]]] = Field(
|
|
description="List of prediction results or errors"
|
|
)
|
|
total: int = Field(description="Total number of samples processed")
|
|
successful: int = Field(description="Number of successful predictions")
|
|
failed: int = Field(description="Number of failed predictions")
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
"""Response model for health check."""
|
|
status: str
|
|
model_loaded: bool
|
|
n_classes: Optional[int] = None
|
|
n_features: Optional[int] = None
|
|
classes: Optional[List[str]] = None
|
|
|
|
|
|
class ErrorResponse(BaseModel):
|
|
"""Response model for errors."""
|
|
error: str
|
|
detail: Optional[str] = None
|
|
|
|
|
|
# Lifespan context manager for startup/shutdown
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Manage application lifespan events."""
|
|
# Startup
|
|
global predictor
|
|
try:
|
|
logger.info("Loading XGBoost model...")
|
|
predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')
|
|
logger.info("Model loaded successfully!")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model: {e}")
|
|
predictor = None
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
logger.info("Shutting down...")
|
|
|
|
|
|
# Initialize FastAPI app
|
|
app = FastAPI(
|
|
title="XGBoost Multi-Label Classification API",
|
|
description="REST API for multi-label classification inference using XGBoost",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # In production, specify allowed origins
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# API Endpoints
|
|
|
|
@app.get("/", tags=["General"])
|
|
async def root():
|
|
"""Root endpoint with API information."""
|
|
return {
|
|
"name": "XGBoost Multi-Label Classification API",
|
|
"version": "1.0.0",
|
|
"endpoints": {
|
|
"health": "/health",
|
|
"predict": "/predict",
|
|
"predict_top_k": "/predict_top_k",
|
|
"batch_predict": "/batch_predict"
|
|
}
|
|
}
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse, tags=["General"])
|
|
async def health_check():
|
|
"""
|
|
Check API health and model status.
|
|
|
|
Returns model information if loaded successfully.
|
|
"""
|
|
if predictor is None:
|
|
return HealthResponse(
|
|
status="error",
|
|
model_loaded=False
|
|
)
|
|
|
|
try:
|
|
info = predictor.get_class_info()
|
|
return HealthResponse(
|
|
status="healthy",
|
|
model_loaded=True,
|
|
n_classes=info['n_classes'],
|
|
n_features=info['n_features'],
|
|
classes=info['classes']
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Health check error: {e}")
|
|
return HealthResponse(
|
|
status="degraded",
|
|
model_loaded=True
|
|
)
|
|
|
|
|
|
@app.post("/predict", response_model=PredictResponse, tags=["Inference"])
|
|
async def predict(request: PredictRequest):
|
|
"""
|
|
Perform single prediction on input features.
|
|
|
|
**Input formats:**
|
|
- Feature dictionary: `{"feature1": value1, "feature2": value2, ...}`
|
|
- JSON roofline data: Set `is_json=true` and provide JSON string
|
|
|
|
**Example (features):**
|
|
```json
|
|
{
|
|
"features": {
|
|
"bandwidth_raw_p10": 150.5,
|
|
"flops_raw_median": 2500.0,
|
|
...
|
|
},
|
|
"threshold": 0.5
|
|
}
|
|
```
|
|
|
|
**Example (JSON roofline):**
|
|
```json
|
|
{
|
|
"features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
|
|
"is_json": true,
|
|
"job_id": "test_job_123",
|
|
"threshold": 0.3
|
|
}
|
|
```
|
|
"""
|
|
if predictor is None:
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
try:
|
|
result = predictor.predict(
|
|
features=request.features,
|
|
threshold=request.threshold,
|
|
return_all_probabilities=request.return_all_probabilities,
|
|
is_json=request.is_json,
|
|
job_id=request.job_id
|
|
)
|
|
return PredictResponse(**result)
|
|
except Exception as e:
|
|
logger.error(f"Prediction error: {e}")
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@app.post("/predict_top_k", response_model=PredictTopKResponse, tags=["Inference"])
|
|
async def predict_top_k(request: PredictTopKRequest):
|
|
"""
|
|
Get top-K predictions with their probabilities.
|
|
|
|
**Example (features):**
|
|
```json
|
|
{
|
|
"features": {
|
|
"bandwidth_raw_p10": 150.5,
|
|
"flops_raw_median": 2500.0,
|
|
...
|
|
},
|
|
"k": 5
|
|
}
|
|
```
|
|
|
|
**Example (JSON roofline):**
|
|
```json
|
|
{
|
|
"features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
|
|
"is_json": true,
|
|
"job_id": "test_job_123",
|
|
"k": 10
|
|
}
|
|
```
|
|
"""
|
|
if predictor is None:
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
try:
|
|
result = predictor.predict_top_k(
|
|
features=request.features,
|
|
k=request.k,
|
|
is_json=request.is_json,
|
|
job_id=request.job_id
|
|
)
|
|
return PredictTopKResponse(**result)
|
|
except Exception as e:
|
|
logger.error(f"Top-K prediction error: {e}")
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@app.post("/batch_predict", response_model=BatchPredictResponse, tags=["Inference"])
|
|
async def batch_predict(request: BatchPredictRequest):
|
|
"""
|
|
Perform batch prediction on multiple samples.
|
|
|
|
**Example (features):**
|
|
```json
|
|
{
|
|
"features_list": [
|
|
{"bandwidth_raw_p10": 150.5, ...},
|
|
{"bandwidth_raw_p10": 160.2, ...}
|
|
],
|
|
"threshold": 0.5
|
|
}
|
|
```
|
|
|
|
**Example (JSON roofline):**
|
|
```json
|
|
{
|
|
"features_list": [
|
|
"[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
|
|
"[{\"node_num\": 2, \"bandwidth_raw\": 160.2, ...}]"
|
|
],
|
|
"is_json": true,
|
|
"job_ids": ["job1", "job2"],
|
|
"threshold": 0.3
|
|
}
|
|
```
|
|
"""
|
|
if predictor is None:
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
try:
|
|
results = predictor.batch_predict(
|
|
features_list=request.features_list,
|
|
threshold=request.threshold,
|
|
is_json=request.is_json,
|
|
job_ids=request.job_ids
|
|
)
|
|
|
|
# Count successful and failed predictions
|
|
successful = sum(1 for r in results if 'error' not in r)
|
|
failed = len(results) - successful
|
|
|
|
return BatchPredictResponse(
|
|
results=results,
|
|
total=len(results),
|
|
successful=successful,
|
|
failed=failed
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Batch prediction error: {e}")
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@app.get("/model/info", tags=["Model"])
|
|
async def model_info():
|
|
"""
|
|
Get detailed model information.
|
|
|
|
Returns information about classes, features, and model configuration.
|
|
"""
|
|
if predictor is None:
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
try:
|
|
info = predictor.get_class_info()
|
|
return {
|
|
"classes": info['classes'],
|
|
"n_classes": info['n_classes'],
|
|
"features": info['feature_columns'],
|
|
"n_features": info['n_features']
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Model info error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
def main():
|
|
"""Run the FastAPI server."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="XGBoost Multi-Label Classification REST API")
|
|
parser.add_argument('--host', type=str, default='0.0.0.0',
|
|
help='Host to bind to (default: 0.0.0.0)')
|
|
parser.add_argument('--port', type=int, default=8000,
|
|
help='Port to bind to (default: 8000)')
|
|
parser.add_argument('--reload', action='store_true',
|
|
help='Enable auto-reload for development')
|
|
parser.add_argument('--workers', type=int, default=1,
|
|
help='Number of worker processes (default: 1)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
logger.info(f"Starting FastAPI server on {args.host}:{args.port}")
|
|
logger.info(f"Workers: {args.workers}, Reload: {args.reload}")
|
|
|
|
uvicorn.run(
|
|
"xgb_fastapi:app",
|
|
host=args.host,
|
|
port=args.port,
|
|
reload=args.reload,
|
|
workers=args.workers if not args.reload else 1,
|
|
log_level="info"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|