diff --git a/README.md b/README.md index 94c69285..a325c2d3 100644 --- a/README.md +++ b/README.md @@ -206,6 +206,7 @@ The frontend can be configured by passing these URL Parameters. The backend can be configured with the following environment variables. +- `WEBHOOK_URL` - URL for Webhook Backend. Provides authentication and logging - `DISABLE_STATUS` - Disable the status API - `DISABLE_FRONTEND` - Disable the serving of frontend. Only REST APIs + WebRTC is enabled. - `HTTP_ADDRESS` - HTTP Server Address @@ -235,6 +236,16 @@ The backend can be configured with the following environment variables. - `DEBUG_PRINT_OFFER` - Print WebRTC Offers from client to Broadcast Box. Debug things like accepted codecs. - `DEBUG_PRINT_ANSWER` - Print WebRTC Answers from Broadcast Box to Browser. Debug things like IP/Ports returned to client. +## Authentication and Logging + +To prevent random users from streaming to your server, you can set the `WEBHOOK_URL` and validate/process requests in your code. + +If the request succeeds (meaning the stream key is accepted), broadcast-box redirects the stream to an url given +by the external server, otherwise the streaming request is dropped. + +See [here](examples/webhook-server.go). For an example Webhook Server that only allows the stream `broadcastBoxRulez` + + ## Network Test on Start When running in Docker Broadcast Box runs a network tests on startup. This tests that WebRTC traffic can be established diff --git a/examples/webhook-server.go b/examples/webhook-server.go new file mode 100644 index 00000000..bdb9e804 --- /dev/null +++ b/examples/webhook-server.go @@ -0,0 +1,51 @@ +package main + +import ( + "encoding/json" + "log" + "net/http" +) + +type webhookPayload struct { + Action string `json:"action"` + IP string `json:"ip"` + BearerToken string `json:"bearerToken"` + QueryParams map[string]string `json:"queryParams"` + UserAgent string `json:"userAgent"` +} + +type webhookResponse struct { + StreamKey string `json:"streamKey"` +} + +func main() { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.Error(w, "Only POST method is accepted", http.StatusMethodNotAllowed) + return + } + + var payload webhookPayload + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + http.Error(w, "Invalid JSON", http.StatusBadRequest) + return + } + + if payload.BearerToken == "broadcastBoxRulez" { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(webhookResponse{StreamKey: payload.BearerToken}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } else { + w.WriteHeader(http.StatusForbidden) + if err := json.NewEncoder(w).Encode(webhookResponse{}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } + }) + + log.Println("Server listening on port 8081") + if err := http.ListenAndServe("127.0.0.1:8081", nil); err != nil { + log.Fatalf("Could not start server: %s\n", err) + } +} diff --git a/internal/webhook/webhook.go b/internal/webhook/webhook.go index c519066e..6d17a283 100644 --- a/internal/webhook/webhook.go +++ b/internal/webhook/webhook.go @@ -8,6 +8,8 @@ import ( "time" ) +const defaultTimeout = time.Second * 5 + type webhookPayload struct { Action string `json:"action"` IP string `json:"ip"` @@ -20,7 +22,7 @@ type webhookResponse struct { StreamKey string `json:"streamKey"` } -func CallWebhook(url, action, bearerToken string, timeout int, r *http.Request) (string, error) { +func CallWebhook(url, action, bearerToken string, r *http.Request) (string, error) { start := time.Now() queryParams := make(map[string]string) @@ -41,17 +43,15 @@ func CallWebhook(url, action, bearerToken string, timeout int, r *http.Request) return "", fmt.Errorf("failed to marshal payload: %w", err) } - client := &http.Client{ - Timeout: time.Duration(timeout) * time.Millisecond, - } - req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonPayload)) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") - resp, err := client.Do(req) + resp, err := (&http.Client{ + Timeout: defaultTimeout, + }).Do(req) if err != nil { return "", fmt.Errorf("webhook request failed after %v: %w", time.Since(start), err) } diff --git a/internal/webhook/webhook_test.go b/internal/webhook/webhook_test.go index 8b6c4da8..f4b52173 100644 --- a/internal/webhook/webhook_test.go +++ b/internal/webhook/webhook_test.go @@ -17,7 +17,7 @@ func TestCallWebhook(t *testing.T) { w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(webhookResponse{StreamKey: "dummy_stream_key"}) case "/timeout": - time.Sleep(2 * time.Second) + time.Sleep(7 * time.Second) case "/error": w.WriteHeader(http.StatusInternalServerError) case "/badjson": @@ -32,15 +32,14 @@ func TestCallWebhook(t *testing.T) { tests := []struct { name string url string - timeout int expectedErr bool expectedKey string }{ - {"Success Case", "/ok", 1000, false, "dummy_stream_key"}, - {"Server Timeout", "/timeout", 1000, true, ""}, - {"Server Error", "/error", 1000, true, ""}, - {"Malformed JSON", "/badjson", 1000, true, ""}, - {"Not Found", "/notfound", 1000, true, ""}, + {"Success Case", "/ok", false, "dummy_stream_key"}, + {"Server Timeout", "/timeout", true, ""}, + {"Server Error", "/error", true, ""}, + {"Malformed JSON", "/badjson", true, ""}, + {"Not Found", "/notfound", true, ""}, } for _, tt := range tests { @@ -50,7 +49,7 @@ func TestCallWebhook(t *testing.T) { req.Header.Set("User-Agent", "test-agent") // call the function with test layers - result, err := CallWebhook(fmt.Sprintf("%s%s", mockServer.URL, tt.url), "action", "bearerToken", tt.timeout, req) + result, err := CallWebhook(fmt.Sprintf("%s%s", mockServer.URL, tt.url), "action", "bearerToken", req) if tt.expectedErr && err == nil { t.Fatalf("expected an error but got none") diff --git a/main.go b/main.go index 1c80cf5b..cccad9a8 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( "time" "github.com/glimesh/broadcast-box/internal/networktest" + "github.com/glimesh/broadcast-box/internal/webhook" "github.com/glimesh/broadcast-box/internal/webrtc" "github.com/joho/godotenv" ) @@ -44,7 +45,7 @@ type ( } ) -func getStreamKey(r *http.Request) (string, error) { +func getStreamKey(action string, r *http.Request) (streamKey string, err error) { authorizationHeader := r.Header.Get("Authorization") if authorizationHeader == "" { return "", errAuthorizationNotSet @@ -55,13 +56,19 @@ func getStreamKey(r *http.Request) (string, error) { return "", errInvalidStreamKey } - bearerToken := strings.TrimPrefix(authorizationHeader, bearerPrefix) - if !streamKeyRegex.MatchString(bearerToken) { - return "", errInvalidStreamKey + streamKey = strings.TrimPrefix(authorizationHeader, bearerPrefix) + if webhookUrl := os.Getenv("WEBHOOK_URL"); webhookUrl != "" { + streamKey, err = webhook.CallWebhook(webhookUrl, action, streamKey, r) + if err != nil { + return "", err + } } - return bearerToken, nil + if !streamKeyRegex.MatchString(streamKey) { + return "", errInvalidStreamKey + } + return streamKey, nil } func logHTTPError(w http.ResponseWriter, err string, code int) { @@ -74,9 +81,10 @@ func whipHandler(res http.ResponseWriter, r *http.Request) { return } - streamKey, err := getStreamKey(r) + streamKey, err := getStreamKey("whip-connect", r) if err != nil { logHTTPError(res, err.Error(), http.StatusBadRequest) + return } offer, err := io.ReadAll(r.Body) @@ -104,9 +112,10 @@ func whepHandler(res http.ResponseWriter, req *http.Request) { return } - streamKey, err := getStreamKey(req) + streamKey, err := getStreamKey("whep-connect", req) if err != nil { logHTTPError(res, err.Error(), http.StatusBadRequest) + return } offer, err := io.ReadAll(req.Body) @@ -207,25 +216,25 @@ func corsHandler(next func(w http.ResponseWriter, r *http.Request)) http.Handler } } -func main() { - loadConfigs := func() error { - if os.Getenv("APP_ENV") == "development" { - log.Println("Loading `" + envFileDev + "`") - return godotenv.Load(envFileDev) - } else { - log.Println("Loading `" + envFileProd + "`") - if err := godotenv.Load(envFileProd); err != nil { - return err - } - - if _, err := os.Stat("./web/build"); os.IsNotExist(err) && os.Getenv("DISABLE_FRONTEND") == "" { - return errNoBuildDirectoryErr - } +func loadConfigs() error { + if os.Getenv("APP_ENV") == "development" { + log.Println("Loading `" + envFileDev + "`") + return godotenv.Load(envFileDev) + } else { + log.Println("Loading `" + envFileProd + "`") + if err := godotenv.Load(envFileProd); err != nil { + return err + } - return nil + if _, err := os.Stat("./web/build"); os.IsNotExist(err) && os.Getenv("DISABLE_FRONTEND") == "" { + return errNoBuildDirectoryErr } + + return nil } +} +func main() { if err := loadConfigs(); err != nil { log.Println("Failed to find config in CWD, changing CWD to executable path")