#!/usr/bin/env python3 """ Tests for XGBoost FastAPI Multi-Label Classification API These tests use realistic sample data extracted from /Volumes/T7/roofline_features.h5 which was generated using feature_aggregator.py to process roofline dataframes. Test data includes samples from different application types: - TurTLE (turbulence simulation) - SCALEXA (scaling benchmarks) - Chroma (lattice QCD) """ import pytest import json from fastapi.testclient import TestClient from unittest.mock import patch, MagicMock import numpy as np import os # Import the FastAPI app and predictor import xgb_fastapi from xgb_fastapi import app # ============================================================================ # Test Data: Realistic samples from roofline_features.h5 # Generated using feature_aggregator.py processing roofline dataframes # ============================================================================ # Sample 1: TurTLE application - typical turbulence simulation workload SAMPLE_TURTLE = { "bandwidth_raw_p10": 186.33, "bandwidth_raw_median": 205.14, "bandwidth_raw_p90": 210.83, "bandwidth_raw_mad": 3.57, "bandwidth_raw_range": 24.5, "bandwidth_raw_iqr": 12.075, "flops_raw_p10": 162.024, "flops_raw_median": 171.45, "flops_raw_p90": 176.48, "flops_raw_mad": 3.08, "flops_raw_range": 14.456, "flops_raw_iqr": 8.29, "arith_intensity_p10": 0.7906, "arith_intensity_median": 0.837, "arith_intensity_p90": 0.9109, "arith_intensity_mad": 0.02, "arith_intensity_range": 0.12, "arith_intensity_iqr": 0.0425, "bw_flops_covariance": 60.86, "bw_flops_correlation": 0.16, "avg_performance_gflops": 168.1, "median_performance_gflops": 171.45, "performance_gflops_mad": 3.08, "avg_memory_bw_gbs": 350.0, "scalar_peak_gflops": 432.0, "simd_peak_gflops": 9216.0, "node_num": 0, "duration": 19366, } # Sample 2: SCALEXA application - scaling benchmark workload SAMPLE_SCALEXA = { "bandwidth_raw_p10": 13.474, "bandwidth_raw_median": 32.57, "bandwidth_raw_p90": 51.466, "bandwidth_raw_mad": 23.62, "bandwidth_raw_range": 37.992, "bandwidth_raw_iqr": 23.745, "flops_raw_p10": 4.24, "flops_raw_median": 16.16, "flops_raw_p90": 24.584, "flops_raw_mad": 10.53, "flops_raw_range": 20.344, "flops_raw_iqr": 12.715, "arith_intensity_p10": 0.211, "arith_intensity_median": 0.475, "arith_intensity_p90": 0.492, "arith_intensity_mad": 0.021, "arith_intensity_range": 0.281, "arith_intensity_iqr": 0.176, "bw_flops_covariance": 302.0, "bw_flops_correlation": 0.995, "avg_performance_gflops": 14.7, "median_performance_gflops": 16.16, "performance_gflops_mad": 10.53, "avg_memory_bw_gbs": 350.0, "scalar_peak_gflops": 432.0, "simd_peak_gflops": 9216.0, "node_num": 18, "duration": 165, } # Sample 3: Chroma application - lattice QCD workload (compute-intensive) SAMPLE_CHROMA = { "bandwidth_raw_p10": 154.176, "bandwidth_raw_median": 200.57, "bandwidth_raw_p90": 259.952, "bandwidth_raw_mad": 5.12, "bandwidth_raw_range": 105.776, "bandwidth_raw_iqr": 10.215, "flops_raw_p10": 327.966, "flops_raw_median": 519.8, "flops_raw_p90": 654.422, "flops_raw_mad": 16.97, "flops_raw_range": 326.456, "flops_raw_iqr": 34.88, "arith_intensity_p10": 1.55, "arith_intensity_median": 2.595, "arith_intensity_p90": 3.445, "arith_intensity_mad": 0.254, "arith_intensity_range": 1.894, "arith_intensity_iqr": 0.512, "bw_flops_covariance": 382.76, "bw_flops_correlation": 0.063, "avg_performance_gflops": 503.26, "median_performance_gflops": 519.8, "performance_gflops_mad": 16.97, "avg_memory_bw_gbs": 350.0, "scalar_peak_gflops": 432.0, "simd_peak_gflops": 9216.0, "node_num": 3, "duration": 31133, } # Sample JSON roofline data (raw data before aggregation, as would be received by API) SAMPLE_JSON_ROOFLINE = json.dumps([ { "node_num": 1, "bandwidth_raw": 150.5, "flops_raw": 2500.0, "arith_intensity": 16.6, "performance_gflops": 1200.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600 }, { "node_num": 1, "bandwidth_raw": 155.2, "flops_raw": 2600.0, "arith_intensity": 16.8, "performance_gflops": 1250.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600 }, { "node_num": 1, "bandwidth_raw": 148.0, "flops_raw": 2450.0, "arith_intensity": 16.5, "performance_gflops": 1180.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600 } ]) # ============================================================================ # Fixtures # ============================================================================ @pytest.fixture(scope="module") def setup_predictor(): """ Set up the predictor for tests. Try to load the real model if available. """ from xgb_inference_api import XGBoostMultiLabelPredictor model_path = os.path.join(os.path.dirname(__file__), 'xgb_model.joblib') if os.path.exists(model_path): try: predictor = XGBoostMultiLabelPredictor(model_path) xgb_fastapi.predictor = predictor return True except Exception as e: print(f"Failed to load model: {e}") return False return False @pytest.fixture def client(setup_predictor): """Create a test client for the FastAPI app.""" return TestClient(app) @pytest.fixture def model_loaded(setup_predictor): """Check if the model is loaded.""" return setup_predictor def skip_if_no_model(model_loaded): """Helper to skip tests if model is not loaded.""" if not model_loaded: pytest.skip("Model not loaded, skipping test") # ============================================================================ # Health and Root Endpoint Tests # ============================================================================ class TestHealthEndpoints: """Tests for health check and root endpoints.""" def test_root_endpoint(self, client): """Test the root endpoint returns API information.""" response = client.get("/") assert response.status_code == 200 data = response.json() assert "name" in data assert data["name"] == "XGBoost Multi-Label Classification API" assert "version" in data assert "endpoints" in data assert all(key in data["endpoints"] for key in ["health", "predict", "predict_top_k", "batch_predict"]) def test_health_check(self, client, model_loaded): """Test the health check endpoint.""" response = client.get("/health") assert response.status_code == 200 data = response.json() assert "status" in data assert "model_loaded" in data # If model is loaded, check for additional info if model_loaded: assert data["model_loaded"] == True assert data["status"] in ["healthy", "degraded"] if data["status"] == "healthy": assert "n_classes" in data assert "n_features" in data assert "classes" in data assert data["n_classes"] > 0 assert data["n_features"] > 0 def test_health_check_model_not_loaded(self, client): """Test health check returns correct status when model not loaded.""" response = client.get("/health") assert response.status_code == 200 data = response.json() # Should have status and model_loaded fields regardless assert "status" in data assert "model_loaded" in data # ============================================================================ # Single Prediction Tests # ============================================================================ class TestPredictEndpoint: """Tests for the /predict endpoint.""" def test_predict_with_feature_dict(self, client, model_loaded): """Test prediction with a feature dictionary (TurTLE sample).""" if not model_loaded: pytest.skip("Model not loaded") request_data = { "features": SAMPLE_TURTLE, "threshold": 0.5, "return_all_probabilities": True } response = client.post("/predict", json=request_data) assert response.status_code == 200 data = response.json() assert "predictions" in data assert "probabilities" in data assert "confidences" in data assert "threshold" in data assert isinstance(data["predictions"], list) assert isinstance(data["probabilities"], dict) assert data["threshold"] == 0.5 # All probabilities should be between 0 and 1 for prob in data["probabilities"].values(): assert 0.0 <= prob <= 1.0 def test_predict_with_different_thresholds(self, client, model_loaded): """Test that different thresholds affect predictions.""" if not model_loaded: pytest.skip("Model not loaded") request_low = { "features": SAMPLE_TURTLE, "threshold": 0.1, "return_all_probabilities": True } request_high = { "features": SAMPLE_TURTLE, "threshold": 0.9, "return_all_probabilities": True } response_low = client.post("/predict", json=request_low) response_high = client.post("/predict", json=request_high) assert response_low.status_code == 200 assert response_high.status_code == 200 data_low = response_low.json() data_high = response_high.json() # Lower threshold should generally produce more predictions assert len(data_low["predictions"]) >= len(data_high["predictions"]) def test_predict_different_workloads(self, client, model_loaded): """Test predictions on different application workloads.""" if not model_loaded: pytest.skip("Model not loaded") samples = [ ("TurTLE", SAMPLE_TURTLE), ("SCALEXA", SAMPLE_SCALEXA), ("Chroma", SAMPLE_CHROMA), ] for name, sample in samples: request_data = { "features": sample, "threshold": 0.3, "return_all_probabilities": True } response = client.post("/predict", json=request_data) assert response.status_code == 200, f"Failed for {name}" data = response.json() assert len(data["probabilities"]) > 0, f"No probabilities for {name}" def test_predict_return_only_predicted_probabilities(self, client, model_loaded): """Test prediction with return_all_probabilities=False.""" if not model_loaded: pytest.skip("Model not loaded") request_data = { "features": SAMPLE_TURTLE, "threshold": 0.3, "return_all_probabilities": False } response = client.post("/predict", json=request_data) assert response.status_code == 200 data = response.json() # When return_all_probabilities is False, probabilities should only # contain classes that are in predictions if len(data["predictions"]) > 0: assert set(data["probabilities"].keys()) == set(data["predictions"]) def test_predict_with_json_roofline_data(self, client, model_loaded): """Test prediction with raw JSON roofline data (requires aggregation).""" if not model_loaded: pytest.skip("Model not loaded") request_data = { "features": SAMPLE_JSON_ROOFLINE, "is_json": True, "job_id": "test_job_123", "threshold": 0.3 } response = client.post("/predict", json=request_data) assert response.status_code == 200 data = response.json() assert "predictions" in data assert "probabilities" in data def test_predict_threshold_boundaries(self, client, model_loaded): """Test prediction with threshold at boundaries.""" if not model_loaded: pytest.skip("Model not loaded") for threshold in [0.0, 0.5, 1.0]: request_data = { "features": SAMPLE_TURTLE, "threshold": threshold } response = client.post("/predict", json=request_data) assert response.status_code == 200 def test_predict_invalid_threshold(self, client): """Test that invalid threshold values are rejected.""" for threshold in [-0.1, 1.5]: request_data = { "features": SAMPLE_TURTLE, "threshold": threshold } response = client.post("/predict", json=request_data) assert response.status_code == 422 # Validation error def test_predict_model_not_loaded(self, client): """Test that prediction returns 503 when model not loaded.""" # Temporarily set predictor to None original_predictor = xgb_fastapi.predictor xgb_fastapi.predictor = None try: request_data = { "features": SAMPLE_TURTLE, "threshold": 0.5 } response = client.post("/predict", json=request_data) assert response.status_code == 503 assert "Model not loaded" in response.json().get("detail", "") finally: xgb_fastapi.predictor = original_predictor # ============================================================================ # Top-K Prediction Tests # ============================================================================ class TestPredictTopKEndpoint: """Tests for the /predict_top_k endpoint.""" def test_predict_top_k_default(self, client, model_loaded): """Test top-K prediction with default k=5.""" if not model_loaded: pytest.skip("Model not loaded") request_data = { "features": SAMPLE_TURTLE } response = client.post("/predict_top_k", json=request_data) assert response.status_code == 200 data = response.json() assert "top_predictions" in data assert "top_probabilities" in data assert "all_probabilities" in data assert len(data["top_predictions"]) <= 5 assert len(data["top_probabilities"]) <= 5 def test_predict_top_k_custom_k(self, client, model_loaded): """Test top-K prediction with custom k values.""" if not model_loaded: pytest.skip("Model not loaded") for k in [1, 3, 10]: request_data = { "features": SAMPLE_TURTLE, "k": k } response = client.post("/predict_top_k", json=request_data) assert response.status_code == 200 data = response.json() assert len(data["top_predictions"]) <= k def test_predict_top_k_ordering(self, client, model_loaded): """Test that top-K predictions are ordered by probability.""" if not model_loaded: pytest.skip("Model not loaded") request_data = { "features": SAMPLE_CHROMA, "k": 10 } response = client.post("/predict_top_k", json=request_data) assert response.status_code == 200 data = response.json() probabilities = [data["top_probabilities"][cls] for cls in data["top_predictions"]] # Check that probabilities are in descending order for i in range(len(probabilities) - 1): assert probabilities[i] >= probabilities[i + 1] def test_predict_top_k_with_json_data(self, client, model_loaded): """Test top-K prediction with JSON roofline data.""" if not model_loaded: pytest.skip("Model not loaded") request_data = { "features": SAMPLE_JSON_ROOFLINE, "is_json": True, "job_id": "test_job_456", "k": 5 } response = client.post("/predict_top_k", json=request_data) assert response.status_code == 200 data = response.json() assert len(data["top_predictions"]) <= 5 def test_predict_top_k_invalid_k(self, client): """Test that invalid k values are rejected.""" for k in [0, -1, 101]: request_data = { "features": SAMPLE_TURTLE, "k": k } response = client.post("/predict_top_k", json=request_data) assert response.status_code == 422 # Validation error # ============================================================================ # Batch Prediction Tests # ============================================================================ class TestBatchPredictEndpoint: """Tests for the /batch_predict endpoint.""" def test_batch_predict_multiple_samples(self, client, model_loaded): """Test batch prediction with multiple samples.""" if not model_loaded: pytest.skip("Model not loaded") request_data = { "features_list": [SAMPLE_TURTLE, SAMPLE_SCALEXA, SAMPLE_CHROMA], "threshold": 0.3 } response = client.post("/batch_predict", json=request_data) assert response.status_code == 200 data = response.json() assert "results" in data assert "total" in data assert "successful" in data assert "failed" in data assert data["total"] == 3 assert len(data["results"]) == 3 assert data["successful"] + data["failed"] == data["total"] def test_batch_predict_single_sample(self, client, model_loaded): """Test batch prediction with a single sample.""" if not model_loaded: pytest.skip("Model not loaded") request_data = { "features_list": [SAMPLE_TURTLE], "threshold": 0.5 } response = client.post("/batch_predict", json=request_data) assert response.status_code == 200 data = response.json() assert data["total"] == 1 assert len(data["results"]) == 1 def test_batch_predict_with_json_data(self, client, model_loaded): """Test batch prediction with JSON roofline data.""" if not model_loaded: pytest.skip("Model not loaded") request_data = { "features_list": [SAMPLE_JSON_ROOFLINE, SAMPLE_JSON_ROOFLINE], "is_json": True, "job_ids": ["job_001", "job_002"], "threshold": 0.3 } response = client.post("/batch_predict", json=request_data) assert response.status_code == 200 data = response.json() assert data["total"] == 2 def test_batch_predict_empty_list(self, client, model_loaded): """Test batch prediction with empty list.""" if not model_loaded: pytest.skip("Model not loaded") request_data = { "features_list": [], "threshold": 0.5 } response = client.post("/batch_predict", json=request_data) assert response.status_code == 200 data = response.json() assert data["total"] == 0 assert data["successful"] == 0 assert data["failed"] == 0 # ============================================================================ # Model Info Tests # ============================================================================ class TestModelInfoEndpoint: """Tests for the /model/info endpoint.""" def test_model_info(self, client, model_loaded): """Test getting model information.""" response = client.get("/model/info") # May return 503 if model not loaded, or 200 if loaded if model_loaded: assert response.status_code == 200 data = response.json() assert "classes" in data assert "n_classes" in data assert "features" in data assert "n_features" in data assert isinstance(data["classes"], list) assert len(data["classes"]) == data["n_classes"] assert len(data["features"]) == data["n_features"] else: assert response.status_code == 503 def test_model_info_not_loaded(self, client): """Test that model info returns 503 when model not loaded.""" original_predictor = xgb_fastapi.predictor xgb_fastapi.predictor = None try: response = client.get("/model/info") assert response.status_code == 503 finally: xgb_fastapi.predictor = original_predictor # ============================================================================ # Error Handling Tests # ============================================================================ class TestErrorHandling: """Tests for error handling.""" def test_predict_missing_features(self, client): """Test prediction without features field.""" request_data = { "threshold": 0.5 } response = client.post("/predict", json=request_data) assert response.status_code == 422 def test_predict_invalid_json_format(self, client, model_loaded): """Test prediction with invalid JSON in is_json mode.""" if not model_loaded: pytest.skip("Model not loaded") request_data = { "features": "not valid json {{", "is_json": True } response = client.post("/predict", json=request_data) # Should return error (400) assert response.status_code == 400 def test_invalid_endpoint(self, client): """Test accessing an invalid endpoint.""" response = client.get("/nonexistent") assert response.status_code == 404 # ============================================================================ # Integration Tests (Full Pipeline) # ============================================================================ class TestIntegration: """Integration tests for full prediction pipeline.""" def test_full_prediction_pipeline_features(self, client, model_loaded): """Test complete prediction pipeline with feature dict.""" if not model_loaded: pytest.skip("Model not loaded") # 1. Check health health_response = client.get("/health") assert health_response.status_code == 200 health_data = health_response.json() assert health_data["model_loaded"] == True # 2. Get model info info_response = client.get("/model/info") assert info_response.status_code == 200 # 3. Make single prediction predict_response = client.post("/predict", json={ "features": SAMPLE_TURTLE, "threshold": 0.3 }) assert predict_response.status_code == 200 # 4. Make top-K prediction topk_response = client.post("/predict_top_k", json={ "features": SAMPLE_TURTLE, "k": 5 }) assert topk_response.status_code == 200 # 5. Make batch prediction batch_response = client.post("/batch_predict", json={ "features_list": [SAMPLE_TURTLE, SAMPLE_CHROMA], "threshold": 0.3 }) assert batch_response.status_code == 200 def test_consistency_single_vs_batch(self, client, model_loaded): """Test that single prediction and batch prediction give consistent results.""" if not model_loaded: pytest.skip("Model not loaded") threshold = 0.3 # Single prediction single_response = client.post("/predict", json={ "features": SAMPLE_TURTLE, "threshold": threshold }) # Batch prediction with same sample batch_response = client.post("/batch_predict", json={ "features_list": [SAMPLE_TURTLE], "threshold": threshold }) assert single_response.status_code == 200 assert batch_response.status_code == 200 single_data = single_response.json() batch_data = batch_response.json() if batch_data["successful"] == 1: batch_result = batch_data["results"][0] # Predictions should be the same assert set(single_data["predictions"]) == set(batch_result["predictions"]) # ============================================================================ # CORS Tests # ============================================================================ class TestCORS: """Tests for CORS configuration.""" def test_cors_headers(self, client): """Test that CORS headers are present in responses.""" response = client.options("/predict") # Accept either 200 or 405 (method not allowed for OPTIONS in some configs) assert response.status_code in [200, 405] # ============================================================================ # Performance Tests (Basic) # ============================================================================ class TestPerformance: """Basic performance tests.""" def test_response_time_single_prediction(self, client, model_loaded): """Test that single prediction completes in reasonable time.""" if not model_loaded: pytest.skip("Model not loaded") import time start = time.time() response = client.post("/predict", json={ "features": SAMPLE_TURTLE, "threshold": 0.5 }) elapsed = time.time() - start # Should complete within 5 seconds (generous for CI environments) assert elapsed < 5.0, f"Prediction took {elapsed:.2f}s, expected < 5s" assert response.status_code == 200 def test_response_time_batch_prediction(self, client, model_loaded): """Test that batch prediction scales reasonably.""" if not model_loaded: pytest.skip("Model not loaded") import time # Create a batch of 10 samples features_list = [SAMPLE_TURTLE] * 10 start = time.time() response = client.post("/batch_predict", json={ "features_list": features_list, "threshold": 0.5 }) elapsed = time.time() - start # Should complete within 10 seconds assert elapsed < 10.0, f"Batch prediction took {elapsed:.2f}s, expected < 10s" assert response.status_code == 200 # ============================================================================ # Run tests if executed directly # ============================================================================ if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"])