Use reference to Stream

This commit is contained in:
Alexandre Iooss 2020-10-17 13:43:16 +02:00
parent 5b85eed646
commit 5b8c73057b
No known key found for this signature in database
GPG Key ID: 6C79278F3FCDCC02
6 changed files with 30 additions and 18 deletions

View File

@ -11,6 +11,9 @@ type Stream struct {
// Use a map to be able to delete an item // Use a map to be able to delete an item
outputs map[chan []byte]struct{} outputs map[chan []byte]struct{}
// Count clients for statistics
nbClients int
// Mutex to lock this ressource // Mutex to lock this ressource
lock sync.Mutex lock sync.Mutex
} }
@ -21,6 +24,7 @@ func New() *Stream {
broadcast := make(chan []byte, 64) broadcast := make(chan []byte, 64)
s.Broadcast = broadcast s.Broadcast = broadcast
s.outputs = make(map[chan []byte]struct{}) s.outputs = make(map[chan []byte]struct{})
s.nbClients = 0
go s.run(broadcast) go s.run(broadcast)
return s return s
} }
@ -56,15 +60,20 @@ func (s *Stream) Close() {
close(s.Broadcast) close(s.Broadcast)
} }
// Register a new output on a stream // Register a new output on a stream.
func (s *Stream) Register(output chan []byte) { // If hidden in true, then do not count this client.
func (s *Stream) Register(output chan []byte, hidden bool) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
s.outputs[output] = struct{}{} s.outputs[output] = struct{}{}
if !hidden {
s.nbClients++
}
} }
// Unregister removes an output // Unregister removes an output.
func (s *Stream) Unregister(output chan []byte) { // If hidden in true, then do not count this client.
func (s *Stream) Unregister(output chan []byte, hidden bool) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
@ -73,10 +82,13 @@ func (s *Stream) Unregister(output chan []byte) {
if ok { if ok {
delete(s.outputs, output) delete(s.outputs, output)
close(output) close(output)
if !hidden {
s.nbClients--
}
} }
} }
// Count number of outputs // Count number of clients
func (s *Stream) Count() int { func (s *Stream) Count() int {
return len(s.outputs) return s.nbClients
} }

View File

@ -8,7 +8,7 @@ import (
"gitlab.crans.org/nounous/ghostream/stream" "gitlab.crans.org/nounous/ghostream/stream"
) )
func handleStreamer(socket *srtgo.SrtSocket, streams map[string]stream.Stream, name string) { func handleStreamer(socket *srtgo.SrtSocket, streams map[string]*stream.Stream, name string) {
// Check stream does not exist // Check stream does not exist
if _, ok := streams[name]; ok { if _, ok := streams[name]; ok {
log.Print("Stream already exists, refusing new streamer") log.Print("Stream already exists, refusing new streamer")
@ -18,7 +18,7 @@ func handleStreamer(socket *srtgo.SrtSocket, streams map[string]stream.Stream, n
// Create stream // Create stream
log.Printf("New SRT streamer for stream %s", name) log.Printf("New SRT streamer for stream %s", name)
st := *stream.New() st := stream.New()
streams[name] = st streams[name] = st
// Create a new buffer // Create a new buffer
@ -54,7 +54,7 @@ func handleStreamer(socket *srtgo.SrtSocket, streams map[string]stream.Stream, n
delete(streams, name) delete(streams, name)
} }
func handleViewer(s *srtgo.SrtSocket, streams map[string]stream.Stream, name string) { func handleViewer(s *srtgo.SrtSocket, streams map[string]*stream.Stream, name string) {
log.Printf("New SRT viewer for stream %s", name) log.Printf("New SRT viewer for stream %s", name)
// Get requested stream // Get requested stream
@ -66,16 +66,16 @@ func handleViewer(s *srtgo.SrtSocket, streams map[string]stream.Stream, name str
// Register new output // Register new output
c := make(chan []byte, 128) c := make(chan []byte, 128)
st.Register(c) st.Register(c, false)
// Receive data and send them // Receive data and send them
for { for data := range c {
data := <-c
if len(data) < 1 { if len(data) < 1 {
log.Print("Remove SRT viewer because of end of stream") log.Print("Remove SRT viewer because of end of stream")
break break
} }
// Send data
_, err := s.Write(data, 1000) _, err := s.Write(data, 1000)
if err != nil { if err != nil {
log.Printf("Remove SRT viewer because of sending error, %s", err) log.Printf("Remove SRT viewer because of sending error, %s", err)
@ -84,6 +84,6 @@ func handleViewer(s *srtgo.SrtSocket, streams map[string]stream.Stream, name str
} }
// Close output // Close output
st.Unregister(c) st.Unregister(c, false)
s.Close() s.Close()
} }

View File

@ -39,7 +39,7 @@ func splitHostPort(hostport string) (string, uint16, error) {
} }
// Serve SRT server // Serve SRT server
func Serve(streams map[string]stream.Stream, authBackend auth.Backend, cfg *Options) { func Serve(streams map[string]*stream.Stream, authBackend auth.Backend, cfg *Options) {
if !cfg.Enabled { if !cfg.Enabled {
// SRT is not enabled, ignore // SRT is not enabled, ignore
return return

View File

@ -58,7 +58,7 @@ func TestServeSRT(t *testing.T) {
} }
// Init streams messaging and SRT server // Init streams messaging and SRT server
streams := make(map[string]stream.Stream) streams := make(map[string]*stream.Stream)
go Serve(streams, nil, &Options{Enabled: true, ListenAddress: ":9711", MaxClients: 2}) go Serve(streams, nil, &Options{Enabled: true, ListenAddress: ":9711", MaxClients: 2})
ffmpeg := exec.Command("ffmpeg", "-hide_banner", "-loglevel", "error", ffmpeg := exec.Command("ffmpeg", "-hide_banner", "-loglevel", "error",

View File

@ -44,7 +44,7 @@ var (
templates *template.Template templates *template.Template
// Streams to get statistics // Streams to get statistics
streams map[string]stream.Stream streams map[string]*stream.Stream
) )
// Load templates with pkger // Load templates with pkger
@ -78,7 +78,7 @@ func loadTemplates() error {
} }
// Serve HTTP server // Serve HTTP server
func Serve(s map[string]stream.Stream, rSdpChan chan struct { func Serve(s map[string]*stream.Stream, rSdpChan chan struct {
StreamID string StreamID string
RemoteDescription webrtc.SessionDescription RemoteDescription webrtc.SessionDescription
}, lSdpChan chan webrtc.SessionDescription, c *Options) { }, lSdpChan chan webrtc.SessionDescription, c *Options) {

View File

@ -11,7 +11,7 @@ import (
// TestHTTPServe tries to serve a real HTTP server and load some pages // TestHTTPServe tries to serve a real HTTP server and load some pages
func TestHTTPServe(t *testing.T) { func TestHTTPServe(t *testing.T) {
// Init streams messaging // Init streams messaging
streams := make(map[string]stream.Stream) streams := make(map[string]*stream.Stream)
// Create a disabled web server // Create a disabled web server
go Serve(streams, nil, nil, &Options{Enabled: false, ListenAddress: "127.0.0.1:8081"}) go Serve(streams, nil, nil, &Options{Enabled: false, ListenAddress: "127.0.0.1:8081"})