diff --git a/internal/scheduler/slurmRest.go b/internal/scheduler/slurmRest.go index 63b48fe..66f1576 100644 --- a/internal/scheduler/slurmRest.go +++ b/internal/scheduler/slurmRest.go @@ -27,11 +27,15 @@ type SlurmRestSchedulerConfig struct { JobRepository *repository.JobRepository clusterConfig ClusterConfig + + client *http.Client + + nodeClusterMap map[string]string + + clusterPCIEAddressMap map[string][]string } -var client *http.Client - -func queryDB(qtime int64, clusterName string) ([]interface{}, error) { +func (cfg *SlurmRestSchedulerConfig) queryDB(qtime int64, clusterName string) ([]interface{}, error) { apiEndpoint := "/slurmdb/v0.0.38/jobs" @@ -50,7 +54,7 @@ func queryDB(qtime int64, clusterName string) ([]interface{}, error) { } // Send the request - resp, err := client.Do(req) + resp, err := cfg.client.Do(req) if err != nil { log.Errorf("Error sending request: %v", err) } @@ -88,7 +92,7 @@ func queryDB(qtime int64, clusterName string) ([]interface{}, error) { return jobs, nil } -func fetchJobs() (SlurmPayload, error) { +func (cfg *SlurmRestSchedulerConfig) fetchJobs() (SlurmPayload, error) { var ctlOutput []byte apiEndpoint := "http://:8080/slurm/v0.0.38/jobs" @@ -99,7 +103,7 @@ func fetchJobs() (SlurmPayload, error) { } // Send the request - resp, err := client.Do(req) + resp, err := cfg.client.Do(req) if err != nil { log.Errorf("Error sending request: %v", err) } @@ -219,8 +223,23 @@ func (cfg *SlurmRestSchedulerConfig) Init() error { cfg.clusterConfig, err = DecodeClusterConfig("cluster-alex.json") + cfg.nodeClusterMap = make(map[string]string) + cfg.clusterPCIEAddressMap = make(map[string][]string) + + for _, subCluster := range cfg.clusterConfig.SubClusters { + + cfg.ConstructNodeClusterMap(subCluster.Nodes, subCluster.Name) + + pcieAddresses := make([]string, 0, 32) + + for idx, accelerator := range subCluster.Topology.Accelerators { + pcieAddresses[idx] = accelerator.ID + } + + cfg.clusterPCIEAddressMap[subCluster.Name] = pcieAddresses + } // Create an HTTP client - client = &http.Client{} + cfg.client = &http.Client{} return err } @@ -259,11 +278,14 @@ func (cfg *SlurmRestSchedulerConfig) checkAndHandleStopJob(job *schema.Job, req cfg.JobRepository.TriggerArchiving(job) } -func ConstructNodeAcceleratorMap(input string, accelerator string) map[string]string { - numberMap := make(map[string]string) +func (cfg *SlurmRestSchedulerConfig) ConstructNodeClusterMap(nodes string, cluster string) { + + if cfg.nodeClusterMap == nil { + cfg.nodeClusterMap = make(map[string]string) + } // Split the input by commas - groups := strings.Split(input, ",") + groups := strings.Split(nodes, ",") for _, group := range groups { // Use regular expressions to match numbers and ranges @@ -277,22 +299,59 @@ func ConstructNodeAcceleratorMap(input string, accelerator string) map[string]st start, _ := strconv.Atoi(matches[1]) end, _ := strconv.Atoi(matches[2]) for i := start; i <= end; i++ { - numberMap[matches[0]+fmt.Sprintf("%04d", i)] = accelerator + cfg.nodeClusterMap[matches[0]+fmt.Sprintf("%04d", i)] = cluster } } } else if numberRegex.MatchString(group) { // Extract individual node matches := numberRegex.FindStringSubmatch(group) if len(matches) == 2 { - numberMap[group] = accelerator + cfg.nodeClusterMap[group] = cluster } } } - - return numberMap } -func CreateJobMeta(job Job) *schema.JobMeta { +func extractElements(indicesStr string, addresses []string) ([]string, error) { + // Split the input string by commas to get individual index ranges + indexRanges := strings.Split(indicesStr, ",") + + var selected []string + for _, indexRange := range indexRanges { + // Split each index range by hyphen to separate start and end indices + rangeParts := strings.Split(indexRange, "-") + + if len(rangeParts) == 1 { + // If there's only one part, it's a single index + index, err := strconv.Atoi(rangeParts[0]) + if err != nil { + return nil, err + } + selected = append(selected, addresses[index]) + } else if len(rangeParts) == 2 { + // If there are two parts, it's a range + start, err := strconv.Atoi(rangeParts[0]) + if err != nil { + return nil, err + } + end, err := strconv.Atoi(rangeParts[1]) + if err != nil { + return nil, err + } + // Add all indices in the range to the result + for i := start; i <= end; i++ { + selected = append(selected, addresses[i]) + } + } else { + // Invalid format + return nil, fmt.Errorf("invalid index range: %s", indexRange) + } + } + + return selected, nil +} + +func (cfg *SlurmRestSchedulerConfig) CreateJobMeta(job Job) (*schema.JobMeta, error) { var exclusive int32 if job.Shared == nil { @@ -301,24 +360,11 @@ func CreateJobMeta(job Job) *schema.JobMeta { exclusive = 0 } + totalGPUs := 0 + var resources []*schema.Resource - // Define a regular expression to match "gpu=x" - regex := regexp.MustCompile(`gpu=(\d+)`) - - // Find all matches in the input string - matches := regex.FindAllStringSubmatch(job.TresAllocStr, -1) - - // Initialize a variable to store the total number of GPUs - var totalGPUs int32 - // Iterate through the matches - match := matches[0] - if len(match) == 2 { - gpuCount, _ := strconv.Atoi(match[1]) - totalGPUs += int32(gpuCount) - } - - for _, node := range job.JobResources.AllocatedNodes { + for nodeIndex, node := range job.JobResources.AllocatedNodes { var res schema.Resource res.Hostname = node.Nodename @@ -334,9 +380,25 @@ func CreateJobMeta(job Job) *schema.JobMeta { res.HWThreads = append(res.HWThreads, threadID) } - // cpu=512,mem=1875G,node=4,billing=512,gres\/gpu=32,gres\/gpu:a40=32 + re := regexp.MustCompile(`\(([^)]*)\)`) + matches := re.FindStringSubmatch(job.GresDetail[nodeIndex]) + + if len(matches) < 2 { + return nil, fmt.Errorf("no substring found in brackets") + } + + nodePCIEAddresses := cfg.clusterPCIEAddressMap[cfg.nodeClusterMap[node.Nodename]] + + selectedPCIEAddresses, err := extractElements(matches[1], nodePCIEAddresses) + + totalGPUs += len(selectedPCIEAddresses) + + if err != nil { + return nil, err + } + // For core/GPU id mapping, need to query from cluster config file - res.Accelerators = append(res.Accelerators, job.Comment) + res.Accelerators = selectedPCIEAddresses resources = append(resources, &res) } @@ -359,7 +421,7 @@ func CreateJobMeta(job Job) *schema.JobMeta { ArrayJobId: job.ArrayJobID, NumNodes: job.NodeCount, NumHWThreads: job.CPUs, - NumAcc: totalGPUs, + NumAcc: int32(totalGPUs), Exclusive: exclusive, // MonitoringStatus: job.MonitoringStatus, // SMT: job.TasksPerCore, @@ -387,7 +449,7 @@ func CreateJobMeta(job Job) *schema.JobMeta { } // log.Debugf("Generated JobMeta %v", req.BaseJob.JobID) - return meta + return meta, nil } func (cfg *SlurmRestSchedulerConfig) HandleJobs(jobs []Job) error { @@ -405,7 +467,7 @@ func (cfg *SlurmRestSchedulerConfig) HandleJobs(jobs []Job) error { if job.JobState == "RUNNING" { - meta := CreateJobMeta(job) + meta, _ := cfg.CreateJobMeta(job) // For all running jobs from Slurm _, notFoundError := cfg.JobRepository.Find(&job.JobID, &job.Cluster, &job.StartTime)