Skip to content

Commit e7156e5

Browse files
Sean-DerNabos
andcommitted
Add Server Trickle ICE Support
Co-authored-by: Nabos <2191596-nabos@users.noreply.gitlab.com>
1 parent ee384f0 commit e7156e5

5 files changed

Lines changed: 157 additions & 9 deletions

File tree

internal/webrtc/webrtc.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package webrtc
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"io"
89
"log"
@@ -51,6 +52,8 @@ type (
5152
whipActiveContext context.Context
5253
whipActiveContextCancel func()
5354

55+
peerConnection atomic.Pointer[webrtc.PeerConnection]
56+
5457
whepSessionsLock sync.RWMutex
5558
whepSessions map[string]*whepSession
5659
}
@@ -70,6 +73,9 @@ var (
7073
streamMapLock sync.Mutex
7174
apiWhip, apiWhep *webrtc.API
7275

76+
errNoPeerConnection = errors.New("unable to find PeerConnection")
77+
errICERestartNotSupported = errors.New("ice restart not supported")
78+
7379
// nolint
7480
videoRTCPFeedback = []webrtc.RTCPFeedback{{"goog-remb", ""}, {"ccm", "fir"}, {"nack", ""}, {"nack", "pli"}}
7581
)
@@ -510,3 +516,59 @@ func GetStreamStatuses() []StreamStatus {
510516

511517
return out
512518
}
519+
520+
func HandlePatch(sessionId, body string, isWHIP bool) error {
521+
valueForKey := func(sdp, key string) string {
522+
for _, l := range strings.Split(sdp, "\n") {
523+
expectedPrefix := "a=" + key + ":"
524+
if strings.HasPrefix(l, expectedPrefix) {
525+
return strings.TrimPrefix(l, expectedPrefix)
526+
}
527+
}
528+
529+
return ""
530+
}
531+
532+
var peerConnection *webrtc.PeerConnection
533+
534+
streamMapLock.Lock()
535+
if isWHIP {
536+
if stream := streamMap[sessionId]; stream != nil {
537+
peerConnection = stream.peerConnection.Load()
538+
}
539+
} else {
540+
for _, s := range streamMap {
541+
s.whepSessionsLock.Lock()
542+
if whepSession := s.whepSessions[sessionId]; whepSession != nil {
543+
peerConnection = whepSession.peerConnection
544+
}
545+
s.whepSessionsLock.Unlock()
546+
}
547+
}
548+
streamMapLock.Unlock()
549+
550+
if peerConnection == nil {
551+
return errNoPeerConnection
552+
}
553+
554+
oldUfrag := valueForKey(peerConnection.CurrentRemoteDescription().SDP, "ice-ufrag")
555+
oldPwd := valueForKey(peerConnection.CurrentRemoteDescription().SDP, "ice-pwd")
556+
newUfrag, newPwd := valueForKey(body, "ice-ufrag"), valueForKey(body, "ice-pwd")
557+
isICERestart := oldUfrag != newUfrag || oldPwd != newPwd
558+
if isICERestart {
559+
return errICERestartNotSupported
560+
}
561+
562+
for _, l := range strings.Split(body, "\n") {
563+
expectedPrefix := "a=candidate:"
564+
if strings.HasPrefix(l, expectedPrefix) {
565+
if err := peerConnection.AddICECandidate(webrtc.ICECandidateInit{
566+
Candidate: strings.TrimSpace(strings.TrimPrefix(l, "a=")),
567+
}); err != nil {
568+
return err
569+
}
570+
}
571+
}
572+
573+
return nil
574+
}

internal/webrtc/webrtc_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package webrtc
2+
3+
import (
4+
"context"
5+
"strings"
6+
"testing"
7+
8+
"github.com/pion/webrtc/v4"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestICETrickle(t *testing.T) {
13+
Configure()
14+
localTrack, err := webrtc.NewTrackLocalStaticSample(
15+
webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264}, "video", "pion",
16+
)
17+
require.NoError(t, err)
18+
19+
ctx, done := context.WithCancel(context.TODO())
20+
21+
peerConnection, err := webrtc.NewPeerConnection(webrtc.Configuration{})
22+
require.NoError(t, err)
23+
24+
peerConnection.OnConnectionStateChange(func(c webrtc.PeerConnectionState) {
25+
if c == webrtc.PeerConnectionStateConnected {
26+
done()
27+
}
28+
})
29+
30+
peerConnection.OnICECandidate(func(_ *webrtc.ICECandidate) {
31+
require.NoError(t, HandlePatch(testStreamKey, peerConnection.LocalDescription().SDP, true))
32+
})
33+
34+
_, err = peerConnection.AddTrack(localTrack)
35+
require.NoError(t, err)
36+
37+
offer, err := peerConnection.CreateOffer(nil)
38+
require.NoError(t, err)
39+
require.NoError(t, peerConnection.SetLocalDescription(offer))
40+
41+
answer, err := WHIP(offer.SDP, testStreamKey)
42+
require.NoError(t, err)
43+
44+
noCandidateAnswer := ""
45+
for _, l := range strings.Split(answer, "\n") {
46+
if !strings.HasPrefix(l, "a=candidate:") {
47+
noCandidateAnswer += l + "\n"
48+
}
49+
}
50+
51+
require.NoError(t, peerConnection.SetRemoteDescription(webrtc.SessionDescription{
52+
Type: webrtc.SDPTypeAnswer,
53+
SDP: noCandidateAnswer,
54+
}))
55+
56+
<-ctx.Done()
57+
}

