#!/usr/bin/env python3 """ FastAPI REST API for XGBoost Multi-Label Classification Inference Provides HTTP endpoints for: - POST /predict: Single prediction with confidence scores - POST /predict_top_k: Top-K predictions - POST /batch_predict: Batch predictions Supports both feature vectors and JSON roofline data as input. """ from fastapi import FastAPI, HTTPException, Body from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field, validator from typing import Dict, List, Union, Optional, Any import logging import uvicorn from contextlib import asynccontextmanager from xgb_local import XGBoostMultiLabelPredictor # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Global predictor instance predictor: Optional[XGBoostMultiLabelPredictor] = None # Pydantic models for request/response validation class PredictRequest(BaseModel): """Request model for single prediction.""" features: Union[Dict[str, float], str] = Field( ..., description="Feature dictionary or JSON string of roofline data" ) threshold: float = Field( default=0.5, ge=0.0, le=1.0, description="Probability threshold for classification" ) return_all_probabilities: bool = Field( default=True, description="Whether to return probabilities for all classes" ) is_json: bool = Field( default=False, description="Whether features is a JSON string of roofline data" ) job_id: Optional[str] = Field( default=None, description="Optional job ID for JSON aggregation" ) class PredictTopKRequest(BaseModel): """Request model for top-K prediction.""" features: Union[Dict[str, float], str] = Field( ..., description="Feature dictionary or JSON string of roofline data" ) k: int = Field( default=5, ge=1, le=100, description="Number of top predictions to return" ) is_json: bool = Field( default=False, description="Whether features is a JSON string of roofline data" ) job_id: Optional[str] = Field( default=None, description="Optional job ID for JSON aggregation" ) class BatchPredictRequest(BaseModel): """Request model for batch prediction.""" features_list: List[Union[Dict[str, float], str]] = Field( ..., description="List of feature dictionaries or JSON strings" ) threshold: float = Field( default=0.5, ge=0.0, le=1.0, description="Probability threshold for classification" ) is_json: bool = Field( default=False, description="Whether features are JSON strings of roofline data" ) job_ids: Optional[List[str]] = Field( default=None, description="Optional list of job IDs for JSON aggregation" ) class PredictResponse(BaseModel): """Response model for single prediction.""" predictions: List[str] = Field(description="List of predicted class names") probabilities: Dict[str, float] = Field(description="Probabilities for each class") confidences: Dict[str, float] = Field(description="Confidence scores for predicted classes") threshold: float = Field(description="Threshold used for prediction") class PredictTopKResponse(BaseModel): """Response model for top-K prediction.""" top_predictions: List[str] = Field(description="Top-K predicted class names") top_probabilities: Dict[str, float] = Field(description="Probabilities for top-K classes") all_probabilities: Dict[str, float] = Field(description="Probabilities for all classes") class BatchPredictResponse(BaseModel): """Response model for batch prediction.""" results: List[Union[PredictResponse, Dict[str, str]]] = Field( description="List of prediction results or errors" ) total: int = Field(description="Total number of samples processed") successful: int = Field(description="Number of successful predictions") failed: int = Field(description="Number of failed predictions") class HealthResponse(BaseModel): """Response model for health check.""" status: str model_loaded: bool n_classes: Optional[int] = None n_features: Optional[int] = None classes: Optional[List[str]] = None class ErrorResponse(BaseModel): """Response model for errors.""" error: str detail: Optional[str] = None # Lifespan context manager for startup/shutdown @asynccontextmanager async def lifespan(app: FastAPI): """Manage application lifespan events.""" # Startup global predictor try: logger.info("Loading XGBoost model...") predictor = XGBoostMultiLabelPredictor('xgb_model.joblib') logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Failed to load model: {e}") predictor = None yield # Shutdown logger.info("Shutting down...") # Initialize FastAPI app app = FastAPI( title="XGBoost Multi-Label Classification API", description="REST API for multi-label classification inference using XGBoost", version="1.0.0", lifespan=lifespan ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, specify allowed origins allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # API Endpoints @app.get("/", tags=["General"]) async def root(): """Root endpoint with API information.""" return { "name": "XGBoost Multi-Label Classification API", "version": "1.0.0", "endpoints": { "health": "/health", "predict": "/predict", "predict_top_k": "/predict_top_k", "batch_predict": "/batch_predict" } } @app.get("/health", response_model=HealthResponse, tags=["General"]) async def health_check(): """ Check API health and model status. Returns model information if loaded successfully. """ if predictor is None: return HealthResponse( status="error", model_loaded=False ) try: info = predictor.get_class_info() return HealthResponse( status="healthy", model_loaded=True, n_classes=info['n_classes'], n_features=info['n_features'], classes=info['classes'] ) except Exception as e: logger.error(f"Health check error: {e}") return HealthResponse( status="degraded", model_loaded=True ) @app.post("/predict", response_model=PredictResponse, tags=["Inference"]) async def predict(request: PredictRequest): """ Perform single prediction on input features. **Input formats:** - Feature dictionary: `{"feature1": value1, "feature2": value2, ...}` - JSON roofline data: Set `is_json=true` and provide JSON string **Example (features):** ```json { "features": { "bandwidth_raw_p10": 150.5, "flops_raw_median": 2500.0, ... }, "threshold": 0.5 } ``` **Example (JSON roofline):** ```json { "features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]", "is_json": true, "job_id": "test_job_123", "threshold": 0.3 } ``` """ if predictor is None: raise HTTPException(status_code=503, detail="Model not loaded") try: result = predictor.predict( features=request.features, threshold=request.threshold, return_all_probabilities=request.return_all_probabilities, is_json=request.is_json, job_id=request.job_id ) return PredictResponse(**result) except Exception as e: logger.error(f"Prediction error: {e}") raise HTTPException(status_code=400, detail=str(e)) @app.post("/predict_top_k", response_model=PredictTopKResponse, tags=["Inference"]) async def predict_top_k(request: PredictTopKRequest): """ Get top-K predictions with their probabilities. **Example (features):** ```json { "features": { "bandwidth_raw_p10": 150.5, "flops_raw_median": 2500.0, ... }, "k": 5 } ``` **Example (JSON roofline):** ```json { "features": "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]", "is_json": true, "job_id": "test_job_123", "k": 10 } ``` """ if predictor is None: raise HTTPException(status_code=503, detail="Model not loaded") try: result = predictor.predict_top_k( features=request.features, k=request.k, is_json=request.is_json, job_id=request.job_id ) return PredictTopKResponse(**result) except Exception as e: logger.error(f"Top-K prediction error: {e}") raise HTTPException(status_code=400, detail=str(e)) @app.post("/batch_predict", response_model=BatchPredictResponse, tags=["Inference"]) async def batch_predict(request: BatchPredictRequest): """ Perform batch prediction on multiple samples. **Example (features):** ```json { "features_list": [ {"bandwidth_raw_p10": 150.5, ...}, {"bandwidth_raw_p10": 160.2, ...} ], "threshold": 0.5 } ``` **Example (JSON roofline):** ```json { "features_list": [ "[{\"node_num\": 1, \"bandwidth_raw\": 150.5, ...}]", "[{\"node_num\": 2, \"bandwidth_raw\": 160.2, ...}]" ], "is_json": true, "job_ids": ["job1", "job2"], "threshold": 0.3 } ``` """ if predictor is None: raise HTTPException(status_code=503, detail="Model not loaded") try: results = predictor.batch_predict( features_list=request.features_list, threshold=request.threshold, is_json=request.is_json, job_ids=request.job_ids ) # Count successful and failed predictions successful = sum(1 for r in results if 'error' not in r) failed = len(results) - successful return BatchPredictResponse( results=results, total=len(results), successful=successful, failed=failed ) except Exception as e: logger.error(f"Batch prediction error: {e}") raise HTTPException(status_code=400, detail=str(e)) @app.get("/model/info", tags=["Model"]) async def model_info(): """ Get detailed model information. Returns information about classes, features, and model configuration. """ if predictor is None: raise HTTPException(status_code=503, detail="Model not loaded") try: info = predictor.get_class_info() return { "classes": info['classes'], "n_classes": info['n_classes'], "features": info['feature_columns'], "n_features": info['n_features'] } except Exception as e: logger.error(f"Model info error: {e}") raise HTTPException(status_code=500, detail=str(e)) def main(): """Run the FastAPI server.""" import argparse parser = argparse.ArgumentParser(description="XGBoost Multi-Label Classification REST API") parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind to (default: 0.0.0.0)') parser.add_argument('--port', type=int, default=8000, help='Port to bind to (default: 8000)') parser.add_argument('--reload', action='store_true', help='Enable auto-reload for development') parser.add_argument('--workers', type=int, default=1, help='Number of worker processes (default: 1)') args = parser.parse_args() logger.info(f"Starting FastAPI server on {args.host}:{args.port}") logger.info(f"Workers: {args.workers}, Reload: {args.reload}") uvicorn.run( "xgb_fastapi:app", host=args.host, port=args.port, reload=args.reload, workers=args.workers if not args.reload else 1, log_level="info" ) if __name__ == "__main__": main()