Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions go/base/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
_workflowClientConfigKey = "workflowClient"
_mysqlConfigKey = "mysql"
_ingesterConfigKey = "ingester"
_inferenceServerConfigKey = "inferenceServer"
)

// K8sConfig is the configuration for k8s REST client.
Expand Down Expand Up @@ -130,3 +131,27 @@ func GetIngesterConfig(provider config.Provider) (IngesterConfig, error) {
err := provider.Get(_ingesterConfigKey).Populate(&ingesterConfig)
return ingesterConfig, err
}

// InferenceServerConfig is the controller-side configuration for the inference
// server controller.
type InferenceServerConfig struct {
Gateway GatewayConfig `yaml:"gateway"`
}

// GatewayConfig describes how the inference server controller locates a
// cluster's ingress Gateway Service. The Service is set up out-of-band (e.g.,
// by `ma sandbox create` for sandbox); this config tells the EndpointSource
// where to find it.
type GatewayConfig struct {
ServiceName string `yaml:"serviceName"`
ServiceNamespace string `yaml:"serviceNamespace"`
PortName string `yaml:"portName"`
}

// GetInferenceServerConfig parses the configuration file and returns the
// inference server controller configuration.
func GetInferenceServerConfig(provider config.Provider) (InferenceServerConfig, error) {
inferenceServerConfig := InferenceServerConfig{}
err := provider.Get(_inferenceServerConfigKey).Populate(&inferenceServerConfig)
return inferenceServerConfig, err
}
12 changes: 11 additions & 1 deletion go/cmd/controllermgr/config/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,14 @@ minio:
awsRegion: us-east-1
awsAccessKeyId: minioadmin
awsSecretAccessKey: minioadmin
awsEndpointUrl: localhost:9091
awsEndpointUrl: localhost:9091

inferenceServer:
gateway:
# Istio's Gateway controller materializes a Service named '<gateway>-istio'
# for each Gateway resource. The IS controller's EndpointSource reads the
# NodePort + node InternalIP of this Service to publish per-cluster
# EndpointSlices in the control plane.
serviceName: ma-gateway-istio
serviceNamespace: default
portName: http
3 changes: 3 additions & 0 deletions go/components/inferenceserver/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ go_library(
"//go/api/utils:go_default_library",
"//go/base/conditions/engine:go_default_library",
"//go/base/conditions/interfaces:go_default_library",
"//go/components/inferenceserver/clientfactory:go_default_library",
"//go/components/inferenceserver/endpoints:go_default_library",
"//go/components/inferenceserver/endpoints/source:go_default_library",
"//go/components/inferenceserver/plugins:go_default_library",
"//proto/api/v2:go_default_library",
"@io_k8s_api//core/v1:go_default_library",
Expand Down
24 changes: 24 additions & 0 deletions go/components/inferenceserver/clientfactory/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "go_default_library",
srcs = [
"factory.go",
"interface.go",
"module.go",
],
importpath = "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory",
visibility = ["//visibility:public"],
deps = [
"//go/components/inferenceserver/clientfactory/secrets:go_default_library",
"//proto/api/v2:go_default_library",
"@io_k8s_apimachinery//pkg/runtime:go_default_library",
"@io_k8s_client_go//rest:go_default_library",
"@io_k8s_client_go//tools/clientcmd:go_default_library",
"@io_k8s_client_go//tools/clientcmd/api:go_default_library",
"@io_k8s_client_go//util/flowcontrol:go_default_library",
"@io_k8s_sigs_controller_runtime//pkg/client:go_default_library",
"@org_uber_go_fx//:go_default_library",
"@org_uber_go_zap//:go_default_library",
],
)
221 changes: 221 additions & 0 deletions go/components/inferenceserver/clientfactory/factory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
package clientfactory

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
"sync"
"time"

"go.uber.org/zap"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
"k8s.io/client-go/util/flowcontrol"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory/secrets"
v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2"
)

const (
userAgent = "michelangelo-inferenceserver"

// httpClientTimeout caps any single HTTP request issued through the factory.
httpClientTimeout = 30 * time.Second
)

var _ ClientFactory = &remoteClientFactory{}

// remoteClientFactory builds and caches Kubernetes clients for ClusterTargets.
//
// Both controller-runtime clients and HTTP clients are cached keyed by the connection
// tuple (cluster_id + host + port) so a steady-state reconcile does not rebuild a TLS
// transport on every actor invocation.
type remoteClientFactory struct {
secretProvider secrets.SecretProvider
scheme *runtime.Scheme
logger *zap.Logger

kubeClients sync.Map // key string → client.Client
httpClients sync.Map // key string → *http.Client
mu sync.Mutex
}

