Skip to content

Commit f96b865

Browse files
committed
Use callbacks instead of channels for SSE
Reduce complexity and fix more deadlocks with disconnected SSE sessions
1 parent 4b011f1 commit f96b865

13 files changed

Lines changed: 242 additions & 211 deletions

File tree

internal/server/handlers/sse.go

Lines changed: 74 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
package handlers
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"log"
78
"net/http"
89
"os"
910
"strings"
11+
"sync"
1012
"time"
1113

1214
"github.com/glimesh/broadcast-box/internal/environment"
1315
"github.com/glimesh/broadcast-box/internal/server/helpers"
1416
"github.com/glimesh/broadcast-box/internal/webrtc/sessions/manager"
17+
"github.com/google/uuid"
1518
)
1619

1720
func sseHandler(responseWriter http.ResponseWriter, request *http.Request) {
@@ -34,79 +37,94 @@ func sseHandler(responseWriter http.ResponseWriter, request *http.Request) {
3437
ctx := request.Context()
3538
responseController := http.NewResponseController(responseWriter)
3639

37-
// Setup WHEP/WHIP session for SSE feed
38-
sseChannel := getWhipSessionChannel(sessionId)
40+
var writeLock sync.Mutex
41+
writeEvent := func(writeCtx context.Context, msg string) bool {
42+
if msg == "" || writeCtx.Err() != nil {
43+
return false
44+
}
3945

40-
if sseChannel == nil {
41-
sseChannel = getWhepSessionChannel(sessionId)
42-
}
46+
writeLock.Lock()
47+
defer writeLock.Unlock()
4348

44-
if sseChannel == nil {
45-
helpers.LogHttpError(responseWriter, "Invalid request", http.StatusBadRequest)
46-
return
47-
}
49+
if debugSseMessages {
50+
log.Println("API.SSE Sending:", msg)
51+
}
4852

49-
for {
50-
select {
51-
case <-ctx.Done():
52-
log.Println("API.SSE: Client disconnected")
53-
return
53+
if err := responseController.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil && !errors.Is(err, http.ErrNotSupported) {
54+
log.Println("API.SSE SetWriteDeadline error:", err)
55+
return false
56+
}
5457

55-
case msg, ok := <-sseChannel:
56-
if debugSseMessages {
57-
log.Println("API.SSE Sending:", msg)
58-
}
58+
_, err := fmt.Fprintf(responseWriter, "%s\n", msg)
59+
if err == nil {
60+
flusher.Flush()
61+
}
5962

60-
if !ok || msg == "close" {
61-
log.Println("API.SSE: Channel closed")
62-
return
63-
}
63+
if deadlineErr := responseController.SetWriteDeadline(time.Time{}); deadlineErr != nil && !errors.Is(deadlineErr, http.ErrNotSupported) {
64+
log.Println("API.SSE ClearWriteDeadline error:", deadlineErr)
65+
return false
66+
}
6467

65-
if err := responseController.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil && !errors.Is(err, http.ErrNotSupported) {
66-
log.Println("API.SSE SetWriteDeadline error:", err)
67-
return
68+
if err != nil {
69+
if errors.Is(err, os.ErrDeadlineExceeded) {
70+
log.Println("API.SSE Write timeout")
71+
} else {
72+
log.Println("API.SSE Write error:", err)
6873
}
74+
return false
75+
}
6976

70-
_, err := fmt.Fprintf(responseWriter, "%s\n", msg)
71-
if err == nil {
72-
flusher.Flush()
73-
}
77+
return true
78+
}
7479

75-
if deadlineErr := responseController.SetWriteDeadline(time.Time{}); deadlineErr != nil && !errors.Is(deadlineErr, http.ErrNotSupported) {
76-
log.Println("API.SSE ClearWriteDeadline error:", deadlineErr)
77-
return
78-
}
80+
if streamSession, whepSession, foundSession := manager.SessionsManager.GetSessionAndWhepById(sessionId); foundSession {
81+
subscriberCtx, subscriberCancel := context.WithCancel(ctx)
82+
defer subscriberCancel()
7983

80-
if err != nil {
81-
if errors.Is(err, os.ErrDeadlineExceeded) {
82-
log.Println("API.SSE Write timeout")
83-
} else {
84-
log.Println("API.SSE Write error:", err)
85-
}
86-
return
87-
}
84+
subscriberID := uuid.NewString()
85+
subscriberWrite := func(msg string) bool {
86+
return writeEvent(subscriberCtx, msg)
8887
}
89-
}
90-
}
88+
if !whepSession.AddSSESubscriber(subscriberID, subscriberWrite, subscriberCancel) {
89+
helpers.LogHttpError(responseWriter, "Invalid request", http.StatusBadRequest)
90+
return
91+
}
92+
defer whepSession.RemoveSSESubscriber(subscriberID)
9193

92-
func getWhipSessionChannel(sessionId string) chan any {
93-
var channel chan any
94-
whipSession, ok := manager.SessionsManager.GetHostSessionById(sessionId)
94+
if !subscriberWrite(streamSession.GetSessionStatsEvent()) {
95+
return
96+
}
9597

96-
if ok {
97-
channel = whipSession.EventsChannel
98+
host := streamSession.Host.Load()
99+
if host != nil && !subscriberWrite(host.GetAvailableLayersEvent()) {
100+
return
101+
}
102+
103+
<-subscriberCtx.Done()
104+
log.Println("API.SSE: Client disconnected")
105+
return
98106
}
99107

100-
return channel
101-
}
108+
if streamSession, foundSession := manager.SessionsManager.GetSessionByHostSessionId(sessionId); foundSession {
109+
if !writeEvent(ctx, streamSession.GetSessionStatsEvent()) {
110+
return
111+
}
102112

103-
func getWhepSessionChannel(sessionId string) chan any {
104-
var channel chan any
105-
whepSession, ok := manager.SessionsManager.GetWhepSessionById(sessionId)
113+
ticker := time.NewTicker(5 * time.Second)
114+
defer ticker.Stop()
106115

107-
if ok {
108-
channel = whepSession.SseEventsChannel
116+
for {
117+
select {
118+
case <-ctx.Done():
119+
log.Println("API.SSE: Client disconnected")
120+
return
121+
case <-ticker.C:
122+
if !writeEvent(ctx, streamSession.GetSessionStatsEvent()) {
123+
return
124+
}
125+
}
126+
}
109127
}
110128

111-
return channel
129+
helpers.LogHttpError(responseWriter, "Invalid request", http.StatusBadRequest)
112130
}

internal/webrtc/sessions/manager/manager.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ func (manager *SessionManager) addSession(profile authorization.PublicProfile) (
4242
manager.sessions[profile.StreamKey] = s
4343
manager.sessionsLock.Unlock()
4444

45-
go s.Snapshot()
4645
go func() {
4746
<-activeContext.Done()
4847
log.Println("SessionManager.Session.Done")
@@ -175,6 +174,11 @@ func (manager *SessionManager) UpdateProfile(profile *authorization.PersonalProf
175174

176175
// Get Session by id
177176
func (manager *SessionManager) GetWhepSessionById(sessionId string) (whep *whep.WhepSession, foundSession bool) {
177+
_, whepSession, foundSession := manager.GetSessionAndWhepById(sessionId)
178+
return whepSession, foundSession
179+
}
180+
181+
func (manager *SessionManager) GetSessionAndWhepById(sessionId string) (streamSession *session.Session, whepSession *whep.WhepSession, foundSession bool) {
178182
manager.sessionsLock.RLock()
179183
defer manager.sessionsLock.RUnlock()
180184

@@ -183,11 +187,11 @@ func (manager *SessionManager) GetWhepSessionById(sessionId string) (whep *whep.
183187
whepSession, ok := session.WhepSessions[sessionId]
184188
session.WhepSessionsLock.RUnlock()
185189
if ok {
186-
return whepSession, true
190+
return session, whepSession, true
187191
}
188192
}
189193

190-
return nil, false
194+
return nil, nil, false
191195
}
192196

193197
func (manager *SessionManager) GetHostSessionById(sessionId string) (host *whip.WhipSession, foundSession bool) {

internal/webrtc/sessions/session/routines.go

Lines changed: 7 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"time"
66

77
"github.com/glimesh/broadcast-box/internal/webrtc/sessions/whep"
8-
"github.com/glimesh/broadcast-box/internal/webrtc/sessions/whip"
98
"github.com/pion/rtcp"
109
"github.com/pion/webrtc/v4"
1110
)
@@ -24,11 +23,9 @@ import (
2423
//
2524
// }
2625

27-
// When WHEP is established, send initial messages to client
28-
func (session *Session) handleWhepConnection(whipSession *whip.WhipSession, whepSession *whep.WhepSession) {
26+
// Waits for WHEP disconnect and removes the session
27+
func (session *Session) handleWhepConnection(whepSession *whep.WhepSession) {
2928
log.Println("Session.WhepSession.Connected:", session.StreamKey)
30-
whepSession.SseEventsChannel <- session.GetSessionStatsEvent()
31-
whepSession.SseEventsChannel <- whipSession.GetAvailableLayersEvent()
3229

3330
<-whepSession.ActiveContext.Done()
3431

@@ -52,9 +49,7 @@ func (session *Session) handleWhepVideoRtcpSender(whepSession *whep.WhepSession,
5249
}
5350
}
5451

55-
// - Initializes by announcing stream start to potentially awaiting clients
56-
// - Announces layers changes to clients when layers are added or removed from the session
57-
// - Triggers a status update every 5 seconds to send to all listening WHEP sessions
52+
// Broadcast stream status to connected WHEP clients while host is active.
5853
func (session *Session) hostStatusLoop() {
5954
log.Println("Session.Host.HostStatusLoop")
6055
ticker := time.NewTicker(5 * time.Second)
@@ -87,49 +82,18 @@ func (session *Session) hostStatusLoop() {
8782
if session.isEmpty() {
8883
session.close()
8984
} else if session.Host.Load() != nil {
90-
9185
status := session.GetSessionStatsEvent()
9286

9387
session.WhepSessionsLock.RLock()
88+
whepSessions := make([]*whep.WhepSession, 0, len(session.WhepSessions))
9489
for _, whepSession := range session.WhepSessions {
95-
select {
96-
case whepSession.SseEventsChannel <- status:
97-
default:
98-
log.Println("Session.Host.HostStatusLoop: SSE channel full, skipping", whepSession.SessionId)
99-
}
90+
whepSessions = append(whepSessions, whepSession)
10091
}
10192
session.WhepSessionsLock.RUnlock()
10293

103-
}
104-
}
105-
}
106-
}
107-
108-
// Start a routing that takes snapshots of the current whep sessions in the whip session.
109-
func (session *Session) Snapshot() {
110-
ticker := time.NewTicker(1 * time.Second)
111-
defer ticker.Stop()
112-
113-
for {
114-
select {
115-
case <-session.ActiveContext.Done():
116-
if host := session.Host.Load(); host != nil {
117-
host.WhepSessionsSnapshot.Store(make(map[string]*whep.WhepSession))
118-
}
119-
return
120-
case <-ticker.C:
121-
if host := session.Host.Load(); host != nil {
122-
session.WhepSessionsLock.RLock()
123-
snapshot := make(map[string]*whep.WhepSession, len(session.WhepSessions))
124-
125-
for _, whepSession := range session.WhepSessions {
126-
if !whepSession.IsSessionClosed.Load() {
127-
snapshot[whepSession.SessionId] = whepSession
128-
}
94+
for _, whepSession := range whepSessions {
95+
whepSession.BroadcastSSE(status)
12996
}
130-
session.WhepSessionsLock.RUnlock()
131-
132-
host.WhepSessionsSnapshot.Store(snapshot)
13397
}
13498
}
13599
}

internal/webrtc/sessions/session/session.go

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,9 @@ func (session *Session) AddWhep(whepSessionId string, peerConnection *webrtc.Pee
7272
session.WhepSessionsLock.Lock()
7373
session.WhepSessions[whepSessionId] = whepSession
7474
session.WhepSessionsLock.Unlock()
75+
session.updateHostWhepSessionsSnapshot()
7576

76-
go session.handleWhepConnection(host, whepSession)
77+
go session.handleWhepConnection(whepSession)
7778
go session.handleWhepVideoRtcpSender(whepSession, videoRtcpSender)
7879

7980
return nil
@@ -101,11 +102,9 @@ func (session *Session) AddHost(peerConnection *webrtc.PeerConnection) (err erro
101102
activeContext, activeContextCancel := context.WithCancel(context.Background())
102103

103104
host := &whip.WhipSession{
104-
Id: uuid.New().String(),
105-
AudioTracks: make(map[string]*whip.AudioTrack),
106-
VideoTracks: make(map[string]*whip.VideoTrack),
107-
OnTrackChangeChannel: make(chan struct{}, 50),
108-
EventsChannel: make(chan any, 50),
105+
Id: uuid.New().String(),
106+
AudioTracks: make(map[string]*whip.AudioTrack),
107+
VideoTracks: make(map[string]*whip.VideoTrack),
109108

110109
ActiveContext: activeContext,
111110
ActiveContextCancel: activeContextCancel,
@@ -118,6 +117,8 @@ func (session *Session) AddHost(peerConnection *webrtc.PeerConnection) (err erro
118117
host.RemoveTracks()
119118
return fmt.Errorf("session already has a host")
120119
}
120+
host.WhepSessionsSnapshot.Store(make(map[string]*whep.WhepSession))
121+
session.updateHostWhepSessionsSnapshot()
121122

122123
go session.hostStatusLoop()
123124

@@ -134,6 +135,7 @@ func (session *Session) RemoveHost() {
134135

135136
log.Println("Session.RemoveHost", session.StreamKey)
136137

138+
host.WhepSessionsSnapshot.Store(make(map[string]*whep.WhepSession))
137139
host.ActiveContextCancel()
138140
host.RemovePeerConnection()
139141
host.RemoveTracks()
@@ -153,6 +155,7 @@ func (session *Session) removeWhep(whepSessionId string) {
153155
log.Println("Session.RemoveWhepSession.InvalidSession:", session.StreamKey, " - ", whepSessionId)
154156
}
155157
session.WhepSessionsLock.Unlock()
158+
session.updateHostWhepSessionsSnapshot()
156159

157160
if session.isEmpty() {
158161
session.close()
@@ -161,14 +164,19 @@ func (session *Session) removeWhep(whepSessionId string) {
161164

162165
// Remove all Hosts and clients before closing down session
163166
func (session *Session) close() {
164-
165167
session.WhepSessionsLock.Lock()
166-
for _, whep := range session.WhepSessions {
167-
whep.Close()
168+
whepSessions := make([]*whep.WhepSession, 0, len(session.WhepSessions))
169+
for _, whepSession := range session.WhepSessions {
170+
whepSessions = append(whepSessions, whepSession)
168171
}
169172
session.WhepSessions = make(map[string]*whep.WhepSession)
170173
session.WhepSessionsLock.Unlock()
171174

175+
for _, whepSession := range whepSessions {
176+
whepSession.Close()
177+
}
178+
session.updateHostWhepSessionsSnapshot()
179+
172180
session.RemoveHost()
173181

174182
session.ActiveContextCancel()
@@ -233,6 +241,24 @@ func (session *Session) hasWhepSessions() bool {
233241
return true
234242
}
235243

244+
func (session *Session) updateHostWhepSessionsSnapshot() {
245+
host := session.Host.Load()
246+
if host == nil {
247+
return
248+
}
249+
250+
session.WhepSessionsLock.RLock()
251+
snapshot := make(map[string]*whep.WhepSession, len(session.WhepSessions))
252+
for _, whepSession := range session.WhepSessions {
253+
if !whepSession.IsSessionClosed.Load() {
254+
snapshot[whepSession.SessionId] = whepSession
255+
}
256+
}
257+
session.WhepSessionsLock.RUnlock()
258+
259+
host.WhepSessionsSnapshot.Store(snapshot)
260+
}
261+
236262
// Get the status of the current session
237263
func (session *Session) GetStreamStatus() (status WhipSessionStatus) {
238264
session.WhepSessionsLock.RLock()

0 commit comments

Comments
 (0)