diff --git a/rest-api/api/pkg/api/handler/allocation.go b/rest-api/api/pkg/api/handler/allocation.go index 4c6455b556..d4f0eb5c87 100644 --- a/rest-api/api/pkg/api/handler/allocation.go +++ b/rest-api/api/pkg/api/handler/allocation.go @@ -361,7 +361,11 @@ func (cah CreateAllocationHandler) Handle(c echo.Context) error { imAcAdd := []cdbm.AllocationConstraint{} imAcUpd := []cdbm.AllocationConstraint{} for _, ac := range dbacs { - retac, serr := acDAO.CreateFromParams(ctx, tx, a.ID, ac.ResourceType, ac.ResourceTypeID, ac.ConstraintType, ac.ConstraintValue, ac.DerivedResourceID, dbUser.ID) + retac, serr := acDAO.Create(ctx, tx, cdbm.AllocationConstraintCreateInput{ + AllocationID: a.ID, ResourceType: ac.ResourceType, ResourceTypeID: ac.ResourceTypeID, + ConstraintType: ac.ConstraintType, ConstraintValue: ac.ConstraintValue, + DerivedResourceID: ac.DerivedResourceID, CreatedBy: dbUser.ID, + }) if serr != nil { logger.Error().Err(serr).Msg("error creating Allocation Constraint DB entry") return cutil.NewAPIError(http.StatusInternalServerError, "Failed to create Allocation Constraint entry for Allocation", nil) @@ -750,7 +754,7 @@ func (gaah GetAllAllocationHandler) Handle(c echo.Context) error { // Get allocation constraints based on allocation filter by resource type acDAO := cdbm.NewAllocationConstraintDAO(gaah.dbSession) - alcs, _, err := acDAO.GetAll(ctx, nil, aids, nil, nil, nil, nil, nil, nil, cutil.GetPtr(cdbp.TotalLimit), nil) + alcs, _, err := acDAO.GetAll(ctx, nil, cdbm.AllocationConstraintFilterInput{AllocationIDs: aids}, cdbp.PageInput{Limit: cutil.GetPtr(cdbp.TotalLimit)}, nil) if err != nil { logger.Error().Err(err).Msg("error retrieving Allocation Constraints for Allocations from DB") return cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Failed to populate Constraints for Allocations", nil) @@ -884,7 +888,7 @@ func (gah GetAllocationHandler) Handle(c echo.Context) error { return cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Failed to retrieve Status Details for Allocation", nil) } acDAO := cdbm.NewAllocationConstraintDAO(gah.dbSession) - acs, _, err := acDAO.GetAll(ctx, nil, []uuid.UUID{a.ID}, nil, nil, nil, nil, nil, nil, nil, nil) + acs, _, err := acDAO.GetAll(ctx, nil, cdbm.AllocationConstraintFilterInput{AllocationIDs: []uuid.UUID{a.ID}}, cdbp.PageInput{}, nil) if err != nil { logger.Error().Err(err).Msg("error retrieving Allocation Constraints for Allocation from DB") return cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Failed to retrieve Allocation Constraints for Allocation", nil) @@ -1059,7 +1063,10 @@ func (uah UpdateAllocationHandler) Handle(c echo.Context) error { // If this was an IP Block allocation, then update the derived resource name // Get IP Block Allocation Constraints, if any acDAO := cdbm.NewAllocationConstraintDAO(uah.dbSession) - ipbAcs, _, derr := acDAO.GetAll(ctx, tx, []uuid.UUID{a.ID}, cutil.GetPtr(cdbm.AllocationResourceTypeIPBlock), nil, nil, nil, nil, nil, nil, nil) + ipbAcs, _, derr := acDAO.GetAll(ctx, tx, cdbm.AllocationConstraintFilterInput{ + AllocationIDs: []uuid.UUID{a.ID}, + ResourceType: cutil.GetPtr(cdbm.AllocationResourceTypeIPBlock), + }, cdbp.PageInput{}, nil) if derr != nil { logger.Error().Err(derr).Msg("error retrieving Allocation Constraints for Allocation from DB") return cutil.NewAPIError(http.StatusInternalServerError, "Failed to retrieve Allocation Constraints for Allocation", nil) @@ -1097,7 +1104,7 @@ func (uah UpdateAllocationHandler) Handle(c echo.Context) error { ssds = retSsds acDAO := cdbm.NewAllocationConstraintDAO(uah.dbSession) - retAcs, _, derr := acDAO.GetAll(ctx, tx, []uuid.UUID{a.ID}, nil, nil, nil, nil, nil, nil, nil, nil) + retAcs, _, derr := acDAO.GetAll(ctx, tx, cdbm.AllocationConstraintFilterInput{AllocationIDs: []uuid.UUID{a.ID}}, cdbp.PageInput{}, nil) if derr != nil { logger.Error().Err(derr).Msg("error retrieving Allocation Constraints for Allocation from DB") return cutil.NewAPIError(http.StatusInternalServerError, "Failed to retrieve Allocation Constraints for Allocation", nil) @@ -1233,7 +1240,7 @@ func (dah DeleteAllocationHandler) Handle(c echo.Context) error { // check dependent objects (instances or subnets for the tenant) in allocation constraints for the allocation acDAO := cdbm.NewAllocationConstraintDAO(dah.dbSession) - acs, _, derr := acDAO.GetAll(ctx, tx, []uuid.UUID{a.ID}, nil, nil, nil, nil, nil, nil, cutil.GetPtr(cdbp.TotalLimit), nil) + acs, _, derr := acDAO.GetAll(ctx, tx, cdbm.AllocationConstraintFilterInput{AllocationIDs: []uuid.UUID{a.ID}}, cdbp.PageInput{Limit: cutil.GetPtr(cdbp.TotalLimit)}, nil) if derr != nil && derr != cdb.ErrDoesNotExist { logger.Error().Err(derr).Msg("error retrieving Allocation Constraints from DB") return cutil.NewAPIError(http.StatusInternalServerError, "Error getting allocation constraints for allocation", nil) diff --git a/rest-api/api/pkg/api/handler/allocationconstraint.go b/rest-api/api/pkg/api/handler/allocationconstraint.go index bd6a5ad6e5..45fd0dbff2 100644 --- a/rest-api/api/pkg/api/handler/allocationconstraint.go +++ b/rest-api/api/pkg/api/handler/allocationconstraint.go @@ -243,14 +243,13 @@ func (uach UpdateAllocationConstraintHandler) Handle(c echo.Context) error { allocConstraints, _, derr := acDAO.GetAll( ctx, tx, - allocationIDs, - cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), - []uuid.UUID{dbit.ID}, - cutil.GetPtr(cdbm.AllocationConstraintTypeReserved), - nil, - nil, - nil, - cutil.GetPtr(paginator.TotalLimit), + cdbm.AllocationConstraintFilterInput{ + AllocationIDs: allocationIDs, + ResourceType: cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), + ResourceTypeIDs: []uuid.UUID{dbit.ID}, + ConstraintType: cutil.GetPtr(cdbm.AllocationConstraintTypeReserved), + }, + paginator.PageInput{Limit: cutil.GetPtr(paginator.TotalLimit)}, nil, ) if derr != nil { @@ -424,7 +423,10 @@ func (uach UpdateAllocationConstraintHandler) Handle(c echo.Context) error { } } - newac, derr := acDAO.UpdateFromParams(ctx, tx, ac.ID, nil, nil, nil, nil, cutil.GetPtr(apiRequest.ConstraintValue), nil) + newac, derr := acDAO.Update(ctx, tx, cdbm.AllocationConstraintUpdateInput{ + AllocationConstraintID: ac.ID, + ConstraintValue: cutil.GetPtr(apiRequest.ConstraintValue), + }) if derr != nil { logger.Error().Err(derr).Msg("error updating Allocation Constraint in DB") return nil, cutil.NewAPIError(http.StatusInternalServerError, "Failed to update Allocation Constraint with new constraint value, DB error", nil) diff --git a/rest-api/api/pkg/api/handler/allocationconstraint_test.go b/rest-api/api/pkg/api/handler/allocationconstraint_test.go index b4f67bd20c..d3175886f2 100644 --- a/rest-api/api/pkg/api/handler/allocationconstraint_test.go +++ b/rest-api/api/pkg/api/handler/allocationconstraint_test.go @@ -19,6 +19,7 @@ import ( cutil "github.com/NVIDIA/infra-controller/rest-api/common/pkg/util" "github.com/NVIDIA/infra-controller/rest-api/db/pkg/db/ipam" cdbm "github.com/NVIDIA/infra-controller/rest-api/db/pkg/db/model" + cdbp "github.com/NVIDIA/infra-controller/rest-api/db/pkg/db/paginator" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -162,19 +163,19 @@ func TestAllocationConstraintHandler_Update(t *testing.T) { // Get Allocation Constraints for above Allocations acDAO := cdbm.NewAllocationConstraintDAO(dbSession) - acsit1, _, _ := acDAO.GetAll(ctx, nil, []uuid.UUID{aID1}, nil, nil, nil, nil, nil, nil, nil, nil) + acsit1, _, _ := acDAO.GetAll(ctx, nil, cdbm.AllocationConstraintFilterInput{AllocationIDs: []uuid.UUID{aID1}}, cdbp.PageInput{}, nil) assert.NotNil(t, acsit1) - acsip1, _, _ := acDAO.GetAll(ctx, nil, []uuid.UUID{aipID1}, nil, nil, nil, nil, nil, nil, nil, nil) + acsip1, _, _ := acDAO.GetAll(ctx, nil, cdbm.AllocationConstraintFilterInput{AllocationIDs: []uuid.UUID{aipID1}}, cdbp.PageInput{}, nil) assert.NotNil(t, acsip1) - acsit2, _, _ := acDAO.GetAll(ctx, nil, []uuid.UUID{aID2}, nil, nil, nil, nil, nil, nil, nil, nil) + acsit2, _, _ := acDAO.GetAll(ctx, nil, cdbm.AllocationConstraintFilterInput{AllocationIDs: []uuid.UUID{aID2}}, cdbp.PageInput{}, nil) assert.NotNil(t, acsit2) - acsip2, _, _ := acDAO.GetAll(ctx, nil, []uuid.UUID{aipID2}, nil, nil, nil, nil, nil, nil, nil, nil) + acsip2, _, _ := acDAO.GetAll(ctx, nil, cdbm.AllocationConstraintFilterInput{AllocationIDs: []uuid.UUID{aipID2}}, cdbp.PageInput{}, nil) assert.NotNil(t, acsip2) - acsip3, _, _ := acDAO.GetAll(ctx, nil, []uuid.UUID{aipID3}, nil, nil, nil, nil, nil, nil, nil, nil) + acsip3, _, _ := acDAO.GetAll(ctx, nil, cdbm.AllocationConstraintFilterInput{AllocationIDs: []uuid.UUID{aipID3}}, cdbp.PageInput{}, nil) assert.NotNil(t, acsip3) // Setup test data for Allocation Constraint Update diff --git a/rest-api/api/pkg/api/handler/instance_test.go b/rest-api/api/pkg/api/handler/instance_test.go index 2671f354dc..8e2f5ced84 100644 --- a/rest-api/api/pkg/api/handler/instance_test.go +++ b/rest-api/api/pkg/api/handler/instance_test.go @@ -247,7 +247,10 @@ func testInstanceSiteBuildAllocation(t *testing.T, dbSession *cdb.Session, st *c func testInstanceSiteBuildAllocationContraints(t *testing.T, dbSession *cdb.Session, al *cdbm.Allocation, rt string, rtID uuid.UUID, ct string, cv int, user *cdbm.User) *cdbm.AllocationConstraint { alctDAO := cdbm.NewAllocationConstraintDAO(dbSession) - alct, err := alctDAO.CreateFromParams(context.Background(), nil, al.ID, rt, rtID, ct, cv, nil, user.ID) + alct, err := alctDAO.Create(context.Background(), nil, cdbm.AllocationConstraintCreateInput{ + AllocationID: al.ID, ResourceType: rt, ResourceTypeID: rtID, + ConstraintType: ct, ConstraintValue: cv, CreatedBy: user.ID, + }) assert.Nil(t, err) return alct diff --git a/rest-api/api/pkg/api/handler/instancetype.go b/rest-api/api/pkg/api/handler/instancetype.go index 99bffd12e6..f5a3d13e10 100644 --- a/rest-api/api/pkg/api/handler/instancetype.go +++ b/rest-api/api/pkg/api/handler/instancetype.go @@ -1411,7 +1411,10 @@ func (dith DeleteInstanceTypeHandler) Handle(c echo.Context) error { return cutil.NewAPIError(http.StatusBadRequest, "Instance Type is being used by one or more Instances and cannot be deleted", nil) } - acs, _, derr := acDAO.GetAll(ctx, tx, nil, &resourceType, []uuid.UUID{it.ID}, nil, nil, nil, nil, nil, nil) + acs, _, derr := acDAO.GetAll(ctx, tx, cdbm.AllocationConstraintFilterInput{ + ResourceType: &resourceType, + ResourceTypeIDs: []uuid.UUID{it.ID}, + }, cdbp.PageInput{}, nil) if derr != nil { logger.Error().Err(derr).Msg("error retrieving Allocation Constraints for Instance Type from DB") return cutil.NewAPIError(http.StatusInternalServerError, "Failed to retrieve Allocation Constraints for Instance Type", nil) diff --git a/rest-api/api/pkg/api/handler/ipblock.go b/rest-api/api/pkg/api/handler/ipblock.go index 2e561ccbb7..6c72c12f73 100644 --- a/rest-api/api/pkg/api/handler/ipblock.go +++ b/rest-api/api/pkg/api/handler/ipblock.go @@ -637,7 +637,9 @@ func (gadipbh GetAllDerivedIPBlockHandler) Handle(c echo.Context) error { // Get allocation constraints by resourcetype ID (parent IPBlock) acDAO := cdbm.NewAllocationConstraintDAO(gadipbh.dbSession) - acs, _, err := acDAO.GetAll(ctx, nil, nil, nil, []uuid.UUID{ipb.ID}, nil, nil, nil, nil, cutil.GetPtr(cdbp.TotalLimit), nil) + acs, _, err := acDAO.GetAll(ctx, nil, cdbm.AllocationConstraintFilterInput{ + ResourceTypeIDs: []uuid.UUID{ipb.ID}, + }, cdbp.PageInput{Limit: cutil.GetPtr(cdbp.TotalLimit)}, nil) if err != nil { logger.Error().Err(err).Msg("error retrieving Allocation Constraints for parent IPBlock from DB") return cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Failed to retrieve Allocation Constraints for parent IPBlock", nil) @@ -1129,7 +1131,10 @@ func (dipbh DeleteIPBlockHandler) Handle(c echo.Context) error { // Verify that the IPBlock does not have any allocations associated with it acDAO := cdbm.NewAllocationConstraintDAO(dipbh.dbSession) - _, acCount, err := acDAO.GetAll(ctx, nil, nil, cutil.GetPtr(cdbm.AllocationResourceTypeIPBlock), []uuid.UUID{ipb.ID}, nil, nil, nil, nil, nil, nil) + _, acCount, err := acDAO.GetAll(ctx, nil, cdbm.AllocationConstraintFilterInput{ + ResourceType: cutil.GetPtr(cdbm.AllocationResourceTypeIPBlock), + ResourceTypeIDs: []uuid.UUID{ipb.ID}, + }, cdbp.PageInput{}, nil) if err != nil { logger.Error().Err(err).Msg("error getting allocation constraints") return cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Error retrieving Allocations for IP Block", nil) diff --git a/rest-api/api/pkg/api/handler/ipblock_test.go b/rest-api/api/pkg/api/handler/ipblock_test.go index 061d853397..b3ab061879 100644 --- a/rest-api/api/pkg/api/handler/ipblock_test.go +++ b/rest-api/api/pkg/api/handler/ipblock_test.go @@ -205,7 +205,11 @@ func testIPBlockBuildAllocation(t *testing.T, dbSession *cdb.Session, st *cdbm.S func testIPBlockBuildAllocationConstraint(t *testing.T, dbSession *cdb.Session, allocationID uuid.UUID, resourceType string, resourceTypeID uuid.UUID, constraintType string, constraintValue int, derivedResourceID *uuid.UUID, createdBy uuid.UUID) *cdbm.AllocationConstraint { alcDAO := cdbm.NewAllocationConstraintDAO(dbSession) - alc, err := alcDAO.CreateFromParams(context.Background(), nil, allocationID, resourceType, resourceTypeID, constraintType, constraintValue, derivedResourceID, createdBy) + alc, err := alcDAO.Create(context.Background(), nil, cdbm.AllocationConstraintCreateInput{ + AllocationID: allocationID, ResourceType: resourceType, ResourceTypeID: resourceTypeID, + ConstraintType: constraintType, ConstraintValue: constraintValue, + DerivedResourceID: derivedResourceID, CreatedBy: createdBy, + }) assert.Nil(t, err) return alc diff --git a/rest-api/api/pkg/api/handler/machine_test.go b/rest-api/api/pkg/api/handler/machine_test.go index 78bdb139c4..23ea9e6c03 100644 --- a/rest-api/api/pkg/api/handler/machine_test.go +++ b/rest-api/api/pkg/api/handler/machine_test.go @@ -253,7 +253,10 @@ func testMachineBuildAllocation(t *testing.T, dbSession *cdb.Session, ip *cdbm.I func testMachineBuildAllocationContraints(t *testing.T, dbSession *cdb.Session, al *cdbm.Allocation, rt string, rtID uuid.UUID, ct string, cv int, user *cdbm.User) *cdbm.AllocationConstraint { alctDAO := cdbm.NewAllocationConstraintDAO(dbSession) - alct, err := alctDAO.CreateFromParams(context.Background(), nil, al.ID, rt, rtID, ct, cv, nil, user.ID) + alct, err := alctDAO.Create(context.Background(), nil, cdbm.AllocationConstraintCreateInput{ + AllocationID: al.ID, ResourceType: rt, ResourceTypeID: rtID, + ConstraintType: ct, ConstraintValue: cv, CreatedBy: user.ID, + }) assert.Nil(t, err) return alct diff --git a/rest-api/api/pkg/api/handler/machineinstancetype_test.go b/rest-api/api/pkg/api/handler/machineinstancetype_test.go index 6f2ca222c3..6fd2a00096 100644 --- a/rest-api/api/pkg/api/handler/machineinstancetype_test.go +++ b/rest-api/api/pkg/api/handler/machineinstancetype_test.go @@ -592,7 +592,7 @@ func TestDeleteMachineInstanceTypeHandler_Handle(t *testing.T) { al, err := alDAO.GetByID(context.Background(), nil, uuid.MustParse(apial.ID), nil) assert.NoError(t, err) - alcs, _, err := alcDAO.GetAll(context.Background(), nil, []uuid.UUID{al.ID}, nil, nil, nil, nil, nil, nil, nil, nil) + alcs, _, err := alcDAO.GetAll(context.Background(), nil, cdbm.AllocationConstraintFilterInput{AllocationIDs: []uuid.UUID{al.ID}}, paginator.PageInput{}, nil) assert.NoError(t, err) assert.Equal(t, 1, len(alcs)) diff --git a/rest-api/api/pkg/api/handler/stats.go b/rest-api/api/pkg/api/handler/stats.go index 5076a43d56..98c4c2e9f8 100644 --- a/rest-api/api/pkg/api/handler/stats.go +++ b/rest-api/api/pkg/api/handler/stats.go @@ -191,9 +191,11 @@ func (gtitsh GetTenantInstanceTypeStatsHandler) Handle(c echo.Context) error { var constraints []cdbm.AllocationConstraint if len(allocationIDs) > 0 { acDAO := cdbm.NewAllocationConstraintDAO(gtitsh.dbSession) - constraints, _, err = acDAO.GetAll(ctx, nil, allocationIDs, - cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), instanceTypeIDs, - nil, nil, []string{"Allocation.Tenant"}, nil, cutil.GetPtr(cdbp.TotalLimit), nil) + constraints, _, err = acDAO.GetAll(ctx, nil, cdbm.AllocationConstraintFilterInput{ + AllocationIDs: allocationIDs, + ResourceType: cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), + ResourceTypeIDs: instanceTypeIDs, + }, cdbp.PageInput{Limit: cutil.GetPtr(cdbp.TotalLimit)}, []string{"Allocation.Tenant"}) if err != nil { logger.Error().Err(err).Msg("error retrieving allocation constraints") return cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Failed to retrieve allocation constraints", nil) @@ -526,9 +528,11 @@ func (gmitsh GetMachineInstanceTypeStatsHandler) Handle(c echo.Context) error { var constraints []cdbm.AllocationConstraint if len(allocationIDs) > 0 { acDAO := cdbm.NewAllocationConstraintDAO(gmitsh.dbSession) - constraints, _, err = acDAO.GetAll(ctx, nil, allocationIDs, - cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), instanceTypeIDs, - nil, nil, []string{"Allocation.Tenant"}, nil, cutil.GetPtr(cdbp.TotalLimit), nil) + constraints, _, err = acDAO.GetAll(ctx, nil, cdbm.AllocationConstraintFilterInput{ + AllocationIDs: allocationIDs, + ResourceType: cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), + ResourceTypeIDs: instanceTypeIDs, + }, cdbp.PageInput{Limit: cutil.GetPtr(cdbp.TotalLimit)}, []string{"Allocation.Tenant"}) if err != nil { logger.Error().Err(err).Msg("error retrieving allocation constraints") return cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Failed to retrieve allocation constraints", nil) diff --git a/rest-api/api/pkg/api/handler/util/common/common.go b/rest-api/api/pkg/api/handler/util/common/common.go index f7875d7447..5545f20020 100644 --- a/rest-api/api/pkg/api/handler/util/common/common.go +++ b/rest-api/api/pkg/api/handler/util/common/common.go @@ -215,7 +215,12 @@ func GetAllocationConstraintsForInstanceType(ctx context.Context, tx *cdb.Tx, db var alconstraints []cdbm.AllocationConstraint for _, ac := range allocations { // improve this query by adding allocation slices in allocation constraints model - alcoss, _, err := alcsDAO.GetAll(ctx, tx, []uuid.UUID{ac.ID}, cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), []uuid.UUID{instancetype.ID}, cutil.GetPtr(cdbm.AllocationConstraintTypeReserved), nil, nil, nil, cutil.GetPtr(cdbp.TotalLimit), nil) + alcoss, _, err := alcsDAO.GetAll(ctx, tx, cdbm.AllocationConstraintFilterInput{ + AllocationIDs: []uuid.UUID{ac.ID}, + ResourceType: cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), + ResourceTypeIDs: []uuid.UUID{instancetype.ID}, + ConstraintType: cutil.GetPtr(cdbm.AllocationConstraintTypeReserved), + }, cdbp.PageInput{Limit: cutil.GetPtr(cdbp.TotalLimit)}, nil) if err != nil { return nil, err } @@ -458,7 +463,10 @@ func GetSiteMachineCountStats(ctx context.Context, tx *cdb.Tx, dbSession *cdb.Se // Get all Allocation Constraints for Allocation IDs acDAO := cdbm.NewAllocationConstraintDAO(dbSession) - acs, _, err := acDAO.GetAll(ctx, tx, allocationIDs, cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), nil, nil, nil, nil, nil, cutil.GetPtr(cdbp.TotalLimit), nil) + acs, _, err := acDAO.GetAll(ctx, tx, cdbm.AllocationConstraintFilterInput{ + AllocationIDs: allocationIDs, + ResourceType: cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), + }, cdbp.PageInput{Limit: cutil.GetPtr(cdbp.TotalLimit)}, nil) if err != nil { return nil, err } @@ -495,7 +503,12 @@ func GetTotalAllocationConstraintValueForInstanceType(ctx context.Context, tx *c if instanceTypeID != nil { instanceTypeIDs = []uuid.UUID{*instanceTypeID} } - acs, _, err := acDAO.GetAll(ctx, tx, paramAllocationIDs, cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), instanceTypeIDs, constraintType, nil, nil, nil, cutil.GetPtr(cdbp.TotalLimit), nil) + acs, _, err := acDAO.GetAll(ctx, tx, cdbm.AllocationConstraintFilterInput{ + AllocationIDs: paramAllocationIDs, + ResourceType: cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), + ResourceTypeIDs: instanceTypeIDs, + ConstraintType: constraintType, + }, cdbp.PageInput{Limit: cutil.GetPtr(cdbp.TotalLimit)}, nil) if err != nil { return 0, err } @@ -552,7 +565,11 @@ func GetAllAllocationConstraintsForInstanceType(ctx context.Context, tx *cdb.Tx, for _, alloc := range allocs { allocIDs = append(allocIDs, alloc.ID) } - acs, tot, serr := acDAO.GetAll(ctx, tx, allocIDs, cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), resourceTypeIDs, nil, nil, nil, nil, cutil.GetPtr(cdbp.TotalLimit), nil) + acs, tot, serr := acDAO.GetAll(ctx, tx, cdbm.AllocationConstraintFilterInput{ + AllocationIDs: allocIDs, + ResourceType: cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), + ResourceTypeIDs: resourceTypeIDs, + }, cdbp.PageInput{Limit: cutil.GetPtr(cdbp.TotalLimit)}, nil) if serr != nil { return nil, 0, serr } @@ -646,7 +663,11 @@ func GetAllInstanceTypeAllocationStats(ctx context.Context, dbSession *cdb.Sessi // Get all Allocation Constraints for the Instance Type IDs and Allocation IDs acDAO := cdbm.NewAllocationConstraintDAO(dbSession) - acss, _, err := acDAO.GetAll(ctx, nil, aids, cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), instanceTypeIDs, nil, nil, nil, nil, cutil.GetPtr(cdbp.TotalLimit), nil) + acss, _, err := acDAO.GetAll(ctx, nil, cdbm.AllocationConstraintFilterInput{ + AllocationIDs: aids, + ResourceType: cutil.GetPtr(cdbm.AllocationResourceTypeInstanceType), + ResourceTypeIDs: instanceTypeIDs, + }, cdbp.PageInput{Limit: cutil.GetPtr(cdbp.TotalLimit)}, nil) if err != nil { return nil, cutil.NewAPIError(http.StatusInternalServerError, "Error retrieving Allocations for Instance Type, DB error", nil) } diff --git a/rest-api/api/pkg/api/handler/util/common/testing.go b/rest-api/api/pkg/api/handler/util/common/testing.go index 5f05d2e01b..900aed30eb 100644 --- a/rest-api/api/pkg/api/handler/util/common/testing.go +++ b/rest-api/api/pkg/api/handler/util/common/testing.go @@ -321,10 +321,18 @@ func TestBuildAllocationConstraint(t *testing.T, dbSession *cdb.Session, al *cdb var ac *cdbm.AllocationConstraint var err error if it != nil { - ac, err = acDAO.CreateFromParams(context.Background(), nil, al.ID, cdbm.AllocationResourceTypeInstanceType, it.ID, cdbm.AllocationConstraintTypeReserved, constraintValue, nil, user.ID) + ac, err = acDAO.Create(context.Background(), nil, cdbm.AllocationConstraintCreateInput{ + AllocationID: al.ID, ResourceType: cdbm.AllocationResourceTypeInstanceType, + ResourceTypeID: it.ID, ConstraintType: cdbm.AllocationConstraintTypeReserved, + ConstraintValue: constraintValue, CreatedBy: user.ID, + }) assert.Nil(t, err) } else if ipb != nil { - ac, err = acDAO.CreateFromParams(context.Background(), nil, al.ID, cdbm.AllocationResourceTypeIPBlock, ipb.ID, cdbm.AllocationConstraintTypeReserved, constraintValue, nil, user.ID) + ac, err = acDAO.Create(context.Background(), nil, cdbm.AllocationConstraintCreateInput{ + AllocationID: al.ID, ResourceType: cdbm.AllocationResourceTypeIPBlock, + ResourceTypeID: ipb.ID, ConstraintType: cdbm.AllocationConstraintTypeReserved, + ConstraintValue: constraintValue, CreatedBy: user.ID, + }) assert.Nil(t, err) } return ac diff --git a/rest-api/db/pkg/db/model/allocation_test.go b/rest-api/db/pkg/db/model/allocation_test.go index 99fbccee17..c5c51ff17d 100644 --- a/rest-api/db/pkg/db/model/allocation_test.go +++ b/rest-api/db/pkg/db/model/allocation_test.go @@ -391,11 +391,19 @@ func TestAllocationSQLDAO_GetAll(t *testing.T) { if i%2 == 0 { // Create AllocationConstraint for every other Allocation var serr error - allocationConstraint1, serr = acDAO.CreateFromParams(ctx, nil, at.ID, AllocationResourceTypeInstanceType, it.ID, AllocationConstraintTypeReserved, 5, nil, user.ID) + allocationConstraint1, serr = acDAO.Create(ctx, nil, AllocationConstraintCreateInput{ + AllocationID: at.ID, ResourceType: AllocationResourceTypeInstanceType, + ResourceTypeID: it.ID, ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 5, CreatedBy: user.ID, + }) assert.NoError(t, serr) } else { var serr error - allocationConstraint2, serr = acDAO.CreateFromParams(ctx, nil, at.ID, AllocationResourceTypeIPBlock, ipb.ID, AllocationConstraintTypeReserved, 10, nil, user.ID) + allocationConstraint2, serr = acDAO.Create(ctx, nil, AllocationConstraintCreateInput{ + AllocationID: at.ID, ResourceType: AllocationResourceTypeIPBlock, + ResourceTypeID: ipb.ID, ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 10, CreatedBy: user.ID, + }) assert.NoError(t, serr) } } @@ -1102,9 +1110,17 @@ func TestAllocationSQLDAO_GetCount(t *testing.T) { asTenant1 = append(asTenant1, *at) if i%2 == 0 { // Create AllocationConstraint for every other Allocation - _, serr := acDAO.CreateFromParams(ctx, nil, at.ID, AllocationResourceTypeInstanceType, it.ID, AllocationConstraintTypeReserved, 5, nil, user.ID) + _, serr := acDAO.Create(ctx, nil, AllocationConstraintCreateInput{ + AllocationID: at.ID, ResourceType: AllocationResourceTypeInstanceType, + ResourceTypeID: it.ID, ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 5, CreatedBy: user.ID, + }) assert.NoError(t, serr) - _, serr = acDAO.CreateFromParams(ctx, nil, at.ID, AllocationResourceTypeInstanceType, it2.ID, AllocationConstraintTypeReserved, 5, nil, user.ID) + _, serr = acDAO.Create(ctx, nil, AllocationConstraintCreateInput{ + AllocationID: at.ID, ResourceType: AllocationResourceTypeInstanceType, + ResourceTypeID: it2.ID, ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 5, CreatedBy: user.ID, + }) assert.NoError(t, serr) } } diff --git a/rest-api/db/pkg/db/model/allocationconstraint.go b/rest-api/db/pkg/db/model/allocationconstraint.go index b19a68d27b..51f6e90f1d 100644 --- a/rest-api/db/pkg/db/model/allocationconstraint.go +++ b/rest-api/db/pkg/db/model/allocationconstraint.go @@ -57,6 +57,43 @@ var ( } ) +// AllocationConstraintCreateInput input parameters for Create method +type AllocationConstraintCreateInput struct { + AllocationID uuid.UUID + ResourceType string + ResourceTypeID uuid.UUID + ConstraintType string + ConstraintValue int + DerivedResourceID *uuid.UUID + CreatedBy uuid.UUID +} + +// AllocationConstraintUpdateInput input parameters for Update method +type AllocationConstraintUpdateInput struct { + AllocationConstraintID uuid.UUID + AllocationID *uuid.UUID + ResourceType *string + ResourceTypeID *uuid.UUID + ConstraintType *string + ConstraintValue *int + DerivedResourceID *uuid.UUID +} + +// AllocationConstraintClearInput input parameters for Clear method +type AllocationConstraintClearInput struct { + AllocationConstraintID uuid.UUID + DerivedResourceID bool +} + +// AllocationConstraintFilterInput input parameters for GetAll method +type AllocationConstraintFilterInput struct { + AllocationIDs []uuid.UUID + ResourceType *string + ResourceTypeIDs []uuid.UUID + ConstraintType *string + DerivedResourceID *uuid.UUID +} + // AllocationConstraint represents entries in the allocation_constraint table // Constraints an allocation by specifying limits for different resource types type AllocationConstraint struct { @@ -108,28 +145,15 @@ func (ac *AllocationConstraint) BeforeCreateTable(ctx context.Context, // AllocationConstraintDAO is an interface for interacting with the AllocationConstraint model type AllocationConstraintDAO interface { // - CreateFromParams(ctx context.Context, tx *db.Tx, - allocationID uuid.UUID, resourceType string, - resourceTypeID uuid.UUID, constraintType string, - constraintValue int, derivedResourceID *uuid.UUID, - createdBy uuid.UUID) (*AllocationConstraint, error) + Create(ctx context.Context, tx *db.Tx, input AllocationConstraintCreateInput) (*AllocationConstraint, error) // - GetByID(ctx context.Context, tx *db.Tx, id uuid.UUID, - includeRelations []string) (*AllocationConstraint, error) + GetByID(ctx context.Context, tx *db.Tx, id uuid.UUID, includeRelations []string) (*AllocationConstraint, error) // - GetAll(ctx context.Context, tx *db.Tx, - allocationIDs []uuid.UUID, resourceType *string, - resourceTypeID []uuid.UUID, constraintType *string, - derivedResourceID *uuid.UUID, includeRelations []string, - offset *int, limit *int, orderBy *paginator.OrderBy) ([]AllocationConstraint, int, error) + GetAll(ctx context.Context, tx *db.Tx, filter AllocationConstraintFilterInput, page paginator.PageInput, includeRelations []string) ([]AllocationConstraint, int, error) // - UpdateFromParams(ctx context.Context, tx *db.Tx, id uuid.UUID, - allocationID *uuid.UUID, resourceType *string, - resourceTypeID *uuid.UUID, constraintType *string, - constraintValue *int, derivedResourceID *uuid.UUID) (*AllocationConstraint, error) + Update(ctx context.Context, tx *db.Tx, input AllocationConstraintUpdateInput) (*AllocationConstraint, error) // - ClearFromParams(ctx context.Context, tx *db.Tx, id uuid.UUID, - derivedResourceID bool) (*AllocationConstraint, error) + Clear(ctx context.Context, tx *db.Tx, input AllocationConstraintClearInput) (*AllocationConstraint, error) // DeleteByID(ctx context.Context, tx *db.Tx, id uuid.UUID) error } @@ -141,39 +165,34 @@ type AllocationConstraintSQLDAO struct { tracerSpan *stracer.TracerSpan } -// CreateFromParams creates a new AllocationConstraint from the given parameters -// The returned AllocationConstraint will not have any related structs filled in -// since there are 2 operations (INSERT, SELECT), in this, it is required that -// this library call happens within a transaction -func (acd AllocationConstraintSQLDAO) CreateFromParams( - ctx context.Context, tx *db.Tx, allocationID uuid.UUID, - resourceType string, resourceTypeID uuid.UUID, - constraintType string, constraintValue int, - derivedResourceID *uuid.UUID, createdBy uuid.UUID) (*AllocationConstraint, error) { +// Create creates a new AllocationConstraint from the given input. +// The returned AllocationConstraint will not have any related structs filled in. +// Since there are 2 operations (INSERT, SELECT), this call must happen within a transaction. +func (acd AllocationConstraintSQLDAO) Create( + ctx context.Context, tx *db.Tx, input AllocationConstraintCreateInput) (*AllocationConstraint, error) { // Create a child span and set the attributes for current request - ctx, aDAOSpan := acd.tracerSpan.CreateChildInCurrentContext(ctx, "AllocationConstraintDAO.CreateFromParams") + ctx, aDAOSpan := acd.tracerSpan.CreateChildInCurrentContext(ctx, "AllocationConstraintDAO.Create") if aDAOSpan != nil { defer aDAOSpan.End() - acd.tracerSpan.SetAttribute(aDAOSpan, "allocation_id", allocationID.String()) + acd.tracerSpan.SetAttribute(aDAOSpan, "allocation_id", input.AllocationID.String()) } - // - if len(strings.TrimSpace(resourceType)) == 0 { + if len(strings.TrimSpace(input.ResourceType)) == 0 { return nil, errors.New("resourceType is empty") } - if len(strings.TrimSpace(constraintType)) == 0 { + if len(strings.TrimSpace(input.ConstraintType)) == 0 { return nil, errors.New("constraintType is empty") } a := &AllocationConstraint{ ID: uuid.New(), - AllocationID: allocationID, - ResourceType: resourceType, - ResourceTypeID: resourceTypeID, - ConstraintType: constraintType, - ConstraintValue: constraintValue, - DerivedResourceID: derivedResourceID, - CreatedBy: createdBy, + AllocationID: input.AllocationID, + ResourceType: input.ResourceType, + ResourceTypeID: input.ResourceTypeID, + ConstraintType: input.ConstraintType, + ConstraintValue: input.ConstraintValue, + DerivedResourceID: input.DerivedResourceID, + CreatedBy: input.CreatedBy, } _, err := db.GetIDB(tx, acd.dbSession).NewInsert().Model(a).Exec(ctx) if err != nil { @@ -219,15 +238,12 @@ func (acd AllocationConstraintSQLDAO) GetByID(ctx context.Context, tx *db.Tx, id return a, nil } -// GetAll returns all AllocationConstraints for an InstanceType -// Errors are returned only when there is a db related error -// if records not found, then error is nil, but length of returned slice is 0 -// if orderBy is nil, then records are ordered by column specified in AllocationConstraintOrderByDefault in ascending order +// GetAll returns all AllocationConstraints matching the given filter. +// Errors are returned only when there is a db related error. +// If records not found, then error is nil, but length of returned slice is 0. +// If orderBy is nil, then records are ordered by column specified in AllocationConstraintOrderByDefault in ascending order. func (acd AllocationConstraintSQLDAO) GetAll(ctx context.Context, tx *db.Tx, - allocationIDs []uuid.UUID, resourceType *string, - resourceTypeIDs []uuid.UUID, constraintType *string, - derivedResourceID *uuid.UUID, includeRelations []string, - offset *int, limit *int, orderBy *paginator.OrderBy) ([]AllocationConstraint, int, error) { + filter AllocationConstraintFilterInput, page paginator.PageInput, includeRelations []string) ([]AllocationConstraint, int, error) { acs := []AllocationConstraint{} // Create a child span and set the attributes for current request ctx, aDAOSpan := acd.tracerSpan.CreateChildInCurrentContext(ctx, "AllocationConstraintDAO.GetAll") @@ -237,47 +253,51 @@ func (acd AllocationConstraintSQLDAO) GetAll(ctx context.Context, tx *db.Tx, query := db.GetIDB(tx, acd.dbSession).NewSelect().Model(&acs) - if allocationIDs != nil { - if len(allocationIDs) == 1 { - query = query.Where("ac.allocation_id = ?", allocationIDs[0]) + if filter.AllocationIDs != nil { + if len(filter.AllocationIDs) == 1 { + query = query.Where("ac.allocation_id = ?", filter.AllocationIDs[0]) } else { - query = query.Where("ac.allocation_id IN (?)", bun.In(allocationIDs)) + query = query.Where("ac.allocation_id IN (?)", bun.In(filter.AllocationIDs)) + } + + if aDAOSpan != nil { + acd.tracerSpan.SetAttribute(aDAOSpan, "allocation_ids", filter.AllocationIDs) } } - if resourceType != nil { - query = query.Where("ac.resource_type = ?", *resourceType) + if filter.ResourceType != nil { + query = query.Where("ac.resource_type = ?", *filter.ResourceType) if aDAOSpan != nil { - acd.tracerSpan.SetAttribute(aDAOSpan, "resource_type", *resourceType) + acd.tracerSpan.SetAttribute(aDAOSpan, "resource_type", *filter.ResourceType) } } - if resourceTypeIDs != nil { - if len(resourceTypeIDs) == 1 { - query = query.Where("ac.resource_type_id = ?", resourceTypeIDs[0]) + if filter.ResourceTypeIDs != nil { + if len(filter.ResourceTypeIDs) == 1 { + query = query.Where("ac.resource_type_id = ?", filter.ResourceTypeIDs[0]) } else { - query = query.Where("ac.resource_type_id IN (?)", bun.In(resourceTypeIDs)) + query = query.Where("ac.resource_type_id IN (?)", bun.In(filter.ResourceTypeIDs)) } if aDAOSpan != nil { - acd.tracerSpan.SetAttribute(aDAOSpan, "resource_type_ids", resourceTypeIDs) + acd.tracerSpan.SetAttribute(aDAOSpan, "resource_type_ids", filter.ResourceTypeIDs) } } - if constraintType != nil { - query = query.Where("ac.constraint_type = ?", *constraintType) + if filter.ConstraintType != nil { + query = query.Where("ac.constraint_type = ?", *filter.ConstraintType) if aDAOSpan != nil { - acd.tracerSpan.SetAttribute(aDAOSpan, "constraint_type", *constraintType) + acd.tracerSpan.SetAttribute(aDAOSpan, "constraint_type", *filter.ConstraintType) } } - if derivedResourceID != nil { - query = query.Where("ac.derived_resource_id = ?", *derivedResourceID) + if filter.DerivedResourceID != nil { + query = query.Where("ac.derived_resource_id = ?", *filter.DerivedResourceID) if aDAOSpan != nil { - acd.tracerSpan.SetAttribute(aDAOSpan, "derived_resource_id", *derivedResourceID) + acd.tracerSpan.SetAttribute(aDAOSpan, "derived_resource_id", filter.DerivedResourceID.String()) } } @@ -286,11 +306,11 @@ func (acd AllocationConstraintSQLDAO) GetAll(ctx context.Context, tx *db.Tx, } // if no order is passed, set default to make sure objects return always in the same order and pagination works properly - if orderBy == nil { - orderBy = paginator.NewDefaultOrderBy(AllocationConstraintOrderByDefault) + if page.OrderBy == nil { + page.OrderBy = paginator.NewDefaultOrderBy(AllocationConstraintOrderByDefault) } - paginator, err := paginator.NewPaginator(ctx, query, offset, limit, orderBy, SiteOrderByFields) + paginator, err := paginator.NewPaginator(ctx, query, page.Offset, page.Limit, page.OrderBy, AllocationConstraintOrderByFields) if err != nil { return nil, 0, err } @@ -303,87 +323,83 @@ func (acd AllocationConstraintSQLDAO) GetAll(ctx context.Context, tx *db.Tx, return acs, paginator.Total, nil } -// UpdateFromParams updates specified fields of an existing AllocationConstraint -// The updated fields are assumed to be set to non-null values -// since there are 2 operations (UPDATE, SELECT), in this, it is required that -// this library call happens within a transaction -func (acd AllocationConstraintSQLDAO) UpdateFromParams(ctx context.Context, tx *db.Tx, id uuid.UUID, - allocationID *uuid.UUID, resourceType *string, - resourceTypeID *uuid.UUID, constraintType *string, - constraintValue *int, derivedResourceID *uuid.UUID) (*AllocationConstraint, error) { +// Update updates specified fields of an existing AllocationConstraint. +// The updated fields are assumed to be set to non-null values. +// Since there are 2 operations (UPDATE, SELECT), this call must happen within a transaction. +func (acd AllocationConstraintSQLDAO) Update(ctx context.Context, tx *db.Tx, input AllocationConstraintUpdateInput) (*AllocationConstraint, error) { // Create a child span and set the attributes for current request - ctx, aDAOSpan := acd.tracerSpan.CreateChildInCurrentContext(ctx, "AllocationConstraintDAO.UpdateFromParams") + ctx, aDAOSpan := acd.tracerSpan.CreateChildInCurrentContext(ctx, "AllocationConstraintDAO.Update") if aDAOSpan != nil { defer aDAOSpan.End() - acd.tracerSpan.SetAttribute(aDAOSpan, "id", id.String()) + acd.tracerSpan.SetAttribute(aDAOSpan, "id", input.AllocationConstraintID.String()) } a := &AllocationConstraint{ - ID: id, + ID: input.AllocationConstraintID, } updatedFields := []string{} - if allocationID != nil { - a.AllocationID = *allocationID + if input.AllocationID != nil { + a.AllocationID = *input.AllocationID updatedFields = append(updatedFields, "allocation_id") if aDAOSpan != nil { - acd.tracerSpan.SetAttribute(aDAOSpan, "allocation_id", allocationID.String()) + acd.tracerSpan.SetAttribute(aDAOSpan, "allocation_id", input.AllocationID.String()) } } - if resourceType != nil { - if len(strings.TrimSpace(*resourceType)) == 0 { + if input.ResourceType != nil { + if len(strings.TrimSpace(*input.ResourceType)) == 0 { return nil, errors.New("resourceType is empty") } - a.ResourceType = *resourceType + a.ResourceType = *input.ResourceType updatedFields = append(updatedFields, "resource_type") if aDAOSpan != nil { - acd.tracerSpan.SetAttribute(aDAOSpan, "resource_type", *resourceType) + acd.tracerSpan.SetAttribute(aDAOSpan, "resource_type", *input.ResourceType) } } - if resourceTypeID != nil { - a.ResourceTypeID = *resourceTypeID + if input.ResourceTypeID != nil { + a.ResourceTypeID = *input.ResourceTypeID updatedFields = append(updatedFields, "resource_type_id") if aDAOSpan != nil { - acd.tracerSpan.SetAttribute(aDAOSpan, "resource_type_id", resourceTypeID.String()) + acd.tracerSpan.SetAttribute(aDAOSpan, "resource_type_id", input.ResourceTypeID.String()) } } - if constraintType != nil { - if len(strings.TrimSpace(*constraintType)) == 0 { + if input.ConstraintType != nil { + if len(strings.TrimSpace(*input.ConstraintType)) == 0 { return nil, errors.New("constraintType is empty") } - a.ConstraintType = *constraintType + a.ConstraintType = *input.ConstraintType updatedFields = append(updatedFields, "constraint_type") if aDAOSpan != nil { - acd.tracerSpan.SetAttribute(aDAOSpan, "constraint_type", *constraintType) + acd.tracerSpan.SetAttribute(aDAOSpan, "constraint_type", *input.ConstraintType) } } - if constraintValue != nil { - a.ConstraintValue = *constraintValue + if input.ConstraintValue != nil { + a.ConstraintValue = *input.ConstraintValue updatedFields = append(updatedFields, "constraint_value") if aDAOSpan != nil { - acd.tracerSpan.SetAttribute(aDAOSpan, "constraint_value", *constraintValue) + acd.tracerSpan.SetAttribute(aDAOSpan, "constraint_value", *input.ConstraintValue) } } - if derivedResourceID != nil { - a.DerivedResourceID = derivedResourceID + if input.DerivedResourceID != nil { + a.DerivedResourceID = input.DerivedResourceID updatedFields = append(updatedFields, "derived_resource_id") if aDAOSpan != nil { - acd.tracerSpan.SetAttribute(aDAOSpan, "derived_resource_id", derivedResourceID.String()) + acd.tracerSpan.SetAttribute(aDAOSpan, "derived_resource_id", input.DerivedResourceID.String()) } } if len(updatedFields) > 0 { updatedFields = append(updatedFields, "updated") - _, err := db.GetIDB(tx, acd.dbSession).NewUpdate().Model(a).Column(updatedFields...).Where("id = ?", id).Exec(ctx) + _, err := db.GetIDB(tx, acd.dbSession).NewUpdate().Model(a).Column(updatedFields...).Where("id = ?", input.AllocationConstraintID).Exec(ctx) if err != nil { return nil, err } @@ -397,25 +413,23 @@ func (acd AllocationConstraintSQLDAO) UpdateFromParams(ctx context.Context, tx * return nv, nil } -// ClearFromParams sets parameters of an existing AllocationConstraint to null values in db -// since there are 2 operations (UPDATE, SELECT), it is required that -// this must be within a transaction -func (acd AllocationConstraintSQLDAO) ClearFromParams(ctx context.Context, tx *db.Tx, id uuid.UUID, - derivedResourceID bool) (*AllocationConstraint, error) { +// Clear sets parameters of an existing AllocationConstraint to null values in db. +// Since there are 2 operations (UPDATE, SELECT), this must be within a transaction. +func (acd AllocationConstraintSQLDAO) Clear(ctx context.Context, tx *db.Tx, input AllocationConstraintClearInput) (*AllocationConstraint, error) { // Create a child span and set the attributes for current request - ctx, aDAOSpan := acd.tracerSpan.CreateChildInCurrentContext(ctx, "AllocationConstraintDAO.ClearFromParams") + ctx, aDAOSpan := acd.tracerSpan.CreateChildInCurrentContext(ctx, "AllocationConstraintDAO.Clear") if aDAOSpan != nil { defer aDAOSpan.End() - acd.tracerSpan.SetAttribute(aDAOSpan, "id", id.String()) + acd.tracerSpan.SetAttribute(aDAOSpan, "id", input.AllocationConstraintID.String()) } a := &AllocationConstraint{ - ID: id, + ID: input.AllocationConstraintID, } updatedFields := []string{} - if derivedResourceID { + if input.DerivedResourceID { a.DerivedResourceID = nil updatedFields = append(updatedFields, "derived_resource_id") } @@ -423,13 +437,13 @@ func (acd AllocationConstraintSQLDAO) ClearFromParams(ctx context.Context, tx *d if len(updatedFields) > 0 { updatedFields = append(updatedFields, "updated") - _, err := db.GetIDB(tx, acd.dbSession).NewUpdate().Model(a).Column(updatedFields...).Where("id = ?", id).Exec(ctx) + _, err := db.GetIDB(tx, acd.dbSession).NewUpdate().Model(a).Column(updatedFields...).Where("id = ?", input.AllocationConstraintID).Exec(ctx) if err != nil { return nil, err } } - nv, err := acd.GetByID(ctx, tx, id, []string{"Allocation"}) + nv, err := acd.GetByID(ctx, tx, input.AllocationConstraintID, []string{"Allocation"}) if err != nil { return nil, err } diff --git a/rest-api/db/pkg/db/model/allocationconstraint_test.go b/rest-api/db/pkg/db/model/allocationconstraint_test.go index 836381acf7..5133a00649 100644 --- a/rest-api/db/pkg/db/model/allocationconstraint_test.go +++ b/rest-api/db/pkg/db/model/allocationconstraint_test.go @@ -165,7 +165,7 @@ func testAllocationConstraintBuildIPBlock(t *testing.T, dbSession *db.Session, return ipBlock } -func TestAllocationConstraintSQLDAO_CreateFromParams(t *testing.T) { +func TestAllocationConstraintSQLDAO_Create(t *testing.T) { ctx := context.Background() dbSession := testAllocationConstraintInitDB(t) defer dbSession.Close() @@ -189,30 +189,28 @@ func TestAllocationConstraintSQLDAO_CreateFromParams(t *testing.T) { tests := []struct { desc string - as []AllocationConstraint + inputs []AllocationConstraintCreateInput expectError bool verifyChildSpanner bool }{ { desc: "create constraint of InstanceType and IPBlock", - as: []AllocationConstraint{ + inputs: []AllocationConstraintCreateInput{ { - AllocationID: alloc1.ID, - ResourceType: AllocationResourceTypeInstanceType, - ResourceTypeID: insType.ID, - ConstraintType: AllocationConstraintTypeReserved, - ConstraintValue: 100, - DerivedResourceID: nil, - CreatedBy: user.ID, + AllocationID: alloc1.ID, + ResourceType: AllocationResourceTypeInstanceType, + ResourceTypeID: insType.ID, + ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 100, + CreatedBy: user.ID, }, { - AllocationID: alloc2.ID, - ResourceType: AllocationResourceTypeIPBlock, - ResourceTypeID: ipv4Block.ID, - ConstraintType: AllocationConstraintTypeReserved, - ConstraintValue: 100, - DerivedResourceID: nil, - CreatedBy: user.ID, + AllocationID: alloc2.ID, + ResourceType: AllocationResourceTypeIPBlock, + ResourceTypeID: ipv4Block.ID, + ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 100, + CreatedBy: user.ID, }, }, expectError: false, @@ -220,39 +218,36 @@ func TestAllocationConstraintSQLDAO_CreateFromParams(t *testing.T) { }, { desc: "failure - foreign key violation on allocation_id", - as: []AllocationConstraint{ + inputs: []AllocationConstraintCreateInput{ { - AllocationID: uuid.New(), - ResourceType: AllocationResourceTypeIPBlock, - ResourceTypeID: ipv4Block.ID, - ConstraintType: AllocationConstraintTypeReserved, - ConstraintValue: 100, - DerivedResourceID: nil, - CreatedBy: user.ID, + AllocationID: uuid.New(), + ResourceType: AllocationResourceTypeIPBlock, + ResourceTypeID: ipv4Block.ID, + ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 100, + CreatedBy: user.ID, }, }, expectError: true, }, { desc: "failure - multiple fields with nil", - as: []AllocationConstraint{ + inputs: []AllocationConstraintCreateInput{ { - AllocationID: alloc1.ID, - ResourceType: " ", - ResourceTypeID: ipv4Block.ID, - ConstraintType: AllocationConstraintTypeReserved, - ConstraintValue: 100, - DerivedResourceID: nil, - CreatedBy: user.ID, + AllocationID: alloc1.ID, + ResourceType: " ", + ResourceTypeID: ipv4Block.ID, + ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 100, + CreatedBy: user.ID, }, { - AllocationID: alloc1.ID, - ResourceType: AllocationResourceTypeIPBlock, - ResourceTypeID: ipv4Block.ID, - ConstraintType: " ", - ConstraintValue: 100, - DerivedResourceID: nil, - CreatedBy: user.ID, + AllocationID: alloc1.ID, + ResourceType: AllocationResourceTypeIPBlock, + ResourceTypeID: ipv4Block.ID, + ConstraintType: " ", + ConstraintValue: 100, + CreatedBy: user.ID, }, }, expectError: true, @@ -260,12 +255,8 @@ func TestAllocationConstraintSQLDAO_CreateFromParams(t *testing.T) { } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { - for _, i := range tc.as { - it, err := asd.CreateFromParams( - ctx, nil, i.AllocationID, i.ResourceType, - i.ResourceTypeID, i.ConstraintType, - i.ConstraintValue, i.DerivedResourceID, - i.CreatedBy) + for _, input := range tc.inputs { + it, err := asd.Create(ctx, nil, input) assert.Equal(t, tc.expectError, err != nil) if !tc.expectError { assert.NotNil(t, it) @@ -301,15 +292,25 @@ func TestAllocationConstraintSQLDAO_GetByID(t *testing.T) { asd := NewAllocationConstraintDAO(dbSession) - a1, err := asd.CreateFromParams( - ctx, nil, alloc1.ID, AllocationResourceTypeInstanceType, - insType.ID, AllocationConstraintTypeReserved, 10, nil, user.ID) + a1, err := asd.Create(ctx, nil, AllocationConstraintCreateInput{ + AllocationID: alloc1.ID, + ResourceType: AllocationResourceTypeInstanceType, + ResourceTypeID: insType.ID, + ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 10, + CreatedBy: user.ID, + }) assert.Nil(t, err) assert.NotNil(t, a1) - a2, err := asd.CreateFromParams( - ctx, nil, alloc2.ID, AllocationResourceTypeIPBlock, - ipv4Block.ID, AllocationConstraintTypeReserved, 10, nil, user.ID) + a2, err := asd.Create(ctx, nil, AllocationConstraintCreateInput{ + AllocationID: alloc2.ID, + ResourceType: AllocationResourceTypeIPBlock, + ResourceTypeID: ipv4Block.ID, + ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 10, + CreatedBy: user.ID, + }) assert.Nil(t, err) assert.NotNil(t, a2) @@ -397,20 +398,31 @@ func TestAllocationConstraintSQLDAO_GetAll(t *testing.T) { totalCount := 30 asc1it := []AllocationConstraint{} - for i := 0; i < totalCount/2; i++ { - asc1, err := asd.CreateFromParams( - ctx, nil, alloc1.ID, AllocationResourceTypeInstanceType, - insType.ID, AllocationConstraintTypeReserved, 10, nil, user.ID) + for range totalCount / 2 { + asc1, err := asd.Create(ctx, nil, AllocationConstraintCreateInput{ + AllocationID: alloc1.ID, + ResourceType: AllocationResourceTypeInstanceType, + ResourceTypeID: insType.ID, + ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 10, + CreatedBy: user.ID, + }) assert.Nil(t, err) assert.NotNil(t, asc1) asc1it = append(asc1it, *asc1) } asc2ipb := []AllocationConstraint{} - for i := 0; i < totalCount/2; i++ { - asc2, err := asd.CreateFromParams( - ctx, nil, alloc2.ID, AllocationResourceTypeIPBlock, - ipv4Block.ID, AllocationConstraintTypeReserved, 10, cutil.GetPtr(uuid.New()), user.ID) + for range totalCount / 2 { + asc2, err := asd.Create(ctx, nil, AllocationConstraintCreateInput{ + AllocationID: alloc2.ID, + ResourceType: AllocationResourceTypeIPBlock, + ResourceTypeID: ipv4Block.ID, + ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 10, + DerivedResourceID: cutil.GetPtr(uuid.New()), + CreatedBy: user.ID, + }) assert.Nil(t, err) assert.NotNil(t, asc2) asc2ipb = append(asc2ipb, *asc2) @@ -421,14 +433,8 @@ func TestAllocationConstraintSQLDAO_GetAll(t *testing.T) { tests := []struct { desc string - allocationIDs []uuid.UUID - resourceType *string - resourceTypeID *uuid.UUID - constraintType *string - derivedResourceID *uuid.UUID - offset *int - limit *int - orderBy *paginator.OrderBy + filter AllocationConstraintFilterInput + page paginator.PageInput firstEntry *AllocationConstraint expectedCount int expectedTotal *int @@ -438,102 +444,96 @@ func TestAllocationConstraintSQLDAO_GetAll(t *testing.T) { }{ { desc: "GetAll with no filters returns objects", - allocationIDs: nil, - resourceType: nil, - resourceTypeID: nil, - constraintType: nil, + filter: AllocationConstraintFilterInput{}, expectedCount: paginator.DefaultLimit, expectedError: false, verifyChildSpanner: true, }, { - desc: "GetAll with Allocation ID filter returns objects", - allocationIDs: []uuid.UUID{alloc2.ID}, - resourceType: nil, - resourceTypeID: nil, - constraintType: nil, + desc: "GetAll with Allocation ID filter returns objects", + filter: AllocationConstraintFilterInput{ + AllocationIDs: []uuid.UUID{alloc2.ID}, + }, expectedCount: totalCount / 2, expectedError: false, paramRelations: []string{AllocationRelationName}, }, { - desc: "GetAll with Resource Type filter returns objects", - allocationIDs: nil, - resourceType: cutil.GetPtr(AllocationResourceTypeIPBlock), - resourceTypeID: nil, - constraintType: nil, - expectedCount: totalCount / 2, - expectedError: false, + desc: "GetAll with Resource Type filter returns objects", + filter: AllocationConstraintFilterInput{ + ResourceType: cutil.GetPtr(AllocationResourceTypeIPBlock), + }, + expectedCount: totalCount / 2, + expectedError: false, }, { - desc: "GetAll with Resource Type ID filter returns objects", - allocationIDs: nil, - resourceType: nil, - resourceTypeID: &insType.ID, - constraintType: nil, - expectedCount: totalCount / 2, - expectedError: false, + desc: "GetAll with Resource Type ID filter returns objects", + filter: AllocationConstraintFilterInput{ + ResourceTypeIDs: []uuid.UUID{insType.ID}, + }, + expectedCount: totalCount / 2, + expectedError: false, }, { - desc: "GetAll with Derived Resource ID filter returns objects", - allocationIDs: nil, - resourceType: nil, - resourceTypeID: nil, - constraintType: nil, - derivedResourceID: asc2ipb[0].DerivedResourceID, - expectedCount: 1, - expectedError: false, + desc: "GetAll with Derived Resource ID filter returns objects", + filter: AllocationConstraintFilterInput{ + DerivedResourceID: asc2ipb[0].DerivedResourceID, + }, + expectedCount: 1, + expectedError: false, }, { - desc: "GetAll with invalid Resource Type ID filter returns no objects", - allocationIDs: nil, - resourceType: nil, - resourceTypeID: cutil.GetPtr(uuid.New()), - constraintType: nil, - expectedCount: 0, - expectedError: false, + desc: "GetAll with invalid Resource Type ID filter returns no objects", + filter: AllocationConstraintFilterInput{ + ResourceTypeIDs: []uuid.UUID{uuid.New()}, + }, + expectedCount: 0, + expectedError: false, }, { - desc: "GetAll with Constraint Type filter returns objects", - allocationIDs: nil, - resourceType: cutil.GetPtr(AllocationResourceTypeIPBlock), - resourceTypeID: nil, - constraintType: cutil.GetPtr(AllocationConstraintTypeReserved), - expectedCount: totalCount / 2, - expectedError: false, + desc: "GetAll with Constraint Type filter returns objects", + filter: AllocationConstraintFilterInput{ + ResourceType: cutil.GetPtr(AllocationResourceTypeIPBlock), + ConstraintType: cutil.GetPtr(AllocationConstraintTypeReserved), + }, + expectedCount: totalCount / 2, + expectedError: false, }, { - desc: "GetAll with limit returns objects", - allocationIDs: []uuid.UUID{alloc1.ID}, - resourceType: nil, - resourceTypeID: nil, - constraintType: nil, - offset: cutil.GetPtr(0), - limit: cutil.GetPtr(5), - expectedCount: 5, - expectedTotal: cutil.GetPtr(totalCount / 2), - expectedError: false, + desc: "GetAll with limit returns objects", + filter: AllocationConstraintFilterInput{ + AllocationIDs: []uuid.UUID{alloc1.ID}, + }, + page: paginator.PageInput{ + Offset: cutil.GetPtr(0), + Limit: cutil.GetPtr(5), + }, + expectedCount: 5, + expectedTotal: cutil.GetPtr(totalCount / 2), + expectedError: false, }, { - desc: "GetAll with offset returns objects", - allocationIDs: []uuid.UUID{alloc2.ID}, - resourceType: nil, - resourceTypeID: nil, - constraintType: nil, - offset: cutil.GetPtr(5), - expectedCount: 10, - expectedTotal: cutil.GetPtr(totalCount / 2), - expectedError: false, + desc: "GetAll with offset returns objects", + filter: AllocationConstraintFilterInput{ + AllocationIDs: []uuid.UUID{alloc2.ID}, + }, + page: paginator.PageInput{ + Offset: cutil.GetPtr(5), + }, + expectedCount: 10, + expectedTotal: cutil.GetPtr(totalCount / 2), + expectedError: false, }, { - desc: "GetAll with order by returns objects", - allocationIDs: []uuid.UUID{alloc1.ID}, - resourceType: nil, - resourceTypeID: nil, - constraintType: nil, - orderBy: &paginator.OrderBy{ - Field: "created", - Order: paginator.OrderAscending, + desc: "GetAll with order by returns objects", + filter: AllocationConstraintFilterInput{ + AllocationIDs: []uuid.UUID{alloc1.ID}, + }, + page: paginator.PageInput{ + OrderBy: &paginator.OrderBy{ + Field: "created", + Order: paginator.OrderAscending, + }, }, firstEntry: &asc1it[0], expectedCount: totalCount / 2, @@ -541,25 +541,31 @@ func TestAllocationConstraintSQLDAO_GetAll(t *testing.T) { expectedError: false, }, { - desc: "GetAll with resource Type filter returns none objects", - allocationIDs: []uuid.UUID{alloc1.ID}, - resourceType: cutil.GetPtr(AllocationResourceTypeIPBlock), + desc: "GetAll with resource Type filter returns none objects", + filter: AllocationConstraintFilterInput{ + AllocationIDs: []uuid.UUID{alloc1.ID}, + ResourceType: cutil.GetPtr(AllocationResourceTypeIPBlock), + }, expectedCount: 0, expectedTotal: cutil.GetPtr(0), expectedError: false, }, { - desc: "GetAll with resource Type filter returns objects", - allocationIDs: []uuid.UUID{alloc2.ID}, - resourceType: cutil.GetPtr(AllocationResourceTypeIPBlock), + desc: "GetAll with resource Type filter returns objects", + filter: AllocationConstraintFilterInput{ + AllocationIDs: []uuid.UUID{alloc2.ID}, + ResourceType: cutil.GetPtr(AllocationResourceTypeIPBlock), + }, expectedCount: totalCount / 2, expectedTotal: cutil.GetPtr(totalCount / 2), expectedError: false, }, { - desc: "GetAll with resource Type filter with mixed allocation uuids returns objects", - allocationIDs: []uuid.UUID{alloc1.ID, alloc2.ID}, - resourceType: cutil.GetPtr(AllocationResourceTypeIPBlock), + desc: "GetAll with resource Type filter with mixed allocation uuids returns objects", + filter: AllocationConstraintFilterInput{ + AllocationIDs: []uuid.UUID{alloc1.ID, alloc2.ID}, + ResourceType: cutil.GetPtr(AllocationResourceTypeIPBlock), + }, expectedCount: totalCount / 2, expectedTotal: cutil.GetPtr(totalCount / 2), expectedError: false, @@ -567,19 +573,14 @@ func TestAllocationConstraintSQLDAO_GetAll(t *testing.T) { } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { - var resourceIDs []uuid.UUID - if tc.resourceTypeID != nil { - resourceIDs = append(resourceIDs, *tc.resourceTypeID) - } - got, total, err := asd.GetAll(ctx, nil, tc.allocationIDs, tc.resourceType, resourceIDs, tc.constraintType, - tc.derivedResourceID, tc.paramRelations, tc.offset, tc.limit, tc.orderBy) + got, total, err := asd.GetAll(ctx, nil, tc.filter, tc.page, tc.paramRelations) assert.Equal(t, tc.expectedError, err != nil) if tc.expectedError { assert.Equal(t, nil, got) } else { assert.Equal(t, tc.expectedCount, len(got)) - if len(tc.paramRelations) > 0 && len(tc.allocationIDs) > 0 { - assert.Equal(t, tc.allocationIDs[0], got[0].Allocation.ID) + if len(tc.paramRelations) > 0 && len(tc.filter.AllocationIDs) > 0 { + assert.Equal(t, tc.filter.AllocationIDs[0], got[0].Allocation.ID) } } @@ -597,7 +598,7 @@ func TestAllocationConstraintSQLDAO_GetAll(t *testing.T) { } } -func TestAllocationConstraintSQLDAO_UpdateFromParams(t *testing.T) { +func TestAllocationConstraintSQLDAO_Update(t *testing.T) { ctx := context.Background() dbSession := testAllocationConstraintInitDB(t) defer dbSession.Close() @@ -620,9 +621,15 @@ func TestAllocationConstraintSQLDAO_UpdateFromParams(t *testing.T) { derivedResourceID2 := uuid.New() constraintValue := 1 constraintValue2 := 2 - a1, err := asd.CreateFromParams( - ctx, nil, alloc1.ID, AllocationResourceTypeInstanceType, insType.ID, - AllocationConstraintTypeReserved, constraintValue, &derivedResourceID, user.ID) + a1, err := asd.Create(ctx, nil, AllocationConstraintCreateInput{ + AllocationID: alloc1.ID, + ResourceType: AllocationResourceTypeInstanceType, + ResourceTypeID: insType.ID, + ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: constraintValue, + DerivedResourceID: &derivedResourceID, + CreatedBy: user.ID, + }) assert.Nil(t, err) assert.NotNil(t, a1) @@ -630,14 +637,8 @@ func TestAllocationConstraintSQLDAO_UpdateFromParams(t *testing.T) { _, _, ctx = testCommonTraceProviderSetup(t, ctx) tests := []struct { - desc string - - paramAllocationID *uuid.UUID - paramResourceType *string - paramResourceTypeID *uuid.UUID - paramConstraintType *string - paramConstraintValue *int - paramDerivedResourceID *uuid.UUID + desc string + input AllocationConstraintUpdateInput expectedError bool @@ -651,13 +652,10 @@ func TestAllocationConstraintSQLDAO_UpdateFromParams(t *testing.T) { verifyChildSpanner bool }{ { - desc: "can update nothing", - paramAllocationID: nil, - paramResourceType: nil, - paramResourceTypeID: nil, - paramConstraintType: nil, - paramConstraintValue: nil, - paramDerivedResourceID: nil, + desc: "can update nothing", + input: AllocationConstraintUpdateInput{ + AllocationConstraintID: a1.ID, + }, expectedError: false, expectedAllocationID: &alloc1.ID, @@ -670,13 +668,11 @@ func TestAllocationConstraintSQLDAO_UpdateFromParams(t *testing.T) { verifyChildSpanner: true, }, { - desc: "error updating due to foreign key violation", - paramAllocationID: &derivedResourceID, - paramResourceType: nil, - paramResourceTypeID: nil, - paramConstraintType: nil, - paramConstraintValue: nil, - paramDerivedResourceID: nil, + desc: "error updating due to foreign key violation", + input: AllocationConstraintUpdateInput{ + AllocationConstraintID: a1.ID, + AllocationID: &derivedResourceID, + }, expectedError: true, expectedAllocationID: &alloc1.ID, @@ -688,13 +684,13 @@ func TestAllocationConstraintSQLDAO_UpdateFromParams(t *testing.T) { expectedUpdate: true, }, { - desc: "can update AllocationID ResourceType n ID", - paramAllocationID: &alloc2.ID, - paramResourceType: cutil.GetPtr(AllocationResourceTypeIPBlock), - paramResourceTypeID: &ipv4Block.ID, - paramConstraintType: nil, - paramConstraintValue: nil, - paramDerivedResourceID: nil, + desc: "can update AllocationID ResourceType and ResourceTypeID", + input: AllocationConstraintUpdateInput{ + AllocationConstraintID: a1.ID, + AllocationID: &alloc2.ID, + ResourceType: cutil.GetPtr(AllocationResourceTypeIPBlock), + ResourceTypeID: &ipv4Block.ID, + }, expectedError: false, expectedAllocationID: &alloc2.ID, @@ -706,13 +702,13 @@ func TestAllocationConstraintSQLDAO_UpdateFromParams(t *testing.T) { expectedUpdate: true, }, { - desc: "can update Constraint Type Value n ResourceID", - paramAllocationID: nil, - paramResourceType: nil, - paramResourceTypeID: nil, - paramConstraintType: cutil.GetPtr(AllocationConstraintTypePreemptible), - paramConstraintValue: &constraintValue2, - paramDerivedResourceID: &derivedResourceID2, + desc: "can update Constraint Type Value and ResourceID", + input: AllocationConstraintUpdateInput{ + AllocationConstraintID: a1.ID, + ConstraintType: cutil.GetPtr(AllocationConstraintTypePreemptible), + ConstraintValue: &constraintValue2, + DerivedResourceID: &derivedResourceID2, + }, expectedError: false, expectedAllocationID: &alloc2.ID, @@ -724,13 +720,11 @@ func TestAllocationConstraintSQLDAO_UpdateFromParams(t *testing.T) { expectedUpdate: true, }, { - desc: "invalid Constraint Type", - paramAllocationID: nil, - paramResourceType: nil, - paramResourceTypeID: nil, - paramConstraintType: cutil.GetPtr(" "), - paramConstraintValue: nil, - paramDerivedResourceID: nil, + desc: "invalid Constraint Type", + input: AllocationConstraintUpdateInput{ + AllocationConstraintID: a1.ID, + ConstraintType: cutil.GetPtr(" "), + }, expectedError: true, expectedAllocationID: &alloc2.ID, @@ -742,13 +736,11 @@ func TestAllocationConstraintSQLDAO_UpdateFromParams(t *testing.T) { expectedUpdate: true, }, { - desc: "invalid Constraint Type", - paramAllocationID: nil, - paramResourceType: cutil.GetPtr(" "), - paramResourceTypeID: nil, - paramConstraintType: nil, - paramConstraintValue: nil, - paramDerivedResourceID: nil, + desc: "invalid Resource Type", + input: AllocationConstraintUpdateInput{ + AllocationConstraintID: a1.ID, + ResourceType: cutil.GetPtr(" "), + }, expectedError: true, expectedAllocationID: &alloc2.ID, @@ -762,10 +754,7 @@ func TestAllocationConstraintSQLDAO_UpdateFromParams(t *testing.T) { } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { - got, err := asd.UpdateFromParams(ctx, nil, a1.ID, - tc.paramAllocationID, tc.paramResourceType, - tc.paramResourceTypeID, tc.paramConstraintType, - tc.paramConstraintValue, tc.paramDerivedResourceID) + got, err := asd.Update(ctx, nil, tc.input) assert.Equal(t, tc.expectedError, err != nil) if !tc.expectedError { assert.NotNil(t, got) @@ -791,7 +780,7 @@ func TestAllocationConstraintSQLDAO_UpdateFromParams(t *testing.T) { } } -func TestAllocationConstraintSQLDAO_ClearFromParams(t *testing.T) { +func TestAllocationConstraintSQLDAO_Clear(t *testing.T) { ctx := context.Background() dbSession := testAllocationConstraintInitDB(t) defer dbSession.Close() @@ -808,10 +797,15 @@ func TestAllocationConstraintSQLDAO_ClearFromParams(t *testing.T) { asd := NewAllocationConstraintDAO(dbSession) - a1, err := asd.CreateFromParams( - ctx, nil, alloc.ID, AllocationResourceTypeInstanceType, - insType.ID, AllocationConstraintTypeReserved, 10, - &dummyUID, user.ID) + a1, err := asd.Create(ctx, nil, AllocationConstraintCreateInput{ + AllocationID: alloc.ID, + ResourceType: AllocationResourceTypeInstanceType, + ResourceTypeID: insType.ID, + ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 10, + DerivedResourceID: &dummyUID, + CreatedBy: user.ID, + }) assert.Nil(t, err) assert.NotNil(t, a1) @@ -820,21 +814,22 @@ func TestAllocationConstraintSQLDAO_ClearFromParams(t *testing.T) { tests := []struct { desc string - a *AllocationConstraint - paramDerivedResourceID bool + input AllocationConstraintClearInput expectedDerivedResourceID *uuid.UUID verifyChildSpanner bool }{ { - desc: "can clear derivedResourceId", - a: a1, - paramDerivedResourceID: true, + desc: "can clear derivedResourceId", + input: AllocationConstraintClearInput{ + AllocationConstraintID: a1.ID, + DerivedResourceID: true, + }, expectedDerivedResourceID: nil, }, } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { - got, err := asd.ClearFromParams(ctx, nil, tc.a.ID, tc.paramDerivedResourceID) + got, err := asd.Clear(ctx, nil, tc.input) assert.Nil(t, err) assert.NotNil(t, got) assert.Equal(t, tc.expectedDerivedResourceID == nil, @@ -844,7 +839,7 @@ func TestAllocationConstraintSQLDAO_ClearFromParams(t *testing.T) { *got.DerivedResourceID) } - assert.True(t, got.Updated.After(tc.a.Updated)) + assert.True(t, got.Updated.After(a1.Updated)) if tc.verifyChildSpanner { span := otrace.SpanFromContext(ctx) @@ -872,9 +867,14 @@ func TestAllocationConstraintSQLDAO_DeleteByID(t *testing.T) { asd := NewAllocationConstraintDAO(dbSession) - a1, err := asd.CreateFromParams( - ctx, nil, alloc.ID, AllocationResourceTypeInstanceType, - insType.ID, AllocationConstraintTypeReserved, 10, nil, user.ID) + a1, err := asd.Create(ctx, nil, AllocationConstraintCreateInput{ + AllocationID: alloc.ID, + ResourceType: AllocationResourceTypeInstanceType, + ResourceTypeID: insType.ID, + ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 10, + CreatedBy: user.ID, + }) assert.Nil(t, err) assert.NotNil(t, a1) diff --git a/rest-api/db/pkg/db/model/domain.go b/rest-api/db/pkg/db/model/domain.go index f979b0c768..dea7350475 100644 --- a/rest-api/db/pkg/db/model/domain.go +++ b/rest-api/db/pkg/db/model/domain.go @@ -38,6 +38,38 @@ var ( } ) +// DomainCreateInput input parameters for Create method +type DomainCreateInput struct { + Hostname string + Org string + ControllerDomainID *uuid.UUID + Status string + CreatedBy uuid.UUID +} + +// DomainUpdateInput input parameters for Update method +type DomainUpdateInput struct { + DomainID uuid.UUID + Hostname *string + Org *string + ControllerDomainID *uuid.UUID + Status *string +} + +// DomainClearInput input parameters for Clear method +type DomainClearInput struct { + DomainID uuid.UUID + ControllerDomainID bool +} + +// DomainFilterInput input parameters for GetAll method +type DomainFilterInput struct { + Hostname *string + Org *string + ControllerDomainID *uuid.UUID + Status *string +} + // Domain contains information about the fully qualified domain // name for determining machine hostnames type Domain struct { @@ -71,17 +103,17 @@ func (d *Domain) BeforeAppendModel(ctx context.Context, query bun.Query) error { // DomainDAO is an interface for interacting with the Domain model type DomainDAO interface { // - CreateFromParams(ctx context.Context, tx *db.Tx, hostname string, org string, controllerDomainID *uuid.UUID, status string, createdBy uuid.UUID) (*Domain, error) + Create(ctx context.Context, tx *db.Tx, input DomainCreateInput) (*Domain, error) // GetByID(ctx context.Context, tx *db.Tx, id uuid.UUID, includeRelations []string) (*Domain, error) // - GetAll(ctx context.Context, tx *db.Tx, hostname, org *string, controllerDomainID *uuid.UUID, status *string, includeRelations []string) ([]Domain, error) + GetAll(ctx context.Context, tx *db.Tx, filter DomainFilterInput, includeRelations []string) ([]Domain, error) // - UpdateFromParams(ctx context.Context, tx *db.Tx, id uuid.UUID, hostname *string, org *string, controllerDomainID *uuid.UUID, status *string) (*Domain, error) + Update(ctx context.Context, tx *db.Tx, input DomainUpdateInput) (*Domain, error) // - ClearFromParams(ctx context.Context, tx *db.Tx, id uuid.UUID, controllerDomainID bool) (*Domain, error) + Clear(ctx context.Context, tx *db.Tx, input DomainClearInput) (*Domain, error) // - DeleteByID(ctx context.Context, tx *db.Tx, id uuid.UUID) error + Delete(ctx context.Context, tx *db.Tx, id uuid.UUID) error } // DomainSQLDAO is an implementation of the DomainDAO interface @@ -91,28 +123,22 @@ type DomainSQLDAO struct { tracerSpan *stracer.TracerSpan } -// CreateFromParams creates a new Domain from the given parameters -// since there are 2 operations (INSERT, SELECT), in this, it is required that -// this library call happens within a transaction -func (dsd DomainSQLDAO) CreateFromParams( - ctx context.Context, tx *db.Tx, - hostname string, - org string, - controllerDomainID *uuid.UUID, - status string, createdBy uuid.UUID) (*Domain, error) { +// Create creates a new Domain from the given input. +// Since there are 2 operations (INSERT, SELECT), this call must happen within a transaction. +func (dsd DomainSQLDAO) Create(ctx context.Context, tx *db.Tx, input DomainCreateInput) (*Domain, error) { // Create a child span and set the attributes for current request - ctx, domainDAOSpan := dsd.tracerSpan.CreateChildInCurrentContext(ctx, "DomainDAO.CreateFromParams") + ctx, domainDAOSpan := dsd.tracerSpan.CreateChildInCurrentContext(ctx, "DomainDAO.Create") if domainDAOSpan != nil { defer domainDAOSpan.End() } d := &Domain{ ID: uuid.New(), - Hostname: hostname, - Org: org, - ControllerDomainID: controllerDomainID, - Status: status, - CreatedBy: createdBy, + Hostname: input.Hostname, + Org: input.Org, + ControllerDomainID: input.ControllerDomainID, + Status: input.Status, + CreatedBy: input.CreatedBy, } _, err := db.GetIDB(tx, dsd.dbSession).NewInsert().Model(d).Exec(ctx) @@ -163,7 +189,7 @@ func (dsd DomainSQLDAO) GetByID(ctx context.Context, tx *db.Tx, id uuid.UUID, in // Optional filters can be specified on hostname, org, controllerDomainID // errors are returned only when there is a db related error // if records not found, then error is nil, but length of returned slice is 0 -func (dsd DomainSQLDAO) GetAll(ctx context.Context, tx *db.Tx, hostname, org *string, controllerDomainID *uuid.UUID, status *string, includeRelations []string) ([]Domain, error) { +func (dsd DomainSQLDAO) GetAll(ctx context.Context, tx *db.Tx, filter DomainFilterInput, includeRelations []string) ([]Domain, error) { // Create a child span and set the attributes for current request ctx, domainDAOSpan := dsd.tracerSpan.CreateChildInCurrentContext(ctx, "DomainDAO.GetAll") if domainDAOSpan != nil { @@ -174,32 +200,32 @@ func (dsd DomainSQLDAO) GetAll(ctx context.Context, tx *db.Tx, hostname, org *st query := db.GetIDB(tx, dsd.dbSession).NewSelect().Model(&d) - if hostname != nil { - query = query.Where("d.hostname = ?", *hostname) + if filter.Hostname != nil { + query = query.Where("d.hostname = ?", *filter.Hostname) if domainDAOSpan != nil { - dsd.tracerSpan.SetAttribute(domainDAOSpan, "hostname", *hostname) + dsd.tracerSpan.SetAttribute(domainDAOSpan, "hostname", *filter.Hostname) } } - if org != nil { - query = query.Where("d.org = ?", *org) + if filter.Org != nil { + query = query.Where("d.org = ?", *filter.Org) if domainDAOSpan != nil { - dsd.tracerSpan.SetAttribute(domainDAOSpan, "org", *org) + dsd.tracerSpan.SetAttribute(domainDAOSpan, "org", *filter.Org) } } - if controllerDomainID != nil { - query = query.Where("d.controller_domain_id = ?", *controllerDomainID) + if filter.ControllerDomainID != nil { + query = query.Where("d.controller_domain_id = ?", *filter.ControllerDomainID) if domainDAOSpan != nil { - dsd.tracerSpan.SetAttribute(domainDAOSpan, "controller_domain_id", controllerDomainID.String()) + dsd.tracerSpan.SetAttribute(domainDAOSpan, "controller_domain_id", filter.ControllerDomainID.String()) } } - if status != nil { - query = query.Where("d.status = ?", *status) + if filter.Status != nil { + query = query.Where("d.status = ?", *filter.Status) if domainDAOSpan != nil { - dsd.tracerSpan.SetAttribute(domainDAOSpan, "status", *status) + dsd.tracerSpan.SetAttribute(domainDAOSpan, "status", *filter.Status) } } @@ -216,66 +242,60 @@ func (dsd DomainSQLDAO) GetAll(ctx context.Context, tx *db.Tx, hostname, org *st return d, nil } -// UpdateFromParams updates specified fields of an existing Domain +// Update updates specified fields of an existing Domain // The updated fields are assumed to be set to non-null values -// For setting to null values, use: ClearFromParams +// For setting to null values, use: Clear // since there are 2 operations (UPDATE, SELECT), in this, it is required that // this library call happens within a transaction -func (dsd DomainSQLDAO) UpdateFromParams( - ctx context.Context, tx *db.Tx, - id uuid.UUID, - hostname *string, - org *string, - controllerDomainID *uuid.UUID, - status *string) (*Domain, error) { +func (dsd DomainSQLDAO) Update(ctx context.Context, tx *db.Tx, input DomainUpdateInput) (*Domain, error) { d := &Domain{ - ID: id, + ID: input.DomainID, } // Create a child span and set the attributes for current request - ctx, domainDAOSpan := dsd.tracerSpan.CreateChildInCurrentContext(ctx, "DomainDAO.UpdateFromParams") + ctx, domainDAOSpan := dsd.tracerSpan.CreateChildInCurrentContext(ctx, "DomainDAO.Update") if domainDAOSpan != nil { defer domainDAOSpan.End() } updatedFields := []string{} - if hostname != nil { - d.Hostname = *hostname + if input.Hostname != nil { + d.Hostname = *input.Hostname updatedFields = append(updatedFields, "hostname") if domainDAOSpan != nil { - dsd.tracerSpan.SetAttribute(domainDAOSpan, "hostname", *hostname) + dsd.tracerSpan.SetAttribute(domainDAOSpan, "hostname", *input.Hostname) } } - if org != nil { - d.Org = *org + if input.Org != nil { + d.Org = *input.Org updatedFields = append(updatedFields, "org") if domainDAOSpan != nil { - dsd.tracerSpan.SetAttribute(domainDAOSpan, "org", *org) + dsd.tracerSpan.SetAttribute(domainDAOSpan, "org", *input.Org) } } - if controllerDomainID != nil { - d.ControllerDomainID = controllerDomainID + if input.ControllerDomainID != nil { + d.ControllerDomainID = input.ControllerDomainID updatedFields = append(updatedFields, "controller_domain_id") if domainDAOSpan != nil { - dsd.tracerSpan.SetAttribute(domainDAOSpan, "controller_domain_id", controllerDomainID.String()) + dsd.tracerSpan.SetAttribute(domainDAOSpan, "controller_domain_id", input.ControllerDomainID.String()) } } - if status != nil { - d.Status = *status + if input.Status != nil { + d.Status = *input.Status updatedFields = append(updatedFields, "status") if domainDAOSpan != nil { - dsd.tracerSpan.SetAttribute(domainDAOSpan, "status", *status) + dsd.tracerSpan.SetAttribute(domainDAOSpan, "status", *input.Status) } } if len(updatedFields) > 0 { updatedFields = append(updatedFields, "updated") - _, err := db.GetIDB(tx, dsd.dbSession).NewUpdate().Model(d).Column(updatedFields...).Where("id = ?", id).Exec(ctx) + _, err := db.GetIDB(tx, dsd.dbSession).NewUpdate().Model(d).Column(updatedFields...).Where("id = ?", input.DomainID).Exec(ctx) if err != nil { return nil, err } @@ -288,24 +308,24 @@ func (dsd DomainSQLDAO) UpdateFromParams( return nv, nil } -// ClearFromParams sets parameters of an existing Domain to null values in db +// Clear sets parameters of an existing Domain to null values in db // parameter controllerDomainID when true, the are set to null in db // since there are 2 operations (UPDATE, SELECT), it is required that // this must be within a transaction -func (dsd DomainSQLDAO) ClearFromParams(ctx context.Context, tx *db.Tx, id uuid.UUID, controllerDomainID bool) (*Domain, error) { +func (dsd DomainSQLDAO) Clear(ctx context.Context, tx *db.Tx, input DomainClearInput) (*Domain, error) { // Create a child span and set the attributes for current request - ctx, domainDAOSpan := dsd.tracerSpan.CreateChildInCurrentContext(ctx, "DomainDAO.ClearFromParams") + ctx, domainDAOSpan := dsd.tracerSpan.CreateChildInCurrentContext(ctx, "DomainDAO.Clear") if domainDAOSpan != nil { defer domainDAOSpan.End() } d := &Domain{ - ID: id, + ID: input.DomainID, } updatedFields := []string{} - if controllerDomainID { + if input.ControllerDomainID { d.ControllerDomainID = nil updatedFields = append(updatedFields, "controller_domain_id") } @@ -313,25 +333,25 @@ func (dsd DomainSQLDAO) ClearFromParams(ctx context.Context, tx *db.Tx, id uuid. if len(updatedFields) > 0 { updatedFields = append(updatedFields, "updated") - _, err := db.GetIDB(tx, dsd.dbSession).NewUpdate().Model(d).Column(updatedFields...).Where("id = ?", id).Exec(ctx) + _, err := db.GetIDB(tx, dsd.dbSession).NewUpdate().Model(d).Column(updatedFields...).Where("id = ?", input.DomainID).Exec(ctx) if err != nil { return nil, err } } - nv, err := dsd.GetByID(ctx, tx, d.ID, nil) + nv, err := dsd.GetByID(ctx, tx, input.DomainID, nil) if err != nil { return nil, err } return nv, nil } -// DeleteByID deletes an Domain by ID +// Delete deletes an Domain by ID // error is returned only if there is a db error // if the object being deleted doesnt exist, error is not returned (idempotent delete) -func (dsd DomainSQLDAO) DeleteByID(ctx context.Context, tx *db.Tx, id uuid.UUID) error { +func (dsd DomainSQLDAO) Delete(ctx context.Context, tx *db.Tx, id uuid.UUID) error { // Create a child span and set the attributes for current request - ctx, domainDAOSpan := dsd.tracerSpan.CreateChildInCurrentContext(ctx, "DomainDAO.DeleteByID") + ctx, domainDAOSpan := dsd.tracerSpan.CreateChildInCurrentContext(ctx, "DomainDAO.Delete") if domainDAOSpan != nil { defer domainDAOSpan.End() } diff --git a/rest-api/db/pkg/db/model/domain_test.go b/rest-api/db/pkg/db/model/domain_test.go index f42e45e9b2..617f8c5a97 100644 --- a/rest-api/db/pkg/db/model/domain_test.go +++ b/rest-api/db/pkg/db/model/domain_test.go @@ -51,7 +51,7 @@ func testDomainBuildUser(t *testing.T, dbSession *db.Session, starfleetID string return user } -func TestDomainSQLDAO_CreateFromParams(t *testing.T) { +func TestDomainSQLDAO_Create(t *testing.T) { ctx := context.Background() dbSession := testDomainInitDB(t) defer dbSession.Close() @@ -68,16 +68,19 @@ func TestDomainSQLDAO_CreateFromParams(t *testing.T) { tests := []struct { desc string - ds []Domain + inputs []DomainCreateInput expectError bool tx *db.Tx verifyChildSpanner bool }{ { desc: "create one", - ds: []Domain{ + inputs: []DomainCreateInput{ { - Hostname: "test.com", Org: "testOrg", ControllerDomainID: &controllerDomainID, Status: DomainStatusPending, + Hostname: "test.com", + Org: "testOrg", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, }, }, expectError: false, @@ -86,15 +89,27 @@ func TestDomainSQLDAO_CreateFromParams(t *testing.T) { }, { desc: "create multiple, some with null controllerDomainID field", - ds: []Domain{ + inputs: []DomainCreateInput{ { - Hostname: "test1.com", Org: "testOrg1", ControllerDomainID: &controllerDomainID, Status: DomainStatusPending, CreatedBy: user.ID, + Hostname: "test1.com", + Org: "testOrg1", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, + CreatedBy: user.ID, }, { - Hostname: "test2.com", Org: "testOrg2", ControllerDomainID: nil, Status: DomainStatusPending, CreatedBy: user.ID, + Hostname: "test2.com", + Org: "testOrg2", + ControllerDomainID: nil, + Status: DomainStatusPending, + CreatedBy: user.ID, }, { - Hostname: "test3.com", Org: "testOrg3", ControllerDomainID: &controllerDomainID, Status: DomainStatusPending, CreatedBy: user.ID, + Hostname: "test3.com", + Org: "testOrg3", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, + CreatedBy: user.ID, }, }, expectError: false, @@ -102,15 +117,27 @@ func TestDomainSQLDAO_CreateFromParams(t *testing.T) { }, { desc: "create multiple, within transaction", - ds: []Domain{ + inputs: []DomainCreateInput{ { - Hostname: "test4.com", Org: "testOrg1", ControllerDomainID: &controllerDomainID, Status: DomainStatusPending, CreatedBy: user.ID, + Hostname: "test4.com", + Org: "testOrg1", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, + CreatedBy: user.ID, }, { - Hostname: "test5.com", Org: "testOrg2", ControllerDomainID: nil, Status: DomainStatusPending, CreatedBy: user.ID, + Hostname: "test5.com", + Org: "testOrg2", + ControllerDomainID: nil, + Status: DomainStatusPending, + CreatedBy: user.ID, }, { - Hostname: "test6.com", Org: "testOrg3", ControllerDomainID: &controllerDomainID, Status: DomainStatusPending, CreatedBy: user.ID, + Hostname: "test6.com", + Org: "testOrg3", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, + CreatedBy: user.ID, }, }, expectError: false, @@ -119,10 +146,8 @@ func TestDomainSQLDAO_CreateFromParams(t *testing.T) { } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { - for _, i := range tc.ds { - d, err := dsd.CreateFromParams( - ctx, tc.tx, i.Hostname, i.Org, i.ControllerDomainID, i.Status, i.CreatedBy, - ) + for _, input := range tc.inputs { + d, err := dsd.Create(ctx, tc.tx, input) assert.Equal(t, tc.expectError, err != nil) if !tc.expectError { assert.NotNil(t, d) @@ -151,7 +176,13 @@ func TestDomainSQLDAO_GetByID(t *testing.T) { controllerDomainID := uuid.New() dsd := NewDomainDAO(dbSession) - domain1, err := dsd.CreateFromParams(ctx, nil, "test.com", "testOrg", &controllerDomainID, DomainStatusPending, user.ID) + domain1, err := dsd.Create(ctx, nil, DomainCreateInput{ + Hostname: "test.com", + Org: "testOrg", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, + CreatedBy: user.ID, + }) assert.Nil(t, err) // OTEL Spanner configuration @@ -210,13 +241,31 @@ func TestDomainSQLDAO_GetAll(t *testing.T) { controllerDomainID := uuid.New() dsd := NewDomainDAO(dbSession) - domain1, err := dsd.CreateFromParams(ctx, nil, "test1.com", "testOrg1", &controllerDomainID, DomainStatusPending, user.ID) + domain1, err := dsd.Create(ctx, nil, DomainCreateInput{ + Hostname: "test1.com", + Org: "testOrg1", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, + CreatedBy: user.ID, + }) assert.Nil(t, err) assert.NotNil(t, domain1) - domain2, err := dsd.CreateFromParams(ctx, nil, "test1.com", "testOrg2", &controllerDomainID, DomainStatusPending, user.ID) + domain2, err := dsd.Create(ctx, nil, DomainCreateInput{ + Hostname: "test1.com", + Org: "testOrg2", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, + CreatedBy: user.ID, + }) assert.Nil(t, err) assert.NotNil(t, domain2) - domain3, err := dsd.CreateFromParams(ctx, nil, "test2.com", "testOrg2", &controllerDomainID, DomainStatusPending, user.ID) + domain3, err := dsd.Create(ctx, nil, DomainCreateInput{ + Hostname: "test2.com", + Org: "testOrg2", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, + CreatedBy: user.ID, + }) assert.Nil(t, err) assert.NotNil(t, domain3) @@ -225,85 +274,82 @@ func TestDomainSQLDAO_GetAll(t *testing.T) { tests := []struct { desc string - hostname *string - org *string - controllerDomainID *uuid.UUID - status *string + filter DomainFilterInput expectedCnt int expectedError bool verifyChildSpanner bool }{ { desc: "GetAll with no filters returns objects", - hostname: nil, - org: nil, - controllerDomainID: nil, + filter: DomainFilterInput{}, expectedCnt: 3, expectedError: false, verifyChildSpanner: true, }, { - desc: "GetAll with hostname filter returns objects", - hostname: cutil.GetPtr("test1.com"), - org: nil, - controllerDomainID: nil, - expectedCnt: 2, - expectedError: false, + desc: "GetAll with hostname filter returns objects", + filter: DomainFilterInput{ + Hostname: cutil.GetPtr("test1.com"), + }, + expectedCnt: 2, + expectedError: false, }, { - desc: "GetAll with org filter returns objects", - hostname: nil, - org: cutil.GetPtr("testOrg2"), - controllerDomainID: nil, - expectedCnt: 2, - expectedError: false, + desc: "GetAll with org filter returns objects", + filter: DomainFilterInput{ + Org: cutil.GetPtr("testOrg2"), + }, + expectedCnt: 2, + expectedError: false, }, { - desc: "GetAll with controllerDomainID filter returns objects", - hostname: nil, - org: nil, - controllerDomainID: &controllerDomainID, - expectedCnt: 3, - expectedError: false, + desc: "GetAll with controllerDomainID filter returns objects", + filter: DomainFilterInput{ + ControllerDomainID: &controllerDomainID, + }, + expectedCnt: 3, + expectedError: false, }, { - desc: "GetAll with multiple filters returns objects", - hostname: cutil.GetPtr("test1.com"), - org: cutil.GetPtr("testOrg1"), - controllerDomainID: &controllerDomainID, - expectedCnt: 1, - expectedError: false, + desc: "GetAll with multiple filters returns objects", + filter: DomainFilterInput{ + Hostname: cutil.GetPtr("test1.com"), + Org: cutil.GetPtr("testOrg1"), + ControllerDomainID: &controllerDomainID, + }, + expectedCnt: 1, + expectedError: false, }, { - desc: "GetAll with multiple filters returns no objects", - hostname: cutil.GetPtr("notfound.com"), - org: cutil.GetPtr("testOrg1"), - controllerDomainID: &controllerDomainID, - expectedCnt: 0, - expectedError: false, + desc: "GetAll with multiple filters returns no objects", + filter: DomainFilterInput{ + Hostname: cutil.GetPtr("notfound.com"), + Org: cutil.GetPtr("testOrg1"), + ControllerDomainID: &controllerDomainID, + }, + expectedCnt: 0, + expectedError: false, }, { - desc: "GetAll with DomainStatusPending status returns objects", - hostname: nil, - org: nil, - controllerDomainID: nil, - expectedCnt: 3, - status: cutil.GetPtr(DomainStatusPending), - expectedError: false, + desc: "GetAll with DomainStatusPending status returns objects", + filter: DomainFilterInput{ + Status: cutil.GetPtr(DomainStatusPending), + }, + expectedCnt: 3, + expectedError: false, }, { - desc: "GetAll with DomainStatusError status returns no objects", - hostname: nil, - org: nil, - controllerDomainID: nil, - expectedCnt: 0, - status: cutil.GetPtr(DomainStatusError), - expectedError: false, + desc: "GetAll with DomainStatusError status returns no objects", + filter: DomainFilterInput{ + Status: cutil.GetPtr(DomainStatusError), + }, + expectedCnt: 0, + expectedError: false, }, } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { - tmp, err := dsd.GetAll(ctx, nil, tc.hostname, tc.org, tc.controllerDomainID, tc.status, nil) + tmp, err := dsd.GetAll(ctx, nil, tc.filter, nil) assert.Equal(t, tc.expectedError, err != nil) if tc.expectedError { assert.Equal(t, nil, tmp) @@ -321,7 +367,7 @@ func TestDomainSQLDAO_GetAll(t *testing.T) { } } -func TestDomainSQLDAO_UpdateFromParams(t *testing.T) { +func TestDomainSQLDAO_Update(t *testing.T) { ctx := context.Background() dbSession := testDomainInitDB(t) defer dbSession.Close() @@ -330,9 +376,21 @@ func TestDomainSQLDAO_UpdateFromParams(t *testing.T) { controllerDomainID := uuid.New() updatedControllerDomainID := uuid.New() dsd := NewDomainDAO(dbSession) - domain, err := dsd.CreateFromParams(ctx, nil, "test.com", "testOrg", &controllerDomainID, DomainStatusPending, user.ID) + domain, err := dsd.Create(ctx, nil, DomainCreateInput{ + Hostname: "test.com", + Org: "testOrg", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, + CreatedBy: user.ID, + }) assert.Nil(t, err) - domain2, err := dsd.CreateFromParams(ctx, nil, "test2.com", "testOrg2", &controllerDomainID, DomainStatusPending, user.ID) + domain2, err := dsd.Create(ctx, nil, DomainCreateInput{ + Hostname: "test2.com", + Org: "testOrg2", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, + CreatedBy: user.ID, + }) assert.Nil(t, err) tx1, err := db.BeginTx(context.Background(), dbSession, &sql.TxOptions{}) assert.Nil(t, err) @@ -342,14 +400,9 @@ func TestDomainSQLDAO_UpdateFromParams(t *testing.T) { tests := []struct { desc string - id uuid.UUID - - paramDomain *Domain - paramHostname *string - paramOrg *string - paramControllerDomainID *uuid.UUID - paramStatus *string + input DomainUpdateInput + paramDomain *Domain expectedError bool expectedHostname *string expectedOrg *string @@ -359,13 +412,12 @@ func TestDomainSQLDAO_UpdateFromParams(t *testing.T) { verifyChildSpanner bool }{ { - desc: "can update hostname", - id: domain.ID, + desc: "can update hostname", + input: DomainUpdateInput{ + DomainID: domain.ID, + Hostname: cutil.GetPtr("updated.com"), + }, paramDomain: domain, - paramHostname: cutil.GetPtr("updated.com"), - paramOrg: nil, - paramControllerDomainID: nil, - paramStatus: nil, expectedError: false, expectedHostname: cutil.GetPtr("updated.com"), expectedOrg: &domain.Org, @@ -374,13 +426,12 @@ func TestDomainSQLDAO_UpdateFromParams(t *testing.T) { verifyChildSpanner: true, }, { - desc: "error when updating object doesnt exist", - id: uuid.New(), + desc: "error when updating object doesnt exist", + input: DomainUpdateInput{ + DomainID: uuid.New(), + Hostname: cutil.GetPtr("updated.com"), + }, paramDomain: domain, - paramHostname: cutil.GetPtr("updated.com"), - paramOrg: nil, - paramControllerDomainID: nil, - paramStatus: nil, expectedError: true, expectedHostname: cutil.GetPtr("updated.com"), expectedOrg: &domain.Org, @@ -388,13 +439,12 @@ func TestDomainSQLDAO_UpdateFromParams(t *testing.T) { expectedStatus: &domain.Status, }, { - desc: "can update org", - id: domain.ID, + desc: "can update org", + input: DomainUpdateInput{ + DomainID: domain.ID, + Org: cutil.GetPtr("updatedOrg"), + }, paramDomain: domain, - paramHostname: nil, - paramOrg: cutil.GetPtr("updatedOrg"), - paramControllerDomainID: nil, - paramStatus: nil, expectedError: false, expectedHostname: cutil.GetPtr("updated.com"), expectedOrg: cutil.GetPtr("updatedOrg"), @@ -402,13 +452,12 @@ func TestDomainSQLDAO_UpdateFromParams(t *testing.T) { expectedStatus: &domain.Status, }, { - desc: "can update controllerDomainID", - id: domain.ID, + desc: "can update controllerDomainID", + input: DomainUpdateInput{ + DomainID: domain.ID, + ControllerDomainID: &updatedControllerDomainID, + }, paramDomain: domain, - paramHostname: nil, - paramOrg: nil, - paramControllerDomainID: &updatedControllerDomainID, - paramStatus: nil, expectedError: false, expectedHostname: cutil.GetPtr("updated.com"), expectedOrg: cutil.GetPtr("updatedOrg"), @@ -416,13 +465,12 @@ func TestDomainSQLDAO_UpdateFromParams(t *testing.T) { expectedStatus: &domain.Status, }, { - desc: "can update status", - id: domain.ID, + desc: "can update status", + input: DomainUpdateInput{ + DomainID: domain.ID, + Status: cutil.GetPtr(DomainStatusReady), + }, paramDomain: domain, - paramHostname: nil, - paramOrg: nil, - paramControllerDomainID: nil, - paramStatus: cutil.GetPtr(DomainStatusReady), expectedError: false, expectedHostname: cutil.GetPtr("updated.com"), expectedOrg: cutil.GetPtr("updatedOrg"), @@ -430,13 +478,15 @@ func TestDomainSQLDAO_UpdateFromParams(t *testing.T) { expectedStatus: cutil.GetPtr(DomainStatusReady), }, { - desc: "can update multiple fields", - id: domain2.ID, + desc: "can update multiple fields", + input: DomainUpdateInput{ + DomainID: domain2.ID, + Hostname: cutil.GetPtr("updated.com"), + Org: cutil.GetPtr("updatedOrg"), + ControllerDomainID: &updatedControllerDomainID, + Status: cutil.GetPtr(DomainStatusReady), + }, paramDomain: domain2, - paramHostname: cutil.GetPtr("updated.com"), - paramOrg: cutil.GetPtr("updatedOrg"), - paramControllerDomainID: &updatedControllerDomainID, - paramStatus: cutil.GetPtr(DomainStatusReady), expectedError: false, expectedHostname: cutil.GetPtr("updated.com"), expectedOrg: cutil.GetPtr("updatedOrg"), @@ -445,13 +495,11 @@ func TestDomainSQLDAO_UpdateFromParams(t *testing.T) { tx: tx1, }, { - desc: "noop when no fields are specified", - id: domain2.ID, + desc: "noop when no fields are specified", + input: DomainUpdateInput{ + DomainID: domain2.ID, + }, paramDomain: domain2, - paramHostname: nil, - paramOrg: nil, - paramControllerDomainID: nil, - paramStatus: nil, expectedError: false, expectedHostname: cutil.GetPtr("updated.com"), expectedOrg: cutil.GetPtr("updatedOrg"), @@ -461,7 +509,7 @@ func TestDomainSQLDAO_UpdateFromParams(t *testing.T) { } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { - got, err := dsd.UpdateFromParams(ctx, tc.tx, tc.id, tc.paramHostname, tc.paramOrg, tc.paramControllerDomainID, tc.paramStatus) + got, err := dsd.Update(ctx, tc.tx, tc.input) assert.Equal(t, tc.expectedError, err != nil) if err == nil { assert.NotNil(t, got) @@ -493,7 +541,7 @@ func TestDomainSQLDAO_UpdateFromParams(t *testing.T) { } } -func TestDomainSQLDAO_ClearFromParams(t *testing.T) { +func TestDomainSQLDAO_Clear(t *testing.T) { ctx := context.Background() dbSession := testDomainInitDB(t) defer dbSession.Close() @@ -501,9 +549,21 @@ func TestDomainSQLDAO_ClearFromParams(t *testing.T) { user := testDomainBuildUser(t, dbSession, "testUser") controllerDomainID := uuid.New() dsd := NewDomainDAO(dbSession) - domain, err := dsd.CreateFromParams(ctx, nil, "test.com", "testOrg", &controllerDomainID, DomainStatusPending, user.ID) + domain, err := dsd.Create(ctx, nil, DomainCreateInput{ + Hostname: "test.com", + Org: "testOrg", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, + CreatedBy: user.ID, + }) assert.Nil(t, err) - domain2, err := dsd.CreateFromParams(ctx, nil, "test.com", "testOrg", &controllerDomainID, DomainStatusPending, user.ID) + domain2, err := dsd.Create(ctx, nil, DomainCreateInput{ + Hostname: "test.com", + Org: "testOrg", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, + CreatedBy: user.ID, + }) assert.Nil(t, err) tx1, err := db.BeginTx(context.Background(), dbSession, &sql.TxOptions{}) assert.Nil(t, err) @@ -511,11 +571,13 @@ func TestDomainSQLDAO_ClearFromParams(t *testing.T) { // OTEL Spanner configuration _, _, ctx = testCommonTraceProviderSetup(t, ctx) + missingID := uuid.New() + tests := []struct { desc string domain *Domain + input DomainClearInput - paramControllerDomainID bool expectedUpdate bool expectedError bool expectedControllerDomainID *uuid.UUID @@ -523,9 +585,12 @@ func TestDomainSQLDAO_ClearFromParams(t *testing.T) { verifyChildSpanner bool }{ { - desc: "can clear controllerDomainID", - domain: domain, - paramControllerDomainID: true, + desc: "can clear controllerDomainID", + domain: domain, + input: DomainClearInput{ + DomainID: domain.ID, + ControllerDomainID: true, + }, expectedUpdate: true, expectedError: false, expectedControllerDomainID: nil, @@ -533,32 +598,40 @@ func TestDomainSQLDAO_ClearFromParams(t *testing.T) { verifyChildSpanner: true, }, { - desc: "can clear controllerDomainID when it is already nil", - domain: domain, - paramControllerDomainID: true, + desc: "can clear controllerDomainID when it is already nil", + domain: domain, + input: DomainClearInput{ + DomainID: domain.ID, + ControllerDomainID: true, + }, expectedUpdate: true, expectedError: false, expectedControllerDomainID: nil, }, { - desc: "noop when nothing cleared", - domain: domain2, - paramControllerDomainID: false, + desc: "noop when nothing cleared", + domain: domain2, + input: DomainClearInput{ + DomainID: domain2.ID, + }, expectedUpdate: false, expectedError: false, expectedControllerDomainID: domain2.ControllerDomainID, }, { - desc: "error when updating object doesnt exist", - domain: &Domain{ID: uuid.New()}, - paramControllerDomainID: true, + desc: "error when updating object doesnt exist", + domain: &Domain{ID: missingID}, + input: DomainClearInput{ + DomainID: missingID, + ControllerDomainID: true, + }, expectedError: true, expectedControllerDomainID: nil, }, } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { - got, err := dsd.ClearFromParams(ctx, tc.tx, tc.domain.ID, tc.paramControllerDomainID) + got, err := dsd.Clear(ctx, tc.tx, tc.input) assert.Equal(t, tc.expectedError, err != nil) if err == nil { assert.NotNil(t, got) @@ -584,7 +657,7 @@ func TestDomainSQLDAO_ClearFromParams(t *testing.T) { } } -func TestDomainSQLDAO_DeleteByID(t *testing.T) { +func TestDomainSQLDAO_Delete(t *testing.T) { ctx := context.Background() dbSession := testDomainInitDB(t) defer dbSession.Close() @@ -592,7 +665,13 @@ func TestDomainSQLDAO_DeleteByID(t *testing.T) { user := testDomainBuildUser(t, dbSession, "testUser") controllerDomainID := uuid.New() dsd := NewDomainDAO(dbSession) - domain, err := dsd.CreateFromParams(ctx, nil, "test.com", "testOrg", &controllerDomainID, DomainStatusPending, user.ID) + domain, err := dsd.Create(ctx, nil, DomainCreateInput{ + Hostname: "test.com", + Org: "testOrg", + ControllerDomainID: &controllerDomainID, + Status: DomainStatusPending, + CreatedBy: user.ID, + }) assert.Nil(t, err) tx1, err := db.BeginTx(context.Background(), dbSession, &sql.TxOptions{}) assert.Nil(t, err) @@ -622,7 +701,7 @@ func TestDomainSQLDAO_DeleteByID(t *testing.T) { } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { - err := dsd.DeleteByID(ctx, tc.tx, tc.id) + err := dsd.Delete(ctx, tc.tx, tc.id) assert.Equal(t, tc.expectedError, err != nil) if !tc.expectedError { tmp, err := dsd.GetByID(ctx, tc.tx, tc.id, nil) diff --git a/rest-api/db/pkg/db/model/instancetype_test.go b/rest-api/db/pkg/db/model/instancetype_test.go index 9047e3afb5..61496c6636 100644 --- a/rest-api/db/pkg/db/model/instancetype_test.go +++ b/rest-api/db/pkg/db/model/instancetype_test.go @@ -545,13 +545,21 @@ func TestInstanceTypeSQLDAO_GetAll(t *testing.T) { // Create a single allocation with a constraint > 0 if i == 1 { - _, serr := acDAO.CreateFromParams(ctx, nil, at1.ID, AllocationResourceTypeInstanceType, it.ID, AllocationConstraintTypeReserved, 5, nil, user.ID) + _, serr := acDAO.Create(ctx, nil, AllocationConstraintCreateInput{ + AllocationID: at1.ID, ResourceType: AllocationResourceTypeInstanceType, + ResourceTypeID: it.ID, ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 5, CreatedBy: user.ID, + }) assert.NoError(t, serr) } // Create an allocation but with no real constraint. An allocation that exists but is empty. if i == 3 { - _, serr := acDAO.CreateFromParams(ctx, nil, at2.ID, AllocationResourceTypeInstanceType, it.ID, AllocationConstraintTypeReserved, 0, nil, user.ID) + _, serr := acDAO.Create(ctx, nil, AllocationConstraintCreateInput{ + AllocationID: at2.ID, ResourceType: AllocationResourceTypeInstanceType, + ResourceTypeID: it.ID, ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: 0, CreatedBy: user.ID, + }) assert.NoError(t, serr) } diff --git a/rest-api/db/pkg/db/model/testing.go b/rest-api/db/pkg/db/model/testing.go index 01f01e8407..0d199a97fd 100644 --- a/rest-api/db/pkg/db/model/testing.go +++ b/rest-api/db/pkg/db/model/testing.go @@ -304,7 +304,11 @@ func TestBuildAllocationConstraint(t *testing.T, dbSession *db.Session, al *Allo } acDAO := NewAllocationConstraintDAO(dbSession) - ac, err := acDAO.CreateFromParams(context.Background(), nil, al.ID, AllocationResourceTypeInstanceType, resourceID, AllocationConstraintTypeReserved, constraintValue, nil, user.ID) + ac, err := acDAO.Create(context.Background(), nil, AllocationConstraintCreateInput{ + AllocationID: al.ID, ResourceType: AllocationResourceTypeInstanceType, + ResourceTypeID: resourceID, ConstraintType: AllocationConstraintTypeReserved, + ConstraintValue: constraintValue, CreatedBy: user.ID, + }) assert.Nil(t, err) return ac diff --git a/rest-api/workflow/pkg/util/testing.go b/rest-api/workflow/pkg/util/testing.go index 3bded0ade3..618d32e035 100644 --- a/rest-api/workflow/pkg/util/testing.go +++ b/rest-api/workflow/pkg/util/testing.go @@ -342,7 +342,10 @@ func TestBuildAllocation(t *testing.T, dbSession *cdb.Session, ip *cdbm.Infrastr func TestBuildAllocationContraints(t *testing.T, dbSession *cdb.Session, al *cdbm.Allocation, rt string, rtID uuid.UUID, ct string, cv int, user *cdbm.User) *cdbm.AllocationConstraint { alctDAO := cdbm.NewAllocationConstraintDAO(dbSession) - alct, err := alctDAO.CreateFromParams(context.Background(), nil, al.ID, rt, rtID, ct, cv, nil, user.ID) + alct, err := alctDAO.Create(context.Background(), nil, cdbm.AllocationConstraintCreateInput{ + AllocationID: al.ID, ResourceType: rt, ResourceTypeID: rtID, + ConstraintType: ct, ConstraintValue: cv, CreatedBy: user.ID, + }) assert.Nil(t, err) return alct @@ -577,7 +580,11 @@ func TestUpdateVPC(t *testing.T, dbSession *cdb.Session, v *cdbm.Vpc) { // TestBuildAllocationConstraint creates a test Allocation Constraint of Instance Type func TestBuildAllocationConstraint(t *testing.T, dbSession *cdb.Session, al *cdbm.Allocation, it *cdbm.InstanceType, constraintValue int, user *cdbm.User) *cdbm.AllocationConstraint { acDAO := cdbm.NewAllocationConstraintDAO(dbSession) - ac, err := acDAO.CreateFromParams(context.Background(), nil, al.ID, cdbm.AllocationResourceTypeInstanceType, it.ID, cdbm.AllocationConstraintTypeReserved, constraintValue, nil, user.ID) + ac, err := acDAO.Create(context.Background(), nil, cdbm.AllocationConstraintCreateInput{ + AllocationID: al.ID, ResourceType: cdbm.AllocationResourceTypeInstanceType, + ResourceTypeID: it.ID, ConstraintType: cdbm.AllocationConstraintTypeReserved, + ConstraintValue: constraintValue, CreatedBy: user.ID, + }) assert.Nil(t, err) return ac