Files
slurm-application-detection…/test_xgb_fastapi.py
2025-12-10 12:17:41 +01:00

812 lines
27 KiB
Python

#!/usr/bin/env python3
"""
Tests for XGBoost FastAPI Multi-Label Classification API
These tests use realistic sample data extracted from /Volumes/T7/roofline_features.h5
which was generated using feature_aggregator.py to process roofline dataframes.
Test data includes samples from different application types:
- TurTLE (turbulence simulation)
- SCALEXA (scaling benchmarks)
- Chroma (lattice QCD)
"""
import pytest
import json
from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock
import numpy as np
import os
# Import the FastAPI app and predictor
import xgb_fastapi
from xgb_fastapi import app
# ============================================================================
# Test Data: Realistic samples from roofline_features.h5
# Generated using feature_aggregator.py processing roofline dataframes
# ============================================================================
# Sample 1: TurTLE application - typical turbulence simulation workload
SAMPLE_TURTLE = {
"bandwidth_raw_p10": 186.33,
"bandwidth_raw_median": 205.14,
"bandwidth_raw_p90": 210.83,
"bandwidth_raw_mad": 3.57,
"bandwidth_raw_range": 24.5,
"bandwidth_raw_iqr": 12.075,
"flops_raw_p10": 162.024,
"flops_raw_median": 171.45,
"flops_raw_p90": 176.48,
"flops_raw_mad": 3.08,
"flops_raw_range": 14.456,
"flops_raw_iqr": 8.29,
"arith_intensity_p10": 0.7906,
"arith_intensity_median": 0.837,
"arith_intensity_p90": 0.9109,
"arith_intensity_mad": 0.02,
"arith_intensity_range": 0.12,
"arith_intensity_iqr": 0.0425,
"bw_flops_covariance": 60.86,
"bw_flops_correlation": 0.16,
"avg_performance_gflops": 168.1,
"median_performance_gflops": 171.45,
"performance_gflops_mad": 3.08,
"avg_memory_bw_gbs": 350.0,
"scalar_peak_gflops": 432.0,
"simd_peak_gflops": 9216.0,
"node_num": 0,
"duration": 19366,
}
# Sample 2: SCALEXA application - scaling benchmark workload
SAMPLE_SCALEXA = {
"bandwidth_raw_p10": 13.474,
"bandwidth_raw_median": 32.57,
"bandwidth_raw_p90": 51.466,
"bandwidth_raw_mad": 23.62,
"bandwidth_raw_range": 37.992,
"bandwidth_raw_iqr": 23.745,
"flops_raw_p10": 4.24,
"flops_raw_median": 16.16,
"flops_raw_p90": 24.584,
"flops_raw_mad": 10.53,
"flops_raw_range": 20.344,
"flops_raw_iqr": 12.715,
"arith_intensity_p10": 0.211,
"arith_intensity_median": 0.475,
"arith_intensity_p90": 0.492,
"arith_intensity_mad": 0.021,
"arith_intensity_range": 0.281,
"arith_intensity_iqr": 0.176,
"bw_flops_covariance": 302.0,
"bw_flops_correlation": 0.995,
"avg_performance_gflops": 14.7,
"median_performance_gflops": 16.16,
"performance_gflops_mad": 10.53,
"avg_memory_bw_gbs": 350.0,
"scalar_peak_gflops": 432.0,
"simd_peak_gflops": 9216.0,
"node_num": 18,
"duration": 165,
}
# Sample 3: Chroma application - lattice QCD workload (compute-intensive)
SAMPLE_CHROMA = {
"bandwidth_raw_p10": 154.176,
"bandwidth_raw_median": 200.57,
"bandwidth_raw_p90": 259.952,
"bandwidth_raw_mad": 5.12,
"bandwidth_raw_range": 105.776,
"bandwidth_raw_iqr": 10.215,
"flops_raw_p10": 327.966,
"flops_raw_median": 519.8,
"flops_raw_p90": 654.422,
"flops_raw_mad": 16.97,
"flops_raw_range": 326.456,
"flops_raw_iqr": 34.88,
"arith_intensity_p10": 1.55,
"arith_intensity_median": 2.595,
"arith_intensity_p90": 3.445,
"arith_intensity_mad": 0.254,
"arith_intensity_range": 1.894,
"arith_intensity_iqr": 0.512,
"bw_flops_covariance": 382.76,
"bw_flops_correlation": 0.063,
"avg_performance_gflops": 503.26,
"median_performance_gflops": 519.8,
"performance_gflops_mad": 16.97,
"avg_memory_bw_gbs": 350.0,
"scalar_peak_gflops": 432.0,
"simd_peak_gflops": 9216.0,
"node_num": 3,
"duration": 31133,
}
# Sample JSON roofline data (raw data before aggregation, as would be received by API)
SAMPLE_JSON_ROOFLINE = json.dumps([
{
"node_num": 1,
"bandwidth_raw": 150.5,
"flops_raw": 2500.0,
"arith_intensity": 16.6,
"performance_gflops": 1200.0,
"memory_bw_gbs": 450,
"scalar_peak_gflops": 600,
"duration": 3600
},
{
"node_num": 1,
"bandwidth_raw": 155.2,
"flops_raw": 2600.0,
"arith_intensity": 16.8,
"performance_gflops": 1250.0,
"memory_bw_gbs": 450,
"scalar_peak_gflops": 600,
"duration": 3600
},
{
"node_num": 1,
"bandwidth_raw": 148.0,
"flops_raw": 2450.0,
"arith_intensity": 16.5,
"performance_gflops": 1180.0,
"memory_bw_gbs": 450,
"scalar_peak_gflops": 600,
"duration": 3600
}
])
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture(scope="module")
def setup_predictor():
"""
Set up the predictor for tests.
Try to load the real model if available.
"""
from xgb_inference_api import XGBoostMultiLabelPredictor
model_path = os.path.join(os.path.dirname(__file__), 'xgb_model.joblib')
if os.path.exists(model_path):
try:
predictor = XGBoostMultiLabelPredictor(model_path)
xgb_fastapi.predictor = predictor
return True
except Exception as e:
print(f"Failed to load model: {e}")
return False
return False
@pytest.fixture
def client(setup_predictor):
"""Create a test client for the FastAPI app."""
return TestClient(app)
@pytest.fixture
def model_loaded(setup_predictor):
"""Check if the model is loaded."""
return setup_predictor
def skip_if_no_model(model_loaded):
"""Helper to skip tests if model is not loaded."""
if not model_loaded:
pytest.skip("Model not loaded, skipping test")
# ============================================================================
# Health and Root Endpoint Tests
# ============================================================================
class TestHealthEndpoints:
"""Tests for health check and root endpoints."""
def test_root_endpoint(self, client):
"""Test the root endpoint returns API information."""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert "name" in data
assert data["name"] == "XGBoost Multi-Label Classification API"
assert "version" in data
assert "endpoints" in data
assert all(key in data["endpoints"] for key in ["health", "predict", "predict_top_k", "batch_predict"])
def test_health_check(self, client, model_loaded):
"""Test the health check endpoint."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert "model_loaded" in data
# If model is loaded, check for additional info
if model_loaded:
assert data["model_loaded"] == True
assert data["status"] in ["healthy", "degraded"]
if data["status"] == "healthy":
assert "n_classes" in data
assert "n_features" in data
assert "classes" in data
assert data["n_classes"] > 0
assert data["n_features"] > 0
def test_health_check_model_not_loaded(self, client):
"""Test health check returns correct status when model not loaded."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
# Should have status and model_loaded fields regardless
assert "status" in data
assert "model_loaded" in data
# ============================================================================
# Single Prediction Tests
# ============================================================================
class TestPredictEndpoint:
"""Tests for the /predict endpoint."""
def test_predict_with_feature_dict(self, client, model_loaded):
"""Test prediction with a feature dictionary (TurTLE sample)."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": SAMPLE_TURTLE,
"threshold": 0.5,
"return_all_probabilities": True
}
response = client.post("/predict", json=request_data)
assert response.status_code == 200
data = response.json()
assert "predictions" in data
assert "probabilities" in data
assert "confidences" in data
assert "threshold" in data
assert isinstance(data["predictions"], list)
assert isinstance(data["probabilities"], dict)
assert data["threshold"] == 0.5
# All probabilities should be between 0 and 1
for prob in data["probabilities"].values():
assert 0.0 <= prob <= 1.0
def test_predict_with_different_thresholds(self, client, model_loaded):
"""Test that different thresholds affect predictions."""
if not model_loaded:
pytest.skip("Model not loaded")
request_low = {
"features": SAMPLE_TURTLE,
"threshold": 0.1,
"return_all_probabilities": True
}
request_high = {
"features": SAMPLE_TURTLE,
"threshold": 0.9,
"return_all_probabilities": True
}
response_low = client.post("/predict", json=request_low)
response_high = client.post("/predict", json=request_high)
assert response_low.status_code == 200
assert response_high.status_code == 200
data_low = response_low.json()
data_high = response_high.json()
# Lower threshold should generally produce more predictions
assert len(data_low["predictions"]) >= len(data_high["predictions"])
def test_predict_different_workloads(self, client, model_loaded):
"""Test predictions on different application workloads."""
if not model_loaded:
pytest.skip("Model not loaded")
samples = [
("TurTLE", SAMPLE_TURTLE),
("SCALEXA", SAMPLE_SCALEXA),
("Chroma", SAMPLE_CHROMA),
]
for name, sample in samples:
request_data = {
"features": sample,
"threshold": 0.3,
"return_all_probabilities": True
}
response = client.post("/predict", json=request_data)
assert response.status_code == 200, f"Failed for {name}"
data = response.json()
assert len(data["probabilities"]) > 0, f"No probabilities for {name}"
def test_predict_return_only_predicted_probabilities(self, client, model_loaded):
"""Test prediction with return_all_probabilities=False."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": SAMPLE_TURTLE,
"threshold": 0.3,
"return_all_probabilities": False
}
response = client.post("/predict", json=request_data)
assert response.status_code == 200
data = response.json()
# When return_all_probabilities is False, probabilities should only
# contain classes that are in predictions
if len(data["predictions"]) > 0:
assert set(data["probabilities"].keys()) == set(data["predictions"])
def test_predict_with_json_roofline_data(self, client, model_loaded):
"""Test prediction with raw JSON roofline data (requires aggregation)."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": SAMPLE_JSON_ROOFLINE,
"is_json": True,
"job_id": "test_job_123",
"threshold": 0.3
}
response = client.post("/predict", json=request_data)
assert response.status_code == 200
data = response.json()
assert "predictions" in data
assert "probabilities" in data
def test_predict_threshold_boundaries(self, client, model_loaded):
"""Test prediction with threshold at boundaries."""
if not model_loaded:
pytest.skip("Model not loaded")
for threshold in [0.0, 0.5, 1.0]:
request_data = {
"features": SAMPLE_TURTLE,
"threshold": threshold
}
response = client.post("/predict", json=request_data)
assert response.status_code == 200
def test_predict_invalid_threshold(self, client):
"""Test that invalid threshold values are rejected."""
for threshold in [-0.1, 1.5]:
request_data = {
"features": SAMPLE_TURTLE,
"threshold": threshold
}
response = client.post("/predict", json=request_data)
assert response.status_code == 422 # Validation error
def test_predict_model_not_loaded(self, client):
"""Test that prediction returns 503 when model not loaded."""
# Temporarily set predictor to None
original_predictor = xgb_fastapi.predictor
xgb_fastapi.predictor = None
try:
request_data = {
"features": SAMPLE_TURTLE,
"threshold": 0.5
}
response = client.post("/predict", json=request_data)
assert response.status_code == 503
assert "Model not loaded" in response.json().get("detail", "")
finally:
xgb_fastapi.predictor = original_predictor
# ============================================================================
# Top-K Prediction Tests
# ============================================================================
class TestPredictTopKEndpoint:
"""Tests for the /predict_top_k endpoint."""
def test_predict_top_k_default(self, client, model_loaded):
"""Test top-K prediction with default k=5."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": SAMPLE_TURTLE
}
response = client.post("/predict_top_k", json=request_data)
assert response.status_code == 200
data = response.json()
assert "top_predictions" in data
assert "top_probabilities" in data
assert "all_probabilities" in data
assert len(data["top_predictions"]) <= 5
assert len(data["top_probabilities"]) <= 5
def test_predict_top_k_custom_k(self, client, model_loaded):
"""Test top-K prediction with custom k values."""
if not model_loaded:
pytest.skip("Model not loaded")
for k in [1, 3, 10]:
request_data = {
"features": SAMPLE_TURTLE,
"k": k
}
response = client.post("/predict_top_k", json=request_data)
assert response.status_code == 200
data = response.json()
assert len(data["top_predictions"]) <= k
def test_predict_top_k_ordering(self, client, model_loaded):
"""Test that top-K predictions are ordered by probability."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": SAMPLE_CHROMA,
"k": 10
}
response = client.post("/predict_top_k", json=request_data)
assert response.status_code == 200
data = response.json()
probabilities = [data["top_probabilities"][cls] for cls in data["top_predictions"]]
# Check that probabilities are in descending order
for i in range(len(probabilities) - 1):
assert probabilities[i] >= probabilities[i + 1]
def test_predict_top_k_with_json_data(self, client, model_loaded):
"""Test top-K prediction with JSON roofline data."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": SAMPLE_JSON_ROOFLINE,
"is_json": True,
"job_id": "test_job_456",
"k": 5
}
response = client.post("/predict_top_k", json=request_data)
assert response.status_code == 200
data = response.json()
assert len(data["top_predictions"]) <= 5
def test_predict_top_k_invalid_k(self, client):
"""Test that invalid k values are rejected."""
for k in [0, -1, 101]:
request_data = {
"features": SAMPLE_TURTLE,
"k": k
}
response = client.post("/predict_top_k", json=request_data)
assert response.status_code == 422 # Validation error
# ============================================================================
# Batch Prediction Tests
# ============================================================================
class TestBatchPredictEndpoint:
"""Tests for the /batch_predict endpoint."""
def test_batch_predict_multiple_samples(self, client, model_loaded):
"""Test batch prediction with multiple samples."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features_list": [SAMPLE_TURTLE, SAMPLE_SCALEXA, SAMPLE_CHROMA],
"threshold": 0.3
}
response = client.post("/batch_predict", json=request_data)
assert response.status_code == 200
data = response.json()
assert "results" in data
assert "total" in data
assert "successful" in data
assert "failed" in data
assert data["total"] == 3
assert len(data["results"]) == 3
assert data["successful"] + data["failed"] == data["total"]
def test_batch_predict_single_sample(self, client, model_loaded):
"""Test batch prediction with a single sample."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features_list": [SAMPLE_TURTLE],
"threshold": 0.5
}
response = client.post("/batch_predict", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert len(data["results"]) == 1
def test_batch_predict_with_json_data(self, client, model_loaded):
"""Test batch prediction with JSON roofline data."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features_list": [SAMPLE_JSON_ROOFLINE, SAMPLE_JSON_ROOFLINE],
"is_json": True,
"job_ids": ["job_001", "job_002"],
"threshold": 0.3
}
response = client.post("/batch_predict", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["total"] == 2
def test_batch_predict_empty_list(self, client, model_loaded):
"""Test batch prediction with empty list."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features_list": [],
"threshold": 0.5
}
response = client.post("/batch_predict", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["total"] == 0
assert data["successful"] == 0
assert data["failed"] == 0
# ============================================================================
# Model Info Tests
# ============================================================================
class TestModelInfoEndpoint:
"""Tests for the /model/info endpoint."""
def test_model_info(self, client, model_loaded):
"""Test getting model information."""
response = client.get("/model/info")
# May return 503 if model not loaded, or 200 if loaded
if model_loaded:
assert response.status_code == 200
data = response.json()
assert "classes" in data
assert "n_classes" in data
assert "features" in data
assert "n_features" in data
assert isinstance(data["classes"], list)
assert len(data["classes"]) == data["n_classes"]
assert len(data["features"]) == data["n_features"]
else:
assert response.status_code == 503
def test_model_info_not_loaded(self, client):
"""Test that model info returns 503 when model not loaded."""
original_predictor = xgb_fastapi.predictor
xgb_fastapi.predictor = None
try:
response = client.get("/model/info")
assert response.status_code == 503
finally:
xgb_fastapi.predictor = original_predictor
# ============================================================================
# Error Handling Tests
# ============================================================================
class TestErrorHandling:
"""Tests for error handling."""
def test_predict_missing_features(self, client):
"""Test prediction without features field."""
request_data = {
"threshold": 0.5
}
response = client.post("/predict", json=request_data)
assert response.status_code == 422
def test_predict_invalid_json_format(self, client, model_loaded):
"""Test prediction with invalid JSON in is_json mode."""
if not model_loaded:
pytest.skip("Model not loaded")
request_data = {
"features": "not valid json {{",
"is_json": True
}
response = client.post("/predict", json=request_data)
# Should return error (400)
assert response.status_code == 400
def test_invalid_endpoint(self, client):
"""Test accessing an invalid endpoint."""
response = client.get("/nonexistent")
assert response.status_code == 404
# ============================================================================
# Integration Tests (Full Pipeline)
# ============================================================================
class TestIntegration:
"""Integration tests for full prediction pipeline."""
def test_full_prediction_pipeline_features(self, client, model_loaded):
"""Test complete prediction pipeline with feature dict."""
if not model_loaded:
pytest.skip("Model not loaded")
# 1. Check health
health_response = client.get("/health")
assert health_response.status_code == 200
health_data = health_response.json()
assert health_data["model_loaded"] == True
# 2. Get model info
info_response = client.get("/model/info")
assert info_response.status_code == 200
# 3. Make single prediction
predict_response = client.post("/predict", json={
"features": SAMPLE_TURTLE,
"threshold": 0.3
})
assert predict_response.status_code == 200
# 4. Make top-K prediction
topk_response = client.post("/predict_top_k", json={
"features": SAMPLE_TURTLE,
"k": 5
})
assert topk_response.status_code == 200
# 5. Make batch prediction
batch_response = client.post("/batch_predict", json={
"features_list": [SAMPLE_TURTLE, SAMPLE_CHROMA],
"threshold": 0.3
})
assert batch_response.status_code == 200
def test_consistency_single_vs_batch(self, client, model_loaded):
"""Test that single prediction and batch prediction give consistent results."""
if not model_loaded:
pytest.skip("Model not loaded")
threshold = 0.3
# Single prediction
single_response = client.post("/predict", json={
"features": SAMPLE_TURTLE,
"threshold": threshold
})
# Batch prediction with same sample
batch_response = client.post("/batch_predict", json={
"features_list": [SAMPLE_TURTLE],
"threshold": threshold
})
assert single_response.status_code == 200
assert batch_response.status_code == 200
single_data = single_response.json()
batch_data = batch_response.json()
if batch_data["successful"] == 1:
batch_result = batch_data["results"][0]
# Predictions should be the same
assert set(single_data["predictions"]) == set(batch_result["predictions"])
# ============================================================================
# CORS Tests
# ============================================================================
class TestCORS:
"""Tests for CORS configuration."""
def test_cors_headers(self, client):
"""Test that CORS headers are present in responses."""
response = client.options("/predict")
# Accept either 200 or 405 (method not allowed for OPTIONS in some configs)
assert response.status_code in [200, 405]
# ============================================================================
# Performance Tests (Basic)
# ============================================================================
class TestPerformance:
"""Basic performance tests."""
def test_response_time_single_prediction(self, client, model_loaded):
"""Test that single prediction completes in reasonable time."""
if not model_loaded:
pytest.skip("Model not loaded")
import time
start = time.time()
response = client.post("/predict", json={
"features": SAMPLE_TURTLE,
"threshold": 0.5
})
elapsed = time.time() - start
# Should complete within 5 seconds (generous for CI environments)
assert elapsed < 5.0, f"Prediction took {elapsed:.2f}s, expected < 5s"
assert response.status_code == 200
def test_response_time_batch_prediction(self, client, model_loaded):
"""Test that batch prediction scales reasonably."""
if not model_loaded:
pytest.skip("Model not loaded")
import time
# Create a batch of 10 samples
features_list = [SAMPLE_TURTLE] * 10
start = time.time()
response = client.post("/batch_predict", json={
"features_list": features_list,
"threshold": 0.5
})
elapsed = time.time() - start
# Should complete within 10 seconds
assert elapsed < 10.0, f"Batch prediction took {elapsed:.2f}s, expected < 10s"
assert response.status_code == 200
# ============================================================================
# Run tests if executed directly
# ============================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])