// NewRemoteClientFactory constructs a ClientFactory.
//
// Parameters:
// - secretProvider: source of CA bundles and bearer tokens for remote clusters.
// - scheme: the runtime.Scheme used to build remote clients (must include all CRDs
// the controller will read/write on remote clusters).
// - logger: structured logger.
func NewRemoteClientFactory(
secretProvider secrets.SecretProvider,
scheme *runtime.Scheme,
logger *zap.Logger,
) ClientFactory {
return &remoteClientFactory{
secretProvider: secretProvider,
scheme: scheme,
logger: logger.With(zap.String("component", "clientfactory")),
}
}

// GetClient returns a controller-runtime client for the given ClusterTarget. The
// client is built (and cached) using credentials retrieved from the SecretProvider.
func (f *remoteClientFactory) GetClient(ctx context.Context, cluster *v2pb.ClusterTarget) (client.Client, error) {
if cluster.GetKubernetes() == nil {
return nil, fmt.Errorf("cluster %q has no kubernetes connection spec", cluster.GetClusterId())
}

key := cacheKey(cluster)
if cached, ok := f.kubeClients.Load(key); ok {
return cached.(client.Client), nil
}

// Building a kube client requires hitting the SecretProvider and constructing TLS
// state. Guard with a mutex so concurrent reconciles for the same cluster don't
// duplicate work.
f.mu.Lock()
defer f.mu.Unlock()

if cached, ok := f.kubeClients.Load(key); ok {
return cached.(client.Client), nil
}

cfg, err := f.buildRESTConfig(ctx, cluster)
if err != nil {
return nil, fmt.Errorf("build REST config for cluster %q: %w", cluster.GetClusterId(), err)
}

kubeClient, err := client.New(cfg, client.Options{Scheme: f.scheme})
if err != nil {
return nil, fmt.Errorf("create kube client for cluster %q: %w", cluster.GetClusterId(), err)
}

f.kubeClients.Store(key, kubeClient)
f.logger.Info("Built kube client for cluster",
zap.String("cluster_id", cluster.GetClusterId()),
zap.String("host", cluster.GetKubernetes().GetHost()))
return kubeClient, nil
}

// GetHTTPClient returns an HTTP client whose transport authenticates with a bearer
// token over TLS validated against the cluster's CA.
func (f *remoteClientFactory) GetHTTPClient(ctx context.Context, cluster *v2pb.ClusterTarget) (*http.Client, error) {
if cluster.GetKubernetes() == nil {
return nil, fmt.Errorf("cluster %q has no kubernetes connection spec", cluster.GetClusterId())
}

key := cacheKey(cluster)
if cached, ok := f.httpClients.Load(key); ok {
return cached.(*http.Client), nil
}

f.mu.Lock()
defer f.mu.Unlock()

if cached, ok := f.httpClients.Load(key); ok {
return cached.(*http.Client), nil
}

auth, err := f.secretProvider.GetClientAuth(ctx, cluster)
if err != nil {
return nil, fmt.Errorf("get client auth for cluster %q: %w", cluster.GetClusterId(), err)
}

caPool := x509.NewCertPool()
if !caPool.AppendCertsFromPEM([]byte(auth.CertificateAuthorityData)) {
return nil, fmt.Errorf("parse CA certificate for cluster %q: invalid PEM", cluster.GetClusterId())
}

httpClient := &http.Client{
Transport: &bearerTokenRoundTripper{
token: auth.ClientTokenData,
rt: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: caPool,
MinVersion: tls.VersionTLS12,
},
},
},
Timeout: httpClientTimeout,
}

f.httpClients.Store(key, httpClient)
f.logger.Info("Built HTTP client for cluster",
zap.String("cluster_id", cluster.GetClusterId()),
zap.String("host", cluster.GetKubernetes().GetHost()))
return httpClient, nil
}

