mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2025-12-31 02:46:16 +01:00
Add unit test for NATS API
This commit is contained in:
@@ -36,6 +36,8 @@ import (
|
||||
)
|
||||
|
||||
func setup(t *testing.T) *api.RestAPI {
|
||||
repository.ResetConnection()
|
||||
|
||||
const testconfig = `{
|
||||
"main": {
|
||||
"addr": "0.0.0.0:8080",
|
||||
@@ -190,11 +192,9 @@ func setup(t *testing.T) *api.RestAPI {
|
||||
}
|
||||
|
||||
func cleanup() {
|
||||
// Gracefully shutdown archiver with timeout
|
||||
if err := archiver.Shutdown(5 * time.Second); err != nil {
|
||||
cclog.Warnf("Archiver shutdown timeout in tests: %v", err)
|
||||
}
|
||||
// TODO: Clear all caches, reset all modules, etc...
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -230,7 +230,7 @@ func TestRestApi(t *testing.T) {
|
||||
r.StrictSlash(true)
|
||||
restapi.MountAPIRoutes(r)
|
||||
|
||||
var TestJobId int64 = 123
|
||||
var TestJobID int64 = 123
|
||||
TestClusterName := "testcluster"
|
||||
var TestStartTime int64 = 123456789
|
||||
|
||||
@@ -280,7 +280,7 @@ func TestRestApi(t *testing.T) {
|
||||
}
|
||||
// resolver := graph.GetResolverInstance()
|
||||
restapi.JobRepository.SyncJobs()
|
||||
job, err := restapi.JobRepository.Find(&TestJobId, &TestClusterName, &TestStartTime)
|
||||
job, err := restapi.JobRepository.Find(&TestJobID, &TestClusterName, &TestStartTime)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -338,7 +338,7 @@ func TestRestApi(t *testing.T) {
|
||||
}
|
||||
|
||||
// Archiving happens asynchronously, will be completed in cleanup
|
||||
job, err := restapi.JobRepository.Find(&TestJobId, &TestClusterName, &TestStartTime)
|
||||
job, err := restapi.JobRepository.Find(&TestJobID, &TestClusterName, &TestStartTime)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
892
internal/api/nats_test.go
Normal file
892
internal/api/nats_test.go
Normal file
@@ -0,0 +1,892 @@
|
||||
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||
// All rights reserved. This file is part of cc-backend.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ClusterCockpit/cc-backend/internal/archiver"
|
||||
"github.com/ClusterCockpit/cc-backend/internal/auth"
|
||||
"github.com/ClusterCockpit/cc-backend/internal/config"
|
||||
"github.com/ClusterCockpit/cc-backend/internal/graph"
|
||||
"github.com/ClusterCockpit/cc-backend/internal/metricdata"
|
||||
"github.com/ClusterCockpit/cc-backend/internal/repository"
|
||||
"github.com/ClusterCockpit/cc-backend/pkg/archive"
|
||||
ccconf "github.com/ClusterCockpit/cc-lib/ccConfig"
|
||||
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
||||
lp "github.com/ClusterCockpit/cc-lib/ccMessage"
|
||||
"github.com/ClusterCockpit/cc-lib/schema"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func setupNatsTest(t *testing.T) *NatsAPI {
|
||||
repository.ResetConnection()
|
||||
|
||||
const testconfig = `{
|
||||
"main": {
|
||||
"addr": "0.0.0.0:8080",
|
||||
"validate": false,
|
||||
"apiAllowedIPs": [
|
||||
"*"
|
||||
]
|
||||
},
|
||||
"archive": {
|
||||
"kind": "file",
|
||||
"path": "./var/job-archive"
|
||||
},
|
||||
"auth": {
|
||||
"jwts": {
|
||||
"max-age": "2m"
|
||||
}
|
||||
},
|
||||
"clusters": [
|
||||
{
|
||||
"name": "testcluster",
|
||||
"metricDataRepository": {"kind": "test", "url": "bla:8081"},
|
||||
"filterRanges": {
|
||||
"numNodes": { "from": 1, "to": 64 },
|
||||
"duration": { "from": 0, "to": 86400 },
|
||||
"startTime": { "from": "2022-01-01T00:00:00Z", "to": null }
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
const testclusterJSON = `{
|
||||
"name": "testcluster",
|
||||
"subClusters": [
|
||||
{
|
||||
"name": "sc1",
|
||||
"nodes": "host123,host124,host125",
|
||||
"processorType": "Intel Core i7-4770",
|
||||
"socketsPerNode": 1,
|
||||
"coresPerSocket": 4,
|
||||
"threadsPerCore": 2,
|
||||
"flopRateScalar": {
|
||||
"unit": {
|
||||
"prefix": "G",
|
||||
"base": "F/s"
|
||||
},
|
||||
"value": 14
|
||||
},
|
||||
"flopRateSimd": {
|
||||
"unit": {
|
||||
"prefix": "G",
|
||||
"base": "F/s"
|
||||
},
|
||||
"value": 112
|
||||
},
|
||||
"memoryBandwidth": {
|
||||
"unit": {
|
||||
"prefix": "G",
|
||||
"base": "B/s"
|
||||
},
|
||||
"value": 24
|
||||
},
|
||||
"numberOfNodes": 70,
|
||||
"topology": {
|
||||
"node": [0, 1, 2, 3, 4, 5, 6, 7],
|
||||
"socket": [[0, 1, 2, 3, 4, 5, 6, 7]],
|
||||
"memoryDomain": [[0, 1, 2, 3, 4, 5, 6, 7]],
|
||||
"die": [[0, 1, 2, 3, 4, 5, 6, 7]],
|
||||
"core": [[0], [1], [2], [3], [4], [5], [6], [7]]
|
||||
}
|
||||
}
|
||||
],
|
||||
"metricConfig": [
|
||||
{
|
||||
"name": "load_one",
|
||||
"unit": { "base": ""},
|
||||
"scope": "node",
|
||||
"timestep": 60,
|
||||
"aggregation": "avg",
|
||||
"peak": 8,
|
||||
"normal": 0,
|
||||
"caution": 0,
|
||||
"alert": 0
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
cclog.Init("info", true)
|
||||
tmpdir := t.TempDir()
|
||||
jobarchive := filepath.Join(tmpdir, "job-archive")
|
||||
if err := os.Mkdir(jobarchive, 0o777); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(jobarchive, "version.txt"), fmt.Appendf(nil, "%d", 3), 0o666); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.Mkdir(filepath.Join(jobarchive, "testcluster"), 0o777); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(jobarchive, "testcluster", "cluster.json"), []byte(testclusterJSON), 0o666); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dbfilepath := filepath.Join(tmpdir, "test.db")
|
||||
err := repository.MigrateDB(dbfilepath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfgFilePath := filepath.Join(tmpdir, "config.json")
|
||||
if err := os.WriteFile(cfgFilePath, []byte(testconfig), 0o666); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ccconf.Init(cfgFilePath)
|
||||
|
||||
// Load and check main configuration
|
||||
if cfg := ccconf.GetPackageConfig("main"); cfg != nil {
|
||||
if clustercfg := ccconf.GetPackageConfig("clusters"); clustercfg != nil {
|
||||
config.Init(cfg, clustercfg)
|
||||
} else {
|
||||
cclog.Abort("Cluster configuration must be present")
|
||||
}
|
||||
} else {
|
||||
cclog.Abort("Main configuration must be present")
|
||||
}
|
||||
archiveCfg := fmt.Sprintf("{\"kind\": \"file\",\"path\": \"%s\"}", jobarchive)
|
||||
|
||||
repository.Connect("sqlite3", dbfilepath)
|
||||
|
||||
if err := archive.Init(json.RawMessage(archiveCfg), config.Keys.DisableArchive); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := metricdata.Init(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
archiver.Start(repository.GetJobRepository(), context.Background())
|
||||
|
||||
if cfg := ccconf.GetPackageConfig("auth"); cfg != nil {
|
||||
auth.Init(&cfg)
|
||||
} else {
|
||||
cclog.Warn("Authentication disabled due to missing configuration")
|
||||
auth.Init(nil)
|
||||
}
|
||||
|
||||
graph.Init()
|
||||
|
||||
return NewNatsAPI()
|
||||
}
|
||||
|
||||
func cleanupNatsTest() {
|
||||
if err := archiver.Shutdown(5 * time.Second); err != nil {
|
||||
cclog.Warnf("Archiver shutdown timeout in tests: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNatsHandleStartJob(t *testing.T) {
|
||||
natsAPI := setupNatsTest(t)
|
||||
t.Cleanup(cleanupNatsTest)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
payload string
|
||||
expectError bool
|
||||
validateJob func(t *testing.T, job *schema.Job)
|
||||
shouldFindJob bool
|
||||
}{
|
||||
{
|
||||
name: "valid job start",
|
||||
payload: `{
|
||||
"jobId": 1001,
|
||||
"user": "testuser1",
|
||||
"project": "testproj1",
|
||||
"cluster": "testcluster",
|
||||
"partition": "main",
|
||||
"walltime": 7200,
|
||||
"numNodes": 1,
|
||||
"numHwthreads": 8,
|
||||
"numAcc": 0,
|
||||
"shared": "none",
|
||||
"monitoringStatus": 1,
|
||||
"smt": 1,
|
||||
"resources": [
|
||||
{
|
||||
"hostname": "host123",
|
||||
"hwthreads": [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
}
|
||||
],
|
||||
"startTime": 1234567890
|
||||
}`,
|
||||
expectError: false,
|
||||
shouldFindJob: true,
|
||||
validateJob: func(t *testing.T, job *schema.Job) {
|
||||
if job.JobID != 1001 {
|
||||
t.Errorf("expected JobID 1001, got %d", job.JobID)
|
||||
}
|
||||
if job.User != "testuser1" {
|
||||
t.Errorf("expected user testuser1, got %s", job.User)
|
||||
}
|
||||
if job.State != schema.JobStateRunning {
|
||||
t.Errorf("expected state running, got %s", job.State)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
payload: `{
|
||||
"jobId": "not a number",
|
||||
"user": "testuser2"
|
||||
}`,
|
||||
expectError: true,
|
||||
shouldFindJob: false,
|
||||
},
|
||||
{
|
||||
name: "missing required fields",
|
||||
payload: `{
|
||||
"jobId": 1002
|
||||
}`,
|
||||
expectError: true,
|
||||
shouldFindJob: false,
|
||||
},
|
||||
{
|
||||
name: "job with unknown fields (should fail due to DisallowUnknownFields)",
|
||||
payload: `{
|
||||
"jobId": 1003,
|
||||
"user": "testuser3",
|
||||
"project": "testproj3",
|
||||
"cluster": "testcluster",
|
||||
"partition": "main",
|
||||
"walltime": 3600,
|
||||
"numNodes": 1,
|
||||
"numHwthreads": 8,
|
||||
"unknownField": "should cause error",
|
||||
"startTime": 1234567900
|
||||
}`,
|
||||
expectError: true,
|
||||
shouldFindJob: false,
|
||||
},
|
||||
{
|
||||
name: "job with tags",
|
||||
payload: `{
|
||||
"jobId": 1004,
|
||||
"user": "testuser4",
|
||||
"project": "testproj4",
|
||||
"cluster": "testcluster",
|
||||
"partition": "main",
|
||||
"walltime": 3600,
|
||||
"numNodes": 1,
|
||||
"numHwthreads": 8,
|
||||
"numAcc": 0,
|
||||
"shared": "none",
|
||||
"monitoringStatus": 1,
|
||||
"smt": 1,
|
||||
"resources": [
|
||||
{
|
||||
"hostname": "host123",
|
||||
"hwthreads": [0, 1, 2, 3]
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
{
|
||||
"type": "test",
|
||||
"name": "testtag",
|
||||
"scope": "testuser4"
|
||||
}
|
||||
],
|
||||
"startTime": 1234567910
|
||||
}`,
|
||||
expectError: false,
|
||||
shouldFindJob: true,
|
||||
validateJob: func(t *testing.T, job *schema.Job) {
|
||||
if job.JobID != 1004 {
|
||||
t.Errorf("expected JobID 1004, got %d", job.JobID)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
natsAPI.handleStartJob(tt.payload)
|
||||
natsAPI.JobRepository.SyncJobs()
|
||||
|
||||
// Allow some time for async operations
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if tt.shouldFindJob {
|
||||
// Extract jobId from payload
|
||||
var payloadMap map[string]any
|
||||
json.Unmarshal([]byte(tt.payload), &payloadMap)
|
||||
jobID := int64(payloadMap["jobId"].(float64))
|
||||
cluster := payloadMap["cluster"].(string)
|
||||
startTime := int64(payloadMap["startTime"].(float64))
|
||||
|
||||
job, err := natsAPI.JobRepository.Find(&jobID, &cluster, &startTime)
|
||||
if err != nil {
|
||||
if !tt.expectError {
|
||||
t.Fatalf("expected to find job, but got error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if tt.validateJob != nil {
|
||||
tt.validateJob(t, job)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNatsHandleStopJob(t *testing.T) {
|
||||
natsAPI := setupNatsTest(t)
|
||||
t.Cleanup(cleanupNatsTest)
|
||||
|
||||
// First, create a running job
|
||||
startPayload := `{
|
||||
"jobId": 2001,
|
||||
"user": "testuser",
|
||||
"project": "testproj",
|
||||
"cluster": "testcluster",
|
||||
"partition": "main",
|
||||
"walltime": 3600,
|
||||
"numNodes": 1,
|
||||
"numHwthreads": 8,
|
||||
"numAcc": 0,
|
||||
"shared": "none",
|
||||
"monitoringStatus": 1,
|
||||
"smt": 1,
|
||||
"resources": [
|
||||
{
|
||||
"hostname": "host123",
|
||||
"hwthreads": [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
}
|
||||
],
|
||||
"startTime": 1234567890
|
||||
}`
|
||||
|
||||
natsAPI.handleStartJob(startPayload)
|
||||
natsAPI.JobRepository.SyncJobs()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
payload string
|
||||
expectError bool
|
||||
validateJob func(t *testing.T, job *schema.Job)
|
||||
setupJobFunc func() // Optional: create specific test job
|
||||
}{
|
||||
{
|
||||
name: "valid job stop - completed",
|
||||
payload: `{
|
||||
"jobId": 2001,
|
||||
"cluster": "testcluster",
|
||||
"startTime": 1234567890,
|
||||
"jobState": "completed",
|
||||
"stopTime": 1234571490
|
||||
}`,
|
||||
expectError: false,
|
||||
validateJob: func(t *testing.T, job *schema.Job) {
|
||||
if job.State != schema.JobStateCompleted {
|
||||
t.Errorf("expected state completed, got %s", job.State)
|
||||
}
|
||||
expectedDuration := int32(1234571490 - 1234567890)
|
||||
if job.Duration != expectedDuration {
|
||||
t.Errorf("expected duration %d, got %d", expectedDuration, job.Duration)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid job stop - failed",
|
||||
setupJobFunc: func() {
|
||||
startPayloadFailed := `{
|
||||
"jobId": 2002,
|
||||
"user": "testuser",
|
||||
"project": "testproj",
|
||||
"cluster": "testcluster",
|
||||
"partition": "main",
|
||||
"walltime": 3600,
|
||||
"numNodes": 1,
|
||||
"numHwthreads": 8,
|
||||
"numAcc": 0,
|
||||
"shared": "none",
|
||||
"monitoringStatus": 1,
|
||||
"smt": 1,
|
||||
"resources": [
|
||||
{
|
||||
"hostname": "host123",
|
||||
"hwthreads": [0, 1, 2, 3]
|
||||
}
|
||||
],
|
||||
"startTime": 1234567900
|
||||
}`
|
||||
natsAPI.handleStartJob(startPayloadFailed)
|
||||
natsAPI.JobRepository.SyncJobs()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
},
|
||||
payload: `{
|
||||
"jobId": 2002,
|
||||
"cluster": "testcluster",
|
||||
"startTime": 1234567900,
|
||||
"jobState": "failed",
|
||||
"stopTime": 1234569900
|
||||
}`,
|
||||
expectError: false,
|
||||
validateJob: func(t *testing.T, job *schema.Job) {
|
||||
if job.State != schema.JobStateFailed {
|
||||
t.Errorf("expected state failed, got %s", job.State)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
payload: `{
|
||||
"jobId": "not a number"
|
||||
}`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing jobId",
|
||||
payload: `{
|
||||
"cluster": "testcluster",
|
||||
"jobState": "completed",
|
||||
"stopTime": 1234571490
|
||||
}`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid job state",
|
||||
setupJobFunc: func() {
|
||||
startPayloadInvalid := `{
|
||||
"jobId": 2003,
|
||||
"user": "testuser",
|
||||
"project": "testproj",
|
||||
"cluster": "testcluster",
|
||||
"partition": "main",
|
||||
"walltime": 3600,
|
||||
"numNodes": 1,
|
||||
"numHwthreads": 8,
|
||||
"numAcc": 0,
|
||||
"shared": "none",
|
||||
"monitoringStatus": 1,
|
||||
"smt": 1,
|
||||
"resources": [
|
||||
{
|
||||
"hostname": "host123",
|
||||
"hwthreads": [0, 1]
|
||||
}
|
||||
],
|
||||
"startTime": 1234567910
|
||||
}`
|
||||
natsAPI.handleStartJob(startPayloadInvalid)
|
||||
natsAPI.JobRepository.SyncJobs()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
},
|
||||
payload: `{
|
||||
"jobId": 2003,
|
||||
"cluster": "testcluster",
|
||||
"startTime": 1234567910,
|
||||
"jobState": "invalid_state",
|
||||
"stopTime": 1234571510
|
||||
}`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "stopTime before startTime",
|
||||
setupJobFunc: func() {
|
||||
startPayloadTime := `{
|
||||
"jobId": 2004,
|
||||
"user": "testuser",
|
||||
"project": "testproj",
|
||||
"cluster": "testcluster",
|
||||
"partition": "main",
|
||||
"walltime": 3600,
|
||||
"numNodes": 1,
|
||||
"numHwthreads": 8,
|
||||
"numAcc": 0,
|
||||
"shared": "none",
|
||||
"monitoringStatus": 1,
|
||||
"smt": 1,
|
||||
"resources": [
|
||||
{
|
||||
"hostname": "host123",
|
||||
"hwthreads": [0]
|
||||
}
|
||||
],
|
||||
"startTime": 1234567920
|
||||
}`
|
||||
natsAPI.handleStartJob(startPayloadTime)
|
||||
natsAPI.JobRepository.SyncJobs()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
},
|
||||
payload: `{
|
||||
"jobId": 2004,
|
||||
"cluster": "testcluster",
|
||||
"startTime": 1234567920,
|
||||
"jobState": "completed",
|
||||
"stopTime": 1234567900
|
||||
}`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "job not found",
|
||||
payload: `{
|
||||
"jobId": 99999,
|
||||
"cluster": "testcluster",
|
||||
"startTime": 1234567890,
|
||||
"jobState": "completed",
|
||||
"stopTime": 1234571490
|
||||
}`,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
testData := schema.JobData{
|
||||
"load_one": map[schema.MetricScope]*schema.JobMetric{
|
||||
schema.MetricScopeNode: {
|
||||
Unit: schema.Unit{Base: "load"},
|
||||
Timestep: 60,
|
||||
Series: []schema.Series{
|
||||
{
|
||||
Hostname: "host123",
|
||||
Statistics: schema.MetricStatistics{Min: 0.1, Avg: 0.2, Max: 0.3},
|
||||
Data: []schema.Float{0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
metricdata.TestLoadDataCallback = func(job *schema.Job, metrics []string, scopes []schema.MetricScope, ctx context.Context, resolution int) (schema.JobData, error) {
|
||||
return testData, nil
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.setupJobFunc != nil {
|
||||
tt.setupJobFunc()
|
||||
}
|
||||
|
||||
natsAPI.handleStopJob(tt.payload)
|
||||
|
||||
// Allow some time for async operations
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if !tt.expectError && tt.validateJob != nil {
|
||||
// Extract job details from payload
|
||||
var payloadMap map[string]any
|
||||
json.Unmarshal([]byte(tt.payload), &payloadMap)
|
||||
jobID := int64(payloadMap["jobId"].(float64))
|
||||
cluster := payloadMap["cluster"].(string)
|
||||
|
||||
var startTime *int64
|
||||
if st, ok := payloadMap["startTime"]; ok {
|
||||
t := int64(st.(float64))
|
||||
startTime = &t
|
||||
}
|
||||
|
||||
job, err := natsAPI.JobRepository.Find(&jobID, &cluster, startTime)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to find job, but got error: %v", err)
|
||||
}
|
||||
|
||||
tt.validateJob(t, job)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNatsHandleNodeState(t *testing.T) {
|
||||
natsAPI := setupNatsTest(t)
|
||||
t.Cleanup(cleanupNatsTest)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
payload string
|
||||
expectError bool
|
||||
validateFn func(t *testing.T)
|
||||
}{
|
||||
{
|
||||
name: "valid node state update",
|
||||
payload: `{
|
||||
"cluster": "testcluster",
|
||||
"nodes": [
|
||||
{
|
||||
"hostname": "host123",
|
||||
"states": ["allocated"],
|
||||
"cpusAllocated": 8,
|
||||
"memoryAllocated": 16384,
|
||||
"gpusAllocated": 0,
|
||||
"jobsRunning": 1
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectError: false,
|
||||
validateFn: func(t *testing.T) {
|
||||
// In a full test, we would verify the node state was updated in the database
|
||||
// For now, just ensure no error occurred
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple nodes",
|
||||
payload: `{
|
||||
"cluster": "testcluster",
|
||||
"nodes": [
|
||||
{
|
||||
"hostname": "host123",
|
||||
"states": ["idle"],
|
||||
"cpusAllocated": 0,
|
||||
"memoryAllocated": 0,
|
||||
"gpusAllocated": 0,
|
||||
"jobsRunning": 0
|
||||
},
|
||||
{
|
||||
"hostname": "host124",
|
||||
"states": ["allocated"],
|
||||
"cpusAllocated": 4,
|
||||
"memoryAllocated": 8192,
|
||||
"gpusAllocated": 1,
|
||||
"jobsRunning": 1
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
payload: `{
|
||||
"cluster": "testcluster",
|
||||
"nodes": "not an array"
|
||||
}`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty nodes array",
|
||||
payload: `{
|
||||
"cluster": "testcluster",
|
||||
"nodes": []
|
||||
}`,
|
||||
expectError: false, // Empty array should not cause error
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
natsAPI.handleNodeState("test.subject", []byte(tt.payload))
|
||||
|
||||
// Allow some time for async operations
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if tt.validateFn != nil {
|
||||
tt.validateFn(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNatsProcessJobEvent(t *testing.T) {
|
||||
natsAPI := setupNatsTest(t)
|
||||
t.Cleanup(cleanupNatsTest)
|
||||
|
||||
msgStartJob, err := lp.NewMessage(
|
||||
"job",
|
||||
map[string]string{"function": "start_job"},
|
||||
nil,
|
||||
map[string]any{
|
||||
"event": `{
|
||||
"jobId": 3001,
|
||||
"user": "testuser",
|
||||
"project": "testproj",
|
||||
"cluster": "testcluster",
|
||||
"partition": "main",
|
||||
"walltime": 3600,
|
||||
"numNodes": 1,
|
||||
"numHwthreads": 8,
|
||||
"numAcc": 0,
|
||||
"shared": "none",
|
||||
"monitoringStatus": 1,
|
||||
"smt": 1,
|
||||
"resources": [
|
||||
{
|
||||
"hostname": "host123",
|
||||
"hwthreads": [0, 1, 2, 3]
|
||||
}
|
||||
],
|
||||
"startTime": 1234567890
|
||||
}`,
|
||||
},
|
||||
time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test message: %v", err)
|
||||
}
|
||||
|
||||
msgMissingTag, err := lp.NewMessage(
|
||||
"job",
|
||||
map[string]string{},
|
||||
nil,
|
||||
map[string]any{
|
||||
"event": `{}`,
|
||||
},
|
||||
time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test message: %v", err)
|
||||
}
|
||||
|
||||
msgUnknownFunc, err := lp.NewMessage(
|
||||
"job",
|
||||
map[string]string{"function": "unknown_function"},
|
||||
nil,
|
||||
map[string]any{
|
||||
"event": `{}`,
|
||||
},
|
||||
time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test message: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
message lp.CCMessage
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "start_job function",
|
||||
message: msgStartJob,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing function tag",
|
||||
message: msgMissingTag,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "unknown function",
|
||||
message: msgUnknownFunc,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
natsAPI.processJobEvent(tt.message)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNatsHandleJobEvent(t *testing.T) {
|
||||
natsAPI := setupNatsTest(t)
|
||||
t.Cleanup(cleanupNatsTest)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid influx line protocol",
|
||||
data: []byte(`job,function=start_job event="{\"jobId\":4001,\"user\":\"testuser\",\"project\":\"testproj\",\"cluster\":\"testcluster\",\"partition\":\"main\",\"walltime\":3600,\"numNodes\":1,\"numHwthreads\":8,\"numAcc\":0,\"shared\":\"none\",\"monitoringStatus\":1,\"smt\":1,\"resources\":[{\"hostname\":\"host123\",\"hwthreads\":[0,1,2,3]}],\"startTime\":1234567890}"`),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid influx line protocol",
|
||||
data: []byte(`invalid line protocol format`),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty data",
|
||||
data: []byte(``),
|
||||
expectError: false, // Decoder should handle empty input gracefully
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// HandleJobEvent doesn't return errors, it logs them
|
||||
// We're just ensuring it doesn't panic
|
||||
natsAPI.handleJobEvent("test.subject", tt.data)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNatsHandleStartJobDuplicatePrevention(t *testing.T) {
|
||||
natsAPI := setupNatsTest(t)
|
||||
t.Cleanup(cleanupNatsTest)
|
||||
|
||||
// Start a job
|
||||
payload := `{
|
||||
"jobId": 5001,
|
||||
"user": "testuser",
|
||||
"project": "testproj",
|
||||
"cluster": "testcluster",
|
||||
"partition": "main",
|
||||
"walltime": 3600,
|
||||
"numNodes": 1,
|
||||
"numHwthreads": 8,
|
||||
"numAcc": 0,
|
||||
"shared": "none",
|
||||
"monitoringStatus": 1,
|
||||
"smt": 1,
|
||||
"resources": [
|
||||
{
|
||||
"hostname": "host123",
|
||||
"hwthreads": [0, 1, 2, 3]
|
||||
}
|
||||
],
|
||||
"startTime": 1234567890
|
||||
}`
|
||||
|
||||
natsAPI.handleStartJob(payload)
|
||||
natsAPI.JobRepository.SyncJobs()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Try to start the same job again (within 24 hours)
|
||||
duplicatePayload := `{
|
||||
"jobId": 5001,
|
||||
"user": "testuser",
|
||||
"project": "testproj",
|
||||
"cluster": "testcluster",
|
||||
"partition": "main",
|
||||
"walltime": 3600,
|
||||
"numNodes": 1,
|
||||
"numHwthreads": 8,
|
||||
"numAcc": 0,
|
||||
"shared": "none",
|
||||
"monitoringStatus": 1,
|
||||
"smt": 1,
|
||||
"resources": [
|
||||
{
|
||||
"hostname": "host123",
|
||||
"hwthreads": [0, 1, 2, 3]
|
||||
}
|
||||
],
|
||||
"startTime": 1234567900
|
||||
}`
|
||||
|
||||
natsAPI.handleStartJob(duplicatePayload)
|
||||
natsAPI.JobRepository.SyncJobs()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify only one job exists
|
||||
jobID := int64(5001)
|
||||
cluster := "testcluster"
|
||||
jobs, err := natsAPI.JobRepository.FindAll(&jobID, &cluster, nil)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(jobs) != 1 {
|
||||
t.Errorf("expected 1 job, got %d", len(jobs))
|
||||
}
|
||||
}
|
||||
@@ -115,3 +115,26 @@ func GetConnection() *DBConnection {
|
||||
|
||||
return dbConnInstance
|
||||
}
|
||||
|
||||
// ResetConnection closes the current database connection and resets the connection state.
|
||||
// This function is intended for testing purposes only to allow test isolation.
|
||||
func ResetConnection() error {
|
||||
if dbConnInstance != nil && dbConnInstance.DB != nil {
|
||||
if err := dbConnInstance.DB.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close database connection: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
dbConnInstance = nil
|
||||
dbConnOnce = sync.Once{}
|
||||
jobRepoInstance = nil
|
||||
jobRepoOnce = sync.Once{}
|
||||
nodeRepoInstance = nil
|
||||
nodeRepoOnce = sync.Once{}
|
||||
userRepoInstance = nil
|
||||
userRepoOnce = sync.Once{}
|
||||
userCfgRepoInstance = nil
|
||||
userCfgRepoOnce = sync.Once{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user