Initial commit
This commit is contained in:
811
test_xgb_fastapi.py
Normal file
811
test_xgb_fastapi.py
Normal file
@@ -0,0 +1,811 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for XGBoost FastAPI Multi-Label Classification API
|
||||
|
||||
These tests use realistic sample data extracted from /Volumes/T7/roofline_features.h5
|
||||
which was generated using feature_aggregator.py to process roofline dataframes.
|
||||
|
||||
Test data includes samples from different application types:
|
||||
- TurTLE (turbulence simulation)
|
||||
- SCALEXA (scaling benchmarks)
|
||||
- Chroma (lattice QCD)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch, MagicMock
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# Import the FastAPI app and predictor
|
||||
import xgb_fastapi
|
||||
from xgb_fastapi import app
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Data: Realistic samples from roofline_features.h5
|
||||
# Generated using feature_aggregator.py processing roofline dataframes
|
||||
# ============================================================================
|
||||
|
||||
# Sample 1: TurTLE application - typical turbulence simulation workload
|
||||
SAMPLE_TURTLE = {
|
||||
"bandwidth_raw_p10": 186.33,
|
||||
"bandwidth_raw_median": 205.14,
|
||||
"bandwidth_raw_p90": 210.83,
|
||||
"bandwidth_raw_mad": 3.57,
|
||||
"bandwidth_raw_range": 24.5,
|
||||
"bandwidth_raw_iqr": 12.075,
|
||||
"flops_raw_p10": 162.024,
|
||||
"flops_raw_median": 171.45,
|
||||
"flops_raw_p90": 176.48,
|
||||
"flops_raw_mad": 3.08,
|
||||
"flops_raw_range": 14.456,
|
||||
"flops_raw_iqr": 8.29,
|
||||
"arith_intensity_p10": 0.7906,
|
||||
"arith_intensity_median": 0.837,
|
||||
"arith_intensity_p90": 0.9109,
|
||||
"arith_intensity_mad": 0.02,
|
||||
"arith_intensity_range": 0.12,
|
||||
"arith_intensity_iqr": 0.0425,
|
||||
"bw_flops_covariance": 60.86,
|
||||
"bw_flops_correlation": 0.16,
|
||||
"avg_performance_gflops": 168.1,
|
||||
"median_performance_gflops": 171.45,
|
||||
"performance_gflops_mad": 3.08,
|
||||
"avg_memory_bw_gbs": 350.0,
|
||||
"scalar_peak_gflops": 432.0,
|
||||
"simd_peak_gflops": 9216.0,
|
||||
"node_num": 0,
|
||||
"duration": 19366,
|
||||
}
|
||||
|
||||
# Sample 2: SCALEXA application - scaling benchmark workload
|
||||
SAMPLE_SCALEXA = {
|
||||
"bandwidth_raw_p10": 13.474,
|
||||
"bandwidth_raw_median": 32.57,
|
||||
"bandwidth_raw_p90": 51.466,
|
||||
"bandwidth_raw_mad": 23.62,
|
||||
"bandwidth_raw_range": 37.992,
|
||||
"bandwidth_raw_iqr": 23.745,
|
||||
"flops_raw_p10": 4.24,
|
||||
"flops_raw_median": 16.16,
|
||||
"flops_raw_p90": 24.584,
|
||||
"flops_raw_mad": 10.53,
|
||||
"flops_raw_range": 20.344,
|
||||
"flops_raw_iqr": 12.715,
|
||||
"arith_intensity_p10": 0.211,
|
||||
"arith_intensity_median": 0.475,
|
||||
"arith_intensity_p90": 0.492,
|
||||
"arith_intensity_mad": 0.021,
|
||||
"arith_intensity_range": 0.281,
|
||||
"arith_intensity_iqr": 0.176,
|
||||
"bw_flops_covariance": 302.0,
|
||||
"bw_flops_correlation": 0.995,
|
||||
"avg_performance_gflops": 14.7,
|
||||
"median_performance_gflops": 16.16,
|
||||
"performance_gflops_mad": 10.53,
|
||||
"avg_memory_bw_gbs": 350.0,
|
||||
"scalar_peak_gflops": 432.0,
|
||||
"simd_peak_gflops": 9216.0,
|
||||
"node_num": 18,
|
||||
"duration": 165,
|
||||
}
|
||||
|
||||
# Sample 3: Chroma application - lattice QCD workload (compute-intensive)
|
||||
SAMPLE_CHROMA = {
|
||||
"bandwidth_raw_p10": 154.176,
|
||||
"bandwidth_raw_median": 200.57,
|
||||
"bandwidth_raw_p90": 259.952,
|
||||
"bandwidth_raw_mad": 5.12,
|
||||
"bandwidth_raw_range": 105.776,
|
||||
"bandwidth_raw_iqr": 10.215,
|
||||
"flops_raw_p10": 327.966,
|
||||
"flops_raw_median": 519.8,
|
||||
"flops_raw_p90": 654.422,
|
||||
"flops_raw_mad": 16.97,
|
||||
"flops_raw_range": 326.456,
|
||||
"flops_raw_iqr": 34.88,
|
||||
"arith_intensity_p10": 1.55,
|
||||
"arith_intensity_median": 2.595,
|
||||
"arith_intensity_p90": 3.445,
|
||||
"arith_intensity_mad": 0.254,
|
||||
"arith_intensity_range": 1.894,
|
||||
"arith_intensity_iqr": 0.512,
|
||||
"bw_flops_covariance": 382.76,
|
||||
"bw_flops_correlation": 0.063,
|
||||
"avg_performance_gflops": 503.26,
|
||||
"median_performance_gflops": 519.8,
|
||||
"performance_gflops_mad": 16.97,
|
||||
"avg_memory_bw_gbs": 350.0,
|
||||
"scalar_peak_gflops": 432.0,
|
||||
"simd_peak_gflops": 9216.0,
|
||||
"node_num": 3,
|
||||
"duration": 31133,
|
||||
}
|
||||
|
||||
# Sample JSON roofline data (raw data before aggregation, as would be received by API)
|
||||
SAMPLE_JSON_ROOFLINE = json.dumps([
|
||||
{
|
||||
"node_num": 1,
|
||||
"bandwidth_raw": 150.5,
|
||||
"flops_raw": 2500.0,
|
||||
"arith_intensity": 16.6,
|
||||
"performance_gflops": 1200.0,
|
||||
"memory_bw_gbs": 450,
|
||||
"scalar_peak_gflops": 600,
|
||||
"duration": 3600
|
||||
},
|
||||
{
|
||||
"node_num": 1,
|
||||
"bandwidth_raw": 155.2,
|
||||
"flops_raw": 2600.0,
|
||||
"arith_intensity": 16.8,
|
||||
"performance_gflops": 1250.0,
|
||||
"memory_bw_gbs": 450,
|
||||
"scalar_peak_gflops": 600,
|
||||
"duration": 3600
|
||||
},
|
||||
{
|
||||
"node_num": 1,
|
||||
"bandwidth_raw": 148.0,
|
||||
"flops_raw": 2450.0,
|
||||
"arith_intensity": 16.5,
|
||||
"performance_gflops": 1180.0,
|
||||
"memory_bw_gbs": 450,
|
||||
"scalar_peak_gflops": 600,
|
||||
"duration": 3600
|
||||
}
|
||||
])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def setup_predictor():
|
||||
"""
|
||||
Set up the predictor for tests.
|
||||
Try to load the real model if available.
|
||||
"""
|
||||
from xgb_inference_api import XGBoostMultiLabelPredictor
|
||||
|
||||
model_path = os.path.join(os.path.dirname(__file__), 'xgb_model.joblib')
|
||||
if os.path.exists(model_path):
|
||||
try:
|
||||
predictor = XGBoostMultiLabelPredictor(model_path)
|
||||
xgb_fastapi.predictor = predictor
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to load model: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(setup_predictor):
|
||||
"""Create a test client for the FastAPI app."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_loaded(setup_predictor):
|
||||
"""Check if the model is loaded."""
|
||||
return setup_predictor
|
||||
|
||||
|
||||
def skip_if_no_model(model_loaded):
|
||||
"""Helper to skip tests if model is not loaded."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded, skipping test")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Health and Root Endpoint Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestHealthEndpoints:
|
||||
"""Tests for health check and root endpoints."""
|
||||
|
||||
def test_root_endpoint(self, client):
|
||||
"""Test the root endpoint returns API information."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "name" in data
|
||||
assert data["name"] == "XGBoost Multi-Label Classification API"
|
||||
assert "version" in data
|
||||
assert "endpoints" in data
|
||||
assert all(key in data["endpoints"] for key in ["health", "predict", "predict_top_k", "batch_predict"])
|
||||
|
||||
def test_health_check(self, client, model_loaded):
|
||||
"""Test the health check endpoint."""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
assert "model_loaded" in data
|
||||
|
||||
# If model is loaded, check for additional info
|
||||
if model_loaded:
|
||||
assert data["model_loaded"] == True
|
||||
assert data["status"] in ["healthy", "degraded"]
|
||||
if data["status"] == "healthy":
|
||||
assert "n_classes" in data
|
||||
assert "n_features" in data
|
||||
assert "classes" in data
|
||||
assert data["n_classes"] > 0
|
||||
assert data["n_features"] > 0
|
||||
|
||||
def test_health_check_model_not_loaded(self, client):
|
||||
"""Test health check returns correct status when model not loaded."""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
# Should have status and model_loaded fields regardless
|
||||
assert "status" in data
|
||||
assert "model_loaded" in data
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Single Prediction Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestPredictEndpoint:
|
||||
"""Tests for the /predict endpoint."""
|
||||
|
||||
def test_predict_with_feature_dict(self, client, model_loaded):
|
||||
"""Test prediction with a feature dictionary (TurTLE sample)."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.5,
|
||||
"return_all_probabilities": True
|
||||
}
|
||||
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "predictions" in data
|
||||
assert "probabilities" in data
|
||||
assert "confidences" in data
|
||||
assert "threshold" in data
|
||||
|
||||
assert isinstance(data["predictions"], list)
|
||||
assert isinstance(data["probabilities"], dict)
|
||||
assert data["threshold"] == 0.5
|
||||
|
||||
# All probabilities should be between 0 and 1
|
||||
for prob in data["probabilities"].values():
|
||||
assert 0.0 <= prob <= 1.0
|
||||
|
||||
def test_predict_with_different_thresholds(self, client, model_loaded):
|
||||
"""Test that different thresholds affect predictions."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_low = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.1,
|
||||
"return_all_probabilities": True
|
||||
}
|
||||
request_high = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.9,
|
||||
"return_all_probabilities": True
|
||||
}
|
||||
|
||||
response_low = client.post("/predict", json=request_low)
|
||||
response_high = client.post("/predict", json=request_high)
|
||||
|
||||
assert response_low.status_code == 200
|
||||
assert response_high.status_code == 200
|
||||
|
||||
data_low = response_low.json()
|
||||
data_high = response_high.json()
|
||||
|
||||
# Lower threshold should generally produce more predictions
|
||||
assert len(data_low["predictions"]) >= len(data_high["predictions"])
|
||||
|
||||
def test_predict_different_workloads(self, client, model_loaded):
|
||||
"""Test predictions on different application workloads."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
samples = [
|
||||
("TurTLE", SAMPLE_TURTLE),
|
||||
("SCALEXA", SAMPLE_SCALEXA),
|
||||
("Chroma", SAMPLE_CHROMA),
|
||||
]
|
||||
|
||||
for name, sample in samples:
|
||||
request_data = {
|
||||
"features": sample,
|
||||
"threshold": 0.3,
|
||||
"return_all_probabilities": True
|
||||
}
|
||||
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 200, f"Failed for {name}"
|
||||
|
||||
data = response.json()
|
||||
assert len(data["probabilities"]) > 0, f"No probabilities for {name}"
|
||||
|
||||
def test_predict_return_only_predicted_probabilities(self, client, model_loaded):
|
||||
"""Test prediction with return_all_probabilities=False."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.3,
|
||||
"return_all_probabilities": False
|
||||
}
|
||||
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
# When return_all_probabilities is False, probabilities should only
|
||||
# contain classes that are in predictions
|
||||
if len(data["predictions"]) > 0:
|
||||
assert set(data["probabilities"].keys()) == set(data["predictions"])
|
||||
|
||||
def test_predict_with_json_roofline_data(self, client, model_loaded):
|
||||
"""Test prediction with raw JSON roofline data (requires aggregation)."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": SAMPLE_JSON_ROOFLINE,
|
||||
"is_json": True,
|
||||
"job_id": "test_job_123",
|
||||
"threshold": 0.3
|
||||
}
|
||||
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "predictions" in data
|
||||
assert "probabilities" in data
|
||||
|
||||
def test_predict_threshold_boundaries(self, client, model_loaded):
|
||||
"""Test prediction with threshold at boundaries."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
for threshold in [0.0, 0.5, 1.0]:
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": threshold
|
||||
}
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_predict_invalid_threshold(self, client):
|
||||
"""Test that invalid threshold values are rejected."""
|
||||
for threshold in [-0.1, 1.5]:
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": threshold
|
||||
}
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
def test_predict_model_not_loaded(self, client):
|
||||
"""Test that prediction returns 503 when model not loaded."""
|
||||
# Temporarily set predictor to None
|
||||
original_predictor = xgb_fastapi.predictor
|
||||
xgb_fastapi.predictor = None
|
||||
|
||||
try:
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.5
|
||||
}
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 503
|
||||
assert "Model not loaded" in response.json().get("detail", "")
|
||||
finally:
|
||||
xgb_fastapi.predictor = original_predictor
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Top-K Prediction Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestPredictTopKEndpoint:
|
||||
"""Tests for the /predict_top_k endpoint."""
|
||||
|
||||
def test_predict_top_k_default(self, client, model_loaded):
|
||||
"""Test top-K prediction with default k=5."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE
|
||||
}
|
||||
|
||||
response = client.post("/predict_top_k", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "top_predictions" in data
|
||||
assert "top_probabilities" in data
|
||||
assert "all_probabilities" in data
|
||||
|
||||
assert len(data["top_predictions"]) <= 5
|
||||
assert len(data["top_probabilities"]) <= 5
|
||||
|
||||
def test_predict_top_k_custom_k(self, client, model_loaded):
|
||||
"""Test top-K prediction with custom k values."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
for k in [1, 3, 10]:
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"k": k
|
||||
}
|
||||
|
||||
response = client.post("/predict_top_k", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data["top_predictions"]) <= k
|
||||
|
||||
def test_predict_top_k_ordering(self, client, model_loaded):
|
||||
"""Test that top-K predictions are ordered by probability."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": SAMPLE_CHROMA,
|
||||
"k": 10
|
||||
}
|
||||
|
||||
response = client.post("/predict_top_k", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
probabilities = [data["top_probabilities"][cls] for cls in data["top_predictions"]]
|
||||
|
||||
# Check that probabilities are in descending order
|
||||
for i in range(len(probabilities) - 1):
|
||||
assert probabilities[i] >= probabilities[i + 1]
|
||||
|
||||
def test_predict_top_k_with_json_data(self, client, model_loaded):
|
||||
"""Test top-K prediction with JSON roofline data."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": SAMPLE_JSON_ROOFLINE,
|
||||
"is_json": True,
|
||||
"job_id": "test_job_456",
|
||||
"k": 5
|
||||
}
|
||||
|
||||
response = client.post("/predict_top_k", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data["top_predictions"]) <= 5
|
||||
|
||||
def test_predict_top_k_invalid_k(self, client):
|
||||
"""Test that invalid k values are rejected."""
|
||||
for k in [0, -1, 101]:
|
||||
request_data = {
|
||||
"features": SAMPLE_TURTLE,
|
||||
"k": k
|
||||
}
|
||||
response = client.post("/predict_top_k", json=request_data)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Batch Prediction Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestBatchPredictEndpoint:
|
||||
"""Tests for the /batch_predict endpoint."""
|
||||
|
||||
def test_batch_predict_multiple_samples(self, client, model_loaded):
|
||||
"""Test batch prediction with multiple samples."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features_list": [SAMPLE_TURTLE, SAMPLE_SCALEXA, SAMPLE_CHROMA],
|
||||
"threshold": 0.3
|
||||
}
|
||||
|
||||
response = client.post("/batch_predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "results" in data
|
||||
assert "total" in data
|
||||
assert "successful" in data
|
||||
assert "failed" in data
|
||||
|
||||
assert data["total"] == 3
|
||||
assert len(data["results"]) == 3
|
||||
assert data["successful"] + data["failed"] == data["total"]
|
||||
|
||||
def test_batch_predict_single_sample(self, client, model_loaded):
|
||||
"""Test batch prediction with a single sample."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features_list": [SAMPLE_TURTLE],
|
||||
"threshold": 0.5
|
||||
}
|
||||
|
||||
response = client.post("/batch_predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 1
|
||||
assert len(data["results"]) == 1
|
||||
|
||||
def test_batch_predict_with_json_data(self, client, model_loaded):
|
||||
"""Test batch prediction with JSON roofline data."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features_list": [SAMPLE_JSON_ROOFLINE, SAMPLE_JSON_ROOFLINE],
|
||||
"is_json": True,
|
||||
"job_ids": ["job_001", "job_002"],
|
||||
"threshold": 0.3
|
||||
}
|
||||
|
||||
response = client.post("/batch_predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
|
||||
def test_batch_predict_empty_list(self, client, model_loaded):
|
||||
"""Test batch prediction with empty list."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features_list": [],
|
||||
"threshold": 0.5
|
||||
}
|
||||
|
||||
response = client.post("/batch_predict", json=request_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["total"] == 0
|
||||
assert data["successful"] == 0
|
||||
assert data["failed"] == 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Model Info Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestModelInfoEndpoint:
|
||||
"""Tests for the /model/info endpoint."""
|
||||
|
||||
def test_model_info(self, client, model_loaded):
|
||||
"""Test getting model information."""
|
||||
response = client.get("/model/info")
|
||||
|
||||
# May return 503 if model not loaded, or 200 if loaded
|
||||
if model_loaded:
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "classes" in data
|
||||
assert "n_classes" in data
|
||||
assert "features" in data
|
||||
assert "n_features" in data
|
||||
|
||||
assert isinstance(data["classes"], list)
|
||||
assert len(data["classes"]) == data["n_classes"]
|
||||
assert len(data["features"]) == data["n_features"]
|
||||
else:
|
||||
assert response.status_code == 503
|
||||
|
||||
def test_model_info_not_loaded(self, client):
|
||||
"""Test that model info returns 503 when model not loaded."""
|
||||
original_predictor = xgb_fastapi.predictor
|
||||
xgb_fastapi.predictor = None
|
||||
|
||||
try:
|
||||
response = client.get("/model/info")
|
||||
assert response.status_code == 503
|
||||
finally:
|
||||
xgb_fastapi.predictor = original_predictor
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Error Handling Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling."""
|
||||
|
||||
def test_predict_missing_features(self, client):
|
||||
"""Test prediction without features field."""
|
||||
request_data = {
|
||||
"threshold": 0.5
|
||||
}
|
||||
response = client.post("/predict", json=request_data)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_predict_invalid_json_format(self, client, model_loaded):
|
||||
"""Test prediction with invalid JSON in is_json mode."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
request_data = {
|
||||
"features": "not valid json {{",
|
||||
"is_json": True
|
||||
}
|
||||
response = client.post("/predict", json=request_data)
|
||||
# Should return error (400)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_invalid_endpoint(self, client):
|
||||
"""Test accessing an invalid endpoint."""
|
||||
response = client.get("/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests (Full Pipeline)
|
||||
# ============================================================================
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for full prediction pipeline."""
|
||||
|
||||
def test_full_prediction_pipeline_features(self, client, model_loaded):
|
||||
"""Test complete prediction pipeline with feature dict."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
# 1. Check health
|
||||
health_response = client.get("/health")
|
||||
assert health_response.status_code == 200
|
||||
|
||||
health_data = health_response.json()
|
||||
assert health_data["model_loaded"] == True
|
||||
|
||||
# 2. Get model info
|
||||
info_response = client.get("/model/info")
|
||||
assert info_response.status_code == 200
|
||||
|
||||
# 3. Make single prediction
|
||||
predict_response = client.post("/predict", json={
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.3
|
||||
})
|
||||
assert predict_response.status_code == 200
|
||||
|
||||
# 4. Make top-K prediction
|
||||
topk_response = client.post("/predict_top_k", json={
|
||||
"features": SAMPLE_TURTLE,
|
||||
"k": 5
|
||||
})
|
||||
assert topk_response.status_code == 200
|
||||
|
||||
# 5. Make batch prediction
|
||||
batch_response = client.post("/batch_predict", json={
|
||||
"features_list": [SAMPLE_TURTLE, SAMPLE_CHROMA],
|
||||
"threshold": 0.3
|
||||
})
|
||||
assert batch_response.status_code == 200
|
||||
|
||||
def test_consistency_single_vs_batch(self, client, model_loaded):
|
||||
"""Test that single prediction and batch prediction give consistent results."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
threshold = 0.3
|
||||
|
||||
# Single prediction
|
||||
single_response = client.post("/predict", json={
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": threshold
|
||||
})
|
||||
|
||||
# Batch prediction with same sample
|
||||
batch_response = client.post("/batch_predict", json={
|
||||
"features_list": [SAMPLE_TURTLE],
|
||||
"threshold": threshold
|
||||
})
|
||||
|
||||
assert single_response.status_code == 200
|
||||
assert batch_response.status_code == 200
|
||||
|
||||
single_data = single_response.json()
|
||||
batch_data = batch_response.json()
|
||||
|
||||
if batch_data["successful"] == 1:
|
||||
batch_result = batch_data["results"][0]
|
||||
# Predictions should be the same
|
||||
assert set(single_data["predictions"]) == set(batch_result["predictions"])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CORS Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestCORS:
|
||||
"""Tests for CORS configuration."""
|
||||
|
||||
def test_cors_headers(self, client):
|
||||
"""Test that CORS headers are present in responses."""
|
||||
response = client.options("/predict")
|
||||
# Accept either 200 or 405 (method not allowed for OPTIONS in some configs)
|
||||
assert response.status_code in [200, 405]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Performance Tests (Basic)
|
||||
# ============================================================================
|
||||
|
||||
class TestPerformance:
|
||||
"""Basic performance tests."""
|
||||
|
||||
def test_response_time_single_prediction(self, client, model_loaded):
|
||||
"""Test that single prediction completes in reasonable time."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
response = client.post("/predict", json={
|
||||
"features": SAMPLE_TURTLE,
|
||||
"threshold": 0.5
|
||||
})
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Should complete within 5 seconds (generous for CI environments)
|
||||
assert elapsed < 5.0, f"Prediction took {elapsed:.2f}s, expected < 5s"
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_response_time_batch_prediction(self, client, model_loaded):
|
||||
"""Test that batch prediction scales reasonably."""
|
||||
if not model_loaded:
|
||||
pytest.skip("Model not loaded")
|
||||
|
||||
import time
|
||||
|
||||
# Create a batch of 10 samples
|
||||
features_list = [SAMPLE_TURTLE] * 10
|
||||
|
||||
start = time.time()
|
||||
response = client.post("/batch_predict", json={
|
||||
"features_list": features_list,
|
||||
"threshold": 0.5
|
||||
})
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Should complete within 10 seconds
|
||||
assert elapsed < 10.0, f"Batch prediction took {elapsed:.2f}s, expected < 10s"
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Run tests if executed directly
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
Reference in New Issue
Block a user