437 lines
12 KiB
Bash
Executable File
437 lines
12 KiB
Bash
Executable File
#!/bin/bash
|
|
#
|
|
# Test script for XGBoost FastAPI Multi-Label Classification API
|
|
# Uses curl to test all endpoints with realistic sample data from roofline_features.h5
|
|
#
|
|
# Usage:
|
|
# 1. Start the server: ./venv/bin/python xgb_fastapi.py
|
|
# 2. Run this script: ./test_api_curl.sh
|
|
#
|
|
# Optional: Set API_URL environment variable to test a different host
|
|
# API_URL=http://192.168.1.100:8000 ./test_api_curl.sh
|
|
|
|
set -e
|
|
|
|
# Configuration
|
|
API_URL="${API_URL:-http://localhost:8000}"
|
|
VERBOSE="${VERBOSE:-false}"
|
|
|
|
# Colors for output
|
|
RED='\033[0;31m'
|
|
GREEN='\033[0;32m'
|
|
YELLOW='\033[1;33m'
|
|
BLUE='\033[0;34m'
|
|
NC='\033[0m' # No Color
|
|
|
|
# Counters
|
|
PASSED=0
|
|
FAILED=0
|
|
|
|
# ============================================================================
|
|
# Helper Functions
|
|
# ============================================================================
|
|
|
|
print_header() {
|
|
echo ""
|
|
echo -e "${BLUE}============================================================${NC}"
|
|
echo -e "${BLUE}$1${NC}"
|
|
echo -e "${BLUE}============================================================${NC}"
|
|
}
|
|
|
|
print_test() {
|
|
echo -e "\n${YELLOW}▶ TEST: $1${NC}"
|
|
}
|
|
|
|
print_success() {
|
|
echo -e "${GREEN}✓ PASSED: $1${NC}"
|
|
((PASSED++))
|
|
}
|
|
|
|
print_failure() {
|
|
echo -e "${RED}✗ FAILED: $1${NC}"
|
|
((FAILED++))
|
|
}
|
|
|
|
# Make a curl request and check status code
|
|
# Usage: make_request METHOD ENDPOINT [DATA]
|
|
make_request() {
|
|
local method=$1
|
|
local endpoint=$2
|
|
local data=$3
|
|
local expected_status=${4:-200}
|
|
|
|
if [ "$method" == "GET" ]; then
|
|
response=$(curl -s -w "\n%{http_code}" "${API_URL}${endpoint}")
|
|
else
|
|
response=$(curl -s -w "\n%{http_code}" -X "$method" \
|
|
-H "Content-Type: application/json" \
|
|
-d "$data" \
|
|
"${API_URL}${endpoint}")
|
|
fi
|
|
|
|
# Extract status code (last line) and body (everything else)
|
|
http_code=$(echo "$response" | tail -n1)
|
|
body=$(echo "$response" | sed '$d')
|
|
|
|
if [ "$VERBOSE" == "true" ]; then
|
|
echo "Response: $body"
|
|
echo "Status: $http_code"
|
|
fi
|
|
|
|
if [ "$http_code" == "$expected_status" ]; then
|
|
return 0
|
|
else
|
|
echo "Expected status $expected_status, got $http_code"
|
|
echo "Response: $body"
|
|
return 1
|
|
fi
|
|
}
|
|
|
|
# ============================================================================
|
|
# Sample Data (from roofline_features.h5)
|
|
# ============================================================================
|
|
|
|
# TurTLE application - turbulence simulation workload
|
|
SAMPLE_TURTLE='{
|
|
"bandwidth_raw_p10": 186.33,
|
|
"bandwidth_raw_median": 205.14,
|
|
"bandwidth_raw_p90": 210.83,
|
|
"bandwidth_raw_mad": 3.57,
|
|
"bandwidth_raw_range": 24.5,
|
|
"bandwidth_raw_iqr": 12.075,
|
|
"flops_raw_p10": 162.024,
|
|
"flops_raw_median": 171.45,
|
|
"flops_raw_p90": 176.48,
|
|
"flops_raw_mad": 3.08,
|
|
"flops_raw_range": 14.456,
|
|
"flops_raw_iqr": 8.29,
|
|
"arith_intensity_p10": 0.7906,
|
|
"arith_intensity_median": 0.837,
|
|
"arith_intensity_p90": 0.9109,
|
|
"arith_intensity_mad": 0.02,
|
|
"arith_intensity_range": 0.12,
|
|
"arith_intensity_iqr": 0.0425,
|
|
"bw_flops_covariance": 60.86,
|
|
"bw_flops_correlation": 0.16,
|
|
"avg_performance_gflops": 168.1,
|
|
"median_performance_gflops": 171.45,
|
|
"performance_gflops_mad": 3.08,
|
|
"avg_memory_bw_gbs": 350.0,
|
|
"scalar_peak_gflops": 432.0,
|
|
"simd_peak_gflops": 9216.0,
|
|
"node_num": 0,
|
|
"duration": 19366
|
|
}'
|
|
|
|
# Chroma application - lattice QCD workload (compute-intensive)
|
|
SAMPLE_CHROMA='{
|
|
"bandwidth_raw_p10": 154.176,
|
|
"bandwidth_raw_median": 200.57,
|
|
"bandwidth_raw_p90": 259.952,
|
|
"bandwidth_raw_mad": 5.12,
|
|
"bandwidth_raw_range": 105.776,
|
|
"bandwidth_raw_iqr": 10.215,
|
|
"flops_raw_p10": 327.966,
|
|
"flops_raw_median": 519.8,
|
|
"flops_raw_p90": 654.422,
|
|
"flops_raw_mad": 16.97,
|
|
"flops_raw_range": 326.456,
|
|
"flops_raw_iqr": 34.88,
|
|
"arith_intensity_p10": 1.55,
|
|
"arith_intensity_median": 2.595,
|
|
"arith_intensity_p90": 3.445,
|
|
"arith_intensity_mad": 0.254,
|
|
"arith_intensity_range": 1.894,
|
|
"arith_intensity_iqr": 0.512,
|
|
"bw_flops_covariance": 382.76,
|
|
"bw_flops_correlation": 0.063,
|
|
"avg_performance_gflops": 503.26,
|
|
"median_performance_gflops": 519.8,
|
|
"performance_gflops_mad": 16.97,
|
|
"avg_memory_bw_gbs": 350.0,
|
|
"scalar_peak_gflops": 432.0,
|
|
"simd_peak_gflops": 9216.0,
|
|
"node_num": 3,
|
|
"duration": 31133
|
|
}'
|
|
|
|
# Raw JSON roofline data (before aggregation)
|
|
SAMPLE_JSON_ROOFLINE='[
|
|
{"node_num": 1, "bandwidth_raw": 150.5, "flops_raw": 2500.0, "arith_intensity": 16.6, "performance_gflops": 1200.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
|
|
{"node_num": 1, "bandwidth_raw": 155.2, "flops_raw": 2600.0, "arith_intensity": 16.8, "performance_gflops": 1250.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600},
|
|
{"node_num": 1, "bandwidth_raw": 148.0, "flops_raw": 2450.0, "arith_intensity": 16.5, "performance_gflops": 1180.0, "memory_bw_gbs": 450, "scalar_peak_gflops": 600, "duration": 3600}
|
|
]'
|
|
|
|
# ============================================================================
|
|
# Tests
|
|
# ============================================================================
|
|
|
|
print_header "XGBoost FastAPI Test Suite"
|
|
echo "Testing API at: $API_URL"
|
|
echo "Started at: $(date)"
|
|
|
|
# ----------------------------------------------------------------------------
|
|
# Health & Info Endpoints
|
|
# ----------------------------------------------------------------------------
|
|
|
|
print_header "Health & Info Endpoints"
|
|
|
|
print_test "Root Endpoint"
|
|
if make_request GET "/"; then
|
|
print_success "Root endpoint accessible"
|
|
else
|
|
print_failure "Root endpoint failed"
|
|
fi
|
|
|
|
print_test "Health Check"
|
|
if make_request GET "/health"; then
|
|
print_success "Health check passed"
|
|
else
|
|
print_failure "Health check failed"
|
|
fi
|
|
|
|
print_test "Model Info"
|
|
if make_request GET "/model/info"; then
|
|
print_success "Model info retrieved"
|
|
else
|
|
print_failure "Model info failed"
|
|
fi
|
|
|
|
# ----------------------------------------------------------------------------
|
|
# Single Prediction Endpoints
|
|
# ----------------------------------------------------------------------------
|
|
|
|
print_header "Single Prediction Endpoint (/predict)"
|
|
|
|
print_test "Predict with TurTLE sample (threshold=0.3)"
|
|
REQUEST_DATA=$(cat <<EOF
|
|
{
|
|
"features": $SAMPLE_TURTLE,
|
|
"threshold": 0.3,
|
|
"return_all_probabilities": true
|
|
}
|
|
EOF
|
|
)
|
|
if make_request POST "/predict" "$REQUEST_DATA"; then
|
|
print_success "TurTLE prediction completed"
|
|
else
|
|
print_failure "TurTLE prediction failed"
|
|
fi
|
|
|
|
print_test "Predict with Chroma sample (threshold=0.5)"
|
|
REQUEST_DATA=$(cat <<EOF
|
|
{
|
|
"features": $SAMPLE_CHROMA,
|
|
"threshold": 0.5
|
|
}
|
|
EOF
|
|
)
|
|
if make_request POST "/predict" "$REQUEST_DATA"; then
|
|
print_success "Chroma prediction completed"
|
|
else
|
|
print_failure "Chroma prediction failed"
|
|
fi
|
|
|
|
print_test "Predict with JSON roofline data (is_json=true)"
|
|
# Need to escape the JSON for embedding
|
|
ESCAPED_JSON=$(echo "$SAMPLE_JSON_ROOFLINE" | tr -d '\n' | sed 's/"/\\"/g')
|
|
REQUEST_DATA=$(cat <<EOF
|
|
{
|
|
"features": "$ESCAPED_JSON",
|
|
"is_json": true,
|
|
"job_id": "test_job_curl_001",
|
|
"threshold": 0.3
|
|
}
|
|
EOF
|
|
)
|
|
if make_request POST "/predict" "$REQUEST_DATA"; then
|
|
print_success "JSON roofline prediction completed"
|
|
else
|
|
print_failure "JSON roofline prediction failed"
|
|
fi
|
|
|
|
print_test "Predict with return_all_probabilities=false"
|
|
REQUEST_DATA=$(cat <<EOF
|
|
{
|
|
"features": $SAMPLE_TURTLE,
|
|
"threshold": 0.3,
|
|
"return_all_probabilities": false
|
|
}
|
|
EOF
|
|
)
|
|
if make_request POST "/predict" "$REQUEST_DATA"; then
|
|
print_success "Prediction with filtered probabilities completed"
|
|
else
|
|
print_failure "Prediction with filtered probabilities failed"
|
|
fi
|
|
|
|
# ----------------------------------------------------------------------------
|
|
# Top-K Prediction Endpoints
|
|
# ----------------------------------------------------------------------------
|
|
|
|
print_header "Top-K Prediction Endpoint (/predict_top_k)"
|
|
|
|
print_test "Top-5 predictions (default)"
|
|
REQUEST_DATA=$(cat <<EOF
|
|
{
|
|
"features": $SAMPLE_TURTLE
|
|
}
|
|
EOF
|
|
)
|
|
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
|
|
print_success "Top-5 prediction completed"
|
|
else
|
|
print_failure "Top-5 prediction failed"
|
|
fi
|
|
|
|
print_test "Top-10 predictions"
|
|
REQUEST_DATA=$(cat <<EOF
|
|
{
|
|
"features": $SAMPLE_CHROMA,
|
|
"k": 10
|
|
}
|
|
EOF
|
|
)
|
|
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
|
|
print_success "Top-10 prediction completed"
|
|
else
|
|
print_failure "Top-10 prediction failed"
|
|
fi
|
|
|
|
print_test "Top-3 predictions"
|
|
REQUEST_DATA=$(cat <<EOF
|
|
{
|
|
"features": $SAMPLE_TURTLE,
|
|
"k": 3
|
|
}
|
|
EOF
|
|
)
|
|
if make_request POST "/predict_top_k" "$REQUEST_DATA"; then
|
|
print_success "Top-3 prediction completed"
|
|
else
|
|
print_failure "Top-3 prediction failed"
|
|
fi
|
|
|
|
# ----------------------------------------------------------------------------
|
|
# Batch Prediction Endpoints
|
|
# ----------------------------------------------------------------------------
|
|
|
|
print_header "Batch Prediction Endpoint (/batch_predict)"
|
|
|
|
print_test "Batch predict with 2 samples"
|
|
REQUEST_DATA=$(cat <<EOF
|
|
{
|
|
"features_list": [$SAMPLE_TURTLE, $SAMPLE_CHROMA],
|
|
"threshold": 0.3
|
|
}
|
|
EOF
|
|
)
|
|
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
|
|
print_success "Batch prediction (2 samples) completed"
|
|
else
|
|
print_failure "Batch prediction (2 samples) failed"
|
|
fi
|
|
|
|
print_test "Batch predict with single sample"
|
|
REQUEST_DATA=$(cat <<EOF
|
|
{
|
|
"features_list": [$SAMPLE_TURTLE],
|
|
"threshold": 0.5
|
|
}
|
|
EOF
|
|
)
|
|
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
|
|
print_success "Batch prediction (single sample) completed"
|
|
else
|
|
print_failure "Batch prediction (single sample) failed"
|
|
fi
|
|
|
|
print_test "Batch predict with empty list"
|
|
REQUEST_DATA='{"features_list": [], "threshold": 0.5}'
|
|
if make_request POST "/batch_predict" "$REQUEST_DATA"; then
|
|
print_success "Batch prediction (empty list) completed"
|
|
else
|
|
print_failure "Batch prediction (empty list) failed"
|
|
fi
|
|
|
|
# ----------------------------------------------------------------------------
|
|
# Error Handling Tests
|
|
# ----------------------------------------------------------------------------
|
|
|
|
print_header "Error Handling Tests"
|
|
|
|
print_test "Invalid threshold (negative) - should return 422"
|
|
REQUEST_DATA=$(cat <<EOF
|
|
{
|
|
"features": $SAMPLE_TURTLE,
|
|
"threshold": -0.5
|
|
}
|
|
EOF
|
|
)
|
|
if make_request POST "/predict" "$REQUEST_DATA" 422; then
|
|
print_success "Invalid threshold correctly rejected"
|
|
else
|
|
print_failure "Invalid threshold not rejected properly"
|
|
fi
|
|
|
|
print_test "Invalid threshold (> 1.0) - should return 422"
|
|
REQUEST_DATA=$(cat <<EOF
|
|
{
|
|
"features": $SAMPLE_TURTLE,
|
|
"threshold": 1.5
|
|
}
|
|
EOF
|
|
)
|
|
if make_request POST "/predict" "$REQUEST_DATA" 422; then
|
|
print_success "Invalid threshold (>1) correctly rejected"
|
|
else
|
|
print_failure "Invalid threshold (>1) not rejected properly"
|
|
fi
|
|
|
|
print_test "Missing features field - should return 422"
|
|
REQUEST_DATA='{"threshold": 0.5}'
|
|
if make_request POST "/predict" "$REQUEST_DATA" 422; then
|
|
print_success "Missing features correctly rejected"
|
|
else
|
|
print_failure "Missing features not rejected properly"
|
|
fi
|
|
|
|
print_test "Invalid endpoint - should return 404"
|
|
if make_request GET "/nonexistent" "" 404; then
|
|
print_success "Invalid endpoint correctly returns 404"
|
|
else
|
|
print_failure "Invalid endpoint not handled properly"
|
|
fi
|
|
|
|
print_test "Invalid k value (0) - should return 422"
|
|
REQUEST_DATA=$(cat <<EOF
|
|
{
|
|
"features": $SAMPLE_TURTLE,
|
|
"k": 0
|
|
}
|
|
EOF
|
|
)
|
|
if make_request POST "/predict_top_k" "$REQUEST_DATA" 422; then
|
|
print_success "Invalid k value correctly rejected"
|
|
else
|
|
print_failure "Invalid k value not rejected properly"
|
|
fi
|
|
|
|
# ============================================================================
|
|
# Summary
|
|
# ============================================================================
|
|
|
|
print_header "Test Summary"
|
|
echo ""
|
|
echo -e "Passed: ${GREEN}$PASSED${NC}"
|
|
echo -e "Failed: ${RED}$FAILED${NC}"
|
|
echo ""
|
|
|
|
if [ $FAILED -eq 0 ]; then
|
|
echo -e "${GREEN}All tests passed! ✓${NC}"
|
|
exit 0
|
|
else
|
|
echo -e "${RED}Some tests failed! ✗${NC}"
|
|
exit 1
|
|
fi
|