Initial commit
This commit is contained in:
429
xgb_fastapi.py
Normal file
429
xgb_fastapi.py
Normal file
@@ -0,0 +1,429 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FastAPI REST API for XGBoost Multi-Label Classification Inference
|
||||
|
||||
Provides HTTP endpoints for:
|
||||
- POST /predict: Single prediction with confidence scores
|
||||
- POST /predict_top_k: Top-K predictions
|
||||
- POST /batch_predict: Batch predictions
|
||||
|
||||
Supports both feature vectors and JSON roofline data as input.
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import Dict, List, Union, Optional, Any
|
||||
import logging
|
||||
import uvicorn
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from xgb_local import XGBoostMultiLabelPredictor
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global predictor instance
|
||||
predictor: Optional[XGBoostMultiLabelPredictor] = None
|
||||
|
||||
|
||||
# Pydantic models for request/response validation
|
||||
class PredictRequest(BaseModel):
|
||||
"""Request model for single prediction."""
|
||||
features: Union[Dict[str, float], str] = Field(
|
||||
...,
|
||||
description="Feature dictionary or JSON string of roofline data"
|
||||
)
|
||||
threshold: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Probability threshold for classification"
|
||||
)
|
||||
return_all_probabilities: bool = Field(
|
||||
default=True,
|
||||
description="Whether to return probabilities for all classes"
|
||||
)
|
||||
is_json: bool = Field(
|
||||
default=False,
|
||||
description="Whether features is a JSON string of roofline data"
|
||||
)
|
||||
job_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional job ID for JSON aggregation"
|
||||
)
|
||||
|
||||
|
||||
class PredictTopKRequest(BaseModel):
|
||||
"""Request model for top-K prediction."""
|
||||
features: Union[Dict[str, float], str] = Field(
|
||||
...,
|
||||
description="Feature dictionary or JSON string of roofline data"
|
||||
)
|
||||
k: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of top predictions to return"
|
||||
)
|
||||
is_json: bool = Field(
|
||||
default=False,
|
||||
description="Whether features is a JSON string of roofline data"
|
||||
)
|
||||
job_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional job ID for JSON aggregation"
|
||||
)
|
||||
|
||||
|
||||
class BatchPredictRequest(BaseModel):
|
||||
"""Request model for batch prediction."""
|
||||
features_list: List[Union[Dict[str, float], str]] = Field(
|
||||
...,
|
||||
description="List of feature dictionaries or JSON strings"
|
||||
)
|
||||
threshold: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Probability threshold for classification"
|
||||
)
|
||||
is_json: bool = Field(
|
||||
default=False,
|
||||
description="Whether features are JSON strings of roofline data"
|
||||
)
|
||||
job_ids: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Optional list of job IDs for JSON aggregation"
|
||||
)
|
||||
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
"""Response model for single prediction."""
|
||||
predictions: List[str] = Field(description="List of predicted class names")
|
||||
probabilities: Dict[str, float] = Field(description="Probabilities for each class")
|
||||
confidences: Dict[str, float] = Field(description="Confidence scores for predicted classes")
|
||||
threshold: float = Field(description="Threshold used for prediction")
|
||||
|
||||
|
||||
class PredictTopKResponse(BaseModel):
|
||||
"""Response model for top-K prediction."""
|
||||
top_predictions: List[str] = Field(description="Top-K predicted class names")
|
||||
top_probabilities: Dict[str, float] = Field(description="Probabilities for top-K classes")
|
||||
all_probabilities: Dict[str, float] = Field(description="Probabilities for all classes")
|
||||
|
||||
|
||||
class BatchPredictResponse(BaseModel):
|
||||
"""Response model for batch prediction."""
|
||||
results: List[Union[PredictResponse, Dict[str, str]]] = Field(
|
||||
description="List of prediction results or errors"
|
||||
)
|
||||
total: int = Field(description="Total number of samples processed")
|
||||
successful: int = Field(description="Number of successful predictions")
|
||||
failed: int = Field(description="Number of failed predictions")
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Response model for health check."""
|
||||
status: str
|
||||
model_loaded: bool
|
||||
n_classes: Optional[int] = None
|
||||
n_features: Optional[int] = None
|
||||
classes: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Response model for errors."""
|
||||
error: str
|
||||
detail: Optional[str] = None
|
||||
|
||||
|
||||
# Lifespan context manager for startup/shutdown
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage application lifespan events."""
|
||||
# Startup
|
||||
global predictor
|
||||
try:
|
||||
logger.info("Loading XGBoost model...")
|
||||
predictor = XGBoostMultiLabelPredictor('xgb_model.joblib')
|
||||
logger.info("Model loaded successfully!")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
predictor = None
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down...")
|
||||
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title="XGBoost Multi-Label Classification API",
|
||||
description="REST API for multi-label classification inference using XGBoost",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, specify allowed origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# API Endpoints
|
||||
|
||||
@app.get("/", tags=["General"])
|
||||
async def root():
|
||||
"""Root endpoint with API information."""
|
||||
return {
|
||||
"name": "XGBoost Multi-Label Classification API",
|
||||
"version": "1.0.0",
|
||||
"endpoints": {
|
||||
"health": "/health",
|
||||
"predict": "/predict",
|
||||
"predict_top_k": "/predict_top_k",
|
||||
"batch_predict": "/batch_predict"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse, tags=["General"])
|
||||
async def health_check():
|
||||
"""
|
||||
Check API health and model status.
|
||||
|
||||
Returns model information if loaded successfully.
|
||||
"""
|
||||
if predictor is None:
|
||||
return HealthResponse(
|
||||
status="error",
|
||||
model_loaded=False
|
||||
)
|
||||
|
||||
try:
|
||||
info = predictor.get_class_info()
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
model_loaded=True,
|
||||
n_classes=info['n_classes'],
|
||||
n_features=info['n_features'],
|
||||
classes=info['classes']
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Health check error: {e}")
|
||||
return HealthResponse(
|
||||
status="degraded",
|
||||
model_loaded=True
|
||||
)
|
||||
|
||||
|
||||
@app.post("/predict", response_model=PredictResponse, tags=["Inference"])
|
||||
async def predict(request: PredictRequest):
|
||||
"""
|
||||
Perform single prediction on input features.
|
||||
|
||||
**Input formats:**
|
||||
- Feature dictionary: `{"feature1": value1, "feature2": value2, ...}`
|
||||
- JSON roofline data: Set `is_json=true` and provide JSON string
|
||||
|
||||
**Example (features):**
|
||||
```json
|
||||
{
|
||||
"features": {
|
||||
"bandwidth_raw_p10": 150.5,
|
||||
"flops_raw_median": 2500.0,
|
||||
...
|
||||
},
|
||||
"threshold": 0.5
|
||||
}
|
||||
```
|
||||
|
||||
**Example (JSON roofline):**
|
||||
```json
|
||||
{
|
||||
"features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
|
||||
"is_json": true,
|
||||
"job_id": "test_job_123",
|
||||
"threshold": 0.3
|
||||
}
|
||||
```
|
||||
"""
|
||||
if predictor is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
result = predictor.predict(
|
||||
features=request.features,
|
||||
threshold=request.threshold,
|
||||
return_all_probabilities=request.return_all_probabilities,
|
||||
is_json=request.is_json,
|
||||
job_id=request.job_id
|
||||
)
|
||||
return PredictResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"Prediction error: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/predict_top_k", response_model=PredictTopKResponse, tags=["Inference"])
|
||||
async def predict_top_k(request: PredictTopKRequest):
|
||||
"""
|
||||
Get top-K predictions with their probabilities.
|
||||
|
||||
**Example (features):**
|
||||
```json
|
||||
{
|
||||
"features": {
|
||||
"bandwidth_raw_p10": 150.5,
|
||||
"flops_raw_median": 2500.0,
|
||||
...
|
||||
},
|
||||
"k": 5
|
||||
}
|
||||
```
|
||||
|
||||
**Example (JSON roofline):**
|
||||
```json
|
||||
{
|
||||
"features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
|
||||
"is_json": true,
|
||||
"job_id": "test_job_123",
|
||||
"k": 10
|
||||
}
|
||||
```
|
||||
"""
|
||||
if predictor is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
result = predictor.predict_top_k(
|
||||
features=request.features,
|
||||
k=request.k,
|
||||
is_json=request.is_json,
|
||||
job_id=request.job_id
|
||||
)
|
||||
return PredictTopKResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"Top-K prediction error: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/batch_predict", response_model=BatchPredictResponse, tags=["Inference"])
|
||||
async def batch_predict(request: BatchPredictRequest):
|
||||
"""
|
||||
Perform batch prediction on multiple samples.
|
||||
|
||||
**Example (features):**
|
||||
```json
|
||||
{
|
||||
"features_list": [
|
||||
{"bandwidth_raw_p10": 150.5, ...},
|
||||
{"bandwidth_raw_p10": 160.2, ...}
|
||||
],
|
||||
"threshold": 0.5
|
||||
}
|
||||
```
|
||||
|
||||
**Example (JSON roofline):**
|
||||
```json
|
||||
{
|
||||
"features_list": [
|
||||
"[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]",
|
||||
"[{\"node_num\": 2, \"bandwidth_raw\": 160.2, ...}]"
|
||||
],
|
||||
"is_json": true,
|
||||
"job_ids": ["job1", "job2"],
|
||||
"threshold": 0.3
|
||||
}
|
||||
```
|
||||
"""
|
||||
if predictor is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
results = predictor.batch_predict(
|
||||
features_list=request.features_list,
|
||||
threshold=request.threshold,
|
||||
is_json=request.is_json,
|
||||
job_ids=request.job_ids
|
||||
)
|
||||
|
||||
# Count successful and failed predictions
|
||||
successful = sum(1 for r in results if 'error' not in r)
|
||||
failed = len(results) - successful
|
||||
|
||||
return BatchPredictResponse(
|
||||
results=results,
|
||||
total=len(results),
|
||||
successful=successful,
|
||||
failed=failed
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Batch prediction error: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/model/info", tags=["Model"])
|
||||
async def model_info():
|
||||
"""
|
||||
Get detailed model information.
|
||||
|
||||
Returns information about classes, features, and model configuration.
|
||||
"""
|
||||
if predictor is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
info = predictor.get_class_info()
|
||||
return {
|
||||
"classes": info['classes'],
|
||||
"n_classes": info['n_classes'],
|
||||
"features": info['feature_columns'],
|
||||
"n_features": info['n_features']
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Model info error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the FastAPI server."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="XGBoost Multi-Label Classification REST API")
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0',
|
||||
help='Host to bind to (default: 0.0.0.0)')
|
||||
parser.add_argument('--port', type=int, default=8000,
|
||||
help='Port to bind to (default: 8000)')
|
||||
parser.add_argument('--reload', action='store_true',
|
||||
help='Enable auto-reload for development')
|
||||
parser.add_argument('--workers', type=int, default=1,
|
||||
help='Number of worker processes (default: 1)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"Starting FastAPI server on {args.host}:{args.port}")
|
||||
logger.info(f"Workers: {args.workers}, Reload: {args.reload}")
|
||||
|
||||
uvicorn.run(
|
||||
"xgb_fastapi:app",
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
reload=args.reload,
|
||||
workers=args.workers if not args.reload else 1,
|
||||
log_level="info"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user