Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ members = [
]

[workspace.package]
version = "0.7.8"
version = "0.7.9"
authors = [
"Benjamin Bolte <ben@kscale.dev>",
"Denys Bezmenov <denys@kscale.dev>",
Expand Down
2 changes: 1 addition & 1 deletion kos-py/pykos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""KOS Python client."""

__version__ = "0.7.8"
__version__ = "0.7.9"

from . import services
from .client import KOS
11 changes: 11 additions & 0 deletions kos-py/pykos/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pykos.services.imu import IMUServiceClient
from pykos.services.inference import InferenceServiceClient
from pykos.services.led_matrix import LEDMatrixServiceClient
from pykos.services.policy import PolicyServiceClient
from pykos.services.process_manager import ProcessManagerServiceClient
from pykos.services.sim import SimServiceClient
from pykos.services.sound import SoundServiceClient
Expand Down Expand Up @@ -40,6 +41,7 @@ def __init__(self, ip: str = "localhost", port: int = 50051) -> None:
self._process_manager: ProcessManagerServiceClient | None = None
self._inference: InferenceServiceClient | None = None
self._sim: SimServiceClient | None = None
self._policy: PolicyServiceClient | None = None

@property
def imu(self) -> IMUServiceClient:
Expand Down Expand Up @@ -81,6 +83,14 @@ def process_manager(self) -> ProcessManagerServiceClient:
raise RuntimeError("Process Manager client not initialized! Must call `connect()` manually.")
return self._process_manager

@property
def policy(self) -> PolicyServiceClient:
if self._policy is None:
self.connect()
if self._policy is None:
raise RuntimeError("Policy client not initialized! Must call `connect()` manually.")
return self._policy

@property
def inference(self) -> InferenceServiceClient:
if self._inference is None:
Expand Down Expand Up @@ -108,6 +118,7 @@ def connect(self) -> None:
self._led_matrix = LEDMatrixServiceClient(self._channel)
self._sound = SoundServiceClient(self._channel)
self._process_manager = ProcessManagerServiceClient(self._channel)
self._policy = PolicyServiceClient(self._channel)
self._inference = InferenceServiceClient(self._channel)
self._sim = SimServiceClient(self._channel)

Expand Down
54 changes: 54 additions & 0 deletions kos-py/pykos/services/policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Policy service client."""

import grpc.aio
from google.protobuf.empty_pb2 import Empty

from kos_protos import policy_pb2, policy_pb2_grpc
from kos_protos.policy_pb2 import StartPolicyRequest
from pykos.services import AsyncClientBase


class PolicyServiceClient(AsyncClientBase):
"""Client for the PolicyService."""

def __init__(self, channel: grpc.aio.Channel) -> None:
super().__init__()
self.stub = policy_pb2_grpc.PolicyServiceStub(channel)

async def start_policy(
self, action: str, action_scale: float, episode_length: int, dry_run: bool
) -> policy_pb2.StartPolicyResponse:
"""Start policy execution.

Args:
action: The action string for the policy
action_scale: Scale factor for actions
episode_length: Length of the episode
dry_run: Whether to perform a dry run

Returns:
The response from the server.
"""
request = StartPolicyRequest(
action=action,
action_scale=action_scale,
episode_length=episode_length,
dry_run=dry_run,
)
return await self.stub.StartPolicy(request)

async def stop_policy(self, request: Empty = Empty()) -> policy_pb2.StopPolicyResponse:
"""Stop policy execution.

Returns:
The response from the server.
"""
return await self.stub.StopPolicy(request)

async def get_state(self, request: Empty = Empty()) -> policy_pb2.GetStateResponse:
"""Get the current policy state.

Returns:
The response from the server containing the policy state.
"""
return await self.stub.GetState(request)
13 changes: 12 additions & 1 deletion kos-stub/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
mod actuator;
mod imu;
mod policy;
mod process_manager;
use crate::actuator::StubActuator;
use crate::imu::StubIMU;
use crate::policy::StubPolicy;
use crate::process_manager::StubProcessManager;
use async_trait::async_trait;
use kos::hal::Operation;
use kos::kos_proto::actuator::actuator_service_server::ActuatorServiceServer;
use kos::kos_proto::imu::imu_service_server::ImuServiceServer;
use kos::kos_proto::policy::policy_service_server::PolicyServiceServer;
use kos::kos_proto::process_manager::process_manager_service_server::ProcessManagerServiceServer;
use kos::services::{ActuatorServiceImpl, IMUServiceImpl, ProcessManagerServiceImpl};
use kos::services::{
ActuatorServiceImpl, IMUServiceImpl, PolicyServiceImpl, ProcessManagerServiceImpl,
};
use kos::{services::OperationsServiceImpl, Platform, ServiceEnum};

use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
Expand Down Expand Up @@ -52,6 +58,7 @@ impl Platform for StubPlatform {
let actuator = StubActuator::new(operations_service.clone());
let imu = StubIMU::new(operations_service.clone());
let process_manager = StubProcessManager::new();
let policy = StubPolicy::new();

Ok(vec![
ServiceEnum::Actuator(ActuatorServiceServer::new(ActuatorServiceImpl::new(
Expand All @@ -61,6 +68,10 @@ impl Platform for StubPlatform {
ProcessManagerServiceImpl::new(Arc::new(process_manager)),
)),
ServiceEnum::Imu(ImuServiceServer::new(IMUServiceImpl::new(Arc::new(imu)))),
ServiceEnum::Policy(PolicyServiceServer::new(
// Add this block
PolicyServiceImpl::new(Arc::new(policy)),
)),
])
})
}
Expand Down
96 changes: 96 additions & 0 deletions kos-stub/src/policy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use async_trait::async_trait;
use eyre::Result;
use kos::hal::{GetStateResponse, Policy, StartPolicyResponse, StopPolicyResponse};
use kos::kos_proto::common::{Error, ErrorCode};
use std::collections::HashMap;
use std::sync::Mutex;
use uuid::Uuid;

pub struct StubPolicy {
policy_uuid: Mutex<Option<String>>,
state: Mutex<HashMap<String, String>>,
}

impl Default for StubPolicy {
fn default() -> Self {
Self::new()
}
}

impl StubPolicy {
pub fn new() -> Self {
StubPolicy {
policy_uuid: Mutex::new(None),
state: Mutex::new(HashMap::new()),
}
}
}

#[async_trait]
impl Policy for StubPolicy {
async fn start_policy(
&self,
action: String,
action_scale: f32,
episode_length: i32,
dry_run: bool,
) -> Result<StartPolicyResponse> {
let mut policy_uuid = self.policy_uuid.lock().unwrap();
if policy_uuid.is_some() {
return Ok(StartPolicyResponse {
policy_uuid: None,
error: Some(Error {
code: ErrorCode::InvalidArgument as i32,
message: "Policy is already running".to_string(),
}),
});
}

let new_uuid = Uuid::new_v4().to_string();
*policy_uuid = Some(new_uuid.clone());

// Update state with policy parameters
let mut state = self.state.lock().unwrap();
state.insert("action".to_string(), action);
state.insert("action_scale".to_string(), action_scale.to_string());
state.insert("episode_length".to_string(), episode_length.to_string());
state.insert("dry_run".to_string(), dry_run.to_string());

Ok(StartPolicyResponse {
policy_uuid: Some(new_uuid),
error: None,
})
}

async fn stop_policy(&self) -> Result<StopPolicyResponse> {
let mut policy_uuid = self.policy_uuid.lock().unwrap();
if policy_uuid.is_none() {
return Ok(StopPolicyResponse {
policy_uuid: None,
error: Some(Error {
code: ErrorCode::InvalidArgument as i32,
message: "Policy is not running".to_string(),
}),
});
}

let stopped_uuid = policy_uuid.take().unwrap();

// Clear the state when stopping
let mut state = self.state.lock().unwrap();
state.clear();

Ok(StopPolicyResponse {
policy_uuid: Some(stopped_uuid),
error: None,
})
}

async fn get_state(&self) -> Result<GetStateResponse> {
let state = self.state.lock().unwrap();
Ok(GetStateResponse {
state: state.clone(),
error: None,
})
}
}
1 change: 1 addition & 0 deletions kos/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ fn main() {
"kos/sim.proto",
"kos/inference.proto",
"kos/process_manager.proto",
"kos/policy.proto",
"kos/system.proto",
"kos/led_matrix.proto",
"kos/sound.proto",
Expand Down
44 changes: 44 additions & 0 deletions kos/proto/kos/policy.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
syntax = "proto3";

package kos.policy;

import "google/protobuf/empty.proto";
import "kos/common.proto";

option go_package = "kos/policy;policy";
option java_package = "com.kos.policy";
option csharp_namespace = "KOS.Policy";

// The PolicyService manages policy execution.
service PolicyService {
// Starts policy execution.
rpc StartPolicy(StartPolicyRequest) returns (StartPolicyResponse);

// Stops policy execution.
rpc StopPolicy(google.protobuf.Empty) returns (StopPolicyResponse);

// Gets the current policy state.
rpc GetState(google.protobuf.Empty) returns (GetStateResponse);
}

message StartPolicyRequest {
string action = 1;
float action_scale = 2;
int32 episode_length = 3;
bool dry_run = 4;
}

message StartPolicyResponse {
optional string policy_uuid = 1;
kos.common.Error error = 2;
}

message StopPolicyResponse {
optional string policy_uuid = 1;
kos.common.Error error = 2;
}

message GetStateResponse {
map<string, string> state = 1;
kos.common.Error error = 2;
}
1 change: 1 addition & 0 deletions kos/src/daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ fn add_service_to_router(
ServiceEnum::Inference(svc) => router.add_service(svc),
ServiceEnum::LEDMatrix(svc) => router.add_service(svc),
ServiceEnum::Sound(svc) => router.add_service(svc),
ServiceEnum::Policy(svc) => router.add_service(svc),
}
}

Expand Down
4 changes: 4 additions & 0 deletions kos/src/grpc_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ pub mod kos {
tonic::include_proto!("kos/kos.processmanager");
}

pub mod policy {
tonic::include_proto!("kos/kos.policy");
}

pub mod system {
tonic::include_proto!("kos/kos.system");
}
Expand Down
17 changes: 15 additions & 2 deletions kos/src/hal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ pub use crate::grpc_interface::google::longrunning::*;
pub use crate::grpc_interface::kos;
pub use crate::grpc_interface::kos::common::ActionResponse;
pub use crate::kos_proto::{
actuator::*, common::ActionResult, imu::*, inference::*, led_matrix::*, process_manager::*,
sound::*,
actuator::*, common::ActionResult, imu::*, inference::*, led_matrix::*, policy::*,
process_manager::*, sound::*,
};
use async_trait::async_trait;
use bytes::Bytes;
Expand Down Expand Up @@ -55,6 +55,19 @@ pub trait ProcessManager: Send + Sync {
async fn stop_kclip(&self) -> Result<KClipStopResponse>;
}

#[async_trait]
pub trait Policy: Send + Sync {
async fn start_policy(
&self,
action: String,
action_scale: f32,
episode_length: i32,
dry_run: bool,
) -> Result<StartPolicyResponse>;
async fn stop_policy(&self) -> Result<StopPolicyResponse>;
async fn get_state(&self) -> Result<GetStateResponse>;
}

#[async_trait]
pub trait Inference: Send + Sync {
async fn upload_model(
Expand Down
10 changes: 9 additions & 1 deletion kos/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ use hal::imu_service_server::ImuServiceServer;
use hal::inference_service_server::InferenceServiceServer;
use hal::led_matrix_service_server::LedMatrixServiceServer;
use hal::process_manager_service_server::ProcessManagerServiceServer;
use hal::policy_service_server::PolicyServiceServer;
use hal::sound_service_server::SoundServiceServer;
use services::OperationsServiceImpl;
use services::{
ActuatorServiceImpl, IMUServiceImpl, InferenceServiceImpl, LEDMatrixServiceImpl,
ProcessManagerServiceImpl, SoundServiceImpl,
ProcessManagerServiceImpl, SoundServiceImpl, PolicyServiceImpl,
};
use std::fmt::Debug;
use std::future::Future;
Expand All @@ -48,6 +49,12 @@ impl Debug for ProcessManagerServiceImpl {
}
}

impl Debug for PolicyServiceImpl {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "PolicyServiceImpl")
}
}

impl Debug for InferenceServiceImpl {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "InferenceServiceImpl")
Expand All @@ -74,6 +81,7 @@ pub enum ServiceEnum {
Inference(InferenceServiceServer<InferenceServiceImpl>),
LEDMatrix(LedMatrixServiceServer<LEDMatrixServiceImpl>),
Sound(SoundServiceServer<SoundServiceImpl>),
Policy(PolicyServiceServer<PolicyServiceImpl>),
}

#[async_trait]
Expand Down
Loading