diff --git a/arm.go b/arm.go index 505b607..0ec5e6e 100644 --- a/arm.go +++ b/arm.go @@ -87,11 +87,12 @@ func (cfg *SO101ArmConfig) Validate(path string) ([]string, []string, error) { type so101 struct { resource.AlwaysRebuild - name resource.Name - logger logging.Logger - cfg *SO101ArmConfig - opMgr *operation.SingleOperationManager - controller *SafeSoArmController + name resource.Name + logger logging.Logger + cfg *SO101ArmConfig + opMgr *operation.SingleOperationManager + controller *SafeSoArmController + controllerPort string // port path used to acquire the shared controller mu sync.RWMutex moveLock sync.Mutex @@ -105,10 +106,6 @@ type so101 struct { defaultAcc float32 motion motion.Service - - cancelCtx context.Context - cancelFunc func() - initCtx context.Context // Context for initialization operations } func makeSO101ModelFrame() (referenceframe.Model, error) { @@ -227,42 +224,41 @@ func NewSO101(ctx context.Context, deps resource.Dependencies, name resource.Nam model, err := makeSO101ModelFrame() if err != nil { - ReleaseSharedController() // Clean up on error + globalRegistry.ReleaseController(controllerConfig.Port) return nil, fmt.Errorf("failed to create kinematic model: %w", err) } var ms motion.Service if conf.Motion != "" { if deps == nil { + globalRegistry.ReleaseController(controllerConfig.Port) return nil, fmt.Errorf("no deps") } ms, err = motion.FromProvider(deps, conf.Motion) if err != nil { + globalRegistry.ReleaseController(controllerConfig.Port) return nil, err } } else { ms, err = motion.FromProvider(deps, "builtin") if err != nil { + globalRegistry.ReleaseController(controllerConfig.Port) return nil, err } } - cancelCtx, cancelFunc := context.WithCancel(context.Background()) - arm := &so101{ - name: name, - cfg: conf, - opMgr: operation.NewSingleOperationManager(), - logger: logger, - controller: controller, - model: model, - armServoIDs: conf.ServoIDs, // Store which servos this arm controls - defaultSpeed: speedDegsPerSec, - defaultAcc: accelerationDegsPerSec, - motion: ms, - cancelCtx: cancelCtx, - cancelFunc: cancelFunc, - initCtx: ctx, // Store initialization context + name: name, + cfg: conf, + opMgr: operation.NewSingleOperationManager(), + logger: logger, + controller: controller, + controllerPort: controllerConfig.Port, + model: model, + armServoIDs: conf.ServoIDs, // Store which servos this arm controls + defaultSpeed: speedDegsPerSec, + defaultAcc: accelerationDegsPerSec, + motion: ms, } logger.Debugf("SO-101 configured with speed: %.1f deg/s, acceleration: %.1f deg/s²", @@ -270,8 +266,8 @@ func NewSO101(ctx context.Context, deps resource.Dependencies, name resource.Nam logger.Debugf("Arm controlling servo IDs: %v", arm.armServoIDs) // Initialize and verify servo connections - if err := arm.initializeServos(); err != nil { - ReleaseSharedController() // Clean up on error + if err := arm.initializeServos(ctx); err != nil { + globalRegistry.ReleaseController(controllerConfig.Port) return nil, fmt.Errorf("failed to initialize servos: %w", err) } @@ -470,14 +466,14 @@ func (s *so101) DoCommand(ctx context.Context, cmd map[string]interface{}) (map[ }, nil case "diagnose": - err := s.diagnoseConnection() + err := s.diagnoseConnection(ctx) return map[string]interface{}{ "success": err == nil, "error": fmt.Sprintf("%v", err), }, nil case "verify_config": - err := s.verifyServoConfig() + err := s.verifyServoConfig(ctx) return map[string]interface{}{ "success": err == nil, "error": fmt.Sprintf("%v", err), @@ -488,7 +484,7 @@ func (s *so101) DoCommand(ctx context.Context, cmd map[string]interface{}) (map[ if r, ok := cmd["retries"].(float64); ok { retries = int(r) } - err := s.initializeServosWithRetry(retries) + err := s.initializeServosWithRetry(ctx, retries) return map[string]interface{}{ "success": err == nil, "error": fmt.Sprintf("%v", err), @@ -618,25 +614,24 @@ func (s *so101) Geometries(ctx context.Context, extra map[string]interface{}) ([ } func (s *so101) Close(context.Context) error { - s.cancelFunc() - ReleaseSharedController() + globalRegistry.ReleaseController(s.controllerPort) return nil } // initializeServos pings each servo and enables torque to ensure proper communication -func (s *so101) initializeServos() error { - return s.initializeServosWithRetry(3) +func (s *so101) initializeServos(ctx context.Context) error { + return s.initializeServosWithRetry(ctx, 3) } // initializeServosWithRetry attempts servo initialization with retries -func (s *so101) initializeServosWithRetry(maxRetries int) error { +func (s *so101) initializeServosWithRetry(ctx context.Context, maxRetries int) error { s.logger.Debug("Initializing SO-101 arm servos...") var lastErr error for attempt := 1; attempt <= maxRetries; attempt++ { s.logger.Debugf("Arm servo initialization attempt %d/%d", attempt, maxRetries) - if err := s.doServoInitialization(); err != nil { + if err := s.doServoInitialization(ctx); err != nil { lastErr = err s.logger.Warnf("Initialization attempt %d failed: %v", attempt, err) @@ -656,10 +651,7 @@ func (s *so101) initializeServosWithRetry(maxRetries int) error { } // doServoInitialization performs the actual initialization steps -func (s *so101) doServoInitialization() error { - // Use stored initialization context instead of creating new one - ctx := s.initCtx - +func (s *so101) doServoInitialization(ctx context.Context) error { // Ping all servos to ensure they're responding s.logger.Debug("Pinging all servos...") if err := s.controller.Ping(ctx); err != nil { @@ -690,10 +682,7 @@ func (s *so101) doServoInitialization() error { } // diagnoseConnection provides detailed diagnostics for troubleshooting -func (s *so101) diagnoseConnection() error { - // Use stored initialization context instead of creating new one - ctx := s.initCtx - +func (s *so101) diagnoseConnection(ctx context.Context) error { s.logger.Debug("Starting SO-101 arm connection diagnosis...") // Test overall ping @@ -718,10 +707,7 @@ func (s *so101) diagnoseConnection() error { } // verifyServoConfig checks servo configuration -func (s *so101) verifyServoConfig() error { - // Use stored initialization context instead of creating new one - ctx := s.initCtx - +func (s *so101) verifyServoConfig(ctx context.Context) error { s.logger.Debug("Verifying arm servo configuration...") positions, err := s.controller.GetJointPositionsForServos(ctx, s.armServoIDs) diff --git a/calibration.go b/calibration.go index aa1bb75..f0cdc59 100644 --- a/calibration.go +++ b/calibration.go @@ -108,10 +108,11 @@ func (cfg *SO101CalibrationSensorConfig) Validate(path string) ([]string, []stri type so101CalibrationSensor struct { resource.AlwaysRebuild - name resource.Name - logger logging.Logger - cfg *SO101CalibrationSensorConfig - controller *SafeSoArmController + name resource.Name + logger logging.Logger + cfg *SO101CalibrationSensorConfig + controller *SafeSoArmController + controllerPort string // port path used to acquire the shared controller // Calibration state mu sync.RWMutex @@ -205,6 +206,7 @@ func NewSO101CalibrationSensor( logger: logger, cfg: conf, controller: controller, + controllerPort: controllerConfig.Port, state: StateIdle, joints: joints, servoNames: servoNames, @@ -1230,7 +1232,7 @@ func (cs *so101CalibrationSensor) Close(ctx context.Context) error { cs.recordingActive = false if cs.controller != nil { - ReleaseSharedController() + globalRegistry.ReleaseController(cs.controllerPort) } return nil diff --git a/gripper.go b/gripper.go index 0c32f5a..5c18695 100644 --- a/gripper.go +++ b/gripper.go @@ -58,11 +58,12 @@ func (cfg *SO101GripperConfig) Validate(path string) ([]string, []string, error) type so101Gripper struct { resource.AlwaysRebuild - name resource.Name - logger logging.Logger - controller *SafeSoArmController - geometries []spatialmath.Geometry - servoID int + name resource.Name + logger logging.Logger + controller *SafeSoArmController + controllerPort string // port path used to acquire the shared controller + geometries []spatialmath.Geometry + servoID int mu sync.Mutex isMoving atomic.Bool @@ -131,6 +132,7 @@ func newSO101Gripper(ctx context.Context, deps resource.Dependencies, conf resou name: conf.ResourceName(), logger: logger, controller: controller, + controllerPort: controllerConfig.Port, geometries: geometries, servoID: cfg.ServoID, speed: 30, @@ -339,7 +341,7 @@ func (g *so101Gripper) DoCommand(ctx context.Context, cmd map[string]interface{} } func (g *so101Gripper) Close(ctx context.Context) error { - ReleaseSharedController() + globalRegistry.ReleaseController(g.controllerPort) return nil } diff --git a/lifecycle_test.go b/lifecycle_test.go new file mode 100644 index 0000000..9883ab7 --- /dev/null +++ b/lifecycle_test.go @@ -0,0 +1,173 @@ +package so_arm + +import ( + "context" + "errors" + "sync/atomic" + "testing" +) + +// TestController_PostCloseReturnsSentinel verifies that every gated controller +// method returns ErrControllerClosed when the controller's closed flag is set, +// rather than panicking or hitting the (closed) bus. Table-driven so adding a +// new method to the gated list without an entry here is a visible omission. +func TestController_PostCloseReturnsSentinel(t *testing.T) { + ctx := context.Background() + cases := []struct { + name string + call func(*SafeSoArmController) error + }{ + {"Ping", func(c *SafeSoArmController) error { return c.Ping(ctx) }}, + {"SetTorqueEnable", func(c *SafeSoArmController) error { return c.SetTorqueEnable(ctx, true) }}, + {"Stop", func(c *SafeSoArmController) error { return c.Stop(ctx) }}, + {"MoveToJointPositions", func(c *SafeSoArmController) error { + return c.MoveToJointPositions(ctx, []float64{0, 0, 0, 0, 0}, 0, 0) + }}, + {"MoveServosToPositions", func(c *SafeSoArmController) error { + return c.MoveServosToPositions(ctx, []int{1}, []float64{0}, 0, 0) + }}, + {"WriteServoRegister", func(c *SafeSoArmController) error { + return c.WriteServoRegister(ctx, 1, "goal_position", []byte{0, 0}) + }}, + {"SetCalibration", func(c *SafeSoArmController) error { + return c.SetCalibration(SO101FullCalibration{}) + }}, + {"GetJointPositions", func(c *SafeSoArmController) error { + _, err := c.GetJointPositions(ctx) + return err + }}, + {"GetJointPositionsForServos", func(c *SafeSoArmController) error { + _, err := c.GetJointPositionsForServos(ctx, []int{1}) + return err + }}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + bus, _ := newMockBus(t) + ctrl := &SafeSoArmController{ + bus: bus, + logger: newTestLogger(t), + } + ctrl.closed.Store(true) + + if err := tc.call(ctrl); !errors.Is(err, ErrControllerClosed) { + t.Errorf("%s after close: expected ErrControllerClosed, got %v", tc.name, err) + } + }) + } +} + +// TestRegistry_SamePointerForSamePort verifies that two callers acquiring a +// controller for the same port receive the *same* *SafeSoArmController, so +// that close-state propagates correctly across all consumers. +func TestRegistry_SamePointerForSamePort(t *testing.T) { + registry := NewControllerRegistry() + port := "/dev/test-port" + cfg := testConfig(port) + + // Inject a pre-built entry so we don't need a real bus. + bus, _ := newMockBus(t) + ctrl := &SafeSoArmController{ + bus: bus, + logger: cfg.Logger, + } + registry.entries[port] = &ControllerEntry{ + controller: ctrl, + config: cfg, + calibration: DefaultSO101FullCalibration, + refCount: 0, + } + + first, err := registry.GetController(port, cfg, DefaultSO101FullCalibration, false) + if err != nil { + t.Fatalf("first GetController: %v", err) + } + second, err := registry.GetController(port, cfg, DefaultSO101FullCalibration, false) + if err != nil { + t.Fatalf("second GetController: %v", err) + } + + if first != second { + t.Errorf("expected same pointer for same port; got %p and %p", first, second) + } + if first != ctrl { + t.Errorf("expected cached controller pointer to be returned") + } +} + +// TestRegistry_ReleaseClosesAllConsumers verifies that ReleaseController +// at refcount zero closes the bus and sets the closed flag on the shared +// controller, so other holders observe ErrControllerClosed on next call. +func TestRegistry_ReleaseClosesAllConsumers(t *testing.T) { + registry := NewControllerRegistry() + port := "/dev/test-port" + cfg := testConfig(port) + + bus, _ := newMockBus(t) + ctrl := &SafeSoArmController{ + bus: bus, + logger: cfg.Logger, + } + registry.entries[port] = &ControllerEntry{ + controller: ctrl, + config: cfg, + calibration: DefaultSO101FullCalibration, + refCount: 2, // simulate arm + gripper both holding + } + + // First release: refcount drops to 1, controller stays alive. + registry.ReleaseController(port) + if ctrl.closed.Load() { + t.Fatalf("controller closed prematurely at refcount > 0") + } + + // Second release: refcount drops to 0, controller closes. + registry.ReleaseController(port) + if !ctrl.closed.Load() { + t.Errorf("expected controller.closed=true after final release") + } + if err := ctrl.Ping(t.Context()); !errors.Is(err, ErrControllerClosed) { + t.Errorf("Ping after final release: expected ErrControllerClosed, got %v", err) + } +} + +// TestRegistry_ExplicitPortReleaseDecrementsRefcount verifies that callers +// can release a controller by passing the port path directly, with no +// dependence on runtime.Caller PC tracking. +func TestRegistry_ExplicitPortReleaseDecrementsRefcount(t *testing.T) { + registry := NewControllerRegistry() + port := "/dev/test-port" + cfg := testConfig(port) + bus, _ := newMockBus(t) + registry.entries[port] = &ControllerEntry{ + controller: &SafeSoArmController{bus: bus, logger: cfg.Logger}, + config: cfg, + refCount: 3, + } + + registry.ReleaseController(port) + + got := atomic.LoadInt64(®istry.entries[port].refCount) + if got != 2 { + t.Errorf("expected refCount=2 after release, got %d", got) + } +} + +// TestRegistry_ReleaseUnknownPortIsNoop verifies that releasing a port that +// was never registered does not panic and does not affect other entries. +func TestRegistry_ReleaseUnknownPortIsNoop(t *testing.T) { + registry := NewControllerRegistry() + registry.ReleaseController("/dev/never-existed") +} + +// File-scope signature assertions: these guarantee the ctx-threading helpers +// keep their (ctx context.Context) parameter. Dropping ctx would fail to +// compile here — caught at build time, not at test runtime. +var ( + _ func(context.Context) error = (*so101)(nil).doServoInitialization + _ func(context.Context) error = (*so101)(nil).diagnoseConnection + _ func(context.Context) error = (*so101)(nil).verifyServoConfig + _ func(context.Context) error = (*so101)(nil).initializeServos + _ func(context.Context, int) error = (*so101)(nil).initializeServosWithRetry +) diff --git a/manager.go b/manager.go index 0eb0d07..d91b085 100644 --- a/manager.go +++ b/manager.go @@ -2,6 +2,7 @@ package so_arm import ( "context" + "errors" "fmt" "math" "sync" @@ -12,6 +13,11 @@ import ( "go.viam.com/rdk/utils" ) +// ErrControllerClosed is returned by SafeSoArmController methods after the +// underlying bus has been closed via the registry. Callers holding a stale +// reference should treat this as a permanent failure for that controller. +var ErrControllerClosed = errors.New("so101: controller is closed") + // isGripperServo checks if a servo ID is the gripper (servo 6) func isGripperServo(servoID int) bool { return servoID == 6 @@ -26,9 +32,22 @@ type SafeSoArmController struct { logger logging.Logger calibration SO101FullCalibration mu sync.RWMutex + closed atomic.Bool +} + +// checkClosed returns ErrControllerClosed if the controller has been released. +// checkClosed is best-effort: callers must hold a registry refcount for the duration of any controller call. A concurrent ReleaseController can race the unlocked Load() with an in-flight method. +func (s *SafeSoArmController) checkClosed() error { + if s.closed.Load() { + return ErrControllerClosed + } + return nil } func (s *SafeSoArmController) MoveToJointPositions(ctx context.Context, jointAngles []float64, speed, acc int) error { + if err := s.checkClosed(); err != nil { + return err + } s.mu.Lock() defer s.mu.Unlock() @@ -58,6 +77,9 @@ func (s *SafeSoArmController) MoveToJointPositions(ctx context.Context, jointAng } func (s *SafeSoArmController) MoveServosToPositions(ctx context.Context, servoIDs []int, jointAngles []float64, speed, acc int) error { + if err := s.checkClosed(); err != nil { + return err + } s.mu.Lock() defer s.mu.Unlock() @@ -92,6 +114,9 @@ func (s *SafeSoArmController) MoveServosToPositions(ctx context.Context, servoID } func (s *SafeSoArmController) GetJointPositions(ctx context.Context) ([]float64, error) { + if err := s.checkClosed(); err != nil { + return nil, err + } s.mu.RLock() defer s.mu.RUnlock() @@ -130,6 +155,9 @@ func (s *SafeSoArmController) GetJointPositions(ctx context.Context) ([]float64, } func (s *SafeSoArmController) GetJointPositionsForServos(ctx context.Context, servoIDs []int) ([]float64, error) { + if err := s.checkClosed(); err != nil { + return nil, err + } s.mu.RLock() defer s.mu.RUnlock() @@ -159,6 +187,9 @@ func (s *SafeSoArmController) GetJointPositionsForServos(ctx context.Context, se } func (s *SafeSoArmController) SetTorqueEnable(ctx context.Context, enable bool) error { + if err := s.checkClosed(); err != nil { + return err + } s.mu.Lock() defer s.mu.Unlock() @@ -175,6 +206,9 @@ func (s *SafeSoArmController) SetTorqueEnable(ctx context.Context, enable bool) } func (s *SafeSoArmController) Stop(ctx context.Context) error { + if err := s.checkClosed(); err != nil { + return err + } s.mu.Lock() defer s.mu.Unlock() @@ -190,6 +224,7 @@ func (s *SafeSoArmController) Close() error { s.mu.Lock() defer s.mu.Unlock() + s.closed.Store(true) if s.bus != nil { return s.bus.Close() } @@ -197,6 +232,9 @@ func (s *SafeSoArmController) Close() error { } func (s *SafeSoArmController) Ping(ctx context.Context) error { + if err := s.checkClosed(); err != nil { + return err + } s.mu.RLock() defer s.mu.RUnlock() @@ -210,6 +248,9 @@ func (s *SafeSoArmController) Ping(ctx context.Context) error { // WriteServoRegister writes to a specific servo register by name func (s *SafeSoArmController) WriteServoRegister(ctx context.Context, servoID int, registerName string, data []byte) error { + if err := s.checkClosed(); err != nil { + return err + } s.mu.Lock() defer s.mu.Unlock() @@ -222,6 +263,9 @@ func (s *SafeSoArmController) WriteServoRegister(ctx context.Context, servoID in } func (s *SafeSoArmController) SetCalibration(calibration SO101FullCalibration) error { + if err := s.checkClosed(); err != nil { + return err + } s.mu.Lock() defer s.mu.Unlock() @@ -296,10 +340,6 @@ func GetSharedControllerWithCalibration(config *SoArm101Config, calibration SO10 return globalRegistry.GetController(config.Port, config, calibration, fromFile) } -func ReleaseSharedController() { - globalRegistry.releaseFromCaller() -} - func ForceCloseSharedController() error { globalRegistry.mu.RLock() portPaths := make([]string, 0, len(globalRegistry.entries)) @@ -351,12 +391,6 @@ func GetControllerStatus() (int64, bool, string) { return totalRefCount, hasController, configSummary } -// With multiple controllers, this returns the default calibration -// Use GetCurrentCalibrationForPort for port-specific calibration -func GetCurrentCalibration() SO101FullCalibration { - return DefaultSO101FullCalibration -} - func GetCurrentCalibrationForPort(portPath string) SO101FullCalibration { return globalRegistry.GetCurrentCalibration(portPath) } diff --git a/mock_bus_test.go b/mock_bus_test.go new file mode 100644 index 0000000..d283ec2 --- /dev/null +++ b/mock_bus_test.go @@ -0,0 +1,143 @@ +package so_arm + +import ( + "io" + "sync" + "testing" + "time" + + "github.com/hipsterbrown/feetech-servo/feetech" + "go.viam.com/rdk/logging" +) + +// scriptedMockTransport is a custom Transport (not feetech.MockTransport) whose +// Read responses are queued per-request. We use a custom type rather than the +// upstream MockTransport because tests in PR2-PR5 need to script multiple +// round-trips with different responses, and MockTransport's single ReadData +// buffer doesn't support that pattern cleanly. +type scriptedMockTransport struct { + mu sync.Mutex + written []byte + responses [][]byte + closed bool +} + +func (m *scriptedMockTransport) Read(p []byte) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + if len(m.responses) == 0 { + // Match feetech.MockTransport semantics: returning io.EOF on an empty + // queue lets feetech.Bus.readRawBytesLocked fall into its 1ms-sleep + // retry path instead of busy-spinning under the bus mutex. Important + // for PR5's planned 100Hz calibration-sensor reader. + return 0, io.EOF + } + resp := m.responses[0] + n := copy(p, resp) + if n >= len(resp) { + m.responses = m.responses[1:] + } else { + m.responses[0] = resp[n:] + } + return n, nil +} + +func (m *scriptedMockTransport) Write(p []byte) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.written = append(m.written, p...) + return len(p), nil +} + +func (m *scriptedMockTransport) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + m.closed = true + return nil +} + +func (m *scriptedMockTransport) SetReadTimeout(timeout time.Duration) error { + return nil +} + +func (m *scriptedMockTransport) Flush() error { + // No-op: tests that need to drop unconsumed responses should clear + // m.responses explicitly. Mirroring SerialTransport.Flush would risk + // silently swallowing scripted frames between operations. + return nil +} + +// queueResponse appends a raw frame to the mock's response queue. +func (m *scriptedMockTransport) queueResponse(b []byte) { + m.mu.Lock() + defer m.mu.Unlock() + m.responses = append(m.responses, b) +} + +// reset clears the written-data buffer (responses queue is preserved). +func (m *scriptedMockTransport) resetWritten() { + m.mu.Lock() + defer m.mu.Unlock() + m.written = nil +} + +// snapshotWritten returns a copy of all bytes written since the last reset. +func (m *scriptedMockTransport) snapshotWritten() []byte { + m.mu.Lock() + defer m.mu.Unlock() + out := make([]byte, len(m.written)) + copy(out, m.written) + return out +} + +// pingResponse builds the ping ACK frame for a given servo ID using the STS protocol. +// Frame: 0xFF 0xFF 0x02 0x00 +func pingResponse(servoID byte) []byte { + checksum := ^(servoID + 0x02 + 0x00) + return []byte{0xFF, 0xFF, servoID, 0x02, 0x00, checksum} +} + +// newMockBus returns a *feetech.Bus backed by a scriptedMockTransport. +// Caller is responsible for queuing responses before issuing bus operations. +func newMockBus(t *testing.T) (*feetech.Bus, *scriptedMockTransport) { + t.Helper() + mock := &scriptedMockTransport{} + bus, err := feetech.NewBus(feetech.BusConfig{ + Transport: mock, + Protocol: feetech.ProtocolSTS, + Timeout: 100 * time.Millisecond, + }) + if err != nil { + t.Fatalf("newMockBus: NewBus failed: %v", err) + } + t.Cleanup(func() { _ = bus.Close() }) + return bus, mock +} + +// newTestLogger returns a logger that discards output unless the test fails. +func newTestLogger(t *testing.T) logging.Logger { + t.Helper() + return logging.NewTestLogger(t) +} + +func TestMockBus_PingRoundTrip(t *testing.T) { + bus, mock := newMockBus(t) + // Bus.Ping does two round trips: a ping ACK, then a model-number register read. + // Queue both responses so the call completes. + mock.queueResponse(pingResponse(1)) + // Model number 777 (0x0309) read response: FF FF + mock.queueResponse([]byte{0xFF, 0xFF, 0x01, 0x04, 0x00, 0x09, 0x03, 0xEE}) + + if _, err := bus.Ping(t.Context(), 1); err != nil { + t.Fatalf("ping failed: %v", err) + } + + written := mock.snapshotWritten() + if len(written) < 6 { + t.Fatalf("expected ping packet to be written, got %d bytes: %X", len(written), written) + } + // Ping packet: 0xFF 0xFF + if written[0] != 0xFF || written[1] != 0xFF || written[2] != 0x01 { + t.Errorf("malformed ping packet: %X", written[:6]) + } +} diff --git a/registry.go b/registry.go index 2493377..5479d1f 100644 --- a/registry.go +++ b/registry.go @@ -3,7 +3,6 @@ package so_arm import ( "context" "fmt" - "runtime" "strings" "sync" "sync/atomic" @@ -24,16 +23,11 @@ type ControllerEntry struct { type ControllerRegistry struct { entries map[string]*ControllerEntry // port path -> entry mu sync.RWMutex - - // For backward API compatibility - track which caller uses which port - callerPorts map[uintptr]string // caller pointer -> port path - callerMu sync.RWMutex } func NewControllerRegistry() *ControllerRegistry { return &ControllerRegistry{ - entries: make(map[string]*ControllerEntry), - callerPorts: make(map[uintptr]string), + entries: make(map[string]*ControllerEntry), } } @@ -57,7 +51,13 @@ func (r *ControllerRegistry) getExistingController(entry *ControllerEntry, confi if entry.lastError != nil { return nil, fmt.Errorf("cached controller creation error: %w", entry.lastError) } - return nil, fmt.Errorf("controller not available for port %s", entry.config.Port) + // entry.config may have been nil'd by a concurrent ReleaseController + // at refcount zero. Prefer the caller's config to avoid a nil deref. + port := config.Port + if entry.config != nil { + port = entry.config.Port + } + return nil, fmt.Errorf("controller not available for port %s", port) } if !configsEqual(entry.config, config) { @@ -93,15 +93,10 @@ func (r *ControllerRegistry) getExistingController(entry *ControllerEntry, confi } atomic.AddInt64(&entry.refCount, 1) - r.trackCaller(entry.config.Port) - return &SafeSoArmController{ - bus: entry.controller.bus, - group: entry.controller.group, - calibratedServos: entry.controller.calibratedServos, - logger: config.Logger, - calibration: entry.calibration, - }, nil + // Return the cached pointer so that all consumers observe close-state + // (and any future calibration updates) atomically. + return entry.controller, nil } func (r *ControllerRegistry) createNewController(portPath string, config *SoArm101Config, calibration SO101FullCalibration, fromFile bool) (*SafeSoArmController, error) { @@ -211,19 +206,11 @@ func (r *ControllerRegistry) createNewController(portPath string, config *SoArm1 r.entries[portPath] = entry - r.trackCaller(portPath) - if config.Logger != nil { config.Logger.Debugf("Created new feetech servo bus with %d servos for port %s", len(calibratedServos), portPath) } - return &SafeSoArmController{ - bus: bus, - group: group, - calibratedServos: calibratedServos, - logger: config.Logger, - calibration: finalCalibration, - }, nil + return entry.controller, nil } func (r *ControllerRegistry) ReleaseController(portPath string) { @@ -240,9 +227,12 @@ func (r *ControllerRegistry) ReleaseController(portPath string) { currentRefCount := atomic.AddInt64(&entry.refCount, -1) if currentRefCount <= 0 { - if entry.controller != nil && entry.controller.bus != nil { - if err := entry.controller.bus.Close(); err != nil && entry.config != nil && entry.config.Logger != nil { - entry.config.Logger.Warnf("error closing shared controller for port %s: %v", portPath, err) + if entry.controller != nil { + entry.controller.closed.Store(true) + if entry.controller.bus != nil { + if err := entry.controller.bus.Close(); err != nil && entry.config != nil && entry.config.Logger != nil { + entry.config.Logger.Warnf("error closing shared controller for port %s: %v", portPath, err) + } } } @@ -275,6 +265,7 @@ func (r *ControllerRegistry) ForceCloseController(portPath string) error { var err error if entry.controller != nil { + entry.controller.closed.Store(true) err = entry.controller.bus.Close() entry.controller = nil entry.config = nil @@ -329,36 +320,6 @@ func (r *ControllerRegistry) GetCurrentCalibration(portPath string) SO101FullCal return entry.calibration } -func (r *ControllerRegistry) trackCaller(portPath string) { - pc, _, _, ok := runtime.Caller(3) // 3 levels up to get the actual caller - if !ok { - return - } - - r.callerMu.Lock() - r.callerPorts[pc] = portPath - r.callerMu.Unlock() -} - -func (r *ControllerRegistry) releaseFromCaller() { - pc, _, _, ok := runtime.Caller(2) // 2 levels up to get the actual caller - if !ok { - return - } - - r.callerMu.RLock() - portPath, exists := r.callerPorts[pc] - r.callerMu.RUnlock() - - if exists { - r.ReleaseController(portPath) - - r.callerMu.Lock() - delete(r.callerPorts, pc) - r.callerMu.Unlock() - } -} - // compareConfigs returns a string describing the differences between two configs func compareConfigs(a, b *SoArm101Config) string { diffs := []string{} diff --git a/registry_test.go b/registry_test.go index da03fa7..c0a2696 100644 --- a/registry_test.go +++ b/registry_test.go @@ -38,10 +38,6 @@ func TestRegistryCreation(t *testing.T) { t.Fatal("Registry entries map not initialized") } - if registry.callerPorts == nil { - t.Fatal("Registry callerPorts map not initialized") - } - if len(registry.entries) != 0 { t.Fatal("Registry should start empty") }