#!/bin/bash # # Test script for XGBoost FastAPI Multi-Label Classification API # Uses curl to test all endpoints with realistic sample data from roofline_features.h5 # # Usage: # 1. Start the server: ./venv/bin/python xgb_fastapi.py # 2. Run this script: ./test_api_curl.sh # # Optional: Set API_URL environment variable to test a different host # API_URL=http://192.168.1.100:8000 ./test_api_curl.sh set -e # Configuration API_URL="${API_URL:-http://localhost:8000}" VERBOSE="${VERBOSE:-false}" # Colors for output RED='\033[0;31m' GREEN='\033[0;32m' YELLOW='\033[1;33m' BLUE='\033[0;34m' NC='\033[0m' # No Color # Counters PASSED=0 FAILED=0 # ============================================================================ # Helper Functions # ============================================================================ print_header() { echo "" echo -e "${BLUE}============================================================${NC}" echo -e "${BLUE}$1${NC}" echo -e "${BLUE}============================================================${NC}" } print_test() { echo -e "\n${YELLOW}▶ TEST: $1${NC}" } print_success() { echo -e "${GREEN}✓ PASSED: $1${NC}" ((PASSED++)) } print_failure() { echo -e "${RED}✗ FAILED: $1${NC}" ((FAILED++)) } # Make a curl request and check status code # Usage: make_request METHOD ENDPOINT [DATA] make_request() { local method=$1 local endpoint=$2 local data=$3 local expected_status=${4:-200} if [ "$method" == "GET" ]; then response=$(curl -s -w "\n%{http_code}" "${API_URL}${endpoint}") else response=$(curl -s -w "\n%{http_code}" -X "$method" \ -H "Content-Type: application/json" \ -d "$data" \ "${API_URL}${endpoint}") fi # Extract status code (last line) and body (everything else) http_code=$(echo "$response" | tail -n1) body=$(echo "$response" | sed '$d') if [ "$VERBOSE" == "true" ]; then echo "Response: $body" echo "Status: $http_code" fi if [ "$http_code" == "$expected_status" ]; then return 0 else echo "Expected status $expected_status, got $http_code" echo "Response: $body" return 1 fi } # ============================================================================ # Sample Data (from roofline_features.h5) # ============================================================================ # TurTLE application - turbulence simulation workload SAMPLE_TURTLE='{ "bandwidth_raw_p10": 186.33, "bandwidth_raw_median": 205.14, "bandwidth_raw_p90": 210.83, "bandwidth_raw_mad": 3.57, "bandwidth_raw_range": 24.5, "bandwidth_raw_iqr": 12.075, "flops_raw_p10": 162.024, "flops_raw_median": 171.45, "flops_raw_p90": 176.48, "flops_raw_mad": 3.08, "flops_raw_range": 14.456, "flops_raw_iqr": 8.29, "arith_intensity_p10": 0.7906, "arith_intensity_median": 0.837, "arith_intensity_p90": 0.9109, "arith_intensity_mad": 0.02, "arith_intensity_range": 0.12, "arith_intensity_iqr": 0.0425, "bw_flops_covariance": 60.86, "bw_flops_correlation": 0.16, "avg_performance_gflops": 168.1, "median_performance_gflops": 171.45, "performance_gflops_mad": 3.08, "avg_memory_bw_gbs": 350.0, "scalar_peak_gflops": 432.0, "simd_peak_gflops": 9216.0, "node_num": 0, "duration": 19366 }' # Chroma application - lattice QCD workload (compute-intensive) SAMPLE_CHROMA='{ "bandwidth_raw_p10": 154.176, "bandwidth_raw_median": 200.57, "bandwidth_raw_p90": 259.952, "bandwidth_raw_mad": 5.12, "bandwidth_raw_range": 105.776, "bandwidth_raw_iqr": 10.215, "flops_raw_p10": 327.966, "flops_raw_median": 519.8, "flops_raw_p90": 654.422, "flops_raw_mad": 16.97, "flops_raw_range": 326.456, "flops_raw_iqr": 34.88, "arith_intensity_p10": 1.55, "arith_intensity_median": 2.595, "arith_intensity_p90": 3.445, "arith_intensity_mad": 0.254, "arith_intensity_range": 1.894, "arith_intensity_iqr": 0.512, "bw_flops_covariance": 382.76, "bw_flops_correlation": 0.063, "avg_performance_gflops": 503.26, "median_performance_gflops": 519.8, "performance_gflops_mad": 16.97, "avg_memory_bw_gbs": 350.0, "scalar_peak_gflops": 432.0, "simd_peak_gflops": 9216.0, "node_num": 3, "duration": 31133 }' # Raw JSON roofline data (before aggregation) SAMPLE_JSON_ROOFLINE='[ {"node_num": 1, "bandwidth_raw": 150.5, "flops_raw": 2500.0, "arith_intensity": 16.6, "performance_gflops": 1200.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600}, {"node_num": 1, "bandwidth_raw": 155.2, "flops_raw": 2600.0, "arith_intensity": 16.8, "performance_gflops": 1250.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600}, {"node_num": 1, "bandwidth_raw": 148.0, "flops_raw": 2450.0, "arith_intensity": 16.5, "performance_gflops": 1180.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600} ]' # ============================================================================ # Tests # ============================================================================ print_header "XGBoost FastAPI Test Suite" echo "Testing API at: $API_URL" echo "Started at: $(date)" # ---------------------------------------------------------------------------- # Health & Info Endpoints # ---------------------------------------------------------------------------- print_header "Health & Info Endpoints" print_test "Root Endpoint" if make_request GET "/"; then print_success "Root endpoint accessible" else print_failure "Root endpoint failed" fi print_test "Health Check" if make_request GET "/health"; then print_success "Health check passed" else print_failure "Health check failed" fi print_test "Model Info" if make_request GET "/model/info"; then print_success "Model info retrieved" else print_failure "Model info failed" fi # ---------------------------------------------------------------------------- # Single Prediction Endpoints # ---------------------------------------------------------------------------- print_header "Single Prediction Endpoint (/predict)" print_test "Predict with TurTLE sample (threshold=0.3)" REQUEST_DATA=$(cat < 1.0) - should return 422" REQUEST_DATA=$(cat <1) correctly rejected" else print_failure "Invalid threshold (>1) not rejected properly" fi print_test "Missing features field - should return 422" REQUEST_DATA='{"threshold": 0.5}' if make_request POST "/predict" "$REQUEST_DATA" 422; then print_success "Missing features correctly rejected" else print_failure "Missing features not rejected properly" fi print_test "Invalid endpoint - should return 404" if make_request GET "/nonexistent" "" 404; then print_success "Invalid endpoint correctly returns 404" else print_failure "Invalid endpoint not handled properly" fi print_test "Invalid k value (0) - should return 422" REQUEST_DATA=$(cat <