internal/webrtc/whep.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ type (
2121
sequenceNumber uint16
2222
timestamp uint32
2323
packetsWritten uint64
24+
peerConnection *webrtc.PeerConnection
2425
}
2526

2627
simulcastLayerResponse struct {
@@ -151,8 +152,9 @@ func WHEP(offer, streamKey string) (string, string, error) {
151152
defer stream.whepSessionsLock.Unlock()
152153

153154
stream.whepSessions[whepSessionId] = &whepSession{
154-
videoTrack: videoTrack,
155-
timestamp: 50000,
155+
videoTrack: videoTrack,
156+
timestamp: 50000,
157+
peerConnection: peerConnection,
156158
}
157159
stream.whepSessions[whepSessionId].currentLayer.Store("")
158160
stream.whepSessions[whepSessionId].waitingForKeyframe.Store(false)

internal/webrtc/whip.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ func WHIP(offer, streamKey string) (string, error) {
156156
if err != nil {
157157
return "", err
158158
}
159+
stream.peerConnection.Store(peerConnection)
159160

160161
peerConnection.OnTrack(func(remoteTrack *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) {
161162
if strings.HasPrefix(remoteTrack.Codec().MimeType, "audio") {

main.go

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"log"
10+
"mime"
1011
"net/http"
1112
"os"
1213
"path"
@@ -76,8 +77,23 @@ func logHTTPError(w http.ResponseWriter, err string, code int) {
7677
http.Error(w, err, code)
7778
}
7879

80+
func patchHandler(res http.ResponseWriter, r *http.Request, sessionId, body string, isWHIP bool) {
81+
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
82+
if err != nil || mediaType != "application/trickle-ice-sdpfrag" {
83+
logHTTPError(res, "invalid content type", http.StatusUnsupportedMediaType)
84+
return
85+
}
86+
87+
if err = webrtc.HandlePatch(sessionId, body, isWHIP); err != nil {
88+
logHTTPError(res, err.Error(), http.StatusBadRequest)
89+
return
90+
}
91+
92+
res.WriteHeader(http.StatusNoContent)
93+
}
94+
7995
func whipHandler(res http.ResponseWriter, r *http.Request) {
80-
if r.Method != "POST" {
96+
if r.Method != "POST" && r.Method != "PATCH" {
8197
return
8298
}
8399

@@ -87,13 +103,18 @@ func whipHandler(res http.ResponseWriter, r *http.Request) {
87103
return
88104
}
89105

90-
offer, err := io.ReadAll(r.Body)
106+
body, err := io.ReadAll(r.Body)
91107
if err != nil {
92108
logHTTPError(res, err.Error(), http.StatusBadRequest)
93109
return
94110
}
95111

96-
answer, err := webrtc.WHIP(string(offer), streamKey)
112+
if r.Method == "PATCH" {
113+
patchHandler(res, r, streamKey, string(body), true)
114+
return
115+
}
116+
117+
answer, err := webrtc.WHIP(string(body), streamKey)
97118
if err != nil {
98119
logHTTPError(res, err.Error(), http.StatusBadRequest)
99120
return
@@ -108,7 +129,7 @@ func whipHandler(res http.ResponseWriter, r *http.Request) {
108129
}
109130

110131
func whepHandler(res http.ResponseWriter, req *http.Request) {
111-
if req.Method != "POST" {
132+
if req.Method != "POST" && req.Method != "PATCH" {
112133
return
113134
}
114135

@@ -118,13 +139,18 @@ func whepHandler(res http.ResponseWriter, req *http.Request) {
118139
return
119140
}
120141

121-
offer, err := io.ReadAll(req.Body)
142+
body, err := io.ReadAll(req.Body)
122143
if err != nil {
123144
logHTTPError(res, err.Error(), http.StatusBadRequest)
124145
return
125146
}
126147

127-
answer, whepSessionId, err := webrtc.WHEP(string(offer), streamKey)
148+
if req.Method == "PATCH" {
149+
patchHandler(res, req, "TODO", string(body), true)
150+
return
151+
}
152+
153+
answer, whepSessionId, err := webrtc.WHEP(string(body), streamKey)
128154
if err != nil {
129155
logHTTPError(res, err.Error(), http.StatusBadRequest)
130156
return
@@ -133,7 +159,7 @@ func whepHandler(res http.ResponseWriter, req *http.Request) {
133159
apiPath := req.Host + strings.TrimSuffix(req.URL.RequestURI(), "whep")
134160
res.Header().Add("Link", `<`+apiPath+"sse/"+whepSessionId+`>; rel="urn:ietf:params:whep:ext:core:server-sent-events"; events="layers"`)
135161
res.Header().Add("Link", `<`+apiPath+"layer/"+whepSessionId+`>; rel="urn:ietf:params:whep:ext:core:layer"`)
136-
res.Header().Add("Location", "/api/whep")
162+
res.Header().Add("Location", "/api/whep/"+whepSessionId)
137163
res.Header().Add("Content-Type", "application/sdp")
138164
res.WriteHeader(http.StatusCreated)
139165
if _, err = fmt.Fprint(res, answer); err != nil {

0 commit comments

Comments
 (0)