diff --git a/Cargo.lock b/Cargo.lock index 84d6c1a8..5c1e364e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -841,47 +841,20 @@ dependencies = [ "cc", ] -[[package]] -name = "axum" -version = "0.7.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" -dependencies = [ - "async-trait", - "axum-core 0.4.5", - "bytes", - "futures-util", - "http 1.4.0", - "http-body", - "http-body-util", - "itoa", - "matchit 0.7.3", - "memchr", - "mime", - "percent-encoding", - "pin-project-lite", - "rustversion", - "serde", - "sync_wrapper", - "tower 0.5.3", - "tower-layer", - "tower-service", -] - [[package]] name = "axum" version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ - "axum-core 0.5.6", + "axum-core", "bytes", "futures-util", "http 1.4.0", "http-body", "http-body-util", "itoa", - "matchit 0.8.4", + "matchit", "memchr", "mime", "percent-encoding", @@ -893,26 +866,6 @@ dependencies = [ "tower-service", ] -[[package]] -name = "axum-core" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" -dependencies = [ - "async-trait", - "bytes", - "futures-util", - "http 1.4.0", - "http-body", - "http-body-util", - "mime", - "pin-project-lite", - "rustversion", - "sync_wrapper", - "tower-layer", - "tower-service", -] - [[package]] name = "axum-core" version = "0.5.6" @@ -1263,7 +1216,6 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", - "serde", "wasm-bindgen", "windows-link", ] @@ -1438,16 +1390,6 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" -[[package]] -name = "core-foundation" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "core-foundation" version = "0.10.1" @@ -2024,15 +1966,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" -[[package]] -name = "encoding_rs" -version = "0.8.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" -dependencies = [ - "cfg-if", -] - [[package]] name = "enum-iterator" version = "1.5.0" @@ -2239,12 +2172,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "059c31d7d36c43fe39d89e55711858b4da8be7eb6dabac23c7289b1a19489406" -[[package]] -name = "fixedbitset" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" - [[package]] name = "fixedbitset" version = "0.5.7" @@ -2544,7 +2471,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.4.0", - "indexmap 2.13.0", + "indexmap", "slab", "tokio", "tokio-util", @@ -2637,31 +2564,17 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "helius-laserstream" version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0df82f36800fc2faa33fc9106d9c47da75a19d4a4275c3784bd3cf92e6ce99bc" dependencies = [ "async-stream", - "bs58", - "chrono", "futures", "futures-channel", "futures-util", "laserstream-core-client", "laserstream-core-proto", - "prost 0.12.6", - "prost-types 0.12.6", - "rand 0.8.5", - "reqwest", "serde", - "serde_json", - "sha2 0.10.9", "thiserror 1.0.69", "tokio", - "tokio-stream", - "tonic 0.12.3", - "tonic-build 0.10.2", "tracing", - "tracing-subscriber", "url", "uuid", ] @@ -2826,22 +2739,6 @@ dependencies = [ "tower-service", ] -[[package]] -name = "hyper-tls" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" -dependencies = [ - "bytes", - "http-body-util", - "hyper", - "hyper-util", - "native-tls", - "tokio", - "tokio-native-tls", - "tower-service", -] - [[package]] name = "hyper-util" version = "0.1.20" @@ -2860,11 +2757,9 @@ dependencies = [ "percent-encoding", "pin-project-lite", "socket2 0.6.3", - "system-configuration", "tokio", "tower-service", "tracing", - "windows-registry", ] [[package]] @@ -3046,16 +2941,6 @@ dependencies = [ "quote", ] -[[package]] -name = "indexmap" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" -dependencies = [ - "autocfg", - "hashbrown 0.12.3", -] - [[package]] name = "indexmap" version = "2.13.0" @@ -3324,7 +3209,7 @@ dependencies = [ "futures", "laserstream-core-proto", "thiserror 1.0.69", - "tonic 0.14.5", + "tonic", "tonic-health", ] @@ -3336,8 +3221,8 @@ checksum = "a12d5ab2767a78aea87aeee99c7e62a241319a7976711e3f02f8b33844e33c03" dependencies = [ "anyhow", "bincode", - "prost 0.14.3", - "prost-types 0.14.3", + "prost", + "prost-types", "protobuf-src", "solana-account", "solana-account-decoder", @@ -3350,8 +3235,8 @@ dependencies = [ "solana-transaction-context", "solana-transaction-error", "solana-transaction-status", - "tonic 0.14.5", - "tonic-build 0.14.5", + "tonic", + "tonic-build", "tonic-prost", "tonic-prost-build", ] @@ -3534,12 +3419,6 @@ dependencies = [ "regex-automata", ] -[[package]] -name = "matchit" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" - [[package]] name = "matchit" version = "0.8.4" @@ -3678,23 +3557,6 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" -[[package]] -name = "native-tls" -version = "0.2.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2" -dependencies = [ - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - [[package]] name = "nix" version = "0.30.1" @@ -4085,25 +3947,15 @@ dependencies = [ "num", ] -[[package]] -name = "petgraph" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" -dependencies = [ - "fixedbitset 0.4.2", - "indexmap 2.13.0", -] - [[package]] name = "petgraph" version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" dependencies = [ - "fixedbitset 0.5.7", + "fixedbitset", "hashbrown 0.15.5", - "indexmap 2.13.0", + "indexmap", ] [[package]] @@ -4318,26 +4170,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "prost" -version = "0.12.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" -dependencies = [ - "bytes", - "prost-derive 0.12.6", -] - -[[package]] -name = "prost" -version = "0.13.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" -dependencies = [ - "bytes", - "prost-derive 0.13.5", -] - [[package]] name = "prost" version = "0.14.3" @@ -4345,28 +4177,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" dependencies = [ "bytes", - "prost-derive 0.14.3", -] - -[[package]] -name = "prost-build" -version = "0.12.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" -dependencies = [ - "bytes", - "heck 0.5.0", - "itertools 0.12.1", - "log", - "multimap", - "once_cell", - "petgraph 0.6.5", - "prettyplease", - "prost 0.12.6", - "prost-types 0.12.6", - "regex", - "syn 2.0.117", - "tempfile", + "prost-derive", ] [[package]] @@ -4379,10 +4190,10 @@ dependencies = [ "itertools 0.14.0", "log", "multimap", - "petgraph 0.8.3", + "petgraph", "prettyplease", - "prost 0.14.3", - "prost-types 0.14.3", + "prost", + "prost-types", "pulldown-cmark", "pulldown-cmark-to-cmark", "regex", @@ -4390,32 +4201,6 @@ dependencies = [ "tempfile", ] -[[package]] -name = "prost-derive" -version = "0.12.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" -dependencies = [ - "anyhow", - "itertools 0.12.1", - "proc-macro2", - "quote", - "syn 2.0.117", -] - -[[package]] -name = "prost-derive" -version = "0.13.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" -dependencies = [ - "anyhow", - "itertools 0.14.0", - "proc-macro2", - "quote", - "syn 2.0.117", -] - [[package]] name = "prost-derive" version = "0.14.3" @@ -4429,22 +4214,13 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "prost-types" -version = "0.12.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9091c90b0a32608e984ff2fa4091273cbdd755d54935c51d520887f4a1dbd5b0" -dependencies = [ - "prost 0.12.6", -] - [[package]] name = "prost-types" version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" dependencies = [ - "prost 0.14.3", + "prost", ] [[package]] @@ -4903,22 +4679,17 @@ checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64 0.22.1", "bytes", - "encoding_rs", "futures-channel", "futures-core", "futures-util", - "h2", "http 1.4.0", "http-body", "http-body-util", "hyper", "hyper-rustls", - "hyper-tls", "hyper-util", "js-sys", "log", - "mime", - "native-tls", "percent-encoding", "pin-project-lite", "quinn", @@ -4929,7 +4700,6 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", - "tokio-native-tls", "tokio-rustls", "tower 0.5.3", "tower-http", @@ -5077,15 +4847,6 @@ dependencies = [ "security-framework", ] -[[package]] -name = "rustls-pemfile" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "rustls-pki-types" version = "1.14.0" @@ -5102,7 +4863,7 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" dependencies = [ - "core-foundation 0.10.1", + "core-foundation", "core-foundation-sys", "jni", "log", @@ -5191,7 +4952,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" dependencies = [ "bitflags 2.11.0", - "core-foundation 0.10.1", + "core-foundation", "core-foundation-sys", "libc", "security-framework-sys", @@ -5509,9 +5270,10 @@ dependencies = [ [[package]] name = "sof" -version = "0.18.1" +version = "0.18.2" dependencies = [ "agave-transaction-view", + "ahash 0.8.12", "arcshift", "async-trait", "base64 0.22.1", @@ -5541,6 +5303,7 @@ dependencies = [ "socket2 0.5.10", "sof-gossip-tuning", "sof-solana-gossip", + "sof-support", "sof-types", "solana-entry", "solana-epoch-schedule", @@ -5557,6 +5320,7 @@ dependencies = [ "solana-signature", "solana-signer", "solana-streamer", + "solana-system-interface 3.1.0", "solana-transaction", "solana-vote", "solana-vote-program", @@ -5573,11 +5337,11 @@ dependencies = [ [[package]] name = "sof-gossip-tuning" -version = "0.18.1" +version = "0.18.2" [[package]] name = "sof-solana-compat" -version = "0.18.1" +version = "0.18.2" dependencies = [ "async-trait", "bincode", @@ -5612,7 +5376,7 @@ dependencies = [ "crossbeam-channel", "ed25519-dalek 2.2.0", "flate2", - "indexmap 2.13.0", + "indexmap", "itertools 0.12.1", "libc", "log", @@ -5663,20 +5427,28 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "sof-support" +version = "0.18.2" +dependencies = [ + "sof-types", +] + [[package]] name = "sof-tx" -version = "0.18.1" +version = "0.18.2" dependencies = [ "arcshift", "async-trait", "base64 0.22.1", "bincode", "bs58", - "prost 0.13.5", + "prost", "reqwest", "serde", "serde_json", "sof", + "sof-support", "sof-types", "solana-compute-budget-interface", "solana-connection-cache", @@ -5691,13 +5463,14 @@ dependencies = [ "solana-transaction", "thiserror 2.0.18", "tokio", - "tonic 0.12.3", + "tonic", + "tonic-prost", "xdp", ] [[package]] name = "sof-types" -version = "0.18.1" +version = "0.18.2" dependencies = [ "bs58", "serde", @@ -5807,7 +5580,7 @@ dependencies = [ "bytemuck_derive", "crossbeam-channel", "dashmap", - "indexmap 2.13.0", + "indexmap", "itertools 0.12.1", "log", "lz4", @@ -6135,7 +5908,7 @@ dependencies = [ "dashmap", "futures", "futures-util", - "indexmap 2.13.0", + "indexmap", "indicatif", "log", "rayon", @@ -6304,7 +6077,7 @@ dependencies = [ "bincode", "crossbeam-channel", "futures-util", - "indexmap 2.13.0", + "indexmap", "log", "rand 0.8.5", "rayon", @@ -7791,7 +7564,7 @@ dependencies = [ "futures-util", "governor", "histogram", - "indexmap 2.13.0", + "indexmap", "itertools 0.12.1", "libc", "log", @@ -8074,7 +7847,7 @@ dependencies = [ "async-trait", "bincode", "futures-util", - "indexmap 2.13.0", + "indexmap", "indicatif", "log", "rayon", @@ -8835,27 +8608,6 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "system-configuration" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" -dependencies = [ - "bitflags 2.11.0", - "core-foundation 0.9.4", - "system-configuration-sys", -] - -[[package]] -name = "system-configuration-sys" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "tap" version = "1.0.1" @@ -9070,16 +8822,6 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "tokio-native-tls" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" -dependencies = [ - "native-tls", - "tokio", -] - [[package]] name = "tokio-rustls" version = "0.26.4" @@ -9147,7 +8889,7 @@ version = "0.25.8+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16bff38f1d86c47f9ff0647e6838d7bb362522bdf44006c7068c2b1e606f1f3c" dependencies = [ - "indexmap 2.13.0", + "indexmap", "toml_datetime", "toml_parser", "winnow", @@ -9162,41 +8904,6 @@ dependencies = [ "winnow", ] -[[package]] -name = "tonic" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" -dependencies = [ - "async-stream", - "async-trait", - "axum 0.7.9", - "base64 0.22.1", - "bytes", - "flate2", - "h2", - "http 1.4.0", - "http-body", - "http-body-util", - "hyper", - "hyper-timeout", - "hyper-util", - "percent-encoding", - "pin-project", - "prost 0.13.5", - "rustls-pemfile", - "socket2 0.5.10", - "tokio", - "tokio-rustls", - "tokio-stream", - "tower 0.4.13", - "tower-layer", - "tower-service", - "tracing", - "webpki-roots 0.26.11", - "zstd", -] - [[package]] name = "tonic" version = "0.14.5" @@ -9204,7 +8911,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" dependencies = [ "async-trait", - "axum 0.8.8", + "axum", "base64 0.22.1", "bytes", "flate2", @@ -9227,22 +8934,10 @@ dependencies = [ "tower-layer", "tower-service", "tracing", + "webpki-roots 1.0.6", "zstd", ] -[[package]] -name = "tonic-build" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d021fc044c18582b9a2408cd0dd05b1596e3ecdb5c4df822bb0183545683889" -dependencies = [ - "prettyplease", - "proc-macro2", - "prost-build 0.12.6", - "quote", - "syn 2.0.117", -] - [[package]] name = "tonic-build" version = "0.14.5" @@ -9261,10 +8956,10 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4ff0636fef47afb3ec02818f5bceb4377b8abb9d6a386aeade18bd6212f8eb7" dependencies = [ - "prost 0.14.3", + "prost", "tokio", "tokio-stream", - "tonic 0.14.5", + "tonic", "tonic-prost", ] @@ -9275,8 +8970,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309" dependencies = [ "bytes", - "prost 0.14.3", - "tonic 0.14.5", + "prost", + "tonic", ] [[package]] @@ -9287,12 +8982,12 @@ checksum = "f3144df636917574672e93d0f56d7edec49f90305749c668df5101751bb8f95a" dependencies = [ "prettyplease", "proc-macro2", - "prost-build 0.14.3", - "prost-types 0.14.3", + "prost-build", + "prost-types", "quote", "syn 2.0.117", "tempfile", - "tonic-build 0.14.5", + "tonic-build", ] [[package]] @@ -9303,13 +8998,8 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", - "indexmap 1.9.3", "pin-project", "pin-project-lite", - "rand 0.8.5", - "slab", - "tokio", - "tokio-util", "tower-layer", "tower-service", "tracing", @@ -9323,7 +9013,7 @@ checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" dependencies = [ "futures-core", "futures-util", - "indexmap 2.13.0", + "indexmap", "pin-project-lite", "slab", "sync_wrapper", @@ -9753,7 +9443,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" dependencies = [ "anyhow", - "indexmap 2.13.0", + "indexmap", "wasm-encoder", "wasmparser", ] @@ -9766,7 +9456,7 @@ checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" dependencies = [ "bitflags 2.11.0", "hashbrown 0.15.5", - "indexmap 2.13.0", + "indexmap", "semver", ] @@ -9939,17 +9629,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" -[[package]] -name = "windows-registry" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" -dependencies = [ - "windows-link", - "windows-result", - "windows-strings", -] - [[package]] name = "windows-result" version = "0.4.1" @@ -10236,7 +9915,7 @@ checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" dependencies = [ "anyhow", "heck 0.5.0", - "indexmap 2.13.0", + "indexmap", "prettyplease", "syn 2.0.117", "wasm-metadata", @@ -10267,7 +9946,7 @@ checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" dependencies = [ "anyhow", "bitflags 2.11.0", - "indexmap 2.13.0", + "indexmap", "log", "serde", "serde_derive", @@ -10286,7 +9965,7 @@ checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" dependencies = [ "anyhow", "id-arena", - "indexmap 2.13.0", + "indexmap", "log", "semver", "serde", @@ -10356,7 +10035,7 @@ dependencies = [ "hyper-util", "thiserror 2.0.18", "tokio", - "tonic 0.14.5", + "tonic", "tonic-health", "tower 0.4.13", "yellowstone-grpc-proto", @@ -10369,10 +10048,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab52445da59c3e710dd2128e3e49598cfa45c8ba7f30feb608cafbd404e5e8cf" dependencies = [ "anyhow", - "prost 0.14.3", - "prost-types 0.14.3", + "prost", + "prost-types", "protoc-bin-vendored", - "tonic 0.14.5", + "tonic", "tonic-prost", "tonic-prost-build", ] diff --git a/Cargo.toml b/Cargo.toml index 5953cb8c..97ba0c9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "crates/sof-gossip-tuning", "crates/sof-observer", "crates/sof-solana-compat", + "crates/sof-support", "crates/sof-tx", "crates/sof-types", ] @@ -11,6 +12,7 @@ exclude = ["crates/sof-solana-gossip"] [patch.crates-io] sof-solana-gossip = { path = "crates/sof-solana-gossip" } +helius-laserstream = { path = "vendor/helius-laserstream" } [workspace.package] edition = "2024" diff --git a/crates/sof-gossip-tuning/Cargo.toml b/crates/sof-gossip-tuning/Cargo.toml index f9038f46..a34b3d06 100644 --- a/crates/sof-gossip-tuning/Cargo.toml +++ b/crates/sof-gossip-tuning/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sof-gossip-tuning" -version = "0.18.1" +version = "0.18.2" edition.workspace = true description = "Typed gossip and ingest tuning presets for SOF hosts" license = "Apache-2.0 OR MIT" diff --git a/crates/sof-observer/Cargo.toml b/crates/sof-observer/Cargo.toml index 214d3b94..28cc7457 100644 --- a/crates/sof-observer/Cargo.toml +++ b/crates/sof-observer/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sof" -version = "0.18.1" +version = "0.18.2" edition.workspace = true description = "Solana Observer Framework for low-latency shred ingestion and plugin-driven transaction observation" license = "Apache-2.0 OR MIT" @@ -32,8 +32,9 @@ provider-websocket = [] [dependencies] agave-transaction-view = { version = "3.1.11", features = ["agave-unstable-api"] } -sof-gossip-tuning = { version = "0.18.1", path = "../sof-gossip-tuning" } -sof-types = { version = "0.18.1", path = "../sof-types", features = ["solana-compat"] } +sof-gossip-tuning = { version = "0.18.2", path = "../sof-gossip-tuning" } +sof-support = { version = "0.18.2", path = "../sof-support" } +sof-types = { version = "0.18.2", path = "../sof-types", features = ["solana-compat"] } solana-gossip = { package = "sof-solana-gossip", version = "3.1.11-sof.9", optional = true, features = ["agave-unstable-api"] } solana-entry = { version = "3.1.11", features = ["agave-unstable-api"] } solana-epoch-schedule = "3.0.0" @@ -50,6 +51,7 @@ solana-signer = "3.0.0" solana-streamer = "3.1.11" solana-transaction = { version = "3.0.2", features = ["serde"] } solana-packet = "3.0.0" +solana-system-interface = "3.1.0" solana-vote = "3.1.11" solana-vote-program = { version = "3.1.11", features = ["agave-unstable-api"] } reed-solomon-erasure = { version = "6.0.0", features = ["simd-accel"] } @@ -81,6 +83,7 @@ futures-channel = "0.3.32" tokio-tungstenite = { version = "0.28", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] } signal-hook = "0.3" smallvec = "1.15" +ahash = "0.8.12" xdp = { version = "0.7.3", optional = true } yellowstone-grpc-client = { version = "12.2.0", optional = true } yellowstone-grpc-proto = { version = "12.1.0", optional = true } diff --git a/crates/sof-observer/README.md b/crates/sof-observer/README.md index a6d2e6f0..97787f70 100644 --- a/crates/sof-observer/README.md +++ b/crates/sof-observer/README.md @@ -581,7 +581,7 @@ cargo add sof Optional gossip bootstrap support at compile time: ```toml -sof = { version = "0.18.1", features = ["gossip-bootstrap"] } +sof = { version = "0.18.2", features = ["gossip-bootstrap"] } ``` `gossip-bootstrap` uses the vendored `sof-solana-gossip` backend, but it no longer exact-pins the @@ -590,7 +590,7 @@ Solana `3.1.11` patch line. Downstream crates can resolve newer compatible `3.1. Optional external `kernel-bypass` ingress support: ```toml -sof = { version = "0.18.1", features = ["kernel-bypass"] } +sof = { version = "0.18.2", features = ["kernel-bypass"] } ``` The bundled `sof-solana-gossip` backend defaults to SOF's lightweight in-memory duplicate/conflict @@ -1088,6 +1088,7 @@ Design references: - Queue pressure drops hook events instead of stalling ingest. - Typed host tuning is available through `sof-gossip-tuning` and `RuntimeSetup::with_gossip_tuning_profile(...)`. - `RuntimeExtension` WebSocket connectors support full `ws://` / `wss://` handshake + frame decoding. +- Runtime extensions require non-empty names and resource metadata; startup rejects empty `resource_id` / shared tags and bounds `read_buffer_bytes`. - WebSocket close frames emit `RuntimePacketEventClass::ConnectionClosed` in `on_packet_received`. - WebSocket packet events expose `websocket_frame_type` (`Text`/`Binary`/`Ping`/`Pong`) for startup-time filtering and runtime routing. - In gossip mode, SOF runs as an active bounded relay client by default (UDP relay + repair serve), not as an observer-only passive consumer. diff --git a/crates/sof-observer/fuzz/fuzz_targets/shred_fec_recover.rs b/crates/sof-observer/fuzz/fuzz_targets/shred_fec_recover.rs index 1a406739..aa067b18 100644 --- a/crates/sof-observer/fuzz/fuzz_targets/shred_fec_recover.rs +++ b/crates/sof-observer/fuzz/fuzz_targets/shred_fec_recover.rs @@ -1,5 +1,7 @@ #![no_main] +use std::sync::Arc; + use libfuzzer_sys::fuzz_target; use sof::{ protocol::shred_wire::{ @@ -230,6 +232,7 @@ fuzz_target!(|bytes: &[u8]| { }), ParsedShred::Code(code) => ParsedShredHeader::Code(code), }; + let packet = Arc::<[u8]>::from(packet); let recovered = recoverer.ingest_packet(&packet, &parsed_header); assert!(recoverer.tracked_sets() <= max_tracked_sets); diff --git a/crates/sof-observer/src/app/config/common.rs b/crates/sof-observer/src/app/config/common.rs index 294d068b..80ba050f 100644 --- a/crates/sof-observer/src/app/config/common.rs +++ b/crates/sof-observer/src/app/config/common.rs @@ -1,4 +1,4 @@ -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::time::Duration; pub(crate) use crate::runtime_env::read_env_var; @@ -16,11 +16,3 @@ pub fn duration_to_ms_u64(duration: Duration) -> u64 { millis as u64 } } - -pub fn current_unix_ms() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_or(0, |duration| { - duration.as_millis().min(u128::from(u64::MAX)) as u64 - }) -} diff --git a/crates/sof-observer/src/app/runtime/bootstrap/gossip/handoff.rs b/crates/sof-observer/src/app/runtime/bootstrap/gossip/handoff.rs index 32f61e36..a9747abf 100644 --- a/crates/sof-observer/src/app/runtime/bootstrap/gossip/handoff.rs +++ b/crates/sof-observer/src/app/runtime/bootstrap/gossip/handoff.rs @@ -1,6 +1,8 @@ #[cfg(feature = "gossip-bootstrap")] use super::*; #[cfg(feature = "gossip-bootstrap")] +use sof_support::time_support::current_unix_ms; +#[cfg(feature = "gossip-bootstrap")] use thiserror::Error; #[cfg(feature = "gossip-bootstrap")] diff --git a/crates/sof-observer/src/app/runtime/bootstrap/repair.rs b/crates/sof-observer/src/app/runtime/bootstrap/repair.rs index ee7c43a1..d176e83f 100644 --- a/crates/sof-observer/src/app/runtime/bootstrap/repair.rs +++ b/crates/sof-observer/src/app/runtime/bootstrap/repair.rs @@ -125,6 +125,15 @@ pub(crate) struct RepairSourceHintBuffer { #[cfg(feature = "gossip-bootstrap")] impl RepairSourceHintBuffer { pub(crate) fn new(capacity: usize) -> Self { + let capacity = capacity.max(1); + Self { + counts: HashMap::with_capacity(capacity), + capacity, + } + } + + #[cfg(test)] + pub(crate) fn new_baseline(capacity: usize) -> Self { Self { counts: HashMap::new(), capacity: capacity.max(1), @@ -335,6 +344,90 @@ pub(crate) fn spawn_repair_driver( Ok((command_tx, result_rx, peer_snapshot, driver_handle)) } +#[cfg(all(test, feature = "gossip-bootstrap"))] +mod tests { + use std::{ + env, + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Instant, + }; + + use super::RepairSourceHintBuffer; + + #[test] + #[ignore = "profiling fixture for repair source hint buffer allocation"] + fn repair_source_hint_buffer_profile_fixture() { + let iterations = env::var("SOF_REPAIR_SOURCE_HINT_PROFILE_ITERS") + .ok() + .and_then(|raw| raw.parse::().ok()) + .filter(|value| *value > 0) + .unwrap_or(20_000); + let capacity = env::var("SOF_REPAIR_SOURCE_HINT_PROFILE_CAPACITY") + .ok() + .and_then(|raw| raw.parse::().ok()) + .filter(|value| *value > 0) + .unwrap_or(256); + let batch_size = env::var("SOF_REPAIR_SOURCE_HINT_PROFILE_BATCH") + .ok() + .and_then(|raw| raw.parse::().ok()) + .filter(|value| *value > 0) + .unwrap_or(capacity / 2); + let addresses = (0..capacity) + .map(|index| { + SocketAddr::new( + IpAddr::V4(Ipv4Addr::new( + 127, + 0, + u8::try_from((index / 255) % 255).unwrap_or(0), + u8::try_from((index % 255) + 1).unwrap_or(u8::MAX), + )), + u16::try_from((10_000 + index) % usize::from(u16::MAX)).unwrap_or(u16::MAX), + ) + }) + .collect::>(); + assert!(!addresses.is_empty()); + + let baseline_started_at = Instant::now(); + for _ in 0..iterations { + let mut buffer = RepairSourceHintBuffer::new_baseline(capacity); + for addr in addresses.iter().copied() { + assert!(buffer.record(addr).is_ok()); + } + let drained = buffer.drain_batch(batch_size); + assert!(!drained.is_empty()); + } + let baseline_elapsed = baseline_started_at.elapsed(); + + let optimized_started_at = Instant::now(); + for _ in 0..iterations { + let mut buffer = RepairSourceHintBuffer::new(capacity); + for addr in addresses.iter().copied() { + assert!(buffer.record(addr).is_ok()); + } + let drained = buffer.drain_batch(batch_size); + assert!(!drained.is_empty()); + } + let optimized_elapsed = optimized_started_at.elapsed(); + + let baseline_avg_ns = + baseline_elapsed.as_nanos() / u128::try_from(iterations).unwrap_or(u128::MAX); + let optimized_avg_ns = + optimized_elapsed.as_nanos() / u128::try_from(iterations).unwrap_or(u128::MAX); + let baseline_avg_us = baseline_avg_ns as f64 / 1_000.0; + let optimized_avg_us = optimized_avg_ns as f64 / 1_000.0; + println!( + "repair_source_hint_buffer_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.6} optimized_avg_us_per_iteration={:.6}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_us, + optimized_avg_us + ); + } +} + #[cfg(feature = "gossip-bootstrap")] pub(crate) fn replace_repair_driver( repair_client: crate::repair::GossipRepairClient, diff --git a/crates/sof-observer/src/app/runtime/dataset/process.rs b/crates/sof-observer/src/app/runtime/dataset/process.rs index 8d123f57..fe54fab0 100644 --- a/crates/sof-observer/src/app/runtime/dataset/process.rs +++ b/crates/sof-observer/src/app/runtime/dataset/process.rs @@ -16,6 +16,9 @@ use crate::{ }; use agave_transaction_view::transaction_view::SanitizedTransactionView; use core::mem::size_of; +use sof_support::short_vec::{ + ShortVecDecodeError as PartialParseError, decode_short_u16_len, decode_short_u16_len_partial, +}; use solana_hash::Hash; use solana_packet::PACKET_DATA_SIZE; use solana_pubkey::Pubkey; @@ -1833,50 +1836,6 @@ fn read_u64_le_partial(payload: &[u8], offset: &mut usize) -> Result Option { - let mut value = 0_usize; - let mut shift = 0_u32; - for byte_index in 0..3 { - let byte = usize::from(*payload.get(*offset)?); - *offset = (*offset).saturating_add(1); - value |= (byte & 0x7f) << shift; - if byte & 0x80 == 0 { - return Some(value); - } - shift = shift.saturating_add(7); - if byte_index == 2 { - return None; - } - } - None -} - -fn decode_short_u16_len_partial( - payload: &[u8], - offset: &mut usize, -) -> Result { - let mut value = 0_usize; - let mut shift = 0_u32; - for byte_index in 0..3 { - let byte = usize::from(*payload.get(*offset).ok_or(PartialParseError::Incomplete)?); - *offset = (*offset).saturating_add(1); - value |= (byte & 0x7f) << shift; - if byte & 0x80 == 0 { - return Ok(value); - } - shift = shift.saturating_add(7); - if byte_index == 2 { - return Err(PartialParseError::Invalid); - } - } - Err(PartialParseError::Invalid) -} - -enum PartialParseError { - Incomplete, - Invalid, -} - fn join_payload_fragments_into( buffer: &mut Vec, fragments: &[SharedPayloadFragment], @@ -1955,11 +1914,13 @@ mod tests { use solana_signer::Signer as _; use solana_transaction::Transaction; use std::{ + env, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, sync::{ Arc, atomic::{AtomicUsize, Ordering}, }, + thread, time::{Duration, Instant}, }; use wincode::{ @@ -2366,7 +2327,7 @@ mod tests { #[test] #[ignore = "profiling fixture for perf"] fn multi_hook_profile_fixture() { - let iterations = std::env::var("SOF_MULTI_HOOK_PROFILE_ITERS") + let iterations = env::var("SOF_MULTI_HOOK_PROFILE_ITERS") .ok() .and_then(|raw| raw.parse::().ok()) .filter(|value| *value > 0) @@ -2442,7 +2403,7 @@ mod tests { ); assert!(matches!(outcome, DatasetProcessOutcome::Decoded)); } - std::thread::sleep(Duration::from_millis(250)); + thread::sleep(Duration::from_millis(250)); assert_eq!(dataset_decode_fail_count.load(Ordering::Relaxed), 0); assert_eq!(tx_event_drop_count.load(Ordering::Relaxed), 0); assert_eq!(plugin_host.dropped_event_count(), 0); @@ -2474,12 +2435,12 @@ mod tests { #[test] #[ignore = "profiling fixture for completed-dataset prefilter decode skip A/B"] fn completed_dataset_prefilter_profile_fixture() { - let iterations = std::env::var("SOF_COMPLETED_DATASET_PREFILTER_PROFILE_ITERS") + let iterations = env::var("SOF_COMPLETED_DATASET_PREFILTER_PROFILE_ITERS") .ok() .and_then(|raw| raw.parse::().ok()) .filter(|value| *value > 0) .unwrap_or(20_000); - let mode = std::env::var("SOF_COMPLETED_DATASET_PREFILTER_PROFILE_MODE") + let mode = env::var("SOF_COMPLETED_DATASET_PREFILTER_PROFILE_MODE") .unwrap_or_else(|_| "manual".to_owned()); let ignored_account = Pubkey::new_unique(); let payload = build_profile_payload(PROFILE_ENTRY_COUNT); @@ -2529,7 +2490,7 @@ mod tests { ); assert!(matches!(outcome, DatasetProcessOutcome::Decoded)); } - std::thread::sleep(Duration::from_millis(100)); + thread::sleep(Duration::from_millis(100)); assert_eq!(dataset_decode_fail_count.load(Ordering::Relaxed), 0); assert_eq!(tx_event_drop_count.load(Ordering::Relaxed), 0); assert_eq!(plugin_host.dropped_event_count(), 0); @@ -2697,7 +2658,7 @@ mod tests { ); assert!(matches!(outcome, DatasetProcessOutcome::Decoded)); - std::thread::sleep(Duration::from_millis(50)); + thread::sleep(Duration::from_millis(50)); assert_eq!(handled.load(Ordering::Relaxed), 1); assert_eq!(dataset_decode_fail_count.load(Ordering::Relaxed), 0); assert_eq!(tx_event_drop_count.load(Ordering::Relaxed), 0); diff --git a/crates/sof-observer/src/app/runtime/observability.rs b/crates/sof-observer/src/app/runtime/observability.rs index f139fa38..ea0b9131 100644 --- a/crates/sof-observer/src/app/runtime/observability.rs +++ b/crates/sof-observer/src/app/runtime/observability.rs @@ -9,10 +9,11 @@ use std::{ }; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, + io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}, net::{TcpListener, TcpStream}, sync::oneshot, task::JoinHandle, + time::timeout, }; #[cfg(test)] @@ -33,6 +34,8 @@ const METRICS_PATH: &str = "/metrics"; const HEALTH_PATH: &str = "/healthz"; const READY_PATH: &str = "/readyz"; const REQUEST_BUFFER_BYTES: usize = 8 * 1024; +const REQUEST_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2); +const RESPONSE_WRITE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2); const CONTENT_TYPE_TEXT: &str = "text/plain; charset=utf-8"; const CONTENT_TYPE_PROMETHEUS: &str = "text/plain; version=0.0.4; charset=utf-8"; @@ -267,13 +270,21 @@ async fn handle_connection( Some(_) => HttpResponse::not_found(), None => HttpResponse::bad_request(), }; - stream.write_all(response.serialize().as_bytes()).await?; - stream.shutdown().await + write_response_with_timeout(&mut stream, &response, RESPONSE_WRITE_TIMEOUT).await } async fn read_request_path(stream: &mut TcpStream) -> io::Result> { + read_request_path_with_timeout(stream, REQUEST_READ_TIMEOUT).await +} + +async fn read_request_path_with_timeout( + stream: &mut TcpStream, + request_timeout: std::time::Duration, +) -> io::Result> { let mut buffer = [0_u8; REQUEST_BUFFER_BYTES]; - let read = stream.read(&mut buffer).await?; + let read = timeout(request_timeout, stream.read(&mut buffer)) + .await + .map_err(|_elapsed| io::Error::new(io::ErrorKind::TimedOut, "request read timed out"))??; if read == 0 { return Ok(None); } @@ -302,6 +313,27 @@ async fn read_request_path(stream: &mut TcpStream) -> io::Result( + stream: &mut W, + response: &HttpResponse, + write_timeout: std::time::Duration, +) -> io::Result<()> +where + W: AsyncWrite + Unpin, +{ + let response = response.serialize(); + timeout(write_timeout, stream.write_all(response.as_bytes())) + .await + .map_err(|_elapsed| { + io::Error::new(io::ErrorKind::TimedOut, "response write timed out") + })??; + timeout(write_timeout, stream.shutdown()) + .await + .map_err(|_elapsed| { + io::Error::new(io::ErrorKind::TimedOut, "response shutdown timed out") + })? +} + fn render_metrics( handle: &RuntimeObservabilityHandle, plugin_host: &PluginHost, @@ -2716,7 +2748,7 @@ impl HttpResponse { mod tests { use super::*; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, + io::{AsyncReadExt, AsyncWriteExt, duplex}, net::TcpStream, }; @@ -2994,4 +3026,53 @@ mod tests { .expect("response should read"); response } + + #[tokio::test(flavor = "current_thread")] + async fn read_request_path_times_out_slow_clients() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let addr = listener.local_addr().expect("listener addr"); + let server = tokio::spawn(async move { + let accepted = listener.accept().await; + assert!(accepted.is_ok()); + let (mut stream, _) = accepted.unwrap_or_else(|error| panic!("{error}")); + let result = + read_request_path_with_timeout(&mut stream, std::time::Duration::from_millis(25)) + .await; + assert!(result.is_err(), "slow client should time out"); + let error = match result { + Ok(value) => panic!("expected timeout, got {value:?}"), + Err(error) => error, + }; + assert_eq!(error.kind(), io::ErrorKind::TimedOut); + }); + + let client = TcpStream::connect(addr).await; + assert!(client.is_ok()); + let client = client.unwrap_or_else(|error| panic!("{error}")); + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + drop(client); + + assert!(server.await.is_ok()); + } + + #[tokio::test(flavor = "current_thread")] + async fn write_response_with_timeout_times_out_slow_clients() { + let (mut writer, _reader) = duplex(64); + let response = HttpResponse::ok(CONTENT_TYPE_TEXT, "x".repeat(1024)); + + let result = write_response_with_timeout( + &mut writer, + &response, + std::time::Duration::from_millis(25), + ) + .await; + assert!(result.is_err(), "slow client should time out"); + let error = match result { + Ok(()) => panic!("expected timeout"), + Err(error) => error, + }; + assert_eq!(error.kind(), io::ErrorKind::TimedOut); + } } diff --git a/crates/sof-observer/src/app/runtime/runloop/driver.rs b/crates/sof-observer/src/app/runtime/runloop/driver.rs index 053b4eb1..96312860 100644 --- a/crates/sof-observer/src/app/runtime/runloop/driver.rs +++ b/crates/sof-observer/src/app/runtime/runloop/driver.rs @@ -30,6 +30,7 @@ use agave_transaction_view::transaction_view::SanitizedTransactionView; use crossbeam_channel::Sender as CrossbeamSender; use reassembly::dataset::CompletedDataSet; use reassembly::inline::InlineContiguousDataSetSink; +use sof_support::time_support::current_unix_ms; use solana_signature::Signature; use solana_transaction::versioned::VersionedTransaction; #[cfg(feature = "gossip-bootstrap")] @@ -2658,8 +2659,8 @@ async fn run_async_with_hosts_inner( #[cfg(feature = "gossip-bootstrap")] { if let Some(peer_snapshot) = repair_peer_snapshot.as_ref() { - packet_worker_pool - .update_known_pubkeys(peer_snapshot.shared_get().known_pubkeys.clone()); + let peer_snapshot = peer_snapshot.shared_get(); + packet_worker_pool.update_known_pubkeys(&peer_snapshot.known_pubkeys); } } #[cfg(feature = "gossip-bootstrap")] @@ -5141,7 +5142,11 @@ mod tests { use solana_perf::test_tx::{new_test_vote_tx, test_tx}; use solana_pubkey::Pubkey; use solana_transaction::versioned::VersionedTransaction; - use std::sync::{Arc, Mutex}; + use std::{ + env, + sync::{Arc, Mutex}, + thread, + }; use wincode::{ Serialize as _, containers::{Elem, Vec as WincodeVec}, @@ -5290,12 +5295,12 @@ mod tests { #[test] #[ignore = "profiling fixture for perf"] fn inline_open_dataset_profile_fixture() { - let iterations = std::env::var("SOF_INLINE_OPEN_DATASET_PROFILE_ITERS") + let iterations = env::var("SOF_INLINE_OPEN_DATASET_PROFILE_ITERS") .ok() .and_then(|raw| raw.parse::().ok()) .filter(|value| *value > 0) .unwrap_or(2_048); - let fragment_count = std::env::var("SOF_INLINE_OPEN_DATASET_PROFILE_FRAGMENTS") + let fragment_count = env::var("SOF_INLINE_OPEN_DATASET_PROFILE_FRAGMENTS") .ok() .and_then(|raw| raw.parse::().ok()) .filter(|value| *value > 0) @@ -5435,7 +5440,7 @@ mod tests { } } - std::thread::sleep(Duration::from_millis(100)); + thread::sleep(Duration::from_millis(100)); assert!(dispatched_total > 0); println!( "inline_open_dataset_profile_fixture iterations={} fragments={} dispatched={} elapsed_ms={}", @@ -5449,14 +5454,14 @@ mod tests { #[test] #[ignore = "profiling fixture for inline prefilter decode skip A/B"] fn inline_open_dataset_prefilter_profile_fixture() { - let mode = std::env::var("SOF_INLINE_PREFILTER_PROFILE_MODE") - .unwrap_or_else(|_| "manual".to_owned()); - let iterations = std::env::var("SOF_INLINE_PREFILTER_PROFILE_ITERS") + let mode = + env::var("SOF_INLINE_PREFILTER_PROFILE_MODE").unwrap_or_else(|_| "manual".to_owned()); + let iterations = env::var("SOF_INLINE_PREFILTER_PROFILE_ITERS") .ok() .and_then(|raw| raw.parse::().ok()) .filter(|value| *value > 0) .unwrap_or(2_048); - let fragment_count = std::env::var("SOF_INLINE_PREFILTER_PROFILE_FRAGMENTS") + let fragment_count = env::var("SOF_INLINE_PREFILTER_PROFILE_FRAGMENTS") .ok() .and_then(|raw| raw.parse::().ok()) .filter(|value| *value > 0) diff --git a/crates/sof-observer/src/app/runtime/runloop/packet_workers.rs b/crates/sof-observer/src/app/runtime/runloop/packet_workers.rs index efc38d6a..d1567e63 100644 --- a/crates/sof-observer/src/app/runtime/runloop/packet_workers.rs +++ b/crates/sof-observer/src/app/runtime/runloop/packet_workers.rs @@ -147,17 +147,33 @@ pub(super) struct SharedKnownPubkeys { #[cfg(feature = "gossip-bootstrap")] impl SharedKnownPubkeys { #[cfg(feature = "gossip-bootstrap")] - pub(super) fn update(&self, pubkeys: Vec<[u8; 32]>) { + pub(super) fn update(&self, pubkeys: &[[u8; 32]]) { + let current = self.pubkeys.shared_get(); + if current.as_slice() == pubkeys { + return; + } + + let mut normalized = pubkeys.to_vec(); + normalized.sort_unstable(); + normalized.dedup(); + if current.as_slice() == normalized.as_slice() { + return; + } + let mut shared_pubkeys = self.pubkeys.clone(); - shared_pubkeys.update(Arc::new(pubkeys)); - self.generation.fetch_add(1, Ordering::Relaxed); + shared_pubkeys.update(Arc::new(normalized)); + self.generation.fetch_add(1, Ordering::Release); } fn snapshot(&self) -> (u64, Arc>) { - let generation = self.generation.load(Ordering::Relaxed); - let pubkeys = self.pubkeys.shared_get(); - let pubkeys = Arc::clone(&pubkeys); - (generation, pubkeys) + loop { + let generation_before = self.generation.load(Ordering::Acquire); + let pubkeys = self.pubkeys.shared_get(); + let generation_after = self.generation.load(Ordering::Acquire); + if generation_before == generation_after { + return (generation_after, Arc::clone(&pubkeys)); + } + } } } @@ -386,7 +402,7 @@ impl PacketWorkerPool { } #[cfg(feature = "gossip-bootstrap")] - pub(super) fn update_known_pubkeys(&self, pubkeys: Vec<[u8; 32]>) { + pub(super) fn update_known_pubkeys(&self, pubkeys: &[[u8; 32]]) { self.known_pubkeys.update(pubkeys); } @@ -462,7 +478,7 @@ fn refresh_known_pubkeys( if generation == *verifier_generation { return; } - shred_verifier.set_known_pubkeys(pubkeys.as_ref().clone()); + shred_verifier.set_known_pubkeys_sorted(pubkeys.as_slice()); *verifier_generation = generation; } @@ -510,8 +526,8 @@ where parsed_header_slot(&packet.parsed_header), &mut observed_slot_leaders, ); - let recovered_packets = fec_recoverer - .ingest_packet(packet.packet_bytes.as_ref(), &packet.parsed_header); + let recovered_packets = + fec_recoverer.ingest_packet(&packet.packet_bytes, &packet.parsed_header); push_primary_shred(packet, &mut accepted_shreds); for recovered in recovered_packets { @@ -765,6 +781,9 @@ const fn derive_parent_slot(slot: u64, parent_offset: u16) -> Option { #[cfg(test)] mod tests { use super::*; + + #[cfg(feature = "gossip-bootstrap")] + use crate::verify::ShredVerifier; use crate::{ protocol::shred_wire::{ SIZE_OF_CODING_SHRED_HEADERS, SIZE_OF_CODING_SHRED_PAYLOAD, SIZE_OF_DATA_SHRED_PAYLOAD, @@ -773,6 +792,11 @@ mod tests { shred::wire::{SIZE_OF_DATA_SHRED_HEADERS, parse_shred_header}, }; use reed_solomon_erasure::galois_8::ReedSolomon; + use sof_support::{bench::avg_ns_per_iteration, env_support::read_positive_usize}; + #[cfg(feature = "gossip-bootstrap")] + use solana_keypair::Keypair; + #[cfg(feature = "gossip-bootstrap")] + use solana_signer::Signer; fn build_data_shred_packet( slot: u64, @@ -833,6 +857,61 @@ mod tests { packet[SIZE_OF_SIGNATURE..SIZE_OF_SIGNATURE + shard_len].to_vec() } + #[cfg(feature = "gossip-bootstrap")] + #[test] + fn shared_known_pubkeys_skips_equivalent_updates() { + let shared = SharedKnownPubkeys::default(); + let initial = [[2_u8; 32], [1_u8; 32], [2_u8; 32]]; + + shared.update(&initial); + let (first_generation, first_pubkeys) = shared.snapshot(); + assert_eq!(first_generation, 1); + assert_eq!(first_pubkeys.as_slice(), &[[1_u8; 32], [2_u8; 32]]); + + let reordered = [[1_u8; 32], [2_u8; 32]]; + shared.update(&reordered); + let (second_generation, second_pubkeys) = shared.snapshot(); + assert_eq!(second_generation, 1); + assert_eq!(second_pubkeys.as_slice(), &[[1_u8; 32], [2_u8; 32]]); + } + + #[cfg(feature = "gossip-bootstrap")] + #[test] + #[ignore = "profiling fixture for equivalent known-pubkey refresh churn"] + fn shared_known_pubkeys_equivalent_refresh_profile_fixture() { + let iterations = + read_positive_usize("SOF_PACKET_WORKER_KNOWN_PUBKEY_PROFILE_ITERS", 200_000); + let key_count = read_positive_usize("SOF_PACKET_WORKER_KNOWN_PUBKEY_COUNT", 64); + let mut canonical = (0..key_count) + .map(|_| Keypair::new().pubkey().to_bytes()) + .collect::>(); + canonical.sort_unstable(); + + let shared = SharedKnownPubkeys::default(); + let mut verifier = ShredVerifier::new(1024, 256, Duration::from_secs(5)); + let mut verifier_generation = 0_u64; + + shared.update(&canonical); + refresh_known_pubkeys(&shared, &mut verifier_generation, Some(&mut verifier)); + + let started_at = Instant::now(); + for _ in 0..iterations { + shared.update(&canonical); + refresh_known_pubkeys(&shared, &mut verifier_generation, Some(&mut verifier)); + } + let elapsed = started_at.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); + println!( + "shared_known_pubkeys_equivalent_refresh_profile_fixture iterations={} key_count={} final_generation={} elapsed_ms={} avg_ns_per_iteration={} avg_us_per_iteration={:.3}", + iterations, + key_count, + verifier_generation, + elapsed.as_millis(), + avg_ns, + avg_ns as f64 / 1_000.0 + ); + } + fn build_recoverable_fec_pair(slot: u64) -> [(Arc<[u8]>, ParsedShredHeader); 2] { let data0 = build_data_shred_packet(slot, 0, 0, 1, &[1, 2, 3, 4]); let data1 = build_data_shred_packet(slot, 1, 0, 1, &[5, 6, 7, 8]); @@ -861,11 +940,7 @@ mod tests { #[test] #[ignore = "profiling fixture for packet worker primary FEC ingest"] fn packet_worker_primary_fec_profile_fixture() { - let iterations = std::env::var("SOF_PACKET_WORKER_FEC_PROFILE_ITERS") - .ok() - .and_then(|raw| raw.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(50_000); + let iterations = read_positive_usize("SOF_PACKET_WORKER_FEC_PROFILE_ITERS", 50_000); let packets = (0..iterations) .map(|iteration| { let slot = @@ -908,22 +983,23 @@ mod tests { ); assert!(forwarded); } + let elapsed = started_at.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); println!( - "packet_worker_primary_fec_profile_fixture iterations={} emitted={} elapsed_ms={}", + "packet_worker_primary_fec_profile_fixture iterations={} emitted={} elapsed_ms={} avg_ns_per_iteration={} avg_us_per_iteration={:.3}", iterations, emitted, - started_at.elapsed().as_millis() + elapsed.as_millis(), + avg_ns, + avg_ns as f64 / 1_000.0 ); } #[test] #[ignore = "profiling fixture for packet worker FEC recovery"] fn packet_worker_recovery_fec_profile_fixture() { - let iterations = std::env::var("SOF_PACKET_WORKER_FEC_RECOVERY_PROFILE_ITERS") - .ok() - .and_then(|raw| raw.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(20_000); + let iterations = + read_positive_usize("SOF_PACKET_WORKER_FEC_RECOVERY_PROFILE_ITERS", 20_000); let batches = (0..iterations) .map(|iteration| { let slot = @@ -974,12 +1050,16 @@ mod tests { ); assert!(forwarded); } + let elapsed = started_at.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); println!( - "packet_worker_recovery_fec_profile_fixture iterations={} emitted={} recovered={} elapsed_ms={}", + "packet_worker_recovery_fec_profile_fixture iterations={} emitted={} recovered={} elapsed_ms={} avg_ns_per_iteration={} avg_us_per_iteration={:.3}", iterations, emitted, recovered, - started_at.elapsed().as_millis() + elapsed.as_millis(), + avg_ns, + avg_ns as f64 / 1_000.0 ); } diff --git a/crates/sof-observer/src/app/runtime/tests.rs b/crates/sof-observer/src/app/runtime/tests.rs index e227b005..38bb241e 100644 --- a/crates/sof-observer/src/app/runtime/tests.rs +++ b/crates/sof-observer/src/app/runtime/tests.rs @@ -1,9 +1,64 @@ -use std::time::{Duration, Instant}; +use std::{ + hint::black_box, + time::{Duration, Instant}, +}; + +use sof_support::bench::{avg_ns_per_iteration, profile_iterations}; use crate::repair::{MissingShredRequest, MissingShredRequestKind}; use super::OutstandingRepairRequests; +#[test] +#[ignore = "profiling fixture for outstanding repair request churn"] +fn outstanding_repairs_profile_fixture() { + let iterations = profile_iterations(50_000); + let mut outstanding = OutstandingRepairRequests::new(Duration::from_millis(150)); + let started = Instant::now(); + + for slot in 0_u64..64 { + let request = MissingShredRequest { + slot, + index: 0, + kind: MissingShredRequestKind::HighestWindowIndex, + }; + assert!(outstanding.try_reserve(&request, started)); + } + + for iteration in 0..iterations { + let now = started + Duration::from_millis(u64::try_from(iteration % 500).unwrap_or(0)); + let request = MissingShredRequest { + slot: u64::try_from(iteration % 64).unwrap_or(0), + index: u32::try_from(iteration % 32).unwrap_or(0), + kind: if iteration % 4 == 0 { + MissingShredRequestKind::HighestWindowIndex + } else { + MissingShredRequestKind::WindowIndex + }, + }; + + black_box(outstanding.try_reserve(&request, now)); + if iteration % 3 == 0 { + black_box(outstanding.on_shred_received(request.slot, request.index)); + } + if iteration % 11 == 0 { + black_box(outstanding.purge_expired(now)); + } + } + + let elapsed = started.elapsed(); + let avg_ns_per_iteration = avg_ns_per_iteration(elapsed, iterations); + let avg_us_per_iteration = avg_ns_per_iteration as f64 / 1_000.0; + eprintln!( + "outstanding_repairs_profile_fixture iterations={} elapsed_us={} avg_ns_per_iteration={} avg_us_per_iteration={:.3} entries={}", + iterations, + elapsed.as_micros(), + avg_ns_per_iteration, + avg_us_per_iteration, + outstanding.len(), + ); +} + #[test] fn outstanding_repairs_dedup_within_timeout() { let mut outstanding = OutstandingRepairRequests::new(Duration::from_millis(150)); @@ -50,3 +105,35 @@ fn outstanding_repairs_clear_highest_on_any_slot_shred() { assert_eq!(outstanding.on_shred_received(777, 120), 1); assert_eq!(outstanding.len(), 0); } + +#[test] +fn outstanding_repairs_clear_only_matching_highest_prefix_for_slot() { + let mut outstanding = OutstandingRepairRequests::new(Duration::from_millis(150)); + let now = Instant::now(); + let first = MissingShredRequest { + slot: 800, + index: 10, + kind: MissingShredRequestKind::HighestWindowIndex, + }; + let second = MissingShredRequest { + slot: 800, + index: 25, + kind: MissingShredRequestKind::HighestWindowIndex, + }; + let other_slot = MissingShredRequest { + slot: 801, + index: 12, + kind: MissingShredRequestKind::HighestWindowIndex, + }; + + assert!(outstanding.try_reserve(&first, now)); + assert!(outstanding.try_reserve(&second, now)); + assert!(outstanding.try_reserve(&other_slot, now)); + + assert_eq!(outstanding.on_shred_received(800, 12), 1); + assert_eq!(outstanding.len(), 2); + + assert!(!outstanding.try_reserve(&second, now + Duration::from_millis(10))); + assert!(!outstanding.try_reserve(&other_slot, now + Duration::from_millis(10))); + assert!(outstanding.try_reserve(&first, now + Duration::from_millis(10))); +} diff --git a/crates/sof-observer/src/app/state/dedupe.rs b/crates/sof-observer/src/app/state/dedupe.rs index 1e5395c0..11032ef1 100644 --- a/crates/sof-observer/src/app/state/dedupe.rs +++ b/crates/sof-observer/src/app/state/dedupe.rs @@ -164,7 +164,6 @@ impl ShredDedupeCache { if matches!(observation, ShredDedupeObservation::Accepted) { existing.seen_at = now; self.order.push_back((now, key)); - self.evict(now); self.observe_depths(); } return observation; @@ -239,6 +238,8 @@ fn make_shred_dedupe_key( #[cfg(test)] mod tests { + use sof_support::bench::avg_ns_per_iteration; + use super::*; use crate::{ protocol::shred_wire::SIZE_OF_DATA_SHRED_PAYLOAD, shred::wire::parse_shred_header, @@ -529,12 +530,15 @@ mod tests { ); } let elapsed = start.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); let metrics = cache.metrics(); println!( - "duplicate_ingress_profile_fixture iterations={} unique_keys={} elapsed_us={} queue_depth={} entries={}", + "duplicate_ingress_profile_fixture iterations={} unique_keys={} elapsed_us={} avg_ns_per_iteration={} avg_us_per_iteration={:.3} queue_depth={} entries={}", iterations, unique_keys, elapsed.as_micros(), + avg_ns, + avg_ns as f64 / 1_000.0, metrics.queue_depth, metrics.entries, ); @@ -569,11 +573,14 @@ mod tests { ); } let elapsed = start.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); let metrics = cache.metrics(); println!( - "duplicate_parse_and_ingress_profile_fixture iterations={} elapsed_us={} queue_depth={} entries={}", + "duplicate_parse_and_ingress_profile_fixture iterations={} elapsed_us={} avg_ns_per_iteration={} avg_us_per_iteration={:.3} queue_depth={} entries={}", iterations, elapsed.as_micros(), + avg_ns, + avg_ns as f64 / 1_000.0, metrics.queue_depth, metrics.entries, ); diff --git a/crates/sof-observer/src/app/state/mod.rs b/crates/sof-observer/src/app/state/mod.rs index 7626598a..78a3e9b8 100644 --- a/crates/sof-observer/src/app/state/mod.rs +++ b/crates/sof-observer/src/app/state/mod.rs @@ -15,6 +15,8 @@ pub use fork::{ForkTracker, ForkTrackerSnapshot, ForkTrackerUpdate}; pub use latest::note_latest_shred_slot; pub use repair::OutstandingRepairRequests; +#[cfg(any(feature = "gossip-bootstrap", test))] +pub(super) use std::collections::BTreeSet; pub(super) use std::{ collections::{HashMap, VecDeque}, time::{Duration, Instant}, diff --git a/crates/sof-observer/src/app/state/repair.rs b/crates/sof-observer/src/app/state/repair.rs index 6d4f4d33..8e991af8 100644 --- a/crates/sof-observer/src/app/state/repair.rs +++ b/crates/sof-observer/src/app/state/repair.rs @@ -22,6 +22,11 @@ impl OutstandingRepairKey { pub struct OutstandingRepairRequests { entries: HashMap, #[cfg(any(feature = "gossip-bootstrap", test))] + highest_window_by_slot: BTreeSet<(u64, u32)>, + #[cfg(any(feature = "gossip-bootstrap", test))] + /// Insertion order for expiry, including stale superseded timestamps. + order: VecDeque<(OutstandingRepairKey, Instant)>, + #[cfg(any(feature = "gossip-bootstrap", test))] timeout: Duration, } @@ -32,16 +37,31 @@ impl OutstandingRepairRequests { Self { entries: HashMap::new(), #[cfg(any(feature = "gossip-bootstrap", test))] + highest_window_by_slot: BTreeSet::new(), + #[cfg(any(feature = "gossip-bootstrap", test))] + order: VecDeque::new(), + #[cfg(any(feature = "gossip-bootstrap", test))] timeout, } } - #[cfg(feature = "gossip-bootstrap")] + #[cfg(any(feature = "gossip-bootstrap", test))] pub fn purge_expired(&mut self, now: Instant) -> usize { - let before = self.entries.len(); - self.entries - .retain(|_, sent_at| now.saturating_duration_since(*sent_at) < self.timeout); - before.saturating_sub(self.entries.len()) + let mut removed = 0_usize; + while let Some((_, front_sent_at)) = self.order.front() { + if now.saturating_duration_since(*front_sent_at) < self.timeout { + break; + } + let Some((key, queued_sent_at)) = self.order.pop_front() else { + break; + }; + if self.entries.get(&key) == Some(&queued_sent_at) { + let _ = self.entries.remove(&key); + self.remove_highest_window_index(key); + removed = removed.saturating_add(1); + } + } + removed } #[cfg(any(feature = "gossip-bootstrap", test))] @@ -52,17 +72,21 @@ impl OutstandingRepairRequests { return false; } *sent_at = now; + self.insert_highest_window_index(key); + self.order.push_back((key, now)); return true; } let _ = self.entries.insert(key, now); + self.insert_highest_window_index(key); + self.order.push_back((key, now)); true } #[cfg(feature = "gossip-bootstrap")] pub fn release(&mut self, request: &MissingShredRequest) { - let _ = self - .entries - .remove(&OutstandingRepairKey::from_request(request)); + let key = OutstandingRepairKey::from_request(request); + let _ = self.entries.remove(&key); + self.remove_highest_window_index(key); } pub fn on_shred_received(&mut self, slot: u64, index: u32) -> usize { @@ -78,16 +102,48 @@ impl OutstandingRepairRequests { { removed = removed.saturating_add(1); } - let before = self.entries.len(); - self.entries.retain(|key, _| { - !(key.kind == MissingShredRequestKind::HighestWindowIndex - && key.slot == slot - && key.index <= index) - }); - removed.saturating_add(before.saturating_sub(self.entries.len())) + #[cfg(any(feature = "gossip-bootstrap", test))] + { + let highest_to_clear: Vec<_> = self + .highest_window_by_slot + .range((slot, 0)..=(slot, index)) + .copied() + .collect(); + for (highest_slot, highest_index) in highest_to_clear { + if self + .entries + .remove(&OutstandingRepairKey { + slot: highest_slot, + index: highest_index, + kind: MissingShredRequestKind::HighestWindowIndex, + }) + .is_some() + { + removed = removed.saturating_add(1); + } + let _ = self + .highest_window_by_slot + .remove(&(highest_slot, highest_index)); + } + } + removed } pub fn len(&self) -> usize { self.entries.len() } + + #[cfg(any(feature = "gossip-bootstrap", test))] + fn insert_highest_window_index(&mut self, key: OutstandingRepairKey) { + if key.kind == MissingShredRequestKind::HighestWindowIndex { + let _ = self.highest_window_by_slot.insert((key.slot, key.index)); + } + } + + #[cfg(any(feature = "gossip-bootstrap", test))] + fn remove_highest_window_index(&mut self, key: OutstandingRepairKey) { + if key.kind == MissingShredRequestKind::HighestWindowIndex { + let _ = self.highest_window_by_slot.remove(&(key.slot, key.index)); + } + } } diff --git a/crates/sof-observer/src/framework/derived_state.rs b/crates/sof-observer/src/framework/derived_state.rs index 77d320f9..af0b000b 100644 --- a/crates/sof-observer/src/framework/derived_state.rs +++ b/crates/sof-observer/src/framework/derived_state.rs @@ -5,6 +5,7 @@ //! It defines the feed envelope, event families, checkpoints, and consumer-facing //! fault types without yet wiring a runtime producer. +use sof_support::time_support::{current_unix_nanos, current_unix_secs}; use std::{ cell::RefCell, collections::HashMap, @@ -18,7 +19,7 @@ use std::{ mpsc, }, thread::JoinHandle, - time::{SystemTime, UNIX_EPOCH}, + time::SystemTime, }; use arcshift::ArcShift; @@ -40,6 +41,11 @@ use crate::{ }, }; +/// Maximum derived-state checkpoint bundle accepted from disk during restart recovery. +const MAX_CHECKPOINT_STORE_BYTES: u64 = 64 * 1024 * 1024; +/// Maximum retained replay record accepted from disk before the loader rejects the segment. +const MAX_DISK_REPLAY_RECORD_BYTES: usize = 16 * 1024 * 1024; + #[derive(Debug, Clone, Copy, Default, Eq, PartialEq)] /// Static feed subscriptions requested by one derived-state consumer during host construction. pub struct DerivedStateConsumerConfig { @@ -828,7 +834,7 @@ impl DerivedStateCheckpointStore { if !self.path.exists() { return Ok(None); } - let bytes = fs::read(&self.path).map_err(|error| { + let file = File::open(&self.path).map_err(|error| { DerivedStateConsumerFault::new( DerivedStateConsumerFaultKind::CheckpointWriteFailed, None, @@ -838,6 +844,44 @@ impl DerivedStateCheckpointStore { ), ) })?; + let file_len = file.metadata().map(|metadata| metadata.len()).unwrap_or(0); + if file_len > MAX_CHECKPOINT_STORE_BYTES { + return Err(DerivedStateConsumerFault::new( + DerivedStateConsumerFaultKind::CheckpointWriteFailed, + None, + format!( + "derived-state checkpoint {} exceeds max {} bytes", + self.path.display(), + MAX_CHECKPOINT_STORE_BYTES + ), + )); + } + let mut bytes = Vec::with_capacity( + usize::try_from(file_len.min(MAX_CHECKPOINT_STORE_BYTES)).unwrap_or(0), + ); + file.take(MAX_CHECKPOINT_STORE_BYTES.saturating_add(1)) + .read_to_end(&mut bytes) + .map_err(|error| { + DerivedStateConsumerFault::new( + DerivedStateConsumerFaultKind::CheckpointWriteFailed, + None, + format!( + "failed to read derived-state checkpoint {}: {error}", + self.path.display() + ), + ) + })?; + if u64::try_from(bytes.len()).unwrap_or(u64::MAX) > MAX_CHECKPOINT_STORE_BYTES { + return Err(DerivedStateConsumerFault::new( + DerivedStateConsumerFaultKind::CheckpointWriteFailed, + None, + format!( + "derived-state checkpoint {} exceeds max {} bytes", + self.path.display(), + MAX_CHECKPOINT_STORE_BYTES + ), + )); + } let persisted = serde_json::from_slice::>(&bytes) .map_err(|error| { DerivedStateConsumerFault::new( @@ -898,6 +942,17 @@ impl DerivedStateCheckpointStore { format!("failed to serialize derived-state checkpoint: {error}"), ) })?; + if u64::try_from(bytes.len()).unwrap_or(u64::MAX) > MAX_CHECKPOINT_STORE_BYTES { + return Err(DerivedStateConsumerFault::new( + DerivedStateConsumerFaultKind::CheckpointWriteFailed, + Some(checkpoint.last_applied_sequence), + format!( + "derived-state checkpoint {} exceeds max {} bytes", + self.path.display(), + MAX_CHECKPOINT_STORE_BYTES + ), + )); + } if let Some(parent) = self.path.parent() { fs::create_dir_all(parent).map_err(|error| { DerivedStateConsumerFault::new( @@ -1569,8 +1624,18 @@ impl DiskDerivedStateReplaySource { /// Serializes one feed envelope into an on-disk record payload. fn encode_envelope(envelope: &DerivedStateFeedEnvelope) -> io::Result> { - bincode::serialize(envelope) - .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string())) + let encoded = bincode::serialize(envelope) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string()))?; + if encoded.len() > MAX_DISK_REPLAY_RECORD_BYTES { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "derived-state replay record exceeded max {} bytes", + MAX_DISK_REPLAY_RECORD_BYTES + ), + )); + } + Ok(encoded) } /// Deserializes one feed envelope from an on-disk record payload. @@ -1737,8 +1802,17 @@ impl DiskDerivedStateReplaySource { Err(error) if error.kind() == io::ErrorKind::UnexpectedEof => break, Err(error) => return Err(error), } - let encoded_len = u32::from_le_bytes(length_bytes); - let mut encoded = vec![0_u8; encoded_len as usize]; + let encoded_len = u32::from_le_bytes(length_bytes) as usize; + if encoded_len > MAX_DISK_REPLAY_RECORD_BYTES { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "derived-state replay record exceeded max {} bytes", + MAX_DISK_REPLAY_RECORD_BYTES + ), + )); + } + let mut encoded = vec![0_u8; encoded_len]; file.read_exact(&mut encoded)?; envelopes.push(Self::decode_envelope(&encoded)?); } @@ -1764,8 +1838,17 @@ impl DiskDerivedStateReplaySource { Err(error) if error.kind() == io::ErrorKind::UnexpectedEof => break, Err(error) => return Err(error), } - let encoded_len = u32::from_le_bytes(length_bytes); - let mut encoded = vec![0_u8; encoded_len as usize]; + let encoded_len = u32::from_le_bytes(length_bytes) as usize; + if encoded_len > MAX_DISK_REPLAY_RECORD_BYTES { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "derived-state replay record exceeded max {} bytes", + MAX_DISK_REPLAY_RECORD_BYTES + ), + )); + } + let mut encoded = vec![0_u8; encoded_len]; file.read_exact(&mut encoded)?; let envelope = Self::decode_envelope(&encoded)?; first_sequence.get_or_insert(envelope.sequence); @@ -1915,20 +1998,21 @@ impl DiskDerivedStateReplaySource { mut envelopes_to_remove: usize, ) -> io::Result { let mut removed = 0_usize; + let mut removed_segments = 0_usize; while envelopes_to_remove > 0 { - let Some(oldest_segment) = metadata.segments.first().cloned() else { + let Some(oldest_segment) = metadata.segments.get(removed_segments).cloned() else { break; }; if oldest_segment.retained_envelopes <= envelopes_to_remove { Self::evict_cached_appender(&oldest_segment.path); fs::remove_file(&oldest_segment.path)?; - metadata.segments.remove(0); metadata.retained_envelopes = metadata .retained_envelopes .saturating_sub(oldest_segment.retained_envelopes); envelopes_to_remove = envelopes_to_remove.saturating_sub(oldest_segment.retained_envelopes); removed = removed.saturating_add(oldest_segment.retained_envelopes); + removed_segments = removed_segments.saturating_add(1); continue; } @@ -1943,7 +2027,7 @@ impl DiskDerivedStateReplaySource { }; let new_path = self.segment_path(session_id, new_first_sequence); self.rewrite_records(&oldest_segment.path, &new_path, &retained)?; - let Some(oldest_segment_metadata) = metadata.segments.first_mut() else { + let Some(oldest_segment_metadata) = metadata.segments.get_mut(removed_segments) else { break; }; *oldest_segment_metadata = DiskDerivedStateSegmentMetadata { @@ -1958,6 +2042,9 @@ impl DiskDerivedStateReplaySource { removed = removed.saturating_add(envelopes_to_remove); envelopes_to_remove = 0; } + if removed_segments > 0 { + metadata.segments.drain(..removed_segments); + } Ok(removed) } @@ -1982,8 +2069,10 @@ impl DiskDerivedStateReplaySource { .collect::>(); retained_sessions.sort_by_key(|(session_id, _path)| *session_id); let mut removed_any = false; - while retained_sessions.len() > self.max_retained_sessions { - let Some((session_id, path)) = retained_sessions.first().cloned() else { + let mut removed_sessions = 0_usize; + while retained_sessions.len().saturating_sub(removed_sessions) > self.max_retained_sessions + { + let Some((session_id, path)) = retained_sessions.get(removed_sessions).cloned() else { break; }; if session_id == current_session_id { @@ -1992,9 +2081,12 @@ impl DiskDerivedStateReplaySource { Self::evict_cached_appenders_in_dir(&path); fs::remove_dir_all(&path)?; self.update_session_metadata(session_id, DiskDerivedStateSessionMetadata::default()); - retained_sessions.remove(0); + removed_sessions = removed_sessions.saturating_add(1); removed_any = true; } + if removed_sessions > 0 { + retained_sessions.drain(..removed_sessions); + } if removed_any { let _ = self.compactions.fetch_add(1, Ordering::Relaxed); } @@ -4393,7 +4485,7 @@ fn classify_rooted_account_observed_kind(event: &AccountUpdateEvent) -> RootedAc /// Returns the maximum wallclock skew across topology nodes when present. fn topology_max_wallclock_skew_ms(event: &ClusterTopologyEvent) -> Option { - let now_secs = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs(); + let now_secs = current_unix_secs(); event .snapshot_nodes .iter() @@ -4615,9 +4707,7 @@ impl RegisteredDerivedStateConsumer { /// Generates a best-effort unique session id for one process lifetime. fn generate_session_id() -> FeedSessionId { - let now_nanos = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_or(0_u128, |duration| duration.as_nanos()); + let now_nanos = current_unix_nanos(); let pid = u128::from(std::process::id()); FeedSessionId(now_nanos ^ pid) } @@ -4639,9 +4729,7 @@ mod tests { }; fn unique_test_replay_dir(name: &str) -> PathBuf { - let unique = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_or(0_u128, |duration| duration.as_nanos()); + let unique = current_unix_nanos(); env::temp_dir().join(format!( "sof-derived-state-{name}-{}-{unique}", std::process::id() @@ -4736,6 +4824,144 @@ mod tests { } } + #[test] + fn checkpoint_store_rejects_oversized_files() { + let checkpoint_path = unique_test_checkpoint_path("store-oversized"); + let parent = checkpoint_path + .parent() + .map(Path::to_path_buf) + .unwrap_or_else(|| unique_test_replay_dir("store-oversized")); + let create_result = fs::create_dir_all(&parent); + assert!(create_result.is_ok(), "{create_result:?}"); + + let file_result = File::create(&checkpoint_path); + assert!(file_result.is_ok(), "{file_result:?}"); + let file = file_result.unwrap_or_else(|error| panic!("{error}")); + let set_len_result = file.set_len(MAX_CHECKPOINT_STORE_BYTES.saturating_add(1)); + assert!(set_len_result.is_ok(), "{set_len_result:?}"); + + let store = DerivedStateCheckpointStore::new(&checkpoint_path); + let load_result = store.load::(); + assert!(load_result.is_err(), "oversized checkpoint should fail"); + let error = match load_result { + Ok(value) => panic!("expected oversized checkpoint failure, got {value:?}"), + Err(error) => error, + }; + assert!(error.message.contains("exceeds max")); + + drop(fs::remove_file(&checkpoint_path)); + drop(fs::remove_dir_all(parent)); + } + + #[test] + fn checkpoint_store_rejects_oversized_writes() { + let checkpoint_path = unique_test_checkpoint_path("store-oversized-write"); + let store = DerivedStateCheckpointStore::new(&checkpoint_path); + let checkpoint = DerivedStateCheckpoint { + session_id: FeedSessionId(123), + last_applied_sequence: FeedSequence(9), + watermarks: FeedWatermarks::default(), + state_version: 1, + extension_version: "oversized-write-test".to_owned(), + }; + let oversized = "x".repeat( + usize::try_from(MAX_CHECKPOINT_STORE_BYTES) + .unwrap_or(0) + .saturating_add(1), + ); + + let store_result = store.store(&checkpoint, &oversized); + assert!( + store_result.is_err(), + "oversized checkpoint write should fail" + ); + let error = match store_result { + Ok(()) => panic!("expected oversized checkpoint write failure"), + Err(error) => error, + }; + assert!(error.message.contains("exceeds max")); + assert!( + !checkpoint_path.exists(), + "oversized checkpoint should not be written" + ); + + if let Some(parent) = checkpoint_path.parent() { + drop(fs::remove_dir_all(parent)); + } + } + + #[test] + fn replay_encode_rejects_oversized_records() { + let envelope = DerivedStateFeedEnvelope { + session_id: FeedSessionId(55), + sequence: FeedSequence(1), + emitted_at: SystemTime::UNIX_EPOCH, + watermarks: FeedWatermarks::default(), + event: DerivedStateFeedEvent::RootedAccountObserved(RootedAccountObservedEvent { + slot: 1, + commitment_status: TxCommitmentStatus::Finalized, + finalized_slot: Some(1), + pubkey: PubkeyBytes::from([1_u8; 32]), + owner: PubkeyBytes::from([2_u8; 32]), + lamports: 1, + executable: false, + rent_epoch: 0, + data: Arc::from( + vec![7_u8; MAX_DISK_REPLAY_RECORD_BYTES.saturating_add(1)].into_boxed_slice(), + ), + write_version: None, + txn_signature: None, + is_startup: false, + matched_filter: None, + provider_source: None, + kind: RootedAccountObservedKind::Other, + }), + }; + + let encoded = DiskDerivedStateReplaySource::encode_envelope(&envelope); + assert!(encoded.is_err(), "oversized replay record should fail"); + let error = match encoded { + Ok(value) => panic!( + "expected oversized replay encode failure, got {}", + value.len() + ), + Err(error) => error, + }; + assert_eq!(error.kind(), io::ErrorKind::InvalidData); + assert!(error.to_string().contains("exceeded max")); + } + + #[test] + fn replay_loader_rejects_oversized_segment_record() { + let replay_dir = unique_test_replay_dir("oversized-record"); + let source = DiskDerivedStateReplaySource::new(&replay_dir, 32); + assert!(source.is_ok()); + let source = source.unwrap_or_else(|error| panic!("{error}")); + + let segment_path = replay_dir.join("segment.bin"); + let file_result = File::create(&segment_path); + assert!(file_result.is_ok(), "{file_result:?}"); + let mut file = file_result.unwrap_or_else(|error| panic!("{error}")); + let record_len = + u32::try_from(MAX_DISK_REPLAY_RECORD_BYTES.saturating_add(1)).unwrap_or(u32::MAX); + let write_result = file.write_all(&record_len.to_le_bytes()); + assert!(write_result.is_ok(), "{write_result:?}"); + let flush_result = file.flush(); + assert!(flush_result.is_ok(), "{flush_result:?}"); + + let load_result = source.load_segment_from_disk(&segment_path); + assert!(load_result.is_err(), "oversized replay record should fail"); + let error = match load_result { + Ok(value) => panic!("expected oversized replay record failure, got {value:?}"), + Err(error) => error, + }; + assert_eq!(error.kind(), io::ErrorKind::InvalidData); + assert!(error.to_string().contains("exceeded max")); + + drop(fs::remove_file(&segment_path)); + drop(fs::remove_dir_all(replay_dir)); + } + #[test] fn feed_sequence_next_advances_by_one() { assert_eq!(FeedSequence(41).next(), Some(FeedSequence(42))); diff --git a/crates/sof-observer/src/framework/extension_host.rs b/crates/sof-observer/src/framework/extension_host.rs index dee3016d..b0dea21d 100644 --- a/crates/sof-observer/src/framework/extension_host.rs +++ b/crates/sof-observer/src/framework/extension_host.rs @@ -3,25 +3,35 @@ #[cfg(test)] use std::str::FromStr; use std::{ + any::{Any, type_name_of_val}, collections::HashSet, io::ErrorKind, net::SocketAddr, + panic::AssertUnwindSafe, sync::{ Arc, Mutex, RwLock, atomic::{AtomicU64, Ordering}, }, - time::{Duration, Instant, SystemTime, UNIX_EPOCH}, + time::{Duration, Instant}, }; -use futures_util::{SinkExt, StreamExt}; +use futures_util::{FutureExt, SinkExt, StreamExt}; +use sof_support::time_support::current_unix_ms; use tokio::{ io::AsyncReadExt, net::{TcpListener, TcpStream, UdpSocket}, sync::mpsc, - task::JoinHandle, + task::{JoinHandle, JoinSet}, time::timeout, }; -use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message}; +use tokio_tungstenite::{ + MaybeTlsStream, WebSocketStream, connect_async_with_config, + tungstenite::{ + Error as WebSocketError, Message, + client::IntoClientRequest, + protocol::{CloseFrame, WebSocketConfig}, + }, +}; use crate::framework::extension::{ ExtensionCapability, ExtensionContext, ExtensionManifest, ExtensionResourceSpec, @@ -42,6 +52,12 @@ const DEFAULT_STARTUP_TIMEOUT_SECS: u64 = 5; const DEFAULT_SHUTDOWN_TIMEOUT_SECS: u64 = 3; /// Per-read fallback buffer size used for extension resource sockets. const DEFAULT_RESOURCE_READ_BUFFER_BYTES: usize = 2_048; +/// Maximum extension resource read buffer accepted from one startup manifest. +const MAX_RESOURCE_READ_BUFFER_BYTES: usize = 1024 * 1024; +/// Multiplier used to cap extension websocket frames/messages relative to chunk size. +const EXTENSION_WEBSOCKET_MESSAGE_LIMIT_MULTIPLIER: usize = 64; +/// Minimum non-zero timeout accepted for host startup/shutdown deadlines. +const MIN_EXTENSION_HOST_TIMEOUT: Duration = Duration::from_millis(1); /// Startup failure record for one extension. #[derive(Debug, Clone, Eq, PartialEq)] @@ -219,14 +235,22 @@ impl RuntimeExtensionHostBuilder { /// Sets startup timeout for `RuntimeExtension::setup`. #[must_use] pub const fn with_startup_timeout(mut self, timeout: Duration) -> Self { - self.startup_timeout = timeout; + self.startup_timeout = if timeout.is_zero() { + MIN_EXTENSION_HOST_TIMEOUT + } else { + timeout + }; self } /// Sets shutdown timeout for `RuntimeExtension::shutdown`. #[must_use] pub const fn with_shutdown_timeout(mut self, timeout: Duration) -> Self { - self.shutdown_timeout = timeout; + self.shutdown_timeout = if timeout.is_zero() { + MIN_EXTENSION_HOST_TIMEOUT + } else { + timeout + }; self } @@ -404,27 +428,15 @@ impl ExtensionDispatcher { record_max_atomic(&worker_max_dispatch_lag_us, queue_lag_us); let callback_extension = Arc::clone(&worker_extension); - let callback_result = tokio::spawn(async move { - callback_extension - .on_packet_received(queued_event.event) - .await; - }) - .await; - if let Err(error) = callback_result { - if error.is_panic() { - let payload = error.into_panic(); - let panic_message = panic_payload_to_string(payload.as_ref()); - tracing::error!( - extension = extension_name, - panic = %panic_message, - "runtime extension packet callback panicked; continuing runtime" - ); - } else { - tracing::error!( - extension = extension_name, - "runtime extension packet callback cancelled" - ); - } + if let Err(payload) = + invoke_extension_callback(callback_extension, queued_event.event).await + { + let panic_message = panic_payload_to_string(payload.as_ref()); + tracing::error!( + extension = extension_name, + panic = %panic_message, + "runtime extension packet callback panicked; continuing runtime" + ); } } }); @@ -518,6 +530,16 @@ impl ExtensionDispatcher { } } +/// Runs one extension packet callback while isolating panic unwinds from the dispatcher loop. +async fn invoke_extension_callback( + extension: Arc, + event: RuntimePacketEvent, +) -> Result<(), Box> { + AssertUnwindSafe(extension.on_packet_received(event)) + .catch_unwind() + .await +} + /// Separate runtime extension host from observer plugin host. #[derive(Clone)] pub struct RuntimeExtensionHost { @@ -665,7 +687,7 @@ impl RuntimeExtensionHost { let has_explicit_name = extension.has_explicit_name(); if !has_explicit_name { - let concrete_type_name = std::any::type_name_of_val(extension.as_ref()); + let concrete_type_name = type_name_of_val(extension.as_ref()); tracing::warn!( extension = extension_name, concrete_type = concrete_type_name, @@ -880,6 +902,7 @@ impl RuntimeExtensionHost { extension: &Arc, resources: &[ExtensionResourceSpec], ) -> Result<(), String> { + let startup_timeout = self.inner.startup_timeout; for resource in resources { let handle = match resource { ExtensionResourceSpec::UdpListener(spec) => { @@ -889,10 +912,12 @@ impl RuntimeExtensionHost { spawn_tcp_listener(self.clone(), extension, spec.clone()).await? } ExtensionResourceSpec::TcpConnector(spec) => { - spawn_tcp_connector(self.clone(), extension, spec.clone()).await? + spawn_tcp_connector(self.clone(), extension, spec.clone(), startup_timeout) + .await? } ExtensionResourceSpec::WsConnector(spec) => { - spawn_ws_connector(self.clone(), extension, spec.clone()).await? + spawn_ws_connector(self.clone(), extension, spec.clone(), startup_timeout) + .await? } }; extension.push_resource_handle(handle); @@ -987,6 +1012,22 @@ impl ExtensionResourceEmitter { event_class: RuntimePacketEventClass, websocket_frame_type: Option, bytes: Arc<[u8]>, + ) { + self.emit_event_with_remote_addr( + event_class, + websocket_frame_type, + self.remote_addr, + bytes, + ); + } + + /// Emits one runtime packet event from this resource with one explicit remote address. + fn emit_event_with_remote_addr( + &self, + event_class: RuntimePacketEventClass, + websocket_frame_type: Option, + remote_addr: Option, + bytes: Arc<[u8]>, ) { let source = RuntimePacketSource { kind: RuntimePacketSourceKind::ExtensionResource, @@ -997,7 +1038,7 @@ impl ExtensionResourceEmitter { shared_tag: self.shared_tag.clone(), websocket_frame_type, local_addr: self.local_addr, - remote_addr: self.remote_addr, + remote_addr, }; self.host.emit_extension_packet(source, bytes); } @@ -1055,26 +1096,24 @@ async fn spawn_udp_listener( continue; } if let Some(payload) = buffer.get(..len) { - ExtensionResourceEmitter { - remote_addr: Some(remote_addr), - ..emitter.clone() - } - .emit_event( + emitter.emit_event_with_remote_addr( RuntimePacketEventClass::Packet, None, + Some(remote_addr), Arc::from(payload), ); } } Err(error) => { - if error.kind() != ErrorKind::Interrupted { - tracing::warn!( - extension = owner_extension, - resource_id, - error = %error, - "udp extension listener read loop terminated" - ); + if error.kind() == ErrorKind::Interrupted { + continue; } + tracing::warn!( + extension = owner_extension, + resource_id, + error = %error, + "udp extension listener read loop terminated" + ); break; } } @@ -1099,34 +1138,42 @@ async fn spawn_tcp_listener( .read_buffer_bytes .max(DEFAULT_RESOURCE_READ_BUFFER_BYTES); let handle = tokio::spawn(async move { + let mut connections = JoinSet::new(); loop { - match listener.accept().await { - Ok((stream, remote_addr)) => { - let local_addr = stream.local_addr().ok(); - let emitter = ExtensionResourceEmitter::new( - host.clone(), - &owner_extension, - &resource_id, - shared_tag.clone(), - RuntimePacketTransport::Tcp, - local_addr, - Some(remote_addr), - ); - read_tcp_stream_packets( - ExtensionResourceReadContext::new(emitter, read_buffer_bytes), - stream, - ) - .await; - } - Err(error) => { - tracing::warn!( - extension = owner_extension, - resource_id, - error = %error, - "tcp extension listener accept loop terminated" - ); - break; + tokio::select! { + accepted = listener.accept() => { + match accepted { + Ok((stream, remote_addr)) => { + let local_addr = stream.local_addr().ok(); + let emitter = ExtensionResourceEmitter::new( + host.clone(), + &owner_extension, + &resource_id, + shared_tag.clone(), + RuntimePacketTransport::Tcp, + local_addr, + Some(remote_addr), + ); + connections.spawn(read_tcp_stream_packets( + ExtensionResourceReadContext::new(emitter, read_buffer_bytes), + stream, + )); + } + Err(error) => { + if error.kind() == ErrorKind::Interrupted { + continue; + } + tracing::warn!( + extension = owner_extension, + resource_id, + error = %error, + "tcp extension listener accept loop terminated" + ); + break; + } + } } + Some(_result) = connections.join_next(), if !connections.is_empty() => {} } } }); @@ -1138,9 +1185,17 @@ async fn spawn_tcp_connector( host: RuntimeExtensionHost, extension: &Arc, spec: TcpConnectorSpec, + startup_timeout: Duration, ) -> Result, String> { - let stream = TcpStream::connect(spec.remote_addr) + let stream = timeout(startup_timeout, TcpStream::connect(spec.remote_addr)) .await + .map_err(|_elapsed| { + format!( + "tcp connector {} timed out after {}ms during startup", + spec.remote_addr, + startup_timeout.as_millis() + ) + })? .map_err(|error| format!("failed to connect tcp {}: {error}", spec.remote_addr))?; let local_addr = stream.local_addr().ok(); let remote_addr = stream.peer_addr().ok(); @@ -1174,19 +1229,36 @@ async fn spawn_ws_connector( host: RuntimeExtensionHost, extension: &Arc, spec: WsConnectorSpec, + startup_timeout: Duration, ) -> Result, String> { - let (stream, _response) = connect_async(spec.url.as_str()) - .await - .map_err(|error| format!("failed to connect websocket {}: {error}", spec.url))?; + let max_payload_chunk_bytes = spec + .read_buffer_bytes + .max(DEFAULT_RESOURCE_READ_BUFFER_BYTES); + let (stream, _response) = timeout( + startup_timeout, + connect_async_with_config( + spec.url.as_str(), + Some(extension_websocket_transport_config( + max_payload_chunk_bytes, + )), + false, + ), + ) + .await + .map_err(|_elapsed| { + format!( + "websocket connector {} timed out after {}ms during startup", + spec.url, + startup_timeout.as_millis() + ) + })? + .map_err(|error| format!("failed to connect websocket {}: {error}", spec.url))?; let io = stream.get_ref().get_ref(); let local_addr = io.local_addr().ok(); let peer_addr = io.peer_addr().ok(); let owner_extension = extension.name.to_owned(); let resource_id = spec.resource_id; let shared_tag = visibility_tag(spec.visibility); - let max_payload_chunk_bytes = spec - .read_buffer_bytes - .max(DEFAULT_RESOURCE_READ_BUFFER_BYTES); let handle = tokio::spawn(async move { let emitter = ExtensionResourceEmitter::new( host, @@ -1206,6 +1278,16 @@ async fn spawn_ws_connector( Ok(handle) } +/// Builds one bounded websocket transport config for extension-owned connectors. +fn extension_websocket_transport_config(max_payload_chunk_bytes: usize) -> WebSocketConfig { + let max_message_size = max_payload_chunk_bytes + .max(DEFAULT_RESOURCE_READ_BUFFER_BYTES) + .saturating_mul(EXTENSION_WEBSOCKET_MESSAGE_LIMIT_MULTIPLIER); + WebSocketConfig::default() + .max_message_size(Some(max_message_size)) + .max_frame_size(Some(max_message_size)) +} + /// Reads packet chunks from one TCP stream and forwards them into the runtime. async fn read_tcp_stream_packets(context: ExtensionResourceReadContext, mut stream: TcpStream) { let mut buffer = vec![0_u8; context.max_payload_chunk_bytes.max(1)]; @@ -1222,14 +1304,15 @@ async fn read_tcp_stream_packets(context: ExtensionResourceReadContext, mut stre } } Err(error) => { - if error.kind() != ErrorKind::Interrupted { - tracing::warn!( - extension = context.emitter.owner_extension.as_str(), - resource_id = context.emitter.resource_id.as_str(), - error = %error, - "extension tcp stream read loop terminated" - ); + if error.kind() == ErrorKind::Interrupted { + continue; } + tracing::warn!( + extension = context.emitter.owner_extension.as_str(), + resource_id = context.emitter.resource_id.as_str(), + error = %error, + "extension tcp stream read loop terminated" + ); break; } } @@ -1281,37 +1364,39 @@ async fn read_websocket_messages( ); } Some(Ok(Message::Close(frame))) => { - let close_payload = frame - .as_ref() - .map(|frame| frame.reason.as_bytes()) - .unwrap_or_default(); - context.emitter.emit_event( - RuntimePacketEventClass::ConnectionClosed, - None, - Arc::from(close_payload), - ); - tracing::info!( - extension = context.emitter.owner_extension.as_str(), - resource_id = context.emitter.resource_id.as_str(), - close_code = frame.as_ref().map(|frame| u16::from(frame.code)), - close_reason = frame - .as_ref() - .map(|frame| frame.reason.to_string()) - .unwrap_or_default(), - "websocket connector closed by remote peer" - ); + emit_websocket_close_event(&context, frame.as_ref()); + if let Err(error) = stream.close(None).await + && !matches!( + error, + WebSocketError::ConnectionClosed | WebSocketError::AlreadyClosed + ) + { + tracing::warn!( + extension = context.emitter.owner_extension.as_str(), + resource_id = context.emitter.resource_id.as_str(), + error = %error, + "failed to complete websocket close handshake" + ); + } break; } Some(Ok(Message::Frame(_))) => { // Internal tungstenite frame detail; user-facing callbacks receive decoded messages. } Some(Err(error)) => { - tracing::warn!( - extension = context.emitter.owner_extension.as_str(), - resource_id = context.emitter.resource_id.as_str(), - error = %error, - "websocket connector read loop terminated" - ); + if matches!( + error, + WebSocketError::ConnectionClosed | WebSocketError::AlreadyClosed + ) { + emit_websocket_close_event(&context, None); + } else { + tracing::warn!( + extension = context.emitter.owner_extension.as_str(), + resource_id = context.emitter.resource_id.as_str(), + error = %error, + "websocket connector read loop terminated" + ); + } break; } None => break, @@ -1319,6 +1404,27 @@ async fn read_websocket_messages( } } +/// Emits one websocket connection-closed event with close-frame metadata when available. +fn emit_websocket_close_event(context: &ExtensionResourceReadContext, frame: Option<&CloseFrame>) { + let close_payload = frame + .map(|close_frame| close_frame.reason.as_bytes()) + .unwrap_or_default(); + context.emitter.emit_event( + RuntimePacketEventClass::ConnectionClosed, + None, + Arc::from(close_payload), + ); + tracing::info!( + extension = context.emitter.owner_extension.as_str(), + resource_id = context.emitter.resource_id.as_str(), + close_code = frame.map(|close_frame| u16::from(close_frame.code)), + close_reason = frame + .map(|close_frame| close_frame.reason.to_string()) + .unwrap_or_default(), + "websocket connector closed by remote peer" + ); +} + /// Converts extension stream visibility into the optional shared stream tag. fn visibility_tag(visibility: ExtensionStreamVisibility) -> Option { match visibility { @@ -1335,12 +1441,76 @@ struct ValidatedManifest { subscriptions: Vec, } +/// Validates one websocket connector URL before startup attempts any network work. +fn validate_websocket_resource_url(resource_id: &str, url: &str) -> Result<(), String> { + if url.trim().is_empty() { + return Err(format!( + "resource `{resource_id}` declares empty websocket url" + )); + } + url.into_client_request().map(|_| ()).map_err(|error| { + format!("resource `{resource_id}` declares invalid websocket url `{url}`: {error}") + }) +} + +/// Validates one packet subscription against startup invariants and granted capabilities. +fn validate_packet_subscription( + subscription: &PacketSubscription, + capabilities: &HashSet, +) -> Result<(), String> { + if matches!( + subscription.source_kind, + Some(RuntimePacketSourceKind::ObserverIngress) + ) { + if !capabilities.contains(&ExtensionCapability::ObserveObserverIngress) { + return Err( + "subscription declares ObserverIngress source without ObserveObserverIngress capability" + .to_owned(), + ); + } + if subscription.owner_extension.is_some() + || subscription.resource_id.is_some() + || subscription.shared_tag.is_some() + { + return Err( + "subscription declares ObserverIngress source with extension-resource-only selectors" + .to_owned(), + ); + } + } + if let Some(owner_extension) = subscription.owner_extension.as_ref() + && owner_extension.trim().is_empty() + { + return Err("subscription declares empty owner_extension".to_owned()); + } + if let Some(resource_id) = subscription.resource_id.as_ref() + && resource_id.trim().is_empty() + { + return Err("subscription declares empty resource_id".to_owned()); + } + if let Some(shared_tag) = subscription.shared_tag.as_ref() { + if shared_tag.trim().is_empty() { + return Err("subscription declares empty shared_tag".to_owned()); + } + if !capabilities.contains(&ExtensionCapability::ObserveSharedExtensionStream) { + return Err( + "subscription declares shared_tag without ObserveSharedExtensionStream capability" + .to_owned(), + ); + } + } + Ok(()) +} + /// Validates one extension manifest against the active runtime policy. fn validate_manifest( extension_name: &'static str, manifest: &ExtensionManifest, policy: &RuntimeExtensionCapabilityPolicy, ) -> Result { + if extension_name.trim().is_empty() { + return Err("extension declares empty name".to_owned()); + } let capabilities: HashSet = manifest.capabilities.iter().copied().collect(); for capability in &capabilities { @@ -1353,30 +1523,72 @@ fn validate_manifest( let mut resource_ids = HashSet::::new(); for resource in &manifest.resources { - let (resource_id, required_capability) = match resource { - ExtensionResourceSpec::UdpListener(spec) => { - (&spec.resource_id, ExtensionCapability::BindUdp) - } - ExtensionResourceSpec::TcpListener(spec) => { - (&spec.resource_id, ExtensionCapability::BindTcp) - } - ExtensionResourceSpec::TcpConnector(spec) => { - (&spec.resource_id, ExtensionCapability::ConnectTcp) - } - ExtensionResourceSpec::WsConnector(spec) => { - (&spec.resource_id, ExtensionCapability::ConnectWebSocket) - } + let (resource_id, visibility, read_buffer_bytes, required_capability) = match resource { + ExtensionResourceSpec::UdpListener(spec) => ( + &spec.resource_id, + &spec.visibility, + spec.read_buffer_bytes, + ExtensionCapability::BindUdp, + ), + ExtensionResourceSpec::TcpListener(spec) => ( + &spec.resource_id, + &spec.visibility, + spec.read_buffer_bytes, + ExtensionCapability::BindTcp, + ), + ExtensionResourceSpec::TcpConnector(spec) => ( + &spec.resource_id, + &spec.visibility, + spec.read_buffer_bytes, + ExtensionCapability::ConnectTcp, + ), + ExtensionResourceSpec::WsConnector(spec) => ( + &spec.resource_id, + &spec.visibility, + spec.read_buffer_bytes, + ExtensionCapability::ConnectWebSocket, + ), }; + if resource_id.trim().is_empty() { + return Err(format!( + "extension `{extension_name}` declares empty resource_id" + )); + } if !resource_ids.insert(resource_id.clone()) { return Err(format!( "duplicate resource_id `{resource_id}` in startup manifest for extension `{extension_name}`" )); } + if read_buffer_bytes == 0 { + return Err(format!( + "resource `{resource_id}` declares zero read_buffer_bytes" + )); + } + if read_buffer_bytes > MAX_RESOURCE_READ_BUFFER_BYTES { + return Err(format!( + "resource `{resource_id}` read_buffer_bytes {read_buffer_bytes} exceeds max {}", + MAX_RESOURCE_READ_BUFFER_BYTES + )); + } + if matches!( + visibility, + ExtensionStreamVisibility::Shared { tag } if tag.trim().is_empty() + ) { + return Err(format!( + "resource `{resource_id}` declares empty shared visibility tag" + )); + } if !capabilities.contains(&required_capability) { return Err(format!( "resource `{resource_id}` requires undeclared capability `{required_capability:?}`" )); } + if let ExtensionResourceSpec::WsConnector(spec) = resource { + validate_websocket_resource_url(resource_id, &spec.url)?; + } + } + for subscription in &manifest.subscriptions { + validate_packet_subscription(subscription, &capabilities)?; } Ok(ValidatedManifest { @@ -1396,16 +1608,8 @@ fn record_max_atomic(target: &AtomicU64, value: u64) { } } -/// Returns the current Unix timestamp in milliseconds. -fn current_unix_ms() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)) - .unwrap_or_default() -} - /// Converts a panic payload into a loggable message string. -fn panic_payload_to_string(payload: &(dyn std::any::Any + Send)) -> String { +fn panic_payload_to_string(payload: &(dyn Any + Send)) -> String { payload.downcast_ref::<&str>().map_or_else( || { payload @@ -1420,10 +1624,17 @@ fn panic_payload_to_string(payload: &(dyn std::any::Any + Send)) -> String { #[cfg(test)] mod tests { use super::*; - use std::sync::atomic::{AtomicBool, AtomicUsize}; + use std::{ + net::TcpListener as StdTcpListener, + sync::atomic::{AtomicBool, AtomicUsize}, + }; use crate::framework::ExtensionSetupError; use async_trait::async_trait; + use sof_support::bench::{avg_ns_per_iteration, profile_iterations}; + use tokio::io::AsyncWriteExt; + use tokio::time::{Instant as TokioInstant, sleep}; + use tokio_tungstenite::accept_async; struct CounterExtension { name: &'static str, @@ -1453,8 +1664,72 @@ mod tests { async fn shutdown(&self, _ctx: ExtensionContext) { self.shutdown_called.store(true, Ordering::Relaxed); if !self.shutdown_wait.is_zero() { - tokio::time::sleep(self.shutdown_wait).await; + sleep(self.shutdown_wait).await; + } + } + } + + struct PanicOnceExtension { + panic_seen: AtomicBool, + packet_count: Arc, + } + + #[async_trait] + impl RuntimeExtension for PanicOnceExtension { + fn name(&self) -> &'static str { + "panic-once-extension" + } + + async fn setup( + &self, + _ctx: ExtensionContext, + ) -> Result { + Ok(ExtensionManifest { + capabilities: vec![ExtensionCapability::ObserveObserverIngress], + resources: Vec::new(), + subscriptions: vec![PacketSubscription { + source_kind: Some(RuntimePacketSourceKind::ObserverIngress), + ..PacketSubscription::default() + }], + }) + } + + async fn on_packet_received(&self, _event: RuntimePacketEvent) { + if !self.panic_seen.swap(true, Ordering::Relaxed) { + panic!("intentional extension panic"); } + self.packet_count.fetch_add(1, Ordering::Relaxed); + } + } + + fn sample_runtime_packet_event() -> RuntimePacketEvent { + RuntimePacketEvent { + source: RuntimePacketSource { + kind: RuntimePacketSourceKind::ObserverIngress, + transport: RuntimePacketTransport::Udp, + event_class: RuntimePacketEventClass::Packet, + owner_extension: None, + resource_id: None, + shared_tag: None, + websocket_frame_type: None, + local_addr: None, + remote_addr: Some(SocketAddr::from_str("127.0.0.1:9001").expect("valid addr")), + }, + bytes: Arc::from(&[7_u8; 32][..]), + observed_unix_ms: 0, + } + } + + async fn invoke_extension_callback_baseline( + extension: Arc, + event: RuntimePacketEvent, + ) -> Result<(), Box> { + let handle = tokio::spawn(async move { + extension.on_packet_received(event).await; + }); + match handle.await { + Ok(()) => Ok(()), + Err(error) => Err(error.into_panic()), } } @@ -1521,7 +1796,7 @@ mod tests { SocketAddr::from_str("127.0.0.1:8001").expect("valid addr"), &[1, 2, 3], ); - tokio::time::sleep(Duration::from_millis(50)).await; + sleep(Duration::from_millis(50)).await; assert_eq!(ok_counter.load(Ordering::Relaxed), 1); } @@ -1554,15 +1829,15 @@ mod tests { } #[tokio::test] - async fn production_defaults_deny_outbound_connectors() { - let host = RuntimeExtensionHost::production_builder() + async fn startup_rejects_empty_resource_id() { + let host = RuntimeExtensionHost::builder() .add_extension(CounterExtension { - name: "connect-tcp-extension", + name: "empty-resource-id", startup_manifest: ExtensionManifest { - capabilities: vec![ExtensionCapability::ConnectTcp], - resources: vec![ExtensionResourceSpec::TcpConnector(TcpConnectorSpec { - resource_id: "tcp-outbound".to_owned(), - remote_addr: SocketAddr::from_str("127.0.0.1:9").expect("valid addr"), + capabilities: vec![ExtensionCapability::BindUdp], + resources: vec![ExtensionResourceSpec::UdpListener(UdpListenerSpec { + resource_id: " ".to_owned(), + bind_addr: SocketAddr::from_str("127.0.0.1:0").expect("valid addr"), visibility: ExtensionStreamVisibility::Private, read_buffer_bytes: 128, })], @@ -1577,14 +1852,57 @@ mod tests { let report = host.startup().await; assert_eq!(report.active_extensions, 0); assert_eq!(report.failed_extensions, 1); - assert!(report.failures[0].reason.contains("not allowed")); + assert!(report.failures[0].reason.contains("empty resource_id")); } #[tokio::test] - async fn strict_name_policy_rejects_implicit_type_name_extensions() { + async fn startup_rejects_empty_extension_name() { let host = RuntimeExtensionHost::builder() - .with_require_explicit_extension_names(true) - .add_extension(ImplicitNameExtension) + .add_extension(CounterExtension { + name: " ", + startup_manifest: ExtensionManifest { + capabilities: vec![ExtensionCapability::BindUdp], + resources: vec![ExtensionResourceSpec::UdpListener(UdpListenerSpec { + resource_id: "udp-feed".to_owned(), + bind_addr: SocketAddr::from_str("127.0.0.1:0").expect("valid addr"), + visibility: ExtensionStreamVisibility::Private, + read_buffer_bytes: 128, + })], + subscriptions: Vec::new(), + }, + packet_count: Arc::new(AtomicUsize::new(0)), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }) + .build(); + + let report = host.startup().await; + assert_eq!(report.active_extensions, 0); + assert_eq!(report.failed_extensions, 1); + assert!(report.failures[0].reason.contains("empty name")); + } + + #[tokio::test] + async fn startup_rejects_empty_shared_visibility_tag() { + let host = RuntimeExtensionHost::builder() + .add_extension(CounterExtension { + name: "empty-shared-tag", + startup_manifest: ExtensionManifest { + capabilities: vec![ExtensionCapability::BindTcp], + resources: vec![ExtensionResourceSpec::TcpListener(TcpListenerSpec { + resource_id: "tcp-feed".to_owned(), + bind_addr: SocketAddr::from_str("127.0.0.1:0").expect("valid addr"), + visibility: ExtensionStreamVisibility::Shared { + tag: " ".to_owned(), + }, + read_buffer_bytes: 128, + })], + subscriptions: Vec::new(), + }, + packet_count: Arc::new(AtomicUsize::new(0)), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }) .build(); let report = host.startup().await; @@ -1593,37 +1911,293 @@ mod tests { assert!( report.failures[0] .reason - .contains("requires explicit stable extension names") + .contains("empty shared visibility tag") ); } #[tokio::test] - async fn owner_only_and_shared_stream_visibility() { - let owner_counter = Arc::new(AtomicUsize::new(0)); - let shared_counter = Arc::new(AtomicUsize::new(0)); + async fn startup_rejects_oversized_read_buffer_bytes() { let host = RuntimeExtensionHost::builder() .add_extension(CounterExtension { - name: "owner", + name: "oversized-read-buffer", startup_manifest: ExtensionManifest { - capabilities: vec![], - resources: Vec::new(), - subscriptions: vec![PacketSubscription { - source_kind: Some(RuntimePacketSourceKind::ExtensionResource), - owner_extension: Some("owner".to_owned()), - ..PacketSubscription::default() - }], + capabilities: vec![ExtensionCapability::ConnectWebSocket], + resources: vec![ExtensionResourceSpec::WsConnector(WsConnectorSpec { + resource_id: "ws-feed".to_owned(), + url: "ws://127.0.0.1:1/feed".to_owned(), + visibility: ExtensionStreamVisibility::Private, + read_buffer_bytes: MAX_RESOURCE_READ_BUFFER_BYTES.saturating_add(1), + })], + subscriptions: Vec::new(), }, - packet_count: Arc::clone(&owner_counter), + packet_count: Arc::new(AtomicUsize::new(0)), shutdown_wait: Duration::ZERO, shutdown_called: Arc::new(AtomicBool::new(false)), }) - .add_extension(CounterExtension { - name: "shared-reader", - startup_manifest: ExtensionManifest { - capabilities: vec![ExtensionCapability::ObserveSharedExtensionStream], - resources: Vec::new(), - subscriptions: vec![PacketSubscription { - source_kind: Some(RuntimePacketSourceKind::ExtensionResource), + .build(); + + let report = host.startup().await; + assert_eq!(report.active_extensions, 0); + assert_eq!(report.failed_extensions, 1); + assert!(report.failures[0].reason.contains("read_buffer_bytes")); + } + + #[tokio::test] + async fn startup_rejects_zero_read_buffer_bytes() { + let host = RuntimeExtensionHost::builder() + .add_extension(CounterExtension { + name: "zero-read-buffer", + startup_manifest: ExtensionManifest { + capabilities: vec![ExtensionCapability::ConnectWebSocket], + resources: vec![ExtensionResourceSpec::WsConnector(WsConnectorSpec { + resource_id: "ws-feed".to_owned(), + url: "ws://127.0.0.1:1/feed".to_owned(), + visibility: ExtensionStreamVisibility::Private, + read_buffer_bytes: 0, + })], + subscriptions: Vec::new(), + }, + packet_count: Arc::new(AtomicUsize::new(0)), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }) + .build(); + + let report = host.startup().await; + assert_eq!(report.active_extensions, 0); + assert_eq!(report.failed_extensions, 1); + assert!(report.failures[0].reason.contains("zero read_buffer_bytes")); + } + + #[tokio::test] + async fn startup_rejects_invalid_websocket_url() { + let host = RuntimeExtensionHost::builder() + .add_extension(CounterExtension { + name: "invalid-websocket-url", + startup_manifest: ExtensionManifest { + capabilities: vec![ExtensionCapability::ConnectWebSocket], + resources: vec![ExtensionResourceSpec::WsConnector(WsConnectorSpec { + resource_id: "ws-feed".to_owned(), + url: "not a websocket url".to_owned(), + visibility: ExtensionStreamVisibility::Private, + read_buffer_bytes: 128, + })], + subscriptions: Vec::new(), + }, + packet_count: Arc::new(AtomicUsize::new(0)), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }) + .build(); + + let report = host.startup().await; + assert_eq!(report.active_extensions, 0); + assert_eq!(report.failed_extensions, 1); + assert!(report.failures[0].reason.contains("invalid websocket url")); + } + + #[tokio::test] + async fn startup_rejects_empty_subscription_shared_tag() { + let host = RuntimeExtensionHost::builder() + .add_extension(CounterExtension { + name: "empty-subscription-shared-tag", + startup_manifest: ExtensionManifest { + capabilities: vec![ExtensionCapability::ObserveSharedExtensionStream], + resources: Vec::new(), + subscriptions: vec![PacketSubscription { + source_kind: Some(RuntimePacketSourceKind::ExtensionResource), + shared_tag: Some(" ".to_owned()), + ..PacketSubscription::default() + }], + }, + packet_count: Arc::new(AtomicUsize::new(0)), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }) + .build(); + + let report = host.startup().await; + assert_eq!(report.active_extensions, 0); + assert_eq!(report.failed_extensions, 1); + assert!(report.failures[0].reason.contains("empty shared_tag")); + } + + #[tokio::test] + async fn startup_rejects_shared_stream_subscription_without_capability() { + let host = RuntimeExtensionHost::builder() + .add_extension(CounterExtension { + name: "missing-shared-stream-capability", + startup_manifest: ExtensionManifest { + capabilities: vec![ExtensionCapability::ObserveObserverIngress], + resources: Vec::new(), + subscriptions: vec![PacketSubscription { + source_kind: Some(RuntimePacketSourceKind::ExtensionResource), + shared_tag: Some("shared-feed".to_owned()), + ..PacketSubscription::default() + }], + }, + packet_count: Arc::new(AtomicUsize::new(0)), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }) + .build(); + + let report = host.startup().await; + assert_eq!(report.active_extensions, 0); + assert_eq!(report.failed_extensions, 1); + assert!( + report.failures[0] + .reason + .contains("ObserveSharedExtensionStream capability") + ); + } + + #[tokio::test] + async fn startup_rejects_observer_ingress_subscription_without_capability() { + let host = RuntimeExtensionHost::builder() + .add_extension(CounterExtension { + name: "missing-observer-ingress-capability", + startup_manifest: ExtensionManifest { + capabilities: vec![ExtensionCapability::ObserveSharedExtensionStream], + resources: Vec::new(), + subscriptions: vec![PacketSubscription { + source_kind: Some(RuntimePacketSourceKind::ObserverIngress), + ..PacketSubscription::default() + }], + }, + packet_count: Arc::new(AtomicUsize::new(0)), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }) + .build(); + + let report = host.startup().await; + assert_eq!(report.active_extensions, 0); + assert_eq!(report.failed_extensions, 1); + assert!( + report.failures[0] + .reason + .contains("ObserveObserverIngress capability") + ); + } + + #[tokio::test] + async fn startup_rejects_observer_ingress_subscription_with_resource_selectors() { + let host = RuntimeExtensionHost::builder() + .add_extension(CounterExtension { + name: "observer-ingress-resource-selectors", + startup_manifest: ExtensionManifest { + capabilities: vec![ExtensionCapability::ObserveObserverIngress], + resources: Vec::new(), + subscriptions: vec![PacketSubscription { + source_kind: Some(RuntimePacketSourceKind::ObserverIngress), + owner_extension: Some("owner".to_owned()), + ..PacketSubscription::default() + }], + }, + packet_count: Arc::new(AtomicUsize::new(0)), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }) + .build(); + + let report = host.startup().await; + assert_eq!(report.active_extensions, 0); + assert_eq!(report.failed_extensions, 1); + assert!( + report.failures[0] + .reason + .contains("extension-resource-only selectors") + ); + } + + #[test] + fn builder_clamps_zero_startup_timeout() { + let host = RuntimeExtensionHost::builder() + .with_startup_timeout(Duration::ZERO) + .build(); + assert_eq!(host.inner.startup_timeout, Duration::from_millis(1)); + } + + #[test] + fn builder_clamps_zero_shutdown_timeout() { + let host = RuntimeExtensionHost::builder() + .with_shutdown_timeout(Duration::ZERO) + .build(); + assert_eq!(host.inner.shutdown_timeout, Duration::from_millis(1)); + } + + #[tokio::test] + async fn production_defaults_deny_outbound_connectors() { + let host = RuntimeExtensionHost::production_builder() + .add_extension(CounterExtension { + name: "connect-tcp-extension", + startup_manifest: ExtensionManifest { + capabilities: vec![ExtensionCapability::ConnectTcp], + resources: vec![ExtensionResourceSpec::TcpConnector(TcpConnectorSpec { + resource_id: "tcp-outbound".to_owned(), + remote_addr: SocketAddr::from_str("127.0.0.1:9").expect("valid addr"), + visibility: ExtensionStreamVisibility::Private, + read_buffer_bytes: 128, + })], + subscriptions: Vec::new(), + }, + packet_count: Arc::new(AtomicUsize::new(0)), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }) + .build(); + + let report = host.startup().await; + assert_eq!(report.active_extensions, 0); + assert_eq!(report.failed_extensions, 1); + assert!(report.failures[0].reason.contains("not allowed")); + } + + #[tokio::test] + async fn strict_name_policy_rejects_implicit_type_name_extensions() { + let host = RuntimeExtensionHost::builder() + .with_require_explicit_extension_names(true) + .add_extension(ImplicitNameExtension) + .build(); + + let report = host.startup().await; + assert_eq!(report.active_extensions, 0); + assert_eq!(report.failed_extensions, 1); + assert!( + report.failures[0] + .reason + .contains("requires explicit stable extension names") + ); + } + + #[tokio::test] + async fn owner_only_and_shared_stream_visibility() { + let owner_counter = Arc::new(AtomicUsize::new(0)); + let shared_counter = Arc::new(AtomicUsize::new(0)); + let host = RuntimeExtensionHost::builder() + .add_extension(CounterExtension { + name: "owner", + startup_manifest: ExtensionManifest { + capabilities: vec![], + resources: Vec::new(), + subscriptions: vec![PacketSubscription { + source_kind: Some(RuntimePacketSourceKind::ExtensionResource), + owner_extension: Some("owner".to_owned()), + ..PacketSubscription::default() + }], + }, + packet_count: Arc::clone(&owner_counter), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }) + .add_extension(CounterExtension { + name: "shared-reader", + startup_manifest: ExtensionManifest { + capabilities: vec![ExtensionCapability::ObserveSharedExtensionStream], + resources: Vec::new(), + subscriptions: vec![PacketSubscription { + source_kind: Some(RuntimePacketSourceKind::ExtensionResource), shared_tag: Some("shared-feed".to_owned()), ..PacketSubscription::default() }], @@ -1650,7 +2224,7 @@ mod tests { }, Arc::from(&[1_u8][..]), ); - tokio::time::sleep(Duration::from_millis(50)).await; + sleep(Duration::from_millis(50)).await; assert_eq!(owner_counter.load(Ordering::Relaxed), 1); assert_eq!(shared_counter.load(Ordering::Relaxed), 0); @@ -1668,7 +2242,7 @@ mod tests { }, Arc::from(&[2_u8][..]), ); - tokio::time::sleep(Duration::from_millis(50)).await; + sleep(Duration::from_millis(50)).await; assert_eq!(owner_counter.load(Ordering::Relaxed), 2); assert_eq!(shared_counter.load(Ordering::Relaxed), 1); } @@ -1744,7 +2318,7 @@ mod tests { Arc::from(&[2_u8][..]), ); - tokio::time::sleep(Duration::from_millis(50)).await; + sleep(Duration::from_millis(50)).await; assert_eq!(text_counter.load(Ordering::Relaxed), 1); assert_eq!(any_counter.load(Ordering::Relaxed), 2); } @@ -1788,10 +2362,57 @@ mod tests { }, Arc::from(&[][..]), ); - tokio::time::sleep(Duration::from_millis(50)).await; + sleep(Duration::from_millis(50)).await; assert_eq!(close_counter.load(Ordering::Relaxed), 1); } + #[tokio::test] + async fn tcp_listener_accepts_new_connections_while_existing_stream_stays_open() { + let probe = StdTcpListener::bind("127.0.0.1:0").expect("bind probe listener"); + let bind_addr = probe.local_addr().expect("probe local addr"); + drop(probe); + + let packet_count = Arc::new(AtomicUsize::new(0)); + let host = RuntimeExtensionHost::builder() + .add_extension(CounterExtension { + name: "tcp-listener-extension", + startup_manifest: ExtensionManifest { + capabilities: vec![ExtensionCapability::BindTcp], + resources: vec![ExtensionResourceSpec::TcpListener(TcpListenerSpec { + resource_id: "tcp-listener".to_owned(), + bind_addr, + visibility: ExtensionStreamVisibility::Private, + read_buffer_bytes: 128, + })], + subscriptions: vec![PacketSubscription { + source_kind: Some(RuntimePacketSourceKind::ExtensionResource), + transport: Some(RuntimePacketTransport::Tcp), + owner_extension: Some("tcp-listener-extension".to_owned()), + ..PacketSubscription::default() + }], + }, + packet_count: Arc::clone(&packet_count), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }) + .build(); + let report = host.startup().await; + assert_eq!(report.active_extensions, 1); + + let _first = TcpStream::connect(bind_addr) + .await + .expect("connect first tcp client"); + let mut second = TcpStream::connect(bind_addr) + .await + .expect("connect second tcp client"); + assert!(second.write_all(b"second").await.is_ok()); + + sleep(Duration::from_millis(50)).await; + assert_eq!(packet_count.load(Ordering::Relaxed), 1); + + host.shutdown().await; + } + struct SlowExtension { counter: Arc, } @@ -1817,7 +2438,7 @@ mod tests { } async fn on_packet_received(&self, _event: RuntimePacketEvent) { - tokio::time::sleep(Duration::from_millis(120)).await; + sleep(Duration::from_millis(120)).await; self.counter.fetch_add(1, Ordering::Relaxed); } } @@ -1838,7 +2459,7 @@ mod tests { for _ in 0..16 { host.on_observer_packet(source, &[7_u8; 32]); } - tokio::time::sleep(Duration::from_millis(350)).await; + sleep(Duration::from_millis(350)).await; assert!(counter.load(Ordering::Relaxed) < 16); assert!(host.dropped_event_count() > 0); @@ -1850,6 +2471,144 @@ mod tests { assert!(metrics[0].dispatched_events >= 1); } + #[tokio::test] + async fn packet_callback_panic_does_not_stop_dispatcher() { + let packet_count = Arc::new(AtomicUsize::new(0)); + let host = RuntimeExtensionHost::builder() + .add_extension(PanicOnceExtension { + panic_seen: AtomicBool::new(false), + packet_count: Arc::clone(&packet_count), + }) + .build(); + let report = host.startup().await; + assert_eq!(report.active_extensions, 1); + + let source = SocketAddr::from_str("127.0.0.1:9001").expect("valid addr"); + host.on_observer_packet(source, &[1_u8; 8]); + host.on_observer_packet(source, &[2_u8; 8]); + + sleep(Duration::from_millis(100)).await; + assert_eq!(packet_count.load(Ordering::Relaxed), 1); + + host.shutdown().await; + } + + #[tokio::test] + #[ignore = "profiling fixture for runtime extension callback isolation"] + async fn runtime_extension_callback_isolation_profile_fixture() { + let iterations = profile_iterations(50_000); + let baseline_extension = Arc::new(CounterExtension { + name: "baseline-counter", + startup_manifest: ExtensionManifest { + capabilities: Vec::new(), + resources: Vec::new(), + subscriptions: Vec::new(), + }, + packet_count: Arc::new(AtomicUsize::new(0)), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }); + let optimized_extension = Arc::new(CounterExtension { + name: "optimized-counter", + startup_manifest: ExtensionManifest { + capabilities: Vec::new(), + resources: Vec::new(), + subscriptions: Vec::new(), + }, + packet_count: Arc::new(AtomicUsize::new(0)), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }); + let event = sample_runtime_packet_event(); + + let baseline_started_at = Instant::now(); + for _ in 0..iterations { + invoke_extension_callback_baseline( + Arc::clone(&baseline_extension) as Arc, + event.clone(), + ) + .await + .expect("baseline callback"); + } + let baseline_elapsed = baseline_started_at.elapsed(); + + let optimized_started_at = Instant::now(); + for _ in 0..iterations { + invoke_extension_callback( + Arc::clone(&optimized_extension) as Arc, + event.clone(), + ) + .await + .expect("optimized callback"); + } + let optimized_elapsed = optimized_started_at.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); + + println!( + "runtime_extension_callback_isolation_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, + ); + } + + #[test] + #[ignore = "profiling fixture for udp extension emitter remote-address churn"] + fn udp_extension_emitter_remote_addr_profile_fixture() { + let iterations = profile_iterations(200_000); + let host = RuntimeExtensionHost::default(); + let emitter = ExtensionResourceEmitter::new( + host, + "udp-extension", + "udp-listener", + None, + RuntimePacketTransport::Udp, + Some(SocketAddr::from_str("127.0.0.1:7000").expect("valid local addr")), + None, + ); + let remote_addr = Some(SocketAddr::from_str("127.0.0.1:8000").expect("valid remote addr")); + let payload = Arc::from(&[9_u8; 256][..]); + + let baseline_started_at = Instant::now(); + for _ in 0..iterations { + ExtensionResourceEmitter { + remote_addr, + ..emitter.clone() + } + .emit_event(RuntimePacketEventClass::Packet, None, Arc::clone(&payload)); + } + let baseline_elapsed = baseline_started_at.elapsed(); + + let optimized_started_at = Instant::now(); + for _ in 0..iterations { + emitter.emit_event_with_remote_addr( + RuntimePacketEventClass::Packet, + None, + remote_addr, + Arc::clone(&payload), + ); + } + let optimized_elapsed = optimized_started_at.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); + + println!( + "udp_extension_emitter_remote_addr_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, + ); + } + #[tokio::test] async fn shutdown_deadline_then_cancel() { let shutdown_called = Arc::new(AtomicBool::new(false)); @@ -1875,7 +2634,7 @@ mod tests { let report = host.startup().await; assert_eq!(report.active_extensions, 1); - let started = tokio::time::Instant::now(); + let started = TokioInstant::now(); host.shutdown().await; let elapsed = started.elapsed(); assert!(elapsed >= shutdown_timeout); @@ -1893,11 +2652,7 @@ mod tests { let tcp_server_addr = tcp_server.local_addr().expect("tcp local addr"); let tcp_server_task = tokio::spawn(async move { if let Ok((mut stream, _)) = tcp_server.accept().await { - assert!( - tokio::io::AsyncWriteExt::write_all(&mut stream, b"tcp") - .await - .is_ok() - ); + assert!(stream.write_all(b"tcp").await.is_ok()); } }); @@ -1907,7 +2662,7 @@ mod tests { let ws_server_addr = ws_server.local_addr().expect("ws local addr"); let ws_server_task = tokio::spawn(async move { if let Ok((stream, _)) = ws_server.accept().await - && let Ok(mut websocket) = tokio_tungstenite::accept_async(stream).await + && let Ok(mut websocket) = accept_async(stream).await { assert!(websocket.send(Message::Text("ws".into())).await.is_ok()); } @@ -1966,4 +2721,119 @@ mod tests { assert!(ws_server_task.await.is_ok()); host.shutdown().await; } + + #[tokio::test] + #[ignore = "requires local socket bind/connect permissions"] + async fn websocket_connector_remote_close_dispatches_connection_closed() { + let ws_server = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind ws server"); + let ws_server_addr = ws_server.local_addr().expect("ws local addr"); + let ws_server_task = tokio::spawn(async move { + if let Ok((stream, _)) = ws_server.accept().await + && let Ok(mut websocket) = accept_async(stream).await + { + assert!(websocket.close(None).await.is_ok()); + } + }); + + let closed_count = Arc::new(AtomicUsize::new(0)); + let host = RuntimeExtensionHost::builder() + .add_extension(CounterExtension { + name: "ws-close-extension", + startup_manifest: ExtensionManifest { + capabilities: vec![ExtensionCapability::ConnectWebSocket], + resources: vec![ExtensionResourceSpec::WsConnector(WsConnectorSpec { + resource_id: "ws-connector".to_owned(), + url: format!("ws://{ws_server_addr}/feed"), + visibility: ExtensionStreamVisibility::Private, + read_buffer_bytes: 128, + })], + subscriptions: vec![PacketSubscription { + source_kind: Some(RuntimePacketSourceKind::ExtensionResource), + transport: Some(RuntimePacketTransport::WebSocket), + event_class: Some(RuntimePacketEventClass::ConnectionClosed), + owner_extension: Some("ws-close-extension".to_owned()), + ..PacketSubscription::default() + }], + }, + packet_count: Arc::clone(&closed_count), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }) + .build(); + let report = host.startup().await; + assert_eq!(report.active_extensions, 1); + + sleep(Duration::from_millis(100)).await; + assert_eq!(closed_count.load(Ordering::Relaxed), 1); + + assert!(ws_server_task.await.is_ok()); + host.shutdown().await; + } + + #[tokio::test] + async fn startup_times_out_hung_websocket_connector() { + let ws_server = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind ws server"); + let ws_server_addr = ws_server.local_addr().expect("ws local addr"); + let ws_server_task = tokio::spawn(async move { + if let Ok((_stream, _)) = ws_server.accept().await { + sleep(Duration::from_secs(5)).await; + } + }); + + let host = RuntimeExtensionHost::builder() + .with_startup_timeout(Duration::from_millis(50)) + .add_extension(CounterExtension { + name: "hung-ws-connector", + startup_manifest: ExtensionManifest { + capabilities: vec![ExtensionCapability::ConnectWebSocket], + resources: vec![ExtensionResourceSpec::WsConnector(WsConnectorSpec { + resource_id: "ws-connector".to_owned(), + url: format!("ws://{ws_server_addr}/feed"), + visibility: ExtensionStreamVisibility::Private, + read_buffer_bytes: 128, + })], + subscriptions: Vec::new(), + }, + packet_count: Arc::new(AtomicUsize::new(0)), + shutdown_wait: Duration::ZERO, + shutdown_called: Arc::new(AtomicBool::new(false)), + }) + .build(); + + let report = host.startup().await; + assert_eq!(report.active_extensions, 0); + assert_eq!(report.failed_extensions, 1); + assert!(report.failures[0].reason.contains("timed out")); + + ws_server_task.abort(); + drop(ws_server_task.await); + } + + #[test] + fn extension_websocket_transport_config_caps_frames_from_chunk_size() { + let transport = extension_websocket_transport_config(4_096); + let expected = 4_096_usize.saturating_mul(EXTENSION_WEBSOCKET_MESSAGE_LIMIT_MULTIPLIER); + assert_eq!(transport.max_message_size, Some(expected)); + assert_eq!(transport.max_frame_size, Some(expected)); + + let floor_transport = extension_websocket_transport_config(1); + assert_eq!( + floor_transport.max_message_size, + Some( + DEFAULT_RESOURCE_READ_BUFFER_BYTES + .saturating_mul(EXTENSION_WEBSOCKET_MESSAGE_LIMIT_MULTIPLIER) + ) + ); + assert_eq!( + floor_transport.max_frame_size, + Some( + DEFAULT_RESOURCE_READ_BUFFER_BYTES + .saturating_mul(EXTENSION_WEBSOCKET_MESSAGE_LIMIT_MULTIPLIER) + ) + ); + } } diff --git a/crates/sof-observer/src/framework/plugin.rs b/crates/sof-observer/src/framework/plugin.rs index b5041dfc..1b4c2ed6 100644 --- a/crates/sof-observer/src/framework/plugin.rs +++ b/crates/sof-observer/src/framework/plugin.rs @@ -222,7 +222,9 @@ impl TransactionPrefilter { { return TransactionInterest::Ignore; } - if transaction_matches_any_keys(event.tx, &self.account_exclude) { + if !matches!(self.account_exclude, CompiledAccountMatcher::Empty) + && transaction_matches_any_keys(event.tx, &self.account_exclude) + { return TransactionInterest::Ignore; } if !transaction_matches_all_keys(event.tx, &self.account_required) { @@ -247,7 +249,9 @@ impl TransactionPrefilter { { return TransactionInterest::Ignore; } - if transaction_view_matches_any_keys(view, &self.account_exclude) { + if !matches!(self.account_exclude, CompiledAccountMatcher::Empty) + && transaction_view_matches_any_keys(view, &self.account_exclude) + { return TransactionInterest::Ignore; } if !transaction_view_matches_all_keys(view, &self.account_required) { diff --git a/crates/sof-observer/src/ingest/receiver/core.rs b/crates/sof-observer/src/ingest/receiver/core.rs index f1eaba71..096b5bb0 100644 --- a/crates/sof-observer/src/ingest/receiver/core.rs +++ b/crates/sof-observer/src/ingest/receiver/core.rs @@ -1,12 +1,12 @@ #![allow(clippy::indexing_slicing)] -use std::io::{ErrorKind, IoSliceMut}; +use std::io::{Error as IoError, ErrorKind, IoSliceMut}; use std::net::SocketAddr; #[cfg(target_os = "linux")] use std::os::fd::{AsFd, AsRawFd}; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; +use std::time::{Duration, Instant}; use crossbeam_queue::ArrayQueue; @@ -17,13 +17,15 @@ use nix::poll::{PollFd, PollFlags}; #[cfg(target_os = "linux")] use nix::sys::socket::{ControlMessageOwned, MsgFlags, SockaddrStorage, recvmsg}; use socket2::SockRef; +use sof_support::time_support::current_unix_ms; use thiserror::Error; use tokio::task::{self, JoinHandle}; use crate::ingest::RawPacketBatchSender; use crate::ingest::config::{ enable_rxq_ovfl_tracking, read_udp_batch_max_wait_ms, read_udp_batch_size, - read_udp_idle_wait_ms, read_udp_rcvbuf_bytes, read_udp_receiver_core, + read_udp_drop_on_channel_full, read_udp_idle_wait_ms, read_udp_rcvbuf_bytes, + read_udp_receiver_core, }; #[derive(Debug, Clone, Copy, Eq, PartialEq)] @@ -168,7 +170,7 @@ impl RawPacketBatch { ) -> Result<(), UdpReceiverError> { let Some(storage) = self.storage.as_mut() else { return Err(UdpReceiverError::Receive { - source: std::io::Error::other("raw packet batch storage missing"), + source: IoError::other("raw packet batch storage missing"), }); }; if len > UDP_PACKET_BUFFER_BYTES { @@ -213,25 +215,31 @@ impl RawPacketBatch { ingress: RawPacketIngress, bytes: &[u8], ) -> Result<(), UdpReceiverError> { - if self.storage.is_none() { - return Err(UdpReceiverError::Receive { - source: std::io::Error::other("raw packet batch storage missing"), - }); - } if bytes.len() > UDP_PACKET_BUFFER_BYTES { return Err(UdpReceiverError::InvalidPacketLength { len: bytes.len(), capacity: UDP_PACKET_BUFFER_BYTES, }); } - let buffer_index = self.ensure_receive_slots(1); - let buffer = - self.receive_buffer_mut(buffer_index) - .ok_or_else(|| UdpReceiverError::Receive { - source: std::io::Error::other("raw packet receive buffer missing"), - })?; - buffer[..bytes.len()].copy_from_slice(bytes); - self.push_received_metadata(source, ingress, buffer_index, bytes.len()) + let Some(storage) = self.storage.as_mut() else { + return Err(UdpReceiverError::Receive { + source: IoError::other("raw packet batch storage missing"), + }); + }; + let buffer_index = storage.packets.len(); + if buffer_index == storage.buffers.len() { + storage.buffers.push([0_u8; UDP_PACKET_BUFFER_BYTES]); + } + let packet_len = bytes.len(); + debug_assert!(buffer_index < storage.buffers.len()); + storage.buffers[buffer_index][..packet_len].copy_from_slice(bytes); + storage.packets.push(RawPacket { + source, + ingress, + buffer_index, + len: packet_len, + }); + Ok(()) } #[must_use] @@ -499,9 +507,10 @@ fn run_udp_receiver_with_socket( let mut batch = recycler.allocate(); let mut batch_started_at: Option = None; let mut last_rxq_ovfl_counter: Option = None; + let drop_on_full = read_udp_drop_on_channel_full(); loop { if should_shutdown(shutdown) { - flush_batch(tx, &mut batch, telemetry); + flush_batch(tx, &mut batch, drop_on_full, telemetry); return Ok(()); } #[cfg(target_os = "linux")] @@ -521,7 +530,7 @@ fn run_udp_receiver_with_socket( if let Some(telemetry) = telemetry { telemetry.record_packets(received); } - flush_batch(tx, &mut batch, telemetry); + flush_batch(tx, &mut batch, drop_on_full, telemetry); continue; } Err(error) @@ -574,7 +583,7 @@ fn run_udp_receiver_with_socket( .map(|started_at| started_at.elapsed()) .unwrap_or_default(); if batch.len() >= batch_size || batch_elapsed >= batch_max_wait { - flush_batch(tx, &mut batch, telemetry); + flush_batch(tx, &mut batch, drop_on_full, telemetry); batch_started_at = None; if current_wait != idle_wait { std_socket @@ -587,7 +596,7 @@ fn run_udp_receiver_with_socket( Err(error) if error.kind() == ErrorKind::WouldBlock || error.kind() == ErrorKind::TimedOut => { - flush_batch(tx, &mut batch, telemetry); + flush_batch(tx, &mut batch, drop_on_full, telemetry); batch_started_at = None; if current_wait != idle_wait { std_socket @@ -613,6 +622,6 @@ fn should_shutdown(shutdown: Option<&Arc>) -> bool { #[path = "io.rs"] mod io; use io::{ - current_unix_ms, flush_batch, maybe_pin_receiver_thread, recv_udp_batch_coalesced, - recv_udp_packet, tune_udp_socket, + flush_batch, maybe_pin_receiver_thread, recv_udp_batch_coalesced, recv_udp_packet, + tune_udp_socket, }; diff --git a/crates/sof-observer/src/ingest/receiver/io.rs b/crates/sof-observer/src/ingest/receiver/io.rs index d63f9939..0ab72c39 100644 --- a/crates/sof-observer/src/ingest/receiver/io.rs +++ b/crates/sof-observer/src/ingest/receiver/io.rs @@ -3,15 +3,29 @@ #![allow(clippy::arithmetic_side_effects)] use super::*; +use std::{ + io, + mem::{size_of, zeroed}, + net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, UdpSocket}, + ptr::null_mut, +}; + +use crate::ingest::RawPacketBatchSender; +#[cfg(test)] +use crate::ingest::config::read_udp_drop_on_channel_full; use crate::ingest::config::{ read_udp_busy_poll_budget, read_udp_busy_poll_us, read_udp_prefer_busy_poll, }; +#[cfg(target_os = "linux")] +use nix::errno::Errno; #[cfg(all(target_os = "linux", test))] use nix::poll::PollFlags; #[cfg(target_os = "linux")] use nix::poll::{PollFd, ppoll}; #[cfg(target_os = "linux")] use nix::sys::time::TimeSpec; +#[cfg(target_os = "linux")] +const SOCKADDR_STORAGE_LEN: libc::socklen_t = size_of::() as _; pub(super) struct UdpReceive { pub(super) len: usize, @@ -19,6 +33,18 @@ pub(super) struct UdpReceive { pub(super) rxq_ovfl_counter: Option, } +fn retry_on_interrupted(mut operation: F) -> io::Result +where + F: FnMut() -> io::Result, +{ + loop { + match operation() { + Err(error) if error.kind() == ErrorKind::Interrupted => continue, + result => return result, + } + } +} + #[cfg(target_os = "linux")] pub(super) struct UdpBatchScratch { io_vectors: Vec, @@ -32,13 +58,28 @@ impl UdpBatchScratch { let capacity = capacity.max(1); // SAFETY: The libc socket structs are plain old data and immediately // initialized before each syscall use. - let io_vectors = vec![unsafe { std::mem::zeroed() }; capacity]; + let mut io_vectors = vec![unsafe { zeroed() }; capacity]; // SAFETY: The libc socket structs are plain old data and immediately // initialized before each syscall use. - let addrs = vec![unsafe { std::mem::zeroed() }; capacity]; + let mut addrs = vec![unsafe { zeroed() }; capacity]; // SAFETY: The libc socket structs are plain old data and immediately // initialized before each syscall use. - let headers = vec![unsafe { std::mem::zeroed() }; capacity]; + let mut headers = vec![unsafe { zeroed() }; capacity]; + for index in 0..capacity { + headers[index] = libc::mmsghdr { + msg_hdr: libc::msghdr { + msg_name: (&mut addrs[index]) as *mut libc::sockaddr_storage + as *mut libc::c_void, + msg_namelen: SOCKADDR_STORAGE_LEN, + msg_iov: (&mut io_vectors[index]) as *mut libc::iovec, + msg_iovlen: 1, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }, + msg_len: 0, + }; + } Self { io_vectors, addrs, @@ -48,29 +89,33 @@ impl UdpBatchScratch { } pub(super) fn recv_udp_packet( - socket: &std::net::UdpSocket, + socket: &UdpSocket, buffer: &mut [u8], track_rxq_ovfl: bool, -) -> std::io::Result { +) -> io::Result { #[cfg(target_os = "linux")] if track_rxq_ovfl { let mut io_vectors = [IoSliceMut::new(buffer)]; let mut cmsg_space = nix::cmsg_space!([u32; 1]); - let message = recvmsg::( - socket.as_raw_fd(), - &mut io_vectors, - Some(&mut cmsg_space), - MsgFlags::empty(), - ) - .map_err(nix_errno_to_io_error)?; + let message = loop { + match recvmsg::( + socket.as_raw_fd(), + &mut io_vectors, + Some(&mut cmsg_space), + MsgFlags::empty(), + ) { + Err(Errno::EINTR) => continue, + result => break result.map_err(nix_errno_to_io_error), + } + }?; let Some(source_storage) = message.address.as_ref() else { - return Err(std::io::Error::new( + return Err(io::Error::new( ErrorKind::InvalidData, "udp recvmsg missing source address", )); }; let Some(source) = sockaddr_storage_to_socket_addr(source_storage) else { - return Err(std::io::Error::new( + return Err(io::Error::new( ErrorKind::InvalidData, "udp recvmsg source address is not inet/inet6", )); @@ -91,7 +136,7 @@ pub(super) fn recv_udp_packet( }); } - let (len, source) = socket.recv_from(buffer)?; + let (len, source) = retry_on_interrupted(|| socket.recv_from(buffer))?; Ok(UdpReceive { len, source, @@ -102,23 +147,33 @@ pub(super) fn recv_udp_packet( #[cfg(test)] #[cfg(target_os = "linux")] pub(super) fn recv_udp_batch( - socket: &std::net::UdpSocket, + socket: &UdpSocket, scratch: &mut UdpBatchScratch, batch: &mut RawPacketBatch, -) -> std::io::Result { +) -> io::Result { batch.clear(); recv_udp_batch_append(socket, scratch, batch, scratch.headers.len()) } +#[cfg(all(test, target_os = "linux"))] +fn recv_udp_batch_baseline( + socket: &UdpSocket, + scratch: &mut UdpBatchScratch, + batch: &mut RawPacketBatch, +) -> io::Result { + batch.clear(); + recv_udp_batch_append_baseline(socket, scratch, batch, scratch.headers.len()) +} + #[cfg(target_os = "linux")] pub(super) fn recv_udp_batch_coalesced( - socket: &std::net::UdpSocket, + socket: &UdpSocket, scratch: &mut UdpBatchScratch, batch: &mut RawPacketBatch, idle_wait: Duration, batch_max_wait: Duration, poll_fd: &mut [PollFd<'_>], -) -> std::io::Result { +) -> io::Result { batch.clear(); let mut total_received = 0_usize; let deadline = Instant::now() + batch_max_wait; @@ -154,7 +209,7 @@ pub(super) fn recv_udp_batch_coalesced( }; if !wait_udp_readable(poll_fd, wait)? { if total_received == 0 { - return Err(std::io::Error::from(ErrorKind::WouldBlock)); + return Err(io::Error::from(ErrorKind::WouldBlock)); } break; } @@ -168,11 +223,11 @@ pub(super) fn recv_udp_batch_coalesced( #[cfg(target_os = "linux")] fn recv_udp_batch_append( - socket: &std::net::UdpSocket, + socket: &UdpSocket, scratch: &mut UdpBatchScratch, batch: &mut RawPacketBatch, max_packets: usize, -) -> std::io::Result { +) -> io::Result { let capacity = scratch.headers.len(); let count = capacity.min(max_packets); if count == 0 { @@ -182,9 +237,91 @@ fn recv_udp_batch_append( for index in 0..count { let buffer_index = start_index.saturating_add(index); let Some(buffer) = batch.receive_buffer_mut(buffer_index) else { - return Err(std::io::Error::other( - "raw packet batch receive buffer missing", - )); + return Err(io::Error::other("raw packet batch receive buffer missing")); + }; + scratch.io_vectors[index] = libc::iovec { + iov_base: buffer.as_mut_ptr().cast(), + iov_len: buffer.len(), + }; + scratch.headers[index].msg_hdr.msg_namelen = SOCKADDR_STORAGE_LEN; + scratch.headers[index].msg_hdr.msg_flags = 0; + scratch.headers[index].msg_len = 0; + } + + let received = retry_on_interrupted(|| { + // SAFETY: All message headers, names, and iovecs point to valid writable + // memory for the duration of the syscall, and the socket fd remains live. + let received = unsafe { + libc::recvmmsg( + socket.as_raw_fd(), + scratch.headers.as_mut_ptr(), + count.min(u32::MAX as usize) as u32, + libc::MSG_WAITFORONE, + null_mut(), + ) + }; + if received < 0 { + return Err(io::Error::last_os_error()); + } + Ok(received) + })?; + let received = usize::try_from(received).unwrap_or(0); + if received == 0 { + return Ok(0); + } + + batch.reserve(received); + for index in 0..received { + let len = usize::try_from(scratch.headers[index].msg_len).unwrap_or(0); + let buffer_index = start_index.saturating_add(index); + let source = sockaddr_storage_to_socket_addr_libc( + &scratch.addrs[index], + scratch.headers[index].msg_hdr.msg_namelen, + ) + .ok_or_else(|| { + io::Error::new( + ErrorKind::InvalidData, + "udp recvmmsg source address is not inet/inet6", + ) + })?; + batch + .push_received_metadata(source, RawPacketIngress::Udp, buffer_index, len) + .map_err(|error| match error { + UdpReceiverError::InvalidPacketLength { len, capacity } => io::Error::new( + ErrorKind::InvalidData, + format!( + "udp recvmmsg returned packet length {len} beyond buffer capacity {capacity}" + ), + ), + UdpReceiverError::Receive { source: io_error } => io_error, + UdpReceiverError::BindSocket { .. } + | UdpReceiverError::SetBlockingMode { .. } + | UdpReceiverError::SetReadTimeout { .. } => io::Error::new( + ErrorKind::InvalidData, + "udp recvmmsg packet push failed", + ), + })?; + } + Ok(received) +} + +#[cfg(all(test, target_os = "linux"))] +fn recv_udp_batch_append_baseline( + socket: &UdpSocket, + scratch: &mut UdpBatchScratch, + batch: &mut RawPacketBatch, + max_packets: usize, +) -> io::Result { + let capacity = scratch.headers.len(); + let count = capacity.min(max_packets); + if count == 0 { + return Ok(0); + } + let start_index = batch.ensure_receive_slots(count); + for index in 0..count { + let buffer_index = start_index.saturating_add(index); + let Some(buffer) = batch.receive_buffer_mut(buffer_index) else { + return Err(io::Error::other("raw packet batch receive buffer missing")); }; scratch.io_vectors[index] = libc::iovec { iov_base: buffer.as_mut_ptr().cast(), @@ -194,10 +331,10 @@ fn recv_udp_batch_append( msg_hdr: libc::msghdr { msg_name: (&mut scratch.addrs[index]) as *mut libc::sockaddr_storage as *mut libc::c_void, - msg_namelen: std::mem::size_of::() as libc::socklen_t, + msg_namelen: size_of::() as libc::socklen_t, msg_iov: (&mut scratch.io_vectors[index]) as *mut libc::iovec, msg_iovlen: 1, - msg_control: std::ptr::null_mut(), + msg_control: null_mut(), msg_controllen: 0, msg_flags: 0, }, @@ -205,20 +342,23 @@ fn recv_udp_batch_append( }; } - // SAFETY: All message headers, names, and iovecs point to valid writable - // memory for the duration of the syscall, and the socket fd remains live. - let received = unsafe { - libc::recvmmsg( - socket.as_raw_fd(), - scratch.headers.as_mut_ptr(), - count.min(u32::MAX as usize) as u32, - libc::MSG_WAITFORONE, - std::ptr::null_mut(), - ) - }; - if received < 0 { - return Err(std::io::Error::last_os_error()); - } + let received = retry_on_interrupted(|| { + // SAFETY: All message headers, names, and iovecs point to valid writable + // memory for the duration of the syscall, and the socket fd remains live. + let received = unsafe { + libc::recvmmsg( + socket.as_raw_fd(), + scratch.headers.as_mut_ptr(), + count.min(u32::MAX as usize) as u32, + libc::MSG_WAITFORONE, + null_mut(), + ) + }; + if received < 0 { + return Err(io::Error::last_os_error()); + } + Ok(received) + })?; let received = usize::try_from(received).unwrap_or(0); if received == 0 { return Ok(0); @@ -228,17 +368,20 @@ fn recv_udp_batch_append( for index in 0..received { let len = usize::try_from(scratch.headers[index].msg_len).unwrap_or(0); let buffer_index = start_index.saturating_add(index); - let source = - sockaddr_storage_to_socket_addr_libc(&scratch.addrs[index]).ok_or_else(|| { - std::io::Error::new( - ErrorKind::InvalidData, - "udp recvmmsg source address is not inet/inet6", - ) - })?; + let source = sockaddr_storage_to_socket_addr_libc( + &scratch.addrs[index], + scratch.headers[index].msg_hdr.msg_namelen, + ) + .ok_or_else(|| { + io::Error::new( + ErrorKind::InvalidData, + "udp recvmmsg source address is not inet/inet6", + ) + })?; batch .push_received_metadata(source, RawPacketIngress::Udp, buffer_index, len) .map_err(|error| match error { - UdpReceiverError::InvalidPacketLength { len, capacity } => std::io::Error::new( + UdpReceiverError::InvalidPacketLength { len, capacity } => io::Error::new( ErrorKind::InvalidData, format!( "udp recvmmsg returned packet length {len} beyond buffer capacity {capacity}" @@ -247,7 +390,7 @@ fn recv_udp_batch_append( UdpReceiverError::Receive { source: io_error } => io_error, UdpReceiverError::BindSocket { .. } | UdpReceiverError::SetBlockingMode { .. } - | UdpReceiverError::SetReadTimeout { .. } => std::io::Error::new( + | UdpReceiverError::SetReadTimeout { .. } => io::Error::new( ErrorKind::InvalidData, "udp recvmmsg packet push failed", ), @@ -257,46 +400,58 @@ fn recv_udp_batch_append( } #[cfg(target_os = "linux")] -fn wait_udp_readable(poll_fd: &mut [PollFd<'_>], timeout: Duration) -> std::io::Result { +fn wait_udp_readable(poll_fd: &mut [PollFd<'_>], timeout: Duration) -> io::Result { if timeout.is_zero() { return Ok(false); } - Ok(ppoll(poll_fd, Some(TimeSpec::from_duration(timeout)), None)? > 0) + Ok(retry_on_interrupted(|| { + ppoll(poll_fd, Some(TimeSpec::from_duration(timeout)), None).map_err(nix_errno_to_io_error) + })? > 0) } #[cfg(target_os = "linux")] -fn nix_errno_to_io_error(error: nix::errno::Errno) -> std::io::Error { - std::io::Error::from_raw_os_error(error as i32) +fn nix_errno_to_io_error(error: nix::errno::Errno) -> io::Error { + io::Error::from_raw_os_error(error as i32) } #[cfg(target_os = "linux")] fn sockaddr_storage_to_socket_addr(storage: &SockaddrStorage) -> Option { storage .as_sockaddr_in() - .map(|address| SocketAddr::from(std::net::SocketAddrV4::from(*address))) + .map(|address| SocketAddr::from(SocketAddrV4::from(*address))) .or_else(|| { storage .as_sockaddr_in6() - .map(|address| SocketAddr::from(std::net::SocketAddrV6::from(*address))) + .map(|address| SocketAddr::from(SocketAddrV6::from(*address))) }) } #[cfg(target_os = "linux")] -fn sockaddr_storage_to_socket_addr_libc(storage: &libc::sockaddr_storage) -> Option { +fn sockaddr_storage_to_socket_addr_libc( + storage: &libc::sockaddr_storage, + namelen: libc::socklen_t, +) -> Option { + let namelen = usize::try_from(namelen).ok()?; match i32::from(storage.ss_family) { libc::AF_INET => { + if namelen < size_of::() { + return None; + } // SAFETY: `ss_family` confirmed AF_INET, so reinterpret as sockaddr_in. let address = unsafe { &*(storage as *const _ as *const libc::sockaddr_in) }; - Some(SocketAddr::from(std::net::SocketAddrV4::new( - std::net::Ipv4Addr::from(address.sin_addr.s_addr.to_ne_bytes()), + Some(SocketAddr::from(SocketAddrV4::new( + Ipv4Addr::from(address.sin_addr.s_addr.to_ne_bytes()), u16::from_be(address.sin_port), ))) } libc::AF_INET6 => { + if namelen < size_of::() { + return None; + } // SAFETY: `ss_family` confirmed AF_INET6, so reinterpret as sockaddr_in6. let address = unsafe { &*(storage as *const _ as *const libc::sockaddr_in6) }; - Some(SocketAddr::from(std::net::SocketAddrV6::new( - std::net::Ipv6Addr::from(address.sin6_addr.s6_addr), + Some(SocketAddr::from(SocketAddrV6::new( + Ipv6Addr::from(address.sin6_addr.s6_addr), u16::from_be(address.sin6_port), address.sin6_flowinfo, address.sin6_scope_id, @@ -307,8 +462,18 @@ fn sockaddr_storage_to_socket_addr_libc(storage: &libc::sockaddr_storage) -> Opt } pub(super) fn flush_batch( - tx: &crate::ingest::RawPacketBatchSender, + tx: &RawPacketBatchSender, batch: &mut RawPacketBatch, + drop_on_full: bool, + telemetry: Option<&ReceiverTelemetry>, +) { + flush_batch_inner(tx, batch, drop_on_full, telemetry); +} + +fn flush_batch_inner( + tx: &RawPacketBatchSender, + batch: &mut RawPacketBatch, + drop_on_full: bool, telemetry: Option<&ReceiverTelemetry>, ) { if batch.is_empty() { @@ -316,7 +481,6 @@ pub(super) fn flush_batch( } let packet_count = batch.len(); let outbound = batch.take_for_send(); - let drop_on_full = crate::ingest::config::read_udp_drop_on_channel_full(); if tx.send_batch(outbound, drop_on_full) { if let Some(telemetry) = telemetry { telemetry.record_sent_batch(packet_count); @@ -326,40 +490,45 @@ pub(super) fn flush_batch( } } -pub(super) fn current_unix_ms() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_or(0, |duration| { - duration.as_millis().min(u128::from(u64::MAX)) as u64 - }) +#[cfg(test)] +fn flush_batch_baseline( + tx: &RawPacketBatchSender, + batch: &mut RawPacketBatch, + telemetry: Option<&ReceiverTelemetry>, +) { + let drop_on_full = read_udp_drop_on_channel_full(); + flush_batch_inner(tx, batch, drop_on_full, telemetry); } -pub(super) fn tune_udp_socket(socket: &std::net::UdpSocket) { +pub(super) fn tune_udp_socket(socket: &UdpSocket) { let Some(rcvbuf_bytes) = read_udp_rcvbuf_bytes() else { tune_udp_busy_poll(socket); return; }; let sockref = SockRef::from(socket); - if let Err(error) = sockref.set_recv_buffer_size(rcvbuf_bytes) { - tracing::warn!( - requested = rcvbuf_bytes, - error = %error, - "failed to set UDP receive buffer size" - ); - return; - } - if let Ok(actual) = sockref.recv_buffer_size() { - tracing::debug!( - requested = rcvbuf_bytes, - actual, - "configured UDP receive buffer size" - ); + match sockref.set_recv_buffer_size(rcvbuf_bytes) { + Ok(()) => { + if let Ok(actual) = sockref.recv_buffer_size() { + tracing::debug!( + requested = rcvbuf_bytes, + actual, + "configured UDP receive buffer size" + ); + } + } + Err(error) => { + tracing::warn!( + requested = rcvbuf_bytes, + error = %error, + "failed to set UDP receive buffer size" + ); + } } tune_udp_busy_poll(socket); } #[cfg(target_os = "linux")] -fn tune_udp_busy_poll(socket: &std::net::UdpSocket) { +fn tune_udp_busy_poll(socket: &UdpSocket) { const SO_BUSY_POLL: libc::c_int = 46; const SO_PREFER_BUSY_POLL: libc::c_int = 69; const SO_BUSY_POLL_BUDGET: libc::c_int = 70; @@ -372,23 +541,36 @@ fn tune_udp_busy_poll(socket: &std::net::UdpSocket) { } if let Some(timeout_us) = busy_poll_us { - set_udp_socket_int_sockopt( - socket, - SO_BUSY_POLL, - timeout_us as libc::c_int, - "SO_BUSY_POLL", - ); + if let Some(timeout_us) = udp_sockopt_int_value(timeout_us) { + set_udp_socket_int_sockopt(socket, SO_BUSY_POLL, timeout_us, "SO_BUSY_POLL"); + } else { + tracing::warn!( + option = "SO_BUSY_POLL", + value = timeout_us, + local_addr = ?socket.local_addr().ok(), + "skipping UDP socket option outside libc::c_int range" + ); + } } if prefer_busy_poll { set_udp_socket_int_sockopt(socket, SO_PREFER_BUSY_POLL, 1, "SO_PREFER_BUSY_POLL"); } if let Some(packet_budget) = busy_poll_budget { - set_udp_socket_int_sockopt( - socket, - SO_BUSY_POLL_BUDGET, - packet_budget as libc::c_int, - "SO_BUSY_POLL_BUDGET", - ); + if let Some(packet_budget) = udp_sockopt_int_value(packet_budget) { + set_udp_socket_int_sockopt( + socket, + SO_BUSY_POLL_BUDGET, + packet_budget, + "SO_BUSY_POLL_BUDGET", + ); + } else { + tracing::warn!( + option = "SO_BUSY_POLL_BUDGET", + value = packet_budget, + local_addr = ?socket.local_addr().ok(), + "skipping UDP socket option outside libc::c_int range" + ); + } } tracing::info!( @@ -401,11 +583,16 @@ fn tune_udp_busy_poll(socket: &std::net::UdpSocket) { } #[cfg(not(target_os = "linux"))] -fn tune_udp_busy_poll(_socket: &std::net::UdpSocket) {} +fn tune_udp_busy_poll(_socket: &UdpSocket) {} + +#[cfg(target_os = "linux")] +fn udp_sockopt_int_value(raw_value: u32) -> Option { + libc::c_int::try_from(raw_value).ok() +} #[cfg(target_os = "linux")] fn set_udp_socket_int_sockopt( - socket: &std::net::UdpSocket, + socket: &UdpSocket, option_name: libc::c_int, option_value: libc::c_int, option_label: &str, @@ -418,13 +605,13 @@ fn set_udp_socket_int_sockopt( libc::SOL_SOCKET, option_name, &option_value as *const libc::c_int as *const libc::c_void, - std::mem::size_of::() as libc::socklen_t, + size_of::() as libc::socklen_t, ) }; if result == 0 { return; } - let error = std::io::Error::last_os_error(); + let error = io::Error::last_os_error(); tracing::warn!( option = option_label, value = option_value, @@ -434,7 +621,7 @@ fn set_udp_socket_int_sockopt( ); } -pub(super) fn maybe_pin_receiver_thread(socket: &std::net::UdpSocket) { +pub(super) fn maybe_pin_receiver_thread(socket: &UdpSocket) { let local_port = socket .local_addr() .map(|address| address.port()) @@ -483,8 +670,13 @@ pub(super) fn maybe_pin_receiver_thread(socket: &std::net::UdpSocket) { #[cfg(all(test, target_os = "linux"))] mod tests { use super::*; + use std::os::fd::AsFd; use std::thread; + use crate::ingest::create_raw_packet_batch_queue; + use crate::runtime_env::with_runtime_env_overrides_for_test; + use sof_support::{bench::avg_ns_per_iteration, env_support::read_positive_usize}; + #[derive(Debug)] struct LegacyRawPacket { _source: SocketAddr, @@ -492,10 +684,10 @@ mod tests { } fn send_burst( - sender: &std::net::UdpSocket, + sender: &UdpSocket, destination: SocketAddr, packet_count: usize, - ) -> std::io::Result<()> { + ) -> io::Result<()> { let payload = [7_u8; 256]; for _ in 0..packet_count { sender.send_to(&payload, destination)?; @@ -504,12 +696,12 @@ mod tests { } fn send_staggered_burst( - sender: std::net::UdpSocket, + sender: UdpSocket, destination: SocketAddr, packet_count: usize, packets_per_chunk: usize, gap: Duration, - ) -> std::thread::JoinHandle> { + ) -> thread::JoinHandle> { thread::spawn(move || { let payload = [9_u8; 256]; let mut sent = 0_usize; @@ -527,10 +719,7 @@ mod tests { }) } - fn receive_legacy_burst( - receiver: &std::net::UdpSocket, - packet_count: usize, - ) -> std::io::Result { + fn receive_legacy_burst(receiver: &UdpSocket, packet_count: usize) -> io::Result { let mut buffer = vec![0_u8; 2048]; let mut received = 0_usize; while received < packet_count { @@ -542,8 +731,8 @@ mod tests { #[test] fn recvmmsg_batch_matches_legacy_receive_count() { - let receiver = std::net::UdpSocket::bind("127.0.0.1:0").expect("bind receiver"); - let sender = std::net::UdpSocket::bind("127.0.0.1:0").expect("bind sender"); + let receiver = UdpSocket::bind("127.0.0.1:0").expect("bind receiver"); + let sender = UdpSocket::bind("127.0.0.1:0").expect("bind sender"); receiver .set_read_timeout(Some(Duration::from_millis(200))) .expect("set read timeout"); @@ -563,22 +752,52 @@ mod tests { assert_eq!(legacy_received, packet_count); } + #[test] + fn recvmmsg_source_address_rejects_truncated_name() { + // SAFETY: The test initializes the family tag before conversion. + let mut storage: libc::sockaddr_storage = unsafe { zeroed() }; + storage.ss_family = libc::AF_INET as libc::sa_family_t; + + let truncated = size_of::().saturating_sub(1); + let namelen = libc::socklen_t::try_from(truncated).unwrap_or(0); + assert!(sockaddr_storage_to_socket_addr_libc(&storage, namelen).is_none()); + } + + #[test] + fn udp_sockopt_int_value_rejects_out_of_range_values() { + assert_eq!(udp_sockopt_int_value(1), Some(1)); + assert_eq!(udp_sockopt_int_value(i32::MAX as u32), Some(i32::MAX)); + assert_eq!( + udp_sockopt_int_value((i32::MAX as u32).saturating_add(1)), + None + ); + assert_eq!(udp_sockopt_int_value(u32::MAX), None); + } + + #[test] + fn retry_on_interrupted_retries_until_success() { + let mut calls = 0_u8; + let result = retry_on_interrupted(|| { + calls = calls.saturating_add(1); + if calls < 3 { + return Err(io::Error::from(ErrorKind::Interrupted)); + } + Ok(calls) + }) + .expect("interrupted helper should retry"); + + assert_eq!(result, 3); + assert_eq!(calls, 3); + } + #[test] #[ignore = "profiling fixture for UDP receiver ingress"] fn udp_receiver_recvmmsg_profile_fixture() { - let iterations = std::env::var("SOF_UDP_RECEIVER_PROFILE_ITERS") - .ok() - .and_then(|raw| raw.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(1_000); - let packet_count = std::env::var("SOF_UDP_RECEIVER_PROFILE_BURST") - .ok() - .and_then(|raw| raw.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(64); - - let receiver = std::net::UdpSocket::bind("127.0.0.1:0").expect("bind receiver"); - let sender = std::net::UdpSocket::bind("127.0.0.1:0").expect("bind sender"); + let iterations = read_positive_usize("SOF_UDP_RECEIVER_PROFILE_ITERS", 1_000); + let packet_count = read_positive_usize("SOF_UDP_RECEIVER_PROFILE_BURST", 64); + + let receiver = UdpSocket::bind("127.0.0.1:0").expect("bind receiver"); + let sender = UdpSocket::bind("127.0.0.1:0").expect("bind sender"); receiver .set_read_timeout(Some(Duration::from_millis(200))) .expect("set read timeout"); @@ -603,47 +822,154 @@ mod tests { assert_eq!(batch.len(), packet_count); } let batch_elapsed = batch_started_at.elapsed(); + let legacy_avg_ns = avg_ns_per_iteration(legacy_elapsed, iterations); + let batch_avg_ns = avg_ns_per_iteration(batch_elapsed, iterations); println!( - "udp_receiver_recvmmsg_profile_fixture iterations={} burst={} legacy_us={} recvmmsg_us={}", + "udp_receiver_recvmmsg_profile_fixture iterations={} burst={} legacy_us={} legacy_avg_ns_per_iteration={} legacy_avg_us_per_iteration={:.3} recvmmsg_us={} recvmmsg_avg_ns_per_iteration={} recvmmsg_avg_us_per_iteration={:.3}", iterations, packet_count, legacy_elapsed.as_micros(), - batch_elapsed.as_micros() + legacy_avg_ns, + legacy_avg_ns as f64 / 1_000.0, + batch_elapsed.as_micros(), + batch_avg_ns, + batch_avg_ns as f64 / 1_000.0 + ); + } + + #[test] + #[ignore = "profiling fixture for UDP receiver recvmmsg setup A/B"] + fn udp_receiver_recvmmsg_setup_profile_fixture() { + let iterations = read_positive_usize("SOF_UDP_RECEIVER_PROFILE_ITERS", 1_000); + let packet_count = read_positive_usize("SOF_UDP_RECEIVER_PROFILE_BURST", 64); + + let baseline_receiver = UdpSocket::bind("127.0.0.1:0").expect("bind receiver"); + let sender = UdpSocket::bind("127.0.0.1:0").expect("bind sender"); + baseline_receiver + .set_read_timeout(Some(Duration::from_millis(200))) + .expect("set read timeout"); + let baseline_destination = baseline_receiver.local_addr().expect("receiver addr"); + let mut baseline_scratch = UdpBatchScratch::new(packet_count); + let mut baseline_batch = RawPacketBatch::with_capacity(packet_count); + + let baseline_started_at = Instant::now(); + for _ in 0..iterations { + send_burst(&sender, baseline_destination, packet_count).expect("send baseline burst"); + let received = recv_udp_batch_baseline( + &baseline_receiver, + &mut baseline_scratch, + &mut baseline_batch, + ) + .expect("receive baseline batch"); + assert_eq!(received, packet_count); + assert_eq!(baseline_batch.len(), packet_count); + } + let baseline_elapsed = baseline_started_at.elapsed(); + + let optimized_receiver = UdpSocket::bind("127.0.0.1:0").expect("bind receiver"); + optimized_receiver + .set_read_timeout(Some(Duration::from_millis(200))) + .expect("set read timeout"); + let optimized_destination = optimized_receiver.local_addr().expect("receiver addr"); + let mut optimized_scratch = UdpBatchScratch::new(packet_count); + let mut optimized_batch = RawPacketBatch::with_capacity(packet_count); + + let optimized_started_at = Instant::now(); + for _ in 0..iterations { + send_burst(&sender, optimized_destination, packet_count).expect("send optimized burst"); + let received = recv_udp_batch( + &optimized_receiver, + &mut optimized_scratch, + &mut optimized_batch, + ) + .expect("receive optimized batch"); + assert_eq!(received, packet_count); + assert_eq!(optimized_batch.len(), packet_count); + } + let optimized_elapsed = optimized_started_at.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); + + println!( + "udp_receiver_recvmmsg_setup_profile_fixture iterations={} burst={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", + iterations, + packet_count, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0 + ); + } + + #[test] + #[ignore = "profiling fixture for UDP receiver flush-path config lookup"] + fn udp_receiver_flush_batch_profile_fixture() { + let iterations = read_positive_usize("SOF_UDP_RECEIVER_PROFILE_ITERS", 10_000); + let capacity = iterations.saturating_mul(2).max(1).to_string(); + + with_runtime_env_overrides_for_test( + [("SOF_INGEST_QUEUE_CAPACITY".to_owned(), capacity)], + || { + let source: SocketAddr = "127.0.0.1:8899".parse().expect("source addr"); + let payload = [11_u8; 256]; + let recycler = RawPacketBatch::recycler_for_tests(1); + + let (baseline_tx, _baseline_rx) = create_raw_packet_batch_queue(); + let baseline_started_at = Instant::now(); + for _ in 0..iterations { + let mut batch = RawPacketBatch::from_recycler_for_tests(&recycler); + batch + .push_packet(source, RawPacketIngress::Udp, &payload) + .expect("push packet"); + flush_batch_baseline(&baseline_tx, &mut batch, None); + } + let baseline_elapsed = baseline_started_at.elapsed(); + + let (optimized_tx, _optimized_rx) = create_raw_packet_batch_queue(); + let drop_on_full = read_udp_drop_on_channel_full(); + let optimized_started_at = Instant::now(); + for _ in 0..iterations { + let mut batch = RawPacketBatch::from_recycler_for_tests(&recycler); + batch + .push_packet(source, RawPacketIngress::Udp, &payload) + .expect("push packet"); + flush_batch(&optimized_tx, &mut batch, drop_on_full, None); + } + let optimized_elapsed = optimized_started_at.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); + + println!( + "udp_receiver_flush_batch_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0 + ); + }, ); } #[test] #[ignore = "profiling fixture for UDP receiver coalesced ingress"] fn udp_receiver_recvmmsg_coalesced_profile_fixture() { - use std::os::fd::AsFd; - - let iterations = std::env::var("SOF_UDP_RECEIVER_PROFILE_ITERS") - .ok() - .and_then(|raw| raw.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(1_000); - let packet_count = std::env::var("SOF_UDP_RECEIVER_PROFILE_BURST") - .ok() - .and_then(|raw| raw.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(64); - let chunk_size = std::env::var("SOF_UDP_RECEIVER_PROFILE_CHUNK") - .ok() - .and_then(|raw| raw.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(8); - let gap_us = std::env::var("SOF_UDP_RECEIVER_PROFILE_GAP_US") - .ok() - .and_then(|raw| raw.parse::().ok()) - .filter(|value| *value > 0) + let iterations = read_positive_usize("SOF_UDP_RECEIVER_PROFILE_ITERS", 1_000); + let packet_count = read_positive_usize("SOF_UDP_RECEIVER_PROFILE_BURST", 64); + let chunk_size = read_positive_usize("SOF_UDP_RECEIVER_PROFILE_CHUNK", 8); + let gap_us = u64::try_from(read_positive_usize("SOF_UDP_RECEIVER_PROFILE_GAP_US", 100)) .unwrap_or(100); let gap = Duration::from_micros(gap_us); let idle_wait = Duration::from_millis(200); let batch_max_wait = Duration::from_millis(2); - let blocking_receiver = std::net::UdpSocket::bind("127.0.0.1:0").expect("bind receiver"); - let sender = std::net::UdpSocket::bind("127.0.0.1:0").expect("bind sender"); + let blocking_receiver = UdpSocket::bind("127.0.0.1:0").expect("bind receiver"); + let sender = UdpSocket::bind("127.0.0.1:0").expect("bind sender"); blocking_receiver .set_read_timeout(Some(idle_wait)) .expect("set read timeout"); @@ -674,8 +1000,7 @@ mod tests { } let blocking_elapsed = blocking_started_at.elapsed(); - let coalesced_receiver = - std::net::UdpSocket::bind("127.0.0.1:0").expect("bind coalesced receiver"); + let coalesced_receiver = UdpSocket::bind("127.0.0.1:0").expect("bind coalesced receiver"); let destination = coalesced_receiver .local_addr() .expect("coalesced receiver addr"); @@ -711,31 +1036,29 @@ mod tests { .expect("send staggered burst"); } let coalesced_elapsed = coalesced_started_at.elapsed(); + let blocking_avg_ns = avg_ns_per_iteration(blocking_elapsed, iterations); + let coalesced_avg_ns = avg_ns_per_iteration(coalesced_elapsed, iterations); println!( - "udp_receiver_recvmmsg_coalesced_profile_fixture iterations={} burst={} chunk={} gap_us={} immediate_us={} coalesced_us={}", + "udp_receiver_recvmmsg_coalesced_profile_fixture iterations={} burst={} chunk={} gap_us={} immediate_us={} immediate_avg_ns_per_iteration={} immediate_avg_us_per_iteration={:.3} coalesced_us={} coalesced_avg_ns_per_iteration={} coalesced_avg_us_per_iteration={:.3}", iterations, packet_count, chunk_size, gap_us, blocking_elapsed.as_micros(), - coalesced_elapsed.as_micros() + blocking_avg_ns, + blocking_avg_ns as f64 / 1_000.0, + coalesced_elapsed.as_micros(), + coalesced_avg_ns, + coalesced_avg_ns as f64 / 1_000.0 ); } #[test] #[ignore = "profiling fixture for contiguous raw packet batch materialization"] fn udp_receiver_batch_materialization_profile_fixture() { - let iterations = std::env::var("SOF_UDP_RECEIVER_PROFILE_ITERS") - .ok() - .and_then(|raw| raw.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(20_000); - let packet_count = std::env::var("SOF_UDP_RECEIVER_PROFILE_BURST") - .ok() - .and_then(|raw| raw.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(64); + let iterations = read_positive_usize("SOF_UDP_RECEIVER_PROFILE_ITERS", 20_000); + let packet_count = read_positive_usize("SOF_UDP_RECEIVER_PROFILE_BURST", 64); let source: SocketAddr = "127.0.0.1:8899".parse().expect("source addr"); let payloads: Vec> = (0..packet_count) .map(|index| vec![u8::try_from(index % 251).unwrap_or(0); 256]) @@ -778,30 +1101,31 @@ mod tests { assert_eq!(batch.len(), packet_count); } let recycled_elapsed = recycled_started_at.elapsed(); + let legacy_avg_ns = avg_ns_per_iteration(legacy_elapsed, iterations); + let contiguous_avg_ns = avg_ns_per_iteration(contiguous_elapsed, iterations); + let recycled_avg_ns = avg_ns_per_iteration(recycled_elapsed, iterations); println!( - "udp_receiver_batch_materialization_profile_fixture iterations={} burst={} legacy_arc_us={} contiguous_us={} recycled_us={}", + "udp_receiver_batch_materialization_profile_fixture iterations={} burst={} legacy_arc_us={} legacy_arc_avg_ns_per_iteration={} legacy_arc_avg_us_per_iteration={:.3} contiguous_us={} contiguous_avg_ns_per_iteration={} contiguous_avg_us_per_iteration={:.3} recycled_us={} recycled_avg_ns_per_iteration={} recycled_avg_us_per_iteration={:.3}", iterations, packet_count, legacy_elapsed.as_micros(), + legacy_avg_ns, + legacy_avg_ns as f64 / 1_000.0, contiguous_elapsed.as_micros(), - recycled_elapsed.as_micros() + contiguous_avg_ns, + contiguous_avg_ns as f64 / 1_000.0, + recycled_elapsed.as_micros(), + recycled_avg_ns, + recycled_avg_ns as f64 / 1_000.0 ); } #[test] #[ignore = "profiling fixture for recycled raw packet batch materialization"] fn udp_receiver_batch_recycler_profile_fixture() { - let iterations = std::env::var("SOF_UDP_RECEIVER_PROFILE_ITERS") - .ok() - .and_then(|raw| raw.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(50_000); - let packet_count = std::env::var("SOF_UDP_RECEIVER_PROFILE_BURST") - .ok() - .and_then(|raw| raw.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(64); + let iterations = read_positive_usize("SOF_UDP_RECEIVER_PROFILE_ITERS", 50_000); + let packet_count = read_positive_usize("SOF_UDP_RECEIVER_PROFILE_BURST", 64); let source: SocketAddr = "127.0.0.1:8899".parse().expect("source addr"); let payloads: Vec> = (0..packet_count) .map(|index| vec![u8::try_from(index % 251).unwrap_or(0); 256]) @@ -819,12 +1143,15 @@ mod tests { assert_eq!(batch.len(), packet_count); } let elapsed = started_at.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); println!( - "udp_receiver_batch_recycler_profile_fixture iterations={} burst={} recycled_us={}", + "udp_receiver_batch_recycler_profile_fixture iterations={} burst={} recycled_us={} recycled_avg_ns_per_iteration={} recycled_avg_us_per_iteration={:.3}", iterations, packet_count, - elapsed.as_micros() + elapsed.as_micros(), + avg_ns, + avg_ns as f64 / 1_000.0 ); } } diff --git a/crates/sof-observer/src/provider_stream.rs b/crates/sof-observer/src/provider_stream.rs index 65b60372..b5d116d6 100644 --- a/crates/sof-observer/src/provider_stream.rs +++ b/crates/sof-observer/src/provider_stream.rs @@ -188,6 +188,8 @@ use std::time::Duration; use thiserror::Error; use tokio::sync::mpsc; use tokio::sync::mpsc::error::SendError; +#[cfg(any(feature = "provider-grpc", feature = "provider-websocket"))] +use tokio::time::{Instant as TokioInstant, Interval, interval_at}; #[cfg(any(feature = "provider-grpc", feature = "provider-websocket"))] use std::sync::atomic::{AtomicU64, Ordering}; @@ -208,6 +210,26 @@ use solana_transaction::versioned::VersionedTransaction; /// Default queue capacity for processed provider-stream ingress. pub const DEFAULT_PROVIDER_STREAM_QUEUE_CAPACITY: usize = 8_192; +/// Smallest keepalive period accepted by provider intervals to avoid Tokio zero-period panics. +#[cfg(any(feature = "provider-grpc", feature = "provider-websocket"))] +const MIN_KEEPALIVE_INTERVAL: Duration = Duration::from_millis(1); +/// Stable compute-budget program id reused in provider transaction classifiers. +const COMPUTE_BUDGET_PROGRAM_ID: solana_pubkey::Pubkey = compute_budget::ID; +/// Stable vote program id reused in provider transaction classifiers. +const VOTE_PROGRAM_ID: solana_pubkey::Pubkey = vote::ID; + +/// Creates one keepalive interval that waits one full period before the first tick. +#[cfg(any(feature = "provider-grpc", feature = "provider-websocket"))] +pub(crate) fn keepalive_interval(period: Duration) -> Interval { + let period = if period.is_zero() { + MIN_KEEPALIVE_INTERVAL + } else { + period + }; + let start = TokioInstant::now(); + let first_tick = start.checked_add(period).unwrap_or(start); + interval_at(first_tick, period) +} /// Identifies the processed provider family driving SOF's direct plugin ingress. #[derive(Clone, Copy, Debug, Eq, PartialEq)] @@ -382,6 +404,28 @@ impl ProviderStreamUpdate { } self } + + /// Tags one provider-origin update with one shared source reference. + #[must_use] + #[cfg_attr(not(feature = "provider-grpc"), allow(dead_code))] + pub(crate) fn with_provider_source_ref(mut self, source: &Arc) -> Self { + match &mut self { + Self::Transaction(event) => event.provider_source = Some(Arc::clone(source)), + Self::SerializedTransaction(event) => event.provider_source = Some(Arc::clone(source)), + Self::TransactionLog(event) => event.provider_source = Some(Arc::clone(source)), + Self::TransactionStatus(event) => event.provider_source = Some(Arc::clone(source)), + Self::TransactionViewBatch(event) => event.provider_source = Some(Arc::clone(source)), + Self::AccountUpdate(event) => event.provider_source = Some(Arc::clone(source)), + Self::BlockMeta(event) => event.provider_source = Some(Arc::clone(source)), + Self::RecentBlockhash(event) => event.provider_source = Some(Arc::clone(source)), + Self::SlotStatus(event) => event.provider_source = Some(Arc::clone(source)), + Self::ClusterTopology(event) => event.provider_source = Some(Arc::clone(source)), + Self::LeaderSchedule(event) => event.provider_source = Some(Arc::clone(source)), + Self::Reorg(event) => event.provider_source = Some(Arc::clone(source)), + Self::Health(event) => event.source = (**source).clone(), + } + self + } } impl From for ProviderStreamUpdate { @@ -1165,14 +1209,14 @@ pub(crate) fn classify_provider_transaction_kind(tx: &VersionedTransaction) -> T let keys = tx.message.static_account_keys(); for instruction in tx.message.instructions() { if let Some(program_id) = keys.get(usize::from(instruction.program_id_index)) { - if *program_id == vote::id() { + if *program_id == VOTE_PROGRAM_ID { has_vote = true; if has_non_vote_non_budget { return TxKind::Mixed; } continue; } - if *program_id != compute_budget::id() { + if *program_id != COMPUTE_BUDGET_PROGRAM_ID { has_non_vote_non_budget = true; if has_vote { return TxKind::Mixed; @@ -1196,14 +1240,14 @@ pub(crate) fn classify_provider_transaction_kind_view( let mut has_vote = false; let mut has_non_vote_non_budget = false; for (program_id, _) in view.program_instructions_iter() { - if *program_id == vote::id() { + if *program_id == VOTE_PROGRAM_ID { has_vote = true; if has_non_vote_non_budget { return TxKind::Mixed; } continue; } - if *program_id != compute_budget::id() { + if *program_id != COMPUTE_BUDGET_PROGRAM_ID { has_non_vote_non_budget = true; if has_vote { return TxKind::Mixed; @@ -1300,6 +1344,7 @@ pub mod websocket; #[cfg(all(test, any(feature = "provider-grpc", feature = "provider-websocket")))] mod tests { use super::*; + use sof_support::bench::{avg_ns_per_iteration, profile_iterations}; use solana_instruction::Instruction; use solana_keypair::Keypair; use solana_message::{Message, VersionedMessage}; @@ -1309,14 +1354,6 @@ mod tests { use tokio::runtime::Runtime; use tokio::time::{Duration, sleep, timeout}; - fn profile_iterations(default: usize) -> usize { - std::env::var("SOF_PROFILE_ITERATIONS") - .ok() - .and_then(|value| value.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(default) - } - fn sample_mixed_transaction() -> VersionedTransaction { let signer = Keypair::new(); let mut instructions = Vec::with_capacity(34); @@ -1337,6 +1374,37 @@ mod tests { VersionedTransaction::try_new(VersionedMessage::Legacy(message), &[&signer]).expect("tx") } + fn sample_recent_blockhash_update() -> ProviderStreamUpdate { + ProviderStreamUpdate::RecentBlockhash(ObservedRecentBlockhashEvent { + slot: 7, + recent_blockhash: solana_hash::Hash::new_unique().to_bytes(), + dataset_tx_count: 0, + provider_source: None, + }) + } + + #[tokio::test] + async fn keepalive_interval_waits_one_full_period_before_first_tick() { + let mut interval = keepalive_interval(Duration::from_millis(25)); + assert!( + timeout(Duration::from_millis(5), interval.tick()) + .await + .is_err(), + "keepalive interval should not tick immediately" + ); + timeout(Duration::from_millis(50), interval.tick()) + .await + .expect("keepalive interval should tick after one full period"); + } + + #[tokio::test] + async fn keepalive_interval_zero_period_is_clamped() { + let mut interval = keepalive_interval(Duration::ZERO); + timeout(Duration::from_millis(50), interval.tick()) + .await + .expect("zero keepalive period should not panic or stall"); + } + fn classify_provider_transaction_kind_baseline(tx: &VersionedTransaction) -> TxKind { let mut has_vote = false; let mut has_non_vote_non_budget = false; @@ -1361,6 +1429,36 @@ mod tests { } } + fn classify_provider_transaction_kind_pre_hoist(tx: &VersionedTransaction) -> TxKind { + let mut has_vote = false; + let mut has_non_vote_non_budget = false; + let keys = tx.message.static_account_keys(); + for instruction in tx.message.instructions() { + if let Some(program_id) = keys.get(usize::from(instruction.program_id_index)) { + if *program_id == vote::id() { + has_vote = true; + if has_non_vote_non_budget { + return TxKind::Mixed; + } + continue; + } + if *program_id != compute_budget::id() { + has_non_vote_non_budget = true; + if has_vote { + return TxKind::Mixed; + } + } + } + } + if has_vote && !has_non_vote_non_budget { + TxKind::VoteOnly + } else if has_vote { + TxKind::Mixed + } else { + TxKind::NonVote + } + } + #[test] fn classify_provider_transaction_kind_detects_mixed() { let tx = sample_mixed_transaction(); @@ -1696,6 +1794,37 @@ mod tests { ); } + #[test] + #[ignore = "profiling fixture for provider tx kind hoisted id comparison"] + fn provider_transaction_kind_hoist_ids_profile_fixture() { + let iterations = profile_iterations(1_000_000); + + let tx = sample_mixed_transaction(); + + let baseline_started = Instant::now(); + for _ in 0..iterations { + std::hint::black_box(classify_provider_transaction_kind_pre_hoist(&tx)); + } + let baseline_elapsed = baseline_started.elapsed(); + + let optimized_started = Instant::now(); + for _ in 0..iterations { + std::hint::black_box(classify_provider_transaction_kind(&tx)); + } + let optimized_elapsed = optimized_started.elapsed(); + + eprintln!( + "provider_transaction_kind_hoist_ids_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + avg_ns_per_iteration(baseline_elapsed, iterations), + avg_ns_per_iteration(optimized_elapsed, iterations), + avg_ns_per_iteration(baseline_elapsed, iterations) as f64 / 1_000.0, + avg_ns_per_iteration(optimized_elapsed, iterations) as f64 / 1_000.0, + ); + } + #[test] #[ignore = "profiling fixture for baseline provider tx kind classification"] fn provider_transaction_kind_baseline_profile_fixture() { @@ -1717,4 +1846,41 @@ mod tests { std::hint::black_box(classify_provider_transaction_kind(&tx)); } } + + #[test] + #[ignore = "profiling fixture for provider source attachment path"] + fn provider_update_source_attachment_profile_fixture() { + let iterations = profile_iterations(1_000_000); + let source = ProviderSourceIdentity::new( + ProviderSourceId::Generic(Arc::::from("custom")), + "source-a", + ); + let source_ref = Arc::new(source.clone()); + let update = sample_recent_blockhash_update(); + + let baseline_started = Instant::now(); + for _ in 0..iterations { + std::hint::black_box(update.clone().with_provider_source(source.clone())); + } + let baseline_elapsed = baseline_started.elapsed(); + + let optimized_started = Instant::now(); + for _ in 0..iterations { + std::hint::black_box(update.clone().with_provider_source_ref(&source_ref)); + } + let optimized_elapsed = optimized_started.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); + + eprintln!( + "provider_update_source_attachment_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, + ); + } } diff --git a/crates/sof-observer/src/provider_stream/laserstream.rs b/crates/sof-observer/src/provider_stream/laserstream.rs index 6a1fcb07..58526f0a 100644 --- a/crates/sof-observer/src/provider_stream/laserstream.rs +++ b/crates/sof-observer/src/provider_stream/laserstream.rs @@ -6,8 +6,12 @@ //! built-in adapter exposes transaction, transaction-status, account-update, //! block-meta, and slot feeds through the same typed provider-stream surface. +#[cfg(test)] +use std::hint::black_box; use std::{ collections::HashMap, + fmt, + pin::Pin, str::FromStr, sync::Arc, time::{Duration, Instant}, @@ -23,7 +27,10 @@ use laserstream_core_proto::prelude::Transaction as LaserStreamTransaction; use laserstream_core_proto::tonic::{ Status, codec::CompressionEncoding, metadata::MetadataValue, transport::Endpoint, }; -use sof_types::{PubkeyBytes, SignatureBytes}; +use sof_support::bytes::{pubkey_bytes_from_slice, signature_bytes_from_slice}; +use sof_support::collections_support::prune_recent_slots; +use sof_support::time_support::{duration_secs_ceil, nonzero_duration_or}; +use sof_types::SignatureBytes; use solana_hash::Hash; use solana_message::{ Message, MessageHeader, VersionedMessage, @@ -32,6 +39,7 @@ use solana_message::{ }; use solana_pubkey::Pubkey; use solana_signature::Signature; +use solana_system_interface::MAX_PERMITTED_DATA_LENGTH; use solana_transaction::versioned::VersionedTransaction; use thiserror::Error; use tokio::sync::mpsc; @@ -57,6 +65,20 @@ use crate::{ const INTERNAL_WATERMARK_SLOT_FILTER: &str = "__sof_watermark_slots"; const LASERSTREAM_SDK_NAME: &str = "sof"; const LASERSTREAM_SDK_VERSION: &str = env!("CARGO_PKG_VERSION"); +const MAX_ACCOUNT_DATA_LEN: usize = MAX_PERMITTED_DATA_LENGTH as usize; +const SLOT_STATUS_RETAINED_LAG: u64 = 4_096; +const SLOT_STATUS_PRUNE_THRESHOLD: usize = SLOT_STATUS_RETAINED_LAG as usize * 2; +const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10; +const DEFAULT_TIMEOUT_SECS: u64 = 30; +const DEFAULT_HTTP2_KEEP_ALIVE_INTERVAL_SECS: u64 = 30; +const DEFAULT_KEEP_ALIVE_TIMEOUT_SECS: u64 = 5; +const DEFAULT_INITIAL_STREAM_WINDOW_SIZE: u32 = 1024 * 1024 * 4; +const DEFAULT_INITIAL_CONNECTION_WINDOW_SIZE: u32 = 1024 * 1024 * 8; +const DEFAULT_BUFFER_SIZE: usize = 1024 * 64; +const DEFAULT_MAX_DECODING_MESSAGE_SIZE: usize = 1_000_000_000; +const DEFAULT_MAX_ENCODING_MESSAGE_SIZE: usize = 32_000_000; +const MIN_PROVIDER_STALL_TIMEOUT: Duration = Duration::from_millis(1); +const MIN_RECONNECT_DELAY: Duration = Duration::from_millis(1); /// LaserStream subscription commitment used for provider-stream transaction updates. #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] @@ -143,8 +165,8 @@ pub enum LaserStreamConfigOption { RequireTransactionSignature, } -impl std::fmt::Display for LaserStreamConfigOption { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for LaserStreamConfigOption { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::VoteFilter => f.write_str("vote filter"), Self::FailedFilter => f.write_str("failed filter"), @@ -471,6 +493,10 @@ impl LaserStreamConfig { self } + const fn reconnect_delay_effective(&self) -> Duration { + effective_reconnect_delay(self.reconnect_delay) + } + /// Sets provider replay behavior. #[must_use] pub const fn with_replay_mode(mut self, mode: ProviderReplayMode) -> Self { @@ -519,47 +545,20 @@ impl LaserStreamConfig { } fn client_config(&self) -> ClientConfig { - let mut options = ChannelOptions::default(); - if let Some(timeout) = self.connect_timeout { - options.connect_timeout_secs = Some(timeout.as_secs()); - } - if let Some(timeout) = self.timeout { - options.timeout_secs = Some(timeout.as_secs()); - } - if let Some(bytes) = self.max_decoding_message_size { - options.max_decoding_message_size = Some(bytes); - } - if let Some(bytes) = self.max_encoding_message_size { - options.max_encoding_message_size = Some(bytes); - } - - let mut config = ClientConfig::new(self.endpoint.clone(), self.api_key.clone()) - .with_channel_options(options) - .with_replay(!matches!(self.replay_mode, ProviderReplayMode::Live)); - if let Some(attempts) = self.max_reconnect_attempts { - config = config.with_max_reconnect_attempts(attempts); - } - config + laserstream_client_config(&LaserStreamClientConfigInputs { + endpoint: &self.endpoint, + api_key: &self.api_key, + connect_timeout: self.connect_timeout, + request_timeout: self.timeout, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + replay_mode: self.replay_mode, + max_reconnect_attempts: self.max_reconnect_attempts, + }) } const fn replay_from_slot(&self, tracked_slot: u64) -> Option { - match self.replay_mode { - ProviderReplayMode::Live => None, - ProviderReplayMode::Resume => { - if tracked_slot == 0 { - None - } else { - Some(tracked_slot) - } - } - ProviderReplayMode::FromSlot(slot) => { - if tracked_slot == 0 { - Some(slot) - } else { - Some(tracked_slot) - } - } - } + laserstream_replay_from_slot(self.replay_mode, tracked_slot) } fn transaction_filter(&self) -> grpc::SubscribeRequestFilterTransactions { @@ -650,6 +649,70 @@ impl LaserStreamConfig { } } +struct LaserStreamClientConfigInputs<'config> { + endpoint: &'config str, + api_key: &'config str, + connect_timeout: Option, + request_timeout: Option, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + replay_mode: ProviderReplayMode, + max_reconnect_attempts: Option, +} + +fn laserstream_client_config(inputs: &LaserStreamClientConfigInputs<'_>) -> ClientConfig { + let mut options = ChannelOptions::default(); + if let Some(connect_timeout) = inputs.connect_timeout { + options.connect_timeout_secs = Some(duration_secs_ceil(nonzero_duration_or( + connect_timeout, + Duration::from_millis(1), + ))); + } + if let Some(request_timeout) = inputs.request_timeout { + options.timeout_secs = Some(duration_secs_ceil(nonzero_duration_or( + request_timeout, + Duration::from_millis(1), + ))); + } + if let Some(bytes) = inputs.max_decoding_message_size { + options.max_decoding_message_size = Some(bytes); + } + if let Some(bytes) = inputs.max_encoding_message_size { + options.max_encoding_message_size = Some(bytes); + } + + let mut config = ClientConfig::new(inputs.endpoint.to_owned(), inputs.api_key.to_owned()) + .with_channel_options(options) + .with_replay(!matches!(inputs.replay_mode, ProviderReplayMode::Live)); + if let Some(attempts) = inputs.max_reconnect_attempts { + config = config.with_max_reconnect_attempts(attempts); + } + config +} + +const fn laserstream_replay_from_slot( + replay_mode: ProviderReplayMode, + tracked_slot: u64, +) -> Option { + match replay_mode { + ProviderReplayMode::Live => None, + ProviderReplayMode::Resume => { + if tracked_slot == 0 { + None + } else { + Some(tracked_slot) + } + } + ProviderReplayMode::FromSlot(slot) => { + if tracked_slot == 0 { + Some(slot) + } else { + Some(tracked_slot) + } + } + } +} + /// Primary LaserStream stream families supported by the built-in adapter. #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] pub enum LaserStreamStream { @@ -843,6 +906,10 @@ impl LaserStreamSlotsConfig { self } + const fn reconnect_delay_effective(&self) -> Duration { + effective_reconnect_delay(self.reconnect_delay) + } + /// Sets provider replay behavior. #[must_use] pub const fn with_replay_mode(mut self, mode: ProviderReplayMode) -> Self { @@ -866,47 +933,28 @@ impl LaserStreamSlotsConfig { } fn client_config(&self) -> ClientConfig { - let mut options = ChannelOptions::default(); - if let Some(timeout) = self.connect_timeout { - options.connect_timeout_secs = Some(timeout.as_secs()); - } - if let Some(timeout) = self.timeout { - options.timeout_secs = Some(timeout.as_secs()); - } - if let Some(bytes) = self.max_decoding_message_size { - options.max_decoding_message_size = Some(bytes); - } - if let Some(bytes) = self.max_encoding_message_size { - options.max_encoding_message_size = Some(bytes); - } - - let mut config = ClientConfig::new(self.endpoint.clone(), self.api_key.clone()) - .with_channel_options(options) - .with_replay(!matches!(self.replay_mode, ProviderReplayMode::Live)); - if let Some(attempts) = self.max_reconnect_attempts { - config = config.with_max_reconnect_attempts(attempts); - } - config + laserstream_client_config(&LaserStreamClientConfigInputs { + endpoint: &self.endpoint, + api_key: &self.api_key, + connect_timeout: self.connect_timeout, + request_timeout: self.timeout, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + replay_mode: self.replay_mode, + max_reconnect_attempts: self.max_reconnect_attempts, + }) } const fn replay_from_slot(&self, tracked_slot: u64) -> Option { - match self.replay_mode { - ProviderReplayMode::Live => None, - ProviderReplayMode::Resume => { - if tracked_slot == 0 { - None - } else { - Some(tracked_slot) - } - } - ProviderReplayMode::FromSlot(slot) => { - if tracked_slot == 0 { - Some(slot) - } else { - Some(tracked_slot) - } - } - } + laserstream_replay_from_slot(self.replay_mode, tracked_slot) + } +} + +const fn effective_reconnect_delay(delay: Duration) -> Duration { + if delay.is_zero() { + MIN_RECONNECT_DELAY + } else { + delay } } @@ -991,8 +1039,8 @@ pub enum LaserStreamStreamKind { Slots, } -impl std::fmt::Display for LaserStreamStreamKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for LaserStreamStreamKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Transaction => f.write_str("transaction"), Self::TransactionStatus => f.write_str("transaction-status"), @@ -1003,10 +1051,10 @@ impl std::fmt::Display for LaserStreamStreamKind { } } -type LaserStreamSubscribeSink = std::pin::Pin< +type LaserStreamSubscribeSink = Pin< Box + Send>, >; -type LaserStreamUpdateStream = std::pin::Pin< +type LaserStreamUpdateStream = Pin< Box< dyn futures_util::Stream< Item = Result, @@ -1227,7 +1275,7 @@ async fn spawn_laserstream_source_inner( .await?; return Err(LaserStreamProtocolError::ReconnectBudgetExhausted { attempts }.into()); } - tokio::time::sleep(config.reconnect_delay).await; + tokio::time::sleep(config.reconnect_delay_effective()).await; } })) } @@ -1291,7 +1339,7 @@ async fn spawn_laserstream_slot_source_inner( let mut attempts = 0_u32; let mut tracked_slot = 0_u64; let mut watermarks = ProviderCommitmentWatermarks::default(); - let mut slot_states = HashMap::new(); + let mut slot_states = HashMap::with_capacity(SLOT_STATUS_PRUNE_THRESHOLD); let mut first_session = Some(first_session); loop { let mut session_established = false; @@ -1397,7 +1445,7 @@ async fn spawn_laserstream_slot_source_inner( .await?; return Err(LaserStreamProtocolError::ReconnectBudgetExhausted { attempts }.into()); } - tokio::time::sleep(config.reconnect_delay).await; + tokio::time::sleep(config.reconnect_delay_effective()).await; } })) } @@ -1412,6 +1460,7 @@ async fn run_laserstream_primary_connection( ) -> Result<(), LaserStreamError> { *state.session_established = false; let commitment = config.commitment.as_tx_commitment(); + let provider_source = Arc::new(source.clone()); *state.session_established = true; send_primary_provider_health( source, @@ -1426,7 +1475,10 @@ async fn run_laserstream_primary_connection( loop { tokio::select! { () = async { - if let Some(timeout) = config.stall_timeout { + if let Some(timeout) = config + .stall_timeout + .map(|timeout| nonzero_duration_or(timeout, MIN_PROVIDER_STALL_TIMEOUT)) + { let deadline = last_progress.checked_add(timeout).unwrap_or(last_progress); tokio::time::sleep_until(deadline.into()).await; } else { @@ -1480,7 +1532,7 @@ async fn run_laserstream_primary_connection( sender .send( ProviderStreamUpdate::Transaction(event) - .with_provider_source(source.clone()), + .with_provider_source_ref(&provider_source), ) .await .map_err(|_error| LaserStreamError::QueueClosed)?; @@ -1500,7 +1552,7 @@ async fn run_laserstream_primary_connection( sender .send( ProviderStreamUpdate::TransactionStatus(event) - .with_provider_source(source.clone()), + .with_provider_source_ref(&provider_source), ) .await .map_err(|_error| LaserStreamError::QueueClosed)?; @@ -1522,7 +1574,7 @@ async fn run_laserstream_primary_connection( sender .send( ProviderStreamUpdate::AccountUpdate(event) - .with_provider_source(source.clone()), + .with_provider_source_ref(&provider_source), ) .await .map_err(|_error| LaserStreamError::QueueClosed)?; @@ -1544,7 +1596,7 @@ async fn run_laserstream_primary_connection( sender .send( ProviderStreamUpdate::BlockMeta(event) - .with_provider_source(source.clone()), + .with_provider_source_ref(&provider_source), ) .await .map_err(|_error| LaserStreamError::QueueClosed)?; @@ -1565,6 +1617,7 @@ async fn run_laserstream_slot_connection( mut stream: LaserStreamUpdateStream, ) -> Result<(), LaserStreamError> { *state.session_established = false; + let provider_source = Arc::new(source.clone()); *state.session_established = true; send_provider_slot_health( source, @@ -1579,7 +1632,10 @@ async fn run_laserstream_slot_connection( loop { tokio::select! { () = async { - if let Some(timeout) = config.stall_timeout { + if let Some(timeout) = config + .stall_timeout + .map(|timeout| nonzero_duration_or(timeout, MIN_PROVIDER_STALL_TIMEOUT)) + { let deadline = last_progress.checked_add(timeout).unwrap_or(last_progress); tokio::time::sleep_until(deadline.into()).await; } else { @@ -1607,7 +1663,7 @@ async fn run_laserstream_slot_connection( sender .send( ProviderStreamUpdate::SlotStatus(event) - .with_provider_source(source.clone()), + .with_provider_source_ref(&provider_source), ) .await .map_err(|_error| LaserStreamError::QueueClosed)?; @@ -1772,54 +1828,79 @@ impl Interceptor for SofLaserStreamInterceptor { } } -async fn connect_and_subscribe_once( - config: &LaserStreamConfig, - request: grpc::SubscribeRequest, -) -> Result<(LaserStreamSubscribeSink, LaserStreamUpdateStream), LaserStreamError> { - let options = config.client_config().channel_options; - let interceptor = SofLaserStreamInterceptor::new(&config.api_key) - .map_err(|error| LaserStreamProtocolError::InvalidApiKey(error.to_string()))?; - - let mut endpoint = Endpoint::from_shared(config.endpoint.clone()) +fn laserstream_endpoint( + endpoint: &str, + options: &ChannelOptions, +) -> Result { + let mut transport = Endpoint::from_shared(endpoint.to_owned()) .map_err(|error| LaserStreamProtocolError::InvalidEndpoint(error.to_string()))? .connect_timeout(Duration::from_secs( - options.connect_timeout_secs.unwrap_or(10), + options + .connect_timeout_secs + .unwrap_or(DEFAULT_CONNECT_TIMEOUT_SECS), + )) + .timeout(Duration::from_secs( + options.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS), )) - .timeout(Duration::from_secs(options.timeout_secs.unwrap_or(30))) .http2_keep_alive_interval(Duration::from_secs( - options.http2_keep_alive_interval_secs.unwrap_or(30), + options + .http2_keep_alive_interval_secs + .unwrap_or(DEFAULT_HTTP2_KEEP_ALIVE_INTERVAL_SECS), )) .keep_alive_timeout(Duration::from_secs( - options.keep_alive_timeout_secs.unwrap_or(5), + options + .keep_alive_timeout_secs + .unwrap_or(DEFAULT_KEEP_ALIVE_TIMEOUT_SECS), )) .keep_alive_while_idle(options.keep_alive_while_idle.unwrap_or(true)) - .initial_stream_window_size(options.initial_stream_window_size.or(Some(1024 * 1024 * 4))) + .initial_stream_window_size( + options + .initial_stream_window_size + .or(Some(DEFAULT_INITIAL_STREAM_WINDOW_SIZE)), + ) .initial_connection_window_size( options .initial_connection_window_size - .or(Some(1024 * 1024 * 8)), + .or(Some(DEFAULT_INITIAL_CONNECTION_WINDOW_SIZE)), ) .http2_adaptive_window(options.http2_adaptive_window.unwrap_or(true)) .tcp_nodelay(options.tcp_nodelay.unwrap_or(true)) - .buffer_size(options.buffer_size.or(Some(1024 * 64))); + .buffer_size(options.buffer_size.or(Some(DEFAULT_BUFFER_SIZE))); if let Some(tcp_keepalive_secs) = options.tcp_keepalive_secs { - endpoint = endpoint.tcp_keepalive(Some(Duration::from_secs(tcp_keepalive_secs))); + transport = transport.tcp_keepalive(Some(Duration::from_secs(tcp_keepalive_secs))); } - if endpoint_uses_tls(&config.endpoint) { - endpoint = endpoint + if endpoint_uses_tls(endpoint) { + transport = transport .tls_config(ClientTlsConfig::new().with_enabled_roots()) .map_err(|error| LaserStreamProtocolError::TlsConfig(error.to_string()))?; } + Ok(transport) +} - let channel = endpoint +async fn connect_and_subscribe_once( + config: &LaserStreamConfig, + request: grpc::SubscribeRequest, +) -> Result<(LaserStreamSubscribeSink, LaserStreamUpdateStream), LaserStreamError> { + let options = config.client_config().channel_options; + let interceptor = SofLaserStreamInterceptor::new(&config.api_key) + .map_err(|error| LaserStreamProtocolError::InvalidApiKey(error.to_string()))?; + let channel = laserstream_endpoint(&config.endpoint, &options)? .connect() .await .map_err(|error| LaserStreamProtocolError::ConnectionFailed(error.to_string()))?; let mut geyser_client = grpc::geyser_client::GeyserClient::with_interceptor(channel, interceptor); geyser_client = geyser_client - .max_decoding_message_size(options.max_decoding_message_size.unwrap_or(1_000_000_000)) - .max_encoding_message_size(options.max_encoding_message_size.unwrap_or(32_000_000)); + .max_decoding_message_size( + options + .max_decoding_message_size + .unwrap_or(DEFAULT_MAX_DECODING_MESSAGE_SIZE), + ) + .max_encoding_message_size( + options + .max_encoding_message_size + .unwrap_or(DEFAULT_MAX_ENCODING_MESSAGE_SIZE), + ); if let Some(send_comp) = options.send_compression { let encoding = match send_comp { helius_laserstream::CompressionEncoding::Gzip => CompressionEncoding::Gzip, @@ -1857,47 +1938,23 @@ async fn connect_and_subscribe_slots_once( let options = config.client_config().channel_options; let interceptor = SofLaserStreamInterceptor::new(&config.api_key) .map_err(|error| LaserStreamProtocolError::InvalidApiKey(error.to_string()))?; - - let mut endpoint = Endpoint::from_shared(config.endpoint.clone()) - .map_err(|error| LaserStreamProtocolError::InvalidEndpoint(error.to_string()))? - .connect_timeout(Duration::from_secs( - options.connect_timeout_secs.unwrap_or(10), - )) - .timeout(Duration::from_secs(options.timeout_secs.unwrap_or(30))) - .http2_keep_alive_interval(Duration::from_secs( - options.http2_keep_alive_interval_secs.unwrap_or(30), - )) - .keep_alive_timeout(Duration::from_secs( - options.keep_alive_timeout_secs.unwrap_or(5), - )) - .keep_alive_while_idle(options.keep_alive_while_idle.unwrap_or(true)) - .initial_stream_window_size(options.initial_stream_window_size.or(Some(1024 * 1024 * 4))) - .initial_connection_window_size( - options - .initial_connection_window_size - .or(Some(1024 * 1024 * 8)), - ) - .http2_adaptive_window(options.http2_adaptive_window.unwrap_or(true)) - .tcp_nodelay(options.tcp_nodelay.unwrap_or(true)) - .buffer_size(options.buffer_size.or(Some(1024 * 64))); - if let Some(tcp_keepalive_secs) = options.tcp_keepalive_secs { - endpoint = endpoint.tcp_keepalive(Some(Duration::from_secs(tcp_keepalive_secs))); - } - if endpoint_uses_tls(&config.endpoint) { - endpoint = endpoint - .tls_config(ClientTlsConfig::new().with_enabled_roots()) - .map_err(|error| LaserStreamProtocolError::TlsConfig(error.to_string()))?; - } - - let channel = endpoint + let channel = laserstream_endpoint(&config.endpoint, &options)? .connect() .await .map_err(|error| LaserStreamProtocolError::ConnectionFailed(error.to_string()))?; let mut geyser_client = grpc::geyser_client::GeyserClient::with_interceptor(channel, interceptor); geyser_client = geyser_client - .max_decoding_message_size(options.max_decoding_message_size.unwrap_or(1_000_000_000)) - .max_encoding_message_size(options.max_encoding_message_size.unwrap_or(32_000_000)); + .max_decoding_message_size( + options + .max_decoding_message_size + .unwrap_or(DEFAULT_MAX_DECODING_MESSAGE_SIZE), + ) + .max_encoding_message_size( + options + .max_encoding_message_size + .unwrap_or(DEFAULT_MAX_ENCODING_MESSAGE_SIZE), + ); if let Some(send_comp) = options.send_compression { let encoding = match send_comp { helius_laserstream::CompressionEncoding::Gzip => CompressionEncoding::Gzip, @@ -1937,21 +1994,21 @@ fn transaction_event_from_update( let transaction = transaction.ok_or(LaserStreamError::Convert("missing transaction payload"))?; let is_vote = transaction.is_vote; - let signature = Some(signature_bytes_from_slice( - transaction.signature.as_slice(), - "invalid signature", - )?); + let signature = signature_bytes_from_slice(transaction.signature.as_slice(), || { + LaserStreamError::Convert("invalid signature") + })?; let tx = convert_transaction( transaction .transaction .ok_or(LaserStreamError::Convert("missing versioned transaction"))?, + Some(signature), )?; Ok(TransactionEvent { slot, commitment_status, confirmed_slot: watermarks.confirmed_slot, finalized_slot: watermarks.finalized_slot, - signature, + signature: Some(signature), provider_source: None, kind: if is_vote { TxKind::VoteOnly @@ -1967,10 +2024,9 @@ fn transaction_status_event_from_update( watermarks: ProviderCommitmentWatermarks, update: grpc::SubscribeUpdateTransactionStatus, ) -> Result { - let signature = signature_bytes_from_slice( - update.signature.as_slice(), - "invalid transaction-status signature", - )?; + let signature = signature_bytes_from_slice(update.signature.as_slice(), || { + LaserStreamError::Convert("invalid transaction-status signature") + })?; Ok(TransactionStatusEvent { slot: update.slot, commitment_status, @@ -1992,15 +2048,23 @@ fn account_update_event_from_laserstream( let account = update .account .ok_or(LaserStreamError::Convert("missing account payload"))?; - let pubkey = pubkey_bytes_from_slice(account.pubkey.as_slice(), "invalid account pubkey")?; - let owner = pubkey_bytes_from_slice(account.owner.as_slice(), "invalid account owner")?; + let pubkey = pubkey_bytes_from_slice(account.pubkey.as_slice(), || { + LaserStreamError::Convert("invalid account pubkey") + })?; + let owner = pubkey_bytes_from_slice(account.owner.as_slice(), || { + LaserStreamError::Convert("invalid account owner") + })?; let txn_signature = match account.txn_signature { - Some(signature) => Some(signature_bytes_from_slice( - signature.as_slice(), - "invalid account txn signature", - )?), + Some(signature) => Some(signature_bytes_from_slice(signature.as_slice(), || { + LaserStreamError::Convert("invalid account txn signature") + })?), None => None, }; + if account.data.len() > MAX_ACCOUNT_DATA_LEN { + return Err(LaserStreamError::Convert( + "account data exceeds max permitted size", + )); + } Ok(AccountUpdateEvent { slot: update.slot, commitment_status, @@ -2045,26 +2109,6 @@ fn block_meta_event_from_update( }) } -fn signature_bytes_from_slice( - bytes: &[u8], - message: &'static str, -) -> Result { - let raw: [u8; 64] = bytes - .try_into() - .map_err(|_error: std::array::TryFromSliceError| LaserStreamError::Convert(message))?; - Ok(SignatureBytes::from(raw)) -} - -fn pubkey_bytes_from_slice( - bytes: &[u8], - message: &'static str, -) -> Result { - let raw: [u8; 32] = bytes - .try_into() - .map_err(|_error: std::array::TryFromSliceError| LaserStreamError::Convert(message))?; - Ok(PubkeyBytes::from(raw)) -} - fn observe_non_transaction_commitment( watermarks: &mut ProviderCommitmentWatermarks, slot: u64, @@ -2104,6 +2148,12 @@ fn slot_status_event_from_update( | grpc::SlotStatus::SlotCreatedBank => ForkSlotStatus::Processed, }; let previous_status = slot_states.insert(slot, mapped); + prune_recent_slots( + slot_states, + slot, + SLOT_STATUS_RETAINED_LAG, + SLOT_STATUS_PRUNE_THRESHOLD, + ); if previous_status == Some(mapped) { return None; } @@ -2168,15 +2218,29 @@ impl ProviderStreamFanIn { } } -#[inline] +#[inline(always)] fn convert_transaction( tx: LaserStreamTransaction, + first_signature: Option, ) -> Result { let mut signatures = Vec::with_capacity(tx.signatures.len()); - for signature in tx.signatures { - signatures.push(Signature::try_from(signature.as_slice()).map_err(|_error| { - LaserStreamError::Convert("failed to parse transaction signature") - })?); + let mut tx_signatures = tx.signatures.into_iter(); + if let Some(signature) = tx_signatures.next() { + signatures.push(match first_signature { + Some(first_signature) => first_signature.into(), + None => signature_bytes_from_slice(signature.as_slice(), || { + LaserStreamError::Convert("failed to parse transaction signature") + })? + .into(), + }); + } + for signature in tx_signatures { + signatures.push( + signature_bytes_from_slice(signature.as_slice(), || { + LaserStreamError::Convert("failed to parse transaction signature") + })? + .into(), + ); } let message = tx .message @@ -2201,8 +2265,10 @@ fn convert_transaction( let mut account_keys = Vec::with_capacity(message.account_keys.len()); for key in message.account_keys { account_keys.push( - Pubkey::try_from(key.as_slice()) - .map_err(|_error| LaserStreamError::Convert("invalid account key"))?, + pubkey_bytes_from_slice(key.as_slice(), || { + LaserStreamError::Convert("invalid account key") + })? + .into(), ); } @@ -2221,9 +2287,10 @@ fn convert_transaction( let mut address_table_lookups = Vec::with_capacity(message.address_table_lookups.len()); for lookup in message.address_table_lookups { address_table_lookups.push(MessageAddressTableLookup { - account_key: Pubkey::try_from(lookup.account_key.as_slice()).map_err(|_error| { + account_key: pubkey_bytes_from_slice(lookup.account_key.as_slice(), || { LaserStreamError::Convert("invalid address table account key") - })?, + })? + .into(), writable_indexes: lookup.writable_indexes, readonly_indexes: lookup.readonly_indexes, }); @@ -2259,8 +2326,14 @@ fn convert_transaction( )] mod tests { use super::*; - use crate::provider_stream::create_provider_stream_queue; - use crate::provider_stream::yellowstone::{YellowstoneGrpcCommitment, YellowstoneGrpcConfig}; + use crate::{ + event::TxKind, + framework::signature_bytes, + provider_stream::{ + create_provider_stream_queue, + yellowstone::{YellowstoneGrpcCommitment, YellowstoneGrpcConfig}, + }, + }; use futures_channel::mpsc as futures_mpsc; use futures_util::stream::{self, Stream}; use laserstream_core_proto::geyser::geyser_server::{Geyser, GeyserServer}; @@ -2282,16 +2355,70 @@ mod tests { use solana_message::{Message, VersionedMessage}; use solana_sdk_ids::{compute_budget, system_program, vote}; use solana_signer::Signer; - use std::{pin::Pin, time::Instant}; + use std::{ + net::{SocketAddr, TcpListener as StdTcpListener}, + pin::Pin, + time::Instant, + }; + + use sof_support::bench::{avg_ns_per_iteration, profile_iterations}; use tokio::sync::oneshot; use tokio::time::{Duration, timeout}; - fn profile_iterations(default: usize) -> usize { - std::env::var("SOF_PROFILE_ITERATIONS") - .ok() - .and_then(|value| value.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(default) + #[test] + fn laserstream_account_update_rejects_oversized_data() { + let pubkey = Pubkey::new_unique(); + let owner = Pubkey::new_unique(); + let update = sample_account_update(92, pubkey, owner); + let account_update = match update.update_oneof { + Some(grpc::subscribe_update::UpdateOneof::Account(account_update)) => account_update, + other => panic!("expected account update, got {other:?}"), + }; + + let mut oversized = account_update; + oversized.account.as_mut().expect("account payload").data = + vec![7_u8; MAX_ACCOUNT_DATA_LEN + 1]; + + let error = account_update_event_from_laserstream( + oversized, + TxCommitmentStatus::Confirmed, + ProviderCommitmentWatermarks::default(), + ) + .expect_err("oversized account payload must fail"); + + assert!(matches!( + error, + LaserStreamError::Convert("account data exceeds max permitted size") + )); + } + + #[test] + fn laserstream_slot_state_pruning_evicts_old_slots() { + let mut slot_states = HashMap::new(); + for slot in 0..=u64::try_from(SLOT_STATUS_PRUNE_THRESHOLD).unwrap_or(u64::MAX) { + let _ = slot_states.insert(slot, ForkSlotStatus::Processed); + } + + prune_recent_slots( + &mut slot_states, + 10_000, + SLOT_STATUS_RETAINED_LAG, + SLOT_STATUS_PRUNE_THRESHOLD, + ); + + assert!( + !slot_states.contains_key(&0), + "old tracked slots should be pruned" + ); + assert!( + slot_states.contains_key(&10_000_u64.saturating_sub(SLOT_STATUS_RETAINED_LAG)), + "recent tracked slots should stay resident" + ); + assert!( + slot_states.len() + <= usize::try_from(SLOT_STATUS_RETAINED_LAG + 1).unwrap_or(usize::MAX), + "tracked slot state should stay bounded" + ); } #[test] @@ -2461,6 +2588,66 @@ mod tests { assert!(request.slots.contains_key(INTERNAL_WATERMARK_SLOT_FILTER)); } + #[test] + fn laserstream_client_config_rounds_subsecond_timeouts_up() { + let config = LaserStreamConfig::new("https://laserstream.example", "token") + .with_connect_timeout(Duration::from_millis(250)) + .with_timeout(Duration::from_millis(750)) + .client_config(); + assert_eq!(config.channel_options.connect_timeout_secs, Some(1)); + assert_eq!(config.channel_options.timeout_secs, Some(1)); + + let slots_config = LaserStreamSlotsConfig::new("https://laserstream.example", "token") + .with_connect_timeout(Duration::from_millis(400)) + .with_timeout(Duration::from_millis(900)) + .client_config(); + assert_eq!(slots_config.channel_options.connect_timeout_secs, Some(1)); + assert_eq!(slots_config.channel_options.timeout_secs, Some(1)); + } + + #[test] + fn laserstream_client_config_clamps_zero_timeouts() { + let config = LaserStreamConfig::new("https://laserstream.example", "token") + .with_connect_timeout(Duration::ZERO) + .with_timeout(Duration::ZERO) + .client_config(); + assert_eq!(config.channel_options.connect_timeout_secs, Some(1)); + assert_eq!(config.channel_options.timeout_secs, Some(1)); + + let slots_config = LaserStreamSlotsConfig::new("https://laserstream.example", "token") + .with_connect_timeout(Duration::ZERO) + .with_timeout(Duration::ZERO) + .client_config(); + assert_eq!(slots_config.channel_options.connect_timeout_secs, Some(1)); + assert_eq!(slots_config.channel_options.timeout_secs, Some(1)); + } + + #[test] + fn laserstream_reconnect_delay_never_spins() { + let config = LaserStreamConfig::new("https://laserstream.example", "token") + .with_reconnect_delay(Duration::ZERO); + assert_eq!(config.reconnect_delay_effective(), Duration::from_millis(1)); + + let slots_config = LaserStreamSlotsConfig::new("https://laserstream.example", "token") + .with_reconnect_delay(Duration::ZERO); + assert_eq!( + slots_config.reconnect_delay_effective(), + Duration::from_millis(1) + ); + } + + #[test] + fn laserstream_stall_timeout_never_zero() { + assert_eq!( + nonzero_duration_or(Duration::ZERO, MIN_PROVIDER_STALL_TIMEOUT), + Duration::from_millis(1) + ); + assert_eq!( + nonzero_duration_or(Duration::from_millis(25), MIN_PROVIDER_STALL_TIMEOUT), + Duration::from_millis(25) + ); + } + #[tokio::test] async fn laserstream_local_source_delivers_transaction_update() { let update = grpc::SubscribeUpdate { @@ -3064,12 +3251,8 @@ mod tests { async fn spawn_laserstream_test_server( service: MockLaserStream, - ) -> ( - std::net::SocketAddr, - oneshot::Sender<()>, - tokio::task::JoinHandle<()>, - ) { - let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("bind LaserStream test"); + ) -> (SocketAddr, oneshot::Sender<()>, tokio::task::JoinHandle<()>) { + let listener = StdTcpListener::bind("127.0.0.1:0").expect("bind LaserStream test"); let addr = listener.local_addr().expect("LaserStream test addr"); drop(listener); @@ -3095,13 +3278,14 @@ mod tests { let transaction = transaction.ok_or(LaserStreamError::Convert("missing transaction payload"))?; let signature = Signature::try_from(transaction.signature.as_slice()) - .map(crate::framework::signature_bytes) + .map(signature_bytes) .map(Some) .map_err(|_error| LaserStreamError::Convert("invalid signature"))?; let tx = convert_transaction( transaction .transaction .ok_or(LaserStreamError::Convert("missing versioned transaction"))?, + None, )?; Ok(TransactionEvent { slot, @@ -3220,7 +3404,7 @@ mod tests { #[tokio::test] async fn laserstream_spawn_rejects_account_filters_for_block_meta_stream() { - let (tx, _rx) = crate::provider_stream::create_provider_stream_queue(1); + let (tx, _rx) = create_provider_stream_queue(1); let config = LaserStreamConfig::new("http://127.0.0.1:1", "test-api-key") .with_stream(LaserStreamStream::BlockMeta) .with_accounts([Pubkey::new_unique()]); @@ -3241,7 +3425,7 @@ mod tests { #[test] fn laserstream_local_conversion_matches_sdk_baseline() { let tx = proto_transaction_from_versioned(&sample_transaction()); - let local = convert_transaction(tx.clone()).expect("local tx"); + let local = convert_transaction(tx.clone(), None).expect("local tx"); let baseline = convert_transaction_sdk_baseline(tx).expect("baseline tx"); assert_eq!(local, baseline); } @@ -3255,22 +3439,28 @@ mod tests { let baseline_started = Instant::now(); for _ in 0..iterations { let tx = convert_transaction_sdk_baseline(tx.clone()).expect("baseline tx"); - std::hint::black_box(tx); + black_box(tx); } let baseline_elapsed = baseline_started.elapsed(); let optimized_started = Instant::now(); for _ in 0..iterations { - let tx = convert_transaction(tx.clone()).expect("optimized tx"); - std::hint::black_box(tx); + let tx = convert_transaction(tx.clone(), None).expect("optimized tx"); + black_box(tx); } let optimized_elapsed = optimized_started.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); eprintln!( - "laserstream_local_conversion_profile_fixture iterations={} baseline_us={} optimized_us={}", + "laserstream_local_conversion_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", iterations, baseline_elapsed.as_micros(), optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, ); } @@ -3284,7 +3474,7 @@ mod tests { ) .expect("event"); assert_eq!(event.slot, 77); - assert_eq!(event.kind, crate::event::TxKind::Mixed); + assert_eq!(event.kind, TxKind::Mixed); assert!(event.signature.is_some()); } @@ -3298,7 +3488,7 @@ mod tests { ) .expect("event"); assert_eq!(event.slot, 78); - assert_eq!(event.kind, crate::event::TxKind::VoteOnly); + assert_eq!(event.kind, TxKind::VoteOnly); assert!(event.signature.is_some()); } @@ -3317,7 +3507,7 @@ mod tests { TxCommitmentStatus::Processed, ) .expect("baseline event"); - std::hint::black_box(event); + black_box(event); } let baseline_elapsed = baseline_started.elapsed(); @@ -3330,15 +3520,21 @@ mod tests { ProviderCommitmentWatermarks::default(), ) .expect("optimized event"); - std::hint::black_box(event); + black_box(event); } let optimized_elapsed = optimized_started.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); eprintln!( - "laserstream_transaction_conversion_profile_fixture iterations={} baseline_us={} optimized_us={}", + "laserstream_transaction_conversion_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", iterations, baseline_elapsed.as_micros(), optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, ); } @@ -3355,7 +3551,7 @@ mod tests { TxCommitmentStatus::Processed, ) .expect("baseline event"); - std::hint::black_box(event); + black_box(event); } } @@ -3373,7 +3569,7 @@ mod tests { ProviderCommitmentWatermarks::default(), ) .expect("optimized event"); - std::hint::black_box(event); + black_box(event); } } @@ -3392,7 +3588,7 @@ mod tests { TxCommitmentStatus::Processed, ) .expect("baseline event"); - std::hint::black_box(event); + black_box(event); } let baseline_elapsed = baseline_started.elapsed(); @@ -3405,15 +3601,21 @@ mod tests { ProviderCommitmentWatermarks::default(), ) .expect("optimized event"); - std::hint::black_box(event); + black_box(event); } let optimized_elapsed = optimized_started.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); eprintln!( - "laserstream_vote_only_conversion_profile_fixture iterations={} baseline_us={} optimized_us={}", + "laserstream_vote_only_conversion_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", iterations, baseline_elapsed.as_micros(), optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, ); } } diff --git a/crates/sof-observer/src/provider_stream/websocket.rs b/crates/sof-observer/src/provider_stream/websocket.rs index 084a6f6a..f4b25292 100644 --- a/crates/sof-observer/src/provider_stream/websocket.rs +++ b/crates/sof-observer/src/provider_stream/websocket.rs @@ -7,28 +7,51 @@ //! and LaserStream by requesting full base64 transaction payloads and converting //! them into [`crate::framework::TransactionEvent`] values before dispatch. -use std::{borrow::Cow, str::FromStr, sync::Arc, time::Duration}; +use std::{ + borrow::Cow, + fmt, + future::pending, + mem, + str::FromStr, + sync::{Arc, OnceLock}, + time::Duration, +}; use base64::{Engine as _, engine::general_purpose::STANDARD}; -use futures_util::{SinkExt, StreamExt}; +use futures_util::{Sink, SinkExt, Stream, StreamExt}; +use reqwest::redirect::Policy; use serde::Deserialize; -use serde_json::{Value, json}; +use serde_json::{Value, from_slice as json_from_slice, json}; use simd_json::{Buffers as SimdJsonBuffers, serde::from_slice as simd_from_slice}; +use sof_support::short_vec::decode_short_u16_len; +use sof_support::time_support::nonzero_duration_or; use sof_types::{PubkeyBytes, SignatureBytes}; +use solana_packet::PACKET_DATA_SIZE; use solana_pubkey::Pubkey; use solana_signature::Signature; +use solana_system_interface::MAX_PERMITTED_DATA_LENGTH; use solana_transaction::versioned::VersionedTransaction; use thiserror::Error; use tokio::net::TcpStream; use tokio::sync::mpsc; use tokio::task::JoinHandle; use tokio_tungstenite::{ - MaybeTlsStream, WebSocketStream, connect_async, tungstenite::protocol::Message as WsMessage, + MaybeTlsStream, WebSocketStream, connect_async_with_config, + tungstenite::{ + Error as TungsteniteError, + protocol::{Message as WsMessage, WebSocketConfig}, + }, }; +#[cfg(test)] +use std::hint::black_box; + use crate::{ - event::TxCommitmentStatus, - framework::{AccountUpdateEvent, TransactionEvent, pubkey_bytes, signature_bytes_opt}, + event::{TxCommitmentStatus, TxKind}, + framework::{ + AccountUpdateEvent, TransactionEvent, TransactionLogEvent, pubkey_bytes, + signature_bytes_opt, + }, provider_stream::{ ProviderCommitmentWatermarks, ProviderSourceArbitrationMode, ProviderSourceHealthEvent, ProviderSourceHealthReason, ProviderSourceHealthStatus, ProviderSourceId, @@ -36,10 +59,14 @@ use crate::{ ProviderSourceReservation, ProviderSourceRole, ProviderSourceTaskGuard, ProviderStreamFanIn, ProviderStreamMode, ProviderStreamSender, ProviderStreamUpdate, SerializedTransactionEvent, classify_provider_transaction_kind, - emit_provider_source_removed_with_reservation, + emit_provider_source_removed_with_reservation, keepalive_interval, }, }; +const MIN_RECONNECT_DELAY: Duration = Duration::from_millis(1); +const MIN_STALL_TIMEOUT: Duration = Duration::from_millis(1); +const DEFAULT_WEBSOCKET_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); + /// Commitment level used for websocket `transactionSubscribe` notifications. #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] pub enum WebsocketTransactionCommitment { @@ -125,8 +152,8 @@ pub enum WebsocketConfigOption { ProgramFilters, } -impl std::fmt::Display for WebsocketConfigOption { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for WebsocketConfigOption { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::HttpEndpoint => f.write_str("http replay endpoint"), Self::VoteFilter => f.write_str("vote filter"), @@ -588,6 +615,10 @@ impl WebsocketTransactionConfig { WebsocketPrimaryStream::Program(_) => WebsocketStreamKind::Program, } } + + const fn reconnect_delay_effective(&self) -> Duration { + effective_reconnect_delay(self.reconnect_delay) + } } /// Primary websocket subscription families supported by the built-in adapter. @@ -814,6 +845,18 @@ impl WebsocketLogsConfig { .with_priority(self.source_priority) .with_arbitration(self.source_arbitration) } + + const fn reconnect_delay_effective(&self) -> Duration { + effective_reconnect_delay(self.reconnect_delay) + } +} + +const fn effective_reconnect_delay(delay: Duration) -> Duration { + if delay.is_zero() { + MIN_RECONNECT_DELAY + } else { + delay + } } /// Websocket `transactionSubscribe` error surface. @@ -824,7 +867,7 @@ pub enum WebsocketTransactionError { Config(#[from] WebsocketConfigError), /// Websocket transport failure. #[error(transparent)] - Transport(#[from] tokio_tungstenite::tungstenite::Error), + Transport(#[from] TungsteniteError), /// Upstream payload shape/protocol failure. #[error(transparent)] Protocol(#[from] WebsocketProtocolError), @@ -844,7 +887,7 @@ pub enum WebsocketTransactionError { pub enum WebsocketLogsError { /// Websocket transport failure. #[error(transparent)] - Transport(#[from] tokio_tungstenite::tungstenite::Error), + Transport(#[from] TungsteniteError), /// Upstream payload shape/protocol failure. #[error(transparent)] Protocol(#[from] WebsocketProtocolError), @@ -937,8 +980,8 @@ pub enum WebsocketStreamKind { Program, } -impl std::fmt::Display for WebsocketStreamKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for WebsocketStreamKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Transaction => f.write_str("websocket transaction"), Self::Logs => f.write_str("websocket logs"), @@ -1131,7 +1174,7 @@ async fn spawn_websocket_source_inner( .await?; return Err(WebsocketProtocolError::ReconnectBudgetExhausted { attempts }.into()); } - tokio::time::sleep(config.reconnect_delay).await; + tokio::time::sleep(config.reconnect_delay_effective()).await; } })) } @@ -1286,7 +1329,7 @@ async fn spawn_websocket_logs_source_inner( .await?; return Err(WebsocketProtocolError::ReconnectBudgetExhausted { attempts }.into()); } - tokio::time::sleep(config.reconnect_delay).await; + tokio::time::sleep(config.reconnect_delay_effective()).await; } })) } @@ -1302,6 +1345,7 @@ async fn run_websocket_primary_connection( ) -> Result<(), WebsocketTransactionError> { *session_established = false; let (mut write, mut read) = stream.split(); + let provider_source = Arc::new(source.clone()); *session_established = true; send_primary_provider_health( source, @@ -1316,9 +1360,9 @@ async fn run_websocket_primary_connection( && config.replay_on_reconnect && last_seen_slot.is_some() { - replay_websocket_gap(config, source, sender, last_seen_slot, watermarks).await?; + replay_websocket_gap(config, &provider_source, sender, last_seen_slot, watermarks).await?; } - let mut ping = config.ping_interval.map(tokio::time::interval); + let mut ping = config.ping_interval.map(keepalive_interval); let mut scratch = WebsocketParseScratch::default(); let mut last_progress = tokio::time::Instant::now(); @@ -1328,17 +1372,17 @@ async fn run_websocket_primary_connection( if let Some(interval) = ping.as_mut() { interval.tick().await; } else { - std::future::pending::<()>().await; + pending::<()>().await; } } => { write.send(WsMessage::Ping(Vec::new().into())).await?; } () = async { - if let Some(timeout) = config.stall_timeout { + if let Some(timeout) = websocket_stall_timeout(config.stall_timeout) { let deadline = last_progress.checked_add(timeout).unwrap_or(last_progress); tokio::time::sleep_until(deadline).await; } else { - std::future::pending::<()>().await; + pending::<()>().await; } } => { return Err(WebsocketProtocolError::StreamStalled { @@ -1362,7 +1406,7 @@ async fn run_websocket_primary_connection( }; handle_primary_notification( config, - source, + &provider_source, sender, frame_bytes_mut(&mut scratch.frame_bytes, text.as_str().as_bytes()), &mut state, @@ -1378,7 +1422,7 @@ async fn run_websocket_primary_connection( }; handle_primary_notification( config, - source, + &provider_source, sender, frame_bytes_mut(&mut scratch.frame_bytes, bytes.as_ref()), &mut state, @@ -1390,7 +1434,8 @@ async fn run_websocket_primary_connection( } WsMessage::Pong(_) => {} WsMessage::Close(frame) => { - return Err(WebsocketProtocolError::Closed(format!("{frame:?}")).into()); + write.send(WsMessage::Close(frame)).await.ok(); + return Ok(()); } _ => {} } @@ -1529,51 +1574,16 @@ const fn websocket_logs_health_reason(error: &WebsocketLogsError) -> ProviderSou } } -async fn wait_for_subscription_ack(read: &mut S) -> Result<(), WebsocketTransactionError> +async fn wait_for_subscription_ack( + read: &mut S, + ack_timeout: Duration, +) -> Result<(), WebsocketTransactionError> where - S: futures_util::Stream> + S: Stream> + + Sink + Unpin, { - let ack_timeout = Duration::from_secs(10); - let mut frame_bytes = Vec::new(); - tokio::time::timeout(ack_timeout, async { - loop { - let Some(frame) = read.next().await else { - return Err(WebsocketTransactionError::Protocol( - WebsocketProtocolError::ClosedBeforeSubscriptionAck, - )); - }; - let frame = frame?; - match frame { - WsMessage::Text(text) => { - if handle_subscription_text(frame_bytes_mut( - &mut frame_bytes, - text.as_str().as_bytes(), - ))? { - return Ok(()); - } - } - WsMessage::Binary(bytes) => { - if handle_subscription_text(frame_bytes_mut(&mut frame_bytes, bytes.as_ref()))? - { - return Ok(()); - } - } - WsMessage::Ping(_) | WsMessage::Pong(_) => {} - WsMessage::Close(frame) => { - return Err( - WebsocketProtocolError::ClosedBeforeSubscriptionAckWithFrame(format!( - "{frame:?}" - )) - .into(), - ); - } - _ => {} - } - } - }) - .await - .map_err(|_elapsed| WebsocketProtocolError::SubscriptionAckTimeout)? + wait_for_subscription_ack_with(read, ack_timeout, handle_subscription_text).await } fn handle_subscription_text(bytes: &mut [u8]) -> Result { @@ -1615,24 +1625,19 @@ fn parse_transaction_notification( "unsupported websocket transaction encoding", )); } - tx_bytes.clear(); - STANDARD - .decode_vec(notification.transaction.transaction.0.as_bytes(), tx_bytes) - .map_err(|_error| { - WebsocketTransactionError::Convert("invalid base64 transaction payload") - })?; + decode_transaction_wire_payload(¬ification.transaction.transaction.0, tx_bytes)?; let signature = serialized_transaction_first_signature(tx_bytes).or_else(|| { notification .signature .and_then(|signature| Signature::from_str(&signature).ok()) .map(SignatureBytes::from) }); - watermarks - .observe_transaction_commitment(notification.slot, commitment_status.as_tx_commitment()); - let tx_payload = std::mem::take(tx_bytes).into_boxed_slice(); + let commitment_status = commitment_status.as_tx_commitment(); + watermarks.observe_transaction_commitment(notification.slot, commitment_status); + let tx_payload = mem::take(tx_bytes).into_boxed_slice(); Ok(Some(SerializedTransactionEvent { slot: notification.slot, - commitment_status: commitment_status.as_tx_commitment(), + commitment_status, confirmed_slot: watermarks.confirmed_slot, finalized_slot: watermarks.finalized_slot, signature, @@ -1641,6 +1646,60 @@ fn parse_transaction_notification( })) } +const MAX_BASE64_TRANSACTION_WIRE_LEN: usize = PACKET_DATA_SIZE.div_ceil(3) * 4; +const MAX_ACCOUNT_DATA_LEN: usize = MAX_PERMITTED_DATA_LENGTH as usize; +const MAX_BASE64_ACCOUNT_DATA_LEN: usize = MAX_ACCOUNT_DATA_LEN.div_ceil(3) * 4; +const WEBSOCKET_TRANSACTION_MAX_MESSAGE_SIZE: usize = 64 * 1024; +const WEBSOCKET_ACCOUNT_MAX_MESSAGE_SIZE: usize = MAX_BASE64_ACCOUNT_DATA_LEN + (128 * 1024); +const WEBSOCKET_LOGS_MAX_MESSAGE_SIZE: usize = 4 * 1024 * 1024; +const WEBSOCKET_REPLAY_HTTP_MAX_RESPONSE_BYTES: usize = 128 * 1024 * 1024; + +fn decode_transaction_wire_payload( + encoded: &str, + tx_bytes: &mut Vec, +) -> Result<(), WebsocketTransactionError> { + if encoded.len() > MAX_BASE64_TRANSACTION_WIRE_LEN { + return Err(WebsocketTransactionError::Convert( + "websocket transaction payload exceeds max wire size", + )); + } + tx_bytes.clear(); + STANDARD + .decode_vec(encoded.as_bytes(), tx_bytes) + .map_err(|_error| { + WebsocketTransactionError::Convert("invalid base64 transaction payload") + })?; + if tx_bytes.len() > PACKET_DATA_SIZE { + tx_bytes.clear(); + return Err(WebsocketTransactionError::Convert( + "websocket transaction payload exceeds max wire size", + )); + } + Ok(()) +} + +fn decode_account_payload( + encoded: &str, + tx_bytes: &mut Vec, +) -> Result<(), WebsocketTransactionError> { + if encoded.len() > MAX_BASE64_ACCOUNT_DATA_LEN { + return Err(WebsocketTransactionError::Convert( + "websocket account payload exceeds max data size", + )); + } + tx_bytes.clear(); + STANDARD + .decode_vec(encoded.as_bytes(), tx_bytes) + .map_err(|_error| WebsocketTransactionError::Convert("invalid base64 account payload"))?; + if tx_bytes.len() > MAX_ACCOUNT_DATA_LEN { + tx_bytes.clear(); + return Err(WebsocketTransactionError::Convert( + "websocket account payload exceeds max data size", + )); + } + Ok(()) +} + fn parse_account_notification( bytes: &mut [u8], json_buffers: &mut SimdJsonBuffers, @@ -1703,20 +1762,19 @@ fn decode_account_update_event( tx_bytes: &mut Vec, ) -> Result { let owner = parse_pubkey(&account.owner)?; + let commitment_status = commitment.as_tx_commitment(); tx_bytes.clear(); if account.data.1 != "base64" { return Err(WebsocketTransactionError::Convert( "unsupported websocket account encoding", )); } - STANDARD - .decode_vec(account.data.0.as_bytes(), tx_bytes) - .map_err(|_error| WebsocketTransactionError::Convert("invalid base64 account payload"))?; - observe_non_transaction_commitment(watermarks, slot, commitment.as_tx_commitment()); - let data = std::mem::take(tx_bytes).into_boxed_slice(); + decode_account_payload(&account.data.0, tx_bytes)?; + observe_non_transaction_commitment(watermarks, slot, commitment_status); + let data = mem::take(tx_bytes).into_boxed_slice(); Ok(AccountUpdateEvent { slot, - commitment_status: commitment.as_tx_commitment(), + commitment_status, confirmed_slot: watermarks.confirmed_slot, finalized_slot: watermarks.finalized_slot, pubkey: pubkey_bytes(pubkey), @@ -1756,24 +1814,6 @@ fn serialized_transaction_first_signature(payload: &[u8]) -> Option Option { - let mut value = 0_usize; - let mut shift = 0_u32; - for byte_index in 0..3 { - let byte = usize::from(*payload.get(*offset)?); - *offset = (*offset).saturating_add(1); - value |= (byte & 0x7f) << shift; - if byte & 0x80 == 0 { - return Some(value); - } - shift = shift.saturating_add(7); - if byte_index == 2 { - return None; - } - } - None -} - fn parse_pubkey(input: &str) -> Result { Pubkey::from_str(input) .map_err(|_error| WebsocketTransactionError::Convert("invalid websocket pubkey")) @@ -1788,7 +1828,7 @@ struct WebsocketPrimaryNotificationState<'state> { async fn handle_primary_notification( config: &WebsocketTransactionConfig, - source: &ProviderSourceIdentity, + source: &Arc, sender: &ProviderStreamSender, bytes: &mut [u8], state: &mut WebsocketPrimaryNotificationState<'_>, @@ -1810,7 +1850,7 @@ async fn handle_primary_notification( sender .send( ProviderStreamUpdate::SerializedTransaction(update) - .with_provider_source(source.clone()), + .with_provider_source_ref(source), ) .await .map_err(|_error| WebsocketTransactionError::QueueClosed)?; @@ -1832,7 +1872,7 @@ async fn handle_primary_notification( sender .send( ProviderStreamUpdate::AccountUpdate(update) - .with_provider_source(source.clone()), + .with_provider_source_ref(source), ) .await .map_err(|_error| WebsocketTransactionError::QueueClosed)?; @@ -1851,6 +1891,7 @@ async fn run_websocket_logs_connection( ) -> Result<(), WebsocketLogsError> { *session_established = false; let (mut write, mut read) = stream.split(); + let provider_source = Arc::new(source.clone()); *session_established = true; send_provider_logs_health( source, @@ -1861,7 +1902,7 @@ async fn run_websocket_logs_connection( PROVIDER_SUBSCRIPTION_ACKNOWLEDGED.to_owned(), ) .await?; - let mut ping = config.ping_interval.map(tokio::time::interval); + let mut ping = config.ping_interval.map(keepalive_interval); let mut frame_bytes = Vec::new(); let mut last_progress = tokio::time::Instant::now(); @@ -1871,17 +1912,17 @@ async fn run_websocket_logs_connection( if let Some(interval) = ping.as_mut() { interval.tick().await; } else { - std::future::pending::<()>().await; + pending::<()>().await; } } => { write.send(WsMessage::Ping(Vec::new().into())).await?; } () = async { - if let Some(timeout) = config.stall_timeout { + if let Some(timeout) = websocket_stall_timeout(config.stall_timeout) { let deadline = last_progress.checked_add(timeout).unwrap_or(last_progress); tokio::time::sleep_until(deadline).await; } else { - std::future::pending::<()>().await; + pending::<()>().await; } } => { return Err(WebsocketProtocolError::StreamStalled { @@ -1904,7 +1945,7 @@ async fn run_websocket_logs_connection( sender .send( ProviderStreamUpdate::TransactionLog(update) - .with_provider_source(source.clone()), + .with_provider_source_ref(&provider_source), ) .await .map_err(|_error| WebsocketLogsError::QueueClosed)?; @@ -1918,7 +1959,7 @@ async fn run_websocket_logs_connection( sender .send( ProviderStreamUpdate::TransactionLog(update) - .with_provider_source(source.clone()), + .with_provider_source_ref(&provider_source), ) .await .map_err(|_error| WebsocketLogsError::QueueClosed)?; @@ -1929,7 +1970,8 @@ async fn run_websocket_logs_connection( } WsMessage::Pong(_) => {} WsMessage::Close(frame) => { - return Err(WebsocketProtocolError::Closed(format!("{frame:?}")).into()); + write.send(WsMessage::Close(frame)).await.ok(); + return Ok(()); } _ => {} } @@ -1941,69 +1983,91 @@ async fn run_websocket_logs_connection( async fn establish_websocket_logs_session( config: &WebsocketLogsConfig, ) -> Result { - let (mut stream, _response) = connect_async(config.endpoint()).await?; + let (mut stream, _response) = tokio::time::timeout( + websocket_connect_timeout(config.stall_timeout), + connect_async_with_config( + config.endpoint(), + Some(websocket_transport_config(WEBSOCKET_LOGS_MAX_MESSAGE_SIZE)), + false, + ), + ) + .await + .map_err(|_elapsed| WebsocketProtocolError::SubscriptionAckTimeout)??; stream .send(WsMessage::Text( config.subscribe_request().to_string().into(), )) .await?; - wait_for_logs_subscription_ack(&mut stream).await?; + wait_for_logs_subscription_ack(&mut stream, websocket_ack_timeout(config.stall_timeout)) + .await?; Ok(stream) } -async fn wait_for_logs_subscription_ack(read: &mut S) -> Result<(), WebsocketLogsError> +async fn wait_for_logs_subscription_ack( + read: &mut S, + ack_timeout: Duration, +) -> Result<(), WebsocketLogsError> +where + S: Stream> + + Sink + + Unpin, +{ + wait_for_subscription_ack_with(read, ack_timeout, handle_logs_subscription_text).await +} + +async fn wait_for_subscription_ack_with( + read: &mut S, + ack_timeout: Duration, + mut handle_text: F, +) -> Result<(), E> where - S: futures_util::Stream> + S: Stream> + + Sink + Unpin, + E: From + From, + F: FnMut(&mut [u8]) -> Result, { - let ack_timeout = Duration::from_secs(10); let mut frame_bytes = Vec::new(); tokio::time::timeout(ack_timeout, async { loop { let Some(frame) = read.next().await else { - return Err(WebsocketLogsError::Protocol( - WebsocketProtocolError::ClosedBeforeSubscriptionAck, - )); + return Err(E::from(WebsocketProtocolError::ClosedBeforeSubscriptionAck)); }; let frame = frame?; match frame { WsMessage::Text(text) => { - if handle_logs_subscription_text(frame_bytes_mut( - &mut frame_bytes, - text.as_str().as_bytes(), - ))? { + if handle_text(frame_bytes_mut(&mut frame_bytes, text.as_str().as_bytes()))? { return Ok(()); } } WsMessage::Binary(bytes) => { - if handle_logs_subscription_text(frame_bytes_mut( - &mut frame_bytes, - bytes.as_ref(), - ))? { + if handle_text(frame_bytes_mut(&mut frame_bytes, bytes.as_ref()))? { return Ok(()); } } - WsMessage::Ping(_) | WsMessage::Pong(_) => {} + WsMessage::Ping(payload) => { + read.send(WsMessage::Pong(payload)).await?; + } + WsMessage::Pong(_) => {} WsMessage::Close(frame) => { - return Err( + return Err(E::from( WebsocketProtocolError::ClosedBeforeSubscriptionAckWithFrame(format!( "{frame:?}" - )) - .into(), - ); + )), + )); } _ => {} } } }) .await - .map_err(|_elapsed| WebsocketProtocolError::SubscriptionAckTimeout)? + .map_err(|_elapsed| E::from(WebsocketProtocolError::SubscriptionAckTimeout))? } fn parse_logs_notification( bytes: &mut [u8], config: &WebsocketLogsConfig, -) -> Result, WebsocketLogsError> { +) -> Result, WebsocketLogsError> { let value: WebsocketLogsEnvelopeMessage = simd_from_slice(bytes) .map_err(|error| WebsocketProtocolError::InvalidJson(error.to_string()))?; if let Some(error) = value.error { @@ -2018,7 +2082,7 @@ fn parse_logs_notification( WebsocketLogsFilter::Mentions(pubkey) => Some(PubkeyBytes::from(pubkey)), WebsocketLogsFilter::All | WebsocketLogsFilter::AllWithVotes => None, }; - Ok(Some(crate::framework::TransactionLogEvent { + Ok(Some(TransactionLogEvent { slot: notification.context.slot, commitment_status: config.commitment.as_tx_commitment(), signature: signature.into(), @@ -2054,11 +2118,8 @@ fn materialize_transaction_baseline( "unsupported websocket transaction encoding", )); } - let tx_bytes = STANDARD - .decode(notification.transaction.transaction.0.as_bytes()) - .map_err(|_error| { - WebsocketTransactionError::Convert("invalid base64 transaction payload") - })?; + let mut tx_bytes = Vec::new(); + decode_transaction_wire_payload(¬ification.transaction.transaction.0, &mut tx_bytes)?; let tx = bincode::deserialize::(&tx_bytes).map_err(|_error| { WebsocketTransactionError::Convert("failed to deserialize transaction") })?; @@ -2067,9 +2128,10 @@ fn materialize_transaction_baseline( .signature .and_then(|signature| Signature::from_str(&signature).ok()) }); + let commitment_status = commitment_status.as_tx_commitment(); Ok(Some(TransactionEvent { slot: notification.slot, - commitment_status: commitment_status.as_tx_commitment(), + commitment_status, confirmed_slot: None, finalized_slot: None, signature: signature_bytes_opt(signature), @@ -2092,9 +2154,28 @@ fn frame_bytes_mut<'buffer>(buffer: &'buffer mut Vec, bytes: &[u8]) -> &'buf buffer.as_mut_slice() } +fn websocket_replay_http_client() -> Result<&'static reqwest::Client, WebsocketTransactionError> { + static CLIENT: OnceLock> = OnceLock::new(); + CLIENT + .get_or_init(|| { + reqwest::Client::builder() + .redirect(Policy::none()) + .build() + .map_err(|error| error.to_string()) + }) + .as_ref() + .map_err(|detail| { + WebsocketProtocolError::HttpRpcFailed { + method: "client", + detail: detail.clone(), + } + .into() + }) +} + async fn replay_websocket_gap( config: &WebsocketTransactionConfig, - source: &ProviderSourceIdentity, + source: &Arc, sender: &ProviderStreamSender, last_seen_slot: &mut Option, watermarks: &mut ProviderCommitmentWatermarks, @@ -2106,24 +2187,35 @@ async fn replay_websocket_gap( return Err(WebsocketProtocolError::MissingReplayHttpEndpoint.into()); }; - let client = reqwest::Client::new(); - let head = rpc_get_slot(&client, &http_endpoint, config.commitment).await?; + let client = websocket_replay_http_client()?; + let head = rpc_get_slot( + client, + &http_endpoint, + config.commitment, + websocket_stall_timeout(config.stall_timeout), + ) + .await?; if head < previous_slot { return Ok(()); } let start_slot = websocket_replay_start_slot(previous_slot, head, config.replay_max_slots); + let commitment_status = config.commitment.as_tx_commitment(); + let mut tx_bytes = Vec::new(); for slot in start_slot..=head { - let Some(block) = rpc_get_block(&client, &http_endpoint, slot, config.commitment).await? + let Some(block) = rpc_get_block( + client, + &http_endpoint, + slot, + config.commitment, + websocket_stall_timeout(config.stall_timeout), + ) + .await? else { continue; }; for transaction in block.transactions { - let tx_bytes = STANDARD - .decode(transaction.transaction.0.as_bytes()) - .map_err(|_error| { - WebsocketTransactionError::Convert("invalid base64 transaction payload") - })?; + decode_transaction_wire_payload(&transaction.transaction.0, &mut tx_bytes)?; let tx = bincode::deserialize::(&tx_bytes).map_err(|_error| { WebsocketTransactionError::Convert("failed to deserialize transaction") })?; @@ -2145,12 +2237,12 @@ async fn replay_websocket_gap( ) { continue; } - watermarks.observe_transaction_commitment(slot, config.commitment.as_tx_commitment()); + watermarks.observe_transaction_commitment(slot, commitment_status); sender .send( ProviderStreamUpdate::Transaction(TransactionEvent { slot, - commitment_status: config.commitment.as_tx_commitment(), + commitment_status, confirmed_slot: watermarks.confirmed_slot, finalized_slot: watermarks.finalized_slot, signature: signature_bytes_opt(tx.signatures.first().copied()), @@ -2158,7 +2250,7 @@ async fn replay_websocket_gap( kind, tx: Arc::new(tx), }) - .with_provider_source(source.clone()), + .with_provider_source_ref(source), ) .await .map_err(|_error| WebsocketTransactionError::QueueClosed)?; @@ -2177,13 +2269,22 @@ async fn establish_websocket_primary_session( { return Err(WebsocketProtocolError::MissingReplayHttpEndpoint.into()); } - let (mut stream, _response) = connect_async(config.endpoint()).await?; + let (mut stream, _response) = tokio::time::timeout( + websocket_connect_timeout(config.stall_timeout), + connect_async_with_config( + config.endpoint(), + Some(websocket_primary_transport_config(config)), + false, + ), + ) + .await + .map_err(|_elapsed| WebsocketProtocolError::SubscriptionAckTimeout)??; stream .send(WsMessage::Text( config.subscribe_request().to_string().into(), )) .await?; - wait_for_subscription_ack(&mut stream).await?; + wait_for_subscription_ack(&mut stream, websocket_ack_timeout(config.stall_timeout)).await?; Ok(stream) } @@ -2207,31 +2308,70 @@ fn websocket_http_endpoint(config: &WebsocketTransactionConfig) -> Option WebSocketConfig { + WebSocketConfig::default() + .max_message_size(Some(max_message_size)) + .max_frame_size(Some(max_message_size)) +} + +const fn websocket_connect_timeout(stall_timeout: Option) -> Duration { + match websocket_stall_timeout(stall_timeout) { + Some(timeout) => timeout, + None => DEFAULT_WEBSOCKET_CONNECT_TIMEOUT, + } +} + +const fn websocket_ack_timeout(stall_timeout: Option) -> Duration { + match websocket_stall_timeout(stall_timeout) { + Some(timeout) => timeout, + None => DEFAULT_WEBSOCKET_CONNECT_TIMEOUT, + } +} + +const fn websocket_stall_timeout(stall_timeout: Option) -> Option { + match stall_timeout { + Some(timeout) => Some(nonzero_duration_or(timeout, MIN_STALL_TIMEOUT)), + None => None, + } +} + +fn http_rpc_error_detail(error: &reqwest::Error) -> String { + if error.is_timeout() { + return format!("request timed out: {error}"); + } + error.to_string() +} + +fn websocket_primary_transport_config(config: &WebsocketTransactionConfig) -> WebSocketConfig { + match config.stream { + WebsocketPrimaryStream::Transaction => { + websocket_transport_config(WEBSOCKET_TRANSACTION_MAX_MESSAGE_SIZE) + } + WebsocketPrimaryStream::Account(_) | WebsocketPrimaryStream::Program(_) => { + websocket_transport_config(WEBSOCKET_ACCOUNT_MAX_MESSAGE_SIZE) + } + } +} + async fn rpc_get_slot( client: &reqwest::Client, endpoint: &str, commitment: WebsocketTransactionCommitment, + timeout: Option, ) -> Result { - let response: RpcJsonResponse = client - .post(endpoint) - .json(&json!({ + let response = rpc_post( + client, + endpoint, + "getSlot", + json!({ "jsonrpc": "2.0", "id": 1, "method": "getSlot", "params": [{ "commitment": commitment.as_str() }], - })) - .send() - .await - .map_err(|error| WebsocketProtocolError::HttpRpcFailed { - method: "getSlot", - detail: error.to_string(), - })? - .json() - .await - .map_err(|error| WebsocketProtocolError::HttpRpcDecodeFailed { - method: "getSlot", - detail: error.to_string(), - })?; + }), + timeout, + ) + .await?; if let Some(error) = response.error { return Err(WebsocketProtocolError::HttpRpcFailed { method: "getSlot", @@ -2249,10 +2389,13 @@ async fn rpc_get_block( endpoint: &str, slot: u64, commitment: WebsocketTransactionCommitment, + timeout: Option, ) -> Result, WebsocketTransactionError> { - let response: RpcJsonResponse = client - .post(endpoint) - .json(&json!({ + let response = rpc_post( + client, + endpoint, + "getBlock", + json!({ "jsonrpc": "2.0", "id": 1, "method": "getBlock", @@ -2266,19 +2409,10 @@ async fn rpc_get_block( "rewards": false } ], - })) - .send() - .await - .map_err(|error| WebsocketProtocolError::HttpRpcFailed { - method: "getBlock", - detail: error.to_string(), - })? - .json() - .await - .map_err(|error| WebsocketProtocolError::HttpRpcDecodeFailed { - method: "getBlock", - detail: error.to_string(), - })?; + }), + timeout, + ) + .await?; if let Some(error) = response.error { return Err(WebsocketProtocolError::HttpRpcFailed { method: "getBlock", @@ -2289,11 +2423,122 @@ async fn rpc_get_block( Ok(response.result) } +async fn rpc_post( + client: &reqwest::Client, + endpoint: &str, + method: &'static str, + payload: Value, + timeout: Option, +) -> Result, WebsocketTransactionError> +where + T: for<'de> Deserialize<'de>, +{ + rpc_post_with_max_response_bytes( + client, + endpoint, + method, + payload, + WEBSOCKET_REPLAY_HTTP_MAX_RESPONSE_BYTES, + timeout, + ) + .await +} + +async fn rpc_post_with_max_response_bytes( + client: &reqwest::Client, + endpoint: &str, + method: &'static str, + payload: Value, + max_response_bytes: usize, + timeout: Option, +) -> Result, WebsocketTransactionError> +where + T: for<'de> Deserialize<'de>, +{ + let mut request = client.post(endpoint).json(&payload); + if let Some(timeout) = timeout { + request = request.timeout(timeout); + } + let response = request + .send() + .await + .map_err(|error| WebsocketProtocolError::HttpRpcFailed { + method, + detail: http_rpc_error_detail(&error), + })?; + if response.status().is_redirection() { + return Err(WebsocketProtocolError::HttpRpcFailed { + method, + detail: format!("unexpected redirect response: {}", response.status()), + } + .into()); + } + let response = + response + .error_for_status() + .map_err(|error| WebsocketProtocolError::HttpRpcFailed { + method, + detail: error.to_string(), + })?; + let response_body = + read_http_response_bytes_bounded(response, method, max_response_bytes).await?; + Ok(json_from_slice(&response_body).map_err(|error| { + WebsocketProtocolError::HttpRpcDecodeFailed { + method, + detail: error.to_string(), + } + })?) +} + +async fn read_http_response_bytes_bounded( + mut response: reqwest::Response, + method: &'static str, + max_response_bytes: usize, +) -> Result, WebsocketTransactionError> { + if response + .content_length() + .is_some_and(|content_length| content_length > max_response_bytes as u64) + { + return Err(WebsocketProtocolError::HttpRpcFailed { + method, + detail: format!("response body exceeded max size of {max_response_bytes} bytes"), + } + .into()); + } + + let initial_capacity = response + .content_length() + .and_then(|content_length| usize::try_from(content_length).ok()) + .unwrap_or(0) + .min(max_response_bytes); + let mut body = Vec::with_capacity(initial_capacity); + while let Some(chunk) = + response + .chunk() + .await + .map_err(|error| WebsocketProtocolError::HttpRpcFailed { + method, + detail: http_rpc_error_detail(&error), + })? + { + let remaining = max_response_bytes.saturating_sub(body.len()); + if chunk.len() > remaining { + return Err(WebsocketProtocolError::HttpRpcFailed { + method, + detail: format!("response body exceeded max size of {max_response_bytes} bytes"), + } + .into()); + } + body.extend_from_slice(&chunk); + } + Ok(body) +} + fn websocket_transaction_matches_filter( config: &WebsocketTransactionConfig, tx: &VersionedTransaction, loaded_addresses: Option<&RpcLoadedAddresses>, - kind: crate::event::TxKind, + kind: TxKind, failed: bool, ) -> bool { if let Some(signature) = config.signature @@ -2302,7 +2547,7 @@ fn websocket_transaction_matches_filter( return false; } if let Some(expect_vote) = config.vote { - let is_vote = kind == crate::event::TxKind::VoteOnly; + let is_vote = kind == TxKind::VoteOnly; if is_vote != expect_vote { return false; } @@ -2556,29 +2801,29 @@ impl RpcLoadedAddresses { )] mod tests { use super::*; - use crate::event::TxKind; - use crate::provider_stream::{create_provider_stream_fan_in, create_provider_stream_queue}; + #[cfg(feature = "provider-grpc")] + use crate::provider_stream::yellowstone::{YellowstoneGrpcCommitment, YellowstoneGrpcConfig}; + use crate::{ + event::TxKind, + provider_stream::{ + ProviderSourceId, ProviderSourceIdentity, ProviderStreamUpdate, + create_provider_stream_fan_in, create_provider_stream_queue, + }, + }; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; use serde_json::json; + #[cfg(feature = "provider-grpc")] + use serde_json::to_value; + use sof_support::bench::{avg_ns_per_iteration, profile_iterations}; use solana_keypair::Keypair; use solana_message::{Message, VersionedMessage}; use solana_signer::Signer; use std::time::Instant; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use tokio::time::{Duration, timeout}; use tokio_tungstenite::{accept_async, tungstenite::protocol::Message as WsMessage}; - #[cfg(feature = "provider-grpc")] - use crate::provider_stream::yellowstone::{YellowstoneGrpcCommitment, YellowstoneGrpcConfig}; - - fn profile_iterations(default: usize) -> usize { - std::env::var("SOF_PROFILE_ITERATIONS") - .ok() - .and_then(|value| value.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(default) - } - fn sample_notification_payload() -> Vec { let signer = Keypair::new(); let message = Message::new(&[], Some(&signer.pubkey())); @@ -2603,41 +2848,290 @@ mod tests { .into_bytes() } - #[cfg(feature = "provider-grpc")] - #[test] - fn websocket_filter_shape_matches_yellowstone_config() { - let signature = Signature::from([7_u8; 64]); - let include = [Pubkey::new_unique(), Pubkey::new_unique()]; - let exclude = [Pubkey::new_unique()]; - let required = [Pubkey::new_unique()]; + async fn spawn_raw_http_response_server(response: String) -> String { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("http listener"); + let addr = listener.local_addr().expect("http local addr"); + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.expect("http accept"); + let mut request = [0_u8; 1024]; + let _ = stream.read(&mut request).await; + stream + .write_all(response.as_bytes()) + .await + .expect("http response write"); + }); + format!("http://{addr}") + } - let websocket = WebsocketTransactionConfig::new("wss://example.invalid") - .with_commitment(WebsocketTransactionCommitment::Confirmed) - .with_vote(true) - .with_failed(true) - .with_signature(signature) - .with_account_include(include) - .with_account_exclude(exclude) - .with_account_required(required) - .subscribe_request(); - let yellowstone = YellowstoneGrpcConfig::new("http://127.0.0.1:10000") - .with_commitment(YellowstoneGrpcCommitment::Confirmed) - .with_vote(true) - .with_failed(true) - .with_signature(signature) - .with_account_include(include) - .with_account_exclude(exclude) - .with_account_required(required) - .subscribe_request(); + async fn spawn_http_response_server(status_line: &'static str, body: &'static str) -> String { + spawn_raw_http_response_server(format!( + "{status_line}\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + )) + .await + } - let websocket_filter = websocket["params"][0] - .as_object() - .expect("websocket filter"); - let yellowstone_filter = yellowstone.transactions.get("sof").expect("ys filter"); + async fn spawn_redirect_http_response_server( + status_line: &'static str, + location: &str, + ) -> String { + spawn_raw_http_response_server(format!( + "{status_line}\r\nlocation: {location}\r\ncontent-length: 0\r\nconnection: close\r\n\r\n" + )) + .await + } - assert_eq!( - websocket["params"][1]["commitment"].as_str(), - Some("confirmed") + async fn spawn_chunked_http_response_server( + status_line: &'static str, + chunks: Vec, + ) -> String { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("http listener"); + let addr = listener.local_addr().expect("http local addr"); + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.expect("http accept"); + let mut request = [0_u8; 1024]; + let _ = stream.read(&mut request).await; + let header = format!( + "{status_line}\r\ncontent-type: application/json\r\ntransfer-encoding: chunked\r\nconnection: close\r\n\r\n" + ); + stream + .write_all(header.as_bytes()) + .await + .expect("http response header write"); + for chunk in chunks { + let chunk_header = format!("{:X}\r\n", chunk.len()); + stream + .write_all(chunk_header.as_bytes()) + .await + .expect("http chunk header write"); + stream + .write_all(chunk.as_bytes()) + .await + .expect("http chunk body write"); + stream + .write_all(b"\r\n") + .await + .expect("http chunk terminator write"); + } + stream + .write_all(b"0\r\n\r\n") + .await + .expect("http final chunk write"); + }); + format!("http://{addr}") + } + + async fn spawn_stalled_http_response_server(stall: Duration) -> String { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("http listener"); + let addr = listener.local_addr().expect("http local addr"); + tokio::spawn(async move { + let (_stream, _) = listener.accept().await.expect("http accept"); + tokio::time::sleep(stall).await; + }); + format!("http://{addr}") + } + + #[tokio::test] + async fn websocket_rpc_get_slot_surfaces_http_status_failures() { + let endpoint = spawn_http_response_server( + "HTTP/1.1 500 Internal Server Error", + "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32000,\"message\":\"boom\"}}", + ) + .await; + let client = reqwest::Client::new(); + + let error = rpc_get_slot( + &client, + &endpoint, + WebsocketTransactionCommitment::Confirmed, + None, + ) + .await + .expect_err("http failure should surface as rpc failure"); + + assert!( + error.to_string().contains("500 Internal Server Error"), + "unexpected error: {error}" + ); + } + + #[tokio::test] + async fn websocket_rpc_get_block_surfaces_http_status_failures() { + let endpoint = spawn_http_response_server( + "HTTP/1.1 503 Service Unavailable", + "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32000,\"message\":\"down\"}}", + ) + .await; + let client = reqwest::Client::new(); + + let error = rpc_get_block( + &client, + &endpoint, + 55, + WebsocketTransactionCommitment::Confirmed, + None, + ) + .await + .expect_err("http failure should surface as rpc failure"); + + assert!( + error.to_string().contains("503 Service Unavailable"), + "unexpected error: {error}" + ); + } + + #[tokio::test] + async fn websocket_rpc_get_slot_rejects_redirects() { + let target = + spawn_http_response_server("HTTP/1.1 200 OK", "{\"jsonrpc\":\"2.0\",\"result\":99}") + .await; + let endpoint = + spawn_redirect_http_response_server("HTTP/1.1 307 Temporary Redirect", &target).await; + + let client = match websocket_replay_http_client() { + Ok(client) => client, + Err(error) => panic!("shared replay client: {error}"), + }; + let error = rpc_get_slot( + client, + &endpoint, + WebsocketTransactionCommitment::Confirmed, + None, + ) + .await + .expect_err("redirect replay response should be rejected"); + + assert!( + error.to_string().contains("307 Temporary Redirect"), + "unexpected error: {error}" + ); + } + + #[tokio::test] + async fn websocket_rpc_get_slot_rejects_oversized_content_length() { + let endpoint = spawn_raw_http_response_server( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: 99\r\nconnection: close\r\n\r\n{}".to_owned(), + ) + .await; + let client = reqwest::Client::new(); + + let error = rpc_post_with_max_response_bytes::( + &client, + &endpoint, + "getSlot", + json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "getSlot", + "params": [] + }), + 32, + None, + ) + .await + .expect_err("oversized content-length should be rejected"); + + assert!( + error.to_string().contains("exceeded max size"), + "unexpected error: {error}" + ); + } + + #[tokio::test] + async fn websocket_rpc_get_slot_rejects_oversized_chunked_response() { + let endpoint = spawn_chunked_http_response_server( + "HTTP/1.1 200 OK", + vec![ + "{\"jsonrpc\":\"2.0\",\"result\":".to_owned(), + "12345678901234567890}".to_owned(), + ], + ) + .await; + let client = reqwest::Client::new(); + + let error = rpc_post_with_max_response_bytes::( + &client, + &endpoint, + "getSlot", + json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "getSlot", + "params": [] + }), + 16, + None, + ) + .await + .expect_err("oversized chunked response should be rejected"); + + assert!( + error.to_string().contains("exceeded max size"), + "unexpected error: {error}" + ); + } + + #[tokio::test] + async fn websocket_rpc_get_slot_times_out_stalled_response() { + let endpoint = spawn_stalled_http_response_server(Duration::from_millis(200)).await; + let client = reqwest::Client::new(); + + let error = rpc_get_slot( + &client, + &endpoint, + WebsocketTransactionCommitment::Confirmed, + Some(Duration::from_millis(25)), + ) + .await + .expect_err("stalled replay response should time out"); + + assert!( + error.to_string().contains("timed out"), + "unexpected error: {error}" + ); + } + + #[cfg(feature = "provider-grpc")] + #[test] + fn websocket_filter_shape_matches_yellowstone_config() { + let signature = Signature::from([7_u8; 64]); + let include = [Pubkey::new_unique(), Pubkey::new_unique()]; + let exclude = [Pubkey::new_unique()]; + let required = [Pubkey::new_unique()]; + + let websocket = WebsocketTransactionConfig::new("wss://example.invalid") + .with_commitment(WebsocketTransactionCommitment::Confirmed) + .with_vote(true) + .with_failed(true) + .with_signature(signature) + .with_account_include(include) + .with_account_exclude(exclude) + .with_account_required(required) + .subscribe_request(); + let yellowstone = YellowstoneGrpcConfig::new("http://127.0.0.1:10000") + .with_commitment(YellowstoneGrpcCommitment::Confirmed) + .with_vote(true) + .with_failed(true) + .with_signature(signature) + .with_account_include(include) + .with_account_exclude(exclude) + .with_account_required(required) + .subscribe_request(); + + let websocket_filter = websocket["params"][0] + .as_object() + .expect("websocket filter"); + let yellowstone_filter = yellowstone.transactions.get("sof").expect("ys filter"); + + assert_eq!( + websocket["params"][1]["commitment"].as_str(), + Some("confirmed") ); assert_eq!( websocket_filter.get("vote").and_then(Value::as_bool), @@ -2653,24 +3147,15 @@ mod tests { ); assert_eq!( websocket_filter.get("accountInclude"), - Some( - &serde_json::to_value(yellowstone_filter.account_include.clone()) - .expect("include json") - ) + Some(&to_value(yellowstone_filter.account_include.clone()).expect("include json")) ); assert_eq!( websocket_filter.get("accountExclude"), - Some( - &serde_json::to_value(yellowstone_filter.account_exclude.clone()) - .expect("exclude json") - ) + Some(&to_value(yellowstone_filter.account_exclude.clone()).expect("exclude json")) ); assert_eq!( websocket_filter.get("accountRequired"), - Some( - &serde_json::to_value(yellowstone_filter.account_required.clone()) - .expect("required json") - ) + Some(&to_value(yellowstone_filter.account_required.clone()).expect("required json")) ); } @@ -2699,6 +3184,20 @@ mod tests { assert!(!filter.contains_key("failed")); } + #[test] + fn websocket_reconnect_delay_never_spins() { + let config = WebsocketTransactionConfig::new("wss://example.invalid") + .with_reconnect_delay(Duration::ZERO); + assert_eq!(config.reconnect_delay_effective(), Duration::from_millis(1)); + + let logs_config = + WebsocketLogsConfig::new("wss://example.invalid").with_reconnect_delay(Duration::ZERO); + assert_eq!( + logs_config.reconnect_delay_effective(), + Duration::from_millis(1) + ); + } + #[test] fn websocket_logs_subscribe_request_uses_configured_filter_and_commitment() { let pubkey = Pubkey::new_unique(); @@ -2840,6 +3339,85 @@ mod tests { assert_eq!(event.finalized_slot, None); } + #[test] + fn websocket_transaction_notification_rejects_oversized_payload() { + let mut payload = json!({ + "jsonrpc":"2.0", + "method":"transactionNotification", + "params":{ + "result":{ + "slot":55, + "transaction":{ + "transaction":[BASE64_STANDARD.encode(vec![7_u8; PACKET_DATA_SIZE + 1]),"base64"] + } + } + } + }) + .to_string() + .into_bytes(); + let mut json_buffers = SimdJsonBuffers::default(); + let mut tx_bytes = Vec::new(); + let mut watermarks = ProviderCommitmentWatermarks::default(); + + let error = parse_transaction_notification( + &mut payload, + &mut json_buffers, + &mut tx_bytes, + WebsocketTransactionCommitment::Confirmed, + &mut watermarks, + ) + .expect_err("oversized payload should be rejected"); + + assert!( + error + .to_string() + .contains("websocket transaction payload exceeds max wire size"), + "unexpected error: {error}" + ); + } + + #[test] + fn websocket_transport_config_caps_transaction_stream_frames() { + let config = WebsocketTransactionConfig::new("ws://example.invalid"); + let transport = websocket_primary_transport_config(&config); + + assert_eq!( + transport.max_message_size, + Some(WEBSOCKET_TRANSACTION_MAX_MESSAGE_SIZE), + ); + assert_eq!( + transport.max_frame_size, + Some(WEBSOCKET_TRANSACTION_MAX_MESSAGE_SIZE), + ); + } + + #[test] + fn websocket_transport_config_caps_account_stream_frames() { + let config = WebsocketTransactionConfig::new("ws://example.invalid") + .with_stream(WebsocketPrimaryStream::Account(Pubkey::new_unique())); + let transport = websocket_primary_transport_config(&config); + + assert_eq!( + transport.max_message_size, + Some(WEBSOCKET_ACCOUNT_MAX_MESSAGE_SIZE), + ); + assert!(transport.max_message_size.unwrap_or_default() > MAX_BASE64_ACCOUNT_DATA_LEN,); + } + + #[test] + fn websocket_logs_transport_config_caps_frames() { + let transport = websocket_transport_config(WEBSOCKET_LOGS_MAX_MESSAGE_SIZE); + + assert_eq!( + transport.max_message_size, + Some(WEBSOCKET_LOGS_MAX_MESSAGE_SIZE), + ); + assert_eq!( + transport.max_frame_size, + Some(WEBSOCKET_LOGS_MAX_MESSAGE_SIZE) + ); + } + #[test] fn websocket_http_endpoint_derives_from_websocket_scheme() { let config = WebsocketTransactionConfig::new("wss://example.invalid/?api-key=1"); @@ -2965,6 +3543,52 @@ mod tests { assert_eq!(event.matched_filter, Some(pubkey.into())); } + #[test] + fn websocket_account_notification_rejects_oversized_payload() { + let pubkey = Pubkey::new_unique(); + let owner = Pubkey::new_unique(); + let payload = json!({ + "jsonrpc":"2.0", + "method":"accountNotification", + "params":{ + "result":{ + "context":{"slot":77}, + "value":{ + "lamports":42, + "data":[STANDARD.encode(vec![7_u8; MAX_ACCOUNT_DATA_LEN + 1]), "base64"], + "owner":owner.to_string(), + "executable":false, + "rentEpoch":9 + } + } + } + }) + .to_string() + .into_bytes(); + let mut payload = payload; + let mut json_buffers = SimdJsonBuffers::default(); + let mut scratch = Vec::new(); + let mut watermarks = ProviderCommitmentWatermarks::default(); + let config = WebsocketTransactionConfig::new("wss://example.invalid") + .with_stream(WebsocketPrimaryStream::Account(pubkey)); + + let error = parse_account_notification( + &mut payload, + &mut json_buffers, + &mut scratch, + &config, + &mut watermarks, + ) + .expect_err("oversized account payload should be rejected"); + + assert!( + error + .to_string() + .contains("websocket account payload exceeds max data size"), + "unexpected error: {error}" + ); + } + #[test] fn websocket_program_notification_decodes_account_update_event() { let program_id = Pubkey::new_unique(); @@ -3040,6 +3664,76 @@ mod tests { } } + #[tokio::test] + async fn websocket_spawn_times_out_stalled_handshake() { + let listener = TcpListener::bind("127.0.0.1:0").await.expect("listener"); + let addr = listener.local_addr().expect("local addr"); + + let server = tokio::spawn(async move { + let accepted = listener.accept().await; + assert!(accepted.is_ok()); + let (_stream, _) = accepted.unwrap_or_else(|error| panic!("{error}")); + tokio::time::sleep(Duration::from_millis(100)).await; + }); + + let (tx, _rx) = create_provider_stream_queue(1); + let config = WebsocketTransactionConfig::new(format!("ws://{addr}")) + .with_stall_timeout(Duration::from_millis(25)); + + let error = spawn_websocket_source(&config, tx) + .await + .expect_err("stalled websocket handshake should time out"); + assert!( + error.to_string().contains("timed out"), + "unexpected error: {error}" + ); + + server.abort(); + drop(server.await); + } + + #[tokio::test] + async fn websocket_spawn_times_out_stalled_subscription_ack() { + let listener = TcpListener::bind("127.0.0.1:0").await.expect("listener"); + let addr = listener.local_addr().expect("local addr"); + + let server = tokio::spawn(async move { + let accepted = listener.accept().await; + assert!(accepted.is_ok()); + let (stream, _) = accepted.unwrap_or_else(|error| panic!("{error}")); + let _ws = accept_async(stream).await.expect("websocket handshake"); + tokio::time::sleep(Duration::from_millis(100)).await; + }); + + let (tx, _rx) = create_provider_stream_queue(1); + let config = WebsocketTransactionConfig::new(format!("ws://{addr}")) + .with_stall_timeout(Duration::from_millis(25)); + + let error = spawn_websocket_source(&config, tx) + .await + .expect_err("stalled subscription ack should time out"); + assert!( + error.to_string().contains("timed out"), + "unexpected error: {error}" + ); + + server.abort(); + drop(server.await); + } + + #[test] + fn websocket_stall_timeout_never_zero() { + assert_eq!( + websocket_stall_timeout(Some(Duration::ZERO)), + Some(Duration::from_millis(1)) + ); + assert_eq!( + websocket_stall_timeout(Some(Duration::from_millis(25))), + Some(Duration::from_millis(25)) + ); + assert_eq!(websocket_stall_timeout(None), None); + } + #[tokio::test] async fn websocket_logs_source_delivers_log_update() { let listener = TcpListener::bind("127.0.0.1:0").await.expect("listener"); @@ -3112,6 +3806,83 @@ mod tests { server.await.expect("server task"); } + #[tokio::test] + async fn websocket_logs_source_replies_to_ping_before_subscription_ack() { + let listener = TcpListener::bind("127.0.0.1:0").await.expect("listener"); + let addr = listener.local_addr().expect("local addr"); + let signature = Signature::from([7_u8; 64]); + let payload = json!({ + "jsonrpc":"2.0", + "method":"logsNotification", + "params":{ + "result":{ + "context":{"slot":91}, + "value":{ + "signature":signature.to_string(), + "err":null, + "logs":["Program log: ping-before-ack"] + } + } + } + }) + .to_string(); + + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.expect("accept"); + let mut ws = accept_async(stream).await.expect("websocket handshake"); + let subscribe = ws + .next() + .await + .expect("subscribe frame") + .expect("subscribe message"); + match subscribe { + WsMessage::Text(text) => assert!(text.contains("logsSubscribe")), + other => panic!("expected subscribe text frame, got {other:?}"), + } + ws.send(WsMessage::Ping(Vec::from(&b"probe"[..]).into())) + .await + .expect("ping"); + let pong = ws.next().await.expect("pong frame").expect("pong message"); + assert!(matches!(pong, WsMessage::Pong(payload) if payload.as_ref() == b"probe")); + ws.send(WsMessage::Text( + String::from(r#"{"jsonrpc":"2.0","id":1,"result":42}"#).into(), + )) + .await + .expect("ack"); + ws.send(WsMessage::Text(payload.into())) + .await + .expect("notification"); + ws.close(None).await.expect("close"); + }); + + let (tx, mut rx) = create_provider_stream_queue(8); + let config = WebsocketLogsConfig::new(format!("ws://{addr}")) + .with_ping_interval(Duration::from_millis(250)) + .with_reconnect_delay(Duration::from_millis(10)) + .with_max_reconnect_attempts(1); + let handle = spawn_websocket_logs_source(&config, tx) + .await + .expect("spawn websocket logs source"); + + let event = loop { + let update = timeout(Duration::from_secs(2), rx.recv()) + .await + .expect("provider update timeout") + .expect("provider update"); + match update { + ProviderStreamUpdate::TransactionLog(event) => break event, + ProviderStreamUpdate::Health(_) => continue, + other => panic!("expected log update, got {other:?}"), + } + }; + assert_eq!(event.slot, 91); + assert_eq!(event.signature, signature.into()); + + handle.abort(); + handle.await.ok(); + server.await.expect("server task"); + } + #[tokio::test] async fn websocket_transaction_source_emits_initial_health_registration() { let listener = TcpListener::bind("127.0.0.1:0").await.expect("listener"); @@ -3473,6 +4244,71 @@ mod tests { server.await.expect("server task"); } + #[tokio::test] + async fn websocket_source_replies_to_ping_before_subscription_ack() { + let listener = TcpListener::bind("127.0.0.1:0").await.expect("listener"); + let addr = listener.local_addr().expect("local addr"); + let payload = sample_notification_payload(); + + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.expect("accept"); + let mut ws = accept_async(stream).await.expect("websocket handshake"); + let subscribe = ws + .next() + .await + .expect("subscribe frame") + .expect("subscribe message"); + match subscribe { + WsMessage::Text(text) => assert!(text.contains("transactionSubscribe")), + other => panic!("expected subscribe text frame, got {other:?}"), + } + ws.send(WsMessage::Ping(Vec::from(&b"probe"[..]).into())) + .await + .expect("ping"); + let pong = ws.next().await.expect("pong frame").expect("pong message"); + assert!(matches!(pong, WsMessage::Pong(payload) if payload.as_ref() == b"probe")); + ws.send(WsMessage::Text( + String::from(r#"{"jsonrpc":"2.0","id":1,"result":42}"#).into(), + )) + .await + .expect("ack"); + ws.send(WsMessage::Text( + String::from_utf8(payload) + .expect("notification utf8") + .into(), + )) + .await + .expect("notification"); + ws.close(None).await.expect("close"); + }); + + let (tx, mut rx) = create_provider_stream_queue(8); + let config = WebsocketTransactionConfig::new(format!("ws://{addr}")) + .with_max_reconnect_attempts(1) + .with_reconnect_delay(Duration::from_millis(10)); + let handle = spawn_websocket_source(&config, tx) + .await + .expect("spawn websocket source"); + + let event = loop { + let update = timeout(Duration::from_secs(2), rx.recv()) + .await + .expect("provider update timeout") + .expect("provider update"); + match update { + ProviderStreamUpdate::SerializedTransaction(event) => break event, + ProviderStreamUpdate::Health(_) => continue, + other => panic!("expected transaction update, got {other:?}"), + } + }; + assert_eq!(event.slot, 55); + assert!(event.signature.is_some()); + + handle.abort(); + handle.await.ok(); + server.await.expect("server task"); + } + #[tokio::test] async fn websocket_source_reconnects_and_delivers_after_disconnect() { let listener = TcpListener::bind("127.0.0.1:0").await.expect("listener"); @@ -3539,10 +4375,7 @@ mod tests { match update { ProviderStreamUpdate::Health(event) => { saw_health = true; - assert_eq!( - event.source.kind, - crate::provider_stream::ProviderSourceId::WebsocketTransaction - ); + assert_eq!(event.source.kind, ProviderSourceId::WebsocketTransaction); continue; } ProviderStreamUpdate::SerializedTransaction(event) => break event, @@ -3571,6 +4404,78 @@ mod tests { server.await.expect("server task"); } + #[tokio::test] + async fn websocket_source_reports_clean_close_as_stream_end() { + let listener = TcpListener::bind("127.0.0.1:0").await.expect("listener"); + let addr = listener.local_addr().expect("local addr"); + + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.expect("accept"); + let mut ws = accept_async(stream).await.expect("websocket handshake"); + let subscribe = ws + .next() + .await + .expect("subscribe frame") + .expect("subscribe message"); + match subscribe { + WsMessage::Text(text) => { + assert!(text.contains("transactionSubscribe")); + } + other @ WsMessage::Binary(_) + | other @ WsMessage::Ping(_) + | other @ WsMessage::Pong(_) + | other @ WsMessage::Close(_) + | other @ WsMessage::Frame(_) => { + panic!("expected subscribe text frame, got {other:?}"); + } + } + ws.send(WsMessage::Text( + String::from(r#"{"jsonrpc":"2.0","id":1,"result":42}"#).into(), + )) + .await + .expect("ack"); + ws.close(None).await.expect("close"); + }); + + let (tx, mut rx) = create_provider_stream_queue(8); + let config = WebsocketTransactionConfig::new(format!("ws://{addr}")) + .with_max_reconnect_attempts(1) + .with_reconnect_delay(Duration::from_millis(10)); + let handle = spawn_websocket_source(&config, tx) + .await + .expect("spawn websocket source"); + + let health = loop { + let update = timeout(Duration::from_secs(2), rx.recv()) + .await + .expect("provider update timeout") + .expect("provider update"); + match update { + ProviderStreamUpdate::Health(event) + if event.status == ProviderSourceHealthStatus::Reconnecting + && event.reason + == ProviderSourceHealthReason::UpstreamStreamClosedUnexpectedly => + { + break event; + } + _ => continue, + } + }; + assert_eq!( + health.reason, + ProviderSourceHealthReason::UpstreamStreamClosedUnexpectedly + ); + assert!( + health.message.contains("stream ended unexpectedly"), + "unexpected health message: {}", + health.message + ); + + handle.abort(); + handle.await.ok(); + server.await.expect("server task"); + } + #[tokio::test] async fn websocket_fan_in_delivers_updates_from_multiple_sources() { let tx_listener = TcpListener::bind("127.0.0.1:0").await.expect("tx listener"); @@ -3798,14 +4703,8 @@ mod tests { } assert_eq!(sources.len(), 2); - assert_eq!( - sources[0].kind, - crate::provider_stream::ProviderSourceId::WebsocketTransaction - ); - assert_eq!( - sources[1].kind, - crate::provider_stream::ProviderSourceId::WebsocketTransaction - ); + assert_eq!(sources[0].kind, ProviderSourceId::WebsocketTransaction); + assert_eq!(sources[1].kind, ProviderSourceId::WebsocketTransaction); assert_ne!(sources[0], sources[1]); handle_a.abort(); @@ -3987,7 +4886,7 @@ mod tests { ) .expect("baseline parse") .expect("baseline event"); - std::hint::black_box(event); + black_box(event); } let baseline_elapsed = baseline_started.elapsed(); @@ -4007,15 +4906,73 @@ mod tests { ) .expect("optimized parse") .expect("optimized event"); - std::hint::black_box(event); + black_box(event); + } + let optimized_elapsed = optimized_started.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); + + eprintln!( + "websocket_transaction_parse_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, + ); + } + + #[test] + #[ignore = "profiling fixture for websocket provider source attachment path"] + fn websocket_provider_source_attachment_profile_fixture() { + let iterations = profile_iterations(500_000); + let payload = sample_notification_payload(); + let mut frame_bytes = Vec::new(); + let mut json_buffers = SimdJsonBuffers::default(); + let mut tx_bytes = Vec::new(); + let mut watermarks = ProviderCommitmentWatermarks::default(); + let update = ProviderStreamUpdate::SerializedTransaction( + parse_transaction_notification( + frame_bytes_mut(&mut frame_bytes, &payload), + &mut json_buffers, + &mut tx_bytes, + WebsocketTransactionCommitment::Confirmed, + &mut watermarks, + ) + .expect("parse update") + .expect("transaction update"), + ); + let source = ProviderSourceIdentity::new( + ProviderSourceId::WebsocketTransaction, + "websocket-source-a", + ); + let source_ref = Arc::new(source.clone()); + + let baseline_started = Instant::now(); + for _ in 0..iterations { + black_box(update.clone().with_provider_source(source.clone())); + } + let baseline_elapsed = baseline_started.elapsed(); + + let optimized_started = Instant::now(); + for _ in 0..iterations { + black_box(update.clone().with_provider_source_ref(&source_ref)); } let optimized_elapsed = optimized_started.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); eprintln!( - "websocket_transaction_parse_profile_fixture iterations={} baseline_us={} optimized_us={}", + "websocket_provider_source_attachment_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", iterations, baseline_elapsed.as_micros(), optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, ); } @@ -4033,7 +4990,7 @@ mod tests { ) .expect("baseline parse") .expect("baseline event"); - std::hint::black_box(event); + black_box(event); } } @@ -4058,7 +5015,7 @@ mod tests { ) .expect("optimized parse") .expect("optimized event"); - std::hint::black_box(event); + black_box(event); } } } diff --git a/crates/sof-observer/src/provider_stream/yellowstone.rs b/crates/sof-observer/src/provider_stream/yellowstone.rs index b7a857f7..de2cd3ec 100644 --- a/crates/sof-observer/src/provider_stream/yellowstone.rs +++ b/crates/sof-observer/src/provider_stream/yellowstone.rs @@ -2,14 +2,22 @@ //! Yellowstone gRPC adapters for SOF processed provider-stream ingress. +#[cfg(test)] +use std::hint::black_box; use std::{ collections::HashMap, + fmt, + pin::Pin, str::FromStr, sync::Arc, time::{Duration, Instant}, }; use futures_util::{SinkExt, StreamExt}; +use sof_support::bytes::{pubkey_bytes_from_slice, signature_bytes_from_slice}; +use sof_support::collections_support::prune_recent_slots; +use sof_support::time_support::nonzero_duration_or; +use sof_types::SignatureBytes; use solana_hash::Hash; use solana_message::{ Message, MessageHeader, VersionedMessage, @@ -18,6 +26,7 @@ use solana_message::{ }; use solana_pubkey::Pubkey; use solana_signature::Signature; +use solana_system_interface::MAX_PERMITTED_DATA_LENGTH; use solana_transaction::versioned::VersionedTransaction; use thiserror::Error; use tokio::sync::mpsc; @@ -34,7 +43,7 @@ use crate::{ event::{ForkSlotStatus, TxCommitmentStatus, TxKind}, framework::{ AccountUpdateEvent, BlockMetaEvent, SlotStatusEvent, TransactionEvent, - TransactionStatusEvent, pubkey_bytes, signature_bytes_opt, + TransactionStatusEvent, }, provider_stream::{ ProviderCommitmentWatermarks, ProviderReplayMode, ProviderSourceArbitrationMode, @@ -43,11 +52,18 @@ use crate::{ ProviderSourceReadiness, ProviderSourceReservation, ProviderSourceRole, ProviderSourceTaskGuard, ProviderStreamFanIn, ProviderStreamMode, ProviderStreamSender, ProviderStreamUpdate, classify_provider_transaction_kind, - emit_provider_source_removed_with_reservation, + emit_provider_source_removed_with_reservation, keepalive_interval, }, }; const INTERNAL_SLOT_FILTER: &str = "__sof_internal_slots"; +const MAX_ACCOUNT_DATA_LEN: usize = MAX_PERMITTED_DATA_LENGTH as usize; +const SLOT_STATUS_RETAINED_LAG: u64 = 4_096; +const SLOT_STATUS_PRUNE_THRESHOLD: usize = SLOT_STATUS_RETAINED_LAG as usize * 2; +const DEFAULT_MAX_DECODING_MESSAGE_SIZE: usize = 64 * 1024 * 1024; +const MIN_PROVIDER_CONNECT_TIMEOUT: Duration = Duration::from_millis(1); +const MIN_PROVIDER_STALL_TIMEOUT: Duration = Duration::from_millis(1); +const MIN_RECONNECT_DELAY: Duration = Duration::from_millis(1); /// Yellowstone subscription commitment used for provider-stream transaction updates. #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] @@ -84,7 +100,7 @@ impl YellowstoneGrpcCommitment { pub struct YellowstoneGrpcConfig { endpoint: String, x_token: Option, - source_instance: Option>, + source_instance: Option>, readiness: ProviderSourceReadiness, source_role: ProviderSourceRole, source_priority: u16, @@ -133,8 +149,8 @@ pub enum YellowstoneGrpcConfigOption { RequireTransactionSignature, } -impl std::fmt::Display for YellowstoneGrpcConfigOption { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for YellowstoneGrpcConfigOption { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::VoteFilter => f.write_str("vote filter"), Self::FailedFilter => f.write_str("failed filter"), @@ -198,7 +214,7 @@ impl YellowstoneGrpcConfig { accounts: Vec::new(), owners: Vec::new(), require_txn_signature: false, - max_decoding_message_size: 64 * 1024 * 1024, + max_decoding_message_size: DEFAULT_MAX_DECODING_MESSAGE_SIZE, connect_timeout: Some(Duration::from_secs(10)), stall_timeout: Some(Duration::from_secs(30)), ping_interval: Some(Duration::from_secs(30)), @@ -216,7 +232,7 @@ impl YellowstoneGrpcConfig { /// Sets one stable source instance label for observability and redundancy intent. #[must_use] - pub fn with_source_instance(mut self, instance: impl Into>) -> Self { + pub fn with_source_instance(mut self, instance: impl Into>) -> Self { self.source_instance = Some(instance.into()); self } @@ -391,6 +407,10 @@ impl YellowstoneGrpcConfig { self } + const fn reconnect_delay_effective(&self) -> Duration { + effective_reconnect_delay(self.reconnect_delay) + } + fn validate(&self) -> Result<(), YellowstoneGrpcConfigError> { match self.stream { YellowstoneGrpcStream::Transaction | YellowstoneGrpcStream::TransactionStatus => { @@ -646,7 +666,7 @@ pub enum YellowstoneGrpcStream { pub struct YellowstoneGrpcSlotsConfig { endpoint: String, x_token: Option, - source_instance: Option>, + source_instance: Option>, readiness: ProviderSourceReadiness, source_role: ProviderSourceRole, source_priority: u16, @@ -690,7 +710,7 @@ impl YellowstoneGrpcSlotsConfig { /// Sets one stable source instance label for observability and redundancy intent. #[must_use] - pub fn with_source_instance(mut self, instance: impl Into>) -> Self { + pub fn with_source_instance(mut self, instance: impl Into>) -> Self { self.source_instance = Some(instance.into()); self } @@ -802,6 +822,10 @@ impl YellowstoneGrpcSlotsConfig { self } + const fn reconnect_delay_effective(&self) -> Duration { + effective_reconnect_delay(self.reconnect_delay) + } + /// Sets the maximum reconnect attempts. `None` keeps retrying forever. #[must_use] pub const fn with_max_reconnect_attempts(mut self, attempts: u32) -> Self { @@ -863,6 +887,14 @@ impl YellowstoneGrpcSlotsConfig { } } +const fn effective_reconnect_delay(delay: Duration) -> Duration { + if delay.is_zero() { + MIN_RECONNECT_DELAY + } else { + delay + } +} + /// Yellowstone transaction-stream error surface. #[derive(Debug, Error)] pub enum YellowstoneGrpcError { @@ -929,8 +961,8 @@ pub enum YellowstoneGrpcStreamKind { Slots, } -impl std::fmt::Display for YellowstoneGrpcStreamKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Display for YellowstoneGrpcStreamKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Transaction => f.write_str("transaction"), Self::TransactionStatus => f.write_str("transaction-status"), @@ -941,10 +973,10 @@ impl std::fmt::Display for YellowstoneGrpcStreamKind { } } -type YellowstoneSubscribeSink = std::pin::Pin< +type YellowstoneSubscribeSink = Pin< Box + Send>, >; -type YellowstoneUpdateStream = std::pin::Pin< +type YellowstoneUpdateStream = Pin< Box< dyn futures_util::Stream< Item = Result, @@ -1162,7 +1194,7 @@ async fn spawn_yellowstone_grpc_source_inner( YellowstoneGrpcProtocolError::ReconnectBudgetExhausted { attempts }.into(), ); } - tokio::time::sleep(config.reconnect_delay).await; + tokio::time::sleep(config.reconnect_delay_effective()).await; } })) } @@ -1223,7 +1255,7 @@ async fn spawn_yellowstone_grpc_slot_source_inner( let mut attempts = 0_u32; let mut tracked_slot = 0_u64; let mut watermarks = ProviderCommitmentWatermarks::default(); - let mut slot_states = HashMap::new(); + let mut slot_states = HashMap::with_capacity(SLOT_STATUS_PRUNE_THRESHOLD); let mut first_session = Some(first_session); loop { let mut session_established = false; @@ -1325,7 +1357,7 @@ async fn spawn_yellowstone_grpc_slot_source_inner( YellowstoneGrpcProtocolError::ReconnectBudgetExhausted { attempts }.into(), ); } - tokio::time::sleep(config.reconnect_delay).await; + tokio::time::sleep(config.reconnect_delay_effective()).await; } })) } @@ -1339,6 +1371,7 @@ async fn run_yellowstone_primary_connection( mut stream: YellowstoneUpdateStream, ) -> Result<(), YellowstoneGrpcError> { let commitment = config.commitment.as_tx_commitment(); + let provider_source = Arc::new(source.clone()); *state.session_established = false; *state.session_established = true; send_primary_provider_health( @@ -1350,7 +1383,7 @@ async fn run_yellowstone_primary_connection( PROVIDER_SUBSCRIPTION_ACKNOWLEDGED.to_owned(), ) .await?; - let mut ping = config.ping_interval.map(tokio::time::interval); + let mut ping = config.ping_interval.map(keepalive_interval); let mut last_progress = Instant::now(); loop { tokio::select! { @@ -1370,7 +1403,10 @@ async fn run_yellowstone_primary_connection( .map_err(GeyserGrpcClientError::SubscribeSendError)?; } () = async { - if let Some(timeout) = config.stall_timeout { + if let Some(timeout) = config + .stall_timeout + .map(|timeout| nonzero_duration_or(timeout, MIN_PROVIDER_STALL_TIMEOUT)) + { let deadline = last_progress.checked_add(timeout).unwrap_or(last_progress); tokio::time::sleep_until(deadline.into()).await; } else { @@ -1405,7 +1441,7 @@ async fn run_yellowstone_primary_connection( sender .send( ProviderStreamUpdate::Transaction(event) - .with_provider_source(source.clone()), + .with_provider_source_ref(&provider_source), ) .await .map_err(|_error| YellowstoneGrpcError::QueueClosed)?; @@ -1425,7 +1461,7 @@ async fn run_yellowstone_primary_connection( sender .send( ProviderStreamUpdate::TransactionStatus(event) - .with_provider_source(source.clone()), + .with_provider_source_ref(&provider_source), ) .await .map_err(|_error| YellowstoneGrpcError::QueueClosed)?; @@ -1447,7 +1483,7 @@ async fn run_yellowstone_primary_connection( sender .send( ProviderStreamUpdate::AccountUpdate(event) - .with_provider_source(source.clone()), + .with_provider_source_ref(&provider_source), ) .await .map_err(|_error| YellowstoneGrpcError::QueueClosed)?; @@ -1469,7 +1505,7 @@ async fn run_yellowstone_primary_connection( sender .send( ProviderStreamUpdate::BlockMeta(event) - .with_provider_source(source.clone()), + .with_provider_source_ref(&provider_source), ) .await .map_err(|_error| YellowstoneGrpcError::QueueClosed)?; @@ -1511,6 +1547,7 @@ async fn run_yellowstone_slot_connection( mut subscribe_tx: YellowstoneSubscribeSink, mut stream: YellowstoneUpdateStream, ) -> Result<(), YellowstoneGrpcError> { + let provider_source = Arc::new(source.clone()); *state.session_established = false; *state.session_established = true; send_provider_slot_health( @@ -1522,7 +1559,7 @@ async fn run_yellowstone_slot_connection( PROVIDER_SUBSCRIPTION_ACKNOWLEDGED.to_owned(), ) .await?; - let mut ping = config.ping_interval.map(tokio::time::interval); + let mut ping = config.ping_interval.map(keepalive_interval); let mut last_progress = Instant::now(); loop { tokio::select! { @@ -1542,7 +1579,10 @@ async fn run_yellowstone_slot_connection( .map_err(GeyserGrpcClientError::SubscribeSendError)?; } () = async { - if let Some(timeout) = config.stall_timeout { + if let Some(timeout) = config + .stall_timeout + .map(|timeout| nonzero_duration_or(timeout, MIN_PROVIDER_STALL_TIMEOUT)) + { let deadline = last_progress.checked_add(timeout).unwrap_or(last_progress); tokio::time::sleep_until(deadline.into()).await; } else { @@ -1570,7 +1610,7 @@ async fn run_yellowstone_slot_connection( sender .send( ProviderStreamUpdate::SlotStatus(event) - .with_provider_source(source.clone()), + .with_provider_source_ref(&provider_source), ) .await .map_err(|_error| YellowstoneGrpcError::QueueClosed)?; @@ -1597,33 +1637,46 @@ async fn establish_yellowstone_session( config: &YellowstoneGrpcConfig, tracked_slot: u64, ) -> Result<(YellowstoneSubscribeSink, YellowstoneUpdateStream), YellowstoneGrpcError> { - let mut builder = GeyserGrpcClient::build_from_shared(config.endpoint.clone())? - .x_token(config.x_token.clone())? - .max_decoding_message_size(config.max_decoding_message_size); - if let Some(timeout) = config.connect_timeout { - builder = builder.connect_timeout(timeout); - } - let mut client = builder.connect().await?; - let (subscribe_tx, stream) = client - .subscribe_with_request(Some(config.subscribe_request_with_state(tracked_slot))) - .await?; - Ok((Box::pin(subscribe_tx), Box::pin(stream))) + establish_yellowstone_subscribe_session( + config.endpoint.clone(), + config.x_token.clone(), + config.max_decoding_message_size, + config.connect_timeout, + config.subscribe_request_with_state(tracked_slot), + ) + .await } async fn establish_yellowstone_slot_session( config: &YellowstoneGrpcSlotsConfig, tracked_slot: u64, ) -> Result<(YellowstoneSubscribeSink, YellowstoneUpdateStream), YellowstoneGrpcError> { - let mut builder = GeyserGrpcClient::build_from_shared(config.endpoint.clone())? - .x_token(config.x_token.clone())? - .max_decoding_message_size(64 * 1024 * 1024); - if let Some(timeout) = config.connect_timeout { - builder = builder.connect_timeout(timeout); + establish_yellowstone_subscribe_session( + config.endpoint.clone(), + config.x_token.clone(), + DEFAULT_MAX_DECODING_MESSAGE_SIZE, + config.connect_timeout, + config.subscribe_request_with_state(tracked_slot), + ) + .await +} + +async fn establish_yellowstone_subscribe_session( + endpoint: String, + x_token: Option, + max_decoding_message_size: usize, + connect_timeout: Option, + request: SubscribeRequest, +) -> Result<(YellowstoneSubscribeSink, YellowstoneUpdateStream), YellowstoneGrpcError> { + let mut builder = GeyserGrpcClient::build_from_shared(endpoint)? + .x_token(x_token)? + .max_decoding_message_size(max_decoding_message_size); + if let Some(timeout) = connect_timeout { + builder = + builder.connect_timeout(nonzero_duration_or(timeout, MIN_PROVIDER_CONNECT_TIMEOUT)); } let mut client = builder.connect().await?; - let (subscribe_tx, stream) = client - .subscribe_with_request(Some(config.subscribe_request_with_state(tracked_slot))) - .await?; + let (subscribe_tx, stream) = client.subscribe_with_request(Some(request)).await?; Ok((Box::pin(subscribe_tx), Box::pin(stream))) } @@ -1754,33 +1807,30 @@ fn transaction_event_from_update( let transaction = transaction.ok_or(YellowstoneGrpcError::Convert("missing transaction payload"))?; let is_vote = transaction.is_vote; - let signature = if is_vote { - Signature::try_from(transaction.signature.as_slice()) - .map(Some) - .map_err(|_error| YellowstoneGrpcError::Convert("invalid signature"))? - } else { - None - }; + let signature = signature_bytes_from_slice(transaction.signature.as_slice(), || { + YellowstoneGrpcError::Convert("invalid signature") + })?; let tx = convert_transaction( transaction .transaction .ok_or(YellowstoneGrpcError::Convert( "missing versioned transaction", ))?, + Some(signature), )?; Ok(TransactionEvent { slot, commitment_status, confirmed_slot: watermarks.confirmed_slot, finalized_slot: watermarks.finalized_slot, - signature: signature_bytes_opt(signature.or_else(|| tx.signatures.first().copied())), + signature: Some(signature), provider_source: None, kind: if is_vote { TxKind::VoteOnly } else { classify_provider_transaction_kind(&tx) }, - tx: std::sync::Arc::new(tx), + tx: Arc::new(tx), }) } @@ -1789,14 +1839,15 @@ fn transaction_status_event_from_update( watermarks: ProviderCommitmentWatermarks, update: yellowstone_grpc_proto::prelude::SubscribeUpdateTransactionStatus, ) -> Result { - let signature = Signature::try_from(update.signature.as_slice()) - .map_err(|_error| YellowstoneGrpcError::Convert("invalid transaction-status signature"))?; + let signature = signature_bytes_from_slice(update.signature.as_slice(), || { + YellowstoneGrpcError::Convert("invalid transaction-status signature") + })?; Ok(TransactionStatusEvent { slot: update.slot, commitment_status, confirmed_slot: watermarks.confirmed_slot, finalized_slot: watermarks.finalized_slot, - signature: signature.into(), + signature, is_vote: update.is_vote, index: Some(update.index), err: update.err.map(|error| format!("{error:?}")), @@ -1812,30 +1863,36 @@ fn account_update_event_from_yellowstone( let account = update .account .ok_or(YellowstoneGrpcError::Convert("missing account payload"))?; - let pubkey = Pubkey::try_from(account.pubkey.as_slice()) - .map_err(|_error| YellowstoneGrpcError::Convert("invalid account pubkey"))?; - let owner = Pubkey::try_from(account.owner.as_slice()) - .map_err(|_error| YellowstoneGrpcError::Convert("invalid account owner"))?; + let pubkey = pubkey_bytes_from_slice(account.pubkey.as_slice(), || { + YellowstoneGrpcError::Convert("invalid account pubkey") + })?; + let owner = pubkey_bytes_from_slice(account.owner.as_slice(), || { + YellowstoneGrpcError::Convert("invalid account owner") + })?; let txn_signature = match account.txn_signature { - Some(signature) => Some( - Signature::try_from(signature.as_slice()) - .map_err(|_error| YellowstoneGrpcError::Convert("invalid account txn signature"))?, - ), + Some(signature) => Some(signature_bytes_from_slice(signature.as_slice(), || { + YellowstoneGrpcError::Convert("invalid account txn signature") + })?), None => None, }; + if account.data.len() > MAX_ACCOUNT_DATA_LEN { + return Err(YellowstoneGrpcError::Convert( + "account data exceeds max permitted size", + )); + } Ok(AccountUpdateEvent { slot: update.slot, commitment_status, confirmed_slot: watermarks.confirmed_slot, finalized_slot: watermarks.finalized_slot, - pubkey: pubkey_bytes(pubkey), - owner: pubkey_bytes(owner), + pubkey, + owner, lamports: account.lamports, executable: account.executable, rent_epoch: account.rent_epoch, data: account.data.into(), write_version: Some(account.write_version), - txn_signature: signature_bytes_opt(txn_signature), + txn_signature, is_startup: update.is_startup, matched_filter: None, provider_source: None, @@ -1902,6 +1959,12 @@ fn slot_status_event_from_update( | SlotStatus::SlotCreatedBank => ForkSlotStatus::Processed, }; let previous_status = slot_states.insert(slot, mapped); + prune_recent_slots( + slot_states, + slot, + SLOT_STATUS_RETAINED_LAG, + SLOT_STATUS_PRUNE_THRESHOLD, + ); if previous_status == Some(mapped) { return None; } @@ -1966,15 +2029,29 @@ impl ProviderStreamFanIn { } } -#[inline] +#[inline(always)] fn convert_transaction( tx: yellowstone_grpc_proto::prelude::Transaction, + first_signature: Option, ) -> Result { let mut signatures = Vec::with_capacity(tx.signatures.len()); - for signature in tx.signatures { - signatures.push(Signature::try_from(signature.as_slice()).map_err(|_error| { - YellowstoneGrpcError::Convert("failed to parse transaction signature") - })?); + let mut tx_signatures = tx.signatures.into_iter(); + if let Some(signature) = tx_signatures.next() { + signatures.push(match first_signature { + Some(first_signature) => first_signature.into(), + None => signature_bytes_from_slice(signature.as_slice(), || { + YellowstoneGrpcError::Convert("failed to parse transaction signature") + })? + .into(), + }); + } + for signature in tx_signatures { + signatures.push( + signature_bytes_from_slice(signature.as_slice(), || { + YellowstoneGrpcError::Convert("failed to parse transaction signature") + })? + .into(), + ); } let message = convert_message( tx.message @@ -1986,7 +2063,7 @@ fn convert_transaction( }) } -#[inline] +#[inline(always)] fn convert_message( message: yellowstone_grpc_proto::prelude::Message, ) -> Result { @@ -2007,8 +2084,10 @@ fn convert_message( let mut account_keys = Vec::with_capacity(message.account_keys.len()); for key in message.account_keys { account_keys.push( - Pubkey::try_from(key.as_slice()) - .map_err(|_error| YellowstoneGrpcError::Convert("invalid account key"))?, + pubkey_bytes_from_slice(key.as_slice(), || { + YellowstoneGrpcError::Convert("invalid account key") + })? + .into(), ); } let recent_blockhash = <[u8; 32]>::try_from(message.recent_blockhash.as_slice()) @@ -2028,9 +2107,10 @@ fn convert_message( let mut address_table_lookups = Vec::with_capacity(message.address_table_lookups.len()); for lookup in message.address_table_lookups { address_table_lookups.push(MessageAddressTableLookup { - account_key: Pubkey::try_from(lookup.account_key.as_slice()).map_err(|_error| { + account_key: pubkey_bytes_from_slice(lookup.account_key.as_slice(), || { YellowstoneGrpcError::Convert("invalid address table account key") - })?, + })? + .into(), writable_indexes: lookup.writable_indexes, readonly_indexes: lookup.readonly_indexes, }); @@ -2060,8 +2140,9 @@ fn convert_message( )] mod tests { use super::*; - use crate::event::TxKind; - use crate::provider_stream::create_provider_stream_queue; + use crate::{ + event::TxKind, framework::signature_bytes, provider_stream::create_provider_stream_queue, + }; use futures_channel::mpsc as futures_mpsc; use futures_util::stream::{self, Stream}; use solana_instruction::Instruction; @@ -2071,6 +2152,8 @@ mod tests { use solana_sdk_ids::{compute_budget, vote}; use solana_signer::Signer; use std::{pin::Pin, time::Instant}; + + use sof_support::bench::{avg_ns_per_iteration, profile_iterations}; use tokio::sync::oneshot; use tokio::time::{Duration, timeout}; use yellowstone_grpc_proto::geyser::geyser_server::{Geyser, GeyserServer}; @@ -2087,12 +2170,60 @@ mod tests { }; use yellowstone_grpc_proto::tonic::{self, Request, Response, Status, transport::Server}; - fn profile_iterations(default: usize) -> usize { - std::env::var("SOF_PROFILE_ITERATIONS") - .ok() - .and_then(|value| value.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(default) + #[test] + fn yellowstone_account_update_rejects_oversized_data() { + let pubkey = Pubkey::new_unique(); + let owner = Pubkey::new_unique(); + let update = sample_account_update(92, pubkey, owner); + let account_update = match update.update_oneof { + Some(UpdateOneof::Account(account_update)) => account_update, + other => panic!("expected account update, got {other:?}"), + }; + + let mut oversized = account_update; + oversized.account.as_mut().expect("account payload").data = + vec![7_u8; MAX_ACCOUNT_DATA_LEN + 1]; + + let error = account_update_event_from_yellowstone( + oversized, + TxCommitmentStatus::Confirmed, + ProviderCommitmentWatermarks::default(), + ) + .expect_err("oversized account payload must fail"); + + assert!(matches!( + error, + YellowstoneGrpcError::Convert("account data exceeds max permitted size") + )); + } + + #[test] + fn yellowstone_slot_state_pruning_evicts_old_slots() { + let mut slot_states = HashMap::new(); + for slot in 0..=u64::try_from(SLOT_STATUS_PRUNE_THRESHOLD).unwrap_or(u64::MAX) { + let _ = slot_states.insert(slot, ForkSlotStatus::Processed); + } + + prune_recent_slots( + &mut slot_states, + 10_000, + SLOT_STATUS_RETAINED_LAG, + SLOT_STATUS_PRUNE_THRESHOLD, + ); + + assert!( + !slot_states.contains_key(&0), + "old tracked slots should be pruned" + ); + assert!( + slot_states.contains_key(&10_000_u64.saturating_sub(SLOT_STATUS_RETAINED_LAG)), + "recent tracked slots should stay resident" + ); + assert!( + slot_states.len() + <= usize::try_from(SLOT_STATUS_RETAINED_LAG + 1).unwrap_or(usize::MAX), + "tracked slot state should stay bounded" + ); } fn sample_transaction() -> VersionedTransaction { @@ -2124,6 +2255,44 @@ mod tests { assert_eq!(filter.failed, None); } + #[test] + fn yellowstone_reconnect_delay_never_spins() { + let config = YellowstoneGrpcConfig::new("http://127.0.0.1:10000") + .with_reconnect_delay(Duration::ZERO); + assert_eq!(config.reconnect_delay_effective(), Duration::from_millis(1)); + + let slots_config = YellowstoneGrpcSlotsConfig::new("http://127.0.0.1:10000") + .with_reconnect_delay(Duration::ZERO); + assert_eq!( + slots_config.reconnect_delay_effective(), + Duration::from_millis(1) + ); + } + + #[test] + fn yellowstone_connect_timeout_never_zero() { + assert_eq!( + nonzero_duration_or(Duration::ZERO, MIN_PROVIDER_CONNECT_TIMEOUT), + Duration::from_millis(1) + ); + assert_eq!( + nonzero_duration_or(Duration::from_millis(25), MIN_PROVIDER_CONNECT_TIMEOUT), + Duration::from_millis(25) + ); + } + + #[test] + fn yellowstone_stall_timeout_never_zero() { + assert_eq!( + nonzero_duration_or(Duration::ZERO, MIN_PROVIDER_STALL_TIMEOUT), + Duration::from_millis(1) + ); + assert_eq!( + nonzero_duration_or(Duration::from_millis(25), MIN_PROVIDER_STALL_TIMEOUT), + Duration::from_millis(25) + ); + } + #[test] fn yellowstone_subscribe_request_tracks_slots_and_replay_cursor() { let request = YellowstoneGrpcConfig::new("http://127.0.0.1:10000") @@ -2827,7 +2996,7 @@ mod tests { let transaction = transaction.ok_or(YellowstoneGrpcError::Convert("missing transaction payload"))?; let signature = Signature::try_from(transaction.signature.as_slice()) - .map(crate::framework::signature_bytes) + .map(signature_bytes) .map(Some) .map_err(|_error| YellowstoneGrpcError::Convert("invalid signature"))?; let tx = { @@ -2940,7 +3109,7 @@ mod tests { finalized_slot: None, signature, kind: classify_provider_transaction_kind(&tx), - tx: std::sync::Arc::new(tx), + tx: Arc::new(tx), provider_source: None, }) } @@ -2975,7 +3144,7 @@ mod tests { #[tokio::test] async fn yellowstone_spawn_rejects_transaction_filters_for_accounts_stream() { - let (tx, _rx) = crate::provider_stream::create_provider_stream_queue(1); + let (tx, _rx) = create_provider_stream_queue(1); let config = YellowstoneGrpcConfig::new("http://127.0.0.1:1") .with_stream(YellowstoneGrpcStream::Accounts) .with_vote(true); @@ -3008,7 +3177,7 @@ mod tests { TxCommitmentStatus::Processed, ) .expect("baseline event"); - std::hint::black_box(event); + black_box(event); } let baseline_elapsed = baseline_started.elapsed(); @@ -3021,15 +3190,21 @@ mod tests { ProviderCommitmentWatermarks::default(), ) .expect("optimized event"); - std::hint::black_box(event); + black_box(event); } let optimized_elapsed = optimized_started.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); eprintln!( - "yellowstone_transaction_conversion_profile_fixture iterations={} baseline_us={} optimized_us={}", + "yellowstone_transaction_conversion_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", iterations, baseline_elapsed.as_micros(), optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, ); } @@ -3046,7 +3221,7 @@ mod tests { TxCommitmentStatus::Processed, ) .expect("baseline event"); - std::hint::black_box(event); + black_box(event); } } @@ -3064,7 +3239,7 @@ mod tests { ProviderCommitmentWatermarks::default(), ) .expect("optimized event"); - std::hint::black_box(event); + black_box(event); } } @@ -3083,7 +3258,7 @@ mod tests { TxCommitmentStatus::Processed, ) .expect("baseline event"); - std::hint::black_box(event); + black_box(event); } let baseline_elapsed = baseline_started.elapsed(); @@ -3096,15 +3271,21 @@ mod tests { ProviderCommitmentWatermarks::default(), ) .expect("optimized event"); - std::hint::black_box(event); + black_box(event); } let optimized_elapsed = optimized_started.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); eprintln!( - "yellowstone_vote_only_conversion_profile_fixture iterations={} baseline_us={} optimized_us={}", + "yellowstone_vote_only_conversion_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", iterations, baseline_elapsed.as_micros(), optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, ); } } diff --git a/crates/sof-observer/src/relay/cache.rs b/crates/sof-observer/src/relay/cache.rs index 10f88b7c..c22ae725 100644 --- a/crates/sof-observer/src/relay/cache.rs +++ b/crates/sof-observer/src/relay/cache.rs @@ -1,5 +1,5 @@ use std::{ - collections::{HashMap, VecDeque}, + collections::{HashMap, VecDeque, hash_map::Entry}, sync::{ Arc, atomic::{AtomicUsize, Ordering}, @@ -110,7 +110,9 @@ impl RecentShredRingBuffer { }, ); self.order.push_back((now, key)); - self.evict(now, &mut evicted); + if replaced.is_none() { + self.evict(now, &mut evicted); + } CacheInsertOutcome { inserted: replaced.is_none(), @@ -146,15 +148,28 @@ impl RecentShredRingBuffer { }); } - let mut matches: Vec<(u32, Arc<[u8]>)> = self - .entries - .iter() - .filter(|(key, _)| { - key.slot == request.slot - && key.index >= request.start_index - && key.index <= request.end_index - }) - .map(|(key, entry)| (key.index, entry.bytes.clone())) + let mut latest_by_index: HashMap)> = + HashMap::with_capacity(query_range_index_capacity(span, limits.max_response_shreds)); + for (key, entry) in self.entries.iter().filter(|(key, _)| { + key.slot == request.slot + && key.index >= request.start_index + && key.index <= request.end_index + }) { + let candidate = (entry.seen_at, entry.bytes.clone()); + match latest_by_index.entry(key.index) { + Entry::Occupied(mut current) => { + if candidate.0 > current.get().0 { + current.insert(candidate); + } + } + Entry::Vacant(vacant) => { + vacant.insert(candidate); + } + } + } + let mut matches: Vec<(u32, Arc<[u8]>)> = latest_by_index + .into_iter() + .map(|(index, (_, bytes))| (index, bytes)) .collect(); matches.sort_unstable_by_key(|(index, _)| *index); @@ -221,6 +236,41 @@ impl RecentShredRingBuffer { } } } + + #[cfg(test)] + fn insert_baseline( + &mut self, + packet: &[u8], + parsed_shred: &ParsedShredHeader, + now: Instant, + ) -> CacheInsertOutcome { + let mut evicted = 0usize; + self.evict(now, &mut evicted); + + let Some(key) = make_cached_shred_key(packet, parsed_shred) else { + return CacheInsertOutcome { + inserted: false, + replaced: false, + evicted, + }; + }; + + let replaced = self.entries.insert( + key, + CachedShred { + seen_at: now, + bytes: Arc::from(packet), + }, + ); + self.order.push_back((now, key)); + self.evict(now, &mut evicted); + + CacheInsertOutcome { + inserted: replaced.is_none(), + replaced: replaced.is_some(), + evicted, + } + } } #[derive(Clone, Debug)] @@ -275,7 +325,6 @@ impl SharedRelayCache { replaced = true; let previous_seen_at = previous.value().seen_at; let _ = self.order.remove(&(previous_seen_at, key)); - let _ = self.slot_index.remove(&(key.slot, key.index, key)); } else { let _ = self.len.fetch_add(1, Ordering::Relaxed); } @@ -287,7 +336,9 @@ impl SharedRelayCache { }, ); self.order.insert((now, key), ()); - self.slot_index.insert((key.slot, key.index, key), ()); + if !replaced { + self.slot_index.insert((key.slot, key.index, key), ()); + } evicted = evicted.saturating_add(self.evict(now)); CacheInsertOutcome { @@ -331,8 +382,8 @@ impl SharedRelayCache { }); } - let mut response: Vec> = Vec::new(); - let mut response_bytes = 0usize; + let mut latest_by_index: HashMap)> = + HashMap::with_capacity(query_range_index_capacity(span, limits.max_response_shreds)); let start = slot_index_range_start(request.slot, request.start_index); let end = slot_index_range_end(request.slot, request.end_index); for entry in self.slot_index.range(start..=end) { @@ -340,7 +391,27 @@ impl SharedRelayCache { let Some(cached) = self.entries.get(&key) else { continue; }; - let bytes = cached.value().bytes.clone(); + let candidate = (cached.value().seen_at, cached.value().bytes.clone()); + match latest_by_index.entry(key.index) { + Entry::Occupied(mut current) => { + if candidate.0 > current.get().0 { + current.insert(candidate); + } + } + Entry::Vacant(vacant) => { + vacant.insert(candidate); + } + } + } + let mut matches: Vec<(u32, Arc<[u8]>)> = latest_by_index + .into_iter() + .map(|(index, (_, bytes))| (index, bytes)) + .collect(); + matches.sort_unstable_by_key(|(index, _)| *index); + + let mut response: Vec> = Vec::new(); + let mut response_bytes = 0usize; + for (_, bytes) in matches { if response.len() >= limits.max_response_shreds { break; } @@ -438,6 +509,55 @@ impl SharedRelayCache { } evicted } + + #[cfg(test)] + fn insert_baseline( + &self, + packet: &[u8], + parsed_shred: &ParsedShredHeader, + now: Instant, + ) -> CacheInsertOutcome { + let mut evicted = self.evict(now); + let Some(key) = make_cached_shred_key(packet, parsed_shred) else { + return CacheInsertOutcome { + inserted: false, + replaced: false, + evicted, + }; + }; + + let mut replaced = false; + if let Some(previous) = self.entries.remove(&key) { + replaced = true; + let previous_seen_at = previous.value().seen_at; + let _ = self.order.remove(&(previous_seen_at, key)); + let _ = self.slot_index.remove(&(key.slot, key.index, key)); + } else { + let _ = self.len.fetch_add(1, Ordering::Relaxed); + } + self.entries.insert( + key, + CachedShred { + seen_at: now, + bytes: Arc::from(packet), + }, + ); + self.order.insert((now, key), ()); + self.slot_index.insert((key.slot, key.index, key), ()); + evicted = evicted.saturating_add(self.evict(now)); + + CacheInsertOutcome { + inserted: !replaced, + replaced, + evicted, + } + } +} + +fn query_range_index_capacity(span: u32, max_response_shreds: usize) -> usize { + usize::try_from(span) + .unwrap_or(usize::MAX) + .min(max_response_shreds.max(1)) } fn make_cached_shred_key( @@ -492,6 +612,10 @@ const fn max_cached_shred_key() -> CachedShredKey { #[cfg(test)] mod tests { + use std::hint::black_box; + + use sof_support::{bench::avg_ns_per_iteration, env_support::read_positive_usize}; + use super::*; use crate::{ protocol::shred_wire::{ @@ -517,6 +641,101 @@ mod tests { assert_eq!(cache.len(), 1); } + #[test] + #[ignore = "profiling fixture for relay ring-buffer replacement churn"] + fn relay_ring_buffer_replace_profile_fixture() { + let iterations = profile_iterations(200_000); + let packet = build_data_shred_packet(42, 7, 1, 9, b"hello"); + let parsed = parse_shred_header(&packet).expect("valid"); + let mut baseline = RecentShredRingBuffer::new(16, Duration::from_secs(60)); + let mut optimized = RecentShredRingBuffer::new(16, Duration::from_secs(60)); + let started_at = Instant::now(); + + let baseline_started = Instant::now(); + for i in 0..iterations { + let now = started_at + Duration::from_nanos(u64::try_from(i).expect("iterations fit")); + black_box(baseline.insert_baseline(&packet, &parsed, now)); + } + let baseline_elapsed = baseline_started.elapsed(); + + let optimized_started = Instant::now(); + for i in 0..iterations { + let now = started_at + Duration::from_nanos(u64::try_from(i).expect("iterations fit")); + black_box(optimized.insert(&packet, &parsed, now)); + } + let optimized_elapsed = optimized_started.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); + + eprintln!( + "relay_ring_buffer_replace_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, + ); + } + + #[test] + #[ignore = "profiling fixture for relay ring-buffer query-range churn"] + fn relay_ring_buffer_query_range_profile_fixture() { + let iterations = profile_iterations(200_000); + let now = Instant::now(); + let request = RelayRangeRequest { + slot: 42, + start_index: 0, + end_index: 31, + }; + let limits = RelayRangeLimits { + max_request_span: 64, + max_response_shreds: 32, + max_response_bytes: usize::MAX, + }; + let mut baseline = RecentShredRingBuffer::new(128, Duration::from_secs(60)); + let mut optimized = RecentShredRingBuffer::new(128, Duration::from_secs(60)); + + for index in 0_u32..32_u32 { + let packet = build_data_shred_packet(42, index, index, 0, b"payload"); + let parsed = parse_shred_header(&packet).expect("valid"); + let seen_at = now + Duration::from_nanos(u64::from(index)); + let _ = baseline.insert(&packet, &parsed, seen_at); + let _ = optimized.insert(&packet, &parsed, seen_at); + } + + let baseline_started = Instant::now(); + for _ in 0..iterations { + let result = + query_range_baseline(&mut baseline, request, limits, now).expect("baseline query"); + black_box(result); + } + let baseline_elapsed = baseline_started.elapsed(); + + let optimized_started = Instant::now(); + for _ in 0..iterations { + let result = optimized + .query_range(request, limits, now) + .expect("optimized query"); + black_box(result); + } + let optimized_elapsed = optimized_started.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); + + eprintln!( + "relay_ring_buffer_query_range_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, + ); + } + #[test] fn cache_evicts_by_capacity() { let p1 = build_data_shred_packet(100, 1, 1, 0, b"a"); @@ -621,6 +840,36 @@ mod tests { assert_eq!(result[0].as_ref(), p1.as_slice()); } + #[test] + fn query_range_prefers_latest_seen_variant_per_index() { + let old_packet = build_data_shred_packet(9, 5, 1, 0, b"old"); + let new_packet = build_data_shred_packet(9, 5, 2, 0, b"new"); + let old_header = parse_shred_header(&old_packet).expect("valid"); + let new_header = parse_shred_header(&new_packet).expect("valid"); + let mut cache = RecentShredRingBuffer::new(16, Duration::from_secs(2)); + let now = Instant::now(); + let _ = cache.insert(&old_packet, &old_header, now); + let _ = cache.insert(&new_packet, &new_header, now + Duration::from_millis(1)); + + let result = cache + .query_range( + RelayRangeRequest { + slot: 9, + start_index: 5, + end_index: 5, + }, + RelayRangeLimits { + max_request_span: 4, + max_response_shreds: 4, + max_response_bytes: usize::MAX, + }, + now + Duration::from_millis(2), + ) + .expect("query succeeds"); + assert_eq!(result.len(), 1); + assert_eq!(result[0].as_ref(), new_packet.as_slice()); + } + #[test] fn query_exact_returns_matching_shred() { let p1 = build_data_shred_packet(10, 1, 1, 0, b"a"); @@ -697,6 +946,41 @@ mod tests { assert_eq!(found.as_ref(), packet_new.as_slice()); } + #[test] + fn shared_insert_replaces_existing_key_without_duplicate_slot_index_entry() { + let packet = build_data_shred_packet(12, 42, 1, 0, b"same"); + let header = parse_shred_header(&packet).expect("valid"); + let cache = SharedRelayCache::new(RecentShredRingBuffer::new(16, Duration::from_secs(2))); + let now = Instant::now(); + + let first = cache.insert(&packet, &header, now); + let second = cache.insert(&packet, &header, now + Duration::from_millis(1)); + + assert!(first.inserted); + assert!(!first.replaced); + assert!(!second.inserted); + assert!(second.replaced); + assert_eq!(cache.len(), 1); + + let query = cache + .query_range( + RelayRangeRequest { + slot: 12, + start_index: 42, + end_index: 42, + }, + RelayRangeLimits { + max_request_span: 4, + max_response_shreds: 4, + max_response_bytes: usize::MAX, + }, + now + Duration::from_millis(2), + ) + .expect("query succeeds"); + assert_eq!(query.len(), 1); + assert_eq!(query[0].as_ref(), packet.as_slice()); + } + #[test] fn shared_query_highest_above_prefers_highest_index_then_latest_seen_at() { let packet_mid = build_data_shred_packet(13, 7, 1, 0, b"mid"); @@ -734,6 +1018,152 @@ mod tests { assert_eq!(found.as_ref(), packet_new_top.as_slice()); } + #[test] + fn shared_query_range_prefers_latest_seen_variant_per_index() { + let old_packet = build_data_shred_packet(14, 9, 1, 0, b"old-top"); + let new_packet = build_data_shred_packet(14, 9, 2, 0, b"new-top"); + let old_header = parse_shred_header(&old_packet).expect("valid"); + let new_header = parse_shred_header(&new_packet).expect("valid"); + let cache = SharedRelayCache::new(RecentShredRingBuffer::new(16, Duration::from_secs(2))); + let now = Instant::now(); + assert!(cache.insert(&old_packet, &old_header, now).inserted); + assert!( + cache + .insert(&new_packet, &new_header, now + Duration::from_millis(1)) + .inserted + ); + + let result = cache + .query_range( + RelayRangeRequest { + slot: 14, + start_index: 9, + end_index: 9, + }, + RelayRangeLimits { + max_request_span: 4, + max_response_shreds: 4, + max_response_bytes: usize::MAX, + }, + now + Duration::from_millis(2), + ) + .expect("query succeeds"); + assert_eq!(result.len(), 1); + assert_eq!(result[0].as_ref(), new_packet.as_slice()); + } + + #[test] + #[ignore = "profiling fixture for shared relay cache replacement churn"] + fn shared_relay_cache_replace_profile_fixture() { + let iterations = profile_iterations(200_000); + let packet = build_data_shred_packet(42, 7, 1, 9, b"hello"); + let parsed = parse_shred_header(&packet).expect("valid"); + let baseline = + SharedRelayCache::new(RecentShredRingBuffer::new(16, Duration::from_secs(60))); + let optimized = + SharedRelayCache::new(RecentShredRingBuffer::new(16, Duration::from_secs(60))); + let started_at = Instant::now(); + + let baseline_started = Instant::now(); + for i in 0..iterations { + let now = started_at + Duration::from_nanos(u64::try_from(i).expect("iterations fit")); + black_box(baseline.insert_baseline(&packet, &parsed, now)); + } + let baseline_elapsed = baseline_started.elapsed(); + + let optimized_started = Instant::now(); + for i in 0..iterations { + let now = started_at + Duration::from_nanos(u64::try_from(i).expect("iterations fit")); + black_box(optimized.insert(&packet, &parsed, now)); + } + let optimized_elapsed = optimized_started.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); + + eprintln!( + "shared_relay_cache_replace_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, + ); + } + + fn query_range_baseline( + cache: &mut RecentShredRingBuffer, + request: RelayRangeRequest, + limits: RelayRangeLimits, + now: Instant, + ) -> Result>, RelayRangeQueryError> { + let mut evicted = 0usize; + cache.evict(now, &mut evicted); + + if request.start_index > request.end_index { + return Err(RelayRangeQueryError::InvalidRange { + start_index: request.start_index, + end_index: request.end_index, + }); + } + + let span = request + .end_index + .saturating_sub(request.start_index) + .saturating_add(1); + if span > limits.max_request_span { + return Err(RelayRangeQueryError::SpanTooLarge { + span, + max_request_span: limits.max_request_span, + }); + } + + let mut latest_by_index: HashMap)> = HashMap::new(); + for (key, entry) in cache.entries.iter().filter(|(key, _)| { + key.slot == request.slot + && key.index >= request.start_index + && key.index <= request.end_index + }) { + let candidate = (entry.seen_at, entry.bytes.clone()); + match latest_by_index.entry(key.index) { + Entry::Occupied(mut current) => { + if candidate.0 > current.get().0 { + current.insert(candidate); + } + } + Entry::Vacant(vacant) => { + vacant.insert(candidate); + } + } + } + let mut matches: Vec<(u32, Arc<[u8]>)> = latest_by_index + .into_iter() + .map(|(index, (_, bytes))| (index, bytes)) + .collect(); + matches.sort_unstable_by_key(|(index, _)| *index); + + let mut response: Vec> = Vec::new(); + let mut response_bytes = 0usize; + for (_, bytes) in matches { + if response.len() >= limits.max_response_shreds { + break; + } + let next_bytes = response_bytes.saturating_add(bytes.len()); + if next_bytes > limits.max_response_bytes { + break; + } + response_bytes = next_bytes; + response.push(bytes); + } + + Ok(response) + } + + fn profile_iterations(default: usize) -> usize { + read_positive_usize("SOF_PROFILE_ITERS", default) + } + fn build_data_shred_packet( slot: u64, index: u32, diff --git a/crates/sof-observer/src/repair/core/gossip/methods.rs b/crates/sof-observer/src/repair/core/gossip/methods.rs index d0e1e2ac..0750bc6c 100644 --- a/crates/sof-observer/src/repair/core/gossip/methods.rs +++ b/crates/sof-observer/src/repair/core/gossip/methods.rs @@ -1,16 +1,21 @@ use super::*; +use std::{net::SocketAddr, sync::Arc}; + use crate::{ protocol::shred_wire::SIZE_OF_CODING_SHRED_PAYLOAD, relay::SharedRelayCache, shred::wire::{ParsedShredHeader, parse_shred_header}, }; +use smallvec::SmallVec; +use sof_support::time_support::duration_millis_u64; +use solana_gossip::contact_info::ContactInfo; use solana_keypair::signable::Signable; impl GossipRepairClient { pub fn new( - cluster_info: std::sync::Arc, + cluster_info: Arc, socket: UdpSocket, - keypair: std::sync::Arc, + keypair: Arc, config: GossipRepairClientConfig, ) -> Self { let now = Instant::now(); @@ -60,18 +65,9 @@ impl GossipRepairClient { (snapshot.total_candidates, snapshot.active_candidates) } - pub fn note_shred_source(&mut self, source_addr: std::net::SocketAddr) -> usize { + pub fn note_shred_source(&mut self, source_addr: SocketAddr) -> usize { let mut updated = 0_usize; - let mut seen = HashSet::new(); - if let Some(pubkey) = self.addr_to_pubkey.get(&source_addr).copied() { - let _ = seen.insert(pubkey); - } - if let Some(pubkeys) = self.ip_to_pubkeys.get(&source_addr.ip()) { - for pubkey in pubkeys { - let _ = seen.insert(*pubkey); - } - } - for pubkey in seen { + for pubkey in self.source_pubkeys(source_addr) { self.peer_scores .entry(pubkey) .or_default() @@ -81,19 +77,10 @@ impl GossipRepairClient { updated } - pub fn note_shred_sources(&mut self, source_addrs: &[(std::net::SocketAddr, u16)]) -> usize { + pub fn note_shred_sources(&mut self, source_addrs: &[(SocketAddr, u16)]) -> usize { let mut updated = 0_usize; for (source_addr, hits) in source_addrs.iter().copied() { - let mut seen = HashSet::new(); - if let Some(pubkey) = self.addr_to_pubkey.get(&source_addr).copied() { - let _ = seen.insert(pubkey); - } - if let Some(pubkeys) = self.ip_to_pubkeys.get(&source_addr.ip()) { - for pubkey in pubkeys { - let _ = seen.insert(*pubkey); - } - } - for pubkey in seen { + for pubkey in self.source_pubkeys(source_addr) { self.peer_scores .entry(pubkey) .or_default() @@ -104,12 +91,36 @@ impl GossipRepairClient { updated } + fn source_pubkeys(&self, source_addr: SocketAddr) -> SmallVec<[Pubkey; 4]> { + let direct_pubkey = self.addr_to_pubkey.get(&source_addr).copied(); + let Some(ip_pubkeys) = self.ip_to_pubkeys.get(&source_addr.ip()) else { + return direct_pubkey.into_iter().collect(); + }; + let mut seen = SmallVec::<[Pubkey; 4]>::with_capacity( + ip_pubkeys + .len() + .saturating_add(usize::from(direct_pubkey.is_some())), + ); + if let Some(pubkey) = direct_pubkey { + seen.push(pubkey); + seen.extend( + ip_pubkeys + .iter() + .copied() + .filter(|ip_pubkey| *ip_pubkey != pubkey), + ); + } else { + seen.extend(ip_pubkeys.iter().copied()); + } + seen + } + pub async fn request_missing_shred( &mut self, slot: u64, index: u32, kind: MissingShredRequestKind, - ) -> Result, GossipRepairClientError> { + ) -> Result, GossipRepairClientError> { let Some(peer) = self.pick_peer(slot, index) else { return Ok(None); }; @@ -152,7 +163,7 @@ impl GossipRepairClient { pub async fn maybe_handle_response_ping( &mut self, packet: &[u8], - from_addr: std::net::SocketAddr, + from_addr: SocketAddr, ) -> Result { if !is_repair_response_ping_packet(packet) { return Ok(false); @@ -179,7 +190,7 @@ impl GossipRepairClient { pub async fn maybe_serve_repair_request( &mut self, packet: &[u8], - from_addr: std::net::SocketAddr, + from_addr: SocketAddr, relay_cache: Option<&SharedRelayCache>, ) -> Result, GossipRepairClientError> { let Some(request) = parse_signed_repair_request( @@ -323,7 +334,6 @@ impl GossipRepairClient { fn refresh_peers(&mut self, slot: u64) { self.decay_peer_scores(); - self.refresh_stake_map(); self.peers_by_slot .retain(|_, cached| cached.updated_at.elapsed() < self.peer_cache_ttl); self.sticky_peer_by_slot.retain(|slot_key, sticky| { @@ -355,19 +365,32 @@ impl GossipRepairClient { if !should_refresh { return; } - let mut candidates = self.collect_candidate_peers(slot); + let all_peers = self.cluster_info.all_peers(); + self.refresh_stake_map(&all_peers); + let mut candidates = self.collect_candidate_peers(slot, &all_peers); let total_candidates = candidates.len(); candidates.sort_unstable_by(|left, right| { self.score_for(right.pubkey) .cmp(&self.score_for(left.pubkey)) .then_with(|| left.pubkey.to_bytes().cmp(&right.pubkey.to_bytes())) }); - let mut peers = candidates.clone(); - if peers.len() > self.active_peer_count { - peers.truncate(self.active_peer_count); - } + let peers = candidates + .iter() + .take(self.active_peer_count) + .copied() + .collect::>(); self.addr_to_pubkey.clear(); self.ip_to_pubkeys.clear(); + self.addr_to_pubkey.reserve( + candidates + .len() + .saturating_sub(self.addr_to_pubkey.capacity()), + ); + self.ip_to_pubkeys.reserve( + candidates + .len() + .saturating_sub(self.ip_to_pubkeys.capacity()), + ); for peer in &candidates { let _ = self.addr_to_pubkey.insert(peer.addr, peer.pubkey); self.ip_to_pubkeys @@ -387,7 +410,7 @@ impl GossipRepairClient { peers: peers.clone(), }, ); - self.publish_peer_snapshot(total_candidates, &peers); + self.publish_peer_snapshot(total_candidates, &peers, &all_peers); } fn decay_peer_scores(&mut self) { @@ -405,9 +428,14 @@ impl GossipRepairClient { }); } - fn collect_candidate_peers(&self, slot: u64) -> Vec { - let mut seen = HashSet::new(); - let mut peers = Vec::new(); + fn collect_candidate_peers( + &self, + slot: u64, + all_peers: &[(ContactInfo, u64)], + ) -> Vec { + let estimated_peers = all_peers.len().saturating_add(1); + let mut seen = HashSet::with_capacity(estimated_peers); + let mut peers = Vec::with_capacity(estimated_peers); for contact_info in self.cluster_info.repair_peers(slot) { let Some(addr) = contact_info.serve_repair(Protocol::UDP) else { continue; @@ -428,7 +456,7 @@ impl GossipRepairClient { if peers.is_empty() { let self_pubkey = self.cluster_info.id(); let self_shred_version = self.cluster_info.my_shred_version(); - for (contact_info, stake_lamports) in self.cluster_info.all_peers() { + for (contact_info, stake_lamports) in all_peers { if contact_info.pubkey() == &self_pubkey || contact_info.shred_version() != self_shred_version || contact_info.tvu(Protocol::UDP).is_none() @@ -441,7 +469,7 @@ impl GossipRepairClient { let peer = RepairPeer { pubkey: *contact_info.pubkey(), addr, - stake_lamports, + stake_lamports: *stake_lamports, }; if seen.insert((peer.pubkey, peer.addr)) { peers.push(peer); @@ -451,10 +479,15 @@ impl GossipRepairClient { peers } - fn publish_peer_snapshot(&mut self, total_candidates: usize, peers: &[RepairPeer]) { - let mut known_pubkeys = Vec::new(); + fn publish_peer_snapshot( + &mut self, + total_candidates: usize, + peers: &[RepairPeer], + all_peers: &[(ContactInfo, u64)], + ) { + let mut known_pubkeys = Vec::with_capacity(all_peers.len().saturating_add(1)); known_pubkeys.push(self.cluster_info.id().to_bytes()); - for (contact_info, _) in self.cluster_info.all_peers() { + for (contact_info, _) in all_peers { known_pubkeys.push(contact_info.pubkey().to_bytes()); } known_pubkeys.sort_unstable(); @@ -468,7 +501,7 @@ impl GossipRepairClient { }); } - fn note_peer_ping(&mut self, from_addr: std::net::SocketAddr, now: Instant) { + fn note_peer_ping(&mut self, from_addr: SocketAddr, now: Instant) { let Some(sent_at) = self.last_request_sent_at.get(&from_addr).copied() else { return; }; @@ -513,12 +546,17 @@ impl GossipRepairClient { .weight() } - fn refresh_stake_map(&mut self) { + fn refresh_stake_map(&mut self, all_peers: &[(ContactInfo, u64)]) { self.stake_by_pubkey.clear(); - for (contact_info, stake_lamports) in self.cluster_info.all_peers() { + self.stake_by_pubkey.reserve( + all_peers + .len() + .saturating_sub(self.stake_by_pubkey.capacity()), + ); + for (contact_info, stake_lamports) in all_peers { let _ = self .stake_by_pubkey - .insert(*contact_info.pubkey(), stake_lamports); + .insert(*contact_info.pubkey(), *stake_lamports); } } @@ -539,11 +577,7 @@ impl GossipRepairClient { self.serve_requests_by_addr.clear(); } - fn reserve_serve_request_budget( - &mut self, - source_addr: std::net::SocketAddr, - now: Instant, - ) -> bool { + fn reserve_serve_request_budget(&mut self, source_addr: SocketAddr, now: Instant) -> bool { self.reset_serve_window_if_needed(now); let requests = self.serve_requests_by_addr.entry(source_addr).or_default(); if *requests >= self.serve_max_requests_per_peer_per_sec { @@ -629,14 +663,11 @@ impl GossipRepairClient { selected } - fn last_request_age_ms(&self, now: Instant, addr: std::net::SocketAddr) -> u64 { + fn last_request_age_ms(&self, now: Instant, addr: SocketAddr) -> u64 { self.last_request_sent_at .get(&addr) .copied() - .map(|sent_at| { - u64::try_from(now.saturating_duration_since(sent_at).as_millis()) - .unwrap_or(u64::MAX) - }) + .map(|sent_at| duration_millis_u64(now.saturating_duration_since(sent_at))) .unwrap_or(u64::MAX) } @@ -681,6 +712,12 @@ const fn mix_seed(seed: u64) -> u64 { #[cfg(test)] mod tests { use super::*; + use std::sync::Arc; + + use sof_support::{bench::avg_ns_per_iteration, env_support::read_positive_usize}; + use solana_gossip::{cluster_info::ClusterInfo, contact_info::ContactInfo, node::Node}; + use solana_signer::Signer; + use solana_streamer::socket::SocketAddrSpace; #[test] fn sticky_peer_is_kept_within_window_when_score_gap_is_small() { @@ -708,4 +745,308 @@ mod tests { 1_000 + REPAIR_PEER_SWITCH_SCORE_MARGIN, )); } + + fn profile_repair_client(peer_count: usize) -> GossipRepairClient { + let identity = Arc::new(Keypair::new()); + let node = Node::new_localhost_with_pubkey(&identity.pubkey()); + let cluster_info = Arc::new(ClusterInfo::new( + node.info, + Arc::clone(&identity), + SocketAddrSpace::Unspecified, + )); + for _ in 0..peer_count { + let peer = Keypair::new(); + cluster_info.insert_info(ContactInfo::new_localhost(&peer.pubkey(), 0)); + } + let socket = std::net::UdpSocket::bind("127.0.0.1:0").expect("bind profile udp socket"); + socket + .set_nonblocking(true) + .expect("set nonblocking profile udp socket"); + let socket = UdpSocket::from_std(socket).expect("tokio udp socket"); + GossipRepairClient::new( + cluster_info, + socket, + identity, + GossipRepairClientConfig { + peer_cache_ttl: Duration::from_secs(60), + peer_cache_capacity: 128, + active_peer_count: 32, + peer_sample_size: 8, + serve_max_bytes_per_sec: 1_000_000, + serve_unstaked_max_bytes_per_sec: 1_000_000, + serve_max_requests_per_peer_per_sec: 1_000, + }, + ) + } + + #[tokio::test(flavor = "current_thread")] + #[ignore = "profiling fixture for repair peer refresh"] + async fn repair_refresh_peer_snapshot_profile_fixture() { + let iterations = read_positive_usize("SOF_REPAIR_REFRESH_PROFILE_ITERS", 50_000); + let peer_count = read_positive_usize("SOF_REPAIR_REFRESH_PROFILE_PEERS", 64); + let slot = 77_u64; + let mut baseline_client = profile_repair_client(peer_count); + baseline_client.peer_cache_ttl = Duration::ZERO; + let mut optimized_client = profile_repair_client(peer_count); + optimized_client.peer_cache_ttl = Duration::ZERO; + + let baseline_started_at = Instant::now(); + let mut baseline_total_candidates = 0_u64; + let mut baseline_active_candidates = 0_u64; + for _ in 0..iterations { + let (total, active) = refresh_peer_snapshot_baseline(&mut baseline_client, slot); + baseline_total_candidates = + baseline_total_candidates.saturating_add(u64::try_from(total).unwrap_or(u64::MAX)); + baseline_active_candidates = baseline_active_candidates + .saturating_add(u64::try_from(active).unwrap_or(u64::MAX)); + } + let baseline_elapsed = baseline_started_at.elapsed(); + + let optimized_started_at = Instant::now(); + let mut optimized_total_candidates = 0_u64; + let mut optimized_active_candidates = 0_u64; + for _ in 0..iterations { + let (total, active) = optimized_client.refresh_peer_snapshot(slot); + optimized_total_candidates = + optimized_total_candidates.saturating_add(u64::try_from(total).unwrap_or(u64::MAX)); + optimized_active_candidates = optimized_active_candidates + .saturating_add(u64::try_from(active).unwrap_or(u64::MAX)); + } + let optimized_elapsed = optimized_started_at.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); + println!( + "repair_refresh_peer_snapshot_profile_fixture iterations={} peer_count={} baseline_total_candidates={} baseline_active_candidates={} optimized_total_candidates={} optimized_active_candidates={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", + iterations, + peer_count, + baseline_total_candidates, + baseline_active_candidates, + optimized_total_candidates, + optimized_active_candidates, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0 + ); + } + + #[tokio::test(flavor = "current_thread")] + #[ignore = "profiling fixture for cached repair peer selection"] + async fn repair_pick_peer_cached_profile_fixture() { + let iterations = read_positive_usize("SOF_REPAIR_PICK_PEER_PROFILE_ITERS", 200_000); + let peer_count = read_positive_usize("SOF_REPAIR_PICK_PEER_PROFILE_PEERS", 64); + let slot = 42_u64; + let mut client = profile_repair_client(0); + let peers = (0..peer_count) + .map(|index| RepairPeer { + pubkey: Pubkey::new_unique(), + addr: SocketAddr::from(( + [127, 0, 0, 1], + u16::try_from(10_000 + index).unwrap_or(u16::MAX), + )), + stake_lamports: (u64::try_from(index).unwrap_or(0) + 1) + .saturating_mul(SOL_LAMPORTS), + }) + .collect::>(); + let _ = client.peers_by_slot.insert( + slot, + CachedPeers { + updated_at: Instant::now(), + peers, + }, + ); + + let started_at = Instant::now(); + let mut selected = 0_u64; + for iteration in 0..iterations { + let index = u32::try_from(iteration & 0xffff).unwrap_or(u32::MAX); + if client.pick_peer(slot, index).is_some() { + selected = selected.saturating_add(1); + } + } + let elapsed = started_at.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); + println!( + "repair_pick_peer_cached_profile_fixture iterations={} peer_count={} selected={} elapsed_ms={} avg_ns_per_iteration={} avg_us_per_iteration={:.3}", + iterations, + peer_count, + selected, + elapsed.as_millis(), + avg_ns, + avg_ns as f64 / 1_000.0 + ); + } + + #[tokio::test(flavor = "current_thread")] + #[ignore = "profiling fixture for repair source note batching"] + async fn repair_note_shred_sources_profile_fixture() { + let iterations = read_positive_usize("SOF_REPAIR_NOTE_SOURCES_PROFILE_ITERS", 100_000); + let source_count = read_positive_usize("SOF_REPAIR_NOTE_SOURCES_PROFILE_BATCH", 64); + let mut client = profile_repair_client(0); + let mut sources = Vec::with_capacity(source_count); + for index in 0..source_count { + let source_addr = SocketAddr::from(( + [127, 0, 0, 1], + u16::try_from(12_000 + index).unwrap_or(u16::MAX), + )); + let direct_pubkey = Pubkey::new_unique(); + let mut ip_pubkeys = vec![direct_pubkey]; + ip_pubkeys.extend((0..3).map(|_| Pubkey::new_unique())); + let _ = client.addr_to_pubkey.insert(source_addr, direct_pubkey); + let _ = client.ip_to_pubkeys.insert(source_addr.ip(), ip_pubkeys); + sources.push((source_addr, u16::try_from((index % 4) + 1).unwrap_or(1))); + } + + let started_at = Instant::now(); + let mut updated = 0_usize; + for _ in 0..iterations { + updated = updated.saturating_add(client.note_shred_sources(&sources)); + } + let elapsed = started_at.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); + println!( + "repair_note_shred_sources_profile_fixture iterations={} source_count={} updated={} elapsed_ms={} avg_ns_per_iteration={} avg_us_per_iteration={:.3}", + iterations, + source_count, + updated, + elapsed.as_millis(), + avg_ns, + avg_ns as f64 / 1_000.0 + ); + } + + fn refresh_peer_snapshot_baseline( + client: &mut GossipRepairClient, + slot: u64, + ) -> (usize, usize) { + refresh_peers_baseline(client, slot); + let snapshot = client.peer_snapshot.shared_get(); + (snapshot.total_candidates, snapshot.active_candidates) + } + + fn refresh_peers_baseline(client: &mut GossipRepairClient, slot: u64) { + client.decay_peer_scores(); + client + .peers_by_slot + .retain(|_, cached| cached.updated_at.elapsed() < client.peer_cache_ttl); + client.sticky_peer_by_slot.retain(|slot_key, sticky| { + client.peers_by_slot.contains_key(slot_key) + && sticky.selected_at.elapsed() < client.peer_cache_ttl + }); + if client.peers_by_slot.len() > client.peer_cache_capacity { + let mut keys: Vec<_> = client.peers_by_slot.keys().copied().collect(); + keys.sort_unstable_by_key(|key| { + client + .peers_by_slot + .get(key) + .map(|cached| cached.updated_at) + .unwrap_or_else(Instant::now) + }); + let overflow = client + .peers_by_slot + .len() + .saturating_sub(client.peer_cache_capacity); + for key in keys.into_iter().take(overflow) { + let _ = client.peers_by_slot.remove(&key); + let _ = client.sticky_peer_by_slot.remove(&key); + } + } + let should_refresh = client + .peers_by_slot + .get(&slot) + .map(|cached| cached.updated_at.elapsed() >= client.peer_cache_ttl) + .unwrap_or(true); + if !should_refresh { + return; + } + let all_peers = client.cluster_info.all_peers(); + client.refresh_stake_map(&all_peers); + let mut candidates = collect_candidate_peers_baseline(client, slot, &all_peers); + let total_candidates = candidates.len(); + candidates.sort_unstable_by(|left, right| { + client + .score_for(right.pubkey) + .cmp(&client.score_for(left.pubkey)) + .then_with(|| left.pubkey.to_bytes().cmp(&right.pubkey.to_bytes())) + }); + let mut peers = candidates.clone(); + if peers.len() > client.active_peer_count { + peers.truncate(client.active_peer_count); + } + client.addr_to_pubkey.clear(); + client.ip_to_pubkeys.clear(); + for peer in &candidates { + let _ = client.addr_to_pubkey.insert(peer.addr, peer.pubkey); + client + .ip_to_pubkeys + .entry(peer.addr.ip()) + .or_default() + .push(peer.pubkey); + let _ = client.peer_scores.entry(peer.pubkey).or_default(); + } + for pubkeys in client.ip_to_pubkeys.values_mut() { + pubkeys.sort_unstable_by_key(Pubkey::to_bytes); + pubkeys.dedup(); + } + let _ = client.peers_by_slot.insert( + slot, + CachedPeers { + updated_at: Instant::now(), + peers: peers.clone(), + }, + ); + client.publish_peer_snapshot(total_candidates, &peers, &all_peers); + } + + fn collect_candidate_peers_baseline( + client: &GossipRepairClient, + slot: u64, + all_peers: &[(ContactInfo, u64)], + ) -> Vec { + let mut seen = HashSet::new(); + let mut peers = Vec::new(); + for contact_info in client.cluster_info.repair_peers(slot) { + let Some(addr) = contact_info.serve_repair(Protocol::UDP) else { + continue; + }; + let peer = RepairPeer { + pubkey: *contact_info.pubkey(), + addr, + stake_lamports: client + .stake_by_pubkey + .get(contact_info.pubkey()) + .copied() + .unwrap_or_default(), + }; + if seen.insert((peer.pubkey, peer.addr)) { + peers.push(peer); + } + } + if peers.is_empty() { + let self_pubkey = client.cluster_info.id(); + let self_shred_version = client.cluster_info.my_shred_version(); + for (contact_info, stake_lamports) in all_peers { + if contact_info.pubkey() == &self_pubkey + || contact_info.shred_version() != self_shred_version + || contact_info.tvu(Protocol::UDP).is_none() + { + continue; + } + let Some(addr) = contact_info.serve_repair(Protocol::UDP) else { + continue; + }; + let peer = RepairPeer { + pubkey: *contact_info.pubkey(), + addr, + stake_lamports: *stake_lamports, + }; + if seen.insert((peer.pubkey, peer.addr)) { + peers.push(peer); + } + } + } + peers + } } diff --git a/crates/sof-observer/src/repair/core/request.rs b/crates/sof-observer/src/repair/core/request.rs index 1b93977d..1605f57f 100644 --- a/crates/sof-observer/src/repair/core/request.rs +++ b/crates/sof-observer/src/repair/core/request.rs @@ -1,4 +1,4 @@ -use std::time::{SystemTime, UNIX_EPOCH}; +use sof_support::time_support::current_unix_ms; use bincode::Options; use serde::{Deserialize, Serialize}; @@ -336,16 +336,7 @@ pub fn build_window_index_request( } pub(super) fn unix_timestamp_ms() -> u64 { - let now = SystemTime::now(); - let Ok(duration) = now.duration_since(UNIX_EPOCH) else { - return 0; - }; - let millis = duration.as_millis(); - if millis > u128::from(u64::MAX) { - u64::MAX - } else { - millis as u64 - } + current_unix_ms() } #[cfg(test)] diff --git a/crates/sof-observer/src/repair/core/tests.rs b/crates/sof-observer/src/repair/core/tests.rs index 9cdc4f82..6b4bcd7b 100644 --- a/crates/sof-observer/src/repair/core/tests.rs +++ b/crates/sof-observer/src/repair/core/tests.rs @@ -1,5 +1,9 @@ -use std::time::Duration; +use std::{ + hint::black_box, + time::{Duration, Instant}, +}; +use sof_support::bench::{avg_ns_per_iteration, profile_iterations}; use solana_keypair::Keypair; use solana_signer::Signer; @@ -383,3 +387,81 @@ fn unix_timestamp_is_monotonicish() { let second = unix_timestamp_ms(); assert!(second >= first); } + +#[test] +fn missing_tracker_sorted_slot_keys_matches_baseline() { + let tracker = build_profile_tracker(); + assert_eq!( + tracker.sorted_slot_keys_baseline(), + tracker.sorted_slot_keys() + ); +} + +#[test] +#[ignore = "profiling fixture for missing shred tracker slot prioritization"] +fn missing_tracker_slot_sort_profile_fixture() { + let iterations = profile_iterations(5_000); + let baseline_tracker = build_profile_tracker(); + let optimized_tracker = build_profile_tracker(); + + let baseline_started = Instant::now(); + for _ in 0..iterations { + let slot_keys = baseline_tracker.sorted_slot_keys_baseline(); + black_box(slot_keys); + } + let baseline_elapsed = baseline_started.elapsed(); + + let optimized_started = Instant::now(); + for _ in 0..iterations { + let slot_keys = optimized_tracker.sorted_slot_keys(); + black_box(slot_keys); + } + let optimized_elapsed = optimized_started.elapsed(); + + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); + eprintln!( + "missing_tracker_slot_sort_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, + ); +} + +fn build_profile_tracker() -> MissingShredTracker { + let mut tracker = MissingShredTracker::new( + 512, + 0, + Duration::from_millis(10), + Duration::from_millis(100), + 4, + 16, + 4, + ); + let base = Instant::now(); + + for slot in 10_000_u64..10_768 { + tracker.on_code_shred(slot, 0, 32, base); + tracker.on_data_shred(slot, 0, 0, false, 0, base); + if slot % 3 != 0 { + tracker.on_data_shred(slot, 2, 0, false, 0, base); + } + if slot % 5 == 0 { + tracker.on_data_shred(slot, 31, 0, true, 0, base); + } + if slot % 7 == 0 { + tracker.seed_highest_probe_slot(slot + 1, base); + } + if slot % 11 == 0 { + tracker.on_code_shred(slot, 32, 32, base); + tracker.on_data_shred(slot, 32, 32, false, 0, base); + tracker.on_data_shred(slot, 34, 32, false, 0, base); + } + } + + tracker +} diff --git a/crates/sof-observer/src/repair/core/tracker/collect.rs b/crates/sof-observer/src/repair/core/tracker/collect.rs index 93b4c812..e833e0b3 100644 --- a/crates/sof-observer/src/repair/core/tracker/collect.rs +++ b/crates/sof-observer/src/repair/core/tracker/collect.rs @@ -1,5 +1,69 @@ +use std::cmp::Ordering; + use super::*; +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +struct SlotCollectPriority { + probe_ready: bool, + has_received: bool, + received_upper: u32, + probe_only: bool, + has_last: bool, + gap: u32, + observed_span: u32, +} + +impl SlotCollectPriority { + fn from_state(slot_state: &SlotMissingState) -> Self { + let probe_ready = slot_state.is_highest_probe_ready(); + let received_upper = slot_state.received_upper_bound(); + let has_received = received_upper.is_some(); + let probe_only = probe_ready && !has_received; + let has_last = slot_state.last_index_seen.is_some(); + let received_upper_value = received_upper.unwrap_or(0); + let gap = received_upper + .map(|upper| upper.saturating_sub(slot_state.contiguous_data_prefix)) + .unwrap_or(u32::MAX); + let observed_span = received_upper + .map(|upper| { + upper.saturating_sub( + slot_state + .min_data_index_seen + .unwrap_or(slot_state.contiguous_data_prefix), + ) + }) + .unwrap_or(u32::MAX); + + Self { + probe_ready, + has_received, + received_upper: received_upper_value, + probe_only, + has_last, + gap, + observed_span, + } + } + + fn cmp_with_slot(self, slot: u64, other: Self, other_slot: u64) -> Ordering { + other + .probe_ready + .cmp(&self.probe_ready) + .then_with(|| other.has_received.cmp(&self.has_received)) + .then_with(|| other.received_upper.cmp(&self.received_upper)) + .then_with(|| { + if self.probe_only && other.probe_only { + slot.cmp(&other_slot) + } else { + other_slot.cmp(&slot) + } + }) + .then_with(|| other.has_last.cmp(&self.has_last)) + .then_with(|| self.gap.cmp(&other.gap)) + .then_with(|| self.observed_span.cmp(&other.observed_span)) + } +} + impl MissingShredTracker { pub fn collect_requests( &mut self, @@ -16,68 +80,8 @@ impl MissingShredTracker { let mut requests = Vec::with_capacity(max_requests); let highest_request_budget = max_requests.min(max_highest_window_requests); let mut forward_probe_requests_sent = 0_usize; - let mut slot_request_counts: HashMap = HashMap::new(); - let mut slot_keys: Vec = self.slots.keys().copied().collect(); - slot_keys.sort_unstable_by(|a, b| { - let Some(a_state) = self.slots.get(a) else { - return std::cmp::Ordering::Greater; - }; - let Some(b_state) = self.slots.get(b) else { - return std::cmp::Ordering::Less; - }; - let a_probe_ready = a_state.is_highest_probe_ready(); - let b_probe_ready = b_state.is_highest_probe_ready(); - let a_has_received = a_state.received_upper_bound().is_some(); - let b_has_received = b_state.received_upper_bound().is_some(); - let a_probe_only = a_probe_ready && !a_has_received; - let b_probe_only = b_probe_ready && !b_has_received; - let a_has_last = a_state.last_index_seen.is_some(); - let b_has_last = b_state.last_index_seen.is_some(); - let a_received_upper = a_state.received_upper_bound().unwrap_or(0); - let b_received_upper = b_state.received_upper_bound().unwrap_or(0); - let a_gap = a_state - .received_upper_bound() - .map(|upper| upper.saturating_sub(a_state.contiguous_data_prefix)) - .unwrap_or(u32::MAX); - let b_gap = b_state - .received_upper_bound() - .map(|upper| upper.saturating_sub(b_state.contiguous_data_prefix)) - .unwrap_or(u32::MAX); - let a_observed_span = a_state - .received_upper_bound() - .map(|upper| { - upper.saturating_sub( - a_state - .min_data_index_seen - .unwrap_or(a_state.contiguous_data_prefix), - ) - }) - .unwrap_or(u32::MAX); - let b_observed_span = b_state - .received_upper_bound() - .map(|upper| { - upper.saturating_sub( - b_state - .min_data_index_seen - .unwrap_or(b_state.contiguous_data_prefix), - ) - }) - .unwrap_or(u32::MAX); - b_probe_ready - .cmp(&a_probe_ready) - .then_with(|| b_has_received.cmp(&a_has_received)) - .then_with(|| b_received_upper.cmp(&a_received_upper)) - .then_with(|| { - if a_probe_only && b_probe_only { - a.cmp(b) - } else { - b.cmp(a) - } - }) - .then_with(|| b_has_last.cmp(&a_has_last)) - .then_with(|| a_gap.cmp(&b_gap)) - .then_with(|| a_observed_span.cmp(&b_observed_span)) - }); + let mut slot_request_counts: HashMap = HashMap::with_capacity(self.slots.len()); + let slot_keys = self.sorted_slot_keys(); if highest_request_budget > 0 { for slot in &slot_keys { if requests.len() >= max_requests || requests.len() >= highest_request_budget { @@ -290,4 +294,35 @@ impl MissingShredTracker { fec_set_index = next; } } + + pub(crate) fn sorted_slot_keys(&self) -> Vec { + let mut priorities: Vec<_> = self + .slots + .iter() + .map(|(&slot, slot_state)| (slot, SlotCollectPriority::from_state(slot_state))) + .collect(); + priorities.sort_unstable_by(|(slot, priority), (other_slot, other_priority)| { + priority.cmp_with_slot(*slot, *other_priority, *other_slot) + }); + priorities.into_iter().map(|(slot, _)| slot).collect() + } + + #[cfg(test)] + pub(crate) fn sorted_slot_keys_baseline(&self) -> Vec { + let mut slot_keys: Vec = self.slots.keys().copied().collect(); + slot_keys.sort_unstable_by(|a, b| { + let Some(a_state) = self.slots.get(a) else { + return Ordering::Greater; + }; + let Some(b_state) = self.slots.get(b) else { + return Ordering::Less; + }; + SlotCollectPriority::from_state(a_state).cmp_with_slot( + *a, + SlotCollectPriority::from_state(b_state), + *b, + ) + }); + slot_keys + } } diff --git a/crates/sof-observer/src/runtime.rs b/crates/sof-observer/src/runtime.rs index c1773106..8388768e 100644 --- a/crates/sof-observer/src/runtime.rs +++ b/crates/sof-observer/src/runtime.rs @@ -19,8 +19,9 @@ pub use crate::app::config::GossipRuntimeMode; use crate::app::config::read_observability_bind_addr; use crate::framework::host::TransactionDispatchScope; use crate::framework::{ - DerivedStateHost, DerivedStateReplayBackend, DerivedStateReplayDurability, PluginHost, - RuntimeExtensionHost, TransactionEvent, + DerivedStateHost, DerivedStateReplayBackend, DerivedStateReplayDurability, + ObservedRecentBlockhashEvent, PluginHost, RuntimeExtensionHost, SignatureBytes, + TransactionEvent, }; #[cfg(feature = "kernel-bypass")] use crate::ingest::{RawPacketBatchReceiver, RawPacketBatchSender, create_raw_packet_batch_queue}; @@ -43,13 +44,14 @@ use sof_gossip_tuning::{ QueueCapacity, ReceiverCoalesceWindow, RuntimeTuningPort, SofRuntimeTuning, TvuReceiveSocketCount, }; -use solana_signature::Signature; +use solana_packet::PACKET_DATA_SIZE; use solana_transaction::versioned::VersionedTransaction; use thiserror::Error; type ShutdownSignal = Pin + Send + 'static>>; const PROVIDER_REPLAY_DEDUPE_CAPACITY: usize = 65_536; const PROVIDER_REPLAY_DEDUPE_SLOT_WINDOW: u64 = 4_096; +const MAX_PROVIDER_SERIALIZED_TRANSACTION_BYTES: usize = PACKET_DATA_SIZE; #[cfg(feature = "gossip-bootstrap")] const PROVIDER_GOSSIP_CONTROL_PLANE_POLL_MS: u64 = 250; #[cfg(feature = "gossip-bootstrap")] @@ -2146,7 +2148,7 @@ const fn provider_stream_mode_accepts_source_kind( enum ProviderReplayLogicalKey { Transaction { slot: u64, - signature: Signature, + signature: SignatureBytes, commitment_status: u8, confirmed_slot: Option, finalized_slot: Option, @@ -2235,15 +2237,106 @@ impl ProviderReplayDedupe { return false; }; + let suppressed = match source.arbitration() { + provider_stream::ProviderSourceArbitrationMode::EmitAll => false, + provider_stream::ProviderSourceArbitrationMode::FirstSeen => { + match self.arbitrated.entry(logical.clone()) { + Entry::Occupied(_) => true, + Entry::Vacant(entry) => { + entry.insert(ProviderReplayArbitratedWinner { + priority: source.priority(), + source, + }); + self.arbitrated_order.push_back(logical); + false + } + } + } + provider_stream::ProviderSourceArbitrationMode::FirstSeenThenPromote => { + match self.arbitrated.entry(logical.clone()) { + Entry::Occupied(mut entry) => { + let winner = entry.get_mut(); + if source.priority() > winner.priority { + winner.priority = source.priority(); + winner.source = source; + false + } else { + true + } + } + Entry::Vacant(entry) => { + entry.insert(ProviderReplayArbitratedWinner { + priority: source.priority(), + source, + }); + self.arbitrated_order.push_back(logical); + false + } + } + } + }; + self.evict(); + suppressed + } + + fn evict(&mut self) { + let min_slot = self + .max_slot_seen + .saturating_sub(PROVIDER_REPLAY_DEDUPE_SLOT_WINDOW); + while self.order.len() > self.capacity + || self + .order + .front() + .is_some_and(|oldest| oldest.slot() < min_slot) + { + let Some(oldest) = self.order.pop_front() else { + break; + }; + self.seen.remove(&oldest); + } + while self.arbitrated_order.len() > self.capacity + || self + .arbitrated_order + .front() + .is_some_and(|oldest| oldest.slot() < min_slot) + { + let Some(oldest) = self.arbitrated_order.pop_front() else { + break; + }; + self.arbitrated.remove(&oldest); + } + } + + #[cfg(test)] + fn observe_baseline(&mut self, update: &ProviderStreamUpdate) -> bool { + let Some(logical) = provider_replay_dedupe_key(update) else { + return false; + }; + self.max_slot_seen = self.max_slot_seen.max(logical.slot()); + let source = provider_stream_update_source_ref(update); + let observed = ProviderReplayObservedKey { + source: source.clone(), + logical: logical.clone(), + }; + if !self.seen.insert(observed.clone()) { + return true; + } + self.order.push_back(observed); + + let Some(source) = source else { + self.evict_baseline(); + return false; + }; + match source.arbitration() { provider_stream::ProviderSourceArbitrationMode::EmitAll => { - self.evict(); + self.evict_baseline(); false } provider_stream::ProviderSourceArbitrationMode::FirstSeen => { match self.arbitrated.entry(logical.clone()) { Entry::Occupied(_) => { - self.evict(); + self.evict_baseline(); true } Entry::Vacant(entry) => { @@ -2252,7 +2345,7 @@ impl ProviderReplayDedupe { source, }); self.arbitrated_order.push_back(logical); - self.evict(); + self.evict_baseline(); false } } @@ -2264,10 +2357,10 @@ impl ProviderReplayDedupe { if source.priority() > winner.priority { winner.priority = source.priority(); winner.source = source; - self.evict(); + self.evict_baseline(); false } else { - self.evict(); + self.evict_baseline(); true } } @@ -2277,7 +2370,7 @@ impl ProviderReplayDedupe { source, }); self.arbitrated_order.push_back(logical); - self.evict(); + self.evict_baseline(); false } } @@ -2285,7 +2378,8 @@ impl ProviderReplayDedupe { } } - fn evict(&mut self) { + #[cfg(test)] + fn evict_baseline(&mut self) { let min_slot = self .max_slot_seen .saturating_sub(PROVIDER_REPLAY_DEDUPE_SLOT_WINDOW); @@ -2330,8 +2424,14 @@ fn provider_replay_dedupe_key(update: &ProviderStreamUpdate) -> Option event .signature - .map(framework::SignatureBytes::to_solana) - .or_else(|| event.tx.signatures.first().copied()) + .or_else(|| { + event + .tx + .signatures + .first() + .copied() + .map(SignatureBytes::from_solana) + }) .map(|signature| ProviderReplayLogicalKey::Transaction { slot: event.slot, signature, @@ -2339,29 +2439,26 @@ fn provider_replay_dedupe_key(update: &ProviderStreamUpdate) -> Option event - .signature - .map(framework::SignatureBytes::to_solana) - .map_or_else( - || { - Some(ProviderReplayLogicalKey::SerializedTransaction { - slot: event.slot, - commitment_status: provider_replay_commitment_key(event.commitment_status), - confirmed_slot: event.confirmed_slot, - finalized_slot: event.finalized_slot, - fingerprint: provider_replay_fingerprint(&event.bytes), - }) - }, - |signature| { - Some(ProviderReplayLogicalKey::Transaction { - slot: event.slot, - signature, - commitment_status: provider_replay_commitment_key(event.commitment_status), - confirmed_slot: event.confirmed_slot, - finalized_slot: event.finalized_slot, - }) - }, - ), + ProviderStreamUpdate::SerializedTransaction(event) => event.signature.map_or_else( + || { + Some(ProviderReplayLogicalKey::SerializedTransaction { + slot: event.slot, + commitment_status: provider_replay_commitment_key(event.commitment_status), + confirmed_slot: event.confirmed_slot, + finalized_slot: event.finalized_slot, + fingerprint: provider_replay_fingerprint(&event.bytes), + }) + }, + |signature| { + Some(ProviderReplayLogicalKey::Transaction { + slot: event.slot, + signature, + commitment_status: provider_replay_commitment_key(event.commitment_status), + confirmed_slot: event.confirmed_slot, + finalized_slot: event.finalized_slot, + }) + }, + ), ProviderStreamUpdate::RecentBlockhash(event) => { Some(ProviderReplayLogicalKey::ControlPlane { slot: event.slot, @@ -2488,7 +2585,7 @@ fn dispatch_provider_stream_update( match update { ProviderStreamUpdate::Transaction(event) => { if plugin_host.wants_recent_blockhash() { - plugin_host.on_recent_blockhash(framework::ObservedRecentBlockhashEvent { + plugin_host.on_recent_blockhash(ObservedRecentBlockhashEvent { slot: event.slot, recent_blockhash: event.tx.message.recent_blockhash().to_bytes(), dataset_tx_count: 1, @@ -2600,9 +2697,26 @@ fn dispatch_provider_stream_serialized_transaction( let wants_recent_blockhash = plugin_host.wants_recent_blockhash(); let wants_derived_state_transaction = !derived_state_empty && derived_state_host.wants_transaction_applied(); + let needs_transaction_event = wants_transaction || wants_derived_state_transaction; if !wants_transaction && !wants_recent_blockhash && !wants_derived_state_transaction { return; } + if event.bytes.len() > MAX_PROVIDER_SERIALIZED_TRANSACTION_BYTES { + tracing::warn!( + slot = event.slot, + payload_len = event.bytes.len(), + "provider serialized transaction exceeds max wire size" + ); + return; + } + if wants_transaction + && !wants_recent_blockhash + && !wants_derived_state_transaction + && !plugin_host.has_transaction_prefilter_at_commitment(event.commitment_status) + { + dispatch_provider_stream_serialized_transaction_decode_only(plugin_host, event); + return; + } let mut signature = event.signature; let mut recent_blockhash = None; @@ -2614,16 +2728,6 @@ fn dispatch_provider_stream_serialized_transaction( if should_try_view && let Ok(view) = SanitizedTransactionView::try_new_sanitized(event.bytes.as_ref(), true) { - if signature.is_none() { - signature = view - .signatures() - .first() - .copied() - .map(framework::SignatureBytes::from_solana); - } - kind = Some(provider_stream::classify_provider_transaction_kind_view( - &view, - )); if wants_recent_blockhash { recent_blockhash = Some(view.recent_blockhash().to_bytes()); } @@ -2638,7 +2742,7 @@ fn dispatch_provider_stream_serialized_transaction( && !wants_derived_state_transaction { if let Some(recent_blockhash) = recent_blockhash { - plugin_host.on_recent_blockhash(framework::ObservedRecentBlockhashEvent { + plugin_host.on_recent_blockhash(ObservedRecentBlockhashEvent { slot: event.slot, recent_blockhash, dataset_tx_count: 1, @@ -2648,9 +2752,9 @@ fn dispatch_provider_stream_serialized_transaction( return; } } - if !wants_transaction && !wants_derived_state_transaction { + if !needs_transaction_event { if let Some(recent_blockhash) = recent_blockhash { - plugin_host.on_recent_blockhash(framework::ObservedRecentBlockhashEvent { + plugin_host.on_recent_blockhash(ObservedRecentBlockhashEvent { slot: event.slot, recent_blockhash, dataset_tx_count: 1, @@ -2659,6 +2763,18 @@ fn dispatch_provider_stream_serialized_transaction( } return; } + if wants_transaction || wants_derived_state_transaction { + if signature.is_none() { + signature = view + .signatures() + .first() + .copied() + .map(SignatureBytes::from_solana); + } + kind = Some(provider_stream::classify_provider_transaction_kind_view( + &view, + )); + } } let Ok(tx) = bincode::deserialize::(event.bytes.as_ref()) else { @@ -2669,6 +2785,16 @@ fn dispatch_provider_stream_serialized_transaction( return; }; let tx = Arc::new(tx); + if !needs_transaction_event { + plugin_host.on_recent_blockhash(ObservedRecentBlockhashEvent { + slot: event.slot, + recent_blockhash: recent_blockhash + .unwrap_or_else(|| tx.message.recent_blockhash().to_bytes()), + dataset_tx_count: 1, + provider_source: event.provider_source.clone(), + }); + return; + } let event = TransactionEvent { slot: event.slot, commitment_status: event.commitment_status, @@ -2678,14 +2804,14 @@ fn dispatch_provider_stream_serialized_transaction( tx.signatures .first() .copied() - .map(framework::SignatureBytes::from_solana) + .map(SignatureBytes::from_solana) }), provider_source: event.provider_source.clone(), kind: kind.unwrap_or_else(|| provider_stream::classify_provider_transaction_kind(&tx)), tx, }; if wants_recent_blockhash { - plugin_host.on_recent_blockhash(framework::ObservedRecentBlockhashEvent { + plugin_host.on_recent_blockhash(ObservedRecentBlockhashEvent { slot: event.slot, recent_blockhash: recent_blockhash .unwrap_or_else(|| event.tx.message.recent_blockhash().to_bytes()), @@ -2702,6 +2828,35 @@ fn dispatch_provider_stream_serialized_transaction( } } +fn dispatch_provider_stream_serialized_transaction_decode_only( + plugin_host: &PluginHost, + event: &provider_stream::SerializedTransactionEvent, +) { + let Ok(tx) = bincode::deserialize::(event.bytes.as_ref()) else { + tracing::warn!( + slot = event.slot, + "failed to deserialize provider serialized transaction" + ); + return; + }; + let tx = Arc::new(tx); + plugin_host.on_transaction(TransactionEvent { + slot: event.slot, + commitment_status: event.commitment_status, + confirmed_slot: event.confirmed_slot, + finalized_slot: event.finalized_slot, + signature: event.signature.or_else(|| { + tx.signatures + .first() + .copied() + .map(SignatureBytes::from_solana) + }), + provider_source: event.provider_source.clone(), + kind: provider_stream::classify_provider_transaction_kind(&tx), + tx, + }); +} + fn provider_stream_unsupported_hooks( mode: ProviderStreamMode, plugin_host: &PluginHost, @@ -3282,6 +3437,7 @@ pub async fn run_async_with_hosts_and_setup( mod tests { use std::{ collections::BTreeMap, + future::Future, sync::{ Arc, Mutex, atomic::{AtomicUsize, Ordering}, @@ -3299,19 +3455,13 @@ mod tests { }; use async_trait::async_trait; use sof_gossip_tuning::{GossipTuningProfile, HostProfilePreset, IngestQueueMode}; + use sof_support::bench::{avg_ns_per_iteration, profile_iterations}; use solana_keypair::Keypair; use solana_message::{Message, VersionedMessage}; + use solana_signature::Signature; use solana_signer::Signer; use solana_transaction::versioned::VersionedTransaction; - fn profile_iterations(default: usize) -> usize { - std::env::var("SOF_PROFILE_ITERATIONS") - .ok() - .and_then(|value| value.parse::().ok()) - .filter(|value| *value > 0) - .unwrap_or(default) - } - fn with_runtime_env_overrides( overrides: impl IntoIterator, f: impl FnOnce() -> T, @@ -3324,7 +3474,7 @@ mod tests { f: impl FnOnce() -> Fut, ) -> T where - Fut: std::future::Future, + Fut: Future, { runtime_env::with_runtime_env_overrides_for_test_async(overrides, f).await } @@ -3343,7 +3493,7 @@ mod tests { .signatures .first() .copied() - .map(framework::SignatureBytes::from_solana), + .map(SignatureBytes::from_solana), provider_source: None, kind: TxKind::NonVote, tx: Arc::new(tx), @@ -3358,6 +3508,10 @@ mod tests { struct TransactionOnlyPlugin; struct RecentBlockhashPlugin; + struct TransactionAndRecentBlockhashCounterPlugin { + transaction_count: Arc, + recent_blockhash_count: Arc, + } #[cfg(feature = "gossip-bootstrap")] struct ClusterTopologyOnlyPlugin; struct StartupCounterPlugin { @@ -3424,6 +3578,23 @@ mod tests { } } + #[async_trait] + impl ObserverPlugin for TransactionAndRecentBlockhashCounterPlugin { + fn config(&self) -> PluginConfig { + PluginConfig::new() + .with_transaction() + .with_recent_blockhash() + } + + async fn on_transaction(&self, _event: &TransactionEvent) { + self.transaction_count.fetch_add(1, Ordering::Relaxed); + } + + async fn on_recent_blockhash(&self, _event: ObservedRecentBlockhashEvent) { + self.recent_blockhash_count.fetch_add(1, Ordering::Relaxed); + } + } + #[async_trait] impl ObserverPlugin for StartupCounterPlugin { fn config(&self) -> PluginConfig { @@ -3762,7 +3933,7 @@ mod tests { ) { match update { ProviderStreamUpdate::Transaction(event) => { - plugin_host.on_recent_blockhash(framework::ObservedRecentBlockhashEvent { + plugin_host.on_recent_blockhash(ObservedRecentBlockhashEvent { slot: event.slot, recent_blockhash: event.tx.message.recent_blockhash().to_bytes(), dataset_tx_count: 1, @@ -3824,6 +3995,9 @@ mod tests { if !wants_transaction && !wants_recent_blockhash && !wants_derived_state_transaction { return; } + if event.bytes.len() > MAX_PROVIDER_SERIALIZED_TRANSACTION_BYTES { + return; + } let mut signature = event.signature; let mut recent_blockhash = None; @@ -3834,7 +4008,7 @@ mod tests { .signatures() .first() .copied() - .map(framework::SignatureBytes::from_solana); + .map(SignatureBytes::from_solana); } kind = Some(provider_stream::classify_provider_transaction_kind_view( &view, @@ -3853,7 +4027,7 @@ mod tests { && !wants_derived_state_transaction { if let Some(recent_blockhash) = recent_blockhash { - plugin_host.on_recent_blockhash(framework::ObservedRecentBlockhashEvent { + plugin_host.on_recent_blockhash(ObservedRecentBlockhashEvent { slot: event.slot, recent_blockhash, dataset_tx_count: 1, @@ -3865,7 +4039,7 @@ mod tests { } if !wants_transaction && !wants_derived_state_transaction { if let Some(recent_blockhash) = recent_blockhash { - plugin_host.on_recent_blockhash(framework::ObservedRecentBlockhashEvent { + plugin_host.on_recent_blockhash(ObservedRecentBlockhashEvent { slot: event.slot, recent_blockhash, dataset_tx_count: 1, @@ -3889,14 +4063,14 @@ mod tests { tx.signatures .first() .copied() - .map(framework::SignatureBytes::from_solana) + .map(SignatureBytes::from_solana) }), provider_source: event.provider_source.clone(), kind: kind.unwrap_or_else(|| provider_stream::classify_provider_transaction_kind(&tx)), tx, }; if wants_recent_blockhash { - plugin_host.on_recent_blockhash(framework::ObservedRecentBlockhashEvent { + plugin_host.on_recent_blockhash(ObservedRecentBlockhashEvent { slot: event.slot, recent_blockhash: recent_blockhash .unwrap_or_else(|| event.tx.message.recent_blockhash().to_bytes()), @@ -3992,7 +4166,7 @@ mod tests { match sample_provider_transaction_update() { ProviderStreamUpdate::Transaction(mut event) => { event.slot = slot; - event.signature = Some(framework::SignatureBytes::from_solana(Signature::from( + event.signature = Some(SignatureBytes::from_solana(Signature::from( [slot as u8; 64], ))); ProviderStreamUpdate::Transaction(event) @@ -4013,7 +4187,7 @@ mod tests { } fn sample_provider_recent_blockhash_update(slot: u64) -> ProviderStreamUpdate { - ProviderStreamUpdate::RecentBlockhash(framework::ObservedRecentBlockhashEvent { + ProviderStreamUpdate::RecentBlockhash(ObservedRecentBlockhashEvent { slot, recent_blockhash: [slot as u8; 32], dataset_tx_count: 1, @@ -4027,7 +4201,7 @@ mod tests { commitment_status: TxCommitmentStatus::Processed, confirmed_slot: None, finalized_slot: None, - signature: framework::SignatureBytes::from_solana(Signature::from([slot as u8; 64])), + signature: SignatureBytes::from_solana(Signature::from([slot as u8; 64])), is_vote: false, index: Some(0), err: None, @@ -4068,7 +4242,7 @@ mod tests { rent_epoch: 0, data: Arc::from([1_u8, 2, 3, 4]), write_version: Some(slot), - txn_signature: Some(framework::SignatureBytes::from_solana(Signature::from( + txn_signature: Some(SignatureBytes::from_solana(Signature::from( [slot as u8; 64], ))), is_startup: false, @@ -4081,7 +4255,7 @@ mod tests { ProviderStreamUpdate::TransactionLog(framework::TransactionLogEvent { slot, commitment_status: TxCommitmentStatus::Processed, - signature: framework::SignatureBytes::from_solana(Signature::from([slot as u8; 64])), + signature: SignatureBytes::from_solana(Signature::from([slot as u8; 64])), err: None, logs: Arc::from([String::from("program log: hello")]), matched_filter: None, @@ -4914,6 +5088,34 @@ mod tests { assert!(dedupe.observe(&update)); } + #[test] + fn oversized_serialized_provider_transaction_is_dropped() { + let transaction_count = Arc::new(AtomicUsize::new(0)); + let recent_blockhash_count = Arc::new(AtomicUsize::new(0)); + let plugin_host = PluginHost::builder() + .add_plugin(TransactionAndRecentBlockhashCounterPlugin { + transaction_count: transaction_count.clone(), + recent_blockhash_count: recent_blockhash_count.clone(), + }) + .build(); + let derived_state_host = DerivedStateHost::builder().build(); + let ProviderStreamUpdate::SerializedTransaction(mut event) = + sample_serialized_provider_transaction_update() + else { + panic!("expected serialized update fixture"); + }; + event.bytes = vec![0_u8; MAX_PROVIDER_SERIALIZED_TRANSACTION_BYTES + 1].into_boxed_slice(); + + dispatch_provider_stream_update( + &plugin_host, + &derived_state_host, + ProviderStreamUpdate::SerializedTransaction(event), + ); + + assert_eq!(transaction_count.load(Ordering::Relaxed), 0); + assert_eq!(recent_blockhash_count.load(Ordering::Relaxed), 0); + } + #[test] fn provider_replay_dedupe_keeps_higher_commitment_transaction_update() { let initial = sample_provider_transaction_update(); @@ -5735,4 +5937,102 @@ mod tests { optimized_elapsed.as_micros(), ); } + + #[test] + #[ignore = "profiling fixture for provider serialized tx recent-blockhash-only path"] + fn provider_stream_serialized_transaction_recent_blockhash_profile_fixture() { + let iterations = profile_iterations(500_000); + let plugin_host = PluginHost::builder() + .add_plugin(RecentBlockhashPlugin) + .build(); + let derived_state_host = DerivedStateHost::builder().build(); + let baseline_update = sample_serialized_provider_transaction_update(); + let optimized_update = baseline_update.clone(); + + let baseline_started = Instant::now(); + for _ in 0..iterations { + let ProviderStreamUpdate::SerializedTransaction(event) = baseline_update.clone() else { + panic!("expected serialized update fixture"); + }; + dispatch_provider_stream_serialized_transaction_baseline( + &plugin_host, + &derived_state_host, + &event, + ); + } + let baseline_elapsed = baseline_started.elapsed(); + + let optimized_started = Instant::now(); + for _ in 0..iterations { + dispatch_provider_stream_update( + &plugin_host, + &derived_state_host, + optimized_update.clone(), + ); + } + let optimized_elapsed = optimized_started.elapsed(); + + eprintln!( + "provider_stream_serialized_transaction_recent_blockhash_profile_fixture iterations={} baseline_us={} optimized_us={}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + ); + } + + #[test] + #[ignore = "profiling fixture for provider replay dedupe eviction churn"] + fn provider_replay_dedupe_eviction_profile_fixture() { + let iterations = profile_iterations(500_000); + let source_a = provider_stream::ProviderSourceIdentity::new( + provider_stream::ProviderSourceId::Generic("source-a".to_owned().into()), + "primary", + ) + .with_arbitration(provider_stream::ProviderSourceArbitrationMode::FirstSeenThenPromote) + .with_priority(100); + let source_b = provider_stream::ProviderSourceIdentity::new( + provider_stream::ProviderSourceId::Generic("source-b".to_owned().into()), + "secondary", + ) + .with_arbitration(provider_stream::ProviderSourceArbitrationMode::FirstSeenThenPromote) + .with_priority(300); + let updates: Vec<_> = (0..4096_u64) + .flat_map(|slot| { + let update = sample_provider_transaction_update_at(slot); + [ + update.clone().with_provider_source(source_a.clone()), + update.with_provider_source(source_b.clone()), + ] + }) + .collect(); + let mut baseline = ProviderReplayDedupe::new(1024); + let mut optimized = ProviderReplayDedupe::new(1024); + + let baseline_started = Instant::now(); + for i in 0..iterations { + let update = &updates[i % updates.len()]; + baseline.observe_baseline(update); + } + let baseline_elapsed = baseline_started.elapsed(); + + let optimized_started = Instant::now(); + for i in 0..iterations { + let update = &updates[i % updates.len()]; + optimized.observe(update); + } + let optimized_elapsed = optimized_started.elapsed(); + let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations); + let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations); + + eprintln!( + "provider_replay_dedupe_eviction_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + baseline_avg_ns, + optimized_avg_ns, + baseline_avg_ns as f64 / 1_000.0, + optimized_avg_ns as f64 / 1_000.0, + ); + } } diff --git a/crates/sof-observer/src/runtime_env.rs b/crates/sof-observer/src/runtime_env.rs index 117cfbe2..6ada52d7 100644 --- a/crates/sof-observer/src/runtime_env.rs +++ b/crates/sof-observer/src/runtime_env.rs @@ -1,5 +1,8 @@ -use std::collections::HashMap; -use std::sync::{OnceLock, RwLock}; +use std::{ + collections::HashMap, + env, + sync::{OnceLock, RwLock}, +}; /// Global runtime setup overrides used to avoid mutating process env at runtime. static ENV_OVERRIDES: OnceLock>> = OnceLock::new(); @@ -16,7 +19,7 @@ pub(crate) fn read_env_var(name: &str) -> Option { { return Some(value.clone()); } - std::env::var(name).ok() + env::var(name).ok() } /// Replaces all runtime setup environment overrides. @@ -89,9 +92,11 @@ where #[cfg(test)] mod tests { use super::*; - use std::sync::mpsc; - use std::thread; - use std::time::Duration; + use std::{ + sync::{Arc, mpsc}, + thread, + time::Duration, + }; #[tokio::test] async fn sync_and_async_override_helpers_share_the_same_lock() { @@ -110,10 +115,10 @@ mod tests { .recv() .expect("guard-ready signal should arrive"); - let started = std::sync::Arc::new(tokio::sync::Notify::new()); - let finished = std::sync::Arc::new(tokio::sync::Notify::new()); - let started_wait = std::sync::Arc::clone(&started); - let finished_wait = std::sync::Arc::clone(&finished); + let started = Arc::new(tokio::sync::Notify::new()); + let finished = Arc::new(tokio::sync::Notify::new()); + let started_wait = Arc::clone(&started); + let finished_wait = Arc::clone(&finished); let waiter = tokio::spawn(async move { started_wait.notify_one(); diff --git a/crates/sof-observer/src/shred/fec/core.rs b/crates/sof-observer/src/shred/fec/core.rs index 6ad68f87..7f5fee9b 100644 --- a/crates/sof-observer/src/shred/fec/core.rs +++ b/crates/sof-observer/src/shred/fec/core.rs @@ -1,8 +1,14 @@ -use std::collections::{HashMap, hash_map::Entry}; +use std::{ + collections::{HashMap as StdHashMap, hash_map::Entry}, + sync::Arc, +}; +use ahash::RandomState; use reed_solomon_erasure::galois_8::ReedSolomon; -use crate::shred::wire::{ParsedShredHeader, ShredVariant}; +use crate::shred::wire::{ + ParsedShredHeader, SIZE_OF_CODING_SHRED_HEADERS, SIZE_OF_CODING_SHRED_PAYLOAD, ShredVariant, +}; #[path = "recover.rs"] mod recover; @@ -12,10 +18,15 @@ use recover::{RecoveredDataPacket, parse_packet_signature, recover_missing_data} const SIZE_OF_SIGNATURE: usize = 64; const SIZE_OF_MERKLE_ROOT: usize = 32; const SIZE_OF_MERKLE_PROOF_ENTRY: usize = 20; +const INITIAL_DATA_SHARD_CAPACITY: usize = 2; +const INITIAL_CODING_SHARD_CAPACITY: usize = 1; + +type HashMap = StdHashMap; pub struct FecRecoverer { sets: HashMap<(u64, u32), ErasureSet>, reed_solomon_cache: HashMap<(usize, usize), ReedSolomon>, + recovery_scratch: RecoveryScratch, max_tracked_sets: usize, retained_slot_lag: u64, last_pruned_floor: u64, @@ -26,11 +37,26 @@ struct ErasureSet { config: Option, config_fec_set_index: Option, leader_signature: [u8; SIZE_OF_SIGNATURE], - data_shards: HashMap>, + data_shards: HashMap, coding_shards: HashMap>, present_data_shreds_in_config: usize, } +#[derive(Default)] +struct RecoveryScratch { + shards: Vec>>, + data_present: Vec, +} + +enum StoredShard { + Borrowed { + packet: Arc<[u8]>, + offset: usize, + len: usize, + }, + Owned(Vec), +} + #[derive(Clone, Copy, Debug, Eq, PartialEq)] struct SetVariant { proof_size: u8, @@ -50,8 +76,8 @@ impl ErasureSet { config: None, config_fec_set_index: None, leader_signature, - data_shards: HashMap::new(), - coding_shards: HashMap::new(), + data_shards: fast_hash_map_with_capacity(INITIAL_DATA_SHARD_CAPACITY), + coding_shards: fast_hash_map_with_capacity(INITIAL_CODING_SHARD_CAPACITY), present_data_shreds_in_config: 0, } } @@ -61,8 +87,9 @@ impl FecRecoverer { #[must_use] pub fn new(max_tracked_sets: usize, retained_slot_lag: u64) -> Self { Self { - sets: HashMap::new(), - reed_solomon_cache: HashMap::new(), + sets: HashMap::default(), + reed_solomon_cache: HashMap::default(), + recovery_scratch: RecoveryScratch::default(), max_tracked_sets, retained_slot_lag: retained_slot_lag.max(1), last_pruned_floor: 0, @@ -71,10 +98,10 @@ impl FecRecoverer { pub fn ingest_packet( &mut self, - packet: &[u8], + packet: &Arc<[u8]>, parsed: &ParsedShredHeader, ) -> Vec { - let signature = match parse_packet_signature(packet) { + let signature = match parse_packet_signature(packet.as_ref()) { Some(signature) => signature, None => return Vec::new(), }; @@ -103,8 +130,13 @@ impl FecRecoverer { return Vec::new(); } set.ingest_packet(parsed, packet); - recovered = recover_missing_data(set, fec_set_index, &mut self.reed_solomon_cache) - .unwrap_or_default(); + recovered = recover_missing_data( + set, + fec_set_index, + &mut self.reed_solomon_cache, + &mut self.recovery_scratch, + ) + .unwrap_or_default(); should_remove = set.is_data_complete_for_config(fec_set_index); } else { let mut new_set = ErasureSet::new(signature); @@ -149,12 +181,16 @@ impl FecRecoverer { } } +fn fast_hash_map_with_capacity(capacity: usize) -> HashMap { + HashMap::with_capacity_and_hasher(capacity, RandomState::default()) +} + impl ErasureSet { fn accepts_variant(&self, incoming: SetVariant) -> bool { self.variant.is_none_or(|existing| existing == incoming) } - fn ingest_packet(&mut self, parsed: &ParsedShredHeader, packet: &[u8]) { + fn ingest_packet(&mut self, parsed: &ParsedShredHeader, packet: &Arc<[u8]>) { let common_variant = match parsed { ParsedShredHeader::Data(data) => data.common.shred_variant, ParsedShredHeader::Code(code) => code.common.shred_variant, @@ -168,21 +204,19 @@ impl ErasureSet { match parsed { ParsedShredHeader::Data(data) => { + let Entry::Vacant(vacant) = self.data_shards.entry(data.common.index) else { + return; + }; let Some(shard) = extract_data_erasure_shard(packet, shard_len) else { return; }; - if let Entry::Vacant(vacant) = self.data_shards.entry(data.common.index) { - let _ = vacant.insert(shard); - if self.index_within_config(data.common.index, self.config_fec_set_index) { - self.present_data_shreds_in_config = - self.present_data_shreds_in_config.saturating_add(1); - } + let _ = vacant.insert(shard); + if self.index_within_config(data.common.index, self.config_fec_set_index) { + self.present_data_shreds_in_config = + self.present_data_shreds_in_config.saturating_add(1); } } ParsedShredHeader::Code(code) => { - let Some(shard) = extract_coding_erasure_shard(packet, shard_len) else { - return; - }; let incoming_config = ErasureConfig { num_data: usize::from(code.coding_header.num_data_shreds), num_coding: usize::from(code.coding_header.num_coding_shreds), @@ -192,22 +226,45 @@ impl ErasureSet { { return; } - if let Entry::Vacant(vacant) = self.coding_shards.entry(code.coding_header.position) - { - let _ = vacant.insert(shard); - } if self.config != Some(incoming_config) || self.config_fec_set_index != Some(code.common.fec_set_index) { - self.config = Some(incoming_config); - self.config_fec_set_index = Some(code.common.fec_set_index); - self.present_data_shreds_in_config = - self.count_present_data_shreds_in_config(code.common.fec_set_index); + self.apply_config(code.common.fec_set_index, incoming_config); } + let Entry::Vacant(vacant) = self.coding_shards.entry(code.coding_header.position) + else { + return; + }; + let Some(shard) = extract_coding_erasure_shard(packet, shard_len) else { + return; + }; + let _ = vacant.insert(shard); } } } + fn apply_config(&mut self, fec_set_index: u32, config: ErasureConfig) { + self.config = Some(config); + self.config_fec_set_index = Some(fec_set_index); + self.reserve_for_config(config); + self.present_data_shreds_in_config = + self.count_present_data_shreds_in_config(fec_set_index); + } + + fn reserve_for_config(&mut self, config: ErasureConfig) { + let data_missing_capacity = config.num_data.saturating_sub(self.data_shards.capacity()); + if data_missing_capacity > 0 { + self.data_shards.reserve(data_missing_capacity); + } + + let coding_missing_capacity = config + .num_coding + .saturating_sub(self.coding_shards.capacity()); + if coding_missing_capacity > 0 { + self.coding_shards.reserve(coding_missing_capacity); + } + } + fn is_data_complete_for_config(&self, fec_set_index: u32) -> bool { let Some(config) = self.config else { return false; @@ -250,7 +307,7 @@ impl ErasureSet { recovered_shard: Vec, ) -> bool { if let Entry::Vacant(vacant) = self.data_shards.entry(index) { - let _ = vacant.insert(recovered_shard); + let _ = vacant.insert(StoredShard::Owned(recovered_shard)); if self.index_within_config(index, Some(fec_set_index)) { self.present_data_shreds_in_config = self.present_data_shreds_in_config.saturating_add(1); @@ -261,6 +318,31 @@ impl ErasureSet { } } +impl RecoveryScratch { + fn prepare(&mut self, total: usize, data_count: usize) { + self.shards.clear(); + self.shards.resize_with(total, || None); + self.data_present.clear(); + self.data_present.resize(data_count, false); + } +} + +impl StoredShard { + fn to_owned_vec(&self) -> Option> { + match self { + Self::Borrowed { + packet, + offset, + len, + } => { + let end = offset.checked_add(*len)?; + Some(packet.get(*offset..end)?.to_vec()) + } + Self::Owned(bytes) => Some(bytes.clone()), + } + } +} + impl From for SetVariant { fn from(value: ShredVariant) -> Self { Self { @@ -270,16 +352,25 @@ impl From for SetVariant { } } -fn extract_data_erasure_shard(packet: &[u8], shard_len: usize) -> Option> { +fn extract_data_erasure_shard(packet: &Arc<[u8]>, shard_len: usize) -> Option { let start = SIZE_OF_SIGNATURE; - let end = start.checked_add(shard_len)?; - packet.get(start..end).map(ToOwned::to_owned) + shard_from_packet(packet, start, shard_len) } -fn extract_coding_erasure_shard(packet: &[u8], shard_len: usize) -> Option> { - let start = crate::shred::wire::SIZE_OF_CODING_SHRED_HEADERS; +fn extract_coding_erasure_shard(packet: &Arc<[u8]>, shard_len: usize) -> Option> { + let start = SIZE_OF_CODING_SHRED_HEADERS; let end = start.checked_add(shard_len)?; - packet.get(start..end).map(ToOwned::to_owned) + Some(packet.get(start..end)?.to_vec()) +} + +fn shard_from_packet(packet: &Arc<[u8]>, start: usize, len: usize) -> Option { + let end = start.checked_add(len)?; + let _ = packet.get(start..end)?; + Some(StoredShard::Borrowed { + packet: Arc::clone(packet), + offset: start, + len, + }) } fn coding_erasure_shard_len(variant: SetVariant) -> Option { @@ -293,8 +384,7 @@ fn coding_erasure_shard_len(variant: SetVariant) -> Option { } else { 0 })?; - crate::shred::wire::SIZE_OF_CODING_SHRED_PAYLOAD - .checked_sub(crate::shred::wire::SIZE_OF_CODING_SHRED_HEADERS.checked_add(trailer)?) + SIZE_OF_CODING_SHRED_PAYLOAD.checked_sub(SIZE_OF_CODING_SHRED_HEADERS.checked_add(trailer)?) } #[cfg(test)] @@ -324,8 +414,8 @@ mod tests { #[test] fn data_completeness_tracks_in_range_count_once_config_is_known() { let mut set = ErasureSet::new([0; SIZE_OF_SIGNATURE]); - let _ = set.data_shards.insert(10, vec![1]); - let _ = set.data_shards.insert(11, vec![2]); + let _ = set.data_shards.insert(10, StoredShard::Owned(vec![1])); + let _ = set.data_shards.insert(11, StoredShard::Owned(vec![2])); set.config = Some(ErasureConfig { num_data: 2, @@ -337,8 +427,12 @@ mod tests { assert!(set.is_data_complete_for_config(10)); let mut incomplete_set = ErasureSet::new([0; SIZE_OF_SIGNATURE]); - let _ = incomplete_set.data_shards.insert(10, vec![1]); - let _ = incomplete_set.data_shards.insert(12, vec![2]); + let _ = incomplete_set + .data_shards + .insert(10, StoredShard::Owned(vec![1])); + let _ = incomplete_set + .data_shards + .insert(12, StoredShard::Owned(vec![2])); incomplete_set.config = Some(ErasureConfig { num_data: 2, num_coding: 1, @@ -349,4 +443,20 @@ mod tests { assert!(!incomplete_set.is_data_complete_for_config(10)); } + + #[test] + fn applying_config_reserves_shard_capacity() { + let mut set = ErasureSet::new([0; SIZE_OF_SIGNATURE]); + + set.apply_config( + 10, + ErasureConfig { + num_data: 16, + num_coding: 8, + }, + ); + + assert!(set.data_shards.capacity() > INITIAL_DATA_SHARD_CAPACITY); + assert!(set.coding_shards.capacity() > INITIAL_CODING_SHARD_CAPACITY); + } } diff --git a/crates/sof-observer/src/shred/fec/recover.rs b/crates/sof-observer/src/shred/fec/recover.rs index 6f36c4f9..3bb8a249 100644 --- a/crates/sof-observer/src/shred/fec/recover.rs +++ b/crates/sof-observer/src/shred/fec/recover.rs @@ -13,6 +13,7 @@ pub(super) fn recover_missing_data( set: &mut ErasureSet, fec_set_index: u32, reed_solomon_cache: &mut HashMap<(usize, usize), ReedSolomon>, + recovery_scratch: &mut RecoveryScratch, ) -> Option> { let config = set.config?; let variant = set.variant?; @@ -22,9 +23,10 @@ pub(super) fn recover_missing_data( let _ = coding_erasure_shard_len(variant)?; let total = config.num_data.checked_add(config.num_coding)?; - let mut shards: Vec>> = vec![None; total]; + recovery_scratch.prepare(total, config.num_data); + let shards = &mut recovery_scratch.shards; let mut present = 0_usize; - let mut data_present = vec![false; config.num_data]; + let data_present = &mut recovery_scratch.data_present; for (&index, shard) in &set.data_shards { let Some(position) = index.checked_sub(fec_set_index) else { @@ -45,7 +47,10 @@ pub(super) fn recover_missing_data( continue; } if let Some(shard_slot) = shards.get_mut(position) { - *shard_slot = Some(shard.clone()); + let Some(bytes) = shard.to_owned_vec() else { + continue; + }; + *shard_slot = Some(bytes); } else { continue; } @@ -82,12 +87,12 @@ pub(super) fn recover_missing_data( let _ = vacant.insert(reed_solomon); } let reed_solomon = reed_solomon_cache.get(&key)?; - if reed_solomon.reconstruct(&mut shards).is_err() { + if reed_solomon.reconstruct(shards).is_err() { return None; } let mut recovered_payloads = Vec::new(); - for (position, was_present) in data_present.into_iter().enumerate() { + for (position, was_present) in data_present.iter().copied().enumerate() { if was_present { continue; } diff --git a/crates/sof-observer/src/shred/wire/parser.rs b/crates/sof-observer/src/shred/wire/parser.rs index 45f1061e..95a800cc 100644 --- a/crates/sof-observer/src/shred/wire/parser.rs +++ b/crates/sof-observer/src/shred/wire/parser.rs @@ -70,12 +70,6 @@ fn parse_data_shred_header( let flags = read_u8(packet, OFFSET_FLAGS)?; let payload_offset = super::SIZE_OF_DATA_SHRED_HEADERS; let payload_len = declared_size.saturating_sub(payload_offset); - if packet.get(payload_offset..declared_size).is_none() { - return Err(ParseError::PacketTooShort { - actual: packet.len(), - minimum: declared_size, - }); - } Ok(ParsedDataShredHeader { common, @@ -117,14 +111,14 @@ fn max_data_shred_size(variant: ShredVariant) -> Option { return None; } let proof = usize::from(variant.proof_size); - let trailer = SIZE_OF_MERKLE_ROOT - .checked_add(proof.checked_mul(SIZE_OF_MERKLE_PROOF_ENTRY)?)? - .checked_add(if variant.resigned { - SIZE_OF_SIGNATURE - } else { - 0 - })?; - let capacity = super::SIZE_OF_DATA_SHRED_PAYLOAD - .checked_sub(super::SIZE_OF_DATA_SHRED_HEADERS.checked_add(trailer)?)?; - super::SIZE_OF_DATA_SHRED_HEADERS.checked_add(capacity) + let proof_bytes = proof.checked_mul(SIZE_OF_MERKLE_PROOF_ENTRY)?; + let trailer = + SIZE_OF_MERKLE_ROOT + .checked_add(proof_bytes)? + .checked_add(if variant.resigned { + SIZE_OF_SIGNATURE + } else { + 0 + })?; + super::SIZE_OF_DATA_SHRED_PAYLOAD.checked_sub(trailer) } diff --git a/crates/sof-observer/src/verify/core/cache.rs b/crates/sof-observer/src/verify/core/cache.rs index 34594631..c472bfaf 100644 --- a/crates/sof-observer/src/verify/core/cache.rs +++ b/crates/sof-observer/src/verify/core/cache.rs @@ -14,22 +14,31 @@ pub(super) enum SignatureCacheEntry { #[derive(Debug)] pub(super) struct SignatureCache { - map: HashMap<[u8; SIZE_OF_SIGNATURE], SignatureCacheEntry>, - order: VecDeque<[u8; SIZE_OF_SIGNATURE]>, + map: HashMap<[u8; SIZE_OF_SIGNATURE], SignatureCacheRecord>, + order: VecDeque<([u8; SIZE_OF_SIGNATURE], u64)>, capacity: usize, + next_generation: u64, +} + +#[derive(Debug, Clone, Copy)] +struct SignatureCacheRecord { + entry: SignatureCacheEntry, + generation: u64, } impl SignatureCache { pub(super) fn new(capacity: usize) -> Self { + let capacity = capacity.max(1); Self { - map: HashMap::new(), - order: VecDeque::new(), + map: HashMap::with_capacity(capacity), + order: VecDeque::with_capacity(capacity), capacity, + next_generation: 0, } } pub(super) fn get(&self, signature: &[u8; SIZE_OF_SIGNATURE]) -> Option { - self.map.get(signature).copied() + self.map.get(signature).map(|record| record.entry) } pub(super) fn remove(&mut self, signature: &[u8; SIZE_OF_SIGNATURE]) { @@ -41,15 +50,109 @@ impl SignatureCache { signature: [u8; SIZE_OF_SIGNATURE], value: SignatureCacheEntry, ) { - if !self.map.contains_key(&signature) { - self.order.push_back(signature); - } - let _ = self.map.insert(signature, value); + let generation = self.next_generation; + self.next_generation = self.next_generation.wrapping_add(1); + self.order.push_back((signature, generation)); + let _ = self.map.insert( + signature, + SignatureCacheRecord { + entry: value, + generation, + }, + ); while self.map.len() > self.capacity { - let Some(oldest) = self.order.pop_front() else { + let Some((oldest, queued_generation)) = self.order.pop_front() else { break; }; - let _ = self.map.remove(&oldest); + if matches!( + self.map.get(&oldest), + Some(record) if record.generation == queued_generation + ) { + let _ = self.map.remove(&oldest); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + collections::{HashMap, VecDeque}, + hint::black_box, + time::Instant, + }; + + use sof_support::bench::profile_iterations; + + use super::{SIZE_OF_SIGNATURE, SignatureCache, SignatureCacheEntry}; + + #[test] + fn stale_queue_entries_do_not_evict_reinserted_signature() { + let mut cache = SignatureCache::new(2); + let first = [1_u8; 64]; + let second = [2_u8; 64]; + let third = [3_u8; 64]; + let old_at = Instant::now(); + let refreshed_at = Instant::now(); + + cache.insert(first, SignatureCacheEntry::Unknown(old_at)); + cache.insert(second, SignatureCacheEntry::Unknown(Instant::now())); + cache.remove(&first); + cache.insert(first, SignatureCacheEntry::Unknown(refreshed_at)); + cache.insert(third, SignatureCacheEntry::Unknown(Instant::now())); + + assert!(matches!( + cache.get(&first), + Some(SignatureCacheEntry::Unknown(value)) if value == refreshed_at + )); + assert!(cache.get(&second).is_none()); + assert!(cache.get(&third).is_some()); + } + + #[test] + #[ignore = "profiling fixture for signature cache insert churn"] + fn signature_cache_insert_profile_fixture() { + let iterations = profile_iterations(200_000); + let mut baseline = signature_cache_baseline(4096); + let mut optimized = SignatureCache::new(4096); + let started_at = Instant::now(); + + let baseline_started = Instant::now(); + for i in 0..iterations { + let signature = make_signature(i); + baseline.insert(signature, SignatureCacheEntry::Unknown(started_at)); + black_box(()); + } + let baseline_elapsed = baseline_started.elapsed(); + + let optimized_started = Instant::now(); + for i in 0..iterations { + let signature = make_signature(i); + optimized.insert(signature, SignatureCacheEntry::Unknown(started_at)); + black_box(()); + } + let optimized_elapsed = optimized_started.elapsed(); + + eprintln!( + "signature_cache_insert_profile_fixture iterations={} baseline_us={} optimized_us={}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + ); + } + + fn signature_cache_baseline(capacity: usize) -> SignatureCache { + SignatureCache { + map: HashMap::new(), + order: VecDeque::new(), + capacity: capacity.max(1), + next_generation: 0, } } + + fn make_signature(iteration: usize) -> [u8; SIZE_OF_SIGNATURE] { + let mut signature = [0_u8; SIZE_OF_SIGNATURE]; + signature[..8].copy_from_slice(&u64::try_from(iteration).unwrap_or(0).to_le_bytes()); + signature + } } diff --git a/crates/sof-observer/src/verify/core/tests.rs b/crates/sof-observer/src/verify/core/tests.rs index 2388f7c2..02c3baa0 100644 --- a/crates/sof-observer/src/verify/core/tests.rs +++ b/crates/sof-observer/src/verify/core/tests.rs @@ -1,4 +1,8 @@ -use std::time::Duration; +use std::{ + env, + hint::black_box, + time::{Duration, Instant}, +}; use solana_keypair::Keypair; use solana_signer::Signer; @@ -127,18 +131,18 @@ fn unknown_slot_retry_short_circuits_distinct_signatures_in_same_slot() { #[test] #[ignore = "profiling fixture for unknown-slot verifier backoff"] fn unknown_slot_retry_profile_fixture() { - let iterations = std::env::var("SOF_VERIFY_UNKNOWN_SLOT_PROFILE_ITERS") + let iterations = env::var("SOF_VERIFY_UNKNOWN_SLOT_PROFILE_ITERS") .ok() .and_then(|value| value.parse::().ok()) .filter(|value| *value > 0) .unwrap_or(200_000); - let strict_unknown = std::env::var("SOF_VERIFY_STRICT_UNKNOWN_PROFILE") + let strict_unknown = env::var("SOF_VERIFY_STRICT_UNKNOWN_PROFILE") .ok() .map(|value| matches!(value.as_str(), "1" | "true" | "TRUE" | "yes" | "YES")) .unwrap_or(false); - let now = std::time::Instant::now(); + let now = Instant::now(); let mut verifier = ShredVerifier::new(1024, 256, Duration::from_secs(5)); - let started_at = std::time::Instant::now(); + let started_at = Instant::now(); for i in 0..iterations { let mut packet = build_data_packet(11, u32::try_from(i & 0xffff).unwrap_or(u32::MAX), 11); @@ -157,20 +161,20 @@ fn unknown_slot_retry_profile_fixture() { #[test] #[ignore = "profiling fixture for strict-unknown verifier short-circuit"] fn strict_unknown_known_pubkey_profile_fixture() { - let iterations = std::env::var("SOF_VERIFY_STRICT_KNOWN_PUBKEY_PROFILE_ITERS") + let iterations = env::var("SOF_VERIFY_STRICT_KNOWN_PUBKEY_PROFILE_ITERS") .ok() .and_then(|value| value.parse::().ok()) .filter(|value| *value > 0) .unwrap_or(50_000); - let strict_unknown = std::env::var("SOF_VERIFY_STRICT_UNKNOWN_PROFILE") + let strict_unknown = env::var("SOF_VERIFY_STRICT_UNKNOWN_PROFILE") .ok() .map(|value| matches!(value.as_str(), "1" | "true" | "TRUE" | "yes" | "YES")) .unwrap_or(false); let keypair = Keypair::new(); let mut verifier = ShredVerifier::new(1024, 256, Duration::from_secs(5)); verifier.set_known_pubkeys(vec![keypair.pubkey().to_bytes()]); - let now = std::time::Instant::now(); - let started_at = std::time::Instant::now(); + let now = Instant::now(); + let started_at = Instant::now(); for i in 0..iterations { let slot = 100_000_u64.saturating_add(u64::try_from(i).unwrap_or(u64::MAX)); @@ -187,6 +191,44 @@ fn strict_unknown_known_pubkey_profile_fixture() { ); } +#[test] +#[ignore = "profiling fixture for verifier slot-state allocation churn"] +fn verifier_slot_state_allocation_profile_fixture() { + let iterations = env::var("SOF_VERIFY_UNKNOWN_SLOT_PROFILE_ITERS") + .ok() + .and_then(|value| value.parse::().ok()) + .filter(|value| *value > 0) + .unwrap_or(50_000); + let now = Instant::now(); + let mut baseline = ShredVerifier::new_baseline(1024, 256, Duration::from_secs(5)); + let mut optimized = ShredVerifier::new(1024, 256, Duration::from_secs(5)); + + let baseline_started = Instant::now(); + for i in 0..iterations { + let slot = 500_000_u64.saturating_add(u64::try_from(i).unwrap_or(u64::MAX)); + let mut packet = build_data_packet(slot, 1, 1); + packet[..SIZE_OF_SIGNATURE].fill((i & 0xff) as u8); + black_box(baseline.verify_packet(&packet, now, true)); + } + let baseline_elapsed = baseline_started.elapsed(); + + let optimized_started = Instant::now(); + for i in 0..iterations { + let slot = 500_000_u64.saturating_add(u64::try_from(i).unwrap_or(u64::MAX)); + let mut packet = build_data_packet(slot, 1, 1); + packet[..SIZE_OF_SIGNATURE].fill((i & 0xff) as u8); + black_box(optimized.verify_packet(&packet, now, true)); + } + let optimized_elapsed = optimized_started.elapsed(); + + println!( + "verifier_slot_state_allocation_profile_fixture iterations={} baseline_us={} optimized_us={}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros() + ); +} + fn build_data_packet(slot: u64, index: u32, fec_set_index: u32) -> Vec { let mut packet = vec![0_u8; SIZE_OF_DATA_SHRED_PAYLOAD]; packet[OFFSET_SHRED_VARIANT] = VARIANT_MERKLE_DATA; // data, proof_size=0 @@ -227,7 +269,7 @@ fn strict_unknown_short_circuits_before_pubkey_probe() { verifier.set_known_pubkeys(vec![keypair.pubkey().to_bytes()]); assert_eq!( - verifier.verify_packet(&packet, std::time::Instant::now(), true), + verifier.verify_packet(&packet, Instant::now(), true), VerifyStatus::UnknownLeader ); } diff --git a/crates/sof-observer/src/verify/core/verifier.rs b/crates/sof-observer/src/verify/core/verifier.rs index 0965f7e2..7e23c691 100644 --- a/crates/sof-observer/src/verify/core/verifier.rs +++ b/crates/sof-observer/src/verify/core/verifier.rs @@ -51,6 +51,32 @@ impl ShredVerifier { signature_cache_capacity: usize, slot_leader_window: u64, unknown_retry: Duration, + ) -> Self { + let slot_state_capacity = usize::try_from(slot_leader_window) + .unwrap_or(usize::MAX) + .saturating_add(1) + .min(65_536); + Self { + known_pubkey_verifiers: Vec::new(), + slot_leaders: HashMap::with_capacity(slot_state_capacity), + unknown_slots: HashMap::with_capacity(slot_state_capacity), + pending_added_slot_leaders: HashMap::with_capacity(slot_state_capacity), + pending_updated_slot_leaders: HashMap::with_capacity(slot_state_capacity), + pending_removed_slots: HashSet::with_capacity(slot_state_capacity), + latest_slot: 0, + has_latest_slot: false, + slot_leader_window, + signature_cache: SignatureCache::new(signature_cache_capacity), + unknown_retry, + } + } + + #[cfg(test)] + #[must_use] + pub(crate) fn new_baseline( + signature_cache_capacity: usize, + slot_leader_window: u64, + unknown_retry: Duration, ) -> Self { Self { known_pubkey_verifiers: Vec::new(), @@ -70,8 +96,12 @@ impl ShredVerifier { pub fn set_known_pubkeys(&mut self, mut pubkeys: Vec<[u8; 32]>) { pubkeys.sort_unstable(); pubkeys.dedup(); + self.set_known_pubkeys_sorted(&pubkeys); + } + + pub(crate) fn set_known_pubkeys_sorted(&mut self, pubkeys: &[[u8; 32]]) { let mut known_pubkey_verifiers = Vec::with_capacity(pubkeys.len()); - for pubkey in pubkeys { + for &pubkey in pubkeys { let Ok(verifying_key) = VerifyingKey::from_bytes(&pubkey) else { continue; }; diff --git a/crates/sof-observer/tests/derived_state_runtime_restart_e2e.rs b/crates/sof-observer/tests/derived_state_runtime_restart_e2e.rs index 55e4d9df..25efc7f8 100644 --- a/crates/sof-observer/tests/derived_state_runtime_restart_e2e.rs +++ b/crates/sof-observer/tests/derived_state_runtime_restart_e2e.rs @@ -3,10 +3,11 @@ #![cfg(feature = "kernel-bypass")] use std::{ + env, net::{IpAddr, Ipv4Addr, SocketAddr}, path::PathBuf, sync::{Arc, Mutex}, - time::{Duration, SystemTime, UNIX_EPOCH}, + time::{Duration, SystemTime}, }; use sof::{ @@ -24,6 +25,7 @@ use sof::{ runtime::{self, DerivedStateReplayConfig, DerivedStateRuntimeConfig, RuntimeSetup}, shred::wire::SIZE_OF_DATA_SHRED_HEADERS, }; +use sof_support::time_support::current_unix_nanos; use sof_types::PubkeyBytes; use solana_sdk_ids::vote; use solana_signature::Signature; @@ -144,10 +146,8 @@ impl DerivedStateConsumer for PersistedCheckpointConsumer { } fn unique_test_dir(name: &str) -> PathBuf { - let unique = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_or(0_u128, |duration| duration.as_nanos()); - std::env::temp_dir().join(format!( + let unique = current_unix_nanos(); + env::temp_dir().join(format!( "sof-derived-state-runtime-{name}-{}-{unique}", std::process::id() )) diff --git a/crates/sof-solana-compat/Cargo.toml b/crates/sof-solana-compat/Cargo.toml index 86c245c2..1085cc38 100644 --- a/crates/sof-solana-compat/Cargo.toml +++ b/crates/sof-solana-compat/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sof-solana-compat" -version = "0.18.1" +version = "0.18.2" edition.workspace = true description = "Explicit Solana-coupled compatibility layer for SOF and sof-tx" license = "Apache-2.0 OR MIT" @@ -16,8 +16,8 @@ workspace = true [dependencies] async-trait = "0.1" -sof-tx = { path = "../sof-tx", version = "0.18.1" } -sof-types = { path = "../sof-types", version = "0.18.1", features = ["solana-compat"] } +sof-tx = { path = "../sof-tx", version = "0.18.2" } +sof-types = { path = "../sof-types", version = "0.18.2", features = ["solana-compat"] } bincode = "1.3.3" solana-compute-budget-interface = "3.0.0" solana-keypair = "3.0.1" @@ -30,6 +30,6 @@ solana-transaction = { version = "3.0.2", features = ["bincode"] } thiserror = "2.0" [dev-dependencies] -sof = { path = "../sof-observer", version = "0.18.1", default-features = false } -sof-tx = { path = "../sof-tx", version = "0.18.1", features = ["sof-adapters"] } +sof = { path = "../sof-observer", version = "0.18.2", default-features = false } +sof-tx = { path = "../sof-tx", version = "0.18.2", features = ["sof-adapters"] } tokio = { version = "1.48", features = ["macros", "rt-multi-thread"] } diff --git a/crates/sof-support/Cargo.toml b/crates/sof-support/Cargo.toml new file mode 100644 index 00000000..9d2e78d2 --- /dev/null +++ b/crates/sof-support/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "sof-support" +version = "0.18.2" +edition.workspace = true +description = "Shared internal support utilities reused across SOF crates" +license = "Apache-2.0 OR MIT" +repository = "https://github.com/Lythaeon/sof" +homepage = "https://github.com/Lythaeon/sof" + +[lints] +workspace = true + +[dependencies] +sof-types = { version = "0.18.2", path = "../sof-types" } diff --git a/crates/sof-support/src/lib.rs b/crates/sof-support/src/lib.rs new file mode 100644 index 00000000..82e14880 --- /dev/null +++ b/crates/sof-support/src/lib.rs @@ -0,0 +1,331 @@ +//! Shared internal support helpers for SOF workspace crates. + +use std::{env, time::Duration}; + +use sof_types::{PubkeyBytes, SignatureBytes}; + +/// Benchmark helper utilities reused across SOF profiling fixtures. +pub mod bench { + use super::Duration; + + /// Reads a positive profiling iteration count from `SOF_PROFILE_ITERATIONS`. + #[must_use] + pub fn profile_iterations(default: usize) -> usize { + super::env_support::read_positive_usize("SOF_PROFILE_ITERATIONS", default) + } + + /// Returns the average nanoseconds spent per iteration. + #[must_use] + pub fn avg_ns_per_iteration(elapsed: Duration, iterations: I) -> u128 + where + I: TryInto, + { + let iterations = iterations + .try_into() + .ok() + .filter(|value| *value > 0) + .unwrap_or(1); + elapsed.as_nanos().checked_div(iterations).unwrap_or(0) + } +} + +/// Environment parsing helpers reused across profiling fixtures and tests. +pub mod env_support { + use super::env; + + /// Reads one positive `usize` from an environment variable, or returns `default`. + #[must_use] + pub fn read_positive_usize(name: &str, default: usize) -> usize { + env::var(name) + .ok() + .and_then(|value| value.parse::().ok()) + .filter(|value| *value > 0) + .unwrap_or(default) + } +} + +/// Collection helpers reused across runtime caches and provider adapters. +pub mod collections_support { + use std::collections::HashMap; + + /// Prunes one slot-keyed map down to the retained recent window once the threshold is crossed. + pub fn prune_recent_slots( + slot_states: &mut HashMap, + slot: u64, + retained_lag: u64, + prune_threshold: usize, + ) { + if slot_states.len() <= prune_threshold { + return; + } + let slot_floor = slot.saturating_sub(retained_lag); + slot_states.retain(|tracked_slot, _| *tracked_slot >= slot_floor); + } +} + +/// Typed byte-slice conversion helpers reused across provider adapters. +pub mod bytes { + use super::{PubkeyBytes, SignatureBytes}; + + /// Converts one 64-byte signature slice into `SignatureBytes`. + /// + /// # Errors + /// + /// Returns the error produced by `on_error` when `bytes` is not exactly 64 bytes long. + pub fn signature_bytes_from_slice(bytes: &[u8], on_error: F) -> Result + where + F: FnOnce() -> E, + { + let raw: [u8; 64] = bytes.try_into().map_err(|_error| on_error())?; + Ok(SignatureBytes::from(raw)) + } + + /// Converts one 32-byte pubkey slice into `PubkeyBytes`. + /// + /// # Errors + /// + /// Returns the error produced by `on_error` when `bytes` is not exactly 32 bytes long. + pub fn pubkey_bytes_from_slice(bytes: &[u8], on_error: F) -> Result + where + F: FnOnce() -> E, + { + let raw: [u8; 32] = bytes.try_into().map_err(|_error| on_error())?; + Ok(PubkeyBytes::from(raw)) + } +} + +/// Short-vector parsing helpers reused across serialized Solana payload readers. +pub mod short_vec { + /// Partial short-vector decode failure. + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub enum ShortVecDecodeError { + /// The payload ended before the short-vector length was fully decoded. + Incomplete, + /// The payload encoded an invalid short-vector length. + Invalid, + } + + /// Decodes one Solana short-vector length from `payload`. + #[must_use] + pub fn decode_short_u16_len(payload: &[u8], offset: &mut usize) -> Option { + let mut value = 0_usize; + let mut shift = 0_u32; + for byte_index in 0..3 { + let byte = usize::from(*payload.get(*offset)?); + *offset = (*offset).saturating_add(1); + value |= (byte & 0x7f) << shift; + if byte & 0x80 == 0 { + return Some(value); + } + shift = shift.saturating_add(7); + if byte_index == 2 { + return None; + } + } + None + } + + /// Decodes one Solana short-vector length prefix from the start of `payload`. + /// + /// Returns the decoded length together with the payload offset immediately + /// after the prefix bytes. + #[must_use] + pub fn decode_short_u16_len_prefix(payload: &[u8]) -> Option<(usize, usize)> { + let mut offset = 0; + let value = decode_short_u16_len(payload, &mut offset)?; + Some((value, offset)) + } + + /// Decodes one Solana short-vector length from a possibly partial payload. + /// + /// # Errors + /// + /// Returns [`ShortVecDecodeError::Incomplete`] when the payload ends before + /// the length is fully decoded, and [`ShortVecDecodeError::Invalid`] for an + /// invalid short-vector encoding. + pub fn decode_short_u16_len_partial( + payload: &[u8], + offset: &mut usize, + ) -> Result { + let mut value = 0_usize; + let mut shift = 0_u32; + for byte_index in 0..3 { + let byte = usize::from( + *payload + .get(*offset) + .ok_or(ShortVecDecodeError::Incomplete)?, + ); + *offset = (*offset) + .checked_add(1) + .ok_or(ShortVecDecodeError::Invalid)?; + value |= (byte & 0x7f) << shift; + if byte & 0x80 == 0 { + return Ok(value); + } + shift = shift.saturating_add(7); + if byte_index == 2 { + return Err(ShortVecDecodeError::Invalid); + } + } + Err(ShortVecDecodeError::Invalid) + } +} + +/// Duration helpers reused across transport adapters. +pub mod time_support { + use super::Duration; + + /// Returns whole seconds rounded up, preserving non-zero sub-second values. + #[must_use] + pub const fn duration_secs_ceil(duration: Duration) -> u64 { + let secs = duration.as_secs(); + if duration.subsec_nanos() == 0 { + secs + } else { + secs.saturating_add(1) + } + } + + /// Returns one duration in whole milliseconds, saturating at `u64::MAX`. + #[must_use] + pub fn duration_millis_u64(duration: Duration) -> u64 { + duration.as_millis().min(u128::from(u64::MAX)) as u64 + } + + /// Returns `duration` unless it is zero, in which case returns `fallback`. + #[must_use] + pub const fn nonzero_duration_or(duration: Duration, fallback: Duration) -> Duration { + if duration.is_zero() { + fallback + } else { + duration + } + } + + /// Returns the current Unix timestamp in milliseconds, saturating at `u64::MAX`. + #[must_use] + pub fn current_unix_ms() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_or(0, duration_millis_u64) + } + + /// Returns the current Unix timestamp in whole seconds. + #[must_use] + pub fn current_unix_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) + } + + /// Returns the current Unix timestamp in nanoseconds as one `u128`. + #[must_use] + pub fn current_unix_nanos() -> u128 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_or(0, |duration| duration.as_nanos()) + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::collections_support::prune_recent_slots; + use super::short_vec::{ + ShortVecDecodeError, decode_short_u16_len, decode_short_u16_len_partial, + decode_short_u16_len_prefix, + }; + use super::time_support::{ + current_unix_ms, current_unix_nanos, current_unix_secs, duration_millis_u64, + duration_secs_ceil, nonzero_duration_or, + }; + + #[test] + fn duration_secs_ceil_rounds_subsecond_values_up() { + assert_eq!(duration_secs_ceil(Duration::from_secs(2)), 2); + assert_eq!(duration_secs_ceil(Duration::from_millis(1)), 1); + assert_eq!(duration_secs_ceil(Duration::from_millis(1500)), 2); + } + + #[test] + fn current_unix_ms_is_monotonic_enough_for_smoke_check() { + let first = current_unix_ms(); + let second = current_unix_ms(); + assert!(second >= first); + } + + #[test] + fn duration_millis_u64_saturates() { + assert_eq!(duration_millis_u64(Duration::from_millis(7)), 7); + assert_eq!(duration_millis_u64(Duration::MAX), u64::MAX); + } + + #[test] + fn nonzero_duration_or_clamps_zero() { + assert_eq!( + nonzero_duration_or(Duration::ZERO, Duration::from_secs(7)), + Duration::from_secs(7) + ); + assert_eq!( + nonzero_duration_or(Duration::from_millis(5), Duration::from_secs(7)), + Duration::from_millis(5) + ); + } + + #[test] + fn current_unix_time_helpers_are_nonzero_or_zero_safely() { + assert!(current_unix_secs() <= current_unix_ms() / 1_000 + 1); + assert!(current_unix_nanos() / 1_000_000 <= u128::from(current_unix_ms()) + 1); + } + + #[test] + fn short_vec_decode_matches_compact_lengths() { + let mut single_byte_offset = 0; + assert_eq!( + decode_short_u16_len(&[0x7f], &mut single_byte_offset), + Some(127) + ); + assert_eq!(single_byte_offset, 1); + + let mut two_byte_offset = 0; + assert_eq!( + decode_short_u16_len(&[0x80, 0x01], &mut two_byte_offset), + Some(128) + ); + assert_eq!(two_byte_offset, 2); + } + + #[test] + fn short_vec_decode_partial_distinguishes_incomplete_and_invalid() { + let mut incomplete_offset = 0; + assert_eq!( + decode_short_u16_len_partial(&[0x80], &mut incomplete_offset), + Err(ShortVecDecodeError::Incomplete) + ); + + let mut invalid_offset = 0; + assert_eq!( + decode_short_u16_len_partial(&[0x80, 0x80, 0x80], &mut invalid_offset), + Err(ShortVecDecodeError::Invalid) + ); + } + + #[test] + fn short_vec_decode_prefix_returns_offset() { + assert_eq!(decode_short_u16_len_prefix(&[0x7f]), Some((127, 1))); + assert_eq!(decode_short_u16_len_prefix(&[0x80, 0x01]), Some((128, 2))); + } + + #[test] + fn prune_recent_slots_drops_old_entries_after_threshold() { + let mut slot_states = (0_u64..10_u64).map(|slot| (slot, slot)).collect(); + + prune_recent_slots(&mut slot_states, 9, 3, 4); + + assert_eq!(slot_states.len(), 4); + assert!(!slot_states.contains_key(&5)); + assert!(slot_states.contains_key(&6)); + assert!(slot_states.contains_key(&9)); + } +} diff --git a/crates/sof-tx/Cargo.toml b/crates/sof-tx/Cargo.toml index d6ec34a6..7c40d1b7 100644 --- a/crates/sof-tx/Cargo.toml +++ b/crates/sof-tx/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sof-tx" -version = "0.18.1" +version = "0.18.2" edition.workspace = true description = "SOF transaction SDK for building and submitting Solana transactions" license = "Apache-2.0 OR MIT" @@ -18,14 +18,14 @@ workspace = true default = [] sof-adapters = ["dep:sof"] kernel-bypass = [] -jito-grpc = ["dep:prost", "dep:tonic"] +jito-grpc = ["dep:prost", "dep:tonic", "dep:tonic-prost"] [dependencies] async-trait = "0.1" base64 = "0.22" bincode = "1.3.3" bs58 = "0.5" -prost = { version = "0.13", optional = true } +prost = { version = "0.14", optional = true } reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -40,11 +40,13 @@ solana-signature = "3.1.0" solana-signer = "3.0.0" solana-system-interface = { version = "3.0.0", features = ["bincode"] } solana-transaction = { version = "3.0.2", features = ["bincode"] } -sof = { version = "0.18.1", path = "../sof-observer", default-features = false, optional = true } -sof-types = { version = "0.18.1", path = "../sof-types", features = ["solana-compat"] } +sof = { version = "0.18.2", path = "../sof-observer", default-features = false, optional = true } +sof-support = { version = "0.18.2", path = "../sof-support" } +sof-types = { version = "0.18.2", path = "../sof-types", features = ["solana-compat"] } thiserror = "2.0" tokio = { version = "1.48", features = ["macros", "rt-multi-thread", "net", "sync", "time"] } -tonic = { version = "0.12", optional = true, default-features = false, features = ["codegen", "prost", "transport", "tls-webpki-roots"] } +tonic = { version = "0.14", optional = true, default-features = false, features = ["codegen", "transport", "tls-webpki-roots"] } +tonic-prost = { version = "0.14", optional = true } arcshift = "0.4.2" [dev-dependencies] diff --git a/crates/sof-tx/README.md b/crates/sof-tx/README.md index ee29c86f..0f979efb 100644 --- a/crates/sof-tx/README.md +++ b/crates/sof-tx/README.md @@ -43,20 +43,20 @@ cargo add sof-tx Enable SOF runtime adapters when you want provider values from live `sof` plugin events: ```toml -sof-tx = { version = "0.18.1", features = ["sof-adapters"] } +sof-tx = { version = "0.18.2", features = ["sof-adapters"] } ``` Enable `kernel-bypass` transport hooks for kernel-bypass direct submit integrations: ```toml -sof-tx = { version = "0.18.1", features = ["kernel-bypass"] } +sof-tx = { version = "0.18.2", features = ["kernel-bypass"] } ``` Use `sof-solana-compat` when you want the Solana-native `TxBuilder` plus unsigned convenience submission helpers on top of `sof-tx`: ```toml -sof-solana-compat = "0.18.1" +sof-solana-compat = "0.18.2" ``` ## Quick Start diff --git a/crates/sof-tx/examples/kernel_bypass_af_xdp.rs b/crates/sof-tx/examples/kernel_bypass_af_xdp.rs index e320edc3..5049bf71 100644 --- a/crates/sof-tx/examples/kernel_bypass_af_xdp.rs +++ b/crates/sof-tx/examples/kernel_bypass_af_xdp.rs @@ -12,11 +12,15 @@ fn main() { #[cfg(all(target_os = "linux", feature = "kernel-bypass"))] use std::{ + env, ffi::CString, io, - net::{Ipv4Addr, SocketAddr, UdpSocket}, + net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket}, + path::Path, process::Command, + slice, sync::{Arc, Mutex}, + thread, time::Duration, }; @@ -30,6 +34,8 @@ use sof_tx::{ submit::{DirectSubmitConfig, DirectSubmitTransport}, }; #[cfg(all(target_os = "linux", feature = "kernel-bypass"))] +use tokio::runtime::Builder; +#[cfg(all(target_os = "linux", feature = "kernel-bypass"))] use xdp::{ RingConfigBuilder, Umem, WakableRings, packet::PacketError, @@ -139,8 +145,8 @@ impl KernelBypassDatagramSocket for AfXdpKernelBypassSocket { /// Example helper used by this binary. async fn send_to(&self, payload: &[u8], target: SocketAddr) -> io::Result { let dst = match target.ip() { - std::net::IpAddr::V4(ip) => ip, - std::net::IpAddr::V6(_) => { + IpAddr::V4(ip) => ip, + IpAddr::V6(_) => { return Err(io::Error::new( io::ErrorKind::InvalidInput, "AF_XDP example socket only supports IPv4 targets", @@ -457,7 +463,7 @@ fn setup_veth_pair() -> Result<(), Box> { #[cfg(all(target_os = "linux", feature = "kernel-bypass"))] /// Example helper used by this binary. -fn run_unshare(current_exe: &std::path::Path) -> Result<(), Box> { +fn run_unshare(current_exe: &Path) -> Result<(), Box> { for candidate in [ "/usr/bin/unshare", "/bin/unshare", @@ -524,15 +530,13 @@ fn run_inner() -> Result<(), Box> { ..DirectSubmitConfig::default() }; - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; + let runtime = Builder::new_current_thread().enable_all().build()?; let selected = runtime .block_on(async { transport .submit_direct( &payload, - std::slice::from_ref(&target), + slice::from_ref(&target), RoutingPolicy::default(), &config, ) @@ -543,7 +547,7 @@ fn run_inner() -> Result<(), Box> { let mut sender_tx_after = sender_tx_before; let mut receiver_rx_after = receiver_rx_before; for _ in 0..10 { - std::thread::sleep(Duration::from_millis(50)); + thread::sleep(Duration::from_millis(50)); (sender_tx_after, _) = read_link_packets(VETH_SENDER)?; (_, receiver_rx_after) = read_link_packets(VETH_RECEIVER)?; if sender_tx_after > sender_tx_before && receiver_rx_after > receiver_rx_before { @@ -576,8 +580,8 @@ fn run_inner() -> Result<(), Box> { #[cfg(all(target_os = "linux", feature = "kernel-bypass"))] /// Example helper used by this binary. fn main() -> Result<(), Box> { - if std::env::var_os(INNER_ENV).is_none() { - let current_exe = std::env::current_exe()?; + if env::var_os(INNER_ENV).is_none() { + let current_exe = env::current_exe()?; return run_unshare(¤t_exe); } run_inner() diff --git a/crates/sof-tx/src/adapters/derived_state.rs b/crates/sof-tx/src/adapters/derived_state.rs index 3bab8cb6..b873f88d 100644 --- a/crates/sof-tx/src/adapters/derived_state.rs +++ b/crates/sof-tx/src/adapters/derived_state.rs @@ -1,15 +1,17 @@ //! `sof` derived-state adapter that bridges replayable control-plane state into `sof-tx` providers. -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use arcshift::ArcShift; use sof::framework::{ DerivedStateCheckpoint, DerivedStateCheckpointStore, DerivedStateConsumer, - DerivedStateConsumerFault, DerivedStateControlPlaneQuality, DerivedStateControlPlaneStateEvent, - DerivedStateFeedEnvelope, DerivedStateFeedEvent, DerivedStatePersistedCheckpoint, + DerivedStateConsumerConfig, DerivedStateConsumerFault, DerivedStateControlPlaneQuality, + DerivedStateControlPlaneStateEvent, DerivedStateFeedEnvelope, DerivedStateFeedEvent, + DerivedStatePersistedCheckpoint, }; use crate::{ + adapters::TxProviderControlPlaneQuality, adapters::common::{ TxProviderAdapterConfig, TxProviderAdapterCore, TxProviderAdapterSnapshot, TxProviderControlPlaneSnapshot, TxProviderFlowSafetyPolicy, TxProviderFlowSafetyReport, @@ -47,7 +49,7 @@ impl DerivedStateTxProviderAdapterPersistence { /// Returns the persisted checkpoint path. #[must_use] - pub fn checkpoint_path(&self) -> &std::path::Path { + pub fn checkpoint_path(&self) -> &Path { &self.checkpoint_path } } @@ -211,16 +213,10 @@ impl TxFlowSafetySource for DerivedStateTxProviderAdapter { let report = self.evaluate_flow_safety(TxProviderFlowSafetyPolicy::default()); TxFlowSafetySnapshot { quality: match report.quality { - crate::adapters::TxProviderControlPlaneQuality::Stable => { - TxFlowSafetyQuality::Stable - } - crate::adapters::TxProviderControlPlaneQuality::Degraded => { - TxFlowSafetyQuality::Degraded - } - crate::adapters::TxProviderControlPlaneQuality::Stale => { - TxFlowSafetyQuality::Stale - } - crate::adapters::TxProviderControlPlaneQuality::IncompleteControlPlane => { + TxProviderControlPlaneQuality::Stable => TxFlowSafetyQuality::Stable, + TxProviderControlPlaneQuality::Degraded => TxFlowSafetyQuality::Degraded, + TxProviderControlPlaneQuality::Stale => TxFlowSafetyQuality::Stale, + TxProviderControlPlaneQuality::IncompleteControlPlane => { TxFlowSafetyQuality::IncompleteControlPlane } }, @@ -305,8 +301,8 @@ impl DerivedStateConsumer for DerivedStateTxProviderAdapter { Ok(Some(checkpoint)) } - fn config(&self) -> sof::framework::DerivedStateConsumerConfig { - sof::framework::DerivedStateConsumerConfig::new().with_control_plane_observed() + fn config(&self) -> DerivedStateConsumerConfig { + DerivedStateConsumerConfig::new().with_control_plane_observed() } fn apply( @@ -357,25 +353,23 @@ impl DerivedStateConsumer for DerivedStateTxProviderAdapter { #[cfg(test)] #[allow(clippy::panic)] mod tests { - use std::{ - env, fs, - net::SocketAddr, - path::PathBuf, - sync::Arc, - time::{SystemTime, UNIX_EPOCH}, - }; + use sof_support::time_support::current_unix_nanos; + use std::{env, fs, net::SocketAddr, path::PathBuf, process::id, sync::Arc, time::UNIX_EPOCH}; use sof::framework::{ BranchReorgedEvent, CheckpointBarrierEvent, CheckpointBarrierReason, ClusterNodeInfo, ClusterTopologyEvent, ControlPlaneSource, DerivedStateConsumer, DerivedStateFeedEnvelope, - DerivedStateFeedEvent, FeedSequence, FeedSessionId, FeedWatermarks, LeaderScheduleEntry, - LeaderScheduleEvent, ObservedRecentBlockhashEvent, SlotStatusChangedEvent, + DerivedStateFeedEvent, FeedSequence, FeedSessionId, FeedWatermarks, ForkSlotStatus, + LeaderScheduleEntry, LeaderScheduleEvent, ObservedRecentBlockhashEvent, + SlotStatusChangedEvent, }; use sof_types::PubkeyBytes; use solana_pubkey::Pubkey; use super::*; + use crate::adapters::TxProviderFlowSafetyIssue; + fn addr(port: u16) -> SocketAddr { SocketAddr::from(([127, 0, 0, 1], port)) } @@ -414,11 +408,8 @@ mod tests { fn unique_temp_path(label: &str) -> PathBuf { env::temp_dir().join(format!( "sof-tx-{label}-{}-{}.json", - std::process::id(), - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_nanos()) - .unwrap_or_default() + id(), + current_unix_nanos() )) } @@ -438,17 +429,17 @@ mod tests { assert!( report .issues - .contains(&crate::adapters::TxProviderFlowSafetyIssue::MissingRecentBlockhash) + .contains(&TxProviderFlowSafetyIssue::MissingRecentBlockhash) ); assert!( report .issues - .contains(&crate::adapters::TxProviderFlowSafetyIssue::MissingClusterTopology) + .contains(&TxProviderFlowSafetyIssue::MissingClusterTopology) ); assert!( report .issues - .contains(&crate::adapters::TxProviderFlowSafetyIssue::MissingLeaderSchedule) + .contains(&TxProviderFlowSafetyIssue::MissingLeaderSchedule) ); } @@ -513,7 +504,7 @@ mod tests { slot: 41, parent_slot: Some(40), previous_status: None, - status: sof::framework::ForkSlotStatus::Processed, + status: ForkSlotStatus::Processed, }, ))), ); diff --git a/crates/sof-tx/src/adapters/plugin_host.rs b/crates/sof-tx/src/adapters/plugin_host.rs index a0c42ccf..da20186a 100644 --- a/crates/sof-tx/src/adapters/plugin_host.rs +++ b/crates/sof-tx/src/adapters/plugin_host.rs @@ -5,11 +5,12 @@ use std::net::SocketAddr; use async_trait::async_trait; use sof::framework::{ ClusterTopologyEvent, LeaderScheduleEvent, ObservedRecentBlockhashEvent, ObserverPlugin, - PluginHost, + PluginConfig, PluginHost, }; use sof_types::PubkeyBytes; use crate::{ + adapters::TxProviderControlPlaneQuality, adapters::common::{ TxProviderAdapterConfig, TxProviderAdapterCore, TxProviderControlPlaneSnapshot, TxProviderFlowSafetyPolicy, TxProviderFlowSafetyReport, take_next_leader_identity_targets, @@ -139,12 +140,10 @@ impl TxFlowSafetySource for PluginHostTxProviderAdapter { ..TxProviderFlowSafetyPolicy::default() }); let quality = match report.quality { - crate::adapters::TxProviderControlPlaneQuality::Stable => TxFlowSafetyQuality::Stable, - crate::adapters::TxProviderControlPlaneQuality::Degraded => { - TxFlowSafetyQuality::Degraded - } - crate::adapters::TxProviderControlPlaneQuality::Stale => TxFlowSafetyQuality::Stale, - crate::adapters::TxProviderControlPlaneQuality::IncompleteControlPlane => { + TxProviderControlPlaneQuality::Stable => TxFlowSafetyQuality::Stable, + TxProviderControlPlaneQuality::Degraded => TxFlowSafetyQuality::Degraded, + TxProviderControlPlaneQuality::Stale => TxFlowSafetyQuality::Stale, + TxProviderControlPlaneQuality::IncompleteControlPlane => { TxFlowSafetyQuality::IncompleteControlPlane } }; @@ -177,8 +176,8 @@ impl ObserverPlugin for PluginHostTxProviderAdapter { "sof-tx-provider-adapter" } - fn config(&self) -> sof::framework::PluginConfig { - let config = sof::framework::PluginConfig::new() + fn config(&self) -> PluginConfig { + let config = PluginConfig::new() .with_recent_blockhash() .with_cluster_topology(); if self.leader_schedule_enabled { @@ -206,10 +205,18 @@ mod tests { use std::net::SocketAddr; use super::*; - use sof::framework::{ClusterNodeInfo, ControlPlaneSource, LeaderScheduleEntry, PluginHost}; + use sof::{ + event::ForkSlotStatus, + framework::{ + ClusterNodeInfo, ControlPlaneSource, LeaderScheduleEntry, PluginHost, + SlotStatusChangedEvent, + }, + }; use sof_types::PubkeyBytes; use solana_pubkey::Pubkey; + use crate::adapters::TxProviderFlowSafetyIssue; + fn addr(port: u16) -> SocketAddr { SocketAddr::from(([127, 0, 0, 1], port)) } @@ -538,14 +545,12 @@ mod tests { vec![LeaderScheduleEntry { slot: 200, leader }], )) .await; - adapter - .core - .apply_slot_status(sof::framework::SlotStatusChangedEvent { - slot: 200, - parent_slot: Some(199), - previous_status: Some(sof::event::ForkSlotStatus::Processed), - status: sof::event::ForkSlotStatus::Confirmed, - }); + adapter.core.apply_slot_status(SlotStatusChangedEvent { + slot: 200, + parent_slot: Some(199), + previous_status: Some(ForkSlotStatus::Processed), + status: ForkSlotStatus::Confirmed, + }); let report = adapter.evaluate_flow_safety(TxProviderFlowSafetyPolicy { max_recent_blockhash_slot_lag: Some(16), @@ -553,12 +558,14 @@ mod tests { }); assert!(!report.is_safe()); - assert!(report.issues.contains( - &crate::adapters::TxProviderFlowSafetyIssue::StaleRecentBlockhash { - slot_lag: 190, - max_allowed: 16, - } - )); + assert!( + report + .issues + .contains(&TxProviderFlowSafetyIssue::StaleRecentBlockhash { + slot_lag: 190, + max_allowed: 16, + }) + ); } #[tokio::test] @@ -579,17 +586,15 @@ mod tests { adapter .on_cluster_topology(topology_snapshot(vec![node(leader, 9441)])) .await; - adapter - .core - .apply_slot_status(sof::framework::SlotStatusChangedEvent { - slot: 100, - parent_slot: Some(99), - previous_status: Some(sof::event::ForkSlotStatus::Processed), - status: sof::event::ForkSlotStatus::Confirmed, - }); + adapter.core.apply_slot_status(SlotStatusChangedEvent { + slot: 100, + parent_slot: Some(99), + previous_status: Some(ForkSlotStatus::Processed), + status: ForkSlotStatus::Confirmed, + }); let snapshot = adapter.toxic_flow_snapshot(); - assert_eq!(snapshot.quality, crate::submit::TxFlowSafetyQuality::Stable); + assert_eq!(snapshot.quality, TxFlowSafetyQuality::Stable); assert!(snapshot.issues.is_empty()); assert_eq!( adapter.current_leader(), diff --git a/crates/sof-tx/src/providers.rs b/crates/sof-tx/src/providers.rs index 9e33b839..3aefee32 100644 --- a/crates/sof-tx/src/providers.rs +++ b/crates/sof-tx/src/providers.rs @@ -6,11 +6,19 @@ use std::{ time::Duration, }; +use reqwest::redirect::Policy; use serde::{Deserialize, Serialize}; +use serde_json::from_slice as json_from_slice; +use sof_support::time_support::nonzero_duration_or; use sof_types::PubkeyBytes; use crate::submit::SubmitTransportError; +/// Maximum HTTP body size accepted from `getLatestBlockhash` RPC responses. +const MAX_BLOCKHASH_RPC_RESPONSE_BYTES: usize = 64 * 1024; +/// Default timeout used for recent-blockhash HTTP requests. +const DEFAULT_RPC_REQUEST_TIMEOUT: Duration = Duration::from_secs(10); + /// One leader/validator target that can receive transactions directly. #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct LeaderTarget { @@ -44,7 +52,7 @@ pub struct RpcRecentBlockhashProviderConfig { impl Default for RpcRecentBlockhashProviderConfig { fn default() -> Self { Self { - request_timeout: Duration::from_secs(10), + request_timeout: DEFAULT_RPC_REQUEST_TIMEOUT, } } } @@ -81,8 +89,12 @@ impl RpcRecentBlockhashProvider { config: &RpcRecentBlockhashProviderConfig, ) -> Result { let rpc_url = rpc_url.into(); + let request_timeout = + nonzero_duration_or(config.request_timeout, DEFAULT_RPC_REQUEST_TIMEOUT); let client = reqwest::Client::builder() - .timeout(config.request_timeout) + .redirect(Policy::none()) + .connect_timeout(request_timeout) + .timeout(request_timeout) .build() .map_err(|error| SubmitTransportError::Config { message: error.to_string(), @@ -250,18 +262,21 @@ async fn fetch_latest_blockhash( .map_err(|error| SubmitTransportError::Failure { message: error.to_string(), })?; + if response.status().is_redirection() { + return Err(SubmitTransportError::Failure { + message: format!("unexpected redirect response: {}", response.status()), + }); + } let response = response .error_for_status() .map_err(|error| SubmitTransportError::Failure { message: error.to_string(), })?; + let response_body = read_http_response_bytes_bounded(response).await?; let parsed: LatestBlockhashRpcResponse = - response - .json() - .await - .map_err(|error| SubmitTransportError::Failure { - message: error.to_string(), - })?; + json_from_slice(&response_body).map_err(|error| SubmitTransportError::Failure { + message: error.to_string(), + })?; if let Some(result) = parsed.result { return parse_blockhash(&result.value.blockhash); } @@ -275,6 +290,48 @@ async fn fetch_latest_blockhash( }) } +/// Reads one RPC response body while enforcing a fixed maximum byte budget. +async fn read_http_response_bytes_bounded( + mut response: reqwest::Response, +) -> Result, SubmitTransportError> { + if response + .content_length() + .is_some_and(|content_length| content_length > MAX_BLOCKHASH_RPC_RESPONSE_BYTES as u64) + { + return Err(SubmitTransportError::Failure { + message: format!( + "response body exceeded max size of {MAX_BLOCKHASH_RPC_RESPONSE_BYTES} bytes" + ), + }); + } + + let initial_capacity = response + .content_length() + .and_then(|content_length| usize::try_from(content_length).ok()) + .unwrap_or(0) + .min(MAX_BLOCKHASH_RPC_RESPONSE_BYTES); + let mut body = Vec::with_capacity(initial_capacity); + while let Some(chunk) = + response + .chunk() + .await + .map_err(|error| SubmitTransportError::Failure { + message: error.to_string(), + })? + { + let remaining = MAX_BLOCKHASH_RPC_RESPONSE_BYTES.saturating_sub(body.len()); + if chunk.len() > remaining { + return Err(SubmitTransportError::Failure { + message: format!( + "response body exceeded max size of {MAX_BLOCKHASH_RPC_RESPONSE_BYTES} bytes" + ), + }); + } + body.extend_from_slice(&chunk); + } + Ok(body) +} + /// Decodes one base58 blockhash string into the byte format used by `TxBuilder`. fn parse_blockhash(blockhash: &str) -> Result<[u8; 32], SubmitTransportError> { let decoded = @@ -300,6 +357,37 @@ mod tests { net::TcpListener, }; + async fn spawn_http_response_server(response: String) -> String { + let listener = TcpListener::bind("127.0.0.1:0").await; + assert!(listener.is_ok()); + let listener = listener.unwrap_or_else(|error| panic!("{error}")); + let addr = listener.local_addr(); + assert!(addr.is_ok()); + let addr = addr.unwrap_or_else(|error| panic!("{error}")); + tokio::spawn(async move { + let accepted = listener.accept().await; + assert!(accepted.is_ok()); + let (mut stream, _) = accepted.unwrap_or_else(|error| panic!("{error}")); + let mut buffer = [0_u8; 4096]; + let read = stream.read(&mut buffer).await; + assert!(read.is_ok()); + let write = stream.write_all(response.as_bytes()).await; + assert!(write.is_ok()); + }); + format!("http://{addr}") + } + + #[test] + fn rpc_recent_blockhash_provider_accepts_zero_timeout_config() { + let provider = RpcRecentBlockhashProvider::with_config( + "http://127.0.0.1:8899", + &RpcRecentBlockhashProviderConfig { + request_timeout: Duration::ZERO, + }, + ); + assert!(provider.is_ok()); + } + #[tokio::test] async fn rpc_recent_blockhash_provider_fetches_initial_value() { let expected = [9_u8; 32]; @@ -347,4 +435,44 @@ mod tests { let joined = server.await; assert!(joined.is_ok()); } + + #[tokio::test] + async fn rpc_recent_blockhash_provider_rejects_redirects() { + let target = spawn_http_response_server( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: 0\r\nconnection: close\r\n\r\n" + .to_owned(), + ) + .await; + let endpoint = spawn_http_response_server(format!( + "HTTP/1.1 307 Temporary Redirect\r\nlocation: {target}\r\ncontent-length: 0\r\nconnection: close\r\n\r\n" + )) + .await; + + let provider = RpcRecentBlockhashProvider::new(endpoint); + assert!(provider.is_ok()); + let provider = provider.unwrap_or_else(|error| panic!("{error}")); + let error = match provider.refresh().await { + Ok(_blockhash) => panic!("redirect should fail"), + Err(error) => error, + }; + assert!(error.to_string().contains("redirect")); + } + + #[tokio::test] + async fn rpc_recent_blockhash_provider_rejects_oversized_responses() { + let endpoint = spawn_http_response_server(format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n", + MAX_BLOCKHASH_RPC_RESPONSE_BYTES.saturating_add(1) + )) + .await; + + let provider = RpcRecentBlockhashProvider::new(endpoint); + assert!(provider.is_ok()); + let provider = provider.unwrap_or_else(|error| panic!("{error}")); + let error = match provider.refresh().await { + Ok(_blockhash) => panic!("oversized body should fail"), + Err(error) => error, + }; + assert!(error.to_string().contains("exceeded max size")); + } } diff --git a/crates/sof-tx/src/routing.rs b/crates/sof-tx/src/routing.rs index 1ff98df8..a80e8a47 100644 --- a/crates/sof-tx/src/routing.rs +++ b/crates/sof-tx/src/routing.rs @@ -1,7 +1,7 @@ //! Routing policy, target selection, and signature-level duplicate suppression. use std::{ - collections::{HashMap, HashSet}, + collections::{HashMap, HashSet, VecDeque, hash_map::Entry}, time::{Duration, Instant}, }; @@ -9,6 +9,9 @@ use sof_types::SignatureBytes; use crate::providers::{LeaderProvider, LeaderTarget}; +/// Initial storage reserved for the signature dedupe window before it grows. +const INITIAL_SIGNATURE_DEDUPER_CAPACITY: usize = 4_096; + /// Routing controls used for direct and hybrid submit paths. #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub struct RoutingPolicy { @@ -93,6 +96,8 @@ pub struct SignatureDeduper { ttl: Duration, /// Last seen timestamps by signature. seen: HashMap, + /// Arrival order for bounded eviction without rescanning the whole map. + order: VecDeque<(SignatureBytes, Instant)>, } impl SignatureDeduper { @@ -101,18 +106,22 @@ impl SignatureDeduper { pub fn new(ttl: Duration) -> Self { Self { ttl: ttl.max(Duration::from_millis(1)), - seen: HashMap::new(), + seen: HashMap::with_capacity(INITIAL_SIGNATURE_DEDUPER_CAPACITY), + order: VecDeque::with_capacity(INITIAL_SIGNATURE_DEDUPER_CAPACITY), } } /// Returns true when signature is new (and records it), false when duplicate. pub fn check_and_insert(&mut self, signature: SignatureBytes, now: Instant) -> bool { self.evict_expired(now); - if self.seen.contains_key(&signature) { - return false; + match self.seen.entry(signature) { + Entry::Occupied(_) => false, + Entry::Vacant(entry) => { + entry.insert(now); + self.order.push_back((signature, now)); + true + } } - let _ = self.seen.insert(signature, now); - true } /// Returns number of signatures currently tracked. @@ -129,14 +138,22 @@ impl SignatureDeduper { /// Removes all expired signature entries. fn evict_expired(&mut self, now: Instant) { - let ttl = self.ttl; - self.seen - .retain(|_, first_seen| now.saturating_duration_since(*first_seen) < ttl); + while let Some((signature, first_seen)) = self.order.front().copied() { + if now.saturating_duration_since(first_seen) < self.ttl { + break; + } + self.order.pop_front(); + if self.seen.get(&signature).copied() == Some(first_seen) { + let _ = self.seen.remove(&signature); + } + } } } #[cfg(test)] mod tests { + use sof_support::{bench::avg_ns_per_iteration, env_support::read_positive_usize}; + use super::*; use crate::providers::{LeaderTarget, StaticLeaderProvider}; @@ -200,4 +217,95 @@ mod tests { assert!(!deduper.check_and_insert(signature, now + Duration::from_millis(5))); assert!(deduper.check_and_insert(signature, now + Duration::from_millis(30))); } + + #[test] + #[ignore = "profiling fixture for signature dedupe churn"] + fn signature_deduper_profile_fixture() { + let iterations = read_positive_usize("SOF_TX_SIGNATURE_DEDUPER_PROFILE_ITERS", 50_000); + let ttl_ms = u64::try_from(read_positive_usize( + "SOF_TX_SIGNATURE_DEDUPER_PROFILE_TTL_MS", + 10_000, + )) + .unwrap_or(10_000); + let mut deduper = SignatureDeduper::new(Duration::from_millis(ttl_ms)); + let start = Instant::now(); + let now = Instant::now(); + + for index in 0..iterations { + let mut signature = [0_u8; 64]; + signature[..8].copy_from_slice(&(index as u64).to_le_bytes()); + assert!(deduper.check_and_insert( + SignatureBytes::from(signature), + now + Duration::from_nanos(index as u64) + )); + } + + let elapsed = start.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); + println!( + "signature_deduper_profile_fixture iterations={} ttl_ms={} entries={} elapsed_us={} avg_ns_per_iteration={} avg_us_per_iteration={:.3}", + iterations, + ttl_ms, + deduper.len(), + elapsed.as_micros(), + avg_ns, + avg_ns as f64 / 1_000.0 + ); + } + + #[test] + #[ignore = "profiling fixture for signature deduper allocation churn"] + fn signature_deduper_allocation_profile_fixture() { + let iterations = read_positive_usize("SOF_TX_SIGNATURE_DEDUPER_PROFILE_ITERS", 50_000); + let ttl_ms = u64::try_from(read_positive_usize( + "SOF_TX_SIGNATURE_DEDUPER_PROFILE_TTL_MS", + 10_000, + )) + .unwrap_or(10_000); + let ttl = Duration::from_millis(ttl_ms); + let mut baseline = signature_deduper_baseline(ttl); + let mut optimized = SignatureDeduper::new(ttl); + let now = Instant::now(); + + let baseline_started = Instant::now(); + for index in 0..iterations { + let signature = make_signature(index); + assert!(baseline.check_and_insert( + SignatureBytes::from(signature), + now + Duration::from_nanos(index as u64) + )); + } + let baseline_elapsed = baseline_started.elapsed(); + + let optimized_started = Instant::now(); + for index in 0..iterations { + let signature = make_signature(index); + assert!(optimized.check_and_insert( + SignatureBytes::from(signature), + now + Duration::from_nanos(index as u64) + )); + } + let optimized_elapsed = optimized_started.elapsed(); + + println!( + "signature_deduper_allocation_profile_fixture iterations={} baseline_us={} optimized_us={}", + iterations, + baseline_elapsed.as_micros(), + optimized_elapsed.as_micros(), + ); + } + + fn signature_deduper_baseline(ttl: Duration) -> SignatureDeduper { + SignatureDeduper { + ttl: ttl.max(Duration::from_millis(1)), + seen: HashMap::new(), + order: VecDeque::new(), + } + } + + fn make_signature(index: usize) -> [u8; 64] { + let mut signature = [0_u8; 64]; + signature[..8].copy_from_slice(&(index as u64).to_le_bytes()); + signature + } } diff --git a/crates/sof-tx/src/submit/client.rs b/crates/sof-tx/src/submit/client.rs index 332f0f9e..1bbf8600 100644 --- a/crates/sof-tx/src/submit/client.rs +++ b/crates/sof-tx/src/submit/client.rs @@ -14,6 +14,7 @@ use std::{ time::{Duration, Instant, SystemTime}, }; +use sof_support::{short_vec::decode_short_u16_len_prefix, time_support::duration_millis_u64}; use sof_types::SignatureBytes; use tokio::{ net::TcpStream, @@ -391,11 +392,11 @@ impl TxSubmitClient { let opportunity_age_ms = context .opportunity_created_at .and_then(|created_at| now.duration_since(created_at).ok()) - .map(|duration| duration.as_millis().min(u128::from(u64::MAX)) as u64); + .map(duration_millis_u64); if let Some(age_ms) = opportunity_age_ms && let Some(max_age) = self.guard_policy.max_opportunity_age { - let max_allowed_ms = max_age.as_millis().min(u128::from(u64::MAX)) as u64; + let max_allowed_ms = duration_millis_u64(max_age); if age_ms > max_allowed_ms { return Err(self.reject_with_outcome( TxToxicFlowRejectionReason::OpportunityStale { @@ -627,20 +628,25 @@ impl TxSubmitClient { plan: SubmitPlan, ) -> Result { let legacy_mode = plan.legacy_mode(); + let task_context = self.route_task_context(); + let direct_transport = self.direct_transport.clone(); + let leader_provider = self.leader_provider.clone(); + let backups = self.backups.clone(); + let policy = self.policy; + let direct_config = self.direct_config.clone().normalized(); let (result_tx, mut result_rx) = mpsc::unbounded_channel(); for (route_idx, route) in plan.routes.iter().copied().enumerate() { - let task_context = self.route_task_context(); + let task_context = task_context.clone(); let tx_bytes = Arc::clone(&tx_bytes); let result_tx = result_tx.clone(); let telemetry = Arc::clone(&self.telemetry); let reporter = self.outcome_reporter.clone(); let flow_safety_source = self.flow_safety_source.clone(); let plan_for_task = plan.clone(); - let direct_transport = self.direct_transport.clone(); - let leader_provider = self.leader_provider.clone(); - let backups = self.backups.clone(); - let policy = self.policy; - let direct_config = self.direct_config.clone().normalized(); + let direct_transport = direct_transport.clone(); + let leader_provider = leader_provider.clone(); + let backups = backups.clone(); + let direct_config = direct_config.clone(); tokio::spawn(async move { let result = submit_one_route_task( route, @@ -969,13 +975,12 @@ impl OutcomeReporterDispatcher { fn shared(reporter: Arc) -> Result, std::io::Error> { let key = reporter_identity(&reporter); let registry = outcome_reporter_registry(); - { - let registry = registry - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - if let Some(existing) = registry.get(&key).and_then(Weak::upgrade) { - return Ok(existing); - } + let mut registry = registry + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + registry.retain(|_key, dispatcher| dispatcher.strong_count() > 0); + if let Some(existing) = registry.get(&key).and_then(Weak::upgrade) { + return Ok(existing); } let (tx, rx) = std_mpsc::sync_channel::(Self::QUEUE_CAPACITY); @@ -994,9 +999,6 @@ impl OutcomeReporterDispatcher { queue_full_warned: AtomicBool::new(false), unavailable_warned: AtomicBool::new(false), }); - let mut registry = registry - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()); let _ = registry.insert(key, Arc::downgrade(&dispatcher)); Ok(dispatcher) } @@ -1054,7 +1056,7 @@ fn reporter_identity(reporter: &Arc) -> usize { /// Extracts the first transaction signature from serialized transaction bytes. fn extract_first_signature(tx_bytes: &[u8]) -> Result, SubmitError> { - let Some((signature_count, offset)) = decode_short_vec_len(tx_bytes) else { + let Some((signature_count, offset)) = decode_short_u16_len_prefix(tx_bytes) else { return Err(decode_signed_bytes_error( "transaction bytes did not contain a valid signature vector prefix", )); @@ -1073,20 +1075,6 @@ fn extract_first_signature(tx_bytes: &[u8]) -> Result, Su Ok(Some(SignatureBytes::new(signature))) } -/// Decodes Solana's short-vec length prefix and returns the decoded length plus payload offset. -fn decode_short_vec_len(bytes: &[u8]) -> Option<(usize, usize)> { - let mut value = 0_usize; - let mut shift = 0_u32; - for (idx, byte) in bytes.iter().copied().take(3).enumerate() { - value |= usize::from(byte & 0x7f) << shift; - if byte & 0x80 == 0 { - return Some((value, idx.saturating_add(1))); - } - shift = shift.saturating_add(7); - } - None -} - /// Builds one signed-byte decode error from a static message. fn decode_signed_bytes_error(message: &'static str) -> SubmitError { SubmitError::DecodeSignedBytes { @@ -1425,3 +1413,31 @@ fn direct_attempt_timeout(direct_config: &DirectSubmitConfig) -> Duration { .saturating_add(direct_config.rebroadcast_interval) .max(Duration::from_secs(8)) } + +#[cfg(test)] +#[allow(clippy::panic)] +mod tests { + use super::*; + + #[derive(Debug)] + struct NoopOutcomeReporter; + + impl TxSubmitOutcomeReporter for NoopOutcomeReporter { + fn record_outcome(&self, _outcome: &TxSubmitOutcome) {} + } + + #[test] + fn reporter_dispatcher_reuses_existing_instance() { + let reporter: Arc = Arc::new(NoopOutcomeReporter); + let first = match OutcomeReporterDispatcher::shared(Arc::clone(&reporter)) { + Ok(dispatcher) => dispatcher, + Err(error) => panic!("first dispatcher failed: {error}"), + }; + let second = match OutcomeReporterDispatcher::shared(reporter) { + Ok(dispatcher) => dispatcher, + Err(error) => panic!("second dispatcher failed: {error}"), + }; + + assert!(Arc::ptr_eq(&first, &second)); + } +} diff --git a/crates/sof-tx/src/submit/jito.rs b/crates/sof-tx/src/submit/jito.rs index 56330122..de6a302a 100644 --- a/crates/sof-tx/src/submit/jito.rs +++ b/crates/sof-tx/src/submit/jito.rs @@ -4,13 +4,19 @@ use std::time::Duration; use async_trait::async_trait; use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD}; -use reqwest::Url; +use reqwest::{Url, redirect::Policy}; use serde::{Deserialize, Serialize}; +use serde_json::from_slice as json_from_slice; +use sof_support::time_support::nonzero_duration_or; use super::{JitoSubmitConfig, JitoSubmitResponse, JitoSubmitTransport, SubmitTransportError}; /// Default Jito mainnet block-engine base URL. const DEFAULT_JITO_BLOCK_ENGINE_URL: &str = "https://mainnet.block-engine.jito.wtf"; +/// Maximum HTTP body size accepted from Jito submit responses. +const MAX_JITO_SUBMIT_RESPONSE_BYTES: usize = 64 * 1024; +/// Default timeout used for Jito HTTP requests. +const DEFAULT_JITO_REQUEST_TIMEOUT: Duration = Duration::from_secs(10); /// Typed Jito mainnet region. #[derive(Debug, Clone, Copy, Eq, PartialEq)] @@ -108,7 +114,7 @@ impl Default for JitoTransportConfig { fn default() -> Self { Self { endpoint: JitoBlockEngineEndpoint::default(), - request_timeout: Duration::from_secs(10), + request_timeout: DEFAULT_JITO_REQUEST_TIMEOUT, } } } @@ -152,8 +158,14 @@ impl JitoJsonRpcTransport { pub fn with_config( transport_config: JitoTransportConfig, ) -> Result { + let request_timeout = nonzero_duration_or( + transport_config.request_timeout, + DEFAULT_JITO_REQUEST_TIMEOUT, + ); let client = reqwest::Client::builder() - .timeout(transport_config.request_timeout) + .redirect(Policy::none()) + .connect_timeout(request_timeout) + .timeout(request_timeout) .build() .map_err(|error| SubmitTransportError::Config { message: error.to_string(), @@ -231,6 +243,11 @@ impl JitoSubmitTransport for JitoJsonRpcTransport { .map_err(|error| SubmitTransportError::Failure { message: error.to_string(), })?; + if response.status().is_redirection() { + return Err(SubmitTransportError::Failure { + message: format!("unexpected redirect response: {}", response.status()), + }); + } let response = response @@ -239,13 +256,11 @@ impl JitoSubmitTransport for JitoJsonRpcTransport { message: error.to_string(), })?; + let response_body = read_http_response_bytes_bounded(response).await?; let parsed: JsonRpcResponse = - response - .json() - .await - .map_err(|error| SubmitTransportError::Failure { - message: error.to_string(), - })?; + json_from_slice(&response_body).map_err(|error| SubmitTransportError::Failure { + message: error.to_string(), + })?; if let Some(signature) = parsed.result { return Ok(JitoSubmitResponse { @@ -265,9 +280,76 @@ impl JitoSubmitTransport for JitoJsonRpcTransport { } } +/// Reads one Jito submit response body while enforcing a fixed maximum byte budget. +async fn read_http_response_bytes_bounded( + mut response: reqwest::Response, +) -> Result, SubmitTransportError> { + if response + .content_length() + .is_some_and(|content_length| content_length > MAX_JITO_SUBMIT_RESPONSE_BYTES as u64) + { + return Err(SubmitTransportError::Failure { + message: format!( + "response body exceeded max size of {MAX_JITO_SUBMIT_RESPONSE_BYTES} bytes" + ), + }); + } + + let initial_capacity = response + .content_length() + .and_then(|content_length| usize::try_from(content_length).ok()) + .unwrap_or(0) + .min(MAX_JITO_SUBMIT_RESPONSE_BYTES); + let mut body = Vec::with_capacity(initial_capacity); + while let Some(chunk) = + response + .chunk() + .await + .map_err(|error| SubmitTransportError::Failure { + message: error.to_string(), + })? + { + let remaining = MAX_JITO_SUBMIT_RESPONSE_BYTES.saturating_sub(body.len()); + if chunk.len() > remaining { + return Err(SubmitTransportError::Failure { + message: format!( + "response body exceeded max size of {MAX_JITO_SUBMIT_RESPONSE_BYTES} bytes" + ), + }); + } + body.extend_from_slice(&chunk); + } + Ok(body) +} + #[cfg(test)] +#[allow(clippy::indexing_slicing, clippy::panic)] mod tests { use super::*; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpListener, + }; + + async fn spawn_http_response_server(response: String) -> String { + let listener = TcpListener::bind("127.0.0.1:0").await; + assert!(listener.is_ok()); + let listener = listener.unwrap_or_else(|error| panic!("{error}")); + let addr = listener.local_addr(); + assert!(addr.is_ok()); + let addr = addr.unwrap_or_else(|error| panic!("{error}")); + tokio::spawn(async move { + let accepted = listener.accept().await; + assert!(accepted.is_ok()); + let (mut stream, _) = accepted.unwrap_or_else(|error| panic!("{error}")); + let mut buffer = [0_u8; 4096]; + let read = stream.read(&mut buffer).await; + assert!(read.is_ok()); + let write = stream.write_all(response.as_bytes()).await; + assert!(write.is_ok()); + }); + format!("http://{addr}") + } #[test] fn request_url_uses_transactions_path() { @@ -315,6 +397,15 @@ mod tests { assert_eq!(config.request_timeout, Duration::from_secs(10)); } + #[test] + fn transport_accepts_zero_timeout_config() { + let transport = JitoJsonRpcTransport::with_config(JitoTransportConfig { + endpoint: JitoBlockEngineEndpoint::default(), + request_timeout: Duration::ZERO, + }); + assert!(transport.is_ok()); + } + #[test] fn regional_endpoint_uses_documented_slug() { let endpoint = JitoBlockEngineEndpoint::mainnet_region(JitoBlockEngineRegion::Frankfurt); @@ -324,4 +415,57 @@ mod tests { "https://frankfurt.mainnet.block-engine.jito.wtf" ); } + + #[tokio::test] + async fn jito_transport_rejects_redirects() { + let parsed_url = Url::parse( + &spawn_http_response_server( + "HTTP/1.1 307 Temporary Redirect\r\nlocation: http://127.0.0.1/\r\ncontent-length: 0\r\nconnection: close\r\n\r\n" + .to_owned(), + ) + .await, + ); + assert!(parsed_url.is_ok()); + let parsed_url = parsed_url.unwrap_or_else(|error| panic!("{error}")); + let transport = + JitoJsonRpcTransport::with_endpoint(JitoBlockEngineEndpoint::custom(parsed_url)); + assert!(transport.is_ok()); + let transport = transport.unwrap_or_else(|error| panic!("{error}")); + + let error = transport + .submit_jito(&[1, 2, 3], &JitoSubmitConfig::default()) + .await; + assert!(error.is_err()); + let error = match error { + Ok(_response) => panic!("redirect should fail"), + Err(error) => error, + }; + assert!(error.to_string().contains("redirect")); + } + + #[tokio::test] + async fn jito_transport_rejects_oversized_responses() { + let endpoint = spawn_http_response_server(format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n", + MAX_JITO_SUBMIT_RESPONSE_BYTES.saturating_add(1) + )) + .await; + let parsed_url = Url::parse(&endpoint); + assert!(parsed_url.is_ok()); + let parsed_url = parsed_url.unwrap_or_else(|error| panic!("{error}")); + let transport = + JitoJsonRpcTransport::with_endpoint(JitoBlockEngineEndpoint::custom(parsed_url)); + assert!(transport.is_ok()); + let transport = transport.unwrap_or_else(|error| panic!("{error}")); + + let error = transport + .submit_jito(&[1, 2, 3], &JitoSubmitConfig::default()) + .await; + assert!(error.is_err()); + let error = match error { + Ok(_response) => panic!("oversized body should fail"), + Err(error) => error, + }; + assert!(error.to_string().contains("exceeded max size")); + } } diff --git a/crates/sof-tx/src/submit/jito_grpc.rs b/crates/sof-tx/src/submit/jito_grpc.rs index e0279f7a..6654b279 100644 --- a/crates/sof-tx/src/submit/jito_grpc.rs +++ b/crates/sof-tx/src/submit/jito_grpc.rs @@ -1,19 +1,25 @@ //! Jito searcher gRPC bundle transport implementation. +use std::time::Duration; + use async_trait::async_trait; +use sof_support::time_support::nonzero_duration_or; use tonic::{ Request, Status, client::Grpc, - codec::ProstCodec, codegen::http::uri::PathAndQuery, transport::{Channel, ClientTlsConfig, Endpoint}, }; +use tonic_prost::ProstCodec; use super::{ JitoBlockEngineEndpoint, JitoSubmitConfig, JitoSubmitResponse, JitoSubmitTransport, JitoTransportConfig, SubmitTransportError, }; +/// Default timeout used for Jito gRPC connect and request deadlines. +const DEFAULT_JITO_GRPC_REQUEST_TIMEOUT: Duration = Duration::from_secs(10); + /// Minimal shared header message for Jito bundle requests. #[derive(Clone, PartialEq, ::prost::Message)] struct SharedHeader {} @@ -107,6 +113,8 @@ impl JitoGrpcTransport { endpoint: block_engine_endpoint, request_timeout, } = config; + let request_timeout = + nonzero_duration_or(request_timeout, DEFAULT_JITO_GRPC_REQUEST_TIMEOUT); let endpoint_url = block_engine_endpoint.as_url().to_owned(); let mut transport_endpoint = Endpoint::from_shared(endpoint_url.clone()).map_err(|error| { @@ -176,7 +184,7 @@ impl JitoGrpcTransport { .unary( Request::new(request), PathAndQuery::from_static("/searcher.SearcherService/SendBundle"), - ProstCodec::default(), + ProstCodec::::default(), ) .await?; Ok(response.into_inner()) @@ -225,4 +233,13 @@ mod tests { .unwrap_or_default(); assert_eq!(packet_count, 1); } + + #[tokio::test(flavor = "current_thread")] + async fn jito_grpc_transport_accepts_zero_timeout_config() { + let transport = JitoGrpcTransport::with_config(JitoTransportConfig { + endpoint: JitoBlockEngineEndpoint::default(), + request_timeout: Duration::ZERO, + }); + assert!(transport.is_ok()); + } } diff --git a/crates/sof-tx/src/submit/rpc.rs b/crates/sof-tx/src/submit/rpc.rs index 944d5d33..c1d3ef9d 100644 --- a/crates/sof-tx/src/submit/rpc.rs +++ b/crates/sof-tx/src/submit/rpc.rs @@ -4,10 +4,15 @@ use std::time::Duration; use async_trait::async_trait; use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD}; +use reqwest::redirect::Policy; use serde::{Deserialize, Serialize}; +use serde_json::from_slice as json_from_slice; use super::{RpcSubmitConfig, RpcSubmitTransport, SubmitTransportError}; +/// Maximum HTTP body size accepted from JSON-RPC submit responses. +const MAX_RPC_SUBMIT_RESPONSE_BYTES: usize = 64 * 1024; + /// JSON-RPC transport that submits encoded transactions via `sendTransaction`. #[derive(Debug, Clone)] pub struct JsonRpcTransport { @@ -25,6 +30,8 @@ impl JsonRpcTransport { /// Returns [`SubmitTransportError::Config`] when HTTP client creation fails. pub fn new(rpc_url: impl Into) -> Result { let client = reqwest::Client::builder() + .redirect(Policy::none()) + .connect_timeout(Duration::from_secs(10)) .timeout(Duration::from_secs(10)) .build() .map_err(|error| SubmitTransportError::Config { @@ -101,6 +108,11 @@ impl RpcSubmitTransport for JsonRpcTransport { .map_err(|error| SubmitTransportError::Failure { message: error.to_string(), })?; + if response.status().is_redirection() { + return Err(SubmitTransportError::Failure { + message: format!("unexpected redirect response: {}", response.status()), + }); + } let response = response @@ -109,13 +121,11 @@ impl RpcSubmitTransport for JsonRpcTransport { message: error.to_string(), })?; + let response_body = read_http_response_bytes_bounded(response).await?; let parsed: JsonRpcResponse = - response - .json() - .await - .map_err(|error| SubmitTransportError::Failure { - message: error.to_string(), - })?; + json_from_slice(&response_body).map_err(|error| SubmitTransportError::Failure { + message: error.to_string(), + })?; if let Some(signature) = parsed.result { return Ok(signature); @@ -131,3 +141,119 @@ impl RpcSubmitTransport for JsonRpcTransport { }) } } + +/// Reads one submit response body while enforcing a fixed maximum byte budget. +async fn read_http_response_bytes_bounded( + mut response: reqwest::Response, +) -> Result, SubmitTransportError> { + if response + .content_length() + .is_some_and(|content_length| content_length > MAX_RPC_SUBMIT_RESPONSE_BYTES as u64) + { + return Err(SubmitTransportError::Failure { + message: format!( + "response body exceeded max size of {MAX_RPC_SUBMIT_RESPONSE_BYTES} bytes" + ), + }); + } + + let initial_capacity = response + .content_length() + .and_then(|content_length| usize::try_from(content_length).ok()) + .unwrap_or(0) + .min(MAX_RPC_SUBMIT_RESPONSE_BYTES); + let mut body = Vec::with_capacity(initial_capacity); + while let Some(chunk) = + response + .chunk() + .await + .map_err(|error| SubmitTransportError::Failure { + message: error.to_string(), + })? + { + let remaining = MAX_RPC_SUBMIT_RESPONSE_BYTES.saturating_sub(body.len()); + if chunk.len() > remaining { + return Err(SubmitTransportError::Failure { + message: format!( + "response body exceeded max size of {MAX_RPC_SUBMIT_RESPONSE_BYTES} bytes" + ), + }); + } + body.extend_from_slice(&chunk); + } + Ok(body) +} + +#[cfg(test)] +#[allow(clippy::indexing_slicing, clippy::panic)] +mod tests { + use super::*; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpListener, + }; + + async fn spawn_http_response_server(response: String) -> String { + let listener = TcpListener::bind("127.0.0.1:0").await; + assert!(listener.is_ok()); + let listener = listener.unwrap_or_else(|error| panic!("{error}")); + let addr = listener.local_addr(); + assert!(addr.is_ok()); + let addr = addr.unwrap_or_else(|error| panic!("{error}")); + tokio::spawn(async move { + let accepted = listener.accept().await; + assert!(accepted.is_ok()); + let (mut stream, _) = accepted.unwrap_or_else(|error| panic!("{error}")); + let mut buffer = [0_u8; 4096]; + let read = stream.read(&mut buffer).await; + assert!(read.is_ok()); + let write = stream.write_all(response.as_bytes()).await; + assert!(write.is_ok()); + }); + format!("http://{addr}") + } + + #[tokio::test] + async fn json_rpc_transport_rejects_redirects() { + let endpoint = spawn_http_response_server( + "HTTP/1.1 307 Temporary Redirect\r\nlocation: http://127.0.0.1/\r\ncontent-length: 0\r\nconnection: close\r\n\r\n" + .to_owned(), + ) + .await; + let transport = JsonRpcTransport::new(endpoint); + assert!(transport.is_ok()); + let transport = transport.unwrap_or_else(|error| panic!("{error}")); + + let error = transport + .submit_rpc(&[1, 2, 3], &RpcSubmitConfig::default()) + .await; + assert!(error.is_err()); + let error = match error { + Ok(_signature) => panic!("redirect should fail"), + Err(error) => error, + }; + assert!(error.to_string().contains("redirect")); + } + + #[tokio::test] + async fn json_rpc_transport_rejects_oversized_responses() { + let endpoint = spawn_http_response_server(format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n", + MAX_RPC_SUBMIT_RESPONSE_BYTES.saturating_add(1) + )) + .await; + let transport = JsonRpcTransport::new(endpoint); + assert!(transport.is_ok()); + let transport = transport.unwrap_or_else(|error| panic!("{error}")); + + let error = transport + .submit_rpc(&[1, 2, 3], &RpcSubmitConfig::default()) + .await; + assert!(error.is_err()); + let error = match error { + Ok(_signature) => panic!("oversized body should fail"), + Err(error) => error, + }; + assert!(error.to_string().contains("exceeded max size")); + } +} diff --git a/crates/sof-tx/src/submit/tests.rs b/crates/sof-tx/src/submit/tests.rs index 7a2bc29c..ff79de02 100644 --- a/crates/sof-tx/src/submit/tests.rs +++ b/crates/sof-tx/src/submit/tests.rs @@ -12,6 +12,7 @@ use std::{ }; use async_trait::async_trait; +use sof_support::{bench::avg_ns_per_iteration, short_vec::decode_short_u16_len_prefix}; use sof_types::SignatureBytes; use solana_keypair::Keypair; use solana_signature::Signature; @@ -262,20 +263,6 @@ fn target(port: u16) -> LeaderTarget { LeaderTarget::new(None, SocketAddr::from(([127, 0, 0, 1], port))) } -/// Decodes the first short-vec length prefix in a serialized transaction. -fn decode_short_vec_len(bytes: &[u8]) -> Option<(usize, usize)> { - let mut value = 0_usize; - let mut shift = 0_u32; - for (idx, byte) in bytes.iter().copied().take(3).enumerate() { - value |= usize::from(byte & 0x7f) << shift; - if byte & 0x80 == 0 { - return Some((value, idx.saturating_add(1))); - } - shift = shift.saturating_add(7); - } - None -} - /// Rewrites the first signature bytes so repeated profile iterations do not trip dedupe. fn rewrite_first_signature(bytes: &mut [u8], seed: u64) { const BYTE_SHIFTS: [u32; 64] = [ @@ -283,7 +270,7 @@ fn rewrite_first_signature(bytes: &mut [u8], seed: u64) { 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56, ]; - let decoded = decode_short_vec_len(bytes); + let decoded = decode_short_u16_len_prefix(bytes); assert!(decoded.is_some()); let (signature_count, offset) = decoded.unwrap_or((0, 0)); assert!(signature_count > 0); @@ -1553,7 +1540,15 @@ async fn submit_rpc_only_profile_fixture() { .await; assert!(result.is_ok()); } - println!("rpc_only_us={}", start.elapsed().as_micros()); + let elapsed = start.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); + println!( + "submit_rpc_only_profile_fixture iterations={} rpc_only_us={} avg_ns_per_iteration={} avg_us_per_iteration={:.3}", + iterations, + elapsed.as_micros(), + avg_ns, + avg_ns as f64 / 1_000.0 + ); } #[tokio::test] @@ -1585,7 +1580,15 @@ async fn submit_jito_only_profile_fixture() { .await; assert!(result.is_ok()); } - println!("jito_only_us={}", start.elapsed().as_micros()); + let elapsed = start.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); + println!( + "submit_jito_only_profile_fixture iterations={} jito_only_us={} avg_ns_per_iteration={} avg_us_per_iteration={:.3}", + iterations, + elapsed.as_micros(), + avg_ns, + avg_ns as f64 / 1_000.0 + ); } #[tokio::test] @@ -1618,7 +1621,15 @@ async fn submit_direct_only_profile_fixture() { .await; assert!(result.is_ok()); } - println!("direct_only_us={}", start.elapsed().as_micros()); + let elapsed = start.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); + println!( + "submit_direct_only_profile_fixture iterations={} direct_only_us={} avg_ns_per_iteration={} avg_us_per_iteration={:.3}", + iterations, + elapsed.as_micros(), + avg_ns, + avg_ns as f64 / 1_000.0 + ); } #[tokio::test] @@ -1663,7 +1674,15 @@ async fn submit_hybrid_fallback_profile_fixture() { .await; assert!(result.is_ok()); } - println!("hybrid_fallback_us={}", start.elapsed().as_micros()); + let elapsed = start.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); + println!( + "submit_hybrid_fallback_profile_fixture iterations={} hybrid_fallback_us={} avg_ns_per_iteration={} avg_us_per_iteration={:.3}", + iterations, + elapsed.as_micros(), + avg_ns, + avg_ns as f64 / 1_000.0 + ); } #[tokio::test] @@ -1709,5 +1728,13 @@ async fn submit_all_at_once_profile_fixture() { .await; assert!(result.is_ok()); } - println!("all_at_once_us={}", start.elapsed().as_micros()); + let elapsed = start.elapsed(); + let avg_ns = avg_ns_per_iteration(elapsed, iterations); + println!( + "submit_all_at_once_profile_fixture iterations={} all_at_once_us={} avg_ns_per_iteration={} avg_us_per_iteration={:.3}", + iterations, + elapsed.as_micros(), + avg_ns, + avg_ns as f64 / 1_000.0 + ); } diff --git a/crates/sof-tx/src/submit/types.rs b/crates/sof-tx/src/submit/types.rs index be2bdbfd..5374e6df 100644 --- a/crates/sof-tx/src/submit/types.rs +++ b/crates/sof-tx/src/submit/types.rs @@ -1,7 +1,7 @@ //! Shared submission types, errors, and transport traits. use std::{ - collections::HashMap, + collections::{HashMap, VecDeque}, hash::{Hash, Hasher}, sync::{ Arc, @@ -319,6 +319,16 @@ impl DirectSubmitConfig { /// Returns this config with minimum valid retry counters. #[must_use] pub const fn normalized(self) -> Self { + let per_target_timeout = if self.per_target_timeout.is_zero() { + Duration::from_millis(1) + } else { + self.per_target_timeout + }; + let global_timeout = if self.global_timeout.is_zero() { + Duration::from_millis(1) + } else { + self.global_timeout + }; let direct_target_rounds = if self.direct_target_rounds == 0 { 1 } else { @@ -349,9 +359,14 @@ impl DirectSubmitConfig { } else { self.agave_rebroadcast_interval }; + let latency_probe_timeout = if self.latency_probe_timeout.is_zero() { + Duration::from_millis(1) + } else { + self.latency_probe_timeout + }; Self { - per_target_timeout: self.per_target_timeout, - global_timeout: self.global_timeout, + per_target_timeout, + global_timeout, direct_target_rounds, direct_submit_attempts, hybrid_direct_attempts, @@ -361,7 +376,7 @@ impl DirectSubmitConfig { agave_rebroadcast_interval, hybrid_rpc_broadcast: self.hybrid_rpc_broadcast, latency_aware_targeting: self.latency_aware_targeting, - latency_probe_timeout: self.latency_probe_timeout, + latency_probe_timeout, latency_probe_port: self.latency_probe_port, latency_probe_max_targets, } @@ -932,6 +947,8 @@ impl TxSubmitOutcomeReporter for TxToxicFlowTelemetry { pub(crate) struct TxSuppressionCache { /// Active suppression entries keyed by opportunity identity. entries: HashMap, + /// Insertion order for eviction, including stale superseded timestamps. + order: VecDeque<(TxSubmitSuppressionKey, SystemTime)>, } impl TxSuppressionCache { @@ -950,16 +967,27 @@ impl TxSuppressionCache { pub(crate) fn insert_all(&mut self, keys: &[TxSubmitSuppressionKey], now: SystemTime) { for key in keys { let _ = self.entries.insert(key.clone(), now); + self.order.push_back((key.clone(), now)); } } /// Removes entries older than the current TTL window. fn evict_expired(&mut self, now: SystemTime, ttl: Duration) { - self.entries.retain(|_, inserted_at| { - now.duration_since(*inserted_at) + while let Some((_, front_inserted_at)) = self.order.front() { + let still_live = now + .duration_since(*front_inserted_at) .map(|elapsed| elapsed <= ttl) - .unwrap_or(false) - }); + .unwrap_or(false); + if still_live { + break; + } + let Some((key, queued_inserted_at)) = self.order.pop_front() else { + break; + }; + if self.entries.get(&key) == Some(&queued_inserted_at) { + let _ = self.entries.remove(&key); + } + } } } @@ -997,3 +1025,90 @@ pub trait DirectSubmitTransport: Send + Sync { config: &DirectSubmitConfig, ) -> Result; } + +#[cfg(test)] +mod tests { + use super::*; + use std::{hint::black_box, time::Instant}; + + use sof_support::bench::profile_iterations; + + #[test] + #[ignore = "profiling fixture for submit suppression cache churn"] + fn suppression_cache_profile_fixture() { + let iterations = profile_iterations(50_000); + let ttl = Duration::from_millis(750); + let base = SystemTime::UNIX_EPOCH + Duration::from_secs(1); + let keys = (0_u8..64) + .map(|value| TxSubmitSuppressionKey::Opportunity([value; 32])) + .collect::>(); + let mut cache = TxSuppressionCache::default(); + + let started = Instant::now(); + for (iteration, key) in keys.iter().cycle().take(iterations).enumerate() { + let now = base + Duration::from_millis(u64::try_from(iteration % 2_000).unwrap_or(0)); + cache.insert_all(std::slice::from_ref(key), now); + black_box(cache.is_suppressed(std::slice::from_ref(key), now, ttl)); + } + let elapsed = started.elapsed(); + let avg_ns_per_iteration = elapsed.as_nanos() / u128::try_from(iterations).unwrap_or(1); + let avg_us_per_iteration = avg_ns_per_iteration as f64 / 1_000.0; + + eprintln!( + "suppression_cache_profile_fixture iterations={} elapsed_us={} avg_ns_per_iteration={} avg_us_per_iteration={:.3} entries={}", + iterations, + elapsed.as_micros(), + avg_ns_per_iteration, + avg_us_per_iteration, + cache.entries.len(), + ); + } + + #[test] + fn suppression_cache_keeps_refreshed_entry_live() { + let mut cache = TxSuppressionCache::default(); + let key = TxSubmitSuppressionKey::Opportunity([7_u8; 32]); + let ttl = Duration::from_millis(750); + let first_inserted_at = SystemTime::UNIX_EPOCH + Duration::from_secs(1); + let refreshed_at = first_inserted_at + Duration::from_millis(500); + + cache.insert_all(std::slice::from_ref(&key), first_inserted_at); + cache.insert_all(std::slice::from_ref(&key), refreshed_at); + + assert!(cache.is_suppressed( + std::slice::from_ref(&key), + refreshed_at + Duration::from_millis(100), + ttl, + )); + assert!(!cache.is_suppressed( + std::slice::from_ref(&key), + refreshed_at + ttl + Duration::from_millis(1), + ttl, + )); + } + + #[test] + fn direct_submit_config_clamps_zero_timeouts() { + let normalized = DirectSubmitConfig { + per_target_timeout: Duration::ZERO, + global_timeout: Duration::ZERO, + direct_target_rounds: 1, + direct_submit_attempts: 1, + hybrid_direct_attempts: 1, + rebroadcast_interval: Duration::from_millis(5), + agave_rebroadcast_enabled: false, + agave_rebroadcast_window: Duration::ZERO, + agave_rebroadcast_interval: Duration::from_millis(5), + hybrid_rpc_broadcast: false, + latency_aware_targeting: true, + latency_probe_timeout: Duration::ZERO, + latency_probe_port: None, + latency_probe_max_targets: 1, + } + .normalized(); + + assert_eq!(normalized.per_target_timeout, Duration::from_millis(1)); + assert_eq!(normalized.global_timeout, Duration::from_millis(1)); + assert_eq!(normalized.latency_probe_timeout, Duration::from_millis(1)); + } +} diff --git a/crates/sof-tx/tests/kernel_bypass_af_xdp_e2e.rs b/crates/sof-tx/tests/kernel_bypass_af_xdp_e2e.rs index 6865b478..b9e575e7 100644 --- a/crates/sof-tx/tests/kernel_bypass_af_xdp_e2e.rs +++ b/crates/sof-tx/tests/kernel_bypass_af_xdp_e2e.rs @@ -3,11 +3,15 @@ #![cfg(all(target_os = "linux", feature = "kernel-bypass"))] use std::{ + env, ffi::CString, io, - net::{Ipv4Addr, SocketAddr, UdpSocket}, + net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket}, + path::Path, process::Command, + slice, sync::{Arc, Mutex}, + thread, time::Duration, }; @@ -17,6 +21,7 @@ use sof_tx::{ KernelBypassDatagramSocket, KernelBypassDirectTransport, LeaderTarget, RoutingPolicy, submit::{DirectSubmitConfig, DirectSubmitTransport}, }; +use tokio::runtime::Builder; use xdp::{ RingConfigBuilder, Umem, WakableRings, packet::PacketError, @@ -97,8 +102,8 @@ impl AfXdpKernelBypassSocket { impl KernelBypassDatagramSocket for AfXdpKernelBypassSocket { async fn send_to(&self, payload: &[u8], target: SocketAddr) -> io::Result { let dst = match target.ip() { - std::net::IpAddr::V4(ip) => ip, - std::net::IpAddr::V6(_) => { + IpAddr::V4(ip) => ip, + IpAddr::V6(_) => { return Err(io::Error::new( io::ErrorKind::InvalidInput, "AF_XDP test socket only supports IPv4 targets", @@ -370,7 +375,7 @@ fn read_link_packets(interface_name: &str) -> Result<(u64, u64), Box Result<(), Box> { +fn run_unshare(current_exe: &Path) -> Result<(), Box> { for candidate in [ "/usr/bin/unshare", "/bin/unshare", @@ -456,15 +461,13 @@ fn run_inner() -> Result<(), Box> { ..DirectSubmitConfig::default() }; - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; + let runtime = Builder::new_current_thread().enable_all().build()?; let selected = runtime .block_on(async { transport .submit_direct( &payload, - std::slice::from_ref(&target), + slice::from_ref(&target), RoutingPolicy::default(), &config, ) @@ -481,7 +484,7 @@ fn run_inner() -> Result<(), Box> { let mut sender_tx_after = sender_tx_before; let mut receiver_rx_after = receiver_rx_before; for _ in 0..10 { - std::thread::sleep(Duration::from_millis(50)); + thread::sleep(Duration::from_millis(50)); (sender_tx_after, _) = read_link_packets(VETH_SENDER)?; (_, receiver_rx_after) = read_link_packets(VETH_RECEIVER)?; if sender_tx_after > sender_tx_before && receiver_rx_after > receiver_rx_before { @@ -532,8 +535,8 @@ fn run_inner() -> Result<(), Box> { #[test] #[ignore = "requires Linux user namespaces and AF_XDP support"] fn kernel_bypass_af_xdp_direct_submit_e2e() -> Result<(), Box> { - if std::env::var_os(INNER_ENV).is_none() { - let current_exe = std::env::current_exe()?; + if env::var_os(INNER_ENV).is_none() { + let current_exe = env::current_exe()?; run_unshare(¤t_exe)?; return Ok(()); } diff --git a/crates/sof-types/Cargo.toml b/crates/sof-types/Cargo.toml index 1b5561de..e68be189 100644 --- a/crates/sof-types/Cargo.toml +++ b/crates/sof-types/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sof-types" -version = "0.18.1" +version = "0.18.2" edition.workspace = true description = "Stable SOF-owned primitive types shared across SOF crates" license = "Apache-2.0 OR MIT" diff --git a/docs/architecture/runtime-extension-hooks.md b/docs/architecture/runtime-extension-hooks.md index 08f8ed0b..934c4a9f 100644 --- a/docs/architecture/runtime-extension-hooks.md +++ b/docs/architecture/runtime-extension-hooks.md @@ -69,12 +69,20 @@ Startup manifests can request runtime-managed resources: 3. `TcpConnector` 4. `WsConnector` +Manifest validation rules: + +1. extension names must be non-empty, +2. `resource_id` must be non-empty, +3. `Shared { tag }` tags must be non-empty, +4. `read_buffer_bytes` is bounded by runtime startup validation. + `WsConnector` supports full WebSocket protocol handling: 1. `ws://` and `wss://` URLs, 2. opening handshake, 3. decoded message frame delivery to extension dispatch, -4. `Ping` / `Pong` handling. +4. `Ping` / `Pong` handling, +5. bounded frame/message limits derived from `read_buffer_bytes`. ## Visibility and Sharing diff --git a/docs/gitbook/crates/sof-tx.md b/docs/gitbook/crates/sof-tx.md index 377f0830..f7f3f601 100644 --- a/docs/gitbook/crates/sof-tx.md +++ b/docs/gitbook/crates/sof-tx.md @@ -309,9 +309,9 @@ If the conceptual docs stop too early for what you need to build, open these nex ## Feature Flags ```toml -sof-tx = { version = "0.18.1", features = ["sof-adapters"] } -sof-tx = { version = "0.18.1", features = ["kernel-bypass"] } -sof-tx = { version = "0.18.1", features = ["jito-grpc"] } +sof-tx = { version = "0.18.2", features = ["sof-adapters"] } +sof-tx = { version = "0.18.2", features = ["kernel-bypass"] } +sof-tx = { version = "0.18.2", features = ["jito-grpc"] } ``` ## Good Fit diff --git a/docs/gitbook/getting-started/install-sof.md b/docs/gitbook/getting-started/install-sof.md index 3a15d21f..87506ba1 100644 --- a/docs/gitbook/getting-started/install-sof.md +++ b/docs/gitbook/getting-started/install-sof.md @@ -35,8 +35,8 @@ Only add `sof-gossip-tuning` if you are embedding `sof` and want typed host/runt Common feature combinations: ```toml -sof = { version = "0.18.1", features = ["gossip-bootstrap"] } -sof-tx = { version = "0.18.1", features = ["sof-adapters"] } +sof = { version = "0.18.2", features = ["gossip-bootstrap"] } +sof-tx = { version = "0.18.2", features = ["sof-adapters"] } ``` ## Choose Your Starting Point @@ -50,7 +50,7 @@ Start with the app shape that matches what you need to build right now. ```toml [dependencies] tokio = { version = "1", features = ["macros", "rt-multi-thread"] } -sof = "0.18.1" +sof = "0.18.2" ``` `src/main.rs`: @@ -74,7 +74,7 @@ Use this when you need ingest, plugin events, datasets, or local control-plane s [dependencies] async-trait = "0.1" tokio = { version = "1", features = ["macros", "rt-multi-thread"] } -sof = "0.18.1" +sof = "0.18.2" tracing = "0.1" ``` @@ -124,7 +124,7 @@ Use this when you already know you want to consume SOF events in your own code. ```toml [dependencies] -sof-tx = "0.18.1" +sof-tx = "0.18.2" tokio = { version = "1", features = ["macros", "rt-multi-thread"] } ``` diff --git a/scripts/verify-publishable-archives.sh b/scripts/verify-publishable-archives.sh index 391a35fb..ac6e1ff0 100755 --- a/scripts/verify-publishable-archives.sh +++ b/scripts/verify-publishable-archives.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -euo pipefail -crates=("sof-types" "sof-gossip-tuning" "sof-solana-gossip" "sof" "sof-tx" "sof-solana-compat") +crates=("sof-types" "sof-gossip-tuning" "sof-support" "sof-solana-gossip" "sof" "sof-tx" "sof-solana-compat") package_dir="target/package" verify_root="$(mktemp -d)" cargo_home_root="$(mktemp -d)" @@ -16,6 +16,7 @@ sof-solana-gossip = { path = "$(pwd)/crates/sof-solana-gossip" } sof = { path = "$(pwd)/crates/sof-observer" } sof-tx = { path = "$(pwd)/crates/sof-tx" } sof-solana-compat = { path = "$(pwd)/crates/sof-solana-compat" } +sof-support = { path = "$(pwd)/crates/sof-support" } EOF export CARGO_HOME="${cargo_home_root}" @@ -62,6 +63,7 @@ extract_crate() { package_crate "sof-types" "" package_crate "sof-gossip-tuning" "" +package_crate "sof-support" "" package_crate "sof-solana-gossip" "--no-verify" package_crate "sof" "--no-verify" package_crate "sof-tx" "--no-verify" @@ -73,6 +75,7 @@ done sof_types_version="$(version_for "sof-types")" sof_gossip_tuning_version="$(version_for "sof-gossip-tuning")" +sof_support_version="$(version_for "sof-support")" sof_solana_gossip_version="$(version_for "sof-solana-gossip")" sof_version="$(version_for "sof")" sof_tx_version="$(version_for "sof-tx")" @@ -84,6 +87,7 @@ resolver = "3" members = [ "sof-types-${sof_types_version}", "sof-gossip-tuning-${sof_gossip_tuning_version}", + "sof-support-${sof_support_version}", "sof-${sof_version}", "sof-tx-${sof_tx_version}", "sof-solana-compat-${sof_solana_compat_version}", @@ -92,6 +96,7 @@ members = [ [patch.crates-io] sof-types = { path = "sof-types-${sof_types_version}" } sof-gossip-tuning = { path = "sof-gossip-tuning-${sof_gossip_tuning_version}" } +sof-support = { path = "sof-support-${sof_support_version}" } sof-solana-gossip = { path = "sof-solana-gossip-${sof_solana_gossip_version}" } sof = { path = "sof-${sof_version}" } sof-tx = { path = "sof-tx-${sof_tx_version}" } diff --git a/vendor/helius-laserstream/Cargo.toml b/vendor/helius-laserstream/Cargo.toml new file mode 100644 index 00000000..8d5933fd --- /dev/null +++ b/vendor/helius-laserstream/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "helius-laserstream" +version = "0.1.9" +edition = "2021" +authors = ["Helius "] +description = "Rust client for Helius LaserStream gRPC with robust reconnection and slot tracking" +license = "MIT" +repository = "https://github.com/helius-labs/laserstream" +keywords = ["solana", "laserstream", "yellowstone", "grpc", "blockchain"] +readme = "README.md" + +[lib] +name = "helius_laserstream" +path = "src/lib.rs" + +[dependencies] +async-stream = "0.3" +futures = "0.3" +futures-channel = "0.3" +futures-util = { version = "0.3", features = ["sink"] } +laserstream-core-client = "9.0.2" +laserstream-core-proto = { version = "9.0.2", default-features = false, features = ["tonic", "tonic-compression"] } +serde = { version = "1.0", features = ["derive"] } +thiserror = "1.0" +tokio = { version = "1", features = ["rt-multi-thread", "macros", "time", "sync"] } +tracing = "0.1" +url = "2.5" +uuid = { version = "1.7.0", features = ["v4"] } diff --git a/vendor/helius-laserstream/README.md b/vendor/helius-laserstream/README.md new file mode 100644 index 00000000..b813ff9b --- /dev/null +++ b/vendor/helius-laserstream/README.md @@ -0,0 +1,5 @@ +Local SOF patch of `helius-laserstream` `0.1.9`. + +This keeps the same crate API surface used by SOF while trimming unused direct +dependencies from the published crate, especially the old `tonic 0.12` path +that is not required by SOF's build. diff --git a/vendor/helius-laserstream/src/client.rs b/vendor/helius-laserstream/src/client.rs new file mode 100644 index 00000000..e101d0af --- /dev/null +++ b/vendor/helius-laserstream/src/client.rs @@ -0,0 +1,530 @@ +use crate::{LaserstreamConfig, LaserstreamError, config::CompressionEncoding as ConfigCompressionEncoding}; +use async_stream::stream; +use futures::StreamExt; +use futures_channel::mpsc as futures_mpsc; +use futures_util::{sink::SinkExt, Stream}; +use std::{pin::Pin, time::Duration}; +use tokio::sync::mpsc; +use tokio::time::sleep; +use laserstream_core_proto::tonic::{ + Status, Request, metadata::MetadataValue, transport::Endpoint, codec::CompressionEncoding, +}; +use tracing::{error, instrument, warn}; +use uuid; +use laserstream_core_client::{ClientTlsConfig, Interceptor}; +use laserstream_core_proto::prelude::{geyser_client::GeyserClient}; +use laserstream_core_proto::geyser::{ + subscribe_update::UpdateOneof, SubscribeRequest, SubscribeRequestFilterSlots, + SubscribeRequestPing, SubscribeUpdate, + SubscribePreprocessedRequest, SubscribePreprocessedUpdate, +}; + +const HARD_CAP_RECONNECT_ATTEMPTS: u32 = (20 * 60) / 5; // 20 mins / 5 sec interval +const FIXED_RECONNECT_INTERVAL_MS: u64 = 5000; // 5 seconds fixed interval +const SDK_NAME: &str = "laserstream-rust"; +const SDK_VERSION: &str = "0.1.9"; + +/// Custom interceptor that adds SDK metadata headers to all gRPC requests +#[derive(Clone)] +struct SdkMetadataInterceptor { + x_token: Option, +} + +impl SdkMetadataInterceptor { + fn new(api_key: String) -> Result { + let x_token = if !api_key.is_empty() { + Some(api_key.parse().map_err(|e| { + Status::invalid_argument(format!("Invalid API key: {}", e)) + })?) + } else { + None + }; + Ok(Self { x_token }) + } +} + +impl Interceptor for SdkMetadataInterceptor { + fn call(&mut self, mut request: Request<()>) -> Result, Status> { + // Add x-token if present + if let Some(ref x_token) = self.x_token { + request.metadata_mut().insert("x-token", x_token.clone()); + } + + // Add SDK metadata headers + request.metadata_mut().insert("x-sdk-name", MetadataValue::from_static(SDK_NAME)); + request.metadata_mut().insert("x-sdk-version", MetadataValue::from_static(SDK_VERSION)); + + Ok(request) + } +} + +/// Handle for managing a bidirectional streaming subscription. +#[derive(Clone)] +pub struct StreamHandle { + write_tx: mpsc::UnboundedSender, +} + +impl StreamHandle { + /// Send a new subscription request to update the active subscription. + pub async fn write(&self, request: SubscribeRequest) -> Result<(), LaserstreamError> { + self.write_tx + .send(request) + .map_err(|_| LaserstreamError::ConnectionError("Write channel closed".to_string())) + } +} + +/// Establishes a gRPC connection, handles the subscription lifecycle, +/// and provides a stream of updates. Automatically reconnects on failure. +#[instrument(skip(config, request))] +pub fn subscribe( + config: LaserstreamConfig, + request: SubscribeRequest, +) -> ( + impl Stream>, + StreamHandle, +) { + let (write_tx, mut write_rx) = mpsc::unbounded_channel::(); + let handle = StreamHandle { write_tx }; + let update_stream = stream! { + let mut reconnect_attempts = 0; + let mut tracked_slot: u64 = 0; + + // Determine the effective max reconnect attempts + let effective_max_attempts = config + .max_reconnect_attempts + .unwrap_or(HARD_CAP_RECONNECT_ATTEMPTS) // Default to hard cap if not set + .min(HARD_CAP_RECONNECT_ATTEMPTS); // Enforce hard cap + + // Keep original request for reconnection attempts + let mut current_request = request.clone(); + let internal_slot_sub_id = format!("internal-{}", uuid::Uuid::new_v4().to_string().split('-').next().unwrap()); + + // Get replay behavior from config + let replay_enabled = config.replay; + + // Add internal slot subscription only when replay is enabled + if replay_enabled { + current_request.slots.insert( + internal_slot_sub_id.clone(), + SubscribeRequestFilterSlots { + filter_by_commitment: Some(true), // Use same commitment as user request + ..Default::default() + } + ); + } + + // Clear any user-provided from_slot if replay is disabled + if !replay_enabled { + current_request.from_slot = None; + } + + let api_key_string = config.api_key.clone(); + + loop { + // Drain any pending write requests that arrived during reconnection delay. + // This ensures writes sent while disconnected are included in the next connection. + while let Ok(write_request) = write_rx.try_recv() { + merge_subscribe_requests(&mut current_request, &write_request, &internal_slot_sub_id); + } + + // Always update from_slot on current_request based on tracked_slot. + // This ensures reconnections always use the most recent slot, even after + // a successful connection that subsequently errors on the stream. + if tracked_slot > 0 && replay_enabled { + let commitment_level = current_request.commitment.unwrap_or(0); + let from_slot = match commitment_level { + 0 => tracked_slot.saturating_sub(31), // PROCESSED: rewind by 31 slots + 1 | 2 => tracked_slot, // CONFIRMED/FINALIZED: exact slot + _ => tracked_slot.saturating_sub(31), // Unknown: default to safe behavior + }; + current_request.from_slot = Some(from_slot); + } else if !replay_enabled { + current_request.from_slot = None; + } + + let attempt_request = current_request.clone(); + + match connect_and_subscribe_once(&config, attempt_request, api_key_string.clone()).await { + Ok((sender, stream)) => { + // Successful connection – reset attempt counter so we don't hit the cap + reconnect_attempts = 0; + + // Box sender and stream here before processing + let mut sender: Pin + Send>> = Box::pin(sender); + // Ensure the boxed stream yields Result<_, Status> + let mut stream: Pin> + Send>> = Box::pin(stream); + + // Ping interval timer + let mut ping_interval = tokio::time::interval(Duration::from_secs(30)); + ping_interval.tick().await; // Skip first immediate tick + let mut ping_id = 0i32; + + loop { + tokio::select! { + // Send periodic ping + _ = ping_interval.tick() => { + ping_id = ping_id.wrapping_add(1); + let ping_request = SubscribeRequest { + ping: Some(SubscribeRequestPing { id: ping_id }), + ..Default::default() + }; + let _ = sender.send(ping_request).await; + }, + // Handle incoming messages from the server + result = stream.next() => { + if let Some(result) = result { + match result { + Ok(update) => { + + // Handle ping/pong + if matches!(&update.update_oneof, Some(UpdateOneof::Ping(_))) { + let pong_req = SubscribeRequest { ping: Some(SubscribeRequestPing { id: 1 }), ..Default::default() }; + if let Err(e) = sender.send(pong_req).await { + warn!(error = %e, "Failed to send pong"); + break; + } + continue; + } + + // Do not forward server 'Pong' updates to consumers either + if matches!(&update.update_oneof, Some(UpdateOneof::Pong(_))) { + continue; + } + + // Track the latest slot from any slot update (including internal subscription) + if let Some(UpdateOneof::Slot(s)) = &update.update_oneof { + if replay_enabled { + tracked_slot = s.slot; + } + + // Skip if this slot update is EXCLUSIVELY from our internal subscription + if update.filters.len() == 1 && update.filters.contains(&internal_slot_sub_id) { + continue; + } + } + + // Filter out internal subscription from filters before yielding (only if replay is enabled) + let mut clean_update = update; + if replay_enabled { + clean_update.filters.retain(|f| f != &internal_slot_sub_id); + + // Only yield if there are still filters after cleaning + if !clean_update.filters.is_empty() { + yield Ok(clean_update); + } + } else { + // When replay is disabled, yield all updates as-is + yield Ok(clean_update); + } + } + Err(status) => { + // Yield the error to consumer AND continue with reconnection + warn!(error = %status, "Stream error, will reconnect after 5s delay"); + yield Err(LaserstreamError::Status(status.clone())); + break; + } + } + } else { + // Stream ended + break; + } + } + + // Handle write requests from the user + Some(write_request) = write_rx.recv() => { + // Merge the write_request into current_request so it persists across reconnections + merge_subscribe_requests(&mut current_request, &write_request, &internal_slot_sub_id); + + if let Err(e) = sender.send(write_request).await { + warn!(error = %e, "Failed to send write request"); + break; + } + } + } + } + } + Err(err) => { + // Increment reconnect attempts + reconnect_attempts += 1; + + // Log error internally but don't yield to consumer until max attempts exhausted + error!(error = %err, attempt = reconnect_attempts, max_attempts = effective_max_attempts, "Connection failed, will retry after 5s delay"); + + // Check if exceeded max reconnect attempts + if reconnect_attempts >= effective_max_attempts { + error!(attempts = effective_max_attempts, "Max reconnection attempts reached"); + // Only report error to consumer after exhausting all retries + yield Err(LaserstreamError::MaxReconnectAttempts(Status::cancelled( + format!("Connection failed after {} attempts", effective_max_attempts) + ))); + return; + } + } + } + + // Wait 5s before retry + let delay = Duration::from_millis(FIXED_RECONNECT_INTERVAL_MS); + sleep(delay).await; + } + }; + + (update_stream, handle) +} + +#[instrument(skip(config, request, api_key))] +async fn connect_and_subscribe_once( + config: &LaserstreamConfig, + request: SubscribeRequest, + api_key: String, +) -> Result< + ( + impl futures_util::Sink + Send, + impl Stream> + Send, + ), + Status, +> { + let options = &config.channel_options; + + // Create our custom interceptor with SDK metadata + let interceptor = SdkMetadataInterceptor::new(api_key)?; + + // Build endpoint with all options + let mut endpoint = Endpoint::from_shared(config.endpoint.clone()) + .map_err(|e| Status::internal(format!("Failed to parse endpoint: {}", e)))? + .connect_timeout(Duration::from_secs(options.connect_timeout_secs.unwrap_or(10))) + .timeout(Duration::from_secs(options.timeout_secs.unwrap_or(30))) + .http2_keep_alive_interval(Duration::from_secs(options.http2_keep_alive_interval_secs.unwrap_or(30))) + .keep_alive_timeout(Duration::from_secs(options.keep_alive_timeout_secs.unwrap_or(5))) + .keep_alive_while_idle(options.keep_alive_while_idle.unwrap_or(true)) + .initial_stream_window_size(options.initial_stream_window_size.or(Some(1024 * 1024 * 4))) + .initial_connection_window_size(options.initial_connection_window_size.or(Some(1024 * 1024 * 8))) + .http2_adaptive_window(options.http2_adaptive_window.unwrap_or(true)) + .tcp_nodelay(options.tcp_nodelay.unwrap_or(true)) + .buffer_size(options.buffer_size.or(Some(1024 * 64))); + + if let Some(tcp_keepalive_secs) = options.tcp_keepalive_secs { + endpoint = endpoint.tcp_keepalive(Some(Duration::from_secs(tcp_keepalive_secs))); + } + + // Configure TLS + endpoint = endpoint + .tls_config(ClientTlsConfig::new().with_enabled_roots()) + .map_err(|e| Status::internal(format!("TLS config error: {}", e)))?; + + // Connect to create channel + let channel = endpoint + .connect() + .await + .map_err(|e| Status::unavailable(format!("Connection failed: {}", e)))?; + + // Create geyser client with our custom interceptor + let mut geyser_client = GeyserClient::with_interceptor(channel, interceptor); + + // Configure message size limits + geyser_client = geyser_client + .max_decoding_message_size(options.max_decoding_message_size.unwrap_or(1_000_000_000)) + .max_encoding_message_size(options.max_encoding_message_size.unwrap_or(32_000_000)); + + // Configure compression if specified + if let Some(send_comp) = options.send_compression { + let encoding = match send_comp { + ConfigCompressionEncoding::Gzip => CompressionEncoding::Gzip, + ConfigCompressionEncoding::Zstd => CompressionEncoding::Zstd, + }; + geyser_client = geyser_client.send_compressed(encoding); + } + + // Configure accepted compression encodings + if let Some(ref accept_comps) = options.accept_compression { + for comp in accept_comps { + let encoding = match comp { + ConfigCompressionEncoding::Gzip => CompressionEncoding::Gzip, + ConfigCompressionEncoding::Zstd => CompressionEncoding::Zstd, + }; + geyser_client = geyser_client.accept_compressed(encoding); + } + } + + // Create bidirectional stream + let (mut subscribe_tx, subscribe_rx) = futures_mpsc::unbounded(); + subscribe_tx + .send(request) + .await + .map_err(|e| Status::internal(format!("Failed to send initial request: {}", e)))?; + + let response = geyser_client + .subscribe(subscribe_rx) + .await + .map_err(|e| Status::internal(format!("Subscription failed: {}", e)))?; + + Ok((subscribe_tx, response.into_inner())) +} + +/// Handle for managing a preprocessed subscription (no write support). +#[derive(Clone)] +pub struct PreprocessedStreamHandle; + +/// Establishes a gRPC connection for preprocessed transactions and provides a stream of updates. +/// Automatically reconnects on failure. No slot tracking or replay - just simple reconnection. +#[instrument(skip(config, request))] +pub fn subscribe_preprocessed( + config: LaserstreamConfig, + request: SubscribePreprocessedRequest, +) -> ( + impl Stream>, + PreprocessedStreamHandle, +) { + let handle = PreprocessedStreamHandle; + let update_stream = stream! { + let mut reconnect_attempts = 0; + + // Determine the effective max reconnect attempts + let effective_max_attempts = config + .max_reconnect_attempts + .unwrap_or(HARD_CAP_RECONNECT_ATTEMPTS) + .min(HARD_CAP_RECONNECT_ATTEMPTS); + + loop { + let api_key = config.api_key.clone(); + let request_clone = request.clone(); + + match connect_and_subscribe_preprocessed_once(&config, request_clone, api_key).await { + Ok(mut stream) => { + reconnect_attempts = 0; + + while let Some(result) = stream.next().await { + match result { + Ok(update) => yield Ok(update), + Err(e) => { + warn!(error = %e, "Stream error received"); + break; + } + } + } + } + Err(err) => { + reconnect_attempts += 1; + error!(error = %err, attempt = reconnect_attempts, max_attempts = effective_max_attempts, "Connection failed, will retry after 5s delay"); + + if reconnect_attempts >= effective_max_attempts { + error!(attempts = effective_max_attempts, "Max reconnection attempts reached"); + yield Err(LaserstreamError::MaxReconnectAttempts(Status::cancelled( + format!("Connection failed after {} attempts", effective_max_attempts) + ))); + return; + } + } + } + + let delay = Duration::from_millis(FIXED_RECONNECT_INTERVAL_MS); + sleep(delay).await; + } + }; + + (update_stream, handle) +} + +#[instrument(skip(config, request, api_key))] +async fn connect_and_subscribe_preprocessed_once( + config: &LaserstreamConfig, + request: SubscribePreprocessedRequest, + api_key: String, +) -> Result< + impl Stream> + Send, + Status, +> { + let options = &config.channel_options; + + // Create our custom interceptor with SDK metadata + let interceptor = SdkMetadataInterceptor::new(api_key)?; + + // Build endpoint with all options + let mut endpoint = Endpoint::from_shared(config.endpoint.clone()) + .map_err(|e| Status::internal(format!("Failed to parse endpoint: {}", e)))? + .connect_timeout(Duration::from_secs(options.connect_timeout_secs.unwrap_or(10))) + .timeout(Duration::from_secs(options.timeout_secs.unwrap_or(30))) + .tcp_nodelay(options.tcp_nodelay.unwrap_or(true)) + .tcp_keepalive(Some(Duration::from_secs(options.tcp_keepalive_secs.unwrap_or(30)))) + .http2_keep_alive_interval(Duration::from_secs(options.http2_keep_alive_interval_secs.unwrap_or(30))) + .keep_alive_timeout(Duration::from_secs(options.keep_alive_timeout_secs.unwrap_or(10))) + .keep_alive_while_idle(options.keep_alive_while_idle.unwrap_or(true)); + + endpoint = endpoint + .tls_config(ClientTlsConfig::new().with_enabled_roots()) + .map_err(|e| Status::internal(format!("Failed to configure TLS: {}", e)))?; + + let channel = endpoint + .connect() + .await + .map_err(|e| Status::internal(format!("Failed to connect: {}", e)))?; + + let mut geyser_client = GeyserClient::with_interceptor(channel, interceptor) + .max_decoding_message_size(options.max_decoding_message_size.unwrap_or(1_000_000_000)) + .max_encoding_message_size(options.max_encoding_message_size.unwrap_or(32_000_000)); + + // Apply compression if specified + if let Some(compression) = &options.send_compression { + let encoding = match compression { + ConfigCompressionEncoding::Gzip => CompressionEncoding::Gzip, + ConfigCompressionEncoding::Zstd => CompressionEncoding::Zstd, + }; + geyser_client = geyser_client.send_compressed(encoding).accept_compressed(encoding); + } + + let (mut subscribe_tx, subscribe_rx) = futures_mpsc::unbounded(); + + subscribe_tx + .send(request) + .await + .map_err(|e| Status::internal(format!("Failed to send initial request: {}", e)))?; + + let response = geyser_client + .subscribe_preprocessed(subscribe_rx) + .await + .map_err(|e| Status::internal(format!("Preprocessed subscription failed: {}", e)))?; + + Ok(response.into_inner()) +} + +/// Merges a write request into the current stored request so that subscription +/// changes made via `write()` persist across reconnections. +/// +/// Replaces all user subscription fields with the modification's values while +/// preserving the internal slot tracker subscription used for replay. +/// `from_slot` and `ping` are not replaced as they are connection-specific. +fn merge_subscribe_requests( + current: &mut SubscribeRequest, + modification: &SubscribeRequest, + internal_slot_sub_id: &str, +) { + // Save the internal slot tracker before replacing slots + let internal_tracker = current + .slots + .get(internal_slot_sub_id) + .cloned(); + + // Replace all subscription types (Yellowstone gRPC replaces, not merges) + current.accounts = modification.accounts.clone(); + current.slots = modification.slots.clone(); + current.transactions = modification.transactions.clone(); + current.transactions_status = modification.transactions_status.clone(); + current.blocks = modification.blocks.clone(); + current.blocks_meta = modification.blocks_meta.clone(); + current.entry = modification.entry.clone(); + current.accounts_data_slice = modification.accounts_data_slice.clone(); + + // Restore the internal slot tracker if it existed + if let Some(value) = internal_tracker { + current + .slots + .insert(internal_slot_sub_id.to_string(), value); + } + + // Update commitment if specified in the modification + if modification.commitment.is_some() { + current.commitment = modification.commitment; + } + + // Note: from_slot and ping are not replaced as they are connection-specific +} + diff --git a/vendor/helius-laserstream/src/config.rs b/vendor/helius-laserstream/src/config.rs new file mode 100644 index 00000000..11a815eb --- /dev/null +++ b/vendor/helius-laserstream/src/config.rs @@ -0,0 +1,123 @@ +use serde::{Deserialize, Serialize}; + +/// Compression encoding options +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum CompressionEncoding { + /// Gzip compression + Gzip, + /// Zstd compression + Zstd, +} + +#[derive(Debug, Clone)] +pub struct LaserstreamConfig { + /// API Key for authentication. + pub api_key: String, + /// The Laserstream endpoint URL. + pub endpoint: String, + /// Maximum number of reconnection attempts. Defaults to 10. + /// A hard cap of 240 attempts (20 minutes / 5 seconds) is enforced internally. + pub max_reconnect_attempts: Option, + /// gRPC channel options + pub channel_options: ChannelOptions, + /// When true, enable replay on reconnects (uses from_slot and internal slot tracking). + /// When false, no replay - start from current slot on reconnects. + /// Default: true + pub replay: bool, +} + +#[derive(Debug, Clone, Default)] +pub struct ChannelOptions { + /// Connect timeout in seconds. Default: 10 + pub connect_timeout_secs: Option, + /// Request timeout in seconds. Default: 30 + pub timeout_secs: Option, + /// Max message size for receiving in bytes. Default: 1GB + pub max_decoding_message_size: Option, + /// Max message size for sending in bytes. Default: 32MB + pub max_encoding_message_size: Option, + /// HTTP/2 keep-alive interval in seconds. Default: 30 + pub http2_keep_alive_interval_secs: Option, + /// Keep-alive timeout in seconds. Default: 5 + pub keep_alive_timeout_secs: Option, + /// Enable keep-alive while idle. Default: true + pub keep_alive_while_idle: Option, + /// Initial stream window size in bytes. Default: 4MB + pub initial_stream_window_size: Option, + /// Initial connection window size in bytes. Default: 8MB + pub initial_connection_window_size: Option, + /// Enable HTTP/2 adaptive window. Default: true + pub http2_adaptive_window: Option, + /// Enable TCP no-delay. Default: true + pub tcp_nodelay: Option, + /// TCP keep-alive interval in seconds. Default: 60 + pub tcp_keepalive_secs: Option, + /// Buffer size in bytes. Default: 64KB + pub buffer_size: Option, + /// Compression encodings to accept from server. Default: ["gzip", "zstd"] + pub accept_compression: Option>, + /// Compression encoding to use when sending. Default: None + pub send_compression: Option, +} + + +impl ChannelOptions { + /// Enable zstd compression for both sending and receiving + pub fn with_zstd_compression(mut self) -> Self { + self.send_compression = Some(CompressionEncoding::Zstd); + self.accept_compression = Some(vec![CompressionEncoding::Zstd, CompressionEncoding::Gzip]); + self + } + + /// Enable gzip compression for both sending and receiving + pub fn with_gzip_compression(mut self) -> Self { + self.send_compression = Some(CompressionEncoding::Gzip); + self.accept_compression = Some(vec![CompressionEncoding::Gzip, CompressionEncoding::Zstd]); + self + } +} + +impl Default for LaserstreamConfig { + fn default() -> Self { + Self { + api_key: String::new(), + endpoint: String::new(), + max_reconnect_attempts: None, // Default to None + channel_options: ChannelOptions::default(), + replay: true, // Default to true + } + } +} + +impl LaserstreamConfig { + pub fn new(endpoint: String, api_key: String) -> Self { + Self { + endpoint, + api_key, + max_reconnect_attempts: None, // Default to None + channel_options: ChannelOptions::default(), + replay: true, // Default to true + } + } + + /// Sets the maximum number of reconnection attempts. + pub fn with_max_reconnect_attempts(mut self, attempts: u32) -> Self { + self.max_reconnect_attempts = Some(attempts); + self + } + + /// Sets custom channel options. + pub fn with_channel_options(mut self, options: ChannelOptions) -> Self { + self.channel_options = options; + self + } + + /// Sets replay behavior on reconnects. + /// When true (default), uses from_slot and internal slot tracking for replay. + /// When false, starts from current slot on reconnects (no replay). + pub fn with_replay(mut self, replay: bool) -> Self { + self.replay = replay; + self + } +} diff --git a/vendor/helius-laserstream/src/error.rs b/vendor/helius-laserstream/src/error.rs new file mode 100644 index 00000000..e588d88c --- /dev/null +++ b/vendor/helius-laserstream/src/error.rs @@ -0,0 +1,41 @@ +use laserstream_core_proto::tonic::Status; +use url::ParseError; +use futures_channel::mpsc::SendError; +use laserstream_core_client::{GeyserGrpcClientError, GeyserGrpcBuilderError}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum LaserstreamError { + #[error("gRPC transport error: {0}")] + Transport(#[from] laserstream_core_proto::tonic::transport::Error), + + #[error("gRPC status error: {0}")] + Status(#[from] Status), + + #[error("Invalid endpoint URL: {0}")] + InvalidUrl(#[from] ParseError), + + #[error("Stream unexpectedly ended")] + StreamEnded, + + #[error("Subscription channel send error: {0}")] + SubscriptionSendError(#[from] SendError), + + #[error("Maximum reconnection attempts reached: {0}")] + MaxReconnectAttempts(Status), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("Laserstream client error: {0}")] + ClientError(#[from] GeyserGrpcClientError), + + #[error("Laserstream builder error: {0}")] + BuilderError(#[from] GeyserGrpcBuilderError), + + #[error("Invalid API Key format")] + InvalidApiKeyFormat, + + #[error("Connection error: {0}")] + ConnectionError(String), +} diff --git a/vendor/helius-laserstream/src/lib.rs b/vendor/helius-laserstream/src/lib.rs new file mode 100644 index 00000000..430f4aa5 --- /dev/null +++ b/vendor/helius-laserstream/src/lib.rs @@ -0,0 +1,11 @@ +pub mod client; +pub mod config; +pub mod error; + +pub use client::{subscribe, subscribe_preprocessed, StreamHandle, PreprocessedStreamHandle}; +pub use config::{ChannelOptions, LaserstreamConfig, CompressionEncoding}; +pub use error::LaserstreamError; + +// Re-export commonly used types from laserstream-core-proto +pub use laserstream_core_proto::geyser as grpc; +pub use laserstream_core_proto::solana; \ No newline at end of file