diff --git a/crates/basic-api/Cargo.lock b/crates/basic-api/Cargo.lock index 7d95ade..848d218 100644 --- a/crates/basic-api/Cargo.lock +++ b/crates/basic-api/Cargo.lock @@ -85,7 +85,7 @@ dependencies = [ "quote", "regex", "rustc-hash", - "shlex", + "shlex 1.3.0", "syn", ] @@ -101,6 +101,16 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +[[package]] +name = "cc" +version = "1.2.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "556e016178bb5662a08681bbe0f00f8e17631781a4dfc8c45e466e4b185ec27f" +dependencies = [ + "find-msvc-tools", + "shlex 2.0.1", +] + [[package]] name = "cexpr" version = "0.6.0" @@ -143,6 +153,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -559,6 +575,12 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "shlex" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8fadd59c855ef2080decdef8ff161eb6661b86933c9d82e5ba29dc602a55aba" + [[package]] name = "signal-hook-registry" version = "1.4.8" @@ -695,6 +717,7 @@ name = "vdb-ffi" version = "0.1.0" dependencies = [ "bindgen", + "cc", ] [[package]] diff --git a/crates/basic-api/src/api/handlers.rs b/crates/basic-api/src/api/handlers.rs index cd9b94b..b9ee97a 100644 --- a/crates/basic-api/src/api/handlers.rs +++ b/crates/basic-api/src/api/handlers.rs @@ -1,88 +1,23 @@ use axum::{extract::State, Json}; use crate::{ - api::response::{bad_request, ApiError, ApiResult}, - models::{ - DeleteRequest, InsertRequest, MessageResponse, SearchRequest, SearchResponse, - }, - state::AppState, + api::response::ApiResult, + models::{DeleteRequest, InsertRequest, MessageResponse, SearchRequest, SearchResponse}, + service::VectorService, }; pub async fn health() -> &'static str { "ok" } -pub async fn insert( - State(state): State, - Json(request): Json, -) -> ApiResult { - if request.id.trim().is_empty() { - return Err(bad_request("id must not be empty")); - } - - validate_vector(&request.vector)?; - - let mut engine = state.engine.lock().expect("engine mutex poisoned"); - - if let Some(expected_dimension) = engine.dimension() { - if request.vector.len() != expected_dimension { - return Err(bad_request("vector dimension does not match existing vectors")); - } - } - - engine.insert(request.id, request.vector); - - Ok(Json(MessageResponse { - message: "vector inserted".to_string(), - })) +pub async fn insert(State(service): State, Json(request): Json) -> ApiResult { + Ok(Json(service.insert(request.id, request.vector)?)) } -pub async fn search( - State(state): State, - Json(request): Json, -) -> ApiResult { - if request.k == 0 { - return Err(bad_request("k must be greater than 0")); - } - - validate_vector(&request.query)?; - - let engine = state.engine.lock().expect("engine mutex poisoned"); - - if let Some(expected_dimension) = engine.dimension() { - if request.query.len() != expected_dimension { - return Err(bad_request("query dimension does not match stored vectors")); - } - } - - let results = engine.search(request.query, request.k); - Ok(Json(SearchResponse { results })) +pub async fn search(State(service): State, Json(request): Json) -> ApiResult { + Ok(Json(service.search(request.query, request.k)?)) } -pub async fn delete_vector( - State(state): State, - Json(request): Json, -) -> ApiResult { - if request.id.trim().is_empty() { - return Err(bad_request("id must not be empty")); - } - - let mut engine = state.engine.lock().expect("engine mutex poisoned"); - engine.delete(&request.id); - - Ok(Json(MessageResponse { - message: "vector deleted".to_string(), - })) -} - -fn validate_vector(vector: &[f32]) -> Result<(), ApiError> { - if vector.is_empty() { - return Err(bad_request("vector must not be empty")); - } - - if vector.iter().any(|value| !value.is_finite()) { - return Err(bad_request("vector must contain only finite floats")); - } - - Ok(()) +pub async fn delete_vector(State(service): State, Json(request): Json) -> ApiResult { + Ok(Json(service.delete(&request.id)?)) } diff --git a/crates/basic-api/src/engine.rs b/crates/basic-api/src/engine.rs index 89200da..3cb3c62 100644 --- a/crates/basic-api/src/engine.rs +++ b/crates/basic-api/src/engine.rs @@ -12,7 +12,7 @@ pub trait VectorEngine { pub struct FfiEngineAdapter { engine: FfiVectorEngine, - stored_dimensions: HashMap, // maps vector ID : dimension + stored_dimensions: HashMap, dimension: Option, } @@ -38,12 +38,12 @@ impl VectorEngine for FfiEngineAdapter { fn search(&self, query: Vec, k: usize) -> Vec { let results = self.engine.search(&query, k); - (0..results.len()) - .map(|index| SearchResult { + (0..results.len()) // creating range from 0 to len - 1 + .map(|index| SearchResult { // for every number, build SearchResult id: results.id_at(index), score: results.score_at(index), }) - .collect() + .collect() // return vector of SearchResults } fn delete(&mut self, id: &str) { @@ -60,64 +60,3 @@ impl VectorEngine for FfiEngineAdapter { self.dimension } } - - -// Old hardcoded implementation minus the FFI -// ------------------------------------------------ -// #[derive(Default)] -// pub struct FlatIndex { -// vectors: Vec<(String, Vec)>, -// } - -// impl FlatIndex { -// pub fn new() -> Self { -// Self::default() -// } -// } -// impl VectorEngine for FlatIndex { -// fn insert(&mut self, id: String, vector: Vec) { -// self.delete(&id); -// self.vectors.push((id, vector)); -// } - -// fn search(&self, query: Vec, k: usize) -> Vec { -// let mut results: Vec = self -// .vectors -// .iter() -// .map(|(id, vector)| SearchResult { -// id: id.clone(), -// score: cosine_similarity(&query, vector), -// }) -// .collect(); - -// results.sort_by(|left, right| right.score.total_cmp(&left.score)); -// results.truncate(k); -// results -// } - -// fn delete(&mut self, id: &str) { -// self.vectors.retain(|(stored_id, _)| stored_id != id); -// } - -// fn dimension(&self) -> Option { -// self.vectors.first().map(|(_, vector)| vector.len()) -// } -// } - -// fn cosine_similarity(left: &[f32], right: &[f32]) -> f32 { -// let mut dot = 0.0; -// let mut left_norm = 0.0; -// let mut right_norm = 0.0; - -// for (left_value, right_value) in left.iter().zip(right.iter()) { -// dot += left_value * right_value; -// left_norm += left_value * left_value; -// right_norm += right_value * right_value; -// } - -// if left_norm == 0.0 || right_norm == 0.0 { -// return 0.0; -// } - -// dot / (left_norm.sqrt() * right_norm.sqrt()) -// } \ No newline at end of file diff --git a/crates/basic-api/src/main.rs b/crates/basic-api/src/main.rs index 76e0f5d..f2a2c03 100644 --- a/crates/basic-api/src/main.rs +++ b/crates/basic-api/src/main.rs @@ -1,23 +1,24 @@ mod api; mod engine; mod models; +mod service; mod state; use axum::{routing::{get, post}, Router}; use tokio::net::TcpListener; -use crate::{api::handlers, state::AppState}; +use crate::{api::handlers, service::VectorService, state::AppState}; #[tokio::main] async fn main() { - let state = AppState::new(); + let service = VectorService::new(AppState::new()); let app = Router::new() .route("/health", get(handlers::health)) .route("/insert", post(handlers::insert)) .route("/search", post(handlers::search)) .route("/delete", post(handlers::delete_vector)) - .with_state(state); + .with_state(service); let listener = TcpListener::bind("127.0.0.1:3000") .await diff --git a/crates/basic-api/src/service.rs b/crates/basic-api/src/service.rs new file mode 100644 index 0000000..8c31fc3 --- /dev/null +++ b/crates/basic-api/src/service.rs @@ -0,0 +1,93 @@ +use crate::{ + api::response::{bad_request, ApiError}, + models::{MessageResponse, SearchResponse}, + state::AppState, +}; + +#[derive(Clone)] +pub struct VectorService { + state: AppState, +} + +impl VectorService { + pub fn new(state: AppState) -> Self { + Self { state } + } + + pub fn insert(&self, id: String, vector: Vec) -> Result { + self.validate_id(&id)?; + self.validate_vector(&vector)?; + + let mut engine = self.state.engine.lock().expect("engine lock failed"); + + self.validate_dimension(engine.dimension(), vector.len(), "vector dimension does not match existing vectors")?; + + engine.insert(id, vector); + + Ok(MessageResponse { + message: "vector inserted".to_string(), + }) + } + + pub fn search(&self, query: Vec, k: usize) -> Result { + self.validate_k(k)?; + self.validate_vector(&query)?; + + let engine = self.state.engine.lock().expect("engine lock failed"); + + self.validate_dimension(engine.dimension(), query.len(), "query dimension does not match stored vectors")?; + + let results = engine.search(query, k); + Ok(SearchResponse { results }) + } + + pub fn delete(&self, id: &str) -> Result { + self.validate_id(id)?; + + let mut engine = self.state.engine.lock().expect("engine lock failed"); + engine.delete(id); + + Ok(MessageResponse { + message: "vector deleted".to_string(), + }) + } + + fn validate_id(&self, id: &str) -> Result<(), ApiError> { + if id.trim().is_empty() { + return Err(bad_request("id must not be empty")); + } + + Ok(()) + } + + fn validate_k(&self, k: usize) -> Result<(), ApiError> { + if k == 0 { + return Err(bad_request("k must be greater than 0")); + } + + Ok(()) + } + + fn validate_vector(&self, vector: &[f32]) -> Result<(), ApiError> { + if vector.is_empty() { + return Err(bad_request("vector must not be empty")); + } + + if vector.iter().any(|value| !value.is_finite()) { + return Err(bad_request("vector must contain only finite floats")); + } + + Ok(()) + } + + fn validate_dimension(&self, expected_dimension: Option, actual_dimension: usize, message: &str) -> Result<(), ApiError> { + // Option in Rust always returns None or Some(value), so we strip it + if let Some(expected_dimension) = expected_dimension { + if actual_dimension != expected_dimension { + return Err(bad_request(message)); + } + } + + Ok(()) + } +}