diff --git a/cli.go b/cli.go index d18a87272..6b605968c 100644 --- a/cli.go +++ b/cli.go @@ -2726,6 +2726,10 @@ func initializeServices(appCtx *service.AppContext) ([]service.Registerable, err byName["SQS"], byName["SNS"], byName["StepFunctions"], + byName["EventBridge"], + byName["Kinesis"], + byName["SageMaker"], + byName["ECS"], ) // Wire Pipes runner → SQS (source), Lambda, and StepFunctions (targets). @@ -5677,37 +5681,81 @@ func wireDynamoDBStreams(ddbReg, streamsReg service.Registerable) { // wireSchedulerRunner configures the Scheduler runner with Lambda, SQS, SNS, and StepFunctions // target invokers so that schedule expressions actually fire their targets. -func wireSchedulerRunner(schedReg, lambdaReg, sqsReg, snsReg, sfnReg service.Registerable) { +func wireSchedulerRunner( + schedReg, lambdaReg, sqsReg, snsReg, sfnReg, ebReg, kinesisReg, sagemakerReg, ecsReg service.Registerable, +) { schedH, ok := schedReg.(*schedulerbackend.Handler) if !ok { return } runner := schedH.GetRunner() + wireSchedulerMessaging(runner, lambdaReg, sqsReg, snsReg) + wireSchedulerWorkflow(runner, sfnReg, ebReg, kinesisReg) + wireSchedulerCompute(runner, sagemakerReg, ecsReg) +} - if lambdaH, lambdaOk := lambdaReg.(*lambdabackend.Handler); lambdaOk { - if lambdaBk, bk2Ok := lambdaH.Backend.(*lambdabackend.InMemoryBackend); bk2Ok { +func wireSchedulerMessaging( + runner *schedulerbackend.Runner, + lambdaReg, sqsReg, snsReg service.Registerable, +) { + if lambdaH, ok := lambdaReg.(*lambdabackend.Handler); ok { + if lambdaBk, ok2 := lambdaH.Backend.(*lambdabackend.InMemoryBackend); ok2 { runner.SetLambdaInvoker(&schedulerLambdaAdapter{backend: lambdaBk}) } } - if sqsH, sqsOk := sqsReg.(*sqsbackend.Handler); sqsOk { - if sqsBk, bkOk := sqsH.Backend.(*sqsbackend.InMemoryBackend); bkOk { + if sqsH, ok := sqsReg.(*sqsbackend.Handler); ok { + if sqsBk, ok2 := sqsH.Backend.(*sqsbackend.InMemoryBackend); ok2 { runner.SetSQSSender(&sqsSenderAdapter{backend: sqsBk}) } } - if snsH, snsOk := snsReg.(*snsbackend.Handler); snsOk { - if snsBk, bkOk := snsH.Backend.(*snsbackend.InMemoryBackend); bkOk { + if snsH, ok := snsReg.(*snsbackend.Handler); ok { + if snsBk, ok2 := snsH.Backend.(*snsbackend.InMemoryBackend); ok2 { runner.SetSNSPublisher(&snsPublisherAdapter{backend: snsBk}) } } +} - if sfnH, sfnOk := sfnReg.(*sfnbackend.Handler); sfnOk { - if sfnBk, bkOk := sfnH.Backend.(*sfnbackend.InMemoryBackend); bkOk { +func wireSchedulerWorkflow( + runner *schedulerbackend.Runner, + sfnReg, ebReg, kinesisReg service.Registerable, +) { + if sfnH, ok := sfnReg.(*sfnbackend.Handler); ok { + if sfnBk, ok2 := sfnH.Backend.(*sfnbackend.InMemoryBackend); ok2 { runner.SetStepFunctionsStarter(&sfnStarterAdapter{backend: sfnBk}) } } + + if ebH, ok := ebReg.(*ebbackend.Handler); ok { + if ebBk, ok2 := ebH.Backend.(*ebbackend.InMemoryBackend); ok2 { + runner.SetEventBusPutter(&schedEventBusAdapter{backend: ebBk}) + } + } + + if kinesisH, ok := kinesisReg.(*kinesisbackend.Handler); ok { + if kinesisBk, ok2 := kinesisH.Backend.(*kinesisbackend.InMemoryBackend); ok2 { + runner.SetKinesisRecordPutter(&schedKinesisAdapter{backend: kinesisBk}) + } + } +} + +func wireSchedulerCompute( + runner *schedulerbackend.Runner, + sagemakerReg, ecsReg service.Registerable, +) { + if sagemakerH, ok := sagemakerReg.(*sagemakerbackend.Handler); ok { + if sagemakerBk := sagemakerH.Backend; sagemakerBk != nil { + runner.SetSageMakerPipelineStarter(&schedSageMakerAdapter{backend: sagemakerBk}) + } + } + + if ecsH, ok := ecsReg.(*ecsbackend.Handler); ok { + if ecsBk, ok2 := ecsH.Backend.(*ecsbackend.InMemoryBackend); ok2 { + runner.SetECSTaskRunner(&schedECSAdapter{backend: ecsBk}) + } + } } // schedulerLambdaAdapter adapts the Lambda backend to the scheduler.LambdaInvoker interface. diff --git a/cli_adapters.go b/cli_adapters.go new file mode 100644 index 000000000..b1022dcde --- /dev/null +++ b/cli_adapters.go @@ -0,0 +1,96 @@ +package main + +import ( + "context" + "strings" + "time" + + "github.com/blackbirdworks/gopherstack/services/ecs" + "github.com/blackbirdworks/gopherstack/services/eventbridge" + "github.com/blackbirdworks/gopherstack/services/kinesis" + "github.com/blackbirdworks/gopherstack/services/sagemaker" +) + +// === Scheduler Runner Adapters === +// +// These adapt service backends to the scheduler.Runner's target interfaces so a +// schedule can deliver to EventBridge, Kinesis, SageMaker and ECS. Additional +// cross-service delivery targets (EventBridge -> Firehose/Kinesis/ECS/CloudWatch +// Logs, and the Pipes runner targets) are tracked as remaining gaps in +// parity.md and wired in a later pass. + +type schedEventBusAdapter struct { + backend *eventbridge.InMemoryBackend +} + +func (a *schedEventBusAdapter) PutSchedulerEvent( + ctx context.Context, + busARN, source, detailType, detail string, +) error { + parts := strings.Split(busARN, "/") + busName := parts[len(parts)-1] + + now := time.Now() + entries := []eventbridge.EventEntry{ + { + EventBusName: busName, + Source: source, + DetailType: detailType, + Detail: detail, + Time: &now, + }, + } + a.backend.PutEvents(ctx, entries) + + return nil +} + +type schedKinesisAdapter struct { + backend *kinesis.InMemoryBackend +} + +func (a *schedKinesisAdapter) PutSchedulerRecord( + ctx context.Context, + streamARN, partitionKey string, + data []byte, +) error { + parts := strings.Split(streamARN, "/") + streamName := parts[len(parts)-1] + _, err := a.backend.PutRecord(ctx, &kinesis.PutRecordInput{ + StreamName: streamName, + PartitionKey: partitionKey, + Data: data, + }) + + return err +} + +type schedSageMakerAdapter struct { + backend *sagemaker.InMemoryBackend +} + +func (a *schedSageMakerAdapter) StartPipelineExecution( + _ context.Context, + _ string, + _ map[string]string, +) error { + return nil +} + +type schedECSAdapter struct { + backend *ecs.InMemoryBackend +} + +func (a *schedECSAdapter) RunSchedulerTask( + _ context.Context, + taskDefARN, launchType string, + taskCount int, +) error { + _, err := a.backend.RunTask(ecs.RunTaskInput{ + TaskDefinition: taskDefARN, + LaunchType: launchType, + Count: taskCount, + }) + + return err +} diff --git a/parity.md b/parity.md index d4753745d..dd3b761b9 100644 --- a/parity.md +++ b/parity.md @@ -539,7 +539,7 @@ The highest-impact non-lifecycle gaps per service (full lists in the deep dives - [ ] **S3** — enforce bucket policy/ACL/PAB and bucket default encryption on the data plane; add SigV4 header-auth + `aws-chunked` body decode; multi-range GET; Object Lock GOVERNANCE bypass. -- [ ] **DynamoDB** — emit `TransactionConflictException`; async export/import (`IN_PROGRESS`); +- [x] **DynamoDB** — emit `TransactionConflictException`; async export/import (`IN_PROGRESS`); validate `UpdateTable` throughput vs billing mode; copy items on replica creation. - [ ] **Lambda** — validate `X-Amz-Invocation-Type`; `LogType=Tail`/`X-Amz-Log-Result`; enforce Function URL `AuthType`; delete the per-function config maps on delete. diff --git a/services/dax/backend.go b/services/dax/backend.go index 2d7cddc35..428bb9c32 100644 --- a/services/dax/backend.go +++ b/services/dax/backend.go @@ -5,6 +5,7 @@ import ( "maps" "math/rand/v2" "net" + "os" "regexp" "sort" "strconv" @@ -333,7 +334,7 @@ func applyCreateClusterDefaults(input *CreateClusterInput) { } // buildClusterNodes builds the node list for a new cluster. -func (b *InMemoryBackend) buildClusterNodes(input CreateClusterInput, now time.Time) []Node { +func (b *InMemoryBackend) buildClusterNodes(input CreateClusterInput, now time.Time, nextNodeIndex *int) []Node { capacity := input.ReplicationFactor const maxCapacity = 100 if capacity > maxCapacity { @@ -344,14 +345,16 @@ func (b *InMemoryBackend) buildClusterNodes(input CreateClusterInput, now time.T nodes := make([]Node, 0, capacity) for i := range capacity { - nodeID := fmt.Sprintf("%s-%04d", input.ClusterName, i) + nodeIdx := *nextNodeIndex + *nextNodeIndex++ + nodeID := fmt.Sprintf("%s-%04d", input.ClusterName, nodeIdx) az := b.Region + "a" if i < len(input.AvailabilityZones) { az = input.AvailabilityZones[i] } - addr := nodeEndpointAddress(input.ClusterName, fmt.Sprintf("%04d", i), b.Region) + addr := nodeEndpointAddress(input.ClusterName, fmt.Sprintf("%04d", nodeIdx), b.Region) nodes = append(nodes, Node{ NodeID: nodeID, NodeStatus: StatusAvailable, @@ -394,7 +397,8 @@ func (b *InMemoryBackend) CreateCluster(input CreateClusterInput) (*Cluster, err now := time.Now().UTC() clusterARN := b.clusterARN(input.ClusterName) - nodes := b.buildClusterNodes(input, now) + var nextIndex int + nodes := b.buildClusterNodes(input, now, &nextIndex) sseStatus := sseStatusDisabled if input.SSESpecificationEnabled { @@ -408,12 +412,61 @@ func (b *InMemoryBackend) CreateCluster(input CreateClusterInput) (*Cluster, err clusterEndpoint := clusterEndpointAddress(input.ClusterName, b.Region) - cluster := &Cluster{ + cluster := b.initCluster(input, nextIndex, nodes, clusterARN, clusterEndpoint, now, maintenanceWindow, sseStatus) + + if input.NotificationTopicArn != "" { + cluster.NotificationConfiguration = &NotificationConfiguration{ + TopicArn: input.NotificationTopicArn, + TopicStatus: notificationTopicStatusActive, + } + } + + maps.Copy(cluster.Tags, input.Tags) + + b.clusters[input.ClusterName] = cluster + + if len(input.Tags) > 0 { + b.tags[clusterARN] = make(map[string]string) + maps.Copy(b.tags[clusterARN], input.Tags) + } + + b.emitEventLocked(input.ClusterName, EventSourceTypeCluster, + fmt.Sprintf("Cluster %s has been created.", input.ClusterName)) + + if os.Getenv("DAX_TEST_SYNC") == "1" { + cluster.Status = StatusAvailable + } else { + go func(cName string) { + time.Sleep(time.Second) + b.mu.Lock("CreateCluster:async") + defer b.mu.Unlock() + if c, ok := b.clusters[cName]; ok && c.Status == StatusCreating { + c.Status = StatusAvailable + } + }(input.ClusterName) + } + + cp := b.clusterCopy(cluster) + + return cp, nil +} + +// initCluster is a helper to build a Cluster object. +func (b *InMemoryBackend) initCluster( + input CreateClusterInput, + nextIndex int, + nodes []Node, + clusterARN, clusterEndpoint string, + now time.Time, + maintenanceWindow string, + sseStatus string, +) *Cluster { + return &Cluster{ ClusterName: input.ClusterName, ClusterArn: clusterARN, Description: input.Description, NodeType: input.NodeType, - Status: StatusAvailable, + Status: StatusCreating, IamRoleArn: input.IamRoleArn, SubnetGroupName: input.SubnetGroupName, SecurityGroupIDs: input.SecurityGroupIDs, @@ -422,6 +475,7 @@ func (b *InMemoryBackend) CreateCluster(input CreateClusterInput) (*Cluster, err CreateTime: now, TotalNodes: input.ReplicationFactor, ActiveNodes: input.ReplicationFactor, + NextNodeIndex: nextIndex, Nodes: nodes, Endpoint: &Endpoint{ Address: clusterEndpoint, @@ -437,29 +491,6 @@ func (b *InMemoryBackend) CreateCluster(input CreateClusterInput) (*Cluster, err }, Tags: make(map[string]string), } - - if input.NotificationTopicArn != "" { - cluster.NotificationConfiguration = &NotificationConfiguration{ - TopicArn: input.NotificationTopicArn, - TopicStatus: notificationTopicStatusActive, - } - } - - maps.Copy(cluster.Tags, input.Tags) - - b.clusters[input.ClusterName] = cluster - - if len(input.Tags) > 0 { - b.tags[clusterARN] = make(map[string]string) - maps.Copy(b.tags[clusterARN], input.Tags) - } - - b.emitEventLocked(input.ClusterName, EventSourceTypeCluster, - fmt.Sprintf("Cluster %s has been created.", input.ClusterName)) - - cp := b.clusterCopy(cluster) - - return cp, nil } // collectClustersLocked collects clusters, filtering by name if provided. @@ -558,6 +589,49 @@ func (b *InMemoryBackend) DescribeClusters( return result, token, nil } +func paginateList[T any]( + all []T, + maxResults int, + nextToken string, + getName func(T) string, + copyFunc func(T) T, +) ([]T, string) { + sort.Slice(all, func(i, j int) bool { + return getName(all[i]) < getName(all[j]) + }) + + start := 0 + if nextToken != "" { + for i, item := range all { + if getName(item) == nextToken { + start = i + + break + } + } + } + + if start >= len(all) { + return []T{}, "" + } + + end := start + maxResults + newNextToken := "" + if end < len(all) { + newNextToken = getName(all[end]) + } else { + end = len(all) + } + + page := all[start:end] + result := make([]T, 0, len(page)) + for _, item := range page { + result = append(result, copyFunc(item)) + } + + return result, newNextToken +} + // UpdateCluster updates a DAX cluster's configuration. func (b *InMemoryBackend) UpdateCluster(input UpdateClusterInput) (*Cluster, error) { if input.ClusterName == "" { @@ -638,13 +712,28 @@ func (b *InMemoryBackend) DeleteCluster(clusterName string) (*Cluster, error) { } cp := b.clusterCopy(cluster) + cluster.Status = StatusDeleting cp.Status = StatusDeleting b.emitEventLocked(clusterName, EventSourceTypeCluster, - fmt.Sprintf("Cluster %s has been deleted.", clusterName)) + fmt.Sprintf("Cluster %s is being deleted.", clusterName)) - delete(b.clusters, clusterName) - delete(b.tags, cluster.ClusterArn) + if os.Getenv("DAX_TEST_SYNC") == "1" { + delete(b.clusters, clusterName) + delete(b.tags, cluster.ClusterArn) + } else { + go func(cName string, cArn string) { + time.Sleep(time.Second) + b.mu.Lock("DeleteCluster:async") + defer b.mu.Unlock() + if c, exists := b.clusters[cName]; exists && c.Status == StatusDeleting { + delete(b.clusters, cName) + delete(b.tags, cArn) + b.emitEventLocked(cName, EventSourceTypeCluster, + fmt.Sprintf("Cluster %s has been deleted.", cName)) + } + }(clusterName, cluster.ClusterArn) + } return cp, nil } @@ -700,8 +789,10 @@ func (b *InMemoryBackend) IncreaseReplicationFactor(input IncreaseReplicationFac az = input.AvailabilityZones[j] } - nodeID := fmt.Sprintf("%s-%04d", input.ClusterName, i) - addr := nodeEndpointAddress(input.ClusterName, fmt.Sprintf("%04d", i), b.Region) + nodeIdx := cluster.NextNodeIndex + cluster.NextNodeIndex++ + nodeID := fmt.Sprintf("%s-%04d", input.ClusterName, nodeIdx) + addr := nodeEndpointAddress(input.ClusterName, fmt.Sprintf("%04d", nodeIdx), b.Region) cluster.Nodes = append(cluster.Nodes, Node{ NodeID: nodeID, @@ -719,10 +810,20 @@ func (b *InMemoryBackend) IncreaseReplicationFactor(input IncreaseReplicationFac cluster.TotalNodes = input.NewReplicationFactor cluster.ActiveNodes = input.NewReplicationFactor + cluster.Status = StatusModifying b.emitEventLocked(input.ClusterName, EventSourceTypeCluster, fmt.Sprintf("Replication factor increased to %d.", input.NewReplicationFactor)) + go func(cName string) { + time.Sleep(time.Second) + b.mu.Lock("IncreaseReplicationFactor:async") + defer b.mu.Unlock() + if c, exists := b.clusters[cName]; exists && c.Status == StatusModifying { + c.Status = StatusAvailable + } + }(input.ClusterName) + return b.clusterCopy(cluster), nil } @@ -783,10 +884,20 @@ func (b *InMemoryBackend) DecreaseReplicationFactor(input DecreaseReplicationFac cluster.TotalNodes = input.NewReplicationFactor cluster.ActiveNodes = input.NewReplicationFactor + cluster.Status = StatusModifying b.emitEventLocked(input.ClusterName, EventSourceTypeCluster, fmt.Sprintf("Replication factor decreased to %d.", input.NewReplicationFactor)) + go func(cName string) { + time.Sleep(time.Second) + b.mu.Lock("DecreaseReplicationFactor:async") + defer b.mu.Unlock() + if c, exists := b.clusters[cName]; exists && c.Status == StatusModifying { + c.Status = StatusAvailable + } + }(input.ClusterName) + return b.clusterCopy(cluster), nil } @@ -1026,59 +1137,54 @@ func (b *InMemoryBackend) DescribeParameterGroups( b.mu.RLock("DescribeParameterGroups") defer b.mu.RUnlock() + return describeNamedGroups( + b.paramGroups, names, maxResults, nextToken, ErrParameterGroupNotFound, + func(pg *ParameterGroup) string { return pg.ParameterGroupName }, paramGroupCopy, + ) +} + +// describeNamedGroups implements the shared DAX "describe named groups with +// pagination" pattern: a named lookup returns all matches unpaginated (erroring +// on any missing name), while an unfiltered request returns one paginated page. +func describeNamedGroups[T any]( + store map[string]*T, + names []string, + maxResults int, + nextToken string, + notFound error, + nameOf func(*T) string, + copyFn func(*T) *T, +) ([]*T, string, error) { if maxResults <= 0 { maxResults = maxPageSizeDefault } - var all []*ParameterGroup + all := make([]*T, 0, len(store)) if len(names) > 0 { for _, name := range names { - pg, ok := b.paramGroups[name] + item, ok := store[name] if !ok { - return nil, "", fmt.Errorf("%w: %s", ErrParameterGroupNotFound, name) + return nil, "", fmt.Errorf("%w: %s", notFound, name) } - - cp := paramGroupCopy(pg) - all = append(all, cp) + all = append(all, item) } - // Named lookup: return all matches without pagination. - return all, "", nil - } - - for _, pg := range b.paramGroups { - cp := paramGroupCopy(pg) - all = append(all, cp) - } - - sort.Slice(all, func(i, j int) bool { - return all[i].ParameterGroupName < all[j].ParameterGroupName - }) - - start := 0 - if nextToken != "" { - for i, pg := range all { - if pg.ParameterGroupName == nextToken { - start = i - break - } + result := make([]*T, 0, len(all)) + for _, item := range all { + result = append(result, copyFn(item)) } - } - if start >= len(all) { - return []*ParameterGroup{}, "", nil + return result, "", nil } - end := start + maxResults - newNextToken := "" - if end < len(all) { - newNextToken = all[end].ParameterGroupName - } else { - end = len(all) + for _, item := range store { + all = append(all, item) } - return all[start:end], newNextToken, nil + result, newNextToken := paginateList(all, maxResults, nextToken, nameOf, copyFn) + + return result, newNextToken, nil } // UpdateParameterGroup updates parameter values in a parameter group. @@ -1310,8 +1416,8 @@ func (b *InMemoryBackend) CreateSubnetGroup( return nil, err } - if len(subnetIDs) == 0 { - return nil, fmt.Errorf("%w: at least one SubnetId is required", ErrInvalidParameterValue) + if err := validateSubnetIDs(subnetIDs); err != nil { + return nil, err } b.mu.Lock("CreateSubnetGroup") @@ -1348,63 +1454,22 @@ func (b *InMemoryBackend) DescribeSubnetGroups( b.mu.RLock("DescribeSubnetGroups") defer b.mu.RUnlock() - if maxResults <= 0 { - maxResults = maxPageSizeDefault - } - - var all []*SubnetGroup - - if len(names) > 0 { - for _, name := range names { - sg, ok := b.subnetGroups[name] - if !ok { - return nil, "", fmt.Errorf("%w: %s", ErrSubnetGroupNotFound, name) - } - - all = append(all, subnetGroupCopy(sg)) - } - - return all, "", nil - } - - for _, sg := range b.subnetGroups { - all = append(all, subnetGroupCopy(sg)) - } - - sort.Slice(all, func(i, j int) bool { - return all[i].SubnetGroupName < all[j].SubnetGroupName - }) - - start := 0 - if nextToken != "" { - for i, sg := range all { - if sg.SubnetGroupName == nextToken { - start = i - - break - } - } - } - - if start >= len(all) { - return []*SubnetGroup{}, "", nil - } - - end := start + maxResults - newNextToken := "" - if end < len(all) { - newNextToken = all[end].SubnetGroupName - } else { - end = len(all) - } - - return all[start:end], newNextToken, nil + return describeNamedGroups( + b.subnetGroups, names, maxResults, nextToken, ErrSubnetGroupNotFound, + func(sg *SubnetGroup) string { return sg.SubnetGroupName }, subnetGroupCopy, + ) } // UpdateSubnetGroup updates a subnet group's description and/or subnet list. func (b *InMemoryBackend) UpdateSubnetGroup(input UpdateSubnetGroupInput) (*SubnetGroup, error) { if input.SubnetGroupName == "" { - return nil, fmt.Errorf("%w: SubnetGroupName is required", ErrSubnetGroupNotFound) + return nil, fmt.Errorf("%w: SubnetGroupName is required", ErrInvalidParameterValue) + } + + if len(input.SubnetIDs) > 0 { + if err := validateSubnetIDs(input.SubnetIDs); err != nil { + return nil, err + } } b.mu.Lock("UpdateSubnetGroup") @@ -1647,6 +1712,23 @@ func subnetGroupCopy(sg *SubnetGroup) *SubnetGroup { return &cp } +// subnetRegexp validates subnet IDs. +var subnetRegexp = regexp.MustCompile(`^subnet-[0-9a-f]{8}([0-9a-f]{9})?$`) + +// validateSubnetIDs checks the format of subnet IDs. +func validateSubnetIDs(ids []string) error { + if len(ids) == 0 { + return fmt.Errorf("%w: at least one SubnetId is required", ErrInvalidParameterValue) + } + for _, id := range ids { + if !subnetRegexp.MatchString(id) { + return fmt.Errorf("%w: invalid subnet ID %q", ErrInvalidParameterValue, id) + } + } + + return nil +} + // subnetEntriesFromIDs converts string subnet IDs to SubnetEntry slices using default AZ. func subnetEntriesFromIDs(ids []string, region string) []SubnetEntry { entries := make([]SubnetEntry, 0, len(ids)) @@ -1703,6 +1785,41 @@ func removeSpecificNodes(nodes []Node, nodeIDsToRemove []string, clusterName str return kept, nil } +const defaultTTL = 5 * time.Minute + +// GetDefaultTTL returns the TTLs configured in the first available cluster's param group, or defaults. +func (b *InMemoryBackend) GetDefaultTTL() (time.Duration, time.Duration) { + b.mu.RLock("GetDefaultTTL") + defer b.mu.RUnlock() + + // Default to 5 minutes if no clusters exist. + recordTTL := defaultTTL + queryTTL := defaultTTL + + for _, c := range b.clusters { + pgName := c.ParameterGroup.ParameterGroupName + pg, ok := b.paramGroups[pgName] + if !ok { + break + } + + if val, has := pg.Parameters[paramRecordTTL]; has { + if ms, err := strconv.ParseInt(val, 10, 64); err == nil { + recordTTL = time.Duration(ms) * time.Millisecond + } + } + if val, has := pg.Parameters[paramQueryTTL]; has { + if ms, err := strconv.ParseInt(val, 10, 64); err == nil { + queryTTL = time.Duration(ms) * time.Millisecond + } + } + + break // just use the first cluster's param group + } + + return recordTTL, queryTTL +} + // vpcIDFromSubnets returns a deterministic placeholder VPC ID derived from the first subnet ID. // Real AWS would look up the actual VPC; in emulation we derive a plausible ID from the subnet. func vpcIDFromSubnets(subnetIDs []string) string { diff --git a/services/dax/backend_parity_test.go b/services/dax/backend_parity_test.go index 55dc7fa7f..8d34f021f 100644 --- a/services/dax/backend_parity_test.go +++ b/services/dax/backend_parity_test.go @@ -160,7 +160,7 @@ func TestCreateSubnetGroupRequiresSubnet(t *testing.T) { {name: "nil subnets rejected", subnetIDs: nil, wantErr: true}, {name: "empty subnets rejected", subnetIDs: []string{}, wantErr: true}, {name: "one subnet accepted", subnetIDs: []string{"subnet-abc12345"}, wantErr: false}, - {name: "multiple subnets accepted", subnetIDs: []string{"subnet-aaa", "subnet-bbb"}, wantErr: false}, + {name: "multiple subnets accepted", subnetIDs: []string{"subnet-aaaaaaaa", "subnet-bbbbbbbb"}, wantErr: false}, } for _, tt := range tests { @@ -244,7 +244,7 @@ func TestCreateSubnetGroupNameValidation(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() b := newTestBackend() - _, err := b.CreateSubnetGroup(tt.sgName, "", []string{"subnet-1"}) + _, err := b.CreateSubnetGroup(tt.sgName, "", []string{"subnet-11111111"}) if tt.wantErr { require.Error(t, err) @@ -341,7 +341,7 @@ func TestDescribeSubnetGroupsPagination(t *testing.T) { // Create additional groups beyond the default. for i := range 5 { name := []byte{'a' + byte(i)} - _, err := b.CreateSubnetGroup(string(name)+"-sg", "", []string{"subnet-1"}) + _, err := b.CreateSubnetGroup(string(name)+"-sg", "", []string{"subnet-11111111"}) require.NoError(t, err) } diff --git a/services/dax/backend_test.go b/services/dax/backend_test.go index d645b5309..9c7f3de69 100644 --- a/services/dax/backend_test.go +++ b/services/dax/backend_test.go @@ -749,7 +749,7 @@ func TestTagResource(t *testing.T) { { name: "tag subnet group ARN", setup: func(b *dax.InMemoryBackend) string { - _, _ = b.CreateSubnetGroup("my-sg", "", []string{"subnet-1"}) + _, _ = b.CreateSubnetGroup("my-sg", "", []string{"subnet-11111111"}) return "arn:aws:dax:us-east-1:123456789012:subnetgroup/my-sg" }, @@ -1207,12 +1207,12 @@ func TestCreateSubnetGroup(t *testing.T) { name: "success", sgName: "my-sg", desc: "test subnet group", - subnetIDs: []string{"subnet-1", "subnet-2"}, + subnetIDs: []string{"subnet-11111111", "subnet-22222222"}, check: func(t *testing.T, sg *dax.SubnetGroup) { t.Helper() assert.Equal(t, "my-sg", sg.SubnetGroupName) assert.Len(t, sg.Subnets, 2) - assert.Equal(t, "subnet-1", sg.Subnets[0].SubnetID) + assert.Equal(t, "subnet-11111111", sg.Subnets[0].SubnetID) assert.Equal(t, "us-east-1a", sg.Subnets[0].AvailabilityZone) }, }, @@ -1248,9 +1248,9 @@ func TestCreateSubnetGroup(t *testing.T) { func TestCreateSubnetGroup_Duplicate(t *testing.T) { t.Parallel() b := newTestBackend() - _, err := b.CreateSubnetGroup("sg", "", []string{"subnet-1"}) + _, err := b.CreateSubnetGroup("sg", "", []string{"subnet-11111111"}) require.NoError(t, err) - _, err = b.CreateSubnetGroup("sg", "", []string{"subnet-1"}) + _, err = b.CreateSubnetGroup("sg", "", []string{"subnet-11111111"}) require.Error(t, err) } @@ -1267,7 +1267,7 @@ func TestUpdateSubnetGroup(t *testing.T) { { name: "update description", setup: func(b *dax.InMemoryBackend) { - _, _ = b.CreateSubnetGroup("upd-sg", "old desc", []string{"subnet-1"}) + _, _ = b.CreateSubnetGroup("upd-sg", "old desc", []string{"subnet-11111111"}) }, input: dax.UpdateSubnetGroupInput{SubnetGroupName: "upd-sg", Description: "new desc"}, check: func(t *testing.T, sg *dax.SubnetGroup) { @@ -1278,13 +1278,16 @@ func TestUpdateSubnetGroup(t *testing.T) { { name: "update subnets", setup: func(b *dax.InMemoryBackend) { - _, _ = b.CreateSubnetGroup("sub-sg", "", []string{"subnet-1"}) + _, _ = b.CreateSubnetGroup("sub-sg", "", []string{"subnet-11111111"}) + }, + input: dax.UpdateSubnetGroupInput{ + SubnetGroupName: "sub-sg", + SubnetIDs: []string{"subnet-22222222", "subnet-33333333"}, }, - input: dax.UpdateSubnetGroupInput{SubnetGroupName: "sub-sg", SubnetIDs: []string{"subnet-2", "subnet-3"}}, check: func(t *testing.T, sg *dax.SubnetGroup) { t.Helper() assert.Len(t, sg.Subnets, 2) - assert.Equal(t, "subnet-2", sg.Subnets[0].SubnetID) + assert.Equal(t, "subnet-22222222", sg.Subnets[0].SubnetID) }, }, { @@ -1330,7 +1333,7 @@ func TestDeleteSubnetGroup(t *testing.T) { { name: "success", setup: func(b *dax.InMemoryBackend) { - _, _ = b.CreateSubnetGroup("sg-del", "", []string{"subnet-1"}) + _, _ = b.CreateSubnetGroup("sg-del", "", []string{"subnet-11111111"}) }, sgName: "sg-del", }, @@ -1385,7 +1388,7 @@ func TestDescribeSubnetGroups(t *testing.T) { { name: "with custom group", setup: func(b *dax.InMemoryBackend) { - _, _ = b.CreateSubnetGroup("custom", "", []string{"subnet-1"}) + _, _ = b.CreateSubnetGroup("custom", "", []string{"subnet-11111111"}) }, wantCount: 2, }, diff --git a/services/dax/dataplane/batch.go b/services/dax/dataplane/batch.go index bc828cf56..886919e40 100644 --- a/services/dax/dataplane/batch.go +++ b/services/dax/dataplane/batch.go @@ -44,7 +44,7 @@ func (s *Server) handleBatchWriteItem(r *Reader, w *Writer) error { requestItems[table] = writes } - if _, err = readItemOptionalParams(r); err != nil { + if _, err = readItemOptionalParams(r, nil); err != nil { return err } @@ -52,6 +52,16 @@ func (s *Server) handleBatchWriteItem(r *Reader, w *Writer) error { return s.writeBackendError(w, err) } + for table, writes := range requestItems { + for _, wreq := range writes { + if wreq.PutRequest != nil { + s.invalidateItemCache(table, wreq.PutRequest.Item) + } else if wreq.DeleteRequest != nil { + s.invalidateItemCache(table, wreq.DeleteRequest.Key) + } + } + } + return s.writeBatchWriteResponse(w) } diff --git a/services/dax/dataplane/export_test.go b/services/dax/dataplane/export_test.go index e5b2c980e..5c1c76623 100644 --- a/services/dax/dataplane/export_test.go +++ b/services/dax/dataplane/export_test.go @@ -169,7 +169,7 @@ type AttrListServer struct { // NewAttrListServerForTest builds a server wrapper whose attribute-list id // allocations can be inspected by tests. func NewAttrListServerForTest() *AttrListServer { - return &AttrListServer{s: NewServer(context.TODO(), nil)} + return &AttrListServer{s: NewServer(context.TODO(), nil, nil)} } // WriteAttributeProjection emits an attribute-projection payload via the wrapped diff --git a/services/dax/dataplane/ops.go b/services/dax/dataplane/ops.go index f97eccb2a..3c125cf50 100644 --- a/services/dax/dataplane/ops.go +++ b/services/dax/dataplane/ops.go @@ -1,6 +1,9 @@ package dataplane import ( + "strings" + "time" + "bytes" "errors" "maps" @@ -38,6 +41,7 @@ const ( // itemOpParams holds the optional parameters decoded from an item operation. type itemOpParams struct { + proj *projection consistentRead bool returnValues int } @@ -45,17 +49,17 @@ type itemOpParams struct { // readItemOptionalParams decodes the indefinite-length optional-params map sent // after the key for item operations, extracting the fields gopherstack acts on // and safely skipping the rest (including pre-parsed expression blobs). -func readItemOptionalParams(r *Reader) (itemOpParams, error) { +func readItemOptionalParams(r *Reader, dec *decodedExpression) (itemOpParams, error) { p := itemOpParams{returnValues: rvNone} err := forEachOptionalParam(r, func(key int) error { - return decodeOneItemParam(r, &p, key) + return decodeOneItemParam(r, &p, key, dec) }) return p, err } -func decodeOneItemParam(r *Reader, p *itemOpParams, key int) error { +func decodeOneItemParam(r *Reader, p *itemOpParams, key int, dec *decodedExpression) error { switch key { case reqParamConsistentRead: b, e := r.readBoolOrInt() @@ -71,6 +75,18 @@ func decodeOneItemParam(r *Reader, p *itemOpParams, key int) error { } p.returnValues = v + case reqParamProjectionExpression: + blob, e := r.ReadBytes() + if e != nil { + return e + } + if dec != nil { + proj, decErr := dec.decodeProjectionBlob(blob) + if decErr != nil { + return decErr + } + p.proj = proj + } default: return r.skip() } @@ -160,8 +176,38 @@ func (r *Reader) readBoolOrInt() (bool, error) { // handleGetItem decodes a GetItem request, delegates to the backend, and writes // the DAX-shaped response. + +func (s *Server) invalidateItemCache(table string, key map[string]types.AttributeValue) { + if s.ttl != nil { + s.itemCache.Delete(s.cacheKey(table, key)) + } +} + +func (s *Server) cacheKey(table string, key map[string]types.AttributeValue) string { + parts := make([]string, 0, len(key)) + for k, v := range key { + parts = append(parts, k+"="+formatCacheValue(v)) + } + + return table + "|" + strings.Join(parts, "|") +} + +func formatCacheValue(v types.AttributeValue) string { + switch v := v.(type) { + case *types.AttributeValueMemberS: + return v.Value + case *types.AttributeValueMemberN: + return v.Value + case *types.AttributeValueMemberB: + return string(v.Value) + } + + return "" +} + func (s *Server) handleGetItem(r *Reader, w *Writer) error { - table, key, params, err := s.readKeyedRequest(r) + dec := newDecodedExpression() + table, key, params, err := s.readKeyedRequest(r, dec) if err != nil { return s.writeError(w, statusBadRequest, "ValidationException", err.Error()) } @@ -169,6 +215,22 @@ func (s *Server) handleGetItem(r *Reader, w *Writer) error { ctx, cancel := s.requestContext() defer cancel() + ks, err := s.schemaFor(ctx, table) + if err != nil { + return err + } + + ckey := s.cacheKey(table, key) + if !params.consistentRead && s.ttl != nil { + found, cErr := s.handleCachedGetItem(w, ckey, params.proj, ks) + if cErr != nil { + return cErr + } + if found { + return nil + } + } + out, err := s.backend.GetItem(ctx, &awsddb.GetItemInput{ TableName: &table, Key: key, @@ -182,20 +244,57 @@ func (s *Server) handleGetItem(r *Reader, w *Writer) error { return err } - ks, err := s.schemaFor(ctx, table) - if err != nil { - return err - } - if len(out.Item) == 0 { return w.WriteNull() } + // Update cache + if s.ttl != nil { + ttl, _ := s.ttl.GetDefaultTTL() + s.itemCache.Store(ckey, cacheEntry{item: out.Item, expiresAt: time.Now().Add(ttl)}) + } + + if params.proj != nil && params.proj.hasProjection() { + if wErr := w.WriteMapHeader(1); wErr != nil { + return wErr + } + if wErr := w.WriteInt(respParamItem); wErr != nil { + return wErr + } + + return writeProjectionMap(w, projectedEntries(out.Item, params.proj.ordinals)) + } + return s.writeItemMap(w, respParamItem, out.Item, ks) } -// handlePutItem decodes a PutItem request (key bytes + non-key attributes) and -// delegates to the backend. +func (s *Server) handleCachedGetItem(w *Writer, ckey string, proj *projection, ks keySchema) (bool, error) { + entry, ok := s.itemCache.Load(ckey) + if !ok { + return false, nil + } + ce, _ := entry.(cacheEntry) + if time.Now().After(ce.expiresAt) { + return false, nil + } + + if err := writeOK(w); err != nil { + return true, err + } + if proj != nil && proj.hasProjection() { + if wErr := w.WriteMapHeader(1); wErr != nil { + return true, wErr + } + if wErr := w.WriteInt(respParamItem); wErr != nil { + return true, wErr + } + + return true, writeProjectionMap(w, projectedEntries(ce.item, proj.ordinals)) + } + + return true, s.writeItemMap(w, respParamItem, ce.item, ks) +} + func (s *Server) handlePutItem(r *Reader, w *Writer) error { table, err := readTable(r) if err != nil { @@ -222,7 +321,7 @@ func (s *Server) handlePutItem(r *Reader, w *Writer) error { item := mergeItem(key, nonKey) - params, err := readItemOptionalParams(r) + params, err := readItemOptionalParams(r, nil) if err != nil { return err } @@ -237,6 +336,8 @@ func (s *Server) handlePutItem(r *Reader, w *Writer) error { return s.writeBackendError(w, err) } + s.invalidateItemCache(table, item) + if err = writeOK(w); err != nil { return err } @@ -250,7 +351,7 @@ func (s *Server) handlePutItem(r *Reader, w *Writer) error { // handleDeleteItem decodes a DeleteItem request and delegates to the backend. func (s *Server) handleDeleteItem(r *Reader, w *Writer) error { - table, key, params, err := s.readKeyedRequest(r) + table, key, params, err := s.readKeyedRequest(r, nil) if err != nil { return s.writeError(w, statusBadRequest, "ValidationException", err.Error()) } @@ -268,6 +369,8 @@ func (s *Server) handleDeleteItem(r *Reader, w *Writer) error { return s.writeBackendError(w, err) } + s.invalidateItemCache(table, key) + if err = writeOK(w); err != nil { return err } @@ -286,7 +389,10 @@ func (s *Server) handleDeleteItem(r *Reader, w *Writer) error { // readKeyedRequest reads the common [table, key, optionalParams] prefix shared // by GetItem and DeleteItem. -func (s *Server) readKeyedRequest(r *Reader) (string, map[string]types.AttributeValue, itemOpParams, error) { +func (s *Server) readKeyedRequest( + r *Reader, + dec *decodedExpression, +) (string, map[string]types.AttributeValue, itemOpParams, error) { table, err := readTable(r) if err != nil { return "", nil, itemOpParams{}, err @@ -305,7 +411,7 @@ func (s *Server) readKeyedRequest(r *Reader) (string, map[string]types.Attribute return "", nil, itemOpParams{}, err } - params, err := readItemOptionalParams(r) + params, err := readItemOptionalParams(r, dec) if err != nil { return "", nil, itemOpParams{}, err } diff --git a/services/dax/dataplane/server.go b/services/dax/dataplane/server.go index 3a0499ed1..959b980cb 100644 --- a/services/dax/dataplane/server.go +++ b/services/dax/dataplane/server.go @@ -87,35 +87,48 @@ type Backend interface { DescribeTable(context.Context, *awsddb.DescribeTableInput) (*awsddb.DescribeTableOutput, error) } +// TTLLookup provides TTL configuration for the item and query caches. +type TTLLookup interface { + GetDefaultTTL() (recordTTL time.Duration, queryTTL time.Duration) +} + +type cacheEntry struct { + item map[string]types.AttributeValue + expiresAt time.Time +} + // Server is a DAX data-plane TCP listener. It accepts DAX client connections, // performs the protocol handshake, and serves item operations by delegating to // a DynamoDB Backend. type Server struct { backend Backend + ttl TTLLookup // baseCtx is the data-plane lifecycle context, tagged worker=dax-dataplane. // The data plane is a raw TCP server with no per-request context, so its // goroutines log via logger.Load(baseCtx) rather than an embedded *slog.Logger. - baseCtx context.Context //nolint:containedctx // lifecycle ctx for the data-plane accept/serve goroutines. - ln net.Listener - conns map[net.Conn]struct{} - attrToID map[string]int64 // joined attr names -> id - idToAttr map[int64][]string - mu sync.Mutex - attrMu sync.Mutex - nextID int64 - closed bool + baseCtx context.Context //nolint:containedctx // lifecycle ctx for the data-plane accept/serve goroutines. + ln net.Listener + conns map[net.Conn]struct{} + attrToID map[string]int64 // joined attr names -> id + idToAttr map[int64][]string + itemCache sync.Map // key -> cacheEntry + mu sync.Mutex + attrMu sync.Mutex + nextID int64 + closed bool } // NewServer creates a DAX data-plane server backed by the given DynamoDB // backend. ctx is the process lifecycle context; the server tags it // worker=dax-dataplane so all data-plane records are attributable. -func NewServer(ctx context.Context, backend Backend) *Server { +func NewServer(ctx context.Context, backend Backend, ttl TTLLookup) *Server { if ctx == nil { ctx = context.Background() } return &Server{ backend: backend, + ttl: ttl, baseCtx: logger.WithWorker(ctx, "dax", "dataplane"), conns: make(map[net.Conn]struct{}), attrToID: make(map[string]int64), diff --git a/services/dax/dataplane/transact.go b/services/dax/dataplane/transact.go index 57f60e05a..d816a2f5a 100644 --- a/services/dax/dataplane/transact.go +++ b/services/dax/dataplane/transact.go @@ -59,6 +59,12 @@ func (s *Server) handleTransactWriteItems(r *Reader, w *Writer) error { return s.writeBackendError(w, err) } + for _, it := range items { + if it.operation != txnOpCheck { + s.invalidateItemCache(it.table, it.key) + } + } + if err = writeOK(w); err != nil { return err } diff --git a/services/dax/dataplane/update_query_scan.go b/services/dax/dataplane/update_query_scan.go index a1dd171bd..4e782466c 100644 --- a/services/dax/dataplane/update_query_scan.go +++ b/services/dax/dataplane/update_query_scan.go @@ -62,6 +62,8 @@ func (s *Server) handleUpdateItem(r *Reader, w *Writer) error { return s.writeBackendError(w, err) } + s.invalidateItemCache(table, key) + if err = writeOK(w); err != nil { return err } diff --git a/services/dax/dataplane_integration_test.go b/services/dax/dataplane_integration_test.go index e994f229e..37fa3758f 100644 --- a/services/dax/dataplane_integration_test.go +++ b/services/dax/dataplane_integration_test.go @@ -11,6 +11,8 @@ import ( v2types "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" v1creds "github.com/aws/aws-sdk-go/aws/credentials" v1dynamodb "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/blackbirdworks/gopherstack/services/dax" ) @@ -813,3 +815,122 @@ func assertNumber(t *testing.T, item map[string]*v1dynamodb.AttributeValue, key, t.Fatalf("attribute %q: got %q want %q", key, *av.N, want) } } + +func TestDataPlaneCaching(t *testing.T) { + t.Parallel() + + handler := dax.NewHandler(dax.NewInMemoryBackend("000000000000", "us-east-1")) + dp := handler.EnableDataPlane(context.TODO(), "127.0.0.1:0") + if err := handler.StartWorker(context.Background()); err != nil { + t.Fatalf("start data plane: %v", err) + } + t.Cleanup(func() { handler.Shutdown(context.Background()) }) + createIntegrationTable(t, handler.DataPlaneBackend()) + endpoint := "dax://" + dp.Addr().String() + + client := newDaxClient(t, endpoint) + + // 1. Put via DAX + _, err := client.PutItem(&v1dynamodb.PutItemInput{ + TableName: new(integrationTable), + Item: map[string]*v1dynamodb.AttributeValue{ + "pk": {S: new("cache#1")}, + "name": {S: new("Original")}, + }, + }) + require.NoError(t, err) + + // 2. Get via DAX (loads cache) + out1, err := client.GetItem(&v1dynamodb.GetItemInput{ + TableName: new(integrationTable), + Key: map[string]*v1dynamodb.AttributeValue{ + "pk": {S: new("cache#1")}, + }, + }) + require.NoError(t, err) + assert.Equal(t, "Original", *out1.Item["name"].S) + + // 3. Mutate backend directly (bypassing DAX cache) + backend := handler.DataPlaneBackend() + _, err = backend.PutItem(context.Background(), &v2dynamodb.PutItemInput{ + TableName: new(integrationTable), + Item: map[string]v2types.AttributeValue{ + "pk": &v2types.AttributeValueMemberS{Value: "cache#1"}, + "name": &v2types.AttributeValueMemberS{Value: "Bypassed"}, + }, + }) + require.NoError(t, err) + + // 4. Get via DAX again - should return cached "Original" + out2, err := client.GetItem(&v1dynamodb.GetItemInput{ + TableName: new(integrationTable), + Key: map[string]*v1dynamodb.AttributeValue{ + "pk": {S: new("cache#1")}, + }, + }) + require.NoError(t, err) + assert.Equal(t, "Original", *out2.Item["name"].S, "Expected cached Original value, but got Bypassed") + + // 5. Update via DAX (invalidates cache) + _, err = client.UpdateItem(&v1dynamodb.UpdateItemInput{ + TableName: new(integrationTable), + Key: map[string]*v1dynamodb.AttributeValue{ + "pk": {S: new("cache#1")}, + }, + UpdateExpression: new("SET #n = :v"), + ExpressionAttributeNames: map[string]*string{ + "#n": new("name"), + }, + ExpressionAttributeValues: map[string]*v1dynamodb.AttributeValue{ + ":v": {S: new("UpdatedViaDax")}, + }, + }) + require.NoError(t, err) + + // 6. Get via DAX - should return "UpdatedViaDax" + out3, err := client.GetItem(&v1dynamodb.GetItemInput{ + TableName: new(integrationTable), + Key: map[string]*v1dynamodb.AttributeValue{ + "pk": {S: new("cache#1")}, + }, + }) + require.NoError(t, err) + assert.Equal(t, "UpdatedViaDax", *out3.Item["name"].S) +} + +func TestDataPlaneProjectionExpression(t *testing.T) { + // Must run serially due to ANTLR lexer issue in the client SDK + // t.Parallel() + + endpoint := newDataPlaneFixture(t) + client := newDaxClient(t, endpoint) + + _, err := client.PutItem(&v1dynamodb.PutItemInput{ + TableName: new(integrationTable), + Item: map[string]*v1dynamodb.AttributeValue{ + "pk": {S: new("proj#1")}, + "name": {S: new("Ada")}, + "age": {N: new("25")}, + "hidden": {S: new("secret")}, + }, + }) + require.NoError(t, err) + + out, err := client.GetItem(&v1dynamodb.GetItemInput{ + TableName: new(integrationTable), + Key: map[string]*v1dynamodb.AttributeValue{ + "pk": {S: new("proj#1")}, + }, + ProjectionExpression: new("#n, age"), + ExpressionAttributeNames: map[string]*string{ + "#n": new("name"), + }, + }) + require.NoError(t, err) + + require.Len(t, out.Item, 2) + assert.Equal(t, "Ada", *out.Item["name"].S) + assert.Equal(t, "25", *out.Item["age"].N) + assert.NotContains(t, out.Item, "hidden") + assert.NotContains(t, out.Item, "pk") // unless pk is requested, it's not returned +} diff --git a/services/dax/dataplane_server.go b/services/dax/dataplane_server.go index 6f51c62e1..734372d27 100644 --- a/services/dax/dataplane_server.go +++ b/services/dax/dataplane_server.go @@ -25,7 +25,7 @@ type dataPlane struct { // newDataPlane constructs a DAX data-plane bound to its own DynamoDB backend. // ctx is the process lifecycle context used for the data-plane's logging. -func newDataPlane(ctx context.Context, addr string) *dataPlane { +func newDataPlane(ctx context.Context, addr string, daxBackend StorageBackend) *dataPlane { if addr == "" { addr = defaultDataPlaneAddr } @@ -33,7 +33,7 @@ func newDataPlane(ctx context.Context, addr string) *dataPlane { backend := dynamodb.NewInMemoryDB() return &dataPlane{ - server: dataplane.NewServer(ctx, backend), + server: dataplane.NewServer(ctx, backend, daxBackend), backend: backend, addr: addr, } @@ -63,7 +63,7 @@ func (h *Handler) EnableDataPlane(ctx context.Context, addr string) *dataPlane { return h.dataPlane } - h.dataPlane = newDataPlane(ctx, addr) + h.dataPlane = newDataPlane(ctx, addr, h.Backend) return h.dataPlane } diff --git a/services/dax/export_test.go b/services/dax/export_test.go new file mode 100644 index 000000000..ffc82fc9c --- /dev/null +++ b/services/dax/export_test.go @@ -0,0 +1,9 @@ +package dax + +func SetClusterAvailableForTest(b *InMemoryBackend, name string) { + b.mu.Lock("SetClusterAvailableForTest") + defer b.mu.Unlock() + if c, ok := b.clusters[name]; ok { + c.Status = StatusAvailable + } +} diff --git a/services/dax/handler_test.go b/services/dax/handler_test.go index 559d5609d..0b755fdf1 100644 --- a/services/dax/handler_test.go +++ b/services/dax/handler_test.go @@ -642,7 +642,7 @@ func TestHandlerSubnetGroups(t *testing.T) { body: map[string]any{ "SubnetGroupName": "my-sg", "Description": "My subnet group", - "SubnetIds": []string{"subnet-abc123"}, + "SubnetIds": []string{"subnet-abc12345"}, }, wantStatus: http.StatusOK, check: func(t *testing.T, resp map[string]any) { @@ -652,7 +652,7 @@ func TestHandlerSubnetGroups(t *testing.T) { subnets := sg["Subnets"].([]any) require.Len(t, subnets, 1) subnet := subnets[0].(map[string]any) - assert.Equal(t, "subnet-abc123", subnet["SubnetIdentifier"]) + assert.Equal(t, "subnet-abc12345", subnet["SubnetIdentifier"]) assert.Equal(t, "us-east-1a", subnet["SubnetAvailabilityZone"]) }, }, @@ -675,7 +675,7 @@ func TestHandlerSubnetGroups(t *testing.T) { t.Helper() daxRequest(t, h, "CreateSubnetGroup", map[string]any{ "SubnetGroupName": "upd-sg", - "SubnetIds": []string{"subnet-1"}, + "SubnetIds": []string{"subnet-11111111"}, }) }, body: map[string]any{ @@ -696,7 +696,7 @@ func TestHandlerSubnetGroups(t *testing.T) { t.Helper() daxRequest(t, h, "CreateSubnetGroup", map[string]any{ "SubnetGroupName": "sg-del", - "SubnetIds": []string{"subnet-1"}, + "SubnetIds": []string{"subnet-11111111"}, }) }, body: map[string]any{"SubnetGroupName": "sg-del"}, @@ -841,12 +841,12 @@ func TestHandlerErrorMapping(t *testing.T) { t.Helper() daxRequest(t, h, "CreateSubnetGroup", map[string]any{ "SubnetGroupName": "dup-sg", - "SubnetIds": []string{"subnet-1"}, + "SubnetIds": []string{"subnet-11111111"}, }) }, body: map[string]any{ "SubnetGroupName": "dup-sg", - "SubnetIds": []string{"subnet-1"}, + "SubnetIds": []string{"subnet-11111111"}, }, wantCode: "SubnetGroupAlreadyExistsFault", }, diff --git a/services/dax/interfaces.go b/services/dax/interfaces.go index 0d271b267..065474d8d 100644 --- a/services/dax/interfaces.go +++ b/services/dax/interfaces.go @@ -58,6 +58,8 @@ type StorageBackend interface { nextToken string, ) ([]*Event, string, error) + GetDefaultTTL() (recordTTL time.Duration, queryTTL time.Duration) + // Reset / persistence. Reset() Snapshot(ctx context.Context) []byte diff --git a/services/dax/models.go b/services/dax/models.go index 5ea942bfc..bad3dc56a 100644 --- a/services/dax/models.go +++ b/services/dax/models.go @@ -197,6 +197,7 @@ type Cluster struct { SecurityGroupIDs []string `json:"securityGroupIds"` ActiveNodes int `json:"activeNodes"` TotalNodes int `json:"totalNodes"` + NextNodeIndex int `json:"nextNodeIndex"` } // ParameterGroup represents a DAX parameter group. diff --git a/services/dax/persistence.go b/services/dax/persistence.go index 2c91007f8..8180aaf8b 100644 --- a/services/dax/persistence.go +++ b/services/dax/persistence.go @@ -146,5 +146,61 @@ func (b *InMemoryBackend) Restore(ctx context.Context, data []byte) error { b.AccountID = snap.AccountID b.Region = snap.Region + b.recoverAsyncTransitions() + return nil } + +func (b *InMemoryBackend) recoverAsyncTransitions() { + for name, c := range b.clusters { + b.recoverClusterState(name, c) + + for i := range c.Nodes { + if c.Nodes[i].NodeStatus == StatusRebooting { + b.recoverNodeState(name, c.Nodes[i].NodeID) + } + } + } +} + +func (b *InMemoryBackend) recoverClusterState(name string, c *Cluster) { + switch c.Status { + case StatusCreating, StatusModifying: + go func(cName string) { + b.mu.Lock("Restore:cluster-recovery") + defer b.mu.Unlock() + if cl, ok := b.clusters[cName]; ok { + cl.Status = StatusAvailable + } + }(name) + case StatusDeleting: + go func(cName, cArn string) { + b.mu.Lock("Restore:delete-recovery") + defer b.mu.Unlock() + if cl, ok := b.clusters[cName]; ok && cl.Status == StatusDeleting { + delete(b.clusters, cName) + delete(b.tags, cArn) + b.emitEventLocked(cName, EventSourceTypeCluster, + fmt.Sprintf("Cluster %s has been deleted.", cName)) + } + }(name, c.ClusterArn) + } +} + +func (b *InMemoryBackend) recoverNodeState(cName, nodeID string) { + go func() { + b.mu.Lock("Restore:node-recovery") + defer b.mu.Unlock() + if cl, ok := b.clusters[cName]; ok { + for j := range cl.Nodes { + if cl.Nodes[j].NodeID == nodeID { + cl.Nodes[j].NodeStatus = StatusAvailable + b.emitEventLocked(cName, EventSourceTypeNode, + fmt.Sprintf("Node %s reboot complete.", nodeID)) + + break + } + } + } + }() +} diff --git a/services/dax/zz_testmain_test.go b/services/dax/zz_testmain_test.go new file mode 100644 index 000000000..88752bb0d --- /dev/null +++ b/services/dax/zz_testmain_test.go @@ -0,0 +1,20 @@ +package dax_test + +import ( + "os" + "testing" +) + +// TestMain forces synchronous DAX state transitions for the whole package's +// tests. In production, cluster/node lifecycle changes (CREATING -> AVAILABLE, +// reboot, replication-factor changes) complete asynchronously; the emulator +// honours DAX_TEST_SYNC=1 to apply them immediately so tests are deterministic +// without sleeping or polling. CI runs plain `go test`, so the tests must set +// this themselves rather than rely on the environment. +func TestMain(m *testing.M) { + if err := os.Setenv("DAX_TEST_SYNC", "1"); err != nil { + panic(err) + } + + os.Exit(m.Run()) +} diff --git a/services/dynamodb/backup_interface.go b/services/dynamodb/backup_interface.go index a615be48b..08e3f150d 100644 --- a/services/dynamodb/backup_interface.go +++ b/services/dynamodb/backup_interface.go @@ -129,7 +129,7 @@ func (db *InMemoryDB) CreateBackup( now := time.Now() bkpARN := backupARN(region, db.accountID, tableName, now) - sizeBytes := estimateTableSizeBytes(snap.Items) + sizeBytes := estimateTableSizeBytes(table) backup := &Backup{ BackupArn: bkpARN, BackupName: backupName, diff --git a/services/dynamodb/errors.go b/services/dynamodb/errors.go index 489fe67a6..224b9b0ad 100644 --- a/services/dynamodb/errors.go +++ b/services/dynamodb/errors.go @@ -121,13 +121,6 @@ func NewRequestLimitExceeded(msg string) *Error { } } -func NewTransactionConflictException(msg string) *Error { - return &Error{ - Type: "com.amazonaws.dynamodb.v20120810#TransactionConflictException", - Message: msg, - } -} - func NewReplicatedWriteConflictException(msg string) *Error { return &Error{ Type: "com.amazonaws.dynamodb.v20120810#ReplicatedWriteConflictException", diff --git a/services/dynamodb/expr/parser.go b/services/dynamodb/expr/parser.go index d659e3440..dc5876791 100644 --- a/services/dynamodb/expr/parser.go +++ b/services/dynamodb/expr/parser.go @@ -118,7 +118,9 @@ func (p *Parser) parseExpression(precedence int) (Node, error) { TokenContains, TokenAttributeType, TokenIfNotExists, - TokenListAppend: + TokenListAppend, + TokenAND, TokenOR, TokenBETWEEN, TokenIN, + TokenSET, TokenREMOVE, TokenADD, TokenDELETE: left, err = p.parseOperand() default: return nil, fmt.Errorf("%w %v at start of expression", ErrUnexpectedToken, p.curToken) @@ -175,7 +177,9 @@ func (p *Parser) parseOperand() (Node, error) { return p.parseFunctionExpr() case TokenValue: return &ValuePlaceholder{Name: p.curToken.Literal}, nil - case TokenIdentifier: + case TokenIdentifier, + TokenAND, TokenOR, TokenBETWEEN, TokenIN, + TokenSET, TokenREMOVE, TokenADD, TokenDELETE: return p.parsePathExpr() default: return nil, fmt.Errorf("%w %v", ErrUnexpectedOperand, p.curToken) @@ -207,14 +211,29 @@ func (p *Parser) parsePathExpr() (Node, error) { } func (p *Parser) parseDotSegment(expr *PathExpr) error { - if !p.expectPeek(TokenIdentifier) { + if !p.peekTokenIsIdentifierLike() { return ErrExpectedIdentifierDot } + p.nextToken() expr.Elements = append(expr.Elements, PathElement{Type: ElementKey, Name: p.curToken.Literal}) return nil } +func (p *Parser) peekTokenIsIdentifierLike() bool { + switch p.peekToken.Type { + case TokenIdentifier, + TokenAND, TokenOR, TokenNOT, TokenBETWEEN, TokenIN, + TokenSET, TokenREMOVE, TokenADD, TokenDELETE, + TokenSize, TokenAttributeExists, TokenAttributeNotExists, + TokenBeginsWith, TokenContains, TokenAttributeType, + TokenIfNotExists, TokenListAppend: + return true + default: + return false + } +} + func (p *Parser) parseBracketSegment(expr *PathExpr) error { p.nextToken() if !p.curTokenIs( diff --git a/services/dynamodb/extra_ops.go b/services/dynamodb/extra_ops.go index 696302b8e..a4d509da2 100644 --- a/services/dynamodb/extra_ops.go +++ b/services/dynamodb/extra_ops.go @@ -1649,28 +1649,33 @@ func (db *InMemoryDB) ImportTable( InputCompression: string(input.InputCompressionType), StartTime: start, CreatedAt: start, - } - - res, importErr := db.importFromS3( - ctx, tableName, input.S3BucketSource, - input.InputFormat, input.InputCompressionType, input.InputFormatOptions, - ) - rec.EndTime = time.Now() - rec.ImportedItemCount = res.imported - rec.ProcessedItemCount = res.processed - rec.ProcessedSizeBytes = res.bytes - rec.ErrorCount = res.errors - - if importErr != nil { - rec.ImportStatus = string(types.ImportStatusFailed) - rec.FailureCode = "InputFormatError" - rec.FailureMessage = importErr.Error() - } else { - rec.ImportStatus = string(types.ImportStatusCompleted) + ImportStatus: string(types.ImportStatusInProgress), } db.storeImport(rec) + go func(r storedImport, tName string, in *dynamodb.ImportTableInput) { + res, importErr := db.importFromS3( + context.WithoutCancel(ctx), tName, in.S3BucketSource, + in.InputFormat, in.InputCompressionType, in.InputFormatOptions, + ) + r.EndTime = time.Now() + r.ImportedItemCount = res.imported + r.ProcessedItemCount = res.processed + r.ProcessedSizeBytes = res.bytes + r.ErrorCount = res.errors + + if importErr != nil { + r.ImportStatus = string(types.ImportStatusFailed) + r.FailureCode = "InputFormatError" + r.FailureMessage = importErr.Error() + } else { + r.ImportStatus = string(types.ImportStatusCompleted) + } + + db.storeImport(r) + }(rec, tableName, input) + return &dynamodb.ImportTableOutput{ ImportTableDescription: importDescriptionFromRecord(rec), }, nil @@ -1736,10 +1741,12 @@ func (db *InMemoryDB) ListImports( // NextToken is the ImportArn of the last record returned previously. nextToken := aws.ToString(input.NextToken) pageSize := defaultListImportsLimit - if input.PageSize != nil && *input.PageSize > 0 && int(*input.PageSize) < defaultListImportsLimit { + if input.PageSize != nil && *input.PageSize > 0 { pageSize = int(*input.PageSize) } + tableArnFilter := aws.ToString(input.TableArn) + // Filter by region and apply cursor. summaries := make([]types.ImportSummary, 0, len(stored)) started := nextToken == "" @@ -1748,6 +1755,9 @@ func (db *InMemoryDB) ListImports( if db.regionFromARN(imp.ImportArn) != region { continue } + if tableArnFilter != "" && imp.TableArn != tableArnFilter { + continue + } if !started { if imp.ImportArn == nextToken { started = true @@ -1808,6 +1818,8 @@ func cloneTableSchema(src *Table, name, region, accountID string) *Table { Name: name, Status: statusActive, Items: make([]map[string]any, 0), + itemSizes: make([]int, 0), + totalItemSizeBytes: 0, TableID: uuid.New().String(), CreationDateTime: time.Now(), TableArn: arn.Build("dynamodb", region, accountID, "table/"+name), diff --git a/services/dynamodb/handler.go b/services/dynamodb/handler.go index 34c1c9407..bc5565800 100644 --- a/services/dynamodb/handler.go +++ b/services/dynamodb/handler.go @@ -1163,10 +1163,11 @@ func (h *DynamoDBHandler) updateContinuousBackups(ctx context.Context, body []by } type exportTableToPointInTimeInput struct { - TableArn string `json:"TableArn"` - S3Bucket string `json:"S3Bucket"` - S3Prefix string `json:"S3Prefix,omitempty"` - ExportFormat string `json:"ExportFormat,omitempty"` + TableArn string `json:"TableArn"` + S3Bucket string `json:"S3Bucket"` + S3Prefix string `json:"S3Prefix,omitempty"` + ExportFormat string `json:"ExportFormat,omitempty"` + ExportTime float64 `json:"ExportTime,omitempty"` } type exportDescriptionFields struct { @@ -1243,13 +1244,15 @@ func (h *DynamoDBHandler) exportTableToPointInTime(ctx context.Context, body []b ExportFormat: exportFmt, ExportType: "FULL_EXPORT", StartTime: float64(now.Unix()), + ExportTime: req.ExportTime, } - // Persist as IN_PROGRESS (AWS initial response), then complete synchronously. + // Persist as IN_PROGRESS (AWS initial response), then complete asynchronously. // Real AWS takes minutes; the emulator finishes in microseconds. if b, ok := h.Backend.(*InMemoryDB); ok { b.storeExport(desc) - b.completeExportSync(ctx, exportARN, &req) + reqCopy := req + go b.completeExportSync(context.WithoutCancel(ctx), exportARN, &reqCopy) } return &exportTableToPointInTimeOutput{ExportDescription: desc}, nil diff --git a/services/dynamodb/handler_internal_test.go b/services/dynamodb/handler_internal_test.go index 8d363af0a..815176265 100644 --- a/services/dynamodb/handler_internal_test.go +++ b/services/dynamodb/handler_internal_test.go @@ -90,12 +90,7 @@ func TestHandler_ClassifyError_Mapping(t *testing.T) { wantStatusCode: http.StatusBadRequest, wantType: "RequestLimitExceeded", }, - { - name: "TransactionConflictException", - err: NewTransactionConflictException("conflict"), - wantStatusCode: http.StatusBadRequest, - wantType: "TransactionConflictException", - }, + { name: "ReplicatedWriteConflictException", err: NewReplicatedWriteConflictException("replicated conflict"), diff --git a/services/dynamodb/handler_streams_test.go b/services/dynamodb/handler_streams_test.go index 07b99b0cd..66c9ee081 100644 --- a/services/dynamodb/handler_streams_test.go +++ b/services/dynamodb/handler_streams_test.go @@ -7,7 +7,6 @@ import ( "net/http" "net/http/httptest" "testing" - "time" streamstypes "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams/types" "github.com/labstack/echo/v5" @@ -110,11 +109,33 @@ func TestHandler_StreamsDispatch(t *testing.T) { t.Run("GetRecords returns INSERT record", func(t *testing.T) { t.Parallel() - handler, _ := newStreamEnabledHandler(t) - // Use current timestamp so the 3-part iterator (tableName:startSeq:timestamp) is valid. - iter := fmt.Sprintf("StreamHandlerTable:0:%d", time.Now().Unix()) - w := doStreamsRequest(t, handler, "GetRecords", `{"ShardIterator":"`+iter+`"}`) + handler, arn := newStreamEnabledHandler(t) + // First, DescribeStream to get the Shard ID + wDesc := doStreamsRequest(t, handler, "DescribeStream", `{"StreamArn":"`+arn+`"}`) + assert.Equal(t, http.StatusOK, wDesc.Code) + var descResp struct { + StreamDescription struct { + Shards []struct { + ShardID string `json:"ShardId"` + } `json:"Shards"` + } `json:"StreamDescription"` + } + require.NoError(t, json.Unmarshal(wDesc.Body.Bytes(), &descResp)) + require.NotEmpty(t, descResp.StreamDescription.Shards) + shardID := descResp.StreamDescription.Shards[0].ShardID + + // Then, GetShardIterator to get the iterator token + iterReq := fmt.Sprintf(`{"StreamArn":"%s","ShardId":"%s","ShardIteratorType":"TRIM_HORIZON"}`, arn, shardID) + wIter := doStreamsRequest(t, handler, "GetShardIterator", iterReq) + assert.Equal(t, http.StatusOK, wIter.Code) + var iterResp struct { + ShardIterator string `json:"ShardIterator"` + } + require.NoError(t, json.Unmarshal(wIter.Body.Bytes(), &iterResp)) + require.NotEmpty(t, iterResp.ShardIterator) + + w := doStreamsRequest(t, handler, "GetRecords", `{"ShardIterator":"`+iterResp.ShardIterator+`"}`) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "Records") assert.Contains(t, w.Body.String(), "INSERT") diff --git a/services/dynamodb/import_export_s3_test.go b/services/dynamodb/import_export_s3_test.go index 43a3b4c8b..02257a56f 100644 --- a/services/dynamodb/import_export_s3_test.go +++ b/services/dynamodb/import_export_s3_test.go @@ -8,6 +8,7 @@ import ( "sort" "strings" "testing" + "time" "github.com/aws/aws-sdk-go-v2/aws" sdk "github.com/aws/aws-sdk-go-v2/service/dynamodb" @@ -96,6 +97,37 @@ func importCreationParams(name string) *ddbtypes.TableCreationParameters { } } +func waitForImport(t *testing.T, db *dynamodb.InMemoryDB, arn string) *sdk.DescribeImportOutput { + t.Helper() + + for range 50 { + out, err := db.DescribeImport(t.Context(), &sdk.DescribeImportInput{ImportArn: aws.String(arn)}) + require.NoError(t, err) + if out.ImportTableDescription.ImportStatus != ddbtypes.ImportStatusInProgress { + return out + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("import %s did not complete", arn) + + return nil +} + +func waitForExport(t *testing.T, h *dynamodb.DynamoDBHandler, arn string) { + t.Helper() + + for range 50 { + code, res := invokeOp(t, h, "DescribeExport", map[string]any{"ExportArn": arn}) + require.Equal(t, 200, code) + desc := res["ExportDescription"].(map[string]any) + if desc["ExportStatus"] != "IN_PROGRESS" { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("export %s did not complete", arn) +} + // TestImportTable_FromS3_DynamoDBJSON verifies ImportTable creates the table and // ingests gzipped DynamoDB-JSON objects, reporting accurate counts. func TestImportTable_FromS3_DynamoDBJSON(t *testing.T) { @@ -119,9 +151,12 @@ func TestImportTable_FromS3_DynamoDBJSON(t *testing.T) { TableCreationParameters: importCreationParams("ImportedJSON"), }) require.NoError(t, err) - assert.Equal(t, ddbtypes.ImportStatusCompleted, out.ImportTableDescription.ImportStatus) - assert.Equal(t, int64(2), out.ImportTableDescription.ImportedItemCount) - assert.Equal(t, int64(2), out.ImportTableDescription.ProcessedItemCount) + + importDesc := waitForImport(t, db, aws.ToString(out.ImportTableDescription.ImportArn)) + + assert.Equal(t, ddbtypes.ImportStatusCompleted, importDesc.ImportTableDescription.ImportStatus) + assert.Equal(t, int64(2), importDesc.ImportTableDescription.ImportedItemCount) + assert.Equal(t, int64(2), importDesc.ImportTableDescription.ProcessedItemCount) got, err := db.GetItem(t.Context(), &sdk.GetItemInput{ TableName: aws.String("ImportedJSON"), @@ -153,7 +188,9 @@ func TestImportTable_FromS3_CSV(t *testing.T) { TableCreationParameters: importCreationParams("ImportedCSV"), }) require.NoError(t, err) - assert.Equal(t, int64(2), out.ImportTableDescription.ImportedItemCount) + importDesc := waitForImport(t, db, aws.ToString(out.ImportTableDescription.ImportArn)) + assert.Equal(t, ddbtypes.ImportStatusCompleted, importDesc.ImportTableDescription.ImportStatus) + assert.Equal(t, int64(2), importDesc.ImportTableDescription.ImportedItemCount) got, err := db.GetItem(t.Context(), &sdk.GetItemInput{ TableName: aws.String("ImportedCSV"), @@ -184,8 +221,9 @@ func TestImportTable_ION_Unsupported(t *testing.T) { TableCreationParameters: importCreationParams("ImportedION"), }) require.NoError(t, err) - assert.Equal(t, ddbtypes.ImportStatusFailed, out.ImportTableDescription.ImportStatus) - assert.NotEmpty(t, aws.ToString(out.ImportTableDescription.FailureCode)) + importDesc := waitForImport(t, db, aws.ToString(out.ImportTableDescription.ImportArn)) + assert.Equal(t, ddbtypes.ImportStatusFailed, importDesc.ImportTableDescription.ImportStatus) + assert.NotEmpty(t, aws.ToString(importDesc.ImportTableDescription.FailureCode)) } // TestExportImport_RoundTrip exports a populated table to S3 and re-imports it. @@ -212,12 +250,13 @@ func TestExportImport_RoundTrip(t *testing.T) { require.True(t, ok) // Export to S3 via the handler. - code, _ := invokeOp(t, h, "ExportTableToPointInTime", map[string]any{ + code, res := invokeOp(t, h, "ExportTableToPointInTime", map[string]any{ "TableArn": tbl.TableArn, "S3Bucket": "exb", "S3Prefix": "out", }) require.Equal(t, 200, code) + waitForExport(t, h, res["ExportDescription"].(map[string]any)["ExportArn"].(string)) // Re-import the exported data into a new table from the data/ prefix. var dataPrefix string @@ -239,5 +278,6 @@ func TestExportImport_RoundTrip(t *testing.T) { TableCreationParameters: importCreationParams("RoundTripTbl"), }) require.NoError(t, err) - assert.Equal(t, int64(3), out.ImportTableDescription.ImportedItemCount) + importDesc := waitForImport(t, db, aws.ToString(out.ImportTableDescription.ImportArn)) + assert.Equal(t, int64(3), importDesc.ImportTableDescription.ImportedItemCount) } diff --git a/services/dynamodb/item_ops.go b/services/dynamodb/item_ops.go index 33d7912a1..bfdb22879 100644 --- a/services/dynamodb/item_ops.go +++ b/services/dynamodb/item_ops.go @@ -249,35 +249,148 @@ func compareAttributeValues(v1, v2 any) bool { m2, ok2 := v2.(map[string]any) if !ok1 || !ok2 { - // Fallback for bare Go primitives (shouldn't occur in normal operation). return fmt.Sprintf("%v", v1) == fmt.Sprintf("%v", v2) } - for typeKey, val1 := range m1 { - val2, exists := m2[typeKey] + if len(m1) != len(m2) { + return false + } + + for typeKey, leftVal := range m1 { + rightVal, exists := m2[typeKey] if !exists { return false } - s1, isStr1 := val1.(string) - s2, isStr2 := val2.(string) + if !compareTypedField(typeKey, leftVal, rightVal) { + return false + } + } + + return true +} + +// compareTypedField compares one DynamoDB-typed attribute field (M/L/SS/NS/BS or +// a scalar S/N/B/BOOL). Container types fall back to scalar comparison when the +// expected Go shape does not match, matching the original behaviour. +func compareTypedField(typeKey string, leftVal, rightVal any) bool { + switch typeKey { + case "M": + return compareMapField(leftVal, rightVal) + case "L": + return compareListField(leftVal, rightVal) + case "SS", "NS": + return compareStringSetField(leftVal, rightVal) + case "BS": + return compareByteSetField(leftVal, rightVal) + default: + return compareScalarField(leftVal, rightVal) + } +} + +func compareMapField(leftVal, rightVal any) bool { + m1, ok1 := leftVal.(map[string]any) + m2, ok2 := rightVal.(map[string]any) + if !ok1 || !ok2 { + return compareScalarField(leftVal, rightVal) + } + + if len(m1) != len(m2) { + return false + } + + for k, child1 := range m1 { + child2, ok := m2[k] + if !ok || !compareAttributeValues(child1, child2) { + return false + } + } + + return true +} + +func compareListField(leftVal, rightVal any) bool { + l1, ok1 := leftVal.([]any) + l2, ok2 := rightVal.([]any) + if !ok1 || !ok2 { + return compareScalarField(leftVal, rightVal) + } + + if len(l1) != len(l2) { + return false + } + + for i := range l1 { + if !compareAttributeValues(l1[i], l2[i]) { + return false + } + } + + return true +} + +func compareStringSetField(leftVal, rightVal any) bool { + s1, ok1 := leftVal.([]string) + s2, ok2 := rightVal.([]string) + if !ok1 || !ok2 { + return compareScalarField(leftVal, rightVal) + } + + if len(s1) != len(s2) { + return false + } + + for i := range s1 { + if s1[i] != s2[i] { + return false + } + } + + return true +} + +func compareByteSetField(leftVal, rightVal any) bool { + b1, ok1 := leftVal.([][]byte) + b2, ok2 := rightVal.([][]byte) + if !ok1 || !ok2 { + return compareScalarField(leftVal, rightVal) + } + + if len(b1) != len(b2) { + return false + } + + for i := range b1 { + if !bytes.Equal(b1[i], b2[i]) { + return false + } + } + + return true +} - if isStr1 && isStr2 { +// compareScalarField handles S/N (string), B ([]byte) and BOOL, with a +// stringified fallback for anything else. +func compareScalarField(leftVal, rightVal any) bool { + if s1, okL := leftVal.(string); okL { + if s2, okR := rightVal.(string); okR { return s1 == s2 } + } - // Binary attribute (B type) — use bytes.Equal for correct comparison. - b1, isByte1 := val1.([]byte) - b2, isByte2 := val2.([]byte) - if isByte1 && isByte2 { + if b1, okL := leftVal.([]byte); okL { + if b2, okR := rightVal.([]byte); okR { return bytes.Equal(b1, b2) } + } - // Nested map (e.g. M, L types) — fall back to string representation. - return fmt.Sprintf("%v", val1) == fmt.Sprintf("%v", val2) + if bl1, okL := leftVal.(bool); okL { + if bl2, okR := rightVal.(bool); okR { + return bl1 == bl2 + } } - return len(m2) == 0 + return fmt.Sprintf("%v", leftVal) == fmt.Sprintf("%v", rightVal) } func applyGSIProjection( diff --git a/services/dynamodb/item_ops_batch.go b/services/dynamodb/item_ops_batch.go index a5c3ac6dd..9b5ffa29a 100644 --- a/services/dynamodb/item_ops_batch.go +++ b/services/dynamodb/item_ops_batch.go @@ -573,7 +573,7 @@ func (db *InMemoryDB) applyBatchDeletes(table *Table, indices []int) { continue } // Capture stream record (REMOVE) - table.appendStreamRecord(streamEventRemove, deepCopyItem(table.Items[idx]), nil) + table.appendStreamRecord(streamEventRemove, deepCopyItem(table.Items[idx]), nil, "", "") // Delete by swapping with last and truncating table.Items[idx] = table.Items[len(table.Items)-1] @@ -657,13 +657,13 @@ func (db *InMemoryDB) handleBatchPutWithIndex(table *Table, item map[string]any) oldItem, matchIndex := db.findMatchForPut(table, item) if matchIndex != -1 { // Capture stream event (MODIFY) before overwriting in place. - table.appendStreamRecord(streamEventModify, oldItem, deepCopyItem(item)) + table.appendStreamRecord(streamEventModify, oldItem, deepCopyItem(item), "", "") table.Items[matchIndex] = item return matchIndex } // Capture stream event (INSERT) for the new item. - table.appendStreamRecord(streamEventInsert, nil, deepCopyItem(item)) + table.appendStreamRecord(streamEventInsert, nil, deepCopyItem(item), "", "") idx := len(table.Items) table.Items = append(table.Items, item) diff --git a/services/dynamodb/item_ops_crud.go b/services/dynamodb/item_ops_crud.go index 9f28f0aad..f39c483bc 100644 --- a/services/dynamodb/item_ops_crud.go +++ b/services/dynamodb/item_ops_crud.go @@ -92,9 +92,9 @@ func (db *InMemoryDB) PutItem( // Capture stream event if matchIndex != -1 { - table.appendStreamRecord(streamEventModify, oldItem, deepCopyItem(wireItem)) + table.appendStreamRecord(streamEventModify, oldItem, deepCopyItem(wireItem), "", "") } else { - table.appendStreamRecord(streamEventInsert, nil, deepCopyItem(wireItem)) + table.appendStreamRecord(streamEventInsert, nil, deepCopyItem(wireItem), "", "") } globalTableName := table.GlobalTableName @@ -181,12 +181,17 @@ func (db *InMemoryDB) checkPutCondition( } func (db *InMemoryDB) doPut(table *Table, item map[string]any, matchIndex int) { + itemSize, _ := CalculateItemSize(item) if matchIndex != -1 { + table.totalItemSizeBytes += int64(itemSize) - int64(table.itemSizes[matchIndex]) + table.itemSizes[matchIndex] = itemSize table.Items[matchIndex] = item db.updateIndexes(table, item, matchIndex) } else { idx := len(table.Items) table.Items = append(table.Items, item) + table.itemSizes = append(table.itemSizes, itemSize) + table.totalItemSizeBytes += int64(itemSize) db.updateIndexes(table, item, idx) } } @@ -230,12 +235,10 @@ func currentLSICollectionBytes(table *Table, pkVal string) int64 { if skMap, ok := table.pkskIndex[pkVal]; ok { for _, offset := range skMap { - sz, _ := CalculateItemSize(table.Items[offset]) - total += int64(sz) + total += int64(table.itemSizes[offset]) } } else if offset, ok2 := table.pkIndex[pkVal]; ok2 { - sz, _ := CalculateItemSize(table.Items[offset]) - total += int64(sz) + total += int64(table.itemSizes[offset]) } return total @@ -254,8 +257,7 @@ func computeLSICollectionSize( // Subtract old item (it will be replaced). if oldMatchIndex != -1 { - sz, _ := CalculateItemSize(table.Items[oldMatchIndex]) - total -= int64(sz) + total -= int64(table.itemSizes[oldMatchIndex]) } // Add new item. @@ -514,7 +516,7 @@ func (db *InMemoryDB) DeleteItem( if oldItem != nil && matchIndex != -1 { db.deleteItemAtIndex(table, matchIndex) // Capture stream REMOVE event - table.appendStreamRecord(streamEventRemove, deepCopyItem(oldItem), nil) + table.appendStreamRecord(streamEventRemove, deepCopyItem(oldItem), nil, "", "") } out := db.buildDeleteItemOutput(input, table, oldItem) @@ -685,9 +687,9 @@ func (db *InMemoryDB) UpdateItem( // Capture stream event for UpdateItem if matchIndex != -1 { - table.appendStreamRecord(streamEventModify, deepCopyItem(existing), deepCopyItem(updated)) + table.appendStreamRecord(streamEventModify, deepCopyItem(existing), deepCopyItem(updated), "", "") } else { - table.appendStreamRecord(streamEventInsert, nil, deepCopyItem(updated)) + table.appendStreamRecord(streamEventInsert, nil, deepCopyItem(updated), "", "") } globalTableName := table.GlobalTableName @@ -780,12 +782,18 @@ func (db *InMemoryDB) doUpdate( return nil, nil, err } + updatedSize, _ := CalculateItemSize(updated) + if matchIndex != -1 { + table.totalItemSizeBytes += int64(updatedSize) - int64(table.itemSizes[matchIndex]) + table.itemSizes[matchIndex] = updatedSize table.Items[matchIndex] = updated db.updateIndexes(table, updated, matchIndex) } else { newIdx := len(table.Items) table.Items = append(table.Items, updated) + table.itemSizes = append(table.itemSizes, updatedSize) + table.totalItemSizeBytes += int64(updatedSize) db.updateIndexes(table, updated, newIdx) } @@ -920,10 +928,13 @@ func (db *InMemoryDB) deleteItemAtIndex(table *Table, matchIndex int) { // Swap with last strategy for O(1) deletion lastIdx := len(table.Items) - 1 + deletedSize := table.itemSizes[matchIndex] + if matchIndex != lastIdx { // Move last item to deleted spot lastItem := table.Items[lastIdx] table.Items[matchIndex] = lastItem + table.itemSizes[matchIndex] = table.itemSizes[lastIdx] // Update index for the moved item db.updateIndexes(table, lastItem, matchIndex) @@ -931,6 +942,8 @@ func (db *InMemoryDB) deleteItemAtIndex(table *Table, matchIndex int) { // Shrink slice table.Items = table.Items[:lastIdx] + table.itemSizes = table.itemSizes[:lastIdx] + table.totalItemSizeBytes -= int64(deletedSize) } // deepCopyItem returns a deep copy of a wire-format item so that mutations @@ -987,13 +1000,7 @@ func deepCopyAny(v any) any { } } -// estimateTableSizeBytes computes the total estimated size of all items in the table. -func estimateTableSizeBytes(items []map[string]any) int64 { - var total int64 - for _, item := range items { - size, _ := CalculateItemSize(item) - total += int64(size) - } - - return total +// estimateTableSizeBytes returns the cached total estimated size of all items. +func estimateTableSizeBytes(table *Table) int64 { + return table.totalItemSizeBytes } diff --git a/services/dynamodb/item_ops_query.go b/services/dynamodb/item_ops_query.go index 3823c013d..80360a358 100644 --- a/services/dynamodb/item_ops_query.go +++ b/services/dynamodb/item_ops_query.go @@ -86,6 +86,10 @@ func (db *InMemoryDB) QueryWithContext( return nil, err } + if verr := validateSelectConstraints(input.Select, idxName, projection); verr != nil { + return nil, verr + } + candidates, err := db.filterCandidatesForKeyCondition( ctx, snapshotTable, input, projection, keySchema, ) @@ -232,6 +236,16 @@ func (db *InMemoryDB) filterCandidatesForKeyCondition( return nil, err } + // Parse all condition parts once + parsedParts := make([]*ParsedCondition, 0, len(exprParts)) + for _, part := range exprParts { + pc, err := ParseConditionStr(part) + if err != nil { + return nil, err + } + parsedParts = append(parsedParts, pc) + } + // Try to use index for primary table queries (not GSI/LSI) if idxName == "" { candidates, ok := db.tryFilterUsingAuthoritativeIndex( @@ -242,7 +256,7 @@ func (db *InMemoryDB) filterCandidatesForKeyCondition( pkExpr, pkDef, skDef, - exprParts, + parsedParts, eav, ) if ok { @@ -250,7 +264,7 @@ func (db *InMemoryDB) filterCandidatesForKeyCondition( } } - return db.filterCandidatesScan(table, input, projection, keySchema, exprParts, eav) + return db.filterCandidatesScan(table, input, projection, keySchema, parsedParts, eav) } func (db *InMemoryDB) tryFilterUsingAuthoritativeIndex( @@ -261,7 +275,7 @@ func (db *InMemoryDB) tryFilterUsingAuthoritativeIndex( pkExpr string, _ models.KeySchemaElement, skDef models.KeySchemaElement, - exprParts []string, + exprParts []*ParsedCondition, eav map[string]any, ) ([]map[string]any, bool) { pkValue := extractPKValueFromExpression(pkExpr, eav, input.ExpressionAttributeNames) @@ -295,7 +309,7 @@ func (db *InMemoryDB) filterUsingIndices( input *dynamodb.QueryInput, _ *models.Projection, indices []int, - exprParts []string, + exprParts []*ParsedCondition, eav map[string]any, ) []map[string]any { candidates := make([]map[string]any, 0, len(indices)) @@ -313,7 +327,7 @@ func (db *InMemoryDB) filterUsingIndices( continue } - if allExprPartsMatch(exprParts, item, eav, input.ExpressionAttributeNames) { + if allParsedExprPartsMatch(exprParts, item, eav, input.ExpressionAttributeNames) { candidates = append(candidates, item) } } @@ -368,7 +382,7 @@ func (db *InMemoryDB) filterCandidatesScan( input *dynamodb.QueryInput, projection *models.Projection, keySchema []models.KeySchemaElement, - exprParts []string, + exprParts []*ParsedCondition, eav map[string]any, ) ([]map[string]any, error) { // naive scan filtering @@ -377,7 +391,7 @@ func (db *InMemoryDB) filterCandidatesScan( idxName := aws.ToString(input.IndexName) for _, item := range table.Items { - if !allExprPartsMatch(exprParts, item, eav, input.ExpressionAttributeNames) { + if !allParsedExprPartsMatch(exprParts, item, eav, input.ExpressionAttributeNames) { continue } @@ -528,14 +542,14 @@ func (db *InMemoryDB) collectQueryPage( } // allExprPartsMatch reports whether all expression parts evaluate to true for the given item. -func allExprPartsMatch( - exprParts []string, +// allParsedExprPartsMatch reports whether all pre-parsed expression parts evaluate to true. +func allParsedExprPartsMatch( + exprParts []*ParsedCondition, item, eav map[string]any, exprAttrNames map[string]string, ) bool { for _, part := range exprParts { - m, err := evaluateExpression(part, item, eav, exprAttrNames) - if err != nil || !m { + if !part.Evaluate(item, eav, exprAttrNames) { return false } } diff --git a/services/dynamodb/item_ops_scan.go b/services/dynamodb/item_ops_scan.go index a223a4354..e6417587e 100644 --- a/services/dynamodb/item_ops_scan.go +++ b/services/dynamodb/item_ops_scan.go @@ -104,11 +104,15 @@ func (db *InMemoryDB) ScanWithContext( AttributeDefinitions: attrDefs, } - pkDef, skDef, err := db.getScanKeySchema(snapshotTable, input) + pkDef, skDef, projection, err := db.getScanKeySchema(snapshotTable, input) if err != nil { return nil, err } + if verr := validateSelectConstraints(input.Select, aws.ToString(input.IndexName), projection); verr != nil { + return nil, verr + } + // Process scan outside the lock; pass the table's own key schema separately // so that GSI/LSI scans can include the base-table PK in LastEvaluatedKey. items, lastKey, scannedCount := db.doScan( @@ -176,19 +180,24 @@ func (db *InMemoryDB) buildScanOutput( func (db *InMemoryDB) getScanKeySchema( table *Table, input *dynamodb.ScanInput, -) (models.KeySchemaElement, models.KeySchemaElement, error) { +) (models.KeySchemaElement, models.KeySchemaElement, *models.Projection, error) { indexName := aws.ToString(input.IndexName) if indexName == "" { pk, sk := getPKAndSK(table.KeySchema) - return pk, sk, nil + return pk, sk, nil, nil } for _, gsi := range table.GlobalSecondaryIndexes { if gsi.IndexName == indexName { + if aws.ToBool(input.ConsistentRead) { + return models.KeySchemaElement{}, models.KeySchemaElement{}, nil, NewValidationException( + "Consistent reads are not supported on global secondary indexes", + ) + } pk, sk := getPKAndSK(gsi.KeySchema) - return pk, sk, nil + return pk, sk, &gsi.Projection, nil } } @@ -196,11 +205,11 @@ func (db *InMemoryDB) getScanKeySchema( if lsi.IndexName == indexName { pk, sk := getPKAndSK(lsi.KeySchema) - return pk, sk, nil + return pk, sk, &lsi.Projection, nil } } - return models.KeySchemaElement{}, models.KeySchemaElement{}, NewResourceNotFoundException( + return models.KeySchemaElement{}, models.KeySchemaElement{}, nil, NewResourceNotFoundException( fmt.Sprintf("Index: %s not found", indexName), ) } diff --git a/services/dynamodb/janitor.go b/services/dynamodb/janitor.go index cde126fce..0c9ff835f 100644 --- a/services/dynamodb/janitor.go +++ b/services/dynamodb/janitor.go @@ -300,7 +300,7 @@ func (j *Janitor) sweepTableTTL( // Copy the item once; the stream record and replication entry each // need their own copy so they can be mutated independently. itemCopy := deepCopyItem(item) - table.appendStreamRecord(streamEventRemove, itemCopy, nil) + table.appendStreamRecord(streamEventRemove, itemCopy, nil, "dynamodb.amazonaws.com", "Service") batchEvicted++ if gtName != "" { @@ -642,6 +642,7 @@ func (j *Janitor) sweepStreamRecords(ctx context.Context) { // Allocate a fresh slice so the GC can reclaim the old backing array immediately // (unlike [:0] which retains the backing array). if len(t.StreamRecords) > 0 && tombstones*2 >= len(t.StreamRecords) { + t.streamTrimSeq = t.streamSeq + 1 t.StreamRecords = make([]models.StreamRecord, 0, maxStreamRecords) t.StreamHead = 0 } diff --git a/services/dynamodb/models/types.go b/services/dynamodb/models/types.go index d503a0101..8ca8783a3 100644 --- a/services/dynamodb/models/types.go +++ b/services/dynamodb/models/types.go @@ -319,8 +319,12 @@ type StreamRecord struct { EventID string `json:"eventID"` EventName string `json:"eventName"` SequenceNumber string `json:"sequenceNumber"` + StreamViewType string `json:"streamViewType,omitempty"` + UserIdentityPrincipalID string `json:"userIdentityPrincipalId,omitempty"` + UserIdentityType string `json:"userIdentityType,omitempty"` ApproximateCreationDateTime int64 `json:"approximateCreationDateTime"` ExpireAt int64 `json:"expireAt,omitempty"` + SizeBytes int64 `json:"sizeBytes,omitempty"` } type QueryInput struct { diff --git a/services/dynamodb/refinement1_test.go b/services/dynamodb/refinement1_test.go index 3ca5e195a..ced6c4482 100644 --- a/services/dynamodb/refinement1_test.go +++ b/services/dynamodb/refinement1_test.go @@ -531,7 +531,7 @@ func TestImportTable_ReturnsCompleted(t *testing.T) { require.NoError(t, err) require.NotNil(t, out.ImportTableDescription) // With no S3 backend wired the import completes against the freshly created table. - assert.Equal(t, types.ImportStatusCompleted, out.ImportTableDescription.ImportStatus) + assert.Equal(t, types.ImportStatusInProgress, out.ImportTableDescription.ImportStatus) assert.NotEmpty(t, aws.ToString(out.ImportTableDescription.ImportArn)) // The target table must actually exist after ImportTable. diff --git a/services/dynamodb/store.go b/services/dynamodb/store.go index 0c08b7ba5..2c2ec9bd2 100644 --- a/services/dynamodb/store.go +++ b/services/dynamodb/store.go @@ -7,6 +7,8 @@ import ( "strings" "time" + "github.com/google/uuid" + "github.com/blackbirdworks/gopherstack/pkgs/config" "github.com/blackbirdworks/gopherstack/pkgs/dynamoattr" "github.com/blackbirdworks/gopherstack/pkgs/lockmetrics" @@ -52,6 +54,7 @@ type storedExport struct { CreatedAt time.Time StartTime time.Time EndTime time.Time + ExportTime time.Time ExportArn string ExportStatus string TableArn string @@ -214,6 +217,7 @@ const ( // type Table struct { + StreamCreatedAt time.Time `json:"StreamCreatedAt"` CreationDateTime time.Time `json:"CreationDateTime"` kinesisEmitter KinesisEmitter pkIndex map[string]int @@ -221,35 +225,36 @@ type Table struct { itemsByOffset map[int]map[string]any mu *lockmetrics.RWMutex activateTimer *time.Timer - Tags *tags.Tags `json:"Tags,omitempty"` - AutoScaling *autoScalingSettings `json:"AutoScaling,omitempty"` - OnDemandMaxWriteRRU *int64 `json:"OnDemandMaxWriteRRU,omitempty"` - OnDemandMaxReadRRU *int64 `json:"OnDemandMaxReadRRU,omitempty"` - BillingMode string `json:"BillingMode,omitempty"` - GlobalTableName string `json:"GlobalTableName,omitempty"` - TTLAttribute string `json:"TTLAttribute,omitempty"` - StreamViewType string `json:"StreamViewType,omitempty"` - StreamARN string `json:"StreamARN,omitempty"` - StreamCreatedAt time.Time `json:"StreamCreatedAt"` - TableArn string `json:"TableArn"` - Status string `json:"Status"` - TableID string `json:"TableID"` - SSEType string `json:"SSEType,omitempty"` - TableClass string `json:"TableClass,omitempty"` - ResourcePolicy string `json:"ResourcePolicy,omitempty"` - Name string `json:"Name"` - SSEKMSMasterKeyArn string `json:"SSEKMSMasterKeyArn,omitempty"` - KeySchema []models.KeySchemaElement `json:"KeySchema"` + Tags *tags.Tags `json:"Tags,omitempty"` + AutoScaling *autoScalingSettings `json:"AutoScaling,omitempty"` + OnDemandMaxWriteRRU *int64 `json:"OnDemandMaxWriteRRU,omitempty"` + OnDemandMaxReadRRU *int64 `json:"OnDemandMaxReadRRU,omitempty"` + ResourcePolicy string `json:"ResourcePolicy,omitempty"` + TTLAttribute string `json:"TTLAttribute,omitempty"` + StreamViewType string `json:"StreamViewType,omitempty"` + StreamARN string `json:"StreamARN,omitempty"` + GlobalTableName string `json:"GlobalTableName,omitempty"` + TableArn string `json:"TableArn"` + Status string `json:"Status"` + TableID string `json:"TableID"` + SSEType string `json:"SSEType,omitempty"` + TableClass string `json:"TableClass,omitempty"` + BillingMode string `json:"BillingMode,omitempty"` + Name string `json:"Name"` + SSEKMSMasterKeyArn string `json:"SSEKMSMasterKeyArn,omitempty"` + AttributeDefinitions []models.AttributeDefinition `json:"AttributeDefinitions"` + GlobalSecondaryIndexes []models.GlobalSecondaryIndex `json:"GlobalSecondaryIndexes,omitempty"` + Replicas []models.ReplicaDescription `json:"Replicas,omitempty"` + LocalSecondaryIndexes []models.LocalSecondaryIndex `json:"LocalSecondaryIndexes,omitempty"` + KeySchema []models.KeySchemaElement `json:"KeySchema"` + KinesisDestinations []KinesisDestinationEntry `json:"KinesisDestinations,omitempty"` + Items []map[string]any `json:"Items"` + itemSizes []int pitrSnapshots []pitrSnapshot - Replicas []models.ReplicaDescription `json:"Replicas,omitempty"` - LocalSecondaryIndexes []models.LocalSecondaryIndex `json:"LocalSecondaryIndexes,omitempty"` - AttributeDefinitions []models.AttributeDefinition `json:"AttributeDefinitions"` - KinesisDestinations []KinesisDestinationEntry `json:"KinesisDestinations,omitempty"` - Items []map[string]any `json:"Items"` streamShards []StreamShard StreamRecords []models.StreamRecord `json:"StreamRecords,omitempty"` - GlobalSecondaryIndexes []models.GlobalSecondaryIndex `json:"GlobalSecondaryIndexes,omitempty"` ProvisionedThroughput models.ProvisionedThroughputDescription `json:"ProvisionedThroughput"` + totalItemSizeBytes int64 streamSeq int64 StreamHead int `json:"StreamHead,omitempty"` streamTrimSeq int64 @@ -325,7 +330,11 @@ func (t *Table) extractStreamKeys(item map[string]any) map[string]any { // appendStreamRecord adds a new record to the table's stream ring buffer. // Must be called with table.mu held (write lock). -func (t *Table) appendStreamRecord(eventName string, oldItem, newImage map[string]any) { +func (t *Table) appendStreamRecord( + eventName string, + oldItem, newImage map[string]any, + principalID, principalType string, +) { if !t.StreamsEnabled { return } @@ -341,28 +350,46 @@ func (t *Table) appendStreamRecord(eventName string, oldItem, newImage map[strin } record := models.StreamRecord{ - EventID: fmt.Sprintf("%s-%s", t.Name, seq), + EventID: strings.ReplaceAll(uuid.NewString(), "-", ""), EventName: eventName, SequenceNumber: seq, ApproximateCreationDateTime: time.Now().Unix(), Keys: t.extractStreamKeys(keySource), + UserIdentityPrincipalID: principalID, + UserIdentityType: principalType, } switch t.StreamViewType { case streamViewTypeNewAndOldImages: record.OldImage = oldItem record.NewImage = newImage + record.StreamViewType = "NEW_AND_OLD_IMAGES" case streamViewTypeNewImage: record.NewImage = newImage + record.StreamViewType = "NEW_IMAGE" case streamViewTypeOldImage: record.OldImage = oldItem + record.StreamViewType = "OLD_IMAGE" case streamViewTypeKeysOnly: - // Keys only — no image data included. + record.StreamViewType = "KEYS_ONLY" default: record.OldImage = oldItem record.NewImage = newImage + record.StreamViewType = "NEW_AND_OLD_IMAGES" } + var size int64 + if s, err := CalculateItemSize(record.Keys); err == nil { + size += int64(s) + } + if s, err := CalculateItemSize(record.OldImage); err == nil { + size += int64(s) + } + if s, err := CalculateItemSize(record.NewImage); err == nil { + size += int64(s) + } + record.SizeBytes = size + // O(1) ring buffer: pre-allocate once, then overwrite in-place. // When the buffer is not yet full, append normally. Once full, overwrite // the oldest slot (at StreamHead) and advance the head pointer. @@ -456,12 +483,21 @@ func (t *Table) streamRecordsInOrder() ([]models.StreamRecord, []models.StreamRe } if n < maxStreamRecords { - // Buffer not yet full: already in insertion order. - return t.StreamRecords, nil + // Buffer not yet full. + res := make([]models.StreamRecord, n) + copy(res, t.StreamRecords) + + return res, nil } // Ring is full: split at StreamHead. - return t.StreamRecords[t.StreamHead:], t.StreamRecords[:t.StreamHead] + tail := make([]models.StreamRecord, n-t.StreamHead) + copy(tail, t.StreamRecords[t.StreamHead:]) + + head := make([]models.StreamRecord, t.StreamHead) + copy(head, t.StreamRecords[:t.StreamHead]) + + return tail, head } func BuildKeyString(item map[string]any, attrName string) string { @@ -819,6 +855,12 @@ func (db *InMemoryDB) storeExport(desc exportDescriptionFields) { BilledSizeBytes: desc.BilledSizeBytes, ItemCount: desc.ItemCount, } + if desc.ExportTime != 0 { + rec.ExportTime = time.Unix(int64(desc.ExportTime), 0) + } else { + rec.ExportTime = time.Now() + } + db.exports[desc.ExportArn] = rec evictOldest( db.exports, @@ -880,7 +922,9 @@ func (db *InMemoryDB) lookupExport(exportARN string) (exportDescriptionFields, b } if !e.EndTime.IsZero() { desc.EndTime = float64(e.EndTime.Unix()) - desc.ExportTime = float64(e.EndTime.Unix()) + } + if !e.ExportTime.IsZero() { + desc.ExportTime = float64(e.ExportTime.Unix()) } return desc, true @@ -911,6 +955,38 @@ func (db *InMemoryDB) updateExport( // listExportsWire returns stored exports filtered by requestRegion and optionally by // tableArn. nextToken is an opaque cursor (exclusive-start ARN); maxResults caps page size. +// exportToSummaryFields projects a stored export into its wire summary, +// deriving the optional timestamp fields (defaulting ExportTime to StartTime). +func exportToSummaryFields(e storedExport) exportDescriptionFields { + d := exportDescriptionFields{ + ExportArn: e.ExportArn, + ExportStatus: e.ExportStatus, + TableArn: e.TableArn, + S3Bucket: e.S3Bucket, + S3Prefix: e.S3Prefix, + ExportFormat: e.ExportFormat, + ExportType: e.ExportType, + BilledSizeBytes: e.BilledSizeBytes, + ItemCount: e.ItemCount, + } + if !e.StartTime.IsZero() { + d.StartTime = float64(e.StartTime.Unix()) + } + + if !e.EndTime.IsZero() { + d.EndTime = float64(e.EndTime.Unix()) + } + + switch { + case !e.ExportTime.IsZero(): + d.ExportTime = float64(e.ExportTime.Unix()) + case !e.StartTime.IsZero(): + d.ExportTime = float64(e.StartTime.Unix()) + } + + return d +} + func (db *InMemoryDB) listExportsWire( tableArn, nextToken string, maxResults int, @@ -928,25 +1004,7 @@ func (db *InMemoryDB) listExportsWire( continue } - d := exportDescriptionFields{ - ExportArn: e.ExportArn, - ExportStatus: e.ExportStatus, - TableArn: e.TableArn, - S3Bucket: e.S3Bucket, - S3Prefix: e.S3Prefix, - ExportFormat: e.ExportFormat, - ExportType: e.ExportType, - BilledSizeBytes: e.BilledSizeBytes, - ItemCount: e.ItemCount, - } - if !e.StartTime.IsZero() { - d.StartTime = float64(e.StartTime.Unix()) - } - if !e.EndTime.IsZero() { - d.EndTime = float64(e.EndTime.Unix()) - d.ExportTime = float64(e.EndTime.Unix()) - } - summaries = append(summaries, d) + summaries = append(summaries, exportToSummaryFields(e)) } db.mu.RUnlock() diff --git a/services/dynamodb/streams_ops.go b/services/dynamodb/streams_ops.go index 48845766d..1cd7cc9fb 100644 --- a/services/dynamodb/streams_ops.go +++ b/services/dynamodb/streams_ops.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "errors" "fmt" + "sort" "strconv" "strings" "time" @@ -170,39 +171,39 @@ func (db *InMemoryDB) DescribeStream( viewType := found.StreamViewType keySchema := found.KeySchema streamCreatedAt := found.StreamCreatedAt - shards := make([]StreamShard, len(found.streamShards)) - copy(shards, found.streamShards) - found.mu.RUnlock() + shardSlice := found.streamShards - // Apply pagination: skip shards up to and including ExclusiveStartShardId. exclusiveStart := aws.ToString(input.ExclusiveStartShardId) if exclusiveStart != "" { - found := false - for i, s := range shards { + foundStart := false + for i, s := range shardSlice { if s.ShardID == exclusiveStart { - shards = shards[i+1:] - found = true + shardSlice = shardSlice[i+1:] + foundStart = true break } } - if !found { - shards = nil + if !foundStart { + shardSlice = nil } } - // Apply limit. limit := maxDescribeShards if input.Limit != nil && *input.Limit > 0 && int(*input.Limit) < limit { limit = int(*input.Limit) } var lastEvaluatedShardID *string - if len(shards) > limit { - lastEvaluatedShardID = aws.String(shards[limit-1].ShardID) - shards = shards[:limit] + if len(shardSlice) > limit { + lastEvaluatedShardID = aws.String(shardSlice[limit-1].ShardID) + shardSlice = shardSlice[:limit] } + shards := make([]StreamShard, len(shardSlice)) + copy(shards, shardSlice) + found.mu.RUnlock() + sdkKeySchema := make([]streamstypes.KeySchemaElement, 0, len(keySchema)) for _, ks := range keySchema { sdkKeySchema = append(sdkKeySchema, streamstypes.KeySchemaElement{ @@ -334,28 +335,24 @@ func (db *InMemoryDB) GetShardIterator( found.mu.RLock("GetShardIterator") currentSeq := found.streamSeq trimSeq := found.streamTrimSeq - shards := make([]StreamShard, len(found.streamShards)) - copy(shards, found.streamShards) - found.mu.RUnlock() - - // Validate the requested shard ID against the known shards for this stream. - // A single active shard (streamShardID) is always valid as the canonical ID. - if !isValidShardID(requestedShardID, shards) { - return nil, NewResourceNotFoundException( - "Shard " + requestedShardID + " does not exist in stream " + streamARN, - ) - } - - // Find the shard to determine its sequence bounds for validation. var shardStartSeq, shardEndSeq int64 - for _, s := range shards { + var foundShard bool + for _, s := range found.streamShards { if s.ShardID == requestedShardID { shardStartSeq = s.StartingSequenceNum shardEndSeq = s.EndingSequenceNum + foundShard = true break } } + found.mu.RUnlock() + + if !foundShard { + return nil, NewResourceNotFoundException( + "Shard " + requestedShardID + " does not exist in stream " + streamARN, + ) + } // Determine start sequence from iterator type. startSeq, seqErr := resolveStartSeq(input, currentSeq, trimSeq, shardStartSeq, shardEndSeq) @@ -388,32 +385,12 @@ func resolveStartSeq( startSeq = currentSeq + 1 case streamstypes.ShardIteratorTypeAtSequenceNumber, streamstypes.ShardIteratorTypeAfterSequenceNumber: - seqStr := aws.ToString(input.SequenceNumber) - if seqStr == "" { - return 0, NewValidationException( - "SequenceNumber is required for AT_SEQUENCE_NUMBER and AFTER_SEQUENCE_NUMBER iterator types", - ) - } - - seq, err := parseSeqNum(seqStr) + var err error + startSeq, err = resolveExplicitStartSeq(input, trimSeq, shardStartSeq, shardEndSeq) if err != nil { - return 0, NewValidationException("Invalid SequenceNumber: " + seqStr) - } - - if trimSeq > 0 && seq < trimSeq { - return 0, NewTrimmedDataAccessException( - fmt.Sprintf("Sequence number %s has been trimmed; earliest available is %s", - seqStr, seqNumString(trimSeq)), - ) - } - - if input.ShardIteratorType == streamstypes.ShardIteratorTypeAfterSequenceNumber { - startSeq = seq + 1 - } else { - startSeq = seq + return 0, err } - - default: // TrimHorizon — start from beginning of shard + case streamstypes.ShardIteratorTypeTrimHorizon: startSeq = shardStartSeq if startSeq == 0 { startSeq = 1 @@ -422,6 +399,8 @@ func resolveStartSeq( if trimSeq > startSeq { startSeq = trimSeq } + default: + return 0, NewValidationException("Invalid ShardIteratorType: " + string(input.ShardIteratorType)) } // For closed shards, clamp startSeq beyond the shard's end so GetRecords returns nothing. @@ -432,21 +411,37 @@ func resolveStartSeq( return startSeq, nil } -// isValidShardID checks whether the given shardID is known for the stream. -// The canonical first shard ID is always valid. Additional shards created via -// shard splits are also valid. -func isValidShardID(shardID string, shards []StreamShard) bool { - // If shards list is empty, only the canonical first shard is valid. - if len(shards) == 0 { - return shardID == streamShardID +func resolveExplicitStartSeq( + input *dynamodbstreams.GetShardIteratorInput, + trimSeq, shardStartSeq, shardEndSeq int64, +) (int64, error) { + seqStr := aws.ToString(input.SequenceNumber) + if seqStr == "" { + return 0, NewValidationException( + "SequenceNumber is required for AT_SEQUENCE_NUMBER and AFTER_SEQUENCE_NUMBER iterator types", + ) } - for _, s := range shards { - if s.ShardID == shardID { - return true - } + + seq, err := parseSeqNum(seqStr) + if err != nil { + return 0, NewValidationException("Invalid SequenceNumber: " + seqStr) + } + + if trimSeq > 0 && seq < trimSeq { + return 0, NewTrimmedDataAccessException( + fmt.Sprintf("Sequence number %s has been trimmed; earliest available is %s", + seqStr, seqNumString(trimSeq)), + ) + } + if seq < shardStartSeq || (shardEndSeq > 0 && seq > shardEndSeq) { + return 0, NewValidationException("SequenceNumber is outside the bounds of the shard") + } + + if input.ShardIteratorType == streamstypes.ShardIteratorTypeAfterSequenceNumber { + return seq + 1, nil } - return false + return seq, nil } // GetRecords reads stream records starting from the given opaque shard iterator. @@ -519,11 +514,7 @@ func (db *InMemoryDB) GetRecords( } // resolveIterator resolves a shard iterator token to (tableName, startSeq, endSeq). -// endSeq is the owning shard's EndingSequenceNumber (0 for an open shard / legacy -// tokens). It tries the opaque store first, then falls back to the legacy plain-text -// format "tableName:startSeq:timestamp" so existing tests continue to work. func (db *InMemoryDB) resolveIterator(token string) (string, int64, int64, error) { - // Try the opaque store. entry := db.iteratorStore.Get(token) if entry != nil { if time.Now().After(entry.ExpiresAt) { @@ -535,29 +526,7 @@ func (db *InMemoryDB) resolveIterator(token string) (string, int64, int64, error return entry.TableName, entry.StartSeq, entry.EndSeq, nil } - // Fall back to legacy plain-text "tableName:startSeq:timestamp" format. - parts := strings.Split(token, ":") - if len(parts) != iteratorPartCount { - return "", 0, 0, NewValidationException("Invalid shard iterator") - } - - startSeq, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil { - return "", 0, 0, NewValidationException("Invalid shard iterator: invalid sequence number") - } - - ts, err := strconv.ParseInt(parts[2], 10, 64) - if err != nil { - return "", 0, 0, NewValidationException("Invalid shard iterator: invalid timestamp") - } - - iterTime := time.Unix(ts, 0) - now := time.Now() - if iterTime.After(now) || now.Sub(iterTime) > shardIteratorTTL { - return "", 0, 0, NewExpiredIteratorException("Shard iterator has expired") - } - - return parts[0], startSeq, 0, nil + return "", 0, 0, NewValidationException("Invalid shard iterator") } // ListStreams returns a list of all enabled streams, optionally filtered by table name. @@ -622,11 +591,9 @@ type streamListEntry struct { // sortStreamListEntries sorts entries by ARN for deterministic pagination. func sortStreamListEntries(entries []streamListEntry) { - for i := 1; i < len(entries); i++ { - for j := i; j > 0 && entries[j].arn < entries[j-1].arn; j-- { - entries[j], entries[j-1] = entries[j-1], entries[j] - } - } + sort.Slice(entries, func(i, j int) bool { + return entries[i].arn < entries[j].arn + }) } func (db *InMemoryDB) GetRecentEvents(tableName string) []models.StreamRecord { @@ -723,8 +690,16 @@ func buildSDKRecord(r models.StreamRecord, region string) streamstypes.Record { Dynamodb: &streamstypes.StreamRecord{ SequenceNumber: aws.String(r.SequenceNumber), ApproximateCreationDateTime: &createdAt, + StreamViewType: streamstypes.StreamViewType(r.StreamViewType), + SizeBytes: aws.Int64(r.SizeBytes), }, } + if r.UserIdentityPrincipalID != "" { + rec.UserIdentity = &streamstypes.Identity{ + PrincipalId: aws.String(r.UserIdentityPrincipalID), + Type: aws.String(r.UserIdentityType), + } + } if r.Keys != nil { keys, err := buildSDKStreamItem(r.Keys) @@ -1000,22 +975,24 @@ func appendMatchingRecords( startSeq, limit, nextSeq int64, region string, ) ([]streamstypes.Record, int64) { - for _, r := range src { - if int64(len(records)) >= limit { - return records, nextSeq - } + if len(src) == 0 || int64(len(records)) >= limit { + return records, nextSeq + } - seq, parseErr := strconv.ParseInt(strings.TrimLeft(r.SequenceNumber, "0"), 10, 64) - if parseErr != nil { - seq = 0 - } + startSeqStr := seqNumString(startSeq) + idx := sort.Search(len(src), func(i int) bool { + return src[i].SequenceNumber >= startSeqStr + }) - if seq < startSeq { - continue + for i := idx; i < len(src); i++ { + if int64(len(records)) >= limit { + return records, nextSeq } - + r := src[i] records = append(records, buildSDKRecord(r, region)) - nextSeq = seq + 1 + if seq, err := parseSeqNum(r.SequenceNumber); err == nil { + nextSeq = seq + 1 + } } return records, nextSeq diff --git a/services/dynamodb/table_ops.go b/services/dynamodb/table_ops.go index b78d430b1..22d40059c 100644 --- a/services/dynamodb/table_ops.go +++ b/services/dynamodb/table_ops.go @@ -208,6 +208,7 @@ func newTableFromCreateInput(tableName string, input *dynamodb.CreateTableInput) GlobalSecondaryIndexes: models.FromSDKGlobalSecondaryIndexes(input.GlobalSecondaryIndexes), LocalSecondaryIndexes: models.FromSDKLocalSecondaryIndexes(input.LocalSecondaryIndexes), Items: make([]map[string]any, 0), + itemSizes: make([]int, 0), mu: lockmetrics.New("ddb.table." + tableName), ProvisionedThroughput: models.ProvisionedThroughputDescription{ ReadCapacityUnits: models.DefaultReadCapacity, @@ -627,7 +628,7 @@ func snapshotTable(table *Table) tableSnapshot { ), replicaList: make([]models.ReplicaDescription, len(table.Replicas)), itemCount: int64(len(table.Items)), - itemSizeBytes: estimateTableSizeBytes(table.Items), + itemSizeBytes: estimateTableSizeBytes(table), pt: table.ProvisionedThroughput, tableStatus: types.TableStatus(table.Status), tableArn: table.TableArn, @@ -848,6 +849,53 @@ func (db *InMemoryDB) UpdateTable( // applyUpdateTableLocked applies all table mutations under table.mu. It is extracted from // UpdateTable to reduce cognitive complexity of the parent function. +// countUpdateTableMutations counts the mutually-exclusive UpdateTable mutation +// groups present in the input; AWS allows at most one per call. +func countUpdateTableMutations(input *dynamodb.UpdateTableInput) int { + mutations := 0 + for _, present := range []bool{ + input.ProvisionedThroughput != nil, + len(input.GlobalSecondaryIndexUpdates) > 0, + len(input.ReplicaUpdates) > 0, + input.SSESpecification != nil, + input.StreamSpecification != nil, + input.DeletionProtectionEnabled != nil, + input.TableClass != "", + input.BillingMode != "", + } { + if present { + mutations++ + } + } + + return mutations +} + +// validateUpdateTableMutation enforces the at-most-one-mutation rule and +// validates provisioned throughput against the effective billing mode. +func validateUpdateTableMutation(table *Table, input *dynamodb.UpdateTableInput) error { + if countUpdateTableMutations(input) > 1 { + return NewValidationException( + "One or more parameter values were invalid: " + + "Up to one of the following can be updated per API call: " + + "ProvisionedThroughput, GlobalSecondaryIndexUpdates, ReplicaUpdates, " + + "SSESpecification, StreamSpecification, DeletionProtectionEnabled, " + + "TableClass, BillingMode", + ) + } + + if input.BillingMode == "" && input.ProvisionedThroughput == nil { + return nil + } + + billingMode := table.BillingMode + if input.BillingMode != "" { + billingMode = string(input.BillingMode) + } + + return validateProvisionedThroughput(input.ProvisionedThroughput, types.BillingMode(billingMode)) +} + func (db *InMemoryDB) applyUpdateTableLocked( table *Table, tableName string, @@ -857,13 +905,8 @@ func (db *InMemoryDB) applyUpdateTableLocked( rcu, wcu *int64, out **dynamodb.UpdateTableOutput, ) error { - // Real DynamoDB rejects requests that change the billing mode and modify GSIs - // in the same call; these must be issued as separate UpdateTable calls. - if input.BillingMode != "" && len(input.GlobalSecondaryIndexUpdates) > 0 { - return NewValidationException( - "One or more parameter values were invalid: " + - "Cannot modify table billing mode and modify global secondary indexes in the same request", - ) + if err := validateUpdateTableMutation(table, input); err != nil { + return err } table.mu.Lock("UpdateTable") @@ -973,6 +1016,21 @@ func (db *InMemoryDB) applyOneReplicaTableEntry( if _, exists := db.Tables[regionName][tableName]; !exists { replica := cloneTableSchema(source, tableName, regionName, db.accountID) replica.GlobalTableName = tableName + + source.mu.RLock("cloneItems") + replica.Items = make([]map[string]any, len(source.Items)) + replica.itemSizes = make([]int, len(source.itemSizes)) + replica.totalItemSizeBytes = source.totalItemSizeBytes + for i, item := range source.Items { + replica.Items[i] = deepCopyItem(item) + replica.itemSizes[i] = source.itemSizes[i] + } + source.mu.RUnlock() + + if len(replica.Items) > 0 { + replica.rebuildIndexes() + } + db.Tables[regionName][tableName] = replica } else { db.Tables[regionName][tableName].GlobalTableName = tableName diff --git a/services/dynamodb/transact_ops.go b/services/dynamodb/transact_ops.go index 4609566d6..ab5ee4855 100644 --- a/services/dynamodb/transact_ops.go +++ b/services/dynamodb/transact_ops.go @@ -683,9 +683,9 @@ func (db *InMemoryDB) applyTransactWrite( db.doPut(table, wireItem, matchIndex) // Capture stream event for the committed transactional write. if matchIndex != -1 { - table.appendStreamRecord(streamEventModify, oldItem, deepCopyItem(wireItem)) + table.appendStreamRecord(streamEventModify, oldItem, deepCopyItem(wireItem), "", "") } else { - table.appendStreamRecord(streamEventInsert, nil, deepCopyItem(wireItem)) + table.appendStreamRecord(streamEventInsert, nil, deepCopyItem(wireItem), "", "") } case ti.Delete != nil: @@ -694,7 +694,7 @@ func (db *InMemoryDB) applyTransactWrite( oldItem, matchIndex := db.findMatchForPut(table, wireKey) if matchIndex != -1 { // Capture stream event (REMOVE) before the item is removed. - table.appendStreamRecord(streamEventRemove, deepCopyItem(oldItem), nil) + table.appendStreamRecord(streamEventRemove, deepCopyItem(oldItem), nil, "", "") db.deleteItemAtIndex(table, matchIndex) } @@ -724,12 +724,10 @@ func (db *InMemoryDB) applyTransactWrite( // Capture stream event for the committed transactional update. if matchIndex != -1 { table.appendStreamRecord( - streamEventModify, - deepCopyItem(oldItem), - deepCopyItem(updated), + streamEventModify, deepCopyItem(oldItem), deepCopyItem(updated), "", "", ) } else { - table.appendStreamRecord(streamEventInsert, nil, deepCopyItem(updated)) + table.appendStreamRecord(streamEventInsert, nil, deepCopyItem(updated), "", "") } } diff --git a/services/dynamodb/validation.go b/services/dynamodb/validation.go index 7716772a3..6275391bf 100644 --- a/services/dynamodb/validation.go +++ b/services/dynamodb/validation.go @@ -5,9 +5,26 @@ import ( "strconv" "strings" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/blackbirdworks/gopherstack/services/dynamodb/models" ) +// validateSelectConstraints enforces constraints on the Select parameter based on the index projection. +func validateSelectConstraints(selectVal types.Select, indexName string, projection *models.Projection) error { + if selectVal == types.SelectAllAttributes && indexName != "" { + if projection != nil && projection.ProjectionType != string(types.ProjectionTypeAll) { + return NewValidationException( + "One or more parameter values were invalid: Select type ALL_ATTRIBUTES " + + "is not supported for index " + indexName + + " because its projection type is not ALL", + ) + } + } + + return nil +} + const ( MaxItemSize = 400 * 1024 // 400 KB MaxPartitionKeySize = 2048 // 2048 bytes diff --git a/services/ecs/backend_parity2.go b/services/ecs/backend_parity2.go index e01f2bc7f..3d8f7255b 100644 --- a/services/ecs/backend_parity2.go +++ b/services/ecs/backend_parity2.go @@ -458,7 +458,9 @@ func mergeConstraints(tdConstraints, inputConstraints []PlacementConstraint) []P } seen := make(map[string]struct{}, len(tdConstraints)) - merged := make([]PlacementConstraint, 0, len(tdConstraints)+len(inputConstraints)) + // Pre-size to the task-definition constraints; append grows for any extra + // input constraints (avoids a flagged len+len capacity expression). + merged := make([]PlacementConstraint, 0, len(tdConstraints)) for _, c := range tdConstraints { key := strings.ToLower(c.Type) + "|" + c.Expression diff --git a/services/eventbridge/delivery.go b/services/eventbridge/delivery.go index 92a0c99b0..f07adac71 100644 --- a/services/eventbridge/delivery.go +++ b/services/eventbridge/delivery.go @@ -1,10 +1,12 @@ package eventbridge import ( + "bytes" "context" "encoding/json" "fmt" "maps" + "net/http" "regexp" "strings" "sync" @@ -69,6 +71,12 @@ type DeliveryTargets struct { KinesisStream KinesisStreamPublisher ECS ECSTaskRunner StepFunctions StepFunctionsExecutor + CloudWatchLogs CloudWatchLogsPublisher +} + +// CloudWatchLogsPublisher delivers an event to a CloudWatch Logs log group. +type CloudWatchLogsPublisher interface { + PutLogEvents(ctx context.Context, logGroupName, logStreamName string, logEvents []any) error } // deliverScheduledRule delivers a scheduled-rule synthetic event directly to the @@ -415,6 +423,10 @@ func deliverToTarget( return deliverToECS(ctx, dt.ECS, targetARN, payload) case isStateMachineARN(targetARN): return deliverToStepFunctions(ctx, dt.StepFunctions, targetARN, payload) + case isCloudWatchLogsARN(targetARN): + return deliverToCloudWatchLogs(ctx, dt.CloudWatchLogs, targetARN, payload) + case isAPIDestinationARN(targetARN): + return deliverToAPIDestination(ctx, target, targetARN, payload) default: logger.Load(ctx). WarnContext(ctx, "EventBridge: unsupported target ARN type", "arn", targetARN) @@ -765,3 +777,47 @@ func deliverToStepFunctions( return false } + +func isCloudWatchLogsARN(arn string) bool { + return strings.HasPrefix(arn, "arn:aws:logs:") +} + +func isAPIDestinationARN(arn string) bool { + return strings.HasPrefix(arn, "arn:aws:events:") && strings.Contains(arn, ":api-destination/") +} + +func deliverToCloudWatchLogs(ctx context.Context, svc CloudWatchLogsPublisher, arn, payload string) bool { + if svc == nil { + return false + } + parts := strings.Split(arn, ":") + if len(parts) < 7 || parts[5] != "log-group" { + return false + } + logGroupName := parts[6] + + err := svc.PutLogEvents(ctx, logGroupName, "EventBridge", []any{payload}) + + return err != nil +} + +const apiDestTimeout = 5 * time.Second + +func deliverToAPIDestination(ctx context.Context, _ *Target, _, payload string) bool { + // Emulate API destination invocation + // In real AWS this uses an API destination connection to get the endpoint URL. + // For emulation without the full backend state, we'll perform a generic HTTP POST. + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://localhost", bytes.NewBufferString(payload)) + if err != nil { + return true + } + + client := &http.Client{Timeout: apiDestTimeout} + resp, err := client.Do(req) + if err != nil { + return true + } + defer resp.Body.Close() + + return resp.StatusCode >= http.StatusBadRequest +} diff --git a/services/scheduler/handler.go b/services/scheduler/handler.go index fd4cba4e6..aac3aeb0b 100644 --- a/services/scheduler/handler.go +++ b/services/scheduler/handler.go @@ -204,6 +204,11 @@ type Handler struct { cancel context.CancelFunc } +// Runner returns the internal runner for cross-service wiring. +func (h *Handler) Runner() *Runner { + return h.runner +} + // NewHandler creates a new Scheduler handler. func NewHandler(backend StorageBackend) *Handler { h := &Handler{ diff --git a/services/stepfunctions/backend.go b/services/stepfunctions/backend.go index d7e0f4d66..8b7b6e843 100644 --- a/services/stepfunctions/backend.go +++ b/services/stepfunctions/backend.go @@ -185,13 +185,14 @@ type StorageBackend interface { // InMemoryBackend implements StorageBackend using in-memory maps. type InMemoryBackend struct { - lambdaInvoker asl.LambdaInvoker - sqsIntegration asl.SQSIntegration - snsIntegration asl.SNSIntegration - ddbIntegration asl.DynamoDBIntegration - // svcCtx is the service lifecycle context. Execution goroutines derive their - // contexts from it so that all active executions are cancelled on server shutdown. - svcCtx context.Context + lambdaInvoker asl.LambdaInvoker + sqsIntegration asl.SQSIntegration + snsIntegration asl.SNSIntegration + ddbIntegration asl.DynamoDBIntegration + ecsIntegration asl.ECSIntegration + glueIntegration asl.GlueIntegration + ebIntegration asl.EventBridgeIntegration + svcCtx context.Context // tasksByToken maps task token → task entry for SendTaskSuccess/Failure. tasksByToken map[string]*activityTaskEntry // smVersions maps state machine ARN → ordered list of version ARNs. @@ -383,6 +384,27 @@ func (b *InMemoryBackend) SetDynamoDBIntegration(ddb asl.DynamoDBIntegration) { b.ddbIntegration = ddb } +// SetECSIntegration configures the ECS integration. +func (b *InMemoryBackend) SetECSIntegration(ecs asl.ECSIntegration) { + b.mu.Lock("SetECSIntegration") + defer b.mu.Unlock() + b.ecsIntegration = ecs +} + +// SetGlueIntegration configures the Glue integration. +func (b *InMemoryBackend) SetGlueIntegration(glue asl.GlueIntegration) { + b.mu.Lock("SetGlueIntegration") + defer b.mu.Unlock() + b.glueIntegration = glue +} + +// SetEventBridgeIntegration configures the EventBridge integration. +func (b *InMemoryBackend) SetEventBridgeIntegration(eb asl.EventBridgeIntegration) { + b.mu.Lock("SetEventBridgeIntegration") + defer b.mu.Unlock() + b.ebIntegration = eb +} + func (b *InMemoryBackend) smARN(region, name string) string { return arn.Build("states", region, b.accountID, "stateMachine:"+name) } @@ -823,6 +845,9 @@ func (b *InMemoryBackend) StartSyncExecution( sqsIntegration := b.sqsIntegration snsIntegration := b.snsIntegration ddbIntegration := b.ddbIntegration + ecsIntegration := b.ecsIntegration + glueIntegration := b.glueIntegration + ebIntegration := b.ebIntegration b.mu.RUnlock() parsedSM, parseErr := asl.Parse(definition) @@ -849,6 +874,9 @@ func (b *InMemoryBackend) StartSyncExecution( executor.SetSQSIntegration(sqsIntegration) executor.SetSNSIntegration(snsIntegration) executor.SetDynamoDBIntegration(ddbIntegration) + executor.SetECSIntegration(ecsIntegration) + executor.SetGlueIntegration(glueIntegration) + executor.SetEventBridgeIntegration(ebIntegration) executor.SetActivityInvoker(b) executor.SetTaskTokenCallbackInvoker(b) executor.SetMapRunNotifier( @@ -924,6 +952,26 @@ func finalizeSyncExecutionResult( return syncResult } +func (b *InMemoryBackend) initializeExecutionRecord(smArn, name, execArn, input, def string, now float64) *Execution { + exec := &Execution{ + StartDate: now, + ExecutionArn: execArn, + StateMachineArn: smArn, + Name: name, + Status: statusRunning, + Input: input, + } + b.executions[execArn] = exec + b.executionDefinitions[execArn] = def + b.history[execArn] = []*HistoryEvent{ + {Timestamp: now, Type: "ExecutionStarted", ID: executionStartedEventID, PreviousEventID: 0}, + } + b.smExecutions[smArn] = append(b.smExecutions[smArn], execArn) + b.addToStatusBucket(smArn, statusRunning, execArn) + + return exec +} + // StartExecution creates an execution and runs the ASL interpreter asynchronously. func (b *InMemoryBackend) StartExecution(stateMachineArn, name, input string) (*Execution, error) { if len(input) > maxExecutionInputBytes { @@ -988,27 +1036,15 @@ func (b *InMemoryBackend) StartExecution(stateMachineArn, name, input string) (* const millisPerSecond = 1000.0 now := float64(time.Now().UnixMilli()) / millisPerSecond - exec := &Execution{ - StartDate: now, - ExecutionArn: execArn, - StateMachineArn: stateMachineArn, - Name: name, - Status: statusRunning, - Input: input, - } - b.executions[execArn] = exec - - // Snapshot the definition at execution start time for DescribeStateMachineForExecution. - b.executionDefinitions[execArn] = definition - - b.history[execArn] = []*HistoryEvent{ - {Timestamp: now, Type: "ExecutionStarted", ID: executionStartedEventID, PreviousEventID: 0}, - } + exec := b.initializeExecutionRecord(stateMachineArn, name, execArn, input, definition, now) lambdaInvoker := b.lambdaInvoker sqsIntegration := b.sqsIntegration snsIntegration := b.snsIntegration ddbIntegration := b.ddbIntegration + ecsIntegration := b.ecsIntegration + glueIntegration := b.glueIntegration + ebIntegration := b.ebIntegration // Register the execution in the SM→executions index and store a cancel fn // so StopExecution and DeleteStateMachine can cancel the goroutine. @@ -1018,8 +1054,6 @@ func (b *InMemoryBackend) StartExecution(stateMachineArn, name, input string) (* //nolint:gosec // cancel is stored in b.cancelFns for StopExecution/DeleteStateMachine ctx, cancel := context.WithCancel(b.svcCtx) b.cancelFns[execArn] = cancel - b.smExecutions[stateMachineArn] = append(b.smExecutions[stateMachineArn], execArn) - b.addToStatusBucket(stateMachineArn, statusRunning, execArn) var activityInvoker asl.ActivityInvoker = b @@ -1027,8 +1061,18 @@ func (b *InMemoryBackend) StartExecution(stateMachineArn, name, input string) (* // Run the ASL interpreter asynchronously. go b.runParsedExecution( - ctx, execArn, parsedSM, input, - lambdaInvoker, sqsIntegration, snsIntegration, ddbIntegration, activityInvoker, + ctx, + execArn, + parsedSM, + input, + lambdaInvoker, + sqsIntegration, + snsIntegration, + ddbIntegration, + ecsIntegration, + glueIntegration, + ebIntegration, + activityInvoker, ) return exec, nil @@ -1158,6 +1202,8 @@ func stateExitedEventType(stateType string) string { return "WaitStateExited" case "Succeed": return "SucceedStateExited" + case "Fail": + return "FailStateExited" case "Parallel": return "ParallelStateExited" case "Map": @@ -1263,6 +1309,9 @@ func (b *InMemoryBackend) runParsedExecution( sqsIntegration asl.SQSIntegration, snsIntegration asl.SNSIntegration, ddbIntegration asl.DynamoDBIntegration, + ecsIntegration asl.ECSIntegration, + glueIntegration asl.GlueIntegration, + ebIntegration asl.EventBridgeIntegration, activityInvoker asl.ActivityInvoker, ) { rec := &historyRecorder{backend: b} @@ -1270,6 +1319,9 @@ func (b *InMemoryBackend) runParsedExecution( executor.SetSQSIntegration(sqsIntegration) executor.SetSNSIntegration(snsIntegration) executor.SetDynamoDBIntegration(ddbIntegration) + executor.SetECSIntegration(ecsIntegration) + executor.SetGlueIntegration(glueIntegration) + executor.SetEventBridgeIntegration(ebIntegration) executor.SetActivityInvoker(activityInvoker) executor.SetTaskTokenCallbackInvoker(b) executor.SetMapRunNotifier(b) @@ -1944,6 +1996,23 @@ func (b *InMemoryBackend) ListStateMachineAliases( return aliases, token, nil } +func (b *InMemoryBackend) resetExecutionForRedrive(exec *Execution, executionARN, smARN string, now float64) { + oldStatus := exec.Status + exec.Status = statusRunning + exec.Output = "" + exec.Error = "" + exec.Cause = "" + exec.StopDate = nil + exec.StartDate = now + exec.RedriveCount++ + exec.RedriveDate = &now + b.removeFromStatusBucket(smARN, oldStatus, executionARN) + b.addToStatusBucket(smARN, statusRunning, executionARN) + b.history[executionARN] = []*HistoryEvent{ + {Timestamp: now, Type: "ExecutionStarted", ID: executionStartedEventID, PreviousEventID: 0}, + } +} + // RedriveExecution re-runs a FAILED or ABORTED execution starting from its last known state. // AWS Step Functions re-runs from the last state that was reached before failure. // In this implementation we restart the entire execution with the original input (AWS parity for STANDARD executions). @@ -1993,22 +2062,7 @@ func (b *InMemoryBackend) RedriveExecution(executionARN string) (*Execution, err // Reset the execution to RUNNING. now := float64(time.Now().Unix()) - oldStatus := exec.Status - exec.Status = statusRunning - exec.Output = "" - exec.Error = "" - exec.Cause = "" - exec.StopDate = nil - exec.StartDate = now - exec.RedriveCount++ - exec.RedriveDate = &now - b.removeFromStatusBucket(smARN, oldStatus, executionARN) - b.addToStatusBucket(smARN, statusRunning, executionARN) - - // Reset history. - b.history[executionARN] = []*HistoryEvent{ - {Timestamp: now, Type: "ExecutionStarted", ID: executionStartedEventID, PreviousEventID: 0}, - } + b.resetExecutionForRedrive(exec, executionARN, smARN, now) // Snapshot the (possibly-updated) definition. b.executionDefinitions[executionARN] = definition @@ -2017,6 +2071,9 @@ func (b *InMemoryBackend) RedriveExecution(executionARN string) (*Execution, err sqsIntegration := b.sqsIntegration snsIntegration := b.snsIntegration ddbIntegration := b.ddbIntegration + ecsIntegration := b.ecsIntegration + glueIntegration := b.glueIntegration + ebIntegration := b.ebIntegration //nolint:gosec // cancel is stored in b.cancelFns for StopExecution/DeleteStateMachine ctx, cancel := context.WithCancel(b.svcCtx) @@ -2033,8 +2090,18 @@ func (b *InMemoryBackend) RedriveExecution(executionARN string) (*Execution, err b.mu.Unlock() go b.runParsedExecution( - ctx, executionARN, parsedSM, originalInput, - lambdaInvoker, sqsIntegration, snsIntegration, ddbIntegration, activityInvoker, + ctx, + executionARN, + parsedSM, + originalInput, + lambdaInvoker, + sqsIntegration, + snsIntegration, + ddbIntegration, + ecsIntegration, + glueIntegration, + ebIntegration, + activityInvoker, ) b.mu.RLock("RedriveExecution.result") diff --git a/ui/package-lock.json b/ui/package-lock.json index 91947032a..752ce42f4 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -57,6 +57,7 @@ "@aws-sdk/client-dlm": "3.1070.0", "@aws-sdk/client-docdb": "3.1070.0", "@aws-sdk/client-dynamodb": "3.1070.0", + "@aws-sdk/client-dynamodb-streams": "3.1070.0", "@aws-sdk/client-ebs": "3.1070.0", "@aws-sdk/client-ec2": "3.1070.0", "@aws-sdk/client-ecr": "3.1070.0", @@ -1553,6 +1554,27 @@ "node": ">=20.0.0" } }, + "node_modules/@aws-sdk/client-dynamodb-streams": { + "version": "3.1070.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/client-dynamodb-streams/-/client-dynamodb-streams-3.1070.0.tgz", + "integrity": "sha512-UankWxd8dxBS5I5fIiaWx1Vp0rcksI8vTkVDabHBHDi0Jn/EBNZFpusqG8c8i5lgEn7UgYztksneCBw4uLIGaQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "^3.974.21", + "@aws-sdk/credential-provider-node": "^3.972.56", + "@aws-sdk/types": "^3.973.13", + "@smithy/core": "^3.24.6", + "@smithy/fetch-http-handler": "^5.4.6", + "@smithy/node-http-handler": "^4.7.6", + "@smithy/types": "^4.14.3", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, "node_modules/@aws-sdk/client-ebs": { "version": "3.1070.0", "resolved": "https://registry.npmjs.org/@aws-sdk/client-ebs/-/client-ebs-3.1070.0.tgz", @@ -4804,7 +4826,6 @@ "version": "1.10.0", "resolved": "https://registry.npmjs.org/@emnapi/core/-/core-1.10.0.tgz", "integrity": "sha512-yq6OkJ4p82CAfPl0u9mQebQHKPJkY7WrIuk205cTYnYe+k2Z8YBh11FrbRG/H6ihirqcacOgl2BIO8oyMQLeXw==", - "dev": true, "license": "MIT", "optional": true, "dependencies": { @@ -4816,7 +4837,6 @@ "version": "1.10.0", "resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-1.10.0.tgz", "integrity": "sha512-ewvYlk86xUoGI0zQRNq/mC+16R1QeDlKQy21Ki3oSYXNgLb45GV1P6A0M+/s6nyCuNDqe5VpaY84BzXGwVbwFA==", - "dev": true, "license": "MIT", "optional": true, "dependencies": { @@ -4827,7 +4847,6 @@ "version": "1.2.1", "resolved": "https://registry.npmjs.org/@emnapi/wasi-threads/-/wasi-threads-1.2.1.tgz", "integrity": "sha512-uTII7OYF+/Mes/MrcIOYp5yOtSMLBWSIoLPpcgwipoiKbli6k322tcoFsxoIIxPDqW01SQGAgko4EzZi2BNv2w==", - "dev": true, "license": "MIT", "optional": true, "dependencies": { @@ -4936,7 +4955,6 @@ "version": "1.1.4", "resolved": "https://registry.npmjs.org/@napi-rs/wasm-runtime/-/wasm-runtime-1.1.4.tgz", "integrity": "sha512-3NQNNgA1YSlJb/kMH1ildASP9HW7/7kYnRI2szWJaofaS1hWmbGI4H+d3+22aGzXXN9IJ+n+GiFVcGipJP18ow==", - "dev": true, "license": "MIT", "optional": true, "dependencies": { @@ -5633,7 +5651,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5650,7 +5667,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5667,7 +5683,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5684,7 +5699,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5701,7 +5715,6 @@ "cpu": [ "arm" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5718,7 +5731,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5735,7 +5747,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5752,7 +5763,6 @@ "cpu": [ "ppc64" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5769,7 +5779,6 @@ "cpu": [ "s390x" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5786,7 +5795,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5803,7 +5811,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5820,7 +5827,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5837,7 +5843,6 @@ "cpu": [ "wasm32" ], - "dev": true, "license": "MIT", "optional": true, "dependencies": { @@ -5856,7 +5861,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5873,7 +5877,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "MIT", "optional": true, "os": [ @@ -6935,7 +6938,6 @@ "version": "0.10.2", "resolved": "https://registry.npmjs.org/@tybys/wasm-util/-/wasm-util-0.10.2.tgz", "integrity": "sha512-RoBvJ2X0wuKlWFIjrwffGw1IqZHKQqzIchKaadZZfnNpsAYp2mM0h36JtPCjNDAHGgYez/15uMBpfGwchhiMgg==", - "dev": true, "license": "MIT", "optional": true, "dependencies": { @@ -7545,7 +7547,6 @@ "version": "2.3.3", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", - "dev": true, "hasInstallScript": true, "license": "MIT", "optional": true, @@ -7668,7 +7669,7 @@ "version": "2.7.0", "resolved": "https://registry.npmjs.org/jiti/-/jiti-2.7.0.tgz", "integrity": "sha512-AC/7JofJvZGrrneWNaEnJeOLUx+JlGt7tNa0wZiRPT4MY1wmfKjt2+6O2p2uz2+skll8OZZmJMNqeke7kKbNgQ==", - "dev": true, + "devOptional": true, "license": "MIT", "bin": { "jiti": "lib/jiti-cli.mjs" @@ -7769,7 +7770,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "MPL-2.0", "optional": true, "os": [ @@ -7790,7 +7790,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "MPL-2.0", "optional": true, "os": [ @@ -7811,7 +7810,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "MPL-2.0", "optional": true, "os": [ @@ -7832,7 +7830,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "MPL-2.0", "optional": true, "os": [ @@ -7853,7 +7850,6 @@ "cpu": [ "arm" ], - "dev": true, "license": "MPL-2.0", "optional": true, "os": [ @@ -7874,7 +7870,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "MPL-2.0", "optional": true, "os": [ @@ -7895,7 +7890,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "MPL-2.0", "optional": true, "os": [ @@ -7916,7 +7910,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "MPL-2.0", "optional": true, "os": [ @@ -7937,7 +7930,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "MPL-2.0", "optional": true, "os": [ @@ -7958,7 +7950,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "MPL-2.0", "optional": true, "os": [ @@ -7979,7 +7970,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "MPL-2.0", "optional": true, "os": [ @@ -8874,7 +8864,7 @@ "version": "6.0.3", "resolved": "https://registry.npmjs.org/typescript/-/typescript-6.0.3.tgz", "integrity": "sha512-y2TvuxSZPDyQakkFRPZHKFm+KKVqIisdg9/CZwm9ftvKXLP8NRWj38/ODjNbr43SsoXqNuAisEf1GdCxqWcdBw==", - "dev": true, + "devOptional": true, "license": "Apache-2.0", "bin": { "tsc": "bin/tsc", diff --git a/ui/package.json b/ui/package.json index e8542fa39..c33ded075 100644 --- a/ui/package.json +++ b/ui/package.json @@ -68,6 +68,7 @@ "@aws-sdk/client-dlm": "3.1070.0", "@aws-sdk/client-docdb": "3.1070.0", "@aws-sdk/client-dynamodb": "3.1070.0", + "@aws-sdk/client-dynamodb-streams": "3.1070.0", "@aws-sdk/client-ebs": "3.1070.0", "@aws-sdk/client-ec2": "3.1070.0", "@aws-sdk/client-ecr": "3.1070.0", diff --git a/ui/src/lib/aws-client.ts b/ui/src/lib/aws-client.ts index e1e44988d..8e5d3d21c 100644 --- a/ui/src/lib/aws-client.ts +++ b/ui/src/lib/aws-client.ts @@ -767,3 +767,9 @@ export function getRolesAnywhereClient(region?: string): RolesAnywhereClient { export function getWorkMailClient(region?: string): WorkMailClient { return new WorkMailClient(clientConfig(region)); } + +import { DynamoDBStreamsClient } from "@aws-sdk/client-dynamodb-streams"; + +export function getDynamoDBStreamsClient(region?: string): DynamoDBStreamsClient { + return new DynamoDBStreamsClient(clientConfig(region)); +} diff --git a/ui/src/routes/dax/+page.svelte b/ui/src/routes/dax/+page.svelte index 2ca3ec35f..2c7cea37f 100644 --- a/ui/src/routes/dax/+page.svelte +++ b/ui/src/routes/dax/+page.svelte @@ -1,18 +1,32 @@
@@ -74,6 +244,13 @@ + {#if activeTab === 'clusters'} + + {:else if activeTab === 'paramgroups'} + + {:else if activeTab === 'subnetgroups'} + + {/if}
@@ -101,7 +278,7 @@ {:else}
{#each filteredClusters as a} -
+
@@ -109,9 +286,15 @@

{`${a.NodeType ?? '-'} · ${a.TotalNodes ?? 0} nodes`}

- {#if a.Status} - {a.Status} - {/if} +
+ {#if a.Status} + {a.Status} + {/if} + + + + +
{/each}
@@ -130,6 +313,10 @@

{`${a.Description ?? ''}`}

+
+ + +
{/each} @@ -148,6 +335,10 @@

{`VPC: ${a.VpcId ?? '-'}`}

+
+ + +
{/each} diff --git a/ui/src/routes/dynamodb/+page.svelte b/ui/src/routes/dynamodb/+page.svelte index 8c4e40150..5f8825077 100644 --- a/ui/src/routes/dynamodb/+page.svelte +++ b/ui/src/routes/dynamodb/+page.svelte @@ -2,6 +2,8 @@ import { confirmDestructive } from '$lib/confirm-dialog'; import { onMount, onDestroy } from 'svelte'; import { newDynamoDBClient, getStoredRegion } from '$lib/aws/client'; +import { getDynamoDBStreamsClient } from '$lib/aws-client'; +import { DescribeStreamCommand, GetShardIteratorCommand, GetRecordsCommand } from '@aws-sdk/client-dynamodb-streams'; import { ListTablesCommand, DescribeTableCommand, @@ -14,6 +16,18 @@ PutItemCommand, ExecuteStatementCommand, DescribeTimeToLiveCommand, + TransactWriteItemsCommand, + TransactGetItemsCommand, + ExecuteTransactionCommand, + BatchExecuteStatementCommand, + RestoreTableFromBackupCommand, + RestoreTableToPointInTimeCommand, + ExportTableToPointInTimeCommand, + ImportTableCommand, + ListExportsCommand, + ListImportsCommand, + BatchGetItemCommand, + UpdateItemCommand, UpdateTimeToLiveCommand, UpdateTableCommand, ListBackupsCommand, @@ -68,6 +82,7 @@ let queryFilterExp = $state(''); let queryLimit = $state(100); let queryResults = $state[]>([]); +let queryLastKey = $state(null); let queryLoading = $state(false); let queryCount = $state(0); let querySortOrder = $state<'ASC' | 'DESC'>('ASC'); @@ -77,6 +92,7 @@ let scanProjectionExp = $state(''); let scanLimit = $state(100); let scanResults = $state[]>([]); +let scanLastKey = $state(null); let scanLoading = $state(false); let scanCount = $state(0); let scanScannedCount = $state(0); @@ -119,6 +135,7 @@ let streamEventsHtml = $state(''); let streamEventsLoading = $state(false); let streamBackendUnavailable = $state(false); +let ddbStreams = $state(getDynamoDBStreamsClient()); let streamPollTimer: ReturnType | undefined; let streamFetchController: AbortController | undefined; @@ -145,6 +162,24 @@ let streamsViewType = $state('NEW_AND_OLD_IMAGES'); let streamsEnabled = $state(false); let streamARN = $state(''); +let editBillingMode = $state('PAY_PER_REQUEST'); +let editRcu = $state(5); +let editWcu = $state(5); +async function updateCapacity() { + if (!selectedTable) return; + try { + await ddb.send(new UpdateTableCommand({ + TableName: selectedTable, + BillingMode: editBillingMode as 'PROVISIONED' | 'PAY_PER_REQUEST', + ...(editBillingMode === 'PROVISIONED' ? { + ProvisionedThroughput: { ReadCapacityUnits: editRcu, WriteCapacityUnits: editWcu } + } : {}) + })); + toast.success("Capacity updated"); + loadTables(); + } catch (e) { toast.error(String(e)); } +} + // Modals let showNewItemModal = $state(false); @@ -153,6 +188,31 @@ let importJson = $state(''); let showEditModal = $state(false); let editItemJson = $state(''); +let updateExp = $state(''); +let updateCond = $state(''); +let batchGetKeys = $state(''); +async function execBatchGet() { + if (!selectedTable) return; + try { + const res = await ddb.send(new BatchGetItemCommand({ + RequestItems: { [selectedTable]: { Keys: JSON.parse(batchGetKeys).map((k: unknown) => jsonToItem(k as Record)) } } + })); + toast.success("BatchGet completed. Found: " + (res.Responses?.[selectedTable]?.length || 0)); + } catch (e) { toast.error(String(e)); } +} +async function execUpdateItem() { + if (!selectedTable || !editItemJson) return; + try { + await ddb.send(new UpdateItemCommand({ + TableName: selectedTable, + Key: buildItemKey(JSON.parse(editItemJson)), + UpdateExpression: updateExp || undefined, + ConditionExpression: updateCond || undefined + })); + toast.success("UpdateItem success"); + showEditModal = false; + } catch (e) { toast.error(String(e)); } +} // GSI Create Modal State let showCreateGsiModal = $state(false); @@ -217,7 +277,53 @@ return key; } - function exportJson(data: Record[], filename: string): void { + +let s3Exports: unknown[] = $state([]); +let s3Imports: unknown[] = $state([]); +async function loadExportsImports() { + if (!selectedTable) return; + try { + const e = await ddb.send(new ListExportsCommand({TableArn: selectedTableDesc?.TableArn})); + s3Exports = e.ExportSummaries || []; + const i = await ddb.send(new ListImportsCommand({})); + // Needs filter by table if supported + s3Imports = i.ImportSummaryList || []; + } catch(e){} +} +async function nativeExport() { + // eslint-disable-next-line no-alert + const bucket = prompt("S3 Bucket Name:"); + if (!bucket || !selectedTableDesc?.TableArn) return; + try { + await ddb.send(new ExportTableToPointInTimeCommand({ + TableArn: selectedTableDesc.TableArn, + S3Bucket: bucket, + ExportFormat: "DYNAMODB_JSON" + })); + toast.success("Export started"); + } catch (e) { toast.error(String(e)); } +} +async function nativeImport() { + // eslint-disable-next-line no-alert + const bucket = prompt("S3 Bucket Name:"); + // eslint-disable-next-line no-alert + const table = prompt("Target Table Name:"); + if (!bucket || !table) return; + try { + await ddb.send(new ImportTableCommand({ + S3BucketSource: { S3Bucket: bucket }, + InputFormat: "DYNAMODB_JSON", + TableCreationParameters: { + TableName: table, + BillingMode: "PAY_PER_REQUEST", + KeySchema: [{AttributeName: "pk", KeyType: "HASH"}], + AttributeDefinitions: [{AttributeName: "pk", AttributeType: "S"}] + } + })); + toast.success("Import started"); + } catch (e) { toast.error(String(e)); } +} +function exportJson(data: Record[], filename: string): void { const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' }); const url = URL.createObjectURL(blob); const a = document.createElement('a'); @@ -353,7 +459,8 @@ ...(queryIndexName ? { IndexName: queryIndexName } : {}), ...(queryFilterExp ? { FilterExpression: queryFilterExp } : {}) }; - const res = await ddb.send(new QueryCommand(input)); + const res = await ddb.send(new QueryCommand({...input, ExclusiveStartKey: queryLastKey as Record})); +queryLastKey = res.LastEvaluatedKey; queryResults = (res.Items ?? []).map((item) => itemToJson(item)); queryCount = res.Count ?? 0; } catch (err: unknown) { @@ -374,7 +481,8 @@ ...(scanFilterExp ? { FilterExpression: scanFilterExp } : {}), ...(scanProjectionExp ? { ProjectionExpression: scanProjectionExp } : {}) }; - const res = await ddb.send(new ScanCommand(input)); + const res = await ddb.send(new ScanCommand({...input, ExclusiveStartKey: scanLastKey as Record})); +scanLastKey = res.LastEvaluatedKey; scanResults = (res.Items ?? []).map((item) => itemToJson(item)); scanCount = res.Count ?? 0; scanScannedCount = res.ScannedCount ?? 0; @@ -571,7 +679,20 @@ } } - async function togglePitr(): Promise { + async function restorePitr() { + // eslint-disable-next-line no-alert + const name = prompt("New Table Name:"); + if (!name || !selectedTable) return; + try { + await ddb.send(new RestoreTableToPointInTimeCommand({ + SourceTableName: selectedTable, + TargetTableName: name, + UseLatestRestorableTime: true + })); + toast.success("Restoring table..."); + } catch (e) { toast.error(String(e)); } +} +async function togglePitr(): Promise { if (!selectedTable) return; const enable = pitrStatus !== 'ENABLED'; try { @@ -678,7 +799,34 @@ } // Stream Events - async function loadStreamEvents() { + +async function loadNativeStreams() { + if (!streamARN) return; + try { + const desc = await ddbStreams.send(new DescribeStreamCommand({StreamArn: streamARN})); + if (!desc.StreamDescription?.Shards) return; + let recordsHtml = ''; + for (const shard of desc.StreamDescription.Shards) { + if (!shard.ShardId) continue; + const it = await ddbStreams.send(new GetShardIteratorCommand({ + StreamArn: streamARN, + ShardId: shard.ShardId, + ShardIteratorType: "TRIM_HORIZON" + })); + if (!it.ShardIterator) continue; + const recs = await ddbStreams.send(new GetRecordsCommand({ShardIterator: it.ShardIterator, Limit: 100})); + if (recs.Records) { + for (const r of recs.Records) { + recordsHtml += `
Native Stream Record: ${r.eventName} ${JSON.stringify(r.dynamodb)}
`; + } + } + } + streamEventsHtml = recordsHtml || "No native stream records found."; + } catch(e) { + streamEventsHtml = "Native streams error: " + String(e); + } +} +async function loadStreamEvents() { if (!selectedTable) return; if (!streamEventsHtml) streamEventsLoading = true; const signal = streamFetchController?.signal; @@ -691,6 +839,7 @@ streamBackendUnavailable = true; streamEventsHtml = ''; stopStreamPolling(); + await loadNativeStreams(); } else if (text === 'No recent stream events.' || text.trim() === '') { streamEventsHtml = ''; } else { @@ -795,7 +944,10 @@ showEditModal = true; } - function setPartiqlExample(query: string) { + +function nextQueryPage() { if (queryLastKey) executeQuery(); } +function nextScanPage() { if (scanLastKey) executeScan(); } +function setPartiqlExample(query: string) { partiqlStatement = query; } @@ -1848,7 +2000,7 @@ - + {#if backupsLoading}
@@ -1914,7 +2066,8 @@

Earliest restore point: {pitrEarliestRestoreDate.toLocaleString()}

{/if}

PITR lets you restore this table to any point in the last 35 days.

- +