From 9a1a3d07a9c5ccb492fa0af869d70c51bc522848 Mon Sep 17 00:00:00 2001 From: Aritra Date: Wed, 29 Apr 2026 20:25:05 -0700 Subject: [PATCH 1/2] refactor(inferenceserver): adjust actors to support provisioning inferenceservers in multiple clusters --- go/components/inferenceserver/BUILD.bazel | 1 + .../inferenceserver/clientfactory/BUILD.bazel | 24 ++ .../inferenceserver/clientfactory/factory.go | 221 ++++++++++++++++++ .../clientfactory/interface.go | 26 +++ .../inferenceserver/clientfactory/module.go | 22 ++ .../clientfactory/secrets/BUILD.bazel | 14 ++ .../clientfactory/secrets/provider.go | 95 ++++++++ go/components/inferenceserver/module.go | 3 + .../inferenceserver/plugins/oss/BUILD.bazel | 2 +- .../plugins/oss/common/BUILD.bazel | 12 +- .../plugins/oss/common/rollout.go | 68 ++++++ .../plugins/oss/creation/BUILD.bazel | 1 + .../plugins/oss/creation/backend_provision.go | 87 ++++--- .../plugins/oss/creation/condition_plugin.go | 15 +- .../plugins/oss/creation/health_check.go | 89 +++++-- .../oss/creation/model_config_provision.go | 54 +++-- .../plugins/oss/creation/validation.go | 10 + .../plugins/oss/deletion/BUILD.bazel | 1 + .../plugins/oss/deletion/cleanup.go | 92 +++++--- .../plugins/oss/deletion/condition_plugin.go | 11 +- .../inferenceserver/plugins/oss/plugin.go | 96 +++++--- proto/api/v2/inference_server.proto | 38 +++ .../demo/inference/inferenceserver.yaml | 7 + .../resources/rbac-inferenceserver.yaml | 30 +++ python/michelangelo/cli/sandbox/sandbox.py | 69 ++++++ 25 files changed, 942 insertions(+), 146 deletions(-) create mode 100644 go/components/inferenceserver/clientfactory/BUILD.bazel create mode 100644 go/components/inferenceserver/clientfactory/factory.go create mode 100644 go/components/inferenceserver/clientfactory/interface.go create mode 100644 go/components/inferenceserver/clientfactory/module.go create mode 100644 go/components/inferenceserver/clientfactory/secrets/BUILD.bazel create mode 100644 go/components/inferenceserver/clientfactory/secrets/provider.go create mode 100644 go/components/inferenceserver/plugins/oss/common/rollout.go create mode 100644 python/michelangelo/cli/sandbox/resources/rbac-inferenceserver.yaml diff --git a/go/components/inferenceserver/BUILD.bazel b/go/components/inferenceserver/BUILD.bazel index ecf5e7917..158270504 100644 --- a/go/components/inferenceserver/BUILD.bazel +++ b/go/components/inferenceserver/BUILD.bazel @@ -15,6 +15,7 @@ go_library( "//go/api/utils:go_default_library", "//go/base/conditions/engine:go_default_library", "//go/base/conditions/interfaces:go_default_library", + "//go/components/inferenceserver/clientfactory:go_default_library", "//go/components/inferenceserver/plugins:go_default_library", "//proto/api/v2:go_default_library", "@io_k8s_api//core/v1:go_default_library", diff --git a/go/components/inferenceserver/clientfactory/BUILD.bazel b/go/components/inferenceserver/clientfactory/BUILD.bazel new file mode 100644 index 000000000..4589a77f9 --- /dev/null +++ b/go/components/inferenceserver/clientfactory/BUILD.bazel @@ -0,0 +1,24 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "go_default_library", + srcs = [ + "factory.go", + "interface.go", + "module.go", + ], + importpath = "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory", + visibility = ["//visibility:public"], + deps = [ + "//go/components/inferenceserver/clientfactory/secrets:go_default_library", + "//proto/api/v2:go_default_library", + "@io_k8s_apimachinery//pkg/runtime:go_default_library", + "@io_k8s_client_go//rest:go_default_library", + "@io_k8s_client_go//tools/clientcmd:go_default_library", + "@io_k8s_client_go//tools/clientcmd/api:go_default_library", + "@io_k8s_client_go//util/flowcontrol:go_default_library", + "@io_k8s_sigs_controller_runtime//pkg/client:go_default_library", + "@org_uber_go_fx//:go_default_library", + "@org_uber_go_zap//:go_default_library", + ], +) diff --git a/go/components/inferenceserver/clientfactory/factory.go b/go/components/inferenceserver/clientfactory/factory.go new file mode 100644 index 000000000..c420e4d58 --- /dev/null +++ b/go/components/inferenceserver/clientfactory/factory.go @@ -0,0 +1,221 @@ +package clientfactory + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "sync" + "time" + + "go.uber.org/zap" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + clientcmdapi "k8s.io/client-go/tools/clientcmd/api" + "k8s.io/client-go/util/flowcontrol" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory/secrets" + v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" +) + +const ( + userAgent = "michelangelo-inferenceserver" + + // httpClientTimeout caps any single HTTP request issued through the factory. + httpClientTimeout = 30 * time.Second +) + +var _ ClientFactory = &remoteClientFactory{} + +// remoteClientFactory builds and caches Kubernetes clients for ClusterTargets. +// +// Both controller-runtime clients and HTTP clients are cached keyed by the connection +// tuple (cluster_id + host + port) so a steady-state reconcile does not rebuild a TLS +// transport on every actor invocation. +type remoteClientFactory struct { + secretProvider secrets.SecretProvider + scheme *runtime.Scheme + logger *zap.Logger + + kubeClients sync.Map // key string → client.Client + httpClients sync.Map // key string → *http.Client + mu sync.Mutex +} + +// NewRemoteClientFactory constructs a ClientFactory. +// +// Parameters: +// - secretProvider: source of CA bundles and bearer tokens for remote clusters. +// - scheme: the runtime.Scheme used to build remote clients (must include all CRDs +// the controller will read/write on remote clusters). +// - logger: structured logger. +func NewRemoteClientFactory( + secretProvider secrets.SecretProvider, + scheme *runtime.Scheme, + logger *zap.Logger, +) ClientFactory { + return &remoteClientFactory{ + secretProvider: secretProvider, + scheme: scheme, + logger: logger.With(zap.String("component", "clientfactory")), + } +} + +// GetClient returns a controller-runtime client for the given ClusterTarget. The +// client is built (and cached) using credentials retrieved from the SecretProvider. +func (f *remoteClientFactory) GetClient(ctx context.Context, cluster *v2pb.ClusterTarget) (client.Client, error) { + if cluster.GetKubernetes() == nil { + return nil, fmt.Errorf("cluster %q has no kubernetes connection spec", cluster.GetClusterId()) + } + + key := cacheKey(cluster) + if cached, ok := f.kubeClients.Load(key); ok { + return cached.(client.Client), nil + } + + // Building a kube client requires hitting the SecretProvider and constructing TLS + // state. Guard with a mutex so concurrent reconciles for the same cluster don't + // duplicate work. + f.mu.Lock() + defer f.mu.Unlock() + + if cached, ok := f.kubeClients.Load(key); ok { + return cached.(client.Client), nil + } + + cfg, err := f.buildRESTConfig(ctx, cluster) + if err != nil { + return nil, fmt.Errorf("build REST config for cluster %q: %w", cluster.GetClusterId(), err) + } + + kubeClient, err := client.New(cfg, client.Options{Scheme: f.scheme}) + if err != nil { + return nil, fmt.Errorf("create kube client for cluster %q: %w", cluster.GetClusterId(), err) + } + + f.kubeClients.Store(key, kubeClient) + f.logger.Info("Built kube client for cluster", + zap.String("cluster_id", cluster.GetClusterId()), + zap.String("host", cluster.GetKubernetes().GetHost())) + return kubeClient, nil +} + +// GetHTTPClient returns an HTTP client whose transport authenticates with a bearer +// token over TLS validated against the cluster's CA. +func (f *remoteClientFactory) GetHTTPClient(ctx context.Context, cluster *v2pb.ClusterTarget) (*http.Client, error) { + if cluster.GetKubernetes() == nil { + return nil, fmt.Errorf("cluster %q has no kubernetes connection spec", cluster.GetClusterId()) + } + + key := cacheKey(cluster) + if cached, ok := f.httpClients.Load(key); ok { + return cached.(*http.Client), nil + } + + f.mu.Lock() + defer f.mu.Unlock() + + if cached, ok := f.httpClients.Load(key); ok { + return cached.(*http.Client), nil + } + + auth, err := f.secretProvider.GetClientAuth(ctx, cluster) + if err != nil { + return nil, fmt.Errorf("get client auth for cluster %q: %w", cluster.GetClusterId(), err) + } + + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM([]byte(auth.CertificateAuthorityData)) { + return nil, fmt.Errorf("parse CA certificate for cluster %q: invalid PEM", cluster.GetClusterId()) + } + + httpClient := &http.Client{ + Transport: &bearerTokenRoundTripper{ + token: auth.ClientTokenData, + rt: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: caPool, + MinVersion: tls.VersionTLS12, + }, + }, + }, + Timeout: httpClientTimeout, + } + + f.httpClients.Store(key, httpClient) + f.logger.Info("Built HTTP client for cluster", + zap.String("cluster_id", cluster.GetClusterId()), + zap.String("host", cluster.GetKubernetes().GetHost())) + return httpClient, nil +} + +// buildRESTConfig assembles a *rest.Config for a cluster from the connection spec +// and credentials retrieved from the SecretProvider. +func (f *remoteClientFactory) buildRESTConfig(ctx context.Context, cluster *v2pb.ClusterTarget) (*rest.Config, error) { + auth, err := f.secretProvider.GetClientAuth(ctx, cluster) + if err != nil { + return nil, fmt.Errorf("get client auth: %w", err) + } + + server := fmt.Sprintf("%s:%s", cluster.GetKubernetes().GetHost(), cluster.GetKubernetes().GetPort()) + + // Build a kubeconfig in-memory and resolve it through clientcmd, which handles + // CA-data + bearer-token wiring via its established schema. + apiCfg := &clientcmdapi.Config{ + Kind: "Config", + APIVersion: "v1", + Clusters: map[string]*clientcmdapi.Cluster{ + "remote": { + Server: server, + CertificateAuthorityData: []byte(auth.CertificateAuthorityData), + }, + }, + AuthInfos: map[string]*clientcmdapi.AuthInfo{ + userAgent: {Token: auth.ClientTokenData}, + }, + Contexts: map[string]*clientcmdapi.Context{ + userAgent + "@remote": { + Cluster: "remote", + AuthInfo: userAgent, + }, + }, + CurrentContext: userAgent + "@remote", + } + + cfg, err := clientcmd.NewDefaultClientConfig(*apiCfg, &clientcmd.ConfigOverrides{}).ClientConfig() + if err != nil { + return nil, fmt.Errorf("resolve client config: %w", err) + } + + // Disable client-side rate limiting; rely on the API server's Priority and Fairness. + cfg.RateLimiter = flowcontrol.NewFakeAlwaysRateLimiter() + cfg.ContentType = runtime.ContentTypeJSON + + return rest.AddUserAgent(cfg, userAgent), nil +} + +// cacheKey produces a stable cache key for a ClusterTarget. +func cacheKey(cluster *v2pb.ClusterTarget) string { + return fmt.Sprintf("%s|%s:%s", + cluster.GetClusterId(), + cluster.GetKubernetes().GetHost(), + cluster.GetKubernetes().GetPort(), + ) +} + +// bearerTokenRoundTripper injects a bearer Authorization header on every request. +type bearerTokenRoundTripper struct { + token string + rt http.RoundTripper +} + +// RoundTrip clones the request to avoid mutating the caller's headers, then forwards +// to the underlying transport. +func (rt *bearerTokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + clone := req.Clone(req.Context()) + clone.Header.Set("Authorization", "Bearer "+rt.token) + return rt.rt.RoundTrip(clone) +} diff --git a/go/components/inferenceserver/clientfactory/interface.go b/go/components/inferenceserver/clientfactory/interface.go new file mode 100644 index 000000000..2c9a1a7ff --- /dev/null +++ b/go/components/inferenceserver/clientfactory/interface.go @@ -0,0 +1,26 @@ +// Package clientfactory provides Kubernetes API clients for ClusterTargets that an +// InferenceServer is provisioned across. +// +//go:generate mamockgen ClientFactory +package clientfactory + +import ( + "context" + "net/http" + + "sigs.k8s.io/controller-runtime/pkg/client" + + v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" +) + +// ClientFactory returns Kubernetes API clients for a ClusterTarget. The clients are +// built using credentials retrieved from the SecretProvider. +type ClientFactory interface { + // GetClient returns a controller-runtime client for the given cluster. + GetClient(ctx context.Context, cluster *v2pb.ClusterTarget) (client.Client, error) + + // GetHTTPClient returns an HTTP client for talking to user-space services in the + // given cluster. The client is configured with TLS using the cluster's CA bundle + // and authenticates with the bearer token. + GetHTTPClient(ctx context.Context, cluster *v2pb.ClusterTarget) (*http.Client, error) +} diff --git a/go/components/inferenceserver/clientfactory/module.go b/go/components/inferenceserver/clientfactory/module.go new file mode 100644 index 000000000..d862d1faa --- /dev/null +++ b/go/components/inferenceserver/clientfactory/module.go @@ -0,0 +1,22 @@ +package clientfactory + +import ( + "go.uber.org/fx" + "go.uber.org/zap" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory/secrets" +) + +// Module wires the ClientFactory into the fx graph. +var Module = fx.Options( + fx.Provide(newClientFactory), +) + +func newClientFactory(kubeClient client.Client, logger *zap.Logger) ClientFactory { + return NewRemoteClientFactory( + secrets.NewProvider(kubeClient), + kubeClient.Scheme(), + logger, + ) +} diff --git a/go/components/inferenceserver/clientfactory/secrets/BUILD.bazel b/go/components/inferenceserver/clientfactory/secrets/BUILD.bazel new file mode 100644 index 000000000..5bfa2eaba --- /dev/null +++ b/go/components/inferenceserver/clientfactory/secrets/BUILD.bazel @@ -0,0 +1,14 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "go_default_library", + srcs = ["provider.go"], + importpath = "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory/secrets", + visibility = ["//visibility:public"], + deps = [ + "//proto/api/v2:go_default_library", + "@io_k8s_api//core/v1:go_default_library", + "@io_k8s_apimachinery//pkg/types:go_default_library", + "@io_k8s_sigs_controller_runtime//pkg/client:go_default_library", + ], +) diff --git a/go/components/inferenceserver/clientfactory/secrets/provider.go b/go/components/inferenceserver/clientfactory/secrets/provider.go new file mode 100644 index 000000000..b2ed86c3c --- /dev/null +++ b/go/components/inferenceserver/clientfactory/secrets/provider.go @@ -0,0 +1,95 @@ +// Package secrets retrieves Kubernetes API credentials for ClusterTargets from +// a Kubernetes-backed secret store. +// +//go:generate mamockgen SecretProvider +package secrets + +import ( + "context" + "fmt" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" +) + +// secretsNamespace is where cluster-specific credential secrets are expected to live. +// The secret names themselves are taken from ClusterTarget.Kubernetes.{CaDataTag,TokenTag}. +const secretsNamespace = "default" + +// Keys within each Kubernetes Secret's .data map. +const ( + caDataKey = "cadata" + tokenKey = "token" +) + +// ClientAuth contains the credentials needed to authenticate to a Kubernetes cluster. +type ClientAuth struct { + // CertificateAuthorityData is the PEM-encoded CA bundle that signs the API server's + // serving certificate. + CertificateAuthorityData string + // ClientTokenData is the bearer token presented to the API server. + ClientTokenData string +} + +// SecretProvider retrieves cluster authentication credentials for a ClusterTarget. +type SecretProvider interface { + GetClientAuth(ctx context.Context, cluster *v2pb.ClusterTarget) (ClientAuth, error) +} + +// Provider implements SecretProvider by reading two Kubernetes Secret objects from the +// control-plane cluster: +// 1. One for the CA bundle +// 2. One for the bearer token +// The secret names are pulled from the ClusterTarget's `CaDataTag` and `TokenTag` fields. +// +// NOTE: This implementation is intended for sandbox and testing use. Production deployments +// should use an external secret manager (e.g. HashiCorp Vault, AWS Secrets Manager, GCP +// Secret Manager) and provide their own SecretProvider implementation. +type Provider struct { + kubeClient client.Client +} + +// NewProvider returns a SecretProvider backed by the given Kubernetes client. +func NewProvider(kubeClient client.Client) *Provider { + return &Provider{kubeClient: kubeClient} +} + +// GetClientAuth fetches the CA certificate and bearer token secrets for the given +// ClusterTarget and returns them as a ClientAuth value. +func (p *Provider) GetClientAuth(ctx context.Context, cluster *v2pb.ClusterTarget) (ClientAuth, error) { + if cluster.GetKubernetes() == nil { + return ClientAuth{}, fmt.Errorf("cluster %q has no kubernetes connection spec", cluster.GetClusterId()) + } + + caSecretName := cluster.GetKubernetes().GetCaDataTag() + caSecret, err := p.fetchSecret(ctx, caSecretName) + if err != nil { + return ClientAuth{}, fmt.Errorf("CA secret for cluster %q: %w", cluster.GetClusterId(), err) + } + + tokenSecretName := cluster.GetKubernetes().GetTokenTag() + tokenSecret, err := p.fetchSecret(ctx, tokenSecretName) + if err != nil { + return ClientAuth{}, fmt.Errorf("token secret for cluster %q: %w", cluster.GetClusterId(), err) + } + + return ClientAuth{ + CertificateAuthorityData: string(caSecret.Data[caDataKey]), + ClientTokenData: string(tokenSecret.Data[tokenKey]), + }, nil +} + +func (p *Provider) fetchSecret(ctx context.Context, name string) (*corev1.Secret, error) { + if name == "" { + return nil, fmt.Errorf("empty secret name") + } + secret := &corev1.Secret{} + key := types.NamespacedName{Name: name, Namespace: secretsNamespace} + if err := p.kubeClient.Get(ctx, key, secret); err != nil { + return nil, fmt.Errorf("get secret %s/%s: %w", secretsNamespace, name, err) + } + return secret, nil +} diff --git a/go/components/inferenceserver/module.go b/go/components/inferenceserver/module.go index 29e9bda4d..f9a544495 100644 --- a/go/components/inferenceserver/module.go +++ b/go/components/inferenceserver/module.go @@ -4,10 +4,13 @@ import ( "go.uber.org/fx" "k8s.io/client-go/tools/record" ctrl "sigs.k8s.io/controller-runtime" + + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" ) // Module provides the inference server controller with all dependencies var Module = fx.Options( + clientfactory.Module, fx.Provide(newEventRecorder), fx.Provide(NewReconciler), fx.Invoke(register), diff --git a/go/components/inferenceserver/plugins/oss/BUILD.bazel b/go/components/inferenceserver/plugins/oss/BUILD.bazel index af9d88a19..111e5e3ec 100644 --- a/go/components/inferenceserver/plugins/oss/BUILD.bazel +++ b/go/components/inferenceserver/plugins/oss/BUILD.bazel @@ -11,6 +11,7 @@ go_library( deps = [ "//go/base/conditions/interfaces:go_default_library", "//go/components/inferenceserver/backends:go_default_library", + "//go/components/inferenceserver/clientfactory:go_default_library", "//go/components/inferenceserver/modelconfig:go_default_library", "//go/components/inferenceserver/plugins:go_default_library", "//go/components/inferenceserver/plugins/oss/creation:go_default_library", @@ -19,7 +20,6 @@ go_library( "//proto/api/v2:go_default_library", "@io_k8s_api//core/v1:go_default_library", "@io_k8s_client_go//tools/record:go_default_library", - "@io_k8s_sigs_controller_runtime//pkg/client:go_default_library", "@org_uber_go_fx//:go_default_library", "@org_uber_go_zap//:go_default_library", ], diff --git a/go/components/inferenceserver/plugins/oss/common/BUILD.bazel b/go/components/inferenceserver/plugins/oss/common/BUILD.bazel index e95143c78..4f2fde7ef 100644 --- a/go/components/inferenceserver/plugins/oss/common/BUILD.bazel +++ b/go/components/inferenceserver/plugins/oss/common/BUILD.bazel @@ -2,7 +2,17 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "go_default_library", - srcs = ["constants.go"], + srcs = [ + "constants.go", + "rollout.go", + ], importpath = "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/plugins/oss/common", visibility = ["//visibility:public"], + deps = [ + "//go/base/conditions/utils:go_default_library", + "//go/components/inferenceserver/clientfactory:go_default_library", + "//proto/api:go_default_library", + "//proto/api/v2:go_default_library", + "@io_k8s_sigs_controller_runtime//pkg/client:go_default_library", + ], ) diff --git a/go/components/inferenceserver/plugins/oss/common/rollout.go b/go/components/inferenceserver/plugins/oss/common/rollout.go new file mode 100644 index 000000000..a43601f9a --- /dev/null +++ b/go/components/inferenceserver/plugins/oss/common/rollout.go @@ -0,0 +1,68 @@ +package common + +import ( + "context" + "fmt" + + "sigs.k8s.io/controller-runtime/pkg/client" + + conditionsutil "github.com/michelangelo-ai/michelangelo/go/base/conditions/utils" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" + apipb "github.com/michelangelo-ai/michelangelo/proto-go/api" + v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" +) + +const ( + // ClusterRolloutStrategyAnnotation is the annotation key for specifying the per-cluster rollout strategy. + ClusterRolloutStrategyAnnotation = "michelangelo.ai/cluster-rollout-strategy" + rollingStrategy = "rolling" +) + +// GetRolloutStrategy reads the rollout strategy annotation. Defaults to "rolling" when absent. +func GetRolloutStrategy(resource *v2pb.InferenceServer) string { + if anns := resource.GetMetadata().GetAnnotations(); anns != nil { + if v, ok := anns[ClusterRolloutStrategyAnnotation]; ok { + return v + } + } + return rollingStrategy +} + +// IsKnownRolloutStrategy reports whether strategy is a recognized rollout strategy value. +func IsKnownRolloutStrategy(strategy string) bool { + return strategy == rollingStrategy +} + +// RunRolling iterates cluster_targets in spec order and calls doWork on the first cluster +// that isDone returns false for. Returns TRUE once all clusters are done. +func RunRolling( + ctx context.Context, + factory clientfactory.ClientFactory, + targets []*v2pb.ClusterTarget, + condition *apipb.Condition, + isDone func(ctx context.Context, kubeClient client.Client, target *v2pb.ClusterTarget) (bool, error), + doWork func(ctx context.Context, kubeClient client.Client, target *v2pb.ClusterTarget) error, +) (*apipb.Condition, error) { + for _, target := range targets { + kubeClient, err := factory.GetClient(ctx, target) + if err != nil { + return conditionsutil.GenerateFalseCondition(condition, "ClientError", + fmt.Sprintf("%s: %v", target.GetClusterId(), err)), nil + } + done, err := isDone(ctx, kubeClient, target) + if err != nil { + return conditionsutil.GenerateFalseCondition(condition, "StatusCheckFailed", + fmt.Sprintf("%s: %v", target.GetClusterId(), err)), nil + } + if done { + continue + } + if err := doWork(ctx, kubeClient, target); err != nil { + return conditionsutil.GenerateFalseCondition(condition, "ProvisionFailed", + fmt.Sprintf("%s: %v", target.GetClusterId(), err)), nil + } + return conditionsutil.GenerateFalseCondition(condition, "RollingInProgress", + fmt.Sprintf("provisioning cluster %s", target.GetClusterId())), nil + } + return conditionsutil.GenerateTrueCondition(condition), nil +} diff --git a/go/components/inferenceserver/plugins/oss/creation/BUILD.bazel b/go/components/inferenceserver/plugins/oss/creation/BUILD.bazel index 637f90d2e..4955679f7 100644 --- a/go/components/inferenceserver/plugins/oss/creation/BUILD.bazel +++ b/go/components/inferenceserver/plugins/oss/creation/BUILD.bazel @@ -15,6 +15,7 @@ go_library( "//go/base/conditions/interfaces:go_default_library", "//go/base/conditions/utils:go_default_library", "//go/components/inferenceserver/backends:go_default_library", + "//go/components/inferenceserver/clientfactory:go_default_library", "//go/components/inferenceserver/modelconfig:go_default_library", "//go/components/inferenceserver/plugins/oss/common:go_default_library", "//proto/api:go_default_library", diff --git a/go/components/inferenceserver/plugins/oss/creation/backend_provision.go b/go/components/inferenceserver/plugins/oss/creation/backend_provision.go index 101c69542..634d13145 100644 --- a/go/components/inferenceserver/plugins/oss/creation/backend_provision.go +++ b/go/components/inferenceserver/plugins/oss/creation/backend_provision.go @@ -3,14 +3,15 @@ package creation import ( "context" "fmt" + "strings" "go.uber.org/zap" - "sigs.k8s.io/controller-runtime/pkg/client" conditionInterfaces "github.com/michelangelo-ai/michelangelo/go/base/conditions/interfaces" conditionUtils "github.com/michelangelo-ai/michelangelo/go/base/conditions/utils" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/backends" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/plugins/oss/common" apipb "github.com/michelangelo-ai/michelangelo/proto-go/api" v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" @@ -20,17 +21,17 @@ var _ conditionInterfaces.ConditionActor[*v2pb.InferenceServer] = &BackendProvis // BackendProvisioningActor provisions Kubernetes resources for inference servers. type BackendProvisionActor struct { - client client.Client - registry *backends.Registry - logger *zap.Logger + clientFactory clientfactory.ClientFactory + registry *backends.Registry + logger *zap.Logger } // NewBackendProvisionActor creates a condition actor for inference server provisioning. -func NewBackendProvisionActor(client client.Client, registry *backends.Registry, logger *zap.Logger) conditionInterfaces.ConditionActor[*v2pb.InferenceServer] { +func NewBackendProvisionActor(clientFactory clientfactory.ClientFactory, registry *backends.Registry, logger *zap.Logger) conditionInterfaces.ConditionActor[*v2pb.InferenceServer] { return &BackendProvisionActor{ - client: client, - registry: registry, - logger: logger, + clientFactory: clientFactory, + registry: registry, + logger: logger, } } @@ -48,23 +49,36 @@ func (a *BackendProvisionActor) Retrieve(ctx context.Context, resource *v2pb.Inf return conditionUtils.GenerateFalseCondition(condition, "BackendNotFound", fmt.Sprintf("Failed to get backend: %v", err)), nil } - // Check if inference server resources exist - status, err := backend.GetServerStatus(ctx, a.logger, a.client, resource.Name, resource.Namespace) - if err != nil { - a.logger.Error("Failed to check backend provisioning status", - zap.Error(err), - zap.String("operation", "get_backend_provisioning_status"), - zap.String("namespace", resource.Namespace), - zap.String("backend", resource.Name)) - return conditionUtils.GenerateFalseCondition(condition, "BackendProvisioningCheckFailed", fmt.Sprintf("Failed to check backend status: %v", err)), nil + var failures []string + for _, target := range resource.Spec.ClusterTargets { + kubeClient, err := a.clientFactory.GetClient(ctx, target) + if err != nil { + failures = append(failures, fmt.Sprintf("%s: client error: %v", target.GetClusterId(), err)) + continue + } + + // Check if inference server resources exist + status, err := backend.GetServerStatus(ctx, a.logger, kubeClient, resource.Name, resource.Namespace) + if err != nil { + a.logger.Error("Failed to check backend provisioning status", + zap.Error(err), + zap.String("operation", "get_backend_provisioning_status"), + zap.String("namespace", resource.Namespace), + zap.String("backend", resource.Name), + zap.String("cluster_id", target.GetClusterId())) + failures = append(failures, fmt.Sprintf("%s: %v", target.GetClusterId(), err)) + continue + } + + if status.State != v2pb.INFERENCE_SERVER_STATE_SERVING { + failures = append(failures, fmt.Sprintf("%s: state %s", target.GetClusterId(), status.State)) + } } - switch status.State { - case v2pb.INFERENCE_SERVER_STATE_SERVING: - return conditionUtils.GenerateTrueCondition(condition), nil - default: - return conditionUtils.GenerateFalseCondition(condition, "BackendProvisioningFailed", fmt.Sprintf("Backend state is not serving: %v", status.State)), nil + if len(failures) > 0 { + return conditionUtils.GenerateFalseCondition(condition, "BackendProvisioningFailed", strings.Join(failures, "; ")), nil } + return conditionUtils.GenerateTrueCondition(condition), nil } // Run creates the Kubernetes deployment, service, and related resources for inference servers. @@ -76,15 +90,24 @@ func (a *BackendProvisionActor) Run(ctx context.Context, resource *v2pb.Inferenc return conditionUtils.GenerateFalseCondition(condition, "BackendNotFound", fmt.Sprintf("Failed to get backend: %v", err)), nil } - _, err = backend.CreateServer(ctx, a.logger, a.client, resource) - if err != nil { - a.logger.Error("Failed to create backend", - zap.Error(err), - zap.String("operation", "create_backend"), - zap.String("namespace", resource.Namespace), - zap.String("inferenceServer", resource.Name)) - return conditionUtils.GenerateFalseCondition(condition, "BackendProvisionFailed", fmt.Sprintf("Failed to provision backend: %v", err)), err + isDone := func(ctx context.Context, kubeClient client.Client, target *v2pb.ClusterTarget) (bool, error) { + status, err := backend.GetServerStatus(ctx, a.logger, kubeClient, resource.Name, resource.Namespace) + if err != nil { + return false, err + } + return status.State == v2pb.INFERENCE_SERVER_STATE_SERVING, nil } - - return conditionUtils.GenerateTrueCondition(condition), nil + doWork := func(ctx context.Context, kubeClient client.Client, target *v2pb.ClusterTarget) error { + _, err := backend.CreateServer(ctx, a.logger, kubeClient, resource) + if err != nil { + a.logger.Error("Failed to create backend", + zap.Error(err), + zap.String("operation", "create_backend"), + zap.String("namespace", resource.Namespace), + zap.String("inferenceServer", resource.Name), + zap.String("cluster_id", target.GetClusterId())) + } + return err + } + return common.RunRolling(ctx, a.clientFactory, resource.Spec.ClusterTargets, condition, isDone, doWork) } diff --git a/go/components/inferenceserver/plugins/oss/creation/condition_plugin.go b/go/components/inferenceserver/plugins/oss/creation/condition_plugin.go index fbbd2d8a6..d7ad4b39c 100644 --- a/go/components/inferenceserver/plugins/oss/creation/condition_plugin.go +++ b/go/components/inferenceserver/plugins/oss/creation/condition_plugin.go @@ -3,10 +3,9 @@ package creation import ( "go.uber.org/zap" - "sigs.k8s.io/controller-runtime/pkg/client" - conditionInterfaces "github.com/michelangelo-ai/michelangelo/go/base/conditions/interfaces" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/backends" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" modelconfig "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/modelconfig" apipb "github.com/michelangelo-ai/michelangelo/proto-go/api" v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" @@ -14,16 +13,16 @@ import ( // CreationPlugin orchestrates the condition actors for inference server creation. type CreationPlugin struct { - client client.Client + clientFactory clientfactory.ClientFactory registry *backends.Registry modelConfigProvider modelconfig.ModelConfigProvider logger *zap.Logger } // NewCreationPlugin creates a plugin that manages validation, provisioning, health checks, and routing. -func NewCreationPlugin(client client.Client, registry *backends.Registry, modelConfigProvider modelconfig.ModelConfigProvider, logger *zap.Logger) conditionInterfaces.Plugin[*v2pb.InferenceServer] { +func NewCreationPlugin(clientFactory clientfactory.ClientFactory, registry *backends.Registry, modelConfigProvider modelconfig.ModelConfigProvider, logger *zap.Logger) conditionInterfaces.Plugin[*v2pb.InferenceServer] { return &CreationPlugin{ - client: client, + clientFactory: clientFactory, registry: registry, modelConfigProvider: modelConfigProvider, logger: logger, @@ -34,9 +33,9 @@ func NewCreationPlugin(client client.Client, registry *backends.Registry, modelC func (p *CreationPlugin) GetActors() []conditionInterfaces.ConditionActor[*v2pb.InferenceServer] { return []conditionInterfaces.ConditionActor[*v2pb.InferenceServer]{ NewValidationActor(p.registry, p.logger), - NewBackendProvisionActor(p.client, p.registry, p.logger), - NewModelConfigProvisionActor(p.client, p.modelConfigProvider, p.logger), - NewHealthCheckActor(p.client, p.registry, p.logger), + NewBackendProvisionActor(p.clientFactory, p.registry, p.logger), + NewModelConfigProvisionActor(p.clientFactory, p.modelConfigProvider, p.logger), + NewHealthCheckActor(p.clientFactory, p.registry, p.logger), } } diff --git a/go/components/inferenceserver/plugins/oss/creation/health_check.go b/go/components/inferenceserver/plugins/oss/creation/health_check.go index 97c48acd8..1c7acf2fd 100644 --- a/go/components/inferenceserver/plugins/oss/creation/health_check.go +++ b/go/components/inferenceserver/plugins/oss/creation/health_check.go @@ -3,13 +3,14 @@ package creation import ( "context" "fmt" + "strings" "go.uber.org/zap" - "sigs.k8s.io/controller-runtime/pkg/client" conditionInterfaces "github.com/michelangelo-ai/michelangelo/go/base/conditions/interfaces" conditionUtils "github.com/michelangelo-ai/michelangelo/go/base/conditions/utils" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/backends" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/plugins/oss/common" apipb "github.com/michelangelo-ai/michelangelo/proto-go/api" v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" @@ -19,17 +20,17 @@ var _ conditionInterfaces.ConditionActor[*v2pb.InferenceServer] = &HealthCheckAc // HealthCheckActor verifies inference server health by polling backend health endpoints. type HealthCheckActor struct { - registry *backends.Registry - logger *zap.Logger - client client.Client + registry *backends.Registry + logger *zap.Logger + clientFactory clientfactory.ClientFactory } // NewHealthCheckActor creates a condition actor for inference server health verification. -func NewHealthCheckActor(client client.Client, registry *backends.Registry, logger *zap.Logger) conditionInterfaces.ConditionActor[*v2pb.InferenceServer] { +func NewHealthCheckActor(clientFactory clientfactory.ClientFactory, registry *backends.Registry, logger *zap.Logger) conditionInterfaces.ConditionActor[*v2pb.InferenceServer] { return &HealthCheckActor{ - client: client, - registry: registry, - logger: logger, + clientFactory: clientFactory, + registry: registry, + logger: logger, } } @@ -47,24 +48,66 @@ func (a *HealthCheckActor) Retrieve(ctx context.Context, resource *v2pb.Inferenc return conditionUtils.GenerateFalseCondition(condition, "BackendNotFound", fmt.Sprintf("Failed to get backend: %v", err)), nil } - healthy, err := backend.IsHealthy(ctx, a.logger, a.client, resource.Name, resource.Namespace) - if err == nil && healthy { - return conditionUtils.GenerateTrueCondition(condition), nil - } else if err != nil { - a.logger.Error("Health check failed", - zap.Error(err), - zap.String("operation", "health_check"), - zap.String("namespace", resource.Namespace), - zap.String("inferenceServer", resource.Name)) - return conditionUtils.GenerateFalseCondition(condition, "HealthCheckFailed", fmt.Sprintf("Health check error: %v", err)), nil - } + clusterStatuses, failures := a.checkClusterHealth(ctx, backend, resource) + resource.Status.ClusterStatuses = clusterStatuses - return conditionUtils.GenerateFalseCondition(condition, "HealthCheckFailed", "Server is not healthy"), nil + if len(failures) > 0 { + return conditionUtils.GenerateFalseCondition(condition, "HealthCheckFailed", strings.Join(failures, "; ")), nil + } + return conditionUtils.GenerateTrueCondition(condition), nil } -// Run returns a failed condition since health check failures cannot be automatically remediated. +// Run returns the condition unchanged. Health check failures are not auto-remediable. func (a *HealthCheckActor) Run(ctx context.Context, resource *v2pb.InferenceServer, condition *apipb.Condition) (*apipb.Condition, error) { - // This method is only run when Retrieve() fails. - // If Retrieve() failed, then there's nothing we can do here, simply return the condition. return condition, nil } + +// checkClusterHealth polls IsHealthy on each cluster target and returns per-cluster statuses +// along with a list of failure messages for clusters that are not yet healthy. +func (a *HealthCheckActor) checkClusterHealth(ctx context.Context, backend backends.Backend, resource *v2pb.InferenceServer) ([]*v2pb.ClusterTargetStatus, []string) { + clusterStatuses := make([]*v2pb.ClusterTargetStatus, 0, len(resource.Spec.ClusterTargets)) + var failures []string + + for _, target := range resource.Spec.ClusterTargets { + kubeClient, err := a.clientFactory.GetClient(ctx, target) + if err != nil { + failures = append(failures, fmt.Sprintf("%s: client error: %v", target.GetClusterId(), err)) + clusterStatuses = append(clusterStatuses, &v2pb.ClusterTargetStatus{ + ClusterId: target.GetClusterId(), + State: v2pb.INFERENCE_SERVER_STATE_CREATING, + Message: err.Error(), + }) + continue + } + + healthy, err := backend.IsHealthy(ctx, a.logger, kubeClient, resource.Name, resource.Namespace) + if err != nil { + a.logger.Error("Health check failed", + zap.Error(err), + zap.String("operation", "health_check"), + zap.String("namespace", resource.Namespace), + zap.String("inferenceServer", resource.Name), + zap.String("cluster_id", target.GetClusterId())) + failures = append(failures, fmt.Sprintf("%s: %v", target.GetClusterId(), err)) + clusterStatuses = append(clusterStatuses, &v2pb.ClusterTargetStatus{ + ClusterId: target.GetClusterId(), + State: v2pb.INFERENCE_SERVER_STATE_CREATING, + Message: err.Error(), + }) + continue + } + + state := v2pb.INFERENCE_SERVER_STATE_CREATING + if healthy { + state = v2pb.INFERENCE_SERVER_STATE_SERVING + } else { + failures = append(failures, fmt.Sprintf("%s: not healthy", target.GetClusterId())) + } + clusterStatuses = append(clusterStatuses, &v2pb.ClusterTargetStatus{ + ClusterId: target.GetClusterId(), + State: state, + }) + } + + return clusterStatuses, failures +} diff --git a/go/components/inferenceserver/plugins/oss/creation/model_config_provision.go b/go/components/inferenceserver/plugins/oss/creation/model_config_provision.go index ab75a55a9..87d450e6f 100644 --- a/go/components/inferenceserver/plugins/oss/creation/model_config_provision.go +++ b/go/components/inferenceserver/plugins/oss/creation/model_config_provision.go @@ -3,13 +3,14 @@ package creation import ( "context" "fmt" + "strings" "go.uber.org/zap" - "sigs.k8s.io/controller-runtime/pkg/client" conditionInterfaces "github.com/michelangelo-ai/michelangelo/go/base/conditions/interfaces" conditionsutil "github.com/michelangelo-ai/michelangelo/go/base/conditions/utils" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" modelconfig "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/modelconfig" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/plugins/oss/common" apipb "github.com/michelangelo-ai/michelangelo/proto-go/api" @@ -20,14 +21,14 @@ var _ conditionInterfaces.ConditionActor[*v2pb.InferenceServer] = &ModelConfigPr // ModelConfigProvisionActor provisions model configuration for inference servers. type ModelConfigProvisionActor struct { - client client.Client + clientFactory clientfactory.ClientFactory modelConfigProvider modelconfig.ModelConfigProvider logger *zap.Logger } -func NewModelConfigProvisionActor(client client.Client, modelConfigProvider modelconfig.ModelConfigProvider, logger *zap.Logger) conditionInterfaces.ConditionActor[*v2pb.InferenceServer] { +func NewModelConfigProvisionActor(clientFactory clientfactory.ClientFactory, modelConfigProvider modelconfig.ModelConfigProvider, logger *zap.Logger) conditionInterfaces.ConditionActor[*v2pb.InferenceServer] { return &ModelConfigProvisionActor{ - client: client, + clientFactory: clientFactory, modelConfigProvider: modelConfigProvider, logger: logger, } @@ -40,14 +41,28 @@ func (a *ModelConfigProvisionActor) GetType() string { func (a *ModelConfigProvisionActor) Retrieve(ctx context.Context, resource *v2pb.InferenceServer, condition *apipb.Condition) (*apipb.Condition, error) { a.logger.Info("Retrieving model config provisioning condition") - exists, err := a.modelConfigProvider.CheckModelConfigExists(ctx, a.logger, a.client, resource.Name, resource.Namespace) - if err != nil { - a.logger.Error("Failed to check model config existence", zap.Error(err)) - return conditionsutil.GenerateFalseCondition(condition, "ModelConfigProvisionFailed", fmt.Sprintf("Failed to check model config existence: %v", err)), err + var failures []string + for _, target := range resource.Spec.ClusterTargets { + kubeClient, err := a.clientFactory.GetClient(ctx, target) + if err != nil { + failures = append(failures, fmt.Sprintf("%s: client error: %v", target.GetClusterId(), err)) + continue + } + + exists, err := a.modelConfigProvider.CheckModelConfigExists(ctx, a.logger, kubeClient, resource.Name, resource.Namespace) + if err != nil { + a.logger.Error("Failed to check model config existence", zap.Error(err), zap.String("cluster_id", target.GetClusterId())) + failures = append(failures, fmt.Sprintf("%s: %v", target.GetClusterId(), err)) + continue + } + + if !exists { + failures = append(failures, fmt.Sprintf("%s: model config not found", target.GetClusterId())) + } } - if !exists { - return conditionsutil.GenerateFalseCondition(condition, "ModelConfigNotFound", "Model config not found"), nil + if len(failures) > 0 { + return conditionsutil.GenerateFalseCondition(condition, "ModelConfigNotFound", strings.Join(failures, "; ")), nil } return conditionsutil.GenerateTrueCondition(condition), nil } @@ -55,9 +70,20 @@ func (a *ModelConfigProvisionActor) Retrieve(ctx context.Context, resource *v2pb func (a *ModelConfigProvisionActor) Run(ctx context.Context, resource *v2pb.InferenceServer, condition *apipb.Condition) (*apipb.Condition, error) { a.logger.Info("Running model config provisioning") - err := a.modelConfigProvider.CreateModelConfig(ctx, a.logger, a.client, resource.Name, resource.Namespace, map[string]string{}, map[string]string{}) - if err != nil { - return conditionsutil.GenerateFalseCondition(condition, "ModelConfigProvisionFailed", fmt.Sprintf("Failed to create config map: %v", err)), err + isDone := func(ctx context.Context, kubeClient client.Client, target *v2pb.ClusterTarget) (bool, error) { + return a.modelConfigProvider.CheckModelConfigExists(ctx, a.logger, kubeClient, resource.Name, resource.Namespace) } - return conditionsutil.GenerateTrueCondition(condition), nil + doWork := func(ctx context.Context, kubeClient client.Client, target *v2pb.ClusterTarget) error { + err := a.modelConfigProvider.CreateModelConfig(ctx, a.logger, kubeClient, resource.Name, resource.Namespace, map[string]string{}, map[string]string{}) + if err != nil { + a.logger.Error("Failed to create model config", + zap.Error(err), + zap.String("operation", "create_modelconfig"), + zap.String("namespace", resource.Namespace), + zap.String("inferenceServer", resource.Name), + zap.String("cluster_id", target.GetClusterId())) + } + return err + } + return common.RunRolling(ctx, a.clientFactory, resource.Spec.ClusterTargets, condition, isDone, doWork) } diff --git a/go/components/inferenceserver/plugins/oss/creation/validation.go b/go/components/inferenceserver/plugins/oss/creation/validation.go index 48d006774..a136d51a4 100644 --- a/go/components/inferenceserver/plugins/oss/creation/validation.go +++ b/go/components/inferenceserver/plugins/oss/creation/validation.go @@ -45,6 +45,16 @@ func (a *ValidationActor) Retrieve(ctx context.Context, resource *v2pb.Inference return conditionUtils.GenerateFalseCondition(condition, "InvalidBackendType", fmt.Sprintf("unsupported backend type: %v", resource.Spec.BackendType)), nil } + if len(resource.Spec.ClusterTargets) == 0 { + return conditionUtils.GenerateFalseCondition(condition, "NoClusterTargets", "spec.cluster_targets must declare at least one cluster"), nil + } + + // Validate cluster rollout strategy annotation before operational actors attempt multi-cluster iteration. + if strategy := common.GetRolloutStrategy(resource); !common.IsKnownRolloutStrategy(strategy) { + return conditionUtils.GenerateFalseCondition(condition, "InvalidRolloutStrategy", + fmt.Sprintf("unknown cluster rollout strategy %q; supported: rolling", strategy)), nil + } + return conditionUtils.GenerateTrueCondition(condition), nil } diff --git a/go/components/inferenceserver/plugins/oss/deletion/BUILD.bazel b/go/components/inferenceserver/plugins/oss/deletion/BUILD.bazel index 57bdaa3cb..b976c5f9c 100644 --- a/go/components/inferenceserver/plugins/oss/deletion/BUILD.bazel +++ b/go/components/inferenceserver/plugins/oss/deletion/BUILD.bazel @@ -12,6 +12,7 @@ go_library( "//go/base/conditions/interfaces:go_default_library", "//go/base/conditions/utils:go_default_library", "//go/components/inferenceserver/backends:go_default_library", + "//go/components/inferenceserver/clientfactory:go_default_library", "//go/components/inferenceserver/modelconfig:go_default_library", "//go/components/inferenceserver/plugins/oss/common:go_default_library", "//proto/api:go_default_library", diff --git a/go/components/inferenceserver/plugins/oss/deletion/cleanup.go b/go/components/inferenceserver/plugins/oss/deletion/cleanup.go index b7a99025c..9bf60ee5b 100644 --- a/go/components/inferenceserver/plugins/oss/deletion/cleanup.go +++ b/go/components/inferenceserver/plugins/oss/deletion/cleanup.go @@ -3,14 +3,15 @@ package deletion import ( "context" "fmt" + "strings" "go.uber.org/zap" - "sigs.k8s.io/controller-runtime/pkg/client" conditionInterfaces "github.com/michelangelo-ai/michelangelo/go/base/conditions/interfaces" conditionsUtil "github.com/michelangelo-ai/michelangelo/go/base/conditions/utils" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/backends" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/modelconfig" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/plugins/oss/common" apipb "github.com/michelangelo-ai/michelangelo/proto-go/api" @@ -21,16 +22,16 @@ var _ conditionInterfaces.ConditionActor[*v2pb.InferenceServer] = &CleanupActor{ // CleanupActor removes all Kubernetes resources associated with an inference server. type CleanupActor struct { - client client.Client + clientFactory clientfactory.ClientFactory registry *backends.Registry modelConfigProvider modelconfig.ModelConfigProvider logger *zap.Logger } // NewCleanupActor creates a condition actor for inference server cleanup during deletion. -func NewCleanupActor(client client.Client, registry *backends.Registry, modelConfigProvider modelconfig.ModelConfigProvider, logger *zap.Logger) conditionInterfaces.ConditionActor[*v2pb.InferenceServer] { +func NewCleanupActor(clientFactory clientfactory.ClientFactory, registry *backends.Registry, modelConfigProvider modelconfig.ModelConfigProvider, logger *zap.Logger) conditionInterfaces.ConditionActor[*v2pb.InferenceServer] { return &CleanupActor{ - client: client, + clientFactory: clientFactory, registry: registry, modelConfigProvider: modelConfigProvider, logger: logger, @@ -52,12 +53,24 @@ func (a *CleanupActor) Retrieve(ctx context.Context, resource *v2pb.InferenceSer return conditionsUtil.GenerateTrueCondition(condition), nil } - // Check if inference server still exists - _, err = backend.GetServerStatus(ctx, a.logger, a.client, resource.Name, resource.Namespace) - if err == nil { - return conditionsUtil.GenerateFalseCondition(condition, "CleanupInProgress", "Inference server cleanup in progress"), nil + var failures []string + for _, target := range resource.Spec.ClusterTargets { + kubeClient, err := a.clientFactory.GetClient(ctx, target) + if err != nil { + failures = append(failures, fmt.Sprintf("%s: client error: %v", target.GetClusterId(), err)) + continue + } + + // Check if inference server still exists + _, err = backend.GetServerStatus(ctx, a.logger, kubeClient, resource.Name, resource.Namespace) + if err == nil { + failures = append(failures, fmt.Sprintf("%s: still present", target.GetClusterId())) + } } + if len(failures) > 0 { + return conditionsUtil.GenerateFalseCondition(condition, "CleanupInProgress", strings.Join(failures, "; ")), nil + } return conditionsUtil.GenerateTrueCondition(condition), nil } @@ -65,21 +78,6 @@ func (a *CleanupActor) Retrieve(ctx context.Context, resource *v2pb.InferenceSer func (a *CleanupActor) Run(ctx context.Context, resource *v2pb.InferenceServer, condition *apipb.Condition) (*apipb.Condition, error) { a.logger.Info("Running inference server cleanup with ConfigMap cleanup") - // Delete Model Config first - a.logger.Info("Cleaning up Model Config for inference server", zap.String("inferenceServer", resource.Name)) - - // Clean up model-config - if err := a.modelConfigProvider.DeleteModelConfig(ctx, a.logger, a.client, resource.Name, resource.Namespace); err != nil { - a.logger.Error("Failed to delete Model Config", - zap.Error(err), - zap.String("operation", "delete_modelconfig"), - zap.String("namespace", resource.Namespace), - zap.String("inferenceServer", resource.Name), - ) - } else { - a.logger.Info("Successfully deleted Model Config for inference server", zap.String("inferenceServer", resource.Name)) - } - // Get backend from registry backend, err := a.registry.GetBackend(resource.Spec.BackendType) if err != nil { @@ -87,18 +85,42 @@ func (a *CleanupActor) Run(ctx context.Context, resource *v2pb.InferenceServer, return conditionsUtil.GenerateTrueCondition(condition), nil } - // Delete inference server - a.logger.Info("Cleaning up inference server", zap.String("inferenceServer", resource.Name)) - err = backend.DeleteServer(ctx, a.logger, a.client, resource.Name, resource.Namespace) - if err != nil { - a.logger.Error("Failed to delete inference server", - zap.Error(err), - zap.String("operation", "delete_server"), + isDone := func(ctx context.Context, kubeClient client.Client, target *v2pb.ClusterTarget) (bool, error) { + _, err := backend.GetServerStatus(ctx, a.logger, kubeClient, resource.Name, resource.Namespace) + return err != nil, nil // resource gone (error) means done + } + doWork := func(ctx context.Context, kubeClient client.Client, target *v2pb.ClusterTarget) error { + // Delete Model Config first (preserving existing ordering) + a.logger.Info("Cleaning up Model Config for inference server", + zap.String("namespace", resource.Namespace), + zap.String("inferenceServer", resource.Name), + zap.String("cluster_id", target.GetClusterId())) + if mcErr := a.modelConfigProvider.DeleteModelConfig(ctx, a.logger, kubeClient, resource.Name, resource.Namespace); mcErr != nil { + a.logger.Error("Failed to delete Model Config", + zap.Error(mcErr), + zap.String("operation", "delete_modelconfig"), + zap.String("namespace", resource.Namespace), + zap.String("inferenceServer", resource.Name), + zap.String("cluster_id", target.GetClusterId()), + ) + } else { + a.logger.Info("Successfully deleted Model Config for inference server", + zap.String("namespace", resource.Namespace), + zap.String("inferenceServer", resource.Name), + zap.String("cluster_id", target.GetClusterId())) + } + a.logger.Info("Cleaning up inference server", + zap.String("namespace", resource.Namespace), + zap.String("inferenceServer", resource.Name), + zap.String("cluster_id", target.GetClusterId())) + if err := backend.DeleteServer(ctx, a.logger, kubeClient, resource.Name, resource.Namespace); err != nil { + return err + } + a.logger.Info("Inference server cleanup completed successfully", zap.String("namespace", resource.Namespace), - zap.String("inferenceServer", resource.Name)) - return conditionsUtil.GenerateFalseCondition(condition, "ServerCleanupFailed", fmt.Sprintf("Failed to cleanup inference server: %v", err)), fmt.Errorf("delete inference server %s/%s: %w", resource.Namespace, resource.Name, err) + zap.String("inferenceServer", resource.Name), + zap.String("cluster_id", target.GetClusterId())) + return nil } - - a.logger.Info("Inference server cleanup completed successfully", zap.String("inferenceServer", resource.Name)) - return conditionsUtil.GenerateTrueCondition(condition), nil + return common.RunRolling(ctx, a.clientFactory, resource.Spec.ClusterTargets, condition, isDone, doWork) } diff --git a/go/components/inferenceserver/plugins/oss/deletion/condition_plugin.go b/go/components/inferenceserver/plugins/oss/deletion/condition_plugin.go index 3d639b6f2..f29c829bf 100644 --- a/go/components/inferenceserver/plugins/oss/deletion/condition_plugin.go +++ b/go/components/inferenceserver/plugins/oss/deletion/condition_plugin.go @@ -3,10 +3,9 @@ package deletion import ( "go.uber.org/zap" - "sigs.k8s.io/controller-runtime/pkg/client" - conditionInterfaces "github.com/michelangelo-ai/michelangelo/go/base/conditions/interfaces" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/backends" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/modelconfig" apipb "github.com/michelangelo-ai/michelangelo/proto-go/api" v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" @@ -14,16 +13,16 @@ import ( // DeletionPlugin orchestrates the condition actors for inference server deletion. type DeletionPlugin struct { - client client.Client + clientFactory clientfactory.ClientFactory registry *backends.Registry modelConfigProvider modelconfig.ModelConfigProvider logger *zap.Logger } // NewDeletionPlugin creates a plugin that manages cleanup of all inference server resources. -func NewDeletionPlugin(client client.Client, registry *backends.Registry, modelConfigProvider modelconfig.ModelConfigProvider, logger *zap.Logger) conditionInterfaces.Plugin[*v2pb.InferenceServer] { +func NewDeletionPlugin(clientFactory clientfactory.ClientFactory, registry *backends.Registry, modelConfigProvider modelconfig.ModelConfigProvider, logger *zap.Logger) conditionInterfaces.Plugin[*v2pb.InferenceServer] { return &DeletionPlugin{ - client: client, + clientFactory: clientFactory, registry: registry, modelConfigProvider: modelConfigProvider, logger: logger, @@ -33,7 +32,7 @@ func NewDeletionPlugin(client client.Client, registry *backends.Registry, modelC // GetActors returns the condition actors for deletion workflow. func (p *DeletionPlugin) GetActors() []conditionInterfaces.ConditionActor[*v2pb.InferenceServer] { return []conditionInterfaces.ConditionActor[*v2pb.InferenceServer]{ - NewCleanupActor(p.client, p.registry, p.modelConfigProvider, p.logger), + NewCleanupActor(p.clientFactory, p.registry, p.modelConfigProvider, p.logger), } } diff --git a/go/components/inferenceserver/plugins/oss/plugin.go b/go/components/inferenceserver/plugins/oss/plugin.go index 3058a63d8..4840b9296 100644 --- a/go/components/inferenceserver/plugins/oss/plugin.go +++ b/go/components/inferenceserver/plugins/oss/plugin.go @@ -8,10 +8,9 @@ import ( corev1 "k8s.io/api/core/v1" - "sigs.k8s.io/controller-runtime/pkg/client" - conditionInterfaces "github.com/michelangelo-ai/michelangelo/go/base/conditions/interfaces" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/backends" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" modelconfig "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/modelconfig" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/plugins" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/plugins/oss/creation" @@ -28,22 +27,22 @@ type Plugin struct { creationPlugin conditionInterfaces.Plugin[*v2pb.InferenceServer] deletionPlugin conditionInterfaces.Plugin[*v2pb.InferenceServer] - registry *backends.Registry - client client.Client - Recorder record.EventRecorder - logger *zap.Logger + registry *backends.Registry + clientFactory clientfactory.ClientFactory + Recorder record.EventRecorder + logger *zap.Logger } // NewPlugin creates a plugin with creation and deletion workflows. -func NewOSSPlugin(client client.Client, registry *backends.Registry, modelConfigProvider modelconfig.ModelConfigProvider, recorder record.EventRecorder, logger *zap.Logger) plugins.Plugin { +func NewOSSPlugin(clientFactory clientfactory.ClientFactory, registry *backends.Registry, modelConfigProvider modelconfig.ModelConfigProvider, recorder record.EventRecorder, logger *zap.Logger) plugins.Plugin { return &Plugin{ - creationPlugin: creation.NewCreationPlugin(client, registry, modelConfigProvider, logger), - deletionPlugin: deletion.NewDeletionPlugin(client, registry, modelConfigProvider, logger), + creationPlugin: creation.NewCreationPlugin(clientFactory, registry, modelConfigProvider, logger), + deletionPlugin: deletion.NewDeletionPlugin(clientFactory, registry, modelConfigProvider, logger), - client: client, - registry: registry, - Recorder: recorder, - logger: logger, + clientFactory: clientFactory, + registry: registry, + Recorder: recorder, + logger: logger, } } @@ -122,28 +121,23 @@ func (p *Plugin) UpdateDetails(ctx context.Context, resource *v2pb.InferenceServ return nil } - // Get current status from backend - status, err := backend.GetServerStatus(ctx, p.logger, p.client, resource.Name, resource.Namespace) - if err != nil { - // Don't fail reconciliation for status check errors - p.logger.Error("Failed to get server status", - zap.Error(err), - zap.String("operation", "get_server_status"), - zap.String("namespace", resource.Namespace), - zap.String("inferenceServer", resource.Name)) + // Get current status from backend, aggregated across cluster targets + aggregateState, ok := p.aggregateBackendState(ctx, backend, resource) + if !ok { + // No conclusive aggregate this reconcile — keep the existing state. return nil } // Update status based on external state - if status.State != resource.Status.State { + if aggregateState != resource.Status.State { p.logger.Info("External state change detected", zap.String("currentState", resource.Status.State.String()), - zap.String("externalState", status.State.String())) + zap.String("externalState", aggregateState.String())) - resource.Status.State = status.State + resource.Status.State = aggregateState // Record state transition events - switch status.State { + switch aggregateState { case v2pb.INFERENCE_SERVER_STATE_SERVING: p.Recorder.Event(resource, corev1.EventTypeNormal, "CreationCompleted", "InferenceServer creation completed successfully") case v2pb.INFERENCE_SERVER_STATE_FAILED: @@ -153,6 +147,56 @@ func (p *Plugin) UpdateDetails(ctx context.Context, resource *v2pb.InferenceServ return nil } +// aggregateBackendState polls the backend on every target cluster and reduces the +// per-cluster states into a single InferenceServerState. The boolean return is false +// when no conclusive aggregate exists (e.g., all status fetches errored). +// +// FAILED on any cluster wins; otherwise SERVING requires every reachable cluster to +// be SERVING. +func (p *Plugin) aggregateBackendState(ctx context.Context, backend backends.Backend, resource *v2pb.InferenceServer) (v2pb.InferenceServerState, bool) { + servingCount := 0 + totalCount := 0 + hasFailure := false + + for _, target := range resource.Spec.ClusterTargets { + kubeClient, err := p.clientFactory.GetClient(ctx, target) + if err != nil { + p.logger.Error("Failed to resolve client", + zap.Error(err), + zap.String("operation", "resolve_client"), + zap.String("inferenceServer", resource.Name), + zap.String("cluster_id", target.GetClusterId())) + continue + } + totalCount++ + status, err := backend.GetServerStatus(ctx, p.logger, kubeClient, resource.Name, resource.Namespace) + if err != nil { + // Don't fail reconciliation for status check errors + p.logger.Error("Failed to get server status", + zap.Error(err), + zap.String("operation", "get_server_status"), + zap.String("namespace", resource.Namespace), + zap.String("inferenceServer", resource.Name), + zap.String("cluster_id", target.GetClusterId())) + continue + } + switch status.State { + case v2pb.INFERENCE_SERVER_STATE_SERVING: + servingCount++ + case v2pb.INFERENCE_SERVER_STATE_FAILED: + hasFailure = true + } + } + + if hasFailure { + return v2pb.INFERENCE_SERVER_STATE_FAILED, true + } + if totalCount > 0 && servingCount == totalCount { + return v2pb.INFERENCE_SERVER_STATE_SERVING, true + } + return v2pb.INFERENCE_SERVER_STATE_INVALID, false +} + // UpdateConditions filters the resource conditions to only those relevant to the current plugin workflow. func (p *Plugin) UpdateConditions(resource *v2pb.InferenceServer, conditionPlugin conditionInterfaces.Plugin[*v2pb.InferenceServer]) { actors := conditionPlugin.GetActors() diff --git a/proto/api/v2/inference_server.proto b/proto/api/v2/inference_server.proto index 6b4826e60..0fc4eb474 100644 --- a/proto/api/v2/inference_server.proto +++ b/proto/api/v2/inference_server.proto @@ -9,6 +9,7 @@ option go_package = "v2"; import "k8s.io/apimachinery/pkg/apis/meta/v1/generated.proto"; import "michelangelo/api/options.proto"; import "michelangelo/api/conditions.proto"; +import "michelangelo/api/v2/kubernetes.proto"; import "michelangelo/api/v2/pod.proto"; import "michelangelo/api/v2/project.proto"; import "michelangelo/api/v2/user.proto"; @@ -79,6 +80,21 @@ enum BackendType { BACKEND_TYPE_TORCHSERVE = 4; } +// ClusterTarget identifies a Kubernetes cluster where the InferenceServer's +// workloads should be provisioned. +message ClusterTarget { + // Stable, human-readable identifier for the cluster (e.g. "k3d-test-a", + // "gke-prod-us-west-1"). Must be unique across the InferenceServer's + // cluster_targets list. + string cluster_id = 1; + + // How to connect to the cluster. + oneof connection { + option (michelangelo.api.required) = true; + ConnectionSpec kubernetes = 2; + } +} + // Inference Server spec. message InferenceServerSpec { // Tenancy type of this InferenceServer. @@ -94,6 +110,10 @@ message InferenceServerSpec { UserInfo owner = 5; // Backend type for this InferenceServer (e.g. Triton, LLM-D, Dynamo). BackendType backend_type = 6; + // Clusters where this InferenceServer should be provisioned. One workload + // set per ClusterTarget. + // At least one ClusterTarget is required. + repeated ClusterTarget cluster_targets = 7; } // The state of the Inference Server. @@ -109,6 +129,22 @@ enum InferenceServerState { INFERENCE_SERVER_STATE_DELETED = 8; } +// ClusterTargetStatus reports the observed state of an InferenceServer on +// a single ClusterTarget. +message ClusterTargetStatus { + // Matches ClusterTarget.cluster_id in the spec. + string cluster_id = 1; + + // Per-cluster lifecycle state. + InferenceServerState state = 2; + + // Per-cluster conditions. + repeated Condition conditions = 3; + + // Free-form message for human-readable detail (e.g., last error). + string message = 4; +} + // The status of the InferenceServer. message InferenceServerStatus { // The uuid of the InferenceServer. @@ -132,6 +168,8 @@ message InferenceServerStatus { // Available environments correspond to the list of environments that are available for // this InferenceServer resource. repeated string available_environments = 8; + // Per-cluster status, one entry per ClusterTarget in spec.cluster_targets. + repeated ClusterTargetStatus cluster_statuses = 9; } // Defines a managed InferenceServer service for model serving. diff --git a/python/michelangelo/cli/sandbox/demo/inference/inferenceserver.yaml b/python/michelangelo/cli/sandbox/demo/inference/inferenceserver.yaml index 8d44f6974..bb3f5fadf 100644 --- a/python/michelangelo/cli/sandbox/demo/inference/inferenceserver.yaml +++ b/python/michelangelo/cli/sandbox/demo/inference/inferenceserver.yaml @@ -21,3 +21,10 @@ spec: decommission: false owner: name: "user-example" + clusterTargets: + - clusterId: michelangelo-sandbox + kubernetes: + host: https://kubernetes.default.svc + port: "443" + tokenTag: cluster-michelangelo-sandbox-is-token + caDataTag: cluster-michelangelo-sandbox-ca-data diff --git a/python/michelangelo/cli/sandbox/resources/rbac-inferenceserver.yaml b/python/michelangelo/cli/sandbox/resources/rbac-inferenceserver.yaml new file mode 100644 index 000000000..4eec63115 --- /dev/null +++ b/python/michelangelo/cli/sandbox/resources/rbac-inferenceserver.yaml @@ -0,0 +1,30 @@ +apiVersion: v1 +kind: ServiceAccount +metadata: + name: inference-server-manager + namespace: default +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: inference-server-manager +rules: +- apiGroups: ["apps"] + resources: ["deployments"] + verbs: ["create", "get", "list", "watch", "update", "patch", "delete"] +- apiGroups: [""] + resources: ["services", "configmaps"] + verbs: ["create", "get", "list", "watch", "update", "patch", "delete"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: inference-server-manager-binding +subjects: +- kind: ServiceAccount + name: inference-server-manager + namespace: default +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: inference-server-manager diff --git a/python/michelangelo/cli/sandbox/sandbox.py b/python/michelangelo/cli/sandbox/sandbox.py index bf098b84e..3408709d7 100644 --- a/python/michelangelo/cli/sandbox/sandbox.py +++ b/python/michelangelo/cli/sandbox/sandbox.py @@ -1647,6 +1647,71 @@ def _create_compute_cluster_secrets(cluster_name: str): print(f"\nCreated secrets for cluster '{cluster_name}' in the sandbox cluster") +def _setup_inference_server_secrets(): + """Create RBAC and credentials for inference server cluster access. + + Applies an inference-server-manager ServiceAccount with permissions to + manage Deployments, Services, and ConfigMaps (required for Triton provisioning). + Stores a long-lived bearer token as a Secret so the clientfactory can build + a remote kube client for the sandbox cluster using kubernetes.default.svc:443. + + The CA secret (cluster-michelangelo-sandbox-ca-data) is already created by + the sandbox create flow; we only need to provision the token here. + """ + cluster_name = _michelangelo_sandbox_kube_cluster_name + token_secret_name = f"cluster-{cluster_name}-is-token" + + # Check if the token secret already exists to make this idempotent. + exists = ( + subprocess.run( + ["kubectl", "get", "secret", token_secret_name], + capture_output=True, + ).returncode + == 0 + ) + if exists: + print( + f"Secret '{token_secret_name}' already exists — " + "skipping inference server credential setup." + ) + return + + # Apply ServiceAccount + ClusterRole + ClusterRoleBinding. + _kube_apply(_dir / "resources" / "rbac-inferenceserver.yaml") + + # Mint a long-lived token (same duration as ray-manager) so the sandbox + # does not require frequent re-creation. + token_decoded = ( + subprocess.check_output( + [ + "kubectl", + "create", + "token", + "inference-server-manager", + "-n", + "default", + "--duration=87600h", + ] + ) + .decode() + .strip() + ) + + token_secret = { + "apiVersion": "v1", + "kind": "Secret", + "metadata": {"name": token_secret_name, "namespace": "default"}, + "stringData": {"token": token_decoded}, + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(token_secret, f) + f.flush() + _exec("kubectl", "apply", "-f", f.name) + + print(f"Created inference server credentials for cluster '{cluster_name}'") + + def _create_inference_demo_crs(): """Create an inference server for the sandbox cluster for demo purposes.""" print("🚀 Setting up Michelangelo AI Inference Demo...") @@ -1655,6 +1720,10 @@ def _create_inference_demo_crs(): # This allows usage of HTTPRoutes to route traffic to the inference server. _setup_istio_with_gateway_api() + # Create the SA, RBAC, and token secret that the clientfactory uses to + # connect to the sandbox cluster as a ClusterTarget. + _setup_inference_server_secrets() + inference_demo_dir = _dir / "demo" / "inference" # Create inference server CR inference_server_path = inference_demo_dir / "inferenceserver.yaml" From d71bed1d9f22665e89a6305853144160423c3465 Mon Sep 17 00:00:00 2001 From: Aritra Date: Wed, 29 Apr 2026 20:25:05 -0700 Subject: [PATCH 2/2] feat(inferenceserver): make services in target clusters discoverable within control plane cluster --- go/base/config/config.go | 25 ++ go/cmd/controllermgr/config/base.yaml | 12 +- go/components/inferenceserver/BUILD.bazel | 2 + .../inferenceserver/endpoints/BUILD.bazel | 46 +++ .../inferenceserver/endpoints/interface.go | 46 +++ .../inferenceserver/endpoints/module.go | 26 ++ .../inferenceserver/endpoints/publisher.go | 266 ++++++++++++++++++ .../endpoints/publisher_test.go | 207 ++++++++++++++ .../endpoints/source/BUILD.bazel | 22 ++ .../inferenceserver/endpoints/source/k3d.go | 108 +++++++ .../endpoints/source/module.go | 20 ++ go/components/inferenceserver/module.go | 4 + .../inferenceserver/plugins/oss/BUILD.bazel | 1 + .../plugins/oss/common/constants.go | 1 + .../plugins/oss/creation/BUILD.bazel | 2 + .../plugins/oss/creation/condition_plugin.go | 8 +- .../plugins/oss/creation/endpoint_publish.go | 143 ++++++++++ .../inferenceserver/plugins/oss/plugin.go | 5 +- .../sandbox/resources/gateway-api-setup.yaml | 2 +- .../resources/michelangelo-controllermgr.yaml | 10 + .../michelangelo-temporal-controllermgr.yaml | 10 + .../resources/rbac-inferenceserver.yaml | 6 + 22 files changed, 967 insertions(+), 5 deletions(-) create mode 100644 go/components/inferenceserver/endpoints/BUILD.bazel create mode 100644 go/components/inferenceserver/endpoints/interface.go create mode 100644 go/components/inferenceserver/endpoints/module.go create mode 100644 go/components/inferenceserver/endpoints/publisher.go create mode 100644 go/components/inferenceserver/endpoints/publisher_test.go create mode 100644 go/components/inferenceserver/endpoints/source/BUILD.bazel create mode 100644 go/components/inferenceserver/endpoints/source/k3d.go create mode 100644 go/components/inferenceserver/endpoints/source/module.go create mode 100644 go/components/inferenceserver/plugins/oss/creation/endpoint_publish.go diff --git a/go/base/config/config.go b/go/base/config/config.go index 58ee8003a..0043014c5 100644 --- a/go/base/config/config.go +++ b/go/base/config/config.go @@ -23,6 +23,7 @@ const ( _workflowClientConfigKey = "workflowClient" _mysqlConfigKey = "mysql" _ingesterConfigKey = "ingester" + _inferenceServerConfigKey = "inferenceServer" ) // K8sConfig is the configuration for k8s REST client. @@ -130,3 +131,27 @@ func GetIngesterConfig(provider config.Provider) (IngesterConfig, error) { err := provider.Get(_ingesterConfigKey).Populate(&ingesterConfig) return ingesterConfig, err } + +// InferenceServerConfig is the controller-side configuration for the inference +// server controller. +type InferenceServerConfig struct { + Gateway GatewayConfig `yaml:"gateway"` +} + +// GatewayConfig describes how the inference server controller locates a +// cluster's ingress Gateway Service. The Service is set up out-of-band (e.g., +// by `ma sandbox create` for sandbox); this config tells the EndpointSource +// where to find it. +type GatewayConfig struct { + ServiceName string `yaml:"serviceName"` + ServiceNamespace string `yaml:"serviceNamespace"` + PortName string `yaml:"portName"` +} + +// GetInferenceServerConfig parses the configuration file and returns the +// inference server controller configuration. +func GetInferenceServerConfig(provider config.Provider) (InferenceServerConfig, error) { + inferenceServerConfig := InferenceServerConfig{} + err := provider.Get(_inferenceServerConfigKey).Populate(&inferenceServerConfig) + return inferenceServerConfig, err +} diff --git a/go/cmd/controllermgr/config/base.yaml b/go/cmd/controllermgr/config/base.yaml index a29aacf5f..8d2380315 100644 --- a/go/cmd/controllermgr/config/base.yaml +++ b/go/cmd/controllermgr/config/base.yaml @@ -28,4 +28,14 @@ minio: awsRegion: us-east-1 awsAccessKeyId: minioadmin awsSecretAccessKey: minioadmin - awsEndpointUrl: localhost:9091 \ No newline at end of file + awsEndpointUrl: localhost:9091 + +inferenceServer: + gateway: + # Istio's Gateway controller materializes a Service named '-istio' + # for each Gateway resource. The IS controller's EndpointSource reads the + # NodePort + node InternalIP of this Service to publish per-cluster + # EndpointSlices in the control plane. + serviceName: ma-gateway-istio + serviceNamespace: default + portName: http \ No newline at end of file diff --git a/go/components/inferenceserver/BUILD.bazel b/go/components/inferenceserver/BUILD.bazel index 158270504..484bdaf6f 100644 --- a/go/components/inferenceserver/BUILD.bazel +++ b/go/components/inferenceserver/BUILD.bazel @@ -16,6 +16,8 @@ go_library( "//go/base/conditions/engine:go_default_library", "//go/base/conditions/interfaces:go_default_library", "//go/components/inferenceserver/clientfactory:go_default_library", + "//go/components/inferenceserver/endpoints:go_default_library", + "//go/components/inferenceserver/endpoints/source:go_default_library", "//go/components/inferenceserver/plugins:go_default_library", "//proto/api/v2:go_default_library", "@io_k8s_api//core/v1:go_default_library", diff --git a/go/components/inferenceserver/endpoints/BUILD.bazel b/go/components/inferenceserver/endpoints/BUILD.bazel new file mode 100644 index 000000000..16e8ee9c5 --- /dev/null +++ b/go/components/inferenceserver/endpoints/BUILD.bazel @@ -0,0 +1,46 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "go_default_library", + srcs = [ + "interface.go", + "module.go", + "publisher.go", + ], + importpath = "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/endpoints", + visibility = ["//visibility:public"], + deps = [ + "//go/base/config:go_default_library", + "//proto/api/v2:go_default_library", + "@io_k8s_api//core/v1:go_default_library", + "@io_k8s_api//discovery/v1:go_default_library", + "@io_k8s_apimachinery//pkg/api/errors:go_default_library", + "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library", + "@io_k8s_apimachinery//pkg/runtime:go_default_library", + "@io_k8s_apimachinery//pkg/types:go_default_library", + "@io_k8s_apimachinery//pkg/util/intstr:go_default_library", + "@io_k8s_sigs_controller_runtime//pkg/client:go_default_library", + "@io_k8s_sigs_controller_runtime//pkg/controller/controllerutil:go_default_library", + "@org_uber_go_config//:go_default_library", + "@org_uber_go_fx//:go_default_library", + ], +) + +go_test( + name = "go_default_test", + srcs = ["publisher_test.go"], + embed = [":go_default_library"], + deps = [ + "//proto/api/v2:go_default_library", + "@com_github_stretchr_testify//assert:go_default_library", + "@com_github_stretchr_testify//require:go_default_library", + "@io_k8s_api//core/v1:go_default_library", + "@io_k8s_api//discovery/v1:go_default_library", + "@io_k8s_apimachinery//pkg/api/errors:go_default_library", + "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library", + "@io_k8s_apimachinery//pkg/runtime:go_default_library", + "@io_k8s_apimachinery//pkg/types:go_default_library", + "@io_k8s_sigs_controller_runtime//pkg/client:go_default_library", + "@io_k8s_sigs_controller_runtime//pkg/client/fake:go_default_library", + ], +) diff --git a/go/components/inferenceserver/endpoints/interface.go b/go/components/inferenceserver/endpoints/interface.go new file mode 100644 index 000000000..815b471aa --- /dev/null +++ b/go/components/inferenceserver/endpoints/interface.go @@ -0,0 +1,46 @@ +// Package endpoints defines the abstractions the InferenceServer controller +// uses to publish per-cluster service-discovery information about an +// InferenceServer. EndpointSource resolves the network address at which one +// cluster admits traffic for an InferenceServer. EndpointPublisher maintains +// a per-server cluster ID to Endpoint map readable by other components. +package endpoints + +import ( + "context" + + v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" +) + +// Endpoint is the network address (host, port, scheme) at which one cluster +// admits traffic for an InferenceServer. +type Endpoint struct { + Host string + Port int32 + Scheme string // "http" | "https" +} + +// EndpointSource resolves the ingress endpoint for one ClusterTarget. +// Implementations abstract away the per-environment differences in how a +// cluster's ingress address is discovered. Implementations bind via fx; +// callers do not branch on environment. +type EndpointSource interface { + Resolve(ctx context.Context, target *v2pb.ClusterTarget) (Endpoint, error) +} + +// EndpointPublisher maintains the published cluster ID to Endpoint map for one +// InferenceServer. The interface defines the contract (Sync the desired map, +// Get it back, Delete it) and is agnostic about how the map is stored and how +// other components observe it. +type EndpointPublisher interface { + // Sync reconciles the published map to match `endpoints`. Idempotent. + // Cluster IDs in `endpoints` are upserted. Cluster IDs previously published + // but absent from `endpoints` are removed. + Sync(ctx context.Context, server *v2pb.InferenceServer, endpoints map[string]Endpoint) error + + // Get returns the currently published cluster ID to Endpoint map for the + // server. The map is empty when nothing has been published yet. + Get(ctx context.Context, server *v2pb.InferenceServer) (map[string]Endpoint, error) + + // Delete removes everything the publisher has created for the server. + Delete(ctx context.Context, server *v2pb.InferenceServer) error +} diff --git a/go/components/inferenceserver/endpoints/module.go b/go/components/inferenceserver/endpoints/module.go new file mode 100644 index 000000000..24fb20e37 --- /dev/null +++ b/go/components/inferenceserver/endpoints/module.go @@ -0,0 +1,26 @@ +package endpoints + +import ( + "go.uber.org/config" + "go.uber.org/fx" + "sigs.k8s.io/controller-runtime/pkg/client" + + maconfig "github.com/michelangelo-ai/michelangelo/go/base/config" +) + +// Module wires the EndpointPublisher with the in-cluster Kubernetes client and +// the InferenceServerConfig from the typed config provider. The +// EndpointSource is provided separately by an environment-specific module +// (for example, source.Module for k3d-style clusters). +var Module = fx.Options( + fx.Provide(newDefaultPublisher), + fx.Provide(newInferenceServerConfig), +) + +func newDefaultPublisher(kubeClient client.Client) EndpointPublisher { + return NewDefaultPublisher(kubeClient, kubeClient.Scheme()) +} + +func newInferenceServerConfig(provider config.Provider) (maconfig.InferenceServerConfig, error) { + return maconfig.GetInferenceServerConfig(provider) +} diff --git a/go/components/inferenceserver/endpoints/publisher.go b/go/components/inferenceserver/endpoints/publisher.go new file mode 100644 index 000000000..98c82a5f8 --- /dev/null +++ b/go/components/inferenceserver/endpoints/publisher.go @@ -0,0 +1,266 @@ +package endpoints + +import ( + "context" + "fmt" + + corev1 "k8s.io/api/core/v1" + discoveryv1 "k8s.io/api/discovery/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/intstr" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + + v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" +) + +// Label keys on the published Service and EndpointSlices. +const ( + // kubeServiceNameLabel is the well-known kubernetes label that links an + // EndpointSlice to its parent Service. Gateway controllers (Istio, Envoy + // Gateway, GKE Gateway) read EndpointSlices keyed on this label when + // resolving the Service's backends. + kubeServiceNameLabel = "kubernetes.io/service-name" + + // clusterIDLabel records which ClusterTarget this EndpointSlice represents. + // Used by the publisher to find slices for orphan deletion in Sync. + clusterIDLabel = "michelangelo.ai/cluster-id" + + // portName is the named port shared between the parent Service and each + // EndpointSlice. The Gateway controller resolves the Service's port number + // to its name, then finds the matching name on the EndpointSlice and uses + // the EndpointSlice's port (the actual cluster Gateway NodePort or + // LoadBalancer port) as the upstream destination. The shared name is the + // join key. + portName = "http" + + // servicePort is the logical port published on the parent Service. It is + // not the upstream destination, which comes from the EndpointSlice port. + servicePort int32 = 80 + + // endpointsServiceSuffix is appended to the InferenceServer name to form + // the per-server Service name in the control plane. + endpointsServiceSuffix = "-endpoints" +) + +var _ EndpointPublisher = &defaultPublisher{} + +// defaultPublisher implements EndpointPublisher by writing into the local +// Kubernetes API. The published surface for a server is one ClusterIP Service +// named "{is-name}-endpoints" (no selector) plus one EndpointSlice per +// cluster, named "{is-name}-endpoints-{cluster-id}" and labeled with the +// parent service name. +// +// Gateway API implementations (Istio, Envoy Gateway, GKE Gateway) resolve a +// Service backend by reading its EndpointSlices, so consumers can reference +// "{is-name}-endpoints" as a single Service even though traffic actually fans +// out across cluster gateways. The Service must be ClusterIP rather than +// headless: headless Services do not work as backends in Gateway API +// implementations that take the ClusterIP resolution path. +type defaultPublisher struct { + kubeClient client.Client + scheme *runtime.Scheme +} + +// NewDefaultPublisher returns an EndpointPublisher that targets the local +// (control-plane) cluster via the supplied client. The scheme is needed to +// stamp each published object with a Kubernetes owner reference back to the +// InferenceServer, so the kube garbage collector wipes the Service and +// EndpointSlices automatically when the InferenceServer is deleted. +func NewDefaultPublisher(kubeClient client.Client, scheme *runtime.Scheme) EndpointPublisher { + return &defaultPublisher{kubeClient: kubeClient, scheme: scheme} +} + +// Sync makes the control-plane Service and per-cluster EndpointSlices match +// `endpoints`. Idempotent. +func (p *defaultPublisher) Sync(ctx context.Context, server *v2pb.InferenceServer, endpoints map[string]Endpoint) error { + if err := p.ensureService(ctx, server); err != nil { + return fmt.Errorf("ensure service: %w", err) + } + for clusterID, ep := range endpoints { + if err := p.upsertSlice(ctx, server, clusterID, ep); err != nil { + return fmt.Errorf("upsert endpoint slice %q: %w", clusterID, err) + } + } + // Delete orphan EndpointSlices: ones whose cluster_id label is no longer in + // the desired endpoints map (cluster removed from spec). + existing, err := p.listSlices(ctx, server) + if err != nil { + return fmt.Errorf("list existing slices: %w", err) + } + for _, slice := range existing.Items { + clusterID := slice.Labels[clusterIDLabel] + if _, keep := endpoints[clusterID]; keep { + continue + } + if err := p.kubeClient.Delete(ctx, &slice); err != nil && !apierrors.IsNotFound(err) { + return fmt.Errorf("delete orphan slice %q: %w", slice.Name, err) + } + } + return nil +} + +// Get returns the currently published cluster ID to Endpoint map for `server`. +// Empty when nothing is published yet. +func (p *defaultPublisher) Get(ctx context.Context, server *v2pb.InferenceServer) (map[string]Endpoint, error) { + slices, err := p.listSlices(ctx, server) + if err != nil { + return nil, fmt.Errorf("list slices: %w", err) + } + out := make(map[string]Endpoint, len(slices.Items)) + for _, slice := range slices.Items { + clusterID := slice.Labels[clusterIDLabel] + if clusterID == "" || len(slice.Endpoints) == 0 || len(slice.Ports) == 0 { + continue + } + ep := slice.Endpoints[0] + port := slice.Ports[0] + if len(ep.Addresses) == 0 || port.Port == nil { + continue + } + out[clusterID] = Endpoint{ + Host: ep.Addresses[0], + Port: *port.Port, + Scheme: "http", + } + } + return out, nil +} + +// Delete removes the per-server Service and every EndpointSlice for `server`. +func (p *defaultPublisher) Delete(ctx context.Context, server *v2pb.InferenceServer) error { + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: serviceName(server), + Namespace: server.Namespace, + }, + } + if err := p.kubeClient.Delete(ctx, svc); err != nil && !apierrors.IsNotFound(err) { + return fmt.Errorf("delete service: %w", err) + } + slices, err := p.listSlices(ctx, server) + if err != nil { + return fmt.Errorf("list slices: %w", err) + } + for _, slice := range slices.Items { + if err := p.kubeClient.Delete(ctx, &slice); err != nil && !apierrors.IsNotFound(err) { + return fmt.Errorf("delete slice %q: %w", slice.Name, err) + } + } + return nil +} + +// ensureService creates the per-server discovery Service if it does not already exist in the control plane. +// The Service has no selector because its EndpointSlices are populated +// explicitly by upsertSlice. ClusterIP (rather than headless) is required so +// the Service works as an HTTPRoute backend in Gateway implementations that +// resolve via the Service ClusterIP. +func (p *defaultPublisher) ensureService(ctx context.Context, server *v2pb.InferenceServer) error { + key := types.NamespacedName{Name: serviceName(server), Namespace: server.Namespace} + existing := &corev1.Service{} + err := p.kubeClient.Get(ctx, key, existing) + if err == nil { + return nil + } + if !apierrors.IsNotFound(err) { + return fmt.Errorf("get service: %w", err) + } + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: key.Name, + Namespace: key.Namespace, + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeClusterIP, + Ports: []corev1.ServicePort{{ + Name: portName, + Port: servicePort, + Protocol: corev1.ProtocolTCP, + TargetPort: intstr.FromString(portName), + }}, + }, + } + if refErr := controllerutil.SetControllerReference(server, svc, p.scheme); refErr != nil { + return fmt.Errorf("set owner reference on service: %w", refErr) + } + if createErr := p.kubeClient.Create(ctx, svc); createErr != nil && !apierrors.IsAlreadyExists(createErr) { + return fmt.Errorf("create service: %w", createErr) + } + return nil +} + +// upsertSlice creates or updates the EndpointSlice for one cluster. The slice +// is named "{is-name}-endpoints-{cluster-id}" and labeled with the parent +// service name (consumed by Gateway controllers via EDS) and the cluster ID +// (consumed by the publisher itself for orphan detection). +func (p *defaultPublisher) upsertSlice(ctx context.Context, server *v2pb.InferenceServer, clusterID string, ep Endpoint) error { + name := sliceName(server, clusterID) + key := types.NamespacedName{Name: name, Namespace: server.Namespace} + port := ep.Port + desired := &discoveryv1.EndpointSlice{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: server.Namespace, + Labels: map[string]string{ + kubeServiceNameLabel: serviceName(server), + clusterIDLabel: clusterID, + }, + }, + AddressType: discoveryv1.AddressTypeIPv4, + Endpoints: []discoveryv1.Endpoint{{ + Addresses: []string{ep.Host}, + }}, + Ports: []discoveryv1.EndpointPort{{ + Name: ptr(portName), + Port: &port, + Protocol: ptr(corev1.ProtocolTCP), + }}, + } + if err := controllerutil.SetControllerReference(server, desired, p.scheme); err != nil { + return fmt.Errorf("set owner reference on slice: %w", err) + } + existing := &discoveryv1.EndpointSlice{} + err := p.kubeClient.Get(ctx, key, existing) + if apierrors.IsNotFound(err) { + if createErr := p.kubeClient.Create(ctx, desired); createErr != nil && !apierrors.IsAlreadyExists(createErr) { + return fmt.Errorf("create slice: %w", createErr) + } + return nil + } + if err != nil { + return fmt.Errorf("get slice: %w", err) + } + existing.Labels = desired.Labels + existing.AddressType = desired.AddressType + existing.Endpoints = desired.Endpoints + existing.Ports = desired.Ports + if err := p.kubeClient.Update(ctx, existing); err != nil { + return fmt.Errorf("update slice: %w", err) + } + return nil +} + +// listSlices returns every EndpointSlice in the server's namespace whose +// kubernetes.io/service-name label matches the published Service. Used by +// Sync (orphan detection) and Get (drift check). +func (p *defaultPublisher) listSlices(ctx context.Context, server *v2pb.InferenceServer) (*discoveryv1.EndpointSliceList, error) { + out := &discoveryv1.EndpointSliceList{} + err := p.kubeClient.List(ctx, out, + client.InNamespace(server.Namespace), + client.MatchingLabels{kubeServiceNameLabel: serviceName(server)}, + ) + return out, err +} + +func serviceName(server *v2pb.InferenceServer) string { + return server.Name + endpointsServiceSuffix +} + +func sliceName(server *v2pb.InferenceServer, clusterID string) string { + return server.Name + endpointsServiceSuffix + "-" + clusterID +} + +func ptr[T any](v T) *T { return &v } diff --git a/go/components/inferenceserver/endpoints/publisher_test.go b/go/components/inferenceserver/endpoints/publisher_test.go new file mode 100644 index 000000000..a9072d214 --- /dev/null +++ b/go/components/inferenceserver/endpoints/publisher_test.go @@ -0,0 +1,207 @@ +package endpoints + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + discoveryv1 "k8s.io/api/discovery/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + + v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" +) + +const ( + testNamespace = "default" + testServer = "test-is" +) + +func newFixture(t *testing.T, existing ...client.Object) (EndpointPublisher, client.Client, *v2pb.InferenceServer) { + t.Helper() + scheme := runtime.NewScheme() + require.NoError(t, corev1.AddToScheme(scheme)) + require.NoError(t, discoveryv1.AddToScheme(scheme)) + require.NoError(t, v2pb.AddToScheme(scheme)) + + server := &v2pb.InferenceServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: testServer, + Namespace: testNamespace, + UID: "test-uid", + }, + } + objects := append([]client.Object{server}, existing...) + c := fake.NewClientBuilder().WithScheme(scheme).WithObjects(objects...).Build() + return NewDefaultPublisher(c, scheme), c, server +} + +// TestSync_CreatesServiceAndSlices covers the cold-start path: no existing +// objects, Sync should create the parent Service plus one EndpointSlice per +// cluster ID with the port-name join wired up correctly. +func TestSync_CreatesServiceAndSlices(t *testing.T) { + pub, c, server := newFixture(t) + endpoints := map[string]Endpoint{ + "clusterA": {Host: "10.0.0.1", Port: 31001, Scheme: "http"}, + "clusterB": {Host: "10.0.0.2", Port: 31002, Scheme: "http"}, + } + + require.NoError(t, pub.Sync(context.Background(), server, endpoints)) + + svc := &corev1.Service{} + require.NoError(t, c.Get(context.Background(), types.NamespacedName{Name: "test-is-endpoints", Namespace: testNamespace}, svc)) + assert.Equal(t, corev1.ServiceTypeClusterIP, svc.Spec.Type) + require.Len(t, svc.Spec.Ports, 1) + assert.Equal(t, portName, svc.Spec.Ports[0].Name, "service port name must match endpointslice port name (the gateway-controller join key)") + assert.Equal(t, servicePort, svc.Spec.Ports[0].Port) + assert.Empty(t, svc.Spec.Selector, "service must have no selector — its EndpointSlices are managed explicitly") + + slices := &discoveryv1.EndpointSliceList{} + require.NoError(t, c.List(context.Background(), slices, client.MatchingLabels{kubeServiceNameLabel: "test-is-endpoints"})) + require.Len(t, slices.Items, 2) + + for _, slice := range slices.Items { + require.Len(t, slice.Ports, 1) + assert.Equal(t, portName, *slice.Ports[0].Name, "endpointslice port name must match service port name") + require.Len(t, slice.Endpoints, 1) + require.Len(t, slice.Endpoints[0].Addresses, 1) + clusterID := slice.Labels[clusterIDLabel] + expected, ok := endpoints[clusterID] + require.True(t, ok, "slice has unexpected cluster ID label %q", clusterID) + assert.Equal(t, expected.Host, slice.Endpoints[0].Addresses[0]) + assert.Equal(t, expected.Port, *slice.Ports[0].Port) + } +} + +// TestSync_Idempotent covers the warm-state path: calling Sync twice with the +// same input should not create duplicate objects or error. +func TestSync_Idempotent(t *testing.T) { + pub, c, server := newFixture(t) + endpoints := map[string]Endpoint{ + "clusterA": {Host: "10.0.0.1", Port: 31001, Scheme: "http"}, + } + require.NoError(t, pub.Sync(context.Background(), server, endpoints)) + require.NoError(t, pub.Sync(context.Background(), server, endpoints)) + + svcs := &corev1.ServiceList{} + require.NoError(t, c.List(context.Background(), svcs, client.InNamespace(testNamespace))) + assert.Len(t, svcs.Items, 1) + + slices := &discoveryv1.EndpointSliceList{} + require.NoError(t, c.List(context.Background(), slices, client.MatchingLabels{kubeServiceNameLabel: "test-is-endpoints"})) + assert.Len(t, slices.Items, 1) +} + +// TestSync_OrphanDeletion covers the convergence path: when a cluster ID is +// removed from the desired map, Sync must delete its stale EndpointSlice. +func TestSync_OrphanDeletion(t *testing.T) { + pub, c, server := newFixture(t) + require.NoError(t, pub.Sync(context.Background(), server, map[string]Endpoint{ + "clusterA": {Host: "10.0.0.1", Port: 31001}, + "clusterB": {Host: "10.0.0.2", Port: 31002}, + })) + + require.NoError(t, pub.Sync(context.Background(), server, map[string]Endpoint{ + "clusterA": {Host: "10.0.0.1", Port: 31001}, + })) + + slices := &discoveryv1.EndpointSliceList{} + require.NoError(t, c.List(context.Background(), slices, client.MatchingLabels{kubeServiceNameLabel: "test-is-endpoints"})) + require.Len(t, slices.Items, 1) + assert.Equal(t, "clusterA", slices.Items[0].Labels[clusterIDLabel]) +} + +// TestSync_UpdatesExistingSlice covers spec drift: when an existing slice's +// host/port changes, Sync must update the slice rather than create a duplicate. +func TestSync_UpdatesExistingSlice(t *testing.T) { + pub, c, server := newFixture(t) + require.NoError(t, pub.Sync(context.Background(), server, map[string]Endpoint{ + "clusterA": {Host: "10.0.0.1", Port: 31001}, + })) + require.NoError(t, pub.Sync(context.Background(), server, map[string]Endpoint{ + "clusterA": {Host: "10.0.0.99", Port: 31999}, + })) + + slice := &discoveryv1.EndpointSlice{} + require.NoError(t, c.Get(context.Background(), types.NamespacedName{ + Name: "test-is-endpoints-clusterA", + Namespace: testNamespace, + }, slice)) + assert.Equal(t, "10.0.0.99", slice.Endpoints[0].Addresses[0]) + assert.Equal(t, int32(31999), *slice.Ports[0].Port) +} + +// TestGet_ReturnsPublishedEndpoints covers the read path used by the actor's +// Retrieve to detect drift: Get must round-trip whatever Sync wrote. +func TestGet_ReturnsPublishedEndpoints(t *testing.T) { + pub, _, server := newFixture(t) + desired := map[string]Endpoint{ + "clusterA": {Host: "10.0.0.1", Port: 31001, Scheme: "http"}, + "clusterB": {Host: "10.0.0.2", Port: 31002, Scheme: "http"}, + } + require.NoError(t, pub.Sync(context.Background(), server, desired)) + + got, err := pub.Get(context.Background(), server) + require.NoError(t, err) + assert.Equal(t, desired, got) +} + +// TestGet_EmptyWhenNoSlices covers the cold-start branch of Retrieve. +func TestGet_EmptyWhenNoSlices(t *testing.T) { + pub, _, server := newFixture(t) + got, err := pub.Get(context.Background(), server) + require.NoError(t, err) + assert.Empty(t, got) +} + +// TestGet_FiltersByServiceLabel ensures Get does not return EndpointSlices +// belonging to a different InferenceServer in the same namespace. +func TestGet_FiltersByServiceLabel(t *testing.T) { + otherSlice := &discoveryv1.EndpointSlice{ + ObjectMeta: metav1.ObjectMeta{ + Name: "other-is-endpoints-clusterX", + Namespace: testNamespace, + Labels: map[string]string{ + kubeServiceNameLabel: "other-is-endpoints", + clusterIDLabel: "clusterX", + }, + }, + AddressType: discoveryv1.AddressTypeIPv4, + Endpoints: []discoveryv1.Endpoint{{Addresses: []string{"10.99.0.1"}}}, + Ports: []discoveryv1.EndpointPort{{Name: ptr(portName), Port: ptr(int32(99999))}}, + } + pub, _, server := newFixture(t, otherSlice) + require.NoError(t, pub.Sync(context.Background(), server, map[string]Endpoint{ + "clusterA": {Host: "10.0.0.1", Port: 31001, Scheme: "http"}, + })) + + got, err := pub.Get(context.Background(), server) + require.NoError(t, err) + assert.Equal(t, map[string]Endpoint{ + "clusterA": {Host: "10.0.0.1", Port: 31001, Scheme: "http"}, + }, got) +} + +// TestDelete_RemovesServiceAndSlices covers the explicit-teardown path used +// when something other than IS deletion needs to wipe the published surface. +func TestDelete_RemovesServiceAndSlices(t *testing.T) { + pub, c, server := newFixture(t) + require.NoError(t, pub.Sync(context.Background(), server, map[string]Endpoint{ + "clusterA": {Host: "10.0.0.1", Port: 31001}, + })) + + require.NoError(t, pub.Delete(context.Background(), server)) + + err := c.Get(context.Background(), types.NamespacedName{Name: "test-is-endpoints", Namespace: testNamespace}, &corev1.Service{}) + assert.True(t, apierrors.IsNotFound(err), "service should be gone after Delete; got %v", err) + + slices := &discoveryv1.EndpointSliceList{} + require.NoError(t, c.List(context.Background(), slices, client.MatchingLabels{kubeServiceNameLabel: "test-is-endpoints"})) + assert.Empty(t, slices.Items) +} diff --git a/go/components/inferenceserver/endpoints/source/BUILD.bazel b/go/components/inferenceserver/endpoints/source/BUILD.bazel new file mode 100644 index 000000000..5afb70618 --- /dev/null +++ b/go/components/inferenceserver/endpoints/source/BUILD.bazel @@ -0,0 +1,22 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "go_default_library", + srcs = [ + "k3d.go", + "module.go", + ], + importpath = "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/endpoints/source", + visibility = ["//visibility:public"], + deps = [ + "//go/base/config:go_default_library", + "//go/components/inferenceserver/clientfactory:go_default_library", + "//go/components/inferenceserver/endpoints:go_default_library", + "//proto/api/v2:go_default_library", + "@io_k8s_api//core/v1:go_default_library", + "@io_k8s_apimachinery//pkg/types:go_default_library", + "@io_k8s_sigs_controller_runtime//pkg/client:go_default_library", + "@org_uber_go_fx//:go_default_library", + "@org_uber_go_zap//:go_default_library", + ], +) diff --git a/go/components/inferenceserver/endpoints/source/k3d.go b/go/components/inferenceserver/endpoints/source/k3d.go new file mode 100644 index 000000000..a9c84d7eb --- /dev/null +++ b/go/components/inferenceserver/endpoints/source/k3d.go @@ -0,0 +1,108 @@ +// Package source provides EndpointSource implementations for specific +// cluster environments. +package source + +import ( + "context" + "fmt" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + maconfig "github.com/michelangelo-ai/michelangelo/go/base/config" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/endpoints" + v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" +) + +var _ endpoints.EndpointSource = &K3dSource{} + +// K3dSource resolves a cluster's ingress address using the k3d networking +// model: the gateway Service is read for its NodePort, and a Node's InternalIP +// is read as a routable address. Both reads target the cluster identified by +// the ClusterTarget via the supplied ClientFactory. +type K3dSource struct { + clientFactory clientfactory.ClientFactory + config maconfig.InferenceServerConfig + logger *zap.Logger +} + +// NewK3dSource returns an EndpointSource for k3d-style clusters where the +// gateway Service is exposed via NodePort and addressable from peers on the +// same docker network at any node's InternalIP. +func NewK3dSource(clientFactory clientfactory.ClientFactory, config maconfig.InferenceServerConfig, logger *zap.Logger) *K3dSource { + return &K3dSource{ + clientFactory: clientFactory, + config: config, + logger: logger.With(zap.String("component", "k3d-endpoint-source")), + } +} + +// Resolve returns the Endpoint at which the target cluster's ingress gateway +// admits traffic. Returns an error when the gateway Service is missing, has +// no NodePort on the configured named port, or no node has an InternalIP. +func (s *K3dSource) Resolve(ctx context.Context, target *v2pb.ClusterTarget) (endpoints.Endpoint, error) { + kubeClient, err := s.clientFactory.GetClient(ctx, target) + if err != nil { + return endpoints.Endpoint{}, fmt.Errorf("get client for cluster %q: %w", target.GetClusterId(), err) + } + + nodePort, err := s.gatewayNodePort(ctx, kubeClient, target.GetClusterId()) + if err != nil { + return endpoints.Endpoint{}, err + } + nodeAddr, err := s.firstNodeInternalIP(ctx, kubeClient, target.GetClusterId()) + if err != nil { + return endpoints.Endpoint{}, err + } + + return endpoints.Endpoint{ + Host: nodeAddr, + Port: nodePort, + Scheme: "http", + }, nil +} + +// gatewayNodePort fetches the gateway Service identified by config and returns +// the NodePort assigned to the named port. +func (s *K3dSource) gatewayNodePort(ctx context.Context, kubeClient client.Client, clusterID string) (int32, error) { + gw := s.config.Gateway + svc := &corev1.Service{} + key := types.NamespacedName{Name: gw.ServiceName, Namespace: gw.ServiceNamespace} + if err := kubeClient.Get(ctx, key, svc); err != nil { + return 0, fmt.Errorf("get gateway service %s/%s on cluster %q: %w", + gw.ServiceNamespace, gw.ServiceName, clusterID, err) + } + for _, port := range svc.Spec.Ports { + if port.Name != gw.PortName { + continue + } + if port.NodePort == 0 { + return 0, fmt.Errorf("gateway service %s/%s on cluster %q has no NodePort on port %q (Service type may not be NodePort)", + gw.ServiceNamespace, gw.ServiceName, clusterID, gw.PortName) + } + return port.NodePort, nil + } + return 0, fmt.Errorf("gateway service %s/%s on cluster %q has no port named %q", + gw.ServiceNamespace, gw.ServiceName, clusterID, gw.PortName) +} + +// firstNodeInternalIP returns the InternalIP of the first Node listed by the +// target cluster's API. Any node's InternalIP suffices because the gateway is +// a NodePort exposed on every node. +func (s *K3dSource) firstNodeInternalIP(ctx context.Context, kubeClient client.Client, clusterID string) (string, error) { + nodes := &corev1.NodeList{} + if err := kubeClient.List(ctx, nodes); err != nil { + return "", fmt.Errorf("list nodes on cluster %q: %w", clusterID, err) + } + for _, node := range nodes.Items { + for _, addr := range node.Status.Addresses { + if addr.Type == corev1.NodeInternalIP && addr.Address != "" { + return addr.Address, nil + } + } + } + return "", fmt.Errorf("no node on cluster %q reported an InternalIP", clusterID) +} diff --git a/go/components/inferenceserver/endpoints/source/module.go b/go/components/inferenceserver/endpoints/source/module.go new file mode 100644 index 000000000..ec0f3c6ef --- /dev/null +++ b/go/components/inferenceserver/endpoints/source/module.go @@ -0,0 +1,20 @@ +package source + +import ( + "go.uber.org/fx" + "go.uber.org/zap" + + maconfig "github.com/michelangelo-ai/michelangelo/go/base/config" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/endpoints" +) + +// Module binds the k3d EndpointSource into the fx graph. Include this module +// alongside endpoints.Module when running against k3d clusters. +var Module = fx.Options( + fx.Provide(newK3dSource), +) + +func newK3dSource(clientFactory clientfactory.ClientFactory, isConfig maconfig.InferenceServerConfig, logger *zap.Logger) endpoints.EndpointSource { + return NewK3dSource(clientFactory, isConfig, logger) +} diff --git a/go/components/inferenceserver/module.go b/go/components/inferenceserver/module.go index f9a544495..41a35453f 100644 --- a/go/components/inferenceserver/module.go +++ b/go/components/inferenceserver/module.go @@ -6,11 +6,15 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/endpoints" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/endpoints/source" ) // Module provides the inference server controller with all dependencies var Module = fx.Options( clientfactory.Module, + endpoints.Module, + source.Module, fx.Provide(newEventRecorder), fx.Provide(NewReconciler), fx.Invoke(register), diff --git a/go/components/inferenceserver/plugins/oss/BUILD.bazel b/go/components/inferenceserver/plugins/oss/BUILD.bazel index 111e5e3ec..79be51254 100644 --- a/go/components/inferenceserver/plugins/oss/BUILD.bazel +++ b/go/components/inferenceserver/plugins/oss/BUILD.bazel @@ -12,6 +12,7 @@ go_library( "//go/base/conditions/interfaces:go_default_library", "//go/components/inferenceserver/backends:go_default_library", "//go/components/inferenceserver/clientfactory:go_default_library", + "//go/components/inferenceserver/endpoints:go_default_library", "//go/components/inferenceserver/modelconfig:go_default_library", "//go/components/inferenceserver/plugins:go_default_library", "//go/components/inferenceserver/plugins/oss/creation:go_default_library", diff --git a/go/components/inferenceserver/plugins/oss/common/constants.go b/go/components/inferenceserver/plugins/oss/common/constants.go index b5c9686fc..d0405439c 100644 --- a/go/components/inferenceserver/plugins/oss/common/constants.go +++ b/go/components/inferenceserver/plugins/oss/common/constants.go @@ -6,4 +6,5 @@ const ( BackendProvisionConditionType = "BackendProvision" ModelConfigProvisionConditionType = "ModelConfigProvision" ValidationConditionType = "Validation" + EndpointPublishConditionType = "EndpointPublish" ) diff --git a/go/components/inferenceserver/plugins/oss/creation/BUILD.bazel b/go/components/inferenceserver/plugins/oss/creation/BUILD.bazel index 4955679f7..44d54ceaf 100644 --- a/go/components/inferenceserver/plugins/oss/creation/BUILD.bazel +++ b/go/components/inferenceserver/plugins/oss/creation/BUILD.bazel @@ -5,6 +5,7 @@ go_library( srcs = [ "backend_provision.go", "condition_plugin.go", + "endpoint_publish.go", "health_check.go", "model_config_provision.go", "validation.go", @@ -16,6 +17,7 @@ go_library( "//go/base/conditions/utils:go_default_library", "//go/components/inferenceserver/backends:go_default_library", "//go/components/inferenceserver/clientfactory:go_default_library", + "//go/components/inferenceserver/endpoints:go_default_library", "//go/components/inferenceserver/modelconfig:go_default_library", "//go/components/inferenceserver/plugins/oss/common:go_default_library", "//proto/api:go_default_library", diff --git a/go/components/inferenceserver/plugins/oss/creation/condition_plugin.go b/go/components/inferenceserver/plugins/oss/creation/condition_plugin.go index d7ad4b39c..87d59732a 100644 --- a/go/components/inferenceserver/plugins/oss/creation/condition_plugin.go +++ b/go/components/inferenceserver/plugins/oss/creation/condition_plugin.go @@ -6,6 +6,7 @@ import ( conditionInterfaces "github.com/michelangelo-ai/michelangelo/go/base/conditions/interfaces" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/backends" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/endpoints" modelconfig "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/modelconfig" apipb "github.com/michelangelo-ai/michelangelo/proto-go/api" v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" @@ -16,15 +17,19 @@ type CreationPlugin struct { clientFactory clientfactory.ClientFactory registry *backends.Registry modelConfigProvider modelconfig.ModelConfigProvider + endpointPublisher endpoints.EndpointPublisher + endpointSource endpoints.EndpointSource logger *zap.Logger } // NewCreationPlugin creates a plugin that manages validation, provisioning, health checks, and routing. -func NewCreationPlugin(clientFactory clientfactory.ClientFactory, registry *backends.Registry, modelConfigProvider modelconfig.ModelConfigProvider, logger *zap.Logger) conditionInterfaces.Plugin[*v2pb.InferenceServer] { +func NewCreationPlugin(clientFactory clientfactory.ClientFactory, registry *backends.Registry, modelConfigProvider modelconfig.ModelConfigProvider, endpointPublisher endpoints.EndpointPublisher, endpointSource endpoints.EndpointSource, logger *zap.Logger) conditionInterfaces.Plugin[*v2pb.InferenceServer] { return &CreationPlugin{ clientFactory: clientFactory, registry: registry, modelConfigProvider: modelConfigProvider, + endpointPublisher: endpointPublisher, + endpointSource: endpointSource, logger: logger, } } @@ -34,6 +39,7 @@ func (p *CreationPlugin) GetActors() []conditionInterfaces.ConditionActor[*v2pb. return []conditionInterfaces.ConditionActor[*v2pb.InferenceServer]{ NewValidationActor(p.registry, p.logger), NewBackendProvisionActor(p.clientFactory, p.registry, p.logger), + NewEndpointPublishActor(p.endpointPublisher, p.endpointSource, p.logger), NewModelConfigProvisionActor(p.clientFactory, p.modelConfigProvider, p.logger), NewHealthCheckActor(p.clientFactory, p.registry, p.logger), } diff --git a/go/components/inferenceserver/plugins/oss/creation/endpoint_publish.go b/go/components/inferenceserver/plugins/oss/creation/endpoint_publish.go new file mode 100644 index 000000000..3ea21a868 --- /dev/null +++ b/go/components/inferenceserver/plugins/oss/creation/endpoint_publish.go @@ -0,0 +1,143 @@ +package creation + +import ( + "context" + "fmt" + "sort" + "strings" + + "go.uber.org/zap" + + conditionInterfaces "github.com/michelangelo-ai/michelangelo/go/base/conditions/interfaces" + conditionsutil "github.com/michelangelo-ai/michelangelo/go/base/conditions/utils" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/endpoints" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/plugins/oss/common" + apipb "github.com/michelangelo-ai/michelangelo/proto-go/api" + v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2" +) + +var _ conditionInterfaces.ConditionActor[*v2pb.InferenceServer] = &EndpointPublishActor{} + +// EndpointPublishActor reconciles the per-cluster endpoints published for an +// InferenceServer, so other components in the control plane can address the +// server across all the clusters it is reachable in. +type EndpointPublishActor struct { + publisher endpoints.EndpointPublisher + source endpoints.EndpointSource + logger *zap.Logger +} + +// NewEndpointPublishActor creates the condition actor that maintains the +// control-plane Service + per-cluster EndpointSlices for an InferenceServer. +func NewEndpointPublishActor(publisher endpoints.EndpointPublisher, source endpoints.EndpointSource, logger *zap.Logger) conditionInterfaces.ConditionActor[*v2pb.InferenceServer] { + return &EndpointPublishActor{ + publisher: publisher, + source: source, + logger: logger, + } +} + +// GetType returns the condition type identifier for endpoint publishing. +func (a *EndpointPublishActor) GetType() string { + return common.EndpointPublishConditionType +} + +// Retrieve checks that the published EndpointSlices match the spec's cluster +// set: every ClusterTarget has a slice, and there are no orphan slices for +// clusters removed from the spec. +func (a *EndpointPublishActor) Retrieve(ctx context.Context, resource *v2pb.InferenceServer, condition *apipb.Condition) (*apipb.Condition, error) { + a.logger.Info("Retrieving endpoint publish condition") + + desired := desiredClusterIDs(resource) + observed, err := a.publisher.Get(ctx, resource) + if err != nil { + a.logger.Error("Failed to read published endpoints", + zap.Error(err), + zap.String("operation", "get_published_endpoints"), + zap.String("namespace", resource.Namespace), + zap.String("inferenceServer", resource.Name)) + return conditionsutil.GenerateFalseCondition(condition, "GetFailed", err.Error()), nil + } + observedIDs := observedClusterIDs(observed) + + if missing := setDiff(desired, observedIDs); len(missing) > 0 { + return conditionsutil.GenerateFalseCondition(condition, "EndpointSliceMissing", strings.Join(missing, ",")), nil + } + if extra := setDiff(observedIDs, desired); len(extra) > 0 { + return conditionsutil.GenerateFalseCondition(condition, "OrphanEndpointSlice", strings.Join(extra, ",")), nil + } + return conditionsutil.GenerateTrueCondition(condition), nil +} + +// Run resolves each ClusterTarget's Gateway endpoint in a parallel fan-out, +// then a single Sync call reconciles the published Service + EndpointSlices +// (creating missing, deleting orphans). Partial resolve failures surface as +// UNKNOWN so a transient error on one cluster does not flip the +// IS to STATE_FAILED. +func (a *EndpointPublishActor) Run(ctx context.Context, resource *v2pb.InferenceServer, condition *apipb.Condition) (*apipb.Condition, error) { + a.logger.Info("Running endpoint publish") + + endpointMap := map[string]endpoints.Endpoint{} + var failures []string + for _, target := range resource.Spec.ClusterTargets { + ep, err := a.source.Resolve(ctx, target) + if err != nil { + a.logger.Error("Failed to resolve cluster endpoint", + zap.Error(err), + zap.String("operation", "resolve_endpoint"), + zap.String("namespace", resource.Namespace), + zap.String("inferenceServer", resource.Name), + zap.String("cluster_id", target.GetClusterId())) + failures = append(failures, fmt.Sprintf("%s: resolve: %v", target.GetClusterId(), err)) + continue + } + endpointMap[target.GetClusterId()] = ep + } + + if err := a.publisher.Sync(ctx, resource, endpointMap); err != nil { + a.logger.Error("Failed to sync published endpoints", + zap.Error(err), + zap.String("operation", "sync_endpoints"), + zap.String("namespace", resource.Namespace), + zap.String("inferenceServer", resource.Name)) + return conditionsutil.GenerateFalseCondition(condition, "SyncFailed", err.Error()), nil + } + + if len(failures) > 0 { + // Partial resolve failures are transient, so report UNKNOWN. + return conditionsutil.GenerateUnknownCondition(condition, "PartialEndpointPublish", strings.Join(failures, "; ")), nil + } + return conditionsutil.GenerateTrueCondition(condition), nil +} + +// desiredClusterIDs returns the set of cluster IDs currently in the IS spec. +func desiredClusterIDs(resource *v2pb.InferenceServer) map[string]struct{} { + out := make(map[string]struct{}, len(resource.Spec.ClusterTargets)) + for _, target := range resource.Spec.ClusterTargets { + out[target.GetClusterId()] = struct{}{} + } + return out +} + +// observedClusterIDs lifts a published endpoint map to a key set so it can be +// compared against the desired set with a single setDiff helper. +func observedClusterIDs(m map[string]endpoints.Endpoint) map[string]struct{} { + out := make(map[string]struct{}, len(m)) + for k := range m { + out[k] = struct{}{} + } + return out +} + +// setDiff returns the sorted slice of keys in `a` that are not in `b`. +// Sorted output keeps condition messages deterministic between reconciles. +func setDiff(a, b map[string]struct{}) []string { + var out []string + for k := range a { + if _, found := b[k]; !found { + out = append(out, k) + } + } + sort.Strings(out) + return out +} diff --git a/go/components/inferenceserver/plugins/oss/plugin.go b/go/components/inferenceserver/plugins/oss/plugin.go index 4840b9296..de4fb8497 100644 --- a/go/components/inferenceserver/plugins/oss/plugin.go +++ b/go/components/inferenceserver/plugins/oss/plugin.go @@ -11,6 +11,7 @@ import ( conditionInterfaces "github.com/michelangelo-ai/michelangelo/go/base/conditions/interfaces" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/backends" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory" + "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/endpoints" modelconfig "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/modelconfig" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/plugins" "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/plugins/oss/creation" @@ -34,9 +35,9 @@ type Plugin struct { } // NewPlugin creates a plugin with creation and deletion workflows. -func NewOSSPlugin(clientFactory clientfactory.ClientFactory, registry *backends.Registry, modelConfigProvider modelconfig.ModelConfigProvider, recorder record.EventRecorder, logger *zap.Logger) plugins.Plugin { +func NewOSSPlugin(clientFactory clientfactory.ClientFactory, registry *backends.Registry, modelConfigProvider modelconfig.ModelConfigProvider, endpointPublisher endpoints.EndpointPublisher, endpointSource endpoints.EndpointSource, recorder record.EventRecorder, logger *zap.Logger) plugins.Plugin { return &Plugin{ - creationPlugin: creation.NewCreationPlugin(clientFactory, registry, modelConfigProvider, logger), + creationPlugin: creation.NewCreationPlugin(clientFactory, registry, modelConfigProvider, endpointPublisher, endpointSource, logger), deletionPlugin: deletion.NewDeletionPlugin(clientFactory, registry, modelConfigProvider, logger), clientFactory: clientFactory, diff --git a/python/michelangelo/cli/sandbox/resources/gateway-api-setup.yaml b/python/michelangelo/cli/sandbox/resources/gateway-api-setup.yaml index 2f28b650b..c1c8861b9 100644 --- a/python/michelangelo/cli/sandbox/resources/gateway-api-setup.yaml +++ b/python/michelangelo/cli/sandbox/resources/gateway-api-setup.yaml @@ -19,7 +19,7 @@ metadata: name: ma-gateway namespace: default annotations: - networking.istio.io/service-type: ClusterIP + networking.istio.io/service-type: NodePort spec: gatewayClassName: istio listeners: diff --git a/python/michelangelo/cli/sandbox/resources/michelangelo-controllermgr.yaml b/python/michelangelo/cli/sandbox/resources/michelangelo-controllermgr.yaml index b9c492377..b1fbbd8c9 100644 --- a/python/michelangelo/cli/sandbox/resources/michelangelo-controllermgr.yaml +++ b/python/michelangelo/cli/sandbox/resources/michelangelo-controllermgr.yaml @@ -103,3 +103,13 @@ data: domain: default taskList: default executionUrlFormat: http://localhost:8088/namespaces/{{.Domain}}/workflows/{{.ExecutionID}} + + inferenceServer: + gateway: + # Istio's Gateway controller materializes a Service named '-istio' + # for each Gateway resource. The IS controller's EndpointSource reads the + # NodePort + node InternalIP of this Service to publish per-cluster + # EndpointSlices in the control plane. + serviceName: ma-gateway-istio + serviceNamespace: default + portName: http diff --git a/python/michelangelo/cli/sandbox/resources/michelangelo-temporal-controllermgr.yaml b/python/michelangelo/cli/sandbox/resources/michelangelo-temporal-controllermgr.yaml index 88be80756..fde05c577 100644 --- a/python/michelangelo/cli/sandbox/resources/michelangelo-temporal-controllermgr.yaml +++ b/python/michelangelo/cli/sandbox/resources/michelangelo-temporal-controllermgr.yaml @@ -68,3 +68,13 @@ data: transport: grpc domain: default provider: Temporal + + inferenceServer: + gateway: + # Istio's Gateway controller materializes a Service named '-istio' + # for each Gateway resource. The IS controller's EndpointSource reads the + # NodePort + node InternalIP of this Service to publish per-cluster + # EndpointSlices in the control plane. + serviceName: ma-gateway-istio + serviceNamespace: default + portName: http diff --git a/python/michelangelo/cli/sandbox/resources/rbac-inferenceserver.yaml b/python/michelangelo/cli/sandbox/resources/rbac-inferenceserver.yaml index 4eec63115..e49b92da6 100644 --- a/python/michelangelo/cli/sandbox/resources/rbac-inferenceserver.yaml +++ b/python/michelangelo/cli/sandbox/resources/rbac-inferenceserver.yaml @@ -15,6 +15,12 @@ rules: - apiGroups: [""] resources: ["services", "configmaps"] verbs: ["create", "get", "list", "watch", "update", "patch", "delete"] +- apiGroups: [""] + resources: ["nodes"] + verbs: ["get", "list", "watch"] +- apiGroups: ["discovery.k8s.io"] + resources: ["endpointslices"] + verbs: ["create", "get", "list", "watch", "update", "patch", "delete"] --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding