Skip to content

Commit 495a38e

Browse files
committed
Merge remote-tracking branch 'origin/main' into agent-metadata
2 parents baa157f + b439c3e commit 495a38e

34 files changed

+812
-306
lines changed

agent/agent.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,7 @@ func (a *agent) init(ctx context.Context) {
988988
_ = session.Exit(MagicSessionErrorCode)
989989
return
990990
}
991+
_ = session.Exit(0)
991992
},
992993
HostSigners: []ssh.Signer{randomSigner},
993994
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
@@ -1240,7 +1241,9 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
12401241
if err != nil {
12411242
return xerrors.Errorf("start command: %w", err)
12421243
}
1244+
var wg sync.WaitGroup
12431245
defer func() {
1246+
defer wg.Wait()
12441247
closeErr := ptty.Close()
12451248
if closeErr != nil {
12461249
a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
@@ -1257,10 +1260,16 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
12571260
}
12581261
}
12591262
}()
1263+
// We don't add input copy to wait group because
1264+
// it won't return until the session is closed.
12601265
go func() {
12611266
_, _ = io.Copy(ptty.Input(), session)
12621267
}()
1268+
wg.Add(1)
12631269
go func() {
1270+
// Ensure data is flushed to session on command exit, if we
1271+
// close the session too soon, we might lose data.
1272+
defer wg.Done()
12641273
_, _ = io.Copy(session, ptty.Output())
12651274
}()
12661275
err = process.Wait()

agent/agent_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,57 @@ func TestAgent_Session_TTY_Hushlogin(t *testing.T) {
349349
require.NotContains(t, stdout.String(), wantNotMOTD, "should not show motd")
350350
}
351351

