diff --git a/Cargo.toml b/Cargo.toml index a74e65d..67114ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ members = [ ] [workspace.package] -version = "0.7.8" +version = "0.7.9" authors = [ "Benjamin Bolte ", "Denys Bezmenov ", diff --git a/kos-py/pykos/__init__.py b/kos-py/pykos/__init__.py index 41941fa..c77043f 100644 --- a/kos-py/pykos/__init__.py +++ b/kos-py/pykos/__init__.py @@ -1,6 +1,6 @@ """KOS Python client.""" -__version__ = "0.7.8" +__version__ = "0.7.9" from . import services from .client import KOS diff --git a/kos-py/pykos/client.py b/kos-py/pykos/client.py index 4d8e64f..c397b11 100644 --- a/kos-py/pykos/client.py +++ b/kos-py/pykos/client.py @@ -12,6 +12,7 @@ from pykos.services.imu import IMUServiceClient from pykos.services.inference import InferenceServiceClient from pykos.services.led_matrix import LEDMatrixServiceClient +from pykos.services.policy import PolicyServiceClient from pykos.services.process_manager import ProcessManagerServiceClient from pykos.services.sim import SimServiceClient from pykos.services.sound import SoundServiceClient @@ -40,6 +41,7 @@ def __init__(self, ip: str = "localhost", port: int = 50051) -> None: self._process_manager: ProcessManagerServiceClient | None = None self._inference: InferenceServiceClient | None = None self._sim: SimServiceClient | None = None + self._policy: PolicyServiceClient | None = None @property def imu(self) -> IMUServiceClient: @@ -81,6 +83,14 @@ def process_manager(self) -> ProcessManagerServiceClient: raise RuntimeError("Process Manager client not initialized! Must call `connect()` manually.") return self._process_manager + @property + def policy(self) -> PolicyServiceClient: + if self._policy is None: + self.connect() + if self._policy is None: + raise RuntimeError("Policy client not initialized! Must call `connect()` manually.") + return self._policy + @property def inference(self) -> InferenceServiceClient: if self._inference is None: @@ -108,6 +118,7 @@ def connect(self) -> None: self._led_matrix = LEDMatrixServiceClient(self._channel) self._sound = SoundServiceClient(self._channel) self._process_manager = ProcessManagerServiceClient(self._channel) + self._policy = PolicyServiceClient(self._channel) self._inference = InferenceServiceClient(self._channel) self._sim = SimServiceClient(self._channel) diff --git a/kos-py/pykos/services/policy.py b/kos-py/pykos/services/policy.py new file mode 100644 index 0000000..9cdaf17 --- /dev/null +++ b/kos-py/pykos/services/policy.py @@ -0,0 +1,54 @@ +"""Policy service client.""" + +import grpc.aio +from google.protobuf.empty_pb2 import Empty + +from kos_protos import policy_pb2, policy_pb2_grpc +from kos_protos.policy_pb2 import StartPolicyRequest +from pykos.services import AsyncClientBase + + +class PolicyServiceClient(AsyncClientBase): + """Client for the PolicyService.""" + + def __init__(self, channel: grpc.aio.Channel) -> None: + super().__init__() + self.stub = policy_pb2_grpc.PolicyServiceStub(channel) + + async def start_policy( + self, action: str, action_scale: float, episode_length: int, dry_run: bool + ) -> policy_pb2.StartPolicyResponse: + """Start policy execution. + + Args: + action: The action string for the policy + action_scale: Scale factor for actions + episode_length: Length of the episode + dry_run: Whether to perform a dry run + + Returns: + The response from the server. + """ + request = StartPolicyRequest( + action=action, + action_scale=action_scale, + episode_length=episode_length, + dry_run=dry_run, + ) + return await self.stub.StartPolicy(request) + + async def stop_policy(self, request: Empty = Empty()) -> policy_pb2.StopPolicyResponse: + """Stop policy execution. + + Returns: + The response from the server. + """ + return await self.stub.StopPolicy(request) + + async def get_state(self, request: Empty = Empty()) -> policy_pb2.GetStateResponse: + """Get the current policy state. + + Returns: + The response from the server containing the policy state. + """ + return await self.stub.GetState(request) diff --git a/kos-stub/src/lib.rs b/kos-stub/src/lib.rs index 8cba942..decf255 100644 --- a/kos-stub/src/lib.rs +++ b/kos-stub/src/lib.rs @@ -1,16 +1,22 @@ mod actuator; mod imu; +mod policy; mod process_manager; use crate::actuator::StubActuator; use crate::imu::StubIMU; +use crate::policy::StubPolicy; use crate::process_manager::StubProcessManager; use async_trait::async_trait; use kos::hal::Operation; use kos::kos_proto::actuator::actuator_service_server::ActuatorServiceServer; use kos::kos_proto::imu::imu_service_server::ImuServiceServer; +use kos::kos_proto::policy::policy_service_server::PolicyServiceServer; use kos::kos_proto::process_manager::process_manager_service_server::ProcessManagerServiceServer; -use kos::services::{ActuatorServiceImpl, IMUServiceImpl, ProcessManagerServiceImpl}; +use kos::services::{ + ActuatorServiceImpl, IMUServiceImpl, PolicyServiceImpl, ProcessManagerServiceImpl, +}; use kos::{services::OperationsServiceImpl, Platform, ServiceEnum}; + use std::future::Future; use std::pin::Pin; use std::sync::Arc; @@ -52,6 +58,7 @@ impl Platform for StubPlatform { let actuator = StubActuator::new(operations_service.clone()); let imu = StubIMU::new(operations_service.clone()); let process_manager = StubProcessManager::new(); + let policy = StubPolicy::new(); Ok(vec![ ServiceEnum::Actuator(ActuatorServiceServer::new(ActuatorServiceImpl::new( @@ -61,6 +68,10 @@ impl Platform for StubPlatform { ProcessManagerServiceImpl::new(Arc::new(process_manager)), )), ServiceEnum::Imu(ImuServiceServer::new(IMUServiceImpl::new(Arc::new(imu)))), + ServiceEnum::Policy(PolicyServiceServer::new( + // Add this block + PolicyServiceImpl::new(Arc::new(policy)), + )), ]) }) } diff --git a/kos-stub/src/policy.rs b/kos-stub/src/policy.rs new file mode 100644 index 0000000..89d603e --- /dev/null +++ b/kos-stub/src/policy.rs @@ -0,0 +1,96 @@ +use async_trait::async_trait; +use eyre::Result; +use kos::hal::{GetStateResponse, Policy, StartPolicyResponse, StopPolicyResponse}; +use kos::kos_proto::common::{Error, ErrorCode}; +use std::collections::HashMap; +use std::sync::Mutex; +use uuid::Uuid; + +pub struct StubPolicy { + policy_uuid: Mutex>, + state: Mutex>, +} + +impl Default for StubPolicy { + fn default() -> Self { + Self::new() + } +} + +impl StubPolicy { + pub fn new() -> Self { + StubPolicy { + policy_uuid: Mutex::new(None), + state: Mutex::new(HashMap::new()), + } + } +} + +#[async_trait] +impl Policy for StubPolicy { + async fn start_policy( + &self, + action: String, + action_scale: f32, + episode_length: i32, + dry_run: bool, + ) -> Result { + let mut policy_uuid = self.policy_uuid.lock().unwrap(); + if policy_uuid.is_some() { + return Ok(StartPolicyResponse { + policy_uuid: None, + error: Some(Error { + code: ErrorCode::InvalidArgument as i32, + message: "Policy is already running".to_string(), + }), + }); + } + + let new_uuid = Uuid::new_v4().to_string(); + *policy_uuid = Some(new_uuid.clone()); + + // Update state with policy parameters + let mut state = self.state.lock().unwrap(); + state.insert("action".to_string(), action); + state.insert("action_scale".to_string(), action_scale.to_string()); + state.insert("episode_length".to_string(), episode_length.to_string()); + state.insert("dry_run".to_string(), dry_run.to_string()); + + Ok(StartPolicyResponse { + policy_uuid: Some(new_uuid), + error: None, + }) + } + + async fn stop_policy(&self) -> Result { + let mut policy_uuid = self.policy_uuid.lock().unwrap(); + if policy_uuid.is_none() { + return Ok(StopPolicyResponse { + policy_uuid: None, + error: Some(Error { + code: ErrorCode::InvalidArgument as i32, + message: "Policy is not running".to_string(), + }), + }); + } + + let stopped_uuid = policy_uuid.take().unwrap(); + + // Clear the state when stopping + let mut state = self.state.lock().unwrap(); + state.clear(); + + Ok(StopPolicyResponse { + policy_uuid: Some(stopped_uuid), + error: None, + }) + } + + async fn get_state(&self) -> Result { + let state = self.state.lock().unwrap(); + Ok(GetStateResponse { + state: state.clone(), + error: None, + }) + } +} diff --git a/kos/build.rs b/kos/build.rs index cc4a475..621d5db 100644 --- a/kos/build.rs +++ b/kos/build.rs @@ -16,6 +16,7 @@ fn main() { "kos/sim.proto", "kos/inference.proto", "kos/process_manager.proto", + "kos/policy.proto", "kos/system.proto", "kos/led_matrix.proto", "kos/sound.proto", diff --git a/kos/proto/kos/policy.proto b/kos/proto/kos/policy.proto new file mode 100644 index 0000000..61964cc --- /dev/null +++ b/kos/proto/kos/policy.proto @@ -0,0 +1,44 @@ +syntax = "proto3"; + +package kos.policy; + +import "google/protobuf/empty.proto"; +import "kos/common.proto"; + +option go_package = "kos/policy;policy"; +option java_package = "com.kos.policy"; +option csharp_namespace = "KOS.Policy"; + +// The PolicyService manages policy execution. +service PolicyService { + // Starts policy execution. + rpc StartPolicy(StartPolicyRequest) returns (StartPolicyResponse); + + // Stops policy execution. + rpc StopPolicy(google.protobuf.Empty) returns (StopPolicyResponse); + + // Gets the current policy state. + rpc GetState(google.protobuf.Empty) returns (GetStateResponse); +} + +message StartPolicyRequest { + string action = 1; + float action_scale = 2; + int32 episode_length = 3; + bool dry_run = 4; +} + +message StartPolicyResponse { + optional string policy_uuid = 1; + kos.common.Error error = 2; +} + +message StopPolicyResponse { + optional string policy_uuid = 1; + kos.common.Error error = 2; +} + +message GetStateResponse { + map state = 1; + kos.common.Error error = 2; +} \ No newline at end of file diff --git a/kos/src/daemon.rs b/kos/src/daemon.rs index a69ba83..ee8b067 100644 --- a/kos/src/daemon.rs +++ b/kos/src/daemon.rs @@ -40,6 +40,7 @@ fn add_service_to_router( ServiceEnum::Inference(svc) => router.add_service(svc), ServiceEnum::LEDMatrix(svc) => router.add_service(svc), ServiceEnum::Sound(svc) => router.add_service(svc), + ServiceEnum::Policy(svc) => router.add_service(svc), } } diff --git a/kos/src/grpc_interface.rs b/kos/src/grpc_interface.rs index 8c21486..cc9554c 100644 --- a/kos/src/grpc_interface.rs +++ b/kos/src/grpc_interface.rs @@ -19,6 +19,10 @@ pub mod kos { tonic::include_proto!("kos/kos.processmanager"); } + pub mod policy { + tonic::include_proto!("kos/kos.policy"); + } + pub mod system { tonic::include_proto!("kos/kos.system"); } diff --git a/kos/src/hal.rs b/kos/src/hal.rs index 4d043c9..da941a9 100644 --- a/kos/src/hal.rs +++ b/kos/src/hal.rs @@ -2,8 +2,8 @@ pub use crate::grpc_interface::google::longrunning::*; pub use crate::grpc_interface::kos; pub use crate::grpc_interface::kos::common::ActionResponse; pub use crate::kos_proto::{ - actuator::*, common::ActionResult, imu::*, inference::*, led_matrix::*, process_manager::*, - sound::*, + actuator::*, common::ActionResult, imu::*, inference::*, led_matrix::*, policy::*, + process_manager::*, sound::*, }; use async_trait::async_trait; use bytes::Bytes; @@ -55,6 +55,19 @@ pub trait ProcessManager: Send + Sync { async fn stop_kclip(&self) -> Result; } +#[async_trait] +pub trait Policy: Send + Sync { + async fn start_policy( + &self, + action: String, + action_scale: f32, + episode_length: i32, + dry_run: bool, + ) -> Result; + async fn stop_policy(&self) -> Result; + async fn get_state(&self) -> Result; +} + #[async_trait] pub trait Inference: Send + Sync { async fn upload_model( diff --git a/kos/src/lib.rs b/kos/src/lib.rs index b4da246..a472b62 100644 --- a/kos/src/lib.rs +++ b/kos/src/lib.rs @@ -19,11 +19,12 @@ use hal::imu_service_server::ImuServiceServer; use hal::inference_service_server::InferenceServiceServer; use hal::led_matrix_service_server::LedMatrixServiceServer; use hal::process_manager_service_server::ProcessManagerServiceServer; +use hal::policy_service_server::PolicyServiceServer; use hal::sound_service_server::SoundServiceServer; use services::OperationsServiceImpl; use services::{ ActuatorServiceImpl, IMUServiceImpl, InferenceServiceImpl, LEDMatrixServiceImpl, - ProcessManagerServiceImpl, SoundServiceImpl, + ProcessManagerServiceImpl, SoundServiceImpl, PolicyServiceImpl, }; use std::fmt::Debug; use std::future::Future; @@ -48,6 +49,12 @@ impl Debug for ProcessManagerServiceImpl { } } +impl Debug for PolicyServiceImpl { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "PolicyServiceImpl") + } +} + impl Debug for InferenceServiceImpl { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "InferenceServiceImpl") @@ -74,6 +81,7 @@ pub enum ServiceEnum { Inference(InferenceServiceServer), LEDMatrix(LedMatrixServiceServer), Sound(SoundServiceServer), + Policy(PolicyServiceServer), } #[async_trait] diff --git a/kos/src/services/mod.rs b/kos/src/services/mod.rs index 075c804..3f8b31c 100644 --- a/kos/src/services/mod.rs +++ b/kos/src/services/mod.rs @@ -4,6 +4,7 @@ mod inference; mod krec_logger; mod led_matrix; mod operations; +mod policy; mod process_manager; mod sound; @@ -13,5 +14,6 @@ pub use inference::*; pub use krec_logger::*; pub use led_matrix::*; pub use operations::*; +pub use policy::*; pub use process_manager::*; pub use sound::*; diff --git a/kos/src/services/policy.rs b/kos/src/services/policy.rs new file mode 100644 index 0000000..71abd7a --- /dev/null +++ b/kos/src/services/policy.rs @@ -0,0 +1,58 @@ +use crate::hal::Policy; +use crate::kos_proto::policy::policy_service_server::PolicyService; +use crate::kos_proto::policy::*; +use std::sync::Arc; +use tonic::{Request, Response, Status}; +use tracing::trace; + +pub struct PolicyServiceImpl { + policy: Arc, +} + +impl PolicyServiceImpl { + pub fn new(policy: Arc) -> Self { + Self { policy } + } +} + +#[tonic::async_trait] +impl PolicyService for PolicyServiceImpl { + async fn start_policy( + &self, + request: Request, + ) -> Result, Status> { + trace!("Starting Policy"); + let req = request.get_ref(); + + Ok(Response::new( + self.policy + .start_policy( + req.action.clone(), + req.action_scale, + req.episode_length, + req.dry_run, + ) + .await + .map_err(|e| Status::internal(format!("Failed to start policy: {:?}", e)))?, + )) + } + + async fn stop_policy( + &self, + _request: Request<()>, + ) -> Result, Status> { + trace!("Stopping Policy"); + + Ok(Response::new(self.policy.stop_policy().await.map_err( + |e| Status::internal(format!("Failed to stop policy: {:?}", e)), + )?)) + } + + async fn get_state(&self, _request: Request<()>) -> Result, Status> { + trace!("Getting Policy State"); + + Ok(Response::new(self.policy.get_state().await.map_err( + |e| Status::internal(format!("Failed to get policy state: {:?}", e)), + )?)) + } +}