Skip to content

Commit fcd5e39

Browse files
committed
fix: lock log sink against concurrent write and close
1 parent e0afee1 commit fcd5e39

File tree

4 files changed

+101
-5
lines changed

4 files changed

+101
-5
lines changed

cli/cliutil/sink.go

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package cliutil
2+
3+
import (
4+
"io"
5+
"sync"
6+
)
7+
8+
type discardAfterClose struct {
9+
sync.Mutex
10+
wc io.WriteCloser
11+
closed bool
12+
}
13+
14+
// DiscardAfterClose is an io.WriteCloser that discards writes after it is closed without errors.
15+
// It is useful as a target for a slog.Sink such that an underlying WriteCloser, like a file, can
16+
// be cleaned up without race conditions from still-active loggers.
17+
func DiscardAfterClose(wc io.WriteCloser) io.WriteCloser {
18+
return &discardAfterClose{wc: wc}
19+
}
20+
21+
func (d *discardAfterClose) Write(p []byte) (n int, err error) {
22+
d.Lock()
23+
defer d.Unlock()
24+
if d.closed {
25+
return len(p), nil
26+
}
27+
return d.wc.Write(p)
28+
}
29+
30+
func (d *discardAfterClose) Close() error {
31+
d.Lock()
32+
defer d.Unlock()
33+
if d.closed {
34+
return nil
35+
}
36+
d.closed = true
37+
return d.wc.Close()
38+
}

cli/cliutil/sink_test.go

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package cliutil_test
2+
3+
import (
4+
"errors"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/coder/coder/v2/cli/cliutil"
10+
)
11+
12+
func TestDiscardAfterClose(t *testing.T) {
13+
t.Parallel()
14+
exErr := errors.New("test")
15+
fwc := &fakeWriteCloser{err: exErr}
16+
uut := cliutil.DiscardAfterClose(fwc)
17+
18+
n, err := uut.Write([]byte("one"))
19+
require.Equal(t, 3, n)
20+
require.NoError(t, err)
21+
22+
n, err = uut.Write([]byte("two"))
23+
require.Equal(t, 3, n)
24+
require.NoError(t, err)
25+
26+
err = uut.Close()
27+
require.Equal(t, exErr, err)
28+
29+
n, err = uut.Write([]byte("three"))
30+
require.Equal(t, 5, n)
31+
require.NoError(t, err)
32+
33+
require.Len(t, fwc.writes, 2)
34+
require.EqualValues(t, "one", fwc.writes[0])
35+
require.EqualValues(t, "two", fwc.writes[1])
36+
}
37+
38+
type fakeWriteCloser struct {
39+
writes [][]byte
40+
closed bool
41+
err error
42+
}
43+
44+
func (f *fakeWriteCloser) Write(p []byte) (n int, err error) {
45+
q := make([]byte, len(p))
46+
copy(q, p)
47+
f.writes = append(f.writes, q)
48+
return len(p), nil
49+
}
50+
51+
func (f *fakeWriteCloser) Close() error {
52+
f.closed = true
53+
return f.err
54+
}

cli/ssh.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828

2929
"github.com/coder/coder/v2/cli/clibase"
3030
"github.com/coder/coder/v2/cli/cliui"
31+
"github.com/coder/coder/v2/cli/cliutil"
3132
"github.com/coder/coder/v2/coderd/autobuild/notify"
3233
"github.com/coder/coder/v2/coderd/util/ptr"
3334
"github.com/coder/coder/v2/codersdk"
@@ -114,12 +115,13 @@ func (r *RootCmd) ssh() *clibase.Cmd {
114115
if err != nil {
115116
return xerrors.Errorf("error opening %s for logging: %w", logDirPath, err)
116117
}
118+
dc := cliutil.DiscardAfterClose(logFile)
117119
go func() {
118120
wg.Wait()
119-
_ = logFile.Close()
121+
_ = dc.Close()
120122
}()
121123

122-
logger = slog.Make(sloghuman.Sink(logFile))
124+
logger = logger.AppendSinks(sloghuman.Sink(dc))
123125
if r.verbose {
124126
logger = logger.Leveled(slog.LevelDebug)
125127
}

cli/vscodessh.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"cdr.dev/slog/sloggers/sloghuman"
2222

2323
"github.com/coder/coder/v2/cli/clibase"
24+
"github.com/coder/coder/v2/cli/cliutil"
2425
"github.com/coder/coder/v2/codersdk"
2526
)
2627

@@ -137,15 +138,16 @@ func (r *RootCmd) vscodeSSH() *clibase.Cmd {
137138
// command via the ProxyCommand SSH option.
138139
pid := os.Getppid()
139140

140-
var logger slog.Logger
141+
logger := slog.Make()
141142
if logDir != "" {
142143
logFilePath := filepath.Join(logDir, fmt.Sprintf("%d.log", pid))
143144
logFile, err := fs.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY, 0o600)
144145
if err != nil {
145146
return xerrors.Errorf("open log file %q: %w", logFilePath, err)
146147
}
147-
defer logFile.Close()
148-
logger = slog.Make(sloghuman.Sink(logFile)).Leveled(slog.LevelDebug)
148+
dc := cliutil.DiscardAfterClose(logFile)
149+
defer dc.Close()
150+
logger = logger.AppendSinks(sloghuman.Sink(dc)).Leveled(slog.LevelDebug)
149151
}
150152
if r.disableDirect {
151153
logger.Info(ctx, "direct connections disabled")

0 commit comments

Comments
 (0)