Skip to content

Commit c1423d4

Browse files
committed
Merge remote-tracking branch 'origin/main' into cj/rbac_table_driven
2 parents b22b723 + 63d1465 commit c1423d4

File tree

16 files changed

+344
-50
lines changed

16 files changed

+344
-50
lines changed

.github/workflows/coder.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ jobs:
158158
terraform_version: 1.1.2
159159
terraform_wrapper: false
160160

161+
- name: Install socat
162+
if: runner.os == 'Linux'
163+
run: sudo apt-get install -y socat
164+
161165
- name: Test with Mock Database
162166
shell: bash
163167
env:

agent/agent.go

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import (
2121
"github.com/coder/coder/pty"
2222
"github.com/coder/retry"
2323

24+
"github.com/pkg/sftp"
25+
2426
"github.com/gliderlabs/ssh"
2527
gossh "golang.org/x/crypto/ssh"
2628
"golang.org/x/xerrors"
@@ -121,7 +123,7 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
121123

122124
switch channel.Protocol() {
123125
case "ssh":
124-
a.sshServer.HandleConn(channel.NetConn())
126+
go a.sshServer.HandleConn(channel.NetConn())
125127
default:
126128
a.options.Logger.Warn(ctx, "unhandled protocol from channel",
127129
slog.F("protocol", channel.Protocol()),
@@ -146,7 +148,10 @@ func (a *agent) init(ctx context.Context) {
146148
sshLogger := a.options.Logger.Named("ssh-server")
147149
forwardHandler := &ssh.ForwardedTCPHandler{}
148150
a.sshServer = &ssh.Server{
149-
ChannelHandlers: ssh.DefaultChannelHandlers,
151+
ChannelHandlers: map[string]ssh.ChannelHandler{
152+
"direct-tcpip": ssh.DirectTCPIPHandler,
153+
"session": ssh.DefaultSessionHandler,
154+
},
150155
ConnectionFailedCallback: func(conn net.Conn, err error) {
151156
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
152157
},
@@ -185,61 +190,54 @@ func (a *agent) init(ctx context.Context) {
185190
NoClientAuth: true,
186191
}
187192
},
193+
SubsystemHandlers: map[string]ssh.SubsystemHandler{
194+
"sftp": func(session ssh.Session) {
195+
server, err := sftp.NewServer(session)
196+
if err != nil {
197+
a.options.Logger.Debug(session.Context(), "initialize sftp server", slog.Error(err))
198+
return
199+
}
200+
defer server.Close()
201+
err = server.Serve()
202+
if errors.Is(err, io.EOF) {
203+
return
204+
}
205+
a.options.Logger.Debug(session.Context(), "sftp server exited with error", slog.Error(err))
206+
},
207+
},
188208
}
189209

190210
go a.run(ctx)
191211
}
192212

193213
func (a *agent) handleSSHSession(session ssh.Session) error {
194-
var (
195-
command string
196-
args = []string{}
197-
err error
198-
)
199-
200214
currentUser, err := user.Current()
201215
if err != nil {
202216
return xerrors.Errorf("get current user: %w", err)
203217
}
204218
username := currentUser.Username
205219

220+
shell, err := usershell.Get(username)
221+
if err != nil {
222+
return xerrors.Errorf("get user shell: %w", err)
223+
}
224+
206225
// gliderlabs/ssh returns a command slice of zero
207226
// when a shell is requested.
227+
command := session.RawCommand()
208228
if len(session.Command()) == 0 {
209-
command, err = usershell.Get(username)
210-
if err != nil {
211-
return xerrors.Errorf("get user shell: %w", err)
212-
}
213-
} else {
214-
command = session.Command()[0]
215-
if len(session.Command()) > 1 {
216-
args = session.Command()[1:]
217-
}
229+
command = shell
218230
}
219231

220-
signals := make(chan ssh.Signal)
221-
breaks := make(chan bool)
222-
defer close(signals)
223-
defer close(breaks)
224-
go func() {
225-
for {
226-
select {
227-
case <-session.Context().Done():
228-
return
229-
// Ignore signals and breaks for now!
230-
case <-signals:
231-
case <-breaks:
232-
}
233-
}
234-
}()
235-
236-
cmd := exec.CommandContext(session.Context(), command, args...)
232+
// OpenSSH executes all commands with the users current shell.
233+
// We replicate that behavior for IDE support.
234+
cmd := exec.CommandContext(session.Context(), shell, "-c", command)
237235
cmd.Env = append(os.Environ(), session.Environ()...)
238236
executablePath, err := os.Executable()
239237
if err != nil {
240238
return xerrors.Errorf("getting os executable: %w", err)
241239
}
242-
cmd.Env = append(session.Environ(), fmt.Sprintf(`GIT_SSH_COMMAND="%s gitssh --"`, executablePath))
240+
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND="%s gitssh --"`, executablePath))
243241

244242
sshPty, windowSize, isPty := session.Pty()
245243
if isPty {
@@ -268,7 +266,7 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
268266
}
269267

270268
cmd.Stdout = session
271-
cmd.Stderr = session
269+
cmd.Stderr = session.Stderr()
272270
// This blocks forever until stdin is received if we don't
273271
// use StdinPipe. It's unknown what causes this.
274272
stdinPipe, err := cmd.StdinPipe()
@@ -282,8 +280,7 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
282280
if err != nil {
283281
return xerrors.Errorf("start: %w", err)
284282
}
285-
_ = cmd.Wait()
286-
return nil
283+
return cmd.Wait()
287284
}
288285

289286
// isClosed returns whether the API is closed or not.

agent/agent_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@ import (
55
"fmt"
66
"io"
77
"net"
8+
"os"
89
"os/exec"
10+
"path/filepath"
911
"runtime"
1012
"strconv"
1113
"strings"
1214
"testing"
1315

1416
"github.com/pion/webrtc/v3"
17+
"github.com/pkg/sftp"
1518
"github.com/stretchr/testify/require"
1619
"go.uber.org/goleak"
1720
"golang.org/x/crypto/ssh"
@@ -94,6 +97,7 @@ func TestAgent(t *testing.T) {
9497

9598
local, err := net.Listen("tcp", "127.0.0.1:0")
9699
require.NoError(t, err)
100+
defer local.Close()
97101
tcpAddr, valid = local.Addr().(*net.TCPAddr)
98102
require.True(t, valid)
99103
localPort := tcpAddr.Port
@@ -113,6 +117,21 @@ func TestAgent(t *testing.T) {
113117
conn.Close()
114118
<-done
115119
})
120+
121+
t.Run("SFTP", func(t *testing.T) {
122+
t.Parallel()
123+
sshClient, err := setupAgent(t).SSHClient()
124+
require.NoError(t, err)
125+
client, err := sftp.NewClient(sshClient)
126+
require.NoError(t, err)
127+
tempFile := filepath.Join(t.TempDir(), "sftp")
128+
file, err := client.Create(tempFile)
129+
require.NoError(t, err)
130+
err = file.Close()
131+
require.NoError(t, err)
132+
_, err = os.Stat(tempFile)
133+
require.NoError(t, err)
134+
})
116135
}
117136

118137
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {

cli/cliui/resources_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ import (
44
"testing"
55
"time"
66

7+
"github.com/stretchr/testify/require"
8+
79
"github.com/coder/coder/cli/cliui"
810
"github.com/coder/coder/coderd/database"
911
"github.com/coder/coder/codersdk"
1012
"github.com/coder/coder/pty/ptytest"
11-
"github.com/stretchr/testify/require"
1213
)
1314

1415
func TestWorkspaceResources(t *testing.T) {

coderd/coderd.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ func New(options *Options) (http.Handler, func()) {
150150
r.Route("/{user}", func(r chi.Router) {
151151
r.Use(httpmw.ExtractUserParam(options.Database))
152152
r.Get("/", api.userByName)
153+
r.Put("/profile", api.putUserProfile)
153154
r.Get("/organizations", api.organizationsByUser)
154155
r.Post("/organizations", api.postOrganizationsByUser)
155156
r.Post("/keys", api.postAPIKey)

coderd/database/databasefake/databasefake.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,23 @@ func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam
10501050
return user, nil
10511051
}
10521052

1053+
func (q *fakeQuerier) UpdateUserProfile(_ context.Context, arg database.UpdateUserProfileParams) (database.User, error) {
1054+
q.mutex.Lock()
1055+
defer q.mutex.Unlock()
1056+
1057+
for index, user := range q.users {
1058+
if user.ID != arg.ID {
1059+
continue
1060+
}
1061+
user.Name = arg.Name
1062+
user.Email = arg.Email
1063+
user.Username = arg.Username
1064+
q.users[index] = user
1065+
return user, nil
1066+
}
1067+
return database.User{}, sql.ErrNoRows
1068+
}
1069+
10531070
func (q *fakeQuerier) InsertWorkspace(_ context.Context, arg database.InsertWorkspaceParams) (database.Workspace, error) {
10541071
q.mutex.Lock()
10551072
defer q.mutex.Unlock()

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 43 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/users.sql

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,14 @@ INSERT INTO
4040
)
4141
VALUES
4242
($1, $2, $3, $4, FALSE, $5, $6, $7, $8) RETURNING *;
43+
44+
-- name: UpdateUserProfile :one
45+
UPDATE
46+
users
47+
SET
48+
email = $2,
49+
"name" = $3,
50+
username = $4,
51+
updated_at = $5
52+
WHERE
53+
id = $1 RETURNING *;

coderd/users.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,70 @@ func (*api) userByName(rw http.ResponseWriter, r *http.Request) {
270270
render.JSON(rw, r, convertUser(user))
271271
}
272272

273+
func (api *api) putUserProfile(rw http.ResponseWriter, r *http.Request) {
274+
user := httpmw.UserParam(r)
275+
276+
var params codersdk.UpdateUserProfileRequest
277+
if !httpapi.Read(rw, r, &params) {
278+
return
279+
}
280+
281+
if params.Name == nil {
282+
params.Name = &user.Name
283+
}
284+
285+
existentUser, err := api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
286+
Email: params.Email,
287+
Username: params.Username,
288+
})
289+
isDifferentUser := existentUser.ID != user.ID
290+
291+
if err == nil && isDifferentUser {
292+
responseErrors := []httpapi.Error{}
293+
if existentUser.Email == params.Email {
294+
responseErrors = append(responseErrors, httpapi.Error{
295+
Field: "email",
296+
Code: "exists",
297+
})
298+
}
299+
if existentUser.Username == params.Username {
300+
responseErrors = append(responseErrors, httpapi.Error{
301+
Field: "username",
302+
Code: "exists",
303+
})
304+
}
305+
httpapi.Write(rw, http.StatusConflict, httpapi.Response{
306+
Message: fmt.Sprintf("user already exists"),
307+
Errors: responseErrors,
308+
})
309+
return
310+
}
311+
if !errors.Is(err, sql.ErrNoRows) && isDifferentUser {
312+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
313+
Message: fmt.Sprintf("get user: %s", err),
314+
})
315+
return
316+
}
317+
318+
updatedUserProfile, err := api.Database.UpdateUserProfile(r.Context(), database.UpdateUserProfileParams{
319+
ID: user.ID,
320+
Name: *params.Name,
321+
Email: params.Email,
322+
Username: params.Username,
323+
UpdatedAt: database.Now(),
324+
})
325+
326+
if err != nil {
327+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
328+
Message: fmt.Sprintf("patch user: %s", err.Error()),
329+
})
330+
return
331+
}
332+
333+
render.Status(r, http.StatusOK)
334+
render.JSON(rw, r, convertUser(updatedUserProfile))
335+
}
336+
273337
// Returns organizations the parameterized user has access to.
274338
func (api *api) organizationsByUser(rw http.ResponseWriter, r *http.Request) {
275339
user := httpmw.UserParam(r)
@@ -872,5 +936,6 @@ func convertUser(user database.User) codersdk.User {
872936
Email: user.Email,
873937
CreatedAt: user.CreatedAt,
874938
Username: user.Username,
939+
Name: user.Name,
875940
}
876941
}

0 commit comments

Comments
 (0)