diff --git a/main.go b/main.go index 48c4600e..1c80cf5b 100644 --- a/main.go +++ b/main.go @@ -29,7 +29,13 @@ const ( networkTestFailedMessage = "\033[0;31mNetwork Test failed.\n%s\nPlease see the README and join Discord for help\033[0m" ) -var errNoBuildDirectoryErr = errors.New("\033[0;31mBuild directory does not exist, run `npm install` and `npm run build` in the web directory.\033[0m") +var ( + errNoBuildDirectoryErr = errors.New("\033[0;31mBuild directory does not exist, run `npm install` and `npm run build` in the web directory.\033[0m") + errAuthorizationNotSet = errors.New("authorization was not set") + errInvalidStreamKey = errors.New("invalid stream key format") + + streamKeyRegex = regexp.MustCompile(`^[a-zA-Z0-9_\-\.~]+$`) +) type ( whepLayerRequestJSON struct { @@ -38,38 +44,39 @@ type ( } ) -func logHTTPError(w http.ResponseWriter, err string, code int) { - log.Println(err) - http.Error(w, err, code) -} - -func validateStreamKey(streamKey string) bool { - return regexp.MustCompile(`^[a-zA-Z0-9_\-\.~]+$`).MatchString(streamKey) -} +func getStreamKey(r *http.Request) (string, error) { + authorizationHeader := r.Header.Get("Authorization") + if authorizationHeader == "" { + return "", errAuthorizationNotSet + } -func extractBearerToken(authHeader string) (string, bool) { const bearerPrefix = "Bearer " - if strings.HasPrefix(authHeader, bearerPrefix) { - return strings.TrimPrefix(authHeader, bearerPrefix), true + if !strings.HasPrefix(authorizationHeader, bearerPrefix) { + return "", errInvalidStreamKey } - return "", false -} -func whipHandler(res http.ResponseWriter, r *http.Request) { - if r.Method == "DELETE" { - return + bearerToken := strings.TrimPrefix(authorizationHeader, bearerPrefix) + if !streamKeyRegex.MatchString(bearerToken) { + return "", errInvalidStreamKey } - streamKeyHeader := r.Header.Get("Authorization") - if streamKeyHeader == "" { - logHTTPError(res, "Authorization was not set", http.StatusBadRequest) + return bearerToken, nil + +} + +func logHTTPError(w http.ResponseWriter, err string, code int) { + log.Println(err) + http.Error(w, err, code) +} + +func whipHandler(res http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { return } - streamKey, ok := extractBearerToken(streamKeyHeader) - if !ok || !validateStreamKey(streamKey) { - logHTTPError(res, "Invalid stream key format", http.StatusBadRequest) - return + streamKey, err := getStreamKey(r) + if err != nil { + logHTTPError(res, err.Error(), http.StatusBadRequest) } offer, err := io.ReadAll(r.Body) @@ -93,16 +100,13 @@ func whipHandler(res http.ResponseWriter, r *http.Request) { } func whepHandler(res http.ResponseWriter, req *http.Request) { - streamKeyHeader := req.Header.Get("Authorization") - if streamKeyHeader == "" { - logHTTPError(res, "Authorization was not set", http.StatusBadRequest) + if req.Method != "POST" { return } - streamKey, ok := extractBearerToken(streamKeyHeader) - if !ok || !validateStreamKey(streamKey) { - logHTTPError(res, "Invalid stream key format", http.StatusBadRequest) - return + streamKey, err := getStreamKey(req) + if err != nil { + logHTTPError(res, err.Error(), http.StatusBadRequest) } offer, err := io.ReadAll(req.Body)