From 8f1477eaa7b57b2e1d1a9d586d3c87bcf4275f7d Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 21 Mar 2026 15:42:46 +0000 Subject: [PATCH 01/40] feat(consensus): implement the scheduler fix(consensus/propeller): shceduler.go minor improvs --- consensus/propeller/scheduler.go | 254 +++++++++++++++++++++++++++++++ 1 file changed, 254 insertions(+) create mode 100644 consensus/propeller/scheduler.go diff --git a/consensus/propeller/scheduler.go b/consensus/propeller/scheduler.go new file mode 100644 index 0000000000..075791cd30 --- /dev/null +++ b/consensus/propeller/scheduler.go @@ -0,0 +1,254 @@ +package propeller + +import ( + "cmp" + "errors" + "fmt" + "slices" + + "github.com/libp2p/go-libp2p/core/peer" +) + +type Stake uint64 + +// todo(rdr): this is a Peer that belongs to a committee and has a stake. I would like to +// give it a better name +type PeerCommittee struct { + ID peer.ID + Stake Stake +} + +// Scheduler represents the tree manager that computes the tree topology on demand for each +// publisher. It holds a deterministic shard-to-peer mapping for a committee. +// Given a sorted set of peers and a publisher, it computes which peer is +// responsible for broadcasting each shard index. The mapping is deterministic +// so that all nodes agree on the assignment without coordination. +// +// The design relies on the invariant that there are N-1 shards for N peers, +// and each non-publisher peer gets exactly one shard. The publisher is "skipped" +// in the sorted peer list when assigning shard indices. +// +// Propeller uses a distributed broadcast approach where: +// - numDataShards = floor((N-1)/3) where N is total number of nodes +// - numDataShards represents both max faulty nodes AND number of data shards +// - numCodingShards = N-1-numDataShards (meaning, the rest) +// - Message is BUILT when numDataShards are received (can reconstruct) +// - Message is RECEIVED when 2*numDataShards shards are received (guarantees gossip property) +// - Each peer broadcasts received shards to all other peers (full mesh) +type Scheduler struct { + peerID peer.ID + peerIDIndex int + peers []PeerCommittee + numDataShards int + numCodingShards int +} + +// NewScheduler creates a schedule from a list of peers. The peers are sorted +// lexicographically by their string representation to ensure all nodes derive +// the same ordering regardless of discovery order. +// Note that `nodes` will be mutated after this function gets called +// todo(rdr): should we return scheduler by reference or by value? +func NewScheduler( + id peer.ID, + nodes []PeerCommittee, +) (*Scheduler, error) { + if len(nodes) < 2 { + return nil, fmt.Errorf( + "at least 2 peers are required to form a new committee: %d given", + len(nodes), + ) + } + + // todo(rdr): check with function is faster for sorting in our case: + // `slices.Sort` or `sort.Slice` + // sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID }) + slices.SortFunc(nodes, func(i, j PeerCommittee) int { return cmp.Compare(i.ID, j.ID) }) + + // check that the local peer ID is part of the peer committee + idIndex, exists := slices.BinarySearchFunc( + nodes, + id, + func(elem PeerCommittee, target peer.ID) int { + return cmp.Compare(elem.ID, target) + }, + ) + if !exists { + return nil, errors.New("the local peer id is not part of the suplied list of peeers") + } + + // check that there is no duplicated ID in the node list + for i := range len(nodes) - 1 { + if nodes[i].ID == nodes[i+1].ID { + return nil, fmt.Errorf("duplicated ids in the suplied list of peers: %s", nodes[i].ID) + } + } + + totalNodes := len(nodes) + // We guarantee always one data shard for small networks (N = 2 or N = 3) + numDataShards := max(1, (totalNodes-1)/3) + // We avoid the possibility of an underflow + numCodingShards := max(0, totalNodes-1-numDataShards) + + return &Scheduler{ + peerID: id, + peerIDIndex: idIndex, + peers: nodes, + numDataShards: numDataShards, + numCodingShards: numCodingShards, + }, nil +} + +// PeerID returns the Scheduler Peer ID +func (s *Scheduler) PeerID() peer.ID { + return s.peerID +} + +// Peers return the Scheduler list of nodes +func (s *Scheduler) Peers() []PeerCommittee { + return s.peers +} + +// DataShards returns the number of data (systematic) shards. +func (s *Scheduler) NumDataShards() int { return s.numDataShards } + +// CodingShards returns the number of parity (coding) shards. +func (s *Scheduler) NumCodingShards() int { return s.numCodingShards } + +// NumShards returns the total number of shards (data + coding = N-1). +func (s *Scheduler) NumTotalShards() int { return s.numDataShards + s.numCodingShards } + +func (s *Scheduler) publisherIndex(publisher peer.ID) (int, error) { + publisherIndex, found := slices.BinarySearchFunc( + s.peers, + publisher, + func(elem PeerCommittee, target peer.ID) int { + return cmp.Compare(elem.ID, target) + }, + ) + if !found { + return -1, fmt.Errorf("publisher with id \"%s\" not found in the peer list", publisher) + } + return publisherIndex, nil +} + +// PeerForShardIndex returns the peer responsible for broadcasting a given +// shard index to a given publisher. The mapping skips the publisher in the +// sorted list: +// +// if shardIndex < publisherIndex: peer = peers[shardIndex] +// if shardIndex >= publisherIndex: peer = peers[shardIndex + 1] +// +// Example with peers [A, B, C, D] and publisher C (index 2): +// +// shard 0 -> A, shard 1 -> B, shard 2 -> D +func (s *Scheduler) PeerForShardIndex( + publisher peer.ID, shardIndex ShardIndex, +) (peer.ID, error) { + if int(shardIndex) >= s.NumTotalShards() { + return "", fmt.Errorf( + "shard index %d out of range [0, %d)", shardIndex, s.NumTotalShards(), + ) + } + + pubIdx, err := s.publisherIndex(publisher) + if err != nil { + return "", err + } + + // Skip the publisher's position. + peerIdx := int(shardIndex) + if peerIdx >= pubIdx { + peerIdx++ + } + + return s.peers[peerIdx].ID, nil +} + +// ShardIndexForPublisher returns the shard index that shceduler is responsible for +// broadcasting for a given publisher. This is the inverse of PeerForShard: +// +// if localPeerIndex < publisherIndex: shard = localPeerIndex +// if localPeerIndex > publisherIndex: shard = localPeerIndex - 1 +// +// Returns an error if Scheduler's peer is the publisher (publishers don't have an +// assigned shard) or if the publisher is not in the list. +func (s *Scheduler) ShardIndexForPublisher( + publisher peer.ID, +) (ShardIndex, error) { + if s.peerID == publisher { + return 0, fmt.Errorf( + "scheduler peer is the same as the publisher and has no assinged shard: %s", + publisher, + ) + } + + pubIdx, err := s.publisherIndex(publisher) + if err != nil { + return 0, fmt.Errorf("couldn't locate shard index for publisher: %w", err) + } + + shardIdx := s.peerIDIndex + if s.peerIDIndex >= pubIdx { + shardIdx = s.peerIDIndex - 1 + } + + return ShardIndex(shardIdx), nil +} + +// ValidateShardOrigin verifies that a shard unit was received from the expected sender. +// The sender has to be either the publisher for direct shards or a designated +// broadcasted for the given shard index. +func (s *Scheduler) ValidateShardOrigin( + sender peer.ID, + publisher peer.ID, + shardIndex ShardIndex, +) error { + if sender == s.peerID { + return fmt.Errorf("scheduler sent itself a shard: %s", sender) + } + if publisher == s.peerID { + return fmt.Errorf("scheduler broadcast itself a shard: %s", publisher) + } + + expectedBroadcaster, err := s.PeerForShardIndex(publisher, shardIndex) + if err != nil { + return fmt.Errorf( + "couldn't validate publisher %s with shard %d: %w", + publisher, + shardIndex, + err, + ) + } + + validDirectShard := expectedBroadcaster == s.peerID && sender == publisher + if validDirectShard { + return nil + } + + validBroadcastShard := expectedBroadcaster == sender + if validBroadcastShard { + return nil + } + + return fmt.Errorf( + "received shard index %d from unexpected sender %s", + shardIndex, + sender, + ) +} + +// BroadcastTargets returns all peers the Schudler's peer needs to braodcast to, +// in shard-index order. The i-th element of the returned slice is the peer responsible for +// shard i. +func (s *Scheduler) BroadcastTargets() []peer.ID { + targets := make([]peer.ID, s.NumTotalShards()-1) + i := 0 + for _, p := range s.peers { + if i == s.peerIDIndex { + continue + } + targets[i] = p.ID + i += 1 + } + return targets +} From 95145c0d5d7b7d9ec5b63ecad19aafb6057f401f Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 13 May 2026 10:26:49 +0100 Subject: [PATCH 02/40] feat(consensus/propeller): implement reedsolomon fec --- .../propeller/reedsolomon/reedsolomon.go | 81 +++++++++++++++++++ .../propeller/reedsolomon/reedsolomon_test.go | 1 + go.mod | 1 + go.sum | 2 + 4 files changed, 85 insertions(+) create mode 100644 consensus/propeller/reedsolomon/reedsolomon.go create mode 100644 consensus/propeller/reedsolomon/reedsolomon_test.go diff --git a/consensus/propeller/reedsolomon/reedsolomon.go b/consensus/propeller/reedsolomon/reedsolomon.go new file mode 100644 index 0000000000..56991aba26 --- /dev/null +++ b/consensus/propeller/reedsolomon/reedsolomon.go @@ -0,0 +1,81 @@ +package reedsolomon + +import ( + "errors" + "fmt" + + "github.com/klauspost/reedsolomon" +) + +// EncodeData generates the coding shards usign Reed-Solomon +// erasure codes. Receives the data, amount of shards and parity number. +// It will return the Reed Solomon encoding where the first `numDataShards` +// `[]byte` slices will be occupied by the original data. The remaining `parity` +// `[]byte` slices will contain the coding shards. +// The data will be modified in place so the input shouldn't be modified after calling this +// function. +func EncodeData( + data []byte, + numDataShards, + parity int, +) ([][]byte, error) { + if len(data) == 0 { + return nil, errors.New("received empty data") + } + + encoder, err := reedsolomon.New(numDataShards, parity) + if err != nil { + return nil, fmt.Errorf("creating Reed-Solomon encoder: %w", err) + } + + split, err := encoder.Split(data) + if err != nil { + return nil, fmt.Errorf("splitting the data into shards: %w", err) + } + + err = encoder.Encode(split) + if err != nil { + return nil, fmt.Errorf("encoding the data shards: %w", err) + } + + return split, nil +} + +// RecoverData restores the missing data using Reed-Solomon erasure codes. There cannot be more than +// `parity` shards missing otherwise the recover will fail. Data that is considered missing needs to +// be marked as `nil`. Returns the recovered data. +// The data will be modified in place so the input shouldn't be modified after calling this function. +func RecoverData( + shards [][]byte, + numDataShards, + parity int, +) ([][]byte, error) { + if len(shards) == 0 { + return nil, errors.New("no data shards provided") + } + + // todo(rdr): numDataShards can be inferred by getting the length of the shards + decoder, err := reedsolomon.New(numDataShards, parity) + if err != nil { + return nil, fmt.Errorf("creating Reed-Solomon decoder: %w", err) + } + + // todo(rdr): this is a slow approach where we are reconstructing parity shards as + // well. This is safe because at the end we can verify that it is correct. We might + // want to speed this up using `ReconstructData` with no `Verify` which should be 3x faster. + err = decoder.Reconstruct(shards) + if err != nil { + return nil, fmt.Errorf("recovering the data shards: %w", err) + } + + correct, err := decoder.Verify(shards) + if err != nil { + return nil, fmt.Errorf("verifying the data shards: %w", err) + } + + if !correct { + return nil, errors.New("data shard failed verification") + } + + return shards, nil +} diff --git a/consensus/propeller/reedsolomon/reedsolomon_test.go b/consensus/propeller/reedsolomon/reedsolomon_test.go new file mode 100644 index 0000000000..c37f6836a0 --- /dev/null +++ b/consensus/propeller/reedsolomon/reedsolomon_test.go @@ -0,0 +1 @@ +package reedsolomon_test diff --git a/go.mod b/go.mod index eca739c819..749aa1926d 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/ethereum/go-ethereum v1.17.2 github.com/fxamacker/cbor/v2 v2.9.1 github.com/go-playground/validator/v10 v10.30.2 + github.com/klauspost/reedsolomon v1.14.0 github.com/libp2p/go-libp2p v0.48.0 github.com/libp2p/go-libp2p-kad-dht v0.39.1 github.com/libp2p/go-libp2p-pubsub v0.16.0 diff --git a/go.sum b/go.sum index eb2ec50c55..0f81defce2 100644 --- a/go.sum +++ b/go.sum @@ -379,6 +379,8 @@ github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/klauspost/reedsolomon v1.14.0 h1:5YSZeclzSYg5nl349+GDG/agDtQ6MZiwUYXvVKN1Jx0= +github.com/klauspost/reedsolomon v1.14.0/go.mod h1:yjqqjgMTQkBUHSG97/rm4zipffCNbCiZcB3kTqr++sQ= github.com/koron/go-ssdp v0.1.0 h1:ckl5x5H6qSNFmi+wCuROvvGUu2FQnMbQrU95IHCcv3Y= github.com/koron/go-ssdp v0.1.0/go.mod h1:GltaDBjtK1kemZOusWYLGotV0kBeEf59Bp0wtSB0uyU= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= From 45a5eb3e8a942a37a7bde8db31ef93919b0a8d58 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 13 May 2026 10:32:13 +0100 Subject: [PATCH 03/40] refactor(p2p): apply style correctness --- p2p/server/server.go | 287 ++++++++++++++++++++++++------------------- 1 file changed, 161 insertions(+), 126 deletions(-) diff --git a/p2p/server/server.go b/p2p/server/server.go index d8f4621646..148e8d8c7b 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -216,7 +216,10 @@ func (h *Server) onHeadersRequest( HeaderMessage: &header.BlockHeadersResponse_Fin{}, } - return h.processIterationRequest(req.Iteration, finMsg, func(it blockDataAccessor) (proto.Message, error) { + return h.processIterationRequest( + req.Iteration, + finMsg, + func(it blockDataAccessor) (proto.Message, error) { blockHeader, err := it.Header() if err != nil { return nil, err @@ -252,17 +255,17 @@ func (h *Server) onHeadersRequest( } } - stateDiffCommitment := stateUpdate.StateDiff.Hash() - return &header.BlockHeadersResponse{ - HeaderMessage: &header.BlockHeadersResponse_Header{ - Header: core2p2p.AdaptHeader( - blockHeader, - commitments, - &stateDiffCommitment, - stateUpdate.StateDiff.Length()), - }, - }, nil - }) + stateDiffCommitment := stateUpdate.StateDiff.Hash() + return &header.BlockHeadersResponse{ + HeaderMessage: &header.BlockHeadersResponse_Header{ + Header: core2p2p.AdaptHeader( + blockHeader, + commitments, + &stateDiffCommitment, + stateUpdate.StateDiff.Length()), + }, + }, nil + }) } func (h *Server) onEventsRequest( @@ -329,114 +332,119 @@ func (h *Server) onStateDiffRequest( finMsg := &state.StateDiffsResponse{ StateDiffMessage: &state.StateDiffsResponse_Fin{}, } - return h.processIterationRequestMulti(req.Iteration, finMsg, func(it blockDataAccessor) ([]proto.Message, error) { - block, err := it.Block() - if err != nil { - return nil, err - } - blockNumber := block.Number + return h.processIterationRequestMulti( + req.Iteration, + finMsg, + func(it blockDataAccessor) ([]proto.Message, error) { + block, err := it.Block() + if err != nil { + return nil, err + } + blockNumber := block.Number - stateUpdate, err := h.bcReader.StateUpdateByNumber(blockNumber) - if err != nil { - return nil, err - } - diff := stateUpdate.StateDiff + stateUpdate, err := h.bcReader.StateUpdateByNumber(blockNumber) + if err != nil { + return nil, err + } + diff := stateUpdate.StateDiff - type contractDiff struct { - address *felt.Felt - storageDiffs map[felt.Felt]*felt.Felt - nonce *felt.Felt - classHash *felt.Felt // set only if contract deployed or replaced - } - modifiedContracts := make(map[felt.Felt]*contractDiff) + type contractDiff struct { + address *felt.Felt + storageDiffs map[felt.Felt]*felt.Felt + nonce *felt.Felt + classHash *felt.Felt // set only if contract deployed or replaced + } + modifiedContracts := make(map[felt.Felt]*contractDiff) - initContractDiff := func(addr *felt.Felt) *contractDiff { - return &contractDiff{address: addr} - } - updateModifiedContracts := func(addr felt.Felt, f func(*contractDiff)) error { - cDiff, ok := modifiedContracts[addr] - if !ok { - cDiff = initContractDiff(&addr) - if err != nil { - return err - } - modifiedContracts[addr] = cDiff + initContractDiff := func(addr *felt.Felt) *contractDiff { + return &contractDiff{address: addr} } + updateModifiedContracts := func(addr felt.Felt, f func(*contractDiff)) error { + cDiff, ok := modifiedContracts[addr] + if !ok { + cDiff = initContractDiff(&addr) + if err != nil { + return err + } + modifiedContracts[addr] = cDiff + } - f(cDiff) - return nil - } + f(cDiff) + return nil + } - for addr, n := range diff.Nonces { - err = updateModifiedContracts(addr, func(diff *contractDiff) { - diff.nonce = n - }) - if err != nil { - return nil, err + for addr, n := range diff.Nonces { + err = updateModifiedContracts(addr, func(diff *contractDiff) { + diff.nonce = n + }) + if err != nil { + return nil, err + } } - } - for addr, sDiff := range diff.StorageDiffs { - err = updateModifiedContracts(addr, func(diff *contractDiff) { - diff.storageDiffs = sDiff - }) - if err != nil { - return nil, err + for addr, sDiff := range diff.StorageDiffs { + err = updateModifiedContracts(addr, func(diff *contractDiff) { + diff.storageDiffs = sDiff + }) + if err != nil { + return nil, err + } } - } - for addr, classHash := range diff.DeployedContracts { - classHashCopy := classHash - err = updateModifiedContracts(addr, func(diff *contractDiff) { - diff.classHash = classHashCopy - }) - if err != nil { - return nil, err + for addr, classHash := range diff.DeployedContracts { + classHashCopy := classHash + err = updateModifiedContracts(addr, func(diff *contractDiff) { + diff.classHash = classHashCopy + }) + if err != nil { + return nil, err + } } - } - for addr, classHash := range diff.ReplacedClasses { - classHashCopy := classHash - err = updateModifiedContracts(addr, func(diff *contractDiff) { - diff.classHash = classHashCopy - }) - if err != nil { - return nil, err + for addr, classHash := range diff.ReplacedClasses { + classHashCopy := classHash + err = updateModifiedContracts(addr, func(diff *contractDiff) { + diff.classHash = classHashCopy + }) + if err != nil { + return nil, err + } } - } - var responses []proto.Message - for _, c := range modifiedContracts { - responses = append(responses, &state.StateDiffsResponse{ - StateDiffMessage: &state.StateDiffsResponse_ContractDiff{ - ContractDiff: core2p2p.AdaptContractDiff(c.address, c.nonce, c.classHash, c.storageDiffs), - }, - }) - } + var responses []proto.Message + for _, c := range modifiedContracts { + responses = append(responses, &state.StateDiffsResponse{ + StateDiffMessage: &state.StateDiffsResponse_ContractDiff{ + ContractDiff: core2p2p.AdaptContractDiff( + c.address, c.nonce, c.classHash, c.storageDiffs, + ), + }, + }) + } - for _, classHash := range diff.DeclaredV0Classes { - responses = append(responses, &state.StateDiffsResponse{ - StateDiffMessage: &state.StateDiffsResponse_DeclaredClass{ - DeclaredClass: &state.DeclaredClass{ - ClassHash: core2p2p.AdaptHash(classHash), - CompiledClassHash: nil, // for cairo0 it's nil + for _, classHash := range diff.DeclaredV0Classes { + responses = append(responses, &state.StateDiffsResponse{ + StateDiffMessage: &state.StateDiffsResponse_DeclaredClass{ + DeclaredClass: &state.DeclaredClass{ + ClassHash: core2p2p.AdaptHash(classHash), + CompiledClassHash: nil, // for cairo0 it's nil + }, }, - }, - }) - } - for classHash, compiledHash := range diff.DeclaredV1Classes { - responses = append(responses, &state.StateDiffsResponse{ - StateDiffMessage: &state.StateDiffsResponse_DeclaredClass{ - DeclaredClass: &state.DeclaredClass{ - ClassHash: core2p2p.AdaptHash(&classHash), - CompiledClassHash: core2p2p.AdaptHash(compiledHash), + }) + } + for classHash, compiledHash := range diff.DeclaredV1Classes { + responses = append(responses, &state.StateDiffsResponse{ + StateDiffMessage: &state.StateDiffsResponse_DeclaredClass{ + DeclaredClass: &state.DeclaredClass{ + ClassHash: core2p2p.AdaptHash(&classHash), + CompiledClassHash: core2p2p.AdaptHash(compiledHash), + }, }, - }, - }) - } + }) + } - return responses, nil - }) + return responses, nil + }) } func (h *Server) onClassesRequest( @@ -445,6 +453,7 @@ func (h *Server) onClassesRequest( finMsg := &syncclass.ClassesResponse{ ClassMessage: &syncclass.ClassesResponse_Fin{}, } +<<<<<<< HEAD return h.processIterationRequestMulti(req.Iteration, finMsg, func(it blockDataAccessor) ([]proto.Message, error) { block, err := it.Block() if err != nil { @@ -464,39 +473,63 @@ func (h *Server) onClassesRequest( defer func() { if closeErr := closer(); closeErr != nil { h.logger.Error("Failed to close state reader", zap.Error(closeErr)) +======= + return h.processIterationRequestMulti( + req.Iteration, + finMsg, + func(it blockDataAccessor) ([]proto.Message, error) { + block, err := it.Block() + if err != nil { + return nil, err +>>>>>>> 7470f4d1f (refactor(p2p): apply style correctness) } - }() + blockNumber := block.Number - stateDiff := stateUpdate.StateDiff - - var responses []proto.Message - for _, hash := range stateDiff.DeclaredV0Classes { - cls, err := stateReader.Class(hash) + stateUpdate, err := h.bcReader.StateUpdateByNumber(blockNumber) if err != nil { return nil, err } - responses = append(responses, &syncclass.ClassesResponse{ - ClassMessage: &syncclass.ClassesResponse_Class{ - Class: core2p2p.AdaptClass(cls.Class), - }, - }) - } - for classHash := range stateDiff.DeclaredV1Classes { - cls, err := stateReader.Class(&classHash) + stateReader, closer, err := h.bcReader.StateAtBlockNumber(blockNumber) if err != nil { return nil, err } + defer func() { + if closeErr := closer(); closeErr != nil { + h.log.Error("Failed to close state reader", zap.Error(closeErr)) + } + }() - responses = append(responses, &syncclass.ClassesResponse{ - ClassMessage: &syncclass.ClassesResponse_Class{ - Class: core2p2p.AdaptClass(cls.Class), - }, - }) - } + stateDiff := stateUpdate.StateDiff - return responses, nil - }) + var responses []proto.Message + for _, hash := range stateDiff.DeclaredV0Classes { + cls, err := stateReader.Class(hash) + if err != nil { + return nil, err + } + + responses = append(responses, &syncclass.ClassesResponse{ + ClassMessage: &syncclass.ClassesResponse_Class{ + Class: core2p2p.AdaptClass(cls.Class), + }, + }) + } + for classHash := range stateDiff.DeclaredV1Classes { + cls, err := stateReader.Class(&classHash) + if err != nil { + return nil, err + } + + responses = append(responses, &syncclass.ClassesResponse{ + ClassMessage: &syncclass.ClassesResponse_Class{ + Class: core2p2p.AdaptClass(cls.Class), + }, + }) + } + + return responses, nil + }) } // blockDataAccessor provides access to either entire block or header @@ -570,7 +603,8 @@ func (h *Server) processIterationRequestMulti(iteration *synccommon.Iteration, f return func(yield yieldFunc) { // while iterator is valid for it.Valid() { - // pass it to handler function (some might be interested in header, others in entire block) + // pass it to handler function; some might be interested in header, + // others in entire block messages, err := getMsg(it) if err != nil { if !errors.Is(err, db.ErrKeyNotFound) { @@ -586,7 +620,8 @@ func (h *Server) processIterationRequestMulti(iteration *synccommon.Iteration, f for _, msg := range messages { // push generated msg to caller if !yield(msg) { - // if caller is not interested in remaining data (example: connection to a peer is closed) exit + // if caller is not interested in the remaining data, exit. + // (example: connection to a peer is closed) // note that in this case we won't send finMsg return } From bf9b63d1cb98c6f4b180bc0d4c1b6bc981a4ba55 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 25 Mar 2026 18:06:34 +0000 Subject: [PATCH 04/40] [wip] feat(consensus/propeller): full and messy propeller impl --- consensus/propeller/engine.go | 521 ++++++++++++++++++++++ consensus/propeller/engine_test.go | 372 +++++++++++++++ consensus/propeller/merkle/merkle.go | 168 +++++++ consensus/propeller/merkle/merkle_test.go | 243 ++++++++++ consensus/propeller/pool/pool.go | 26 ++ consensus/propeller/pool/pool_test.go | 0 consensus/propeller/processor.go | 343 ++++++++++++++ consensus/propeller/processor_test.go | 451 +++++++++++++++++++ consensus/propeller/propeller.go | 67 +++ consensus/propeller/propeller_test.go | 0 consensus/propeller/scheduler_test.go | 286 ++++++++++++ consensus/propeller/sharding.go | 113 +++++ consensus/propeller/sharding_test.go | 204 +++++++++ consensus/propeller/timecache.go | 73 +++ consensus/propeller/timecache_test.go | 122 +++++ consensus/propeller/types.go | 306 +++++++++++++ consensus/propeller/utils/padding.go | 60 +++ consensus/propeller/utils/padding_test.go | 108 +++++ consensus/propeller/utils/signing.go | 33 ++ consensus/propeller/utils/signing_test.go | 1 + consensus/propeller/validator.go | 171 +++++++ consensus/propeller/validator_test.go | 364 +++++++++++++++ 22 files changed, 4032 insertions(+) create mode 100644 consensus/propeller/engine.go create mode 100644 consensus/propeller/engine_test.go create mode 100644 consensus/propeller/merkle/merkle.go create mode 100644 consensus/propeller/merkle/merkle_test.go create mode 100644 consensus/propeller/pool/pool.go create mode 100644 consensus/propeller/pool/pool_test.go create mode 100644 consensus/propeller/processor.go create mode 100644 consensus/propeller/processor_test.go create mode 100644 consensus/propeller/propeller.go create mode 100644 consensus/propeller/propeller_test.go create mode 100644 consensus/propeller/scheduler_test.go create mode 100644 consensus/propeller/sharding.go create mode 100644 consensus/propeller/sharding_test.go create mode 100644 consensus/propeller/timecache.go create mode 100644 consensus/propeller/timecache_test.go create mode 100644 consensus/propeller/types.go create mode 100644 consensus/propeller/utils/padding.go create mode 100644 consensus/propeller/utils/padding_test.go create mode 100644 consensus/propeller/utils/signing.go create mode 100644 consensus/propeller/utils/signing_test.go create mode 100644 consensus/propeller/validator.go create mode 100644 consensus/propeller/validator_test.go diff --git a/consensus/propeller/engine.go b/consensus/propeller/engine.go new file mode 100644 index 0000000000..19a3ff79c6 --- /dev/null +++ b/consensus/propeller/engine.go @@ -0,0 +1,521 @@ +package propeller + +import ( + "context" + "fmt" + + "github.com/NethermindEth/juno/utils" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + "go.uber.org/zap" +) + +// Channel buffer sizes for the engine's internal channels. These are large +// enough to absorb bursts without blocking, but bounded to prevent unbounded +// memory growth from slow consumers. +const ( + eventChSize = 256 + cleanupChSize = 256 + appEventChSize = 256 + cmdChSize = 64 +) + +type broadcastResult struct { + units []PropellerUnit + err error +} + +// todo(rdr): using String until I find a better type +type StakerID string + +type committeeState struct { + scheduler *Scheduler + peerKeys []StakerID +} + +// engineCommand is a tagged union of commands sent to the engine's Run() loop. +type engineCommand interface { + isCommand() +} + +type registerCommittee struct { + committeeID CommitteeID + peers []PeerCommittee + peersKeys []*StakerID + errCh chan error +} + +func (registerCommittee) isCommand() + +type cmdUnregister struct { + committeeID CommitteeID +} + +func (cmdUnregister) isCommand() + +type cmdBroadcast struct { + committeeID CommitteeID + msg []byte + errCh chan error +} + +func (cmdBroadcast) isCommand() + +type cmdHandleUnit struct { + unit *PropellerUnit + sender peer.ID +} + +func (cmdHandleUnit) isCommand() + +// Engine is the central orchestrator of the Propeller protocol. It: +// +// - Manages channel registrations (each channel has its own peer set and schedule). +// - Routes incoming PropellerUnits to the correct MessageProcessor. +// - Handles broadcast requests from the application layer. +// - Collects and forwards events from processors to the application. +// +// The engine is designed to be run as a single long-lived goroutine via Run(). +// External callers interact with it through thread-safe methods that send +// commands on internal channels, so no locks are needed on the hot path. +type Engine struct { + localPeer peer.ID + privKey crypto.PrivKey + config Config + log utils.Logger + // committees holds the Scheduler (i.e. Propeller Tree) and Stakers ID of + // the peers of each registered channel + committees map[CommitteeID]*committeeState + // connected peers hold all the connected peers to the engine + connectedPeers map[peer.ID]struct{} + + // whenever a broadcast action is started, units preparaition are done concurrently + // and delivered through this channel + unitsPrepared chan broadcastResult + + // processors maps each active message to its processor's shard input + // channel. The engine creates processors lazily on first shard receipt. + // Only accessed from the Run() goroutine, so no lock needed. + processors map[messageKey]chan<- shardDelivery + + // finalised tracks recently finalised messages to avoid re-creating + // processors for late-arriving shards. + finalised *TimeCache[messageKey] + + // eventCh is shared between all processors and the engine. The engine + // reads from it and forwards events to the application via Events(). + eventCh chan any + + // cleanupCh carries internal processor-done signals. This is separate + // from eventCh so that a full eventCh never blocks processor goroutines + // trying to signal completion, which would leak goroutines. + cleanupCh chan processorDone + + // appEventCh is the externally-visible event channel. The engine copies + // events from eventCh to appEventCh in its Run() loop, filtering out + // internal events as needed. + appEventCh chan any + + // cmdCh carries commands from external callers into the Run() loop. + cmdCh chan engineCommand + + // sendFn is the network callback for delivering units to peers. + // Injected at construction time for testability. + sendFn SendUnitFunc +} + +// NewEngine creates an engine instance. Call Run() to start processing. +// +// Parameters: +// - localPeer: this node's peer ID. +// - privKey: this node's Ed25519 private key (for signing published messages). +// - config: protocol parameters. +// - sendFn: callback for delivering PropellerUnits to peers over the network. +// - log: structured logger. +func NewEngine( + // todo(rdr): this should be a key pair + privKey crypto.PrivKey, + config *Config, + sendFn SendUnitFunc, + log utils.Logger, +) *Engine { + // todo(rdr): generate local peer id from keypair + return &Engine{ + localPeer: peer.ID("some random value for now"), + privKey: privKey, + config: *config, + log: log, + committees: make(map[CommitteeID]*committeeState), + connectedPeers: make(map[peer.ID]struct{}), + cmdCh: make(chan engineCommand, cmdChSize), + unitsPrepared: make(chan broadcastResult), + // Unsure of the fields below + processors: make(map[messageKey]chan<- shardDelivery), + finalised: NewTimeCache[messageKey](config.StaleMessageTimeout * 2), + eventCh: make(chan any, eventChSize), + cleanupCh: make(chan processorDone, cleanupChSize), + appEventCh: make(chan any, appEventChSize), + sendFn: sendFn, + } +} + +// registerCommittee creates the schedule and encoder for a new channel. +func (e *Engine) registerCommittee( + committeeID CommitteeID, + peers []PeerCommittee, + peersKeys []*StakerID, +) error { + if _, ok := e.committees[committeeID]; ok { + e.log.Warn( + "committee already registered, will ignore re-registration attempt", + zap.Uint64("committeeID", uint64(committeeID)), + ) + return nil + } + + stakerIDs := make([]StakerID, len(peersKeys)) + for i := range peersKeys { + if peersKeys[i] != nil { + stakerIDs[i] = *peersKeys[i] + } else { + // todo(rdr): re-check this flow once implementation is complete + panic("received nil key, they shoudln't be nil") + } + } + + schedule, err := NewScheduler(e.localPeer, peers) + if err != nil { + return fmt.Errorf("couldn't register a new committee: %w", err) + } + + e.committees[committeeID] = &committeeState{ + scheduler: schedule, + peerKeys: stakerIDs, + } + + e.log.Info("registered new committee", + zap.Uint64("channel", uint64(committeeID)), + zap.Int("peers", len(peers)), + zap.Int("dataShards", schedule.NumDataShards()), + zap.Int("codingShards", schedule.NumCodingShards()), + ) + + return nil +} + +// unregisterCommittee removes a channel's state. Not new processors will be started but +// currently running ones will continue until the timeout / stop naturally +func (e *Engine) unregisterCommittee(committeeID CommitteeID) { + delete(e.committees, committeeID) + + e.log.Info("unregistered propeller committee", + zap.Uint64("committee", uint64(committeeID)), + ) +} + +// prepareBroadcast creates Proppeller units asynchronously since it is a very expensive +// operation. +func (e *Engine) prepareBroadcast(committeeID CommitteeID, data []byte) error { + cs, ok := e.committees[committeeID] + if !ok { + return fmt.Errorf("cannot broadcast to an unregistered committee: %s", committeeID) + } + + // todo(rdr): consider having a maximum amount of working threads and a queue tasks for this + // This is an expensive operation, hence we need to do it separately + go func() { + units, err := CreatePropellerUnits( + committeeID, + data, + e.privKey, + cs.scheduler.NumDataShards(), + cs.scheduler.NumCodingShards(), + ) + e.unitsPrepared <- broadcastResult{ + units: units, + err: err, + } + }() + + return nil +} + +// broacast receives Propeller units (built in `prepareBroadcast`) and sends them +func (e *Engine) broadcast(units []PropellerUnit) error { + targetCommittee := units[0].CommitteeID + + cs, ok := e.committees[targetCommittee] + if !ok { + return fmt.Errorf("target committee ID not found: %d", targetCommittee) + } + + targetPeers := cs.scheduler.BroadcastTargets() + if len(targetPeers) != len(units) { + return fmt.Errorf( + "different amount of target peers and propeller units to broadcast: %d vs %d", + len(targetPeers), + len(units), + ) + } + + // todo(rdr): I need to do the actual sending + + return nil +} + +// doHandleUnit routes an incoming unit to the correct processor, creating +// one if needed. +func (e *Engine) doHandleUnit(ctx context.Context, cmd *cmdHandleUnit) { + unit := cmd.unit + key := messageKey{ + Channel: unit.CommitteeID, + Publisher: unit.Publisher, + Root: unit.MerkleRoot, + } + + // Skip already-finalised messages. + if e.finalised.Contains(key) { + return + } + + // Route to existing processor or create a new one. + shardCh, exists := e.processors[key] + if !exists { + shardCh = e.createProcessor(ctx, key, unit) + if shardCh == nil { + return // Channel not registered; logged inside createProcessor. + } + } + + // Non-blocking send to the processor. If its buffer is full, the shard + // is dropped (the processor can reconstruct from other shards). + select { + case shardCh <- shardDelivery{Unit: unit, Sender: cmd.sender}: + default: + e.log.Warn("dropping shard: processor channel full", + zap.Uint32("shard", uint32(unit.ShardIndex)), + zap.Stringer("publisher", unit.Publisher), + ) + } +} + +// createProcessor spins up a new MessageProcessor goroutine for a message +// we haven't seen before. +func (e *Engine) createProcessor( + ctx context.Context, key messageKey, unit *PropellerUnit, +) chan<- shardDelivery { + cs, ok := e.committees[unit.CommitteeID] + if !ok { + e.log.Warn("received unit for unregistered channel", + zap.Uint32("channel", uint32(unit.CommitteeID)), + ) + return nil + } + + validator := NewValidator(cs.scheduler, e.localPeer, &DefaultSignatureVerifier{}) + + // Buffer the shard channel so the engine doesn't block when delivering + // multiple shards in rapid succession. + shardCh := make(chan shardDelivery, cs.scheduler.NumShards()) + + proc := NewMessageProcessor( + key.Channel, + key.Publisher, + key.Root, + e.localPeer, + e.config, + cs.scheduler, + validator, + cs.encoder, + shardCh, + e.eventCh, + e.sendFn, + ) + + e.processors[key] = shardCh + + // Launch the processor goroutine. It will run until finalisation, + // timeout, or context cancellation. The cleanup signal goes to a + // dedicated channel so it cannot be blocked by a full eventCh. + go func() { + proc.Run(ctx) + select { + case e.cleanupCh <- processorDone{key: key}: + case <-ctx.Done(): + } + }() + + return shardCh +} + +// processorDone is an internal event signalling that a processor's goroutine +// has exited. The engine uses this to clean up the processors map. +type processorDone struct { + key messageKey +} + +// handleProcessorDone cleans up after a processor goroutine exits. +func (e *Engine) handleProcessorDone(done processorDone) { + delete(e.processors, done.key) + e.finalised.Add(done.key) + + // Periodically clean up expired entries in the time cache. + // Amortised cost: we do it on every processor exit, which is + // infrequent relative to shard processing. + e.finalised.Cleanup() +} + +// forwardEvent sends an event to the application's event channel. Non-blocking +// to avoid stalling the engine if the application is slow to consume events. +func (e *Engine) forwardEvent(event any) { + select { + case e.appEventCh <- event: + default: + e.log.Warn("dropping event: application event channel full") + } +} + +// handleCommand dispatches a command to the appropriate handler. +func (e *Engine) handleCommand(ctx context.Context, command engineCommand) { + switch cmd := command.(type) { + case *registerCommittee: + err := e.registerCommittee(cmd.committeeID, cmd.peers, cmd.peersKeys) + cmd.errCh <- err + case *cmdUnregister: + e.unregisterCommittee(cmd.committeeID) + case *cmdBroadcast: + err := e.prepareBroadcast(cmd.committeeID, cmd.msg) + cmd.errCh <- err + case *cmdHandleUnit: + e.doHandleUnit(ctx, cmd) + } +} + +// Run starts the engine's main loop. It blocks until the context is cancelled. +// This should be called in its own goroutine. +// +// The loop processes three things concurrently: +// 1. Commands from external callers (register, broadcast, handle incoming unit). +// 2. Events from message processors (forward to application). +// 3. Context cancellation (graceful shutdown). +func (e *Engine) Run(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case cmd := <-e.cmdCh: + e.handleCommand(ctx, cmd) + + case broadcastResult := <-e.unitsPrepared: + if broadcastResult.err != nil { + e.log.Error("couldn't prepare units", zap.Error(broadcastResult.err)) + } + e.broadcast(broadcastResult.units) + + case event := <-e.eventCh: + // Forward application-visible events from processors. + e.forwardEvent(event) + + case done := <-e.cleanupCh: + // Processor goroutine exited; clean up the processors map. + e.handleProcessorDone(done) + } + } +} + +// Probably unuseful code + +// Events returns the channel on which the application receives protocol +// events. The caller should read from this channel continuously to avoid +// backpressure on the engine. +// func (e *Engine) Events() <-chan any { +// return e.appEventCh +// } + +// RegisterCommitee registers a committee with its peer set. This must be called +// before broadcasting on or receiving shards for a committee. +// +// The method blocks until the command is processed by the engine's Run() loop. +// todo(rdr): I am not sure this method should exist, or at least be defined at engine level +// func (e *Engine) RegisterCommittee( +// ctx context.Context, +// committeeID CommitteeID, +// peers []peer.ID, +// ) error { +// errCh := make(chan error, 1) +// select { +// case e.cmdCh <- ®isterCommittee{ +// committeeID: committeeID, +// peers: peers, +// errCh: errCh, +// }: +// case <-ctx.Done(): +// return ctx.Err() +// } +// +// select { +// case err := <-errCh: +// return err +// case <-ctx.Done(): +// return ctx.Err() +// } +// } + +// // UnregisterCommittee removes a committee. Existing processors for that channel +// // will continue running until they finalise or time out, but no new +// // processors will be created. +// func (e *Engine) UnregisterCommittee(ctx context.Context, channel CommitteeID) error { +// select { +// case e.cmdCh <- &cmdUnregister{committeeID: channel}: +// return nil +// case <-ctx.Done(): +// return ctx.Err() +// } +// } + +// // Broadcast encodes and distributes a message to all peers in the given +// // channel. The local node acts as the publisher. +// // +// // The method blocks until the command is processed by the engine's Run() loop. +// func (e *Engine) Broadcast( +// ctx context.Context, channel CommitteeID, msg []byte, +// ) error { +// errCh := make(chan error, 1) +// select { +// case e.cmdCh <- &cmdBroadcast{ +// channel: channel, +// msg: msg, +// errCh: errCh, +// }: +// case <-ctx.Done(): +// return ctx.Err() +// } +// +// select { +// case err := <-errCh: +// return err +// case <-ctx.Done(): +// return ctx.Err() +// } +// } + +// // HandleUnit routes an incoming PropellerUnit from the network to the +// // appropriate message processor. This method is non-blocking: it sends +// // the unit to the engine's command channel. +// func (e *Engine) HandleUnit(unit *PropellerUnit, sender peer.ID) { +// // Non-blocking send: if the command channel is full, drop the unit. +// // This provides backpressure against flood attacks. The sender can +// // retry or the processor can reconstruct from other shards. +// select { +// case e.cmdCh <- &cmdHandleUnit{ +// unit: unit, +// sender: sender, +// }: +// default: +// e.log.Warn("dropping incoming unit: command channel full", +// zap.Uint32("shard", uint32(unit.ShardIndex)), +// zap.Stringer("publisher", unit.Publisher), +// ) +// } +// } diff --git a/consensus/propeller/engine_test.go b/consensus/propeller/engine_test.go new file mode 100644 index 0000000000..b92a4e67f0 --- /dev/null +++ b/consensus/propeller/engine_test.go @@ -0,0 +1,372 @@ +package propeller + +import ( + "bytes" + "context" + "crypto/ed25519" + "fmt" + "sync" + "testing" + "time" + + "github.com/NethermindEth/juno/utils" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// engineTestEnv provides the common setup for engine-level tests. +type engineTestEnv struct { + peers []peer.ID + privKeys []crypto.PrivKey + engines []*Engine + sentUnits map[peer.ID][]*PropellerUnit + sentMu sync.Mutex + log utils.Logger +} + +//nolint:unparam // n is always 4 in current tests but kept for flexibility +func newEngineTestEnv(t *testing.T, n int) *engineTestEnv { + t.Helper() + + peers := make([]peer.ID, n) + privKeys := make([]crypto.PrivKey, n) + for i := range n { + seed := make([]byte, ed25519.SeedSize) + seed[0] = byte(i) + reader := bytes.NewReader(seed) + priv, pub, err := crypto.GenerateEd25519Key(reader) + require.NoError(t, err) + id, err := peer.IDFromPublicKey(pub) + require.NoError(t, err) + privKeys[i] = priv + peers[i] = id + } + + log := utils.NewNopZapLogger() + + env := &engineTestEnv{ + peers: peers, + privKeys: privKeys, + sentUnits: make(map[peer.ID][]*PropellerUnit), + log: log, + } + + config := Config{ + StaleMessageTimeout: 5 * time.Second, + StreamProtocol: "/propeller/test/0.1.0", + MaxWireMessageSize: 1 << 20, + } + + engines := make([]*Engine, n) + for i := range n { + engines[i] = NewEngine( + peers[i], privKeys[i], config, + env.makeSendFn(), + log, + ) + } + env.engines = engines + + return env +} + +// makeSendFn creates a SendUnitFunc that records sent units. +func (env *engineTestEnv) makeSendFn() SendUnitFunc { + return func(_ context.Context, to peer.ID, unit *PropellerUnit) error { + env.sentMu.Lock() + env.sentUnits[to] = append(env.sentUnits[to], unit) + env.sentMu.Unlock() + return nil + } +} + +// getSentUnits returns all units sent to a given peer. +func (env *engineTestEnv) getSentUnits(to peer.ID) []*PropellerUnit { + env.sentMu.Lock() + defer env.sentMu.Unlock() + result := make([]*PropellerUnit, len(env.sentUnits[to])) + copy(result, env.sentUnits[to]) + return result +} + +func TestEngine_RegisterAndBroadcast(t *testing.T) { + env := newEngineTestEnv(t, 4) + + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + engine := env.engines[0] + + // Run the engine in the background. + done := make(chan error, 1) + go func() { + done <- engine.Run(ctx) + }() + + // Register a channel with all peers. + err := engine.RegisterChannel(ctx, 1, env.peers) + require.NoError(t, err) + + // Broadcast a message. + msg := []byte("hello from engine test") + err = engine.Broadcast(ctx, 1, msg) + require.NoError(t, err) + + // Verify that units were sent to the other 3 peers. + // Give a moment for async processing. + time.Sleep(100 * time.Millisecond) + + totalSent := 0 + for _, p := range env.peers { + if p == env.peers[0] { + continue + } + units := env.getSentUnits(p) + totalSent += len(units) + } + assert.Equal(t, 3, totalSent, "should send one unit to each non-publisher peer") + + cancel() + <-done +} + +func TestEngine_BroadcastUnregisteredChannel(t *testing.T) { + env := newEngineTestEnv(t, 4) + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + engine := env.engines[0] + + go func() { + engine.Run(ctx) //nolint:errcheck // test helper + }() + + err := engine.Broadcast(ctx, 99, []byte("should fail")) + require.Error(t, err) + + var pubErr *ShardPublishError + require.ErrorAs(t, err, &pubErr) + assert.Equal(t, ReasonChannelNotRegistered, pubErr.Reason) + + cancel() +} + +func TestEngine_HandleUnit_CreatesProcessor(t *testing.T) { + env := newEngineTestEnv(t, 4) + + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + // Set up engine for peer 0. + engine := env.engines[0] + + go func() { + engine.Run(ctx) //nolint:errcheck // test helper + }() + + // Register the channel. + err := engine.RegisterChannel(ctx, 1, env.peers) + require.NoError(t, err) + + // Simulate receiving a unit from peer 1 (as publisher). + schedule := NewScheduler(env.peers) + enc, err := NewEncoder(schedule.NumDataShards(), schedule.NumCodingShards()) + require.NoError(t, err) + + msg := []byte("incoming message") + units, root, err := EncodeMessage(msg, schedule, enc) + require.NoError(t, err) + + publisher := env.peers[1] + sig, err := SignRoot(root, env.privKeys[1]) + require.NoError(t, err) + + for i := range units { + units[i].Publisher = publisher + units[i].Signature = sig + units[i].CommitteeID = 1 + } + + // Send units from their correct senders. + for i, unit := range units { + sender, err := schedule.PeerForShard(publisher, ShardIndex(i)) + require.NoError(t, err) + + // Skip units "from ourselves" -- the validator rejects those. + if sender == env.peers[0] { + continue + } + + unitCopy := unit + engine.HandleUnit(&unitCopy, sender) + } + + // Wait for the message to be processed and check events. + var received *EventMessageReceived + deadline := time.After(5 * time.Second) + for received == nil { + select { + case ev := <-engine.Events(): + if r, ok := ev.(EventMessageReceived); ok { + received = &r + } + case <-deadline: + t.Fatal("timed out waiting for EventMessageReceived") + } + } + + assert.Equal(t, msg, received.Message) + assert.Equal(t, publisher, received.Publisher) + assert.Equal(t, root, received.Root) + + cancel() +} + +func TestEngine_UnregisterChannel(t *testing.T) { + env := newEngineTestEnv(t, 4) + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + engine := env.engines[0] + + go func() { + engine.Run(ctx) //nolint:errcheck // test helper + }() + + err := engine.RegisterChannel(ctx, 1, env.peers) + require.NoError(t, err) + + err = engine.UnregisterChannel(ctx, 1) + require.NoError(t, err) + + // Allow command to be processed. + time.Sleep(50 * time.Millisecond) + + // Broadcasting should fail now. + err = engine.Broadcast(ctx, 1, []byte("after unregister")) + require.Error(t, err) + + cancel() +} + +func TestEngine_HandleUnit_UnregisteredChannel(t *testing.T) { + env := newEngineTestEnv(t, 4) + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + engine := env.engines[0] + + go func() { + engine.Run(ctx) //nolint:errcheck // test helper + }() + + // Send a unit for an unregistered channel. + unit := &PropellerUnit{ + CommitteeID: 99, + Publisher: env.peers[1], + MerkleRoot: MessageRoot{0x01}, + ShardIndex: 0, + ShardData: []byte("data"), + } + engine.HandleUnit(unit, env.peers[1]) + + // Allow time for processing. + time.Sleep(100 * time.Millisecond) + + // No crash, no panic -- the unit is silently dropped. + cancel() +} + +func TestEngine_GracefulShutdown(t *testing.T) { + env := newEngineTestEnv(t, 4) + + ctx, cancel := context.WithCancel(t.Context()) + + engine := env.engines[0] + done := make(chan error, 1) + go func() { + done <- engine.Run(ctx) + }() + + cancel() + + select { + case err := <-done: + assert.ErrorIs(t, err, context.Canceled) + case <-time.After(2 * time.Second): + t.Fatal("engine did not shut down in time") + } +} + +func TestEngine_SendFailureEmitsEvent(t *testing.T) { + env := newEngineTestEnv(t, 4) + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + // Create an engine with a failing send function. + engine := NewEngine( + env.peers[0], env.privKeys[0], + Config{ + StaleMessageTimeout: 5 * time.Second, + StreamProtocol: "/propeller/test/0.1.0", + MaxWireMessageSize: 1 << 20, + }, + func(_ context.Context, _ peer.ID, _ *PropellerUnit) error { + return fmt.Errorf("simulated network failure") + }, + utils.NewNopZapLogger(), + ) + + go func() { + engine.Run(ctx) //nolint:errcheck // test helper + }() + + err := engine.RegisterChannel(ctx, 1, env.peers) + require.NoError(t, err) + + err = engine.Broadcast(ctx, 1, []byte("will fail sending")) + require.NoError(t, err) // Broadcast itself succeeds; send failures are events. + + // Collect send failure events. + deadline := time.After(2 * time.Second) + failures := 0 +loop: + for failures < 3 { + select { + case ev := <-engine.Events(): + if _, ok := ev.(EventShardSendFailed); ok { + failures++ + } + case <-deadline: + break loop + } + } + assert.Equal(t, 3, failures, "should have 3 send failures (one per non-publisher peer)") + + cancel() +} + +func TestEngine_RegisterChannelTooFewPeers(t *testing.T) { + env := newEngineTestEnv(t, 4) + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + engine := env.engines[0] + + go func() { + engine.Run(ctx) //nolint:errcheck // test helper + }() + + // A single peer cannot form a channel (0 shards). + err := engine.RegisterChannel(ctx, 1, []peer.ID{env.peers[0]}) + require.Error(t, err) + + cancel() +} diff --git a/consensus/propeller/merkle/merkle.go b/consensus/propeller/merkle/merkle.go new file mode 100644 index 0000000000..65c542370a --- /dev/null +++ b/consensus/propeller/merkle/merkle.go @@ -0,0 +1,168 @@ +package merkle + +import ( + "crypto/sha256" + "math/bits" +) + +type Hash [32]byte + +// Proof contains the sibling hashes needed to verify that a leaf +// belongs to a Merkle tree with a known root. Siblings are ordered from +// leaf level (index 0) up to the root. +type Proof struct { + Siblings []Hash +} + +// Represents a Merkle Tree +type Tree []Proof + +// Merkle tree construction and verification using a specific SHA-256 tagging +// scheme. Tags prevent second-preimage attacks by domain-separating leaf +// hashes from internal node hashes. The exact tag format matches the Propeller +// protocol specification so that all implementations produce identical trees. +// +// Tree layout: leaves are at the bottom, padded to the next power-of-two +// with the hash of empty data. The tree is built bottom-up by hashing pairs. + +// merkleLeafHash computes: SHA256("" || data || "") +// +// The XML-like tags are the domain separator specified by the Propeller +// protocol. They ensure a leaf hash can never collide with a node hash, +// even if an attacker controls the data. +func merkleLeafHash(data []byte) Hash { + h := sha256.New() + h.Write([]byte("")) + h.Write(data) + h.Write([]byte("")) + var out [32]byte + h.Sum(out[:0]) + return out +} + +// merkleNodeHash computes: +// +// SHA256("" || left || "" || right || "") +// +// The nested tags ensure node hashes are in a separate domain from leaf hashes. +func merkleNodeHash(left, right [32]byte) Hash { + h := sha256.New() + h.Write([]byte("")) + h.Write(left[:]) + h.Write([]byte("")) + h.Write(right[:]) + h.Write([]byte("")) + var out [32]byte + h.Sum(out[:0]) + return out +} + +// nextPowerOfTwo returns the smallest power of two >= n, with a minimum of 2. +// A minimum of 2 ensures even a single-leaf tree has a sibling for its proof. +func nextPowerOfTwo(n int) int { + if n <= 2 { + return 2 + } + // bits.Len returns the position of the highest set bit + 1. + // Subtracting 1 before Len handles exact powers-of-two correctly. + return 1 << bits.Len(uint(n-1)) +} + +// emptyLeafHash is the hash of a padding leaf (no data). We precompute it +// because the same value is used repeatedly when the leaf count is not a +// power of two. +var emptyLeafHash = merkleLeafHash(nil) + +// New constructs a binary Merkle tree from the given leaf data +// and returns the root hash plus one inclusion proof per original leaf. +// +// The tree is padded to the next power-of-two size with empty leaves. This +// simplifies the proof logic: every node at every level has a sibling, and +// the proof path length is always log2(paddedSize). +// +// Returns a zero root and nil proofs if leaves is empty. +func New(leaves [][]byte) (root Hash, tree Tree) { + n := len(leaves) + if n == 0 { + // todo(rdr): maybe here we return a default merkle tree + return [32]byte{}, nil + } + + size := nextPowerOfTwo(n) + + // Build the bottom layer: hash each leaf, pad to power-of-two. + layer := make([]Hash, size) + for i := range n { + layer[i] = merkleLeafHash(leaves[i]) + } + for i := n; i < size; i++ { + layer[i] = emptyLeafHash + } + + // proofSiblings[i] accumulates the sibling hashes for leaf i's proof. + // We collect them bottom-up as we build the tree. + proofSiblings := make([][]Hash, n) + + // Build the tree bottom-up, one level at a time. + for len(layer) > 1 { + nextLayer := make([]Hash, len(layer)/2) + for i := 0; i < len(layer); i += 2 { + left, right := layer[i], layer[i+1] + nextLayer[i/2] = merkleNodeHash(left, right) + + // Record siblings for any original leaves still tracked at + // this level. Leaf j at this level has its sibling at j^1 + // (XOR flips the last bit to get the pair partner). + for j := range n { + // Which position in the current layer does leaf j's + // ancestor occupy? It's j >> (current depth), but we + // track this implicitly: at depth d the ancestor of + // leaf j is at position j >> d. Since we've already + // collected d levels of siblings, d == len(proofSiblings[j]). + d := len(proofSiblings[j]) + ancestorPos := j >> d + if ancestorPos/2 == i/2 { + // This pair contains leaf j's ancestor. The sibling + // is the other element of the pair. + sibling := ancestorPos ^ 1 + proofSiblings[j] = append(proofSiblings[j], layer[sibling]) + } + } + } + layer = nextLayer + } + + root = layer[0] + + tree = make([]Proof, n) + for i := range n { + tree[i] = Proof{Siblings: proofSiblings[i]} + } + + return root, tree +} + +// VerifyProof checks that a leaf at the given index is included in a +// tree with the claimed root. The proof contains sibling hashes from the leaf +// level up to the root. +// +// The index determines the path through the tree: at each level, if the +// current bit of the index is 0 the current hash is the left child and the +// sibling is the right child, and vice versa. +func VerifyProof(root Hash, leaf []byte, index uint32, proof Proof) bool { + current := merkleLeafHash(leaf) + + idx := index + for _, sibling := range proof.Siblings { + if idx%2 == 0 { + // Current node is left child, sibling is right. + current = merkleNodeHash(current, sibling) + } else { + // Current node is right child, sibling is left. + current = merkleNodeHash(sibling, current) + } + idx /= 2 + } + + return current == root +} diff --git a/consensus/propeller/merkle/merkle_test.go b/consensus/propeller/merkle/merkle_test.go new file mode 100644 index 0000000000..e02236492b --- /dev/null +++ b/consensus/propeller/merkle/merkle_test.go @@ -0,0 +1,243 @@ +package merkle_test + +import ( + "crypto/sha256" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMerkleLeafHash(t *testing.T) { + data := []byte("hello") + hash := merkleLeafHash(data) + + // Manually compute expected: SHA256("hello") + h := sha256.New() + h.Write([]byte("hello")) + var expected [32]byte + h.Sum(expected[:0]) + + assert.Equal(t, expected, hash) +} + +func TestMerkleLeafHash_Empty(t *testing.T) { + hash := merkleLeafHash(nil) + + h := sha256.New() + h.Write([]byte("")) + var expected [32]byte + h.Sum(expected[:0]) + + assert.Equal(t, expected, hash) +} + +func TestMerkleNodeHash(t *testing.T) { + left := merkleLeafHash([]byte("L")) + right := merkleLeafHash([]byte("R")) + node := merkleNodeHash(left, right) + + h := sha256.New() + h.Write([]byte("")) + h.Write(left[:]) + h.Write([]byte("")) + h.Write(right[:]) + h.Write([]byte("")) + var expected [32]byte + h.Sum(expected[:0]) + + assert.Equal(t, expected, node) +} + +func TestNextPowerOfTwo(t *testing.T) { + tests := []struct { + n int + expected int + }{ + {0, 2}, + {1, 2}, + {2, 2}, + {3, 4}, + {4, 4}, + {5, 8}, + {7, 8}, + {8, 8}, + {9, 16}, + {16, 16}, + {17, 32}, + } + + for _, tc := range tests { + assert.Equal(t, tc.expected, nextPowerOfTwo(tc.n), "nextPowerOfTwo(%d)", tc.n) + } +} + +func TestBuildMerkleTree_Empty(t *testing.T) { + root, proofs := BuildMerkleTree(nil) + assert.Equal(t, [32]byte{}, root) + assert.Nil(t, proofs) +} + +func TestBuildMerkleTree_SingleLeaf(t *testing.T) { + leaves := [][]byte{[]byte("only")} + root, proofs := BuildMerkleTree(leaves) + + require.Len(t, proofs, 1) + + // With one leaf padded to 2, the tree is: + // root + // / \ + // leaf0 empty + leafHash := merkleLeafHash([]byte("only")) + expectedRoot := merkleNodeHash(leafHash, emptyLeafHash) + assert.Equal(t, expectedRoot, root) + + // Proof for leaf 0 should contain the empty leaf as sibling. + assert.Len(t, proofs[0].Siblings, 1) + assert.Equal(t, emptyLeafHash, proofs[0].Siblings[0]) +} + +func TestBuildMerkleTree_TwoLeaves(t *testing.T) { + leaves := [][]byte{[]byte("A"), []byte("B")} + root, proofs := BuildMerkleTree(leaves) + + require.Len(t, proofs, 2) + + h0 := merkleLeafHash([]byte("A")) + h1 := merkleLeafHash([]byte("B")) + expectedRoot := merkleNodeHash(h0, h1) + assert.Equal(t, expectedRoot, root) + + // Leaf 0's sibling is leaf 1. + assert.Equal(t, h1, proofs[0].Siblings[0]) + // Leaf 1's sibling is leaf 0. + assert.Equal(t, h0, proofs[1].Siblings[0]) +} + +func TestBuildMerkleTree_FourLeaves(t *testing.T) { + leaves := [][]byte{[]byte("A"), []byte("B"), []byte("C"), []byte("D")} + root, proofs := BuildMerkleTree(leaves) + + require.Len(t, proofs, 4) + + // Build expected tree manually: + // root + // / \ + // n01 n23 + // / \ / \ + // h0 h1 h2 h3 + h0 := merkleLeafHash([]byte("A")) + h1 := merkleLeafHash([]byte("B")) + h2 := merkleLeafHash([]byte("C")) + h3 := merkleLeafHash([]byte("D")) + n01 := merkleNodeHash(h0, h1) + n23 := merkleNodeHash(h2, h3) + expectedRoot := merkleNodeHash(n01, n23) + assert.Equal(t, expectedRoot, root) + + // Proof for leaf 0: siblings [h1, n23] + require.Len(t, proofs[0].Siblings, 2) + assert.Equal(t, h1, proofs[0].Siblings[0]) + assert.Equal(t, n23, proofs[0].Siblings[1]) + + // Proof for leaf 2: siblings [h3, n01] + require.Len(t, proofs[2].Siblings, 2) + assert.Equal(t, h3, proofs[2].Siblings[0]) + assert.Equal(t, n01, proofs[2].Siblings[1]) +} + +func TestBuildMerkleTree_ThreeLeaves(t *testing.T) { + // Three leaves means padding to 4: the fourth leaf is empty. + leaves := [][]byte{[]byte("A"), []byte("B"), []byte("C")} + root, proofs := BuildMerkleTree(leaves) + + require.Len(t, proofs, 3) + + h0 := merkleLeafHash([]byte("A")) + h1 := merkleLeafHash([]byte("B")) + h2 := merkleLeafHash([]byte("C")) + h3 := emptyLeafHash + n01 := merkleNodeHash(h0, h1) + n23 := merkleNodeHash(h2, h3) + expectedRoot := merkleNodeHash(n01, n23) + assert.Equal(t, expectedRoot, root) + + // Proof for leaf 2: siblings [emptyLeaf, n01] + require.Len(t, proofs[2].Siblings, 2) + assert.Equal(t, h3, proofs[2].Siblings[0]) + assert.Equal(t, n01, proofs[2].Siblings[1]) +} + +func TestVerifyMerkleProof_ValidProofs(t *testing.T) { + // Build a tree with several leaves, then verify every proof. + data := [][]byte{ + []byte("alpha"), + []byte("bravo"), + []byte("charlie"), + []byte("delta"), + []byte("echo"), + } + root, proofs := BuildMerkleTree(data) + require.Len(t, proofs, len(data)) + + for i, d := range data { + ok := VerifyMerkleProof(root, d, uint32(i), proofs[i]) + assert.True(t, ok, "proof for leaf %d should verify", i) + } +} + +func TestVerifyMerkleProof_WrongData(t *testing.T) { + leaves := [][]byte{[]byte("real"), []byte("data")} + root, proofs := BuildMerkleTree(leaves) + + // Tamper with the data. + ok := VerifyMerkleProof(root, []byte("fake"), 0, proofs[0]) + assert.False(t, ok, "tampered data should not verify") +} + +func TestVerifyMerkleProof_WrongIndex(t *testing.T) { + leaves := [][]byte{[]byte("A"), []byte("B"), []byte("C"), []byte("D")} + root, proofs := BuildMerkleTree(leaves) + + // Use leaf 0's data with leaf 1's index. + ok := VerifyMerkleProof(root, []byte("A"), 1, proofs[0]) + assert.False(t, ok, "wrong index should not verify") +} + +func TestVerifyMerkleProof_WrongRoot(t *testing.T) { + leaves := [][]byte{[]byte("A"), []byte("B")} + _, proofs := BuildMerkleTree(leaves) + + fakeRoot := [32]byte{0xff} + ok := VerifyMerkleProof(fakeRoot, []byte("A"), 0, proofs[0]) + assert.False(t, ok, "wrong root should not verify") +} + +func TestVerifyMerkleProof_TamperedSibling(t *testing.T) { + leaves := [][]byte{[]byte("A"), []byte("B")} + root, _ := BuildMerkleTree(leaves) + + // Tamper with a sibling hash in the proof. + badProof := MerkleProof{Siblings: [][32]byte{{0xde, 0xad}}} + ok := VerifyMerkleProof(root, []byte("A"), 0, badProof) + assert.False(t, ok, "tampered sibling should not verify") +} + +func TestBuildAndVerify_LargeTree(t *testing.T) { + // Build a tree with a non-power-of-two count to exercise padding. + n := 31 + leaves := make([][]byte, n) + for i := range n { + leaves[i] = []byte{byte(i), byte(i >> 8)} + } + + root, proofs := BuildMerkleTree(leaves) + require.Len(t, proofs, n) + + for i, leaf := range leaves { + assert.True(t, + VerifyMerkleProof(root, leaf, uint32(i), proofs[i]), + "proof for leaf %d in 31-leaf tree should verify", i, + ) + } +} diff --git a/consensus/propeller/pool/pool.go b/consensus/propeller/pool/pool.go new file mode 100644 index 0000000000..d4282891db --- /dev/null +++ b/consensus/propeller/pool/pool.go @@ -0,0 +1,26 @@ +package pool + +import ( + "context" + "time" +) + +type Pool[T any] struct { + ctx context.Context + taskTimeout time.Duration + activeWorkers uint64 + maxWorkers uint64 +} + +func New[T any]( + ctx context.Context, + taskTimeout time.Duration, + maxWorkers uint64, +) *Pool[T] { + return &Pool[T]{ + ctx: ctx, + taskTimeout: taskTimeout, + activeWorkers: 0, + maxWorkers: maxWorkers, + } +} diff --git a/consensus/propeller/pool/pool_test.go b/consensus/propeller/pool/pool_test.go new file mode 100644 index 0000000000..e69de29bb2 diff --git a/consensus/propeller/processor.go b/consensus/propeller/processor.go new file mode 100644 index 0000000000..33a2b06c13 --- /dev/null +++ b/consensus/propeller/processor.go @@ -0,0 +1,343 @@ +package propeller + +import ( + "context" + "fmt" + "time" + + "github.com/libp2p/go-libp2p/core/peer" +) + +// processorState tracks which phase of the message lifecycle a processor is in. +// The transitions are strictly one-directional: +// +// PreConstruction -> PostConstruction -> Finalised +// +// There is also a direct path from either state to Finalised via timeout. +type processorState int + +const ( + // statePreConstruction: collecting shards, waiting to reach the build + // threshold so we can reconstruct the original message. + statePreConstruction processorState = iota + + // statePostConstruction: message has been reconstructed. We continue + // counting incoming shards until we hit the receive threshold, which + // guarantees that enough honest nodes have our shard to ensure all + // other honest nodes can also reconstruct. + statePostConstruction + + // stateFinalised: terminal state. The processor emits a result event + // and stops accepting shards. The engine should clean up this processor. + stateFinalised +) + +// SendUnitFunc is called by the processor to send a PropellerUnit to a +// specific peer. The engine provides this callback, which handles the +// actual network I/O. The processor doesn't know or care how delivery works. +type SendUnitFunc func(ctx context.Context, to peer.ID, unit *PropellerUnit) error + +// shardDelivery bundles an incoming shard with the peer that sent it, +// so the processor can validate the sender identity. +type shardDelivery struct { + Unit *PropellerUnit + Sender peer.ID +} + +// MessageProcessor manages the lifecycle of a single message identified by +// (channel, publisher, root). It runs as a goroutine that: +// +// 1. Accepts validated shards via its input channel. +// 2. In PreConstruction: collects shards until the build threshold is met, +// then reconstructs the message via Reed-Solomon. +// 3. In PostConstruction: counts additional shards until the receive +// threshold is met, then emits the message to the application. +// 4. On timeout: emits a timeout event and finalises. +// +// The processor is deliberately simple -- it owns no locks and communicates +// entirely through channels. All mutable state is confined to its goroutine. +type MessageProcessor struct { + // Identity and configuration. + channel CommitteeID + publisher peer.ID + root MessageRoot + localPeer peer.ID + config Config + + // Dependencies (injected for testability). + schedule *Scheduler + validator *Validator + encoder Encoder + + // State. + state processorState + shards [][]byte // indexed by ShardIndex, nil = not yet received + seenShards map[ShardIndex]bool + receivedCount int + signatureVerified bool + storedSignature []byte // cached from the first valid unit + reconstructedMsg []byte + myShardUnit *PropellerUnit // the unit we are responsible for forwarding + + // Channels. + shardCh chan shardDelivery // incoming shards from the engine + eventCh chan<- any // outgoing events to the engine/application + sendFn SendUnitFunc // callback for sending units to peers +} + +// NewMessageProcessor creates a processor for a specific message. The caller +// must call Run() in a goroutine to start processing. +// +// Parameters: +// - shardCh: the engine writes incoming shards here. Buffered to prevent +// blocking the engine's main loop. +// - eventCh: the processor writes lifecycle events here (shared with other +// processors; the engine reads from it). +// - sendFn: callback for network delivery of units to peers. +func NewMessageProcessor( + channel CommitteeID, + publisher peer.ID, + root MessageRoot, + localPeer peer.ID, + config Config, + schedule *Scheduler, + validator *Validator, + encoder Encoder, + shardCh chan shardDelivery, + eventCh chan<- any, + sendFn SendUnitFunc, +) *MessageProcessor { + return &MessageProcessor{ + channel: channel, + publisher: publisher, + root: root, + localPeer: localPeer, + config: config, + schedule: schedule, + validator: validator, + encoder: encoder, + state: statePreConstruction, + shards: make([][]byte, schedule.NumShards()), + seenShards: make(map[ShardIndex]bool), + shardCh: shardCh, + eventCh: eventCh, + sendFn: sendFn, + } +} + +// Run is the processor's main loop. It blocks until the processor finalises +// (either by completing the protocol or timing out) or the context is cancelled. +// +// The select on shardCh vs timer is the core of the state machine. We +// intentionally use a single goroutine to avoid any need for synchronisation +// on the processor's internal state. +func (p *MessageProcessor) Run(ctx context.Context) { + timer := time.NewTimer(p.config.StaleMessageTimeout) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return + + case delivery, ok := <-p.shardCh: + if !ok { + // Channel closed by engine; processor is being torn down. + return + } + if p.state == stateFinalised { + return + } + p.handleShard(ctx, delivery) + if p.state == stateFinalised { + return + } + + case <-timer.C: + if p.state != stateFinalised { + p.emitEvent(EventMessageTimeout{ + Channel: p.channel, + Publisher: p.publisher, + Root: p.root, + }) + p.state = stateFinalised + } + return + } + } +} + +// handleShard processes a single incoming shard delivery. +func (p *MessageProcessor) handleShard(ctx context.Context, delivery shardDelivery) { + unit := delivery.Unit + + // Validate the unit. + if err := p.validator.ValidateUnit( + unit, delivery.Sender, p.seenShards, p.signatureVerified, + ); err != nil { + p.emitEvent(EventShardValidationFailed{ + Sender: delivery.Sender, + ClaimedRoot: unit.MerkleRoot, + ClaimedPublisher: unit.Publisher, + Err: err, + }) + return + } + + // Mark the shard as received and store its data. + p.seenShards[unit.ShardIndex] = true + p.shards[unit.ShardIndex] = unit.ShardData + p.receivedCount++ + + // Cache the signature from the first valid unit. All units for the same + // message carry the same publisher signature, so we only need one copy. + // We store it here rather than in the unit slice because we only keep + // shard data (not full units) to save memory. + if !p.signatureVerified { + p.storedSignature = make([]byte, len(unit.Signature)) + copy(p.storedSignature, unit.Signature) + } + p.signatureVerified = true + + switch p.state { + case statePreConstruction: + p.handlePreConstruction(ctx) + case statePostConstruction: + p.handlePostConstruction() + case stateFinalised: + // Should not reach here due to early return in Run, but be safe. + } +} + +// handlePreConstruction checks if we have enough shards to reconstruct. +func (p *MessageProcessor) handlePreConstruction(ctx context.Context) { + if p.receivedCount < p.schedule.BuildThreshold() { + return + } + + // Attempt Reed-Solomon reconstruction. + // We pass copies of the shard data because Reconstruct modifies the + // slice in-place, and we don't want to corrupt our stored references. + shardsCopy := make([][]byte, len(p.shards)) + for i, s := range p.shards { + if s != nil { + c := make([]byte, len(s)) + copy(c, s) + shardsCopy[i] = c + } + } + + msg, err := ReconstructMessage(shardsCopy, p.schedule, p.encoder, p.root) + if err != nil { + p.emitEvent(EventReconstructionFailed{ + Root: p.root, + Publisher: p.publisher, + Err: err, + }) + p.state = stateFinalised + return + } + + // Find our assigned shard so we can forward it to all other peers. + myShard, err := p.schedule.ShardForPeer(p.publisher, p.localPeer) + if err != nil { + p.emitEvent(EventReconstructionFailed{ + Root: p.root, + Publisher: p.publisher, + Err: fmt.Errorf("determining my shard assignment: %w", err), + }) + p.state = stateFinalised + return + } + + // Rebuild the Merkle tree from the complete shard set to get a valid + // proof for our shard. We may not have received our own shard from + // the network, so we need the proof from the reconstructed data. + leaves := make([][]byte, len(shardsCopy)) + copy(leaves, shardsCopy) + _, proofs := BuildMerkleTree(leaves) + + p.myShardUnit = &PropellerUnit{ + CommitteeID: p.channel, + Publisher: p.publisher, + MerkleRoot: p.root, + Signature: p.storedSignature, + ShardIndex: myShard, + ShardData: shardsCopy[myShard], + MerkleProof: proofs[myShard], + } + + p.reconstructedMsg = msg + + // Replace our sparse shard data with the fully reconstructed set. + p.shards = shardsCopy + + // Count our own shard as held if we didn't receive it from the network. + if !p.seenShards[myShard] { + p.seenShards[myShard] = true + p.receivedCount++ + } + + p.state = statePostConstruction + + // Broadcast our shard to all other peers (except the publisher, who + // already has all shards). + p.broadcastMyShard(ctx) + + // Check if we already meet the receive threshold (possible if many + // shards arrived before reconstruction completed). + p.handlePostConstruction() +} + +// handlePostConstruction checks if the receive threshold has been met. +func (p *MessageProcessor) handlePostConstruction() { + if p.receivedCount < p.schedule.ReceiveThreshold() { + return + } + + p.emitEvent(EventMessageReceived{ + Publisher: p.publisher, + Root: p.root, + Message: p.reconstructedMsg, + }) + p.state = stateFinalised +} + +// broadcastMyShard sends our assigned shard to all peers except the publisher. +// Failures are reported as events but do not stop the broadcast to other peers. +func (p *MessageProcessor) broadcastMyShard(ctx context.Context) { + targets, err := p.schedule.BroadcastTargets(p.publisher) + if err != nil { + p.emitEvent(EventShardPublishFailed{ + Err: fmt.Errorf("getting broadcast targets: %w", err), + }) + return + } + + for _, target := range targets { + if target == p.localPeer { + // Don't send to ourselves. + continue + } + + if err := p.sendFn(ctx, target, p.myShardUnit); err != nil { + p.emitEvent(EventShardSendFailed{ + From: p.localPeer, + To: target, + Err: err, + }) + } + } +} + +// emitEvent sends an event to the application layer. Uses a non-blocking send +// so a slow consumer doesn't block the processor. The engine's event channel +// should be large enough that this rarely drops. +func (p *MessageProcessor) emitEvent(event any) { + select { + case p.eventCh <- event: + default: + // Event channel is full. This should be rare with a properly sized + // buffer. The event is lost, but the processor continues operating. + } +} diff --git a/consensus/propeller/processor_test.go b/consensus/propeller/processor_test.go new file mode 100644 index 0000000000..6c2f5342e5 --- /dev/null +++ b/consensus/propeller/processor_test.go @@ -0,0 +1,451 @@ +package propeller + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// processorTestEnv bundles the common setup for processor tests. It creates +// a realistic N-peer environment with a real Reed-Solomon encoder. +type processorTestEnv struct { + peers []peer.ID + privKeys map[peer.ID]crypto.PrivKey + schedule *Scheduler + encoder Encoder + config Config + eventCh chan any + sentUnits []sentUnit + sentMu sync.Mutex +} + +type sentUnit struct { + To peer.ID + Unit *PropellerUnit +} + +func newProcessorTestEnv(t *testing.T, n int) *processorTestEnv { + t.Helper() + + rawPeers := make([]peer.ID, n) + privKeys := make(map[peer.ID]crypto.PrivKey, n) + for i := range n { + priv, id := realPeer(byte(i)) + rawPeers[i] = id + privKeys[id] = priv + } + + schedule := NewScheduler(rawPeers) + + enc, err := NewEncoder(schedule.NumDataShards(), schedule.NumCodingShards()) + require.NoError(t, err) + + return &processorTestEnv{ + peers: schedule.Peers(), // Use sorted order. + privKeys: privKeys, + schedule: schedule, + encoder: enc, + config: Config{ + StaleMessageTimeout: 5 * time.Second, + }, + eventCh: make(chan any, 100), + } +} + +// sendFunc records sent units for later inspection. +func (env *processorTestEnv) sendFunc() SendUnitFunc { + return func(_ context.Context, to peer.ID, unit *PropellerUnit) error { + env.sentMu.Lock() + defer env.sentMu.Unlock() + env.sentUnits = append(env.sentUnits, sentUnit{To: to, Unit: unit}) + return nil + } +} + +// encodeTestMessage encodes a message from the given publisher and returns +// the signed units and root. +func (env *processorTestEnv) encodeTestMessage( + t *testing.T, publisher peer.ID, msg []byte, +) ([]PropellerUnit, MessageRoot) { + t.Helper() + + units, root, err := EncodeMessage(msg, env.schedule, env.encoder) + require.NoError(t, err) + + privKey, ok := env.privKeys[publisher] + require.True(t, ok, "no private key for publisher %s", publisher) + + sig, err := SignRoot(root, privKey) + require.NoError(t, err) + + for i := range units { + units[i].Publisher = publisher + units[i].Signature = sig + units[i].CommitteeID = 1 + } + + return units, root +} + +// drainEvents reads all currently available events from the event channel. +func (env *processorTestEnv) drainEvents() []any { + var events []any + for { + select { + case ev := <-env.eventCh: + events = append(events, ev) + default: + return events + } + } +} + +func TestProcessor_FullLifecycle(t *testing.T) { + // Simulate a 7-node network. localPeer (sorted[0]) receives shards from + // a message published by sorted[1]. With 7 peers: 2 data shards, + // 4 coding shards. Build threshold = 2, receive threshold = 4. + env := newProcessorTestEnv(t, 7) + + localPeer := env.peers[0] + publisher := env.peers[1] + msg := []byte("hello propeller protocol") + + units, root := env.encodeTestMessage(t, publisher, msg) + + validator := NewValidator( + env.schedule, localPeer, &DefaultSignatureVerifier{}, + ) + + shardCh := make(chan shardDelivery, 20) + proc := NewMessageProcessor( + 1, publisher, root, localPeer, env.config, + env.schedule, validator, env.encoder, + shardCh, env.eventCh, env.sendFunc(), + ) + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + // Run the processor in a goroutine. + done := make(chan struct{}) + go func() { + proc.Run(ctx) + close(done) + }() + + // Feed shards one at a time, from their correct senders. + // Skip shards "from" localPeer (the validator rejects self-sends). + for i, unit := range units { + sender, err := env.schedule.PeerForShard(publisher, ShardIndex(i)) + require.NoError(t, err) + + if sender == localPeer { + continue + } + + unitCopy := unit + shardCh <- shardDelivery{Unit: &unitCopy, Sender: sender} + } + + // Wait for the processor to finalise. + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("processor did not finalise in time") + } + + // Check that we got a MessageReceived event. + events := env.drainEvents() + var received *EventMessageReceived + for _, ev := range events { + if r, ok := ev.(EventMessageReceived); ok { + received = &r + break + } + } + require.NotNil(t, received, + "expected EventMessageReceived, got %d events", len(events)) + assert.Equal(t, msg, received.Message) + assert.Equal(t, publisher, received.Publisher) + assert.Equal(t, root, received.Root) + + // Check that our shard was broadcast to other peers. + env.sentMu.Lock() + defer env.sentMu.Unlock() + assert.NotEmpty(t, env.sentUnits, "should have broadcast our shard") +} + +func TestProcessor_ReconstructionFromMinimumShards(t *testing.T) { + // With 4 peers: 1 data shard, 2 coding shards. + // Build threshold = 1, receive threshold = 2 (N>3). + // After reconstruction, the processor counts its own shard (=1), + // so it needs at least 1 more from the network to reach 2. + env := newProcessorTestEnv(t, 4) + + localPeer := env.peers[0] + publisher := env.peers[1] + msg := []byte("minimum shards test") + + units, root := env.encodeTestMessage(t, publisher, msg) + + validator := NewValidator( + env.schedule, localPeer, &DefaultSignatureVerifier{}, + ) + + shardCh := make(chan shardDelivery, 10) + proc := NewMessageProcessor( + 1, publisher, root, localPeer, env.config, + env.schedule, validator, env.encoder, + shardCh, env.eventCh, env.sendFunc(), + ) + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + done := make(chan struct{}) + go func() { + proc.Run(ctx) + close(done) + }() + + // Send all non-local shards. With 3 shards total and localPeer holding + // one slot, we have 2 shards to send. receive threshold = 2, and + // after reconstruction the processor holds its own shard (+1), so + // it needs 1 from the network + 1 own = 2 to finalise. + for i, unit := range units { + sender, err := env.schedule.PeerForShard(publisher, ShardIndex(i)) + require.NoError(t, err) + + if sender == localPeer { + continue + } + + unitCopy := unit + shardCh <- shardDelivery{Unit: &unitCopy, Sender: sender} + } + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("processor did not finalise in time") + } + + events := env.drainEvents() + var received *EventMessageReceived + for _, ev := range events { + if r, ok := ev.(EventMessageReceived); ok { + received = &r + break + } + } + require.NotNil(t, received, "expected EventMessageReceived") + assert.Equal(t, msg, received.Message) +} + +func TestProcessor_Timeout(t *testing.T) { + env := newProcessorTestEnv(t, 4) + // Use a very short timeout for the test. + env.config.StaleMessageTimeout = 50 * time.Millisecond + + localPeer := env.peers[0] + publisher := env.peers[1] + + _, root := env.encodeTestMessage(t, publisher, []byte("will timeout")) + + validator := NewValidator( + env.schedule, localPeer, &DefaultSignatureVerifier{}, + ) + + shardCh := make(chan shardDelivery, 10) + proc := NewMessageProcessor( + 42, publisher, root, localPeer, env.config, + env.schedule, validator, env.encoder, + shardCh, env.eventCh, env.sendFunc(), + ) + + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) + defer cancel() + + done := make(chan struct{}) + go func() { + proc.Run(ctx) + close(done) + }() + + // Don't send any shards -- just wait for timeout. + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("processor did not finalise after timeout") + } + + events := env.drainEvents() + var timeout *EventMessageTimeout + for _, ev := range events { + if to, ok := ev.(EventMessageTimeout); ok { + timeout = &to + break + } + } + require.NotNil(t, timeout, + "expected EventMessageTimeout, got %d events", len(events)) + assert.Equal(t, CommitteeID(42), timeout.Channel) + assert.Equal(t, publisher, timeout.Publisher) + assert.Equal(t, root, timeout.Root) +} + +func TestProcessor_DuplicateShardRejected(t *testing.T) { + // Use 10 peers so the processor doesn't finalise after the first shard. + // N=10: numDataShards=3, receiveThreshold=6. + // Sending just one shard (receivedCount=1) is far below the threshold, + // so the processor stays in PreConstruction and will reject the duplicate. + env := newProcessorTestEnv(t, 10) + + localPeer := env.peers[0] + publisher := env.peers[1] + msg := []byte("test duplicates") + + units, root := env.encodeTestMessage(t, publisher, msg) + + validator := NewValidator( + env.schedule, localPeer, &DefaultSignatureVerifier{}, + ) + + shardCh := make(chan shardDelivery, 10) + proc := NewMessageProcessor( + 1, publisher, root, localPeer, env.config, + env.schedule, validator, env.encoder, + shardCh, env.eventCh, env.sendFunc(), + ) + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + done := make(chan struct{}) + go func() { + proc.Run(ctx) + close(done) + }() + + // Find a shard not from localPeer. + var targetUnit PropellerUnit + var targetSender peer.ID + for i, unit := range units { + sender, err := env.schedule.PeerForShard(publisher, ShardIndex(i)) + require.NoError(t, err) + if sender != localPeer { + targetUnit = unit + targetSender = sender + break + } + } + + // Send the same shard twice, back-to-back. The processor handles them + // sequentially (single goroutine), so the second one will see the first + // already in seenShards. + shardCh <- shardDelivery{Unit: &targetUnit, Sender: targetSender} + + dup := targetUnit + shardCh <- shardDelivery{Unit: &dup, Sender: targetSender} + + // Give the processor time to handle both deliveries. + time.Sleep(200 * time.Millisecond) + cancel() + <-done + + // Check that we got a validation failure event for the duplicate. + events := env.drainEvents() + var validationFailed bool + for _, ev := range events { + if vf, ok := ev.(EventShardValidationFailed); ok { + var valErr *ShardValidationError + if asErr, ok := vf.Err.(*ShardValidationError); ok { + valErr = asErr + } + if valErr != nil && valErr.Reason == ReasonDuplicateShard { + validationFailed = true + break + } + } + } + assert.True(t, validationFailed, "expected duplicate shard to be rejected") +} + +func TestProcessor_ContextCancellation(t *testing.T) { + env := newProcessorTestEnv(t, 4) + + localPeer := env.peers[0] + publisher := env.peers[1] + + validator := NewValidator( + env.schedule, localPeer, &DefaultSignatureVerifier{}, + ) + + root := MessageRoot{0x01} + shardCh := make(chan shardDelivery, 10) + proc := NewMessageProcessor( + 1, publisher, root, localPeer, env.config, + env.schedule, validator, env.encoder, + shardCh, env.eventCh, env.sendFunc(), + ) + + ctx, cancel := context.WithCancel(t.Context()) + + done := make(chan struct{}) + go func() { + proc.Run(ctx) + close(done) + }() + + // Cancel immediately. + cancel() + + select { + case <-done: + // Good, processor exited. + case <-time.After(1 * time.Second): + t.Fatal("processor did not exit on context cancellation") + } +} + +func TestProcessor_ChannelClose(t *testing.T) { + env := newProcessorTestEnv(t, 4) + + localPeer := env.peers[0] + publisher := env.peers[1] + + validator := NewValidator( + env.schedule, localPeer, &DefaultSignatureVerifier{}, + ) + + root := MessageRoot{0x02} + shardCh := make(chan shardDelivery, 10) + proc := NewMessageProcessor( + 1, publisher, root, localPeer, env.config, + env.schedule, validator, env.encoder, + shardCh, env.eventCh, env.sendFunc(), + ) + + ctx := t.Context() + + done := make(chan struct{}) + go func() { + proc.Run(ctx) + close(done) + }() + + // Close the shard channel to signal teardown. + close(shardCh) + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("processor did not exit on channel close") + } +} diff --git a/consensus/propeller/propeller.go b/consensus/propeller/propeller.go new file mode 100644 index 0000000000..620b5b4fc0 --- /dev/null +++ b/consensus/propeller/propeller.go @@ -0,0 +1,67 @@ +package propeller + +import ( + "bytes" + "context" + "io" + + "github.com/NethermindEth/juno/utils" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "go.uber.org/zap" +) + +const propellerProtocolID = "/propeller/0.0.1" + +// This would represent the propeller service that glues the whole +// thing to p2p. Thing is, I've no clue how to do that. +type Service interface{} + +type propellerService struct { + host host.Host + engine *Engine + config Config + log utils.Logger +} + +func New( + host host.Host, + privKey crypto.PrivKey, + config *Config, + log utils.Logger, +) Service { + engine := NewEngine( + privKey, + config, + nil, + log, + ) + + return &propellerService{ + host: host, + engine: engine, + config: *config, + log: log, + } +} + +func (s *propellerService) Run(ctx context.Context) { +} + +func (s *propellerService) handleInboudStream(stream network.Stream) { + defer stream.Close() + + sender := stream.Conn().RemotePeer() + + reader := io.LimitReader(stream, int64(s.config.MaxWireMessageSize)) + + var buf bytes.Buffer + _, err := buf.ReadFrom(reader) + if err != nil { + s.log.Debug("error reading inbound propeller stream", + zap.Stringer("peer", sender), + zap.Error(err), + ) + } +} diff --git a/consensus/propeller/propeller_test.go b/consensus/propeller/propeller_test.go new file mode 100644 index 0000000000..e69de29bb2 diff --git a/consensus/propeller/scheduler_test.go b/consensus/propeller/scheduler_test.go new file mode 100644 index 0000000000..d8153483fd --- /dev/null +++ b/consensus/propeller/scheduler_test.go @@ -0,0 +1,286 @@ +// todo(rdr): make it propeller_test +package propeller + +import ( + "cmp" + "math/rand" + "slices" + "testing" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testPeers creates N deterministic peer IDs that sort in alphabetical order. +// Also returns a local peer ID choosen at random fromt the list +func testPeers(t *testing.T, names ...string) (peer.ID, []PeerCommittee) { + t.Helper() + + peers := make([]PeerCommittee, len(names)) + for i, n := range names { + peers[i] = PeerCommittee{ + ID: peer.ID(n), + Stake: Stake(rand.Uint32()), + } + } + // note(rdr): should we make the random generation deterministic? + localPeer := peers[rand.Int()%len(peers)].ID + + return localPeer, peers +} + +func TestSchedule_Thresholds(t *testing.T) { + tests := []struct { + name string + n int + numDataShards int + numCodingShards int + numShards int + buildThreshold int + receiveThreshold int + }{ + { + name: "N=1 (solo node, no shards)", + n: 1, + numDataShards: 0, + numCodingShards: 0, + numShards: 0, + }, + { + name: "N=2", + n: 2, + numDataShards: 1, // max of 1 and (1/3) yields 1 + numCodingShards: 0, // 1 minus 1 yields 0 + numShards: 1, + }, + { + name: "N=3", + n: 3, + numDataShards: 1, // max of 1 and (2/3) yields 1 + numCodingShards: 1, // 2 minus 1 yields 1 + numShards: 2, + }, + { + name: "N=4", + n: 4, + numDataShards: 1, // 3/3 = 1 + numCodingShards: 2, // 3 - 1 = 2 + numShards: 3, + }, + { + name: "N=5", + n: 5, + numDataShards: 1, // 4/3 = 1 + numCodingShards: 3, // 4 - 1 = 3 + numShards: 4, + }, + { + name: "N=7", + n: 7, + numDataShards: 2, // 6/3 = 2 + numCodingShards: 4, // 6 - 2 = 4 + numShards: 6, + }, + { + name: "N=10", + n: 10, + numDataShards: 3, // 9/3 = 3 + numCodingShards: 6, // 9 - 3 = 6 + numShards: 9, + }, + { + name: "N=31", + n: 31, + numDataShards: 10, // 30/3 = 10 + numCodingShards: 20, // 30 - 10 = 20 + numShards: 30, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + names := make([]string, tc.n) + for i := range tc.n { + names[i] = string(rune('A' + i)) + } + localPeer, peers := testPeers(t, names...) + + s, err := NewScheduler(localPeer, peers) + require.NoError(t, err) + + assert.Equal(t, tc.numDataShards, s.DataShards()) + assert.Equal(t, tc.numCodingShards, s.CodingShards()) + assert.Equal(t, tc.numShards, s.NumShards()) + }) + } +} + +func TestSchedule_Sorting(t *testing.T) { + // Peers provided out of order should be sorted. + localPeer, peers := testPeers(t, "D", "B", "A", "C") + + sortedPeers := make([]PeerCommittee, 0, len(peers)) + copy(sortedPeers, peers) + slices.SortFunc(sortedPeers, func(a, b PeerCommittee) int { + return cmp.Compare(a.ID, b.ID) + }) + + s, err := NewScheduler(localPeer, peers) + require.NoError(t, err) + + assert.Equal(t, sortedPeers, s.Peers()) +} + +func TestSchedule_PeerForShard_SpecExample(t *testing.T) { + // From the specification: peers [A, B, C, D], publisher = C (index 2). + // Shard 0 -> A, Shard 1 -> B, Shard 2 -> D + localPeer, peers := testPeers(t, "A", "B", "C", "D") + + s, err := NewScheduler(localPeer, peers) + require.NoError(t, err) + + publisher := peer.ID("C") + + tests := []struct { + shardIndex ShardIndex + expected peer.ID + }{ + {0, peer.ID("A")}, + {1, peer.ID("B")}, + {2, peer.ID("D")}, + } + + for _, tc := range tests { + got, err := s.PeerForShard(publisher, tc.shardIndex) + require.NoError(t, err) + assert.Equal(t, tc.expected, got, "shard %d", tc.shardIndex) + } +} + +func TestSchedule_PeerForShard_PublisherFirst(t *testing.T) { + // Publisher is the first peer in sorted order. + peers := testPeers("A", "B", "C", "D") + s := NewScheduler(peers) + publisher := peer.ID("A") + + // Shard 0 -> B, Shard 1 -> C, Shard 2 -> D + expected := testPeers("B", "C", "D") + for i, exp := range expected { + got, err := s.PeerForShard(publisher, ShardIndex(i)) + require.NoError(t, err) + assert.Equal(t, exp, got) + } +} + +func TestSchedule_PeerForShard_PublisherLast(t *testing.T) { + // Publisher is the last peer in sorted order. + peers := testPeers("A", "B", "C", "D") + s := NewScheduler(peers) + publisher := peer.ID("D") + + // Shard 0 -> A, Shard 1 -> B, Shard 2 -> C + expected := testPeers("A", "B", "C") + for i, exp := range expected { + got, err := s.PeerForShard(publisher, ShardIndex(i)) + require.NoError(t, err) + assert.Equal(t, exp, got) + } +} + +func TestSchedule_PeerForShard_Errors(t *testing.T) { + peers := testPeers("A", "B", "C") + s := NewScheduler(peers) + + // Publisher not in list. + _, err := s.PeerForShard(peer.ID("Z"), 0) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") + + // Shard index out of range. + _, err = s.PeerForShard(peer.ID("A"), ShardIndex(s.NumShards())) + assert.Error(t, err) + assert.Contains(t, err.Error(), "out of range") +} + +func TestSchedule_ShardForPeer_SpecExample(t *testing.T) { + // Inverse of PeerForShard: peers [A,B,C,D], publisher=C. + // A -> shard 0, B -> shard 1, D -> shard 2 + peers := testPeers("A", "B", "C", "D") + s := NewScheduler(peers) + publisher := peer.ID("C") + + tests := []struct { + localPeer peer.ID + expected ShardIndex + }{ + {peer.ID("A"), 0}, + {peer.ID("B"), 1}, + {peer.ID("D"), 2}, + } + + for _, tc := range tests { + got, err := s.ShardForPeer(publisher, tc.localPeer) + require.NoError(t, err) + assert.Equal(t, tc.expected, got, "peer %s", tc.localPeer) + } +} + +func TestSchedule_ShardForPeer_PublisherError(t *testing.T) { + peers := testPeers("A", "B", "C") + s := NewScheduler(peers) + + // The publisher itself has no assigned shard. + _, err := s.ShardForPeer(peer.ID("B"), peer.ID("B")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "is the publisher") +} + +func TestSchedule_ShardForPeer_NotFound(t *testing.T) { + peers := testPeers("A", "B", "C") + s := NewScheduler(peers) + + _, err := s.ShardForPeer(peer.ID("A"), peer.ID("Z")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestSchedule_PeerForShardAndShardForPeer_AreInverses(t *testing.T) { + // For every publisher, verify that PeerForShard and ShardForPeer + // are consistent inverses. + peers := testPeers("A", "B", "C", "D", "E") + s := NewScheduler(peers) + + for _, publisher := range s.Peers() { + for shardIdx := range s.NumShards() { + p, err := s.PeerForShard(publisher, ShardIndex(shardIdx)) + require.NoError(t, err) + + // The reverse: given that peer, find its shard index. + gotShard, err := s.ShardForPeer(publisher, p) + require.NoError(t, err) + assert.Equal(t, ShardIndex(shardIdx), gotShard, + "publisher=%s, shard=%d, peer=%s", publisher, shardIdx, p) + } + } +} + +func TestSchedule_BroadcastTargets(t *testing.T) { + peers := testPeers("A", "B", "C", "D") + s := NewScheduler(peers) + + targets, err := s.BroadcastTargets(peer.ID("C")) + require.NoError(t, err) + + // Should be all peers except C, in shard order. + expected := testPeers("A", "B", "D") + assert.Equal(t, expected, targets) +} + +func TestSchedule_BroadcastTargets_PublisherNotFound(t *testing.T) { + peers := testPeers("A", "B") + s := NewScheduler(peers) + + _, err := s.BroadcastTargets(peer.ID("Z")) + assert.Error(t, err) +} diff --git a/consensus/propeller/sharding.go b/consensus/propeller/sharding.go new file mode 100644 index 0000000000..292042c6d6 --- /dev/null +++ b/consensus/propeller/sharding.go @@ -0,0 +1,113 @@ +package propeller + +import ( + "errors" + "fmt" + + "github.com/NethermindEth/juno/consensus/propeller/merkle" + "github.com/NethermindEth/juno/consensus/propeller/reedsolomon" + "github.com/NethermindEth/juno/consensus/propeller/utils" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" +) + +// CreatePropellerUnits creates the PropellerUnits for publishing +func CreatePropellerUnits( + committeeID CommitteeID, + message []byte, + privKey crypto.PrivKey, + numDataShards, + parity int, +) ([]PropellerUnit, error) { + publisherID, err := peer.IDFromPrivateKey(privKey) + if err != nil { + return nil, fmt.Errorf("getting publisher id from private key: %w", publisherID) + } + + paddedMessage := utils.PadMessage(message, numDataShards) + encodedMessage, err := reedsolomon.EncodeData(paddedMessage, numDataShards, parity) + if err != nil { + return nil, fmt.Errorf("encoding the message: %w", err) + } + + merkleRoot, merkleTree := merkle.New(encodedMessage) + messageRoot := MessageRoot(merkleRoot) + + signature, err := SignRoot(messageRoot, privKey) + if err != nil { + return nil, err + } + + units := make([]PropellerUnit, len(encodedMessage)) + for i, shard := range encodedMessage { + merkleProof := merkleTree[i] + + units[i] = PropellerUnit{ + CommitteeID: committeeID, + Publisher: publisherID, + MerkleRoot: messageRoot, + MerkleProof: merkleProof, + Signature: signature, + ShardIndex: ShardIndex(i), + ShardData: shard, + } + } + return units, nil +} + +// DecodePropellerUnit receives Propeller units, recovers any missing data and returns +// the fully verified message, together with the corresponding shard data and merkle proof. +func DecodePropellerUnit( + units []PropellerUnit, + messageRoot MessageRoot, + localShardIndex ShardIndex, + numDataShards int, + parity int, +) ([]byte, []byte, merkle.Proof, error) { + if len(units) == 0 { + return nil, nil, merkle.Proof{}, errors.New("no propeller units to decode") + } + + shards := make([][]byte, len(units)) + for i := range shards { + shards[i] = units[i].ShardData + } + + shards, err := reedsolomon.RecoverData(shards, numDataShards, parity) + if err != nil { + return nil, nil, merkle.Proof{}, fmt.Errorf("recovering shards data: %w", err) + } + shardSize := len(shards[0]) + for i := range numDataShards { + if shards[i] != nil && len(shards[i]) != shardSize { + return nil, nil, merkle.Proof{}, fmt.Errorf( + "missmatch on shard size: %d (at index 0) vs %d (at index %d)", + len(shards[0]), + len(shards[i]), + i, + ) + } + } + + merkleRoot, merkleTree := merkle.New(shards) + + expectedRoot := MessageRoot(merkleRoot) + if messageRoot != expectedRoot { + // todo(rdr): probably need to write string methods for the MessageRoot type + return nil, nil, merkle.Proof{}, fmt.Errorf( + "wrong message root hash. Expected %s but got %s", + &expectedRoot, + &messageRoot, + ) + } + + paddedMessage := make([]byte, len(shards[0])*len(shards)) + for i := range shards { + copy(paddedMessage[i*shardSize:], shards[i]) + } + + localShard := shards[localShardIndex] + localProof := merkleTree[localShardIndex] + + return paddedMessage, localShard, localProof, nil +} diff --git a/consensus/propeller/sharding_test.go b/consensus/propeller/sharding_test.go new file mode 100644 index 0000000000..8c908dfe41 --- /dev/null +++ b/consensus/propeller/sharding_test.go @@ -0,0 +1,204 @@ +package propeller + +import ( + "testing" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// makeSchedule is a test helper that creates a schedule from N single-char peers. +func makeSchedule(n int) *Scheduler { + names := make([]peer.ID, n) + for i := range n { + names[i] = peer.ID(string(rune('A' + i))) + } + return NewScheduler(names) +} + +func TestEncodeMessage_RoundTrip(t *testing.T) { + tests := []struct { + name string + n int + msgLen int + }{ + {"4 peers, short message", 4, 10}, + {"4 peers, medium message", 4, 500}, + {"7 peers, short message", 7, 20}, + {"10 peers, 1KB message", 10, 1024}, + {"2 peers, tiny message", 2, 1}, + {"3 peers, empty message", 3, 0}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + schedule := makeSchedule(tc.n) + if schedule.NumShards() == 0 { + t.Skip("no shards for single peer") + } + + enc, err := NewEncoder( + schedule.NumDataShards(), schedule.NumCodingShards(), + ) + require.NoError(t, err) + + msg := make([]byte, tc.msgLen) + for i := range msg { + msg[i] = byte(i) + } + + units, root, err := EncodeMessage(msg, schedule, enc) + require.NoError(t, err) + assert.Len(t, units, schedule.NumShards()) + + // All units should reference the same root. + for _, u := range units { + assert.Equal(t, root, u.MerkleRoot) + } + + // Reconstruct from all shards. + shards := make([][]byte, schedule.NumShards()) + for _, u := range units { + shards[u.ShardIndex] = u.ShardData + } + + recovered, err := ReconstructMessage( + shards, schedule, enc, root, + ) + require.NoError(t, err) + assert.Equal(t, msg, recovered) + }) + } +} + +func TestEncodeMessage_ReconstructFromMinimumShards(t *testing.T) { + // With N=10 we have 3 data shards and 6 coding shards. + // We should be able to reconstruct from just the 3 data shards. + schedule := makeSchedule(10) + enc, err := NewEncoder( + schedule.NumDataShards(), schedule.NumCodingShards(), + ) + require.NoError(t, err) + + msg := []byte("reconstruct me from minimum shards please") + units, root, err := EncodeMessage(msg, schedule, enc) + require.NoError(t, err) + + // Keep only the first numDataShards shards. + shards := make([][]byte, schedule.NumShards()) + for i := range schedule.NumDataShards() { + shards[units[i].ShardIndex] = units[i].ShardData + } + + recovered, err := ReconstructMessage(shards, schedule, enc, root) + require.NoError(t, err) + assert.Equal(t, msg, recovered) +} + +func TestEncodeMessage_ReconstructWithMissingDataShards(t *testing.T) { + // With N=7 we have 2 data shards and 4 coding shards. + // Drop all data shards, keep only coding shards -> should reconstruct. + schedule := makeSchedule(7) + enc, err := NewEncoder( + schedule.NumDataShards(), schedule.NumCodingShards(), + ) + require.NoError(t, err) + + msg := []byte("even without data shards") + units, root, err := EncodeMessage(msg, schedule, enc) + require.NoError(t, err) + + // Keep only coding shards (indices >= numDataShards). + shards := make([][]byte, schedule.NumShards()) + for _, u := range units { + if int(u.ShardIndex) >= schedule.NumDataShards() { + shards[u.ShardIndex] = u.ShardData + } + } + + recovered, err := ReconstructMessage(shards, schedule, enc, root) + require.NoError(t, err) + assert.Equal(t, msg, recovered) +} + +func TestEncodeMessage_MerkleProofsVerify(t *testing.T) { + schedule := makeSchedule(5) + enc, err := NewEncoder( + schedule.NumDataShards(), schedule.NumCodingShards(), + ) + require.NoError(t, err) + + msg := []byte("verify all proofs") + units, root, err := EncodeMessage(msg, schedule, enc) + require.NoError(t, err) + + for _, u := range units { + ok := VerifyMerkleProof(root, u.ShardData, uint32(u.ShardIndex), u.MerkleProof) + assert.True(t, ok, "proof for shard %d should verify", u.ShardIndex) + } +} + +func TestReconstructMessage_MismatchedRoot(t *testing.T) { + schedule := makeSchedule(4) + enc, err := NewEncoder( + schedule.NumDataShards(), schedule.NumCodingShards(), + ) + require.NoError(t, err) + + msg := []byte("good message") + units, _, err := EncodeMessage(msg, schedule, enc) + require.NoError(t, err) + + shards := make([][]byte, schedule.NumShards()) + for _, u := range units { + shards[u.ShardIndex] = u.ShardData + } + + // Pass a wrong root. + fakeRoot := MessageRoot{0xff} + _, err = ReconstructMessage(shards, schedule, enc, fakeRoot) + require.Error(t, err) + + var reconErr *ReconstructionError + require.ErrorAs(t, err, &reconErr) + assert.Equal(t, ReasonMismatchedMessageRoot, reconErr.Reason) +} + +func TestReconstructMessage_InsufficientShards(t *testing.T) { + schedule := makeSchedule(10) // 3 data, 6 coding + enc, err := NewEncoder( + schedule.NumDataShards(), schedule.NumCodingShards(), + ) + require.NoError(t, err) + + msg := []byte("not enough shards") + units, root, err := EncodeMessage(msg, schedule, enc) + require.NoError(t, err) + + // Provide only 2 shards when 3 are needed. + shards := make([][]byte, schedule.NumShards()) + shards[units[0].ShardIndex] = units[0].ShardData + shards[units[1].ShardIndex] = units[1].ShardData + + _, err = ReconstructMessage(shards, schedule, enc, root) + require.Error(t, err) + + var reconErr *ReconstructionError + require.ErrorAs(t, err, &reconErr) + assert.Equal(t, ReasonErasureReconstructionFailed, reconErr.Reason) +} + +func TestEncodeMessage_NoShards(t *testing.T) { + // A single-node schedule has no shards. + schedule := makeSchedule(1) + enc, err := NewEncoder(1, 0) + require.NoError(t, err) + + _, _, err = EncodeMessage([]byte("x"), schedule, enc) + require.Error(t, err) + + var pubErr *ShardPublishError + require.ErrorAs(t, err, &pubErr) + assert.Equal(t, ReasonInvalidDataSize, pubErr.Reason) +} diff --git a/consensus/propeller/timecache.go b/consensus/propeller/timecache.go new file mode 100644 index 0000000000..8f539d8605 --- /dev/null +++ b/consensus/propeller/timecache.go @@ -0,0 +1,73 @@ +package propeller + +import ( + "sync" + "time" +) + +// TimeCache is a set data structure where entries automatically expire after a +// configured TTL. It is used to remember which messages have been finalised so +// we can reject late-arriving shards without keeping state forever. +// +// The cache is safe for concurrent access. Expired entries are lazily removed: +// Contains() ignores expired entries, and Cleanup() bulk-removes them. This +// amortised approach avoids the overhead of per-entry timers. +type TimeCache[K comparable] struct { + mu sync.Mutex + entries map[K]time.Time + ttl time.Duration + // nowFn is injectable for testing. In production it is time.Now. + nowFn func() time.Time +} + +// NewTimeCache creates a cache where entries expire after the given TTL. +func NewTimeCache[K comparable](ttl time.Duration) *TimeCache[K] { + return &TimeCache[K]{ + entries: make(map[K]time.Time), + ttl: ttl, + nowFn: time.Now, + } +} + +// Add inserts a key into the cache with an expiry of now + TTL. +// If the key already exists, its expiry is refreshed. +func (c *TimeCache[K]) Add(key K) { + c.mu.Lock() + defer c.mu.Unlock() + c.entries[key] = c.nowFn().Add(c.ttl) +} + +// Contains returns true if the key is present and has not expired. +// Expired keys are treated as absent but not removed -- call Cleanup() +// periodically to reclaim memory. +func (c *TimeCache[K]) Contains(key K) bool { + c.mu.Lock() + defer c.mu.Unlock() + expiry, ok := c.entries[key] + if !ok { + return false + } + return c.nowFn().Before(expiry) +} + +// Cleanup removes all expired entries from the cache. Call this periodically +// (e.g., every N operations or on a timer) to prevent unbounded growth from +// expired entries that are never looked up again. +func (c *TimeCache[K]) Cleanup() { + c.mu.Lock() + defer c.mu.Unlock() + now := c.nowFn() + for k, expiry := range c.entries { + if !now.Before(expiry) { + delete(c.entries, k) + } + } +} + +// Len returns the total number of entries including expired ones that have +// not yet been cleaned up. Useful for testing and monitoring. +func (c *TimeCache[K]) Len() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.entries) +} diff --git a/consensus/propeller/timecache_test.go b/consensus/propeller/timecache_test.go new file mode 100644 index 0000000000..25a12b4a5b --- /dev/null +++ b/consensus/propeller/timecache_test.go @@ -0,0 +1,122 @@ +package propeller + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTimeCache_AddAndContains(t *testing.T) { + cache := NewTimeCache[string](10 * time.Second) + + assert.False(t, cache.Contains("a"), "empty cache should not contain any key") + + cache.Add("a") + assert.True(t, cache.Contains("a"), "key should be present after Add") + assert.False(t, cache.Contains("b"), "unrelated key should not be present") +} + +func TestTimeCache_Expiration(t *testing.T) { + // Use a controllable clock so we don't need real sleeps. + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + cache := NewTimeCache[string](5 * time.Second) + cache.nowFn = func() time.Time { return now } + + cache.Add("x") + assert.True(t, cache.Contains("x")) + + // Advance time to just before expiry. + now = now.Add(4 * time.Second) + assert.True(t, cache.Contains("x"), "should still be present before TTL") + + // Advance time to exactly the expiry moment. + now = now.Add(1 * time.Second) + assert.False(t, cache.Contains("x"), "should be expired at TTL boundary") + + // Advance well past expiry. + now = now.Add(10 * time.Second) + assert.False(t, cache.Contains("x"), "should be expired well after TTL") +} + +func TestTimeCache_RefreshExpiry(t *testing.T) { + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + cache := NewTimeCache[string](5 * time.Second) + cache.nowFn = func() time.Time { return now } + + cache.Add("k") + + // Advance 3 seconds, then re-add to refresh. + now = now.Add(3 * time.Second) + cache.Add("k") + + // Advance another 3 seconds -- would be expired without refresh (6s > 5s), + // but the refresh pushed the deadline to 3s+5s=8s. + now = now.Add(3 * time.Second) + assert.True(t, cache.Contains("k"), "re-add should refresh the TTL") + + // Advance past the refreshed expiry. + now = now.Add(3 * time.Second) + assert.False(t, cache.Contains("k"), "should expire after refreshed TTL") +} + +func TestTimeCache_Cleanup(t *testing.T) { + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + cache := NewTimeCache[int](2 * time.Second) + cache.nowFn = func() time.Time { return now } + + for i := range 5 { + cache.Add(i) + } + require.Equal(t, 5, cache.Len()) + + // Expire all entries. + now = now.Add(3 * time.Second) + + // They're expired but still in the map until Cleanup. + assert.Equal(t, 5, cache.Len(), "expired entries linger until Cleanup") + assert.False(t, cache.Contains(0), "expired entries should not be found") + + cache.Cleanup() + assert.Equal(t, 0, cache.Len(), "Cleanup should remove all expired entries") +} + +func TestTimeCache_CleanupPartial(t *testing.T) { + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + cache := NewTimeCache[string](5 * time.Second) + cache.nowFn = func() time.Time { return now } + + cache.Add("early") + now = now.Add(3 * time.Second) + cache.Add("late") + + // Advance so "early" is expired but "late" is not. + now = now.Add(3 * time.Second) + + cache.Cleanup() + assert.Equal(t, 1, cache.Len(), "only the expired entry should be removed") + assert.False(t, cache.Contains("early")) + assert.True(t, cache.Contains("late")) +} + +func TestTimeCache_ConcurrentAccess(t *testing.T) { + cache := NewTimeCache[int](1 * time.Second) + + var wg sync.WaitGroup + // Hammer the cache from multiple goroutines to verify no races. + for i := range 100 { + wg.Add(1) + go func(v int) { + defer wg.Done() + cache.Add(v) + cache.Contains(v) + cache.Cleanup() + }(i) + } + wg.Wait() + + // We just care that it didn't panic or race. + assert.LessOrEqual(t, cache.Len(), 100) +} diff --git a/consensus/propeller/types.go b/consensus/propeller/types.go new file mode 100644 index 0000000000..da24cd8651 --- /dev/null +++ b/consensus/propeller/types.go @@ -0,0 +1,306 @@ +// Package propeller implements an erasure-coding based message broadcast protocol +// for Byzantine fault-tolerant consensus. A publisher splits a message into shards, +// erasure-encodes them via Reed-Solomon, and distributes one shard per peer. +// Any peer can reconstruct the full message from a threshold number of shards, +// then forwards its own assigned shard to all others. +// +// The protocol tolerates up to f = floor((N-1)/3) Byzantine faulty nodes. +package propeller + +import ( + "fmt" + "time" + + "github.com/NethermindEth/juno/consensus/propeller/merkle" + "github.com/libp2p/go-libp2p/core/peer" +) + +// CommitteeID identifies a committee or logical broadcast group. Multiple committees +// can operate concurrently within the same engine, each with its own peer set. +type CommitteeID uint64 + +// ShardIndex is the position of a shard within the erasure-coded output. +// Valid range is [0, N-2] where N is the total number of peers. +type ShardIndex uint32 + +// MessageRoot is the SHA-256 Merkle root over all shard leaves. It uniquely +// identifies a message and is signed by the publisher to bind authenticity. +type MessageRoot merkle.Hash + +// Config holds tunable parameters for the propeller engine. Sensible defaults +// are provided by DefaultConfig(). +type Config struct { + // StaleMessageTimeout is how long the engine waits for a message to + // reach the receive threshold before giving up. This prevents memory + // leaks from partially-received messages that will never complete + // (e.g., due to a crashed publisher or network partition). + StaleMessageTimeout time.Duration + + // StreamProtocol is the libp2p protocol identifier used for direct + // shard transfers between peers. + StreamProtocol string + + // MaxWireMessageSize caps the size of a single serialised PropellerUnit + // on the wire. Units exceeding this are rejected to prevent memory + // exhaustion from malicious peers. + MaxWireMessageSize int +} + +// DefaultConfig returns production-ready defaults. +func DefaultConfig() Config { + return Config{ + StaleMessageTimeout: 120 * time.Second, + StreamProtocol: "/propeller/0.1.0", + MaxWireMessageSize: 1 << 20, // 1 MiB + } +} + +// PropellerUnit is the atomic wire message: one erasure-coded shard plus +// the metadata needed for independent verification. Each unit is self-contained +// so a receiver can validate it without any other shards. +type PropellerUnit struct { + CommitteeID CommitteeID // Which committee this belongs to + Publisher peer.ID // Original message author + MerkleRoot MessageRoot // Merkle root binding all shards together + MerkleProof merkle.Proof // Merkle inclusion proof for this shard + Signature []byte // Publisher's Ed25519 signature over the root + ShardIndex ShardIndex // This shard's position in the erasure-coded output + ShardData []byte // The actual data fragment +} + +// messageKey uniquely identifies a message within a channel. We track +// per-message state (processor, time cache) using this composite key +// because the same publisher could broadcast different messages (different +// roots) and we need to handle each independently. +type messageKey struct { + Channel CommitteeID + Publisher peer.ID + Root MessageRoot +} + +// --------------------------------------------------------------------------- +// Events: structured outputs from the engine to the application layer. +// Each event is emitted at most once per message lifecycle. +// --------------------------------------------------------------------------- + +// EventMessageReceived signals that a message has been fully reconstructed +// and enough shards have been forwarded to guarantee delivery to all honest +// nodes. The application can safely process the contained message bytes. +type EventMessageReceived struct { + Publisher peer.ID + Root MessageRoot + Message []byte +} + +// EventReconstructionFailed signals that Reed-Solomon reconstruction or +// post-reconstruction verification failed. This typically indicates Byzantine +// behaviour from the publisher (e.g., inconsistent shards). +type EventReconstructionFailed struct { + Root MessageRoot + Publisher peer.ID + Err error +} + +// EventShardPublishFailed signals that the local node failed to encode or +// distribute shards when acting as publisher. +type EventShardPublishFailed struct { + Err error +} + +// EventShardSendFailed signals that sending a single shard to a specific +// peer failed. The engine continues sending to other peers; this is +// informational for monitoring. +type EventShardSendFailed struct { + From peer.ID + To peer.ID + Err error +} + +// EventShardValidationFailed signals that an incoming shard was rejected +// during validation. This may indicate Byzantine behaviour from the sender +// or publisher. +type EventShardValidationFailed struct { + Sender peer.ID + ClaimedRoot MessageRoot + ClaimedPublisher peer.ID + Err error +} + +// EventMessageTimeout signals that a message did not reach the receive +// threshold before the stale message timeout elapsed. The engine cleans +// up state for this message. +type EventMessageTimeout struct { + Channel CommitteeID + Publisher peer.ID + Root MessageRoot +} + +// reasonUnknown is the string representation for unrecognised enum values. +// Extracted as a constant to satisfy goconst. +const reasonUnknown = "unknown" + +// --------------------------------------------------------------------------- +// Error types: structured errors for each failure domain. +// Using typed errors rather than sentinel values lets callers inspect the +// specific failure reason programmatically. +// --------------------------------------------------------------------------- + +// ShardValidationReason enumerates the specific causes of shard rejection. +type ShardValidationReason int + +const ( + // ReasonSelfSending means a peer sent us a unit claiming to be from us. + ReasonSelfSending ShardValidationReason = iota + // ReasonReceivedSelfPublishedShard means we received a shard for a + // message we published ourselves -- we already have all shards. + ReasonReceivedSelfPublishedShard + // ReasonDuplicateShard means we already have a shard at this index + // for this message. + ReasonDuplicateShard + // ReasonUnexpectedSender means the sender is not the peer assigned + // to broadcast this shard index. + ReasonUnexpectedSender + // ReasonSignatureVerificationFailed means the publisher's signature + // over the Merkle root did not verify. + ReasonSignatureVerificationFailed + // ReasonMerkleProofVerificationFailed means the Merkle inclusion + // proof for this shard is invalid. + ReasonMerkleProofVerificationFailed + // ReasonScheduleError means the shard-to-peer mapping lookup failed + // (e.g., publisher not in the channel's peer set). + ReasonScheduleError +) + +func (r ShardValidationReason) String() string { + switch r { + case ReasonSelfSending: + return "self_sending" + case ReasonReceivedSelfPublishedShard: + return "received_self_published_shard" + case ReasonDuplicateShard: + return "duplicate_shard" + case ReasonUnexpectedSender: + return "unexpected_sender" + case ReasonSignatureVerificationFailed: + return "signature_verification_failed" + case ReasonMerkleProofVerificationFailed: + return "merkle_proof_verification_failed" + case ReasonScheduleError: + return "schedule_error" + default: + return reasonUnknown + } +} + +// ShardValidationError is returned when an incoming PropellerUnit fails +// validation. The Reason field allows programmatic inspection; the Detail +// field carries human-readable context. +type ShardValidationError struct { + Reason ShardValidationReason + Detail string +} + +func (e *ShardValidationError) Error() string { + return fmt.Sprintf("shard validation failed (%s): %s", e.Reason, e.Detail) +} + +// ReconstructionReason enumerates the specific causes of reconstruction failure. +type ReconstructionReason int + +const ( + // ReasonErasureReconstructionFailed means Reed-Solomon decoding failed, + // likely because too many shards are missing or corrupted. + ReasonErasureReconstructionFailed ReconstructionReason = iota + // ReasonMismatchedMessageRoot means the Merkle root computed from the + // reconstructed shards does not match the claimed root. This indicates + // Byzantine behaviour from the publisher. + ReasonMismatchedMessageRoot + // ReasonUnequalShardLengths means shards have inconsistent lengths, + // which violates Reed-Solomon's equal-length requirement. + ReasonUnequalShardLengths + // ReasonMessagePaddingError means the varint length prefix in the + // unpadded message is malformed or points beyond the data. + ReasonMessagePaddingError +) + +func (r ReconstructionReason) String() string { + switch r { + case ReasonErasureReconstructionFailed: + return "erasure_reconstruction_failed" + case ReasonMismatchedMessageRoot: + return "mismatched_message_root" + case ReasonUnequalShardLengths: + return "unequal_shard_lengths" + case ReasonMessagePaddingError: + return "message_padding_error" + default: + return reasonUnknown + } +} + +// ReconstructionError is returned when message reconstruction fails after +// collecting enough shards. +type ReconstructionError struct { + Reason ReconstructionReason + Detail string +} + +func (e *ReconstructionError) Error() string { + return fmt.Sprintf("reconstruction failed (%s): %s", e.Reason, e.Detail) +} + +// ShardPublishReason enumerates the specific causes of publish failure. +type ShardPublishReason int + +const ( + // ReasonLocalPeerNotInChannel means the local peer is not a member + // of the channel it is trying to broadcast on. + ReasonLocalPeerNotInChannel ShardPublishReason = iota + // ReasonInvalidDataSize means the message is too large to encode. + ReasonInvalidDataSize + // ReasonSigningFailed means the local private key failed to sign. + ReasonSigningFailed + // ReasonEncodingFailed means Reed-Solomon encoding failed. + ReasonEncodingFailed + // ReasonNotConnectedToPeer means we have no open connection to a + // target peer. + ReasonNotConnectedToPeer + // ReasonChannelNotRegistered means the channel has not been registered + // with the engine. + ReasonChannelNotRegistered + // ReasonBroadcastFailed means the broadcast operation failed for an + // unspecified reason. + ReasonBroadcastFailed +) + +func (r ShardPublishReason) String() string { + switch r { + case ReasonLocalPeerNotInChannel: + return "local_peer_not_in_channel" + case ReasonInvalidDataSize: + return "invalid_data_size" + case ReasonSigningFailed: + return "signing_failed" + case ReasonEncodingFailed: + return "encoding_failed" + case ReasonNotConnectedToPeer: + return "not_connected_to_peer" + case ReasonChannelNotRegistered: + return "channel_not_registered" + case ReasonBroadcastFailed: + return "broadcast_failed" + default: + return reasonUnknown + } +} + +// todo(rdr): check if we want to do this. I think it is better not, unless necessary +// ShardPublishError is returned when the local node fails to publish shards. +type ShardPublishError struct { + Reason ShardPublishReason + Detail string +} + +func (e *ShardPublishError) Error() string { + return fmt.Sprintf("shard publish failed (%s): %s", e.Reason, e.Detail) +} diff --git a/consensus/propeller/utils/padding.go b/consensus/propeller/utils/padding.go new file mode 100644 index 0000000000..465d279622 --- /dev/null +++ b/consensus/propeller/utils/padding.go @@ -0,0 +1,60 @@ +package utils + +import ( + "encoding/binary" + "fmt" +) + +// PadMessage prepends an unsigned varint-encoded length to the message and +// pads the result with zeros so the total length is divisible by +// 2*numDataShards. +// +// The varint prefix lets the receiver recover the exact original message +// length after reconstruction. The zero-padding ensures the padded message +// can be evenly split into numDataShards pieces, which is required by +// Reed-Solomon encoding (all shards must be equal length). +// +// Layout: [varint(len(msg))] [msg bytes] [zero padding] +func PadMessage(msg []byte, numDataShards int) []byte { + // Compute the varint-encoded length prefix. + var varintBuf [binary.MaxVarintLen64]byte + varintLen := binary.PutUvarint(varintBuf[:], uint64(len(msg))) + + unpaddedMsgLen := uint64(varintLen + len(msg)) + + // Round up to the next multiple of divisor. + divisor := uint64(2 * numDataShards) + paddedMsgLen := unpaddedMsgLen + if remainder := paddedMsgLen % divisor; remainder != 0 { + paddedMsgLen += divisor - remainder + } + + result := make([]byte, paddedMsgLen) + copy(result, varintBuf[:varintLen]) + copy(result[varintLen:], msg) + // Remaining bytes are already zero (Go slice initialization). + + return result +} + +// UnpadMessage performs the reverse operation to PadMessage: it reads the varint length prefix and +// extracts the original message bytes, discarding the zero padding. The slice returned uses the +// input's backing array but re-sliced to start and finish on the original message. +// +// An error is returned if the varint is malformed or the encoded length exceeds the available data. +func UnpadMessage(padded []byte) ([]byte, error) { + msgLen, varintLen := binary.Uvarint(padded) + if varintLen <= 0 { + return nil, fmt.Errorf("invalid varint prefix in padded message: %d", varintLen) + } + + end := uint64(varintLen) + msgLen + if end > uint64(len(padded)) { + return nil, fmt.Errorf( + "varint length %d exceeds available data (have %d bytes after prefix)", + msgLen, len(padded)-varintLen, + ) + } + + return padded[varintLen:msgLen], nil +} diff --git a/consensus/propeller/utils/padding_test.go b/consensus/propeller/utils/padding_test.go new file mode 100644 index 0000000000..9d29a3dcfc --- /dev/null +++ b/consensus/propeller/utils/padding_test.go @@ -0,0 +1,108 @@ +package utils_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPadMessage_RoundTrip(t *testing.T) { + tests := []struct { + name string + msg []byte + numDataShards int + }{ + { + name: "empty message, 1 shard", + msg: []byte{}, + numDataShards: 1, + }, + { + name: "small message, 1 shard", + msg: []byte("hello"), + numDataShards: 1, + }, + { + name: "small message, 3 shards", + msg: []byte("hello world"), + numDataShards: 3, + }, + { + name: "message exactly divisible", + msg: make([]byte, 6), // varint(6)=1 byte, total=7, divisor=2*1=2 -> pad to 8 + numDataShards: 1, + }, + { + name: "larger message, 10 shards", + msg: make([]byte, 1000), + numDataShards: 10, + }, + { + name: "single byte", + msg: []byte{0x42}, + numDataShards: 5, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + padded := PadMessage(tc.msg, tc.numDataShards) + + // Verify divisibility. + divisor := 2 * tc.numDataShards + assert.Equal(t, 0, len(padded)%divisor, + "padded length %d should be divisible by %d", len(padded), divisor) + + // Verify round-trip. + recovered, err := UnpadMessage(padded) + require.NoError(t, err) + assert.Equal(t, tc.msg, recovered) + }) + } +} + +func TestPadMessage_Alignment(t *testing.T) { + // Verify that padding produces the minimum size that is a multiple of divisor. + msg := []byte("ab") // 2 bytes + padded := PadMessage(msg, 3) // divisor = 6 + // varint(2) = 1 byte, payload = 3 bytes, next multiple of 6 = 6 + assert.Equal(t, 6, len(padded)) +} + +func TestUnpadMessage_InvalidVarint(t *testing.T) { + // An empty buffer has no valid varint. + _, err := UnpadMessage([]byte{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid varint") +} + +func TestUnpadMessage_LengthExceedsData(t *testing.T) { + // Manually encode a varint claiming 100 bytes, but only provide 5. + buf := make([]byte, 6) + buf[0] = 100 // varint encoding of 100 + copy(buf[1:], "short") + + _, err := UnpadMessage(buf) + assert.Error(t, err) + assert.Contains(t, err.Error(), "exceeds available data") +} + +func TestUnpadMessage_Truncated(t *testing.T) { + // Encode a valid varint pointing past the end. + buf := []byte{0x80, 0x01} // varint 128, but only 2 bytes total + _, err := UnpadMessage(buf) + assert.Error(t, err) +} + +func TestPadMessage_LargeVarint(t *testing.T) { + // A message large enough to need a multi-byte varint. + msg := make([]byte, 300) + for i := range msg { + msg[i] = byte(i) + } + padded := PadMessage(msg, 4) + recovered, err := UnpadMessage(padded) + require.NoError(t, err) + assert.Equal(t, msg, recovered) +} diff --git a/consensus/propeller/utils/signing.go b/consensus/propeller/utils/signing.go new file mode 100644 index 0000000000..4d66eb5ee2 --- /dev/null +++ b/consensus/propeller/utils/signing.go @@ -0,0 +1,33 @@ +package utils + +import ( + "fmt" + + "github.com/libp2p/go-libp2p/core/crypto" +) + +// SignPayload constructs the byte sequence that the publisher signs: +// +// "" || root[0:32] || "" +// +// The tags domain-separate propeller signatures from any other protocol +// that might use the same key, preventing cross-protocol signature reuse. +func SignPayload[T ~[32]byte](root T) []byte { + payload := make([]byte, 0, len("")+32+len("")) + payload = append(payload, []byte("")...) + payload = append(payload, root[:]...) + payload = append(payload, []byte("")...) + return payload +} + +// todo(rdr): verify this is correct +// SignRoot signs the Merkle root with the given private key, producing the +// signature that goes into every PropellerUnit for this message. +func SignRoot[T ~[32]byte](root T, privKey crypto.PrivKey) ([]byte, error) { + payload := SignPayload(root) + sig, err := privKey.Sign(payload) + if err != nil { + return nil, fmt.Errorf("signing message root: %w", err) + } + return sig, nil +} diff --git a/consensus/propeller/utils/signing_test.go b/consensus/propeller/utils/signing_test.go new file mode 100644 index 0000000000..90c372d319 --- /dev/null +++ b/consensus/propeller/utils/signing_test.go @@ -0,0 +1 @@ +package utils_test diff --git a/consensus/propeller/validator.go b/consensus/propeller/validator.go new file mode 100644 index 0000000000..9155e7198c --- /dev/null +++ b/consensus/propeller/validator.go @@ -0,0 +1,171 @@ +package propeller + +import ( + "fmt" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" +) + +// todo(rdr): Need to review this whole module + +// SignatureVerifier abstracts Ed25519 signature verification. The default +// implementation extracts the public key from a peer.ID and verifies using +// libp2p crypto. Tests can inject a mock to control verification outcomes. +type SignatureVerifier interface { + Verify(peerID peer.ID, data, signature []byte) (bool, error) +} + +// DefaultSignatureVerifier implements SignatureVerifier by extracting the +// public key embedded in a libp2p peer.ID. This works because peer.IDs +// for Ed25519 keys are derived from the public key, and for small keys +// the public key is embedded directly in the ID. +type DefaultSignatureVerifier struct{} + +func (DefaultSignatureVerifier) Verify( + peerID peer.ID, data, signature []byte, +) (bool, error) { + pubKey, err := peerID.ExtractPublicKey() + if err != nil { + return false, fmt.Errorf("extracting public key from peer %s: %w", peerID, err) + } + return pubKey.Verify(data, signature) +} + +// Validator checks incoming PropellerUnits for correctness. Each check +// serves a specific defensive purpose: +// +// - Self-sending check: prevents reflection attacks. +// - Self-published check: we already have all shards for our own messages. +// - Duplicate check: avoids redundant work and state corruption. +// - Origin check: ensures the sender is the peer assigned to this shard, +// preventing sybil-like relay attacks. +// - Merkle proof check: ensures the shard data is authentic (matches the +// committed tree root). +// - Signature check: ensures the publisher actually authored the message +// (the root they committed to). +// +// These checks are ordered from cheapest to most expensive so we reject +// invalid units as early as possible. +type Validator struct { + schedule *Scheduler + localPeer peer.ID + verifier SignatureVerifier +} + +// NewValidator creates a validator for the given channel configuration. +func NewValidator( + schedule *Scheduler, + localPeer peer.ID, + verifier SignatureVerifier, +) *Validator { + return &Validator{ + schedule: schedule, + localPeer: localPeer, + verifier: verifier, + } +} + +// ValidateUnit checks an incoming unit against all validation rules. +// +// Parameters: +// - unit: the incoming PropellerUnit to validate. +// - sender: the peer.ID of the network peer that sent us this unit. +// - seenShards: set of shard indices already received for this message +// (used for duplicate detection). +// - signatureVerified: true if we have already verified the publisher's +// signature for this Merkle root. Allows skipping the expensive crypto +// check after the first shard from the same message passes. +// +// Returns nil if valid, or a *ShardValidationError describing the failure. +func (v *Validator) ValidateUnit( + unit *PropellerUnit, + sender peer.ID, + seenShards map[ShardIndex]bool, + signatureVerified bool, +) error { + // 1. Reject units from ourselves (should never happen in normal + // operation; indicates a routing bug or reflection attack). + if sender == v.localPeer { + return &ShardValidationError{ + Reason: ReasonSelfSending, + Detail: "received unit from ourselves", + } + } + + // 2. Reject units for messages we published (we already have all + // shards and don't need them relayed back). + if unit.Publisher == v.localPeer { + return &ShardValidationError{ + Reason: ReasonReceivedSelfPublishedShard, + Detail: "received shard for a message we published", + } + } + + // 3. Reject duplicate shards. A well-behaved peer sends each shard + // exactly once; duplicates waste bandwidth and could corrupt state. + if seenShards[unit.ShardIndex] { + return &ShardValidationError{ + Reason: ReasonDuplicateShard, + Detail: fmt.Sprintf("already received shard %d", unit.ShardIndex), + } + } + + // 4. Verify the sender is either the peer assigned to broadcast this + // shard or the publisher itself (who initially distributes all shards). + // This prevents a Byzantine node from impersonating another peer's + // shard assignment while still allowing the publisher's initial send. + expectedPeer, err := v.schedule.PeerForShard(unit.Publisher, unit.ShardIndex) + if err != nil { + return &ShardValidationError{ + Reason: ReasonScheduleError, + Detail: fmt.Sprintf( + "looking up peer for shard %d: %v", unit.ShardIndex, err, + ), + } + } + if sender != expectedPeer && sender != unit.Publisher { + return &ShardValidationError{ + Reason: ReasonUnexpectedSender, + Detail: fmt.Sprintf( + "shard %d should come from %s or publisher %s, got %s", + unit.ShardIndex, expectedPeer, unit.Publisher, sender, + ), + } + } + + // 5. Verify the Merkle inclusion proof. This ensures the shard data + // is consistent with the tree root the publisher committed to. + if !VerifyMerkleProof( + unit.MerkleRoot, unit.ShardData, uint32(unit.ShardIndex), unit.MerkleProof, + ) { + return &ShardValidationError{ + Reason: ReasonMerkleProofVerificationFailed, + Detail: fmt.Sprintf("merkle proof invalid for shard %d", unit.ShardIndex), + } + } + + // 6. Verify the publisher's signature over the root. This is the most + // expensive check (public-key crypto), so we skip it if we've already + // verified the same root from this publisher. + if !signatureVerified { + payload := SignPayload(unit.MerkleRoot) + ok, err := v.verifier.Verify(unit.Publisher, payload, unit.Signature) + if err != nil { + return &ShardValidationError{ + Reason: ReasonSignatureVerificationFailed, + Detail: fmt.Sprintf( + "verifying signature: %v", err, + ), + } + } + if !ok { + return &ShardValidationError{ + Reason: ReasonSignatureVerificationFailed, + Detail: "signature does not match publisher's public key", + } + } + } + + return nil +} diff --git a/consensus/propeller/validator_test.go b/consensus/propeller/validator_test.go new file mode 100644 index 0000000000..3f4e48fe38 --- /dev/null +++ b/consensus/propeller/validator_test.go @@ -0,0 +1,364 @@ +package propeller + +import ( + "bytes" + "crypto/ed25519" + "fmt" + "testing" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockVerifier is a test double for SignatureVerifier that returns +// configurable results. +type mockVerifier struct { + valid bool + err error +} + +func (m *mockVerifier) Verify(peer.ID, []byte, []byte) (bool, error) { + return m.valid, m.err +} + +// realPeer creates a real libp2p peer.ID from a deterministic Ed25519 seed. +// Real peer IDs are needed because DefaultSignatureVerifier extracts the +// public key from the peer ID, which only works for keys encoded into the ID. +func realPeer(seed byte) (crypto.PrivKey, peer.ID) { + seedBytes := make([]byte, ed25519.SeedSize) + seedBytes[0] = seed + reader := bytes.NewReader(seedBytes) + priv, pub, err := crypto.GenerateEd25519Key(reader) + if err != nil { + panic(err) + } + id, err := peer.IDFromPublicKey(pub) + if err != nil { + panic(err) + } + return priv, id +} + +// makeValidUnit creates a PropellerUnit that passes all validation checks +// for the given schedule, publisher, and shard index. The unit has a valid +// Merkle proof and signature. +func makeValidUnit( + t *testing.T, + schedule *Scheduler, + publisherKey crypto.PrivKey, + publisher peer.ID, + shardIndex ShardIndex, +) *PropellerUnit { + t.Helper() + + // Create a simple message and encode it. + msg := []byte("test message for validation") + enc, err := NewEncoder(schedule.NumDataShards(), schedule.NumCodingShards()) + require.NoError(t, err) + + units, root, err := EncodeMessage(msg, schedule, enc) + require.NoError(t, err) + + sig, err := SignRoot(root, publisherKey) + require.NoError(t, err) + + unit := &units[shardIndex] + unit.Publisher = publisher + unit.Signature = sig + unit.CommitteeID = 1 + return unit +} + +// validatorTestSetup creates a realistic N-peer environment and returns +// the schedule, a local peer (which is NOT the publisher), the publisher's +// key and ID, and a shard index that maps to a sender who is neither +// localPeer nor publisher. +type validatorTestSetup struct { + schedule *Scheduler + localPeer peer.ID + publisher peer.ID + publisherKey crypto.PrivKey + // shardIndex and expectedSender: a shard whose assigned sender is a + // third peer (neither localPeer nor publisher). + shardIndex ShardIndex + expectedSender peer.ID +} + +func newValidatorTestSetup(t *testing.T) validatorTestSetup { + t.Helper() + + // Create 5 peers so we have enough room to find a shard where the + // sender is a third party. + n := 5 + keys := make([]crypto.PrivKey, n) + ids := make([]peer.ID, n) + for i := range n { + keys[i], ids[i] = realPeer(byte(i)) + } + + schedule := NewScheduler(ids) + sorted := schedule.Peers() + + // Pick localPeer = sorted[0], publisher = sorted[1]. + localPeer := sorted[0] + publisher := sorted[1] + + // Find the publisher's private key. + var publisherKey crypto.PrivKey + for i, id := range ids { + if id == publisher { + publisherKey = keys[i] + break + } + } + require.NotNil(t, publisherKey) + + // Find a shard whose expected sender is NOT localPeer. + var shardIndex ShardIndex + var expectedSender peer.ID + found := false + for si := range schedule.NumShards() { + s, err := schedule.PeerForShard(publisher, ShardIndex(si)) + require.NoError(t, err) + if s != localPeer { + shardIndex = ShardIndex(si) + expectedSender = s + found = true + break + } + } + require.True(t, found, "could not find a shard with a third-party sender") + + return validatorTestSetup{ + schedule: schedule, + localPeer: localPeer, + publisher: publisher, + publisherKey: publisherKey, + shardIndex: shardIndex, + expectedSender: expectedSender, + } +} + +func TestValidator_HappyPath(t *testing.T) { + setup := newValidatorTestSetup(t) + v := NewValidator(setup.schedule, setup.localPeer, &DefaultSignatureVerifier{}) + + unit := makeValidUnit( + t, setup.schedule, setup.publisherKey, + setup.publisher, setup.shardIndex, + ) + + seenShards := make(map[ShardIndex]bool) + err := v.ValidateUnit(unit, setup.expectedSender, seenShards, false) + assert.NoError(t, err) +} + +func TestValidator_SelfSending(t *testing.T) { + setup := newValidatorTestSetup(t) + v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) + + unit := &PropellerUnit{Publisher: setup.publisher, ShardIndex: setup.shardIndex} + err := v.ValidateUnit(unit, setup.localPeer, nil, true) + + var valErr *ShardValidationError + require.ErrorAs(t, err, &valErr) + assert.Equal(t, ReasonSelfSending, valErr.Reason) +} + +func TestValidator_ReceivedSelfPublishedShard(t *testing.T) { + setup := newValidatorTestSetup(t) + v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) + + // Unit claims we are the publisher. + unit := &PropellerUnit{Publisher: setup.localPeer, ShardIndex: 0} + err := v.ValidateUnit(unit, setup.expectedSender, nil, true) + + var valErr *ShardValidationError + require.ErrorAs(t, err, &valErr) + assert.Equal(t, ReasonReceivedSelfPublishedShard, valErr.Reason) +} + +func TestValidator_DuplicateShard(t *testing.T) { + setup := newValidatorTestSetup(t) + v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) + + unit := &PropellerUnit{Publisher: setup.publisher, ShardIndex: setup.shardIndex} + seenShards := map[ShardIndex]bool{setup.shardIndex: true} + + err := v.ValidateUnit(unit, setup.expectedSender, seenShards, true) + var valErr *ShardValidationError + require.ErrorAs(t, err, &valErr) + assert.Equal(t, ReasonDuplicateShard, valErr.Reason) +} + +func TestValidator_UnexpectedSender(t *testing.T) { + setup := newValidatorTestSetup(t) + v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) + + // Find a peer that is NOT the expected sender, NOT localPeer, and NOT + // the publisher. The publisher is now an accepted sender for any shard, + // so it must be excluded from the "wrong sender" set. + var wrongSender peer.ID + for _, p := range setup.schedule.Peers() { + if p != setup.expectedSender && p != setup.localPeer && p != setup.publisher { + wrongSender = p + break + } + } + require.NotEmpty(t, wrongSender) + + unit := &PropellerUnit{Publisher: setup.publisher, ShardIndex: setup.shardIndex} + seenShards := make(map[ShardIndex]bool) + err := v.ValidateUnit(unit, wrongSender, seenShards, true) + + var valErr *ShardValidationError + require.ErrorAs(t, err, &valErr) + assert.Equal(t, ReasonUnexpectedSender, valErr.Reason) +} + +func TestValidator_PublisherAsAcceptedSender(t *testing.T) { + setup := newValidatorTestSetup(t) + v := NewValidator(setup.schedule, setup.localPeer, &DefaultSignatureVerifier{}) + + // The publisher initially distributes all shards, so it should be + // accepted as a sender for any shard -- even one assigned to another peer. + unit := makeValidUnit( + t, setup.schedule, setup.publisherKey, + setup.publisher, setup.shardIndex, + ) + + seenShards := make(map[ShardIndex]bool) + err := v.ValidateUnit(unit, setup.publisher, seenShards, false) + assert.NoError(t, err, "publisher should be accepted as sender for any shard") +} + +func TestValidator_MerkleProofFailed(t *testing.T) { + setup := newValidatorTestSetup(t) + v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) + + // Create a unit with a bad Merkle proof. + unit := &PropellerUnit{ + Publisher: setup.publisher, + ShardIndex: setup.shardIndex, + MerkleRoot: MessageRoot{0x01}, + ShardData: []byte("data"), + MerkleProof: MerkleProof{Siblings: [][32]byte{{0xde, 0xad}}}, + } + + seenShards := make(map[ShardIndex]bool) + err := v.ValidateUnit(unit, setup.expectedSender, seenShards, true) + + var valErr *ShardValidationError + require.ErrorAs(t, err, &valErr) + assert.Equal(t, ReasonMerkleProofVerificationFailed, valErr.Reason) +} + +func TestValidator_SignatureVerificationFailed(t *testing.T) { + setup := newValidatorTestSetup(t) + + // Use a verifier that rejects signatures. + v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: false}) + + unit := makeValidUnit( + t, setup.schedule, setup.publisherKey, + setup.publisher, setup.shardIndex, + ) + + seenShards := make(map[ShardIndex]bool) + err := v.ValidateUnit(unit, setup.expectedSender, seenShards, false) + + var valErr *ShardValidationError + require.ErrorAs(t, err, &valErr) + assert.Equal(t, ReasonSignatureVerificationFailed, valErr.Reason) +} + +func TestValidator_SignatureVerificationError(t *testing.T) { + setup := newValidatorTestSetup(t) + + // Use a verifier that returns an error. + v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{ + valid: false, + err: fmt.Errorf("key extraction failed"), + }) + + unit := makeValidUnit( + t, setup.schedule, setup.publisherKey, + setup.publisher, setup.shardIndex, + ) + + seenShards := make(map[ShardIndex]bool) + err := v.ValidateUnit(unit, setup.expectedSender, seenShards, false) + + var valErr *ShardValidationError + require.ErrorAs(t, err, &valErr) + assert.Equal(t, ReasonSignatureVerificationFailed, valErr.Reason) + assert.Contains(t, valErr.Detail, "key extraction failed") +} + +func TestValidator_SkipSignatureWhenAlreadyVerified(t *testing.T) { + setup := newValidatorTestSetup(t) + + // Verifier that would reject -- but we pass signatureVerified=true. + v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: false}) + + unit := makeValidUnit( + t, setup.schedule, setup.publisherKey, + setup.publisher, setup.shardIndex, + ) + + seenShards := make(map[ShardIndex]bool) + err := v.ValidateUnit(unit, setup.expectedSender, seenShards, true) + assert.NoError(t, err, "should skip signature check when already verified") +} + +func TestSignPayload(t *testing.T) { + root := MessageRoot{0x01, 0x02, 0x03} + payload := SignPayload(root) + + expected := append([]byte(""), root[:]...) + expected = append(expected, []byte("")...) + assert.Equal(t, expected, payload) +} + +func TestSignRoot_RoundTrip(t *testing.T) { + privKey, peerID := realPeer(42) + + root := MessageRoot{0xaa, 0xbb, 0xcc} + sig, err := SignRoot(root, privKey) + require.NoError(t, err) + require.NotEmpty(t, sig) + + // Verify with the default verifier. + verifier := DefaultSignatureVerifier{} + payload := SignPayload(root) + ok, err := verifier.Verify(peerID, payload, sig) + require.NoError(t, err) + assert.True(t, ok) + + // Wrong root should fail. + wrongRoot := MessageRoot{0xff} + wrongPayload := SignPayload(wrongRoot) + ok, err = verifier.Verify(peerID, wrongPayload, sig) + require.NoError(t, err) + assert.False(t, ok) +} + +func TestValidator_ScheduleError(t *testing.T) { + setup := newValidatorTestSetup(t) + v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) + + _, unknownPeer := realPeer(99) + unit := &PropellerUnit{ + Publisher: unknownPeer, + ShardIndex: 0, + ShardData: []byte("data"), + MerkleProof: MerkleProof{}, + } + + err := v.ValidateUnit(unit, setup.expectedSender, make(map[ShardIndex]bool), true) + var valErr *ShardValidationError + require.ErrorAs(t, err, &valErr) + assert.Equal(t, ReasonScheduleError, valErr.Reason) +} From d5475ca6068237dfb4e91660cdb2df08bff7bbb6 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 1 Apr 2026 10:55:21 +0100 Subject: [PATCH 05/40] refactor: rename processor to deprecated_processor --- .../{processor.go => deprecated_processor.go} | 102 ++++++++---------- ...r_test.go => deprecated_processor_test.go} | 8 +- 2 files changed, 51 insertions(+), 59 deletions(-) rename consensus/propeller/{processor.go => deprecated_processor.go} (85%) rename consensus/propeller/{processor_test.go => deprecated_processor_test.go} (98%) diff --git a/consensus/propeller/processor.go b/consensus/propeller/deprecated_processor.go similarity index 85% rename from consensus/propeller/processor.go rename to consensus/propeller/deprecated_processor.go index 33a2b06c13..22af75641f 100644 --- a/consensus/propeller/processor.go +++ b/consensus/propeller/deprecated_processor.go @@ -32,15 +32,11 @@ const ( stateFinalised ) -// SendUnitFunc is called by the processor to send a PropellerUnit to a -// specific peer. The engine provides this callback, which handles the -// actual network I/O. The processor doesn't know or care how delivery works. -type SendUnitFunc func(ctx context.Context, to peer.ID, unit *PropellerUnit) error - // shardDelivery bundles an incoming shard with the peer that sent it, // so the processor can validate the sender identity. +// todo(rdr): a better name for this type shardDelivery struct { - Unit *PropellerUnit + Unit *Unit Sender peer.ID } @@ -57,32 +53,27 @@ type shardDelivery struct { // The processor is deliberately simple -- it owns no locks and communicates // entirely through channels. All mutable state is confined to its goroutine. type MessageProcessor struct { - // Identity and configuration. - channel CommitteeID - publisher peer.ID - root MessageRoot - localPeer peer.ID - config Config - - // Dependencies (injected for testability). - schedule *Scheduler - validator *Validator - encoder Encoder - - // State. + // Identity + committeeID CommitteeID + publisher peer.ID + root MessageRoot + + // Config + timeout time.Duration + + // Internal State. state processorState shards [][]byte // indexed by ShardIndex, nil = not yet received - seenShards map[ShardIndex]bool + seenShards map[ShardIndex]struct{} receivedCount int signatureVerified bool storedSignature []byte // cached from the first valid unit reconstructedMsg []byte - myShardUnit *PropellerUnit // the unit we are responsible for forwarding + myShardUnit *Unit // the unit we are responsible for forwarding // Channels. shardCh chan shardDelivery // incoming shards from the engine eventCh chan<- any // outgoing events to the engine/application - sendFn SendUnitFunc // callback for sending units to peers } // NewMessageProcessor creates a processor for a specific message. The caller @@ -108,20 +99,20 @@ func NewMessageProcessor( sendFn SendUnitFunc, ) *MessageProcessor { return &MessageProcessor{ - channel: channel, - publisher: publisher, - root: root, - localPeer: localPeer, - config: config, - schedule: schedule, - validator: validator, - encoder: encoder, - state: statePreConstruction, - shards: make([][]byte, schedule.NumShards()), - seenShards: make(map[ShardIndex]bool), - shardCh: shardCh, - eventCh: eventCh, - sendFn: sendFn, + committeeID: channel, + publisher: publisher, + root: root, + localPeer: localPeer, + config: config, + schedule: schedule, + validator: validator, + encoder: encoder, + state: statePreConstruction, + shards: make([][]byte, schedule.NumShards()), + seenShards: make(map[ShardIndex]bool), + shardCh: shardCh, + eventCh: eventCh, + sendFn: sendFn, } } @@ -131,38 +122,39 @@ func NewMessageProcessor( // The select on shardCh vs timer is the core of the state machine. We // intentionally use a single goroutine to avoid any need for synchronisation // on the processor's internal state. -func (p *MessageProcessor) Run(ctx context.Context) { - timer := time.NewTimer(p.config.StaleMessageTimeout) +func (p *MessageProcessor) Run(ctx context.Context) error { + timer := time.NewTimer(p.timeout) defer timer.Stop() for { select { case <-ctx.Done(): - return + return ctx.Err() + case <-timer.C: + if p.state != stateFinalised { + p.emitEvent(EventMessageTimeout{ + Channel: p.committeeID, + Publisher: p.publisher, + Root: p.root, + }) + p.state = stateFinalised + } + // throw an error processor is stopped after timeout? + return nil case delivery, ok := <-p.shardCh: if !ok { - // Channel closed by engine; processor is being torn down. - return + // Channel closed by engine; processor is being shot down. + return nil } if p.state == stateFinalised { - return + return nil } p.handleShard(ctx, delivery) if p.state == stateFinalised { - return + return nil } - case <-timer.C: - if p.state != stateFinalised { - p.emitEvent(EventMessageTimeout{ - Channel: p.channel, - Publisher: p.publisher, - Root: p.root, - }) - p.state = stateFinalised - } - return } } } @@ -257,8 +249,8 @@ func (p *MessageProcessor) handlePreConstruction(ctx context.Context) { copy(leaves, shardsCopy) _, proofs := BuildMerkleTree(leaves) - p.myShardUnit = &PropellerUnit{ - CommitteeID: p.channel, + p.myShardUnit = &Unit{ + CommitteeID: p.committeeID, Publisher: p.publisher, MerkleRoot: p.root, Signature: p.storedSignature, diff --git a/consensus/propeller/processor_test.go b/consensus/propeller/deprecated_processor_test.go similarity index 98% rename from consensus/propeller/processor_test.go rename to consensus/propeller/deprecated_processor_test.go index 6c2f5342e5..b718984d38 100644 --- a/consensus/propeller/processor_test.go +++ b/consensus/propeller/deprecated_processor_test.go @@ -27,7 +27,7 @@ type processorTestEnv struct { type sentUnit struct { To peer.ID - Unit *PropellerUnit + Unit *Unit } func newProcessorTestEnv(t *testing.T, n int) *processorTestEnv { @@ -60,7 +60,7 @@ func newProcessorTestEnv(t *testing.T, n int) *processorTestEnv { // sendFunc records sent units for later inspection. func (env *processorTestEnv) sendFunc() SendUnitFunc { - return func(_ context.Context, to peer.ID, unit *PropellerUnit) error { + return func(_ context.Context, to peer.ID, unit *Unit) error { env.sentMu.Lock() defer env.sentMu.Unlock() env.sentUnits = append(env.sentUnits, sentUnit{To: to, Unit: unit}) @@ -72,7 +72,7 @@ func (env *processorTestEnv) sendFunc() SendUnitFunc { // the signed units and root. func (env *processorTestEnv) encodeTestMessage( t *testing.T, publisher peer.ID, msg []byte, -) ([]PropellerUnit, MessageRoot) { +) ([]Unit, MessageRoot) { t.Helper() units, root, err := EncodeMessage(msg, env.schedule, env.encoder) @@ -334,7 +334,7 @@ func TestProcessor_DuplicateShardRejected(t *testing.T) { }() // Find a shard not from localPeer. - var targetUnit PropellerUnit + var targetUnit Unit var targetSender peer.ID for i, unit := range units { sender, err := env.schedule.PeerForShard(publisher, ShardIndex(i)) From e9fdc8f8b92356d0a85437816f714d06aecd080d Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 1 Apr 2026 10:55:49 +0100 Subject: [PATCH 06/40] chore: add proto units --- consensus/propeller/proto/README.md | 25 ++ consensus/propeller/proto/propeller.pb.go | 295 ++++++++++++++++++++++ consensus/propeller/proto/propeller.proto | 41 +++ 3 files changed, 361 insertions(+) create mode 100644 consensus/propeller/proto/README.md create mode 100644 consensus/propeller/proto/propeller.pb.go create mode 100644 consensus/propeller/proto/propeller.proto diff --git a/consensus/propeller/proto/README.md b/consensus/propeller/proto/README.md new file mode 100644 index 0000000000..8daac27a4c --- /dev/null +++ b/consensus/propeller/proto/README.md @@ -0,0 +1,25 @@ +# Generating Go code from propeller.proto + +The `propeller.proto` file imports `p2p/proto/common.proto` from the upstream +[starknet-p2p-specs](https://github.com/starknet-io/starknet-p2p-specs) repository. +Since the upstream module is not on the Buf Schema Registry, we use `buf export` + `protoc` directly. + +From the project root: + +```bash +# 1. Export upstream .proto sources (needed for import resolution) +buf export \ + "https://github.com/starknet-io/starknet-p2p-specs.git#branch=bcfa353a169c859e4d5d97757caccbe76f75bc06,depth=1" \ + -o /tmp/starknet-p2p-specs-proto + +# 2. Generate Go code +protoc \ + --go_out=. \ + --go_opt=paths=source_relative \ + --go_opt=Mp2p/proto/common.proto=github.com/starknet-io/starknet-p2p-specs/p2p/proto/common \ + -I /tmp/starknet-p2p-specs-proto \ + -I . \ + consensus/propeller/proto/propeller.proto +``` + +This produces `propeller.pb.go` in this directory. diff --git a/consensus/propeller/proto/propeller.pb.go b/consensus/propeller/proto/propeller.pb.go new file mode 100644 index 0000000000..0a1fcf667b --- /dev/null +++ b/consensus/propeller/proto/propeller.pb.go @@ -0,0 +1,295 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.9 +// protoc v7.34.0 +// source: consensus/propeller/proto/propeller.proto + +package proto + +import ( + common "github.com/starknet-io/starknet-p2p-specs/p2p/proto/common" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// A Merkle proof consisting of sibling hashes used to verify that a leaf belongs to a Merkle tree. +// Each sibling hash is 32 bytes (SHA-256). The siblings are ordered from leaf level to root level. +type MerkleProof struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The sibling hashes needed to reconstruct the path from the leaf to the root. + // Each hash is 32 bytes. + Siblings []*common.Hash256 `protobuf:"bytes,1,rep,name=siblings,proto3" json:"siblings,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *MerkleProof) Reset() { + *x = MerkleProof{} + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *MerkleProof) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MerkleProof) ProtoMessage() {} + +func (x *MerkleProof) ProtoReflect() protoreflect.Message { + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MerkleProof.ProtoReflect.Descriptor instead. +func (*MerkleProof) Descriptor() ([]byte, []int) { + return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{0} +} + +func (x *MerkleProof) GetSiblings() []*common.Hash256 { + if x != nil { + return x.Siblings + } + return nil +} + +// A single unit in the Propeller protocol containing a shard of erasure-coded data +// along with cryptographic proofs for verification. +type PropellerUnit struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The actual data shard (erasure-coded fragment of the original message). + Shard []byte `protobuf:"bytes,1,opt,name=shard,proto3" json:"shard,omitempty"` + // The position of this shard in the erasure coding scheme. + Index uint64 `protobuf:"varint,2,opt,name=index,proto3" json:"index,omitempty"` + // The Merkle root of all shards, used to verify shard integrity. + MerkleRoot *common.Hash256 `protobuf:"bytes,3,opt,name=merkle_root,json=merkleRoot,proto3" json:"merkle_root,omitempty"` + // The Merkle proof that this shard belongs to the tree with the given root. + MerkleProof *MerkleProof `protobuf:"bytes,4,opt,name=merkle_proof,json=merkleProof,proto3" json:"merkle_proof,omitempty"` + // The peer ID of the original publisher who created and signed this unit. + Publisher *common.PeerID `protobuf:"bytes,5,opt,name=publisher,proto3" json:"publisher,omitempty"` + // Cryptographic signature from the publisher over the merkle_root. + Signature []byte `protobuf:"bytes,6,opt,name=signature,proto3" json:"signature,omitempty"` + // TODO(AndrewL): consider re-naming channel + // TODO(AndrewL): make it uint64 instead of uint32. + // Logical channel identifier for multiplexing different message streams. + Channel uint32 `protobuf:"varint,7,opt,name=channel,proto3" json:"channel,omitempty"` // TODO(AndrewL): CRITICAL: protect against replay attacks (maybe using a timestamp) + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PropellerUnit) Reset() { + *x = PropellerUnit{} + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PropellerUnit) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PropellerUnit) ProtoMessage() {} + +func (x *PropellerUnit) ProtoReflect() protoreflect.Message { + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PropellerUnit.ProtoReflect.Descriptor instead. +func (*PropellerUnit) Descriptor() ([]byte, []int) { + return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{1} +} + +func (x *PropellerUnit) GetShard() []byte { + if x != nil { + return x.Shard + } + return nil +} + +func (x *PropellerUnit) GetIndex() uint64 { + if x != nil { + return x.Index + } + return 0 +} + +func (x *PropellerUnit) GetMerkleRoot() *common.Hash256 { + if x != nil { + return x.MerkleRoot + } + return nil +} + +func (x *PropellerUnit) GetMerkleProof() *MerkleProof { + if x != nil { + return x.MerkleProof + } + return nil +} + +func (x *PropellerUnit) GetPublisher() *common.PeerID { + if x != nil { + return x.Publisher + } + return nil +} + +func (x *PropellerUnit) GetSignature() []byte { + if x != nil { + return x.Signature + } + return nil +} + +func (x *PropellerUnit) GetChannel() uint32 { + if x != nil { + return x.Channel + } + return 0 +} + +// A batch of PropellerUnits for efficient transmission. +type PropellerUnitBatch struct { + state protoimpl.MessageState `protogen:"open.v1"` + Batch []*PropellerUnit `protobuf:"bytes,1,rep,name=batch,proto3" json:"batch,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PropellerUnitBatch) Reset() { + *x = PropellerUnitBatch{} + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PropellerUnitBatch) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PropellerUnitBatch) ProtoMessage() {} + +func (x *PropellerUnitBatch) ProtoReflect() protoreflect.Message { + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PropellerUnitBatch.ProtoReflect.Descriptor instead. +func (*PropellerUnitBatch) Descriptor() ([]byte, []int) { + return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{2} +} + +func (x *PropellerUnitBatch) GetBatch() []*PropellerUnit { + if x != nil { + return x.Batch + } + return nil +} + +var File_consensus_propeller_proto_propeller_proto protoreflect.FileDescriptor + +const file_consensus_propeller_proto_propeller_proto_rawDesc = "" + + "\n" + + ")consensus/propeller/proto/propeller.proto\x1a\x16p2p/proto/common.proto\"3\n" + + "\vMerkleProof\x12$\n" + + "\bsiblings\x18\x01 \x03(\v2\b.Hash256R\bsiblings\"\xf6\x01\n" + + "\rPropellerUnit\x12\x14\n" + + "\x05shard\x18\x01 \x01(\fR\x05shard\x12\x14\n" + + "\x05index\x18\x02 \x01(\x04R\x05index\x12)\n" + + "\vmerkle_root\x18\x03 \x01(\v2\b.Hash256R\n" + + "merkleRoot\x12/\n" + + "\fmerkle_proof\x18\x04 \x01(\v2\f.MerkleProofR\vmerkleProof\x12%\n" + + "\tpublisher\x18\x05 \x01(\v2\a.PeerIDR\tpublisher\x12\x1c\n" + + "\tsignature\x18\x06 \x01(\fR\tsignature\x12\x18\n" + + "\achannel\x18\a \x01(\rR\achannel\":\n" + + "\x12PropellerUnitBatch\x12$\n" + + "\x05batch\x18\x01 \x03(\v2\x0e.PropellerUnitR\x05batchB9Z7github.com/NethermindEth/juno/consensus/propeller/protob\x06proto3" + +var ( + file_consensus_propeller_proto_propeller_proto_rawDescOnce sync.Once + file_consensus_propeller_proto_propeller_proto_rawDescData []byte +) + +func file_consensus_propeller_proto_propeller_proto_rawDescGZIP() []byte { + file_consensus_propeller_proto_propeller_proto_rawDescOnce.Do(func() { + file_consensus_propeller_proto_propeller_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_consensus_propeller_proto_propeller_proto_rawDesc), len(file_consensus_propeller_proto_propeller_proto_rawDesc))) + }) + return file_consensus_propeller_proto_propeller_proto_rawDescData +} + +var file_consensus_propeller_proto_propeller_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_consensus_propeller_proto_propeller_proto_goTypes = []any{ + (*MerkleProof)(nil), // 0: MerkleProof + (*PropellerUnit)(nil), // 1: PropellerUnit + (*PropellerUnitBatch)(nil), // 2: PropellerUnitBatch + (*common.Hash256)(nil), // 3: Hash256 + (*common.PeerID)(nil), // 4: PeerID +} +var file_consensus_propeller_proto_propeller_proto_depIdxs = []int32{ + 3, // 0: MerkleProof.siblings:type_name -> Hash256 + 3, // 1: PropellerUnit.merkle_root:type_name -> Hash256 + 0, // 2: PropellerUnit.merkle_proof:type_name -> MerkleProof + 4, // 3: PropellerUnit.publisher:type_name -> PeerID + 1, // 4: PropellerUnitBatch.batch:type_name -> PropellerUnit + 5, // [5:5] is the sub-list for method output_type + 5, // [5:5] is the sub-list for method input_type + 5, // [5:5] is the sub-list for extension type_name + 5, // [5:5] is the sub-list for extension extendee + 0, // [0:5] is the sub-list for field type_name +} + +func init() { file_consensus_propeller_proto_propeller_proto_init() } +func file_consensus_propeller_proto_propeller_proto_init() { + if File_consensus_propeller_proto_propeller_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_consensus_propeller_proto_propeller_proto_rawDesc), len(file_consensus_propeller_proto_propeller_proto_rawDesc)), + NumEnums: 0, + NumMessages: 3, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_consensus_propeller_proto_propeller_proto_goTypes, + DependencyIndexes: file_consensus_propeller_proto_propeller_proto_depIdxs, + MessageInfos: file_consensus_propeller_proto_propeller_proto_msgTypes, + }.Build() + File_consensus_propeller_proto_propeller_proto = out.File + file_consensus_propeller_proto_propeller_proto_goTypes = nil + file_consensus_propeller_proto_propeller_proto_depIdxs = nil +} diff --git a/consensus/propeller/proto/propeller.proto b/consensus/propeller/proto/propeller.proto new file mode 100644 index 0000000000..67787bdd3a --- /dev/null +++ b/consensus/propeller/proto/propeller.proto @@ -0,0 +1,41 @@ +syntax = "proto3"; + +import "p2p/proto/common.proto"; + +option go_package = "github.com/NethermindEth/juno/consensus/propeller/proto"; + +// A Merkle proof consisting of sibling hashes used to verify that a leaf belongs to a Merkle tree. +// Each sibling hash is 32 bytes (SHA-256). The siblings are ordered from leaf level to root level. +message MerkleProof { + // The sibling hashes needed to reconstruct the path from the leaf to the root. + // Each hash is 32 bytes. + repeated Hash256 siblings = 1; +} + +// A single unit in the Propeller protocol containing a shard of erasure-coded data +// along with cryptographic proofs for verification. +message PropellerUnit { + // The actual data shard (erasure-coded fragment of the original message). + bytes shard = 1; + // The position of this shard in the erasure coding scheme. + uint64 index = 2; + // The Merkle root of all shards, used to verify shard integrity. + Hash256 merkle_root = 3; + // The Merkle proof that this shard belongs to the tree with the given root. + MerkleProof merkle_proof = 4; + // The peer ID of the original publisher who created and signed this unit. + PeerID publisher = 5; + // Cryptographic signature from the publisher over the merkle_root. + bytes signature = 6; + // TODO(AndrewL): consider re-naming channel + // TODO(AndrewL): make it uint64 instead of uint32. + // Logical channel identifier for multiplexing different message streams. + uint32 channel = 7; + // TODO(AndrewL): CRITICAL: protect against replay attacks (maybe using a timestamp) +} + +// A batch of PropellerUnits for efficient transmission. +message PropellerUnitBatch { + repeated PropellerUnit batch = 1; +} + From 4e2fbb36638eac4988bb090b440923fcec14001f Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 1 Apr 2026 10:56:41 +0100 Subject: [PATCH 07/40] chore: add propeller unit type --- consensus/propeller/unit.go | 53 ++++++++++++++++++++++++++++++++ consensus/propeller/unit_test.go | 0 2 files changed, 53 insertions(+) create mode 100644 consensus/propeller/unit.go create mode 100644 consensus/propeller/unit_test.go diff --git a/consensus/propeller/unit.go b/consensus/propeller/unit.go new file mode 100644 index 0000000000..b60a9ee88d --- /dev/null +++ b/consensus/propeller/unit.go @@ -0,0 +1,53 @@ +package propeller + +import ( + "github.com/NethermindEth/juno/consensus/propeller/merkle" + pb "github.com/NethermindEth/juno/consensus/propeller/proto" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/starknet-io/starknet-p2p-specs/p2p/proto/common" +) + +// Unit is the atomic wire message: one erasure-coded shard plus +// the metadata needed for independent verification. Each unit is self-contained +// so a receiver can validate it without any other shards. +type Unit struct { + CommitteeID CommitteeID // Which committee this belongs to + Publisher peer.ID // Original message author + MerkleRoot MessageRoot // Merkle root binding all shards together + MerkleProof merkle.Proof // Merkle inclusion proof for this shard + Signature []byte // Publisher's Ed25519 signature over the root + ShardIndex ShardIndex // This shard's position in the erasure-coded output + ShardData []byte // The actual data fragment +} + +func UnitFromProto(protoUnit *pb.PropellerUnit) Unit { + return Unit{ + CommitteeID: CommitteeID(protoUnit.Channel), + // todo(rdr): this casting operations seem a bit risky, are they? + Publisher: peer.ID(protoUnit.Publisher.Id), + MerkleRoot: MessageRoot(protoUnit.MerkleRoot.Elements), + Signature: protoUnit.Signature, + ShardIndex: ShardIndex(protoUnit.Index), + ShardData: protoUnit.Shard, + } +} + +func (u *Unit) ToProto() *pb.PropellerUnit { + siblings := make([]*common.Hash256, len(u.MerkleProof.Siblings)) + for i, s := range u.MerkleProof.Siblings { + siblings[i] = &common.Hash256{Elements: s[:]} + } + + root := merkle.Hash(u.MerkleRoot) + return &pb.PropellerUnit{ + Shard: u.ShardData, + Index: uint64(u.ShardIndex), + MerkleRoot: &common.Hash256{ + Elements: root[:], + }, + MerkleProof: &pb.MerkleProof{Siblings: siblings}, + Publisher: &common.PeerID{Id: []byte(u.Publisher)}, + Signature: u.Signature, + Channel: uint32(u.CommitteeID), + } +} diff --git a/consensus/propeller/unit_test.go b/consensus/propeller/unit_test.go new file mode 100644 index 0000000000..e69de29bb2 From ca88d79d4cd1db7a0a192d1195d9e45ba290e0a5 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 4 Apr 2026 14:47:36 +0100 Subject: [PATCH 08/40] [wip]feat(propeller): introduce new message processing behaviour --- consensus/propeller/engine.go | 168 +++++++-------- consensus/propeller/engine_test.go | 14 +- consensus/propeller/processor.go | 281 ++++++++++++++++++++++++++ consensus/propeller/processor_test.go | 1 + 4 files changed, 378 insertions(+), 86 deletions(-) create mode 100644 consensus/propeller/processor.go create mode 100644 consensus/propeller/processor_test.go diff --git a/consensus/propeller/engine.go b/consensus/propeller/engine.go index 19a3ff79c6..deb6f69643 100644 --- a/consensus/propeller/engine.go +++ b/consensus/propeller/engine.go @@ -21,16 +21,27 @@ const ( ) type broadcastResult struct { - units []PropellerUnit + units []Unit err error } // todo(rdr): using String until I find a better type type StakerID string +// Holds the state for a Committee ID: +// - The `scheduler` represents the Propeller Tree of peers +// - The `processor` stores the state of this committee: when the built or receive threshold +// have been reached +// - The `peerKeys` I am not sure yet todo(rdr): <- type committeeState struct { scheduler *Scheduler - peerKeys []StakerID + // todo(rdr): A look at processor shows that it's lifetime is strictly coupled with the + // state of a current committee. They both should be created and closed at the + // same time. If it is like this then it stands to reason that it should be coupled + // here. Not 100% sure right now, so leaving a big todo for now. + proccessor *MessageProcessor + // todo(rdr): why do we need this + peerKeys []StakerID } // engineCommand is a tagged union of commands sent to the engine's Run() loop. @@ -47,26 +58,26 @@ type registerCommittee struct { func (registerCommittee) isCommand() -type cmdUnregister struct { +type unregisterCommittee struct { committeeID CommitteeID } -func (cmdUnregister) isCommand() +func (unregisterCommittee) isCommand() -type cmdBroadcast struct { +type broadcast struct { committeeID CommitteeID msg []byte errCh chan error } -func (cmdBroadcast) isCommand() +func (broadcast) isCommand() -type cmdHandleUnit struct { - unit *PropellerUnit +type processUnit struct { + unit *Unit sender peer.ID } -func (cmdHandleUnit) isCommand() +func (processUnit) isCommand() // Engine is the central orchestrator of the Propeller protocol. It: // @@ -82,81 +93,77 @@ type Engine struct { localPeer peer.ID privKey crypto.PrivKey config Config - log utils.Logger + log utils.StructuredLogger + // processor handles validates and process all the messages received by other peers + processor *Processor + // committees holds the Scheduler (i.e. Propeller Tree) and Stakers ID of // the peers of each registered channel + // todo(rdr): committeeState can set be there by value instead of by ref? committees map[CommitteeID]*committeeState + + // todo(rdr): not sure of this one yet // connected peers hold all the connected peers to the engine connectedPeers map[peer.ID]struct{} - // whenever a broadcast action is started, units preparaition are done concurrently + // whenever a broadcast action is started, units preparation are done concurrently // and delivered through this channel unitsPrepared chan broadcastResult - // processors maps each active message to its processor's shard input - // channel. The engine creates processors lazily on first shard receipt. - // Only accessed from the Run() goroutine, so no lock needed. - processors map[messageKey]chan<- shardDelivery - - // finalised tracks recently finalised messages to avoid re-creating - // processors for late-arriving shards. - finalised *TimeCache[messageKey] - // eventCh is shared between all processors and the engine. The engine // reads from it and forwards events to the application via Events(). eventCh chan any - // cleanupCh carries internal processor-done signals. This is separate - // from eventCh so that a full eventCh never blocks processor goroutines - // trying to signal completion, which would leak goroutines. - cleanupCh chan processorDone - // appEventCh is the externally-visible event channel. The engine copies // events from eventCh to appEventCh in its Run() loop, filtering out // internal events as needed. appEventCh chan any - // cmdCh carries commands from external callers into the Run() loop. - cmdCh chan engineCommand - - // sendFn is the network callback for delivering units to peers. - // Injected at construction time for testability. - sendFn SendUnitFunc + // cmdCh receives commands from the propeller service and act on those + cmdCh <-chan engineCommand } -// NewEngine creates an engine instance. Call Run() to start processing. +// NewEngine creates an engine instance. It returns the engine and the channel to +// send engineCommands to. +// Call Run() to start processing. // // Parameters: // - localPeer: this node's peer ID. // - privKey: this node's Ed25519 private key (for signing published messages). // - config: protocol parameters. -// - sendFn: callback for delivering PropellerUnits to peers over the network. // - log: structured logger. +// +// todo(rdr): Maybe in the future we don't want to expose the command channel and instead hide +// the interaction behind a public API. :think: func NewEngine( - // todo(rdr): this should be a key pair privKey crypto.PrivKey, config *Config, - sendFn SendUnitFunc, - log utils.Logger, -) *Engine { - // todo(rdr): generate local peer id from keypair + log utils.StructuredLogger, +) (*Engine, chan<- engineCommand) { + localPeerID, err := peer.IDFromPrivateKey(privKey) + if err != nil { + // todo(rdr): pannic for now, error handling for later + panic(err) + } + + processor := NewProcessor(localPeerID, config) + + cmdCh := make(chan engineCommand) + return &Engine{ - localPeer: peer.ID("some random value for now"), - privKey: privKey, - config: *config, - log: log, - committees: make(map[CommitteeID]*committeeState), - connectedPeers: make(map[peer.ID]struct{}), - cmdCh: make(chan engineCommand, cmdChSize), - unitsPrepared: make(chan broadcastResult), + localPeer: localPeerID, + privKey: privKey, + config: *config, + log: log, + processor: processor, + committees: make(map[CommitteeID]*committeeState), + cmdCh: cmdCh, + unitsPrepared: make(chan broadcastResult), // Unsure of the fields below - processors: make(map[messageKey]chan<- shardDelivery), - finalised: NewTimeCache[messageKey](config.StaleMessageTimeout * 2), - eventCh: make(chan any, eventChSize), - cleanupCh: make(chan processorDone, cleanupChSize), - appEventCh: make(chan any, appEventChSize), - sendFn: sendFn, - } + connectedPeers: make(map[peer.ID]struct{}), + eventCh: make(chan any, eventChSize), + appEventCh: make(chan any, appEventChSize), + }, cmdCh } // registerCommittee creates the schedule and encoder for a new channel. @@ -165,6 +172,7 @@ func (e *Engine) registerCommittee( peers []PeerCommittee, peersKeys []*StakerID, ) error { + // todo(rdr): Why re-registration should be ignored, as far as I know, it shouldn't happen :think: if _, ok := e.committees[committeeID]; ok { e.log.Warn( "committee already registered, will ignore re-registration attempt", @@ -194,7 +202,7 @@ func (e *Engine) registerCommittee( } e.log.Info("registered new committee", - zap.Uint64("channel", uint64(committeeID)), + zap.Uint64("committeeID", uint64(committeeID)), zap.Int("peers", len(peers)), zap.Int("dataShards", schedule.NumDataShards()), zap.Int("codingShards", schedule.NumCodingShards()), @@ -207,6 +215,9 @@ func (e *Engine) registerCommittee( // currently running ones will continue until the timeout / stop naturally func (e *Engine) unregisterCommittee(committeeID CommitteeID) { delete(e.committees, committeeID) + // todo(rdr): We have to clean the processors, right? + // or will they shut down on their own eventually + // better to pass a context with cancelj e.log.Info("unregistered propeller committee", zap.Uint64("committee", uint64(committeeID)), @@ -241,7 +252,7 @@ func (e *Engine) prepareBroadcast(committeeID CommitteeID, data []byte) error { } // broacast receives Propeller units (built in `prepareBroadcast`) and sends them -func (e *Engine) broadcast(units []PropellerUnit) error { +func (e *Engine) broadcast(units []Unit) error { targetCommittee := units[0].CommitteeID cs, ok := e.committees[targetCommittee] @@ -263,17 +274,24 @@ func (e *Engine) broadcast(units []PropellerUnit) error { return nil } -// doHandleUnit routes an incoming unit to the correct processor, creating +// processUnit routes an incoming unit to the correct processor, creating // one if needed. -func (e *Engine) doHandleUnit(ctx context.Context, cmd *cmdHandleUnit) { - unit := cmd.unit - key := messageKey{ - Channel: unit.CommitteeID, - Publisher: unit.Publisher, - Root: unit.MerkleRoot, +func (e *Engine) processUnit(ctx context.Context, unit *Unit, sender peer.ID) { + if _, ok := e.committees[unit.CommitteeID]; !ok { + // note(rdr): maybe debug? + e.log.Warn("received key for unregistered committee, dropping", + zap.Uint64("committee id", uint64(unit.CommitteeID)), + ) + return } // Skip already-finalised messages. + // todo(rdr): add timestamps to message keys to avoid replay attacks on old messages + key := messageKey{ + CommitteeID: unit.CommitteeID, + Publisher: unit.Publisher, + Root: unit.MerkleRoot, + } if e.finalised.Contains(key) { return } @@ -290,7 +308,7 @@ func (e *Engine) doHandleUnit(ctx context.Context, cmd *cmdHandleUnit) { // Non-blocking send to the processor. If its buffer is full, the shard // is dropped (the processor can reconstruct from other shards). select { - case shardCh <- shardDelivery{Unit: unit, Sender: cmd.sender}: + case shardCh <- shardDelivery{Unit: unit, Sender: sender}: default: e.log.Warn("dropping shard: processor channel full", zap.Uint32("shard", uint32(unit.ShardIndex)), @@ -302,7 +320,7 @@ func (e *Engine) doHandleUnit(ctx context.Context, cmd *cmdHandleUnit) { // createProcessor spins up a new MessageProcessor goroutine for a message // we haven't seen before. func (e *Engine) createProcessor( - ctx context.Context, key messageKey, unit *PropellerUnit, + ctx context.Context, key messageKey, unit *Unit, ) chan<- shardDelivery { cs, ok := e.committees[unit.CommitteeID] if !ok { @@ -319,14 +337,11 @@ func (e *Engine) createProcessor( shardCh := make(chan shardDelivery, cs.scheduler.NumShards()) proc := NewMessageProcessor( - key.Channel, - key.Publisher, - key.Root, + unit.CommitteeID, + unit.Publisher, + unit.Root, e.localPeer, e.config, - cs.scheduler, - validator, - cs.encoder, shardCh, e.eventCh, e.sendFn, @@ -381,13 +396,13 @@ func (e *Engine) handleCommand(ctx context.Context, command engineCommand) { case *registerCommittee: err := e.registerCommittee(cmd.committeeID, cmd.peers, cmd.peersKeys) cmd.errCh <- err - case *cmdUnregister: + case *unregisterCommittee: e.unregisterCommittee(cmd.committeeID) - case *cmdBroadcast: + case *broadcast: err := e.prepareBroadcast(cmd.committeeID, cmd.msg) cmd.errCh <- err - case *cmdHandleUnit: - e.doHandleUnit(ctx, cmd) + case *processUnit: + e.processUnit(ctx, cmd.unit, cmd.sender) } } @@ -416,11 +431,6 @@ func (e *Engine) Run(ctx context.Context) error { case event := <-e.eventCh: // Forward application-visible events from processors. e.forwardEvent(event) - - case done := <-e.cleanupCh: - // Processor goroutine exited; clean up the processors map. - e.handleProcessorDone(done) - } } } diff --git a/consensus/propeller/engine_test.go b/consensus/propeller/engine_test.go index b92a4e67f0..5543e20042 100644 --- a/consensus/propeller/engine_test.go +++ b/consensus/propeller/engine_test.go @@ -21,7 +21,7 @@ type engineTestEnv struct { peers []peer.ID privKeys []crypto.PrivKey engines []*Engine - sentUnits map[peer.ID][]*PropellerUnit + sentUnits map[peer.ID][]*Unit sentMu sync.Mutex log utils.Logger } @@ -49,7 +49,7 @@ func newEngineTestEnv(t *testing.T, n int) *engineTestEnv { env := &engineTestEnv{ peers: peers, privKeys: privKeys, - sentUnits: make(map[peer.ID][]*PropellerUnit), + sentUnits: make(map[peer.ID][]*Unit), log: log, } @@ -74,7 +74,7 @@ func newEngineTestEnv(t *testing.T, n int) *engineTestEnv { // makeSendFn creates a SendUnitFunc that records sent units. func (env *engineTestEnv) makeSendFn() SendUnitFunc { - return func(_ context.Context, to peer.ID, unit *PropellerUnit) error { + return func(_ context.Context, to peer.ID, unit *Unit) error { env.sentMu.Lock() env.sentUnits[to] = append(env.sentUnits[to], unit) env.sentMu.Unlock() @@ -83,10 +83,10 @@ func (env *engineTestEnv) makeSendFn() SendUnitFunc { } // getSentUnits returns all units sent to a given peer. -func (env *engineTestEnv) getSentUnits(to peer.ID) []*PropellerUnit { +func (env *engineTestEnv) getSentUnits(to peer.ID) []*Unit { env.sentMu.Lock() defer env.sentMu.Unlock() - result := make([]*PropellerUnit, len(env.sentUnits[to])) + result := make([]*Unit, len(env.sentUnits[to])) copy(result, env.sentUnits[to]) return result } @@ -266,7 +266,7 @@ func TestEngine_HandleUnit_UnregisteredChannel(t *testing.T) { }() // Send a unit for an unregistered channel. - unit := &PropellerUnit{ + unit := &Unit{ CommitteeID: 99, Publisher: env.peers[1], MerkleRoot: MessageRoot{0x01}, @@ -317,7 +317,7 @@ func TestEngine_SendFailureEmitsEvent(t *testing.T) { StreamProtocol: "/propeller/test/0.1.0", MaxWireMessageSize: 1 << 20, }, - func(_ context.Context, _ peer.ID, _ *PropellerUnit) error { + func(_ context.Context, _ peer.ID, _ *Unit) error { return fmt.Errorf("simulated network failure") }, utils.NewNopZapLogger(), diff --git a/consensus/propeller/processor.go b/consensus/propeller/processor.go new file mode 100644 index 0000000000..21807a47b3 --- /dev/null +++ b/consensus/propeller/processor.go @@ -0,0 +1,281 @@ +package propeller + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/peer" +) + +type unitWithSender struct { + unit *Unit + sender peer.ID +} + +type messageState uint64 + +const ( + preBuilt = iota + preReceived +) + +func (ms *messageState) NextState() { + *ms += 1 +} + +type subprocessor struct { + scheduler *Scheduler + localShardIndex ShardIndex + localShardWasBroadcast bool + + messageState messageState + unitsReceived []Unit +} + +func newSubprocessor(scheduler *Scheduler, localShardIndex ShardIndex) subprocessor { + return subprocessor{ + scheduler: scheduler, + localShardIndex: localShardIndex, + localShardWasBroadcast: false, + + messageState: preBuilt, + unitsReceived: make([]Unit, 0, scheduler.ReceiveThreshold()), + } +} + +func (s *subprocessor) Run(ctx context.Context, unitChan <-chan unitWithSender) error { + for { + select { + case <-ctx.Done(): + // todo(rdr): need to differentiate between context cancellation and timeout. + // can check for `context.DeadlineExceeded` + return ctx.Err() + case unitWithSender := <-unitChan: + // todo(rdr): validate that the unit is correct + // if the unit is incorrect penalize publisher (how?) + + s.unitsReceived = append(s.unitsReceived, *unitWithSender.unit) + switch s.messageState { + case preBuilt: + // if the unit / shard is our own and we are pre-construction then we should + // broadcast our own shard (only once) + // todo(rdr): consider inlining this function? or use go naming ("once" in the name) + s.maybeBroacastLocalShard(unitWithSender.unit) + + // todo(rdr): do something with a signature that I don't understand very well + + if len(s.unitsReceived) == s.scheduler.BuildThreshold() { + s.messageState.NextState() + } + + case preReceived: + if len(s.unitsReceived) == s.scheduler.ReceiveThreshold() { + // broadcast and finish execution – but don't broadcast the local shard + } + } + + } + } +} + +// todo(rdr): this can probably be inlined? +func (s *subprocessor) maybeBroacastLocalShard(unit *Unit) { + if !s.localShardWasBroadcast && s.localShardIndex == unit.ShardIndex { + // broadcast shard index + s.localShardWasBroadcast = true + } +} + +// messageKey uniquely identifies a message within a committee. We track +// per-message state (processor, time cache) using this composite key +// because the same publisher could broadcast different messages (different +// roots) and we need to handle each independently. +type messageKey struct { + CommitteeID CommitteeID + Publisher peer.ID + Root MessageRoot +} + +type messageKeyWithError struct { + messageKey messageKey + error error +} + +type concurrentTasksBounds struct { + maxWorkers uint64 + maxWorkersPerPublisher uint64 +} + +// Processor handles all concurrent work on message processing +type Processor struct { + finalized *TimeCache[messageKey] + // todo(rdr): channel to communicate that a certain subprocessor has finished + done chan messageKeyWithError + + mu sync.Mutex + publisherTasks map[peer.ID]uint64 + tasks uint64 + // ---------------------------------- + subProcessors map[messageKey]chan unitWithSender + + // config inherited from Engine + localPeer peer.ID + timeout time.Duration + concurrentTasksBounds concurrentTasksBounds +} + +func NewProcessor(localPeer peer.ID, config *Config) *Processor { + timeout := config.StaleMessageTimeout + + return &Processor{ + finalized: NewTimeCache[messageKey](timeout), + done: make(chan messageKeyWithError), + + publisherTasks: make(map[peer.ID]uint64), + tasks: 0, + subProcessors: make(map[messageKey]chan unitWithSender), + + localPeer: localPeer, + timeout: timeout, + // todo(rdr): set this ones based on the config (or some consts?) + concurrentTasksBounds: concurrentTasksBounds{}, + } +} + +func (p *Processor) Run(ctx context.Context) { +} + +func (p *Processor) ProcessMessage( + ctx context.Context, + unit *Unit, + sender peer.ID, + scheduler *Scheduler, +) error { + key := messageKey{ + CommitteeID: unit.CommitteeID, + Publisher: unit.Publisher, + Root: unit.MerkleRoot, + } + if p.finalized.Contains(key) { + return nil + } + + // todo(rdr): currently on a single go-routine the validation is performed and then the unit + // is processed. This could be divided into: + // - A validation task that performs validation (go routine A) + // - A processing task that process the message (go routine B) + // - Then A will send the correct units to B + // This means that when many messages are received in quick succession, they can be validated + // non blockingly. This also means we have two go routines by sub processor than just a single + // one. Does it makes sense? + unitChan, err := p.subprocessorChannel(ctx, &key, scheduler) + if err != nil { + fmt.Errorf("couldn't get processor channel for key: %w", err) + } + + select { + case unitChan <- unitWithSender{unit: unit, sender: sender}: + return nil + default: + } + + return errors.New("dropping shard, processor channel full") +} + +// createSubprocessor creates a go-routine (subprocessor) that handles all the processing of `key`. +// It returns a channel through which this processor can be given units to process +// todo(rdr): I would like not to create a channel for everytime we have a different messageKey +// since that can be a bit rough to the GC, better to have a pool of them. Benchmarks will give +// the final word +func (p *Processor) createSubprocessor( + ctx context.Context, + key *messageKey, + scheduler *Scheduler, +) (chan unitWithSender, error) { + localShardIndex, err := scheduler.ShardIndexForPublisher(key.Publisher) + if err != nil { + return nil, fmt.Errorf("cannot create new subprocessor: %w", err) + } + + err = p.increaseTasks(key.Publisher) + if err != nil { + return nil, err + } + + // create communication channel + unitChan := make(chan unitWithSender) + p.subProcessors[*key] = unitChan + + // launch subprocessor + ctxWithTimeout, _ := context.WithTimeout(ctx, p.timeout) + // todo(rdr): passing to avoid closures. Does it makes sense? + // need to learn more how closures work in Go if it makes any difference + // in performance. + // todo(rdr): should I pass p.done as an argument? + go func( + ctx context.Context, + messageKey messageKey, + scheduler *Scheduler, + localShardIndex ShardIndex, + unitChan <-chan unitWithSender, + ) { + subProcessor := newSubprocessor(scheduler, localShardIndex) + err := subProcessor.Run(ctx, unitChan) + p.done <- messageKeyWithError{ + messageKey: messageKey, + error: err, + } + }(ctxWithTimeout, *key, scheduler, localShardIndex, unitChan) + + return unitChan, nil +} + +// Given a message key it returns a channel that communicates with the subprocessor +// handling this specific message key. +func (p *Processor) subprocessorChannel( + ctx context.Context, + key *messageKey, + scheduler *Scheduler, +) (chan unitWithSender, error) { + unitChan, ok := p.subProcessors[*key] + if !ok { + return p.createSubprocessor(ctx, key, scheduler) + } + return unitChan, nil +} + +func (p *Processor) increaseTasks(publisher peer.ID) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.publisherTasks[publisher] == p.concurrentTasksBounds.maxWorkersPerPublisher { + return fmt.Errorf( + "tasks per publisher exceeded (max: %d): %s", + p.publisherTasks[publisher], + publisher, + ) + } + + if p.tasks == p.concurrentTasksBounds.maxWorkers { + return fmt.Errorf( + "max tasks that the processor can handle has been reached (max: %d)", + p.tasks, + ) + } + + p.publisherTasks[publisher] += 1 + p.tasks += 1 + + return nil +} + +func (p *Processor) decreaseTask(publisher peer.ID) { + p.mu.Lock() + defer p.mu.Unlock() + + p.publisherTasks[publisher] -= 1 + p.tasks -= 1 +} diff --git a/consensus/propeller/processor_test.go b/consensus/propeller/processor_test.go new file mode 100644 index 0000000000..bb61865bd4 --- /dev/null +++ b/consensus/propeller/processor_test.go @@ -0,0 +1 @@ +package propeller_test From c27f995fcedff4eb29e063b792f6f919cf6a318e Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 4 Apr 2026 14:48:28 +0100 Subject: [PATCH 09/40] chore(propeller): rename validator to deprecated validator --- .../{validator.go => deprecated_validator.go} | 28 +------------------ ...r_test.go => deprecated_validator_test.go} | 14 +++++----- 2 files changed, 8 insertions(+), 34 deletions(-) rename consensus/propeller/{validator.go => deprecated_validator.go} (82%) rename consensus/propeller/{validator_test.go => deprecated_validator_test.go} (96%) diff --git a/consensus/propeller/validator.go b/consensus/propeller/deprecated_validator.go similarity index 82% rename from consensus/propeller/validator.go rename to consensus/propeller/deprecated_validator.go index 9155e7198c..da954f8cbb 100644 --- a/consensus/propeller/validator.go +++ b/consensus/propeller/deprecated_validator.go @@ -3,35 +3,9 @@ package propeller import ( "fmt" - "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" ) -// todo(rdr): Need to review this whole module - -// SignatureVerifier abstracts Ed25519 signature verification. The default -// implementation extracts the public key from a peer.ID and verifies using -// libp2p crypto. Tests can inject a mock to control verification outcomes. -type SignatureVerifier interface { - Verify(peerID peer.ID, data, signature []byte) (bool, error) -} - -// DefaultSignatureVerifier implements SignatureVerifier by extracting the -// public key embedded in a libp2p peer.ID. This works because peer.IDs -// for Ed25519 keys are derived from the public key, and for small keys -// the public key is embedded directly in the ID. -type DefaultSignatureVerifier struct{} - -func (DefaultSignatureVerifier) Verify( - peerID peer.ID, data, signature []byte, -) (bool, error) { - pubKey, err := peerID.ExtractPublicKey() - if err != nil { - return false, fmt.Errorf("extracting public key from peer %s: %w", peerID, err) - } - return pubKey.Verify(data, signature) -} - // Validator checks incoming PropellerUnits for correctness. Each check // serves a specific defensive purpose: // @@ -79,7 +53,7 @@ func NewValidator( // // Returns nil if valid, or a *ShardValidationError describing the failure. func (v *Validator) ValidateUnit( - unit *PropellerUnit, + unit *Unit, sender peer.ID, seenShards map[ShardIndex]bool, signatureVerified bool, diff --git a/consensus/propeller/validator_test.go b/consensus/propeller/deprecated_validator_test.go similarity index 96% rename from consensus/propeller/validator_test.go rename to consensus/propeller/deprecated_validator_test.go index 3f4e48fe38..44b39215fb 100644 --- a/consensus/propeller/validator_test.go +++ b/consensus/propeller/deprecated_validator_test.go @@ -50,7 +50,7 @@ func makeValidUnit( publisherKey crypto.PrivKey, publisher peer.ID, shardIndex ShardIndex, -) *PropellerUnit { +) *Unit { t.Helper() // Create a simple message and encode it. @@ -159,7 +159,7 @@ func TestValidator_SelfSending(t *testing.T) { setup := newValidatorTestSetup(t) v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) - unit := &PropellerUnit{Publisher: setup.publisher, ShardIndex: setup.shardIndex} + unit := &Unit{Publisher: setup.publisher, ShardIndex: setup.shardIndex} err := v.ValidateUnit(unit, setup.localPeer, nil, true) var valErr *ShardValidationError @@ -172,7 +172,7 @@ func TestValidator_ReceivedSelfPublishedShard(t *testing.T) { v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) // Unit claims we are the publisher. - unit := &PropellerUnit{Publisher: setup.localPeer, ShardIndex: 0} + unit := &Unit{Publisher: setup.localPeer, ShardIndex: 0} err := v.ValidateUnit(unit, setup.expectedSender, nil, true) var valErr *ShardValidationError @@ -184,7 +184,7 @@ func TestValidator_DuplicateShard(t *testing.T) { setup := newValidatorTestSetup(t) v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) - unit := &PropellerUnit{Publisher: setup.publisher, ShardIndex: setup.shardIndex} + unit := &Unit{Publisher: setup.publisher, ShardIndex: setup.shardIndex} seenShards := map[ShardIndex]bool{setup.shardIndex: true} err := v.ValidateUnit(unit, setup.expectedSender, seenShards, true) @@ -209,7 +209,7 @@ func TestValidator_UnexpectedSender(t *testing.T) { } require.NotEmpty(t, wrongSender) - unit := &PropellerUnit{Publisher: setup.publisher, ShardIndex: setup.shardIndex} + unit := &Unit{Publisher: setup.publisher, ShardIndex: setup.shardIndex} seenShards := make(map[ShardIndex]bool) err := v.ValidateUnit(unit, wrongSender, seenShards, true) @@ -239,7 +239,7 @@ func TestValidator_MerkleProofFailed(t *testing.T) { v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) // Create a unit with a bad Merkle proof. - unit := &PropellerUnit{ + unit := &Unit{ Publisher: setup.publisher, ShardIndex: setup.shardIndex, MerkleRoot: MessageRoot{0x01}, @@ -350,7 +350,7 @@ func TestValidator_ScheduleError(t *testing.T) { v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) _, unknownPeer := realPeer(99) - unit := &PropellerUnit{ + unit := &Unit{ Publisher: unknownPeer, ShardIndex: 0, ShardData: []byte("data"), From 9d114f0834f0311716814da9dde90f92de0b9eaa Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sun, 5 Apr 2026 11:16:19 +0100 Subject: [PATCH 10/40] chore: apply changes to deprecated packages --- consensus/propeller/deprecated_processor.go | 2 +- consensus/propeller/deprecated_validator.go | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/consensus/propeller/deprecated_processor.go b/consensus/propeller/deprecated_processor.go index 22af75641f..3c79ccb8b3 100644 --- a/consensus/propeller/deprecated_processor.go +++ b/consensus/propeller/deprecated_processor.go @@ -92,7 +92,7 @@ func NewMessageProcessor( localPeer peer.ID, config Config, schedule *Scheduler, - validator *Validator, + validator *DeprecatedValidator, encoder Encoder, shardCh chan shardDelivery, eventCh chan<- any, diff --git a/consensus/propeller/deprecated_validator.go b/consensus/propeller/deprecated_validator.go index da954f8cbb..0ea7725725 100644 --- a/consensus/propeller/deprecated_validator.go +++ b/consensus/propeller/deprecated_validator.go @@ -6,7 +6,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" ) -// Validator checks incoming PropellerUnits for correctness. Each check +// DeprecatedValidator checks incoming PropellerUnits for correctness. Each check // serves a specific defensive purpose: // // - Self-sending check: prevents reflection attacks. @@ -21,19 +21,19 @@ import ( // // These checks are ordered from cheapest to most expensive so we reject // invalid units as early as possible. -type Validator struct { +type DeprecatedValidator struct { schedule *Scheduler localPeer peer.ID verifier SignatureVerifier } // NewValidator creates a validator for the given channel configuration. -func NewValidator( +func NewDeprecatedValidator( schedule *Scheduler, localPeer peer.ID, verifier SignatureVerifier, -) *Validator { - return &Validator{ +) *DeprecatedValidator { + return &DeprecatedValidator{ schedule: schedule, localPeer: localPeer, verifier: verifier, @@ -52,7 +52,7 @@ func NewValidator( // check after the first shard from the same message passes. // // Returns nil if valid, or a *ShardValidationError describing the failure. -func (v *Validator) ValidateUnit( +func (v *DeprecatedValidator) ValidateUnit( unit *Unit, sender peer.ID, seenShards map[ShardIndex]bool, From acf3efd654ba1ba1d26fb278542e84f5c7d0c73f Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sun, 5 Apr 2026 11:16:36 +0100 Subject: [PATCH 11/40] feat: implement new validator logic --- consensus/propeller/validator.go | 111 ++++++++++++++++++++++++++ consensus/propeller/validator_test.go | 1 + 2 files changed, 112 insertions(+) create mode 100644 consensus/propeller/validator.go create mode 100644 consensus/propeller/validator_test.go diff --git a/consensus/propeller/validator.go b/consensus/propeller/validator.go new file mode 100644 index 0000000000..e1ce60855c --- /dev/null +++ b/consensus/propeller/validator.go @@ -0,0 +1,111 @@ +package propeller + +import ( + "bytes" + "fmt" + "time" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" +) + +type verified struct { + verified bool + signature []byte + nonce time.Duration +} + +// Retuns if +func (v *verified) Verify(unit *Unit, pubKey crypto.PubKey) error { + if v.verified { + if bytes.Equal(v.signature, unit.Signature) && v.nonce == unit.Nonce { + return nil + } + // todo(rdr): add error information. Perhaps build an error type? + return fmt.Errorf("unit signature or nonce missmatch") + } + + err := verifyMessageIDSignature( + unit.CommitteeID, + unit.MerkleRoot, + unit.Signature, + unit.Nonce, + pubKey, + ) + if err != nil { + // add error information + return err + } + + *v = verified{ + verified: true, + // todo(rdr): by storing a field of unit.Signature am I forcing the whole `unit` to + // continue to exist on the heap, or the remaining fields can be cleaned. Probably the + // latter. + signature: unit.Signature, + nonce: unit.Nonce, + } + + return nil +} + +func verifyMessageIDSignature( + committeeID CommitteeID, + root MessageRoot, + signature []byte, + nonce time.Duration, + publisherPubKey crypto.PubKey, +) error { + panic("not yet implemented") +} + +// Validates all the incoming units / shards given a committee and the publisher +type Validator struct { + // Required fields to perform the validation + committeeID CommitteeID + publisher peer.ID + publisherPubKey crypto.PubKey + messageRoot MessageRoot + scheduler *Scheduler + + // Once the validation is done it's stored here, subsequent runs + // compare against it + verified verified + + // track of every shard index received + receivedShards map[ShardIndex]struct{} +} + +func NewValidator(key *messageKey, scheduler *Scheduler) Validator { + return Validator{ + committeeID: key.CommitteeID, + publisher: key.Publisher, + messageRoot: key.Root, + publisherPubKey: nil, // todo(rdr): nil for now, need to think how to pass this one + scheduler: scheduler, + receivedShards: make(map[ShardIndex]struct{}, scheduler.NumDataShards()), + } +} + +func (v *Validator) ValidateUnit(unit *Unit, sender peer.ID) error { + // todo(rdr): Do I need to check that comitteeID, publisher and messageRoot to be the + // same as the one being hold by the validator, or is that infered from the signature being + // correct or wrong? + + if _, ok := v.receivedShards[unit.ShardIndex]; ok { + return fmt.Errorf("duplicated shard %d received", unit.ShardIndex) + } + + err := v.scheduler.ValidateShardOrigin(sender, v.publisher, unit.ShardIndex) + if err != nil { + } + + if err = v.verified.Verify(unit, v.publisherPubKey); err != nil { + return nil + } + + // Cache the verified shard to avoid re-verification + v.receivedShards[unit.ShardIndex] = struct{}{} + + return nil +} diff --git a/consensus/propeller/validator_test.go b/consensus/propeller/validator_test.go new file mode 100644 index 0000000000..bb61865bd4 --- /dev/null +++ b/consensus/propeller/validator_test.go @@ -0,0 +1 @@ +package propeller_test From 22cdef034fae83c78365b261c27a52bab5383b59 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sun, 5 Apr 2026 11:18:02 +0100 Subject: [PATCH 12/40] chore(consensus/propeller): general continue cleanup of propeller --- consensus/propeller/engine.go | 129 ++++++------------------------- consensus/propeller/processor.go | 32 +++++++- consensus/propeller/propeller.go | 57 ++++++++++++-- consensus/propeller/scheduler.go | 13 ++++ consensus/propeller/sharding.go | 10 ++- consensus/propeller/timecache.go | 9 +-- consensus/propeller/types.go | 26 +------ consensus/propeller/unit.go | 5 ++ 8 files changed, 129 insertions(+), 152 deletions(-) diff --git a/consensus/propeller/engine.go b/consensus/propeller/engine.go index deb6f69643..962b11dc1d 100644 --- a/consensus/propeller/engine.go +++ b/consensus/propeller/engine.go @@ -26,7 +26,10 @@ type broadcastResult struct { } // todo(rdr): using String until I find a better type -type StakerID string +type StakerID struct { + peerID peer.ID + pubKey crypto.PubKey +} // Holds the state for a Committee ID: // - The `scheduler` represents the Propeller Tree of peers @@ -39,9 +42,9 @@ type committeeState struct { // state of a current committee. They both should be created and closed at the // same time. If it is like this then it stands to reason that it should be coupled // here. Not 100% sure right now, so leaving a big todo for now. - proccessor *MessageProcessor + // todo(rdr): why do we need this - peerKeys []StakerID + peerKeys map[peer.ID]crypto.PubKey } // engineCommand is a tagged union of commands sent to the engine's Run() loop. @@ -181,15 +184,15 @@ func (e *Engine) registerCommittee( return nil } - stakerIDs := make([]StakerID, len(peersKeys)) - for i := range peersKeys { - if peersKeys[i] != nil { - stakerIDs[i] = *peersKeys[i] - } else { - // todo(rdr): re-check this flow once implementation is complete - panic("received nil key, they shoudln't be nil") - } - } + // stakerIDs := make([]StakerID, len(peersKeys)) + // for i := range peersKeys { + // if peersKeys[i] != nil { + // stakerIDs[i] = *peersKeys[i] + // } else { + // // todo(rdr): re-check this flow once implementation is complete + // panic("received nil key, they shoudln't be nil") + // } + // } schedule, err := NewScheduler(e.localPeer, peers) if err != nil { @@ -198,7 +201,8 @@ func (e *Engine) registerCommittee( e.committees[committeeID] = &committeeState{ scheduler: schedule, - peerKeys: stakerIDs, + // todo(rdr): need to add the peer pub keys + peerKeys: nil, } e.log.Info("registered new committee", @@ -277,7 +281,8 @@ func (e *Engine) broadcast(units []Unit) error { // processUnit routes an incoming unit to the correct processor, creating // one if needed. func (e *Engine) processUnit(ctx context.Context, unit *Unit, sender peer.ID) { - if _, ok := e.committees[unit.CommitteeID]; !ok { + cs, ok := e.committees[unit.CommitteeID] + if !ok { // note(rdr): maybe debug? e.log.Warn("received key for unregistered committee, dropping", zap.Uint64("committee id", uint64(unit.CommitteeID)), @@ -285,99 +290,10 @@ func (e *Engine) processUnit(ctx context.Context, unit *Unit, sender peer.ID) { return } - // Skip already-finalised messages. - // todo(rdr): add timestamps to message keys to avoid replay attacks on old messages - key := messageKey{ - CommitteeID: unit.CommitteeID, - Publisher: unit.Publisher, - Root: unit.MerkleRoot, - } - if e.finalised.Contains(key) { - return - } - - // Route to existing processor or create a new one. - shardCh, exists := e.processors[key] - if !exists { - shardCh = e.createProcessor(ctx, key, unit) - if shardCh == nil { - return // Channel not registered; logged inside createProcessor. - } - } - - // Non-blocking send to the processor. If its buffer is full, the shard - // is dropped (the processor can reconstruct from other shards). - select { - case shardCh <- shardDelivery{Unit: unit, Sender: sender}: - default: - e.log.Warn("dropping shard: processor channel full", - zap.Uint32("shard", uint32(unit.ShardIndex)), - zap.Stringer("publisher", unit.Publisher), - ) - } -} - -// createProcessor spins up a new MessageProcessor goroutine for a message -// we haven't seen before. -func (e *Engine) createProcessor( - ctx context.Context, key messageKey, unit *Unit, -) chan<- shardDelivery { - cs, ok := e.committees[unit.CommitteeID] - if !ok { - e.log.Warn("received unit for unregistered channel", - zap.Uint32("channel", uint32(unit.CommitteeID)), - ) - return nil + err := e.processor.ProcessMessage(ctx, unit, sender, cs.scheduler) + if err != nil { + e.log.Error("cannot process incoming unit", zap.Error(err)) } - - validator := NewValidator(cs.scheduler, e.localPeer, &DefaultSignatureVerifier{}) - - // Buffer the shard channel so the engine doesn't block when delivering - // multiple shards in rapid succession. - shardCh := make(chan shardDelivery, cs.scheduler.NumShards()) - - proc := NewMessageProcessor( - unit.CommitteeID, - unit.Publisher, - unit.Root, - e.localPeer, - e.config, - shardCh, - e.eventCh, - e.sendFn, - ) - - e.processors[key] = shardCh - - // Launch the processor goroutine. It will run until finalisation, - // timeout, or context cancellation. The cleanup signal goes to a - // dedicated channel so it cannot be blocked by a full eventCh. - go func() { - proc.Run(ctx) - select { - case e.cleanupCh <- processorDone{key: key}: - case <-ctx.Done(): - } - }() - - return shardCh -} - -// processorDone is an internal event signalling that a processor's goroutine -// has exited. The engine uses this to clean up the processors map. -type processorDone struct { - key messageKey -} - -// handleProcessorDone cleans up after a processor goroutine exits. -func (e *Engine) handleProcessorDone(done processorDone) { - delete(e.processors, done.key) - e.finalised.Add(done.key) - - // Periodically clean up expired entries in the time cache. - // Amortised cost: we do it on every processor exit, which is - // infrequent relative to shard processing. - e.finalised.Cleanup() } // forwardEvent sends an event to the application's event channel. Non-blocking @@ -431,6 +347,7 @@ func (e *Engine) Run(ctx context.Context) error { case event := <-e.eventCh: // Forward application-visible events from processors. e.forwardEvent(event) + } } } diff --git a/consensus/propeller/processor.go b/consensus/propeller/processor.go index 21807a47b3..c62bbeda89 100644 --- a/consensus/propeller/processor.go +++ b/consensus/propeller/processor.go @@ -7,7 +7,9 @@ import ( "sync" "time" + "github.com/NethermindEth/juno/utils" "github.com/libp2p/go-libp2p/core/peer" + "go.uber.org/zap" ) type unitWithSender struct { @@ -31,11 +33,14 @@ type subprocessor struct { localShardIndex ShardIndex localShardWasBroadcast bool + validator Validator messageState messageState unitsReceived []Unit } -func newSubprocessor(scheduler *Scheduler, localShardIndex ShardIndex) subprocessor { +func newSubprocessor( + key *messageKey, scheduler *Scheduler, localShardIndex ShardIndex, +) subprocessor { return subprocessor{ scheduler: scheduler, localShardIndex: localShardIndex, @@ -125,6 +130,7 @@ type Processor struct { localPeer peer.ID timeout time.Duration concurrentTasksBounds concurrentTasksBounds + log utils.StructuredLogger } func NewProcessor(localPeer peer.ID, config *Config) *Processor { @@ -146,6 +152,24 @@ func NewProcessor(localPeer peer.ID, config *Config) *Processor { } func (p *Processor) Run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case done := <-p.done: + if done.error != nil { + p.log.Error( + "subprocessor error", + // todo(rdr): need to use proper zap logger here + zap.Any("message key", done.messageKey), + zap.Error(done.error), + ) + } + p.decreaseTask(done.messageKey.Publisher) + delete(p.subProcessors, done.messageKey) + + } + } } func (p *Processor) ProcessMessage( @@ -169,8 +193,8 @@ func (p *Processor) ProcessMessage( // - A processing task that process the message (go routine B) // - Then A will send the correct units to B // This means that when many messages are received in quick succession, they can be validated - // non blockingly. This also means we have two go routines by sub processor than just a single - // one. Does it makes sense? + // non blockingly. This also means we have two go routines for sub processor than just a single + // one. unitChan, err := p.subprocessorChannel(ctx, &key, scheduler) if err != nil { fmt.Errorf("couldn't get processor channel for key: %w", err) @@ -222,7 +246,7 @@ func (p *Processor) createSubprocessor( localShardIndex ShardIndex, unitChan <-chan unitWithSender, ) { - subProcessor := newSubprocessor(scheduler, localShardIndex) + subProcessor := newSubprocessor(&key, scheduler, localShardIndex) err := subProcessor.Run(ctx, unitChan) p.done <- messageKeyWithError{ messageKey: messageKey, diff --git a/consensus/propeller/propeller.go b/consensus/propeller/propeller.go index 620b5b4fc0..a07cb315dd 100644 --- a/consensus/propeller/propeller.go +++ b/consensus/propeller/propeller.go @@ -5,15 +5,15 @@ import ( "context" "io" + pb "github.com/NethermindEth/juno/consensus/propeller/proto" "github.com/NethermindEth/juno/utils" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "go.uber.org/zap" + "google.golang.org/protobuf/proto" ) -const propellerProtocolID = "/propeller/0.0.1" - // This would represent the propeller service that glues the whole // thing to p2p. Thing is, I've no clue how to do that. type Service interface{} @@ -21,6 +21,7 @@ type Service interface{} type propellerService struct { host host.Host engine *Engine + cmdCh chan<- engineCommand config Config log utils.Logger } @@ -31,25 +32,22 @@ func New( config *Config, log utils.Logger, ) Service { - engine := NewEngine( + engine, cmdCh := NewEngine( privKey, config, - nil, log, ) return &propellerService{ host: host, engine: engine, + cmdCh: cmdCh, config: *config, log: log, } } -func (s *propellerService) Run(ctx context.Context) { -} - -func (s *propellerService) handleInboudStream(stream network.Stream) { +func (s *propellerService) receivePropellerUnits(stream network.Stream) { defer stream.Close() sender := stream.Conn().RemotePeer() @@ -64,4 +62,47 @@ func (s *propellerService) handleInboudStream(stream network.Stream) { zap.Error(err), ) } + + var batch pb.PropellerUnitBatch + err = proto.Unmarshal(buf.Bytes(), &batch) + if err != nil { + s.log.Debug("error unmarshalling propeller batch", + zap.Stringer("peer", sender), + zap.Error(err), + ) + } + + for _, protoUnit := range batch.GetBatch() { + unit := UnitFromProto(protoUnit) + // send unit to engine + s.cmdCh <- processUnit{ + &unit, + sender, + } + } +} + +func (s *propellerService) Run(ctx context.Context) error { + go func() { + err := s.engine.Run(ctx) + if err != nil { + s.log.Error("shutting down propeller engine", zap.Error(err)) + return + } + s.log.Info("shutting down propeller engine") + }() + + s.host.SetStreamHandler(s.config.StreamProtocol, s.receivePropellerUnits) + defer s.host.RemoveStreamHandler(s.config.StreamProtocol) + + for { + select { + case <-ctx.Done(): + return ctx.Err() + } + // todo(rdr): handle the engines output such as units to broadcast + } +} + +func (s *propellerService) broadcast() { } diff --git a/consensus/propeller/scheduler.go b/consensus/propeller/scheduler.go index 075791cd30..48e9e4e02b 100644 --- a/consensus/propeller/scheduler.go +++ b/consensus/propeller/scheduler.go @@ -117,6 +117,19 @@ func (s *Scheduler) NumCodingShards() int { return s.numCodingShards } // NumShards returns the total number of shards (data + coding = N-1). func (s *Scheduler) NumTotalShards() int { return s.numDataShards + s.numCodingShards } +// Minimum (inclusive) amount of shards required to build a message +func (s *Scheduler) BuildThreshold() int { + return s.numDataShards +} + +// Minimum (inclusive) amount of shards required to guarantee a message is received +func (s *Scheduler) ReceiveThreshold() int { + if len(s.peers) <= 3 { + return s.BuildThreshold() + } + return s.numDataShards * 2 +} + func (s *Scheduler) publisherIndex(publisher peer.ID) (int, error) { publisherIndex, found := slices.BinarySearchFunc( s.peers, diff --git a/consensus/propeller/sharding.go b/consensus/propeller/sharding.go index 292042c6d6..08a1c411b2 100644 --- a/consensus/propeller/sharding.go +++ b/consensus/propeller/sharding.go @@ -12,13 +12,14 @@ import ( ) // CreatePropellerUnits creates the PropellerUnits for publishing +// todo(rdr): maybe call it create message for sharing or somth like that func CreatePropellerUnits( committeeID CommitteeID, message []byte, privKey crypto.PrivKey, numDataShards, parity int, -) ([]PropellerUnit, error) { +) ([]Unit, error) { publisherID, err := peer.IDFromPrivateKey(privKey) if err != nil { return nil, fmt.Errorf("getting publisher id from private key: %w", publisherID) @@ -38,11 +39,11 @@ func CreatePropellerUnits( return nil, err } - units := make([]PropellerUnit, len(encodedMessage)) + units := make([]Unit, len(encodedMessage)) for i, shard := range encodedMessage { merkleProof := merkleTree[i] - units[i] = PropellerUnit{ + units[i] = Unit{ CommitteeID: committeeID, Publisher: publisherID, MerkleRoot: messageRoot, @@ -57,8 +58,9 @@ func CreatePropellerUnits( // DecodePropellerUnit receives Propeller units, recovers any missing data and returns // the fully verified message, together with the corresponding shard data and merkle proof. +// todo(rdr): maybe call it decode received message func DecodePropellerUnit( - units []PropellerUnit, + units []Unit, messageRoot MessageRoot, localShardIndex ShardIndex, numDataShards int, diff --git a/consensus/propeller/timecache.go b/consensus/propeller/timecache.go index 8f539d8605..cf229ed695 100644 --- a/consensus/propeller/timecache.go +++ b/consensus/propeller/timecache.go @@ -16,8 +16,6 @@ type TimeCache[K comparable] struct { mu sync.Mutex entries map[K]time.Time ttl time.Duration - // nowFn is injectable for testing. In production it is time.Now. - nowFn func() time.Time } // NewTimeCache creates a cache where entries expire after the given TTL. @@ -25,7 +23,6 @@ func NewTimeCache[K comparable](ttl time.Duration) *TimeCache[K] { return &TimeCache[K]{ entries: make(map[K]time.Time), ttl: ttl, - nowFn: time.Now, } } @@ -34,7 +31,7 @@ func NewTimeCache[K comparable](ttl time.Duration) *TimeCache[K] { func (c *TimeCache[K]) Add(key K) { c.mu.Lock() defer c.mu.Unlock() - c.entries[key] = c.nowFn().Add(c.ttl) + c.entries[key] = time.Now().Add(c.ttl) } // Contains returns true if the key is present and has not expired. @@ -47,7 +44,7 @@ func (c *TimeCache[K]) Contains(key K) bool { if !ok { return false } - return c.nowFn().Before(expiry) + return time.Now().Before(expiry) } // Cleanup removes all expired entries from the cache. Call this periodically @@ -56,7 +53,7 @@ func (c *TimeCache[K]) Contains(key K) bool { func (c *TimeCache[K]) Cleanup() { c.mu.Lock() defer c.mu.Unlock() - now := c.nowFn() + now := time.Now() for k, expiry := range c.entries { if !now.Before(expiry) { delete(c.entries, k) diff --git a/consensus/propeller/types.go b/consensus/propeller/types.go index da24cd8651..a120009fb6 100644 --- a/consensus/propeller/types.go +++ b/consensus/propeller/types.go @@ -13,6 +13,7 @@ import ( "github.com/NethermindEth/juno/consensus/propeller/merkle" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" ) // CommitteeID identifies a committee or logical broadcast group. Multiple committees @@ -38,7 +39,7 @@ type Config struct { // StreamProtocol is the libp2p protocol identifier used for direct // shard transfers between peers. - StreamProtocol string + StreamProtocol protocol.ID // MaxWireMessageSize caps the size of a single serialised PropellerUnit // on the wire. Units exceeding this are rejected to prevent memory @@ -55,29 +56,6 @@ func DefaultConfig() Config { } } -// PropellerUnit is the atomic wire message: one erasure-coded shard plus -// the metadata needed for independent verification. Each unit is self-contained -// so a receiver can validate it without any other shards. -type PropellerUnit struct { - CommitteeID CommitteeID // Which committee this belongs to - Publisher peer.ID // Original message author - MerkleRoot MessageRoot // Merkle root binding all shards together - MerkleProof merkle.Proof // Merkle inclusion proof for this shard - Signature []byte // Publisher's Ed25519 signature over the root - ShardIndex ShardIndex // This shard's position in the erasure-coded output - ShardData []byte // The actual data fragment -} - -// messageKey uniquely identifies a message within a channel. We track -// per-message state (processor, time cache) using this composite key -// because the same publisher could broadcast different messages (different -// roots) and we need to handle each independently. -type messageKey struct { - Channel CommitteeID - Publisher peer.ID - Root MessageRoot -} - // --------------------------------------------------------------------------- // Events: structured outputs from the engine to the application layer. // Each event is emitted at most once per message lifecycle. diff --git a/consensus/propeller/unit.go b/consensus/propeller/unit.go index b60a9ee88d..404702b82a 100644 --- a/consensus/propeller/unit.go +++ b/consensus/propeller/unit.go @@ -1,6 +1,8 @@ package propeller import ( + "time" + "github.com/NethermindEth/juno/consensus/propeller/merkle" pb "github.com/NethermindEth/juno/consensus/propeller/proto" "github.com/libp2p/go-libp2p/core/peer" @@ -18,6 +20,9 @@ type Unit struct { Signature []byte // Publisher's Ed25519 signature over the root ShardIndex ShardIndex // This shard's position in the erasure-coded output ShardData []byte // The actual data fragment + // todo(rdr): calling it nonce because that's what is called on the rust side but + // time stamp or some other name would be better + Nonce time.Duration // Strictly increasing number, starting from the Unix epoch } func UnitFromProto(protoUnit *pb.PropellerUnit) Unit { From a7aee53b1888681b4a8499cd0cbd472681c9ee5f Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sun, 5 Apr 2026 12:34:24 +0100 Subject: [PATCH 13/40] chore: general sanitazing of the code --- consensus/propeller/deprecated_processor.go | 4 +- consensus/propeller/deprecated_validator.go | 4 +- .../propeller/deprecated_validator_test.go | 2 +- consensus/propeller/engine.go | 96 ------------- consensus/propeller/engine_test.go | 2 +- consensus/propeller/processor.go | 62 +++++--- consensus/propeller/scheduler.go | 40 +++--- consensus/propeller/sharding.go | 5 +- consensus/propeller/unit.go | 22 +-- consensus/propeller/validator.go | 134 +++++++++++------- 10 files changed, 166 insertions(+), 205 deletions(-) diff --git a/consensus/propeller/deprecated_processor.go b/consensus/propeller/deprecated_processor.go index 3c79ccb8b3..8268bb71bb 100644 --- a/consensus/propeller/deprecated_processor.go +++ b/consensus/propeller/deprecated_processor.go @@ -169,7 +169,7 @@ func (p *MessageProcessor) handleShard(ctx context.Context, delivery shardDelive ); err != nil { p.emitEvent(EventShardValidationFailed{ Sender: delivery.Sender, - ClaimedRoot: unit.MerkleRoot, + ClaimedRoot: unit.MessageRoot, ClaimedPublisher: unit.Publisher, Err: err, }) @@ -252,7 +252,7 @@ func (p *MessageProcessor) handlePreConstruction(ctx context.Context) { p.myShardUnit = &Unit{ CommitteeID: p.committeeID, Publisher: p.publisher, - MerkleRoot: p.root, + MessageRoot: p.root, Signature: p.storedSignature, ShardIndex: myShard, ShardData: shardsCopy[myShard], diff --git a/consensus/propeller/deprecated_validator.go b/consensus/propeller/deprecated_validator.go index 0ea7725725..ec0c390cae 100644 --- a/consensus/propeller/deprecated_validator.go +++ b/consensus/propeller/deprecated_validator.go @@ -111,7 +111,7 @@ func (v *DeprecatedValidator) ValidateUnit( // 5. Verify the Merkle inclusion proof. This ensures the shard data // is consistent with the tree root the publisher committed to. if !VerifyMerkleProof( - unit.MerkleRoot, unit.ShardData, uint32(unit.ShardIndex), unit.MerkleProof, + unit.MessageRoot, unit.ShardData, uint32(unit.ShardIndex), unit.MerkleProof, ) { return &ShardValidationError{ Reason: ReasonMerkleProofVerificationFailed, @@ -123,7 +123,7 @@ func (v *DeprecatedValidator) ValidateUnit( // expensive check (public-key crypto), so we skip it if we've already // verified the same root from this publisher. if !signatureVerified { - payload := SignPayload(unit.MerkleRoot) + payload := SignPayload(unit.MessageRoot) ok, err := v.verifier.Verify(unit.Publisher, payload, unit.Signature) if err != nil { return &ShardValidationError{ diff --git a/consensus/propeller/deprecated_validator_test.go b/consensus/propeller/deprecated_validator_test.go index 44b39215fb..37b50ceb16 100644 --- a/consensus/propeller/deprecated_validator_test.go +++ b/consensus/propeller/deprecated_validator_test.go @@ -242,7 +242,7 @@ func TestValidator_MerkleProofFailed(t *testing.T) { unit := &Unit{ Publisher: setup.publisher, ShardIndex: setup.shardIndex, - MerkleRoot: MessageRoot{0x01}, + MessageRoot: MessageRoot{0x01}, ShardData: []byte("data"), MerkleProof: MerkleProof{Siblings: [][32]byte{{0xde, 0xad}}}, } diff --git a/consensus/propeller/engine.go b/consensus/propeller/engine.go index 962b11dc1d..e577f0cd3a 100644 --- a/consensus/propeller/engine.go +++ b/consensus/propeller/engine.go @@ -350,99 +350,3 @@ func (e *Engine) Run(ctx context.Context) error { } } } - -// Probably unuseful code - -// Events returns the channel on which the application receives protocol -// events. The caller should read from this channel continuously to avoid -// backpressure on the engine. -// func (e *Engine) Events() <-chan any { -// return e.appEventCh -// } - -// RegisterCommitee registers a committee with its peer set. This must be called -// before broadcasting on or receiving shards for a committee. -// -// The method blocks until the command is processed by the engine's Run() loop. -// todo(rdr): I am not sure this method should exist, or at least be defined at engine level -// func (e *Engine) RegisterCommittee( -// ctx context.Context, -// committeeID CommitteeID, -// peers []peer.ID, -// ) error { -// errCh := make(chan error, 1) -// select { -// case e.cmdCh <- ®isterCommittee{ -// committeeID: committeeID, -// peers: peers, -// errCh: errCh, -// }: -// case <-ctx.Done(): -// return ctx.Err() -// } -// -// select { -// case err := <-errCh: -// return err -// case <-ctx.Done(): -// return ctx.Err() -// } -// } - -// // UnregisterCommittee removes a committee. Existing processors for that channel -// // will continue running until they finalise or time out, but no new -// // processors will be created. -// func (e *Engine) UnregisterCommittee(ctx context.Context, channel CommitteeID) error { -// select { -// case e.cmdCh <- &cmdUnregister{committeeID: channel}: -// return nil -// case <-ctx.Done(): -// return ctx.Err() -// } -// } - -// // Broadcast encodes and distributes a message to all peers in the given -// // channel. The local node acts as the publisher. -// // -// // The method blocks until the command is processed by the engine's Run() loop. -// func (e *Engine) Broadcast( -// ctx context.Context, channel CommitteeID, msg []byte, -// ) error { -// errCh := make(chan error, 1) -// select { -// case e.cmdCh <- &cmdBroadcast{ -// channel: channel, -// msg: msg, -// errCh: errCh, -// }: -// case <-ctx.Done(): -// return ctx.Err() -// } -// -// select { -// case err := <-errCh: -// return err -// case <-ctx.Done(): -// return ctx.Err() -// } -// } - -// // HandleUnit routes an incoming PropellerUnit from the network to the -// // appropriate message processor. This method is non-blocking: it sends -// // the unit to the engine's command channel. -// func (e *Engine) HandleUnit(unit *PropellerUnit, sender peer.ID) { -// // Non-blocking send: if the command channel is full, drop the unit. -// // This provides backpressure against flood attacks. The sender can -// // retry or the processor can reconstruct from other shards. -// select { -// case e.cmdCh <- &cmdHandleUnit{ -// unit: unit, -// sender: sender, -// }: -// default: -// e.log.Warn("dropping incoming unit: command channel full", -// zap.Uint32("shard", uint32(unit.ShardIndex)), -// zap.Stringer("publisher", unit.Publisher), -// ) -// } -// } diff --git a/consensus/propeller/engine_test.go b/consensus/propeller/engine_test.go index 5543e20042..1d4b57fdef 100644 --- a/consensus/propeller/engine_test.go +++ b/consensus/propeller/engine_test.go @@ -269,7 +269,7 @@ func TestEngine_HandleUnit_UnregisteredChannel(t *testing.T) { unit := &Unit{ CommitteeID: 99, Publisher: env.peers[1], - MerkleRoot: MessageRoot{0x01}, + MessageRoot: MessageRoot{0x01}, ShardIndex: 0, ShardData: []byte("data"), } diff --git a/consensus/propeller/processor.go b/consensus/propeller/processor.go index c62bbeda89..2491c7edea 100644 --- a/consensus/propeller/processor.go +++ b/consensus/propeller/processor.go @@ -46,6 +46,7 @@ func newSubprocessor( localShardIndex: localShardIndex, localShardWasBroadcast: false, + validator: NewValidator(key, scheduler), messageState: preBuilt, unitsReceived: make([]Unit, 0, scheduler.ReceiveThreshold()), } @@ -74,6 +75,8 @@ func (s *subprocessor) Run(ctx context.Context, unitChan <-chan unitWithSender) if len(s.unitsReceived) == s.scheduler.BuildThreshold() { s.messageState.NextState() + // trigger message rebuilding - can it be done in a non-blocking way? + // does it makes sense to do it in a non-blocking way? } case preReceived: @@ -102,6 +105,11 @@ type messageKey struct { CommitteeID CommitteeID Publisher peer.ID Root MessageRoot + Nonce Nonce +} + +func (mk *messageKey) String() string { + return fmt.Sprintf("%+v", *mk) } type messageKeyWithError struct { @@ -116,16 +124,17 @@ type concurrentTasksBounds struct { // Processor handles all concurrent work on message processing type Processor struct { + // to avoid processing units already finalized finalized *TimeCache[messageKey] - // todo(rdr): channel to communicate that a certain subprocessor has finished + // channel that communicates when a subprocessor has finished done chan messageKeyWithError + subProcessors map[messageKey]chan unitWithSender + + // track current open and closed tasks to avoid resource starvation mu sync.Mutex publisherTasks map[peer.ID]uint64 tasks uint64 - // ---------------------------------- - subProcessors map[messageKey]chan unitWithSender - // config inherited from Engine localPeer peer.ID timeout time.Duration @@ -147,7 +156,11 @@ func NewProcessor(localPeer peer.ID, config *Config) *Processor { localPeer: localPeer, timeout: timeout, // todo(rdr): set this ones based on the config (or some consts?) - concurrentTasksBounds: concurrentTasksBounds{}, + concurrentTasksBounds: concurrentTasksBounds{ + // dummy values for now + maxWorkers: 1000, + maxWorkersPerPublisher: 250, + }, } } @@ -156,22 +169,23 @@ func (p *Processor) Run(ctx context.Context) { select { case <-ctx.Done(): return - case done := <-p.done: - if done.error != nil { + case finishedSubP := <-p.done: + if finishedSubP.error != nil { p.log.Error( "subprocessor error", - // todo(rdr): need to use proper zap logger here - zap.Any("message key", done.messageKey), - zap.Error(done.error), + zap.String("message key", finishedSubP.messageKey.String()), + zap.Error(finishedSubP.error), ) } - p.decreaseTask(done.messageKey.Publisher) - delete(p.subProcessors, done.messageKey) + p.decreaseTask(finishedSubP.messageKey.Publisher) + delete(p.subProcessors, finishedSubP.messageKey) } } } +// ProcessMessage validates and process the received `unit` non-blockingly. It returns an +// error if the unit couldn't start processing. func (p *Processor) ProcessMessage( ctx context.Context, unit *Unit, @@ -181,7 +195,8 @@ func (p *Processor) ProcessMessage( key := messageKey{ CommitteeID: unit.CommitteeID, Publisher: unit.Publisher, - Root: unit.MerkleRoot, + Root: unit.MessageRoot, + Nonce: unit.Nonce, } if p.finalized.Contains(key) { return nil @@ -193,8 +208,8 @@ func (p *Processor) ProcessMessage( // - A processing task that process the message (go routine B) // - Then A will send the correct units to B // This means that when many messages are received in quick succession, they can be validated - // non blockingly. This also means we have two go routines for sub processor than just a single - // one. + // non blockingly. This also means we have two go routines for sub processor rather than just + // a single one. unitChan, err := p.subprocessorChannel(ctx, &key, scheduler) if err != nil { fmt.Errorf("couldn't get processor channel for key: %w", err) @@ -221,7 +236,9 @@ func (p *Processor) createSubprocessor( ) (chan unitWithSender, error) { localShardIndex, err := scheduler.ShardIndexForPublisher(key.Publisher) if err != nil { - return nil, fmt.Errorf("cannot create new subprocessor: %w", err) + return nil, fmt.Errorf( + "cannot get local shard index for publisher %s: %w", key.Publisher, err, + ) } err = p.increaseTasks(key.Publisher) @@ -241,7 +258,7 @@ func (p *Processor) createSubprocessor( // todo(rdr): should I pass p.done as an argument? go func( ctx context.Context, - messageKey messageKey, + key messageKey, scheduler *Scheduler, localShardIndex ShardIndex, unitChan <-chan unitWithSender, @@ -249,7 +266,7 @@ func (p *Processor) createSubprocessor( subProcessor := newSubprocessor(&key, scheduler, localShardIndex) err := subProcessor.Run(ctx, unitChan) p.done <- messageKeyWithError{ - messageKey: messageKey, + messageKey: key, error: err, } }(ctxWithTimeout, *key, scheduler, localShardIndex, unitChan) @@ -265,8 +282,13 @@ func (p *Processor) subprocessorChannel( scheduler *Scheduler, ) (chan unitWithSender, error) { unitChan, ok := p.subProcessors[*key] - if !ok { - return p.createSubprocessor(ctx, key, scheduler) + if ok { + return unitChan, nil + } + + unitChan, err := p.createSubprocessor(ctx, key, scheduler) + if err != nil { + return nil, fmt.Errorf("creating new subprocessor: %w", err) } return unitChan, nil } diff --git a/consensus/propeller/scheduler.go b/consensus/propeller/scheduler.go index 48e9e4e02b..c7840c4a58 100644 --- a/consensus/propeller/scheduler.go +++ b/consensus/propeller/scheduler.go @@ -36,11 +36,11 @@ type PeerCommittee struct { // - Message is RECEIVED when 2*numDataShards shards are received (guarantees gossip property) // - Each peer broadcasts received shards to all other peers (full mesh) type Scheduler struct { - peerID peer.ID - peerIDIndex int - peers []PeerCommittee - numDataShards int - numCodingShards int + localPeerID peer.ID + localPeerIDIndex int + peers []PeerCommittee + numDataShards int + numCodingShards int } // NewScheduler creates a schedule from a list of peers. The peers are sorted @@ -60,7 +60,7 @@ func NewScheduler( } // todo(rdr): check with function is faster for sorting in our case: - // `slices.Sort` or `sort.Slice` + // `slices.Sort` or `sort.Slice`. Alternative: // sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID }) slices.SortFunc(nodes, func(i, j PeerCommittee) int { return cmp.Compare(i.ID, j.ID) }) @@ -90,17 +90,17 @@ func NewScheduler( numCodingShards := max(0, totalNodes-1-numDataShards) return &Scheduler{ - peerID: id, - peerIDIndex: idIndex, - peers: nodes, - numDataShards: numDataShards, - numCodingShards: numCodingShards, + localPeerID: id, + localPeerIDIndex: idIndex, + peers: nodes, + numDataShards: numDataShards, + numCodingShards: numCodingShards, }, nil } // PeerID returns the Scheduler Peer ID func (s *Scheduler) PeerID() peer.ID { - return s.peerID + return s.localPeerID } // Peers return the Scheduler list of nodes @@ -188,7 +188,7 @@ func (s *Scheduler) PeerForShardIndex( func (s *Scheduler) ShardIndexForPublisher( publisher peer.ID, ) (ShardIndex, error) { - if s.peerID == publisher { + if s.localPeerID == publisher { return 0, fmt.Errorf( "scheduler peer is the same as the publisher and has no assinged shard: %s", publisher, @@ -200,9 +200,9 @@ func (s *Scheduler) ShardIndexForPublisher( return 0, fmt.Errorf("couldn't locate shard index for publisher: %w", err) } - shardIdx := s.peerIDIndex - if s.peerIDIndex >= pubIdx { - shardIdx = s.peerIDIndex - 1 + shardIdx := s.localPeerIDIndex + if s.localPeerIDIndex >= pubIdx { + shardIdx = s.localPeerIDIndex - 1 } return ShardIndex(shardIdx), nil @@ -216,10 +216,10 @@ func (s *Scheduler) ValidateShardOrigin( publisher peer.ID, shardIndex ShardIndex, ) error { - if sender == s.peerID { + if sender == s.localPeerID { return fmt.Errorf("scheduler sent itself a shard: %s", sender) } - if publisher == s.peerID { + if publisher == s.localPeerID { return fmt.Errorf("scheduler broadcast itself a shard: %s", publisher) } @@ -233,7 +233,7 @@ func (s *Scheduler) ValidateShardOrigin( ) } - validDirectShard := expectedBroadcaster == s.peerID && sender == publisher + validDirectShard := expectedBroadcaster == s.localPeerID && sender == publisher if validDirectShard { return nil } @@ -257,7 +257,7 @@ func (s *Scheduler) BroadcastTargets() []peer.ID { targets := make([]peer.ID, s.NumTotalShards()-1) i := 0 for _, p := range s.peers { - if i == s.peerIDIndex { + if i == s.localPeerIDIndex { continue } targets[i] = p.ID diff --git a/consensus/propeller/sharding.go b/consensus/propeller/sharding.go index 08a1c411b2..3760fae5b2 100644 --- a/consensus/propeller/sharding.go +++ b/consensus/propeller/sharding.go @@ -34,7 +34,8 @@ func CreatePropellerUnits( merkleRoot, merkleTree := merkle.New(encodedMessage) messageRoot := MessageRoot(merkleRoot) - signature, err := SignRoot(messageRoot, privKey) + // todo(rdr): check that this signing is correct + signature, err := utils.SignRoot(messageRoot, privKey) if err != nil { return nil, err } @@ -46,7 +47,7 @@ func CreatePropellerUnits( units[i] = Unit{ CommitteeID: committeeID, Publisher: publisherID, - MerkleRoot: messageRoot, + MessageRoot: messageRoot, MerkleProof: merkleProof, Signature: signature, ShardIndex: ShardIndex(i), diff --git a/consensus/propeller/unit.go b/consensus/propeller/unit.go index 404702b82a..d700bd2e77 100644 --- a/consensus/propeller/unit.go +++ b/consensus/propeller/unit.go @@ -9,31 +9,35 @@ import ( "github.com/starknet-io/starknet-p2p-specs/p2p/proto/common" ) +type Signature []byte + +type Nonce time.Duration + // Unit is the atomic wire message: one erasure-coded shard plus // the metadata needed for independent verification. Each unit is self-contained // so a receiver can validate it without any other shards. type Unit struct { CommitteeID CommitteeID // Which committee this belongs to Publisher peer.ID // Original message author - MerkleRoot MessageRoot // Merkle root binding all shards together + MessageRoot MessageRoot // Merkle root binding all shards together MerkleProof merkle.Proof // Merkle inclusion proof for this shard - Signature []byte // Publisher's Ed25519 signature over the root + Signature Signature // Publisher's Ed25519 signature over the root ShardIndex ShardIndex // This shard's position in the erasure-coded output ShardData []byte // The actual data fragment // todo(rdr): calling it nonce because that's what is called on the rust side but // time stamp or some other name would be better - Nonce time.Duration // Strictly increasing number, starting from the Unix epoch + Nonce Nonce // Strictly increasing number, starting from the Unix epoch } func UnitFromProto(protoUnit *pb.PropellerUnit) Unit { return Unit{ CommitteeID: CommitteeID(protoUnit.Channel), // todo(rdr): this casting operations seem a bit risky, are they? - Publisher: peer.ID(protoUnit.Publisher.Id), - MerkleRoot: MessageRoot(protoUnit.MerkleRoot.Elements), - Signature: protoUnit.Signature, - ShardIndex: ShardIndex(protoUnit.Index), - ShardData: protoUnit.Shard, + Publisher: peer.ID(protoUnit.Publisher.Id), + MessageRoot: MessageRoot(protoUnit.MerkleRoot.Elements), + Signature: protoUnit.Signature, + ShardIndex: ShardIndex(protoUnit.Index), + ShardData: protoUnit.Shard, } } @@ -43,7 +47,7 @@ func (u *Unit) ToProto() *pb.PropellerUnit { siblings[i] = &common.Hash256{Elements: s[:]} } - root := merkle.Hash(u.MerkleRoot) + root := merkle.Hash(u.MessageRoot) return &pb.PropellerUnit{ Shard: u.ShardData, Index: uint64(u.ShardIndex), diff --git a/consensus/propeller/validator.go b/consensus/propeller/validator.go index e1ce60855c..d4b8f2a150 100644 --- a/consensus/propeller/validator.go +++ b/consensus/propeller/validator.go @@ -3,60 +3,15 @@ package propeller import ( "bytes" "fmt" - "time" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" ) -type verified struct { - verified bool - signature []byte - nonce time.Duration -} - -// Retuns if -func (v *verified) Verify(unit *Unit, pubKey crypto.PubKey) error { - if v.verified { - if bytes.Equal(v.signature, unit.Signature) && v.nonce == unit.Nonce { - return nil - } - // todo(rdr): add error information. Perhaps build an error type? - return fmt.Errorf("unit signature or nonce missmatch") - } - - err := verifyMessageIDSignature( - unit.CommitteeID, - unit.MerkleRoot, - unit.Signature, - unit.Nonce, - pubKey, - ) - if err != nil { - // add error information - return err - } - - *v = verified{ - verified: true, - // todo(rdr): by storing a field of unit.Signature am I forcing the whole `unit` to - // continue to exist on the heap, or the remaining fields can be cleaned. Probably the - // latter. - signature: unit.Signature, - nonce: unit.Nonce, - } - - return nil -} - -func verifyMessageIDSignature( - committeeID CommitteeID, - root MessageRoot, - signature []byte, - nonce time.Duration, - publisherPubKey crypto.PubKey, -) error { - panic("not yet implemented") +type verificationResult struct { + done bool + signature Signature + nonce Nonce } // Validates all the incoming units / shards given a committee and the publisher @@ -70,23 +25,88 @@ type Validator struct { // Once the validation is done it's stored here, subsequent runs // compare against it - verified verified + verification verificationResult // track of every shard index received receivedShards map[ShardIndex]struct{} } func NewValidator(key *messageKey, scheduler *Scheduler) Validator { + pubKey, err := key.Publisher.ExtractPublicKey() + // for now we are assuming that extracting a publisher key is always successful + // and done in constant time + if err != nil { + panic(err) + } return Validator{ committeeID: key.CommitteeID, publisher: key.Publisher, messageRoot: key.Root, - publisherPubKey: nil, // todo(rdr): nil for now, need to think how to pass this one + publisherPubKey: pubKey, // todo(rdr): nil for now, need to think how to pass this one scheduler: scheduler, receivedShards: make(map[ShardIndex]struct{}, scheduler.NumDataShards()), } } +func (v *Validator) verifyKeyFields(unit *Unit) error { + if unit.CommitteeID != v.committeeID { + return fmt.Errorf( + "different committe id. Expected: %s. Received: %s", unit.CommitteeID, v.committeeID, + ) + } + if unit.Publisher != v.publisher { + return fmt.Errorf( + "different publisher. Expected: %s. Received: %s", unit.Publisher, v.publisher, + ) + } + if unit.MessageRoot != v.messageRoot { + return fmt.Errorf( + "different message root. Expected: %s. Received: %s", unit.MessageRoot, v.messageRoot, + ) + } + + return nil +} + +func (v *Validator) verify(unit *Unit) error { + // todo(rdr): Here we are verifying everything but the data, do we verify the data + // at some point. What happens if a Peer has all the data correct except the shard data + // for example? Is that an attack vector? + // Something fails at some point and the publisher gets slashed? + if v.verification.done { + verificationMatch := bytes.Equal(v.verification.signature, unit.Signature) && + v.verification.nonce == unit.Nonce + if verificationMatch { + return nil + } + // todo(rdr): add error information. Perhaps build an error type? + return fmt.Errorf("unit signature or nonce missmatch") + } + + err := verifyMessageIDSignature( + unit.CommitteeID, + unit.MessageRoot, + unit.Signature, + unit.Nonce, + v.publisherPubKey, + ) + if err != nil { + // add error information + return err + } + + v.verification = verificationResult{ + done: true, + // todo(rdr): by storing a field of unit.Signature am I forcing the whole `unit` to + // continue to exist on the heap, or can the remaining fields be cleaned. Probably the + // latter. + signature: unit.Signature, + nonce: unit.Nonce, + } + + return nil +} + func (v *Validator) ValidateUnit(unit *Unit, sender peer.ID) error { // todo(rdr): Do I need to check that comitteeID, publisher and messageRoot to be the // same as the one being hold by the validator, or is that infered from the signature being @@ -100,7 +120,7 @@ func (v *Validator) ValidateUnit(unit *Unit, sender peer.ID) error { if err != nil { } - if err = v.verified.Verify(unit, v.publisherPubKey); err != nil { + if err = v.verify(unit); err != nil { return nil } @@ -109,3 +129,13 @@ func (v *Validator) ValidateUnit(unit *Unit, sender peer.ID) error { return nil } + +func verifyMessageIDSignature( + committeeID CommitteeID, + root MessageRoot, + signature Signature, + nonce Nonce, + publisherPubKey crypto.PubKey, +) error { + panic("not yet implemented") +} From 66daa9f7cb9c1b9e802795e8ed41a05538505c2d Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sun, 5 Apr 2026 13:03:47 +0100 Subject: [PATCH 14/40] chore: code healing for validator logic --- consensus/propeller/validator.go | 94 ++++++++++++-------------------- 1 file changed, 36 insertions(+), 58 deletions(-) diff --git a/consensus/propeller/validator.go b/consensus/propeller/validator.go index d4b8f2a150..27fd4a0d7f 100644 --- a/consensus/propeller/validator.go +++ b/consensus/propeller/validator.go @@ -8,29 +8,34 @@ import ( "github.com/libp2p/go-libp2p/core/peer" ) -type verificationResult struct { - done bool - signature Signature - nonce Nonce -} +// todo(rdr): A validator lifetime is attached to a `subprocessor`. A `subprocessor` is attached +// to a message key field. This logic is handled by a `Processor`. This means that a validator will // always be given units that have the same committeeID, publisher, messageRoot and Nonce (the +// current fields of a `messageKey`). Does it makes sense for the validator to also hold a copy +// of this. Is there a way of testing this invariant – where a validator only sees the same +// fields. I need to add a test for that invariant // Validates all the incoming units / shards given a committee and the publisher type Validator struct { // Required fields to perform the validation - committeeID CommitteeID - publisher peer.ID + // or not. Check if I can delete them + // committeeID CommitteeID + // publisher peer.ID + // messageRoot MessageRoot + // nonce Nonce + // ---------------------------------------- + publisherPubKey crypto.PubKey - messageRoot MessageRoot scheduler *Scheduler // Once the validation is done it's stored here, subsequent runs // compare against it - verification verificationResult + verifiedSignature Signature // track of every shard index received receivedShards map[ShardIndex]struct{} } +// todo(rdr): maybe just pass the publisher? func NewValidator(key *messageKey, scheduler *Scheduler) Validator { pubKey, err := key.Publisher.ExtractPublicKey() // for now we are assuming that extracting a publisher key is always successful @@ -39,48 +44,27 @@ func NewValidator(key *messageKey, scheduler *Scheduler) Validator { panic(err) } return Validator{ - committeeID: key.CommitteeID, - publisher: key.Publisher, - messageRoot: key.Root, - publisherPubKey: pubKey, // todo(rdr): nil for now, need to think how to pass this one + // committeeID: key.CommitteeID, + // publisher: key.Publisher, + // messageRoot: key.Root, + // nonce: key.Nonce, + publisherPubKey: pubKey, scheduler: scheduler, receivedShards: make(map[ShardIndex]struct{}, scheduler.NumDataShards()), } } -func (v *Validator) verifyKeyFields(unit *Unit) error { - if unit.CommitteeID != v.committeeID { - return fmt.Errorf( - "different committe id. Expected: %s. Received: %s", unit.CommitteeID, v.committeeID, - ) - } - if unit.Publisher != v.publisher { - return fmt.Errorf( - "different publisher. Expected: %s. Received: %s", unit.Publisher, v.publisher, - ) - } - if unit.MessageRoot != v.messageRoot { - return fmt.Errorf( - "different message root. Expected: %s. Received: %s", unit.MessageRoot, v.messageRoot, - ) - } - - return nil -} - func (v *Validator) verify(unit *Unit) error { - // todo(rdr): Here we are verifying everything but the data, do we verify the data - // at some point. What happens if a Peer has all the data correct except the shard data - // for example? Is that an attack vector? - // Something fails at some point and the publisher gets slashed? - if v.verification.done { - verificationMatch := bytes.Equal(v.verification.signature, unit.Signature) && - v.verification.nonce == unit.Nonce - if verificationMatch { + if v.verifiedSignature != nil { + if bytes.Equal(v.verifiedSignature, unit.Signature) { return nil } - // todo(rdr): add error information. Perhaps build an error type? - return fmt.Errorf("unit signature or nonce missmatch") + // todo(rdr): make sure this error is readable + return fmt.Errorf( + "signature missmatch. Expected: %v. Received %v", + v.verifiedSignature, + unit.Signature, + ) } err := verifyMessageIDSignature( @@ -95,33 +79,27 @@ func (v *Validator) verify(unit *Unit) error { return err } - v.verification = verificationResult{ - done: true, - // todo(rdr): by storing a field of unit.Signature am I forcing the whole `unit` to - // continue to exist on the heap, or can the remaining fields be cleaned. Probably the - // latter. - signature: unit.Signature, - nonce: unit.Nonce, - } - + // todo(rdr): by storing a field of unit.Signature am I forcing the whole `unit` to + // continue to exist on the heap, or can the remaining fields be cleaned. Probably the + // latter. + v.verifiedSignature = unit.Signature return nil } func (v *Validator) ValidateUnit(unit *Unit, sender peer.ID) error { - // todo(rdr): Do I need to check that comitteeID, publisher and messageRoot to be the - // same as the one being hold by the validator, or is that infered from the signature being - // correct or wrong? - if _, ok := v.receivedShards[unit.ShardIndex]; ok { return fmt.Errorf("duplicated shard %d received", unit.ShardIndex) } - err := v.scheduler.ValidateShardOrigin(sender, v.publisher, unit.ShardIndex) + // We can use `unit.Publisher` because it is part of messageKey and hence + // this validator wouldn't be used otherwise + err := v.scheduler.ValidateShardOrigin(sender, unit.Publisher, unit.ShardIndex) if err != nil { + return err } if err = v.verify(unit); err != nil { - return nil + return err } // Cache the verified shard to avoid re-verification From 6e29701801cf293fc57a6dc7ddee4279d353ef18 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sun, 5 Apr 2026 13:38:13 +0100 Subject: [PATCH 15/40] chore: update propeller unit (including protobuf) to latest specs --- consensus/propeller/proto/propeller.pb.go | 189 +++++++++++++++++----- consensus/propeller/proto/propeller.proto | 29 +++- consensus/propeller/types.go | 2 +- consensus/propeller/unit.go | 64 ++++++-- 4 files changed, 224 insertions(+), 60 deletions(-) diff --git a/consensus/propeller/proto/propeller.pb.go b/consensus/propeller/proto/propeller.pb.go index 0a1fcf667b..fdf371adb1 100644 --- a/consensus/propeller/proto/propeller.pb.go +++ b/consensus/propeller/proto/propeller.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.9 -// protoc v7.34.0 +// protoc v7.34.1 // source: consensus/propeller/proto/propeller.proto package proto @@ -70,12 +70,104 @@ func (x *MerkleProof) GetSiblings() []*common.Hash256 { return nil } -// A single unit in the Propeller protocol containing a shard of erasure-coded data +// A single erasure-coded fragment of the original message. +type Shard struct { + state protoimpl.MessageState `protogen:"open.v1"` + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Shard) Reset() { + *x = Shard{} + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Shard) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Shard) ProtoMessage() {} + +func (x *Shard) ProtoReflect() protoreflect.Message { + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Shard.ProtoReflect.Descriptor instead. +func (*Shard) Descriptor() ([]byte, []int) { + return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{1} +} + +func (x *Shard) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +// A collection of shards assigned to a single peer. +// The proto-encoded bytes of this message are used as Merkle tree leaf data, +// ensuring cross-language determinism. +type ShardsOfPeer struct { + state protoimpl.MessageState `protogen:"open.v1"` + Shards []*Shard `protobuf:"bytes,1,rep,name=shards,proto3" json:"shards,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ShardsOfPeer) Reset() { + *x = ShardsOfPeer{} + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ShardsOfPeer) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ShardsOfPeer) ProtoMessage() {} + +func (x *ShardsOfPeer) ProtoReflect() protoreflect.Message { + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ShardsOfPeer.ProtoReflect.Descriptor instead. +func (*ShardsOfPeer) Descriptor() ([]byte, []int) { + return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{2} +} + +func (x *ShardsOfPeer) GetShards() []*Shard { + if x != nil { + return x.Shards + } + return nil +} + +// A single unit in the Propeller protocol containing shards of erasure-coded data // along with cryptographic proofs for verification. type PropellerUnit struct { state protoimpl.MessageState `protogen:"open.v1"` - // The actual data shard (erasure-coded fragment of the original message). - Shard []byte `protobuf:"bytes,1,opt,name=shard,proto3" json:"shard,omitempty"` + // The shards assigned to this unit's peer. + Shards *ShardsOfPeer `protobuf:"bytes,1,opt,name=shards,proto3" json:"shards,omitempty"` // The position of this shard in the erasure coding scheme. Index uint64 `protobuf:"varint,2,opt,name=index,proto3" json:"index,omitempty"` // The Merkle root of all shards, used to verify shard integrity. @@ -86,17 +178,19 @@ type PropellerUnit struct { Publisher *common.PeerID `protobuf:"bytes,5,opt,name=publisher,proto3" json:"publisher,omitempty"` // Cryptographic signature from the publisher over the merkle_root. Signature []byte `protobuf:"bytes,6,opt,name=signature,proto3" json:"signature,omitempty"` - // TODO(AndrewL): consider re-naming channel - // TODO(AndrewL): make it uint64 instead of uint32. - // Logical channel identifier for multiplexing different message streams. - Channel uint32 `protobuf:"varint,7,opt,name=channel,proto3" json:"channel,omitempty"` // TODO(AndrewL): CRITICAL: protect against replay attacks (maybe using a timestamp) + // Committee identifier for multiplexing different message streams. + CommitteeId *common.Hash256 `protobuf:"bytes,7,opt,name=committee_id,json=committeeId,proto3" json:"committee_id,omitempty"` + // A strictly increasing number, to avoid replays. + // Current implementation is nanoseconds since UNIX_EPOCH. + // TODO(guyn): CRITICAL: protect against replay attacks using this timestamp + Nonce uint64 `protobuf:"varint,8,opt,name=nonce,proto3" json:"nonce,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *PropellerUnit) Reset() { *x = PropellerUnit{} - mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[1] + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -108,7 +202,7 @@ func (x *PropellerUnit) String() string { func (*PropellerUnit) ProtoMessage() {} func (x *PropellerUnit) ProtoReflect() protoreflect.Message { - mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[1] + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -121,12 +215,12 @@ func (x *PropellerUnit) ProtoReflect() protoreflect.Message { // Deprecated: Use PropellerUnit.ProtoReflect.Descriptor instead. func (*PropellerUnit) Descriptor() ([]byte, []int) { - return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{1} + return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{3} } -func (x *PropellerUnit) GetShard() []byte { +func (x *PropellerUnit) GetShards() *ShardsOfPeer { if x != nil { - return x.Shard + return x.Shards } return nil } @@ -166,9 +260,16 @@ func (x *PropellerUnit) GetSignature() []byte { return nil } -func (x *PropellerUnit) GetChannel() uint32 { +func (x *PropellerUnit) GetCommitteeId() *common.Hash256 { if x != nil { - return x.Channel + return x.CommitteeId + } + return nil +} + +func (x *PropellerUnit) GetNonce() uint64 { + if x != nil { + return x.Nonce } return 0 } @@ -183,7 +284,7 @@ type PropellerUnitBatch struct { func (x *PropellerUnitBatch) Reset() { *x = PropellerUnitBatch{} - mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[2] + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -195,7 +296,7 @@ func (x *PropellerUnitBatch) String() string { func (*PropellerUnitBatch) ProtoMessage() {} func (x *PropellerUnitBatch) ProtoReflect() protoreflect.Message { - mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[2] + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -208,7 +309,7 @@ func (x *PropellerUnitBatch) ProtoReflect() protoreflect.Message { // Deprecated: Use PropellerUnitBatch.ProtoReflect.Descriptor instead. func (*PropellerUnitBatch) Descriptor() ([]byte, []int) { - return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{2} + return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{4} } func (x *PropellerUnitBatch) GetBatch() []*PropellerUnit { @@ -224,16 +325,21 @@ const file_consensus_propeller_proto_propeller_proto_rawDesc = "" + "\n" + ")consensus/propeller/proto/propeller.proto\x1a\x16p2p/proto/common.proto\"3\n" + "\vMerkleProof\x12$\n" + - "\bsiblings\x18\x01 \x03(\v2\b.Hash256R\bsiblings\"\xf6\x01\n" + - "\rPropellerUnit\x12\x14\n" + - "\x05shard\x18\x01 \x01(\fR\x05shard\x12\x14\n" + + "\bsiblings\x18\x01 \x03(\v2\b.Hash256R\bsiblings\"\x1b\n" + + "\x05Shard\x12\x12\n" + + "\x04data\x18\x01 \x01(\fR\x04data\".\n" + + "\fShardsOfPeer\x12\x1e\n" + + "\x06shards\x18\x01 \x03(\v2\x06.ShardR\x06shards\"\xb0\x02\n" + + "\rPropellerUnit\x12%\n" + + "\x06shards\x18\x01 \x01(\v2\r.ShardsOfPeerR\x06shards\x12\x14\n" + "\x05index\x18\x02 \x01(\x04R\x05index\x12)\n" + "\vmerkle_root\x18\x03 \x01(\v2\b.Hash256R\n" + "merkleRoot\x12/\n" + "\fmerkle_proof\x18\x04 \x01(\v2\f.MerkleProofR\vmerkleProof\x12%\n" + "\tpublisher\x18\x05 \x01(\v2\a.PeerIDR\tpublisher\x12\x1c\n" + - "\tsignature\x18\x06 \x01(\fR\tsignature\x12\x18\n" + - "\achannel\x18\a \x01(\rR\achannel\":\n" + + "\tsignature\x18\x06 \x01(\fR\tsignature\x12+\n" + + "\fcommittee_id\x18\a \x01(\v2\b.Hash256R\vcommitteeId\x12\x14\n" + + "\x05nonce\x18\b \x01(\x04R\x05nonce\":\n" + "\x12PropellerUnitBatch\x12$\n" + "\x05batch\x18\x01 \x03(\v2\x0e.PropellerUnitR\x05batchB9Z7github.com/NethermindEth/juno/consensus/propeller/protob\x06proto3" @@ -249,25 +355,30 @@ func file_consensus_propeller_proto_propeller_proto_rawDescGZIP() []byte { return file_consensus_propeller_proto_propeller_proto_rawDescData } -var file_consensus_propeller_proto_propeller_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_consensus_propeller_proto_propeller_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_consensus_propeller_proto_propeller_proto_goTypes = []any{ (*MerkleProof)(nil), // 0: MerkleProof - (*PropellerUnit)(nil), // 1: PropellerUnit - (*PropellerUnitBatch)(nil), // 2: PropellerUnitBatch - (*common.Hash256)(nil), // 3: Hash256 - (*common.PeerID)(nil), // 4: PeerID + (*Shard)(nil), // 1: Shard + (*ShardsOfPeer)(nil), // 2: ShardsOfPeer + (*PropellerUnit)(nil), // 3: PropellerUnit + (*PropellerUnitBatch)(nil), // 4: PropellerUnitBatch + (*common.Hash256)(nil), // 5: Hash256 + (*common.PeerID)(nil), // 6: PeerID } var file_consensus_propeller_proto_propeller_proto_depIdxs = []int32{ - 3, // 0: MerkleProof.siblings:type_name -> Hash256 - 3, // 1: PropellerUnit.merkle_root:type_name -> Hash256 - 0, // 2: PropellerUnit.merkle_proof:type_name -> MerkleProof - 4, // 3: PropellerUnit.publisher:type_name -> PeerID - 1, // 4: PropellerUnitBatch.batch:type_name -> PropellerUnit - 5, // [5:5] is the sub-list for method output_type - 5, // [5:5] is the sub-list for method input_type - 5, // [5:5] is the sub-list for extension type_name - 5, // [5:5] is the sub-list for extension extendee - 0, // [0:5] is the sub-list for field type_name + 5, // 0: MerkleProof.siblings:type_name -> Hash256 + 1, // 1: ShardsOfPeer.shards:type_name -> Shard + 2, // 2: PropellerUnit.shards:type_name -> ShardsOfPeer + 5, // 3: PropellerUnit.merkle_root:type_name -> Hash256 + 0, // 4: PropellerUnit.merkle_proof:type_name -> MerkleProof + 6, // 5: PropellerUnit.publisher:type_name -> PeerID + 5, // 6: PropellerUnit.committee_id:type_name -> Hash256 + 3, // 7: PropellerUnitBatch.batch:type_name -> PropellerUnit + 8, // [8:8] is the sub-list for method output_type + 8, // [8:8] is the sub-list for method input_type + 8, // [8:8] is the sub-list for extension type_name + 8, // [8:8] is the sub-list for extension extendee + 0, // [0:8] is the sub-list for field type_name } func init() { file_consensus_propeller_proto_propeller_proto_init() } @@ -281,7 +392,7 @@ func file_consensus_propeller_proto_propeller_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_consensus_propeller_proto_propeller_proto_rawDesc), len(file_consensus_propeller_proto_propeller_proto_rawDesc)), NumEnums: 0, - NumMessages: 3, + NumMessages: 5, NumExtensions: 0, NumServices: 0, }, diff --git a/consensus/propeller/proto/propeller.proto b/consensus/propeller/proto/propeller.proto index 67787bdd3a..a45dc7ed19 100644 --- a/consensus/propeller/proto/propeller.proto +++ b/consensus/propeller/proto/propeller.proto @@ -12,11 +12,23 @@ message MerkleProof { repeated Hash256 siblings = 1; } -// A single unit in the Propeller protocol containing a shard of erasure-coded data +// A single erasure-coded fragment of the original message. +message Shard { + bytes data = 1; +} + +// A collection of shards assigned to a single peer. +// The proto-encoded bytes of this message are used as Merkle tree leaf data, +// ensuring cross-language determinism. +message ShardsOfPeer { + repeated Shard shards = 1; +} + +// A single unit in the Propeller protocol containing shards of erasure-coded data // along with cryptographic proofs for verification. message PropellerUnit { - // The actual data shard (erasure-coded fragment of the original message). - bytes shard = 1; + // The shards assigned to this unit's peer. + ShardsOfPeer shards = 1; // The position of this shard in the erasure coding scheme. uint64 index = 2; // The Merkle root of all shards, used to verify shard integrity. @@ -27,11 +39,12 @@ message PropellerUnit { PeerID publisher = 5; // Cryptographic signature from the publisher over the merkle_root. bytes signature = 6; - // TODO(AndrewL): consider re-naming channel - // TODO(AndrewL): make it uint64 instead of uint32. - // Logical channel identifier for multiplexing different message streams. - uint32 channel = 7; - // TODO(AndrewL): CRITICAL: protect against replay attacks (maybe using a timestamp) + // Committee identifier for multiplexing different message streams. + Hash256 committee_id = 7; + // A strictly increasing number, to avoid replays. + // Current implementation is nanoseconds since UNIX_EPOCH. + // TODO(guyn): CRITICAL: protect against replay attacks using this timestamp + uint64 nonce = 8; } // A batch of PropellerUnits for efficient transmission. diff --git a/consensus/propeller/types.go b/consensus/propeller/types.go index a120009fb6..ea6aadf912 100644 --- a/consensus/propeller/types.go +++ b/consensus/propeller/types.go @@ -18,7 +18,7 @@ import ( // CommitteeID identifies a committee or logical broadcast group. Multiple committees // can operate concurrently within the same engine, each with its own peer set. -type CommitteeID uint64 +type CommitteeID [4]uint64 // ShardIndex is the position of a shard within the erasure-coded output. // Valid range is [0, N-2] where N is the total number of peers. diff --git a/consensus/propeller/unit.go b/consensus/propeller/unit.go index d700bd2e77..4a5d483cdb 100644 --- a/consensus/propeller/unit.go +++ b/consensus/propeller/unit.go @@ -1,6 +1,7 @@ package propeller import ( + "encoding/binary" "time" "github.com/NethermindEth/juno/consensus/propeller/merkle" @@ -9,8 +10,16 @@ import ( "github.com/starknet-io/starknet-p2p-specs/p2p/proto/common" ) +// The actual shard fragmen +type Shard []byte + +// Holds the shard fragments carried by the Propeller Unit +type ShardData []Shard + +// Propeller Unit Signature type Signature []byte +// Propeller Unit Nonce type Nonce time.Duration // Unit is the atomic wire message: one erasure-coded shard plus @@ -23,25 +32,41 @@ type Unit struct { MerkleProof merkle.Proof // Merkle inclusion proof for this shard Signature Signature // Publisher's Ed25519 signature over the root ShardIndex ShardIndex // This shard's position in the erasure-coded output - ShardData []byte // The actual data fragment + ShardData ShardData // // todo(rdr): calling it nonce because that's what is called on the rust side but // time stamp or some other name would be better Nonce Nonce // Strictly increasing number, starting from the Unix epoch } func UnitFromProto(protoUnit *pb.PropellerUnit) Unit { + shards := make(ShardData, len(protoUnit.Shards.GetShards())) + for i, s := range protoUnit.Shards.GetShards() { + shards[i] = Shard(s.Data) + } + + siblings := make([]merkle.Hash, len(protoUnit.MerkleProof.GetSiblings())) + for i, s := range protoUnit.MerkleProof.GetSiblings() { + copy(siblings[i][:], s.Elements) + } + return Unit{ - CommitteeID: CommitteeID(protoUnit.Channel), - // todo(rdr): this casting operations seem a bit risky, are they? - Publisher: peer.ID(protoUnit.Publisher.Id), - MessageRoot: MessageRoot(protoUnit.MerkleRoot.Elements), + CommitteeID: committeeIDFromBytes(protoUnit.CommitteeId.GetElements()), + Publisher: peer.ID(protoUnit.Publisher.GetId()), + MessageRoot: MessageRoot(protoUnit.MerkleRoot.GetElements()), + MerkleProof: merkle.Proof{Siblings: siblings}, Signature: protoUnit.Signature, ShardIndex: ShardIndex(protoUnit.Index), - ShardData: protoUnit.Shard, + ShardData: shards, + Nonce: Nonce(time.Duration(protoUnit.Nonce)), } } func (u *Unit) ToProto() *pb.PropellerUnit { + protoShards := make([]*pb.Shard, len(u.ShardData)) + for i, s := range u.ShardData { + protoShards[i] = &pb.Shard{Data: s} + } + siblings := make([]*common.Hash256, len(u.MerkleProof.Siblings)) for i, s := range u.MerkleProof.Siblings { siblings[i] = &common.Hash256{Elements: s[:]} @@ -49,14 +74,29 @@ func (u *Unit) ToProto() *pb.PropellerUnit { root := merkle.Hash(u.MessageRoot) return &pb.PropellerUnit{ - Shard: u.ShardData, - Index: uint64(u.ShardIndex), - MerkleRoot: &common.Hash256{ - Elements: root[:], - }, + Shards: &pb.ShardsOfPeer{Shards: protoShards}, + Index: uint64(u.ShardIndex), + MerkleRoot: &common.Hash256{Elements: root[:]}, MerkleProof: &pb.MerkleProof{Siblings: siblings}, Publisher: &common.PeerID{Id: []byte(u.Publisher)}, Signature: u.Signature, - Channel: uint32(u.CommitteeID), + CommitteeId: &common.Hash256{Elements: committeeIDToBytes(u.CommitteeID)}, + Nonce: uint64(u.Nonce), + } +} + +func committeeIDFromBytes(b []byte) CommitteeID { + var id CommitteeID + for i := range 4 { + id[i] = binary.BigEndian.Uint64(b[i*8 : (i+1)*8]) + } + return id +} + +func committeeIDToBytes(id CommitteeID) []byte { + b := make([]byte, 32) + for i := range 4 { + binary.BigEndian.PutUint64(b[i*8:(i+1)*8], id[i]) } + return b } From f8712d6b8d461c41a69ad5651c569bce8b345fbc Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sun, 5 Apr 2026 14:28:35 +0100 Subject: [PATCH 16/40] feat: implement data shard verification for validator --- consensus/propeller/merkle/merkle.go | 136 +++++++++++++-------------- consensus/propeller/processor.go | 18 +++- consensus/propeller/unit.go | 27 +++++- consensus/propeller/validator.go | 41 ++++++-- 4 files changed, 140 insertions(+), 82 deletions(-) diff --git a/consensus/propeller/merkle/merkle.go b/consensus/propeller/merkle/merkle.go index 65c542370a..1fecd859c7 100644 --- a/consensus/propeller/merkle/merkle.go +++ b/consensus/propeller/merkle/merkle.go @@ -14,59 +14,33 @@ type Proof struct { Siblings []Hash } -// Represents a Merkle Tree -type Tree []Proof - -// Merkle tree construction and verification using a specific SHA-256 tagging -// scheme. Tags prevent second-preimage attacks by domain-separating leaf -// hashes from internal node hashes. The exact tag format matches the Propeller -// protocol specification so that all implementations produce identical trees. +// VerifyProof checks that a leaf at the given index is included in a +// tree with the claimed root. The proof contains sibling hashes from the leaf +// level up to the root. // -// Tree layout: leaves are at the bottom, padded to the next power-of-two -// with the hash of empty data. The tree is built bottom-up by hashing pairs. +// The index determines the path through the tree: at each level, if the +// current bit of the index is 0 the current hash is the left child and the +// sibling is the right child, and vice versa. +func (p *Proof) Verify(root *Hash, leaf []byte, index uint32) bool { + current := merkleLeafHash(leaf) -// merkleLeafHash computes: SHA256("" || data || "") -// -// The XML-like tags are the domain separator specified by the Propeller -// protocol. They ensure a leaf hash can never collide with a node hash, -// even if an attacker controls the data. -func merkleLeafHash(data []byte) Hash { - h := sha256.New() - h.Write([]byte("")) - h.Write(data) - h.Write([]byte("")) - var out [32]byte - h.Sum(out[:0]) - return out -} + idx := index + for _, sibling := range p.Siblings { + if idx%2 == 0 { + // Current node is left child, sibling is right. + current = merkleNodeHash(current, sibling) + } else { + // Current node is right child, sibling is left. + current = merkleNodeHash(sibling, current) + } + idx /= 2 + } -// merkleNodeHash computes: -// -// SHA256("" || left || "" || right || "") -// -// The nested tags ensure node hashes are in a separate domain from leaf hashes. -func merkleNodeHash(left, right [32]byte) Hash { - h := sha256.New() - h.Write([]byte("")) - h.Write(left[:]) - h.Write([]byte("")) - h.Write(right[:]) - h.Write([]byte("")) - var out [32]byte - h.Sum(out[:0]) - return out + return current == *root } -// nextPowerOfTwo returns the smallest power of two >= n, with a minimum of 2. -// A minimum of 2 ensures even a single-leaf tree has a sibling for its proof. -func nextPowerOfTwo(n int) int { - if n <= 2 { - return 2 - } - // bits.Len returns the position of the highest set bit + 1. - // Subtracting 1 before Len handles exact powers-of-two correctly. - return 1 << bits.Len(uint(n-1)) -} +// Represents a Merkle Tree +type Tree []Proof // emptyLeafHash is the hash of a padding leaf (no data). We precompute it // because the same value is used repeatedly when the leaf count is not a @@ -142,27 +116,53 @@ func New(leaves [][]byte) (root Hash, tree Tree) { return root, tree } -// VerifyProof checks that a leaf at the given index is included in a -// tree with the claimed root. The proof contains sibling hashes from the leaf -// level up to the root. +// Merkle tree construction and verification using a specific SHA-256 tagging +// scheme. Tags prevent second-preimage attacks by domain-separating leaf +// hashes from internal node hashes. The exact tag format matches the Propeller +// protocol specification so that all implementations produce identical trees. // -// The index determines the path through the tree: at each level, if the -// current bit of the index is 0 the current hash is the left child and the -// sibling is the right child, and vice versa. -func VerifyProof(root Hash, leaf []byte, index uint32, proof Proof) bool { - current := merkleLeafHash(leaf) +// Tree layout: leaves are at the bottom, padded to the next power-of-two +// with the hash of empty data. The tree is built bottom-up by hashing pairs. - idx := index - for _, sibling := range proof.Siblings { - if idx%2 == 0 { - // Current node is left child, sibling is right. - current = merkleNodeHash(current, sibling) - } else { - // Current node is right child, sibling is left. - current = merkleNodeHash(sibling, current) - } - idx /= 2 - } +// merkleLeafHash computes: SHA256("" || data || "") +// +// The XML-like tags are the domain separator specified by the Propeller +// protocol. They ensure a leaf hash can never collide with a node hash, +// even if an attacker controls the data. +func merkleLeafHash(data []byte) Hash { + h := sha256.New() + h.Write([]byte("")) + h.Write(data) + h.Write([]byte("")) + var out [32]byte + h.Sum(out[:0]) + return out +} + +// merkleNodeHash computes: +// +// SHA256("" || left || "" || right || "") +// +// The nested tags ensure node hashes are in a separate domain from leaf hashes. +func merkleNodeHash(left, right [32]byte) Hash { + h := sha256.New() + h.Write([]byte("")) + h.Write(left[:]) + h.Write([]byte("")) + h.Write(right[:]) + h.Write([]byte("")) + var out [32]byte + h.Sum(out[:0]) + return out +} - return current == root +// nextPowerOfTwo returns the smallest power of two >= n, with a minimum of 2. +// A minimum of 2 ensures even a single-leaf tree has a sibling for its proof. +func nextPowerOfTwo(n int) int { + if n <= 2 { + return 2 + } + // bits.Len returns the position of the highest set bit + 1. + // Subtracting 1 before Len handles exact powers-of-two correctly. + return 1 << bits.Len(uint(n-1)) } diff --git a/consensus/propeller/processor.go b/consensus/propeller/processor.go index 2491c7edea..8132587b25 100644 --- a/consensus/propeller/processor.go +++ b/consensus/propeller/processor.go @@ -60,16 +60,23 @@ func (s *subprocessor) Run(ctx context.Context, unitChan <-chan unitWithSender) // can check for `context.DeadlineExceeded` return ctx.Err() case unitWithSender := <-unitChan: - // todo(rdr): validate that the unit is correct - // if the unit is incorrect penalize publisher (how?) + unit := unitWithSender.unit + sender := unitWithSender.sender + + err := s.validator.ValidateUnit(unit, sender) + if err != nil { + // do something with the error, logging it and + // sharing it with the main processor + continue + } - s.unitsReceived = append(s.unitsReceived, *unitWithSender.unit) + s.unitsReceived = append(s.unitsReceived, *unit) switch s.messageState { case preBuilt: // if the unit / shard is our own and we are pre-construction then we should // broadcast our own shard (only once) // todo(rdr): consider inlining this function? or use go naming ("once" in the name) - s.maybeBroacastLocalShard(unitWithSender.unit) + s.maybeBroacastLocalShard(unit) // todo(rdr): do something with a signature that I don't understand very well @@ -108,6 +115,9 @@ type messageKey struct { Nonce Nonce } +// todo(rdr): since message key is a subset of a unit, it should probably be constructed by +// receiving a unit as an argument!! + func (mk *messageKey) String() string { return fmt.Sprintf("%+v", *mk) } diff --git a/consensus/propeller/unit.go b/consensus/propeller/unit.go index 4a5d483cdb..f8dab77cc3 100644 --- a/consensus/propeller/unit.go +++ b/consensus/propeller/unit.go @@ -2,12 +2,14 @@ package propeller import ( "encoding/binary" + "errors" "time" "github.com/NethermindEth/juno/consensus/propeller/merkle" pb "github.com/NethermindEth/juno/consensus/propeller/proto" "github.com/libp2p/go-libp2p/core/peer" "github.com/starknet-io/starknet-p2p-specs/p2p/proto/common" + "google.golang.org/protobuf/proto" ) // The actual shard fragmen @@ -16,6 +18,17 @@ type Shard []byte // Holds the shard fragments carried by the Propeller Unit type ShardData []Shard +func (sd ShardData) MarshalProto() []byte { + shards := make([]*pb.Shard, len(sd)) + for i, s := range sd { + shards[i] = &pb.Shard{Data: s} + } + // We ignore the error because this data has already been converted and it is expected + // to be correct. + res, _ := proto.Marshal(&pb.ShardsOfPeer{Shards: shards}) + return res +} + // Propeller Unit Signature type Signature []byte @@ -38,12 +51,22 @@ type Unit struct { Nonce Nonce // Strictly increasing number, starting from the Unix epoch } -func UnitFromProto(protoUnit *pb.PropellerUnit) Unit { +func UnitFromProto(protoUnit *pb.PropellerUnit) (Unit, error) { shards := make(ShardData, len(protoUnit.Shards.GetShards())) for i, s := range protoUnit.Shards.GetShards() { shards[i] = Shard(s.Data) } + // validate that all shard length is the same + // todo(rdr): What other validations should I do? + // todo(rdr): Should I do these validations here? + shardLen := len(shards[0]) + for i := range shards[1:] { + if len(shards[i]) != shardLen { + return Unit{}, errors.New("unit has shards of different length") + } + } + siblings := make([]merkle.Hash, len(protoUnit.MerkleProof.GetSiblings())) for i, s := range protoUnit.MerkleProof.GetSiblings() { copy(siblings[i][:], s.Elements) @@ -58,7 +81,7 @@ func UnitFromProto(protoUnit *pb.PropellerUnit) Unit { ShardIndex: ShardIndex(protoUnit.Index), ShardData: shards, Nonce: Nonce(time.Duration(protoUnit.Nonce)), - } + }, nil } func (u *Unit) ToProto() *pb.PropellerUnit { diff --git a/consensus/propeller/validator.go b/consensus/propeller/validator.go index 27fd4a0d7f..12aeb7bea3 100644 --- a/consensus/propeller/validator.go +++ b/consensus/propeller/validator.go @@ -2,8 +2,10 @@ package propeller import ( "bytes" + "errors" "fmt" + "github.com/NethermindEth/juno/consensus/propeller/merkle" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" ) @@ -27,12 +29,11 @@ type Validator struct { publisherPubKey crypto.PubKey scheduler *Scheduler + // track of every shard index received + receivedShards map[ShardIndex]struct{} // Once the validation is done it's stored here, subsequent runs // compare against it verifiedSignature Signature - - // track of every shard index received - receivedShards map[ShardIndex]struct{} } // todo(rdr): maybe just pass the publisher? @@ -48,13 +49,33 @@ func NewValidator(key *messageKey, scheduler *Scheduler) Validator { // publisher: key.Publisher, // messageRoot: key.Root, // nonce: key.Nonce, - publisherPubKey: pubKey, - scheduler: scheduler, - receivedShards: make(map[ShardIndex]struct{}, scheduler.NumDataShards()), + publisherPubKey: pubKey, + scheduler: scheduler, + receivedShards: make(map[ShardIndex]struct{}, scheduler.NumDataShards()), + verifiedSignature: nil, + } +} + +func (v *Validator) verifyDataShards(unit *Unit) error { + if len(unit.ShardData) != 1 { + return fmt.Errorf( + "unexpected amount of shards. Expected %d. Received %d", + 1, + len(unit.ShardData), + ) + } + + proof := unit.MerkleProof + root := merkle.Hash(unit.MessageRoot) + // We marshal to Proto bytes to make the verification language agnostic + if proof.Verify(&root, unit.ShardData.MarshalProto(), uint32(unit.ShardIndex)) { + return nil } + + return errors.New("data shards verification failed") } -func (v *Validator) verify(unit *Unit) error { +func (v *Validator) verifySignature(unit *Unit) error { if v.verifiedSignature != nil { if bytes.Equal(v.verifiedSignature, unit.Signature) { return nil @@ -98,7 +119,11 @@ func (v *Validator) ValidateUnit(unit *Unit, sender peer.ID) error { return err } - if err = v.verify(unit); err != nil { + if err = v.verifyDataShards(unit); err != nil { + return err + } + + if err = v.verifySignature(unit); err != nil { return err } From 8c548acbefd2a36070240e7b794b331721668935 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 6 Apr 2026 12:59:49 +0100 Subject: [PATCH 17/40] feat: complete unit validator logic --- .../propeller/deprecated_processor_test.go | 2 +- .../propeller/deprecated_validator_test.go | 4 +- consensus/propeller/engine.go | 15 ++-- consensus/propeller/engine_test.go | 2 +- consensus/propeller/processor.go | 26 +++---- consensus/propeller/propeller.go | 7 +- consensus/propeller/signing.go | 77 +++++++++++++++++++ .../propeller/{utils => }/signing_test.go | 0 consensus/propeller/types.go | 2 +- consensus/propeller/unit.go | 11 +-- consensus/propeller/utils/signing.go | 33 -------- consensus/propeller/validator.go | 44 ++--------- 12 files changed, 121 insertions(+), 102 deletions(-) create mode 100644 consensus/propeller/signing.go rename consensus/propeller/{utils => }/signing_test.go (100%) delete mode 100644 consensus/propeller/utils/signing.go diff --git a/consensus/propeller/deprecated_processor_test.go b/consensus/propeller/deprecated_processor_test.go index b718984d38..a4b1e29400 100644 --- a/consensus/propeller/deprecated_processor_test.go +++ b/consensus/propeller/deprecated_processor_test.go @@ -81,7 +81,7 @@ func (env *processorTestEnv) encodeTestMessage( privKey, ok := env.privKeys[publisher] require.True(t, ok, "no private key for publisher %s", publisher) - sig, err := SignRoot(root, privKey) + sig, err := SignMessage(root, privKey) require.NoError(t, err) for i := range units { diff --git a/consensus/propeller/deprecated_validator_test.go b/consensus/propeller/deprecated_validator_test.go index 37b50ceb16..cd48adcfec 100644 --- a/consensus/propeller/deprecated_validator_test.go +++ b/consensus/propeller/deprecated_validator_test.go @@ -61,7 +61,7 @@ func makeValidUnit( units, root, err := EncodeMessage(msg, schedule, enc) require.NoError(t, err) - sig, err := SignRoot(root, publisherKey) + sig, err := SignMessage(root, publisherKey) require.NoError(t, err) unit := &units[shardIndex] @@ -326,7 +326,7 @@ func TestSignRoot_RoundTrip(t *testing.T) { privKey, peerID := realPeer(42) root := MessageRoot{0xaa, 0xbb, 0xcc} - sig, err := SignRoot(root, privKey) + sig, err := SignMessage(root, privKey) require.NoError(t, err) require.NotEmpty(t, sig) diff --git a/consensus/propeller/engine.go b/consensus/propeller/engine.go index e577f0cd3a..50849c9a2f 100644 --- a/consensus/propeller/engine.go +++ b/consensus/propeller/engine.go @@ -175,11 +175,13 @@ func (e *Engine) registerCommittee( peers []PeerCommittee, peersKeys []*StakerID, ) error { - // todo(rdr): Why re-registration should be ignored, as far as I know, it shouldn't happen :think: + // todo(rdr): Why re-registration should be ignored, + // as far as I understand, it shouldn't happen :think: if _, ok := e.committees[committeeID]; ok { e.log.Warn( "committee already registered, will ignore re-registration attempt", - zap.Uint64("committeeID", uint64(committeeID)), + // todo(rdr): give a propper string repr + zap.Any("committee id", committeeID), ) return nil } @@ -206,7 +208,8 @@ func (e *Engine) registerCommittee( } e.log.Info("registered new committee", - zap.Uint64("committeeID", uint64(committeeID)), + // todo(rdr): give a proper string representation + zap.Any("committeeID", committeeID), zap.Int("peers", len(peers)), zap.Int("dataShards", schedule.NumDataShards()), zap.Int("codingShards", schedule.NumCodingShards()), @@ -224,7 +227,8 @@ func (e *Engine) unregisterCommittee(committeeID CommitteeID) { // better to pass a context with cancelj e.log.Info("unregistered propeller committee", - zap.Uint64("committee", uint64(committeeID)), + // todo(rdr): give a proper string representation + zap.Any("committee id", committeeID), ) } @@ -285,7 +289,8 @@ func (e *Engine) processUnit(ctx context.Context, unit *Unit, sender peer.ID) { if !ok { // note(rdr): maybe debug? e.log.Warn("received key for unregistered committee, dropping", - zap.Uint64("committee id", uint64(unit.CommitteeID)), + // todo(rdr): give a propper string representation + zap.Any("committee id", unit.CommitteeID), ) return } diff --git a/consensus/propeller/engine_test.go b/consensus/propeller/engine_test.go index 1d4b57fdef..2e8f811267 100644 --- a/consensus/propeller/engine_test.go +++ b/consensus/propeller/engine_test.go @@ -181,7 +181,7 @@ func TestEngine_HandleUnit_CreatesProcessor(t *testing.T) { require.NoError(t, err) publisher := env.peers[1] - sig, err := SignRoot(root, env.privKeys[1]) + sig, err := SignMessage(root, env.privKeys[1]) require.NoError(t, err) for i := range units { diff --git a/consensus/propeller/processor.go b/consensus/propeller/processor.go index 8132587b25..ca02bfb4a5 100644 --- a/consensus/propeller/processor.go +++ b/consensus/propeller/processor.go @@ -104,10 +104,8 @@ func (s *subprocessor) maybeBroacastLocalShard(unit *Unit) { } } -// messageKey uniquely identifies a message within a committee. We track -// per-message state (processor, time cache) using this composite key -// because the same publisher could broadcast different messages (different -// roots) and we need to handle each independently. +// messageKey are a copy of the values of a propeller unit that uniquely identifies it +// all unit that carries shard of the same message will have the same "key" fields type messageKey struct { CommitteeID CommitteeID Publisher peer.ID @@ -115,8 +113,14 @@ type messageKey struct { Nonce Nonce } -// todo(rdr): since message key is a subset of a unit, it should probably be constructed by -// receiving a unit as an argument!! +func extractKey(unit *Unit) messageKey { + return messageKey{ + CommitteeID: unit.CommitteeID, + Publisher: unit.Publisher, + Root: unit.MessageRoot, + Nonce: unit.Nonce, + } +} func (mk *messageKey) String() string { return fmt.Sprintf("%+v", *mk) @@ -202,12 +206,7 @@ func (p *Processor) ProcessMessage( sender peer.ID, scheduler *Scheduler, ) error { - key := messageKey{ - CommitteeID: unit.CommitteeID, - Publisher: unit.Publisher, - Root: unit.MessageRoot, - Nonce: unit.Nonce, - } + key := extractKey(unit) if p.finalized.Contains(key) { return nil } @@ -234,7 +233,8 @@ func (p *Processor) ProcessMessage( return errors.New("dropping shard, processor channel full") } -// createSubprocessor creates a go-routine (subprocessor) that handles all the processing of `key`. +// createSubprocessor creates a go-routine (subprocessor) that handles all the processing of the +// messages identified with the given `messageKey`. // It returns a channel through which this processor can be given units to process // todo(rdr): I would like not to create a channel for everytime we have a different messageKey // since that can be a bit rough to the GC, better to have a pool of them. Benchmarks will give diff --git a/consensus/propeller/propeller.go b/consensus/propeller/propeller.go index a07cb315dd..24dfce294e 100644 --- a/consensus/propeller/propeller.go +++ b/consensus/propeller/propeller.go @@ -73,7 +73,12 @@ func (s *propellerService) receivePropellerUnits(stream network.Stream) { } for _, protoUnit := range batch.GetBatch() { - unit := UnitFromProto(protoUnit) + unit, err := UnitFromProto(protoUnit) + if err != nil { + s.log.Warn("received invalid unit", zap.Error(err)) + // todo(rdr): penalize sender? + continue + } // send unit to engine s.cmdCh <- processUnit{ &unit, diff --git a/consensus/propeller/signing.go b/consensus/propeller/signing.go new file mode 100644 index 0000000000..71e5b93290 --- /dev/null +++ b/consensus/propeller/signing.go @@ -0,0 +1,77 @@ +package propeller + +import ( + "encoding/binary" + "errors" + "fmt" + + "github.com/libp2p/go-libp2p/core/crypto" +) + +const payloadLen = 95 + +// buildSignPayload constructs the byte sequence that the publisher signs. +// Does it in constant time without heap allocations +func buildSignPayload( + root *MessageRoot, committeeID *CommitteeID, nonce Nonce, +) [payloadLen]byte { + // The tags domain-separate propeller signatures from any other protocol + // that might use the same key, preventing cross-protocol signature reuse. + const prefix = "" + const suffix = "" + + // cummulative lenghts denoting the ranges in where each bytes of data should be stored + const prefixLen = len(prefix) + const rootLen = prefixLen + 32 + const committeeIDLen = rootLen + 32 + const nonceLen = committeeIDLen + 8 + const suffixLen = nonceLen + len(suffix) + + var payload [payloadLen]byte + + copy(payload[0:prefixLen], prefix) + copy(payload[prefixLen:rootLen], root[:]) + copy(payload[rootLen:committeeIDLen], committeeID[:]) + binary.BigEndian.PutUint64(payload[committeeIDLen:nonceLen], uint64(nonce)) + copy(payload[nonceLen:suffixLen], suffix) + + return payload +} + +func SignMessage( + privKey crypto.PrivKey, + root *MessageRoot, + committeeID *CommitteeID, + nonce Nonce, +) ([]byte, error) { + payload := buildSignPayload(root, committeeID, nonce) + sig, err := privKey.Sign(payload[:]) + if err != nil { + return nil, fmt.Errorf("signing message root: %w", err) + } + return sig, nil +} + +func VerifyMessageSignature( + pubKey crypto.PubKey, + root *MessageRoot, + committeeID *CommitteeID, + nonce Nonce, + signature Signature, +) error { + if len(signature) == 0 { + return errors.New("empty signature") + } + + payload := buildSignPayload(root, committeeID, nonce) + valid, err := pubKey.Verify(payload[:], signature) + if err != nil { + return fmt.Errorf("failed pub key verification: %w", err) + } + + if !valid { + return errors.New("signature is invalid") + } + + return nil +} diff --git a/consensus/propeller/utils/signing_test.go b/consensus/propeller/signing_test.go similarity index 100% rename from consensus/propeller/utils/signing_test.go rename to consensus/propeller/signing_test.go diff --git a/consensus/propeller/types.go b/consensus/propeller/types.go index ea6aadf912..b516b9c00f 100644 --- a/consensus/propeller/types.go +++ b/consensus/propeller/types.go @@ -18,7 +18,7 @@ import ( // CommitteeID identifies a committee or logical broadcast group. Multiple committees // can operate concurrently within the same engine, each with its own peer set. -type CommitteeID [4]uint64 +type CommitteeID [32]byte // ShardIndex is the position of a shard within the erasure-coded output. // Valid range is [0, N-2] where N is the total number of peers. diff --git a/consensus/propeller/unit.go b/consensus/propeller/unit.go index f8dab77cc3..9d14d02a60 100644 --- a/consensus/propeller/unit.go +++ b/consensus/propeller/unit.go @@ -1,7 +1,6 @@ package propeller import ( - "encoding/binary" "errors" "time" @@ -110,16 +109,10 @@ func (u *Unit) ToProto() *pb.PropellerUnit { func committeeIDFromBytes(b []byte) CommitteeID { var id CommitteeID - for i := range 4 { - id[i] = binary.BigEndian.Uint64(b[i*8 : (i+1)*8]) - } + copy(id[:], b) return id } func committeeIDToBytes(id CommitteeID) []byte { - b := make([]byte, 32) - for i := range 4 { - binary.BigEndian.PutUint64(b[i*8:(i+1)*8], id[i]) - } - return b + return id[:] } diff --git a/consensus/propeller/utils/signing.go b/consensus/propeller/utils/signing.go deleted file mode 100644 index 4d66eb5ee2..0000000000 --- a/consensus/propeller/utils/signing.go +++ /dev/null @@ -1,33 +0,0 @@ -package utils - -import ( - "fmt" - - "github.com/libp2p/go-libp2p/core/crypto" -) - -// SignPayload constructs the byte sequence that the publisher signs: -// -// "" || root[0:32] || "" -// -// The tags domain-separate propeller signatures from any other protocol -// that might use the same key, preventing cross-protocol signature reuse. -func SignPayload[T ~[32]byte](root T) []byte { - payload := make([]byte, 0, len("")+32+len("")) - payload = append(payload, []byte("")...) - payload = append(payload, root[:]...) - payload = append(payload, []byte("")...) - return payload -} - -// todo(rdr): verify this is correct -// SignRoot signs the Merkle root with the given private key, producing the -// signature that goes into every PropellerUnit for this message. -func SignRoot[T ~[32]byte](root T, privKey crypto.PrivKey) ([]byte, error) { - payload := SignPayload(root) - sig, err := privKey.Sign(payload) - if err != nil { - return nil, fmt.Errorf("signing message root: %w", err) - } - return sig, nil -} diff --git a/consensus/propeller/validator.go b/consensus/propeller/validator.go index 12aeb7bea3..066056d035 100644 --- a/consensus/propeller/validator.go +++ b/consensus/propeller/validator.go @@ -18,20 +18,12 @@ import ( // Validates all the incoming units / shards given a committee and the publisher type Validator struct { - // Required fields to perform the validation - // or not. Check if I can delete them - // committeeID CommitteeID - // publisher peer.ID - // messageRoot MessageRoot - // nonce Nonce - // ---------------------------------------- - publisherPubKey crypto.PubKey scheduler *Scheduler // track of every shard index received receivedShards map[ShardIndex]struct{} - // Once the validation is done it's stored here, subsequent runs + // Once the validation is done it's stored here, subsequent validation // compare against it verifiedSignature Signature } @@ -45,10 +37,6 @@ func NewValidator(key *messageKey, scheduler *Scheduler) Validator { panic(err) } return Validator{ - // committeeID: key.CommitteeID, - // publisher: key.Publisher, - // messageRoot: key.Root, - // nonce: key.Nonce, publisherPubKey: pubKey, scheduler: scheduler, receivedShards: make(map[ShardIndex]struct{}, scheduler.NumDataShards()), @@ -88,21 +76,17 @@ func (v *Validator) verifySignature(unit *Unit) error { ) } - err := verifyMessageIDSignature( - unit.CommitteeID, - unit.MessageRoot, - unit.Signature, - unit.Nonce, + err := VerifyMessageSignature( v.publisherPubKey, + &unit.MessageRoot, + &unit.CommitteeID, + unit.Nonce, + unit.Signature, ) if err != nil { - // add error information - return err + return fmt.Errorf("failed message signature verification: %w", err) } - // todo(rdr): by storing a field of unit.Signature am I forcing the whole `unit` to - // continue to exist on the heap, or can the remaining fields be cleaned. Probably the - // latter. v.verifiedSignature = unit.Signature return nil } @@ -112,8 +96,6 @@ func (v *Validator) ValidateUnit(unit *Unit, sender peer.ID) error { return fmt.Errorf("duplicated shard %d received", unit.ShardIndex) } - // We can use `unit.Publisher` because it is part of messageKey and hence - // this validator wouldn't be used otherwise err := v.scheduler.ValidateShardOrigin(sender, unit.Publisher, unit.ShardIndex) if err != nil { return err @@ -127,18 +109,8 @@ func (v *Validator) ValidateUnit(unit *Unit, sender peer.ID) error { return err } - // Cache the verified shard to avoid re-verification + // Store the verified shard to avoid re-verification v.receivedShards[unit.ShardIndex] = struct{}{} return nil } - -func verifyMessageIDSignature( - committeeID CommitteeID, - root MessageRoot, - signature Signature, - nonce Nonce, - publisherPubKey crypto.PubKey, -) error { - panic("not yet implemented") -} From 3982b8c34bde885cff6f6ecdbcd235376fae9bb1 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 6 Apr 2026 18:01:04 +0100 Subject: [PATCH 18/40] refactor: the entier processor and subprocessor making it more legible --- consensus/propeller/processor.go | 219 +++++++++++++++++++++---------- consensus/propeller/validator.go | 4 +- 2 files changed, 150 insertions(+), 73 deletions(-) diff --git a/consensus/propeller/processor.go b/consensus/propeller/processor.go index ca02bfb4a5..c707243a0c 100644 --- a/consensus/propeller/processor.go +++ b/consensus/propeller/processor.go @@ -17,91 +17,140 @@ type unitWithSender struct { sender peer.ID } -type messageState uint64 - -const ( - preBuilt = iota - preReceived -) - -func (ms *messageState) NextState() { - *ms += 1 -} - type subprocessor struct { scheduler *Scheduler localShardIndex ShardIndex localShardWasBroadcast bool - validator Validator - messageState messageState - unitsReceived []Unit + unitsChan <-chan unitWithSender + invalidUnitsChan chan<- invalidUnit + + validator Validator } func newSubprocessor( - key *messageKey, scheduler *Scheduler, localShardIndex ShardIndex, + publisher peer.ID, + scheduler *Scheduler, + localShardIndex ShardIndex, + unitsChan <-chan unitWithSender, + invalidUnitsChan chan<- invalidUnit, ) subprocessor { return subprocessor{ - scheduler: scheduler, - localShardIndex: localShardIndex, - localShardWasBroadcast: false, + scheduler: scheduler, + localShardIndex: localShardIndex, + + unitsChan: unitsChan, + invalidUnitsChan: invalidUnitsChan, - validator: NewValidator(key, scheduler), - messageState: preBuilt, - unitsReceived: make([]Unit, 0, scheduler.ReceiveThreshold()), + validator: NewValidator(publisher, scheduler), } } -func (s *subprocessor) Run(ctx context.Context, unitChan <-chan unitWithSender) error { - for { +func (s *subprocessor) beforeMessageBuiltStage(ctx context.Context) ( + unitsReceived []*Unit, + unitCount int, + message []byte, + err error, +) { + // Keep track of the units received + unitsReceived = make([]*Unit, s.scheduler.ReceiveThreshold()) + unitCount = 0 + + localShardWasBroadcast := false + + buildThreshold := s.scheduler.BuildThreshold() + for unitCount != buildThreshold { select { case <-ctx.Done(): - // todo(rdr): need to differentiate between context cancellation and timeout. - // can check for `context.DeadlineExceeded` - return ctx.Err() - case unitWithSender := <-unitChan: + return + case unitWithSender := <-s.unitsChan: unit := unitWithSender.unit sender := unitWithSender.sender - err := s.validator.ValidateUnit(unit, sender) + err = s.validator.ValidateUnit(unit, sender) if err != nil { - // do something with the error, logging it and - // sharing it with the main processor + s.invalidUnitsChan <- invalidUnit{ + // todo(rdr): not sure if we need message key. + // We just want to penalize the sender + messageKey: extractKey(unit), + sender: sender, + error: err, + } + // if this is the first unit we are receiving, finish abruptly since + // it can be a DOS attack. + if unitCount == 0 { + return + } continue } - s.unitsReceived = append(s.unitsReceived, *unit) - switch s.messageState { - case preBuilt: - // if the unit / shard is our own and we are pre-construction then we should - // broadcast our own shard (only once) - // todo(rdr): consider inlining this function? or use go naming ("once" in the name) - s.maybeBroacastLocalShard(unit) + unitsReceived[int(unit.ShardIndex)] = unit + unitCount += 1 - // todo(rdr): do something with a signature that I don't understand very well + // broadcast as soon as I get my shard + if localShardWasBroadcast && s.localShardIndex == unit.ShardIndex { + localShardWasBroadcast = true + // todo(rdr): actually broadcast shard index + } + } + } - if len(s.unitsReceived) == s.scheduler.BuildThreshold() { - s.messageState.NextState() - // trigger message rebuilding - can it be done in a non-blocking way? - // does it makes sense to do it in a non-blocking way? - } + // perform the build thing + panic("not implemented") +} - case preReceived: - if len(s.unitsReceived) == s.scheduler.ReceiveThreshold() { - // broadcast and finish execution – but don't broadcast the local shard +func (s *subprocessor) beforeMessageReceivedStage( + ctx context.Context, + unitsReceived []*Unit, + unitCount int, + message []byte, +) error { + receivedThreshold := s.scheduler.ReceiveThreshold() + for unitCount != receivedThreshold { + select { + case <-ctx.Done(): + return ctx.Err() + case unitWithSender := <-s.unitsChan: + unit := unitWithSender.unit + sender := unitWithSender.sender + if err := s.validator.ValidateUnit(unit, sender); err != nil { + s.invalidUnitsChan <- invalidUnit{ + messageKey: extractKey(unit), + sender: sender, + error: err, } + continue } + unitsReceived[int(unit.ShardIndex)] = unit + unitCount += 1 } } + + // do the actual job that requires doing once the receive threshold is reached + panic("not implemented") } -// todo(rdr): this can probably be inlined? -func (s *subprocessor) maybeBroacastLocalShard(unit *Unit) { - if !s.localShardWasBroadcast && s.localShardIndex == unit.ShardIndex { - // broadcast shard index - s.localShardWasBroadcast = true +// todo(rdr): we need to be sure to test both cases: +// - when built threshold == received threshold +// - when build threshold != received threshold +func (s *subprocessor) Run( + ctx context.Context, +) error { + // The Run function works in two main loops depending on the stage we are in. + // First stage is before we can build the message, where which we receive messsages + // until we have enough to build the full messsage. The local shard will be broadcasted + // during this stage. + // Second stage starts with the full message built and waits until we receive enough + // messages to reach the received threshold, which guarantees that at leasrt 2/3 of the + // network is non faulty. This stages broadcasts the whole message once finished + + unitsReceived, unitCount, message, err := s.beforeMessageBuiltStage(ctx) + if err != nil { + return err } + + return s.beforeMessageReceivedStage(ctx, unitsReceived, unitCount, message) } // messageKey are a copy of the values of a propeller unit that uniquely identifies it @@ -126,7 +175,17 @@ func (mk *messageKey) String() string { return fmt.Sprintf("%+v", *mk) } -type messageKeyWithError struct { +// invalidUnit is sent when a unit identified with `messageKey` failed validation with +// error `error` +type invalidUnit struct { + messageKey messageKey + sender peer.ID + error error +} + +// finalizedSubprocessor is sent once a subprocessor finalizes processing a message +// identified with `messageKey`. If it finalized on error the `error` field will be non-nil +type finalizedSubprocessor struct { messageKey messageKey error error } @@ -140,10 +199,12 @@ type concurrentTasksBounds struct { type Processor struct { // to avoid processing units already finalized finalized *TimeCache[messageKey] - // channel that communicates when a subprocessor has finished - done chan messageKeyWithError subProcessors map[messageKey]chan unitWithSender + // channel through wich subprocessors signal they have finalized execution + subProcessorsFinalized chan finalizedSubprocessor + // channel through which subprocessor sharedunits that failed validation + invalidUnits chan invalidUnit // track current open and closed tasks to avoid resource starvation mu sync.Mutex @@ -161,11 +222,14 @@ func NewProcessor(localPeer peer.ID, config *Config) *Processor { return &Processor{ finalized: NewTimeCache[messageKey](timeout), - done: make(chan messageKeyWithError), + subProcessors: make(map[messageKey]chan unitWithSender), + subProcessorsFinalized: make(chan finalizedSubprocessor), + invalidUnits: make(chan invalidUnit), + + mu: sync.Mutex{}, publisherTasks: make(map[peer.ID]uint64), tasks: 0, - subProcessors: make(map[messageKey]chan unitWithSender), localPeer: localPeer, timeout: timeout, @@ -183,16 +247,25 @@ func (p *Processor) Run(ctx context.Context) { select { case <-ctx.Done(): return - case finishedSubP := <-p.done: - if finishedSubP.error != nil { - p.log.Error( - "subprocessor error", - zap.String("message key", finishedSubP.messageKey.String()), - zap.Error(finishedSubP.error), + case finalizedSubP := <-p.subProcessorsFinalized: + if finalizedSubP.error != nil { + p.log.Error("subprocessor finalized with error", + zap.String("message key", finalizedSubP.messageKey.String()), + zap.Error(finalizedSubP.error), + ) + } else { + p.log.Info("subprocessor finalized", + zap.String("message key", finalizedSubP.messageKey.String()), ) } - p.decreaseTask(finishedSubP.messageKey.Publisher) - delete(p.subProcessors, finishedSubP.messageKey) + p.finalize(&finalizedSubP.messageKey) + + case invalidUnit := <-p.invalidUnits: + p.log.Error("unit validation failed", + zap.String("message key", invalidUnit.messageKey.String()), + zap.Error(invalidUnit.error), + ) + // todo(rdr): should we mark sender to penalize? } } @@ -265,7 +338,7 @@ func (p *Processor) createSubprocessor( // todo(rdr): passing to avoid closures. Does it makes sense? // need to learn more how closures work in Go if it makes any difference // in performance. - // todo(rdr): should I pass p.done as an argument? + // todo(rdr): should I pass p.chan as an argument? go func( ctx context.Context, key messageKey, @@ -273,12 +346,10 @@ func (p *Processor) createSubprocessor( localShardIndex ShardIndex, unitChan <-chan unitWithSender, ) { - subProcessor := newSubprocessor(&key, scheduler, localShardIndex) - err := subProcessor.Run(ctx, unitChan) - p.done <- messageKeyWithError{ - messageKey: key, - error: err, - } + subProcessor := newSubprocessor( + key.Publisher, scheduler, localShardIndex, unitChan, p.invalidUnits, + ) + subProcessor.Run(ctx) }(ctxWithTimeout, *key, scheduler, localShardIndex, unitChan) return unitChan, nil @@ -303,6 +374,12 @@ func (p *Processor) subprocessorChannel( return unitChan, nil } +func (p *Processor) finalize(key *messageKey) { + p.decreaseTask(key.Publisher) + delete(p.subProcessors, *key) + p.finalized.Add(*key) +} + func (p *Processor) increaseTasks(publisher peer.ID) error { p.mu.Lock() defer p.mu.Unlock() diff --git a/consensus/propeller/validator.go b/consensus/propeller/validator.go index 066056d035..81c13001fd 100644 --- a/consensus/propeller/validator.go +++ b/consensus/propeller/validator.go @@ -29,8 +29,8 @@ type Validator struct { } // todo(rdr): maybe just pass the publisher? -func NewValidator(key *messageKey, scheduler *Scheduler) Validator { - pubKey, err := key.Publisher.ExtractPublicKey() +func NewValidator(publisher peer.ID, scheduler *Scheduler) Validator { + pubKey, err := publisher.ExtractPublicKey() // for now we are assuming that extracting a publisher key is always successful // and done in constant time if err != nil { From d4dd031ae2dcfbe24bc2540effe2571a69d71dbf Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Thu, 9 Apr 2026 10:19:11 +0100 Subject: [PATCH 19/40] chore: remove deprecated validator and processor --- consensus/propeller/deprecated_processor.go | 335 ------------- .../propeller/deprecated_processor_test.go | 451 ------------------ consensus/propeller/deprecated_validator.go | 145 ------ .../propeller/deprecated_validator_test.go | 364 -------------- 4 files changed, 1295 deletions(-) delete mode 100644 consensus/propeller/deprecated_processor.go delete mode 100644 consensus/propeller/deprecated_processor_test.go delete mode 100644 consensus/propeller/deprecated_validator.go delete mode 100644 consensus/propeller/deprecated_validator_test.go diff --git a/consensus/propeller/deprecated_processor.go b/consensus/propeller/deprecated_processor.go deleted file mode 100644 index 8268bb71bb..0000000000 --- a/consensus/propeller/deprecated_processor.go +++ /dev/null @@ -1,335 +0,0 @@ -package propeller - -import ( - "context" - "fmt" - "time" - - "github.com/libp2p/go-libp2p/core/peer" -) - -// processorState tracks which phase of the message lifecycle a processor is in. -// The transitions are strictly one-directional: -// -// PreConstruction -> PostConstruction -> Finalised -// -// There is also a direct path from either state to Finalised via timeout. -type processorState int - -const ( - // statePreConstruction: collecting shards, waiting to reach the build - // threshold so we can reconstruct the original message. - statePreConstruction processorState = iota - - // statePostConstruction: message has been reconstructed. We continue - // counting incoming shards until we hit the receive threshold, which - // guarantees that enough honest nodes have our shard to ensure all - // other honest nodes can also reconstruct. - statePostConstruction - - // stateFinalised: terminal state. The processor emits a result event - // and stops accepting shards. The engine should clean up this processor. - stateFinalised -) - -// shardDelivery bundles an incoming shard with the peer that sent it, -// so the processor can validate the sender identity. -// todo(rdr): a better name for this -type shardDelivery struct { - Unit *Unit - Sender peer.ID -} - -// MessageProcessor manages the lifecycle of a single message identified by -// (channel, publisher, root). It runs as a goroutine that: -// -// 1. Accepts validated shards via its input channel. -// 2. In PreConstruction: collects shards until the build threshold is met, -// then reconstructs the message via Reed-Solomon. -// 3. In PostConstruction: counts additional shards until the receive -// threshold is met, then emits the message to the application. -// 4. On timeout: emits a timeout event and finalises. -// -// The processor is deliberately simple -- it owns no locks and communicates -// entirely through channels. All mutable state is confined to its goroutine. -type MessageProcessor struct { - // Identity - committeeID CommitteeID - publisher peer.ID - root MessageRoot - - // Config - timeout time.Duration - - // Internal State. - state processorState - shards [][]byte // indexed by ShardIndex, nil = not yet received - seenShards map[ShardIndex]struct{} - receivedCount int - signatureVerified bool - storedSignature []byte // cached from the first valid unit - reconstructedMsg []byte - myShardUnit *Unit // the unit we are responsible for forwarding - - // Channels. - shardCh chan shardDelivery // incoming shards from the engine - eventCh chan<- any // outgoing events to the engine/application -} - -// NewMessageProcessor creates a processor for a specific message. The caller -// must call Run() in a goroutine to start processing. -// -// Parameters: -// - shardCh: the engine writes incoming shards here. Buffered to prevent -// blocking the engine's main loop. -// - eventCh: the processor writes lifecycle events here (shared with other -// processors; the engine reads from it). -// - sendFn: callback for network delivery of units to peers. -func NewMessageProcessor( - channel CommitteeID, - publisher peer.ID, - root MessageRoot, - localPeer peer.ID, - config Config, - schedule *Scheduler, - validator *DeprecatedValidator, - encoder Encoder, - shardCh chan shardDelivery, - eventCh chan<- any, - sendFn SendUnitFunc, -) *MessageProcessor { - return &MessageProcessor{ - committeeID: channel, - publisher: publisher, - root: root, - localPeer: localPeer, - config: config, - schedule: schedule, - validator: validator, - encoder: encoder, - state: statePreConstruction, - shards: make([][]byte, schedule.NumShards()), - seenShards: make(map[ShardIndex]bool), - shardCh: shardCh, - eventCh: eventCh, - sendFn: sendFn, - } -} - -// Run is the processor's main loop. It blocks until the processor finalises -// (either by completing the protocol or timing out) or the context is cancelled. -// -// The select on shardCh vs timer is the core of the state machine. We -// intentionally use a single goroutine to avoid any need for synchronisation -// on the processor's internal state. -func (p *MessageProcessor) Run(ctx context.Context) error { - timer := time.NewTimer(p.timeout) - defer timer.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-timer.C: - if p.state != stateFinalised { - p.emitEvent(EventMessageTimeout{ - Channel: p.committeeID, - Publisher: p.publisher, - Root: p.root, - }) - p.state = stateFinalised - } - // throw an error processor is stopped after timeout? - return nil - - case delivery, ok := <-p.shardCh: - if !ok { - // Channel closed by engine; processor is being shot down. - return nil - } - if p.state == stateFinalised { - return nil - } - p.handleShard(ctx, delivery) - if p.state == stateFinalised { - return nil - } - - } - } -} - -// handleShard processes a single incoming shard delivery. -func (p *MessageProcessor) handleShard(ctx context.Context, delivery shardDelivery) { - unit := delivery.Unit - - // Validate the unit. - if err := p.validator.ValidateUnit( - unit, delivery.Sender, p.seenShards, p.signatureVerified, - ); err != nil { - p.emitEvent(EventShardValidationFailed{ - Sender: delivery.Sender, - ClaimedRoot: unit.MessageRoot, - ClaimedPublisher: unit.Publisher, - Err: err, - }) - return - } - - // Mark the shard as received and store its data. - p.seenShards[unit.ShardIndex] = true - p.shards[unit.ShardIndex] = unit.ShardData - p.receivedCount++ - - // Cache the signature from the first valid unit. All units for the same - // message carry the same publisher signature, so we only need one copy. - // We store it here rather than in the unit slice because we only keep - // shard data (not full units) to save memory. - if !p.signatureVerified { - p.storedSignature = make([]byte, len(unit.Signature)) - copy(p.storedSignature, unit.Signature) - } - p.signatureVerified = true - - switch p.state { - case statePreConstruction: - p.handlePreConstruction(ctx) - case statePostConstruction: - p.handlePostConstruction() - case stateFinalised: - // Should not reach here due to early return in Run, but be safe. - } -} - -// handlePreConstruction checks if we have enough shards to reconstruct. -func (p *MessageProcessor) handlePreConstruction(ctx context.Context) { - if p.receivedCount < p.schedule.BuildThreshold() { - return - } - - // Attempt Reed-Solomon reconstruction. - // We pass copies of the shard data because Reconstruct modifies the - // slice in-place, and we don't want to corrupt our stored references. - shardsCopy := make([][]byte, len(p.shards)) - for i, s := range p.shards { - if s != nil { - c := make([]byte, len(s)) - copy(c, s) - shardsCopy[i] = c - } - } - - msg, err := ReconstructMessage(shardsCopy, p.schedule, p.encoder, p.root) - if err != nil { - p.emitEvent(EventReconstructionFailed{ - Root: p.root, - Publisher: p.publisher, - Err: err, - }) - p.state = stateFinalised - return - } - - // Find our assigned shard so we can forward it to all other peers. - myShard, err := p.schedule.ShardForPeer(p.publisher, p.localPeer) - if err != nil { - p.emitEvent(EventReconstructionFailed{ - Root: p.root, - Publisher: p.publisher, - Err: fmt.Errorf("determining my shard assignment: %w", err), - }) - p.state = stateFinalised - return - } - - // Rebuild the Merkle tree from the complete shard set to get a valid - // proof for our shard. We may not have received our own shard from - // the network, so we need the proof from the reconstructed data. - leaves := make([][]byte, len(shardsCopy)) - copy(leaves, shardsCopy) - _, proofs := BuildMerkleTree(leaves) - - p.myShardUnit = &Unit{ - CommitteeID: p.committeeID, - Publisher: p.publisher, - MessageRoot: p.root, - Signature: p.storedSignature, - ShardIndex: myShard, - ShardData: shardsCopy[myShard], - MerkleProof: proofs[myShard], - } - - p.reconstructedMsg = msg - - // Replace our sparse shard data with the fully reconstructed set. - p.shards = shardsCopy - - // Count our own shard as held if we didn't receive it from the network. - if !p.seenShards[myShard] { - p.seenShards[myShard] = true - p.receivedCount++ - } - - p.state = statePostConstruction - - // Broadcast our shard to all other peers (except the publisher, who - // already has all shards). - p.broadcastMyShard(ctx) - - // Check if we already meet the receive threshold (possible if many - // shards arrived before reconstruction completed). - p.handlePostConstruction() -} - -// handlePostConstruction checks if the receive threshold has been met. -func (p *MessageProcessor) handlePostConstruction() { - if p.receivedCount < p.schedule.ReceiveThreshold() { - return - } - - p.emitEvent(EventMessageReceived{ - Publisher: p.publisher, - Root: p.root, - Message: p.reconstructedMsg, - }) - p.state = stateFinalised -} - -// broadcastMyShard sends our assigned shard to all peers except the publisher. -// Failures are reported as events but do not stop the broadcast to other peers. -func (p *MessageProcessor) broadcastMyShard(ctx context.Context) { - targets, err := p.schedule.BroadcastTargets(p.publisher) - if err != nil { - p.emitEvent(EventShardPublishFailed{ - Err: fmt.Errorf("getting broadcast targets: %w", err), - }) - return - } - - for _, target := range targets { - if target == p.localPeer { - // Don't send to ourselves. - continue - } - - if err := p.sendFn(ctx, target, p.myShardUnit); err != nil { - p.emitEvent(EventShardSendFailed{ - From: p.localPeer, - To: target, - Err: err, - }) - } - } -} - -// emitEvent sends an event to the application layer. Uses a non-blocking send -// so a slow consumer doesn't block the processor. The engine's event channel -// should be large enough that this rarely drops. -func (p *MessageProcessor) emitEvent(event any) { - select { - case p.eventCh <- event: - default: - // Event channel is full. This should be rare with a properly sized - // buffer. The event is lost, but the processor continues operating. - } -} diff --git a/consensus/propeller/deprecated_processor_test.go b/consensus/propeller/deprecated_processor_test.go deleted file mode 100644 index a4b1e29400..0000000000 --- a/consensus/propeller/deprecated_processor_test.go +++ /dev/null @@ -1,451 +0,0 @@ -package propeller - -import ( - "context" - "sync" - "testing" - "time" - - "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/peer" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// processorTestEnv bundles the common setup for processor tests. It creates -// a realistic N-peer environment with a real Reed-Solomon encoder. -type processorTestEnv struct { - peers []peer.ID - privKeys map[peer.ID]crypto.PrivKey - schedule *Scheduler - encoder Encoder - config Config - eventCh chan any - sentUnits []sentUnit - sentMu sync.Mutex -} - -type sentUnit struct { - To peer.ID - Unit *Unit -} - -func newProcessorTestEnv(t *testing.T, n int) *processorTestEnv { - t.Helper() - - rawPeers := make([]peer.ID, n) - privKeys := make(map[peer.ID]crypto.PrivKey, n) - for i := range n { - priv, id := realPeer(byte(i)) - rawPeers[i] = id - privKeys[id] = priv - } - - schedule := NewScheduler(rawPeers) - - enc, err := NewEncoder(schedule.NumDataShards(), schedule.NumCodingShards()) - require.NoError(t, err) - - return &processorTestEnv{ - peers: schedule.Peers(), // Use sorted order. - privKeys: privKeys, - schedule: schedule, - encoder: enc, - config: Config{ - StaleMessageTimeout: 5 * time.Second, - }, - eventCh: make(chan any, 100), - } -} - -// sendFunc records sent units for later inspection. -func (env *processorTestEnv) sendFunc() SendUnitFunc { - return func(_ context.Context, to peer.ID, unit *Unit) error { - env.sentMu.Lock() - defer env.sentMu.Unlock() - env.sentUnits = append(env.sentUnits, sentUnit{To: to, Unit: unit}) - return nil - } -} - -// encodeTestMessage encodes a message from the given publisher and returns -// the signed units and root. -func (env *processorTestEnv) encodeTestMessage( - t *testing.T, publisher peer.ID, msg []byte, -) ([]Unit, MessageRoot) { - t.Helper() - - units, root, err := EncodeMessage(msg, env.schedule, env.encoder) - require.NoError(t, err) - - privKey, ok := env.privKeys[publisher] - require.True(t, ok, "no private key for publisher %s", publisher) - - sig, err := SignMessage(root, privKey) - require.NoError(t, err) - - for i := range units { - units[i].Publisher = publisher - units[i].Signature = sig - units[i].CommitteeID = 1 - } - - return units, root -} - -// drainEvents reads all currently available events from the event channel. -func (env *processorTestEnv) drainEvents() []any { - var events []any - for { - select { - case ev := <-env.eventCh: - events = append(events, ev) - default: - return events - } - } -} - -func TestProcessor_FullLifecycle(t *testing.T) { - // Simulate a 7-node network. localPeer (sorted[0]) receives shards from - // a message published by sorted[1]. With 7 peers: 2 data shards, - // 4 coding shards. Build threshold = 2, receive threshold = 4. - env := newProcessorTestEnv(t, 7) - - localPeer := env.peers[0] - publisher := env.peers[1] - msg := []byte("hello propeller protocol") - - units, root := env.encodeTestMessage(t, publisher, msg) - - validator := NewValidator( - env.schedule, localPeer, &DefaultSignatureVerifier{}, - ) - - shardCh := make(chan shardDelivery, 20) - proc := NewMessageProcessor( - 1, publisher, root, localPeer, env.config, - env.schedule, validator, env.encoder, - shardCh, env.eventCh, env.sendFunc(), - ) - - ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) - defer cancel() - - // Run the processor in a goroutine. - done := make(chan struct{}) - go func() { - proc.Run(ctx) - close(done) - }() - - // Feed shards one at a time, from their correct senders. - // Skip shards "from" localPeer (the validator rejects self-sends). - for i, unit := range units { - sender, err := env.schedule.PeerForShard(publisher, ShardIndex(i)) - require.NoError(t, err) - - if sender == localPeer { - continue - } - - unitCopy := unit - shardCh <- shardDelivery{Unit: &unitCopy, Sender: sender} - } - - // Wait for the processor to finalise. - select { - case <-done: - case <-time.After(3 * time.Second): - t.Fatal("processor did not finalise in time") - } - - // Check that we got a MessageReceived event. - events := env.drainEvents() - var received *EventMessageReceived - for _, ev := range events { - if r, ok := ev.(EventMessageReceived); ok { - received = &r - break - } - } - require.NotNil(t, received, - "expected EventMessageReceived, got %d events", len(events)) - assert.Equal(t, msg, received.Message) - assert.Equal(t, publisher, received.Publisher) - assert.Equal(t, root, received.Root) - - // Check that our shard was broadcast to other peers. - env.sentMu.Lock() - defer env.sentMu.Unlock() - assert.NotEmpty(t, env.sentUnits, "should have broadcast our shard") -} - -func TestProcessor_ReconstructionFromMinimumShards(t *testing.T) { - // With 4 peers: 1 data shard, 2 coding shards. - // Build threshold = 1, receive threshold = 2 (N>3). - // After reconstruction, the processor counts its own shard (=1), - // so it needs at least 1 more from the network to reach 2. - env := newProcessorTestEnv(t, 4) - - localPeer := env.peers[0] - publisher := env.peers[1] - msg := []byte("minimum shards test") - - units, root := env.encodeTestMessage(t, publisher, msg) - - validator := NewValidator( - env.schedule, localPeer, &DefaultSignatureVerifier{}, - ) - - shardCh := make(chan shardDelivery, 10) - proc := NewMessageProcessor( - 1, publisher, root, localPeer, env.config, - env.schedule, validator, env.encoder, - shardCh, env.eventCh, env.sendFunc(), - ) - - ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) - defer cancel() - - done := make(chan struct{}) - go func() { - proc.Run(ctx) - close(done) - }() - - // Send all non-local shards. With 3 shards total and localPeer holding - // one slot, we have 2 shards to send. receive threshold = 2, and - // after reconstruction the processor holds its own shard (+1), so - // it needs 1 from the network + 1 own = 2 to finalise. - for i, unit := range units { - sender, err := env.schedule.PeerForShard(publisher, ShardIndex(i)) - require.NoError(t, err) - - if sender == localPeer { - continue - } - - unitCopy := unit - shardCh <- shardDelivery{Unit: &unitCopy, Sender: sender} - } - - select { - case <-done: - case <-time.After(3 * time.Second): - t.Fatal("processor did not finalise in time") - } - - events := env.drainEvents() - var received *EventMessageReceived - for _, ev := range events { - if r, ok := ev.(EventMessageReceived); ok { - received = &r - break - } - } - require.NotNil(t, received, "expected EventMessageReceived") - assert.Equal(t, msg, received.Message) -} - -func TestProcessor_Timeout(t *testing.T) { - env := newProcessorTestEnv(t, 4) - // Use a very short timeout for the test. - env.config.StaleMessageTimeout = 50 * time.Millisecond - - localPeer := env.peers[0] - publisher := env.peers[1] - - _, root := env.encodeTestMessage(t, publisher, []byte("will timeout")) - - validator := NewValidator( - env.schedule, localPeer, &DefaultSignatureVerifier{}, - ) - - shardCh := make(chan shardDelivery, 10) - proc := NewMessageProcessor( - 42, publisher, root, localPeer, env.config, - env.schedule, validator, env.encoder, - shardCh, env.eventCh, env.sendFunc(), - ) - - ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) - defer cancel() - - done := make(chan struct{}) - go func() { - proc.Run(ctx) - close(done) - }() - - // Don't send any shards -- just wait for timeout. - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatal("processor did not finalise after timeout") - } - - events := env.drainEvents() - var timeout *EventMessageTimeout - for _, ev := range events { - if to, ok := ev.(EventMessageTimeout); ok { - timeout = &to - break - } - } - require.NotNil(t, timeout, - "expected EventMessageTimeout, got %d events", len(events)) - assert.Equal(t, CommitteeID(42), timeout.Channel) - assert.Equal(t, publisher, timeout.Publisher) - assert.Equal(t, root, timeout.Root) -} - -func TestProcessor_DuplicateShardRejected(t *testing.T) { - // Use 10 peers so the processor doesn't finalise after the first shard. - // N=10: numDataShards=3, receiveThreshold=6. - // Sending just one shard (receivedCount=1) is far below the threshold, - // so the processor stays in PreConstruction and will reject the duplicate. - env := newProcessorTestEnv(t, 10) - - localPeer := env.peers[0] - publisher := env.peers[1] - msg := []byte("test duplicates") - - units, root := env.encodeTestMessage(t, publisher, msg) - - validator := NewValidator( - env.schedule, localPeer, &DefaultSignatureVerifier{}, - ) - - shardCh := make(chan shardDelivery, 10) - proc := NewMessageProcessor( - 1, publisher, root, localPeer, env.config, - env.schedule, validator, env.encoder, - shardCh, env.eventCh, env.sendFunc(), - ) - - ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) - defer cancel() - - done := make(chan struct{}) - go func() { - proc.Run(ctx) - close(done) - }() - - // Find a shard not from localPeer. - var targetUnit Unit - var targetSender peer.ID - for i, unit := range units { - sender, err := env.schedule.PeerForShard(publisher, ShardIndex(i)) - require.NoError(t, err) - if sender != localPeer { - targetUnit = unit - targetSender = sender - break - } - } - - // Send the same shard twice, back-to-back. The processor handles them - // sequentially (single goroutine), so the second one will see the first - // already in seenShards. - shardCh <- shardDelivery{Unit: &targetUnit, Sender: targetSender} - - dup := targetUnit - shardCh <- shardDelivery{Unit: &dup, Sender: targetSender} - - // Give the processor time to handle both deliveries. - time.Sleep(200 * time.Millisecond) - cancel() - <-done - - // Check that we got a validation failure event for the duplicate. - events := env.drainEvents() - var validationFailed bool - for _, ev := range events { - if vf, ok := ev.(EventShardValidationFailed); ok { - var valErr *ShardValidationError - if asErr, ok := vf.Err.(*ShardValidationError); ok { - valErr = asErr - } - if valErr != nil && valErr.Reason == ReasonDuplicateShard { - validationFailed = true - break - } - } - } - assert.True(t, validationFailed, "expected duplicate shard to be rejected") -} - -func TestProcessor_ContextCancellation(t *testing.T) { - env := newProcessorTestEnv(t, 4) - - localPeer := env.peers[0] - publisher := env.peers[1] - - validator := NewValidator( - env.schedule, localPeer, &DefaultSignatureVerifier{}, - ) - - root := MessageRoot{0x01} - shardCh := make(chan shardDelivery, 10) - proc := NewMessageProcessor( - 1, publisher, root, localPeer, env.config, - env.schedule, validator, env.encoder, - shardCh, env.eventCh, env.sendFunc(), - ) - - ctx, cancel := context.WithCancel(t.Context()) - - done := make(chan struct{}) - go func() { - proc.Run(ctx) - close(done) - }() - - // Cancel immediately. - cancel() - - select { - case <-done: - // Good, processor exited. - case <-time.After(1 * time.Second): - t.Fatal("processor did not exit on context cancellation") - } -} - -func TestProcessor_ChannelClose(t *testing.T) { - env := newProcessorTestEnv(t, 4) - - localPeer := env.peers[0] - publisher := env.peers[1] - - validator := NewValidator( - env.schedule, localPeer, &DefaultSignatureVerifier{}, - ) - - root := MessageRoot{0x02} - shardCh := make(chan shardDelivery, 10) - proc := NewMessageProcessor( - 1, publisher, root, localPeer, env.config, - env.schedule, validator, env.encoder, - shardCh, env.eventCh, env.sendFunc(), - ) - - ctx := t.Context() - - done := make(chan struct{}) - go func() { - proc.Run(ctx) - close(done) - }() - - // Close the shard channel to signal teardown. - close(shardCh) - - select { - case <-done: - case <-time.After(1 * time.Second): - t.Fatal("processor did not exit on channel close") - } -} diff --git a/consensus/propeller/deprecated_validator.go b/consensus/propeller/deprecated_validator.go deleted file mode 100644 index ec0c390cae..0000000000 --- a/consensus/propeller/deprecated_validator.go +++ /dev/null @@ -1,145 +0,0 @@ -package propeller - -import ( - "fmt" - - "github.com/libp2p/go-libp2p/core/peer" -) - -// DeprecatedValidator checks incoming PropellerUnits for correctness. Each check -// serves a specific defensive purpose: -// -// - Self-sending check: prevents reflection attacks. -// - Self-published check: we already have all shards for our own messages. -// - Duplicate check: avoids redundant work and state corruption. -// - Origin check: ensures the sender is the peer assigned to this shard, -// preventing sybil-like relay attacks. -// - Merkle proof check: ensures the shard data is authentic (matches the -// committed tree root). -// - Signature check: ensures the publisher actually authored the message -// (the root they committed to). -// -// These checks are ordered from cheapest to most expensive so we reject -// invalid units as early as possible. -type DeprecatedValidator struct { - schedule *Scheduler - localPeer peer.ID - verifier SignatureVerifier -} - -// NewValidator creates a validator for the given channel configuration. -func NewDeprecatedValidator( - schedule *Scheduler, - localPeer peer.ID, - verifier SignatureVerifier, -) *DeprecatedValidator { - return &DeprecatedValidator{ - schedule: schedule, - localPeer: localPeer, - verifier: verifier, - } -} - -// ValidateUnit checks an incoming unit against all validation rules. -// -// Parameters: -// - unit: the incoming PropellerUnit to validate. -// - sender: the peer.ID of the network peer that sent us this unit. -// - seenShards: set of shard indices already received for this message -// (used for duplicate detection). -// - signatureVerified: true if we have already verified the publisher's -// signature for this Merkle root. Allows skipping the expensive crypto -// check after the first shard from the same message passes. -// -// Returns nil if valid, or a *ShardValidationError describing the failure. -func (v *DeprecatedValidator) ValidateUnit( - unit *Unit, - sender peer.ID, - seenShards map[ShardIndex]bool, - signatureVerified bool, -) error { - // 1. Reject units from ourselves (should never happen in normal - // operation; indicates a routing bug or reflection attack). - if sender == v.localPeer { - return &ShardValidationError{ - Reason: ReasonSelfSending, - Detail: "received unit from ourselves", - } - } - - // 2. Reject units for messages we published (we already have all - // shards and don't need them relayed back). - if unit.Publisher == v.localPeer { - return &ShardValidationError{ - Reason: ReasonReceivedSelfPublishedShard, - Detail: "received shard for a message we published", - } - } - - // 3. Reject duplicate shards. A well-behaved peer sends each shard - // exactly once; duplicates waste bandwidth and could corrupt state. - if seenShards[unit.ShardIndex] { - return &ShardValidationError{ - Reason: ReasonDuplicateShard, - Detail: fmt.Sprintf("already received shard %d", unit.ShardIndex), - } - } - - // 4. Verify the sender is either the peer assigned to broadcast this - // shard or the publisher itself (who initially distributes all shards). - // This prevents a Byzantine node from impersonating another peer's - // shard assignment while still allowing the publisher's initial send. - expectedPeer, err := v.schedule.PeerForShard(unit.Publisher, unit.ShardIndex) - if err != nil { - return &ShardValidationError{ - Reason: ReasonScheduleError, - Detail: fmt.Sprintf( - "looking up peer for shard %d: %v", unit.ShardIndex, err, - ), - } - } - if sender != expectedPeer && sender != unit.Publisher { - return &ShardValidationError{ - Reason: ReasonUnexpectedSender, - Detail: fmt.Sprintf( - "shard %d should come from %s or publisher %s, got %s", - unit.ShardIndex, expectedPeer, unit.Publisher, sender, - ), - } - } - - // 5. Verify the Merkle inclusion proof. This ensures the shard data - // is consistent with the tree root the publisher committed to. - if !VerifyMerkleProof( - unit.MessageRoot, unit.ShardData, uint32(unit.ShardIndex), unit.MerkleProof, - ) { - return &ShardValidationError{ - Reason: ReasonMerkleProofVerificationFailed, - Detail: fmt.Sprintf("merkle proof invalid for shard %d", unit.ShardIndex), - } - } - - // 6. Verify the publisher's signature over the root. This is the most - // expensive check (public-key crypto), so we skip it if we've already - // verified the same root from this publisher. - if !signatureVerified { - payload := SignPayload(unit.MessageRoot) - ok, err := v.verifier.Verify(unit.Publisher, payload, unit.Signature) - if err != nil { - return &ShardValidationError{ - Reason: ReasonSignatureVerificationFailed, - Detail: fmt.Sprintf( - "verifying signature: %v", err, - ), - } - } - if !ok { - return &ShardValidationError{ - Reason: ReasonSignatureVerificationFailed, - Detail: "signature does not match publisher's public key", - } - } - } - - return nil -} diff --git a/consensus/propeller/deprecated_validator_test.go b/consensus/propeller/deprecated_validator_test.go deleted file mode 100644 index cd48adcfec..0000000000 --- a/consensus/propeller/deprecated_validator_test.go +++ /dev/null @@ -1,364 +0,0 @@ -package propeller - -import ( - "bytes" - "crypto/ed25519" - "fmt" - "testing" - - "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/peer" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// mockVerifier is a test double for SignatureVerifier that returns -// configurable results. -type mockVerifier struct { - valid bool - err error -} - -func (m *mockVerifier) Verify(peer.ID, []byte, []byte) (bool, error) { - return m.valid, m.err -} - -// realPeer creates a real libp2p peer.ID from a deterministic Ed25519 seed. -// Real peer IDs are needed because DefaultSignatureVerifier extracts the -// public key from the peer ID, which only works for keys encoded into the ID. -func realPeer(seed byte) (crypto.PrivKey, peer.ID) { - seedBytes := make([]byte, ed25519.SeedSize) - seedBytes[0] = seed - reader := bytes.NewReader(seedBytes) - priv, pub, err := crypto.GenerateEd25519Key(reader) - if err != nil { - panic(err) - } - id, err := peer.IDFromPublicKey(pub) - if err != nil { - panic(err) - } - return priv, id -} - -// makeValidUnit creates a PropellerUnit that passes all validation checks -// for the given schedule, publisher, and shard index. The unit has a valid -// Merkle proof and signature. -func makeValidUnit( - t *testing.T, - schedule *Scheduler, - publisherKey crypto.PrivKey, - publisher peer.ID, - shardIndex ShardIndex, -) *Unit { - t.Helper() - - // Create a simple message and encode it. - msg := []byte("test message for validation") - enc, err := NewEncoder(schedule.NumDataShards(), schedule.NumCodingShards()) - require.NoError(t, err) - - units, root, err := EncodeMessage(msg, schedule, enc) - require.NoError(t, err) - - sig, err := SignMessage(root, publisherKey) - require.NoError(t, err) - - unit := &units[shardIndex] - unit.Publisher = publisher - unit.Signature = sig - unit.CommitteeID = 1 - return unit -} - -// validatorTestSetup creates a realistic N-peer environment and returns -// the schedule, a local peer (which is NOT the publisher), the publisher's -// key and ID, and a shard index that maps to a sender who is neither -// localPeer nor publisher. -type validatorTestSetup struct { - schedule *Scheduler - localPeer peer.ID - publisher peer.ID - publisherKey crypto.PrivKey - // shardIndex and expectedSender: a shard whose assigned sender is a - // third peer (neither localPeer nor publisher). - shardIndex ShardIndex - expectedSender peer.ID -} - -func newValidatorTestSetup(t *testing.T) validatorTestSetup { - t.Helper() - - // Create 5 peers so we have enough room to find a shard where the - // sender is a third party. - n := 5 - keys := make([]crypto.PrivKey, n) - ids := make([]peer.ID, n) - for i := range n { - keys[i], ids[i] = realPeer(byte(i)) - } - - schedule := NewScheduler(ids) - sorted := schedule.Peers() - - // Pick localPeer = sorted[0], publisher = sorted[1]. - localPeer := sorted[0] - publisher := sorted[1] - - // Find the publisher's private key. - var publisherKey crypto.PrivKey - for i, id := range ids { - if id == publisher { - publisherKey = keys[i] - break - } - } - require.NotNil(t, publisherKey) - - // Find a shard whose expected sender is NOT localPeer. - var shardIndex ShardIndex - var expectedSender peer.ID - found := false - for si := range schedule.NumShards() { - s, err := schedule.PeerForShard(publisher, ShardIndex(si)) - require.NoError(t, err) - if s != localPeer { - shardIndex = ShardIndex(si) - expectedSender = s - found = true - break - } - } - require.True(t, found, "could not find a shard with a third-party sender") - - return validatorTestSetup{ - schedule: schedule, - localPeer: localPeer, - publisher: publisher, - publisherKey: publisherKey, - shardIndex: shardIndex, - expectedSender: expectedSender, - } -} - -func TestValidator_HappyPath(t *testing.T) { - setup := newValidatorTestSetup(t) - v := NewValidator(setup.schedule, setup.localPeer, &DefaultSignatureVerifier{}) - - unit := makeValidUnit( - t, setup.schedule, setup.publisherKey, - setup.publisher, setup.shardIndex, - ) - - seenShards := make(map[ShardIndex]bool) - err := v.ValidateUnit(unit, setup.expectedSender, seenShards, false) - assert.NoError(t, err) -} - -func TestValidator_SelfSending(t *testing.T) { - setup := newValidatorTestSetup(t) - v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) - - unit := &Unit{Publisher: setup.publisher, ShardIndex: setup.shardIndex} - err := v.ValidateUnit(unit, setup.localPeer, nil, true) - - var valErr *ShardValidationError - require.ErrorAs(t, err, &valErr) - assert.Equal(t, ReasonSelfSending, valErr.Reason) -} - -func TestValidator_ReceivedSelfPublishedShard(t *testing.T) { - setup := newValidatorTestSetup(t) - v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) - - // Unit claims we are the publisher. - unit := &Unit{Publisher: setup.localPeer, ShardIndex: 0} - err := v.ValidateUnit(unit, setup.expectedSender, nil, true) - - var valErr *ShardValidationError - require.ErrorAs(t, err, &valErr) - assert.Equal(t, ReasonReceivedSelfPublishedShard, valErr.Reason) -} - -func TestValidator_DuplicateShard(t *testing.T) { - setup := newValidatorTestSetup(t) - v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) - - unit := &Unit{Publisher: setup.publisher, ShardIndex: setup.shardIndex} - seenShards := map[ShardIndex]bool{setup.shardIndex: true} - - err := v.ValidateUnit(unit, setup.expectedSender, seenShards, true) - var valErr *ShardValidationError - require.ErrorAs(t, err, &valErr) - assert.Equal(t, ReasonDuplicateShard, valErr.Reason) -} - -func TestValidator_UnexpectedSender(t *testing.T) { - setup := newValidatorTestSetup(t) - v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) - - // Find a peer that is NOT the expected sender, NOT localPeer, and NOT - // the publisher. The publisher is now an accepted sender for any shard, - // so it must be excluded from the "wrong sender" set. - var wrongSender peer.ID - for _, p := range setup.schedule.Peers() { - if p != setup.expectedSender && p != setup.localPeer && p != setup.publisher { - wrongSender = p - break - } - } - require.NotEmpty(t, wrongSender) - - unit := &Unit{Publisher: setup.publisher, ShardIndex: setup.shardIndex} - seenShards := make(map[ShardIndex]bool) - err := v.ValidateUnit(unit, wrongSender, seenShards, true) - - var valErr *ShardValidationError - require.ErrorAs(t, err, &valErr) - assert.Equal(t, ReasonUnexpectedSender, valErr.Reason) -} - -func TestValidator_PublisherAsAcceptedSender(t *testing.T) { - setup := newValidatorTestSetup(t) - v := NewValidator(setup.schedule, setup.localPeer, &DefaultSignatureVerifier{}) - - // The publisher initially distributes all shards, so it should be - // accepted as a sender for any shard -- even one assigned to another peer. - unit := makeValidUnit( - t, setup.schedule, setup.publisherKey, - setup.publisher, setup.shardIndex, - ) - - seenShards := make(map[ShardIndex]bool) - err := v.ValidateUnit(unit, setup.publisher, seenShards, false) - assert.NoError(t, err, "publisher should be accepted as sender for any shard") -} - -func TestValidator_MerkleProofFailed(t *testing.T) { - setup := newValidatorTestSetup(t) - v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) - - // Create a unit with a bad Merkle proof. - unit := &Unit{ - Publisher: setup.publisher, - ShardIndex: setup.shardIndex, - MessageRoot: MessageRoot{0x01}, - ShardData: []byte("data"), - MerkleProof: MerkleProof{Siblings: [][32]byte{{0xde, 0xad}}}, - } - - seenShards := make(map[ShardIndex]bool) - err := v.ValidateUnit(unit, setup.expectedSender, seenShards, true) - - var valErr *ShardValidationError - require.ErrorAs(t, err, &valErr) - assert.Equal(t, ReasonMerkleProofVerificationFailed, valErr.Reason) -} - -func TestValidator_SignatureVerificationFailed(t *testing.T) { - setup := newValidatorTestSetup(t) - - // Use a verifier that rejects signatures. - v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: false}) - - unit := makeValidUnit( - t, setup.schedule, setup.publisherKey, - setup.publisher, setup.shardIndex, - ) - - seenShards := make(map[ShardIndex]bool) - err := v.ValidateUnit(unit, setup.expectedSender, seenShards, false) - - var valErr *ShardValidationError - require.ErrorAs(t, err, &valErr) - assert.Equal(t, ReasonSignatureVerificationFailed, valErr.Reason) -} - -func TestValidator_SignatureVerificationError(t *testing.T) { - setup := newValidatorTestSetup(t) - - // Use a verifier that returns an error. - v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{ - valid: false, - err: fmt.Errorf("key extraction failed"), - }) - - unit := makeValidUnit( - t, setup.schedule, setup.publisherKey, - setup.publisher, setup.shardIndex, - ) - - seenShards := make(map[ShardIndex]bool) - err := v.ValidateUnit(unit, setup.expectedSender, seenShards, false) - - var valErr *ShardValidationError - require.ErrorAs(t, err, &valErr) - assert.Equal(t, ReasonSignatureVerificationFailed, valErr.Reason) - assert.Contains(t, valErr.Detail, "key extraction failed") -} - -func TestValidator_SkipSignatureWhenAlreadyVerified(t *testing.T) { - setup := newValidatorTestSetup(t) - - // Verifier that would reject -- but we pass signatureVerified=true. - v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: false}) - - unit := makeValidUnit( - t, setup.schedule, setup.publisherKey, - setup.publisher, setup.shardIndex, - ) - - seenShards := make(map[ShardIndex]bool) - err := v.ValidateUnit(unit, setup.expectedSender, seenShards, true) - assert.NoError(t, err, "should skip signature check when already verified") -} - -func TestSignPayload(t *testing.T) { - root := MessageRoot{0x01, 0x02, 0x03} - payload := SignPayload(root) - - expected := append([]byte(""), root[:]...) - expected = append(expected, []byte("")...) - assert.Equal(t, expected, payload) -} - -func TestSignRoot_RoundTrip(t *testing.T) { - privKey, peerID := realPeer(42) - - root := MessageRoot{0xaa, 0xbb, 0xcc} - sig, err := SignMessage(root, privKey) - require.NoError(t, err) - require.NotEmpty(t, sig) - - // Verify with the default verifier. - verifier := DefaultSignatureVerifier{} - payload := SignPayload(root) - ok, err := verifier.Verify(peerID, payload, sig) - require.NoError(t, err) - assert.True(t, ok) - - // Wrong root should fail. - wrongRoot := MessageRoot{0xff} - wrongPayload := SignPayload(wrongRoot) - ok, err = verifier.Verify(peerID, wrongPayload, sig) - require.NoError(t, err) - assert.False(t, ok) -} - -func TestValidator_ScheduleError(t *testing.T) { - setup := newValidatorTestSetup(t) - v := NewValidator(setup.schedule, setup.localPeer, &mockVerifier{valid: true}) - - _, unknownPeer := realPeer(99) - unit := &Unit{ - Publisher: unknownPeer, - ShardIndex: 0, - ShardData: []byte("data"), - MerkleProof: MerkleProof{}, - } - - err := v.ValidateUnit(unit, setup.expectedSender, make(map[ShardIndex]bool), true) - var valErr *ShardValidationError - require.ErrorAs(t, err, &valErr) - assert.Equal(t, ReasonScheduleError, valErr.Reason) -} From 5f97c57b620b08beb49a280c932976265759dbea Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Thu, 9 Apr 2026 10:19:41 +0100 Subject: [PATCH 20/40] chore: move padding out of utils and delete utils --- consensus/propeller/{utils => }/padding.go | 0 consensus/propeller/{utils => }/padding_test.go | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename consensus/propeller/{utils => }/padding.go (100%) rename consensus/propeller/{utils => }/padding_test.go (100%) diff --git a/consensus/propeller/utils/padding.go b/consensus/propeller/padding.go similarity index 100% rename from consensus/propeller/utils/padding.go rename to consensus/propeller/padding.go diff --git a/consensus/propeller/utils/padding_test.go b/consensus/propeller/padding_test.go similarity index 100% rename from consensus/propeller/utils/padding_test.go rename to consensus/propeller/padding_test.go From 9c730422bc16a361280393014cdc4d031db96695 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Thu, 9 Apr 2026 10:20:28 +0100 Subject: [PATCH 21/40] chore: remove unused pool impl --- consensus/propeller/pool/pool.go | 26 -------------------------- consensus/propeller/pool/pool_test.go | 0 2 files changed, 26 deletions(-) delete mode 100644 consensus/propeller/pool/pool.go delete mode 100644 consensus/propeller/pool/pool_test.go diff --git a/consensus/propeller/pool/pool.go b/consensus/propeller/pool/pool.go deleted file mode 100644 index d4282891db..0000000000 --- a/consensus/propeller/pool/pool.go +++ /dev/null @@ -1,26 +0,0 @@ -package pool - -import ( - "context" - "time" -) - -type Pool[T any] struct { - ctx context.Context - taskTimeout time.Duration - activeWorkers uint64 - maxWorkers uint64 -} - -func New[T any]( - ctx context.Context, - taskTimeout time.Duration, - maxWorkers uint64, -) *Pool[T] { - return &Pool[T]{ - ctx: ctx, - taskTimeout: taskTimeout, - activeWorkers: 0, - maxWorkers: maxWorkers, - } -} diff --git a/consensus/propeller/pool/pool_test.go b/consensus/propeller/pool/pool_test.go deleted file mode 100644 index e69de29bb2..0000000000 From 61e872dfe1f3d02fa080f19a1b2d08c97345e9f6 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Fri, 10 Apr 2026 11:46:52 +0100 Subject: [PATCH 22/40] feat: add new (super fast) time cache --- consensus/propeller/timecache/timecache.go | 146 ++++++++++++++++ .../propeller/timecache/timecache_test.go | 165 ++++++++++++++++++ 2 files changed, 311 insertions(+) create mode 100644 consensus/propeller/timecache/timecache.go create mode 100644 consensus/propeller/timecache/timecache_test.go diff --git a/consensus/propeller/timecache/timecache.go b/consensus/propeller/timecache/timecache.go new file mode 100644 index 0000000000..924da4599e --- /dev/null +++ b/consensus/propeller/timecache/timecache.go @@ -0,0 +1,146 @@ +package timecache + +import ( + "sync" + "time" +) + +type index int + +type timedValue[K any] struct { + value K + expiry time.Time +} + +type TimeCache[K comparable] struct { + // Access valid keys in O(1) + values map[K]time.Time + // Expire in O(k) where `k` is the amount of expired keys + timestamps []timedValue[K] + mu sync.RWMutex + + // Track currently stored values, first value at `start` and last + // one at `end`. `size` is the maximum amount of elements we can hold + start index + end index + size int + // used to delete any timed value which has been inserted more than + // `exipiry` time.Duration ago. + expiry time.Duration +} + +// New allocates a new Timecache with initial allocation size and expiry time. +// If `size` gets filled the timecache will allocate more memory to fit more +// elements into it. The cache will not shrink after regrowing. +func New[K comparable](size int, expiry time.Duration) *TimeCache[K] { + // we allocate size+1 because we allways leave the last position empty + // to detect when the cache is full + return &TimeCache[K]{ + values: make(map[K]time.Time, size+1), + timestamps: make([]timedValue[K], size+1), + mu: sync.RWMutex{}, + + start: 0, + end: 0, + size: size + 1, + expiry: expiry, + } +} + +// Add adds a new key into the timecache, it doesn't guard against duplicated +// entries. Adding the same entry twice will result in undefined behaviour. +func (tc *TimeCache[K]) Add(value *K) { + tc.mu.Lock() + defer tc.mu.Unlock() + + now := time.Now() + tc.removeExpired(now) + if tc.almostFull() { + tc.regrowth() + } + + expiryTime := now.Add(tc.expiry) + tc.values[*value] = expiryTime + tc.timestamps[tc.end] = timedValue[K]{ + value: *value, + expiry: expiryTime, + } + tc.increaseIndex(&tc.end) +} + +func (tc *TimeCache[K]) Get(value *K) bool { + tc.mu.RLock() + expiry, ok := tc.values[*value] + tc.mu.RUnlock() + + if !ok { + return false + } + + now := time.Now() + if expiry.After(now) { + return true + } + + // If we know we have an expired value + // let's clean the expired entries + tc.mu.Lock() + tc.removeExpired(now) + tc.mu.Unlock() + return false +} + +func (tc *TimeCache[K]) increaseIndex(idx *index) { + *idx = (*idx + 1) % index(tc.size) +} + +// removeExpired deletes all the elements that have already expired until it +// finds the first one that hasn't or the cache empties +func (tc *TimeCache[K]) removeExpired(now time.Time) { + for tc.start != tc.end { + tv := tc.timestamps[tc.start] + if now.Before(tv.expiry) { + break + } + + delete(tc.values, tv.value) + tc.increaseIndex(&tc.start) + } +} + +// almostFull returns if the time cache will get full on the next insertion +func (tc *TimeCache[K]) almostFull() bool { + nextEnd := tc.end + tc.increaseIndex(&nextEnd) + return nextEnd == tc.start +} + +func (tc *TimeCache[K]) regrowth() { + const standardSize = 1024 + + nextSize := tc.size * 2 + if tc.size > standardSize { + // growth by 20% + nextSize = (tc.size * 12) / 10 + } + + nextTimestamps := make([]timedValue[K], nextSize) + + // This case only applies when start == 0 and end == size-1 + if tc.start < tc.end { + copy(nextTimestamps, tc.timestamps) + tc.size = nextSize + tc.timestamps = nextTimestamps + return + } + + count := tc.size - int(tc.start) + copy(nextTimestamps[0:count], tc.timestamps[tc.start:tc.size]) + nextEnd := count + int(tc.end) + copy(nextTimestamps[count:nextEnd], tc.timestamps[0:tc.end]) + + tc.start = 0 + tc.end = index(count) + tc.end + tc.size = nextSize + tc.timestamps = nextTimestamps +} diff --git a/consensus/propeller/timecache/timecache_test.go b/consensus/propeller/timecache/timecache_test.go new file mode 100644 index 0000000000..0f858f4fc9 --- /dev/null +++ b/consensus/propeller/timecache/timecache_test.go @@ -0,0 +1,165 @@ +package timecache_test + +import ( + "math/rand/v2" + "sync" + "testing" + "time" + + "github.com/NethermindEth/juno/consensus/propeller/timecache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTimeCacheSequentially(t *testing.T) { + t.Parallel() + t.Run("correctly expires keys", func(t *testing.T) { + t.Parallel() + + const expiry = 3 * time.Second + tc := timecache.New[int](3, expiry) + + key := 3 + tc.Add(&key) + + require.True(t, tc.Get(&key), "key should exists while it hasn't expired") + + time.Sleep(expiry + 1*time.Second) + + require.False(t, tc.Get(&key), "key shouldn't exist after expiration window") + }) + + t.Run("correctly increases in size the cache's exceeds orignal size", func(t *testing.T) { + t.Parallel() + + const size = 3 + const expiry = 3 * time.Second + + // Fill the time cache with data to it's maximum size + tc := timecache.New[int](size, expiry) + for i := range size { + tc.Add(&i) + } + for i := range size { + require.Truef(t, tc.Get(&i), "key %d should still exist", i) + } + + // Add more element and check that old and new keys both exist in the cache + time.Sleep(time.Second) + for i := range size { + newKey := i + size + tc.Add(&newKey) + require.Truef(t, tc.Get(&i), "key %d should exist", i) + require.Truef(t, tc.Get(&newKey), "new key %d should also exist ", newKey) + } + + // Wait for old keys to expire + time.Sleep(2*time.Second + 200*time.Millisecond) + for i := range size { + newKey := i + size + require.Falsef(t, tc.Get(&i), "key %d shouldn't exist", i) + require.Truef(t, tc.Get(&newKey), "new key %d should exist ", newKey) + } + + // Add back the initial keys and check that they can both are being held + for i := range size { + newKey := i + size + tc.Add(&i) + require.Truef(t, tc.Get(&i), "key %d should exist again", i) + require.Truef(t, tc.Get(&newKey), "new key %d should still exist ", newKey) + } + }) +} + +func TestTimeCacheConcurrently(t *testing.T) { + const size = 100 + const expiry = 1 * time.Second + + tc := timecache.New[int](size, expiry) + + // Go-routine A will send non stop elements on an interval for 3 seconds + // Go-routine B will check for the elements right after + // Go-routine C will check for the elements after expiry time + + fastCheckCh := make(chan int) + + type timedInt struct { + val int + time time.Time + } + slowCheckCh := make(chan timedInt) + + var wg sync.WaitGroup + + // Go-routine A + const sendPeriod = expiry * 3 + wg.Go(func() { + key := 0 + timeout := time.After(sendPeriod) + for { + select { + case <-timeout: + close(fastCheckCh) + close(slowCheckCh) + return + case <-time.After(time.Duration(rand.IntN(250)+1) * time.Millisecond): + tc.Add(&key) + fastCheckCh <- key + slowCheckCh <- timedInt{val: key, time: time.Now()} + key += 1 + } + } + }) + + // Go-routine B + wg.Go(func() { + for v := range fastCheckCh { + assert.Truef(t, tc.Get(&v), "value %d should still exists", v) + } + }) + + // Go-routine C + wg.Go(func() { + valToReview := 0 + valsToCheck := make([]timedInt, 0, 100) + waitDuration := time.Now().Add(sendPeriod * 2) + + for { + select { + case timedInt, ok := <-slowCheckCh: + if !ok { + if valToReview == len(valsToCheck) { + // This shouldn't happen, as with the current config after sending + // is finished there should be more values to check + return + } + slowCheckCh = nil + continue + } + valsToCheck = append(valsToCheck, timedInt) + if valToReview == len(valsToCheck)-1 { + waitDuration = valsToCheck[valToReview].time.Add(expiry) + } + case <-time.After(time.Until(waitDuration)): + val := valsToCheck[valToReview] + assert.Falsef(t, tc.Get(&val.val), "value %d shouldn't exist", val.val) + + valToReview += 1 + if valToReview == len(valsToCheck) { + if slowCheckCh == nil { + // all values have been received and checked + return + } + // all values have been checked but there are still more to receive + // Set a long enough wait duration to avoid triggering this until + // a new new element is received + waitDuration = time.Now().Add(sendPeriod * 2) + continue + } + waitDuration = valsToCheck[valToReview].time.Add(expiry) + } + } + }) + + wg.Wait() +} From db915959a8a6ffc5d0904a1273c6a8ce078a2c47 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Fri, 10 Apr 2026 18:35:08 +0100 Subject: [PATCH 23/40] chore: add bench tests to timecache --- consensus/propeller/timecache/timecache.go | 2 +- .../timecache/timecache_bench_test.go | 131 ++++++++++++++++++ 2 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 consensus/propeller/timecache/timecache_bench_test.go diff --git a/consensus/propeller/timecache/timecache.go b/consensus/propeller/timecache/timecache.go index 924da4599e..890c74bf3b 100644 --- a/consensus/propeller/timecache/timecache.go +++ b/consensus/propeller/timecache/timecache.go @@ -98,7 +98,7 @@ func (tc *TimeCache[K]) increaseIndex(idx *index) { // finds the first one that hasn't or the cache empties func (tc *TimeCache[K]) removeExpired(now time.Time) { for tc.start != tc.end { - tv := tc.timestamps[tc.start] + tv := &tc.timestamps[tc.start] if now.Before(tv.expiry) { break } diff --git a/consensus/propeller/timecache/timecache_bench_test.go b/consensus/propeller/timecache/timecache_bench_test.go new file mode 100644 index 0000000000..f885127cc5 --- /dev/null +++ b/consensus/propeller/timecache/timecache_bench_test.go @@ -0,0 +1,131 @@ +package timecache_test + +import ( + "math/rand/v2" + "sync/atomic" + "testing" + "time" + + "github.com/NethermindEth/juno/consensus/propeller/timecache" +) + +func BenchmarkTimeCacheAdd(b *testing.B) { + b.Run("small cache size", func(b *testing.B) { + tc := timecache.New[int](100, 3*time.Second) + for i := range b.N { + tc.Add(&i) + } + }) + + b.Run("big cache size", func(b *testing.B) { + tc := timecache.New[int](5000, 3*time.Second) + for i := range b.N { + tc.Add(&i) + } + }) + + b.Run("big cache size with some expired values", func(b *testing.B) { + const size = 2000 + tc := timecache.New[int](size, 2*time.Second) + for i := range size - 1 { + tc.Add(&i) + time.Sleep(1 * time.Millisecond) + } + + // Let some values expire + time.Sleep(500 * time.Millisecond) + b.ResetTimer() + + // Add a lot of new ones + for i := range b.N { + tc.Add(&i) + } + }) + + b.Run("custom key", func(b *testing.B) { + // the same size as `messageKey` + type key [10]uint64 + + tc := timecache.New[key](5000, 3*time.Second) + for i := range b.N { + key := key{} + key[0] = uint64(i) + tc.Add(&key) + } + }) +} + +func BenchmarkTimeCacheGet(b *testing.B) { + b.Run("empty cache", func(b *testing.B) { + tc := timecache.New[int](100, 3*time.Second) + for i := range b.N { + tc.Get(&i) + } + }) + + b.Run("full unexpired cache", func(b *testing.B) { + const size = 10000 + tc := timecache.New[int](size, 1*time.Hour) + for i := range size { + tc.Add(&i) + } + b.ResetTimer() + + for i := range b.N { + key := i % size + tc.Get(&key) + } + }) + + b.Run("full with some expired values", func(b *testing.B) { + const size = 2000 + tc := timecache.New[int](size, 2*time.Second) + for i := range size / 2 { + tc.Add(&i) + time.Sleep(time.Millisecond) + } + time.Sleep(200 * time.Millisecond) + + b.ResetTimer() + + for i := range b.N { + key := i % size + tc.Get(&key) + } + }) + + b.Run("while additions are ocurring at the same time", func(b *testing.B) { + const size = 100 + tc := timecache.New[int64](size, 100*time.Millisecond) + + var insertedKeys atomic.Int64 + // random val != 0 + insertedKeys.Store(13) + + done := make(chan struct{}) + go func() { + var key int64 + for { + select { + case <-done: + return + case <-time.After(time.Duration((rand.IntN(10))+1) * time.Millisecond): + tc.Add(&key) + key += 1 + insertedKeys.Store(key) + } + } + }() + + b.ResetTimer() + for range b.N { + n := insertedKeys.Load() + + key := rand.Int64N(n) + tc.Get(&key) + } + + b.StopTimer() + close(done) + }) +} From 9808f3a8640a8bfc32b9981c6cb46c83371efd44 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Fri, 10 Apr 2026 18:42:04 +0100 Subject: [PATCH 24/40] chore: final touches to timecache --- consensus/propeller/timecache/timecache.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/consensus/propeller/timecache/timecache.go b/consensus/propeller/timecache/timecache.go index 890c74bf3b..474027e48c 100644 --- a/consensus/propeller/timecache/timecache.go +++ b/consensus/propeller/timecache/timecache.go @@ -15,7 +15,7 @@ type timedValue[K any] struct { type TimeCache[K comparable] struct { // Access valid keys in O(1) values map[K]time.Time - // Expire in O(k) where `k` is the amount of expired keys + // Clean expired keys O(k) where `k` is the amount of expired keys timestamps []timedValue[K] mu sync.RWMutex @@ -31,7 +31,8 @@ type TimeCache[K comparable] struct { // New allocates a new Timecache with initial allocation size and expiry time. // If `size` gets filled the timecache will allocate more memory to fit more -// elements into it. The cache will not shrink after regrowing. +// elements into it. The cache will not shrink after regrowing. It is safe for +// concurrent use. func New[K comparable](size int, expiry time.Duration) *TimeCache[K] { // we allocate size+1 because we allways leave the last position empty // to detect when the cache is full @@ -68,6 +69,7 @@ func (tc *TimeCache[K]) Add(value *K) { tc.increaseIndex(&tc.end) } +// Get returns true if the entry exists and it hasn't expired, false otherwise func (tc *TimeCache[K]) Get(value *K) bool { tc.mu.RLock() expiry, ok := tc.values[*value] @@ -140,7 +142,7 @@ func (tc *TimeCache[K]) regrowth() { copy(nextTimestamps[count:nextEnd], tc.timestamps[0:tc.end]) tc.start = 0 - tc.end = index(count) + tc.end + tc.end = index(nextEnd) tc.size = nextSize tc.timestamps = nextTimestamps } From 72ff50246c28959d4a6bc6f7499066fccdb51f59 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 11 Apr 2026 09:35:00 +0100 Subject: [PATCH 25/40] refactor: remove old timecache --- consensus/propeller/timecache.go | 70 --------------- consensus/propeller/timecache_test.go | 122 -------------------------- 2 files changed, 192 deletions(-) delete mode 100644 consensus/propeller/timecache.go delete mode 100644 consensus/propeller/timecache_test.go diff --git a/consensus/propeller/timecache.go b/consensus/propeller/timecache.go deleted file mode 100644 index cf229ed695..0000000000 --- a/consensus/propeller/timecache.go +++ /dev/null @@ -1,70 +0,0 @@ -package propeller - -import ( - "sync" - "time" -) - -// TimeCache is a set data structure where entries automatically expire after a -// configured TTL. It is used to remember which messages have been finalised so -// we can reject late-arriving shards without keeping state forever. -// -// The cache is safe for concurrent access. Expired entries are lazily removed: -// Contains() ignores expired entries, and Cleanup() bulk-removes them. This -// amortised approach avoids the overhead of per-entry timers. -type TimeCache[K comparable] struct { - mu sync.Mutex - entries map[K]time.Time - ttl time.Duration -} - -// NewTimeCache creates a cache where entries expire after the given TTL. -func NewTimeCache[K comparable](ttl time.Duration) *TimeCache[K] { - return &TimeCache[K]{ - entries: make(map[K]time.Time), - ttl: ttl, - } -} - -// Add inserts a key into the cache with an expiry of now + TTL. -// If the key already exists, its expiry is refreshed. -func (c *TimeCache[K]) Add(key K) { - c.mu.Lock() - defer c.mu.Unlock() - c.entries[key] = time.Now().Add(c.ttl) -} - -// Contains returns true if the key is present and has not expired. -// Expired keys are treated as absent but not removed -- call Cleanup() -// periodically to reclaim memory. -func (c *TimeCache[K]) Contains(key K) bool { - c.mu.Lock() - defer c.mu.Unlock() - expiry, ok := c.entries[key] - if !ok { - return false - } - return time.Now().Before(expiry) -} - -// Cleanup removes all expired entries from the cache. Call this periodically -// (e.g., every N operations or on a timer) to prevent unbounded growth from -// expired entries that are never looked up again. -func (c *TimeCache[K]) Cleanup() { - c.mu.Lock() - defer c.mu.Unlock() - now := time.Now() - for k, expiry := range c.entries { - if !now.Before(expiry) { - delete(c.entries, k) - } - } -} - -// Len returns the total number of entries including expired ones that have -// not yet been cleaned up. Useful for testing and monitoring. -func (c *TimeCache[K]) Len() int { - c.mu.Lock() - defer c.mu.Unlock() - return len(c.entries) -} diff --git a/consensus/propeller/timecache_test.go b/consensus/propeller/timecache_test.go deleted file mode 100644 index 25a12b4a5b..0000000000 --- a/consensus/propeller/timecache_test.go +++ /dev/null @@ -1,122 +0,0 @@ -package propeller - -import ( - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestTimeCache_AddAndContains(t *testing.T) { - cache := NewTimeCache[string](10 * time.Second) - - assert.False(t, cache.Contains("a"), "empty cache should not contain any key") - - cache.Add("a") - assert.True(t, cache.Contains("a"), "key should be present after Add") - assert.False(t, cache.Contains("b"), "unrelated key should not be present") -} - -func TestTimeCache_Expiration(t *testing.T) { - // Use a controllable clock so we don't need real sleeps. - now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) - cache := NewTimeCache[string](5 * time.Second) - cache.nowFn = func() time.Time { return now } - - cache.Add("x") - assert.True(t, cache.Contains("x")) - - // Advance time to just before expiry. - now = now.Add(4 * time.Second) - assert.True(t, cache.Contains("x"), "should still be present before TTL") - - // Advance time to exactly the expiry moment. - now = now.Add(1 * time.Second) - assert.False(t, cache.Contains("x"), "should be expired at TTL boundary") - - // Advance well past expiry. - now = now.Add(10 * time.Second) - assert.False(t, cache.Contains("x"), "should be expired well after TTL") -} - -func TestTimeCache_RefreshExpiry(t *testing.T) { - now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) - cache := NewTimeCache[string](5 * time.Second) - cache.nowFn = func() time.Time { return now } - - cache.Add("k") - - // Advance 3 seconds, then re-add to refresh. - now = now.Add(3 * time.Second) - cache.Add("k") - - // Advance another 3 seconds -- would be expired without refresh (6s > 5s), - // but the refresh pushed the deadline to 3s+5s=8s. - now = now.Add(3 * time.Second) - assert.True(t, cache.Contains("k"), "re-add should refresh the TTL") - - // Advance past the refreshed expiry. - now = now.Add(3 * time.Second) - assert.False(t, cache.Contains("k"), "should expire after refreshed TTL") -} - -func TestTimeCache_Cleanup(t *testing.T) { - now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) - cache := NewTimeCache[int](2 * time.Second) - cache.nowFn = func() time.Time { return now } - - for i := range 5 { - cache.Add(i) - } - require.Equal(t, 5, cache.Len()) - - // Expire all entries. - now = now.Add(3 * time.Second) - - // They're expired but still in the map until Cleanup. - assert.Equal(t, 5, cache.Len(), "expired entries linger until Cleanup") - assert.False(t, cache.Contains(0), "expired entries should not be found") - - cache.Cleanup() - assert.Equal(t, 0, cache.Len(), "Cleanup should remove all expired entries") -} - -func TestTimeCache_CleanupPartial(t *testing.T) { - now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) - cache := NewTimeCache[string](5 * time.Second) - cache.nowFn = func() time.Time { return now } - - cache.Add("early") - now = now.Add(3 * time.Second) - cache.Add("late") - - // Advance so "early" is expired but "late" is not. - now = now.Add(3 * time.Second) - - cache.Cleanup() - assert.Equal(t, 1, cache.Len(), "only the expired entry should be removed") - assert.False(t, cache.Contains("early")) - assert.True(t, cache.Contains("late")) -} - -func TestTimeCache_ConcurrentAccess(t *testing.T) { - cache := NewTimeCache[int](1 * time.Second) - - var wg sync.WaitGroup - // Hammer the cache from multiple goroutines to verify no races. - for i := range 100 { - wg.Add(1) - go func(v int) { - defer wg.Done() - cache.Add(v) - cache.Contains(v) - cache.Cleanup() - }(i) - } - wg.Wait() - - // We just care that it didn't panic or race. - assert.LessOrEqual(t, cache.Len(), 100) -} From 8a9584fd02a19e4a242521cfc8f851cf5f26711e Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 11 Apr 2026 12:04:22 +0100 Subject: [PATCH 26/40] test: the reedsolomon pkg --- .../propeller/reedsolomon/reedsolomon.go | 2 - .../propeller/reedsolomon/reedsolomon_test.go | 281 ++++++++++++++++++ 2 files changed, 281 insertions(+), 2 deletions(-) diff --git a/consensus/propeller/reedsolomon/reedsolomon.go b/consensus/propeller/reedsolomon/reedsolomon.go index 56991aba26..0ed3073649 100644 --- a/consensus/propeller/reedsolomon/reedsolomon.go +++ b/consensus/propeller/reedsolomon/reedsolomon.go @@ -12,8 +12,6 @@ import ( // It will return the Reed Solomon encoding where the first `numDataShards` // `[]byte` slices will be occupied by the original data. The remaining `parity` // `[]byte` slices will contain the coding shards. -// The data will be modified in place so the input shouldn't be modified after calling this -// function. func EncodeData( data []byte, numDataShards, diff --git a/consensus/propeller/reedsolomon/reedsolomon_test.go b/consensus/propeller/reedsolomon/reedsolomon_test.go index c37f6836a0..dc1dc27050 100644 --- a/consensus/propeller/reedsolomon/reedsolomon_test.go +++ b/consensus/propeller/reedsolomon/reedsolomon_test.go @@ -1 +1,282 @@ package reedsolomon_test + +import ( + "bytes" + "crypto/rand" + "slices" + "testing" + + "github.com/NethermindEth/juno/consensus/propeller/reedsolomon" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncodeData(t *testing.T) { + requireEqualLength := func(t *testing.T, fragments [][]byte) { + length := len(fragments[0]) + for i := range fragments { + require.Len(t, fragments[i], length) + } + } + + requireEqualPrefix := func(t *testing.T, expected []byte, fragments [][]byte) { + actual := bytes.Join(fragments, nil) + require.Truef( + t, + bytes.HasPrefix(actual, expected), + "expected to get prefix: %s in %s", + expected, + actual, + ) + } + + largeData := make([]byte, 10*1024) + _, err := rand.Read(largeData) + require.NoError(t, err) + + successTests := []struct { + name string + data []byte + numData int + parity int + }{ + { + name: "success", + data: []byte("A journey of a thousands shards begins with a single byte"), + numData: 5, + parity: 3, + }, + { + name: "single data shard and single parity", + data: []byte("some data"), + numData: 1, + parity: 1, + }, + { + name: "large data", + data: largeData, + numData: 8, + parity: 4, + }, + { + name: "above 256 total shards", + data: largeData, + numData: 200, + parity: 100, + }, + } + for _, tc := range successTests { + t.Run(tc.name, func(t *testing.T) { + shards, err := reedsolomon.EncodeData(tc.data, tc.numData, tc.parity) + require.NoError(t, err) + require.Len(t, shards, tc.numData+tc.parity) + requireEqualLength(t, shards) + requireEqualPrefix(t, tc.data, shards) + }) + } + + errorTests := []struct { + name string + data []byte + numData int + parity int + errContains string + }{ + { + name: "empty data", + data: []byte{}, + numData: 5, + parity: 3, + errContains: "received empty data", + }, + { + name: "zero data shards", + data: []byte("data"), + numData: 0, + parity: 3, + errContains: "creating Reed-Solomon encoder", + }, + { + name: "negative parity", + data: []byte("data"), + numData: 5, + parity: -1, + errContains: "creating Reed-Solomon encoder", + }, + { + name: "exceeds max shard count", + data: []byte("data"), + numData: 40000, + parity: 40000, + errContains: "creating Reed-Solomon encoder", + }, + } + for _, tc := range errorTests { + t.Run(tc.name, func(t *testing.T) { + _, err := reedsolomon.EncodeData(tc.data, tc.numData, tc.parity) + require.ErrorContains(t, err, tc.errContains) + }) + } +} + +func TestRecoverData(t *testing.T) { + encode := func(t *testing.T, data []byte, numData, parity int) [][]byte { + t.Helper() + shards, err := reedsolomon.EncodeData(data, numData, parity) + require.NoError(t, err) + return shards + } + + buildDataShards := func(t *testing.T, original [][]byte, missingData ...int) [][]byte { + t.Helper() + dataShards := make([][]byte, len(original)) + for i := range original { + if slices.Contains(missingData, i) { + continue + } + dataShards[i] = make([]byte, len(original[i])) + copy(dataShards[i], original[i]) + } + return dataShards + } + + requireEqualShards := func(t *testing.T, expected [][]byte, actual [][]byte) { + t.Helper() + for i := range expected { + assert.Equalf( + t, expected[i], actual[i], + "at index %d, expected: %s, actual: %s", + i, expected[i], actual[i], + ) + } + } + + successTests := []struct { + name string + data []byte + numData int + parity int + missingIdx []int + }{ + { + name: "no missing shards", + data: []byte("nothing is missing here"), + numData: 4, + parity: 2, + }, + { + name: "missing parity shards", + data: []byte("recover parity shards"), + numData: 4, + parity: 3, + missingIdx: []int{5, 6, 7}, + }, + { + name: "missing data shards within parity limit", + data: []byte("recover data shards from parity"), + numData: 5, + parity: 3, + missingIdx: []int{0, 2, 4}, + }, + { + name: "missing mixed data and parity shards", + data: []byte("mixed missing shards scenario"), + numData: 5, + parity: 4, + missingIdx: []int{1, 3, 5, 6}, + }, + } + for _, tc := range successTests { + t.Run(tc.name, func(t *testing.T) { + expected := encode(t, tc.data, tc.numData, tc.parity) + dataShards := buildDataShards(t, expected, tc.missingIdx...) + + recovered, err := reedsolomon.RecoverData(dataShards, tc.numData, tc.parity) + require.NoError(t, err) + requireEqualShards(t, expected, recovered) + }) + } + + errorTests := []struct { + name string + data []byte + numData int + parity int + missingIdx []int + errContains string + }{ + { + name: "too many missing shards", + data: []byte("too many shards gone"), + numData: 4, + parity: 2, + missingIdx: []int{0, 1, 4}, + errContains: "recovering the data shards:", + }, + { + name: "empty shards slice", + numData: 4, + parity: 2, + errContains: "no data shards provided", + }, + } + for _, tc := range errorTests { + t.Run(tc.name, func(t *testing.T) { + var shards [][]byte + if tc.data != nil { + expected := encode(t, tc.data, tc.numData, tc.parity) + shards = buildDataShards(t, expected, tc.missingIdx...) + } + + recovered, err := reedsolomon.RecoverData(shards, tc.numData, tc.parity) + require.Nil(t, recovered) + require.ErrorContains(t, err, tc.errContains) + }) + } +} + +func TestEncodeDecodeRoundTrip(t *testing.T) { + tests := []struct { + name string + data []byte + numData int + parity int + nilIdxs []int // indices to nil out before recovery + }{ + { + name: "small data, lose 1 data shard", + data: []byte("round trip test"), + numData: 4, parity: 2, + nilIdxs: []int{0}, + }, + { + name: "medium data, lose max shards", + data: bytes.Repeat([]byte("abcdefghij"), 100), + numData: 5, parity: 3, + nilIdxs: []int{1, 3, 6}, + }, + { + name: "single byte", + data: []byte{0xff}, + numData: 2, parity: 1, + nilIdxs: []int{0}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + shards, err := reedsolomon.EncodeData(tc.data, tc.numData, tc.parity) + require.NoError(t, err) + + for _, idx := range tc.nilIdxs { + shards[idx] = nil + } + + recovered, err := reedsolomon.RecoverData(shards, tc.numData, tc.parity) + require.NoError(t, err) + + joined := bytes.Join(recovered[:tc.numData], nil) + assert.Equal(t, tc.data, joined[:len(tc.data)]) + }) + } +} From 7241ce5e44db862d385ff14849f5101bdcc4013f Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 11 Apr 2026 13:01:06 +0100 Subject: [PATCH 27/40] feat: optimize merkle tree ops --- consensus/propeller/merkle/merkle.go | 131 ++++++++++++++------------- 1 file changed, 68 insertions(+), 63 deletions(-) diff --git a/consensus/propeller/merkle/merkle.go b/consensus/propeller/merkle/merkle.go index 1fecd859c7..eb708d024a 100644 --- a/consensus/propeller/merkle/merkle.go +++ b/consensus/propeller/merkle/merkle.go @@ -1,3 +1,11 @@ +// Package merkle implements Merkle tree construction and verification using a +// SHA-256 tagging scheme. Tags prevent second-preimage attacks by +// domain-separating leaf hashes from internal node hashes. The exact tag +// format matches the Propeller protocol specification so that all +// implementations produce identical trees. +// +// Tree layout: leaves are at the bottom, padded to the next power-of-two +// with the hash of empty data. The tree is built bottom-up by hashing pairs. package merkle import ( @@ -5,6 +13,28 @@ import ( "math/bits" ) +const ( + leafOpenTag = "" + leafCloseTag = "" + nodeOpenTag = "" + nodeMidTag = "" + nodeCloseTag = "" +) + +// Pre-computed domain-separator tags to avoid repeated []byte conversions. +var ( + leafOpen = []byte(leafOpenTag) + leafClose = []byte(leafCloseTag) + nodeOpen = []byte(nodeOpenTag) + nodeMid = []byte(nodeMidTag) + nodeClose = []byte(nodeCloseTag) +) + +// emptyLeafHash is the hash of a padding leaf (no data). We precompute it +// because the same value is used repeatedly when the leaf count is not a +// power of two. +var emptyLeafHash = merkleLeafHash(nil) + type Hash [32]byte // Proof contains the sibling hashes needed to verify that a leaf @@ -14,9 +44,9 @@ type Proof struct { Siblings []Hash } -// VerifyProof checks that a leaf at the given index is included in a -// tree with the claimed root. The proof contains sibling hashes from the leaf -// level up to the root. +// Verify checks that a leaf at the given index is included in a tree with +// the claimed root. The proof contains sibling hashes from the leaf level +// up to the root. // // The index determines the path through the tree: at each level, if the // current bit of the index is 0 the current hash is the left child and the @@ -25,13 +55,11 @@ func (p *Proof) Verify(root *Hash, leaf []byte, index uint32) bool { current := merkleLeafHash(leaf) idx := index - for _, sibling := range p.Siblings { + for i := range p.Siblings { if idx%2 == 0 { - // Current node is left child, sibling is right. - current = merkleNodeHash(current, sibling) + current = merkleNodeHash(¤t, &p.Siblings[i]) } else { - // Current node is right child, sibling is left. - current = merkleNodeHash(sibling, current) + current = merkleNodeHash(&p.Siblings[i], ¤t) } idx /= 2 } @@ -39,14 +67,9 @@ func (p *Proof) Verify(root *Hash, leaf []byte, index uint32) bool { return current == *root } -// Represents a Merkle Tree +// Tree is a set of inclusion proofs, one per original leaf. type Tree []Proof -// emptyLeafHash is the hash of a padding leaf (no data). We precompute it -// because the same value is used repeatedly when the leaf count is not a -// power of two. -var emptyLeafHash = merkleLeafHash(nil) - // New constructs a binary Merkle tree from the given leaf data // and returns the root hash plus one inclusion proof per original leaf. // @@ -54,12 +77,12 @@ var emptyLeafHash = merkleLeafHash(nil) // simplifies the proof logic: every node at every level has a sibling, and // the proof path length is always log2(paddedSize). // -// Returns a zero root and nil proofs if leaves is empty. +// Returns a zero root and nil Tree if leaves is empty. func New(leaves [][]byte) (root Hash, tree Tree) { n := len(leaves) if n == 0 { // todo(rdr): maybe here we return a default merkle tree - return [32]byte{}, nil + return Hash{}, nil } size := nextPowerOfTwo(n) @@ -67,6 +90,7 @@ func New(leaves [][]byte) (root Hash, tree Tree) { // Build the bottom layer: hash each leaf, pad to power-of-two. layer := make([]Hash, size) for i := range n { + //nolint: gosec // Everything is inbouds here layer[i] = merkleLeafHash(leaves[i]) } for i := n; i < size; i++ { @@ -74,34 +98,22 @@ func New(leaves [][]byte) (root Hash, tree Tree) { } // proofSiblings[i] accumulates the sibling hashes for leaf i's proof. - // We collect them bottom-up as we build the tree. + // ancestors[i] tracks leaf i's ancestor position in the current layer. proofSiblings := make([][]Hash, n) + ancestors := make([]int, n) + for j := range n { + ancestors[j] = j + } // Build the tree bottom-up, one level at a time. for len(layer) > 1 { nextLayer := make([]Hash, len(layer)/2) for i := 0; i < len(layer); i += 2 { - left, right := layer[i], layer[i+1] - nextLayer[i/2] = merkleNodeHash(left, right) - - // Record siblings for any original leaves still tracked at - // this level. Leaf j at this level has its sibling at j^1 - // (XOR flips the last bit to get the pair partner). - for j := range n { - // Which position in the current layer does leaf j's - // ancestor occupy? It's j >> (current depth), but we - // track this implicitly: at depth d the ancestor of - // leaf j is at position j >> d. Since we've already - // collected d levels of siblings, d == len(proofSiblings[j]). - d := len(proofSiblings[j]) - ancestorPos := j >> d - if ancestorPos/2 == i/2 { - // This pair contains leaf j's ancestor. The sibling - // is the other element of the pair. - sibling := ancestorPos ^ 1 - proofSiblings[j] = append(proofSiblings[j], layer[sibling]) - } - } + nextLayer[i/2] = merkleNodeHash(&layer[i], &layer[i+1]) + } + for i := range n { + proofSiblings[i] = append(proofSiblings[i], layer[ancestors[i]^1]) + ancestors[i] /= 2 } layer = nextLayer } @@ -116,27 +128,19 @@ func New(leaves [][]byte) (root Hash, tree Tree) { return root, tree } -// Merkle tree construction and verification using a specific SHA-256 tagging -// scheme. Tags prevent second-preimage attacks by domain-separating leaf -// hashes from internal node hashes. The exact tag format matches the Propeller -// protocol specification so that all implementations produce identical trees. -// -// Tree layout: leaves are at the bottom, padded to the next power-of-two -// with the hash of empty data. The tree is built bottom-up by hashing pairs. - // merkleLeafHash computes: SHA256("" || data || "") // // The XML-like tags are the domain separator specified by the Propeller // protocol. They ensure a leaf hash can never collide with a node hash, // even if an attacker controls the data. func merkleLeafHash(data []byte) Hash { - h := sha256.New() - h.Write([]byte("")) - h.Write(data) - h.Write([]byte("")) - var out [32]byte - h.Sum(out[:0]) - return out + buf := make([]byte, len(leafOpenTag)+len(data)+len(leafCloseTag)) + + n := copy(buf, leafOpen) + n += copy(buf[n:], data) + copy(buf[n:], leafClose) + + return sha256.Sum256(buf) } // merkleNodeHash computes: @@ -144,16 +148,17 @@ func merkleLeafHash(data []byte) Hash { // SHA256("" || left || "" || right || "") // // The nested tags ensure node hashes are in a separate domain from leaf hashes. -func merkleNodeHash(left, right [32]byte) Hash { - h := sha256.New() - h.Write([]byte("")) - h.Write(left[:]) - h.Write([]byte("")) - h.Write(right[:]) - h.Write([]byte("")) - var out [32]byte - h.Sum(out[:0]) - return out +func merkleNodeHash(left, right *Hash) Hash { + const size = len(nodeOpenTag) + 32 + len(nodeMidTag) + 32 + len(nodeCloseTag) + var buf [size]byte + + n := copy(buf[:], nodeOpen) + n += copy(buf[n:], left[:]) + n += copy(buf[n:], nodeMid) + n += copy(buf[n:], right[:]) + copy(buf[n:], nodeClose) + + return sha256.Sum256(buf[:]) } // nextPowerOfTwo returns the smallest power of two >= n, with a minimum of 2. From 945d4c8e90b392e9307e6ecaae9d8467bcfd7966 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 11 Apr 2026 13:36:44 +0100 Subject: [PATCH 28/40] test: for merkle tree implementation --- consensus/propeller/merkle/merkle_test.go | 282 ++++++---------------- 1 file changed, 76 insertions(+), 206 deletions(-) diff --git a/consensus/propeller/merkle/merkle_test.go b/consensus/propeller/merkle/merkle_test.go index e02236492b..eaf870cc2e 100644 --- a/consensus/propeller/merkle/merkle_test.go +++ b/consensus/propeller/merkle/merkle_test.go @@ -1,243 +1,113 @@ package merkle_test import ( - "crypto/sha256" + "fmt" "testing" + "github.com/NethermindEth/juno/consensus/propeller/merkle" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestMerkleLeafHash(t *testing.T) { - data := []byte("hello") - hash := merkleLeafHash(data) - - // Manually compute expected: SHA256("hello") - h := sha256.New() - h.Write([]byte("hello")) - var expected [32]byte - h.Sum(expected[:0]) - - assert.Equal(t, expected, hash) -} - -func TestMerkleLeafHash_Empty(t *testing.T) { - hash := merkleLeafHash(nil) - - h := sha256.New() - h.Write([]byte("")) - var expected [32]byte - h.Sum(expected[:0]) - - assert.Equal(t, expected, hash) -} - -func TestMerkleNodeHash(t *testing.T) { - left := merkleLeafHash([]byte("L")) - right := merkleLeafHash([]byte("R")) - node := merkleNodeHash(left, right) - - h := sha256.New() - h.Write([]byte("")) - h.Write(left[:]) - h.Write([]byte("")) - h.Write(right[:]) - h.Write([]byte("")) - var expected [32]byte - h.Sum(expected[:0]) - - assert.Equal(t, expected, node) -} - -func TestNextPowerOfTwo(t *testing.T) { - tests := []struct { - n int - expected int - }{ - {0, 2}, - {1, 2}, - {2, 2}, - {3, 4}, - {4, 4}, - {5, 8}, - {7, 8}, - {8, 8}, - {9, 16}, - {16, 16}, - {17, 32}, - } - - for _, tc := range tests { - assert.Equal(t, tc.expected, nextPowerOfTwo(tc.n), "nextPowerOfTwo(%d)", tc.n) +func makeLeaves(n int) [][]byte { + leaves := make([][]byte, n) + for i := range n { + leaves[i] = fmt.Appendf(nil, "leaf-%d", i) } + return leaves } -func TestBuildMerkleTree_Empty(t *testing.T) { - root, proofs := BuildMerkleTree(nil) - assert.Equal(t, [32]byte{}, root) +func TestNew_Empty(t *testing.T) { + root, proofs := merkle.New(nil) + assert.Equal(t, merkle.Hash{}, root) assert.Nil(t, proofs) } -func TestBuildMerkleTree_SingleLeaf(t *testing.T) { - leaves := [][]byte{[]byte("only")} - root, proofs := BuildMerkleTree(leaves) - - require.Len(t, proofs, 1) - - // With one leaf padded to 2, the tree is: - // root - // / \ - // leaf0 empty - leafHash := merkleLeafHash([]byte("only")) - expectedRoot := merkleNodeHash(leafHash, emptyLeafHash) - assert.Equal(t, expectedRoot, root) - - // Proof for leaf 0 should contain the empty leaf as sibling. - assert.Len(t, proofs[0].Siblings, 1) - assert.Equal(t, emptyLeafHash, proofs[0].Siblings[0]) +func TestNew_ProofsVerify(t *testing.T) { + for _, n := range []int{1, 2, 3, 4, 5, 8, 16, 31} { + t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) { + leaves := makeLeaves(n) + root, proofs := merkle.New(leaves) + + require.Len(t, proofs, n) + assert.NotEqual(t, merkle.Hash{}, root) + + for i, leaf := range leaves { + assert.True(t, + proofs[i].Verify(&root, leaf, uint32(i)), + "proof for leaf %d should verify", i, + ) + } + }) + } } -func TestBuildMerkleTree_TwoLeaves(t *testing.T) { - leaves := [][]byte{[]byte("A"), []byte("B")} - root, proofs := BuildMerkleTree(leaves) - - require.Len(t, proofs, 2) - - h0 := merkleLeafHash([]byte("A")) - h1 := merkleLeafHash([]byte("B")) - expectedRoot := merkleNodeHash(h0, h1) - assert.Equal(t, expectedRoot, root) - - // Leaf 0's sibling is leaf 1. - assert.Equal(t, h1, proofs[0].Siblings[0]) - // Leaf 1's sibling is leaf 0. - assert.Equal(t, h0, proofs[1].Siblings[0]) +func TestNew_WrongDataDoesNotVerify(t *testing.T) { + for _, n := range []int{1, 2, 3, 4, 5, 31} { + t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) { + leaves := makeLeaves(n) + root, proofs := merkle.New(leaves) + + for i := range leaves { + assert.False(t, + proofs[i].Verify(&root, []byte("tampered"), uint32(i)), + "tampered data should not verify for leaf %d", i, + ) + } + }) + } } -func TestBuildMerkleTree_FourLeaves(t *testing.T) { +func TestVerify_Rejects(t *testing.T) { leaves := [][]byte{[]byte("A"), []byte("B"), []byte("C"), []byte("D")} - root, proofs := BuildMerkleTree(leaves) - - require.Len(t, proofs, 4) - - // Build expected tree manually: - // root - // / \ - // n01 n23 - // / \ / \ - // h0 h1 h2 h3 - h0 := merkleLeafHash([]byte("A")) - h1 := merkleLeafHash([]byte("B")) - h2 := merkleLeafHash([]byte("C")) - h3 := merkleLeafHash([]byte("D")) - n01 := merkleNodeHash(h0, h1) - n23 := merkleNodeHash(h2, h3) - expectedRoot := merkleNodeHash(n01, n23) - assert.Equal(t, expectedRoot, root) - - // Proof for leaf 0: siblings [h1, n23] - require.Len(t, proofs[0].Siblings, 2) - assert.Equal(t, h1, proofs[0].Siblings[0]) - assert.Equal(t, n23, proofs[0].Siblings[1]) - - // Proof for leaf 2: siblings [h3, n01] - require.Len(t, proofs[2].Siblings, 2) - assert.Equal(t, h3, proofs[2].Siblings[0]) - assert.Equal(t, n01, proofs[2].Siblings[1]) -} - -func TestBuildMerkleTree_ThreeLeaves(t *testing.T) { - // Three leaves means padding to 4: the fourth leaf is empty. - leaves := [][]byte{[]byte("A"), []byte("B"), []byte("C")} - root, proofs := BuildMerkleTree(leaves) - - require.Len(t, proofs, 3) - - h0 := merkleLeafHash([]byte("A")) - h1 := merkleLeafHash([]byte("B")) - h2 := merkleLeafHash([]byte("C")) - h3 := emptyLeafHash - n01 := merkleNodeHash(h0, h1) - n23 := merkleNodeHash(h2, h3) - expectedRoot := merkleNodeHash(n01, n23) - assert.Equal(t, expectedRoot, root) - - // Proof for leaf 2: siblings [emptyLeaf, n01] - require.Len(t, proofs[2].Siblings, 2) - assert.Equal(t, h3, proofs[2].Siblings[0]) - assert.Equal(t, n01, proofs[2].Siblings[1]) -} - -func TestVerifyMerkleProof_ValidProofs(t *testing.T) { - // Build a tree with several leaves, then verify every proof. - data := [][]byte{ - []byte("alpha"), - []byte("bravo"), - []byte("charlie"), - []byte("delta"), - []byte("echo"), - } - root, proofs := BuildMerkleTree(data) - require.Len(t, proofs, len(data)) + root, proofs := merkle.New(leaves) - for i, d := range data { - ok := VerifyMerkleProof(root, d, uint32(i), proofs[i]) - assert.True(t, ok, "proof for leaf %d should verify", i) - } -} + t.Run("wrong index", func(t *testing.T) { + assert.False(t, proofs[0].Verify(&root, leaves[0], 1)) + }) -func TestVerifyMerkleProof_WrongData(t *testing.T) { - leaves := [][]byte{[]byte("real"), []byte("data")} - root, proofs := BuildMerkleTree(leaves) + t.Run("wrong root", func(t *testing.T) { + fakeRoot := merkle.Hash{0xff} + assert.False(t, proofs[0].Verify(&fakeRoot, leaves[0], 0)) + }) - // Tamper with the data. - ok := VerifyMerkleProof(root, []byte("fake"), 0, proofs[0]) - assert.False(t, ok, "tampered data should not verify") + t.Run("tampered sibling", func(t *testing.T) { + badProof := merkle.Proof{Siblings: []merkle.Hash{{0xde, 0xad}}} + assert.False(t, badProof.Verify(&root, leaves[0], 0)) + }) } -func TestVerifyMerkleProof_WrongIndex(t *testing.T) { - leaves := [][]byte{[]byte("A"), []byte("B"), []byte("C"), []byte("D")} - root, proofs := BuildMerkleTree(leaves) - - // Use leaf 0's data with leaf 1's index. - ok := VerifyMerkleProof(root, []byte("A"), 1, proofs[0]) - assert.False(t, ok, "wrong index should not verify") -} +func TestNew_Deterministic(t *testing.T) { + leaves := makeLeaves(7) -func TestVerifyMerkleProof_WrongRoot(t *testing.T) { - leaves := [][]byte{[]byte("A"), []byte("B")} - _, proofs := BuildMerkleTree(leaves) + root1, proofs1 := merkle.New(leaves) + root2, proofs2 := merkle.New(leaves) - fakeRoot := [32]byte{0xff} - ok := VerifyMerkleProof(fakeRoot, []byte("A"), 0, proofs[0]) - assert.False(t, ok, "wrong root should not verify") + assert.Equal(t, root1, root2) + require.Len(t, proofs1, len(proofs2)) + for i := range proofs1 { + assert.Equal(t, proofs1[i], proofs2[i], "proof %d should be identical", i) + } } -func TestVerifyMerkleProof_TamperedSibling(t *testing.T) { - leaves := [][]byte{[]byte("A"), []byte("B")} - root, _ := BuildMerkleTree(leaves) +func TestNew_DifferentLeavesDifferentRoots(t *testing.T) { + rootA, _ := merkle.New([][]byte{[]byte("A"), []byte("B")}) + rootB, _ := merkle.New([][]byte{[]byte("X"), []byte("Y")}) - // Tamper with a sibling hash in the proof. - badProof := MerkleProof{Siblings: [][32]byte{{0xde, 0xad}}} - ok := VerifyMerkleProof(root, []byte("A"), 0, badProof) - assert.False(t, ok, "tampered sibling should not verify") + assert.NotEqual(t, rootA, rootB) } -func TestBuildAndVerify_LargeTree(t *testing.T) { - // Build a tree with a non-power-of-two count to exercise padding. - n := 31 - leaves := make([][]byte, n) - for i := range n { - leaves[i] = []byte{byte(i), byte(i >> 8)} - } +func TestNew_CrossTreeIsolation(t *testing.T) { + leavesA := [][]byte{[]byte("A"), []byte("B"), []byte("C"), []byte("D")} + leavesB := [][]byte{[]byte("W"), []byte("X"), []byte("Y"), []byte("Z")} - root, proofs := BuildMerkleTree(leaves) - require.Len(t, proofs, n) + rootB, _ := merkle.New(leavesB) + _, proofsA := merkle.New(leavesA) - for i, leaf := range leaves { - assert.True(t, - VerifyMerkleProof(root, leaf, uint32(i), proofs[i]), - "proof for leaf %d in 31-leaf tree should verify", i, + for i, leaf := range leavesA { + assert.False(t, + proofsA[i].Verify(&rootB, leaf, uint32(i)), + "proof from tree A should not verify against tree B root", ) } } From f07b7b1d938971eba4a4b2f9d469c9db4dd15ae9 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 11 Apr 2026 15:55:51 +0100 Subject: [PATCH 29/40] chore: improve padding logic and tests --- consensus/propeller/padding.go | 5 +- consensus/propeller/padding_test.go | 87 ++++++++++++++++------------- 2 files changed, 49 insertions(+), 43 deletions(-) diff --git a/consensus/propeller/padding.go b/consensus/propeller/padding.go index 465d279622..3dca643e93 100644 --- a/consensus/propeller/padding.go +++ b/consensus/propeller/padding.go @@ -1,4 +1,4 @@ -package utils +package propeller import ( "encoding/binary" @@ -32,7 +32,6 @@ func PadMessage(msg []byte, numDataShards int) []byte { result := make([]byte, paddedMsgLen) copy(result, varintBuf[:varintLen]) copy(result[varintLen:], msg) - // Remaining bytes are already zero (Go slice initialization). return result } @@ -56,5 +55,5 @@ func UnpadMessage(padded []byte) ([]byte, error) { ) } - return padded[varintLen:msgLen], nil + return padded[varintLen:end], nil } diff --git a/consensus/propeller/padding_test.go b/consensus/propeller/padding_test.go index 9d29a3dcfc..bfb6d0cb2b 100644 --- a/consensus/propeller/padding_test.go +++ b/consensus/propeller/padding_test.go @@ -1,8 +1,9 @@ -package utils_test +package propeller_test import ( "testing" + "github.com/NethermindEth/juno/consensus/propeller" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -43,11 +44,16 @@ func TestPadMessage_RoundTrip(t *testing.T) { msg: []byte{0x42}, numDataShards: 5, }, + { + name: "large message requiring multi-byte varint", + msg: makeSequentialBytes(300), + numDataShards: 4, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - padded := PadMessage(tc.msg, tc.numDataShards) + padded := propeller.PadMessage(tc.msg, tc.numDataShards) // Verify divisibility. divisor := 2 * tc.numDataShards @@ -55,54 +61,55 @@ func TestPadMessage_RoundTrip(t *testing.T) { "padded length %d should be divisible by %d", len(padded), divisor) // Verify round-trip. - recovered, err := UnpadMessage(padded) + recovered, err := propeller.UnpadMessage(padded) require.NoError(t, err) assert.Equal(t, tc.msg, recovered) }) } } -func TestPadMessage_Alignment(t *testing.T) { - // Verify that padding produces the minimum size that is a multiple of divisor. - msg := []byte("ab") // 2 bytes - padded := PadMessage(msg, 3) // divisor = 6 - // varint(2) = 1 byte, payload = 3 bytes, next multiple of 6 = 6 - assert.Equal(t, 6, len(padded)) -} - -func TestUnpadMessage_InvalidVarint(t *testing.T) { - // An empty buffer has no valid varint. - _, err := UnpadMessage([]byte{}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid varint") -} - -func TestUnpadMessage_LengthExceedsData(t *testing.T) { - // Manually encode a varint claiming 100 bytes, but only provide 5. - buf := make([]byte, 6) - buf[0] = 100 // varint encoding of 100 - copy(buf[1:], "short") +func TestUnpadMessage_Errors(t *testing.T) { + tests := []struct { + name string + input []byte + wantErr string + }{ + { + name: "empty buffer", + input: []byte{}, + wantErr: "invalid varint", + }, + { + name: "truncated varint", + input: []byte{0x80}, // continuation bit set, no following byte + wantErr: "invalid varint", + }, + { + name: "length exceeds data", + input: append([]byte{100}, []byte("short")...), // varint 100, only 5 bytes + wantErr: "exceeds available data", + }, + } - _, err := UnpadMessage(buf) - assert.Error(t, err) - assert.Contains(t, err.Error(), "exceeds available data") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := propeller.UnpadMessage(tc.input) + require.ErrorContains(t, err, tc.wantErr) + }) + } } -func TestUnpadMessage_Truncated(t *testing.T) { - // Encode a valid varint pointing past the end. - buf := []byte{0x80, 0x01} // varint 128, but only 2 bytes total - _, err := UnpadMessage(buf) - assert.Error(t, err) +func TestPadMessage_Size(t *testing.T) { + msg := []byte("ab") // 2 bytes + padded := propeller.PadMessage(msg, 3) // divisor = 6 + // varint(2) = 1 byte, payload = 3 bytes, next multiple of 6 = 6 + require.Len(t, padded, 6) } -func TestPadMessage_LargeVarint(t *testing.T) { - // A message large enough to need a multi-byte varint. - msg := make([]byte, 300) - for i := range msg { - msg[i] = byte(i) +func makeSequentialBytes(n int) []byte { + b := make([]byte, n) + for i := range b { + b[i] = byte(i) } - padded := PadMessage(msg, 4) - recovered, err := UnpadMessage(padded) - require.NoError(t, err) - assert.Equal(t, msg, recovered) + return b } From 8ac0f42687729bb98425adcbe38d4dce30ed0575 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 11 Apr 2026 16:11:31 +0100 Subject: [PATCH 30/40] test: for signing and buf fix --- consensus/propeller/signing.go | 4 +- consensus/propeller/signing_test.go | 136 +++++++++++++++++++++++++++- 2 files changed, 137 insertions(+), 3 deletions(-) diff --git a/consensus/propeller/signing.go b/consensus/propeller/signing.go index 71e5b93290..5b1c8a80d8 100644 --- a/consensus/propeller/signing.go +++ b/consensus/propeller/signing.go @@ -20,7 +20,7 @@ func buildSignPayload( const prefix = "" const suffix = "" - // cummulative lenghts denoting the ranges in where each bytes of data should be stored + // cumulative lenghts denoting the ranges in where each bytes of data should be stored const prefixLen = len(prefix) const rootLen = prefixLen + 32 const committeeIDLen = rootLen + 32 @@ -43,7 +43,7 @@ func SignMessage( root *MessageRoot, committeeID *CommitteeID, nonce Nonce, -) ([]byte, error) { +) (Signature, error) { payload := buildSignPayload(root, committeeID, nonce) sig, err := privKey.Sign(payload[:]) if err != nil { diff --git a/consensus/propeller/signing_test.go b/consensus/propeller/signing_test.go index 90c372d319..29e7f50b8d 100644 --- a/consensus/propeller/signing_test.go +++ b/consensus/propeller/signing_test.go @@ -1 +1,135 @@ -package utils_test +package propeller_test + +import ( + "bytes" + "crypto/ed25519" + "testing" + "time" + + "github.com/NethermindEth/juno/consensus/propeller" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/stretchr/testify/require" +) + +func generateKey(t *testing.T, seed byte) (crypto.PrivKey, crypto.PubKey) { + t.Helper() + s := make([]byte, ed25519.SeedSize) + s[0] = seed + priv, pub, err := crypto.GenerateEd25519Key(bytes.NewReader(s)) + require.NoError(t, err) + return priv, pub +} + +func TestSignAndVerify(t *testing.T) { + privA, pubA := generateKey(t, 1) + + root := propeller.MessageRoot{0xAA} + committeeID := propeller.CommitteeID{0xBB} + nonce := propeller.Nonce(time.Second) + + sig, err := propeller.SignMessage(privA, &root, &committeeID, nonce) + require.NoError(t, err) + + t.Run("success", func(t *testing.T) { + tests := []struct { + name string + root propeller.MessageRoot + committeeID propeller.CommitteeID + nonce propeller.Nonce + }{ + { + name: "valid roundtrip", + root: root, + committeeID: committeeID, + nonce: nonce, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := propeller.VerifyMessageSignature( + pubA, &tc.root, &tc.committeeID, tc.nonce, sig, + ) + require.NoError(t, err) + }) + } + }) + + t.Run("error", func(t *testing.T) { + _, pubB := generateKey(t, 2) + + tests := []struct { + name string + pubKey crypto.PubKey + root propeller.MessageRoot + committeeID propeller.CommitteeID + nonce propeller.Nonce + signature propeller.Signature + wantErr string + }{ + { + name: "wrong public key", + pubKey: pubB, + root: root, + committeeID: committeeID, + nonce: nonce, + signature: sig, + wantErr: "signature is invalid", + }, + { + name: "tampered root", + pubKey: pubA, + root: propeller.MessageRoot{0xFF}, + committeeID: committeeID, + nonce: nonce, + signature: sig, + wantErr: "signature is invalid", + }, + { + name: "tampered committee ID", + pubKey: pubA, + root: root, + committeeID: propeller.CommitteeID{0xFF}, + nonce: nonce, + signature: sig, + wantErr: "signature is invalid", + }, + { + name: "tampered nonce", + pubKey: pubA, + root: root, + committeeID: committeeID, + nonce: propeller.Nonce(time.Hour), + signature: sig, + wantErr: "signature is invalid", + }, + { + name: "empty signature", + pubKey: pubA, + root: root, + committeeID: committeeID, + nonce: nonce, + signature: nil, + wantErr: "empty signature", + }, + { + name: "corrupted signature", + pubKey: pubA, + root: root, + committeeID: committeeID, + nonce: nonce, + signature: append(append([]byte{}, sig...), 0xFF), + wantErr: "signature is invalid", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := propeller.VerifyMessageSignature( + tc.pubKey, &tc.root, &tc.committeeID, tc.nonce, tc.signature, + ) + require.ErrorContains(t, err, tc.wantErr) + }) + } + }) +} From d5bc6a09718109771a29319f467db3d14341dfce Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 11 Apr 2026 17:04:01 +0100 Subject: [PATCH 31/40] chore: remove all compilation issuess --- consensus/propeller/engine.go | 36 +- consensus/propeller/engine_test.go | 742 +++++++++--------- consensus/propeller/processor.go | 157 +++- consensus/propeller/propeller_test.go | 1 + .../propeller/reedsolomon/reedsolomon.go | 8 +- consensus/propeller/scheduler.go | 30 +- consensus/propeller/scheduler_test.go | 407 +++++----- consensus/propeller/sharding.go | 45 +- consensus/propeller/sharding_test.go | 404 +++++----- consensus/propeller/unit_test.go | 1 + 10 files changed, 987 insertions(+), 844 deletions(-) diff --git a/consensus/propeller/engine.go b/consensus/propeller/engine.go index 50849c9a2f..68e75323ed 100644 --- a/consensus/propeller/engine.go +++ b/consensus/propeller/engine.go @@ -3,6 +3,7 @@ package propeller import ( "context" "fmt" + "time" "github.com/NethermindEth/juno/utils" "github.com/libp2p/go-libp2p/core/crypto" @@ -171,13 +172,13 @@ func NewEngine( // registerCommittee creates the schedule and encoder for a new channel. func (e *Engine) registerCommittee( - committeeID CommitteeID, + committeeID *CommitteeID, peers []PeerCommittee, peersKeys []*StakerID, ) error { // todo(rdr): Why re-registration should be ignored, // as far as I understand, it shouldn't happen :think: - if _, ok := e.committees[committeeID]; ok { + if _, ok := e.committees[*committeeID]; ok { e.log.Warn( "committee already registered, will ignore re-registration attempt", // todo(rdr): give a propper string repr @@ -201,7 +202,7 @@ func (e *Engine) registerCommittee( return fmt.Errorf("couldn't register a new committee: %w", err) } - e.committees[committeeID] = &committeeState{ + e.committees[*committeeID] = &committeeState{ scheduler: schedule, // todo(rdr): need to add the peer pub keys peerKeys: nil, @@ -220,8 +221,8 @@ func (e *Engine) registerCommittee( // unregisterCommittee removes a channel's state. Not new processors will be started but // currently running ones will continue until the timeout / stop naturally -func (e *Engine) unregisterCommittee(committeeID CommitteeID) { - delete(e.committees, committeeID) +func (e *Engine) unregisterCommittee(committeeID *CommitteeID) { + delete(e.committees, *committeeID) // todo(rdr): We have to clean the processors, right? // or will they shut down on their own eventually // better to pass a context with cancelj @@ -234,27 +235,30 @@ func (e *Engine) unregisterCommittee(committeeID CommitteeID) { // prepareBroadcast creates Proppeller units asynchronously since it is a very expensive // operation. -func (e *Engine) prepareBroadcast(committeeID CommitteeID, data []byte) error { - cs, ok := e.committees[committeeID] +func (e *Engine) prepareBroadcast(committeeID *CommitteeID, data []byte) error { + cs, ok := e.committees[*committeeID] if !ok { return fmt.Errorf("cannot broadcast to an unregistered committee: %s", committeeID) } + // todo(rdr): unsure if this approach of passing arguments to the go routine makes sense // todo(rdr): consider having a maximum amount of working threads and a queue tasks for this // This is an expensive operation, hence we need to do it separately - go func() { + go func(e *Engine, scheduler *Scheduler, committeeID CommitteeID, data []byte) { units, err := CreatePropellerUnits( - committeeID, - data, e.privKey, - cs.scheduler.NumDataShards(), - cs.scheduler.NumCodingShards(), + &committeeID, + // todo(rdr): Find how nonce is set when creating propeller units + Nonce(time.Now().UnixNano()), + data, + scheduler.NumDataShards(), + scheduler.NumCodingShards(), ) e.unitsPrepared <- broadcastResult{ units: units, err: err, } - }() + }(e, cs.scheduler, *committeeID, data) return nil } @@ -315,12 +319,12 @@ func (e *Engine) forwardEvent(event any) { func (e *Engine) handleCommand(ctx context.Context, command engineCommand) { switch cmd := command.(type) { case *registerCommittee: - err := e.registerCommittee(cmd.committeeID, cmd.peers, cmd.peersKeys) + err := e.registerCommittee(&cmd.committeeID, cmd.peers, cmd.peersKeys) cmd.errCh <- err case *unregisterCommittee: - e.unregisterCommittee(cmd.committeeID) + e.unregisterCommittee(&cmd.committeeID) case *broadcast: - err := e.prepareBroadcast(cmd.committeeID, cmd.msg) + err := e.prepareBroadcast(&cmd.committeeID, cmd.msg) cmd.errCh <- err case *processUnit: e.processUnit(ctx, cmd.unit, cmd.sender) diff --git a/consensus/propeller/engine_test.go b/consensus/propeller/engine_test.go index 2e8f811267..66e9781a2d 100644 --- a/consensus/propeller/engine_test.go +++ b/consensus/propeller/engine_test.go @@ -1,372 +1,372 @@ -package propeller - -import ( - "bytes" - "context" - "crypto/ed25519" - "fmt" - "sync" - "testing" - "time" - - "github.com/NethermindEth/juno/utils" - "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/peer" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - +package propeller_test + +// import ( +// "bytes" +// "context" +// "crypto/ed25519" +// "fmt" +// "sync" +// "testing" +// "time" +// +// "github.com/NethermindEth/juno/utils" +// "github.com/libp2p/go-libp2p/core/crypto" +// "github.com/libp2p/go-libp2p/core/peer" +// "github.com/stretchr/testify/assert" +// "github.com/stretchr/testify/require" +// ) +// // engineTestEnv provides the common setup for engine-level tests. -type engineTestEnv struct { - peers []peer.ID - privKeys []crypto.PrivKey - engines []*Engine - sentUnits map[peer.ID][]*Unit - sentMu sync.Mutex - log utils.Logger -} - -//nolint:unparam // n is always 4 in current tests but kept for flexibility -func newEngineTestEnv(t *testing.T, n int) *engineTestEnv { - t.Helper() - - peers := make([]peer.ID, n) - privKeys := make([]crypto.PrivKey, n) - for i := range n { - seed := make([]byte, ed25519.SeedSize) - seed[0] = byte(i) - reader := bytes.NewReader(seed) - priv, pub, err := crypto.GenerateEd25519Key(reader) - require.NoError(t, err) - id, err := peer.IDFromPublicKey(pub) - require.NoError(t, err) - privKeys[i] = priv - peers[i] = id - } - - log := utils.NewNopZapLogger() - - env := &engineTestEnv{ - peers: peers, - privKeys: privKeys, - sentUnits: make(map[peer.ID][]*Unit), - log: log, - } - - config := Config{ - StaleMessageTimeout: 5 * time.Second, - StreamProtocol: "/propeller/test/0.1.0", - MaxWireMessageSize: 1 << 20, - } - - engines := make([]*Engine, n) - for i := range n { - engines[i] = NewEngine( - peers[i], privKeys[i], config, - env.makeSendFn(), - log, - ) - } - env.engines = engines - - return env -} - -// makeSendFn creates a SendUnitFunc that records sent units. -func (env *engineTestEnv) makeSendFn() SendUnitFunc { - return func(_ context.Context, to peer.ID, unit *Unit) error { - env.sentMu.Lock() - env.sentUnits[to] = append(env.sentUnits[to], unit) - env.sentMu.Unlock() - return nil - } -} - -// getSentUnits returns all units sent to a given peer. -func (env *engineTestEnv) getSentUnits(to peer.ID) []*Unit { - env.sentMu.Lock() - defer env.sentMu.Unlock() - result := make([]*Unit, len(env.sentUnits[to])) - copy(result, env.sentUnits[to]) - return result -} - -func TestEngine_RegisterAndBroadcast(t *testing.T) { - env := newEngineTestEnv(t, 4) - - ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) - defer cancel() - - engine := env.engines[0] - - // Run the engine in the background. - done := make(chan error, 1) - go func() { - done <- engine.Run(ctx) - }() - - // Register a channel with all peers. - err := engine.RegisterChannel(ctx, 1, env.peers) - require.NoError(t, err) - - // Broadcast a message. - msg := []byte("hello from engine test") - err = engine.Broadcast(ctx, 1, msg) - require.NoError(t, err) - - // Verify that units were sent to the other 3 peers. - // Give a moment for async processing. - time.Sleep(100 * time.Millisecond) - - totalSent := 0 - for _, p := range env.peers { - if p == env.peers[0] { - continue - } - units := env.getSentUnits(p) - totalSent += len(units) - } - assert.Equal(t, 3, totalSent, "should send one unit to each non-publisher peer") - - cancel() - <-done -} - -func TestEngine_BroadcastUnregisteredChannel(t *testing.T) { - env := newEngineTestEnv(t, 4) - - ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) - defer cancel() - - engine := env.engines[0] - - go func() { - engine.Run(ctx) //nolint:errcheck // test helper - }() - - err := engine.Broadcast(ctx, 99, []byte("should fail")) - require.Error(t, err) - - var pubErr *ShardPublishError - require.ErrorAs(t, err, &pubErr) - assert.Equal(t, ReasonChannelNotRegistered, pubErr.Reason) - - cancel() -} - -func TestEngine_HandleUnit_CreatesProcessor(t *testing.T) { - env := newEngineTestEnv(t, 4) - - ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) - defer cancel() - - // Set up engine for peer 0. - engine := env.engines[0] - - go func() { - engine.Run(ctx) //nolint:errcheck // test helper - }() - - // Register the channel. - err := engine.RegisterChannel(ctx, 1, env.peers) - require.NoError(t, err) - - // Simulate receiving a unit from peer 1 (as publisher). - schedule := NewScheduler(env.peers) - enc, err := NewEncoder(schedule.NumDataShards(), schedule.NumCodingShards()) - require.NoError(t, err) - - msg := []byte("incoming message") - units, root, err := EncodeMessage(msg, schedule, enc) - require.NoError(t, err) - - publisher := env.peers[1] - sig, err := SignMessage(root, env.privKeys[1]) - require.NoError(t, err) - - for i := range units { - units[i].Publisher = publisher - units[i].Signature = sig - units[i].CommitteeID = 1 - } - - // Send units from their correct senders. - for i, unit := range units { - sender, err := schedule.PeerForShard(publisher, ShardIndex(i)) - require.NoError(t, err) - - // Skip units "from ourselves" -- the validator rejects those. - if sender == env.peers[0] { - continue - } - - unitCopy := unit - engine.HandleUnit(&unitCopy, sender) - } - - // Wait for the message to be processed and check events. - var received *EventMessageReceived - deadline := time.After(5 * time.Second) - for received == nil { - select { - case ev := <-engine.Events(): - if r, ok := ev.(EventMessageReceived); ok { - received = &r - } - case <-deadline: - t.Fatal("timed out waiting for EventMessageReceived") - } - } - - assert.Equal(t, msg, received.Message) - assert.Equal(t, publisher, received.Publisher) - assert.Equal(t, root, received.Root) - - cancel() -} - -func TestEngine_UnregisterChannel(t *testing.T) { - env := newEngineTestEnv(t, 4) - - ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) - defer cancel() - - engine := env.engines[0] - - go func() { - engine.Run(ctx) //nolint:errcheck // test helper - }() - - err := engine.RegisterChannel(ctx, 1, env.peers) - require.NoError(t, err) - - err = engine.UnregisterChannel(ctx, 1) - require.NoError(t, err) - - // Allow command to be processed. - time.Sleep(50 * time.Millisecond) - - // Broadcasting should fail now. - err = engine.Broadcast(ctx, 1, []byte("after unregister")) - require.Error(t, err) - - cancel() -} - -func TestEngine_HandleUnit_UnregisteredChannel(t *testing.T) { - env := newEngineTestEnv(t, 4) - - ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) - defer cancel() - - engine := env.engines[0] - - go func() { - engine.Run(ctx) //nolint:errcheck // test helper - }() - - // Send a unit for an unregistered channel. - unit := &Unit{ - CommitteeID: 99, - Publisher: env.peers[1], - MessageRoot: MessageRoot{0x01}, - ShardIndex: 0, - ShardData: []byte("data"), - } - engine.HandleUnit(unit, env.peers[1]) - - // Allow time for processing. - time.Sleep(100 * time.Millisecond) - - // No crash, no panic -- the unit is silently dropped. - cancel() -} - -func TestEngine_GracefulShutdown(t *testing.T) { - env := newEngineTestEnv(t, 4) - - ctx, cancel := context.WithCancel(t.Context()) - - engine := env.engines[0] - done := make(chan error, 1) - go func() { - done <- engine.Run(ctx) - }() - - cancel() - - select { - case err := <-done: - assert.ErrorIs(t, err, context.Canceled) - case <-time.After(2 * time.Second): - t.Fatal("engine did not shut down in time") - } -} - -func TestEngine_SendFailureEmitsEvent(t *testing.T) { - env := newEngineTestEnv(t, 4) - - ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) - defer cancel() - - // Create an engine with a failing send function. - engine := NewEngine( - env.peers[0], env.privKeys[0], - Config{ - StaleMessageTimeout: 5 * time.Second, - StreamProtocol: "/propeller/test/0.1.0", - MaxWireMessageSize: 1 << 20, - }, - func(_ context.Context, _ peer.ID, _ *Unit) error { - return fmt.Errorf("simulated network failure") - }, - utils.NewNopZapLogger(), - ) - - go func() { - engine.Run(ctx) //nolint:errcheck // test helper - }() - - err := engine.RegisterChannel(ctx, 1, env.peers) - require.NoError(t, err) - - err = engine.Broadcast(ctx, 1, []byte("will fail sending")) - require.NoError(t, err) // Broadcast itself succeeds; send failures are events. - - // Collect send failure events. - deadline := time.After(2 * time.Second) - failures := 0 -loop: - for failures < 3 { - select { - case ev := <-engine.Events(): - if _, ok := ev.(EventShardSendFailed); ok { - failures++ - } - case <-deadline: - break loop - } - } - assert.Equal(t, 3, failures, "should have 3 send failures (one per non-publisher peer)") - - cancel() -} - -func TestEngine_RegisterChannelTooFewPeers(t *testing.T) { - env := newEngineTestEnv(t, 4) - - ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) - defer cancel() - - engine := env.engines[0] - - go func() { - engine.Run(ctx) //nolint:errcheck // test helper - }() - - // A single peer cannot form a channel (0 shards). - err := engine.RegisterChannel(ctx, 1, []peer.ID{env.peers[0]}) - require.Error(t, err) - - cancel() -} +// type engineTestEnv struct { +// peers []peer.ID +// privKeys []crypto.PrivKey +// engines []*Engine +// sentUnits map[peer.ID][]*Unit +// sentMu sync.Mutex +// log utils.Logger +// } +// +// //nolint:unparam // n is always 4 in current tests but kept for flexibility +// func newEngineTestEnv(t *testing.T, n int) *engineTestEnv { +// t.Helper() +// +// peers := make([]peer.ID, n) +// privKeys := make([]crypto.PrivKey, n) +// for i := range n { +// seed := make([]byte, ed25519.SeedSize) +// seed[0] = byte(i) +// reader := bytes.NewReader(seed) +// priv, pub, err := crypto.GenerateEd25519Key(reader) +// require.NoError(t, err) +// id, err := peer.IDFromPublicKey(pub) +// require.NoError(t, err) +// privKeys[i] = priv +// peers[i] = id +// } +// +// log := utils.NewNopZapLogger() +// +// env := &engineTestEnv{ +// peers: peers, +// privKeys: privKeys, +// sentUnits: make(map[peer.ID][]*Unit), +// log: log, +// } +// +// config := Config{ +// StaleMessageTimeout: 5 * time.Second, +// StreamProtocol: "/propeller/test/0.1.0", +// MaxWireMessageSize: 1 << 20, +// } +// +// engines := make([]*Engine, n) +// for i := range n { +// engines[i] = NewEngine( +// peers[i], privKeys[i], config, +// env.makeSendFn(), +// log, +// ) +// } +// env.engines = engines +// +// return env +// } +// +// // makeSendFn creates a SendUnitFunc that records sent units. +// func (env *engineTestEnv) makeSendFn() SendUnitFunc { +// return func(_ context.Context, to peer.ID, unit *Unit) error { +// env.sentMu.Lock() +// env.sentUnits[to] = append(env.sentUnits[to], unit) +// env.sentMu.Unlock() +// return nil +// } +// } +// +// // getSentUnits returns all units sent to a given peer. +// func (env *engineTestEnv) getSentUnits(to peer.ID) []*Unit { +// env.sentMu.Lock() +// defer env.sentMu.Unlock() +// result := make([]*Unit, len(env.sentUnits[to])) +// copy(result, env.sentUnits[to]) +// return result +// } +// +// func TestEngine_RegisterAndBroadcast(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) +// defer cancel() +// +// engine := env.engines[0] +// +// // Run the engine in the background. +// done := make(chan error, 1) +// go func() { +// done <- engine.Run(ctx) +// }() +// +// // Register a channel with all peers. +// err := engine.RegisterChannel(ctx, 1, env.peers) +// require.NoError(t, err) +// +// // Broadcast a message. +// msg := []byte("hello from engine test") +// err = engine.Broadcast(ctx, 1, msg) +// require.NoError(t, err) +// +// // Verify that units were sent to the other 3 peers. +// // Give a moment for async processing. +// time.Sleep(100 * time.Millisecond) +// +// totalSent := 0 +// for _, p := range env.peers { +// if p == env.peers[0] { +// continue +// } +// units := env.getSentUnits(p) +// totalSent += len(units) +// } +// assert.Equal(t, 3, totalSent, "should send one unit to each non-publisher peer") +// +// cancel() +// <-done +// } +// +// func TestEngine_BroadcastUnregisteredChannel(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) +// defer cancel() +// +// engine := env.engines[0] +// +// go func() { +// engine.Run(ctx) //nolint:errcheck // test helper +// }() +// +// err := engine.Broadcast(ctx, 99, []byte("should fail")) +// require.Error(t, err) +// +// var pubErr *ShardPublishError +// require.ErrorAs(t, err, &pubErr) +// assert.Equal(t, ReasonChannelNotRegistered, pubErr.Reason) +// +// cancel() +// } +// +// func TestEngine_HandleUnit_CreatesProcessor(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) +// defer cancel() +// +// // Set up engine for peer 0. +// engine := env.engines[0] +// +// go func() { +// engine.Run(ctx) //nolint:errcheck // test helper +// }() +// +// // Register the channel. +// err := engine.RegisterChannel(ctx, 1, env.peers) +// require.NoError(t, err) +// +// // Simulate receiving a unit from peer 1 (as publisher). +// schedule := NewScheduler(env.peers) +// enc, err := NewEncoder(schedule.NumDataShards(), schedule.NumCodingShards()) +// require.NoError(t, err) +// +// msg := []byte("incoming message") +// units, root, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// +// publisher := env.peers[1] +// sig, err := SignMessage(root, env.privKeys[1]) +// require.NoError(t, err) +// +// for i := range units { +// units[i].Publisher = publisher +// units[i].Signature = sig +// units[i].CommitteeID = 1 +// } +// +// // Send units from their correct senders. +// for i, unit := range units { +// sender, err := schedule.PeerForShard(publisher, ShardIndex(i)) +// require.NoError(t, err) +// +// // Skip units "from ourselves" -- the validator rejects those. +// if sender == env.peers[0] { +// continue +// } +// +// unitCopy := unit +// engine.HandleUnit(&unitCopy, sender) +// } +// +// // Wait for the message to be processed and check events. +// var received *EventMessageReceived +// deadline := time.After(5 * time.Second) +// for received == nil { +// select { +// case ev := <-engine.Events(): +// if r, ok := ev.(EventMessageReceived); ok { +// received = &r +// } +// case <-deadline: +// t.Fatal("timed out waiting for EventMessageReceived") +// } +// } +// +// assert.Equal(t, msg, received.Message) +// assert.Equal(t, publisher, received.Publisher) +// assert.Equal(t, root, received.Root) +// +// cancel() +// } +// +// func TestEngine_UnregisterChannel(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) +// defer cancel() +// +// engine := env.engines[0] +// +// go func() { +// engine.Run(ctx) //nolint:errcheck // test helper +// }() +// +// err := engine.RegisterChannel(ctx, 1, env.peers) +// require.NoError(t, err) +// +// err = engine.UnregisterChannel(ctx, 1) +// require.NoError(t, err) +// +// // Allow command to be processed. +// time.Sleep(50 * time.Millisecond) +// +// // Broadcasting should fail now. +// err = engine.Broadcast(ctx, 1, []byte("after unregister")) +// require.Error(t, err) +// +// cancel() +// } +// +// func TestEngine_HandleUnit_UnregisteredChannel(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) +// defer cancel() +// +// engine := env.engines[0] +// +// go func() { +// engine.Run(ctx) //nolint:errcheck // test helper +// }() +// +// // Send a unit for an unregistered channel. +// unit := &Unit{ +// CommitteeID: 99, +// Publisher: env.peers[1], +// MessageRoot: MessageRoot{0x01}, +// ShardIndex: 0, +// ShardData: []byte("data"), +// } +// engine.HandleUnit(unit, env.peers[1]) +// +// // Allow time for processing. +// time.Sleep(100 * time.Millisecond) +// +// // No crash, no panic -- the unit is silently dropped. +// cancel() +// } +// +// func TestEngine_GracefulShutdown(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithCancel(t.Context()) +// +// engine := env.engines[0] +// done := make(chan error, 1) +// go func() { +// done <- engine.Run(ctx) +// }() +// +// cancel() +// +// select { +// case err := <-done: +// assert.ErrorIs(t, err, context.Canceled) +// case <-time.After(2 * time.Second): +// t.Fatal("engine did not shut down in time") +// } +// } +// +// func TestEngine_SendFailureEmitsEvent(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) +// defer cancel() +// +// // Create an engine with a failing send function. +// engine := NewEngine( +// env.peers[0], env.privKeys[0], +// Config{ +// StaleMessageTimeout: 5 * time.Second, +// StreamProtocol: "/propeller/test/0.1.0", +// MaxWireMessageSize: 1 << 20, +// }, +// func(_ context.Context, _ peer.ID, _ *Unit) error { +// return fmt.Errorf("simulated network failure") +// }, +// utils.NewNopZapLogger(), +// ) +// +// go func() { +// engine.Run(ctx) //nolint:errcheck // test helper +// }() +// +// err := engine.RegisterChannel(ctx, 1, env.peers) +// require.NoError(t, err) +// +// err = engine.Broadcast(ctx, 1, []byte("will fail sending")) +// require.NoError(t, err) // Broadcast itself succeeds; send failures are events. +// +// // Collect send failure events. +// deadline := time.After(2 * time.Second) +// failures := 0 +// loop: +// for failures < 3 { +// select { +// case ev := <-engine.Events(): +// if _, ok := ev.(EventShardSendFailed); ok { +// failures++ +// } +// case <-deadline: +// break loop +// } +// } +// assert.Equal(t, 3, failures, "should have 3 send failures (one per non-publisher peer)") +// +// cancel() +// } +// +// func TestEngine_RegisterChannelTooFewPeers(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) +// defer cancel() +// +// engine := env.engines[0] +// +// go func() { +// engine.Run(ctx) //nolint:errcheck // test helper +// }() +// +// // A single peer cannot form a channel (0 shards). +// err := engine.RegisterChannel(ctx, 1, []peer.ID{env.peers[0]}) +// require.Error(t, err) +// +// cancel() +// } diff --git a/consensus/propeller/processor.go b/consensus/propeller/processor.go index c707243a0c..78168993d4 100644 --- a/consensus/propeller/processor.go +++ b/consensus/propeller/processor.go @@ -4,9 +4,11 @@ import ( "context" "errors" "fmt" + "math/rand" "sync" "time" + "github.com/NethermindEth/juno/consensus/propeller/timecache" "github.com/NethermindEth/juno/utils" "github.com/libp2p/go-libp2p/core/peer" "go.uber.org/zap" @@ -18,25 +20,29 @@ type unitWithSender struct { } type subprocessor struct { - scheduler *Scheduler - localShardIndex ShardIndex - localShardWasBroadcast bool + scheduler *Scheduler + localPeer peer.ID + localShardIndex ShardIndex unitsChan <-chan unitWithSender invalidUnitsChan chan<- invalidUnit + // todo(rdr): I think I would like it more if it is called UnitValidator since + // is more specfic validator Validator } func newSubprocessor( publisher peer.ID, scheduler *Scheduler, + localPeer peer.ID, localShardIndex ShardIndex, unitsChan <-chan unitWithSender, invalidUnitsChan chan<- invalidUnit, ) subprocessor { return subprocessor{ scheduler: scheduler, + localPeer: localPeer, localShardIndex: localShardIndex, unitsChan: unitsChan, @@ -46,29 +52,54 @@ func newSubprocessor( } } +func (s *subprocessor) broadcastUnit(unit *Unit) error { + index := 0 + peers := make([]peer.ID, len(s.scheduler.Peers())-2) + for _, peerCommittee := range s.scheduler.Peers() { + if peerCommittee.ID == unit.Publisher || peerCommittee.ID == s.localPeer { + continue + } + // todo(rdr): index out of range issue in this code + peers[index] = peerCommittee.ID + index += 1 + } + rand.Shuffle(len(peers), func(i, j int) { + peers[i], peers[j] = peers[j], peers[i] + }) + + // todo(rdr): This should forward the unit and the peers that require broadcasting + panic("not implemented") +} + func (s *subprocessor) beforeMessageBuiltStage(ctx context.Context) ( - unitsReceived []*Unit, - unitCount int, - message []byte, - err error, + int, []byte, error, ) { // Keep track of the units received - unitsReceived = make([]*Unit, s.scheduler.ReceiveThreshold()) - unitCount = 0 + unitsReceived := make([]*Unit, s.scheduler.NumTotalShards()) + unitCount := 0 localShardWasBroadcast := false - buildThreshold := s.scheduler.BuildThreshold() - for unitCount != buildThreshold { + // todo(rdr): we are triggering message building (expensive) as soon as the bulid threshold is + // achieved, but it might be convenient to wait a few seconds to see if more messages + // will arrive. Although, that will mean we also need to validate any of those extra messages. + // The question is then: Do the cost of validating missing messages reduces greatly the cost + // of recovering them? Cases to consider: + // - Perfect network condition: a lot of bandwith and everybody is good. Does receiving all + // all the missing messages and validating them is cheaper than recovering them? What's the + // performance difference? <- Write benchmark + // - Bad network conditions: does the time waiting but receiving no messages will + // cause to waste a few seconds were the build was already done + // - Bad messages: the remaining messages we are waiting for and hence we incur on the cost + // of validating them but we get no benefit and we don't reduce the cost of recovering them. + for unitCount != s.scheduler.BuildThreshold() { select { case <-ctx.Done(): - return + return 0, nil, ctx.Err() case unitWithSender := <-s.unitsChan: unit := unitWithSender.unit sender := unitWithSender.sender - - err = s.validator.ValidateUnit(unit, sender) - if err != nil { + if err := s.validator.ValidateUnit(unit, sender); err != nil { s.invalidUnitsChan <- invalidUnit{ // todo(rdr): not sure if we need message key. // We just want to penalize the sender @@ -79,7 +110,7 @@ func (s *subprocessor) beforeMessageBuiltStage(ctx context.Context) ( // if this is the first unit we are receiving, finish abruptly since // it can be a DOS attack. if unitCount == 0 { - return + return 0, nil, fmt.Errorf("couldn't validate first unit received: %w", err) } continue } @@ -88,25 +119,60 @@ func (s *subprocessor) beforeMessageBuiltStage(ctx context.Context) ( unitCount += 1 // broadcast as soon as I get my shard - if localShardWasBroadcast && s.localShardIndex == unit.ShardIndex { + if !localShardWasBroadcast && s.localShardIndex == unit.ShardIndex { localShardWasBroadcast = true - // todo(rdr): actually broadcast shard index + err := s.broadcastUnit(unit) + if err != nil { + // todo(rdr): tbd if we need an error here + panic(err) + } } } } - // perform the build thing - panic("not implemented") + fullMessage, localShardData, localProof, err := ConstructMessageFromUnits( + unitsReceived, + s.localShardIndex, + s.scheduler.NumDataShards(), + s.scheduler.NumCodingShards(), + ) + if err != nil { + return 0, nil, err + } + + if !localShardWasBroadcast { + // We pick a unit at random to fill the common data between the two. All of these values + // have already been verified up top. + // todo(rdr): there is an issue where unit in 0 is not guaranteed to be non-nil + unit := unitsReceived[0] + localUnit := Unit{ + CommitteeID: unit.CommitteeID, + Publisher: unit.Publisher, + MessageRoot: unit.MessageRoot, + Nonce: unit.Nonce, + Signature: unit.Signature, + MerkleProof: localProof, + ShardIndex: s.localShardIndex, + ShardData: localShardData, + } + err := s.broadcastUnit(&localUnit) + if err != nil { + // todo(rdr): tbd if we need an error here + panic(err) + } + unitCount += 1 + } + + return unitCount, fullMessage, nil } func (s *subprocessor) beforeMessageReceivedStage( ctx context.Context, - unitsReceived []*Unit, unitCount int, message []byte, ) error { - receivedThreshold := s.scheduler.ReceiveThreshold() - for unitCount != receivedThreshold { + receiveThreshold := s.scheduler.ReceiveThreshold() + for unitCount != receiveThreshold { select { case <-ctx.Done(): return ctx.Err() @@ -121,14 +187,17 @@ func (s *subprocessor) beforeMessageReceivedStage( } continue } - - unitsReceived[int(unit.ShardIndex)] = unit + if unit.ShardIndex == s.localShardIndex { + continue + } unitCount += 1 } } - // do the actual job that requires doing once the receive threshold is reached - panic("not implemented") + // todo(rdr): if we are here it means the message has been received. + // forward it to (proc/engine/service <- one of these) + + return nil } // todo(rdr): we need to be sure to test both cases: @@ -138,19 +207,19 @@ func (s *subprocessor) Run( ctx context.Context, ) error { // The Run function works in two main loops depending on the stage we are in. - // First stage is before we can build the message, where which we receive messsages + // First stage is before we can build the message, in which we receive messsages // until we have enough to build the full messsage. The local shard will be broadcasted // during this stage. // Second stage starts with the full message built and waits until we receive enough - // messages to reach the received threshold, which guarantees that at leasrt 2/3 of the - // network is non faulty. This stages broadcasts the whole message once finished + // messages to reach the received threshold, which guarantees that at least 2/3 of the + // network is non faulty. Once there, Broadcast the rebuilt message and finishes - unitsReceived, unitCount, message, err := s.beforeMessageBuiltStage(ctx) + unitCount, message, err := s.beforeMessageBuiltStage(ctx) if err != nil { return err } - return s.beforeMessageReceivedStage(ctx, unitsReceived, unitCount, message) + return s.beforeMessageReceivedStage(ctx, unitCount, message) } // messageKey are a copy of the values of a propeller unit that uniquely identifies it @@ -198,9 +267,9 @@ type concurrentTasksBounds struct { // Processor handles all concurrent work on message processing type Processor struct { // to avoid processing units already finalized - finalized *TimeCache[messageKey] + finalized *timecache.TimeCache[messageKey] - subProcessors map[messageKey]chan unitWithSender + subProcessors map[messageKey]chan<- unitWithSender // channel through wich subprocessors signal they have finalized execution subProcessorsFinalized chan finalizedSubprocessor // channel through which subprocessor sharedunits that failed validation @@ -221,9 +290,9 @@ func NewProcessor(localPeer peer.ID, config *Config) *Processor { timeout := config.StaleMessageTimeout return &Processor{ - finalized: NewTimeCache[messageKey](timeout), + finalized: timecache.New[messageKey](2048, timeout), - subProcessors: make(map[messageKey]chan unitWithSender), + subProcessors: make(map[messageKey]chan<- unitWithSender), subProcessorsFinalized: make(chan finalizedSubprocessor), invalidUnits: make(chan invalidUnit), @@ -280,7 +349,7 @@ func (p *Processor) ProcessMessage( scheduler *Scheduler, ) error { key := extractKey(unit) - if p.finalized.Contains(key) { + if p.finalized.Get(&key) { return nil } @@ -316,7 +385,7 @@ func (p *Processor) createSubprocessor( ctx context.Context, key *messageKey, scheduler *Scheduler, -) (chan unitWithSender, error) { +) (chan<- unitWithSender, error) { localShardIndex, err := scheduler.ShardIndexForPublisher(key.Publisher) if err != nil { return nil, fmt.Errorf( @@ -347,9 +416,13 @@ func (p *Processor) createSubprocessor( unitChan <-chan unitWithSender, ) { subProcessor := newSubprocessor( - key.Publisher, scheduler, localShardIndex, unitChan, p.invalidUnits, + key.Publisher, scheduler, p.localPeer, localShardIndex, unitChan, p.invalidUnits, ) - subProcessor.Run(ctx) + err := subProcessor.Run(ctx) + p.subProcessorsFinalized <- finalizedSubprocessor{ + messageKey: key, + error: err, + } }(ctxWithTimeout, *key, scheduler, localShardIndex, unitChan) return unitChan, nil @@ -361,7 +434,7 @@ func (p *Processor) subprocessorChannel( ctx context.Context, key *messageKey, scheduler *Scheduler, -) (chan unitWithSender, error) { +) (chan<- unitWithSender, error) { unitChan, ok := p.subProcessors[*key] if ok { return unitChan, nil @@ -377,7 +450,7 @@ func (p *Processor) subprocessorChannel( func (p *Processor) finalize(key *messageKey) { p.decreaseTask(key.Publisher) delete(p.subProcessors, *key) - p.finalized.Add(*key) + p.finalized.Add(key) } func (p *Processor) increaseTasks(publisher peer.ID) error { diff --git a/consensus/propeller/propeller_test.go b/consensus/propeller/propeller_test.go index e69de29bb2..bb61865bd4 100644 --- a/consensus/propeller/propeller_test.go +++ b/consensus/propeller/propeller_test.go @@ -0,0 +1 @@ +package propeller_test diff --git a/consensus/propeller/reedsolomon/reedsolomon.go b/consensus/propeller/reedsolomon/reedsolomon.go index 0ed3073649..78a7b9da77 100644 --- a/consensus/propeller/reedsolomon/reedsolomon.go +++ b/consensus/propeller/reedsolomon/reedsolomon.go @@ -39,10 +39,10 @@ func EncodeData( return split, nil } -// RecoverData restores the missing data using Reed-Solomon erasure codes. There cannot be more than -// `parity` shards missing otherwise the recover will fail. Data that is considered missing needs to -// be marked as `nil`. Returns the recovered data. -// The data will be modified in place so the input shouldn't be modified after calling this function. +// RecoverData restores the missing data using Reed-Solomon erasure codes. +// There cannot be more than `parity` shards missing otherwise the recover will fail. +// Data that is considered missing needs to be marked as `nil`. Returns the recovered data. +// The input data shards well be modified in place. func RecoverData( shards [][]byte, numDataShards, diff --git a/consensus/propeller/scheduler.go b/consensus/propeller/scheduler.go index c7840c4a58..76333908d5 100644 --- a/consensus/propeller/scheduler.go +++ b/consensus/propeller/scheduler.go @@ -73,13 +73,13 @@ func NewScheduler( }, ) if !exists { - return nil, errors.New("the local peer id is not part of the suplied list of peeers") + return nil, errors.New("the local peer id is not part of the supplied list of peeers") } // check that there is no duplicated ID in the node list for i := range len(nodes) - 1 { if nodes[i].ID == nodes[i+1].ID { - return nil, fmt.Errorf("duplicated ids in the suplied list of peers: %s", nodes[i].ID) + return nil, fmt.Errorf("duplicated ids in the supplied list of peers: %s", nodes[i].ID) } } @@ -118,9 +118,7 @@ func (s *Scheduler) NumCodingShards() int { return s.numCodingShards } func (s *Scheduler) NumTotalShards() int { return s.numDataShards + s.numCodingShards } // Minimum (inclusive) amount of shards required to build a message -func (s *Scheduler) BuildThreshold() int { - return s.numDataShards -} +func (s *Scheduler) BuildThreshold() int { return s.numDataShards } // Minimum (inclusive) amount of shards required to guarantee a message is received func (s *Scheduler) ReceiveThreshold() int { @@ -190,7 +188,7 @@ func (s *Scheduler) ShardIndexForPublisher( ) (ShardIndex, error) { if s.localPeerID == publisher { return 0, fmt.Errorf( - "scheduler peer is the same as the publisher and has no assinged shard: %s", + "scheduler peer is the same as the publisher and has no assigned shard: %s", publisher, ) } @@ -210,17 +208,18 @@ func (s *Scheduler) ShardIndexForPublisher( // ValidateShardOrigin verifies that a shard unit was received from the expected sender. // The sender has to be either the publisher for direct shards or a designated -// broadcasted for the given shard index. +// broadcaster for the given shard index. +// todo(rdr): Maybe the unit validator should have this implementation func (s *Scheduler) ValidateShardOrigin( sender peer.ID, publisher peer.ID, shardIndex ShardIndex, ) error { if sender == s.localPeerID { - return fmt.Errorf("scheduler sent itself a shard: %s", sender) + return fmt.Errorf("self sending message from %s", sender) } if publisher == s.localPeerID { - return fmt.Errorf("scheduler broadcast itself a shard: %s", publisher) + return fmt.Errorf("self published shard was sent back by %s", sender) } expectedBroadcaster, err := s.PeerForShardIndex(publisher, shardIndex) @@ -250,18 +249,15 @@ func (s *Scheduler) ValidateShardOrigin( ) } -// BroadcastTargets returns all peers the Schudler's peer needs to braodcast to, -// in shard-index order. The i-th element of the returned slice is the peer responsible for -// shard i. +// BroadcastTargets returns all peers whom to broadcast to, in shard-index order. +// The i-th element of the returned slice is the peer responsible for shard i. func (s *Scheduler) BroadcastTargets() []peer.ID { - targets := make([]peer.ID, s.NumTotalShards()-1) - i := 0 - for _, p := range s.peers { + targets := make([]peer.ID, 0, s.NumTotalShards()) + for i, p := range s.peers { if i == s.localPeerIDIndex { continue } - targets[i] = p.ID - i += 1 + targets = append(targets, p.ID) } return targets } diff --git a/consensus/propeller/scheduler_test.go b/consensus/propeller/scheduler_test.go index d8153483fd..e849fc7e4e 100644 --- a/consensus/propeller/scheduler_test.go +++ b/consensus/propeller/scheduler_test.go @@ -1,10 +1,6 @@ -// todo(rdr): make it propeller_test package propeller import ( - "cmp" - "math/rand" - "slices" "testing" "github.com/libp2p/go-libp2p/core/peer" @@ -12,89 +8,116 @@ import ( "github.com/stretchr/testify/require" ) -// testPeers creates N deterministic peer IDs that sort in alphabetical order. -// Also returns a local peer ID choosen at random fromt the list -func testPeers(t *testing.T, names ...string) (peer.ID, []PeerCommittee) { +// testPeers creates PeerCommittee entries from the given names. +// Each test chooses its own local peer explicitly. +func testPeers(t *testing.T, names ...string) []PeerCommittee { t.Helper() - peers := make([]PeerCommittee, len(names)) for i, n := range names { - peers[i] = PeerCommittee{ - ID: peer.ID(n), - Stake: Stake(rand.Uint32()), - } + peers[i] = PeerCommittee{ID: peer.ID(n), Stake: 1} } - // note(rdr): should we make the random generation deterministic? - localPeer := peers[rand.Int()%len(peers)].ID + return peers +} + +func TestScheduler_NewScheduler_Validation(t *testing.T) { + t.Run("fewer than 2 peers", func(t *testing.T) { + peers := testPeers(t, "A") + _, err := NewScheduler(peer.ID("A"), peers) + assert.Error(t, err) + }) + + t.Run("local peer not in list", func(t *testing.T) { + peers := testPeers(t, "A", "B", "C") + _, err := NewScheduler(peer.ID("Z"), peers) + assert.Error(t, err) + }) + + t.Run("duplicate peers", func(t *testing.T) { + peers := testPeers(t, "A", "B", "B") + _, err := NewScheduler(peer.ID("A"), peers) + assert.Error(t, err) + }) - return localPeer, peers + t.Run("valid construction", func(t *testing.T) { + peers := testPeers(t, "A", "B", "C") + s, err := NewScheduler(peer.ID("B"), peers) + require.NoError(t, err) + assert.Equal(t, peer.ID("B"), s.PeerID()) + }) } -func TestSchedule_Thresholds(t *testing.T) { +func TestScheduler_ShardCounts(t *testing.T) { tests := []struct { name string n int numDataShards int numCodingShards int - numShards int + numTotalShards int buildThreshold int receiveThreshold int }{ { - name: "N=1 (solo node, no shards)", - n: 1, - numDataShards: 0, - numCodingShards: 0, - numShards: 0, - }, - { - name: "N=2", - n: 2, - numDataShards: 1, // max of 1 and (1/3) yields 1 - numCodingShards: 0, // 1 minus 1 yields 0 - numShards: 1, + name: "N=2", + n: 2, + numDataShards: 1, + numCodingShards: 0, + numTotalShards: 1, + buildThreshold: 1, + receiveThreshold: 1, // len<=3, falls back to buildThreshold }, { - name: "N=3", - n: 3, - numDataShards: 1, // max of 1 and (2/3) yields 1 - numCodingShards: 1, // 2 minus 1 yields 1 - numShards: 2, + name: "N=3", + n: 3, + numDataShards: 1, + numCodingShards: 1, + numTotalShards: 2, + buildThreshold: 1, + receiveThreshold: 1, // len<=3, falls back to buildThreshold }, { - name: "N=4", - n: 4, - numDataShards: 1, // 3/3 = 1 - numCodingShards: 2, // 3 - 1 = 2 - numShards: 3, + name: "N=4", + n: 4, + numDataShards: 1, + numCodingShards: 2, + numTotalShards: 3, + buildThreshold: 1, + receiveThreshold: 2, }, { - name: "N=5", - n: 5, - numDataShards: 1, // 4/3 = 1 - numCodingShards: 3, // 4 - 1 = 3 - numShards: 4, + name: "N=5", + n: 5, + numDataShards: 1, + numCodingShards: 3, + numTotalShards: 4, + buildThreshold: 1, + receiveThreshold: 2, }, { - name: "N=7", - n: 7, - numDataShards: 2, // 6/3 = 2 - numCodingShards: 4, // 6 - 2 = 4 - numShards: 6, + name: "N=7", + n: 7, + numDataShards: 2, + numCodingShards: 4, + numTotalShards: 6, + buildThreshold: 2, + receiveThreshold: 4, }, { - name: "N=10", - n: 10, - numDataShards: 3, // 9/3 = 3 - numCodingShards: 6, // 9 - 3 = 6 - numShards: 9, + name: "N=10", + n: 10, + numDataShards: 3, + numCodingShards: 6, + numTotalShards: 9, + buildThreshold: 3, + receiveThreshold: 6, }, { - name: "N=31", - n: 31, - numDataShards: 10, // 30/3 = 10 - numCodingShards: 20, // 30 - 10 = 20 - numShards: 30, + name: "N=31", + n: 31, + numDataShards: 10, + numCodingShards: 20, + numTotalShards: 30, + buildThreshold: 10, + receiveThreshold: 20, }, } @@ -104,183 +127,217 @@ func TestSchedule_Thresholds(t *testing.T) { for i := range tc.n { names[i] = string(rune('A' + i)) } - localPeer, peers := testPeers(t, names...) + peers := testPeers(t, names...) - s, err := NewScheduler(localPeer, peers) + s, err := NewScheduler(peers[0].ID, peers) require.NoError(t, err) - assert.Equal(t, tc.numDataShards, s.DataShards()) - assert.Equal(t, tc.numCodingShards, s.CodingShards()) - assert.Equal(t, tc.numShards, s.NumShards()) + assert.Equal(t, tc.numDataShards, s.NumDataShards()) + assert.Equal(t, tc.numCodingShards, s.NumCodingShards()) + assert.Equal(t, tc.numTotalShards, s.NumTotalShards()) + assert.Equal(t, tc.buildThreshold, s.BuildThreshold()) + assert.Equal(t, tc.receiveThreshold, s.ReceiveThreshold()) }) } } -func TestSchedule_Sorting(t *testing.T) { - // Peers provided out of order should be sorted. - localPeer, peers := testPeers(t, "D", "B", "A", "C") - - sortedPeers := make([]PeerCommittee, 0, len(peers)) - copy(sortedPeers, peers) - slices.SortFunc(sortedPeers, func(a, b PeerCommittee) int { - return cmp.Compare(a.ID, b.ID) - }) +func TestScheduler_DeterministicMapping(t *testing.T) { + // Two schedulers built from differently-ordered peer lists must + // produce identical shard-to-peer mappings for every publisher. + s1, err := NewScheduler(peer.ID("A"), testPeers(t, "A", "B", "C", "D")) + require.NoError(t, err) - s, err := NewScheduler(localPeer, peers) + s2, err := NewScheduler(peer.ID("A"), testPeers(t, "D", "B", "A", "C")) require.NoError(t, err) - assert.Equal(t, sortedPeers, s.Peers()) + for _, pub := range s1.Peers() { + for idx := range s1.NumTotalShards() { + p1, err1 := s1.PeerForShardIndex(pub.ID, ShardIndex(idx)) + p2, err2 := s2.PeerForShardIndex(pub.ID, ShardIndex(idx)) + require.NoError(t, err1) + require.NoError(t, err2) + assert.Equal(t, p1, p2, "publisher=%s shard=%d", pub.ID, idx) + } + } } -func TestSchedule_PeerForShard_SpecExample(t *testing.T) { - // From the specification: peers [A, B, C, D], publisher = C (index 2). +func TestScheduler_PeerForShardIndex_SpecExample(t *testing.T) { + // From the doc comment: peers [A, B, C, D], publisher = C (index 2). // Shard 0 -> A, Shard 1 -> B, Shard 2 -> D - localPeer, peers := testPeers(t, "A", "B", "C", "D") - - s, err := NewScheduler(localPeer, peers) + s, err := NewScheduler(peer.ID("A"), testPeers(t, "A", "B", "C", "D")) require.NoError(t, err) publisher := peer.ID("C") - - tests := []struct { - shardIndex ShardIndex - expected peer.ID - }{ - {0, peer.ID("A")}, - {1, peer.ID("B")}, - {2, peer.ID("D")}, - } - - for _, tc := range tests { - got, err := s.PeerForShard(publisher, tc.shardIndex) + expected := []peer.ID{"A", "B", "D"} + for i, want := range expected { + got, err := s.PeerForShardIndex(publisher, ShardIndex(i)) require.NoError(t, err) - assert.Equal(t, tc.expected, got, "shard %d", tc.shardIndex) + assert.Equal(t, want, got, "shard %d", i) } } -func TestSchedule_PeerForShard_PublisherFirst(t *testing.T) { +func TestScheduler_PeerForShardIndex_PublisherFirst(t *testing.T) { // Publisher is the first peer in sorted order. - peers := testPeers("A", "B", "C", "D") - s := NewScheduler(peers) - publisher := peer.ID("A") + s, err := NewScheduler(peer.ID("B"), testPeers(t, "A", "B", "C", "D")) + require.NoError(t, err) + publisher := peer.ID("A") // Shard 0 -> B, Shard 1 -> C, Shard 2 -> D - expected := testPeers("B", "C", "D") - for i, exp := range expected { - got, err := s.PeerForShard(publisher, ShardIndex(i)) + expected := []peer.ID{"B", "C", "D"} + for i, want := range expected { + got, err := s.PeerForShardIndex(publisher, ShardIndex(i)) require.NoError(t, err) - assert.Equal(t, exp, got) + assert.Equal(t, want, got, "shard %d", i) } } -func TestSchedule_PeerForShard_PublisherLast(t *testing.T) { +func TestScheduler_PeerForShardIndex_PublisherLast(t *testing.T) { // Publisher is the last peer in sorted order. - peers := testPeers("A", "B", "C", "D") - s := NewScheduler(peers) - publisher := peer.ID("D") + s, err := NewScheduler(peer.ID("A"), testPeers(t, "A", "B", "C", "D")) + require.NoError(t, err) + publisher := peer.ID("D") // Shard 0 -> A, Shard 1 -> B, Shard 2 -> C - expected := testPeers("A", "B", "C") - for i, exp := range expected { - got, err := s.PeerForShard(publisher, ShardIndex(i)) + expected := []peer.ID{"A", "B", "C"} + for i, want := range expected { + got, err := s.PeerForShardIndex(publisher, ShardIndex(i)) require.NoError(t, err) - assert.Equal(t, exp, got) + assert.Equal(t, want, got, "shard %d", i) } } -func TestSchedule_PeerForShard_Errors(t *testing.T) { - peers := testPeers("A", "B", "C") - s := NewScheduler(peers) +func TestScheduler_PeerForShardIndex_Errors(t *testing.T) { + s, err := NewScheduler(peer.ID("A"), testPeers(t, "A", "B", "C")) + require.NoError(t, err) - // Publisher not in list. - _, err := s.PeerForShard(peer.ID("Z"), 0) - assert.Error(t, err) - assert.Contains(t, err.Error(), "not found") + t.Run("publisher not in list", func(t *testing.T) { + _, err := s.PeerForShardIndex(peer.ID("Z"), 0) + assert.Error(t, err) + }) - // Shard index out of range. - _, err = s.PeerForShard(peer.ID("A"), ShardIndex(s.NumShards())) - assert.Error(t, err) - assert.Contains(t, err.Error(), "out of range") + t.Run("shard index out of range", func(t *testing.T) { + _, err := s.PeerForShardIndex(peer.ID("A"), ShardIndex(s.NumTotalShards())) + assert.Error(t, err) + }) } -func TestSchedule_ShardForPeer_SpecExample(t *testing.T) { - // Inverse of PeerForShard: peers [A,B,C,D], publisher=C. - // A -> shard 0, B -> shard 1, D -> shard 2 - peers := testPeers("A", "B", "C", "D") - s := NewScheduler(peers) - publisher := peer.ID("C") - +func TestScheduler_ShardIndexForPublisher(t *testing.T) { + // peers [A, B, C, D], publisher=C. + // A(idx 0) -> shard 0, B(idx 1) -> shard 1, D(idx 3) -> shard 2 tests := []struct { - localPeer peer.ID - expected ShardIndex + localPeer string + expectedShard ShardIndex }{ - {peer.ID("A"), 0}, - {peer.ID("B"), 1}, - {peer.ID("D"), 2}, + {"A", 0}, + {"B", 1}, + {"D", 2}, } + publisher := peer.ID("C") for _, tc := range tests { - got, err := s.ShardForPeer(publisher, tc.localPeer) - require.NoError(t, err) - assert.Equal(t, tc.expected, got, "peer %s", tc.localPeer) + t.Run("local="+tc.localPeer, func(t *testing.T) { + s, err := NewScheduler( + peer.ID(tc.localPeer), testPeers(t, "A", "B", "C", "D"), + ) + require.NoError(t, err) + + got, err := s.ShardIndexForPublisher(publisher) + require.NoError(t, err) + assert.Equal(t, tc.expectedShard, got) + }) } } -func TestSchedule_ShardForPeer_PublisherError(t *testing.T) { - peers := testPeers("A", "B", "C") - s := NewScheduler(peers) - - // The publisher itself has no assigned shard. - _, err := s.ShardForPeer(peer.ID("B"), peer.ID("B")) - assert.Error(t, err) - assert.Contains(t, err.Error(), "is the publisher") -} +func TestScheduler_ShardIndexForPublisher_Errors(t *testing.T) { + s, err := NewScheduler(peer.ID("B"), testPeers(t, "A", "B", "C")) + require.NoError(t, err) -func TestSchedule_ShardForPeer_NotFound(t *testing.T) { - peers := testPeers("A", "B", "C") - s := NewScheduler(peers) + t.Run("local peer is the publisher", func(t *testing.T) { + _, err := s.ShardIndexForPublisher(peer.ID("B")) + assert.Error(t, err) + }) - _, err := s.ShardForPeer(peer.ID("A"), peer.ID("Z")) - assert.Error(t, err) - assert.Contains(t, err.Error(), "not found") + t.Run("publisher not in list", func(t *testing.T) { + _, err := s.ShardIndexForPublisher(peer.ID("Z")) + assert.Error(t, err) + }) } -func TestSchedule_PeerForShardAndShardForPeer_AreInverses(t *testing.T) { - // For every publisher, verify that PeerForShard and ShardForPeer - // are consistent inverses. - peers := testPeers("A", "B", "C", "D", "E") - s := NewScheduler(peers) +func TestScheduler_InverseProperty(t *testing.T) { + // For every local peer and every publisher, verify that + // PeerForShardIndex and ShardIndexForPublisher are inverses. + names := []string{"A", "B", "C", "D", "E"} + + for _, local := range names { + s, err := NewScheduler(peer.ID(local), testPeers(t, names...)) + require.NoError(t, err) - for _, publisher := range s.Peers() { - for shardIdx := range s.NumShards() { - p, err := s.PeerForShard(publisher, ShardIndex(shardIdx)) + for _, pub := range names { + if pub == local { + continue + } + // ShardIndexForPublisher -> PeerForShardIndex should round-trip + shardIdx, err := s.ShardIndexForPublisher(peer.ID(pub)) require.NoError(t, err) - // The reverse: given that peer, find its shard index. - gotShard, err := s.ShardForPeer(publisher, p) + gotPeer, err := s.PeerForShardIndex(peer.ID(pub), shardIdx) require.NoError(t, err) - assert.Equal(t, ShardIndex(shardIdx), gotShard, - "publisher=%s, shard=%d, peer=%s", publisher, shardIdx, p) + assert.Equal(t, peer.ID(local), gotPeer, + "local=%s publisher=%s shard=%d", local, pub, shardIdx) } } } -func TestSchedule_BroadcastTargets(t *testing.T) { - peers := testPeers("A", "B", "C", "D") - s := NewScheduler(peers) - - targets, err := s.BroadcastTargets(peer.ID("C")) +func TestScheduler_BroadcastTargets(t *testing.T) { + s, err := NewScheduler(peer.ID("C"), testPeers(t, "A", "B", "C", "D")) require.NoError(t, err) - // Should be all peers except C, in shard order. - expected := testPeers("A", "B", "D") - assert.Equal(t, expected, targets) + targets := s.BroadcastTargets() + + // Should contain all peers except the local peer. + assert.Len(t, targets, s.NumTotalShards()) + assert.Equal(t, []peer.ID{"A", "B", "D"}, targets) } -func TestSchedule_BroadcastTargets_PublisherNotFound(t *testing.T) { - peers := testPeers("A", "B") - s := NewScheduler(peers) +func TestScheduler_ValidateShardOrigin(t *testing.T) { + // Setup: peers [A, B, C, D], local = C + // For publisher A: shard 0 -> A(skip, publisher), shard 1 -> B, shard 2 -> D + // Wait — publisher A is at index 0, so: + // shard 0 -> B (peers[1]), shard 1 -> C (peers[2]), shard 2 -> D (peers[3]) + // So local=C is responsible for shard 1 when publisher=A. + s, err := NewScheduler(peer.ID("C"), testPeers(t, "A", "B", "C", "D")) + require.NoError(t, err) + + publisher := peer.ID("A") + + t.Run("valid direct shard from publisher", func(t *testing.T) { + // Shard 1's designated broadcaster is C (local peer). + // A direct shard means publisher sends to the designated broadcaster. + err := s.ValidateShardOrigin(publisher, publisher, 1) + assert.NoError(t, err) + }) + + t.Run("valid broadcast shard from designated peer", func(t *testing.T) { + // Shard 0's designated broadcaster is B. + // B broadcasts shard 0 to other peers including C. + err := s.ValidateShardOrigin(peer.ID("B"), publisher, 0) + assert.NoError(t, err) + }) + + t.Run("self-send rejected", func(t *testing.T) { + err := s.ValidateShardOrigin(peer.ID("C"), publisher, 0) + assert.Error(t, err) + }) + + t.Run("self-published shard sent back rejected", func(t *testing.T) { + // Local peer is both the publisher and the receiver — should error + err := s.ValidateShardOrigin(peer.ID("A"), peer.ID("C"), 0) + assert.Error(t, err) + }) - _, err := s.BroadcastTargets(peer.ID("Z")) - assert.Error(t, err) + t.Run("wrong sender rejected", func(t *testing.T) { + // Shard 0's designated broadcaster is B, but D sends it. + err := s.ValidateShardOrigin(peer.ID("D"), publisher, 0) + assert.Error(t, err) + }) } diff --git a/consensus/propeller/sharding.go b/consensus/propeller/sharding.go index 3760fae5b2..fa3d5f3c4a 100644 --- a/consensus/propeller/sharding.go +++ b/consensus/propeller/sharding.go @@ -6,7 +6,6 @@ import ( "github.com/NethermindEth/juno/consensus/propeller/merkle" "github.com/NethermindEth/juno/consensus/propeller/reedsolomon" - "github.com/NethermindEth/juno/consensus/propeller/utils" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" ) @@ -14,9 +13,10 @@ import ( // CreatePropellerUnits creates the PropellerUnits for publishing // todo(rdr): maybe call it create message for sharing or somth like that func CreatePropellerUnits( - committeeID CommitteeID, - message []byte, privKey crypto.PrivKey, + committeeID *CommitteeID, + nonce Nonce, + message []byte, numDataShards, parity int, ) ([]Unit, error) { @@ -25,7 +25,7 @@ func CreatePropellerUnits( return nil, fmt.Errorf("getting publisher id from private key: %w", publisherID) } - paddedMessage := utils.PadMessage(message, numDataShards) + paddedMessage := PadMessage(message, numDataShards) encodedMessage, err := reedsolomon.EncodeData(paddedMessage, numDataShards, parity) if err != nil { return nil, fmt.Errorf("encoding the message: %w", err) @@ -34,8 +34,7 @@ func CreatePropellerUnits( merkleRoot, merkleTree := merkle.New(encodedMessage) messageRoot := MessageRoot(merkleRoot) - // todo(rdr): check that this signing is correct - signature, err := utils.SignRoot(messageRoot, privKey) + signature, err := SignMessage(privKey, &messageRoot, committeeID, nonce) if err != nil { return nil, err } @@ -45,35 +44,39 @@ func CreatePropellerUnits( merkleProof := merkleTree[i] units[i] = Unit{ - CommitteeID: committeeID, + CommitteeID: *committeeID, Publisher: publisherID, MessageRoot: messageRoot, MerkleProof: merkleProof, Signature: signature, ShardIndex: ShardIndex(i), - ShardData: shard, + // todo(rdr): assigning one shard per unit until multi shard algo per unit + // is clear to me + ShardData: []Shard{shard}, } } return units, nil } -// DecodePropellerUnit receives Propeller units, recovers any missing data and returns +// ConstructMessageFromUnits receives Propeller units, recovers any missing data and returns // the fully verified message, together with the corresponding shard data and merkle proof. -// todo(rdr): maybe call it decode received message -func DecodePropellerUnit( - units []Unit, - messageRoot MessageRoot, +func ConstructMessageFromUnits( + units []*Unit, localShardIndex ShardIndex, numDataShards int, parity int, -) ([]byte, []byte, merkle.Proof, error) { +) ([]byte, ShardData, merkle.Proof, error) { if len(units) == 0 { return nil, nil, merkle.Proof{}, errors.New("no propeller units to decode") } shards := make([][]byte, len(units)) for i := range shards { - shards[i] = units[i].ShardData + if units[i] != nil { + // todo(rdr): we are assuming that every unit only carries one shard data for now + // Not sure how the matrix is built when unit carries more than one + shards[i] = units[i].ShardData[0] + } } shards, err := reedsolomon.RecoverData(shards, numDataShards, parity) @@ -94,6 +97,7 @@ func DecodePropellerUnit( merkleRoot, merkleTree := merkle.New(shards) + messageRoot := units[0].MessageRoot expectedRoot := MessageRoot(merkleRoot) if messageRoot != expectedRoot { // todo(rdr): probably need to write string methods for the MessageRoot type @@ -108,9 +112,16 @@ func DecodePropellerUnit( for i := range shards { copy(paddedMessage[i*shardSize:], shards[i]) } + message, err := UnpadMessage(paddedMessage) + if err != nil { + return nil, nil, merkle.Proof{}, fmt.Errorf("unpadding reconstructed message: %w", err) + } - localShard := shards[localShardIndex] + // todo(rdr): only one for now, but there can be more.TBD how that works + localShard := []Shard{ + shards[localShardIndex], + } localProof := merkleTree[localShardIndex] - return paddedMessage, localShard, localProof, nil + return message, localShard, localProof, nil } diff --git a/consensus/propeller/sharding_test.go b/consensus/propeller/sharding_test.go index 8c908dfe41..108f138015 100644 --- a/consensus/propeller/sharding_test.go +++ b/consensus/propeller/sharding_test.go @@ -1,204 +1,204 @@ package propeller -import ( - "testing" - - "github.com/libp2p/go-libp2p/core/peer" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// makeSchedule is a test helper that creates a schedule from N single-char peers. -func makeSchedule(n int) *Scheduler { - names := make([]peer.ID, n) - for i := range n { - names[i] = peer.ID(string(rune('A' + i))) - } - return NewScheduler(names) -} - -func TestEncodeMessage_RoundTrip(t *testing.T) { - tests := []struct { - name string - n int - msgLen int - }{ - {"4 peers, short message", 4, 10}, - {"4 peers, medium message", 4, 500}, - {"7 peers, short message", 7, 20}, - {"10 peers, 1KB message", 10, 1024}, - {"2 peers, tiny message", 2, 1}, - {"3 peers, empty message", 3, 0}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - schedule := makeSchedule(tc.n) - if schedule.NumShards() == 0 { - t.Skip("no shards for single peer") - } - - enc, err := NewEncoder( - schedule.NumDataShards(), schedule.NumCodingShards(), - ) - require.NoError(t, err) - - msg := make([]byte, tc.msgLen) - for i := range msg { - msg[i] = byte(i) - } - - units, root, err := EncodeMessage(msg, schedule, enc) - require.NoError(t, err) - assert.Len(t, units, schedule.NumShards()) - - // All units should reference the same root. - for _, u := range units { - assert.Equal(t, root, u.MerkleRoot) - } - - // Reconstruct from all shards. - shards := make([][]byte, schedule.NumShards()) - for _, u := range units { - shards[u.ShardIndex] = u.ShardData - } - - recovered, err := ReconstructMessage( - shards, schedule, enc, root, - ) - require.NoError(t, err) - assert.Equal(t, msg, recovered) - }) - } -} - -func TestEncodeMessage_ReconstructFromMinimumShards(t *testing.T) { - // With N=10 we have 3 data shards and 6 coding shards. - // We should be able to reconstruct from just the 3 data shards. - schedule := makeSchedule(10) - enc, err := NewEncoder( - schedule.NumDataShards(), schedule.NumCodingShards(), - ) - require.NoError(t, err) - - msg := []byte("reconstruct me from minimum shards please") - units, root, err := EncodeMessage(msg, schedule, enc) - require.NoError(t, err) - - // Keep only the first numDataShards shards. - shards := make([][]byte, schedule.NumShards()) - for i := range schedule.NumDataShards() { - shards[units[i].ShardIndex] = units[i].ShardData - } - - recovered, err := ReconstructMessage(shards, schedule, enc, root) - require.NoError(t, err) - assert.Equal(t, msg, recovered) -} - -func TestEncodeMessage_ReconstructWithMissingDataShards(t *testing.T) { - // With N=7 we have 2 data shards and 4 coding shards. - // Drop all data shards, keep only coding shards -> should reconstruct. - schedule := makeSchedule(7) - enc, err := NewEncoder( - schedule.NumDataShards(), schedule.NumCodingShards(), - ) - require.NoError(t, err) - - msg := []byte("even without data shards") - units, root, err := EncodeMessage(msg, schedule, enc) - require.NoError(t, err) - - // Keep only coding shards (indices >= numDataShards). - shards := make([][]byte, schedule.NumShards()) - for _, u := range units { - if int(u.ShardIndex) >= schedule.NumDataShards() { - shards[u.ShardIndex] = u.ShardData - } - } - - recovered, err := ReconstructMessage(shards, schedule, enc, root) - require.NoError(t, err) - assert.Equal(t, msg, recovered) -} - -func TestEncodeMessage_MerkleProofsVerify(t *testing.T) { - schedule := makeSchedule(5) - enc, err := NewEncoder( - schedule.NumDataShards(), schedule.NumCodingShards(), - ) - require.NoError(t, err) - - msg := []byte("verify all proofs") - units, root, err := EncodeMessage(msg, schedule, enc) - require.NoError(t, err) - - for _, u := range units { - ok := VerifyMerkleProof(root, u.ShardData, uint32(u.ShardIndex), u.MerkleProof) - assert.True(t, ok, "proof for shard %d should verify", u.ShardIndex) - } -} - -func TestReconstructMessage_MismatchedRoot(t *testing.T) { - schedule := makeSchedule(4) - enc, err := NewEncoder( - schedule.NumDataShards(), schedule.NumCodingShards(), - ) - require.NoError(t, err) - - msg := []byte("good message") - units, _, err := EncodeMessage(msg, schedule, enc) - require.NoError(t, err) - - shards := make([][]byte, schedule.NumShards()) - for _, u := range units { - shards[u.ShardIndex] = u.ShardData - } - - // Pass a wrong root. - fakeRoot := MessageRoot{0xff} - _, err = ReconstructMessage(shards, schedule, enc, fakeRoot) - require.Error(t, err) - - var reconErr *ReconstructionError - require.ErrorAs(t, err, &reconErr) - assert.Equal(t, ReasonMismatchedMessageRoot, reconErr.Reason) -} - -func TestReconstructMessage_InsufficientShards(t *testing.T) { - schedule := makeSchedule(10) // 3 data, 6 coding - enc, err := NewEncoder( - schedule.NumDataShards(), schedule.NumCodingShards(), - ) - require.NoError(t, err) - - msg := []byte("not enough shards") - units, root, err := EncodeMessage(msg, schedule, enc) - require.NoError(t, err) - - // Provide only 2 shards when 3 are needed. - shards := make([][]byte, schedule.NumShards()) - shards[units[0].ShardIndex] = units[0].ShardData - shards[units[1].ShardIndex] = units[1].ShardData - - _, err = ReconstructMessage(shards, schedule, enc, root) - require.Error(t, err) - - var reconErr *ReconstructionError - require.ErrorAs(t, err, &reconErr) - assert.Equal(t, ReasonErasureReconstructionFailed, reconErr.Reason) -} - -func TestEncodeMessage_NoShards(t *testing.T) { - // A single-node schedule has no shards. - schedule := makeSchedule(1) - enc, err := NewEncoder(1, 0) - require.NoError(t, err) - - _, _, err = EncodeMessage([]byte("x"), schedule, enc) - require.Error(t, err) - - var pubErr *ShardPublishError - require.ErrorAs(t, err, &pubErr) - assert.Equal(t, ReasonInvalidDataSize, pubErr.Reason) -} +// import ( +// "testing" +// +// "github.com/libp2p/go-libp2p/core/peer" +// "github.com/stretchr/testify/assert" +// "github.com/stretchr/testify/require" +// ) +// +// // makeSchedule is a test helper that creates a schedule from N single-char peers. +// func makeSchedule(n int) *Scheduler { +// names := make([]peer.ID, n) +// for i := range n { +// names[i] = peer.ID(string(rune('A' + i))) +// } +// return NewScheduler(names) +// } +// +// func TestEncodeMessage_RoundTrip(t *testing.T) { +// tests := []struct { +// name string +// n int +// msgLen int +// }{ +// {"4 peers, short message", 4, 10}, +// {"4 peers, medium message", 4, 500}, +// {"7 peers, short message", 7, 20}, +// {"10 peers, 1KB message", 10, 1024}, +// {"2 peers, tiny message", 2, 1}, +// {"3 peers, empty message", 3, 0}, +// } +// +// for _, tc := range tests { +// t.Run(tc.name, func(t *testing.T) { +// schedule := makeSchedule(tc.n) +// if schedule.NumShards() == 0 { +// t.Skip("no shards for single peer") +// } +// +// enc, err := NewEncoder( +// schedule.NumDataShards(), schedule.NumCodingShards(), +// ) +// require.NoError(t, err) +// +// msg := make([]byte, tc.msgLen) +// for i := range msg { +// msg[i] = byte(i) +// } +// +// units, root, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// assert.Len(t, units, schedule.NumShards()) +// +// // All units should reference the same root. +// for _, u := range units { +// assert.Equal(t, root, u.MerkleRoot) +// } +// +// // Reconstruct from all shards. +// shards := make([][]byte, schedule.NumShards()) +// for _, u := range units { +// shards[u.ShardIndex] = u.ShardData +// } +// +// recovered, err := ReconstructMessage( +// shards, schedule, enc, root, +// ) +// require.NoError(t, err) +// assert.Equal(t, msg, recovered) +// }) +// } +// } +// +// func TestEncodeMessage_ReconstructFromMinimumShards(t *testing.T) { +// // With N=10 we have 3 data shards and 6 coding shards. +// // We should be able to reconstruct from just the 3 data shards. +// schedule := makeSchedule(10) +// enc, err := NewEncoder( +// schedule.NumDataShards(), schedule.NumCodingShards(), +// ) +// require.NoError(t, err) +// +// msg := []byte("reconstruct me from minimum shards please") +// units, root, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// +// // Keep only the first numDataShards shards. +// shards := make([][]byte, schedule.NumShards()) +// for i := range schedule.NumDataShards() { +// shards[units[i].ShardIndex] = units[i].ShardData +// } +// +// recovered, err := ReconstructMessage(shards, schedule, enc, root) +// require.NoError(t, err) +// assert.Equal(t, msg, recovered) +// } +// +// func TestEncodeMessage_ReconstructWithMissingDataShards(t *testing.T) { +// // With N=7 we have 2 data shards and 4 coding shards. +// // Drop all data shards, keep only coding shards -> should reconstruct. +// schedule := makeSchedule(7) +// enc, err := NewEncoder( +// schedule.NumDataShards(), schedule.NumCodingShards(), +// ) +// require.NoError(t, err) +// +// msg := []byte("even without data shards") +// units, root, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// +// // Keep only coding shards (indices >= numDataShards). +// shards := make([][]byte, schedule.NumShards()) +// for _, u := range units { +// if int(u.ShardIndex) >= schedule.NumDataShards() { +// shards[u.ShardIndex] = u.ShardData +// } +// } +// +// recovered, err := ReconstructMessage(shards, schedule, enc, root) +// require.NoError(t, err) +// assert.Equal(t, msg, recovered) +// } +// +// func TestEncodeMessage_MerkleProofsVerify(t *testing.T) { +// schedule := makeSchedule(5) +// enc, err := NewEncoder( +// schedule.NumDataShards(), schedule.NumCodingShards(), +// ) +// require.NoError(t, err) +// +// msg := []byte("verify all proofs") +// units, root, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// +// for _, u := range units { +// ok := VerifyMerkleProof(root, u.ShardData, uint32(u.ShardIndex), u.MerkleProof) +// assert.True(t, ok, "proof for shard %d should verify", u.ShardIndex) +// } +// } +// +// func TestReconstructMessage_MismatchedRoot(t *testing.T) { +// schedule := makeSchedule(4) +// enc, err := NewEncoder( +// schedule.NumDataShards(), schedule.NumCodingShards(), +// ) +// require.NoError(t, err) +// +// msg := []byte("good message") +// units, _, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// +// shards := make([][]byte, schedule.NumShards()) +// for _, u := range units { +// shards[u.ShardIndex] = u.ShardData +// } +// +// // Pass a wrong root. +// fakeRoot := MessageRoot{0xff} +// _, err = ReconstructMessage(shards, schedule, enc, fakeRoot) +// require.Error(t, err) +// +// var reconErr *ReconstructionError +// require.ErrorAs(t, err, &reconErr) +// assert.Equal(t, ReasonMismatchedMessageRoot, reconErr.Reason) +// } +// +// func TestReconstructMessage_InsufficientShards(t *testing.T) { +// schedule := makeSchedule(10) // 3 data, 6 coding +// enc, err := NewEncoder( +// schedule.NumDataShards(), schedule.NumCodingShards(), +// ) +// require.NoError(t, err) +// +// msg := []byte("not enough shards") +// units, root, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// +// // Provide only 2 shards when 3 are needed. +// shards := make([][]byte, schedule.NumShards()) +// shards[units[0].ShardIndex] = units[0].ShardData +// shards[units[1].ShardIndex] = units[1].ShardData +// +// _, err = ReconstructMessage(shards, schedule, enc, root) +// require.Error(t, err) +// +// var reconErr *ReconstructionError +// require.ErrorAs(t, err, &reconErr) +// assert.Equal(t, ReasonErasureReconstructionFailed, reconErr.Reason) +// } +// +// func TestEncodeMessage_NoShards(t *testing.T) { +// // A single-node schedule has no shards. +// schedule := makeSchedule(1) +// enc, err := NewEncoder(1, 0) +// require.NoError(t, err) +// +// _, _, err = EncodeMessage([]byte("x"), schedule, enc) +// require.Error(t, err) +// +// var pubErr *ShardPublishError +// require.ErrorAs(t, err, &pubErr) +// assert.Equal(t, ReasonInvalidDataSize, pubErr.Reason) +// } diff --git a/consensus/propeller/unit_test.go b/consensus/propeller/unit_test.go index e69de29bb2..bb61865bd4 100644 --- a/consensus/propeller/unit_test.go +++ b/consensus/propeller/unit_test.go @@ -0,0 +1 @@ +package propeller_test From caf7ad99c9efb486934078f9ca3ada37d3c46a93 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 11 Apr 2026 17:08:38 +0100 Subject: [PATCH 32/40] refactor: scheduler.go tests --- consensus/propeller/scheduler_test.go | 196 +++++++++++++++----------- 1 file changed, 117 insertions(+), 79 deletions(-) diff --git a/consensus/propeller/scheduler_test.go b/consensus/propeller/scheduler_test.go index e849fc7e4e..60bedf2d43 100644 --- a/consensus/propeller/scheduler_test.go +++ b/consensus/propeller/scheduler_test.go @@ -161,48 +161,42 @@ func TestScheduler_DeterministicMapping(t *testing.T) { } } -func TestScheduler_PeerForShardIndex_SpecExample(t *testing.T) { - // From the doc comment: peers [A, B, C, D], publisher = C (index 2). - // Shard 0 -> A, Shard 1 -> B, Shard 2 -> D - s, err := NewScheduler(peer.ID("A"), testPeers(t, "A", "B", "C", "D")) - require.NoError(t, err) - - publisher := peer.ID("C") - expected := []peer.ID{"A", "B", "D"} - for i, want := range expected { - got, err := s.PeerForShardIndex(publisher, ShardIndex(i)) - require.NoError(t, err) - assert.Equal(t, want, got, "shard %d", i) - } -} - -func TestScheduler_PeerForShardIndex_PublisherFirst(t *testing.T) { - // Publisher is the first peer in sorted order. - s, err := NewScheduler(peer.ID("B"), testPeers(t, "A", "B", "C", "D")) - require.NoError(t, err) - - publisher := peer.ID("A") - // Shard 0 -> B, Shard 1 -> C, Shard 2 -> D - expected := []peer.ID{"B", "C", "D"} - for i, want := range expected { - got, err := s.PeerForShardIndex(publisher, ShardIndex(i)) - require.NoError(t, err) - assert.Equal(t, want, got, "shard %d", i) +func TestScheduler_PeerForShardIndex(t *testing.T) { + // peers [A, B, C, D]: the publisher is skipped in the sorted list, + // so each remaining peer maps to shard indices 0..2 in order. + tests := []struct { + name string + publisher string + expected []peer.ID + }{ + { + name: "publisher middle (C, index 2)", + publisher: "C", + expected: []peer.ID{"A", "B", "D"}, + }, + { + name: "publisher first (A, index 0)", + publisher: "A", + expected: []peer.ID{"B", "C", "D"}, + }, + { + name: "publisher last (D, index 3)", + publisher: "D", + expected: []peer.ID{"A", "B", "C"}, + }, } -} -func TestScheduler_PeerForShardIndex_PublisherLast(t *testing.T) { - // Publisher is the last peer in sorted order. s, err := NewScheduler(peer.ID("A"), testPeers(t, "A", "B", "C", "D")) require.NoError(t, err) - publisher := peer.ID("D") - // Shard 0 -> A, Shard 1 -> B, Shard 2 -> C - expected := []peer.ID{"A", "B", "C"} - for i, want := range expected { - got, err := s.PeerForShardIndex(publisher, ShardIndex(i)) - require.NoError(t, err) - assert.Equal(t, want, got, "shard %d", i) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + for i, want := range tc.expected { + got, err := s.PeerForShardIndex(peer.ID(tc.publisher), ShardIndex(i)) + require.NoError(t, err) + assert.Equal(t, want, got, "shard %d", i) + } + }) } } @@ -289,55 +283,99 @@ func TestScheduler_InverseProperty(t *testing.T) { } func TestScheduler_BroadcastTargets(t *testing.T) { - s, err := NewScheduler(peer.ID("C"), testPeers(t, "A", "B", "C", "D")) - require.NoError(t, err) + // BroadcastTargets returns every peer except the local peer, in sorted order. + tests := []struct { + name string + local string + expected []peer.ID + }{ + { + name: "local first", + local: "A", + expected: []peer.ID{"B", "C", "D"}, + }, + { + name: "local middle", + local: "C", + expected: []peer.ID{"A", "B", "D"}, + }, + { + name: "local last", + local: "D", + expected: []peer.ID{"A", "B", "C"}, + }, + } - targets := s.BroadcastTargets() + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + s, err := NewScheduler(peer.ID(tc.local), testPeers(t, "A", "B", "C", "D")) + require.NoError(t, err) - // Should contain all peers except the local peer. - assert.Len(t, targets, s.NumTotalShards()) - assert.Equal(t, []peer.ID{"A", "B", "D"}, targets) + targets := s.BroadcastTargets() + assert.Equal(t, tc.expected, targets) + }) + } } func TestScheduler_ValidateShardOrigin(t *testing.T) { - // Setup: peers [A, B, C, D], local = C - // For publisher A: shard 0 -> A(skip, publisher), shard 1 -> B, shard 2 -> D - // Wait — publisher A is at index 0, so: - // shard 0 -> B (peers[1]), shard 1 -> C (peers[2]), shard 2 -> D (peers[3]) - // So local=C is responsible for shard 1 when publisher=A. + // peers [A, B, C, D], local = C. + // For publisher A (index 0): shard 0 -> B, shard 1 -> C, shard 2 -> D + // So local=C is the designated broadcaster for shard 1 when publisher=A. s, err := NewScheduler(peer.ID("C"), testPeers(t, "A", "B", "C", "D")) require.NoError(t, err) - publisher := peer.ID("A") - - t.Run("valid direct shard from publisher", func(t *testing.T) { - // Shard 1's designated broadcaster is C (local peer). - // A direct shard means publisher sends to the designated broadcaster. - err := s.ValidateShardOrigin(publisher, publisher, 1) - assert.NoError(t, err) - }) - - t.Run("valid broadcast shard from designated peer", func(t *testing.T) { - // Shard 0's designated broadcaster is B. - // B broadcasts shard 0 to other peers including C. - err := s.ValidateShardOrigin(peer.ID("B"), publisher, 0) - assert.NoError(t, err) - }) - - t.Run("self-send rejected", func(t *testing.T) { - err := s.ValidateShardOrigin(peer.ID("C"), publisher, 0) - assert.Error(t, err) - }) - - t.Run("self-published shard sent back rejected", func(t *testing.T) { - // Local peer is both the publisher and the receiver — should error - err := s.ValidateShardOrigin(peer.ID("A"), peer.ID("C"), 0) - assert.Error(t, err) - }) + tests := []struct { + name string + sender string + publisher string + shardIndex ShardIndex + wantErr bool + }{ + { + name: "valid direct shard from publisher", + sender: "A", + publisher: "A", + shardIndex: 1, // C is the designated broadcaster, so publisher sends directly + wantErr: false, + }, + { + name: "valid broadcast shard from designated peer", + sender: "B", + publisher: "A", + shardIndex: 0, // B is the designated broadcaster for shard 0 + wantErr: false, + }, + { + name: "self-send rejected", + sender: "C", + publisher: "A", + shardIndex: 0, + wantErr: true, + }, + { + name: "self-published shard sent back rejected", + sender: "A", + publisher: "C", // local peer is the publisher + shardIndex: 0, + wantErr: true, + }, + { + name: "wrong sender rejected", + sender: "D", + publisher: "A", + shardIndex: 0, // designated broadcaster is B, not D + wantErr: true, + }, + } - t.Run("wrong sender rejected", func(t *testing.T) { - // Shard 0's designated broadcaster is B, but D sends it. - err := s.ValidateShardOrigin(peer.ID("D"), publisher, 0) - assert.Error(t, err) - }) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := s.ValidateShardOrigin(peer.ID(tc.sender), peer.ID(tc.publisher), tc.shardIndex) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } } From a7549a30ea65afa1180fb41866e7e9189a3ba855 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Tue, 14 Apr 2026 11:48:28 +0100 Subject: [PATCH 33/40] refactor: update old consensus and p2p code to match current style --- consensus/consensus.go | 4 +++- consensus/p2p/buffered/proto_broadcaster.go | 6 ++++-- consensus/p2p/validator/proposal_stream.go | 4 +++- consensus/p2p/vote/vote_broadcasters.go | 12 +++++++++--- p2p/pubsub/pubsub.go | 4 ++-- 5 files changed, 21 insertions(+), 9 deletions(-) diff --git a/consensus/consensus.go b/consensus/consensus.go index 4365e101d6..6aabd49126 100644 --- a/consensus/consensus.go +++ b/consensus/consensus.go @@ -52,7 +52,9 @@ func Init( } currentHeight := types.Height(chainHeight + 1) - tendermintDB := consensusDB.NewTendermintDB[starknet.Value, starknet.Hash, starknet.Address](database) + tendermintDB := consensusDB.NewTendermintDB[ + starknet.Value, starknet.Hash, starknet.Address, + ](database) executor := builder.NewExecutor(blockchain, vm, logger, false, false) builder := builder.New(blockchain, executor) diff --git a/consensus/p2p/buffered/proto_broadcaster.go b/consensus/p2p/buffered/proto_broadcaster.go index 42e527b20e..dc4e417258 100644 --- a/consensus/p2p/buffered/proto_broadcaster.go +++ b/consensus/p2p/buffered/proto_broadcaster.go @@ -57,7 +57,8 @@ func (b ProtoBroadcaster[M]) Loop(ctx context.Context, topic *pubsub.Topic) { } for { - if err := topic.Publish(ctx, msgBytes); err != nil && !errors.Is(err, context.Canceled) { + err := topic.Publish(ctx, msgBytes) + if err != nil && !errors.Is(err, context.Canceled) { b.logger.Error("unable to send message", zap.Error(err)) time.Sleep(b.retryInterval) continue @@ -70,7 +71,8 @@ func (b ProtoBroadcaster[M]) Loop(ctx context.Context, topic *pubsub.Topic) { } case <-rebroadcasted.trigger: for msgBytes := range rebroadcasted.messages { - if err := topic.Publish(ctx, msgBytes); err != nil && !errors.Is(err, context.Canceled) { + err := topic.Publish(ctx, msgBytes) + if err != nil && !errors.Is(err, context.Canceled) { b.logger.Error("unable to rebroadcast message", zap.Error(err)) } } diff --git a/consensus/p2p/validator/proposal_stream.go b/consensus/p2p/validator/proposal_stream.go index b519294704..b39a21e85b 100644 --- a/consensus/p2p/validator/proposal_stream.go +++ b/consensus/p2p/validator/proposal_stream.go @@ -48,7 +48,9 @@ func newSingleProposalStream( } } -func (s *proposalStream) start(ctx context.Context, firstMessage *consensus.StreamMessage) (types.Height, error) { +func (s *proposalStream) start( + ctx context.Context, firstMessage *consensus.StreamMessage, +) (types.Height, error) { content := firstMessage.GetContent() if content == nil { return 0, fmt.Errorf("first message has empty content") diff --git a/consensus/p2p/vote/vote_broadcasters.go b/consensus/p2p/vote/vote_broadcasters.go index 603ae12944..853aa8a02c 100644 --- a/consensus/p2p/vote/vote_broadcasters.go +++ b/consensus/p2p/vote/vote_broadcasters.go @@ -34,7 +34,9 @@ func NewVoteBroadcaster[H types.Hash, A types.Addr]( } } -func (b *voteBroadcaster[H, A]) broadcast(ctx context.Context, message *types.Vote[H, A], voteType consensus.Vote_VoteType) { +func (b *voteBroadcaster[H, A]) broadcast( + ctx context.Context, message *types.Vote[H, A], voteType consensus.Vote_VoteType, +) { msg, err := b.voteAdapter.FromVote(message, voteType) if err != nil { b.logger.Error("unable to convert vote", zap.Error(err)) @@ -60,6 +62,10 @@ func (b *prevoteBroadcaster[H, A]) Broadcast(ctx context.Context, message *types type precommitBroadcaster[H types.Hash, A types.Addr] voteBroadcaster[H, A] -func (b *precommitBroadcaster[H, A]) Broadcast(ctx context.Context, message *types.Precommit[H, A]) { - (*voteBroadcaster[H, A])(b).broadcast(ctx, (*types.Vote[H, A])(message), consensus.Vote_Precommit) +func (b *precommitBroadcaster[H, A]) Broadcast( + ctx context.Context, message *types.Precommit[H, A], +) { + (*voteBroadcaster[H, A])(b).broadcast( + ctx, (*types.Vote[H, A])(message), consensus.Vote_Precommit, + ) } diff --git a/p2p/pubsub/pubsub.go b/p2p/pubsub/pubsub.go index a2ee1f1419..7a2f98a7ff 100644 --- a/p2p/pubsub/pubsub.go +++ b/p2p/pubsub/pubsub.go @@ -15,8 +15,6 @@ import ( "github.com/libp2p/go-libp2p/p2p/discovery/routing" ) -const gossipSubHistory = 60 - func GetHost(hostPrivateKey crypto.PrivKey, hostAddress string) (host.Host, error) { return libp2p.New( libp2p.ListenAddrStrings(hostAddress), @@ -49,6 +47,8 @@ func Run( } params := pubsub.DefaultGossipSubParams() + + const gossipSubHistory = 60 params.HistoryLength = gossipSubHistory params.HistoryGossip = gossipSubHistory From 44a4e8fe1bf5eec5fb9fc5dc2e0ba6f3982a370a Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Tue, 14 Apr 2026 11:48:52 +0100 Subject: [PATCH 34/40] refactor: rename Validator to UnitValidator --- consensus/propeller/engine.go | 23 ++++++++++------------- consensus/propeller/processor.go | 6 +++--- consensus/propeller/validator.go | 15 ++++++++------- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/consensus/propeller/engine.go b/consensus/propeller/engine.go index 68e75323ed..f837f74107 100644 --- a/consensus/propeller/engine.go +++ b/consensus/propeller/engine.go @@ -85,19 +85,17 @@ func (processUnit) isCommand() // Engine is the central orchestrator of the Propeller protocol. It: // -// - Manages channel registrations (each channel has its own peer set and schedule). -// - Routes incoming PropellerUnits to the correct MessageProcessor. -// - Handles broadcast requests from the application layer. -// - Collects and forwards events from processors to the application. -// -// The engine is designed to be run as a single long-lived goroutine via Run(). -// External callers interact with it through thread-safe methods that send -// commands on internal channels, so no locks are needed on the hot path. +// - Manages committee registrations (each committee has its own peer set and scheduler). +// - Process all incoming messages and broadcasts them when expected. +// - Handles broadcast requests from the service layer. +// - Forwards all noteworthy event to the service layer. type Engine struct { - localPeer peer.ID privKey crypto.PrivKey - config Config - log utils.StructuredLogger + localPeer peer.ID + + config Config + log utils.StructuredLogger + // processor handles validates and process all the messages received by other peers processor *Processor @@ -132,7 +130,6 @@ type Engine struct { // Call Run() to start processing. // // Parameters: -// - localPeer: this node's peer ID. // - privKey: this node's Ed25519 private key (for signing published messages). // - config: protocol parameters. // - log: structured logger. @@ -225,7 +222,7 @@ func (e *Engine) unregisterCommittee(committeeID *CommitteeID) { delete(e.committees, *committeeID) // todo(rdr): We have to clean the processors, right? // or will they shut down on their own eventually - // better to pass a context with cancelj + // better to pass a context with cancel? e.log.Info("unregistered propeller committee", // todo(rdr): give a proper string representation diff --git a/consensus/propeller/processor.go b/consensus/propeller/processor.go index 78168993d4..e45ae18191 100644 --- a/consensus/propeller/processor.go +++ b/consensus/propeller/processor.go @@ -29,7 +29,7 @@ type subprocessor struct { // todo(rdr): I think I would like it more if it is called UnitValidator since // is more specfic - validator Validator + validator UnitValidator } func newSubprocessor( @@ -99,7 +99,7 @@ func (s *subprocessor) beforeMessageBuiltStage(ctx context.Context) ( case unitWithSender := <-s.unitsChan: unit := unitWithSender.unit sender := unitWithSender.sender - if err := s.validator.ValidateUnit(unit, sender); err != nil { + if err := s.validator.Validate(unit, sender); err != nil { s.invalidUnitsChan <- invalidUnit{ // todo(rdr): not sure if we need message key. // We just want to penalize the sender @@ -179,7 +179,7 @@ func (s *subprocessor) beforeMessageReceivedStage( case unitWithSender := <-s.unitsChan: unit := unitWithSender.unit sender := unitWithSender.sender - if err := s.validator.ValidateUnit(unit, sender); err != nil { + if err := s.validator.Validate(unit, sender); err != nil { s.invalidUnitsChan <- invalidUnit{ messageKey: extractKey(unit), sender: sender, diff --git a/consensus/propeller/validator.go b/consensus/propeller/validator.go index 81c13001fd..20773d7af0 100644 --- a/consensus/propeller/validator.go +++ b/consensus/propeller/validator.go @@ -11,13 +11,14 @@ import ( ) // todo(rdr): A validator lifetime is attached to a `subprocessor`. A `subprocessor` is attached -// to a message key field. This logic is handled by a `Processor`. This means that a validator will // always be given units that have the same committeeID, publisher, messageRoot and Nonce (the +// to a message key field. This logic is handled by a `Processor`. This means that a validator will +// always be given units that have the same committeeID, publisher, messageRoot and Nonce (the // current fields of a `messageKey`). Does it makes sense for the validator to also hold a copy // of this. Is there a way of testing this invariant – where a validator only sees the same // fields. I need to add a test for that invariant // Validates all the incoming units / shards given a committee and the publisher -type Validator struct { +type UnitValidator struct { publisherPubKey crypto.PubKey scheduler *Scheduler @@ -29,14 +30,14 @@ type Validator struct { } // todo(rdr): maybe just pass the publisher? -func NewValidator(publisher peer.ID, scheduler *Scheduler) Validator { +func NewValidator(publisher peer.ID, scheduler *Scheduler) UnitValidator { pubKey, err := publisher.ExtractPublicKey() // for now we are assuming that extracting a publisher key is always successful // and done in constant time if err != nil { panic(err) } - return Validator{ + return UnitValidator{ publisherPubKey: pubKey, scheduler: scheduler, receivedShards: make(map[ShardIndex]struct{}, scheduler.NumDataShards()), @@ -44,7 +45,7 @@ func NewValidator(publisher peer.ID, scheduler *Scheduler) Validator { } } -func (v *Validator) verifyDataShards(unit *Unit) error { +func (v *UnitValidator) verifyDataShards(unit *Unit) error { if len(unit.ShardData) != 1 { return fmt.Errorf( "unexpected amount of shards. Expected %d. Received %d", @@ -63,7 +64,7 @@ func (v *Validator) verifyDataShards(unit *Unit) error { return errors.New("data shards verification failed") } -func (v *Validator) verifySignature(unit *Unit) error { +func (v *UnitValidator) verifySignature(unit *Unit) error { if v.verifiedSignature != nil { if bytes.Equal(v.verifiedSignature, unit.Signature) { return nil @@ -91,7 +92,7 @@ func (v *Validator) verifySignature(unit *Unit) error { return nil } -func (v *Validator) ValidateUnit(unit *Unit, sender peer.ID) error { +func (v *UnitValidator) Validate(unit *Unit, sender peer.ID) error { if _, ok := v.receivedShards[unit.ShardIndex]; ok { return fmt.Errorf("duplicated shard %d received", unit.ShardIndex) } From 5ee8ae711c90f3fd89e28539eff32266e217ca00 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Tue, 14 Apr 2026 11:49:33 +0100 Subject: [PATCH 35/40] refactor: rename validator.go -> unit_validator.go --- consensus/propeller/{validator.go => unit_validator.go} | 0 consensus/propeller/{validator_test.go => unit_validator_test.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename consensus/propeller/{validator.go => unit_validator.go} (100%) rename consensus/propeller/{validator_test.go => unit_validator_test.go} (100%) diff --git a/consensus/propeller/validator.go b/consensus/propeller/unit_validator.go similarity index 100% rename from consensus/propeller/validator.go rename to consensus/propeller/unit_validator.go diff --git a/consensus/propeller/validator_test.go b/consensus/propeller/unit_validator_test.go similarity index 100% rename from consensus/propeller/validator_test.go rename to consensus/propeller/unit_validator_test.go From 88d6f597c1024993ec4240e6aa7af3b6b0bc621c Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Wed, 15 Apr 2026 14:33:30 +0100 Subject: [PATCH 36/40] refactor: improve propeller service design --- consensus/propeller/engine.go | 44 ++++---------- consensus/propeller/processor.go | 56 +++++++++++------ consensus/propeller/propeller.go | 100 ++++++++++++++++++++++++++----- consensus/propeller/sharding.go | 6 +- 4 files changed, 139 insertions(+), 67 deletions(-) diff --git a/consensus/propeller/engine.go b/consensus/propeller/engine.go index f837f74107..8f10f53742 100644 --- a/consensus/propeller/engine.go +++ b/consensus/propeller/engine.go @@ -114,12 +114,9 @@ type Engine struct { // eventCh is shared between all processors and the engine. The engine // reads from it and forwards events to the application via Events(). - eventCh chan any - - // appEventCh is the externally-visible event channel. The engine copies - // events from eventCh to appEventCh in its Run() loop, filtering out - // internal events as needed. - appEventCh chan any + // todo(rdr): currently sent directly from the processor to the service, + // does the engine needs to do any filtering? + // eventCh chan Event // cmdCh receives commands from the propeller service and act on those cmdCh <-chan engineCommand @@ -140,14 +137,14 @@ func NewEngine( privKey crypto.PrivKey, config *Config, log utils.StructuredLogger, -) (*Engine, chan<- engineCommand) { +) (*Engine, chan<- engineCommand, <-chan Event) { localPeerID, err := peer.IDFromPrivateKey(privKey) if err != nil { // todo(rdr): pannic for now, error handling for later panic(err) } - processor := NewProcessor(localPeerID, config) + processor, eventsCh := NewProcessor(localPeerID, config) cmdCh := make(chan engineCommand) @@ -162,9 +159,7 @@ func NewEngine( unitsPrepared: make(chan broadcastResult), // Unsure of the fields below connectedPeers: make(map[peer.ID]struct{}), - eventCh: make(chan any, eventChSize), - appEventCh: make(chan any, appEventChSize), - }, cmdCh + }, cmdCh, eventsCh } // registerCommittee creates the schedule and encoder for a new channel. @@ -230,9 +225,9 @@ func (e *Engine) unregisterCommittee(committeeID *CommitteeID) { ) } -// prepareBroadcast creates Proppeller units asynchronously since it is a very expensive +// prepareUnitsForBroadcast creates Proppeller units asynchronously since it is a very expensive // operation. -func (e *Engine) prepareBroadcast(committeeID *CommitteeID, data []byte) error { +func (e *Engine) prepareUnitsForBroadcast(committeeID *CommitteeID, data []byte) error { cs, ok := e.committees[*committeeID] if !ok { return fmt.Errorf("cannot broadcast to an unregistered committee: %s", committeeID) @@ -279,6 +274,7 @@ func (e *Engine) broadcast(units []Unit) error { } // todo(rdr): I need to do the actual sending + // I need to pass to the eventCh all the units that it should receive return nil } @@ -302,16 +298,6 @@ func (e *Engine) processUnit(ctx context.Context, unit *Unit, sender peer.ID) { } } -// forwardEvent sends an event to the application's event channel. Non-blocking -// to avoid stalling the engine if the application is slow to consume events. -func (e *Engine) forwardEvent(event any) { - select { - case e.appEventCh <- event: - default: - e.log.Warn("dropping event: application event channel full") - } -} - // handleCommand dispatches a command to the appropriate handler. func (e *Engine) handleCommand(ctx context.Context, command engineCommand) { switch cmd := command.(type) { @@ -321,16 +307,16 @@ func (e *Engine) handleCommand(ctx context.Context, command engineCommand) { case *unregisterCommittee: e.unregisterCommittee(&cmd.committeeID) case *broadcast: - err := e.prepareBroadcast(&cmd.committeeID, cmd.msg) + // we might need to pass the error channel here so that the internal go-routine + // can forward it correctly (assuming a per command error channel) + err := e.prepareUnitsForBroadcast(&cmd.committeeID, cmd.msg) cmd.errCh <- err case *processUnit: e.processUnit(ctx, cmd.unit, cmd.sender) } } -// Run starts the engine's main loop. It blocks until the context is cancelled. -// This should be called in its own goroutine. -// +// Run starts the engine's main loop until context is cancelled. // The loop processes three things concurrently: // 1. Commands from external callers (register, broadcast, handle incoming unit). // 2. Events from message processors (forward to application). @@ -349,10 +335,6 @@ func (e *Engine) Run(ctx context.Context) error { e.log.Error("couldn't prepare units", zap.Error(broadcastResult.err)) } e.broadcast(broadcastResult.units) - - case event := <-e.eventCh: - // Forward application-visible events from processors. - e.forwardEvent(event) } } } diff --git a/consensus/propeller/processor.go b/consensus/propeller/processor.go index e45ae18191..a7500df7ec 100644 --- a/consensus/propeller/processor.go +++ b/consensus/propeller/processor.go @@ -14,6 +14,29 @@ import ( "go.uber.org/zap" ) +type Event interface { + isEvent() +} + +type messageFinalized struct { + message []byte +} + +func (*messageFinalized) isEvent() {} + +type broadcastUnit struct { + unit *Unit + peers []peer.ID +} + +func (*broadcastUnit) isEvent() {} + +type broadcastMessage struct { + unit []Unit +} + +func (*broadcastMessage) isEvent() {} + type unitWithSender struct { unit *Unit sender peer.ID @@ -26,9 +49,8 @@ type subprocessor struct { unitsChan <-chan unitWithSender invalidUnitsChan chan<- invalidUnit + processingEvents chan<- Event - // todo(rdr): I think I would like it more if it is called UnitValidator since - // is more specfic validator UnitValidator } @@ -52,7 +74,7 @@ func newSubprocessor( } } -func (s *subprocessor) broadcastUnit(unit *Unit) error { +func (s *subprocessor) broadcastUnit(unit *Unit) { index := 0 peers := make([]peer.ID, len(s.scheduler.Peers())-2) for _, peerCommittee := range s.scheduler.Peers() { @@ -67,8 +89,10 @@ func (s *subprocessor) broadcastUnit(unit *Unit) error { peers[i], peers[j] = peers[j], peers[i] }) - // todo(rdr): This should forward the unit and the peers that require broadcasting - panic("not implemented") + s.processingEvents <- &broadcastUnit{ + unit: unit, + peers: peers, + } } func (s *subprocessor) beforeMessageBuiltStage(ctx context.Context) ( @@ -121,11 +145,7 @@ func (s *subprocessor) beforeMessageBuiltStage(ctx context.Context) ( // broadcast as soon as I get my shard if !localShardWasBroadcast && s.localShardIndex == unit.ShardIndex { localShardWasBroadcast = true - err := s.broadcastUnit(unit) - if err != nil { - // todo(rdr): tbd if we need an error here - panic(err) - } + s.broadcastUnit(unit) } } } @@ -155,11 +175,7 @@ func (s *subprocessor) beforeMessageBuiltStage(ctx context.Context) ( ShardIndex: s.localShardIndex, ShardData: localShardData, } - err := s.broadcastUnit(&localUnit) - if err != nil { - // todo(rdr): tbd if we need an error here - panic(err) - } + s.broadcastUnit(&localUnit) unitCount += 1 } @@ -272,8 +288,10 @@ type Processor struct { subProcessors map[messageKey]chan<- unitWithSender // channel through wich subprocessors signal they have finalized execution subProcessorsFinalized chan finalizedSubprocessor - // channel through which subprocessor sharedunits that failed validation + // channel through which subprocessor share units that failed validation invalidUnits chan invalidUnit + // channel through which important events are shared + processingEvents chan<- Event // track current open and closed tasks to avoid resource starvation mu sync.Mutex @@ -286,8 +304,9 @@ type Processor struct { log utils.StructuredLogger } -func NewProcessor(localPeer peer.ID, config *Config) *Processor { +func NewProcessor(localPeer peer.ID, config *Config) (*Processor, <-chan Event) { timeout := config.StaleMessageTimeout + processingEvents := make(chan Event) return &Processor{ finalized: timecache.New[messageKey](2048, timeout), @@ -295,6 +314,7 @@ func NewProcessor(localPeer peer.ID, config *Config) *Processor { subProcessors: make(map[messageKey]chan<- unitWithSender), subProcessorsFinalized: make(chan finalizedSubprocessor), invalidUnits: make(chan invalidUnit), + processingEvents: processingEvents, mu: sync.Mutex{}, publisherTasks: make(map[peer.ID]uint64), @@ -308,7 +328,7 @@ func NewProcessor(localPeer peer.ID, config *Config) *Processor { maxWorkers: 1000, maxWorkersPerPublisher: 250, }, - } + }, processingEvents } func (p *Processor) Run(ctx context.Context) { diff --git a/consensus/propeller/propeller.go b/consensus/propeller/propeller.go index 24dfce294e..5717993367 100644 --- a/consensus/propeller/propeller.go +++ b/consensus/propeller/propeller.go @@ -10,20 +10,27 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" "go.uber.org/zap" "google.golang.org/protobuf/proto" ) // This would represent the propeller service that glues the whole // thing to p2p. Thing is, I've no clue how to do that. -type Service interface{} +type Service any type propellerService struct { - host host.Host - engine *Engine - cmdCh chan<- engineCommand + // P2P config + host host.Host + // Internal config config Config log utils.Logger + // Propeller communication + engine *Engine + cmdCh chan<- engineCommand + eventsCh <-chan Event + // External communication + messageRecv chan []byte } func New( @@ -32,22 +39,23 @@ func New( config *Config, log utils.Logger, ) Service { - engine, cmdCh := NewEngine( + engine, cmdCh, eventsCh := NewEngine( privKey, config, log, ) return &propellerService{ - host: host, - engine: engine, - cmdCh: cmdCh, - config: *config, - log: log, + host: host, + engine: engine, + cmdCh: cmdCh, + eventsCh: eventsCh, + config: *config, + log: log, } } -func (s *propellerService) receivePropellerUnits(stream network.Stream) { +func (s *propellerService) receiveUnits(stream network.Stream) { defer stream.Close() sender := stream.Conn().RemotePeer() @@ -77,6 +85,10 @@ func (s *propellerService) receivePropellerUnits(stream network.Stream) { if err != nil { s.log.Warn("received invalid unit", zap.Error(err)) // todo(rdr): penalize sender? + // If we do it here then it means it shouldn't be handled at + // subP or Processor level, and all should be handled here, + // which means, every invalid unit should be handled at Service + // level. To be determined yet. continue } // send unit to engine @@ -87,7 +99,53 @@ func (s *propellerService) receivePropellerUnits(stream network.Stream) { } } +func (s *propellerService) broadcastUnit(ctx context.Context, unit *Unit, peers []peer.ID) { + batch := &pb.PropellerUnitBatch{ + Batch: []*pb.PropellerUnit{unit.ToProto()}, + } + data, err := proto.Marshal(batch) + if err != nil { + // todo(rdr): log the error? What if this cannot get it done? + // Our batch is correct unless there is an internal bug + panic(err) + } + + for _, p := range peers { + err = s.sendToPeer(ctx, p, data) + if err != nil { + // Why would there be any error + // What should we do in this case + panic(err) + } + } +} + +func (s *propellerService) sendToPeer(ctx context.Context, p peer.ID, data []byte) error { + stream, err := s.host.NewStream(ctx, p, s.config.StreamProtocol) + if err != nil { + return err + } + defer stream.Close() + + _, err = stream.Write(data) + return err +} + +func (s *propellerService) broadcastMessage(ctx context.Context, msg []byte) { +} + +func (s *propellerService) handleEvent(ctx context.Context, event Event) { + switch event := event.(type) { + case *messageFinalized: + // if the message is finalized it should have a receive + s.messageRecv <- event.message + case *broadcastUnit: + s.broadcastUnit(ctx, event.unit, event.peers) + } +} + func (s *propellerService) Run(ctx context.Context) error { + // Start engine service in the background go func() { err := s.engine.Run(ctx) if err != nil { @@ -97,17 +155,31 @@ func (s *propellerService) Run(ctx context.Context) error { s.log.Info("shutting down propeller engine") }() - s.host.SetStreamHandler(s.config.StreamProtocol, s.receivePropellerUnits) + // Subscribe to receiving certain topics + s.host.SetStreamHandler(s.config.StreamProtocol, s.receiveUnits) defer s.host.RemoveStreamHandler(s.config.StreamProtocol) + // Handle Engine outputs for { + // todo(rdr): handle the engines output such as units to broadcast select { case <-ctx.Done(): return ctx.Err() + case event := <-s.eventsCh: + s.handleEvent(ctx, event) } - // todo(rdr): handle the engines output such as units to broadcast } } -func (s *propellerService) broadcast() { +func (s *propellerService) Broadcast(msg []byte) { +} + +func (s *propellerService) Recv() <-chan []byte { + return s.messageRecv +} + +func (s *propellerService) RegisterCommittee(committeeID *CommitteeID) { +} + +func (s *propellerService) UnregisterCommittee(comitteeID *CommitteeID) { } diff --git a/consensus/propeller/sharding.go b/consensus/propeller/sharding.go index fa3d5f3c4a..189d1c429b 100644 --- a/consensus/propeller/sharding.go +++ b/consensus/propeller/sharding.go @@ -22,7 +22,7 @@ func CreatePropellerUnits( ) ([]Unit, error) { publisherID, err := peer.IDFromPrivateKey(privKey) if err != nil { - return nil, fmt.Errorf("getting publisher id from private key: %w", publisherID) + return nil, fmt.Errorf("getting publisher id %s from private key: %w", publisherID, err) } paddedMessage := PadMessage(message, numDataShards) @@ -41,13 +41,11 @@ func CreatePropellerUnits( units := make([]Unit, len(encodedMessage)) for i, shard := range encodedMessage { - merkleProof := merkleTree[i] - units[i] = Unit{ CommitteeID: *committeeID, Publisher: publisherID, MessageRoot: messageRoot, - MerkleProof: merkleProof, + MerkleProof: merkleTree[i], Signature: signature, ShardIndex: ShardIndex(i), // todo(rdr): assigning one shard per unit until multi shard algo per unit From aafbc80f02a26c2ecb0fe6d0b27543ae6496267c Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Thu, 23 Apr 2026 17:21:55 +0100 Subject: [PATCH 37/40] refactor: move things here and there --- consensus/propeller/engine.go | 57 +++++++++++++++------- consensus/propeller/propeller.go | 9 +++- consensus/propeller/scheduler.go | 3 +- consensus/propeller/sharding.go | 3 +- consensus/propeller/signing.go | 2 +- consensus/propeller/timecache/timecache.go | 9 ++-- consensus/propeller/types.go | 13 ----- consensus/propeller/unit.go | 16 +++++- consensus/propeller/unit_validator.go | 4 +- 9 files changed, 74 insertions(+), 42 deletions(-) diff --git a/consensus/propeller/engine.go b/consensus/propeller/engine.go index 8f10f53742..881de0a74a 100644 --- a/consensus/propeller/engine.go +++ b/consensus/propeller/engine.go @@ -11,16 +11,6 @@ import ( "go.uber.org/zap" ) -// Channel buffer sizes for the engine's internal channels. These are large -// enough to absorb bursts without blocking, but bounded to prevent unbounded -// memory growth from slow consumers. -const ( - eventChSize = 256 - cleanupChSize = 256 - appEventChSize = 256 - cmdChSize = 64 -) - type broadcastResult struct { units []Unit err error @@ -60,13 +50,13 @@ type registerCommittee struct { errCh chan error } -func (registerCommittee) isCommand() +func (registerCommittee) isCommand() {} type unregisterCommittee struct { committeeID CommitteeID } -func (unregisterCommittee) isCommand() +func (unregisterCommittee) isCommand() {} type broadcast struct { committeeID CommitteeID @@ -74,14 +64,14 @@ type broadcast struct { errCh chan error } -func (broadcast) isCommand() +func (broadcast) isCommand() {} type processUnit struct { unit *Unit sender peer.ID } -func (processUnit) isCommand() +func (processUnit) isCommand() {} // Engine is the central orchestrator of the Propeller protocol. It: // @@ -119,7 +109,7 @@ type Engine struct { // eventCh chan Event // cmdCh receives commands from the propeller service and act on those - cmdCh <-chan engineCommand + cmdCh chan engineCommand } // NewEngine creates an engine instance. It returns the engine and the channel to @@ -173,7 +163,7 @@ func (e *Engine) registerCommittee( if _, ok := e.committees[*committeeID]; ok { e.log.Warn( "committee already registered, will ignore re-registration attempt", - // todo(rdr): give a propper string repr + // todo(rdr): give a proper string repr zap.Any("committee id", committeeID), ) return nil @@ -286,7 +276,7 @@ func (e *Engine) processUnit(ctx context.Context, unit *Unit, sender peer.ID) { if !ok { // note(rdr): maybe debug? e.log.Warn("received key for unregistered committee, dropping", - // todo(rdr): give a propper string representation + // todo(rdr): give a proper string representation zap.Any("committee id", unit.CommitteeID), ) return @@ -333,8 +323,39 @@ func (e *Engine) Run(ctx context.Context) error { case broadcastResult := <-e.unitsPrepared: if broadcastResult.err != nil { e.log.Error("couldn't prepare units", zap.Error(broadcastResult.err)) + // todo(rdr): send error to service, probably don't log it + } + err := e.broadcast(broadcastResult.units) + if err != nil { + // log it? + // send error to service? } - e.broadcast(broadcastResult.units) } } } + +func (e *Engine) Broadcast(msg []byte) { +} + +func (e *Engine) RegisterCommittee( + committeeID *CommitteeID, + peers []PeerCommittee, + // todo(rdr): peersKeys is something I don't know how to set correctly yet + peersKeys []*StakerID, +) error { + // todo(rdr): does creating an error channel per call is performant or + // should we have a pool of err channels or that is too crazy :3 + // Thinking on the GC cost... + errCh := make(chan error) + e.cmdCh <- ®isterCommittee{ + committeeID: *committeeID, + peers: peers, + peersKeys: peersKeys, + errCh: errCh, + } + return <-errCh +} + +func (e *Engine) UnregisterCommittee(committeeID *CommitteeID) { + e.cmdCh <- &unregisterCommittee{committeeID: *committeeID} +} diff --git a/consensus/propeller/propeller.go b/consensus/propeller/propeller.go index 5717993367..5ffa38aac8 100644 --- a/consensus/propeller/propeller.go +++ b/consensus/propeller/propeller.go @@ -178,8 +178,15 @@ func (s *propellerService) Recv() <-chan []byte { return s.messageRecv } -func (s *propellerService) RegisterCommittee(committeeID *CommitteeID) { +func (s *propellerService) RegisterCommittee( + committeeID *CommitteeID, + peers []PeerCommittee, + // todo(rdr): peersKeys is something I don't know how to set correctly yet + peersKeys []*StakerID, +) error { + return s.engine.RegisterCommittee(committeeID, peers, peersKeys) } func (s *propellerService) UnregisterCommittee(comitteeID *CommitteeID) { + s.engine.UnregisterCommittee(comitteeID) } diff --git a/consensus/propeller/scheduler.go b/consensus/propeller/scheduler.go index 76333908d5..14cc9027a4 100644 --- a/consensus/propeller/scheduler.go +++ b/consensus/propeller/scheduler.go @@ -209,7 +209,7 @@ func (s *Scheduler) ShardIndexForPublisher( // ValidateShardOrigin verifies that a shard unit was received from the expected sender. // The sender has to be either the publisher for direct shards or a designated // broadcaster for the given shard index. -// todo(rdr): Maybe the unit validator should have this implementation +// todo(rdr): This implementation should probably be part of `UnitValidator` func (s *Scheduler) ValidateShardOrigin( sender peer.ID, publisher peer.ID, @@ -252,6 +252,7 @@ func (s *Scheduler) ValidateShardOrigin( // BroadcastTargets returns all peers whom to broadcast to, in shard-index order. // The i-th element of the returned slice is the peer responsible for shard i. func (s *Scheduler) BroadcastTargets() []peer.ID { + // todo(rdr): I would like to not use `append` and index directly instead (it's faster) targets := make([]peer.ID, 0, s.NumTotalShards()) for i, p := range s.peers { if i == s.localPeerIDIndex { diff --git a/consensus/propeller/sharding.go b/consensus/propeller/sharding.go index 189d1c429b..43cbefbd6a 100644 --- a/consensus/propeller/sharding.go +++ b/consensus/propeller/sharding.go @@ -49,7 +49,7 @@ func CreatePropellerUnits( Signature: signature, ShardIndex: ShardIndex(i), // todo(rdr): assigning one shard per unit until multi shard algo per unit - // is clear to me + // is clear to me. ShardData: []Shard{shard}, } } @@ -73,6 +73,7 @@ func ConstructMessageFromUnits( if units[i] != nil { // todo(rdr): we are assuming that every unit only carries one shard data for now // Not sure how the matrix is built when unit carries more than one + // Probably it is an algorithm based on stake levels (?) shards[i] = units[i].ShardData[0] } } diff --git a/consensus/propeller/signing.go b/consensus/propeller/signing.go index 5b1c8a80d8..9fb93d1f23 100644 --- a/consensus/propeller/signing.go +++ b/consensus/propeller/signing.go @@ -20,7 +20,7 @@ func buildSignPayload( const prefix = "" const suffix = "" - // cumulative lenghts denoting the ranges in where each bytes of data should be stored + // cumulative lengths denoting the ranges in where each bytes of data should be stored const prefixLen = len(prefix) const rootLen = prefixLen + 32 const committeeIDLen = rootLen + 32 diff --git a/consensus/propeller/timecache/timecache.go b/consensus/propeller/timecache/timecache.go index 474027e48c..371a49fd7f 100644 --- a/consensus/propeller/timecache/timecache.go +++ b/consensus/propeller/timecache/timecache.go @@ -13,6 +13,11 @@ type timedValue[K any] struct { } type TimeCache[K comparable] struct { + // todo(rdr): there is a possibility of make the value an `index` and find the time + // information on `timestamps`. This would reduce duplication of time.Time + // and it would also allow to easily detect expired values (no longer need to + // perform time.Time substractions since everything before `index` is expired) + // This is might be hyper-optimising.... // Access valid keys in O(1) values map[K]time.Time // Clean expired keys O(k) where `k` is the amount of expired keys @@ -112,9 +117,7 @@ func (tc *TimeCache[K]) removeExpired(now time.Time) { // almostFull returns if the time cache will get full on the next insertion func (tc *TimeCache[K]) almostFull() bool { - nextEnd := tc.end - tc.increaseIndex(&nextEnd) - return nextEnd == tc.start + return (tc.end+1)%index(tc.size) == tc.start } func (tc *TimeCache[K]) regrowth() { diff --git a/consensus/propeller/types.go b/consensus/propeller/types.go index b516b9c00f..63b37c6093 100644 --- a/consensus/propeller/types.go +++ b/consensus/propeller/types.go @@ -11,23 +11,10 @@ import ( "fmt" "time" - "github.com/NethermindEth/juno/consensus/propeller/merkle" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" ) -// CommitteeID identifies a committee or logical broadcast group. Multiple committees -// can operate concurrently within the same engine, each with its own peer set. -type CommitteeID [32]byte - -// ShardIndex is the position of a shard within the erasure-coded output. -// Valid range is [0, N-2] where N is the total number of peers. -type ShardIndex uint32 - -// MessageRoot is the SHA-256 Merkle root over all shard leaves. It uniquely -// identifies a message and is signed by the publisher to bind authenticity. -type MessageRoot merkle.Hash - // Config holds tunable parameters for the propeller engine. Sensible defaults // are provided by DefaultConfig(). type Config struct { diff --git a/consensus/propeller/unit.go b/consensus/propeller/unit.go index 9d14d02a60..1c4e75bbaf 100644 --- a/consensus/propeller/unit.go +++ b/consensus/propeller/unit.go @@ -11,12 +11,24 @@ import ( "google.golang.org/protobuf/proto" ) -// The actual shard fragmen +// CommitteeID identifies a committee or logical broadcast group. Multiple committees +// can operate concurrently within the same engine, each with its own peer set. +type CommitteeID [32]byte + +// MessageRoot is the SHA-256 Merkle root over all shard leaves. It uniquely +// identifies a message and is signed by the publisher to bind authenticity. +type MessageRoot merkle.Hash + +// The actual shard fragment type Shard []byte -// Holds the shard fragments carried by the Propeller Unit +// Set of shard fragments held by the Propeller Unit type ShardData []Shard +// ShardIndex is the position of a shard within the erasure-coded output. +// Valid range is [0, N-2] where N is the total number of peers. +type ShardIndex uint32 + func (sd ShardData) MarshalProto() []byte { shards := make([]*pb.Shard, len(sd)) for i, s := range sd { diff --git a/consensus/propeller/unit_validator.go b/consensus/propeller/unit_validator.go index 20773d7af0..695e8beff9 100644 --- a/consensus/propeller/unit_validator.go +++ b/consensus/propeller/unit_validator.go @@ -22,6 +22,7 @@ type UnitValidator struct { publisherPubKey crypto.PubKey scheduler *Scheduler + // todo(rdr): `receivedShards` can surely be an boolean array (cheaper than map) // track of every shard index received receivedShards map[ShardIndex]struct{} // Once the validation is done it's stored here, subsequent validation @@ -29,11 +30,10 @@ type UnitValidator struct { verifiedSignature Signature } -// todo(rdr): maybe just pass the publisher? func NewValidator(publisher peer.ID, scheduler *Scheduler) UnitValidator { - pubKey, err := publisher.ExtractPublicKey() // for now we are assuming that extracting a publisher key is always successful // and done in constant time + pubKey, err := publisher.ExtractPublicKey() if err != nil { panic(err) } From c74d0a492ee0d824db27defed25aaf4d8b732de5 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 11 May 2026 15:39:09 +0100 Subject: [PATCH 38/40] chore(consensus/propeller): minor improvements --- consensus/propeller/engine.go | 46 ++++++++++++++++++++------------ consensus/propeller/propeller.go | 10 +++++-- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/consensus/propeller/engine.go b/consensus/propeller/engine.go index 881de0a74a..c3a0aa9aee 100644 --- a/consensus/propeller/engine.go +++ b/consensus/propeller/engine.go @@ -13,7 +13,7 @@ import ( type broadcastResult struct { units []Unit - err error + errCh chan<- error } // todo(rdr): using String until I find a better type @@ -61,7 +61,7 @@ func (unregisterCommittee) isCommand() {} type broadcast struct { committeeID CommitteeID msg []byte - errCh chan error + errCh chan<- error } func (broadcast) isCommand() {} @@ -217,7 +217,11 @@ func (e *Engine) unregisterCommittee(committeeID *CommitteeID) { // prepareUnitsForBroadcast creates Proppeller units asynchronously since it is a very expensive // operation. -func (e *Engine) prepareUnitsForBroadcast(committeeID *CommitteeID, data []byte) error { +func (e *Engine) prepareUnitsForBroadcast( + committeeID *CommitteeID, + data []byte, + errCh chan<- error, +) error { cs, ok := e.committees[*committeeID] if !ok { return fmt.Errorf("cannot broadcast to an unregistered committee: %s", committeeID) @@ -236,9 +240,16 @@ func (e *Engine) prepareUnitsForBroadcast(committeeID *CommitteeID, data []byte) scheduler.NumDataShards(), scheduler.NumCodingShards(), ) + if err != nil { + errCh <- err + return + } + + // todo(rdr): Why do we send this back to the engine.Run thread instead of processing + // it right here? e.unitsPrepared <- broadcastResult{ units: units, - err: err, + errCh: errCh, } }(e, cs.scheduler, *committeeID, data) @@ -246,7 +257,7 @@ func (e *Engine) prepareUnitsForBroadcast(committeeID *CommitteeID, data []byte) } // broacast receives Propeller units (built in `prepareBroadcast`) and sends them -func (e *Engine) broadcast(units []Unit) error { +func (e *Engine) broadcast(ctx context.Context, units []Unit) error { targetCommittee := units[0].CommitteeID cs, ok := e.committees[targetCommittee] @@ -263,6 +274,7 @@ func (e *Engine) broadcast(units []Unit) error { ) } + broadcastMessage // todo(rdr): I need to do the actual sending // I need to pass to the eventCh all the units that it should receive @@ -299,8 +311,7 @@ func (e *Engine) handleCommand(ctx context.Context, command engineCommand) { case *broadcast: // we might need to pass the error channel here so that the internal go-routine // can forward it correctly (assuming a per command error channel) - err := e.prepareUnitsForBroadcast(&cmd.committeeID, cmd.msg) - cmd.errCh <- err + e.prepareUnitsForBroadcast(&cmd.committeeID, cmd.msg, cmd.errCh) case *processUnit: e.processUnit(ctx, cmd.unit, cmd.sender) } @@ -321,20 +332,21 @@ func (e *Engine) Run(ctx context.Context) error { e.handleCommand(ctx, cmd) case broadcastResult := <-e.unitsPrepared: - if broadcastResult.err != nil { - e.log.Error("couldn't prepare units", zap.Error(broadcastResult.err)) - // todo(rdr): send error to service, probably don't log it - } - err := e.broadcast(broadcastResult.units) - if err != nil { - // log it? - // send error to service? - } + err := e.broadcast(ctx, broadcastResult.units) + broadcastResult.errCh <- err } } } -func (e *Engine) Broadcast(msg []byte) { +func (e *Engine) Broadcast(committeeID *CommitteeID, msg []byte) error { + // todo(rdr): check how costly is this? Is there a better way than creating a channel + errCh := make(chan error) + e.cmdCh <- &broadcast{ + committeeID: *committeeID, + msg: msg, + errCh: errCh, + } + return <-errCh } func (e *Engine) RegisterCommittee( diff --git a/consensus/propeller/propeller.go b/consensus/propeller/propeller.go index 5ffa38aac8..590903f15e 100644 --- a/consensus/propeller/propeller.go +++ b/consensus/propeller/propeller.go @@ -114,7 +114,7 @@ func (s *propellerService) broadcastUnit(ctx context.Context, unit *Unit, peers err = s.sendToPeer(ctx, p, data) if err != nil { // Why would there be any error - // What should we do in this case + // Based on the error type, what should we do panic(err) } } @@ -141,6 +141,8 @@ func (s *propellerService) handleEvent(ctx context.Context, event Event) { s.messageRecv <- event.message case *broadcastUnit: s.broadcastUnit(ctx, event.unit, event.peers) + case *broadcastMessage: + s.broadcastMessage(ctx, event.unit) } } @@ -171,7 +173,11 @@ func (s *propellerService) Run(ctx context.Context) error { } } -func (s *propellerService) Broadcast(msg []byte) { +// todo(rdr): I am not sure of the Propeller <-> Engine separation... +// TBD how it looks like or if there should be any in the future + +func (s *propellerService) Broadcast(committeeID *CommitteeID, msg []byte) error { + return s.engine.Broadcast(committeeID, msg) } func (s *propellerService) Recv() <-chan []byte { From a10ad1911500d37546b8f96278d169d01f7b65fc Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Mon, 11 May 2026 15:39:42 +0100 Subject: [PATCH 39/40] chore(starknet-p2p-spcs): update --- buf.gen.yaml | 2 +- .../p2p/proto/capabilities/capabilities.pb.go | 23 ++++---- .../proto/consensus/consensus/consensus.pb.go | 59 ++++++++++--------- 3 files changed, 45 insertions(+), 39 deletions(-) diff --git a/buf.gen.yaml b/buf.gen.yaml index 3ba58f14be..6608c4e185 100644 --- a/buf.gen.yaml +++ b/buf.gen.yaml @@ -7,4 +7,4 @@ plugins: inputs: - git_repo: https://github.com/starknet-io/starknet-p2p-specs.git branch: bcfa353a169c859e4d5d97757caccbe76f75bc06 # Latest commit as of 2025 May 6th - depth: 1 \ No newline at end of file + depth: 1 diff --git a/starknet-p2p-specs/p2p/proto/capabilities/capabilities.pb.go b/starknet-p2p-specs/p2p/proto/capabilities/capabilities.pb.go index d1ae793cea..374497c461 100644 --- a/starknet-p2p-specs/p2p/proto/capabilities/capabilities.pb.go +++ b/starknet-p2p-specs/p2p/proto/capabilities/capabilities.pb.go @@ -7,11 +7,12 @@ package capabilities import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" unsafe "unsafe" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) const ( @@ -333,14 +334,16 @@ func file_p2p_proto_capabilities_proto_rawDescGZIP() []byte { return file_p2p_proto_capabilities_proto_rawDescData } -var file_p2p_proto_capabilities_proto_msgTypes = make([]protoimpl.MessageInfo, 5) -var file_p2p_proto_capabilities_proto_goTypes = []any{ - (*SyncCapability)(nil), // 0: SyncCapability - (*SyncCapability_ArchiveStrategy)(nil), // 1: SyncCapability.ArchiveStrategy - (*SyncCapability_L1PruneStrategy)(nil), // 2: SyncCapability.L1PruneStrategy - (*SyncCapability_ConstSizePruneStrategy)(nil), // 3: SyncCapability.ConstSizePruneStrategy - (*SyncCapability_StaticPruneStrategy)(nil), // 4: SyncCapability.StaticPruneStrategy -} +var ( + file_p2p_proto_capabilities_proto_msgTypes = make([]protoimpl.MessageInfo, 5) + file_p2p_proto_capabilities_proto_goTypes = []any{ + (*SyncCapability)(nil), // 0: SyncCapability + (*SyncCapability_ArchiveStrategy)(nil), // 1: SyncCapability.ArchiveStrategy + (*SyncCapability_L1PruneStrategy)(nil), // 2: SyncCapability.L1PruneStrategy + (*SyncCapability_ConstSizePruneStrategy)(nil), // 3: SyncCapability.ConstSizePruneStrategy + (*SyncCapability_StaticPruneStrategy)(nil), // 4: SyncCapability.StaticPruneStrategy + } +) var file_p2p_proto_capabilities_proto_depIdxs = []int32{ 1, // 0: SyncCapability.archive_strategy:type_name -> SyncCapability.ArchiveStrategy 2, // 1: SyncCapability.l1_prune_strategy:type_name -> SyncCapability.L1PruneStrategy diff --git a/starknet-p2p-specs/p2p/proto/consensus/consensus/consensus.pb.go b/starknet-p2p-specs/p2p/proto/consensus/consensus/consensus.pb.go index 6a13e77e79..10be303364 100644 --- a/starknet-p2p-specs/p2p/proto/consensus/consensus/consensus.pb.go +++ b/starknet-p2p-specs/p2p/proto/consensus/consensus/consensus.pb.go @@ -7,13 +7,14 @@ package consensus import ( + reflect "reflect" + sync "sync" + unsafe "unsafe" + common "github.com/starknet-io/starknet-p2p-specs/p2p/proto/common" transaction "github.com/starknet-io/starknet-p2p-specs/p2p/proto/transaction" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" - unsafe "unsafe" ) const ( @@ -1109,31 +1110,33 @@ func file_p2p_proto_consensus_consensus_proto_rawDescGZIP() []byte { return file_p2p_proto_consensus_consensus_proto_rawDescData } -var file_p2p_proto_consensus_consensus_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_p2p_proto_consensus_consensus_proto_msgTypes = make([]protoimpl.MessageInfo, 10) -var file_p2p_proto_consensus_consensus_proto_goTypes = []any{ - (Vote_VoteType)(0), // 0: Vote.VoteType - (*ConsensusTransaction)(nil), // 1: ConsensusTransaction - (*Vote)(nil), // 2: Vote - (*ConsensusStreamId)(nil), // 3: ConsensusStreamId - (*ProposalPart)(nil), // 4: ProposalPart - (*ProposalInit)(nil), // 5: ProposalInit - (*ProposalFin)(nil), // 6: ProposalFin - (*TransactionBatch)(nil), // 7: TransactionBatch - (*StreamMessage)(nil), // 8: StreamMessage - (*ProposalCommitment)(nil), // 9: ProposalCommitment - (*BlockInfo)(nil), // 10: BlockInfo - (*transaction.DeclareV3WithClass)(nil), // 11: DeclareV3WithClass - (*transaction.DeployAccountV3)(nil), // 12: DeployAccountV3 - (*transaction.InvokeV3)(nil), // 13: InvokeV3 - (*transaction.L1HandlerV0)(nil), // 14: L1HandlerV0 - (*common.Hash)(nil), // 15: Hash - (*common.Address)(nil), // 16: Address - (*common.Fin)(nil), // 17: Fin - (*common.Felt252)(nil), // 18: Felt252 - (*common.Uint128)(nil), // 19: Uint128 - (common.L1DataAvailabilityMode)(0), // 20: L1DataAvailabilityMode -} +var ( + file_p2p_proto_consensus_consensus_proto_enumTypes = make([]protoimpl.EnumInfo, 1) + file_p2p_proto_consensus_consensus_proto_msgTypes = make([]protoimpl.MessageInfo, 10) + file_p2p_proto_consensus_consensus_proto_goTypes = []any{ + (Vote_VoteType)(0), // 0: Vote.VoteType + (*ConsensusTransaction)(nil), // 1: ConsensusTransaction + (*Vote)(nil), // 2: Vote + (*ConsensusStreamId)(nil), // 3: ConsensusStreamId + (*ProposalPart)(nil), // 4: ProposalPart + (*ProposalInit)(nil), // 5: ProposalInit + (*ProposalFin)(nil), // 6: ProposalFin + (*TransactionBatch)(nil), // 7: TransactionBatch + (*StreamMessage)(nil), // 8: StreamMessage + (*ProposalCommitment)(nil), // 9: ProposalCommitment + (*BlockInfo)(nil), // 10: BlockInfo + (*transaction.DeclareV3WithClass)(nil), // 11: DeclareV3WithClass + (*transaction.DeployAccountV3)(nil), // 12: DeployAccountV3 + (*transaction.InvokeV3)(nil), // 13: InvokeV3 + (*transaction.L1HandlerV0)(nil), // 14: L1HandlerV0 + (*common.Hash)(nil), // 15: Hash + (*common.Address)(nil), // 16: Address + (*common.Fin)(nil), // 17: Fin + (*common.Felt252)(nil), // 18: Felt252 + (*common.Uint128)(nil), // 19: Uint128 + (common.L1DataAvailabilityMode)(0), // 20: L1DataAvailabilityMode + } +) var file_p2p_proto_consensus_consensus_proto_depIdxs = []int32{ 11, // 0: ConsensusTransaction.declare_v3:type_name -> DeclareV3WithClass 12, // 1: ConsensusTransaction.deploy_account_v3:type_name -> DeployAccountV3 From 99a155aaa856a0b4d5785cf5fe5279d10ed6ddbd Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Fri, 15 May 2026 14:34:46 +0100 Subject: [PATCH 40/40] fix(consensus, p2p): rebase --- consensus/propeller/engine.go | 46 ++++++++--------- consensus/propeller/processor.go | 30 +++++++---- consensus/propeller/propeller.go | 24 +++++---- consensus/propeller/sharding.go | 2 +- p2p/server/server.go | 88 +++++++++++++------------------- 5 files changed, 91 insertions(+), 99 deletions(-) diff --git a/consensus/propeller/engine.go b/consensus/propeller/engine.go index c3a0aa9aee..0ffcfff119 100644 --- a/consensus/propeller/engine.go +++ b/consensus/propeller/engine.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/log" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" "go.uber.org/zap" @@ -18,8 +18,8 @@ type broadcastResult struct { // todo(rdr): using String until I find a better type type StakerID struct { - peerID peer.ID - pubKey crypto.PubKey + peerID peer.ID //nolint:unused // populated once committee key wiring lands. + pubKey crypto.PubKey //nolint:unused // populated once committee key wiring lands. } // Holds the state for a Committee ID: @@ -84,7 +84,7 @@ type Engine struct { localPeer peer.ID config Config - log utils.StructuredLogger + logger log.StructuredLogger // processor handles validates and process all the messages received by other peers processor *Processor @@ -126,7 +126,7 @@ type Engine struct { func NewEngine( privKey crypto.PrivKey, config *Config, - log utils.StructuredLogger, + logger log.StructuredLogger, ) (*Engine, chan<- engineCommand, <-chan Event) { localPeerID, err := peer.IDFromPrivateKey(privKey) if err != nil { @@ -142,7 +142,7 @@ func NewEngine( localPeer: localPeerID, privKey: privKey, config: *config, - log: log, + logger: logger, processor: processor, committees: make(map[CommitteeID]*committeeState), cmdCh: cmdCh, @@ -153,6 +153,8 @@ func NewEngine( } // registerCommittee creates the schedule and encoder for a new channel. +// +//nolint:unparam // peersKeys is part of the public registration API; wiring is still pending. func (e *Engine) registerCommittee( committeeID *CommitteeID, peers []PeerCommittee, @@ -161,7 +163,7 @@ func (e *Engine) registerCommittee( // todo(rdr): Why re-registration should be ignored, // as far as I understand, it shouldn't happen :think: if _, ok := e.committees[*committeeID]; ok { - e.log.Warn( + e.logger.Warn( "committee already registered, will ignore re-registration attempt", // todo(rdr): give a proper string repr zap.Any("committee id", committeeID), @@ -169,16 +171,6 @@ func (e *Engine) registerCommittee( return nil } - // stakerIDs := make([]StakerID, len(peersKeys)) - // for i := range peersKeys { - // if peersKeys[i] != nil { - // stakerIDs[i] = *peersKeys[i] - // } else { - // // todo(rdr): re-check this flow once implementation is complete - // panic("received nil key, they shoudln't be nil") - // } - // } - schedule, err := NewScheduler(e.localPeer, peers) if err != nil { return fmt.Errorf("couldn't register a new committee: %w", err) @@ -190,7 +182,8 @@ func (e *Engine) registerCommittee( peerKeys: nil, } - e.log.Info("registered new committee", + e.logger.Info( + "registered new committee", // todo(rdr): give a proper string representation zap.Any("committeeID", committeeID), zap.Int("peers", len(peers)), @@ -209,7 +202,8 @@ func (e *Engine) unregisterCommittee(committeeID *CommitteeID) { // or will they shut down on their own eventually // better to pass a context with cancel? - e.log.Info("unregistered propeller committee", + e.logger.Info( + "unregistered propeller committee", // todo(rdr): give a proper string representation zap.Any("committee id", committeeID), ) @@ -224,7 +218,7 @@ func (e *Engine) prepareUnitsForBroadcast( ) error { cs, ok := e.committees[*committeeID] if !ok { - return fmt.Errorf("cannot broadcast to an unregistered committee: %s", committeeID) + return fmt.Errorf("cannot broadcast to an unregistered committee: %v", committeeID) } // todo(rdr): unsure if this approach of passing arguments to the go routine makes sense @@ -257,6 +251,8 @@ func (e *Engine) prepareUnitsForBroadcast( } // broacast receives Propeller units (built in `prepareBroadcast`) and sends them +// +//nolint:unparam // ctx will be used once the actual sending is wired up. func (e *Engine) broadcast(ctx context.Context, units []Unit) error { targetCommittee := units[0].CommitteeID @@ -274,7 +270,6 @@ func (e *Engine) broadcast(ctx context.Context, units []Unit) error { ) } - broadcastMessage // todo(rdr): I need to do the actual sending // I need to pass to the eventCh all the units that it should receive @@ -287,7 +282,8 @@ func (e *Engine) processUnit(ctx context.Context, unit *Unit, sender peer.ID) { cs, ok := e.committees[unit.CommitteeID] if !ok { // note(rdr): maybe debug? - e.log.Warn("received key for unregistered committee, dropping", + e.logger.Warn( + "received key for unregistered committee, dropping", // todo(rdr): give a proper string representation zap.Any("committee id", unit.CommitteeID), ) @@ -296,7 +292,7 @@ func (e *Engine) processUnit(ctx context.Context, unit *Unit, sender peer.ID) { err := e.processor.ProcessMessage(ctx, unit, sender, cs.scheduler) if err != nil { - e.log.Error("cannot process incoming unit", zap.Error(err)) + e.logger.Error("cannot process incoming unit", zap.Error(err)) } } @@ -311,7 +307,9 @@ func (e *Engine) handleCommand(ctx context.Context, command engineCommand) { case *broadcast: // we might need to pass the error channel here so that the internal go-routine // can forward it correctly (assuming a per command error channel) - e.prepareUnitsForBroadcast(&cmd.committeeID, cmd.msg, cmd.errCh) + if err := e.prepareUnitsForBroadcast(&cmd.committeeID, cmd.msg, cmd.errCh); err != nil { + cmd.errCh <- err + } case *processUnit: e.processUnit(ctx, cmd.unit, cmd.sender) } diff --git a/consensus/propeller/processor.go b/consensus/propeller/processor.go index a7500df7ec..713d99237e 100644 --- a/consensus/propeller/processor.go +++ b/consensus/propeller/processor.go @@ -9,7 +9,7 @@ import ( "time" "github.com/NethermindEth/juno/consensus/propeller/timecache" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/log" "github.com/libp2p/go-libp2p/core/peer" "go.uber.org/zap" ) @@ -109,7 +109,7 @@ func (s *subprocessor) beforeMessageBuiltStage(ctx context.Context) ( // will arrive. Although, that will mean we also need to validate any of those extra messages. // The question is then: Do the cost of validating missing messages reduces greatly the cost // of recovering them? Cases to consider: - // - Perfect network condition: a lot of bandwith and everybody is good. Does receiving all + // - Perfect network condition: a lot of bandwidth and everybody is good. Does receiving all // all the missing messages and validating them is cheaper than recovering them? What's the // performance difference? <- Write benchmark // - Bad network conditions: does the time waiting but receiving no messages will @@ -182,6 +182,7 @@ func (s *subprocessor) beforeMessageBuiltStage(ctx context.Context) ( return unitCount, fullMessage, nil } +//nolint:unparam // message will be used once the receive-stage forwarding is wired up. func (s *subprocessor) beforeMessageReceivedStage( ctx context.Context, unitCount int, @@ -286,7 +287,7 @@ type Processor struct { finalized *timecache.TimeCache[messageKey] subProcessors map[messageKey]chan<- unitWithSender - // channel through wich subprocessors signal they have finalized execution + // channel through which subprocessors signal they have finalized execution subProcessorsFinalized chan finalizedSubprocessor // channel through which subprocessor share units that failed validation invalidUnits chan invalidUnit @@ -301,15 +302,19 @@ type Processor struct { localPeer peer.ID timeout time.Duration concurrentTasksBounds concurrentTasksBounds - log utils.StructuredLogger + logger log.StructuredLogger } +// finalizedCacheSize bounds the number of recently-finalized message keys retained +// to avoid re-processing units belonging to messages already completed. +const finalizedCacheSize = 2048 + func NewProcessor(localPeer peer.ID, config *Config) (*Processor, <-chan Event) { timeout := config.StaleMessageTimeout processingEvents := make(chan Event) return &Processor{ - finalized: timecache.New[messageKey](2048, timeout), + finalized: timecache.New[messageKey](finalizedCacheSize, timeout), subProcessors: make(map[messageKey]chan<- unitWithSender), subProcessorsFinalized: make(chan finalizedSubprocessor), @@ -338,24 +343,26 @@ func (p *Processor) Run(ctx context.Context) { return case finalizedSubP := <-p.subProcessorsFinalized: if finalizedSubP.error != nil { - p.log.Error("subprocessor finalized with error", + p.logger.Error( + "subprocessor finalized with error", zap.String("message key", finalizedSubP.messageKey.String()), zap.Error(finalizedSubP.error), ) } else { - p.log.Info("subprocessor finalized", + p.logger.Info( + "subprocessor finalized", zap.String("message key", finalizedSubP.messageKey.String()), ) } p.finalize(&finalizedSubP.messageKey) case invalidUnit := <-p.invalidUnits: - p.log.Error("unit validation failed", + p.logger.Error( + "unit validation failed", zap.String("message key", invalidUnit.messageKey.String()), zap.Error(invalidUnit.error), ) // todo(rdr): should we mark sender to penalize? - } } } @@ -383,7 +390,7 @@ func (p *Processor) ProcessMessage( // a single one. unitChan, err := p.subprocessorChannel(ctx, &key, scheduler) if err != nil { - fmt.Errorf("couldn't get processor channel for key: %w", err) + return fmt.Errorf("couldn't get processor channel for key: %w", err) } select { @@ -423,7 +430,7 @@ func (p *Processor) createSubprocessor( p.subProcessors[*key] = unitChan // launch subprocessor - ctxWithTimeout, _ := context.WithTimeout(ctx, p.timeout) + ctxWithTimeout, cancel := context.WithTimeout(ctx, p.timeout) // todo(rdr): passing to avoid closures. Does it makes sense? // need to learn more how closures work in Go if it makes any difference // in performance. @@ -435,6 +442,7 @@ func (p *Processor) createSubprocessor( localShardIndex ShardIndex, unitChan <-chan unitWithSender, ) { + defer cancel() subProcessor := newSubprocessor( key.Publisher, scheduler, p.localPeer, localShardIndex, unitChan, p.invalidUnits, ) diff --git a/consensus/propeller/propeller.go b/consensus/propeller/propeller.go index 590903f15e..8e38c89dbf 100644 --- a/consensus/propeller/propeller.go +++ b/consensus/propeller/propeller.go @@ -6,7 +6,7 @@ import ( "io" pb "github.com/NethermindEth/juno/consensus/propeller/proto" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/log" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" @@ -24,7 +24,7 @@ type propellerService struct { host host.Host // Internal config config Config - log utils.Logger + logger log.Logger // Propeller communication engine *Engine cmdCh chan<- engineCommand @@ -37,12 +37,12 @@ func New( host host.Host, privKey crypto.PrivKey, config *Config, - log utils.Logger, + logger log.Logger, ) Service { engine, cmdCh, eventsCh := NewEngine( privKey, config, - log, + logger, ) return &propellerService{ @@ -51,7 +51,7 @@ func New( cmdCh: cmdCh, eventsCh: eventsCh, config: *config, - log: log, + logger: logger, } } @@ -65,7 +65,8 @@ func (s *propellerService) receiveUnits(stream network.Stream) { var buf bytes.Buffer _, err := buf.ReadFrom(reader) if err != nil { - s.log.Debug("error reading inbound propeller stream", + s.logger.Debug( + "error reading inbound propeller stream", zap.Stringer("peer", sender), zap.Error(err), ) @@ -74,7 +75,8 @@ func (s *propellerService) receiveUnits(stream network.Stream) { var batch pb.PropellerUnitBatch err = proto.Unmarshal(buf.Bytes(), &batch) if err != nil { - s.log.Debug("error unmarshalling propeller batch", + s.logger.Debug( + "error unmarshalling propeller batch", zap.Stringer("peer", sender), zap.Error(err), ) @@ -83,7 +85,7 @@ func (s *propellerService) receiveUnits(stream network.Stream) { for _, protoUnit := range batch.GetBatch() { unit, err := UnitFromProto(protoUnit) if err != nil { - s.log.Warn("received invalid unit", zap.Error(err)) + s.logger.Warn("received invalid unit", zap.Error(err)) // todo(rdr): penalize sender? // If we do it here then it means it shouldn't be handled at // subP or Processor level, and all should be handled here, @@ -131,7 +133,7 @@ func (s *propellerService) sendToPeer(ctx context.Context, p peer.ID, data []byt return err } -func (s *propellerService) broadcastMessage(ctx context.Context, msg []byte) { +func (s *propellerService) broadcastMessage(ctx context.Context, units []Unit) { } func (s *propellerService) handleEvent(ctx context.Context, event Event) { @@ -151,10 +153,10 @@ func (s *propellerService) Run(ctx context.Context) error { go func() { err := s.engine.Run(ctx) if err != nil { - s.log.Error("shutting down propeller engine", zap.Error(err)) + s.logger.Error("shutting down propeller engine", zap.Error(err)) return } - s.log.Info("shutting down propeller engine") + s.logger.Info("shutting down propeller engine") }() // Subscribe to receiving certain topics diff --git a/consensus/propeller/sharding.go b/consensus/propeller/sharding.go index 43cbefbd6a..e2e555b6ae 100644 --- a/consensus/propeller/sharding.go +++ b/consensus/propeller/sharding.go @@ -101,7 +101,7 @@ func ConstructMessageFromUnits( if messageRoot != expectedRoot { // todo(rdr): probably need to write string methods for the MessageRoot type return nil, nil, merkle.Proof{}, fmt.Errorf( - "wrong message root hash. Expected %s but got %s", + "wrong message root hash. Expected %v but got %v", &expectedRoot, &messageRoot, ) diff --git a/p2p/server/server.go b/p2p/server/server.go index 148e8d8c7b..6c6301e873 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -220,40 +220,42 @@ func (h *Server) onHeadersRequest( req.Iteration, finMsg, func(it blockDataAccessor) (proto.Message, error) { - blockHeader, err := it.Header() - if err != nil { - return nil, err - } - - h.logger.Debug("Created Header Iterator", zap.Uint64("blockNumber", blockHeader.Number)) - - stateUpdate, err := h.bcReader.StateUpdateByNumber(blockHeader.Number) - if err != nil { - return nil, err - } - - blockVer, err := core.ParseBlockVersion(blockHeader.ProtocolVersion) - if err != nil { - return nil, err - } - - var commitments *core.BlockCommitments - if blockVer.LessThan(core.Ver0_13_2) { - block, err := it.Block() + blockHeader, err := it.Header() if err != nil { return nil, err } - // TODO: switch to core.NewTrieBackend once the legacy trie and state are removed. - _, commitments, err = core.Post0132Hash(block, stateUpdate.StateDiff, core.DeprecatedTrieBackend) + + h.logger.Debug("Created Header Iterator", zap.Uint64("blockNumber", blockHeader.Number)) + + stateUpdate, err := h.bcReader.StateUpdateByNumber(blockHeader.Number) if err != nil { return nil, err } - } else { - commitments, err = h.bcReader.BlockCommitmentsByNumber(blockHeader.Number) + + blockVer, err := core.ParseBlockVersion(blockHeader.ProtocolVersion) if err != nil { return nil, err } - } + + var commitments *core.BlockCommitments + if blockVer.LessThan(core.Ver0_13_2) { + block, err := it.Block() + if err != nil { + return nil, err + } + // TODO: switch to core.NewTrieBackend once the legacy trie and state are removed. + _, commitments, err = core.Post0132Hash( + block, stateUpdate.StateDiff, core.DeprecatedTrieBackend, + ) + if err != nil { + return nil, err + } + } else { + commitments, err = h.bcReader.BlockCommitmentsByNumber(blockHeader.Number) + if err != nil { + return nil, err + } + } stateDiffCommitment := stateUpdate.StateDiff.Hash() return &header.BlockHeadersResponse{ @@ -262,10 +264,12 @@ func (h *Server) onHeadersRequest( blockHeader, commitments, &stateDiffCommitment, - stateUpdate.StateDiff.Length()), + stateUpdate.StateDiff.Length(), + ), }, }, nil - }) + }, + ) } func (h *Server) onEventsRequest( @@ -444,7 +448,8 @@ func (h *Server) onStateDiffRequest( } return responses, nil - }) + }, + ) } func (h *Server) onClassesRequest( @@ -453,27 +458,6 @@ func (h *Server) onClassesRequest( finMsg := &syncclass.ClassesResponse{ ClassMessage: &syncclass.ClassesResponse_Fin{}, } -<<<<<<< HEAD - return h.processIterationRequestMulti(req.Iteration, finMsg, func(it blockDataAccessor) ([]proto.Message, error) { - block, err := it.Block() - if err != nil { - return nil, err - } - blockNumber := block.Number - - stateUpdate, err := h.bcReader.StateUpdateByNumber(blockNumber) - if err != nil { - return nil, err - } - - stateReader, closer, err := h.bcReader.StateAtBlockNumber(blockNumber) - if err != nil { - return nil, err - } - defer func() { - if closeErr := closer(); closeErr != nil { - h.logger.Error("Failed to close state reader", zap.Error(closeErr)) -======= return h.processIterationRequestMulti( req.Iteration, finMsg, @@ -481,7 +465,6 @@ func (h *Server) onClassesRequest( block, err := it.Block() if err != nil { return nil, err ->>>>>>> 7470f4d1f (refactor(p2p): apply style correctness) } blockNumber := block.Number @@ -496,7 +479,7 @@ func (h *Server) onClassesRequest( } defer func() { if closeErr := closer(); closeErr != nil { - h.log.Error("Failed to close state reader", zap.Error(closeErr)) + h.logger.Error("Failed to close state reader", zap.Error(closeErr)) } }() @@ -529,7 +512,8 @@ func (h *Server) onClassesRequest( } return responses, nil - }) + }, + ) } // blockDataAccessor provides access to either entire block or header