812 lines
27 KiB
Python
812 lines
27 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Tests for XGBoost FastAPI Multi-Label Classification API
|
|
|
|
These tests use realistic sample data extracted from /Volumes/T7/roofline_features.h5
|
|
which was generated using feature_aggregator.py to process roofline dataframes.
|
|
|
|
Test data includes samples from different application types:
|
|
- TurTLE (turbulence simulation)
|
|
- SCALEXA (scaling benchmarks)
|
|
- Chroma (lattice QCD)
|
|
"""
|
|
|
|
import pytest
|
|
import json
|
|
from fastapi.testclient import TestClient
|
|
from unittest.mock import patch, MagicMock
|
|
import numpy as np
|
|
import os
|
|
|
|
# Import the FastAPI app and predictor
|
|
import xgb_fastapi
|
|
from xgb_fastapi import app
|
|
|
|
|
|
# ============================================================================
|
|
# Test Data: Realistic samples from roofline_features.h5
|
|
# Generated using feature_aggregator.py processing roofline dataframes
|
|
# ============================================================================
|
|
|
|
# Sample 1: TurTLE application - typical turbulence simulation workload
|
|
SAMPLE_TURTLE = {
|
|
"bandwidth_raw_p10": 186.33,
|
|
"bandwidth_raw_median": 205.14,
|
|
"bandwidth_raw_p90": 210.83,
|
|
"bandwidth_raw_mad": 3.57,
|
|
"bandwidth_raw_range": 24.5,
|
|
"bandwidth_raw_iqr": 12.075,
|
|
"flops_raw_p10": 162.024,
|
|
"flops_raw_median": 171.45,
|
|
"flops_raw_p90": 176.48,
|
|
"flops_raw_mad": 3.08,
|
|
"flops_raw_range": 14.456,
|
|
"flops_raw_iqr": 8.29,
|
|
"arith_intensity_p10": 0.7906,
|
|
"arith_intensity_median": 0.837,
|
|
"arith_intensity_p90": 0.9109,
|
|
"arith_intensity_mad": 0.02,
|
|
"arith_intensity_range": 0.12,
|
|
"arith_intensity_iqr": 0.0425,
|
|
"bw_flops_covariance": 60.86,
|
|
"bw_flops_correlation": 0.16,
|
|
"avg_performance_gflops": 168.1,
|
|
"median_performance_gflops": 171.45,
|
|
"performance_gflops_mad": 3.08,
|
|
"avg_memory_bw_gbs": 350.0,
|
|
"scalar_peak_gflops": 432.0,
|
|
"simd_peak_gflops": 9216.0,
|
|
"node_num": 0,
|
|
"duration": 19366,
|
|
}
|
|
|
|
# Sample 2: SCALEXA application - scaling benchmark workload
|
|
SAMPLE_SCALEXA = {
|
|
"bandwidth_raw_p10": 13.474,
|
|
"bandwidth_raw_median": 32.57,
|
|
"bandwidth_raw_p90": 51.466,
|
|
"bandwidth_raw_mad": 23.62,
|
|
"bandwidth_raw_range": 37.992,
|
|
"bandwidth_raw_iqr": 23.745,
|
|
"flops_raw_p10": 4.24,
|
|
"flops_raw_median": 16.16,
|
|
"flops_raw_p90": 24.584,
|
|
"flops_raw_mad": 10.53,
|
|
"flops_raw_range": 20.344,
|
|
"flops_raw_iqr": 12.715,
|
|
"arith_intensity_p10": 0.211,
|
|
"arith_intensity_median": 0.475,
|
|
"arith_intensity_p90": 0.492,
|
|
"arith_intensity_mad": 0.021,
|
|
"arith_intensity_range": 0.281,
|
|
"arith_intensity_iqr": 0.176,
|
|
"bw_flops_covariance": 302.0,
|
|
"bw_flops_correlation": 0.995,
|
|
"avg_performance_gflops": 14.7,
|
|
"median_performance_gflops": 16.16,
|
|
"performance_gflops_mad": 10.53,
|
|
"avg_memory_bw_gbs": 350.0,
|
|
"scalar_peak_gflops": 432.0,
|
|
"simd_peak_gflops": 9216.0,
|
|
"node_num": 18,
|
|
"duration": 165,
|
|
}
|
|
|
|
# Sample 3: Chroma application - lattice QCD workload (compute-intensive)
|
|
SAMPLE_CHROMA = {
|
|
"bandwidth_raw_p10": 154.176,
|
|
"bandwidth_raw_median": 200.57,
|
|
"bandwidth_raw_p90": 259.952,
|
|
"bandwidth_raw_mad": 5.12,
|
|
"bandwidth_raw_range": 105.776,
|
|
"bandwidth_raw_iqr": 10.215,
|
|
"flops_raw_p10": 327.966,
|
|
"flops_raw_median": 519.8,
|
|
"flops_raw_p90": 654.422,
|
|
"flops_raw_mad": 16.97,
|
|
"flops_raw_range": 326.456,
|
|
"flops_raw_iqr": 34.88,
|
|
"arith_intensity_p10": 1.55,
|
|
"arith_intensity_median": 2.595,
|
|
"arith_intensity_p90": 3.445,
|
|
"arith_intensity_mad": 0.254,
|
|
"arith_intensity_range": 1.894,
|
|
"arith_intensity_iqr": 0.512,
|
|
"bw_flops_covariance": 382.76,
|
|
"bw_flops_correlation": 0.063,
|
|
"avg_performance_gflops": 503.26,
|
|
"median_performance_gflops": 519.8,
|
|
"performance_gflops_mad": 16.97,
|
|
"avg_memory_bw_gbs": 350.0,
|
|
"scalar_peak_gflops": 432.0,
|
|
"simd_peak_gflops": 9216.0,
|
|
"node_num": 3,
|
|
"duration": 31133,
|
|
}
|
|
|
|
# Sample JSON roofline data (raw data before aggregation, as would be received by API)
|
|
SAMPLE_JSON_ROOFLINE = json.dumps([
|
|
{
|
|
"node_num": 1,
|
|
"bandwidth_raw": 150.5,
|
|
"flops_raw": 2500.0,
|
|
"arith_intensity": 16.6,
|
|
"performance_gflops": 1200.0,
|
|
"memory_bw_gbs": 450,
|
|
"scalar_peak_gflops": 600,
|
|
"duration": 3600
|
|
},
|
|
{
|
|
"node_num": 1,
|
|
"bandwidth_raw": 155.2,
|
|
"flops_raw": 2600.0,
|
|
"arith_intensity": 16.8,
|
|
"performance_gflops": 1250.0,
|
|
"memory_bw_gbs": 450,
|
|
"scalar_peak_gflops": 600,
|
|
"duration": 3600
|
|
},
|
|
{
|
|
"node_num": 1,
|
|
"bandwidth_raw": 148.0,
|
|
"flops_raw": 2450.0,
|
|
"arith_intensity": 16.5,
|
|
"performance_gflops": 1180.0,
|
|
"memory_bw_gbs": 450,
|
|
"scalar_peak_gflops": 600,
|
|
"duration": 3600
|
|
}
|
|
])
|
|
|
|
|
|
# ============================================================================
|
|
# Fixtures
|
|
# ============================================================================
|
|
|
|
@pytest.fixture(scope="module")
|
|
def setup_predictor():
|
|
"""
|
|
Set up the predictor for tests.
|
|
Try to load the real model if available.
|
|
"""
|
|
from xgb_inference_api import XGBoostMultiLabelPredictor
|
|
|
|
model_path = os.path.join(os.path.dirname(__file__), 'xgb_model.joblib')
|
|
if os.path.exists(model_path):
|
|
try:
|
|
predictor = XGBoostMultiLabelPredictor(model_path)
|
|
xgb_fastapi.predictor = predictor
|
|
return True
|
|
except Exception as e:
|
|
print(f"Failed to load model: {e}")
|
|
return False
|
|
return False
|
|
|
|
|
|
@pytest.fixture
|
|
def client(setup_predictor):
|
|
"""Create a test client for the FastAPI app."""
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture
|
|
def model_loaded(setup_predictor):
|
|
"""Check if the model is loaded."""
|
|
return setup_predictor
|
|
|
|
|
|
def skip_if_no_model(model_loaded):
|
|
"""Helper to skip tests if model is not loaded."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded, skipping test")
|
|
|
|
|
|
# ============================================================================
|
|
# Health and Root Endpoint Tests
|
|
# ============================================================================
|
|
|
|
class TestHealthEndpoints:
|
|
"""Tests for health check and root endpoints."""
|
|
|
|
def test_root_endpoint(self, client):
|
|
"""Test the root endpoint returns API information."""
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "name" in data
|
|
assert data["name"] == "XGBoost Multi-Label Classification API"
|
|
assert "version" in data
|
|
assert "endpoints" in data
|
|
assert all(key in data["endpoints"] for key in ["health", "predict", "predict_top_k", "batch_predict"])
|
|
|
|
def test_health_check(self, client, model_loaded):
|
|
"""Test the health check endpoint."""
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "status" in data
|
|
assert "model_loaded" in data
|
|
|
|
# If model is loaded, check for additional info
|
|
if model_loaded:
|
|
assert data["model_loaded"] == True
|
|
assert data["status"] in ["healthy", "degraded"]
|
|
if data["status"] == "healthy":
|
|
assert "n_classes" in data
|
|
assert "n_features" in data
|
|
assert "classes" in data
|
|
assert data["n_classes"] > 0
|
|
assert data["n_features"] > 0
|
|
|
|
def test_health_check_model_not_loaded(self, client):
|
|
"""Test health check returns correct status when model not loaded."""
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
# Should have status and model_loaded fields regardless
|
|
assert "status" in data
|
|
assert "model_loaded" in data
|
|
|
|
|
|
# ============================================================================
|
|
# Single Prediction Tests
|
|
# ============================================================================
|
|
|
|
class TestPredictEndpoint:
|
|
"""Tests for the /predict endpoint."""
|
|
|
|
def test_predict_with_feature_dict(self, client, model_loaded):
|
|
"""Test prediction with a feature dictionary (TurTLE sample)."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
request_data = {
|
|
"features": SAMPLE_TURTLE,
|
|
"threshold": 0.5,
|
|
"return_all_probabilities": True
|
|
}
|
|
|
|
response = client.post("/predict", json=request_data)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "predictions" in data
|
|
assert "probabilities" in data
|
|
assert "confidences" in data
|
|
assert "threshold" in data
|
|
|
|
assert isinstance(data["predictions"], list)
|
|
assert isinstance(data["probabilities"], dict)
|
|
assert data["threshold"] == 0.5
|
|
|
|
# All probabilities should be between 0 and 1
|
|
for prob in data["probabilities"].values():
|
|
assert 0.0 <= prob <= 1.0
|
|
|
|
def test_predict_with_different_thresholds(self, client, model_loaded):
|
|
"""Test that different thresholds affect predictions."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
request_low = {
|
|
"features": SAMPLE_TURTLE,
|
|
"threshold": 0.1,
|
|
"return_all_probabilities": True
|
|
}
|
|
request_high = {
|
|
"features": SAMPLE_TURTLE,
|
|
"threshold": 0.9,
|
|
"return_all_probabilities": True
|
|
}
|
|
|
|
response_low = client.post("/predict", json=request_low)
|
|
response_high = client.post("/predict", json=request_high)
|
|
|
|
assert response_low.status_code == 200
|
|
assert response_high.status_code == 200
|
|
|
|
data_low = response_low.json()
|
|
data_high = response_high.json()
|
|
|
|
# Lower threshold should generally produce more predictions
|
|
assert len(data_low["predictions"]) >= len(data_high["predictions"])
|
|
|
|
def test_predict_different_workloads(self, client, model_loaded):
|
|
"""Test predictions on different application workloads."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
samples = [
|
|
("TurTLE", SAMPLE_TURTLE),
|
|
("SCALEXA", SAMPLE_SCALEXA),
|
|
("Chroma", SAMPLE_CHROMA),
|
|
]
|
|
|
|
for name, sample in samples:
|
|
request_data = {
|
|
"features": sample,
|
|
"threshold": 0.3,
|
|
"return_all_probabilities": True
|
|
}
|
|
|
|
response = client.post("/predict", json=request_data)
|
|
assert response.status_code == 200, f"Failed for {name}"
|
|
|
|
data = response.json()
|
|
assert len(data["probabilities"]) > 0, f"No probabilities for {name}"
|
|
|
|
def test_predict_return_only_predicted_probabilities(self, client, model_loaded):
|
|
"""Test prediction with return_all_probabilities=False."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
request_data = {
|
|
"features": SAMPLE_TURTLE,
|
|
"threshold": 0.3,
|
|
"return_all_probabilities": False
|
|
}
|
|
|
|
response = client.post("/predict", json=request_data)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
# When return_all_probabilities is False, probabilities should only
|
|
# contain classes that are in predictions
|
|
if len(data["predictions"]) > 0:
|
|
assert set(data["probabilities"].keys()) == set(data["predictions"])
|
|
|
|
def test_predict_with_json_roofline_data(self, client, model_loaded):
|
|
"""Test prediction with raw JSON roofline data (requires aggregation)."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
request_data = {
|
|
"features": SAMPLE_JSON_ROOFLINE,
|
|
"is_json": True,
|
|
"job_id": "test_job_123",
|
|
"threshold": 0.3
|
|
}
|
|
|
|
response = client.post("/predict", json=request_data)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "predictions" in data
|
|
assert "probabilities" in data
|
|
|
|
def test_predict_threshold_boundaries(self, client, model_loaded):
|
|
"""Test prediction with threshold at boundaries."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
for threshold in [0.0, 0.5, 1.0]:
|
|
request_data = {
|
|
"features": SAMPLE_TURTLE,
|
|
"threshold": threshold
|
|
}
|
|
response = client.post("/predict", json=request_data)
|
|
assert response.status_code == 200
|
|
|
|
def test_predict_invalid_threshold(self, client):
|
|
"""Test that invalid threshold values are rejected."""
|
|
for threshold in [-0.1, 1.5]:
|
|
request_data = {
|
|
"features": SAMPLE_TURTLE,
|
|
"threshold": threshold
|
|
}
|
|
response = client.post("/predict", json=request_data)
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
def test_predict_model_not_loaded(self, client):
|
|
"""Test that prediction returns 503 when model not loaded."""
|
|
# Temporarily set predictor to None
|
|
original_predictor = xgb_fastapi.predictor
|
|
xgb_fastapi.predictor = None
|
|
|
|
try:
|
|
request_data = {
|
|
"features": SAMPLE_TURTLE,
|
|
"threshold": 0.5
|
|
}
|
|
response = client.post("/predict", json=request_data)
|
|
assert response.status_code == 503
|
|
assert "Model not loaded" in response.json().get("detail", "")
|
|
finally:
|
|
xgb_fastapi.predictor = original_predictor
|
|
|
|
|
|
# ============================================================================
|
|
# Top-K Prediction Tests
|
|
# ============================================================================
|
|
|
|
class TestPredictTopKEndpoint:
|
|
"""Tests for the /predict_top_k endpoint."""
|
|
|
|
def test_predict_top_k_default(self, client, model_loaded):
|
|
"""Test top-K prediction with default k=5."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
request_data = {
|
|
"features": SAMPLE_TURTLE
|
|
}
|
|
|
|
response = client.post("/predict_top_k", json=request_data)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "top_predictions" in data
|
|
assert "top_probabilities" in data
|
|
assert "all_probabilities" in data
|
|
|
|
assert len(data["top_predictions"]) <= 5
|
|
assert len(data["top_probabilities"]) <= 5
|
|
|
|
def test_predict_top_k_custom_k(self, client, model_loaded):
|
|
"""Test top-K prediction with custom k values."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
for k in [1, 3, 10]:
|
|
request_data = {
|
|
"features": SAMPLE_TURTLE,
|
|
"k": k
|
|
}
|
|
|
|
response = client.post("/predict_top_k", json=request_data)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert len(data["top_predictions"]) <= k
|
|
|
|
def test_predict_top_k_ordering(self, client, model_loaded):
|
|
"""Test that top-K predictions are ordered by probability."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
request_data = {
|
|
"features": SAMPLE_CHROMA,
|
|
"k": 10
|
|
}
|
|
|
|
response = client.post("/predict_top_k", json=request_data)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
probabilities = [data["top_probabilities"][cls] for cls in data["top_predictions"]]
|
|
|
|
# Check that probabilities are in descending order
|
|
for i in range(len(probabilities) - 1):
|
|
assert probabilities[i] >= probabilities[i + 1]
|
|
|
|
def test_predict_top_k_with_json_data(self, client, model_loaded):
|
|
"""Test top-K prediction with JSON roofline data."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
request_data = {
|
|
"features": SAMPLE_JSON_ROOFLINE,
|
|
"is_json": True,
|
|
"job_id": "test_job_456",
|
|
"k": 5
|
|
}
|
|
|
|
response = client.post("/predict_top_k", json=request_data)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert len(data["top_predictions"]) <= 5
|
|
|
|
def test_predict_top_k_invalid_k(self, client):
|
|
"""Test that invalid k values are rejected."""
|
|
for k in [0, -1, 101]:
|
|
request_data = {
|
|
"features": SAMPLE_TURTLE,
|
|
"k": k
|
|
}
|
|
response = client.post("/predict_top_k", json=request_data)
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
|
|
# ============================================================================
|
|
# Batch Prediction Tests
|
|
# ============================================================================
|
|
|
|
class TestBatchPredictEndpoint:
|
|
"""Tests for the /batch_predict endpoint."""
|
|
|
|
def test_batch_predict_multiple_samples(self, client, model_loaded):
|
|
"""Test batch prediction with multiple samples."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
request_data = {
|
|
"features_list": [SAMPLE_TURTLE, SAMPLE_SCALEXA, SAMPLE_CHROMA],
|
|
"threshold": 0.3
|
|
}
|
|
|
|
response = client.post("/batch_predict", json=request_data)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "results" in data
|
|
assert "total" in data
|
|
assert "successful" in data
|
|
assert "failed" in data
|
|
|
|
assert data["total"] == 3
|
|
assert len(data["results"]) == 3
|
|
assert data["successful"] + data["failed"] == data["total"]
|
|
|
|
def test_batch_predict_single_sample(self, client, model_loaded):
|
|
"""Test batch prediction with a single sample."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
request_data = {
|
|
"features_list": [SAMPLE_TURTLE],
|
|
"threshold": 0.5
|
|
}
|
|
|
|
response = client.post("/batch_predict", json=request_data)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["total"] == 1
|
|
assert len(data["results"]) == 1
|
|
|
|
def test_batch_predict_with_json_data(self, client, model_loaded):
|
|
"""Test batch prediction with JSON roofline data."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
request_data = {
|
|
"features_list": [SAMPLE_JSON_ROOFLINE, SAMPLE_JSON_ROOFLINE],
|
|
"is_json": True,
|
|
"job_ids": ["job_001", "job_002"],
|
|
"threshold": 0.3
|
|
}
|
|
|
|
response = client.post("/batch_predict", json=request_data)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["total"] == 2
|
|
|
|
def test_batch_predict_empty_list(self, client, model_loaded):
|
|
"""Test batch prediction with empty list."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
request_data = {
|
|
"features_list": [],
|
|
"threshold": 0.5
|
|
}
|
|
|
|
response = client.post("/batch_predict", json=request_data)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["total"] == 0
|
|
assert data["successful"] == 0
|
|
assert data["failed"] == 0
|
|
|
|
|
|
# ============================================================================
|
|
# Model Info Tests
|
|
# ============================================================================
|
|
|
|
class TestModelInfoEndpoint:
|
|
"""Tests for the /model/info endpoint."""
|
|
|
|
def test_model_info(self, client, model_loaded):
|
|
"""Test getting model information."""
|
|
response = client.get("/model/info")
|
|
|
|
# May return 503 if model not loaded, or 200 if loaded
|
|
if model_loaded:
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "classes" in data
|
|
assert "n_classes" in data
|
|
assert "features" in data
|
|
assert "n_features" in data
|
|
|
|
assert isinstance(data["classes"], list)
|
|
assert len(data["classes"]) == data["n_classes"]
|
|
assert len(data["features"]) == data["n_features"]
|
|
else:
|
|
assert response.status_code == 503
|
|
|
|
def test_model_info_not_loaded(self, client):
|
|
"""Test that model info returns 503 when model not loaded."""
|
|
original_predictor = xgb_fastapi.predictor
|
|
xgb_fastapi.predictor = None
|
|
|
|
try:
|
|
response = client.get("/model/info")
|
|
assert response.status_code == 503
|
|
finally:
|
|
xgb_fastapi.predictor = original_predictor
|
|
|
|
|
|
# ============================================================================
|
|
# Error Handling Tests
|
|
# ============================================================================
|
|
|
|
class TestErrorHandling:
|
|
"""Tests for error handling."""
|
|
|
|
def test_predict_missing_features(self, client):
|
|
"""Test prediction without features field."""
|
|
request_data = {
|
|
"threshold": 0.5
|
|
}
|
|
response = client.post("/predict", json=request_data)
|
|
assert response.status_code == 422
|
|
|
|
def test_predict_invalid_json_format(self, client, model_loaded):
|
|
"""Test prediction with invalid JSON in is_json mode."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
request_data = {
|
|
"features": "not valid json {{",
|
|
"is_json": True
|
|
}
|
|
response = client.post("/predict", json=request_data)
|
|
# Should return error (400)
|
|
assert response.status_code == 400
|
|
|
|
def test_invalid_endpoint(self, client):
|
|
"""Test accessing an invalid endpoint."""
|
|
response = client.get("/nonexistent")
|
|
assert response.status_code == 404
|
|
|
|
|
|
# ============================================================================
|
|
# Integration Tests (Full Pipeline)
|
|
# ============================================================================
|
|
|
|
class TestIntegration:
|
|
"""Integration tests for full prediction pipeline."""
|
|
|
|
def test_full_prediction_pipeline_features(self, client, model_loaded):
|
|
"""Test complete prediction pipeline with feature dict."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
# 1. Check health
|
|
health_response = client.get("/health")
|
|
assert health_response.status_code == 200
|
|
|
|
health_data = health_response.json()
|
|
assert health_data["model_loaded"] == True
|
|
|
|
# 2. Get model info
|
|
info_response = client.get("/model/info")
|
|
assert info_response.status_code == 200
|
|
|
|
# 3. Make single prediction
|
|
predict_response = client.post("/predict", json={
|
|
"features": SAMPLE_TURTLE,
|
|
"threshold": 0.3
|
|
})
|
|
assert predict_response.status_code == 200
|
|
|
|
# 4. Make top-K prediction
|
|
topk_response = client.post("/predict_top_k", json={
|
|
"features": SAMPLE_TURTLE,
|
|
"k": 5
|
|
})
|
|
assert topk_response.status_code == 200
|
|
|
|
# 5. Make batch prediction
|
|
batch_response = client.post("/batch_predict", json={
|
|
"features_list": [SAMPLE_TURTLE, SAMPLE_CHROMA],
|
|
"threshold": 0.3
|
|
})
|
|
assert batch_response.status_code == 200
|
|
|
|
def test_consistency_single_vs_batch(self, client, model_loaded):
|
|
"""Test that single prediction and batch prediction give consistent results."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
threshold = 0.3
|
|
|
|
# Single prediction
|
|
single_response = client.post("/predict", json={
|
|
"features": SAMPLE_TURTLE,
|
|
"threshold": threshold
|
|
})
|
|
|
|
# Batch prediction with same sample
|
|
batch_response = client.post("/batch_predict", json={
|
|
"features_list": [SAMPLE_TURTLE],
|
|
"threshold": threshold
|
|
})
|
|
|
|
assert single_response.status_code == 200
|
|
assert batch_response.status_code == 200
|
|
|
|
single_data = single_response.json()
|
|
batch_data = batch_response.json()
|
|
|
|
if batch_data["successful"] == 1:
|
|
batch_result = batch_data["results"][0]
|
|
# Predictions should be the same
|
|
assert set(single_data["predictions"]) == set(batch_result["predictions"])
|
|
|
|
|
|
# ============================================================================
|
|
# CORS Tests
|
|
# ============================================================================
|
|
|
|
class TestCORS:
|
|
"""Tests for CORS configuration."""
|
|
|
|
def test_cors_headers(self, client):
|
|
"""Test that CORS headers are present in responses."""
|
|
response = client.options("/predict")
|
|
# Accept either 200 or 405 (method not allowed for OPTIONS in some configs)
|
|
assert response.status_code in [200, 405]
|
|
|
|
|
|
# ============================================================================
|
|
# Performance Tests (Basic)
|
|
# ============================================================================
|
|
|
|
class TestPerformance:
|
|
"""Basic performance tests."""
|
|
|
|
def test_response_time_single_prediction(self, client, model_loaded):
|
|
"""Test that single prediction completes in reasonable time."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
import time
|
|
|
|
start = time.time()
|
|
response = client.post("/predict", json={
|
|
"features": SAMPLE_TURTLE,
|
|
"threshold": 0.5
|
|
})
|
|
elapsed = time.time() - start
|
|
|
|
# Should complete within 5 seconds (generous for CI environments)
|
|
assert elapsed < 5.0, f"Prediction took {elapsed:.2f}s, expected < 5s"
|
|
assert response.status_code == 200
|
|
|
|
def test_response_time_batch_prediction(self, client, model_loaded):
|
|
"""Test that batch prediction scales reasonably."""
|
|
if not model_loaded:
|
|
pytest.skip("Model not loaded")
|
|
|
|
import time
|
|
|
|
# Create a batch of 10 samples
|
|
features_list = [SAMPLE_TURTLE] * 10
|
|
|
|
start = time.time()
|
|
response = client.post("/batch_predict", json={
|
|
"features_list": features_list,
|
|
"threshold": 0.5
|
|
})
|
|
elapsed = time.time() - start
|
|
|
|
# Should complete within 10 seconds
|
|
assert elapsed < 10.0, f"Batch prediction took {elapsed:.2f}s, expected < 10s"
|
|
assert response.status_code == 200
|
|
|
|
|
|
# ============================================================================
|
|
# Run tests if executed directly
|
|
# ============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "--tb=short"])
|