diff --git a/stream/srt/srt.go b/stream/srt/srt.go index ad85035..6baa5e3 100644 --- a/stream/srt/srt.go +++ b/stream/srt/srt.go @@ -4,6 +4,7 @@ package srt import "C" import ( + "fmt" "gitlab.crans.org/nounous/ghostream/auth" "gitlab.crans.org/nounous/ghostream/auth/bypass" "log" @@ -27,6 +28,11 @@ type Packet struct { StreamName string } +var ( + authBackend auth.Backend + forwardingChannel chan Packet +) + // Split host and port from listen address func splitHostPort(hostport string) (string, uint16) { host, portS, err := net.SplitHostPort(hostport) @@ -44,10 +50,12 @@ func splitHostPort(hostport string) (string, uint16) { } // Serve SRT server -func Serve(cfg *Options, authBackend auth.Backend, forwardingChannel chan Packet) { - if authBackend == nil { - authBackend, _ = bypass.New() +func Serve(cfg *Options, backend auth.Backend, forwarding chan Packet) { + if backend == nil { + backend, _ = bypass.New() } + authBackend = backend + forwardingChannel = forwarding options := make(map[string]string) options["transtype"] = "live" @@ -71,63 +79,72 @@ func Serve(cfg *Options, authBackend auth.Backend, forwardingChannel chan Packet break // FIXME: should not break here } - streamID, err := s.GetSockOptString(C.SRTO_STREAMID) - if err != nil { - log.Println("Error while fetching stream key:", err) - s.Close() - continue - } - if !strings.Contains(streamID, "|") { - log.Printf("Warning: stream id must be at the format streamID|password. Input: %s", streamID) - s.Close() - continue - } - - splittedStreamID := strings.SplitN(streamID, "|", 2) - streamName, password := splittedStreamID[0], splittedStreamID[1] - loggedIn, err := authBackend.Login(streamName, password) - if !loggedIn { - log.Printf("Invalid credentials for stream %s.", streamName) - s.Close() - continue - } - - log.Printf("Starting stream %s...", streamName) - - // Create a new buffer - buff := make([]byte, 2048) - - // Setup stream forwarding - forwardingChannel <- Packet{StreamName: streamName, PacketType: "register", Data: nil} - - // Read RTP packets forever and send them to the WebRTC Client - for { - n, err := s.Read(buff, 10000) - if err != nil { - log.Println("Error occured while reading SRT socket:", err) - break - } - - if n == 0 { - // End of stream - log.Printf("Received no bytes, stopping stream.") - break - } - // log.Printf("Received %d bytes", n) - - // Send raw packet to other streams - // Copy data in another buffer to ensure that the data would not be overwritten - data := make([]byte, n) - copy(data, buff[:n]) - forwardingChannel <- Packet{StreamName: streamName, PacketType: "sendData", Data: data} - - // TODO: Send to WebRTC - // See https://github.com/ebml-go/webm/blob/master/reader.go - //err := videoTrack.WriteSample(media.Sample{Data: data, Samples: uint32(sampleCount)}) - } - - forwardingChannel <- Packet{StreamName: streamName, PacketType: "close", Data: nil} + go acceptSocket(s) } sck.Close() } + +func acceptSocket(s *srtgo.SrtSocket) { + streamName, err := authenticateSocket(s) + if err != nil { + log.Println("Authentication failure:", err) + s.Close() + return + } + + log.Printf("Starting stream %s...", streamName) + + // Create a new buffer + buff := make([]byte, 2048) + + // Setup stream forwarding + forwardingChannel <- Packet{StreamName: streamName, PacketType: "register", Data: nil} + + // Read RTP packets forever and send them to the WebRTC Client + for { + n, err := s.Read(buff, 10000) + if err != nil { + log.Println("Error occured while reading SRT socket:", err) + break + } + + if n == 0 { + // End of stream + log.Printf("Received no bytes, stopping stream.") + break + } + // log.Printf("Received %d bytes", n) + + // Send raw packet to other streams + // Copy data in another buffer to ensure that the data would not be overwritten + data := make([]byte, n) + copy(data, buff[:n]) + forwardingChannel <- Packet{StreamName: streamName, PacketType: "sendData", Data: data} + + // TODO: Send to WebRTC + // See https://github.com/ebml-go/webm/blob/master/reader.go + //err := videoTrack.WriteSample(media.Sample{Data: data, Samples: uint32(sampleCount)}) + } + + forwardingChannel <- Packet{StreamName: streamName, PacketType: "close", Data: nil} +} + +func authenticateSocket(s *srtgo.SrtSocket) (string, error) { + streamID, err := s.GetSockOptString(C.SRTO_STREAMID) + if err != nil { + return "", fmt.Errorf("error while fetching stream key: %s", err) + } + if !strings.Contains(streamID, "|") { + return streamID, fmt.Errorf("warning: stream id must be at the format streamID|password. Input: %s", streamID) + } + + splittedStreamID := strings.SplitN(streamID, "|", 2) + streamName, password := splittedStreamID[0], splittedStreamID[1] + loggedIn, err := authBackend.Login(streamName, password) + if !loggedIn { + return streamID, fmt.Errorf("invalid credentials for stream %s", streamName) + } + + return streamName, nil +}