- Dria Compute Node serves the computation results within Dria Knowledge Network.
+ Run AI inference on the Dria network. Earn rewards by serving models from your machine.
@@ -29,111 +29,139 @@
-> Use the [Dria Compute Launcher](https://github.com/firstbatchxyz/dkn-compute-launcher/) to run a compute node with many more features!
+## Quick Start
-## Releases
+### Install
-For _production_ images:
+**Homebrew (macOS / Linux):**
-- **Versioned**: With each release, a versioned image is deployed on Docker hub with the version tag `:vX.X.X`.
-- **Latest**: The latest production image is always under the `:latest` tag.
+```sh
+brew tap firstbatchxyz/dkn
+brew install dria-node
+```
-For _development_ images:
+**Shell script (macOS / Linux):**
-- **Master**: On each push to `master` branch, a new image is created with the tag `master--`.
-- **Unstable**: The latest development image is always under the `:unstable` tag.
+```sh
+curl -fsSL https://raw.githubusercontent.com/firstbatchxyz/dkn-compute-node/master/install.sh | sh
+```
-You can see the list of deployed images on [Docker Hub](https://hub.docker.com/orgs/firstbatch/members).
+**PowerShell (Windows):**
-## Development
+```powershell
+irm https://raw.githubusercontent.com/firstbatchxyz/dkn-compute-node/master/install.ps1 | iex
+```
-> If you have a feature that you would like to add with respect to its respective issue, or a bug fix, feel free to fork & create a PR!
+**From GitHub Releases:**
-If you would like to run the node from source (which is really handy during development), you can use our shorthand scripts within the Makefile. You can see the available commands with:
+Download the latest binary for your platform from [Releases](https://github.com/firstbatchxyz/dkn-compute-node/releases) and place it in your `PATH`.
-```sh
-make help
-```
+### Setup
-You can run the binary as is:
+Run the interactive setup:
```sh
-cargo run
-
-# specify custom .env file
-DKN_COMPUTE_ENV=./path/to/.env cargo run
+dria-node setup
```
-If you have a valid `.env` file, you can run the latest Docker image via compose as well:
+This will:
+
+1. Detect your system RAM and list models that fit
+2. Let you pick a model from the available options
+3. Download the GGUF model file from HuggingFace
+4. Run a test inference to verify everything works
+5. Print a benchmark (tokens per second)
+
+Use `--gpu-layers -1` to offload all layers to GPU (Metal on macOS, requires building with `--features cuda` for NVIDIA):
```sh
-docker compose up
-
-# Ollama without any GPUs
-docker compose --profile=ollama-cpu up
-# Ollama for NVIDIA gpus
-docker compose --profile=ollama-cuda up
-# Ollama for AMD gpus
-docker compose --profile=ollama-rocm up
+dria-node setup --gpu-layers -1
```
-> [!TIP]
->
-> You can specify a custom initial RPC address with `DKN_INITIAL_RPC_ADDR`.
+### Start
-### Testing
+Once setup is complete, start the node:
-You can the tests as follows:
+```sh
+dria-node start --wallet --model
+```
+
+The node will connect to the Dria network, register your models, and start serving inference requests. You can increase throughput with `--max-concurrent`:
```sh
-cargo test --workspace
+dria-node start --wallet --model lfm2.5:1.2b --max-concurrent 4
```
-We also have some benchmarking and profiling scripts, see [node performance](./docs/NODE_PERFORMANCE.md) for more details.
+## Available Models
-### Documentation
+| Model | Type | Quant | ~Size |
+|-------|------|-------|-------|
+| `lfm2.5:1.2b` | Text | Q4_K_M | 0.8 GB |
+| `lfm2.5-audio:1.5b` | Audio | Q4_0 | 1.0 GB |
+| `lfm2.5-vl:1.6b` | Vision | Q4_0 | 1.2 GB |
+| `nanbeige:3b` | Text | Q4_K_M | 2.0 GB |
+| `locooperator:4b` | Text | Q4_K_M | 2.5 GB |
+| `qwen3.5:9b` | Vision | Q4_K_M | 6.0 GB |
+| `lfm2:24b-a2b` | Text | Q4_K_M | 14 GB |
+| `qwen3.5:27b` | Vision | Q4_K_M | 16 GB |
+| `qwen3.5:35b-a3b` | Vision | Q4_K_M | 20 GB |
-You can view the entire crate-level documentation with:
+Serve multiple models by comma-separating them: `--model "qwen3.5:9b,lfm2.5:1.2b"`
+
+Override quantization with `--quant Q8_0` (applies to all models).
+
+## CLI Reference
-```sh
-cargo doc --open --no-deps --document-private-items
+```
+dria-node
+
+Commands:
+ setup Interactive setup: pick a model, download it, and run a test
+ start Start the compute node
+
+setup options:
+ --data-dir Data directory [env: DRIA_DATA_DIR]
+ --gpu-layers GPU layers to offload (0 = CPU only) [default: 0]
+
+start options:
+ --wallet Wallet secret key, hex-encoded [env: DRIA_WALLET]
+ --model Model(s) to serve, comma-separated [env: DRIA_MODELS]
+ --router-url Router URL [default: quic.dria.co:4001] [env: DRIA_ROUTER_URL]
+ --gpu-layers GPU layers to offload (-1 = all, 0 = CPU) [default: 0]
+ --max-concurrent Max concurrent inference requests [default: 1]
+ --data-dir Data directory [env: DRIA_DATA_DIR]
+ --quant Override GGUF quantization [env: DRIA_QUANT]
+ --insecure Skip TLS verification [env: DRIA_INSECURE]
```
-### Styling
+All flags can also be set via environment variables.
-Lint and format with:
+## Building from Source
```sh
-cargo clippy --workspace
-cargo fmt -v
+git clone https://github.com/firstbatchxyz/dkn-compute-node.git
+cd dkn-compute-node
+cargo build --release
```
-### Profiling
-
-We have scripts to profile both CPU and Memory usage. A special build is created for profiling, via a custom `profiling` feature, such that the output inherits `release` mode but also has debug symbols.
+**Feature flags:**
-Furthermore, the profiling build will exit automatically after a certain time, as if CTRL+C has been pressed. This is needed by the memory profiling tool in particular.
+- `--features metal` — Apple Metal GPU acceleration (macOS)
+- `--features cuda` — NVIDIA CUDA GPU acceleration
-**CPU Profiling**: To create a [flamegraph](https://crates.io/crates/flamegraph) of the application, the command below will create a profiling build that inherits `release` mode, except with debug information:
+### Testing
```sh
-DKN_EXIT_TIMEOUT=120 cargo flamegraph --root --profile=profiling --bin dkn-compute
+cargo test
```
-> [!NOTE]
->
-> CPU profiling may require super-user access.
-
-**Memory Profiling**: To profile memory usage, we make use of [cargo-instruments](https://crates.io/crates/cargo-instruments):
+### Linting
```sh
-DKN_EXIT_TIMEOUT=120 cargo instruments --profile=profiling -t Allocations --bin dkn-compute
+cargo clippy
+cargo fmt --check
```
-> [!TIP]
->
-> You can adjust the profiling duration via the `DKN_EXIT_TIMEOUT` variable, which takes a number of seconds until termination.
-
## License
This project is licensed under the [Apache License 2.0](https://opensource.org/license/Apache-2.0).
diff --git a/TESTER_GUIDE.md b/TESTER_GUIDE.md
new file mode 100644
index 00000000..f870798d
--- /dev/null
+++ b/TESTER_GUIDE.md
@@ -0,0 +1,235 @@
+# Dria Node v2 — Tester Guide
+
+Thanks for testing! This guide walks you through building and running a Dria compute node from source.
+
+## 1. Install Prerequisites
+
+You need **Rust** and **cmake**. Pick your OS:
+
+### macOS
+
+```bash
+# Install Rust
+curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
+source ~/.cargo/env
+
+# Install cmake
+brew install cmake
+```
+
+### Linux (Ubuntu/Debian)
+
+```bash
+# Install Rust
+curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
+source ~/.cargo/env
+
+# Install build tools
+sudo apt-get update && sudo apt-get install -y cmake build-essential
+```
+
+### Windows
+
+Open **PowerShell as Administrator** (right-click Start → "Terminal (Admin)") and run these commands one by one:
+
+```powershell
+# Install Rust
+winget install Rustlang.Rustup
+
+# Install C++ build tools (needed to compile the inference engine)
+winget install Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;includeRecommended"
+
+# Install CMake
+winget install -e --id Kitware.CMake
+
+# Install LLVM/Clang (needed by bindgen for llama.cpp bindings)
+winget install -e --id LLVM.LLVM
+```
+
+**Important:** After all three finish, **close PowerShell and open a new one** so the tools are available. To verify everything installed correctly:
+
+```powershell
+rustc --version
+cmake --version
+```
+
+Both should print a version number. If either says "not recognized", restart your PC and try again.
+
+## 2. Build
+
+```bash
+git clone https://github.com/firstbatchxyz/dkn-compute-node.git
+cd dkn-compute-node
+git checkout v2
+cargo build --release
+```
+
+This takes a few minutes (it compiles the inference engine from source).
+
+**Apple Silicon (M1/M2/M3/M4)?** Build with Metal GPU support instead:
+
+```bash
+cargo build --release --features metal
+```
+
+**NVIDIA GPU?** Install the [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads) first, then:
+
+```bash
+cargo build --release --features cuda
+```
+
+## 3. Run Setup
+
+The setup wizard helps you pick and download a model:
+
+```bash
+./target/release/dria-node setup
+```
+
+**Windows (PowerShell):** Use backslashes and `.exe`:
+
+```powershell
+.\target\release\dria-node.exe setup
+```
+
+It will:
+- Detect your RAM and show models that fit
+- Let you pick a model and quantization
+- Download it (once, cached for future runs)
+- Run a test inference to confirm everything works
+
+If you're unsure which model to pick, start with **lfm2.5:1.2b** — it's the smallest (~0.8 GB) and works on any machine.
+
+## 4. Your Wallet Key
+
+You'll need your Ethereum wallet private key. The node uses it to sign messages and prove identity on the network.
+
+This is the 64-character hex string (with or without `0x` prefix). You can export it from MetaMask: Account Details → Show Private Key.
+
+## 5. Start the Node
+
+```bash
+./target/release/dria-node start \
+ --wallet YOUR_KEY_HERE \
+ --model lfm2.5:1.2b
+```
+
+**Windows (PowerShell):**
+
+```powershell
+.\target\release\dria-node.exe start --wallet YOUR_KEY_HERE --model lfm2.5:1.2b
+```
+
+Replace `YOUR_KEY_HERE` with the key from step 4, and `lfm2.5:1.2b` with whatever model you chose in setup.
+
+**If you have a GPU** and built with `--features metal` or `--features cuda`:
+
+```bash
+./target/release/dria-node start \
+ --wallet YOUR_KEY_HERE \
+ --model lfm2.5:1.2b \
+ --gpu-layers -1
+```
+
+### What to expect
+
+```
+INFO node identity address=0x...
+INFO benchmark complete tps=25.3 model=lfm2.5:1.2b
+INFO connected to router node_id=... router=quic.dria.co:4001
+INFO node ready models=["lfm2.5:1.2b"] online=true
+```
+
+That's it — the node is running and accepting tasks. Leave it open. Press **Ctrl+C** to stop.
+
+## 6. Skip the Flags Next Time
+
+Instead of typing flags every time, set environment variables:
+
+```bash
+# Add these to your shell profile (~/.bashrc, ~/.zshrc, etc.)
+export DRIA_WALLET=your_key_here
+export DRIA_MODELS=lfm2.5:1.2b
+export DRIA_GPU_LAYERS=-1
+```
+
+Then just run:
+
+```bash
+./target/release/dria-node start
+```
+
+## Models
+
+| Model | Type | Download | Min RAM |
+|---|---|---|---|
+| qwen3.5:0.8b | Text, Vision | ~0.5 GB | ~1 GB |
+| lfm2.5:1.2b | Text | ~0.8 GB | ~1 GB |
+| lfm2.5-audio:1.5b | Text, Audio | ~1.0 GB | ~1.5 GB |
+| lfm2.5-vl:1.6b | Text, Vision | ~1.2 GB | ~1.5 GB |
+| qwen3.5:2b | Text, Vision | ~1.2 GB | ~2 GB |
+| nanbeige:3b | Text | ~2.0 GB | ~2.5 GB |
+| locooperator:4b | Text | ~2.5 GB | ~3 GB |
+| qwen3.5:9b | Text, Vision | ~6.0 GB | ~7 GB |
+| lfm2:24b-a2b | Text | ~14 GB | ~16 GB |
+| qwen3.5:27b | Text, Vision | ~16 GB | ~18 GB |
+| qwen3.5:35b-a3b | Text, Vision | ~20 GB | ~22 GB |
+| nemotron:30b-a3b | Text | ~24.5 GB | ~27 GB |
+
+Pick one model that fits your RAM. Smaller models are faster to download and easier to test with.
+
+## All Options
+
+| Flag | Env Var | Default | What it does |
+|---|---|---|---|
+| `--wallet` | `DRIA_WALLET` | (required) | Your node identity key |
+| `--model` | `DRIA_MODELS` | (required) | Model(s) to serve |
+| `--router-url` | `DRIA_ROUTER_URL` | `quic.dria.co:4001` | Router to connect to |
+| `--gpu-layers` | `DRIA_GPU_LAYERS` | `0` (CPU) | GPU layers (-1 = all) |
+| `--max-concurrent` | `DRIA_MAX_CONCURRENT` | `1` | Parallel inference tasks |
+| `--data-dir` | `DRIA_DATA_DIR` | `~/.dria` | Where models are cached |
+| `--quant` | `DRIA_QUANT` | Q4_K_M | Override quantization |
+| `--insecure` | `DRIA_INSECURE` | `false` | Skip TLS verification |
+| `--skip-update` | `DRIA_SKIP_UPDATE` | `false` | Skip auto-update check |
+
+## Troubleshooting
+
+**Windows: "dria-node is not recognized"**
+On Windows you must use `.\target\release\dria-node.exe` (backslashes, `.exe` extension). PowerShell does not find executables without the `.exe` suffix.
+
+**"cmake not found" or build errors about C compiler**
+Make sure cmake is installed (step 1). On macOS: `brew install cmake`. On Linux: `sudo apt install cmake build-essential`. On Windows: `winget install -e --id Kitware.CMake` then reopen PowerShell.
+
+**Windows: "dria-node.exe not found in target\release"**
+The build probably failed. Scroll up in your terminal and look for red error messages. The most common cause is missing C++ build tools — run `winget install Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;includeRecommended"`, reopen PowerShell, and rebuild with `cargo build --release`.
+
+**Windows: "Unable to find libclang" or "couldn't find clang.dll"**
+Install LLVM: `winget install -e --id LLVM.LLVM`, reopen PowerShell, and rebuild. If it still can't find it, set the path manually: `$env:LIBCLANG_PATH = "C:\Program Files\LLVM\bin"` then rebuild.
+
+**Build fails**
+Try a clean build: `cargo clean && cargo build --release`. Make sure you're on the `v2` branch: `git checkout v2`.
+
+**"unknown model"**
+Model names are exact. Use the names from the table above (e.g. `lfm2.5:1.2b`, not `lfm-2.5`).
+
+**"all routers unavailable" or "offline mode"**
+The node can't reach the router. Check your internet connection. If you're behind a strict firewall, **UDP port 4001 outbound** must be allowed.
+
+**Slow inference**
+If you have a GPU, make sure you built with `--features metal` (Mac) or `--features cuda` (NVIDIA) and are passing `--gpu-layers -1`.
+
+**Model download stalls or fails**
+Models come from HuggingFace. Try again — it might be a temporary network issue. You can also set `HF_ENDPOINT` if HuggingFace is blocked in your region.
+
+**Want more detail in the logs?**
+
+```bash
+RUST_LOG=debug ./target/release/dria-node start ...
+```
+
+## Reporting Issues
+
+If something goes wrong, please share:
+1. Your OS and hardware (CPU, RAM, GPU)
+2. The command you ran
+3. The full error output
diff --git a/TESTING.md b/TESTING.md
new file mode 100644
index 00000000..03808b01
--- /dev/null
+++ b/TESTING.md
@@ -0,0 +1,366 @@
+# DKN Network Testing Guide
+
+How to test the full router + compute-node stack locally (single machine) and over the internet (two laptops).
+
+## Prerequisites
+
+- Rust toolchain (`rustup`, `cargo`)
+- `openssl` CLI (for generating TLS certs)
+- A HuggingFace account (models download automatically)
+- ~1 GB free disk for the smallest model (`lfm2.5:1.2b`)
+
+Build both binaries first:
+
+```bash
+# Router
+cd dkn-router && cargo build --release
+
+# Compute node (CPU)
+cd dkn-compute-node && cargo build --release
+
+# Compute node (Metal / Apple Silicon)
+cd dkn-compute-node && cargo build --release --features metal
+
+# Compute node (CUDA)
+cd dkn-compute-node && cargo build --release --features cuda
+```
+
+## Generate a wallet key
+
+Any 32-byte hex string works as a test wallet:
+
+```bash
+openssl rand -hex 32
+# example output: a1b2c3d4...64 hex chars total
+```
+
+Save it — you'll pass it to the node via `--wallet`.
+
+---
+
+## Scenario 1: Everything on localhost
+
+### 1. Generate self-signed TLS certs
+
+```bash
+mkdir -p /tmp/dkn-certs
+
+openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 \
+ -keyout /tmp/dkn-certs/key.pem -out /tmp/dkn-certs/cert.pem \
+ -days 365 -nodes -subj "/CN=localhost" \
+ -addext "subjectAltName=DNS:localhost,IP:127.0.0.1"
+```
+
+### 2. Start the router
+
+```bash
+./dkn-router/target/release/dkn-router \
+ --listen-quic 127.0.0.1:4001 \
+ --listen-http 127.0.0.1:8080 \
+ --cert /tmp/dkn-certs/cert.pem \
+ --key /tmp/dkn-certs/key.pem
+```
+
+You should see:
+
+```
+INFO starting DKN router quic=127.0.0.1:4001 http=127.0.0.1:8080
+INFO router ready ...
+```
+
+### 3. Start the compute node
+
+In a second terminal:
+
+```bash
+./dkn-compute-node/target/release/dria-node start \
+ --wallet \
+ --model lfm2.5:1.2b \
+ --router-url https://127.0.0.1:4001 \
+ --insecure \
+ --gpu-layers -1
+```
+
+- `--insecure` skips TLS verification (required for self-signed certs).
+- `--gpu-layers -1` offloads all layers to GPU. Use `0` for CPU-only.
+- First run downloads the model from HuggingFace (~730 MB).
+
+You should see:
+
+```
+INFO node identity address=0x...
+INFO model found in cache ...
+INFO benchmark complete tps=... model=lfm2.5:1.2b
+INFO connected to router node_id=... router=https://127.0.0.1:4001
+INFO node ready ...
+```
+
+### 4. Send a request
+
+```bash
+curl -s http://127.0.0.1:8080/v1/generate \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "lfm2.5:1.2b",
+ "messages": [{"role": "user", "content": "What is 2+2?"}],
+ "max_tokens": 128,
+ "temperature": 0.7
+ }' | python3 -m json.tool
+```
+
+Expected response:
+
+```json
+{
+ "text": "2+2 equals 4...",
+ "model": "lfm2.5:1.2b",
+ "stats": {
+ "tokens_generated": 12,
+ "prompt_tokens": 8,
+ "generation_time_ms": 450,
+ "tokens_per_second": 26.7
+ }
+}
+```
+
+### 5. Check other endpoints
+
+```bash
+# Health check
+curl -s http://127.0.0.1:8080/v1/health | python3 -m json.tool
+
+# List models served by connected nodes
+curl -s http://127.0.0.1:8080/v1/models | python3 -m json.tool
+
+# Batch request
+curl -s http://127.0.0.1:8080/v1/batch \
+ -H "Content-Type: application/json" \
+ -d '{
+ "tasks": [
+ {"model": "lfm2.5:1.2b", "messages": [{"role": "user", "content": "Say hi"}]},
+ {"model": "lfm2.5:1.2b", "messages": [{"role": "user", "content": "Say bye"}]}
+ ],
+ "timeout_secs": 30
+ }' | python3 -m json.tool
+```
+
+### 6. Run multiple nodes (optional)
+
+Start a second node with a different model and wallet on the same machine:
+
+```bash
+./dkn-compute-node/target/release/dria-node start \
+ --wallet $(openssl rand -hex 32) \
+ --model nanbeige:3b \
+ --router-url https://127.0.0.1:4001 \
+ --insecure \
+ --gpu-layers 0
+```
+
+Now `/v1/models` will show both `lfm2.5:1.2b` and `nanbeige:3b`.
+
+---
+
+## Scenario 2: Two laptops over the internet
+
+**Laptop A** = router, **Laptop B** = compute node.
+
+### 1. Find Laptop A's public IP
+
+If Laptop A is behind NAT (home router), you need to either:
+
+- **Port-forward** UDP 4001 and TCP 8080 on the home router to Laptop A's LAN IP.
+- Use a cloud VM (DigitalOcean, AWS, etc.) as Laptop A instead.
+
+Get the public IP:
+
+```bash
+curl -s ifconfig.me
+# e.g. 203.0.113.42
+```
+
+### 2. Generate TLS certs on Laptop A
+
+Generate certs with the public IP as a SAN:
+
+```bash
+export ROUTER_IP=203.0.113.42 # replace with your public IP
+
+mkdir -p /tmp/dkn-certs
+
+openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 \
+ -keyout /tmp/dkn-certs/key.pem -out /tmp/dkn-certs/cert.pem \
+ -days 365 -nodes -subj "/CN=$ROUTER_IP" \
+ -addext "subjectAltName=IP:$ROUTER_IP"
+```
+
+If you have a domain name, use `DNS:yourdomain.com` instead of `IP:...`.
+
+### 3. Start the router on Laptop A
+
+```bash
+./dkn-router/target/release/dkn-router \
+ --listen-quic 0.0.0.0:4001 \
+ --listen-http 0.0.0.0:8080 \
+ --cert /tmp/dkn-certs/cert.pem \
+ --key /tmp/dkn-certs/key.pem
+```
+
+Note `0.0.0.0` to listen on all interfaces.
+
+### 4. Verify connectivity from Laptop B
+
+```bash
+# Check HTTP is reachable
+curl -s http://203.0.113.42:8080/v1/health
+
+# Check QUIC port is open (UDP)
+nc -z -u 203.0.113.42 4001 && echo "open" || echo "blocked"
+```
+
+If the health check returns `{"status":"ok",...}`, HTTP is working. If QUIC is blocked, check firewall/NAT rules for **UDP** port 4001.
+
+### 5. Start the compute node on Laptop B
+
+```bash
+./dkn-compute-node/target/release/dria-node start \
+ --wallet \
+ --model lfm2.5:1.2b \
+ --router-url https://203.0.113.42:4001 \
+ --insecure \
+ --gpu-layers -1
+```
+
+`--insecure` is needed because the cert is self-signed. Once the node connects:
+
+```
+INFO connected to router node_id=... router=https://203.0.113.42:4001
+```
+
+### 6. Send requests from either laptop
+
+From Laptop A (or any machine that can reach the router):
+
+```bash
+curl -s http://203.0.113.42:8080/v1/generate \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "lfm2.5:1.2b",
+ "messages": [{"role": "user", "content": "Hello from the internet!"}],
+ "max_tokens": 64
+ }' | python3 -m json.tool
+```
+
+The HTTP request goes to the router, which forwards it via QUIC to the node on Laptop B, which runs inference and sends the result back.
+
+---
+
+## Scenario 3: LAN testing (two laptops, same network)
+
+Same as Scenario 2 but simpler — no NAT/port-forwarding needed.
+
+### 1. Find Laptop A's LAN IP
+
+```bash
+# macOS
+ipconfig getifaddr en0
+
+# Linux
+hostname -I | awk '{print $1}'
+```
+
+Example: `192.168.1.100`
+
+### 2. Generate certs and start router on Laptop A
+
+```bash
+export ROUTER_IP=192.168.1.100
+
+mkdir -p /tmp/dkn-certs
+
+openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:prime256v1 \
+ -keyout /tmp/dkn-certs/key.pem -out /tmp/dkn-certs/cert.pem \
+ -days 365 -nodes -subj "/CN=$ROUTER_IP" \
+ -addext "subjectAltName=IP:$ROUTER_IP"
+
+./dkn-router/target/release/dkn-router \
+ --listen-quic 0.0.0.0:4001 \
+ --listen-http 0.0.0.0:8080 \
+ --cert /tmp/dkn-certs/cert.pem \
+ --key /tmp/dkn-certs/key.pem
+```
+
+### 3. Start node on Laptop B
+
+```bash
+./dkn-compute-node/target/release/dria-node start \
+ --wallet $(openssl rand -hex 32) \
+ --model lfm2.5:1.2b \
+ --router-url https://192.168.1.100:4001 \
+ --insecure \
+ --gpu-layers -1
+```
+
+### 4. Send requests from Laptop A
+
+```bash
+curl -s http://192.168.1.100:8080/v1/generate \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "lfm2.5:1.2b",
+ "messages": [{"role": "user", "content": "Hello from LAN!"}],
+ "max_tokens": 64
+ }' | python3 -m json.tool
+```
+
+---
+
+## Available models
+
+| Short name | Size | Type | Notes |
+|---|---|---|---|
+| `lfm2.5:1.2b` | 731 MB | text | Fastest, good for testing |
+| `nanbeige:3b` | 2.4 GB | text | |
+| `locooperator:4b` | 2.5 GB | text | |
+| `lfm2.5-vl:1.6b` | 696 MB | vision | Rejects text-only requests are fine, rejects audio |
+| `lfm2.5-audio:1.5b` | 696 MB | audio | Rejects image content |
+| `lfm2:24b-a2b` | 14.4 GB | text | MoE |
+| `qwen3.5:27b` | 16.7 GB | text | |
+| `qwen3.5:35b-a3b` | 19.9 GB | text | MoE |
+
+## Environment variables
+
+All CLI flags can be set via env vars instead:
+
+| Env var | Flag | Default |
+|---|---|---|
+| `DRIA_WALLET` | `--wallet` | (required) |
+| `DRIA_MODELS` | `--model` | (required) |
+| `DRIA_ROUTER_URL` | `--router-url` | `https://router.dria.co` |
+| `DRIA_GPU_LAYERS` | `--gpu-layers` | `0` |
+| `DRIA_MAX_CONCURRENT` | `--max-concurrent` | `1` |
+| `DRIA_DATA_DIR` | `--data-dir` | `~/.dria` |
+| `DRIA_INSECURE` | `--insecure` | `false` |
+| `DRIA_ROUTER_QUIC_ADDR` | `--listen-quic` | `0.0.0.0:4001` |
+| `DRIA_ROUTER_HTTP_ADDR` | `--listen-http` | `0.0.0.0:8080` |
+| `DRIA_ROUTER_CERT` | `--cert` | (required) |
+| `DRIA_ROUTER_KEY` | `--key` | (required) |
+
+## Troubleshooting
+
+| Symptom | Cause | Fix |
+|---|---|---|
+| Node logs `all routers unavailable` | Can't reach router QUIC port | Check firewall allows **UDP** 4001, verify IP/port |
+| Node logs `TLS error` | Cert doesn't match router hostname/IP | Regenerate cert with correct SAN, or use `--insecure` |
+| `curl` to `/v1/generate` returns 503 | No nodes connected | Check node logs, ensure it says `connected to router` |
+| `curl` to `/v1/generate` returns 504 | Node timeout during inference | Increase `timeout_secs` in request, or use a smaller model |
+| Node logs `SHA-256 mismatch` | Corrupted download | Delete `~/.dria/models/` and restart to re-download |
+| `QUIC connect failed: no initial cipher suite` | TLS/QUIC version mismatch | Ensure both router and node are built from the same branch |
+| Batch request partial failures | One model not loaded | Check `/v1/models` to see what's available |
+
+## Verbose logging
+
+```bash
+RUST_LOG=debug ./target/release/dkn-router ...
+RUST_LOG=debug ./target/release/dria-node start ...
+```
diff --git a/compose.yml b/compose.yml
deleted file mode 100644
index c06fa4a2..00000000
--- a/compose.yml
+++ /dev/null
@@ -1,64 +0,0 @@
-services:
- # Compute Node
- compute:
- image: "firstbatch/dkn-compute-node:latest"
- # build: "./" # use this one instead if you want to build locally
- environment:
- RUST_LOG: ${RUST_LOG:-none,dkn_compute=info}
- # Dria
- DKN_WALLET_SECRET_KEY: ${DKN_WALLET_SECRET_KEY}
- DKN_MODELS: ${DKN_MODELS}
- DKN_P2P_LISTEN_ADDR: ${DKN_P2P_LISTEN_ADDR}
- # API Keys
- OPENAI_API_KEY: ${OPENAI_API_KEY}
- GEMINI_API_KEY: ${GEMINI_API_KEY}
- OPENROUTER_API_KEY: ${OPENROUTER_API_KEY}
- # Ollama
- OLLAMA_HOST: ${OLLAMA_HOST}
- OLLAMA_PORT: ${OLLAMA_PORT}
- OLLAMA_AUTO_PULL: ${OLLAMA_AUTO_PULL:-true}
- network_mode: ${DKN_DOCKER_NETWORK_MODE:-bridge}
- extra_hosts:
- # for Linux, we need to add this line manually
- - "host.docker.internal:host-gateway"
- restart: "on-failure"
-
- # Ollama Container (CPU)
- ollama:
- image: ollama/ollama:latest
- ports:
- - 11434:11434
- volumes:
- - ~/.ollama:/root/.ollama
- profiles: [ollama-cpu]
-
- # Ollama Container (ROCM)
- ollama-rocm:
- image: ollama/ollama:rocm
- ports:
- - 11434:11434
- volumes:
- - ~/.ollama:/root/.ollama
- devices:
- - "/dev/kfd"
- - "/dev/dri"
- profiles: [ollama-rocm]
-
- # Ollama Container (CUDA)
- ollama-cuda:
- image: ollama/ollama
- ports:
- - 11434:11434
- volumes:
- - ~/.ollama:/root/.ollama
- deploy:
- resources:
- reservations:
- devices:
- - driver: nvidia
- count: 1
- capabilities: [gpu]
- profiles: [ollama-cuda]
-
-volumes:
- ollama:
diff --git a/compute/Cargo.toml b/compute/Cargo.toml
deleted file mode 100644
index fbd4f289..00000000
--- a/compute/Cargo.toml
+++ /dev/null
@@ -1,64 +0,0 @@
-[package]
-name = "dkn-compute"
-version.workspace = true
-edition.workspace = true
-license.workspace = true
-readme = "README.md"
-authors = ["Erhan Tezcan "]
-
-[dependencies]
-# async stuff
-tokio-util.workspace = true
-tokio.workspace = true
-
-# serialize & deserialize
-serde.workspace = true
-serde_json.workspace = true
-
-# http & networking
-reqwest.workspace = true
-port_check = "0.2.1"
-url = "2.5.0"
-urlencoding = "2.1.3"
-
-# utilities
-dotenvy.workspace = true
-base64 = "0.22.0"
-hex = "0.4.3"
-hex-literal = "0.4.1"
-uuid.workspace = true
-rand.workspace = true
-
-# logging & errors
-env_logger.workspace = true
-log.workspace = true
-eyre.workspace = true
-colored = "3.0.0"
-
-# encryption (ecies) & signatures (ecdsa) & hashing & bloom-filters
-ecies = { version = "0.2", default-features = false, features = ["pure"] }
-libsecp256k1 = "0.7.1"
-
-# machine diagnostics
-# system info
-sysinfo = "0.33.1"
-# gpu info TODO: this gives a build error on Windows
-# wgpu = { version = "23.0.1", features = [
-# "serde",
-# "dx12",
-# "metal",
-# ], default-features = false }
-# public ip
-public-ip-address = "0.3.2"
-
-# dria subcrates
-dkn-p2p = { path = "../p2p" }
-dkn-utils = { path = "../utils", features = ["crypto"] }
-dkn-executor = { path = "../executor" }
-chrono.workspace = true
-
-
-# vendor OpenSSL so that its easier to build cross-platform packages
-[dependencies.openssl]
-version = "*"
-features = ["vendored"]
diff --git a/compute/src/config.rs b/compute/src/config.rs
deleted file mode 100644
index e6678698..00000000
--- a/compute/src/config.rs
+++ /dev/null
@@ -1,183 +0,0 @@
-use dkn_executor::DriaExecutorsManager;
-use dkn_p2p::libp2p::{Multiaddr, PeerId};
-use eyre::{eyre, Result};
-use libsecp256k1::{PublicKey, SecretKey};
-use std::{env, str::FromStr};
-
-use dkn_utils::{
- crypto::{public_key_to_address, secret_to_keypair},
- DriaNetwork, SemanticVersion,
-};
-
-const DEFAULT_TASK_BATCH_SIZE: usize = 5;
-const DEFAULT_P2P_LISTEN_ADDR: &str = "/ip4/0.0.0.0/tcp/4001";
-
-#[derive(Clone)]
-pub struct DriaComputeNodeConfig {
- /// Wallet secret/private key.
- pub secret_key: SecretKey,
- /// Wallet public key, derived from the secret key.
- pub public_key: PublicKey,
- /// Wallet address in hex without `0x` prefix, derived from the public key.
- pub address: String,
- /// Peer ID of the node.
- pub peer_id: PeerId,
- /// Compute node version.
- pub version: SemanticVersion,
- /// P2P listen address, e.g. `/ip4/0.0.0.0/tcp/4001`.
- pub p2p_listen_addr: Multiaddr,
- /// Executor manager, handles models and providers.
- pub executors: DriaExecutorsManager,
- /// Network type of the node.
- pub network: DriaNetwork,
- /// Batch size for batchable tasks (e.g. API-based ones).
- ///
- /// A higher value will help execute more tasks concurrently,
- /// at the risk of hitting rate-limits.
- pub batch_size: usize,
- /// An optional first-attempt RPC address, will be dialled at startup.
- ///
- /// TODO: this is `None` after startup due to `Option::take`, can we do any better?
- pub initial_rpc_addr: Option,
- /// Execution platform, mainly for diagnostics.
- ///
- /// Given by `DKN_EXEC_PLATFORM`.
- pub exec_platform: String,
-}
-
-#[allow(clippy::new_without_default)]
-impl DriaComputeNodeConfig {
- /// Creates new config from environment variables.
- pub fn new(executors: DriaExecutorsManager) -> Self {
- let secret_key = match env::var("DKN_WALLET_SECRET_KEY") {
- Ok(secret_env) => {
- let secret_dec = hex::decode(secret_env.trim_start_matches("0x"))
- .expect("Secret key should be 32-bytes hex encoded.");
-
- // if secret key is all-zeros, create one randomly
- // this is useful for testing & creating nodes on the fly
- if secret_dec.iter().all(|b| b == &0) {
- SecretKey::random(&mut rand::thread_rng())
- } else {
- SecretKey::parse_slice(&secret_dec).expect("Secret key should be parseable.")
- }
- }
- Err(err) => {
- log::error!("No secret key provided: {err}");
- panic!("Please provide a secret key.");
- }
- };
- log::info!(
- "Node Secret Key: 0x{}{}",
- hex::encode(&secret_key.serialize()[0..1]),
- ".".repeat(64)
- );
-
- let public_key = PublicKey::from_secret_key(&secret_key);
- log::info!(
- "Node Public Key: 0x{}",
- hex::encode(public_key.serialize_compressed())
- );
-
- // print address
- let address = hex::encode(public_key_to_address(&public_key));
- log::info!("Node Address: 0x{address}");
-
- // to this here to log the peer id at start
- let peer_id = secret_to_keypair(&secret_key).public().to_peer_id();
- log::info!("Node PeerID: {peer_id}");
-
- // parse listen address
- let p2p_listen_addr_str = env::var("DKN_P2P_LISTEN_ADDR")
- .map(|addr| addr.trim_matches('"').to_string())
- .unwrap_or(DEFAULT_P2P_LISTEN_ADDR.to_string());
- let p2p_listen_addr = Multiaddr::from_str(&p2p_listen_addr_str)
- .expect("could not parse the given P2P listen address.");
-
- // parse network type
- let network_type = env::var("DKN_NETWORK")
- // if there is an explicit value, default to testnet on error
- .map(|s| DriaNetwork::try_from(s.as_str()).unwrap_or(DriaNetwork::Testnet))
- // if there is no explicit value, default to mainnet
- .unwrap_or(DriaNetwork::Mainnet);
- if network_type == DriaNetwork::Testnet {
- log::warn!("Using testnet network!");
- }
-
- // parse batch size
- let batch_size = env::var("DKN_BATCH_SIZE")
- .map(|s| s.parse::().unwrap_or(DEFAULT_TASK_BATCH_SIZE))
- .unwrap_or(DEFAULT_TASK_BATCH_SIZE);
-
- // parse version
- let version = env!("CARGO_PKG_VERSION")
- .parse()
- .expect("could not parse version");
-
- // parse initial rpc address, if any
- let initial_rpc_addr = env::var("DKN_INITIAL_RPC_ADDR")
- .ok()
- .and_then(|addr| if addr.is_empty() { None } else { Some(addr) })
- .map(|addr| {
- Multiaddr::from_str(&addr).expect("could not parse the given initial RPC address.")
- });
-
- // parse execution platform
- let exec_platform = env::var("DKN_EXEC_PLATFORM").unwrap_or_else(|_| "unknown".to_string());
-
- Self {
- secret_key,
- public_key,
- address,
- peer_id,
- version,
- executors,
- p2p_listen_addr,
- network: network_type,
- batch_size,
- initial_rpc_addr,
- exec_platform,
- }
- }
-
- /// Asserts that the configured listen address is free.
- /// Throws an error if the address is already in use.
- ///
- /// Uses `is_port_reachable` function internally, which makes a simple
- /// TCP connection to the given address.
- ///
- /// Can be inlined because the function is small and called only once.
- #[inline]
- pub fn assert_address_not_in_use(&self) -> Result<()> {
- use dkn_p2p::libp2p::multiaddr::Protocol;
- use port_check::is_port_reachable;
- use std::net::{Ipv4Addr, SocketAddrV4};
-
- let address_in_use = self
- .p2p_listen_addr
- .iter()
- // find the port within our multiaddr
- .find_map(|protocol| match protocol {
- Protocol::Tcp(port) => Some(port),
- _ => None,
- })
- // check if its reachable or not
- .map(|port| is_port_reachable(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)))
- .unwrap_or_else(|| {
- log::error!(
- "could not find any TCP port in the given address: {:?}",
- self.p2p_listen_addr
- );
- false
- });
-
- if address_in_use {
- return Err(eyre!(
- "Listen address {} is already in use.",
- self.p2p_listen_addr
- ));
- }
-
- Ok(())
- }
-}
diff --git a/compute/src/lib.rs b/compute/src/lib.rs
deleted file mode 100644
index 56aed735..00000000
--- a/compute/src/lib.rs
+++ /dev/null
@@ -1,12 +0,0 @@
-pub mod config;
-pub mod node;
-pub mod reqres;
-pub mod utils;
-pub mod workers;
-
-/// Crate version of the compute node.
-/// This value is attached within the published messages.
-pub const DRIA_COMPUTE_NODE_VERSION: &str = env!("CARGO_PKG_VERSION");
-
-pub use config::DriaComputeNodeConfig;
-pub use node::DriaComputeNode;
diff --git a/compute/src/main.rs b/compute/src/main.rs
deleted file mode 100644
index bf1c9eff..00000000
--- a/compute/src/main.rs
+++ /dev/null
@@ -1,209 +0,0 @@
-use dkn_compute::*;
-use dkn_executor::{DriaExecutorsManager, Model};
-use eyre::Result;
-use std::env;
-use tokio_util::{sync::CancellationToken, task::TaskTracker};
-use workers::task::TaskWorker;
-
-#[tokio::main]
-async fn main() -> Result<()> {
- // load a particular environment file specified by DKN_COMPUTE_ENV, or `.env` by default
- let env_path = env::var("DKN_COMPUTE_ENV").unwrap_or_else(|_| ".env".to_string());
- let dotenv_result = dotenvy::from_path(&env_path);
-
- env_logger::builder()
- .format_timestamp(Some(env_logger::TimestampPrecision::Millis))
- .filter(None, log::LevelFilter::Off)
- .filter_module("dkn_compute", log::LevelFilter::Info)
- .filter_module("dkn_p2p", log::LevelFilter::Info)
- .filter_module("dkn_utils", log::LevelFilter::Info)
- .filter_module("dkn_executor", log::LevelFilter::Info)
- .filter_module("libp2p", log::LevelFilter::Error)
- .parse_default_env() // reads RUST_LOG variable
- .init();
-
- log::info!(
- r#"
-
-██████╗ ██████╗ ██╗ █████╗
-██╔══██╗██╔══██╗██║██╔══██╗ Dria Compute Node
-██║ ██║██████╔╝██║███████║ v{DRIA_COMPUTE_NODE_VERSION}
-██║ ██║██╔══██╗██║██╔══██║ https://dria.co
-██████╔╝██║ ██║██║██║ ██║
-╚═════╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═╝
-"#
- );
-
- // log about env usage
- match dotenv_result {
- Ok(_) => log::info!("Loaded environment file from {env_path}"),
- Err(err) => log::warn!("Could not load environment file from {env_path}: {err}"),
- }
-
- // task tracker for multiple threads
- let task_tracker = TaskTracker::new();
- let cancellation = CancellationToken::new();
-
- // spawn the background task to wait for termination signals
- let task_tracker_to_close = task_tracker.clone();
- let cancellation_token = cancellation.clone();
- task_tracker.spawn(async move {
- if let Ok(Ok(duration_secs)) =
- env::var("DKN_EXIT_TIMEOUT").map(|s| s.to_string().parse::())
- {
- // the timeout is done for profiling only, and should not be used in production
- log::warn!("Waiting for {duration_secs} seconds before exiting.");
- tokio::time::sleep(tokio::time::Duration::from_secs(duration_secs)).await;
-
- log::warn!("Exiting due to DKN_EXIT_TIMEOUT.");
- cancellation_token.cancel();
- } else if let Err(err) = wait_for_termination(cancellation_token.clone()).await {
- // if there is no timeout, we wait for termination signals here
- log::error!("Error waiting for termination: {err:?}");
- log::error!("Cancelling due to unexpected error.");
- cancellation_token.cancel();
- };
-
- // close tracker in any case
- task_tracker_to_close.close();
- });
-
- // create configurations
- let models = Model::from_csv(env::var("DKN_MODELS").unwrap_or_default());
- let executors_config = DriaExecutorsManager::new_from_env_for_models(models.into_iter())?;
- if executors_config.models.is_empty() {
- return Err(eyre::eyre!("No models were provided, make sure to restart with at least one model provided within DKN_MODELS."));
- }
-
- log::info!(
- "Initial provided models are: {}",
- executors_config.get_model_names().join(", ")
- );
- let mut config = DriaComputeNodeConfig::new(executors_config);
-
- // check address in use
- config.assert_address_not_in_use()?;
-
- // check services & models, will exit if there is an error
- // since service check can take time, we allow early-exit here as well
- let model_perf = tokio::select! {
- result = config.executors.check_services() => result,
- _ = cancellation.cancelled() => {
- log::info!("Service check cancelled, exiting.");
- return Ok(());
- }
- };
-
- if config.executors.models.is_empty() {
- return Err(eyre::eyre!(
- "No valid models left after service checks, exiting."
- ));
- } else {
- log::info!(
- "Using models: {}\n{}",
- config.executors.get_model_names().join(", "),
- model_perf
- .iter()
- .map(|(model, perf)| format!("{model}: {perf}"))
- .collect::>()
- .join("\n")
- );
- }
- // create the node
- let batch_size = config.batch_size;
- let (mut node, p2p, worker_batch, worker_single) =
- DriaComputeNode::new(config, model_perf).await?;
-
- // spawn p2p client first
- log::info!("Spawning peer-to-peer client thread.");
- task_tracker.spawn(async move { p2p.run().await });
-
- // spawn batch worker thread if we are using such models (e.g. OpenAI, Gemini, OpenRouter)
- if let Some(mut worker_batch) = worker_batch {
- assert!(
- batch_size <= TaskWorker::MAX_BATCH_SIZE,
- "batch size too large"
- );
- log::info!("Spawning batch executor worker thread. (batch size {batch_size})");
- task_tracker.spawn(async move { worker_batch.run_batch(batch_size).await });
- }
-
- // spawn single worker thread if we are using such models (e.g. Ollama)
- if let Some(mut worker_single) = worker_single {
- log::info!("Spawning single executor worker thread.");
- task_tracker.spawn(async move { worker_single.run_series().await });
- }
-
- // spawn compute node thread
- log::info!("Spawning compute node thread.");
- let node_token = cancellation.clone();
- task_tracker.spawn(async move {
- node.run(node_token).await;
- log::info!("Closing node.")
- });
-
- // wait for all tasks to finish
- task_tracker.wait().await;
- log::info!("All tasks have exited succesfully.");
-
- log::info!("Bye!");
- Ok(())
-}
-
-/// Waits for various termination signals, and cancels the given token when the signal is received.
-///
-/// Handles Unix and Windows [target families](https://doc.rust-lang.org/reference/conditional-compilation.html#target_family).
-async fn wait_for_termination(cancellation: CancellationToken) -> Result<()> {
- #[cfg(unix)]
- {
- use tokio::signal::unix::{signal, SignalKind};
- let mut sigterm = signal(SignalKind::terminate())?; // Docker sends SIGTERM
- let mut sigint = signal(SignalKind::interrupt())?; // Ctrl+C sends SIGINT
- tokio::select! {
- _ = sigterm.recv() => log::warn!("Recieved SIGTERM"),
- _ = sigint.recv() => log::warn!("Recieved SIGINT"),
- _ = cancellation.cancelled() => {
- // no need to wait if cancelled anyways
- // although this is not likely to happen
- return Ok(());
- }
- };
-
- cancellation.cancel();
- }
-
- #[cfg(windows)]
- {
- use tokio::signal::windows;
-
- // https://learn.microsoft.com/en-us/windows/console/handlerroutine
- let mut signal_c = windows::ctrl_c()?;
- let mut signal_break = windows::ctrl_break()?;
- let mut signal_close = windows::ctrl_close()?;
- let mut signal_shutdown = windows::ctrl_shutdown()?;
-
- tokio::select! {
- _ = signal_c.recv() => log::warn!("Received CTRL_C"),
- _ = signal_break.recv() => log::warn!("Received CTRL_BREAK"),
- _ = signal_close.recv() => log::warn!("Received CTRL_CLOSE"),
- _ = signal_shutdown.recv() => log::warn!("Received CTRL_SHUTDOWN"),
- _ = cancellation.cancelled() => {
- // no need to wait if cancelled anyways
- // although this is not likely to happen
- return Ok(());
- }
- };
-
- cancellation.cancel();
- }
-
- #[cfg(not(any(unix, windows)))]
- {
- log::error!("No signal handling for this platform: {}", env::consts::OS);
- cancellation.cancel();
- }
-
- log::info!("Terminating the application...");
-
- Ok(())
-}
diff --git a/compute/src/node/core.rs b/compute/src/node/core.rs
deleted file mode 100644
index e91c7f6c..00000000
--- a/compute/src/node/core.rs
+++ /dev/null
@@ -1,167 +0,0 @@
-use colored::Colorize;
-use dkn_p2p::libp2p::{Multiaddr, PeerId};
-use dkn_utils::{
- payloads::{HEARTBEAT_TOPIC, SPECS_TOPIC},
- DriaMessage,
-};
-use eyre::{eyre, Result};
-use std::time::Duration;
-use tokio_util::sync::CancellationToken;
-
-use crate::{reqres::HeartbeatRequester, DriaComputeNode};
-
-impl DriaComputeNode {
- /// Runs the main loop of the compute node.
- /// This method is not expected to return until cancellation occurs for the given token.
- pub async fn run(&mut self, cancellation: CancellationToken) {
- // initialize the points client
- self.points_client.initialize().await;
-
- /// Duration between refreshing for diagnostic prints.
- const DIAGNOSTIC_REFRESH_INTERVAL_SECS: Duration = Duration::from_secs(45);
- /// Duration between refreshing for points update.
- const POINTS_REFRESH_INTERVAL_SECS: Duration = Duration::from_secs(180);
- /// Duration between refreshing the available nodes.
- const RPC_LIVENESS_REFRESH_INTERVAL_SECS: Duration = Duration::from_secs(2 * 60);
- /// Duration between each specs update sent to the RPC.
- const SPECS_INTERVAL_SECS: Duration = Duration::from_secs(60 * 5);
-
- let mut diagnostic_refresh_interval =
- tokio::time::interval(DIAGNOSTIC_REFRESH_INTERVAL_SECS);
- diagnostic_refresh_interval.tick().await; // move each one tick
- let mut rpc_liveness_refresh_interval =
- tokio::time::interval(RPC_LIVENESS_REFRESH_INTERVAL_SECS);
- rpc_liveness_refresh_interval.tick().await; // move each one tick
-
- // tick the first time a bit earlier
- let mut points_refresh_interval = tokio::time::interval(POINTS_REFRESH_INTERVAL_SECS);
- points_refresh_interval.tick().await;
- points_refresh_interval.reset_after(POINTS_REFRESH_INTERVAL_SECS / 12);
-
- // move one tick, and wait at least a third of the diagnostics
- let mut heartbeat_interval = tokio::time::interval(HeartbeatRequester::HEARTBEAT_DEADLINE);
- heartbeat_interval.tick().await;
- heartbeat_interval.reset_after(DIAGNOSTIC_REFRESH_INTERVAL_SECS / 3);
-
- // move one tick, and wait a little bit
- let mut specs_interval = tokio::time::interval(SPECS_INTERVAL_SECS);
- specs_interval.tick().await;
- specs_interval.reset_after(DIAGNOSTIC_REFRESH_INTERVAL_SECS / 6);
-
- loop {
- tokio::select! {
- // a task is completed by the worker & should be responded to the requesting peer
- task_response_msg_opt = self.task_output_rx.recv() => {
- if let Some(task_response_msg) = task_response_msg_opt {
- if let Err(err) = self.send_task_output(task_response_msg).await {
- log::error!("Error responding to task: {err:?}");
- }
- } else {
- log::error!("task_output_rx channel closed unexpectedly, we still have {} batch and {} single tasks.", self.pending_tasks_batch.len(), self.pending_tasks_single.len());
- break;
- }
- },
-
- // a Request or Response is received by the p2p client
- reqres_msg_opt = self.reqres_rx.recv() => {
- if let Some((peer_id, message)) = reqres_msg_opt {
- self.handle_reqres(peer_id, message).await;
- } else {
- log::error!("reqres_rx channel closed unexpectedly.");
- break;
- }
- },
-
- // check peer count every now and then
- _ = diagnostic_refresh_interval.tick() => self.handle_diagnostic_refresh().await,
-
- // check RPC, and get a new one if we are disconnected
- _ = rpc_liveness_refresh_interval.tick() => {
- let is_connected = self.handle_rpc_liveness_check().await;
- if !is_connected {
- // make sure we reset the heartbeat and specs intervals so that
- // we dont wait the entire duration for this new connection
- log::info!("Connecting was re-attempted, resetting timers.");
- heartbeat_interval.reset_after(Duration::from_secs(5));
- specs_interval.reset_after(Duration::from_secs(5));
- }
- },
-
- // log points every now and then
- _ = points_refresh_interval.tick() => self.handle_points_refresh().await,
-
- // send a heartbeat request to publish liveness info
- _ = heartbeat_interval.tick() => {
- if let Err(e) = self.send_heartbeat().await {
- log::error!("Error making {}: {:?}", HEARTBEAT_TOPIC.blue(), e);
- }
- },
-
- // send specs to the RPC
- _ = specs_interval.tick() => {
- if let Err(e) = self.send_specs().await {
- log::error!("Error sending {}: {:?}", SPECS_TOPIC.green(), e);
- }
- },
-
- // check if the cancellation token is cancelled
- // this is expected to be cancelled by the main thread with signal handling
- _ = cancellation.cancelled() => {
- log::info!("Cancellation received, shutting down the node.");
- break;
- },
- }
- }
-
- // print one final diagnostic as a summary
- self.handle_diagnostic_refresh().await;
-
- // shutdown channels
- if let Err(err) = self.shutdown().await {
- log::error!("Could not shutdown the node gracefully: {err:?}");
- }
- }
-
- /// Shorthand method to create a signed message with the given data and topic.
- ///
- /// Topic was previously used for GossipSub, but kept for verbosity.
- #[inline(always)]
- pub fn new_message(&self, data: impl AsRef<[u8]>, topic: impl ToString) -> DriaMessage {
- DriaMessage::new_signed(
- data,
- topic,
- self.p2p.protocol().name.clone(),
- &self.config.secret_key,
- self.config.version,
- )
- }
-
- /// Dial the given peer at the given address.
- pub async fn dial_with_timeout(&mut self, peer_id: PeerId, addr: Multiaddr) -> Result<()> {
- // while not yet known, some people get stuck during the dialling step,
- // this timeout prevents that.
- const DIAL_TIMEOUT: Duration = Duration::from_secs(10);
-
- match tokio::time::timeout(DIAL_TIMEOUT, self.p2p.dial(peer_id, addr)).await {
- Err(timeout) => Err(eyre!("Timeout dialling RPC node: {}", timeout)),
- Ok(result) => result, // this is also a `Result` enum
- }
- }
-
- /// Shutdown channels between p2p, worker and yourself.
- ///
- /// Can be inlined as it is called only once from very few places.
- #[inline]
- pub async fn shutdown(&mut self) -> Result<()> {
- log::debug!("Sending shutdown command to p2p client.");
- self.p2p.shutdown().await?;
-
- log::debug!("Closing task output channel.");
- self.task_output_rx.close();
-
- log::debug!("Closing reqres channel.");
- self.reqres_rx.close();
-
- Ok(())
- }
-}
diff --git a/compute/src/node/diagnostic.rs b/compute/src/node/diagnostic.rs
deleted file mode 100644
index e407a783..00000000
--- a/compute/src/node/diagnostic.rs
+++ /dev/null
@@ -1,152 +0,0 @@
-use colored::Colorize;
-use std::time::Duration;
-
-use crate::{node::rpc::DriaRPC, DriaComputeNode, DRIA_COMPUTE_NODE_VERSION};
-
-/// Number of seconds such that if the last heartbeat ACK is older than this, the node is considered unreachable.
-/// This must be at least greated than the heartbeat interval duration, and the liveness check duration.
-const HEARTBEAT_LIVENESS_SECS: Duration = Duration::from_secs(4 * 60);
-
-impl DriaComputeNode {
- /// Returns the task count within the channels, `single` and `batch`.
- #[inline(always)]
- pub fn get_pending_task_count(&self) -> [usize; 2] {
- [
- self.pending_tasks_single.len(),
- self.pending_tasks_batch.len(),
- ]
- }
-
- /// Peer refresh simply reports the peer count to the user.
- pub(crate) async fn handle_diagnostic_refresh(&mut self) {
- let mut diagnostics = vec![format!("Diagnostics (v{}):", DRIA_COMPUTE_NODE_VERSION)];
-
- // completed tasks count is printed as well in debug
- if log::log_enabled!(log::Level::Debug) {
- diagnostics.push(format!(
- "Completed Tasks (single/batch): {} / {}",
- self.completed_tasks_single, self.completed_tasks_batch
- ));
-
- diagnostics.push(format!(
- "RPC {}: {}",
- self.dria_rpc.peer_id,
- if self
- .p2p
- .is_connected(self.dria_rpc.peer_id)
- .await
- .unwrap_or(false)
- {
- "Connected".green()
- } else {
- "Disconnected".red()
- }
- ));
- }
-
- // print peer id and address
- diagnostics.push(format!("Peer ID: {}", self.config.peer_id));
- diagnostics.push(format!("Address: 0x{}", self.config.address));
-
- // print models
- diagnostics.push(format!(
- "Models: {}",
- self.config.executors.get_model_names().join(", ")
- ));
-
- // if we have not received pings for a while, we are considered offline
- let is_offline = chrono::Utc::now() > self.last_heartbeat_at + HEARTBEAT_LIVENESS_SECS;
-
- // if we have not yet received a heartbeat response, we are still connecting
- if self.num_heartbeats == 0 {
- // if we didnt have any pings, we might still be connecting
- diagnostics.push(format!("Node Status: {}", "CONNECTING".yellow()));
- } else {
- diagnostics.push(format!(
- "Node Status: {}",
- if is_offline {
- "OFFLINE".red()
- } else {
- "ONLINE".green()
- }
- ));
- }
-
- log::info!("{}", diagnostics.join("\n "));
-
- // if offline, print this error message as well
- if is_offline {
- log::error!(
- "Node has not received any pings for at least {} seconds & it may be unreachable!\nPlease restart your node!",
- HEARTBEAT_LIVENESS_SECS.as_secs()
- );
- }
- }
-
- /// Dials the existing RPC node if we are not connected to it.
- ///
- /// If there is an error while doing that, it will try to get a new RPC node and dial it.
- ///
- /// Returns `true` if the RPC is connected, `false` otherwise.
- pub(crate) async fn handle_rpc_liveness_check(&mut self) -> bool {
- log::debug!("Checking RPC connections for diagnostics.");
-
- // check if we are connected
- let is_connected = self
- .p2p
- .is_connected(self.dria_rpc.peer_id)
- .await
- .unwrap_or(false);
-
- // if we are not connected, get a new RPC and dial it again
- if !is_connected {
- // if we also cannot dial it, get a new RPC node
- log::warn!(
- "Connection to RPC {} is lost, geting a new one!",
- self.dria_rpc.addr,
- );
- match DriaRPC::new_for_network(self.dria_rpc.network, &self.config.version).await {
- Ok(new_rpc) => {
- self.dria_rpc = new_rpc;
-
- // now dial this new RPC again
- if let Err(err) = self
- .dial_with_timeout(self.dria_rpc.peer_id, self.dria_rpc.addr.clone())
- .await
- {
- // worst-case we cant dial this one too, just leave it for the next diagnostic
- log::error!("Could not dial the new RPC: {err:?}");
- }
- }
- Err(err) => {
- log::error!("Could not get a new RPC node: {err:?}");
- }
- };
- } else {
- log::debug!("Connection with {} is intact.", self.dria_rpc.peer_id);
- }
-
- // return the connection status
- is_connected
- }
-
- /// Updates the points for the given address.
- #[inline]
- pub(crate) async fn handle_points_refresh(&mut self) {
- // get points from the API
- match self.points_client.get_points().await {
- Ok(steps) => {
- log::info!(
- "{}: {} total, {} earned in this run, within top {}%",
- "$DRIA Points".purple(),
- steps.score,
- steps.score - self.points_client.initial,
- steps.percentile
- );
- }
- Err(err) => {
- log::error!("Could not get $DRIA points info: {err:?}");
- }
- }
- }
-}
diff --git a/compute/src/node/mod.rs b/compute/src/node/mod.rs
deleted file mode 100644
index 0554c54c..00000000
--- a/compute/src/node/mod.rs
+++ /dev/null
@@ -1,166 +0,0 @@
-use dkn_executor::Model;
-use dkn_p2p::{
- libp2p::PeerId, DriaP2PClient, DriaP2PCommander, DriaP2PProtocol, DriaReqResMessage,
-};
-use dkn_utils::{crypto::secret_to_keypair, payloads::SpecModelPerformance};
-use eyre::Result;
-use std::collections::{HashMap, HashSet};
-use tokio::sync::mpsc;
-use uuid::Uuid;
-
-use crate::{
- config::*,
- utils::{DriaPointsClient, SpecCollector},
- workers::task::{TaskWorker, TaskWorkerInput, TaskWorkerMetadata, TaskWorkerOutput},
-};
-
-mod core;
-mod diagnostic;
-mod reqres;
-mod rpc;
-use rpc::DriaRPC;
-
-/// Buffer size for message publishes.
-const PUBLISH_CHANNEL_BUFSIZE: usize = 1024;
-
-pub struct DriaComputeNode {
- /// Compute node configuration.
- pub config: DriaComputeNodeConfig,
- /// Chosen RPC node.
- pub dria_rpc: DriaRPC,
- /// Peer-to-peer client commander to interact with the network.
- pub p2p: DriaP2PCommander,
- /// The last time the node had an acknowledged heartbeat.
- /// If this is too much, we can say that the node is not reachable by RPC.
- pub(crate) last_heartbeat_at: chrono::DateTime,
- /// Number of pings received.
- pub(crate) num_heartbeats: u64,
- /// A mapping of heartbeat UUIDs to their deadlines.
- /// This is used to track the heartbeats, and their acknowledgements.
- pub(crate) heartbeats_reqs: HashMap>,
- /// A mapping of specs UUIDs to their deadlines.
- /// This is used to track the specs, and their acknowledgements.
- pub(crate) specs_reqs: HashSet,
- /// Request-response message receiver, can have both a request or a response.
- reqres_rx: mpsc::Receiver<(PeerId, DriaReqResMessage)>,
- /// Task response receiver, will respond to the request-response channel with the given result.
- task_output_rx: mpsc::Receiver,
- /// Task worker transmitter to send batchable tasks.
- task_request_batch_tx: Option>,
- /// Task worker transmitter to send single tasks.
- task_request_single_tx: Option>,
- /// Single tasks, key is `row_id`, which has negligible probability of collision.
- pub pending_tasks_single: HashMap,
- // Batchable tasks, key is `row_id`, which has negligible probability of collision.
- pub pending_tasks_batch: HashMap,
- /// Completed single tasks count
- completed_tasks_single: usize,
- /// Completed batch tasks count
- completed_tasks_batch: usize,
- /// Specifications collector.
- spec_collector: SpecCollector,
- /// Points client.
- points_client: DriaPointsClient,
-}
-
-impl DriaComputeNode {
- /// Creates a new `DriaComputeNode` with the given configuration and cancellation token.
- ///
- /// Returns the node instance and p2p client together. P2p MUST be run in a separate task before this node is used at all.
- pub async fn new(
- mut config: DriaComputeNodeConfig,
- model_perf: HashMap,
- ) -> Result<(
- DriaComputeNode,
- DriaP2PClient,
- Option,
- Option,
- )> {
- // create the keypair from secret key
- let keypair = secret_to_keypair(&config.secret_key);
-
- // dial the RPC node
- let dria_rpc = if let Some(addr) = config.initial_rpc_addr.take() {
- log::info!("Using initial RPC address: {addr}");
- DriaRPC::new(addr, config.network).expect("could not get RPC to connect to")
- } else {
- DriaRPC::new_for_network(config.network, &config.version)
- .await
- .expect("could not get RPC to connect to")
- };
-
- // we are using the major.minor version as the P2P version
- // so that patch versions do not interfere with the protocol
- let protocol = DriaP2PProtocol::new_major_minor(config.network.protocol_name());
- log::info!("Using identity: {protocol}");
-
- // create p2p client
- let (p2p_client, p2p_commander, request_rx) = DriaP2PClient::new(
- keypair,
- config.p2p_listen_addr.clone(),
- &dria_rpc.addr,
- protocol,
- )?;
-
- // create channel for task executors, all workers use the same publish channel
- let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_BUFSIZE);
-
- // check if we should create a worker for batch executor
- let (task_batch_worker, task_batch_tx) =
- if config.executors.providers.keys().any(|p| p.is_batchable()) {
- let (worker, sender) = TaskWorker::new(publish_tx.clone());
- (Some(worker), Some(sender))
- } else {
- (None, None)
- };
-
- // check if we should create a worker for single executor
- let (task_single_worker, task_single_tx) =
- if config.executors.providers.keys().any(|p| !p.is_batchable()) {
- let (worker, sender) = TaskWorker::new(publish_tx);
- (Some(worker), Some(sender))
- } else {
- (None, None)
- };
-
- let model_names = config.executors.get_model_names();
- let points_client = DriaPointsClient::new(&config.address, &config.network)?;
-
- let spec_collector = SpecCollector::new(
- model_names.clone(),
- model_perf,
- config.version,
- config.exec_platform.clone(),
- p2p_client.peer_id,
- );
- Ok((
- DriaComputeNode {
- config,
- p2p: p2p_commander,
- dria_rpc,
- points_client,
- // receivers
- task_output_rx: publish_rx,
- reqres_rx: request_rx,
- // transmitters
- task_request_batch_tx: task_batch_tx,
- task_request_single_tx: task_single_tx,
- // task trackers
- pending_tasks_single: HashMap::new(),
- pending_tasks_batch: HashMap::new(),
- completed_tasks_single: 0,
- completed_tasks_batch: 0,
- // heartbeats
- heartbeats_reqs: HashMap::new(),
- last_heartbeat_at: chrono::Utc::now(),
- num_heartbeats: 0,
- // specs
- specs_reqs: HashSet::new(),
- spec_collector,
- },
- p2p_client,
- task_batch_worker,
- task_single_worker,
- ))
- }
-}
diff --git a/compute/src/node/reqres.rs b/compute/src/node/reqres.rs
deleted file mode 100644
index 10368c7f..00000000
--- a/compute/src/node/reqres.rs
+++ /dev/null
@@ -1,211 +0,0 @@
-use colored::Colorize;
-use dkn_p2p::libp2p::{
- request_response::{OutboundRequestId, ResponseChannel},
- PeerId,
-};
-use dkn_p2p::DriaReqResMessage;
-use dkn_utils::{
- payloads::{HEARTBEAT_TOPIC, SPECS_TOPIC, TASK_REQUEST_TOPIC},
- DriaMessage,
-};
-use eyre::Result;
-
-use crate::{reqres::*, workers::task::TaskWorkerOutput};
-
-use super::DriaComputeNode;
-
-impl DriaComputeNode {
- /// Handles a generic request-response message received from the network.
- ///
- /// - Request is forwarded to [`handle_request`](DriaComputeNode::handle_request) method.
- /// - Response is forwarded to [`handle_response`](DriaComputeNode::handle_response) method.
- ///
- /// Does not return an error, but simply logs it to [`log::error`].
- pub(crate) async fn handle_reqres(&mut self, peer_id: PeerId, message: DriaReqResMessage) {
- match message {
- // make sure that the `channel` here is NOT DROPPED until a response is sent,
- // otherwise you will get an error
- DriaReqResMessage::Request {
- request,
- request_id,
- channel,
- } => {
- log::debug!("Received a request ({request_id}) from {peer_id}");
-
- // ensure that message is from the known RPCs
- if self.dria_rpc.peer_id != peer_id {
- log::warn!("Received request from unauthorized source: {peer_id}");
- log::debug!("Allowed source: {}", self.dria_rpc.peer_id);
- } else if let Err(err) = self.handle_request(peer_id, &request, channel).await {
- log::error!("Error handling request: {err:?}");
- }
- }
-
- DriaReqResMessage::Response {
- response,
- request_id,
- } => {
- log::debug!("Received a response ({request_id}) from {peer_id}");
- if let Err(err) = self.handle_response(peer_id, request_id, response).await {
- log::error!("Error handling response: {err:?}");
- }
- }
- };
- }
-
- /// Handles a [`request_response`] response received from the network.
- ///
- /// - Internally, the data is expected to be some JSON serialized data that is expected to be parsed and handled.
- /// - Can be inlined because it is only called by [`DriaComputeNode::handle_reqres`].
- #[inline]
- async fn handle_response(
- &mut self,
- peer_id: PeerId,
- request_id: OutboundRequestId,
- data: Vec,
- ) -> Result<()> {
- if peer_id != self.dria_rpc.peer_id {
- log::warn!("Received response from unauthorized source: {peer_id}");
- log::debug!("Allowed source: {}", self.dria_rpc.peer_id);
- }
-
- if let Ok(heartbeat_response) = HeartbeatRequester::try_parse_response(&data) {
- log::info!(
- "Received a {} response ({request_id}) from {peer_id}",
- HEARTBEAT_TOPIC.blue(),
- );
- HeartbeatRequester::handle_ack(self, heartbeat_response).await
- } else if let Ok(spec_response) = SpecRequester::try_parse_response(&data) {
- log::info!(
- "Received a {} response ({request_id}) from {peer_id}",
- SPECS_TOPIC.green(),
- );
- SpecRequester::handle_ack(self, spec_response).await
- } else {
- Err(eyre::eyre!("Received unhandled request from {}", peer_id))
- }
- }
-
- /// Handles a [`request_response`] request received from the network.
- ///
- /// - Internally, the data is expected to be some JSON serialized data that is expected to be parsed and handled.
- /// - Can be inlined because it is only called by [`DriaComputeNode::handle_reqres`].
- async fn handle_request(
- &mut self,
- peer_id: PeerId,
- message_data: &[u8],
- channel: ResponseChannel>,
- ) -> Result<()> {
- let message = DriaMessage::from_slice_checked(
- message_data,
- self.p2p.protocol().name.clone(),
- self.config.version,
- )?;
-
- match message.topic.as_str() {
- TASK_REQUEST_TOPIC => self.handle_task_request(peer_id, message, channel).await,
- _ => Err(eyre::eyre!("Received unhandled request from {peer_id}")),
- }
- }
-
- /// Handles a Task request received from the network.
- ///
- /// Based on the task type, the task is sent to the appropriate worker & metadata is stored in memory.
- /// This metadata will be used during response as well, and we can count the number of tasks at hand by
- /// looking at the number metadata stored.
- async fn handle_task_request(
- &mut self,
- peer_id: PeerId,
- task_request: ::Request,
- channel: ResponseChannel>,
- ) -> Result<()> {
- log::info!(
- "Received a {} request from {peer_id}",
- TASK_REQUEST_TOPIC.yellow()
- );
-
- let (task_input, task_metadata) =
- TaskResponder::parse_task_request(self, &task_request, channel).await?;
- if let Err(err) = match task_input.task.is_batchable() {
- // this is a batchable task, send it to batch worker
- // and keep track of the task id in pending tasks
- true => match self.task_request_batch_tx {
- Some(ref mut tx) => {
- self.pending_tasks_batch
- .insert(task_input.row_id, task_metadata);
- tx.send(task_input).await
- }
- None => eyre::bail!("Batchable task received but no worker available."),
- },
-
- // this is a single task, send it to single worker
- // and keep track of the task id in pending tasks
- false => match self.task_request_single_tx {
- Some(ref mut tx) => {
- self.pending_tasks_single
- .insert(task_input.row_id, task_metadata);
- tx.send(task_input).await
- }
- None => eyre::bail!("Single task received but no worker available."),
- },
- } {
- log::error!("Could not send task to worker: {err:?}");
- };
-
- Ok(())
- }
-
- pub(crate) async fn send_task_output(&mut self, task_response: TaskWorkerOutput) -> Result<()> {
- // remove the task from pending tasks, and get its metadata
- let task_metadata = match task_response.batchable {
- true => {
- self.completed_tasks_batch += 1; // TODO: this should be done in success
- self.pending_tasks_batch.remove(&task_response.row_id)
- }
- false => {
- self.completed_tasks_single += 1; // TODO: this should be done in success
- self.pending_tasks_single.remove(&task_response.row_id)
- }
- };
-
- // respond to the response channel with the result
- match task_metadata {
- Some(task_metadata) => {
- TaskResponder::send_task_output(self, task_response, task_metadata).await?;
- }
- None => {
- // totally unexpected case, wont happen at all
- eyre::bail!("Metadata not found for {}", task_response.row_id);
- }
- };
-
- Ok(())
- }
-
- /// Sends a heartbeat request to the configured RPC node.
- #[inline]
- pub(crate) async fn send_heartbeat(&mut self) -> Result<()> {
- let peer_id = self.dria_rpc.peer_id;
- let request_id = HeartbeatRequester::send_heartbeat(self, peer_id).await?;
- log::info!(
- "Sending {} request ({request_id}) to {peer_id}",
- HEARTBEAT_TOPIC.blue()
- );
-
- Ok(())
- }
-
- /// Sends a specs request to the configured RPC node.
- #[inline]
- pub(crate) async fn send_specs(&mut self) -> Result<()> {
- let peer_id = self.dria_rpc.peer_id;
- let specs = self.spec_collector.collect().await;
- let request_id = SpecRequester::send_specs(self, peer_id, specs).await?;
- log::info!(
- "Sending {} request ({request_id}) to {peer_id}",
- SPECS_TOPIC.green()
- );
-
- Ok(())
- }
-}
diff --git a/compute/src/node/rpc.rs b/compute/src/node/rpc.rs
deleted file mode 100644
index 8e3cbeb9..00000000
--- a/compute/src/node/rpc.rs
+++ /dev/null
@@ -1,108 +0,0 @@
-use dkn_p2p::libp2p::{multiaddr::Protocol, Multiaddr, PeerId};
-use dkn_utils::{DriaNetwork, SemanticVersion};
-use eyre::{Context, OptionExt, Result};
-use rand::seq::SliceRandom;
-use std::fmt::Debug;
-
-/// The connected RPC node, as per the Star network topology.
-#[derive(Debug, Clone)]
-pub struct DriaRPC {
- pub addr: Multiaddr,
- pub peer_id: PeerId,
- pub network: DriaNetwork,
-}
-
-impl DriaRPC {
- /// Creates a new RPC target at the given type, along with a network type for refreshing the RPC address.
- pub fn new(addr: Multiaddr, network: DriaNetwork) -> Result {
- let peer_id = addr
- .iter()
- .find_map(|p| match p {
- Protocol::P2p(peer_id) => Some(peer_id),
- _ => None,
- })
- .ok_or_eyre("did not find peer ID within the returned RPC address")?;
-
- Ok(Self {
- addr,
- peer_id,
- network,
- })
- }
-
- /// Creates a new RPC target for the given network type and version.
- pub async fn new_for_network(network: DriaNetwork, version: &SemanticVersion) -> Result {
- let addr = get_rpc_for_network(&network, version).await?;
- Self::new(addr, network)
- }
-}
-
-/// Calls the DKN API to get an RPC address for the given network type.
-///
-/// The peer id is expected to be within the multi-address.
-async fn get_rpc_for_network(
- network: &DriaNetwork,
- version: &SemanticVersion,
-) -> Result {
- const MIN_MARGIN: usize = 150;
-
- let response = reqwest::get(network.discovery_url(version)).await?;
- let rpcs_and_peer_counts = response
- .json::>()
- .await
- .wrap_err("could not parse API response")?;
-
- // ensure that the response contains at least one RPC
- if rpcs_and_peer_counts.is_empty() {
- eyre::bail!("no RPCs were returned by discovery API");
- }
-
- // get the minimum count of peers from all RPCs
- let min_peer_count = rpcs_and_peer_counts
- .iter()
- .map(|(_, peer_count)| *peer_count)
- .min()
- .unwrap(); // safe to unwrap because we checked for empty earlier
-
- // choose the RPCs that have peers in range `[min_peer_count, min_peer_count + MIN_MARGIN]`
- let rpcs_and_peer_counts: Vec<(Multiaddr, usize)> = rpcs_and_peer_counts
- .into_iter()
- .filter(|(_, peer_count)| {
- (min_peer_count..=min_peer_count + MIN_MARGIN).contains(peer_count)
- })
- .collect();
-
- // pick a random RPC from the filtered list
- let chosen_rpc = rpcs_and_peer_counts
- .choose(&mut rand::thread_rng())
- .cloned()
- .map(|(addr, _)| addr)
- .unwrap(); // safe to unwrap because we checked for empty earlier
-
- Ok(chosen_rpc)
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[tokio::test]
- async fn test_dria_nodes() {
- let node =
- DriaRPC::new_for_network(DriaNetwork::Mainnet, &SemanticVersion::from_crate_version())
- .await;
- assert!(node.is_ok());
- }
-
- #[test]
- fn test_deserialize() {
- let input = r#"[
- ["/ip4/12.34.56.78/tcp/4001/p2p/16Uiu2HAmG7qrpSh8kenjuYqyrwxgEVdzqRV4wM1hHAZRq4j25VBC", 1],
- ["/ip4/78.56.34.12/tcp/4001/p2p/16Uiu2HAmG7qrpSh8kenjuYqyrwxgEVdzqRV4wM1hHAZRq4j25VBC", 4]
- ]"#;
- let result: Vec<(Multiaddr, usize)> = serde_json::from_str(input).unwrap();
- assert_eq!(result.len(), 2);
- assert_eq!(result[0].1, 1);
- assert_eq!(result[1].1, 4);
- }
-}
diff --git a/compute/src/reqres/heartbeat.rs b/compute/src/reqres/heartbeat.rs
deleted file mode 100644
index fb852cc2..00000000
--- a/compute/src/reqres/heartbeat.rs
+++ /dev/null
@@ -1,87 +0,0 @@
-use colored::Colorize;
-use dkn_p2p::libp2p::{request_response::OutboundRequestId, PeerId};
-use dkn_utils::{
- payloads::{HeartbeatRequest, HeartbeatResponse, HEARTBEAT_TOPIC},
- DriaMessage,
-};
-use eyre::{eyre, Result};
-use std::time::Duration;
-use uuid::Uuid;
-
-use super::IsResponder;
-
-use crate::DriaComputeNode;
-
-pub struct HeartbeatRequester;
-
-impl IsResponder for HeartbeatRequester {
- type Request = DriaMessage; // HeartbeatRequest;
- type Response = HeartbeatResponse;
-}
-
-impl HeartbeatRequester {
- /// Any acknowledged heartbeat that is older than this duration is considered dead.
- pub const HEARTBEAT_DEADLINE: Duration = Duration::from_secs(60);
- pub(crate) async fn send_heartbeat(
- node: &mut DriaComputeNode,
- peer_id: PeerId,
- ) -> Result {
- let uuid = Uuid::now_v7();
- let deadline = chrono::Utc::now() + Self::HEARTBEAT_DEADLINE;
-
- let heartbeat_request = HeartbeatRequest {
- heartbeat_id: uuid,
- deadline,
- pending_batch: node.pending_tasks_batch.len(),
- pending_single: node.pending_tasks_single.len(),
- batch_size: node.config.batch_size,
- };
-
- let heartbeat_message = node.new_message(
- serde_json::to_vec(&heartbeat_request).expect("should be serializable"),
- HEARTBEAT_TOPIC,
- );
- let request_id = node.p2p.request(peer_id, heartbeat_message).await?;
-
- // add it to local heartbeats set
- node.heartbeats_reqs.insert(uuid, deadline);
-
- Ok(request_id)
- }
-
- /// Handles the heartbeat acknowledement by RPC.
- pub(crate) async fn handle_ack(
- node: &mut DriaComputeNode,
- res: HeartbeatResponse,
- ) -> Result<()> {
- if let Some(deadline) = node.heartbeats_reqs.remove(&res.heartbeat_id) {
- if let Some(err) = res.error {
- Err(eyre!(
- "{} was not acknowledged: {}",
- HEARTBEAT_TOPIC.blue(),
- err
- ))
- } else {
- // acknowledge heartbeat
- node.last_heartbeat_at = chrono::Utc::now();
- node.num_heartbeats += 1;
-
- // for diagnostics, we can check if the heartbeat was past its deadline as well
- if chrono::Utc::now() > deadline {
- log::warn!(
- "Acknowledged {} was past its deadline.",
- HEARTBEAT_TOPIC.blue()
- )
- }
-
- Ok(())
- }
- } else {
- Err(eyre!(
- "Received an unknown {} response with id {}.",
- HEARTBEAT_TOPIC.blue(),
- res.heartbeat_id
- ))
- }
- }
-}
diff --git a/compute/src/reqres/mod.rs b/compute/src/reqres/mod.rs
deleted file mode 100644
index 5903cdfb..00000000
--- a/compute/src/reqres/mod.rs
+++ /dev/null
@@ -1,80 +0,0 @@
-//! Request-response handlers.
-
-use eyre::Context;
-use serde::{de::DeserializeOwned, Serialize};
-
-mod specs;
-pub use specs::SpecRequester;
-
-mod task;
-pub use task::TaskResponder;
-
-mod heartbeat;
-pub use heartbeat::HeartbeatRequester;
-
-/// A responder should implement a request & response type, both serializable.
-///
-/// The `try_parse_request` is automatically implemented using `serde-json` for a byte slice.
-pub trait IsResponder {
- type Request: DeserializeOwned;
- type Response: Serialize + DeserializeOwned;
-
- fn try_parse_request(data: &[u8]) -> eyre::Result {
- serde_json::from_slice(data).wrap_err("could not parse request")
- }
-
- fn try_parse_response(data: &[u8]) -> eyre::Result {
- serde_json::from_slice(data).wrap_err("could not parse response")
- }
-}
-
-#[cfg(test)]
-mod tests {
-
- use super::*;
-
- // TODO: remove this test when we migrate to enum-based bodies
- #[test]
- fn test_enum_serialization() {
- use serde::Deserialize;
- #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
- struct AEnum {
- a1: bool,
- a2: String,
- }
-
- #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
- struct BEnum {
- b1: u64,
- b2: bool,
- }
-
- #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
- #[serde(tag = "type", rename_all = "camelCase")]
- enum TestEnum {
- A(AEnum),
- B(BEnum),
- }
-
- let a_variant = TestEnum::A(AEnum {
- a1: true,
- a2: "test".to_string(),
- });
- let b_variant = TestEnum::B(BEnum {
- b1: 123456789,
- b2: false,
- });
-
- let a_serialized = serde_json::to_string(&a_variant).unwrap();
- let b_serialized = serde_json::to_string(&b_variant).unwrap();
-
- assert_eq!(a_serialized, r#"{"type":"a","a1":true,"a2":"test"}"#);
- assert_eq!(b_serialized, r#"{"type":"b","b1":123456789,"b2":false}"#);
-
- let a_deserialized: TestEnum = serde_json::from_str(&a_serialized).unwrap();
- let b_deserialized: TestEnum = serde_json::from_str(&b_serialized).unwrap();
-
- assert_eq!(a_variant, a_deserialized);
- assert_eq!(b_variant, b_deserialized);
- }
-}
diff --git a/compute/src/reqres/specs.rs b/compute/src/reqres/specs.rs
deleted file mode 100644
index eb9194c0..00000000
--- a/compute/src/reqres/specs.rs
+++ /dev/null
@@ -1,57 +0,0 @@
-use crate::DriaComputeNode;
-
-use super::IsResponder;
-use colored::Colorize;
-use dkn_p2p::libp2p::{request_response::OutboundRequestId, PeerId};
-use dkn_utils::{
- payloads::{Specs, SpecsRequest, SpecsResponse, SPECS_TOPIC},
- DriaMessage,
-};
-use eyre::{eyre, Result};
-use uuid::Uuid;
-
-pub struct SpecRequester;
-
-impl IsResponder for SpecRequester {
- type Request = DriaMessage; // SpecRequest;
- type Response = SpecsResponse;
-}
-
-impl SpecRequester {
- pub(crate) async fn send_specs(
- node: &mut DriaComputeNode,
- peer_id: PeerId,
- specs: Specs,
- ) -> Result {
- let uuid = Uuid::now_v7();
- let specs_request = SpecsRequest {
- specs_id: uuid,
- specs,
- address: node.config.address.clone(),
- };
-
- let specs_message = node.new_message(
- serde_json::to_vec(&specs_request).expect("should be serializable"),
- SPECS_TOPIC,
- );
- let request_id = node.p2p.request(peer_id, specs_message).await?;
-
- // add it to local specs set
- node.specs_reqs.insert(uuid);
-
- Ok(request_id)
- }
-
- /// Handles the specs request received from the network.
- pub(crate) async fn handle_ack(node: &mut DriaComputeNode, res: SpecsResponse) -> Result<()> {
- if node.specs_reqs.remove(&res.specs_id) {
- Ok(())
- } else {
- Err(eyre!(
- "Received an unknown {} response with id {}.",
- SPECS_TOPIC.green(),
- res.specs_id
- ))
- }
- }
-}
diff --git a/compute/src/reqres/task.rs b/compute/src/reqres/task.rs
deleted file mode 100644
index f74ae36f..00000000
--- a/compute/src/reqres/task.rs
+++ /dev/null
@@ -1,283 +0,0 @@
-use colored::Colorize;
-use dkn_executor::{CompletionError, ModelProvider, PromptError, TaskBody};
-use dkn_p2p::libp2p::request_response::ResponseChannel;
-use dkn_utils::payloads::{
- TaskError, TaskRequestPayload, TaskResponsePayload, TaskStats, TASK_RESULT_TOPIC,
-};
-use dkn_utils::DriaMessage;
-use eyre::{Context, Result};
-
-use crate::workers::task::*;
-use crate::DriaComputeNode;
-
-pub struct TaskResponder;
-
-impl super::IsResponder for TaskResponder {
- type Request = DriaMessage; // TODO: can we do this typed?
- type Response = DriaMessage; // TODO: can we do this typed?
-}
-
-impl TaskResponder {
- pub(crate) async fn parse_task_request(
- node: &mut DriaComputeNode,
- compute_message: &DriaMessage,
- channel: ResponseChannel>,
- ) -> Result<(TaskWorkerInput, TaskWorkerMetadata)> {
- // parse this in two-steps so that if something goes wrong we know the task id
- let task = compute_message
- .parse_payload::>()
- .wrap_err("could not parse task request payload")?;
- let task_body = match serde_json::from_value::(task.input) {
- Ok(task_body) => task_body,
- Err(err) => {
- log::error!(
- "Task {}/{} failed due to parsing error: {err}",
- task.file_id,
- task.row_id,
- );
-
- // prepare error payload
- let error_payload = TaskResponsePayload {
- result: None,
- error: Some(TaskError::ParseError(err.to_string())),
- row_id: task.row_id,
- file_id: task.file_id,
- task_id: task.task_id,
- model: "".to_string(), // no model available due to parsing error
- stats: TaskStats::new(),
- };
-
- let error_payload_str = serde_json::to_string(&error_payload)
- .wrap_err("could not serialize payload")?;
-
- // respond through the channel to notify about the parsing error
- let response = node.new_message(error_payload_str, TASK_RESULT_TOPIC);
- node.p2p.respond(response.into(), channel).await?;
-
- // return with error
- eyre::bail!("could not parse task body: {err}")
- }
- };
-
- let stats = TaskStats::new().record_received_at();
- log::info!(
- "Handling {} {} with model {}",
- "task".yellow(),
- task.row_id,
- task_body.model.to_string().yellow()
- );
-
- // check if the model is available in this node, if so
- // it will return an executor that can run this model
- let executor = node.config.executors.get_executor(&task_body.model).await?;
-
- let task_metadata = TaskWorkerMetadata {
- task_id: task.task_id,
- file_id: task.file_id,
- model: task_body.model,
- channel,
- };
- let task_input = TaskWorkerInput {
- executor,
- task: task_body,
- row_id: task.row_id,
- stats,
- };
-
- Ok((task_input, task_metadata))
- }
-
- /// Handles the result of a task.
- pub(crate) async fn send_task_output(
- node: &mut DriaComputeNode,
- task_output: TaskWorkerOutput,
- task_metadata: TaskWorkerMetadata,
- ) -> Result<()> {
- let response = match task_output.result {
- Ok(result) => {
- // prepare signed and encrypted payload
- log::info!(
- "Publishing {} result for {}/{}",
- "task".yellow(),
- task_metadata.file_id,
- task_output.row_id
- );
-
- // TODO: will get better token count from `TaskWorkerOutput`
- let token_count = result.len();
- let payload = TaskResponsePayload {
- result: Some(result),
- error: None,
- file_id: task_metadata.file_id,
- task_id: task_metadata.task_id,
- row_id: task_output.row_id,
- model: task_metadata.model.to_string(),
- stats: task_output
- .stats
- .record_published_at()
- .record_token_count(token_count),
- };
- let payload_str =
- serde_json::to_string(&payload).wrap_err("could not serialize payload")?;
-
- node.new_message(payload_str, TASK_RESULT_TOPIC)
- }
- Err(err) => {
- // use pretty display string for error logging with causes
- log::error!(
- "Task {}/{} failed: {:#}",
- task_metadata.file_id,
- task_output.row_id,
- err
- );
-
- // prepare error payload
- let error_payload = TaskResponsePayload {
- result: None,
- error: Some(map_prompt_error_to_task_error(
- task_metadata.model.provider(),
- err,
- )),
- row_id: task_output.row_id,
- file_id: task_metadata.file_id,
- task_id: task_metadata.task_id,
- model: task_metadata.model.to_string(),
- stats: task_output
- .stats
- .record_published_at()
- .record_token_count(0),
- };
- let error_payload_str = serde_json::to_string(&error_payload)
- .wrap_err("could not serialize payload")?;
-
- node.new_message(error_payload_str, TASK_RESULT_TOPIC)
- }
- };
-
- // respond through the channel
- node.p2p
- .respond(response.into(), task_metadata.channel)
- .await?;
-
- Ok(())
- }
-}
-
-/// Maps a [`PromptError`] to a [`TaskError`] with respect to the given provider.
-fn map_prompt_error_to_task_error(provider: ModelProvider, err: PromptError) -> TaskError {
- match &err {
- // if the error is a provider error, we can try to parse it
- PromptError::CompletionError(CompletionError::ProviderError(err_inner)) => {
- /// A wrapper for `{ error: T }` to match the provider error format.
- #[derive(Clone, serde::Deserialize)]
- struct ErrorObject {
- error: T,
- }
-
- match provider {
- // ModelProvider::Gemini => {
- // /// Gemini API [error object](https://github.com/googleapis/go-genai/blob/main/api_client.go#L273).
- // #[derive(Clone, serde::Deserialize)]
- // pub struct GeminiError {
- // code: u32,
- // message: String,
- // status: String,
- // }
-
- // serde_json::from_str::>(err_inner).map(
- // |ErrorObject {
- // error: gemini_error,
- // }| TaskError::ProviderError {
- // code: format!("{} ({})", gemini_error.code, gemini_error.status),
- // message: gemini_error.message,
- // provider: provider.to_string(),
- // },
- // )
- // }
- // ModelProvider::OpenAI => {
- // /// OpenAI API [error object](https://github.com/openai/openai-go/blob/main/internal/apierror/apierror.go#L17).
- // #[derive(Clone, serde::Deserialize)]
- // pub struct OpenAIError {
- // code: String,
- // message: String,
- // }
-
- // serde_json::from_str::>(err_inner).map(
- // |ErrorObject {
- // error: openai_error,
- // }| TaskError::ProviderError {
- // code: openai_error.code,
- // message: openai_error.message,
- // provider: provider.to_string(),
- // },
- // )
- // }
- // ModelProvider::OpenRouter => {
- // /// OpenRouter API [error object](https://openrouter.ai/docs/api-reference/errors).
- // #[derive(Clone, serde::Deserialize)]
- // pub struct OpenRouterError {
- // code: u32,
- // message: String,
- // }
-
- // serde_json::from_str::>(err_inner).map(
- // |ErrorObject {
- // error: openrouter_error,
- // }| {
- // TaskError::ProviderError {
- // code: openrouter_error.code.to_string(),
- // message: openrouter_error.message,
- // provider: provider.to_string(),
- // }
- // },
- // )
- // }
- ModelProvider::Ollama => serde_json::from_str::>(err_inner)
- .map(
- // Ollama just returns a string error message
- |ErrorObject {
- error: ollama_error,
- }| {
- // based on the error message, we can come up with out own "dummy" codes
- let code = if ollama_error.contains("server busy, please try again.") {
- "server_busy"
- } else if ollama_error.contains("model requires more system memory") {
- "model_requires_more_memory"
- } else if ollama_error.contains("cudaMalloc failed: out of memory") {
- "cuda_malloc_failed"
- } else if ollama_error.contains("CUDA error: out of memory") {
- "cuda_oom"
- } else if ollama_error.contains("API Error: Too Many Requests") {
- "api:too_many_requests"
- } else if ollama_error.contains("API Error: Bad Request") {
- "api:bad_request"
- } else if ollama_error.contains("not found, try pulling it first") {
- "model_not_pulled"
- } else if ollama_error.contains("Unexpected end of JSON input") {
- "unexpected_end_of_json"
- } else {
- "unknown"
- };
-
- TaskError::ProviderError {
- code: code.to_string(),
- message: ollama_error,
- provider: provider.to_string(),
- }
- },
- ),
- }
- // if we couldn't parse it, just return a generic prompt error
- .unwrap_or(TaskError::ExecutorError(format!(
- "{provider} executor error: {}",
- err_inner.clone()
- )))
- }
- // if its a http error, we can try to parse it as well
- PromptError::CompletionError(CompletionError::HttpError(err_inner)) => {
- TaskError::HttpError(err_inner.to_string())
- }
- // if it's not a completion error, we just return the error as is
- err => TaskError::Other(err.to_string()),
- }
-}
diff --git a/compute/src/utils/mod.rs b/compute/src/utils/mod.rs
deleted file mode 100644
index e5be541e..00000000
--- a/compute/src/utils/mod.rs
+++ /dev/null
@@ -1,5 +0,0 @@
-mod specs;
-pub use specs::*;
-
-mod points;
-pub use points::*;
diff --git a/compute/src/utils/points.rs b/compute/src/utils/points.rs
deleted file mode 100644
index ac07bdf3..00000000
--- a/compute/src/utils/points.rs
+++ /dev/null
@@ -1,85 +0,0 @@
-use dkn_utils::DriaNetwork;
-use eyre::Context;
-
-pub struct DriaPointsClient {
- pub url: String,
- client: reqwest::Client,
- /// The total number of points you have accumulated at the start of the run.
- pub initial: f64,
-}
-
-#[derive(Debug, serde::Deserialize)]
-pub struct DriaPoints {
- /// Indicates in which top percentile your points are.
- pub percentile: usize,
- /// The total number of points you have accumulated.
- pub score: f64,
-}
-
-impl DriaPointsClient {
- /// The base URL for the points API, w.r.t network.
- pub fn base_url(network: &DriaNetwork) -> &'static str {
- match network {
- DriaNetwork::Mainnet => "https://mainnet.dkn.dria.co/points/v0/total/node/",
- DriaNetwork::Testnet => "https://testnet.dkn.dria.co/points/v0/total/node/",
- }
- }
-
- /// Creates a new `DriaPointsClient` for the given address.
- pub fn new(address: &str, network: &DriaNetwork) -> eyre::Result {
- const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
-
- let url = format!(
- "{}/0x{}",
- Self::base_url(network),
- address.trim_start_matches("0x")
- );
-
- let client = reqwest::Client::builder()
- .user_agent(USER_AGENT)
- .build()
- .wrap_err("could not create Points client")?;
-
- Ok(Self {
- url,
- client,
- initial: 0.0,
- })
- }
-
- /// Sets the initial points to the current points.
- ///
- /// If there is an error, it sets to 0.0.
- pub async fn initialize(&mut self) {
- self.initial = self.get_points().await.map(|p| p.score).unwrap_or_default();
- }
-
- pub async fn get_points(&self) -> eyre::Result {
- let res = self
- .client
- .get(&self.url)
- .send()
- .await
- .wrap_err("could not make request")?;
- res.json::()
- .await
- .wrap_err("could not parse response")
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[tokio::test]
- async fn test_get_points() {
- let client = DriaPointsClient::new(
- "0xa43536a6032a3907ccf60e8109429ee1047b207c",
- &DriaNetwork::Mainnet,
- )
- .unwrap();
- let steps = client.get_points().await.unwrap();
- assert!(steps.score >= 0.0);
- assert!(steps.percentile <= 100);
- }
-}
diff --git a/compute/src/utils/specs.rs b/compute/src/utils/specs.rs
deleted file mode 100644
index 837cfab3..00000000
--- a/compute/src/utils/specs.rs
+++ /dev/null
@@ -1,120 +0,0 @@
-use dkn_executor::Model;
-use dkn_p2p::libp2p::PeerId;
-use dkn_utils::{
- payloads::{SpecModelPerformance, Specs},
- SemanticVersion,
-};
-use std::collections::HashMap;
-use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind};
-
-pub struct SpecCollector {
- /// System information object, this is expected to be created only once
- /// as per the [docs](https://github.com/GuillaumeGomez/sysinfo?tab=readme-ov-file#good-practice--performance-tips).
- system: sysinfo::System,
- /// Used models.
- models: Vec,
- /// Model performances
- model_perf: HashMap,
- /// Version string.
- version: String,
- /// Execution platform, mainly for diagnostics.
- exec_platform: String,
- /// Peer ID of the node, used for identification in the network.
- peer_id: String,
- // GPU adapter infos, showing information about the available GPUs.
- // gpus: Vec,
-}
-
-impl SpecCollector {
- pub fn new(
- models: Vec,
- model_perf: HashMap,
- version: SemanticVersion,
- exec_platform: String,
- peer_id: PeerId,
- ) -> Self {
- log::info!("Creating spec collector with version {version} and platform {exec_platform} and models {models:?}");
- SpecCollector {
- system: sysinfo::System::new_with_specifics(Self::get_refresh_specifics()),
- models,
- model_perf: model_perf
- .into_iter()
- .map(|(k, v)| (k.to_string(), v))
- .collect(),
- version: version.to_string(),
- exec_platform,
- peer_id: peer_id.to_string(),
- // gpus: wgpu::Instance::default()
- // .enumerate_adapters(wgpu::Backends::all())
- // .into_iter()
- // .map(|a| a.get_info())
- // .collect(),
- }
- }
-
- /// Returns the selected refresh kinds. It is important to ignore
- /// process values here because it will consume a lot of file-descriptors.
- #[inline(always)]
- fn get_refresh_specifics() -> RefreshKind {
- RefreshKind::nothing()
- .with_cpu(CpuRefreshKind::everything())
- .with_memory(MemoryRefreshKind::everything())
- }
-
- pub async fn collect(&mut self) -> Specs {
- self.system.refresh_specifics(Self::get_refresh_specifics());
-
- Specs {
- total_mem: self.system.total_memory(),
- free_mem: self.system.free_memory(),
- num_cpus: self.system.physical_core_count(),
- cpu_usage: self.system.global_cpu_usage(),
- os: std::env::consts::OS.to_string(),
- arch: std::env::consts::ARCH.to_string(),
- lookup: public_ip_address::perform_lookup(None).await.ok(),
- models: self.models.clone(),
- version: self.version.clone(),
- model_perf: self.model_perf.clone(),
- exec_platform: Some(self.exec_platform.clone()),
- peer_id: Some(self.peer_id.clone()),
- // gpus: self.gpus.clone(),
- }
- }
-}
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[tokio::test]
- async fn test_specs_serialization() {
- let mut spec_collector = SpecCollector::new(
- vec![Model::Gemma3_4b.to_string()],
- HashMap::from_iter([
- (Model::Gemma3_4b, SpecModelPerformance::PassedWithTPS(100.0)),
- (Model::Gemma3_27b, SpecModelPerformance::ExecutionFailed),
- ]),
- SemanticVersion {
- major: 4,
- minor: 5,
- patch: 1,
- },
- "testing".to_string(),
- PeerId::random(),
- );
- let specs = spec_collector.collect().await;
- assert!(specs.total_mem > 0);
- assert!(specs.free_mem > 0);
- assert!(specs.num_cpus.is_some());
- assert!(specs.cpu_usage > 0.0);
- assert!(!specs.os.is_empty());
- assert!(!specs.arch.is_empty());
- assert!(specs.lookup.is_some());
- assert!(!specs.models.is_empty());
- assert_eq!(specs.model_perf.len(), 2);
- assert_eq!(specs.version, "4.5.1");
- assert_eq!(specs.exec_platform, Some("testing".to_string()));
-
- // should be serializable to JSON
- assert!(serde_json::to_string_pretty(&specs).is_ok())
- }
-}
diff --git a/compute/src/workers/mod.rs b/compute/src/workers/mod.rs
deleted file mode 100644
index cdafe4ad..00000000
--- a/compute/src/workers/mod.rs
+++ /dev/null
@@ -1 +0,0 @@
-pub mod task;
diff --git a/compute/src/workers/task.rs b/compute/src/workers/task.rs
deleted file mode 100644
index 4515951f..00000000
--- a/compute/src/workers/task.rs
+++ /dev/null
@@ -1,312 +0,0 @@
-use colored::Colorize;
-use dkn_executor::{DriaExecutor, Model, TaskBody};
-use dkn_p2p::libp2p::request_response::ResponseChannel;
-use dkn_utils::payloads::TaskStats;
-use tokio::sync::mpsc;
-use uuid::Uuid;
-
-/// A metadata object that is kept aside while the worker is doing its job.
-///
-/// This is put into a map before execution, and then removed after the task is done.
-pub struct TaskWorkerMetadata {
- pub model: Model,
- pub task_id: String,
- pub file_id: Uuid,
- /// If for any reason this object is dropped before `channel` is responded to,
- /// the task will be lost and the channel will be abruptly closed, causing an error on
- /// both the responder and the requester side, likely with an `OmissionError`.
- pub channel: ResponseChannel>,
-}
-
-pub struct TaskWorkerInput {
- /// used as identifier for metadata
- pub row_id: Uuid,
- // actual consumed input
- pub executor: DriaExecutor,
- pub task: TaskBody,
- // piggybacked metadata
- pub stats: TaskStats,
-}
-
-pub struct TaskWorkerOutput {
- // used as identifier for metadata
- pub row_id: Uuid,
- // actual produced output
- pub result: Result,
- // piggybacked metadata
- pub stats: TaskStats,
- pub batchable: bool,
-}
-
-/// It is expected to be spawned in another thread, with [`Self::run_batch`] for batch processing and [`Self::run_series`] for single processing.
-pub struct TaskWorker {
- /// Task channel receiver, the sender is most likely the compute node itself.
- task_rx: mpsc::Receiver,
- /// Publish message channel sender, the receiver is most likely the compute node itself.
- publish_tx: mpsc::Sender,
- // TODO: batch size must be defined here
-}
-
-/// Buffer size for task channels (per worker).
-const TASK_RX_CHANNEL_BUFSIZE: usize = 1024;
-
-impl TaskWorker {
- /// Batch size that defines how many tasks can be executed concurrently at once.
- ///
- /// The `run` function is designed to handle the batch size here specifically,
- /// if there are more tasks than the batch size, the function will panic.
- pub const MAX_BATCH_SIZE: usize = 8;
-
- /// Creates a worker and returns the sender and receiver for the worker.
- pub fn new(
- publish_tx: mpsc::Sender,
- ) -> (TaskWorker, mpsc::Sender) {
- let (task_tx, task_rx) = mpsc::channel(TASK_RX_CHANNEL_BUFSIZE);
-
- let worker = TaskWorker {
- task_rx,
- publish_tx,
- };
-
- (worker, task_tx)
- }
-
- /// Closes the worker's receiver channel.
- fn shutdown(&mut self) {
- log::info!("Closing worker.");
- self.task_rx.close();
- }
-
- /// Launches the thread that can process tasks one by one (in series).
- /// This function will block until the channel is closed.
- ///
- /// It is suitable for task streams that consume local resources, unlike API calls.
- pub async fn run_series(&mut self) {
- loop {
- let task = self.task_rx.recv().await;
-
- if let Some(task) = task {
- log::info!("Processing {} (single)", "task".yellow(),);
- TaskWorker::execute((task, &self.publish_tx)).await
- } else {
- return self.shutdown();
- };
- }
- }
-
- /// Launches the thread that can process tasks in batches.
- /// This function will block until the channel is closed.
- ///
- /// It is suitable for task streams that make use of API calls, unlike Ollama-like
- /// tasks that consumes local resources and would not make sense to run in parallel.
- ///
- /// Batch size must NOT be larger than `MAX_BATCH_SIZE`, otherwise will panic.
- pub async fn run_batch(&mut self, batch_size: usize) {
- assert!(
- batch_size <= Self::MAX_BATCH_SIZE,
- "Batch size must not be larger than {}",
- Self::MAX_BATCH_SIZE
- );
-
- loop {
- let mut tasks = Vec::new();
-
- // get tasks in batch from the channel, we enter the loop if:
- // (1) there are no tasks, or,
- // (2) there are tasks less than the batch size and the channel is not empty
- while tasks.is_empty() || (tasks.len() < batch_size && !self.task_rx.is_empty()) {
- log::info!(
- "Worker is waiting for tasks ({} < {})",
- tasks.len(),
- batch_size
- );
- let limit = batch_size - tasks.len();
- match self.task_rx.recv_many(&mut tasks, limit).await {
- // 0 tasks returned means that the channel is closed
- 0 => return self.shutdown(),
- _ => {
- // wait a small amount of time to allow for more tasks to be sent into the channel
- tokio::time::sleep(std::time::Duration::from_millis(256)).await;
- }
- }
- }
-
- // process the batch
- let num_tasks = tasks.len();
- debug_assert!(
- num_tasks <= batch_size,
- "number of tasks cant be larger than batch size"
- );
- debug_assert!(num_tasks != 0, "number of tasks cant be zero");
-
- log::info!("Processing {num_tasks} tasks in batch");
- let mut batch = tasks.into_iter().map(|b| (b, &self.publish_tx));
- match num_tasks {
- 1 => {
- TaskWorker::execute(batch.next().unwrap()).await;
- }
- 2 => {
- tokio::join!(
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap())
- );
- }
- 3 => {
- tokio::join!(
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap())
- );
- }
- 4 => {
- tokio::join!(
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap())
- );
- }
- 5 => {
- tokio::join!(
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap())
- );
- }
- 6 => {
- tokio::join!(
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap())
- );
- }
- 7 => {
- tokio::join!(
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap())
- );
- }
- 8 => {
- tokio::join!(
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap()),
- TaskWorker::execute(batch.next().unwrap())
- );
- }
- _ => {
- unreachable!(
- "number of tasks cant be larger than batch size ({} > {})",
- num_tasks,
- Self::MAX_BATCH_SIZE
- );
- }
- };
- }
- }
-
- /// Executes a single task, and publishes the output.
- pub async fn execute(
- (mut input, publish_tx): (TaskWorkerInput, &mpsc::Sender),
- ) {
- let batchable = input.task.is_batchable();
- input.stats = input.stats.record_execution_started_at();
- let result = input.executor.execute(input.task).await;
- input.stats = input.stats.record_execution_ended_at();
-
- let output = TaskWorkerOutput {
- result,
- row_id: input.row_id,
- batchable,
- stats: input.stats,
- };
-
- if let Err(err) = publish_tx.send(output).await {
- log::error!("Error sending task result: {err}");
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use dkn_executor::{DriaExecutor, Model};
-
- /// Tests the worker with a single task sent within a batch.
- ///
- /// ## Run command
- ///
- /// ```sh
- /// cargo test --package dkn-compute --lib --all-features -- workers::task::tests::test_executor_worker --exact --show-output --nocapture --ignored
- /// ```
- #[tokio::test]
- #[ignore = "run manually with Ollama"]
- async fn test_executor_worker() {
- let _ = env_logger::builder()
- .filter_level(log::LevelFilter::Off)
- .filter_module("dkn_compute", log::LevelFilter::Debug)
- .is_test(true)
- .try_init();
-
- let (publish_tx, mut publish_rx) = mpsc::channel(1024);
- let (mut worker, task_tx) = TaskWorker::new(publish_tx);
-
- // create batch worker
- let worker_handle = tokio::spawn(async move {
- worker.run_batch(4).await;
- });
-
- let num_tasks = 4;
- let model = Model::Llama3_2_1bInstructQ4Km;
- let executor = DriaExecutor::new_from_env(model.provider()).unwrap();
- let task = TaskBody::new_prompt("Write a poem about Julius Caesar.", model.clone());
-
- for i in 0..num_tasks {
- log::info!("Sending task {}", i + 1);
-
- let task_input = TaskWorkerInput {
- executor: executor.clone(),
- task: task.clone(),
- // dummy variables
- row_id: Uuid::now_v7(),
- stats: TaskStats::default(),
- };
-
- // send task to worker
- task_tx.send(task_input).await.unwrap();
- }
-
- // now wait for all results
- let mut results = Vec::new();
- for i in 0..num_tasks {
- log::info!("Waiting for result {}", i + 1);
- let result = publish_rx.recv().await.unwrap();
- log::info!("Got result {}", i + 1,);
- if result.result.is_err() {
- log::error!("Error: {:?}", result.result);
- }
- results.push(result);
- }
-
- log::info!("Got all results, closing channel.");
- publish_rx.close();
-
- // FIXME: this bugs out
- worker_handle.await.unwrap();
- log::info!("Done.");
- }
-}
diff --git a/dnet.art b/dnet.art
new file mode 100644
index 00000000..f0c6c6c2
--- /dev/null
+++ b/dnet.art
@@ -0,0 +1,18 @@
+
+
+Dria [0.7.2]
+Decentralized LLM inference
+
+ .-----. .-.--.. .-..=+--. ----. ..... +.
+ ==-. -==+##==#@#+- -+#@@+====== .==-. =###@+--.....=@o+. ==-...===++=----
+ =#- .ooooo+ .oooo= -@#- .#ooooo- #@- @ooo# @+. -o#=. -oooo= +o=
+=o .oooo-. oooo@ =#oooooo= ## .@ooo+ @- +@- .oooo- @+
+@# .oooo. +oooo- ++ @oooo= ## . .oooo+ .@. --.++ .oooo- =
+-@- -oooo. +oooo= += @oooo= ## -# .oooo+.--=o= -oooo.
+ ----+ooo@. #ooo@- #= @oooo= @# .@ -oooo- ++ . =oooo.
+ =ooo@ .#ooo#. #- ooooo@@+ ++-oooo- =#. #oooo=
+ =ooo@ -@oo@#. #- =@oooo# @oooo= .#+ #oooo=
+ .=#ooo@ +###-. .-@+ +ooo+ .+oooooo-...-#o= --#oooo@.
+
+ https://dria.co/edge-ai
+ Made with <3
\ No newline at end of file
diff --git a/docs/NODE_SPECS.md b/docs/NODE_SPECS.md
deleted file mode 100644
index 4234d9bc..00000000
--- a/docs/NODE_SPECS.md
+++ /dev/null
@@ -1,254 +0,0 @@
-# 🚀 LLM Node Runner's Guide: Minimum Specs
-
-Hello, Drians! 👋 Here's a guide to help you understand the minimum specs needed for running different LLMs. We've broken it down into two main categories: (1) **GPU-enabled** nodes and (2) **CPU-only** nodes, as you can run your nodes on machines both _with_ or _without_ GPU.
-
-- ## 🖥️ GPU-Enabled Nodes
-
-### RTX3090 Single GPU:
-
-| Model | TPS |
-| ----------------------------------- | -------- |
-| finalend/hermes-3-llama-3.1:8b-q8_0 | 76.4388 |
-| phi3:14b-medium-4k-instruct-q4_1 | 75.6148 |
-| phi3:14b-medium-128k-instruct-q4_1 | 76.0658 |
-| phi3.5:3.8b | 195.0728 |
-| phi3.5:3.8b-mini-instruct-fp16 | 88.4656 |
-| gemma2:9b-instruct-q8_0 | 56.2726 |
-| gemma2:9b-instruct-fp16 | 37.9404 |
-| llama3.1:latest | 103.3473 |
-| llama3.1:8b-instruct-q8_0 | 78.5861 |
-| llama3.1:8b-instruct-fp16 | 50.9302 |
-| llama3.1:8b-text-q4_K_M | 104.4776 |
-| llama3.1:8b-text-q8_0 | 82.3980 |
-| llama3.2:1b | 293.1785 |
-| llama3.2:3b | 168.7500 |
-| llama3.2:1b-text-q4_K_M | 349.2497 |
-| qwen2.5:7b-instruct-q5_0 | 114.0511 |
-| qwen2.5:7b-instruct-fp16 | 53.5423 |
-| qwen2.5-coder:1.5b | 238.6117 |
-| qwen2.5-coder:7b-instruct | 125.2194 |
-| qwen2.5-coder:7b-instruct-q8_0 | 83.7696 |
-| qwen2.5-coder:7b-instruct-fp16 | 53.7400 |
-| qwq | 33.4434 |
-| deepseek-coder:6.7b | 141.7769 |
-| deepseek-r1:1.5b | 235.8560 |
-| deepseek-r1:7b | 121.9637 |
-| deepseek-r1:8b | 107.5933 |
-| deepseek-r1:14b | 66.5972 |
-| deepseek-r1:32b | 34.4669 |
-| deepseek-r1 | 120.9809 |
-| driaforall/tiny-agent-a:0.5b | 279.2553 |
-| driaforall/tiny-agent-a:1.5b | 201.7011 |
-| driaforall/tiny-agent-a:3b | 135.1052 |
-
-### H200 SXM Single GPU:
-
-| Model | TPS |
-| ----------------------------------- | -------- |
-| finalend/hermes-3-llama-3.1:8b-q8_0 | 121.2871 |
-| phi3:14b-medium-4k-instruct-q4_1 | 128.9496 |
-| phi3:14b-medium-128k-instruct-q4_1 | 124.4223 |
-| phi3.5:3.8b | 184.3729 |
-| phi3.5:3.8b-mini-instruct-fp16 | 155.6164 |
-| gemma2:9b-instruct-q8_0 | 91.6370 |
-| gemma2:9b-instruct-fp16 | 85.6672 |
-| llama3.1:latest | 123.8938 |
-| llama3.1:8b-instruct-q8_0 | 112.3102 |
-| llama3.1:8b-instruct-fp16 | 108.9053 |
-| llama3.1:8b-text-q4_K_M | 148.0687 |
-| llama3.1:8b-text-q8_0 | 135.3251 |
-| llama3.1:70b-instruct-q4_0 | 47.0107 |
-| llama3.1:70b-instruct-q8_0 | 35.2827 |
-| llama3.2:1b | 163.9058 |
-| llama3.2:3b | 150.6063 |
-| llama3.3:70b | 39.1993 |
-| llama3.2:1b-text-q4_K_M | 233.6957 |
-| qwen2.5:7b-instruct-q5_0 | 126.5432 |
-| qwen2.5:7b-instruct-fp16 | 103.8552 |
-| qwen2.5:32b-instruct-fp16 | 40.3735 |
-| qwen2.5-coder:1.5b | 187.3554 |
-| qwen2.5-coder:7b-instruct | 119.7279 |
-| qwen2.5-coder:7b-instruct-q8_0 | 108.9536 |
-| qwen2.5-coder:7b-instruct-fp16 | 104.0222 |
-| qwq | 59.4734 |
-| deepseek-coder:6.7b | 136.8015 |
-| mixtral:8x7b | 94.9618 |
-| deepseek-r1:1.5b | 160.8217 |
-| deepseek-r1:7b | 141.2172 |
-| deepseek-r1:8b | 136.8324 |
-| deepseek-r1:14b | 90.3022 |
-| deepseek-r1:32b | 63.1900 |
-| deepseek-r1:70b | 39.4153 |
-| deepseek-r1 | 121.8406 |
-| driaforall/tiny-agent-a:0.5b | 148.5390 |
-| driaforall/tiny-agent-a:1.5b | 180.9409 |
-| driaforall/tiny-agent-a:3b | 111.1869 |
-
-- ## 💻 CPU-Only Nodes
-
-For those running without a GPU, we've got you covered too! Here are the specs for different CPU types:
-
-### AMD (8 CPU, 16GB RAM)
-
-| Model | TPS |
-| ---------------------------- | ------- |
-| llama3.2:1b | 22.6293 |
-| llama3.2:1b-text-q4_K_M | 25.0413 |
-| qwen2.5-coder:1.5b | 21.7418 |
-| deepseek-r1:1.5b | 29.7842 |
-| driaforall/tiny-agent-a:0.5b | 54.5455 |
-| driaforall/tiny-agent-a:1.5b | 19.9501 |
-
-### AMD (16 CPU, 32GB RAM)
-
-| Model | TPS |
-| ---------------------------- | ------- |
-| phi3.5:3.8b | 15.3677 |
-| llama3.2:1b | 25.6367 |
-| llama3.2:3b | 16.3185 |
-| llama3.2:1b-text-q4_K_M | 38.0039 |
-| qwen2.5-coder:1.5b | 30.3651 |
-| deepseek-r1:1.5b | 30.2977 |
-| driaforall/tiny-agent-a:0.5b | 61.2553 |
-| driaforall/tiny-agent-a:1.5b | 25.7011 |
-
-### AMD (32 CPU, 64GB RAM)
-
-| Model | TPS |
-| ---------------------------- | ------- |
-| phi3.5:3.8b | 22.9944 |
-| llama3.2:1b | 40.6091 |
-| llama3.2:3b | 26.0240 |
-| llama3.2:1b-text-q4_K_M | 56.2027 |
-| qwen2.5-coder:1.5b | 44.6331 |
-| deepseek-coder:6.7b | 15.1620 |
-| deepseek-r1:1.5b | 43.8323 |
-| driaforall/tiny-agent-a:0.5b | 59.9854 |
-| driaforall/tiny-agent-a:1.5b | 27.7891 |
-
-### AMD (48 CPU, 96GB RAM)
-
-| Model | TPS |
-| ---------------------------- | ------- |
-| phi3.5:3.8b | 29.7455 |
-| llama3.1:latest | 17.4744 |
-| llama3.1:8b-text-q4_K_M | 18.1928 |
-| llama3.2:1b | 49.1555 |
-| llama3.2:3b | 33.9283 |
-| llama3.2:1b-text-q4_K_M | 72.7273 |
-| qwen2.5:7b-instruct-q5_0 | 17.0779 |
-| qwen2.5-coder:1.5b | 56.2710 |
-| qwen2.5-coder:7b-instruct | 18.2935 |
-| deepseek-coder:6.7b | 21.2014 |
-| deepseek-r1:1.5b | 55.0080 |
-| deepseek-r1:7b | 18.0150 |
-| deepseek-r1:8b | 16.4574 |
-| deepseek-r1 | 18.0991 |
-| driaforall/tiny-agent-a:0.5b | 86.2903 |
-| driaforall/tiny-agent-a:1.5b | 41.6198 |
-| driaforall/tiny-agent-a:3b | 24.1364 |
-
-### AMD (64 CPU, 128GB RAM)
-
-| Model | TPS |
-| ---------------------------- | ------- |
-| phi3.5:3.8b | 33.8993 |
-| llama3.1:latest | 19.3015 |
-| llama3.1:8b-text-q4_K_M | 19.9081 |
-| llama3.2:1b | 55.6815 |
-| llama3.2:3b | 36.6654 |
-| llama3.2:1b-text-q4_K_M | 68.9655 |
-| qwen2.5:7b-instruct-q5_0 | 18.0591 |
-| qwen2.5-coder:1.5b | 56.7301 |
-| qwen2.5-coder:7b-instruct | 20.1563 |
-| deepseek-coder:6.7b | 23.4261 |
-| deepseek-r1:1.5b | 57.0494 |
-| deepseek-r1:7b | 20.3577 |
-| deepseek-r1:8b | 18.6653 |
-| deepseek-r1 | 20.2571 |
-| driaforall/tiny-agent-a:0.5b | 94.6503 |
-| driaforall/tiny-agent-a:1.5b | 49.5431 |
-| driaforall/tiny-agent-a:3b | 27.1564 |
-
-### AMD (96 CPU, 192GB RAM)
-
-| Model | TPS |
-| ---------------------------- | ------- |
-| phi3.5:3.8b | 34.1058 |
-| llama3.1:latest | 20.2221 |
-| llama3.1:8b-text-q4_K_M | 20.1473 |
-| llama3.2:1b | 54.5232 |
-| llama3.2:3b | 37.6344 |
-| llama3.2:1b-text-q4_K_M | 65.7570 |
-| qwen2.5:7b-instruct-q5_0 | 20.2058 |
-| qwen2.5-coder:1.5b | 55.4435 |
-| qwen2.5-coder:7b-instruct | 21.3058 |
-| deepseek-coder:6.7b | 24.6414 |
-| deepseek-r1:1.5b | 54.3133 |
-| deepseek-r1:7b | 20.8902 |
-| deepseek-r1:8b | 18.7142 |
-| deepseek-r1 | 22.1564 |
-| driaforall/tiny-agent-a:0.5b | 94.7864 |
-| driaforall/tiny-agent-a:1.5b | 50.7868 |
-| driaforall/tiny-agent-a:3b | 29.4635 |
-
-### AMD (192 CPU, 384GB RAM)
-
-| Model | TPS |
-| ----------------------------------- | ------- |
-| finalend/hermes-3-llama-3.1:8b-q8_0 | 16.8002 |
-| phi3.5:3.8b | 26.2855 |
-| phi3.5:3.8b-mini-instruct-fp16 | 16.7343 |
-| llama3.1:latest | 21.9456 |
-| llama3.1:8b-instruct-q8_0 | 16.7135 |
-| llama3.1:8b-text-q4_K_M | 22.5764 |
-| llama3.1:8b-text-q8_0 | 16.3817 |
-| llama3.2:1b | 43.5632 |
-| llama3.2:3b | 29.5560 |
-| llama3.2:1b-text-q4_K_M | 48.6348 |
-| qwen2.5:7b-instruct-q5_0 | 21.4938 |
-| qwen2.5-coder:1.5b | 33.3333 |
-| qwen2.5-coder:7b-instruct | 21.7933 |
-| qwen2.5-coder:7b-instruct-q8_0 | 17.8134 |
-| deepseek-coder:6.7b | 23.4474 |
-| deepseek-r1:1.5b | 32.7795 |
-| deepseek-r1:7b | 22.5376 |
-| deepseek-r1:8b | 20.3057 |
-| deepseek-r1 | 23.0604 |
-| driaforall/tiny-agent-a:0.5b | 42.1866 |
-| driaforall/tiny-agent-a:1.5b | 33.4957 |
-| driaforall/tiny-agent-a:3b | 24.5138 |
-
-### ARM (192 CPU, 384GB RAM)
-
-| Model | TPS |
-| ---------------------------- | ------- |
-| phi3.5:3.8b | 26.3062 |
-| llama3.1:latest | 18.9597 |
-| llama3.1:8b-text-q4_K_M | 18.2489 |
-| llama3.2:1b | 43.7856 |
-| llama3.2:3b | 30.3443 |
-| llama3.2:1b-text-q4_K_M | 49.6852 |
-| qwen2.5:7b-instruct-q5_0 | 16.8128 |
-| qwen2.5-coder:1.5b | 38.3562 |
-| qwen2.5-coder:7b-instruct | 19.5582 |
-| deepseek-coder:6.7b | 21.2699 |
-| deepseek-r1:1.5b | 36.0020 |
-| deepseek-r1:7b | 19.5293 |
-| deepseek-r1:8b | 18.5300 |
-| deepseek-r1 | 18.9405 |
-| driaforall/tiny-agent-a:0.5b | 28.4991 |
-| driaforall/tiny-agent-a:1.5b | 31.6353 |
-| driaforall/tiny-agent-a:3b | 22.2788 |
-
-## 📝 Notes
-
-- CPU usage can vary significantly between tasks, especially for long context vs. multiple steps.
-
-- Some models may require more than the available CPU cores, which could lead to slower performance.
-
-- RAM usage is generally consistent but can spike for certain operations.
-
-- **Important**: Lower CPU count results in lower performance. Systems with fewer CPUs will process requests more slowly, especially for models that require more CPU resources than are available.
-
-Remember, these are minimum specs, and your experience may vary depending on the specific tasks and workload. Happy node running! 🎉
diff --git a/executor/Cargo.toml b/executor/Cargo.toml
deleted file mode 100644
index 12f12c42..00000000
--- a/executor/Cargo.toml
+++ /dev/null
@@ -1,36 +0,0 @@
-[package]
-name = "dkn-executor"
-version.workspace = true
-edition.workspace = true
-license.workspace = true
-readme = "README.md"
-authors = ["Erhan Tezcan "]
-
-
-[dependencies]
-env_logger.workspace = true
-
-# async stuff
-tokio-util.workspace = true
-tokio.workspace = true
-
-# serialize & deserialize
-serde.workspace = true
-serde_json.workspace = true
-
-# http & networking
-reqwest.workspace = true
-
-# logging & errors
-log.workspace = true
-eyre.workspace = true
-thiserror.workspace = true
-
-enum-iterator = "2.1.0"
-rig-core = "0.11.1"
-ollama-rs = { version = "0.3.0", features = ["tokio", "rustls", "stream"] }
-dkn-utils = { path = "../utils" }
-
-[dev-dependencies]
-# only used for tests
-dotenvy.workspace = true
diff --git a/executor/README.md b/executor/README.md
deleted file mode 100644
index 69f3199e..00000000
--- a/executor/README.md
+++ /dev/null
@@ -1,20 +0,0 @@
-# Dria Executor
-
-## Installation
-
-Add the package via `git` within your Cargo dependencies:
-
-```toml
-dkn-executor = { git = "https://github.com/firstbatchxyz/dkn-compute-node" }
-```
-
-## Usage
-
-Dria Executor makes use of several environment variables, with respect to several model providers.
-
-- `OLLAMA_HOST` is used to connect to **Ollama** server
-- `OLLAMA_PORT` is used to connect to **Ollama** server
-- `OLLAMA_AUTO_PULL` indicates whether we should pull missing models automatically or not
-- `OPENAI_API_KEY` is used for **OpenAI** requests
-- `GEMINI_API_KEY` is used for **Gemini** requests
-- `OPENROUTER_API_KEY` is used for **OpenRouter** requests.
diff --git a/executor/examples/ollama.rs b/executor/examples/ollama.rs
deleted file mode 100644
index cec6200a..00000000
--- a/executor/examples/ollama.rs
+++ /dev/null
@@ -1,19 +0,0 @@
-use dkn_executor::{DriaExecutorsManager, Model};
-
-#[tokio::main]
-async fn main() -> eyre::Result<()> {
- dotenvy::dotenv().ok();
-
- let model = Model::Llama3_2_1bInstructQ4Km;
- let models = vec![model];
- let mut config = DriaExecutorsManager::new_from_env_for_models(models.into_iter())?;
- config.check_services().await;
- assert!(config.models.contains(&model));
-
- let task = dkn_executor::TaskBody::new_prompt("Write a haiku about category theory.", model);
- let executor = config.get_executor(&task.model).await?;
- let result = executor.execute(task).await?;
-
- println!("{}", result);
- Ok(())
-}
diff --git a/executor/src/executors/gemini.rs b/executor/src/executors/gemini.rs
deleted file mode 100644
index fe77fadd..00000000
--- a/executor/src/executors/gemini.rs
+++ /dev/null
@@ -1,178 +0,0 @@
-use dkn_utils::payloads::SpecModelPerformance;
-use eyre::{eyre, Context, Result};
-use reqwest::Client;
-use rig::{
- completion::{Chat, PromptError},
- providers::gemini,
-};
-use serde::Deserialize;
-use std::collections::{HashMap, HashSet};
-
-use crate::{Model, TaskBody};
-
-/// OpenAI-specific configurations.
-#[derive(Clone)]
-pub struct GeminiClient {
- api_key: String,
- client: gemini::Client,
-}
-
-impl GeminiClient {
- /// Looks at the environment variables for Gemini API key.
- pub fn new(api_key: &str) -> Self {
- Self {
- api_key: api_key.to_string(),
- client: gemini::Client::new(api_key),
- }
- }
-
- /// Creates a new client using the API key in `GEMINI_API_KEY` environment variable.
- pub fn from_env() -> Result {
- let api_key = std::env::var("GEMINI_API_KEY")?;
- Ok(Self::new(&api_key))
- }
-
- pub async fn execute(&self, task: TaskBody) -> Result {
- let mut model = self.client.agent(&task.model.to_string());
- if let Some(preamble) = task.preamble {
- model = model.preamble(&preamble);
- }
-
- let agent = model.build();
-
- agent.chat(task.prompt, task.chat_history).await
- }
-
- /// Check if requested models exist & are available in the OpenAI account.
- pub async fn check(
- &self,
- models: &mut HashSet,
- ) -> Result> {
- let mut models_to_remove = Vec::new();
- let mut model_performances = HashMap::new();
- log::info!("Checking Gemini requirements");
-
- // check if models exist and select those that are available
- let gemini_models_names = self.fetch_models().await?;
- for requested_model in models.iter().cloned() {
- // check if model exists
- if !gemini_models_names
- .iter()
- // due to weird naming of models in Gemini API, we need to check prefix
- .any(|model| model.starts_with(&requested_model.to_string()))
- {
- log::warn!(
- "Model {} not found in your Gemini account, ignoring it.",
- requested_model
- );
- models_to_remove.push(requested_model);
- model_performances.insert(requested_model, SpecModelPerformance::NotFound);
- continue;
- }
-
- // make a dummy request
- if let Err(err) = self
- .execute(TaskBody::new_prompt("What is 2 + 2?", requested_model))
- .await
- {
- log::warn!(
- "Model {} failed dummy request, ignoring it: {}",
- requested_model,
- err
- );
- models_to_remove.push(requested_model);
- model_performances.insert(requested_model, SpecModelPerformance::ExecutionFailed);
- continue;
- }
-
- // record the performance of the model
- model_performances.insert(requested_model, SpecModelPerformance::Passed);
- }
-
- // remove models that are not available
- for model in models_to_remove.iter() {
- models.remove(model);
- }
-
- Ok(model_performances)
- }
-
- /// Returns the list of models available to this account.
- ///
- /// A gemini model name in API response is given as `models/{baseModelId}-{version}`
- /// the model name in Dria can include the version as well, so best bet is to check prefix
- /// ignoring the `models/` part.
- async fn fetch_models(&self) -> Result> {
- /// [Model](https://ai.google.dev/api/models#Model) API object, fields omitted.
- #[derive(Debug, Clone, Deserialize)]
- struct GeminiModel {
- name: String,
- // other fields are ignored from API response
- }
-
- #[derive(Debug, Clone, Deserialize)]
- struct GeminiModelsResponse {
- models: Vec,
- }
-
- // fetch models
- let client = Client::new();
- let request = client
- // [`models.list`](https://ai.google.dev/api/models#method:-models.list) endpoint
- .get("https://generativelanguage.googleapis.com/v1beta/models")
- .query(&[("key", &self.api_key)])
- .build()
- .wrap_err("failed to build request")?;
-
- let response = client
- .execute(request)
- .await
- .wrap_err("failed to send request")?;
-
- // parse response
- if response.status().is_client_error() {
- return Err(eyre!(
- "Failed to fetch Gemini models:\n{}",
- response.text().await.unwrap_or_default()
- ));
- }
- let gemini_models = response.json::().await?;
-
- Ok(gemini_models
- .models
- .into_iter()
- .map(|model| model.name.trim_start_matches("models/").to_string())
- .collect())
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[tokio::test]
- #[ignore = "requires Gemini API key"]
- async fn test_gemini_check() {
- let _ = env_logger::builder()
- .filter_level(log::LevelFilter::Off)
- .filter_module("dkn_executor", log::LevelFilter::Debug)
- .is_test(true)
- .try_init();
- let _ = dotenvy::dotenv(); // read api key
-
- let initial_models = [Model::Gemini2_0Flash, Model::Gemini2_5ProExp];
- let mut models = HashSet::from_iter(initial_models);
- GeminiClient::from_env()
- .unwrap()
- .check(&mut models)
- .await
- .unwrap();
- assert_eq!(models.len(), initial_models.len());
-
- // should give error for bad API key
- let res = GeminiClient::new("i-dont-work")
- .check(&mut HashSet::new())
- .await;
- assert!(res.is_err());
- }
-}
diff --git a/executor/src/executors/mod.rs b/executor/src/executors/mod.rs
deleted file mode 100644
index efa1baf6..00000000
--- a/executor/src/executors/mod.rs
+++ /dev/null
@@ -1,71 +0,0 @@
-use crate::{Model, ModelProvider, TaskBody};
-use dkn_utils::payloads::SpecModelPerformance;
-use rig::completion::PromptError;
-use std::collections::{HashMap, HashSet};
-
-mod ollama;
-use ollama::OllamaClient;
-
-// mod openai;
-// use openai::OpenAIClient;
-
-// mod gemini;
-// use gemini::GeminiClient;
-
-// mod openrouter;
-// use openrouter::OpenRouterClient;
-
-/// A wrapper enum for all model providers.
-#[derive(Clone)]
-pub enum DriaExecutor {
- Ollama(OllamaClient),
- // OpenAI(OpenAIClient),
- // Gemini(GeminiClient),
- // OpenRouter(OpenRouterClient),
-}
-
-impl DriaExecutor {
- /// Creates a new executor for the given provider using the API key in the environment variables.
- pub fn new_from_env(provider: ModelProvider) -> Result {
- match provider {
- ModelProvider::Ollama => OllamaClient::from_env().map(DriaExecutor::Ollama),
- // ModelProvider::OpenAI => OpenAIClient::from_env().map(DriaExecutor::OpenAI),
- // ModelProvider::Gemini => GeminiClient::from_env().map(DriaExecutor::Gemini),
- // ModelProvider::OpenRouter => OpenRouterClient::from_env().map(DriaExecutor::OpenRouter),
- }
- }
-
- /// Executes the given task using the appropriate provider.
- pub async fn execute(&self, task: TaskBody) -> Result {
- match self {
- DriaExecutor::Ollama(provider) => provider.execute(task).await,
- // DriaExecutor::OpenAI(provider) => provider.execute(task).await,
- // DriaExecutor::Gemini(provider) => provider.execute(task).await,
- // DriaExecutor::OpenRouter(provider) => provider.execute(task).await,
- }
- }
-
- /// Checks if the requested models exist and are available in the provider's account.
- ///
- /// For Ollama in particular, it also checks if the models are performant enough.
- pub async fn check(
- &self,
- models: &mut HashSet,
- ) -> eyre::Result> {
- match self {
- DriaExecutor::Ollama(provider) => provider.check(models).await,
- // DriaExecutor::OpenAI(provider) => provider.check(models).await,
- // DriaExecutor::Gemini(provider) => provider.check(models).await,
- // DriaExecutor::OpenRouter(provider) => provider.check(models).await,
- }
- }
-
- pub fn name(&self) -> String {
- match self {
- DriaExecutor::Ollama(_) => ModelProvider::Ollama.to_string(),
- // DriaExecutor::OpenAI(_) => ModelProvider::OpenAI.to_string(),
- // DriaExecutor::Gemini(_) => ModelProvider::Gemini.to_string(),
- // DriaExecutor::OpenRouter(_) => ModelProvider::OpenRouter.to_string(),
- }
- }
-}
diff --git a/executor/src/executors/ollama.rs b/executor/src/executors/ollama.rs
deleted file mode 100644
index 766099d8..00000000
--- a/executor/src/executors/ollama.rs
+++ /dev/null
@@ -1,253 +0,0 @@
-use dkn_utils::payloads::SpecModelPerformance;
-use eyre::{Context, Result};
-use ollama_rs::generation::completion::request::GenerationRequest;
-use rig::completion::{Chat, PromptError};
-use rig::providers::ollama;
-use std::collections::HashMap;
-use std::time::Duration;
-use std::{collections::HashSet, env};
-
-use crate::{Model, TaskBody};
-
-const DEFAULT_OLLAMA_HOST: &str = "http://127.0.0.1";
-const DEFAULT_OLLAMA_PORT: u16 = 11434;
-
-/// Timeout duration for checking model performance during a generation.
-const PERFORMANCE_TIMEOUT: Duration = Duration::from_secs(120);
-/// Minimum tokens per second (TPS) for checking model performance during a generation.
-const PERFORMANCE_MIN_TPS: f64 = 10.0;
-
-/// Ollama-specific configurations.
-#[derive(Clone)]
-pub struct OllamaClient {
- /// Whether to automatically pull models from Ollama.
- auto_pull: bool,
- /// Underlying Ollama client.
- client: ollama::Client,
- /// A more specialized Ollama client.
- ///
- /// - Can do pulls
- /// - Can list local models
- ollama_rs_client: ollama_rs::Ollama,
-}
-
-impl OllamaClient {
- /// Creates a new Ollama client using the host and port.
- pub fn new(host: &str, port: u16, auto_pull: bool) -> Self {
- Self {
- auto_pull,
- ollama_rs_client: ollama_rs::Ollama::new(host, port),
- client: ollama::Client::from_url(&format!("{host}:{port}",)),
- }
- }
-
- /// Looks at the environment variables for Ollama host and port.
- ///
- /// If not found, defaults to `DEFAULT_OLLAMA_HOST` and `DEFAULT_OLLAMA_PORT`.
- ///
- /// Returns a `Result` to be compatible with other executors.
- pub fn from_env() -> Result {
- let host = env::var("OLLAMA_HOST")
- .map(|h| h.trim_matches('"').to_string())
- .unwrap_or(DEFAULT_OLLAMA_HOST.to_string());
- let port = env::var("OLLAMA_PORT")
- .and_then(|port_str| port_str.parse().map_err(|_| std::env::VarError::NotPresent))
- .unwrap_or(DEFAULT_OLLAMA_PORT);
-
- // auto-pull, its true by default
- let auto_pull = env::var("OLLAMA_AUTO_PULL")
- .map(|s| s == "true")
- .unwrap_or(true);
-
- Ok(Self::new(&host, port, auto_pull))
- }
-
- /// Sets the auto-pull flag for Ollama models.
- pub fn with_auto_pull(mut self, auto_pull: bool) -> Self {
- self.auto_pull = auto_pull;
- self
- }
-
- pub async fn execute(&self, task: TaskBody) -> Result {
- let mut model = self.client.agent(&task.model.to_string());
- if let Some(preamble) = task.preamble {
- model = model.preamble(&preamble);
- }
-
- let agent = model.build();
-
- agent.chat(task.prompt, task.chat_history).await
- }
-
- /// Check if requested models exist in Ollama & test them using a dummy prompt.
- pub async fn check(
- &self,
- models: &mut HashSet,
- ) -> Result> {
- log::info!(
- "Checking Ollama requirements ({}, timeout: {}s, min tps: {})",
- if self.auto_pull {
- "auto-pull enabled"
- } else {
- "auto-pull disabled"
- },
- PERFORMANCE_TIMEOUT.as_secs(),
- PERFORMANCE_MIN_TPS
- );
-
- // fetch local models
- let local_models = match self.ollama_rs_client.list_local_models().await {
- Ok(models) => models.into_iter().map(|m| m.name).collect::>(),
- Err(e) => {
- return {
- log::error!("Could not fetch local models from Ollama, is it online?");
- Err(e.into())
- }
- }
- };
- log::info!("Found local Ollama models: {local_models:#?}");
-
- // check external models & pull them if available
- // iterate over models and remove bad ones
- let mut models_to_remove = Vec::new();
- let mut model_performances = HashMap::new();
- for model in models.iter() {
- // pull the model if it is not in the local models
- if !local_models.contains(&model.to_string()) {
- log::warn!("Model {model} not found in Ollama");
- if self.auto_pull {
- self.try_pull(model)
- .await
- .wrap_err("could not pull model")?;
- } else {
- log::error!("Please download missing model with: ollama pull {model}");
- log::error!("Or, set OLLAMA_AUTO_PULL=true to pull automatically.");
- eyre::bail!("required model not pulled in Ollama");
- }
- }
-
- // test its performance
- let perf = self.measure_tps_with_warmup(model).await;
- if let SpecModelPerformance::PassedWithTPS(_) = perf {
- model_performances.insert(*model, perf);
- } else {
- // if its anything but PassedWithTPS, remove the model
- models_to_remove.push(*model);
- model_performances.insert(*model, perf);
- }
- }
-
- // remove failed models
- for model in models_to_remove {
- models.remove(&model);
- }
-
- if models.is_empty() {
- log::warn!("No Ollama models passed the performance test! Try using a more powerful machine OR smaller models.");
- } else {
- log::info!("Ollama checks are finished, using models: {models:#?}");
- }
-
- Ok(model_performances)
- }
-
- /// Pulls a model from Ollama.
- async fn try_pull(&self, model: &Model) -> Result {
- // TODO: add pull-bar here
- // if auto-pull is enabled, pull the model
- log::info!("Downloading missing model {model} (this may take a while)");
- self.ollama_rs_client
- .pull_model(model.to_string(), false)
- .await
- .wrap_err("could not pull model")
- }
-
- /// Runs a small test to test local model performance.
- ///
- /// This is to see if a given system can execute tasks for their chosen models,
- /// e.g. if they have enough RAM/CPU and such.
- pub async fn measure_tps_with_warmup(&self, model: &Model) -> SpecModelPerformance {
- const TEST_PROMPT: &str = "Please write a poem about Kapadokya.";
- const WARMUP_PROMPT: &str = "Write a short poem about hedgehogs and squirrels.";
-
- log::info!("Measuring {model}");
-
- // run a dummy generation for warm-up
- log::debug!("Warming up Ollama for {model}");
- if let Err(err) = self
- .ollama_rs_client
- .generate(GenerationRequest::new(
- model.to_string(),
- WARMUP_PROMPT.to_string(),
- ))
- .await
- {
- log::warn!("Ignoring {model}: {err}");
- return SpecModelPerformance::ExecutionFailed;
- }
-
- // then, run a sample generation with timeout and measure tps
- let Ok(result) = tokio::time::timeout(
- PERFORMANCE_TIMEOUT,
- self.ollama_rs_client.generate(GenerationRequest::new(
- model.to_string(),
- TEST_PROMPT.to_string(),
- )),
- )
- .await
- else {
- log::warn!("Ignoring {model}: Timed out");
- return SpecModelPerformance::Timeout;
- };
-
- // check the result
- match result {
- Ok(response) => {
- let tps = (response.eval_count.unwrap_or_default() as f64)
- / (response.eval_duration.unwrap_or(1) as f64)
- * 1_000_000_000f64;
-
- if tps >= PERFORMANCE_MIN_TPS {
- log::info!("{model} passed the test with tps: {tps}");
- SpecModelPerformance::PassedWithTPS(tps)
- } else {
- log::warn!(
- "Ignoring {model}: tps too low ({tps:.3} < {PERFORMANCE_MIN_TPS:.3})"
- );
- SpecModelPerformance::FailedWithTPS(tps)
- }
- }
- Err(err) => {
- log::warn!("Ignoring {model} due to: {err}");
- SpecModelPerformance::ExecutionFailed
- }
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[tokio::test]
- #[ignore = "requires Ollama"]
- async fn test_ollama_prompt() {
- let client = OllamaClient::from_env().unwrap();
- let model = Model::Llama3_2_1bInstructQ4Km;
-
- let stats = client.try_pull(&model).await.unwrap();
- println!("Model {}: {:#?}", model, stats);
- let prompt = "The sky appears blue during the day because of a process called scattering. \
- When sunlight enters the Earth's atmosphere, it collides with air molecules such as oxygen and nitrogen. \
- These collisions cause some of the light to be absorbed or reflected, which makes the colors we see appear more vivid and vibrant. \
- Blue is one of the brightest colors that is scattered the most by the atmosphere, making it visible to our eyes during the day. \
- What may be the question this answer?".to_string();
-
- let response = client
- .execute(TaskBody::new_prompt(&prompt, model))
- .await
- .unwrap();
-
- println!("Prompt: {}\n\nResponse:{}", prompt, response);
- }
-}
diff --git a/executor/src/executors/openai.rs b/executor/src/executors/openai.rs
deleted file mode 100644
index cb98b5e1..00000000
--- a/executor/src/executors/openai.rs
+++ /dev/null
@@ -1,172 +0,0 @@
-use std::collections::{HashMap, HashSet};
-
-use dkn_utils::payloads::SpecModelPerformance;
-use eyre::{eyre, Context, Result};
-use reqwest::Client;
-use rig::{
- completion::{Chat, PromptError},
- providers::openai,
-};
-use serde::Deserialize;
-
-use crate::{Model, TaskBody};
-
-/// OpenAI-specific configurations.
-#[derive(Clone)]
-pub struct OpenAIClient {
- /// API key, if available.
- api_key: String,
- /// Underlying OpenAI client from [`rig`].
- client: openai::Client,
-}
-
-impl OpenAIClient {
- /// Looks at the environment variables for OpenAI API key.
- pub fn new(api_key: &str) -> Self {
- Self {
- api_key: api_key.to_string(),
- client: openai::Client::new(api_key),
- }
- }
-
- /// Creates a new OpenAI client using the API key in `OPENAI_API_KEY` environment variable.
- pub fn from_env() -> Result {
- let api_key = std::env::var("OPENAI_API_KEY")?;
- Ok(Self::new(&api_key))
- }
-
- pub async fn execute(&self, task: TaskBody) -> Result {
- let mut model = self.client.agent(&task.model.to_string());
- if let Some(preamble) = task.preamble {
- model = model.preamble(&preamble);
- }
-
- let agent = model.build();
-
- agent.chat(task.prompt, task.chat_history).await
- }
-
- /// Returns the list of model names available to this account.
- pub async fn check(
- &self,
- models: &mut HashSet,
- ) -> Result> {
- let mut models_to_remove = Vec::new();
- let mut model_performances = HashMap::new();
- log::info!("Checking OpenAI requirements");
-
- // check if models exist within the account and select those that are available
- let openai_model_names = self.fetch_models().await?;
- for model in models.iter().cloned() {
- // check if model exists
- if !openai_model_names.contains(&model.to_string()) {
- log::warn!(
- "Model {} not found in your OpenAI account, ignoring it.",
- model
- );
- models_to_remove.push(model);
- model_performances.insert(model, SpecModelPerformance::NotFound);
- continue;
- }
-
- // if it exists, make a dummy request
- if let Err(err) = self
- .execute(TaskBody::new_prompt("What is 2 + 2?", model))
- .await
- {
- log::warn!("Model {} failed dummy request, ignoring it: {}", model, err);
- models_to_remove.push(model);
- model_performances.insert(model, SpecModelPerformance::ExecutionFailed);
- continue;
- }
-
- // record the performance of the model
- model_performances.insert(model, SpecModelPerformance::Passed);
- }
-
- // remove models that are not available
- for model in models_to_remove.iter() {
- models.remove(model);
- }
-
- // log results
- if models.is_empty() {
- log::warn!("OpenAI checks are finished, no available models found.",);
- } else {
- log::info!("OpenAI checks are finished, using models: {:#?}", models);
- }
-
- Ok(model_performances)
- }
-
- /// Fetches the list of models available in the OpenAI account.
- async fn fetch_models(&self) -> Result> {
- /// [Model](https://platform.openai.com/docs/api-reference/models/object) API object, fields omitted.
- #[derive(Debug, Clone, Deserialize)]
- struct OpenAIModel {
- /// The model identifier, which can be referenced in the API endpoints.
- id: String,
- }
-
- #[derive(Debug, Clone, Deserialize)]
- struct OpenAIModelsResponse {
- data: Vec,
- }
-
- let client = Client::new();
- let request = client
- .get("https://api.openai.com/v1/models")
- .header("Authorization", format!("Bearer {}", self.api_key))
- .build()
- .wrap_err("failed to build request")?;
-
- let response = client
- .execute(request)
- .await
- .wrap_err("failed to send request")?;
-
- // parse response
- if !response.status().is_success() {
- Err(eyre!(
- "Failed to fetch OpenAI models:\n{}",
- response
- .text()
- .await
- .unwrap_or("could not get error text as well".to_string())
- ))
- } else {
- let openai_models = response.json::().await?;
- Ok(openai_models.data.into_iter().map(|m| m.id).collect())
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[tokio::test]
- #[ignore = "requires OpenAI API key"]
- async fn test_openai_check() {
- let _ = env_logger::builder()
- .filter_level(log::LevelFilter::Off)
- .filter_module("dkn_executor", log::LevelFilter::Debug)
- .is_test(true)
- .try_init();
- let _ = dotenvy::dotenv(); // read api key
-
- let initial_models = [Model::GPT4o, Model::GPT4oMini];
- let mut models = HashSet::from_iter(initial_models);
- OpenAIClient::from_env()
- .unwrap()
- .check(&mut models)
- .await
- .unwrap();
- assert_eq!(models.len(), initial_models.len());
-
- let res = OpenAIClient::new("i-dont-work")
- .check(&mut Default::default())
- .await;
- assert!(res.is_err());
- }
-}
diff --git a/executor/src/executors/openrouter.rs b/executor/src/executors/openrouter.rs
deleted file mode 100644
index d4fc3c5e..00000000
--- a/executor/src/executors/openrouter.rs
+++ /dev/null
@@ -1,98 +0,0 @@
-use std::collections::{HashMap, HashSet};
-
-use dkn_utils::payloads::SpecModelPerformance;
-use eyre::Result;
-use rig::completion::{Chat, PromptError};
-use rig::providers::openrouter;
-
-use crate::{Model, TaskBody};
-
-/// OpenRouter-specific configurations.
-#[derive(Clone)]
-pub struct OpenRouterClient {
- client: openrouter::Client,
-}
-
-impl OpenRouterClient {
- /// Looks at the environment variables for OpenRouter API key.
- pub fn new(api_key: &str) -> Self {
- Self {
- client: openrouter::Client::new(api_key),
- }
- }
-
- /// Creates a new client using the API key in `OPENROUTER_API_KEY` environment variable.
- pub fn from_env() -> Result {
- let api_key = std::env::var("OPENROUTER_API_KEY")?;
- Ok(Self::new(&api_key))
- }
-
- pub async fn execute(&self, task: TaskBody) -> Result {
- let mut model = self.client.agent(&task.model.to_string());
- if let Some(preamble) = task.preamble {
- model = model.preamble(&preamble);
- }
-
- let agent = model.build();
- agent.chat(task.prompt, task.chat_history).await
- }
-
- /// Checks if the API key exists.
- pub async fn check(
- &self,
- models: &mut HashSet,
- ) -> Result> {
- let mut models_to_remove = Vec::new();
- let mut model_performances = HashMap::new();
- log::info!("Checking OpenRouter API key");
-
- // make a dummy request with existing models
- for model in models.iter().cloned() {
- if let Err(err) = self
- .execute(TaskBody::new_prompt("What is 2 + 2?", model))
- .await
- {
- log::warn!("Model {} failed dummy request, ignoring it: {}", model, err);
- models_to_remove.push(model);
- model_performances.insert(model, SpecModelPerformance::ExecutionFailed);
- continue;
- }
-
- // record the model performance
- model_performances.insert(model, SpecModelPerformance::Passed);
- }
-
- // remove models that failed the dummy request
- for model in models_to_remove.iter() {
- models.remove(model);
- }
-
- Ok(model_performances)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[tokio::test]
- #[ignore = "requires OpenRouter API key"]
- async fn test_openrouter_check() {
- let _ = env_logger::builder()
- .filter_level(log::LevelFilter::Off)
- .filter_module("dkn_executor", log::LevelFilter::Debug)
- .is_test(true)
- .try_init();
- let _ = dotenvy::dotenv(); // read api key
-
- let initial_models = [Model::OR3_5Sonnet, Model::OR3_7Sonnet];
- let mut models = HashSet::from_iter(initial_models);
- let config = OpenRouterClient::from_env().unwrap();
- config.check(&mut models).await.unwrap();
- assert_eq!(models.len(), initial_models.len());
-
- // create with a bad api key
- let config = OpenRouterClient::new("i-dont-work");
- config.check(&mut HashSet::new()).await.unwrap(); // should not panic
- }
-}
diff --git a/executor/src/lib.rs b/executor/src/lib.rs
deleted file mode 100644
index 4dd16bcd..00000000
--- a/executor/src/lib.rs
+++ /dev/null
@@ -1,17 +0,0 @@
-mod executors;
-pub use executors::DriaExecutor;
-
-mod manager;
-pub use manager::DriaExecutorsManager;
-
-mod models;
-pub use models::{Model, ModelProvider};
-
-mod task;
-pub use task::{TaskBody, TaskResult};
-
-pub use rig::completion::CompletionModel;
-pub use rig::completion::{CompletionError, PromptError};
-
-// re-export ollama_rs
-pub use ollama_rs;
diff --git a/executor/src/manager.rs b/executor/src/manager.rs
deleted file mode 100644
index 01b5e2d2..00000000
--- a/executor/src/manager.rs
+++ /dev/null
@@ -1,143 +0,0 @@
-use dkn_utils::payloads::SpecModelPerformance;
-
-use crate::{executors::DriaExecutor, Model, ModelProvider};
-use std::collections::{HashMap, HashSet};
-
-#[derive(Clone)]
-pub struct DriaExecutorsManager {
- /// List of all models supported by this node.
- ///
- /// Equivalent to the union of all sets of models in the providers.
- pub models: HashSet,
- /// Providers and their executors along with the models they support.
- pub providers: HashMap)>,
-}
-
-impl DriaExecutorsManager {
- /// Creates a new executor manager with the given models, using environment variables for the providers.
- ///
- /// If a provider is required (as per the chosen model) but its environment variables are missing,
- /// this will return an error.
- pub fn new_from_env_for_models(
- models: impl Iterator,
- ) -> Result {
- let mut provider_set: HashMap)> =
- HashMap::new();
- let mut model_set = HashSet::new();
- for model in models {
- // get the provider for the model
- let provider = model.provider();
-
- // add model to the provider set, and create a new executor if needed
- match provider_set.get_mut(&provider) {
- Some((_, models)) => {
- models.insert(model);
- }
- None => {
- // create a new executor for the provider, may return an error!
- match DriaExecutor::new_from_env(provider) {
- Ok(executor) => {
- provider_set.insert(provider, (executor, HashSet::from_iter([model])));
- }
- Err(err) => {
- log::error!(
- "Failed to create executor for {provider}: {err}, {model} will not be supported.",
- );
- continue; // skip this model if the executor creation failed
- }
- }
- }
- }
-
- // add the model to the global model set
- model_set.insert(model);
- }
-
- Ok(Self {
- providers: provider_set,
- models: model_set,
- })
- }
-
- /// Given the model, returns a _cloned_ executor for it.
- ///
- /// If the model's provider is not supported, an error is returned.
- /// Likewise, if the provider is supported but the model is not, an error is returned.
- pub async fn get_executor(&self, model: &Model) -> eyre::Result {
- let provider = model.provider();
- let (executor, models) = self
- .providers
- .get(&provider)
- .ok_or_else(|| eyre::eyre!("Provider {provider} supported by this executor"))?;
-
- if models.contains(model) {
- Ok(executor.clone())
- } else {
- Err(eyre::eyre!("Model {model} not supported by this executor"))
- }
- }
-
- /// Returns the set of models supported by the given provider for this manager.
- ///
- /// If there are no models for the provider, an empty set is returned.
- pub fn get_models_for_provider(&self, provider: ModelProvider) -> HashSet {
- self.providers
- .get(&provider)
- .map(|(_, models)| models.clone())
- .unwrap_or_default()
- }
-
- /// Returns the names of all models in the manager, in a random order.
- pub fn get_model_names(&self) -> Vec {
- self.models.iter().map(|m| m.to_string()).collect()
- }
-
- /// Check if the required compute services are running.
- ///
- /// - If Ollama models are used the task is tested with a simple task with timeout.
- /// - If API based models are used, the API key is checked and the models are tested with a dummy request.
- ///
- /// In the end, bad models are filtered out and we simply check if we are left if any valid models at all.
- /// If there are no models left in the end, an error is thrown.
- pub async fn check_services(&mut self) -> HashMap {
- log::info!("Checking configured services.");
-
- // check all configured providers & record model performances
- let mut model_perf = HashMap::new();
- for (client, models) in self.providers.values_mut() {
- if let Ok(provider_model_perf) = client.check(models).await {
- model_perf.extend(provider_model_perf);
- } else {
- log::warn!(
- "Provider {} failed to check services, ignoring its models.",
- client.name()
- );
- model_perf.extend(
- models
- .iter()
- .map(|m| (*m, SpecModelPerformance::ExecutionFailed)),
- );
- // clear models
- models.clear();
- }
- }
-
- // obtain the final list of providers & models, removing the providers with no models left
- self.providers.retain(|provider, (_, models)| {
- let ok = !models.is_empty();
- if !ok {
- log::warn!("Provider {provider} has no models left, removing it from the config.")
- }
- ok
- });
-
- // update the models set
- self.models = self
- .providers
- .values()
- .flat_map(|(_, models)| models.iter().cloned())
- .collect();
-
- model_perf
- }
-}
diff --git a/executor/src/models.rs b/executor/src/models.rs
deleted file mode 100644
index c98160d7..00000000
--- a/executor/src/models.rs
+++ /dev/null
@@ -1,299 +0,0 @@
-use enum_iterator::Sequence;
-use serde::{Deserialize, Serialize};
-use std::{collections::HashSet, fmt, str::FromStr};
-
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, Sequence)]
-pub enum Model {
- // Ollama models
- /// [Meta's Llama3.1](https://ollama.com/library/llama3.1:8b-instruct-q4_K_M)
- #[serde(rename = "llama3.1:8b-instruct-q4_K_M")]
- Llama3_1_8bInstructQ4Km,
- /// [Meta's LLama3.2](https://ollama.com/library/llama3.2:1b-instruct-q4_K_M)
- #[serde(rename = "llama3.2:1b-instruct-q4_K_M")]
- Llama3_2_1bInstructQ4Km,
- /// [Meta's LLama3.3](https://ollama.com/library/llama3.3:70b-instruct-q4_K_M)
- #[serde(rename = "llama3.3:70b-instruct-q4_K_M")]
- Llama3_3_70bInstructQ4Km,
- /// [Mistral's Nemo](https://ollama.com/library/mistral-nemo:12b)
- #[serde(rename = "mistral-nemo:12b")]
- MistralNemo12b,
- /// [Google's Gemma3 4b](https://ollama.com/library/gemma3:4b)
- #[serde(rename = "gemma3:4b")]
- Gemma3_4b,
- /// [Google's Gemma3 12b](https://ollama.com/library/gemma3:12b)
- #[serde(rename = "gemma3:12b")]
- Gemma3_12b,
- /// [Google's Gemma3 27b](https://ollama.com/library/gemma3:27b)
- #[serde(rename = "gemma3:27b")]
- Gemma3_27b,
- /// [Alibaba's Qwen3 32b](https://ollama.com/library/qwen3:32b)
- #[serde(rename = "qwen3:32b")]
- Qwen3_32b,
- /// [Alibaba's Qwen3 8b](https://ollama.com/library/qwen3:8b)
- #[serde(rename = "qwen3:8b")]
- Qwen3_8b,
- // // OpenAI models
- // /// [OpenAI's GPT-4o](https://platform.openai.com/docs/models#gpt-4o)
- // #[serde(rename = "gpt-4o")]
- // GPT4o,
- // /// [OpenAI's GPT-4o mini](https://platform.openai.com/docs/models#gpt-4o-mini)
- // #[serde(rename = "gpt-4o-mini")]
- // GPT4oMini,
-
- // // Gemini models
- // /// [Google's Gemini 2.5 Pro experimental](https://ai.google.dev/gemini-api/docs/models#gemini-2.5-pro-preview-03-25)
- // #[serde(rename = "gemini-2.5-pro-exp-03-25")]
- // Gemini2_5ProExp,
- // /// [Google's Gemini 2.0 Flash](https://ai.google.dev/gemini-api/docs/models#gemini-2.0-flash)
- // #[serde(rename = "gemini-2.0-flash")]
- // Gemini2_0Flash,
-
- // /// OpenRouter Models
- // /// [Anthropic's Claude 3.5 Sonnet](https://openrouter.ai/models?q=claude-3.5-sonnet)
- // #[serde(rename = "anthropic/claude-3.5-sonnet")]
- // OR3_5Sonnet,
- // /// [Anthropic's Claude 3.7 Sonnet](https://openrouter.ai/models?q=claude-3.7-sonnet)
- // #[serde(rename = "anthropic/claude-3-7-sonnet")]
- // OR3_7Sonnet,
-}
-
-impl FromStr for Model {
- type Err = String;
-
- /// Tries to parse the given `str` into a `Model`.
- /// On failure, returns the original string back as the `Err` value.
- fn from_str(value: &str) -> Result {
- // serde requires quotes (for JSON)
- serde_json::from_str::(&format!("\"{value}\""))
- .map_err(|err| format!("Model {value} invalid: {err}"))
- }
-}
-
-impl Model {
- /// Returns a set of models from a CSV string.
- ///
- /// The input string should be a comma-separated list of model names.
- ///
- /// ## Example
- ///
- /// ```rs
- /// let models = Model::from_csv("gpt-4o, gpt-4o-mini");
- /// assert!(models.contains(&Model::GPT4o));
- /// assert!(models.contains(&Model::GPT4oMini));
- /// ```
- pub fn from_csv(input: impl AsRef) -> HashSet {
- HashSet::from_iter(
- input
- .as_ref()
- .split(',')
- .filter_map(|s| Self::try_from(s.trim()).ok()),
- )
- }
-
- /// Returns an iterator over all models.
- #[inline(always)]
- pub fn all() -> impl Iterator {
- enum_iterator::all::()
- }
-
- /// Returns an iterator over all models that belong to a given provider.
- #[inline(always)]
- pub fn all_with_provider(provider: &ModelProvider) -> impl Iterator + '_ {
- enum_iterator::all::().filter(move |m| m.provider() == *provider)
- }
-
- /// Returns the provider that hosts the model.
- #[inline]
- pub fn provider(&self) -> ModelProvider {
- ModelProvider::from(self)
- }
-}
-
-impl fmt::Display for Model {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- // guaranteed not to fail because this is enum to string serialization
- let self_str = serde_json::to_string(&self).unwrap_or_default();
- // remove quotes from JSON
- write!(f, "{}", self_str.trim_matches('"'))
- }
-}
-
-impl TryFrom for Model {
- type Error = String;
- fn try_from(value: String) -> Result {
- value.as_str().parse()
- }
-}
-
-impl TryFrom<&str> for Model {
- type Error = String;
- fn try_from(value: &str) -> Result {
- value.parse()
- }
-}
-
-/// A model provider is a service that hosts the chosen Model.
-/// It can be derived from the model name, e.g. GPT4o is hosted by OpenAI (via API), or Phi3 is hosted by Ollama (locally).
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, Sequence)]
-pub enum ModelProvider {
- #[serde(rename = "ollama")]
- Ollama,
- // #[serde(rename = "openai")]
- // OpenAI,
- // #[serde(rename = "gemini")]
- // Gemini,
- // #[serde(rename = "openrouter")]
- // OpenRouter,
-}
-
-impl ModelProvider {
- /// Returns an iterator over all model providers.
- #[inline(always)]
- pub fn all() -> impl Iterator {
- enum_iterator::all::()
- }
-
- /// Returns all models that belong to the provider.
- #[inline]
- pub fn models(&self) -> impl Iterator + '_ {
- Model::all_with_provider(self)
- }
-
- /// Returns whether the provider is batchable
- /// (can be executed concurrently) or not.
- pub fn is_batchable(&self) -> bool {
- match self {
- // ollama models are not batchable
- ModelProvider::Ollama => false,
- // // api-based providers are batchable
- // ModelProvider::OpenAI => true,
- // ModelProvider::Gemini => true,
- // ModelProvider::OpenRouter => true,
- }
- }
-}
-
-impl From for ModelProvider {
- fn from(value: Model) -> Self {
- Self::from(&value)
- }
-}
-
-impl From<&Model> for ModelProvider {
- fn from(model: &Model) -> Self {
- match model {
- // ollama
- Model::Gemma3_4b => ModelProvider::Ollama,
- Model::Gemma3_12b => ModelProvider::Ollama,
- Model::Gemma3_27b => ModelProvider::Ollama,
- Model::Llama3_1_8bInstructQ4Km => ModelProvider::Ollama,
- Model::Llama3_2_1bInstructQ4Km => ModelProvider::Ollama,
- Model::Llama3_3_70bInstructQ4Km => ModelProvider::Ollama,
- Model::MistralNemo12b => ModelProvider::Ollama,
- Model::Qwen3_8b => ModelProvider::Ollama,
- Model::Qwen3_32b => ModelProvider::Ollama,
- // // openai
- // Model::GPT4o => ModelProvider::OpenAI,
- // Model::GPT4oMini => ModelProvider::OpenAI,
- // // gemini
- // Model::Gemini2_0Flash => ModelProvider::Gemini,
- // Model::Gemini2_5ProExp => ModelProvider::Gemini,
- // // openrouter
- // Model::OR3_5Sonnet => ModelProvider::OpenRouter,
- // Model::OR3_7Sonnet => ModelProvider::OpenRouter,
- }
- }
-}
-
-impl FromStr for ModelProvider {
- type Err = String;
-
- /// Tries to parse the given `str` into a `ModelProvider`.
- /// On failure, returns the original string back as the `Err` value.
- fn from_str(value: &str) -> Result {
- // serde requires quotes (for JSON)
- serde_json::from_str::(&format!("\"{value}\""))
- .map_err(|err| format!("Model provider {value} invalid: {err}"))
- }
-}
-
-impl TryFrom for ModelProvider {
- type Error = String;
- fn try_from(value: String) -> Result {
- value.as_str().parse()
- }
-}
-
-impl TryFrom<&str> for ModelProvider {
- type Error = String;
- fn try_from(value: &str) -> Result {
- value.parse()
- }
-}
-
-impl fmt::Display for ModelProvider {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- // guaranteed not to fail because this is enum to string serialization
- let self_str = serde_json::to_string(&self).unwrap_or_default();
- // remove quotes from JSON
- write!(f, "{}", self_str.trim_matches('"'))
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_model_string_conversion() {
- let model = Model::Gemma3_4b;
-
- // convert to string
- let model_str = model.clone().to_string();
- assert_eq!(model_str, "gemma3:4b");
-
- // (try) convert from string
- let model_from = Model::try_from(model_str).expect("should convert");
- assert_eq!(model_from, model);
-
- // (try) convert from string
- let model = Model::try_from("this-model-does-not-will-not-exist".to_string());
- assert!(model.is_err());
- }
-
- #[test]
- fn test_model_string_serde() {
- let model = Model::Gemma3_12b;
-
- // serialize to string via serde
- let model_str = serde_json::to_string(&model).expect("should serialize");
- assert_eq!(model_str, "\"gemma3:12b\"");
-
- // deserialize from string via serde
- let model_from: Model = serde_json::from_str(&model_str).expect("should deserialize");
- assert_eq!(model_from, model);
-
- // (try) deserialize from invalid model
- let bad_model = serde_json::from_str::("\"this-model-does-not-will-not-exist\"");
- assert!(bad_model.is_err());
- }
-
- #[test]
- fn test_provider_string_serde() {
- let provider = ModelProvider::Ollama;
-
- // serialize to string via serde
- let provider_str = serde_json::to_string(&provider).expect("should serialize");
- assert_eq!(provider_str, "\"ollama\"");
-
- // deserialize from string via serde
- let provider_from: ModelProvider =
- serde_json::from_str(&provider_str).expect("should deserialize");
- assert_eq!(provider_from, provider);
-
- // (try) deserialize from invalid model
- let bad_provider =
- serde_json::from_str::("\"this-provider-does-not-will-not-exist\"");
- assert!(bad_provider.is_err());
- }
-}
diff --git a/executor/src/task.rs b/executor/src/task.rs
deleted file mode 100644
index 4407be63..00000000
--- a/executor/src/task.rs
+++ /dev/null
@@ -1,168 +0,0 @@
-use rig::{
- completion::{CompletionRequest, PromptError},
- message::Message,
-};
-use serde::{Deserialize, Deserializer};
-
-use crate::{Model, ModelProvider};
-
-/// A future that represents the result of a task execution, of any provider.
-pub type TaskResult = Result;
-
-/// The body of a task request that includes the messages and the model to use.
-///
-/// Implements a custom [`Deserialize`] to convert from an object of the form below to self:
-///
-/// ```ts
-/// {
-/// "model": string,
-/// "messages": { role: string, content: string }[]
-/// }
-/// ```
-///
-/// For the `messages` array, the following rules apply:
-/// - If the first message is a system message, it will be stored in the `preamble` field.
-/// - The last message must be a user message, and it will be stored in the `prompt` field.
-/// - All other intermediate messages will be stored in the `chat_history` field.
-#[derive(Debug, Clone)]
-pub struct TaskBody {
- /// An optional system prompt.
- pub preamble: Option,
- /// The main user prompt.
- pub prompt: Message,
- /// List of messages for context or chat history.
- pub chat_history: Vec,
- /// The model to use for the task.
- pub model: Model,
-}
-
-impl TaskBody {
- /// Creates a new task body with the given prompt and model.
- pub fn new_prompt(prompt: impl Into, model: Model) -> Self {
- TaskBody {
- preamble: None,
- prompt: Message::user(prompt),
- chat_history: Vec::default(),
- model,
- }
- }
-
- /// Returns whether this task can be executed in parallel, w.r.t to its model.
- pub fn is_batchable(&self) -> bool {
- self.model.provider() != ModelProvider::Ollama
- }
-}
-
-impl From for CompletionRequest {
- fn from(task_body: TaskBody) -> Self {
- CompletionRequest {
- prompt: task_body.prompt,
- preamble: task_body.preamble,
- chat_history: task_body.chat_history,
- documents: Vec::default(),
- tools: Vec::default(),
- temperature: None,
- max_tokens: None,
- additional_params: None,
- }
- }
-}
-
-impl<'de> Deserialize<'de> for TaskBody {
- fn deserialize(deserializer: D) -> Result
- where
- D: Deserializer<'de>,
- {
- use serde::de::Error;
-
- #[derive(Deserialize)]
- struct RawMessage {
- role: String,
- content: String,
- }
-
- #[derive(Deserialize)]
- struct RawTaskBody {
- model: String,
- messages: Vec,
- }
-
- let raw = RawTaskBody::deserialize(deserializer)?;
-
- // parse model
- let model = Model::try_from(raw.model).map_err(|err_model| {
- Error::custom(format!("Model {err_model} is not supported by this node."))
- })?;
-
- // ensure there are messages
- if raw.messages.is_empty() {
- return Err(Error::custom("No messages found in the task body"));
- }
-
- // ensure the last message is from the user
- if raw.messages.last().unwrap().role != "user" {
- return Err(Error::custom("Last message must be from the user"));
- }
-
- let mut preamble = None;
- let mut messages = Vec::new();
- for msg in raw.messages.into_iter() {
- match msg.role.as_str() {
- "system" => {
- // we only expect to see one system message ever
- if preamble.is_some() {
- return Err(Error::custom("Only one system message is allowed"));
- }
- preamble = Some(msg.content);
- }
- "user" => {
- messages.push(Message::user(msg.content));
- }
- "assistant" => {
- messages.push(Message::assistant(msg.content));
- }
- _ => {
- return Err(Error::custom(format!("Invalid role: {}", msg.role)));
- }
- }
- }
-
- // the last message (ensured to be role: user), will be returned as the prompt separately
- let prompt = messages.pop().unwrap();
-
- Ok(TaskBody {
- preamble,
- prompt,
- chat_history: messages,
- model,
- })
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use serde_json::json;
-
- #[test]
- fn test_task_body_deserialization() {
- let json_data = json!({
- "model": "gemma3:4b",
- "messages": [
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "What is the capital of France?"},
- {"role": "assistant", "content": "The capital of France is Paris."},
- {"role": "user", "content": "How many letters are there in the answer to the last question?"},
- ]
- });
-
- let task_body: TaskBody = serde_json::from_value(json_data).unwrap();
-
- assert_eq!(task_body.model, Model::Gemma3_4b);
- assert_eq!(
- task_body.preamble,
- Some("You are a helpful assistant.".to_string())
- );
- assert_eq!(task_body.chat_history.len(), 2);
- }
-}
diff --git a/install.ps1 b/install.ps1
new file mode 100644
index 00000000..f72a7aaa
--- /dev/null
+++ b/install.ps1
@@ -0,0 +1,36 @@
+$ErrorActionPreference = "Stop"
+
+$repo = "firstbatchxyz/dkn-compute-node"
+$binary = "dria-node"
+
+# Get latest release tag (includes pre-releases)
+$releases = Invoke-RestMethod "https://api.github.com/repos/$repo/releases"
+$tag = $releases[0].tag_name
+if (-not $tag) {
+ Write-Error "Failed to fetch latest release"
+ exit 1
+}
+
+$asset = "$binary-windows-amd64.exe"
+$url = "https://github.com/$repo/releases/download/$tag/$asset"
+
+$installDir = "$env:LOCALAPPDATA\dria-node"
+if (-not (Test-Path $installDir)) {
+ New-Item -ItemType Directory -Path $installDir | Out-Null
+}
+
+$dest = Join-Path $installDir "$binary.exe"
+
+Write-Host "Installing $binary $tag..."
+Invoke-WebRequest -Uri $url -OutFile $dest
+
+# Add to PATH if not already there
+$userPath = [Environment]::GetEnvironmentVariable("Path", "User")
+if ($userPath -notlike "*$installDir*") {
+ [Environment]::SetEnvironmentVariable("Path", "$userPath;$installDir", "User")
+ $env:Path = "$env:Path;$installDir"
+ Write-Host "Added $installDir to PATH"
+}
+
+Write-Host "Installed $binary to $dest"
+& $dest --version
diff --git a/install.sh b/install.sh
new file mode 100755
index 00000000..d16fd9ac
--- /dev/null
+++ b/install.sh
@@ -0,0 +1,47 @@
+#!/bin/sh
+set -e
+
+REPO="firstbatchxyz/dkn-compute-node"
+BINARY="dria-node"
+INSTALL_DIR="/usr/local/bin"
+
+# Detect OS
+OS=$(uname -s)
+case "$OS" in
+ Linux*) OS_NAME="linux" ;;
+ Darwin*) OS_NAME="macOS" ;;
+ *) echo "Unsupported OS: $OS"; exit 1 ;;
+esac
+
+# Detect architecture
+ARCH=$(uname -m)
+case "$ARCH" in
+ x86_64|amd64) ARCH_NAME="amd64" ;;
+ aarch64|arm64) ARCH_NAME="arm64" ;;
+ *) echo "Unsupported architecture: $ARCH"; exit 1 ;;
+esac
+
+# Get latest release tag (includes pre-releases)
+TAG=$(curl -fsSL "https://api.github.com/repos/${REPO}/releases" | grep '"tag_name"' | head -1 | cut -d'"' -f4)
+if [ -z "$TAG" ]; then
+ echo "Failed to fetch latest release"
+ exit 1
+fi
+
+ASSET="${BINARY}-${OS_NAME}-${ARCH_NAME}"
+URL="https://github.com/${REPO}/releases/download/${TAG}/${ASSET}"
+
+echo "Installing ${BINARY} ${TAG} (${OS_NAME}/${ARCH_NAME})..."
+
+TMPFILE=$(mktemp)
+curl -fsSL "$URL" -o "$TMPFILE"
+chmod +x "$TMPFILE"
+
+if [ -w "$INSTALL_DIR" ]; then
+ mv "$TMPFILE" "${INSTALL_DIR}/${BINARY}"
+else
+ sudo mv "$TMPFILE" "${INSTALL_DIR}/${BINARY}"
+fi
+
+echo "Installed ${BINARY} to ${INSTALL_DIR}/${BINARY}"
+"${INSTALL_DIR}/${BINARY}" --version
diff --git a/p2p/Cargo.toml b/p2p/Cargo.toml
deleted file mode 100644
index 31e26396..00000000
--- a/p2p/Cargo.toml
+++ /dev/null
@@ -1,36 +0,0 @@
-[package]
-name = "dkn-p2p"
-version.workspace = true
-edition.workspace = true
-license.workspace = true
-readme = "README.md"
-authors = [
- "Erhan Tezcan ",
- "Anil Altuner {
- todo!("handle stuff")
- }
- None => {
- todo!("channel closed");
- break
- }
- }
-}
-```
-
-### Interactions
-
-Here is how the whole thing works in a bit more detail:
-
-- **Events**: When a message is received within the Swarm event handler, it is returned via a `mpsc` channel. Here, the p2p is `Sender` and your application must be the `Receiver`. The client handles many events, and only sends GossipSub message receipts via this channel so that the application can handle them however they would like.
-
-```mermaid
-sequenceDiagram
- actor A as Application
- actor P as P2P Client
-
- note over P: e_tx
- note over A: e_rx
-
- loop event loop
- activate P
- note over A: e_rx.wait()
- P ->> A: e_tx.send(message)
- deactivate P
-
- note over A: handle message
- end
-
-```
-
-- **Commands**: To call functions within this thread-scoped client, functions must be remotely called via the command `mpsc` channel. Here, p2p is `Receiver` and your application will be the `Sender` (we provide the commander client as well). While making a function call, a `oneshot` channel is created and its `Sender` is provided to the commander, kind of like a callback, and the caller waits as the `Receiver` for this call.
-
-```mermaid
-sequenceDiagram
- actor C as P2P Commander
- actor P as P2P Client
- note over C: c_tx
- activate C
- note over P: c_rx
-
- note over P: c_rx.wait()
- note over C: o_tx, o_rx := oneshot()
- C ->> P: c_tx.send(input, o_tx)
- deactivate C
- activate P
- note over C: o_rx.wait()
- P ->> C: o_tx.send(output)
- deactivate P
-```
diff --git a/p2p/src/behaviour.rs b/p2p/src/behaviour.rs
deleted file mode 100644
index f0723272..00000000
--- a/p2p/src/behaviour.rs
+++ /dev/null
@@ -1,53 +0,0 @@
-use eyre::Result;
-use libp2p::identity::{Keypair, PublicKey};
-use libp2p::{identify, request_response, StreamProtocol};
-use std::time::Duration;
-
-use crate::DriaP2PProtocol;
-
-#[derive(libp2p::swarm::NetworkBehaviour)]
-pub struct DriaBehaviour {
- pub identify: identify::Behaviour,
- pub request_response: request_response::cbor::Behaviour, Vec>,
-}
-
-impl DriaBehaviour {
- pub fn new(key: &Keypair, protocol: &DriaP2PProtocol) -> Self {
- let public_key = key.public();
-
- Self {
- identify: create_identify_behaviour(public_key, protocol.identity()),
- request_response: create_request_response_behaviour(protocol.request_response()),
- }
- }
-}
-
-/// Configures the request-response behaviour for the node.
-///
-/// The protocol supports bytes only.
-#[inline]
-fn create_request_response_behaviour(
- protocol_name: StreamProtocol,
-) -> request_response::cbor::Behaviour, Vec> {
- use request_response::{Behaviour, Config, ProtocolSupport};
-
- const REQUEST_RESPONSE_TIMEOUT: Duration = Duration::from_secs(512);
-
- Behaviour::new(
- [(protocol_name, ProtocolSupport::Full)],
- Config::default().with_request_timeout(REQUEST_RESPONSE_TIMEOUT),
- )
-}
-
-/// Configures the Identify behavior to allow nodes to exchange information like supported protocols.
-#[inline]
-fn create_identify_behaviour(
- local_public_key: PublicKey,
- protocol_version: String,
-) -> identify::Behaviour {
- use identify::{Behaviour, Config};
-
- Behaviour::new(
- Config::new(protocol_version, local_public_key).with_push_listen_addr_updates(true),
- )
-}
diff --git a/p2p/src/client.rs b/p2p/src/client.rs
deleted file mode 100644
index fe637ad9..00000000
--- a/p2p/src/client.rs
+++ /dev/null
@@ -1,358 +0,0 @@
-use eyre::Result;
-use libp2p::futures::StreamExt;
-use libp2p::swarm::{
- dial_opts::{DialOpts, PeerCondition},
- SwarmEvent,
-};
-use libp2p::{identify, noise, request_response, tcp, yamux};
-use libp2p::{Multiaddr, PeerId, Swarm, SwarmBuilder};
-use libp2p_identity::Keypair;
-use std::time::Duration;
-use tokio::sync::mpsc;
-
-use crate::behaviour::{DriaBehaviour, DriaBehaviourEvent};
-use crate::DriaP2PProtocol;
-
-use super::commands::DriaP2PCommand;
-use super::DriaP2PCommander;
-
-/// Buffer size for command channel.
-const COMMAND_CHANNEL_BUFSIZE: usize = 1024;
-/// Buffer size for events channel.
-const MSG_CHANNEL_BUFSIZE: usize = 1024;
-
-/// Request-response message type for Dria protocol, accepts bytes as both request and response.
-///
-/// The additional parsing must be done by the application itself (for now).
-pub type DriaReqResMessage = request_response::Message, Vec>;
-
-/// Peer-to-peer client for Dria Knowledge Network.
-pub struct DriaP2PClient {
- pub peer_id: PeerId,
- /// `Swarm` instance, everything p2p-related are accessed through this instace.
- swarm: Swarm,
- /// Dria protocol, used for identifying the client.
- protocol: DriaP2PProtocol,
- /// Request-response protocol messages.
- reqres_tx: mpsc::Sender<(PeerId, DriaReqResMessage)>,
- /// Command receiver.
- cmd_rx: mpsc::Receiver,
-}
-
-impl DriaP2PClient {
- /// Creates a new P2P client with the given keypair and listen address.
- ///
- /// The `version` is used to create the protocol strings for the client, and its very important that
- /// they match with the clients existing within the network.
- ///
- /// If for any reason the given `listen_addr` is not available, it will try to listen on a random port on `localhost`.
- #[allow(clippy::type_complexity)]
- pub fn new(
- keypair: Keypair,
- listen_addr: Multiaddr,
- rpc_addr: &Multiaddr,
- protocol: DriaP2PProtocol,
- ) -> Result<(
- DriaP2PClient,
- DriaP2PCommander,
- mpsc::Receiver<(PeerId, DriaReqResMessage)>,
- )> {
- let peer_id = keypair.public().to_peer_id();
-
- let mut swarm = SwarmBuilder::with_existing_identity(keypair)
- .with_tokio()
- .with_tcp(
- tcp::Config::default(),
- noise::Config::new,
- yamux::Config::default,
- )?
- .with_behaviour(|key| DriaBehaviour::new(key, &protocol))?
- // do not timeout at all, as we are only connected to an authority RPC at a given time and should stick to it
- .with_swarm_config(|c| c.with_idle_connection_timeout(Duration::from_secs(u64::MAX)))
- .build();
-
- // listen on all interfaces for incoming connections
- log::info!("Listening p2p network on: {listen_addr}");
- if let Err(err) = swarm.listen_on(listen_addr) {
- log::error!("Could not listen on address: {err:?}");
- log::warn!("Trying fallback address with localhost random port");
- swarm.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap())?;
- }
-
- // dial rpc node, this will cause `identify` event to be called on their side
- log::info!("Dialing RPC node: {rpc_addr}");
- if let Err(err) = swarm.dial(rpc_addr.clone()) {
- log::error!("Could not dial RPC node: {err:?}");
- };
-
- // create commander
- let (cmd_tx, cmd_rx) = mpsc::channel(COMMAND_CHANNEL_BUFSIZE);
- let commander = DriaP2PCommander::new(cmd_tx, protocol.clone());
-
- // create p2p client itself
- let (reqres_tx, reqres_rx) = mpsc::channel(MSG_CHANNEL_BUFSIZE);
-
- let client = Self {
- peer_id,
- swarm,
- protocol,
- reqres_tx,
- cmd_rx,
- };
-
- Ok((client, commander, reqres_rx))
- }
-
- /// Waits for swarm events and Node commands at the same time.
- ///
- /// To terminate, the command channel must be closed.
- pub async fn run(mut self) {
- loop {
- tokio::select! {
- command = self.cmd_rx.recv() => match command {
- Some(c) => self.handle_command(c).await,
- // channel closed, thus shutting down the network event loop
- None=> {
- log::info!("Closing peer-to-peer client.");
- return
- },
- },
- event = self.swarm.select_next_some() => self.handle_event(event).await,
- }
- }
- }
-
- /// Handles a single command, which originates from `DriaP2PCommander`.
- pub async fn handle_command(&mut self, command: DriaP2PCommand) {
- match command {
- DriaP2PCommand::Dial {
- peer_id,
- address,
- sender,
- } => {
- let opts = DialOpts::peer_id(peer_id)
- .addresses(vec![address])
- .condition(PeerCondition::Always)
- .build();
- let _ = sender.send(self.swarm.dial(opts));
- }
- DriaP2PCommand::IsConnected { peer_id, sender } => {
- let _ = sender.send(self.swarm.is_connected(&peer_id));
- }
- DriaP2PCommand::NetworkInfo { sender } => {
- let _ = sender.send(self.swarm.network_info());
- }
- DriaP2PCommand::Respond {
- data,
- channel,
- sender,
- } => {
- let _ = sender.send(
- self.swarm
- .behaviour_mut()
- .request_response
- .send_response(channel, data)
- .map_err(|_| eyre::eyre!("could not send response, channel is closed?")),
- );
- }
- DriaP2PCommand::Request {
- data,
- peer_id,
- sender,
- } => {
- let _ = sender.send(
- self.swarm
- .behaviour_mut()
- .request_response
- .send_request(&peer_id, data),
- );
- }
- DriaP2PCommand::Shutdown { sender } => {
- // close the command channel
- self.cmd_rx.close();
-
- let _ = sender.send(());
- }
- }
- }
-
- /// Handles a single event from the `swarm` stream.
- pub async fn handle_event(&mut self, event: SwarmEvent) {
- match event {
- /*****************************************
- * Request-response events *
- *****************************************/
- SwarmEvent::Behaviour(DriaBehaviourEvent::RequestResponse(
- request_response::Event::Message { message, peer, .. },
- )) => {
- // whether its a request or response, we forward it to the main thread
- if let Err(err) = self.reqres_tx.send((peer, message)).await {
- log::error!("Could not transfer request {err:?}");
- }
- }
-
- SwarmEvent::Behaviour(DriaBehaviourEvent::RequestResponse(
- request_response::Event::ResponseSent {
- peer, request_id, ..
- },
- )) => {
- log::debug!("Request-Response: response ({request_id}) sent to peer {peer} with",)
- }
- SwarmEvent::Behaviour(DriaBehaviourEvent::RequestResponse(
- request_response::Event::OutboundFailure {
- peer,
- request_id,
- error,
- ..
- },
- )) => {
- log::error!(
- "Request-Response: Outbound failure to peer {peer} with request_id {request_id}: {error:?}",
- );
- }
- SwarmEvent::Behaviour(DriaBehaviourEvent::RequestResponse(
- request_response::Event::InboundFailure {
- peer,
- request_id,
- error,
- ..
- },
- )) => {
- log::error!(
- "Request-Response: Inbound failure to {peer} with request_id {request_id}: {error:?}"
- );
- }
-
- /*****************************************
- * Identify events *
- *****************************************/
- SwarmEvent::Behaviour(DriaBehaviourEvent::Identify(identify::Event::Received {
- peer_id,
- info,
- ..
- })) => {
- if info.protocol_version != self.protocol.identity {
- log::warn!(
- "Identify: Peer {} has different Identify protocol: (them {}, you {})",
- peer_id,
- info.protocol_version,
- self.protocol.identity
- );
-
- // disconnect them
- let _ = self.swarm.disconnect_peer_id(peer_id);
- }
- }
-
- /*****************************************
- * Connection events and errors handling *
- *****************************************/
- SwarmEvent::NewListenAddr { address, .. } => {
- log::warn!("Local node is listening on {address}");
- }
- SwarmEvent::NewExternalAddrOfPeer { peer_id, address } => {
- log::info!("External address of peer {peer_id} confirmed: {address}");
- }
- SwarmEvent::ExternalAddrConfirmed { address } => {
- log::info!("External address confirmed: {address}");
- }
-
- SwarmEvent::IncomingConnectionError {
- local_addr,
- send_back_addr,
- error,
- ..
- } => {
- log::debug!(
- "Incoming connection error: from {local_addr} to {send_back_addr} - {error:?}"
- );
- }
- SwarmEvent::IncomingConnection {
- local_addr,
- send_back_addr,
- ..
- } => {
- log::debug!("Incoming connection attempt: from {local_addr} to {send_back_addr}");
- }
-
- SwarmEvent::OutgoingConnectionError { peer_id, error, .. } => {
- if let Some(peer_id) = peer_id {
- log::warn!("Could not connect to peer {peer_id}: {error:?}");
- } else {
- log::warn!("Outgoing connection error: {error:?}");
- }
- }
-
- SwarmEvent::ConnectionEstablished {
- peer_id,
- connection_id,
- endpoint,
- ..
- } => {
- if endpoint.is_dialer() {
- // we only care about logs about the ones that we have dialed
- log::info!(
- "Connection ({connection_id}) established with {peer_id} at {}",
- endpoint.get_remote_address()
- );
- } else {
- log::debug!(
- "Connection ({connection_id}) established with {peer_id} from {}",
- endpoint.get_remote_address()
- );
- }
- }
-
- SwarmEvent::ConnectionClosed {
- peer_id,
- connection_id,
- endpoint,
- cause,
- ..
- } => {
- // we only care about the connections that we have dialed
- if endpoint.is_dialer() {
- // if we know the cause, it may be a good idea to re-dial
- if let Some(cause) = cause {
- log::warn!(
- "Connection ({connection_id}) closed for {peer_id} due to {cause}"
- );
-
- let addr = endpoint.get_remote_address();
- log::info!("Dialing {peer_id} again at {addr}");
- if let Err(err) = self.swarm.dial(
- DialOpts::peer_id(peer_id)
- .addresses(vec![addr.clone()])
- .condition(PeerCondition::DisconnectedAndNotDialing)
- .build(),
- ) {
- log::error!("Could not dial peer {peer_id}: {err:?}");
- }
- } else {
- // if we don't know the cause, we don't want to re-dial,
- // because the cause is `None` if the other side closed the connection manually
- log::warn!(
- "Connection ({connection_id}) closed for {peer_id} without a cause, will not re-dial!"
- );
- }
- } else {
- log::debug!("Connection ({connection_id}) closed for {peer_id}: {cause:?}",);
- }
- }
-
- SwarmEvent::ExpiredListenAddr {
- address,
- listener_id,
- } => {
- // this may happen when your connection is lost, e.g. you turn off your machine / internet
- log::warn!("Listener ({listener_id}) expired: {address}");
- }
-
- SwarmEvent::ListenerError { listener_id, error } => {
- log::error!("Listener ({listener_id}) failed: {error}");
- }
-
- event => log::debug!("Unhandled Swarm Event: {event:?}"),
- }
- }
-}
diff --git a/p2p/src/commands.rs b/p2p/src/commands.rs
deleted file mode 100644
index 2a1a344a..00000000
--- a/p2p/src/commands.rs
+++ /dev/null
@@ -1,154 +0,0 @@
-use eyre::{Context, Result};
-use libp2p::{request_response, swarm, Multiaddr, PeerId};
-use tokio::sync::{mpsc, oneshot};
-
-use crate::DriaP2PProtocol;
-
-#[derive(Debug)]
-pub enum DriaP2PCommand {
- /// Returns the network information, such as the number of incoming and outgoing connections.
- NetworkInfo {
- sender: oneshot::Sender,
- },
- /// Check if there is an active connection to the given peer.
- IsConnected {
- peer_id: PeerId,
- sender: oneshot::Sender,
- },
- /// Dial a known peer.
- Dial {
- peer_id: PeerId,
- address: Multiaddr,
- sender: oneshot::Sender>,
- },
- /// Respond to a request-response message.
- Respond {
- data: Vec,
- channel: request_response::ResponseChannel>,
- sender: oneshot::Sender>,
- },
- /// Request a request-response message.
- /// Note that you are likely to be caught by the RPC peer id check,
- /// and your messages will be ignored.
- Request {
- peer_id: PeerId,
- data: Vec,
- sender: oneshot::Sender,
- },
- /// Shutsdown the client, closes the command channel.
- Shutdown { sender: oneshot::Sender<()> },
-}
-
-pub struct DriaP2PCommander {
- sender: mpsc::Sender,
- protocol: DriaP2PProtocol,
-}
-
-impl DriaP2PCommander {
- pub fn new(sender: mpsc::Sender, protocol: DriaP2PProtocol) -> Self {
- Self { sender, protocol }
- }
-
- /// Returns a reference to the protocol.
- pub fn protocol(&self) -> &DriaP2PProtocol {
- &self.protocol
- }
-
- /// Returns the network information, such as the number of
- /// incoming and outgoing connections.
- pub async fn network_info(&self) -> Result {
- let (sender, receiver) = oneshot::channel();
-
- self.sender
- .send(DriaP2PCommand::NetworkInfo { sender })
- .await
- .wrap_err("could not send")?;
-
- receiver.await.wrap_err("could not receive")
- }
-
- pub async fn respond(
- &mut self,
- data: Vec,
- channel: request_response::ResponseChannel>,
- ) -> Result<()> {
- let (sender, receiver) = oneshot::channel();
-
- self.sender
- .send(DriaP2PCommand::Respond {
- data,
- channel,
- sender,
- })
- .await
- .wrap_err("could not send")?;
-
- receiver
- .await
- .wrap_err("could not receive")?
- .wrap_err("could not respond")
- }
-
- pub async fn request(
- &mut self,
- peer_id: PeerId,
- data: impl Into>,
- ) -> Result {
- let data = data.into();
- let (sender, receiver) = oneshot::channel();
-
- self.sender
- .send(DriaP2PCommand::Request {
- data,
- peer_id,
- sender,
- })
- .await
- .wrap_err("could not send")?;
-
- receiver.await.wrap_err("could not receive")
- }
-
- /// Dials a given peer.
- pub async fn dial(&mut self, peer_id: PeerId, address: Multiaddr) -> Result<()> {
- let (sender, receiver) = oneshot::channel();
-
- self.sender
- .send(DriaP2PCommand::Dial {
- peer_id,
- address,
- sender,
- })
- .await
- .wrap_err("could not send")?;
-
- receiver
- .await
- .wrap_err("could not receive")?
- .wrap_err("could not dial")
- }
-
- /// Checks if there is an active connection to the given peer.
- pub async fn is_connected(&mut self, peer_id: PeerId) -> Result {
- let (sender, receiver) = oneshot::channel();
-
- self.sender
- .send(DriaP2PCommand::IsConnected { peer_id, sender })
- .await
- .wrap_err("could not send")?;
-
- receiver.await.wrap_err("could not receive")
- }
-
- /// Sends a shutdown signal to the client.
- pub async fn shutdown(&mut self) -> Result<()> {
- let (sender, receiver) = oneshot::channel();
-
- self.sender
- .send(DriaP2PCommand::Shutdown { sender })
- .await
- .wrap_err("could not send")?;
-
- receiver.await.wrap_err("could not receive")
- }
-}
diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs
deleted file mode 100644
index 81383226..00000000
--- a/p2p/src/lib.rs
+++ /dev/null
@@ -1,14 +0,0 @@
-mod behaviour;
-
-mod client;
-pub use client::{DriaP2PClient, DriaReqResMessage};
-
-mod commands;
-pub use commands::{DriaP2PCommand, DriaP2PCommander};
-
-mod protocol;
-pub use protocol::DriaP2PProtocol;
-
-// re-exports
-pub use libp2p;
-pub use libp2p_identity;
diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs
deleted file mode 100644
index 52781b9f..00000000
--- a/p2p/src/protocol.rs
+++ /dev/null
@@ -1,104 +0,0 @@
-use libp2p::StreamProtocol;
-use std::env;
-
-#[derive(Clone, Debug)]
-pub struct DriaP2PProtocol {
- /// Main protocol name, e.g. `dria`.
- pub name: String,
- /// Version of the protocol, e.g. `0.2`.
- /// By default, this is set to the current `major.minor` version of the crate.
- pub version: String,
- /// Identity protocol string to be used for the Identity behaviour.
- ///
- /// This is usually `{name}/{version}`.
- pub identity: String,
- /// Request-response protocol, must match with other peers in the network.
- ///
- /// This is usually `/{name}/rr/{version}`, notice the `/` at the start
- /// which is mandatory for a `StreamProtocol`.
- ///
- pub request_response: StreamProtocol,
-}
-
-impl std::fmt::Display for DriaP2PProtocol {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.identity)
- }
-}
-
-impl Default for DriaP2PProtocol {
- /// Creates a new instance of the protocol with the default name `dria`.
- fn default() -> Self {
- Self::new_major_minor("dria")
- }
-}
-
-impl DriaP2PProtocol {
- /// Creates a new instance of the protocol with the given `name` and `version`.
- pub fn new(name: impl ToString, version: impl ToString) -> Self {
- let name = name.to_string();
- let version = version.to_string();
-
- let identity = format!("{name}/{version}");
- let request_response =
- StreamProtocol::try_from_owned(format!("/{name}/rr/{version}")).unwrap();
-
- Self {
- name,
- version,
- identity,
- request_response,
- }
- }
-
- /// Creates a new instance of the protocol with the given `name` and the current version as per Cargo.toml.
- /// The verison is represented with `major.minor` version numbers.
- pub fn new_major_minor(name: &str) -> Self {
- const VERSION: &str = concat!(
- env!("CARGO_PKG_VERSION_MAJOR"),
- ".",
- env!("CARGO_PKG_VERSION_MINOR")
- );
-
- Self::new(name, VERSION)
- }
-
- /// Returns the identity protocol, e.g. `dria/0.2`.
- pub fn identity(&self) -> String {
- self.identity.clone()
- }
-
- /// Returns the request-response protocol, e.g. `/dria/rr/0.2`.
- pub fn request_response(&self) -> StreamProtocol {
- self.request_response.clone()
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_new() {
- let protocol = DriaP2PProtocol::new("test", "1.0");
- assert_eq!(protocol.name, "test");
- assert_eq!(protocol.version, "1.0");
- assert_eq!(protocol.identity, "test/1.0");
- assert_eq!(protocol.request_response.to_string(), "/test/rr/1.0");
- }
-
- #[test]
- fn test_new_major_minor() {
- let protocol = DriaP2PProtocol::new_major_minor("test");
- assert_eq!(protocol.name, "test");
- assert_eq!(
- protocol.version,
- concat!(
- env!("CARGO_PKG_VERSION_MAJOR"),
- ".",
- env!("CARGO_PKG_VERSION_MINOR")
- )
- );
- assert_eq!(protocol.identity, format!("test/{}", protocol.version));
- }
-}
diff --git a/p2p/tests/request_test.rs b/p2p/tests/request_test.rs
deleted file mode 100644
index ce8c3132..00000000
--- a/p2p/tests/request_test.rs
+++ /dev/null
@@ -1,64 +0,0 @@
-use std::str::FromStr;
-use std::thread::sleep;
-use std::time::Duration;
-
-use dkn_p2p::{DriaP2PClient, DriaP2PProtocol};
-use eyre::Result;
-use libp2p::PeerId;
-use libp2p_identity::Keypair;
-
-/// Makes a dummy request to some peer hardcoded within the test.
-///
-/// ## Run command
-///
-/// ```sh
-/// cargo test --package dkn-p2p --test request_test --all-features -- test_request_message --exact --show-output --ignored
-/// ```
-#[tokio::test]
-#[ignore = "run this manually"]
-async fn test_request_message() -> Result<()> {
- let _ = env_logger::builder()
- .filter_level(log::LevelFilter::Off)
- .filter_module("request_test", log::LevelFilter::Debug)
- .filter_module("dkn_p2p", log::LevelFilter::Debug)
- .is_test(true)
- .try_init();
-
- // prepare nodes
- let rpc_addr = "your-rpc-here".parse().unwrap();
-
- // spawn P2P client in another task
- let (client, mut commander, mut req_rx) = DriaP2PClient::new(
- Keypair::generate_secp256k1(),
- "/ip4/127.0.0.1/tcp/0".parse().unwrap(),
- &rpc_addr,
- DriaP2PProtocol::default(),
- )
- .expect("could not create p2p client");
-
- // spawn task
- let task_handle = tokio::spawn(async move { client.run().await });
-
- log::info!("Waiting a bit until we have enough peers");
- sleep(Duration::from_secs(10));
-
- let peer_id =
- PeerId::from_str("16Uiu2HAmB5HGdwLNHX81u7ey1fvDx5Mr4ofa2PdSSVxFKrrcErAN").unwrap();
- log::info!("Making a request to peer: {}", peer_id);
- commander.request(peer_id, b"here is some data").await?;
-
- log::info!("Waiting for response logs for a few moments...");
- sleep(Duration::from_secs(5));
-
- // close command channel
- commander.shutdown().await.expect("could not shutdown");
-
- // close other channels
- req_rx.close();
-
- log::info!("Waiting for p2p task to finish...");
- task_handle.await?;
-
- log::info!("Done!");
- Ok(())
-}
diff --git a/scripts/install.ps1 b/scripts/install.ps1
new file mode 100644
index 00000000..8bc189e7
--- /dev/null
+++ b/scripts/install.ps1
@@ -0,0 +1,61 @@
+# Dria Node installer for Windows
+# Usage: irm https://raw.githubusercontent.com/firstbatchxyz/dkn-compute-node/v2/scripts/install.ps1 | iex
+$ErrorActionPreference = "Stop"
+
+$Repo = "firstbatchxyz/dkn-compute-node"
+$Binary = "dria-node"
+$InstallDir = "$env:LOCALAPPDATA\dria"
+
+Write-Host "Dria Node Installer" -ForegroundColor Cyan
+
+# Fetch latest release
+Write-Host "Fetching latest release..." -ForegroundColor Blue
+try {
+ $Release = Invoke-RestMethod -Uri "https://api.github.com/repos/$Repo/releases/latest"
+ $Tag = $Release.tag_name
+} catch {
+ Write-Host "Error: Failed to fetch latest release. Check your internet connection." -ForegroundColor Red
+ exit 1
+}
+
+Write-Host "Latest release: $Tag" -ForegroundColor Blue
+
+# Download binary
+$Asset = "$Binary-windows-amd64.exe"
+$Url = "https://github.com/$Repo/releases/download/$Tag/$Asset"
+
+Write-Host "Downloading $Asset..." -ForegroundColor Blue
+$TmpFile = Join-Path $env:TEMP "$Binary.exe"
+try {
+ Invoke-WebRequest -Uri $Url -OutFile $TmpFile -UseBasicParsing
+} catch {
+ Write-Host "Error: Download failed. Asset may not exist: $Url" -ForegroundColor Red
+ exit 1
+}
+
+# Install
+if (-not (Test-Path $InstallDir)) {
+ New-Item -ItemType Directory -Path $InstallDir -Force | Out-Null
+}
+$Dest = Join-Path $InstallDir "$Binary.exe"
+Move-Item -Path $TmpFile -Destination $Dest -Force
+Write-Host "Installed to $Dest" -ForegroundColor Blue
+
+# Add to PATH if not present
+$UserPath = [Environment]::GetEnvironmentVariable("PATH", "User")
+if ($UserPath -notlike "*$InstallDir*") {
+ [Environment]::SetEnvironmentVariable("PATH", "$InstallDir;$UserPath", "User")
+ $env:PATH = "$InstallDir;$env:PATH"
+ Write-Host "Added $InstallDir to user PATH." -ForegroundColor Blue
+ Write-Host "Restart your terminal for PATH changes to take effect." -ForegroundColor Yellow
+}
+
+# Verify
+Write-Host ""
+try {
+ $Version = & $Dest --version 2>&1
+ Write-Host "Successfully installed $Version" -ForegroundColor Green
+} catch {
+ Write-Host "Installed successfully. Run '$Binary --version' to verify." -ForegroundColor Green
+}
+Write-Host "Run '$Binary start --help' to get started." -ForegroundColor Cyan
diff --git a/scripts/install.sh b/scripts/install.sh
new file mode 100755
index 00000000..57e5ac1a
--- /dev/null
+++ b/scripts/install.sh
@@ -0,0 +1,91 @@
+#!/usr/bin/env bash
+# Dria Node installer for macOS and Linux
+# Usage: curl -sSL https://raw.githubusercontent.com/firstbatchxyz/dkn-compute-node/v2/scripts/install.sh | bash
+set -euo pipefail
+
+REPO="firstbatchxyz/dkn-compute-node"
+BINARY="dria-node"
+
+info() { printf '\033[1;34m%s\033[0m\n' "$*"; }
+error() { printf '\033[1;31mError: %s\033[0m\n' "$*" >&2; exit 1; }
+
+# Detect OS
+case "$(uname -s)" in
+ Darwin) OS="macOS" ;;
+ Linux) OS="linux" ;;
+ *) error "Unsupported OS: $(uname -s). Use Windows installer for Windows." ;;
+esac
+
+# Detect architecture
+case "$(uname -m)" in
+ x86_64|amd64) ARCH="amd64" ;;
+ aarch64|arm64) ARCH="arm64" ;;
+ *) error "Unsupported architecture: $(uname -m)" ;;
+esac
+
+# On Linux x86_64, check for AVX2 support and fall back to noavx if missing
+if [ "$OS" = "linux" ] && [ "$ARCH" = "amd64" ]; then
+ if ! grep -q avx2 /proc/cpuinfo 2>/dev/null; then
+ ARCH="amd64-noavx"
+ info "CPU does not support AVX2, using baseline binary."
+ fi
+fi
+
+info "Detected: ${OS} ${ARCH}"
+
+# Fetch latest release tag
+info "Fetching latest release..."
+LATEST=$(curl -sSf "https://api.github.com/repos/${REPO}/releases/latest" \
+ | grep '"tag_name"' | head -1 | cut -d'"' -f4) \
+ || error "Failed to fetch latest release. Check your internet connection."
+
+[ -z "$LATEST" ] && error "Could not determine latest release tag."
+info "Latest release: ${LATEST}"
+
+# Download binary
+ASSET="${BINARY}-${OS}-${ARCH}"
+URL="https://github.com/${REPO}/releases/download/${LATEST}/${ASSET}"
+
+info "Downloading ${ASSET}..."
+TMPDIR=$(mktemp -d)
+trap 'rm -rf "$TMPDIR"' EXIT
+
+curl -sSfL -o "${TMPDIR}/${BINARY}" "$URL" \
+ || error "Download failed. Asset may not exist for your platform: ${URL}"
+
+chmod +x "${TMPDIR}/${BINARY}"
+
+# Install
+if [ -w "/usr/local/bin" ]; then
+ INSTALL_DIR="/usr/local/bin"
+elif [ "$(id -u)" = "0" ]; then
+ INSTALL_DIR="/usr/local/bin"
+else
+ INSTALL_DIR="${HOME}/.local/bin"
+ mkdir -p "$INSTALL_DIR"
+fi
+
+mv "${TMPDIR}/${BINARY}" "${INSTALL_DIR}/${BINARY}"
+info "Installed to ${INSTALL_DIR}/${BINARY}"
+
+# Check if install dir is in PATH
+case ":${PATH}:" in
+ *":${INSTALL_DIR}:"*) ;;
+ *)
+ info ""
+ info "WARNING: ${INSTALL_DIR} is not in your PATH."
+ info "Add it by running:"
+ info " export PATH=\"${INSTALL_DIR}:\$PATH\""
+ info "Or add that line to your ~/.bashrc / ~/.zshrc"
+ ;;
+esac
+
+# Verify
+if command -v "$BINARY" &>/dev/null; then
+ info ""
+ info "Successfully installed $(${BINARY} --version)"
+ info "Run '${BINARY} start --help' to get started."
+else
+ info ""
+ info "Installation complete. Run '${INSTALL_DIR}/${BINARY} --version' to verify."
+fi
diff --git a/src/config.rs b/src/config.rs
new file mode 100644
index 00000000..b143db07
--- /dev/null
+++ b/src/config.rs
@@ -0,0 +1,329 @@
+use std::path::PathBuf;
+
+use clap::{Parser, Subcommand};
+
+use crate::error::NodeError;
+
+#[derive(Parser)]
+#[command(name = "dria-node", version, about = "Dria Compute Node")]
+pub struct Cli {
+ #[command(subcommand)]
+ pub command: Command,
+}
+
+#[derive(Subcommand)]
+pub enum Command {
+ /// Interactive setup: pick a model, download it, and run a test
+ Setup {
+ /// Data directory
+ #[arg(long, env = "DRIA_DATA_DIR")]
+ data_dir: Option,
+
+ /// Number of GPU layers to offload (-1 = all, 0 = CPU only)
+ #[arg(long, env = "DRIA_GPU_LAYERS", default_value = "0")]
+ gpu_layers: i32,
+ },
+
+ /// Start the compute node
+ Start {
+ /// Wallet secret key (hex-encoded, 32 bytes)
+ #[arg(long, env = "DRIA_WALLET")]
+ wallet: String,
+
+ /// Model(s) to serve (comma-separated shortnames, e.g. "qwen3.5:9b,lfm2.5:1.2b")
+ #[arg(long, env = "DRIA_MODELS")]
+ model: String,
+
+ /// Router URL for task coordination
+ #[arg(long, env = "DRIA_ROUTER_URL", default_value = "quic.dria.co:4001")]
+ router_url: String,
+
+ /// Number of GPU layers to offload (-1 = all, 0 = CPU only)
+ #[arg(long, env = "DRIA_GPU_LAYERS", default_value = "0")]
+ gpu_layers: i32,
+
+ /// Maximum concurrent inference requests
+ #[arg(long, env = "DRIA_MAX_CONCURRENT", default_value = "1")]
+ max_concurrent: usize,
+
+ /// Data directory
+ #[arg(long, env = "DRIA_DATA_DIR")]
+ data_dir: Option,
+
+ /// Override GGUF quantization (e.g. Q8_0, Q5_K_M, Q6_K). Defaults to the registry value (usually Q4_K_M).
+ #[arg(long, env = "DRIA_QUANT")]
+ quant: Option,
+
+ /// Skip TLS certificate verification (for development/testing)
+ #[arg(long, env = "DRIA_INSECURE")]
+ insecure: bool,
+
+ /// Skip automatic update check on startup
+ #[arg(long, env = "DRIA_SKIP_UPDATE")]
+ skip_update: bool,
+
+ /// Maximum context window size (tokens). When set, engines use min(model_native, this value).
+ /// When unset, engines use the model's full native context window.
+ #[arg(long, env = "DRIA_CONTEXT_SIZE")]
+ context_size: Option,
+
+ /// KV cache quantization type (q8_0, q4_0, f16). Default: q8_0 (halves KV memory vs f16).
+ #[arg(long, env = "DRIA_KV_QUANT", default_value = "q8_0")]
+ kv_quant: String,
+ },
+}
+
+/// Parsed and validated configuration for the node.
+pub struct Config {
+ pub secret_key_hex: String,
+ pub model_names: Vec,
+ pub router_urls: Vec,
+ pub gpu_layers: i32,
+ pub max_concurrent: usize,
+ pub data_dir: PathBuf,
+ pub models_dir: PathBuf,
+ pub quant: Option,
+ pub insecure: bool,
+ pub skip_update: bool,
+ pub max_context: Option,
+ pub kv_quant: String,
+}
+
+impl Config {
+ /// Create a Config from the `start` subcommand arguments.
+ #[allow(clippy::too_many_arguments)]
+ pub fn from_start_args(
+ wallet: String,
+ model: String,
+ router_url: String,
+ gpu_layers: i32,
+ max_concurrent: usize,
+ data_dir: Option,
+ quant: Option,
+ insecure: bool,
+ skip_update: bool,
+ max_context: Option,
+ kv_quant: String,
+ ) -> Result {
+ // Validate wallet key
+ let secret_key_hex = wallet.strip_prefix("0x").unwrap_or(&wallet).to_string();
+ if secret_key_hex.len() != 64 {
+ return Err(NodeError::Config(format!(
+ "wallet secret key must be 64 hex chars, got {}",
+ secret_key_hex.len()
+ )));
+ }
+ hex::decode(&secret_key_hex)
+ .map_err(|e| NodeError::Config(format!("wallet key is not valid hex: {e}")))?;
+
+ // Parse model names
+ let model_names: Vec = model
+ .split(',')
+ .map(|s| s.trim().to_string())
+ .filter(|s| !s.is_empty())
+ .collect();
+ if model_names.is_empty() {
+ return Err(NodeError::Config("at least one model must be specified".into()));
+ }
+
+ // Resolve data directory
+ let data_dir = match data_dir {
+ Some(d) => d,
+ None => dirs::home_dir()
+ .ok_or_else(|| NodeError::Config("could not determine home directory".into()))?
+ .join(".dria"),
+ };
+ let models_dir = data_dir.join("models");
+
+ if max_concurrent == 0 {
+ return Err(NodeError::Config("max-concurrent must be >= 1".into()));
+ }
+
+ // Parse router URLs (comma-separated)
+ let router_urls: Vec = router_url
+ .split(',')
+ .map(|s| s.trim().to_string())
+ .filter(|s| !s.is_empty())
+ .collect();
+ if router_urls.is_empty() {
+ return Err(NodeError::Config("at least one router URL must be specified".into()));
+ }
+
+ Ok(Config {
+ secret_key_hex,
+ model_names,
+ router_urls,
+ gpu_layers,
+ max_concurrent,
+ data_dir,
+ models_dir,
+ quant,
+ insecure,
+ skip_update,
+ max_context,
+ kv_quant,
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_config_from_valid_args() {
+ let cfg = Config::from_start_args(
+ "0x6472696164726961647269616472696164726961647269616472696164726961".into(),
+ "qwen3.5:9b, lfm2.5:1.2b".into(),
+ "quic.dria.co:4001".into(),
+ 0,
+ 1,
+ Some("/tmp/dria-test".into()),
+ None,
+ false,
+ false,
+ None,
+ "q8_0".into(),
+ )
+ .unwrap();
+
+ assert_eq!(cfg.model_names, vec!["qwen3.5:9b", "lfm2.5:1.2b"]);
+ assert_eq!(
+ cfg.secret_key_hex,
+ "6472696164726961647269616472696164726961647269616472696164726961"
+ );
+ assert_eq!(cfg.models_dir, PathBuf::from("/tmp/dria-test/models"));
+ assert_eq!(cfg.router_urls, vec!["quic.dria.co:4001"]);
+ }
+
+ #[test]
+ fn test_config_invalid_wallet_length() {
+ let result = Config::from_start_args(
+ "0xabcd".into(),
+ "qwen3.5:9b".into(),
+ "quic.dria.co:4001".into(),
+ 0,
+ 1,
+ None,
+ None,
+ false,
+ false,
+ None,
+ "q8_0".into(),
+ );
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_config_invalid_wallet_hex() {
+ let result = Config::from_start_args(
+ "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz".into(),
+ "qwen3.5:9b".into(),
+ "quic.dria.co:4001".into(),
+ 0,
+ 1,
+ None,
+ None,
+ false,
+ false,
+ None,
+ "q8_0".into(),
+ );
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_config_empty_model() {
+ let result = Config::from_start_args(
+ "6472696164726961647269616472696164726961647269616472696164726961".into(),
+ "".into(),
+ "quic.dria.co:4001".into(),
+ 0,
+ 1,
+ None,
+ None,
+ false,
+ false,
+ None,
+ "q8_0".into(),
+ );
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_config_zero_concurrency() {
+ let result = Config::from_start_args(
+ "6472696164726961647269616472696164726961647269616472696164726961".into(),
+ "qwen3.5:9b".into(),
+ "quic.dria.co:4001".into(),
+ 0,
+ 0,
+ None,
+ None,
+ false,
+ false,
+ None,
+ "q8_0".into(),
+ );
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_config_comma_separated_router_urls() {
+ let cfg = Config::from_start_args(
+ "6472696164726961647269616472696164726961647269616472696164726961".into(),
+ "qwen3.5:9b".into(),
+ "https://router1.dria.co, https://router2.dria.co".into(),
+ 0,
+ 1,
+ None,
+ None,
+ false,
+ false,
+ None,
+ "q8_0".into(),
+ )
+ .unwrap();
+ assert_eq!(
+ cfg.router_urls,
+ vec!["https://router1.dria.co", "https://router2.dria.co"]
+ );
+ }
+
+ #[test]
+ fn test_config_empty_router_url() {
+ let result = Config::from_start_args(
+ "6472696164726961647269616472696164726961647269616472696164726961".into(),
+ "qwen3.5:9b".into(),
+ "".into(),
+ 0,
+ 1,
+ None,
+ None,
+ false,
+ false,
+ None,
+ "q8_0".into(),
+ );
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_config_insecure_flag() {
+ let cfg = Config::from_start_args(
+ "6472696164726961647269616472696164726961647269616472696164726961".into(),
+ "qwen3.5:9b".into(),
+ "quic.dria.co:4001".into(),
+ 0,
+ 1,
+ None,
+ None,
+ true,
+ false,
+ None,
+ "q8_0".into(),
+ )
+ .unwrap();
+ assert!(cfg.insecure);
+ }
+}
diff --git a/src/error.rs b/src/error.rs
new file mode 100644
index 00000000..440093a1
--- /dev/null
+++ b/src/error.rs
@@ -0,0 +1,31 @@
+use thiserror::Error;
+
+#[derive(Debug, Error)]
+pub enum NodeError {
+ #[error("config error: {0}")]
+ Config(String),
+
+ #[error("identity error: {0}")]
+ Identity(String),
+
+ #[error("inference error: {0}")]
+ Inference(String),
+
+ #[error("model error: {0}")]
+ Model(String),
+
+ #[error("network error: {0}")]
+ Network(String),
+
+ #[error("update error: {0}")]
+ Update(String),
+
+ #[error("io error: {0}")]
+ Io(#[from] std::io::Error),
+}
+
+impl From for NodeError {
+ fn from(e: dkn_protocol::ProtocolError) -> Self {
+ NodeError::Network(e.to_string())
+ }
+}
diff --git a/src/identity.rs b/src/identity.rs
new file mode 100644
index 00000000..2b755369
--- /dev/null
+++ b/src/identity.rs
@@ -0,0 +1,116 @@
+use libsecp256k1::{PublicKey, SecretKey};
+use sha2::{Digest as _, Sha256};
+use sha3::Keccak256;
+
+use crate::error::NodeError;
+
+/// Node identity derived from a secp256k1 secret key.
+/// The address is an Ethereum-style address (last 20 bytes of keccak256 of uncompressed pubkey).
+pub struct Identity {
+ pub secret_key: SecretKey,
+ #[allow(dead_code)]
+ pub public_key: PublicKey,
+ #[allow(dead_code)]
+ pub address: [u8; 20],
+ pub address_hex: String,
+}
+
+impl Identity {
+ /// Create an identity from a hex-encoded secret key (without 0x prefix).
+ pub fn from_secret_hex(hex_str: &str) -> Result {
+ let bytes = hex::decode(hex_str)
+ .map_err(|e| NodeError::Identity(format!("invalid hex: {e}")))?;
+ let secret_key = SecretKey::parse_slice(&bytes)
+ .map_err(|e| NodeError::Identity(format!("invalid secret key: {e}")))?;
+ let public_key = PublicKey::from_secret_key(&secret_key);
+ let address = public_key_to_address(&public_key);
+ let address_hex = hex::encode(address);
+
+ Ok(Identity {
+ secret_key,
+ public_key,
+ address,
+ address_hex,
+ })
+ }
+
+ /// Sign a SHA-256 digest of the given message.
+ /// Returns (signature, recovery_id).
+ pub fn sign(&self, message: &[u8]) -> (libsecp256k1::Signature, libsecp256k1::RecoveryId) {
+ let digest = sha256hash(message);
+ let msg = libsecp256k1::Message::parse_slice(&digest)
+ .expect("SHA-256 output is always 32 bytes");
+ libsecp256k1::sign(&msg, &self.secret_key)
+ }
+}
+
+/// SHA-256 hash.
+#[inline(always)]
+pub fn sha256hash(data: impl AsRef<[u8]>) -> [u8; 32] {
+ Sha256::digest(data).into()
+}
+
+/// Keccak-256 hash.
+#[inline(always)]
+pub fn keccak256hash(data: impl AsRef<[u8]>) -> [u8; 32] {
+ Keccak256::digest(data).into()
+}
+
+/// Derive an Ethereum address from a secp256k1 public key.
+/// Serializes uncompressed (65 bytes: 0x04 || x || y), hashes (x || y) with keccak256,
+/// and takes the last 20 bytes.
+#[inline]
+fn public_key_to_address(public_key: &PublicKey) -> [u8; 20] {
+ let public_key_xy = &public_key.serialize()[1..];
+ let mut addr = [0u8; 20];
+ addr.copy_from_slice(&keccak256hash(public_key_xy)[12..32]);
+ addr
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ const DUMMY_SECRET_KEY: &[u8; 32] = b"driadriadriadriadriadriadriadria";
+
+ #[test]
+ fn test_sha256() {
+ let expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
+ assert_eq!(hex::encode(sha256hash(b"hello world")), expected);
+ }
+
+ #[test]
+ fn test_address_from_secret() {
+ let hex_key = hex::encode(DUMMY_SECRET_KEY);
+ let identity = Identity::from_secret_hex(&hex_key).unwrap();
+ assert_eq!(
+ identity.address_hex,
+ "d79fdf178547614cfdd0df6397c53569716bd596"
+ );
+ }
+
+ #[test]
+ fn test_sign_and_recover() {
+ let hex_key = hex::encode(DUMMY_SECRET_KEY);
+ let identity = Identity::from_secret_hex(&hex_key).unwrap();
+
+ let message = b"hello world";
+ let (signature, recid) = identity.sign(message);
+
+ // Recover public key from signature
+ let digest = sha256hash(message);
+ let msg = libsecp256k1::Message::parse_slice(&digest).unwrap();
+ let recovered = libsecp256k1::recover(&msg, &signature, &recid).unwrap();
+ assert_eq!(recovered, identity.public_key);
+ }
+
+ #[test]
+ fn test_invalid_hex() {
+ assert!(Identity::from_secret_hex("not-hex").is_err());
+ }
+
+ #[test]
+ fn test_invalid_key_length() {
+ assert!(Identity::from_secret_hex("abcd").is_err());
+ }
+}
diff --git a/src/inference/benchmark.rs b/src/inference/benchmark.rs
new file mode 100644
index 00000000..44c3146e
--- /dev/null
+++ b/src/inference/benchmark.rs
@@ -0,0 +1,71 @@
+use std::ops::ControlFlow;
+use std::time::Instant;
+
+use crate::error::NodeError;
+use crate::inference::engine::{GenerateParams, InferenceEngine};
+
+/// Result of a TPS benchmark run.
+#[derive(Debug, Clone)]
+pub struct TpsResult {
+ pub model_name: String,
+ pub prompt_eval_tps: f64,
+ pub generation_tps: f64,
+ pub total_time_ms: u64,
+ pub tokens_generated: u32,
+}
+
+impl std::fmt::Display for TpsResult {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(
+ f,
+ "{}: {:.1} tok/s generation, {:.1} tok/s prompt eval ({} tokens in {}ms)",
+ self.model_name,
+ self.generation_tps,
+ self.prompt_eval_tps,
+ self.tokens_generated,
+ self.total_time_ms,
+ )
+ }
+}
+
+const WARMUP_PROMPT: &str = "Write a short poem about hedgehogs and squirrels.";
+const BENCHMARK_PROMPT: &str = "Please write a poem about Kapadokya.";
+const BENCHMARK_MAX_TOKENS: u32 = 128;
+
+impl InferenceEngine {
+ /// Run a TPS benchmark: warmup generation, then timed generation.
+ pub fn benchmark(&self, model_name: &str) -> Result {
+ // Warmup: short generation to prime caches
+ let warmup_params = GenerateParams {
+ max_tokens: 16,
+ temperature: 0.0,
+ ..Default::default()
+ };
+ let _ = self.generate(WARMUP_PROMPT, &warmup_params, |_| ControlFlow::Continue(()));
+
+ // Timed benchmark
+ let bench_params = GenerateParams {
+ max_tokens: BENCHMARK_MAX_TOKENS,
+ temperature: 0.0,
+ ..Default::default()
+ };
+
+ let start = Instant::now();
+ let result = self.generate(BENCHMARK_PROMPT, &bench_params, |_| ControlFlow::Continue(()))?;
+ let total_time_ms = start.elapsed().as_millis() as u64;
+
+ let prompt_eval_tps = if result.prompt_eval_time_ms > 0 {
+ (result.prompt_tokens as f64) / (result.prompt_eval_time_ms as f64 / 1000.0)
+ } else {
+ 0.0
+ };
+
+ Ok(TpsResult {
+ model_name: model_name.to_string(),
+ prompt_eval_tps,
+ generation_tps: result.tokens_per_second,
+ total_time_ms,
+ tokens_generated: result.tokens_generated,
+ })
+ }
+}
diff --git a/src/inference/engine.rs b/src/inference/engine.rs
new file mode 100644
index 00000000..b1888e13
--- /dev/null
+++ b/src/inference/engine.rs
@@ -0,0 +1,1233 @@
+use std::ops::ControlFlow;
+use std::path::Path;
+use std::sync::OnceLock;
+use std::time::Instant;
+
+use llama_cpp_2::context::params::{KvCacheType, LlamaContextParams};
+use llama_cpp_2::llama_backend::LlamaBackend;
+
+/// Global singleton — llama.cpp backend can only be initialized once per process.
+static LLAMA_BACKEND: OnceLock = OnceLock::new();
+
+fn get_backend() -> Result<&'static LlamaBackend, NodeError> {
+ // OnceLock guarantees the closure runs exactly once, so BackendAlreadyInitialized
+ // cannot happen here. If init() somehow fails, it's a fatal environment issue.
+ Ok(LLAMA_BACKEND.get_or_init(|| {
+ LlamaBackend::init().expect("failed to init llama backend")
+ }))
+}
+use llama_cpp_2::llama_batch::LlamaBatch;
+use llama_cpp_2::model::params::LlamaModelParams;
+use llama_cpp_2::model::{AddBos, LlamaChatMessage, LlamaModel};
+use llama_cpp_2::mtmd::{MtmdBitmap, MtmdContext, MtmdContextParams, MtmdInputText};
+use llama_cpp_2::sampling::LlamaSampler;
+use llama_cpp_2::token::LlamaToken;
+
+use dkn_protocol::ChatMessage;
+
+use crate::error::NodeError;
+use crate::identity::sha256hash;
+use dkn_protocol::{InferenceProof, TokenLogprob};
+use crate::inference::stream::StreamToken;
+
+/// Parameters controlling text generation.
+#[derive(Debug, Clone)]
+pub struct GenerateParams {
+ pub max_tokens: u32,
+ pub temperature: f32,
+ pub top_p: f32,
+ pub seed: Option,
+ /// Extract logprobs every N tokens (0 = disabled).
+ /// E.g. 32 → positions [0, 32, 64, ...].
+ pub logprob_every_n: usize,
+ /// Top-k alternatives to collect at each logprob position.
+ pub logprob_top_k: usize,
+ /// Optional GBNF grammar string for constrained output.
+ pub grammar: Option,
+}
+
+impl Default for GenerateParams {
+ fn default() -> Self {
+ Self {
+ max_tokens: 512,
+ temperature: 0.7,
+ top_p: 0.9,
+ seed: None,
+ logprob_every_n: 0,
+ logprob_top_k: 5,
+ grammar: None,
+ }
+ }
+}
+
+/// Result of an inference run.
+#[derive(Debug, Clone)]
+pub struct InferenceResult {
+ pub text: String,
+ pub tokens_generated: u32,
+ pub prompt_tokens: u32,
+ pub generation_time_ms: u64,
+ pub prompt_eval_time_ms: u64,
+ pub tokens_per_second: f64,
+ pub proof: Option,
+}
+
+/// Wraps llama-cpp-2 for model loading and inference.
+///
+/// NOTE: `LlamaContext` is not Send/Sync. All inference must happen
+/// via `tokio::task::spawn_blocking` with the engine moved into the closure.
+pub struct InferenceEngine {
+ backend: &'static LlamaBackend,
+ model: LlamaModel,
+ mtmd_ctx: Option,
+ #[allow(dead_code)]
+ gpu_layers: i32,
+ /// Effective context window size (tokens), auto-detected from model metadata.
+ ctx_limit: u32,
+ /// KV cache quantization type (default Q8_0 to save memory).
+ kv_cache_type: KvCacheType,
+}
+
+/// Helper to convert a token to a string piece using the new token_to_piece API.
+fn token_to_string(model: &LlamaModel, token: LlamaToken) -> String {
+ let mut decoder = encoding_rs::UTF_8.new_decoder();
+ model
+ .token_to_piece(token, &mut decoder, true, None)
+ .unwrap_or_default()
+}
+
+impl InferenceEngine {
+ /// Load a GGUF model from disk, optionally with a multimodal projector.
+ ///
+ /// `max_context` optionally caps the context window (e.g. for limited VRAM).
+ /// When `None`, the model's full native context window is used.
+ pub fn load(
+ path: &Path,
+ gpu_layers: i32,
+ mmproj_path: Option<&Path>,
+ max_context: Option,
+ kv_cache_type: Option,
+ ) -> Result {
+ let kv_cache_type = kv_cache_type.unwrap_or(KvCacheType::Q8_0);
+ let backend = get_backend()?;
+
+ let model_params = if gpu_layers != 0 {
+ let layers = if gpu_layers < 0 { 1000 } else { gpu_layers as u32 };
+ LlamaModelParams::default().with_n_gpu_layers(layers)
+ } else {
+ LlamaModelParams::default()
+ };
+
+ let model = LlamaModel::load_from_file(backend, path, &model_params)
+ .map_err(|e| NodeError::Inference(format!("failed to load model: {e}")))?;
+
+ let n_ctx_train = model.n_ctx_train();
+ let ctx_limit = match max_context {
+ Some(cap) => n_ctx_train.min(cap),
+ None => n_ctx_train,
+ };
+ tracing::info!(model_ctx = n_ctx_train, effective_ctx = ctx_limit, kv_type = ?kv_cache_type, "context window");
+
+ let mtmd_ctx = match mmproj_path {
+ Some(p) => {
+ let params = MtmdContextParams::default();
+ let ctx = MtmdContext::init_from_file(
+ p.to_str()
+ .ok_or_else(|| NodeError::Inference("invalid mmproj path".into()))?,
+ &model,
+ ¶ms,
+ )
+ .map_err(|e| NodeError::Inference(format!("failed to init mtmd context: {e}")))?;
+ tracing::info!(
+ path = %p.display(),
+ vision = ctx.support_vision(),
+ audio = ctx.support_audio(),
+ "multimodal projector loaded"
+ );
+ Some(ctx)
+ }
+ None => None,
+ };
+
+ Ok(InferenceEngine {
+ backend,
+ model,
+ mtmd_ctx,
+ gpu_layers,
+ ctx_limit,
+ kv_cache_type,
+ })
+ }
+
+ /// Whether this engine has a multimodal projector loaded.
+ pub fn has_multimodal(&self) -> bool {
+ self.mtmd_ctx.is_some()
+ }
+
+ /// Return the number of GPU layers configured.
+ #[allow(dead_code)]
+ pub fn gpu_layers(&self) -> i32 {
+ self.gpu_layers
+ }
+
+ /// The model's native training context length.
+ #[allow(dead_code)]
+ pub fn n_ctx_train(&self) -> u32 {
+ self.model.n_ctx_train()
+ }
+
+ /// The effective context limit (possibly capped by --context-size).
+ pub fn ctx_limit(&self) -> u32 {
+ self.ctx_limit
+ }
+
+ /// Count prompt tokens without creating a context (LlamaModel is Send+Sync).
+ pub fn tokenize_count(&self, messages: &[ChatMessage]) -> Result {
+ let prompt = self.apply_template(messages)?;
+ let tokens = self.model
+ .str_to_token(&prompt, AddBos::Always)
+ .map_err(|e| NodeError::Inference(format!("tokenization failed: {e}")))?;
+ Ok(tokens.len() as u32)
+ }
+
+ /// Apply the GGUF-embedded chat template to produce a formatted prompt string.
+ pub fn apply_template(&self, messages: &[ChatMessage]) -> Result {
+ let template = self
+ .model
+ .chat_template(None)
+ .map_err(|e| NodeError::Inference(format!("no chat template in model: {e}")))?;
+ let llama_messages: Vec = messages
+ .iter()
+ .map(|m| LlamaChatMessage::new(m.role.clone(), m.content.to_string()))
+ .collect::>()
+ .map_err(|e| NodeError::Inference(format!("invalid chat message: {e}")))?;
+ self.model
+ .apply_chat_template(&template, &llama_messages, true)
+ .map_err(|e| NodeError::Inference(format!("failed to apply chat template: {e}")))
+ }
+
+ /// Apply the GGUF-embedded chat template with media parts replaced by the given marker.
+ fn apply_template_with_marker(
+ &self,
+ messages: &[ChatMessage],
+ marker: &str,
+ ) -> Result {
+ let template = self
+ .model
+ .chat_template(None)
+ .map_err(|e| NodeError::Inference(format!("no chat template in model: {e}")))?;
+ let llama_messages: Vec = messages
+ .iter()
+ .map(|m| {
+ LlamaChatMessage::new(m.role.clone(), m.content.text_with_markers(marker))
+ })
+ .collect::>()
+ .map_err(|e| NodeError::Inference(format!("invalid chat message: {e}")))?;
+ self.model
+ .apply_chat_template(&template, &llama_messages, true)
+ .map_err(|e| NodeError::Inference(format!("failed to apply chat template: {e}")))
+ }
+
+ /// Generate text from a prompt.
+ ///
+ /// `on_token` is called for each generated token. Return `ControlFlow::Break(())`
+ /// to stop generation early.
+ pub fn generate(
+ &self,
+ prompt: &str,
+ params: &GenerateParams,
+ mut on_token: F,
+ ) -> Result
+ where
+ F: FnMut(StreamToken) -> ControlFlow<()>,
+ {
+ // Tokenize prompt
+ let tokens = self
+ .model
+ .str_to_token(prompt, AddBos::Always)
+ .map_err(|e| NodeError::Inference(format!("tokenization failed: {e}")))?;
+ let prompt_token_count = tokens.len() as u32;
+
+ // Pre-flight: check that prompt + max_tokens fits in context
+ let needed = prompt_token_count + params.max_tokens;
+ if needed > self.ctx_limit {
+ return Err(NodeError::Inference(format!(
+ "prompt ({prompt_token_count}) + max_tokens ({}) = {needed} exceeds context ({})",
+ params.max_tokens, self.ctx_limit
+ )));
+ }
+
+ // Allocate only what this request needs (saves RAM vs full ctx_limit)
+ let ctx_size = std::num::NonZeroU32::new(needed);
+ let ctx_params = LlamaContextParams::default()
+ .with_n_ctx(ctx_size)
+ .with_type_k(self.kv_cache_type)
+ .with_type_v(self.kv_cache_type);
+
+ let mut ctx = self
+ .model
+ .new_context(self.backend, ctx_params)
+ .map_err(|e| NodeError::Inference(format!("failed to create context: {e}")))?;
+
+ // Evaluate prompt in chunks (n_batch = 2048 default in llama.cpp)
+ let prompt_start = Instant::now();
+ let n_batch = 2048usize;
+ let mut batch = LlamaBatch::new(n_batch.min(tokens.len()).max(1), 1);
+ let mut prompt_pos = 0;
+ while prompt_pos < tokens.len() {
+ batch.clear();
+ let chunk_end = (prompt_pos + n_batch).min(tokens.len());
+ for i in prompt_pos..chunk_end {
+ let is_last = i == tokens.len() - 1;
+ batch
+ .add(tokens[i], i as i32, &[0], is_last)
+ .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?;
+ }
+ ctx.decode(&mut batch)
+ .map_err(|e| NodeError::Inference(format!("prompt decode failed: {e}")))?;
+ prompt_pos = chunk_end;
+ }
+ let prompt_eval_time_ms = prompt_start.elapsed().as_millis() as u64;
+
+ // Build sampler chain (grammar first to mask invalid tokens, then sampling)
+ let mut samplers = vec![];
+ if let Some(ref grammar_str) = params.grammar {
+ samplers.push(
+ LlamaSampler::grammar(&self.model, grammar_str, "root")
+ .map_err(|e| NodeError::Inference(format!("grammar error: {e}")))?,
+ );
+ }
+ if params.temperature > 0.0 {
+ samplers.push(LlamaSampler::top_p(params.top_p, 1));
+ samplers.push(LlamaSampler::temp(params.temperature));
+ samplers.push(LlamaSampler::dist(params.seed.unwrap_or(0)));
+ } else {
+ samplers.push(LlamaSampler::greedy());
+ }
+ let mut sampler = LlamaSampler::chain_simple(samplers);
+
+ // Generation loop
+ let gen_start = Instant::now();
+ let mut generated_text = String::new();
+ let mut generated_count: u32 = 0;
+ let mut logprobs: Vec = Vec::new();
+ let mut current_pos = tokens.len() as i32;
+ let mut decoder = encoding_rs::UTF_8.new_decoder();
+ // Batch index where logits are available:
+ // after chunked prompt eval → last token's position in last chunk; after single-token decode → 0
+ let mut logit_batch_idx: i32 = ((tokens.len() - 1) % n_batch) as i32;
+
+ for _ in 0..params.max_tokens {
+ // sample() internally calls apply + select + accept
+ let new_token = sampler.sample(&ctx, -1);
+
+ if self.model.is_eog_token(new_token) {
+ break;
+ }
+
+ // Extract logprobs at stride positions
+ let gen_index = generated_count as usize;
+ if params.logprob_every_n > 0 && gen_index.is_multiple_of(params.logprob_every_n) {
+ if let Some(lp) =
+ self.extract_logprob(&ctx, logit_batch_idx, gen_index, new_token, params.logprob_top_k)
+ {
+ logprobs.push(lp);
+ }
+ }
+
+ // Decode token to text
+ let piece = self
+ .model
+ .token_to_piece(new_token, &mut decoder, true, None)
+ .unwrap_or_default();
+ generated_text.push_str(&piece);
+ generated_count += 1;
+
+ // Stream callback
+ let stream_token = StreamToken {
+ text: piece,
+ index: gen_index,
+ };
+ if let ControlFlow::Break(()) = on_token(stream_token) {
+ break;
+ }
+
+ // Prepare next batch
+ batch.clear();
+ batch
+ .add(new_token, current_pos, &[0], true)
+ .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?;
+ ctx.decode(&mut batch)
+ .map_err(|e| NodeError::Inference(format!("decode failed: {e}")))?;
+ logit_batch_idx = 0; // single-token batch → logits at batch index 0
+ current_pos += 1;
+ }
+
+ let generation_time_ms = gen_start.elapsed().as_millis() as u64;
+ let tokens_per_second = if generation_time_ms > 0 {
+ (generated_count as f64) / (generation_time_ms as f64 / 1000.0)
+ } else {
+ 0.0
+ };
+
+ let proof = if logprobs.is_empty() {
+ None
+ } else {
+ Some(InferenceProof {
+ logprobs,
+ kv_cache_hash: None,
+ })
+ };
+
+ Ok(InferenceResult {
+ text: generated_text,
+ tokens_generated: generated_count,
+ prompt_tokens: prompt_token_count,
+ generation_time_ms,
+ prompt_eval_time_ms,
+ tokens_per_second,
+ proof,
+ })
+ }
+
+ /// Generate text from multimodal messages containing image/audio parts.
+ ///
+ /// Uses the mtmd context to process media, then runs the standard sampling loop.
+ pub fn generate_multimodal(
+ &self,
+ messages: &[ChatMessage],
+ params: &GenerateParams,
+ mut on_token: F,
+ ) -> Result
+ where
+ F: FnMut(StreamToken) -> ControlFlow<()>,
+ {
+ let mtmd_ctx = self
+ .mtmd_ctx
+ .as_ref()
+ .ok_or_else(|| NodeError::Inference("no multimodal context loaded".into()))?;
+
+ // Get the default media marker used by the mtmd tokenizer
+ let marker = llama_cpp_2::mtmd::mtmd_default_marker();
+
+ // Apply chat template with media parts replaced by the marker
+ let prompt = self.apply_template_with_marker(messages, marker)?;
+
+ // Collect all media byte slices in order across all messages
+ let mut media_blobs: Vec<&[u8]> = Vec::new();
+ for msg in messages {
+ media_blobs.extend(msg.content.media_data());
+ }
+
+ // Create bitmaps from media blobs
+ let bitmaps: Vec = media_blobs
+ .iter()
+ .map(|data| {
+ MtmdBitmap::from_buffer(mtmd_ctx, data)
+ .map_err(|e| NodeError::Inference(format!("failed to create bitmap: {e}")))
+ })
+ .collect::, _>>()?;
+
+ let bitmap_refs: Vec<&MtmdBitmap> = bitmaps.iter().collect();
+
+ // Tokenize the prompt with media markers resolved to bitmap embeddings
+ let input_text = MtmdInputText {
+ text: prompt,
+ add_special: false, // chat template already includes BOS
+ parse_special: true,
+ };
+ let chunks = mtmd_ctx
+ .tokenize(input_text, &bitmap_refs)
+ .map_err(|e| NodeError::Inference(format!("mtmd tokenize failed: {e}")))?;
+
+ let prompt_token_count = chunks.total_tokens() as u32;
+
+ // Allocate only what this request needs (saves RAM vs full ctx_limit)
+ let needed = prompt_token_count + params.max_tokens;
+ let ctx_size = std::num::NonZeroU32::new(needed);
+ let ctx_params = LlamaContextParams::default()
+ .with_n_ctx(ctx_size)
+ .with_type_k(self.kv_cache_type)
+ .with_type_v(self.kv_cache_type);
+
+ let mut ctx = self
+ .model
+ .new_context(self.backend, ctx_params)
+ .map_err(|e| NodeError::Inference(format!("failed to create context: {e}")))?;
+
+ // Evaluate all chunks (text + media embeddings)
+ let prompt_start = Instant::now();
+ let n_past = chunks
+ .eval_chunks(mtmd_ctx, &ctx, 0, 0, 512, true)
+ .map_err(|e| NodeError::Inference(format!("mtmd eval_chunks failed: {e}")))?;
+ let prompt_eval_time_ms = prompt_start.elapsed().as_millis() as u64;
+
+ // Build sampler chain (grammar first to mask invalid tokens, then sampling)
+ let mut samplers = vec![];
+ if let Some(ref grammar_str) = params.grammar {
+ samplers.push(
+ LlamaSampler::grammar(&self.model, grammar_str, "root")
+ .map_err(|e| NodeError::Inference(format!("grammar error: {e}")))?,
+ );
+ }
+ if params.temperature > 0.0 {
+ samplers.push(LlamaSampler::top_p(params.top_p, 1));
+ samplers.push(LlamaSampler::temp(params.temperature));
+ samplers.push(LlamaSampler::dist(params.seed.unwrap_or(0)));
+ } else {
+ samplers.push(LlamaSampler::greedy());
+ }
+ let mut sampler = LlamaSampler::chain_simple(samplers);
+
+ // Generation loop (same as text-only but starting from n_past)
+ let gen_start = Instant::now();
+ let mut generated_text = String::new();
+ let mut generated_count: u32 = 0;
+ let logprobs: Vec = Vec::new();
+ let mut current_pos = n_past;
+ let mut decoder = encoding_rs::UTF_8.new_decoder();
+ let mut batch = LlamaBatch::new(1, 1);
+ // Always use -1 (C API sentinel for "last logits") for sampling.
+ // After single-token decode, batch output index is 0, but -1 always works.
+ // Multimodal tasks skip validation so logprob extraction is not needed.
+
+ for _ in 0..params.max_tokens {
+ // sample() internally calls apply + select + accept
+ let new_token = sampler.sample(&ctx, -1);
+
+ if self.model.is_eog_token(new_token) {
+ break;
+ }
+
+ let gen_index = generated_count as usize;
+
+ // Decode token to text
+ let piece = self
+ .model
+ .token_to_piece(new_token, &mut decoder, true, None)
+ .unwrap_or_default();
+ generated_text.push_str(&piece);
+ generated_count += 1;
+
+ // Stream callback
+ let stream_token = StreamToken {
+ text: piece,
+ index: gen_index,
+ };
+ if let ControlFlow::Break(()) = on_token(stream_token) {
+ break;
+ }
+
+ // Prepare next batch
+ batch.clear();
+ batch
+ .add(new_token, current_pos, &[0], true)
+ .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?;
+ ctx.decode(&mut batch)
+ .map_err(|e| NodeError::Inference(format!("decode failed: {e}")))?;
+ current_pos += 1;
+ }
+
+ let generation_time_ms = gen_start.elapsed().as_millis() as u64;
+ let tokens_per_second = if generation_time_ms > 0 {
+ (generated_count as f64) / (generation_time_ms as f64 / 1000.0)
+ } else {
+ 0.0
+ };
+
+ let proof = if logprobs.is_empty() {
+ None
+ } else {
+ Some(InferenceProof {
+ logprobs,
+ kv_cache_hash: None,
+ })
+ };
+
+ Ok(InferenceResult {
+ text: generated_text,
+ tokens_generated: generated_count,
+ prompt_tokens: prompt_token_count,
+ generation_time_ms,
+ prompt_eval_time_ms,
+ tokens_per_second,
+ proof,
+ })
+ }
+
+ /// Prefill-only validation: tokenize prompt+output, run a single forward pass,
+ /// and extract logprobs at the same stride positions used during generation.
+ ///
+ /// Returns an `InferenceProof` that can be compared against the original.
+ pub fn validate_prefill(
+ &self,
+ prompt: &str,
+ output_text: &str,
+ logprob_every_n: usize,
+ logprob_top_k: usize,
+ ) -> Result {
+ // Tokenize prompt alone to find the split point
+ let prompt_tokens = self
+ .model
+ .str_to_token(prompt, AddBos::Always)
+ .map_err(|e| NodeError::Inference(format!("prompt tokenization failed: {e}")))?;
+ let n_prompt = prompt_tokens.len();
+
+ // Tokenize prompt + output together
+ let full_text = format!("{}{}", prompt, output_text);
+ let all_tokens = self
+ .model
+ .str_to_token(&full_text, AddBos::Always)
+ .map_err(|e| NodeError::Inference(format!("full tokenization failed: {e}")))?;
+ let n_output = all_tokens.len().saturating_sub(n_prompt);
+
+ if n_output == 0 {
+ return Ok(InferenceProof {
+ logprobs: vec![],
+ kv_cache_hash: None,
+ });
+ }
+
+ // Compute probe positions: gen_index values [0, N, 2N, ...] where each is < n_output
+ let mut probe_gen_indices: Vec = Vec::new();
+ if logprob_every_n > 0 {
+ let mut k = 0;
+ while k < n_output {
+ probe_gen_indices.push(k);
+ k += logprob_every_n;
+ }
+ }
+
+ if probe_gen_indices.is_empty() {
+ return Ok(InferenceProof {
+ logprobs: vec![],
+ kv_cache_hash: None,
+ });
+ }
+
+ // Create context sized to fit all tokens (+ small padding)
+ let ctx_size = std::num::NonZeroU32::new((all_tokens.len() + 64) as u32);
+ let ctx_params = LlamaContextParams::default()
+ .with_n_ctx(ctx_size)
+ .with_type_k(self.kv_cache_type)
+ .with_type_v(self.kv_cache_type);
+
+ let mut ctx = self
+ .model
+ .new_context(self.backend, ctx_params)
+ .map_err(|e| NodeError::Inference(format!("failed to create context: {e}")))?;
+
+ // Build batch with all tokens. Set output=true only at positions where we need logits.
+ // For probe gen_index k, we need logits at sequence position (n_prompt + k - 1) for k > 0,
+ // and at (n_prompt - 1) for k == 0 (last prompt token predicts first output token).
+ let mut output_positions: Vec = Vec::new();
+ for &k in &probe_gen_indices {
+ let seq_pos = if k == 0 { n_prompt - 1 } else { n_prompt + k - 1 };
+ output_positions.push(seq_pos);
+ }
+
+ // Evaluate in chunks and extract logprobs per-chunk (next decode overwrites logits)
+ let n_batch = 2048usize;
+ let mut batch = LlamaBatch::new(n_batch.min(all_tokens.len()).max(1), 1);
+ let mut logprobs: Vec = Vec::new();
+
+ let mut pos = 0;
+ while pos < all_tokens.len() {
+ batch.clear();
+ let chunk_end = (pos + n_batch).min(all_tokens.len());
+
+ // Track which probe positions fall in this chunk
+ let mut chunk_probes: Vec<(usize, usize)> = Vec::new(); // (probe_idx, batch_position)
+
+ for (batch_pos, (i, &token)) in all_tokens.iter().enumerate().skip(pos).take(chunk_end - pos).enumerate() {
+ let is_output = output_positions.contains(&i);
+ batch
+ .add(token, i as i32, &[0], is_output)
+ .map_err(|e| NodeError::Inference(format!("batch add failed: {e}")))?;
+ if is_output {
+ if let Some(probe_idx) = output_positions.iter().position(|&p| p == i) {
+ chunk_probes.push((probe_idx, batch_pos));
+ }
+ }
+ }
+
+ ctx.decode(&mut batch)
+ .map_err(|e| NodeError::Inference(format!("prefill decode failed: {e}")))?;
+
+ // Extract logprobs for this chunk's probes before next decode
+ for &(probe_idx, batch_pos) in &chunk_probes {
+ let gen_index = probe_gen_indices[probe_idx];
+ let target_token = all_tokens[n_prompt + gen_index];
+ if let Some(lp) =
+ self.extract_logprob(&ctx, batch_pos as i32, gen_index, target_token, logprob_top_k)
+ {
+ logprobs.push(lp);
+ }
+ }
+
+ pos = chunk_end;
+ }
+
+ Ok(InferenceProof {
+ logprobs,
+ kv_cache_hash: None,
+ })
+ }
+
+ /// Extract logprob data at a given batch index.
+ fn extract_logprob(
+ &self,
+ ctx: &llama_cpp_2::context::LlamaContext,
+ batch_idx: i32,
+ position: usize,
+ chosen_token: LlamaToken,
+ top_k: usize,
+ ) -> Option {
+ let logits = ctx.get_logits_ith(batch_idx);
+
+ // Compute softmax to get log-probabilities
+ let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
+ let exp_sum: f32 = logits.iter().map(|&l| (l - max_logit).exp()).sum();
+ let log_sum = max_logit + exp_sum.ln();
+
+ // Collect (token_id, logprob) for all vocab
+ let mut all_logprobs: Vec<(u32, f32)> = logits
+ .iter()
+ .enumerate()
+ .map(|(i, &l)| (i as u32, l - log_sum))
+ .collect();
+
+ // Sort by logprob descending
+ all_logprobs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
+
+ let chosen_id = chosen_token.0 as u32;
+ let chosen_logprob = all_logprobs
+ .iter()
+ .find(|(id, _)| *id == chosen_id)
+ .map(|(_, lp)| *lp)
+ .unwrap_or(f32::NEG_INFINITY);
+
+ let chosen_text = token_to_string(&self.model, chosen_token);
+
+ let top_k_entries: Vec<(String, f32)> = all_logprobs
+ .iter()
+ .take(top_k)
+ .map(|(id, lp)| {
+ let text = token_to_string(&self.model, LlamaToken(*id as i32));
+ (text, *lp)
+ })
+ .collect();
+
+ Some(TokenLogprob {
+ position,
+ token_id: chosen_id,
+ token_text: chosen_text,
+ logprob: chosen_logprob,
+ top_k: top_k_entries,
+ })
+ }
+
+ /// Compute a placeholder KV-cache hash from logits at a given position.
+ #[allow(dead_code)]
+ fn kv_cache_hash_placeholder(
+ ctx: &llama_cpp_2::context::LlamaContext,
+ batch_idx: i32,
+ ) -> [u8; 32] {
+ let logits = ctx.get_logits_ith(batch_idx);
+ let bytes: Vec = logits.iter().flat_map(|f| f.to_le_bytes()).collect();
+ sha256hash(&bytes)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Tests
+// ---------------------------------------------------------------------------
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use dkn_protocol::{ContentPart, MessageContent};
+
+ /// Create a minimal 64x64 BMP image with a color gradient (no external deps).
+ fn create_test_bmp() -> Vec {
+ let width: u32 = 64;
+ let height: u32 = 64;
+ let row_bytes = width * 3; // 192, already 4-byte aligned
+ let pixel_data_size = row_bytes * height;
+ let file_size = 54 + pixel_data_size;
+
+ let mut data = Vec::with_capacity(file_size as usize);
+
+ // BMP file header (14 bytes)
+ data.extend_from_slice(b"BM");
+ data.extend_from_slice(&file_size.to_le_bytes());
+ data.extend_from_slice(&[0u8; 4]); // reserved
+ data.extend_from_slice(&54u32.to_le_bytes()); // pixel data offset
+
+ // BITMAPINFOHEADER (40 bytes)
+ data.extend_from_slice(&40u32.to_le_bytes()); // header size
+ data.extend_from_slice(&width.to_le_bytes());
+ data.extend_from_slice(&height.to_le_bytes());
+ data.extend_from_slice(&1u16.to_le_bytes()); // planes
+ data.extend_from_slice(&24u16.to_le_bytes()); // bits per pixel
+ data.extend_from_slice(&[0u8; 24]); // compression=0, rest zeros
+
+ // Pixel data (bottom-up, BGR)
+ for y in 0..height {
+ for x in 0..width {
+ let r = ((x * 255) / (width - 1)) as u8;
+ let g = ((y * 255) / (height - 1)) as u8;
+ let b = 128u8;
+ data.push(b);
+ data.push(g);
+ data.push(r);
+ }
+ }
+
+ data
+ }
+
+ /// Integration test: download lfm2.5-vl:1.6b + mmproj, run vision inference.
+ ///
+ /// Run with:
+ /// cargo test test_vision_inference -- --ignored --nocapture
+ ///
+ /// Optionally provide your own image:
+ /// TEST_IMAGE_PATH=/path/to/photo.jpg cargo test test_vision_inference -- --ignored --nocapture
+ #[tokio::test]
+ #[ignore] // requires ~1.5 GB download (model + mmproj)
+ async fn test_vision_inference() {
+ let registry = crate::models::default_registry();
+ let spec = registry.get("lfm2.5-vl:1.6b").unwrap().clone();
+
+ let cache_dir = dirs::cache_dir()
+ .unwrap_or_else(|| std::path::PathBuf::from("."))
+ .join("dria-test-models");
+ let cache = crate::models::ModelCache::new(cache_dir).unwrap();
+
+ // Download / cache the GGUF model
+ let model_path = if let Some(p) = cache.get_local_path(&spec) {
+ println!("model found in cache: {}", p.display());
+ p
+ } else {
+ println!("downloading model (this may take a while)...");
+ let hf_path = crate::models::ModelDownloader::download(&spec).await.unwrap();
+ cache.link_model(&spec, &hf_path).unwrap()
+ };
+
+ // Download / cache the mmproj
+ let mmproj_path = if let Some(p) = cache.get_mmproj_path(&spec) {
+ println!("mmproj found in cache: {}", p.display());
+ p
+ } else {
+ println!("downloading mmproj (this may take a while)...");
+ let hf_path = crate::models::ModelDownloader::download_mmproj(&spec)
+ .await
+ .unwrap();
+ cache.link_mmproj(&spec, &hf_path).unwrap()
+ };
+
+ // Load engine with multimodal projector
+ println!("loading model + mmproj...");
+ let engine = InferenceEngine::load(&model_path, 0, Some(&mmproj_path), None, None).unwrap();
+ assert!(engine.has_multimodal(), "engine should have multimodal context");
+
+ // Get test image: from env var or generate a synthetic BMP
+ let image_bytes = if let Ok(path) = std::env::var("TEST_IMAGE_PATH") {
+ println!("using image: {path}");
+ std::fs::read(&path).expect("failed to read TEST_IMAGE_PATH")
+ } else {
+ println!("using synthetic 64x64 gradient BMP");
+ create_test_bmp()
+ };
+
+ // Build multimodal chat messages
+ let messages = vec![ChatMessage {
+ role: "user".into(),
+ content: MessageContent::Parts(vec![
+ ContentPart::Text {
+ text: "What do you see in this image? Describe it briefly.".into(),
+ },
+ ContentPart::Image {
+ data: image_bytes,
+ },
+ ]),
+ }];
+
+ let params = GenerateParams {
+ max_tokens: 256,
+ temperature: 0.0,
+ ..Default::default()
+ };
+
+ // Run multimodal inference, streaming tokens to stdout
+ println!("\n--- model output ---");
+ let result = engine
+ .generate_multimodal(&messages, ¶ms, |token| {
+ print!("{}", token.text);
+ ControlFlow::Continue(())
+ })
+ .unwrap();
+ println!("\n--- end output ---\n");
+
+ println!(
+ "tokens: {} | prompt: {} | time: {}ms | {:.1} tok/s",
+ result.tokens_generated,
+ result.prompt_tokens,
+ result.generation_time_ms,
+ result.tokens_per_second,
+ );
+
+ assert!(!result.text.is_empty(), "model should produce output");
+ assert!(result.tokens_generated > 0);
+ }
+
+ /// Helper to load a model from cache (or download).
+ async fn load_model(spec: crate::models::registry::ModelSpec) -> (InferenceEngine, String) {
+ let cache_dir = dirs::cache_dir()
+ .unwrap_or_else(|| std::path::PathBuf::from("."))
+ .join("dria-test-models");
+ let cache = crate::models::ModelCache::new(cache_dir).unwrap();
+
+ let model_path = if let Some(p) = cache.get_local_path(&spec) {
+ println!("model found in cache: {}", p.display());
+ p
+ } else {
+ println!("downloading model...");
+ let hf_path = crate::models::ModelDownloader::download(&spec).await.unwrap();
+ cache.link_model(&spec, &hf_path).unwrap()
+ };
+
+ let name = spec.name.clone();
+ let engine = InferenceEngine::load(&model_path, 0, None, None, None).unwrap();
+ (engine, name)
+ }
+
+ /// Load lfm2.5:1.2b from the default registry.
+ async fn load_text_model() -> (InferenceEngine, String) {
+ let registry = crate::models::default_registry();
+ let spec = registry.get("lfm2.5:1.2b").unwrap().clone();
+ load_model(spec).await
+ }
+
+ /// Load a small Qwen 3.5 model for grammar-compatible testing.
+ async fn load_qwen_model() -> (InferenceEngine, String) {
+ let spec = crate::models::registry::ModelSpec {
+ name: "qwen3.5:0.8b".into(),
+ hf_repo: "unsloth/Qwen3.5-0.8B-GGUF".into(),
+ hf_file: "Qwen3.5-0.8B-Q4_K_M.gguf".into(),
+ sha256: None,
+ model_type: dkn_protocol::ModelType::Text,
+ hf_mmproj_file: None,
+ };
+ load_model(spec).await
+ }
+
+ /// End-to-end validation test:
+ /// 1. Generate text with logprob_every_n=8 (greedy so output is deterministic)
+ /// 2. validate_prefill() with the same prompt+output
+ /// 3. compare_proofs() — should Pass
+ ///
+ /// Run with:
+ /// cargo test test_validate_prefill_e2e -- --ignored --nocapture
+ #[tokio::test]
+ #[ignore] // requires lfm2.5:1.2b model (~800 MB)
+ async fn test_validate_prefill_e2e() {
+ let (engine, _model_name) = load_text_model().await;
+
+ let messages = vec![ChatMessage {
+ role: "user".into(),
+ content: "What is 2 + 2? Answer in one word.".into(),
+ }];
+
+ let prompt = engine.apply_template(&messages).unwrap();
+
+ // Generate with logprobs every 8 tokens, greedy (deterministic)
+ let params = GenerateParams {
+ max_tokens: 64,
+ temperature: 0.0,
+ logprob_every_n: 8,
+ logprob_top_k: 5,
+ ..Default::default()
+ };
+
+ let gen_result = engine
+ .generate(&prompt, ¶ms, |_| ControlFlow::Continue(()))
+ .unwrap();
+
+ println!("generated: {:?}", gen_result.text);
+ println!("tokens: {}", gen_result.tokens_generated);
+
+ let original_proof = gen_result.proof.as_ref().expect("should have proof with logprob_every_n=8");
+ println!("original proof positions: {:?}",
+ original_proof.logprobs.iter().map(|lp| lp.position).collect::>()
+ );
+
+ // Now validate: prefill-only forward pass
+ let validator_proof = engine
+ .validate_prefill(&prompt, &gen_result.text, 8, 5)
+ .unwrap();
+
+ println!("validator proof positions: {:?}",
+ validator_proof.logprobs.iter().map(|lp| lp.position).collect::>()
+ );
+
+ // Both proofs should have the same positions
+ assert_eq!(
+ original_proof.logprobs.len(),
+ validator_proof.logprobs.len(),
+ "proof lengths should match"
+ );
+
+ // Compare position by position
+ for (orig, val) in original_proof.logprobs.iter().zip(validator_proof.logprobs.iter()) {
+ assert_eq!(orig.position, val.position, "positions should match");
+ assert_eq!(orig.token_id, val.token_id, "token IDs should match at position {}", orig.position);
+ let diff = (orig.logprob - val.logprob).abs();
+ println!(
+ "pos {} | token '{}' | orig_lp={:.4} | val_lp={:.4} | diff={:.4}",
+ orig.position, orig.token_text, orig.logprob, val.logprob, diff
+ );
+ assert!(
+ diff < 0.5,
+ "logprob diff too large at position {}: {diff}",
+ orig.position
+ );
+ }
+
+ println!("\nall positions match — validation passed!");
+ }
+
+ /// End-to-end structured output test:
+ /// 1. Test a trivial GBNF grammar to verify grammar sampling works
+ /// 2. Generate with json_object grammar (greedy) — output must be valid JSON
+ /// 3. Generate with json_schema grammar — output must match the schema
+ ///
+ /// Run with:
+ /// cargo test test_structured_output_e2e -- --ignored --nocapture
+ #[tokio::test]
+ #[ignore] // requires qwen3.5:0.8b model (~533 MB download)
+ async fn test_structured_output_e2e() {
+ let (engine, _model_name) = load_qwen_model().await;
+
+ // --- Step 1: trivial GBNF grammar to confirm grammar sampling works ---
+ {
+ let grammar = r#"root ::= "hello""#.to_string();
+ let messages = vec![ChatMessage {
+ role: "user".into(),
+ content: "Say hello".into(),
+ }];
+ let prompt = engine.apply_template(&messages).unwrap();
+
+ let params = GenerateParams {
+ max_tokens: 16,
+ temperature: 0.0,
+ grammar: Some(grammar),
+ ..Default::default()
+ };
+
+ println!("\n--- trivial grammar test ---");
+ let result = engine
+ .generate(&prompt, ¶ms, |_| ControlFlow::Continue(()))
+ .unwrap();
+ println!("output: {:?}", result.text);
+ assert_eq!(result.text, "hello", "trivial grammar should constrain to 'hello'");
+ println!("trivial grammar OK");
+ }
+
+ // --- Step 2: json_object mode (permissive JSON) ---
+ {
+ let json_grammar = llama_cpp_2::json_schema_to_grammar(r#"{"type": "object"}"#)
+ .expect("json_object grammar should convert");
+ println!("\njson_object grammar length: {} chars", json_grammar.len());
+
+ let messages = vec![ChatMessage {
+ role: "user".into(),
+ content: "Return a JSON object with a field called 'answer' set to 42.".into(),
+ }];
+ let prompt = engine.apply_template(&messages).unwrap();
+
+ let params = GenerateParams {
+ max_tokens: 128,
+ temperature: 0.0,
+ grammar: Some(json_grammar),
+ ..Default::default()
+ };
+
+ print!("\n--- json_object output ---\n");
+ let result = engine
+ .generate(&prompt, ¶ms, |tok| {
+ print!("{}", tok.text);
+ ControlFlow::Continue(())
+ })
+ .unwrap();
+ println!("\n--- end ---");
+
+ let text = result.text.trim();
+ assert!(!text.is_empty(), "should produce output");
+
+ let parsed: serde_json::Value =
+ serde_json::from_str(text).expect("json_object output must be valid JSON");
+ assert!(parsed.is_object(), "should be a JSON object");
+ println!("parsed JSON: {parsed}");
+ }
+
+ // --- Step 3: json_schema mode (specific schema) ---
+ {
+ let schema = serde_json::json!({
+ "type": "object",
+ "properties": {
+ "name": { "type": "string" },
+ "age": { "type": "integer" }
+ },
+ "required": ["name", "age"],
+ "additionalProperties": false
+ });
+
+ let schema_str = serde_json::to_string(&schema).unwrap();
+ let schema_grammar = llama_cpp_2::json_schema_to_grammar(&schema_str)
+ .expect("json_schema grammar should convert");
+ println!("\njson_schema grammar length: {} chars", schema_grammar.len());
+
+ let messages = vec![ChatMessage {
+ role: "user".into(),
+ content: "Give me a person named Alice who is 30 years old.".into(),
+ }];
+ let prompt = engine.apply_template(&messages).unwrap();
+
+ let params = GenerateParams {
+ max_tokens: 128,
+ temperature: 0.0,
+ grammar: Some(schema_grammar),
+ ..Default::default()
+ };
+
+ print!("\n--- json_schema output ---\n");
+ let result = engine
+ .generate(&prompt, ¶ms, |tok| {
+ print!("{}", tok.text);
+ ControlFlow::Continue(())
+ })
+ .unwrap();
+ println!("\n--- end ---");
+
+ let text = result.text.trim();
+ assert!(!text.is_empty(), "should produce output");
+
+ let parsed: serde_json::Value =
+ serde_json::from_str(text).expect("json_schema output must be valid JSON");
+ assert!(parsed.is_object(), "should be a JSON object");
+ assert!(parsed.get("name").is_some(), "should have 'name' field");
+ assert!(parsed.get("age").is_some(), "should have 'age' field");
+ assert!(parsed["name"].is_string(), "'name' should be a string");
+ assert!(parsed["age"].is_number(), "'age' should be a number");
+ println!("parsed JSON: {parsed}");
+ }
+
+ println!("\nstructured output test passed!");
+ }
+
+ /// Grammar test with lfm2.5:1.2b — verify grammar sampling works across tokenizer types.
+ ///
+ /// Run with:
+ /// cargo test test_structured_output_lfm2 -- --ignored --nocapture
+ #[tokio::test]
+ #[ignore] // requires lfm2.5:1.2b model (~800 MB)
+ async fn test_structured_output_lfm2() {
+ let (engine, _model_name) = load_text_model().await;
+
+ // Trivial grammar
+ {
+ let grammar = r#"root ::= "hello""#.to_string();
+ let messages = vec![ChatMessage {
+ role: "user".into(),
+ content: "Say hello".into(),
+ }];
+ let prompt = engine.apply_template(&messages).unwrap();
+
+ let params = GenerateParams {
+ max_tokens: 16,
+ temperature: 0.0,
+ grammar: Some(grammar),
+ ..Default::default()
+ };
+
+ println!("\n--- lfm2 trivial grammar test ---");
+ let result = engine
+ .generate(&prompt, ¶ms, |_| ControlFlow::Continue(()))
+ .unwrap();
+ println!("output: {:?}", result.text);
+ assert_eq!(result.text, "hello");
+ println!("trivial grammar OK");
+ }
+
+ // Class-like schema with nested object, array, and enum
+ {
+ let schema = serde_json::json!({
+ "type": "object",
+ "properties": {
+ "name": { "type": "string" },
+ "age": { "type": "integer" },
+ "role": { "type": "string", "enum": ["admin", "user", "moderator"] },
+ "address": {
+ "type": "object",
+ "properties": {
+ "city": { "type": "string" },
+ "country": { "type": "string" }
+ },
+ "required": ["city", "country"],
+ "additionalProperties": false
+ },
+ "tags": {
+ "type": "array",
+ "items": { "type": "string" }
+ }
+ },
+ "required": ["name", "age", "role", "address", "tags"],
+ "additionalProperties": false
+ });
+
+ let schema_str = serde_json::to_string(&schema).unwrap();
+ let grammar = llama_cpp_2::json_schema_to_grammar(&schema_str)
+ .expect("class schema should convert");
+ println!("\nclass schema grammar length: {} chars", grammar.len());
+
+ let messages = vec![ChatMessage {
+ role: "user".into(),
+ content: "Create a user profile for Alice, age 30, admin role, lives in Istanbul Turkey, tags: developer and lead.".into(),
+ }];
+ let prompt = engine.apply_template(&messages).unwrap();
+
+ let params = GenerateParams {
+ max_tokens: 256,
+ temperature: 0.0,
+ grammar: Some(grammar),
+ ..Default::default()
+ };
+
+ print!("\n--- lfm2 class-like schema output ---\n");
+ let result = engine
+ .generate(&prompt, ¶ms, |tok| {
+ print!("{}", tok.text);
+ ControlFlow::Continue(())
+ })
+ .unwrap();
+ println!("\n--- end ---");
+
+ let parsed: serde_json::Value =
+ serde_json::from_str(result.text.trim()).expect("must be valid JSON");
+ assert!(parsed.is_object());
+ assert!(parsed["name"].is_string());
+ assert!(parsed["age"].is_number());
+ let role = parsed["role"].as_str().unwrap();
+ assert!(["admin", "user", "moderator"].contains(&role), "role must be enum value");
+ assert!(parsed["address"].is_object());
+ assert!(parsed["address"]["city"].is_string());
+ assert!(parsed["address"]["country"].is_string());
+ assert!(parsed["tags"].is_array());
+ println!("parsed: {parsed}");
+ }
+
+ println!("\nlfm2 structured output OK");
+ }
+}
diff --git a/src/inference/mod.rs b/src/inference/mod.rs
new file mode 100644
index 00000000..7e9506cf
--- /dev/null
+++ b/src/inference/mod.rs
@@ -0,0 +1,5 @@
+pub mod benchmark;
+pub mod engine;
+pub mod stream;
+
+pub use engine::{GenerateParams, InferenceEngine, InferenceResult};
diff --git a/src/inference/stream.rs b/src/inference/stream.rs
new file mode 100644
index 00000000..e510c19b
--- /dev/null
+++ b/src/inference/stream.rs
@@ -0,0 +1,39 @@
+use serde::{Deserialize, Serialize};
+
+/// A single token emitted during streaming generation.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct StreamToken {
+ /// The decoded text of this token.
+ pub text: String,
+ /// The zero-based position of this token in the generated sequence.
+ pub index: usize,
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_stream_token_serde() {
+ let token = StreamToken {
+ text: "hello".into(),
+ index: 0,
+ };
+ let json = serde_json::to_string(&token).unwrap();
+ let roundtrip: StreamToken = serde_json::from_str(&json).unwrap();
+ assert_eq!(roundtrip.text, "hello");
+ assert_eq!(roundtrip.index, 0);
+ }
+
+ #[test]
+ fn test_stream_token_msgpack() {
+ let token = StreamToken {
+ text: "world".into(),
+ index: 42,
+ };
+ let packed = rmp_serde::to_vec(&token).unwrap();
+ let roundtrip: StreamToken = rmp_serde::from_slice(&packed).unwrap();
+ assert_eq!(roundtrip.text, "world");
+ assert_eq!(roundtrip.index, 42);
+ }
+}
diff --git a/src/main.rs b/src/main.rs
new file mode 100644
index 00000000..fc272627
--- /dev/null
+++ b/src/main.rs
@@ -0,0 +1,655 @@
+mod config;
+mod error;
+mod identity;
+mod inference;
+mod models;
+mod network;
+mod setup;
+mod stats;
+mod update;
+mod worker;
+
+use std::collections::HashMap;
+use std::sync::Arc;
+use std::time::Duration;
+
+use clap::Parser;
+use llama_cpp_2::context::params::KvCacheType;
+use tokio::sync::mpsc;
+use tracing_subscriber::EnvFilter;
+
+use config::{Cli, Command, Config};
+use identity::Identity;
+use models::{ModelCache, ModelDownloader, default_registry, resolve_model};
+use models::registry::ModelSpec;
+use network::{NodeMessage, RouterMessage};
+use network::protocol::ModelType;
+use network::RouterConnection;
+use stats::NodeStats;
+use worker::{CompletedTask, Worker};
+
+#[tokio::main]
+async fn main() -> anyhow::Result<()> {
+ // Initialize tracing
+ tracing_subscriber::fmt()
+ .with_env_filter(
+ EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
+ )
+ .init();
+
+ let cli = Cli::parse();
+
+ match cli.command {
+ Command::Setup {
+ data_dir,
+ gpu_layers,
+ } => {
+ setup::run_setup(data_dir, gpu_layers).await?;
+ }
+ Command::Start {
+ wallet,
+ model,
+ router_url,
+ gpu_layers,
+ max_concurrent,
+ data_dir,
+ quant,
+ insecure,
+ skip_update,
+ context_size,
+ kv_quant,
+ } => {
+ run_start(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure, skip_update, context_size, kv_quant).await?;
+ }
+ }
+
+ Ok(())
+}
+
+/// Shared state needed by event handlers for reconnection and challenge-response.
+struct NodeContext {
+ identity: Identity,
+ config: Config,
+ tps: HashMap,
+ stats: Arc,
+ cache: ModelCache,
+}
+
+/// Result of a background model download + load operation.
+struct ModelLoadResult {
+ name: String,
+ model_type: ModelType,
+ result: Result<(inference::InferenceEngine, f64), error::NodeError>,
+}
+
+#[allow(clippy::too_many_arguments)]
+async fn run_start(
+ wallet: String,
+ model: String,
+ router_url: String,
+ gpu_layers: i32,
+ max_concurrent: usize,
+ data_dir: Option,
+ quant: Option,
+ insecure: bool,
+ skip_update: bool,
+ max_context: Option,
+ kv_quant: String,
+) -> anyhow::Result<()> {
+ // Parse config
+ let config = Config::from_start_args(wallet, model, router_url, gpu_layers, max_concurrent, data_dir, quant, insecure, skip_update, max_context, kv_quant)?;
+
+ // Create identity
+ let identity = Identity::from_secret_hex(&config.secret_key_hex)?;
+ tracing::info!(address = %format!("0x{}", identity.address_hex), "node identity");
+
+ // Check for updates
+ if !config.skip_update {
+ match update::check_for_update().await {
+ Ok(update::UpdateAction::Force(version)) => {
+ tracing::warn!(%version, "mandatory update available, downloading...");
+ if let Err(e) = update::perform_update(&version).await {
+ tracing::error!(%e, "auto-update failed, continuing with current version");
+ } else {
+ tracing::info!("update complete — please restart the node");
+ return Ok(());
+ }
+ }
+ Ok(update::UpdateAction::Warn(version)) => {
+ tracing::warn!(
+ %version,
+ "new patch version available, update recommended (current: {})",
+ env!("CARGO_PKG_VERSION")
+ );
+ }
+ Ok(update::UpdateAction::UpToDate) => {}
+ Err(e) => {
+ tracing::debug!(%e, "update check failed, continuing");
+ }
+ }
+ }
+
+ // Ensure directories exist
+ std::fs::create_dir_all(&config.data_dir)?;
+ std::fs::create_dir_all(&config.models_dir)?;
+
+ // Parse KV cache quantization type
+ let kv_cache_type = parse_kv_quant(&config.kv_quant)?;
+
+ // Resolve and download models
+ let registry = default_registry();
+ let cache = ModelCache::new(config.models_dir.clone())?;
+
+ // Accumulate engines and TPS per model
+ let mut engines: HashMap = HashMap::new();
+ let mut tps_map: HashMap = HashMap::new();
+
+ for model_name in &config.model_names {
+ let spec = resolve_model(model_name, ®istry, config.quant.as_deref())
+ .ok_or_else(|| error::NodeError::Model(format!("unknown model: {model_name}")))?;
+
+ let (engine, tps) = download_and_load_model(&spec, &cache, config.gpu_layers, config.max_context, Some(kv_cache_type)).await?;
+
+ tracing::info!(tps = %format!("{tps:.1}"), model = %model_name, "benchmark complete");
+ engines.insert(model_name.clone(), (engine, spec.model_type));
+ tps_map.insert(model_name.clone(), tps);
+ }
+
+ if engines.is_empty() {
+ return Err(error::NodeError::Config("no models loaded".into()).into());
+ }
+
+ // Print banner
+ eprint!("{}", include_str!("../dnet.art"));
+
+ // Build the worker
+ let mut worker = Worker::new(engines, config.max_concurrent);
+
+ // Attempt router connection; try each URL, go offline if all unavailable
+ let mut connection: Option = None;
+ for url in &config.router_urls {
+ match RouterConnection::connect(
+ url,
+ config.insecure,
+ &identity,
+ config.model_names.clone(),
+ tps_map.clone(),
+ worker.capacity(),
+ )
+ .await
+ {
+ Ok(conn) => {
+ tracing::info!(node_id = %conn.node_id, router = %url, "connected to router");
+ connection = Some(conn);
+ break;
+ }
+ Err(e) => {
+ tracing::warn!(%e, router = %url, "failed to connect to router");
+ }
+ }
+ }
+ if connection.is_none() {
+ tracing::warn!("all routers unavailable, running in offline mode");
+ }
+
+ tracing::info!(
+ routers = ?config.router_urls,
+ models = ?config.model_names,
+ max_concurrent = config.max_concurrent,
+ insecure = config.insecure,
+ online = connection.is_some(),
+ "node ready"
+ );
+
+ // Build shared context for event handlers
+ let stats = Arc::new(NodeStats::new());
+ let mut ctx = NodeContext {
+ identity,
+ config,
+ tps: tps_map,
+ stats: Arc::clone(&stats),
+ cache,
+ };
+
+ // Channel for background model load results
+ let (model_tx, mut model_rx) = mpsc::unbounded_channel::();
+
+ // Main event loop
+ let mut stats_interval = tokio::time::interval(Duration::from_secs(60));
+ stats_interval.tick().await; // consume the immediate first tick
+ loop {
+ let event = tokio::select! {
+ msg = recv_router_msg(&mut connection) => Event::RouterMsg(msg),
+ Some(done) = worker.next_completed() => Event::TaskDone(done),
+ Some(loaded) = model_rx.recv() => Event::ModelLoaded(loaded),
+ _ = stats_interval.tick() => Event::StatsLog,
+ _ = tokio::signal::ctrl_c() => Event::Shutdown,
+ };
+
+ match event {
+ Event::RouterMsg(Ok(Some(msg))) => {
+ handle_router_message(msg, &mut worker, &mut connection, &mut ctx, &model_tx).await;
+ }
+ Event::RouterMsg(Ok(None)) => {
+ // Stream closed cleanly
+ tracing::warn!("router stream closed, attempting reconnect");
+ if let Some(ref conn) = connection {
+ conn.close();
+ }
+ connection = tokio::select! {
+ result = try_reconnect(&ctx, worker.capacity()) => result,
+ _ = tokio::signal::ctrl_c() => {
+ tracing::info!("shutdown signal received during reconnect");
+ break;
+ }
+ };
+ }
+ Event::RouterMsg(Err(e)) => {
+ tracing::warn!(%e, "router communication error, attempting reconnect");
+ if let Some(ref conn) = connection {
+ conn.close();
+ }
+ connection = tokio::select! {
+ result = try_reconnect(&ctx, worker.capacity()) => result,
+ _ = tokio::signal::ctrl_c() => {
+ tracing::info!("shutdown signal received during reconnect");
+ break;
+ }
+ };
+ }
+ Event::TaskDone(completed) => {
+ handle_completed_task(completed, &connection, &ctx.stats);
+ }
+ Event::ModelLoaded(loaded) => {
+ match loaded.result {
+ Ok((engine, tps)) => {
+ tracing::info!(
+ model = %loaded.name,
+ tps = %format!("{tps:.1}"),
+ "model loaded successfully"
+ );
+ worker.add_engine(loaded.name.clone(), engine, loaded.model_type);
+ ctx.tps.insert(loaded.name, tps);
+ }
+ Err(e) => {
+ tracing::error!(model = %loaded.name, %e, "failed to load model");
+ }
+ }
+ }
+ Event::StatsLog => {
+ ctx.stats.log_summary();
+ }
+ Event::Shutdown => {
+ tracing::info!("shutdown signal received");
+ break;
+ }
+ }
+ }
+
+ // Graceful shutdown: drain in-flight tasks with 30s timeout
+ if worker.has_in_flight() {
+ tracing::info!("draining in-flight tasks (30s timeout)");
+ let drain_deadline = tokio::time::Instant::now() + Duration::from_secs(30);
+
+ loop {
+ tokio::select! {
+ Some(completed) = worker.next_completed() => {
+ handle_completed_task(completed, &connection, &ctx.stats);
+ }
+ _ = tokio::time::sleep_until(drain_deadline) => {
+ tracing::warn!("drain timeout reached, dropping remaining tasks");
+ break;
+ }
+ }
+ if !worker.has_in_flight() {
+ break;
+ }
+ }
+ }
+
+ if let Some(ref conn) = connection {
+ conn.close();
+ }
+ tracing::info!("shutdown complete");
+
+ Ok(())
+}
+
+// ---------------------------------------------------------------------------
+// Event types for the select! loop
+// ---------------------------------------------------------------------------
+
+enum Event {
+ RouterMsg(Result