From 022f6fb0984215252e0d02a11f68a8be9765a4c3 Mon Sep 17 00:00:00 2001 From: Yohann D'ANELLO Date: Mon, 5 Oct 2020 22:00:08 +0200 Subject: [PATCH] :poop: Split webrtc tracks by stream id (need to clean this, stream ID must pass between the session descriptor and the webrtc flux transmit) --- main.go | 5 ++++- stream/webrtc/ingest.go | 37 +++++++++++++++++++++++-------------- stream/webrtc/webrtc.go | 38 ++++++++++++++++++++++++++++---------- web/handler.go | 25 +++++++++++++++++++++++-- web/web.go | 12 +++++++++--- 5 files changed, 87 insertions(+), 30 deletions(-) diff --git a/main.go b/main.go index ff2c566..cd4c73a 100644 --- a/main.go +++ b/main.go @@ -104,7 +104,10 @@ func main() { defer authBackend.Close() // WebRTC session description channels - remoteSdpChan := make(chan webrtc.SessionDescription) + remoteSdpChan := make(chan struct { + StreamID string + RemoteDescription webrtc.SessionDescription + }) localSdpChan := make(chan webrtc.SessionDescription) // SRT channel for forwarding and webrtc diff --git a/stream/webrtc/ingest.go b/stream/webrtc/ingest.go index 155af1e..83e0385 100644 --- a/stream/webrtc/ingest.go +++ b/stream/webrtc/ingest.go @@ -2,6 +2,7 @@ package webrtc import ( "bufio" + "github.com/pion/webrtc/v3" "io" "log" "net" @@ -18,10 +19,10 @@ func ingestFrom(inputChannel chan srt.Packet) { for { var err error = nil - packet := <-inputChannel - switch packet.PacketType { + srtPacket := <-inputChannel + switch srtPacket.PacketType { case "register": - log.Printf("WebRTC RegisterStream %s", packet.StreamName) + log.Printf("WebRTC RegisterStream %s", srtPacket.StreamName) // Open a UDP Listener for RTP Packets on port 5004 videoListener, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5004}) @@ -74,13 +75,17 @@ func ingestFrom(inputChannel chan srt.Packet) { } packet := &rtp.Packet{} if err := packet.Unmarshal(inboundRTPPacket[:n]); err != nil { - log.Printf("Failed to unmarshal RTP packet: %s", err) + log.Printf("Failed to unmarshal RTP srtPacket: %s", err) continue } - // Write RTP packet to all video tracks + if videoTracks[srtPacket.StreamName] == nil { + videoTracks[srtPacket.StreamName] = make([]*webrtc.Track, 0) + } + + // Write RTP srtPacket to all video tracks // Adapt payload and SSRC to match destination - for _, videoTrack := range videoTracks { + for _, videoTrack := range videoTracks[srtPacket.StreamName] { packet.Header.PayloadType = videoTrack.PayloadType() packet.Header.SSRC = videoTrack.SSRC() if writeErr := videoTrack.WriteRTP(packet); writeErr != nil { @@ -102,13 +107,17 @@ func ingestFrom(inputChannel chan srt.Packet) { } packet := &rtp.Packet{} if err := packet.Unmarshal(inboundRTPPacket[:n]); err != nil { - log.Printf("Failed to unmarshal RTP packet: %s", err) + log.Printf("Failed to unmarshal RTP srtPacket: %s", err) continue } - // Write RTP packet to all audio tracks + if audioTracks[srtPacket.StreamName] == nil { + audioTracks[srtPacket.StreamName] = make([]*webrtc.Track, 0) + } + + // Write RTP srtPacket to all audio tracks // Adapt payload and SSRC to match destination - for _, audioTrack := range audioTracks { + for _, audioTrack := range audioTracks[srtPacket.StreamName] { packet.Header.PayloadType = audioTrack.PayloadType() packet.Header.SSRC = audioTrack.SSRC() if writeErr := audioTrack.WriteRTP(packet); writeErr != nil { @@ -127,20 +136,20 @@ func ingestFrom(inputChannel chan srt.Packet) { }() break case "sendData": - // FIXME send to stream packet.StreamName - if _, err := ffmpegInput.Write(packet.Data); err != nil { + // FIXME send to stream srtPacket.StreamName + if _, err := ffmpegInput.Write(srtPacket.Data); err != nil { log.Printf("Failed to write data to ffmpeg input: %s", err) } break case "close": - log.Printf("WebRTC CloseConnection %s", packet.StreamName) + log.Printf("WebRTC CloseConnection %s", srtPacket.StreamName) break default: - log.Println("Unknown SRT packet type:", packet.PacketType) + log.Println("Unknown SRT srtPacket type:", srtPacket.PacketType) break } if err != nil { - log.Printf("Error occured while receiving SRT packet of type %s: %s", packet.PacketType, err) + log.Printf("Error occured while receiving SRT srtPacket of type %s: %s", srtPacket.PacketType, err) } } } diff --git a/stream/webrtc/webrtc.go b/stream/webrtc/webrtc.go index 82fdac6..6f2b0dd 100644 --- a/stream/webrtc/webrtc.go +++ b/stream/webrtc/webrtc.go @@ -23,8 +23,8 @@ type Options struct { type SessionDescription = webrtc.SessionDescription var ( - videoTracks []*webrtc.Track - audioTracks []*webrtc.Track + videoTracks map[string][]*webrtc.Track + audioTracks map[string][]*webrtc.Track ) // Helper to reslice tracks @@ -44,10 +44,13 @@ func GetNumberConnectedSessions() int { // newPeerHandler is called when server receive a new session description // this initiates a WebRTC connection and return server description -func newPeerHandler(remoteSdp webrtc.SessionDescription, cfg *Options) webrtc.SessionDescription { +func newPeerHandler(remoteSdp struct { + StreamID string + RemoteDescription webrtc.SessionDescription +}, cfg *Options) webrtc.SessionDescription { // Create media engine using client SDP mediaEngine := webrtc.MediaEngine{} - if err := mediaEngine.PopulateFromSDP(remoteSdp); err != nil { + if err := mediaEngine.PopulateFromSDP(remoteSdp.RemoteDescription); err != nil { log.Println("Failed to create new media engine", err) return webrtc.SessionDescription{} } @@ -95,7 +98,7 @@ func newPeerHandler(remoteSdp webrtc.SessionDescription, cfg *Options) webrtc.Se } // Set the remote SessionDescription - if err = peerConnection.SetRemoteDescription(remoteSdp); err != nil { + if err = peerConnection.SetRemoteDescription(remoteSdp.RemoteDescription); err != nil { log.Println("Failed to set remote description", err) return webrtc.SessionDescription{} } @@ -116,19 +119,27 @@ func newPeerHandler(remoteSdp webrtc.SessionDescription, cfg *Options) webrtc.Se return webrtc.SessionDescription{} } + streamID := remoteSdp.StreamID + // Set the handler for ICE connection state // This will notify you when the peer has connected/disconnected peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { log.Printf("Connection State has changed %s \n", connectionState.String()) + if videoTracks[streamID] == nil { + videoTracks[streamID] = make([]*webrtc.Track, 0, 1) + } + if audioTracks[streamID] == nil { + audioTracks[streamID] = make([]*webrtc.Track, 0, 1) + } if connectionState == webrtc.ICEConnectionStateConnected { // Register tracks - videoTracks = append(videoTracks, videoTrack) - audioTracks = append(audioTracks, audioTrack) + videoTracks[streamID] = append(videoTracks[streamID], videoTrack) + audioTracks[streamID] = append(audioTracks[streamID], audioTrack) monitoring.WebRTCConnectedSessions.Inc() } else if connectionState == webrtc.ICEConnectionStateDisconnected { // Unregister tracks - videoTracks = removeTrack(videoTracks, videoTrack) - audioTracks = removeTrack(audioTracks, audioTrack) + videoTracks[streamID] = removeTrack(videoTracks[streamID], videoTrack) + audioTracks[streamID] = removeTrack(audioTracks[streamID], audioTrack) monitoring.WebRTCConnectedSessions.Dec() } }) @@ -155,9 +166,16 @@ func getPayloadType(m webrtc.MediaEngine, codecType webrtc.RTPCodecType, codecNa } // Serve WebRTC media streaming server -func Serve(remoteSdpChan, localSdpChan chan webrtc.SessionDescription, inputChannel chan srt.Packet, cfg *Options) { +func Serve(remoteSdpChan chan struct { + StreamID string + RemoteDescription webrtc.SessionDescription +}, localSdpChan chan webrtc.SessionDescription, inputChannel chan srt.Packet, cfg *Options) { log.Printf("WebRTC server using UDP from port %d to %d", cfg.MinPortUDP, cfg.MaxPortUDP) + // Allocate memory + videoTracks = make(map[string][]*webrtc.Track) + audioTracks = make(map[string][]*webrtc.Track) + // Ingest data from SRT go ingestFrom(inputChannel) diff --git a/web/handler.go b/web/handler.go index c8e74e0..46f003d 100644 --- a/web/handler.go +++ b/web/handler.go @@ -19,6 +19,21 @@ func viewerPostHandler(w http.ResponseWriter, r *http.Request) { // Limit response body to 128KB r.Body = http.MaxBytesReader(w, r.Body, 131072) + // Get stream ID from URL, or from domain name + path := r.URL.Path[1:] + if cfg.OneStreamPerDomain { + host := r.Host + if strings.Contains(host, ":") { + realHost, _, err := net.SplitHostPort(r.Host) + if err != nil { + log.Printf("Failed to split host and port from %s", r.Host) + return + } + host = realHost + } + path = host + } + // Decode client description dec := json.NewDecoder(r.Body) dec.DisallowUnknownFields() @@ -29,7 +44,10 @@ func viewerPostHandler(w http.ResponseWriter, r *http.Request) { } // Exchange session descriptions with WebRTC stream server - remoteSdpChan <- remoteDescription + remoteSdpChan <- struct { + StreamID string + RemoteDescription webrtc.SessionDescription + }{StreamID: path, RemoteDescription: remoteDescription} localDescription := <-localSdpChan // Send server description as JSON @@ -40,7 +58,10 @@ func viewerPostHandler(w http.ResponseWriter, r *http.Request) { return } w.Header().Set("Content-Type", "application/json") - _, _ = w.Write(jsonDesc) + _, err = w.Write(jsonDesc) + if err != nil { + log.Println("An error occurred while sending session description", err) + } // Increment monitoring monitoring.WebSessions.Inc() diff --git a/web/web.go b/web/web.go index cecae91..f28b3da 100644 --- a/web/web.go +++ b/web/web.go @@ -30,8 +30,11 @@ var ( cfg *Options // WebRTC session description channels - remoteSdpChan chan webrtc.SessionDescription - localSdpChan chan webrtc.SessionDescription + remoteSdpChan chan struct { + StreamID string + RemoteDescription webrtc.SessionDescription + } + localSdpChan chan webrtc.SessionDescription // Preload templates templates *template.Template @@ -71,7 +74,10 @@ func loadTemplates() error { } // Serve HTTP server -func Serve(rSdpChan chan webrtc.SessionDescription, lSdpChan chan webrtc.SessionDescription, c *Options) { +func Serve(rSdpChan chan struct { + StreamID string + RemoteDescription webrtc.SessionDescription +}, lSdpChan chan webrtc.SessionDescription, c *Options) { remoteSdpChan = rSdpChan localSdpChan = lSdpChan cfg = c