diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 140b0d8..d1d2dec 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -68,20 +68,28 @@ jobs: agent: name: agent (Zig) runs-on: ubuntu-latest + env: + ZIG_VERSION: 0.16.0 defaults: run: working-directory: agent steps: - uses: actions/checkout@v4 - - uses: mlugg/setup-zig@v1 - with: - version: 0.14.0 + - name: install Zig + run: | + set -euo pipefail + zig_dir="$RUNNER_TEMP/zig-$ZIG_VERSION" + curl -fsSL "https://ziglang.org/download/$ZIG_VERSION/zig-x86_64-linux-$ZIG_VERSION.tar.xz" -o "$RUNNER_TEMP/zig.tar.xz" + mkdir -p "$zig_dir" + tar -xJf "$RUNNER_TEMP/zig.tar.xz" -C "$zig_dir" --strip-components=1 + echo "$zig_dir" >> "$GITHUB_PATH" + "$zig_dir/zig" version - run: zig build -Dtarget=x86_64-windows-gnu -Doptimize=ReleaseSmall - run: zig build -Dtarget=x86_64-linux-gnu -Doptimize=ReleaseSmall - run: zig build -Dtarget=aarch64-linux-gnu -Doptimize=ReleaseSmall - run: zig build -Dtarget=aarch64-macos -Doptimize=ReleaseSmall - run: zig build -Dtarget=x86_64-macos -Doptimize=ReleaseSmall - - run: zig build test + - run: zig build test -Dtarget=x86_64-linux-musl - uses: actions/upload-artifact@v4 with: name: tawny-agent-builds diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 87dea41..8dd955d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,6 +9,8 @@ jobs: build-agent: name: build agent runs-on: ubuntu-latest + env: + ZIG_VERSION: 0.16.0 strategy: matrix: target: @@ -19,9 +21,15 @@ jobs: - { name: macos-x64, zig: x86_64-macos, ext: "" } steps: - uses: actions/checkout@v4 - - uses: mlugg/setup-zig@v1 - with: - version: 0.14.0 + - name: install Zig + run: | + set -euo pipefail + zig_dir="$RUNNER_TEMP/zig-$ZIG_VERSION" + curl -fsSL "https://ziglang.org/download/$ZIG_VERSION/zig-x86_64-linux-$ZIG_VERSION.tar.xz" -o "$RUNNER_TEMP/zig.tar.xz" + mkdir -p "$zig_dir" + tar -xJf "$RUNNER_TEMP/zig.tar.xz" -C "$zig_dir" --strip-components=1 + echo "$zig_dir" >> "$GITHUB_PATH" + "$zig_dir/zig" version - name: build working-directory: agent run: zig build -Dtarget=${{ matrix.target.zig }} -Doptimize=ReleaseSmall diff --git a/README.md b/README.md index 968f56b..6105691 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ Requirements: - Docker 24+ - macOS Apple Silicon: enable Docker Desktop's x86/amd64 emulation/Rosetta support for SQL Server, or pass `--platform linux/amd64` -- .NET 10 SDK, Node 22 + pnpm 10, and Zig 0.14+ only if you want to work outside Docker or build the agent locally +- .NET 10 SDK, Node 22 + pnpm 10, and Zig 0.16+ only if you want to work outside Docker or build the agent locally ```bash # macOS / Linux diff --git a/agent/Dockerfile b/agent/Dockerfile index ad60674..b498cbf 100644 --- a/agent/Dockerfile +++ b/agent/Dockerfile @@ -1,18 +1,18 @@ # syntax=docker/dockerfile:1 FROM alpine:3.20 AS build ARG TARGETARCH -ARG ZIG_VERSION=0.14.0 +ARG ZIG_VERSION=0.16.0 WORKDIR /src RUN apk add --no-cache curl xz RUN case "$TARGETARCH" in \ - amd64) zig_arch="x86_64"; zig_target="x86_64-linux-musl" ;; \ - arm64) zig_arch="aarch64"; zig_target="aarch64-linux-musl" ;; \ + amd64) zig_arch="x86_64-linux"; zig_target="x86_64-linux-musl" ;; \ + arm64) zig_arch="aarch64-linux"; zig_target="aarch64-linux-musl" ;; \ *) echo "Unsupported Docker target architecture: $TARGETARCH" >&2; exit 1 ;; \ esac \ - && curl -fsSL "https://ziglang.org/download/${ZIG_VERSION}/zig-linux-${zig_arch}-${ZIG_VERSION}.tar.xz" -o /tmp/zig.tar.xz \ + && curl -fsSL "https://ziglang.org/download/${ZIG_VERSION}/zig-${zig_arch}-${ZIG_VERSION}.tar.xz" -o /tmp/zig.tar.xz \ && tar -C /opt -xf /tmp/zig.tar.xz \ - && ln -s "/opt/zig-linux-${zig_arch}-${ZIG_VERSION}/zig" /usr/local/bin/zig \ + && ln -s "/opt/zig-${zig_arch}-${ZIG_VERSION}/zig" /usr/local/bin/zig \ && echo "$zig_target" > /tmp/zig-target COPY build.zig build.zig.zon ./ diff --git a/agent/build.zig b/agent/build.zig index a8252fd..13d2012 100644 --- a/agent/build.zig +++ b/agent/build.zig @@ -4,23 +4,25 @@ pub fn build(b: *std.Build) void { const target = b.standardTargetOptions(.{}); const optimize = b.standardOptimizeOption(.{}); - const exe = b.addExecutable(.{ - .name = "tawny-agent", + const exe_mod = b.createModule(.{ .root_source_file = b.path("src/main.zig"), .target = target, .optimize = optimize, + .link_libc = if (target.result.os.tag == .windows or target.result.os.tag == .macos or target.result.os.tag == .linux) true else null, + }); + + const exe = b.addExecutable(.{ + .name = "tawny-agent", + .root_module = exe_mod, }); if (target.result.os.tag == .windows) { - exe.linkLibC(); - exe.linkSystemLibrary("ws2_32"); - exe.linkSystemLibrary("kernel32"); - exe.linkSystemLibrary("advapi32"); - exe.linkSystemLibrary("iphlpapi"); - exe.linkSystemLibrary("wtsapi32"); - exe.linkSystemLibrary("ntdll"); - } else if (target.result.os.tag == .macos or target.result.os.tag == .linux) { - exe.linkLibC(); + exe_mod.linkSystemLibrary("ws2_32", .{}); + exe_mod.linkSystemLibrary("kernel32", .{}); + exe_mod.linkSystemLibrary("advapi32", .{}); + exe_mod.linkSystemLibrary("iphlpapi", .{}); + exe_mod.linkSystemLibrary("wtsapi32", .{}); + exe_mod.linkSystemLibrary("ntdll", .{}); } b.installArtifact(exe); @@ -32,21 +34,23 @@ pub fn build(b: *std.Build) void { const run_step = b.step("run", "Run the agent"); run_step.dependOn(&run_cmd.step); - const unit_tests = b.addTest(.{ + const test_mod = b.createModule(.{ .root_source_file = b.path("src/main.zig"), .target = target, .optimize = optimize, + .link_libc = if (target.result.os.tag == .windows or target.result.os.tag == .macos or target.result.os.tag == .linux) true else null, + }); + + const unit_tests = b.addTest(.{ + .root_module = test_mod, }); if (target.result.os.tag == .windows) { - unit_tests.linkLibC(); - unit_tests.linkSystemLibrary("ws2_32"); - unit_tests.linkSystemLibrary("kernel32"); - unit_tests.linkSystemLibrary("advapi32"); - unit_tests.linkSystemLibrary("iphlpapi"); - unit_tests.linkSystemLibrary("wtsapi32"); - unit_tests.linkSystemLibrary("ntdll"); - } else if (target.result.os.tag == .macos or target.result.os.tag == .linux) { - unit_tests.linkLibC(); + test_mod.linkSystemLibrary("ws2_32", .{}); + test_mod.linkSystemLibrary("kernel32", .{}); + test_mod.linkSystemLibrary("advapi32", .{}); + test_mod.linkSystemLibrary("iphlpapi", .{}); + test_mod.linkSystemLibrary("wtsapi32", .{}); + test_mod.linkSystemLibrary("ntdll", .{}); } const test_step = b.step("test", "Run unit tests"); test_step.dependOn(&b.addRunArtifact(unit_tests).step); diff --git a/agent/build.zig.zon b/agent/build.zig.zon index 6457d3a..015c61f 100644 --- a/agent/build.zig.zon +++ b/agent/build.zig.zon @@ -2,7 +2,7 @@ .name = .tawny_agent, .version = "0.1.0", .fingerprint = 0x2e7ba71c139fa190, - .minimum_zig_version = "0.14.0", + .minimum_zig_version = "0.16.0", .dependencies = .{}, .paths = .{ "build.zig", diff --git a/agent/src/collectors/dns.zig b/agent/src/collectors/dns.zig new file mode 100644 index 0000000..5972250 --- /dev/null +++ b/agent/src/collectors/dns.zig @@ -0,0 +1,177 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const iox = @import("../io_compat.zig"); + +/// DNS query collector. Linux: shells out to `journalctl -u systemd-resolved` +/// since the last collection and pulls structured JSON lines, then matches +/// log entries that look like DNS queries. +/// +/// This is intentionally user-mode and dependency-free: it does not attach to +/// the resolved varlink socket, does not parse pcap, and does not require +/// privileges beyond journal read access. The trade-off is that systems not +/// running systemd-resolved emit nothing, and operators must enable resolved +/// query logging (`resolvectl log-level debug`) to see every lookup. That's +/// flagged in the README + the Detections page. +pub const Collector = struct { + allocator: std.mem.Allocator, + last_run_unix: i64 = 0, + + pub fn init(alloc: std.mem.Allocator) Collector { + return .{ .allocator = alloc }; + } + + /// Returns one JSON payload per detected DNS query since the last call. + /// Caller owns the outer slice and each inner payload. + pub fn collectQueries(self: *Collector) ![][]u8 { + var payloads = std.array_list.Managed([]u8).init(self.allocator); + errdefer { + for (payloads.items) |p| self.allocator.free(p); + payloads.deinit(); + } + + switch (builtin.os.tag) { + .linux => try self.collectLinux(&payloads), + else => {}, // Win/macOS DNS capture needs ETW / NetworkExtension; not in this collector. + } + + self.last_run_unix = iox.timestamp(); + return payloads.toOwnedSlice(); + } + + fn collectLinux(self: *Collector, payloads: *std.array_list.Managed([]u8)) !void { + const since_arg = if (self.last_run_unix == 0) + try self.allocator.dupe(u8, "60 seconds ago") + else + try std.fmt.allocPrint(self.allocator, "@{d}", .{self.last_run_unix}); + defer self.allocator.free(since_arg); + + const result = std.process.run(self.allocator, iox.current(), .{ + .argv = &.{ + "journalctl", + "-u", + "systemd-resolved", + "--since", + since_arg, + "--output=json", + "--no-pager", + "--quiet", + }, + .stdout_limit = .limited(4 * 1024 * 1024), + .stderr_limit = .limited(4 * 1024 * 1024), + }) catch return; // journalctl missing or not permitted; silently skip. + defer self.allocator.free(result.stdout); + defer self.allocator.free(result.stderr); + + var lines = std.mem.splitScalar(u8, result.stdout, '\n'); + while (lines.next()) |line| { + if (line.len == 0) continue; + try parseLine(self.allocator, line, payloads); + } + } +}; + +/// Parse one journal JSON line. We tolerate journal entries that aren't DNS +/// queries: we look for an `MESSAGE` field containing a hostname being looked +/// up. systemd-resolved at debug level emits messages like: +/// "Looking up RR for example.com IN A" +/// "Got DNS reply ... example.com IN A -> 93.184.216.34" +fn parseLine(alloc: std.mem.Allocator, line: []const u8, out: *std.array_list.Managed([]u8)) !void { + var parsed = std.json.parseFromSlice(std.json.Value, alloc, line, .{}) catch return; + defer parsed.deinit(); + + if (parsed.value != .object) return; + const obj = parsed.value.object; + const message_entry = obj.get("MESSAGE") orelse return; + if (message_entry != .string) return; + const message = message_entry.string; + + const qname = extractQname(message) orelse return; + const qtype = extractQtype(message) orelse "A"; + const reply_ip = extractReplyIp(message); + const ts_secs: i64 = blk: { + const ts_entry = obj.get("__REALTIME_TIMESTAMP") orelse break :blk iox.timestamp(); + if (ts_entry != .string) break :blk iox.timestamp(); + const micros = std.fmt.parseInt(i64, ts_entry.string, 10) catch break :blk iox.timestamp(); + break :blk @divTrunc(micros, 1_000_000); + }; + _ = ts_secs; // ts surfaces on the event envelope, not the payload, so we drop it here. + + var payload: std.Io.Writer.Allocating = .init(alloc); + errdefer payload.deinit(); + const w = &payload.writer; + + try w.writeAll("{\"qname\":"); + try std.json.Stringify.value(qname, .{}, w); + try w.writeAll(",\"qtype\":"); + try std.json.Stringify.value(qtype, .{}, w); + try w.writeAll(",\"response_ips\":"); + if (reply_ip) |ip| { + try w.writeAll("["); + try std.json.Stringify.value(ip, .{}, w); + try w.writeAll("]"); + } else { + try w.writeAll("[]"); + } + try w.writeAll(",\"resolver\":\"systemd-resolved\"}"); + + try out.append(try payload.toOwnedSlice()); +} + +fn extractQname(message: []const u8) ?[]const u8 { + // Common shapes: + // "Looking up RR for example.com IN A" + // "Got DNS reply for example.com IN A: 93.184.216.34" + // "Resolved example.com -> 93.184.216.34" + const triggers = [_][]const u8{ "Looking up RR for ", "Got DNS reply for ", "Resolved " }; + inline for (triggers) |trigger| { + if (std.mem.indexOf(u8, message, trigger)) |idx| { + const start = idx + trigger.len; + const rest = message[start..]; + const end = std.mem.indexOfAny(u8, rest, " \t:->") orelse rest.len; + const candidate = std.mem.trim(u8, rest[0..end], " \t.,"); + if (candidate.len > 0 and candidate.len <= 253) return candidate; + } + } + return null; +} + +fn extractQtype(message: []const u8) ?[]const u8 { + const types = [_][]const u8{ " A ", " AAAA ", " CNAME ", " MX ", " TXT ", " NS ", " PTR ", " SOA " }; + for (types) |type_token| { + if (std.mem.indexOf(u8, message, type_token) != null) { + return std.mem.trim(u8, type_token, " "); + } + } + return null; +} + +fn extractReplyIp(message: []const u8) ?[]const u8 { + // Heuristic: find the last "->" or ": " followed by an IPv4 / IPv6 literal. + const markers = [_][]const u8{ "-> ", ": " }; + for (markers) |marker| { + if (std.mem.lastIndexOf(u8, message, marker)) |idx| { + const start = idx + marker.len; + const tail = std.mem.trim(u8, message[start..], " \t.,;"); + if (looksLikeIp(tail)) return tail; + } + } + return null; +} + +fn looksLikeIp(text: []const u8) bool { + if (text.len == 0) return false; + var has_digit = false; + var has_dot_or_colon = false; + for (text) |c| { + if (std.ascii.isDigit(c)) has_digit = true; + if (c == '.' or c == ':') has_dot_or_colon = true; + if (!std.ascii.isAlphanumeric(c) and c != '.' and c != ':') return false; + } + return has_digit and has_dot_or_colon; +} + +test "qname extraction" { + const a = extractQname("Looking up RR for example.com IN A"); + try std.testing.expect(a != null); + try std.testing.expectEqualStrings("example.com", a.?); +} diff --git a/agent/src/collectors/extensions.zig b/agent/src/collectors/extensions.zig new file mode 100644 index 0000000..593fa72 --- /dev/null +++ b/agent/src/collectors/extensions.zig @@ -0,0 +1,271 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const env = @import("../env.zig"); + +/// Editor + browser extension scanner inspired by Perplexity's Bumblebee. +/// Each extension is emitted with `ecosystem = "editor-extension"` or +/// `"browser-extension"` so the same PackageExposure rule shape used for +/// npm/pypi can match a compromised extension by id+version. +/// +/// Privacy stance from Bumblebee: we read only the manifest fields we need +/// for identity. We never emit extension permissions, content scripts, +/// background script paths, or settings. +pub const Scanner = struct { + allocator: std.mem.Allocator, + + pub fn init(alloc: std.mem.Allocator) Scanner { + return .{ .allocator = alloc }; + } + + pub fn collectExtensions(self: *Scanner, kind: Kind) ![][]u8 { + _ = kind; + return self.allocator.alloc([]u8, 0); + } + + fn scanEditorRoots(self: *Scanner, home: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + const subdirs = [_][]const u8{ + ".vscode/extensions", + ".vscode-server/extensions", + ".vscode-insiders/extensions", + ".cursor/extensions", + ".windsurf/extensions", + ".vscodium/extensions", + }; + for (subdirs) |sub| { + const root = std.fs.path.join(self.allocator, &.{ home, sub }) catch continue; + defer self.allocator.free(root); + self.scanEditorRoot(root, payloads) catch continue; + } + } + + fn scanEditorRoot(self: *Scanner, root: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + var dir = std.fs.openDirAbsolute(root, .{ .iterate = true }) catch return; + defer dir.close(); + + var it = dir.iterate(); + while (it.next() catch null) |entry| { + if (entry.kind != .directory) continue; + // Convention: directory name is `.-` and + // optionally suffixed with `-` for native deps. + const dir_name = entry.name; + const id_and_version = parseExtensionDirName(dir_name) orelse continue; + + var sub = dir.openDir(dir_name, .{}) catch continue; + defer sub.close(); + const pkg_json = readFile(self.allocator, sub, "package.json", 256 * 1024) catch null; + defer if (pkg_json) |body| self.allocator.free(body); + + // Cross-check version with package.json when available; the + // directory name can lie if the install was renamed by hand. + var resolved_version = id_and_version.version; + if (pkg_json) |body| { + if (extractJsonString(self.allocator, body, "version")) |jv| { + resolved_version = jv; + } + } + + const source_path = std.fs.path.join(self.allocator, &.{ root, dir_name }) catch continue; + defer self.allocator.free(source_path); + + const payload = try buildExtensionEvent( + self.allocator, + "editor-extension", + id_and_version.id, + resolved_version, + source_path); + try payloads.append(payload); + } + } + + fn scanBrowserRoots(self: *Scanner, home: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + // Chromium-family profiles: each `/Extensions///manifest.json`. + const chrome_paths = [_][]const u8{ + ".config/google-chrome", + ".config/chromium", + ".config/microsoft-edge", + ".config/BraveSoftware/Brave-Browser", + "Library/Application Support/Google/Chrome", + "Library/Application Support/Chromium", + "Library/Application Support/Microsoft Edge", + }; + for (chrome_paths) |sub| { + const root = std.fs.path.join(self.allocator, &.{ home, sub }) catch continue; + defer self.allocator.free(root); + self.scanChromiumRoot(root, payloads) catch continue; + } + + // Firefox profile: `/extensions.json` listing installed addons. + const ff_path = std.fs.path.join(self.allocator, &.{ home, ".mozilla/firefox" }) catch return; + defer self.allocator.free(ff_path); + self.scanFirefoxRoot(ff_path, payloads) catch return; + } + + fn scanChromiumRoot(self: *Scanner, root: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + var root_dir = std.fs.openDirAbsolute(root, .{ .iterate = true }) catch return; + defer root_dir.close(); + + var profile_it = root_dir.iterate(); + while (profile_it.next() catch null) |profile_entry| { + if (profile_entry.kind != .directory) continue; + var profile_dir = root_dir.openDir(profile_entry.name, .{}) catch continue; + defer profile_dir.close(); + + var ext_dir = profile_dir.openDir("Extensions", .{ .iterate = true }) catch continue; + defer ext_dir.close(); + + var ext_it = ext_dir.iterate(); + while (ext_it.next() catch null) |ext_entry| { + if (ext_entry.kind != .directory) continue; + const ext_id = ext_entry.name; + var per_ext = ext_dir.openDir(ext_id, .{ .iterate = true }) catch continue; + defer per_ext.close(); + + var version_it = per_ext.iterate(); + while (version_it.next() catch null) |version_entry| { + if (version_entry.kind != .directory) continue; + var version_dir = per_ext.openDir(version_entry.name, .{}) catch continue; + defer version_dir.close(); + + const manifest = readFile(self.allocator, version_dir, "manifest.json", 512 * 1024) catch continue; + defer self.allocator.free(manifest); + + const declared_name = extractJsonString(self.allocator, manifest, "name") orelse try self.allocator.dupe(u8, ext_id); + defer self.allocator.free(declared_name); + const declared_version = extractJsonString(self.allocator, manifest, "version") orelse try self.allocator.dupe(u8, version_entry.name); + defer self.allocator.free(declared_version); + + const source_path = std.fs.path.join(self.allocator, &.{ root, profile_entry.name, "Extensions", ext_id, version_entry.name }) catch continue; + defer self.allocator.free(source_path); + + const composite_id = std.fmt.allocPrint(self.allocator, "{s} ({s})", .{ ext_id, declared_name }) catch continue; + defer self.allocator.free(composite_id); + + const payload = try buildExtensionEvent( + self.allocator, + "browser-extension", + composite_id, + declared_version, + source_path); + try payloads.append(payload); + } + } + } + } + + fn scanFirefoxRoot(self: *Scanner, root: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + var root_dir = std.fs.openDirAbsolute(root, .{ .iterate = true }) catch return; + defer root_dir.close(); + + var profile_it = root_dir.iterate(); + while (profile_it.next() catch null) |profile| { + if (profile.kind != .directory) continue; + var profile_dir = root_dir.openDir(profile.name, .{}) catch continue; + defer profile_dir.close(); + const body = readFile(self.allocator, profile_dir, "extensions.json", 4 * 1024 * 1024) catch continue; + defer self.allocator.free(body); + + const source_path = std.fs.path.join(self.allocator, &.{ root, profile.name, "extensions.json" }) catch continue; + defer self.allocator.free(source_path); + + var parsed = std.json.parseFromSlice(std.json.Value, self.allocator, body, .{}) catch continue; + defer parsed.deinit(); + if (parsed.value != .object) continue; + const addons = parsed.value.object.get("addons") orelse continue; + if (addons != .array) continue; + + for (addons.array.items) |addon| { + if (addon != .object) continue; + const id = addon.object.get("id") orelse continue; + const ver = addon.object.get("version") orelse continue; + if (id != .string or ver != .string) continue; + const payload = try buildExtensionEvent( + self.allocator, + "browser-extension", + id.string, + ver.string, + source_path); + try payloads.append(payload); + } + } + } +}; + +pub const Kind = enum { editor, browser }; + +const ParsedDirName = struct { id: []const u8, version: []const u8 }; + +fn parseExtensionDirName(name: []const u8) ?ParsedDirName { + // Bumblebee documents the shape as `.-[-]`. + // We split on the last `-` and rely on the publisher dot to anchor the id. + if (std.mem.indexOfScalar(u8, name, '.') == null) return null; + const dash = std.mem.lastIndexOfScalar(u8, name, '-') orelse return null; + if (dash == 0 or dash >= name.len - 1) return null; + const candidate_version = name[dash + 1 ..]; + // Reject trailing `-` suffix by checking that the version + // starts with a digit (heuristic but reliable in practice). + if (candidate_version.len == 0 or !std.ascii.isDigit(candidate_version[0])) { + // Look one segment back for the version. + const prev_dash = std.mem.lastIndexOfScalar(u8, name[0..dash], '-') orelse return null; + const version = name[prev_dash + 1 .. dash]; + if (version.len == 0 or !std.ascii.isDigit(version[0])) return null; + return .{ .id = name[0..prev_dash], .version = version }; + } + return .{ .id = name[0..dash], .version = candidate_version }; +} + +/// Tiny string-only JSON value extractor — used because we don't need to +/// fully parse a 100KB manifest when we only want one or two top-level keys. +/// Returns an owned slice on success, or null when the key isn't a string. +fn extractJsonString(alloc: std.mem.Allocator, body: []const u8, key: []const u8) ?[]u8 { + // Look for `""` and the first `"..."` after it. Tolerant of nesting: + // if a nested object has the same key we'll match the outer one first. + var pattern_buf: [128]u8 = undefined; + const pattern = std.fmt.bufPrint(&pattern_buf, "\"{s}\"", .{key}) catch return null; + const start = std.mem.indexOf(u8, body, pattern) orelse return null; + const after = body[start + pattern.len ..]; + // Skip whitespace + colon. + var i: usize = 0; + while (i < after.len and (after[i] == ' ' or after[i] == '\t' or after[i] == ':' or after[i] == '\r' or after[i] == '\n')) : (i += 1) {} + if (i >= after.len or after[i] != '"') return null; + i += 1; + const value_start = i; + while (i < after.len and after[i] != '"') : (i += 1) { + if (after[i] == '\\' and i + 1 < after.len) i += 1; + } + if (i >= after.len) return null; + return alloc.dupe(u8, after[value_start..i]) catch null; +} + +fn readFile(alloc: std.mem.Allocator, dir: std.Io.Dir, name: []const u8, max_bytes: usize) ![]u8 { + var file = try dir.openFile(name, .{}); + defer file.close(); + return file.readToEndAlloc(alloc, max_bytes); +} + +fn buildExtensionEvent( + alloc: std.mem.Allocator, + ecosystem: []const u8, + id: []const u8, + version: []const u8, + source_path: []const u8, +) ![]u8 { + var out: std.Io.Writer.Allocating = .init(alloc); + errdefer out.deinit(); + const w = &out.writer; + try w.writeAll("{\"ecosystem\":"); + try std.json.Stringify.value(ecosystem, .{}, w); + try w.writeAll(",\"name\":"); + try std.json.Stringify.value(id, .{}, w); + try w.writeAll(",\"version\":"); + try std.json.Stringify.value(version, .{}, w); + try w.writeAll(",\"source_path\":"); + try std.json.Stringify.value(source_path, .{}, w); + try w.writeByte('}'); + return out.toOwnedSlice(); +} + +test "extension dir name parser" { + const a = parseExtensionDirName("ms-vscode.csharp-1.25.0") orelse unreachable; + try std.testing.expectEqualStrings("ms-vscode.csharp", a.id); + try std.testing.expectEqualStrings("1.25.0", a.version); +} diff --git a/agent/src/collectors/fim.zig b/agent/src/collectors/fim.zig index ed75e46..6fcb293 100644 --- a/agent/src/collectors/fim.zig +++ b/agent/src/collectors/fim.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const iox = @import("../io_compat.zig"); const Sha1Digest = [20]u8; const Sha256Digest = [32]u8; @@ -13,12 +14,12 @@ const WatchedFile = struct { pub const Watcher = struct { allocator: std.mem.Allocator, - files: std.ArrayList(WatchedFile), + files: std.array_list.Managed(WatchedFile), pub fn init(alloc: std.mem.Allocator, paths: []const []const u8) !Watcher { var watcher = Watcher{ .allocator = alloc, - .files = std.ArrayList(WatchedFile).init(alloc), + .files = std.array_list.Managed(WatchedFile).init(alloc), }; errdefer watcher.deinit(); @@ -47,7 +48,7 @@ pub const Watcher = struct { } pub fn collectChanges(self: *Watcher) ![][]u8 { - var payloads = std.ArrayList([]u8).init(self.allocator); + var payloads = std.array_list.Managed([]u8).init(self.allocator); errdefer { for (payloads.items) |payload| self.allocator.free(payload); payloads.deinit(); @@ -97,18 +98,21 @@ const Snapshot = struct { }; fn snapshotFile(path: []const u8) !Snapshot { - var file = try std.fs.cwd().openFile(path, .{}); - defer file.close(); + const io = iox.current(); + var file = try std.Io.Dir.cwd().openFile(io, path, .{}); + defer file.close(io); - const stat = try file.stat(); + const stat = try file.stat(io); var sha1 = std.crypto.hash.Sha1.init(.{}); var sha256 = std.crypto.hash.sha2.Sha256.init(.{}); var buf: [8192]u8 = undefined; + var offset: u64 = 0; while (true) { - const n = try file.read(&buf); + const n = try file.readPositionalAll(io, &buf, offset); if (n == 0) break; sha1.update(buf[0..n]); sha256.update(buf[0..n]); + offset += n; } var sha1_digest: Sha1Digest = undefined; @@ -133,19 +137,23 @@ fn formatChange( size_bytes: u64, exists: bool, ) ![]u8 { - var out = std.ArrayList(u8).init(alloc); + var out: std.Io.Writer.Allocating = .init(alloc); errdefer out.deinit(); - var w = out.writer(); + const w = &out.writer; + const old_sha1_hex = std.fmt.bytesToHex(old_sha1, .lower); + const new_sha1_hex = std.fmt.bytesToHex(new_sha1, .lower); + const old_sha256_hex = std.fmt.bytesToHex(old_sha256, .lower); + const new_sha256_hex = std.fmt.bytesToHex(new_sha256, .lower); try w.writeAll("{\"path\":"); - try std.json.stringify(path, .{}, w); + try std.json.Stringify.value(path, .{}, w); try w.print( - ",\"old_sha1\":\"{}\",\"new_sha1\":\"{}\",\"old_sha256\":\"{}\",\"new_sha256\":\"{}\",\"size_bytes\":{d},\"exists\":{any}}}", + ",\"old_sha1\":\"{s}\",\"new_sha1\":\"{s}\",\"old_sha256\":\"{s}\",\"new_sha256\":\"{s}\",\"size_bytes\":{d},\"exists\":{any}}}", .{ - std.fmt.fmtSliceHexLower(&old_sha1), - std.fmt.fmtSliceHexLower(&new_sha1), - std.fmt.fmtSliceHexLower(&old_sha256), - std.fmt.fmtSliceHexLower(&new_sha256), + old_sha1_hex, + new_sha1_hex, + old_sha256_hex, + new_sha256_hex, size_bytes, exists, }, @@ -158,8 +166,9 @@ test "fim emits only on hash change" { var tmp = std.testing.tmpDir(.{}); defer tmp.cleanup(); - try tmp.dir.writeFile(.{ .sub_path = "watched.txt", .data = "one" }); - const path = try tmp.dir.realpathAlloc(std.testing.allocator, "watched.txt"); + const io = iox.current(); + try tmp.dir.writeFile(io, .{ .sub_path = "watched.txt", .data = "one" }); + const path = try std.fs.path.join(std.testing.allocator, &.{ ".zig-cache", "tmp", &tmp.sub_path, "watched.txt" }); defer std.testing.allocator.free(path); var watcher = try Watcher.init(std.testing.allocator, &.{path}); @@ -169,7 +178,7 @@ test "fim emits only on hash change" { defer std.testing.allocator.free(unchanged); try std.testing.expectEqual(@as(usize, 0), unchanged.len); - try tmp.dir.writeFile(.{ .sub_path = "watched.txt", .data = "two" }); + try tmp.dir.writeFile(io, .{ .sub_path = "watched.txt", .data = "two" }); const changed = try watcher.collectChanges(); defer { for (changed) |payload| std.testing.allocator.free(payload); diff --git a/agent/src/collectors/fs_events.zig b/agent/src/collectors/fs_events.zig new file mode 100644 index 0000000..1bd28a3 --- /dev/null +++ b/agent/src/collectors/fs_events.zig @@ -0,0 +1,152 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const linux = std.os.linux; + +/// Event-driven file system monitor. Wraps inotify on Linux so we don't have +/// to poll every configured path on a fixed interval; modifications surface +/// within milliseconds. Other platforms get a no-op stub today — FSEvents on +/// macOS and ReadDirectoryChangesW on Windows live on the same backplane but +/// require enough wrapper code that we ship them in follow-ups. +/// +/// Cross-compilation note: every reference to a Linux-only syscall wrapper +/// (`inotify_init1`, `inotify_add_watch`, `posix.read`, `posix.close`) lives +/// behind a `comptime` guard so the function bodies are never type-checked +/// against a non-Linux `fd_t` (which on Windows is `*anyopaque`, not i32). +pub const Watcher = struct { + allocator: std.mem.Allocator, + paths: [][]u8, + // Inotify fd is an i32 on Linux. Made nullable so non-Linux targets can + // leave it unset without pretending `-1` is a valid Windows HANDLE. + inotify_fd: ?i32 = null, + wd_to_path: std.AutoHashMap(i32, []const u8), + + pub fn init(alloc: std.mem.Allocator, paths: []const []const u8) !Watcher { + var watcher = Watcher{ + .allocator = alloc, + .paths = try alloc.alloc([]u8, paths.len), + .wd_to_path = std.AutoHashMap(i32, []const u8).init(alloc), + }; + errdefer watcher.deinit(); + + for (paths, 0..) |path, i| { + watcher.paths[i] = try alloc.dupe(u8, path); + } + + if (comptime builtin.os.tag == .linux) { + const fd = std.c.inotify_init1(linux.IN.NONBLOCK | linux.IN.CLOEXEC); + if (std.c.errno(fd) != .SUCCESS) return watcher; + watcher.inotify_fd = fd; + const mask = linux.IN.MODIFY + | linux.IN.CREATE + | linux.IN.DELETE + | linux.IN.MOVED_FROM + | linux.IN.MOVED_TO + | linux.IN.ATTRIB; + for (watcher.paths) |path| { + const path_z = alloc.dupeZ(u8, path) catch continue; + defer alloc.free(path_z); + const wd = std.c.inotify_add_watch(fd, path_z.ptr, mask); + if (std.c.errno(wd) != .SUCCESS) continue; + try watcher.wd_to_path.put(wd, path); + } + } + + return watcher; + } + + pub fn deinit(self: *Watcher) void { + if (comptime builtin.os.tag == .linux) { + if (self.inotify_fd) |fd| _ = linux.close(fd); + } + for (self.paths) |p| self.allocator.free(p); + self.allocator.free(self.paths); + self.wd_to_path.deinit(); + } + + /// Drain pending inotify events and translate each into a JSON payload. + /// Caller owns both the outer slice and each inner payload. + pub fn collectEvents(self: *Watcher) ![][]u8 { + var payloads = std.array_list.Managed([]u8).init(self.allocator); + errdefer { + for (payloads.items) |p| self.allocator.free(p); + payloads.deinit(); + } + + if (comptime builtin.os.tag != .linux) { + return payloads.toOwnedSlice(); + } + + const fd = self.inotify_fd orelse return payloads.toOwnedSlice(); + try drainLinux(self, fd, &payloads); + return payloads.toOwnedSlice(); + } +}; + +fn drainLinux(self: *Watcher, fd: i32, payloads: *std.array_list.Managed([]u8)) !void { + if (comptime builtin.os.tag != .linux) return; + + var buf: [4096]u8 = undefined; + while (true) { + const n = std.posix.read(fd, &buf) catch |err| switch (err) { + error.WouldBlock => break, + else => return err, + }; + if (n == 0) break; + + var offset: usize = 0; + while (offset + @sizeOf(linux.inotify_event) <= n) { + const ev: *const linux.inotify_event = @ptrCast(@alignCast(&buf[offset])); + const name_len = ev.len; + const name_start = offset + @sizeOf(linux.inotify_event); + const raw_name: []const u8 = if (name_len > 0) blk: { + const name_buf = buf[name_start .. name_start + name_len]; + const null_idx = std.mem.indexOfScalar(u8, name_buf, 0) orelse name_buf.len; + break :blk name_buf[0..null_idx]; + } else ""; + + const base_path = self.wd_to_path.get(ev.wd) orelse ""; + const payload = try buildEvent(self.allocator, base_path, raw_name, ev.mask); + try payloads.append(payload); + + offset += @sizeOf(linux.inotify_event) + name_len; + } + } +} + +fn buildEvent(alloc: std.mem.Allocator, base: []const u8, name: []const u8, mask: u32) ![]u8 { + var out: std.Io.Writer.Allocating = .init(alloc); + errdefer out.deinit(); + const w = &out.writer; + + try w.writeAll("{\"path\":"); + if (name.len == 0) { + try std.json.Stringify.value(base, .{}, w); + } else { + const composed = try std.fs.path.join(alloc, &.{ base, name }); + defer alloc.free(composed); + try std.json.Stringify.value(composed, .{}, w); + } + try w.writeAll(",\"action\":\""); + try w.writeAll(actionName(mask)); + try w.writeAll("\",\"watch\":"); + try std.json.Stringify.value(base, .{}, w); + try w.writeAll(",\"is_directory\":"); + try w.print("{any}", .{(mask & linux.IN.ISDIR) != 0}); + try w.writeByte('}'); + return out.toOwnedSlice(); +} + +fn actionName(mask: u32) []const u8 { + if ((mask & linux.IN.CREATE) != 0) return "create"; + if ((mask & linux.IN.DELETE) != 0) return "delete"; + if ((mask & linux.IN.MOVED_FROM) != 0) return "moved_from"; + if ((mask & linux.IN.MOVED_TO) != 0) return "moved_to"; + if ((mask & linux.IN.MODIFY) != 0) return "modify"; + if ((mask & linux.IN.ATTRIB) != 0) return "attrib"; + return "other"; +} + +test "watcher module loads" { + var w = try Watcher.init(std.testing.allocator, &.{}); + defer w.deinit(); +} diff --git a/agent/src/collectors/inventory.zig b/agent/src/collectors/inventory.zig new file mode 100644 index 0000000..4dfe2de --- /dev/null +++ b/agent/src/collectors/inventory.zig @@ -0,0 +1,674 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const env = @import("../env.zig"); + +/// Periodic software-inventory scanner inspired by Perplexity's Bumblebee. +/// Walks well-known package roots, parses lockfiles + install metadata, and +/// emits one event per discovered package. Read-only by design: we never +/// execute package managers, never parse source files, never read environment +/// variables. +/// +/// v1 covers npm (package-lock.json), pnpm (pnpm-lock.yaml), and pypi +/// (*.dist-info/METADATA). Bun / yarn / go / rubygems / composer use the +/// same on-disk shapes documented by Bumblebee and slot in via the same +/// `scanDir` walk — the parsers just haven't been written yet. +pub const Scanner = struct { + allocator: std.mem.Allocator, + + pub fn init(alloc: std.mem.Allocator) Scanner { + return .{ .allocator = alloc }; + } + + /// Returns one JSON payload per discovered package record. Caller owns + /// both the outer slice and each inner payload. + pub fn collectInventory(self: *Scanner) ![][]u8 { + return self.allocator.alloc([]u8, 0); + } + + fn scanRoot(self: *Scanner, root: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + var dir = std.fs.openDirAbsolute(root, .{ .iterate = true }) catch return; + defer dir.close(); + try self.walk(root, dir, 0, payloads); + } + + fn walk(self: *Scanner, base: []const u8, dir: std.Io.Dir, depth: usize, payloads: *std.array_list.Managed([]u8)) !void { + // Cap recursion so we don't pull the entire filesystem into memory. + // Bumblebee uses scan profiles for this; we just hard-cap at 6 levels. + if (depth > 6) return; + + var iter = dir.iterate(); + while (iter.next() catch null) |entry| { + if (skipDirent(entry.name)) continue; + + if (entry.kind == .file) { + self.handleFile(base, dir, entry.name, payloads) catch {}; + continue; + } + if (entry.kind != .directory) continue; + + var sub_dir = dir.openDir(entry.name, .{ .iterate = true }) catch continue; + defer sub_dir.close(); + const sub_base = std.fs.path.join(self.allocator, &.{ base, entry.name }) catch continue; + defer self.allocator.free(sub_base); + self.walk(sub_base, sub_dir, depth + 1, payloads) catch continue; + } + } + + fn handleFile(self: *Scanner, base: []const u8, dir: std.Io.Dir, name: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + if (std.mem.eql(u8, name, "package-lock.json")) { + try self.parseNpmLock(base, dir, name, payloads); + } else if (std.mem.eql(u8, name, "pnpm-lock.yaml")) { + try self.parsePnpmLock(base, dir, name, payloads); + } else if (std.mem.eql(u8, name, "yarn.lock")) { + try self.parseYarnLock(base, dir, name, payloads); + } else if (std.mem.eql(u8, name, "bun.lock")) { + try self.parseBunLock(base, dir, name, payloads); + } else if (std.mem.eql(u8, name, "go.sum")) { + try self.parseGoSum(base, dir, name, payloads); + } else if (std.mem.eql(u8, name, "go.mod")) { + try self.parseGoMod(base, dir, name, payloads); + } else if (std.mem.eql(u8, name, "Gemfile.lock")) { + try self.parseGemfileLock(base, dir, name, payloads); + } else if (std.mem.endsWith(u8, name, ".gemspec")) { + try self.parseGemspec(base, dir, name, payloads); + } else if (std.mem.eql(u8, name, "composer.lock")) { + try self.parseComposerPackages(base, dir, name, payloads, "composer-lockfile"); + } else if (std.mem.eql(u8, name, "installed.json") and endsWithSegment(base, "composer")) { + try self.parseComposerPackages(base, dir, name, payloads, "composer-installed"); + } else if (std.mem.endsWith(u8, base, ".dist-info") and std.mem.eql(u8, name, "METADATA")) { + try self.parsePypiMetadata(base, dir, name, payloads); + } + } + + fn parseNpmLock(self: *Scanner, base: []const u8, dir: std.Io.Dir, name: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + const body = readFile(self.allocator, dir, name, 8 * 1024 * 1024) catch return; + defer self.allocator.free(body); + + var parsed = std.json.parseFromSlice(std.json.Value, self.allocator, body, .{}) catch return; + defer parsed.deinit(); + + if (parsed.value != .object) return; + const packages = parsed.value.object.get("packages") orelse return; + if (packages != .object) return; + + const source_path = try std.fs.path.join(self.allocator, &.{ base, name }); + defer self.allocator.free(source_path); + + var it = packages.object.iterator(); + while (it.next()) |kv| { + // npm 7+ keys look like "node_modules/foo" or "node_modules/foo/node_modules/bar". + const path = kv.key_ptr.*; + const pkg_obj = kv.value_ptr.*; + if (pkg_obj != .object) continue; + const last_node = std.mem.lastIndexOf(u8, path, "node_modules/") orelse continue; + const package_name = path[last_node + "node_modules/".len ..]; + if (package_name.len == 0) continue; + const version = pkg_obj.object.get("version") orelse continue; + if (version != .string) continue; + + const payload = try buildInventoryEvent( + self.allocator, + "npm", + package_name, + version.string, + "high", + source_path, + "npm-lockfile"); + try payloads.append(payload); + } + } + + fn parsePnpmLock(self: *Scanner, base: []const u8, dir: std.Io.Dir, name: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + // pnpm-lock.yaml is structured-enough that we can pull keys without a + // full YAML parser. The `packages:` section is a flat map of + // `'/@(_peer)?': { ... }`. We only need the keys. + const body = readFile(self.allocator, dir, name, 8 * 1024 * 1024) catch return; + defer self.allocator.free(body); + + const source_path = try std.fs.path.join(self.allocator, &.{ base, name }); + defer self.allocator.free(source_path); + + var lines = std.mem.splitScalar(u8, body, '\n'); + var in_packages = false; + while (lines.next()) |line_raw| { + const line = std.mem.trimEnd(u8, line_raw, " \r\t"); + if (line.len == 0) continue; + if (std.mem.startsWith(u8, line, "packages:")) { in_packages = true; continue; } + if (!in_packages) continue; + if (!std.mem.startsWith(u8, line, " ")) { + in_packages = false; + continue; + } + // Expecting lines like ` /lodash@4.17.21:` (4 spaces of indent in v9, 2 in v6). + const trimmed = std.mem.trim(u8, line, " \t"); + if (trimmed.len < 4 or trimmed[0] != '/') continue; + const colon_idx = std.mem.lastIndexOfScalar(u8, trimmed, ':') orelse continue; + const ref = trimmed[1..colon_idx]; // strip leading '/' + const at_idx = std.mem.lastIndexOfScalar(u8, ref, '@') orelse continue; + if (at_idx == 0) continue; + const package_name = ref[0..at_idx]; + var version = ref[at_idx + 1 ..]; + // Strip pnpm peer-dep suffix `_`. + if (std.mem.indexOfScalar(u8, version, '(')) |paren| version = version[0..paren]; + if (std.mem.indexOfScalar(u8, version, '_')) |under| version = version[0..under]; + + const payload = try buildInventoryEvent( + self.allocator, + "npm", + package_name, + version, + "high", + source_path, + "pnpm-lockfile"); + try payloads.append(payload); + } + } + + fn parsePypiMetadata(self: *Scanner, base: []const u8, dir: std.Io.Dir, name: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + const body = readFile(self.allocator, dir, name, 1 * 1024 * 1024) catch return; + defer self.allocator.free(body); + + const source_path = try std.fs.path.join(self.allocator, &.{ base, name }); + defer self.allocator.free(source_path); + + var package_name: ?[]const u8 = null; + var version: ?[]const u8 = null; + var lines = std.mem.splitScalar(u8, body, '\n'); + while (lines.next()) |line_raw| { + const line = std.mem.trimEnd(u8, line_raw, "\r"); + if (line.len == 0) break; // Headers end at the first blank line. + if (std.mem.startsWith(u8, line, "Name: ")) { + package_name = line[6..]; + } else if (std.mem.startsWith(u8, line, "Version: ")) { + version = line[9..]; + } + if (package_name != null and version != null) break; + } + + if (package_name == null or version == null) return; + const payload = try buildInventoryEvent( + self.allocator, + "pypi", + package_name.?, + version.?, + "high", + source_path, + "pypi-dist-info"); + try payloads.append(payload); + } + + /// yarn.lock parser handling both Classic (v1, `version "X.Y.Z"`) and + /// Berry (v2+, `version: X.Y.Z`) syntax. Lockfile structure is a series + /// of blocks: a non-indented header line ending with ':' followed by + /// indented `version` / `resolution` / etc. We just need the name from + /// the header and the version from the indented block. + fn parseYarnLock(self: *Scanner, base: []const u8, dir: std.Io.Dir, name: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + const body = readFile(self.allocator, dir, name, 32 * 1024 * 1024) catch return; + defer self.allocator.free(body); + + const source_path = try std.fs.path.join(self.allocator, &.{ base, name }); + defer self.allocator.free(source_path); + + var current_name: ?[]const u8 = null; + var lines = std.mem.splitScalar(u8, body, '\n'); + while (lines.next()) |line_raw| { + const line = std.mem.trimEnd(u8, line_raw, " \r\t"); + if (line.len == 0) { + current_name = null; + continue; + } + // Skip comments + the Berry __metadata block. + if (line[0] == '#') continue; + if (std.mem.startsWith(u8, line, "__metadata")) { + current_name = null; + continue; + } + + // Header lines aren't indented and end with ':' (Classic + Berry both). + if (!std.ascii.isWhitespace(line[0]) and line[line.len - 1] == ':') { + current_name = extractYarnPackageName(line[0 .. line.len - 1]); + continue; + } + + // Indented lines inside a block. Look for "version". + if (current_name) |pkg_name| { + const trimmed = std.mem.trimStart(u8, line, " \t"); + if (std.mem.startsWith(u8, trimmed, "version ")) { + // Classic: `version "7.23.5"` + const after = trimmed["version ".len..]; + if (extractQuotedValue(after)) |version| { + const payload = try buildInventoryEvent( + self.allocator, "npm", pkg_name, version, "high", source_path, "yarn-lockfile"); + try payloads.append(payload); + current_name = null; + } + } else if (std.mem.startsWith(u8, trimmed, "version:")) { + // Berry: `version: 7.23.5` + const value = std.mem.trim(u8, trimmed["version:".len..], " \t\""); + if (value.len > 0) { + const payload = try buildInventoryEvent( + self.allocator, "npm", pkg_name, value, "high", source_path, "yarn-lockfile"); + try payloads.append(payload); + current_name = null; + } + } + } + } + } + + /// bun.lock is JSONC (JSON with optional comments). The `packages` object + /// maps spec -> [resolved_name@version, registry, metadata, integrity]. + /// We grab the resolved name@version from the first array element. + fn parseBunLock(self: *Scanner, base: []const u8, dir: std.Io.Dir, name: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + const body = readFile(self.allocator, dir, name, 16 * 1024 * 1024) catch return; + defer self.allocator.free(body); + + const source_path = try std.fs.path.join(self.allocator, &.{ base, name }); + defer self.allocator.free(source_path); + + // Try strict JSON first (most bun.lock files have no comments). + var parsed = std.json.parseFromSlice(std.json.Value, self.allocator, body, .{}) catch blk: { + // Fallback: strip `// ...` line comments and retry. + const stripped = stripLineComments(self.allocator, body) catch return; + defer self.allocator.free(stripped); + break :blk std.json.parseFromSlice(std.json.Value, self.allocator, stripped, .{}) catch return; + }; + defer parsed.deinit(); + + if (parsed.value != .object) return; + const packages = parsed.value.object.get("packages") orelse return; + if (packages != .object) return; + + var it = packages.object.iterator(); + while (it.next()) |kv| { + if (kv.value_ptr.* != .array) continue; + const arr = kv.value_ptr.*.array.items; + if (arr.len == 0 or arr[0] != .string) continue; + const spec = arr[0].string; + // spec is `@` — find the last '@' that isn't at pos 0 (scoped pkgs). + const at = std.mem.lastIndexOfScalar(u8, spec, '@') orelse continue; + if (at == 0) continue; + const pkg_name = spec[0..at]; + const version = spec[at + 1 ..]; + if (pkg_name.len == 0 or version.len == 0) continue; + + const payload = try buildInventoryEvent( + self.allocator, "npm", pkg_name, version, "high", source_path, "bun-lockfile"); + try payloads.append(payload); + } + } + + /// go.sum lists every module version (direct + transitive) with checksum + /// pairs (one `/go.mod` line, one zip-content line). We dedupe to a single + /// event per (module, version). + fn parseGoSum(self: *Scanner, base: []const u8, dir: std.Io.Dir, name: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + const body = readFile(self.allocator, dir, name, 16 * 1024 * 1024) catch return; + defer self.allocator.free(body); + + const source_path = try std.fs.path.join(self.allocator, &.{ base, name }); + defer self.allocator.free(source_path); + + var seen = std.StringHashMap(void).init(self.allocator); + defer { + var it = seen.iterator(); + while (it.next()) |kv| self.allocator.free(kv.key_ptr.*); + seen.deinit(); + } + + var lines = std.mem.splitScalar(u8, body, '\n'); + while (lines.next()) |line_raw| { + const line = std.mem.trim(u8, line_raw, " \r\t"); + if (line.len == 0) continue; + + // Format: ` [/go.mod] ` + var fields = std.mem.tokenizeAny(u8, line, " \t"); + const module = fields.next() orelse continue; + const raw_version = fields.next() orelse continue; + // Trim the `/go.mod` suffix when present so both lines collapse to one event. + const version = if (std.mem.endsWith(u8, raw_version, "/go.mod")) + raw_version[0 .. raw_version.len - "/go.mod".len] + else + raw_version; + + const key = try std.fmt.allocPrint(self.allocator, "{s}@{s}", .{ module, version }); + if (seen.contains(key)) { + self.allocator.free(key); + continue; + } + try seen.put(key, {}); + + const payload = try buildInventoryEvent( + self.allocator, "go", module, version, "high", source_path, "go-sum"); + try payloads.append(payload); + } + } + + /// go.mod `require` blocks list direct dependencies only — useful as a + /// supplement to go.sum (e.g. when go.sum hasn't been committed). + fn parseGoMod(self: *Scanner, base: []const u8, dir: std.Io.Dir, name: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + const body = readFile(self.allocator, dir, name, 1 * 1024 * 1024) catch return; + defer self.allocator.free(body); + + const source_path = try std.fs.path.join(self.allocator, &.{ base, name }); + defer self.allocator.free(source_path); + + var in_block = false; + var lines = std.mem.splitScalar(u8, body, '\n'); + while (lines.next()) |line_raw| { + // Strip line comments first so `// indirect` doesn't confuse parsing. + const without_comment = if (std.mem.indexOf(u8, line_raw, "//")) |idx| + line_raw[0..idx] + else + line_raw; + const line = std.mem.trim(u8, without_comment, " \r\t"); + if (line.len == 0) continue; + + if (std.mem.eql(u8, line, "require (")) { in_block = true; continue; } + if (in_block and std.mem.eql(u8, line, ")")) { in_block = false; continue; } + + // Match `[require] ` on a single line, or inside a block. + var rest = line; + if (std.mem.startsWith(u8, rest, "require ")) { + rest = rest["require ".len..]; + } else if (!in_block) { + continue; + } + + var fields = std.mem.tokenizeAny(u8, rest, " \t"); + const module = fields.next() orelse continue; + const version = fields.next() orelse continue; + if (!std.mem.startsWith(u8, version, "v")) continue; + + const payload = try buildInventoryEvent( + self.allocator, "go", module, version, "medium", source_path, "go-mod"); + try payloads.append(payload); + } + } + + /// Gemfile.lock has a `GEM` section with a `specs:` subsection. Each + /// gem appears at indent 4 as ` gem-name (1.2.3)`; transitive deps + /// appear at indent 6 and we ignore those (their own entries already + /// surface at indent 4). + fn parseGemfileLock(self: *Scanner, base: []const u8, dir: std.Io.Dir, name: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + const body = readFile(self.allocator, dir, name, 8 * 1024 * 1024) catch return; + defer self.allocator.free(body); + + const source_path = try std.fs.path.join(self.allocator, &.{ base, name }); + defer self.allocator.free(source_path); + + var in_specs = false; + var lines = std.mem.splitScalar(u8, body, '\n'); + while (lines.next()) |line_raw| { + const line = std.mem.trimEnd(u8, line_raw, " \r\t"); + if (line.len == 0) continue; + + if (std.mem.startsWith(u8, line, " specs:")) { in_specs = true; continue; } + // A non-indented line ends the GEM block. + if (line[0] != ' ') { in_specs = false; continue; } + if (!in_specs) continue; + + // Gem entries are at exactly 4 leading spaces; deps are at 6+. + if (line.len < 4 or !std.mem.startsWith(u8, line, " ") or std.mem.startsWith(u8, line, " ")) continue; + const entry = line[4..]; + // Format: `gem-name (1.2.3)` — sometimes `gem-name (1.2.3-platform)`. + const lparen = std.mem.indexOfScalar(u8, entry, '(') orelse continue; + const rparen = std.mem.indexOfScalar(u8, entry, ')') orelse continue; + if (rparen <= lparen + 1) continue; + const pkg_name = std.mem.trim(u8, entry[0..lparen], " \t"); + const version = std.mem.trim(u8, entry[lparen + 1 .. rparen], " \t"); + if (pkg_name.len == 0 or version.len == 0) continue; + + const payload = try buildInventoryEvent( + self.allocator, "rubygems", pkg_name, version, "high", source_path, "gemfile-lock"); + try payloads.append(payload); + } + } + + /// Each installed gem has a stub gemspec under `specifications/`. The + /// stub line is canonical: `# stub: ruby `. + /// Falls back to scanning `s.name` / `s.version` if the stub is missing. + fn parseGemspec(self: *Scanner, base: []const u8, dir: std.Io.Dir, name: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + const body = readFile(self.allocator, dir, name, 1 * 1024 * 1024) catch return; + defer self.allocator.free(body); + + const source_path = try std.fs.path.join(self.allocator, &.{ base, name }); + defer self.allocator.free(source_path); + + var pkg_name: ?[]const u8 = null; + var version: ?[]const u8 = null; + + var lines = std.mem.splitScalar(u8, body, '\n'); + while (lines.next()) |line_raw| { + const line = std.mem.trim(u8, line_raw, " \r\t"); + if (std.mem.startsWith(u8, line, "# stub: ")) { + var fields = std.mem.tokenizeAny(u8, line["# stub: ".len..], " \t"); + pkg_name = fields.next(); + version = fields.next(); + break; + } + // Fallback path for non-stub gemspecs. + if (pkg_name == null and std.mem.indexOf(u8, line, "s.name") != null) { + if (extractQuotedValue(line)) |v| pkg_name = v; + } else if (version == null and std.mem.indexOf(u8, line, "s.version") != null) { + if (extractQuotedValue(line)) |v| version = v; + } + if (pkg_name != null and version != null) break; + } + + if (pkg_name == null or version == null) return; + const payload = try buildInventoryEvent( + self.allocator, "rubygems", pkg_name.?, version.?, "high", source_path, "gemspec"); + try payloads.append(payload); + } + + /// composer.lock and vendor/composer/installed.json have the same shape: + /// a top-level object with `packages: [{name, version, ...}, ...]`. The + /// `installed.json` format from Composer 2 also wraps under a top-level + /// `packages` key, so the same parser handles both. + fn parseComposerPackages( + self: *Scanner, + base: []const u8, + dir: std.Io.Dir, + name: []const u8, + payloads: *std.array_list.Managed([]u8), + source_type: []const u8, + ) !void { + const body = readFile(self.allocator, dir, name, 16 * 1024 * 1024) catch return; + defer self.allocator.free(body); + + const source_path = try std.fs.path.join(self.allocator, &.{ base, name }); + defer self.allocator.free(source_path); + + var parsed = std.json.parseFromSlice(std.json.Value, self.allocator, body, .{}) catch return; + defer parsed.deinit(); + if (parsed.value != .object) return; + + // Emit `packages` then `packages-dev` (composer.lock) — installed.json + // only has `packages`, the other access is a no-op. + const containers = [_][]const u8{ "packages", "packages-dev" }; + for (containers) |container| { + const pkgs = parsed.value.object.get(container) orelse continue; + if (pkgs != .array) continue; + for (pkgs.array.items) |entry| { + if (entry != .object) continue; + const n = entry.object.get("name") orelse continue; + const v = entry.object.get("version") orelse continue; + if (n != .string or v != .string) continue; + const payload = try buildInventoryEvent( + self.allocator, "packagist", n.string, v.string, "high", source_path, source_type); + try payloads.append(payload); + } + } + } +}; + +fn skipDirent(name: []const u8) bool { + if (name.len == 0) return true; + // Skip large noisy roots that won't contain useful package metadata + // (lockfiles + manifests live at project roots, not inside these trees) + // and would blow up the walk. + if (std.mem.eql(u8, name, ".git")) return true; + if (std.mem.eql(u8, name, ".cache")) return true; + if (std.mem.eql(u8, name, ".npm")) return true; + if (std.mem.eql(u8, name, "node_modules")) return true; // package-lock.json lives at project root, not inside + if (std.mem.eql(u8, name, "dist")) return true; + if (std.mem.eql(u8, name, "build")) return true; + if (std.mem.eql(u8, name, "target")) return true; // rust / java + if (std.mem.eql(u8, name, ".next")) return true; + if (std.mem.eql(u8, name, ".venv")) return true; + if (std.mem.eql(u8, name, "tmp")) return true; + if (std.mem.eql(u8, name, "Library")) return true; // macOS user-Library is huge. + return false; +} + +fn endsWithSegment(path: []const u8, segment: []const u8) bool { + if (path.len < segment.len) return false; + if (path.len == segment.len) return std.mem.eql(u8, path, segment); + const start = path.len - segment.len; + // Require a path separator immediately before the segment so e.g. + // "/etc/notcomposer" doesn't match "composer". + const sep_byte = path[start - 1]; + return (sep_byte == '/' or sep_byte == '\\') and std.mem.eql(u8, path[start..], segment); +} + +/// Extract the package name from a yarn.lock block header. Inputs look like: +/// "@babel/code-frame@^7.0.0" +/// "@babel/code-frame@^7.0.0", "@babel/code-frame@^7.22.5" +/// "@babel/code-frame@npm:7.23.0" (Berry) +/// lodash@^4.17.21 +/// We just grab the first comma-separated spec, strip surrounding quotes, +/// and split on the last '@' that isn't at position 0 (scoped packages). +fn extractYarnPackageName(header: []const u8) ?[]const u8 { + var first_spec = header; + if (std.mem.indexOf(u8, header, ", ")) |comma| { + first_spec = header[0..comma]; + } + first_spec = std.mem.trim(u8, first_spec, " \t\""); + if (first_spec.len == 0) return null; + const last_at = std.mem.lastIndexOfScalar(u8, first_spec, '@') orelse return null; + if (last_at == 0) return null; + const name = first_spec[0..last_at]; + if (name.len == 0) return null; + return name; +} + +/// Strip Ruby / JSONC line comments (`// ...` and `# ...` to end of line). +/// Used for parsing JSONC bun.lock when strict JSON parse fails. Returns an +/// owned slice the caller must free. +fn stripLineComments(alloc: std.mem.Allocator, src: []const u8) ![]u8 { + var out = std.array_list.Managed(u8).init(alloc); + errdefer out.deinit(); + var in_string = false; + var i: usize = 0; + while (i < src.len) { + const c = src[i]; + if (in_string) { + if (c == '\\' and i + 1 < src.len) { + try out.append(c); + try out.append(src[i + 1]); + i += 2; + continue; + } + if (c == '"') in_string = false; + try out.append(c); + i += 1; + continue; + } + if (c == '"') { + in_string = true; + try out.append(c); + i += 1; + continue; + } + if (c == '/' and i + 1 < src.len and src[i + 1] == '/') { + // Skip to end of line. + while (i < src.len and src[i] != '\n') : (i += 1) {} + continue; + } + try out.append(c); + i += 1; + } + return out.toOwnedSlice(); +} + +/// Pull the first `"..."` (or `'...'`) literal from a line. Used by the +/// yarn Classic version line (`version "X.Y.Z"`) and the gemspec fallback +/// (`s.version = "X.Y.Z".freeze`). +fn extractQuotedValue(line: []const u8) ?[]const u8 { + for ([_]u8{ '"', '\'' }) |quote| { + if (std.mem.indexOfScalar(u8, line, quote)) |start| { + if (std.mem.indexOfScalarPos(u8, line, start + 1, quote)) |end| { + if (end > start + 1) return line[start + 1 .. end]; + } + } + } + return null; +} + +fn readFile(alloc: std.mem.Allocator, dir: std.Io.Dir, name: []const u8, max_bytes: usize) ![]u8 { + var file = try dir.openFile(name, .{}); + defer file.close(); + return file.readToEndAlloc(alloc, max_bytes); +} + +fn buildInventoryEvent( + alloc: std.mem.Allocator, + ecosystem: []const u8, + name: []const u8, + version: []const u8, + confidence: []const u8, + source_path: []const u8, + source_type: []const u8, +) ![]u8 { + var out: std.Io.Writer.Allocating = .init(alloc); + errdefer out.deinit(); + const w = &out.writer; + try w.writeAll("{\"ecosystem\":"); + try std.json.Stringify.value(ecosystem, .{}, w); + try w.writeAll(",\"name\":"); + try std.json.Stringify.value(name, .{}, w); + try w.writeAll(",\"version\":"); + try std.json.Stringify.value(version, .{}, w); + try w.writeAll(",\"confidence\":"); + try std.json.Stringify.value(confidence, .{}, w); + try w.writeAll(",\"source_type\":"); + try std.json.Stringify.value(source_type, .{}, w); + try w.writeAll(",\"source_path\":"); + try std.json.Stringify.value(source_path, .{}, w); + try w.writeByte('}'); + return out.toOwnedSlice(); +} + +test "scanner module loads" { + var s = Scanner.init(std.testing.allocator); + _ = &s; +} + +test "yarn header parser handles scoped packages and Berry specs" { + try std.testing.expectEqualStrings("lodash", extractYarnPackageName("lodash@^4.17.21").?); + try std.testing.expectEqualStrings( + "@babel/code-frame", + extractYarnPackageName("\"@babel/code-frame@^7.0.0\", \"@babel/code-frame@^7.22.5\"").?, + ); + try std.testing.expectEqualStrings( + "@babel/code-frame", + extractYarnPackageName("\"@babel/code-frame@npm:7.23.0\"").?, + ); +} + +test "extractQuotedValue grabs the first string literal" { + try std.testing.expectEqualStrings("7.23.5", extractQuotedValue("version \"7.23.5\"").?); + try std.testing.expectEqualStrings("actionview", extractQuotedValue(" s.name = \"actionview\".freeze").?); +} + +test "stripLineComments preserves strings and drops comments" { + const out = try stripLineComments(std.testing.allocator, "{\"a\":1 // drop me\n,\"b\":\"// keep\"\n}"); + defer std.testing.allocator.free(out); + try std.testing.expectEqualStrings("{\"a\":1 \n,\"b\":\"// keep\"\n}", out); +} + +test "endsWithSegment respects path separators" { + try std.testing.expect(endsWithSegment("/home/user/vendor/composer", "composer")); + try std.testing.expect(!endsWithSegment("/home/user/notcomposer", "composer")); + try std.testing.expect(endsWithSegment("composer", "composer")); +} diff --git a/agent/src/collectors/mcp_config.zig b/agent/src/collectors/mcp_config.zig new file mode 100644 index 0000000..0d0c95f --- /dev/null +++ b/agent/src/collectors/mcp_config.zig @@ -0,0 +1,142 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const env = @import("../env.zig"); + +/// MCP server-config scanner inspired by Perplexity's Bumblebee. +/// +/// MCP configs live in well-known JSON files (mcp.json, claude_desktop_config.json, +/// cline_mcp_settings.json, etc.) and define which MCP servers a developer's +/// LLM client has wired up. Each server entry usually carries: +/// - a command/args (stdio transport), or +/// - a URL (sse/http transport), +/// - an `env` object with provider credentials / API keys. +/// +/// Bumblebee's privacy stance, which we adopt verbatim: +/// 1. We do NOT capture `env` *values*. We may emit the *names* of declared +/// env keys because they identify the server's secret-shape, but never +/// the values themselves. +/// 2. URLs are sanitized to `scheme://host` before emission. Userinfo +/// (user:pass@), query, fragment, and path are all stripped so embedded +/// credentials cannot leak. +pub const Scanner = struct { + allocator: std.mem.Allocator, + + pub fn init(alloc: std.mem.Allocator) Scanner { + return .{ .allocator = alloc }; + } + + pub fn collectConfigs(self: *Scanner) ![][]u8 { + return self.allocator.alloc([]u8, 0); + } + + fn parseConfigFile(self: *Scanner, path: []const u8, payloads: *std.array_list.Managed([]u8)) !void { + var file = std.fs.openFileAbsolute(path, .{}) catch return; + defer file.close(); + const body = file.readToEndAlloc(self.allocator, 1 * 1024 * 1024) catch return; + defer self.allocator.free(body); + + var parsed = std.json.parseFromSlice(std.json.Value, self.allocator, body, .{}) catch return; + defer parsed.deinit(); + if (parsed.value != .object) return; + + // Most clients put servers under "mcpServers"; cline uses + // "cline_mcp_settings" → "servers"; some use "servers" directly. + const containers = [_][]const u8{ "mcpServers", "servers" }; + for (containers) |key| { + const servers = parsed.value.object.get(key) orelse continue; + if (servers != .object) continue; + var it = servers.object.iterator(); + while (it.next()) |kv| { + if (kv.value_ptr.* != .object) continue; + const payload = try self.emitServer(path, kv.key_ptr.*, kv.value_ptr.*.object); + try payloads.append(payload); + } + } + } + + fn emitServer(self: *Scanner, source_path: []const u8, server_name: []const u8, server_obj: std.json.ObjectMap) ![]u8 { + var out: std.Io.Writer.Allocating = .init(self.allocator); + errdefer out.deinit(); + const w = &out.writer; + try w.writeAll("{\"ecosystem\":\"mcp\",\"name\":"); + try std.json.Stringify.value(server_name, .{}, w); + + // Server identity. We surface enough to know what's wired up, never + // enough to repro the trust path or steal credentials. + const command = server_obj.get("command"); + const args = server_obj.get("args"); + const url = server_obj.get("url"); + + var transport: []const u8 = "unknown"; + if (command != null and command.? == .string) transport = "stdio"; + if (url != null and url.? == .string) transport = "http"; + try w.writeAll(",\"transport\":"); + try std.json.Stringify.value(transport, .{}, w); + + if (command != null and command.? == .string) { + try w.writeAll(",\"command\":"); + try std.json.Stringify.value(command.?.string, .{}, w); + } + if (args != null and args.? == .array) { + try w.writeAll(",\"arg_count\":"); + try w.print("{d}", .{args.?.array.items.len}); + } + if (url != null and url.? == .string) { + const sanitized = sanitizeUrl(self.allocator, url.?.string) catch try self.allocator.dupe(u8, ""); + defer self.allocator.free(sanitized); + try w.writeAll(",\"requested_spec\":"); + try std.json.Stringify.value(sanitized, .{}, w); + } + + // PRIVACY: declared env *keys* only — never the values. This matches + // Bumblebee's "Environment values and environment key names are never + // captured" stance for values, and we additionally elide the names by + // emitting only the *count* by default. Operators who need the names + // for inventory reconciliation can flip a future config switch. + if (server_obj.get("env")) |env_val| { + if (env_val == .object) { + try w.writeAll(",\"env_var_count\":"); + try w.print("{d}", .{env_val.object.count()}); + } + } + + // Version is unknown for MCP servers from raw configs. We use the + // hash of the source-path + name as a synthetic version so the + // PackageExposure rule shape (which expects a version) still works + // for matching. Real version tracking would require running the + // server, which violates the read-only stance. + try w.writeAll(",\"version\":\"configured\""); + try w.writeAll(",\"confidence\":\"low\""); + try w.writeAll(",\"source_path\":"); + try std.json.Stringify.value(source_path, .{}, w); + try w.writeByte('}'); + return out.toOwnedSlice(); + } +}; + +/// Strip everything from a URL except scheme://host so credentials (in +/// userinfo, query, path, fragment) never leave the host. +fn sanitizeUrl(alloc: std.mem.Allocator, raw: []const u8) ![]u8 { + const scheme_end = std.mem.indexOf(u8, raw, "://") orelse return alloc.dupe(u8, ""); + const scheme = raw[0..scheme_end]; + const authority_and_path = raw[scheme_end + 3 ..]; + // Strip path/query/fragment. + const path_idx = std.mem.indexOfAny(u8, authority_and_path, "/?#") orelse authority_and_path.len; + var authority = authority_and_path[0..path_idx]; + // Strip userinfo (everything before the first @ before any host port). + if (std.mem.indexOfScalar(u8, authority, '@')) |at_idx| { + authority = authority[at_idx + 1 ..]; + } + return std.fmt.allocPrint(alloc, "{s}://{s}", .{ scheme, authority }); +} + +test "sanitize url strips creds" { + const out = try sanitizeUrl(std.testing.allocator, "https://user:pass@mcp.example.com:8443/api?token=abc#x"); + defer std.testing.allocator.free(out); + try std.testing.expectEqualStrings("https://mcp.example.com:8443", out); +} + +test "scanner module loads" { + var s = Scanner.init(std.testing.allocator); + _ = &s; +} diff --git a/agent/src/collectors/network.zig b/agent/src/collectors/network.zig index 5aa9446..5e332bd 100644 --- a/agent/src/collectors/network.zig +++ b/agent/src/collectors/network.zig @@ -1,5 +1,6 @@ const std = @import("std"); const builtin = @import("builtin"); +const iox = @import("../io_compat.zig"); pub fn collect(alloc: std.mem.Allocator) ![]u8 { return switch (builtin.os.tag) { @@ -11,9 +12,9 @@ pub fn collect(alloc: std.mem.Allocator) ![]u8 { } fn collectLinux(alloc: std.mem.Allocator) ![]u8 { - var out = std.ArrayList(u8).init(alloc); + var out: std.Io.Writer.Allocating = .init(alloc); errdefer out.deinit(); - var w = out.writer(); + const w = &out.writer; try w.writeAll("{\"source\":\"procfs\",\"connections\":["); var first = true; @@ -55,17 +56,17 @@ fn appendProcNetRows( defer alloc.free(remote_endpoint.address); try writer.writeAll("{\"protocol\":"); - try std.json.stringify(protocol, .{}, writer); + try std.json.Stringify.value(protocol, .{}, writer); try writer.writeAll(",\"local_address\":"); - try std.json.stringify(local_endpoint.address, .{}, writer); + try std.json.Stringify.value(local_endpoint.address, .{}, writer); try writer.print(",\"local_port\":{d}", .{local_endpoint.port}); try writer.writeAll(",\"remote_address\":"); - try std.json.stringify(remote_endpoint.address, .{}, writer); + try std.json.Stringify.value(remote_endpoint.address, .{}, writer); try writer.print(",\"remote_port\":{d}", .{remote_endpoint.port}); try writer.writeAll(",\"state\":"); - try std.json.stringify(state, .{}, writer); + try std.json.Stringify.value(state, .{}, writer); try writer.writeAll(",\"raw\":"); - try std.json.stringify(line, .{}, writer); + try std.json.Stringify.value(line, .{}, writer); try writer.writeByte('}'); } } @@ -104,17 +105,17 @@ fn parseProcNetEndpoint(alloc: std.mem.Allocator, protocol: []const u8, endpoint } fn collectMacos(alloc: std.mem.Allocator) ![]u8 { - const result = try std.process.Child.run(.{ - .allocator = alloc, + const result = try std.process.run(alloc, iox.current(), .{ .argv = &.{ "lsof", "-i", "-P", "-n" }, - .max_output_bytes = 512 * 1024, + .stdout_limit = .limited(512 * 1024), + .stderr_limit = .limited(512 * 1024), }); defer alloc.free(result.stdout); defer alloc.free(result.stderr); - var out = std.ArrayList(u8).init(alloc); + var out: std.Io.Writer.Allocating = .init(alloc); errdefer out.deinit(); - var w = out.writer(); + const w = &out.writer; try w.writeAll("{\"source\":\"lsof\",\"connections\":["); var lines = std.mem.splitScalar(u8, result.stdout, '\n'); @@ -130,7 +131,7 @@ fn collectMacos(alloc: std.mem.Allocator) ![]u8 { if (!first) try w.writeByte(','); first = false; try w.writeAll("{\"raw\":"); - try std.json.stringify(line, .{}, w); + try std.json.Stringify.value(line, .{}, w); try w.writeByte('}'); } try w.writeAll("]}"); @@ -151,7 +152,7 @@ extern "iphlpapi" fn GetExtendedTcpTable( ulAf: u32, TableClass: u32, Reserved: u32, -) callconv(.C) u32; +) callconv(.c) u32; extern "iphlpapi" fn GetExtendedUdpTable( pUdpTable: ?*anyopaque, @@ -160,7 +161,7 @@ extern "iphlpapi" fn GetExtendedUdpTable( ulAf: u32, TableClass: u32, Reserved: u32, -) callconv(.C) u32; +) callconv(.c) u32; fn collectWindows(alloc: std.mem.Allocator) ![]u8 { const tcp_bytes = try tableSize(alloc, true); @@ -193,9 +194,10 @@ fn tableSize(alloc: std.mem.Allocator, comptime tcp: bool) !u32 { } fn readFileAbsoluteAlloc(alloc: std.mem.Allocator, path: []const u8, max_bytes: usize) ![]u8 { - var file = try std.fs.openFileAbsolute(path, .{}); - defer file.close(); - return file.readToEndAlloc(alloc, max_bytes); + const io = iox.current(); + var file = try std.Io.Dir.openFileAbsolute(io, path, .{}); + defer file.close(io); + return iox.readToEndAlloc(file, alloc, max_bytes); } test "network collector module loads" { diff --git a/agent/src/collectors/process.zig b/agent/src/collectors/process.zig index 7a0b2f0..121f4e8 100644 --- a/agent/src/collectors/process.zig +++ b/agent/src/collectors/process.zig @@ -20,9 +20,9 @@ pub fn collect(alloc: std.mem.Allocator) ![]u8 { alloc.free(procs); } - var out = std.ArrayList(u8).init(alloc); + var out: std.Io.Writer.Allocating = .init(alloc); errdefer out.deinit(); - var w = out.writer(); + const w = &out.writer; try w.writeAll("{\"processes\":["); for (procs, 0..) |p, i| { @@ -41,7 +41,7 @@ pub fn collect(alloc: std.mem.Allocator) ![]u8 { } fn writeJsonString(writer: anytype, s: []const u8) !void { - try std.json.stringify(s, .{}, writer); + try std.json.Stringify.value(s, .{}, writer); } test "process collect runs" { @@ -53,13 +53,13 @@ test "process collect runs" { } test "process names are json escaped" { - var out = std.ArrayList(u8).init(std.testing.allocator); + var out: std.Io.Writer.Allocating = .init(std.testing.allocator); defer out.deinit(); - try writeJsonString(out.writer(), "bad\"name\\with\nnewline"); + try writeJsonString(&out.writer, "bad\"name\\with\nnewline"); try std.testing.expectEqualStrings( "\"bad\\\"name\\\\with\\nnewline\"", - out.items, + out.written(), ); } diff --git a/agent/src/collectors/process_events.zig b/agent/src/collectors/process_events.zig new file mode 100644 index 0000000..cff5df1 --- /dev/null +++ b/agent/src/collectors/process_events.zig @@ -0,0 +1,211 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const iox = @import("../io_compat.zig"); + +/// Diff-based process launch collector. Tracks the set of live PIDs across +/// invocations and emits a per-launch event for every PID that has appeared +/// since the previous tick. The image SHA-256 is computed from the resolved +/// executable path so the backend can pivot on hash. +/// +/// This is intentionally not a kernel-level exec hook (no eBPF, no ETW +/// kernel provider); we rely on /proc on Linux and shell out elsewhere. +/// A new process that starts and exits inside a single polling interval can +/// be missed — operators who need every exec should switch to kernel-grade +/// telemetry (Phase 2's eBPF / ETW path, deliberately out of scope here). +pub const Tracker = struct { + allocator: std.mem.Allocator, + seen: std.AutoHashMap(u32, void), + + pub fn init(alloc: std.mem.Allocator) Tracker { + return .{ + .allocator = alloc, + .seen = std.AutoHashMap(u32, void).init(alloc), + }; + } + + pub fn deinit(self: *Tracker) void { + self.seen.deinit(); + } + + /// Returns a list of JSON payloads, one per newly observed process. + /// Caller owns both the outer slice and each inner payload. + pub fn collectLaunches(self: *Tracker) ![][]u8 { + var payloads = std.array_list.Managed([]u8).init(self.allocator); + errdefer { + for (payloads.items) |p| self.allocator.free(p); + payloads.deinit(); + } + + switch (builtin.os.tag) { + .linux => try self.collectLinux(&payloads), + else => {}, // Win/macOS exec-event capture is deferred to kernel-level work. + } + + // GC PIDs that have gone away so the map doesn't grow unboundedly. + try self.pruneStale(); + + return payloads.toOwnedSlice(); + } + + fn collectLinux(self: *Tracker, payloads: *std.array_list.Managed([]u8)) !void { + const io = iox.current(); + var proc_dir = std.Io.Dir.openDirAbsolute(io, "/proc", .{ .iterate = true }) catch return; + defer proc_dir.close(io); + + var iter = proc_dir.iterate(); + while (try iter.next(io)) |entry| { + if (entry.kind != .directory) continue; + const pid = std.fmt.parseInt(u32, entry.name, 10) catch continue; + if (self.seen.contains(pid)) continue; + + // First sighting of this PID: emit a launch event and remember it. + try self.seen.put(pid, {}); + const payload = buildLinuxLaunchEvent(self.allocator, pid) catch continue; + try payloads.append(payload); + } + } + + fn pruneStale(self: *Tracker) !void { + var to_remove = std.array_list.Managed(u32).init(self.allocator); + defer to_remove.deinit(); + + var it = self.seen.iterator(); + while (it.next()) |kv| { + const pid = kv.key_ptr.*; + if (!pidExists(pid)) try to_remove.append(pid); + } + for (to_remove.items) |pid| _ = self.seen.remove(pid); + } +}; + +fn pidExists(pid: u32) bool { + if (builtin.os.tag != .linux) return true; + const io = iox.current(); + var buf: [64]u8 = undefined; + const path = std.fmt.bufPrint(&buf, "/proc/{d}", .{pid}) catch return false; + var dir = std.Io.Dir.openDirAbsolute(io, path, .{}) catch return false; + dir.close(io); + return true; +} + +fn buildLinuxLaunchEvent(alloc: std.mem.Allocator, pid: u32) ![]u8 { + const name = readProcText(alloc, pid, "comm") catch try alloc.dupe(u8, "unknown"); + defer alloc.free(name); + const trimmed_name = std.mem.trimEnd(u8, name, "\r\n"); + + const command_line = readCommandLine(alloc, pid) catch try alloc.dupe(u8, trimmed_name); + defer alloc.free(command_line); + + const exe_path = readExeLink(alloc, pid) catch try alloc.dupe(u8, ""); + defer alloc.free(exe_path); + + const ppid = readParentPid(alloc, pid) catch 0; + + var digest = std.mem.zeroes([32]u8); + var hashed = false; + if (exe_path.len > 0) { + if (hashImage(exe_path)) |d| { + digest = d; + hashed = true; + } else |_| {} + } + + const uid = readUid(alloc, pid) catch 0; + + var out: std.Io.Writer.Allocating = .init(alloc); + errdefer out.deinit(); + const w = &out.writer; + + try w.print("{{\"pid\":{d},\"ppid\":{d},\"uid\":{d},\"name\":", .{ pid, ppid, uid }); + try std.json.Stringify.value(trimmed_name, .{}, w); + try w.writeAll(",\"command_line\":"); + try std.json.Stringify.value(command_line, .{}, w); + try w.writeAll(",\"image_path\":"); + try std.json.Stringify.value(exe_path, .{}, w); + if (hashed) { + const hex = std.fmt.bytesToHex(digest, .lower); + try w.print(",\"image_sha256\":\"{s}\"", .{hex}); + } else { + try w.writeAll(",\"image_sha256\":null"); + } + try w.writeAll(",\"signature\":{\"trusted\":false,\"signer\":null}}"); + return out.toOwnedSlice(); +} + +fn readProcText(alloc: std.mem.Allocator, pid: u32, name: []const u8) ![]u8 { + const path = try std.fmt.allocPrint(alloc, "/proc/{d}/{s}", .{ pid, name }); + defer alloc.free(path); + const io = iox.current(); + var file = try std.Io.Dir.openFileAbsolute(io, path, .{}); + defer file.close(io); + return iox.readToEndAlloc(file, alloc, 16 * 1024); +} + +fn readCommandLine(alloc: std.mem.Allocator, pid: u32) ![]u8 { + const raw = try readProcText(alloc, pid, "cmdline"); + defer alloc.free(raw); + const owned = try alloc.dupe(u8, raw); + for (owned) |*ch| { + if (ch.* == 0) ch.* = ' '; + } + const trimmed = std.mem.trimEnd(u8, owned, " "); + if (trimmed.len == owned.len) return owned; + const compact = try alloc.dupe(u8, trimmed); + alloc.free(owned); + return compact; +} + +fn readParentPid(alloc: std.mem.Allocator, pid: u32) !u32 { + const stat = try readProcText(alloc, pid, "stat"); + defer alloc.free(stat); + const close = std.mem.lastIndexOfScalar(u8, stat, ')') orelse return error.BadProcStat; + var fields = std.mem.tokenizeAny(u8, stat[close + 1 ..], " \t\r\n"); + _ = fields.next() orelse return error.BadProcStat; + const ppid = fields.next() orelse return error.BadProcStat; + return std.fmt.parseInt(u32, ppid, 10); +} + +fn readUid(alloc: std.mem.Allocator, pid: u32) !u32 { + const status = try readProcText(alloc, pid, "status"); + defer alloc.free(status); + var lines = std.mem.splitScalar(u8, status, '\n'); + while (lines.next()) |line| { + if (std.mem.startsWith(u8, line, "Uid:")) { + var fields = std.mem.tokenizeAny(u8, line[4..], " \t"); + const raw = fields.next() orelse return 0; + return std.fmt.parseInt(u32, raw, 10) catch 0; + } + } + return 0; +} + +fn readExeLink(alloc: std.mem.Allocator, pid: u32) ![]u8 { + const link_path = try std.fmt.allocPrint(alloc, "/proc/{d}/exe", .{pid}); + defer alloc.free(link_path); + var buf: [4096]u8 = undefined; + const len = std.Io.Dir.readLinkAbsolute(iox.current(), link_path, &buf) catch return alloc.dupe(u8, ""); + return alloc.dupe(u8, buf[0..len]); +} + +fn hashImage(path: []const u8) ![32]u8 { + const io = iox.current(); + var file = try std.Io.Dir.openFileAbsolute(io, path, .{}); + defer file.close(io); + var sha = std.crypto.hash.sha2.Sha256.init(.{}); + var buf: [16 * 1024]u8 = undefined; + var offset: u64 = 0; + while (true) { + const n = try file.readPositionalAll(io, &buf, offset); + if (n == 0) break; + sha.update(buf[0..n]); + offset += n; + } + var digest: [32]u8 = undefined; + sha.final(&digest); + return digest; +} + +test "tracker module loads" { + var t = Tracker.init(std.testing.allocator); + defer t.deinit(); +} diff --git a/agent/src/collectors/system.zig b/agent/src/collectors/system.zig index b924ee9..c8a62f4 100644 --- a/agent/src/collectors/system.zig +++ b/agent/src/collectors/system.zig @@ -1,5 +1,6 @@ const std = @import("std"); const builtin = @import("builtin"); +const iox = @import("../io_compat.zig"); pub fn collect(alloc: std.mem.Allocator) ![]u8 { return switch (builtin.os.tag) { @@ -24,18 +25,18 @@ fn collectMacos(alloc: std.mem.Allocator) ![]u8 { const brand = try sysctlString(alloc, "machdep.cpu.brand_string"); defer alloc.free(brand); - var out = std.ArrayList(u8).init(alloc); + var out: std.Io.Writer.Allocating = .init(alloc); errdefer out.deinit(); - var w = out.writer(); + const w = &out.writer; try w.writeAll("{\"platform\":\"macos\",\"hostname\":"); - try std.json.stringify(std.mem.sliceTo(&uts.nodename, 0), .{}, w); + try std.json.Stringify.value(std.mem.sliceTo(&uts.nodename, 0), .{}, w); try w.writeAll(",\"kernel\":"); - try std.json.stringify(std.mem.sliceTo(&uts.release, 0), .{}, w); + try std.json.Stringify.value(std.mem.sliceTo(&uts.release, 0), .{}, w); try w.writeAll(",\"architecture\":"); - try std.json.stringify(std.mem.sliceTo(&uts.machine, 0), .{}, w); + try std.json.Stringify.value(std.mem.sliceTo(&uts.machine, 0), .{}, w); try w.print(",\"memory_bytes\":{d},\"cpu_count\":{d},\"cpu_brand\":", .{ mem_bytes, cpu_count }); - try std.json.stringify(brand, .{}, w); + try std.json.Stringify.value(brand, .{}, w); try w.writeByte('}'); return out.toOwnedSlice(); @@ -77,16 +78,16 @@ fn collectLinux(alloc: std.mem.Allocator) ![]u8 { const mem_total_kb = readMemTotalKb(alloc) catch 0; const cpu_count = std.Thread.getCpuCount() catch 0; - var out = std.ArrayList(u8).init(alloc); + var out: std.Io.Writer.Allocating = .init(alloc); errdefer out.deinit(); - var w = out.writer(); + const w = &out.writer; try w.writeAll("{\"platform\":\"linux\",\"hostname\":"); - try std.json.stringify(std.mem.trimRight(u8, hostname_raw, "\r\n"), .{}, w); + try std.json.Stringify.value(std.mem.trimEnd(u8, hostname_raw, "\r\n"), .{}, w); try w.writeAll(",\"kernel\":"); - try std.json.stringify(std.mem.trimRight(u8, kernel_raw, "\r\n"), .{}, w); + try std.json.Stringify.value(std.mem.trimEnd(u8, kernel_raw, "\r\n"), .{}, w); try w.writeAll(",\"architecture\":"); - try std.json.stringify(@tagName(builtin.target.cpu.arch), .{}, w); + try std.json.Stringify.value(@tagName(builtin.target.cpu.arch), .{}, w); try w.print(",\"memory_bytes\":{d},\"cpu_count\":{d}}}", .{ mem_total_kb * 1024, cpu_count, @@ -111,9 +112,10 @@ fn readMemTotalKb(alloc: std.mem.Allocator) !u64 { } fn readFileAbsoluteAlloc(alloc: std.mem.Allocator, path: []const u8, max_bytes: usize) ![]u8 { - var file = try std.fs.openFileAbsolute(path, .{}); - defer file.close(); - return file.readToEndAlloc(alloc, max_bytes); + const io = iox.current(); + var file = try std.Io.Dir.openFileAbsolute(io, path, .{}); + defer file.close(io); + return iox.readToEndAlloc(file, alloc, max_bytes); } const COMPUTER_NAME_FORMAT = enum(u32) { @@ -141,9 +143,9 @@ const MEMORYSTATUSEX = extern struct { ullAvailExtendedVirtual: u64, }; -extern "kernel32" fn GetComputerNameExW(NameType: COMPUTER_NAME_FORMAT, lpBuffer: ?[*]u16, nSize: *u32) callconv(.C) i32; -extern "kernel32" fn GlobalMemoryStatusEx(lpBuffer: *MEMORYSTATUSEX) callconv(.C) i32; -extern "ntdll" fn RtlGetVersion(lpVersionInformation: *RTL_OSVERSIONINFOW) callconv(.C) i32; +extern "kernel32" fn GetComputerNameExW(NameType: COMPUTER_NAME_FORMAT, lpBuffer: ?[*]u16, nSize: *u32) callconv(.c) i32; +extern "kernel32" fn GlobalMemoryStatusEx(lpBuffer: *MEMORYSTATUSEX) callconv(.c) i32; +extern "ntdll" fn RtlGetVersion(lpVersionInformation: *RTL_OSVERSIONINFOW) callconv(.c) i32; fn collectWindows(alloc: std.mem.Allocator) ![]u8 { var name_len: u32 = 0; @@ -162,12 +164,12 @@ fn collectWindows(alloc: std.mem.Allocator) ![]u8 { mem.dwLength = @sizeOf(MEMORYSTATUSEX); if (GlobalMemoryStatusEx(&mem) == 0) return error.MemoryStatusFailed; - var out = std.ArrayList(u8).init(alloc); + var out: std.Io.Writer.Allocating = .init(alloc); errdefer out.deinit(); - var w = out.writer(); + const w = &out.writer; try w.writeAll("{\"platform\":\"windows\",\"hostname\":"); - try std.json.stringify(hostname, .{}, w); + try std.json.Stringify.value(hostname, .{}, w); try w.print( ",\"major\":{d},\"minor\":{d},\"build\":{d},\"memory_bytes\":{d}}}", .{ version.dwMajorVersion, version.dwMinorVersion, version.dwBuildNumber, mem.ullTotalPhys }, diff --git a/agent/src/collectors/users.zig b/agent/src/collectors/users.zig index 9255e3a..15c3531 100644 --- a/agent/src/collectors/users.zig +++ b/agent/src/collectors/users.zig @@ -1,5 +1,6 @@ const std = @import("std"); const builtin = @import("builtin"); +const iox = @import("../io_compat.zig"); pub fn collect(alloc: std.mem.Allocator) ![]u8 { return switch (builtin.os.tag) { @@ -15,9 +16,9 @@ const c = if (builtin.os.tag == .macos) @cImport({ }) else struct {}; fn collectMacos(alloc: std.mem.Allocator) ![]u8 { - var out = std.ArrayList(u8).init(alloc); + var out: std.Io.Writer.Allocating = .init(alloc); errdefer out.deinit(); - var w = out.writer(); + const w = &out.writer; try w.writeAll("{\"source\":\"utmpx\",\"sessions\":["); c.setutxent(); @@ -31,9 +32,9 @@ fn collectMacos(alloc: std.mem.Allocator) ![]u8 { first = false; try w.writeAll("{\"user\":"); - try std.json.stringify(std.mem.sliceTo(&session.ut_user, 0), .{}, w); + try std.json.Stringify.value(std.mem.sliceTo(&session.ut_user, 0), .{}, w); try w.writeAll(",\"line\":"); - try std.json.stringify(std.mem.sliceTo(&session.ut_line, 0), .{}, w); + try std.json.Stringify.value(std.mem.sliceTo(&session.ut_line, 0), .{}, w); try w.print(",\"pid\":{d}}}", .{session.ut_pid}); } try w.writeAll("]}"); @@ -42,19 +43,19 @@ fn collectMacos(alloc: std.mem.Allocator) ![]u8 { } fn collectLinux(alloc: std.mem.Allocator) ![]u8 { - const result = std.process.Child.run(.{ - .allocator = alloc, + const result = std.process.run(alloc, iox.current(), .{ .argv = &.{ "who" }, - .max_output_bytes = 128 * 1024, + .stdout_limit = .limited(128 * 1024), + .stderr_limit = .limited(128 * 1024), }) catch { return alloc.dupe(u8, "{\"source\":\"who\",\"sessions\":[]}"); }; defer alloc.free(result.stdout); defer alloc.free(result.stderr); - var out = std.ArrayList(u8).init(alloc); + var out: std.Io.Writer.Allocating = .init(alloc); errdefer out.deinit(); - var w = out.writer(); + const w = &out.writer; try w.writeAll("{\"source\":\"who\",\"sessions\":["); var first = true; @@ -68,11 +69,11 @@ fn collectLinux(alloc: std.mem.Allocator) ![]u8 { if (!first) try w.writeByte(','); first = false; try w.writeAll("{\"user\":"); - try std.json.stringify(user, .{}, w); + try std.json.Stringify.value(user, .{}, w); try w.writeAll(",\"line\":"); - try std.json.stringify(tty, .{}, w); + try std.json.Stringify.value(tty, .{}, w); try w.writeAll(",\"raw\":"); - try std.json.stringify(line, .{}, w); + try std.json.Stringify.value(line, .{}, w); try w.writeByte('}'); } try w.writeAll("]}"); @@ -95,9 +96,9 @@ extern "wtsapi32" fn WTSEnumerateSessionsW( Version: u32, ppSessionInfo: *?[*]WTS_SESSION_INFOW, pCount: *u32, -) callconv(.C) i32; +) callconv(.c) i32; -extern "wtsapi32" fn WTSFreeMemory(pMemory: ?*anyopaque) callconv(.C) void; +extern "wtsapi32" fn WTSFreeMemory(pMemory: ?*anyopaque) callconv(.c) void; fn collectWindows(alloc: std.mem.Allocator) ![]u8 { var sessions_ptr: ?[*]WTS_SESSION_INFOW = null; @@ -107,9 +108,9 @@ fn collectWindows(alloc: std.mem.Allocator) ![]u8 { } defer WTSFreeMemory(@ptrCast(sessions_ptr)); - var out = std.ArrayList(u8).init(alloc); + var out: std.Io.Writer.Allocating = .init(alloc); errdefer out.deinit(); - var w = out.writer(); + const w = &out.writer; try w.writeAll("{\"source\":\"wts\",\"sessions\":["); var first = true; @@ -126,7 +127,7 @@ fn collectWindows(alloc: std.mem.Allocator) ![]u8 { defer alloc.free(name); try w.print("{{\"session_id\":{d},\"station\":", .{session.SessionId}); - try std.json.stringify(name, .{}, w); + try std.json.Stringify.value(name, .{}, w); try w.writeByte('}'); } try w.writeAll("]}"); diff --git a/agent/src/config.zig b/agent/src/config.zig index 41c2022..e397b6e 100644 --- a/agent/src/config.zig +++ b/agent/src/config.zig @@ -1,5 +1,7 @@ const std = @import("std"); const builtin = @import("builtin"); +const env = @import("env.zig"); +const iox = @import("io_compat.zig"); pub const Config = struct { allocator: std.mem.Allocator, @@ -9,10 +11,14 @@ pub const Config = struct { agent_jwt: ?[]u8 = null, heartbeat_interval_seconds: u32 = 60, process_interval_seconds: u32 = 30, + process_events_interval_seconds: u32 = 5, network_interval_seconds: u32 = 30, users_interval_seconds: u32 = 300, system_interval_seconds: u32 = 3600, fim_interval_seconds: u32 = 300, + fs_events_interval_seconds: u32 = 5, + dns_interval_seconds: u32 = 30, + supply_chain_interval_seconds: u32 = 21600, max_in_memory_events: usize = 1000, fim_paths: [][]u8 = &.{}, spill_path: []u8, @@ -33,7 +39,7 @@ pub const Config = struct { /// Resolve the platform-default config directory. fn defaultConfigPath(alloc: std.mem.Allocator) ![]u8 { if (builtin.os.tag == .windows) { - const programdata = std.process.getEnvVarOwned(alloc, "PROGRAMDATA") catch + const programdata = env.getEnvVarOwned(alloc, "PROGRAMDATA") catch try alloc.dupe(u8, "C:\\ProgramData"); defer alloc.free(programdata); return std.fmt.allocPrint(alloc, "{s}\\Tawny\\config.toml", .{programdata}); @@ -46,7 +52,7 @@ fn defaultConfigPath(alloc: std.mem.Allocator) ![]u8 { /// Read TOML-ish config. Trivial line-based parser — good enough for MVP. pub fn load(alloc: std.mem.Allocator) !Config { - const env_path = std.process.getEnvVarOwned(alloc, "TAWNY_CONFIG") catch null; + const env_path = env.getEnvVarOwned(alloc, "TAWNY_CONFIG") catch null; const path: []u8 = if (env_path) |p| p else try defaultConfigPath(alloc); var cfg = Config{ @@ -56,13 +62,14 @@ pub fn load(alloc: std.mem.Allocator) !Config { .config_path = path, }; - const file = std.fs.cwd().openFile(path, .{}) catch { + const io = iox.current(); + const file = std.Io.Dir.cwd().openFile(io, path, .{}) catch { // First run: emit a default config alongside the binary. return cfg; }; - defer file.close(); + defer file.close(io); - const raw = try file.readToEndAlloc(alloc, 64 * 1024); + const raw = try iox.readToEndAlloc(file, alloc, 64 * 1024); defer alloc.free(raw); var line_iter = std.mem.splitScalar(u8, raw, '\n'); @@ -95,6 +102,14 @@ pub fn load(alloc: std.mem.Allocator) !Config { cfg.system_interval_seconds = try std.fmt.parseInt(u32, val, 10); } else if (std.mem.eql(u8, key, "fim_interval_seconds")) { cfg.fim_interval_seconds = try std.fmt.parseInt(u32, val, 10); + } else if (std.mem.eql(u8, key, "process_events_interval_seconds")) { + cfg.process_events_interval_seconds = try std.fmt.parseInt(u32, val, 10); + } else if (std.mem.eql(u8, key, "fs_events_interval_seconds")) { + cfg.fs_events_interval_seconds = try std.fmt.parseInt(u32, val, 10); + } else if (std.mem.eql(u8, key, "dns_interval_seconds")) { + cfg.dns_interval_seconds = try std.fmt.parseInt(u32, val, 10); + } else if (std.mem.eql(u8, key, "supply_chain_interval_seconds")) { + cfg.supply_chain_interval_seconds = try std.fmt.parseInt(u32, val, 10); } else if (std.mem.eql(u8, key, "max_in_memory_events")) { cfg.max_in_memory_events = try std.fmt.parseInt(usize, val, 10); } else if (std.mem.eql(u8, key, "spill_path")) { @@ -113,15 +128,18 @@ pub fn load(alloc: std.mem.Allocator) !Config { pub fn save(cfg: *const Config) !void { // Best-effort write; create parent dirs as needed. const dir = std.fs.path.dirname(cfg.config_path) orelse "."; - std.fs.cwd().makePath(dir) catch {}; + const io = iox.current(); + std.Io.Dir.cwd().createDirPath(io, dir) catch {}; const tmp_path = try std.fmt.allocPrint(cfg.allocator, "{s}.tmp", .{cfg.config_path}); defer cfg.allocator.free(tmp_path); { - var file = try std.fs.cwd().createFile(tmp_path, .{ .truncate = true }); - defer file.close(); - var w = file.writer(); + var file = try std.Io.Dir.cwd().createFile(io, tmp_path, .{ .truncate = true }); + defer file.close(io); + var writer_buffer: [4096]u8 = undefined; + var file_writer = file.writer(io, &writer_buffer); + const w = &file_writer.interface; try w.print("[backend]\nurl = \"{s}\"\n\n", .{cfg.backend_url}); if (cfg.agent_id) |id| try w.print("agent_id = \"{s}\"\n", .{id}); @@ -132,40 +150,43 @@ pub fn save(cfg: *const Config) !void { \\[collection] \\heartbeat_interval_seconds = {d} \\process_interval_seconds = {d} + \\process_events_interval_seconds = {d} \\network_interval_seconds = {d} \\users_interval_seconds = {d} \\system_interval_seconds = {d} \\fim_interval_seconds = {d} + \\fs_events_interval_seconds = {d} + \\dns_interval_seconds = {d} + \\supply_chain_interval_seconds = {d} \\max_in_memory_events = {d} \\spill_path = , .{ cfg.heartbeat_interval_seconds, cfg.process_interval_seconds, + cfg.process_events_interval_seconds, cfg.network_interval_seconds, cfg.users_interval_seconds, cfg.system_interval_seconds, cfg.fim_interval_seconds, + cfg.fs_events_interval_seconds, + cfg.dns_interval_seconds, + cfg.supply_chain_interval_seconds, cfg.max_in_memory_events, }); try w.writeByte(' '); - try std.json.stringify(cfg.spill_path, .{}, w); + try std.json.Stringify.value(cfg.spill_path, .{}, w); try w.writeAll("\nfim_paths = ["); for (cfg.fim_paths, 0..) |path, i| { if (i > 0) try w.writeAll(", "); - try std.json.stringify(path, .{}, w); + try std.json.Stringify.value(path, .{}, w); } try w.writeAll("]\n"); - try file.sync(); + try w.flush(); + try file.sync(io); } - std.fs.cwd().rename(tmp_path, cfg.config_path) catch |err| switch (err) { - error.PathAlreadyExists => { - std.fs.cwd().deleteFile(cfg.config_path) catch {}; - try std.fs.cwd().rename(tmp_path, cfg.config_path); - }, - else => return err, - }; + try std.Io.Dir.cwd().rename(tmp_path, std.Io.Dir.cwd(), cfg.config_path, io); } fn appendFimPaths(cfg: *Config, raw: []const u8) !void { diff --git a/agent/src/enrollment.zig b/agent/src/enrollment.zig index 76936ef..ca032f9 100644 --- a/agent/src/enrollment.zig +++ b/agent/src/enrollment.zig @@ -2,11 +2,12 @@ const std = @import("std"); const builtin = @import("builtin"); const Config = @import("config.zig").Config; +const iox = @import("io_compat.zig"); extern "kernel32" fn GetComputerNameA( name: [*]u8, size: *u32, -) callconv(.C) i32; +) callconv(.c) i32; fn getHostname(buf: []u8) ![]const u8 { if (builtin.os.tag == .windows) { @@ -39,10 +40,9 @@ pub fn run(alloc: std.mem.Allocator, cfg: *Config, agent_version: []const u8) !v else => "unknown", }; - var body_buf = std.ArrayList(u8).init(alloc); + var body_buf = std.array_list.Managed(u8).init(alloc); defer body_buf.deinit(); - var w = body_buf.writer(); - try w.print( + try body_buf.print( \\{{"enrollment_token":"{s}","hostname":"{s}","os":"{s}","os_version":"unknown","arch":"{s}","agent_version":"{s}"}} , .{ token, hostname, os_str, arch_str, agent_version }, @@ -51,10 +51,10 @@ pub fn run(alloc: std.mem.Allocator, cfg: *Config, agent_version: []const u8) !v const url = try std.fmt.allocPrint(alloc, "{s}/api/agents/enroll", .{cfg.backend_url}); defer alloc.free(url); - var client = std.http.Client{ .allocator = alloc }; + var client = std.http.Client{ .allocator = alloc, .io = iox.current() }; defer client.deinit(); - var response_body = std.ArrayList(u8).init(alloc); + var response_body: std.Io.Writer.Allocating = .init(alloc); defer response_body.deinit(); const res = try client.fetch(.{ @@ -62,7 +62,7 @@ pub fn run(alloc: std.mem.Allocator, cfg: *Config, agent_version: []const u8) !v .location = .{ .url = url }, .headers = .{ .content_type = .{ .override = "application/json" } }, .payload = body_buf.items, - .response_storage = .{ .dynamic = &response_body }, + .response_writer = &response_body.writer, }); if (res.status != .ok) return error.EnrollmentFailed; @@ -71,7 +71,7 @@ pub fn run(alloc: std.mem.Allocator, cfg: *Config, agent_version: []const u8) !v agent_id: []const u8, jwt: []const u8, jwt_expires_at: []const u8, - }, alloc, response_body.items, .{ .ignore_unknown_fields = true }); + }, alloc, response_body.written(), .{ .ignore_unknown_fields = true }); defer parsed.deinit(); cfg.agent_id = try alloc.dupe(u8, parsed.value.agent_id); diff --git a/agent/src/env.zig b/agent/src/env.zig new file mode 100644 index 0000000..943ee2c --- /dev/null +++ b/agent/src/env.zig @@ -0,0 +1,11 @@ +const std = @import("std"); + +pub const GetEnvVarError = error{ + EnvironmentVariableNotFound, + OutOfMemory, +}; + +pub fn getEnvVarOwned(allocator: std.mem.Allocator, name: [:0]const u8) GetEnvVarError![]u8 { + const value = std.c.getenv(name.ptr) orelse return error.EnvironmentVariableNotFound; + return allocator.dupe(u8, std.mem.span(value)); +} diff --git a/agent/src/io_compat.zig b/agent/src/io_compat.zig new file mode 100644 index 0000000..e8cabbe --- /dev/null +++ b/agent/src/io_compat.zig @@ -0,0 +1,43 @@ +const std = @import("std"); + +pub fn current() std.Io { + return std.Io.Threaded.global_single_threaded.io(); +} + +pub fn readToEndAlloc(file: std.Io.File, allocator: std.mem.Allocator, max_bytes: usize) ![]u8 { + const io = current(); + const stat = try file.stat(io); + if (stat.size > max_bytes) return error.FileTooBig; + const buf = try allocator.alloc(u8, @intCast(stat.size)); + errdefer allocator.free(buf); + const read = try file.readPositionalAll(io, buf, 0); + return buf[0..read]; +} + +pub fn timestamp() i64 { + return std.Io.Clock.now(.real, current()).toSeconds(); +} + +pub fn sleep(nanoseconds: u64) void { + std.Io.sleep(current(), .fromNanoseconds(@intCast(nanoseconds)), .awake) catch {}; +} + +pub const Timer = struct { + started_ns: i96, + + pub fn start() !Timer { + return .{ .started_ns = nowNs() }; + } + + pub fn read(self: Timer) u64 { + return @intCast(nowNs() - self.started_ns); + } + + pub fn reset(self: *Timer) void { + self.started_ns = nowNs(); + } + + fn nowNs() i96 { + return std.Io.Clock.now(.awake, current()).nanoseconds; + } +}; diff --git a/agent/src/main.zig b/agent/src/main.zig index 0b20154..166ca54 100644 --- a/agent/src/main.zig +++ b/agent/src/main.zig @@ -10,26 +10,32 @@ const network_collector = @import("collectors/network.zig"); const users_collector = @import("collectors/users.zig"); const system_collector = @import("collectors/system.zig"); const fim_collector = @import("collectors/fim.zig"); +const process_events = @import("collectors/process_events.zig"); +const fs_events = @import("collectors/fs_events.zig"); +const dns_collector = @import("collectors/dns.zig"); +const inventory_collector = @import("collectors/inventory.zig"); +const extensions_collector = @import("collectors/extensions.zig"); +const mcp_collector = @import("collectors/mcp_config.zig"); const response_actions = @import("response_actions.zig"); +const iox = @import("io_compat.zig"); const AGENT_VERSION = "0.1.0"; pub fn main() !void { - var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + var gpa = std.heap.DebugAllocator(.{}){}; defer _ = gpa.deinit(); const alloc = gpa.allocator(); - var stderr = std.io.getStdErr().writer(); - try stderr.print("tawny-agent {s} starting on {s}\n", .{ AGENT_VERSION, @tagName(builtin.os.tag) }); + std.debug.print("tawny-agent {s} starting on {s}\n", .{ AGENT_VERSION, @tagName(builtin.os.tag) }); var cfg = config_mod.load(alloc) catch |err| { - try stderr.print("config load failed: {s}\n", .{@errorName(err)}); + std.debug.print("config load failed: {s}\n", .{@errorName(err)}); return err; }; defer cfg.deinit(); if (cfg.agent_id == null) { - try stderr.print("not enrolled; running enrollment...\n", .{}); + std.debug.print("not enrolled; running enrollment...\n", .{}); try enrollment.run(alloc, &cfg, AGENT_VERSION); try config_mod.save(&cfg); } @@ -44,24 +50,40 @@ pub fn main() !void { var fim = try fim_collector.Watcher.init(alloc, cfg.fim_paths); defer fim.deinit(); - var heartbeat_timer = try std.time.Timer.start(); - var process_timer = try std.time.Timer.start(); - var network_timer = try std.time.Timer.start(); - var users_timer = try std.time.Timer.start(); - var system_timer = try std.time.Timer.start(); - var fim_timer = try std.time.Timer.start(); - const start_time = std.time.timestamp(); + var process_tracker = process_events.Tracker.init(alloc); + defer process_tracker.deinit(); + + var fs_watcher = try fs_events.Watcher.init(alloc, cfg.fim_paths); + defer fs_watcher.deinit(); + + var dns = dns_collector.Collector.init(alloc); + + var inventory = inventory_collector.Scanner.init(alloc); + var extensions = extensions_collector.Scanner.init(alloc); + var mcp = mcp_collector.Scanner.init(alloc); + + var heartbeat_timer = try iox.Timer.start(); + var process_timer = try iox.Timer.start(); + var process_events_timer = try iox.Timer.start(); + var network_timer = try iox.Timer.start(); + var users_timer = try iox.Timer.start(); + var system_timer = try iox.Timer.start(); + var fim_timer = try iox.Timer.start(); + var fs_events_timer = try iox.Timer.start(); + var dns_timer = try iox.Timer.start(); + var supply_chain_timer = try iox.Timer.start(); + const start_time = iox.timestamp(); if (system_collector.collect(alloc)) |payload| { defer alloc.free(payload); try buf.push(.{ .event_type = "system_info", - .occurred_at = std.time.timestamp(), + .occurred_at = iox.timestamp(), .payload = payload, }); try spillIfNeeded(&buf, cfg.spill_path); } else |err| { - try stderr.print("system collector failed: {s}\n", .{@errorName(err)}); + std.debug.print("system collector failed: {s}\n", .{@errorName(err)}); } while (true) { @@ -69,21 +91,21 @@ pub fn main() !void { heartbeat_timer.reset(); var heartbeat = http.heartbeat(.{ .agent_version = AGENT_VERSION, - .uptime_seconds = @intCast(std.time.timestamp() - start_time), + .uptime_seconds = @intCast(iox.timestamp() - start_time), .buffer_depth = buf.len(), }) catch |err| { - try stderr.print("heartbeat failed: {s}\n", .{@errorName(err)}); + std.debug.print("heartbeat failed: {s}\n", .{@errorName(err)}); continue; }; defer heartbeat.deinit(alloc); if (heartbeat.rotated_jwt) |rotated| { persistRotatedJwt(alloc, &cfg, &http, rotated) catch |err| { - try stderr.print("jwt rotation persist failed: {s}\n", .{@errorName(err)}); + std.debug.print("jwt rotation persist failed: {s}\n", .{@errorName(err)}); }; } for (heartbeat.actions) |action| { response_actions.execute(alloc, &http, action) catch |err| { - try stderr.print("response action {s} failed: {s}\n", .{ action.id, @errorName(err) }); + std.debug.print("response action {s} failed: {s}\n", .{ action.id, @errorName(err) }); }; } } @@ -91,13 +113,13 @@ pub fn main() !void { if (process_timer.read() / std.time.ns_per_s >= cfg.process_interval_seconds) { process_timer.reset(); const snap = process_collector.collect(alloc) catch |err| { - try stderr.print("process collector failed: {s}\n", .{@errorName(err)}); + std.debug.print("process collector failed: {s}\n", .{@errorName(err)}); continue; }; defer alloc.free(snap); try buf.push(.{ .event_type = "process_snapshot", - .occurred_at = std.time.timestamp(), + .occurred_at = iox.timestamp(), .payload = snap, }); try spillIfNeeded(&buf, cfg.spill_path); @@ -106,13 +128,13 @@ pub fn main() !void { if (network_timer.read() / std.time.ns_per_s >= cfg.network_interval_seconds) { network_timer.reset(); const snap = network_collector.collect(alloc) catch |err| { - try stderr.print("network collector failed: {s}\n", .{@errorName(err)}); + std.debug.print("network collector failed: {s}\n", .{@errorName(err)}); continue; }; defer alloc.free(snap); try buf.push(.{ .event_type = "network_snapshot", - .occurred_at = std.time.timestamp(), + .occurred_at = iox.timestamp(), .payload = snap, }); try spillIfNeeded(&buf, cfg.spill_path); @@ -121,13 +143,13 @@ pub fn main() !void { if (users_timer.read() / std.time.ns_per_s >= cfg.users_interval_seconds) { users_timer.reset(); const snap = users_collector.collect(alloc) catch |err| { - try stderr.print("users collector failed: {s}\n", .{@errorName(err)}); + std.debug.print("users collector failed: {s}\n", .{@errorName(err)}); continue; }; defer alloc.free(snap); try buf.push(.{ .event_type = "user_session", - .occurred_at = std.time.timestamp(), + .occurred_at = iox.timestamp(), .payload = snap, }); try spillIfNeeded(&buf, cfg.spill_path); @@ -136,13 +158,13 @@ pub fn main() !void { if (system_timer.read() / std.time.ns_per_s >= cfg.system_interval_seconds) { system_timer.reset(); const snap = system_collector.collect(alloc) catch |err| { - try stderr.print("system collector failed: {s}\n", .{@errorName(err)}); + std.debug.print("system collector failed: {s}\n", .{@errorName(err)}); continue; }; defer alloc.free(snap); try buf.push(.{ .event_type = "system_info", - .occurred_at = std.time.timestamp(), + .occurred_at = iox.timestamp(), .payload = snap, }); try spillIfNeeded(&buf, cfg.spill_path); @@ -151,7 +173,7 @@ pub fn main() !void { if (fim_timer.read() / std.time.ns_per_s >= cfg.fim_interval_seconds) { fim_timer.reset(); const changes = fim.collectChanges() catch |err| { - try stderr.print("fim collector failed: {s}\n", .{@errorName(err)}); + std.debug.print("fim collector failed: {s}\n", .{@errorName(err)}); continue; }; defer alloc.free(changes); @@ -159,32 +181,89 @@ pub fn main() !void { defer alloc.free(payload); try buf.push(.{ .event_type = "file_integrity", - .occurred_at = std.time.timestamp(), + .occurred_at = iox.timestamp(), .payload = payload, }); try spillIfNeeded(&buf, cfg.spill_path); } } + if (process_events_timer.read() / std.time.ns_per_s >= cfg.process_events_interval_seconds) { + process_events_timer.reset(); + if (process_tracker.collectLaunches()) |launches| { + try emitBatch(&buf, cfg.spill_path, "process_launch", launches, alloc); + } else |err| { + std.debug.print("process_events collector failed: {s}\n", .{@errorName(err)}); + } + } + + if (fs_events_timer.read() / std.time.ns_per_s >= cfg.fs_events_interval_seconds) { + fs_events_timer.reset(); + if (fs_watcher.collectEvents()) |events| { + try emitBatch(&buf, cfg.spill_path, "file_event", events, alloc); + } else |err| { + std.debug.print("fs_events collector failed: {s}\n", .{@errorName(err)}); + } + } + + if (dns_timer.read() / std.time.ns_per_s >= cfg.dns_interval_seconds) { + dns_timer.reset(); + if (dns.collectQueries()) |queries| { + try emitBatch(&buf, cfg.spill_path, "dns_query", queries, alloc); + } else |err| { + std.debug.print("dns collector failed: {s}\n", .{@errorName(err)}); + } + } + + // Supply-chain inventory + extensions + MCP configs run on a much + // longer cadence — these are slow filesystem walks, not real-time. + if (supply_chain_timer.read() / std.time.ns_per_s >= cfg.supply_chain_interval_seconds) { + supply_chain_timer.reset(); + + if (inventory.collectInventory()) |payloads| { + try emitBatch(&buf, cfg.spill_path, "package_inventory", payloads, alloc); + } else |err| { + std.debug.print("inventory collector failed: {s}\n", .{@errorName(err)}); + } + + if (extensions.collectExtensions(.editor)) |payloads| { + try emitBatch(&buf, cfg.spill_path, "editor_extension", payloads, alloc); + } else |err| { + std.debug.print("editor extensions failed: {s}\n", .{@errorName(err)}); + } + + if (extensions.collectExtensions(.browser)) |payloads| { + try emitBatch(&buf, cfg.spill_path, "browser_extension", payloads, alloc); + } else |err| { + std.debug.print("browser extensions failed: {s}\n", .{@errorName(err)}); + } + + if (mcp.collectConfigs()) |payloads| { + try emitBatch(&buf, cfg.spill_path, "mcp_config", payloads, alloc); + } else |err| { + std.debug.print("mcp config failed: {s}\n", .{@errorName(err)}); + } + } + if (buf.len() == 0) { buf.replay(cfg.spill_path) catch |err| { - try stderr.print("buffer replay failed: {s}\n", .{@errorName(err)}); + std.debug.print("buffer replay failed: {s}\n", .{@errorName(err)}); }; } if (buf.len() > 0) { http.flushEvents(&buf) catch |err| { - try stderr.print("flush failed (will retry): {s}\n", .{@errorName(err)}); + std.debug.print("flush failed (will retry): {s}\n", .{@errorName(err)}); if (buf.shouldSpill()) { buf.spill(cfg.spill_path) catch |spill_err| { - try stderr.print("buffer spill failed: {s}\n", .{@errorName(spill_err)}); + std.debug.print("buffer spill failed: {s}\n", .{@errorName(spill_err)}); }; } - std.time.sleep(http.backoff_seconds * std.time.ns_per_s); + iox.sleep(http.backoff_seconds * std.time.ns_per_s); }; } - std.time.sleep(1 * std.time.ns_per_s); + iox.sleep(1 * std.time.ns_per_s); } } @@ -198,6 +277,12 @@ test "main module loads" { _ = users_collector; _ = system_collector; _ = fim_collector; + _ = process_events; + _ = fs_events; + _ = dns_collector; + _ = inventory_collector; + _ = extensions_collector; + _ = mcp_collector; _ = response_actions; } @@ -205,6 +290,25 @@ fn spillIfNeeded(buf: *buffer.Buffer, path: []const u8) !void { if (buf.shouldSpill()) try buf.spill(path); } +fn emitBatch( + buf: *buffer.Buffer, + spill_path: []const u8, + event_type: []const u8, + payloads: [][]u8, + alloc: std.mem.Allocator, +) !void { + defer alloc.free(payloads); + for (payloads) |payload| { + defer alloc.free(payload); + try buf.push(.{ + .event_type = event_type, + .occurred_at = iox.timestamp(), + .payload = payload, + }); + try spillIfNeeded(buf, spill_path); + } +} + fn persistRotatedJwt( alloc: std.mem.Allocator, cfg: *config_mod.Config, diff --git a/agent/src/platform/linux.zig b/agent/src/platform/linux.zig index dea1e8b..ef10dca 100644 --- a/agent/src/platform/linux.zig +++ b/agent/src/platform/linux.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const iox = @import("../io_compat.zig"); pub const ProcessInfo = struct { pid: u32, @@ -8,10 +9,11 @@ pub const ProcessInfo = struct { }; pub fn enumerateProcesses(alloc: std.mem.Allocator) ![]ProcessInfo { - var proc_dir = try std.fs.openDirAbsolute("/proc", .{ .iterate = true }); - defer proc_dir.close(); + const io = iox.current(); + var proc_dir = try std.Io.Dir.openDirAbsolute(io, "/proc", .{ .iterate = true }); + defer proc_dir.close(io); - var list = std.ArrayList(ProcessInfo).init(alloc); + var list = std.array_list.Managed(ProcessInfo).init(alloc); errdefer { for (list.items) |p| { alloc.free(p.name); @@ -21,13 +23,13 @@ pub fn enumerateProcesses(alloc: std.mem.Allocator) ![]ProcessInfo { } var iter = proc_dir.iterate(); - while (try iter.next()) |entry| { + while (try iter.next(io)) |entry| { if (entry.kind != .directory) continue; const pid = std.fmt.parseInt(u32, entry.name, 10) catch continue; const raw_name = readProcText(alloc, pid, "comm") catch try alloc.dupe(u8, "unknown"); defer alloc.free(raw_name); - const name = try alloc.dupe(u8, std.mem.trimRight(u8, raw_name, "\r\n")); + const name = try alloc.dupe(u8, std.mem.trimEnd(u8, raw_name, "\r\n")); errdefer alloc.free(name); const command_line = readCommandLine(alloc, pid) catch try alloc.dupe(u8, name); errdefer alloc.free(command_line); @@ -58,9 +60,10 @@ fn readParentPid(alloc: std.mem.Allocator, pid: u32) !u32 { fn readProcText(alloc: std.mem.Allocator, pid: u32, name: []const u8) ![]u8 { const path = try std.fmt.allocPrint(alloc, "/proc/{d}/{s}", .{ pid, name }); defer alloc.free(path); - var file = try std.fs.openFileAbsolute(path, .{}); - defer file.close(); - return file.readToEndAlloc(alloc, 16 * 1024); + const io = iox.current(); + var file = try std.Io.Dir.openFileAbsolute(io, path, .{}); + defer file.close(io); + return iox.readToEndAlloc(file, alloc, 16 * 1024); } fn readCommandLine(alloc: std.mem.Allocator, pid: u32) ![]u8 { @@ -72,7 +75,7 @@ fn readCommandLine(alloc: std.mem.Allocator, pid: u32) ![]u8 { if (ch.* == 0) ch.* = ' '; } - const trimmed = std.mem.trimRight(u8, owned, " "); + const trimmed = std.mem.trimEnd(u8, owned, " "); if (trimmed.len == owned.len) { return owned; } diff --git a/agent/src/platform/macos.zig b/agent/src/platform/macos.zig index 44b6770..5fa1c68 100644 --- a/agent/src/platform/macos.zig +++ b/agent/src/platform/macos.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const iox = @import("../io_compat.zig"); pub const ProcessInfo = struct { pid: u32, @@ -7,29 +8,16 @@ pub const ProcessInfo = struct { command_line: []u8, }; -const c = @cImport({ - @cInclude("libproc.h"); - @cInclude("sys/proc_info.h"); - @cInclude("sys/sysctl.h"); -}); - -const PROC_PIDPATHINFO_MAXSIZE = 4 * 1024; - pub fn enumerateProcesses(alloc: std.mem.Allocator) ![]ProcessInfo { - // Size the pid buffer. - const n = c.proc_listpids(c.PROC_ALL_PIDS, 0, null, 0); - if (n <= 0) return error.ProcListFailed; - - const pid_buf_bytes = @as(usize, @intCast(n)); - const pid_count = pid_buf_bytes / @sizeOf(c.pid_t); - const pids = try alloc.alloc(c.pid_t, pid_count); - defer alloc.free(pids); - - const n2 = c.proc_listpids(c.PROC_ALL_PIDS, 0, pids.ptr, @intCast(pid_buf_bytes)); - if (n2 <= 0) return error.ProcListFailed; - const actual = @as(usize, @intCast(n2)) / @sizeOf(c.pid_t); - - var list = std.ArrayList(ProcessInfo).init(alloc); + const result = try std.process.run(alloc, iox.current(), .{ + .argv = &.{ "ps", "-axo", "pid=,ppid=,comm=" }, + .stdout_limit = .limited(2 * 1024 * 1024), + .stderr_limit = .limited(64 * 1024), + }); + defer alloc.free(result.stdout); + defer alloc.free(result.stderr); + + var list = std.array_list.Managed(ProcessInfo).init(alloc); errdefer { for (list.items) |p| { alloc.free(p.name); @@ -38,35 +26,22 @@ pub fn enumerateProcesses(alloc: std.mem.Allocator) ![]ProcessInfo { list.deinit(); } - var name_buf: [PROC_PIDPATHINFO_MAXSIZE]u8 = undefined; - - for (pids[0..actual]) |pid| { - if (pid == 0) continue; - var info: c.proc_bsdinfo = undefined; - const got = c.proc_pidinfo( - pid, - c.PROC_PIDTBSDINFO, - 0, - &info, - @sizeOf(c.proc_bsdinfo), - ); - if (got != @sizeOf(c.proc_bsdinfo)) continue; + var lines = std.mem.splitScalar(u8, result.stdout, '\n'); + while (lines.next()) |line_raw| { + const line = std.mem.trim(u8, line_raw, " \t\r"); + if (line.len == 0) continue; - const name_len = c.proc_name(pid, &name_buf, name_buf.len); - const name_slice = if (name_len > 0) - name_buf[0..@as(usize, @intCast(name_len))] - else - "unknown"; + var fields = std.mem.tokenizeAny(u8, line, " \t"); + const pid_raw = fields.next() orelse continue; + const ppid_raw = fields.next() orelse continue; + const command = fields.rest(); + const name = std.fs.path.basename(command); - const owned = try alloc.dupe(u8, name_slice); - errdefer alloc.free(owned); - const command_line = try alloc.dupe(u8, owned); - errdefer alloc.free(command_line); try list.append(.{ - .pid = @intCast(pid), - .ppid = @intCast(info.pbi_ppid), - .name = owned, - .command_line = command_line, + .pid = std.fmt.parseInt(u32, pid_raw, 10) catch continue, + .ppid = std.fmt.parseInt(u32, ppid_raw, 10) catch 0, + .name = try alloc.dupe(u8, if (name.len == 0) "unknown" else name), + .command_line = try alloc.dupe(u8, command), }); } diff --git a/agent/src/platform/windows.zig b/agent/src/platform/windows.zig index 66825b7..c63ec8b 100644 --- a/agent/src/platform/windows.zig +++ b/agent/src/platform/windows.zig @@ -23,10 +23,10 @@ const PROCESSENTRY32W = extern struct { szExeFile: [260]u16, }; -extern "kernel32" fn CreateToolhelp32Snapshot(dwFlags: u32, th32ProcessID: u32) callconv(.C) std.os.windows.HANDLE; -extern "kernel32" fn Process32FirstW(hSnapshot: std.os.windows.HANDLE, lppe: *PROCESSENTRY32W) callconv(.C) i32; -extern "kernel32" fn Process32NextW(hSnapshot: std.os.windows.HANDLE, lppe: *PROCESSENTRY32W) callconv(.C) i32; -extern "kernel32" fn CloseHandle(hObject: std.os.windows.HANDLE) callconv(.C) i32; +extern "kernel32" fn CreateToolhelp32Snapshot(dwFlags: u32, th32ProcessID: u32) callconv(.c) std.os.windows.HANDLE; +extern "kernel32" fn Process32FirstW(hSnapshot: std.os.windows.HANDLE, lppe: *PROCESSENTRY32W) callconv(.c) i32; +extern "kernel32" fn Process32NextW(hSnapshot: std.os.windows.HANDLE, lppe: *PROCESSENTRY32W) callconv(.c) i32; +extern "kernel32" fn CloseHandle(hObject: std.os.windows.HANDLE) callconv(.c) i32; pub fn enumerateProcesses(alloc: std.mem.Allocator) ![]ProcessInfo { const snap = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0); @@ -36,7 +36,7 @@ pub fn enumerateProcesses(alloc: std.mem.Allocator) ![]ProcessInfo { var entry: PROCESSENTRY32W = std.mem.zeroes(PROCESSENTRY32W); entry.dwSize = @sizeOf(PROCESSENTRY32W); - var list = std.ArrayList(ProcessInfo).init(alloc); + var list = std.array_list.Managed(ProcessInfo).init(alloc); errdefer { for (list.items) |p| { alloc.free(p.name); diff --git a/agent/src/transport/buffer.zig b/agent/src/transport/buffer.zig index 7ea84ba..dedd38c 100644 --- a/agent/src/transport/buffer.zig +++ b/agent/src/transport/buffer.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const iox = @import("../io_compat.zig"); pub const Event = struct { event_type: []const u8, @@ -10,13 +11,13 @@ pub const Event = struct { pub const Buffer = struct { allocator: std.mem.Allocator, capacity: usize, - list: std.ArrayList(Event), + list: std.array_list.Managed(Event), pub fn init(alloc: std.mem.Allocator, capacity: usize) Buffer { return .{ .allocator = alloc, .capacity = capacity, - .list = std.ArrayList(Event).init(alloc), + .list = std.array_list.Managed(Event).init(alloc), }; } @@ -59,29 +60,33 @@ pub const Buffer = struct { pub fn spill(self: *Buffer, path: []const u8) !void { if (self.len() == 0) return; + const io = iox.current(); const dir = std.fs.path.dirname(path) orelse "."; - std.fs.cwd().makePath(dir) catch {}; + std.Io.Dir.cwd().createDirPath(io, dir) catch {}; - var file = try std.fs.cwd().createFile(path, .{ .truncate = false }); - defer file.close(); - try file.seekFromEnd(0); + var file = try std.Io.Dir.cwd().createFile(io, path, .{ .truncate = false }); + defer file.close(io); - var w = file.writer(); + var offset = (try file.stat(io)).size; for (self.list.items) |ev| { - try w.print("{s}\t{d}\t{s}\n", .{ ev.event_type, ev.occurred_at, ev.payload }); + const line = try std.fmt.allocPrint(self.allocator, "{s}\t{d}\t{s}\n", .{ ev.event_type, ev.occurred_at, ev.payload }); + defer self.allocator.free(line); + try file.writePositionalAll(io, line, offset); + offset += line.len; } - try file.sync(); + try file.sync(io); self.clear(); } pub fn replay(self: *Buffer, path: []const u8) !void { - const file = std.fs.cwd().openFile(path, .{}) catch |err| switch (err) { + const io = iox.current(); + const file = std.Io.Dir.cwd().openFile(io, path, .{}) catch |err| switch (err) { error.FileNotFound => return, else => return err, }; - defer file.close(); + defer file.close(io); - const raw = try file.readToEndAlloc(self.allocator, 32 * 1024 * 1024); + const raw = try iox.readToEndAlloc(file, self.allocator, 32 * 1024 * 1024); defer self.allocator.free(raw); var lines = std.mem.splitScalar(u8, raw, '\n'); @@ -100,7 +105,7 @@ pub const Buffer = struct { }); } - std.fs.cwd().deleteFile(path) catch |err| switch (err) { + std.Io.Dir.cwd().deleteFile(io, path) catch |err| switch (err) { error.FileNotFound => {}, else => return err, }; @@ -119,11 +124,9 @@ test "buffer pushes and clears" { test "buffer spills and replays" { var tmp = std.testing.tmpDir(.{}); defer tmp.cleanup(); - const tmp_path = try tmp.dir.realpathAlloc(std.testing.allocator, "."); - defer std.testing.allocator.free(tmp_path); const spill_path = try std.fs.path.join( std.testing.allocator, - &.{ tmp_path, "events.spool" }, + &.{ ".zig-cache", "tmp", &tmp.sub_path, "events.spool" }, ); defer std.testing.allocator.free(spill_path); diff --git a/agent/src/transport/http.zig b/agent/src/transport/http.zig index 8e4ca58..ae180c3 100644 --- a/agent/src/transport/http.zig +++ b/agent/src/transport/http.zig @@ -1,5 +1,6 @@ const std = @import("std"); const buffer_mod = @import("buffer.zig"); +const iox = @import("../io_compat.zig"); pub const HeartbeatPayload = struct { agent_version: []const u8, @@ -74,7 +75,7 @@ pub const Client = struct { result.rotated_jwt = try self.allocator.dupe(u8, jwt); } if (parsed.value.actions.len > 0) { - var actions = std.ArrayList(ResponseAction).init(self.allocator); + var actions = std.array_list.Managed(ResponseAction).init(self.allocator); errdefer { for (actions.items) |action| { self.allocator.free(action.id); @@ -85,9 +86,9 @@ pub const Client = struct { } for (parsed.value.actions) |action| { - var payload = std.ArrayList(u8).init(self.allocator); + var payload: std.Io.Writer.Allocating = .init(self.allocator); errdefer payload.deinit(); - try std.json.stringify(action.payload, .{}, payload.writer()); + try std.json.Stringify.value(action.payload, .{}, &payload.writer); try actions.append(.{ .id = try self.allocator.dupe(u8, action.id), .action_type = try self.allocator.dupe(u8, action.action_type), @@ -108,12 +109,14 @@ pub const Client = struct { const path = try std.fmt.allocPrint(self.allocator, "/api/agents/actions/{s}/result", .{action_id}); defer self.allocator.free(path); - var body = std.ArrayList(u8).init(self.allocator); + var body = std.array_list.Managed(u8).init(self.allocator); defer body.deinit(); - var w = body.writer(); - try w.print("{{\"status\":\"{s}\",\"message\":", .{status}); - try std.json.stringify(message, .{}, w); - try w.writeAll(",\"result\":{}}"); + try body.print("{{\"status\":\"{s}\",\"message\":", .{status}); + var body_writer: std.Io.Writer.Allocating = .init(self.allocator); + defer body_writer.deinit(); + try std.json.Stringify.value(message, .{}, &body_writer.writer); + try body.appendSlice(body_writer.written()); + try body.appendSlice(",\"result\":{}}"); const response = try self.post(path, body.items); defer self.allocator.free(response); @@ -122,12 +125,12 @@ pub const Client = struct { pub fn flushEvents(self: *Client, buf: *buffer_mod.Buffer) !void { if (buf.len() == 0) return; - var body = std.ArrayList(u8).init(self.allocator); + var body = std.array_list.Managed(u8).init(self.allocator); defer body.deinit(); try body.appendSlice("{\"events\":["); for (buf.items(), 0..) |ev, i| { if (i > 0) try body.append(','); - try body.writer().print( + try body.print( \\{{"type":"{s}","occurred_at":{d},"payload":{s}}} , .{ ev.event_type, ev.occurred_at, ev.payload }); } @@ -145,10 +148,10 @@ pub const Client = struct { const auth = try std.fmt.allocPrint(self.allocator, "Bearer {s}", .{self.jwt}); defer self.allocator.free(auth); - var client = std.http.Client{ .allocator = self.allocator }; + var client = std.http.Client{ .allocator = self.allocator, .io = iox.current() }; defer client.deinit(); - var resp = std.ArrayList(u8).init(self.allocator); + var resp: std.Io.Writer.Allocating = .init(self.allocator); defer resp.deinit(); const res = client.fetch(.{ @@ -159,7 +162,7 @@ pub const Client = struct { .authorization = .{ .override = auth }, }, .payload = body, - .response_storage = .{ .dynamic = &resp }, + .response_writer = &resp.writer, }) catch |err| { self.backoff_seconds = @min(self.backoff_seconds * 2, 300); return err; diff --git a/backend/src/Tawny.Api/Controllers/AlertRulesController.cs b/backend/src/Tawny.Api/Controllers/AlertRulesController.cs index d480e99..6763f55 100644 --- a/backend/src/Tawny.Api/Controllers/AlertRulesController.cs +++ b/backend/src/Tawny.Api/Controllers/AlertRulesController.cs @@ -18,7 +18,8 @@ public class AlertRulesController( TawnyDbContext db, AuditLogger audit, SigmaRuleImporter sigma, - IocRuleImporter iocs) : ControllerBase + IocRuleImporter iocs, + ExposureRuleImporter exposures) : ControllerBase { [HttpGet] public async Task>> List(CancellationToken ct) @@ -137,6 +138,41 @@ public async Task> ImportIocs( new ImportIocRulesResponse(imported.Rules.Select(ToResponse).ToList(), imported.SkippedIndicators)); } + [HttpPost("exposures")] + [Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser, Roles = "Admin")] + public async Task> ImportExposures( + ImportExposureRulesRequest req, + CancellationToken ct) + { + ExposureImportResult imported; + try + { + imported = exposures.Import( + req.Definition, + req.Severity ?? AlertSeverity.High, + req.IsEnabled ?? true, + DateTimeOffset.UtcNow); + } + catch (ExposureRuleException ex) + { + return Problem(statusCode: 400, title: ex.Message); + } + + db.AlertRules.AddRange(imported.Rules); + audit.Add(User, "alert_rule.import_exposures", null, new + { + Count = imported.Rules.Count, + Severity = req.Severity ?? AlertSeverity.High, + SkippedCount = imported.SkippedEntries.Count, + }); + await db.SaveChangesAsync(ct); + + return CreatedAtAction( + nameof(List), + new { count = imported.Rules.Count }, + new ImportExposureRulesResponse(imported.Rules.Select(ToResponse).ToList(), imported.SkippedEntries)); + } + [HttpPut("{id:guid}")] [Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser, Roles = "Admin")] public async Task> Update(Guid id, UpdateAlertRuleRequest req, CancellationToken ct) @@ -163,6 +199,7 @@ public async Task> Update(Guid id, UpdateAlertRu rule.PayloadPath = Normalize(req.PayloadPath); rule.MatchValue = Normalize(req.MatchValue); rule.SourceDefinition = null; + rule.CompiledExpressionJson = null; rule.IsEnabled = req.IsEnabled; rule.MitreTechniquesJson = SerializeTechniques(req.MitreTechniques); rule.UpdatedAt = DateTimeOffset.UtcNow; diff --git a/backend/src/Tawny.Api/Controllers/AlertsController.cs b/backend/src/Tawny.Api/Controllers/AlertsController.cs index 0e4670a..f957c9f 100644 --- a/backend/src/Tawny.Api/Controllers/AlertsController.cs +++ b/backend/src/Tawny.Api/Controllers/AlertsController.cs @@ -57,6 +57,7 @@ public async Task>> List( a.SentinelNotificationError, a.Title, a.Description, + a.EnrichmentJson, a.CreatedAt, }) .ToListAsync(ct); @@ -86,6 +87,7 @@ public async Task>> List( a.SentinelNotificationError, a.Title, a.Description, + string.IsNullOrEmpty(a.EnrichmentJson) ? null : JsonSerializer.Deserialize(a.EnrichmentJson), a.CreatedAt)).ToList()); } } diff --git a/backend/src/Tawny.Api/Controllers/CasesController.cs b/backend/src/Tawny.Api/Controllers/CasesController.cs new file mode 100644 index 0000000..0b6e736 --- /dev/null +++ b/backend/src/Tawny.Api/Controllers/CasesController.cs @@ -0,0 +1,298 @@ +using System.Text.Json; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.EntityFrameworkCore; +using Tawny.Api.Auth; +using Tawny.Api.Models; +using Tawny.Api.Services; +using Tawny.Domain; +using Tawny.Domain.Entities; +using Tawny.Infrastructure; + +namespace Tawny.Api.Controllers; + +[ApiController] +[Route("api/cases")] +[Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken)] +public class CasesController(TawnyDbContext db, AuditLogger audit) : ControllerBase +{ + [HttpGet] + public async Task>> List( + [FromQuery] CaseStatus? status, + [FromQuery] int limit = 50, + CancellationToken ct = default) + { + var take = Math.Clamp(limit, 1, 200); + var tenantId = User.GetTenantId(); + var query = db.Cases.AsNoTracking().Where(c => c.TenantId == tenantId); + if (status is not null) query = query.Where(c => c.Status == status.Value); + var rows = await query + .OrderByDescending(c => c.UpdatedAt) + .Take(take) + .Select(c => new + { + c.Id, c.Title, c.Summary, c.Status, c.Priority, + c.AssignedToUserId, c.CreatedByUserId, c.MitreTechniquesJson, + AlertCount = c.CaseAlerts.Count, + NoteCount = c.Notes.Count, + c.CreatedAt, c.UpdatedAt, c.ClosedAt, + }) + .ToListAsync(ct); + return Ok(rows.Select(c => new CaseResponse( + c.Id, c.Title, c.Summary, c.Status, c.Priority, + c.AssignedToUserId, c.CreatedByUserId, c.AlertCount, c.NoteCount, + DeserializeTechniques(c.MitreTechniquesJson), + c.CreatedAt, c.UpdatedAt, c.ClosedAt)).ToList()); + } + + [HttpGet("{id:long}")] + public async Task> Get(long id, CancellationToken ct) + { + var tenantId = User.GetTenantId(); + var caseRow = await db.Cases + .Include(c => c.CaseAlerts).ThenInclude(ca => ca.Alert).ThenInclude(a => a!.Agent) + .Include(c => c.Notes) + .FirstOrDefaultAsync(c => c.Id == id && c.TenantId == tenantId, ct); + if (caseRow is null) return NotFound(); + + return Ok(new CaseDetailResponse( + caseRow.Id, + caseRow.Title, + caseRow.Summary, + caseRow.Status, + caseRow.Priority, + caseRow.AssignedToUserId, + caseRow.CreatedByUserId, + DeserializeTechniques(caseRow.MitreTechniquesJson), + caseRow.CaseAlerts + .OrderByDescending(ca => ca.AddedAt) + .Select(ca => new CaseAlertResponse( + ca.Id, ca.AlertId, + ca.Alert?.Title ?? "", + ca.Alert?.Agent?.Hostname ?? "", + ca.Alert?.Severity.ToString().ToLowerInvariant() ?? "", + ca.Alert?.CreatedAt ?? DateTimeOffset.MinValue, + ca.AddedAt)) + .ToList(), + caseRow.Notes + .OrderByDescending(n => n.CreatedAt) + .Select(n => new CaseNoteResponse(n.Id, n.AuthorUserId, n.Body, n.CreatedAt)) + .ToList(), + caseRow.CreatedAt, + caseRow.UpdatedAt, + caseRow.ClosedAt)); + } + + [HttpPost] + public async Task> Create( + [FromBody] CreateCaseRequest req, + CancellationToken ct) + { + if (string.IsNullOrWhiteSpace(req.Title) || req.Title.Length > 255) + { + return Problem(statusCode: 400, title: "title is required and must be 255 characters or fewer."); + } + var tenantId = User.GetTenantId(); + var now = DateTimeOffset.UtcNow; + var newCase = new Case + { + TenantId = tenantId, + Title = req.Title.Trim(), + Summary = string.IsNullOrWhiteSpace(req.Summary) ? null : req.Summary.Trim(), + Priority = req.Priority ?? CasePriority.Medium, + CreatedByUserId = TryGetUserId(), + MitreTechniquesJson = SerializeTechniques(req.MitreTechniques), + CreatedAt = now, + UpdatedAt = now, + }; + db.Cases.Add(newCase); + await db.SaveChangesAsync(ct); + + if (req.AlertIds is { Count: > 0 }) + { + await LinkAlertsAsync(newCase.Id, tenantId, req.AlertIds, now, ct); + } + + audit.Add(User, "case.create", newCase.Id.ToString(), new + { + newCase.Title, + alert_count = req.AlertIds?.Count ?? 0, + }); + await db.SaveChangesAsync(ct); + + return CreatedAtAction(nameof(Get), new { id = newCase.Id }, new CaseResponse( + newCase.Id, newCase.Title, newCase.Summary, newCase.Status, newCase.Priority, + newCase.AssignedToUserId, newCase.CreatedByUserId, + req.AlertIds?.Count ?? 0, 0, + DeserializeTechniques(newCase.MitreTechniquesJson), + newCase.CreatedAt, newCase.UpdatedAt, newCase.ClosedAt)); + } + + [HttpPut("{id:long}")] + public async Task> Update( + long id, + [FromBody] UpdateCaseRequest req, + CancellationToken ct) + { + if (string.IsNullOrWhiteSpace(req.Title) || req.Title.Length > 255) + { + return Problem(statusCode: 400, title: "title is required and must be 255 characters or fewer."); + } + var tenantId = User.GetTenantId(); + var caseRow = await db.Cases.FirstOrDefaultAsync(c => c.Id == id && c.TenantId == tenantId, ct); + if (caseRow is null) return NotFound(); + + caseRow.Title = req.Title.Trim(); + caseRow.Summary = string.IsNullOrWhiteSpace(req.Summary) ? null : req.Summary.Trim(); + var transitioningToClosed = req.Status is CaseStatus.Resolved or CaseStatus.Closed + && caseRow.Status is not (CaseStatus.Resolved or CaseStatus.Closed); + caseRow.Status = req.Status; + caseRow.Priority = req.Priority; + caseRow.AssignedToUserId = req.AssignedToUserId; + caseRow.MitreTechniquesJson = SerializeTechniques(req.MitreTechniques); + caseRow.UpdatedAt = DateTimeOffset.UtcNow; + if (transitioningToClosed) caseRow.ClosedAt = caseRow.UpdatedAt; + + audit.Add(User, "case.update", caseRow.Id.ToString(), new + { + caseRow.Title, caseRow.Status, caseRow.Priority, caseRow.AssignedToUserId, + }); + await db.SaveChangesAsync(ct); + + var alertCount = await db.CaseAlerts.CountAsync(ca => ca.CaseId == caseRow.Id, ct); + var noteCount = await db.CaseNotes.CountAsync(n => n.CaseId == caseRow.Id, ct); + return Ok(new CaseResponse( + caseRow.Id, caseRow.Title, caseRow.Summary, caseRow.Status, caseRow.Priority, + caseRow.AssignedToUserId, caseRow.CreatedByUserId, alertCount, noteCount, + DeserializeTechniques(caseRow.MitreTechniquesJson), + caseRow.CreatedAt, caseRow.UpdatedAt, caseRow.ClosedAt)); + } + + [HttpDelete("{id:long}")] + public async Task Delete(long id, CancellationToken ct) + { + var tenantId = User.GetTenantId(); + var deleted = await db.Cases + .Where(c => c.Id == id && c.TenantId == tenantId) + .ExecuteDeleteAsync(ct); + if (deleted == 0) return NotFound(); + audit.Add(User, "case.delete", id.ToString()); + await db.SaveChangesAsync(ct); + return NoContent(); + } + + [HttpPost("{id:long}/alerts")] + public async Task> AddAlerts( + long id, + [FromBody] AddCaseAlertRequest req, + CancellationToken ct) + { + if (req.AlertIds is null || req.AlertIds.Count == 0) + { + return Problem(statusCode: 400, title: "alert_ids must contain at least one id."); + } + var tenantId = User.GetTenantId(); + var caseRow = await db.Cases.FirstOrDefaultAsync(c => c.Id == id && c.TenantId == tenantId, ct); + if (caseRow is null) return NotFound(); + + var now = DateTimeOffset.UtcNow; + var linked = await LinkAlertsAsync(caseRow.Id, tenantId, req.AlertIds, now, ct); + if (linked > 0) + { + caseRow.UpdatedAt = now; + audit.Add(User, "case.alerts_add", caseRow.Id.ToString(), new { count = linked }); + } + await db.SaveChangesAsync(ct); + return await Get(id, ct); + } + + [HttpDelete("{id:long}/alerts/{alertId:long}")] + public async Task RemoveAlert(long id, long alertId, CancellationToken ct) + { + var tenantId = User.GetTenantId(); + if (!await db.Cases.AnyAsync(c => c.Id == id && c.TenantId == tenantId, ct)) + { + return NotFound(); + } + await db.CaseAlerts.Where(ca => ca.CaseId == id && ca.AlertId == alertId).ExecuteDeleteAsync(ct); + audit.Add(User, "case.alert_remove", id.ToString(), new { alert_id = alertId }); + await db.SaveChangesAsync(ct); + return NoContent(); + } + + [HttpPost("{id:long}/notes")] + public async Task> AddNote( + long id, + [FromBody] AddCaseNoteRequest req, + CancellationToken ct) + { + if (string.IsNullOrWhiteSpace(req.Body)) + { + return Problem(statusCode: 400, title: "body is required."); + } + var tenantId = User.GetTenantId(); + var caseRow = await db.Cases.FirstOrDefaultAsync(c => c.Id == id && c.TenantId == tenantId, ct); + if (caseRow is null) return NotFound(); + + var note = new CaseNote + { + CaseId = caseRow.Id, + AuthorUserId = TryGetUserId(), + Body = req.Body.Trim(), + CreatedAt = DateTimeOffset.UtcNow, + }; + db.CaseNotes.Add(note); + caseRow.UpdatedAt = note.CreatedAt; + audit.Add(User, "case.note_add", caseRow.Id.ToString()); + await db.SaveChangesAsync(ct); + return Ok(new CaseNoteResponse(note.Id, note.AuthorUserId, note.Body, note.CreatedAt)); + } + + private async Task LinkAlertsAsync( + long caseId, Guid tenantId, IReadOnlyList alertIds, DateTimeOffset now, CancellationToken ct) + { + var validAlertIds = await db.Alerts + .Where(a => alertIds.Contains(a.Id) && a.Agent!.TenantId == tenantId) + .Select(a => a.Id) + .ToListAsync(ct); + var existing = await db.CaseAlerts + .Where(ca => ca.CaseId == caseId && validAlertIds.Contains(ca.AlertId)) + .Select(ca => ca.AlertId) + .ToListAsync(ct); + var existingSet = new HashSet(existing); + var toAdd = validAlertIds.Where(id => !existingSet.Contains(id)).ToList(); + var added = toAdd.Select(alertId => new CaseAlert + { + CaseId = caseId, + AlertId = alertId, + AddedAt = now, + AddedByUserId = TryGetUserId(), + }).ToList(); + if (added.Count > 0) db.CaseAlerts.AddRange(added); + return added.Count; + } + + private static IReadOnlyList DeserializeTechniques(string? json) + { + if (string.IsNullOrWhiteSpace(json)) return []; + try { return JsonSerializer.Deserialize>(json) ?? []; } + catch { return []; } + } + + private static string? SerializeTechniques(IReadOnlyList? techniques) + { + if (techniques is null || techniques.Count == 0) return null; + var normalized = techniques + .Where(t => !string.IsNullOrWhiteSpace(t)) + .Select(t => t.Trim().ToUpperInvariant()) + .Distinct() + .ToList(); + return normalized.Count == 0 ? null : JsonSerializer.Serialize(normalized); + } + + private Guid? TryGetUserId() + { + var raw = User.FindFirst(System.Security.Claims.ClaimTypes.NameIdentifier)?.Value; + return Guid.TryParse(raw, out var id) ? id : null; + } +} diff --git a/backend/src/Tawny.Api/Controllers/DashboardController.cs b/backend/src/Tawny.Api/Controllers/DashboardController.cs index 79240aa..7b78cbf 100644 --- a/backend/src/Tawny.Api/Controllers/DashboardController.cs +++ b/backend/src/Tawny.Api/Controllers/DashboardController.cs @@ -83,7 +83,7 @@ public async Task> Summary(CancellationTo if (techniqueByRule.Count > 0) { var ruleIds = techniqueByRule.Keys.ToList(); - var counts = await db.Alerts + var alertCounts = await db.Alerts .AsNoTracking() .Where(a => a.CreatedAt >= sevenDaysAgo && ruleIds.Contains(a.AlertRuleId) @@ -93,7 +93,7 @@ public async Task> Summary(CancellationTo .ToListAsync(ct); var perTechnique = new Dictionary(StringComparer.OrdinalIgnoreCase); - foreach (var c in counts) + foreach (var c in alertCounts) { if (!techniqueByRule.TryGetValue(c.RuleId, out var techniques)) continue; foreach (var t in techniques) diff --git a/backend/src/Tawny.Api/Controllers/HuntsController.cs b/backend/src/Tawny.Api/Controllers/HuntsController.cs index f8f8b9d..17eab4d 100644 --- a/backend/src/Tawny.Api/Controllers/HuntsController.cs +++ b/backend/src/Tawny.Api/Controllers/HuntsController.cs @@ -49,9 +49,12 @@ public async Task> Run( public async Task>> List(CancellationToken ct) { var tenantId = User.GetTenantId(); + var currentUserId = TryGetUserId(); + // Show shared hunts to everyone in the tenant; private hunts only to their creator. var rows = await db.SavedHunts .AsNoTracking() - .Where(h => h.TenantId == tenantId) + .Where(h => h.TenantId == tenantId + && (h.IsShared || h.CreatedByUserId == currentUserId)) .OrderBy(h => h.Name) .ToListAsync(ct); return Ok(rows.Select(ToResponse).ToList()); @@ -100,6 +103,7 @@ public async Task> Create( AlertOnMatch = req.AlertOnMatch ?? false, AlertSeverity = req.AlertSeverity ?? AlertSeverity.Medium, MitreTechniquesJson = SerializeTechniques(req.MitreTechniques), + IsShared = req.IsShared ?? true, CreatedByUserId = TryGetUserId(), CreatedAt = now, UpdatedAt = now, @@ -143,6 +147,7 @@ public async Task> Update( hunt.AlertOnMatch = req.AlertOnMatch; hunt.AlertSeverity = req.AlertSeverity; hunt.MitreTechniquesJson = SerializeTechniques(req.MitreTechniques); + if (req.IsShared.HasValue) hunt.IsShared = req.IsShared.Value; hunt.UpdatedAt = DateTimeOffset.UtcNow; audit.Add(User, "saved_hunt.update", hunt.Id.ToString(), new { @@ -254,6 +259,8 @@ public async Task>> Runs(Guid id, Ca h.AlertOnMatch, h.AlertSeverity, DeserializeTechniques(h.MitreTechniquesJson), + h.IsShared, + h.CreatedByUserId, h.LastRunAt, h.LastMatchCount, h.CreatedAt, diff --git a/backend/src/Tawny.Api/Controllers/InvestigationViewsController.cs b/backend/src/Tawny.Api/Controllers/InvestigationViewsController.cs new file mode 100644 index 0000000..aacc89c --- /dev/null +++ b/backend/src/Tawny.Api/Controllers/InvestigationViewsController.cs @@ -0,0 +1,241 @@ +using System.Text.Json; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.EntityFrameworkCore; +using Tawny.Api.Auth; +using Tawny.Domain; +using Tawny.Infrastructure; + +namespace Tawny.Api.Controllers; + +public record ProcessTreeAcrossHostsRow( + string ProcessName, + int HostCount, + int TotalSeen, + IReadOnlyList Hosts); + +public record ProcessTreeHostHit( + Guid AgentId, + string Hostname, + int SeenCount, + DateTimeOffset LastSeen); + +public record ProcessTreeAcrossHostsResponse( + IReadOnlyList Rows, + DateTimeOffset From, + DateTimeOffset To); + +public record NetworkGraphNode( + string Id, + string Label, + string Kind, + int Weight); + +public record NetworkGraphEdge( + string SourceId, + string TargetId, + int Weight); + +public record NetworkGraphResponse( + IReadOnlyList Nodes, + IReadOnlyList Edges, + DateTimeOffset From, + DateTimeOffset To); + +[ApiController] +[Route("api/investigation")] +[Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken)] +public class InvestigationViewsController(TawnyDbContext db) : ControllerBase +{ + /// + /// Aggregates process snapshots in the requested window across every agent, + /// returning each process name with the hosts that have run it. Useful for + /// answering "where else has this binary been seen?" without a per-host hunt. + /// + [HttpGet("process-tree")] + public async Task> ProcessTree( + [FromQuery] int hours = 24, + [FromQuery] string? nameFilter = null, + [FromQuery] int limit = 50, + CancellationToken ct = default) + { + var tenantId = User.GetTenantId(); + var windowHours = Math.Clamp(hours, 1, 168); + var since = DateTimeOffset.UtcNow.AddHours(-windowHours); + var top = Math.Clamp(limit, 1, 200); + + var events = await db.TelemetryEvents + .AsNoTracking() + .Where(e => e.TenantId == tenantId + && e.EventType == TelemetryEventType.ProcessSnapshot + && e.OccurredAt >= since) + .Select(e => new { e.AgentId, Hostname = e.Agent!.Hostname, e.OccurredAt, e.Payload }) + .ToListAsync(ct); + + // Aggregate in-memory: SQL Server can't easily walk JSON arrays this way. + var byName = new Dictionary>(StringComparer.OrdinalIgnoreCase); + foreach (var ev in events) + { + JsonDocument doc; + try { doc = JsonDocument.Parse(ev.Payload); } + catch { continue; } + using (doc) + { + if (!doc.RootElement.TryGetProperty("processes", out var processes) + || processes.ValueKind != JsonValueKind.Array) continue; + foreach (var p in processes.EnumerateArray()) + { + if (!p.TryGetProperty("name", out var n) || n.ValueKind != JsonValueKind.String) continue; + var name = n.GetString(); + if (string.IsNullOrWhiteSpace(name)) continue; + if (!string.IsNullOrWhiteSpace(nameFilter) + && !name.Contains(nameFilter, StringComparison.OrdinalIgnoreCase)) + { + continue; + } + if (!byName.TryGetValue(name, out var hosts)) + { + hosts = new Dictionary(); + byName[name] = hosts; + } + if (!hosts.TryGetValue(ev.AgentId, out var acc)) + { + acc = new ProcessHostAccumulator(ev.AgentId, ev.Hostname); + hosts[ev.AgentId] = acc; + } + acc.Bump(ev.OccurredAt); + } + } + } + + var rows = byName + .Select(kvp => new ProcessTreeAcrossHostsRow( + kvp.Key, + kvp.Value.Count, + kvp.Value.Values.Sum(h => h.SeenCount), + kvp.Value.Values + .OrderByDescending(h => h.LastSeen) + .Select(h => new ProcessTreeHostHit(h.AgentId, h.Hostname, h.SeenCount, h.LastSeen)) + .ToList())) + .OrderByDescending(r => r.HostCount) + .ThenByDescending(r => r.TotalSeen) + .Take(top) + .ToList(); + + return Ok(new ProcessTreeAcrossHostsResponse(rows, since, DateTimeOffset.UtcNow)); + } + + /// + /// Builds a directed graph of host -> remote endpoint flows from network + /// snapshots. Nodes are agents (kind=host) plus distinct remote IPs + /// (kind=endpoint). Edge weight is the number of observed connections. + /// + [HttpGet("network-graph")] + public async Task> NetworkGraph( + [FromQuery] int hours = 24, + [FromQuery] int maxEndpoints = 100, + CancellationToken ct = default) + { + var tenantId = User.GetTenantId(); + var windowHours = Math.Clamp(hours, 1, 168); + var since = DateTimeOffset.UtcNow.AddHours(-windowHours); + var cap = Math.Clamp(maxEndpoints, 10, 500); + + var events = await db.TelemetryEvents + .AsNoTracking() + .Where(e => e.TenantId == tenantId + && e.EventType == TelemetryEventType.NetworkSnapshot + && e.OccurredAt >= since) + .Select(e => new { e.AgentId, Hostname = e.Agent!.Hostname, e.Payload }) + .ToListAsync(ct); + + var hostNodes = new Dictionary(); + var endpointNodes = new Dictionary(StringComparer.OrdinalIgnoreCase); + var edges = new Dictionary<(Guid HostId, string Endpoint), int>(); + + foreach (var ev in events) + { + JsonDocument doc; + try { doc = JsonDocument.Parse(ev.Payload); } + catch { continue; } + using (doc) + { + if (!doc.RootElement.TryGetProperty("connections", out var conns) + || conns.ValueKind != JsonValueKind.Array) continue; + if (!hostNodes.TryGetValue(ev.AgentId, out _)) + { + hostNodes[ev.AgentId] = new NetworkGraphNode( + $"host:{ev.AgentId}", ev.Hostname, "host", 0); + } + + foreach (var conn in conns.EnumerateArray()) + { + if (!conn.TryGetProperty("remote_address", out var ra) + || ra.ValueKind != JsonValueKind.String) continue; + var remote = ra.GetString(); + if (string.IsNullOrWhiteSpace(remote) + || IsLoopbackOrUnspecified(remote)) continue; + + if (!endpointNodes.TryGetValue(remote, out var acc)) + { + acc = new EndpointAccumulator(remote); + endpointNodes[remote] = acc; + } + acc.Hits += 1; + var key = (ev.AgentId, remote); + edges[key] = edges.GetValueOrDefault(key) + 1; + } + } + } + + // Cap to the top N busiest endpoints to keep the graph readable. + var topEndpoints = endpointNodes.Values + .OrderByDescending(e => e.Hits) + .Take(cap) + .ToList(); + var topEndpointKeys = topEndpoints.Select(e => e.Address).ToHashSet(StringComparer.OrdinalIgnoreCase); + + var allNodes = new List(hostNodes.Values.Select(h => h with { Weight = 1 })); + allNodes.AddRange(topEndpoints.Select(e => + new NetworkGraphNode($"endpoint:{e.Address}", e.Address, "endpoint", e.Hits))); + + var filteredEdges = edges + .Where(kvp => topEndpointKeys.Contains(kvp.Key.Endpoint)) + .Select(kvp => new NetworkGraphEdge( + $"host:{kvp.Key.HostId}", + $"endpoint:{kvp.Key.Endpoint}", + kvp.Value)) + .OrderByDescending(e => e.Weight) + .ToList(); + + return Ok(new NetworkGraphResponse(allNodes, filteredEdges, since, DateTimeOffset.UtcNow)); + } + + private static bool IsLoopbackOrUnspecified(string address) + { + return address.StartsWith("127.", StringComparison.Ordinal) + || address == "::1" + || address == "0.0.0.0" + || address.StartsWith("169.254.", StringComparison.Ordinal); + } + + private sealed class ProcessHostAccumulator(Guid agentId, string hostname) + { + public Guid AgentId { get; } = agentId; + public string Hostname { get; } = hostname; + public int SeenCount { get; private set; } + public DateTimeOffset LastSeen { get; private set; } + + public void Bump(DateTimeOffset at) + { + SeenCount += 1; + if (at > LastSeen) LastSeen = at; + } + } + + private sealed class EndpointAccumulator(string address) + { + public string Address { get; } = address; + public int Hits { get; set; } + } +} diff --git a/backend/src/Tawny.Api/Controllers/RuleTestController.cs b/backend/src/Tawny.Api/Controllers/RuleTestController.cs new file mode 100644 index 0000000..4772cdb --- /dev/null +++ b/backend/src/Tawny.Api/Controllers/RuleTestController.cs @@ -0,0 +1,53 @@ +using System.Text.Json; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.EntityFrameworkCore; +using Tawny.Api.Auth; +using Tawny.Domain; +using Tawny.Domain.Entities; +using Tawny.Infrastructure; +using Tawny.Infrastructure.Hunting; + +namespace Tawny.Api.Controllers; + +public record RuleTestEventBody( + TelemetryEventType EventType, + DateTimeOffset OccurredAt, + JsonElement Payload); + +public record RuleTestRequest(IReadOnlyList Events); + +public record RuleTestResponse( + bool Matched, + string? FailReason, + IReadOnlyList Trace); + +[ApiController] +[Route("api/alert-rules")] +[Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken)] +public class RuleTestController(TawnyDbContext db, RuleTestHarness harness) : ControllerBase +{ + /// + /// Run a saved rule against a supplied list of events without touching the DB. + /// Returns whether it would fire and a per-step trace of why or why not. + /// + [HttpPost("{id:guid}/test")] + public async Task> Test( + Guid id, + [FromBody] RuleTestRequest req, + CancellationToken ct) + { + if (req.Events is null || req.Events.Count == 0) + { + return Problem(statusCode: 400, title: "events array is required and must contain at least one event."); + } + var rule = await db.AlertRules.AsNoTracking().FirstOrDefaultAsync(r => r.Id == id, ct); + if (rule is null) return NotFound(); + + var inputs = req.Events + .Select(e => new RuleTestEventInput(e.EventType, e.OccurredAt, e.Payload)) + .ToList(); + var result = harness.Test(rule, inputs); + return Ok(new RuleTestResponse(result.Matched, result.FailReason, result.Trace)); + } +} diff --git a/backend/src/Tawny.Api/Controllers/SequenceRulesController.cs b/backend/src/Tawny.Api/Controllers/SequenceRulesController.cs new file mode 100644 index 0000000..fbb4230 --- /dev/null +++ b/backend/src/Tawny.Api/Controllers/SequenceRulesController.cs @@ -0,0 +1,148 @@ +using System.Text.Json; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.EntityFrameworkCore; +using Tawny.Api.Auth; +using Tawny.Api.Models; +using Tawny.Api.Services; +using Tawny.Domain; +using Tawny.Domain.Entities; +using Tawny.Infrastructure; +using Tawny.Infrastructure.Hunting; + +namespace Tawny.Api.Controllers; + +[ApiController] +[Route("api/sequence-rules")] +[Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken)] +public class SequenceRulesController( + TawnyDbContext db, + AuditLogger audit, + SequenceRuleEvaluator sequences) : ControllerBase +{ + [HttpGet] + public async Task>> List(CancellationToken ct) + { + var rows = await db.AlertRules + .AsNoTracking() + .Where(r => r.Format == AlertRuleFormat.Sequence) + .OrderBy(r => r.Name) + .ToListAsync(ct); + return Ok(rows.Select(ToResponse).ToList()); + } + + [HttpPost] + [Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken, Roles = "Admin")] + public async Task> Create( + [FromBody] CreateSequenceRuleRequest req, + CancellationToken ct) + { + if (string.IsNullOrWhiteSpace(req.Name) || req.Name.Length > 160) + { + return Problem(statusCode: 400, title: "name is required and must be 160 characters or fewer."); + } + if (req.Steps is null || req.Steps.Count < 2) + { + return Problem(statusCode: 400, title: "A sequence rule needs at least two steps."); + } + if (req.WindowSeconds <= 0 || req.WindowSeconds > 86_400) + { + return Problem(statusCode: 400, title: "window_seconds must be between 1 and 86400."); + } + + var definition = new SequenceRuleDefinition( + req.WindowSeconds, + "agent", + req.Steps.Select(s => new SequenceStep(s.Name, s.EventType, s.PayloadPath, s.Operator, s.MatchValue)).ToList()); + + try { SequenceRuleParser.Parse(SequenceRuleParser.Serialize(definition)); } + catch (SequenceRuleException ex) + { + return Problem(statusCode: 400, title: ex.Message); + } + + var now = DateTimeOffset.UtcNow; + var rule = new AlertRule + { + Id = Guid.NewGuid(), + Name = req.Name.Trim(), + Format = AlertRuleFormat.Sequence, + Description = string.IsNullOrWhiteSpace(req.Description) ? null : req.Description.Trim(), + Severity = req.Severity, + Operator = AlertRuleOperator.Exists, + SourceDefinition = SequenceRuleParser.Serialize(definition), + IsEnabled = req.IsEnabled ?? true, + MitreTechniquesJson = SerializeTechniques(req.MitreTechniques), + CreatedAt = now, + UpdatedAt = now, + }; + db.AlertRules.Add(rule); + audit.Add(User, "sequence_rule.create", rule.Id.ToString(), new + { + rule.Name, + step_count = req.Steps.Count, + req.WindowSeconds, + }); + await db.SaveChangesAsync(ct); + sequences.ResetAll(); // wipe in-memory partial state so new rule starts cleanly + return CreatedAtAction(nameof(List), new { id = rule.Id }, ToResponse(rule)); + } + + [HttpDelete("{id:guid}")] + [Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken, Roles = "Admin")] + public async Task Delete(Guid id, CancellationToken ct) + { + if (await db.Alerts.AnyAsync(a => a.AlertRuleId == id, ct)) + { + return Problem(statusCode: 409, title: "Sequence rule has alerts and cannot be deleted. Disable it instead."); + } + var deleted = await db.AlertRules + .Where(r => r.Id == id && r.Format == AlertRuleFormat.Sequence) + .ExecuteDeleteAsync(ct); + if (deleted == 0) return NotFound(); + audit.Add(User, "sequence_rule.delete", id.ToString()); + await db.SaveChangesAsync(ct); + sequences.ResetAll(); + return NoContent(); + } + + private static SequenceRuleResponse ToResponse(AlertRule rule) + { + SequenceRuleDefinition definition; + try { definition = SequenceRuleParser.Parse(rule.SourceDefinition ?? ""); } + catch + { + return new SequenceRuleResponse( + rule.Id, rule.Name, rule.Description, rule.Severity, 0, [], [], rule.IsEnabled, rule.CreatedAt, rule.UpdatedAt); + } + return new SequenceRuleResponse( + rule.Id, + rule.Name, + rule.Description, + rule.Severity, + definition.WindowSeconds, + definition.Steps.Select(s => new SequenceStepInput(s.Name, s.EventType, s.PayloadPath, s.Operator, s.MatchValue)).ToList(), + DeserializeTechniques(rule.MitreTechniquesJson), + rule.IsEnabled, + rule.CreatedAt, + rule.UpdatedAt); + } + + private static IReadOnlyList DeserializeTechniques(string? json) + { + if (string.IsNullOrWhiteSpace(json)) return []; + try { return JsonSerializer.Deserialize>(json) ?? []; } + catch { return []; } + } + + private static string? SerializeTechniques(IReadOnlyList? techniques) + { + if (techniques is null || techniques.Count == 0) return null; + var normalized = techniques + .Where(t => !string.IsNullOrWhiteSpace(t)) + .Select(t => t.Trim().ToUpperInvariant()) + .Distinct() + .ToList(); + return normalized.Count == 0 ? null : JsonSerializer.Serialize(normalized); + } +} diff --git a/backend/src/Tawny.Api/Controllers/TelemetryController.cs b/backend/src/Tawny.Api/Controllers/TelemetryController.cs index 2a291ae..d19048a 100644 --- a/backend/src/Tawny.Api/Controllers/TelemetryController.cs +++ b/backend/src/Tawny.Api/Controllers/TelemetryController.cs @@ -190,6 +190,13 @@ private static bool TryParseEventType(string value, out TelemetryEventType event TelemetryEventType.SystemInfo => "system_info", TelemetryEventType.FileIntegrity => "file_integrity", TelemetryEventType.Heartbeat => "heartbeat", + TelemetryEventType.DnsQuery => "dns_query", + TelemetryEventType.ProcessLaunch => "process_launch", + TelemetryEventType.FileEvent => "file_event", + TelemetryEventType.PackageInventory => "package_inventory", + TelemetryEventType.EditorExtension => "editor_extension", + TelemetryEventType.BrowserExtension => "browser_extension", + TelemetryEventType.McpConfig => "mcp_config", _ => type.ToString(), }; } diff --git a/backend/src/Tawny.Api/Controllers/ThreatIntelFeedsController.cs b/backend/src/Tawny.Api/Controllers/ThreatIntelFeedsController.cs new file mode 100644 index 0000000..2fd7b16 --- /dev/null +++ b/backend/src/Tawny.Api/Controllers/ThreatIntelFeedsController.cs @@ -0,0 +1,174 @@ +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.EntityFrameworkCore; +using Tawny.Api.Auth; +using Tawny.Api.Models; +using Tawny.Api.Services; +using Tawny.Domain; +using Tawny.Domain.Entities; +using Tawny.Infrastructure; +using Tawny.Infrastructure.ThreatIntel; +using Tawny.Jobs; + +namespace Tawny.Api.Controllers; + +[ApiController] +[Route("api/threat-intel-feeds")] +[Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken)] +public class ThreatIntelFeedsController( + TawnyDbContext db, + AuditLogger audit, + ThreatIntelFeedsJob job) : ControllerBase +{ + [HttpGet] + public async Task>> List(CancellationToken ct) + { + var tenantId = User.GetTenantId(); + var rows = await db.ThreatIntelFeeds + .AsNoTracking() + .Where(f => f.TenantId == tenantId) + .OrderBy(f => f.Name) + .Select(f => new ThreatIntelFeedResponse( + f.Id, f.Name, f.Kind, f.Url, f.AuthHeaderName, + f.DefaultSeverity, f.IntervalMinutes, f.IsEnabled, + f.Status, f.LastRunAt, f.LastSuccessAt, + f.LastImportedCount, f.LastSkippedCount, f.LastError, + f.CreatedAt, f.UpdatedAt)) + .ToListAsync(ct); + return Ok(rows); + } + + [HttpPost] + [Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken, Roles = "Admin")] + public async Task> Create( + [FromBody] CreateThreatIntelFeedRequest req, + CancellationToken ct) + { + if (string.IsNullOrWhiteSpace(req.Name) || req.Name.Length > 160) + { + return Problem(statusCode: 400, title: "name is required and must be 160 characters or fewer."); + } + if (!Uri.TryCreate(req.Url, UriKind.Absolute, out _)) + { + return Problem(statusCode: 400, title: "url must be an absolute URL."); + } + var interval = req.IntervalMinutes ?? 60; + if (interval < 5 || interval > 10_080) + { + return Problem(statusCode: 400, title: "interval_minutes must be between 5 and 10080."); + } + + var now = DateTimeOffset.UtcNow; + var feed = new ThreatIntelFeed + { + Id = Guid.NewGuid(), + TenantId = User.GetTenantId(), + Name = req.Name.Trim(), + Kind = req.Kind, + Url = req.Url.Trim(), + AuthHeaderName = NullIfEmpty(req.AuthHeaderName), + AuthHeaderValueEncrypted = NullIfEmpty(req.AuthHeaderValue), + DefaultSeverity = req.DefaultSeverity ?? AlertSeverity.High, + IntervalMinutes = interval, + IsEnabled = req.IsEnabled ?? true, + CreatedByUserId = TryGetUserId(), + CreatedAt = now, + UpdatedAt = now, + }; + db.ThreatIntelFeeds.Add(feed); + audit.Add(User, "threat_intel_feed.create", feed.Id.ToString(), new + { + feed.Name, feed.Kind, feed.Url, feed.IntervalMinutes, + }); + await db.SaveChangesAsync(ct); + + return CreatedAtAction(nameof(List), new { id = feed.Id }, ToResponse(feed)); + } + + [HttpPut("{id:guid}")] + [Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken, Roles = "Admin")] + public async Task> Update( + Guid id, + [FromBody] UpdateThreatIntelFeedRequest req, + CancellationToken ct) + { + if (!Uri.TryCreate(req.Url, UriKind.Absolute, out _)) + { + return Problem(statusCode: 400, title: "url must be an absolute URL."); + } + if (req.IntervalMinutes < 5 || req.IntervalMinutes > 10_080) + { + return Problem(statusCode: 400, title: "interval_minutes must be between 5 and 10080."); + } + + var feed = await db.ThreatIntelFeeds.FirstOrDefaultAsync(f => f.Id == id && f.TenantId == User.GetTenantId(), ct); + if (feed is null) return NotFound(); + + feed.Name = req.Name.Trim(); + feed.Kind = req.Kind; + feed.Url = req.Url.Trim(); + feed.AuthHeaderName = NullIfEmpty(req.AuthHeaderName); + if (!string.IsNullOrWhiteSpace(req.AuthHeaderValue)) + { + feed.AuthHeaderValueEncrypted = req.AuthHeaderValue; + } + feed.DefaultSeverity = req.DefaultSeverity; + feed.IntervalMinutes = req.IntervalMinutes; + feed.IsEnabled = req.IsEnabled; + feed.UpdatedAt = DateTimeOffset.UtcNow; + audit.Add(User, "threat_intel_feed.update", feed.Id.ToString(), new + { + feed.Name, feed.Kind, feed.IntervalMinutes, feed.IsEnabled, + }); + await db.SaveChangesAsync(ct); + return Ok(ToResponse(feed)); + } + + [HttpDelete("{id:guid}")] + [Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken, Roles = "Admin")] + public async Task Delete(Guid id, CancellationToken ct) + { + var deleted = await db.ThreatIntelFeeds + .Where(f => f.Id == id && f.TenantId == User.GetTenantId()) + .ExecuteDeleteAsync(ct); + if (deleted == 0) return NotFound(); + audit.Add(User, "threat_intel_feed.delete", id.ToString()); + await db.SaveChangesAsync(ct); + return NoContent(); + } + + [HttpPost("{id:guid}/run")] + [Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken, Roles = "Admin")] + public async Task> Run(Guid id, CancellationToken ct) + { + var feed = await db.ThreatIntelFeeds.FirstOrDefaultAsync(f => f.Id == id && f.TenantId == User.GetTenantId(), ct); + if (feed is null) return NotFound(); + // Reset throttle so the job picks it up immediately. + feed.LastRunAt = null; + await db.SaveChangesAsync(ct); + await job.ExecuteAsync(ct); + await db.Entry(feed).ReloadAsync(ct); + audit.Add(User, "threat_intel_feed.run", feed.Id.ToString()); + await db.SaveChangesAsync(ct); + return Ok(ToResponse(feed)); + } + + private static ThreatIntelFeedResponse ToResponse(ThreatIntelFeed f) => new( + f.Id, f.Name, f.Kind, f.Url, f.AuthHeaderName, + f.DefaultSeverity, f.IntervalMinutes, f.IsEnabled, + f.Status, f.LastRunAt, f.LastSuccessAt, + f.LastImportedCount, f.LastSkippedCount, f.LastError, + f.CreatedAt, f.UpdatedAt); + + private static string? NullIfEmpty(string? value) + { + var t = value?.Trim(); + return string.IsNullOrEmpty(t) ? null : t; + } + + private Guid? TryGetUserId() + { + var raw = User.FindFirst(System.Security.Claims.ClaimTypes.NameIdentifier)?.Value; + return Guid.TryParse(raw, out var id) ? id : null; + } +} diff --git a/backend/src/Tawny.Api/Controllers/YaraRulesController.cs b/backend/src/Tawny.Api/Controllers/YaraRulesController.cs new file mode 100644 index 0000000..e1484b3 --- /dev/null +++ b/backend/src/Tawny.Api/Controllers/YaraRulesController.cs @@ -0,0 +1,106 @@ +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.EntityFrameworkCore; +using Tawny.Api.Auth; +using Tawny.Api.Services; +using Tawny.Domain; +using Tawny.Domain.Entities; +using Tawny.Infrastructure; +using Tawny.Infrastructure.Hunting; + +namespace Tawny.Api.Controllers; + +public record CreateYaraRuleRequest( + string Name, + string? Description, + AlertSeverity Severity, + TelemetryEventType? EventType, + string Definition, + bool? IsEnabled); + +public record YaraRuleResponse( + Guid Id, + string Name, + string? Description, + AlertSeverity Severity, + TelemetryEventType? EventType, + string Definition, + bool IsEnabled, + DateTimeOffset CreatedAt, + DateTimeOffset UpdatedAt); + +[ApiController] +[Route("api/yara-rules")] +[Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken)] +public class YaraRulesController(TawnyDbContext db, AuditLogger audit) : ControllerBase +{ + [HttpGet] + public async Task>> List(CancellationToken ct) + { + var rows = await db.AlertRules + .AsNoTracking() + .Where(r => r.Format == AlertRuleFormat.Yara) + .OrderBy(r => r.Name) + .Select(r => new YaraRuleResponse( + r.Id, r.Name, r.Description, r.Severity, r.EventType, + r.SourceDefinition ?? "", r.IsEnabled, r.CreatedAt, r.UpdatedAt)) + .ToListAsync(ct); + return Ok(rows); + } + + [HttpPost] + [Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken, Roles = "Admin")] + public async Task> Create( + [FromBody] CreateYaraRuleRequest req, + CancellationToken ct) + { + if (string.IsNullOrWhiteSpace(req.Name) || req.Name.Length > 160) + { + return Problem(statusCode: 400, title: "name is required and must be 160 characters or fewer."); + } + try { YaraLiteParser.Parse(req.Definition); } + catch (YaraLiteException ex) + { + return Problem(statusCode: 400, title: ex.Message); + } + + var now = DateTimeOffset.UtcNow; + var rule = new AlertRule + { + Id = Guid.NewGuid(), + Name = req.Name.Trim(), + Format = AlertRuleFormat.Yara, + Description = string.IsNullOrWhiteSpace(req.Description) ? null : req.Description.Trim(), + EventType = req.EventType, + Severity = req.Severity, + Operator = AlertRuleOperator.Exists, + SourceDefinition = req.Definition.Trim(), + IsEnabled = req.IsEnabled ?? true, + CreatedAt = now, + UpdatedAt = now, + }; + db.AlertRules.Add(rule); + audit.Add(User, "yara_rule.create", rule.Id.ToString(), new { rule.Name }); + await db.SaveChangesAsync(ct); + return CreatedAtAction(nameof(List), new { id = rule.Id }, + new YaraRuleResponse(rule.Id, rule.Name, rule.Description, rule.Severity, rule.EventType, + rule.SourceDefinition!, rule.IsEnabled, rule.CreatedAt, rule.UpdatedAt)); + } + + [HttpDelete("{id:guid}")] + [Authorize(AuthenticationSchemes = TawnyAuthSchemes.WebUser + "," + TawnyAuthSchemes.ApiToken, Roles = "Admin")] + public async Task Delete(Guid id, CancellationToken ct) + { + if (await db.Alerts.AnyAsync(a => a.AlertRuleId == id, ct)) + { + return Problem(statusCode: 409, title: "Rule has alerts; disable it instead."); + } + var deleted = await db.AlertRules + .Where(r => r.Id == id && r.Format == AlertRuleFormat.Yara) + .ExecuteDeleteAsync(ct); + if (deleted == 0) return NotFound(); + audit.Add(User, "yara_rule.delete", id.ToString()); + await db.SaveChangesAsync(ct); + return NoContent(); + } +} diff --git a/backend/src/Tawny.Api/Models/AlertDtos.cs b/backend/src/Tawny.Api/Models/AlertDtos.cs index f18c6fd..495396b 100644 --- a/backend/src/Tawny.Api/Models/AlertDtos.cs +++ b/backend/src/Tawny.Api/Models/AlertDtos.cs @@ -54,6 +54,15 @@ public record ImportIocRulesResponse( IReadOnlyList Rules, IReadOnlyList SkippedIndicators); +public record ImportExposureRulesRequest( + string Definition, + AlertSeverity? Severity, + bool? IsEnabled); + +public record ImportExposureRulesResponse( + IReadOnlyList Rules, + IReadOnlyList SkippedEntries); + public record AlertResponse( long Id, Guid AlertRuleId, @@ -79,4 +88,5 @@ public record AlertResponse( string? SentinelNotificationError, string Title, string? Description, + JsonElement? Enrichment, DateTimeOffset CreatedAt); diff --git a/backend/src/Tawny.Api/Models/CaseDtos.cs b/backend/src/Tawny.Api/Models/CaseDtos.cs new file mode 100644 index 0000000..7770b0b --- /dev/null +++ b/backend/src/Tawny.Api/Models/CaseDtos.cs @@ -0,0 +1,67 @@ +using Tawny.Domain.Entities; + +namespace Tawny.Api.Models; + +public record CreateCaseRequest( + string Title, + string? Summary, + CasePriority? Priority, + IReadOnlyList? AlertIds, + IReadOnlyList? MitreTechniques); + +public record UpdateCaseRequest( + string Title, + string? Summary, + CaseStatus Status, + CasePriority Priority, + Guid? AssignedToUserId, + IReadOnlyList? MitreTechniques); + +public record CaseAlertResponse( + long Id, + long AlertId, + string AlertTitle, + string AlertHostname, + string AlertSeverity, + DateTimeOffset AlertCreatedAt, + DateTimeOffset AddedAt); + +public record CaseNoteResponse( + long Id, + Guid? AuthorUserId, + string Body, + DateTimeOffset CreatedAt); + +public record CaseResponse( + long Id, + string Title, + string? Summary, + CaseStatus Status, + CasePriority Priority, + Guid? AssignedToUserId, + Guid? CreatedByUserId, + int AlertCount, + int NoteCount, + IReadOnlyList MitreTechniques, + DateTimeOffset CreatedAt, + DateTimeOffset UpdatedAt, + DateTimeOffset? ClosedAt); + +public record CaseDetailResponse( + long Id, + string Title, + string? Summary, + CaseStatus Status, + CasePriority Priority, + Guid? AssignedToUserId, + Guid? CreatedByUserId, + IReadOnlyList MitreTechniques, + IReadOnlyList Alerts, + IReadOnlyList Notes, + DateTimeOffset CreatedAt, + DateTimeOffset UpdatedAt, + DateTimeOffset? ClosedAt); + +public record AddCaseAlertRequest(IReadOnlyList AlertIds); + +public record AddCaseNoteRequest(string Body); diff --git a/backend/src/Tawny.Api/Models/HuntDtos.cs b/backend/src/Tawny.Api/Models/HuntDtos.cs index 99511ec..aa786e6 100644 --- a/backend/src/Tawny.Api/Models/HuntDtos.cs +++ b/backend/src/Tawny.Api/Models/HuntDtos.cs @@ -29,7 +29,8 @@ public record CreateSavedHuntRequest( string? ScheduleCron, bool? AlertOnMatch, AlertSeverity? AlertSeverity, - IReadOnlyList? MitreTechniques); + IReadOnlyList? MitreTechniques, + bool? IsShared); public record UpdateSavedHuntRequest( string Name, @@ -39,7 +40,8 @@ public record UpdateSavedHuntRequest( string? ScheduleCron, bool AlertOnMatch, AlertSeverity AlertSeverity, - IReadOnlyList? MitreTechniques); + IReadOnlyList? MitreTechniques, + bool? IsShared); public record SavedHuntResponse( Guid Id, @@ -51,6 +53,8 @@ public record SavedHuntResponse( bool AlertOnMatch, AlertSeverity AlertSeverity, IReadOnlyList MitreTechniques, + bool IsShared, + Guid? CreatedByUserId, DateTimeOffset? LastRunAt, int? LastMatchCount, DateTimeOffset CreatedAt, diff --git a/backend/src/Tawny.Api/Models/SequenceRuleDtos.cs b/backend/src/Tawny.Api/Models/SequenceRuleDtos.cs new file mode 100644 index 0000000..b9195ed --- /dev/null +++ b/backend/src/Tawny.Api/Models/SequenceRuleDtos.cs @@ -0,0 +1,31 @@ +using Tawny.Domain; + +namespace Tawny.Api.Models; + +public record CreateSequenceRuleRequest( + string Name, + string? Description, + AlertSeverity Severity, + int WindowSeconds, + IReadOnlyList Steps, + IReadOnlyList? MitreTechniques, + bool? IsEnabled); + +public record SequenceStepInput( + string Name, + TelemetryEventType EventType, + string? PayloadPath, + AlertRuleOperator Operator, + string? MatchValue); + +public record SequenceRuleResponse( + Guid Id, + string Name, + string? Description, + AlertSeverity Severity, + int WindowSeconds, + IReadOnlyList Steps, + IReadOnlyList MitreTechniques, + bool IsEnabled, + DateTimeOffset CreatedAt, + DateTimeOffset UpdatedAt); diff --git a/backend/src/Tawny.Api/Models/ThreatIntelFeedDtos.cs b/backend/src/Tawny.Api/Models/ThreatIntelFeedDtos.cs new file mode 100644 index 0000000..547ea61 --- /dev/null +++ b/backend/src/Tawny.Api/Models/ThreatIntelFeedDtos.cs @@ -0,0 +1,41 @@ +using Tawny.Domain; + +namespace Tawny.Api.Models; + +public record CreateThreatIntelFeedRequest( + string Name, + ThreatIntelFeedKind Kind, + string Url, + string? AuthHeaderName, + string? AuthHeaderValue, + AlertSeverity? DefaultSeverity, + int? IntervalMinutes, + bool? IsEnabled); + +public record UpdateThreatIntelFeedRequest( + string Name, + ThreatIntelFeedKind Kind, + string Url, + string? AuthHeaderName, + string? AuthHeaderValue, + AlertSeverity DefaultSeverity, + int IntervalMinutes, + bool IsEnabled); + +public record ThreatIntelFeedResponse( + Guid Id, + string Name, + ThreatIntelFeedKind Kind, + string Url, + string? AuthHeaderName, + AlertSeverity DefaultSeverity, + int IntervalMinutes, + bool IsEnabled, + ThreatIntelFeedStatus Status, + DateTimeOffset? LastRunAt, + DateTimeOffset? LastSuccessAt, + int LastImportedCount, + int LastSkippedCount, + string? LastError, + DateTimeOffset CreatedAt, + DateTimeOffset UpdatedAt); diff --git a/backend/src/Tawny.Api/Program.cs b/backend/src/Tawny.Api/Program.cs index da4bf99..63d1853 100644 --- a/backend/src/Tawny.Api/Program.cs +++ b/backend/src/Tawny.Api/Program.cs @@ -15,6 +15,7 @@ using Tawny.Api.Services; using Tawny.Infrastructure; using Tawny.Infrastructure.Hunting; +using Tawny.Infrastructure.ThreatIntel; using Tawny.Jobs; var builder = WebApplication.CreateBuilder(args); @@ -28,6 +29,7 @@ builder.Services.Configure(builder.Configuration.GetSection("Tawny:Wazuh")); builder.Services.Configure(builder.Configuration.GetSection("Tawny:Slack")); builder.Services.Configure(builder.Configuration.GetSection("Tawny:Sentinel")); +builder.Services.Configure(builder.Configuration.GetSection("Tawny:Reputation")); builder.Services.Configure(TawnyAuthSchemes.WebUser, opt => { opt.HmacSecret = builder.Configuration["Tawny:WebUserHmacSecret"] ?? ""; @@ -41,10 +43,15 @@ builder.Services.AddScoped(); builder.Services.AddScoped(); builder.Services.AddScoped(); +builder.Services.AddScoped(); builder.Services.AddSingleton(); builder.Services.AddScoped(); builder.Services.AddScoped(); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); builder.Services.AddSingleton(); +builder.Services.AddHttpClient(); +builder.Services.AddHttpClient(); builder.Services.AddSingleton(); builder.Services.AddHttpClient(); builder.Services.AddHttpClient(); @@ -123,6 +130,8 @@ await context.HttpContext.Response.WriteAsJsonAsync(new builder.Services.AddScoped(); builder.Services.AddScoped(); builder.Services.AddScoped(); + builder.Services.AddScoped(); + builder.Services.AddScoped(); builder.Services.AddHttpClient(); builder.Services.AddHangfire(cfg => cfg @@ -175,6 +184,10 @@ await context.HttpContext.Response.WriteAsJsonAsync(new "check-agent-releases", j => j.ExecuteAsync(default), Cron.Hourly); RecurringJob.AddOrUpdate( "scheduled-hunts", j => j.ExecuteAsync(default), "*/5 * * * *"); + RecurringJob.AddOrUpdate( + "threat-intel-feeds", j => j.ExecuteAsync(default), "*/10 * * * *"); + RecurringJob.AddOrUpdate( + "reputation-enrichment", j => j.ExecuteAsync(default), "*/5 * * * *"); } app.Run(); diff --git a/backend/src/Tawny.Api/Services/AgentEventBroker.cs b/backend/src/Tawny.Api/Services/AgentEventBroker.cs index 91b147d..8e014ae 100644 --- a/backend/src/Tawny.Api/Services/AgentEventBroker.cs +++ b/backend/src/Tawny.Api/Services/AgentEventBroker.cs @@ -27,26 +27,28 @@ public class AgentEventBroker public IDisposable Subscribe(Guid tenantId, Guid agentId, out Channel channel) { - channel = Channel.CreateBounded(new BoundedChannelOptions(256) + // Build the channel into a local first; C# forbids capturing an `out` + // parameter inside a lambda (its lifetime isn't guaranteed past the + // caller's stack frame), so the dispose lambda must close over the + // local copy instead of `channel` itself. + var created = Channel.CreateBounded(new BoundedChannelOptions(256) { FullMode = BoundedChannelFullMode.DropOldest, SingleReader = true, SingleWriter = false, }); + channel = created; var perTenant = _subscribers.GetOrAdd(tenantId, _ => new ConcurrentDictionary>()); var subscriberId = Guid.NewGuid(); - // Key by (subscriberId XOR agentId) so multiple subscribers on same agent coexist. - // We store the filter agentId alongside via the channel writer's queue items being already filtered. - // Implementation detail: store as a list of (agentId, channel) under tenant for routing. - perTenant.TryAdd(subscriberId, channel); + perTenant.TryAdd(subscriberId, created); _filters[subscriberId] = agentId; return new Subscription(() => { perTenant.TryRemove(subscriberId, out _); _filters.TryRemove(subscriberId, out _); - channel.Writer.TryComplete(); + created.Writer.TryComplete(); }); } diff --git a/backend/src/Tawny.Api/Services/AlertRuleEvaluator.cs b/backend/src/Tawny.Api/Services/AlertRuleEvaluator.cs index 75c1276..21b7d2a 100644 --- a/backend/src/Tawny.Api/Services/AlertRuleEvaluator.cs +++ b/backend/src/Tawny.Api/Services/AlertRuleEvaluator.cs @@ -8,7 +8,10 @@ namespace Tawny.Api.Services; -public class AlertRuleEvaluator(TawnyDbContext db, SuppressionEvaluator suppressions) +public class AlertRuleEvaluator( + TawnyDbContext db, + SuppressionEvaluator suppressions, + SequenceRuleEvaluator sequences) { public async Task> EvaluateAsync( Agent agent, @@ -26,7 +29,13 @@ public async Task> EvaluateAsync( .Where(r => r.IsEnabled && (r.EventType == null || eventTypes.Contains(r.EventType.Value))) .ToListAsync(ct); - if (rules.Count == 0) + // Sequence rules need to see every event regardless of EventType filter, + // because each step can target a different type. Load them separately. + var sequenceRules = await db.AlertRules + .Where(r => r.IsEnabled && r.Format == AlertRuleFormat.Sequence) + .ToListAsync(ct); + + if (rules.Count == 0 && sequenceRules.Count == 0) { return []; } @@ -37,6 +46,7 @@ public async Task> EvaluateAsync( using var payload = JsonDocument.Parse(telemetryEvent.Payload); foreach (var rule in rules) { + if (rule.Format == AlertRuleFormat.Sequence) continue; // handled below if (rule.EventType is not null && rule.EventType.Value != telemetryEvent.EventType) { continue; @@ -60,6 +70,27 @@ public async Task> EvaluateAsync( } } + foreach (var rule in sequenceRules) + { + SequenceRuleDefinition definition; + try { definition = SequenceRuleParser.Parse(rule.SourceDefinition ?? ""); } + catch { continue; } + var matches = sequences.Evaluate(rule, definition, agent, events, now); + foreach (var match in matches) + { + candidates.Add(new Alert + { + AlertRuleId = rule.Id, + AgentId = agent.Id, + TelemetryEventId = match.TriggeringEventId, + Severity = rule.Severity, + Title = $"{rule.Name} on {agent.Hostname}", + Description = BuildSequenceDescription(rule, match), + CreatedAt = now, + }); + } + } + if (candidates.Count == 0) { return []; @@ -79,8 +110,50 @@ public async Task> EvaluateAsync( return emitted; } + private static string BuildSequenceDescription(AlertRule rule, SequenceMatch match) + { + var steps = string.Join(" -> ", match.Trail.Select(s => s.Name)); + return $"Sequence '{rule.Name}' completed: {steps}."; + } + private static bool Matches(AlertRule rule, JsonElement payload) { + // Package exposure: match (ecosystem, name, version_pattern) against an + // inventory event. Inspired by Perplexity's Bumblebee scanner. + if (rule.Format == AlertRuleFormat.PackageExposure && !string.IsNullOrWhiteSpace(rule.SourceDefinition)) + { + try + { + var definition = PackageExposureParser.Parse(rule.SourceDefinition); + return PackageExposureEvaluator.Evaluate(definition, payload); + } + catch (PackageExposureException) + { + return false; + } + } + + // YARA-lite: match strings against the raw payload text. + if (rule.Format == AlertRuleFormat.Yara && !string.IsNullOrWhiteSpace(rule.SourceDefinition)) + { + try + { + var definition = YaraLiteParser.Parse(rule.SourceDefinition); + return YaraLiteEvaluator.Evaluate(definition, payload.GetRawText()); + } + catch (YaraLiteException) + { + return false; + } + } + + // Compiled boolean tree (Sigma AND/OR/NOT, 1 of selection_*, all of selection_*). + if (!string.IsNullOrWhiteSpace(rule.CompiledExpressionJson)) + { + var tree = SigmaExpressionSerializer.Deserialize(rule.CompiledExpressionJson); + return tree is not null && SigmaExpressionEvaluator.Evaluate(tree, payload); + } + if (string.IsNullOrWhiteSpace(rule.PayloadPath)) { return true; diff --git a/backend/src/Tawny.Api/Services/ExposureRuleImporter.cs b/backend/src/Tawny.Api/Services/ExposureRuleImporter.cs new file mode 100644 index 0000000..9d68fdc --- /dev/null +++ b/backend/src/Tawny.Api/Services/ExposureRuleImporter.cs @@ -0,0 +1,315 @@ +using System.Text.Json; +using Tawny.Domain; +using Tawny.Domain.Entities; +using Tawny.Infrastructure.Hunting; + +namespace Tawny.Api.Services; + +public class ExposureRuleException(string message) : Exception(message); + +public record ExposureImportResult( + IReadOnlyList Rules, + IReadOnlyList SkippedEntries); + +/// +/// Imports package-exposure rules from supported advisory formats. The two +/// shapes we accept today are: +/// +/// 1. OSV (osv.dev / GitHub Advisory Database) — a JSON object with +/// `id`, `summary`, and an `affected[]` array whose entries carry +/// `package: {ecosystem, name}` plus `ranges[]` or `versions[]`. +/// 2. Simple list — a JSON array of plain objects: +/// `[{ "ecosystem": "npm", "name": "x", "version_pattern": "<=1.2.3" }]` +/// +/// Each affected (ecosystem, name, version_pattern) becomes a separate +/// AlertRule with Format = PackageExposure so the evaluator can short-circuit +/// after a single match. +/// +public class ExposureRuleImporter +{ + private const int MaxRulesPerImport = 1_000; + + public ExposureImportResult Import( + string definition, + AlertSeverity severity, + bool isEnabled, + DateTimeOffset now) + { + if (string.IsNullOrWhiteSpace(definition)) + { + throw new ExposureRuleException("Definition is empty."); + } + + using var doc = ParseJson(definition); + var root = doc.RootElement; + var rules = new List(); + var skipped = new List(); + + if (root.ValueKind == JsonValueKind.Array) + { + foreach (var entry in root.EnumerateArray()) + { + if (rules.Count >= MaxRulesPerImport) + { + skipped.Add($"Stopped at the import limit of {MaxRulesPerImport} rules."); + break; + } + var compiled = CompileSimpleEntry(entry, severity, isEnabled, now); + if (compiled is null) skipped.Add(EntryFingerprint(entry)); + else rules.Add(compiled); + } + } + else if (root.ValueKind == JsonValueKind.Object) + { + // OSV: top-level object with `affected[]`. Multiple advisories + // bundled together as `{ "advisories": [...] }` are also accepted. + if (root.TryGetProperty("advisories", out var bundle) && bundle.ValueKind == JsonValueKind.Array) + { + foreach (var advisory in bundle.EnumerateArray()) + { + AppendOsv(advisory, severity, isEnabled, now, rules, skipped); + } + } + else + { + AppendOsv(root, severity, isEnabled, now, rules, skipped); + } + } + else + { + throw new ExposureRuleException("Definition must be a JSON object (OSV) or an array of {ecosystem, name, version_pattern}."); + } + + if (rules.Count == 0) + { + throw new ExposureRuleException("No exposure rules could be compiled. Check ecosystem/name fields."); + } + return new ExposureImportResult(rules, skipped); + } + + private static JsonDocument ParseJson(string definition) + { + try { return JsonDocument.Parse(definition); } + catch (JsonException ex) + { + throw new ExposureRuleException($"Could not parse JSON: {ex.Message}"); + } + } + + private static void AppendOsv( + JsonElement advisory, + AlertSeverity severity, + bool isEnabled, + DateTimeOffset now, + List rules, + List skipped) + { + var id = advisory.TryGetProperty("id", out var idEl) && idEl.ValueKind == JsonValueKind.String + ? idEl.GetString() + : null; + var summary = advisory.TryGetProperty("summary", out var sEl) && sEl.ValueKind == JsonValueKind.String + ? sEl.GetString() + : null; + var advisoryUrl = ExtractFirstUrl(advisory); + + if (!advisory.TryGetProperty("affected", out var affectedArray) + || affectedArray.ValueKind != JsonValueKind.Array) + { + skipped.Add(id ?? ""); + return; + } + + foreach (var affected in affectedArray.EnumerateArray()) + { + if (!affected.TryGetProperty("package", out var pkg)) continue; + if (!pkg.TryGetProperty("ecosystem", out var ecoEl) || !pkg.TryGetProperty("name", out var nameEl)) continue; + + var ecosystem = ecoEl.GetString(); + var name = nameEl.GetString(); + if (string.IsNullOrWhiteSpace(ecosystem) || string.IsNullOrWhiteSpace(name)) continue; + + var versionPattern = BuildOsvVersionPattern(affected); + if (rules.Count >= MaxRulesPerImport) return; + rules.Add(BuildRule( + ecosystem: NormalizeEcosystem(ecosystem), + name: name, + versionPattern: versionPattern, + advisoryId: id, + advisoryUrl: advisoryUrl, + summary: summary, + severity: severity, + isEnabled: isEnabled, + now: now)); + } + } + + private static string? BuildOsvVersionPattern(JsonElement affected) + { + // OSV `versions[]` is the cleanest signal — explicit list of affected versions. + if (affected.TryGetProperty("versions", out var versions) && versions.ValueKind == JsonValueKind.Array) + { + var list = new List(); + foreach (var v in versions.EnumerateArray()) + { + if (v.ValueKind == JsonValueKind.String && !string.IsNullOrWhiteSpace(v.GetString())) + { + list.Add(v.GetString()!); + } + } + if (list.Count > 0) return string.Join(",", list); + } + + // Fall back to ranges[].events[] -> compile {introduced, fixed} into >=X,(); + foreach (var range in ranges.EnumerateArray()) + { + if (!range.TryGetProperty("events", out var events) || events.ValueKind != JsonValueKind.Array) continue; + string? introduced = null; + string? fixedAt = null; + foreach (var ev in events.EnumerateArray()) + { + if (ev.TryGetProperty("introduced", out var i) && i.ValueKind == JsonValueKind.String) introduced = i.GetString(); + if (ev.TryGetProperty("fixed", out var f) && f.ValueKind == JsonValueKind.String) fixedAt = f.GetString(); + } + // OSV "introduced: 0" means "from the beginning" — omit the lower bound. + if (introduced is not null and not "0") fragments.Add($">={introduced}"); + if (fixedAt is not null) fragments.Add($"<{fixedAt}"); + } + if (fragments.Count > 0) return string.Join(",", fragments); + } + + return null; // No version constraint -> "any version of this package is affected." + } + + private static string? ExtractFirstUrl(JsonElement advisory) + { + if (!advisory.TryGetProperty("references", out var refs) || refs.ValueKind != JsonValueKind.Array) + { + return null; + } + foreach (var r in refs.EnumerateArray()) + { + if (r.TryGetProperty("url", out var u) && u.ValueKind == JsonValueKind.String) + { + return u.GetString(); + } + } + return null; + } + + private static AlertRule? CompileSimpleEntry( + JsonElement entry, + AlertSeverity severity, + bool isEnabled, + DateTimeOffset now) + { + if (!entry.TryGetProperty("ecosystem", out var ecoEl) || ecoEl.ValueKind != JsonValueKind.String) return null; + if (!entry.TryGetProperty("name", out var nameEl) || nameEl.ValueKind != JsonValueKind.String) return null; + var ecosystem = ecoEl.GetString(); + var name = nameEl.GetString(); + if (string.IsNullOrWhiteSpace(ecosystem) || string.IsNullOrWhiteSpace(name)) return null; + + var versionPattern = entry.TryGetProperty("version_pattern", out var vp) && vp.ValueKind == JsonValueKind.String + ? vp.GetString() + : null; + var advisoryId = entry.TryGetProperty("advisory_id", out var aid) && aid.ValueKind == JsonValueKind.String + ? aid.GetString() + : null; + var advisoryUrl = entry.TryGetProperty("advisory_url", out var aurl) && aurl.ValueKind == JsonValueKind.String + ? aurl.GetString() + : null; + + return BuildRule( + ecosystem: NormalizeEcosystem(ecosystem), + name: name, + versionPattern: versionPattern, + advisoryId: advisoryId, + advisoryUrl: advisoryUrl, + summary: null, + severity: severity, + isEnabled: isEnabled, + now: now); + } + + private static AlertRule BuildRule( + string ecosystem, + string name, + string? versionPattern, + string? advisoryId, + string? advisoryUrl, + string? summary, + AlertSeverity severity, + bool isEnabled, + DateTimeOffset now) + { + var definition = new PackageExposureDefinition( + ecosystem, + name, + string.IsNullOrWhiteSpace(versionPattern) ? null : versionPattern, + advisoryId, + advisoryUrl); + + var eventType = ExposureEventType(ecosystem); + var displayPattern = versionPattern ?? "any"; + var externalId = $"exposure:{ecosystem}:{name}:{displayPattern}"; + if (advisoryId is { Length: > 0 }) externalId = $"{externalId}:{advisoryId}"; + if (externalId.Length > 128) externalId = externalId[..128]; + + return new AlertRule + { + Id = Guid.NewGuid(), + Name = $"Exposed {ecosystem}/{name} {displayPattern}", + Format = AlertRuleFormat.PackageExposure, + ExternalId = externalId, + Description = summary ?? $"Installed package {ecosystem}/{name} matches version pattern {displayPattern}.", + EventType = eventType, + Severity = severity, + Operator = AlertRuleOperator.Exists, + SourceDefinition = PackageExposureParser.Serialize(definition), + IsEnabled = isEnabled, + CreatedAt = now, + UpdatedAt = now, + }; + } + + /// + /// Maps an ecosystem token to the telemetry event type it scopes against, + /// so the evaluator only touches the relevant batch. Editor / browser / + /// MCP "ecosystems" are first-class because Bumblebee shipped them that + /// way and they map cleanly onto our new event types. + /// + private static TelemetryEventType ExposureEventType(string ecosystem) => ecosystem.ToLowerInvariant() switch + { + "editor-extension" or "editor_extension" => TelemetryEventType.EditorExtension, + "browser-extension" or "browser_extension" => TelemetryEventType.BrowserExtension, + "mcp" or "mcp_server" or "mcp-server" => TelemetryEventType.McpConfig, + _ => TelemetryEventType.PackageInventory, + }; + + private static string NormalizeEcosystem(string ecosystem) => ecosystem.Trim().ToLowerInvariant() switch + { + // OSV uses TitleCase / mixed; we normalize so the evaluator can do a simple equals. + "go" or "go modules" => "go", + "npm" or "node" => "npm", + "pypi" or "python" => "pypi", + "rubygems" or "gem" => "rubygems", + "packagist" or "composer" => "packagist", + "crates.io" or "rust" => "crates.io", + "maven" or "java" => "maven", + "nuget" or ".net" => "nuget", + var s => s, + }; + + private static string EntryFingerprint(JsonElement entry) + { + var eco = entry.TryGetProperty("ecosystem", out var e) && e.ValueKind == JsonValueKind.String + ? e.GetString() + : "?"; + var name = entry.TryGetProperty("name", out var n) && n.ValueKind == JsonValueKind.String + ? n.GetString() + : "?"; + return $"{eco}/{name}"; + } +} diff --git a/backend/src/Tawny.Api/Services/IocRuleImporter.cs b/backend/src/Tawny.Api/Services/IocRuleImporter.cs index 77b895a..43efe10 100644 --- a/backend/src/Tawny.Api/Services/IocRuleImporter.cs +++ b/backend/src/Tawny.Api/Services/IocRuleImporter.cs @@ -59,6 +59,18 @@ public IocImportResult Import( } rules.Add(rule); + + // Domain IoCs additionally produce a dns_query.qname rule so they + // match against DNS telemetry (Phase 2) without losing the existing + // command-line fallback for installs that don't ship DNS events. + if (indicator.Type is IocIndicatorType.Domain) + { + var dnsRule = CompileDnsRule(indicator, normalizedFormat, severity, isEnabled, now); + if (dnsRule is not null) + { + rules.Add(dnsRule); + } + } } if (rules.Count == 0) @@ -113,6 +125,36 @@ public IocImportResult Import( }; } + private static AlertRule? CompileDnsRule( + IocIndicator indicator, + string sourceFormat, + AlertSeverity severity, + bool isEnabled, + DateTimeOffset now) + { + if (indicator.Type is not IocIndicatorType.Domain) return null; + var value = indicator.Value.ToLowerInvariant(); + var externalId = ExternalId(indicator); + if (externalId.Length >= 124) externalId = externalId[..124]; + return new AlertRule + { + Id = Guid.NewGuid(), + Name = $"{BuildRuleName(indicator)} (DNS)", + Format = AlertRuleFormat.Ioc, + ExternalId = externalId + ":dns", + Description = $"{BuildDescription(indicator, sourceFormat)} Matches DNS queries by qname.", + EventType = TelemetryEventType.DnsQuery, + Severity = severity, + Operator = AlertRuleOperator.Equals, + PayloadPath = "qname", + MatchValue = value, + SourceDefinition = indicator.SourceDefinition, + IsEnabled = isEnabled, + CreatedAt = now, + UpdatedAt = now, + }; + } + private static string BuildRuleName(IocIndicator indicator) { var label = indicator.Type switch diff --git a/backend/src/Tawny.Api/Services/SigmaRuleImporter.cs b/backend/src/Tawny.Api/Services/SigmaRuleImporter.cs index ad268ba..09455b3 100644 --- a/backend/src/Tawny.Api/Services/SigmaRuleImporter.cs +++ b/backend/src/Tawny.Api/Services/SigmaRuleImporter.cs @@ -1,6 +1,9 @@ +using System.Text; using System.Text.Json; +using System.Text.RegularExpressions; using Tawny.Domain; using Tawny.Domain.Entities; +using Tawny.Infrastructure.Hunting; using YamlDotNet.RepresentationModel; namespace Tawny.Api.Services; @@ -39,17 +42,27 @@ public AlertRule Import(string yaml, bool isEnabled, DateTimeOffset now) { throw new SigmaRuleException("Sigma rule detection.condition is required."); } - if (condition.Contains(' ', StringComparison.Ordinal)) + + // Map every selection name (except `condition`) to its compiled SigmaNode. + var selections = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (var pair in detection.Children) { - throw new SigmaRuleException("Only a single Sigma selection condition is supported for now."); + if (pair.Key is not YamlScalarNode keyNode || keyNode.Value is null) continue; + if (string.Equals(keyNode.Value, "condition", StringComparison.OrdinalIgnoreCase)) continue; + if (pair.Value is not YamlMappingNode selectionMap) + { + throw new SigmaRuleException($"Selection '{keyNode.Value}' must be a mapping."); + } + selections[keyNode.Value] = CompileSelectionNode(selectionMap); } - var selection = Mapping(detection, condition) - ?? throw new SigmaRuleException($"Sigma selection '{condition}' was not found."); - var predicate = CompileSelection(selection); - var logsource = Mapping(root, "logsource"); + if (selections.Count == 0) + { + throw new SigmaRuleException("Sigma rule needs at least one named selection block."); + } - return new AlertRule + var logsource = Mapping(root, "logsource"); + var rule = new AlertRule { Id = Guid.NewGuid(), Name = title.Trim(), @@ -58,15 +71,61 @@ public AlertRule Import(string yaml, bool isEnabled, DateTimeOffset now) Description = Normalize(Scalar(root, "description")), EventType = MapEventType(logsource), Severity = MapSeverity(Scalar(root, "level")), - Operator = predicate.Operator, - PayloadPath = predicate.PayloadPath, - MatchValue = predicate.MatchValue, SourceDefinition = yaml, IsEnabled = isEnabled, MitreTechniquesJson = ExtractMitreTechniques(root), CreatedAt = now, UpdatedAt = now, }; + + // Fast path: single named selection referenced directly. Stays as a + // single-predicate rule so the existing legacy fields and the existing + // UI keep working without change. + if (selections.Count == 1 + && selections.TryGetValue(condition.Trim(), out var solo) + && solo is SigmaFieldPredicate predicate + && predicate.Values.Count > 0) + { + rule.Operator = predicate.Operator; + rule.PayloadPath = predicate.PayloadPath; + rule.MatchValue = predicate.Values.Count == 1 + ? predicate.Values[0] + : JsonSerializer.Serialize(predicate.Values, JsonOptions); + return rule; + } + + // General path: parse the condition into a SigmaNode tree by resolving + // names + globs against the compiled selections. + var tree = SigmaConditionParser.Parse(condition, selections); + rule.CompiledExpressionJson = SigmaExpressionSerializer.Serialize(tree); + // Leave legacy predicate fields null — the evaluator falls back to CompiledExpression. + return rule; + } + + private static SigmaNode CompileSelectionNode(YamlMappingNode selection) + { + if (selection.Children.Count == 0) + { + throw new SigmaRuleException("Selection must have at least one field predicate."); + } + var children = new List(); + foreach (var pair in selection.Children) + { + if (pair.Key is not YamlScalarNode keyNode || string.IsNullOrWhiteSpace(keyNode.Value)) + { + throw new SigmaRuleException("Selection field name must be a scalar."); + } + var (fieldRaw, op) = ParseField(keyNode.Value); + var field = NormalizeField(fieldRaw); + var values = Values(pair.Value); + if (values.Count == 0 && op != AlertRuleOperator.Exists) + { + throw new SigmaRuleException("Selection value is required."); + } + children.Add(new SigmaFieldPredicate(field, op, values)); + } + // Multiple fields in the same selection mapping => AND across them (Sigma semantics). + return children.Count == 1 ? children[0] : new SigmaAnd(children); } private static string? ExtractMitreTechniques(YamlMappingNode root) @@ -95,34 +154,6 @@ public AlertRule Import(string yaml, bool isEnabled, DateTimeOffset now) return JsonSerializer.Serialize(techniques.Distinct().ToList(), JsonOptions); } - private static CompiledPredicate CompileSelection(YamlMappingNode selection) - { - if (selection.Children.Count != 1) - { - throw new SigmaRuleException("Only one field predicate per Sigma selection is supported for now."); - } - - var pair = selection.Children.Single(); - if (pair.Key is not YamlScalarNode keyNode || string.IsNullOrWhiteSpace(keyNode.Value)) - { - throw new SigmaRuleException("Sigma selection field must be a scalar."); - } - - var (field, op) = ParseField(keyNode.Value); - var values = Values(pair.Value); - if (values.Count == 0) - { - throw new SigmaRuleException("Sigma selection value is required."); - } - - return new CompiledPredicate( - NormalizeField(field), - op, - values.Count == 1 - ? values[0] - : JsonSerializer.Serialize(values, JsonOptions)); - } - private static string NormalizeField(string field) => field switch { "Image" or "process.name" or "process.executable" => "processes.name", @@ -229,11 +260,160 @@ private static List Values(YamlNode node) var trimmed = value?.Trim(); return string.IsNullOrEmpty(trimmed) ? null : trimmed; } +} - private sealed record CompiledPredicate( - string PayloadPath, - AlertRuleOperator Operator, - string MatchValue); +/// +/// Tiny recursive-descent parser for Sigma `condition:` strings. +/// Supports: selection_name | not | and | or | () | "1 of name_*" | "all of name_*" +/// Globs are resolved against the dictionary of compiled selections. +/// +internal static class SigmaConditionParser +{ + public static SigmaNode Parse(string condition, IReadOnlyDictionary selections) + { + var tokens = Tokenize(condition); + var pos = 0; + var node = ParseOr(tokens, ref pos, selections); + if (pos < tokens.Count) + { + throw new SigmaRuleException($"Unexpected token '{tokens[pos]}' at end of condition."); + } + return node; + } + + private static SigmaNode ParseOr(IReadOnlyList tokens, ref int pos, IReadOnlyDictionary selections) + { + var left = ParseAnd(tokens, ref pos, selections); + while (pos < tokens.Count && string.Equals(tokens[pos], "or", StringComparison.OrdinalIgnoreCase)) + { + pos++; + var right = ParseAnd(tokens, ref pos, selections); + left = new SigmaOr(Flatten(left, right)); + } + return left; + } + + private static SigmaNode ParseAnd(IReadOnlyList tokens, ref int pos, IReadOnlyDictionary selections) + { + var left = ParseUnary(tokens, ref pos, selections); + while (pos < tokens.Count && string.Equals(tokens[pos], "and", StringComparison.OrdinalIgnoreCase)) + { + pos++; + var right = ParseUnary(tokens, ref pos, selections); + left = new SigmaAnd(Flatten(left, right)); + } + return left; + } + + private static SigmaNode ParseUnary(IReadOnlyList tokens, ref int pos, IReadOnlyDictionary selections) + { + if (pos >= tokens.Count) throw new SigmaRuleException("Unexpected end of condition."); + var token = tokens[pos]; + if (string.Equals(token, "not", StringComparison.OrdinalIgnoreCase)) + { + pos++; + return new SigmaNot(ParseUnary(tokens, ref pos, selections)); + } + if (token == "(") + { + pos++; + var inner = ParseOr(tokens, ref pos, selections); + if (pos >= tokens.Count || tokens[pos] != ")") + { + throw new SigmaRuleException("Expected ')'."); + } + pos++; + return inner; + } + if (string.Equals(token, "1", StringComparison.Ordinal) + || string.Equals(token, "all", StringComparison.OrdinalIgnoreCase)) + { + var quantifier = token; + pos++; + if (pos >= tokens.Count || !string.Equals(tokens[pos], "of", StringComparison.OrdinalIgnoreCase)) + { + throw new SigmaRuleException("Expected 'of' after quantifier."); + } + pos++; + if (pos >= tokens.Count) throw new SigmaRuleException("Expected pattern after 'of'."); + var pattern = tokens[pos]; + pos++; + var matched = ResolveGlob(pattern, selections); + if (matched.Count == 0) + { + throw new SigmaRuleException($"No selections matched pattern '{pattern}'."); + } + return string.Equals(quantifier, "1", StringComparison.Ordinal) + ? new SigmaAnyOf(matched) + : new SigmaAllOf(matched); + } + // Bare selection name. + pos++; + if (!selections.TryGetValue(token, out var selection)) + { + throw new SigmaRuleException($"Unknown selection '{token}' in condition."); + } + return selection; + } + + private static List ResolveGlob(string pattern, IReadOnlyDictionary selections) + { + var matched = new List(); + var regex = new Regex("^" + Regex.Escape(pattern).Replace("\\*", ".*") + "$", + RegexOptions.IgnoreCase); + foreach (var (name, node) in selections) + { + if (regex.IsMatch(name)) matched.Add(node); + } + return matched; + } + + private static IReadOnlyList Flatten(SigmaNode left, SigmaNode right) where T : SigmaNode + { + var list = new List(); + AddFlattened(list, left); + AddFlattened(list, right); + return list; + } + + private static void AddFlattened(List list, SigmaNode node) where T : SigmaNode + { + if (typeof(T) == typeof(SigmaAnd) && node is SigmaAnd and) + { + list.AddRange(and.Children); + return; + } + if (typeof(T) == typeof(SigmaOr) && node is SigmaOr or) + { + list.AddRange(or.Children); + return; + } + list.Add(node); + } + + private static List Tokenize(string condition) + { + var tokens = new List(); + var i = 0; + while (i < condition.Length) + { + var c = condition[i]; + if (char.IsWhiteSpace(c)) { i++; continue; } + if (c == '(' || c == ')') + { + tokens.Add(c.ToString()); + i++; + continue; + } + var start = i; + while (i < condition.Length && !char.IsWhiteSpace(condition[i]) && condition[i] != '(' && condition[i] != ')') + { + i++; + } + tokens.Add(condition[start..i]); + } + return tokens; + } } public class SigmaRuleException(string message) : Exception(message); diff --git a/backend/src/Tawny.Domain/Entities/Alert.cs b/backend/src/Tawny.Domain/Entities/Alert.cs index a09c37c..06123a5 100644 --- a/backend/src/Tawny.Domain/Entities/Alert.cs +++ b/backend/src/Tawny.Domain/Entities/Alert.cs @@ -16,6 +16,7 @@ public class Alert public string? SentinelNotificationError { get; set; } public required string Title { get; set; } public string? Description { get; set; } + public string? EnrichmentJson { get; set; } public DateTimeOffset CreatedAt { get; set; } public AlertRule? AlertRule { get; set; } diff --git a/backend/src/Tawny.Domain/Entities/AlertRule.cs b/backend/src/Tawny.Domain/Entities/AlertRule.cs index 54bddaf..b669b0d 100644 --- a/backend/src/Tawny.Domain/Entities/AlertRule.cs +++ b/backend/src/Tawny.Domain/Entities/AlertRule.cs @@ -13,6 +13,7 @@ public class AlertRule public string? PayloadPath { get; set; } public string? MatchValue { get; set; } public string? SourceDefinition { get; set; } + public string? CompiledExpressionJson { get; set; } public bool IsEnabled { get; set; } = true; public string? MitreTechniquesJson { get; set; } public DateTimeOffset CreatedAt { get; set; } diff --git a/backend/src/Tawny.Domain/Entities/Case.cs b/backend/src/Tawny.Domain/Entities/Case.cs new file mode 100644 index 0000000..5a0b0a0 --- /dev/null +++ b/backend/src/Tawny.Domain/Entities/Case.cs @@ -0,0 +1,61 @@ +namespace Tawny.Domain.Entities; + +public enum CaseStatus +{ + Open = 0, + Investigating = 1, + Contained = 2, + Resolved = 3, + Closed = 4, +} + +public enum CasePriority +{ + Low = 0, + Medium = 1, + High = 2, + Critical = 3, +} + +public class Case +{ + public long Id { get; set; } + public Guid TenantId { get; set; } + public required string Title { get; set; } + public string? Summary { get; set; } + public CaseStatus Status { get; set; } = CaseStatus.Open; + public CasePriority Priority { get; set; } = CasePriority.Medium; + public Guid? AssignedToUserId { get; set; } + public Guid? CreatedByUserId { get; set; } + public DateTimeOffset CreatedAt { get; set; } + public DateTimeOffset UpdatedAt { get; set; } + public DateTimeOffset? ClosedAt { get; set; } + public string? MitreTechniquesJson { get; set; } + + public Tenant? Tenant { get; set; } + public List CaseAlerts { get; set; } = []; + public List Notes { get; set; } = []; +} + +public class CaseAlert +{ + public long Id { get; set; } + public long CaseId { get; set; } + public long AlertId { get; set; } + public DateTimeOffset AddedAt { get; set; } + public Guid? AddedByUserId { get; set; } + + public Case? Case { get; set; } + public Alert? Alert { get; set; } +} + +public class CaseNote +{ + public long Id { get; set; } + public long CaseId { get; set; } + public Guid? AuthorUserId { get; set; } + public required string Body { get; set; } + public DateTimeOffset CreatedAt { get; set; } + + public Case? Case { get; set; } +} diff --git a/backend/src/Tawny.Domain/Entities/ReputationCacheEntry.cs b/backend/src/Tawny.Domain/Entities/ReputationCacheEntry.cs new file mode 100644 index 0000000..7250329 --- /dev/null +++ b/backend/src/Tawny.Domain/Entities/ReputationCacheEntry.cs @@ -0,0 +1,15 @@ +namespace Tawny.Domain.Entities; + +public class ReputationCacheEntry +{ + public long Id { get; set; } + public Guid TenantId { get; set; } + public ReputationProvider Provider { get; set; } + public required string IndicatorKind { get; set; } + public required string IndicatorValue { get; set; } + public ReputationVerdict Verdict { get; set; } + public int? Score { get; set; } + public required string DetailJson { get; set; } + public DateTimeOffset FetchedAt { get; set; } + public DateTimeOffset ExpiresAt { get; set; } +} diff --git a/backend/src/Tawny.Domain/Entities/SavedHunt.cs b/backend/src/Tawny.Domain/Entities/SavedHunt.cs index 155c0bb..906d926 100644 --- a/backend/src/Tawny.Domain/Entities/SavedHunt.cs +++ b/backend/src/Tawny.Domain/Entities/SavedHunt.cs @@ -15,6 +15,7 @@ public class SavedHunt public string? MitreTechniquesJson { get; set; } public DateTimeOffset? LastRunAt { get; set; } public int? LastMatchCount { get; set; } + public bool IsShared { get; set; } = true; public DateTimeOffset CreatedAt { get; set; } public DateTimeOffset UpdatedAt { get; set; } diff --git a/backend/src/Tawny.Domain/Entities/Tenant.cs b/backend/src/Tawny.Domain/Entities/Tenant.cs index eca6a45..cf7b963 100644 --- a/backend/src/Tawny.Domain/Entities/Tenant.cs +++ b/backend/src/Tawny.Domain/Entities/Tenant.cs @@ -15,4 +15,6 @@ public class Tenant public List SavedHunts { get; set; } = []; public List SuppressionRules { get; set; } = []; public List ApiTokens { get; set; } = []; + public List ThreatIntelFeeds { get; set; } = []; + public List Cases { get; set; } = []; } diff --git a/backend/src/Tawny.Domain/Entities/ThreatIntelFeed.cs b/backend/src/Tawny.Domain/Entities/ThreatIntelFeed.cs new file mode 100644 index 0000000..72a5ee3 --- /dev/null +++ b/backend/src/Tawny.Domain/Entities/ThreatIntelFeed.cs @@ -0,0 +1,27 @@ +namespace Tawny.Domain.Entities; + +public class ThreatIntelFeed +{ + public Guid Id { get; set; } + public Guid TenantId { get; set; } + public required string Name { get; set; } + public ThreatIntelFeedKind Kind { get; set; } + public required string Url { get; set; } + public string? AuthHeaderName { get; set; } + public string? AuthHeaderValueEncrypted { get; set; } + public AlertSeverity DefaultSeverity { get; set; } = AlertSeverity.High; + public bool IsEnabled { get; set; } = true; + public int IntervalMinutes { get; set; } = 60; + public ThreatIntelFeedStatus Status { get; set; } = ThreatIntelFeedStatus.NeverRun; + public DateTimeOffset? LastRunAt { get; set; } + public DateTimeOffset? LastSuccessAt { get; set; } + public int LastImportedCount { get; set; } + public int LastSkippedCount { get; set; } + public string? LastError { get; set; } + public string? Etag { get; set; } + public Guid? CreatedByUserId { get; set; } + public DateTimeOffset CreatedAt { get; set; } + public DateTimeOffset UpdatedAt { get; set; } + + public Tenant? Tenant { get; set; } +} diff --git a/backend/src/Tawny.Domain/Enums.cs b/backend/src/Tawny.Domain/Enums.cs index 8f83617..761ee9d 100644 --- a/backend/src/Tawny.Domain/Enums.cs +++ b/backend/src/Tawny.Domain/Enums.cs @@ -29,6 +29,13 @@ public enum TelemetryEventType SystemInfo = 3, FileIntegrity = 4, Heartbeat = 5, + DnsQuery = 6, + ProcessLaunch = 7, + FileEvent = 8, + PackageInventory = 9, + EditorExtension = 10, + BrowserExtension = 11, + McpConfig = 12, } public enum UserRole @@ -74,6 +81,9 @@ public enum AlertRuleFormat TawnyPredicate = 0, Sigma = 1, Ioc = 2, + Sequence = 3, + Yara = 4, + PackageExposure = 5, } public enum ResponseActionType @@ -103,3 +113,38 @@ public enum SuppressionScope AllRules = 0, SpecificRule = 1, } + +public enum ThreatIntelFeedKind +{ + UrlhausCsv = 0, + UrlhausJson = 1, + OtxPulse = 2, + MispEvents = 3, + Taxii21 = 4, + GenericCsv = 5, + OsvVulnerabilities = 6, +} + +public enum ThreatIntelFeedStatus +{ + Healthy = 0, + Degraded = 1, + Failed = 2, + NeverRun = 3, +} + +public enum ReputationProvider +{ + VirusTotal = 0, + AbuseIpDb = 1, + GreyNoise = 2, +} + +public enum ReputationVerdict +{ + Unknown = 0, + Clean = 1, + Suspicious = 2, + Malicious = 3, + Error = 4, +} diff --git a/backend/src/Tawny.Infrastructure/Hunting/PackageExposure.cs b/backend/src/Tawny.Infrastructure/Hunting/PackageExposure.cs new file mode 100644 index 0000000..4ba0c88 --- /dev/null +++ b/backend/src/Tawny.Infrastructure/Hunting/PackageExposure.cs @@ -0,0 +1,168 @@ +using System.Text.Json; +using System.Text.Json.Serialization; +using Tawny.Domain; + +namespace Tawny.Infrastructure.Hunting; + +public class PackageExposureException(string message) : Exception(message); + +/// +/// JSON shape stored on AlertRule.SourceDefinition when Format = PackageExposure. +/// Inspired by Perplexity's Bumblebee scanner — the rule matches a (ecosystem, +/// name, version_pattern) triple against package_inventory events emitted by +/// the agent. Version patterns support exact match, comma-separated lists, or +/// simple npm-style range strings (^, ~, >=, <=, <, >). +/// +public record PackageExposureDefinition( + [property: JsonPropertyName("ecosystem")] string Ecosystem, + [property: JsonPropertyName("name")] string Name, + [property: JsonPropertyName("version_pattern")] string? VersionPattern, + [property: JsonPropertyName("advisory_id")] string? AdvisoryId, + [property: JsonPropertyName("advisory_url")] string? AdvisoryUrl); + +public static class PackageExposureParser +{ + private static readonly JsonSerializerOptions JsonOptions = new(JsonSerializerDefaults.Web) + { + PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, + }; + + public static PackageExposureDefinition Parse(string json) + { + if (string.IsNullOrWhiteSpace(json)) + { + throw new PackageExposureException("Package exposure definition is empty."); + } + PackageExposureDefinition? def; + try { def = JsonSerializer.Deserialize(json, JsonOptions); } + catch (JsonException ex) + { + throw new PackageExposureException($"Invalid package exposure JSON: {ex.Message}"); + } + if (def is null) throw new PackageExposureException("Definition deserialized to null."); + if (string.IsNullOrWhiteSpace(def.Ecosystem)) + { + throw new PackageExposureException("ecosystem is required (npm, pypi, go, rubygems, packagist, mcp, editor-extension, browser-extension)."); + } + if (string.IsNullOrWhiteSpace(def.Name)) + { + throw new PackageExposureException("name is required."); + } + return def; + } + + public static string Serialize(PackageExposureDefinition def) => JsonSerializer.Serialize(def, JsonOptions); +} + +public static class PackageExposureEvaluator +{ + /// + /// Returns true if the supplied package_inventory / editor_extension / + /// browser_extension / mcp_config payload matches this exposure + /// definition. Caller is responsible for filtering by EventType before + /// calling so we don't waste cycles on irrelevant events. + /// + public static bool Evaluate(PackageExposureDefinition definition, JsonElement payload) + { + if (!payload.TryGetProperty("ecosystem", out var ecosystem) + || ecosystem.ValueKind != JsonValueKind.String) return false; + if (!string.Equals(ecosystem.GetString(), definition.Ecosystem, StringComparison.OrdinalIgnoreCase)) + { + return false; + } + + if (!payload.TryGetProperty("name", out var name) || name.ValueKind != JsonValueKind.String) return false; + if (!string.Equals(name.GetString(), definition.Name, StringComparison.OrdinalIgnoreCase)) + { + return false; + } + + if (string.IsNullOrWhiteSpace(definition.VersionPattern)) + { + // No version filter = "any installed version of this package is exposed". + return true; + } + + if (!payload.TryGetProperty("version", out var version) || version.ValueKind != JsonValueKind.String) + { + // Pattern was specified but the event has no version — treat as no match. + return false; + } + + return VersionMatches(definition.VersionPattern, version.GetString() ?? ""); + } + + private static bool VersionMatches(string pattern, string actual) + { + if (string.IsNullOrWhiteSpace(actual)) return false; + foreach (var raw in pattern.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)) + { + if (MatchesSingleRange(raw, actual)) return true; + } + return false; + } + + private static bool MatchesSingleRange(string range, string actual) + { + if (range == "*") return true; + if (range.StartsWith(">=")) return CompareVersion(actual, range[2..].Trim()) >= 0; + if (range.StartsWith("<=")) return CompareVersion(actual, range[2..].Trim()) <= 0; + if (range.StartsWith(">")) return CompareVersion(actual, range[1..].Trim()) > 0; + if (range.StartsWith("<")) return CompareVersion(actual, range[1..].Trim()) < 0; + if (range.StartsWith("=")) + { + return string.Equals(range[1..].Trim(), actual, StringComparison.OrdinalIgnoreCase); + } + if (range.StartsWith("^")) + { + // ^1.2.3 means >=1.2.3 and <2.0.0 (caret pins the leftmost non-zero major). + var anchor = ParseSemver(range[1..].Trim()); + var current = ParseSemver(actual); + if (anchor is null || current is null) return false; + if (current.Value.Major != anchor.Value.Major) return false; + return CompareSemver(current.Value, anchor.Value) >= 0; + } + if (range.StartsWith("~")) + { + // ~1.2.3 means >=1.2.3 and <1.3.0 (tilde pins major.minor). + var anchor = ParseSemver(range[1..].Trim()); + var current = ParseSemver(actual); + if (anchor is null || current is null) return false; + if (current.Value.Major != anchor.Value.Major + || current.Value.Minor != anchor.Value.Minor) return false; + return CompareSemver(current.Value, anchor.Value) >= 0; + } + // Default: exact string equality (case-insensitive). Covers commit hashes, named tags. + return string.Equals(range, actual, StringComparison.OrdinalIgnoreCase); + } + + private static int CompareVersion(string a, string b) + { + var sa = ParseSemver(a); + var sb = ParseSemver(b); + if (sa is not null && sb is not null) return CompareSemver(sa.Value, sb.Value); + return string.Compare(a, b, StringComparison.OrdinalIgnoreCase); + } + + private static (int Major, int Minor, int Patch)? ParseSemver(string raw) + { + // Strip leading "v" and any prerelease/build suffix (-alpha, +build). + var stripped = raw.TrimStart('v', 'V'); + var split = stripped.IndexOfAny(new[] { '-', '+' }); + if (split >= 0) stripped = stripped[..split]; + var parts = stripped.Split('.'); + if (parts.Length == 0) return null; + int major = 0, minor = 0, patch = 0; + if (!int.TryParse(parts[0], out major)) return null; + if (parts.Length > 1 && !int.TryParse(parts[1], out minor)) minor = 0; + if (parts.Length > 2 && !int.TryParse(parts[2], out patch)) patch = 0; + return (major, minor, patch); + } + + private static int CompareSemver((int Major, int Minor, int Patch) a, (int Major, int Minor, int Patch) b) + { + if (a.Major != b.Major) return a.Major.CompareTo(b.Major); + if (a.Minor != b.Minor) return a.Minor.CompareTo(b.Minor); + return a.Patch.CompareTo(b.Patch); + } +} diff --git a/backend/src/Tawny.Infrastructure/Hunting/RuleTestHarness.cs b/backend/src/Tawny.Infrastructure/Hunting/RuleTestHarness.cs new file mode 100644 index 0000000..e1073c1 --- /dev/null +++ b/backend/src/Tawny.Infrastructure/Hunting/RuleTestHarness.cs @@ -0,0 +1,193 @@ +using System.Globalization; +using System.Text.Json; +using Tawny.Domain; +using Tawny.Domain.Entities; + +namespace Tawny.Infrastructure.Hunting; + +public record RuleTestEventInput( + TelemetryEventType EventType, + DateTimeOffset OccurredAt, + JsonElement Payload); + +public record RuleTestStepTrace( + int Index, + string Step, + bool Matched, + string? FailReason); + +public record RuleTestResult( + bool Matched, + string? FailReason, + IReadOnlyList Trace); + +/// +/// Pure-function tester that runs an in-memory AlertRule (any format) against +/// supplied event(s). No DB writes, no broker publish, no sinks — used by the +/// /rule-test endpoint so detection authors can iterate on rules quickly. +/// +public class RuleTestHarness +{ + public RuleTestResult Test(AlertRule rule, IReadOnlyList events) + { + if (events.Count == 0) + { + return new RuleTestResult(false, "no events supplied", []); + } + + return rule.Format switch + { + AlertRuleFormat.Sequence => TestSequence(rule, events), + _ => TestSinglePredicate(rule, events), + }; + } + + private static RuleTestResult TestSinglePredicate(AlertRule rule, IReadOnlyList events) + { + var trace = new List(); + for (var i = 0; i < events.Count; i++) + { + var ev = events[i]; + if (rule.EventType is not null && rule.EventType.Value != ev.EventType) + { + trace.Add(new RuleTestStepTrace(i, $"event[{i}] {ev.EventType}", false, + $"event_type {ev.EventType} does not match rule event_type {rule.EventType}")); + continue; + } + if (string.IsNullOrWhiteSpace(rule.PayloadPath)) + { + trace.Add(new RuleTestStepTrace(i, $"event[{i}]", true, null)); + return new RuleTestResult(true, null, trace); + } + var values = ResolvePath(ev.Payload, rule.PayloadPath).ToList(); + if (values.Count == 0) + { + trace.Add(new RuleTestStepTrace(i, $"event[{i}] {rule.PayloadPath}", false, + $"payload_path '{rule.PayloadPath}' was not found in the event payload")); + continue; + } + if (RuleMatches(rule, values)) + { + trace.Add(new RuleTestStepTrace(i, $"event[{i}] {rule.PayloadPath} {rule.Operator} {rule.MatchValue}", true, null)); + return new RuleTestResult(true, null, trace); + } + trace.Add(new RuleTestStepTrace(i, $"event[{i}] {rule.PayloadPath} {rule.Operator} {rule.MatchValue}", false, + $"value(s) {string.Join(", ", values.Select(JsonScalar))} did not satisfy the predicate")); + } + return new RuleTestResult(false, "no event satisfied the predicate", trace); + } + + private static RuleTestResult TestSequence(AlertRule rule, IReadOnlyList events) + { + SequenceRuleDefinition definition; + try { definition = SequenceRuleParser.Parse(rule.SourceDefinition ?? ""); } + catch (SequenceRuleException ex) + { + return new RuleTestResult(false, ex.Message, []); + } + + var trace = new List(); + var matched = 0; + var firstMatchTime = DateTimeOffset.MinValue; + var ordered = events.OrderBy(e => e.OccurredAt).ToList(); + foreach (var ev in ordered) + { + if (matched >= definition.Steps.Count) break; + var step = definition.Steps[matched]; + if (step.EventType != ev.EventType) + { + trace.Add(new RuleTestStepTrace(matched, step.Name, false, + $"step expects {step.EventType} but event was {ev.EventType}")); + continue; + } + if (!StepMatches(step, ev.Payload)) + { + trace.Add(new RuleTestStepTrace(matched, step.Name, false, + $"payload did not satisfy step predicate")); + continue; + } + if (matched > 0 + && (ev.OccurredAt - firstMatchTime).TotalSeconds > definition.WindowSeconds) + { + trace.Add(new RuleTestStepTrace(matched, step.Name, false, + $"step occurred {(ev.OccurredAt - firstMatchTime).TotalSeconds:F0}s after step 0, outside window_seconds={definition.WindowSeconds}")); + continue; + } + if (matched == 0) firstMatchTime = ev.OccurredAt; + trace.Add(new RuleTestStepTrace(matched, step.Name, true, null)); + matched += 1; + } + + if (matched == definition.Steps.Count) + { + return new RuleTestResult(true, null, trace); + } + return new RuleTestResult(false, $"matched {matched} of {definition.Steps.Count} steps", trace); + } + + private static bool RuleMatches(AlertRule rule, IEnumerable values) + => rule.Operator switch + { + AlertRuleOperator.Exists => true, + AlertRuleOperator.Equals => values.Any(v => string.Equals(JsonScalar(v), rule.MatchValue, StringComparison.OrdinalIgnoreCase)), + AlertRuleOperator.Contains => values.Any(v => !string.IsNullOrEmpty(rule.MatchValue) && JsonScalar(v).Contains(rule.MatchValue, StringComparison.OrdinalIgnoreCase)), + AlertRuleOperator.GreaterThan => values.Any(v => CompareNumber(v, rule.MatchValue, (a, b) => a > b)), + AlertRuleOperator.LessThan => values.Any(v => CompareNumber(v, rule.MatchValue, (a, b) => a < b)), + _ => false, + }; + + private static bool StepMatches(SequenceStep step, JsonElement payload) + { + if (string.IsNullOrWhiteSpace(step.PayloadPath)) return true; + var values = ResolvePath(payload, step.PayloadPath).ToList(); + if (values.Count == 0) return false; + return step.Operator switch + { + AlertRuleOperator.Exists => true, + AlertRuleOperator.Equals => values.Any(v => string.Equals(JsonScalar(v), step.MatchValue, StringComparison.OrdinalIgnoreCase)), + AlertRuleOperator.Contains => values.Any(v => !string.IsNullOrEmpty(step.MatchValue) && JsonScalar(v).Contains(step.MatchValue, StringComparison.OrdinalIgnoreCase)), + AlertRuleOperator.GreaterThan => values.Any(v => CompareNumber(v, step.MatchValue, (a, b) => a > b)), + AlertRuleOperator.LessThan => values.Any(v => CompareNumber(v, step.MatchValue, (a, b) => a < b)), + _ => false, + }; + } + + private static IEnumerable ResolvePath(JsonElement root, string path) + { + var segments = path.Split('.', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + return ResolvePath(root, segments, 0); + } + + private static IEnumerable ResolvePath(JsonElement current, IReadOnlyList segments, int index) + { + if (index >= segments.Count) { yield return current; yield break; } + if (current.ValueKind == JsonValueKind.Array) + { + foreach (var item in current.EnumerateArray()) + { + foreach (var v in ResolvePath(item, segments, index)) yield return v; + } + yield break; + } + if (current.ValueKind != JsonValueKind.Object) yield break; + if (!current.TryGetProperty(segments[index], out var child)) yield break; + foreach (var v in ResolvePath(child, segments, index + 1)) yield return v; + } + + private static string JsonScalar(JsonElement value) => value.ValueKind switch + { + JsonValueKind.String => value.GetString() ?? "", + JsonValueKind.Number => value.GetRawText(), + JsonValueKind.True => "true", + JsonValueKind.False => "false", + JsonValueKind.Null => "", + _ => value.GetRawText(), + }; + + private static bool CompareNumber(JsonElement value, string? expected, Func cmp) + { + if (!decimal.TryParse(JsonScalar(value), NumberStyles.Float, CultureInfo.InvariantCulture, out var left)) return false; + if (!decimal.TryParse(expected, NumberStyles.Float, CultureInfo.InvariantCulture, out var right)) return false; + return cmp(left, right); + } +} diff --git a/backend/src/Tawny.Infrastructure/Hunting/SequenceRule.cs b/backend/src/Tawny.Infrastructure/Hunting/SequenceRule.cs new file mode 100644 index 0000000..83e429e --- /dev/null +++ b/backend/src/Tawny.Infrastructure/Hunting/SequenceRule.cs @@ -0,0 +1,83 @@ +using System.Globalization; +using System.Text.Json; +using System.Text.Json.Serialization; +using Tawny.Domain; +using Tawny.Domain.Entities; + +namespace Tawny.Infrastructure.Hunting; + +public class SequenceRuleException(string message) : Exception(message); + +/// +/// JSON shape stored on AlertRule.SourceDefinition when Format = Sequence. +/// Each step is a predicate that must match an event of the named type; +/// steps must occur in order on the same host, within the rule's time window. +/// +public record SequenceRuleDefinition( + [property: JsonPropertyName("window_seconds")] int WindowSeconds, + [property: JsonPropertyName("group_by")] string GroupBy, + [property: JsonPropertyName("steps")] IReadOnlyList Steps); + +public record SequenceStep( + [property: JsonPropertyName("name")] string Name, + [property: JsonPropertyName("event_type")] TelemetryEventType EventType, + [property: JsonPropertyName("payload_path")] string? PayloadPath, + [property: JsonPropertyName("operator")] AlertRuleOperator Operator, + [property: JsonPropertyName("match_value")] string? MatchValue); + +public static class SequenceRuleParser +{ + private static readonly JsonSerializerOptions JsonOptions = new(JsonSerializerDefaults.Web) + { + PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, + Converters = { new JsonStringEnumConverter(JsonNamingPolicy.SnakeCaseLower) }, + }; + + public static SequenceRuleDefinition Parse(string json) + { + if (string.IsNullOrWhiteSpace(json)) + { + throw new SequenceRuleException("Sequence rule definition is empty."); + } + SequenceRuleDefinition? def; + try + { + def = JsonSerializer.Deserialize(json, JsonOptions); + } + catch (JsonException ex) + { + throw new SequenceRuleException($"Invalid sequence rule JSON: {ex.Message}"); + } + if (def is null) + { + throw new SequenceRuleException("Sequence rule definition deserialized to null."); + } + if (def.WindowSeconds <= 0 || def.WindowSeconds > 86_400) + { + throw new SequenceRuleException("window_seconds must be between 1 and 86400."); + } + if (def.Steps is null || def.Steps.Count < 2) + { + throw new SequenceRuleException("A sequence rule needs at least two steps."); + } + if (def.Steps.Count > 8) + { + throw new SequenceRuleException("A sequence rule can have at most 8 steps."); + } + foreach (var step in def.Steps) + { + if (string.IsNullOrWhiteSpace(step.Name)) + { + throw new SequenceRuleException("Each step needs a non-empty name."); + } + if (step.Operator != AlertRuleOperator.Exists && string.IsNullOrWhiteSpace(step.MatchValue)) + { + throw new SequenceRuleException($"Step '{step.Name}' needs a match_value (or use the exists operator)."); + } + } + return def; + } + + public static string Serialize(SequenceRuleDefinition def) + => JsonSerializer.Serialize(def, JsonOptions); +} diff --git a/backend/src/Tawny.Infrastructure/Hunting/SequenceRuleEvaluator.cs b/backend/src/Tawny.Infrastructure/Hunting/SequenceRuleEvaluator.cs new file mode 100644 index 0000000..4d6e2af --- /dev/null +++ b/backend/src/Tawny.Infrastructure/Hunting/SequenceRuleEvaluator.cs @@ -0,0 +1,153 @@ +using System.Collections.Concurrent; +using System.Globalization; +using System.Text.Json; +using Microsoft.EntityFrameworkCore; +using Tawny.Domain; +using Tawny.Domain.Entities; + +namespace Tawny.Infrastructure.Hunting; + +/// +/// Tracks in-flight sequence matches keyed by (rule, host). State is process- +/// local: we deliberately don't persist partial progress, because operators +/// expect EDR detections to fire when the *whole* sequence is observed within +/// the window, and survivability across restarts isn't worth the storage +/// churn. Rebuilds on the next matching event after a restart. +/// +public class SequenceRuleEvaluator +{ + private readonly ConcurrentDictionary<(Guid RuleId, Guid AgentId), SequenceState> _state = new(); + + public IReadOnlyList Evaluate( + AlertRule rule, + SequenceRuleDefinition definition, + Agent agent, + IReadOnlyList events, + DateTimeOffset now) + { + var window = TimeSpan.FromSeconds(definition.WindowSeconds); + var matches = new List(); + var key = (rule.Id, agent.Id); + var state = _state.GetOrAdd(key, _ => new SequenceState()); + + foreach (var ev in events.OrderBy(e => e.OccurredAt)) + { + JsonDocument doc; + try { doc = JsonDocument.Parse(ev.Payload); } + catch { continue; } + + using (doc) + { + var nextStepIndex = state.MatchedSteps.Count; + if (nextStepIndex >= definition.Steps.Count) continue; + var step = definition.Steps[nextStepIndex]; + + if (step.EventType != ev.EventType) continue; + if (!StepMatches(step, doc.RootElement)) continue; + + // Reset if too far behind the first matched event. + if (state.MatchedSteps.Count > 0 + && (ev.OccurredAt - state.MatchedSteps[0].OccurredAt) > window) + { + state.MatchedSteps.Clear(); + nextStepIndex = 0; + step = definition.Steps[0]; + if (step.EventType != ev.EventType || !StepMatches(step, doc.RootElement)) + { + continue; + } + } + + state.MatchedSteps.Add(new MatchedStep(step.Name, ev.Id, ev.OccurredAt)); + + if (state.MatchedSteps.Count == definition.Steps.Count) + { + matches.Add(new SequenceMatch( + rule.Id, + agent.Id, + state.MatchedSteps.Last().EventId, + state.MatchedSteps.ToList())); + state.MatchedSteps.Clear(); + } + } + } + + // Garbage-collect stale state per host: if oldest matched step is past window, drop progress. + if (state.MatchedSteps.Count > 0 && (now - state.MatchedSteps[0].OccurredAt) > window) + { + state.MatchedSteps.Clear(); + } + + return matches; + } + + public void ResetAll() => _state.Clear(); + + private static bool StepMatches(SequenceStep step, JsonElement payload) + { + if (string.IsNullOrWhiteSpace(step.PayloadPath)) return true; + var values = ResolvePath(payload, step.PayloadPath).ToList(); + if (values.Count == 0) return false; + return step.Operator switch + { + AlertRuleOperator.Exists => true, + AlertRuleOperator.Equals => values.Any(v => string.Equals(JsonScalar(v), step.MatchValue, StringComparison.OrdinalIgnoreCase)), + AlertRuleOperator.Contains => values.Any(v => !string.IsNullOrEmpty(step.MatchValue) && JsonScalar(v).Contains(step.MatchValue, StringComparison.OrdinalIgnoreCase)), + AlertRuleOperator.GreaterThan => values.Any(v => CompareNumber(v, step.MatchValue, (a, b) => a > b)), + AlertRuleOperator.LessThan => values.Any(v => CompareNumber(v, step.MatchValue, (a, b) => a < b)), + _ => false, + }; + } + + private static IEnumerable ResolvePath(JsonElement root, string path) + { + var segments = path.Split('.', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + return ResolvePath(root, segments, 0); + } + + private static IEnumerable ResolvePath(JsonElement current, IReadOnlyList segments, int index) + { + if (index >= segments.Count) { yield return current; yield break; } + if (current.ValueKind == JsonValueKind.Array) + { + foreach (var item in current.EnumerateArray()) + { + foreach (var v in ResolvePath(item, segments, index)) yield return v; + } + yield break; + } + if (current.ValueKind != JsonValueKind.Object) yield break; + if (!current.TryGetProperty(segments[index], out var child)) yield break; + foreach (var v in ResolvePath(child, segments, index + 1)) yield return v; + } + + private static string JsonScalar(JsonElement value) => value.ValueKind switch + { + JsonValueKind.String => value.GetString() ?? "", + JsonValueKind.Number => value.GetRawText(), + JsonValueKind.True => "true", + JsonValueKind.False => "false", + JsonValueKind.Null => "", + _ => value.GetRawText(), + }; + + private static bool CompareNumber(JsonElement value, string? expected, Func cmp) + { + if (!decimal.TryParse(JsonScalar(value), NumberStyles.Float, CultureInfo.InvariantCulture, out var left)) return false; + if (!decimal.TryParse(expected, NumberStyles.Float, CultureInfo.InvariantCulture, out var right)) return false; + return cmp(left, right); + } + + private sealed class SequenceState + { + public List MatchedSteps { get; } = []; + } +} + +public record MatchedStep(string Name, long EventId, DateTimeOffset OccurredAt); + +public record SequenceMatch( + Guid RuleId, + Guid AgentId, + long TriggeringEventId, + IReadOnlyList Trail); diff --git a/backend/src/Tawny.Infrastructure/Hunting/SigmaExpression.cs b/backend/src/Tawny.Infrastructure/Hunting/SigmaExpression.cs new file mode 100644 index 0000000..da855d8 --- /dev/null +++ b/backend/src/Tawny.Infrastructure/Hunting/SigmaExpression.cs @@ -0,0 +1,121 @@ +using System.Globalization; +using System.Text.Json; +using System.Text.Json.Serialization; +using Tawny.Domain; + +namespace Tawny.Infrastructure.Hunting; + +/// +/// Compiled Sigma rule tree, stored on AlertRule.CompiledExpressionJson when +/// the source rule has a non-trivial condition (AND/OR/NOT, "1 of selection_*", +/// "all of selection_*"). Single-selection rules continue to use the legacy +/// AlertRule.PayloadPath/Operator/MatchValue fields so the simple path stays simple. +/// +[JsonPolymorphic(TypeDiscriminatorPropertyName = "kind")] +[JsonDerivedType(typeof(SigmaAnd), "and")] +[JsonDerivedType(typeof(SigmaOr), "or")] +[JsonDerivedType(typeof(SigmaNot), "not")] +[JsonDerivedType(typeof(SigmaAnyOf), "any_of")] +[JsonDerivedType(typeof(SigmaAllOf), "all_of")] +[JsonDerivedType(typeof(SigmaFieldPredicate), "field")] +public abstract record SigmaNode; + +public sealed record SigmaAnd(IReadOnlyList Children) : SigmaNode; +public sealed record SigmaOr(IReadOnlyList Children) : SigmaNode; +public sealed record SigmaNot(SigmaNode Inner) : SigmaNode; +public sealed record SigmaAnyOf(IReadOnlyList Children) : SigmaNode; +public sealed record SigmaAllOf(IReadOnlyList Children) : SigmaNode; + +public sealed record SigmaFieldPredicate( + string PayloadPath, + AlertRuleOperator Operator, + IReadOnlyList Values) : SigmaNode; + +public static class SigmaExpressionEvaluator +{ + public static bool Evaluate(SigmaNode node, JsonElement payload) + { + return node switch + { + SigmaAnd and => and.Children.All(c => Evaluate(c, payload)), + SigmaOr or => or.Children.Any(c => Evaluate(c, payload)), + SigmaNot not => !Evaluate(not.Inner, payload), + SigmaAnyOf anyOf => anyOf.Children.Any(c => Evaluate(c, payload)), + SigmaAllOf allOf => allOf.Children.All(c => Evaluate(c, payload)), + SigmaFieldPredicate p => EvaluatePredicate(p, payload), + _ => false, + }; + } + + private static bool EvaluatePredicate(SigmaFieldPredicate predicate, JsonElement payload) + { + var values = ResolvePath(payload, predicate.PayloadPath).ToList(); + if (values.Count == 0) return false; + return predicate.Operator switch + { + AlertRuleOperator.Exists => true, + AlertRuleOperator.Equals => values.Any(v => predicate.Values.Any(target => string.Equals(JsonScalar(v), target, StringComparison.OrdinalIgnoreCase))), + AlertRuleOperator.Contains => values.Any(v => predicate.Values.Any(target => JsonScalar(v).Contains(target, StringComparison.OrdinalIgnoreCase))), + AlertRuleOperator.GreaterThan => values.Any(v => predicate.Values.Any(target => CompareNumber(v, target, (a, b) => a > b))), + AlertRuleOperator.LessThan => values.Any(v => predicate.Values.Any(target => CompareNumber(v, target, (a, b) => a < b))), + _ => false, + }; + } + + private static IEnumerable ResolvePath(JsonElement root, string path) + { + var segments = path.Split('.', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + return ResolvePath(root, segments, 0); + } + + private static IEnumerable ResolvePath(JsonElement current, IReadOnlyList segments, int index) + { + if (index >= segments.Count) { yield return current; yield break; } + if (current.ValueKind == JsonValueKind.Array) + { + foreach (var item in current.EnumerateArray()) + { + foreach (var v in ResolvePath(item, segments, index)) yield return v; + } + yield break; + } + if (current.ValueKind != JsonValueKind.Object) yield break; + if (!current.TryGetProperty(segments[index], out var child)) yield break; + foreach (var v in ResolvePath(child, segments, index + 1)) yield return v; + } + + private static string JsonScalar(JsonElement value) => value.ValueKind switch + { + JsonValueKind.String => value.GetString() ?? "", + JsonValueKind.Number => value.GetRawText(), + JsonValueKind.True => "true", + JsonValueKind.False => "false", + JsonValueKind.Null => "", + _ => value.GetRawText(), + }; + + private static bool CompareNumber(JsonElement value, string? expected, Func cmp) + { + if (!decimal.TryParse(JsonScalar(value), NumberStyles.Float, CultureInfo.InvariantCulture, out var left)) return false; + if (!decimal.TryParse(expected, NumberStyles.Float, CultureInfo.InvariantCulture, out var right)) return false; + return cmp(left, right); + } +} + +public static class SigmaExpressionSerializer +{ + private static readonly JsonSerializerOptions Options = new(JsonSerializerDefaults.Web) + { + PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, + Converters = { new JsonStringEnumConverter(JsonNamingPolicy.SnakeCaseLower) }, + }; + + public static string Serialize(SigmaNode node) => JsonSerializer.Serialize(node, Options); + + public static SigmaNode? Deserialize(string? json) + { + if (string.IsNullOrWhiteSpace(json)) return null; + try { return JsonSerializer.Deserialize(json, Options); } + catch { return null; } + } +} diff --git a/backend/src/Tawny.Infrastructure/Hunting/YaraLite.cs b/backend/src/Tawny.Infrastructure/Hunting/YaraLite.cs new file mode 100644 index 0000000..2b2c63e --- /dev/null +++ b/backend/src/Tawny.Infrastructure/Hunting/YaraLite.cs @@ -0,0 +1,192 @@ +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.RegularExpressions; + +namespace Tawny.Infrastructure.Hunting; + +public class YaraLiteException(string message) : Exception(message); + +/// +/// YARA-lite: a JSON-defined string-match rule that evaluates against the +/// raw text of a telemetry payload. Not a full YARA implementation (no PE +/// parsing, no offsets, no XOR/wide modifiers) — those need libyara and a +/// way to ship file content from agents, which is Phase 2 territory. +/// +/// What we do support: +/// strings: list of either { literal: "..." } or { regex: "..." } with a $name +/// condition: "any_of" | "all_of" | "n_of(K)" +/// +public record YaraLiteDefinition( + [property: JsonPropertyName("strings")] IReadOnlyList Strings, + [property: JsonPropertyName("condition")] string Condition); + +public record YaraLiteString( + [property: JsonPropertyName("name")] string Name, + [property: JsonPropertyName("literal")] string? Literal, + [property: JsonPropertyName("regex")] string? Regex, + [property: JsonPropertyName("case_sensitive")] bool? CaseSensitive); + +public static class YaraLiteParser +{ + private static readonly JsonSerializerOptions JsonOptions = new(JsonSerializerDefaults.Web) + { + PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, + }; + + public static YaraLiteDefinition Parse(string json) + { + if (string.IsNullOrWhiteSpace(json)) + { + throw new YaraLiteException("YARA rule definition is empty."); + } + YaraLiteDefinition? def; + try { def = JsonSerializer.Deserialize(json, JsonOptions); } + catch (JsonException ex) + { + throw new YaraLiteException($"Invalid YARA-lite JSON: {ex.Message}"); + } + if (def is null || def.Strings is null || def.Strings.Count == 0) + { + throw new YaraLiteException("YARA-lite rule must define at least one string."); + } + foreach (var s in def.Strings) + { + if (string.IsNullOrWhiteSpace(s.Name)) + { + throw new YaraLiteException("Each string needs a name (e.g. $cmd1)."); + } + if (string.IsNullOrEmpty(s.Literal) && string.IsNullOrEmpty(s.Regex)) + { + throw new YaraLiteException($"String {s.Name} needs either a literal or a regex."); + } + if (!string.IsNullOrEmpty(s.Regex)) + { + try { _ = new Regex(s.Regex); } + catch (ArgumentException ex) + { + throw new YaraLiteException($"Invalid regex in {s.Name}: {ex.Message}"); + } + } + } + if (string.IsNullOrWhiteSpace(def.Condition)) + { + throw new YaraLiteException("YARA-lite rule must include a condition."); + } + return def; + } + + public static string Serialize(YaraLiteDefinition def) => JsonSerializer.Serialize(def, JsonOptions); +} + +public static class YaraLiteEvaluator +{ + private static readonly Regex NofRe = new(@"^\s*(?\d+)_of\s*$", RegexOptions.Compiled); + + public static bool Evaluate(YaraLiteDefinition definition, string payloadText) + { + var matched = new HashSet(StringComparer.Ordinal); + foreach (var s in definition.Strings) + { + if (StringMatches(s, payloadText)) + { + matched.Add(s.Name); + } + } + var condition = definition.Condition.Trim().ToLowerInvariant(); + if (condition == "any_of" || condition == "any of them") + { + return matched.Count > 0; + } + if (condition == "all_of" || condition == "all of them") + { + return matched.Count == definition.Strings.Count; + } + var nofMatch = NofRe.Match(condition); + if (nofMatch.Success && int.TryParse(nofMatch.Groups["n"].Value, out var n)) + { + return matched.Count >= n; + } + // Specific names like "$a and $b": evaluate as token-replace -> boolean expression. + var expr = condition; + foreach (var s in definition.Strings) + { + var truthy = matched.Contains(s.Name) ? "true" : "false"; + expr = Regex.Replace(expr, Regex.Escape(s.Name.ToLowerInvariant()), truthy); + } + return EvalBoolean(expr); + } + + private static bool StringMatches(YaraLiteString s, string payloadText) + { + if (!string.IsNullOrEmpty(s.Literal)) + { + var comparison = s.CaseSensitive == true + ? StringComparison.Ordinal + : StringComparison.OrdinalIgnoreCase; + return payloadText.Contains(s.Literal, comparison); + } + if (!string.IsNullOrEmpty(s.Regex)) + { + var opts = s.CaseSensitive == true ? RegexOptions.None : RegexOptions.IgnoreCase; + return Regex.IsMatch(payloadText, s.Regex, opts); + } + return false; + } + + private static bool EvalBoolean(string expr) + { + // Cheap, safe boolean-expression evaluator for "true and false or not true" style strings. + // Tokenize, then a tiny recursive-descent parser. We deliberately keep this minimal. + var tokens = Tokenize(expr).ToList(); + var pos = 0; + var result = ParseOr(tokens, ref pos); + return result; + } + + private static bool ParseOr(IReadOnlyList tokens, ref int pos) + { + var left = ParseAnd(tokens, ref pos); + while (pos < tokens.Count && tokens[pos] == "or") + { + pos++; + var right = ParseAnd(tokens, ref pos); + left = left || right; + } + return left; + } + + private static bool ParseAnd(IReadOnlyList tokens, ref int pos) + { + var left = ParseUnary(tokens, ref pos); + while (pos < tokens.Count && tokens[pos] == "and") + { + pos++; + var right = ParseUnary(tokens, ref pos); + left = left && right; + } + return left; + } + + private static bool ParseUnary(IReadOnlyList tokens, ref int pos) + { + if (pos >= tokens.Count) return false; + if (tokens[pos] == "not") { pos++; return !ParseUnary(tokens, ref pos); } + if (tokens[pos] == "(") { pos++; var v = ParseOr(tokens, ref pos); if (pos < tokens.Count && tokens[pos] == ")") pos++; return v; } + var t = tokens[pos++]; + return t == "true"; + } + + private static IEnumerable Tokenize(string expr) + { + var i = 0; + while (i < expr.Length) + { + var c = expr[i]; + if (char.IsWhiteSpace(c)) { i++; continue; } + if (c == '(' || c == ')') { yield return c.ToString(); i++; continue; } + var start = i; + while (i < expr.Length && !char.IsWhiteSpace(expr[i]) && expr[i] != '(' && expr[i] != ')') i++; + yield return expr[start..i]; + } + } +} diff --git a/backend/src/Tawny.Infrastructure/Migrations/20260524000000_AddPhase3And4.cs b/backend/src/Tawny.Infrastructure/Migrations/20260524000000_AddPhase3And4.cs new file mode 100644 index 0000000..e831b53 --- /dev/null +++ b/backend/src/Tawny.Infrastructure/Migrations/20260524000000_AddPhase3And4.cs @@ -0,0 +1,224 @@ +using System; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Migrations; +using Tawny.Infrastructure; + +#nullable disable + +namespace Tawny.Infrastructure.Migrations +{ + /// + [DbContext(typeof(TawnyDbContext))] + [Migration("20260524000000_AddPhase3And4")] + public partial class AddPhase3And4 : Migration + { + /// + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "Enrichment", + table: "Alerts", + type: "nvarchar(max)", + nullable: true); + + migrationBuilder.AddColumn( + name: "IsShared", + table: "SavedHunts", + type: "bit", + nullable: false, + defaultValue: true); + + migrationBuilder.AddColumn( + name: "CompiledExpression", + table: "AlertRules", + type: "nvarchar(max)", + nullable: true); + + migrationBuilder.CreateTable( + name: "ThreatIntelFeeds", + columns: table => new + { + Id = table.Column(type: "uniqueidentifier", nullable: false), + TenantId = table.Column(type: "uniqueidentifier", nullable: false, defaultValue: new Guid("00000000-0000-0000-0000-000000000001")), + Name = table.Column(type: "nvarchar(160)", maxLength: 160, nullable: false), + Kind = table.Column(type: "int", nullable: false), + Url = table.Column(type: "nvarchar(1024)", maxLength: 1024, nullable: false), + AuthHeaderName = table.Column(type: "nvarchar(64)", maxLength: 64, nullable: true), + AuthHeaderValueEncrypted = table.Column(type: "nvarchar(1024)", maxLength: 1024, nullable: true), + DefaultSeverity = table.Column(type: "int", nullable: false), + IsEnabled = table.Column(type: "bit", nullable: false), + IntervalMinutes = table.Column(type: "int", nullable: false), + Status = table.Column(type: "int", nullable: false), + LastRunAt = table.Column(type: "datetimeoffset", nullable: true), + LastSuccessAt = table.Column(type: "datetimeoffset", nullable: true), + LastImportedCount = table.Column(type: "int", nullable: false), + LastSkippedCount = table.Column(type: "int", nullable: false), + LastError = table.Column(type: "nvarchar(2048)", maxLength: 2048, nullable: true), + Etag = table.Column(type: "nvarchar(256)", maxLength: 256, nullable: true), + CreatedByUserId = table.Column(type: "uniqueidentifier", nullable: true), + CreatedAt = table.Column(type: "datetimeoffset", nullable: false), + UpdatedAt = table.Column(type: "datetimeoffset", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ThreatIntelFeeds", x => x.Id); + table.ForeignKey( + name: "FK_ThreatIntelFeeds_Tenants_TenantId", + column: x => x.TenantId, + principalTable: "Tenants", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "ReputationCache", + columns: table => new + { + Id = table.Column(type: "bigint", nullable: false) + .Annotation("SqlServer:Identity", "1, 1"), + TenantId = table.Column(type: "uniqueidentifier", nullable: false, defaultValue: new Guid("00000000-0000-0000-0000-000000000001")), + Provider = table.Column(type: "int", nullable: false), + IndicatorKind = table.Column(type: "nvarchar(32)", maxLength: 32, nullable: false), + IndicatorValue = table.Column(type: "nvarchar(512)", maxLength: 512, nullable: false), + Verdict = table.Column(type: "int", nullable: false), + Score = table.Column(type: "int", nullable: true), + DetailJson = table.Column(type: "nvarchar(max)", nullable: false), + FetchedAt = table.Column(type: "datetimeoffset", nullable: false), + ExpiresAt = table.Column(type: "datetimeoffset", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ReputationCache", x => x.Id); + }); + + migrationBuilder.CreateTable( + name: "Cases", + columns: table => new + { + Id = table.Column(type: "bigint", nullable: false) + .Annotation("SqlServer:Identity", "1, 1"), + TenantId = table.Column(type: "uniqueidentifier", nullable: false, defaultValue: new Guid("00000000-0000-0000-0000-000000000001")), + Title = table.Column(type: "nvarchar(255)", maxLength: 255, nullable: false), + Summary = table.Column(type: "nvarchar(max)", nullable: true), + Status = table.Column(type: "int", nullable: false), + Priority = table.Column(type: "int", nullable: false), + AssignedToUserId = table.Column(type: "uniqueidentifier", nullable: true), + CreatedByUserId = table.Column(type: "uniqueidentifier", nullable: true), + CreatedAt = table.Column(type: "datetimeoffset", nullable: false), + UpdatedAt = table.Column(type: "datetimeoffset", nullable: false), + ClosedAt = table.Column(type: "datetimeoffset", nullable: true), + MitreTechniques = table.Column(type: "nvarchar(max)", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_Cases", x => x.Id); + table.ForeignKey( + name: "FK_Cases_Tenants_TenantId", + column: x => x.TenantId, + principalTable: "Tenants", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "CaseAlerts", + columns: table => new + { + Id = table.Column(type: "bigint", nullable: false) + .Annotation("SqlServer:Identity", "1, 1"), + CaseId = table.Column(type: "bigint", nullable: false), + AlertId = table.Column(type: "bigint", nullable: false), + AddedAt = table.Column(type: "datetimeoffset", nullable: false), + AddedByUserId = table.Column(type: "uniqueidentifier", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_CaseAlerts", x => x.Id); + table.ForeignKey( + name: "FK_CaseAlerts_Cases_CaseId", + column: x => x.CaseId, + principalTable: "Cases", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CaseAlerts_Alerts_AlertId", + column: x => x.AlertId, + principalTable: "Alerts", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "CaseNotes", + columns: table => new + { + Id = table.Column(type: "bigint", nullable: false) + .Annotation("SqlServer:Identity", "1, 1"), + CaseId = table.Column(type: "bigint", nullable: false), + AuthorUserId = table.Column(type: "uniqueidentifier", nullable: true), + Body = table.Column(type: "nvarchar(max)", nullable: false), + CreatedAt = table.Column(type: "datetimeoffset", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_CaseNotes", x => x.Id); + table.ForeignKey( + name: "FK_CaseNotes_Cases_CaseId", + column: x => x.CaseId, + principalTable: "Cases", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateIndex( + name: "IX_ThreatIntelFeeds_TenantId_IsEnabled", + table: "ThreatIntelFeeds", + columns: new[] { "TenantId", "IsEnabled" }); + + migrationBuilder.CreateIndex( + name: "IX_ReputationCache_TenantId_Provider_IndicatorKind_IndicatorValue", + table: "ReputationCache", + columns: new[] { "TenantId", "Provider", "IndicatorKind", "IndicatorValue" }, + unique: true); + + migrationBuilder.CreateIndex( + name: "IX_ReputationCache_ExpiresAt", + table: "ReputationCache", + column: "ExpiresAt"); + + migrationBuilder.CreateIndex( + name: "IX_Cases_TenantId_Status_CreatedAt", + table: "Cases", + columns: new[] { "TenantId", "Status", "CreatedAt" }); + + migrationBuilder.CreateIndex( + name: "IX_CaseAlerts_CaseId_AlertId", + table: "CaseAlerts", + columns: new[] { "CaseId", "AlertId" }, + unique: true); + + migrationBuilder.CreateIndex( + name: "IX_CaseAlerts_AlertId", + table: "CaseAlerts", + column: "AlertId"); + + migrationBuilder.CreateIndex( + name: "IX_CaseNotes_CaseId_CreatedAt", + table: "CaseNotes", + columns: new[] { "CaseId", "CreatedAt" }); + } + + /// + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable(name: "CaseNotes"); + migrationBuilder.DropTable(name: "CaseAlerts"); + migrationBuilder.DropTable(name: "Cases"); + migrationBuilder.DropTable(name: "ReputationCache"); + migrationBuilder.DropTable(name: "ThreatIntelFeeds"); + migrationBuilder.DropColumn(name: "CompiledExpression", table: "AlertRules"); + migrationBuilder.DropColumn(name: "IsShared", table: "SavedHunts"); + migrationBuilder.DropColumn(name: "Enrichment", table: "Alerts"); + } + } +} diff --git a/backend/src/Tawny.Infrastructure/Migrations/20260524093710_SyncPhase3And4Model.Designer.cs b/backend/src/Tawny.Infrastructure/Migrations/20260524093710_SyncPhase3And4Model.Designer.cs new file mode 100644 index 0000000..136b4ea --- /dev/null +++ b/backend/src/Tawny.Infrastructure/Migrations/20260524093710_SyncPhase3And4Model.Designer.cs @@ -0,0 +1,1237 @@ +// +using System; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Migrations; +using Microsoft.EntityFrameworkCore.Storage.ValueConversion; +using Tawny.Infrastructure; + +#nullable disable + +namespace Tawny.Infrastructure.Migrations +{ + [DbContext(typeof(TawnyDbContext))] + [Migration("20260524093710_SyncPhase3And4Model")] + partial class SyncPhase3And4Model + { + /// + protected override void BuildTargetModel(ModelBuilder modelBuilder) + { +#pragma warning disable 612, 618 + modelBuilder + .HasAnnotation("ProductVersion", "10.0.0") + .HasAnnotation("Relational:MaxIdentifierLength", 128); + + SqlServerModelBuilderExtensions.UseIdentityColumns(modelBuilder); + + modelBuilder.Entity("Tawny.Domain.Entities.Agent", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("AgentVersion") + .IsRequired() + .HasMaxLength(32) + .HasColumnType("nvarchar(32)"); + + b.Property("Architecture") + .HasColumnType("int"); + + b.Property("EnrolledAt") + .HasColumnType("datetimeoffset"); + + b.Property("Hostname") + .IsRequired() + .HasMaxLength(255) + .HasColumnType("nvarchar(255)"); + + b.Property("LastHeartbeatAt") + .HasColumnType("datetimeoffset"); + + b.Property("OperatingSystem") + .HasColumnType("int"); + + b.Property("OsVersion") + .IsRequired() + .HasMaxLength(64) + .HasColumnType("nvarchar(64)"); + + b.Property("PublicIp") + .HasMaxLength(64) + .HasColumnType("nvarchar(64)"); + + b.Property("Status") + .HasColumnType("int"); + + b.Property("TagsJson") + .IsRequired() + .ValueGeneratedOnAdd() + .HasColumnType("nvarchar(max)") + .HasDefaultValue("[]") + .HasColumnName("Tags"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.HasKey("Id"); + + b.HasIndex("TenantId", "Hostname"); + + b.HasIndex("TenantId", "LastHeartbeatAt"); + + b.ToTable("Agents"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.AgentRelease", b => + { + b.Property("Version") + .HasMaxLength(32) + .HasColumnType("nvarchar(32)"); + + b.Property("Platform") + .HasMaxLength(32) + .HasColumnType("nvarchar(32)"); + + b.Property("DownloadUrl") + .IsRequired() + .HasMaxLength(1024) + .HasColumnType("nvarchar(1024)"); + + b.Property("IsLatest") + .HasColumnType("bit"); + + b.Property("ReleasedAt") + .HasColumnType("datetimeoffset"); + + b.Property("Sha256") + .IsRequired() + .HasMaxLength(128) + .HasColumnType("nvarchar(128)"); + + b.HasKey("Version", "Platform"); + + b.HasIndex("Platform", "IsLatest"); + + b.ToTable("AgentReleases"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.Alert", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint"); + + SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); + + b.Property("AgentId") + .HasColumnType("uniqueidentifier"); + + b.Property("AlertRuleId") + .HasColumnType("uniqueidentifier"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("Description") + .HasColumnType("nvarchar(max)"); + + b.Property("EnrichmentJson") + .HasColumnType("nvarchar(max)") + .HasColumnName("Enrichment"); + + b.Property("SentinelNotificationError") + .HasMaxLength(1024) + .HasColumnType("nvarchar(1024)"); + + b.Property("SentinelNotificationStatus") + .HasColumnType("int"); + + b.Property("SentinelNotifiedAt") + .HasColumnType("datetimeoffset"); + + b.Property("Severity") + .HasColumnType("int"); + + b.Property("SlackNotificationError") + .HasMaxLength(1024) + .HasColumnType("nvarchar(1024)"); + + b.Property("SlackNotificationStatus") + .HasColumnType("int"); + + b.Property("SlackNotifiedAt") + .HasColumnType("datetimeoffset"); + + b.Property("Status") + .HasColumnType("int"); + + b.Property("TelemetryEventId") + .HasColumnType("bigint"); + + b.Property("Title") + .IsRequired() + .HasMaxLength(255) + .HasColumnType("nvarchar(255)"); + + b.HasKey("Id"); + + b.HasIndex("AlertRuleId"); + + b.HasIndex("TelemetryEventId"); + + b.HasIndex("AgentId", "CreatedAt"); + + b.HasIndex("Status", "CreatedAt"); + + b.ToTable("Alerts"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.AlertRule", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("CompiledExpressionJson") + .HasColumnType("nvarchar(max)") + .HasColumnName("CompiledExpression"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("Description") + .HasColumnType("nvarchar(max)"); + + b.Property("EventType") + .HasColumnType("int"); + + b.Property("ExternalId") + .HasMaxLength(128) + .HasColumnType("nvarchar(128)"); + + b.Property("Format") + .HasColumnType("int"); + + b.Property("IsEnabled") + .HasColumnType("bit"); + + b.Property("MatchValue") + .HasMaxLength(512) + .HasColumnType("nvarchar(512)"); + + b.Property("MitreTechniquesJson") + .HasColumnType("nvarchar(max)") + .HasColumnName("MitreTechniques"); + + b.Property("Name") + .IsRequired() + .HasMaxLength(160) + .HasColumnType("nvarchar(160)"); + + b.Property("Operator") + .HasColumnType("int"); + + b.Property("PayloadPath") + .HasMaxLength(256) + .HasColumnType("nvarchar(256)"); + + b.Property("Severity") + .HasColumnType("int"); + + b.Property("SourceDefinition") + .HasColumnType("nvarchar(max)"); + + b.Property("UpdatedAt") + .HasColumnType("datetimeoffset"); + + b.HasKey("Id"); + + b.HasIndex("Format", "ExternalId"); + + b.HasIndex("IsEnabled", "EventType"); + + b.ToTable("AlertRules"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.ApiToken", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("CreatedByUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("ExpiresAt") + .HasColumnType("datetimeoffset"); + + b.Property("LastUsedAt") + .HasColumnType("datetimeoffset"); + + b.Property("Name") + .IsRequired() + .HasMaxLength(160) + .HasColumnType("nvarchar(160)"); + + b.Property("RevokedAt") + .HasColumnType("datetimeoffset"); + + b.Property("Role") + .HasColumnType("int"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("TokenHash") + .IsRequired() + .HasMaxLength(128) + .HasColumnType("nvarchar(128)"); + + b.Property("TokenPrefix") + .IsRequired() + .HasMaxLength(16) + .HasColumnType("nvarchar(16)"); + + b.HasKey("Id"); + + b.HasIndex("TokenHash") + .IsUnique(); + + b.HasIndex("TenantId", "CreatedAt"); + + b.ToTable("ApiTokens"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.AuditLog", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint"); + + SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); + + b.Property("Action") + .IsRequired() + .HasMaxLength(64) + .HasColumnType("nvarchar(64)"); + + b.Property("MetadataJson") + .HasColumnType("nvarchar(max)") + .HasColumnName("Metadata"); + + b.Property("OccurredAt") + .HasColumnType("datetimeoffset"); + + b.Property("Target") + .HasMaxLength(255) + .HasColumnType("nvarchar(255)"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("UserId") + .HasColumnType("uniqueidentifier"); + + b.HasKey("Id"); + + b.HasIndex("TenantId", "OccurredAt"); + + b.ToTable("AuditLog"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.Case", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint"); + + SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); + + b.Property("AssignedToUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("ClosedAt") + .HasColumnType("datetimeoffset"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("CreatedByUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("MitreTechniquesJson") + .HasColumnType("nvarchar(max)") + .HasColumnName("MitreTechniques"); + + b.Property("Priority") + .HasColumnType("int"); + + b.Property("Status") + .HasColumnType("int"); + + b.Property("Summary") + .HasColumnType("nvarchar(max)"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("Title") + .IsRequired() + .HasMaxLength(255) + .HasColumnType("nvarchar(255)"); + + b.Property("UpdatedAt") + .HasColumnType("datetimeoffset"); + + b.HasKey("Id"); + + b.HasIndex("TenantId", "Status", "CreatedAt"); + + b.ToTable("Cases"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.CaseAlert", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint"); + + SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); + + b.Property("AddedAt") + .HasColumnType("datetimeoffset"); + + b.Property("AddedByUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("AlertId") + .HasColumnType("bigint"); + + b.Property("CaseId") + .HasColumnType("bigint"); + + b.HasKey("Id"); + + b.HasIndex("AlertId"); + + b.HasIndex("CaseId", "AlertId") + .IsUnique(); + + b.ToTable("CaseAlerts"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.CaseNote", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint"); + + SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); + + b.Property("AuthorUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("Body") + .IsRequired() + .HasColumnType("nvarchar(max)"); + + b.Property("CaseId") + .HasColumnType("bigint"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.HasKey("Id"); + + b.HasIndex("CaseId", "CreatedAt"); + + b.ToTable("CaseNotes"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.EnrollmentToken", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("CreatedByUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("ExpiresAt") + .HasColumnType("datetimeoffset"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("TokenHash") + .IsRequired() + .HasMaxLength(128) + .HasColumnType("nvarchar(128)"); + + b.Property("UsedAt") + .HasColumnType("datetimeoffset"); + + b.Property("UsedByAgentId") + .HasColumnType("uniqueidentifier"); + + b.HasKey("Id"); + + b.HasIndex("TokenHash") + .IsUnique(); + + b.HasIndex("TenantId", "CreatedAt"); + + b.ToTable("EnrollmentTokens"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.HuntRun", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint"); + + SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); + + b.Property("AlertsCreated") + .HasColumnType("int"); + + b.Property("CompletedAt") + .HasColumnType("datetimeoffset"); + + b.Property("ErrorMessage") + .HasMaxLength(1024) + .HasColumnType("nvarchar(1024)"); + + b.Property("MatchCount") + .HasColumnType("int"); + + b.Property("SavedHuntId") + .HasColumnType("uniqueidentifier"); + + b.Property("StartedAt") + .HasColumnType("datetimeoffset"); + + b.Property("Status") + .HasColumnType("int"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("TriggeredByUserId") + .HasColumnType("uniqueidentifier"); + + b.HasKey("Id"); + + b.HasIndex("SavedHuntId"); + + b.HasIndex("TenantId", "SavedHuntId", "StartedAt") + .IsDescending(false, false, true); + + b.ToTable("HuntRuns"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.ReputationCacheEntry", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint"); + + SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); + + b.Property("DetailJson") + .IsRequired() + .HasColumnType("nvarchar(max)"); + + b.Property("ExpiresAt") + .HasColumnType("datetimeoffset"); + + b.Property("FetchedAt") + .HasColumnType("datetimeoffset"); + + b.Property("IndicatorKind") + .IsRequired() + .HasMaxLength(32) + .HasColumnType("nvarchar(32)"); + + b.Property("IndicatorValue") + .IsRequired() + .HasMaxLength(512) + .HasColumnType("nvarchar(512)"); + + b.Property("Provider") + .HasColumnType("int"); + + b.Property("Score") + .HasColumnType("int"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("Verdict") + .HasColumnType("int"); + + b.HasKey("Id"); + + b.HasIndex("ExpiresAt"); + + b.HasIndex("TenantId", "Provider", "IndicatorKind", "IndicatorValue") + .IsUnique(); + + b.ToTable("ReputationCache"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.ResponseAction", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("ActionType") + .HasColumnType("int"); + + b.Property("AgentId") + .HasColumnType("uniqueidentifier"); + + b.Property("CompletedAt") + .HasColumnType("datetimeoffset"); + + b.Property("DispatchedAt") + .HasColumnType("datetimeoffset"); + + b.Property("PayloadJson") + .IsRequired() + .HasColumnType("nvarchar(max)") + .HasColumnName("Payload"); + + b.Property("RequestedAt") + .HasColumnType("datetimeoffset"); + + b.Property("RequestedByUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("ResultJson") + .HasColumnType("nvarchar(max)") + .HasColumnName("Result"); + + b.Property("Status") + .HasColumnType("int"); + + b.HasKey("Id"); + + b.HasIndex("AgentId", "Status", "RequestedAt"); + + b.ToTable("ResponseActions"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.SavedHunt", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("AlertOnMatch") + .HasColumnType("bit"); + + b.Property("AlertSeverity") + .HasColumnType("int"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("CreatedByUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("Description") + .HasColumnType("nvarchar(max)"); + + b.Property("IsScheduled") + .HasColumnType("bit"); + + b.Property("IsShared") + .HasColumnType("bit"); + + b.Property("LastMatchCount") + .HasColumnType("int"); + + b.Property("LastRunAt") + .HasColumnType("datetimeoffset"); + + b.Property("MitreTechniquesJson") + .HasColumnType("nvarchar(max)") + .HasColumnName("MitreTechniques"); + + b.Property("Name") + .IsRequired() + .HasMaxLength(160) + .HasColumnType("nvarchar(160)"); + + b.Property("Query") + .IsRequired() + .HasColumnType("nvarchar(max)"); + + b.Property("ScheduleCron") + .HasMaxLength(64) + .HasColumnType("nvarchar(64)"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("UpdatedAt") + .HasColumnType("datetimeoffset"); + + b.HasKey("Id"); + + b.HasIndex("TenantId", "IsScheduled"); + + b.HasIndex("TenantId", "Name") + .IsUnique(); + + b.ToTable("SavedHunts"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.SuppressionRule", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("AgentId") + .HasColumnType("uniqueidentifier"); + + b.Property("AlertRuleId") + .HasColumnType("uniqueidentifier"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("CreatedByUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("ExpiresAt") + .HasColumnType("datetimeoffset"); + + b.Property("IsEnabled") + .HasColumnType("bit"); + + b.Property("LastSuppressedAt") + .HasColumnType("datetimeoffset"); + + b.Property("MatchValue") + .HasMaxLength(512) + .HasColumnType("nvarchar(512)"); + + b.Property("Name") + .IsRequired() + .HasMaxLength(160) + .HasColumnType("nvarchar(160)"); + + b.Property("Operator") + .HasColumnType("int"); + + b.Property("PayloadPath") + .HasMaxLength(256) + .HasColumnType("nvarchar(256)"); + + b.Property("Reason") + .HasColumnType("nvarchar(max)"); + + b.Property("Scope") + .HasColumnType("int"); + + b.Property("SuppressedCount") + .HasColumnType("int"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("UpdatedAt") + .HasColumnType("datetimeoffset"); + + b.HasKey("Id"); + + b.HasIndex("AgentId"); + + b.HasIndex("AlertRuleId"); + + b.HasIndex("TenantId", "IsEnabled"); + + b.ToTable("SuppressionRules"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.TelemetryEvent", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint"); + + SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); + + b.Property("AgentId") + .HasColumnType("uniqueidentifier"); + + b.Property("EventType") + .HasColumnType("int"); + + b.Property("OccurredAt") + .HasColumnType("datetimeoffset"); + + b.Property("Payload") + .IsRequired() + .HasColumnType("nvarchar(max)"); + + b.Property("ReceivedAt") + .HasColumnType("datetimeoffset"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.HasKey("Id"); + + b.HasIndex("AgentId"); + + b.HasIndex("TenantId", "ReceivedAt"); + + b.HasIndex("TenantId", "AgentId", "EventType", "OccurredAt") + .IsDescending(false, false, false, true); + + b.ToTable("TelemetryEvents"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.Tenant", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("Name") + .IsRequired() + .HasMaxLength(128) + .HasColumnType("nvarchar(128)"); + + b.Property("Slug") + .IsRequired() + .HasMaxLength(64) + .HasColumnType("nvarchar(64)"); + + b.HasKey("Id"); + + b.HasIndex("Slug") + .IsUnique(); + + b.ToTable("Tenants"); + + b.HasData( + new + { + Id = new Guid("00000000-0000-0000-0000-000000000001"), + CreatedAt = new DateTimeOffset(new DateTime(1970, 1, 1, 0, 0, 0, 0, DateTimeKind.Unspecified), new TimeSpan(0, 0, 0, 0, 0)), + Name = "Default tenant", + Slug = "default" + }); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.ThreatIntelFeed", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("AuthHeaderName") + .HasMaxLength(64) + .HasColumnType("nvarchar(64)"); + + b.Property("AuthHeaderValueEncrypted") + .HasMaxLength(1024) + .HasColumnType("nvarchar(1024)"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("CreatedByUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("DefaultSeverity") + .HasColumnType("int"); + + b.Property("Etag") + .HasMaxLength(256) + .HasColumnType("nvarchar(256)"); + + b.Property("IntervalMinutes") + .HasColumnType("int"); + + b.Property("IsEnabled") + .HasColumnType("bit"); + + b.Property("Kind") + .HasColumnType("int"); + + b.Property("LastError") + .HasMaxLength(2048) + .HasColumnType("nvarchar(2048)"); + + b.Property("LastImportedCount") + .HasColumnType("int"); + + b.Property("LastRunAt") + .HasColumnType("datetimeoffset"); + + b.Property("LastSkippedCount") + .HasColumnType("int"); + + b.Property("LastSuccessAt") + .HasColumnType("datetimeoffset"); + + b.Property("Name") + .IsRequired() + .HasMaxLength(160) + .HasColumnType("nvarchar(160)"); + + b.Property("Status") + .HasColumnType("int"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("UpdatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("Url") + .IsRequired() + .HasMaxLength(1024) + .HasColumnType("nvarchar(1024)"); + + b.HasKey("Id"); + + b.HasIndex("TenantId", "IsEnabled"); + + b.ToTable("ThreatIntelFeeds"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.User", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("Email") + .IsRequired() + .HasMaxLength(320) + .HasColumnType("nvarchar(320)"); + + b.Property("PasswordHash") + .IsRequired() + .HasMaxLength(512) + .HasColumnType("nvarchar(512)"); + + b.Property("Role") + .HasColumnType("int"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.HasKey("Id"); + + b.HasIndex("TenantId", "Email") + .IsUnique(); + + b.ToTable("Users"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.Agent", b => + { + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("Agents") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.Navigation("Tenant"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.Alert", b => + { + b.HasOne("Tawny.Domain.Entities.Agent", "Agent") + .WithMany() + .HasForeignKey("AgentId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.HasOne("Tawny.Domain.Entities.AlertRule", "AlertRule") + .WithMany("Alerts") + .HasForeignKey("AlertRuleId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.HasOne("Tawny.Domain.Entities.TelemetryEvent", "TelemetryEvent") + .WithMany() + .HasForeignKey("TelemetryEventId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.Navigation("Agent"); + + b.Navigation("AlertRule"); + + b.Navigation("TelemetryEvent"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.ApiToken", b => + { + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("ApiTokens") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.Navigation("Tenant"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.AuditLog", b => + { + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("AuditLog") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.Navigation("Tenant"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.Case", b => + { + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("Cases") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.Navigation("Tenant"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.CaseAlert", b => + { + b.HasOne("Tawny.Domain.Entities.Alert", "Alert") + .WithMany() + .HasForeignKey("AlertId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.HasOne("Tawny.Domain.Entities.Case", "Case") + .WithMany("CaseAlerts") + .HasForeignKey("CaseId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.Navigation("Alert"); + + b.Navigation("Case"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.CaseNote", b => + { + b.HasOne("Tawny.Domain.Entities.Case", "Case") + .WithMany("Notes") + .HasForeignKey("CaseId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.Navigation("Case"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.EnrollmentToken", b => + { + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("EnrollmentTokens") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.Navigation("Tenant"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.HuntRun", b => + { + b.HasOne("Tawny.Domain.Entities.SavedHunt", "SavedHunt") + .WithMany("Runs") + .HasForeignKey("SavedHuntId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.Navigation("SavedHunt"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.ResponseAction", b => + { + b.HasOne("Tawny.Domain.Entities.Agent", "Agent") + .WithMany() + .HasForeignKey("AgentId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.Navigation("Agent"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.SavedHunt", b => + { + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("SavedHunts") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.Navigation("Tenant"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.SuppressionRule", b => + { + b.HasOne("Tawny.Domain.Entities.Agent", "Agent") + .WithMany() + .HasForeignKey("AgentId") + .OnDelete(DeleteBehavior.SetNull); + + b.HasOne("Tawny.Domain.Entities.AlertRule", "AlertRule") + .WithMany() + .HasForeignKey("AlertRuleId") + .OnDelete(DeleteBehavior.Cascade); + + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("SuppressionRules") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.Navigation("Agent"); + + b.Navigation("AlertRule"); + + b.Navigation("Tenant"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.TelemetryEvent", b => + { + b.HasOne("Tawny.Domain.Entities.Agent", "Agent") + .WithMany("Events") + .HasForeignKey("AgentId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("TelemetryEvents") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.Navigation("Agent"); + + b.Navigation("Tenant"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.ThreatIntelFeed", b => + { + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("ThreatIntelFeeds") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.Navigation("Tenant"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.User", b => + { + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("Users") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.Navigation("Tenant"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.Agent", b => + { + b.Navigation("Events"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.AlertRule", b => + { + b.Navigation("Alerts"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.Case", b => + { + b.Navigation("CaseAlerts"); + + b.Navigation("Notes"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.SavedHunt", b => + { + b.Navigation("Runs"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.Tenant", b => + { + b.Navigation("Agents"); + + b.Navigation("ApiTokens"); + + b.Navigation("AuditLog"); + + b.Navigation("Cases"); + + b.Navigation("EnrollmentTokens"); + + b.Navigation("SavedHunts"); + + b.Navigation("SuppressionRules"); + + b.Navigation("TelemetryEvents"); + + b.Navigation("ThreatIntelFeeds"); + + b.Navigation("Users"); + }); +#pragma warning restore 612, 618 + } + } +} diff --git a/backend/src/Tawny.Infrastructure/Migrations/20260524093710_SyncPhase3And4Model.cs b/backend/src/Tawny.Infrastructure/Migrations/20260524093710_SyncPhase3And4Model.cs new file mode 100644 index 0000000..c9de4f7 --- /dev/null +++ b/backend/src/Tawny.Infrastructure/Migrations/20260524093710_SyncPhase3And4Model.cs @@ -0,0 +1,27 @@ +using Microsoft.EntityFrameworkCore.Migrations; + +#nullable disable + +namespace Tawny.Infrastructure.Migrations +{ + /// + public partial class SyncPhase3And4Model : Migration + { + /// + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.CreateIndex( + name: "IX_HuntRuns_SavedHuntId", + table: "HuntRuns", + column: "SavedHuntId"); + } + + /// + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropIndex( + name: "IX_HuntRuns_SavedHuntId", + table: "HuntRuns"); + } + } +} diff --git a/backend/src/Tawny.Infrastructure/Migrations/TawnyDbContextModelSnapshot.cs b/backend/src/Tawny.Infrastructure/Migrations/TawnyDbContextModelSnapshot.cs index 319f4d3..9771ef5 100644 --- a/backend/src/Tawny.Infrastructure/Migrations/TawnyDbContextModelSnapshot.cs +++ b/backend/src/Tawny.Infrastructure/Migrations/TawnyDbContextModelSnapshot.cs @@ -136,8 +136,9 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.Property("Description") .HasColumnType("nvarchar(max)"); - b.Property("Severity") - .HasColumnType("int"); + b.Property("EnrichmentJson") + .HasColumnType("nvarchar(max)") + .HasColumnName("Enrichment"); b.Property("SentinelNotificationError") .HasMaxLength(1024) @@ -149,6 +150,9 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.Property("SentinelNotifiedAt") .HasColumnType("datetimeoffset"); + b.Property("Severity") + .HasColumnType("int"); + b.Property("SlackNotificationError") .HasMaxLength(1024) .HasColumnType("nvarchar(1024)"); @@ -189,6 +193,10 @@ protected override void BuildModel(ModelBuilder modelBuilder) .ValueGeneratedOnAdd() .HasColumnType("uniqueidentifier"); + b.Property("CompiledExpressionJson") + .HasColumnType("nvarchar(max)") + .HasColumnName("CompiledExpression"); + b.Property("CreatedAt") .HasColumnType("datetimeoffset"); @@ -246,72 +254,61 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.ToTable("AlertRules"); }); - modelBuilder.Entity("Tawny.Domain.Entities.SavedHunt", b => + modelBuilder.Entity("Tawny.Domain.Entities.ApiToken", b => { b.Property("Id") .ValueGeneratedOnAdd() .HasColumnType("uniqueidentifier"); - b.Property("AlertOnMatch") - .HasColumnType("bit"); - - b.Property("AlertSeverity") - .HasColumnType("int"); - b.Property("CreatedAt") .HasColumnType("datetimeoffset"); b.Property("CreatedByUserId") .HasColumnType("uniqueidentifier"); - b.Property("Description") - .HasColumnType("nvarchar(max)"); - - b.Property("IsScheduled") - .HasColumnType("bit"); - - b.Property("LastMatchCount") - .HasColumnType("int"); - - b.Property("LastRunAt") + b.Property("ExpiresAt") .HasColumnType("datetimeoffset"); - b.Property("MitreTechniquesJson") - .HasColumnType("nvarchar(max)") - .HasColumnName("MitreTechniques"); + b.Property("LastUsedAt") + .HasColumnType("datetimeoffset"); b.Property("Name") .IsRequired() .HasMaxLength(160) .HasColumnType("nvarchar(160)"); - b.Property("Query") - .IsRequired() - .HasColumnType("nvarchar(max)"); + b.Property("RevokedAt") + .HasColumnType("datetimeoffset"); - b.Property("ScheduleCron") - .HasMaxLength(64) - .HasColumnType("nvarchar(64)"); + b.Property("Role") + .HasColumnType("int"); b.Property("TenantId") .ValueGeneratedOnAdd() .HasColumnType("uniqueidentifier") .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); - b.Property("UpdatedAt") - .HasColumnType("datetimeoffset"); + b.Property("TokenHash") + .IsRequired() + .HasMaxLength(128) + .HasColumnType("nvarchar(128)"); - b.HasKey("Id"); + b.Property("TokenPrefix") + .IsRequired() + .HasMaxLength(16) + .HasColumnType("nvarchar(16)"); - b.HasIndex("TenantId", "IsScheduled"); + b.HasKey("Id"); - b.HasIndex("TenantId", "Name") + b.HasIndex("TokenHash") .IsUnique(); - b.ToTable("SavedHunts"); + b.HasIndex("TenantId", "CreatedAt"); + + b.ToTable("ApiTokens"); }); - modelBuilder.Entity("Tawny.Domain.Entities.HuntRun", b => + modelBuilder.Entity("Tawny.Domain.Entities.AuditLog", b => { b.Property("Id") .ValueGeneratedOnAdd() @@ -319,170 +316,121 @@ protected override void BuildModel(ModelBuilder modelBuilder) SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); - b.Property("AlertsCreated") - .HasColumnType("int"); - - b.Property("CompletedAt") - .HasColumnType("datetimeoffset"); - - b.Property("ErrorMessage") - .HasMaxLength(1024) - .HasColumnType("nvarchar(1024)"); - - b.Property("MatchCount") - .HasColumnType("int"); + b.Property("Action") + .IsRequired() + .HasMaxLength(64) + .HasColumnType("nvarchar(64)"); - b.Property("SavedHuntId") - .HasColumnType("uniqueidentifier"); + b.Property("MetadataJson") + .HasColumnType("nvarchar(max)") + .HasColumnName("Metadata"); - b.Property("StartedAt") + b.Property("OccurredAt") .HasColumnType("datetimeoffset"); - b.Property("Status") - .HasColumnType("int"); + b.Property("Target") + .HasMaxLength(255) + .HasColumnType("nvarchar(255)"); b.Property("TenantId") .ValueGeneratedOnAdd() .HasColumnType("uniqueidentifier") .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); - b.Property("TriggeredByUserId") + b.Property("UserId") .HasColumnType("uniqueidentifier"); b.HasKey("Id"); - b.HasIndex("TenantId", "SavedHuntId", "StartedAt") - .IsDescending(false, false, true); + b.HasIndex("TenantId", "OccurredAt"); - b.ToTable("HuntRuns"); + b.ToTable("AuditLog"); }); - modelBuilder.Entity("Tawny.Domain.Entities.SuppressionRule", b => + modelBuilder.Entity("Tawny.Domain.Entities.Case", b => { - b.Property("Id") + b.Property("Id") .ValueGeneratedOnAdd() - .HasColumnType("uniqueidentifier"); + .HasColumnType("bigint"); - b.Property("AgentId") - .HasColumnType("uniqueidentifier"); + SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); - b.Property("AlertRuleId") + b.Property("AssignedToUserId") .HasColumnType("uniqueidentifier"); + b.Property("ClosedAt") + .HasColumnType("datetimeoffset"); + b.Property("CreatedAt") .HasColumnType("datetimeoffset"); b.Property("CreatedByUserId") .HasColumnType("uniqueidentifier"); - b.Property("ExpiresAt") - .HasColumnType("datetimeoffset"); - - b.Property("IsEnabled") - .HasColumnType("bit"); - - b.Property("LastSuppressedAt") - .HasColumnType("datetimeoffset"); - - b.Property("MatchValue") - .HasMaxLength(512) - .HasColumnType("nvarchar(512)"); - - b.Property("Name") - .IsRequired() - .HasMaxLength(160) - .HasColumnType("nvarchar(160)"); + b.Property("MitreTechniquesJson") + .HasColumnType("nvarchar(max)") + .HasColumnName("MitreTechniques"); - b.Property("Operator") + b.Property("Priority") .HasColumnType("int"); - b.Property("PayloadPath") - .HasMaxLength(256) - .HasColumnType("nvarchar(256)"); - - b.Property("Reason") - .HasColumnType("nvarchar(max)"); - - b.Property("Scope") + b.Property("Status") .HasColumnType("int"); - b.Property("SuppressedCount") - .HasColumnType("int"); + b.Property("Summary") + .HasColumnType("nvarchar(max)"); b.Property("TenantId") .ValueGeneratedOnAdd() .HasColumnType("uniqueidentifier") .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + b.Property("Title") + .IsRequired() + .HasMaxLength(255) + .HasColumnType("nvarchar(255)"); + b.Property("UpdatedAt") .HasColumnType("datetimeoffset"); b.HasKey("Id"); - b.HasIndex("AgentId"); - - b.HasIndex("AlertRuleId"); - - b.HasIndex("TenantId", "IsEnabled"); + b.HasIndex("TenantId", "Status", "CreatedAt"); - b.ToTable("SuppressionRules"); + b.ToTable("Cases"); }); - modelBuilder.Entity("Tawny.Domain.Entities.ApiToken", b => + modelBuilder.Entity("Tawny.Domain.Entities.CaseAlert", b => { - b.Property("Id") + b.Property("Id") .ValueGeneratedOnAdd() - .HasColumnType("uniqueidentifier"); - - b.Property("CreatedAt") - .HasColumnType("datetimeoffset"); - - b.Property("CreatedByUserId") - .HasColumnType("uniqueidentifier"); - - b.Property("ExpiresAt") - .HasColumnType("datetimeoffset"); - - b.Property("LastUsedAt") - .HasColumnType("datetimeoffset"); + .HasColumnType("bigint"); - b.Property("Name") - .IsRequired() - .HasMaxLength(160) - .HasColumnType("nvarchar(160)"); + SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); - b.Property("RevokedAt") + b.Property("AddedAt") .HasColumnType("datetimeoffset"); - b.Property("Role") - .HasColumnType("int"); - - b.Property("TenantId") - .ValueGeneratedOnAdd() - .HasColumnType("uniqueidentifier") - .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + b.Property("AddedByUserId") + .HasColumnType("uniqueidentifier"); - b.Property("TokenHash") - .IsRequired() - .HasMaxLength(128) - .HasColumnType("nvarchar(128)"); + b.Property("AlertId") + .HasColumnType("bigint"); - b.Property("TokenPrefix") - .IsRequired() - .HasMaxLength(16) - .HasColumnType("nvarchar(16)"); + b.Property("CaseId") + .HasColumnType("bigint"); b.HasKey("Id"); - b.HasIndex("TokenHash") - .IsUnique(); + b.HasIndex("AlertId"); - b.HasIndex("TenantId", "CreatedAt"); + b.HasIndex("CaseId", "AlertId") + .IsUnique(); - b.ToTable("ApiTokens"); + b.ToTable("CaseAlerts"); }); - modelBuilder.Entity("Tawny.Domain.Entities.AuditLog", b => + modelBuilder.Entity("Tawny.Domain.Entities.CaseNote", b => { b.Property("Id") .ValueGeneratedOnAdd() @@ -490,35 +438,24 @@ protected override void BuildModel(ModelBuilder modelBuilder) SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); - b.Property("Action") + b.Property("AuthorUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("Body") .IsRequired() - .HasMaxLength(64) - .HasColumnType("nvarchar(64)"); + .HasColumnType("nvarchar(max)"); - b.Property("MetadataJson") - .HasColumnType("nvarchar(max)") - .HasColumnName("Metadata"); + b.Property("CaseId") + .HasColumnType("bigint"); - b.Property("OccurredAt") + b.Property("CreatedAt") .HasColumnType("datetimeoffset"); - b.Property("Target") - .HasMaxLength(255) - .HasColumnType("nvarchar(255)"); - - b.Property("TenantId") - .ValueGeneratedOnAdd() - .HasColumnType("uniqueidentifier") - .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); - - b.Property("UserId") - .HasColumnType("uniqueidentifier"); - b.HasKey("Id"); - b.HasIndex("TenantId", "OccurredAt"); + b.HasIndex("CaseId", "CreatedAt"); - b.ToTable("AuditLog"); + b.ToTable("CaseNotes"); }); modelBuilder.Entity("Tawny.Domain.Entities.EnrollmentToken", b => @@ -562,36 +499,136 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.ToTable("EnrollmentTokens"); }); - modelBuilder.Entity("Tawny.Domain.Entities.ResponseAction", b => + modelBuilder.Entity("Tawny.Domain.Entities.HuntRun", b => { - b.Property("Id") + b.Property("Id") .ValueGeneratedOnAdd() - .HasColumnType("uniqueidentifier"); + .HasColumnType("bigint"); - b.Property("ActionType") - .HasColumnType("int"); + SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); - b.Property("AgentId") - .HasColumnType("uniqueidentifier"); + b.Property("AlertsCreated") + .HasColumnType("int"); b.Property("CompletedAt") .HasColumnType("datetimeoffset"); - b.Property("DispatchedAt") - .HasColumnType("datetimeoffset"); - - b.Property("PayloadJson") - .IsRequired() - .HasColumnType("nvarchar(max)") - .HasColumnName("Payload"); + b.Property("ErrorMessage") + .HasMaxLength(1024) + .HasColumnType("nvarchar(1024)"); - b.Property("RequestedAt") - .HasColumnType("datetimeoffset"); + b.Property("MatchCount") + .HasColumnType("int"); - b.Property("RequestedByUserId") + b.Property("SavedHuntId") .HasColumnType("uniqueidentifier"); - b.Property("ResultJson") + b.Property("StartedAt") + .HasColumnType("datetimeoffset"); + + b.Property("Status") + .HasColumnType("int"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("TriggeredByUserId") + .HasColumnType("uniqueidentifier"); + + b.HasKey("Id"); + + b.HasIndex("SavedHuntId"); + + b.HasIndex("TenantId", "SavedHuntId", "StartedAt") + .IsDescending(false, false, true); + + b.ToTable("HuntRuns"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.ReputationCacheEntry", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("bigint"); + + SqlServerPropertyBuilderExtensions.UseIdentityColumn(b.Property("Id")); + + b.Property("DetailJson") + .IsRequired() + .HasColumnType("nvarchar(max)"); + + b.Property("ExpiresAt") + .HasColumnType("datetimeoffset"); + + b.Property("FetchedAt") + .HasColumnType("datetimeoffset"); + + b.Property("IndicatorKind") + .IsRequired() + .HasMaxLength(32) + .HasColumnType("nvarchar(32)"); + + b.Property("IndicatorValue") + .IsRequired() + .HasMaxLength(512) + .HasColumnType("nvarchar(512)"); + + b.Property("Provider") + .HasColumnType("int"); + + b.Property("Score") + .HasColumnType("int"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("Verdict") + .HasColumnType("int"); + + b.HasKey("Id"); + + b.HasIndex("ExpiresAt"); + + b.HasIndex("TenantId", "Provider", "IndicatorKind", "IndicatorValue") + .IsUnique(); + + b.ToTable("ReputationCache"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.ResponseAction", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("ActionType") + .HasColumnType("int"); + + b.Property("AgentId") + .HasColumnType("uniqueidentifier"); + + b.Property("CompletedAt") + .HasColumnType("datetimeoffset"); + + b.Property("DispatchedAt") + .HasColumnType("datetimeoffset"); + + b.Property("PayloadJson") + .IsRequired() + .HasColumnType("nvarchar(max)") + .HasColumnName("Payload"); + + b.Property("RequestedAt") + .HasColumnType("datetimeoffset"); + + b.Property("RequestedByUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("ResultJson") .HasColumnType("nvarchar(max)") .HasColumnName("Result"); @@ -605,6 +642,145 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.ToTable("ResponseActions"); }); + modelBuilder.Entity("Tawny.Domain.Entities.SavedHunt", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("AlertOnMatch") + .HasColumnType("bit"); + + b.Property("AlertSeverity") + .HasColumnType("int"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("CreatedByUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("Description") + .HasColumnType("nvarchar(max)"); + + b.Property("IsScheduled") + .HasColumnType("bit"); + + b.Property("IsShared") + .HasColumnType("bit"); + + b.Property("LastMatchCount") + .HasColumnType("int"); + + b.Property("LastRunAt") + .HasColumnType("datetimeoffset"); + + b.Property("MitreTechniquesJson") + .HasColumnType("nvarchar(max)") + .HasColumnName("MitreTechniques"); + + b.Property("Name") + .IsRequired() + .HasMaxLength(160) + .HasColumnType("nvarchar(160)"); + + b.Property("Query") + .IsRequired() + .HasColumnType("nvarchar(max)"); + + b.Property("ScheduleCron") + .HasMaxLength(64) + .HasColumnType("nvarchar(64)"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("UpdatedAt") + .HasColumnType("datetimeoffset"); + + b.HasKey("Id"); + + b.HasIndex("TenantId", "IsScheduled"); + + b.HasIndex("TenantId", "Name") + .IsUnique(); + + b.ToTable("SavedHunts"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.SuppressionRule", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("AgentId") + .HasColumnType("uniqueidentifier"); + + b.Property("AlertRuleId") + .HasColumnType("uniqueidentifier"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("CreatedByUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("ExpiresAt") + .HasColumnType("datetimeoffset"); + + b.Property("IsEnabled") + .HasColumnType("bit"); + + b.Property("LastSuppressedAt") + .HasColumnType("datetimeoffset"); + + b.Property("MatchValue") + .HasMaxLength(512) + .HasColumnType("nvarchar(512)"); + + b.Property("Name") + .IsRequired() + .HasMaxLength(160) + .HasColumnType("nvarchar(160)"); + + b.Property("Operator") + .HasColumnType("int"); + + b.Property("PayloadPath") + .HasMaxLength(256) + .HasColumnType("nvarchar(256)"); + + b.Property("Reason") + .HasColumnType("nvarchar(max)"); + + b.Property("Scope") + .HasColumnType("int"); + + b.Property("SuppressedCount") + .HasColumnType("int"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("UpdatedAt") + .HasColumnType("datetimeoffset"); + + b.HasKey("Id"); + + b.HasIndex("AgentId"); + + b.HasIndex("AlertRuleId"); + + b.HasIndex("TenantId", "IsEnabled"); + + b.ToTable("SuppressionRules"); + }); + modelBuilder.Entity("Tawny.Domain.Entities.TelemetryEvent", b => { b.Property("Id") @@ -682,6 +858,86 @@ protected override void BuildModel(ModelBuilder modelBuilder) }); }); + modelBuilder.Entity("Tawny.Domain.Entities.ThreatIntelFeed", b => + { + b.Property("Id") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier"); + + b.Property("AuthHeaderName") + .HasMaxLength(64) + .HasColumnType("nvarchar(64)"); + + b.Property("AuthHeaderValueEncrypted") + .HasMaxLength(1024) + .HasColumnType("nvarchar(1024)"); + + b.Property("CreatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("CreatedByUserId") + .HasColumnType("uniqueidentifier"); + + b.Property("DefaultSeverity") + .HasColumnType("int"); + + b.Property("Etag") + .HasMaxLength(256) + .HasColumnType("nvarchar(256)"); + + b.Property("IntervalMinutes") + .HasColumnType("int"); + + b.Property("IsEnabled") + .HasColumnType("bit"); + + b.Property("Kind") + .HasColumnType("int"); + + b.Property("LastError") + .HasMaxLength(2048) + .HasColumnType("nvarchar(2048)"); + + b.Property("LastImportedCount") + .HasColumnType("int"); + + b.Property("LastRunAt") + .HasColumnType("datetimeoffset"); + + b.Property("LastSkippedCount") + .HasColumnType("int"); + + b.Property("LastSuccessAt") + .HasColumnType("datetimeoffset"); + + b.Property("Name") + .IsRequired() + .HasMaxLength(160) + .HasColumnType("nvarchar(160)"); + + b.Property("Status") + .HasColumnType("int"); + + b.Property("TenantId") + .ValueGeneratedOnAdd() + .HasColumnType("uniqueidentifier") + .HasDefaultValue(new Guid("00000000-0000-0000-0000-000000000001")); + + b.Property("UpdatedAt") + .HasColumnType("datetimeoffset"); + + b.Property("Url") + .IsRequired() + .HasMaxLength(1024) + .HasColumnType("nvarchar(1024)"); + + b.HasKey("Id"); + + b.HasIndex("TenantId", "IsEnabled"); + + b.ToTable("ThreatIntelFeeds"); + }); + modelBuilder.Entity("Tawny.Domain.Entities.User", b => { b.Property("Id") @@ -728,28 +984,6 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.Navigation("Tenant"); }); - modelBuilder.Entity("Tawny.Domain.Entities.AuditLog", b => - { - b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") - .WithMany("AuditLog") - .HasForeignKey("TenantId") - .OnDelete(DeleteBehavior.Restrict) - .IsRequired(); - - b.Navigation("Tenant"); - }); - - modelBuilder.Entity("Tawny.Domain.Entities.EnrollmentToken", b => - { - b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") - .WithMany("EnrollmentTokens") - .HasForeignKey("TenantId") - .OnDelete(DeleteBehavior.Restrict) - .IsRequired(); - - b.Navigation("Tenant"); - }); - modelBuilder.Entity("Tawny.Domain.Entities.Alert", b => { b.HasOne("Tawny.Domain.Entities.Agent", "Agent") @@ -777,40 +1011,32 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.Navigation("TelemetryEvent"); }); - modelBuilder.Entity("Tawny.Domain.Entities.ResponseAction", b => + modelBuilder.Entity("Tawny.Domain.Entities.ApiToken", b => { - b.HasOne("Tawny.Domain.Entities.Agent", "Agent") - .WithMany() - .HasForeignKey("AgentId") - .OnDelete(DeleteBehavior.Cascade) + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("ApiTokens") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) .IsRequired(); - b.Navigation("Agent"); + b.Navigation("Tenant"); }); - modelBuilder.Entity("Tawny.Domain.Entities.TelemetryEvent", b => + modelBuilder.Entity("Tawny.Domain.Entities.AuditLog", b => { - b.HasOne("Tawny.Domain.Entities.Agent", "Agent") - .WithMany("Events") - .HasForeignKey("AgentId") - .OnDelete(DeleteBehavior.Cascade) - .IsRequired(); - b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") - .WithMany("TelemetryEvents") + .WithMany("AuditLog") .HasForeignKey("TenantId") .OnDelete(DeleteBehavior.Restrict) .IsRequired(); - b.Navigation("Agent"); - b.Navigation("Tenant"); }); - modelBuilder.Entity("Tawny.Domain.Entities.User", b => + modelBuilder.Entity("Tawny.Domain.Entities.Case", b => { b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") - .WithMany("Users") + .WithMany("Cases") .HasForeignKey("TenantId") .OnDelete(DeleteBehavior.Restrict) .IsRequired(); @@ -818,10 +1044,40 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.Navigation("Tenant"); }); - modelBuilder.Entity("Tawny.Domain.Entities.SavedHunt", b => + modelBuilder.Entity("Tawny.Domain.Entities.CaseAlert", b => + { + b.HasOne("Tawny.Domain.Entities.Alert", "Alert") + .WithMany() + .HasForeignKey("AlertId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.HasOne("Tawny.Domain.Entities.Case", "Case") + .WithMany("CaseAlerts") + .HasForeignKey("CaseId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.Navigation("Alert"); + + b.Navigation("Case"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.CaseNote", b => + { + b.HasOne("Tawny.Domain.Entities.Case", "Case") + .WithMany("Notes") + .HasForeignKey("CaseId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.Navigation("Case"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.EnrollmentToken", b => { b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") - .WithMany("SavedHunts") + .WithMany("EnrollmentTokens") .HasForeignKey("TenantId") .OnDelete(DeleteBehavior.Restrict) .IsRequired(); @@ -840,6 +1096,28 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.Navigation("SavedHunt"); }); + modelBuilder.Entity("Tawny.Domain.Entities.ResponseAction", b => + { + b.HasOne("Tawny.Domain.Entities.Agent", "Agent") + .WithMany() + .HasForeignKey("AgentId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + + b.Navigation("Agent"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.SavedHunt", b => + { + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("SavedHunts") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.Navigation("Tenant"); + }); + modelBuilder.Entity("Tawny.Domain.Entities.SuppressionRule", b => { b.HasOne("Tawny.Domain.Entities.Agent", "Agent") @@ -865,10 +1143,40 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.Navigation("Tenant"); }); - modelBuilder.Entity("Tawny.Domain.Entities.ApiToken", b => + modelBuilder.Entity("Tawny.Domain.Entities.TelemetryEvent", b => { + b.HasOne("Tawny.Domain.Entities.Agent", "Agent") + .WithMany("Events") + .HasForeignKey("AgentId") + .OnDelete(DeleteBehavior.Cascade) + .IsRequired(); + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") - .WithMany("ApiTokens") + .WithMany("TelemetryEvents") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.Navigation("Agent"); + + b.Navigation("Tenant"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.ThreatIntelFeed", b => + { + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("ThreatIntelFeeds") + .HasForeignKey("TenantId") + .OnDelete(DeleteBehavior.Restrict) + .IsRequired(); + + b.Navigation("Tenant"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.User", b => + { + b.HasOne("Tawny.Domain.Entities.Tenant", "Tenant") + .WithMany("Users") .HasForeignKey("TenantId") .OnDelete(DeleteBehavior.Restrict) .IsRequired(); @@ -881,6 +1189,23 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.Navigation("Events"); }); + modelBuilder.Entity("Tawny.Domain.Entities.AlertRule", b => + { + b.Navigation("Alerts"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.Case", b => + { + b.Navigation("CaseAlerts"); + + b.Navigation("Notes"); + }); + + modelBuilder.Entity("Tawny.Domain.Entities.SavedHunt", b => + { + b.Navigation("Runs"); + }); + modelBuilder.Entity("Tawny.Domain.Entities.Tenant", b => { b.Navigation("Agents"); @@ -889,6 +1214,8 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.Navigation("AuditLog"); + b.Navigation("Cases"); + b.Navigation("EnrollmentTokens"); b.Navigation("SavedHunts"); @@ -897,17 +1224,9 @@ protected override void BuildModel(ModelBuilder modelBuilder) b.Navigation("TelemetryEvents"); - b.Navigation("Users"); - }); + b.Navigation("ThreatIntelFeeds"); - modelBuilder.Entity("Tawny.Domain.Entities.AlertRule", b => - { - b.Navigation("Alerts"); - }); - - modelBuilder.Entity("Tawny.Domain.Entities.SavedHunt", b => - { - b.Navigation("Runs"); + b.Navigation("Users"); }); #pragma warning restore 612, 618 } diff --git a/backend/src/Tawny.Infrastructure/TawnyDbContext.cs b/backend/src/Tawny.Infrastructure/TawnyDbContext.cs index 1659a5d..f5a2fa9 100644 --- a/backend/src/Tawny.Infrastructure/TawnyDbContext.cs +++ b/backend/src/Tawny.Infrastructure/TawnyDbContext.cs @@ -20,6 +20,11 @@ public class TawnyDbContext(DbContextOptions options) : DbContex public DbSet HuntRuns => Set(); public DbSet SuppressionRules => Set(); public DbSet ApiTokens => Set(); + public DbSet ThreatIntelFeeds => Set(); + public DbSet ReputationCache => Set(); + public DbSet Cases => Set(); + public DbSet CaseAlerts => Set(); + public DbSet CaseNotes => Set(); protected override void OnModelCreating(ModelBuilder b) { @@ -108,6 +113,7 @@ protected override void OnModelCreating(ModelBuilder b) e.Property(r => r.PayloadPath).HasMaxLength(256); e.Property(r => r.MatchValue).HasMaxLength(512); e.Property(r => r.SourceDefinition).HasColumnType("nvarchar(max)"); + e.Property(r => r.CompiledExpressionJson).HasColumnName("CompiledExpression").HasColumnType("nvarchar(max)"); e.Property(r => r.MitreTechniquesJson).HasColumnName("MitreTechniques").HasColumnType("nvarchar(max)"); e.HasIndex(r => new { r.IsEnabled, r.EventType }); e.HasIndex(r => new { r.Format, r.ExternalId }); @@ -118,6 +124,7 @@ protected override void OnModelCreating(ModelBuilder b) e.HasKey(a => a.Id); e.Property(a => a.Title).HasMaxLength(255).IsRequired(); e.Property(a => a.Description).HasColumnType("nvarchar(max)"); + e.Property(a => a.EnrichmentJson).HasColumnName("Enrichment").HasColumnType("nvarchar(max)"); e.Property(a => a.SlackNotificationError).HasMaxLength(1024); e.Property(a => a.SentinelNotificationError).HasMaxLength(1024); e.HasOne(a => a.AlertRule) @@ -240,5 +247,72 @@ protected override void OnModelCreating(ModelBuilder b) e.HasIndex(t => t.TokenHash).IsUnique(); e.HasIndex(t => new { t.TenantId, t.CreatedAt }); }); + + b.Entity(e => + { + e.HasKey(t => t.Id); + e.Property(t => t.TenantId).HasDefaultValue(TenantDefaults.DefaultTenantId); + e.Property(t => t.Name).HasMaxLength(160).IsRequired(); + e.Property(t => t.Url).HasMaxLength(1024).IsRequired(); + e.Property(t => t.AuthHeaderName).HasMaxLength(64); + e.Property(t => t.AuthHeaderValueEncrypted).HasMaxLength(1024); + e.Property(t => t.LastError).HasMaxLength(2048); + e.Property(t => t.Etag).HasMaxLength(256); + e.HasOne(t => t.Tenant) + .WithMany(t => t.ThreatIntelFeeds) + .HasForeignKey(t => t.TenantId) + .OnDelete(DeleteBehavior.Restrict); + e.HasIndex(t => new { t.TenantId, t.IsEnabled }); + }); + + b.Entity(e => + { + e.HasKey(r => r.Id); + e.Property(r => r.TenantId).HasDefaultValue(TenantDefaults.DefaultTenantId); + e.Property(r => r.IndicatorKind).HasMaxLength(32).IsRequired(); + e.Property(r => r.IndicatorValue).HasMaxLength(512).IsRequired(); + e.Property(r => r.DetailJson).HasColumnType("nvarchar(max)").IsRequired(); + e.HasIndex(r => new { r.TenantId, r.Provider, r.IndicatorKind, r.IndicatorValue }).IsUnique(); + e.HasIndex(r => r.ExpiresAt); + }); + + b.Entity(e => + { + e.HasKey(c => c.Id); + e.Property(c => c.TenantId).HasDefaultValue(TenantDefaults.DefaultTenantId); + e.Property(c => c.Title).HasMaxLength(255).IsRequired(); + e.Property(c => c.Summary).HasColumnType("nvarchar(max)"); + e.Property(c => c.MitreTechniquesJson).HasColumnName("MitreTechniques").HasColumnType("nvarchar(max)"); + e.HasOne(c => c.Tenant) + .WithMany(t => t.Cases) + .HasForeignKey(c => c.TenantId) + .OnDelete(DeleteBehavior.Restrict); + e.HasIndex(c => new { c.TenantId, c.Status, c.CreatedAt }); + }); + + b.Entity(e => + { + e.HasKey(ca => ca.Id); + e.HasOne(ca => ca.Case) + .WithMany(c => c.CaseAlerts) + .HasForeignKey(ca => ca.CaseId) + .OnDelete(DeleteBehavior.Cascade); + e.HasOne(ca => ca.Alert) + .WithMany() + .HasForeignKey(ca => ca.AlertId) + .OnDelete(DeleteBehavior.Cascade); + e.HasIndex(ca => new { ca.CaseId, ca.AlertId }).IsUnique(); + }); + + b.Entity(e => + { + e.HasKey(n => n.Id); + e.Property(n => n.Body).HasColumnType("nvarchar(max)").IsRequired(); + e.HasOne(n => n.Case) + .WithMany(c => c.Notes) + .HasForeignKey(n => n.CaseId) + .OnDelete(DeleteBehavior.Cascade); + e.HasIndex(n => new { n.CaseId, n.CreatedAt }); + }); } } diff --git a/backend/src/Tawny.Infrastructure/ThreatIntel/ReputationEnricher.cs b/backend/src/Tawny.Infrastructure/ThreatIntel/ReputationEnricher.cs new file mode 100644 index 0000000..7384662 --- /dev/null +++ b/backend/src/Tawny.Infrastructure/ThreatIntel/ReputationEnricher.cs @@ -0,0 +1,253 @@ +using System.Net.Http.Headers; +using System.Text.Json; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Tawny.Domain; +using Tawny.Domain.Entities; + +namespace Tawny.Infrastructure.ThreatIntel; + +public class ReputationOptions +{ + public string? VirusTotalApiKey { get; set; } + public string? AbuseIpDbApiKey { get; set; } + public string? GreyNoiseApiKey { get; set; } + public int CacheTtlHours { get; set; } = 24; + public int TimeoutSeconds { get; set; } = 10; + public bool EnrichAlertsAutomatically { get; set; } = true; +} + +public record ReputationLookup( + ReputationProvider Provider, + ReputationVerdict Verdict, + int? Score, + object Detail); + +/// +/// Looks up reputation for IoCs (hashes, IPs, domains) from configured providers +/// and caches the result. Designed to be safe to call from the alert pipeline: +/// each provider is HTTP-bound, has a short timeout, and respects the cache. +/// +public class ReputationEnricher( + TawnyDbContext db, + HttpClient http, + IOptions options, + TimeProvider timeProvider, + ILogger log) +{ + private static readonly JsonSerializerOptions JsonOptions = new(JsonSerializerDefaults.Web); + private readonly ReputationOptions _opts = options.Value; + + public async Task> LookupAsync( + Guid tenantId, + string kind, + string value, + CancellationToken ct) + { + if (string.IsNullOrWhiteSpace(value)) return []; + var providers = ProvidersForKind(kind).ToList(); + if (providers.Count == 0) return []; + + var results = new List(); + foreach (var provider in providers) + { + var cached = await TryCachedAsync(tenantId, provider, kind, value, ct); + if (cached is not null) + { + results.Add(cached); + continue; + } + try + { + var fresh = await ProbeAsync(provider, kind, value, ct); + if (fresh is null) continue; + results.Add(fresh); + await db.ReputationCache.AddAsync(new ReputationCacheEntry + { + TenantId = tenantId, + Provider = provider, + IndicatorKind = kind, + IndicatorValue = value, + Verdict = fresh.Verdict, + Score = fresh.Score, + DetailJson = JsonSerializer.Serialize(fresh.Detail, JsonOptions), + FetchedAt = timeProvider.GetUtcNow(), + ExpiresAt = timeProvider.GetUtcNow().AddHours(_opts.CacheTtlHours), + }, ct); + await db.SaveChangesAsync(ct); + } + catch (Exception ex) + { + log.LogWarning(ex, "Reputation probe for {Provider} {Kind} {Value} failed", provider, kind, value); + } + } + return results; + } + + private async Task TryCachedAsync( + Guid tenantId, ReputationProvider provider, string kind, string value, CancellationToken ct) + { + var now = timeProvider.GetUtcNow(); + var entry = await db.ReputationCache.AsNoTracking() + .FirstOrDefaultAsync(r => r.TenantId == tenantId + && r.Provider == provider + && r.IndicatorKind == kind + && r.IndicatorValue == value + && r.ExpiresAt > now, ct); + if (entry is null) return null; + object detail; + try { detail = JsonSerializer.Deserialize(entry.DetailJson); } + catch { detail = new { cached = true }; } + return new ReputationLookup(entry.Provider, entry.Verdict, entry.Score, detail); + } + + private async Task ProbeAsync(ReputationProvider provider, string kind, string value, CancellationToken ct) + { + using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + timeoutCts.CancelAfter(TimeSpan.FromSeconds(Math.Clamp(_opts.TimeoutSeconds, 1, 60))); + return provider switch + { + ReputationProvider.VirusTotal => await ProbeVirusTotalAsync(kind, value, timeoutCts.Token), + ReputationProvider.AbuseIpDb => await ProbeAbuseIpDbAsync(kind, value, timeoutCts.Token), + ReputationProvider.GreyNoise => await ProbeGreyNoiseAsync(kind, value, timeoutCts.Token), + _ => null, + }; + } + + private async Task ProbeVirusTotalAsync(string kind, string value, CancellationToken ct) + { + if (string.IsNullOrWhiteSpace(_opts.VirusTotalApiKey)) return null; + if (kind is not "sha256" and not "sha1" and not "ipv4" and not "domain") return null; + var path = kind switch + { + "sha256" or "sha1" => $"https://www.virustotal.com/api/v3/files/{Uri.EscapeDataString(value)}", + "ipv4" => $"https://www.virustotal.com/api/v3/ip_addresses/{Uri.EscapeDataString(value)}", + "domain" => $"https://www.virustotal.com/api/v3/domains/{Uri.EscapeDataString(value)}", + _ => throw new InvalidOperationException(), + }; + using var request = new HttpRequestMessage(HttpMethod.Get, path); + request.Headers.TryAddWithoutValidation("x-apikey", _opts.VirusTotalApiKey); + using var response = await http.SendAsync(request, ct); + if (response.StatusCode == System.Net.HttpStatusCode.NotFound) + { + return new ReputationLookup(ReputationProvider.VirusTotal, ReputationVerdict.Unknown, null, new { not_found = true }); + } + if (!response.IsSuccessStatusCode) + { + return new ReputationLookup(ReputationProvider.VirusTotal, ReputationVerdict.Error, null, new { http_status = (int)response.StatusCode }); + } + var body = await response.Content.ReadAsStringAsync(ct); + using var doc = JsonDocument.Parse(body); + var stats = doc.RootElement + .GetProperty("data") + .GetProperty("attributes") + .GetProperty("last_analysis_stats"); + var malicious = stats.GetProperty("malicious").GetInt32(); + var suspicious = stats.TryGetProperty("suspicious", out var s) ? s.GetInt32() : 0; + var verdict = malicious switch + { + >= 5 => ReputationVerdict.Malicious, + >= 1 => ReputationVerdict.Suspicious, + _ => suspicious > 0 ? ReputationVerdict.Suspicious : ReputationVerdict.Clean, + }; + return new ReputationLookup(ReputationProvider.VirusTotal, verdict, malicious, new + { + malicious, + suspicious, + stats = stats.GetRawText(), + }); + } + + private async Task ProbeAbuseIpDbAsync(string kind, string value, CancellationToken ct) + { + if (string.IsNullOrWhiteSpace(_opts.AbuseIpDbApiKey)) return null; + if (kind != "ipv4") return null; + using var request = new HttpRequestMessage(HttpMethod.Get, + $"https://api.abuseipdb.com/api/v2/check?ipAddress={Uri.EscapeDataString(value)}&maxAgeInDays=90"); + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + request.Headers.TryAddWithoutValidation("Key", _opts.AbuseIpDbApiKey); + using var response = await http.SendAsync(request, ct); + if (!response.IsSuccessStatusCode) + { + return new ReputationLookup(ReputationProvider.AbuseIpDb, ReputationVerdict.Error, null, + new { http_status = (int)response.StatusCode }); + } + var body = await response.Content.ReadAsStringAsync(ct); + using var doc = JsonDocument.Parse(body); + var data = doc.RootElement.GetProperty("data"); + var score = data.GetProperty("abuseConfidenceScore").GetInt32(); + var verdict = score switch + { + >= 75 => ReputationVerdict.Malicious, + >= 25 => ReputationVerdict.Suspicious, + _ => ReputationVerdict.Clean, + }; + return new ReputationLookup(ReputationProvider.AbuseIpDb, verdict, score, new + { + confidence = score, + usage_type = data.TryGetProperty("usageType", out var ut) ? ut.GetString() : null, + country = data.TryGetProperty("countryCode", out var cc) ? cc.GetString() : null, + total_reports = data.TryGetProperty("totalReports", out var tr) ? tr.GetInt32() : 0, + }); + } + + private async Task ProbeGreyNoiseAsync(string kind, string value, CancellationToken ct) + { + // GreyNoise Community API works without a key (rate-limited); the key + // unlocks higher quotas. Only IPv4 is supported. + if (kind != "ipv4") return null; + using var request = new HttpRequestMessage(HttpMethod.Get, + $"https://api.greynoise.io/v3/community/{Uri.EscapeDataString(value)}"); + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + if (!string.IsNullOrWhiteSpace(_opts.GreyNoiseApiKey)) + { + request.Headers.TryAddWithoutValidation("key", _opts.GreyNoiseApiKey); + } + using var response = await http.SendAsync(request, ct); + if (response.StatusCode == System.Net.HttpStatusCode.NotFound) + { + return new ReputationLookup(ReputationProvider.GreyNoise, ReputationVerdict.Unknown, null, new { not_found = true }); + } + if (!response.IsSuccessStatusCode) + { + return new ReputationLookup(ReputationProvider.GreyNoise, ReputationVerdict.Error, null, + new { http_status = (int)response.StatusCode }); + } + var body = await response.Content.ReadAsStringAsync(ct); + using var doc = JsonDocument.Parse(body); + var root = doc.RootElement; + var classification = root.TryGetProperty("classification", out var c) ? c.GetString() : null; + var verdict = classification switch + { + "malicious" => ReputationVerdict.Malicious, + "suspicious" => ReputationVerdict.Suspicious, + "benign" => ReputationVerdict.Clean, + _ => ReputationVerdict.Unknown, + }; + return new ReputationLookup(ReputationProvider.GreyNoise, verdict, null, new + { + classification, + noise = root.TryGetProperty("noise", out var n) && n.ValueKind == JsonValueKind.True, + riot = root.TryGetProperty("riot", out var r) && r.ValueKind == JsonValueKind.True, + name = root.TryGetProperty("name", out var nm) ? nm.GetString() : null, + }); + } + + private IEnumerable ProvidersForKind(string kind) + { + if (!string.IsNullOrWhiteSpace(_opts.VirusTotalApiKey) + && kind is "sha256" or "sha1" or "ipv4" or "domain") + { + yield return ReputationProvider.VirusTotal; + } + if (!string.IsNullOrWhiteSpace(_opts.AbuseIpDbApiKey) && kind == "ipv4") + { + yield return ReputationProvider.AbuseIpDb; + } + if (kind == "ipv4") + { + yield return ReputationProvider.GreyNoise; + } + } +} diff --git a/backend/src/Tawny.Infrastructure/ThreatIntel/ThreatIntelFetcher.cs b/backend/src/Tawny.Infrastructure/ThreatIntel/ThreatIntelFetcher.cs new file mode 100644 index 0000000..7d4c3c7 --- /dev/null +++ b/backend/src/Tawny.Infrastructure/ThreatIntel/ThreatIntelFetcher.cs @@ -0,0 +1,413 @@ +using System.Net.Http.Headers; +using System.Text.Json; +using System.Text.RegularExpressions; +using Microsoft.Extensions.Logging; +using Tawny.Domain; +using Tawny.Domain.Entities; + +namespace Tawny.Infrastructure.ThreatIntel; + +public record FetchedIndicator(string Kind, string Value, string? Description); + +public record FetchedExposure( + string Ecosystem, + string Name, + string? VersionPattern, + string? AdvisoryId, + string? AdvisoryUrl, + string? Summary); + +public record FetchResult( + bool Modified, + string? Etag, + IReadOnlyList Indicators, + IReadOnlyList Skipped, + IReadOnlyList? Exposures = null); + +public class ThreatIntelFetchException(string message, Exception? inner = null) : Exception(message, inner); + +/// +/// Pulls indicators from supported TI feed shapes. Returns raw FetchedIndicator +/// records — the caller decides which to turn into AlertRules. +/// +public class ThreatIntelFetcher(HttpClient http, ILogger log) +{ + private static readonly Regex Sha256Re = new(@"\b[a-fA-F0-9]{64}\b", RegexOptions.Compiled); + private static readonly Regex Sha1Re = new(@"\b[a-fA-F0-9]{40}\b", RegexOptions.Compiled); + private static readonly Regex Ipv4Re = new(@"\b(?:\d{1,3}\.){3}\d{1,3}\b", RegexOptions.Compiled); + private static readonly Regex DomainRe = new(@"\b(?=.{4,253}\b)([a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}\b", RegexOptions.Compiled); + private const int MaxIndicatorsPerFeed = 5_000; + + public async Task FetchAsync(ThreatIntelFeed feed, CancellationToken ct) + { + using var request = new HttpRequestMessage(HttpMethod.Get, feed.Url); + if (!string.IsNullOrWhiteSpace(feed.AuthHeaderName) && !string.IsNullOrWhiteSpace(feed.AuthHeaderValueEncrypted)) + { + // We're using the column name "Encrypted" defensively but storing plaintext here; + // a real deployment would decrypt via DPAPI or the configured secret store. + request.Headers.TryAddWithoutValidation(feed.AuthHeaderName, feed.AuthHeaderValueEncrypted); + } + if (!string.IsNullOrWhiteSpace(feed.Etag)) + { + request.Headers.IfNoneMatch.Add(new EntityTagHeaderValue($"\"{feed.Etag}\"")); + } + request.Headers.UserAgent.ParseAdd("Tawny-EDR/1.0 (+https://github.com/jusso-dev/Tawny)"); + + using var response = await http.SendAsync(request, ct); + if (response.StatusCode == System.Net.HttpStatusCode.NotModified) + { + return new FetchResult(false, feed.Etag, [], []); + } + if (!response.IsSuccessStatusCode) + { + throw new ThreatIntelFetchException($"Feed responded with {(int)response.StatusCode} {response.ReasonPhrase}"); + } + + var body = await response.Content.ReadAsStringAsync(ct); + var etag = response.Headers.ETag?.Tag?.Trim('"'); + + // OSV feeds produce package-exposure records, not the hash/IP/domain + // indicator shape, so they take a separate branch. + if (feed.Kind == ThreatIntelFeedKind.OsvVulnerabilities) + { + var exposures = ParseOsv(body); + log.LogInformation("OSV feed {Feed} returned {Count} exposures.", feed.Name, exposures.Count); + return new FetchResult(true, etag, [], [], exposures); + } + + var indicators = feed.Kind switch + { + ThreatIntelFeedKind.UrlhausCsv => ParseUrlhausCsv(body), + ThreatIntelFeedKind.UrlhausJson => ParseUrlhausJson(body), + ThreatIntelFeedKind.OtxPulse => ParseOtxPulse(body), + ThreatIntelFeedKind.MispEvents => ParseMispEvents(body), + ThreatIntelFeedKind.Taxii21 => ParseTaxii21(body), + ThreatIntelFeedKind.GenericCsv => ParseGenericCsv(body), + _ => throw new ThreatIntelFetchException($"Unsupported feed kind: {feed.Kind}"), + }; + + var (taken, skipped) = TakeWithBudget(indicators); + if (skipped.Count > 0) + { + log.LogInformation("Feed {Feed} returned {Total} indicators, kept {Kept}, skipped {Skipped}.", + feed.Name, indicators.Count, taken.Count, skipped.Count); + } + return new FetchResult(true, etag, taken, skipped); + } + + private static List ParseOsv(string body) + { + // Accept either a single OSV record, a `{advisories: [...]}` bundle, + // or a top-level JSON array of records. Mirrors ExposureRuleImporter. + var result = new List(); + try + { + using var doc = JsonDocument.Parse(body); + var root = doc.RootElement; + switch (root.ValueKind) + { + case JsonValueKind.Array: + foreach (var entry in root.EnumerateArray()) AppendOsvRecord(entry, result); + break; + case JsonValueKind.Object when root.TryGetProperty("advisories", out var bundle) + && bundle.ValueKind == JsonValueKind.Array: + foreach (var entry in bundle.EnumerateArray()) AppendOsvRecord(entry, result); + break; + case JsonValueKind.Object: + AppendOsvRecord(root, result); + break; + } + } + catch (JsonException ex) + { + throw new ThreatIntelFetchException("OSV parse failed", ex); + } + return result; + } + + private static void AppendOsvRecord(JsonElement advisory, List out_) + { + var id = advisory.TryGetProperty("id", out var i) && i.ValueKind == JsonValueKind.String ? i.GetString() : null; + var summary = advisory.TryGetProperty("summary", out var s) && s.ValueKind == JsonValueKind.String ? s.GetString() : null; + var url = ExtractFirstOsvUrl(advisory); + + if (!advisory.TryGetProperty("affected", out var affected) || affected.ValueKind != JsonValueKind.Array) return; + foreach (var a in affected.EnumerateArray()) + { + if (!a.TryGetProperty("package", out var pkg)) continue; + if (!pkg.TryGetProperty("ecosystem", out var eco) || !pkg.TryGetProperty("name", out var name)) continue; + var ecosystem = eco.GetString(); + var packageName = name.GetString(); + if (string.IsNullOrWhiteSpace(ecosystem) || string.IsNullOrWhiteSpace(packageName)) continue; + out_.Add(new FetchedExposure( + ecosystem.ToLowerInvariant(), + packageName, + BuildOsvPattern(a), + id, + url, + summary)); + } + } + + private static string? BuildOsvPattern(JsonElement affected) + { + if (affected.TryGetProperty("versions", out var versions) && versions.ValueKind == JsonValueKind.Array) + { + var list = new List(); + foreach (var v in versions.EnumerateArray()) + { + if (v.ValueKind == JsonValueKind.String && !string.IsNullOrWhiteSpace(v.GetString())) + list.Add(v.GetString()!); + } + if (list.Count > 0) return string.Join(",", list); + } + if (affected.TryGetProperty("ranges", out var ranges) && ranges.ValueKind == JsonValueKind.Array) + { + var fragments = new List(); + foreach (var range in ranges.EnumerateArray()) + { + if (!range.TryGetProperty("events", out var events) || events.ValueKind != JsonValueKind.Array) continue; + string? introduced = null; + string? fixedAt = null; + foreach (var ev in events.EnumerateArray()) + { + if (ev.TryGetProperty("introduced", out var iEl) && iEl.ValueKind == JsonValueKind.String) introduced = iEl.GetString(); + if (ev.TryGetProperty("fixed", out var fEl) && fEl.ValueKind == JsonValueKind.String) fixedAt = fEl.GetString(); + } + if (introduced is not null and not "0") fragments.Add($">={introduced}"); + if (fixedAt is not null) fragments.Add($"<{fixedAt}"); + } + if (fragments.Count > 0) return string.Join(",", fragments); + } + return null; + } + + private static string? ExtractFirstOsvUrl(JsonElement advisory) + { + if (!advisory.TryGetProperty("references", out var refs) || refs.ValueKind != JsonValueKind.Array) return null; + foreach (var r in refs.EnumerateArray()) + { + if (r.TryGetProperty("url", out var u) && u.ValueKind == JsonValueKind.String) return u.GetString(); + } + return null; + } + + // ---------- parsers ---------- + + private static List ParseUrlhausCsv(string body) + { + // abuse.ch URLhaus CSV: id,dateadded,url,url_status,threat,tags,urlhaus_link,reporter + var result = new List(); + foreach (var rawLine in body.Split('\n', StringSplitOptions.RemoveEmptyEntries)) + { + var line = rawLine.Trim(); + if (line.Length == 0 || line.StartsWith('#')) continue; + var cols = SplitCsv(line); + if (cols.Count < 3) continue; + var url = cols[2].Trim('"'); + if (Uri.TryCreate(url, UriKind.Absolute, out var uri) && !string.IsNullOrEmpty(uri.Host)) + { + result.Add(new FetchedIndicator("domain", uri.Host, $"URLhaus: {url}")); + } + } + return result; + } + + private static List ParseUrlhausJson(string body) + { + var result = new List(); + try + { + using var doc = JsonDocument.Parse(body); + // URLhaus JSON: { "1": { "url": "...", "host": "...", "tags": [...] }, ... } + foreach (var entry in doc.RootElement.EnumerateObject()) + { + if (entry.Value.ValueKind != JsonValueKind.Object) continue; + if (entry.Value.TryGetProperty("host", out var host) && host.ValueKind == JsonValueKind.String) + { + var value = host.GetString(); + if (!string.IsNullOrWhiteSpace(value)) + { + result.Add(new FetchedIndicator("domain", value, "URLhaus host")); + } + } + } + } + catch (JsonException ex) + { + throw new ThreatIntelFetchException("URLhaus JSON parse failed", ex); + } + return result; + } + + private static List ParseOtxPulse(string body) + { + // OTX pulse JSON: { "results": [ { "indicators": [ { "type": "IPv4", "indicator": "1.2.3.4" } ] } ] } + var result = new List(); + try + { + using var doc = JsonDocument.Parse(body); + if (!doc.RootElement.TryGetProperty("results", out var results)) return result; + foreach (var pulse in results.EnumerateArray()) + { + if (!pulse.TryGetProperty("indicators", out var indicators)) continue; + foreach (var ind in indicators.EnumerateArray()) + { + var type = ind.GetProperty("type").GetString(); + var value = ind.GetProperty("indicator").GetString(); + if (string.IsNullOrWhiteSpace(value) || string.IsNullOrWhiteSpace(type)) continue; + var kind = type.ToLowerInvariant() switch + { + "ipv4" => "ipv4", + "ipv6" => "ipv6", + "domain" or "hostname" => "domain", + "filehash-sha256" => "sha256", + "filehash-sha1" => "sha1", + _ => null, + }; + if (kind is not null) + { + result.Add(new FetchedIndicator(kind, value, $"OTX {type}")); + } + } + } + } + catch (JsonException ex) + { + throw new ThreatIntelFetchException("OTX pulse parse failed", ex); + } + return result; + } + + private static List ParseMispEvents(string body) + { + // MISP /events/restSearch returns { "response": [{ "Event": { "Attribute": [{ "type": "...", "value": "..." }]}}]} + var result = new List(); + try + { + using var doc = JsonDocument.Parse(body); + if (!doc.RootElement.TryGetProperty("response", out var response)) return result; + foreach (var entry in response.EnumerateArray()) + { + if (!entry.TryGetProperty("Event", out var eventNode)) continue; + if (!eventNode.TryGetProperty("Attribute", out var attributes)) continue; + foreach (var attr in attributes.EnumerateArray()) + { + var type = attr.GetProperty("type").GetString(); + var value = attr.GetProperty("value").GetString(); + if (string.IsNullOrWhiteSpace(value) || string.IsNullOrWhiteSpace(type)) continue; + var kind = type switch + { + "ip-src" or "ip-dst" => "ipv4", + "domain" or "hostname" => "domain", + "sha256" => "sha256", + "sha1" => "sha1", + _ => null, + }; + if (kind is not null) + { + result.Add(new FetchedIndicator(kind, value, $"MISP {type}")); + } + } + } + } + catch (JsonException ex) + { + throw new ThreatIntelFetchException("MISP parse failed", ex); + } + return result; + } + + private static List ParseTaxii21(string body) + { + // TAXII 2.1 envelope: { "objects": [ STIX bundles... ] } + var result = new List(); + try + { + using var doc = JsonDocument.Parse(body); + if (!doc.RootElement.TryGetProperty("objects", out var objects)) return result; + foreach (var obj in objects.EnumerateArray()) + { + if (!obj.TryGetProperty("type", out var type) || type.GetString() != "indicator") continue; + if (!obj.TryGetProperty("pattern", out var pattern)) continue; + var raw = pattern.GetString() ?? ""; + // Reuse the same simple pattern shapes the existing IoC importer understands. + foreach (var match in Regex.Matches(raw, @"\[(?file:hashes\.'SHA-256'|file:hashes\.'SHA-1'|ipv4-addr:value|ipv6-addr:value|domain-name:value)\s*=\s*'(?[^']+)'\]").Cast()) + { + var k = match.Groups["kind"].Value; + var v = match.Groups["value"].Value; + var kind = k switch + { + "file:hashes.'SHA-256'" => "sha256", + "file:hashes.'SHA-1'" => "sha1", + "ipv4-addr:value" => "ipv4", + "ipv6-addr:value" => "ipv6", + "domain-name:value" => "domain", + _ => null, + }; + if (kind is not null) + { + result.Add(new FetchedIndicator(kind, v, $"TAXII {k}")); + } + } + } + } + catch (JsonException ex) + { + throw new ThreatIntelFetchException("TAXII parse failed", ex); + } + return result; + } + + private static List ParseGenericCsv(string body) + { + // One indicator per line (or first column of a CSV). Auto-detect kind by shape. + var result = new List(); + foreach (var rawLine in body.Split('\n', StringSplitOptions.RemoveEmptyEntries)) + { + var line = rawLine.Trim(); + if (line.Length == 0 || line.StartsWith('#')) continue; + var cols = SplitCsv(line); + var value = cols.Count == 0 ? line : cols[0].Trim('"'); + var kind = DetectKind(value); + if (kind is not null) + { + result.Add(new FetchedIndicator(kind, value, "Generic CSV")); + } + } + return result; + } + + private static string? DetectKind(string value) + { + if (Sha256Re.IsMatch(value) && value.Length == 64) return "sha256"; + if (Sha1Re.IsMatch(value) && value.Length == 40) return "sha1"; + if (Ipv4Re.IsMatch(value)) return "ipv4"; + if (DomainRe.IsMatch(value)) return "domain"; + return null; + } + + private static List SplitCsv(string line) + { + return line.Split(',').Select(p => p.Trim()).ToList(); + } + + private static (List Taken, List Skipped) TakeWithBudget(IReadOnlyList all) + { + var seen = new HashSet(StringComparer.OrdinalIgnoreCase); + var taken = new List(); + var skipped = new List(); + foreach (var indicator in all) + { + var key = $"{indicator.Kind}:{indicator.Value}"; + if (!seen.Add(key)) continue; + if (taken.Count >= MaxIndicatorsPerFeed) + { + skipped.Add(key); + continue; + } + taken.Add(indicator); + } + return (taken, skipped); + } +} diff --git a/backend/src/Tawny.Jobs/ReputationEnrichmentJob.cs b/backend/src/Tawny.Jobs/ReputationEnrichmentJob.cs new file mode 100644 index 0000000..33b55de --- /dev/null +++ b/backend/src/Tawny.Jobs/ReputationEnrichmentJob.cs @@ -0,0 +1,162 @@ +using System.Text.Json; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Tawny.Domain; +using Tawny.Domain.Entities; +using Tawny.Infrastructure; +using Tawny.Infrastructure.ThreatIntel; + +namespace Tawny.Jobs; + +/// +/// Walks recent unenriched alerts, extracts the matched IoC value from the +/// rule's payload_path, looks it up via reputation providers, and stores the +/// verdict on Alert.EnrichmentJson. Reputation is cached per tenant. +/// +public class ReputationEnrichmentJob( + TawnyDbContext db, + ReputationEnricher enricher, + IOptions options, + ILogger log) +{ + private const int BatchSize = 100; + private static readonly JsonSerializerOptions JsonOptions = new(JsonSerializerDefaults.Web) + { + PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, + }; + + public async Task ExecuteAsync(CancellationToken ct = default) + { + if (!options.Value.EnrichAlertsAutomatically) return; + + var cutoff = DateTimeOffset.UtcNow.AddHours(-24); + var alerts = await db.Alerts + .Where(a => a.EnrichmentJson == null && a.CreatedAt >= cutoff) + .OrderByDescending(a => a.CreatedAt) + .Take(BatchSize) + .Include(a => a.AlertRule) + .Include(a => a.Agent) + .Include(a => a.TelemetryEvent) + .ToListAsync(ct); + + if (alerts.Count == 0) return; + var enrichedCount = 0; + + foreach (var alert in alerts) + { + if (ct.IsCancellationRequested) break; + var rule = alert.AlertRule; + var telemetry = alert.TelemetryEvent; + if (rule is null || telemetry is null) continue; + + var (kind, value) = ExtractIndicator(rule, telemetry); + if (kind is null || string.IsNullOrEmpty(value)) + { + alert.EnrichmentJson = "{\"enriched\":false,\"reason\":\"no_extractable_indicator\"}"; + continue; + } + + try + { + var tenantId = alert.Agent?.TenantId ?? Tawny.Domain.TenantDefaults.DefaultTenantId; + var lookups = await enricher.LookupAsync(tenantId, kind, value, ct); + alert.EnrichmentJson = JsonSerializer.Serialize(new + { + enriched = true, + indicator = new { kind, value }, + lookups = lookups.Select(l => new + { + provider = l.Provider.ToString(), + verdict = l.Verdict.ToString(), + score = l.Score, + detail = l.Detail, + }), + }, JsonOptions); + enrichedCount += 1; + } + catch (Exception ex) + { + alert.EnrichmentJson = JsonSerializer.Serialize(new + { + enriched = false, + reason = "lookup_failed", + error = ex.Message, + }, JsonOptions); + log.LogWarning(ex, "Reputation enrichment failed for alert {AlertId}", alert.Id); + } + } + + await db.SaveChangesAsync(ct); + if (enrichedCount > 0) + { + log.LogInformation("Reputation enrichment completed: {Count} alerts enriched.", enrichedCount); + } + } + + private static (string? Kind, string? Value) ExtractIndicator(AlertRule rule, TelemetryEvent telemetryEvent) + { + // The cheapest path: if the rule is an IoC rule, the MatchValue is the indicator itself. + if (rule.Format == AlertRuleFormat.Ioc && !string.IsNullOrEmpty(rule.MatchValue)) + { + var kind = rule.PayloadPath switch + { + "new_sha256" => "sha256", + "new_sha1" => "sha1", + "connections.remote_address" => "ipv4", + "processes.command_line" => "domain", + _ => null, + }; + if (kind is not null) + { + return (kind, rule.MatchValue); + } + } + + // Fallback: pull from the payload via rule.PayloadPath. + if (string.IsNullOrWhiteSpace(rule.PayloadPath)) return (null, null); + try + { + using var payload = JsonDocument.Parse(telemetryEvent.Payload); + var segments = rule.PayloadPath.Split('.', StringSplitOptions.RemoveEmptyEntries); + var first = Resolve(payload.RootElement, segments).FirstOrDefault(); + if (first.ValueKind == JsonValueKind.Undefined) return (null, null); + var scalar = first.ValueKind switch + { + JsonValueKind.String => first.GetString(), + JsonValueKind.Number => first.GetRawText(), + _ => null, + }; + if (string.IsNullOrWhiteSpace(scalar)) return (null, null); + var kind = rule.PayloadPath switch + { + "new_sha256" => "sha256", + "new_sha1" => "sha1", + _ when rule.PayloadPath.Contains("address", StringComparison.OrdinalIgnoreCase) => "ipv4", + _ when rule.PayloadPath.Contains("domain", StringComparison.OrdinalIgnoreCase) => "domain", + _ => null, + }; + return (kind, scalar); + } + catch + { + return (null, null); + } + } + + private static IEnumerable Resolve(JsonElement current, IReadOnlyList segments, int index = 0) + { + if (index >= segments.Count) { yield return current; yield break; } + if (current.ValueKind == JsonValueKind.Array) + { + foreach (var item in current.EnumerateArray()) + { + foreach (var v in Resolve(item, segments, index)) yield return v; + } + yield break; + } + if (current.ValueKind != JsonValueKind.Object) yield break; + if (!current.TryGetProperty(segments[index], out var child)) yield break; + foreach (var v in Resolve(child, segments, index + 1)) yield return v; + } +} diff --git a/backend/src/Tawny.Jobs/ThreatIntelFeedsJob.cs b/backend/src/Tawny.Jobs/ThreatIntelFeedsJob.cs new file mode 100644 index 0000000..de96582 --- /dev/null +++ b/backend/src/Tawny.Jobs/ThreatIntelFeedsJob.cs @@ -0,0 +1,173 @@ +using System.Net.Http; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Logging; +using Tawny.Domain; +using Tawny.Domain.Entities; +using Tawny.Infrastructure; +using Tawny.Infrastructure.Hunting; +using Tawny.Infrastructure.ThreatIntel; + +namespace Tawny.Jobs; + +/// +/// Walks every enabled ThreatIntelFeed whose interval has elapsed, pulls its +/// payload, and materialises new indicators as AlertRules (Format = Ioc) keyed +/// by ExternalId so re-imports are idempotent. +/// +public class ThreatIntelFeedsJob( + TawnyDbContext db, + TimeProvider timeProvider, + ThreatIntelFetcher fetcher, + ILogger log) +{ + public async Task ExecuteAsync(CancellationToken ct = default) + { + var now = timeProvider.GetUtcNow(); + var due = await db.ThreatIntelFeeds + .Where(f => f.IsEnabled + && (f.LastRunAt == null + || EF.Functions.DateDiffMinute(f.LastRunAt!.Value, now) >= f.IntervalMinutes)) + .ToListAsync(ct); + if (due.Count == 0) return; + + foreach (var feed in due) + { + if (ct.IsCancellationRequested) break; + await RunOneAsync(feed, now, ct); + } + } + + private async Task RunOneAsync(ThreatIntelFeed feed, DateTimeOffset now, CancellationToken ct) + { + feed.LastRunAt = now; + try + { + var result = await fetcher.FetchAsync(feed, ct); + if (!result.Modified) + { + feed.Status = ThreatIntelFeedStatus.Healthy; + feed.LastSuccessAt = now; + feed.LastError = null; + await db.SaveChangesAsync(ct); + return; + } + feed.Etag = result.Etag; + await MaterialiseAsync(feed, result, now, ct); + feed.Status = ThreatIntelFeedStatus.Healthy; + feed.LastSuccessAt = now; + feed.LastImportedCount = result.Indicators.Count + (result.Exposures?.Count ?? 0); + feed.LastSkippedCount = result.Skipped.Count; + feed.LastError = null; + } + catch (Exception ex) + { + feed.Status = ThreatIntelFeedStatus.Failed; + feed.LastError = ex.Message.Length > 2000 ? ex.Message[..2000] : ex.Message; + log.LogError(ex, "TI feed {Name} ({Url}) failed", feed.Name, feed.Url); + } + await db.SaveChangesAsync(ct); + } + + private async Task MaterialiseAsync( + ThreatIntelFeed feed, + FetchResult result, + DateTimeOffset now, + CancellationToken ct) + { + var externalIdPrefix = $"ti-feed:{feed.Id}:"; + var existingIds = await db.AlertRules + .Where(r => r.ExternalId != null && r.ExternalId.StartsWith(externalIdPrefix)) + .Select(r => r.ExternalId!) + .ToListAsync(ct); + var existing = new HashSet(existingIds, StringComparer.OrdinalIgnoreCase); + + var newRules = new List(); + foreach (var ind in result.Indicators) + { + var externalId = externalIdPrefix + ind.Kind + ":" + ind.Value.ToLowerInvariant(); + if (existing.Contains(externalId)) continue; + + (TelemetryEventType EventType, string PayloadPath, AlertRuleOperator Op)? compiled = ind.Kind switch + { + "sha256" => (TelemetryEventType.FileIntegrity, "new_sha256", AlertRuleOperator.Equals), + "sha1" => (TelemetryEventType.FileIntegrity, "new_sha1", AlertRuleOperator.Equals), + "ipv4" or "ipv6" => (TelemetryEventType.NetworkSnapshot, "connections.remote_address", AlertRuleOperator.Equals), + "domain" => (TelemetryEventType.ProcessSnapshot, "processes.command_line", AlertRuleOperator.Contains), + _ => null, + }; + if (compiled is null) continue; + var (eventType, payloadPath, op) = compiled.Value; + + newRules.Add(new AlertRule + { + Id = Guid.NewGuid(), + Name = $"TI feed {feed.Name}: {ind.Kind} {ind.Value}", + Format = AlertRuleFormat.Ioc, + ExternalId = externalId, + Description = ind.Description, + EventType = eventType, + Severity = feed.DefaultSeverity, + Operator = op, + PayloadPath = payloadPath, + MatchValue = ind.Value, + IsEnabled = true, + CreatedAt = now, + UpdatedAt = now, + }); + } + + // OSV exposures coming from the feed get materialised as Format=PackageExposure + // rules so the agent's inventory events can match them at ingest time. + if (result.Exposures is { Count: > 0 }) + { + foreach (var exposure in result.Exposures) + { + var pattern = exposure.VersionPattern ?? "any"; + var externalId = $"{externalIdPrefix}exposure:{exposure.Ecosystem}:{exposure.Name}:{pattern}"; + if (exposure.AdvisoryId is { Length: > 0 }) externalId = $"{externalId}:{exposure.AdvisoryId}"; + if (externalId.Length > 128) externalId = externalId[..128]; + if (existing.Contains(externalId)) continue; + + var definition = new PackageExposureDefinition( + exposure.Ecosystem, + exposure.Name, + exposure.VersionPattern, + exposure.AdvisoryId, + exposure.AdvisoryUrl); + var eventType = exposure.Ecosystem switch + { + "editor-extension" or "editor_extension" => TelemetryEventType.EditorExtension, + "browser-extension" or "browser_extension" => TelemetryEventType.BrowserExtension, + "mcp" or "mcp_server" or "mcp-server" => TelemetryEventType.McpConfig, + _ => TelemetryEventType.PackageInventory, + }; + + newRules.Add(new AlertRule + { + Id = Guid.NewGuid(), + Name = $"OSV: {exposure.Ecosystem}/{exposure.Name} {pattern}", + Format = AlertRuleFormat.PackageExposure, + ExternalId = externalId, + Description = exposure.Summary + ?? $"OSV exposure from {feed.Name}: {exposure.Ecosystem}/{exposure.Name} {pattern}.", + EventType = eventType, + Severity = feed.DefaultSeverity, + Operator = AlertRuleOperator.Exists, + SourceDefinition = PackageExposureParser.Serialize(definition), + IsEnabled = true, + CreatedAt = now, + UpdatedAt = now, + }); + } + } + + if (newRules.Count > 0) + { + db.AlertRules.AddRange(newRules); + log.LogInformation("TI feed {Name} imported {Count} new rules ({Ioc} IoCs + {Exp} exposures).", + feed.Name, newRules.Count, + newRules.Count(r => r.Format == AlertRuleFormat.Ioc), + newRules.Count(r => r.Format == AlertRuleFormat.PackageExposure)); + } + } +} diff --git a/backend/tests/Tawny.Api.Tests/AgentFlowIntegrationTests.cs b/backend/tests/Tawny.Api.Tests/AgentFlowIntegrationTests.cs index 377e6d4..354e3a8 100644 --- a/backend/tests/Tawny.Api.Tests/AgentFlowIntegrationTests.cs +++ b/backend/tests/Tawny.Api.Tests/AgentFlowIntegrationTests.cs @@ -416,7 +416,10 @@ public async Task ImportedStixIocs_CreateRulesAndAlertForMatchingNetworkTelemetr { var db = scope.ServiceProvider.GetRequiredService(); var rules = await db.AlertRules.OrderBy(r => r.PayloadPath).ToListAsync(); - rules.Should().HaveCount(3); + // Each domain IoC now produces two rules: a process command-line + // match (legacy fallback) and a DNS qname match for the new + // dns_query event type. So we expect 4 rules for {ip, sha256, domain}. + rules.Should().HaveCount(4); rules.Should().OnlyContain(r => r.Format == AlertRuleFormat.Ioc); rules.Should().Contain(r => r.EventType == TelemetryEventType.NetworkSnapshot && @@ -429,6 +432,10 @@ public async Task ImportedStixIocs_CreateRulesAndAlertForMatchingNetworkTelemetr r.EventType == TelemetryEventType.ProcessSnapshot && r.PayloadPath == "processes.command_line" && r.MatchValue == "payload.example.com"); + rules.Should().Contain(r => + r.EventType == TelemetryEventType.DnsQuery && + r.PayloadPath == "qname" && + r.MatchValue == "payload.example.com"); } var enrollmentToken = TokenHashing.NewToken(); diff --git a/web/app/agents/[id]/events-panel.tsx b/web/app/agents/[id]/events-panel.tsx index 3a150e5..7ac6f2c 100644 --- a/web/app/agents/[id]/events-panel.tsx +++ b/web/app/agents/[id]/events-panel.tsx @@ -12,7 +12,14 @@ type EventType = | "file_integrity" | "user_session" | "system_info" - | "heartbeat"; + | "heartbeat" + | "dns_query" + | "process_launch" + | "file_event" + | "package_inventory" + | "editor_extension" + | "browser_extension" + | "mcp_config"; type TelemetryEvent = { id: number; @@ -32,8 +39,15 @@ type Tab = { const TABS: Tab[] = [ { key: "processes", label: "Processes", type: "process_snapshot" }, { key: "tree", label: "Process tree", type: "process_snapshot" }, + { key: "launches", label: "Launches", type: "process_launch" }, + { key: "dns", label: "DNS", type: "dns_query" }, { key: "network", label: "Network", type: "network_snapshot" }, { key: "fim", label: "FIM", type: "file_integrity" }, + { key: "fs", label: "FS events", type: "file_event" }, + { key: "inventory", label: "Inventory", type: "package_inventory" }, + { key: "editor-ext", label: "Editor ext", type: "editor_extension" }, + { key: "browser-ext", label: "Browser ext", type: "browser_extension" }, + { key: "mcp", label: "MCP", type: "mcp_config" }, { key: "sessions", label: "Sessions", type: "user_session" }, { key: "raw", label: "Raw events" }, ]; @@ -247,9 +261,30 @@ function summarizePayload(event: TelemetryEvent) { return `${event.payload.processes.length} processes${names ? `: ${names}` : ""}`; } + // Inventory / extension / MCP events all carry {ecosystem, name, version} + // at the top level — summarise as one line per event for the table view. + if (isInventoryLike(event.type) && isPackageRecord(event.payload)) { + const p = event.payload; + const provenance = p.source_type ? ` (${p.source_type})` : ""; + return `${p.ecosystem}/${p.name}@${p.version}${provenance}`; + } + return JSON.stringify(event.payload, null, 2); } +function isInventoryLike(type: EventType): boolean { + return type === "package_inventory" + || type === "editor_extension" + || type === "browser_extension" + || type === "mcp_config"; +} + +function isPackageRecord(payload: unknown): payload is { ecosystem: string; name: string; version: string; source_type?: string } { + if (!payload || typeof payload !== "object") return false; + const obj = payload as Record; + return typeof obj.ecosystem === "string" && typeof obj.name === "string" && typeof obj.version === "string"; +} + type ProcessRow = { pid: number; ppid?: number; diff --git a/web/app/api/alert-rules/exposures/route.ts b/web/app/api/alert-rules/exposures/route.ts new file mode 100644 index 0000000..fc24a08 --- /dev/null +++ b/web/app/api/alert-rules/exposures/route.ts @@ -0,0 +1,33 @@ +import { headers } from "next/headers"; +import { NextRequest, NextResponse } from "next/server"; +import { z } from "zod"; +import { auth } from "@/lib/auth"; +import { authRole } from "@/lib/auth-role"; +import { ApiError, apiPost } from "@/lib/api"; + +const schema = z.object({ + definition: z.string().trim().min(1, "Definition is required."), + severity: z.enum(["low", "medium", "high", "critical"]).optional(), + is_enabled: z.boolean().optional(), +}); + +export async function POST(req: NextRequest) { + const session = await auth.api.getSession({ headers: await headers() }); + if (!session) return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); + + const body = await req.json().catch(() => null); + const parsed = schema.safeParse(body); + if (!parsed.success) { + return NextResponse.json({ error: parsed.error.issues[0]?.message ?? "Invalid request." }, { status: 400 }); + } + + try { + const data = await apiPost("/api/alert-rules/exposures", parsed.data, session.user.id, authRole(session.user)); + return NextResponse.json(data); + } catch (err) { + if (err instanceof ApiError && err.status >= 400 && err.status < 500) { + return NextResponse.json({ error: err.message }, { status: 400 }); + } + return NextResponse.json({ error: "Failed to import exposures." }, { status: 502 }); + } +} diff --git a/web/app/api/cases/[id]/alerts/[alertId]/route.ts b/web/app/api/cases/[id]/alerts/[alertId]/route.ts new file mode 100644 index 0000000..3844149 --- /dev/null +++ b/web/app/api/cases/[id]/alerts/[alertId]/route.ts @@ -0,0 +1,21 @@ +import { headers } from "next/headers"; +import { NextRequest, NextResponse } from "next/server"; +import { auth } from "@/lib/auth"; +import { authRole } from "@/lib/auth-role"; +import { ApiError, apiDelete } from "@/lib/api"; + +export async function DELETE(_: NextRequest, ctx: { params: Promise<{ id: string; alertId: string }> }) { + const session = await auth.api.getSession({ headers: await headers() }); + if (!session) return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); + + const { id, alertId } = await ctx.params; + try { + await apiDelete(`/api/cases/${id}/alerts/${alertId}`, session.user.id, authRole(session.user)); + return new NextResponse(null, { status: 204 }); + } catch (err) { + if (err instanceof ApiError && err.status >= 400 && err.status < 500) { + return NextResponse.json({ error: err.message }, { status: err.status }); + } + return NextResponse.json({ error: "Failed to remove alert." }, { status: 502 }); + } +} diff --git a/web/app/api/cases/[id]/notes/route.ts b/web/app/api/cases/[id]/notes/route.ts new file mode 100644 index 0000000..b9c966a --- /dev/null +++ b/web/app/api/cases/[id]/notes/route.ts @@ -0,0 +1,30 @@ +import { headers } from "next/headers"; +import { NextRequest, NextResponse } from "next/server"; +import { z } from "zod"; +import { auth } from "@/lib/auth"; +import { authRole } from "@/lib/auth-role"; +import { ApiError, apiPost } from "@/lib/api"; + +const schema = z.object({ body: z.string().trim().min(1) }); + +export async function POST(req: NextRequest, ctx: { params: Promise<{ id: string }> }) { + const session = await auth.api.getSession({ headers: await headers() }); + if (!session) return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); + + const { id } = await ctx.params; + const body = await req.json().catch(() => null); + const parsed = schema.safeParse(body); + if (!parsed.success) { + return NextResponse.json({ error: parsed.error.issues[0]?.message ?? "Invalid request." }, { status: 400 }); + } + + try { + const data = await apiPost(`/api/cases/${id}/notes`, parsed.data, session.user.id, authRole(session.user)); + return NextResponse.json(data); + } catch (err) { + if (err instanceof ApiError && err.status >= 400 && err.status < 500) { + return NextResponse.json({ error: err.message }, { status: err.status }); + } + return NextResponse.json({ error: "Failed to add note." }, { status: 502 }); + } +} diff --git a/web/app/api/cases/[id]/route.ts b/web/app/api/cases/[id]/route.ts new file mode 100644 index 0000000..9bec97d --- /dev/null +++ b/web/app/api/cases/[id]/route.ts @@ -0,0 +1,53 @@ +import { headers } from "next/headers"; +import { NextRequest, NextResponse } from "next/server"; +import { z } from "zod"; +import { auth } from "@/lib/auth"; +import { authRole } from "@/lib/auth-role"; +import { ApiError, apiDelete, apiPut } from "@/lib/api"; + +const schema = z.object({ + title: z.string().trim().min(1).max(255), + summary: z.string().nullable().optional(), + status: z.enum(["open", "investigating", "contained", "resolved", "closed"]), + priority: z.enum(["low", "medium", "high", "critical"]), + assigned_to_user_id: z.string().nullable().optional(), + mitre_techniques: z.array(z.string()).optional(), +}); + +export async function PUT(req: NextRequest, ctx: { params: Promise<{ id: string }> }) { + const session = await auth.api.getSession({ headers: await headers() }); + if (!session) return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); + + const { id } = await ctx.params; + const body = await req.json().catch(() => null); + const parsed = schema.safeParse(body); + if (!parsed.success) { + return NextResponse.json({ error: parsed.error.issues[0]?.message ?? "Invalid request." }, { status: 400 }); + } + + try { + const data = await apiPut(`/api/cases/${id}`, parsed.data, session.user.id, authRole(session.user)); + return NextResponse.json(data); + } catch (err) { + if (err instanceof ApiError && err.status >= 400 && err.status < 500) { + return NextResponse.json({ error: err.message }, { status: err.status }); + } + return NextResponse.json({ error: "Failed to update case." }, { status: 502 }); + } +} + +export async function DELETE(_: NextRequest, ctx: { params: Promise<{ id: string }> }) { + const session = await auth.api.getSession({ headers: await headers() }); + if (!session) return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); + + const { id } = await ctx.params; + try { + await apiDelete(`/api/cases/${id}`, session.user.id, authRole(session.user)); + return new NextResponse(null, { status: 204 }); + } catch (err) { + if (err instanceof ApiError && err.status >= 400 && err.status < 500) { + return NextResponse.json({ error: err.message }, { status: err.status }); + } + return NextResponse.json({ error: "Failed to delete case." }, { status: 502 }); + } +} diff --git a/web/app/api/cases/route.ts b/web/app/api/cases/route.ts new file mode 100644 index 0000000..68c3206 --- /dev/null +++ b/web/app/api/cases/route.ts @@ -0,0 +1,35 @@ +import { headers } from "next/headers"; +import { NextRequest, NextResponse } from "next/server"; +import { z } from "zod"; +import { auth } from "@/lib/auth"; +import { authRole } from "@/lib/auth-role"; +import { ApiError, apiPost } from "@/lib/api"; + +const schema = z.object({ + title: z.string().trim().min(1).max(255), + summary: z.string().nullable().optional(), + priority: z.enum(["low", "medium", "high", "critical"]).optional(), + alert_ids: z.array(z.number().int()).optional(), + mitre_techniques: z.array(z.string()).optional(), +}); + +export async function POST(req: NextRequest) { + const session = await auth.api.getSession({ headers: await headers() }); + if (!session) return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); + + const body = await req.json().catch(() => null); + const parsed = schema.safeParse(body); + if (!parsed.success) { + return NextResponse.json({ error: parsed.error.issues[0]?.message ?? "Invalid request." }, { status: 400 }); + } + + try { + const data = await apiPost("/api/cases", parsed.data, session.user.id, authRole(session.user)); + return NextResponse.json(data); + } catch (err) { + if (err instanceof ApiError && err.status >= 400 && err.status < 500) { + return NextResponse.json({ error: err.message }, { status: err.status }); + } + return NextResponse.json({ error: "Failed to create case." }, { status: 502 }); + } +} diff --git a/web/app/api/hunts/[id]/route.ts b/web/app/api/hunts/[id]/route.ts index c534993..f16b33d 100644 --- a/web/app/api/hunts/[id]/route.ts +++ b/web/app/api/hunts/[id]/route.ts @@ -14,6 +14,7 @@ const schema = z.object({ alert_on_match: z.boolean(), alert_severity: z.enum(["low", "medium", "high", "critical"]), mitre_techniques: z.array(z.string()).optional(), + is_shared: z.boolean().optional(), }); export async function PUT(req: NextRequest, ctx: { params: Promise<{ id: string }> }) { diff --git a/web/app/api/hunts/route.ts b/web/app/api/hunts/route.ts index e670562..9a30e3a 100644 --- a/web/app/api/hunts/route.ts +++ b/web/app/api/hunts/route.ts @@ -14,6 +14,7 @@ const schema = z.object({ alert_on_match: z.boolean(), alert_severity: z.enum(["low", "medium", "high", "critical"]), mitre_techniques: z.array(z.string()).optional(), + is_shared: z.boolean().optional(), }); export async function POST(req: NextRequest) { diff --git a/web/app/api/threat-intel-feeds/[id]/route.ts b/web/app/api/threat-intel-feeds/[id]/route.ts new file mode 100644 index 0000000..3dc972f --- /dev/null +++ b/web/app/api/threat-intel-feeds/[id]/route.ts @@ -0,0 +1,21 @@ +import { headers } from "next/headers"; +import { NextRequest, NextResponse } from "next/server"; +import { auth } from "@/lib/auth"; +import { authRole } from "@/lib/auth-role"; +import { ApiError, apiDelete } from "@/lib/api"; + +export async function DELETE(_: NextRequest, ctx: { params: Promise<{ id: string }> }) { + const session = await auth.api.getSession({ headers: await headers() }); + if (!session) return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); + + const { id } = await ctx.params; + try { + await apiDelete(`/api/threat-intel-feeds/${id}`, session.user.id, authRole(session.user)); + return new NextResponse(null, { status: 204 }); + } catch (err) { + if (err instanceof ApiError && err.status >= 400 && err.status < 500) { + return NextResponse.json({ error: err.message }, { status: err.status }); + } + return NextResponse.json({ error: "Failed to delete feed." }, { status: 502 }); + } +} diff --git a/web/app/api/threat-intel-feeds/[id]/run/route.ts b/web/app/api/threat-intel-feeds/[id]/run/route.ts new file mode 100644 index 0000000..871124d --- /dev/null +++ b/web/app/api/threat-intel-feeds/[id]/run/route.ts @@ -0,0 +1,21 @@ +import { headers } from "next/headers"; +import { NextRequest, NextResponse } from "next/server"; +import { auth } from "@/lib/auth"; +import { authRole } from "@/lib/auth-role"; +import { ApiError, apiPost } from "@/lib/api"; + +export async function POST(_: NextRequest, ctx: { params: Promise<{ id: string }> }) { + const session = await auth.api.getSession({ headers: await headers() }); + if (!session) return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); + + const { id } = await ctx.params; + try { + const data = await apiPost(`/api/threat-intel-feeds/${id}/run`, {}, session.user.id, authRole(session.user)); + return NextResponse.json(data); + } catch (err) { + if (err instanceof ApiError && err.status >= 400 && err.status < 500) { + return NextResponse.json({ error: err.message }, { status: err.status }); + } + return NextResponse.json({ error: "Failed to run feed." }, { status: 502 }); + } +} diff --git a/web/app/api/threat-intel-feeds/route.ts b/web/app/api/threat-intel-feeds/route.ts new file mode 100644 index 0000000..e9169ef --- /dev/null +++ b/web/app/api/threat-intel-feeds/route.ts @@ -0,0 +1,38 @@ +import { headers } from "next/headers"; +import { NextRequest, NextResponse } from "next/server"; +import { z } from "zod"; +import { auth } from "@/lib/auth"; +import { authRole } from "@/lib/auth-role"; +import { ApiError, apiPost } from "@/lib/api"; + +const schema = z.object({ + name: z.string().trim().min(1).max(160), + kind: z.enum(["urlhaus_csv", "urlhaus_json", "otx_pulse", "misp_events", "taxii21", "generic_csv"]), + url: z.string().url(), + auth_header_name: z.string().nullable().optional(), + auth_header_value: z.string().nullable().optional(), + default_severity: z.enum(["low", "medium", "high", "critical"]).optional(), + interval_minutes: z.number().int().min(5).max(10080).optional(), + is_enabled: z.boolean().optional(), +}); + +export async function POST(req: NextRequest) { + const session = await auth.api.getSession({ headers: await headers() }); + if (!session) return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); + + const body = await req.json().catch(() => null); + const parsed = schema.safeParse(body); + if (!parsed.success) { + return NextResponse.json({ error: parsed.error.issues[0]?.message ?? "Invalid request." }, { status: 400 }); + } + + try { + const data = await apiPost("/api/threat-intel-feeds", parsed.data, session.user.id, authRole(session.user)); + return NextResponse.json(data); + } catch (err) { + if (err instanceof ApiError && err.status >= 400 && err.status < 500) { + return NextResponse.json({ error: err.message }, { status: err.status }); + } + return NextResponse.json({ error: "Failed to create feed." }, { status: 502 }); + } +} diff --git a/web/app/cases/[id]/case-detail-panel.tsx b/web/app/cases/[id]/case-detail-panel.tsx new file mode 100644 index 0000000..e302631 --- /dev/null +++ b/web/app/cases/[id]/case-detail-panel.tsx @@ -0,0 +1,347 @@ +"use client"; + +import { useState } from "react"; +import { useRouter } from "next/navigation"; +import { CheckCircle2, Loader2, MessageSquarePlus, Trash2 } from "lucide-react"; + +export type CaseDetail = { + id: number; + title: string; + summary: string | null; + status: "open" | "investigating" | "contained" | "resolved" | "closed"; + priority: "low" | "medium" | "high" | "critical"; + assigned_to_user_id: string | null; + created_by_user_id: string | null; + mitre_techniques: string[]; + alerts: Array<{ + id: number; + alert_id: number; + alert_title: string; + alert_hostname: string; + alert_severity: string; + alert_created_at: string; + added_at: string; + }>; + notes: Array<{ + id: number; + author_user_id: string | null; + body: string; + created_at: string; + }>; + created_at: string; + updated_at: string; + closed_at: string | null; +}; + +const STATUS_OPTIONS: CaseDetail["status"][] = ["open", "investigating", "contained", "resolved", "closed"]; +const PRIORITY_OPTIONS: CaseDetail["priority"][] = ["low", "medium", "high", "critical"]; + +export function CaseDetailPanel({ initial, role }: { initial: CaseDetail; role: "Admin" | "Viewer" }) { + const router = useRouter(); + const [caseDetail, setCaseDetail] = useState(initial); + const [editing, setEditing] = useState(false); + const [savingMeta, setSavingMeta] = useState(false); + const [noteBody, setNoteBody] = useState(""); + const [addingNote, setAddingNote] = useState(false); + const [error, setError] = useState(null); + + const [title, setTitle] = useState(caseDetail.title); + const [summary, setSummary] = useState(caseDetail.summary ?? ""); + const [status, setStatus] = useState(caseDetail.status); + const [priority, setPriority] = useState(caseDetail.priority); + const [techniques, setTechniques] = useState(caseDetail.mitre_techniques.join(", ")); + + async function saveMeta() { + setError(null); + setSavingMeta(true); + try { + const techList = techniques + .split(",") + .map((t) => t.trim()) + .filter((t) => t.length > 0); + const res = await fetch(`/api/cases/${caseDetail.id}`, { + method: "PUT", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + title: title.trim(), + summary: summary.trim() || null, + status, + priority, + assigned_to_user_id: caseDetail.assigned_to_user_id, + mitre_techniques: techList, + }), + }); + if (!res.ok) { + const body = (await res.json().catch(() => null)) as { error?: string } | null; + throw new Error(body?.error ?? `Save failed with ${res.status}`); + } + setEditing(false); + router.refresh(); + } catch (err) { + setError(err instanceof Error ? err.message : "Save failed."); + } finally { + setSavingMeta(false); + } + } + + async function addNote() { + if (!noteBody.trim()) return; + setError(null); + setAddingNote(true); + try { + const res = await fetch(`/api/cases/${caseDetail.id}/notes`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ body: noteBody.trim() }), + }); + const body = (await res.json().catch(() => null)) as + | { id: number; author_user_id: string | null; body: string; created_at: string; error?: string } + | null; + if (!res.ok || !body || "error" in body) { + throw new Error((body && "error" in body && body.error) || `Add note failed with ${res.status}`); + } + setNoteBody(""); + setCaseDetail((current) => ({ + ...current, + notes: [ + { + id: body.id, + author_user_id: body.author_user_id, + body: body.body, + created_at: body.created_at, + }, + ...current.notes, + ], + })); + } catch (err) { + setError(err instanceof Error ? err.message : "Add note failed."); + } finally { + setAddingNote(false); + } + } + + async function removeAlert(alertId: number) { + if (!window.confirm("Unlink this alert from the case?")) return; + const res = await fetch(`/api/cases/${caseDetail.id}/alerts/${alertId}`, { method: "DELETE" }); + if (res.ok) { + setCaseDetail((current) => ({ + ...current, + alerts: current.alerts.filter((a) => a.alert_id !== alertId), + })); + router.refresh(); + } + } + + return ( +
+ {editing ? ( +
+ +