// buildRESTConfig assembles a *rest.Config for a cluster from the connection spec
// and credentials retrieved from the SecretProvider.
func (f *remoteClientFactory) buildRESTConfig(ctx context.Context, cluster *v2pb.ClusterTarget) (*rest.Config, error) {
auth, err := f.secretProvider.GetClientAuth(ctx, cluster)
if err != nil {
return nil, fmt.Errorf("get client auth: %w", err)
}

server := fmt.Sprintf("%s:%s", cluster.GetKubernetes().GetHost(), cluster.GetKubernetes().GetPort())

// Build a kubeconfig in-memory and resolve it through clientcmd, which handles
// CA-data + bearer-token wiring via its established schema.
apiCfg := &clientcmdapi.Config{
Kind: "Config",
APIVersion: "v1",
Clusters: map[string]*clientcmdapi.Cluster{
"remote": {
Server: server,
CertificateAuthorityData: []byte(auth.CertificateAuthorityData),
},
},
AuthInfos: map[string]*clientcmdapi.AuthInfo{
userAgent: {Token: auth.ClientTokenData},
},
Contexts: map[string]*clientcmdapi.Context{
userAgent + "@remote": {
Cluster: "remote",
AuthInfo: userAgent,
},
},
CurrentContext: userAgent + "@remote",
}

cfg, err := clientcmd.NewDefaultClientConfig(*apiCfg, &clientcmd.ConfigOverrides{}).ClientConfig()
if err != nil {
return nil, fmt.Errorf("resolve client config: %w", err)
}

// Disable client-side rate limiting; rely on the API server's Priority and Fairness.
cfg.RateLimiter = flowcontrol.NewFakeAlwaysRateLimiter()
cfg.ContentType = runtime.ContentTypeJSON

return rest.AddUserAgent(cfg, userAgent), nil
}

// cacheKey produces a stable cache key for a ClusterTarget.
func cacheKey(cluster *v2pb.ClusterTarget) string {
return fmt.Sprintf("%s|%s:%s",
cluster.GetClusterId(),
cluster.GetKubernetes().GetHost(),
cluster.GetKubernetes().GetPort(),
)
}

// bearerTokenRoundTripper injects a bearer Authorization header on every request.
type bearerTokenRoundTripper struct {
token string
rt http.RoundTripper
}

// RoundTrip clones the request to avoid mutating the caller's headers, then forwards
// to the underlying transport.
func (rt *bearerTokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
clone := req.Clone(req.Context())
clone.Header.Set("Authorization", "Bearer "+rt.token)
return rt.rt.RoundTrip(clone)
}
26 changes: 26 additions & 0 deletions go/components/inferenceserver/clientfactory/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Package clientfactory provides Kubernetes API clients for ClusterTargets that an
// InferenceServer is provisioned across.
//
//go:generate mamockgen ClientFactory
package clientfactory

import (
"context"
"net/http"

"sigs.k8s.io/controller-runtime/pkg/client"

v2pb "github.com/michelangelo-ai/michelangelo/proto-go/api/v2"
)

// ClientFactory returns Kubernetes API clients for a ClusterTarget. The clients are
// built using credentials retrieved from the SecretProvider.
type ClientFactory interface {
// GetClient returns a controller-runtime client for the given cluster.
GetClient(ctx context.Context, cluster *v2pb.ClusterTarget) (client.Client, error)

// GetHTTPClient returns an HTTP client for talking to user-space services in the
// given cluster. The client is configured with TLS using the cluster's CA bundle
// and authenticates with the bearer token.
GetHTTPClient(ctx context.Context, cluster *v2pb.ClusterTarget) (*http.Client, error)
}
22 changes: 22 additions & 0 deletions go/components/inferenceserver/clientfactory/module.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package clientfactory

import (
"go.uber.org/fx"
"go.uber.org/zap"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory/secrets"
)

// Module wires the ClientFactory into the fx graph.
var Module = fx.Options(
fx.Provide(newClientFactory),
)

func newClientFactory(kubeClient client.Client, logger *zap.Logger) ClientFactory {
return NewRemoteClientFactory(
secrets.NewProvider(kubeClient),
kubeClient.Scheme(),
logger,
)
}
14 changes: 14 additions & 0 deletions go/components/inferenceserver/clientfactory/secrets/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "go_default_library",
srcs = ["provider.go"],
importpath = "github.com/michelangelo-ai/michelangelo/go/components/inferenceserver/clientfactory/secrets",
visibility = ["//visibility:public"],
deps = [
"//proto/api/v2:go_default_library",
"@io_k8s_api//core/v1:go_default_library",
"@io_k8s_apimachinery//pkg/types:go_default_library",
"@io_k8s_sigs_controller_runtime//pkg/client:go_default_library",
],
)
Loading