diff --git a/cli/root.go b/cli/root.go index 254368a01d97d..4fc7958772bac 100644 --- a/cli/root.go +++ b/cli/root.go @@ -19,6 +19,8 @@ import ( "golang.org/x/xerrors" + "cdr.dev/slog" + "github.com/charmbracelet/lipgloss" "github.com/kirsle/configdir" "github.com/mattn/go-isatty" @@ -179,6 +181,21 @@ func Root(subcommands []*cobra.Command) *cobra.Command { return cmd } +type contextKey int + +const ( + contextKeyLogger contextKey = iota +) + +func ContextWithLogger(ctx context.Context, l slog.Logger) context.Context { + return context.WithValue(ctx, contextKeyLogger, l) +} + +func LoggerFromContext(ctx context.Context) (slog.Logger, bool) { + l, ok := ctx.Value(contextKeyLogger).(slog.Logger) + return l, ok +} + // fixUnknownSubcommandError modifies the provided commands so that the // ones with subcommands output the correct error message when an // unknown subcommand is invoked. diff --git a/cli/speedtest.go b/cli/speedtest.go index 3ab9822b2a71e..2fc62227fdd58 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -51,7 +51,10 @@ func speedtest() *cobra.Command { if err != nil && !xerrors.Is(err, cliui.AgentStartError) { return xerrors.Errorf("await agent: %w", err) } - logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr())) + logger, ok := LoggerFromContext(ctx) + if !ok { + logger = slog.Make(sloghuman.Sink(cmd.ErrOrStderr())) + } if cliflag.IsSetBool(cmd, varVerbose) { logger = logger.Leveled(slog.LevelDebug) } diff --git a/cli/speedtest_test.go b/cli/speedtest_test.go index cdb70e97b558f..3cb2956975525 100644 --- a/cli/speedtest_test.go +++ b/cli/speedtest_test.go @@ -10,6 +10,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" + "github.com/coder/coder/cli" "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" @@ -51,10 +52,12 @@ func TestSpeedtest(t *testing.T) { clitest.SetupConfig(t, client, root) pty := ptytest.New(t) cmd.SetOut(pty.Output()) + cmd.SetErr(pty.Output()) ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() + ctx = cli.ContextWithLogger(ctx, slogtest.Make(t, nil).Named("speedtest").Leveled(slog.LevelDebug)) cmdDone := tGo(t, func() { err := cmd.ExecuteContext(ctx) assert.NoError(t, err)