Files
slurm-application-detection…/test_api_curl.sh
2025-12-10 12:17:41 +01:00

437 lines
12 KiB
Bash
Executable File

#!/bin/bash
#
# Test script for XGBoost FastAPI Multi-Label Classification API
# Uses curl to test all endpoints with realistic sample data from roofline_features.h5
#
# Usage:
# 1. Start the server: ./venv/bin/python xgb_fastapi.py
# 2. Run this script: ./test_api_curl.sh
#
# Optional: Set API_URL environment variable to test a different host
# API_URL=http://192.168.1.100:8000 ./test_api_curl.sh
set -e
# Configuration
API_URL="${API_URL:-http://localhost:8000}"
VERBOSE="${VERBOSE:-false}"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Counters
PASSED=0
FAILED=0
# ============================================================================
# Helper Functions
# ============================================================================
print_header() {
echo ""
echo -e "${BLUE}============================================================${NC}"
echo -e "${BLUE}$1${NC}"
echo -e "${BLUE}============================================================${NC}"
}
print_test() {
echo -e "\n${YELLOW}▶ TEST: $1${NC}"
}
print_success() {
echo -e "${GREEN}✓ PASSED: $1${NC}"
((PASSED++))
}
print_failure() {
echo -e "${RED}✗ FAILED: $1${NC}"
((FAILED++))
}
# Make a curl request and check status code
# Usage: make_request METHOD ENDPOINT [DATA]
make_request() {
local method=$1
local endpoint=$2
local data=$3
local expected_status=${4:-200}
if [ "$method" == "GET" ]; then
response=$(curl -s -w "\n%{http_code}" "${API_URL}${endpoint}")
else
response=$(curl -s -w "\n%{http_code}" -X "$method" \
-H "Content-Type: application/json" \
-d "$data" \
"${API_URL}${endpoint}")
fi
# Extract status code (last line) and body (everything else)
http_code=$(echo "$response" | tail -n1)
body=$(echo "$response" | sed '$d')
if [ "$VERBOSE" == "true" ]; then
echo "Response: $body"
echo "Status: $http_code"
fi
if [ "$http_code" == "$expected_status" ]; then
return 0
else
echo "Expected status $expected_status, got $http_code"
echo "Response: $body"
return 1
fi
}
# ============================================================================
# Sample Data (from roofline_features.h5)
# ============================================================================
# TurTLE application - turbulence simulation workload
SAMPLE_TURTLE='{
"bandwidth_raw_p10": 186.33,
"bandwidth_raw_median": 205.14,
"bandwidth_raw_p90": 210.83,
"bandwidth_raw_mad": 3.57,
"bandwidth_raw_range": 24.5,
"bandwidth_raw_iqr": 12.075,
"flops_raw_p10": 162.024,
"flops_raw_median": 171.45,
"flops_raw_p90": 176.48,
"flops_raw_mad": 3.08,
"flops_raw_range": 14.456,
"flops_raw_iqr": 8.29,
"arith_intensity_p10": 0.7906,
"arith_intensity_median": 0.837,
"arith_intensity_p90": 0.9109,
"arith_intensity_mad": 0.02,
"arith_intensity_range": 0.12,
"arith_intensity_iqr": 0.0425,
"bw_flops_covariance": 60.86,
"bw_flops_correlation": 0.16,
"avg_performance_gflops": 168.1,
"median_performance_gflops": 171.45,
"performance_gflops_mad": 3.08,
"avg_memory_bw_gbs": 350.0,
"scalar_peak_gflops": 432.0,
"simd_peak_gflops": 9216.0,
"node_num": 0,
"duration": 19366
}'
# Chroma application - lattice QCD workload (compute-intensive)
SAMPLE_CHROMA='{
"bandwidth_raw_p10": 154.176,
"bandwidth_raw_median": 200.57,
"bandwidth_raw_p90": 259.952,
"bandwidth_raw_mad": 5.12,
"bandwidth_raw_range": 105.776,
"bandwidth_raw_iqr": 10.215,
"flops_raw_p10": 327.966,
"flops_raw_median": 519.8,
"flops_raw_p90": 654.422,
"flops_raw_mad": 16.97,
"flops_raw_range": 326.456,
"flops_raw_iqr": 34.88,
"arith_intensity_p10": 1.55,
"arith_intensity_median": 2.595,
"arith_intensity_p90": 3.445,
"arith_intensity_mad": 0.254,
"arith_intensity_range": 1.894,
"arith_intensity_iqr": 0.512,
"bw_flops_covariance": 382.76,
"bw_flops_correlation": 0.063,
"avg_performance_gflops": 503.26,
"median_performance_gflops": 519.8,
"performance_gflops_mad": 16.97,
"avg_memory_bw_gbs": 350.0,
"scalar_peak_gflops": 432.0,
"simd_peak_gflops": 9216.0,
"node_num": 3,
"duration": 31133
}'
# Raw JSON roofline data (before aggregation)
SAMPLE_JSON_ROOFLINE='[
{"node_num": 1, "bandwidth_raw": 150.5, "flops_raw": 2500.0, "arith_intensity": 16.6, "performance_gflops": 1200.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
{"node_num": 1, "bandwidth_raw": 155.2, "flops_raw": 2600.0, "arith_intensity": 16.8, "performance_gflops": 1250.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
{"node_num": 1, "bandwidth_raw": 148.0, "flops_raw": 2450.0, "arith_intensity": 16.5, "performance_gflops": 1180.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600}
]'
# ============================================================================
# Tests
# ============================================================================
print_header "XGBoost FastAPI Test Suite"
echo "Testing API at: $API_URL"
echo "Started at: $(date)"
# ----------------------------------------------------------------------------
# Health & Info Endpoints
# ----------------------------------------------------------------------------
print_header "Health & Info Endpoints"
print_test "Root Endpoint"
if make_request GET "/"; then
print_success "Root endpoint accessible"
else
print_failure "Root endpoint failed"
fi
print_test "Health Check"
if make_request GET "/health"; then
print_success "Health check passed"
else
print_failure "Health check failed"
fi
print_test "Model Info"
if make_request GET "/model/info"; then
print_success "Model info retrieved"
else
print_failure "Model info failed"
fi
# ----------------------------------------------------------------------------
# Single Prediction Endpoints
# ----------------------------------------------------------------------------
print_header "Single Prediction Endpoint (/predict)"
print_test "Predict with TurTLE sample (threshold=0.3)"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE,
"threshold": 0.3,
"return_all_probabilities": true
}
EOF
)
if make_request POST "/predict" "$REQUEST_DATA"; then
print_success "TurTLE prediction completed"
else
print_failure "TurTLE prediction failed"
fi
print_test "Predict with Chroma sample (threshold=0.5)"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_CHROMA,
"threshold": 0.5
}
EOF
)
if make_request POST "/predict" "$REQUEST_DATA"; then
print_success "Chroma prediction completed"
else
print_failure "Chroma prediction failed"
fi
print_test "Predict with JSON roofline data (is_json=true)"
# Need to escape the JSON for embedding
ESCAPED_JSON=$(echo "$SAMPLE_JSON_ROOFLINE" | tr -d '\n' | sed 's/"/\\"/g')
REQUEST_DATA=$(cat <<EOF
{
"features": "$ESCAPED_JSON",
"is_json": true,
"job_id": "test_job_curl_001",
"threshold": 0.3
}
EOF
)
if make_request POST "/predict" "$REQUEST_DATA"; then
print_success "JSON roofline prediction completed"
else
print_failure "JSON roofline prediction failed"
fi
print_test "Predict with return_all_probabilities=false"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE,
"threshold": 0.3,
"return_all_probabilities": false
}
EOF
)
if make_request POST "/predict" "$REQUEST_DATA"; then
print_success "Prediction with filtered probabilities completed"
else
print_failure "Prediction with filtered probabilities failed"
fi
# ----------------------------------------------------------------------------
# Top-K Prediction Endpoints
# ----------------------------------------------------------------------------
print_header "Top-K Prediction Endpoint (/predict_top_k)"
print_test "Top-5 predictions (default)"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE
}
EOF
)
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
print_success "Top-5 prediction completed"
else
print_failure "Top-5 prediction failed"
fi
print_test "Top-10 predictions"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_CHROMA,
"k": 10
}
EOF
)
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
print_success "Top-10 prediction completed"
else
print_failure "Top-10 prediction failed"
fi
print_test "Top-3 predictions"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE,
"k": 3
}
EOF
)
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
print_success "Top-3 prediction completed"
else
print_failure "Top-3 prediction failed"
fi
# ----------------------------------------------------------------------------
# Batch Prediction Endpoints
# ----------------------------------------------------------------------------
print_header "Batch Prediction Endpoint (/batch_predict)"
print_test "Batch predict with 2 samples"
REQUEST_DATA=$(cat <<EOF
{
"features_list": [$SAMPLE_TURTLE, $SAMPLE_CHROMA],
"threshold": 0.3
}
EOF
)
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
print_success "Batch prediction (2 samples) completed"
else
print_failure "Batch prediction (2 samples) failed"
fi
print_test "Batch predict with single sample"
REQUEST_DATA=$(cat <<EOF
{
"features_list": [$SAMPLE_TURTLE],
"threshold": 0.5
}
EOF
)
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
print_success "Batch prediction (single sample) completed"
else
print_failure "Batch prediction (single sample) failed"
fi
print_test "Batch predict with empty list"
REQUEST_DATA='{"features_list": [], "threshold": 0.5}'
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
print_success "Batch prediction (empty list) completed"
else
print_failure "Batch prediction (empty list) failed"
fi
# ----------------------------------------------------------------------------
# Error Handling Tests
# ----------------------------------------------------------------------------
print_header "Error Handling Tests"
print_test "Invalid threshold (negative) - should return 422"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE,
"threshold": -0.5
}
EOF
)
if make_request POST "/predict" "$REQUEST_DATA" 422; then
print_success "Invalid threshold correctly rejected"
else
print_failure "Invalid threshold not rejected properly"
fi
print_test "Invalid threshold (> 1.0) - should return 422"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE,
"threshold": 1.5
}
EOF
)
if make_request POST "/predict" "$REQUEST_DATA" 422; then
print_success "Invalid threshold (>1) correctly rejected"
else
print_failure "Invalid threshold (>1) not rejected properly"
fi
print_test "Missing features field - should return 422"
REQUEST_DATA='{"threshold": 0.5}'
if make_request POST "/predict" "$REQUEST_DATA" 422; then
print_success "Missing features correctly rejected"
else
print_failure "Missing features not rejected properly"
fi
print_test "Invalid endpoint - should return 404"
if make_request GET "/nonexistent" "" 404; then
print_success "Invalid endpoint correctly returns 404"
else
print_failure "Invalid endpoint not handled properly"
fi
print_test "Invalid k value (0) - should return 422"
REQUEST_DATA=$(cat <<EOF
{
"features": $SAMPLE_TURTLE,
"k": 0
}
EOF
)
if make_request POST "/predict_top_k" "$REQUEST_DATA" 422; then
print_success "Invalid k value correctly rejected"
else
print_failure "Invalid k value not rejected properly"
fi
# ============================================================================
# Summary
# ============================================================================
print_header "Test Summary"
echo ""
echo -e "Passed: ${GREEN}$PASSED${NC}"
echo -e "Failed: ${RED}$FAILED${NC}"
echo ""
if [ $FAILED -eq 0 ]; then
echo -e "${GREEN}All tests passed! ✓${NC}"
exit 0
else
echo -e "${RED}Some tests failed! ✗${NC}"
exit 1
fi