diff --git a/pkg/sip/inbound.go b/pkg/sip/inbound.go index 2d5f7bd9..2282183e 100644 --- a/pkg/sip/inbound.go +++ b/pkg/sip/inbound.go @@ -312,6 +312,26 @@ func (s *Server) handleInviteAuth(tid traceid.ID, log logger.Logger, req *sip.Re return true, false } +func sdpBodyFromRequest(req *sip.Request) []byte { + ct := req.ContentType() + if ct != nil && ct.Value() != "application/sdp" { + return nil + } + return req.Body() +} + +func updateRemoteFromSDP(media *MediaPort, log logger.Logger, body []byte) { + if len(body) == 0 || media == nil { + return + } + desc, err := sdp.ParseWith(msdk.GlobalCodecs(), body) + if err != nil { + log.Warnw("failed to parse re-INVITE SDP, RTP destination not updated", err) + return + } + media.UpdateRemote(desc.Addr) +} + func (s *Server) onInvite(log *slog.Logger, req *sip.Request, tx sip.ServerTransaction) { // Error processed in defer _ = s.processInvite(req, tx) @@ -380,7 +400,8 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE existing := s.byLocalTag[cc.ID()] s.cmu.RUnlock() if existing != nil && existing.cc.InviteCSeq() < cc.InviteCSeq() { - existing.log().Infow("reinvite", "content-type", req.ContentType(), "content-length", req.ContentLength(), "cseq", cc.InviteCSeq()) + existing.log().Infow("reinvite", "content-length", req.ContentLength(), "cseq", cc.InviteCSeq()) + updateRemoteFromSDP(existing.media, existing.log(), sdpBodyFromRequest(req)) cc.AcceptAsKeepAlive(existing.cc.OwnSDP()) return nil } @@ -388,11 +409,12 @@ func (s *Server) processInvite(req *sip.Request, tx sip.ServerTransaction) (retE oc := s.cli.getActiveCall(cc.ID()) newCSeq := cc.InviteCSeq() if oc != nil && oc.cc != nil && oc.cc.InviteCSeq() < newCSeq { - sdp := oc.cc.LocalSDP() - if len(sdp) != 0 { - oc.log.Infow("accepting reinvite", "content-type", req.ContentType(), "content-length", req.ContentLength(), "cseq", cc.InviteCSeq()) + localSDP := oc.cc.LocalSDP() + if len(localSDP) != 0 { + oc.log.Infow("accepting reinvite", "content-length", req.ContentLength(), "cseq", cc.InviteCSeq()) + updateRemoteFromSDP(oc.media, oc.log, sdpBodyFromRequest(req)) oc.cc.RecordInvite(newCSeq) - cc.AcceptAsKeepAlive(sdp) + cc.AcceptAsKeepAlive(localSDP) return nil } } diff --git a/pkg/sip/media_port.go b/pkg/sip/media_port.go index 00567906..9ab48b65 100644 --- a/pkg/sip/media_port.go +++ b/pkg/sip/media_port.go @@ -241,7 +241,7 @@ func (c *udpConn) SetDst(addr netip.AddrPort) { if addr.IsValid() { prev := c.dst.Swap(&addr) if prev == nil || !prev.IsValid() { - c.log.Infow("setting media destination", "prev", prev, "addr", addr.String()) + c.log.Infow("setting media destination", "addr", addr.String()) } else if *prev != addr { changeCount := c.dstChangeCount.Add(1) now := time.Now().UnixNano() @@ -257,7 +257,7 @@ func (c *udpConn) Read(b []byte) (n int, err error) { n, addr, err := c.ReadFromUDPAddrPort(b) prev := c.src.Swap(&addr) if prev == nil || !prev.IsValid() { - c.log.Infow("setting media source", "prev", prev, "addr", addr.String()) + c.log.Infow("setting media source", "addr", addr.String()) } else if *prev != addr { changeCount := c.srcChangeCount.Add(1) now := time.Now().UnixNano() @@ -630,6 +630,20 @@ func (p *MediaPort) Port() int { return p.port.LocalAddr().(*net.UDPAddr).Port } +func (p *MediaPort) RemoteAddr() netip.AddrPort { + dst := p.port.dst.Load() + if dst == nil { + return netip.AddrPort{} + } + return *dst +} + +func (p *MediaPort) UpdateRemote(addr netip.AddrPort) { + if addr.IsValid() && !addr.Addr().IsUnspecified() { + p.port.SetDst(addr) + } +} + func (p *MediaPort) Received() <-chan struct{} { return p.mediaReceived.Watch() } diff --git a/pkg/sip/media_port_test.go b/pkg/sip/media_port_test.go index f8d5823c..e2ee1222 100644 --- a/pkg/sip/media_port_test.go +++ b/pkg/sip/media_port_test.go @@ -165,6 +165,35 @@ func newIP(v string) netip.Addr { return ip } +func TestMediaPortUpdateRemote(t *testing.T) { + log := logger.GetLogger() + mon := newTestCallMonitor(t) + + // newUDPPipe wires two in-memory testUDPConn together. + c1, _ := newUDPPipe() + mp, err := NewMediaPortWith(1, log, mon, c1, &MediaOptions{ + IP: netip.MustParseAddr("127.0.0.1"), + }, 8000) + require.NoError(t, err) + defer mp.Close() + + // Initially no destination is set. + require.False(t, mp.RemoteAddr().IsValid(), "RemoteAddr should be invalid before any update") + + // Update to a valid address. + addr := netip.MustParseAddrPort("9.8.7.6:12345") + mp.UpdateRemote(addr) + require.Equal(t, addr, mp.RemoteAddr(), "RemoteAddr should reflect the updated address") + + // UpdateRemote with invalid addr should be a no-op. + mp.UpdateRemote(netip.AddrPort{}) + require.Equal(t, addr, mp.RemoteAddr(), "UpdateRemote with invalid addr should not change RemoteAddr") + + // UpdateRemote with unspecified address (c=0.0.0.0 hold form) should be a no-op. + mp.UpdateRemote(netip.MustParseAddrPort("0.0.0.0:12345")) + require.Equal(t, addr, mp.RemoteAddr(), "UpdateRemote with unspecified addr should not change RemoteAddr") +} + func TestMediaPort(t *testing.T) { // Main resampler has unpredictable (although tiny) output delay // and other randomness in the generated samples. diff --git a/pkg/sip/signaling_test.go b/pkg/sip/signaling_test.go index 389f4255..8744f46f 100644 --- a/pkg/sip/signaling_test.go +++ b/pkg/sip/signaling_test.go @@ -641,7 +641,7 @@ func TestReinvite(t *testing.T) { t.Run("inbound", func(t *testing.T) { t.Run("normal", func(t *testing.T) { st := NewServiceTest(t, nil) - call, _ := st.CreateInboundCall(t) + call, ic := st.CreateInboundCall(t) serverLocalSDP := call.remoteSDP // Re-INVITE @@ -661,6 +661,9 @@ func TestReinvite(t *testing.T) { resp = st.TestUA.TransactionRequest(t, req, true) require.Equal(t, sip.StatusCode(200), resp.StatusCode, "reinvite for outbound call should get 200 OK") require.Equal(t, serverLocalSDP, resp.Body(), "reinvite 200 OK should return server local SDP") + + // After the re-INVITE with new offer, the media port destination must be updated. + require.Equal(t, newOffer.Addr, ic.media.RemoteAddr(), "re-INVITE should redirect RTP to the new remote address") }) t.Run("miss", func(t *testing.T) { @@ -683,6 +686,20 @@ func TestReinvite(t *testing.T) { require.Equal(t, sip.StatusCode(200), resp.StatusCode, "reinvite for outbound call should get 200 OK") require.NotEqual(t, serverLocalSDP, resp.Body(), "reinvite for new call should return new server local SDP") }) + + t.Run("no_body", func(t *testing.T) { + st := NewServiceTest(t, nil) + call, ic := st.CreateInboundCall(t) + serverLocalSDP := call.remoteSDP + initialRemote := ic.media.RemoteAddr() + + // Re-INVITE with no SDP body — destination must not change. + req := call.NewRequest(sip.INVITE) // no body, no Content-Type + resp := st.TestUA.TransactionRequest(t, req, true) + require.Equal(t, sip.StatusCode(200), resp.StatusCode, "body-less re-INVITE should still get 200 OK") + require.Equal(t, serverLocalSDP, resp.Body(), "body-less re-INVITE should return server local SDP") + require.Equal(t, initialRemote, ic.media.RemoteAddr(), "body-less re-INVITE must not change RTP destination") + }) }) t.Run("outbound", func(t *testing.T) { t.Run("normal", func(t *testing.T) { @@ -708,6 +725,23 @@ func TestReinvite(t *testing.T) { resp = st.TestUA.TransactionRequest(t, req, false) require.Equal(t, sip.StatusCode(200), resp.StatusCode, "reinvite for outbound call should get 200 OK") require.Equal(t, serverLocalSDP, resp.Body(), "reinvite 200 OK should return server local SDP") + + // After the re-INVITE with new offer, the media port destination must be updated. + require.Equal(t, newOffer.Addr, oc.media.RemoteAddr(), "re-INVITE should redirect outbound call RTP to the new remote address") + }) + + t.Run("no_body", func(t *testing.T) { + st := NewServiceTest(t, nil) + call, oc, _ := st.CreateOutboundCall(t) + serverLocalSDP := oc.cc.LocalSDP() + initialRemote := oc.media.RemoteAddr() + + // Re-INVITE with no SDP body — destination must not change. + req := call.NewRequest(sip.INVITE) // no body, no Content-Type + resp := st.TestUA.TransactionRequest(t, req, false) + require.Equal(t, sip.StatusCode(200), resp.StatusCode, "body-less re-INVITE should still get 200 OK") + require.Equal(t, serverLocalSDP, resp.Body(), "body-less re-INVITE should return server local SDP") + require.Equal(t, initialRemote, oc.media.RemoteAddr(), "body-less re-INVITE must not change RTP destination") }) t.Run("miss", func(t *testing.T) {