352+
func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
353+
t.Parallel()
354+
if runtime.GOOS == "windows" {
355+
// This might be our implementation, or ConPTY itself.
356+
// It's difficult to find extensive tests for it, so
357+
// it seems like it could be either.
358+
t.Skip("ConPTY appears to be inconsistent on Windows.")
359+
}
360+
361+
// This test is here to prevent regressions where quickly executing
362+
// commands (with TTY) don't flush their output to the SSH session.
363+
//
364+
// See: https://github.com/coder/coder/issues/6656
365+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
366+
defer cancel()
367+
//nolint:dogsled
368+
conn, _, _, _, _ := setupAgent(t, agentsdk.Metadata{}, 0)
369+
sshClient, err := conn.SSHClient(ctx)
370+
require.NoError(t, err)
371+
defer sshClient.Close()
372+
373+
ptty := ptytest.New(t)
374+
375+
var stdout bytes.Buffer
376+
// NOTE(mafredri): Increase iterations to increase chance of failure,
377+
// assuming bug is present.
378+
// Using 1000 iterations is basically a guaranteed failure (but let's
379+
// not increase test times needlessly).
380+
for i := 0; i < 5; i++ {
381+
func() {
382+
stdout.Reset()
383+
384+
session, err := sshClient.NewSession()
385+
require.NoError(t, err)
386+
defer session.Close()
387+
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
388+
require.NoError(t, err)
389+
390+
session.Stdout = &stdout
391+
session.Stderr = ptty.Output()
392+
session.Stdin = ptty.Input()
393+
err = session.Start("echo wazzup")
394+
require.NoError(t, err)
395+
396+
err = session.Wait()
397+
require.NoError(t, err)
398+
require.Contains(t, stdout.String(), "wazzup", "should output greeting")
399+
}()
400+
}
401+
}
402+
352403
//nolint:paralleltest // This test reserves a port.
353404
func TestAgent_TCPLocalForwarding(t *testing.T) {
354405
random, err := net.Listen("tcp", "127.0.0.1:0")

cli/cliui/cliui.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ var Styles = struct {
4949
Keyword: defaultStyles.Keyword,
5050
Paragraph: defaultStyles.Paragraph,
5151
Placeholder: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#585858", Dark: "#4d46b3"}),
52-
Prompt: defaultStyles.Prompt.Foreground(lipgloss.AdaptiveColor{Light: "#9B9B9B", Dark: "#5C5C5C"}),
53-
FocusedPrompt: defaultStyles.FocusedPrompt.Foreground(lipgloss.Color("#651fff")),
52+
Prompt: defaultStyles.Prompt.Copy().Foreground(lipgloss.AdaptiveColor{Light: "#9B9B9B", Dark: "#5C5C5C"}),
53+
FocusedPrompt: defaultStyles.FocusedPrompt.Copy().Foreground(lipgloss.Color("#651fff")),
5454
Fuchsia: defaultStyles.SelectedMenuItem.Copy(),
55-
Logo: defaultStyles.Logo.SetString("Coder"),
55+
Logo: defaultStyles.Logo.Copy().SetString("Coder"),
5656
Warn: lipgloss.NewStyle().Foreground(
5757
lipgloss.AdaptiveColor{Light: "#04B575", Dark: "#ECFD65"},
5858
),

cli/ssh.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,13 @@ func (r *RootCmd) ssh() *clibase.Cmd {
8282
if xerrors.Is(err, context.Canceled) {
8383
return cliui.Canceled
8484
}
85-
if xerrors.Is(err, cliui.AgentStartError) {
86-
return xerrors.New("Agent startup script exited with non-zero status, use --no-wait to login anyway.")
85+
if !xerrors.Is(err, cliui.AgentStartError) {
86+
return xerrors.Errorf("await agent: %w", err)
8787
}
88-
return xerrors.Errorf("await agent: %w", err)
88+
89+
// We don't want to fail on a startup script error because it's
90+
// natural that the user will want to fix the script and try again.
91+
// We don't print the error because cliui.Agent does that for us.
8992
}
9093

9194
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, &codersdk.DialWorkspaceAgentOptions{})

coderd/database/dbauthz/querier.go

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,57 @@ func (q *querier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditL
8888
}
8989

9090
func (q *querier) GetFileByHashAndCreator(ctx context.Context, arg database.GetFileByHashAndCreatorParams) (database.File, error) {
91-
return fetch(q.log, q.auth, q.db.GetFileByHashAndCreator)(ctx, arg)
91+
file, err := q.db.GetFileByHashAndCreator(ctx, arg)
92+
if err != nil {
93+
return database.File{}, err
94+
}
95+
err = q.authorizeContext(ctx, rbac.ActionRead, file)
96+
if err != nil {
97+
// Check the user's access to the file's templates.
98+
if q.authorizeUpdateFileTemplate(ctx, file) != nil {
99+
return database.File{}, err
100+
}
101+
}
102+
103+
return file, nil
92104
}
93105

94106
func (q *querier) GetFileByID(ctx context.Context, id uuid.UUID) (database.File, error) {
95-
return fetch(q.log, q.auth, q.db.GetFileByID)(ctx, id)
107+
file, err := q.db.GetFileByID(ctx, id)
108+
if err != nil {
109+
return database.File{}, err
110+
}
111+
err = q.authorizeContext(ctx, rbac.ActionRead, file)
112+
if err != nil {
113+
// Check the user's access to the file's templates.
114+
if q.authorizeUpdateFileTemplate(ctx, file) != nil {
115+
return database.File{}, err
116+
}
117+
}
118+
119+
return file, nil
120+
}
121+
122+
// authorizeReadFile is a hotfix for the fact that file permissions are
123+
// independent of template permissions. This function checks if the user has
124+
// update access to any of the file's templates.
125+
func (q *querier) authorizeUpdateFileTemplate(ctx context.Context, file database.File) error {
126+
tpls, err := q.db.GetFileTemplates(ctx, file.ID)
127+
if err != nil {
128+
return err
129+
}
130+
// There __should__ only be 1 template per file, but there can be more than
131+
// 1, so check them all.
132+
for _, tpl := range tpls {
133+
// If the user has update access to any template, they have read access to the file.
134+
if err := q.authorizeContext(ctx, rbac.ActionUpdate, tpl); err == nil {
135+
return nil
136+
}
137+
}
138+
139+
return NotAuthorizedError{
140+
Err: xerrors.Errorf("not authorized to read file %s", file.ID),
141+
}
96142
}
97143

98144
func (q *querier) InsertFile(ctx context.Context, arg database.InsertFileParams) (database.File, error) {
@@ -859,11 +905,22 @@ func (q *querier) UpdateTemplateScheduleByID(ctx context.Context, arg database.U
859905
}
860906

861907
func (q *querier) UpdateTemplateVersionByID(ctx context.Context, arg database.UpdateTemplateVersionByIDParams) (database.TemplateVersion, error) {
862-
template, err := q.db.GetTemplateByID(ctx, arg.TemplateID.UUID)
908+
// An actor is allowed to update the template version if they are authorized to update the template.
909+
tv, err := q.db.GetTemplateVersionByID(ctx, arg.ID)
863910
if err != nil {
864911
return database.TemplateVersion{}, err
865912
}
866-
if err := q.authorizeContext(ctx, rbac.ActionUpdate, template); err != nil {
913+
var obj rbac.Objecter
914+
if !tv.TemplateID.Valid {
915+
obj = rbac.ResourceTemplate.InOrg(tv.OrganizationID)
916+
} else {
917+
tpl, err := q.db.GetTemplateByID(ctx, tv.TemplateID.UUID)
918+
if err != nil {
919+
return database.TemplateVersion{}, err
920+
}
921+
obj = tpl
922+
}
923+
if err := q.authorizeContext(ctx, rbac.ActionUpdate, obj); err != nil {
867924
return database.TemplateVersion{}, err
868925
}
869926
return q.db.UpdateTemplateVersionByID(ctx, arg)

coderd/database/dbauthz/system.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ import (
1010
"github.com/coder/coder/coderd/rbac"
1111
)
1212

13+
func (q *querier) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]database.GetFileTemplatesRow, error) {
14+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
15+
return nil, err
16+
}
17+
return q.db.GetFileTemplates(ctx, fileID)
18+
}
19+
1320
// GetWorkspaceAppsByAgentIDs
1421
// The workspace/job is already fetched.
1522
func (q *querier) GetWorkspaceAppsByAgentIDs(ctx context.Context, ids []uuid.UUID) ([]database.WorkspaceApp, error) {

coderd/database/dbfake/databasefake.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,47 @@ func (q *fakeQuerier) GetFileByID(_ context.Context, id uuid.UUID) (database.Fil
686686
return database.File{}, sql.ErrNoRows
687687
}
688688

689+
func (q *fakeQuerier) GetFileTemplates(_ context.Context, id uuid.UUID) ([]database.GetFileTemplatesRow, error) {
690+
q.mutex.RLock()
691+
defer q.mutex.RUnlock()
692+
693+
rows := make([]database.GetFileTemplatesRow, 0)
694+
var file database.File
695+
for _, f := range q.files {
696+
if f.ID == id {
697+
file = f
698+
break
699+
}
700+
}
701+
if file.Hash == "" {
702+
return rows, nil
703+
}
704+
705+
for _, job := range q.provisionerJobs {
706+
if job.FileID == id {
707+
for _, version := range q.templateVersions {
708+
if version.JobID == job.ID {
709+
for _, template := range q.templates {
710+
if template.ID == version.TemplateID.UUID {
711+
rows = append(rows, database.GetFileTemplatesRow{
712+
FileID: file.ID,
713+
FileCreatedBy: file.CreatedBy,
714+
TemplateID: template.ID,
715+
TemplateOrganizationID: template.OrganizationID,
716+
TemplateCreatedBy: template.CreatedBy,
717+
UserACL: template.UserACL,
718+
GroupACL: template.GroupACL,
719+
})
720+
}
721+
}
722+
}
723+
}
724+
}
725+
}
726+
727+
return rows, nil
728+
}
729+
689730
func (q *fakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) {
690731
if err := validateDatabaseType(arg); err != nil {
691732
return database.User{}, err

coderd/database/modelmethods.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ func (t Template) RBACObject() rbac.Object {
8888
WithGroupACL(t.GroupACL)
8989
}
9090

91+
func (t GetFileTemplatesRow) RBACObject() rbac.Object {
92+
return rbac.ResourceTemplate.WithID(t.TemplateID).
93+
InOrg(t.TemplateOrganizationID).
94+
WithACLUserList(t.UserACL).
95+
WithGroupACL(t.GroupACL)
96+
}
97+
9198
func (t Template) DeepCopy() Template {
9299
cpy := t
93100
cpy.UserACL = maps.Clone(t.UserACL)

coderd/database/querier.go

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 69 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)