From 354156cb079fb16cc02004ec00547f33a037be05 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Mon, 2 Mar 2026 04:17:45 +0300 Subject: [PATCH 01/57] Rewrite as single-crate v2 architecture with embedded inference Replace the old workspace (compute, executor, p2p, utils) with a single-crate binary featuring embedded GGUF inference via llama-cpp-2, QUIC networking via quinn, HuggingFace model management, and a backpressure-aware worker pool. --- .dockerignore | 12 - .env.example | 35 - ARCHITECTURE_V2.md | 773 ++++++ Cargo.lock | 3223 ++++++-------------------- Cargo.toml | 77 +- Cross.toml | 8 - Makefile | 25 - compose.yml | 64 - compute/Cargo.toml | 64 - compute/src/config.rs | 183 -- compute/src/lib.rs | 12 - compute/src/main.rs | 209 -- compute/src/node/core.rs | 167 -- compute/src/node/diagnostic.rs | 152 -- compute/src/node/mod.rs | 166 -- compute/src/node/reqres.rs | 211 -- compute/src/node/rpc.rs | 108 - compute/src/reqres/heartbeat.rs | 87 - compute/src/reqres/mod.rs | 80 - compute/src/reqres/specs.rs | 57 - compute/src/reqres/task.rs | 283 --- compute/src/utils/mod.rs | 5 - compute/src/utils/points.rs | 85 - compute/src/utils/specs.rs | 120 - compute/src/workers/mod.rs | 1 - compute/src/workers/task.rs | 312 --- docs/NODE_SPECS.md | 254 -- executor/Cargo.toml | 36 - executor/README.md | 20 - executor/examples/ollama.rs | 19 - executor/src/executors/gemini.rs | 178 -- executor/src/executors/mod.rs | 71 - executor/src/executors/ollama.rs | 253 -- executor/src/executors/openai.rs | 172 -- executor/src/executors/openrouter.rs | 98 - executor/src/lib.rs | 17 - executor/src/manager.rs | 143 -- executor/src/models.rs | 299 --- executor/src/task.rs | 168 -- p2p/Cargo.toml | 36 - p2p/README.md | 125 - p2p/src/behaviour.rs | 53 - p2p/src/client.rs | 358 --- p2p/src/commands.rs | 154 -- p2p/src/lib.rs | 14 - p2p/src/protocol.rs | 104 - p2p/tests/request_test.rs | 64 - src/config.rs | 213 ++ src/error.rs | 22 + src/identity.rs | 114 + src/inference/benchmark.rs | 71 + src/inference/engine.rs | 297 +++ src/inference/mod.rs | 7 + src/inference/proof.rs | 65 + src/inference/stream.rs | 39 + src/main.rs | 338 +++ src/models/cache.rs | 103 + src/models/download.rs | 46 + src/models/mod.rs | 8 + src/models/registry.rs | 132 ++ src/models/template.rs | 163 ++ src/network/auth.rs | 163 ++ src/network/connection.rs | 537 +++++ src/network/mod.rs | 6 + src/network/protocol.rs | 329 +++ src/worker.rs | 261 +++ utils/Cargo.toml | 40 - utils/README.md | 19 - utils/src/crypto.rs | 115 - utils/src/env.rs | 25 - utils/src/lib.rs | 30 - utils/src/message.rs | 215 -- utils/src/network.rs | 86 - utils/src/payloads/heartbeat.rs | 36 - utils/src/payloads/mod.rs | 11 - utils/src/payloads/specs.rs | 98 - utils/src/payloads/tasks.rs | 151 -- utils/src/version.rs | 95 - 78 files changed, 4374 insertions(+), 8616 deletions(-) delete mode 100644 .dockerignore delete mode 100644 .env.example create mode 100644 ARCHITECTURE_V2.md delete mode 100644 Cross.toml delete mode 100644 Makefile delete mode 100644 compose.yml delete mode 100644 compute/Cargo.toml delete mode 100644 compute/src/config.rs delete mode 100644 compute/src/lib.rs delete mode 100644 compute/src/main.rs delete mode 100644 compute/src/node/core.rs delete mode 100644 compute/src/node/diagnostic.rs delete mode 100644 compute/src/node/mod.rs delete mode 100644 compute/src/node/reqres.rs delete mode 100644 compute/src/node/rpc.rs delete mode 100644 compute/src/reqres/heartbeat.rs delete mode 100644 compute/src/reqres/mod.rs delete mode 100644 compute/src/reqres/specs.rs delete mode 100644 compute/src/reqres/task.rs delete mode 100644 compute/src/utils/mod.rs delete mode 100644 compute/src/utils/points.rs delete mode 100644 compute/src/utils/specs.rs delete mode 100644 compute/src/workers/mod.rs delete mode 100644 compute/src/workers/task.rs delete mode 100644 docs/NODE_SPECS.md delete mode 100644 executor/Cargo.toml delete mode 100644 executor/README.md delete mode 100644 executor/examples/ollama.rs delete mode 100644 executor/src/executors/gemini.rs delete mode 100644 executor/src/executors/mod.rs delete mode 100644 executor/src/executors/ollama.rs delete mode 100644 executor/src/executors/openai.rs delete mode 100644 executor/src/executors/openrouter.rs delete mode 100644 executor/src/lib.rs delete mode 100644 executor/src/manager.rs delete mode 100644 executor/src/models.rs delete mode 100644 executor/src/task.rs delete mode 100644 p2p/Cargo.toml delete mode 100644 p2p/README.md delete mode 100644 p2p/src/behaviour.rs delete mode 100644 p2p/src/client.rs delete mode 100644 p2p/src/commands.rs delete mode 100644 p2p/src/lib.rs delete mode 100644 p2p/src/protocol.rs delete mode 100644 p2p/tests/request_test.rs create mode 100644 src/config.rs create mode 100644 src/error.rs create mode 100644 src/identity.rs create mode 100644 src/inference/benchmark.rs create mode 100644 src/inference/engine.rs create mode 100644 src/inference/mod.rs create mode 100644 src/inference/proof.rs create mode 100644 src/inference/stream.rs create mode 100644 src/main.rs create mode 100644 src/models/cache.rs create mode 100644 src/models/download.rs create mode 100644 src/models/mod.rs create mode 100644 src/models/registry.rs create mode 100644 src/models/template.rs create mode 100644 src/network/auth.rs create mode 100644 src/network/connection.rs create mode 100644 src/network/mod.rs create mode 100644 src/network/protocol.rs create mode 100644 src/worker.rs delete mode 100644 utils/Cargo.toml delete mode 100644 utils/README.md delete mode 100644 utils/src/crypto.rs delete mode 100644 utils/src/env.rs delete mode 100644 utils/src/lib.rs delete mode 100644 utils/src/message.rs delete mode 100644 utils/src/network.rs delete mode 100644 utils/src/payloads/heartbeat.rs delete mode 100644 utils/src/payloads/mod.rs delete mode 100644 utils/src/payloads/specs.rs delete mode 100644 utils/src/payloads/tasks.rs delete mode 100644 utils/src/version.rs diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index 0f09f7a3..00000000 --- a/.dockerignore +++ /dev/null @@ -1,12 +0,0 @@ -target -Dockerfile -manifests -.dockerignore -.git -.gitignore -.github -.DS_Store -README.md -misc -Justfile -Makefile diff --git a/.env.example b/.env.example deleted file mode 100644 index cd053459..00000000 --- a/.env.example +++ /dev/null @@ -1,35 +0,0 @@ -## DRIA (required) ## -# Secret key of your compute node, 32 byte in hexadecimal. -# e.g.: DKN_WALLET_SECRET_KEY=0xabc...123 -DKN_WALLET_SECRET_KEY= -# model1,model2,model3,... (comma separated, case-insensitive) -# example: gemini-2.0-flash,gpt-4o-mini -DKN_MODELS= - -## DRIA (optional) ## -# P2P address, you don't need to change this unless this port is already in use. -DKN_P2P_LISTEN_ADDR=/ip4/0.0.0.0/tcp/4001 -# Batch size for task worker, you do not need to edit this. -DKN_BATCH_SIZE= -# Initial RPC address for testing purposes -# DKN_INITIAL_RPC_ADDR= - -## DRIA (profiling only, do not uncomment) ## -# Set to a number of seconds to wait before exiting, only use in profiling build! -# Otherwise, leave this empty. -# DKN_EXIT_TIMEOUT= - -## Open AI (if used, required) ## -OPENAI_API_KEY= -## Gemini (if used, required) ## -GEMINI_API_KEY= -## Open Router (if used, required) ## -OPENROUTER_API_KEY= - -## Ollama (if used, optional) ## -OLLAMA_HOST=http://localhost -# you can change the port if you would like -OLLAMA_PORT=11434 -# if "true", automatically pull models from Ollama -# if "false", you have to download manually -OLLAMA_AUTO_PULL=true diff --git a/ARCHITECTURE_V2.md b/ARCHITECTURE_V2.md new file mode 100644 index 00000000..39e9a261 --- /dev/null +++ b/ARCHITECTURE_V2.md @@ -0,0 +1,773 @@ +# DKN v2 Architecture Plan + +> A ground-up redesign of the Dria Knowledge Network for low-latency agentic inference at scale. + +## Goals + +1. **Single binary compute node** — no Ollama, no launcher, no `.env` juggling. Download, run, earn. +2. **Sub-second task routing** — from 14-hop batch pipeline to 4-hop direct routing. +3. **Cloud-agnostic** — no AWS vendor lock. Runs on any infrastructure. +4. **Agentic-first** — real-time inference, streaming tokens, multi-model sessions, sub-agent fan-out. +5. **Provable inference** — validation embedded in the architecture, not bolted on. +6. **Scale to millions of nodes** — stateless router fleet, horizontal scaling at every layer. + +## Current State (v1) — What We're Replacing + +### Problems + +| Problem | Impact | +|---|---| +| 14-hop task pipeline (Client → API → S3 → PG → EventBridge → Validator → PG → Dispatcher → RabbitMQ → RPC → Dispatcher API → libp2p → Node → Ollama) | Minutes of latency per task, unusable for agents | +| Ollama as separate installation | Friction for 292K operators, HTTP overhead for local inference, no access to model internals | +| Hardcoded model enum in Rust | Every new model requires binary release across entire fleet | +| AWS-locked orchestration (ECS, EventBridge, S3, SQS) | Cannot deploy outside AWS | +| 10+ services (Harbor, Dispatcher, RPC, Challenger, NDX, Cortex, etc.) | Operational complexity, many failure points | +| Batch-only paradigm | Cannot serve agentic workloads that need real-time responses | +| RPC gateway bottleneck (star topology, single connection per node) | Single point of failure per node cluster | +| No backpressure — nodes can't reject tasks | Overloaded nodes queue indefinitely | +| Challenger uses gameable deterministic puzzles | Bad actors can pass challenges without running models | + +### Current Service Count: 10+ + +- Harbor API, Validator, Uploader, Dashboard, Cortex, NDX, Models, Points, Status (TypeScript) +- Dispatcher (Rust) +- RPC Gateway (Rust) +- Challenger API (Python) +- Compute Node + Launcher (Rust) +- Ollama (Go) +- RabbitMQ, PostgreSQL, MongoDB, Redis, S3, SQS, EventBridge + +## v2 Architecture + +### Components: 3 + +1. **Dria Node** — single Rust binary with embedded inference (community-operated) +2. **Dria Router** — stateless routing + validation fleet (Dria-operated, horizontally scalable) +3. **Shared State** — NATS (messaging) + Redis (node registry) + PostgreSQL (persistent data) + +### System Diagram + +``` +┌──────────────────────────────────────────────────────────────┐ +│ Client Layer │ +│ │ +│ ┌──────────┐ ┌───────────┐ ┌────────────────────────┐ │ +│ │ Agent SDK│ │ Batch API │ │ Sub-agent Orchestrator │ │ +│ └────┬─────┘ └─────┬─────┘ └───────────┬────────────┘ │ +└───────┼───────────────┼────────────────────┼─────────────────┘ + │ │ │ + ▼ ▼ ▼ + (HTTPS / WebSocket / gRPC — pick per use case) + │ │ │ +┌───────┴───────────────┴───────────────────┴──────────────────┐ +│ Load Balancer │ +│ (nginx / envoy / any cloud LB) │ +└───────┬───────────────┬───────────────────┬──────────────────┘ + │ │ │ + ▼ ▼ ▼ +┌──────────────┐ ┌──────────────┐ ┌──────────────┐ +│ Router A │ │ Router B │ │ Router N │ +│ (stateless) │ │ (stateless) │ │ (stateless) │ +│ │ │ │ │ │ +│ - Routing │ │ - Routing │ │ - Routing │ +│ - Validate │ │ - Validate │ │ - Validate │ +│ - Stream │ │ - Stream │ │ - Stream │ +└──────┬───────┘ └──────┬───────┘ └──────┬───────┘ + │ │ │ + │ ┌───────────┴──────────┐ │ + │ ▼ ▼ │ + │ ┌──────┐ ┌────────────────┐ │ + ├─│ NATS │ │ Node Registry │─────┤ + │ │ │ │ (Redis/etcd) │ │ + │ └──────┘ └────────────────┘ │ + │ │ + ▼ (QUIC persistent conns) ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Compute Nodes (292K+) │ +│ │ +│ ┌───────────────────────────────────────────────────────┐ │ +│ │ dria-node (single Rust binary) │ │ +│ │ │ │ +│ │ ┌─────────────┐ ┌──────────────┐ ┌─────────────┐ │ │ +│ │ │ Inference │ │ Network │ │ Identity │ │ │ +│ │ │ (llama.cpp) │ │ (QUIC) │ │ (secp256k1)│ │ │ +│ │ └─────────────┘ └──────────────┘ └─────────────┘ │ │ +│ └───────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +--- + +## Component 1: Dria Node + +The compute node is a single statically-linked Rust binary that community operators download and run. It embeds the inference engine, manages models, connects to routers, and proves its work. + +### Install & Run + +```bash +# Install +curl -fsSL https://dria.co/install | bash + +# Run (interactive first-time setup) +dria-node start + +# Or fully non-interactive +dria-node start --wallet 0x... --model gemma3:12b + +# Multi-model +dria-node start --wallet 0x... --model gemma3:4b,llama3.1:8b +``` + +No Ollama installation. No `.env` files. No launcher binary. The node downloads GGUF weights from HuggingFace on first run, validates hardware capability, connects to a router, and starts accepting work. + +### Internal Architecture + +``` +dria-node binary (~3,000 lines Rust, single crate) +│ +├── main.rs # CLI, startup, signal handling +├── config.rs # Config from args/env, proper error handling (no panics) +│ +├── inference/ +│ ├── engine.rs # llama.cpp bindings, load model, run inference +│ ├── models.rs # GGUF download from HuggingFace, file management +│ ├── stream.rs # Token-by-token streaming callback +│ └── proof.rs # Logprob extraction, KV-cache fingerprinting +│ +├── network/ +│ ├── connection.rs # QUIC connection to router, auto-reconnect +│ ├── protocol.rs # Message types, serialization (flat, no base64) +│ └── auth.rs # secp256k1 challenge-response handshake +│ +├── worker.rs # Task execution loop, backpressure, capacity reporting +└── identity.rs # Wallet, keypair, address derivation +``` + +### Key Design Decisions + +#### Models Are Strings, Not Enums + +```rust +// v1 (current) — adding a model requires a release +enum Model { + #[serde(rename = "gemma3:4b")] + Gemma3_4b, + // ... every model is a variant +} + +// v2 — models are just identifiers +struct ModelSpec { + name: String, // "gemma3:12b" + gguf_repo: String, // "bartowski/gemma-3-12b-it-GGUF" + gguf_file: String, // "gemma-3-12b-it-Q4_K_M.gguf" + chat_template: String, // jinja2 template name or inline +} +``` + +The router maintains the model registry and pushes specs to nodes. Adding a new model is a config change on the router side — zero node updates needed. + +#### Embedded Inference via llama.cpp + +```rust +// Rust bindings to llama.cpp (via llama-cpp-2 or llama-cpp-rs crate) + +pub struct InferenceEngine { + model: LlamaModel, + ctx: LlamaContext, +} + +impl InferenceEngine { + /// Load a GGUF model from disk + pub fn load(model_path: &Path, gpu_layers: u32) -> Result; + + /// Run inference, streaming tokens via callback + pub async fn generate( + &mut self, + prompt: &str, + params: GenerateParams, + on_token: impl FnMut(Token) -> ControlFlow, // stream tokens out + ) -> Result; + + /// Extract logprobs at specific positions (for validation) + pub fn logprobs_at(&self, positions: &[usize]) -> Vec; + + /// Hash KV-cache state at (layer, position) for proof-of-inference + pub fn kv_cache_hash(&self, layer: usize, position: usize) -> [u8; 32]; + + /// Benchmark: tokens per second on this hardware + pub fn benchmark(&mut self, prompt: &str) -> TpsResult; +} +``` + +Hardware backend selection at build time (or runtime via feature flags): +- `--features cuda` — NVIDIA GPUs +- `--features metal` — Apple Silicon +- `--features rocm` — AMD GPUs +- `--features vulkan` — Cross-platform GPU +- Default: CPU (OpenBLAS/Accelerate) + +Pre-built binaries for common combinations: `dria-node-linux-cuda`, `dria-node-macos-metal`, `dria-node-linux-cpu`. + +#### Task Execution with Backpressure + +```rust +pub struct Worker { + engine: InferenceEngine, + capacity: AtomicUsize, // how many slots are free + max_concurrent: usize, // 1 for local GPU, N for multi-GPU +} + +impl Worker { + /// Returns None if at capacity (router will route elsewhere) + pub fn try_accept(&self, task: Task) -> Option { + if self.capacity.load(Ordering::Relaxed) == 0 { + return None; // REJECT — tell router to re-route + } + self.capacity.fetch_sub(1, Ordering::Relaxed); + Some(self.spawn_task(task)) + } +} +``` + +The current system queues tasks into a 1024-size channel and hopes for the best. The new system exposes real capacity — the router never sends work to a node that can't handle it. + +#### Single Tokio Runtime, Two Tasks + +```rust +#[tokio::main] +async fn main() -> Result<()> { + let config = Config::from_args_and_env()?; // no panics + let engine = InferenceEngine::load(&config.model_path, config.gpu_layers)?; + let identity = Identity::from_secret_key(&config.secret_key)?; + + let (conn, mut events) = Connection::connect(&config.router_url, &identity).await?; + let worker = Worker::new(engine, config.max_concurrent); + + // Single select loop — no commander pattern, no inter-thread channels + let cancellation = CancellationToken::new(); + loop { + tokio::select! { + event = events.recv() => match event { + Event::TaskRequest(task) => { + match worker.try_accept(task) { + Some(handle) => { /* task running, result streams back via conn */ } + None => conn.reject(task.id).await?, // backpressure + } + } + Event::ValidationChallenge(challenge) => { + let proof = worker.generate_proof(&challenge)?; + conn.submit_proof(proof).await?; + } + Event::Ping => conn.pong(worker.status()).await?, + Event::Disconnected => conn.reconnect().await?, + }, + result = worker.next_completed() => { + conn.send_result(result).await?; + } + _ = cancellation.cancelled() => break, + } + } +} +``` + +No commander pattern. No mpsc+oneshot roundtrips. No separate P2P thread. The connection and worker live in the same task, communicating directly. + +--- + +## Component 2: Dria Router + +Stateless Rust service operated by Dria. Handles task routing, node management, validation, and client-facing APIs. Horizontally scalable — add more instances behind a load balancer. + +### Responsibilities + +| Function | Description | +|---|---| +| **Client API** | Accept inference requests (real-time, streaming, batch) via HTTPS/WebSocket/gRPC | +| **Node Management** | Accept QUIC connections from compute nodes, track health and capacity | +| **Task Routing** | Match tasks to capable nodes based on model, capacity, latency, reputation | +| **Result Delivery** | Stream results back to clients, aggregate batch results | +| **Validation** | Issue proof-of-inference challenges, verify logprobs, detect anomalies | +| **Billing/Points** | Emit events to NATS for points calculation and billing | + +### Internal Architecture + +``` +dria-router binary +│ +├── main.rs +├── config.rs +│ +├── api/ +│ ├── rest.rs # Batch API (POST /v2/infer, POST /v2/batch) +│ ├── websocket.rs # Streaming API for agents +│ └── grpc.rs # Optional gRPC for high-performance clients +│ +├── nodes/ +│ ├── registry.rs # Read/write node state to Redis +│ ├── connection.rs # QUIC listener, per-node connection management +│ ├── selector.rs # Routing algorithm: model match → capacity → latency → reputation +│ └── health.rs # Heartbeat monitoring, disconnect detection +│ +├── routing/ +│ ├── realtime.rs # Single task → single node, stream result back +│ ├── batch.rs # Fan-out tasks across nodes, aggregate results +│ └── session.rs # Session affinity for multi-turn (optional KV-cache reuse) +│ +├── validation/ +│ ├── logprob.rs # Request and verify logprobs from nodes +│ ├── timing.rs # TPS anomaly detection +│ ├── kv_cache.rs # KV-cache fingerprint challenges +│ └── reputation.rs # Node reputation scoring, challenge frequency +│ +└── events.rs # Publish to NATS: points, billing, audit logs +``` + +### Scaling Model + +Each router instance handles ~10,000 concurrent node connections (QUIC is lightweight per-connection). Scaling is linear: + +| Nodes | Routers Needed | Infra | +|---|---|---| +| 50K | 5 | Small K8s cluster or 5 VMs | +| 292K | 30 | Medium cluster | +| 1M | 100 | Large cluster, any cloud or bare metal | + +Routers are stateless — they can crash and restart without data loss. Nodes reconnect to any available router. All durable state lives in Redis (node registry) and NATS (event stream). + +### Task Routing Algorithm + +``` +fn select_node(task: &Task, registry: &NodeRegistry) -> Option { + registry + .nodes_with_model(&task.model) // 1. Must have the model + .filter(|n| n.free_capacity > 0) // 2. Must have free slots + .filter(|n| n.reputation > THRESHOLD) // 3. Must not be blacklisted + .sort_by(|a, b| { + // 4. Prefer: lowest latency, then highest TPS, then lowest load + a.avg_latency.cmp(&b.avg_latency) + .then(b.tps.cmp(&a.tps)) + .then(a.load_percent.cmp(&b.load_percent)) + }) + .next() +} +``` + +If no node is available, the router returns a `503 Service Unavailable` with a retry-after hint. The client SDK handles retry with backoff. + +### Client-Facing API + +#### Real-time Inference (for agents) + +``` +POST /v2/infer +Content-Type: application/json +Authorization: Bearer + +{ + "model": "gemma3:12b", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"} + ], + "max_tokens": 256, + "stream": true // optional: stream tokens via SSE +} + +→ 200 (streaming): +data: {"token": "The", "index": 0} +data: {"token": " capital", "index": 1} +... +data: {"done": true, "usage": {"prompt_tokens": 24, "completion_tokens": 12}} + +→ 200 (non-streaming): +{ + "result": "The capital of France is Paris.", + "model": "gemma3:12b", + "usage": {"prompt_tokens": 24, "completion_tokens": 12}, + "node": "0xabc..." // optional: which node served this +} +``` + +#### Batch Inference (for data processing) + +``` +POST /v2/batch +Content-Type: application/json +Authorization: Bearer + +{ + "tasks": [ + {"id": "task-001", "model": "gemma3:12b", "messages": [...]}, + {"id": "task-002", "model": "gemma3:12b", "messages": [...]}, + ... + ], + "webhook": "https://your-app.com/callback" // optional +} + +→ 202 Accepted: +{ + "batch_id": "batch-uuid", + "status_url": "/v2/batch/batch-uuid" +} +``` + +Results stream to the webhook as they complete, or poll the status URL. No S3 upload/download cycle. For very large batches (100K+ tasks), the client can upload a JSONL file to any S3-compatible store and POST the URL. + +#### Sub-agent Fan-out (for orchestrators) + +``` +POST /v2/fan-out +Content-Type: application/json +Authorization: Bearer + +{ + "tasks": [ + {"id": "classify", "model": "gemma3:4b", "messages": [...]}, + {"id": "reason-1", "model": "llama3.3:70b", "messages": [...]}, + {"id": "reason-2", "model": "llama3.3:70b", "messages": [...]}, + {"id": "summarize", "model": "gemma3:4b", "messages": [...], + "depends_on": ["reason-1", "reason-2"]} + ] +} +``` + +The router executes independent tasks in parallel across different nodes and respects dependency ordering. `summarize` only runs after `reason-1` and `reason-2` complete, and their outputs are injected into its context. + +--- + +## Component 3: Shared State + +### NATS (replaces EventBridge + RabbitMQ + SQS) + +Single NATS cluster handles: +- **Task events**: routing notifications, completion events +- **Points/billing**: events emitted per completed task +- **Audit logs**: validation results, anomaly alerts +- **Inter-router communication**: if a node reconnects to a different router mid-task + +NATS JetStream provides persistent streams where needed (billing events must not be lost). Regular NATS pub/sub for ephemeral events. + +### Redis (node registry) + +``` +# Per-node state (SET by node via router, READ by any router) +node:{address}:models = ["gemma3:12b", "llama3.1:8b"] +node:{address}:capacity = { free: 1, max: 1 } +node:{address}:tps = { "gemma3:12b": 45.2 } +node:{address}:router = "router-a" +node:{address}:last_seen = 1709312400 +node:{address}:reputation = 0.95 + +# Model index (which nodes serve which model) +model:gemma3:12b:nodes = SET of node addresses +model:llama3.1:8b:nodes = SET of node addresses + +# Blacklist +blacklist:{address}:{model} = TTL-based key +``` + +Redis is fast enough for the read-heavy routing queries (~100K reads/sec per instance). For >1M nodes, shard by model name. + +### PostgreSQL (persistent data) + +- Batch job records, file metadata +- User accounts, API keys +- Historical task results (for large batch jobs) +- Billing records + +PostgreSQL is already cloud-agnostic. No changes needed from v1 except simpler schema (no file status state machine with 8 states). + +--- + +## Validation: Proof-of-Inference + +### Why Current Challenger Is Insufficient + +The current system sends 5 types of deterministic puzzles (addition, leg counting, letter sums, algebra, word repeat). Problems: + +1. All questions can be answered without an LLM (a calculator + regex handles 100%) +2. Only 5 question types — easy to build a specialized solver +3. Separate Python service, separate MongoDB — operational overhead +4. Challenges are infrequent and predictable + +### v2 Validation: Three Layers + +Validation is built into the router, not a separate service: + +#### Layer 1: Timing Analysis (every task, zero cost) + +Every inference result includes timing metadata from the embedded engine: +- `prompt_eval_time_ms`: how long to process the input +- `generation_time_ms`: how long to generate the output +- `tokens_per_second`: eval TPS for this specific request + +The router maintains a statistical model of expected TPS per model per hardware class. Outliers are flagged: +- Too fast → likely not running the model (cached/faked responses) +- Too slow → possible CPU fallback when claiming GPU, or overloaded hardware +- Inconsistent → TPS varies wildly between similar-length prompts + +Timing comes free from the embedded llama.cpp engine — no extra work on the node side. The router just needs to track distributions and flag statistical outliers. + +#### Layer 2: Logprob Spot-Checks (random % of tasks, low cost) + +When the router assigns a task, it can request logprobs at specific token positions: + +``` +Router → Node: { + task: { ... }, + validation: { + request_logprobs_at: [5, 12, 31] // token positions + } +} + +Node → Router: { + result: "The capital of France is Paris...", + proof: { + logprobs: [ + { position: 5, token: "capital", logprob: -0.23, top_5: [...] }, + { position: 12, token: "Paris", logprob: -0.08, top_5: [...] }, + { position: 31, token: ".", logprob: -1.42, top_5: [...] } + ] + } +} +``` + +The router validates by: +1. Checking logprob distributions are plausible for the model +2. Periodically cross-referencing with a trusted validator node running the same model +3. Building a per-node profile — consistent logprob patterns indicate legitimate inference + +Faking logprobs requires actually running the model. They cannot be derived from the text output alone. + +#### Layer 3: KV-Cache Fingerprinting (periodic challenges, highest strength) + +The strongest proof: request the SHA-256 hash of the KV-cache tensor at a specific (layer, position) during inference: + +``` +Router → Node: { + challenge: { + prompt: "The quick brown fox...", + request_kv_hash: { layer: 8, position: 15 } + } +} + +Node → Router: { + kv_hash: "a3f8b2c1d4e5..." +} +``` + +This hash is deterministic for a given model + input + position. Only a node with the model loaded and actively processing the input can produce it. The router verifies against a trusted reference. This is computationally impossible to fake without running the actual model weights. + +KV-cache proofs are the most expensive to verify (router needs a reference node to compare against) so they're issued periodically — more frequently for new/low-reputation nodes, less for established ones. + +#### Reputation Score + +Each node maintains a reputation score (0.0 to 1.0) based on: +- Validation pass rate +- Task completion rate +- Response time consistency +- Historical behavior + +New nodes start at 0.5 and must build reputation through successful validated tasks. Reputation decays slowly over time (must stay active). Nodes below 0.3 are blacklisted. + +``` +reputation_update(node, event): + match event: + TaskCompleted → +0.001 + ValidationPassed → +0.005 + ValidationFailed → -0.1 // harsh penalty + TimingAnomaly → -0.05 + Timeout → -0.02 + Rejection → no change (backpressure is fine) +``` + +High-reputation nodes get: +- Less frequent validation (cost savings for the network) +- Priority in task routing (rewarding good behavior) +- Higher point earnings multiplier + +--- + +## Network Protocol + +### Node ↔ Router: QUIC + +QUIC provides: +- **Built-in encryption** (TLS 1.3) — no need for libp2p's Noise layer +- **Multiplexed streams** — no need for Yamux +- **NAT-friendly** (UDP-based) — works behind most consumer routers +- **0-RTT reconnect** — near-instant reconnection after brief disconnects +- **Connection migration** — survives IP address changes (mobile, DHCP renewal) + +Rust implementation: `quinn` crate (mature, production-ready). + +### Authentication Handshake + +``` +1. Node opens QUIC connection to Router +2. Router sends: { challenge: random_32_bytes } +3. Node signs challenge with secp256k1 private key +4. Node sends: { + address: "0xabc...", + peer_id: "16Uiu2HAm...", + signature: "0x...", + recovery_id: 0, + models: ["gemma3:12b"], + tps: { "gemma3:12b": 45.2 }, + version: "2.0.0", + capacity: { free: 1, max: 1 } + } +5. Router recovers public key from signature, verifies address +6. Router sends: { authenticated: true, node_id: "..." } +``` + +No persistent identity storage needed. The node proves identity on every connection using its wallet key. + +### Message Format + +```rust +// Flat, simple, no base64 wrapping, no nested JSON-in-JSON +#[derive(Serialize, Deserialize)] +enum NodeMessage { + // Node → Router + TaskResult { + task_id: Uuid, + result: String, + proof: Option, + stats: TaskStats, + }, + TaskRejected { + task_id: Uuid, + reason: RejectReason, // AtCapacity, ModelUnloaded, etc. + }, + StatusUpdate { + capacity: Capacity, + models_loaded: Vec, + }, + ChallengeResponse { + challenge_id: Uuid, + proof: InferenceProof, + }, +} + +#[derive(Serialize, Deserialize)] +enum RouterMessage { + // Router → Node + TaskAssignment { + task_id: Uuid, + model: String, + messages: Vec, + max_tokens: u32, + validation: Option, + }, + Challenge { + challenge_id: Uuid, + prompt: String, + proof_request: ProofRequest, + }, + Ping, + ModelRegistryUpdate { + models: Vec, + }, +} +``` + +Serialized as MessagePack (binary, ~30% smaller than JSON, faster to parse) over QUIC streams. No base64 encoding. No JSON-in-JSON. No triple serialization. + +--- + +## Migration Path + +### Phase 1: New Node Binary (weeks 1-4) + +Build the new `dria-node` with: +- [ ] Embedded llama.cpp via `llama-cpp-2` crate +- [ ] GGUF model download from HuggingFace +- [ ] Hardware detection and TPS benchmarking +- [ ] QUIC connection to router (using `quinn`) +- [ ] secp256k1 authentication handshake +- [ ] Task execution with streaming +- [ ] Backpressure (reject when at capacity) +- [ ] Logprob extraction for validation +- [ ] Single-binary builds for Linux (CUDA, CPU), macOS (Metal) + +**Backwards compatibility**: The new node can initially speak the v1 libp2p protocol to connect to existing RPC gateways. This allows incremental rollout — operators upgrade their node binary while the backend remains unchanged. + +### Phase 2: Router MVP (weeks 3-6) + +Build the first Dria Router with: +- [ ] QUIC listener for node connections +- [ ] Node registry in Redis +- [ ] Real-time inference API (POST /v2/infer) +- [ ] Task routing algorithm (model match → capacity → latency) +- [ ] Result streaming via SSE/WebSocket +- [ ] Basic validation (timing analysis + logprob spot-checks) +- [ ] NATS integration for events + +Run alongside v1 infrastructure. Clients can use either the v1 batch API or the v2 real-time API. + +### Phase 3: Batch & Fan-out (weeks 5-8) + +- [ ] Batch API (POST /v2/batch) +- [ ] Sub-agent fan-out with dependency DAGs +- [ ] Webhook result delivery +- [ ] Large batch support (JSONL upload to S3-compatible store) +- [ ] Cross-verification validation +- [ ] Reputation system + +### Phase 4: v1 Deprecation (weeks 8-12) + +- [ ] Migrate all batch API clients to v2 +- [ ] Shut down Harbor services one by one +- [ ] Remove Dispatcher, RPC Gateway, RabbitMQ +- [ ] Remove Challenger API (validation is in-router now) +- [ ] Remove Ollama dependency from node documentation + +### Phase 5: Advanced Features (ongoing) + +- [ ] KV-cache proof-of-inference +- [ ] Session affinity (multi-turn KV-cache reuse) +- [ ] Model hot-swap (switch models without restart) +- [ ] Multi-GPU inference (tensor parallelism via llama.cpp) +- [ ] gRPC API for high-performance clients +- [ ] Geographic routing (prefer nodes close to the client) +- [ ] Node-to-node communication for collaborative inference + +--- + +## Comparison: v1 vs v2 + +| Dimension | v1 (Current) | v2 (Proposed) | +|---|---|---| +| **Task latency** | Minutes (batch pipeline) | Seconds (direct routing) | +| **Streaming** | No | Yes (token-by-token) | +| **Node setup** | Install Ollama + Launcher + configure .env | Single binary, one command | +| **Adding a model** | Code change + binary release | Config update on router | +| **Services to operate** | 10+ | 3 (Router, NATS, Redis) | +| **Cloud dependency** | AWS (ECS, EventBridge, S3, SQS) | Any cloud or bare metal | +| **Validation** | Gameable math puzzles | Logprobs, timing, KV-cache proofs | +| **Backpressure** | None (queue and hope) | Nodes reject, router re-routes | +| **Agentic support** | None (batch only) | Real-time, streaming, fan-out, DAGs | +| **P2P protocol** | libp2p (TCP+Noise+Yamux+CBOR) | QUIC (built-in encryption+mux) | +| **Inference engine** | Ollama (separate process, HTTP) | Embedded llama.cpp (in-process) | +| **Node count scaling** | ~30 RPC gateways | Stateless router fleet, linear scaling | +| **Code size (node)** | ~5,300 lines, 4 crates | ~3,000 lines, 1 crate | +| **Message format** | JSON → base64 → sign → JSON → CBOR | MessagePack, signed, direct | + +--- + +## Open Questions + +1. **QUIC vs WebSocket**: QUIC is technically superior but WebSocket has broader NAT traversal success. Could offer both — QUIC primary, WebSocket fallback for restrictive networks. + +2. **llama-cpp-2 vs candle**: llama.cpp has the best hardware coverage and GGUF ecosystem. Candle is pure Rust (simpler builds) but narrower model support. Recommend llama.cpp for now. + +3. **Model registry governance**: Who decides which models are available on the network? Currently hardcoded in Rust. Should be a curated list managed by Dria, pushed to nodes via router. + +4. **Multi-GPU nodes**: Some operators have 8x H100 setups. The new node should support tensor parallelism via llama.cpp's built-in support. How does this affect task routing? + +5. **Pricing model**: v1 is per-token. v2 could be per-token (same), per-request (simpler), or per-second-of-compute (fairer for different model sizes). + +6. **SDK design**: The agent SDK (Python/TypeScript) should abstract away routing, streaming, retries. What does the ideal developer experience look like? + +7. **Testnet**: Replace mockollama with a `--mock` flag built into the node binary. The node generates deterministic responses without loading a real model. Simpler than a separate mock server. diff --git a/Cargo.lock b/Cargo.lock index 34d76b28..00578aa5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,41 +17,6 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" -[[package]] -name = "aead" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" -dependencies = [ - "crypto-common", - "generic-array", -] - -[[package]] -name = "aes" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" -dependencies = [ - "cfg-if", - "cipher", - "cpufeatures", -] - -[[package]] -name = "aes-gcm" -version = "0.10.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" -dependencies = [ - "aead", - "aes", - "cipher", - "ctr", - "ghash", - "subtle", -] - [[package]] name = "aho-corasick" version = "1.1.3" @@ -61,27 +26,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "allocator-api2" -version = "0.2.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" - -[[package]] -name = "android-tzdata" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" - -[[package]] -name = "android_system_properties" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" -dependencies = [ - "libc", -] - [[package]] name = "anstream" version = "0.6.18" @@ -144,155 +88,12 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" -[[package]] -name = "asn1-rs" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" -dependencies = [ - "asn1-rs-derive", - "asn1-rs-impl", - "displaydoc", - "nom", - "num-traits", - "rusticata-macros", - "thiserror 1.0.69", - "time", -] - -[[package]] -name = "asn1-rs-derive" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "synstructure", -] - -[[package]] -name = "asn1-rs-impl" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "asn1_der" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "155a5a185e42c6b77ac7b88a15143d930a9e9727a5b7b77eed417404ab15c247" - -[[package]] -name = "async-io" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a2b323ccce0a1d90b449fd71f2a06ca7faa7c54c2751f06c9bd851fc061059" -dependencies = [ - "async-lock", - "cfg-if", - "concurrent-queue", - "futures-io", - "futures-lite", - "parking", - "polling", - "rustix 0.38.44", - "slab", - "tracing", - "windows-sys 0.59.0", -] - -[[package]] -name = "async-lock" -version = "3.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" -dependencies = [ - "event-listener", - "event-listener-strategy", - "pin-project-lite", -] - -[[package]] -name = "async-recursion" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "async-stream" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "async-trait" -version = "0.1.88" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "asynchronous-codec" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a860072022177f903e59730004fb5dc13db9275b79bb2aef7ba8ce831956c233" -dependencies = [ - "bytes", - "futures-sink", - "futures-util", - "memchr", - "pin-project-lite", -] - [[package]] name = "atomic-waker" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" -[[package]] -name = "attohttpc" -version = "0.24.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d9a9bf8b79a749ee0b911b91b671cc2b6c670bdbc7e3dfd537576ddc94bb2a2" -dependencies = [ - "http 0.2.12", - "log", - "url", -] - [[package]] name = "autocfg" version = "1.4.0" @@ -314,18 +115,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "base-x" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cbbc9d0964165b47557570cce6c952866c2678457aca742aafc9fb771d30270" - -[[package]] -name = "base16ct" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" - [[package]] name = "base64" version = "0.22.1" @@ -333,31 +122,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] -name = "base64ct" -version = "1.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3" - -[[package]] -name = "bitflags" -version = "1.3.2" +name = "bindgen" +version = "0.72.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "itertools", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", +] [[package]] name = "bitflags" -version = "2.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" - -[[package]] -name = "blake2" -version = "0.10.6" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" -dependencies = [ - "digest 0.10.7", -] +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" [[package]] name = "block-buffer" @@ -377,15 +165,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "bs58" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf88ba1141d185c399bee5288d850d63b8369520c1eafc32a0430b5b6c287bf4" -dependencies = [ - "tinyvec", -] - [[package]] name = "bumpalo" version = "3.17.0" @@ -405,21 +184,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] -name = "cbor4ii" -version = "0.3.3" +name = "cc" +version = "1.2.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "472931dd4dfcc785075b09be910147f9c6258883fc4591d0dac6116392b2daa6" +checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" dependencies = [ - "serde", + "find-msvc-tools", + "jobserver", + "libc", + "shlex", ] [[package]] -name = "cc" -version = "1.2.20" +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + +[[package]] +name = "cexpr" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04da6a0d40b948dfc4fa8f5bbf402b0fc1a64a28dbf7d12ffd683550f2c1b63a" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" dependencies = [ - "shlex", + "nom", ] [[package]] @@ -435,53 +223,63 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] -name = "chacha20" -version = "0.9.1" +name = "clang-sys" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" dependencies = [ - "cfg-if", - "cipher", - "cpufeatures", + "glob", + "libc", + "libloading", ] [[package]] -name = "chacha20poly1305" -version = "0.10.1" +name = "clap" +version = "4.5.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" dependencies = [ - "aead", - "chacha20", - "cipher", - "poly1305", - "zeroize", + "clap_builder", + "clap_derive", ] [[package]] -name = "chrono" -version = "0.4.41" +name = "clap_builder" +version = "4.5.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" dependencies = [ - "android-tzdata", - "iana-time-zone", - "js-sys", - "num-traits", - "serde", - "wasm-bindgen", - "windows-link", + "anstream", + "anstyle", + "clap_lex", + "strsim", ] [[package]] -name = "cipher" -version = "0.4.4" +name = "clap_derive" +version = "4.5.55" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" dependencies = [ - "crypto-common", - "inout", - "zeroize", + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" + +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", ] [[package]] @@ -491,34 +289,43 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] -name = "colored" -version = "3.0.0" +name = "combine" +version = "4.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" dependencies = [ - "windows-sys 0.59.0", + "bytes", + "memchr", ] [[package]] -name = "concurrent-queue" -version = "2.5.0" +name = "console" +version = "0.15.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" dependencies = [ - "crossbeam-utils", + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", ] [[package]] -name = "const-oid" -version = "0.9.6" +name = "core-foundation" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] [[package]] name = "core-foundation" -version = "0.9.4" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" dependencies = [ "core-foundation-sys", "libc", @@ -531,195 +338,56 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] -name = "core2" -version = "0.4.0" +name = "cpufeatures" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b49ba7ef1ad6107f8824dbe97de947cbaac53c44e7f9756a1fba0d37c1eec505" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ - "memchr", + "libc", ] [[package]] -name = "cpufeatures" -version = "0.2.17" +name = "crc32fast" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" dependencies = [ - "libc", + "cfg-if", ] [[package]] -name = "critical-section" -version = "1.2.0" +name = "crunchy" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" [[package]] -name = "crossbeam-channel" -version = "0.5.15" +name = "crypto-common" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ - "crossbeam-utils", + "generic-array", + "typenum", ] [[package]] -name = "crossbeam-deque" -version = "0.8.6" +name = "crypto-mac" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +checksum = "b584a330336237c1eecd3e94266efb216c56ed91225d634cb2991c5f3fd1aeab" dependencies = [ - "crossbeam-epoch", - "crossbeam-utils", + "generic-array", + "subtle", ] [[package]] -name = "crossbeam-epoch" -version = "0.9.18" +name = "deranged" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c" dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" - -[[package]] -name = "crunchy" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" - -[[package]] -name = "crypto-bigint" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" -dependencies = [ - "generic-array", - "rand_core 0.6.4", - "subtle", - "zeroize", -] - -[[package]] -name = "crypto-common" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" -dependencies = [ - "generic-array", - "rand_core 0.6.4", - "typenum", -] - -[[package]] -name = "crypto-mac" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b584a330336237c1eecd3e94266efb216c56ed91225d634cb2991c5f3fd1aeab" -dependencies = [ - "generic-array", - "subtle", -] - -[[package]] -name = "ctr" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" -dependencies = [ - "cipher", -] - -[[package]] -name = "curve25519-dalek" -version = "4.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" -dependencies = [ - "cfg-if", - "cpufeatures", - "curve25519-dalek-derive", - "digest 0.10.7", - "fiat-crypto", - "rustc_version", - "subtle", - "zeroize", -] - -[[package]] -name = "curve25519-dalek-derive" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "data-encoding" -version = "2.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" - -[[package]] -name = "data-encoding-macro" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47ce6c96ea0102f01122a185683611bd5ac8d99e62bc59dd12e6bda344ee673d" -dependencies = [ - "data-encoding", - "data-encoding-macro-internal", -] - -[[package]] -name = "data-encoding-macro-internal" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d162beedaa69905488a8da94f5ac3edb4dd4788b732fadb7bd120b2625c1976" -dependencies = [ - "data-encoding", - "syn", -] - -[[package]] -name = "der" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" -dependencies = [ - "const-oid", - "zeroize", -] - -[[package]] -name = "der-parser" -version = "9.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" -dependencies = [ - "asn1-rs", - "displaydoc", - "nom", - "num-bigint", - "num-traits", - "rusticata-macros", -] - -[[package]] -name = "deranged" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" -dependencies = [ - "powerfmt", + "powerfmt", ] [[package]] @@ -738,30 +406,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer 0.10.4", - "const-oid", "crypto-common", - "subtle", ] [[package]] -name = "directories" -version = "5.0.1" +name = "dirs" +version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a49173b84e034382284f27f1af4dcbbd231ffa358c0fe316541a7337f376a35" +checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e" dependencies = [ "dirs-sys", ] [[package]] name = "dirs-sys" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] @@ -776,168 +442,37 @@ dependencies = [ ] [[package]] -name = "dkn-compute" -version = "0.6.7" +name = "dria-node" +version = "2.0.0-alpha.1" dependencies = [ - "base64", - "chrono", - "colored", - "dkn-executor", - "dkn-p2p", - "dkn-utils", - "dotenvy", - "ecies", - "env_logger", - "eyre", + "anyhow", + "bytes", + "clap", + "dirs", + "encoding_rs", + "futures", "hex", - "hex-literal", + "hf-hub", "libsecp256k1", - "log", - "openssl", - "port_check", - "public-ip-address", + "llama-cpp-2", + "quinn", "rand 0.8.5", - "reqwest", - "serde", - "serde_json", - "sysinfo", - "tokio", - "tokio-util", - "url", - "urlencoding", - "uuid", -] - -[[package]] -name = "dkn-executor" -version = "0.6.7" -dependencies = [ - "dkn-utils", - "dotenvy", - "enum-iterator", - "env_logger", - "eyre", - "log", - "ollama-rs", - "reqwest", - "rig-core", - "serde", - "serde_json", - "thiserror 2.0.12", - "tokio", - "tokio-util", -] - -[[package]] -name = "dkn-p2p" -version = "0.6.7" -dependencies = [ - "dkn-utils", - "env_logger", - "eyre", - "libp2p", - "libp2p-identity", - "log", - "serde", - "serde_json", - "tokio", - "tokio-util", -] - -[[package]] -name = "dkn-utils" -version = "0.6.7" -dependencies = [ - "base64", - "chrono", - "ecies", - "hex", - "libp2p-identity", - "libsecp256k1", - "public-ip-address", + "rcgen", + "rmp-serde", + "rustls", + "rustls-native-certs", "serde", "serde_json", "sha2 0.10.9", "sha3", "thiserror 2.0.12", + "tokio", + "tokio-util", + "tracing", + "tracing-subscriber", "uuid", ] -[[package]] -name = "dotenvy" -version = "0.15.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" - -[[package]] -name = "dtoa" -version = "1.0.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6add3b8cff394282be81f3fc1a0605db594ed69890078ca6e2cab1c408bcf04" - -[[package]] -name = "dyn-clone" -version = "1.0.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" - -[[package]] -name = "ecdsa" -version = "0.16.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" -dependencies = [ - "der", - "digest 0.10.7", - "elliptic-curve", - "rfc6979", - "signature", - "spki", -] - -[[package]] -name = "ecies" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "011318cc6f4f1906c1dae015013fd381e92deac290a29ddcd9f2e0dd14786037" -dependencies = [ - "aes-gcm", - "getrandom 0.2.16", - "hkdf", - "libsecp256k1", - "once_cell", - "parking_lot", - "rand_core 0.6.4", - "sha2 0.10.9", - "typenum", - "wasm-bindgen", -] - -[[package]] -name = "ed25519" -version = "2.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" -dependencies = [ - "pkcs8", - "signature", -] - -[[package]] -name = "ed25519-dalek" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a3daa8e81a3963a60642bcc1f90a670680bd4a77535faa384e9d1c79d620871" -dependencies = [ - "curve25519-dalek", - "ed25519", - "rand_core 0.6.4", - "serde", - "sha2 0.10.9", - "subtle", - "zeroize", -] - [[package]] name = "either" version = "1.15.0" @@ -945,23 +480,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] -name = "elliptic-curve" -version = "0.13.8" +name = "encode_unicode" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" -dependencies = [ - "base16ct", - "crypto-bigint", - "digest 0.10.7", - "ff", - "generic-array", - "group", - "pkcs8", - "rand_core 0.6.4", - "sec1", - "subtle", - "zeroize", -] +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" [[package]] name = "encoding_rs" @@ -973,60 +495,25 @@ dependencies = [ ] [[package]] -name = "enum-as-inner" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "enum-iterator" -version = "2.1.0" +name = "enumflags2" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c280b9e6b3ae19e152d8e31cf47f18389781e119d4013a2a2bb0180e5facc635" +checksum = "1027f7680c853e056ebcec683615fb6fbbc07dbaa13b4d5d9442b146ded4ecef" dependencies = [ - "enum-iterator-derive", + "enumflags2_derive", ] [[package]] -name = "enum-iterator-derive" -version = "1.4.0" +name = "enumflags2_derive" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1ab991c1362ac86c61ab6f556cff143daa22e5a15e4e189df818b2fd19fe65b" +checksum = "67c78a4d8fdf9953a5c9d458f9efe940fd97a0cab0941c075a813ac594733827" dependencies = [ "proc-macro2", "quote", "syn", ] -[[package]] -name = "env_filter" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" -dependencies = [ - "log", - "regex", -] - -[[package]] -name = "env_logger" -version = "0.11.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" -dependencies = [ - "anstream", - "anstyle", - "env_filter", - "jiff", - "log", -] - [[package]] name = "equivalent" version = "1.0.2" @@ -1044,34 +531,15 @@ dependencies = [ ] [[package]] -name = "event-listener" -version = "5.4.0" +name = "fastbloom" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" +checksum = "4e7f34442dbe69c60fe8eaf58a8cafff81a1f278816d8ab4db255b3bef4ac3c4" dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - -[[package]] -name = "event-listener-strategy" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" -dependencies = [ - "event-listener", - "pin-project-lite", -] - -[[package]] -name = "eyre" -version = "0.6.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" -dependencies = [ - "indenter", - "once_cell", + "getrandom 0.3.2", + "libm", + "rand 0.9.1", + "siphasher", ] [[package]] @@ -1081,20 +549,29 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] -name = "ff" -version = "0.13.1" +name = "find-msvc-tools" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "find_cuda_helper" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9f9e65c593dd01ac77daad909ea4ad17f0d6d1776193fc8ea766356177abdad" dependencies = [ - "rand_core 0.6.4", - "subtle", + "glob", ] [[package]] -name = "fiat-crypto" -version = "0.2.9" +name = "flate2" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] [[package]] name = "fnv" @@ -1102,12 +579,6 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "foldhash" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" - [[package]] name = "foreign-types" version = "0.3.2" @@ -1147,16 +618,6 @@ dependencies = [ "futures-util", ] -[[package]] -name = "futures-bounded" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91f328e7fb845fc832912fb6a34f40cf6d1888c92f974d1893a54e97b5ff542e" -dependencies = [ - "futures-timer", - "futures-util", -] - [[package]] name = "futures-channel" version = "0.3.31" @@ -1182,7 +643,6 @@ dependencies = [ "futures-core", "futures-task", "futures-util", - "num_cpus", ] [[package]] @@ -1191,16 +651,6 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" -[[package]] -name = "futures-lite" -version = "2.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5edaec856126859abb19ed65f39e90fea3a9574b9707f13539acf4abf7eb532" -dependencies = [ - "futures-core", - "pin-project-lite", -] - [[package]] name = "futures-macro" version = "0.3.31" @@ -1212,17 +662,6 @@ dependencies = [ "syn", ] -[[package]] -name = "futures-rustls" -version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8f2f12607f92c69b12ed746fabf9ca4f5c482cba46679c1a75b874ed7c26adb" -dependencies = [ - "futures-io", - "rustls", - "rustls-pki-types", -] - [[package]] name = "futures-sink" version = "0.3.31" @@ -1235,12 +674,6 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" -[[package]] -name = "futures-timer" -version = "3.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" - [[package]] name = "futures-util" version = "0.3.31" @@ -1259,19 +692,6 @@ dependencies = [ "slab", ] -[[package]] -name = "generator" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc6bd114ceda131d3b1d665eba35788690ad37f5916457286b32ab6fd3c438dd" -dependencies = [ - "cfg-if", - "libc", - "log", - "rustversion", - "windows 0.58.0", -] - [[package]] name = "generic-array" version = "0.14.7" @@ -1280,7 +700,6 @@ checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", - "zeroize", ] [[package]] @@ -1310,16 +729,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "ghash" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" -dependencies = [ - "opaque-debug", - "polyval", -] - [[package]] name = "gimli" version = "0.31.1" @@ -1328,20 +737,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" - -[[package]] -name = "group" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" -dependencies = [ - "ff", - "rand_core 0.6.4", - "subtle", -] +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "h2" @@ -1354,30 +752,19 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http 1.3.1", - "indexmap 2.9.0", + "http", + "indexmap", "slab", "tokio", "tokio-util", "tracing", ] -[[package]] -name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - [[package]] name = "hashbrown" version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" -dependencies = [ - "allocator-api2", - "equivalent", - "foldhash", -] [[package]] name = "heck" @@ -1392,77 +779,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] -name = "hermit-abi" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" - -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - -[[package]] -name = "hex-literal" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" - -[[package]] -name = "hickory-proto" -version = "0.25.0-alpha.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d00147af6310f4392a31680db52a3ed45a2e0f68eb18e8c3fe5537ecc96d9e2" -dependencies = [ - "async-recursion", - "async-trait", - "cfg-if", - "data-encoding", - "enum-as-inner", - "futures-channel", - "futures-io", - "futures-util", - "idna", - "ipnet", - "once_cell", - "rand 0.9.1", - "socket2", - "thiserror 2.0.12", - "tinyvec", - "tokio", - "tracing", - "url", -] - -[[package]] -name = "hickory-resolver" -version = "0.25.0-alpha.5" +name = "hex" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5762f69ebdbd4ddb2e975cd24690bf21fe6b2604039189c26acddbc427f12887" -dependencies = [ - "cfg-if", - "futures-util", - "hickory-proto", - "ipconfig", - "moka", - "once_cell", - "parking_lot", - "rand 0.9.1", - "resolv-conf", - "smallvec", - "thiserror 2.0.12", - "tokio", - "tracing", -] +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] -name = "hkdf" -version = "0.12.4" +name = "hf-hub" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97" dependencies = [ - "hmac 0.12.1", + "dirs", + "futures", + "http", + "indicatif", + "libc", + "log", + "native-tls", + "num_cpus", + "rand 0.9.1", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.12", + "tokio", + "ureq", + "windows-sys 0.60.2", ] [[package]] @@ -1475,15 +818,6 @@ dependencies = [ "digest 0.9.0", ] -[[package]] -name = "hmac" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" -dependencies = [ - "digest 0.10.7", -] - [[package]] name = "hmac-drbg" version = "0.3.0" @@ -1492,18 +826,7 @@ checksum = "17ea0a1394df5b6574da6e0c1ade9e78868c9fb0a4e5ef4428e32da4676b85b1" dependencies = [ "digest 0.9.0", "generic-array", - "hmac 0.8.1", -] - -[[package]] -name = "http" -version = "0.2.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" -dependencies = [ - "bytes", - "fnv", - "itoa", + "hmac", ] [[package]] @@ -1524,7 +847,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.3.1", + "http", ] [[package]] @@ -1535,7 +858,7 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http 1.3.1", + "http", "http-body", "pin-project-lite", ] @@ -1556,7 +879,7 @@ dependencies = [ "futures-channel", "futures-util", "h2", - "http 1.3.1", + "http", "http-body", "httparse", "itoa", @@ -1573,7 +896,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" dependencies = [ "futures-util", - "http 1.3.1", + "http", "hyper", "hyper-util", "rustls", @@ -1581,7 +904,6 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", ] [[package]] @@ -1609,7 +931,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.3.1", + "http", "http-body", "hyper", "libc", @@ -1620,30 +942,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "iana-time-zone" -version = "0.1.63" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" -dependencies = [ - "android_system_properties", - "core-foundation-sys", - "iana-time-zone-haiku", - "js-sys", - "log", - "wasm-bindgen", - "windows-core 0.61.0", -] - -[[package]] -name = "iana-time-zone-haiku" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" -dependencies = [ - "cc", -] - [[package]] name = "icu_collections" version = "1.5.0" @@ -1670,644 +968,247 @@ dependencies = [ ] [[package]] -name = "icu_locid_transform" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_locid_transform_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_locid_transform_data" -version = "1.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7515e6d781098bf9f7205ab3fc7e9709d34554ae0b21ddbcb5febfa4bc7df11d" - -[[package]] -name = "icu_normalizer" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" -dependencies = [ - "displaydoc", - "icu_collections", - "icu_normalizer_data", - "icu_properties", - "icu_provider", - "smallvec", - "utf16_iter", - "utf8_iter", - "write16", - "zerovec", -] - -[[package]] -name = "icu_normalizer_data" -version = "1.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5e8338228bdc8ab83303f16b797e177953730f601a96c25d10cb3ab0daa0cb7" - -[[package]] -name = "icu_properties" -version = "1.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" -dependencies = [ - "displaydoc", - "icu_collections", - "icu_locid_transform", - "icu_properties_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_properties_data" -version = "1.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85fb8799753b75aee8d2a21d7c14d9f38921b54b3dbda10f5a3c7a7b82dba5e2" - -[[package]] -name = "icu_provider" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_provider_macros", - "stable_deref_trait", - "tinystr", - "writeable", - "yoke", - "zerofrom", - "zerovec", -] - -[[package]] -name = "icu_provider_macros" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "idna" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" -dependencies = [ - "idna_adapter", - "smallvec", - "utf8_iter", -] - -[[package]] -name = "idna_adapter" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" -dependencies = [ - "icu_normalizer", - "icu_properties", -] - -[[package]] -name = "if-addrs" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cabb0019d51a643781ff15c9c8a3e5dedc365c47211270f4e8f82812fedd8f0a" -dependencies = [ - "libc", - "windows-sys 0.48.0", -] - -[[package]] -name = "if-watch" -version = "3.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdf9d64cfcf380606e64f9a0bcf493616b65331199f984151a6fa11a7b3cde38" -dependencies = [ - "async-io", - "core-foundation", - "fnv", - "futures", - "if-addrs", - "ipnet", - "log", - "netlink-packet-core", - "netlink-packet-route", - "netlink-proto", - "netlink-sys", - "rtnetlink", - "system-configuration", - "tokio", - "windows 0.53.0", -] - -[[package]] -name = "igd-next" -version = "0.15.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76b0d7d4541def58a37bf8efc559683f21edce7c82f0d866c93ac21f7e098f93" -dependencies = [ - "async-trait", - "attohttpc", - "bytes", - "futures", - "http 1.3.1", - "http-body-util", - "hyper", - "hyper-util", - "log", - "rand 0.8.5", - "tokio", - "url", - "xmltree", -] - -[[package]] -name = "indenter" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" - -[[package]] -name = "indexmap" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" -dependencies = [ - "autocfg", - "hashbrown 0.12.3", - "serde", -] - -[[package]] -name = "indexmap" -version = "2.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" -dependencies = [ - "equivalent", - "hashbrown 0.15.3", -] - -[[package]] -name = "inout" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" -dependencies = [ - "generic-array", -] - -[[package]] -name = "ipconfig" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" -dependencies = [ - "socket2", - "widestring", - "windows-sys 0.48.0", - "winreg", -] - -[[package]] -name = "ipnet" -version = "2.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" - -[[package]] -name = "is_terminal_polyfill" -version = "1.70.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" - -[[package]] -name = "itoa" -version = "1.0.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" - -[[package]] -name = "jiff" -version = "0.2.10" +name = "icu_locid_transform" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a064218214dc6a10fbae5ec5fa888d80c45d611aba169222fc272072bf7aef6" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" dependencies = [ - "jiff-static", - "log", - "portable-atomic", - "portable-atomic-util", - "serde", + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", ] [[package]] -name = "jiff-static" -version = "0.2.10" +name = "icu_locid_transform_data" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "199b7932d97e325aff3a7030e141eafe7f2c6268e1d1b24859b753a627f45254" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] +checksum = "7515e6d781098bf9f7205ab3fc7e9709d34554ae0b21ddbcb5febfa4bc7df11d" [[package]] -name = "js-sys" -version = "0.3.77" +name = "icu_normalizer" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" dependencies = [ - "once_cell", - "wasm-bindgen", + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", ] [[package]] -name = "k256" -version = "0.13.4" +name = "icu_normalizer_data" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6e3919bbaa2945715f0bb6d3934a173d1e9a59ac23767fbaaef277265a7411b" -dependencies = [ - "cfg-if", - "ecdsa", - "elliptic-curve", - "once_cell", - "sha2 0.10.9", - "signature", -] +checksum = "c5e8338228bdc8ab83303f16b797e177953730f601a96c25d10cb3ab0daa0cb7" [[package]] -name = "keccak" -version = "0.1.5" +name = "icu_properties" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" dependencies = [ - "cpufeatures", + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", ] [[package]] -name = "lazy_static" -version = "1.5.0" +name = "icu_properties_data" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +checksum = "85fb8799753b75aee8d2a21d7c14d9f38921b54b3dbda10f5a3c7a7b82dba5e2" [[package]] -name = "libc" -version = "0.2.172" +name = "icu_provider" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] [[package]] -name = "libp2p" -version = "0.55.0" +name = "icu_provider_macros" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b72dc443ddd0254cb49a794ed6b6728400ee446a0f7ab4a07d0209ee98de20e9" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ - "bytes", - "either", - "futures", - "futures-timer", - "getrandom 0.2.16", - "libp2p-allow-block-list", - "libp2p-connection-limits", - "libp2p-core", - "libp2p-dns", - "libp2p-identify", - "libp2p-identity", - "libp2p-mdns", - "libp2p-metrics", - "libp2p-noise", - "libp2p-quic", - "libp2p-request-response", - "libp2p-swarm", - "libp2p-tcp", - "libp2p-upnp", - "libp2p-yamux", - "multiaddr", - "pin-project", - "rw-stream-sink", - "thiserror 2.0.12", + "proc-macro2", + "quote", + "syn", ] [[package]] -name = "libp2p-allow-block-list" -version = "0.5.0" +name = "idna" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38944b7cb981cc93f2f0fb411ff82d0e983bd226fbcc8d559639a3a73236568b" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ - "libp2p-core", - "libp2p-identity", - "libp2p-swarm", + "idna_adapter", + "smallvec", + "utf8_iter", ] [[package]] -name = "libp2p-connection-limits" -version = "0.5.0" +name = "idna_adapter" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efe9323175a17caa8a2ed4feaf8a548eeef5e0b72d03840a0eab4bcb0210ce1c" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" dependencies = [ - "libp2p-core", - "libp2p-identity", - "libp2p-swarm", + "icu_normalizer", + "icu_properties", ] [[package]] -name = "libp2p-core" -version = "0.43.0" +name = "indexmap" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "193c75710ba43f7504ad8f58a62ca0615b1d7e572cb0f1780bc607252c39e9ef" +checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ - "either", - "fnv", - "futures", - "futures-timer", - "libp2p-identity", - "multiaddr", - "multihash", - "multistream-select", - "once_cell", - "parking_lot", - "pin-project", - "quick-protobuf", - "rand 0.8.5", - "rw-stream-sink", - "thiserror 2.0.12", - "tracing", - "unsigned-varint 0.8.0", - "web-time", + "equivalent", + "hashbrown", ] [[package]] -name = "libp2p-dns" -version = "0.43.0" +name = "indicatif" +version = "0.17.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b780a1150214155b0ed1cdf09fbd2e1b0442604f9146a431d1b21d23eef7bd7" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" dependencies = [ - "async-trait", - "futures", - "hickory-resolver", - "libp2p-core", - "libp2p-identity", - "parking_lot", - "smallvec", - "tracing", + "console", + "number_prefix", + "portable-atomic", + "unicode-width", + "web-time", ] [[package]] -name = "libp2p-identify" -version = "0.46.0" +name = "ipnet" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c06862544f02d05d62780ff590cc25a75f5c2b9df38ec7a370dcae8bb873cf" -dependencies = [ - "asynchronous-codec", - "either", - "futures", - "futures-bounded", - "futures-timer", - "libp2p-core", - "libp2p-identity", - "libp2p-swarm", - "quick-protobuf", - "quick-protobuf-codec", - "smallvec", - "thiserror 2.0.12", - "tracing", -] +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] -name = "libp2p-identity" -version = "0.2.11" +name = "is_terminal_polyfill" +version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbb68ea10844211a59ce46230909fd0ea040e8a192454d4cc2ee0d53e12280eb" -dependencies = [ - "asn1_der", - "bs58", - "ed25519-dalek", - "hkdf", - "k256", - "multihash", - "quick-protobuf", - "rand 0.8.5", - "sha2 0.10.9", - "thiserror 2.0.12", - "tracing", - "zeroize", -] +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] -name = "libp2p-mdns" -version = "0.47.0" +name = "itertools" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11d0ba095e1175d797540e16b62e7576846b883cb5046d4159086837b36846cc" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" dependencies = [ - "futures", - "hickory-proto", - "if-watch", - "libp2p-core", - "libp2p-identity", - "libp2p-swarm", - "rand 0.8.5", - "smallvec", - "socket2", - "tokio", - "tracing", + "either", ] [[package]] -name = "libp2p-metrics" -version = "0.16.0" +name = "itoa" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ce58c64292e87af624fcb86465e7dd8342e46a388d71e8fec0ab37ee789630a" -dependencies = [ - "futures", - "libp2p-core", - "libp2p-identify", - "libp2p-identity", - "libp2p-swarm", - "pin-project", - "prometheus-client", - "web-time", -] +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] -name = "libp2p-noise" -version = "0.46.0" +name = "jni" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afcc133e0f3cea07acde6eb8a9665cb11b600bd61110b010593a0210b8153b16" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" dependencies = [ - "asynchronous-codec", - "bytes", - "futures", - "libp2p-core", - "libp2p-identity", - "multiaddr", - "multihash", - "once_cell", - "quick-protobuf", - "rand 0.8.5", - "snow", - "static_assertions", - "thiserror 2.0.12", - "tracing", - "x25519-dalek", - "zeroize", + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", ] [[package]] -name = "libp2p-quic" -version = "0.12.0" +name = "jni-sys" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41432a159b00424a0abaa2c80d786cddff81055ac24aa127e0cf375f7858d880" -dependencies = [ - "futures", - "futures-timer", - "if-watch", - "libp2p-core", - "libp2p-identity", - "libp2p-tls", - "quinn", - "rand 0.8.5", - "ring", - "rustls", - "socket2", - "thiserror 2.0.12", - "tokio", - "tracing", -] +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" [[package]] -name = "libp2p-request-response" -version = "0.28.0" +name = "jobserver" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "548fe44a80ff275d400f1b26b090d441d83ef73efabbeb6415f4ce37e5aed865" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" dependencies = [ - "async-trait", - "cbor4ii", - "futures", - "futures-bounded", - "libp2p-core", - "libp2p-identity", - "libp2p-swarm", - "rand 0.8.5", - "serde", - "smallvec", - "tracing", + "getrandom 0.3.2", + "libc", ] [[package]] -name = "libp2p-swarm" -version = "0.46.0" +name = "js-sys" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "803399b4b6f68adb85e63ab573ac568154b193e9a640f03e0f2890eabbcb37f8" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ - "either", - "fnv", - "futures", - "futures-timer", - "libp2p-core", - "libp2p-identity", - "libp2p-swarm-derive", - "lru", - "multistream-select", "once_cell", - "rand 0.8.5", - "smallvec", - "tokio", - "tracing", - "web-time", + "wasm-bindgen", ] [[package]] -name = "libp2p-swarm-derive" -version = "0.35.0" +name = "keccak" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206e0aa0ebe004d778d79fb0966aa0de996c19894e2c0605ba2f8524dd4443d8" +checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn", + "cpufeatures", ] [[package]] -name = "libp2p-tcp" -version = "0.43.0" +name = "lazy_static" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65346fb4d36035b23fec4e7be4c320436ba53537ce9b6be1d1db1f70c905cad0" -dependencies = [ - "futures", - "futures-timer", - "if-watch", - "libc", - "libp2p-core", - "socket2", - "tokio", - "tracing", -] +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] -name = "libp2p-tls" -version = "0.6.1" +name = "libc" +version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42bbf5084fb44133267ad4caaa72a253d68d709edd2ed1cf9b42431a8ead8fd5" -dependencies = [ - "futures", - "futures-rustls", - "libp2p-core", - "libp2p-identity", - "rcgen", - "ring", - "rustls", - "rustls-webpki 0.101.7", - "thiserror 2.0.12", - "x509-parser", - "yasna", -] +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" [[package]] -name = "libp2p-upnp" -version = "0.4.0" +name = "libloading" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d457b9ecceb66e7199f049926fad447f1f17f040e8d29d690c086b4cab8ed14a" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" dependencies = [ - "futures", - "futures-timer", - "igd-next", - "libp2p-core", - "libp2p-swarm", - "tokio", - "tracing", + "cfg-if", + "windows-link 0.2.1", ] [[package]] -name = "libp2p-yamux" -version = "0.47.0" +name = "libm" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f15df094914eb4af272acf9adaa9e287baa269943f32ea348ba29cfb9bfc60d8" -dependencies = [ - "either", - "futures", - "libp2p-core", - "thiserror 2.0.12", - "tracing", - "yamux 0.12.1", - "yamux 0.13.4", -] +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libredox" @@ -2315,7 +1216,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.9.0", + "bitflags", "libc", ] @@ -2367,12 +1268,6 @@ dependencies = [ "libsecp256k1-core", ] -[[package]] -name = "linux-raw-sys" -version = "0.4.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" - [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -2386,13 +1281,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" [[package]] -name = "lock_api" -version = "0.4.12" +name = "llama-cpp-2" +version = "0.1.137" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aedc4f4ca22ad992bc43fe20734b3a0f37363b9621419727821bf6572b9c0395" +dependencies = [ + "encoding_rs", + "enumflags2", + "llama-cpp-sys-2", + "thiserror 2.0.12", + "tracing", + "tracing-core", +] + +[[package]] +name = "llama-cpp-sys-2" +version = "0.1.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "da365e84fbe4d10e849fa3bfd5a0d70b3b4a59e8c5adc8b7be5c189327566bdb" dependencies = [ - "autocfg", - "scopeguard", + "bindgen", + "cc", + "cmake", + "find_cuda_helper", + "glob", + "walkdir", ] [[package]] @@ -2402,26 +1315,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] -name = "loom" -version = "0.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" -dependencies = [ - "cfg-if", - "generator", - "scoped-tls", - "tracing", - "tracing-subscriber", -] - -[[package]] -name = "lru" -version = "0.12.5" +name = "lru-slab" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" -dependencies = [ - "hashbrown 0.15.3", -] +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "matchers" @@ -2432,17 +1329,6 @@ dependencies = [ "regex-automata 0.1.10", ] -[[package]] -name = "maybe-async" -version = "0.2.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cf92c10c7e361d6b99666ec1c6f9805b0bea2c3bd8c78dc6fe98ac5bd78db11" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "memchr" version = "2.7.4" @@ -2455,16 +1341,6 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" -[[package]] -name = "mime_guess" -version = "2.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" -dependencies = [ - "mime", - "unicase", -] - [[package]] name = "minimal-lexical" version = "0.2.1" @@ -2478,6 +1354,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", + "simd-adler32", ] [[package]] @@ -2491,79 +1368,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "moka" -version = "0.12.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9321642ca94a4282428e6ea4af8cc2ca4eac48ac7a6a4ea8f33f76d0ce70926" -dependencies = [ - "crossbeam-channel", - "crossbeam-epoch", - "crossbeam-utils", - "loom", - "parking_lot", - "portable-atomic", - "rustc_version", - "smallvec", - "tagptr", - "thiserror 1.0.69", - "uuid", -] - -[[package]] -name = "multiaddr" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe6351f60b488e04c1d21bc69e56b89cb3f5e8f5d22557d6e8031bdfd79b6961" -dependencies = [ - "arrayref", - "byteorder", - "data-encoding", - "libp2p-identity", - "multibase", - "multihash", - "percent-encoding", - "serde", - "static_assertions", - "unsigned-varint 0.8.0", - "url", -] - -[[package]] -name = "multibase" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b3539ec3c1f04ac9748a260728e855f261b4977f5c3406612c884564f329404" -dependencies = [ - "base-x", - "data-encoding", - "data-encoding-macro", -] - -[[package]] -name = "multihash" -version = "0.19.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b430e7953c29dd6a09afc29ff0bb69c6e306329ee6794700aee27b76a1aea8d" -dependencies = [ - "core2", - "unsigned-varint 0.8.0", -] - -[[package]] -name = "multistream-select" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea0df8e5eec2298a62b326ee4f0d7fe1a6b90a09dfcf9df37b38f947a8c42f19" -dependencies = [ - "bytes", - "futures", - "log", - "pin-project", - "smallvec", - "unsigned-varint 0.7.2", -] - [[package]] name = "native-tls" version = "0.2.14" @@ -2573,95 +1377,14 @@ dependencies = [ "libc", "log", "openssl", - "openssl-probe", + "openssl-probe 0.1.6", "openssl-sys", "schannel", - "security-framework", + "security-framework 2.11.1", "security-framework-sys", "tempfile", ] -[[package]] -name = "netlink-packet-core" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72724faf704479d67b388da142b186f916188505e7e0b26719019c525882eda4" -dependencies = [ - "anyhow", - "byteorder", - "netlink-packet-utils", -] - -[[package]] -name = "netlink-packet-route" -version = "0.17.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "053998cea5a306971f88580d0829e90f270f940befd7cf928da179d4187a5a66" -dependencies = [ - "anyhow", - "bitflags 1.3.2", - "byteorder", - "libc", - "netlink-packet-core", - "netlink-packet-utils", -] - -[[package]] -name = "netlink-packet-utils" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ede8a08c71ad5a95cdd0e4e52facd37190977039a4704eb82a283f713747d34" -dependencies = [ - "anyhow", - "byteorder", - "paste", - "thiserror 1.0.69", -] - -[[package]] -name = "netlink-proto" -version = "0.11.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72452e012c2f8d612410d89eea01e2d9b56205274abb35d53f60200b2ec41d60" -dependencies = [ - "bytes", - "futures", - "log", - "netlink-packet-core", - "netlink-sys", - "thiserror 2.0.12", -] - -[[package]] -name = "netlink-sys" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16c903aa70590cb93691bf97a767c8d1d6122d2cc9070433deb3bbf36ce8bd23" -dependencies = [ - "bytes", - "futures", - "libc", - "log", - "tokio", -] - -[[package]] -name = "nix" -version = "0.26.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b" -dependencies = [ - "bitflags 1.3.2", - "cfg-if", - "libc", -] - -[[package]] -name = "nohash-hasher" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" - [[package]] name = "nom" version = "7.1.3" @@ -2672,15 +1395,6 @@ dependencies = [ "minimal-lexical", ] -[[package]] -name = "ntapi" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" -dependencies = [ - "winapi", -] - [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -2691,30 +1405,11 @@ dependencies = [ "winapi", ] -[[package]] -name = "num-bigint" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" -dependencies = [ - "num-integer", - "num-traits", -] - [[package]] name = "num-conv" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" - -[[package]] -name = "num-integer" -version = "0.1.46" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" -dependencies = [ - "num-traits", -] +checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" [[package]] name = "num-traits" @@ -2731,10 +1426,16 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.9", + "hermit-abi", "libc", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "object" version = "0.36.7" @@ -2744,43 +1445,11 @@ dependencies = [ "memchr", ] -[[package]] -name = "oid-registry" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" -dependencies = [ - "asn1-rs", -] - -[[package]] -name = "ollama-rs" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0bd0e2c30868e72ffca8143873c6c1e288b2efda9d3950e9ae0d0b4039c49c3" -dependencies = [ - "async-stream", - "log", - "reqwest", - "schemars", - "serde", - "serde_json", - "static_assertions", - "thiserror 2.0.12", - "tokio", - "tokio-stream", - "url", -] - [[package]] name = "once_cell" version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" -dependencies = [ - "critical-section", - "portable-atomic", -] [[package]] name = "opaque-debug" @@ -2794,7 +1463,7 @@ version = "0.10.72" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da" dependencies = [ - "bitflags 2.9.0", + "bitflags", "cfg-if", "foreign-types", "libc", @@ -2821,13 +1490,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] -name = "openssl-src" -version = "300.5.0+3.5.0" +name = "openssl-probe" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8ce546f549326b0e6052b649198487d91320875da901e7bd11a06d1ee3f9c2f" -dependencies = [ - "cc", -] +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "openssl-sys" @@ -2837,7 +1503,6 @@ checksum = "e145e1651e858e820e4860f7b9c5e169bc1d8ce1c86043be79fa7b7634821847" dependencies = [ "cc", "libc", - "openssl-src", "pkg-config", "vcpkg", ] @@ -2848,64 +1513,20 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" -[[package]] -name = "ordered-float" -version = "4.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" -dependencies = [ - "num-traits", -] - [[package]] name = "overload" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "parking" -version = "2.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" - -[[package]] -name = "parking_lot" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets 0.52.6", -] - -[[package]] -name = "paste" -version = "1.0.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" - [[package]] name = "pem" -version = "3.0.5" +version = "3.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38af38e8470ac9dee3ce1bae1af9c1671fffc44ddfd8bd1d0a3445bf349a8ef3" +checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" dependencies = [ "base64", - "serde", + "serde_core", ] [[package]] @@ -2914,26 +1535,6 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" -[[package]] -name = "pin-project" -version = "1.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "pin-project-lite" version = "0.2.16" @@ -2943,68 +1544,14 @@ checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "pin-utils" version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - -[[package]] -name = "pkcs8" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" -dependencies = [ - "der", - "spki", -] - -[[package]] -name = "pkg-config" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" - -[[package]] -name = "polling" -version = "3.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a604568c3202727d1507653cb121dbd627a58684eb09a820fd746bee38b4442f" -dependencies = [ - "cfg-if", - "concurrent-queue", - "hermit-abi 0.4.0", - "pin-project-lite", - "rustix 0.38.44", - "tracing", - "windows-sys 0.59.0", -] - -[[package]] -name = "poly1305" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf" -dependencies = [ - "cpufeatures", - "opaque-debug", - "universal-hash", -] - -[[package]] -name = "polyval" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" -dependencies = [ - "cfg-if", - "cpufeatures", - "opaque-debug", - "universal-hash", -] +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] -name = "port_check" -version = "0.2.1" +name = "pkg-config" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2110609fb863cdb367d4e69d6c43c81ba6a8c7d18e80082fe9f3ef16b23afeed" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "portable-atomic" @@ -3012,15 +1559,6 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" -[[package]] -name = "portable-atomic-util" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" -dependencies = [ - "portable-atomic", -] - [[package]] name = "powerfmt" version = "0.2.0" @@ -3037,83 +1575,32 @@ dependencies = [ ] [[package]] -name = "proc-macro2" -version = "1.0.95" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "prometheus-client" -version = "0.22.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "504ee9ff529add891127c4827eb481bd69dc0ebc72e9a682e187db4caa60c3ca" -dependencies = [ - "dtoa", - "itoa", - "parking_lot", - "prometheus-client-derive-encode", -] - -[[package]] -name = "prometheus-client-derive-encode" -version = "0.4.2" +name = "prettyplease" +version = "0.2.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "440f724eba9f6996b75d63681b0a92b06947f1457076d503a4d2e2c8f56442b8" +checksum = "6837b9e10d61f45f987d50808f83d1ee3d206c66acf650c3e4ae2e1f6ddedf55" dependencies = [ "proc-macro2", - "quote", "syn", ] [[package]] -name = "public-ip-address" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "761cf3bcffbc326e841fcbaf0849759dc2e30876b89c454e0991f20ceca40f4c" -dependencies = [ - "directories", - "log", - "maybe-async", - "reqwest", - "serde", - "serde_json", - "thiserror 1.0.69", -] - -[[package]] -name = "quick-protobuf" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d6da84cc204722a989e01ba2f6e1e276e190f22263d0cb6ce8526fcdb0d2e1f" -dependencies = [ - "byteorder", -] - -[[package]] -name = "quick-protobuf-codec" -version = "0.3.1" +name = "proc-macro2" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15a0580ab32b169745d7a39db2ba969226ca16738931be152a3209b409de2474" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ - "asynchronous-codec", - "bytes", - "quick-protobuf", - "thiserror 1.0.69", - "unsigned-varint 0.8.0", + "unicode-ident", ] [[package]] name = "quinn" -version = "0.11.7" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3bd15a6f2967aef83887dcb9fec0014580467e33720d073560cf015a5683012" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" dependencies = [ "bytes", "cfg_aliases", - "futures-io", "pin-project-lite", "quinn-proto", "quinn-udp", @@ -3128,17 +1615,20 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.11" +version = "0.11.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcbafbbdbb0f638fe3f35f3c56739f77a8a1d070cb25603226c83339b391472b" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" dependencies = [ "bytes", + "fastbloom", "getrandom 0.3.2", + "lru-slab", "rand 0.9.1", "ring", "rustc-hash", "rustls", "rustls-pki-types", + "rustls-platform-verifier", "slab", "thiserror 2.0.12", "tinyvec", @@ -3148,16 +1638,16 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.12" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee4e529991f949c5e25755532370b8af5d114acae52326361d68d47af64aa842" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" dependencies = [ "cfg_aliases", "libc", "once_cell", "socket2", "tracing", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -3234,26 +1724,6 @@ dependencies = [ "getrandom 0.3.2", ] -[[package]] -name = "rayon" -version = "1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" -dependencies = [ - "crossbeam-deque", - "crossbeam-utils", -] - [[package]] name = "rcgen" version = "0.13.2" @@ -3267,24 +1737,15 @@ dependencies = [ "yasna", ] -[[package]] -name = "redox_syscall" -version = "0.5.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2f103c6d277498fbceb16e84d317e2a400f160f46904d5f5410848c829511a3" -dependencies = [ - "bitflags 2.9.0", -] - [[package]] name = "redox_users" -version = "0.4.6" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" dependencies = [ "getrandom 0.2.16", "libredox", - "thiserror 1.0.69", + "thiserror 2.0.12", ] [[package]] @@ -3343,7 +1804,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http 1.3.1", + "http", "http-body", "http-body-util", "hyper", @@ -3354,15 +1815,11 @@ dependencies = [ "js-sys", "log", "mime", - "mime_guess", "native-tls", "once_cell", "percent-encoding", "pin-project-lite", - "quinn", - "rustls", "rustls-pemfile", - "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", @@ -3370,7 +1827,6 @@ dependencies = [ "system-configuration", "tokio", "tokio-native-tls", - "tokio-rustls", "tokio-util", "tower", "tower-service", @@ -3379,47 +1835,9 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots", "windows-registry", ] -[[package]] -name = "resolv-conf" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7c8f7f733062b66dc1c63f9db168ac0b97a9210e247fa90fdc9ad08f51b302" - -[[package]] -name = "rfc6979" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" -dependencies = [ - "hmac 0.12.1", - "subtle", -] - -[[package]] -name = "rig-core" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb610bd7e61825e79ca79b7efcad93206256147e27cbf707dffd80b7622b5ca7" -dependencies = [ - "async-stream", - "base64", - "bytes", - "futures", - "glob", - "mime_guess", - "ordered-float", - "reqwest", - "schemars", - "serde", - "serde_json", - "thiserror 1.0.69", - "tracing", -] - [[package]] name = "ring" version = "0.17.14" @@ -3435,21 +1853,22 @@ dependencies = [ ] [[package]] -name = "rtnetlink" -version = "0.13.1" +name = "rmp" +version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a552eb82d19f38c3beed3f786bd23aa434ceb9ac43ab44419ca6d67a7e186c0" +checksum = "4ba8be72d372b2c9b35542551678538b562e7cf86c3315773cae48dfbfe7790c" dependencies = [ - "futures", - "log", - "netlink-packet-core", - "netlink-packet-route", - "netlink-packet-utils", - "netlink-proto", - "netlink-sys", - "nix", - "thiserror 1.0.69", - "tokio", + "num-traits", +] + +[[package]] +name = "rmp-serde" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72f81bee8c8ef9b577d1681a70ebbc962c232461e397b22c208c43c04b67a155" +dependencies = [ + "rmp", + "serde", ] [[package]] @@ -3464,64 +1883,46 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" -[[package]] -name = "rustc_version" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" -dependencies = [ - "semver", -] - -[[package]] -name = "rusticata-macros" -version = "4.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" -dependencies = [ - "nom", -] - -[[package]] -name = "rustix" -version = "0.38.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" -dependencies = [ - "bitflags 2.9.0", - "errno", - "libc", - "linux-raw-sys 0.4.15", - "windows-sys 0.59.0", -] - [[package]] name = "rustix" version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" dependencies = [ - "bitflags 2.9.0", + "bitflags", "errno", "libc", - "linux-raw-sys 0.9.4", + "linux-raw-sys", "windows-sys 0.59.0", ] [[package]] name = "rustls" -version = "0.23.26" +version = "0.23.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0" +checksum = "7160e3e10bf4535308537f3c4e1641468cd0e485175d6163087c0393c7d46643" dependencies = [ + "log", "once_cell", "ring", "rustls-pki-types", - "rustls-webpki 0.103.1", + "rustls-webpki", "subtle", "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +dependencies = [ + "openssl-probe 0.2.1", + "rustls-pki-types", + "schannel", + "security-framework 3.7.0", +] + [[package]] name = "rustls-pemfile" version = "2.2.0" @@ -3541,20 +1942,37 @@ dependencies = [ ] [[package]] -name = "rustls-webpki" -version = "0.101.7" +name = "rustls-platform-verifier" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" dependencies = [ - "ring", - "untrusted", + "core-foundation 0.10.1", + "core-foundation-sys", + "jni", + "log", + "once_cell", + "rustls", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework 3.7.0", + "security-framework-sys", + "webpki-root-certs", + "windows-sys 0.61.2", ] +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" + [[package]] name = "rustls-webpki" -version = "0.103.1" +version = "0.103.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03" +checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" dependencies = [ "ring", "rustls-pki-types", @@ -3567,17 +1985,6 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" -[[package]] -name = "rw-stream-sink" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8c9026ff5d2f23da5e45bbc283f156383001bfb09c4e44256d02c1a685fe9a1" -dependencies = [ - "futures", - "pin-project", - "static_assertions", -] - [[package]] name = "ryu" version = "1.0.20" @@ -3585,73 +1992,44 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] -name = "schannel" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" -dependencies = [ - "windows-sys 0.59.0", -] - -[[package]] -name = "schemars" -version = "0.8.22" +name = "same-file" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" dependencies = [ - "dyn-clone", - "indexmap 1.9.3", - "schemars_derive", - "serde", - "serde_json", + "winapi-util", ] [[package]] -name = "schemars_derive" -version = "0.8.22" +name = "schannel" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" dependencies = [ - "proc-macro2", - "quote", - "serde_derive_internals", - "syn", + "windows-sys 0.59.0", ] [[package]] -name = "scoped-tls" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "sec1" -version = "0.7.3" +name = "security-framework" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "base16ct", - "der", - "generic-array", - "pkcs8", - "subtle", - "zeroize", + "bitflags", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", ] [[package]] name = "security-framework" -version = "2.11.1" +version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" dependencies = [ - "bitflags 2.9.0", - "core-foundation", + "bitflags", + "core-foundation 0.10.1", "core-foundation-sys", "libc", "security-framework-sys", @@ -3659,45 +2037,38 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.14.0" +version = "2.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" dependencies = [ "core-foundation-sys", "libc", ] -[[package]] -name = "semver" -version = "1.0.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" - [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ + "serde_core", "serde_derive", ] [[package]] -name = "serde_derive" -version = "1.0.219" +name = "serde_core" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ - "proc-macro2", - "quote", - "syn", + "serde_derive", ] [[package]] -name = "serde_derive_internals" -version = "0.29.1" +name = "serde_derive" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -3787,14 +2158,16 @@ dependencies = [ ] [[package]] -name = "signature" -version = "2.2.0" +name = "simd-adler32" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" -dependencies = [ - "digest 0.10.7", - "rand_core 0.6.4", -] +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + +[[package]] +name = "siphasher" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" [[package]] name = "slab" @@ -3811,23 +2184,6 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" -[[package]] -name = "snow" -version = "0.9.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "850948bee068e713b8ab860fe1adc4d109676ab4c3b621fd8147f06b261f2f85" -dependencies = [ - "aes-gcm", - "blake2", - "chacha20poly1305", - "curve25519-dalek", - "rand_core 0.6.4", - "ring", - "rustc_version", - "sha2 0.10.9", - "subtle", -] - [[package]] name = "socket2" version = "0.5.9" @@ -3839,13 +2195,14 @@ dependencies = [ ] [[package]] -name = "spki" -version = "0.7.3" +name = "socks" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" dependencies = [ - "base64ct", - "der", + "byteorder", + "libc", + "winapi", ] [[package]] @@ -3855,10 +2212,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] -name = "static_assertions" -version = "1.1.0" +name = "strsim" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "subtle" @@ -3897,28 +2254,14 @@ dependencies = [ "syn", ] -[[package]] -name = "sysinfo" -version = "0.33.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fc858248ea01b66f19d8e8a6d55f41deaf91e9d495246fd01368d99935c6c01" -dependencies = [ - "core-foundation-sys", - "libc", - "memchr", - "ntapi", - "rayon", - "windows 0.57.0", -] - [[package]] name = "system-configuration" version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ - "bitflags 2.9.0", - "core-foundation", + "bitflags", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -3932,12 +2275,6 @@ dependencies = [ "libc", ] -[[package]] -name = "tagptr" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" - [[package]] name = "tempfile" version = "3.19.1" @@ -3947,7 +2284,7 @@ dependencies = [ "fastrand", "getrandom 0.3.2", "once_cell", - "rustix 1.0.7", + "rustix", "windows-sys 0.59.0", ] @@ -4003,34 +2340,22 @@ dependencies = [ [[package]] name = "time" -version = "0.3.41" +version = "0.3.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" dependencies = [ "deranged", - "itoa", "num-conv", "powerfmt", - "serde", + "serde_core", "time-core", - "time-macros", ] [[package]] name = "time-core" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" - -[[package]] -name = "time-macros" -version = "0.2.22" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" -dependencies = [ - "num-conv", - "time-core", -] +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" [[package]] name = "tinystr" @@ -4044,9 +2369,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" dependencies = [ "tinyvec_macros", ] @@ -4067,7 +2392,6 @@ dependencies = [ "bytes", "libc", "mio", - "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", @@ -4106,17 +2430,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-stream" -version = "0.1.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" -dependencies = [ - "futures-core", - "pin-project-lite", - "tokio", -] - [[package]] name = "tokio-util" version = "0.7.15" @@ -4127,7 +2440,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "hashbrown 0.15.3", + "hashbrown", "pin-project-lite", "tokio", ] @@ -4165,6 +2478,7 @@ version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -4232,12 +2546,6 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" -[[package]] -name = "unicase" -version = "2.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" - [[package]] name = "unicode-ident" version = "1.0.18" @@ -4245,26 +2553,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" [[package]] -name = "universal-hash" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" -dependencies = [ - "crypto-common", - "subtle", -] - -[[package]] -name = "unsigned-varint" -version = "0.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6889a77d49f1f013504cec6bf97a2c730394adedaeb1deb5ea08949a50541105" - -[[package]] -name = "unsigned-varint" -version = "0.8.0" +name = "unicode-width" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb066959b24b5196ae73cb057f45598450d2c5f71460e98c49b738086eff9c06" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" [[package]] name = "untrusted" @@ -4272,6 +2564,26 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" +dependencies = [ + "base64", + "flate2", + "log", + "native-tls", + "once_cell", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "socks", + "url", + "webpki-roots", +] + [[package]] name = "url" version = "2.5.4" @@ -4283,12 +2595,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "urlencoding" -version = "2.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" - [[package]] name = "utf16_iter" version = "1.0.5" @@ -4335,6 +2641,16 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -4464,19 +2780,22 @@ dependencies = [ ] [[package]] -name = "webpki-roots" -version = "0.26.10" +name = "webpki-root-certs" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37493cadf42a2a939ed404698ded7fb378bf301b5011f973361779a3a74f8c93" +checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca" dependencies = [ "rustls-pki-types", ] [[package]] -name = "widestring" -version = "1.2.0" +name = "webpki-roots" +version = "0.26.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd7cf3379ca1aac9eea11fba24fd7e315d621f8dfe35c8d7d2be8b793726e07d" +checksum = "37493cadf42a2a939ed404698ded7fb378bf301b5011f973361779a3a74f8c93" +dependencies = [ + "rustls-pki-types", +] [[package]] name = "winapi" @@ -4495,160 +2814,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - -[[package]] -name = "windows" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efc5cf48f83140dcaab716eeaea345f9e93d0018fb81162753a3f76c3397b538" -dependencies = [ - "windows-core 0.53.0", - "windows-targets 0.52.6", -] - -[[package]] -name = "windows" -version = "0.57.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12342cb4d8e3b046f3d80effd474a7a02447231330ef77d71daa6fbc40681143" -dependencies = [ - "windows-core 0.57.0", - "windows-targets 0.52.6", -] - -[[package]] -name = "windows" -version = "0.58.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd04d41d93c4992d421894c18c8b43496aa748dd4c081bac0dc93eb0489272b6" -dependencies = [ - "windows-core 0.58.0", - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-core" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dcc5b895a6377f1ab9fa55acedab1fd5ac0db66ad1e6c7f47e28a22e446a5dd" -dependencies = [ - "windows-result 0.1.2", - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-core" -version = "0.57.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2ed2439a290666cd67ecce2b0ffaad89c2a56b976b736e6ece670297897832d" -dependencies = [ - "windows-implement 0.57.0", - "windows-interface 0.57.0", - "windows-result 0.1.2", - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-core" -version = "0.58.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ba6d44ec8c2591c134257ce647b7ea6b20335bf6379a27dac5f1641fcf59f99" -dependencies = [ - "windows-implement 0.58.0", - "windows-interface 0.58.0", - "windows-result 0.2.0", - "windows-strings 0.1.0", - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-core" -version = "0.61.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980" -dependencies = [ - "windows-implement 0.60.0", - "windows-interface 0.59.1", - "windows-link", - "windows-result 0.3.2", - "windows-strings 0.4.0", -] - -[[package]] -name = "windows-implement" -version = "0.57.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "windows-implement" -version = "0.58.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "windows-implement" -version = "0.60.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "windows-interface" -version = "0.57.0" +name = "winapi-util" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "proc-macro2", - "quote", - "syn", + "windows-sys 0.61.2", ] [[package]] -name = "windows-interface" -version = "0.58.0" +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] -name = "windows-interface" -version = "0.59.1" +name = "windows-link" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] +checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" [[package]] name = "windows-link" -version = "0.1.1" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "windows-registry" @@ -4656,27 +2846,9 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" dependencies = [ - "windows-result 0.3.2", - "windows-strings 0.3.1", - "windows-targets 0.53.0", -] - -[[package]] -name = "windows-result" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8" -dependencies = [ - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-result" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" -dependencies = [ - "windows-targets 0.52.6", + "windows-result", + "windows-strings", + "windows-targets 0.53.5", ] [[package]] @@ -4685,77 +2857,76 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252" dependencies = [ - "windows-link", + "windows-link 0.1.1", ] [[package]] name = "windows-strings" -version = "0.1.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" dependencies = [ - "windows-result 0.2.0", - "windows-targets 0.52.6", + "windows-link 0.1.1", ] [[package]] -name = "windows-strings" -version = "0.3.1" +name = "windows-sys" +version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" dependencies = [ - "windows-link", + "windows-targets 0.42.2", ] [[package]] -name = "windows-strings" -version = "0.4.0" +name = "windows-sys" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2ba9642430ee452d5a7aa78d72907ebe8cfda358e8cb7918a2050581322f97" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-link", + "windows-targets 0.52.6", ] [[package]] name = "windows-sys" -version = "0.48.0" +version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] name = "windows-sys" -version = "0.52.0" +version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" dependencies = [ - "windows-targets 0.52.6", + "windows-targets 0.53.5", ] [[package]] name = "windows-sys" -version = "0.59.0" +version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" dependencies = [ - "windows-targets 0.52.6", + "windows-link 0.2.1", ] [[package]] name = "windows-targets" -version = "0.48.5" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" dependencies = [ - "windows_aarch64_gnullvm 0.48.5", - "windows_aarch64_msvc 0.48.5", - "windows_i686_gnu 0.48.5", - "windows_i686_msvc 0.48.5", - "windows_x86_64_gnu 0.48.5", - "windows_x86_64_gnullvm 0.48.5", - "windows_x86_64_msvc 0.48.5", + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", ] [[package]] @@ -4776,10 +2947,11 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.53.0" +version = "0.53.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ + "windows-link 0.2.1", "windows_aarch64_gnullvm 0.53.0", "windows_aarch64_msvc 0.53.0", "windows_i686_gnu 0.53.0", @@ -4792,9 +2964,9 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.5" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" [[package]] name = "windows_aarch64_gnullvm" @@ -4810,9 +2982,9 @@ checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" [[package]] name = "windows_aarch64_msvc" -version = "0.48.5" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" [[package]] name = "windows_aarch64_msvc" @@ -4828,9 +3000,9 @@ checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" [[package]] name = "windows_i686_gnu" -version = "0.48.5" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" [[package]] name = "windows_i686_gnu" @@ -4858,9 +3030,9 @@ checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" [[package]] name = "windows_i686_msvc" -version = "0.48.5" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" [[package]] name = "windows_i686_msvc" @@ -4876,9 +3048,9 @@ checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" [[package]] name = "windows_x86_64_gnu" -version = "0.48.5" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" [[package]] name = "windows_x86_64_gnu" @@ -4894,9 +3066,9 @@ checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.5" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" [[package]] name = "windows_x86_64_gnullvm" @@ -4912,9 +3084,9 @@ checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" [[package]] name = "windows_x86_64_msvc" -version = "0.48.5" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" [[package]] name = "windows_x86_64_msvc" @@ -4928,23 +3100,13 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" -[[package]] -name = "winreg" -version = "0.50.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" -dependencies = [ - "cfg-if", - "windows-sys 0.48.0", -] - [[package]] name = "wit-bindgen-rt" version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" dependencies = [ - "bitflags 2.9.0", + "bitflags", ] [[package]] @@ -4959,81 +3121,6 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" -[[package]] -name = "x25519-dalek" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" -dependencies = [ - "curve25519-dalek", - "rand_core 0.6.4", - "serde", - "zeroize", -] - -[[package]] -name = "x509-parser" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" -dependencies = [ - "asn1-rs", - "data-encoding", - "der-parser", - "lazy_static", - "nom", - "oid-registry", - "rusticata-macros", - "thiserror 1.0.69", - "time", -] - -[[package]] -name = "xml-rs" -version = "0.8.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62ce76d9b56901b19a74f19431b0d8b3bc7ca4ad685a746dfd78ca8f4fc6bda" - -[[package]] -name = "xmltree" -version = "0.10.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7d8a75eaf6557bb84a65ace8609883db44a29951042ada9b393151532e41fcb" -dependencies = [ - "xml-rs", -] - -[[package]] -name = "yamux" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed0164ae619f2dc144909a9f082187ebb5893693d8c0196e8085283ccd4b776" -dependencies = [ - "futures", - "log", - "nohash-hasher", - "parking_lot", - "pin-project", - "rand 0.8.5", - "static_assertions", -] - -[[package]] -name = "yamux" -version = "0.13.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17610762a1207ee816c6fadc29220904753648aba0a9ed61c7b8336e80a559c4" -dependencies = [ - "futures", - "log", - "nohash-hasher", - "parking_lot", - "pin-project", - "rand 0.8.5", - "static_assertions", - "web-time", -] - [[package]] name = "yasna" version = "0.5.2" @@ -5113,20 +3200,6 @@ name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" -dependencies = [ - "zeroize_derive", -] - -[[package]] -name = "zeroize_derive" -version = "1.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] [[package]] name = "zerovec" diff --git a/Cargo.toml b/Cargo.toml index f975ff0f..c2d69768 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,41 +1,42 @@ -[workspace] -resolver = "2" -members = ["compute", "p2p", "utils", "executor"] - -# FIXME: removing this breaks .github workflows -default-members = ["compute"] - -[workspace.package] +[package] +name = "dria-node" +version = "2.0.0-alpha.1" edition = "2021" -version = "0.6.7" license = "Apache-2.0" -readme = "README.md" - -# profiling build for flamegraphs -[profile.profiling] -inherits = "release" -debug = true - -[workspace.dependencies] -# async stuff -tokio-util = { version = "0.7.10", features = ["rt"] } -tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal"] } - -# serialize & deserialize -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" - -# http client -reqwest = "0.12.5" - -# utilities -dotenvy = "0.15.7" -rand = "0.8.5" -uuid = { version = "1.8.0", features = ["v7", "serde"] } -chrono = { version = "0.4.40", features = ["serde"] } -# logging & errors -env_logger = "0.11.3" -log = "0.4.21" -eyre = "0.6.12" -thiserror = "2.0.12" +[[bin]] +name = "dria-node" +path = "src/main.rs" + +[features] +default = [] +cuda = ["llama-cpp-2/cuda"] +metal = ["llama-cpp-2/metal"] + +[dependencies] +clap = { version = "4", features = ["derive", "env"] } +tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal", "time"] } +tokio-util = { version = "0.7", features = ["rt"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +rmp-serde = "1" +libsecp256k1 = "0.7" +sha2 = "0.10" +sha3 = "0.10" +hex = "0.4" +rand = "0.8" +llama-cpp-2 = "0.1.137" +hf-hub = { version = "0.4", features = ["tokio"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +thiserror = "2" +anyhow = "1" +uuid = { version = "1", features = ["v7", "serde"] } +dirs = "6" +encoding_rs = "0.8" +quinn = "0.11" +rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } +rustls-native-certs = "0.8" +rcgen = "0.13" +futures = "0.3" +bytes = "1" diff --git a/Cross.toml b/Cross.toml deleted file mode 100644 index fb58fe24..00000000 --- a/Cross.toml +++ /dev/null @@ -1,8 +0,0 @@ -# this setting here helps with OpenSSL problems -# but we may remove it now that we use vendored OpenSSL -# see: https://github.com/cross-rs/cross/wiki/Recipes#pre-build -[target.aarch64-unknown-linux-gnu] -pre-build = [ - "dpkg --add-architecture $CROSS_DEB_ARCH", - "apt-get update && apt-get install --assume-yes libssl-dev:$CROSS_DEB_ARCH", -] diff --git a/Makefile b/Makefile deleted file mode 100644 index 2603e139..00000000 --- a/Makefile +++ /dev/null @@ -1,25 +0,0 @@ -# load .env -ifneq (,$(wildcard ./.env)) - include ./.env - export -endif - -############################################################################### -.PHONY: debug # | Run with DEBUG logs with -debug: - RUST_LOG=warn,dkn_compute=debug,dkn_executor=debug,dkn_p2p=debug \ - cargo run --bin dkn-compute - -.PHONY: build # | Build -build: - cargo build --workspace - -.PHONY: trace # | Run with TRACE logs -trace: - RUST_LOG=warn,dkn_compute=trace,libp2p=debug \ - cargo run --bin dkn-compute - -# https://stackoverflow.com/a/45843594 -.PHONY: help # | List targets -help: - @grep '^.PHONY: .* #' Makefile | sed 's/\.PHONY: \(.*\) # \(.*\)/\1 \2/' | expand -t20 diff --git a/compose.yml b/compose.yml deleted file mode 100644 index c06fa4a2..00000000 --- a/compose.yml +++ /dev/null @@ -1,64 +0,0 @@ -services: - # Compute Node - compute: - image: "firstbatch/dkn-compute-node:latest" - # build: "./" # use this one instead if you want to build locally - environment: - RUST_LOG: ${RUST_LOG:-none,dkn_compute=info} - # Dria - DKN_WALLET_SECRET_KEY: ${DKN_WALLET_SECRET_KEY} - DKN_MODELS: ${DKN_MODELS} - DKN_P2P_LISTEN_ADDR: ${DKN_P2P_LISTEN_ADDR} - # API Keys - OPENAI_API_KEY: ${OPENAI_API_KEY} - GEMINI_API_KEY: ${GEMINI_API_KEY} - OPENROUTER_API_KEY: ${OPENROUTER_API_KEY} - # Ollama - OLLAMA_HOST: ${OLLAMA_HOST} - OLLAMA_PORT: ${OLLAMA_PORT} - OLLAMA_AUTO_PULL: ${OLLAMA_AUTO_PULL:-true} - network_mode: ${DKN_DOCKER_NETWORK_MODE:-bridge} - extra_hosts: - # for Linux, we need to add this line manually - - "host.docker.internal:host-gateway" - restart: "on-failure" - - # Ollama Container (CPU) - ollama: - image: ollama/ollama:latest - ports: - - 11434:11434 - volumes: - - ~/.ollama:/root/.ollama - profiles: [ollama-cpu] - - # Ollama Container (ROCM) - ollama-rocm: - image: ollama/ollama:rocm - ports: - - 11434:11434 - volumes: - - ~/.ollama:/root/.ollama - devices: - - "/dev/kfd" - - "/dev/dri" - profiles: [ollama-rocm] - - # Ollama Container (CUDA) - ollama-cuda: - image: ollama/ollama - ports: - - 11434:11434 - volumes: - - ~/.ollama:/root/.ollama - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: 1 - capabilities: [gpu] - profiles: [ollama-cuda] - -volumes: - ollama: diff --git a/compute/Cargo.toml b/compute/Cargo.toml deleted file mode 100644 index fbd4f289..00000000 --- a/compute/Cargo.toml +++ /dev/null @@ -1,64 +0,0 @@ -[package] -name = "dkn-compute" -version.workspace = true -edition.workspace = true -license.workspace = true -readme = "README.md" -authors = ["Erhan Tezcan "] - -[dependencies] -# async stuff -tokio-util.workspace = true -tokio.workspace = true - -# serialize & deserialize -serde.workspace = true -serde_json.workspace = true - -# http & networking -reqwest.workspace = true -port_check = "0.2.1" -url = "2.5.0" -urlencoding = "2.1.3" - -# utilities -dotenvy.workspace = true -base64 = "0.22.0" -hex = "0.4.3" -hex-literal = "0.4.1" -uuid.workspace = true -rand.workspace = true - -# logging & errors -env_logger.workspace = true -log.workspace = true -eyre.workspace = true -colored = "3.0.0" - -# encryption (ecies) & signatures (ecdsa) & hashing & bloom-filters -ecies = { version = "0.2", default-features = false, features = ["pure"] } -libsecp256k1 = "0.7.1" - -# machine diagnostics -# system info -sysinfo = "0.33.1" -# gpu info TODO: this gives a build error on Windows -# wgpu = { version = "23.0.1", features = [ -# "serde", -# "dx12", -# "metal", -# ], default-features = false } -# public ip -public-ip-address = "0.3.2" - -# dria subcrates -dkn-p2p = { path = "../p2p" } -dkn-utils = { path = "../utils", features = ["crypto"] } -dkn-executor = { path = "../executor" } -chrono.workspace = true - - -# vendor OpenSSL so that its easier to build cross-platform packages -[dependencies.openssl] -version = "*" -features = ["vendored"] diff --git a/compute/src/config.rs b/compute/src/config.rs deleted file mode 100644 index e6678698..00000000 --- a/compute/src/config.rs +++ /dev/null @@ -1,183 +0,0 @@ -use dkn_executor::DriaExecutorsManager; -use dkn_p2p::libp2p::{Multiaddr, PeerId}; -use eyre::{eyre, Result}; -use libsecp256k1::{PublicKey, SecretKey}; -use std::{env, str::FromStr}; - -use dkn_utils::{ - crypto::{public_key_to_address, secret_to_keypair}, - DriaNetwork, SemanticVersion, -}; - -const DEFAULT_TASK_BATCH_SIZE: usize = 5; -const DEFAULT_P2P_LISTEN_ADDR: &str = "/ip4/0.0.0.0/tcp/4001"; - -#[derive(Clone)] -pub struct DriaComputeNodeConfig { - /// Wallet secret/private key. - pub secret_key: SecretKey, - /// Wallet public key, derived from the secret key. - pub public_key: PublicKey, - /// Wallet address in hex without `0x` prefix, derived from the public key. - pub address: String, - /// Peer ID of the node. - pub peer_id: PeerId, - /// Compute node version. - pub version: SemanticVersion, - /// P2P listen address, e.g. `/ip4/0.0.0.0/tcp/4001`. - pub p2p_listen_addr: Multiaddr, - /// Executor manager, handles models and providers. - pub executors: DriaExecutorsManager, - /// Network type of the node. - pub network: DriaNetwork, - /// Batch size for batchable tasks (e.g. API-based ones). - /// - /// A higher value will help execute more tasks concurrently, - /// at the risk of hitting rate-limits. - pub batch_size: usize, - /// An optional first-attempt RPC address, will be dialled at startup. - /// - /// TODO: this is `None` after startup due to `Option::take`, can we do any better? - pub initial_rpc_addr: Option, - /// Execution platform, mainly for diagnostics. - /// - /// Given by `DKN_EXEC_PLATFORM`. - pub exec_platform: String, -} - -#[allow(clippy::new_without_default)] -impl DriaComputeNodeConfig { - /// Creates new config from environment variables. - pub fn new(executors: DriaExecutorsManager) -> Self { - let secret_key = match env::var("DKN_WALLET_SECRET_KEY") { - Ok(secret_env) => { - let secret_dec = hex::decode(secret_env.trim_start_matches("0x")) - .expect("Secret key should be 32-bytes hex encoded."); - - // if secret key is all-zeros, create one randomly - // this is useful for testing & creating nodes on the fly - if secret_dec.iter().all(|b| b == &0) { - SecretKey::random(&mut rand::thread_rng()) - } else { - SecretKey::parse_slice(&secret_dec).expect("Secret key should be parseable.") - } - } - Err(err) => { - log::error!("No secret key provided: {err}"); - panic!("Please provide a secret key."); - } - }; - log::info!( - "Node Secret Key: 0x{}{}", - hex::encode(&secret_key.serialize()[0..1]), - ".".repeat(64) - ); - - let public_key = PublicKey::from_secret_key(&secret_key); - log::info!( - "Node Public Key: 0x{}", - hex::encode(public_key.serialize_compressed()) - ); - - // print address - let address = hex::encode(public_key_to_address(&public_key)); - log::info!("Node Address: 0x{address}"); - - // to this here to log the peer id at start - let peer_id = secret_to_keypair(&secret_key).public().to_peer_id(); - log::info!("Node PeerID: {peer_id}"); - - // parse listen address - let p2p_listen_addr_str = env::var("DKN_P2P_LISTEN_ADDR") - .map(|addr| addr.trim_matches('"').to_string()) - .unwrap_or(DEFAULT_P2P_LISTEN_ADDR.to_string()); - let p2p_listen_addr = Multiaddr::from_str(&p2p_listen_addr_str) - .expect("could not parse the given P2P listen address."); - - // parse network type - let network_type = env::var("DKN_NETWORK") - // if there is an explicit value, default to testnet on error - .map(|s| DriaNetwork::try_from(s.as_str()).unwrap_or(DriaNetwork::Testnet)) - // if there is no explicit value, default to mainnet - .unwrap_or(DriaNetwork::Mainnet); - if network_type == DriaNetwork::Testnet { - log::warn!("Using testnet network!"); - } - - // parse batch size - let batch_size = env::var("DKN_BATCH_SIZE") - .map(|s| s.parse::().unwrap_or(DEFAULT_TASK_BATCH_SIZE)) - .unwrap_or(DEFAULT_TASK_BATCH_SIZE); - - // parse version - let version = env!("CARGO_PKG_VERSION") - .parse() - .expect("could not parse version"); - - // parse initial rpc address, if any - let initial_rpc_addr = env::var("DKN_INITIAL_RPC_ADDR") - .ok() - .and_then(|addr| if addr.is_empty() { None } else { Some(addr) }) - .map(|addr| { - Multiaddr::from_str(&addr).expect("could not parse the given initial RPC address.") - }); - - // parse execution platform - let exec_platform = env::var("DKN_EXEC_PLATFORM").unwrap_or_else(|_| "unknown".to_string()); - - Self { - secret_key, - public_key, - address, - peer_id, - version, - executors, - p2p_listen_addr, - network: network_type, - batch_size, - initial_rpc_addr, - exec_platform, - } - } - - /// Asserts that the configured listen address is free. - /// Throws an error if the address is already in use. - /// - /// Uses `is_port_reachable` function internally, which makes a simple - /// TCP connection to the given address. - /// - /// Can be inlined because the function is small and called only once. - #[inline] - pub fn assert_address_not_in_use(&self) -> Result<()> { - use dkn_p2p::libp2p::multiaddr::Protocol; - use port_check::is_port_reachable; - use std::net::{Ipv4Addr, SocketAddrV4}; - - let address_in_use = self - .p2p_listen_addr - .iter() - // find the port within our multiaddr - .find_map(|protocol| match protocol { - Protocol::Tcp(port) => Some(port), - _ => None, - }) - // check if its reachable or not - .map(|port| is_port_reachable(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))) - .unwrap_or_else(|| { - log::error!( - "could not find any TCP port in the given address: {:?}", - self.p2p_listen_addr - ); - false - }); - - if address_in_use { - return Err(eyre!( - "Listen address {} is already in use.", - self.p2p_listen_addr - )); - } - - Ok(()) - } -} diff --git a/compute/src/lib.rs b/compute/src/lib.rs deleted file mode 100644 index 56aed735..00000000 --- a/compute/src/lib.rs +++ /dev/null @@ -1,12 +0,0 @@ -pub mod config; -pub mod node; -pub mod reqres; -pub mod utils; -pub mod workers; - -/// Crate version of the compute node. -/// This value is attached within the published messages. -pub const DRIA_COMPUTE_NODE_VERSION: &str = env!("CARGO_PKG_VERSION"); - -pub use config::DriaComputeNodeConfig; -pub use node::DriaComputeNode; diff --git a/compute/src/main.rs b/compute/src/main.rs deleted file mode 100644 index bf1c9eff..00000000 --- a/compute/src/main.rs +++ /dev/null @@ -1,209 +0,0 @@ -use dkn_compute::*; -use dkn_executor::{DriaExecutorsManager, Model}; -use eyre::Result; -use std::env; -use tokio_util::{sync::CancellationToken, task::TaskTracker}; -use workers::task::TaskWorker; - -#[tokio::main] -async fn main() -> Result<()> { - // load a particular environment file specified by DKN_COMPUTE_ENV, or `.env` by default - let env_path = env::var("DKN_COMPUTE_ENV").unwrap_or_else(|_| ".env".to_string()); - let dotenv_result = dotenvy::from_path(&env_path); - - env_logger::builder() - .format_timestamp(Some(env_logger::TimestampPrecision::Millis)) - .filter(None, log::LevelFilter::Off) - .filter_module("dkn_compute", log::LevelFilter::Info) - .filter_module("dkn_p2p", log::LevelFilter::Info) - .filter_module("dkn_utils", log::LevelFilter::Info) - .filter_module("dkn_executor", log::LevelFilter::Info) - .filter_module("libp2p", log::LevelFilter::Error) - .parse_default_env() // reads RUST_LOG variable - .init(); - - log::info!( - r#" - -██████╗ ██████╗ ██╗ █████╗ -██╔══██╗██╔══██╗██║██╔══██╗ Dria Compute Node -██║ ██║██████╔╝██║███████║ v{DRIA_COMPUTE_NODE_VERSION} -██║ ██║██╔══██╗██║██╔══██║ https://dria.co -██████╔╝██║ ██║██║██║ ██║ -╚═════╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═╝ -"# - ); - - // log about env usage - match dotenv_result { - Ok(_) => log::info!("Loaded environment file from {env_path}"), - Err(err) => log::warn!("Could not load environment file from {env_path}: {err}"), - } - - // task tracker for multiple threads - let task_tracker = TaskTracker::new(); - let cancellation = CancellationToken::new(); - - // spawn the background task to wait for termination signals - let task_tracker_to_close = task_tracker.clone(); - let cancellation_token = cancellation.clone(); - task_tracker.spawn(async move { - if let Ok(Ok(duration_secs)) = - env::var("DKN_EXIT_TIMEOUT").map(|s| s.to_string().parse::()) - { - // the timeout is done for profiling only, and should not be used in production - log::warn!("Waiting for {duration_secs} seconds before exiting."); - tokio::time::sleep(tokio::time::Duration::from_secs(duration_secs)).await; - - log::warn!("Exiting due to DKN_EXIT_TIMEOUT."); - cancellation_token.cancel(); - } else if let Err(err) = wait_for_termination(cancellation_token.clone()).await { - // if there is no timeout, we wait for termination signals here - log::error!("Error waiting for termination: {err:?}"); - log::error!("Cancelling due to unexpected error."); - cancellation_token.cancel(); - }; - - // close tracker in any case - task_tracker_to_close.close(); - }); - - // create configurations - let models = Model::from_csv(env::var("DKN_MODELS").unwrap_or_default()); - let executors_config = DriaExecutorsManager::new_from_env_for_models(models.into_iter())?; - if executors_config.models.is_empty() { - return Err(eyre::eyre!("No models were provided, make sure to restart with at least one model provided within DKN_MODELS.")); - } - - log::info!( - "Initial provided models are: {}", - executors_config.get_model_names().join(", ") - ); - let mut config = DriaComputeNodeConfig::new(executors_config); - - // check address in use - config.assert_address_not_in_use()?; - - // check services & models, will exit if there is an error - // since service check can take time, we allow early-exit here as well - let model_perf = tokio::select! { - result = config.executors.check_services() => result, - _ = cancellation.cancelled() => { - log::info!("Service check cancelled, exiting."); - return Ok(()); - } - }; - - if config.executors.models.is_empty() { - return Err(eyre::eyre!( - "No valid models left after service checks, exiting." - )); - } else { - log::info!( - "Using models: {}\n{}", - config.executors.get_model_names().join(", "), - model_perf - .iter() - .map(|(model, perf)| format!("{model}: {perf}")) - .collect::>() - .join("\n") - ); - } - // create the node - let batch_size = config.batch_size; - let (mut node, p2p, worker_batch, worker_single) = - DriaComputeNode::new(config, model_perf).await?; - - // spawn p2p client first - log::info!("Spawning peer-to-peer client thread."); - task_tracker.spawn(async move { p2p.run().await }); - - // spawn batch worker thread if we are using such models (e.g. OpenAI, Gemini, OpenRouter) - if let Some(mut worker_batch) = worker_batch { - assert!( - batch_size <= TaskWorker::MAX_BATCH_SIZE, - "batch size too large" - ); - log::info!("Spawning batch executor worker thread. (batch size {batch_size})"); - task_tracker.spawn(async move { worker_batch.run_batch(batch_size).await }); - } - - // spawn single worker thread if we are using such models (e.g. Ollama) - if let Some(mut worker_single) = worker_single { - log::info!("Spawning single executor worker thread."); - task_tracker.spawn(async move { worker_single.run_series().await }); - } - - // spawn compute node thread - log::info!("Spawning compute node thread."); - let node_token = cancellation.clone(); - task_tracker.spawn(async move { - node.run(node_token).await; - log::info!("Closing node.") - }); - - // wait for all tasks to finish - task_tracker.wait().await; - log::info!("All tasks have exited succesfully."); - - log::info!("Bye!"); - Ok(()) -} - -/// Waits for various termination signals, and cancels the given token when the signal is received. -/// -/// Handles Unix and Windows [target families](https://doc.rust-lang.org/reference/conditional-compilation.html#target_family). -async fn wait_for_termination(cancellation: CancellationToken) -> Result<()> { - #[cfg(unix)] - { - use tokio::signal::unix::{signal, SignalKind}; - let mut sigterm = signal(SignalKind::terminate())?; // Docker sends SIGTERM - let mut sigint = signal(SignalKind::interrupt())?; // Ctrl+C sends SIGINT - tokio::select! { - _ = sigterm.recv() => log::warn!("Recieved SIGTERM"), - _ = sigint.recv() => log::warn!("Recieved SIGINT"), - _ = cancellation.cancelled() => { - // no need to wait if cancelled anyways - // although this is not likely to happen - return Ok(()); - } - }; - - cancellation.cancel(); - } - - #[cfg(windows)] - { - use tokio::signal::windows; - - // https://learn.microsoft.com/en-us/windows/console/handlerroutine - let mut signal_c = windows::ctrl_c()?; - let mut signal_break = windows::ctrl_break()?; - let mut signal_close = windows::ctrl_close()?; - let mut signal_shutdown = windows::ctrl_shutdown()?; - - tokio::select! { - _ = signal_c.recv() => log::warn!("Received CTRL_C"), - _ = signal_break.recv() => log::warn!("Received CTRL_BREAK"), - _ = signal_close.recv() => log::warn!("Received CTRL_CLOSE"), - _ = signal_shutdown.recv() => log::warn!("Received CTRL_SHUTDOWN"), - _ = cancellation.cancelled() => { - // no need to wait if cancelled anyways - // although this is not likely to happen - return Ok(()); - } - }; - - cancellation.cancel(); - } - - #[cfg(not(any(unix, windows)))] - { - log::error!("No signal handling for this platform: {}", env::consts::OS); - cancellation.cancel(); - } - - log::info!("Terminating the application..."); - - Ok(()) -} diff --git a/compute/src/node/core.rs b/compute/src/node/core.rs deleted file mode 100644 index e91c7f6c..00000000 --- a/compute/src/node/core.rs +++ /dev/null @@ -1,167 +0,0 @@ -use colored::Colorize; -use dkn_p2p::libp2p::{Multiaddr, PeerId}; -use dkn_utils::{ - payloads::{HEARTBEAT_TOPIC, SPECS_TOPIC}, - DriaMessage, -}; -use eyre::{eyre, Result}; -use std::time::Duration; -use tokio_util::sync::CancellationToken; - -use crate::{reqres::HeartbeatRequester, DriaComputeNode}; - -impl DriaComputeNode { - /// Runs the main loop of the compute node. - /// This method is not expected to return until cancellation occurs for the given token. - pub async fn run(&mut self, cancellation: CancellationToken) { - // initialize the points client - self.points_client.initialize().await; - - /// Duration between refreshing for diagnostic prints. - const DIAGNOSTIC_REFRESH_INTERVAL_SECS: Duration = Duration::from_secs(45); - /// Duration between refreshing for points update. - const POINTS_REFRESH_INTERVAL_SECS: Duration = Duration::from_secs(180); - /// Duration between refreshing the available nodes. - const RPC_LIVENESS_REFRESH_INTERVAL_SECS: Duration = Duration::from_secs(2 * 60); - /// Duration between each specs update sent to the RPC. - const SPECS_INTERVAL_SECS: Duration = Duration::from_secs(60 * 5); - - let mut diagnostic_refresh_interval = - tokio::time::interval(DIAGNOSTIC_REFRESH_INTERVAL_SECS); - diagnostic_refresh_interval.tick().await; // move each one tick - let mut rpc_liveness_refresh_interval = - tokio::time::interval(RPC_LIVENESS_REFRESH_INTERVAL_SECS); - rpc_liveness_refresh_interval.tick().await; // move each one tick - - // tick the first time a bit earlier - let mut points_refresh_interval = tokio::time::interval(POINTS_REFRESH_INTERVAL_SECS); - points_refresh_interval.tick().await; - points_refresh_interval.reset_after(POINTS_REFRESH_INTERVAL_SECS / 12); - - // move one tick, and wait at least a third of the diagnostics - let mut heartbeat_interval = tokio::time::interval(HeartbeatRequester::HEARTBEAT_DEADLINE); - heartbeat_interval.tick().await; - heartbeat_interval.reset_after(DIAGNOSTIC_REFRESH_INTERVAL_SECS / 3); - - // move one tick, and wait a little bit - let mut specs_interval = tokio::time::interval(SPECS_INTERVAL_SECS); - specs_interval.tick().await; - specs_interval.reset_after(DIAGNOSTIC_REFRESH_INTERVAL_SECS / 6); - - loop { - tokio::select! { - // a task is completed by the worker & should be responded to the requesting peer - task_response_msg_opt = self.task_output_rx.recv() => { - if let Some(task_response_msg) = task_response_msg_opt { - if let Err(err) = self.send_task_output(task_response_msg).await { - log::error!("Error responding to task: {err:?}"); - } - } else { - log::error!("task_output_rx channel closed unexpectedly, we still have {} batch and {} single tasks.", self.pending_tasks_batch.len(), self.pending_tasks_single.len()); - break; - } - }, - - // a Request or Response is received by the p2p client - reqres_msg_opt = self.reqres_rx.recv() => { - if let Some((peer_id, message)) = reqres_msg_opt { - self.handle_reqres(peer_id, message).await; - } else { - log::error!("reqres_rx channel closed unexpectedly."); - break; - } - }, - - // check peer count every now and then - _ = diagnostic_refresh_interval.tick() => self.handle_diagnostic_refresh().await, - - // check RPC, and get a new one if we are disconnected - _ = rpc_liveness_refresh_interval.tick() => { - let is_connected = self.handle_rpc_liveness_check().await; - if !is_connected { - // make sure we reset the heartbeat and specs intervals so that - // we dont wait the entire duration for this new connection - log::info!("Connecting was re-attempted, resetting timers."); - heartbeat_interval.reset_after(Duration::from_secs(5)); - specs_interval.reset_after(Duration::from_secs(5)); - } - }, - - // log points every now and then - _ = points_refresh_interval.tick() => self.handle_points_refresh().await, - - // send a heartbeat request to publish liveness info - _ = heartbeat_interval.tick() => { - if let Err(e) = self.send_heartbeat().await { - log::error!("Error making {}: {:?}", HEARTBEAT_TOPIC.blue(), e); - } - }, - - // send specs to the RPC - _ = specs_interval.tick() => { - if let Err(e) = self.send_specs().await { - log::error!("Error sending {}: {:?}", SPECS_TOPIC.green(), e); - } - }, - - // check if the cancellation token is cancelled - // this is expected to be cancelled by the main thread with signal handling - _ = cancellation.cancelled() => { - log::info!("Cancellation received, shutting down the node."); - break; - }, - } - } - - // print one final diagnostic as a summary - self.handle_diagnostic_refresh().await; - - // shutdown channels - if let Err(err) = self.shutdown().await { - log::error!("Could not shutdown the node gracefully: {err:?}"); - } - } - - /// Shorthand method to create a signed message with the given data and topic. - /// - /// Topic was previously used for GossipSub, but kept for verbosity. - #[inline(always)] - pub fn new_message(&self, data: impl AsRef<[u8]>, topic: impl ToString) -> DriaMessage { - DriaMessage::new_signed( - data, - topic, - self.p2p.protocol().name.clone(), - &self.config.secret_key, - self.config.version, - ) - } - - /// Dial the given peer at the given address. - pub async fn dial_with_timeout(&mut self, peer_id: PeerId, addr: Multiaddr) -> Result<()> { - // while not yet known, some people get stuck during the dialling step, - // this timeout prevents that. - const DIAL_TIMEOUT: Duration = Duration::from_secs(10); - - match tokio::time::timeout(DIAL_TIMEOUT, self.p2p.dial(peer_id, addr)).await { - Err(timeout) => Err(eyre!("Timeout dialling RPC node: {}", timeout)), - Ok(result) => result, // this is also a `Result` enum - } - } - - /// Shutdown channels between p2p, worker and yourself. - /// - /// Can be inlined as it is called only once from very few places. - #[inline] - pub async fn shutdown(&mut self) -> Result<()> { - log::debug!("Sending shutdown command to p2p client."); - self.p2p.shutdown().await?; - - log::debug!("Closing task output channel."); - self.task_output_rx.close(); - - log::debug!("Closing reqres channel."); - self.reqres_rx.close(); - - Ok(()) - } -} diff --git a/compute/src/node/diagnostic.rs b/compute/src/node/diagnostic.rs deleted file mode 100644 index e407a783..00000000 --- a/compute/src/node/diagnostic.rs +++ /dev/null @@ -1,152 +0,0 @@ -use colored::Colorize; -use std::time::Duration; - -use crate::{node::rpc::DriaRPC, DriaComputeNode, DRIA_COMPUTE_NODE_VERSION}; - -/// Number of seconds such that if the last heartbeat ACK is older than this, the node is considered unreachable. -/// This must be at least greated than the heartbeat interval duration, and the liveness check duration. -const HEARTBEAT_LIVENESS_SECS: Duration = Duration::from_secs(4 * 60); - -impl DriaComputeNode { - /// Returns the task count within the channels, `single` and `batch`. - #[inline(always)] - pub fn get_pending_task_count(&self) -> [usize; 2] { - [ - self.pending_tasks_single.len(), - self.pending_tasks_batch.len(), - ] - } - - /// Peer refresh simply reports the peer count to the user. - pub(crate) async fn handle_diagnostic_refresh(&mut self) { - let mut diagnostics = vec![format!("Diagnostics (v{}):", DRIA_COMPUTE_NODE_VERSION)]; - - // completed tasks count is printed as well in debug - if log::log_enabled!(log::Level::Debug) { - diagnostics.push(format!( - "Completed Tasks (single/batch): {} / {}", - self.completed_tasks_single, self.completed_tasks_batch - )); - - diagnostics.push(format!( - "RPC {}: {}", - self.dria_rpc.peer_id, - if self - .p2p - .is_connected(self.dria_rpc.peer_id) - .await - .unwrap_or(false) - { - "Connected".green() - } else { - "Disconnected".red() - } - )); - } - - // print peer id and address - diagnostics.push(format!("Peer ID: {}", self.config.peer_id)); - diagnostics.push(format!("Address: 0x{}", self.config.address)); - - // print models - diagnostics.push(format!( - "Models: {}", - self.config.executors.get_model_names().join(", ") - )); - - // if we have not received pings for a while, we are considered offline - let is_offline = chrono::Utc::now() > self.last_heartbeat_at + HEARTBEAT_LIVENESS_SECS; - - // if we have not yet received a heartbeat response, we are still connecting - if self.num_heartbeats == 0 { - // if we didnt have any pings, we might still be connecting - diagnostics.push(format!("Node Status: {}", "CONNECTING".yellow())); - } else { - diagnostics.push(format!( - "Node Status: {}", - if is_offline { - "OFFLINE".red() - } else { - "ONLINE".green() - } - )); - } - - log::info!("{}", diagnostics.join("\n ")); - - // if offline, print this error message as well - if is_offline { - log::error!( - "Node has not received any pings for at least {} seconds & it may be unreachable!\nPlease restart your node!", - HEARTBEAT_LIVENESS_SECS.as_secs() - ); - } - } - - /// Dials the existing RPC node if we are not connected to it. - /// - /// If there is an error while doing that, it will try to get a new RPC node and dial it. - /// - /// Returns `true` if the RPC is connected, `false` otherwise. - pub(crate) async fn handle_rpc_liveness_check(&mut self) -> bool { - log::debug!("Checking RPC connections for diagnostics."); - - // check if we are connected - let is_connected = self - .p2p - .is_connected(self.dria_rpc.peer_id) - .await - .unwrap_or(false); - - // if we are not connected, get a new RPC and dial it again - if !is_connected { - // if we also cannot dial it, get a new RPC node - log::warn!( - "Connection to RPC {} is lost, geting a new one!", - self.dria_rpc.addr, - ); - match DriaRPC::new_for_network(self.dria_rpc.network, &self.config.version).await { - Ok(new_rpc) => { - self.dria_rpc = new_rpc; - - // now dial this new RPC again - if let Err(err) = self - .dial_with_timeout(self.dria_rpc.peer_id, self.dria_rpc.addr.clone()) - .await - { - // worst-case we cant dial this one too, just leave it for the next diagnostic - log::error!("Could not dial the new RPC: {err:?}"); - } - } - Err(err) => { - log::error!("Could not get a new RPC node: {err:?}"); - } - }; - } else { - log::debug!("Connection with {} is intact.", self.dria_rpc.peer_id); - } - - // return the connection status - is_connected - } - - /// Updates the points for the given address. - #[inline] - pub(crate) async fn handle_points_refresh(&mut self) { - // get points from the API - match self.points_client.get_points().await { - Ok(steps) => { - log::info!( - "{}: {} total, {} earned in this run, within top {}%", - "$DRIA Points".purple(), - steps.score, - steps.score - self.points_client.initial, - steps.percentile - ); - } - Err(err) => { - log::error!("Could not get $DRIA points info: {err:?}"); - } - } - } -} diff --git a/compute/src/node/mod.rs b/compute/src/node/mod.rs deleted file mode 100644 index 0554c54c..00000000 --- a/compute/src/node/mod.rs +++ /dev/null @@ -1,166 +0,0 @@ -use dkn_executor::Model; -use dkn_p2p::{ - libp2p::PeerId, DriaP2PClient, DriaP2PCommander, DriaP2PProtocol, DriaReqResMessage, -}; -use dkn_utils::{crypto::secret_to_keypair, payloads::SpecModelPerformance}; -use eyre::Result; -use std::collections::{HashMap, HashSet}; -use tokio::sync::mpsc; -use uuid::Uuid; - -use crate::{ - config::*, - utils::{DriaPointsClient, SpecCollector}, - workers::task::{TaskWorker, TaskWorkerInput, TaskWorkerMetadata, TaskWorkerOutput}, -}; - -mod core; -mod diagnostic; -mod reqres; -mod rpc; -use rpc::DriaRPC; - -/// Buffer size for message publishes. -const PUBLISH_CHANNEL_BUFSIZE: usize = 1024; - -pub struct DriaComputeNode { - /// Compute node configuration. - pub config: DriaComputeNodeConfig, - /// Chosen RPC node. - pub dria_rpc: DriaRPC, - /// Peer-to-peer client commander to interact with the network. - pub p2p: DriaP2PCommander, - /// The last time the node had an acknowledged heartbeat. - /// If this is too much, we can say that the node is not reachable by RPC. - pub(crate) last_heartbeat_at: chrono::DateTime, - /// Number of pings received. - pub(crate) num_heartbeats: u64, - /// A mapping of heartbeat UUIDs to their deadlines. - /// This is used to track the heartbeats, and their acknowledgements. - pub(crate) heartbeats_reqs: HashMap>, - /// A mapping of specs UUIDs to their deadlines. - /// This is used to track the specs, and their acknowledgements. - pub(crate) specs_reqs: HashSet, - /// Request-response message receiver, can have both a request or a response. - reqres_rx: mpsc::Receiver<(PeerId, DriaReqResMessage)>, - /// Task response receiver, will respond to the request-response channel with the given result. - task_output_rx: mpsc::Receiver, - /// Task worker transmitter to send batchable tasks. - task_request_batch_tx: Option>, - /// Task worker transmitter to send single tasks. - task_request_single_tx: Option>, - /// Single tasks, key is `row_id`, which has negligible probability of collision. - pub pending_tasks_single: HashMap, - // Batchable tasks, key is `row_id`, which has negligible probability of collision. - pub pending_tasks_batch: HashMap, - /// Completed single tasks count - completed_tasks_single: usize, - /// Completed batch tasks count - completed_tasks_batch: usize, - /// Specifications collector. - spec_collector: SpecCollector, - /// Points client. - points_client: DriaPointsClient, -} - -impl DriaComputeNode { - /// Creates a new `DriaComputeNode` with the given configuration and cancellation token. - /// - /// Returns the node instance and p2p client together. P2p MUST be run in a separate task before this node is used at all. - pub async fn new( - mut config: DriaComputeNodeConfig, - model_perf: HashMap, - ) -> Result<( - DriaComputeNode, - DriaP2PClient, - Option, - Option, - )> { - // create the keypair from secret key - let keypair = secret_to_keypair(&config.secret_key); - - // dial the RPC node - let dria_rpc = if let Some(addr) = config.initial_rpc_addr.take() { - log::info!("Using initial RPC address: {addr}"); - DriaRPC::new(addr, config.network).expect("could not get RPC to connect to") - } else { - DriaRPC::new_for_network(config.network, &config.version) - .await - .expect("could not get RPC to connect to") - }; - - // we are using the major.minor version as the P2P version - // so that patch versions do not interfere with the protocol - let protocol = DriaP2PProtocol::new_major_minor(config.network.protocol_name()); - log::info!("Using identity: {protocol}"); - - // create p2p client - let (p2p_client, p2p_commander, request_rx) = DriaP2PClient::new( - keypair, - config.p2p_listen_addr.clone(), - &dria_rpc.addr, - protocol, - )?; - - // create channel for task executors, all workers use the same publish channel - let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_BUFSIZE); - - // check if we should create a worker for batch executor - let (task_batch_worker, task_batch_tx) = - if config.executors.providers.keys().any(|p| p.is_batchable()) { - let (worker, sender) = TaskWorker::new(publish_tx.clone()); - (Some(worker), Some(sender)) - } else { - (None, None) - }; - - // check if we should create a worker for single executor - let (task_single_worker, task_single_tx) = - if config.executors.providers.keys().any(|p| !p.is_batchable()) { - let (worker, sender) = TaskWorker::new(publish_tx); - (Some(worker), Some(sender)) - } else { - (None, None) - }; - - let model_names = config.executors.get_model_names(); - let points_client = DriaPointsClient::new(&config.address, &config.network)?; - - let spec_collector = SpecCollector::new( - model_names.clone(), - model_perf, - config.version, - config.exec_platform.clone(), - p2p_client.peer_id, - ); - Ok(( - DriaComputeNode { - config, - p2p: p2p_commander, - dria_rpc, - points_client, - // receivers - task_output_rx: publish_rx, - reqres_rx: request_rx, - // transmitters - task_request_batch_tx: task_batch_tx, - task_request_single_tx: task_single_tx, - // task trackers - pending_tasks_single: HashMap::new(), - pending_tasks_batch: HashMap::new(), - completed_tasks_single: 0, - completed_tasks_batch: 0, - // heartbeats - heartbeats_reqs: HashMap::new(), - last_heartbeat_at: chrono::Utc::now(), - num_heartbeats: 0, - // specs - specs_reqs: HashSet::new(), - spec_collector, - }, - p2p_client, - task_batch_worker, - task_single_worker, - )) - } -} diff --git a/compute/src/node/reqres.rs b/compute/src/node/reqres.rs deleted file mode 100644 index 10368c7f..00000000 --- a/compute/src/node/reqres.rs +++ /dev/null @@ -1,211 +0,0 @@ -use colored::Colorize; -use dkn_p2p::libp2p::{ - request_response::{OutboundRequestId, ResponseChannel}, - PeerId, -}; -use dkn_p2p::DriaReqResMessage; -use dkn_utils::{ - payloads::{HEARTBEAT_TOPIC, SPECS_TOPIC, TASK_REQUEST_TOPIC}, - DriaMessage, -}; -use eyre::Result; - -use crate::{reqres::*, workers::task::TaskWorkerOutput}; - -use super::DriaComputeNode; - -impl DriaComputeNode { - /// Handles a generic request-response message received from the network. - /// - /// - Request is forwarded to [`handle_request`](DriaComputeNode::handle_request) method. - /// - Response is forwarded to [`handle_response`](DriaComputeNode::handle_response) method. - /// - /// Does not return an error, but simply logs it to [`log::error`]. - pub(crate) async fn handle_reqres(&mut self, peer_id: PeerId, message: DriaReqResMessage) { - match message { - // make sure that the `channel` here is NOT DROPPED until a response is sent, - // otherwise you will get an error - DriaReqResMessage::Request { - request, - request_id, - channel, - } => { - log::debug!("Received a request ({request_id}) from {peer_id}"); - - // ensure that message is from the known RPCs - if self.dria_rpc.peer_id != peer_id { - log::warn!("Received request from unauthorized source: {peer_id}"); - log::debug!("Allowed source: {}", self.dria_rpc.peer_id); - } else if let Err(err) = self.handle_request(peer_id, &request, channel).await { - log::error!("Error handling request: {err:?}"); - } - } - - DriaReqResMessage::Response { - response, - request_id, - } => { - log::debug!("Received a response ({request_id}) from {peer_id}"); - if let Err(err) = self.handle_response(peer_id, request_id, response).await { - log::error!("Error handling response: {err:?}"); - } - } - }; - } - - /// Handles a [`request_response`] response received from the network. - /// - /// - Internally, the data is expected to be some JSON serialized data that is expected to be parsed and handled. - /// - Can be inlined because it is only called by [`DriaComputeNode::handle_reqres`]. - #[inline] - async fn handle_response( - &mut self, - peer_id: PeerId, - request_id: OutboundRequestId, - data: Vec, - ) -> Result<()> { - if peer_id != self.dria_rpc.peer_id { - log::warn!("Received response from unauthorized source: {peer_id}"); - log::debug!("Allowed source: {}", self.dria_rpc.peer_id); - } - - if let Ok(heartbeat_response) = HeartbeatRequester::try_parse_response(&data) { - log::info!( - "Received a {} response ({request_id}) from {peer_id}", - HEARTBEAT_TOPIC.blue(), - ); - HeartbeatRequester::handle_ack(self, heartbeat_response).await - } else if let Ok(spec_response) = SpecRequester::try_parse_response(&data) { - log::info!( - "Received a {} response ({request_id}) from {peer_id}", - SPECS_TOPIC.green(), - ); - SpecRequester::handle_ack(self, spec_response).await - } else { - Err(eyre::eyre!("Received unhandled request from {}", peer_id)) - } - } - - /// Handles a [`request_response`] request received from the network. - /// - /// - Internally, the data is expected to be some JSON serialized data that is expected to be parsed and handled. - /// - Can be inlined because it is only called by [`DriaComputeNode::handle_reqres`]. - async fn handle_request( - &mut self, - peer_id: PeerId, - message_data: &[u8], - channel: ResponseChannel>, - ) -> Result<()> { - let message = DriaMessage::from_slice_checked( - message_data, - self.p2p.protocol().name.clone(), - self.config.version, - )?; - - match message.topic.as_str() { - TASK_REQUEST_TOPIC => self.handle_task_request(peer_id, message, channel).await, - _ => Err(eyre::eyre!("Received unhandled request from {peer_id}")), - } - } - - /// Handles a Task request received from the network. - /// - /// Based on the task type, the task is sent to the appropriate worker & metadata is stored in memory. - /// This metadata will be used during response as well, and we can count the number of tasks at hand by - /// looking at the number metadata stored. - async fn handle_task_request( - &mut self, - peer_id: PeerId, - task_request: ::Request, - channel: ResponseChannel>, - ) -> Result<()> { - log::info!( - "Received a {} request from {peer_id}", - TASK_REQUEST_TOPIC.yellow() - ); - - let (task_input, task_metadata) = - TaskResponder::parse_task_request(self, &task_request, channel).await?; - if let Err(err) = match task_input.task.is_batchable() { - // this is a batchable task, send it to batch worker - // and keep track of the task id in pending tasks - true => match self.task_request_batch_tx { - Some(ref mut tx) => { - self.pending_tasks_batch - .insert(task_input.row_id, task_metadata); - tx.send(task_input).await - } - None => eyre::bail!("Batchable task received but no worker available."), - }, - - // this is a single task, send it to single worker - // and keep track of the task id in pending tasks - false => match self.task_request_single_tx { - Some(ref mut tx) => { - self.pending_tasks_single - .insert(task_input.row_id, task_metadata); - tx.send(task_input).await - } - None => eyre::bail!("Single task received but no worker available."), - }, - } { - log::error!("Could not send task to worker: {err:?}"); - }; - - Ok(()) - } - - pub(crate) async fn send_task_output(&mut self, task_response: TaskWorkerOutput) -> Result<()> { - // remove the task from pending tasks, and get its metadata - let task_metadata = match task_response.batchable { - true => { - self.completed_tasks_batch += 1; // TODO: this should be done in success - self.pending_tasks_batch.remove(&task_response.row_id) - } - false => { - self.completed_tasks_single += 1; // TODO: this should be done in success - self.pending_tasks_single.remove(&task_response.row_id) - } - }; - - // respond to the response channel with the result - match task_metadata { - Some(task_metadata) => { - TaskResponder::send_task_output(self, task_response, task_metadata).await?; - } - None => { - // totally unexpected case, wont happen at all - eyre::bail!("Metadata not found for {}", task_response.row_id); - } - }; - - Ok(()) - } - - /// Sends a heartbeat request to the configured RPC node. - #[inline] - pub(crate) async fn send_heartbeat(&mut self) -> Result<()> { - let peer_id = self.dria_rpc.peer_id; - let request_id = HeartbeatRequester::send_heartbeat(self, peer_id).await?; - log::info!( - "Sending {} request ({request_id}) to {peer_id}", - HEARTBEAT_TOPIC.blue() - ); - - Ok(()) - } - - /// Sends a specs request to the configured RPC node. - #[inline] - pub(crate) async fn send_specs(&mut self) -> Result<()> { - let peer_id = self.dria_rpc.peer_id; - let specs = self.spec_collector.collect().await; - let request_id = SpecRequester::send_specs(self, peer_id, specs).await?; - log::info!( - "Sending {} request ({request_id}) to {peer_id}", - SPECS_TOPIC.green() - ); - - Ok(()) - } -} diff --git a/compute/src/node/rpc.rs b/compute/src/node/rpc.rs deleted file mode 100644 index 8e3cbeb9..00000000 --- a/compute/src/node/rpc.rs +++ /dev/null @@ -1,108 +0,0 @@ -use dkn_p2p::libp2p::{multiaddr::Protocol, Multiaddr, PeerId}; -use dkn_utils::{DriaNetwork, SemanticVersion}; -use eyre::{Context, OptionExt, Result}; -use rand::seq::SliceRandom; -use std::fmt::Debug; - -/// The connected RPC node, as per the Star network topology. -#[derive(Debug, Clone)] -pub struct DriaRPC { - pub addr: Multiaddr, - pub peer_id: PeerId, - pub network: DriaNetwork, -} - -impl DriaRPC { - /// Creates a new RPC target at the given type, along with a network type for refreshing the RPC address. - pub fn new(addr: Multiaddr, network: DriaNetwork) -> Result { - let peer_id = addr - .iter() - .find_map(|p| match p { - Protocol::P2p(peer_id) => Some(peer_id), - _ => None, - }) - .ok_or_eyre("did not find peer ID within the returned RPC address")?; - - Ok(Self { - addr, - peer_id, - network, - }) - } - - /// Creates a new RPC target for the given network type and version. - pub async fn new_for_network(network: DriaNetwork, version: &SemanticVersion) -> Result { - let addr = get_rpc_for_network(&network, version).await?; - Self::new(addr, network) - } -} - -/// Calls the DKN API to get an RPC address for the given network type. -/// -/// The peer id is expected to be within the multi-address. -async fn get_rpc_for_network( - network: &DriaNetwork, - version: &SemanticVersion, -) -> Result { - const MIN_MARGIN: usize = 150; - - let response = reqwest::get(network.discovery_url(version)).await?; - let rpcs_and_peer_counts = response - .json::>() - .await - .wrap_err("could not parse API response")?; - - // ensure that the response contains at least one RPC - if rpcs_and_peer_counts.is_empty() { - eyre::bail!("no RPCs were returned by discovery API"); - } - - // get the minimum count of peers from all RPCs - let min_peer_count = rpcs_and_peer_counts - .iter() - .map(|(_, peer_count)| *peer_count) - .min() - .unwrap(); // safe to unwrap because we checked for empty earlier - - // choose the RPCs that have peers in range `[min_peer_count, min_peer_count + MIN_MARGIN]` - let rpcs_and_peer_counts: Vec<(Multiaddr, usize)> = rpcs_and_peer_counts - .into_iter() - .filter(|(_, peer_count)| { - (min_peer_count..=min_peer_count + MIN_MARGIN).contains(peer_count) - }) - .collect(); - - // pick a random RPC from the filtered list - let chosen_rpc = rpcs_and_peer_counts - .choose(&mut rand::thread_rng()) - .cloned() - .map(|(addr, _)| addr) - .unwrap(); // safe to unwrap because we checked for empty earlier - - Ok(chosen_rpc) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_dria_nodes() { - let node = - DriaRPC::new_for_network(DriaNetwork::Mainnet, &SemanticVersion::from_crate_version()) - .await; - assert!(node.is_ok()); - } - - #[test] - fn test_deserialize() { - let input = r#"[ - ["/ip4/12.34.56.78/tcp/4001/p2p/16Uiu2HAmG7qrpSh8kenjuYqyrwxgEVdzqRV4wM1hHAZRq4j25VBC", 1], - ["/ip4/78.56.34.12/tcp/4001/p2p/16Uiu2HAmG7qrpSh8kenjuYqyrwxgEVdzqRV4wM1hHAZRq4j25VBC", 4] - ]"#; - let result: Vec<(Multiaddr, usize)> = serde_json::from_str(input).unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result[0].1, 1); - assert_eq!(result[1].1, 4); - } -} diff --git a/compute/src/reqres/heartbeat.rs b/compute/src/reqres/heartbeat.rs deleted file mode 100644 index fb852cc2..00000000 --- a/compute/src/reqres/heartbeat.rs +++ /dev/null @@ -1,87 +0,0 @@ -use colored::Colorize; -use dkn_p2p::libp2p::{request_response::OutboundRequestId, PeerId}; -use dkn_utils::{ - payloads::{HeartbeatRequest, HeartbeatResponse, HEARTBEAT_TOPIC}, - DriaMessage, -}; -use eyre::{eyre, Result}; -use std::time::Duration; -use uuid::Uuid; - -use super::IsResponder; - -use crate::DriaComputeNode; - -pub struct HeartbeatRequester; - -impl IsResponder for HeartbeatRequester { - type Request = DriaMessage; // HeartbeatRequest; - type Response = HeartbeatResponse; -} - -impl HeartbeatRequester { - /// Any acknowledged heartbeat that is older than this duration is considered dead. - pub const HEARTBEAT_DEADLINE: Duration = Duration::from_secs(60); - pub(crate) async fn send_heartbeat( - node: &mut DriaComputeNode, - peer_id: PeerId, - ) -> Result { - let uuid = Uuid::now_v7(); - let deadline = chrono::Utc::now() + Self::HEARTBEAT_DEADLINE; - - let heartbeat_request = HeartbeatRequest { - heartbeat_id: uuid, - deadline, - pending_batch: node.pending_tasks_batch.len(), - pending_single: node.pending_tasks_single.len(), - batch_size: node.config.batch_size, - }; - - let heartbeat_message = node.new_message( - serde_json::to_vec(&heartbeat_request).expect("should be serializable"), - HEARTBEAT_TOPIC, - ); - let request_id = node.p2p.request(peer_id, heartbeat_message).await?; - - // add it to local heartbeats set - node.heartbeats_reqs.insert(uuid, deadline); - - Ok(request_id) - } - - /// Handles the heartbeat acknowledement by RPC. - pub(crate) async fn handle_ack( - node: &mut DriaComputeNode, - res: HeartbeatResponse, - ) -> Result<()> { - if let Some(deadline) = node.heartbeats_reqs.remove(&res.heartbeat_id) { - if let Some(err) = res.error { - Err(eyre!( - "{} was not acknowledged: {}", - HEARTBEAT_TOPIC.blue(), - err - )) - } else { - // acknowledge heartbeat - node.last_heartbeat_at = chrono::Utc::now(); - node.num_heartbeats += 1; - - // for diagnostics, we can check if the heartbeat was past its deadline as well - if chrono::Utc::now() > deadline { - log::warn!( - "Acknowledged {} was past its deadline.", - HEARTBEAT_TOPIC.blue() - ) - } - - Ok(()) - } - } else { - Err(eyre!( - "Received an unknown {} response with id {}.", - HEARTBEAT_TOPIC.blue(), - res.heartbeat_id - )) - } - } -} diff --git a/compute/src/reqres/mod.rs b/compute/src/reqres/mod.rs deleted file mode 100644 index 5903cdfb..00000000 --- a/compute/src/reqres/mod.rs +++ /dev/null @@ -1,80 +0,0 @@ -//! Request-response handlers. - -use eyre::Context; -use serde::{de::DeserializeOwned, Serialize}; - -mod specs; -pub use specs::SpecRequester; - -mod task; -pub use task::TaskResponder; - -mod heartbeat; -pub use heartbeat::HeartbeatRequester; - -/// A responder should implement a request & response type, both serializable. -/// -/// The `try_parse_request` is automatically implemented using `serde-json` for a byte slice. -pub trait IsResponder { - type Request: DeserializeOwned; - type Response: Serialize + DeserializeOwned; - - fn try_parse_request(data: &[u8]) -> eyre::Result { - serde_json::from_slice(data).wrap_err("could not parse request") - } - - fn try_parse_response(data: &[u8]) -> eyre::Result { - serde_json::from_slice(data).wrap_err("could not parse response") - } -} - -#[cfg(test)] -mod tests { - - use super::*; - - // TODO: remove this test when we migrate to enum-based bodies - #[test] - fn test_enum_serialization() { - use serde::Deserialize; - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - struct AEnum { - a1: bool, - a2: String, - } - - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - struct BEnum { - b1: u64, - b2: bool, - } - - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - #[serde(tag = "type", rename_all = "camelCase")] - enum TestEnum { - A(AEnum), - B(BEnum), - } - - let a_variant = TestEnum::A(AEnum { - a1: true, - a2: "test".to_string(), - }); - let b_variant = TestEnum::B(BEnum { - b1: 123456789, - b2: false, - }); - - let a_serialized = serde_json::to_string(&a_variant).unwrap(); - let b_serialized = serde_json::to_string(&b_variant).unwrap(); - - assert_eq!(a_serialized, r#"{"type":"a","a1":true,"a2":"test"}"#); - assert_eq!(b_serialized, r#"{"type":"b","b1":123456789,"b2":false}"#); - - let a_deserialized: TestEnum = serde_json::from_str(&a_serialized).unwrap(); - let b_deserialized: TestEnum = serde_json::from_str(&b_serialized).unwrap(); - - assert_eq!(a_variant, a_deserialized); - assert_eq!(b_variant, b_deserialized); - } -} diff --git a/compute/src/reqres/specs.rs b/compute/src/reqres/specs.rs deleted file mode 100644 index eb9194c0..00000000 --- a/compute/src/reqres/specs.rs +++ /dev/null @@ -1,57 +0,0 @@ -use crate::DriaComputeNode; - -use super::IsResponder; -use colored::Colorize; -use dkn_p2p::libp2p::{request_response::OutboundRequestId, PeerId}; -use dkn_utils::{ - payloads::{Specs, SpecsRequest, SpecsResponse, SPECS_TOPIC}, - DriaMessage, -}; -use eyre::{eyre, Result}; -use uuid::Uuid; - -pub struct SpecRequester; - -impl IsResponder for SpecRequester { - type Request = DriaMessage; // SpecRequest; - type Response = SpecsResponse; -} - -impl SpecRequester { - pub(crate) async fn send_specs( - node: &mut DriaComputeNode, - peer_id: PeerId, - specs: Specs, - ) -> Result { - let uuid = Uuid::now_v7(); - let specs_request = SpecsRequest { - specs_id: uuid, - specs, - address: node.config.address.clone(), - }; - - let specs_message = node.new_message( - serde_json::to_vec(&specs_request).expect("should be serializable"), - SPECS_TOPIC, - ); - let request_id = node.p2p.request(peer_id, specs_message).await?; - - // add it to local specs set - node.specs_reqs.insert(uuid); - - Ok(request_id) - } - - /// Handles the specs request received from the network. - pub(crate) async fn handle_ack(node: &mut DriaComputeNode, res: SpecsResponse) -> Result<()> { - if node.specs_reqs.remove(&res.specs_id) { - Ok(()) - } else { - Err(eyre!( - "Received an unknown {} response with id {}.", - SPECS_TOPIC.green(), - res.specs_id - )) - } - } -} diff --git a/compute/src/reqres/task.rs b/compute/src/reqres/task.rs deleted file mode 100644 index f74ae36f..00000000 --- a/compute/src/reqres/task.rs +++ /dev/null @@ -1,283 +0,0 @@ -use colored::Colorize; -use dkn_executor::{CompletionError, ModelProvider, PromptError, TaskBody}; -use dkn_p2p::libp2p::request_response::ResponseChannel; -use dkn_utils::payloads::{ - TaskError, TaskRequestPayload, TaskResponsePayload, TaskStats, TASK_RESULT_TOPIC, -}; -use dkn_utils::DriaMessage; -use eyre::{Context, Result}; - -use crate::workers::task::*; -use crate::DriaComputeNode; - -pub struct TaskResponder; - -impl super::IsResponder for TaskResponder { - type Request = DriaMessage; // TODO: can we do this typed? - type Response = DriaMessage; // TODO: can we do this typed? -} - -impl TaskResponder { - pub(crate) async fn parse_task_request( - node: &mut DriaComputeNode, - compute_message: &DriaMessage, - channel: ResponseChannel>, - ) -> Result<(TaskWorkerInput, TaskWorkerMetadata)> { - // parse this in two-steps so that if something goes wrong we know the task id - let task = compute_message - .parse_payload::>() - .wrap_err("could not parse task request payload")?; - let task_body = match serde_json::from_value::(task.input) { - Ok(task_body) => task_body, - Err(err) => { - log::error!( - "Task {}/{} failed due to parsing error: {err}", - task.file_id, - task.row_id, - ); - - // prepare error payload - let error_payload = TaskResponsePayload { - result: None, - error: Some(TaskError::ParseError(err.to_string())), - row_id: task.row_id, - file_id: task.file_id, - task_id: task.task_id, - model: "".to_string(), // no model available due to parsing error - stats: TaskStats::new(), - }; - - let error_payload_str = serde_json::to_string(&error_payload) - .wrap_err("could not serialize payload")?; - - // respond through the channel to notify about the parsing error - let response = node.new_message(error_payload_str, TASK_RESULT_TOPIC); - node.p2p.respond(response.into(), channel).await?; - - // return with error - eyre::bail!("could not parse task body: {err}") - } - }; - - let stats = TaskStats::new().record_received_at(); - log::info!( - "Handling {} {} with model {}", - "task".yellow(), - task.row_id, - task_body.model.to_string().yellow() - ); - - // check if the model is available in this node, if so - // it will return an executor that can run this model - let executor = node.config.executors.get_executor(&task_body.model).await?; - - let task_metadata = TaskWorkerMetadata { - task_id: task.task_id, - file_id: task.file_id, - model: task_body.model, - channel, - }; - let task_input = TaskWorkerInput { - executor, - task: task_body, - row_id: task.row_id, - stats, - }; - - Ok((task_input, task_metadata)) - } - - /// Handles the result of a task. - pub(crate) async fn send_task_output( - node: &mut DriaComputeNode, - task_output: TaskWorkerOutput, - task_metadata: TaskWorkerMetadata, - ) -> Result<()> { - let response = match task_output.result { - Ok(result) => { - // prepare signed and encrypted payload - log::info!( - "Publishing {} result for {}/{}", - "task".yellow(), - task_metadata.file_id, - task_output.row_id - ); - - // TODO: will get better token count from `TaskWorkerOutput` - let token_count = result.len(); - let payload = TaskResponsePayload { - result: Some(result), - error: None, - file_id: task_metadata.file_id, - task_id: task_metadata.task_id, - row_id: task_output.row_id, - model: task_metadata.model.to_string(), - stats: task_output - .stats - .record_published_at() - .record_token_count(token_count), - }; - let payload_str = - serde_json::to_string(&payload).wrap_err("could not serialize payload")?; - - node.new_message(payload_str, TASK_RESULT_TOPIC) - } - Err(err) => { - // use pretty display string for error logging with causes - log::error!( - "Task {}/{} failed: {:#}", - task_metadata.file_id, - task_output.row_id, - err - ); - - // prepare error payload - let error_payload = TaskResponsePayload { - result: None, - error: Some(map_prompt_error_to_task_error( - task_metadata.model.provider(), - err, - )), - row_id: task_output.row_id, - file_id: task_metadata.file_id, - task_id: task_metadata.task_id, - model: task_metadata.model.to_string(), - stats: task_output - .stats - .record_published_at() - .record_token_count(0), - }; - let error_payload_str = serde_json::to_string(&error_payload) - .wrap_err("could not serialize payload")?; - - node.new_message(error_payload_str, TASK_RESULT_TOPIC) - } - }; - - // respond through the channel - node.p2p - .respond(response.into(), task_metadata.channel) - .await?; - - Ok(()) - } -} - -/// Maps a [`PromptError`] to a [`TaskError`] with respect to the given provider. -fn map_prompt_error_to_task_error(provider: ModelProvider, err: PromptError) -> TaskError { - match &err { - // if the error is a provider error, we can try to parse it - PromptError::CompletionError(CompletionError::ProviderError(err_inner)) => { - /// A wrapper for `{ error: T }` to match the provider error format. - #[derive(Clone, serde::Deserialize)] - struct ErrorObject { - error: T, - } - - match provider { - // ModelProvider::Gemini => { - // /// Gemini API [error object](https://github.com/googleapis/go-genai/blob/main/api_client.go#L273). - // #[derive(Clone, serde::Deserialize)] - // pub struct GeminiError { - // code: u32, - // message: String, - // status: String, - // } - - // serde_json::from_str::>(err_inner).map( - // |ErrorObject { - // error: gemini_error, - // }| TaskError::ProviderError { - // code: format!("{} ({})", gemini_error.code, gemini_error.status), - // message: gemini_error.message, - // provider: provider.to_string(), - // }, - // ) - // } - // ModelProvider::OpenAI => { - // /// OpenAI API [error object](https://github.com/openai/openai-go/blob/main/internal/apierror/apierror.go#L17). - // #[derive(Clone, serde::Deserialize)] - // pub struct OpenAIError { - // code: String, - // message: String, - // } - - // serde_json::from_str::>(err_inner).map( - // |ErrorObject { - // error: openai_error, - // }| TaskError::ProviderError { - // code: openai_error.code, - // message: openai_error.message, - // provider: provider.to_string(), - // }, - // ) - // } - // ModelProvider::OpenRouter => { - // /// OpenRouter API [error object](https://openrouter.ai/docs/api-reference/errors). - // #[derive(Clone, serde::Deserialize)] - // pub struct OpenRouterError { - // code: u32, - // message: String, - // } - - // serde_json::from_str::>(err_inner).map( - // |ErrorObject { - // error: openrouter_error, - // }| { - // TaskError::ProviderError { - // code: openrouter_error.code.to_string(), - // message: openrouter_error.message, - // provider: provider.to_string(), - // } - // }, - // ) - // } - ModelProvider::Ollama => serde_json::from_str::>(err_inner) - .map( - // Ollama just returns a string error message - |ErrorObject { - error: ollama_error, - }| { - // based on the error message, we can come up with out own "dummy" codes - let code = if ollama_error.contains("server busy, please try again.") { - "server_busy" - } else if ollama_error.contains("model requires more system memory") { - "model_requires_more_memory" - } else if ollama_error.contains("cudaMalloc failed: out of memory") { - "cuda_malloc_failed" - } else if ollama_error.contains("CUDA error: out of memory") { - "cuda_oom" - } else if ollama_error.contains("API Error: Too Many Requests") { - "api:too_many_requests" - } else if ollama_error.contains("API Error: Bad Request") { - "api:bad_request" - } else if ollama_error.contains("not found, try pulling it first") { - "model_not_pulled" - } else if ollama_error.contains("Unexpected end of JSON input") { - "unexpected_end_of_json" - } else { - "unknown" - }; - - TaskError::ProviderError { - code: code.to_string(), - message: ollama_error, - provider: provider.to_string(), - } - }, - ), - } - // if we couldn't parse it, just return a generic prompt error - .unwrap_or(TaskError::ExecutorError(format!( - "{provider} executor error: {}", - err_inner.clone() - ))) - } - // if its a http error, we can try to parse it as well - PromptError::CompletionError(CompletionError::HttpError(err_inner)) => { - TaskError::HttpError(err_inner.to_string()) - } - // if it's not a completion error, we just return the error as is - err => TaskError::Other(err.to_string()), - } -} diff --git a/compute/src/utils/mod.rs b/compute/src/utils/mod.rs deleted file mode 100644 index e5be541e..00000000 --- a/compute/src/utils/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod specs; -pub use specs::*; - -mod points; -pub use points::*; diff --git a/compute/src/utils/points.rs b/compute/src/utils/points.rs deleted file mode 100644 index ac07bdf3..00000000 --- a/compute/src/utils/points.rs +++ /dev/null @@ -1,85 +0,0 @@ -use dkn_utils::DriaNetwork; -use eyre::Context; - -pub struct DriaPointsClient { - pub url: String, - client: reqwest::Client, - /// The total number of points you have accumulated at the start of the run. - pub initial: f64, -} - -#[derive(Debug, serde::Deserialize)] -pub struct DriaPoints { - /// Indicates in which top percentile your points are. - pub percentile: usize, - /// The total number of points you have accumulated. - pub score: f64, -} - -impl DriaPointsClient { - /// The base URL for the points API, w.r.t network. - pub fn base_url(network: &DriaNetwork) -> &'static str { - match network { - DriaNetwork::Mainnet => "https://mainnet.dkn.dria.co/points/v0/total/node/", - DriaNetwork::Testnet => "https://testnet.dkn.dria.co/points/v0/total/node/", - } - } - - /// Creates a new `DriaPointsClient` for the given address. - pub fn new(address: &str, network: &DriaNetwork) -> eyre::Result { - const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); - - let url = format!( - "{}/0x{}", - Self::base_url(network), - address.trim_start_matches("0x") - ); - - let client = reqwest::Client::builder() - .user_agent(USER_AGENT) - .build() - .wrap_err("could not create Points client")?; - - Ok(Self { - url, - client, - initial: 0.0, - }) - } - - /// Sets the initial points to the current points. - /// - /// If there is an error, it sets to 0.0. - pub async fn initialize(&mut self) { - self.initial = self.get_points().await.map(|p| p.score).unwrap_or_default(); - } - - pub async fn get_points(&self) -> eyre::Result { - let res = self - .client - .get(&self.url) - .send() - .await - .wrap_err("could not make request")?; - res.json::() - .await - .wrap_err("could not parse response") - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_get_points() { - let client = DriaPointsClient::new( - "0xa43536a6032a3907ccf60e8109429ee1047b207c", - &DriaNetwork::Mainnet, - ) - .unwrap(); - let steps = client.get_points().await.unwrap(); - assert!(steps.score >= 0.0); - assert!(steps.percentile <= 100); - } -} diff --git a/compute/src/utils/specs.rs b/compute/src/utils/specs.rs deleted file mode 100644 index 837cfab3..00000000 --- a/compute/src/utils/specs.rs +++ /dev/null @@ -1,120 +0,0 @@ -use dkn_executor::Model; -use dkn_p2p::libp2p::PeerId; -use dkn_utils::{ - payloads::{SpecModelPerformance, Specs}, - SemanticVersion, -}; -use std::collections::HashMap; -use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind}; - -pub struct SpecCollector { - /// System information object, this is expected to be created only once - /// as per the [docs](https://github.com/GuillaumeGomez/sysinfo?tab=readme-ov-file#good-practice--performance-tips). - system: sysinfo::System, - /// Used models. - models: Vec, - /// Model performances - model_perf: HashMap, - /// Version string. - version: String, - /// Execution platform, mainly for diagnostics. - exec_platform: String, - /// Peer ID of the node, used for identification in the network. - peer_id: String, - // GPU adapter infos, showing information about the available GPUs. - // gpus: Vec, -} - -impl SpecCollector { - pub fn new( - models: Vec, - model_perf: HashMap, - version: SemanticVersion, - exec_platform: String, - peer_id: PeerId, - ) -> Self { - log::info!("Creating spec collector with version {version} and platform {exec_platform} and models {models:?}"); - SpecCollector { - system: sysinfo::System::new_with_specifics(Self::get_refresh_specifics()), - models, - model_perf: model_perf - .into_iter() - .map(|(k, v)| (k.to_string(), v)) - .collect(), - version: version.to_string(), - exec_platform, - peer_id: peer_id.to_string(), - // gpus: wgpu::Instance::default() - // .enumerate_adapters(wgpu::Backends::all()) - // .into_iter() - // .map(|a| a.get_info()) - // .collect(), - } - } - - /// Returns the selected refresh kinds. It is important to ignore - /// process values here because it will consume a lot of file-descriptors. - #[inline(always)] - fn get_refresh_specifics() -> RefreshKind { - RefreshKind::nothing() - .with_cpu(CpuRefreshKind::everything()) - .with_memory(MemoryRefreshKind::everything()) - } - - pub async fn collect(&mut self) -> Specs { - self.system.refresh_specifics(Self::get_refresh_specifics()); - - Specs { - total_mem: self.system.total_memory(), - free_mem: self.system.free_memory(), - num_cpus: self.system.physical_core_count(), - cpu_usage: self.system.global_cpu_usage(), - os: std::env::consts::OS.to_string(), - arch: std::env::consts::ARCH.to_string(), - lookup: public_ip_address::perform_lookup(None).await.ok(), - models: self.models.clone(), - version: self.version.clone(), - model_perf: self.model_perf.clone(), - exec_platform: Some(self.exec_platform.clone()), - peer_id: Some(self.peer_id.clone()), - // gpus: self.gpus.clone(), - } - } -} -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_specs_serialization() { - let mut spec_collector = SpecCollector::new( - vec![Model::Gemma3_4b.to_string()], - HashMap::from_iter([ - (Model::Gemma3_4b, SpecModelPerformance::PassedWithTPS(100.0)), - (Model::Gemma3_27b, SpecModelPerformance::ExecutionFailed), - ]), - SemanticVersion { - major: 4, - minor: 5, - patch: 1, - }, - "testing".to_string(), - PeerId::random(), - ); - let specs = spec_collector.collect().await; - assert!(specs.total_mem > 0); - assert!(specs.free_mem > 0); - assert!(specs.num_cpus.is_some()); - assert!(specs.cpu_usage > 0.0); - assert!(!specs.os.is_empty()); - assert!(!specs.arch.is_empty()); - assert!(specs.lookup.is_some()); - assert!(!specs.models.is_empty()); - assert_eq!(specs.model_perf.len(), 2); - assert_eq!(specs.version, "4.5.1"); - assert_eq!(specs.exec_platform, Some("testing".to_string())); - - // should be serializable to JSON - assert!(serde_json::to_string_pretty(&specs).is_ok()) - } -} diff --git a/compute/src/workers/mod.rs b/compute/src/workers/mod.rs deleted file mode 100644 index cdafe4ad..00000000 --- a/compute/src/workers/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod task; diff --git a/compute/src/workers/task.rs b/compute/src/workers/task.rs deleted file mode 100644 index 4515951f..00000000 --- a/compute/src/workers/task.rs +++ /dev/null @@ -1,312 +0,0 @@ -use colored::Colorize; -use dkn_executor::{DriaExecutor, Model, TaskBody}; -use dkn_p2p::libp2p::request_response::ResponseChannel; -use dkn_utils::payloads::TaskStats; -use tokio::sync::mpsc; -use uuid::Uuid; - -/// A metadata object that is kept aside while the worker is doing its job. -/// -/// This is put into a map before execution, and then removed after the task is done. -pub struct TaskWorkerMetadata { - pub model: Model, - pub task_id: String, - pub file_id: Uuid, - /// If for any reason this object is dropped before `channel` is responded to, - /// the task will be lost and the channel will be abruptly closed, causing an error on - /// both the responder and the requester side, likely with an `OmissionError`. - pub channel: ResponseChannel>, -} - -pub struct TaskWorkerInput { - /// used as identifier for metadata - pub row_id: Uuid, - // actual consumed input - pub executor: DriaExecutor, - pub task: TaskBody, - // piggybacked metadata - pub stats: TaskStats, -} - -pub struct TaskWorkerOutput { - // used as identifier for metadata - pub row_id: Uuid, - // actual produced output - pub result: Result, - // piggybacked metadata - pub stats: TaskStats, - pub batchable: bool, -} - -/// It is expected to be spawned in another thread, with [`Self::run_batch`] for batch processing and [`Self::run_series`] for single processing. -pub struct TaskWorker { - /// Task channel receiver, the sender is most likely the compute node itself. - task_rx: mpsc::Receiver, - /// Publish message channel sender, the receiver is most likely the compute node itself. - publish_tx: mpsc::Sender, - // TODO: batch size must be defined here -} - -/// Buffer size for task channels (per worker). -const TASK_RX_CHANNEL_BUFSIZE: usize = 1024; - -impl TaskWorker { - /// Batch size that defines how many tasks can be executed concurrently at once. - /// - /// The `run` function is designed to handle the batch size here specifically, - /// if there are more tasks than the batch size, the function will panic. - pub const MAX_BATCH_SIZE: usize = 8; - - /// Creates a worker and returns the sender and receiver for the worker. - pub fn new( - publish_tx: mpsc::Sender, - ) -> (TaskWorker, mpsc::Sender) { - let (task_tx, task_rx) = mpsc::channel(TASK_RX_CHANNEL_BUFSIZE); - - let worker = TaskWorker { - task_rx, - publish_tx, - }; - - (worker, task_tx) - } - - /// Closes the worker's receiver channel. - fn shutdown(&mut self) { - log::info!("Closing worker."); - self.task_rx.close(); - } - - /// Launches the thread that can process tasks one by one (in series). - /// This function will block until the channel is closed. - /// - /// It is suitable for task streams that consume local resources, unlike API calls. - pub async fn run_series(&mut self) { - loop { - let task = self.task_rx.recv().await; - - if let Some(task) = task { - log::info!("Processing {} (single)", "task".yellow(),); - TaskWorker::execute((task, &self.publish_tx)).await - } else { - return self.shutdown(); - }; - } - } - - /// Launches the thread that can process tasks in batches. - /// This function will block until the channel is closed. - /// - /// It is suitable for task streams that make use of API calls, unlike Ollama-like - /// tasks that consumes local resources and would not make sense to run in parallel. - /// - /// Batch size must NOT be larger than `MAX_BATCH_SIZE`, otherwise will panic. - pub async fn run_batch(&mut self, batch_size: usize) { - assert!( - batch_size <= Self::MAX_BATCH_SIZE, - "Batch size must not be larger than {}", - Self::MAX_BATCH_SIZE - ); - - loop { - let mut tasks = Vec::new(); - - // get tasks in batch from the channel, we enter the loop if: - // (1) there are no tasks, or, - // (2) there are tasks less than the batch size and the channel is not empty - while tasks.is_empty() || (tasks.len() < batch_size && !self.task_rx.is_empty()) { - log::info!( - "Worker is waiting for tasks ({} < {})", - tasks.len(), - batch_size - ); - let limit = batch_size - tasks.len(); - match self.task_rx.recv_many(&mut tasks, limit).await { - // 0 tasks returned means that the channel is closed - 0 => return self.shutdown(), - _ => { - // wait a small amount of time to allow for more tasks to be sent into the channel - tokio::time::sleep(std::time::Duration::from_millis(256)).await; - } - } - } - - // process the batch - let num_tasks = tasks.len(); - debug_assert!( - num_tasks <= batch_size, - "number of tasks cant be larger than batch size" - ); - debug_assert!(num_tasks != 0, "number of tasks cant be zero"); - - log::info!("Processing {num_tasks} tasks in batch"); - let mut batch = tasks.into_iter().map(|b| (b, &self.publish_tx)); - match num_tasks { - 1 => { - TaskWorker::execute(batch.next().unwrap()).await; - } - 2 => { - tokio::join!( - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()) - ); - } - 3 => { - tokio::join!( - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()) - ); - } - 4 => { - tokio::join!( - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()) - ); - } - 5 => { - tokio::join!( - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()) - ); - } - 6 => { - tokio::join!( - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()) - ); - } - 7 => { - tokio::join!( - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()) - ); - } - 8 => { - tokio::join!( - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()), - TaskWorker::execute(batch.next().unwrap()) - ); - } - _ => { - unreachable!( - "number of tasks cant be larger than batch size ({} > {})", - num_tasks, - Self::MAX_BATCH_SIZE - ); - } - }; - } - } - - /// Executes a single task, and publishes the output. - pub async fn execute( - (mut input, publish_tx): (TaskWorkerInput, &mpsc::Sender), - ) { - let batchable = input.task.is_batchable(); - input.stats = input.stats.record_execution_started_at(); - let result = input.executor.execute(input.task).await; - input.stats = input.stats.record_execution_ended_at(); - - let output = TaskWorkerOutput { - result, - row_id: input.row_id, - batchable, - stats: input.stats, - }; - - if let Err(err) = publish_tx.send(output).await { - log::error!("Error sending task result: {err}"); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use dkn_executor::{DriaExecutor, Model}; - - /// Tests the worker with a single task sent within a batch. - /// - /// ## Run command - /// - /// ```sh - /// cargo test --package dkn-compute --lib --all-features -- workers::task::tests::test_executor_worker --exact --show-output --nocapture --ignored - /// ``` - #[tokio::test] - #[ignore = "run manually with Ollama"] - async fn test_executor_worker() { - let _ = env_logger::builder() - .filter_level(log::LevelFilter::Off) - .filter_module("dkn_compute", log::LevelFilter::Debug) - .is_test(true) - .try_init(); - - let (publish_tx, mut publish_rx) = mpsc::channel(1024); - let (mut worker, task_tx) = TaskWorker::new(publish_tx); - - // create batch worker - let worker_handle = tokio::spawn(async move { - worker.run_batch(4).await; - }); - - let num_tasks = 4; - let model = Model::Llama3_2_1bInstructQ4Km; - let executor = DriaExecutor::new_from_env(model.provider()).unwrap(); - let task = TaskBody::new_prompt("Write a poem about Julius Caesar.", model.clone()); - - for i in 0..num_tasks { - log::info!("Sending task {}", i + 1); - - let task_input = TaskWorkerInput { - executor: executor.clone(), - task: task.clone(), - // dummy variables - row_id: Uuid::now_v7(), - stats: TaskStats::default(), - }; - - // send task to worker - task_tx.send(task_input).await.unwrap(); - } - - // now wait for all results - let mut results = Vec::new(); - for i in 0..num_tasks { - log::info!("Waiting for result {}", i + 1); - let result = publish_rx.recv().await.unwrap(); - log::info!("Got result {}", i + 1,); - if result.result.is_err() { - log::error!("Error: {:?}", result.result); - } - results.push(result); - } - - log::info!("Got all results, closing channel."); - publish_rx.close(); - - // FIXME: this bugs out - worker_handle.await.unwrap(); - log::info!("Done."); - } -} diff --git a/docs/NODE_SPECS.md b/docs/NODE_SPECS.md deleted file mode 100644 index 4234d9bc..00000000 --- a/docs/NODE_SPECS.md +++ /dev/null @@ -1,254 +0,0 @@ -# 🚀 LLM Node Runner's Guide: Minimum Specs - -Hello, Drians! 👋 Here's a guide to help you understand the minimum specs needed for running different LLMs. We've broken it down into two main categories: (1) **GPU-enabled** nodes and (2) **CPU-only** nodes, as you can run your nodes on machines both _with_ or _without_ GPU. - -- ## 🖥️ GPU-Enabled Nodes - -### RTX3090 Single GPU: - -| Model | TPS | -| ----------------------------------- | -------- | -| finalend/hermes-3-llama-3.1:8b-q8_0 | 76.4388 | -| phi3:14b-medium-4k-instruct-q4_1 | 75.6148 | -| phi3:14b-medium-128k-instruct-q4_1 | 76.0658 | -| phi3.5:3.8b | 195.0728 | -| phi3.5:3.8b-mini-instruct-fp16 | 88.4656 | -| gemma2:9b-instruct-q8_0 | 56.2726 | -| gemma2:9b-instruct-fp16 | 37.9404 | -| llama3.1:latest | 103.3473 | -| llama3.1:8b-instruct-q8_0 | 78.5861 | -| llama3.1:8b-instruct-fp16 | 50.9302 | -| llama3.1:8b-text-q4_K_M | 104.4776 | -| llama3.1:8b-text-q8_0 | 82.3980 | -| llama3.2:1b | 293.1785 | -| llama3.2:3b | 168.7500 | -| llama3.2:1b-text-q4_K_M | 349.2497 | -| qwen2.5:7b-instruct-q5_0 | 114.0511 | -| qwen2.5:7b-instruct-fp16 | 53.5423 | -| qwen2.5-coder:1.5b | 238.6117 | -| qwen2.5-coder:7b-instruct | 125.2194 | -| qwen2.5-coder:7b-instruct-q8_0 | 83.7696 | -| qwen2.5-coder:7b-instruct-fp16 | 53.7400 | -| qwq | 33.4434 | -| deepseek-coder:6.7b | 141.7769 | -| deepseek-r1:1.5b | 235.8560 | -| deepseek-r1:7b | 121.9637 | -| deepseek-r1:8b | 107.5933 | -| deepseek-r1:14b | 66.5972 | -| deepseek-r1:32b | 34.4669 | -| deepseek-r1 | 120.9809 | -| driaforall/tiny-agent-a:0.5b | 279.2553 | -| driaforall/tiny-agent-a:1.5b | 201.7011 | -| driaforall/tiny-agent-a:3b | 135.1052 | - -### H200 SXM Single GPU: - -| Model | TPS | -| ----------------------------------- | -------- | -| finalend/hermes-3-llama-3.1:8b-q8_0 | 121.2871 | -| phi3:14b-medium-4k-instruct-q4_1 | 128.9496 | -| phi3:14b-medium-128k-instruct-q4_1 | 124.4223 | -| phi3.5:3.8b | 184.3729 | -| phi3.5:3.8b-mini-instruct-fp16 | 155.6164 | -| gemma2:9b-instruct-q8_0 | 91.6370 | -| gemma2:9b-instruct-fp16 | 85.6672 | -| llama3.1:latest | 123.8938 | -| llama3.1:8b-instruct-q8_0 | 112.3102 | -| llama3.1:8b-instruct-fp16 | 108.9053 | -| llama3.1:8b-text-q4_K_M | 148.0687 | -| llama3.1:8b-text-q8_0 | 135.3251 | -| llama3.1:70b-instruct-q4_0 | 47.0107 | -| llama3.1:70b-instruct-q8_0 | 35.2827 | -| llama3.2:1b | 163.9058 | -| llama3.2:3b | 150.6063 | -| llama3.3:70b | 39.1993 | -| llama3.2:1b-text-q4_K_M | 233.6957 | -| qwen2.5:7b-instruct-q5_0 | 126.5432 | -| qwen2.5:7b-instruct-fp16 | 103.8552 | -| qwen2.5:32b-instruct-fp16 | 40.3735 | -| qwen2.5-coder:1.5b | 187.3554 | -| qwen2.5-coder:7b-instruct | 119.7279 | -| qwen2.5-coder:7b-instruct-q8_0 | 108.9536 | -| qwen2.5-coder:7b-instruct-fp16 | 104.0222 | -| qwq | 59.4734 | -| deepseek-coder:6.7b | 136.8015 | -| mixtral:8x7b | 94.9618 | -| deepseek-r1:1.5b | 160.8217 | -| deepseek-r1:7b | 141.2172 | -| deepseek-r1:8b | 136.8324 | -| deepseek-r1:14b | 90.3022 | -| deepseek-r1:32b | 63.1900 | -| deepseek-r1:70b | 39.4153 | -| deepseek-r1 | 121.8406 | -| driaforall/tiny-agent-a:0.5b | 148.5390 | -| driaforall/tiny-agent-a:1.5b | 180.9409 | -| driaforall/tiny-agent-a:3b | 111.1869 | - -- ## 💻 CPU-Only Nodes - -For those running without a GPU, we've got you covered too! Here are the specs for different CPU types: - -### AMD (8 CPU, 16GB RAM) - -| Model | TPS | -| ---------------------------- | ------- | -| llama3.2:1b | 22.6293 | -| llama3.2:1b-text-q4_K_M | 25.0413 | -| qwen2.5-coder:1.5b | 21.7418 | -| deepseek-r1:1.5b | 29.7842 | -| driaforall/tiny-agent-a:0.5b | 54.5455 | -| driaforall/tiny-agent-a:1.5b | 19.9501 | - -### AMD (16 CPU, 32GB RAM) - -| Model | TPS | -| ---------------------------- | ------- | -| phi3.5:3.8b | 15.3677 | -| llama3.2:1b | 25.6367 | -| llama3.2:3b | 16.3185 | -| llama3.2:1b-text-q4_K_M | 38.0039 | -| qwen2.5-coder:1.5b | 30.3651 | -| deepseek-r1:1.5b | 30.2977 | -| driaforall/tiny-agent-a:0.5b | 61.2553 | -| driaforall/tiny-agent-a:1.5b | 25.7011 | - -### AMD (32 CPU, 64GB RAM) - -| Model | TPS | -| ---------------------------- | ------- | -| phi3.5:3.8b | 22.9944 | -| llama3.2:1b | 40.6091 | -| llama3.2:3b | 26.0240 | -| llama3.2:1b-text-q4_K_M | 56.2027 | -| qwen2.5-coder:1.5b | 44.6331 | -| deepseek-coder:6.7b | 15.1620 | -| deepseek-r1:1.5b | 43.8323 | -| driaforall/tiny-agent-a:0.5b | 59.9854 | -| driaforall/tiny-agent-a:1.5b | 27.7891 | - -### AMD (48 CPU, 96GB RAM) - -| Model | TPS | -| ---------------------------- | ------- | -| phi3.5:3.8b | 29.7455 | -| llama3.1:latest | 17.4744 | -| llama3.1:8b-text-q4_K_M | 18.1928 | -| llama3.2:1b | 49.1555 | -| llama3.2:3b | 33.9283 | -| llama3.2:1b-text-q4_K_M | 72.7273 | -| qwen2.5:7b-instruct-q5_0 | 17.0779 | -| qwen2.5-coder:1.5b | 56.2710 | -| qwen2.5-coder:7b-instruct | 18.2935 | -| deepseek-coder:6.7b | 21.2014 | -| deepseek-r1:1.5b | 55.0080 | -| deepseek-r1:7b | 18.0150 | -| deepseek-r1:8b | 16.4574 | -| deepseek-r1 | 18.0991 | -| driaforall/tiny-agent-a:0.5b | 86.2903 | -| driaforall/tiny-agent-a:1.5b | 41.6198 | -| driaforall/tiny-agent-a:3b | 24.1364 | - -### AMD (64 CPU, 128GB RAM) - -| Model | TPS | -| ---------------------------- | ------- | -| phi3.5:3.8b | 33.8993 | -| llama3.1:latest | 19.3015 | -| llama3.1:8b-text-q4_K_M | 19.9081 | -| llama3.2:1b | 55.6815 | -| llama3.2:3b | 36.6654 | -| llama3.2:1b-text-q4_K_M | 68.9655 | -| qwen2.5:7b-instruct-q5_0 | 18.0591 | -| qwen2.5-coder:1.5b | 56.7301 | -| qwen2.5-coder:7b-instruct | 20.1563 | -| deepseek-coder:6.7b | 23.4261 | -| deepseek-r1:1.5b | 57.0494 | -| deepseek-r1:7b | 20.3577 | -| deepseek-r1:8b | 18.6653 | -| deepseek-r1 | 20.2571 | -| driaforall/tiny-agent-a:0.5b | 94.6503 | -| driaforall/tiny-agent-a:1.5b | 49.5431 | -| driaforall/tiny-agent-a:3b | 27.1564 | - -### AMD (96 CPU, 192GB RAM) - -| Model | TPS | -| ---------------------------- | ------- | -| phi3.5:3.8b | 34.1058 | -| llama3.1:latest | 20.2221 | -| llama3.1:8b-text-q4_K_M | 20.1473 | -| llama3.2:1b | 54.5232 | -| llama3.2:3b | 37.6344 | -| llama3.2:1b-text-q4_K_M | 65.7570 | -| qwen2.5:7b-instruct-q5_0 | 20.2058 | -| qwen2.5-coder:1.5b | 55.4435 | -| qwen2.5-coder:7b-instruct | 21.3058 | -| deepseek-coder:6.7b | 24.6414 | -| deepseek-r1:1.5b | 54.3133 | -| deepseek-r1:7b | 20.8902 | -| deepseek-r1:8b | 18.7142 | -| deepseek-r1 | 22.1564 | -| driaforall/tiny-agent-a:0.5b | 94.7864 | -| driaforall/tiny-agent-a:1.5b | 50.7868 | -| driaforall/tiny-agent-a:3b | 29.4635 | - -### AMD (192 CPU, 384GB RAM) - -| Model | TPS | -| ----------------------------------- | ------- | -| finalend/hermes-3-llama-3.1:8b-q8_0 | 16.8002 | -| phi3.5:3.8b | 26.2855 | -| phi3.5:3.8b-mini-instruct-fp16 | 16.7343 | -| llama3.1:latest | 21.9456 | -| llama3.1:8b-instruct-q8_0 | 16.7135 | -| llama3.1:8b-text-q4_K_M | 22.5764 | -| llama3.1:8b-text-q8_0 | 16.3817 | -| llama3.2:1b | 43.5632 | -| llama3.2:3b | 29.5560 | -| llama3.2:1b-text-q4_K_M | 48.6348 | -| qwen2.5:7b-instruct-q5_0 | 21.4938 | -| qwen2.5-coder:1.5b | 33.3333 | -| qwen2.5-coder:7b-instruct | 21.7933 | -| qwen2.5-coder:7b-instruct-q8_0 | 17.8134 | -| deepseek-coder:6.7b | 23.4474 | -| deepseek-r1:1.5b | 32.7795 | -| deepseek-r1:7b | 22.5376 | -| deepseek-r1:8b | 20.3057 | -| deepseek-r1 | 23.0604 | -| driaforall/tiny-agent-a:0.5b | 42.1866 | -| driaforall/tiny-agent-a:1.5b | 33.4957 | -| driaforall/tiny-agent-a:3b | 24.5138 | - -### ARM (192 CPU, 384GB RAM) - -| Model | TPS | -| ---------------------------- | ------- | -| phi3.5:3.8b | 26.3062 | -| llama3.1:latest | 18.9597 | -| llama3.1:8b-text-q4_K_M | 18.2489 | -| llama3.2:1b | 43.7856 | -| llama3.2:3b | 30.3443 | -| llama3.2:1b-text-q4_K_M | 49.6852 | -| qwen2.5:7b-instruct-q5_0 | 16.8128 | -| qwen2.5-coder:1.5b | 38.3562 | -| qwen2.5-coder:7b-instruct | 19.5582 | -| deepseek-coder:6.7b | 21.2699 | -| deepseek-r1:1.5b | 36.0020 | -| deepseek-r1:7b | 19.5293 | -| deepseek-r1:8b | 18.5300 | -| deepseek-r1 | 18.9405 | -| driaforall/tiny-agent-a:0.5b | 28.4991 | -| driaforall/tiny-agent-a:1.5b | 31.6353 | -| driaforall/tiny-agent-a:3b | 22.2788 | - -## 📝 Notes - -- CPU usage can vary significantly between tasks, especially for long context vs. multiple steps. - -- Some models may require more than the available CPU cores, which could lead to slower performance. - -- RAM usage is generally consistent but can spike for certain operations. - -- **Important**: Lower CPU count results in lower performance. Systems with fewer CPUs will process requests more slowly, especially for models that require more CPU resources than are available. - -Remember, these are minimum specs, and your experience may vary depending on the specific tasks and workload. Happy node running! 🎉 diff --git a/executor/Cargo.toml b/executor/Cargo.toml deleted file mode 100644 index 12f12c42..00000000 --- a/executor/Cargo.toml +++ /dev/null @@ -1,36 +0,0 @@ -[package] -name = "dkn-executor" -version.workspace = true -edition.workspace = true -license.workspace = true -readme = "README.md" -authors = ["Erhan Tezcan "] - - -[dependencies] -env_logger.workspace = true - -# async stuff -tokio-util.workspace = true -tokio.workspace = true - -# serialize & deserialize -serde.workspace = true -serde_json.workspace = true - -# http & networking -reqwest.workspace = true - -# logging & errors -log.workspace = true -eyre.workspace = true -thiserror.workspace = true - -enum-iterator = "2.1.0" -rig-core = "0.11.1" -ollama-rs = { version = "0.3.0", features = ["tokio", "rustls", "stream"] } -dkn-utils = { path = "../utils" } - -[dev-dependencies] -# only used for tests -dotenvy.workspace = true diff --git a/executor/README.md b/executor/README.md deleted file mode 100644 index 69f3199e..00000000 --- a/executor/README.md +++ /dev/null @@ -1,20 +0,0 @@ -# Dria Executor - -## Installation - -Add the package via `git` within your Cargo dependencies: - -```toml -dkn-executor = { git = "https://github.com/firstbatchxyz/dkn-compute-node" } -``` - -## Usage - -Dria Executor makes use of several environment variables, with respect to several model providers. - -- `OLLAMA_HOST` is used to connect to **Ollama** server -- `OLLAMA_PORT` is used to connect to **Ollama** server -- `OLLAMA_AUTO_PULL` indicates whether we should pull missing models automatically or not -- `OPENAI_API_KEY` is used for **OpenAI** requests -- `GEMINI_API_KEY` is used for **Gemini** requests -- `OPENROUTER_API_KEY` is used for **OpenRouter** requests. diff --git a/executor/examples/ollama.rs b/executor/examples/ollama.rs deleted file mode 100644 index cec6200a..00000000 --- a/executor/examples/ollama.rs +++ /dev/null @@ -1,19 +0,0 @@ -use dkn_executor::{DriaExecutorsManager, Model}; - -#[tokio::main] -async fn main() -> eyre::Result<()> { - dotenvy::dotenv().ok(); - - let model = Model::Llama3_2_1bInstructQ4Km; - let models = vec![model]; - let mut config = DriaExecutorsManager::new_from_env_for_models(models.into_iter())?; - config.check_services().await; - assert!(config.models.contains(&model)); - - let task = dkn_executor::TaskBody::new_prompt("Write a haiku about category theory.", model); - let executor = config.get_executor(&task.model).await?; - let result = executor.execute(task).await?; - - println!("{}", result); - Ok(()) -} diff --git a/executor/src/executors/gemini.rs b/executor/src/executors/gemini.rs deleted file mode 100644 index fe77fadd..00000000 --- a/executor/src/executors/gemini.rs +++ /dev/null @@ -1,178 +0,0 @@ -use dkn_utils::payloads::SpecModelPerformance; -use eyre::{eyre, Context, Result}; -use reqwest::Client; -use rig::{ - completion::{Chat, PromptError}, - providers::gemini, -}; -use serde::Deserialize; -use std::collections::{HashMap, HashSet}; - -use crate::{Model, TaskBody}; - -/// OpenAI-specific configurations. -#[derive(Clone)] -pub struct GeminiClient { - api_key: String, - client: gemini::Client, -} - -impl GeminiClient { - /// Looks at the environment variables for Gemini API key. - pub fn new(api_key: &str) -> Self { - Self { - api_key: api_key.to_string(), - client: gemini::Client::new(api_key), - } - } - - /// Creates a new client using the API key in `GEMINI_API_KEY` environment variable. - pub fn from_env() -> Result { - let api_key = std::env::var("GEMINI_API_KEY")?; - Ok(Self::new(&api_key)) - } - - pub async fn execute(&self, task: TaskBody) -> Result { - let mut model = self.client.agent(&task.model.to_string()); - if let Some(preamble) = task.preamble { - model = model.preamble(&preamble); - } - - let agent = model.build(); - - agent.chat(task.prompt, task.chat_history).await - } - - /// Check if requested models exist & are available in the OpenAI account. - pub async fn check( - &self, - models: &mut HashSet, - ) -> Result> { - let mut models_to_remove = Vec::new(); - let mut model_performances = HashMap::new(); - log::info!("Checking Gemini requirements"); - - // check if models exist and select those that are available - let gemini_models_names = self.fetch_models().await?; - for requested_model in models.iter().cloned() { - // check if model exists - if !gemini_models_names - .iter() - // due to weird naming of models in Gemini API, we need to check prefix - .any(|model| model.starts_with(&requested_model.to_string())) - { - log::warn!( - "Model {} not found in your Gemini account, ignoring it.", - requested_model - ); - models_to_remove.push(requested_model); - model_performances.insert(requested_model, SpecModelPerformance::NotFound); - continue; - } - - // make a dummy request - if let Err(err) = self - .execute(TaskBody::new_prompt("What is 2 + 2?", requested_model)) - .await - { - log::warn!( - "Model {} failed dummy request, ignoring it: {}", - requested_model, - err - ); - models_to_remove.push(requested_model); - model_performances.insert(requested_model, SpecModelPerformance::ExecutionFailed); - continue; - } - - // record the performance of the model - model_performances.insert(requested_model, SpecModelPerformance::Passed); - } - - // remove models that are not available - for model in models_to_remove.iter() { - models.remove(model); - } - - Ok(model_performances) - } - - /// Returns the list of models available to this account. - /// - /// A gemini model name in API response is given as `models/{baseModelId}-{version}` - /// the model name in Dria can include the version as well, so best bet is to check prefix - /// ignoring the `models/` part. - async fn fetch_models(&self) -> Result> { - /// [Model](https://ai.google.dev/api/models#Model) API object, fields omitted. - #[derive(Debug, Clone, Deserialize)] - struct GeminiModel { - name: String, - // other fields are ignored from API response - } - - #[derive(Debug, Clone, Deserialize)] - struct GeminiModelsResponse { - models: Vec, - } - - // fetch models - let client = Client::new(); - let request = client - // [`models.list`](https://ai.google.dev/api/models#method:-models.list) endpoint - .get("https://generativelanguage.googleapis.com/v1beta/models") - .query(&[("key", &self.api_key)]) - .build() - .wrap_err("failed to build request")?; - - let response = client - .execute(request) - .await - .wrap_err("failed to send request")?; - - // parse response - if response.status().is_client_error() { - return Err(eyre!( - "Failed to fetch Gemini models:\n{}", - response.text().await.unwrap_or_default() - )); - } - let gemini_models = response.json::().await?; - - Ok(gemini_models - .models - .into_iter() - .map(|model| model.name.trim_start_matches("models/").to_string()) - .collect()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - #[ignore = "requires Gemini API key"] - async fn test_gemini_check() { - let _ = env_logger::builder() - .filter_level(log::LevelFilter::Off) - .filter_module("dkn_executor", log::LevelFilter::Debug) - .is_test(true) - .try_init(); - let _ = dotenvy::dotenv(); // read api key - - let initial_models = [Model::Gemini2_0Flash, Model::Gemini2_5ProExp]; - let mut models = HashSet::from_iter(initial_models); - GeminiClient::from_env() - .unwrap() - .check(&mut models) - .await - .unwrap(); - assert_eq!(models.len(), initial_models.len()); - - // should give error for bad API key - let res = GeminiClient::new("i-dont-work") - .check(&mut HashSet::new()) - .await; - assert!(res.is_err()); - } -} diff --git a/executor/src/executors/mod.rs b/executor/src/executors/mod.rs deleted file mode 100644 index efa1baf6..00000000 --- a/executor/src/executors/mod.rs +++ /dev/null @@ -1,71 +0,0 @@ -use crate::{Model, ModelProvider, TaskBody}; -use dkn_utils::payloads::SpecModelPerformance; -use rig::completion::PromptError; -use std::collections::{HashMap, HashSet}; - -mod ollama; -use ollama::OllamaClient; - -// mod openai; -// use openai::OpenAIClient; - -// mod gemini; -// use gemini::GeminiClient; - -// mod openrouter; -// use openrouter::OpenRouterClient; - -/// A wrapper enum for all model providers. -#[derive(Clone)] -pub enum DriaExecutor { - Ollama(OllamaClient), - // OpenAI(OpenAIClient), - // Gemini(GeminiClient), - // OpenRouter(OpenRouterClient), -} - -impl DriaExecutor { - /// Creates a new executor for the given provider using the API key in the environment variables. - pub fn new_from_env(provider: ModelProvider) -> Result { - match provider { - ModelProvider::Ollama => OllamaClient::from_env().map(DriaExecutor::Ollama), - // ModelProvider::OpenAI => OpenAIClient::from_env().map(DriaExecutor::OpenAI), - // ModelProvider::Gemini => GeminiClient::from_env().map(DriaExecutor::Gemini), - // ModelProvider::OpenRouter => OpenRouterClient::from_env().map(DriaExecutor::OpenRouter), - } - } - - /// Executes the given task using the appropriate provider. - pub async fn execute(&self, task: TaskBody) -> Result { - match self { - DriaExecutor::Ollama(provider) => provider.execute(task).await, - // DriaExecutor::OpenAI(provider) => provider.execute(task).await, - // DriaExecutor::Gemini(provider) => provider.execute(task).await, - // DriaExecutor::OpenRouter(provider) => provider.execute(task).await, - } - } - - /// Checks if the requested models exist and are available in the provider's account. - /// - /// For Ollama in particular, it also checks if the models are performant enough. - pub async fn check( - &self, - models: &mut HashSet, - ) -> eyre::Result> { - match self { - DriaExecutor::Ollama(provider) => provider.check(models).await, - // DriaExecutor::OpenAI(provider) => provider.check(models).await, - // DriaExecutor::Gemini(provider) => provider.check(models).await, - // DriaExecutor::OpenRouter(provider) => provider.check(models).await, - } - } - - pub fn name(&self) -> String { - match self { - DriaExecutor::Ollama(_) => ModelProvider::Ollama.to_string(), - // DriaExecutor::OpenAI(_) => ModelProvider::OpenAI.to_string(), - // DriaExecutor::Gemini(_) => ModelProvider::Gemini.to_string(), - // DriaExecutor::OpenRouter(_) => ModelProvider::OpenRouter.to_string(), - } - } -} diff --git a/executor/src/executors/ollama.rs b/executor/src/executors/ollama.rs deleted file mode 100644 index 766099d8..00000000 --- a/executor/src/executors/ollama.rs +++ /dev/null @@ -1,253 +0,0 @@ -use dkn_utils::payloads::SpecModelPerformance; -use eyre::{Context, Result}; -use ollama_rs::generation::completion::request::GenerationRequest; -use rig::completion::{Chat, PromptError}; -use rig::providers::ollama; -use std::collections::HashMap; -use std::time::Duration; -use std::{collections::HashSet, env}; - -use crate::{Model, TaskBody}; - -const DEFAULT_OLLAMA_HOST: &str = "http://127.0.0.1"; -const DEFAULT_OLLAMA_PORT: u16 = 11434; - -/// Timeout duration for checking model performance during a generation. -const PERFORMANCE_TIMEOUT: Duration = Duration::from_secs(120); -/// Minimum tokens per second (TPS) for checking model performance during a generation. -const PERFORMANCE_MIN_TPS: f64 = 10.0; - -/// Ollama-specific configurations. -#[derive(Clone)] -pub struct OllamaClient { - /// Whether to automatically pull models from Ollama. - auto_pull: bool, - /// Underlying Ollama client. - client: ollama::Client, - /// A more specialized Ollama client. - /// - /// - Can do pulls - /// - Can list local models - ollama_rs_client: ollama_rs::Ollama, -} - -impl OllamaClient { - /// Creates a new Ollama client using the host and port. - pub fn new(host: &str, port: u16, auto_pull: bool) -> Self { - Self { - auto_pull, - ollama_rs_client: ollama_rs::Ollama::new(host, port), - client: ollama::Client::from_url(&format!("{host}:{port}",)), - } - } - - /// Looks at the environment variables for Ollama host and port. - /// - /// If not found, defaults to `DEFAULT_OLLAMA_HOST` and `DEFAULT_OLLAMA_PORT`. - /// - /// Returns a `Result` to be compatible with other executors. - pub fn from_env() -> Result { - let host = env::var("OLLAMA_HOST") - .map(|h| h.trim_matches('"').to_string()) - .unwrap_or(DEFAULT_OLLAMA_HOST.to_string()); - let port = env::var("OLLAMA_PORT") - .and_then(|port_str| port_str.parse().map_err(|_| std::env::VarError::NotPresent)) - .unwrap_or(DEFAULT_OLLAMA_PORT); - - // auto-pull, its true by default - let auto_pull = env::var("OLLAMA_AUTO_PULL") - .map(|s| s == "true") - .unwrap_or(true); - - Ok(Self::new(&host, port, auto_pull)) - } - - /// Sets the auto-pull flag for Ollama models. - pub fn with_auto_pull(mut self, auto_pull: bool) -> Self { - self.auto_pull = auto_pull; - self - } - - pub async fn execute(&self, task: TaskBody) -> Result { - let mut model = self.client.agent(&task.model.to_string()); - if let Some(preamble) = task.preamble { - model = model.preamble(&preamble); - } - - let agent = model.build(); - - agent.chat(task.prompt, task.chat_history).await - } - - /// Check if requested models exist in Ollama & test them using a dummy prompt. - pub async fn check( - &self, - models: &mut HashSet, - ) -> Result> { - log::info!( - "Checking Ollama requirements ({}, timeout: {}s, min tps: {})", - if self.auto_pull { - "auto-pull enabled" - } else { - "auto-pull disabled" - }, - PERFORMANCE_TIMEOUT.as_secs(), - PERFORMANCE_MIN_TPS - ); - - // fetch local models - let local_models = match self.ollama_rs_client.list_local_models().await { - Ok(models) => models.into_iter().map(|m| m.name).collect::>(), - Err(e) => { - return { - log::error!("Could not fetch local models from Ollama, is it online?"); - Err(e.into()) - } - } - }; - log::info!("Found local Ollama models: {local_models:#?}"); - - // check external models & pull them if available - // iterate over models and remove bad ones - let mut models_to_remove = Vec::new(); - let mut model_performances = HashMap::new(); - for model in models.iter() { - // pull the model if it is not in the local models - if !local_models.contains(&model.to_string()) { - log::warn!("Model {model} not found in Ollama"); - if self.auto_pull { - self.try_pull(model) - .await - .wrap_err("could not pull model")?; - } else { - log::error!("Please download missing model with: ollama pull {model}"); - log::error!("Or, set OLLAMA_AUTO_PULL=true to pull automatically."); - eyre::bail!("required model not pulled in Ollama"); - } - } - - // test its performance - let perf = self.measure_tps_with_warmup(model).await; - if let SpecModelPerformance::PassedWithTPS(_) = perf { - model_performances.insert(*model, perf); - } else { - // if its anything but PassedWithTPS, remove the model - models_to_remove.push(*model); - model_performances.insert(*model, perf); - } - } - - // remove failed models - for model in models_to_remove { - models.remove(&model); - } - - if models.is_empty() { - log::warn!("No Ollama models passed the performance test! Try using a more powerful machine OR smaller models."); - } else { - log::info!("Ollama checks are finished, using models: {models:#?}"); - } - - Ok(model_performances) - } - - /// Pulls a model from Ollama. - async fn try_pull(&self, model: &Model) -> Result { - // TODO: add pull-bar here - // if auto-pull is enabled, pull the model - log::info!("Downloading missing model {model} (this may take a while)"); - self.ollama_rs_client - .pull_model(model.to_string(), false) - .await - .wrap_err("could not pull model") - } - - /// Runs a small test to test local model performance. - /// - /// This is to see if a given system can execute tasks for their chosen models, - /// e.g. if they have enough RAM/CPU and such. - pub async fn measure_tps_with_warmup(&self, model: &Model) -> SpecModelPerformance { - const TEST_PROMPT: &str = "Please write a poem about Kapadokya."; - const WARMUP_PROMPT: &str = "Write a short poem about hedgehogs and squirrels."; - - log::info!("Measuring {model}"); - - // run a dummy generation for warm-up - log::debug!("Warming up Ollama for {model}"); - if let Err(err) = self - .ollama_rs_client - .generate(GenerationRequest::new( - model.to_string(), - WARMUP_PROMPT.to_string(), - )) - .await - { - log::warn!("Ignoring {model}: {err}"); - return SpecModelPerformance::ExecutionFailed; - } - - // then, run a sample generation with timeout and measure tps - let Ok(result) = tokio::time::timeout( - PERFORMANCE_TIMEOUT, - self.ollama_rs_client.generate(GenerationRequest::new( - model.to_string(), - TEST_PROMPT.to_string(), - )), - ) - .await - else { - log::warn!("Ignoring {model}: Timed out"); - return SpecModelPerformance::Timeout; - }; - - // check the result - match result { - Ok(response) => { - let tps = (response.eval_count.unwrap_or_default() as f64) - / (response.eval_duration.unwrap_or(1) as f64) - * 1_000_000_000f64; - - if tps >= PERFORMANCE_MIN_TPS { - log::info!("{model} passed the test with tps: {tps}"); - SpecModelPerformance::PassedWithTPS(tps) - } else { - log::warn!( - "Ignoring {model}: tps too low ({tps:.3} < {PERFORMANCE_MIN_TPS:.3})" - ); - SpecModelPerformance::FailedWithTPS(tps) - } - } - Err(err) => { - log::warn!("Ignoring {model} due to: {err}"); - SpecModelPerformance::ExecutionFailed - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - #[ignore = "requires Ollama"] - async fn test_ollama_prompt() { - let client = OllamaClient::from_env().unwrap(); - let model = Model::Llama3_2_1bInstructQ4Km; - - let stats = client.try_pull(&model).await.unwrap(); - println!("Model {}: {:#?}", model, stats); - let prompt = "The sky appears blue during the day because of a process called scattering. \ - When sunlight enters the Earth's atmosphere, it collides with air molecules such as oxygen and nitrogen. \ - These collisions cause some of the light to be absorbed or reflected, which makes the colors we see appear more vivid and vibrant. \ - Blue is one of the brightest colors that is scattered the most by the atmosphere, making it visible to our eyes during the day. \ - What may be the question this answer?".to_string(); - - let response = client - .execute(TaskBody::new_prompt(&prompt, model)) - .await - .unwrap(); - - println!("Prompt: {}\n\nResponse:{}", prompt, response); - } -} diff --git a/executor/src/executors/openai.rs b/executor/src/executors/openai.rs deleted file mode 100644 index cb98b5e1..00000000 --- a/executor/src/executors/openai.rs +++ /dev/null @@ -1,172 +0,0 @@ -use std::collections::{HashMap, HashSet}; - -use dkn_utils::payloads::SpecModelPerformance; -use eyre::{eyre, Context, Result}; -use reqwest::Client; -use rig::{ - completion::{Chat, PromptError}, - providers::openai, -}; -use serde::Deserialize; - -use crate::{Model, TaskBody}; - -/// OpenAI-specific configurations. -#[derive(Clone)] -pub struct OpenAIClient { - /// API key, if available. - api_key: String, - /// Underlying OpenAI client from [`rig`]. - client: openai::Client, -} - -impl OpenAIClient { - /// Looks at the environment variables for OpenAI API key. - pub fn new(api_key: &str) -> Self { - Self { - api_key: api_key.to_string(), - client: openai::Client::new(api_key), - } - } - - /// Creates a new OpenAI client using the API key in `OPENAI_API_KEY` environment variable. - pub fn from_env() -> Result { - let api_key = std::env::var("OPENAI_API_KEY")?; - Ok(Self::new(&api_key)) - } - - pub async fn execute(&self, task: TaskBody) -> Result { - let mut model = self.client.agent(&task.model.to_string()); - if let Some(preamble) = task.preamble { - model = model.preamble(&preamble); - } - - let agent = model.build(); - - agent.chat(task.prompt, task.chat_history).await - } - - /// Returns the list of model names available to this account. - pub async fn check( - &self, - models: &mut HashSet, - ) -> Result> { - let mut models_to_remove = Vec::new(); - let mut model_performances = HashMap::new(); - log::info!("Checking OpenAI requirements"); - - // check if models exist within the account and select those that are available - let openai_model_names = self.fetch_models().await?; - for model in models.iter().cloned() { - // check if model exists - if !openai_model_names.contains(&model.to_string()) { - log::warn!( - "Model {} not found in your OpenAI account, ignoring it.", - model - ); - models_to_remove.push(model); - model_performances.insert(model, SpecModelPerformance::NotFound); - continue; - } - - // if it exists, make a dummy request - if let Err(err) = self - .execute(TaskBody::new_prompt("What is 2 + 2?", model)) - .await - { - log::warn!("Model {} failed dummy request, ignoring it: {}", model, err); - models_to_remove.push(model); - model_performances.insert(model, SpecModelPerformance::ExecutionFailed); - continue; - } - - // record the performance of the model - model_performances.insert(model, SpecModelPerformance::Passed); - } - - // remove models that are not available - for model in models_to_remove.iter() { - models.remove(model); - } - - // log results - if models.is_empty() { - log::warn!("OpenAI checks are finished, no available models found.",); - } else { - log::info!("OpenAI checks are finished, using models: {:#?}", models); - } - - Ok(model_performances) - } - - /// Fetches the list of models available in the OpenAI account. - async fn fetch_models(&self) -> Result> { - /// [Model](https://platform.openai.com/docs/api-reference/models/object) API object, fields omitted. - #[derive(Debug, Clone, Deserialize)] - struct OpenAIModel { - /// The model identifier, which can be referenced in the API endpoints. - id: String, - } - - #[derive(Debug, Clone, Deserialize)] - struct OpenAIModelsResponse { - data: Vec, - } - - let client = Client::new(); - let request = client - .get("https://api.openai.com/v1/models") - .header("Authorization", format!("Bearer {}", self.api_key)) - .build() - .wrap_err("failed to build request")?; - - let response = client - .execute(request) - .await - .wrap_err("failed to send request")?; - - // parse response - if !response.status().is_success() { - Err(eyre!( - "Failed to fetch OpenAI models:\n{}", - response - .text() - .await - .unwrap_or("could not get error text as well".to_string()) - )) - } else { - let openai_models = response.json::().await?; - Ok(openai_models.data.into_iter().map(|m| m.id).collect()) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - #[ignore = "requires OpenAI API key"] - async fn test_openai_check() { - let _ = env_logger::builder() - .filter_level(log::LevelFilter::Off) - .filter_module("dkn_executor", log::LevelFilter::Debug) - .is_test(true) - .try_init(); - let _ = dotenvy::dotenv(); // read api key - - let initial_models = [Model::GPT4o, Model::GPT4oMini]; - let mut models = HashSet::from_iter(initial_models); - OpenAIClient::from_env() - .unwrap() - .check(&mut models) - .await - .unwrap(); - assert_eq!(models.len(), initial_models.len()); - - let res = OpenAIClient::new("i-dont-work") - .check(&mut Default::default()) - .await; - assert!(res.is_err()); - } -} diff --git a/executor/src/executors/openrouter.rs b/executor/src/executors/openrouter.rs deleted file mode 100644 index d4fc3c5e..00000000 --- a/executor/src/executors/openrouter.rs +++ /dev/null @@ -1,98 +0,0 @@ -use std::collections::{HashMap, HashSet}; - -use dkn_utils::payloads::SpecModelPerformance; -use eyre::Result; -use rig::completion::{Chat, PromptError}; -use rig::providers::openrouter; - -use crate::{Model, TaskBody}; - -/// OpenRouter-specific configurations. -#[derive(Clone)] -pub struct OpenRouterClient { - client: openrouter::Client, -} - -impl OpenRouterClient { - /// Looks at the environment variables for OpenRouter API key. - pub fn new(api_key: &str) -> Self { - Self { - client: openrouter::Client::new(api_key), - } - } - - /// Creates a new client using the API key in `OPENROUTER_API_KEY` environment variable. - pub fn from_env() -> Result { - let api_key = std::env::var("OPENROUTER_API_KEY")?; - Ok(Self::new(&api_key)) - } - - pub async fn execute(&self, task: TaskBody) -> Result { - let mut model = self.client.agent(&task.model.to_string()); - if let Some(preamble) = task.preamble { - model = model.preamble(&preamble); - } - - let agent = model.build(); - agent.chat(task.prompt, task.chat_history).await - } - - /// Checks if the API key exists. - pub async fn check( - &self, - models: &mut HashSet, - ) -> Result> { - let mut models_to_remove = Vec::new(); - let mut model_performances = HashMap::new(); - log::info!("Checking OpenRouter API key"); - - // make a dummy request with existing models - for model in models.iter().cloned() { - if let Err(err) = self - .execute(TaskBody::new_prompt("What is 2 + 2?", model)) - .await - { - log::warn!("Model {} failed dummy request, ignoring it: {}", model, err); - models_to_remove.push(model); - model_performances.insert(model, SpecModelPerformance::ExecutionFailed); - continue; - } - - // record the model performance - model_performances.insert(model, SpecModelPerformance::Passed); - } - - // remove models that failed the dummy request - for model in models_to_remove.iter() { - models.remove(model); - } - - Ok(model_performances) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - #[ignore = "requires OpenRouter API key"] - async fn test_openrouter_check() { - let _ = env_logger::builder() - .filter_level(log::LevelFilter::Off) - .filter_module("dkn_executor", log::LevelFilter::Debug) - .is_test(true) - .try_init(); - let _ = dotenvy::dotenv(); // read api key - - let initial_models = [Model::OR3_5Sonnet, Model::OR3_7Sonnet]; - let mut models = HashSet::from_iter(initial_models); - let config = OpenRouterClient::from_env().unwrap(); - config.check(&mut models).await.unwrap(); - assert_eq!(models.len(), initial_models.len()); - - // create with a bad api key - let config = OpenRouterClient::new("i-dont-work"); - config.check(&mut HashSet::new()).await.unwrap(); // should not panic - } -} diff --git a/executor/src/lib.rs b/executor/src/lib.rs deleted file mode 100644 index 4dd16bcd..00000000 --- a/executor/src/lib.rs +++ /dev/null @@ -1,17 +0,0 @@ -mod executors; -pub use executors::DriaExecutor; - -mod manager; -pub use manager::DriaExecutorsManager; - -mod models; -pub use models::{Model, ModelProvider}; - -mod task; -pub use task::{TaskBody, TaskResult}; - -pub use rig::completion::CompletionModel; -pub use rig::completion::{CompletionError, PromptError}; - -// re-export ollama_rs -pub use ollama_rs; diff --git a/executor/src/manager.rs b/executor/src/manager.rs deleted file mode 100644 index 01b5e2d2..00000000 --- a/executor/src/manager.rs +++ /dev/null @@ -1,143 +0,0 @@ -use dkn_utils::payloads::SpecModelPerformance; - -use crate::{executors::DriaExecutor, Model, ModelProvider}; -use std::collections::{HashMap, HashSet}; - -#[derive(Clone)] -pub struct DriaExecutorsManager { - /// List of all models supported by this node. - /// - /// Equivalent to the union of all sets of models in the providers. - pub models: HashSet, - /// Providers and their executors along with the models they support. - pub providers: HashMap)>, -} - -impl DriaExecutorsManager { - /// Creates a new executor manager with the given models, using environment variables for the providers. - /// - /// If a provider is required (as per the chosen model) but its environment variables are missing, - /// this will return an error. - pub fn new_from_env_for_models( - models: impl Iterator, - ) -> Result { - let mut provider_set: HashMap)> = - HashMap::new(); - let mut model_set = HashSet::new(); - for model in models { - // get the provider for the model - let provider = model.provider(); - - // add model to the provider set, and create a new executor if needed - match provider_set.get_mut(&provider) { - Some((_, models)) => { - models.insert(model); - } - None => { - // create a new executor for the provider, may return an error! - match DriaExecutor::new_from_env(provider) { - Ok(executor) => { - provider_set.insert(provider, (executor, HashSet::from_iter([model]))); - } - Err(err) => { - log::error!( - "Failed to create executor for {provider}: {err}, {model} will not be supported.", - ); - continue; // skip this model if the executor creation failed - } - } - } - } - - // add the model to the global model set - model_set.insert(model); - } - - Ok(Self { - providers: provider_set, - models: model_set, - }) - } - - /// Given the model, returns a _cloned_ executor for it. - /// - /// If the model's provider is not supported, an error is returned. - /// Likewise, if the provider is supported but the model is not, an error is returned. - pub async fn get_executor(&self, model: &Model) -> eyre::Result { - let provider = model.provider(); - let (executor, models) = self - .providers - .get(&provider) - .ok_or_else(|| eyre::eyre!("Provider {provider} supported by this executor"))?; - - if models.contains(model) { - Ok(executor.clone()) - } else { - Err(eyre::eyre!("Model {model} not supported by this executor")) - } - } - - /// Returns the set of models supported by the given provider for this manager. - /// - /// If there are no models for the provider, an empty set is returned. - pub fn get_models_for_provider(&self, provider: ModelProvider) -> HashSet { - self.providers - .get(&provider) - .map(|(_, models)| models.clone()) - .unwrap_or_default() - } - - /// Returns the names of all models in the manager, in a random order. - pub fn get_model_names(&self) -> Vec { - self.models.iter().map(|m| m.to_string()).collect() - } - - /// Check if the required compute services are running. - /// - /// - If Ollama models are used the task is tested with a simple task with timeout. - /// - If API based models are used, the API key is checked and the models are tested with a dummy request. - /// - /// In the end, bad models are filtered out and we simply check if we are left if any valid models at all. - /// If there are no models left in the end, an error is thrown. - pub async fn check_services(&mut self) -> HashMap { - log::info!("Checking configured services."); - - // check all configured providers & record model performances - let mut model_perf = HashMap::new(); - for (client, models) in self.providers.values_mut() { - if let Ok(provider_model_perf) = client.check(models).await { - model_perf.extend(provider_model_perf); - } else { - log::warn!( - "Provider {} failed to check services, ignoring its models.", - client.name() - ); - model_perf.extend( - models - .iter() - .map(|m| (*m, SpecModelPerformance::ExecutionFailed)), - ); - // clear models - models.clear(); - } - } - - // obtain the final list of providers & models, removing the providers with no models left - self.providers.retain(|provider, (_, models)| { - let ok = !models.is_empty(); - if !ok { - log::warn!("Provider {provider} has no models left, removing it from the config.") - } - ok - }); - - // update the models set - self.models = self - .providers - .values() - .flat_map(|(_, models)| models.iter().cloned()) - .collect(); - - model_perf - } -} diff --git a/executor/src/models.rs b/executor/src/models.rs deleted file mode 100644 index c98160d7..00000000 --- a/executor/src/models.rs +++ /dev/null @@ -1,299 +0,0 @@ -use enum_iterator::Sequence; -use serde::{Deserialize, Serialize}; -use std::{collections::HashSet, fmt, str::FromStr}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, Sequence)] -pub enum Model { - // Ollama models - /// [Meta's Llama3.1](https://ollama.com/library/llama3.1:8b-instruct-q4_K_M) - #[serde(rename = "llama3.1:8b-instruct-q4_K_M")] - Llama3_1_8bInstructQ4Km, - /// [Meta's LLama3.2](https://ollama.com/library/llama3.2:1b-instruct-q4_K_M) - #[serde(rename = "llama3.2:1b-instruct-q4_K_M")] - Llama3_2_1bInstructQ4Km, - /// [Meta's LLama3.3](https://ollama.com/library/llama3.3:70b-instruct-q4_K_M) - #[serde(rename = "llama3.3:70b-instruct-q4_K_M")] - Llama3_3_70bInstructQ4Km, - /// [Mistral's Nemo](https://ollama.com/library/mistral-nemo:12b) - #[serde(rename = "mistral-nemo:12b")] - MistralNemo12b, - /// [Google's Gemma3 4b](https://ollama.com/library/gemma3:4b) - #[serde(rename = "gemma3:4b")] - Gemma3_4b, - /// [Google's Gemma3 12b](https://ollama.com/library/gemma3:12b) - #[serde(rename = "gemma3:12b")] - Gemma3_12b, - /// [Google's Gemma3 27b](https://ollama.com/library/gemma3:27b) - #[serde(rename = "gemma3:27b")] - Gemma3_27b, - /// [Alibaba's Qwen3 32b](https://ollama.com/library/qwen3:32b) - #[serde(rename = "qwen3:32b")] - Qwen3_32b, - /// [Alibaba's Qwen3 8b](https://ollama.com/library/qwen3:8b) - #[serde(rename = "qwen3:8b")] - Qwen3_8b, - // // OpenAI models - // /// [OpenAI's GPT-4o](https://platform.openai.com/docs/models#gpt-4o) - // #[serde(rename = "gpt-4o")] - // GPT4o, - // /// [OpenAI's GPT-4o mini](https://platform.openai.com/docs/models#gpt-4o-mini) - // #[serde(rename = "gpt-4o-mini")] - // GPT4oMini, - - // // Gemini models - // /// [Google's Gemini 2.5 Pro experimental](https://ai.google.dev/gemini-api/docs/models#gemini-2.5-pro-preview-03-25) - // #[serde(rename = "gemini-2.5-pro-exp-03-25")] - // Gemini2_5ProExp, - // /// [Google's Gemini 2.0 Flash](https://ai.google.dev/gemini-api/docs/models#gemini-2.0-flash) - // #[serde(rename = "gemini-2.0-flash")] - // Gemini2_0Flash, - - // /// OpenRouter Models - // /// [Anthropic's Claude 3.5 Sonnet](https://openrouter.ai/models?q=claude-3.5-sonnet) - // #[serde(rename = "anthropic/claude-3.5-sonnet")] - // OR3_5Sonnet, - // /// [Anthropic's Claude 3.7 Sonnet](https://openrouter.ai/models?q=claude-3.7-sonnet) - // #[serde(rename = "anthropic/claude-3-7-sonnet")] - // OR3_7Sonnet, -} - -impl FromStr for Model { - type Err = String; - - /// Tries to parse the given `str` into a `Model`. - /// On failure, returns the original string back as the `Err` value. - fn from_str(value: &str) -> Result { - // serde requires quotes (for JSON) - serde_json::from_str::(&format!("\"{value}\"")) - .map_err(|err| format!("Model {value} invalid: {err}")) - } -} - -impl Model { - /// Returns a set of models from a CSV string. - /// - /// The input string should be a comma-separated list of model names. - /// - /// ## Example - /// - /// ```rs - /// let models = Model::from_csv("gpt-4o, gpt-4o-mini"); - /// assert!(models.contains(&Model::GPT4o)); - /// assert!(models.contains(&Model::GPT4oMini)); - /// ``` - pub fn from_csv(input: impl AsRef) -> HashSet { - HashSet::from_iter( - input - .as_ref() - .split(',') - .filter_map(|s| Self::try_from(s.trim()).ok()), - ) - } - - /// Returns an iterator over all models. - #[inline(always)] - pub fn all() -> impl Iterator { - enum_iterator::all::() - } - - /// Returns an iterator over all models that belong to a given provider. - #[inline(always)] - pub fn all_with_provider(provider: &ModelProvider) -> impl Iterator + '_ { - enum_iterator::all::().filter(move |m| m.provider() == *provider) - } - - /// Returns the provider that hosts the model. - #[inline] - pub fn provider(&self) -> ModelProvider { - ModelProvider::from(self) - } -} - -impl fmt::Display for Model { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // guaranteed not to fail because this is enum to string serialization - let self_str = serde_json::to_string(&self).unwrap_or_default(); - // remove quotes from JSON - write!(f, "{}", self_str.trim_matches('"')) - } -} - -impl TryFrom for Model { - type Error = String; - fn try_from(value: String) -> Result { - value.as_str().parse() - } -} - -impl TryFrom<&str> for Model { - type Error = String; - fn try_from(value: &str) -> Result { - value.parse() - } -} - -/// A model provider is a service that hosts the chosen Model. -/// It can be derived from the model name, e.g. GPT4o is hosted by OpenAI (via API), or Phi3 is hosted by Ollama (locally). -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, Sequence)] -pub enum ModelProvider { - #[serde(rename = "ollama")] - Ollama, - // #[serde(rename = "openai")] - // OpenAI, - // #[serde(rename = "gemini")] - // Gemini, - // #[serde(rename = "openrouter")] - // OpenRouter, -} - -impl ModelProvider { - /// Returns an iterator over all model providers. - #[inline(always)] - pub fn all() -> impl Iterator { - enum_iterator::all::() - } - - /// Returns all models that belong to the provider. - #[inline] - pub fn models(&self) -> impl Iterator + '_ { - Model::all_with_provider(self) - } - - /// Returns whether the provider is batchable - /// (can be executed concurrently) or not. - pub fn is_batchable(&self) -> bool { - match self { - // ollama models are not batchable - ModelProvider::Ollama => false, - // // api-based providers are batchable - // ModelProvider::OpenAI => true, - // ModelProvider::Gemini => true, - // ModelProvider::OpenRouter => true, - } - } -} - -impl From for ModelProvider { - fn from(value: Model) -> Self { - Self::from(&value) - } -} - -impl From<&Model> for ModelProvider { - fn from(model: &Model) -> Self { - match model { - // ollama - Model::Gemma3_4b => ModelProvider::Ollama, - Model::Gemma3_12b => ModelProvider::Ollama, - Model::Gemma3_27b => ModelProvider::Ollama, - Model::Llama3_1_8bInstructQ4Km => ModelProvider::Ollama, - Model::Llama3_2_1bInstructQ4Km => ModelProvider::Ollama, - Model::Llama3_3_70bInstructQ4Km => ModelProvider::Ollama, - Model::MistralNemo12b => ModelProvider::Ollama, - Model::Qwen3_8b => ModelProvider::Ollama, - Model::Qwen3_32b => ModelProvider::Ollama, - // // openai - // Model::GPT4o => ModelProvider::OpenAI, - // Model::GPT4oMini => ModelProvider::OpenAI, - // // gemini - // Model::Gemini2_0Flash => ModelProvider::Gemini, - // Model::Gemini2_5ProExp => ModelProvider::Gemini, - // // openrouter - // Model::OR3_5Sonnet => ModelProvider::OpenRouter, - // Model::OR3_7Sonnet => ModelProvider::OpenRouter, - } - } -} - -impl FromStr for ModelProvider { - type Err = String; - - /// Tries to parse the given `str` into a `ModelProvider`. - /// On failure, returns the original string back as the `Err` value. - fn from_str(value: &str) -> Result { - // serde requires quotes (for JSON) - serde_json::from_str::(&format!("\"{value}\"")) - .map_err(|err| format!("Model provider {value} invalid: {err}")) - } -} - -impl TryFrom for ModelProvider { - type Error = String; - fn try_from(value: String) -> Result { - value.as_str().parse() - } -} - -impl TryFrom<&str> for ModelProvider { - type Error = String; - fn try_from(value: &str) -> Result { - value.parse() - } -} - -impl fmt::Display for ModelProvider { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // guaranteed not to fail because this is enum to string serialization - let self_str = serde_json::to_string(&self).unwrap_or_default(); - // remove quotes from JSON - write!(f, "{}", self_str.trim_matches('"')) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_model_string_conversion() { - let model = Model::Gemma3_4b; - - // convert to string - let model_str = model.clone().to_string(); - assert_eq!(model_str, "gemma3:4b"); - - // (try) convert from string - let model_from = Model::try_from(model_str).expect("should convert"); - assert_eq!(model_from, model); - - // (try) convert from string - let model = Model::try_from("this-model-does-not-will-not-exist".to_string()); - assert!(model.is_err()); - } - - #[test] - fn test_model_string_serde() { - let model = Model::Gemma3_12b; - - // serialize to string via serde - let model_str = serde_json::to_string(&model).expect("should serialize"); - assert_eq!(model_str, "\"gemma3:12b\""); - - // deserialize from string via serde - let model_from: Model = serde_json::from_str(&model_str).expect("should deserialize"); - assert_eq!(model_from, model); - - // (try) deserialize from invalid model - let bad_model = serde_json::from_str::("\"this-model-does-not-will-not-exist\""); - assert!(bad_model.is_err()); - } - - #[test] - fn test_provider_string_serde() { - let provider = ModelProvider::Ollama; - - // serialize to string via serde - let provider_str = serde_json::to_string(&provider).expect("should serialize"); - assert_eq!(provider_str, "\"ollama\""); - - // deserialize from string via serde - let provider_from: ModelProvider = - serde_json::from_str(&provider_str).expect("should deserialize"); - assert_eq!(provider_from, provider); - - // (try) deserialize from invalid model - let bad_provider = - serde_json::from_str::("\"this-provider-does-not-will-not-exist\""); - assert!(bad_provider.is_err()); - } -} diff --git a/executor/src/task.rs b/executor/src/task.rs deleted file mode 100644 index 4407be63..00000000 --- a/executor/src/task.rs +++ /dev/null @@ -1,168 +0,0 @@ -use rig::{ - completion::{CompletionRequest, PromptError}, - message::Message, -}; -use serde::{Deserialize, Deserializer}; - -use crate::{Model, ModelProvider}; - -/// A future that represents the result of a task execution, of any provider. -pub type TaskResult = Result; - -/// The body of a task request that includes the messages and the model to use. -/// -/// Implements a custom [`Deserialize`] to convert from an object of the form below to self: -/// -/// ```ts -/// { -/// "model": string, -/// "messages": { role: string, content: string }[] -/// } -/// ``` -/// -/// For the `messages` array, the following rules apply: -/// - If the first message is a system message, it will be stored in the `preamble` field. -/// - The last message must be a user message, and it will be stored in the `prompt` field. -/// - All other intermediate messages will be stored in the `chat_history` field. -#[derive(Debug, Clone)] -pub struct TaskBody { - /// An optional system prompt. - pub preamble: Option, - /// The main user prompt. - pub prompt: Message, - /// List of messages for context or chat history. - pub chat_history: Vec, - /// The model to use for the task. - pub model: Model, -} - -impl TaskBody { - /// Creates a new task body with the given prompt and model. - pub fn new_prompt(prompt: impl Into, model: Model) -> Self { - TaskBody { - preamble: None, - prompt: Message::user(prompt), - chat_history: Vec::default(), - model, - } - } - - /// Returns whether this task can be executed in parallel, w.r.t to its model. - pub fn is_batchable(&self) -> bool { - self.model.provider() != ModelProvider::Ollama - } -} - -impl From for CompletionRequest { - fn from(task_body: TaskBody) -> Self { - CompletionRequest { - prompt: task_body.prompt, - preamble: task_body.preamble, - chat_history: task_body.chat_history, - documents: Vec::default(), - tools: Vec::default(), - temperature: None, - max_tokens: None, - additional_params: None, - } - } -} - -impl<'de> Deserialize<'de> for TaskBody { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - use serde::de::Error; - - #[derive(Deserialize)] - struct RawMessage { - role: String, - content: String, - } - - #[derive(Deserialize)] - struct RawTaskBody { - model: String, - messages: Vec, - } - - let raw = RawTaskBody::deserialize(deserializer)?; - - // parse model - let model = Model::try_from(raw.model).map_err(|err_model| { - Error::custom(format!("Model {err_model} is not supported by this node.")) - })?; - - // ensure there are messages - if raw.messages.is_empty() { - return Err(Error::custom("No messages found in the task body")); - } - - // ensure the last message is from the user - if raw.messages.last().unwrap().role != "user" { - return Err(Error::custom("Last message must be from the user")); - } - - let mut preamble = None; - let mut messages = Vec::new(); - for msg in raw.messages.into_iter() { - match msg.role.as_str() { - "system" => { - // we only expect to see one system message ever - if preamble.is_some() { - return Err(Error::custom("Only one system message is allowed")); - } - preamble = Some(msg.content); - } - "user" => { - messages.push(Message::user(msg.content)); - } - "assistant" => { - messages.push(Message::assistant(msg.content)); - } - _ => { - return Err(Error::custom(format!("Invalid role: {}", msg.role))); - } - } - } - - // the last message (ensured to be role: user), will be returned as the prompt separately - let prompt = messages.pop().unwrap(); - - Ok(TaskBody { - preamble, - prompt, - chat_history: messages, - model, - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_task_body_deserialization() { - let json_data = json!({ - "model": "gemma3:4b", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is the capital of France?"}, - {"role": "assistant", "content": "The capital of France is Paris."}, - {"role": "user", "content": "How many letters are there in the answer to the last question?"}, - ] - }); - - let task_body: TaskBody = serde_json::from_value(json_data).unwrap(); - - assert_eq!(task_body.model, Model::Gemma3_4b); - assert_eq!( - task_body.preamble, - Some("You are a helpful assistant.".to_string()) - ); - assert_eq!(task_body.chat_history.len(), 2); - } -} diff --git a/p2p/Cargo.toml b/p2p/Cargo.toml deleted file mode 100644 index 31e26396..00000000 --- a/p2p/Cargo.toml +++ /dev/null @@ -1,36 +0,0 @@ -[package] -name = "dkn-p2p" -version.workspace = true -edition.workspace = true -license.workspace = true -readme = "README.md" -authors = [ - "Erhan Tezcan ", - "Anil Altuner { - todo!("handle stuff") - } - None => { - todo!("channel closed"); - break - } - } -} -``` - -### Interactions - -Here is how the whole thing works in a bit more detail: - -- **Events**: When a message is received within the Swarm event handler, it is returned via a `mpsc` channel. Here, the p2p is `Sender` and your application must be the `Receiver`. The client handles many events, and only sends GossipSub message receipts via this channel so that the application can handle them however they would like. - -```mermaid -sequenceDiagram - actor A as Application - actor P as P2P Client - - note over P: e_tx - note over A: e_rx - - loop event loop - activate P - note over A: e_rx.wait() - P ->> A: e_tx.send(message) - deactivate P - - note over A: handle message - end - -``` - -- **Commands**: To call functions within this thread-scoped client, functions must be remotely called via the command `mpsc` channel. Here, p2p is `Receiver` and your application will be the `Sender` (we provide the commander client as well). While making a function call, a `oneshot` channel is created and its `Sender` is provided to the commander, kind of like a callback, and the caller waits as the `Receiver` for this call. - -```mermaid -sequenceDiagram - actor C as P2P Commander - actor P as P2P Client - note over C: c_tx - activate C - note over P: c_rx - - note over P: c_rx.wait() - note over C: o_tx, o_rx := oneshot() - C ->> P: c_tx.send(input, o_tx) - deactivate C - activate P - note over C: o_rx.wait() - P ->> C: o_tx.send(output) - deactivate P -``` diff --git a/p2p/src/behaviour.rs b/p2p/src/behaviour.rs deleted file mode 100644 index f0723272..00000000 --- a/p2p/src/behaviour.rs +++ /dev/null @@ -1,53 +0,0 @@ -use eyre::Result; -use libp2p::identity::{Keypair, PublicKey}; -use libp2p::{identify, request_response, StreamProtocol}; -use std::time::Duration; - -use crate::DriaP2PProtocol; - -#[derive(libp2p::swarm::NetworkBehaviour)] -pub struct DriaBehaviour { - pub identify: identify::Behaviour, - pub request_response: request_response::cbor::Behaviour, Vec>, -} - -impl DriaBehaviour { - pub fn new(key: &Keypair, protocol: &DriaP2PProtocol) -> Self { - let public_key = key.public(); - - Self { - identify: create_identify_behaviour(public_key, protocol.identity()), - request_response: create_request_response_behaviour(protocol.request_response()), - } - } -} - -/// Configures the request-response behaviour for the node. -/// -/// The protocol supports bytes only. -#[inline] -fn create_request_response_behaviour( - protocol_name: StreamProtocol, -) -> request_response::cbor::Behaviour, Vec> { - use request_response::{Behaviour, Config, ProtocolSupport}; - - const REQUEST_RESPONSE_TIMEOUT: Duration = Duration::from_secs(512); - - Behaviour::new( - [(protocol_name, ProtocolSupport::Full)], - Config::default().with_request_timeout(REQUEST_RESPONSE_TIMEOUT), - ) -} - -/// Configures the Identify behavior to allow nodes to exchange information like supported protocols. -#[inline] -fn create_identify_behaviour( - local_public_key: PublicKey, - protocol_version: String, -) -> identify::Behaviour { - use identify::{Behaviour, Config}; - - Behaviour::new( - Config::new(protocol_version, local_public_key).with_push_listen_addr_updates(true), - ) -} diff --git a/p2p/src/client.rs b/p2p/src/client.rs deleted file mode 100644 index fe637ad9..00000000 --- a/p2p/src/client.rs +++ /dev/null @@ -1,358 +0,0 @@ -use eyre::Result; -use libp2p::futures::StreamExt; -use libp2p::swarm::{ - dial_opts::{DialOpts, PeerCondition}, - SwarmEvent, -}; -use libp2p::{identify, noise, request_response, tcp, yamux}; -use libp2p::{Multiaddr, PeerId, Swarm, SwarmBuilder}; -use libp2p_identity::Keypair; -use std::time::Duration; -use tokio::sync::mpsc; - -use crate::behaviour::{DriaBehaviour, DriaBehaviourEvent}; -use crate::DriaP2PProtocol; - -use super::commands::DriaP2PCommand; -use super::DriaP2PCommander; - -/// Buffer size for command channel. -const COMMAND_CHANNEL_BUFSIZE: usize = 1024; -/// Buffer size for events channel. -const MSG_CHANNEL_BUFSIZE: usize = 1024; - -/// Request-response message type for Dria protocol, accepts bytes as both request and response. -/// -/// The additional parsing must be done by the application itself (for now). -pub type DriaReqResMessage = request_response::Message, Vec>; - -/// Peer-to-peer client for Dria Knowledge Network. -pub struct DriaP2PClient { - pub peer_id: PeerId, - /// `Swarm` instance, everything p2p-related are accessed through this instace. - swarm: Swarm, - /// Dria protocol, used for identifying the client. - protocol: DriaP2PProtocol, - /// Request-response protocol messages. - reqres_tx: mpsc::Sender<(PeerId, DriaReqResMessage)>, - /// Command receiver. - cmd_rx: mpsc::Receiver, -} - -impl DriaP2PClient { - /// Creates a new P2P client with the given keypair and listen address. - /// - /// The `version` is used to create the protocol strings for the client, and its very important that - /// they match with the clients existing within the network. - /// - /// If for any reason the given `listen_addr` is not available, it will try to listen on a random port on `localhost`. - #[allow(clippy::type_complexity)] - pub fn new( - keypair: Keypair, - listen_addr: Multiaddr, - rpc_addr: &Multiaddr, - protocol: DriaP2PProtocol, - ) -> Result<( - DriaP2PClient, - DriaP2PCommander, - mpsc::Receiver<(PeerId, DriaReqResMessage)>, - )> { - let peer_id = keypair.public().to_peer_id(); - - let mut swarm = SwarmBuilder::with_existing_identity(keypair) - .with_tokio() - .with_tcp( - tcp::Config::default(), - noise::Config::new, - yamux::Config::default, - )? - .with_behaviour(|key| DriaBehaviour::new(key, &protocol))? - // do not timeout at all, as we are only connected to an authority RPC at a given time and should stick to it - .with_swarm_config(|c| c.with_idle_connection_timeout(Duration::from_secs(u64::MAX))) - .build(); - - // listen on all interfaces for incoming connections - log::info!("Listening p2p network on: {listen_addr}"); - if let Err(err) = swarm.listen_on(listen_addr) { - log::error!("Could not listen on address: {err:?}"); - log::warn!("Trying fallback address with localhost random port"); - swarm.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap())?; - } - - // dial rpc node, this will cause `identify` event to be called on their side - log::info!("Dialing RPC node: {rpc_addr}"); - if let Err(err) = swarm.dial(rpc_addr.clone()) { - log::error!("Could not dial RPC node: {err:?}"); - }; - - // create commander - let (cmd_tx, cmd_rx) = mpsc::channel(COMMAND_CHANNEL_BUFSIZE); - let commander = DriaP2PCommander::new(cmd_tx, protocol.clone()); - - // create p2p client itself - let (reqres_tx, reqres_rx) = mpsc::channel(MSG_CHANNEL_BUFSIZE); - - let client = Self { - peer_id, - swarm, - protocol, - reqres_tx, - cmd_rx, - }; - - Ok((client, commander, reqres_rx)) - } - - /// Waits for swarm events and Node commands at the same time. - /// - /// To terminate, the command channel must be closed. - pub async fn run(mut self) { - loop { - tokio::select! { - command = self.cmd_rx.recv() => match command { - Some(c) => self.handle_command(c).await, - // channel closed, thus shutting down the network event loop - None=> { - log::info!("Closing peer-to-peer client."); - return - }, - }, - event = self.swarm.select_next_some() => self.handle_event(event).await, - } - } - } - - /// Handles a single command, which originates from `DriaP2PCommander`. - pub async fn handle_command(&mut self, command: DriaP2PCommand) { - match command { - DriaP2PCommand::Dial { - peer_id, - address, - sender, - } => { - let opts = DialOpts::peer_id(peer_id) - .addresses(vec![address]) - .condition(PeerCondition::Always) - .build(); - let _ = sender.send(self.swarm.dial(opts)); - } - DriaP2PCommand::IsConnected { peer_id, sender } => { - let _ = sender.send(self.swarm.is_connected(&peer_id)); - } - DriaP2PCommand::NetworkInfo { sender } => { - let _ = sender.send(self.swarm.network_info()); - } - DriaP2PCommand::Respond { - data, - channel, - sender, - } => { - let _ = sender.send( - self.swarm - .behaviour_mut() - .request_response - .send_response(channel, data) - .map_err(|_| eyre::eyre!("could not send response, channel is closed?")), - ); - } - DriaP2PCommand::Request { - data, - peer_id, - sender, - } => { - let _ = sender.send( - self.swarm - .behaviour_mut() - .request_response - .send_request(&peer_id, data), - ); - } - DriaP2PCommand::Shutdown { sender } => { - // close the command channel - self.cmd_rx.close(); - - let _ = sender.send(()); - } - } - } - - /// Handles a single event from the `swarm` stream. - pub async fn handle_event(&mut self, event: SwarmEvent) { - match event { - /***************************************** - * Request-response events * - *****************************************/ - SwarmEvent::Behaviour(DriaBehaviourEvent::RequestResponse( - request_response::Event::Message { message, peer, .. }, - )) => { - // whether its a request or response, we forward it to the main thread - if let Err(err) = self.reqres_tx.send((peer, message)).await { - log::error!("Could not transfer request {err:?}"); - } - } - - SwarmEvent::Behaviour(DriaBehaviourEvent::RequestResponse( - request_response::Event::ResponseSent { - peer, request_id, .. - }, - )) => { - log::debug!("Request-Response: response ({request_id}) sent to peer {peer} with",) - } - SwarmEvent::Behaviour(DriaBehaviourEvent::RequestResponse( - request_response::Event::OutboundFailure { - peer, - request_id, - error, - .. - }, - )) => { - log::error!( - "Request-Response: Outbound failure to peer {peer} with request_id {request_id}: {error:?}", - ); - } - SwarmEvent::Behaviour(DriaBehaviourEvent::RequestResponse( - request_response::Event::InboundFailure { - peer, - request_id, - error, - .. - }, - )) => { - log::error!( - "Request-Response: Inbound failure to {peer} with request_id {request_id}: {error:?}" - ); - } - - /***************************************** - * Identify events * - *****************************************/ - SwarmEvent::Behaviour(DriaBehaviourEvent::Identify(identify::Event::Received { - peer_id, - info, - .. - })) => { - if info.protocol_version != self.protocol.identity { - log::warn!( - "Identify: Peer {} has different Identify protocol: (them {}, you {})", - peer_id, - info.protocol_version, - self.protocol.identity - ); - - // disconnect them - let _ = self.swarm.disconnect_peer_id(peer_id); - } - } - - /***************************************** - * Connection events and errors handling * - *****************************************/ - SwarmEvent::NewListenAddr { address, .. } => { - log::warn!("Local node is listening on {address}"); - } - SwarmEvent::NewExternalAddrOfPeer { peer_id, address } => { - log::info!("External address of peer {peer_id} confirmed: {address}"); - } - SwarmEvent::ExternalAddrConfirmed { address } => { - log::info!("External address confirmed: {address}"); - } - - SwarmEvent::IncomingConnectionError { - local_addr, - send_back_addr, - error, - .. - } => { - log::debug!( - "Incoming connection error: from {local_addr} to {send_back_addr} - {error:?}" - ); - } - SwarmEvent::IncomingConnection { - local_addr, - send_back_addr, - .. - } => { - log::debug!("Incoming connection attempt: from {local_addr} to {send_back_addr}"); - } - - SwarmEvent::OutgoingConnectionError { peer_id, error, .. } => { - if let Some(peer_id) = peer_id { - log::warn!("Could not connect to peer {peer_id}: {error:?}"); - } else { - log::warn!("Outgoing connection error: {error:?}"); - } - } - - SwarmEvent::ConnectionEstablished { - peer_id, - connection_id, - endpoint, - .. - } => { - if endpoint.is_dialer() { - // we only care about logs about the ones that we have dialed - log::info!( - "Connection ({connection_id}) established with {peer_id} at {}", - endpoint.get_remote_address() - ); - } else { - log::debug!( - "Connection ({connection_id}) established with {peer_id} from {}", - endpoint.get_remote_address() - ); - } - } - - SwarmEvent::ConnectionClosed { - peer_id, - connection_id, - endpoint, - cause, - .. - } => { - // we only care about the connections that we have dialed - if endpoint.is_dialer() { - // if we know the cause, it may be a good idea to re-dial - if let Some(cause) = cause { - log::warn!( - "Connection ({connection_id}) closed for {peer_id} due to {cause}" - ); - - let addr = endpoint.get_remote_address(); - log::info!("Dialing {peer_id} again at {addr}"); - if let Err(err) = self.swarm.dial( - DialOpts::peer_id(peer_id) - .addresses(vec![addr.clone()]) - .condition(PeerCondition::DisconnectedAndNotDialing) - .build(), - ) { - log::error!("Could not dial peer {peer_id}: {err:?}"); - } - } else { - // if we don't know the cause, we don't want to re-dial, - // because the cause is `None` if the other side closed the connection manually - log::warn!( - "Connection ({connection_id}) closed for {peer_id} without a cause, will not re-dial!" - ); - } - } else { - log::debug!("Connection ({connection_id}) closed for {peer_id}: {cause:?}",); - } - } - - SwarmEvent::ExpiredListenAddr { - address, - listener_id, - } => { - // this may happen when your connection is lost, e.g. you turn off your machine / internet - log::warn!("Listener ({listener_id}) expired: {address}"); - } - - SwarmEvent::ListenerError { listener_id, error } => { - log::error!("Listener ({listener_id}) failed: {error}"); - } - - event => log::debug!("Unhandled Swarm Event: {event:?}"), - } - } -} diff --git a/p2p/src/commands.rs b/p2p/src/commands.rs deleted file mode 100644 index 2a1a344a..00000000 --- a/p2p/src/commands.rs +++ /dev/null @@ -1,154 +0,0 @@ -use eyre::{Context, Result}; -use libp2p::{request_response, swarm, Multiaddr, PeerId}; -use tokio::sync::{mpsc, oneshot}; - -use crate::DriaP2PProtocol; - -#[derive(Debug)] -pub enum DriaP2PCommand { - /// Returns the network information, such as the number of incoming and outgoing connections. - NetworkInfo { - sender: oneshot::Sender, - }, - /// Check if there is an active connection to the given peer. - IsConnected { - peer_id: PeerId, - sender: oneshot::Sender, - }, - /// Dial a known peer. - Dial { - peer_id: PeerId, - address: Multiaddr, - sender: oneshot::Sender>, - }, - /// Respond to a request-response message. - Respond { - data: Vec, - channel: request_response::ResponseChannel>, - sender: oneshot::Sender>, - }, - /// Request a request-response message. - /// Note that you are likely to be caught by the RPC peer id check, - /// and your messages will be ignored. - Request { - peer_id: PeerId, - data: Vec, - sender: oneshot::Sender, - }, - /// Shutsdown the client, closes the command channel. - Shutdown { sender: oneshot::Sender<()> }, -} - -pub struct DriaP2PCommander { - sender: mpsc::Sender, - protocol: DriaP2PProtocol, -} - -impl DriaP2PCommander { - pub fn new(sender: mpsc::Sender, protocol: DriaP2PProtocol) -> Self { - Self { sender, protocol } - } - - /// Returns a reference to the protocol. - pub fn protocol(&self) -> &DriaP2PProtocol { - &self.protocol - } - - /// Returns the network information, such as the number of - /// incoming and outgoing connections. - pub async fn network_info(&self) -> Result { - let (sender, receiver) = oneshot::channel(); - - self.sender - .send(DriaP2PCommand::NetworkInfo { sender }) - .await - .wrap_err("could not send")?; - - receiver.await.wrap_err("could not receive") - } - - pub async fn respond( - &mut self, - data: Vec, - channel: request_response::ResponseChannel>, - ) -> Result<()> { - let (sender, receiver) = oneshot::channel(); - - self.sender - .send(DriaP2PCommand::Respond { - data, - channel, - sender, - }) - .await - .wrap_err("could not send")?; - - receiver - .await - .wrap_err("could not receive")? - .wrap_err("could not respond") - } - - pub async fn request( - &mut self, - peer_id: PeerId, - data: impl Into>, - ) -> Result { - let data = data.into(); - let (sender, receiver) = oneshot::channel(); - - self.sender - .send(DriaP2PCommand::Request { - data, - peer_id, - sender, - }) - .await - .wrap_err("could not send")?; - - receiver.await.wrap_err("could not receive") - } - - /// Dials a given peer. - pub async fn dial(&mut self, peer_id: PeerId, address: Multiaddr) -> Result<()> { - let (sender, receiver) = oneshot::channel(); - - self.sender - .send(DriaP2PCommand::Dial { - peer_id, - address, - sender, - }) - .await - .wrap_err("could not send")?; - - receiver - .await - .wrap_err("could not receive")? - .wrap_err("could not dial") - } - - /// Checks if there is an active connection to the given peer. - pub async fn is_connected(&mut self, peer_id: PeerId) -> Result { - let (sender, receiver) = oneshot::channel(); - - self.sender - .send(DriaP2PCommand::IsConnected { peer_id, sender }) - .await - .wrap_err("could not send")?; - - receiver.await.wrap_err("could not receive") - } - - /// Sends a shutdown signal to the client. - pub async fn shutdown(&mut self) -> Result<()> { - let (sender, receiver) = oneshot::channel(); - - self.sender - .send(DriaP2PCommand::Shutdown { sender }) - .await - .wrap_err("could not send")?; - - receiver.await.wrap_err("could not receive") - } -} diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs deleted file mode 100644 index 81383226..00000000 --- a/p2p/src/lib.rs +++ /dev/null @@ -1,14 +0,0 @@ -mod behaviour; - -mod client; -pub use client::{DriaP2PClient, DriaReqResMessage}; - -mod commands; -pub use commands::{DriaP2PCommand, DriaP2PCommander}; - -mod protocol; -pub use protocol::DriaP2PProtocol; - -// re-exports -pub use libp2p; -pub use libp2p_identity; diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs deleted file mode 100644 index 52781b9f..00000000 --- a/p2p/src/protocol.rs +++ /dev/null @@ -1,104 +0,0 @@ -use libp2p::StreamProtocol; -use std::env; - -#[derive(Clone, Debug)] -pub struct DriaP2PProtocol { - /// Main protocol name, e.g. `dria`. - pub name: String, - /// Version of the protocol, e.g. `0.2`. - /// By default, this is set to the current `major.minor` version of the crate. - pub version: String, - /// Identity protocol string to be used for the Identity behaviour. - /// - /// This is usually `{name}/{version}`. - pub identity: String, - /// Request-response protocol, must match with other peers in the network. - /// - /// This is usually `/{name}/rr/{version}`, notice the `/` at the start - /// which is mandatory for a `StreamProtocol`. - /// - pub request_response: StreamProtocol, -} - -impl std::fmt::Display for DriaP2PProtocol { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.identity) - } -} - -impl Default for DriaP2PProtocol { - /// Creates a new instance of the protocol with the default name `dria`. - fn default() -> Self { - Self::new_major_minor("dria") - } -} - -impl DriaP2PProtocol { - /// Creates a new instance of the protocol with the given `name` and `version`. - pub fn new(name: impl ToString, version: impl ToString) -> Self { - let name = name.to_string(); - let version = version.to_string(); - - let identity = format!("{name}/{version}"); - let request_response = - StreamProtocol::try_from_owned(format!("/{name}/rr/{version}")).unwrap(); - - Self { - name, - version, - identity, - request_response, - } - } - - /// Creates a new instance of the protocol with the given `name` and the current version as per Cargo.toml. - /// The verison is represented with `major.minor` version numbers. - pub fn new_major_minor(name: &str) -> Self { - const VERSION: &str = concat!( - env!("CARGO_PKG_VERSION_MAJOR"), - ".", - env!("CARGO_PKG_VERSION_MINOR") - ); - - Self::new(name, VERSION) - } - - /// Returns the identity protocol, e.g. `dria/0.2`. - pub fn identity(&self) -> String { - self.identity.clone() - } - - /// Returns the request-response protocol, e.g. `/dria/rr/0.2`. - pub fn request_response(&self) -> StreamProtocol { - self.request_response.clone() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_new() { - let protocol = DriaP2PProtocol::new("test", "1.0"); - assert_eq!(protocol.name, "test"); - assert_eq!(protocol.version, "1.0"); - assert_eq!(protocol.identity, "test/1.0"); - assert_eq!(protocol.request_response.to_string(), "/test/rr/1.0"); - } - - #[test] - fn test_new_major_minor() { - let protocol = DriaP2PProtocol::new_major_minor("test"); - assert_eq!(protocol.name, "test"); - assert_eq!( - protocol.version, - concat!( - env!("CARGO_PKG_VERSION_MAJOR"), - ".", - env!("CARGO_PKG_VERSION_MINOR") - ) - ); - assert_eq!(protocol.identity, format!("test/{}", protocol.version)); - } -} diff --git a/p2p/tests/request_test.rs b/p2p/tests/request_test.rs deleted file mode 100644 index ce8c3132..00000000 --- a/p2p/tests/request_test.rs +++ /dev/null @@ -1,64 +0,0 @@ -use std::str::FromStr; -use std::thread::sleep; -use std::time::Duration; - -use dkn_p2p::{DriaP2PClient, DriaP2PProtocol}; -use eyre::Result; -use libp2p::PeerId; -use libp2p_identity::Keypair; - -/// Makes a dummy request to some peer hardcoded within the test. -/// -/// ## Run command -/// -/// ```sh -/// cargo test --package dkn-p2p --test request_test --all-features -- test_request_message --exact --show-output --ignored -/// ``` -#[tokio::test] -#[ignore = "run this manually"] -async fn test_request_message() -> Result<()> { - let _ = env_logger::builder() - .filter_level(log::LevelFilter::Off) - .filter_module("request_test", log::LevelFilter::Debug) - .filter_module("dkn_p2p", log::LevelFilter::Debug) - .is_test(true) - .try_init(); - - // prepare nodes - let rpc_addr = "your-rpc-here".parse().unwrap(); - - // spawn P2P client in another task - let (client, mut commander, mut req_rx) = DriaP2PClient::new( - Keypair::generate_secp256k1(), - "/ip4/127.0.0.1/tcp/0".parse().unwrap(), - &rpc_addr, - DriaP2PProtocol::default(), - ) - .expect("could not create p2p client"); - - // spawn task - let task_handle = tokio::spawn(async move { client.run().await }); - - log::info!("Waiting a bit until we have enough peers"); - sleep(Duration::from_secs(10)); - - let peer_id = - PeerId::from_str("16Uiu2HAmB5HGdwLNHX81u7ey1fvDx5Mr4ofa2PdSSVxFKrrcErAN").unwrap(); - log::info!("Making a request to peer: {}", peer_id); - commander.request(peer_id, b"here is some data").await?; - - log::info!("Waiting for response logs for a few moments..."); - sleep(Duration::from_secs(5)); - - // close command channel - commander.shutdown().await.expect("could not shutdown"); - - // close other channels - req_rx.close(); - - log::info!("Waiting for p2p task to finish..."); - task_handle.await?; - - log::info!("Done!"); - Ok(()) -} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 00000000..69b2aec5 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,213 @@ +use std::path::PathBuf; + +use clap::{Parser, Subcommand}; + +use crate::error::NodeError; + +#[derive(Parser)] +#[command(name = "dria-node", version, about = "Dria Compute Node")] +pub struct Cli { + #[command(subcommand)] + pub command: Command, +} + +#[derive(Subcommand)] +pub enum Command { + /// Start the compute node + Start { + /// Wallet secret key (hex-encoded, 32 bytes) + #[arg(long, env = "DRIA_WALLET")] + wallet: String, + + /// Model(s) to serve (comma-separated shortnames, e.g. "gemma3:4b,llama3.1:8b") + #[arg(long, env = "DRIA_MODELS")] + model: String, + + /// Router URL for task coordination + #[arg(long, env = "DRIA_ROUTER_URL", default_value = "https://router.dria.co")] + router_url: String, + + /// Number of GPU layers to offload (-1 = all, 0 = CPU only) + #[arg(long, env = "DRIA_GPU_LAYERS", default_value = "0")] + gpu_layers: i32, + + /// Maximum concurrent inference requests + #[arg(long, env = "DRIA_MAX_CONCURRENT", default_value = "1")] + max_concurrent: usize, + + /// Data directory + #[arg(long, env = "DRIA_DATA_DIR")] + data_dir: Option, + + /// Skip TLS certificate verification (for development/testing) + #[arg(long, env = "DRIA_INSECURE")] + insecure: bool, + }, +} + +/// Parsed and validated configuration for the node. +pub struct Config { + pub secret_key_hex: String, + pub model_names: Vec, + pub router_url: String, + pub gpu_layers: i32, + pub max_concurrent: usize, + pub data_dir: PathBuf, + pub models_dir: PathBuf, + pub insecure: bool, +} + +impl Config { + /// Create a Config from the `start` subcommand arguments. + pub fn from_start_args( + wallet: String, + model: String, + router_url: String, + gpu_layers: i32, + max_concurrent: usize, + data_dir: Option, + insecure: bool, + ) -> Result { + // Validate wallet key + let secret_key_hex = wallet.strip_prefix("0x").unwrap_or(&wallet).to_string(); + if secret_key_hex.len() != 64 { + return Err(NodeError::Config(format!( + "wallet secret key must be 64 hex chars, got {}", + secret_key_hex.len() + ))); + } + hex::decode(&secret_key_hex) + .map_err(|e| NodeError::Config(format!("wallet key is not valid hex: {e}")))?; + + // Parse model names + let model_names: Vec = model + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + if model_names.is_empty() { + return Err(NodeError::Config("at least one model must be specified".into())); + } + + // Resolve data directory + let data_dir = match data_dir { + Some(d) => d, + None => dirs::home_dir() + .ok_or_else(|| NodeError::Config("could not determine home directory".into()))? + .join(".dria"), + }; + let models_dir = data_dir.join("models"); + + if max_concurrent == 0 { + return Err(NodeError::Config("max-concurrent must be >= 1".into())); + } + + Ok(Config { + secret_key_hex, + model_names, + router_url, + gpu_layers, + max_concurrent, + data_dir, + models_dir, + insecure, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_from_valid_args() { + let cfg = Config::from_start_args( + "0x6472696164726961647269616472696164726961647269616472696164726961".into(), + "gemma3:4b, llama3.1:8b".into(), + "https://router.dria.co".into(), + 0, + 1, + Some("/tmp/dria-test".into()), + false, + ) + .unwrap(); + + assert_eq!(cfg.model_names, vec!["gemma3:4b", "llama3.1:8b"]); + assert_eq!( + cfg.secret_key_hex, + "6472696164726961647269616472696164726961647269616472696164726961" + ); + assert_eq!(cfg.models_dir, PathBuf::from("/tmp/dria-test/models")); + } + + #[test] + fn test_config_invalid_wallet_length() { + let result = Config::from_start_args( + "0xabcd".into(), + "gemma3:4b".into(), + "https://router.dria.co".into(), + 0, + 1, + None, + false, + ); + assert!(result.is_err()); + } + + #[test] + fn test_config_invalid_wallet_hex() { + let result = Config::from_start_args( + "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz".into(), + "gemma3:4b".into(), + "https://router.dria.co".into(), + 0, + 1, + None, + false, + ); + assert!(result.is_err()); + } + + #[test] + fn test_config_empty_model() { + let result = Config::from_start_args( + "6472696164726961647269616472696164726961647269616472696164726961".into(), + "".into(), + "https://router.dria.co".into(), + 0, + 1, + None, + false, + ); + assert!(result.is_err()); + } + + #[test] + fn test_config_zero_concurrency() { + let result = Config::from_start_args( + "6472696164726961647269616472696164726961647269616472696164726961".into(), + "gemma3:4b".into(), + "https://router.dria.co".into(), + 0, + 0, + None, + false, + ); + assert!(result.is_err()); + } + + #[test] + fn test_config_insecure_flag() { + let cfg = Config::from_start_args( + "6472696164726961647269616472696164726961647269616472696164726961".into(), + "gemma3:4b".into(), + "https://router.dria.co".into(), + 0, + 1, + None, + true, + ) + .unwrap(); + assert!(cfg.insecure); + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 00000000..652e8dee --- /dev/null +++ b/src/error.rs @@ -0,0 +1,22 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum NodeError { + #[error("config error: {0}")] + Config(String), + + #[error("identity error: {0}")] + Identity(String), + + #[error("inference error: {0}")] + Inference(String), + + #[error("model error: {0}")] + Model(String), + + #[error("network error: {0}")] + Network(String), + + #[error("io error: {0}")] + Io(#[from] std::io::Error), +} diff --git a/src/identity.rs b/src/identity.rs new file mode 100644 index 00000000..86f39412 --- /dev/null +++ b/src/identity.rs @@ -0,0 +1,114 @@ +use libsecp256k1::{PublicKey, SecretKey}; +use sha2::{Digest as _, Sha256}; +use sha3::Keccak256; + +use crate::error::NodeError; + +/// Node identity derived from a secp256k1 secret key. +/// The address is an Ethereum-style address (last 20 bytes of keccak256 of uncompressed pubkey). +pub struct Identity { + pub secret_key: SecretKey, + pub public_key: PublicKey, + pub address: [u8; 20], + pub address_hex: String, +} + +impl Identity { + /// Create an identity from a hex-encoded secret key (without 0x prefix). + pub fn from_secret_hex(hex_str: &str) -> Result { + let bytes = hex::decode(hex_str) + .map_err(|e| NodeError::Identity(format!("invalid hex: {e}")))?; + let secret_key = SecretKey::parse_slice(&bytes) + .map_err(|e| NodeError::Identity(format!("invalid secret key: {e}")))?; + let public_key = PublicKey::from_secret_key(&secret_key); + let address = public_key_to_address(&public_key); + let address_hex = hex::encode(address); + + Ok(Identity { + secret_key, + public_key, + address, + address_hex, + }) + } + + /// Sign a SHA-256 digest of the given message. + /// Returns (signature, recovery_id). + pub fn sign(&self, message: &[u8]) -> (libsecp256k1::Signature, libsecp256k1::RecoveryId) { + let digest = sha256hash(message); + let msg = libsecp256k1::Message::parse_slice(&digest) + .expect("SHA-256 output is always 32 bytes"); + libsecp256k1::sign(&msg, &self.secret_key) + } +} + +/// SHA-256 hash. +#[inline(always)] +pub fn sha256hash(data: impl AsRef<[u8]>) -> [u8; 32] { + Sha256::digest(data).into() +} + +/// Keccak-256 hash. +#[inline(always)] +pub fn keccak256hash(data: impl AsRef<[u8]>) -> [u8; 32] { + Keccak256::digest(data).into() +} + +/// Derive an Ethereum address from a secp256k1 public key. +/// Serializes uncompressed (65 bytes: 0x04 || x || y), hashes (x || y) with keccak256, +/// and takes the last 20 bytes. +#[inline] +fn public_key_to_address(public_key: &PublicKey) -> [u8; 20] { + let public_key_xy = &public_key.serialize()[1..]; + let mut addr = [0u8; 20]; + addr.copy_from_slice(&keccak256hash(public_key_xy)[12..32]); + addr +} + +#[cfg(test)] +mod tests { + use super::*; + + const DUMMY_SECRET_KEY: &[u8; 32] = b"driadriadriadriadriadriadriadria"; + + #[test] + fn test_sha256() { + let expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"; + assert_eq!(hex::encode(sha256hash(b"hello world")), expected); + } + + #[test] + fn test_address_from_secret() { + let hex_key = hex::encode(DUMMY_SECRET_KEY); + let identity = Identity::from_secret_hex(&hex_key).unwrap(); + assert_eq!( + identity.address_hex, + "d79fdf178547614cfdd0df6397c53569716bd596" + ); + } + + #[test] + fn test_sign_and_recover() { + let hex_key = hex::encode(DUMMY_SECRET_KEY); + let identity = Identity::from_secret_hex(&hex_key).unwrap(); + + let message = b"hello world"; + let (signature, recid) = identity.sign(message); + + // Recover public key from signature + let digest = sha256hash(message); + let msg = libsecp256k1::Message::parse_slice(&digest).unwrap(); + let recovered = libsecp256k1::recover(&msg, &signature, &recid).unwrap(); + assert_eq!(recovered, identity.public_key); + } + + #[test] + fn test_invalid_hex() { + assert!(Identity::from_secret_hex("not-hex").is_err()); + } + + #[test] + fn test_invalid_key_length() { + assert!(Identity::from_secret_hex("abcd").is_err()); + } +} diff --git a/src/inference/benchmark.rs b/src/inference/benchmark.rs new file mode 100644 index 00000000..44c3146e --- /dev/null +++ b/src/inference/benchmark.rs @@ -0,0 +1,71 @@ +use std::ops::ControlFlow; +use std::time::Instant; + +use crate::error::NodeError; +use crate::inference::engine::{GenerateParams, InferenceEngine}; + +/// Result of a TPS benchmark run. +#[derive(Debug, Clone)] +pub struct TpsResult { + pub model_name: String, + pub prompt_eval_tps: f64, + pub generation_tps: f64, + pub total_time_ms: u64, + pub tokens_generated: u32, +} + +impl std::fmt::Display for TpsResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}: {:.1} tok/s generation, {:.1} tok/s prompt eval ({} tokens in {}ms)", + self.model_name, + self.generation_tps, + self.prompt_eval_tps, + self.tokens_generated, + self.total_time_ms, + ) + } +} + +const WARMUP_PROMPT: &str = "Write a short poem about hedgehogs and squirrels."; +const BENCHMARK_PROMPT: &str = "Please write a poem about Kapadokya."; +const BENCHMARK_MAX_TOKENS: u32 = 128; + +impl InferenceEngine { + /// Run a TPS benchmark: warmup generation, then timed generation. + pub fn benchmark(&self, model_name: &str) -> Result { + // Warmup: short generation to prime caches + let warmup_params = GenerateParams { + max_tokens: 16, + temperature: 0.0, + ..Default::default() + }; + let _ = self.generate(WARMUP_PROMPT, &warmup_params, |_| ControlFlow::Continue(())); + + // Timed benchmark + let bench_params = GenerateParams { + max_tokens: BENCHMARK_MAX_TOKENS, + temperature: 0.0, + ..Default::default() + }; + + let start = Instant::now(); + let result = self.generate(BENCHMARK_PROMPT, &bench_params, |_| ControlFlow::Continue(()))?; + let total_time_ms = start.elapsed().as_millis() as u64; + + let prompt_eval_tps = if result.prompt_eval_time_ms > 0 { + (result.prompt_tokens as f64) / (result.prompt_eval_time_ms as f64 / 1000.0) + } else { + 0.0 + }; + + Ok(TpsResult { + model_name: model_name.to_string(), + prompt_eval_tps, + generation_tps: result.tokens_per_second, + total_time_ms, + tokens_generated: result.tokens_generated, + }) + } +} diff --git a/src/inference/engine.rs b/src/inference/engine.rs new file mode 100644 index 00000000..b66f5f88 --- /dev/null +++ b/src/inference/engine.rs @@ -0,0 +1,297 @@ +use std::ops::ControlFlow; +use std::path::Path; +use std::time::Instant; + +use llama_cpp_2::context::params::LlamaContextParams; +use llama_cpp_2::llama_backend::LlamaBackend; +use llama_cpp_2::llama_batch::LlamaBatch; +use llama_cpp_2::model::params::LlamaModelParams; +use llama_cpp_2::model::{AddBos, LlamaModel}; +use llama_cpp_2::sampling::LlamaSampler; +use llama_cpp_2::token::LlamaToken; + +use crate::error::NodeError; +use crate::identity::sha256hash; +use crate::inference::proof::{InferenceProof, TokenLogprob}; +use crate::inference::stream::StreamToken; + +/// Parameters controlling text generation. +#[derive(Debug, Clone)] +pub struct GenerateParams { + pub max_tokens: u32, + pub temperature: f32, + pub top_p: f32, + pub seed: Option, + /// Token positions at which to extract logprobs. + pub logprob_positions: Vec, + /// Top-k alternatives to collect at each logprob position. + pub logprob_top_k: usize, +} + +impl Default for GenerateParams { + fn default() -> Self { + Self { + max_tokens: 512, + temperature: 0.7, + top_p: 0.9, + seed: None, + logprob_positions: vec![], + logprob_top_k: 5, + } + } +} + +/// Result of an inference run. +#[derive(Debug, Clone)] +pub struct InferenceResult { + pub text: String, + pub tokens_generated: u32, + pub prompt_tokens: u32, + pub generation_time_ms: u64, + pub prompt_eval_time_ms: u64, + pub tokens_per_second: f64, + pub proof: Option, +} + +/// Wraps llama-cpp-2 for model loading and inference. +/// +/// NOTE: `LlamaContext` is not Send/Sync. All inference must happen +/// via `tokio::task::spawn_blocking` with the engine moved into the closure. +pub struct InferenceEngine { + backend: LlamaBackend, + model: LlamaModel, + gpu_layers: i32, +} + +/// Helper to convert a token to a string piece using the new token_to_piece API. +fn token_to_string(model: &LlamaModel, token: LlamaToken) -> String { + let mut decoder = encoding_rs::UTF_8.new_decoder(); + model + .token_to_piece(token, &mut decoder, true, None) + .unwrap_or_default() +} + +impl InferenceEngine { + /// Load a GGUF model from disk. + pub fn load(path: &Path, gpu_layers: i32) -> Result { + let backend = LlamaBackend::init() + .map_err(|e| NodeError::Inference(format!("failed to init llama backend: {e}")))?; + + let model_params = if gpu_layers != 0 { + let layers = if gpu_layers < 0 { 1000 } else { gpu_layers as u32 }; + LlamaModelParams::default().with_n_gpu_layers(layers) + } else { + LlamaModelParams::default() + }; + + let model = LlamaModel::load_from_file(&backend, path, &model_params) + .map_err(|e| NodeError::Inference(format!("failed to load model: {e}")))?; + + Ok(InferenceEngine { + backend, + model, + gpu_layers, + }) + } + + /// Return the number of GPU layers configured. + pub fn gpu_layers(&self) -> i32 { + self.gpu_layers + } + + /// Generate text from a prompt. + /// + /// `on_token` is called for each generated token. Return `ControlFlow::Break(())` + /// to stop generation early. + pub fn generate( + &self, + prompt: &str, + params: &GenerateParams, + mut on_token: F, + ) -> Result + where + F: FnMut(StreamToken) -> ControlFlow<()>, + { + let ctx_size = std::num::NonZeroU32::new(2048); + let ctx_params = LlamaContextParams::default().with_n_ctx(ctx_size); + + let mut ctx = self + .model + .new_context(&self.backend, ctx_params) + .map_err(|e| NodeError::Inference(format!("failed to create context: {e}")))?; + + // Tokenize prompt + let tokens = self + .model + .str_to_token(prompt, AddBos::Always) + .map_err(|e| NodeError::Inference(format!("tokenization failed: {e}")))?; + let prompt_token_count = tokens.len() as u32; + + // Evaluate prompt + let prompt_start = Instant::now(); + let mut batch = LlamaBatch::new(tokens.len().max(1), 1); + for (i, &token) in tokens.iter().enumerate() { + let is_last = i == tokens.len() - 1; + batch + .add(token, i as i32, &[0], is_last) + .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?; + } + ctx.decode(&mut batch) + .map_err(|e| NodeError::Inference(format!("prompt decode failed: {e}")))?; + let prompt_eval_time_ms = prompt_start.elapsed().as_millis() as u64; + + // Build sampler chain (seed is passed via the dist sampler) + let mut samplers = vec![]; + if params.temperature > 0.0 { + samplers.push(LlamaSampler::top_p(params.top_p, 1)); + samplers.push(LlamaSampler::temp(params.temperature)); + samplers.push(LlamaSampler::dist(params.seed.unwrap_or(0))); + } else { + samplers.push(LlamaSampler::greedy()); + } + let mut sampler = LlamaSampler::chain_simple(samplers); + + // Generation loop + let gen_start = Instant::now(); + let mut generated_text = String::new(); + let mut generated_count: u32 = 0; + let mut logprobs: Vec = Vec::new(); + let mut current_pos = tokens.len() as i32; + let mut decoder = encoding_rs::UTF_8.new_decoder(); + + for _ in 0..params.max_tokens { + let new_token = sampler.sample(&ctx, -1); + sampler.accept(new_token); + + if self.model.is_eog_token(new_token) { + break; + } + + // Extract logprobs if this position was requested + let gen_index = generated_count as usize; + if params.logprob_positions.contains(&gen_index) { + if let Some(lp) = + self.extract_logprob(&ctx, -1, gen_index, new_token, params.logprob_top_k) + { + logprobs.push(lp); + } + } + + // Decode token to text + let piece = self + .model + .token_to_piece(new_token, &mut decoder, true, None) + .unwrap_or_default(); + generated_text.push_str(&piece); + generated_count += 1; + + // Stream callback + let stream_token = StreamToken { + text: piece, + index: gen_index, + }; + if let ControlFlow::Break(()) = on_token(stream_token) { + break; + } + + // Prepare next batch + batch.clear(); + batch + .add(new_token, current_pos, &[0], true) + .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?; + ctx.decode(&mut batch) + .map_err(|e| NodeError::Inference(format!("decode failed: {e}")))?; + current_pos += 1; + } + + let generation_time_ms = gen_start.elapsed().as_millis() as u64; + let tokens_per_second = if generation_time_ms > 0 { + (generated_count as f64) / (generation_time_ms as f64 / 1000.0) + } else { + 0.0 + }; + + let proof = if logprobs.is_empty() { + None + } else { + Some(InferenceProof { + logprobs, + kv_cache_hash: None, + }) + }; + + Ok(InferenceResult { + text: generated_text, + tokens_generated: generated_count, + prompt_tokens: prompt_token_count, + generation_time_ms, + prompt_eval_time_ms, + tokens_per_second, + proof, + }) + } + + /// Extract logprob data at a given batch index. + fn extract_logprob( + &self, + ctx: &llama_cpp_2::context::LlamaContext, + batch_idx: i32, + position: usize, + chosen_token: LlamaToken, + top_k: usize, + ) -> Option { + let logits = ctx.get_logits_ith(batch_idx); + + // Compute softmax to get log-probabilities + let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max); + let exp_sum: f32 = logits.iter().map(|&l| (l - max_logit).exp()).sum(); + let log_sum = max_logit + exp_sum.ln(); + + // Collect (token_id, logprob) for all vocab + let mut all_logprobs: Vec<(u32, f32)> = logits + .iter() + .enumerate() + .map(|(i, &l)| (i as u32, l - log_sum)) + .collect(); + + // Sort by logprob descending + all_logprobs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + + let chosen_id = chosen_token.0 as u32; + let chosen_logprob = all_logprobs + .iter() + .find(|(id, _)| *id == chosen_id) + .map(|(_, lp)| *lp) + .unwrap_or(f32::NEG_INFINITY); + + let chosen_text = token_to_string(&self.model, chosen_token); + + let top_k_entries: Vec<(String, f32)> = all_logprobs + .iter() + .take(top_k) + .map(|(id, lp)| { + let text = token_to_string(&self.model, LlamaToken(*id as i32)); + (text, *lp) + }) + .collect(); + + Some(TokenLogprob { + position, + token_id: chosen_id, + token_text: chosen_text, + logprob: chosen_logprob, + top_k: top_k_entries, + }) + } + + /// Compute a placeholder KV-cache hash from logits at a given position. + #[allow(dead_code)] + fn kv_cache_hash_placeholder( + ctx: &llama_cpp_2::context::LlamaContext, + batch_idx: i32, + ) -> [u8; 32] { + let logits = ctx.get_logits_ith(batch_idx); + let bytes: Vec = logits.iter().flat_map(|f| f.to_le_bytes()).collect(); + sha256hash(&bytes) + } +} diff --git a/src/inference/mod.rs b/src/inference/mod.rs new file mode 100644 index 00000000..8c077dc3 --- /dev/null +++ b/src/inference/mod.rs @@ -0,0 +1,7 @@ +pub mod benchmark; +pub mod engine; +pub mod proof; +pub mod stream; + +pub use engine::{GenerateParams, InferenceEngine, InferenceResult}; +pub use proof::InferenceProof; diff --git a/src/inference/proof.rs b/src/inference/proof.rs new file mode 100644 index 00000000..80fd2bc7 --- /dev/null +++ b/src/inference/proof.rs @@ -0,0 +1,65 @@ +use serde::{Deserialize, Serialize}; + +/// Log-probability information for a single token position. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenLogprob { + /// Position in the generated sequence. + pub position: usize, + /// The token ID chosen at this position. + pub token_id: u32, + /// The decoded text of the chosen token. + pub token_text: String, + /// The log-probability of the chosen token. + pub logprob: f32, + /// Top-k alternatives: (token_text, logprob). + pub top_k: Vec<(String, f32)>, +} + +/// Proof-of-inference data for validation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InferenceProof { + /// Log-probabilities at requested positions. + pub logprobs: Vec, + /// Optional KV-cache hash for determinism verification. + /// Placeholder: currently hashes logits at probed position. + pub kv_cache_hash: Option<[u8; 32]>, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_token_logprob_serde() { + let lp = TokenLogprob { + position: 5, + token_id: 1234, + token_text: "the".into(), + logprob: -0.5, + top_k: vec![("the".into(), -0.5), ("a".into(), -1.2)], + }; + let json = serde_json::to_string(&lp).unwrap(); + let roundtrip: TokenLogprob = serde_json::from_str(&json).unwrap(); + assert_eq!(roundtrip.position, 5); + assert_eq!(roundtrip.token_id, 1234); + assert_eq!(roundtrip.top_k.len(), 2); + } + + #[test] + fn test_inference_proof_serde() { + let proof = InferenceProof { + logprobs: vec![TokenLogprob { + position: 0, + token_id: 1, + token_text: "hello".into(), + logprob: -0.1, + top_k: vec![], + }], + kv_cache_hash: Some([0xAB; 32]), + }; + let packed = rmp_serde::to_vec(&proof).unwrap(); + let roundtrip: InferenceProof = rmp_serde::from_slice(&packed).unwrap(); + assert_eq!(roundtrip.logprobs.len(), 1); + assert!(roundtrip.kv_cache_hash.is_some()); + } +} diff --git a/src/inference/stream.rs b/src/inference/stream.rs new file mode 100644 index 00000000..e510c19b --- /dev/null +++ b/src/inference/stream.rs @@ -0,0 +1,39 @@ +use serde::{Deserialize, Serialize}; + +/// A single token emitted during streaming generation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StreamToken { + /// The decoded text of this token. + pub text: String, + /// The zero-based position of this token in the generated sequence. + pub index: usize, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_stream_token_serde() { + let token = StreamToken { + text: "hello".into(), + index: 0, + }; + let json = serde_json::to_string(&token).unwrap(); + let roundtrip: StreamToken = serde_json::from_str(&json).unwrap(); + assert_eq!(roundtrip.text, "hello"); + assert_eq!(roundtrip.index, 0); + } + + #[test] + fn test_stream_token_msgpack() { + let token = StreamToken { + text: "world".into(), + index: 42, + }; + let packed = rmp_serde::to_vec(&token).unwrap(); + let roundtrip: StreamToken = rmp_serde::from_slice(&packed).unwrap(); + assert_eq!(roundtrip.text, "world"); + assert_eq!(roundtrip.index, 42); + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 00000000..641ebe28 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,338 @@ +// Suppress dead-code warnings for public APIs not yet wired to networking. +#![allow(dead_code)] + +mod config; +mod error; +mod identity; +mod inference; +mod models; +mod network; +mod worker; + +use std::time::Duration; + +use clap::Parser; +use tracing_subscriber::EnvFilter; + +use config::{Cli, Command, Config}; +use identity::Identity; +use models::{ModelCache, ModelDownloader, default_registry, resolve_model}; +use network::{NodeMessage, RouterMessage}; +use network::RouterConnection; +use worker::{CompletedTask, Worker}; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Initialize tracing + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")), + ) + .init(); + + let cli = Cli::parse(); + + match cli.command { + Command::Start { + wallet, + model, + router_url, + gpu_layers, + max_concurrent, + data_dir, + insecure, + } => { + run_start(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, insecure).await?; + } + } + + Ok(()) +} + +async fn run_start( + wallet: String, + model: String, + router_url: String, + gpu_layers: i32, + max_concurrent: usize, + data_dir: Option, + insecure: bool, +) -> anyhow::Result<()> { + // Parse config + let config = Config::from_start_args(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, insecure)?; + + // Create identity + let identity = Identity::from_secret_hex(&config.secret_key_hex)?; + tracing::info!(address = %format!("0x{}", identity.address_hex), "node identity"); + + // Ensure directories exist + std::fs::create_dir_all(&config.data_dir)?; + std::fs::create_dir_all(&config.models_dir)?; + + // Resolve and download models + let registry = default_registry(); + let cache = ModelCache::new(config.models_dir.clone())?; + + // We need to keep one engine alive for inference; use the first model. + let mut chat_template = "chatml".to_string(); + let mut engine_and_tps: Option<(inference::InferenceEngine, f64)> = None; + + for model_name in &config.model_names { + let spec = resolve_model(model_name, ®istry) + .ok_or_else(|| error::NodeError::Model(format!("unknown model: {model_name}")))?; + + // Check local cache first + let model_path = if let Some(path) = cache.get_local_path(&spec) { + tracing::info!(model = %model_name, path = %path.display(), "model found in cache"); + path + } else { + // Download from HuggingFace + let hf_path = ModelDownloader::download(&spec).await?; + + // Verify SHA-256 if specified + if let Some(ref expected_sha) = spec.sha256 { + tracing::info!(model = %model_name, "verifying SHA-256"); + if !ModelCache::verify_sha256(&hf_path, expected_sha)? { + anyhow::bail!("SHA-256 mismatch for model {model_name}"); + } + } + + // Link into our cache + cache.link_model(&spec, &hf_path)? + }; + + // Remember chat template from the spec + if let Some(ref tmpl) = spec.chat_template { + chat_template = tmpl.clone(); + } + + // Load model and run benchmark in blocking thread + let model_name_owned = model_name.clone(); + let gpu = config.gpu_layers; + let (engine, tps) = tokio::task::spawn_blocking(move || { + let engine = inference::InferenceEngine::load(&model_path, gpu)?; + let tps_result = engine.benchmark(&model_name_owned)?; + Ok::<_, error::NodeError>((engine, tps_result.generation_tps)) + }) + .await??; + + tracing::info!(tps = %format!("{tps:.1}"), model = %model_name, "benchmark complete"); + engine_and_tps = Some((engine, tps)); + } + + let (engine, tps) = engine_and_tps.ok_or_else(|| { + error::NodeError::Config("no models loaded".into()) + })?; + + // Build the worker + let mut worker = Worker::new( + engine, + chat_template, + config.model_names.clone(), + config.max_concurrent, + ); + + // Attempt router connection; go offline if unavailable + let mut connection: Option = match RouterConnection::connect( + &config.router_url, + config.insecure, + &identity, + config.model_names.clone(), + tps, + worker.capacity(), + ) + .await + { + Ok(conn) => { + tracing::info!(node_id = %conn.node_id, "connected to router"); + Some(conn) + } + Err(e) => { + tracing::warn!(%e, "failed to connect to router, running in offline mode"); + None + } + }; + + tracing::info!( + router = %config.router_url, + models = ?config.model_names, + max_concurrent = config.max_concurrent, + insecure = config.insecure, + online = connection.is_some(), + "node ready" + ); + + // Main event loop + loop { + let event = tokio::select! { + msg = recv_router_msg(&mut connection) => Event::RouterMsg(msg), + Some(done) = worker.next_completed() => Event::TaskDone(done), + _ = tokio::signal::ctrl_c() => Event::Shutdown, + }; + + match event { + Event::RouterMsg(Ok(Some(msg))) => { + handle_router_message(msg, &mut worker, &mut connection).await; + } + Event::RouterMsg(Ok(None)) => { + // Stream closed cleanly + tracing::warn!("router stream closed, switching to offline mode"); + if let Some(ref conn) = connection { + conn.close(); + } + connection = None; + } + Event::RouterMsg(Err(e)) => { + tracing::warn!(%e, "router communication error"); + if let Some(ref conn) = connection { + conn.close(); + } + connection = None; + + // Attempt reconnect + tracing::info!("will attempt reconnect on next cycle"); + } + Event::TaskDone(completed) => { + handle_completed_task(completed, &mut connection).await; + } + Event::Shutdown => { + tracing::info!("shutdown signal received"); + break; + } + } + } + + // Graceful shutdown: drain in-flight tasks with 30s timeout + if worker.has_in_flight() { + tracing::info!("draining in-flight tasks (30s timeout)"); + let drain_deadline = tokio::time::Instant::now() + Duration::from_secs(30); + + loop { + tokio::select! { + Some(completed) = worker.next_completed() => { + handle_completed_task(completed, &mut connection).await; + } + _ = tokio::time::sleep_until(drain_deadline) => { + tracing::warn!("drain timeout reached, dropping remaining tasks"); + break; + } + } + if !worker.has_in_flight() { + break; + } + } + } + + if let Some(ref conn) = connection { + conn.close(); + } + tracing::info!("shutdown complete"); + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Event types for the select! loop +// --------------------------------------------------------------------------- + +enum Event { + RouterMsg(Result, error::NodeError>), + TaskDone(CompletedTask), + Shutdown, +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Receive a router message, or sleep 10s when offline (to allow periodic reconnect). +async fn recv_router_msg( + connection: &mut Option, +) -> Result, error::NodeError> { + match connection { + Some(ref mut conn) => conn.recv().await, + None => { + // Offline: sleep then signal a reconnect attempt + tokio::time::sleep(Duration::from_secs(10)).await; + Err(error::NodeError::Network("offline, attempting reconnect".into())) + } + } +} + +/// Handle a router message: dispatch tasks, respond to pings, etc. +async fn handle_router_message( + msg: RouterMessage, + worker: &mut Worker, + connection: &mut Option, +) { + match msg { + RouterMessage::TaskAssignment { + task_id, + model, + messages, + max_tokens, + temperature, + validation, + } => { + tracing::info!(%task_id, %model, "received task assignment"); + match worker.try_accept(task_id, &model, messages, max_tokens, temperature, validation) + { + Ok(()) => { + tracing::debug!(%task_id, "task accepted"); + } + Err(reason) => { + tracing::warn!(%task_id, ?reason, "task rejected"); + if let Some(ref mut conn) = connection { + let reject = NodeMessage::TaskRejected { task_id, reason }; + if let Err(e) = conn.send(&reject).await { + tracing::error!(%e, "failed to send rejection"); + } + } + } + } + } + RouterMessage::Ping => { + tracing::debug!("received ping"); + if let Some(ref mut conn) = connection { + let status = NodeMessage::StatusUpdate { + models: worker.model_names().to_vec(), + capacity: worker.capacity(), + version: env!("CARGO_PKG_VERSION").to_string(), + }; + if let Err(e) = conn.send(&status).await { + tracing::error!(%e, "failed to send status update"); + } + } + } + RouterMessage::Challenge { challenge } => { + tracing::debug!(?challenge, "received challenge (not yet implemented)"); + // TODO: implement challenge-response + } + RouterMessage::ModelRegistryUpdate { entries } => { + tracing::info!(count = entries.len(), "received model registry update (not yet implemented)"); + // TODO: handle model registry updates + } + } +} + +/// Handle a completed inference task: send result or log if offline. +async fn handle_completed_task( + completed: CompletedTask, + connection: &mut Option, +) { + match completed.result { + Ok(msg) => { + tracing::info!(task_id = %completed.task_id, "task completed"); + if let Some(ref mut conn) = connection { + if let Err(e) = conn.send(&msg).await { + tracing::error!(%e, task_id = %completed.task_id, "failed to send result"); + } + } else { + tracing::warn!(task_id = %completed.task_id, "task completed but offline, result dropped"); + } + } + Err(e) => { + tracing::error!(%e, task_id = %completed.task_id, "task failed"); + } + } +} diff --git a/src/models/cache.rs b/src/models/cache.rs new file mode 100644 index 00000000..590e7435 --- /dev/null +++ b/src/models/cache.rs @@ -0,0 +1,103 @@ +use std::path::{Path, PathBuf}; + +use sha2::{Digest, Sha256}; + +use crate::error::NodeError; +use crate::models::registry::ModelSpec; + +/// Manages local model file cache. +pub struct ModelCache { + pub cache_dir: PathBuf, +} + +impl ModelCache { + /// Create a new cache backed by the given directory. + pub fn new(cache_dir: PathBuf) -> Result { + std::fs::create_dir_all(&cache_dir)?; + Ok(ModelCache { cache_dir }) + } + + /// Check if a model's GGUF is already present in our cache. + pub fn get_local_path(&self, spec: &ModelSpec) -> Option { + let path = self.cache_dir.join(&spec.hf_file); + if path.exists() { + Some(path) + } else { + None + } + } + + /// Verify a file's SHA-256 against an expected hex digest. + /// Returns Ok(true) if matches, Ok(false) if mismatch, Err on I/O failure. + pub fn verify_sha256(path: &Path, expected_hex: &str) -> Result { + let mut file = std::fs::File::open(path)?; + let mut hasher = Sha256::new(); + std::io::copy(&mut file, &mut hasher)?; + let actual = hex::encode(hasher.finalize()); + Ok(actual == expected_hex.to_lowercase()) + } + + /// Create a symlink from our cache dir to the hf-hub cached file. + /// This avoids duplicating multi-GB files on disk. + pub fn link_model(&self, spec: &ModelSpec, source: &Path) -> Result { + let dest = self.cache_dir.join(&spec.hf_file); + if dest.exists() { + // Already linked or copied + return Ok(dest); + } + + #[cfg(unix)] + std::os::unix::fs::symlink(source, &dest)?; + + #[cfg(not(unix))] + std::fs::copy(source, &dest)?; + + Ok(dest) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + + #[test] + fn test_verify_sha256() { + let dir = std::env::temp_dir().join("dria-cache-test"); + std::fs::create_dir_all(&dir).unwrap(); + let file_path = dir.join("test.bin"); + let mut f = std::fs::File::create(&file_path).unwrap(); + f.write_all(b"hello world").unwrap(); + drop(f); + + // SHA-256 of "hello world" + let expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"; + assert!(ModelCache::verify_sha256(&file_path, expected).unwrap()); + assert!(!ModelCache::verify_sha256(&file_path, "0000000000000000000000000000000000000000000000000000000000000000").unwrap()); + + std::fs::remove_dir_all(&dir).ok(); + } + + #[test] + fn test_cache_local_path() { + let dir = std::env::temp_dir().join("dria-cache-test-2"); + let cache = ModelCache::new(dir.clone()).unwrap(); + + let spec = ModelSpec { + name: "test:1b".into(), + hf_repo: "test/repo".into(), + hf_file: "model.gguf".into(), + sha256: None, + chat_template: None, + }; + + // Not present initially + assert!(cache.get_local_path(&spec).is_none()); + + // Create the file + std::fs::write(dir.join("model.gguf"), b"fake").unwrap(); + assert!(cache.get_local_path(&spec).is_some()); + + std::fs::remove_dir_all(&dir).ok(); + } +} diff --git a/src/models/download.rs b/src/models/download.rs new file mode 100644 index 00000000..997ecea1 --- /dev/null +++ b/src/models/download.rs @@ -0,0 +1,46 @@ +use std::path::PathBuf; + +use hf_hub::api::tokio::ApiBuilder; + +use crate::error::NodeError; +use crate::models::registry::ModelSpec; + +/// Downloads GGUF models from HuggingFace using the `hf-hub` crate. +pub struct ModelDownloader; + +impl ModelDownloader { + /// Download a model's GGUF file from HuggingFace. + /// + /// Uses hf-hub's built-in cache (defaults to `~/.cache/huggingface/`) + /// and supports automatic resume of interrupted downloads. + /// + /// Returns the local path to the downloaded file. + pub async fn download(spec: &ModelSpec) -> Result { + let api = ApiBuilder::new() + .with_progress(true) + .build() + .map_err(|e| NodeError::Model(format!("failed to create HF API client: {e}")))?; + + let repo = api.model(spec.hf_repo.clone()); + + tracing::info!( + model = %spec.name, + repo = %spec.hf_repo, + file = %spec.hf_file, + "downloading model from HuggingFace" + ); + + let path = repo + .get(&spec.hf_file) + .await + .map_err(|e| NodeError::Model(format!("failed to download {}: {e}", spec.name)))?; + + tracing::info!( + model = %spec.name, + path = %path.display(), + "model download complete" + ); + + Ok(path) + } +} diff --git a/src/models/mod.rs b/src/models/mod.rs new file mode 100644 index 00000000..5d814931 --- /dev/null +++ b/src/models/mod.rs @@ -0,0 +1,8 @@ +pub mod cache; +pub mod download; +pub mod registry; +pub mod template; + +pub use cache::ModelCache; +pub use download::ModelDownloader; +pub use registry::{default_registry, resolve_model}; diff --git a/src/models/registry.rs b/src/models/registry.rs new file mode 100644 index 00000000..41dce7c7 --- /dev/null +++ b/src/models/registry.rs @@ -0,0 +1,132 @@ +use std::collections::HashMap; + +/// Specification for a model: shortname mapped to HuggingFace GGUF location. +#[derive(Debug, Clone)] +pub struct ModelSpec { + /// Short name used by users (e.g. "gemma3:4b") + pub name: String, + /// HuggingFace repository (e.g. "bartowski/gemma-3-4b-it-GGUF") + pub hf_repo: String, + /// Filename within the repo (e.g. "gemma-3-4b-it-Q4_K_M.gguf") + pub hf_file: String, + /// Expected SHA-256 hex digest for verification (None = skip verification) + pub sha256: Option, + /// Chat template identifier (e.g. "gemma", "llama3", "chatml") + pub chat_template: Option, +} + +/// Build the default model registry with all 9 supported models. +pub fn default_registry() -> HashMap { + let entries = vec![ + ModelSpec { + name: "gemma3:4b".into(), + hf_repo: "bartowski/google_gemma-3-4b-it-GGUF".into(), + hf_file: "google_gemma-3-4b-it-Q4_K_M.gguf".into(), + sha256: None, + chat_template: Some("gemma".into()), + }, + ModelSpec { + name: "gemma3:12b".into(), + hf_repo: "bartowski/google_gemma-3-12b-it-GGUF".into(), + hf_file: "google_gemma-3-12b-it-Q4_K_M.gguf".into(), + sha256: None, + chat_template: Some("gemma".into()), + }, + ModelSpec { + name: "gemma3:27b".into(), + hf_repo: "bartowski/google_gemma-3-27b-it-GGUF".into(), + hf_file: "google_gemma-3-27b-it-Q4_K_M.gguf".into(), + sha256: None, + chat_template: Some("gemma".into()), + }, + ModelSpec { + name: "llama3.1:8b".into(), + hf_repo: "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF".into(), + hf_file: "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf".into(), + sha256: None, + chat_template: Some("llama3".into()), + }, + ModelSpec { + name: "llama3.2:1b".into(), + hf_repo: "bartowski/Llama-3.2-1B-Instruct-GGUF".into(), + hf_file: "Llama-3.2-1B-Instruct-Q4_K_M.gguf".into(), + sha256: None, + chat_template: Some("llama3".into()), + }, + ModelSpec { + name: "llama3.3:70b".into(), + hf_repo: "bartowski/Llama-3.3-70B-Instruct-GGUF".into(), + hf_file: "Llama-3.3-70B-Instruct-Q4_K_M.gguf".into(), + sha256: None, + chat_template: Some("llama3".into()), + }, + ModelSpec { + name: "mistral-nemo:12b".into(), + hf_repo: "bartowski/Mistral-Nemo-Instruct-2407-GGUF".into(), + hf_file: "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf".into(), + sha256: None, + chat_template: Some("chatml".into()), + }, + ModelSpec { + name: "qwen3:8b".into(), + hf_repo: "bartowski/Qwen3-8B-GGUF".into(), + hf_file: "Qwen3-8B-Q4_K_M.gguf".into(), + sha256: None, + chat_template: Some("chatml".into()), + }, + ModelSpec { + name: "qwen3:32b".into(), + hf_repo: "bartowski/Qwen3-32B-GGUF".into(), + hf_file: "Qwen3-32B-Q4_K_M.gguf".into(), + sha256: None, + chat_template: Some("chatml".into()), + }, + ]; + + entries.into_iter().map(|s| (s.name.clone(), s)).collect() +} + +/// Resolve a user-provided model name to a ModelSpec from the registry. +pub fn resolve_model(name: &str, registry: &HashMap) -> Option { + registry.get(name).cloned() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_registry_has_all_models() { + let reg = default_registry(); + let expected = [ + "gemma3:4b", + "gemma3:12b", + "gemma3:27b", + "llama3.1:8b", + "llama3.2:1b", + "llama3.3:70b", + "mistral-nemo:12b", + "qwen3:8b", + "qwen3:32b", + ]; + for name in &expected { + assert!(reg.contains_key(*name), "missing model: {name}"); + } + assert_eq!(reg.len(), 9); + } + + #[test] + fn test_resolve_known_model() { + let reg = default_registry(); + let spec = resolve_model("gemma3:4b", ®).expect("should resolve"); + assert_eq!(spec.name, "gemma3:4b"); + assert!(spec.hf_repo.contains("gemma")); + assert!(spec.hf_file.ends_with(".gguf")); + } + + #[test] + fn test_resolve_unknown_model() { + let reg = default_registry(); + assert!(resolve_model("nonexistent:1b", ®).is_none()); + } +} diff --git a/src/models/template.rs b/src/models/template.rs new file mode 100644 index 00000000..060d4928 --- /dev/null +++ b/src/models/template.rs @@ -0,0 +1,163 @@ +use serde::{Deserialize, Serialize}; + +/// A single message in a chat conversation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: String, + pub content: String, +} + +/// Apply a chat template to a list of messages, producing a formatted prompt string. +/// +/// Supported templates: "chatml" (qwen, mistral-nemo), "llama3", "gemma". +/// Falls back to chatml for unknown template names. +pub fn apply_chat_template(template_name: &str, messages: &[ChatMessage]) -> String { + match template_name { + "llama3" => format_llama3(messages), + "gemma" => format_gemma(messages), + _ => format_chatml(messages), // chatml is the default fallback + } +} + +/// ChatML format used by Qwen, Mistral-Nemo, and others. +/// ```text +/// <|im_start|>system +/// You are a helpful assistant.<|im_end|> +/// <|im_start|>user +/// Hello<|im_end|> +/// <|im_start|>assistant +/// ``` +fn format_chatml(messages: &[ChatMessage]) -> String { + let mut out = String::new(); + for msg in messages { + out.push_str(&format!( + "<|im_start|>{}\n{}<|im_end|>\n", + msg.role, msg.content + )); + } + out.push_str("<|im_start|>assistant\n"); + out +} + +/// Llama 3 instruct format. +/// ```text +/// <|begin_of_text|><|start_header_id|>system<|end_header_id|> +/// +/// You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|> +/// +/// Hello<|eot_id|><|start_header_id|>assistant<|end_header_id|> +/// +/// ``` +fn format_llama3(messages: &[ChatMessage]) -> String { + let mut out = String::from("<|begin_of_text|>"); + for msg in messages { + out.push_str(&format!( + "<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>", + msg.role, msg.content + )); + } + out.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n"); + out +} + +/// Gemma instruct format. +/// ```text +/// user +/// Hello +/// model +/// ``` +fn format_gemma(messages: &[ChatMessage]) -> String { + let mut out = String::new(); + for msg in messages { + // Gemma uses "model" instead of "assistant" + let role = if msg.role == "assistant" { + "model" + } else { + &msg.role + }; + out.push_str(&format!( + "{}\n{}\n", + role, msg.content + )); + } + out.push_str("model\n"); + out +} + +#[cfg(test)] +mod tests { + use super::*; + + fn sample_messages() -> Vec { + vec![ + ChatMessage { + role: "system".into(), + content: "You are a helpful assistant.".into(), + }, + ChatMessage { + role: "user".into(), + content: "Hello".into(), + }, + ] + } + + #[test] + fn test_chatml_format() { + let result = apply_chat_template("chatml", &sample_messages()); + assert!(result.contains("<|im_start|>system")); + assert!(result.contains("You are a helpful assistant.<|im_end|>")); + assert!(result.contains("<|im_start|>user")); + assert!(result.contains("Hello<|im_end|>")); + assert!(result.ends_with("<|im_start|>assistant\n")); + } + + #[test] + fn test_llama3_format() { + let result = apply_chat_template("llama3", &sample_messages()); + assert!(result.starts_with("<|begin_of_text|>")); + assert!(result.contains("<|start_header_id|>system<|end_header_id|>")); + assert!(result.contains("<|start_header_id|>user<|end_header_id|>")); + assert!(result.ends_with("<|start_header_id|>assistant<|end_header_id|>\n\n")); + } + + #[test] + fn test_gemma_format() { + let msgs = vec![ + ChatMessage { + role: "user".into(), + content: "Hello".into(), + }, + ChatMessage { + role: "assistant".into(), + content: "Hi there!".into(), + }, + ChatMessage { + role: "user".into(), + content: "How are you?".into(), + }, + ]; + let result = apply_chat_template("gemma", &msgs); + assert!(result.contains("user")); + // "assistant" should be mapped to "model" + assert!(result.contains("model\nHi there!")); + assert!(result.ends_with("model\n")); + } + + #[test] + fn test_unknown_template_falls_back_to_chatml() { + let result = apply_chat_template("unknown-template", &sample_messages()); + assert!(result.contains("<|im_start|>")); + } + + #[test] + fn test_chat_message_serde() { + let msg = ChatMessage { + role: "user".into(), + content: "hello".into(), + }; + let json = serde_json::to_string(&msg).unwrap(); + let roundtrip: ChatMessage = serde_json::from_str(&json).unwrap(); + assert_eq!(roundtrip.role, "user"); + assert_eq!(roundtrip.content, "hello"); + } +} diff --git a/src/network/auth.rs b/src/network/auth.rs new file mode 100644 index 00000000..af074c6c --- /dev/null +++ b/src/network/auth.rs @@ -0,0 +1,163 @@ +use serde::{Deserialize, Serialize}; + +use crate::error::NodeError; +use crate::identity::Identity; +use crate::network::protocol::{Capacity, read_framed, write_framed}; + +// --------------------------------------------------------------------------- +// Auth protocol types (separate from NodeMessage/RouterMessage since auth +// happens before the main protocol phase) +// --------------------------------------------------------------------------- + +/// Router sends this challenge after QUIC connection is established. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChallengeMessage { + pub challenge: [u8; 32], +} + +/// Node responds with signed challenge + metadata. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthRequest { + /// Ethereum address (hex, no 0x prefix). + pub address: String, + /// Signature over SHA-256(challenge). + pub signature: Vec, + /// Recovery ID for the signature. + pub recovery_id: u8, + /// Models this node can serve. + pub models: Vec, + /// Benchmark tokens-per-second. + pub tps: f64, + /// Node software version. + pub version: String, + /// Current capacity. + pub capacity: Capacity, +} + +/// Router responds with auth result. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthResponse { + pub authenticated: bool, + /// Assigned node ID on success. + pub node_id: Option, + /// Error message on failure. + pub error: Option, +} + +// --------------------------------------------------------------------------- +// Handshake +// --------------------------------------------------------------------------- + +/// Perform the authentication handshake on an already-opened bi-directional stream. +/// +/// 1. Read `ChallengeMessage` from the router +/// 2. Sign the challenge with our identity +/// 3. Send `AuthRequest` with node metadata +/// 4. Read `AuthResponse` and return the assigned node_id +pub async fn authenticate( + send: &mut quinn::SendStream, + recv: &mut quinn::RecvStream, + identity: &Identity, + models: Vec, + tps: f64, + capacity: Capacity, +) -> Result { + // 1. Read challenge + let challenge_msg: ChallengeMessage = read_framed(recv) + .await? + .ok_or_else(|| NodeError::Network("connection closed before challenge".into()))?; + + // 2. Sign challenge + let (signature, recovery_id) = identity.sign(&challenge_msg.challenge); + + // 3. Send auth request + let auth_req = AuthRequest { + address: identity.address_hex.clone(), + signature: signature.serialize().to_vec(), + recovery_id: recovery_id.serialize(), + models, + tps, + version: env!("CARGO_PKG_VERSION").to_string(), + capacity, + }; + write_framed(send, &auth_req).await?; + + // 4. Read auth response + let auth_resp: AuthResponse = read_framed(recv) + .await? + .ok_or_else(|| NodeError::Network("connection closed before auth response".into()))?; + + if auth_resp.authenticated { + auth_resp + .node_id + .ok_or_else(|| NodeError::Network("auth succeeded but no node_id returned".into())) + } else { + Err(NodeError::Network(format!( + "authentication failed: {}", + auth_resp.error.unwrap_or_else(|| "unknown".into()) + ))) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_challenge_message_roundtrip() { + let msg = ChallengeMessage { + challenge: [0x42; 32], + }; + let packed = rmp_serde::to_vec(&msg).unwrap(); + let roundtrip: ChallengeMessage = rmp_serde::from_slice(&packed).unwrap(); + assert_eq!(roundtrip.challenge, [0x42; 32]); + } + + #[test] + fn test_auth_request_roundtrip() { + let req = AuthRequest { + address: "deadbeef".into(), + signature: vec![1, 2, 3], + recovery_id: 0, + models: vec!["gemma3:4b".into()], + tps: 42.5, + version: "2.0.0".into(), + capacity: Capacity { free: 1, max: 2 }, + }; + let packed = rmp_serde::to_vec(&req).unwrap(); + let roundtrip: AuthRequest = rmp_serde::from_slice(&packed).unwrap(); + assert_eq!(roundtrip.address, "deadbeef"); + assert_eq!(roundtrip.models, vec!["gemma3:4b"]); + assert!((roundtrip.tps - 42.5).abs() < f64::EPSILON); + } + + #[test] + fn test_auth_response_success_roundtrip() { + let resp = AuthResponse { + authenticated: true, + node_id: Some("node-123".into()), + error: None, + }; + let packed = rmp_serde::to_vec(&resp).unwrap(); + let roundtrip: AuthResponse = rmp_serde::from_slice(&packed).unwrap(); + assert!(roundtrip.authenticated); + assert_eq!(roundtrip.node_id.unwrap(), "node-123"); + } + + #[test] + fn test_auth_response_failure_roundtrip() { + let resp = AuthResponse { + authenticated: false, + node_id: None, + error: Some("bad signature".into()), + }; + let packed = rmp_serde::to_vec(&resp).unwrap(); + let roundtrip: AuthResponse = rmp_serde::from_slice(&packed).unwrap(); + assert!(!roundtrip.authenticated); + assert_eq!(roundtrip.error.unwrap(), "bad signature"); + } +} diff --git a/src/network/connection.rs b/src/network/connection.rs new file mode 100644 index 00000000..722ded67 --- /dev/null +++ b/src/network/connection.rs @@ -0,0 +1,537 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use quinn::{ClientConfig, Endpoint, TransportConfig}; +use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; +use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; +use rustls::DigitallySignedStruct; + +use crate::error::NodeError; +use crate::identity::Identity; +use crate::network::auth::authenticate; +use crate::network::protocol::{Capacity, NodeMessage, RouterMessage, read_framed, write_framed}; + +/// Manages a QUIC connection to the router with a single bi-directional stream. +pub struct RouterConnection { + /// The underlying QUIC endpoint (kept alive for the connection's lifetime). + endpoint: Endpoint, + /// The underlying QUIC connection. + connection: quinn::Connection, + /// Send half of the bi-directional stream. + send: quinn::SendStream, + /// Receive half of the bi-directional stream. + recv: quinn::RecvStream, + /// Router URL for reconnection. + router_url: String, + /// Whether to skip TLS verification. + insecure: bool, + /// Assigned node ID from the router. + pub node_id: String, +} + +impl RouterConnection { + /// Establish a QUIC connection to the router, open a bi-stream, and authenticate. + pub async fn connect( + router_url: &str, + insecure: bool, + identity: &Identity, + models: Vec, + tps: f64, + capacity: Capacity, + ) -> Result { + let (host, port) = parse_url(router_url)?; + let addr = resolve_addr(&host, port).await?; + + let client_config = build_client_config(insecure)?; + let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap()) + .map_err(|e| NodeError::Network(format!("failed to create QUIC endpoint: {e}")))?; + endpoint.set_default_client_config(client_config); + + let connection = endpoint + .connect(addr, &host) + .map_err(|e| NodeError::Network(format!("QUIC connect failed: {e}")))? + .await + .map_err(|e| NodeError::Network(format!("QUIC handshake failed: {e}")))?; + + // Accept the bi-stream opened by the router (the router initiates the stream). + let (mut send, mut recv) = connection + .accept_bi() + .await + .map_err(|e| NodeError::Network(format!("failed to accept bi-stream: {e}")))?; + let node_id = authenticate(&mut send, &mut recv, identity, models, tps, capacity).await?; + tracing::info!(%node_id, "authenticated with router"); + + Ok(RouterConnection { + endpoint, + connection, + send, + recv, + router_url: router_url.to_string(), + insecure, + node_id, + }) + } + + /// Send a message to the router. + pub async fn send(&mut self, msg: &NodeMessage) -> Result<(), NodeError> { + write_framed(&mut self.send, msg).await + } + + /// Receive a message from the router. Returns `None` on clean stream close. + pub async fn recv(&mut self) -> Result, NodeError> { + read_framed(&mut self.recv).await + } + + /// Attempt to reconnect with exponential backoff. + /// + /// Retries: 1s → 2s → 4s → 8s → ... → 60s cap. + pub async fn reconnect( + &mut self, + identity: &Identity, + models: Vec, + tps: f64, + capacity: Capacity, + ) -> Result<(), NodeError> { + let mut delay = Duration::from_secs(1); + let max_delay = Duration::from_secs(60); + + loop { + tracing::info!(delay_secs = delay.as_secs(), "attempting reconnect"); + tokio::time::sleep(delay).await; + + match Self::connect(&self.router_url, self.insecure, identity, models.clone(), tps, capacity.clone()) + .await + { + Ok(new_conn) => { + self.endpoint = new_conn.endpoint; + self.connection = new_conn.connection; + self.send = new_conn.send; + self.recv = new_conn.recv; + self.node_id = new_conn.node_id; + tracing::info!(node_id = %self.node_id, "reconnected to router"); + return Ok(()); + } + Err(e) => { + tracing::warn!(%e, "reconnect failed"); + delay = (delay * 2).min(max_delay); + } + } + } + } + + /// Close the connection and endpoint gracefully. + pub fn close(&self) { + self.connection.close(0u32.into(), b"shutdown"); + self.endpoint.close(0u32.into(), b"shutdown"); + } +} + +// --------------------------------------------------------------------------- +// TLS configuration +// --------------------------------------------------------------------------- + +fn build_client_config(insecure: bool) -> Result { + let crypto = if insecure { + rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(SkipServerVerification)) + .with_no_client_auth() + } else { + let mut root_store = rustls::RootCertStore::empty(); + for cert in rustls_native_certs::load_native_certs().certs { + root_store.add(cert).ok(); + } + rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth() + }; + + let mut transport = TransportConfig::default(); + transport.keep_alive_interval(Some(Duration::from_secs(20))); + transport.max_idle_timeout(Some( + Duration::from_secs(60) + .try_into() + .map_err(|e| NodeError::Network(format!("invalid idle timeout: {e}")))?, + )); + + let mut client_config = ClientConfig::new(Arc::new( + quinn::crypto::rustls::QuicClientConfig::try_from(crypto) + .map_err(|e| NodeError::Network(format!("QUIC crypto config: {e}")))?, + )); + client_config.transport_config(Arc::new(transport)); + + Ok(client_config) +} + +/// TLS verifier that accepts any certificate (for development/testing with `--insecure`). +#[derive(Debug)] +struct SkipServerVerification; + +impl ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + rustls::crypto::ring::default_provider() + .signature_verification_algorithms + .supported_schemes() + } +} + +// --------------------------------------------------------------------------- +// URL parsing and DNS resolution +// --------------------------------------------------------------------------- + +fn parse_url(url: &str) -> Result<(String, u16), NodeError> { + // Support both "host:port" and "https://host:port" formats + let stripped = url + .strip_prefix("https://") + .or_else(|| url.strip_prefix("quic://")) + .unwrap_or(url); + + let (host, port) = if let Some((h, p)) = stripped.rsplit_once(':') { + let port: u16 = p + .parse() + .map_err(|_| NodeError::Network(format!("invalid port in URL: {url}")))?; + (h.to_string(), port) + } else { + (stripped.to_string(), 4001) // default QUIC port + }; + + Ok((host, port)) +} + +async fn resolve_addr(host: &str, port: u16) -> Result { + // Try parsing as IP address first + if let Ok(ip) = host.parse::() { + return Ok(SocketAddr::new(ip, port)); + } + + // DNS resolution + let addrs: Vec = tokio::net::lookup_host(format!("{host}:{port}")) + .await + .map_err(|e| NodeError::Network(format!("DNS resolution failed for {host}: {e}")))? + .collect(); + + addrs + .into_iter() + .next() + .ok_or_else(|| NodeError::Network(format!("no addresses found for {host}"))) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_url_with_scheme() { + let (host, port) = parse_url("https://router.dria.co:4001").unwrap(); + assert_eq!(host, "router.dria.co"); + assert_eq!(port, 4001); + } + + #[test] + fn test_parse_url_quic_scheme() { + let (host, port) = parse_url("quic://router.dria.co:5000").unwrap(); + assert_eq!(host, "router.dria.co"); + assert_eq!(port, 5000); + } + + #[test] + fn test_parse_url_no_scheme() { + let (host, port) = parse_url("router.dria.co:4001").unwrap(); + assert_eq!(host, "router.dria.co"); + assert_eq!(port, 4001); + } + + #[test] + fn test_parse_url_default_port() { + let (host, port) = parse_url("https://router.dria.co").unwrap(); + assert_eq!(host, "router.dria.co"); + assert_eq!(port, 4001); + } + + #[test] + fn test_parse_url_ip_address() { + let (host, port) = parse_url("127.0.0.1:4001").unwrap(); + assert_eq!(host, "127.0.0.1"); + assert_eq!(port, 4001); + } + + #[test] + fn test_build_client_config_insecure() { + let config = build_client_config(true); + assert!(config.is_ok()); + } + + #[test] + fn test_build_client_config_secure() { + let config = build_client_config(false); + assert!(config.is_ok()); + } + + /// Integration test: QUIC raw stream exchange with local self-signed server. + /// Tests the full flow: connect, open stream, exchange framed messages. + #[tokio::test] + async fn test_quic_connection_with_local_server() { + tokio::time::timeout(Duration::from_secs(10), async { + // Generate self-signed cert + let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert_der = CertificateDer::from(cert.cert); + let key_der = + rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()); + + // Build server config + let server_crypto = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert_der.clone()], key_der.into()) + .unwrap(); + + let mut server_config = quinn::ServerConfig::with_crypto(Arc::new( + quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto).unwrap(), + )); + let mut transport = TransportConfig::default(); + transport.max_concurrent_bidi_streams(8u32.into()); + server_config.transport_config(Arc::new(transport)); + + // Bind server + let server_endpoint = + Endpoint::server(server_config, "127.0.0.1:0".parse().unwrap()).unwrap(); + let server_addr = server_endpoint.local_addr().unwrap(); + + // Use a oneshot to signal server completion + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + + // Spawn server task — the server opens the bi-stream (router initiates) + tokio::spawn(async move { + let incoming = server_endpoint.accept().await.unwrap(); + let server_conn = incoming.await.unwrap(); + + // Server opens a bi-stream to the client + let (mut send, mut recv) = server_conn.open_bi().await.unwrap(); + + // Send challenge + let challenge = crate::network::auth::ChallengeMessage { + challenge: [0xAA; 32], + }; + write_framed(&mut send, &challenge).await.unwrap(); + + // Read auth request + let auth_req: crate::network::auth::AuthRequest = + read_framed(&mut recv).await.unwrap().unwrap(); + assert!(!auth_req.address.is_empty()); + assert_eq!(auth_req.models, vec!["gemma3:4b"]); + + // Send auth response + let auth_resp = crate::network::auth::AuthResponse { + authenticated: true, + node_id: Some("test-node-1".into()), + error: None, + }; + write_framed(&mut send, &auth_resp).await.unwrap(); + + // Read a NodeMessage + let msg: NodeMessage = read_framed(&mut recv).await.unwrap().unwrap(); + match msg { + NodeMessage::StatusUpdate { version, .. } => { + assert_eq!(version, env!("CARGO_PKG_VERSION")); + } + _ => panic!("expected StatusUpdate"), + } + + // Signal completion + let _ = tx.send(()); + server_conn.close(0u32.into(), b"done"); + server_endpoint.close(0u32.into(), b"shutdown"); + }); + + // Build client config + let client_config = build_client_config(true).unwrap(); + let mut client_endpoint = + Endpoint::client("0.0.0.0:0".parse().unwrap()).unwrap(); + client_endpoint.set_default_client_config(client_config); + + // Connect to server + let client_conn = client_endpoint + .connect(server_addr, "localhost") + .unwrap() + .await + .unwrap(); + + // Client accepts the bi-stream opened by the server + let (mut send, mut recv) = client_conn.accept_bi().await.unwrap(); + + // Run the auth handshake + let identity = Identity::from_secret_hex( + "6472696164726961647269616472696164726961647269616472696164726961", + ) + .unwrap(); + + let node_id = authenticate( + &mut send, + &mut recv, + &identity, + vec!["gemma3:4b".into()], + 42.0, + Capacity { free: 1, max: 2 }, + ) + .await + .unwrap(); + + assert_eq!(node_id, "test-node-1"); + + // Send a status update + let status = NodeMessage::StatusUpdate { + models: vec!["gemma3:4b".into()], + capacity: Capacity { free: 1, max: 2 }, + version: env!("CARGO_PKG_VERSION").to_string(), + }; + write_framed(&mut send, &status).await.unwrap(); + + // Wait for server to confirm receipt + rx.await.expect("server did not signal completion"); + + client_conn.close(0u32.into(), b"done"); + client_endpoint.close(0u32.into(), b"shutdown"); + }) + .await + .expect("test timed out"); + } + + /// Integration test: Full RouterConnection::connect flow with a mock router. + #[tokio::test] + async fn test_router_connection_connect() { + tokio::time::timeout(Duration::from_secs(10), async { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert_der = CertificateDer::from(cert.cert); + let key_der = + rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()); + + let server_crypto = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert_der.clone()], key_der.into()) + .unwrap(); + + let server_config = quinn::ServerConfig::with_crypto(Arc::new( + quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto).unwrap(), + )); + + let server_endpoint = + Endpoint::server(server_config, "127.0.0.1:0".parse().unwrap()).unwrap(); + let server_addr = server_endpoint.local_addr().unwrap(); + + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + + // Mock router: accept connection, open bi-stream, run auth, read one message + tokio::spawn(async move { + let incoming = server_endpoint.accept().await.unwrap(); + let server_conn = incoming.await.unwrap(); + let (mut send, mut recv) = server_conn.open_bi().await.unwrap(); + + // Challenge-response auth + write_framed( + &mut send, + &crate::network::auth::ChallengeMessage { + challenge: [0xBB; 32], + }, + ) + .await + .unwrap(); + + let _auth_req: crate::network::auth::AuthRequest = + read_framed(&mut recv).await.unwrap().unwrap(); + + write_framed( + &mut send, + &crate::network::auth::AuthResponse { + authenticated: true, + node_id: Some("node-42".into()), + error: None, + }, + ) + .await + .unwrap(); + + // Send a ping + write_framed(&mut send, &RouterMessage::Ping).await.unwrap(); + + // Read the status update response + let msg: NodeMessage = read_framed(&mut recv).await.unwrap().unwrap(); + assert!(matches!(msg, NodeMessage::StatusUpdate { .. })); + + let _ = tx.send(()); + server_conn.close(0u32.into(), b"done"); + server_endpoint.close(0u32.into(), b"shutdown"); + }); + + // Use RouterConnection::connect + let url = format!("127.0.0.1:{}", server_addr.port()); + let identity = Identity::from_secret_hex( + "6472696164726961647269616472696164726961647269616472696164726961", + ) + .unwrap(); + + let mut conn = RouterConnection::connect( + &url, + true, + &identity, + vec!["gemma3:4b".into()], + 50.0, + Capacity { free: 2, max: 4 }, + ) + .await + .unwrap(); + + assert_eq!(conn.node_id, "node-42"); + + // Receive ping from router + let msg = conn.recv().await.unwrap().unwrap(); + assert!(matches!(msg, RouterMessage::Ping)); + + // Send status update + conn.send(&NodeMessage::StatusUpdate { + models: vec!["gemma3:4b".into()], + capacity: Capacity { free: 2, max: 4 }, + version: env!("CARGO_PKG_VERSION").to_string(), + }) + .await + .unwrap(); + + rx.await.expect("server did not signal completion"); + conn.close(); + }) + .await + .expect("test timed out"); + } +} diff --git a/src/network/mod.rs b/src/network/mod.rs new file mode 100644 index 00000000..9234956f --- /dev/null +++ b/src/network/mod.rs @@ -0,0 +1,6 @@ +pub mod auth; +pub mod connection; +pub mod protocol; + +pub use connection::RouterConnection; +pub use protocol::{NodeMessage, RouterMessage}; diff --git a/src/network/protocol.rs b/src/network/protocol.rs new file mode 100644 index 00000000..91749dfe --- /dev/null +++ b/src/network/protocol.rs @@ -0,0 +1,329 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::error::NodeError; +use crate::models::template::ChatMessage; + +// --------------------------------------------------------------------------- +// Node → Router messages +// --------------------------------------------------------------------------- + +/// Messages sent from this compute node to the router. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum NodeMessage { + /// Completed inference result for a task. + TaskResult { + task_id: Uuid, + text: String, + stats: TaskStats, + proof: Option, + }, + /// We cannot accept the assigned task. + TaskRejected { + task_id: Uuid, + reason: RejectReason, + }, + /// Periodic or on-demand status snapshot. + StatusUpdate { + models: Vec, + capacity: Capacity, + version: String, + }, + /// Response to a router challenge (placeholder). + ChallengeResponse { + challenge: [u8; 32], + signature: Vec, + recovery_id: u8, + }, +} + +// --------------------------------------------------------------------------- +// Router → Node messages +// --------------------------------------------------------------------------- + +/// Messages sent from the router to this compute node. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum RouterMessage { + /// A new inference task to execute. + TaskAssignment { + task_id: Uuid, + model: String, + messages: Vec, + max_tokens: u32, + temperature: f32, + validation: Option, + }, + /// Challenge for proof-of-liveness. + Challenge { challenge: [u8; 32] }, + /// Heartbeat / keep-alive ping. + Ping, + /// Updated model registry from the router. + ModelRegistryUpdate { + entries: Vec, + }, +} + +// --------------------------------------------------------------------------- +// Supporting types +// --------------------------------------------------------------------------- + +/// Statistics about a completed inference task. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskStats { + pub tokens_generated: u32, + pub prompt_tokens: u32, + pub generation_time_ms: u64, + pub tokens_per_second: f64, +} + +/// Reason a task was rejected. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum RejectReason { + /// Model not loaded on this node. + ModelNotLoaded, + /// All inference slots are busy. + AtCapacity, + /// Task parameters are invalid. + InvalidRequest(String), +} + +/// Current capacity snapshot. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Capacity { + /// Number of free inference slots. + pub free: usize, + /// Maximum concurrent inference slots. + pub max: usize, +} + +/// Optional validation parameters included with a task. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ValidationRequest { + /// Token positions at which to extract logprobs. + pub logprob_positions: Vec, + /// Top-k alternatives to collect at each logprob position. + pub logprob_top_k: usize, +} + +/// A model entry from the router's registry. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelRegistryEntry { + pub name: String, + pub hf_repo: String, + pub hf_file: String, + pub chat_template: Option, +} + +// --------------------------------------------------------------------------- +// Length-prefixed MessagePack framing +// --------------------------------------------------------------------------- + +/// Maximum allowed message size (16 MB). +const MAX_MESSAGE_SIZE: u32 = 16 * 1024 * 1024; + +/// Write a length-prefixed MessagePack message to a QUIC send stream. +/// +/// Wire format: `[4-byte BE length][msgpack payload]` +pub async fn write_framed( + send: &mut quinn::SendStream, + msg: &T, +) -> Result<(), NodeError> { + let payload = + rmp_serde::to_vec(msg).map_err(|e| NodeError::Network(format!("serialize: {e}")))?; + let len = payload.len() as u32; + if len > MAX_MESSAGE_SIZE { + return Err(NodeError::Network(format!( + "message too large: {len} bytes (max {MAX_MESSAGE_SIZE})" + ))); + } + send.write_all(&len.to_be_bytes()) + .await + .map_err(|e| NodeError::Network(format!("write length: {e}")))?; + send.write_all(&payload) + .await + .map_err(|e| NodeError::Network(format!("write payload: {e}")))?; + Ok(()) +} + +/// Read a length-prefixed MessagePack message from a QUIC receive stream. +/// +/// Returns `Ok(None)` on clean EOF (stream closed), `Err` on protocol violations. +pub async fn read_framed( + recv: &mut quinn::RecvStream, +) -> Result, NodeError> { + let mut len_buf = [0u8; 4]; + match recv.read_exact(&mut len_buf).await { + Ok(()) => {} + Err(quinn::ReadExactError::FinishedEarly(_)) => return Ok(None), + Err(e) => return Err(NodeError::Network(format!("read length: {e}"))), + } + let len = u32::from_be_bytes(len_buf); + if len > MAX_MESSAGE_SIZE { + return Err(NodeError::Network(format!( + "message too large: {len} bytes (max {MAX_MESSAGE_SIZE})" + ))); + } + let mut payload = vec![0u8; len as usize]; + recv.read_exact(&mut payload) + .await + .map_err(|e| NodeError::Network(format!("read payload: {e}")))?; + let msg = rmp_serde::from_slice(&payload) + .map_err(|e| NodeError::Network(format!("deserialize: {e}")))?; + Ok(Some(msg)) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_node_message_roundtrip() { + let msg = NodeMessage::TaskResult { + task_id: Uuid::nil(), + text: "Hello world".into(), + stats: TaskStats { + tokens_generated: 10, + prompt_tokens: 5, + generation_time_ms: 100, + tokens_per_second: 100.0, + }, + proof: None, + }; + let packed = rmp_serde::to_vec(&msg).unwrap(); + let roundtrip: NodeMessage = rmp_serde::from_slice(&packed).unwrap(); + match roundtrip { + NodeMessage::TaskResult { task_id, text, .. } => { + assert_eq!(task_id, Uuid::nil()); + assert_eq!(text, "Hello world"); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_router_message_roundtrip() { + let msg = RouterMessage::TaskAssignment { + task_id: Uuid::nil(), + model: "gemma3:4b".into(), + messages: vec![ChatMessage { + role: "user".into(), + content: "hello".into(), + }], + max_tokens: 512, + temperature: 0.7, + validation: None, + }; + let packed = rmp_serde::to_vec(&msg).unwrap(); + let roundtrip: RouterMessage = rmp_serde::from_slice(&packed).unwrap(); + match roundtrip { + RouterMessage::TaskAssignment { model, .. } => { + assert_eq!(model, "gemma3:4b"); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_reject_reason_roundtrip() { + let msg = NodeMessage::TaskRejected { + task_id: Uuid::nil(), + reason: RejectReason::AtCapacity, + }; + let packed = rmp_serde::to_vec(&msg).unwrap(); + let roundtrip: NodeMessage = rmp_serde::from_slice(&packed).unwrap(); + match roundtrip { + NodeMessage::TaskRejected { reason, .. } => { + matches!(reason, RejectReason::AtCapacity); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_status_update_roundtrip() { + let msg = NodeMessage::StatusUpdate { + models: vec!["gemma3:4b".into()], + capacity: Capacity { free: 2, max: 4 }, + version: "2.0.0".into(), + }; + let packed = rmp_serde::to_vec(&msg).unwrap(); + let roundtrip: NodeMessage = rmp_serde::from_slice(&packed).unwrap(); + match roundtrip { + NodeMessage::StatusUpdate { + capacity, version, .. + } => { + assert_eq!(capacity.free, 2); + assert_eq!(capacity.max, 4); + assert_eq!(version, "2.0.0"); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_challenge_roundtrip() { + let msg = RouterMessage::Challenge { + challenge: [0xAB; 32], + }; + let packed = rmp_serde::to_vec(&msg).unwrap(); + let roundtrip: RouterMessage = rmp_serde::from_slice(&packed).unwrap(); + match roundtrip { + RouterMessage::Challenge { challenge } => { + assert_eq!(challenge, [0xAB; 32]); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn test_ping_roundtrip() { + let packed = rmp_serde::to_vec(&RouterMessage::Ping).unwrap(); + let roundtrip: RouterMessage = rmp_serde::from_slice(&packed).unwrap(); + assert!(matches!(roundtrip, RouterMessage::Ping)); + } + + #[test] + fn test_model_registry_update_roundtrip() { + let msg = RouterMessage::ModelRegistryUpdate { + entries: vec![ModelRegistryEntry { + name: "test:1b".into(), + hf_repo: "repo/model".into(), + hf_file: "model.gguf".into(), + chat_template: Some("chatml".into()), + }], + }; + let packed = rmp_serde::to_vec(&msg).unwrap(); + let roundtrip: RouterMessage = rmp_serde::from_slice(&packed).unwrap(); + match roundtrip { + RouterMessage::ModelRegistryUpdate { entries } => { + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].name, "test:1b"); + } + _ => panic!("wrong variant"), + } + } + + /// Test framing over a quinn duplex (uses tokio::io::duplex via quinn test helpers). + /// Since we can't easily create a quinn stream in unit tests, test the serialization + /// logic directly and verify size limits. + #[test] + fn test_message_size_within_limit() { + let msg = NodeMessage::TaskResult { + task_id: Uuid::nil(), + text: "x".repeat(1000), + stats: TaskStats { + tokens_generated: 100, + prompt_tokens: 50, + generation_time_ms: 500, + tokens_per_second: 200.0, + }, + proof: None, + }; + let packed = rmp_serde::to_vec(&msg).unwrap(); + assert!((packed.len() as u32) < MAX_MESSAGE_SIZE); + } +} diff --git a/src/worker.rs b/src/worker.rs new file mode 100644 index 00000000..09acce20 --- /dev/null +++ b/src/worker.rs @@ -0,0 +1,261 @@ +use std::ops::ControlFlow; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use futures::stream::FuturesUnordered; +use futures::StreamExt; +use tokio::task::JoinHandle; +use uuid::Uuid; + +use crate::error::NodeError; +use crate::inference::{GenerateParams, InferenceEngine, InferenceResult}; +use crate::models::template::{ChatMessage, apply_chat_template}; +use crate::network::protocol::{ + Capacity, NodeMessage, RejectReason, TaskStats, ValidationRequest, +}; + +/// A completed inference task ready to be sent back. +pub struct CompletedTask { + pub task_id: Uuid, + pub result: Result, +} + +/// Executes inference tasks with backpressure via capacity tracking. +pub struct Worker { + engine: Arc, + /// Chat template name for prompt formatting. + chat_template: String, + /// Number of available inference slots (CAS-based). + capacity: Arc, + /// Maximum concurrent slots. + max_capacity: usize, + /// Models this worker serves. + model_names: Vec, + /// In-flight tasks tracked via FuturesUnordered. + in_flight: FuturesUnordered>, +} + +impl Worker { + /// Create a new worker wrapping an inference engine. + pub fn new( + engine: InferenceEngine, + chat_template: String, + model_names: Vec, + max_concurrent: usize, + ) -> Self { + Worker { + engine: Arc::new(engine), + chat_template, + capacity: Arc::new(AtomicUsize::new(max_concurrent)), + max_capacity: max_concurrent, + model_names, + in_flight: FuturesUnordered::new(), + } + } + + /// Try to accept a task. Returns `Err(RejectReason)` if the task cannot be accepted. + /// + /// On success, spawns inference in a blocking thread and returns immediately. + pub fn try_accept( + &self, + task_id: Uuid, + model: &str, + messages: Vec, + max_tokens: u32, + temperature: f32, + validation: Option, + ) -> Result<(), RejectReason> { + // Check model + if !self.model_names.iter().any(|m| m == model) { + return Err(RejectReason::ModelNotLoaded); + } + + // Try to decrement capacity (CAS loop) + loop { + let current = self.capacity.load(Ordering::Acquire); + if current == 0 { + return Err(RejectReason::AtCapacity); + } + if self + .capacity + .compare_exchange_weak(current, current - 1, Ordering::AcqRel, Ordering::Acquire) + .is_ok() + { + break; + } + } + + // Build generate params + let params = GenerateParams { + max_tokens, + temperature, + top_p: 0.9, + seed: None, + logprob_positions: validation + .as_ref() + .map(|v| v.logprob_positions.clone()) + .unwrap_or_default(), + logprob_top_k: validation.as_ref().map(|v| v.logprob_top_k).unwrap_or(5), + }; + + let engine = Arc::clone(&self.engine); + let capacity = Arc::clone(&self.capacity); + let template = self.chat_template.clone(); + + let handle = tokio::task::spawn_blocking(move || { + let result = run_inference(&engine, &template, messages, ¶ms, task_id); + // Release capacity slot regardless of outcome + capacity.fetch_add(1, Ordering::Release); + result + }); + + self.in_flight.push(handle); + Ok(()) + } + + /// Poll for the next completed task. + /// + /// Returns `None` when no tasks are in-flight. When used in `tokio::select!`, + /// the branch will be skipped when there's nothing to poll. + pub async fn next_completed(&mut self) -> Option { + let join_result = self.in_flight.next().await?; + match join_result { + Ok(completed) => Some(completed), + Err(e) => { + tracing::error!(%e, "task panicked"); + None + } + } + } + + /// Current capacity snapshot. + pub fn capacity(&self) -> Capacity { + Capacity { + free: self.capacity.load(Ordering::Acquire), + max: self.max_capacity, + } + } + + /// Model names this worker serves. + pub fn model_names(&self) -> &[String] { + &self.model_names + } + + /// Whether there are any in-flight tasks. + pub fn has_in_flight(&self) -> bool { + !self.in_flight.is_empty() + } +} + +/// Run inference synchronously (called from `spawn_blocking`). +fn run_inference( + engine: &InferenceEngine, + template: &str, + messages: Vec, + params: &GenerateParams, + task_id: Uuid, +) -> CompletedTask { + let prompt = apply_chat_template(template, &messages); + + match engine.generate(&prompt, params, |_| ControlFlow::Continue(())) { + Ok(result) => CompletedTask { + task_id, + result: Ok(build_task_result(task_id, result)), + }, + Err(e) => CompletedTask { + task_id, + result: Err(e), + }, + } +} + +fn build_task_result(task_id: Uuid, result: InferenceResult) -> NodeMessage { + NodeMessage::TaskResult { + task_id, + text: result.text, + stats: TaskStats { + tokens_generated: result.tokens_generated, + prompt_tokens: result.prompt_tokens, + generation_time_ms: result.generation_time_ms, + tokens_per_second: result.tokens_per_second, + }, + proof: result.proof, + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + // Helper: create a worker with no real engine (tests that don't need inference) + // We can't easily mock InferenceEngine, so we test capacity logic directly. + + #[test] + fn test_capacity_tracking() { + let cap = Arc::new(AtomicUsize::new(3)); + + // Decrement + assert_eq!(cap.fetch_sub(1, Ordering::AcqRel), 3); + assert_eq!(cap.load(Ordering::Acquire), 2); + + // Increment back + cap.fetch_add(1, Ordering::Release); + assert_eq!(cap.load(Ordering::Acquire), 3); + } + + #[test] + fn test_capacity_struct() { + let c = Capacity { free: 2, max: 4 }; + assert_eq!(c.free, 2); + assert_eq!(c.max, 4); + } + + #[test] + fn test_reject_reason_model_not_loaded() { + let reason = RejectReason::ModelNotLoaded; + let packed = rmp_serde::to_vec(&reason).unwrap(); + let roundtrip: RejectReason = rmp_serde::from_slice(&packed).unwrap(); + assert!(matches!(roundtrip, RejectReason::ModelNotLoaded)); + } + + #[test] + fn test_reject_reason_at_capacity() { + let reason = RejectReason::AtCapacity; + let packed = rmp_serde::to_vec(&reason).unwrap(); + let roundtrip: RejectReason = rmp_serde::from_slice(&packed).unwrap(); + assert!(matches!(roundtrip, RejectReason::AtCapacity)); + } + + #[test] + fn test_completed_task_success() { + let msg = NodeMessage::TaskResult { + task_id: Uuid::nil(), + text: "Hello".into(), + stats: TaskStats { + tokens_generated: 5, + prompt_tokens: 3, + generation_time_ms: 50, + tokens_per_second: 100.0, + }, + proof: None, + }; + let completed = CompletedTask { + task_id: Uuid::nil(), + result: Ok(msg), + }; + assert!(completed.result.is_ok()); + } + + #[test] + fn test_completed_task_error() { + let completed = CompletedTask { + task_id: Uuid::nil(), + result: Err(NodeError::Inference("test error".into())), + }; + assert!(completed.result.is_err()); + } +} diff --git a/utils/Cargo.toml b/utils/Cargo.toml deleted file mode 100644 index 51f5fa36..00000000 --- a/utils/Cargo.toml +++ /dev/null @@ -1,40 +0,0 @@ -[package] -name = "dkn-utils" -version.workspace = true -edition.workspace = true -license.workspace = true -readme = "README.md" -authors = ["Erhan Tezcan "] - -[features] -crypto = [ - "ecies", - "libsecp256k1", - "libp2p-identity", - "sha2", - "sha3", - "hex", - "base64", -] - -[dependencies] -serde.workspace = true -serde_json.workspace = true - -ecies = { version = "0.2", default-features = false, features = [ - "pure", -], optional = true } -libsecp256k1 = { version = "0.7.1", optional = true } -libp2p-identity = { version = "0.2.10", features = [ - "secp256k1", - "peerid", -], optional = true } -sha2 = { version = "0.10.8", optional = true } -sha3 = { version = "0.10.8", optional = true } -hex = { version = "0.4.3", optional = true } -base64 = { version = "0.22.0", optional = true } - -public-ip-address = "0.3.2" -chrono.workspace = true -uuid.workspace = true -thiserror.workspace = true diff --git a/utils/README.md b/utils/README.md deleted file mode 100644 index 91319089..00000000 --- a/utils/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# Dria Utils - -Just small utility functions such as reading environment variables or splitting strings etc. - -## Installation - -Add the package via `git` within your Cargo dependencies: - -```toml -dkn-utils = { git = "https://github.com/firstbatchxyz/dkn-compute-node" } -``` - -## Usage - -```rs -use dkn_utils::*; - -// use whatever you like! -``` diff --git a/utils/src/crypto.rs b/utils/src/crypto.rs deleted file mode 100644 index c42ef52a..00000000 --- a/utils/src/crypto.rs +++ /dev/null @@ -1,115 +0,0 @@ -use libp2p_identity; -use sha2::{Digest, Sha256}; -use sha3::Keccak256; - -/// Generic SHA256 function. -#[inline(always)] -pub fn sha256hash(data: impl AsRef<[u8]>) -> [u8; 32] { - Sha256::digest(data).into() -} - -/// Generic KECCAK256 function. -#[inline(always)] -pub fn keccak256hash(data: impl AsRef<[u8]>) -> [u8; 32] { - Keccak256::digest(data).into() -} - -/// Converts a `libsecp256k1::SecretKey` to a `libp2p_identity::secp256k1::Keypair`. -/// To do this, we serialize the secret key and create a new keypair from it. -#[inline] -pub fn secret_to_keypair(secret_key: &libsecp256k1::SecretKey) -> libp2p_identity::Keypair { - let bytes = secret_key.serialize(); - - let secret_key = libp2p_identity::secp256k1::SecretKey::try_from_bytes(bytes) - .expect("Failed to create secret key"); - libp2p_identity::secp256k1::Keypair::from(secret_key).into() -} - -/// Given a secp256k1 public key, finds the corresponding Ethereum address. -/// -/// Internally, the public key is serialized in uncompressed format at 65 bytes (0x04 || x || y), -/// and then (x || y) is hashed using Keccak256. The last 20 bytes of this hash is taken as the address. -#[inline] -pub fn public_key_to_address(public_key: &libsecp256k1::PublicKey) -> [u8; 20] { - let public_key_xy = &public_key.serialize()[1..]; - let mut addr = [0u8; 20]; - addr.copy_from_slice(&keccak256hash(public_key_xy)[12..32]); - addr -} - -/// Converts a `libsecp256k1::PublicKey` to a `libp2p_identity::PeerId`. -/// To do this, we serialize the secret key and create a new keypair from it. -#[inline] -pub fn public_key_to_peer_id(public_key: &libsecp256k1::PublicKey) -> libp2p_identity::PeerId { - let bytes = public_key.serialize_compressed(); - - let public_key = libp2p_identity::secp256k1::PublicKey::try_from_bytes(&bytes) - .expect("failed to create secret key"); - - libp2p_identity::PeerId::from_public_key(&public_key.into()) -} - -#[cfg(test)] -mod tests { - use super::*; - use ecies::{decrypt, encrypt}; - use hex::decode; - use libsecp256k1::{recover, sign, verify, Message, PublicKey, SecretKey}; - - const DUMMY_SECRET_KEY: &[u8; 32] = b"driadriadriadriadriadriadriadria"; - const MESSAGE: &[u8] = b"hello world"; - - #[test] - fn test_hash() { - // sha256 of "hello world" - let expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"; - let expected = decode(expected).expect("Should decode hex string."); - assert_eq!(sha256hash(MESSAGE), expected.as_slice()); - } - - #[test] - fn test_address() { - let sk = SecretKey::parse_slice(DUMMY_SECRET_KEY).expect("Should parse key."); - let pk = PublicKey::from_secret_key(&sk); - let addr = public_key_to_address(&pk); - assert_eq!( - "D79Fdf178547614CFdd0dF6397c53569716Bd596".to_lowercase(), - hex::encode(addr) - ); - } - - #[test] - fn test_encrypt_decrypt() { - let sk = SecretKey::parse_slice(DUMMY_SECRET_KEY).expect("Should parse private key slice."); - let pk = PublicKey::from_secret_key(&sk); - let (sk, pk) = (&sk.serialize(), &pk.serialize()); - - let ciphertext = encrypt(pk, MESSAGE).expect("Should encrypt."); - let plaintext = decrypt(sk, &ciphertext).expect("Should decyrpt."); - assert_eq!(MESSAGE, plaintext.as_slice()); - } - - #[test] - fn test_sign_verify() { - let secret_key = - SecretKey::parse_slice(DUMMY_SECRET_KEY).expect("to parse private key slice"); - - // sign the message using the secret key - let digest = sha256hash(MESSAGE); - let message = Message::parse_slice(&digest).expect("to parse message"); - let (signature, recid) = sign(&message, &secret_key); - - // recover verifying key (public key) from signature - let expected_public_key = PublicKey::from_secret_key(&secret_key); - let recovered_public_key = - recover(&message, &signature, &recid).expect("to recover public key"); - assert_eq!(expected_public_key, recovered_public_key); - - // verify the signature - let public_key = recovered_public_key; - assert!( - verify(&message, &signature, &public_key), - "could not verify signature" - ); - } -} diff --git a/utils/src/env.rs b/utils/src/env.rs deleted file mode 100644 index 9bbaba68..00000000 --- a/utils/src/env.rs +++ /dev/null @@ -1,25 +0,0 @@ -/// Reads an environment variable and trims whitespace and `"` from both ends. -/// If the trimmed value is empty, returns `None`. -#[inline] -pub fn safe_read_env(var: Result) -> Option { - var.map(|s| s.trim_matches('"').trim().to_string()) - .ok() - .filter(|s| !s.is_empty()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_var_read() { - let var = Ok("\" value \"".to_string()); - assert_eq!(safe_read_env(var), Some("value".to_string())); - - let var = Ok("\" \"".to_string()); - assert!(safe_read_env(var).is_none()); - - let var = Err(std::env::VarError::NotPresent); - assert!(safe_read_env(var).is_none()); - } -} diff --git a/utils/src/lib.rs b/utils/src/lib.rs deleted file mode 100644 index 9a8cee2e..00000000 --- a/utils/src/lib.rs +++ /dev/null @@ -1,30 +0,0 @@ -/// Cryptography-related utilities. -#[cfg(feature = "crypto")] -pub mod crypto; - -/// Payload-related utilities. -/// Includes heartbeat, task and specs payloads and their request/response types. -pub mod payloads; - -mod env; -pub use env::safe_read_env; - -mod network; -pub use network::DriaNetwork; - -mod version; -pub use version::SemanticVersion; - -#[cfg(feature = "crypto")] -mod message; -#[cfg(feature = "crypto")] -pub use message::DriaMessage; - -// re-exports -pub use chrono; - -#[cfg(feature = "crypto")] -pub use libp2p_identity; - -#[cfg(feature = "crypto")] -pub use libsecp256k1; diff --git a/utils/src/message.rs b/utils/src/message.rs deleted file mode 100644 index adae0a08..00000000 --- a/utils/src/message.rs +++ /dev/null @@ -1,215 +0,0 @@ -use crate::crypto::sha256hash; - -use super::SemanticVersion; -use base64::{prelude::BASE64_STANDARD, Engine}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use thiserror::Error; - -/// Message format for Dria network communication. -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct DriaMessage { - /// `base64` encoded message payload, can be decoded with [`Self::decode_payload`]. - /// - /// This payload is signed by the sender, and the public key can be recovered from the signature - /// using [`Self::recover_public_key`]. - pub payload: String, - // Topic identifier derived from TopicHash - pub topic: String, - // Semantic version of Dria Compute Node, of the form `X.Y.Z` - pub version: SemanticVersion, - // Protocol identifier, e.g. "dria" - pub protocol: String, - // Message timestamp in nanoseconds - pub timestamp: chrono::DateTime, - // 64-byte hex-encoded signature - pub signature: String, - // Signature recovery ID - pub recovery_id: u8, -} - -#[derive(Error, Debug)] -pub enum DriaMessageError { - #[error("Could not decode payload: {0}")] - DecodeError(base64::DecodeError), - #[error("Could not parse message: {0}")] - ParseError(serde_json::Error), - #[error("Protocol mismatch (expected {expected:?}, got {found:?})")] - ProtocolMismatch { expected: String, found: String }, - #[error("Version mismatch (expected {expected:?}, got {found:?})")] - VersionMismatch { - expected: SemanticVersion, - found: SemanticVersion, - }, - #[error("Invalid signature ({0})")] - InvalidSignature(libsecp256k1::Error), -} - -impl DriaMessage { - /// Creates a new Dria message. - /// - /// - `data` is converted to a bytes reference, and encoded into base64 to make up the `payload` within. - /// - `topic` is the name of the [gossipsub topic](https://docs.libp2p.io/concepts/pubsub/overview/). - /// - `protocol` is the protocol name, e.g. `dria`. - /// - `signing_key` is the secret key to sign the message. - pub fn new_signed( - data: impl AsRef<[u8]>, - topic: impl ToString, - protocol: String, - signing_key: &libsecp256k1::SecretKey, - version: SemanticVersion, - ) -> Self { - // base64 encode the data to obtain payload - let payload = BASE64_STANDARD.encode(data); - - // sign the SHA256 hash of the payload - let (signature, recovery_id) = libsecp256k1::sign( - &libsecp256k1::Message::parse(&sha256hash(&payload)), - signing_key, - ); - - Self { - payload, - topic: topic.to_string(), - protocol, - timestamp: chrono::Utc::now(), - version, - signature: hex::encode(signature.serialize()), - recovery_id: recovery_id.serialize(), - } - } - - /// Parses a slice of bytes into a `DriaMessage`, and checks for protocol & network matches. - pub fn from_slice_checked( - data: &[u8], - protocol: String, - version: SemanticVersion, - ) -> Result { - let message: DriaMessage = - serde_json::from_slice(data).map_err(DriaMessageError::ParseError)?; - - // ensure that protocol names match - if protocol != message.protocol { - Err(DriaMessageError::ProtocolMismatch { - expected: protocol, - found: message.protocol, - }) - } else - // ensure versions are compatible - if !version.is_compatible(&message.version) { - Err(DriaMessageError::VersionMismatch { - expected: version, - found: message.version, - }) - } else { - Ok(message) - } - } - - /// Decodes the base64 payload into bytes. - #[inline(always)] - pub fn decode_payload(&self) -> Result, DriaMessageError> { - BASE64_STANDARD - .decode(&self.payload) - .map_err(DriaMessageError::DecodeError) - } - - /// Decodes with [`Self::decode_payload`] and parses the decoded payload into JSON for the provided type `T`. - #[inline(always)] - pub fn parse_payload(&self) -> Result { - let decoded = self.decode_payload()?; - serde_json::from_slice::(&decoded).map_err(DriaMessageError::ParseError) - } - - /// Recovers the signature from the message payload. - /// - /// This may be costly to do in a hot loop. - #[inline(always)] - pub fn recover_public_key(&self) -> Result { - let message = libsecp256k1::Message::parse(&sha256hash(&self.payload)); - - // parse the signature and recovery ID - let signature = - libsecp256k1::Signature::parse_standard_slice(&hex::decode(&self.signature).unwrap()) - .map_err(DriaMessageError::InvalidSignature)?; - let recovery_id = libsecp256k1::RecoveryId::parse(self.recovery_id) - .map_err(DriaMessageError::InvalidSignature)?; - - // recover the public key from the signature - libsecp256k1::recover(&message, &signature, &recovery_id) - .map_err(DriaMessageError::InvalidSignature) - } -} - -impl From<&DriaMessage> for Vec { - fn from(message: &DriaMessage) -> Self { - serde_json::to_vec(message).expect("should not fail") - } -} - -impl From for Vec { - fn from(message: DriaMessage) -> Self { - (&message).into() - } -} - -impl std::fmt::Display for DriaMessage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let payload_decoded = self - .decode_payload() - .unwrap_or(self.payload.as_bytes().to_vec()); - - let payload_str = String::from_utf8_lossy(&payload_decoded); - write!( - f, - "{}/{} message at {}\n{}", - self.protocol, self.topic, self.timestamp, payload_str - ) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use ecies::SecretKey; - - #[derive(Serialize, Deserialize, PartialEq, Debug)] - struct TestStruct { - hello: String, - } - - const TOPIC: &str = "test"; - - #[test] - fn test_signed_message() { - const DUMMY_SECRET_KEY: &[u8; 32] = b"driadriadriadriadriadriadriadria"; - let sk = SecretKey::parse(DUMMY_SECRET_KEY).unwrap(); - - // create payload & message with signature & body - let body = TestStruct { - hello: "hi there baby!".to_string(), - }; - let body_str = serde_json::to_string(&body).unwrap(); - let message = DriaMessage::new_signed( - body_str, - TOPIC, - "test".into(), - &sk, - SemanticVersion::default(), - ); - - // decode message - let body = message - .parse_payload::() - .expect("Should decode"); - assert_eq!( - serde_json::to_string(&body).expect("Should stringify"), - "{\"hello\":\"hi there baby!\"}" - ); - assert_eq!(message.topic, TOPIC); - assert_eq!(message.version, SemanticVersion::default()); - assert!(message.timestamp != chrono::DateTime::::default()); - - let parsed_body = message.parse_payload().expect("Should decode"); - assert_eq!(body, parsed_body); - } -} diff --git a/utils/src/network.rs b/utils/src/network.rs deleted file mode 100644 index d93e7daf..00000000 --- a/utils/src/network.rs +++ /dev/null @@ -1,86 +0,0 @@ -use crate::SemanticVersion; - -/// Network type, either mainnet or testnet. -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum DriaNetwork { - Mainnet, - Testnet, -} - -impl TryFrom<&str> for DriaNetwork { - type Error = (); - - /// Converts a string to a `DriaNetwork`, using the same name as in: - /// - /// - "mainnet" for `DriaNetwork::Mainnet` - /// - "testnet" for `DriaNetwork::Testnet` - fn try_from(s: &str) -> Result { - match s { - "mainnet" => Ok(DriaNetwork::Mainnet), - "testnet" => Ok(DriaNetwork::Testnet), - _ => Err(()), - } - } -} - -impl std::fmt::Display for DriaNetwork { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - DriaNetwork::Mainnet => write!(f, "mainnet"), - DriaNetwork::Testnet => write!(f, "testnet"), - } - } -} - -impl DriaNetwork { - /// Returns the protocol name for the given network, which can be used by - /// libp2p `identify` protocol. - pub fn protocol_name(&self) -> &str { - match self { - DriaNetwork::Mainnet => "dria", - DriaNetwork::Testnet => "dria-test", - } - } - - /// Returns the discovery URL for the given version, where the - /// major.minor version is appended to the URL as a path variable. - pub fn discovery_url(&self, version: &SemanticVersion) -> String { - let base_url = match self { - DriaNetwork::Mainnet => "https://mainnet.dkn.dria.co/discovery/v0/available-nodes", - DriaNetwork::Testnet => "https://testnet.dkn.dria.co/discovery/v0/available-nodes", - }; - - format!("{}/{}", base_url, version.as_major_minor()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_dria_network() { - let mainnet = DriaNetwork::Mainnet; - let testnet = DriaNetwork::Testnet; - let version = SemanticVersion { - major: 1, - minor: 0, - patch: 42, - }; - - assert_eq!(mainnet.to_string(), "mainnet"); - assert_eq!(testnet.to_string(), "testnet"); - - assert_eq!(mainnet.protocol_name(), "dria"); - assert_eq!(testnet.protocol_name(), "dria-test"); - - assert_eq!( - mainnet.discovery_url(&version), - "https://mainnet.dkn.dria.co/discovery/v0/available-nodes/1.0" - ); - assert_eq!( - testnet.discovery_url(&version), - "https://testnet.dkn.dria.co/discovery/v0/available-nodes/1.0" - ); - } -} diff --git a/utils/src/payloads/heartbeat.rs b/utils/src/payloads/heartbeat.rs deleted file mode 100644 index a3d2d01b..00000000 --- a/utils/src/payloads/heartbeat.rs +++ /dev/null @@ -1,36 +0,0 @@ -use serde::{Deserialize, Serialize}; -use uuid::Uuid; - -/// Topic used within [`crate::DriaMessage`] for heartbeat messages. -pub const HEARTBEAT_TOPIC: &str = "heartbeat"; - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct HeartbeatRequest { - /// A unique ID for the heartbeat request. - pub heartbeat_id: Uuid, - /// Deadline for the heartbeat request, in nanoseconds. - pub deadline: chrono::DateTime, - /// Number of "single" tasks in the channel. - pub pending_single: usize, - /// Number of tasks in the channel currently, `single` and `batch`. - pub pending_batch: usize, - /// Number of batchable tasks at once. - /// - /// If `pending_batch` is greater than this value, the node will not be able to process them - /// and will stall until the channel is free to do more. - pub batch_size: usize, -} - -/// The response is an object with UUID along with an ACK (acknowledgement). -/// -/// If for any reason the `error` is `Some`, the request is considered failed. -/// This may be when `deadline` is past the current time, or if the node is deeming itself unhealthy. -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct HeartbeatResponse { - /// UUID as given in the request. - pub heartbeat_id: Uuid, - /// An associated error with the response: - /// - `None` means that the heartbeat was acknowledged. - /// - `Some` means that the heartbeat was not acknowledged for the given reason. - pub error: Option, -} diff --git a/utils/src/payloads/mod.rs b/utils/src/payloads/mod.rs deleted file mode 100644 index ce514976..00000000 --- a/utils/src/payloads/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod tasks; -pub use tasks::{TaskError, TaskRequestPayload, TaskResponsePayload, TaskStats}; -pub use tasks::{TASK_REQUEST_TOPIC, TASK_RESULT_TOPIC}; - -mod heartbeat; -pub use heartbeat::HEARTBEAT_TOPIC; -pub use heartbeat::{HeartbeatRequest, HeartbeatResponse}; - -mod specs; -pub use specs::SPECS_TOPIC; -pub use specs::{SpecModelPerformance, Specs, SpecsRequest, SpecsResponse}; diff --git a/utils/src/payloads/specs.rs b/utils/src/payloads/specs.rs deleted file mode 100644 index 95a58e89..00000000 --- a/utils/src/payloads/specs.rs +++ /dev/null @@ -1,98 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use uuid::Uuid; - -/// Topic used within [`crate::DriaMessage`] for specs messages. -pub const SPECS_TOPIC: &str = "specs"; - -#[derive(Serialize, Deserialize)] -pub struct SpecsRequest { - /// UUID of the specs request, prevents replays. - pub specs_id: Uuid, - /// Node specs. - pub specs: Specs, - /// Address of the node, used by frontend etc. instead of peer id. - pub address: String, -} - -#[derive(Serialize, Deserialize)] -pub struct SpecsResponse { - /// UUID of the specs request, prevents replays. - pub specs_id: Uuid, -} - -/// The specs of a node, containing information about the hardware and software it runs on. -/// -/// Optional values are done so for backwards compatibility, as some fields were added later. -#[derive(Debug, Serialize, Deserialize)] -pub struct Specs { - /// Total memory in bytes - pub total_mem: u64, - /// Free memory in bytes - pub free_mem: u64, - /// Number of physical CPU cores. - pub num_cpus: Option, - /// Global CPU usage, in percentage. - pub cpu_usage: f32, - /// Operating system name, e.g. `linux`, `macos`, `windows`. - pub os: String, - /// CPU architecture, e.g. `x86_64`, `aarch64`. - pub arch: String, - /// Public IP lookup response. - pub lookup: Option, - /// Models server by this node. - pub models: Vec, - /// Model performance metrics, keyed by model name. - pub model_perf: HashMap, - /// Node version, e.g. `0.1.0`. - pub version: String, - /// Name of the execution platform, e.g. Docker file or Launcher. - #[serde(skip_serializing_if = "Option::is_none")] - pub exec_platform: Option, - /// Peer id of the node. - #[serde(skip_serializing_if = "Option::is_none")] - pub peer_id: Option, - // GPU adapter infos, showing information about the available GPUs. - // gpus: Vec, -} - -/// Performance metrics for a model, used in the specs. -/// -/// These are measured at the start of the compute node, and those that are not succesfull. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum SpecModelPerformance { - /// Evaluation tokens per second (TPS) for the model that has passed evaluation. - PassedWithTPS(f64), - /// Evaluation tokens per second (TPS) for the model that has failed evaluation. - FailedWithTPS(f64), - /// Model has timed-out during performance evaluation. - /// - /// This can happen if the model is slow to respond or the request takes too long. - Timeout, - /// Model is not found for performance evaluation. - /// - /// Possible reasons are API key not set, or model not available in the account. - NotFound, - /// Model has failed to execute during performance evaluation. - /// - /// This can happen if the model is not available, or the request fails for some reason. - /// One example is OpenRouter, where sometimes models are not available even if they are listed. - ExecutionFailed, - /// Model has passed execution performance evaluation, however TPS was not available. - Passed, -} - -impl std::fmt::Display for SpecModelPerformance { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - SpecModelPerformance::PassedWithTPS(tps) => write!(f, "Passed with TPS: {tps:.3}"), - SpecModelPerformance::FailedWithTPS(tps) => { - write!(f, "Failed with TPS: {tps:.3}") - } - SpecModelPerformance::Timeout => write!(f, "Timeout"), - SpecModelPerformance::NotFound => write!(f, "Not Found"), - SpecModelPerformance::ExecutionFailed => write!(f, "Execution Failed"), - SpecModelPerformance::Passed => write!(f, "Passed"), - } - } -} diff --git a/utils/src/payloads/tasks.rs b/utils/src/payloads/tasks.rs deleted file mode 100644 index add26e18..00000000 --- a/utils/src/payloads/tasks.rs +++ /dev/null @@ -1,151 +0,0 @@ -use serde::{Deserialize, Serialize}; -use uuid::Uuid; - -/// Topic used within [`crate::DriaMessage`] for task request messages. -pub const TASK_REQUEST_TOPIC: &str = "task"; - -/// Topic used within [`crate::DriaMessage`] for task result messages. -pub const TASK_RESULT_TOPIC: &str = "results"; - -/// A computation task is the task of computing a result from a given input. -/// -/// `result` and `error` are mutually-exclusive, only one of them can be `Some`: -/// - if `result` is `Some`, then it contains the result. -/// - if `error` is `Some`, then it contains the error message. -/// -/// Each task belongs to a file (uniquely identified by `file_id`), and has a unique identifier (`row_id`). -/// THe `task_id` is a custom identifier given by a user. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct TaskResponsePayload { - /// The file that this task is associated with. - pub file_id: Uuid, - /// The unique identifier of the task. - pub row_id: Uuid, - /// The custom identifier of the task, not necessarily unique. - pub task_id: String, - /// Name of the model used for this task. - pub model: String, - /// Stats about the task execution. - pub stats: TaskStats, - /// Result from the LLM, as-is. - /// - /// If this is `None`, the task failed, and you should check the `error` field. - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, - /// An error, if any. - /// - /// If this is `Some`, you can ignore the `result` field. - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -/// A generic task request, given by Dria. -/// -/// Each task belongs to a file (uniquely identified by `file_id`), and has a unique identifier (`row_id`). -/// THe `task_id` is a custom identifier given by a user. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct TaskRequestPayload { - /// The file that this task is associated with. - pub file_id: Uuid, - /// The unique identifier of the task. - pub row_id: Uuid, - /// The custom identifier of the task, not necessarily unique. - pub task_id: String, - /// The input to the compute function. - pub input: T, -} - -#[derive(Debug, Clone, Serialize, Deserialize, thiserror::Error)] -pub enum TaskError { - /// A parse error occurred while parsing the task request or response. - #[error("Parse error: {0}")] - ParseError(String), - /// An error returned from the model provider. - #[error("{provider} error ({code}): {message}")] - ProviderError { - /// Not necessarily an HTTP status code, but a code that the provider uses to identify the error. - /// - /// For example, OpenAI uses a string code like "invalid_request_error". - code: String, - /// The error message returned by the provider. - /// - /// May contain additional information about the error. - message: String, - /// The source of the error. - /// - /// Can be a provider name, or RPC etc. - provider: String, - }, - /// This is a generic HTTP error, not necessarily related to the provider. - #[error("HTTP error: {0}")] - HttpError(String), - /// Any other executor error that is not a provider error. - #[error("Executor error: {0}")] - ExecutorError(String), - /// The task request had failed for some network reason. - #[error("Outbound request error: {code} - {message}")] - OutboundRequestError { - code: String, - /// The error message returned by the network. - message: String, - }, - /// Any other error - #[error("Other error: {0}")] - Other(String), -} - -/// Task stats for diagnostics. -/// -/// Returning this as the payload helps to debug the errors received at client side, and latencies. -#[derive(Default, Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct TaskStats { - /// Timestamp at which the task was received from network & parsed. - pub received_at: chrono::DateTime, - /// Timestamp at which the task was published back to network. - pub published_at: chrono::DateTime, - /// Timestamp at which the task execution had started. - pub execution_started_at: chrono::DateTime, - /// Timestamp at which the task execution had finished. - pub execution_ended_at: chrono::DateTime, - /// Number of tokens of the result. - pub token_count: usize, -} - -impl TaskStats { - pub fn new() -> Self { - Self::default() - } - - /// Records the current timestamp within `received_at`. - pub fn record_received_at(mut self) -> Self { - self.received_at = chrono::Utc::now(); - self - } - - /// Records the current timestamp within `published_at`. - pub fn record_published_at(mut self) -> Self { - self.published_at = chrono::Utc::now(); - self - } - - /// Records the execution start time within `execution_started_at`. - pub fn record_execution_started_at(mut self) -> Self { - self.execution_started_at = chrono::Utc::now(); - self - } - - /// Records the execution end time within `execution_ended_time`. - pub fn record_execution_ended_at(mut self) -> Self { - self.execution_ended_at = chrono::Utc::now(); - self - } - - /// Records the token count within `token_count`. - pub fn record_token_count(mut self, token_count: usize) -> Self { - self.token_count = token_count; - self - } -} diff --git a/utils/src/version.rs b/utils/src/version.rs deleted file mode 100644 index 107ce6d0..00000000 --- a/utils/src/version.rs +++ /dev/null @@ -1,95 +0,0 @@ -use std::str::FromStr; - -/// A tiny utility for semantic versioning. -/// This is a simple struct that holds the major, minor, and patch version numbers. -/// -/// Implements a Display trait that serializes to `{major}.{minor}.{patch}`. -#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone, PartialEq, Eq, Copy)] -pub struct SemanticVersion { - /// Major version number. - pub major: u32, - /// Minor version number. - pub minor: u32, - /// Patch version number. - pub patch: u32, -} - -impl FromStr for SemanticVersion { - type Err = String; - - fn from_str(version: &str) -> Result { - let parts: Vec = version.split('.').filter_map(|s| s.parse().ok()).collect(); - - if parts.len() != 3 { - Err("Invalid version format".to_string()) - } else { - Ok(SemanticVersion { - major: parts[0], - minor: parts[1], - patch: parts[2], - }) - } - } -} - -impl SemanticVersion { - /// Checks if the current version is compatible with the given version. - /// Compatibility is defined as: - /// - Major and minor versions must match exactly. - /// - Patch versions dont have to match. - pub fn is_compatible(&self, other: &Self) -> bool { - self.major == other.major && self.minor == other.minor - } - - pub fn with_major(mut self, major: u32) -> Self { - self.major = major; - self - } - - pub fn with_minor(mut self, minor: u32) -> Self { - self.minor = minor; - self - } - - pub fn with_patch(mut self, patch: u32) -> Self { - self.patch = patch; - self - } - - /// Returns a string representation of the version in the format `{major}.{minor}`. - #[inline] - pub fn as_major_minor(&self) -> String { - format!("{}.{}", self.major, self.minor) - } - - /// Parses the Crate version field into `SemanticVersion`. - /// - /// Will panic if for any reason the version format is wrong. - #[inline] - pub fn from_crate_version() -> Self { - env!("CARGO_PKG_VERSION").parse().unwrap() - } -} - -impl std::fmt::Display for SemanticVersion { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}.{}.{}", self.major, self.minor, self.patch) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_compatible() { - let version1 = SemanticVersion::from_str("1.2.3").unwrap(); - let version2 = SemanticVersion::from_str("1.2.4").unwrap(); - let version3 = SemanticVersion::from_str("1.3.0").unwrap(); - let version4 = SemanticVersion::from_str("2.0.0").unwrap(); - - assert!(version1.is_compatible(&version2)); - assert!(!version1.is_compatible(&version3)); - assert!(!version1.is_compatible(&version4)); - } -} From e2c484ee23317e283f174803459a852a0611161f Mon Sep 17 00:00:00 2001 From: andthattoo Date: Mon, 2 Mar 2026 14:36:28 +0300 Subject: [PATCH 02/57] Extract dkn-protocol crate, add multi-model worker, reconnect, and challenge-response - Add dkn-protocol path dependency; re-export wire types via thin modules - Remove local copies of protocol types, proof, and chat template (now in dkn-protocol) - Worker holds per-model engines/templates; TPS is HashMap - Implement challenge signing in the main event loop - Add try_reconnect() with exponential backoff on stream close/error - Introduce NodeContext for shared state across event handlers --- Cargo.lock | 12 ++ Cargo.toml | 1 + src/config.rs | 47 ++++- src/error.rs | 6 + src/identity.rs | 2 + src/inference/engine.rs | 4 +- src/inference/mod.rs | 2 - src/inference/proof.rs | 65 ------- src/main.rs | 350 ++++++++++++++++++++++++++--------- src/models/cache.rs | 1 + src/models/mod.rs | 5 +- src/models/registry.rs | 31 ++++ src/models/template.rs | 163 ---------------- src/network/auth.rs | 107 +---------- src/network/connection.rs | 377 +++++++++++++++++++++++++++++++++----- src/network/protocol.rs | 333 +-------------------------------- src/stats.rs | 118 ++++++++++++ src/worker.rs | 78 ++++++-- 18 files changed, 878 insertions(+), 824 deletions(-) delete mode 100644 src/inference/proof.rs delete mode 100644 src/models/template.rs create mode 100644 src/stats.rs diff --git a/Cargo.lock b/Cargo.lock index 00578aa5..c079ac4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -441,6 +441,17 @@ dependencies = [ "syn", ] +[[package]] +name = "dkn-protocol" +version = "0.1.0" +dependencies = [ + "quinn", + "rmp-serde", + "serde", + "thiserror 2.0.12", + "uuid", +] + [[package]] name = "dria-node" version = "2.0.0-alpha.1" @@ -449,6 +460,7 @@ dependencies = [ "bytes", "clap", "dirs", + "dkn-protocol", "encoding_rs", "futures", "hex", diff --git a/Cargo.toml b/Cargo.toml index c2d69768..718b4000 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ anyhow = "1" uuid = { version = "1", features = ["v7", "serde"] } dirs = "6" encoding_rs = "0.8" +dkn-protocol = { path = "../dkn-protocol" } quinn = "0.11" rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } rustls-native-certs = "0.8" diff --git a/src/config.rs b/src/config.rs index 69b2aec5..285b2ae3 100644 --- a/src/config.rs +++ b/src/config.rs @@ -49,7 +49,7 @@ pub enum Command { pub struct Config { pub secret_key_hex: String, pub model_names: Vec, - pub router_url: String, + pub router_urls: Vec, pub gpu_layers: i32, pub max_concurrent: usize, pub data_dir: PathBuf, @@ -102,10 +102,20 @@ impl Config { return Err(NodeError::Config("max-concurrent must be >= 1".into())); } + // Parse router URLs (comma-separated) + let router_urls: Vec = router_url + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + if router_urls.is_empty() { + return Err(NodeError::Config("at least one router URL must be specified".into())); + } + Ok(Config { secret_key_hex, model_names, - router_url, + router_urls, gpu_layers, max_concurrent, data_dir, @@ -138,6 +148,7 @@ mod tests { "6472696164726961647269616472696164726961647269616472696164726961" ); assert_eq!(cfg.models_dir, PathBuf::from("/tmp/dria-test/models")); + assert_eq!(cfg.router_urls, vec!["https://router.dria.co"]); } #[test] @@ -196,6 +207,38 @@ mod tests { assert!(result.is_err()); } + #[test] + fn test_config_comma_separated_router_urls() { + let cfg = Config::from_start_args( + "6472696164726961647269616472696164726961647269616472696164726961".into(), + "gemma3:4b".into(), + "https://router1.dria.co, https://router2.dria.co".into(), + 0, + 1, + None, + false, + ) + .unwrap(); + assert_eq!( + cfg.router_urls, + vec!["https://router1.dria.co", "https://router2.dria.co"] + ); + } + + #[test] + fn test_config_empty_router_url() { + let result = Config::from_start_args( + "6472696164726961647269616472696164726961647269616472696164726961".into(), + "gemma3:4b".into(), + "".into(), + 0, + 1, + None, + false, + ); + assert!(result.is_err()); + } + #[test] fn test_config_insecure_flag() { let cfg = Config::from_start_args( diff --git a/src/error.rs b/src/error.rs index 652e8dee..1287b72e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -20,3 +20,9 @@ pub enum NodeError { #[error("io error: {0}")] Io(#[from] std::io::Error), } + +impl From for NodeError { + fn from(e: dkn_protocol::ProtocolError) -> Self { + NodeError::Network(e.to_string()) + } +} diff --git a/src/identity.rs b/src/identity.rs index 86f39412..2b755369 100644 --- a/src/identity.rs +++ b/src/identity.rs @@ -8,7 +8,9 @@ use crate::error::NodeError; /// The address is an Ethereum-style address (last 20 bytes of keccak256 of uncompressed pubkey). pub struct Identity { pub secret_key: SecretKey, + #[allow(dead_code)] pub public_key: PublicKey, + #[allow(dead_code)] pub address: [u8; 20], pub address_hex: String, } diff --git a/src/inference/engine.rs b/src/inference/engine.rs index b66f5f88..b65363a8 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -12,7 +12,7 @@ use llama_cpp_2::token::LlamaToken; use crate::error::NodeError; use crate::identity::sha256hash; -use crate::inference::proof::{InferenceProof, TokenLogprob}; +use dkn_protocol::{InferenceProof, TokenLogprob}; use crate::inference::stream::StreamToken; /// Parameters controlling text generation. @@ -60,6 +60,7 @@ pub struct InferenceResult { pub struct InferenceEngine { backend: LlamaBackend, model: LlamaModel, + #[allow(dead_code)] gpu_layers: i32, } @@ -95,6 +96,7 @@ impl InferenceEngine { } /// Return the number of GPU layers configured. + #[allow(dead_code)] pub fn gpu_layers(&self) -> i32 { self.gpu_layers } diff --git a/src/inference/mod.rs b/src/inference/mod.rs index 8c077dc3..7e9506cf 100644 --- a/src/inference/mod.rs +++ b/src/inference/mod.rs @@ -1,7 +1,5 @@ pub mod benchmark; pub mod engine; -pub mod proof; pub mod stream; pub use engine::{GenerateParams, InferenceEngine, InferenceResult}; -pub use proof::InferenceProof; diff --git a/src/inference/proof.rs b/src/inference/proof.rs deleted file mode 100644 index 80fd2bc7..00000000 --- a/src/inference/proof.rs +++ /dev/null @@ -1,65 +0,0 @@ -use serde::{Deserialize, Serialize}; - -/// Log-probability information for a single token position. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TokenLogprob { - /// Position in the generated sequence. - pub position: usize, - /// The token ID chosen at this position. - pub token_id: u32, - /// The decoded text of the chosen token. - pub token_text: String, - /// The log-probability of the chosen token. - pub logprob: f32, - /// Top-k alternatives: (token_text, logprob). - pub top_k: Vec<(String, f32)>, -} - -/// Proof-of-inference data for validation. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct InferenceProof { - /// Log-probabilities at requested positions. - pub logprobs: Vec, - /// Optional KV-cache hash for determinism verification. - /// Placeholder: currently hashes logits at probed position. - pub kv_cache_hash: Option<[u8; 32]>, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_token_logprob_serde() { - let lp = TokenLogprob { - position: 5, - token_id: 1234, - token_text: "the".into(), - logprob: -0.5, - top_k: vec![("the".into(), -0.5), ("a".into(), -1.2)], - }; - let json = serde_json::to_string(&lp).unwrap(); - let roundtrip: TokenLogprob = serde_json::from_str(&json).unwrap(); - assert_eq!(roundtrip.position, 5); - assert_eq!(roundtrip.token_id, 1234); - assert_eq!(roundtrip.top_k.len(), 2); - } - - #[test] - fn test_inference_proof_serde() { - let proof = InferenceProof { - logprobs: vec![TokenLogprob { - position: 0, - token_id: 1, - token_text: "hello".into(), - logprob: -0.1, - top_k: vec![], - }], - kv_cache_hash: Some([0xAB; 32]), - }; - let packed = rmp_serde::to_vec(&proof).unwrap(); - let roundtrip: InferenceProof = rmp_serde::from_slice(&packed).unwrap(); - assert_eq!(roundtrip.logprobs.len(), 1); - assert!(roundtrip.kv_cache_hash.is_some()); - } -} diff --git a/src/main.rs b/src/main.rs index 641ebe28..7c792871 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,24 +1,27 @@ -// Suppress dead-code warnings for public APIs not yet wired to networking. -#![allow(dead_code)] - mod config; mod error; mod identity; mod inference; mod models; mod network; +mod stats; mod worker; +use std::collections::HashMap; +use std::sync::Arc; use std::time::Duration; use clap::Parser; +use tokio::sync::mpsc; use tracing_subscriber::EnvFilter; use config::{Cli, Command, Config}; use identity::Identity; use models::{ModelCache, ModelDownloader, default_registry, resolve_model}; +use models::registry::ModelSpec; use network::{NodeMessage, RouterMessage}; use network::RouterConnection; +use stats::NodeStats; use worker::{CompletedTask, Worker}; #[tokio::main] @@ -49,6 +52,22 @@ async fn main() -> anyhow::Result<()> { Ok(()) } +/// Shared state needed by event handlers for reconnection and challenge-response. +struct NodeContext { + identity: Identity, + config: Config, + tps: HashMap, + stats: Arc, + cache: ModelCache, +} + +/// Result of a background model download + load operation. +struct ModelLoadResult { + name: String, + template: String, + result: Result<(inference::InferenceEngine, f64), error::NodeError>, +} + async fn run_start( wallet: String, model: String, @@ -73,88 +92,61 @@ async fn run_start( let registry = default_registry(); let cache = ModelCache::new(config.models_dir.clone())?; - // We need to keep one engine alive for inference; use the first model. - let mut chat_template = "chatml".to_string(); - let mut engine_and_tps: Option<(inference::InferenceEngine, f64)> = None; + // Accumulate engines and TPS per model + let mut engines: HashMap = HashMap::new(); + let mut tps_map: HashMap = HashMap::new(); for model_name in &config.model_names { let spec = resolve_model(model_name, ®istry) .ok_or_else(|| error::NodeError::Model(format!("unknown model: {model_name}")))?; - // Check local cache first - let model_path = if let Some(path) = cache.get_local_path(&spec) { - tracing::info!(model = %model_name, path = %path.display(), "model found in cache"); - path - } else { - // Download from HuggingFace - let hf_path = ModelDownloader::download(&spec).await?; - - // Verify SHA-256 if specified - if let Some(ref expected_sha) = spec.sha256 { - tracing::info!(model = %model_name, "verifying SHA-256"); - if !ModelCache::verify_sha256(&hf_path, expected_sha)? { - anyhow::bail!("SHA-256 mismatch for model {model_name}"); - } - } - - // Link into our cache - cache.link_model(&spec, &hf_path)? - }; - - // Remember chat template from the spec - if let Some(ref tmpl) = spec.chat_template { - chat_template = tmpl.clone(); - } - - // Load model and run benchmark in blocking thread - let model_name_owned = model_name.clone(); - let gpu = config.gpu_layers; - let (engine, tps) = tokio::task::spawn_blocking(move || { - let engine = inference::InferenceEngine::load(&model_path, gpu)?; - let tps_result = engine.benchmark(&model_name_owned)?; - Ok::<_, error::NodeError>((engine, tps_result.generation_tps)) - }) - .await??; + let (engine, tps) = download_and_load_model(&spec, &cache, config.gpu_layers).await?; + let chat_template = spec + .chat_template + .clone() + .unwrap_or_else(|| "chatml".to_string()); tracing::info!(tps = %format!("{tps:.1}"), model = %model_name, "benchmark complete"); - engine_and_tps = Some((engine, tps)); + engines.insert(model_name.clone(), (engine, chat_template)); + tps_map.insert(model_name.clone(), tps); } - let (engine, tps) = engine_and_tps.ok_or_else(|| { - error::NodeError::Config("no models loaded".into()) - })?; + if engines.is_empty() { + return Err(error::NodeError::Config("no models loaded".into()).into()); + } // Build the worker - let mut worker = Worker::new( - engine, - chat_template, - config.model_names.clone(), - config.max_concurrent, - ); - - // Attempt router connection; go offline if unavailable - let mut connection: Option = match RouterConnection::connect( - &config.router_url, - config.insecure, - &identity, - config.model_names.clone(), - tps, - worker.capacity(), - ) - .await - { - Ok(conn) => { - tracing::info!(node_id = %conn.node_id, "connected to router"); - Some(conn) - } - Err(e) => { - tracing::warn!(%e, "failed to connect to router, running in offline mode"); - None + let mut worker = Worker::new(engines, config.max_concurrent); + + // Attempt router connection; try each URL, go offline if all unavailable + let mut connection: Option = None; + for url in &config.router_urls { + match RouterConnection::connect( + url, + config.insecure, + &identity, + config.model_names.clone(), + tps_map.clone(), + worker.capacity(), + ) + .await + { + Ok(conn) => { + tracing::info!(node_id = %conn.node_id, router = %url, "connected to router"); + connection = Some(conn); + break; + } + Err(e) => { + tracing::warn!(%e, router = %url, "failed to connect to router"); + } } - }; + } + if connection.is_none() { + tracing::warn!("all routers unavailable, running in offline mode"); + } tracing::info!( - router = %config.router_url, + routers = ?config.router_urls, models = ?config.model_names, max_concurrent = config.max_concurrent, insecure = config.insecure, @@ -162,38 +154,71 @@ async fn run_start( "node ready" ); + // Build shared context for event handlers + let stats = Arc::new(NodeStats::new()); + let mut ctx = NodeContext { + identity, + config, + tps: tps_map, + stats: Arc::clone(&stats), + cache, + }; + + // Channel for background model load results + let (model_tx, mut model_rx) = mpsc::unbounded_channel::(); + // Main event loop + let mut stats_interval = tokio::time::interval(Duration::from_secs(60)); + stats_interval.tick().await; // consume the immediate first tick loop { let event = tokio::select! { msg = recv_router_msg(&mut connection) => Event::RouterMsg(msg), Some(done) = worker.next_completed() => Event::TaskDone(done), + Some(loaded) = model_rx.recv() => Event::ModelLoaded(loaded), + _ = stats_interval.tick() => Event::StatsLog, _ = tokio::signal::ctrl_c() => Event::Shutdown, }; match event { Event::RouterMsg(Ok(Some(msg))) => { - handle_router_message(msg, &mut worker, &mut connection).await; + handle_router_message(msg, &mut worker, &mut connection, &mut ctx, &model_tx).await; } Event::RouterMsg(Ok(None)) => { // Stream closed cleanly - tracing::warn!("router stream closed, switching to offline mode"); + tracing::warn!("router stream closed, attempting reconnect"); if let Some(ref conn) = connection { conn.close(); } - connection = None; + connection = try_reconnect(&ctx, worker.capacity()).await; } Event::RouterMsg(Err(e)) => { - tracing::warn!(%e, "router communication error"); + tracing::warn!(%e, "router communication error, attempting reconnect"); if let Some(ref conn) = connection { conn.close(); } - connection = None; - - // Attempt reconnect - tracing::info!("will attempt reconnect on next cycle"); + connection = try_reconnect(&ctx, worker.capacity()).await; } Event::TaskDone(completed) => { - handle_completed_task(completed, &mut connection).await; + handle_completed_task(completed, &mut connection, &ctx.stats).await; + } + Event::ModelLoaded(loaded) => { + match loaded.result { + Ok((engine, tps)) => { + tracing::info!( + model = %loaded.name, + tps = %format!("{tps:.1}"), + "model loaded successfully" + ); + worker.add_engine(loaded.name.clone(), engine, loaded.template); + ctx.tps.insert(loaded.name, tps); + } + Err(e) => { + tracing::error!(model = %loaded.name, %e, "failed to load model"); + } + } + } + Event::StatsLog => { + ctx.stats.log_summary(); } Event::Shutdown => { tracing::info!("shutdown signal received"); @@ -210,7 +235,7 @@ async fn run_start( loop { tokio::select! { Some(completed) = worker.next_completed() => { - handle_completed_task(completed, &mut connection).await; + handle_completed_task(completed, &mut connection, &ctx.stats).await; } _ = tokio::time::sleep_until(drain_deadline) => { tracing::warn!("drain timeout reached, dropping remaining tasks"); @@ -238,6 +263,8 @@ async fn run_start( enum Event { RouterMsg(Result, error::NodeError>), TaskDone(CompletedTask), + ModelLoaded(ModelLoadResult), + StatsLog, Shutdown, } @@ -259,11 +286,56 @@ async fn recv_router_msg( } } -/// Handle a router message: dispatch tasks, respond to pings, etc. +/// Attempt to reconnect to the router with exponential backoff. +/// +/// Tries up to 5 rounds, iterating all router URLs per round (1s → 2s → 4s → 8s → 16s), +/// then gives up and returns None so the main loop can fall back to the offline sleep-and-retry cycle. +async fn try_reconnect( + ctx: &NodeContext, + capacity: network::protocol::Capacity, +) -> Option { + let mut delay = Duration::from_secs(1); + let max_rounds = 5; + + for round in 1..=max_rounds { + tracing::info!(round, delay_secs = delay.as_secs(), "attempting reconnect"); + tokio::time::sleep(delay).await; + + for url in &ctx.config.router_urls { + match RouterConnection::connect( + url, + ctx.config.insecure, + &ctx.identity, + ctx.config.model_names.clone(), + ctx.tps.clone(), + capacity.clone(), + ) + .await + { + Ok(conn) => { + tracing::info!(node_id = %conn.node_id, router = %url, "reconnected to router"); + return Some(conn); + } + Err(e) => { + tracing::warn!(%e, router = %url, round, "reconnect attempt failed"); + } + } + } + + delay *= 2; + } + + tracing::warn!("all reconnect attempts exhausted, running in offline mode"); + None +} + +/// Handle a router message: dispatch tasks, respond to pings, sign challenges, etc. async fn handle_router_message( msg: RouterMessage, worker: &mut Worker, connection: &mut Option, + ctx: &mut NodeContext, + model_tx: &mpsc::UnboundedSender, ) { match msg { RouterMessage::TaskAssignment { @@ -281,6 +353,7 @@ async fn handle_router_message( tracing::debug!(%task_id, "task accepted"); } Err(reason) => { + ctx.stats.record_rejected(); tracing::warn!(%task_id, ?reason, "task rejected"); if let Some(ref mut conn) = connection { let reject = NodeMessage::TaskRejected { task_id, reason }; @@ -295,9 +368,10 @@ async fn handle_router_message( tracing::debug!("received ping"); if let Some(ref mut conn) = connection { let status = NodeMessage::StatusUpdate { - models: worker.model_names().to_vec(), + models: worker.model_names(), capacity: worker.capacity(), version: env!("CARGO_PKG_VERSION").to_string(), + stats: Some(ctx.stats.snapshot()), }; if let Err(e) = conn.send(&status).await { tracing::error!(%e, "failed to send status update"); @@ -305,26 +379,123 @@ async fn handle_router_message( } } RouterMessage::Challenge { challenge } => { - tracing::debug!(?challenge, "received challenge (not yet implemented)"); - // TODO: implement challenge-response + tracing::debug!("received challenge, signing response"); + let (sig, recid) = ctx.identity.sign(&challenge); + if let Some(ref mut conn) = connection { + let response = NodeMessage::ChallengeResponse { + challenge, + signature: sig.serialize().to_vec(), + recovery_id: recid.serialize(), + }; + if let Err(e) = conn.send(&response).await { + tracing::error!(%e, "failed to send challenge response"); + } + } } RouterMessage::ModelRegistryUpdate { entries } => { - tracing::info!(count = entries.len(), "received model registry update (not yet implemented)"); - // TODO: handle model registry updates + tracing::info!(count = entries.len(), "received model registry update"); + + // Compute desired set from entries + let desired: HashMap = entries + .iter() + .map(|e| (e.name.clone(), e)) + .collect(); + + // Remove models not in the desired set + let current = worker.model_names(); + for name in ¤t { + if !desired.contains_key(name) { + tracing::info!(model = %name, "removing model (not in registry)"); + worker.remove_engine(name); + ctx.tps.remove(name); + } + } + + // Spawn background download+load for new models + for entry in &entries { + if !worker.has_model(&entry.name) { + let spec = ModelSpec::from_registry_entry(entry); + let cache = ctx.cache.clone(); + let gpu_layers = ctx.config.gpu_layers; + let tx = model_tx.clone(); + let name = entry.name.clone(); + let template = entry + .chat_template + .clone() + .unwrap_or_else(|| "chatml".to_string()); + + tracing::info!(model = %name, "spawning background model download+load"); + tokio::spawn(async move { + let result = download_and_load_model(&spec, &cache, gpu_layers).await; + let _ = tx.send(ModelLoadResult { name, template, result }); + }); + } + } } } } +/// Download (if needed), verify, cache, load, and benchmark a model. +/// +/// Returns the loaded engine and its benchmark TPS. +async fn download_and_load_model( + spec: &ModelSpec, + cache: &ModelCache, + gpu_layers: i32, +) -> Result<(inference::InferenceEngine, f64), error::NodeError> { + let model_name = spec.name.clone(); + + // Check local cache first + let model_path = if let Some(path) = cache.get_local_path(spec) { + tracing::info!(model = %model_name, path = %path.display(), "model found in cache"); + path + } else { + // Download from HuggingFace + let hf_path = ModelDownloader::download(spec).await?; + + // Verify SHA-256 if specified + if let Some(ref expected_sha) = spec.sha256 { + tracing::info!(model = %model_name, "verifying SHA-256"); + if !ModelCache::verify_sha256(&hf_path, expected_sha)? { + return Err(error::NodeError::Model(format!( + "SHA-256 mismatch for model {model_name}" + ))); + } + } + + // Link into our cache + cache.link_model(spec, &hf_path)? + }; + + // Load model and run benchmark in blocking thread + let (engine, tps) = tokio::task::spawn_blocking(move || { + let engine = inference::InferenceEngine::load(&model_path, gpu_layers)?; + let tps_result = engine.benchmark(&model_name)?; + Ok::<_, error::NodeError>((engine, tps_result.generation_tps)) + }) + .await + .map_err(|e| error::NodeError::Inference(format!("task join error: {e}")))? + ?; + + Ok((engine, tps)) +} + /// Handle a completed inference task: send result or log if offline. async fn handle_completed_task( completed: CompletedTask, connection: &mut Option, + stats: &NodeStats, ) { match completed.result { - Ok(msg) => { + Ok(ref msg) => { + let tokens = match msg { + NodeMessage::TaskResult { stats: ts, .. } => ts.tokens_generated, + _ => 0, + }; + stats.record_completed(tokens); tracing::info!(task_id = %completed.task_id, "task completed"); if let Some(ref mut conn) = connection { - if let Err(e) = conn.send(&msg).await { + if let Err(e) = conn.send(msg).await { tracing::error!(%e, task_id = %completed.task_id, "failed to send result"); } } else { @@ -332,6 +503,7 @@ async fn handle_completed_task( } } Err(e) => { + stats.record_failed(); tracing::error!(%e, task_id = %completed.task_id, "task failed"); } } diff --git a/src/models/cache.rs b/src/models/cache.rs index 590e7435..6b599386 100644 --- a/src/models/cache.rs +++ b/src/models/cache.rs @@ -6,6 +6,7 @@ use crate::error::NodeError; use crate::models::registry::ModelSpec; /// Manages local model file cache. +#[derive(Clone)] pub struct ModelCache { pub cache_dir: PathBuf, } diff --git a/src/models/mod.rs b/src/models/mod.rs index 5d814931..fd5e9529 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,7 +1,10 @@ pub mod cache; pub mod download; pub mod registry; -pub mod template; + +pub mod template { + pub use dkn_protocol::{apply_chat_template, ChatMessage}; +} pub use cache::ModelCache; pub use download::ModelDownloader; diff --git a/src/models/registry.rs b/src/models/registry.rs index 41dce7c7..4eec6ba3 100644 --- a/src/models/registry.rs +++ b/src/models/registry.rs @@ -1,5 +1,7 @@ use std::collections::HashMap; +use dkn_protocol::ModelRegistryEntry; + /// Specification for a model: shortname mapped to HuggingFace GGUF location. #[derive(Debug, Clone)] pub struct ModelSpec { @@ -86,6 +88,19 @@ pub fn default_registry() -> HashMap { entries.into_iter().map(|s| (s.name.clone(), s)).collect() } +impl ModelSpec { + /// Create a ModelSpec from a router-provided registry entry. + pub fn from_registry_entry(entry: &ModelRegistryEntry) -> Self { + ModelSpec { + name: entry.name.clone(), + hf_repo: entry.hf_repo.clone(), + hf_file: entry.hf_file.clone(), + sha256: None, + chat_template: entry.chat_template.clone(), + } + } +} + /// Resolve a user-provided model name to a ModelSpec from the registry. pub fn resolve_model(name: &str, registry: &HashMap) -> Option { registry.get(name).cloned() @@ -129,4 +144,20 @@ mod tests { let reg = default_registry(); assert!(resolve_model("nonexistent:1b", ®).is_none()); } + + #[test] + fn test_from_registry_entry() { + let entry = ModelRegistryEntry { + name: "test:1b".into(), + hf_repo: "test/repo".into(), + hf_file: "model.gguf".into(), + chat_template: Some("chatml".into()), + }; + let spec = ModelSpec::from_registry_entry(&entry); + assert_eq!(spec.name, "test:1b"); + assert_eq!(spec.hf_repo, "test/repo"); + assert_eq!(spec.hf_file, "model.gguf"); + assert!(spec.sha256.is_none()); + assert_eq!(spec.chat_template, Some("chatml".into())); + } } diff --git a/src/models/template.rs b/src/models/template.rs deleted file mode 100644 index 060d4928..00000000 --- a/src/models/template.rs +++ /dev/null @@ -1,163 +0,0 @@ -use serde::{Deserialize, Serialize}; - -/// A single message in a chat conversation. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatMessage { - pub role: String, - pub content: String, -} - -/// Apply a chat template to a list of messages, producing a formatted prompt string. -/// -/// Supported templates: "chatml" (qwen, mistral-nemo), "llama3", "gemma". -/// Falls back to chatml for unknown template names. -pub fn apply_chat_template(template_name: &str, messages: &[ChatMessage]) -> String { - match template_name { - "llama3" => format_llama3(messages), - "gemma" => format_gemma(messages), - _ => format_chatml(messages), // chatml is the default fallback - } -} - -/// ChatML format used by Qwen, Mistral-Nemo, and others. -/// ```text -/// <|im_start|>system -/// You are a helpful assistant.<|im_end|> -/// <|im_start|>user -/// Hello<|im_end|> -/// <|im_start|>assistant -/// ``` -fn format_chatml(messages: &[ChatMessage]) -> String { - let mut out = String::new(); - for msg in messages { - out.push_str(&format!( - "<|im_start|>{}\n{}<|im_end|>\n", - msg.role, msg.content - )); - } - out.push_str("<|im_start|>assistant\n"); - out -} - -/// Llama 3 instruct format. -/// ```text -/// <|begin_of_text|><|start_header_id|>system<|end_header_id|> -/// -/// You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|> -/// -/// Hello<|eot_id|><|start_header_id|>assistant<|end_header_id|> -/// -/// ``` -fn format_llama3(messages: &[ChatMessage]) -> String { - let mut out = String::from("<|begin_of_text|>"); - for msg in messages { - out.push_str(&format!( - "<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>", - msg.role, msg.content - )); - } - out.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n"); - out -} - -/// Gemma instruct format. -/// ```text -/// user -/// Hello -/// model -/// ``` -fn format_gemma(messages: &[ChatMessage]) -> String { - let mut out = String::new(); - for msg in messages { - // Gemma uses "model" instead of "assistant" - let role = if msg.role == "assistant" { - "model" - } else { - &msg.role - }; - out.push_str(&format!( - "{}\n{}\n", - role, msg.content - )); - } - out.push_str("model\n"); - out -} - -#[cfg(test)] -mod tests { - use super::*; - - fn sample_messages() -> Vec { - vec![ - ChatMessage { - role: "system".into(), - content: "You are a helpful assistant.".into(), - }, - ChatMessage { - role: "user".into(), - content: "Hello".into(), - }, - ] - } - - #[test] - fn test_chatml_format() { - let result = apply_chat_template("chatml", &sample_messages()); - assert!(result.contains("<|im_start|>system")); - assert!(result.contains("You are a helpful assistant.<|im_end|>")); - assert!(result.contains("<|im_start|>user")); - assert!(result.contains("Hello<|im_end|>")); - assert!(result.ends_with("<|im_start|>assistant\n")); - } - - #[test] - fn test_llama3_format() { - let result = apply_chat_template("llama3", &sample_messages()); - assert!(result.starts_with("<|begin_of_text|>")); - assert!(result.contains("<|start_header_id|>system<|end_header_id|>")); - assert!(result.contains("<|start_header_id|>user<|end_header_id|>")); - assert!(result.ends_with("<|start_header_id|>assistant<|end_header_id|>\n\n")); - } - - #[test] - fn test_gemma_format() { - let msgs = vec![ - ChatMessage { - role: "user".into(), - content: "Hello".into(), - }, - ChatMessage { - role: "assistant".into(), - content: "Hi there!".into(), - }, - ChatMessage { - role: "user".into(), - content: "How are you?".into(), - }, - ]; - let result = apply_chat_template("gemma", &msgs); - assert!(result.contains("user")); - // "assistant" should be mapped to "model" - assert!(result.contains("model\nHi there!")); - assert!(result.ends_with("model\n")); - } - - #[test] - fn test_unknown_template_falls_back_to_chatml() { - let result = apply_chat_template("unknown-template", &sample_messages()); - assert!(result.contains("<|im_start|>")); - } - - #[test] - fn test_chat_message_serde() { - let msg = ChatMessage { - role: "user".into(), - content: "hello".into(), - }; - let json = serde_json::to_string(&msg).unwrap(); - let roundtrip: ChatMessage = serde_json::from_str(&json).unwrap(); - assert_eq!(roundtrip.role, "user"); - assert_eq!(roundtrip.content, "hello"); - } -} diff --git a/src/network/auth.rs b/src/network/auth.rs index af074c6c..0ddb4d87 100644 --- a/src/network/auth.rs +++ b/src/network/auth.rs @@ -1,48 +1,10 @@ -use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use crate::error::NodeError; use crate::identity::Identity; use crate::network::protocol::{Capacity, read_framed, write_framed}; -// --------------------------------------------------------------------------- -// Auth protocol types (separate from NodeMessage/RouterMessage since auth -// happens before the main protocol phase) -// --------------------------------------------------------------------------- - -/// Router sends this challenge after QUIC connection is established. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChallengeMessage { - pub challenge: [u8; 32], -} - -/// Node responds with signed challenge + metadata. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AuthRequest { - /// Ethereum address (hex, no 0x prefix). - pub address: String, - /// Signature over SHA-256(challenge). - pub signature: Vec, - /// Recovery ID for the signature. - pub recovery_id: u8, - /// Models this node can serve. - pub models: Vec, - /// Benchmark tokens-per-second. - pub tps: f64, - /// Node software version. - pub version: String, - /// Current capacity. - pub capacity: Capacity, -} - -/// Router responds with auth result. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AuthResponse { - pub authenticated: bool, - /// Assigned node ID on success. - pub node_id: Option, - /// Error message on failure. - pub error: Option, -} +pub use dkn_protocol::{AuthRequest, AuthResponse, ChallengeMessage}; // --------------------------------------------------------------------------- // Handshake @@ -59,7 +21,7 @@ pub async fn authenticate( recv: &mut quinn::RecvStream, identity: &Identity, models: Vec, - tps: f64, + tps: HashMap, capacity: Capacity, ) -> Result { // 1. Read challenge @@ -98,66 +60,3 @@ pub async fn authenticate( ))) } } - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_challenge_message_roundtrip() { - let msg = ChallengeMessage { - challenge: [0x42; 32], - }; - let packed = rmp_serde::to_vec(&msg).unwrap(); - let roundtrip: ChallengeMessage = rmp_serde::from_slice(&packed).unwrap(); - assert_eq!(roundtrip.challenge, [0x42; 32]); - } - - #[test] - fn test_auth_request_roundtrip() { - let req = AuthRequest { - address: "deadbeef".into(), - signature: vec![1, 2, 3], - recovery_id: 0, - models: vec!["gemma3:4b".into()], - tps: 42.5, - version: "2.0.0".into(), - capacity: Capacity { free: 1, max: 2 }, - }; - let packed = rmp_serde::to_vec(&req).unwrap(); - let roundtrip: AuthRequest = rmp_serde::from_slice(&packed).unwrap(); - assert_eq!(roundtrip.address, "deadbeef"); - assert_eq!(roundtrip.models, vec!["gemma3:4b"]); - assert!((roundtrip.tps - 42.5).abs() < f64::EPSILON); - } - - #[test] - fn test_auth_response_success_roundtrip() { - let resp = AuthResponse { - authenticated: true, - node_id: Some("node-123".into()), - error: None, - }; - let packed = rmp_serde::to_vec(&resp).unwrap(); - let roundtrip: AuthResponse = rmp_serde::from_slice(&packed).unwrap(); - assert!(roundtrip.authenticated); - assert_eq!(roundtrip.node_id.unwrap(), "node-123"); - } - - #[test] - fn test_auth_response_failure_roundtrip() { - let resp = AuthResponse { - authenticated: false, - node_id: None, - error: Some("bad signature".into()), - }; - let packed = rmp_serde::to_vec(&resp).unwrap(); - let roundtrip: AuthResponse = rmp_serde::from_slice(&packed).unwrap(); - assert!(!roundtrip.authenticated); - assert_eq!(roundtrip.error.unwrap(), "bad signature"); - } -} diff --git a/src/network/connection.rs b/src/network/connection.rs index 722ded67..c6c3b5f0 100644 --- a/src/network/connection.rs +++ b/src/network/connection.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -22,10 +23,6 @@ pub struct RouterConnection { send: quinn::SendStream, /// Receive half of the bi-directional stream. recv: quinn::RecvStream, - /// Router URL for reconnection. - router_url: String, - /// Whether to skip TLS verification. - insecure: bool, /// Assigned node ID from the router. pub node_id: String, } @@ -37,7 +34,7 @@ impl RouterConnection { insecure: bool, identity: &Identity, models: Vec, - tps: f64, + tps: HashMap, capacity: Capacity, ) -> Result { let (host, port) = parse_url(router_url)?; @@ -67,57 +64,18 @@ impl RouterConnection { connection, send, recv, - router_url: router_url.to_string(), - insecure, node_id, }) } /// Send a message to the router. pub async fn send(&mut self, msg: &NodeMessage) -> Result<(), NodeError> { - write_framed(&mut self.send, msg).await + Ok(write_framed(&mut self.send, msg).await?) } /// Receive a message from the router. Returns `None` on clean stream close. pub async fn recv(&mut self) -> Result, NodeError> { - read_framed(&mut self.recv).await - } - - /// Attempt to reconnect with exponential backoff. - /// - /// Retries: 1s → 2s → 4s → 8s → ... → 60s cap. - pub async fn reconnect( - &mut self, - identity: &Identity, - models: Vec, - tps: f64, - capacity: Capacity, - ) -> Result<(), NodeError> { - let mut delay = Duration::from_secs(1); - let max_delay = Duration::from_secs(60); - - loop { - tracing::info!(delay_secs = delay.as_secs(), "attempting reconnect"); - tokio::time::sleep(delay).await; - - match Self::connect(&self.router_url, self.insecure, identity, models.clone(), tps, capacity.clone()) - .await - { - Ok(new_conn) => { - self.endpoint = new_conn.endpoint; - self.connection = new_conn.connection; - self.send = new_conn.send; - self.recv = new_conn.recv; - self.node_id = new_conn.node_id; - tracing::info!(node_id = %self.node_id, "reconnected to router"); - return Ok(()); - } - Err(e) => { - tracing::warn!(%e, "reconnect failed"); - delay = (delay * 2).min(max_delay); - } - } - } + Ok(read_framed(&mut self.recv).await?) } /// Close the connection and endpoint gracefully. @@ -403,7 +361,7 @@ mod tests { &mut recv, &identity, vec!["gemma3:4b".into()], - 42.0, + HashMap::from([("gemma3:4b".to_string(), 42.0)]), Capacity { free: 1, max: 2 }, ) .await @@ -416,6 +374,7 @@ mod tests { models: vec!["gemma3:4b".into()], capacity: Capacity { free: 1, max: 2 }, version: env!("CARGO_PKG_VERSION").to_string(), + stats: None, }; write_framed(&mut send, &status).await.unwrap(); @@ -507,7 +466,7 @@ mod tests { true, &identity, vec!["gemma3:4b".into()], - 50.0, + HashMap::from([("gemma3:4b".to_string(), 50.0)]), Capacity { free: 2, max: 4 }, ) .await @@ -524,6 +483,328 @@ mod tests { models: vec!["gemma3:4b".into()], capacity: Capacity { free: 2, max: 4 }, version: env!("CARGO_PKG_VERSION").to_string(), + stats: None, + }) + .await + .unwrap(); + + rx.await.expect("server did not signal completion"); + conn.close(); + }) + .await + .expect("test timed out"); + } + + /// Helper: run the auth handshake as a mock router on an accepted connection. + async fn mock_router_auth( + send: &mut quinn::SendStream, + recv: &mut quinn::RecvStream, + node_id: &str, + ) { + write_framed( + send, + &crate::network::auth::ChallengeMessage { + challenge: [0xCC; 32], + }, + ) + .await + .unwrap(); + + let _auth_req: crate::network::auth::AuthRequest = + read_framed(recv).await.unwrap().unwrap(); + + write_framed( + send, + &crate::network::auth::AuthResponse { + authenticated: true, + node_id: Some(node_id.into()), + error: None, + }, + ) + .await + .unwrap(); + } + + /// Helper: build a mock QUIC server endpoint with self-signed cert. + fn build_mock_server_endpoint() -> Endpoint { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert_der = CertificateDer::from(cert.cert); + let key_der = + rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()); + + let server_crypto = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert_der.clone()], key_der.into()) + .unwrap(); + + let mut server_config = quinn::ServerConfig::with_crypto(Arc::new( + quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto).unwrap(), + )); + let mut transport = TransportConfig::default(); + transport.max_concurrent_bidi_streams(8u32.into()); + server_config.transport_config(Arc::new(transport)); + + Endpoint::server(server_config, "127.0.0.1:0".parse().unwrap()).unwrap() + } + + /// Integration test: full message flow (challenge → auth → ping → status → task assignment → rejection). + #[tokio::test] + async fn test_full_message_flow() { + tokio::time::timeout(Duration::from_secs(10), async { + let server_endpoint = build_mock_server_endpoint(); + let server_addr = server_endpoint.local_addr().unwrap(); + + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + + tokio::spawn(async move { + let incoming = server_endpoint.accept().await.unwrap(); + let server_conn = incoming.await.unwrap(); + let (mut send, mut recv) = server_conn.open_bi().await.unwrap(); + + // Auth + mock_router_auth(&mut send, &mut recv, "flow-node").await; + + // Send ping + write_framed(&mut send, &RouterMessage::Ping).await.unwrap(); + + // Read status update + let msg: NodeMessage = read_framed(&mut recv).await.unwrap().unwrap(); + assert!(matches!(msg, NodeMessage::StatusUpdate { .. })); + + // Send a task assignment (node has no real model → rejection expected) + let task_id = uuid::Uuid::nil(); + write_framed( + &mut send, + &RouterMessage::TaskAssignment { + task_id, + model: "nonexistent:1b".into(), + messages: vec![dkn_protocol::ChatMessage { + role: "user".into(), + content: "test".into(), + }], + max_tokens: 10, + temperature: 0.7, + validation: None, + }, + ) + .await + .unwrap(); + + // Read the task rejection + let reject: NodeMessage = read_framed(&mut recv).await.unwrap().unwrap(); + match reject { + NodeMessage::TaskRejected { task_id: tid, reason } => { + assert_eq!(tid, task_id); + assert!(matches!(reason, dkn_protocol::RejectReason::ModelNotLoaded)); + } + _ => panic!("expected TaskRejected, got {reject:?}"), + } + + let _ = tx.send(()); + server_conn.close(0u32.into(), b"done"); + server_endpoint.close(0u32.into(), b"shutdown"); + }); + + let url = format!("127.0.0.1:{}", server_addr.port()); + let identity = Identity::from_secret_hex( + "6472696164726961647269616472696164726961647269616472696164726961", + ) + .unwrap(); + + let mut conn = RouterConnection::connect( + &url, + true, + &identity, + vec!["gemma3:4b".into()], + HashMap::from([("gemma3:4b".to_string(), 50.0)]), + Capacity { free: 1, max: 1 }, + ) + .await + .unwrap(); + assert_eq!(conn.node_id, "flow-node"); + + // Receive ping → reply with status + let msg = conn.recv().await.unwrap().unwrap(); + assert!(matches!(msg, RouterMessage::Ping)); + + conn.send(&NodeMessage::StatusUpdate { + models: vec!["gemma3:4b".into()], + capacity: Capacity { free: 1, max: 1 }, + version: env!("CARGO_PKG_VERSION").to_string(), + stats: None, + }) + .await + .unwrap(); + + // Receive task assignment → we just forward to test; in real code the worker handles it + let task_msg = conn.recv().await.unwrap().unwrap(); + match task_msg { + RouterMessage::TaskAssignment { task_id, .. } => { + // Reject: model not loaded (only "gemma3:4b" is listed but task asks for "nonexistent:1b") + conn.send(&NodeMessage::TaskRejected { + task_id, + reason: dkn_protocol::RejectReason::ModelNotLoaded, + }) + .await + .unwrap(); + } + _ => panic!("expected TaskAssignment"), + } + + rx.await.expect("server did not signal completion"); + conn.close(); + }) + .await + .expect("test timed out"); + } + + /// Integration test: multi-router failover — first server closes immediately, second handles auth. + #[tokio::test] + async fn test_multi_router_failover() { + tokio::time::timeout(Duration::from_secs(10), async { + // First "bad" server: accepts connection then immediately closes + let bad_endpoint = build_mock_server_endpoint(); + let bad_addr = bad_endpoint.local_addr().unwrap(); + + tokio::spawn(async move { + if let Some(incoming) = bad_endpoint.accept().await { + let conn = incoming.await.unwrap(); + // Immediately close without opening a stream + conn.close(0u32.into(), b"go away"); + bad_endpoint.close(0u32.into(), b"shutdown"); + } + }); + + // Second "good" server: handles auth normally + let good_endpoint = build_mock_server_endpoint(); + let good_addr = good_endpoint.local_addr().unwrap(); + + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + + tokio::spawn(async move { + let incoming = good_endpoint.accept().await.unwrap(); + let server_conn = incoming.await.unwrap(); + let (mut send, mut recv) = server_conn.open_bi().await.unwrap(); + + mock_router_auth(&mut send, &mut recv, "failover-node").await; + + let _ = tx.send(()); + // Keep connection alive until test finishes + tokio::time::sleep(Duration::from_secs(2)).await; + server_conn.close(0u32.into(), b"done"); + good_endpoint.close(0u32.into(), b"shutdown"); + }); + + let identity = Identity::from_secret_hex( + "6472696164726961647269616472696164726961647269616472696164726961", + ) + .unwrap(); + + // Try bad server first, then good server + let urls = vec![ + format!("127.0.0.1:{}", bad_addr.port()), + format!("127.0.0.1:{}", good_addr.port()), + ]; + + let mut connected = None; + for url in &urls { + match RouterConnection::connect( + url, + true, + &identity, + vec!["gemma3:4b".into()], + HashMap::from([("gemma3:4b".to_string(), 50.0)]), + Capacity { free: 1, max: 1 }, + ) + .await + { + Ok(conn) => { + connected = Some(conn); + break; + } + Err(_) => continue, + } + } + + let conn = connected.expect("should have connected to second server"); + assert_eq!(conn.node_id, "failover-node"); + + rx.await.expect("server did not signal completion"); + conn.close(); + }) + .await + .expect("test timed out"); + } + + /// Integration test: stats field is present in StatusUpdate. + #[tokio::test] + async fn test_status_update_with_stats() { + tokio::time::timeout(Duration::from_secs(10), async { + let server_endpoint = build_mock_server_endpoint(); + let server_addr = server_endpoint.local_addr().unwrap(); + + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + + tokio::spawn(async move { + let incoming = server_endpoint.accept().await.unwrap(); + let server_conn = incoming.await.unwrap(); + let (mut send, mut recv) = server_conn.open_bi().await.unwrap(); + + mock_router_auth(&mut send, &mut recv, "stats-node").await; + + // Send ping + write_framed(&mut send, &RouterMessage::Ping).await.unwrap(); + + // Read status update — verify stats field is present + let msg: NodeMessage = read_framed(&mut recv).await.unwrap().unwrap(); + match msg { + NodeMessage::StatusUpdate { stats, version, .. } => { + assert_eq!(version, env!("CARGO_PKG_VERSION")); + let s = stats.expect("stats should be present"); + assert_eq!(s.tasks_completed, 42); + assert_eq!(s.total_tokens_generated, 1000); + } + _ => panic!("expected StatusUpdate"), + } + + let _ = tx.send(()); + server_conn.close(0u32.into(), b"done"); + server_endpoint.close(0u32.into(), b"shutdown"); + }); + + let url = format!("127.0.0.1:{}", server_addr.port()); + let identity = Identity::from_secret_hex( + "6472696164726961647269616472696164726961647269616472696164726961", + ) + .unwrap(); + + let mut conn = RouterConnection::connect( + &url, + true, + &identity, + vec!["gemma3:4b".into()], + HashMap::from([("gemma3:4b".to_string(), 50.0)]), + Capacity { free: 1, max: 1 }, + ) + .await + .unwrap(); + + // Receive ping + let msg = conn.recv().await.unwrap().unwrap(); + assert!(matches!(msg, RouterMessage::Ping)); + + // Send status update with stats + conn.send(&NodeMessage::StatusUpdate { + models: vec!["gemma3:4b".into()], + capacity: Capacity { free: 1, max: 1 }, + version: env!("CARGO_PKG_VERSION").to_string(), + stats: Some(dkn_protocol::NodeStatsSnapshot { + tasks_completed: 42, + tasks_failed: 3, + tasks_rejected: 1, + total_tokens_generated: 1000, + uptime_secs: 600, + }), }) .await .unwrap(); diff --git a/src/network/protocol.rs b/src/network/protocol.rs index 91749dfe..1aa0e3f3 100644 --- a/src/network/protocol.rs +++ b/src/network/protocol.rs @@ -1,329 +1,4 @@ -use serde::{Deserialize, Serialize}; -use uuid::Uuid; - -use crate::error::NodeError; -use crate::models::template::ChatMessage; - -// --------------------------------------------------------------------------- -// Node → Router messages -// --------------------------------------------------------------------------- - -/// Messages sent from this compute node to the router. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum NodeMessage { - /// Completed inference result for a task. - TaskResult { - task_id: Uuid, - text: String, - stats: TaskStats, - proof: Option, - }, - /// We cannot accept the assigned task. - TaskRejected { - task_id: Uuid, - reason: RejectReason, - }, - /// Periodic or on-demand status snapshot. - StatusUpdate { - models: Vec, - capacity: Capacity, - version: String, - }, - /// Response to a router challenge (placeholder). - ChallengeResponse { - challenge: [u8; 32], - signature: Vec, - recovery_id: u8, - }, -} - -// --------------------------------------------------------------------------- -// Router → Node messages -// --------------------------------------------------------------------------- - -/// Messages sent from the router to this compute node. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum RouterMessage { - /// A new inference task to execute. - TaskAssignment { - task_id: Uuid, - model: String, - messages: Vec, - max_tokens: u32, - temperature: f32, - validation: Option, - }, - /// Challenge for proof-of-liveness. - Challenge { challenge: [u8; 32] }, - /// Heartbeat / keep-alive ping. - Ping, - /// Updated model registry from the router. - ModelRegistryUpdate { - entries: Vec, - }, -} - -// --------------------------------------------------------------------------- -// Supporting types -// --------------------------------------------------------------------------- - -/// Statistics about a completed inference task. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TaskStats { - pub tokens_generated: u32, - pub prompt_tokens: u32, - pub generation_time_ms: u64, - pub tokens_per_second: f64, -} - -/// Reason a task was rejected. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum RejectReason { - /// Model not loaded on this node. - ModelNotLoaded, - /// All inference slots are busy. - AtCapacity, - /// Task parameters are invalid. - InvalidRequest(String), -} - -/// Current capacity snapshot. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Capacity { - /// Number of free inference slots. - pub free: usize, - /// Maximum concurrent inference slots. - pub max: usize, -} - -/// Optional validation parameters included with a task. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ValidationRequest { - /// Token positions at which to extract logprobs. - pub logprob_positions: Vec, - /// Top-k alternatives to collect at each logprob position. - pub logprob_top_k: usize, -} - -/// A model entry from the router's registry. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelRegistryEntry { - pub name: String, - pub hf_repo: String, - pub hf_file: String, - pub chat_template: Option, -} - -// --------------------------------------------------------------------------- -// Length-prefixed MessagePack framing -// --------------------------------------------------------------------------- - -/// Maximum allowed message size (16 MB). -const MAX_MESSAGE_SIZE: u32 = 16 * 1024 * 1024; - -/// Write a length-prefixed MessagePack message to a QUIC send stream. -/// -/// Wire format: `[4-byte BE length][msgpack payload]` -pub async fn write_framed( - send: &mut quinn::SendStream, - msg: &T, -) -> Result<(), NodeError> { - let payload = - rmp_serde::to_vec(msg).map_err(|e| NodeError::Network(format!("serialize: {e}")))?; - let len = payload.len() as u32; - if len > MAX_MESSAGE_SIZE { - return Err(NodeError::Network(format!( - "message too large: {len} bytes (max {MAX_MESSAGE_SIZE})" - ))); - } - send.write_all(&len.to_be_bytes()) - .await - .map_err(|e| NodeError::Network(format!("write length: {e}")))?; - send.write_all(&payload) - .await - .map_err(|e| NodeError::Network(format!("write payload: {e}")))?; - Ok(()) -} - -/// Read a length-prefixed MessagePack message from a QUIC receive stream. -/// -/// Returns `Ok(None)` on clean EOF (stream closed), `Err` on protocol violations. -pub async fn read_framed( - recv: &mut quinn::RecvStream, -) -> Result, NodeError> { - let mut len_buf = [0u8; 4]; - match recv.read_exact(&mut len_buf).await { - Ok(()) => {} - Err(quinn::ReadExactError::FinishedEarly(_)) => return Ok(None), - Err(e) => return Err(NodeError::Network(format!("read length: {e}"))), - } - let len = u32::from_be_bytes(len_buf); - if len > MAX_MESSAGE_SIZE { - return Err(NodeError::Network(format!( - "message too large: {len} bytes (max {MAX_MESSAGE_SIZE})" - ))); - } - let mut payload = vec![0u8; len as usize]; - recv.read_exact(&mut payload) - .await - .map_err(|e| NodeError::Network(format!("read payload: {e}")))?; - let msg = rmp_serde::from_slice(&payload) - .map_err(|e| NodeError::Network(format!("deserialize: {e}")))?; - Ok(Some(msg)) -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_node_message_roundtrip() { - let msg = NodeMessage::TaskResult { - task_id: Uuid::nil(), - text: "Hello world".into(), - stats: TaskStats { - tokens_generated: 10, - prompt_tokens: 5, - generation_time_ms: 100, - tokens_per_second: 100.0, - }, - proof: None, - }; - let packed = rmp_serde::to_vec(&msg).unwrap(); - let roundtrip: NodeMessage = rmp_serde::from_slice(&packed).unwrap(); - match roundtrip { - NodeMessage::TaskResult { task_id, text, .. } => { - assert_eq!(task_id, Uuid::nil()); - assert_eq!(text, "Hello world"); - } - _ => panic!("wrong variant"), - } - } - - #[test] - fn test_router_message_roundtrip() { - let msg = RouterMessage::TaskAssignment { - task_id: Uuid::nil(), - model: "gemma3:4b".into(), - messages: vec![ChatMessage { - role: "user".into(), - content: "hello".into(), - }], - max_tokens: 512, - temperature: 0.7, - validation: None, - }; - let packed = rmp_serde::to_vec(&msg).unwrap(); - let roundtrip: RouterMessage = rmp_serde::from_slice(&packed).unwrap(); - match roundtrip { - RouterMessage::TaskAssignment { model, .. } => { - assert_eq!(model, "gemma3:4b"); - } - _ => panic!("wrong variant"), - } - } - - #[test] - fn test_reject_reason_roundtrip() { - let msg = NodeMessage::TaskRejected { - task_id: Uuid::nil(), - reason: RejectReason::AtCapacity, - }; - let packed = rmp_serde::to_vec(&msg).unwrap(); - let roundtrip: NodeMessage = rmp_serde::from_slice(&packed).unwrap(); - match roundtrip { - NodeMessage::TaskRejected { reason, .. } => { - matches!(reason, RejectReason::AtCapacity); - } - _ => panic!("wrong variant"), - } - } - - #[test] - fn test_status_update_roundtrip() { - let msg = NodeMessage::StatusUpdate { - models: vec!["gemma3:4b".into()], - capacity: Capacity { free: 2, max: 4 }, - version: "2.0.0".into(), - }; - let packed = rmp_serde::to_vec(&msg).unwrap(); - let roundtrip: NodeMessage = rmp_serde::from_slice(&packed).unwrap(); - match roundtrip { - NodeMessage::StatusUpdate { - capacity, version, .. - } => { - assert_eq!(capacity.free, 2); - assert_eq!(capacity.max, 4); - assert_eq!(version, "2.0.0"); - } - _ => panic!("wrong variant"), - } - } - - #[test] - fn test_challenge_roundtrip() { - let msg = RouterMessage::Challenge { - challenge: [0xAB; 32], - }; - let packed = rmp_serde::to_vec(&msg).unwrap(); - let roundtrip: RouterMessage = rmp_serde::from_slice(&packed).unwrap(); - match roundtrip { - RouterMessage::Challenge { challenge } => { - assert_eq!(challenge, [0xAB; 32]); - } - _ => panic!("wrong variant"), - } - } - - #[test] - fn test_ping_roundtrip() { - let packed = rmp_serde::to_vec(&RouterMessage::Ping).unwrap(); - let roundtrip: RouterMessage = rmp_serde::from_slice(&packed).unwrap(); - assert!(matches!(roundtrip, RouterMessage::Ping)); - } - - #[test] - fn test_model_registry_update_roundtrip() { - let msg = RouterMessage::ModelRegistryUpdate { - entries: vec![ModelRegistryEntry { - name: "test:1b".into(), - hf_repo: "repo/model".into(), - hf_file: "model.gguf".into(), - chat_template: Some("chatml".into()), - }], - }; - let packed = rmp_serde::to_vec(&msg).unwrap(); - let roundtrip: RouterMessage = rmp_serde::from_slice(&packed).unwrap(); - match roundtrip { - RouterMessage::ModelRegistryUpdate { entries } => { - assert_eq!(entries.len(), 1); - assert_eq!(entries[0].name, "test:1b"); - } - _ => panic!("wrong variant"), - } - } - - /// Test framing over a quinn duplex (uses tokio::io::duplex via quinn test helpers). - /// Since we can't easily create a quinn stream in unit tests, test the serialization - /// logic directly and verify size limits. - #[test] - fn test_message_size_within_limit() { - let msg = NodeMessage::TaskResult { - task_id: Uuid::nil(), - text: "x".repeat(1000), - stats: TaskStats { - tokens_generated: 100, - prompt_tokens: 50, - generation_time_ms: 500, - tokens_per_second: 200.0, - }, - proof: None, - }; - let packed = rmp_serde::to_vec(&msg).unwrap(); - assert!((packed.len() as u32) < MAX_MESSAGE_SIZE); - } -} +pub use dkn_protocol::{ + read_framed, write_framed, Capacity, NodeMessage, RejectReason, RouterMessage, TaskStats, + ValidationRequest, +}; diff --git a/src/stats.rs b/src/stats.rs new file mode 100644 index 00000000..1ff081c6 --- /dev/null +++ b/src/stats.rs @@ -0,0 +1,118 @@ +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Instant; + +use dkn_protocol::NodeStatsSnapshot; + +/// Atomic counters for node-level metrics. +pub struct NodeStats { + pub tasks_completed: AtomicU64, + pub tasks_failed: AtomicU64, + pub tasks_rejected: AtomicU64, + pub total_tokens_generated: AtomicU64, + started_at: Instant, +} + +impl NodeStats { + pub fn new() -> Self { + NodeStats { + tasks_completed: AtomicU64::new(0), + tasks_failed: AtomicU64::new(0), + tasks_rejected: AtomicU64::new(0), + total_tokens_generated: AtomicU64::new(0), + started_at: Instant::now(), + } + } + + pub fn uptime_secs(&self) -> u64 { + self.started_at.elapsed().as_secs() + } + + pub fn record_completed(&self, tokens: u32) { + self.tasks_completed.fetch_add(1, Ordering::Relaxed); + self.total_tokens_generated + .fetch_add(u64::from(tokens), Ordering::Relaxed); + } + + pub fn record_failed(&self) { + self.tasks_failed.fetch_add(1, Ordering::Relaxed); + } + + pub fn record_rejected(&self) { + self.tasks_rejected.fetch_add(1, Ordering::Relaxed); + } + + pub fn log_summary(&self) { + let snap = self.snapshot(); + tracing::info!( + tasks_completed = snap.tasks_completed, + tasks_failed = snap.tasks_failed, + tasks_rejected = snap.tasks_rejected, + total_tokens = snap.total_tokens_generated, + uptime_secs = snap.uptime_secs, + "node stats" + ); + } + + pub fn snapshot(&self) -> NodeStatsSnapshot { + NodeStatsSnapshot { + tasks_completed: self.tasks_completed.load(Ordering::Relaxed), + tasks_failed: self.tasks_failed.load(Ordering::Relaxed), + tasks_rejected: self.tasks_rejected.load(Ordering::Relaxed), + total_tokens_generated: self.total_tokens_generated.load(Ordering::Relaxed), + uptime_secs: self.uptime_secs(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_stats_initial_values() { + let stats = NodeStats::new(); + assert_eq!(stats.tasks_completed.load(Ordering::Relaxed), 0); + assert_eq!(stats.tasks_failed.load(Ordering::Relaxed), 0); + assert_eq!(stats.tasks_rejected.load(Ordering::Relaxed), 0); + assert_eq!(stats.total_tokens_generated.load(Ordering::Relaxed), 0); + } + + #[test] + fn test_stats_record_completed() { + let stats = NodeStats::new(); + stats.record_completed(50); + stats.record_completed(30); + assert_eq!(stats.tasks_completed.load(Ordering::Relaxed), 2); + assert_eq!(stats.total_tokens_generated.load(Ordering::Relaxed), 80); + } + + #[test] + fn test_stats_record_failed() { + let stats = NodeStats::new(); + stats.record_failed(); + stats.record_failed(); + assert_eq!(stats.tasks_failed.load(Ordering::Relaxed), 2); + } + + #[test] + fn test_stats_record_rejected() { + let stats = NodeStats::new(); + stats.record_rejected(); + assert_eq!(stats.tasks_rejected.load(Ordering::Relaxed), 1); + } + + #[test] + fn test_stats_snapshot() { + let stats = NodeStats::new(); + stats.record_completed(100); + stats.record_failed(); + stats.record_rejected(); + let snap = stats.snapshot(); + assert_eq!(snap.tasks_completed, 1); + assert_eq!(snap.tasks_failed, 1); + assert_eq!(snap.tasks_rejected, 1); + assert_eq!(snap.total_tokens_generated, 100); + // uptime_secs should be >= 0 (just created) + assert!(snap.uptime_secs < 5); + } +} diff --git a/src/worker.rs b/src/worker.rs index 09acce20..21e5b0a3 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::ops::ControlFlow; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -21,34 +22,33 @@ pub struct CompletedTask { } /// Executes inference tasks with backpressure via capacity tracking. +/// +/// Supports multiple models, each with its own engine and chat template. pub struct Worker { - engine: Arc, - /// Chat template name for prompt formatting. - chat_template: String, + /// Map of model name → (engine, chat_template). + engines: HashMap, String)>, /// Number of available inference slots (CAS-based). capacity: Arc, /// Maximum concurrent slots. max_capacity: usize, - /// Models this worker serves. - model_names: Vec, /// In-flight tasks tracked via FuturesUnordered. in_flight: FuturesUnordered>, } impl Worker { - /// Create a new worker wrapping an inference engine. + /// Create a new worker wrapping multiple inference engines. pub fn new( - engine: InferenceEngine, - chat_template: String, - model_names: Vec, + engines: HashMap, max_concurrent: usize, ) -> Self { + let engines = engines + .into_iter() + .map(|(name, (engine, template))| (name, (Arc::new(engine), template))) + .collect(); Worker { - engine: Arc::new(engine), - chat_template, + engines, capacity: Arc::new(AtomicUsize::new(max_concurrent)), max_capacity: max_concurrent, - model_names, in_flight: FuturesUnordered::new(), } } @@ -65,10 +65,13 @@ impl Worker { temperature: f32, validation: Option, ) -> Result<(), RejectReason> { - // Check model - if !self.model_names.iter().any(|m| m == model) { - return Err(RejectReason::ModelNotLoaded); - } + // Look up engine + template for the requested model (fail fast before decrementing capacity) + let (engine, template) = self + .engines + .get(model) + .ok_or(RejectReason::ModelNotLoaded)?; + let engine = Arc::clone(engine); + let template = template.clone(); // Try to decrement capacity (CAS loop) loop { @@ -98,9 +101,7 @@ impl Worker { logprob_top_k: validation.as_ref().map(|v| v.logprob_top_k).unwrap_or(5), }; - let engine = Arc::clone(&self.engine); let capacity = Arc::clone(&self.capacity); - let template = self.chat_template.clone(); let handle = tokio::task::spawn_blocking(move || { let result = run_inference(&engine, &template, messages, ¶ms, task_id); @@ -137,14 +138,33 @@ impl Worker { } /// Model names this worker serves. - pub fn model_names(&self) -> &[String] { - &self.model_names + pub fn model_names(&self) -> Vec { + self.engines.keys().cloned().collect() } /// Whether there are any in-flight tasks. pub fn has_in_flight(&self) -> bool { !self.in_flight.is_empty() } + + /// Add a new model engine at runtime (for hot-swap). + /// + /// If a model with this name already exists, it is replaced. + pub fn add_engine(&mut self, name: String, engine: InferenceEngine, template: String) { + self.engines.insert(name, (Arc::new(engine), template)); + } + + /// Remove a model engine by name. Returns true if the model was present. + /// + /// Safe while tasks are in-flight — running tasks hold their own Arc clone. + pub fn remove_engine(&mut self, name: &str) -> bool { + self.engines.remove(name).is_some() + } + + /// Check whether the worker has a model loaded. + pub fn has_model(&self, name: &str) -> bool { + self.engines.contains_key(name) + } } /// Run inference synchronously (called from `spawn_blocking`). @@ -258,4 +278,22 @@ mod tests { }; assert!(completed.result.is_err()); } + + #[test] + fn test_worker_has_model() { + let worker = Worker::new(HashMap::new(), 1); + assert!(!worker.has_model("gemma3:4b")); + } + + #[test] + fn test_worker_remove_engine_not_present() { + let mut worker = Worker::new(HashMap::new(), 1); + assert!(!worker.remove_engine("gemma3:4b")); + } + + #[test] + fn test_worker_model_names_empty() { + let worker = Worker::new(HashMap::new(), 1); + assert!(worker.model_names().is_empty()); + } } From 68f5b45815401c2114549212bca18ea0cc5faf92 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 3 Mar 2026 01:51:08 +0300 Subject: [PATCH 03/57] Update model registry and add modality (Text/Vision/Audio) validation - Replace old models (gemma, llama, mistral) with lfm2.5, qwen3.5, nanbeige, locooperator - Add ModelType field to ModelSpec and propagate through worker/main - Worker rejects tasks with image/audio content when model lacks that modality - Re-export ModelType and MessageContent from dkn-protocol --- TESTING.md | 366 ++++++++++++++++++++++++++++++++++++++ src/main.rs | 11 +- src/models/cache.rs | 1 + src/models/mod.rs | 2 +- src/models/registry.rs | 123 +++++++------ src/network/connection.rs | 5 +- src/network/protocol.rs | 4 +- src/worker.rs | 63 +++++-- 8 files changed, 497 insertions(+), 78 deletions(-) create mode 100644 TESTING.md diff --git a/TESTING.md b/TESTING.md new file mode 100644 index 00000000..03808b01 --- /dev/null +++ b/TESTING.md @@ -0,0 +1,366 @@ +# DKN Network Testing Guide + +How to test the full router + compute-node stack locally (single machine) and over the internet (two laptops). + +## Prerequisites + +- Rust toolchain (`rustup`, `cargo`) +- `openssl` CLI (for generating TLS certs) +- A HuggingFace account (models download automatically) +- ~1 GB free disk for the smallest model (`lfm2.5:1.2b`) + +Build both binaries first: + +```bash +# Router +cd dkn-router && cargo build --release + +# Compute node (CPU) +cd dkn-compute-node && cargo build --release + +# Compute node (Metal / Apple Silicon) +cd dkn-compute-node && cargo build --release --features metal + +# Compute node (CUDA) +cd dkn-compute-node && cargo build --release --features cuda +``` + +## Generate a wallet key + +Any 32-byte hex string works as a test wallet: + +```bash +openssl rand -hex 32 +# example output: a1b2c3d4...64 hex chars total +``` + +Save it — you'll pass it to the node via `--wallet`. + +--- + +## Scenario 1: Everything on localhost + +### 1. Generate self-signed TLS certs + +```bash +mkdir -p /tmp/dkn-certs + +openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 \ + -keyout /tmp/dkn-certs/key.pem -out /tmp/dkn-certs/cert.pem \ + -days 365 -nodes -subj "/CN=localhost" \ + -addext "subjectAltName=DNS:localhost,IP:127.0.0.1" +``` + +### 2. Start the router + +```bash +./dkn-router/target/release/dkn-router \ + --listen-quic 127.0.0.1:4001 \ + --listen-http 127.0.0.1:8080 \ + --cert /tmp/dkn-certs/cert.pem \ + --key /tmp/dkn-certs/key.pem +``` + +You should see: + +``` +INFO starting DKN router quic=127.0.0.1:4001 http=127.0.0.1:8080 +INFO router ready ... +``` + +### 3. Start the compute node + +In a second terminal: + +```bash +./dkn-compute-node/target/release/dria-node start \ + --wallet \ + --model lfm2.5:1.2b \ + --router-url https://127.0.0.1:4001 \ + --insecure \ + --gpu-layers -1 +``` + +- `--insecure` skips TLS verification (required for self-signed certs). +- `--gpu-layers -1` offloads all layers to GPU. Use `0` for CPU-only. +- First run downloads the model from HuggingFace (~730 MB). + +You should see: + +``` +INFO node identity address=0x... +INFO model found in cache ... +INFO benchmark complete tps=... model=lfm2.5:1.2b +INFO connected to router node_id=... router=https://127.0.0.1:4001 +INFO node ready ... +``` + +### 4. Send a request + +```bash +curl -s http://127.0.0.1:8080/v1/generate \ + -H "Content-Type: application/json" \ + -d '{ + "model": "lfm2.5:1.2b", + "messages": [{"role": "user", "content": "What is 2+2?"}], + "max_tokens": 128, + "temperature": 0.7 + }' | python3 -m json.tool +``` + +Expected response: + +```json +{ + "text": "2+2 equals 4...", + "model": "lfm2.5:1.2b", + "stats": { + "tokens_generated": 12, + "prompt_tokens": 8, + "generation_time_ms": 450, + "tokens_per_second": 26.7 + } +} +``` + +### 5. Check other endpoints + +```bash +# Health check +curl -s http://127.0.0.1:8080/v1/health | python3 -m json.tool + +# List models served by connected nodes +curl -s http://127.0.0.1:8080/v1/models | python3 -m json.tool + +# Batch request +curl -s http://127.0.0.1:8080/v1/batch \ + -H "Content-Type: application/json" \ + -d '{ + "tasks": [ + {"model": "lfm2.5:1.2b", "messages": [{"role": "user", "content": "Say hi"}]}, + {"model": "lfm2.5:1.2b", "messages": [{"role": "user", "content": "Say bye"}]} + ], + "timeout_secs": 30 + }' | python3 -m json.tool +``` + +### 6. Run multiple nodes (optional) + +Start a second node with a different model and wallet on the same machine: + +```bash +./dkn-compute-node/target/release/dria-node start \ + --wallet $(openssl rand -hex 32) \ + --model nanbeige:3b \ + --router-url https://127.0.0.1:4001 \ + --insecure \ + --gpu-layers 0 +``` + +Now `/v1/models` will show both `lfm2.5:1.2b` and `nanbeige:3b`. + +--- + +## Scenario 2: Two laptops over the internet + +**Laptop A** = router, **Laptop B** = compute node. + +### 1. Find Laptop A's public IP + +If Laptop A is behind NAT (home router), you need to either: + +- **Port-forward** UDP 4001 and TCP 8080 on the home router to Laptop A's LAN IP. +- Use a cloud VM (DigitalOcean, AWS, etc.) as Laptop A instead. + +Get the public IP: + +```bash +curl -s ifconfig.me +# e.g. 203.0.113.42 +``` + +### 2. Generate TLS certs on Laptop A + +Generate certs with the public IP as a SAN: + +```bash +export ROUTER_IP=203.0.113.42 # replace with your public IP + +mkdir -p /tmp/dkn-certs + +openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 \ + -keyout /tmp/dkn-certs/key.pem -out /tmp/dkn-certs/cert.pem \ + -days 365 -nodes -subj "/CN=$ROUTER_IP" \ + -addext "subjectAltName=IP:$ROUTER_IP" +``` + +If you have a domain name, use `DNS:yourdomain.com` instead of `IP:...`. + +### 3. Start the router on Laptop A + +```bash +./dkn-router/target/release/dkn-router \ + --listen-quic 0.0.0.0:4001 \ + --listen-http 0.0.0.0:8080 \ + --cert /tmp/dkn-certs/cert.pem \ + --key /tmp/dkn-certs/key.pem +``` + +Note `0.0.0.0` to listen on all interfaces. + +### 4. Verify connectivity from Laptop B + +```bash +# Check HTTP is reachable +curl -s http://203.0.113.42:8080/v1/health + +# Check QUIC port is open (UDP) +nc -z -u 203.0.113.42 4001 && echo "open" || echo "blocked" +``` + +If the health check returns `{"status":"ok",...}`, HTTP is working. If QUIC is blocked, check firewall/NAT rules for **UDP** port 4001. + +### 5. Start the compute node on Laptop B + +```bash +./dkn-compute-node/target/release/dria-node start \ + --wallet \ + --model lfm2.5:1.2b \ + --router-url https://203.0.113.42:4001 \ + --insecure \ + --gpu-layers -1 +``` + +`--insecure` is needed because the cert is self-signed. Once the node connects: + +``` +INFO connected to router node_id=... router=https://203.0.113.42:4001 +``` + +### 6. Send requests from either laptop + +From Laptop A (or any machine that can reach the router): + +```bash +curl -s http://203.0.113.42:8080/v1/generate \ + -H "Content-Type: application/json" \ + -d '{ + "model": "lfm2.5:1.2b", + "messages": [{"role": "user", "content": "Hello from the internet!"}], + "max_tokens": 64 + }' | python3 -m json.tool +``` + +The HTTP request goes to the router, which forwards it via QUIC to the node on Laptop B, which runs inference and sends the result back. + +--- + +## Scenario 3: LAN testing (two laptops, same network) + +Same as Scenario 2 but simpler — no NAT/port-forwarding needed. + +### 1. Find Laptop A's LAN IP + +```bash +# macOS +ipconfig getifaddr en0 + +# Linux +hostname -I | awk '{print $1}' +``` + +Example: `192.168.1.100` + +### 2. Generate certs and start router on Laptop A + +```bash +export ROUTER_IP=192.168.1.100 + +mkdir -p /tmp/dkn-certs + +openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 \ + -keyout /tmp/dkn-certs/key.pem -out /tmp/dkn-certs/cert.pem \ + -days 365 -nodes -subj "/CN=$ROUTER_IP" \ + -addext "subjectAltName=IP:$ROUTER_IP" + +./dkn-router/target/release/dkn-router \ + --listen-quic 0.0.0.0:4001 \ + --listen-http 0.0.0.0:8080 \ + --cert /tmp/dkn-certs/cert.pem \ + --key /tmp/dkn-certs/key.pem +``` + +### 3. Start node on Laptop B + +```bash +./dkn-compute-node/target/release/dria-node start \ + --wallet $(openssl rand -hex 32) \ + --model lfm2.5:1.2b \ + --router-url https://192.168.1.100:4001 \ + --insecure \ + --gpu-layers -1 +``` + +### 4. Send requests from Laptop A + +```bash +curl -s http://192.168.1.100:8080/v1/generate \ + -H "Content-Type: application/json" \ + -d '{ + "model": "lfm2.5:1.2b", + "messages": [{"role": "user", "content": "Hello from LAN!"}], + "max_tokens": 64 + }' | python3 -m json.tool +``` + +--- + +## Available models + +| Short name | Size | Type | Notes | +|---|---|---|---| +| `lfm2.5:1.2b` | 731 MB | text | Fastest, good for testing | +| `nanbeige:3b` | 2.4 GB | text | | +| `locooperator:4b` | 2.5 GB | text | | +| `lfm2.5-vl:1.6b` | 696 MB | vision | Rejects text-only requests are fine, rejects audio | +| `lfm2.5-audio:1.5b` | 696 MB | audio | Rejects image content | +| `lfm2:24b-a2b` | 14.4 GB | text | MoE | +| `qwen3.5:27b` | 16.7 GB | text | | +| `qwen3.5:35b-a3b` | 19.9 GB | text | MoE | + +## Environment variables + +All CLI flags can be set via env vars instead: + +| Env var | Flag | Default | +|---|---|---| +| `DRIA_WALLET` | `--wallet` | (required) | +| `DRIA_MODELS` | `--model` | (required) | +| `DRIA_ROUTER_URL` | `--router-url` | `https://router.dria.co` | +| `DRIA_GPU_LAYERS` | `--gpu-layers` | `0` | +| `DRIA_MAX_CONCURRENT` | `--max-concurrent` | `1` | +| `DRIA_DATA_DIR` | `--data-dir` | `~/.dria` | +| `DRIA_INSECURE` | `--insecure` | `false` | +| `DRIA_ROUTER_QUIC_ADDR` | `--listen-quic` | `0.0.0.0:4001` | +| `DRIA_ROUTER_HTTP_ADDR` | `--listen-http` | `0.0.0.0:8080` | +| `DRIA_ROUTER_CERT` | `--cert` | (required) | +| `DRIA_ROUTER_KEY` | `--key` | (required) | + +## Troubleshooting + +| Symptom | Cause | Fix | +|---|---|---| +| Node logs `all routers unavailable` | Can't reach router QUIC port | Check firewall allows **UDP** 4001, verify IP/port | +| Node logs `TLS error` | Cert doesn't match router hostname/IP | Regenerate cert with correct SAN, or use `--insecure` | +| `curl` to `/v1/generate` returns 503 | No nodes connected | Check node logs, ensure it says `connected to router` | +| `curl` to `/v1/generate` returns 504 | Node timeout during inference | Increase `timeout_secs` in request, or use a smaller model | +| Node logs `SHA-256 mismatch` | Corrupted download | Delete `~/.dria/models/` and restart to re-download | +| `QUIC connect failed: no initial cipher suite` | TLS/QUIC version mismatch | Ensure both router and node are built from the same branch | +| Batch request partial failures | One model not loaded | Check `/v1/models` to see what's available | + +## Verbose logging + +```bash +RUST_LOG=debug ./target/release/dkn-router ... +RUST_LOG=debug ./target/release/dria-node start ... +``` diff --git a/src/main.rs b/src/main.rs index 7c792871..ca8021cc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,6 +20,7 @@ use identity::Identity; use models::{ModelCache, ModelDownloader, default_registry, resolve_model}; use models::registry::ModelSpec; use network::{NodeMessage, RouterMessage}; +use network::protocol::ModelType; use network::RouterConnection; use stats::NodeStats; use worker::{CompletedTask, Worker}; @@ -65,6 +66,7 @@ struct NodeContext { struct ModelLoadResult { name: String, template: String, + model_type: ModelType, result: Result<(inference::InferenceEngine, f64), error::NodeError>, } @@ -93,7 +95,7 @@ async fn run_start( let cache = ModelCache::new(config.models_dir.clone())?; // Accumulate engines and TPS per model - let mut engines: HashMap = HashMap::new(); + let mut engines: HashMap = HashMap::new(); let mut tps_map: HashMap = HashMap::new(); for model_name in &config.model_names { @@ -107,7 +109,7 @@ async fn run_start( .unwrap_or_else(|| "chatml".to_string()); tracing::info!(tps = %format!("{tps:.1}"), model = %model_name, "benchmark complete"); - engines.insert(model_name.clone(), (engine, chat_template)); + engines.insert(model_name.clone(), (engine, chat_template, spec.model_type)); tps_map.insert(model_name.clone(), tps); } @@ -209,7 +211,7 @@ async fn run_start( tps = %format!("{tps:.1}"), "model loaded successfully" ); - worker.add_engine(loaded.name.clone(), engine, loaded.template); + worker.add_engine(loaded.name.clone(), engine, loaded.template, loaded.model_type); ctx.tps.insert(loaded.name, tps); } Err(e) => { @@ -423,11 +425,12 @@ async fn handle_router_message( .chat_template .clone() .unwrap_or_else(|| "chatml".to_string()); + let model_type = entry.model_type; tracing::info!(model = %name, "spawning background model download+load"); tokio::spawn(async move { let result = download_and_load_model(&spec, &cache, gpu_layers).await; - let _ = tx.send(ModelLoadResult { name, template, result }); + let _ = tx.send(ModelLoadResult { name, template, model_type, result }); }); } } diff --git a/src/models/cache.rs b/src/models/cache.rs index 6b599386..2e4a69f9 100644 --- a/src/models/cache.rs +++ b/src/models/cache.rs @@ -90,6 +90,7 @@ mod tests { hf_file: "model.gguf".into(), sha256: None, chat_template: None, + model_type: dkn_protocol::ModelType::Text, }; // Not present initially diff --git a/src/models/mod.rs b/src/models/mod.rs index fd5e9529..7d0d7324 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -3,7 +3,7 @@ pub mod download; pub mod registry; pub mod template { - pub use dkn_protocol::{apply_chat_template, ChatMessage}; + pub use dkn_protocol::{apply_chat_template, ChatMessage, MessageContent}; } pub use cache::ModelCache; diff --git a/src/models/registry.rs b/src/models/registry.rs index 4eec6ba3..a81e3071 100644 --- a/src/models/registry.rs +++ b/src/models/registry.rs @@ -1,87 +1,90 @@ use std::collections::HashMap; -use dkn_protocol::ModelRegistryEntry; +use dkn_protocol::{ModelRegistryEntry, ModelType}; /// Specification for a model: shortname mapped to HuggingFace GGUF location. #[derive(Debug, Clone)] pub struct ModelSpec { - /// Short name used by users (e.g. "gemma3:4b") + /// Short name used by users (e.g. "lfm2.5:1.2b") pub name: String, - /// HuggingFace repository (e.g. "bartowski/gemma-3-4b-it-GGUF") + /// HuggingFace repository (e.g. "LiquidAI/LFM2.5-1.2B-Instruct-GGUF") pub hf_repo: String, - /// Filename within the repo (e.g. "gemma-3-4b-it-Q4_K_M.gguf") + /// Filename within the repo (e.g. "LFM2.5-1.2B-Instruct-Q4_K_M.gguf") pub hf_file: String, /// Expected SHA-256 hex digest for verification (None = skip verification) pub sha256: Option, /// Chat template identifier (e.g. "gemma", "llama3", "chatml") pub chat_template: Option, + /// Modality this model supports. + pub model_type: ModelType, } -/// Build the default model registry with all 9 supported models. +/// Build the default model registry with all supported models. pub fn default_registry() -> HashMap { let entries = vec![ ModelSpec { - name: "gemma3:4b".into(), - hf_repo: "bartowski/google_gemma-3-4b-it-GGUF".into(), - hf_file: "google_gemma-3-4b-it-Q4_K_M.gguf".into(), + name: "lfm2.5:1.2b".into(), + hf_repo: "LiquidAI/LFM2.5-1.2B-Instruct-GGUF".into(), + hf_file: "LFM2.5-1.2B-Instruct-Q4_K_M.gguf".into(), sha256: None, - chat_template: Some("gemma".into()), - }, - ModelSpec { - name: "gemma3:12b".into(), - hf_repo: "bartowski/google_gemma-3-12b-it-GGUF".into(), - hf_file: "google_gemma-3-12b-it-Q4_K_M.gguf".into(), - sha256: None, - chat_template: Some("gemma".into()), + chat_template: Some("chatml".into()), + model_type: ModelType::Text, }, ModelSpec { - name: "gemma3:27b".into(), - hf_repo: "bartowski/google_gemma-3-27b-it-GGUF".into(), - hf_file: "google_gemma-3-27b-it-Q4_K_M.gguf".into(), + name: "qwen3.5:35b-a3b".into(), + hf_repo: "unsloth/Qwen3.5-35B-A3B-GGUF".into(), + hf_file: "Qwen3.5-35B-A3B-UD-Q4_K_M.gguf".into(), sha256: None, - chat_template: Some("gemma".into()), + chat_template: Some("chatml".into()), + model_type: ModelType::Text, }, ModelSpec { - name: "llama3.1:8b".into(), - hf_repo: "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF".into(), - hf_file: "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf".into(), + name: "lfm2:24b-a2b".into(), + hf_repo: "LiquidAI/LFM2-24B-A2B-GGUF".into(), + hf_file: "LFM2-24B-A2B-Q4_K_M.gguf".into(), sha256: None, - chat_template: Some("llama3".into()), + chat_template: Some("chatml".into()), + model_type: ModelType::Text, }, ModelSpec { - name: "llama3.2:1b".into(), - hf_repo: "bartowski/Llama-3.2-1B-Instruct-GGUF".into(), - hf_file: "Llama-3.2-1B-Instruct-Q4_K_M.gguf".into(), + name: "lfm2.5-vl:1.6b".into(), + hf_repo: "LiquidAI/LFM2.5-VL-1.6B-GGUF".into(), + hf_file: "LFM2.5-VL-1.6B-Q4_0.gguf".into(), sha256: None, - chat_template: Some("llama3".into()), + chat_template: Some("chatml".into()), + model_type: ModelType::Vision, }, ModelSpec { - name: "llama3.3:70b".into(), - hf_repo: "bartowski/Llama-3.3-70B-Instruct-GGUF".into(), - hf_file: "Llama-3.3-70B-Instruct-Q4_K_M.gguf".into(), + name: "lfm2.5-audio:1.5b".into(), + hf_repo: "LiquidAI/LFM2.5-Audio-1.5B-GGUF".into(), + hf_file: "LFM2.5-Audio-1.5B-Q4_0.gguf".into(), sha256: None, - chat_template: Some("llama3".into()), + chat_template: Some("chatml".into()), + model_type: ModelType::Audio, }, ModelSpec { - name: "mistral-nemo:12b".into(), - hf_repo: "bartowski/Mistral-Nemo-Instruct-2407-GGUF".into(), - hf_file: "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf".into(), + name: "qwen3.5:27b".into(), + hf_repo: "unsloth/Qwen3.5-27B-GGUF".into(), + hf_file: "Qwen3.5-27B-Q4_K_M.gguf".into(), sha256: None, chat_template: Some("chatml".into()), + model_type: ModelType::Text, }, ModelSpec { - name: "qwen3:8b".into(), - hf_repo: "bartowski/Qwen3-8B-GGUF".into(), - hf_file: "Qwen3-8B-Q4_K_M.gguf".into(), + name: "nanbeige:3b".into(), + hf_repo: "DevQuasar/Nanbeige.Nanbeige4.1-3B-GGUF".into(), + hf_file: "Nanbeige.Nanbeige4.1-3B.Q4_K_M.gguf".into(), sha256: None, chat_template: Some("chatml".into()), + model_type: ModelType::Text, }, ModelSpec { - name: "qwen3:32b".into(), - hf_repo: "bartowski/Qwen3-32B-GGUF".into(), - hf_file: "Qwen3-32B-Q4_K_M.gguf".into(), + name: "locooperator:4b".into(), + hf_repo: "LocoreMind/LocoOperator-4B-GGUF".into(), + hf_file: "LocoOperator-4B.Q4_K_M.gguf".into(), sha256: None, chat_template: Some("chatml".into()), + model_type: ModelType::Text, }, ]; @@ -97,6 +100,7 @@ impl ModelSpec { hf_file: entry.hf_file.clone(), sha256: None, chat_template: entry.chat_template.clone(), + model_type: entry.model_type, } } } @@ -114,29 +118,29 @@ mod tests { fn test_default_registry_has_all_models() { let reg = default_registry(); let expected = [ - "gemma3:4b", - "gemma3:12b", - "gemma3:27b", - "llama3.1:8b", - "llama3.2:1b", - "llama3.3:70b", - "mistral-nemo:12b", - "qwen3:8b", - "qwen3:32b", + "lfm2.5:1.2b", + "qwen3.5:35b-a3b", + "lfm2:24b-a2b", + "lfm2.5-vl:1.6b", + "lfm2.5-audio:1.5b", + "qwen3.5:27b", + "nanbeige:3b", + "locooperator:4b", ]; for name in &expected { assert!(reg.contains_key(*name), "missing model: {name}"); } - assert_eq!(reg.len(), 9); + assert_eq!(reg.len(), 8); } #[test] fn test_resolve_known_model() { let reg = default_registry(); - let spec = resolve_model("gemma3:4b", ®).expect("should resolve"); - assert_eq!(spec.name, "gemma3:4b"); - assert!(spec.hf_repo.contains("gemma")); + let spec = resolve_model("lfm2.5:1.2b", ®).expect("should resolve"); + assert_eq!(spec.name, "lfm2.5:1.2b"); + assert!(spec.hf_repo.contains("LFM2.5")); assert!(spec.hf_file.ends_with(".gguf")); + assert_eq!(spec.model_type, ModelType::Text); } #[test] @@ -152,6 +156,7 @@ mod tests { hf_repo: "test/repo".into(), hf_file: "model.gguf".into(), chat_template: Some("chatml".into()), + model_type: ModelType::Vision, }; let spec = ModelSpec::from_registry_entry(&entry); assert_eq!(spec.name, "test:1b"); @@ -159,5 +164,15 @@ mod tests { assert_eq!(spec.hf_file, "model.gguf"); assert!(spec.sha256.is_none()); assert_eq!(spec.chat_template, Some("chatml".into())); + assert_eq!(spec.model_type, ModelType::Vision); + } + + #[test] + fn test_model_types_correct() { + let reg = default_registry(); + assert_eq!(reg["lfm2.5-vl:1.6b"].model_type, ModelType::Vision); + assert_eq!(reg["lfm2.5-audio:1.5b"].model_type, ModelType::Audio); + assert_eq!(reg["lfm2.5:1.2b"].model_type, ModelType::Text); + assert_eq!(reg["qwen3.5:27b"].model_type, ModelType::Text); } } diff --git a/src/network/connection.rs b/src/network/connection.rs index c6c3b5f0..50141c15 100644 --- a/src/network/connection.rs +++ b/src/network/connection.rs @@ -578,10 +578,7 @@ mod tests { &RouterMessage::TaskAssignment { task_id, model: "nonexistent:1b".into(), - messages: vec![dkn_protocol::ChatMessage { - role: "user".into(), - content: "test".into(), - }], + messages: vec![dkn_protocol::ChatMessage::text("user", "test")], max_tokens: 10, temperature: 0.7, validation: None, diff --git a/src/network/protocol.rs b/src/network/protocol.rs index 1aa0e3f3..647bbb0b 100644 --- a/src/network/protocol.rs +++ b/src/network/protocol.rs @@ -1,4 +1,4 @@ pub use dkn_protocol::{ - read_framed, write_framed, Capacity, NodeMessage, RejectReason, RouterMessage, TaskStats, - ValidationRequest, + read_framed, write_framed, Capacity, ModelType, NodeMessage, RejectReason, RouterMessage, + TaskStats, ValidationRequest, }; diff --git a/src/worker.rs b/src/worker.rs index 21e5b0a3..02658310 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -10,9 +10,9 @@ use uuid::Uuid; use crate::error::NodeError; use crate::inference::{GenerateParams, InferenceEngine, InferenceResult}; -use crate::models::template::{ChatMessage, apply_chat_template}; +use crate::models::template::{ChatMessage, MessageContent, apply_chat_template}; use crate::network::protocol::{ - Capacity, NodeMessage, RejectReason, TaskStats, ValidationRequest, + Capacity, ModelType, NodeMessage, RejectReason, TaskStats, ValidationRequest, }; /// A completed inference task ready to be sent back. @@ -23,10 +23,10 @@ pub struct CompletedTask { /// Executes inference tasks with backpressure via capacity tracking. /// -/// Supports multiple models, each with its own engine and chat template. +/// Supports multiple models, each with its own engine, chat template, and modality. pub struct Worker { - /// Map of model name → (engine, chat_template). - engines: HashMap, String)>, + /// Map of model name → (engine, chat_template, model_type). + engines: HashMap, String, ModelType)>, /// Number of available inference slots (CAS-based). capacity: Arc, /// Maximum concurrent slots. @@ -38,12 +38,14 @@ pub struct Worker { impl Worker { /// Create a new worker wrapping multiple inference engines. pub fn new( - engines: HashMap, + engines: HashMap, max_concurrent: usize, ) -> Self { let engines = engines .into_iter() - .map(|(name, (engine, template))| (name, (Arc::new(engine), template))) + .map(|(name, (engine, template, model_type))| { + (name, (Arc::new(engine), template, model_type)) + }) .collect(); Worker { engines, @@ -65,11 +67,31 @@ impl Worker { temperature: f32, validation: Option, ) -> Result<(), RejectReason> { - // Look up engine + template for the requested model (fail fast before decrementing capacity) - let (engine, template) = self + // Look up engine + template + model_type for the requested model (fail fast before decrementing capacity) + let (engine, template, model_type) = self .engines .get(model) .ok_or(RejectReason::ModelNotLoaded)?; + + // Check modality: reject if messages contain image/audio parts that the model can't handle + let has_image = messages + .iter() + .any(|m| m.content.has_image()); + let has_audio = messages + .iter() + .any(|m| m.content.has_audio()); + + if has_image && *model_type != ModelType::Vision { + return Err(RejectReason::InvalidRequest( + "message contains image content but model does not support vision".into(), + )); + } + if has_audio && *model_type != ModelType::Audio { + return Err(RejectReason::InvalidRequest( + "message contains audio content but model does not support audio".into(), + )); + } + let engine = Arc::clone(engine); let template = template.clone(); @@ -150,8 +172,15 @@ impl Worker { /// Add a new model engine at runtime (for hot-swap). /// /// If a model with this name already exists, it is replaced. - pub fn add_engine(&mut self, name: String, engine: InferenceEngine, template: String) { - self.engines.insert(name, (Arc::new(engine), template)); + pub fn add_engine( + &mut self, + name: String, + engine: InferenceEngine, + template: String, + model_type: ModelType, + ) { + self.engines + .insert(name, (Arc::new(engine), template, model_type)); } /// Remove a model engine by name. Returns true if the model was present. @@ -282,13 +311,13 @@ mod tests { #[test] fn test_worker_has_model() { let worker = Worker::new(HashMap::new(), 1); - assert!(!worker.has_model("gemma3:4b")); + assert!(!worker.has_model("lfm2.5:1.2b")); } #[test] fn test_worker_remove_engine_not_present() { let mut worker = Worker::new(HashMap::new(), 1); - assert!(!worker.remove_engine("gemma3:4b")); + assert!(!worker.remove_engine("lfm2.5:1.2b")); } #[test] @@ -296,4 +325,12 @@ mod tests { let worker = Worker::new(HashMap::new(), 1); assert!(worker.model_names().is_empty()); } + + #[test] + fn test_modality_check_text_content() { + // MessageContent::Text should have no image/audio + let content = MessageContent::Text("hello".into()); + assert!(!content.has_image()); + assert!(!content.has_audio()); + } } From 54fcc292f51d254bcc3ccb973064650086b7315c Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 3 Mar 2026 16:03:24 +0300 Subject: [PATCH 04/57] updated registry --- src/models/registry.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/models/registry.rs b/src/models/registry.rs index a81e3071..21ec3ffc 100644 --- a/src/models/registry.rs +++ b/src/models/registry.rs @@ -86,6 +86,14 @@ pub fn default_registry() -> HashMap { chat_template: Some("chatml".into()), model_type: ModelType::Text, }, + ModelSpec { + name: "qwen3.5:9b".into(), + hf_repo: "lmstudio-community/Qwen3.5-9B-GGUF".into(), + hf_file: "Qwen3.5-9B-Q4_K_M.gguf".into(), + sha256: None, + chat_template: Some("chatml".into()), + model_type: ModelType::Text, + }, ]; entries.into_iter().map(|s| (s.name.clone(), s)).collect() @@ -126,11 +134,12 @@ mod tests { "qwen3.5:27b", "nanbeige:3b", "locooperator:4b", + "qwen3.5:9b", ]; for name in &expected { assert!(reg.contains_key(*name), "missing model: {name}"); } - assert_eq!(reg.len(), 8); + assert_eq!(reg.len(), 9); } #[test] From d8fce1094c11c840f410608afd3824ebef88d48c Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 3 Mar 2026 16:41:48 +0300 Subject: [PATCH 05/57] Add --quant CLI flag to override default GGUF quantization Allows users to specify a quantization level (e.g. Q8_0, Q5_K_M) instead of always downloading the registry default (Q4_K_M). This avoids redundant downloads when a different quantization is already cached locally via HuggingFace hub. Also fix dkn-protocol type mismatches: add ModelType enum, MessageContent enum, and model_type field to ModelRegistryEntry. Co-Authored-By: Claude Opus 4.6 --- src/config.rs | 15 +++++++++ src/main.rs | 8 +++-- src/models/registry.rs | 65 ++++++++++++++++++++++++++++++++++++--- src/network/connection.rs | 2 +- src/worker.rs | 3 +- 5 files changed, 84 insertions(+), 9 deletions(-) diff --git a/src/config.rs b/src/config.rs index 285b2ae3..8deaf678 100644 --- a/src/config.rs +++ b/src/config.rs @@ -39,6 +39,10 @@ pub enum Command { #[arg(long, env = "DRIA_DATA_DIR")] data_dir: Option, + /// Override GGUF quantization (e.g. Q8_0, Q5_K_M, Q6_K). Defaults to the registry value (usually Q4_K_M). + #[arg(long, env = "DRIA_QUANT")] + quant: Option, + /// Skip TLS certificate verification (for development/testing) #[arg(long, env = "DRIA_INSECURE")] insecure: bool, @@ -54,6 +58,7 @@ pub struct Config { pub max_concurrent: usize, pub data_dir: PathBuf, pub models_dir: PathBuf, + pub quant: Option, pub insecure: bool, } @@ -66,6 +71,7 @@ impl Config { gpu_layers: i32, max_concurrent: usize, data_dir: Option, + quant: Option, insecure: bool, ) -> Result { // Validate wallet key @@ -120,6 +126,7 @@ impl Config { max_concurrent, data_dir, models_dir, + quant, insecure, }) } @@ -138,6 +145,7 @@ mod tests { 0, 1, Some("/tmp/dria-test".into()), + None, false, ) .unwrap(); @@ -160,6 +168,7 @@ mod tests { 0, 1, None, + None, false, ); assert!(result.is_err()); @@ -174,6 +183,7 @@ mod tests { 0, 1, None, + None, false, ); assert!(result.is_err()); @@ -188,6 +198,7 @@ mod tests { 0, 1, None, + None, false, ); assert!(result.is_err()); @@ -202,6 +213,7 @@ mod tests { 0, 0, None, + None, false, ); assert!(result.is_err()); @@ -216,6 +228,7 @@ mod tests { 0, 1, None, + None, false, ) .unwrap(); @@ -234,6 +247,7 @@ mod tests { 0, 1, None, + None, false, ); assert!(result.is_err()); @@ -248,6 +262,7 @@ mod tests { 0, 1, None, + None, true, ) .unwrap(); diff --git a/src/main.rs b/src/main.rs index ca8021cc..70e0c16d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -44,9 +44,10 @@ async fn main() -> anyhow::Result<()> { gpu_layers, max_concurrent, data_dir, + quant, insecure, } => { - run_start(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, insecure).await?; + run_start(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure).await?; } } @@ -77,10 +78,11 @@ async fn run_start( gpu_layers: i32, max_concurrent: usize, data_dir: Option, + quant: Option, insecure: bool, ) -> anyhow::Result<()> { // Parse config - let config = Config::from_start_args(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, insecure)?; + let config = Config::from_start_args(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure)?; // Create identity let identity = Identity::from_secret_hex(&config.secret_key_hex)?; @@ -99,7 +101,7 @@ async fn run_start( let mut tps_map: HashMap = HashMap::new(); for model_name in &config.model_names { - let spec = resolve_model(model_name, ®istry) + let spec = resolve_model(model_name, ®istry, config.quant.as_deref()) .ok_or_else(|| error::NodeError::Model(format!("unknown model: {model_name}")))?; let (engine, tps) = download_and_load_model(&spec, &cache, config.gpu_layers).await?; diff --git a/src/models/registry.rs b/src/models/registry.rs index 21ec3ffc..34a3b898 100644 --- a/src/models/registry.rs +++ b/src/models/registry.rs @@ -111,11 +111,40 @@ impl ModelSpec { model_type: entry.model_type, } } + + /// Return a new ModelSpec with the quantization portion of `hf_file` replaced. + /// + /// GGUF filenames follow the pattern `{ModelName}-{Quant}.gguf` + /// (e.g. `Qwen3.5-9B-Q4_K_M.gguf`). This replaces the last `-{Quant}.gguf` + /// segment with the given quantization string. + pub fn with_quant(&self, quant: &str) -> Self { + let new_file = if let Some(pos) = self.hf_file.rfind('-') { + format!("{}-{}.gguf", &self.hf_file[..pos], quant) + } else { + self.hf_file.clone() + }; + ModelSpec { + hf_file: new_file, + sha256: None, // hash no longer valid for a different quant + ..self.clone() + } + } } /// Resolve a user-provided model name to a ModelSpec from the registry. -pub fn resolve_model(name: &str, registry: &HashMap) -> Option { - registry.get(name).cloned() +/// +/// When `quant` is provided, the default quantization in the registry is +/// replaced (e.g. `Q4_K_M` → `Q8_0`). +pub fn resolve_model( + name: &str, + registry: &HashMap, + quant: Option<&str>, +) -> Option { + let spec = registry.get(name)?.clone(); + Some(match quant { + Some(q) => spec.with_quant(q), + None => spec, + }) } #[cfg(test)] @@ -145,7 +174,7 @@ mod tests { #[test] fn test_resolve_known_model() { let reg = default_registry(); - let spec = resolve_model("lfm2.5:1.2b", ®).expect("should resolve"); + let spec = resolve_model("lfm2.5:1.2b", ®, None).expect("should resolve"); assert_eq!(spec.name, "lfm2.5:1.2b"); assert!(spec.hf_repo.contains("LFM2.5")); assert!(spec.hf_file.ends_with(".gguf")); @@ -155,7 +184,7 @@ mod tests { #[test] fn test_resolve_unknown_model() { let reg = default_registry(); - assert!(resolve_model("nonexistent:1b", ®).is_none()); + assert!(resolve_model("nonexistent:1b", ®, None).is_none()); } #[test] @@ -184,4 +213,32 @@ mod tests { assert_eq!(reg["lfm2.5:1.2b"].model_type, ModelType::Text); assert_eq!(reg["qwen3.5:27b"].model_type, ModelType::Text); } + + #[test] + fn test_with_quant_substitutes_suffix() { + let reg = default_registry(); + let spec = ®["qwen3.5:9b"]; + assert_eq!(spec.hf_file, "Qwen3.5-9B-Q4_K_M.gguf"); + + let q8 = spec.with_quant("Q8_0"); + assert_eq!(q8.hf_file, "Qwen3.5-9B-Q8_0.gguf"); + // Everything else stays the same + assert_eq!(q8.name, spec.name); + assert_eq!(q8.hf_repo, spec.hf_repo); + assert_eq!(q8.model_type, spec.model_type); + } + + #[test] + fn test_resolve_model_with_quant_override() { + let reg = default_registry(); + let spec = resolve_model("qwen3.5:9b", ®, Some("Q8_0")).unwrap(); + assert_eq!(spec.hf_file, "Qwen3.5-9B-Q8_0.gguf"); + } + + #[test] + fn test_resolve_model_without_quant_keeps_default() { + let reg = default_registry(); + let spec = resolve_model("qwen3.5:9b", ®, None).unwrap(); + assert_eq!(spec.hf_file, "Qwen3.5-9B-Q4_K_M.gguf"); + } } diff --git a/src/network/connection.rs b/src/network/connection.rs index 50141c15..6ab40f5e 100644 --- a/src/network/connection.rs +++ b/src/network/connection.rs @@ -578,7 +578,7 @@ mod tests { &RouterMessage::TaskAssignment { task_id, model: "nonexistent:1b".into(), - messages: vec![dkn_protocol::ChatMessage::text("user", "test")], + messages: vec![dkn_protocol::ChatMessage { role: "user".into(), content: "test".into() }], max_tokens: 10, temperature: 0.7, validation: None, diff --git a/src/worker.rs b/src/worker.rs index 02658310..267ba259 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -10,7 +10,7 @@ use uuid::Uuid; use crate::error::NodeError; use crate::inference::{GenerateParams, InferenceEngine, InferenceResult}; -use crate::models::template::{ChatMessage, MessageContent, apply_chat_template}; +use crate::models::template::{ChatMessage, apply_chat_template}; use crate::network::protocol::{ Capacity, ModelType, NodeMessage, RejectReason, TaskStats, ValidationRequest, }; @@ -328,6 +328,7 @@ mod tests { #[test] fn test_modality_check_text_content() { + use crate::models::template::MessageContent; // MessageContent::Text should have no image/audio let content = MessageContent::Text("hello".into()); assert!(!content.has_image()); From 0491fa401f0680528cdefe43705881e08aa7f055 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 3 Mar 2026 22:11:32 +0300 Subject: [PATCH 06/57] Add ALPN protocol negotiation for QUIC and suppress clippy too_many_arguments Router requires ALPN "dkn" on QUIC connections. Without it, the handshake fails with "peer doesn't support any known protocol". Set ALPN on the client config and on all mock server configs in tests. --- src/config.rs | 1 + src/main.rs | 1 + src/network/connection.rs | 12 ++++++++---- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/config.rs b/src/config.rs index 8deaf678..240f906a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -64,6 +64,7 @@ pub struct Config { impl Config { /// Create a Config from the `start` subcommand arguments. + #[allow(clippy::too_many_arguments)] pub fn from_start_args( wallet: String, model: String, diff --git a/src/main.rs b/src/main.rs index 70e0c16d..7f2f3f33 100644 --- a/src/main.rs +++ b/src/main.rs @@ -71,6 +71,7 @@ struct ModelLoadResult { result: Result<(inference::InferenceEngine, f64), error::NodeError>, } +#[allow(clippy::too_many_arguments)] async fn run_start( wallet: String, model: String, diff --git a/src/network/connection.rs b/src/network/connection.rs index 6ab40f5e..f431510f 100644 --- a/src/network/connection.rs +++ b/src/network/connection.rs @@ -90,7 +90,7 @@ impl RouterConnection { // --------------------------------------------------------------------------- fn build_client_config(insecure: bool) -> Result { - let crypto = if insecure { + let mut crypto = if insecure { rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(SkipServerVerification)) @@ -104,6 +104,7 @@ fn build_client_config(insecure: bool) -> Result { .with_root_certificates(root_store) .with_no_client_auth() }; + crypto.alpn_protocols = vec![b"dkn".to_vec()]; let mut transport = TransportConfig::default(); transport.keep_alive_interval(Some(Duration::from_secs(20))); @@ -271,10 +272,11 @@ mod tests { rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()); // Build server config - let server_crypto = rustls::ServerConfig::builder() + let mut server_crypto = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(vec![cert_der.clone()], key_der.into()) .unwrap(); + server_crypto.alpn_protocols = vec![b"dkn".to_vec()]; let mut server_config = quinn::ServerConfig::with_crypto(Arc::new( quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto).unwrap(), @@ -397,10 +399,11 @@ mod tests { let key_der = rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()); - let server_crypto = rustls::ServerConfig::builder() + let mut server_crypto = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(vec![cert_der.clone()], key_der.into()) .unwrap(); + server_crypto.alpn_protocols = vec![b"dkn".to_vec()]; let server_config = quinn::ServerConfig::with_crypto(Arc::new( quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto).unwrap(), @@ -532,10 +535,11 @@ mod tests { let key_der = rustls::pki_types::PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()); - let server_crypto = rustls::ServerConfig::builder() + let mut server_crypto = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(vec![cert_der.clone()], key_der.into()) .unwrap(); + server_crypto.alpn_protocols = vec![b"dkn".to_vec()]; let mut server_config = quinn::ServerConfig::with_crypto(Arc::new( quinn::crypto::rustls::QuicServerConfig::try_from(server_crypto).unwrap(), From 68d1bc77a7231b8a50e65d6a6df8ce98a54fd86f Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 3 Mar 2026 22:32:09 +0300 Subject: [PATCH 07/57] gracefull shutdown after reconnect --- src/main.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/main.rs b/src/main.rs index 7f2f3f33..c3034d86 100644 --- a/src/main.rs +++ b/src/main.rs @@ -194,14 +194,26 @@ async fn run_start( if let Some(ref conn) = connection { conn.close(); } - connection = try_reconnect(&ctx, worker.capacity()).await; + connection = tokio::select! { + result = try_reconnect(&ctx, worker.capacity()) => result, + _ = tokio::signal::ctrl_c() => { + tracing::info!("shutdown signal received during reconnect"); + break; + } + }; } Event::RouterMsg(Err(e)) => { tracing::warn!(%e, "router communication error, attempting reconnect"); if let Some(ref conn) = connection { conn.close(); } - connection = try_reconnect(&ctx, worker.capacity()).await; + connection = tokio::select! { + result = try_reconnect(&ctx, worker.capacity()) => result, + _ = tokio::signal::ctrl_c() => { + tracing::info!("shutdown signal received during reconnect"); + break; + } + }; } Event::TaskDone(completed) => { handle_completed_task(completed, &mut connection, &ctx.stats).await; From b39c75d0eaa1d75fc05e1868cbc32e60855059c0 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 3 Mar 2026 22:58:24 +0300 Subject: [PATCH 08/57] Add SSE streaming support for inference tasks Refactor RouterConnection to use mpsc write channel for concurrent sends. Extend Worker with streaming path: run_inference_streaming sends StreamToken messages per token via sync_channel bridge, then StreamEnd/StreamError on completion. Update main.rs to pass stream flag and connection sender to worker. --- src/main.rs | 44 +++++++----- src/network/connection.rs | 48 ++++++++----- src/worker.rs | 143 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 195 insertions(+), 40 deletions(-) diff --git a/src/main.rs b/src/main.rs index c3034d86..1b050c7a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -216,7 +216,7 @@ async fn run_start( }; } Event::TaskDone(completed) => { - handle_completed_task(completed, &mut connection, &ctx.stats).await; + handle_completed_task(completed, &connection, &ctx.stats); } Event::ModelLoaded(loaded) => { match loaded.result { @@ -252,7 +252,7 @@ async fn run_start( loop { tokio::select! { Some(completed) = worker.next_completed() => { - handle_completed_task(completed, &mut connection, &ctx.stats).await; + handle_completed_task(completed, &connection, &ctx.stats); } _ = tokio::time::sleep_until(drain_deadline) => { tracing::warn!("drain timeout reached, dropping remaining tasks"); @@ -362,9 +362,15 @@ async fn handle_router_message( max_tokens, temperature, validation, + stream, } => { - tracing::info!(%task_id, %model, "received task assignment"); - match worker.try_accept(task_id, &model, messages, max_tokens, temperature, validation) + tracing::info!(%task_id, %model, stream, "received task assignment"); + let stream_tx = if stream { + connection.as_ref().map(|conn| conn.sender()) + } else { + None + }; + match worker.try_accept(task_id, &model, messages, max_tokens, temperature, validation, stream, stream_tx) { Ok(()) => { tracing::debug!(%task_id, "task accepted"); @@ -372,9 +378,9 @@ async fn handle_router_message( Err(reason) => { ctx.stats.record_rejected(); tracing::warn!(%task_id, ?reason, "task rejected"); - if let Some(ref mut conn) = connection { + if let Some(ref conn) = connection { let reject = NodeMessage::TaskRejected { task_id, reason }; - if let Err(e) = conn.send(&reject).await { + if let Err(e) = conn.send(reject) { tracing::error!(%e, "failed to send rejection"); } } @@ -383,14 +389,14 @@ async fn handle_router_message( } RouterMessage::Ping => { tracing::debug!("received ping"); - if let Some(ref mut conn) = connection { + if let Some(ref conn) = connection { let status = NodeMessage::StatusUpdate { models: worker.model_names(), capacity: worker.capacity(), version: env!("CARGO_PKG_VERSION").to_string(), stats: Some(ctx.stats.snapshot()), }; - if let Err(e) = conn.send(&status).await { + if let Err(e) = conn.send(status) { tracing::error!(%e, "failed to send status update"); } } @@ -398,13 +404,13 @@ async fn handle_router_message( RouterMessage::Challenge { challenge } => { tracing::debug!("received challenge, signing response"); let (sig, recid) = ctx.identity.sign(&challenge); - if let Some(ref mut conn) = connection { + if let Some(ref conn) = connection { let response = NodeMessage::ChallengeResponse { challenge, signature: sig.serialize().to_vec(), recovery_id: recid.serialize(), }; - if let Err(e) = conn.send(&response).await { + if let Err(e) = conn.send(response) { tracing::error!(%e, "failed to send challenge response"); } } @@ -499,21 +505,25 @@ async fn download_and_load_model( } /// Handle a completed inference task: send result or log if offline. -async fn handle_completed_task( +fn handle_completed_task( completed: CompletedTask, - connection: &mut Option, + connection: &Option, stats: &NodeStats, ) { match completed.result { - Ok(ref msg) => { - let tokens = match msg { + Ok(msg) => { + let tokens = match &msg { NodeMessage::TaskResult { stats: ts, .. } => ts.tokens_generated, _ => 0, }; stats.record_completed(tokens); - tracing::info!(task_id = %completed.task_id, "task completed"); - if let Some(ref mut conn) = connection { - if let Err(e) = conn.send(msg).await { + tracing::info!(task_id = %completed.task_id, stream = completed.stream, "task completed"); + if completed.stream { + // Streaming tasks already sent tokens inline; nothing more to send. + return; + } + if let Some(ref conn) = connection { + if let Err(e) = conn.send(msg) { tracing::error!(%e, task_id = %completed.task_id, "failed to send result"); } } else { diff --git a/src/network/connection.rs b/src/network/connection.rs index f431510f..57305caa 100644 --- a/src/network/connection.rs +++ b/src/network/connection.rs @@ -7,6 +7,7 @@ use quinn::{ClientConfig, Endpoint, TransportConfig}; use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; use rustls::DigitallySignedStruct; +use tokio::sync::mpsc; use crate::error::NodeError; use crate::identity::Identity; @@ -19,10 +20,10 @@ pub struct RouterConnection { endpoint: Endpoint, /// The underlying QUIC connection. connection: quinn::Connection, - /// Send half of the bi-directional stream. - send: quinn::SendStream, /// Receive half of the bi-directional stream. recv: quinn::RecvStream, + /// Channel for outgoing messages (drained by background write task). + outgoing_tx: mpsc::UnboundedSender, /// Assigned node ID from the router. pub node_id: String, } @@ -59,18 +60,36 @@ impl RouterConnection { let node_id = authenticate(&mut send, &mut recv, identity, models, tps, capacity).await?; tracing::info!(%node_id, "authenticated with router"); + // Spawn background write task that drains outgoing channel + let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel::(); + tokio::spawn(async move { + while let Some(msg) = outgoing_rx.recv().await { + if let Err(e) = write_framed(&mut send, &msg).await { + tracing::error!(%e, "write task: failed to send message"); + break; + } + } + }); + Ok(RouterConnection { endpoint, connection, - send, recv, + outgoing_tx, node_id, }) } - /// Send a message to the router. - pub async fn send(&mut self, msg: &NodeMessage) -> Result<(), NodeError> { - Ok(write_framed(&mut self.send, msg).await?) + /// Send a message to the router (non-blocking, queues to write task). + pub fn send(&self, msg: NodeMessage) -> Result<(), NodeError> { + self.outgoing_tx + .send(msg) + .map_err(|_| NodeError::Network("write channel closed".into())) + } + + /// Get a clone of the outgoing sender for concurrent streaming use. + pub fn sender(&self) -> mpsc::UnboundedSender { + self.outgoing_tx.clone() } /// Receive a message from the router. Returns `None` on clean stream close. @@ -477,18 +496,17 @@ mod tests { assert_eq!(conn.node_id, "node-42"); - // Receive ping from router + // Receive ping from router (recv needs &mut) let msg = conn.recv().await.unwrap().unwrap(); assert!(matches!(msg, RouterMessage::Ping)); - // Send status update - conn.send(&NodeMessage::StatusUpdate { + // Send status update (send is &self via channel) + conn.send(NodeMessage::StatusUpdate { models: vec!["gemma3:4b".into()], capacity: Capacity { free: 2, max: 4 }, version: env!("CARGO_PKG_VERSION").to_string(), stats: None, }) - .await .unwrap(); rx.await.expect("server did not signal completion"); @@ -586,6 +604,7 @@ mod tests { max_tokens: 10, temperature: 0.7, validation: None, + stream: false, }, ) .await @@ -628,13 +647,12 @@ mod tests { let msg = conn.recv().await.unwrap().unwrap(); assert!(matches!(msg, RouterMessage::Ping)); - conn.send(&NodeMessage::StatusUpdate { + conn.send(NodeMessage::StatusUpdate { models: vec!["gemma3:4b".into()], capacity: Capacity { free: 1, max: 1 }, version: env!("CARGO_PKG_VERSION").to_string(), stats: None, }) - .await .unwrap(); // Receive task assignment → we just forward to test; in real code the worker handles it @@ -642,11 +660,10 @@ mod tests { match task_msg { RouterMessage::TaskAssignment { task_id, .. } => { // Reject: model not loaded (only "gemma3:4b" is listed but task asks for "nonexistent:1b") - conn.send(&NodeMessage::TaskRejected { + conn.send(NodeMessage::TaskRejected { task_id, reason: dkn_protocol::RejectReason::ModelNotLoaded, }) - .await .unwrap(); } _ => panic!("expected TaskAssignment"), @@ -795,7 +812,7 @@ mod tests { assert!(matches!(msg, RouterMessage::Ping)); // Send status update with stats - conn.send(&NodeMessage::StatusUpdate { + conn.send(NodeMessage::StatusUpdate { models: vec!["gemma3:4b".into()], capacity: Capacity { free: 1, max: 1 }, version: env!("CARGO_PKG_VERSION").to_string(), @@ -807,7 +824,6 @@ mod tests { uptime_secs: 600, }), }) - .await .unwrap(); rx.await.expect("server did not signal completion"); diff --git a/src/worker.rs b/src/worker.rs index 267ba259..4fde9a7d 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use futures::stream::FuturesUnordered; use futures::StreamExt; +use tokio::sync::mpsc; use tokio::task::JoinHandle; use uuid::Uuid; @@ -19,6 +20,8 @@ use crate::network::protocol::{ pub struct CompletedTask { pub task_id: Uuid, pub result: Result, + /// Whether this task was streamed (tokens already forwarded inline). + pub stream: bool, } /// Executes inference tasks with backpressure via capacity tracking. @@ -58,6 +61,9 @@ impl Worker { /// Try to accept a task. Returns `Err(RejectReason)` if the task cannot be accepted. /// /// On success, spawns inference in a blocking thread and returns immediately. + /// When `stream` is true and `stream_tx` is provided, tokens are forwarded + /// inline via the connection's outgoing channel. + #[allow(clippy::too_many_arguments)] pub fn try_accept( &self, task_id: Uuid, @@ -66,6 +72,8 @@ impl Worker { max_tokens: u32, temperature: f32, validation: Option, + stream: bool, + stream_tx: Option>, ) -> Result<(), RejectReason> { // Look up engine + template + model_type for the requested model (fail fast before decrementing capacity) let (engine, template, model_type) = self @@ -125,14 +133,47 @@ impl Worker { let capacity = Arc::clone(&self.capacity); - let handle = tokio::task::spawn_blocking(move || { - let result = run_inference(&engine, &template, messages, ¶ms, task_id); - // Release capacity slot regardless of outcome - capacity.fetch_add(1, Ordering::Release); - result - }); + if stream { + if let Some(conn_tx) = stream_tx { + // Bridge from blocking thread to async: use std sync_channel + let (sync_tx, sync_rx) = std::sync::mpsc::sync_channel::(32); + + // Async forwarder: reads from sync_rx, sends to connection channel + tokio::spawn(async move { + // sync_rx.recv() blocks, so we wrap in spawn_blocking to keep async runtime happy + loop { + let rx = sync_rx.try_recv(); + match rx { + Ok(msg) => { + if conn_tx.send(msg).is_err() { + break; // connection gone + } + } + Err(std::sync::mpsc::TryRecvError::Empty) => { + tokio::time::sleep(std::time::Duration::from_millis(1)).await; + } + Err(std::sync::mpsc::TryRecvError::Disconnected) => break, + } + } + }); + + let handle = tokio::task::spawn_blocking(move || { + let result = + run_inference_streaming(&engine, &template, messages, ¶ms, task_id, sync_tx); + capacity.fetch_add(1, Ordering::Release); + result + }); + self.in_flight.push(handle); + } + } else { + let handle = tokio::task::spawn_blocking(move || { + let result = run_inference(&engine, &template, messages, ¶ms, task_id); + capacity.fetch_add(1, Ordering::Release); + result + }); + self.in_flight.push(handle); + } - self.in_flight.push(handle); Ok(()) } @@ -210,14 +251,78 @@ fn run_inference( Ok(result) => CompletedTask { task_id, result: Ok(build_task_result(task_id, result)), + stream: false, }, Err(e) => CompletedTask { task_id, result: Err(e), + stream: false, }, } } +/// Run streaming inference: sends tokens via `token_tx` as they're generated. +fn run_inference_streaming( + engine: &InferenceEngine, + template: &str, + messages: Vec, + params: &GenerateParams, + task_id: Uuid, + token_tx: std::sync::mpsc::SyncSender, +) -> CompletedTask { + let prompt = apply_chat_template(template, &messages); + + let tx = token_tx.clone(); + let result = engine.generate(&prompt, params, move |stream_token| { + let msg = NodeMessage::StreamToken { + task_id, + token: stream_token.text, + index: stream_token.index as u32, + }; + if tx.send(msg).is_err() { + // Receiver dropped (connection lost), stop generation + return ControlFlow::Break(()); + } + ControlFlow::Continue(()) + }); + + match result { + Ok(result) => { + // Send StreamEnd + let end_msg = NodeMessage::StreamEnd { + task_id, + text: result.text.clone(), + stats: TaskStats { + tokens_generated: result.tokens_generated, + prompt_tokens: result.prompt_tokens, + generation_time_ms: result.generation_time_ms, + tokens_per_second: result.tokens_per_second, + }, + proof: result.proof.clone(), + }; + let _ = token_tx.send(end_msg); + CompletedTask { + task_id, + result: Ok(build_task_result(task_id, result)), + stream: true, + } + } + Err(e) => { + // Send StreamError + let err_msg = NodeMessage::StreamError { + task_id, + error: e.to_string(), + }; + let _ = token_tx.send(err_msg); + CompletedTask { + task_id, + result: Err(e), + stream: true, + } + } + } +} + fn build_task_result(task_id: Uuid, result: InferenceResult) -> NodeMessage { NodeMessage::TaskResult { task_id, @@ -295,8 +400,10 @@ mod tests { let completed = CompletedTask { task_id: Uuid::nil(), result: Ok(msg), + stream: false, }; assert!(completed.result.is_ok()); + assert!(!completed.stream); } #[test] @@ -304,10 +411,32 @@ mod tests { let completed = CompletedTask { task_id: Uuid::nil(), result: Err(NodeError::Inference("test error".into())), + stream: false, }; assert!(completed.result.is_err()); } + #[test] + fn test_completed_task_streaming() { + let completed = CompletedTask { + task_id: Uuid::nil(), + result: Ok(NodeMessage::TaskResult { + task_id: Uuid::nil(), + text: "streamed".into(), + stats: TaskStats { + tokens_generated: 3, + prompt_tokens: 2, + generation_time_ms: 30, + tokens_per_second: 100.0, + }, + proof: None, + }), + stream: true, + }; + assert!(completed.stream); + assert!(completed.result.is_ok()); + } + #[test] fn test_worker_has_model() { let worker = Worker::new(HashMap::new(), 1); From 96159f120cd88944aacf47ae526b228d6cf9132a Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 3 Mar 2026 23:42:41 +0300 Subject: [PATCH 09/57] Use GGUF-embedded chat templates instead of hardcoded ones Add InferenceEngine::apply_template() which extracts the Jinja2 chat template from GGUF metadata and applies it via llama.cpp's built-in engine. Remove all hardcoded template plumbing: chat_template field from ModelSpec, template parameter from Worker/add_engine/run_inference, and the models::template re-export module. Co-Authored-By: Claude Opus 4.6 --- src/inference/engine.rs | 20 ++++++++++++- src/main.rs | 17 +++-------- src/models/cache.rs | 1 - src/models/mod.rs | 4 --- src/models/registry.rs | 14 --------- src/network/protocol.rs | 4 +-- src/worker.rs | 65 ++++++++++++++++++++++++----------------- 7 files changed, 63 insertions(+), 62 deletions(-) diff --git a/src/inference/engine.rs b/src/inference/engine.rs index b65363a8..22b3dea4 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -6,10 +6,12 @@ use llama_cpp_2::context::params::LlamaContextParams; use llama_cpp_2::llama_backend::LlamaBackend; use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::params::LlamaModelParams; -use llama_cpp_2::model::{AddBos, LlamaModel}; +use llama_cpp_2::model::{AddBos, LlamaChatMessage, LlamaModel}; use llama_cpp_2::sampling::LlamaSampler; use llama_cpp_2::token::LlamaToken; +use dkn_protocol::ChatMessage; + use crate::error::NodeError; use crate::identity::sha256hash; use dkn_protocol::{InferenceProof, TokenLogprob}; @@ -101,6 +103,22 @@ impl InferenceEngine { self.gpu_layers } + /// Apply the GGUF-embedded chat template to produce a formatted prompt string. + pub fn apply_template(&self, messages: &[ChatMessage]) -> Result { + let template = self + .model + .chat_template(None) + .map_err(|e| NodeError::Inference(format!("no chat template in model: {e}")))?; + let llama_messages: Vec = messages + .iter() + .map(|m| LlamaChatMessage::new(m.role.clone(), m.content.to_string())) + .collect::>() + .map_err(|e| NodeError::Inference(format!("invalid chat message: {e}")))?; + self.model + .apply_chat_template(&template, &llama_messages, true) + .map_err(|e| NodeError::Inference(format!("failed to apply chat template: {e}"))) + } + /// Generate text from a prompt. /// /// `on_token` is called for each generated token. Return `ControlFlow::Break(())` diff --git a/src/main.rs b/src/main.rs index 1b050c7a..823a8ad1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -66,7 +66,6 @@ struct NodeContext { /// Result of a background model download + load operation. struct ModelLoadResult { name: String, - template: String, model_type: ModelType, result: Result<(inference::InferenceEngine, f64), error::NodeError>, } @@ -98,7 +97,7 @@ async fn run_start( let cache = ModelCache::new(config.models_dir.clone())?; // Accumulate engines and TPS per model - let mut engines: HashMap = HashMap::new(); + let mut engines: HashMap = HashMap::new(); let mut tps_map: HashMap = HashMap::new(); for model_name in &config.model_names { @@ -106,13 +105,9 @@ async fn run_start( .ok_or_else(|| error::NodeError::Model(format!("unknown model: {model_name}")))?; let (engine, tps) = download_and_load_model(&spec, &cache, config.gpu_layers).await?; - let chat_template = spec - .chat_template - .clone() - .unwrap_or_else(|| "chatml".to_string()); tracing::info!(tps = %format!("{tps:.1}"), model = %model_name, "benchmark complete"); - engines.insert(model_name.clone(), (engine, chat_template, spec.model_type)); + engines.insert(model_name.clone(), (engine, spec.model_type)); tps_map.insert(model_name.clone(), tps); } @@ -226,7 +221,7 @@ async fn run_start( tps = %format!("{tps:.1}"), "model loaded successfully" ); - worker.add_engine(loaded.name.clone(), engine, loaded.template, loaded.model_type); + worker.add_engine(loaded.name.clone(), engine, loaded.model_type); ctx.tps.insert(loaded.name, tps); } Err(e) => { @@ -442,16 +437,12 @@ async fn handle_router_message( let gpu_layers = ctx.config.gpu_layers; let tx = model_tx.clone(); let name = entry.name.clone(); - let template = entry - .chat_template - .clone() - .unwrap_or_else(|| "chatml".to_string()); let model_type = entry.model_type; tracing::info!(model = %name, "spawning background model download+load"); tokio::spawn(async move { let result = download_and_load_model(&spec, &cache, gpu_layers).await; - let _ = tx.send(ModelLoadResult { name, template, model_type, result }); + let _ = tx.send(ModelLoadResult { name, model_type, result }); }); } } diff --git a/src/models/cache.rs b/src/models/cache.rs index 2e4a69f9..1ee9d946 100644 --- a/src/models/cache.rs +++ b/src/models/cache.rs @@ -89,7 +89,6 @@ mod tests { hf_repo: "test/repo".into(), hf_file: "model.gguf".into(), sha256: None, - chat_template: None, model_type: dkn_protocol::ModelType::Text, }; diff --git a/src/models/mod.rs b/src/models/mod.rs index 7d0d7324..30886bfe 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -2,10 +2,6 @@ pub mod cache; pub mod download; pub mod registry; -pub mod template { - pub use dkn_protocol::{apply_chat_template, ChatMessage, MessageContent}; -} - pub use cache::ModelCache; pub use download::ModelDownloader; pub use registry::{default_registry, resolve_model}; diff --git a/src/models/registry.rs b/src/models/registry.rs index 34a3b898..99f474db 100644 --- a/src/models/registry.rs +++ b/src/models/registry.rs @@ -13,8 +13,6 @@ pub struct ModelSpec { pub hf_file: String, /// Expected SHA-256 hex digest for verification (None = skip verification) pub sha256: Option, - /// Chat template identifier (e.g. "gemma", "llama3", "chatml") - pub chat_template: Option, /// Modality this model supports. pub model_type: ModelType, } @@ -27,7 +25,6 @@ pub fn default_registry() -> HashMap { hf_repo: "LiquidAI/LFM2.5-1.2B-Instruct-GGUF".into(), hf_file: "LFM2.5-1.2B-Instruct-Q4_K_M.gguf".into(), sha256: None, - chat_template: Some("chatml".into()), model_type: ModelType::Text, }, ModelSpec { @@ -35,7 +32,6 @@ pub fn default_registry() -> HashMap { hf_repo: "unsloth/Qwen3.5-35B-A3B-GGUF".into(), hf_file: "Qwen3.5-35B-A3B-UD-Q4_K_M.gguf".into(), sha256: None, - chat_template: Some("chatml".into()), model_type: ModelType::Text, }, ModelSpec { @@ -43,7 +39,6 @@ pub fn default_registry() -> HashMap { hf_repo: "LiquidAI/LFM2-24B-A2B-GGUF".into(), hf_file: "LFM2-24B-A2B-Q4_K_M.gguf".into(), sha256: None, - chat_template: Some("chatml".into()), model_type: ModelType::Text, }, ModelSpec { @@ -51,7 +46,6 @@ pub fn default_registry() -> HashMap { hf_repo: "LiquidAI/LFM2.5-VL-1.6B-GGUF".into(), hf_file: "LFM2.5-VL-1.6B-Q4_0.gguf".into(), sha256: None, - chat_template: Some("chatml".into()), model_type: ModelType::Vision, }, ModelSpec { @@ -59,7 +53,6 @@ pub fn default_registry() -> HashMap { hf_repo: "LiquidAI/LFM2.5-Audio-1.5B-GGUF".into(), hf_file: "LFM2.5-Audio-1.5B-Q4_0.gguf".into(), sha256: None, - chat_template: Some("chatml".into()), model_type: ModelType::Audio, }, ModelSpec { @@ -67,7 +60,6 @@ pub fn default_registry() -> HashMap { hf_repo: "unsloth/Qwen3.5-27B-GGUF".into(), hf_file: "Qwen3.5-27B-Q4_K_M.gguf".into(), sha256: None, - chat_template: Some("chatml".into()), model_type: ModelType::Text, }, ModelSpec { @@ -75,7 +67,6 @@ pub fn default_registry() -> HashMap { hf_repo: "DevQuasar/Nanbeige.Nanbeige4.1-3B-GGUF".into(), hf_file: "Nanbeige.Nanbeige4.1-3B.Q4_K_M.gguf".into(), sha256: None, - chat_template: Some("chatml".into()), model_type: ModelType::Text, }, ModelSpec { @@ -83,7 +74,6 @@ pub fn default_registry() -> HashMap { hf_repo: "LocoreMind/LocoOperator-4B-GGUF".into(), hf_file: "LocoOperator-4B.Q4_K_M.gguf".into(), sha256: None, - chat_template: Some("chatml".into()), model_type: ModelType::Text, }, ModelSpec { @@ -91,7 +81,6 @@ pub fn default_registry() -> HashMap { hf_repo: "lmstudio-community/Qwen3.5-9B-GGUF".into(), hf_file: "Qwen3.5-9B-Q4_K_M.gguf".into(), sha256: None, - chat_template: Some("chatml".into()), model_type: ModelType::Text, }, ]; @@ -107,7 +96,6 @@ impl ModelSpec { hf_repo: entry.hf_repo.clone(), hf_file: entry.hf_file.clone(), sha256: None, - chat_template: entry.chat_template.clone(), model_type: entry.model_type, } } @@ -193,7 +181,6 @@ mod tests { name: "test:1b".into(), hf_repo: "test/repo".into(), hf_file: "model.gguf".into(), - chat_template: Some("chatml".into()), model_type: ModelType::Vision, }; let spec = ModelSpec::from_registry_entry(&entry); @@ -201,7 +188,6 @@ mod tests { assert_eq!(spec.hf_repo, "test/repo"); assert_eq!(spec.hf_file, "model.gguf"); assert!(spec.sha256.is_none()); - assert_eq!(spec.chat_template, Some("chatml".into())); assert_eq!(spec.model_type, ModelType::Vision); } diff --git a/src/network/protocol.rs b/src/network/protocol.rs index 647bbb0b..9f53ef2b 100644 --- a/src/network/protocol.rs +++ b/src/network/protocol.rs @@ -1,4 +1,4 @@ pub use dkn_protocol::{ - read_framed, write_framed, Capacity, ModelType, NodeMessage, RejectReason, RouterMessage, - TaskStats, ValidationRequest, + read_framed, write_framed, Capacity, ChatMessage, ModelType, NodeMessage, RejectReason, + RouterMessage, TaskStats, ValidationRequest, }; diff --git a/src/worker.rs b/src/worker.rs index 4fde9a7d..aa0eb829 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -11,9 +11,8 @@ use uuid::Uuid; use crate::error::NodeError; use crate::inference::{GenerateParams, InferenceEngine, InferenceResult}; -use crate::models::template::{ChatMessage, apply_chat_template}; use crate::network::protocol::{ - Capacity, ModelType, NodeMessage, RejectReason, TaskStats, ValidationRequest, + Capacity, ChatMessage, ModelType, NodeMessage, RejectReason, TaskStats, ValidationRequest, }; /// A completed inference task ready to be sent back. @@ -26,10 +25,10 @@ pub struct CompletedTask { /// Executes inference tasks with backpressure via capacity tracking. /// -/// Supports multiple models, each with its own engine, chat template, and modality. +/// Supports multiple models, each with its own engine and modality. pub struct Worker { - /// Map of model name → (engine, chat_template, model_type). - engines: HashMap, String, ModelType)>, + /// Map of model name → (engine, model_type). + engines: HashMap, ModelType)>, /// Number of available inference slots (CAS-based). capacity: Arc, /// Maximum concurrent slots. @@ -41,14 +40,12 @@ pub struct Worker { impl Worker { /// Create a new worker wrapping multiple inference engines. pub fn new( - engines: HashMap, + engines: HashMap, max_concurrent: usize, ) -> Self { let engines = engines .into_iter() - .map(|(name, (engine, template, model_type))| { - (name, (Arc::new(engine), template, model_type)) - }) + .map(|(name, (engine, model_type))| (name, (Arc::new(engine), model_type))) .collect(); Worker { engines, @@ -75,8 +72,8 @@ impl Worker { stream: bool, stream_tx: Option>, ) -> Result<(), RejectReason> { - // Look up engine + template + model_type for the requested model (fail fast before decrementing capacity) - let (engine, template, model_type) = self + // Look up engine + model_type for the requested model (fail fast before decrementing capacity) + let (engine, model_type) = self .engines .get(model) .ok_or(RejectReason::ModelNotLoaded)?; @@ -101,7 +98,6 @@ impl Worker { } let engine = Arc::clone(engine); - let template = template.clone(); // Try to decrement capacity (CAS loop) loop { @@ -159,7 +155,7 @@ impl Worker { let handle = tokio::task::spawn_blocking(move || { let result = - run_inference_streaming(&engine, &template, messages, ¶ms, task_id, sync_tx); + run_inference_streaming(&engine, messages, ¶ms, task_id, sync_tx); capacity.fetch_add(1, Ordering::Release); result }); @@ -167,7 +163,7 @@ impl Worker { } } else { let handle = tokio::task::spawn_blocking(move || { - let result = run_inference(&engine, &template, messages, ¶ms, task_id); + let result = run_inference(&engine, messages, ¶ms, task_id); capacity.fetch_add(1, Ordering::Release); result }); @@ -213,15 +209,9 @@ impl Worker { /// Add a new model engine at runtime (for hot-swap). /// /// If a model with this name already exists, it is replaced. - pub fn add_engine( - &mut self, - name: String, - engine: InferenceEngine, - template: String, - model_type: ModelType, - ) { + pub fn add_engine(&mut self, name: String, engine: InferenceEngine, model_type: ModelType) { self.engines - .insert(name, (Arc::new(engine), template, model_type)); + .insert(name, (Arc::new(engine), model_type)); } /// Remove a model engine by name. Returns true if the model was present. @@ -240,12 +230,20 @@ impl Worker { /// Run inference synchronously (called from `spawn_blocking`). fn run_inference( engine: &InferenceEngine, - template: &str, messages: Vec, params: &GenerateParams, task_id: Uuid, ) -> CompletedTask { - let prompt = apply_chat_template(template, &messages); + let prompt = match engine.apply_template(&messages) { + Ok(p) => p, + Err(e) => { + return CompletedTask { + task_id, + result: Err(e), + stream: false, + }; + } + }; match engine.generate(&prompt, params, |_| ControlFlow::Continue(())) { Ok(result) => CompletedTask { @@ -264,13 +262,26 @@ fn run_inference( /// Run streaming inference: sends tokens via `token_tx` as they're generated. fn run_inference_streaming( engine: &InferenceEngine, - template: &str, messages: Vec, params: &GenerateParams, task_id: Uuid, token_tx: std::sync::mpsc::SyncSender, ) -> CompletedTask { - let prompt = apply_chat_template(template, &messages); + let prompt = match engine.apply_template(&messages) { + Ok(p) => p, + Err(e) => { + let err_msg = NodeMessage::StreamError { + task_id, + error: e.to_string(), + }; + let _ = token_tx.send(err_msg); + return CompletedTask { + task_id, + result: Err(e), + stream: true, + }; + } + }; let tx = token_tx.clone(); let result = engine.generate(&prompt, params, move |stream_token| { @@ -457,7 +468,7 @@ mod tests { #[test] fn test_modality_check_text_content() { - use crate::models::template::MessageContent; + use dkn_protocol::MessageContent; // MessageContent::Text should have no image/audio let content = MessageContent::Text("hello".into()); assert!(!content.has_image()); From 50ade4915b46cc511587ba1023ecb564d0759c89 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Wed, 4 Mar 2026 01:05:52 +0300 Subject: [PATCH 10/57] Add multimodal (vision/audio) inference support Enable llama-cpp-2 mtmd feature, add MtmdContext to InferenceEngine with generate_multimodal() path, mmproj download/cache, hf_mmproj_file on ModelSpec (populated for lfm2.5-vl and lfm2.5-audio), and multimodal branching in worker sync/streaming paths. Co-Authored-By: Claude Opus 4.6 --- Cargo.lock | 1 + Cargo.toml | 2 +- src/inference/engine.rs | 219 +++++++++++++++++++++++++++++++++++++++- src/main.rs | 15 ++- src/models/cache.rs | 66 ++++++++++++ src/models/download.rs | 39 +++++++ src/models/registry.rs | 35 +++++++ src/worker.rs | 123 ++++++++++++++-------- 8 files changed, 455 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c079ac4d..d1e19222 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -445,6 +445,7 @@ dependencies = [ name = "dkn-protocol" version = "0.1.0" dependencies = [ + "base64", "quinn", "rmp-serde", "serde", diff --git a/Cargo.toml b/Cargo.toml index 718b4000..2bd49531 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ sha2 = "0.10" sha3 = "0.10" hex = "0.4" rand = "0.8" -llama-cpp-2 = "0.1.137" +llama-cpp-2 = { version = "0.1.137", features = ["mtmd"] } hf-hub = { version = "0.4", features = ["tokio"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/src/inference/engine.rs b/src/inference/engine.rs index 22b3dea4..f9f1a69d 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -7,6 +7,7 @@ use llama_cpp_2::llama_backend::LlamaBackend; use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::params::LlamaModelParams; use llama_cpp_2::model::{AddBos, LlamaChatMessage, LlamaModel}; +use llama_cpp_2::mtmd::{MtmdBitmap, MtmdContext, MtmdContextParams, MtmdInputText}; use llama_cpp_2::sampling::LlamaSampler; use llama_cpp_2::token::LlamaToken; @@ -62,6 +63,7 @@ pub struct InferenceResult { pub struct InferenceEngine { backend: LlamaBackend, model: LlamaModel, + mtmd_ctx: Option, #[allow(dead_code)] gpu_layers: i32, } @@ -75,8 +77,12 @@ fn token_to_string(model: &LlamaModel, token: LlamaToken) -> String { } impl InferenceEngine { - /// Load a GGUF model from disk. - pub fn load(path: &Path, gpu_layers: i32) -> Result { + /// Load a GGUF model from disk, optionally with a multimodal projector. + pub fn load( + path: &Path, + gpu_layers: i32, + mmproj_path: Option<&Path>, + ) -> Result { let backend = LlamaBackend::init() .map_err(|e| NodeError::Inference(format!("failed to init llama backend: {e}")))?; @@ -90,13 +96,40 @@ impl InferenceEngine { let model = LlamaModel::load_from_file(&backend, path, &model_params) .map_err(|e| NodeError::Inference(format!("failed to load model: {e}")))?; + let mtmd_ctx = match mmproj_path { + Some(p) => { + let params = MtmdContextParams::default(); + let ctx = MtmdContext::init_from_file( + p.to_str() + .ok_or_else(|| NodeError::Inference("invalid mmproj path".into()))?, + &model, + ¶ms, + ) + .map_err(|e| NodeError::Inference(format!("failed to init mtmd context: {e}")))?; + tracing::info!( + path = %p.display(), + vision = ctx.support_vision(), + audio = ctx.support_audio(), + "multimodal projector loaded" + ); + Some(ctx) + } + None => None, + }; + Ok(InferenceEngine { backend, model, + mtmd_ctx, gpu_layers, }) } + /// Whether this engine has a multimodal projector loaded. + pub fn has_multimodal(&self) -> bool { + self.mtmd_ctx.is_some() + } + /// Return the number of GPU layers configured. #[allow(dead_code)] pub fn gpu_layers(&self) -> i32 { @@ -119,6 +152,28 @@ impl InferenceEngine { .map_err(|e| NodeError::Inference(format!("failed to apply chat template: {e}"))) } + /// Apply the GGUF-embedded chat template with media parts replaced by the given marker. + fn apply_template_with_marker( + &self, + messages: &[ChatMessage], + marker: &str, + ) -> Result { + let template = self + .model + .chat_template(None) + .map_err(|e| NodeError::Inference(format!("no chat template in model: {e}")))?; + let llama_messages: Vec = messages + .iter() + .map(|m| { + LlamaChatMessage::new(m.role.clone(), m.content.text_with_markers(marker)) + }) + .collect::>() + .map_err(|e| NodeError::Inference(format!("invalid chat message: {e}")))?; + self.model + .apply_chat_template(&template, &llama_messages, true) + .map_err(|e| NodeError::Inference(format!("failed to apply chat template: {e}"))) + } + /// Generate text from a prompt. /// /// `on_token` is called for each generated token. Return `ControlFlow::Break(())` @@ -251,6 +306,166 @@ impl InferenceEngine { }) } + /// Generate text from multimodal messages containing image/audio parts. + /// + /// Uses the mtmd context to process media, then runs the standard sampling loop. + pub fn generate_multimodal( + &self, + messages: &[ChatMessage], + params: &GenerateParams, + mut on_token: F, + ) -> Result + where + F: FnMut(StreamToken) -> ControlFlow<()>, + { + let mtmd_ctx = self + .mtmd_ctx + .as_ref() + .ok_or_else(|| NodeError::Inference("no multimodal context loaded".into()))?; + + // Get the default media marker used by the mtmd tokenizer + let marker = llama_cpp_2::mtmd::mtmd_default_marker(); + + // Apply chat template with media parts replaced by the marker + let prompt = self.apply_template_with_marker(messages, marker)?; + + // Collect all media byte slices in order across all messages + let mut media_blobs: Vec<&[u8]> = Vec::new(); + for msg in messages { + media_blobs.extend(msg.content.media_data()); + } + + // Create bitmaps from media blobs + let bitmaps: Vec = media_blobs + .iter() + .map(|data| { + MtmdBitmap::from_buffer(mtmd_ctx, data) + .map_err(|e| NodeError::Inference(format!("failed to create bitmap: {e}"))) + }) + .collect::, _>>()?; + + let bitmap_refs: Vec<&MtmdBitmap> = bitmaps.iter().collect(); + + // Tokenize the prompt with media markers resolved to bitmap embeddings + let input_text = MtmdInputText { + text: prompt, + add_special: false, // chat template already includes BOS + parse_special: true, + }; + let chunks = mtmd_ctx + .tokenize(input_text, &bitmap_refs) + .map_err(|e| NodeError::Inference(format!("mtmd tokenize failed: {e}")))?; + + let prompt_token_count = chunks.total_tokens() as u32; + + // Create context with larger size for multimodal + let ctx_size = std::num::NonZeroU32::new(4096); + let ctx_params = LlamaContextParams::default().with_n_ctx(ctx_size); + + let mut ctx = self + .model + .new_context(&self.backend, ctx_params) + .map_err(|e| NodeError::Inference(format!("failed to create context: {e}")))?; + + // Evaluate all chunks (text + media embeddings) + let prompt_start = Instant::now(); + let n_past = chunks + .eval_chunks(mtmd_ctx, &ctx, 0, 0, 512, true) + .map_err(|e| NodeError::Inference(format!("mtmd eval_chunks failed: {e}")))?; + let prompt_eval_time_ms = prompt_start.elapsed().as_millis() as u64; + + // Build sampler chain + let mut samplers = vec![]; + if params.temperature > 0.0 { + samplers.push(LlamaSampler::top_p(params.top_p, 1)); + samplers.push(LlamaSampler::temp(params.temperature)); + samplers.push(LlamaSampler::dist(params.seed.unwrap_or(0))); + } else { + samplers.push(LlamaSampler::greedy()); + } + let mut sampler = LlamaSampler::chain_simple(samplers); + + // Generation loop (same as text-only but starting from n_past) + let gen_start = Instant::now(); + let mut generated_text = String::new(); + let mut generated_count: u32 = 0; + let mut logprobs: Vec = Vec::new(); + let mut current_pos = n_past; + let mut decoder = encoding_rs::UTF_8.new_decoder(); + let mut batch = LlamaBatch::new(1, 1); + + for _ in 0..params.max_tokens { + let new_token = sampler.sample(&ctx, -1); + sampler.accept(new_token); + + if self.model.is_eog_token(new_token) { + break; + } + + // Extract logprobs if this position was requested + let gen_index = generated_count as usize; + if params.logprob_positions.contains(&gen_index) { + if let Some(lp) = + self.extract_logprob(&ctx, -1, gen_index, new_token, params.logprob_top_k) + { + logprobs.push(lp); + } + } + + // Decode token to text + let piece = self + .model + .token_to_piece(new_token, &mut decoder, true, None) + .unwrap_or_default(); + generated_text.push_str(&piece); + generated_count += 1; + + // Stream callback + let stream_token = StreamToken { + text: piece, + index: gen_index, + }; + if let ControlFlow::Break(()) = on_token(stream_token) { + break; + } + + // Prepare next batch + batch.clear(); + batch + .add(new_token, current_pos, &[0], true) + .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?; + ctx.decode(&mut batch) + .map_err(|e| NodeError::Inference(format!("decode failed: {e}")))?; + current_pos += 1; + } + + let generation_time_ms = gen_start.elapsed().as_millis() as u64; + let tokens_per_second = if generation_time_ms > 0 { + (generated_count as f64) / (generation_time_ms as f64 / 1000.0) + } else { + 0.0 + }; + + let proof = if logprobs.is_empty() { + None + } else { + Some(InferenceProof { + logprobs, + kv_cache_hash: None, + }) + }; + + Ok(InferenceResult { + text: generated_text, + tokens_generated: generated_count, + prompt_tokens: prompt_token_count, + generation_time_ms, + prompt_eval_time_ms, + tokens_per_second, + proof, + }) + } + /// Extract logprob data at a given batch index. fn extract_logprob( &self, diff --git a/src/main.rs b/src/main.rs index 823a8ad1..080f9dd4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -482,9 +482,22 @@ async fn download_and_load_model( cache.link_model(spec, &hf_path)? }; + // Download mmproj if specified (for vision/audio models) + let mmproj_path = if spec.hf_mmproj_file.is_some() { + if let Some(path) = cache.get_mmproj_path(spec) { + tracing::info!(model = %model_name, path = %path.display(), "mmproj found in cache"); + Some(path) + } else { + let hf_path = ModelDownloader::download_mmproj(spec).await?; + Some(cache.link_mmproj(spec, &hf_path)?) + } + } else { + None + }; + // Load model and run benchmark in blocking thread let (engine, tps) = tokio::task::spawn_blocking(move || { - let engine = inference::InferenceEngine::load(&model_path, gpu_layers)?; + let engine = inference::InferenceEngine::load(&model_path, gpu_layers, mmproj_path.as_deref())?; let tps_result = engine.benchmark(&model_name)?; Ok::<_, error::NodeError>((engine, tps_result.generation_tps)) }) diff --git a/src/models/cache.rs b/src/models/cache.rs index 1ee9d946..9ae22e45 100644 --- a/src/models/cache.rs +++ b/src/models/cache.rs @@ -28,6 +28,17 @@ impl ModelCache { } } + /// Check if a model's mmproj GGUF is already present in our cache. + pub fn get_mmproj_path(&self, spec: &ModelSpec) -> Option { + let file = spec.hf_mmproj_file.as_ref()?; + let path = self.cache_dir.join(file); + if path.exists() { + Some(path) + } else { + None + } + } + /// Verify a file's SHA-256 against an expected hex digest. /// Returns Ok(true) if matches, Ok(false) if mismatch, Err on I/O failure. pub fn verify_sha256(path: &Path, expected_hex: &str) -> Result { @@ -55,6 +66,26 @@ impl ModelCache { Ok(dest) } + + /// Create a symlink from our cache dir to the hf-hub cached mmproj file. + pub fn link_mmproj(&self, spec: &ModelSpec, source: &Path) -> Result { + let file = spec + .hf_mmproj_file + .as_ref() + .ok_or_else(|| NodeError::Model("no mmproj file specified".into()))?; + let dest = self.cache_dir.join(file); + if dest.exists() { + return Ok(dest); + } + + #[cfg(unix)] + std::os::unix::fs::symlink(source, &dest)?; + + #[cfg(not(unix))] + std::fs::copy(source, &dest)?; + + Ok(dest) + } } #[cfg(test)] @@ -90,6 +121,7 @@ mod tests { hf_file: "model.gguf".into(), sha256: None, model_type: dkn_protocol::ModelType::Text, + hf_mmproj_file: None, }; // Not present initially @@ -101,4 +133,38 @@ mod tests { std::fs::remove_dir_all(&dir).ok(); } + + #[test] + fn test_mmproj_cache_path() { + let dir = std::env::temp_dir().join("dria-cache-test-mmproj"); + let cache = ModelCache::new(dir.clone()).unwrap(); + + let spec_no_mmproj = ModelSpec { + name: "text:1b".into(), + hf_repo: "test/repo".into(), + hf_file: "model.gguf".into(), + sha256: None, + model_type: dkn_protocol::ModelType::Text, + hf_mmproj_file: None, + }; + assert!(cache.get_mmproj_path(&spec_no_mmproj).is_none()); + + let spec_with_mmproj = ModelSpec { + name: "vl:1b".into(), + hf_repo: "test/repo".into(), + hf_file: "model.gguf".into(), + sha256: None, + model_type: dkn_protocol::ModelType::Vision, + hf_mmproj_file: Some("mmproj.gguf".into()), + }; + + // Not present initially + assert!(cache.get_mmproj_path(&spec_with_mmproj).is_none()); + + // Create the mmproj file + std::fs::write(dir.join("mmproj.gguf"), b"fake").unwrap(); + assert!(cache.get_mmproj_path(&spec_with_mmproj).is_some()); + + std::fs::remove_dir_all(&dir).ok(); + } } diff --git a/src/models/download.rs b/src/models/download.rs index 997ecea1..f199b328 100644 --- a/src/models/download.rs +++ b/src/models/download.rs @@ -43,4 +43,43 @@ impl ModelDownloader { Ok(path) } + + /// Download the multimodal projector GGUF from HuggingFace. + /// + /// Returns the local path to the downloaded mmproj file. + pub async fn download_mmproj(spec: &ModelSpec) -> Result { + let mmproj_file = spec + .hf_mmproj_file + .as_ref() + .ok_or_else(|| NodeError::Model("no mmproj file specified".into()))?; + + let api = ApiBuilder::new() + .with_progress(true) + .build() + .map_err(|e| NodeError::Model(format!("failed to create HF API client: {e}")))?; + + let repo = api.model(spec.hf_repo.clone()); + + tracing::info!( + model = %spec.name, + repo = %spec.hf_repo, + file = %mmproj_file, + "downloading mmproj from HuggingFace" + ); + + let path = repo.get(mmproj_file).await.map_err(|e| { + NodeError::Model(format!( + "failed to download mmproj for {}: {e}", + spec.name + )) + })?; + + tracing::info!( + model = %spec.name, + path = %path.display(), + "mmproj download complete" + ); + + Ok(path) + } } diff --git a/src/models/registry.rs b/src/models/registry.rs index 99f474db..1e6fa0bb 100644 --- a/src/models/registry.rs +++ b/src/models/registry.rs @@ -15,6 +15,8 @@ pub struct ModelSpec { pub sha256: Option, /// Modality this model supports. pub model_type: ModelType, + /// Optional multimodal projector GGUF filename within the same repo. + pub hf_mmproj_file: Option, } /// Build the default model registry with all supported models. @@ -26,6 +28,7 @@ pub fn default_registry() -> HashMap { hf_file: "LFM2.5-1.2B-Instruct-Q4_K_M.gguf".into(), sha256: None, model_type: ModelType::Text, + hf_mmproj_file: None, }, ModelSpec { name: "qwen3.5:35b-a3b".into(), @@ -33,6 +36,7 @@ pub fn default_registry() -> HashMap { hf_file: "Qwen3.5-35B-A3B-UD-Q4_K_M.gguf".into(), sha256: None, model_type: ModelType::Text, + hf_mmproj_file: None, }, ModelSpec { name: "lfm2:24b-a2b".into(), @@ -40,6 +44,7 @@ pub fn default_registry() -> HashMap { hf_file: "LFM2-24B-A2B-Q4_K_M.gguf".into(), sha256: None, model_type: ModelType::Text, + hf_mmproj_file: None, }, ModelSpec { name: "lfm2.5-vl:1.6b".into(), @@ -47,6 +52,7 @@ pub fn default_registry() -> HashMap { hf_file: "LFM2.5-VL-1.6B-Q4_0.gguf".into(), sha256: None, model_type: ModelType::Vision, + hf_mmproj_file: Some("mmproj-LFM2.5-VL-1.6b-F16.gguf".into()), }, ModelSpec { name: "lfm2.5-audio:1.5b".into(), @@ -54,6 +60,7 @@ pub fn default_registry() -> HashMap { hf_file: "LFM2.5-Audio-1.5B-Q4_0.gguf".into(), sha256: None, model_type: ModelType::Audio, + hf_mmproj_file: Some("mmproj-LFM2.5-Audio-1.5B-Q4_0.gguf".into()), }, ModelSpec { name: "qwen3.5:27b".into(), @@ -61,6 +68,7 @@ pub fn default_registry() -> HashMap { hf_file: "Qwen3.5-27B-Q4_K_M.gguf".into(), sha256: None, model_type: ModelType::Text, + hf_mmproj_file: None, }, ModelSpec { name: "nanbeige:3b".into(), @@ -68,6 +76,7 @@ pub fn default_registry() -> HashMap { hf_file: "Nanbeige.Nanbeige4.1-3B.Q4_K_M.gguf".into(), sha256: None, model_type: ModelType::Text, + hf_mmproj_file: None, }, ModelSpec { name: "locooperator:4b".into(), @@ -75,6 +84,7 @@ pub fn default_registry() -> HashMap { hf_file: "LocoOperator-4B.Q4_K_M.gguf".into(), sha256: None, model_type: ModelType::Text, + hf_mmproj_file: None, }, ModelSpec { name: "qwen3.5:9b".into(), @@ -82,6 +92,7 @@ pub fn default_registry() -> HashMap { hf_file: "Qwen3.5-9B-Q4_K_M.gguf".into(), sha256: None, model_type: ModelType::Text, + hf_mmproj_file: None, }, ]; @@ -97,6 +108,7 @@ impl ModelSpec { hf_file: entry.hf_file.clone(), sha256: None, model_type: entry.model_type, + hf_mmproj_file: entry.hf_mmproj_file.clone(), } } @@ -167,6 +179,7 @@ mod tests { assert!(spec.hf_repo.contains("LFM2.5")); assert!(spec.hf_file.ends_with(".gguf")); assert_eq!(spec.model_type, ModelType::Text); + assert!(spec.hf_mmproj_file.is_none()); } #[test] @@ -182,6 +195,7 @@ mod tests { hf_repo: "test/repo".into(), hf_file: "model.gguf".into(), model_type: ModelType::Vision, + hf_mmproj_file: Some("mmproj.gguf".into()), }; let spec = ModelSpec::from_registry_entry(&entry); assert_eq!(spec.name, "test:1b"); @@ -189,6 +203,7 @@ mod tests { assert_eq!(spec.hf_file, "model.gguf"); assert!(spec.sha256.is_none()); assert_eq!(spec.model_type, ModelType::Vision); + assert_eq!(spec.hf_mmproj_file.as_deref(), Some("mmproj.gguf")); } #[test] @@ -200,6 +215,15 @@ mod tests { assert_eq!(reg["qwen3.5:27b"].model_type, ModelType::Text); } + #[test] + fn test_mmproj_files_correct() { + let reg = default_registry(); + assert!(reg["lfm2.5-vl:1.6b"].hf_mmproj_file.is_some()); + assert!(reg["lfm2.5-audio:1.5b"].hf_mmproj_file.is_some()); + assert!(reg["lfm2.5:1.2b"].hf_mmproj_file.is_none()); + assert!(reg["qwen3.5:27b"].hf_mmproj_file.is_none()); + } + #[test] fn test_with_quant_substitutes_suffix() { let reg = default_registry(); @@ -212,6 +236,17 @@ mod tests { assert_eq!(q8.name, spec.name); assert_eq!(q8.hf_repo, spec.hf_repo); assert_eq!(q8.model_type, spec.model_type); + assert_eq!(q8.hf_mmproj_file, spec.hf_mmproj_file); + } + + #[test] + fn test_with_quant_preserves_mmproj() { + let reg = default_registry(); + let spec = ®["lfm2.5-vl:1.6b"]; + assert!(spec.hf_mmproj_file.is_some()); + + let q8 = spec.with_quant("Q8_0"); + assert_eq!(q8.hf_mmproj_file, spec.hf_mmproj_file); } #[test] diff --git a/src/worker.rs b/src/worker.rs index aa0eb829..28ccb6cc 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -234,28 +234,49 @@ fn run_inference( params: &GenerateParams, task_id: Uuid, ) -> CompletedTask { - let prompt = match engine.apply_template(&messages) { - Ok(p) => p, - Err(e) => { - return CompletedTask { + let has_media = messages + .iter() + .any(|m| m.content.has_image() || m.content.has_audio()); + + if has_media && engine.has_multimodal() { + // Multimodal path + match engine.generate_multimodal(&messages, params, |_| ControlFlow::Continue(())) { + Ok(result) => CompletedTask { + task_id, + result: Ok(build_task_result(task_id, result)), + stream: false, + }, + Err(e) => CompletedTask { task_id, result: Err(e), stream: false, - }; + }, } - }; + } else { + // Text-only path + let prompt = match engine.apply_template(&messages) { + Ok(p) => p, + Err(e) => { + return CompletedTask { + task_id, + result: Err(e), + stream: false, + }; + } + }; - match engine.generate(&prompt, params, |_| ControlFlow::Continue(())) { - Ok(result) => CompletedTask { - task_id, - result: Ok(build_task_result(task_id, result)), - stream: false, - }, - Err(e) => CompletedTask { - task_id, - result: Err(e), - stream: false, - }, + match engine.generate(&prompt, params, |_| ControlFlow::Continue(())) { + Ok(result) => CompletedTask { + task_id, + result: Ok(build_task_result(task_id, result)), + stream: false, + }, + Err(e) => CompletedTask { + task_id, + result: Err(e), + stream: false, + }, + } } } @@ -267,36 +288,56 @@ fn run_inference_streaming( task_id: Uuid, token_tx: std::sync::mpsc::SyncSender, ) -> CompletedTask { - let prompt = match engine.apply_template(&messages) { - Ok(p) => p, - Err(e) => { - let err_msg = NodeMessage::StreamError { + let has_media = messages + .iter() + .any(|m| m.content.has_image() || m.content.has_audio()); + + let result = if has_media && engine.has_multimodal() { + // Multimodal streaming path + let tx = token_tx.clone(); + engine.generate_multimodal(&messages, params, move |stream_token| { + let msg = NodeMessage::StreamToken { task_id, - error: e.to_string(), + token: stream_token.text, + index: stream_token.index as u32, }; - let _ = token_tx.send(err_msg); - return CompletedTask { + if tx.send(msg).is_err() { + return ControlFlow::Break(()); + } + ControlFlow::Continue(()) + }) + } else { + // Text-only streaming path + let prompt = match engine.apply_template(&messages) { + Ok(p) => p, + Err(e) => { + let err_msg = NodeMessage::StreamError { + task_id, + error: e.to_string(), + }; + let _ = token_tx.send(err_msg); + return CompletedTask { + task_id, + result: Err(e), + stream: true, + }; + } + }; + + let tx = token_tx.clone(); + engine.generate(&prompt, params, move |stream_token| { + let msg = NodeMessage::StreamToken { task_id, - result: Err(e), - stream: true, + token: stream_token.text, + index: stream_token.index as u32, }; - } + if tx.send(msg).is_err() { + return ControlFlow::Break(()); + } + ControlFlow::Continue(()) + }) }; - let tx = token_tx.clone(); - let result = engine.generate(&prompt, params, move |stream_token| { - let msg = NodeMessage::StreamToken { - task_id, - token: stream_token.text, - index: stream_token.index as u32, - }; - if tx.send(msg).is_err() { - // Receiver dropped (connection lost), stop generation - return ControlFlow::Break(()); - } - ControlFlow::Continue(()) - }); - match result { Ok(result) => { // Send StreamEnd From ef970bb87a6cf63c42ab6f45fc6d571405543192 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Wed, 4 Mar 2026 01:23:31 +0300 Subject: [PATCH 11/57] Add vision integration test for multimodal inference Ignored test that downloads lfm2.5-vl:1.6b + mmproj, runs generate_multimodal() with a synthetic BMP or user-provided image via TEST_IMAGE_PATH env var. --- src/inference/engine.rs | 144 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/src/inference/engine.rs b/src/inference/engine.rs index f9f1a69d..1a2ed524 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -530,3 +530,147 @@ impl InferenceEngine { sha256hash(&bytes) } } + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use dkn_protocol::{ContentPart, MessageContent}; + + /// Create a minimal 64x64 BMP image with a color gradient (no external deps). + fn create_test_bmp() -> Vec { + let width: u32 = 64; + let height: u32 = 64; + let row_bytes = width * 3; // 192, already 4-byte aligned + let pixel_data_size = row_bytes * height; + let file_size = 54 + pixel_data_size; + + let mut data = Vec::with_capacity(file_size as usize); + + // BMP file header (14 bytes) + data.extend_from_slice(b"BM"); + data.extend_from_slice(&file_size.to_le_bytes()); + data.extend_from_slice(&[0u8; 4]); // reserved + data.extend_from_slice(&54u32.to_le_bytes()); // pixel data offset + + // BITMAPINFOHEADER (40 bytes) + data.extend_from_slice(&40u32.to_le_bytes()); // header size + data.extend_from_slice(&width.to_le_bytes()); + data.extend_from_slice(&height.to_le_bytes()); + data.extend_from_slice(&1u16.to_le_bytes()); // planes + data.extend_from_slice(&24u16.to_le_bytes()); // bits per pixel + data.extend_from_slice(&[0u8; 24]); // compression=0, rest zeros + + // Pixel data (bottom-up, BGR) + for y in 0..height { + for x in 0..width { + let r = ((x * 255) / (width - 1)) as u8; + let g = ((y * 255) / (height - 1)) as u8; + let b = 128u8; + data.push(b); + data.push(g); + data.push(r); + } + } + + data + } + + /// Integration test: download lfm2.5-vl:1.6b + mmproj, run vision inference. + /// + /// Run with: + /// cargo test test_vision_inference -- --ignored --nocapture + /// + /// Optionally provide your own image: + /// TEST_IMAGE_PATH=/path/to/photo.jpg cargo test test_vision_inference -- --ignored --nocapture + #[tokio::test] + #[ignore] // requires ~1.5 GB download (model + mmproj) + async fn test_vision_inference() { + let registry = crate::models::default_registry(); + let spec = registry.get("lfm2.5-vl:1.6b").unwrap().clone(); + + let cache_dir = dirs::cache_dir() + .unwrap_or_else(|| std::path::PathBuf::from(".")) + .join("dria-test-models"); + let cache = crate::models::ModelCache::new(cache_dir).unwrap(); + + // Download / cache the GGUF model + let model_path = if let Some(p) = cache.get_local_path(&spec) { + println!("model found in cache: {}", p.display()); + p + } else { + println!("downloading model (this may take a while)..."); + let hf_path = crate::models::ModelDownloader::download(&spec).await.unwrap(); + cache.link_model(&spec, &hf_path).unwrap() + }; + + // Download / cache the mmproj + let mmproj_path = if let Some(p) = cache.get_mmproj_path(&spec) { + println!("mmproj found in cache: {}", p.display()); + p + } else { + println!("downloading mmproj (this may take a while)..."); + let hf_path = crate::models::ModelDownloader::download_mmproj(&spec) + .await + .unwrap(); + cache.link_mmproj(&spec, &hf_path).unwrap() + }; + + // Load engine with multimodal projector + println!("loading model + mmproj..."); + let engine = InferenceEngine::load(&model_path, 0, Some(&mmproj_path)).unwrap(); + assert!(engine.has_multimodal(), "engine should have multimodal context"); + + // Get test image: from env var or generate a synthetic BMP + let image_bytes = if let Ok(path) = std::env::var("TEST_IMAGE_PATH") { + println!("using image: {path}"); + std::fs::read(&path).expect("failed to read TEST_IMAGE_PATH") + } else { + println!("using synthetic 64x64 gradient BMP"); + create_test_bmp() + }; + + // Build multimodal chat messages + let messages = vec![ChatMessage { + role: "user".into(), + content: MessageContent::Parts(vec![ + ContentPart::Text { + text: "What do you see in this image? Describe it briefly.".into(), + }, + ContentPart::Image { + data: image_bytes, + }, + ]), + }]; + + let params = GenerateParams { + max_tokens: 256, + temperature: 0.0, + ..Default::default() + }; + + // Run multimodal inference, streaming tokens to stdout + println!("\n--- model output ---"); + let result = engine + .generate_multimodal(&messages, ¶ms, |token| { + print!("{}", token.text); + ControlFlow::Continue(()) + }) + .unwrap(); + println!("\n--- end output ---\n"); + + println!( + "tokens: {} | prompt: {} | time: {}ms | {:.1} tok/s", + result.tokens_generated, + result.prompt_tokens, + result.generation_time_ms, + result.tokens_per_second, + ); + + assert!(!result.text.is_empty(), "model should produce output"); + assert!(result.tokens_generated > 0); + } +} From 2995a97f8b605aa268229d53d9599f339db756ae Mon Sep 17 00:00:00 2001 From: andthattoo Date: Wed, 4 Mar 2026 02:10:21 +0300 Subject: [PATCH 12/57] Mark Qwen 3.5 models as Vision with mmproj files --- src/models/registry.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/models/registry.rs b/src/models/registry.rs index 1e6fa0bb..b3800868 100644 --- a/src/models/registry.rs +++ b/src/models/registry.rs @@ -35,8 +35,8 @@ pub fn default_registry() -> HashMap { hf_repo: "unsloth/Qwen3.5-35B-A3B-GGUF".into(), hf_file: "Qwen3.5-35B-A3B-UD-Q4_K_M.gguf".into(), sha256: None, - model_type: ModelType::Text, - hf_mmproj_file: None, + model_type: ModelType::Vision, + hf_mmproj_file: Some("mmproj-BF16.gguf".into()), }, ModelSpec { name: "lfm2:24b-a2b".into(), @@ -67,8 +67,8 @@ pub fn default_registry() -> HashMap { hf_repo: "unsloth/Qwen3.5-27B-GGUF".into(), hf_file: "Qwen3.5-27B-Q4_K_M.gguf".into(), sha256: None, - model_type: ModelType::Text, - hf_mmproj_file: None, + model_type: ModelType::Vision, + hf_mmproj_file: Some("mmproj-BF16.gguf".into()), }, ModelSpec { name: "nanbeige:3b".into(), @@ -91,8 +91,8 @@ pub fn default_registry() -> HashMap { hf_repo: "lmstudio-community/Qwen3.5-9B-GGUF".into(), hf_file: "Qwen3.5-9B-Q4_K_M.gguf".into(), sha256: None, - model_type: ModelType::Text, - hf_mmproj_file: None, + model_type: ModelType::Vision, + hf_mmproj_file: Some("mmproj-Qwen3.5-9B-BF16.gguf".into()), }, ]; @@ -212,7 +212,9 @@ mod tests { assert_eq!(reg["lfm2.5-vl:1.6b"].model_type, ModelType::Vision); assert_eq!(reg["lfm2.5-audio:1.5b"].model_type, ModelType::Audio); assert_eq!(reg["lfm2.5:1.2b"].model_type, ModelType::Text); - assert_eq!(reg["qwen3.5:27b"].model_type, ModelType::Text); + assert_eq!(reg["qwen3.5:9b"].model_type, ModelType::Vision); + assert_eq!(reg["qwen3.5:27b"].model_type, ModelType::Vision); + assert_eq!(reg["qwen3.5:35b-a3b"].model_type, ModelType::Vision); } #[test] @@ -220,8 +222,10 @@ mod tests { let reg = default_registry(); assert!(reg["lfm2.5-vl:1.6b"].hf_mmproj_file.is_some()); assert!(reg["lfm2.5-audio:1.5b"].hf_mmproj_file.is_some()); + assert!(reg["qwen3.5:9b"].hf_mmproj_file.is_some()); + assert!(reg["qwen3.5:27b"].hf_mmproj_file.is_some()); + assert!(reg["qwen3.5:35b-a3b"].hf_mmproj_file.is_some()); assert!(reg["lfm2.5:1.2b"].hf_mmproj_file.is_none()); - assert!(reg["qwen3.5:27b"].hf_mmproj_file.is_none()); } #[test] From 3f31ea4b27b4757079f1b9efdd5f24031a03ac09 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Wed, 4 Mar 2026 23:18:42 +0300 Subject: [PATCH 13/57] Add interactive setup command, Homebrew tap, and v1.0.0-alpha.1 prep - Add `dria-node setup` interactive command: detects RAM, shows models that fit, downloads selection, runs test inference - Update release workflow: auto-update homebrew-dkn formula with SHA256s - Update CI workflow for v2 branch structure - Add install scripts for macOS/Linux (curl) and Windows (PowerShell) - Update README with quick start, model table, and CLI reference - Bump version to 1.0.0-alpha.1 Co-Authored-By: Claude Opus 4.6 --- .github/workflows/releases.yml | 129 +++++++++++- .github/workflows/tests.yml | 26 +-- Cargo.lock | 3 +- Cargo.toml | 4 +- README.md | 132 ++++++------ scripts/install.ps1 | 61 ++++++ scripts/install.sh | 91 ++++++++ src/config.rs | 45 ++-- src/main.rs | 7 + src/setup.rs | 375 +++++++++++++++++++++++++++++++++ 10 files changed, 761 insertions(+), 112 deletions(-) create mode 100644 scripts/install.ps1 create mode 100755 scripts/install.sh create mode 100644 src/setup.rs diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index aa06bd88..a8ad74c1 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -7,6 +7,9 @@ on: permissions: contents: write +env: + GGML_NATIVE: "OFF" + jobs: check_release: runs-on: ubuntu-latest @@ -35,6 +38,7 @@ jobs: arch: arm64, target: aarch64-apple-darwin, command: build, + build_args: --features metal, } - { runner: ubuntu-latest, @@ -43,6 +47,14 @@ jobs: target: x86_64-unknown-linux-musl, command: build, } + - { + runner: ubuntu-latest, + osname: linux, + arch: amd64-noavx, + target: x86_64-unknown-linux-musl, + command: build, + noavx: true, + } - { runner: ubuntu-latest, osname: linux, @@ -59,34 +71,41 @@ jobs: command: build, extension: ".exe", } - # - { runner: windows-latest, osname: windows, arch: arm64, target: aarch64-pc-windows-msvc, command: build, extension: ".exe", toolchain: nightly } steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Get the release version from the tag shell: bash run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV + - name: Disable AVX for baseline build + if: matrix.noavx + shell: bash + run: | + echo "GGML_AVX=OFF" >> $GITHUB_ENV + echo "GGML_AVX2=OFF" >> $GITHUB_ENV + echo "GGML_FMA=OFF" >> $GITHUB_ENV + echo "GGML_F16C=OFF" >> $GITHUB_ENV + - name: Build binary uses: houseabsolute/actions-rust-cross@v0 with: command: ${{ matrix.command }} target: ${{ matrix.target }} - args: "--bin dkn-compute --locked --release ${{ matrix.build_args }}" + args: "--bin dria-node --locked --release ${{ matrix.build_args }}" strip: true - name: Prepare Release File run: | - # move the binary - mv target/${{ matrix.target }}/release/dkn-compute${{ matrix.extension }} ./dkn-compute-binary-${{ matrix.osname }}-${{ matrix.arch }}${{ matrix.extension }} + mv target/${{ matrix.target }}/release/dria-node${{ matrix.extension }} ./dria-node-${{ matrix.osname }}-${{ matrix.arch }}${{ matrix.extension }} - name: Upload Launch Artifacts uses: actions/upload-artifact@v4 with: - name: dkn-compute-binary-${{ matrix.osname }}-${{ matrix.arch }} - path: dkn-compute-binary-${{ matrix.osname }}-${{ matrix.arch }}${{ matrix.extension }} + name: dria-node-${{ matrix.osname }}-${{ matrix.arch }} + path: dria-node-${{ matrix.osname }}-${{ matrix.arch }}${{ matrix.extension }} release: needs: build @@ -94,9 +113,9 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: - fetch-depth: 0 # Fetch all tags and history + fetch-depth: 0 - name: Download Launch Artifacts uses: actions/download-artifact@v4 @@ -104,7 +123,6 @@ jobs: merge-multiple: true path: ./artifacts - # https://github.com/ncipollo/release-action - name: Create release with artifacts uses: ncipollo/release-action@v1 with: @@ -114,3 +132,94 @@ jobs: artifactContentType: application/octet-stream allowUpdates: true makeLatest: false + + update_homebrew: + needs: release + runs-on: ubuntu-latest + + steps: + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + merge-multiple: true + path: ./artifacts + + - name: Checkout homebrew tap + uses: actions/checkout@v4 + with: + repository: firstbatchxyz/homebrew-dkn + token: ${{ secrets.HOMEBREW_TAP_TOKEN }} + path: homebrew-dkn + + - name: Update formula + run: | + VERSION="${{ github.event.release.tag_name }}" + VERSION="${VERSION#v}" + + SHA_MACOS_AMD64=$(sha256sum artifacts/dria-node-macOS-amd64 | cut -d' ' -f1) + SHA_MACOS_ARM64=$(sha256sum artifacts/dria-node-macOS-arm64 | cut -d' ' -f1) + SHA_LINUX_AMD64=$(sha256sum artifacts/dria-node-linux-amd64 | cut -d' ' -f1) + SHA_LINUX_ARM64=$(sha256sum artifacts/dria-node-linux-arm64 | cut -d' ' -f1) + + FORMULA=homebrew-dkn/Formula/dria-node.rb + + cat > "$FORMULA" << 'RUBY' + class DriaNode < Formula + desc "Dria Compute Node - run AI inference on the Dria network" + homepage "https://github.com/firstbatchxyz/dkn-compute-node" + version "VERSION_PLACEHOLDER" + license "Apache-2.0" + + on_macos do + on_intel do + url "https://github.com/firstbatchxyz/dkn-compute-node/releases/download/vVERSION_PLACEHOLDER/dria-node-macOS-amd64" + sha256 "SHA_MACOS_AMD64_PLACEHOLDER" + end + + on_arm do + url "https://github.com/firstbatchxyz/dkn-compute-node/releases/download/vVERSION_PLACEHOLDER/dria-node-macOS-arm64" + sha256 "SHA_MACOS_ARM64_PLACEHOLDER" + end + end + + on_linux do + on_intel do + url "https://github.com/firstbatchxyz/dkn-compute-node/releases/download/vVERSION_PLACEHOLDER/dria-node-linux-amd64" + sha256 "SHA_LINUX_AMD64_PLACEHOLDER" + end + + on_arm do + url "https://github.com/firstbatchxyz/dkn-compute-node/releases/download/vVERSION_PLACEHOLDER/dria-node-linux-arm64" + sha256 "SHA_LINUX_ARM64_PLACEHOLDER" + end + end + + def install + binary = Dir.glob("dria-node*").first + bin.install binary => "dria-node" + end + + test do + assert_match "dria-node", shell_output("#{bin}/dria-node --version") + end + end + RUBY + + # Remove leading whitespace from heredoc + sed -i 's/^ //' "$FORMULA" + + # Replace placeholders with actual values + sed -i "s/VERSION_PLACEHOLDER/${VERSION}/g" "$FORMULA" + sed -i "s/SHA_MACOS_AMD64_PLACEHOLDER/${SHA_MACOS_AMD64}/g" "$FORMULA" + sed -i "s/SHA_MACOS_ARM64_PLACEHOLDER/${SHA_MACOS_ARM64}/g" "$FORMULA" + sed -i "s/SHA_LINUX_AMD64_PLACEHOLDER/${SHA_LINUX_AMD64}/g" "$FORMULA" + sed -i "s/SHA_LINUX_ARM64_PLACEHOLDER/${SHA_LINUX_ARM64}/g" "$FORMULA" + + - name: Commit and push + run: | + cd homebrew-dkn + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git add Formula/dria-node.rb + git commit -m "Update dria-node to ${{ github.event.release.tag_name }}" + git push diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a908166a..4b5caca1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -3,22 +3,15 @@ name: tests on: push: branches: - - master + - v2 paths: - # Source files in each member - - "compute/src/**" - - "p2p/src/**" - - "utils/src/**" - - "executor/src/**" - # Cargo in each member - - "compute/Cargo.toml" - - "p2p/Cargo.toml" - - "utils/Cargo.toml" - - "executor/Cargo.toml" - # root-level Cargo + - "src/**" - "Cargo.toml" - # workflow itself + - "Cargo.lock" - ".github/workflows/tests.yml" + pull_request: + branches: + - v2 workflow_dispatch: jobs: @@ -29,11 +22,14 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 + - name: Install build dependencies + run: sudo apt-get update && sudo apt-get install -y cmake + - name: Install Rust toolchain uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Run tests - run: cargo test --workspace + run: cargo test - name: Run linter - run: cargo clippy --workspace + run: cargo clippy diff --git a/Cargo.lock b/Cargo.lock index d1e19222..b3c604df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -444,6 +444,7 @@ dependencies = [ [[package]] name = "dkn-protocol" version = "0.1.0" +source = "git+https://github.com/firstbatchxyz/dkn-protocol.git#19dcd03eda240eaa98665a7a839a32205b5ef912" dependencies = [ "base64", "quinn", @@ -455,7 +456,7 @@ dependencies = [ [[package]] name = "dria-node" -version = "2.0.0-alpha.1" +version = "1.0.0-alpha.1" dependencies = [ "anyhow", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 2bd49531..cd8aa120 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dria-node" -version = "2.0.0-alpha.1" +version = "1.0.0-alpha.1" edition = "2021" license = "Apache-2.0" @@ -34,7 +34,7 @@ anyhow = "1" uuid = { version = "1", features = ["v7", "serde"] } dirs = "6" encoding_rs = "0.8" -dkn-protocol = { path = "../dkn-protocol" } +dkn-protocol = { git = "https://github.com/firstbatchxyz/dkn-protocol.git" } quinn = "0.11" rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } rustls-native-certs = "0.8" diff --git a/README.md b/README.md index 2e8c9b20..2a3a2e88 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Dria Compute Node

- Dria Compute Node serves the computation results within Dria Knowledge Network. + Run AI inference on the Dria network. Earn rewards by serving models from your machine.

@@ -29,111 +29,109 @@

-> Use the [Dria Compute Launcher](https://github.com/firstbatchxyz/dkn-compute-launcher/) to run a compute node with many more features! +## Quick Start -## Releases +### Install -For _production_ images: - -- **Versioned**: With each release, a versioned image is deployed on Docker hub with the version tag `:vX.X.X`. -- **Latest**: The latest production image is always under the `:latest` tag. - -For _development_ images: - -- **Master**: On each push to `master` branch, a new image is created with the tag `master--`. -- **Unstable**: The latest development image is always under the `:unstable` tag. - -You can see the list of deployed images on [Docker Hub](https://hub.docker.com/orgs/firstbatch/members). - -## Development - -> If you have a feature that you would like to add with respect to its respective issue, or a bug fix, feel free to fork & create a PR! - -If you would like to run the node from source (which is really handy during development), you can use our shorthand scripts within the Makefile. You can see the available commands with: +**Homebrew (macOS / Linux):** ```sh -make help +brew tap firstbatchxyz/dkn +brew install dria-node ``` -You can run the binary as is: +**From GitHub Releases:** -```sh -cargo run +Download the latest binary for your platform from [Releases](https://github.com/firstbatchxyz/dkn-compute-node/releases) and place it in your `PATH`. -# specify custom .env file -DKN_COMPUTE_ENV=./path/to/.env cargo run -``` +### Setup -If you have a valid `.env` file, you can run the latest Docker image via compose as well: +Run the interactive setup to pick a model, download it, and verify it works: ```sh -docker compose up - -# Ollama without any GPUs -docker compose --profile=ollama-cpu up -# Ollama for NVIDIA gpus -docker compose --profile=ollama-cuda up -# Ollama for AMD gpus -docker compose --profile=ollama-rocm up +dria-node setup ``` -> [!TIP] -> -> You can specify a custom initial RPC address with `DKN_INITIAL_RPC_ADDR`. +This detects your system RAM, shows models that fit, downloads your selection, and runs a test inference. -### Testing +### Start -You can the tests as follows: +Once setup is complete, start the node: ```sh -cargo test --workspace +dria-node start --wallet --model ``` -We also have some benchmarking and profiling scripts, see [node performance](./docs/NODE_PERFORMANCE.md) for more details. +## Available Models -### Documentation +| Model | Type | Quant | ~Size | +|-------|------|-------|-------| +| `lfm2.5:1.2b` | Text | Q4_K_M | 0.8 GB | +| `lfm2.5-audio:1.5b` | Audio | Q4_0 | 1.0 GB | +| `lfm2.5-vl:1.6b` | Vision | Q4_0 | 1.2 GB | +| `nanbeige:3b` | Text | Q4_K_M | 2.0 GB | +| `locooperator:4b` | Text | Q4_K_M | 2.5 GB | +| `qwen3.5:9b` | Vision | Q4_K_M | 6.0 GB | +| `lfm2:24b-a2b` | Text | Q4_K_M | 14 GB | +| `qwen3.5:27b` | Vision | Q4_K_M | 16 GB | +| `qwen3.5:35b-a3b` | Vision | Q4_K_M | 20 GB | -You can view the entire crate-level documentation with: +Serve multiple models by comma-separating them: `--model "qwen3.5:9b,lfm2.5:1.2b"` -```sh -cargo doc --open --no-deps --document-private-items +Override quantization with `--quant Q8_0` (applies to all models). + +## CLI Reference + +``` +dria-node + +Commands: + setup Interactive setup: pick a model, download it, and run a test + start Start the compute node + +setup options: + --data-dir Data directory [env: DRIA_DATA_DIR] + --gpu-layers GPU layers to offload (0 = CPU only) [default: 0] + +start options: + --wallet Wallet secret key, hex-encoded [env: DRIA_WALLET] + --model Model(s) to serve, comma-separated [env: DRIA_MODELS] + --router-url Router URL [default: quic.dria.co:4001] [env: DRIA_ROUTER_URL] + --gpu-layers GPU layers to offload (-1 = all, 0 = CPU) [default: 0] + --max-concurrent Max concurrent inference requests [default: 1] + --data-dir Data directory [env: DRIA_DATA_DIR] + --quant Override GGUF quantization [env: DRIA_QUANT] + --insecure Skip TLS verification [env: DRIA_INSECURE] ``` -### Styling +All flags can also be set via environment variables. -Lint and format with: +## Building from Source ```sh -cargo clippy --workspace -cargo fmt -v +git clone https://github.com/firstbatchxyz/dkn-compute-node.git +cd dkn-compute-node +cargo build --release ``` -### Profiling +**Feature flags:** -We have scripts to profile both CPU and Memory usage. A special build is created for profiling, via a custom `profiling` feature, such that the output inherits `release` mode but also has debug symbols. +- `--features metal` — Apple Metal GPU acceleration (macOS) +- `--features cuda` — NVIDIA CUDA GPU acceleration -Furthermore, the profiling build will exit automatically after a certain time, as if CTRL+C has been pressed. This is needed by the memory profiling tool in particular. - -**CPU Profiling**: To create a [flamegraph](https://crates.io/crates/flamegraph) of the application, the command below will create a profiling build that inherits `release` mode, except with debug information: +### Testing ```sh -DKN_EXIT_TIMEOUT=120 cargo flamegraph --root --profile=profiling --bin dkn-compute +cargo test ``` -> [!NOTE] -> -> CPU profiling may require super-user access. - -**Memory Profiling**: To profile memory usage, we make use of [cargo-instruments](https://crates.io/crates/cargo-instruments): +### Linting ```sh -DKN_EXIT_TIMEOUT=120 cargo instruments --profile=profiling -t Allocations --bin dkn-compute +cargo clippy +cargo fmt --check ``` -> [!TIP] -> -> You can adjust the profiling duration via the `DKN_EXIT_TIMEOUT` variable, which takes a number of seconds until termination. - ## License This project is licensed under the [Apache License 2.0](https://opensource.org/license/Apache-2.0). diff --git a/scripts/install.ps1 b/scripts/install.ps1 new file mode 100644 index 00000000..8bc189e7 --- /dev/null +++ b/scripts/install.ps1 @@ -0,0 +1,61 @@ +# Dria Node installer for Windows +# Usage: irm https://raw.githubusercontent.com/firstbatchxyz/dkn-compute-node/v2/scripts/install.ps1 | iex +$ErrorActionPreference = "Stop" + +$Repo = "firstbatchxyz/dkn-compute-node" +$Binary = "dria-node" +$InstallDir = "$env:LOCALAPPDATA\dria" + +Write-Host "Dria Node Installer" -ForegroundColor Cyan + +# Fetch latest release +Write-Host "Fetching latest release..." -ForegroundColor Blue +try { + $Release = Invoke-RestMethod -Uri "https://api.github.com/repos/$Repo/releases/latest" + $Tag = $Release.tag_name +} catch { + Write-Host "Error: Failed to fetch latest release. Check your internet connection." -ForegroundColor Red + exit 1 +} + +Write-Host "Latest release: $Tag" -ForegroundColor Blue + +# Download binary +$Asset = "$Binary-windows-amd64.exe" +$Url = "https://github.com/$Repo/releases/download/$Tag/$Asset" + +Write-Host "Downloading $Asset..." -ForegroundColor Blue +$TmpFile = Join-Path $env:TEMP "$Binary.exe" +try { + Invoke-WebRequest -Uri $Url -OutFile $TmpFile -UseBasicParsing +} catch { + Write-Host "Error: Download failed. Asset may not exist: $Url" -ForegroundColor Red + exit 1 +} + +# Install +if (-not (Test-Path $InstallDir)) { + New-Item -ItemType Directory -Path $InstallDir -Force | Out-Null +} +$Dest = Join-Path $InstallDir "$Binary.exe" +Move-Item -Path $TmpFile -Destination $Dest -Force +Write-Host "Installed to $Dest" -ForegroundColor Blue + +# Add to PATH if not present +$UserPath = [Environment]::GetEnvironmentVariable("PATH", "User") +if ($UserPath -notlike "*$InstallDir*") { + [Environment]::SetEnvironmentVariable("PATH", "$InstallDir;$UserPath", "User") + $env:PATH = "$InstallDir;$env:PATH" + Write-Host "Added $InstallDir to user PATH." -ForegroundColor Blue + Write-Host "Restart your terminal for PATH changes to take effect." -ForegroundColor Yellow +} + +# Verify +Write-Host "" +try { + $Version = & $Dest --version 2>&1 + Write-Host "Successfully installed $Version" -ForegroundColor Green +} catch { + Write-Host "Installed successfully. Run '$Binary --version' to verify." -ForegroundColor Green +} +Write-Host "Run '$Binary start --help' to get started." -ForegroundColor Cyan diff --git a/scripts/install.sh b/scripts/install.sh new file mode 100755 index 00000000..57e5ac1a --- /dev/null +++ b/scripts/install.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash +# Dria Node installer for macOS and Linux +# Usage: curl -sSL https://raw.githubusercontent.com/firstbatchxyz/dkn-compute-node/v2/scripts/install.sh | bash +set -euo pipefail + +REPO="firstbatchxyz/dkn-compute-node" +BINARY="dria-node" + +info() { printf '\033[1;34m%s\033[0m\n' "$*"; } +error() { printf '\033[1;31mError: %s\033[0m\n' "$*" >&2; exit 1; } + +# Detect OS +case "$(uname -s)" in + Darwin) OS="macOS" ;; + Linux) OS="linux" ;; + *) error "Unsupported OS: $(uname -s). Use Windows installer for Windows." ;; +esac + +# Detect architecture +case "$(uname -m)" in + x86_64|amd64) ARCH="amd64" ;; + aarch64|arm64) ARCH="arm64" ;; + *) error "Unsupported architecture: $(uname -m)" ;; +esac + +# On Linux x86_64, check for AVX2 support and fall back to noavx if missing +if [ "$OS" = "linux" ] && [ "$ARCH" = "amd64" ]; then + if ! grep -q avx2 /proc/cpuinfo 2>/dev/null; then + ARCH="amd64-noavx" + info "CPU does not support AVX2, using baseline binary." + fi +fi + +info "Detected: ${OS} ${ARCH}" + +# Fetch latest release tag +info "Fetching latest release..." +LATEST=$(curl -sSf "https://api.github.com/repos/${REPO}/releases/latest" \ + | grep '"tag_name"' | head -1 | cut -d'"' -f4) \ + || error "Failed to fetch latest release. Check your internet connection." + +[ -z "$LATEST" ] && error "Could not determine latest release tag." +info "Latest release: ${LATEST}" + +# Download binary +ASSET="${BINARY}-${OS}-${ARCH}" +URL="https://github.com/${REPO}/releases/download/${LATEST}/${ASSET}" + +info "Downloading ${ASSET}..." +TMPDIR=$(mktemp -d) +trap 'rm -rf "$TMPDIR"' EXIT + +curl -sSfL -o "${TMPDIR}/${BINARY}" "$URL" \ + || error "Download failed. Asset may not exist for your platform: ${URL}" + +chmod +x "${TMPDIR}/${BINARY}" + +# Install +if [ -w "/usr/local/bin" ]; then + INSTALL_DIR="/usr/local/bin" +elif [ "$(id -u)" = "0" ]; then + INSTALL_DIR="/usr/local/bin" +else + INSTALL_DIR="${HOME}/.local/bin" + mkdir -p "$INSTALL_DIR" +fi + +mv "${TMPDIR}/${BINARY}" "${INSTALL_DIR}/${BINARY}" +info "Installed to ${INSTALL_DIR}/${BINARY}" + +# Check if install dir is in PATH +case ":${PATH}:" in + *":${INSTALL_DIR}:"*) ;; + *) + info "" + info "WARNING: ${INSTALL_DIR} is not in your PATH." + info "Add it by running:" + info " export PATH=\"${INSTALL_DIR}:\$PATH\"" + info "Or add that line to your ~/.bashrc / ~/.zshrc" + ;; +esac + +# Verify +if command -v "$BINARY" &>/dev/null; then + info "" + info "Successfully installed $(${BINARY} --version)" + info "Run '${BINARY} start --help' to get started." +else + info "" + info "Installation complete. Run '${INSTALL_DIR}/${BINARY} --version' to verify." +fi diff --git a/src/config.rs b/src/config.rs index 240f906a..7749997e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,18 +13,29 @@ pub struct Cli { #[derive(Subcommand)] pub enum Command { + /// Interactive setup: pick a model, download it, and run a test + Setup { + /// Data directory + #[arg(long, env = "DRIA_DATA_DIR")] + data_dir: Option, + + /// Number of GPU layers to offload (-1 = all, 0 = CPU only) + #[arg(long, env = "DRIA_GPU_LAYERS", default_value = "0")] + gpu_layers: i32, + }, + /// Start the compute node Start { /// Wallet secret key (hex-encoded, 32 bytes) #[arg(long, env = "DRIA_WALLET")] wallet: String, - /// Model(s) to serve (comma-separated shortnames, e.g. "gemma3:4b,llama3.1:8b") + /// Model(s) to serve (comma-separated shortnames, e.g. "qwen3.5:9b,lfm2.5:1.2b") #[arg(long, env = "DRIA_MODELS")] model: String, /// Router URL for task coordination - #[arg(long, env = "DRIA_ROUTER_URL", default_value = "https://router.dria.co")] + #[arg(long, env = "DRIA_ROUTER_URL", default_value = "quic.dria.co:4001")] router_url: String, /// Number of GPU layers to offload (-1 = all, 0 = CPU only) @@ -141,8 +152,8 @@ mod tests { fn test_config_from_valid_args() { let cfg = Config::from_start_args( "0x6472696164726961647269616472696164726961647269616472696164726961".into(), - "gemma3:4b, llama3.1:8b".into(), - "https://router.dria.co".into(), + "qwen3.5:9b, lfm2.5:1.2b".into(), + "quic.dria.co:4001".into(), 0, 1, Some("/tmp/dria-test".into()), @@ -151,21 +162,21 @@ mod tests { ) .unwrap(); - assert_eq!(cfg.model_names, vec!["gemma3:4b", "llama3.1:8b"]); + assert_eq!(cfg.model_names, vec!["qwen3.5:9b", "lfm2.5:1.2b"]); assert_eq!( cfg.secret_key_hex, "6472696164726961647269616472696164726961647269616472696164726961" ); assert_eq!(cfg.models_dir, PathBuf::from("/tmp/dria-test/models")); - assert_eq!(cfg.router_urls, vec!["https://router.dria.co"]); + assert_eq!(cfg.router_urls, vec!["quic.dria.co:4001"]); } #[test] fn test_config_invalid_wallet_length() { let result = Config::from_start_args( "0xabcd".into(), - "gemma3:4b".into(), - "https://router.dria.co".into(), + "qwen3.5:9b".into(), + "quic.dria.co:4001".into(), 0, 1, None, @@ -179,8 +190,8 @@ mod tests { fn test_config_invalid_wallet_hex() { let result = Config::from_start_args( "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz".into(), - "gemma3:4b".into(), - "https://router.dria.co".into(), + "qwen3.5:9b".into(), + "quic.dria.co:4001".into(), 0, 1, None, @@ -195,7 +206,7 @@ mod tests { let result = Config::from_start_args( "6472696164726961647269616472696164726961647269616472696164726961".into(), "".into(), - "https://router.dria.co".into(), + "quic.dria.co:4001".into(), 0, 1, None, @@ -209,8 +220,8 @@ mod tests { fn test_config_zero_concurrency() { let result = Config::from_start_args( "6472696164726961647269616472696164726961647269616472696164726961".into(), - "gemma3:4b".into(), - "https://router.dria.co".into(), + "qwen3.5:9b".into(), + "quic.dria.co:4001".into(), 0, 0, None, @@ -224,7 +235,7 @@ mod tests { fn test_config_comma_separated_router_urls() { let cfg = Config::from_start_args( "6472696164726961647269616472696164726961647269616472696164726961".into(), - "gemma3:4b".into(), + "qwen3.5:9b".into(), "https://router1.dria.co, https://router2.dria.co".into(), 0, 1, @@ -243,7 +254,7 @@ mod tests { fn test_config_empty_router_url() { let result = Config::from_start_args( "6472696164726961647269616472696164726961647269616472696164726961".into(), - "gemma3:4b".into(), + "qwen3.5:9b".into(), "".into(), 0, 1, @@ -258,8 +269,8 @@ mod tests { fn test_config_insecure_flag() { let cfg = Config::from_start_args( "6472696164726961647269616472696164726961647269616472696164726961".into(), - "gemma3:4b".into(), - "https://router.dria.co".into(), + "qwen3.5:9b".into(), + "quic.dria.co:4001".into(), 0, 1, None, diff --git a/src/main.rs b/src/main.rs index 080f9dd4..cf7221ee 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ mod identity; mod inference; mod models; mod network; +mod setup; mod stats; mod worker; @@ -37,6 +38,12 @@ async fn main() -> anyhow::Result<()> { let cli = Cli::parse(); match cli.command { + Command::Setup { + data_dir, + gpu_layers, + } => { + setup::run_setup(data_dir, gpu_layers).await?; + } Command::Start { wallet, model, diff --git a/src/setup.rs b/src/setup.rs new file mode 100644 index 00000000..afab24fe --- /dev/null +++ b/src/setup.rs @@ -0,0 +1,375 @@ +use std::io::{self, Write}; +use std::ops::ControlFlow; +use std::path::PathBuf; + +use dkn_protocol::ModelType; + +use crate::error::NodeError; +use crate::inference::{GenerateParams, InferenceEngine}; +use crate::models::{ModelCache, ModelDownloader, default_registry, resolve_model}; + +/// Model metadata for the setup display. +struct SetupModel { + name: String, + model_type: ModelType, + quant: String, + size_gb: f64, + ram_needed_gb: f64, +} + +/// Hardcoded size estimates (Q4_K_M / Q4_0 defaults) for each registry model. +fn model_size_gb(name: &str) -> Option<(f64, f64)> { + // (gguf_size_gb, ram_needed_gb) + match name { + "lfm2.5:1.2b" => Some((0.8, 1.0)), + "nanbeige:3b" => Some((2.0, 2.5)), + "locooperator:4b" => Some((2.5, 3.0)), + "lfm2.5-vl:1.6b" => Some((1.2, 1.5)), + "lfm2.5-audio:1.5b" => Some((1.0, 1.5)), + "qwen3.5:9b" => Some((6.0, 7.0)), + "qwen3.5:27b" => Some((16.0, 18.0)), + "qwen3.5:35b-a3b" => Some((20.0, 22.0)), + "lfm2:24b-a2b" => Some((14.0, 16.0)), + _ => None, + } +} + +/// Extract the quantization string from a GGUF filename (e.g. "Q4_K_M" from "Foo-Q4_K_M.gguf"). +fn extract_quant(hf_file: &str) -> String { + let stem = hf_file.strip_suffix(".gguf").unwrap_or(hf_file); + match stem.rfind('-') { + Some(pos) => stem[pos + 1..].to_string(), + None => stem.to_string(), + } +} + +/// Detect total system RAM in bytes. +fn detect_ram_bytes() -> Option { + #[cfg(target_os = "linux")] + { + if let Ok(contents) = std::fs::read_to_string("/proc/meminfo") { + for line in contents.lines() { + if let Some(rest) = line.strip_prefix("MemTotal:") { + let rest = rest.trim(); + if let Some(kb_str) = rest.strip_suffix("kB").or_else(|| rest.strip_suffix("KB")) + { + if let Ok(kb) = kb_str.trim().parse::() { + return Some(kb * 1024); + } + } + } + } + } + None + } + + #[cfg(target_os = "macos")] + { + let output = std::process::Command::new("sysctl") + .args(["-n", "hw.memsize"]) + .output() + .ok()?; + let s = String::from_utf8_lossy(&output.stdout); + s.trim().parse::().ok() + } + + #[cfg(target_os = "windows")] + { + let output = std::process::Command::new("wmic") + .args(["OS", "get", "TotalVisibleMemorySize"]) + .output() + .ok()?; + let s = String::from_utf8_lossy(&output.stdout); + for line in s.lines() { + if let Ok(kb) = line.trim().parse::() { + return Some(kb * 1024); + } + } + None + } + + #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))] + { + None + } +} + +fn model_type_label(mt: ModelType) -> &'static str { + match mt { + ModelType::Text => "Text", + ModelType::Vision => "Vision", + ModelType::Audio => "Audio", + } +} + +pub async fn run_setup(data_dir: Option, gpu_layers: i32) -> Result<(), NodeError> { + println!(); + println!(" Welcome to Dria Node setup!"); + println!(); + + // Detect RAM + let ram_gb = detect_ram_bytes().map(|b| b as f64 / (1024.0 * 1024.0 * 1024.0)); + + if let Some(gb) = ram_gb { + println!(" System: {:.0} GB RAM detected", gb); + } else { + println!(" System: could not detect RAM, showing all models"); + } + println!(); + + // Build model list from registry with size info + let registry = default_registry(); + let mut models: Vec = Vec::new(); + + for spec in registry.values() { + if let Some((size_gb, ram_needed_gb)) = model_size_gb(&spec.name) { + models.push(SetupModel { + name: spec.name.clone(), + model_type: spec.model_type, + quant: extract_quant(&spec.hf_file), + size_gb, + ram_needed_gb, + }); + } + } + + // Sort by size ascending + models.sort_by(|a, b| a.size_gb.partial_cmp(&b.size_gb).unwrap()); + + // Split into fits / too-large + let (fits, too_large): (Vec<_>, Vec<_>) = match ram_gb { + Some(gb) => models + .into_iter() + .partition(|m| m.ram_needed_gb < gb), + None => (models, vec![]), + }; + + if fits.is_empty() { + println!(" No models fit your available RAM. Minimum recommended: 2 GB."); + return Ok(()); + } + + // Print selectable list + println!(" Available models:"); + println!(); + for (i, m) in fits.iter().enumerate() { + println!( + " {}) {:<22} {:<8} {:<10} ~{:.1} GB", + i + 1, + m.name, + model_type_label(m.model_type), + m.quant, + m.size_gb, + ); + } + println!(); + + // Print too-large models + if !too_large.is_empty() { + println!(" Models too large for your system (need more RAM):"); + for m in &too_large { + println!( + " - {:<22} (~{:.0} GB) — needs ~{:.0} GB", + m.name, m.size_gb, m.ram_needed_gb, + ); + } + println!(); + } + + // Read selection + let selection = loop { + print!(" Select a model [1-{}]: ", fits.len()); + io::stdout().flush().map_err(|e| NodeError::Config(format!("stdout flush: {e}")))?; + + let mut input = String::new(); + io::stdin() + .read_line(&mut input) + .map_err(|e| NodeError::Config(format!("failed to read input: {e}")))?; + + match input.trim().parse::() { + Ok(n) if n >= 1 && n <= fits.len() => break n - 1, + _ => { + println!(" Invalid selection, please enter a number between 1 and {}.", fits.len()); + } + } + }; + + let chosen = &fits[selection]; + let model_name = &chosen.name; + println!(); + + // Resolve model spec + let spec = resolve_model(model_name, ®istry, None) + .ok_or_else(|| NodeError::Model(format!("unknown model: {model_name}")))?; + + // Set up cache dir + let data_dir = match data_dir { + Some(d) => d, + None => dirs::home_dir() + .ok_or_else(|| NodeError::Config("could not determine home directory".into()))? + .join(".dria"), + }; + let models_dir = data_dir.join("models"); + std::fs::create_dir_all(&models_dir)?; + let cache = ModelCache::new(models_dir)?; + + // Download model + println!(" Downloading {}...", model_name); + let model_path = if let Some(path) = cache.get_local_path(&spec) { + println!(" (already cached)"); + path + } else { + let hf_path = ModelDownloader::download(&spec).await?; + + // Verify SHA-256 if specified + if let Some(ref expected_sha) = spec.sha256 { + if !ModelCache::verify_sha256(&hf_path, expected_sha)? { + return Err(NodeError::Model(format!( + "SHA-256 mismatch for model {model_name}" + ))); + } + } + + cache.link_model(&spec, &hf_path)? + }; + + // Download mmproj if needed + let mmproj_path = if spec.hf_mmproj_file.is_some() { + if let Some(path) = cache.get_mmproj_path(&spec) { + Some(path) + } else { + let hf_path = ModelDownloader::download_mmproj(&spec).await?; + Some(cache.link_mmproj(&spec, &hf_path)?) + } + } else { + None + }; + + // Load model + println!(); + println!(" Loading model..."); + let engine = tokio::task::spawn_blocking({ + let model_path = model_path.clone(); + let mmproj_path = mmproj_path.clone(); + move || InferenceEngine::load(&model_path, gpu_layers, mmproj_path.as_deref()) + }) + .await + .map_err(|e| NodeError::Inference(format!("task join error: {e}")))? + ?; + + // Run test inference + println!(" Running test inference..."); + println!(); + + let model_name_owned = model_name.clone(); + let result = tokio::task::spawn_blocking(move || { + let prompt = engine + .apply_template(&[dkn_protocol::ChatMessage { + role: "user".into(), + content: dkn_protocol::MessageContent::Text("Hello!".into()), + }]) + .unwrap_or_else(|_| "Hello!".into()); + + let params = GenerateParams { + max_tokens: 64, + temperature: 0.7, + ..Default::default() + }; + + print!(" > "); + let result = engine.generate(&prompt, ¶ms, |token| { + print!("{}", token.text); + let _ = io::stdout().flush(); + ControlFlow::Continue(()) + }); + println!(); + + result.map(|r| (r, model_name_owned)) + }) + .await + .map_err(|e| NodeError::Inference(format!("task join error: {e}")))? + ?; + + let (inference_result, model_name_final) = result; + + println!(); + println!( + " Model working! {:.1} tok/s", + inference_result.tokens_per_second + ); + println!(); + println!(" To start the node:"); + println!( + " dria-node start --wallet --model {}", + model_name_final + ); + println!(); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_detect_ram_returns_something() { + // On CI / local machines this should return Some on Linux/macOS/Windows + let ram = detect_ram_bytes(); + if cfg!(any( + target_os = "linux", + target_os = "macos", + target_os = "windows" + )) { + assert!(ram.is_some(), "should detect RAM on this platform"); + assert!(ram.unwrap() > 0); + } + } + + #[test] + fn test_extract_quant() { + assert_eq!(extract_quant("Qwen3.5-9B-Q4_K_M.gguf"), "Q4_K_M"); + assert_eq!(extract_quant("LFM2.5-VL-1.6B-Q4_0.gguf"), "Q4_0"); + assert_eq!(extract_quant("model.gguf"), "model"); + } + + #[test] + fn test_model_size_known() { + assert!(model_size_gb("lfm2.5:1.2b").is_some()); + assert!(model_size_gb("qwen3.5:9b").is_some()); + assert!(model_size_gb("nonexistent:1b").is_none()); + } + + #[test] + fn test_model_size_ordering() { + // RAM needed should always be >= size + for name in [ + "lfm2.5:1.2b", + "nanbeige:3b", + "locooperator:4b", + "lfm2.5-vl:1.6b", + "lfm2.5-audio:1.5b", + "qwen3.5:9b", + "qwen3.5:27b", + "qwen3.5:35b-a3b", + "lfm2:24b-a2b", + ] { + let (size, needed) = model_size_gb(name).unwrap(); + assert!( + needed >= size, + "{name}: ram_needed ({needed}) should be >= size ({size})" + ); + } + } + + #[test] + fn test_all_registry_models_have_sizes() { + let registry = default_registry(); + for name in registry.keys() { + assert!( + model_size_gb(name).is_some(), + "missing size estimate for registry model: {name}" + ); + } + } +} From 667041e8dee18559270b3a94935e7fb25f54ad98 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Wed, 4 Mar 2026 23:21:33 +0300 Subject: [PATCH 14/57] Remove obsolete Docker image workflows --- .github/workflows/image-dev.yml | 70 ----------------------------- .github/workflows/image-release.yml | 62 ------------------------- 2 files changed, 132 deletions(-) delete mode 100644 .github/workflows/image-dev.yml delete mode 100644 .github/workflows/image-release.yml diff --git a/.github/workflows/image-dev.yml b/.github/workflows/image-dev.yml deleted file mode 100644 index d77d1255..00000000 --- a/.github/workflows/image-dev.yml +++ /dev/null @@ -1,70 +0,0 @@ -name: dev-image -on: - push: - branches: ["master"] - paths: - # Source files in each member - - "compute/src/**" - - "p2p/src/**" - - "utils/src/**" - - "executor/src/**" - # Cargo in each member - - "compute/Cargo.toml" - - "p2p/Cargo.toml" - - "utils/Cargo.toml" - - "executor/Cargo.toml" - # root-level changes - - "Cargo.lock" - - "Cross.toml" - - "Dockerfile" - - "compose.yml" - # workflow itself - - ".github/workflows/build_dev_container.yml" - -jobs: - build-and-push: - name: Build and Push - runs-on: ubuntu-latest - - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - name: Login to Docker Hub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Get Unix Time - id: timestamp - run: echo "timestamp=$(date +%s)" >> $GITHUB_OUTPUT - - - name: Get SHA - id: sha - run: echo "sha=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT - - - name: Get Branch Name - id: branch - run: echo "branch=$(echo ${GITHUB_REF#refs/heads/})" >> $GITHUB_OUTPUT - - - name: Set Image Tag - id: itag - run: echo "itag=${{ steps.branch.outputs.branch }}-${{ steps.sha.outputs.sha }}-${{ steps.timestamp.outputs.timestamp }}" >> $GITHUB_OUTPUT - - - name: Build and push - uses: docker/build-push-action@v6 - env: - IMAGE_TAG: ${{ steps.itag.outputs.itag }} - with: - platforms: linux/amd64, linux/arm64, linux/arm, linux/arm64v8 - push: true - tags: | - firstbatch/dkn-compute-node:unstable - firstbatch/dkn-compute-node:${{ env.IMAGE_TAG }} diff --git a/.github/workflows/image-release.yml b/.github/workflows/image-release.yml deleted file mode 100644 index c88e7977..00000000 --- a/.github/workflows/image-release.yml +++ /dev/null @@ -1,62 +0,0 @@ -name: release-image -on: - release: - types: [published] - workflow_dispatch: - -jobs: - check-release: - name: Check Release - runs-on: ubuntu-latest - outputs: - image_tag: ${{ steps.itag.outputs.itag }} - - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Set Image Tag - id: itag - run: | - CARGO_VERSION=$(awk '/\[workspace.package\]/ {flag=1} flag && /version =/ {print $3; flag=0}' Cargo.toml | sed 's/"//g') - IMAGE_TAG=v$CARGO_VERSION # set the image tag with "v" prefix - echo "Cargo.toml version: $CARGO_VERSION" - echo "Image tag: $IMAGE_TAG" - echo "itag=$IMAGE_TAG" >> $GITHUB_OUTPUT - - - name: Check Release Tag - run: | - if [[ ! "${{ steps.itag.outputs.itag }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then - echo "Release tag format is invalid. It should follow the pattern 'vMAJOR.MINOR.PATCH' (e.g., v1.0.0)." - exit 1 - fi - echo "Release tag format is valid." - - build-and-push: - name: Build and Push - needs: check-release - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - name: Login to Docker Hub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Build and push - uses: docker/build-push-action@v6 - with: - platforms: linux/amd64, linux/arm64, linux/arm, linux/arm64v8 - push: true - tags: | - firstbatch/dkn-compute-node:latest - firstbatch/dkn-compute-node:${{ needs.check-release.outputs.image_tag }} From 8fce69997a04d194390a1b35689f0ec51010f484 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Wed, 4 Mar 2026 23:41:58 +0300 Subject: [PATCH 15/57] Fix Linux musl builds: use rustls-tls for hf-hub, add cmake --- .github/workflows/releases.yml | 4 + Cargo.lock | 257 ++------------------------------- Cargo.toml | 2 +- 3 files changed, 14 insertions(+), 249 deletions(-) diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index a8ad74c1..cc53b6e9 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -76,6 +76,10 @@ jobs: - name: Checkout code uses: actions/checkout@v4 + - name: Install build dependencies (Linux) + if: matrix.osname == 'linux' + run: sudo apt-get update && sudo apt-get install -y cmake musl-tools + - name: Get the release version from the tag shell: bash run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV diff --git a/Cargo.lock b/Cargo.lock index b3c604df..491ae242 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -171,12 +171,6 @@ version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - [[package]] name = "bytes" version = "1.10.1" @@ -346,15 +340,6 @@ dependencies = [ "libc", ] -[[package]] -name = "crc32fast" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" -dependencies = [ - "cfg-if", -] - [[package]] name = "crunchy" version = "0.2.3" @@ -534,16 +519,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" -[[package]] -name = "errno" -version = "0.3.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" -dependencies = [ - "libc", - "windows-sys 0.59.0", -] - [[package]] name = "fastbloom" version = "0.14.1" @@ -556,12 +531,6 @@ dependencies = [ "siphasher", ] -[[package]] -name = "fastrand" -version = "2.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" - [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -577,37 +546,12 @@ dependencies = [ "glob", ] -[[package]] -name = "flate2" -version = "1.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" -dependencies = [ - "crc32fast", - "miniz_oxide", -] - [[package]] name = "fnv" version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - [[package]] name = "form_urlencoded" version = "1.2.1" @@ -806,11 +750,9 @@ checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97" dependencies = [ "dirs", "futures", - "http", "indicatif", "libc", "log", - "native-tls", "num_cpus", "rand 0.9.1", "reqwest", @@ -818,7 +760,6 @@ dependencies = [ "serde_json", "thiserror 2.0.12", "tokio", - "ureq", "windows-sys 0.60.2", ] @@ -918,22 +859,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", -] - -[[package]] -name = "hyper-tls" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" -dependencies = [ - "bytes", - "http-body-util", - "hyper", - "hyper-util", - "native-tls", - "tokio", - "tokio-native-tls", - "tower-service", + "webpki-roots", ] [[package]] @@ -1282,12 +1208,6 @@ dependencies = [ "libsecp256k1-core", ] -[[package]] -name = "linux-raw-sys" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" - [[package]] name = "litemap" version = "0.7.5" @@ -1368,7 +1288,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", - "simd-adler32", ] [[package]] @@ -1382,23 +1301,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "native-tls" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" -dependencies = [ - "libc", - "log", - "openssl", - "openssl-probe 0.1.6", - "openssl-sys", - "schannel", - "security-framework 2.11.1", - "security-framework-sys", - "tempfile", -] - [[package]] name = "nom" version = "7.1.3" @@ -1471,56 +1373,12 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" -[[package]] -name = "openssl" -version = "0.10.72" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da" -dependencies = [ - "bitflags", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-macros", - "openssl-sys", -] - -[[package]] -name = "openssl-macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "openssl-probe" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" - [[package]] name = "openssl-probe" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" -[[package]] -name = "openssl-sys" -version = "0.9.108" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e145e1651e858e820e4860f7b9c5e169bc1d8ce1c86043be79fa7b7634821847" -dependencies = [ - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "option-ext" version = "0.2.0" @@ -1561,12 +1419,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" -[[package]] -name = "pkg-config" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" - [[package]] name = "portable-atomic" version = "1.11.0" @@ -1823,24 +1675,25 @@ dependencies = [ "http-body-util", "hyper", "hyper-rustls", - "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", - "native-tls", "once_cell", "percent-encoding", "pin-project-lite", + "quinn", + "rustls", "rustls-pemfile", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", "system-configuration", "tokio", - "tokio-native-tls", + "tokio-rustls", "tokio-util", "tower", "tower-service", @@ -1849,6 +1702,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", + "webpki-roots", "windows-registry", ] @@ -1897,26 +1751,12 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" -[[package]] -name = "rustix" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" -dependencies = [ - "bitflags", - "errno", - "libc", - "linux-raw-sys", - "windows-sys 0.59.0", -] - [[package]] name = "rustls" version = "0.23.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7160e3e10bf4535308537f3c4e1641468cd0e485175d6163087c0393c7d46643" dependencies = [ - "log", "once_cell", "ring", "rustls-pki-types", @@ -1931,10 +1771,10 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" dependencies = [ - "openssl-probe 0.2.1", + "openssl-probe", "rustls-pki-types", "schannel", - "security-framework 3.7.0", + "security-framework", ] [[package]] @@ -1970,7 +1810,7 @@ dependencies = [ "rustls-native-certs", "rustls-platform-verifier-android", "rustls-webpki", - "security-framework 3.7.0", + "security-framework", "security-framework-sys", "webpki-root-certs", "windows-sys 0.61.2", @@ -2023,19 +1863,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "security-framework" -version = "2.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" -dependencies = [ - "bitflags", - "core-foundation 0.9.4", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - [[package]] name = "security-framework" version = "3.7.0" @@ -2171,12 +1998,6 @@ dependencies = [ "libc", ] -[[package]] -name = "simd-adler32" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" - [[package]] name = "siphasher" version = "1.0.2" @@ -2208,17 +2029,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "socks" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" -dependencies = [ - "byteorder", - "libc", - "winapi", -] - [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -2289,19 +2099,6 @@ dependencies = [ "libc", ] -[[package]] -name = "tempfile" -version = "3.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" -dependencies = [ - "fastrand", - "getrandom 0.3.2", - "once_cell", - "rustix", - "windows-sys 0.59.0", -] - [[package]] name = "thiserror" version = "1.0.69" @@ -2424,16 +2221,6 @@ dependencies = [ "syn", ] -[[package]] -name = "tokio-native-tls" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" -dependencies = [ - "native-tls", - "tokio", -] - [[package]] name = "tokio-rustls" version = "0.26.2" @@ -2578,26 +2365,6 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" -[[package]] -name = "ureq" -version = "2.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" -dependencies = [ - "base64", - "flate2", - "log", - "native-tls", - "once_cell", - "rustls", - "rustls-pki-types", - "serde", - "serde_json", - "socks", - "url", - "webpki-roots", -] - [[package]] name = "url" version = "2.5.4" @@ -2643,12 +2410,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - [[package]] name = "version_check" version = "0.9.5" diff --git a/Cargo.toml b/Cargo.toml index cd8aa120..63e02524 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ sha3 = "0.10" hex = "0.4" rand = "0.8" llama-cpp-2 = { version = "0.1.137", features = ["mtmd"] } -hf-hub = { version = "0.4", features = ["tokio"] } +hf-hub = { version = "0.4", default-features = false, features = ["tokio", "rustls-tls"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } thiserror = "2" From 60590ae7271f53a58ba99422bb39963b86f914b6 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Wed, 4 Mar 2026 23:50:00 +0300 Subject: [PATCH 16/57] Switch Linux targets from musl to gnu (llama.cpp needs C++ compiler) --- .github/workflows/releases.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index cc53b6e9..32597aa6 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -44,14 +44,14 @@ jobs: runner: ubuntu-latest, osname: linux, arch: amd64, - target: x86_64-unknown-linux-musl, + target: x86_64-unknown-linux-gnu, command: build, } - { runner: ubuntu-latest, osname: linux, arch: amd64-noavx, - target: x86_64-unknown-linux-musl, + target: x86_64-unknown-linux-gnu, command: build, noavx: true, } @@ -59,7 +59,7 @@ jobs: runner: ubuntu-latest, osname: linux, arch: arm64, - target: aarch64-unknown-linux-musl, + target: aarch64-unknown-linux-gnu, command: build, build_args: --no-default-features, } @@ -78,7 +78,7 @@ jobs: - name: Install build dependencies (Linux) if: matrix.osname == 'linux' - run: sudo apt-get update && sudo apt-get install -y cmake musl-tools + run: sudo apt-get update && sudo apt-get install -y cmake - name: Get the release version from the tag shell: bash From bc043721086b2f881eca8d185ccb90c31b1809c5 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Wed, 4 Mar 2026 23:56:00 +0300 Subject: [PATCH 17/57] Use native ARM64 runner for Linux arm64 build --- .github/workflows/releases.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index 32597aa6..74987927 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -56,7 +56,7 @@ jobs: noavx: true, } - { - runner: ubuntu-latest, + runner: ubuntu-24.04-arm, osname: linux, arch: arm64, target: aarch64-unknown-linux-gnu, From 91cbd3a4a16616092705602311ab85dc6c736597 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 00:15:38 +0300 Subject: [PATCH 18/57] Fix ARM64 Linux build: use cargo directly instead of cross --- .github/workflows/releases.yml | 15 +- Cargo.lock | 1114 +++++++++++++++++--------------- 2 files changed, 606 insertions(+), 523 deletions(-) diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index 74987927..37fbc0f7 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -93,7 +93,20 @@ jobs: echo "GGML_FMA=OFF" >> $GITHUB_ENV echo "GGML_F16C=OFF" >> $GITHUB_ENV - - name: Build binary + - name: Install Rust toolchain (native builds) + if: matrix.runner == 'ubuntu-24.04-arm' + uses: dtolnay/rust-toolchain@stable + + - name: Build binary (native) + if: matrix.runner == 'ubuntu-24.04-arm' + run: cargo build --bin dria-node --locked --release --target ${{ matrix.target }} ${{ matrix.build_args }} + + - name: Strip binary (native) + if: matrix.runner == 'ubuntu-24.04-arm' + run: strip target/${{ matrix.target }}/release/dria-node + + - name: Build binary (cross) + if: matrix.runner != 'ubuntu-24.04-arm' uses: houseabsolute/actions-rust-cross@v0 with: command: ${{ matrix.command }} diff --git a/Cargo.lock b/Cargo.lock index 491ae242..dbcf6403 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,35 +2,20 @@ # It is not intended for manual editing. version = 4 -[[package]] -name = "addr2line" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" -dependencies = [ - "gimli", -] - -[[package]] -name = "adler2" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" - [[package]] name = "aho-corasick" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" dependencies = [ "memchr", ] [[package]] name = "anstream" -version = "0.6.18" +version = "0.6.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" dependencies = [ "anstyle", "anstyle-parse", @@ -43,44 +28,44 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.10" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" [[package]] name = "anstyle-parse" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.2" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] name = "anstyle-wincon" -version = "3.0.7" +version = "3.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", - "once_cell", - "windows-sys 0.59.0", + "once_cell_polyfill", + "windows-sys 0.61.2", ] [[package]] name = "anyhow" -version = "1.0.98" +version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] name = "arrayref" @@ -96,24 +81,9 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" - -[[package]] -name = "backtrace" -version = "0.3.74" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" -dependencies = [ - "addr2line", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", - "windows-targets 0.52.6", -] +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "base64" @@ -167,15 +137,15 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.17.0" +version = "3.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" [[package]] name = "bytes" -version = "1.10.1" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" [[package]] name = "cc" @@ -206,9 +176,9 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.0" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "cfg_aliases" @@ -269,18 +239,18 @@ checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" [[package]] name = "cmake" -version = "0.1.54" +version = "0.1.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" dependencies = [ "cc", ] [[package]] name = "colorchoice" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" [[package]] name = "combine" @@ -342,15 +312,15 @@ dependencies = [ [[package]] name = "crunchy" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" [[package]] name = "crypto-common" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", "typenum", @@ -435,7 +405,7 @@ dependencies = [ "quinn", "rmp-serde", "serde", - "thiserror 2.0.12", + "thiserror 2.0.18", "uuid", ] @@ -464,7 +434,7 @@ dependencies = [ "serde_json", "sha2 0.10.9", "sha3", - "thiserror 2.0.12", + "thiserror 2.0.18", "tokio", "tokio-util", "tracing", @@ -519,15 +489,25 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + [[package]] name = "fastbloom" version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4e7f34442dbe69c60fe8eaf58a8cafff81a1f278816d8ab4db255b3bef4ac3c4" dependencies = [ - "getrandom 0.3.2", + "getrandom 0.3.4", "libm", - "rand 0.9.1", + "rand 0.9.2", "siphasher", ] @@ -552,20 +532,26 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "form_urlencoded" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" dependencies = [ "percent-encoding", ] [[package]] name = "futures" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" dependencies = [ "futures-channel", "futures-core", @@ -578,9 +564,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", "futures-sink", @@ -588,15 +574,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" [[package]] name = "futures-executor" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" dependencies = [ "futures-core", "futures-task", @@ -605,15 +591,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" [[package]] name = "futures-macro" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", @@ -622,21 +608,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" [[package]] name = "futures-task" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-util" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-channel", "futures-core", @@ -646,7 +632,6 @@ dependencies = [ "futures-task", "memchr", "pin-project-lite", - "pin-utils", "slab", ] @@ -662,36 +647,43 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", "js-sys", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", "wasm-bindgen", ] [[package]] name = "getrandom" -version = "0.3.2" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73fea8450eea4bac3940448fb7ae50d91f034f941199fcd9d909a5a07aa455f0" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", "js-sys", "libc", - "r-efi", - "wasi 0.14.2+wasi-0.2.4", + "r-efi 5.3.0", + "wasip2", "wasm-bindgen", ] [[package]] -name = "gimli" -version = "0.31.1" +name = "getrandom" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] [[package]] name = "glob" @@ -701,9 +693,9 @@ checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "h2" -version = "0.4.9" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75249d144030531f8dee69fe9cea04d3edf809a017ae445e2abdff6629e86633" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" dependencies = [ "atomic-waker", "bytes", @@ -720,9 +712,18 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.3" +version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" [[package]] name = "heck" @@ -732,9 +733,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" -version = "0.3.9" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" [[package]] name = "hex" @@ -754,11 +755,11 @@ dependencies = [ "libc", "log", "num_cpus", - "rand 0.9.1", + "rand 0.9.2", "reqwest", "serde", "serde_json", - "thiserror 2.0.12", + "thiserror 2.0.18", "tokio", "windows-sys 0.60.2", ] @@ -786,12 +787,11 @@ dependencies = [ [[package]] name = "http" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" dependencies = [ "bytes", - "fnv", "itoa", ] @@ -826,19 +826,21 @@ checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "hyper" -version = "1.6.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" dependencies = [ + "atomic-waker", "bytes", "futures-channel", - "futures-util", + "futures-core", "h2", "http", "http-body", "httparse", "itoa", "pin-project-lite", + "pin-utils", "smallvec", "tokio", "want", @@ -846,11 +848,10 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.5" +version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "futures-util", "http", "hyper", "hyper-util", @@ -864,41 +865,47 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.11" +version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497bbc33a26fdd4af9ed9c70d63f61cf56a938375fbb32df34db9b1cd6d643f2" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" dependencies = [ + "base64", "bytes", "futures-channel", "futures-util", "http", "http-body", "hyper", + "ipnet", "libc", + "percent-encoding", "pin-project-lite", "socket2", + "system-configuration", "tokio", "tower-service", "tracing", + "windows-registry", ] [[package]] name = "icu_collections" -version = "1.5.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" dependencies = [ "displaydoc", + "potential_utf", "yoke", "zerofrom", "zerovec", ] [[package]] -name = "icu_locid" -version = "1.5.0" +name = "icu_locale_core" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" dependencies = [ "displaydoc", "litemap", @@ -907,104 +914,72 @@ dependencies = [ "zerovec", ] -[[package]] -name = "icu_locid_transform" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_locid_transform_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_locid_transform_data" -version = "1.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7515e6d781098bf9f7205ab3fc7e9709d34554ae0b21ddbcb5febfa4bc7df11d" - [[package]] name = "icu_normalizer" -version = "1.5.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" dependencies = [ - "displaydoc", "icu_collections", "icu_normalizer_data", "icu_properties", "icu_provider", "smallvec", - "utf16_iter", - "utf8_iter", - "write16", "zerovec", ] [[package]] name = "icu_normalizer_data" -version = "1.5.1" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5e8338228bdc8ab83303f16b797e177953730f601a96c25d10cb3ab0daa0cb7" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" [[package]] name = "icu_properties" -version = "1.5.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" dependencies = [ - "displaydoc", "icu_collections", - "icu_locid_transform", + "icu_locale_core", "icu_properties_data", "icu_provider", - "tinystr", + "zerotrie", "zerovec", ] [[package]] name = "icu_properties_data" -version = "1.5.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85fb8799753b75aee8d2a21d7c14d9f38921b54b3dbda10f5a3c7a7b82dba5e2" +checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" [[package]] name = "icu_provider" -version = "1.5.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" dependencies = [ "displaydoc", - "icu_locid", - "icu_provider_macros", - "stable_deref_trait", - "tinystr", + "icu_locale_core", "writeable", "yoke", "zerofrom", + "zerotrie", "zerovec", ] [[package]] -name = "icu_provider_macros" -version = "1.5.0" +name = "id-arena" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" [[package]] name = "idna" -version = "1.0.3" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" dependencies = [ "idna_adapter", "smallvec", @@ -1013,9 +988,9 @@ dependencies = [ [[package]] name = "idna_adapter" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" dependencies = [ "icu_normalizer", "icu_properties", @@ -1023,12 +998,14 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.9.0" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.16.1", + "serde", + "serde_core", ] [[package]] @@ -1046,30 +1023,40 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.11.0" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" + +[[package]] +name = "iri-string" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +dependencies = [ + "memchr", + "serde", +] [[package]] name = "is_terminal_polyfill" -version = "1.70.1" +version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" [[package]] name = "itertools" -version = "0.12.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" dependencies = [ "either", ] [[package]] name = "itoa" -version = "1.0.15" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "jni" @@ -1099,15 +1086,15 @@ version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" dependencies = [ - "getrandom 0.3.2", + "getrandom 0.3.4", "libc", ] [[package]] name = "js-sys" -version = "0.3.77" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" dependencies = [ "once_cell", "wasm-bindgen", @@ -1115,9 +1102,9 @@ dependencies = [ [[package]] name = "keccak" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" +checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" dependencies = [ "cpufeatures", ] @@ -1128,11 +1115,17 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "libc" -version = "0.2.172" +version = "0.2.182" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" [[package]] name = "libloading" @@ -1141,7 +1134,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" dependencies = [ "cfg-if", - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -1152,11 +1145,10 @@ checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libredox" -version = "0.1.3" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +checksum = "1744e39d1d6a9948f4f388969627434e31128196de472883b39f148769bfe30a" dependencies = [ - "bitflags", "libc", ] @@ -1210,29 +1202,29 @@ dependencies = [ [[package]] name = "litemap" -version = "0.7.5" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" [[package]] name = "llama-cpp-2" -version = "0.1.137" +version = "0.1.138" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aedc4f4ca22ad992bc43fe20734b3a0f37363b9621419727821bf6572b9c0395" +checksum = "2947ab625c59d1fdf42e61f538c3fa66f43de2f78316971920873f359483d1d8" dependencies = [ "encoding_rs", "enumflags2", "llama-cpp-sys-2", - "thiserror 2.0.12", + "thiserror 2.0.18", "tracing", "tracing-core", ] [[package]] name = "llama-cpp-sys-2" -version = "0.1.137" +version = "0.1.138" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da365e84fbe4d10e849fa3bfd5a0d70b3b4a59e8c5adc8b7be5c189327566bdb" +checksum = "84a529006bf16af70c7485ba957820dc2bc9467d75697e97970c81d2da73c76f" dependencies = [ "bindgen", "cc", @@ -1244,9 +1236,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.27" +version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "lru-slab" @@ -1256,18 +1248,18 @@ checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "matchers" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata 0.1.10", + "regex-automata", ] [[package]] name = "memchr" -version = "2.7.4" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" [[package]] name = "mime" @@ -1281,24 +1273,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" -[[package]] -name = "miniz_oxide" -version = "0.8.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" -dependencies = [ - "adler2", -] - [[package]] name = "mio" -version = "1.0.3" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", - "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.52.0", + "wasi", + "windows-sys 0.61.2", ] [[package]] @@ -1313,12 +1296,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "overload", - "winapi", + "windows-sys 0.61.2", ] [[package]] @@ -1338,9 +1320,9 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" dependencies = [ "hermit-abi", "libc", @@ -1352,21 +1334,18 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" -[[package]] -name = "object" -version = "0.36.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" -dependencies = [ - "memchr", -] - [[package]] name = "once_cell" version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + [[package]] name = "opaque-debug" version = "0.3.1" @@ -1385,12 +1364,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "pem" version = "3.0.6" @@ -1403,15 +1376,15 @@ dependencies = [ [[package]] name = "percent-encoding" -version = "2.3.1" +version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" [[package]] name = "pin-project-lite" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" [[package]] name = "pin-utils" @@ -1421,9 +1394,18 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "portable-atomic" -version = "1.11.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "potential_utf" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +dependencies = [ + "zerovec", +] [[package]] name = "powerfmt" @@ -1442,9 +1424,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.34" +version = "0.2.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6837b9e10d61f45f987d50808f83d1ee3d206c66acf650c3e4ae2e1f6ddedf55" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", "syn", @@ -1452,9 +1434,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.95" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ "unicode-ident", ] @@ -1473,7 +1455,7 @@ dependencies = [ "rustc-hash", "rustls", "socket2", - "thiserror 2.0.12", + "thiserror 2.0.18", "tokio", "tracing", "web-time", @@ -1487,16 +1469,16 @@ checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" dependencies = [ "bytes", "fastbloom", - "getrandom 0.3.2", + "getrandom 0.3.4", "lru-slab", - "rand 0.9.1", + "rand 0.9.2", "ring", "rustc-hash", "rustls", "rustls-pki-types", "rustls-platform-verifier", "slab", - "thiserror 2.0.12", + "thiserror 2.0.18", "tinyvec", "tracing", "web-time", @@ -1518,18 +1500,24 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.40" +version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] [[package]] name = "r-efi" -version = "5.2.0" +version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" [[package]] name = "rand" @@ -1544,12 +1532,12 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", - "rand_core 0.9.3", + "rand_core 0.9.5", ] [[package]] @@ -1569,7 +1557,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core 0.9.3", + "rand_core 0.9.5", ] [[package]] @@ -1578,16 +1566,16 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", ] [[package]] name = "rand_core" -version = "0.9.3" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" dependencies = [ - "getrandom 0.3.2", + "getrandom 0.3.4", ] [[package]] @@ -1609,60 +1597,45 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", "libredox", - "thiserror 2.0.12", + "thiserror 2.0.18", ] [[package]] name = "regex" -version = "1.11.1" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", ] [[package]] name = "regex-automata" -version = "0.1.10" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", -] - -[[package]] -name = "regex-automata" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.5", + "regex-syntax", ] [[package]] name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - -[[package]] -name = "regex-syntax" -version = "0.8.5" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" [[package]] name = "reqwest" -version = "0.12.15" +version = "0.12.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64", "bytes", @@ -1676,26 +1649,23 @@ dependencies = [ "hyper", "hyper-rustls", "hyper-util", - "ipnet", "js-sys", "log", "mime", - "once_cell", "percent-encoding", "pin-project-lite", "quinn", "rustls", - "rustls-pemfile", "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", - "system-configuration", "tokio", "tokio-rustls", "tokio-util", "tower", + "tower-http", "tower-service", "url", "wasm-bindgen", @@ -1703,7 +1673,6 @@ dependencies = [ "wasm-streams", "web-sys", "webpki-roots", - "windows-registry", ] [[package]] @@ -1714,7 +1683,7 @@ checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.16", + "getrandom 0.2.17", "libc", "untrusted", "windows-sys 0.52.0", @@ -1739,12 +1708,6 @@ dependencies = [ "serde", ] -[[package]] -name = "rustc-demangle" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" - [[package]] name = "rustc-hash" version = "2.1.1" @@ -1753,9 +1716,9 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustls" -version = "0.23.28" +version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7160e3e10bf4535308537f3c4e1641468cd0e485175d6163087c0393c7d46643" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ "once_cell", "ring", @@ -1777,22 +1740,14 @@ dependencies = [ "security-framework", ] -[[package]] -name = "rustls-pemfile" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "rustls-pki-types" -version = "1.11.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ "web-time", + "zeroize", ] [[package]] @@ -1824,9 +1779,9 @@ checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" [[package]] name = "rustls-webpki" -version = "0.103.3" +version = "0.103.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" +checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" dependencies = [ "ring", "rustls-pki-types", @@ -1835,15 +1790,15 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.20" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" -version = "1.0.20" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" [[package]] name = "same-file" @@ -1856,11 +1811,11 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -1886,6 +1841,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "serde" version = "1.0.228" @@ -1918,14 +1879,15 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ "itoa", "memchr", - "ryu", "serde", + "serde_core", + "zmij", ] [[package]] @@ -1991,10 +1953,11 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.5" +version = "1.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" dependencies = [ + "errno", "libc", ] @@ -2006,34 +1969,31 @@ checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" [[package]] name = "slab" -version = "0.4.9" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" [[package]] name = "smallvec" -version = "1.15.0" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] name = "socket2" -version = "0.5.9" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f5fd57c80058a56cf5c777ab8a126398ece8e442983605d280a44ce79d0edef" +checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.60.2", ] [[package]] name = "stable_deref_trait" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" [[package]] name = "strsim" @@ -2049,9 +2009,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.101" +version = "2.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" dependencies = [ "proc-macro2", "quote", @@ -2080,9 +2040,9 @@ dependencies = [ [[package]] name = "system-configuration" -version = "0.6.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" dependencies = [ "bitflags", "core-foundation 0.9.4", @@ -2110,11 +2070,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.12" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ - "thiserror-impl 2.0.12", + "thiserror-impl 2.0.18", ] [[package]] @@ -2130,9 +2090,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.12" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", @@ -2141,12 +2101,11 @@ dependencies = [ [[package]] name = "thread_local" -version = "1.1.8" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" dependencies = [ "cfg-if", - "once_cell", ] [[package]] @@ -2170,9 +2129,9 @@ checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" [[package]] name = "tinystr" -version = "0.7.6" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" dependencies = [ "displaydoc", "zerovec", @@ -2195,11 +2154,10 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.44.2" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" dependencies = [ - "backtrace", "bytes", "libc", "mio", @@ -2207,14 +2165,14 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] name = "tokio-macros" -version = "2.5.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" dependencies = [ "proc-macro2", "quote", @@ -2223,9 +2181,9 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.2" +version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" dependencies = [ "rustls", "tokio", @@ -2233,24 +2191,23 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.15" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", "futures-sink", "futures-util", - "hashbrown", "pin-project-lite", "tokio", ] [[package]] name = "tower" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" dependencies = [ "futures-core", "futures-util", @@ -2261,6 +2218,24 @@ dependencies = [ "tower-service", ] +[[package]] +name = "tower-http" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -2275,9 +2250,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "log", "pin-project-lite", @@ -2287,9 +2262,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.28" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", @@ -2298,9 +2273,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.33" +version = "0.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", "valuable", @@ -2319,14 +2294,14 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ "matchers", "nu-ansi-term", "once_cell", - "regex", + "regex-automata", "sharded-slab", "smallvec", "thread_local", @@ -2343,15 +2318,15 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "typenum" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] name = "unicode-ident" -version = "1.0.18" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "unicode-width" @@ -2359,6 +2334,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "untrusted" version = "0.9.0" @@ -2367,21 +2348,16 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.4" +version = "2.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] -[[package]] -name = "utf16_iter" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" - [[package]] name = "utf8_iter" version = "1.0.4" @@ -2396,12 +2372,14 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.16.0" +version = "1.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "b672338555252d43fd2240c714dc444b8c6fb0a5c5335e65a07bba7742735ddb" dependencies = [ - "getrandom 0.3.2", - "serde", + "getrandom 0.4.2", + "js-sys", + "serde_core", + "wasm-bindgen", ] [[package]] @@ -2437,52 +2415,49 @@ dependencies = [ [[package]] name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] -name = "wasi" -version = "0.14.2+wasi-0.2.4" +name = "wasip2" +version = "1.0.2+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" dependencies = [ - "wit-bindgen-rt", + "wit-bindgen", ] [[package]] -name = "wasm-bindgen" -version = "0.2.100" +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" dependencies = [ - "cfg-if", - "once_cell", - "rustversion", - "wasm-bindgen-macro", + "wit-bindgen", ] [[package]] -name = "wasm-bindgen-backend" -version = "0.2.100" +name = "wasm-bindgen" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" dependencies = [ - "bumpalo", - "log", - "proc-macro2", - "quote", - "syn", + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.50" +version = "0.4.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" +checksum = "e9c5522b3a28661442748e09d40924dfb9ca614b21c00d3fd135720e48b67db8" dependencies = [ "cfg-if", + "futures-util", "js-sys", "once_cell", "wasm-bindgen", @@ -2491,9 +2466,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.100" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2501,26 +2476,48 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.100" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" dependencies = [ + "bumpalo", "proc-macro2", "quote", "syn", - "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.100" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + [[package]] name = "wasm-streams" version = "0.4.2" @@ -2534,11 +2531,23 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + [[package]] name = "web-sys" -version = "0.3.77" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" dependencies = [ "js-sys", "wasm-bindgen", @@ -2565,29 +2574,13 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.10" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37493cadf42a2a939ed404698ded7fb378bf301b5011f973361779a3a74f8c93" +checksum = "22cfaf3c063993ff62e73cb4311efde4db1efb31ab78a3e5c457939ad5cc0bed" dependencies = [ "rustls-pki-types", ] -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - [[package]] name = "winapi-util" version = "0.1.11" @@ -2597,18 +2590,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - -[[package]] -name = "windows-link" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" - [[package]] name = "windows-link" version = "0.2.1" @@ -2617,31 +2598,31 @@ checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "windows-registry" -version = "0.4.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" dependencies = [ + "windows-link", "windows-result", "windows-strings", - "windows-targets 0.53.5", ] [[package]] name = "windows-result" -version = "0.3.2" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" dependencies = [ - "windows-link 0.1.1", + "windows-link", ] [[package]] name = "windows-strings" -version = "0.3.1" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" dependencies = [ - "windows-link 0.1.1", + "windows-link", ] [[package]] @@ -2686,7 +2667,7 @@ version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" dependencies = [ - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -2726,15 +2707,15 @@ version = "0.53.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ - "windows-link 0.2.1", - "windows_aarch64_gnullvm 0.53.0", - "windows_aarch64_msvc 0.53.0", - "windows_i686_gnu 0.53.0", - "windows_i686_gnullvm 0.53.0", - "windows_i686_msvc 0.53.0", - "windows_x86_64_gnu 0.53.0", - "windows_x86_64_gnullvm 0.53.0", - "windows_x86_64_msvc 0.53.0", + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] [[package]] @@ -2751,9 +2732,9 @@ checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" [[package]] name = "windows_aarch64_msvc" @@ -2769,9 +2750,9 @@ checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_aarch64_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" [[package]] name = "windows_i686_gnu" @@ -2787,9 +2768,9 @@ checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnu" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" [[package]] name = "windows_i686_gnullvm" @@ -2799,9 +2780,9 @@ checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" [[package]] name = "windows_i686_msvc" @@ -2817,9 +2798,9 @@ checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_i686_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" [[package]] name = "windows_x86_64_gnu" @@ -2835,9 +2816,9 @@ checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnu" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" [[package]] name = "windows_x86_64_gnullvm" @@ -2853,9 +2834,9 @@ checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" [[package]] name = "windows_x86_64_msvc" @@ -2871,30 +2852,103 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "windows_x86_64_msvc" -version = "0.53.0" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "wit-bindgen" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] [[package]] -name = "wit-bindgen-rt" -version = "0.39.0" +name = "wit-bindgen-core" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", ] [[package]] -name = "write16" -version = "1.0.0" +name = "wit-parser" +version = "0.244.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] [[package]] name = "writeable" -version = "0.5.5" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" [[package]] name = "yasna" @@ -2907,11 +2961,10 @@ dependencies = [ [[package]] name = "yoke" -version = "0.7.5" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" dependencies = [ - "serde", "stable_deref_trait", "yoke-derive", "zerofrom", @@ -2919,9 +2972,9 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.5" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" dependencies = [ "proc-macro2", "quote", @@ -2931,18 +2984,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.25" +version = "0.8.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" +checksum = "a789c6e490b576db9f7e6b6d661bcc9799f7c0ac8352f56ea20193b2681532e5" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.25" +version = "0.8.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" +checksum = "f65c489a7071a749c849713807783f70672b28094011623e200cb86dcb835953" dependencies = [ "proc-macro2", "quote", @@ -2972,15 +3025,26 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.8.1" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] [[package]] name = "zerovec" -version = "0.10.4" +version = "0.11.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" dependencies = [ "yoke", "zerofrom", @@ -2989,11 +3053,17 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.10.3" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" dependencies = [ "proc-macro2", "quote", "syn", ] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" From 123d551d4d76d9a50e0e66800101785b43767c14 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 01:03:26 +0300 Subject: [PATCH 19/57] Add install script --- install.sh | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100755 install.sh diff --git a/install.sh b/install.sh new file mode 100755 index 00000000..d16fd9ac --- /dev/null +++ b/install.sh @@ -0,0 +1,47 @@ +#!/bin/sh +set -e + +REPO="firstbatchxyz/dkn-compute-node" +BINARY="dria-node" +INSTALL_DIR="/usr/local/bin" + +# Detect OS +OS=$(uname -s) +case "$OS" in + Linux*) OS_NAME="linux" ;; + Darwin*) OS_NAME="macOS" ;; + *) echo "Unsupported OS: $OS"; exit 1 ;; +esac + +# Detect architecture +ARCH=$(uname -m) +case "$ARCH" in + x86_64|amd64) ARCH_NAME="amd64" ;; + aarch64|arm64) ARCH_NAME="arm64" ;; + *) echo "Unsupported architecture: $ARCH"; exit 1 ;; +esac + +# Get latest release tag (includes pre-releases) +TAG=$(curl -fsSL "https://api.github.com/repos/${REPO}/releases" | grep '"tag_name"' | head -1 | cut -d'"' -f4) +if [ -z "$TAG" ]; then + echo "Failed to fetch latest release" + exit 1 +fi + +ASSET="${BINARY}-${OS_NAME}-${ARCH_NAME}" +URL="https://github.com/${REPO}/releases/download/${TAG}/${ASSET}" + +echo "Installing ${BINARY} ${TAG} (${OS_NAME}/${ARCH_NAME})..." + +TMPFILE=$(mktemp) +curl -fsSL "$URL" -o "$TMPFILE" +chmod +x "$TMPFILE" + +if [ -w "$INSTALL_DIR" ]; then + mv "$TMPFILE" "${INSTALL_DIR}/${BINARY}" +else + sudo mv "$TMPFILE" "${INSTALL_DIR}/${BINARY}" +fi + +echo "Installed ${BINARY} to ${INSTALL_DIR}/${BINARY}" +"${INSTALL_DIR}/${BINARY}" --version From c47af32a380e7178fc457492aae77e7b013357c4 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 01:06:10 +0300 Subject: [PATCH 20/57] Add Windows install script --- install.ps1 | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 install.ps1 diff --git a/install.ps1 b/install.ps1 new file mode 100644 index 00000000..f72a7aaa --- /dev/null +++ b/install.ps1 @@ -0,0 +1,36 @@ +$ErrorActionPreference = "Stop" + +$repo = "firstbatchxyz/dkn-compute-node" +$binary = "dria-node" + +# Get latest release tag (includes pre-releases) +$releases = Invoke-RestMethod "https://api.github.com/repos/$repo/releases" +$tag = $releases[0].tag_name +if (-not $tag) { + Write-Error "Failed to fetch latest release" + exit 1 +} + +$asset = "$binary-windows-amd64.exe" +$url = "https://github.com/$repo/releases/download/$tag/$asset" + +$installDir = "$env:LOCALAPPDATA\dria-node" +if (-not (Test-Path $installDir)) { + New-Item -ItemType Directory -Path $installDir | Out-Null +} + +$dest = Join-Path $installDir "$binary.exe" + +Write-Host "Installing $binary $tag..." +Invoke-WebRequest -Uri $url -OutFile $dest + +# Add to PATH if not already there +$userPath = [Environment]::GetEnvironmentVariable("Path", "User") +if ($userPath -notlike "*$installDir*") { + [Environment]::SetEnvironmentVariable("Path", "$userPath;$installDir", "User") + $env:Path = "$env:Path;$installDir" + Write-Host "Added $installDir to PATH" +} + +Write-Host "Installed $binary to $dest" +& $dest --version From 9576a1e7b6747a42f18d7f81e016fc11141e3065 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 02:32:14 +0300 Subject: [PATCH 21/57] Print banner after model load --- dnet.art | 18 ++++++++++++++++++ src/main.rs | 3 +++ 2 files changed, 21 insertions(+) create mode 100644 dnet.art diff --git a/dnet.art b/dnet.art new file mode 100644 index 00000000..bd4fa912 --- /dev/null +++ b/dnet.art @@ -0,0 +1,18 @@ + + +Dria [1.0.0-alpha.1] +Decentralized LLM inference + + .-----. .-.--.. .-..=+--. ----. ..... +. + ==-. -==+##==#@#+- -+#@@+====== .==-. =###@+--.....=@o+. ==-...===++=---- + =#- .ooooo+ .oooo= -@#- .#ooooo- #@- @ooo# @+. -o#=. -oooo= +o= +=o .oooo-. oooo@ =#oooooo= ## .@ooo+ @- +@- .oooo- @+ +@# .oooo. +oooo- ++ @oooo= ## . .oooo+ .@. --.++ .oooo- = +-@- -oooo. +oooo= += @oooo= ## -# .oooo+.--=o= -oooo. + ----+ooo@. #ooo@- #= @oooo= @# .@ -oooo- ++ . =oooo. + =ooo@ .#ooo#. #- ooooo@@+ ++-oooo- =#. #oooo= + =ooo@ -@oo@#. #- =@oooo# @oooo= .#+ #oooo= + .=#ooo@ +###-. .-@+ +ooo+ .+oooooo-...-#o= --#oooo@. + + https://dria.co/edge-ai + Made with <3 \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index cf7221ee..a5151056 100644 --- a/src/main.rs +++ b/src/main.rs @@ -122,6 +122,9 @@ async fn run_start( return Err(error::NodeError::Config("no models loaded".into()).into()); } + // Print banner + eprint!("{}", include_str!("../dnet.art")); + // Build the worker let mut worker = Worker::new(engines, config.max_concurrent); From 0af8ef6f1cb13ef83d745cc6082638506b327d28 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 02:44:03 +0300 Subject: [PATCH 22/57] Set macOS deployment target to 14.0 for Metal compatibility --- .github/workflows/releases.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index 37fbc0f7..04b0bc39 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -9,6 +9,7 @@ permissions: env: GGML_NATIVE: "OFF" + MACOSX_DEPLOYMENT_TARGET: "14.0" jobs: check_release: From e09ed5f8acad13c329fd6e45b81fb5612f585e95 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 02:45:36 +0300 Subject: [PATCH 23/57] Bump version to 1.0.0-alpha.2 --- Cargo.lock | 2 +- Cargo.toml | 2 +- dnet.art | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dbcf6403..09fe4672 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -411,7 +411,7 @@ dependencies = [ [[package]] name = "dria-node" -version = "1.0.0-alpha.1" +version = "1.0.0-alpha.2" dependencies = [ "anyhow", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 63e02524..df018aa9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dria-node" -version = "1.0.0-alpha.1" +version = "1.0.0-alpha.2" edition = "2021" license = "Apache-2.0" diff --git a/dnet.art b/dnet.art index bd4fa912..1d555376 100644 --- a/dnet.art +++ b/dnet.art @@ -1,6 +1,6 @@ -Dria [1.0.0-alpha.1] +Dria [1.0.0-alpha.2] Decentralized LLM inference .-----. .-.--.. .-..=+--. ----. ..... +. From 195407fcfbbb71a586f16b646ec64f8cd39ae680 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 03:00:10 +0300 Subject: [PATCH 24/57] Update README with install scripts and setup details --- README.md | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2a3a2e88..7dc9536b 100644 --- a/README.md +++ b/README.md @@ -40,19 +40,43 @@ brew tap firstbatchxyz/dkn brew install dria-node ``` +**Shell script (macOS / Linux):** + +```sh +curl -fsSL https://raw.githubusercontent.com/firstbatchxyz/dkn-compute-node/master/install.sh | sh +``` + +**PowerShell (Windows):** + +```powershell +irm https://raw.githubusercontent.com/firstbatchxyz/dkn-compute-node/master/install.ps1 | iex +``` + **From GitHub Releases:** Download the latest binary for your platform from [Releases](https://github.com/firstbatchxyz/dkn-compute-node/releases) and place it in your `PATH`. ### Setup -Run the interactive setup to pick a model, download it, and verify it works: +Run the interactive setup: ```sh dria-node setup ``` -This detects your system RAM, shows models that fit, downloads your selection, and runs a test inference. +This will: + +1. Detect your system RAM and list models that fit +2. Let you pick a model from the available options +3. Download the GGUF model file from HuggingFace +4. Run a test inference to verify everything works +5. Print a benchmark (tokens per second) + +Use `--gpu-layers -1` to offload all layers to GPU (Metal on macOS, requires building with `--features cuda` for NVIDIA): + +```sh +dria-node setup --gpu-layers -1 +``` ### Start @@ -62,6 +86,12 @@ Once setup is complete, start the node: dria-node start --wallet --model ``` +The node will connect to the Dria network, register your models, and start serving inference requests. You can increase throughput with `--max-concurrent`: + +```sh +dria-node start --wallet --model lfm2.5:1.2b --max-concurrent 4 +``` + ## Available Models | Model | Type | Quant | ~Size | From db7954c703ab80d5771ac5f1def12dfd489a955b Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 18:58:31 +0300 Subject: [PATCH 25/57] Version 0.7.1: interactive setup TUI with quant selection - Bump version from 1.0.0-alpha.2 to 0.7.1 - Replace number input with arrow-key selection (dialoguer) - Add 4-bit/8-bit quantization picker, RAM-aware - Retry loop on download/load failure instead of crashing - Fix qwen3.5:35b-a3b GGUF filename (was 404) --- Cargo.lock | 60 +++++++- Cargo.toml | 3 +- dnet.art | 2 +- src/models/registry.rs | 2 +- src/setup.rs | 338 ++++++++++++++++++++++++++--------------- 5 files changed, 280 insertions(+), 125 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 09fe4672..6416898f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -345,6 +345,19 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "dialoguer" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" +dependencies = [ + "console", + "shell-words", + "tempfile", + "thiserror 1.0.69", + "zeroize", +] + [[package]] name = "digest" version = "0.9.0" @@ -411,11 +424,12 @@ dependencies = [ [[package]] name = "dria-node" -version = "1.0.0-alpha.2" +version = "0.7.1" dependencies = [ "anyhow", "bytes", "clap", + "dialoguer", "dirs", "dkn-protocol", "encoding_rs", @@ -511,6 +525,12 @@ dependencies = [ "siphasher", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -1200,6 +1220,12 @@ dependencies = [ "libsecp256k1-core", ] +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + [[package]] name = "litemap" version = "0.8.1" @@ -1714,6 +1740,19 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + [[package]] name = "rustls" version = "0.23.37" @@ -1945,6 +1984,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shell-words" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77" + [[package]] name = "shlex" version = "1.3.0" @@ -2059,6 +2104,19 @@ dependencies = [ "libc", ] +[[package]] +name = "tempfile" +version = "3.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0" +dependencies = [ + "fastrand", + "getrandom 0.4.2", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "thiserror" version = "1.0.69" diff --git a/Cargo.toml b/Cargo.toml index df018aa9..9ba097f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dria-node" -version = "1.0.0-alpha.2" +version = "0.7.1" edition = "2021" license = "Apache-2.0" @@ -32,6 +32,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } thiserror = "2" anyhow = "1" uuid = { version = "1", features = ["v7", "serde"] } +dialoguer = "0.11" dirs = "6" encoding_rs = "0.8" dkn-protocol = { git = "https://github.com/firstbatchxyz/dkn-protocol.git" } diff --git a/dnet.art b/dnet.art index 1d555376..e99cd7d5 100644 --- a/dnet.art +++ b/dnet.art @@ -1,6 +1,6 @@ -Dria [1.0.0-alpha.2] +Dria [0.7.1] Decentralized LLM inference .-----. .-.--.. .-..=+--. ----. ..... +. diff --git a/src/models/registry.rs b/src/models/registry.rs index b3800868..4a9b3331 100644 --- a/src/models/registry.rs +++ b/src/models/registry.rs @@ -33,7 +33,7 @@ pub fn default_registry() -> HashMap { ModelSpec { name: "qwen3.5:35b-a3b".into(), hf_repo: "unsloth/Qwen3.5-35B-A3B-GGUF".into(), - hf_file: "Qwen3.5-35B-A3B-UD-Q4_K_M.gguf".into(), + hf_file: "Qwen3.5-35B-A3B-Q4_K_M.gguf".into(), sha256: None, model_type: ModelType::Vision, hf_mmproj_file: Some("mmproj-BF16.gguf".into()), diff --git a/src/setup.rs b/src/setup.rs index afab24fe..d3fdddef 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -2,6 +2,7 @@ use std::io::{self, Write}; use std::ops::ControlFlow; use std::path::PathBuf; +use dialoguer::{Select, theme::ColorfulTheme}; use dkn_protocol::ModelType; use crate::error::NodeError; @@ -149,60 +150,35 @@ pub async fn run_setup(data_dir: Option, gpu_layers: i32) -> Result<(), return Ok(()); } - // Print selectable list - println!(" Available models:"); - println!(); - for (i, m) in fits.iter().enumerate() { - println!( - " {}) {:<22} {:<8} {:<10} ~{:.1} GB", - i + 1, - m.name, - model_type_label(m.model_type), - m.quant, - m.size_gb, - ); - } - println!(); - - // Print too-large models + // Print too-large models as info if !too_large.is_empty() { - println!(" Models too large for your system (need more RAM):"); + println!(" Models too large for your system:"); for m in &too_large { println!( - " - {:<22} (~{:.0} GB) — needs ~{:.0} GB", + " - {:<22} (~{:.0} GB) — needs ~{:.0} GB RAM", m.name, m.size_gb, m.ram_needed_gb, ); } println!(); } - // Read selection - let selection = loop { - print!(" Select a model [1-{}]: ", fits.len()); - io::stdout().flush().map_err(|e| NodeError::Config(format!("stdout flush: {e}")))?; - - let mut input = String::new(); - io::stdin() - .read_line(&mut input) - .map_err(|e| NodeError::Config(format!("failed to read input: {e}")))?; - - match input.trim().parse::() { - Ok(n) if n >= 1 && n <= fits.len() => break n - 1, - _ => { - println!(" Invalid selection, please enter a number between 1 and {}.", fits.len()); - } - } - }; - - let chosen = &fits[selection]; - let model_name = &chosen.name; - println!(); - - // Resolve model spec - let spec = resolve_model(model_name, ®istry, None) - .ok_or_else(|| NodeError::Model(format!("unknown model: {model_name}")))?; - - // Set up cache dir + // Build display items for model selection + let model_items: Vec = fits + .iter() + .map(|m| { + format!( + "{:<22} {:<8} {:<10} ~{:.1} GB", + m.name, + model_type_label(m.model_type), + m.quant, + m.size_gb, + ) + }) + .collect(); + + let theme = ColorfulTheme::default(); + + // Set up cache dir once let data_dir = match data_dir { Some(d) => d, None => dirs::home_dir() @@ -213,98 +189,218 @@ pub async fn run_setup(data_dir: Option, gpu_layers: i32) -> Result<(), std::fs::create_dir_all(&models_dir)?; let cache = ModelCache::new(models_dir)?; - // Download model - println!(" Downloading {}...", model_name); - let model_path = if let Some(path) = cache.get_local_path(&spec) { - println!(" (already cached)"); - path - } else { - let hf_path = ModelDownloader::download(&spec).await?; - - // Verify SHA-256 if specified - if let Some(ref expected_sha) = spec.sha256 { - if !ModelCache::verify_sha256(&hf_path, expected_sha)? { - return Err(NodeError::Model(format!( - "SHA-256 mismatch for model {model_name}" - ))); + // Selection + download loop — retries on failure + let (spec, model_name_final, quant_override) = loop { + let selection = Select::with_theme(&theme) + .with_prompt(" Select a model") + .items(&model_items) + .default(0) + .interact() + .map_err(|e| NodeError::Config(format!("selection error: {e}")))?; + + let chosen = &fits[selection]; + let model_name = &chosen.name; + + // Quantization selection (4-bit vs 8-bit) + let q8_size = chosen.size_gb * 2.0; + let q8_ram = chosen.ram_needed_gb * 2.0; + let q8_fits = ram_gb.map_or(true, |gb| q8_ram < gb); + + let quant_override = if q8_fits { + let quant_items = vec![ + format!( + "4-bit ({}) ~{:.1} GB — smaller, faster", + chosen.quant, chosen.size_gb + ), + format!( + "8-bit (Q8_0){} ~{:.1} GB — better quality", + " ".repeat(chosen.quant.len().saturating_sub(4)), + q8_size + ), + "Back".to_string(), + ]; + + let quant_selection = Select::with_theme(&theme) + .with_prompt(" Select quantization") + .items(&quant_items) + .default(0) + .interact() + .map_err(|e| NodeError::Config(format!("selection error: {e}")))?; + + if quant_selection == 2 { + println!(); + continue; + } else if quant_selection == 1 { + Some("Q8_0") + } else { + None } - } + } else { + println!(); + println!( + " Using {} (8-bit needs ~{:.0} GB RAM, you have ~{:.0} GB)", + chosen.quant, + q8_ram, + ram_gb.unwrap_or(0.0) + ); + None + }; - cache.link_model(&spec, &hf_path)? - }; + println!(); - // Download mmproj if needed - let mmproj_path = if spec.hf_mmproj_file.is_some() { - if let Some(path) = cache.get_mmproj_path(&spec) { - Some(path) + let spec = match resolve_model(model_name, ®istry, quant_override) { + Some(s) => s, + None => { + println!(" Unknown model: {model_name}. Try again."); + println!(); + continue; + } + }; + + // Download model + println!(" Downloading {}...", model_name); + let model_path = if let Some(path) = cache.get_local_path(&spec) { + println!(" (already cached)"); + Ok(path) } else { - let hf_path = ModelDownloader::download_mmproj(&spec).await?; - Some(cache.link_mmproj(&spec, &hf_path)?) - } - } else { - None - }; + match ModelDownloader::download(&spec).await { + Ok(hf_path) => { + if let Some(ref expected_sha) = spec.sha256 { + if !ModelCache::verify_sha256(&hf_path, expected_sha)? { + println!(" SHA-256 mismatch! Try a different model."); + println!(); + continue; + } + } + cache.link_model(&spec, &hf_path).map_err(|e| e.into()) + } + Err(e) => Err(e), + } + }; - // Load model - println!(); - println!(" Loading model..."); - let engine = tokio::task::spawn_blocking({ - let model_path = model_path.clone(); - let mmproj_path = mmproj_path.clone(); - move || InferenceEngine::load(&model_path, gpu_layers, mmproj_path.as_deref()) - }) - .await - .map_err(|e| NodeError::Inference(format!("task join error: {e}")))? - ?; - - // Run test inference - println!(" Running test inference..."); - println!(); + let model_path = match model_path { + Ok(p) => p, + Err(e) => { + println!(" Download failed: {e}"); + println!(" Try a different model or quantization."); + println!(); + continue; + } + }; - let model_name_owned = model_name.clone(); - let result = tokio::task::spawn_blocking(move || { - let prompt = engine - .apply_template(&[dkn_protocol::ChatMessage { - role: "user".into(), - content: dkn_protocol::MessageContent::Text("Hello!".into()), - }]) - .unwrap_or_else(|_| "Hello!".into()); - - let params = GenerateParams { - max_tokens: 64, - temperature: 0.7, - ..Default::default() + // Download mmproj if needed + let mmproj_result = if spec.hf_mmproj_file.is_some() { + if let Some(path) = cache.get_mmproj_path(&spec) { + Ok(Some(path)) + } else { + match ModelDownloader::download_mmproj(&spec).await { + Ok(hf_path) => cache.link_mmproj(&spec, &hf_path).map(Some), + Err(e) => Err(e), + } + } + } else { + Ok(None) }; - print!(" > "); - let result = engine.generate(&prompt, ¶ms, |token| { - print!("{}", token.text); - let _ = io::stdout().flush(); - ControlFlow::Continue(()) - }); + let mmproj_path = match mmproj_result { + Ok(p) => p, + Err(e) => { + println!(" Multimodal projector download failed: {e}"); + println!(" Try a different model."); + println!(); + continue; + } + }; + + // Load model println!(); + println!(" Loading model..."); + let engine = tokio::task::spawn_blocking({ + let model_path = model_path.clone(); + let mmproj_path = mmproj_path.clone(); + move || InferenceEngine::load(&model_path, gpu_layers, mmproj_path.as_deref()) + }) + .await + .map_err(|e| NodeError::Inference(format!("task join error: {e}")))?; + + let engine = match engine { + Ok(e) => e, + Err(e) => { + println!(" Failed to load model: {e}"); + println!(" Try a different model or quantization."); + println!(); + continue; + } + }; - result.map(|r| (r, model_name_owned)) - }) - .await - .map_err(|e| NodeError::Inference(format!("task join error: {e}")))? - ?; + // Run test inference + println!(" Running test inference..."); + println!(); - let (inference_result, model_name_final) = result; + let model_name_owned = model_name.clone(); + let result = tokio::task::spawn_blocking(move || { + let prompt = engine + .apply_template(&[dkn_protocol::ChatMessage { + role: "user".into(), + content: dkn_protocol::MessageContent::Text("Hello!".into()), + }]) + .unwrap_or_else(|_| "Hello!".into()); + + let params = GenerateParams { + max_tokens: 64, + temperature: 0.7, + ..Default::default() + }; + + print!(" > "); + let result = engine.generate(&prompt, ¶ms, |token| { + print!("{}", token.text); + let _ = io::stdout().flush(); + ControlFlow::Continue(()) + }); + println!(); + + result.map(|r| (r, model_name_owned)) + }) + .await + .map_err(|e| NodeError::Inference(format!("task join error: {e}")))?; + + match result { + Ok((inference_result, name)) => { + println!(); + println!( + " Model working! {:.1} tok/s", + inference_result.tokens_per_second + ); + break (spec, name, quant_override); + } + Err(e) => { + println!(" Inference test failed: {e}"); + println!(" Try a different model."); + println!(); + continue; + } + } + }; - println!(); - println!( - " Model working! {:.1} tok/s", - inference_result.tokens_per_second - ); println!(); println!(" To start the node:"); - println!( - " dria-node start --wallet --model {}", - model_name_final - ); + if let Some(q) = quant_override { + println!( + " dria-node start --wallet --model {} --quant {}", + model_name_final, q + ); + } else { + println!( + " dria-node start --wallet --model {}", + model_name_final + ); + } println!(); + // Suppress unused variable warning + let _ = spec; + Ok(()) } From 5890067376f72cefbfb928ab4da642ca4e5f6a01 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 19:00:09 +0300 Subject: [PATCH 26/57] Fix clippy warnings --- src/setup.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/setup.rs b/src/setup.rs index d3fdddef..0ca58fc0 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -204,7 +204,7 @@ pub async fn run_setup(data_dir: Option, gpu_layers: i32) -> Result<(), // Quantization selection (4-bit vs 8-bit) let q8_size = chosen.size_gb * 2.0; let q8_ram = chosen.ram_needed_gb * 2.0; - let q8_fits = ram_gb.map_or(true, |gb| q8_ram < gb); + let q8_fits = ram_gb.is_none_or(|gb| q8_ram < gb); let quant_override = if q8_fits { let quant_items = vec![ @@ -272,7 +272,7 @@ pub async fn run_setup(data_dir: Option, gpu_layers: i32) -> Result<(), continue; } } - cache.link_model(&spec, &hf_path).map_err(|e| e.into()) + cache.link_model(&spec, &hf_path) } Err(e) => Err(e), } From 4f5ff7fbbe1a7e7fa540a20499eaa28144eb6561 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 19:56:57 +0300 Subject: [PATCH 27/57] Bump version to 0.7.2 --- Cargo.toml | 2 +- dnet.art | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9ba097f7..11c5c74e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dria-node" -version = "0.7.1" +version = "0.7.2" edition = "2021" license = "Apache-2.0" diff --git a/dnet.art b/dnet.art index e99cd7d5..f0c6c6c2 100644 --- a/dnet.art +++ b/dnet.art @@ -1,6 +1,6 @@ -Dria [0.7.1] +Dria [0.7.2] Decentralized LLM inference .-----. .-.--.. .-..=+--. ----. ..... +. From 961e68e00c87890281b45ea30c9f08f9cd4ee89d Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 20:02:36 +0300 Subject: [PATCH 28/57] Add workflow_dispatch to releases workflow --- .github/workflows/releases.yml | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index 04b0bc39..4acc22c3 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -3,6 +3,11 @@ name: releases on: release: types: [published] + workflow_dispatch: + inputs: + tag: + description: "Release tag (e.g. v0.7.2)" + required: true permissions: contents: write @@ -10,15 +15,14 @@ permissions: env: GGML_NATIVE: "OFF" MACOSX_DEPLOYMENT_TARGET: "14.0" + RELEASE_TAG: ${{ github.event.release.tag_name || github.event.inputs.tag }} jobs: check_release: runs-on: ubuntu-latest steps: - name: Echo tag - run: | - echo "tag name: ${{ github.event.release.tag_name }}" - echo "release name: ${{ github.event.release.name }}" + run: echo "tag name: ${{ env.RELEASE_TAG }}" build: needs: check_release @@ -76,6 +80,8 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v4 + with: + ref: ${{ env.RELEASE_TAG }} - name: Install build dependencies (Linux) if: matrix.osname == 'linux' @@ -144,8 +150,8 @@ jobs: - name: Create release with artifacts uses: ncipollo/release-action@v1 with: - name: ${{ github.event.release.name }} - tag: ${{ github.event.release.tag_name }} + name: ${{ env.RELEASE_TAG }} + tag: ${{ env.RELEASE_TAG }} artifacts: "artifacts/*" artifactContentType: application/octet-stream allowUpdates: true @@ -171,7 +177,7 @@ jobs: - name: Update formula run: | - VERSION="${{ github.event.release.tag_name }}" + VERSION="${{ env.RELEASE_TAG }}" VERSION="${VERSION#v}" SHA_MACOS_AMD64=$(sha256sum artifacts/dria-node-macOS-amd64 | cut -d' ' -f1) @@ -239,5 +245,5 @@ jobs: git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" git add Formula/dria-node.rb - git commit -m "Update dria-node to ${{ github.event.release.tag_name }}" + git commit -m "Update dria-node to ${{ env.RELEASE_TAG }}" git push From 005b441e81a0909da7e6b7fdab407c062483bde4 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 22:30:08 +0300 Subject: [PATCH 29/57] Update Cargo.lock --- Cargo.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 6416898f..4c6b3d11 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -424,7 +424,7 @@ dependencies = [ [[package]] name = "dria-node" -version = "0.7.1" +version = "0.7.2" dependencies = [ "anyhow", "bytes", From 6b25f6458cdeb1d25993911aa3ef628f9d0882c1 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 22:50:59 +0300 Subject: [PATCH 30/57] Remove --locked from release builds (git dep breaks it) --- .github/workflows/releases.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index 4acc22c3..6574e65e 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -106,7 +106,7 @@ jobs: - name: Build binary (native) if: matrix.runner == 'ubuntu-24.04-arm' - run: cargo build --bin dria-node --locked --release --target ${{ matrix.target }} ${{ matrix.build_args }} + run: cargo build --bin dria-node --release --target ${{ matrix.target }} ${{ matrix.build_args }} - name: Strip binary (native) if: matrix.runner == 'ubuntu-24.04-arm' @@ -118,7 +118,7 @@ jobs: with: command: ${{ matrix.command }} target: ${{ matrix.target }} - args: "--bin dria-node --locked --release ${{ matrix.build_args }}" + args: "--bin dria-node --release ${{ matrix.build_args }}" strip: true - name: Prepare Release File From a9274f685b1410cc04fda341a22c7e41a33d1521 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 22:52:27 +0300 Subject: [PATCH 31/57] Quote RELEASE_TAG expression to fix YAML syntax --- .github/workflows/releases.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index 6574e65e..ee5e7eb0 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -15,7 +15,7 @@ permissions: env: GGML_NATIVE: "OFF" MACOSX_DEPLOYMENT_TARGET: "14.0" - RELEASE_TAG: ${{ github.event.release.tag_name || github.event.inputs.tag }} + RELEASE_TAG: "${{ github.event.release.tag_name || github.event.inputs.tag }}" jobs: check_release: From 90c5b72f4bd055b68302f0f85190aeed9c7ea262 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 5 Mar 2026 22:53:24 +0300 Subject: [PATCH 32/57] Remove check_release job (fixes YAML validation on push) --- .github/workflows/releases.yml | 7 ------- 1 file changed, 7 deletions(-) diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index ee5e7eb0..97eb6fb8 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -18,14 +18,7 @@ env: RELEASE_TAG: "${{ github.event.release.tag_name || github.event.inputs.tag }}" jobs: - check_release: - runs-on: ubuntu-latest - steps: - - name: Echo tag - run: echo "tag name: ${{ env.RELEASE_TAG }}" - build: - needs: check_release runs-on: ${{ matrix.runner }} strategy: matrix: From 29d04a428c9beacbf2806c96b6c45eb497d89e39 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Sun, 8 Mar 2026 03:23:53 +0300 Subject: [PATCH 33/57] Add auto-update mechanism for v2 compute node Checks GitHub releases on startup: patch bumps warn, minor/major bumps download and replace the binary. Includes --skip-update flag and makeLatest: true in the release workflow. Co-Authored-By: Claude Opus 4.6 --- .github/workflows/releases.yml | 2 +- Cargo.lock | 15 +++ Cargo.toml | 4 + src/config.rs | 15 +++ src/error.rs | 3 + src/main.rs | 33 +++++- src/update.rs | 211 +++++++++++++++++++++++++++++++++ 7 files changed, 280 insertions(+), 3 deletions(-) create mode 100644 src/update.rs diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index 97eb6fb8..7c57508c 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -148,7 +148,7 @@ jobs: artifacts: "artifacts/*" artifactContentType: application/octet-stream allowUpdates: true - makeLatest: false + makeLatest: true update_homebrew: needs: release diff --git a/Cargo.lock b/Cargo.lock index 4c6b3d11..1436f594 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -441,13 +441,17 @@ dependencies = [ "quinn", "rand 0.8.5", "rcgen", + "reqwest", "rmp-serde", "rustls", "rustls-native-certs", + "self-replace", + "semver", "serde", "serde_json", "sha2 0.10.9", "sha3", + "tempfile", "thiserror 2.0.18", "tokio", "tokio-util", @@ -1880,6 +1884,17 @@ dependencies = [ "libc", ] +[[package]] +name = "self-replace" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03ec815b5eab420ab893f63393878d89c90fdd94c0bcc44c07abb8ad95552fb7" +dependencies = [ + "fastrand", + "tempfile", + "windows-sys 0.52.0", +] + [[package]] name = "semver" version = "1.0.27" diff --git a/Cargo.toml b/Cargo.toml index 11c5c74e..69bb2a19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,3 +42,7 @@ rustls-native-certs = "0.8" rcgen = "0.13" futures = "0.3" bytes = "1" +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json"] } +semver = "1" +self-replace = "1" +tempfile = "3" diff --git a/src/config.rs b/src/config.rs index 7749997e..543b08c5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -57,6 +57,10 @@ pub enum Command { /// Skip TLS certificate verification (for development/testing) #[arg(long, env = "DRIA_INSECURE")] insecure: bool, + + /// Skip automatic update check on startup + #[arg(long, env = "DRIA_SKIP_UPDATE")] + skip_update: bool, }, } @@ -71,6 +75,7 @@ pub struct Config { pub models_dir: PathBuf, pub quant: Option, pub insecure: bool, + pub skip_update: bool, } impl Config { @@ -85,6 +90,7 @@ impl Config { data_dir: Option, quant: Option, insecure: bool, + skip_update: bool, ) -> Result { // Validate wallet key let secret_key_hex = wallet.strip_prefix("0x").unwrap_or(&wallet).to_string(); @@ -140,6 +146,7 @@ impl Config { models_dir, quant, insecure, + skip_update, }) } } @@ -159,6 +166,7 @@ mod tests { Some("/tmp/dria-test".into()), None, false, + false, ) .unwrap(); @@ -182,6 +190,7 @@ mod tests { None, None, false, + false, ); assert!(result.is_err()); } @@ -197,6 +206,7 @@ mod tests { None, None, false, + false, ); assert!(result.is_err()); } @@ -212,6 +222,7 @@ mod tests { None, None, false, + false, ); assert!(result.is_err()); } @@ -227,6 +238,7 @@ mod tests { None, None, false, + false, ); assert!(result.is_err()); } @@ -242,6 +254,7 @@ mod tests { None, None, false, + false, ) .unwrap(); assert_eq!( @@ -261,6 +274,7 @@ mod tests { None, None, false, + false, ); assert!(result.is_err()); } @@ -276,6 +290,7 @@ mod tests { None, None, true, + false, ) .unwrap(); assert!(cfg.insecure); diff --git a/src/error.rs b/src/error.rs index 1287b72e..440093a1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -17,6 +17,9 @@ pub enum NodeError { #[error("network error: {0}")] Network(String), + #[error("update error: {0}")] + Update(String), + #[error("io error: {0}")] Io(#[from] std::io::Error), } diff --git a/src/main.rs b/src/main.rs index a5151056..4109e0f9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ mod models; mod network; mod setup; mod stats; +mod update; mod worker; use std::collections::HashMap; @@ -53,8 +54,9 @@ async fn main() -> anyhow::Result<()> { data_dir, quant, insecure, + skip_update, } => { - run_start(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure).await?; + run_start(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure, skip_update).await?; } } @@ -87,14 +89,41 @@ async fn run_start( data_dir: Option, quant: Option, insecure: bool, + skip_update: bool, ) -> anyhow::Result<()> { // Parse config - let config = Config::from_start_args(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure)?; + let config = Config::from_start_args(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure, skip_update)?; // Create identity let identity = Identity::from_secret_hex(&config.secret_key_hex)?; tracing::info!(address = %format!("0x{}", identity.address_hex), "node identity"); + // Check for updates + if !config.skip_update { + match update::check_for_update().await { + Ok(update::UpdateAction::Force(version)) => { + tracing::warn!(%version, "mandatory update available, downloading..."); + if let Err(e) = update::perform_update(&version).await { + tracing::error!(%e, "auto-update failed, continuing with current version"); + } else { + tracing::info!("update complete — please restart the node"); + return Ok(()); + } + } + Ok(update::UpdateAction::Warn(version)) => { + tracing::warn!( + %version, + "new patch version available, update recommended (current: {})", + env!("CARGO_PKG_VERSION") + ); + } + Ok(update::UpdateAction::UpToDate) => {} + Err(e) => { + tracing::debug!(%e, "update check failed, continuing"); + } + } + } + // Ensure directories exist std::fs::create_dir_all(&config.data_dir)?; std::fs::create_dir_all(&config.models_dir)?; diff --git a/src/update.rs b/src/update.rs new file mode 100644 index 00000000..ca37e156 --- /dev/null +++ b/src/update.rs @@ -0,0 +1,211 @@ +use std::time::Duration; + +use semver::Version; + +use crate::error::NodeError; + +const GITHUB_RELEASES_URL: &str = + "https://api.github.com/repos/firstbatchxyz/dkn-compute-node/releases/latest"; + +#[derive(serde::Deserialize)] +struct GitHubRelease { + tag_name: String, + #[allow(dead_code)] + assets: Vec, + prerelease: bool, +} + +#[derive(serde::Deserialize)] +struct GitHubAsset { + #[allow(dead_code)] + name: String, + #[allow(dead_code)] + browser_download_url: String, +} + +pub enum UpdateAction { + UpToDate, + Warn(String), + Force(String), +} + +/// Determine the correct release asset name for this platform. +fn asset_name() -> Result<&'static str, NodeError> { + match (std::env::consts::OS, std::env::consts::ARCH) { + ("macos", "x86_64") => Ok("dria-node-macOS-amd64"), + ("macos", "aarch64") => Ok("dria-node-macOS-arm64"), + ("linux", "x86_64") => Ok("dria-node-linux-amd64"), + ("linux", "aarch64") => Ok("dria-node-linux-arm64"), + ("windows", "x86_64") => Ok("dria-node-windows-amd64.exe"), + (os, arch) => Err(NodeError::Update(format!( + "unsupported platform: {os}/{arch}" + ))), + } +} + +/// Classify the update action based on semver comparison. +fn classify_update(current: &Version, latest: &Version) -> UpdateAction { + if latest <= current { + return UpdateAction::UpToDate; + } + if current.major == latest.major && current.minor == latest.minor { + UpdateAction::Warn(latest.to_string()) + } else { + UpdateAction::Force(latest.to_string()) + } +} + +/// Check GitHub for the latest release and compare with current version. +pub async fn check_for_update() -> Result { + let current_version = env!("CARGO_PKG_VERSION"); + let current = Version::parse(current_version) + .map_err(|e| NodeError::Update(format!("invalid current version: {e}")))?; + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(3)) + .user_agent(format!("dria-node/{current_version}")) + .build() + .map_err(|e| NodeError::Update(format!("http client error: {e}")))?; + + let release: GitHubRelease = client + .get(GITHUB_RELEASES_URL) + .send() + .await + .map_err(|e| NodeError::Update(format!("failed to fetch release info: {e}")))? + .json() + .await + .map_err(|e| NodeError::Update(format!("failed to parse release info: {e}")))?; + + if release.prerelease { + return Ok(UpdateAction::UpToDate); + } + + let tag = release.tag_name.strip_prefix('v').unwrap_or(&release.tag_name); + let latest = Version::parse(tag) + .map_err(|e| NodeError::Update(format!("invalid release version '{tag}': {e}")))?; + + Ok(classify_update(¤t, &latest)) +} + +/// Download the update binary and replace the current executable. +pub async fn perform_update(version: &str) -> Result<(), NodeError> { + let asset = asset_name()?; + + let url = format!( + "https://github.com/firstbatchxyz/dkn-compute-node/releases/download/v{version}/{asset}" + ); + + tracing::info!(%url, "downloading update"); + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(120)) + .user_agent(format!("dria-node/{}", env!("CARGO_PKG_VERSION"))) + .build() + .map_err(|e| NodeError::Update(format!("http client error: {e}")))?; + + let response = client + .get(&url) + .send() + .await + .map_err(|e| NodeError::Update(format!("download failed: {e}")))?; + + if !response.status().is_success() { + return Err(NodeError::Update(format!( + "download failed with status: {}", + response.status() + ))); + } + + let bytes = response + .bytes() + .await + .map_err(|e| NodeError::Update(format!("failed to read download: {e}")))?; + + // Write to a temp file + let mut tmp = tempfile::NamedTempFile::new() + .map_err(|e| NodeError::Update(format!("failed to create temp file: {e}")))?; + + std::io::Write::write_all(&mut tmp, &bytes) + .map_err(|e| NodeError::Update(format!("failed to write temp file: {e}")))?; + + // Set executable permission on unix + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(tmp.path(), std::fs::Permissions::from_mode(0o755)) + .map_err(|e| NodeError::Update(format!("failed to set permissions: {e}")))?; + } + + // Atomic self-replace + self_replace::self_replace(tmp.path()) + .map_err(|e| NodeError::Update(format!("self-replace failed: {e}")))?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn classify_same_version_is_up_to_date() { + let current = Version::new(0, 7, 2); + let latest = Version::new(0, 7, 2); + assert!(matches!(classify_update(¤t, &latest), UpdateAction::UpToDate)); + } + + #[test] + fn classify_patch_bump_is_warn() { + let current = Version::new(0, 7, 2); + let latest = Version::new(0, 7, 3); + assert!(matches!(classify_update(¤t, &latest), UpdateAction::Warn(v) if v == "0.7.3")); + } + + #[test] + fn classify_minor_bump_is_force() { + let current = Version::new(0, 7, 2); + let latest = Version::new(0, 8, 0); + assert!(matches!(classify_update(¤t, &latest), UpdateAction::Force(v) if v == "0.8.0")); + } + + #[test] + fn classify_major_bump_is_force() { + let current = Version::new(0, 7, 2); + let latest = Version::new(1, 0, 0); + assert!(matches!(classify_update(¤t, &latest), UpdateAction::Force(v) if v == "1.0.0")); + } + + #[test] + fn classify_older_release_is_up_to_date() { + let current = Version::new(0, 8, 0); + let latest = Version::new(0, 7, 2); + assert!(matches!(classify_update(¤t, &latest), UpdateAction::UpToDate)); + } + + #[test] + fn asset_name_returns_value_for_current_platform() { + // Should not error on any CI/dev platform we support + let name = asset_name().unwrap(); + assert!(name.starts_with("dria-node-")); + } + + #[test] + fn parse_github_release_json() { + let json = r#"{ + "tag_name": "v0.8.0", + "prerelease": false, + "assets": [ + { + "name": "dria-node-linux-amd64", + "browser_download_url": "https://github.com/firstbatchxyz/dkn-compute-node/releases/download/v0.8.0/dria-node-linux-amd64" + } + ] + }"#; + + let release: GitHubRelease = serde_json::from_str(json).unwrap(); + assert_eq!(release.tag_name, "v0.8.0"); + assert!(!release.prerelease); + assert_eq!(release.assets.len(), 1); + assert_eq!(release.assets[0].name, "dria-node-linux-amd64"); + } +} From d2242e16ae7fb017959f79bab6264c7840faafc8 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Sun, 8 Mar 2026 18:40:27 +0300 Subject: [PATCH 34/57] Add prefill-only validation and stride-based logprob extraction --- Cargo.lock | 1 - Cargo.toml | 2 +- src/inference/engine.rs | 230 +++++++++++++++++++++++++++++++++++++--- src/main.rs | 21 ++++ src/worker.rs | 93 +++++++++++++++- 5 files changed, 326 insertions(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1436f594..0562eb9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -412,7 +412,6 @@ dependencies = [ [[package]] name = "dkn-protocol" version = "0.1.0" -source = "git+https://github.com/firstbatchxyz/dkn-protocol.git#19dcd03eda240eaa98665a7a839a32205b5ef912" dependencies = [ "base64", "quinn", diff --git a/Cargo.toml b/Cargo.toml index 69bb2a19..94fbc986 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ uuid = { version = "1", features = ["v7", "serde"] } dialoguer = "0.11" dirs = "6" encoding_rs = "0.8" -dkn-protocol = { git = "https://github.com/firstbatchxyz/dkn-protocol.git" } +dkn-protocol = { path = "../dkn-protocol" } quinn = "0.11" rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } rustls-native-certs = "0.8" diff --git a/src/inference/engine.rs b/src/inference/engine.rs index 1a2ed524..a842610a 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -25,8 +25,9 @@ pub struct GenerateParams { pub temperature: f32, pub top_p: f32, pub seed: Option, - /// Token positions at which to extract logprobs. - pub logprob_positions: Vec, + /// Extract logprobs every N tokens (0 = disabled). + /// E.g. 32 → positions [0, 32, 64, ...]. + pub logprob_every_n: usize, /// Top-k alternatives to collect at each logprob position. pub logprob_top_k: usize, } @@ -38,7 +39,7 @@ impl Default for GenerateParams { temperature: 0.7, top_p: 0.9, seed: None, - logprob_positions: vec![], + logprob_every_n: 0, logprob_top_k: 5, } } @@ -233,6 +234,9 @@ impl InferenceEngine { let mut logprobs: Vec = Vec::new(); let mut current_pos = tokens.len() as i32; let mut decoder = encoding_rs::UTF_8.new_decoder(); + // Track the sequence position where logits are available (for extract_logprob). + // sampler.sample() always uses -1 ("last logits" in the C API). + let mut logits_pos: i32 = (tokens.len() - 1) as i32; for _ in 0..params.max_tokens { let new_token = sampler.sample(&ctx, -1); @@ -242,11 +246,11 @@ impl InferenceEngine { break; } - // Extract logprobs if this position was requested + // Extract logprobs at stride positions let gen_index = generated_count as usize; - if params.logprob_positions.contains(&gen_index) { + if params.logprob_every_n > 0 && gen_index.is_multiple_of(params.logprob_every_n) { if let Some(lp) = - self.extract_logprob(&ctx, -1, gen_index, new_token, params.logprob_top_k) + self.extract_logprob(&ctx, logits_pos, gen_index, new_token, params.logprob_top_k) { logprobs.push(lp); } @@ -276,6 +280,7 @@ impl InferenceEngine { .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?; ctx.decode(&mut batch) .map_err(|e| NodeError::Inference(format!("decode failed: {e}")))?; + logits_pos = current_pos; // logits available at the position we just decoded current_pos += 1; } @@ -389,28 +394,23 @@ impl InferenceEngine { let gen_start = Instant::now(); let mut generated_text = String::new(); let mut generated_count: u32 = 0; - let mut logprobs: Vec = Vec::new(); + let logprobs: Vec = Vec::new(); let mut current_pos = n_past; let mut decoder = encoding_rs::UTF_8.new_decoder(); let mut batch = LlamaBatch::new(1, 1); + // After eval_chunks the logits index is opaque; use -1 for first sample. + // Multimodal tasks skip validation so logprob extraction is not needed. + let mut logits_idx: i32 = -1; for _ in 0..params.max_tokens { - let new_token = sampler.sample(&ctx, -1); + let new_token = sampler.sample(&ctx, logits_idx); sampler.accept(new_token); if self.model.is_eog_token(new_token) { break; } - // Extract logprobs if this position was requested let gen_index = generated_count as usize; - if params.logprob_positions.contains(&gen_index) { - if let Some(lp) = - self.extract_logprob(&ctx, -1, gen_index, new_token, params.logprob_top_k) - { - logprobs.push(lp); - } - } // Decode token to text let piece = self @@ -436,6 +436,7 @@ impl InferenceEngine { .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?; ctx.decode(&mut batch) .map_err(|e| NodeError::Inference(format!("decode failed: {e}")))?; + logits_idx = current_pos; current_pos += 1; } @@ -466,6 +467,105 @@ impl InferenceEngine { }) } + /// Prefill-only validation: tokenize prompt+output, run a single forward pass, + /// and extract logprobs at the same stride positions used during generation. + /// + /// Returns an `InferenceProof` that can be compared against the original. + pub fn validate_prefill( + &self, + prompt: &str, + output_text: &str, + logprob_every_n: usize, + logprob_top_k: usize, + ) -> Result { + // Tokenize prompt alone to find the split point + let prompt_tokens = self + .model + .str_to_token(prompt, AddBos::Always) + .map_err(|e| NodeError::Inference(format!("prompt tokenization failed: {e}")))?; + let n_prompt = prompt_tokens.len(); + + // Tokenize prompt + output together + let full_text = format!("{}{}", prompt, output_text); + let all_tokens = self + .model + .str_to_token(&full_text, AddBos::Always) + .map_err(|e| NodeError::Inference(format!("full tokenization failed: {e}")))?; + let n_output = all_tokens.len().saturating_sub(n_prompt); + + if n_output == 0 { + return Ok(InferenceProof { + logprobs: vec![], + kv_cache_hash: None, + }); + } + + // Compute probe positions: gen_index values [0, N, 2N, ...] where each is < n_output + let mut probe_gen_indices: Vec = Vec::new(); + if logprob_every_n > 0 { + let mut k = 0; + while k < n_output { + probe_gen_indices.push(k); + k += logprob_every_n; + } + } + + if probe_gen_indices.is_empty() { + return Ok(InferenceProof { + logprobs: vec![], + kv_cache_hash: None, + }); + } + + // Create context sized to fit all tokens + let ctx_size = std::num::NonZeroU32::new((all_tokens.len() + 64).max(2048) as u32); + let ctx_params = LlamaContextParams::default().with_n_ctx(ctx_size); + + let mut ctx = self + .model + .new_context(&self.backend, ctx_params) + .map_err(|e| NodeError::Inference(format!("failed to create context: {e}")))?; + + // Build batch with all tokens. Set output=true only at positions where we need logits. + // For probe gen_index k, we need logits at sequence position (n_prompt + k - 1) for k > 0, + // and at (n_prompt - 1) for k == 0 (last prompt token predicts first output token). + let mut output_positions: Vec = Vec::new(); + for &k in &probe_gen_indices { + let seq_pos = if k == 0 { n_prompt - 1 } else { n_prompt + k - 1 }; + output_positions.push(seq_pos); + } + + let mut batch = LlamaBatch::new(all_tokens.len().max(1), 1); + for (i, &token) in all_tokens.iter().enumerate() { + let is_output = output_positions.contains(&i); + batch + .add(token, i as i32, &[0], is_output) + .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?; + } + + // Single forward pass + ctx.decode(&mut batch) + .map_err(|e| NodeError::Inference(format!("prefill decode failed: {e}")))?; + + // Extract logprobs at each probe position. + // get_logits_ith(pos) expects the sequence position passed to batch.add(). + let mut logprobs: Vec = Vec::new(); + for (probe_idx, &gen_index) in probe_gen_indices.iter().enumerate() { + let target_token = all_tokens[n_prompt + gen_index]; + let seq_pos = output_positions[probe_idx] as i32; + if let Some(lp) = + self.extract_logprob(&ctx, seq_pos, gen_index, target_token, logprob_top_k) + { + logprobs.push(lp); + } + } + + Ok(InferenceProof { + logprobs, + kv_cache_hash: None, + }) + } + /// Extract logprob data at a given batch index. fn extract_logprob( &self, @@ -673,4 +773,102 @@ mod tests { assert!(!result.text.is_empty(), "model should produce output"); assert!(result.tokens_generated > 0); } + + /// Helper to load lfm2.5:1.2b from cache (or download). + async fn load_text_model() -> (InferenceEngine, String) { + let registry = crate::models::default_registry(); + let spec = registry.get("lfm2.5:1.2b").unwrap().clone(); + + let cache_dir = dirs::cache_dir() + .unwrap_or_else(|| std::path::PathBuf::from(".")) + .join("dria-test-models"); + let cache = crate::models::ModelCache::new(cache_dir).unwrap(); + + let model_path = if let Some(p) = cache.get_local_path(&spec) { + println!("model found in cache: {}", p.display()); + p + } else { + println!("downloading model..."); + let hf_path = crate::models::ModelDownloader::download(&spec).await.unwrap(); + cache.link_model(&spec, &hf_path).unwrap() + }; + + let engine = InferenceEngine::load(&model_path, 0, None).unwrap(); + (engine, spec.name) + } + + /// End-to-end validation test: + /// 1. Generate text with logprob_every_n=8 (greedy so output is deterministic) + /// 2. validate_prefill() with the same prompt+output + /// 3. compare_proofs() — should Pass + /// + /// Run with: + /// cargo test test_validate_prefill_e2e -- --ignored --nocapture + #[tokio::test] + #[ignore] // requires lfm2.5:1.2b model (~800 MB) + async fn test_validate_prefill_e2e() { + let (engine, _model_name) = load_text_model().await; + + let messages = vec![ChatMessage { + role: "user".into(), + content: "What is 2 + 2? Answer in one word.".into(), + }]; + + let prompt = engine.apply_template(&messages).unwrap(); + + // Generate with logprobs every 8 tokens, greedy (deterministic) + let params = GenerateParams { + max_tokens: 64, + temperature: 0.0, + logprob_every_n: 8, + logprob_top_k: 5, + ..Default::default() + }; + + let gen_result = engine + .generate(&prompt, ¶ms, |_| ControlFlow::Continue(())) + .unwrap(); + + println!("generated: {:?}", gen_result.text); + println!("tokens: {}", gen_result.tokens_generated); + + let original_proof = gen_result.proof.as_ref().expect("should have proof with logprob_every_n=8"); + println!("original proof positions: {:?}", + original_proof.logprobs.iter().map(|lp| lp.position).collect::>() + ); + + // Now validate: prefill-only forward pass + let validator_proof = engine + .validate_prefill(&prompt, &gen_result.text, 8, 5) + .unwrap(); + + println!("validator proof positions: {:?}", + validator_proof.logprobs.iter().map(|lp| lp.position).collect::>() + ); + + // Both proofs should have the same positions + assert_eq!( + original_proof.logprobs.len(), + validator_proof.logprobs.len(), + "proof lengths should match" + ); + + // Compare position by position + for (orig, val) in original_proof.logprobs.iter().zip(validator_proof.logprobs.iter()) { + assert_eq!(orig.position, val.position, "positions should match"); + assert_eq!(orig.token_id, val.token_id, "token IDs should match at position {}", orig.position); + let diff = (orig.logprob - val.logprob).abs(); + println!( + "pos {} | token '{}' | orig_lp={:.4} | val_lp={:.4} | diff={:.4}", + orig.position, orig.token_text, orig.logprob, val.logprob, diff + ); + assert!( + diff < 0.5, + "logprob diff too large at position {}: {diff}", + orig.position + ); + } + + println!("\nall positions match — validation passed!"); + } } diff --git a/src/main.rs b/src/main.rs index 4109e0f9..ec9bfa7c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -449,6 +449,27 @@ async fn handle_router_message( } } } + RouterMessage::ValidationTask { + validation_id, + model, + messages, + output_text, + logprob_every_n, + logprob_top_k, + } => { + tracing::info!(%validation_id, %model, "received validation task"); + match worker.try_accept_validation( + validation_id, + &model, + messages, + output_text, + logprob_every_n, + logprob_top_k, + ) { + Ok(()) => tracing::debug!(%validation_id, "validation accepted"), + Err(reason) => tracing::warn!(%validation_id, ?reason, "validation rejected"), + } + } RouterMessage::ModelRegistryUpdate { entries } => { tracing::info!(count = entries.len(), "received model registry update"); diff --git a/src/worker.rs b/src/worker.rs index 28ccb6cc..2ab3b14c 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -120,10 +120,10 @@ impl Worker { temperature, top_p: 0.9, seed: None, - logprob_positions: validation + logprob_every_n: validation .as_ref() - .map(|v| v.logprob_positions.clone()) - .unwrap_or_default(), + .map(|v| v.logprob_every_n) + .unwrap_or(0), logprob_top_k: validation.as_ref().map(|v| v.logprob_top_k).unwrap_or(5), }; @@ -225,6 +225,56 @@ impl Worker { pub fn has_model(&self, name: &str) -> bool { self.engines.contains_key(name) } + + /// Try to accept a validation task. Same capacity semantics as `try_accept()`. + pub fn try_accept_validation( + &self, + validation_id: Uuid, + model: &str, + messages: Vec, + output_text: String, + logprob_every_n: usize, + logprob_top_k: usize, + ) -> Result<(), RejectReason> { + let (engine, _model_type) = self + .engines + .get(model) + .ok_or(RejectReason::ModelNotLoaded)?; + let engine = Arc::clone(engine); + + // CAS-decrement capacity + loop { + let current = self.capacity.load(Ordering::Acquire); + if current == 0 { + return Err(RejectReason::AtCapacity); + } + if self + .capacity + .compare_exchange_weak(current, current - 1, Ordering::AcqRel, Ordering::Acquire) + .is_ok() + { + break; + } + } + + let capacity = Arc::clone(&self.capacity); + + let handle = tokio::task::spawn_blocking(move || { + let result = run_validation( + &engine, + validation_id, + messages, + output_text, + logprob_every_n, + logprob_top_k, + ); + capacity.fetch_add(1, Ordering::Release); + result + }); + self.in_flight.push(handle); + + Ok(()) + } } /// Run inference synchronously (called from `spawn_blocking`). @@ -375,6 +425,43 @@ fn run_inference_streaming( } } +/// Run prefill-only validation (called from `spawn_blocking`). +fn run_validation( + engine: &InferenceEngine, + validation_id: Uuid, + messages: Vec, + output_text: String, + logprob_every_n: usize, + logprob_top_k: usize, +) -> CompletedTask { + let prompt = match engine.apply_template(&messages) { + Ok(p) => p, + Err(e) => { + return CompletedTask { + task_id: validation_id, + result: Err(e), + stream: false, + }; + } + }; + + match engine.validate_prefill(&prompt, &output_text, logprob_every_n, logprob_top_k) { + Ok(proof) => CompletedTask { + task_id: validation_id, + result: Ok(NodeMessage::ValidationResult { + validation_id, + proof, + }), + stream: false, + }, + Err(e) => CompletedTask { + task_id: validation_id, + result: Err(e), + stream: false, + }, + } +} + fn build_task_result(task_id: Uuid, result: InferenceResult) -> NodeMessage { NodeMessage::TaskResult { task_id, From 94e90b5b90c300fb2b6af2ee37787ffa52903456 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Mon, 9 Mar 2026 00:59:59 +0300 Subject: [PATCH 35/57] Fix validate_prefill: use output index for get_logits_ith --- src/inference/engine.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/inference/engine.rs b/src/inference/engine.rs index a842610a..4aee719a 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -552,9 +552,10 @@ impl InferenceEngine { let mut logprobs: Vec = Vec::new(); for (probe_idx, &gen_index) in probe_gen_indices.iter().enumerate() { let target_token = all_tokens[n_prompt + gen_index]; - let seq_pos = output_positions[probe_idx] as i32; + // get_logits_ith takes the output index (0-based among tokens with output=true), + // NOT the sequence position. if let Some(lp) = - self.extract_logprob(&ctx, seq_pos, gen_index, target_token, logprob_top_k) + self.extract_logprob(&ctx, probe_idx as i32, gen_index, target_token, logprob_top_k) { logprobs.push(lp); } From fa3b793e6afa744bb6cfb7dbbeacaece2cee0311 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Mon, 9 Mar 2026 01:08:05 +0300 Subject: [PATCH 36/57] Fix logprob extraction: always use output index 0 in generation loop --- src/inference/engine.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/inference/engine.rs b/src/inference/engine.rs index 4aee719a..2329336a 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -234,10 +234,6 @@ impl InferenceEngine { let mut logprobs: Vec = Vec::new(); let mut current_pos = tokens.len() as i32; let mut decoder = encoding_rs::UTF_8.new_decoder(); - // Track the sequence position where logits are available (for extract_logprob). - // sampler.sample() always uses -1 ("last logits" in the C API). - let mut logits_pos: i32 = (tokens.len() - 1) as i32; - for _ in 0..params.max_tokens { let new_token = sampler.sample(&ctx, -1); sampler.accept(new_token); @@ -246,11 +242,12 @@ impl InferenceEngine { break; } - // Extract logprobs at stride positions + // Extract logprobs at stride positions. + // Each decode has exactly one output token, so the output index is always 0. let gen_index = generated_count as usize; if params.logprob_every_n > 0 && gen_index.is_multiple_of(params.logprob_every_n) { if let Some(lp) = - self.extract_logprob(&ctx, logits_pos, gen_index, new_token, params.logprob_top_k) + self.extract_logprob(&ctx, 0, gen_index, new_token, params.logprob_top_k) { logprobs.push(lp); } @@ -280,7 +277,6 @@ impl InferenceEngine { .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?; ctx.decode(&mut batch) .map_err(|e| NodeError::Inference(format!("decode failed: {e}")))?; - logits_pos = current_pos; // logits available at the position we just decoded current_pos += 1; } From 9d3a592d4edc7fb45444dcefdaafe09120f7b41e Mon Sep 17 00:00:00 2001 From: andthattoo Date: Mon, 9 Mar 2026 01:14:56 +0300 Subject: [PATCH 37/57] Fix get_logits_ith: use batch index in both generation and validation paths --- src/inference/engine.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/inference/engine.rs b/src/inference/engine.rs index 2329336a..23f71f69 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -234,6 +234,10 @@ impl InferenceEngine { let mut logprobs: Vec = Vec::new(); let mut current_pos = tokens.len() as i32; let mut decoder = encoding_rs::UTF_8.new_decoder(); + // Batch index where logits are available: + // after prompt eval → last prompt token; after each single-token decode → 0 + let mut logit_batch_idx: i32 = (tokens.len() - 1) as i32; + for _ in 0..params.max_tokens { let new_token = sampler.sample(&ctx, -1); sampler.accept(new_token); @@ -242,12 +246,11 @@ impl InferenceEngine { break; } - // Extract logprobs at stride positions. - // Each decode has exactly one output token, so the output index is always 0. + // Extract logprobs at stride positions let gen_index = generated_count as usize; if params.logprob_every_n > 0 && gen_index.is_multiple_of(params.logprob_every_n) { if let Some(lp) = - self.extract_logprob(&ctx, 0, gen_index, new_token, params.logprob_top_k) + self.extract_logprob(&ctx, logit_batch_idx, gen_index, new_token, params.logprob_top_k) { logprobs.push(lp); } @@ -277,6 +280,7 @@ impl InferenceEngine { .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?; ctx.decode(&mut batch) .map_err(|e| NodeError::Inference(format!("decode failed: {e}")))?; + logit_batch_idx = 0; // single-token batch → logits at batch index 0 current_pos += 1; } @@ -544,14 +548,13 @@ impl InferenceEngine { .map_err(|e| NodeError::Inference(format!("prefill decode failed: {e}")))?; // Extract logprobs at each probe position. - // get_logits_ith(pos) expects the sequence position passed to batch.add(). + // get_logits_ith takes the batch index where output=true was set. let mut logprobs: Vec = Vec::new(); for (probe_idx, &gen_index) in probe_gen_indices.iter().enumerate() { let target_token = all_tokens[n_prompt + gen_index]; - // get_logits_ith takes the output index (0-based among tokens with output=true), - // NOT the sequence position. + let batch_idx = output_positions[probe_idx] as i32; if let Some(lp) = - self.extract_logprob(&ctx, probe_idx as i32, gen_index, target_token, logprob_top_k) + self.extract_logprob(&ctx, batch_idx, gen_index, target_token, logprob_top_k) { logprobs.push(lp); } From 42705aefa3821747bce3ea0048f8fbf6f2dda445 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Mon, 9 Mar 2026 03:54:37 +0300 Subject: [PATCH 38/57] Switch dkn-protocol from path to git dependency Pin to rev ce026e3 so CI and other machines can build without a local checkout of dkn-protocol. Co-Authored-By: Claude Opus 4.6 --- Cargo.lock | 1 + Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 0562eb9c..bc960bec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -412,6 +412,7 @@ dependencies = [ [[package]] name = "dkn-protocol" version = "0.1.0" +source = "git+https://github.com/firstbatchxyz/dkn-protocol.git?rev=ce026e3#ce026e35f646b4ae3cdf6579758cb65b60c5aa97" dependencies = [ "base64", "quinn", diff --git a/Cargo.toml b/Cargo.toml index 94fbc986..03298c20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ uuid = { version = "1", features = ["v7", "serde"] } dialoguer = "0.11" dirs = "6" encoding_rs = "0.8" -dkn-protocol = { path = "../dkn-protocol" } +dkn-protocol = { git = "https://github.com/firstbatchxyz/dkn-protocol.git", rev = "ce026e3" } quinn = "0.11" rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } rustls-native-certs = "0.8" From 6064734f2a6970f1dac96c5edfd1d4b3f5c6f1e2 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 10 Mar 2026 02:34:08 +0300 Subject: [PATCH 39/57] Add structured output support (response_format) and fix double-accept sampler bug Plumb OpenAI-compatible response_format parameter through the stack: - ResponseFormat enum (JsonObject, JsonSchema) from dkn-protocol - Worker converts response_format to GBNF grammar via json_schema_to_grammar - Grammar sampler inserted first in chain (masks invalid tokens before sampling) - Both generate() and generate_multimodal() support grammar constraints Fix critical double-accept bug: sample() internally calls accept(), so the explicit sampler.accept() after sample() was advancing grammar stacks twice, causing GGML_ASSERT(!stacks.empty()) crash on any grammar-constrained generation. Switch dkn-protocol dependency from rev to branch = "main". --- Cargo.lock | 3 +- Cargo.toml | 2 +- src/inference/engine.rs | 286 ++++++++++++++++++++++++++++++++++++-- src/main.rs | 3 +- src/network/connection.rs | 1 + src/network/protocol.rs | 2 +- src/worker.rs | 25 +++- 7 files changed, 307 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bc960bec..49f36a03 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -412,12 +412,13 @@ dependencies = [ [[package]] name = "dkn-protocol" version = "0.1.0" -source = "git+https://github.com/firstbatchxyz/dkn-protocol.git?rev=ce026e3#ce026e35f646b4ae3cdf6579758cb65b60c5aa97" +source = "git+https://github.com/firstbatchxyz/dkn-protocol.git?branch=main#c3b1fb6f7884feffd8b6ccc8a1ab11debc0ab8f3" dependencies = [ "base64", "quinn", "rmp-serde", "serde", + "serde_json", "thiserror 2.0.18", "uuid", ] diff --git a/Cargo.toml b/Cargo.toml index 03298c20..99ede426 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ uuid = { version = "1", features = ["v7", "serde"] } dialoguer = "0.11" dirs = "6" encoding_rs = "0.8" -dkn-protocol = { git = "https://github.com/firstbatchxyz/dkn-protocol.git", rev = "ce026e3" } +dkn-protocol = { git = "https://github.com/firstbatchxyz/dkn-protocol.git", branch = "main" } quinn = "0.11" rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } rustls-native-certs = "0.8" diff --git a/src/inference/engine.rs b/src/inference/engine.rs index 23f71f69..0b2eb825 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -30,6 +30,8 @@ pub struct GenerateParams { pub logprob_every_n: usize, /// Top-k alternatives to collect at each logprob position. pub logprob_top_k: usize, + /// Optional GBNF grammar string for constrained output. + pub grammar: Option, } impl Default for GenerateParams { @@ -41,6 +43,7 @@ impl Default for GenerateParams { seed: None, logprob_every_n: 0, logprob_top_k: 5, + grammar: None, } } } @@ -216,8 +219,14 @@ impl InferenceEngine { .map_err(|e| NodeError::Inference(format!("prompt decode failed: {e}")))?; let prompt_eval_time_ms = prompt_start.elapsed().as_millis() as u64; - // Build sampler chain (seed is passed via the dist sampler) + // Build sampler chain (grammar first to mask invalid tokens, then sampling) let mut samplers = vec![]; + if let Some(ref grammar_str) = params.grammar { + samplers.push( + LlamaSampler::grammar(&self.model, grammar_str, "root") + .map_err(|e| NodeError::Inference(format!("grammar error: {e}")))?, + ); + } if params.temperature > 0.0 { samplers.push(LlamaSampler::top_p(params.top_p, 1)); samplers.push(LlamaSampler::temp(params.temperature)); @@ -239,8 +248,8 @@ impl InferenceEngine { let mut logit_batch_idx: i32 = (tokens.len() - 1) as i32; for _ in 0..params.max_tokens { + // sample() internally calls apply + select + accept let new_token = sampler.sample(&ctx, -1); - sampler.accept(new_token); if self.model.is_eog_token(new_token) { break; @@ -379,8 +388,14 @@ impl InferenceEngine { .map_err(|e| NodeError::Inference(format!("mtmd eval_chunks failed: {e}")))?; let prompt_eval_time_ms = prompt_start.elapsed().as_millis() as u64; - // Build sampler chain + // Build sampler chain (grammar first to mask invalid tokens, then sampling) let mut samplers = vec![]; + if let Some(ref grammar_str) = params.grammar { + samplers.push( + LlamaSampler::grammar(&self.model, grammar_str, "root") + .map_err(|e| NodeError::Inference(format!("grammar error: {e}")))?, + ); + } if params.temperature > 0.0 { samplers.push(LlamaSampler::top_p(params.top_p, 1)); samplers.push(LlamaSampler::temp(params.temperature)); @@ -403,8 +418,8 @@ impl InferenceEngine { let mut logits_idx: i32 = -1; for _ in 0..params.max_tokens { + // sample() internally calls apply + select + accept let new_token = sampler.sample(&ctx, logits_idx); - sampler.accept(new_token); if self.model.is_eog_token(new_token) { break; @@ -774,11 +789,8 @@ mod tests { assert!(result.tokens_generated > 0); } - /// Helper to load lfm2.5:1.2b from cache (or download). - async fn load_text_model() -> (InferenceEngine, String) { - let registry = crate::models::default_registry(); - let spec = registry.get("lfm2.5:1.2b").unwrap().clone(); - + /// Helper to load a model from cache (or download). + async fn load_model(spec: crate::models::registry::ModelSpec) -> (InferenceEngine, String) { let cache_dir = dirs::cache_dir() .unwrap_or_else(|| std::path::PathBuf::from(".")) .join("dria-test-models"); @@ -793,8 +805,29 @@ mod tests { cache.link_model(&spec, &hf_path).unwrap() }; + let name = spec.name.clone(); let engine = InferenceEngine::load(&model_path, 0, None).unwrap(); - (engine, spec.name) + (engine, name) + } + + /// Load lfm2.5:1.2b from the default registry. + async fn load_text_model() -> (InferenceEngine, String) { + let registry = crate::models::default_registry(); + let spec = registry.get("lfm2.5:1.2b").unwrap().clone(); + load_model(spec).await + } + + /// Load a small Qwen 3.5 model for grammar-compatible testing. + async fn load_qwen_model() -> (InferenceEngine, String) { + let spec = crate::models::registry::ModelSpec { + name: "qwen3.5:0.8b".into(), + hf_repo: "unsloth/Qwen3.5-0.8B-GGUF".into(), + hf_file: "Qwen3.5-0.8B-Q4_K_M.gguf".into(), + sha256: None, + model_type: dkn_protocol::ModelType::Text, + hf_mmproj_file: None, + }; + load_model(spec).await } /// End-to-end validation test: @@ -871,4 +904,237 @@ mod tests { println!("\nall positions match — validation passed!"); } + + /// End-to-end structured output test: + /// 1. Test a trivial GBNF grammar to verify grammar sampling works + /// 2. Generate with json_object grammar (greedy) — output must be valid JSON + /// 3. Generate with json_schema grammar — output must match the schema + /// + /// Run with: + /// cargo test test_structured_output_e2e -- --ignored --nocapture + #[tokio::test] + #[ignore] // requires qwen3.5:0.8b model (~533 MB download) + async fn test_structured_output_e2e() { + let (engine, _model_name) = load_qwen_model().await; + + // --- Step 1: trivial GBNF grammar to confirm grammar sampling works --- + { + let grammar = r#"root ::= "hello""#.to_string(); + let messages = vec![ChatMessage { + role: "user".into(), + content: "Say hello".into(), + }]; + let prompt = engine.apply_template(&messages).unwrap(); + + let params = GenerateParams { + max_tokens: 16, + temperature: 0.0, + grammar: Some(grammar), + ..Default::default() + }; + + println!("\n--- trivial grammar test ---"); + let result = engine + .generate(&prompt, ¶ms, |_| ControlFlow::Continue(())) + .unwrap(); + println!("output: {:?}", result.text); + assert_eq!(result.text, "hello", "trivial grammar should constrain to 'hello'"); + println!("trivial grammar OK"); + } + + // --- Step 2: json_object mode (permissive JSON) --- + { + let json_grammar = llama_cpp_2::json_schema_to_grammar(r#"{"type": "object"}"#) + .expect("json_object grammar should convert"); + println!("\njson_object grammar length: {} chars", json_grammar.len()); + + let messages = vec![ChatMessage { + role: "user".into(), + content: "Return a JSON object with a field called 'answer' set to 42.".into(), + }]; + let prompt = engine.apply_template(&messages).unwrap(); + + let params = GenerateParams { + max_tokens: 128, + temperature: 0.0, + grammar: Some(json_grammar), + ..Default::default() + }; + + print!("\n--- json_object output ---\n"); + let result = engine + .generate(&prompt, ¶ms, |tok| { + print!("{}", tok.text); + ControlFlow::Continue(()) + }) + .unwrap(); + println!("\n--- end ---"); + + let text = result.text.trim(); + assert!(!text.is_empty(), "should produce output"); + + let parsed: serde_json::Value = + serde_json::from_str(text).expect("json_object output must be valid JSON"); + assert!(parsed.is_object(), "should be a JSON object"); + println!("parsed JSON: {parsed}"); + } + + // --- Step 3: json_schema mode (specific schema) --- + { + let schema = serde_json::json!({ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "integer" } + }, + "required": ["name", "age"], + "additionalProperties": false + }); + + let schema_str = serde_json::to_string(&schema).unwrap(); + let schema_grammar = llama_cpp_2::json_schema_to_grammar(&schema_str) + .expect("json_schema grammar should convert"); + println!("\njson_schema grammar length: {} chars", schema_grammar.len()); + + let messages = vec![ChatMessage { + role: "user".into(), + content: "Give me a person named Alice who is 30 years old.".into(), + }]; + let prompt = engine.apply_template(&messages).unwrap(); + + let params = GenerateParams { + max_tokens: 128, + temperature: 0.0, + grammar: Some(schema_grammar), + ..Default::default() + }; + + print!("\n--- json_schema output ---\n"); + let result = engine + .generate(&prompt, ¶ms, |tok| { + print!("{}", tok.text); + ControlFlow::Continue(()) + }) + .unwrap(); + println!("\n--- end ---"); + + let text = result.text.trim(); + assert!(!text.is_empty(), "should produce output"); + + let parsed: serde_json::Value = + serde_json::from_str(text).expect("json_schema output must be valid JSON"); + assert!(parsed.is_object(), "should be a JSON object"); + assert!(parsed.get("name").is_some(), "should have 'name' field"); + assert!(parsed.get("age").is_some(), "should have 'age' field"); + assert!(parsed["name"].is_string(), "'name' should be a string"); + assert!(parsed["age"].is_number(), "'age' should be a number"); + println!("parsed JSON: {parsed}"); + } + + println!("\nstructured output test passed!"); + } + + /// Grammar test with lfm2.5:1.2b — verify grammar sampling works across tokenizer types. + /// + /// Run with: + /// cargo test test_structured_output_lfm2 -- --ignored --nocapture + #[tokio::test] + #[ignore] // requires lfm2.5:1.2b model (~800 MB) + async fn test_structured_output_lfm2() { + let (engine, _model_name) = load_text_model().await; + + // Trivial grammar + { + let grammar = r#"root ::= "hello""#.to_string(); + let messages = vec![ChatMessage { + role: "user".into(), + content: "Say hello".into(), + }]; + let prompt = engine.apply_template(&messages).unwrap(); + + let params = GenerateParams { + max_tokens: 16, + temperature: 0.0, + grammar: Some(grammar), + ..Default::default() + }; + + println!("\n--- lfm2 trivial grammar test ---"); + let result = engine + .generate(&prompt, ¶ms, |_| ControlFlow::Continue(())) + .unwrap(); + println!("output: {:?}", result.text); + assert_eq!(result.text, "hello"); + println!("trivial grammar OK"); + } + + // Class-like schema with nested object, array, and enum + { + let schema = serde_json::json!({ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "integer" }, + "role": { "type": "string", "enum": ["admin", "user", "moderator"] }, + "address": { + "type": "object", + "properties": { + "city": { "type": "string" }, + "country": { "type": "string" } + }, + "required": ["city", "country"], + "additionalProperties": false + }, + "tags": { + "type": "array", + "items": { "type": "string" } + } + }, + "required": ["name", "age", "role", "address", "tags"], + "additionalProperties": false + }); + + let schema_str = serde_json::to_string(&schema).unwrap(); + let grammar = llama_cpp_2::json_schema_to_grammar(&schema_str) + .expect("class schema should convert"); + println!("\nclass schema grammar length: {} chars", grammar.len()); + + let messages = vec![ChatMessage { + role: "user".into(), + content: "Create a user profile for Alice, age 30, admin role, lives in Istanbul Turkey, tags: developer and lead.".into(), + }]; + let prompt = engine.apply_template(&messages).unwrap(); + + let params = GenerateParams { + max_tokens: 256, + temperature: 0.0, + grammar: Some(grammar), + ..Default::default() + }; + + print!("\n--- lfm2 class-like schema output ---\n"); + let result = engine + .generate(&prompt, ¶ms, |tok| { + print!("{}", tok.text); + ControlFlow::Continue(()) + }) + .unwrap(); + println!("\n--- end ---"); + + let parsed: serde_json::Value = + serde_json::from_str(result.text.trim()).expect("must be valid JSON"); + assert!(parsed.is_object()); + assert!(parsed["name"].is_string()); + assert!(parsed["age"].is_number()); + let role = parsed["role"].as_str().unwrap(); + assert!(["admin", "user", "moderator"].contains(&role), "role must be enum value"); + assert!(parsed["address"].is_object()); + assert!(parsed["address"]["city"].is_string()); + assert!(parsed["address"]["country"].is_string()); + assert!(parsed["tags"].is_array()); + println!("parsed: {parsed}"); + } + + println!("\nlfm2 structured output OK"); + } } diff --git a/src/main.rs b/src/main.rs index ec9bfa7c..3eda28ba 100644 --- a/src/main.rs +++ b/src/main.rs @@ -397,6 +397,7 @@ async fn handle_router_message( temperature, validation, stream, + response_format, } => { tracing::info!(%task_id, %model, stream, "received task assignment"); let stream_tx = if stream { @@ -404,7 +405,7 @@ async fn handle_router_message( } else { None }; - match worker.try_accept(task_id, &model, messages, max_tokens, temperature, validation, stream, stream_tx) + match worker.try_accept(task_id, &model, messages, max_tokens, temperature, validation, stream, stream_tx, response_format) { Ok(()) => { tracing::debug!(%task_id, "task accepted"); diff --git a/src/network/connection.rs b/src/network/connection.rs index 57305caa..68b48db1 100644 --- a/src/network/connection.rs +++ b/src/network/connection.rs @@ -605,6 +605,7 @@ mod tests { temperature: 0.7, validation: None, stream: false, + response_format: None, }, ) .await diff --git a/src/network/protocol.rs b/src/network/protocol.rs index 9f53ef2b..2b8460a2 100644 --- a/src/network/protocol.rs +++ b/src/network/protocol.rs @@ -1,4 +1,4 @@ pub use dkn_protocol::{ read_framed, write_framed, Capacity, ChatMessage, ModelType, NodeMessage, RejectReason, - RouterMessage, TaskStats, ValidationRequest, + ResponseFormat, RouterMessage, TaskStats, ValidationRequest, }; diff --git a/src/worker.rs b/src/worker.rs index 2ab3b14c..1092c6b5 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -12,7 +12,8 @@ use uuid::Uuid; use crate::error::NodeError; use crate::inference::{GenerateParams, InferenceEngine, InferenceResult}; use crate::network::protocol::{ - Capacity, ChatMessage, ModelType, NodeMessage, RejectReason, TaskStats, ValidationRequest, + Capacity, ChatMessage, ModelType, NodeMessage, RejectReason, ResponseFormat, TaskStats, + ValidationRequest, }; /// A completed inference task ready to be sent back. @@ -71,6 +72,7 @@ impl Worker { validation: Option, stream: bool, stream_tx: Option>, + response_format: Option, ) -> Result<(), RejectReason> { // Look up engine + model_type for the requested model (fail fast before decrementing capacity) let (engine, model_type) = self @@ -114,6 +116,26 @@ impl Worker { } } + // Convert response_format to GBNF grammar + let grammar = match response_format { + Some(ResponseFormat::JsonObject) => { + let schema = r#"{"type": "object"}"#; + Some( + llama_cpp_2::json_schema_to_grammar(schema) + .map_err(|e| RejectReason::InvalidRequest(format!("json grammar error: {e}")))?, + ) + } + Some(ResponseFormat::JsonSchema { ref json_schema }) => { + let schema_str = serde_json::to_string(&json_schema.schema) + .map_err(|e| RejectReason::InvalidRequest(format!("invalid schema: {e}")))?; + Some( + llama_cpp_2::json_schema_to_grammar(&schema_str) + .map_err(|e| RejectReason::InvalidRequest(format!("schema conversion failed: {e}")))?, + ) + } + None => None, + }; + // Build generate params let params = GenerateParams { max_tokens, @@ -125,6 +147,7 @@ impl Worker { .map(|v| v.logprob_every_n) .unwrap_or(0), logprob_top_k: validation.as_ref().map(|v| v.logprob_top_k).unwrap_or(5), + grammar, }; let capacity = Arc::clone(&self.capacity); From c181ade40ba58f6d916b4a9ea2bb549fe7fda053 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 10 Mar 2026 18:55:46 +0300 Subject: [PATCH 40/57] Fix multimodal inference crash: use -1 sentinel for logits index in sampling generate_multimodal() was setting logits_idx to the sequence position (current_pos) after single-token decode, but get_logits_ith expects a batch output index. Single-token decode has only one output slot (index 0), so passing e.g. 55 caused a panic. Always use -1 (C API sentinel for "last logits"), matching the pattern in generate(). Co-Authored-By: Claude Opus 4.6 --- src/inference/engine.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/inference/engine.rs b/src/inference/engine.rs index 0b2eb825..0886f7f3 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -413,13 +413,13 @@ impl InferenceEngine { let mut current_pos = n_past; let mut decoder = encoding_rs::UTF_8.new_decoder(); let mut batch = LlamaBatch::new(1, 1); - // After eval_chunks the logits index is opaque; use -1 for first sample. + // Always use -1 (C API sentinel for "last logits") for sampling. + // After single-token decode, batch output index is 0, but -1 always works. // Multimodal tasks skip validation so logprob extraction is not needed. - let mut logits_idx: i32 = -1; for _ in 0..params.max_tokens { // sample() internally calls apply + select + accept - let new_token = sampler.sample(&ctx, logits_idx); + let new_token = sampler.sample(&ctx, -1); if self.model.is_eog_token(new_token) { break; @@ -451,7 +451,6 @@ impl InferenceEngine { .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?; ctx.decode(&mut batch) .map_err(|e| NodeError::Inference(format!("decode failed: {e}")))?; - logits_idx = current_pos; current_pos += 1; } From 2c987ee28b61d88bcaf8a04552b363efcc9b4ebc Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 12 Mar 2026 03:38:51 +0300 Subject: [PATCH 41/57] Add qwen3.5:0.8b, qwen3.5:2b, nemotron:30b-a3b to model registry --- src/models/registry.rs | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/src/models/registry.rs b/src/models/registry.rs index 4a9b3331..6459e1cb 100644 --- a/src/models/registry.rs +++ b/src/models/registry.rs @@ -94,6 +94,30 @@ pub fn default_registry() -> HashMap { model_type: ModelType::Vision, hf_mmproj_file: Some("mmproj-Qwen3.5-9B-BF16.gguf".into()), }, + ModelSpec { + name: "qwen3.5:0.8b".into(), + hf_repo: "unsloth/Qwen3.5-0.8B-GGUF".into(), + hf_file: "Qwen3.5-0.8B-Q4_K_M.gguf".into(), + sha256: None, + model_type: ModelType::Vision, + hf_mmproj_file: Some("mmproj-BF16.gguf".into()), + }, + ModelSpec { + name: "qwen3.5:2b".into(), + hf_repo: "unsloth/Qwen3.5-2B-GGUF".into(), + hf_file: "Qwen3.5-2B-Q4_K_M.gguf".into(), + sha256: None, + model_type: ModelType::Vision, + hf_mmproj_file: Some("mmproj-BF16.gguf".into()), + }, + ModelSpec { + name: "nemotron:30b-a3b".into(), + hf_repo: "unsloth/Nemotron-3-Nano-30B-A3B-GGUF".into(), + hf_file: "Nemotron-3-Nano-30B-A3B-Q4_K_M.gguf".into(), + sha256: None, + model_type: ModelType::Text, + hf_mmproj_file: None, + }, ]; entries.into_iter().map(|s| (s.name.clone(), s)).collect() @@ -164,11 +188,14 @@ mod tests { "nanbeige:3b", "locooperator:4b", "qwen3.5:9b", + "qwen3.5:0.8b", + "qwen3.5:2b", + "nemotron:30b-a3b", ]; for name in &expected { assert!(reg.contains_key(*name), "missing model: {name}"); } - assert_eq!(reg.len(), 9); + assert_eq!(reg.len(), 12); } #[test] @@ -215,6 +242,9 @@ mod tests { assert_eq!(reg["qwen3.5:9b"].model_type, ModelType::Vision); assert_eq!(reg["qwen3.5:27b"].model_type, ModelType::Vision); assert_eq!(reg["qwen3.5:35b-a3b"].model_type, ModelType::Vision); + assert_eq!(reg["qwen3.5:0.8b"].model_type, ModelType::Vision); + assert_eq!(reg["qwen3.5:2b"].model_type, ModelType::Vision); + assert_eq!(reg["nemotron:30b-a3b"].model_type, ModelType::Text); } #[test] @@ -225,7 +255,10 @@ mod tests { assert!(reg["qwen3.5:9b"].hf_mmproj_file.is_some()); assert!(reg["qwen3.5:27b"].hf_mmproj_file.is_some()); assert!(reg["qwen3.5:35b-a3b"].hf_mmproj_file.is_some()); + assert!(reg["qwen3.5:0.8b"].hf_mmproj_file.is_some()); + assert!(reg["qwen3.5:2b"].hf_mmproj_file.is_some()); assert!(reg["lfm2.5:1.2b"].hf_mmproj_file.is_none()); + assert!(reg["nemotron:30b-a3b"].hf_mmproj_file.is_none()); } #[test] From 48bf710286c86863585474479ffe628042ca779f Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 12 Mar 2026 11:25:16 +0300 Subject: [PATCH 42/57] Add size estimates for new models in setup mode --- src/setup.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/setup.rs b/src/setup.rs index 0ca58fc0..23c74bd3 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -31,6 +31,9 @@ fn model_size_gb(name: &str) -> Option<(f64, f64)> { "qwen3.5:27b" => Some((16.0, 18.0)), "qwen3.5:35b-a3b" => Some((20.0, 22.0)), "lfm2:24b-a2b" => Some((14.0, 16.0)), + "qwen3.5:0.8b" => Some((0.5, 1.0)), + "qwen3.5:2b" => Some((1.2, 2.0)), + "nemotron:30b-a3b" => Some((24.5, 27.0)), _ => None, } } @@ -449,6 +452,9 @@ mod tests { "qwen3.5:27b", "qwen3.5:35b-a3b", "lfm2:24b-a2b", + "qwen3.5:0.8b", + "qwen3.5:2b", + "nemotron:30b-a3b", ] { let (size, needed) = model_size_gb(name).unwrap(); assert!( From 8cc2ce3913f21a86ac773adcfe0c5aa8b5603aa6 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 12 Mar 2026 11:46:55 +0300 Subject: [PATCH 43/57] Show multiple capability tags in setup model list --- src/setup.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/setup.rs b/src/setup.rs index 23c74bd3..e193e955 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -101,8 +101,8 @@ fn detect_ram_bytes() -> Option { fn model_type_label(mt: ModelType) -> &'static str { match mt { ModelType::Text => "Text", - ModelType::Vision => "Vision", - ModelType::Audio => "Audio", + ModelType::Vision => "Text, Vision", + ModelType::Audio => "Text, Audio", } } @@ -170,7 +170,7 @@ pub async fn run_setup(data_dir: Option, gpu_layers: i32) -> Result<(), .iter() .map(|m| { format!( - "{:<22} {:<8} {:<10} ~{:.1} GB", + "{:<22} {:<14} {:<10} ~{:.1} GB", m.name, model_type_label(m.model_type), m.quant, From c33aaac97d2530f6795539d2739f8a46bf3efc51 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 12 Mar 2026 12:10:28 +0300 Subject: [PATCH 44/57] Add CUDA and ROCm GPU build targets to release workflow --- .github/workflows/releases.yml | 96 +++++++++++++++++++++++++++++++++- Cargo.toml | 1 + 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index 7c57508c..77dad65b 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -124,8 +124,102 @@ jobs: name: dria-node-${{ matrix.osname }}-${{ matrix.arch }} path: dria-node-${{ matrix.osname }}-${{ matrix.arch }}${{ matrix.extension }} + build-cuda: + runs-on: ${{ matrix.runner }} + strategy: + matrix: + include: + - { + runner: ubuntu-latest, + osname: linux, + arch: amd64-cuda, + target: x86_64-unknown-linux-gnu, + cuda: "12.6.3", + } + - { + runner: windows-latest, + osname: windows, + arch: amd64-cuda, + target: x86_64-pc-windows-msvc, + cuda: "12.6.3", + extension: ".exe", + } + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ env.RELEASE_TAG }} + + - name: Install CUDA toolkit + uses: Jimver/cuda-toolkit@v0.2.19 + with: + cuda: ${{ matrix.cuda }} + method: network + + - name: Install build dependencies (Linux) + if: matrix.osname == 'linux' + run: sudo apt-get update && sudo apt-get install -y cmake + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Build binary with CUDA + run: cargo build --bin dria-node --release --target ${{ matrix.target }} --features cuda + + - name: Strip binary (Linux) + if: matrix.osname == 'linux' + run: strip target/${{ matrix.target }}/release/dria-node + + - name: Prepare Release File + shell: bash + run: | + mv target/${{ matrix.target }}/release/dria-node${{ matrix.extension }} ./dria-node-${{ matrix.osname }}-${{ matrix.arch }}${{ matrix.extension }} + + - name: Upload Launch Artifacts + uses: actions/upload-artifact@v4 + with: + name: dria-node-${{ matrix.osname }}-${{ matrix.arch }} + path: dria-node-${{ matrix.osname }}-${{ matrix.arch }}${{ matrix.extension }} + + build-rocm: + runs-on: ubuntu-22.04 + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ env.RELEASE_TAG }} + + - name: Install ROCm HIP SDK + run: | + wget -q -O - https://repo.radeon.com/rocm/rocm.gpg.key | sudo gpg --dearmor -o /etc/apt/keyrings/rocm.gpg + echo "deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/6.3.3 jammy main" | sudo tee /etc/apt/sources.list.d/rocm.list + echo 'Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' | sudo tee /etc/apt/preferences.d/rocm-pin-600 + sudo apt-get update + sudo apt-get install -y --no-install-recommends cmake hip-dev rocm-hip-runtime-dev + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Build binary with ROCm + run: cargo build --bin dria-node --release --target x86_64-unknown-linux-gnu --features rocm + + - name: Strip binary + run: strip target/x86_64-unknown-linux-gnu/release/dria-node + + - name: Prepare Release File + run: | + mv target/x86_64-unknown-linux-gnu/release/dria-node ./dria-node-linux-amd64-rocm + + - name: Upload Launch Artifacts + uses: actions/upload-artifact@v4 + with: + name: dria-node-linux-amd64-rocm + path: dria-node-linux-amd64-rocm + release: - needs: build + needs: [build, build-cuda, build-rocm] runs-on: ubuntu-latest steps: diff --git a/Cargo.toml b/Cargo.toml index 99ede426..9783f8b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ path = "src/main.rs" default = [] cuda = ["llama-cpp-2/cuda"] metal = ["llama-cpp-2/metal"] +rocm = ["llama-cpp-2/rocm"] [dependencies] clap = { version = "4", features = ["derive", "env"] } From bcea3e82c4c5742b136981271e774336f2968cb0 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 12 Mar 2026 13:09:09 +0300 Subject: [PATCH 45/57] Add TESTER_GUIDE.md to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 2ee6387e..ecd4ed6e 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ flamegraph.svg +TESTER_GUIDE.md From 7aa7e25851649dc0f67f6d49d933fd4145e409eb Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 12 Mar 2026 16:51:11 +0300 Subject: [PATCH 46/57] Add dynamic context sizing, error propagation, and pre-flight rejection - Replace hardcoded n_ctx=2048/4096 with model-native context window (auto-detected from GGUF metadata via n_ctx_train) - Add --context-size / DRIA_CONTEXT_SIZE to optionally cap context for limited VRAM (uses min of model native and cap) - Add pre-flight check in worker: reject tasks where prompt_tokens + max_tokens exceeds context before consuming a capacity slot - Propagate inference errors back to router via StreamError so it can retry on another node instead of waiting for timeout Co-Authored-By: Claude Opus 4.6 --- src/config.rs | 16 ++++++++++ src/inference/engine.rs | 67 +++++++++++++++++++++++++++++++++-------- src/main.rs | 21 ++++++++++--- src/setup.rs | 2 +- src/worker.rs | 21 +++++++++++++ 5 files changed, 109 insertions(+), 18 deletions(-) diff --git a/src/config.rs b/src/config.rs index 543b08c5..1285e5e8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -61,6 +61,11 @@ pub enum Command { /// Skip automatic update check on startup #[arg(long, env = "DRIA_SKIP_UPDATE")] skip_update: bool, + + /// Maximum context window size (tokens). When set, engines use min(model_native, this value). + /// When unset, engines use the model's full native context window. + #[arg(long, env = "DRIA_CONTEXT_SIZE")] + context_size: Option, }, } @@ -76,6 +81,7 @@ pub struct Config { pub quant: Option, pub insecure: bool, pub skip_update: bool, + pub max_context: Option, } impl Config { @@ -91,6 +97,7 @@ impl Config { quant: Option, insecure: bool, skip_update: bool, + max_context: Option, ) -> Result { // Validate wallet key let secret_key_hex = wallet.strip_prefix("0x").unwrap_or(&wallet).to_string(); @@ -147,6 +154,7 @@ impl Config { quant, insecure, skip_update, + max_context, }) } } @@ -167,6 +175,7 @@ mod tests { None, false, false, + None, ) .unwrap(); @@ -191,6 +200,7 @@ mod tests { None, false, false, + None, ); assert!(result.is_err()); } @@ -207,6 +217,7 @@ mod tests { None, false, false, + None, ); assert!(result.is_err()); } @@ -223,6 +234,7 @@ mod tests { None, false, false, + None, ); assert!(result.is_err()); } @@ -239,6 +251,7 @@ mod tests { None, false, false, + None, ); assert!(result.is_err()); } @@ -255,6 +268,7 @@ mod tests { None, false, false, + None, ) .unwrap(); assert_eq!( @@ -275,6 +289,7 @@ mod tests { None, false, false, + None, ); assert!(result.is_err()); } @@ -291,6 +306,7 @@ mod tests { None, true, false, + None, ) .unwrap(); assert!(cfg.insecure); diff --git a/src/inference/engine.rs b/src/inference/engine.rs index 0886f7f3..9d1a0fa9 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -70,6 +70,8 @@ pub struct InferenceEngine { mtmd_ctx: Option, #[allow(dead_code)] gpu_layers: i32, + /// Effective context window size (tokens), auto-detected from model metadata. + ctx_limit: u32, } /// Helper to convert a token to a string piece using the new token_to_piece API. @@ -82,10 +84,14 @@ fn token_to_string(model: &LlamaModel, token: LlamaToken) -> String { impl InferenceEngine { /// Load a GGUF model from disk, optionally with a multimodal projector. + /// + /// `max_context` optionally caps the context window (e.g. for limited VRAM). + /// When `None`, the model's full native context window is used. pub fn load( path: &Path, gpu_layers: i32, mmproj_path: Option<&Path>, + max_context: Option, ) -> Result { let backend = LlamaBackend::init() .map_err(|e| NodeError::Inference(format!("failed to init llama backend: {e}")))?; @@ -100,6 +106,13 @@ impl InferenceEngine { let model = LlamaModel::load_from_file(&backend, path, &model_params) .map_err(|e| NodeError::Inference(format!("failed to load model: {e}")))?; + let n_ctx_train = model.n_ctx_train(); + let ctx_limit = match max_context { + Some(cap) => n_ctx_train.min(cap), + None => n_ctx_train, + }; + tracing::info!(model_ctx = n_ctx_train, effective_ctx = ctx_limit, "context window"); + let mtmd_ctx = match mmproj_path { Some(p) => { let params = MtmdContextParams::default(); @@ -126,6 +139,7 @@ impl InferenceEngine { model, mtmd_ctx, gpu_layers, + ctx_limit, }) } @@ -140,6 +154,26 @@ impl InferenceEngine { self.gpu_layers } + /// The model's native training context length. + #[allow(dead_code)] + pub fn n_ctx_train(&self) -> u32 { + self.model.n_ctx_train() + } + + /// The effective context limit (possibly capped by --context-size). + pub fn ctx_limit(&self) -> u32 { + self.ctx_limit + } + + /// Count prompt tokens without creating a context (LlamaModel is Send+Sync). + pub fn tokenize_count(&self, messages: &[ChatMessage]) -> Result { + let prompt = self.apply_template(messages)?; + let tokens = self.model + .str_to_token(&prompt, AddBos::Always) + .map_err(|e| NodeError::Inference(format!("tokenization failed: {e}")))?; + Ok(tokens.len() as u32) + } + /// Apply the GGUF-embedded chat template to produce a formatted prompt string. pub fn apply_template(&self, messages: &[ChatMessage]) -> Result { let template = self @@ -191,14 +225,6 @@ impl InferenceEngine { where F: FnMut(StreamToken) -> ControlFlow<()>, { - let ctx_size = std::num::NonZeroU32::new(2048); - let ctx_params = LlamaContextParams::default().with_n_ctx(ctx_size); - - let mut ctx = self - .model - .new_context(&self.backend, ctx_params) - .map_err(|e| NodeError::Inference(format!("failed to create context: {e}")))?; - // Tokenize prompt let tokens = self .model @@ -206,6 +232,23 @@ impl InferenceEngine { .map_err(|e| NodeError::Inference(format!("tokenization failed: {e}")))?; let prompt_token_count = tokens.len() as u32; + // Pre-flight: check that prompt + max_tokens fits in context + let needed = prompt_token_count + params.max_tokens; + if needed > self.ctx_limit { + return Err(NodeError::Inference(format!( + "prompt ({prompt_token_count}) + max_tokens ({}) = {needed} exceeds context ({})", + params.max_tokens, self.ctx_limit + ))); + } + + let ctx_size = std::num::NonZeroU32::new(self.ctx_limit); + let ctx_params = LlamaContextParams::default().with_n_ctx(ctx_size); + + let mut ctx = self + .model + .new_context(&self.backend, ctx_params) + .map_err(|e| NodeError::Inference(format!("failed to create context: {e}")))?; + // Evaluate prompt let prompt_start = Instant::now(); let mut batch = LlamaBatch::new(tokens.len().max(1), 1); @@ -372,8 +415,8 @@ impl InferenceEngine { let prompt_token_count = chunks.total_tokens() as u32; - // Create context with larger size for multimodal - let ctx_size = std::num::NonZeroU32::new(4096); + // Create context sized to the model's effective limit + let ctx_size = std::num::NonZeroU32::new(self.ctx_limit); let ctx_params = LlamaContextParams::default().with_n_ctx(ctx_size); let mut ctx = self @@ -735,7 +778,7 @@ mod tests { // Load engine with multimodal projector println!("loading model + mmproj..."); - let engine = InferenceEngine::load(&model_path, 0, Some(&mmproj_path)).unwrap(); + let engine = InferenceEngine::load(&model_path, 0, Some(&mmproj_path), None).unwrap(); assert!(engine.has_multimodal(), "engine should have multimodal context"); // Get test image: from env var or generate a synthetic BMP @@ -805,7 +848,7 @@ mod tests { }; let name = spec.name.clone(); - let engine = InferenceEngine::load(&model_path, 0, None).unwrap(); + let engine = InferenceEngine::load(&model_path, 0, None, None).unwrap(); (engine, name) } diff --git a/src/main.rs b/src/main.rs index 3eda28ba..f999cd87 100644 --- a/src/main.rs +++ b/src/main.rs @@ -55,8 +55,9 @@ async fn main() -> anyhow::Result<()> { quant, insecure, skip_update, + context_size, } => { - run_start(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure, skip_update).await?; + run_start(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure, skip_update, context_size).await?; } } @@ -90,9 +91,10 @@ async fn run_start( quant: Option, insecure: bool, skip_update: bool, + max_context: Option, ) -> anyhow::Result<()> { // Parse config - let config = Config::from_start_args(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure, skip_update)?; + let config = Config::from_start_args(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure, skip_update, max_context)?; // Create identity let identity = Identity::from_secret_hex(&config.secret_key_hex)?; @@ -140,7 +142,7 @@ async fn run_start( let spec = resolve_model(model_name, ®istry, config.quant.as_deref()) .ok_or_else(|| error::NodeError::Model(format!("unknown model: {model_name}")))?; - let (engine, tps) = download_and_load_model(&spec, &cache, config.gpu_layers).await?; + let (engine, tps) = download_and_load_model(&spec, &cache, config.gpu_layers, config.max_context).await?; tracing::info!(tps = %format!("{tps:.1}"), model = %model_name, "benchmark complete"); engines.insert(model_name.clone(), (engine, spec.model_type)); @@ -496,13 +498,14 @@ async fn handle_router_message( let spec = ModelSpec::from_registry_entry(entry); let cache = ctx.cache.clone(); let gpu_layers = ctx.config.gpu_layers; + let max_context = ctx.config.max_context; let tx = model_tx.clone(); let name = entry.name.clone(); let model_type = entry.model_type; tracing::info!(model = %name, "spawning background model download+load"); tokio::spawn(async move { - let result = download_and_load_model(&spec, &cache, gpu_layers).await; + let result = download_and_load_model(&spec, &cache, gpu_layers, max_context).await; let _ = tx.send(ModelLoadResult { name, model_type, result }); }); } @@ -518,6 +521,7 @@ async fn download_and_load_model( spec: &ModelSpec, cache: &ModelCache, gpu_layers: i32, + max_context: Option, ) -> Result<(inference::InferenceEngine, f64), error::NodeError> { let model_name = spec.name.clone(); @@ -558,7 +562,7 @@ async fn download_and_load_model( // Load model and run benchmark in blocking thread let (engine, tps) = tokio::task::spawn_blocking(move || { - let engine = inference::InferenceEngine::load(&model_path, gpu_layers, mmproj_path.as_deref())?; + let engine = inference::InferenceEngine::load(&model_path, gpu_layers, mmproj_path.as_deref(), max_context)?; let tps_result = engine.benchmark(&model_name)?; Ok::<_, error::NodeError>((engine, tps_result.generation_tps)) }) @@ -598,6 +602,13 @@ fn handle_completed_task( Err(e) => { stats.record_failed(); tracing::error!(%e, task_id = %completed.task_id, "task failed"); + // Propagate error back to router so it can retry on another node + if let Some(ref conn) = connection { + let _ = conn.send(NodeMessage::StreamError { + task_id: completed.task_id, + error: e.to_string(), + }); + } } } } diff --git a/src/setup.rs b/src/setup.rs index e193e955..ab840a06 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -321,7 +321,7 @@ pub async fn run_setup(data_dir: Option, gpu_layers: i32) -> Result<(), let engine = tokio::task::spawn_blocking({ let model_path = model_path.clone(); let mmproj_path = mmproj_path.clone(); - move || InferenceEngine::load(&model_path, gpu_layers, mmproj_path.as_deref()) + move || InferenceEngine::load(&model_path, gpu_layers, mmproj_path.as_deref(), None) }) .await .map_err(|e| NodeError::Inference(format!("task join error: {e}")))?; diff --git a/src/worker.rs b/src/worker.rs index 1092c6b5..1c4a2a4c 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -99,6 +99,27 @@ impl Worker { )); } + // Pre-flight context check (text-only; multimodal token counting requires a context) + let has_media = has_image || has_audio; + if !has_media { + match engine.tokenize_count(&messages) { + Ok(prompt_tokens) => { + let needed = prompt_tokens + max_tokens; + if needed > engine.ctx_limit() { + return Err(RejectReason::InvalidRequest(format!( + "prompt ({prompt_tokens}) + max_tokens ({max_tokens}) = {needed} exceeds context ({})", + engine.ctx_limit() + ))); + } + } + Err(e) => { + return Err(RejectReason::InvalidRequest(format!( + "tokenization failed: {e}" + ))); + } + } + } + let engine = Arc::clone(engine); // Try to decrement capacity (CAS loop) From 455b2826bfece3c543083b19a426b7fb0e23cf7f Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 12 Mar 2026 16:55:23 +0300 Subject: [PATCH 47/57] Size KV cache to request needs instead of full model context Allocate prompt_tokens + max_tokens per request instead of the full ctx_limit (e.g. 32k). This avoids OOM on machines where the model fits in RAM but a full-size KV cache does not. The ctx_limit ceiling still gates pre-flight rejection. Co-Authored-By: Claude Opus 4.6 --- src/inference/engine.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/inference/engine.rs b/src/inference/engine.rs index 9d1a0fa9..e04fc359 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -241,7 +241,8 @@ impl InferenceEngine { ))); } - let ctx_size = std::num::NonZeroU32::new(self.ctx_limit); + // Allocate only what this request needs (saves RAM vs full ctx_limit) + let ctx_size = std::num::NonZeroU32::new(needed); let ctx_params = LlamaContextParams::default().with_n_ctx(ctx_size); let mut ctx = self @@ -415,8 +416,9 @@ impl InferenceEngine { let prompt_token_count = chunks.total_tokens() as u32; - // Create context sized to the model's effective limit - let ctx_size = std::num::NonZeroU32::new(self.ctx_limit); + // Allocate only what this request needs (saves RAM vs full ctx_limit) + let needed = prompt_token_count + params.max_tokens; + let ctx_size = std::num::NonZeroU32::new(needed); let ctx_params = LlamaContextParams::default().with_n_ctx(ctx_size); let mut ctx = self @@ -574,8 +576,8 @@ impl InferenceEngine { }); } - // Create context sized to fit all tokens - let ctx_size = std::num::NonZeroU32::new((all_tokens.len() + 64).max(2048) as u32); + // Create context sized to fit all tokens (+ small padding) + let ctx_size = std::num::NonZeroU32::new((all_tokens.len() + 64) as u32); let ctx_params = LlamaContextParams::default().with_n_ctx(ctx_size); let mut ctx = self From 36a8579854ba132388684b25e2714a2c1a8de68f Mon Sep 17 00:00:00 2001 From: andthattoo Date: Thu, 12 Mar 2026 17:02:56 +0300 Subject: [PATCH 48/57] Default KV cache to Q8_0 and add --kv-quant flag KV cache now defaults to Q8_0 quantization instead of F16, roughly halving KV memory with negligible quality loss. Operators can override with --kv-quant (f16, q8_0, q4_0, etc.) or DRIA_KV_QUANT env var. Co-Authored-By: Claude Opus 4.6 --- src/config.rs | 15 +++++++++++++++ src/inference/engine.rs | 28 +++++++++++++++++++++------- src/main.rs | 34 +++++++++++++++++++++++++++++----- src/setup.rs | 2 +- 4 files changed, 66 insertions(+), 13 deletions(-) diff --git a/src/config.rs b/src/config.rs index 1285e5e8..b143db07 100644 --- a/src/config.rs +++ b/src/config.rs @@ -66,6 +66,10 @@ pub enum Command { /// When unset, engines use the model's full native context window. #[arg(long, env = "DRIA_CONTEXT_SIZE")] context_size: Option, + + /// KV cache quantization type (q8_0, q4_0, f16). Default: q8_0 (halves KV memory vs f16). + #[arg(long, env = "DRIA_KV_QUANT", default_value = "q8_0")] + kv_quant: String, }, } @@ -82,6 +86,7 @@ pub struct Config { pub insecure: bool, pub skip_update: bool, pub max_context: Option, + pub kv_quant: String, } impl Config { @@ -98,6 +103,7 @@ impl Config { insecure: bool, skip_update: bool, max_context: Option, + kv_quant: String, ) -> Result { // Validate wallet key let secret_key_hex = wallet.strip_prefix("0x").unwrap_or(&wallet).to_string(); @@ -155,6 +161,7 @@ impl Config { insecure, skip_update, max_context, + kv_quant, }) } } @@ -176,6 +183,7 @@ mod tests { false, false, None, + "q8_0".into(), ) .unwrap(); @@ -201,6 +209,7 @@ mod tests { false, false, None, + "q8_0".into(), ); assert!(result.is_err()); } @@ -218,6 +227,7 @@ mod tests { false, false, None, + "q8_0".into(), ); assert!(result.is_err()); } @@ -235,6 +245,7 @@ mod tests { false, false, None, + "q8_0".into(), ); assert!(result.is_err()); } @@ -252,6 +263,7 @@ mod tests { false, false, None, + "q8_0".into(), ); assert!(result.is_err()); } @@ -269,6 +281,7 @@ mod tests { false, false, None, + "q8_0".into(), ) .unwrap(); assert_eq!( @@ -290,6 +303,7 @@ mod tests { false, false, None, + "q8_0".into(), ); assert!(result.is_err()); } @@ -307,6 +321,7 @@ mod tests { true, false, None, + "q8_0".into(), ) .unwrap(); assert!(cfg.insecure); diff --git a/src/inference/engine.rs b/src/inference/engine.rs index e04fc359..17920aaa 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -2,7 +2,7 @@ use std::ops::ControlFlow; use std::path::Path; use std::time::Instant; -use llama_cpp_2::context::params::LlamaContextParams; +use llama_cpp_2::context::params::{KvCacheType, LlamaContextParams}; use llama_cpp_2::llama_backend::LlamaBackend; use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::params::LlamaModelParams; @@ -72,6 +72,8 @@ pub struct InferenceEngine { gpu_layers: i32, /// Effective context window size (tokens), auto-detected from model metadata. ctx_limit: u32, + /// KV cache quantization type (default Q8_0 to save memory). + kv_cache_type: KvCacheType, } /// Helper to convert a token to a string piece using the new token_to_piece API. @@ -92,7 +94,9 @@ impl InferenceEngine { gpu_layers: i32, mmproj_path: Option<&Path>, max_context: Option, + kv_cache_type: Option, ) -> Result { + let kv_cache_type = kv_cache_type.unwrap_or(KvCacheType::Q8_0); let backend = LlamaBackend::init() .map_err(|e| NodeError::Inference(format!("failed to init llama backend: {e}")))?; @@ -111,7 +115,7 @@ impl InferenceEngine { Some(cap) => n_ctx_train.min(cap), None => n_ctx_train, }; - tracing::info!(model_ctx = n_ctx_train, effective_ctx = ctx_limit, "context window"); + tracing::info!(model_ctx = n_ctx_train, effective_ctx = ctx_limit, kv_type = ?kv_cache_type, "context window"); let mtmd_ctx = match mmproj_path { Some(p) => { @@ -140,6 +144,7 @@ impl InferenceEngine { mtmd_ctx, gpu_layers, ctx_limit, + kv_cache_type, }) } @@ -243,7 +248,10 @@ impl InferenceEngine { // Allocate only what this request needs (saves RAM vs full ctx_limit) let ctx_size = std::num::NonZeroU32::new(needed); - let ctx_params = LlamaContextParams::default().with_n_ctx(ctx_size); + let ctx_params = LlamaContextParams::default() + .with_n_ctx(ctx_size) + .with_type_k(self.kv_cache_type) + .with_type_v(self.kv_cache_type); let mut ctx = self .model @@ -419,7 +427,10 @@ impl InferenceEngine { // Allocate only what this request needs (saves RAM vs full ctx_limit) let needed = prompt_token_count + params.max_tokens; let ctx_size = std::num::NonZeroU32::new(needed); - let ctx_params = LlamaContextParams::default().with_n_ctx(ctx_size); + let ctx_params = LlamaContextParams::default() + .with_n_ctx(ctx_size) + .with_type_k(self.kv_cache_type) + .with_type_v(self.kv_cache_type); let mut ctx = self .model @@ -578,7 +589,10 @@ impl InferenceEngine { // Create context sized to fit all tokens (+ small padding) let ctx_size = std::num::NonZeroU32::new((all_tokens.len() + 64) as u32); - let ctx_params = LlamaContextParams::default().with_n_ctx(ctx_size); + let ctx_params = LlamaContextParams::default() + .with_n_ctx(ctx_size) + .with_type_k(self.kv_cache_type) + .with_type_v(self.kv_cache_type); let mut ctx = self .model @@ -780,7 +794,7 @@ mod tests { // Load engine with multimodal projector println!("loading model + mmproj..."); - let engine = InferenceEngine::load(&model_path, 0, Some(&mmproj_path), None).unwrap(); + let engine = InferenceEngine::load(&model_path, 0, Some(&mmproj_path), None, None).unwrap(); assert!(engine.has_multimodal(), "engine should have multimodal context"); // Get test image: from env var or generate a synthetic BMP @@ -850,7 +864,7 @@ mod tests { }; let name = spec.name.clone(); - let engine = InferenceEngine::load(&model_path, 0, None, None).unwrap(); + let engine = InferenceEngine::load(&model_path, 0, None, None, None).unwrap(); (engine, name) } diff --git a/src/main.rs b/src/main.rs index f999cd87..85a2038e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,6 +14,7 @@ use std::sync::Arc; use std::time::Duration; use clap::Parser; +use llama_cpp_2::context::params::KvCacheType; use tokio::sync::mpsc; use tracing_subscriber::EnvFilter; @@ -56,8 +57,9 @@ async fn main() -> anyhow::Result<()> { insecure, skip_update, context_size, + kv_quant, } => { - run_start(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure, skip_update, context_size).await?; + run_start(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure, skip_update, context_size, kv_quant).await?; } } @@ -92,9 +94,10 @@ async fn run_start( insecure: bool, skip_update: bool, max_context: Option, + kv_quant: String, ) -> anyhow::Result<()> { // Parse config - let config = Config::from_start_args(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure, skip_update, max_context)?; + let config = Config::from_start_args(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure, skip_update, max_context, kv_quant)?; // Create identity let identity = Identity::from_secret_hex(&config.secret_key_hex)?; @@ -130,6 +133,9 @@ async fn run_start( std::fs::create_dir_all(&config.data_dir)?; std::fs::create_dir_all(&config.models_dir)?; + // Parse KV cache quantization type + let kv_cache_type = parse_kv_quant(&config.kv_quant)?; + // Resolve and download models let registry = default_registry(); let cache = ModelCache::new(config.models_dir.clone())?; @@ -142,7 +148,7 @@ async fn run_start( let spec = resolve_model(model_name, ®istry, config.quant.as_deref()) .ok_or_else(|| error::NodeError::Model(format!("unknown model: {model_name}")))?; - let (engine, tps) = download_and_load_model(&spec, &cache, config.gpu_layers, config.max_context).await?; + let (engine, tps) = download_and_load_model(&spec, &cache, config.gpu_layers, config.max_context, Some(kv_cache_type)).await?; tracing::info!(tps = %format!("{tps:.1}"), model = %model_name, "benchmark complete"); engines.insert(model_name.clone(), (engine, spec.model_type)); @@ -499,13 +505,14 @@ async fn handle_router_message( let cache = ctx.cache.clone(); let gpu_layers = ctx.config.gpu_layers; let max_context = ctx.config.max_context; + let kv_type = parse_kv_quant(&ctx.config.kv_quant).ok(); let tx = model_tx.clone(); let name = entry.name.clone(); let model_type = entry.model_type; tracing::info!(model = %name, "spawning background model download+load"); tokio::spawn(async move { - let result = download_and_load_model(&spec, &cache, gpu_layers, max_context).await; + let result = download_and_load_model(&spec, &cache, gpu_layers, max_context, kv_type).await; let _ = tx.send(ModelLoadResult { name, model_type, result }); }); } @@ -522,6 +529,7 @@ async fn download_and_load_model( cache: &ModelCache, gpu_layers: i32, max_context: Option, + kv_cache_type: Option, ) -> Result<(inference::InferenceEngine, f64), error::NodeError> { let model_name = spec.name.clone(); @@ -562,7 +570,7 @@ async fn download_and_load_model( // Load model and run benchmark in blocking thread let (engine, tps) = tokio::task::spawn_blocking(move || { - let engine = inference::InferenceEngine::load(&model_path, gpu_layers, mmproj_path.as_deref(), max_context)?; + let engine = inference::InferenceEngine::load(&model_path, gpu_layers, mmproj_path.as_deref(), max_context, kv_cache_type)?; let tps_result = engine.benchmark(&model_name)?; Ok::<_, error::NodeError>((engine, tps_result.generation_tps)) }) @@ -573,6 +581,22 @@ async fn download_and_load_model( Ok((engine, tps)) } +/// Parse a KV cache quantization string (e.g. "q8_0", "q4_0", "f16") into a `KvCacheType`. +fn parse_kv_quant(s: &str) -> Result { + match s.to_lowercase().as_str() { + "f16" => Ok(KvCacheType::F16), + "f32" => Ok(KvCacheType::F32), + "q8_0" => Ok(KvCacheType::Q8_0), + "q4_0" => Ok(KvCacheType::Q4_0), + "q4_1" => Ok(KvCacheType::Q4_1), + "q5_0" => Ok(KvCacheType::Q5_0), + "q5_1" => Ok(KvCacheType::Q5_1), + other => Err(error::NodeError::Config(format!( + "unknown kv-quant type '{other}' (supported: f16, f32, q8_0, q4_0, q4_1, q5_0, q5_1)" + ))), + } +} + /// Handle a completed inference task: send result or log if offline. fn handle_completed_task( completed: CompletedTask, diff --git a/src/setup.rs b/src/setup.rs index ab840a06..5bfc429f 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -321,7 +321,7 @@ pub async fn run_setup(data_dir: Option, gpu_layers: i32) -> Result<(), let engine = tokio::task::spawn_blocking({ let model_path = model_path.clone(); let mmproj_path = mmproj_path.clone(); - move || InferenceEngine::load(&model_path, gpu_layers, mmproj_path.as_deref(), None) + move || InferenceEngine::load(&model_path, gpu_layers, mmproj_path.as_deref(), None, None) }) .await .map_err(|e| NodeError::Inference(format!("task join error: {e}")))?; From 35a87fa1ec1faf9f980e7827c05f6b955e1733f7 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Fri, 13 Mar 2026 16:18:23 +0300 Subject: [PATCH 49/57] Fix multi-model loading: init llama backend once via OnceLock LlamaBackend::init() is a global singleton that errors on second call. Use OnceLock to ensure it's initialized exactly once and shared across all InferenceEngine instances. Co-Authored-By: Claude Opus 4.6 --- src/inference/engine.rs | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/inference/engine.rs b/src/inference/engine.rs index 17920aaa..90cdc80b 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -1,9 +1,21 @@ use std::ops::ControlFlow; use std::path::Path; +use std::sync::OnceLock; use std::time::Instant; use llama_cpp_2::context::params::{KvCacheType, LlamaContextParams}; use llama_cpp_2::llama_backend::LlamaBackend; + +/// Global singleton — llama.cpp backend can only be initialized once per process. +static LLAMA_BACKEND: OnceLock = OnceLock::new(); + +fn get_backend() -> Result<&'static LlamaBackend, NodeError> { + // OnceLock guarantees the closure runs exactly once, so BackendAlreadyInitialized + // cannot happen here. If init() somehow fails, it's a fatal environment issue. + Ok(LLAMA_BACKEND.get_or_init(|| { + LlamaBackend::init().expect("failed to init llama backend") + })) +} use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::params::LlamaModelParams; use llama_cpp_2::model::{AddBos, LlamaChatMessage, LlamaModel}; @@ -65,7 +77,7 @@ pub struct InferenceResult { /// NOTE: `LlamaContext` is not Send/Sync. All inference must happen /// via `tokio::task::spawn_blocking` with the engine moved into the closure. pub struct InferenceEngine { - backend: LlamaBackend, + backend: &'static LlamaBackend, model: LlamaModel, mtmd_ctx: Option, #[allow(dead_code)] @@ -97,8 +109,7 @@ impl InferenceEngine { kv_cache_type: Option, ) -> Result { let kv_cache_type = kv_cache_type.unwrap_or(KvCacheType::Q8_0); - let backend = LlamaBackend::init() - .map_err(|e| NodeError::Inference(format!("failed to init llama backend: {e}")))?; + let backend = get_backend()?; let model_params = if gpu_layers != 0 { let layers = if gpu_layers < 0 { 1000 } else { gpu_layers as u32 }; From 3da746c8156037812de6bac3d0173e7f10377521 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Fri, 13 Mar 2026 16:24:50 +0300 Subject: [PATCH 50/57] Fix clippy: remove needless borrows on static backend ref Co-Authored-By: Claude Opus 4.6 --- src/inference/engine.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/inference/engine.rs b/src/inference/engine.rs index 90cdc80b..284987f6 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -118,7 +118,7 @@ impl InferenceEngine { LlamaModelParams::default() }; - let model = LlamaModel::load_from_file(&backend, path, &model_params) + let model = LlamaModel::load_from_file(backend, path, &model_params) .map_err(|e| NodeError::Inference(format!("failed to load model: {e}")))?; let n_ctx_train = model.n_ctx_train(); @@ -266,7 +266,7 @@ impl InferenceEngine { let mut ctx = self .model - .new_context(&self.backend, ctx_params) + .new_context(self.backend, ctx_params) .map_err(|e| NodeError::Inference(format!("failed to create context: {e}")))?; // Evaluate prompt @@ -445,7 +445,7 @@ impl InferenceEngine { let mut ctx = self .model - .new_context(&self.backend, ctx_params) + .new_context(self.backend, ctx_params) .map_err(|e| NodeError::Inference(format!("failed to create context: {e}")))?; // Evaluate all chunks (text + media embeddings) @@ -607,7 +607,7 @@ impl InferenceEngine { let mut ctx = self .model - .new_context(&self.backend, ctx_params) + .new_context(self.backend, ctx_params) .map_err(|e| NodeError::Inference(format!("failed to create context: {e}")))?; // Build batch with all tokens. Set output=true only at positions where we need logits. From ff5eef8cce4003e912789217da2a07b502035624 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Fri, 13 Mar 2026 18:04:52 +0300 Subject: [PATCH 51/57] Fix Windows RAM detection and auto-recover corrupt cached models - Use PowerShell Get-CimInstance instead of deprecated wmic for RAM detection - Verify SHA-256 of cached model files before loading - Auto-delete and re-download corrupt files instead of failing permanently --- src/main.rs | 49 +++++++++++++++++++++++++++++++++---------------- src/setup.rs | 39 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 69 insertions(+), 19 deletions(-) diff --git a/src/main.rs b/src/main.rs index 85a2038e..fc272627 100644 --- a/src/main.rs +++ b/src/main.rs @@ -533,26 +533,23 @@ async fn download_and_load_model( ) -> Result<(inference::InferenceEngine, f64), error::NodeError> { let model_name = spec.name.clone(); - // Check local cache first + // Check local cache first, verify integrity if SHA is available let model_path = if let Some(path) = cache.get_local_path(spec) { - tracing::info!(model = %model_name, path = %path.display(), "model found in cache"); - path - } else { - // Download from HuggingFace - let hf_path = ModelDownloader::download(spec).await?; - - // Verify SHA-256 if specified if let Some(ref expected_sha) = spec.sha256 { - tracing::info!(model = %model_name, "verifying SHA-256"); - if !ModelCache::verify_sha256(&hf_path, expected_sha)? { - return Err(error::NodeError::Model(format!( - "SHA-256 mismatch for model {model_name}" - ))); + if ModelCache::verify_sha256(&path, expected_sha).unwrap_or(false) { + tracing::info!(model = %model_name, path = %path.display(), "model found in cache (verified)"); + path + } else { + tracing::warn!(model = %model_name, "cached model failed integrity check, re-downloading"); + std::fs::remove_file(&path).ok(); + download_and_link(spec, cache, &model_name).await? } + } else { + tracing::info!(model = %model_name, path = %path.display(), "model found in cache"); + path } - - // Link into our cache - cache.link_model(spec, &hf_path)? + } else { + download_and_link(spec, cache, &model_name).await? }; // Download mmproj if specified (for vision/audio models) @@ -581,6 +578,26 @@ async fn download_and_load_model( Ok((engine, tps)) } +/// Download a model from HuggingFace, verify SHA-256, and link into cache. +async fn download_and_link( + spec: &ModelSpec, + cache: &ModelCache, + model_name: &str, +) -> Result { + let hf_path = ModelDownloader::download(spec).await?; + + if let Some(ref expected_sha) = spec.sha256 { + tracing::info!(model = %model_name, "verifying SHA-256"); + if !ModelCache::verify_sha256(&hf_path, expected_sha)? { + return Err(error::NodeError::Model(format!( + "SHA-256 mismatch for model {model_name}" + ))); + } + } + + cache.link_model(spec, &hf_path) +} + /// Parse a KV cache quantization string (e.g. "q8_0", "q4_0", "f16") into a `KvCacheType`. fn parse_kv_quant(s: &str) -> Result { match s.to_lowercase().as_str() { diff --git a/src/setup.rs b/src/setup.rs index 5bfc429f..d51cb574 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -79,6 +79,16 @@ fn detect_ram_bytes() -> Option { #[cfg(target_os = "windows")] { + // Try PowerShell first (works on all modern Windows) + let output = std::process::Command::new("powershell") + .args(["-NoProfile", "-Command", "(Get-CimInstance Win32_ComputerSystem).TotalPhysicalMemory"]) + .output() + .ok()?; + let s = String::from_utf8_lossy(&output.stdout); + if let Ok(bytes) = s.trim().parse::() { + return Some(bytes); + } + // Fallback to wmic for older Windows let output = std::process::Command::new("wmic") .args(["OS", "get", "TotalVisibleMemorySize"]) .output() @@ -260,11 +270,34 @@ pub async fn run_setup(data_dir: Option, gpu_layers: i32) -> Result<(), } }; - // Download model + // Download model (verify cached files too) println!(" Downloading {}...", model_name); let model_path = if let Some(path) = cache.get_local_path(&spec) { - println!(" (already cached)"); - Ok(path) + // Verify cached file integrity + let valid = match &spec.sha256 { + Some(sha) => ModelCache::verify_sha256(&path, sha).unwrap_or(false), + None => true, + }; + if valid { + println!(" (already cached)"); + Ok(path) + } else { + println!(" Cached file is corrupt, re-downloading..."); + std::fs::remove_file(&path).ok(); + match ModelDownloader::download(&spec).await { + Ok(hf_path) => { + if let Some(ref expected_sha) = spec.sha256 { + if !ModelCache::verify_sha256(&hf_path, expected_sha)? { + println!(" SHA-256 mismatch! Try a different model."); + println!(); + continue; + } + } + cache.link_model(&spec, &hf_path) + } + Err(e) => Err(e), + } + } } else { match ModelDownloader::download(&spec).await { Ok(hf_path) => { From 93490839ac0e415051710f8df35d98c86fdf0922 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Fri, 13 Mar 2026 18:25:21 +0300 Subject: [PATCH 52/57] Fix quant substitution for dot-separated GGUF filenames MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit with_quant() used rfind('-') which broke for repos using dot separators (e.g. LocoOperator-4B.Q4_K_M.gguf → LocoOperator-Q8_0.gguf instead of LocoOperator-4B.Q8_0.gguf). Now finds the quant portion by matching known quant prefixes, preserving the original separator. --- src/models/registry.rs | 46 +++++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/src/models/registry.rs b/src/models/registry.rs index 6459e1cb..b2c84a44 100644 --- a/src/models/registry.rs +++ b/src/models/registry.rs @@ -138,15 +138,23 @@ impl ModelSpec { /// Return a new ModelSpec with the quantization portion of `hf_file` replaced. /// - /// GGUF filenames follow the pattern `{ModelName}-{Quant}.gguf` - /// (e.g. `Qwen3.5-9B-Q4_K_M.gguf`). This replaces the last `-{Quant}.gguf` - /// segment with the given quantization string. + /// GGUF filenames use either `-` or `.` before the quant string + /// (e.g. `Qwen3.5-9B-Q4_K_M.gguf` or `LocoOperator-4B.Q4_K_M.gguf`). + /// This finds the quant portion by looking for common quant prefixes + /// and replaces it while preserving the original separator. pub fn with_quant(&self, quant: &str) -> Self { - let new_file = if let Some(pos) = self.hf_file.rfind('-') { - format!("{}-{}.gguf", &self.hf_file[..pos], quant) - } else { - self.hf_file.clone() - }; + let stem = self.hf_file.strip_suffix(".gguf").unwrap_or(&self.hf_file); + // Find where the quant string starts by looking for known quant prefixes + let quant_prefixes = ["Q4_K_M", "Q4_K_S", "Q4_0", "Q4_1", "Q5_K_M", "Q5_K_S", "Q5_0", "Q5_1", "Q6_K", "Q8_0", "Q2_K", "Q3_K"]; + let new_file = quant_prefixes + .iter() + .filter_map(|prefix| stem.rfind(prefix).map(|pos| (pos, prefix))) + .max_by_key(|(pos, _)| *pos) + .map(|(pos, _)| { + // Preserve the separator character before the quant (- or .) + format!("{}{}.gguf", &self.hf_file[..pos], quant) + }) + .unwrap_or_else(|| self.hf_file.clone()); ModelSpec { hf_file: new_file, sha256: None, // hash no longer valid for a different quant @@ -262,7 +270,7 @@ mod tests { } #[test] - fn test_with_quant_substitutes_suffix() { + fn test_with_quant_substitutes_dash_separator() { let reg = default_registry(); let spec = ®["qwen3.5:9b"]; assert_eq!(spec.hf_file, "Qwen3.5-9B-Q4_K_M.gguf"); @@ -276,6 +284,26 @@ mod tests { assert_eq!(q8.hf_mmproj_file, spec.hf_mmproj_file); } + #[test] + fn test_with_quant_substitutes_dot_separator() { + let reg = default_registry(); + let spec = ®["locooperator:4b"]; + assert_eq!(spec.hf_file, "LocoOperator-4B.Q4_K_M.gguf"); + + let q8 = spec.with_quant("Q8_0"); + assert_eq!(q8.hf_file, "LocoOperator-4B.Q8_0.gguf"); + } + + #[test] + fn test_with_quant_nanbeige_dot_separator() { + let reg = default_registry(); + let spec = ®["nanbeige:3b"]; + assert_eq!(spec.hf_file, "Nanbeige.Nanbeige4.1-3B.Q4_K_M.gguf"); + + let q8 = spec.with_quant("Q8_0"); + assert_eq!(q8.hf_file, "Nanbeige.Nanbeige4.1-3B.Q8_0.gguf"); + } + #[test] fn test_with_quant_preserves_mmproj() { let reg = default_registry(); From 07ba8aae47b404f5c270d005ea2f48d6bc15eaa3 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Fri, 13 Mar 2026 21:06:27 +0300 Subject: [PATCH 53/57] Fix mmproj cache collision and add LLVM to Windows guide - Prefix cached mmproj filenames with model name to avoid collisions between models sharing the same mmproj filename (e.g. multiple Qwen models all using mmproj-BF16.gguf from different repos) - Add LLVM/Clang to Windows prerequisites (needed by bindgen) - Add libclang troubleshooting entry --- TESTER_GUIDE.md | 235 ++++++++++++++++++++++++++++++++++++++++++++ src/models/cache.rs | 63 +++++++++++- 2 files changed, 293 insertions(+), 5 deletions(-) create mode 100644 TESTER_GUIDE.md diff --git a/TESTER_GUIDE.md b/TESTER_GUIDE.md new file mode 100644 index 00000000..f870798d --- /dev/null +++ b/TESTER_GUIDE.md @@ -0,0 +1,235 @@ +# Dria Node v2 — Tester Guide + +Thanks for testing! This guide walks you through building and running a Dria compute node from source. + +## 1. Install Prerequisites + +You need **Rust** and **cmake**. Pick your OS: + +### macOS + +```bash +# Install Rust +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh +source ~/.cargo/env + +# Install cmake +brew install cmake +``` + +### Linux (Ubuntu/Debian) + +```bash +# Install Rust +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh +source ~/.cargo/env + +# Install build tools +sudo apt-get update && sudo apt-get install -y cmake build-essential +``` + +### Windows + +Open **PowerShell as Administrator** (right-click Start → "Terminal (Admin)") and run these commands one by one: + +```powershell +# Install Rust +winget install Rustlang.Rustup + +# Install C++ build tools (needed to compile the inference engine) +winget install Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;includeRecommended" + +# Install CMake +winget install -e --id Kitware.CMake + +# Install LLVM/Clang (needed by bindgen for llama.cpp bindings) +winget install -e --id LLVM.LLVM +``` + +**Important:** After all three finish, **close PowerShell and open a new one** so the tools are available. To verify everything installed correctly: + +```powershell +rustc --version +cmake --version +``` + +Both should print a version number. If either says "not recognized", restart your PC and try again. + +## 2. Build + +```bash +git clone https://github.com/firstbatchxyz/dkn-compute-node.git +cd dkn-compute-node +git checkout v2 +cargo build --release +``` + +This takes a few minutes (it compiles the inference engine from source). + +**Apple Silicon (M1/M2/M3/M4)?** Build with Metal GPU support instead: + +```bash +cargo build --release --features metal +``` + +**NVIDIA GPU?** Install the [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads) first, then: + +```bash +cargo build --release --features cuda +``` + +## 3. Run Setup + +The setup wizard helps you pick and download a model: + +```bash +./target/release/dria-node setup +``` + +**Windows (PowerShell):** Use backslashes and `.exe`: + +```powershell +.\target\release\dria-node.exe setup +``` + +It will: +- Detect your RAM and show models that fit +- Let you pick a model and quantization +- Download it (once, cached for future runs) +- Run a test inference to confirm everything works + +If you're unsure which model to pick, start with **lfm2.5:1.2b** — it's the smallest (~0.8 GB) and works on any machine. + +## 4. Your Wallet Key + +You'll need your Ethereum wallet private key. The node uses it to sign messages and prove identity on the network. + +This is the 64-character hex string (with or without `0x` prefix). You can export it from MetaMask: Account Details → Show Private Key. + +## 5. Start the Node + +```bash +./target/release/dria-node start \ + --wallet YOUR_KEY_HERE \ + --model lfm2.5:1.2b +``` + +**Windows (PowerShell):** + +```powershell +.\target\release\dria-node.exe start --wallet YOUR_KEY_HERE --model lfm2.5:1.2b +``` + +Replace `YOUR_KEY_HERE` with the key from step 4, and `lfm2.5:1.2b` with whatever model you chose in setup. + +**If you have a GPU** and built with `--features metal` or `--features cuda`: + +```bash +./target/release/dria-node start \ + --wallet YOUR_KEY_HERE \ + --model lfm2.5:1.2b \ + --gpu-layers -1 +``` + +### What to expect + +``` +INFO node identity address=0x... +INFO benchmark complete tps=25.3 model=lfm2.5:1.2b +INFO connected to router node_id=... router=quic.dria.co:4001 +INFO node ready models=["lfm2.5:1.2b"] online=true +``` + +That's it — the node is running and accepting tasks. Leave it open. Press **Ctrl+C** to stop. + +## 6. Skip the Flags Next Time + +Instead of typing flags every time, set environment variables: + +```bash +# Add these to your shell profile (~/.bashrc, ~/.zshrc, etc.) +export DRIA_WALLET=your_key_here +export DRIA_MODELS=lfm2.5:1.2b +export DRIA_GPU_LAYERS=-1 +``` + +Then just run: + +```bash +./target/release/dria-node start +``` + +## Models + +| Model | Type | Download | Min RAM | +|---|---|---|---| +| qwen3.5:0.8b | Text, Vision | ~0.5 GB | ~1 GB | +| lfm2.5:1.2b | Text | ~0.8 GB | ~1 GB | +| lfm2.5-audio:1.5b | Text, Audio | ~1.0 GB | ~1.5 GB | +| lfm2.5-vl:1.6b | Text, Vision | ~1.2 GB | ~1.5 GB | +| qwen3.5:2b | Text, Vision | ~1.2 GB | ~2 GB | +| nanbeige:3b | Text | ~2.0 GB | ~2.5 GB | +| locooperator:4b | Text | ~2.5 GB | ~3 GB | +| qwen3.5:9b | Text, Vision | ~6.0 GB | ~7 GB | +| lfm2:24b-a2b | Text | ~14 GB | ~16 GB | +| qwen3.5:27b | Text, Vision | ~16 GB | ~18 GB | +| qwen3.5:35b-a3b | Text, Vision | ~20 GB | ~22 GB | +| nemotron:30b-a3b | Text | ~24.5 GB | ~27 GB | + +Pick one model that fits your RAM. Smaller models are faster to download and easier to test with. + +## All Options + +| Flag | Env Var | Default | What it does | +|---|---|---|---| +| `--wallet` | `DRIA_WALLET` | (required) | Your node identity key | +| `--model` | `DRIA_MODELS` | (required) | Model(s) to serve | +| `--router-url` | `DRIA_ROUTER_URL` | `quic.dria.co:4001` | Router to connect to | +| `--gpu-layers` | `DRIA_GPU_LAYERS` | `0` (CPU) | GPU layers (-1 = all) | +| `--max-concurrent` | `DRIA_MAX_CONCURRENT` | `1` | Parallel inference tasks | +| `--data-dir` | `DRIA_DATA_DIR` | `~/.dria` | Where models are cached | +| `--quant` | `DRIA_QUANT` | Q4_K_M | Override quantization | +| `--insecure` | `DRIA_INSECURE` | `false` | Skip TLS verification | +| `--skip-update` | `DRIA_SKIP_UPDATE` | `false` | Skip auto-update check | + +## Troubleshooting + +**Windows: "dria-node is not recognized"** +On Windows you must use `.\target\release\dria-node.exe` (backslashes, `.exe` extension). PowerShell does not find executables without the `.exe` suffix. + +**"cmake not found" or build errors about C compiler** +Make sure cmake is installed (step 1). On macOS: `brew install cmake`. On Linux: `sudo apt install cmake build-essential`. On Windows: `winget install -e --id Kitware.CMake` then reopen PowerShell. + +**Windows: "dria-node.exe not found in target\release"** +The build probably failed. Scroll up in your terminal and look for red error messages. The most common cause is missing C++ build tools — run `winget install Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;includeRecommended"`, reopen PowerShell, and rebuild with `cargo build --release`. + +**Windows: "Unable to find libclang" or "couldn't find clang.dll"** +Install LLVM: `winget install -e --id LLVM.LLVM`, reopen PowerShell, and rebuild. If it still can't find it, set the path manually: `$env:LIBCLANG_PATH = "C:\Program Files\LLVM\bin"` then rebuild. + +**Build fails** +Try a clean build: `cargo clean && cargo build --release`. Make sure you're on the `v2` branch: `git checkout v2`. + +**"unknown model"** +Model names are exact. Use the names from the table above (e.g. `lfm2.5:1.2b`, not `lfm-2.5`). + +**"all routers unavailable" or "offline mode"** +The node can't reach the router. Check your internet connection. If you're behind a strict firewall, **UDP port 4001 outbound** must be allowed. + +**Slow inference** +If you have a GPU, make sure you built with `--features metal` (Mac) or `--features cuda` (NVIDIA) and are passing `--gpu-layers -1`. + +**Model download stalls or fails** +Models come from HuggingFace. Try again — it might be a temporary network issue. You can also set `HF_ENDPOINT` if HuggingFace is blocked in your region. + +**Want more detail in the logs?** + +```bash +RUST_LOG=debug ./target/release/dria-node start ... +``` + +## Reporting Issues + +If something goes wrong, please share: +1. Your OS and hardware (CPU, RAM, GPU) +2. The command you ran +3. The full error output diff --git a/src/models/cache.rs b/src/models/cache.rs index 9ae22e45..3e6e7cd8 100644 --- a/src/models/cache.rs +++ b/src/models/cache.rs @@ -29,11 +29,20 @@ impl ModelCache { } /// Check if a model's mmproj GGUF is already present in our cache. + /// Uses model-prefixed filename to avoid collisions between models + /// that share the same mmproj filename (e.g. multiple Qwen models + /// all using "mmproj-BF16.gguf" from different repos). pub fn get_mmproj_path(&self, spec: &ModelSpec) -> Option { let file = spec.hf_mmproj_file.as_ref()?; - let path = self.cache_dir.join(file); + let prefixed = format!("{}_{}", spec.name.replace(':', "-"), file); + let path = self.cache_dir.join(&prefixed); if path.exists() { - Some(path) + return Some(path); + } + // Check legacy unprefixed path for backward compat (single-model setups) + let legacy = self.cache_dir.join(file); + if legacy.exists() { + Some(legacy) } else { None } @@ -68,12 +77,14 @@ impl ModelCache { } /// Create a symlink from our cache dir to the hf-hub cached mmproj file. + /// Uses model-prefixed filename to avoid collisions. pub fn link_mmproj(&self, spec: &ModelSpec, source: &Path) -> Result { let file = spec .hf_mmproj_file .as_ref() .ok_or_else(|| NodeError::Model("no mmproj file specified".into()))?; - let dest = self.cache_dir.join(file); + let prefixed = format!("{}_{}", spec.name.replace(':', "-"), file); + let dest = self.cache_dir.join(&prefixed); if dest.exists() { return Ok(dest); } @@ -161,10 +172,52 @@ mod tests { // Not present initially assert!(cache.get_mmproj_path(&spec_with_mmproj).is_none()); - // Create the mmproj file - std::fs::write(dir.join("mmproj.gguf"), b"fake").unwrap(); + // Create the prefixed mmproj file + std::fs::write(dir.join("vl-1b_mmproj.gguf"), b"fake").unwrap(); assert!(cache.get_mmproj_path(&spec_with_mmproj).is_some()); std::fs::remove_dir_all(&dir).ok(); } + + #[test] + fn test_mmproj_no_collision() { + let dir = std::env::temp_dir().join("dria-cache-test-mmproj-collision"); + let cache = ModelCache::new(dir.clone()).unwrap(); + + let spec_a = ModelSpec { + name: "qwen3.5:0.8b".into(), + hf_repo: "unsloth/Qwen3.5-0.8B-GGUF".into(), + hf_file: "model-a.gguf".into(), + sha256: None, + model_type: dkn_protocol::ModelType::Vision, + hf_mmproj_file: Some("mmproj-BF16.gguf".into()), + }; + + let spec_b = ModelSpec { + name: "qwen3.5:27b".into(), + hf_repo: "unsloth/Qwen3.5-27B-GGUF".into(), + hf_file: "model-b.gguf".into(), + sha256: None, + model_type: dkn_protocol::ModelType::Vision, + hf_mmproj_file: Some("mmproj-BF16.gguf".into()), + }; + + // Create separate source files + std::fs::write(dir.join("mmproj_a.gguf"), b"small_model").unwrap(); + std::fs::write(dir.join("mmproj_b.gguf"), b"large_model").unwrap(); + + let path_a = cache.link_mmproj(&spec_a, &dir.join("mmproj_a.gguf")).unwrap(); + let path_b = cache.link_mmproj(&spec_b, &dir.join("mmproj_b.gguf")).unwrap(); + + // Paths must be different + assert_ne!(path_a, path_b); + assert!(path_a.exists()); + assert!(path_b.exists()); + + // Content must be independent + assert_eq!(std::fs::read(&path_a).unwrap(), b"small_model"); + assert_eq!(std::fs::read(&path_b).unwrap(), b"large_model"); + + std::fs::remove_dir_all(&dir).ok(); + } } From 191c79ed5f553137201c0886cefc48abe76ddb01 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Sat, 14 Mar 2026 22:37:27 +0300 Subject: [PATCH 54/57] Remove legacy mmproj fallback to prevent cross-model collision --- src/models/cache.rs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/models/cache.rs b/src/models/cache.rs index 3e6e7cd8..37dca696 100644 --- a/src/models/cache.rs +++ b/src/models/cache.rs @@ -37,12 +37,7 @@ impl ModelCache { let prefixed = format!("{}_{}", spec.name.replace(':', "-"), file); let path = self.cache_dir.join(&prefixed); if path.exists() { - return Some(path); - } - // Check legacy unprefixed path for backward compat (single-model setups) - let legacy = self.cache_dir.join(file); - if legacy.exists() { - Some(legacy) + Some(path) } else { None } From e10f4ba846abff8977fb7cf6790203452c341c09 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Sun, 15 Mar 2026 18:43:57 +0300 Subject: [PATCH 55/57] Bump CI CUDA toolkit to 12.8.0 for Blackwell GPU support --- .github/workflows/releases.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index 77dad65b..82d6c711 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -134,14 +134,14 @@ jobs: osname: linux, arch: amd64-cuda, target: x86_64-unknown-linux-gnu, - cuda: "12.6.3", + cuda: "12.8.0", } - { runner: windows-latest, osname: windows, arch: amd64-cuda, target: x86_64-pc-windows-msvc, - cuda: "12.6.3", + cuda: "12.8.0", extension: ".exe", } From 12aef36078e379870308d9406295ee1e69c441f8 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Mon, 16 Mar 2026 13:35:00 +0300 Subject: [PATCH 56/57] Fix crash when prompt exceeds n_batch (2048 tokens) Chunk prompt evaluation into batches of 2048 tokens instead of decoding all at once, which triggers GGML_ASSERT(n_tokens_all <= n_batch). Applied to both generate() and validate_prefill(). Co-Authored-By: Claude Opus 4.6 --- src/inference/engine.rs | 86 ++++++++++++++++++++++++++--------------- 1 file changed, 55 insertions(+), 31 deletions(-) diff --git a/src/inference/engine.rs b/src/inference/engine.rs index 284987f6..b1888e13 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -269,17 +269,24 @@ impl InferenceEngine { .new_context(self.backend, ctx_params) .map_err(|e| NodeError::Inference(format!("failed to create context: {e}")))?; - // Evaluate prompt + // Evaluate prompt in chunks (n_batch = 2048 default in llama.cpp) let prompt_start = Instant::now(); - let mut batch = LlamaBatch::new(tokens.len().max(1), 1); - for (i, &token) in tokens.iter().enumerate() { - let is_last = i == tokens.len() - 1; - batch - .add(token, i as i32, &[0], is_last) - .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?; + let n_batch = 2048usize; + let mut batch = LlamaBatch::new(n_batch.min(tokens.len()).max(1), 1); + let mut prompt_pos = 0; + while prompt_pos < tokens.len() { + batch.clear(); + let chunk_end = (prompt_pos + n_batch).min(tokens.len()); + for i in prompt_pos..chunk_end { + let is_last = i == tokens.len() - 1; + batch + .add(tokens[i], i as i32, &[0], is_last) + .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?; + } + ctx.decode(&mut batch) + .map_err(|e| NodeError::Inference(format!("prompt decode failed: {e}")))?; + prompt_pos = chunk_end; } - ctx.decode(&mut batch) - .map_err(|e| NodeError::Inference(format!("prompt decode failed: {e}")))?; let prompt_eval_time_ms = prompt_start.elapsed().as_millis() as u64; // Build sampler chain (grammar first to mask invalid tokens, then sampling) @@ -307,8 +314,8 @@ impl InferenceEngine { let mut current_pos = tokens.len() as i32; let mut decoder = encoding_rs::UTF_8.new_decoder(); // Batch index where logits are available: - // after prompt eval → last prompt token; after each single-token decode → 0 - let mut logit_batch_idx: i32 = (tokens.len() - 1) as i32; + // after chunked prompt eval → last token's position in last chunk; after single-token decode → 0 + let mut logit_batch_idx: i32 = ((tokens.len() - 1) % n_batch) as i32; for _ in 0..params.max_tokens { // sample() internally calls apply + select + accept @@ -619,29 +626,46 @@ impl InferenceEngine { output_positions.push(seq_pos); } - let mut batch = LlamaBatch::new(all_tokens.len().max(1), 1); - for (i, &token) in all_tokens.iter().enumerate() { - let is_output = output_positions.contains(&i); - batch - .add(token, i as i32, &[0], is_output) - .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?; - } + // Evaluate in chunks and extract logprobs per-chunk (next decode overwrites logits) + let n_batch = 2048usize; + let mut batch = LlamaBatch::new(n_batch.min(all_tokens.len()).max(1), 1); + let mut logprobs: Vec = Vec::new(); - // Single forward pass - ctx.decode(&mut batch) - .map_err(|e| NodeError::Inference(format!("prefill decode failed: {e}")))?; + let mut pos = 0; + while pos < all_tokens.len() { + batch.clear(); + let chunk_end = (pos + n_batch).min(all_tokens.len()); + + // Track which probe positions fall in this chunk + let mut chunk_probes: Vec<(usize, usize)> = Vec::new(); // (probe_idx, batch_position) + + for (batch_pos, (i, &token)) in all_tokens.iter().enumerate().skip(pos).take(chunk_end - pos).enumerate() { + let is_output = output_positions.contains(&i); + batch + .add(token, i as i32, &[0], is_output) + .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?; + if is_output { + if let Some(probe_idx) = output_positions.iter().position(|&p| p == i) { + chunk_probes.push((probe_idx, batch_pos)); + } + } + } - // Extract logprobs at each probe position. - // get_logits_ith takes the batch index where output=true was set. - let mut logprobs: Vec = Vec::new(); - for (probe_idx, &gen_index) in probe_gen_indices.iter().enumerate() { - let target_token = all_tokens[n_prompt + gen_index]; - let batch_idx = output_positions[probe_idx] as i32; - if let Some(lp) = - self.extract_logprob(&ctx, batch_idx, gen_index, target_token, logprob_top_k) - { - logprobs.push(lp); + ctx.decode(&mut batch) + .map_err(|e| NodeError::Inference(format!("prefill decode failed: {e}")))?; + + // Extract logprobs for this chunk's probes before next decode + for &(probe_idx, batch_pos) in &chunk_probes { + let gen_index = probe_gen_indices[probe_idx]; + let target_token = all_tokens[n_prompt + gen_index]; + if let Some(lp) = + self.extract_logprob(&ctx, batch_pos as i32, gen_index, target_token, logprob_top_k) + { + logprobs.push(lp); + } } + + pos = chunk_end; } Ok(InferenceProof { From e440658748d2923977843d3f37dac2a16f94563c Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 17 Mar 2026 18:46:40 +0300 Subject: [PATCH 57/57] Bump version to 0.7.3 and target master branch in CI --- .github/workflows/tests.yml | 4 ++-- Cargo.lock | 2 +- Cargo.toml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4b5caca1..36f125df 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -3,7 +3,7 @@ name: tests on: push: branches: - - v2 + - master paths: - "src/**" - "Cargo.toml" @@ -11,7 +11,7 @@ on: - ".github/workflows/tests.yml" pull_request: branches: - - v2 + - master workflow_dispatch: jobs: diff --git a/Cargo.lock b/Cargo.lock index 49f36a03..6cd86fed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -425,7 +425,7 @@ dependencies = [ [[package]] name = "dria-node" -version = "0.7.2" +version = "0.7.3" dependencies = [ "anyhow", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 9783f8b4..1d875310 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dria-node" -version = "0.7.2" +version = "0.7.3" edition = "2021" license = "Apache-2.0"