Skip to content

Commit 31c8047

Browse files
committed
fix: lock log sink against concurrent write and close
1 parent e0afee1 commit 31c8047

File tree

4 files changed

+100
-5
lines changed

4 files changed

+100
-5
lines changed

cli/cliutil/sink.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package cliutil
2+
3+
import (
4+
"io"
5+
"sync"
6+
)
7+
8+
type discardAfterClose struct {
9+
sync.Mutex
10+
wc io.WriteCloser
11+
closed bool
12+
}
13+
14+
// DiscardAfterClose is an io.WriteCloser that discards writes after it is closed without errors.
15+
// It is useful as a target for a slog.Sink such that an underlying WriteCloser, like a file, can
16+
// be cleaned up without race conditions from still-active loggers.
17+
func DiscardAfterClose(wc io.WriteCloser) io.WriteCloser {
18+
return &discardAfterClose{wc: wc}
19+
}
20+
21+
func (d *discardAfterClose) Write(p []byte) (n int, err error) {
22+
d.Lock()
23+
defer d.Unlock()
24+
if d.closed {
25+
return len(p), nil
26+
}
27+
return d.wc.Write(p)
28+
}
29+
30+
func (d *discardAfterClose) Close() error {
31+
d.Lock()
32+
defer d.Unlock()
33+
if d.closed {
34+
return nil
35+
}
36+
d.closed = true
37+
return d.wc.Close()
38+
}

cli/cliutil/sink_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package cliutil_test
2+
3+
import (
4+
"errors"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/coder/coder/v2/cli/cliutil"
10+
)
11+
12+
func TestDiscardAfterClose(t *testing.T) {
13+
exErr := errors.New("test")
14+
fwc := &fakeWriteCloser{err: exErr}
15+
uut := cliutil.DiscardAfterClose(fwc)
16+
17+
n, err := uut.Write([]byte("one"))
18+
require.Equal(t, 3, n)
19+
require.NoError(t, err)
20+
21+
n, err = uut.Write([]byte("two"))
22+
require.Equal(t, 3, n)
23+
require.NoError(t, err)
24+
25+
err = uut.Close()
26+
require.Equal(t, exErr, err)
27+
28+
n, err = uut.Write([]byte("three"))
29+
require.Equal(t, 5, n)
30+
require.NoError(t, err)
31+
32+
require.Len(t, fwc.writes, 2)
33+
require.EqualValues(t, "one", fwc.writes[0])
34+
require.EqualValues(t, "two", fwc.writes[1])
35+
}
36+
37+
type fakeWriteCloser struct {
38+
writes [][]byte
39+
closed bool
40+
err error
41+
}
42+
43+
func (f *fakeWriteCloser) Write(p []byte) (n int, err error) {
44+
q := make([]byte, len(p))
45+
copy(q, p)
46+
f.writes = append(f.writes, q)
47+
return len(p), nil
48+
}
49+
50+
func (f *fakeWriteCloser) Close() error {
51+
f.closed = true
52+
return f.err
53+
}

cli/ssh.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828

2929
"github.com/coder/coder/v2/cli/clibase"
3030
"github.com/coder/coder/v2/cli/cliui"
31+
"github.com/coder/coder/v2/cli/cliutil"
3132
"github.com/coder/coder/v2/coderd/autobuild/notify"
3233
"github.com/coder/coder/v2/coderd/util/ptr"
3334
"github.com/coder/coder/v2/codersdk"
@@ -114,12 +115,13 @@ func (r *RootCmd) ssh() *clibase.Cmd {
114115
if err != nil {
115116
return xerrors.Errorf("error opening %s for logging: %w", logDirPath, err)
116117
}
118+
dc := cliutil.DiscardAfterClose(logFile)
117119
go func() {
118120
wg.Wait()
119-
_ = logFile.Close()
121+
_ = dc.Close()
120122
}()
121123

122-
logger = slog.Make(sloghuman.Sink(logFile))
124+
logger = logger.AppendSinks(sloghuman.Sink(dc))
123125
if r.verbose {
124126
logger = logger.Leveled(slog.LevelDebug)
125127
}

cli/vscodessh.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"cdr.dev/slog/sloggers/sloghuman"
2222

2323
"github.com/coder/coder/v2/cli/clibase"
24+
"github.com/coder/coder/v2/cli/cliutil"
2425
"github.com/coder/coder/v2/codersdk"
2526
)
2627

@@ -137,15 +138,16 @@ func (r *RootCmd) vscodeSSH() *clibase.Cmd {
137138
// command via the ProxyCommand SSH option.
138139
pid := os.Getppid()
139140

140-
var logger slog.Logger
141+
logger := slog.Make()
141142
if logDir != "" {
142143
logFilePath := filepath.Join(logDir, fmt.Sprintf("%d.log", pid))
143144
logFile, err := fs.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY, 0o600)
144145
if err != nil {
145146
return xerrors.Errorf("open log file %q: %w", logFilePath, err)
146147
}
147-
defer logFile.Close()
148-
logger = slog.Make(sloghuman.Sink(logFile)).Leveled(slog.LevelDebug)
148+
dc := cliutil.DiscardAfterClose(logFile)
149+
defer dc.Close()
150+
logger = logger.AppendSinks(sloghuman.Sink(dc)).Leveled(slog.LevelDebug)
149151
}
150152
if r.disableDirect {
151153
logger.Info(ctx, "direct connections disabled")

0 commit comments

Comments
 (0)