diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..ec6e28b --- /dev/null +++ b/conn.go @@ -0,0 +1,12 @@ +package srtp + +import ( + "context" +) + +// ConnCtx is a Conn controlled by context.Context instead of SetDeadline. +type ConnCtx interface { + ReadContext(context.Context, []byte) (int, error) + WriteContext(context.Context, []byte) (int, error) + Close() error +} diff --git a/go.mod b/go.mod index 0f23507..8cc1b7d 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,11 @@ -module github.com/pion/srtp +module github.com/pion/srtp/v2 -go 1.12 +go 1.14 require ( - github.com/davecgh/go-spew v1.1.1 // indirect github.com/pion/logging v0.2.2 github.com/pion/rtcp v1.2.4 github.com/pion/rtp v1.6.1 - github.com/pion/transport v0.10.1 + github.com/pion/transport v0.11.0 github.com/stretchr/testify v1.6.1 ) diff --git a/go.sum b/go.sum index c0661c9..71b3858 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= @@ -10,8 +8,8 @@ github.com/pion/rtcp v1.2.4 h1:NT3H5LkUGgaEapvp0HGik+a+CpflRF7KTD7H+o7OWIM= github.com/pion/rtcp v1.2.4/go.mod h1:52rMNPWFsjr39z9B9MhnkqhPLoeHTv1aN63o/42bWE0= github.com/pion/rtp v1.6.1 h1:2Y2elcVBrahYnHKN2X7rMHX/r1R4TEBMP1LaVu/wNhk= github.com/pion/rtp v1.6.1/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= -github.com/pion/transport v0.10.1 h1:2W+yJT+0mOQ160ThZYUx5Zp2skzshiNgxrNE9GUfhJM= -github.com/pion/transport v0.10.1/go.mod h1:PBis1stIILMiis0PewDw91WJeLJkyIMcEk+DwKOzf4A= +github.com/pion/transport v0.11.0 h1:Z1RhzqrWPPYj5Xed8P7pirTKTvXFoxDI3uJuuKu6akM= +github.com/pion/transport v0.11.0/go.mod h1:ORH8Ouyl1enoJyHwU+MwMeQocWbeorEk5068FOsHjog= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -20,12 +18,14 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4= -golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201031054903-ff519b6c9102 h1:42cLlJJdEh+ySyeUUbEQ5bsTiq8voBeTuweGVkY6Puw= +golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= diff --git a/session.go b/session.go index f10c576..160c181 100644 --- a/session.go +++ b/session.go @@ -1,8 +1,8 @@ package srtp import ( + "context" "io" - "net" "sync" "github.com/pion/logging" @@ -10,8 +10,8 @@ import ( type streamSession interface { Close() error - write([]byte) (int, error) - decrypt([]byte) error + write(context.Context, []byte) (int, error) + decrypt(context.Context, []byte) error } type session struct { @@ -30,7 +30,7 @@ type session struct { log logging.LeveledLogger - nextConn net.Conn + nextConn ConnCtx } // Config is used to configure a session. @@ -102,7 +102,7 @@ func (s *session) close() error { return nil } -func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error { +func (s *session) start(ctx context.Context, localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error { var err error s.localContext, err = CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...) if err != nil { @@ -127,7 +127,7 @@ func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remote b := make([]byte, 8192) for { var i int - i, err = s.nextConn.Read(b) + i, err = s.nextConn.ReadContext(ctx, b) if err != nil { if err != io.EOF { s.log.Error(err.Error()) @@ -135,7 +135,7 @@ func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remote return } - if err = child.decrypt(b[:i]); err != nil { + if err = child.decrypt(ctx, b[:i]); err != nil { s.log.Info(err.Error()) } } diff --git a/session_srtcp.go b/session_srtcp.go index 5a762de..e9bae1d 100644 --- a/session_srtcp.go +++ b/session_srtcp.go @@ -1,7 +1,7 @@ package srtp import ( - "net" + "context" "github.com/pion/logging" "github.com/pion/rtcp" @@ -19,7 +19,7 @@ type SessionSRTCP struct { } // NewSessionSRTCP creates a SRTCP session using conn as the underlying transport. -func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //nolint:dupl +func NewSessionSRTCP(ctx context.Context, conn ConnCtx, config *Config) (*SessionSRTCP, error) { //nolint:dupl if config == nil { return nil, errNoConfig } else if conn == nil { @@ -58,6 +58,7 @@ func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //n s.writeStream = &WriteStreamSRTCP{s} err := s.session.start( + ctx, config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt, config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt, config.Profile, @@ -107,7 +108,7 @@ func (s *SessionSRTCP) Close() error { // Private -func (s *SessionSRTCP) write(buf []byte) (int, error) { +func (s *SessionSRTCP) write(ctx context.Context, buf []byte) (int, error) { if _, ok := <-s.session.started; ok { return 0, errStartedChannelUsedIncorrectly } @@ -119,7 +120,7 @@ func (s *SessionSRTCP) write(buf []byte) (int, error) { if err != nil { return 0, err } - return s.session.nextConn.Write(encrypted) + return s.session.nextConn.WriteContext(ctx, encrypted) } // create a list of Destination SSRCs @@ -140,7 +141,7 @@ func destinationSSRC(pkts []rtcp.Packet) []uint32 { return out } -func (s *SessionSRTCP) decrypt(buf []byte) error { +func (s *SessionSRTCP) decrypt(ctx context.Context, buf []byte) error { decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil) if err != nil { return err @@ -164,7 +165,7 @@ func (s *SessionSRTCP) decrypt(buf []byte) error { return errFailedTypeAssertion } - _, err = readStream.write(decrypted) + _, err = readStream.write(ctx, decrypted) if err != nil { return err } diff --git a/session_srtcp_test.go b/session_srtcp_test.go index 9944699..29230f6 100644 --- a/session_srtcp_test.go +++ b/session_srtcp_test.go @@ -2,29 +2,34 @@ package srtp import ( "bytes" + "context" "io" - "net" "reflect" "sync" "testing" "time" "github.com/pion/rtcp" + "github.com/pion/transport/connctx" "github.com/pion/transport/test" ) const rtcpHeaderSize = 4 func TestSessionSRTCPBadInit(t *testing.T) { - if _, err := NewSessionSRTCP(nil, nil); err == nil { + ctx := context.Background() + + if _, err := NewSessionSRTCP(ctx, nil, nil); err == nil { t.Fatal("NewSessionSRTCP should error if no config was provided") - } else if _, err := NewSessionSRTCP(nil, &Config{}); err == nil { + } else if _, err := NewSessionSRTCP(ctx, nil, &Config{}); err == nil { t.Fatal("NewSessionSRTCP should error if no net was provided") } } func buildSessionSRTCPPair(t *testing.T) (*SessionSRTCP, *SessionSRTCP) { //nolint:dupl - aPipe, bPipe := net.Pipe() + ctx := context.Background() + + aPipe, bPipe := connctx.Pipe() config := &Config{ Profile: ProtectionProfileAes128CmHmacSha1_80, Keys: SessionKeys{ @@ -35,14 +40,14 @@ func buildSessionSRTCPPair(t *testing.T) (*SessionSRTCP, *SessionSRTCP) { //noli }, } - aSession, err := NewSessionSRTCP(aPipe, config) + aSession, err := NewSessionSRTCP(ctx, aPipe, config) if err != nil { t.Fatal(err) } else if aSession == nil { t.Fatal("NewSessionSRTCP did not error, but returned nil session") } - bSession, err := NewSessionSRTCP(bPipe, config) + bSession, err := NewSessionSRTCP(ctx, bPipe, config) if err != nil { t.Fatal(err) } else if bSession == nil { @@ -53,6 +58,8 @@ func buildSessionSRTCPPair(t *testing.T) (*SessionSRTCP, *SessionSRTCP) { //noli } func TestSessionSRTCP(t *testing.T) { + ctx := context.Background() + lim := test.TimeOut(time.Second * 10) defer lim.Stop() @@ -71,7 +78,7 @@ func TestSessionSRTCP(t *testing.T) { t.Fatal(err) } - if _, err = aWriteStream.Write(testPayload); err != nil { + if _, err = aWriteStream.Write(ctx, testPayload); err != nil { t.Fatal(err) } @@ -80,7 +87,7 @@ func TestSessionSRTCP(t *testing.T) { t.Fatal(err) } - if _, err = bReadStream.Read(readBuffer); err != nil { + if _, err = bReadStream.Read(ctx, readBuffer); err != nil { t.Fatal(err) } @@ -98,6 +105,8 @@ func TestSessionSRTCP(t *testing.T) { } func TestSessionSRTCPOpenReadStream(t *testing.T) { + ctx := context.Background() + lim := test.TimeOut(time.Second * 10) defer lim.Stop() @@ -121,11 +130,11 @@ func TestSessionSRTCPOpenReadStream(t *testing.T) { t.Fatal(err) } - if _, err = aWriteStream.Write(testPayload); err != nil { + if _, err = aWriteStream.Write(ctx, testPayload); err != nil { t.Fatal(err) } - if _, err = bReadStream.Read(readBuffer); err != nil { + if _, err = bReadStream.Read(ctx, readBuffer); err != nil { t.Fatal(err) } @@ -143,6 +152,8 @@ func TestSessionSRTCPOpenReadStream(t *testing.T) { } func TestSessionSRTCPReplayProtection(t *testing.T) { + ctx := context.Background() + lim := test.TimeOut(time.Second * 5) defer lim.Stop() @@ -181,7 +192,7 @@ func TestSessionSRTCPReplayProtection(t *testing.T) { go func() { defer wg.Done() for { - if ssrc, perr := getSenderSSRC(t, bReadStream); perr == nil { + if ssrc, perr := getSenderSSRC(ctx, t, bReadStream); perr == nil { receivedSSRC = append(receivedSSRC, ssrc) } else if perr == io.EOF { return @@ -191,17 +202,17 @@ func TestSessionSRTCPReplayProtection(t *testing.T) { // Write with replay attack for _, p := range packets { - if _, err = aSession.session.nextConn.Write(p); err != nil { + if _, err = aSession.session.nextConn.WriteContext(ctx, p); err != nil { t.Fatal(err) } // Immediately replay - if _, err = aSession.session.nextConn.Write(p); err != nil { + if _, err = aSession.session.nextConn.WriteContext(ctx, p); err != nil { t.Fatal(err) } } for _, p := range packets { // Delayed replay - if _, err = aSession.session.nextConn.Write(p); err != nil { + if _, err = aSession.session.nextConn.WriteContext(ctx, p); err != nil { t.Fatal(err) } } @@ -224,7 +235,7 @@ func TestSessionSRTCPReplayProtection(t *testing.T) { } } -func getSenderSSRC(t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err error) { +func getSenderSSRC(ctx context.Context, t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err error) { authTagSize, err := ProtectionProfileAes128CmHmacSha1_80.authTagLen() if err != nil { return 0, err @@ -232,7 +243,7 @@ func getSenderSSRC(t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err erro const pliPacketSize = 8 readBuffer := make([]byte, pliPacketSize+authTagSize+srtcpIndexSize) - n, _, err := stream.ReadRTCP(readBuffer) + n, _, err := stream.ReadRTCP(ctx, readBuffer) if err == io.EOF { return 0, err } diff --git a/session_srtp.go b/session_srtp.go index a8be4bd..42ca13d 100644 --- a/session_srtp.go +++ b/session_srtp.go @@ -1,7 +1,7 @@ package srtp import ( - "net" + "context" "github.com/pion/logging" "github.com/pion/rtp" @@ -19,7 +19,7 @@ type SessionSRTP struct { } // NewSessionSRTP creates a SRTP session using conn as the underlying transport. -func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nolint:dupl +func NewSessionSRTP(ctx context.Context, conn ConnCtx, config *Config) (*SessionSRTP, error) { //nolint:dupl if config == nil { return nil, errNoConfig } else if conn == nil { @@ -58,6 +58,7 @@ func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nol s.writeStream = &WriteStreamSRTP{s} err := s.session.start( + ctx, config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt, config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt, config.Profile, @@ -106,7 +107,7 @@ func (s *SessionSRTP) Close() error { return s.session.close() } -func (s *SessionSRTP) write(b []byte) (int, error) { +func (s *SessionSRTP) write(ctx context.Context, b []byte) (int, error) { packet := &rtp.Packet{} err := packet.Unmarshal(b) @@ -114,10 +115,10 @@ func (s *SessionSRTP) write(b []byte) (int, error) { return 0, nil } - return s.writeRTP(&packet.Header, packet.Payload) + return s.writeRTP(ctx, &packet.Header, packet.Payload) } -func (s *SessionSRTP) writeRTP(header *rtp.Header, payload []byte) (int, error) { +func (s *SessionSRTP) writeRTP(ctx context.Context, header *rtp.Header, payload []byte) (int, error) { if _, ok := <-s.session.started; ok { return 0, errStartedChannelUsedIncorrectly } @@ -130,10 +131,10 @@ func (s *SessionSRTP) writeRTP(header *rtp.Header, payload []byte) (int, error) return 0, err } - return s.session.nextConn.Write(encrypted) + return s.session.nextConn.WriteContext(ctx, encrypted) } -func (s *SessionSRTP) decrypt(buf []byte) error { +func (s *SessionSRTP) decrypt(ctx context.Context, buf []byte) error { h := &rtp.Header{} if err := h.Unmarshal(buf); err != nil { return err @@ -156,7 +157,7 @@ func (s *SessionSRTP) decrypt(buf []byte) error { return err } - _, err = readStream.write(decrypted) + _, err = readStream.write(ctx, decrypted) if err != nil { return err } diff --git a/session_srtp_test.go b/session_srtp_test.go index a0d07fc..98ef59c 100644 --- a/session_srtp_test.go +++ b/session_srtp_test.go @@ -2,27 +2,32 @@ package srtp import ( "bytes" + "context" "io" - "net" "reflect" "sync" "testing" "time" "github.com/pion/rtp" + "github.com/pion/transport/connctx" "github.com/pion/transport/test" ) func TestSessionSRTPBadInit(t *testing.T) { - if _, err := NewSessionSRTP(nil, nil); err == nil { + ctx := context.Background() + + if _, err := NewSessionSRTP(ctx, nil, nil); err == nil { t.Fatal("NewSessionSRTP should error if no config was provided") - } else if _, err := NewSessionSRTP(nil, &Config{}); err == nil { + } else if _, err := NewSessionSRTP(ctx, nil, &Config{}); err == nil { t.Fatal("NewSessionSRTP should error if no net was provided") } } func buildSessionSRTPPair(t *testing.T) (*SessionSRTP, *SessionSRTP) { //nolint:dupl - aPipe, bPipe := net.Pipe() + ctx := context.Background() + + aPipe, bPipe := connctx.Pipe() config := &Config{ Profile: ProtectionProfileAes128CmHmacSha1_80, Keys: SessionKeys{ @@ -33,14 +38,14 @@ func buildSessionSRTPPair(t *testing.T) (*SessionSRTP, *SessionSRTP) { //nolint: }, } - aSession, err := NewSessionSRTP(aPipe, config) + aSession, err := NewSessionSRTP(ctx, aPipe, config) if err != nil { t.Fatal(err) } else if aSession == nil { t.Fatal("NewSessionSRTP did not error, but returned nil session") } - bSession, err := NewSessionSRTP(bPipe, config) + bSession, err := NewSessionSRTP(ctx, bPipe, config) if err != nil { t.Fatal(err) } else if bSession == nil { @@ -51,6 +56,8 @@ func buildSessionSRTPPair(t *testing.T) (*SessionSRTP, *SessionSRTP) { //nolint: } func TestSessionSRTP(t *testing.T) { + ctx := context.Background() + lim := test.TimeOut(time.Second * 5) defer lim.Stop() @@ -69,7 +76,7 @@ func TestSessionSRTP(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err = aWriteStream.WriteRTP(&rtp.Header{SSRC: testSSRC}, append([]byte{}, testPayload...)); err != nil { + if _, err = aWriteStream.WriteRTP(ctx, &rtp.Header{SSRC: testSSRC}, append([]byte{}, testPayload...)); err != nil { t.Fatal(err) } @@ -80,7 +87,7 @@ func TestSessionSRTP(t *testing.T) { t.Fatalf("SSRC mismatch during accept exp(%v) actual%v)", testSSRC, ssrc) } - if _, err = bReadStream.Read(readBuffer); err != nil { + if _, err = bReadStream.Read(ctx, readBuffer); err != nil { t.Fatal(err) } @@ -98,6 +105,8 @@ func TestSessionSRTP(t *testing.T) { } func TestSessionSRTPOpenReadStream(t *testing.T) { + ctx := context.Background() + lim := test.TimeOut(time.Second * 5) defer lim.Stop() @@ -121,11 +130,11 @@ func TestSessionSRTPOpenReadStream(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err = aWriteStream.WriteRTP(&rtp.Header{SSRC: testSSRC}, append([]byte{}, testPayload...)); err != nil { + if _, err = aWriteStream.WriteRTP(ctx, &rtp.Header{SSRC: testSSRC}, append([]byte{}, testPayload...)); err != nil { t.Fatal(err) } - if _, err = bReadStream.Read(readBuffer); err != nil { + if _, err = bReadStream.Read(ctx, readBuffer); err != nil { t.Fatal(err) } @@ -143,6 +152,8 @@ func TestSessionSRTPOpenReadStream(t *testing.T) { } func TestSessionSRTPMultiSSRC(t *testing.T) { + ctx := context.Background() + lim := test.TimeOut(time.Second * 5) defer lim.Stop() @@ -168,12 +179,12 @@ func TestSessionSRTPMultiSSRC(t *testing.T) { t.Fatal(err) } for _, ssrc := range ssrcs { - if _, err = aWriteStream.WriteRTP(&rtp.Header{SSRC: ssrc}, append([]byte{}, testPayload...)); err != nil { + if _, err = aWriteStream.WriteRTP(ctx, &rtp.Header{SSRC: ssrc}, append([]byte{}, testPayload...)); err != nil { t.Fatal(err) } readBuffer := make([]byte, rtpHeaderSize+len(testPayload)) - if _, err = bReadStreams[ssrc].Read(readBuffer); err != nil { + if _, err = bReadStreams[ssrc].Read(ctx, readBuffer); err != nil { t.Fatal(err) } @@ -192,6 +203,8 @@ func TestSessionSRTPMultiSSRC(t *testing.T) { } func TestSessionSRTPReplayProtection(t *testing.T) { + ctx := context.Background() + lim := test.TimeOut(time.Second * 5) defer lim.Stop() @@ -234,7 +247,7 @@ func TestSessionSRTPReplayProtection(t *testing.T) { go func() { defer wg.Done() for { - if seq, perr := assertPayloadSRTP(t, bReadStream, rtpHeaderSize, testPayload); perr == nil { + if seq, perr := assertPayloadSRTP(ctx, t, bReadStream, rtpHeaderSize, testPayload); perr == nil { receivedSequenceNumber = append(receivedSequenceNumber, seq) } else if perr == io.EOF { return @@ -244,17 +257,17 @@ func TestSessionSRTPReplayProtection(t *testing.T) { // Write with replay attack for _, p := range packets { - if _, err = aSession.session.nextConn.Write(p); err != nil { + if _, err = aSession.session.nextConn.WriteContext(ctx, p); err != nil { t.Fatal(err) } // Immediately replay - if _, err = aSession.session.nextConn.Write(p); err != nil { + if _, err = aSession.session.nextConn.WriteContext(ctx, p); err != nil { t.Fatal(err) } } for _, p := range packets { // Delayed replay - if _, err = aSession.session.nextConn.Write(p); err != nil { + if _, err = aSession.session.nextConn.WriteContext(ctx, p); err != nil { t.Fatal(err) } } @@ -277,9 +290,9 @@ func TestSessionSRTPReplayProtection(t *testing.T) { } } -func assertPayloadSRTP(t *testing.T, stream *ReadStreamSRTP, headerSize int, expectedPayload []byte) (seq uint16, err error) { +func assertPayloadSRTP(ctx context.Context, t *testing.T, stream *ReadStreamSRTP, headerSize int, expectedPayload []byte) (seq uint16, err error) { readBuffer := make([]byte, headerSize+len(expectedPayload)) - n, hdr, err := stream.ReadRTP(readBuffer) + n, hdr, err := stream.ReadRTP(ctx, readBuffer) if err == io.EOF { return 0, err } diff --git a/stream.go b/stream.go index 7b7a0cf..2ac3d6f 100644 --- a/stream.go +++ b/stream.go @@ -1,8 +1,12 @@ package srtp +import ( + "context" +) + type readStream interface { init(child streamSession, ssrc uint32) error - Read(buf []byte) (int, error) + Read(ctx context.Context, buf []byte) (int, error) GetSSRC() uint32 } diff --git a/stream_srtcp.go b/stream_srtcp.go index 617d6dc..75181b1 100644 --- a/stream_srtcp.go +++ b/stream_srtcp.go @@ -1,6 +1,7 @@ package srtp import ( + "context" "errors" "sync" @@ -24,8 +25,8 @@ type ReadStreamSRTCP struct { buffer *packetio.Buffer } -func (r *ReadStreamSRTCP) write(buf []byte) (n int, err error) { - n, err = r.buffer.Write(buf) +func (r *ReadStreamSRTCP) write(ctx context.Context, buf []byte) (n int, err error) { + n, err = r.buffer.WriteContext(ctx, buf) if errors.Is(err, packetio.ErrFull) { // Silently drop data when the buffer is full. @@ -41,8 +42,8 @@ func newReadStreamSRTCP() readStream { } // ReadRTCP reads and decrypts full RTCP packet and its header from the nextConn -func (r *ReadStreamSRTCP) ReadRTCP(buf []byte) (int, *rtcp.Header, error) { - n, err := r.Read(buf) +func (r *ReadStreamSRTCP) ReadRTCP(ctx context.Context, buf []byte) (int, *rtcp.Header, error) { + n, err := r.Read(ctx, buf) if err != nil { return 0, nil, err } @@ -57,8 +58,8 @@ func (r *ReadStreamSRTCP) ReadRTCP(buf []byte) (int, *rtcp.Header, error) { } // Read reads and decrypts full RTCP packet from the nextConn -func (r *ReadStreamSRTCP) Read(buf []byte) (int, error) { - return r.buffer.Read(buf) +func (r *ReadStreamSRTCP) Read(ctx context.Context, buf []byte) (int, error) { + return r.buffer.ReadContext(ctx, buf) } // Close removes the ReadStream from the session and cleans up any associated state @@ -118,16 +119,16 @@ type WriteStreamSRTCP struct { } // WriteRTCP encrypts a RTCP header and its payload to the nextConn -func (w *WriteStreamSRTCP) WriteRTCP(header *rtcp.Header, payload []byte) (int, error) { +func (w *WriteStreamSRTCP) WriteRTCP(ctx context.Context, header *rtcp.Header, payload []byte) (int, error) { headerRaw, err := header.Marshal() if err != nil { return 0, err } - return w.session.write(append(headerRaw, payload...)) + return w.session.write(ctx, append(headerRaw, payload...)) } // Write encrypts and writes a full RTCP packets to the nextConn -func (w *WriteStreamSRTCP) Write(b []byte) (int, error) { - return w.session.write(b) +func (w *WriteStreamSRTCP) Write(ctx context.Context, b []byte) (int, error) { + return w.session.write(ctx, b) } diff --git a/stream_srtp.go b/stream_srtp.go index 9c3bb6f..5b1da61 100644 --- a/stream_srtp.go +++ b/stream_srtp.go @@ -1,6 +1,7 @@ package srtp import ( + "context" "errors" "sync" @@ -53,8 +54,8 @@ func (r *ReadStreamSRTP) init(child streamSession, ssrc uint32) error { return nil } -func (r *ReadStreamSRTP) write(buf []byte) (n int, err error) { - n, err = r.buffer.Write(buf) +func (r *ReadStreamSRTP) write(ctx context.Context, buf []byte) (n int, err error) { + n, err = r.buffer.WriteContext(ctx, buf) if errors.Is(err, packetio.ErrFull) { // Silently drop data when the buffer is full. @@ -65,13 +66,13 @@ func (r *ReadStreamSRTP) write(buf []byte) (n int, err error) { } // Read reads and decrypts full RTP packet from the nextConn -func (r *ReadStreamSRTP) Read(buf []byte) (int, error) { - return r.buffer.Read(buf) +func (r *ReadStreamSRTP) Read(ctx context.Context, buf []byte) (int, error) { + return r.buffer.ReadContext(ctx, buf) } // ReadRTP reads and decrypts full RTP packet and its header from the nextConn -func (r *ReadStreamSRTP) ReadRTP(buf []byte) (int, *rtp.Header, error) { - n, err := r.Read(buf) +func (r *ReadStreamSRTP) ReadRTP(ctx context.Context, buf []byte) (int, *rtp.Header, error) { + n, err := r.Read(ctx, buf) if err != nil { return 0, nil, err } @@ -120,11 +121,11 @@ type WriteStreamSRTP struct { } // WriteRTP encrypts a RTP packet and writes to the connection -func (w *WriteStreamSRTP) WriteRTP(header *rtp.Header, payload []byte) (int, error) { - return w.session.writeRTP(header, payload) +func (w *WriteStreamSRTP) WriteRTP(ctx context.Context, header *rtp.Header, payload []byte) (int, error) { + return w.session.writeRTP(ctx, header, payload) } // Write encrypts and writes a full RTP packets to the nextConn -func (w *WriteStreamSRTP) Write(b []byte) (int, error) { - return w.session.write(b) +func (w *WriteStreamSRTP) Write(ctx context.Context, b []byte) (int, error) { + return w.session.write(ctx, b) } diff --git a/stream_srtp_test.go b/stream_srtp_test.go index b6ae601..9361dda 100644 --- a/stream_srtp_test.go +++ b/stream_srtp_test.go @@ -1,27 +1,25 @@ package srtp import ( + "context" "io" - "net" "testing" - "time" "github.com/pion/rtp" ) type noopConn struct{ closed chan struct{} } -func newNoopConn() *noopConn { return &noopConn{closed: make(chan struct{})} } -func (c *noopConn) Read(b []byte) (n int, err error) { <-c.closed; return 0, io.EOF } -func (c *noopConn) Write(b []byte) (n int, err error) { return len(b), nil } -func (c *noopConn) Close() error { close(c.closed); return nil } -func (c *noopConn) LocalAddr() net.Addr { return nil } -func (c *noopConn) RemoteAddr() net.Addr { return nil } -func (c *noopConn) SetDeadline(t time.Time) error { return nil } -func (c *noopConn) SetReadDeadline(t time.Time) error { return nil } -func (c *noopConn) SetWriteDeadline(t time.Time) error { return nil } +func newNoopConn() *noopConn { return &noopConn{closed: make(chan struct{})} } +func (c *noopConn) ReadContext(ctx context.Context, b []byte) (n int, err error) { + <-c.closed + return 0, io.EOF +} +func (c *noopConn) WriteContext(ctx context.Context, b []byte) (n int, err error) { return len(b), nil } +func (c *noopConn) Close() error { close(c.closed); return nil } func BenchmarkWrite(b *testing.B) { + ctx := context.Background() conn := newNoopConn() config := &Config{ @@ -34,7 +32,7 @@ func BenchmarkWrite(b *testing.B) { Profile: ProtectionProfileAes128CmHmacSha1_80, } - session, err := NewSessionSRTP(conn, config) + session, err := NewSessionSRTP(ctx, conn, config) if err != nil { b.Fatal(err) } @@ -62,7 +60,7 @@ func BenchmarkWrite(b *testing.B) { for i := 0; i < b.N; i++ { packet.Header.SequenceNumber++ - _, err = ws.Write(packetRaw) + _, err = ws.Write(ctx, packetRaw) if err != nil { b.Fatal(err) } @@ -75,6 +73,7 @@ func BenchmarkWrite(b *testing.B) { } func BenchmarkWriteRTP(b *testing.B) { + ctx := context.Background() conn := &noopConn{ closed: make(chan struct{}), } @@ -89,7 +88,7 @@ func BenchmarkWriteRTP(b *testing.B) { Profile: ProtectionProfileAes128CmHmacSha1_80, } - session, err := NewSessionSRTP(conn, config) + session, err := NewSessionSRTP(ctx, conn, config) if err != nil { b.Fatal(err) } @@ -111,7 +110,7 @@ func BenchmarkWriteRTP(b *testing.B) { for i := 0; i < b.N; i++ { header.SequenceNumber++ - _, err = ws.WriteRTP(header, payload) + _, err = ws.WriteRTP(ctx, header, payload) if err != nil { b.Fatal(err) }