Skip to content

Commit 20f7db9

Browse files
committed
Make Session.Host atomic
Before was guarded by Mutex and was used unsafely
1 parent 390e261 commit 20f7db9

6 files changed

Lines changed: 90 additions & 67 deletions

File tree

internal/webrtc/sessions/manager/manager.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,11 @@ func (manager *SessionManager) GetSessionStates(includePrivateStreams bool) (res
115115

116116
s.StatusLock.RUnlock()
117117

118-
if s.Host != nil {
119-
s.Host.TracksLock.RLock()
118+
host := s.Host.Load()
119+
if host != nil {
120+
host.TracksLock.RLock()
120121

121-
for _, audioTrack := range s.Host.AudioTracks {
122+
for _, audioTrack := range host.AudioTracks {
122123
streamSession.AudioTracks = append(
123124
streamSession.AudioTracks,
124125
session.AudioTrackState{
@@ -128,7 +129,7 @@ func (manager *SessionManager) GetSessionStates(includePrivateStreams bool) (res
128129
})
129130
}
130131

131-
for _, videoTrack := range s.Host.VideoTracks {
132+
for _, videoTrack := range host.VideoTracks {
132133
var lastKeyFrame time.Time
133134
if value, ok := videoTrack.LastKeyFrame.Load().(time.Time); ok {
134135
lastKeyFrame = value
@@ -145,7 +146,7 @@ func (manager *SessionManager) GetSessionStates(includePrivateStreams bool) (res
145146
})
146147
}
147148

148-
s.Host.TracksLock.RUnlock()
149+
host.TracksLock.RUnlock()
149150
}
150151

151152
s.WhepSessionsLock.RLock()
@@ -199,13 +200,13 @@ func (manager *SessionManager) GetHostSessionById(sessionId string) (host *whip.
199200
defer manager.sessionsLock.RUnlock()
200201

201202
for _, session := range manager.sessions {
202-
203-
if session.Host == nil {
204-
return nil, false
203+
host := session.Host.Load()
204+
if host == nil {
205+
continue
205206
}
206207

207-
if sessionId == session.Host.Id {
208-
return session.Host, true
208+
if sessionId == host.Id {
209+
return host, true
209210
}
210211
}
211212

internal/webrtc/sessions/session/routines.go

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,16 @@ func (session *Session) handleWhepVideoRtcpSender(rtcpSender *webrtc.RTPSender)
6868
return
6969
}
7070

71-
if session.HasHost.Load() {
72-
for _, packet := range rtcpPackets {
73-
if _, isPLI := packet.(*rtcp.PictureLossIndication); isPLI {
74-
select {
75-
case session.Host.PacketLossIndicationChannel <- true:
76-
default:
77-
}
71+
host := session.Host.Load()
72+
if host == nil {
73+
continue
74+
}
75+
76+
for _, packet := range rtcpPackets {
77+
if _, isPLI := packet.(*rtcp.PictureLossIndication); isPLI {
78+
select {
79+
case host.PacketLossIndicationChannel <- true:
80+
default:
7881
}
7982
}
8083
}
@@ -89,8 +92,13 @@ func (session *Session) handleWhepChannels(whepSession *whep.WhepSession) {
8992
return
9093

9194
case <-whepSession.ConnectionChannel:
95+
host := session.Host.Load()
96+
if host == nil {
97+
continue
98+
}
99+
92100
select {
93-
case session.Host.PacketLossIndicationChannel <- true:
101+
case host.PacketLossIndicationChannel <- true:
94102
default:
95103
}
96104
}
@@ -106,7 +114,7 @@ func (session *Session) hostStatusLoop() {
106114
defer ticker.Stop()
107115

108116
for {
109-
host := session.Host
117+
host := session.Host.Load()
110118
if host == nil {
111119
if session.isEmpty() {
112120
session.close()
@@ -131,7 +139,7 @@ func (session *Session) hostStatusLoop() {
131139
case <-ticker.C:
132140
if session.isEmpty() {
133141
session.close()
134-
} else if session.Host != nil {
142+
} else if session.Host.Load() != nil {
135143

136144
status := session.GetSessionStatsEvent()
137145
session.WhepSessionsLock.RLock()
@@ -153,12 +161,12 @@ func (session *Session) Snapshot() {
153161
for {
154162
select {
155163
case <-session.ActiveContext.Done():
156-
if session.Host != nil {
157-
session.Host.WhepSessionsSnapshot.Store(make(map[string]*whep.WhepSession))
164+
if host := session.Host.Load(); host != nil {
165+
host.WhepSessionsSnapshot.Store(make(map[string]*whep.WhepSession))
158166
}
159167
return
160168
case <-ticker.C:
161-
if session.Host != nil {
169+
if host := session.Host.Load(); host != nil {
162170
session.WhepSessionsLock.RLock()
163171
snapshot := make(map[string]*whep.WhepSession, len(session.WhepSessions))
164172

@@ -169,7 +177,7 @@ func (session *Session) Snapshot() {
169177
}
170178
session.WhepSessionsLock.RUnlock()
171179

172-
session.Host.WhepSessionsSnapshot.Store(snapshot)
180+
host.WhepSessionsSnapshot.Store(snapshot)
173181
}
174182
}
175183
}

internal/webrtc/sessions/session/session.go

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,11 @@ import (
1717
func (session *Session) GetHost(streamKey string) (host *whip.WhipSession, foundSession bool) {
1818
log.Println("Session.GetHost")
1919

20-
if session.Host == nil {
20+
host = session.Host.Load()
21+
if host == nil {
2122
return nil, false
2223
}
2324

24-
session.HostLock.RLock()
25-
host = session.Host
26-
session.HostLock.RUnlock()
27-
2825
return host, true
2926
}
3027

@@ -56,16 +53,17 @@ func (session *Session) UpdateStreamStatus(profile authorization.PublicProfile)
5653
func (session *Session) AddWhep(whepSessionId string, peerConnection *webrtc.PeerConnection, audioTrack *codecs.TrackMultiCodec, videoTrack *codecs.TrackMultiCodec, videoRtcpSender *webrtc.RTPSender) (err error) {
5754
log.Println("WhipSessionManager.WhipSession.AddWhepSession")
5855

59-
if session.Host == nil {
56+
host := session.Host.Load()
57+
if host == nil {
6058
return fmt.Errorf("no host was found on the current session")
6159
}
6260

6361
whepSession := whep.CreateNewWhep(
6462
whepSessionId,
6563
audioTrack,
66-
session.Host.GetHighestPrioritizedAudioTrack(),
64+
host.GetHighestPrioritizedAudioTrack(),
6765
videoTrack,
68-
session.Host.GetHighestPrioritizedVideoTrack(),
66+
host.GetHighestPrioritizedVideoTrack(),
6967
peerConnection)
7068

7169
whepSession.RegisterWhepHandlers(peerConnection)
@@ -74,12 +72,12 @@ func (session *Session) AddWhep(whepSessionId string, peerConnection *webrtc.Pee
7472
session.WhepSessions[whepSessionId] = whepSession
7573
session.WhepSessionsLock.Unlock()
7674

77-
go session.handleWhepConnection(session.Host, whepSession)
75+
go session.handleWhepConnection(host, whepSession)
7876
go session.handleWhepChannels(whepSession)
7977
go session.handleWhepVideoRtcpSender(videoRtcpSender)
8078

8179
// TODO: Implement
82-
// go session.handleWhepLayerChange(session.Host, whepSession)
80+
// go session.handleWhepLayerChange(host, whepSession)
8381

8482
return nil
8583
}
@@ -88,20 +86,24 @@ func (session *Session) AddWhep(whepSessionId string, peerConnection *webrtc.Pee
8886
func (session *Session) AddHost(peerConnection *webrtc.PeerConnection) (err error) {
8987
log.Println("Session.AddHost")
9088

91-
session.HostLock.Lock()
89+
for {
90+
host := session.Host.Load()
91+
if host == nil {
92+
break
93+
}
9294

93-
if session.Host != nil && session.Host.PeerConnection.ConnectionState() == webrtc.PeerConnectionStateClosed {
94-
if session.ActiveContext.Err() != nil {
95-
session.Host = nil
96-
} else {
97-
session.HostLock.Unlock()
95+
if host.PeerConnection.ConnectionState() != webrtc.PeerConnectionStateClosed || session.ActiveContext.Err() == nil {
9896
return fmt.Errorf("session already has a host")
9997
}
98+
99+
if session.Host.CompareAndSwap(host, nil) {
100+
break
101+
}
100102
}
101103

102104
activeContext, activeContextCancel := context.WithCancel(context.Background())
103105

104-
session.Host = &whip.WhipSession{
106+
host := &whip.WhipSession{
105107
Id: uuid.New().String(),
106108
AudioTracks: make(map[string]*whip.AudioTrack),
107109
VideoTracks: make(map[string]*whip.VideoTrack),
@@ -113,8 +115,13 @@ func (session *Session) AddHost(peerConnection *webrtc.PeerConnection) (err erro
113115
ActiveContextCancel: activeContextCancel,
114116
}
115117

116-
session.Host.AddPeerConnection(peerConnection, session.StreamKey)
117-
session.HostLock.Unlock()
118+
host.AddPeerConnection(peerConnection, session.StreamKey)
119+
if !session.Host.CompareAndSwap(nil, host) {
120+
host.ActiveContextCancel()
121+
host.RemovePeerConnection()
122+
host.RemoveTracks()
123+
return fmt.Errorf("session already has a host")
124+
}
118125

119126
go session.hostStatusLoop()
120127

@@ -123,20 +130,17 @@ func (session *Session) AddHost(peerConnection *webrtc.PeerConnection) (err erro
123130

124131
func (session *Session) RemoveHost() {
125132

126-
if session.Host == nil {
133+
host := session.Host.Swap(nil)
134+
if host == nil {
127135
log.Println("Session.RemoveHost", session.StreamKey, "- No host to remove")
128136
return
129137
}
130138

131139
log.Println("Session.RemoveHost", session.StreamKey)
132140

133-
session.Host.ActiveContextCancel()
134-
session.Host.RemovePeerConnection()
135-
session.Host.RemoveTracks()
136-
137-
session.HostLock.Lock()
138-
session.Host = nil
139-
session.HostLock.Unlock()
141+
host.ActiveContextCancel()
142+
host.RemovePeerConnection()
143+
host.RemoveTracks()
140144
}
141145

142146
// Remove Whep session from Whip session
@@ -197,24 +201,25 @@ func (session *Session) isEmpty() bool {
197201
// Returns true if any tracks are available for the session
198202
func (session *Session) isStreaming() bool {
199203

200-
if session.Host == nil {
204+
host := session.Host.Load()
205+
if host == nil {
201206
return false
202207
}
203208

204-
session.Host.TracksLock.RLock()
209+
host.TracksLock.RLock()
205210

206-
if len(session.Host.AudioTracks) != 0 {
207-
log.Println("Session.IsActive.AudioTracks", len(session.Host.AudioTracks))
208-
session.Host.TracksLock.RUnlock()
211+
if len(host.AudioTracks) != 0 {
212+
log.Println("Session.IsActive.AudioTracks", len(host.AudioTracks))
213+
host.TracksLock.RUnlock()
209214
return true
210215
}
211-
if len(session.Host.VideoTracks) != 0 {
212-
log.Println("Session.IsActive.VideoTracks", len(session.Host.VideoTracks))
213-
session.Host.TracksLock.RUnlock()
216+
if len(host.VideoTracks) != 0 {
217+
log.Println("Session.IsActive.VideoTracks", len(host.VideoTracks))
218+
host.TracksLock.RUnlock()
214219
return true
215220
}
216221

217-
session.Host.TracksLock.RUnlock()
222+
host.TracksLock.RUnlock()
218223
return false
219224
}
220225

internal/webrtc/sessions/session/type.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ type Session struct {
2222
IsPublic bool
2323
StreamStart time.Time
2424

25-
HostLock sync.RWMutex
26-
Host *whip.WhipSession
25+
Host atomic.Pointer[whip.WhipSession]
2726

2827
// Context
2928
ActiveContext context.Context

internal/webrtc/webrtc.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,17 @@ func HandleWhipPatch(sessionId, body string) error {
6868
return errors.New("no session found")
6969
}
7070

71-
session.Host.PeerConnectionLock.Lock()
72-
if err := patchPeerConnection(session.Host.PeerConnection, body); err != nil {
73-
session.Host.PeerConnectionLock.Unlock()
71+
host := session.Host.Load()
72+
if host == nil {
73+
return errors.New("no host found")
74+
}
75+
76+
host.PeerConnectionLock.Lock()
77+
if err := patchPeerConnection(host.PeerConnection, body); err != nil {
78+
host.PeerConnectionLock.Unlock()
7479
return err
7580
}
76-
session.Host.PeerConnectionLock.Unlock()
81+
host.PeerConnectionLock.Unlock()
7782

7883
return nil
7984
}

internal/webrtc/whip.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,13 @@ func WHIP(offer string, profile authorization.PublicProfile) (sdp string, sessio
3434
return "", "", err
3535
}
3636

37+
host := session.Host.Load()
38+
if host == nil {
39+
return "", "", errors.New("host session not available")
40+
}
41+
3742
sdp = utils.DebugOutputAnswer(utils.AppendCandidateToAnswer(peerConnection.LocalDescription().SDP))
38-
sessionId = session.Host.Id
43+
sessionId = host.Id
3944
err = nil
4045
log.Println("WHIP.Offer.Accepted", profile.StreamKey, profile.MOTD)
4146
return

0 commit comments

Comments
 (0)