Skip to content

fix: Use atomic value for logger in peer #1257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions peer/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ func (c *Channel) init() {
}
})
c.dc.OnClose(func() {
c.conn.opts.Logger.Debug(context.Background(), "datachannel closing from OnClose", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
c.conn.logger().Debug(context.Background(), "datachannel closing from OnClose", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
_ = c.closeWithError(ErrClosed)
})
c.dc.OnOpen(func() {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()

c.conn.opts.Logger.Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
c.conn.logger().Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
var err error
c.rwc, err = c.dc.Detach()
if err != nil {
Expand Down Expand Up @@ -289,7 +289,7 @@ func (c *Channel) closeWithError(err error) error {
return c.closeError
}

c.conn.opts.Logger.Debug(context.Background(), "datachannel closing with error", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()), slog.Error(err))
c.conn.logger().Debug(context.Background(), "datachannel closing with error", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()), slog.Error(err))
if err == nil {
c.closeError = ErrClosed
} else {
Expand Down
75 changes: 37 additions & 38 deletions peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp
conn := &Conn{
pingChannelID: 1,
pingEchoChannelID: 2,
opts: opts,
rtc: rtc,
offerer: client,
closed: make(chan struct{}),
Expand All @@ -75,7 +74,9 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp
localCandidateChannel: make(chan webrtc.ICECandidateInit),
localSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1),
remoteSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1),
settingEngine: opts.SettingEngine,
}
conn.loggerValue.Store(opts.Logger)
if client {
// If we're the client, we want to flip the echo and
// ping channel IDs so pings don't accidentally hit each other.
Expand All @@ -100,8 +101,7 @@ type ConnOptions struct {
// This struct wraps webrtc.PeerConnection to add bidirectional pings,
// concurrent-safe webrtc.DataChannel, and standardized errors for connection state.
type Conn struct {
rtc *webrtc.PeerConnection
opts *ConnOptions
rtc *webrtc.PeerConnection
// Determines whether this connection will send the offer or the answer.
offerer bool

Expand All @@ -127,6 +127,9 @@ type Conn struct {
negotiateMutex sync.Mutex
hasNegotiated bool

loggerValue atomic.Value
settingEngine webrtc.SettingEngine

pingChannelID uint16
pingEchoChannelID uint16

Expand All @@ -139,6 +142,14 @@ type Conn struct {
pingError error
}

func (c *Conn) logger() slog.Logger {
log, valid := c.loggerValue.Load().(slog.Logger)
if !valid {
return slog.Logger{}
}
return log
}

func (c *Conn) init() error {
// The negotiation needed callback can take a little bit to execute!
c.negotiateMutex.Lock()
Expand All @@ -152,7 +163,7 @@ func (c *Conn) init() error {
// Don't log more state changes if we've already closed.
return
default:
c.opts.Logger.Debug(context.Background(), "ice connection state updated",
c.logger().Debug(context.Background(), "ice connection state updated",
slog.F("state", iceConnectionState))

if iceConnectionState == webrtc.ICEConnectionStateClosed {
Expand All @@ -171,7 +182,7 @@ func (c *Conn) init() error {
// Don't log more state changes if we've already closed.
return
default:
c.opts.Logger.Debug(context.Background(), "ice gathering state updated",
c.logger().Debug(context.Background(), "ice gathering state updated",
slog.F("state", iceGatherState))

if iceGatherState == webrtc.ICEGathererStateClosed {
Expand All @@ -189,7 +200,7 @@ func (c *Conn) init() error {
if c.isClosed() {
return
}
c.opts.Logger.Debug(context.Background(), "rtc connection updated",
c.logger().Debug(context.Background(), "rtc connection updated",
slog.F("state", peerConnectionState))
}()

Expand Down Expand Up @@ -225,38 +236,25 @@ func (c *Conn) init() error {
// These functions need to check if the conn is closed, because they can be
// called after being closed.
c.rtc.OnSignalingStateChange(func(signalState webrtc.SignalingState) {
if c.isClosed() {
return
}
c.opts.Logger.Debug(context.Background(), "signaling state updated",
c.logger().Debug(context.Background(), "signaling state updated",
slog.F("state", signalState))
})
c.rtc.SCTP().Transport().OnStateChange(func(dtlsTransportState webrtc.DTLSTransportState) {
if c.isClosed() {
return
}
c.opts.Logger.Debug(context.Background(), "dtls transport state updated",
c.logger().Debug(context.Background(), "dtls transport state updated",
slog.F("state", dtlsTransportState))
})
c.rtc.SCTP().Transport().ICETransport().OnSelectedCandidatePairChange(func(candidatePair *webrtc.ICECandidatePair) {
if c.isClosed() {
return
}
c.opts.Logger.Debug(context.Background(), "selected candidate pair changed",
c.logger().Debug(context.Background(), "selected candidate pair changed",
slog.F("local", candidatePair.Local), slog.F("remote", candidatePair.Remote))
})
c.rtc.OnICECandidate(func(iceCandidate *webrtc.ICECandidate) {
if c.isClosed() {
return
}

if iceCandidate == nil {
return
}
// Run this in a goroutine so we don't block pion/webrtc
// from continuing.
go func() {
c.opts.Logger.Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate))
c.logger().Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate))
select {
case <-c.closed:
break
Expand Down Expand Up @@ -287,7 +285,7 @@ func (c *Conn) init() error {
// negotiate is triggered when a connection is ready to be established.
// See trickle ICE for the expected exchange: https://webrtchacks.com/trickle-ice/
func (c *Conn) negotiate() {
c.opts.Logger.Debug(context.Background(), "negotiating")
c.logger().Debug(context.Background(), "negotiating")
// ICE candidates cannot be added until SessionDescriptions have been
// exchanged between peers.
if c.hasNegotiated {
Expand All @@ -311,23 +309,23 @@ func (c *Conn) negotiate() {
_ = c.CloseWithError(xerrors.Errorf("set local description: %w", err))
return
}
c.opts.Logger.Debug(context.Background(), "sending offer", slog.F("offer", offer))
c.logger().Debug(context.Background(), "sending offer", slog.F("offer", offer))
select {
case <-c.closed:
return
case c.localSessionDescriptionChannel <- offer:
}
c.opts.Logger.Debug(context.Background(), "sent offer")
c.logger().Debug(context.Background(), "sent offer")
}

var sessionDescription webrtc.SessionDescription
c.opts.Logger.Debug(context.Background(), "awaiting remote description...")
c.logger().Debug(context.Background(), "awaiting remote description...")
select {
case <-c.closed:
return
case sessionDescription = <-c.remoteSessionDescriptionChannel:
}
c.opts.Logger.Debug(context.Background(), "setting remote description")
c.logger().Debug(context.Background(), "setting remote description")

err := c.rtc.SetRemoteDescription(sessionDescription)
if err != nil {
Expand All @@ -350,13 +348,13 @@ func (c *Conn) negotiate() {
_ = c.CloseWithError(xerrors.Errorf("set local description: %w", err))
return
}
c.opts.Logger.Debug(context.Background(), "sending answer", slog.F("answer", answer))
c.logger().Debug(context.Background(), "sending answer", slog.F("answer", answer))
select {
case <-c.closed:
return
case c.localSessionDescriptionChannel <- answer:
}
c.opts.Logger.Debug(context.Background(), "sent answer")
c.logger().Debug(context.Background(), "sent answer")
}
}

Expand All @@ -373,7 +371,7 @@ func (c *Conn) AddRemoteCandidate(i webrtc.ICECandidateInit) {
if c.isClosed() {
return
}
c.opts.Logger.Debug(context.Background(), "accepting candidate", slog.F("candidate", i.Candidate))
c.logger().Debug(context.Background(), "accepting candidate", slog.F("candidate", i.Candidate))
err := c.rtc.AddICECandidate(i)
if err != nil {
if c.rtc.ConnectionState() == webrtc.PeerConnectionStateClosed {
Expand Down Expand Up @@ -482,7 +480,7 @@ func (c *Conn) Dial(ctx context.Context, label string, opts *ChannelOptions) (*C
}

func (c *Conn) dialChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
c.opts.Logger.Debug(ctx, "creating data channel", slog.F("label", label), slog.F("opts", opts))
c.logger().Debug(ctx, "creating data channel", slog.F("label", label), slog.F("opts", opts))
var id *uint16
if opts.ID != 0 {
id = &opts.ID
Expand Down Expand Up @@ -531,7 +529,7 @@ func (c *Conn) Ping() (time.Duration, error) {
if err != nil {
return 0, xerrors.Errorf("send ping: %w", err)
}
c.opts.Logger.Debug(context.Background(), "wrote ping",
c.logger().Debug(context.Background(), "wrote ping",
slog.F("connection_state", c.rtc.ConnectionState()))

pingDataReceived := make([]byte, pingDataLength)
Expand Down Expand Up @@ -568,12 +566,11 @@ func (c *Conn) isClosed() bool {
func (c *Conn) CloseWithError(err error) error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()

if c.isClosed() {
return c.closeError
}

c.opts.Logger.Debug(context.Background(), "closing conn with error", slog.Error(err))
c.logger().Debug(context.Background(), "closing conn with error", slog.Error(err))
if err == nil {
c.closeError = ErrClosed
} else {
Expand All @@ -591,19 +588,21 @@ func (c *Conn) CloseWithError(err error) error {
// Waiting for pion/webrtc to report closed state on both of these
// ensures no goroutine leaks.
if c.rtc.ConnectionState() != webrtc.PeerConnectionStateNew {
c.opts.Logger.Debug(context.Background(), "waiting for rtc connection close...")
c.logger().Debug(context.Background(), "waiting for rtc connection close...")
<-c.closedRTC
}
if c.rtc.ICEConnectionState() != webrtc.ICEConnectionStateNew {
c.opts.Logger.Debug(context.Background(), "waiting for ice connection close...")
c.logger().Debug(context.Background(), "waiting for ice connection close...")
<-c.closedICE
}

// Waits for all DataChannels to exit before officially labeling as closed.
// All logging, goroutines, and async functionality is cleaned up after this.
c.dcClosedWaitGroup.Wait()

c.opts.Logger.Debug(context.Background(), "closed")
c.logger().Debug(context.Background(), "closed")
// Disable logging!
c.loggerValue.Store(slog.Logger{})
close(c.closed)
return err
}