diff --git a/api/rest.go b/api/rest.go index 9361d56..7d367d9 100644 --- a/api/rest.go +++ b/api/rest.go @@ -212,7 +212,7 @@ func (api *RestApi) startJob(rw http.ResponseWriter, r *http.Request) { } // Check if combination of (job_id, cluster_id, start_time) already exists: - job, err := api.JobRepository.Find(req.JobID, req.Cluster, req.StartTime) + job, err := api.JobRepository.Find(&req.JobID, &req.Cluster, &req.StartTime) if err != nil && err != sql.ErrNoRows { handleError(fmt.Errorf("checking for duplicate failed: %w", err), http.StatusInternalServerError, rw) return @@ -282,12 +282,12 @@ func (api *RestApi) stopJob(rw http.ResponseWriter, r *http.Request) { job, err = api.JobRepository.FindById(id) } else { - if req.JobId == nil || req.Cluster == nil || req.StartTime == nil { - handleError(errors.New("the fields 'jobId', 'cluster' and 'startTime' are required"), http.StatusBadRequest, rw) + if req.JobId == nil { + handleError(errors.New("the field 'jobId' is required"), http.StatusBadRequest, rw) return } - job, err = api.JobRepository.Find(*req.JobId, *req.Cluster, *req.StartTime) + job, err = api.JobRepository.Find(req.JobId, req.Cluster, req.StartTime) } if err != nil { handleError(fmt.Errorf("finding job failed: %w", err), http.StatusUnprocessableEntity, rw) diff --git a/repository/job.go b/repository/job.go index 567507f..d38a133 100644 --- a/repository/job.go +++ b/repository/job.go @@ -23,13 +23,19 @@ type JobRepository struct { // It returns a pointer to a schema.Job data structure and an error variable. // To check if no job was found test err == sql.ErrNoRows func (r *JobRepository) Find( - jobId int64, - cluster string, - startTime int64) (*schema.Job, error) { + jobId *int64, + cluster *string, + startTime *int64) (*schema.Job, error) { + qb := sq.Select(schema.JobColumns...).From("job"). - Where("job.job_id = ?", jobId). - Where("job.cluster = ?", cluster). - Where("job.start_time = ?", startTime) + Where("job.job_id = ?", jobId) + + if cluster != nil { + qb = qb.Where("job.cluster = ?", *cluster) + } + if startTime != nil { + qb = qb.Where("job.start_time = ?", *startTime) + } sqlQuery, args, err := qb.ToSql() if err != nil {