From 051eddab13d6056cb9ceed77e9d884fb36df5298 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 3 Mar 2022 16:02:23 +0000 Subject: [PATCH 1/5] Add client for agent --- agent/agent.go | 8 +- agent/agent_test.go | 8 +- cli/workspaceagent.go | 56 ++++++++ coderd/cmd/root.go | 2 +- coderd/coderd.go | 13 +- coderd/coderdtest/coderdtest.go | 2 +- coderd/projectimport.go | 26 ---- coderd/provisionerdaemons.go | 9 +- coderd/provisionerjobs.go | 76 +++++++++++ coderd/workspaceagent.go | 150 +++++++++------------ coderd/workspaceagent_test.go | 174 ++++++------------------ coderd/workspaceagentauth.go | 147 +++++++++++++++++++++ coderd/workspaceagentauth_test.go | 182 ++++++++++++++++++++++++++ codersdk/client.go | 14 ++ codersdk/provisioners.go | 4 +- codersdk/provisioners_test.go | 4 +- codersdk/workspaceagent.go | 82 ++++++++++++ codersdk/workspaces.go | 13 ++ database/databasefake/databasefake.go | 13 ++ database/querier.go | 1 + database/query.sql | 8 ++ database/query.sql.go | 27 ++++ go.mod | 8 +- go.sum | 33 +++++ httpmw/workspaceagent.go | 65 +++++++++ httpmw/workspaceagent_test.go | 73 +++++++++++ httpmw/workspaceresourceparam.go | 64 +++++++++ httpmw/workspaceresourceparam_test.go | 109 +++++++++++++++ peerbroker/listen.go | 7 +- provisionersdk/agent.go | 16 ++- 30 files changed, 1124 insertions(+), 270 deletions(-) create mode 100644 cli/workspaceagent.go create mode 100644 coderd/workspaceagentauth.go create mode 100644 coderd/workspaceagentauth_test.go create mode 100644 httpmw/workspaceagent.go create mode 100644 httpmw/workspaceagent_test.go create mode 100644 httpmw/workspaceresourceparam.go create mode 100644 httpmw/workspaceresourceparam_test.go diff --git a/agent/agent.go b/agent/agent.go index 285efe3dc9836..d8dba42a47b0b 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -59,9 +59,9 @@ type Options struct { Logger slog.Logger } -type Dialer func(ctx context.Context) (*peerbroker.Listener, error) +type Dialer func(ctx context.Context, options *peer.ConnOptions) (*peerbroker.Listener, error) -func New(dialer Dialer, options *Options) io.Closer { +func New(dialer Dialer, options *peer.ConnOptions) io.Closer { ctx, cancelFunc := context.WithCancel(context.Background()) server := &server{ clientDialer: dialer, @@ -75,7 +75,7 @@ func New(dialer Dialer, options *Options) io.Closer { type server struct { clientDialer Dialer - options *Options + options *peer.ConnOptions closeCancel context.CancelFunc closeMutex sync.Mutex @@ -249,7 +249,7 @@ func (s *server) run(ctx context.Context) { // An exponential back-off occurs when the connection is failing to dial. // This is to prevent server spam in case of a coderd outage. for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { - peerListener, err = s.clientDialer(ctx) + peerListener, err = s.clientDialer(ctx, s.options) if err != nil { if errors.Is(err, context.Canceled) { return diff --git a/agent/agent_test.go b/agent/agent_test.go index 662c054eae146..825600a47931e 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -94,11 +94,9 @@ func TestAgent(t *testing.T) { func setup(t *testing.T) proto.DRPCPeerBrokerClient { client, server := provisionersdk.TransportPipe() - closer := agent.New(func(ctx context.Context) (*peerbroker.Listener, error) { - return peerbroker.Listen(server, &peer.ConnOptions{ - Logger: slogtest.Make(t, nil), - }) - }, &agent.Options{ + closer := agent.New(func(ctx context.Context, opts *peer.ConnOptions) (*peerbroker.Listener, error) { + return peerbroker.Listen(server, opts) + }, &peer.ConnOptions{ Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), }) t.Cleanup(func() { diff --git a/cli/workspaceagent.go b/cli/workspaceagent.go new file mode 100644 index 0000000000000..75a462c0a100e --- /dev/null +++ b/cli/workspaceagent.go @@ -0,0 +1,56 @@ +package cli + +import ( + "net/url" + "os" + + "github.com/coder/coder/agent" + "github.com/coder/coder/codersdk" + "github.com/powersj/whatsthis/pkg/cloud" + "github.com/spf13/cobra" + "golang.org/x/xerrors" +) + +func workspaceAgent() *cobra.Command { + return &cobra.Command{ + Use: "agent", + // This command isn't useful for users, and seems + // more likely to confuse. + Hidden: true, + RunE: func(cmd *cobra.Command, args []string) error { + coderURLRaw, exists := os.LookupEnv("CODER_URL") + if !exists { + return xerrors.New("CODER_URL must be set") + } + coderURL, err := url.Parse(coderURLRaw) + if err != nil { + return xerrors.Errorf("parse %q: %w", coderURLRaw, err) + } + client := codersdk.New(coderURL) + sessionToken, exists := os.LookupEnv("CODER_TOKEN") + if !exists { + probe, err := cloud.New() + if err != nil { + return xerrors.Errorf("probe cloud: %w", err) + } + if !probe.Detected { + return xerrors.Errorf("no valid authentication method found; set \"CODER_TOKEN\"") + } + switch { + case probe.GCP(): + response, err := client.AuthenticateWorkspaceAgentUsingGoogleCloudIdentity(cmd.Context(), "", nil) + if err != nil { + return xerrors.Errorf("authenticate workspace with gcp: %w", err) + } + sessionToken = response.SessionToken + default: + return xerrors.Errorf("%q authentication not supported; set \"CODER_TOKEN\" instead", probe.Name) + } + } + client.SessionToken = sessionToken + closer := agent.New(client.WorkspaceAgentServe, nil) + <-cmd.Context().Done() + return closer.Close() + }, + } +} diff --git a/coderd/cmd/root.go b/coderd/cmd/root.go index 162390898aa77..c5acf3bba3bd0 100644 --- a/coderd/cmd/root.go +++ b/coderd/cmd/root.go @@ -98,7 +98,7 @@ func newProvisionerDaemon(ctx context.Context, client *codersdk.Client, logger s if err != nil { return nil, err } - return provisionerd.New(client.ProvisionerDaemonClient, &provisionerd.Options{ + return provisionerd.New(client.ProvisionerDaemonServe, &provisionerd.Options{ Logger: logger, PollInterval: 50 * time.Millisecond, UpdateInterval: 50 * time.Millisecond, diff --git a/coderd/coderd.go b/coderd/coderd.go index 69a8432aa53bf..90a34787404a1 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -116,6 +116,10 @@ func New(options *Options) (http.Handler, func()) { r.Route("/authenticate", func(r chi.Router) { r.Post("/google-instance-identity", api.postAuthenticateWorkspaceAgentUsingGoogleInstanceIdentity) }) + r.Group(func(r chi.Router) { + r.Use(httpmw.ExtractWorkspaceAgent(options.Database)) + r.Get("/serve", api.workspaceAgentServe) + }) }) r.Route("/upload", func(r chi.Router) { @@ -134,7 +138,7 @@ func New(options *Options) (http.Handler, func()) { r.Get("/", api.provisionerJobByID) r.Get("/schemas", api.projectImportJobSchemasByID) r.Get("/parameters", api.projectImportJobParametersByID) - r.Get("/resources", api.projectImportJobResourcesByID) + r.Get("/resources", api.provisionerJobResourcesByID) r.Get("/logs", api.provisionerJobLogsByID) }) }) @@ -148,6 +152,13 @@ func New(options *Options) (http.Handler, func()) { r.Use(httpmw.ExtractProvisionerJobParam(options.Database)) r.Get("/", api.provisionerJobByID) r.Get("/logs", api.provisionerJobLogsByID) + r.Route("/resources", func(r chi.Router) { + r.Get("/", api.provisionerJobResourcesByID) + r.Route("/{workspaceresource}", func(r chi.Router) { + r.Use(httpmw.ExtractWorkspaceResourceParam(options.Database)) + r.Get("/agent", api.workspaceAgentConnectByResource) + }) + }) }) }) diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 58dd9ceb53292..905bd85b879e2 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -125,7 +125,7 @@ func NewProvisionerDaemon(t *testing.T, client *codersdk.Client) io.Closer { require.NoError(t, err) }() - closer := provisionerd.New(client.ProvisionerDaemonClient, &provisionerd.Options{ + closer := provisionerd.New(client.ProvisionerDaemonServe, &provisionerd.Options{ Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug), PollInterval: 50 * time.Millisecond, UpdateInterval: 50 * time.Millisecond, diff --git a/coderd/projectimport.go b/coderd/projectimport.go index 5ece718c99a90..9955e01ba3319 100644 --- a/coderd/projectimport.go +++ b/coderd/projectimport.go @@ -154,29 +154,3 @@ func (api *api) projectImportJobParametersByID(rw http.ResponseWriter, r *http.R render.Status(r, http.StatusOK) render.JSON(rw, r, values) } - -// Returns resources for an import job by ID. -func (api *api) projectImportJobResourcesByID(rw http.ResponseWriter, r *http.Request) { - job := httpmw.ProvisionerJobParam(r) - if !convertProvisionerJob(job).Status.Completed() { - httpapi.Write(rw, http.StatusPreconditionFailed, httpapi.Response{ - Message: "Job hasn't completed!", - }) - return - } - resources, err := api.Database.GetProvisionerJobResourcesByJobID(r.Context(), job.ID) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get project import job resources: %s", err), - }) - return - } - if resources == nil { - resources = []database.ProvisionerJobResource{} - } - render.Status(r, http.StatusOK) - render.JSON(rw, r, resources) -} diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index d30b876b99547..b416ff1f66c93 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -590,12 +590,19 @@ func insertProvisionerJobResource(ctx context.Context, db database.Store, jobID Valid: true, } } + authToken := uuid.New() + if protoResource.Agent.GetToken() != "" { + authToken, err = uuid.Parse(protoResource.Agent.GetToken()) + if err != nil { + return xerrors.Errorf("invalid auth token format; must be uuid: %w", err) + } + } _, err := db.InsertProvisionerJobAgent(ctx, database.InsertProvisionerJobAgentParams{ ID: resource.AgentID.UUID, CreatedAt: database.Now(), ResourceID: resource.ID, - AuthToken: uuid.New(), + AuthToken: authToken, AuthInstanceID: instanceID, EnvironmentVariables: env, StartupScript: sql.NullString{ diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 8c343256f15ee..e31f187d15365 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -12,6 +12,7 @@ import ( "github.com/go-chi/render" "github.com/google/uuid" + "golang.org/x/xerrors" "cdr.dev/slog" @@ -64,6 +65,7 @@ type ProvisionerJobResource struct { Transition database.WorkspaceTransition `json:"workspace_transition"` Type string `json:"type"` Name string `json:"name"` + Agent *ProvisionerJobAgent `json:"agent,omitempty"` } type ProvisionerJobAgent struct { @@ -238,6 +240,49 @@ func (api *api) provisionerJobLogsByID(rw http.ResponseWriter, r *http.Request) } } +func (api *api) provisionerJobResourcesByID(rw http.ResponseWriter, r *http.Request) { + job := httpmw.ProvisionerJobParam(r) + if !convertProvisionerJob(job).Status.Completed() { + httpapi.Write(rw, http.StatusPreconditionFailed, httpapi.Response{ + Message: "Job hasn't completed!", + }) + return + } + resources, err := api.Database.GetProvisionerJobResourcesByJobID(r.Context(), job.ID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner job resources: %s", err), + }) + return + } + apiResources := make([]ProvisionerJobResource, 0) + for _, resource := range resources { + if !resource.AgentID.Valid { + apiResources = append(apiResources, convertProvisionerJobResource(resource, nil)) + continue + } + // TODO: This should be combined. + agents, err := api.Database.GetProvisionerJobAgentsByResourceIDs(r.Context(), []uuid.UUID{resource.ID}) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner job agent: %s", err), + }) + return + } + agent := agents[0] + apiAgent, err := convertProvisionerJobAgent(agent) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("convert provisioner job agent: %s", err), + }) + return + } + apiResources = append(apiResources, convertProvisionerJobResource(resource, &apiAgent)) + } + render.Status(r, http.StatusOK) + render.JSON(rw, r, apiResources) +} + func convertProvisionerJobLog(provisionerJobLog database.ProvisionerJobLog) ProvisionerJobLog { return ProvisionerJobLog{ ID: provisionerJobLog.ID, @@ -291,6 +336,37 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) ProvisionerJo return job } +func convertProvisionerJobResource(resource database.ProvisionerJobResource, agent *ProvisionerJobAgent) ProvisionerJobResource { + return ProvisionerJobResource{ + ID: resource.ID, + CreatedAt: resource.CreatedAt, + JobID: resource.JobID, + Transition: resource.Transition, + Type: resource.Type, + Name: resource.Name, + Agent: agent, + } +} + +func convertProvisionerJobAgent(agent database.ProvisionerJobAgent) (ProvisionerJobAgent, error) { + var envs map[string]string + if agent.EnvironmentVariables.Valid { + err := json.Unmarshal(agent.EnvironmentVariables.RawMessage, &envs) + if err != nil { + return ProvisionerJobAgent{}, xerrors.Errorf("unmarshal: %w", err) + } + } + return ProvisionerJobAgent{ + ID: agent.ID, + CreatedAt: agent.CreatedAt, + UpdatedAt: agent.UpdatedAt.Time, + ResourceID: agent.ResourceID, + InstanceID: agent.AuthInstanceID.String, + StartupScript: agent.StartupScript.String, + EnvironmentVariables: envs, + }, nil +} + func provisionerJobLogsChannel(jobID uuid.UUID) string { return fmt.Sprintf("provisioner-log-logs:%s", jobID) } diff --git a/coderd/workspaceagent.go b/coderd/workspaceagent.go index 2e45046fa3ff0..4ebc77b100bd1 100644 --- a/coderd/workspaceagent.go +++ b/coderd/workspaceagent.go @@ -1,127 +1,103 @@ package coderd import ( - "database/sql" - "encoding/json" - "errors" "fmt" + "io" "net/http" - "github.com/go-chi/render" + "github.com/google/uuid" + "github.com/hashicorp/yamux" + "nhooyr.io/websocket" - "github.com/coder/coder/database" "github.com/coder/coder/httpapi" - - "github.com/mitchellh/mapstructure" + "github.com/coder/coder/httpmw" + "github.com/coder/coder/peerbroker" + "github.com/coder/coder/peerbroker/proto" + "github.com/coder/coder/provisionersdk" ) -type GoogleInstanceIdentityToken struct { - JSONWebToken string `json:"json_web_token" validate:"required"` -} +func (api *api) workspaceAgentConnectByResource(rw http.ResponseWriter, r *http.Request) { + api.websocketWaitGroup.Add(1) + defer api.websocketWaitGroup.Done() -// WorkspaceAgentAuthenticateResponse is returned when an instance ID -// has been exchanged for a session token. -type WorkspaceAgentAuthenticateResponse struct { - SessionToken string `json:"session_token"` -} - -// Google Compute Engine supports instance identity verification: -// https://cloud.google.com/compute/docs/instances/verifying-instance-identity -// Using this, we can exchange a signed instance payload for an agent token. -func (api *api) postAuthenticateWorkspaceAgentUsingGoogleInstanceIdentity(rw http.ResponseWriter, r *http.Request) { - var req GoogleInstanceIdentityToken - if !httpapi.Read(rw, r, &req) { - return - } - - // We leave the audience blank. It's not important we validate who made the token. - payload, err := api.GoogleTokenValidator.Validate(r.Context(), req.JSONWebToken, "") - if err != nil { - httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ - Message: fmt.Sprintf("validate: %s", err), - }) - return - } - claims := struct { - Google struct { - ComputeEngine struct { - InstanceID string `mapstructure:"instance_id"` - } `mapstructure:"compute_engine"` - } `mapstructure:"google"` - }{} - err = mapstructure.Decode(payload.Claims, &claims) - if err != nil { + resource := httpmw.WorkspaceResourceParam(r) + if !resource.AgentID.Valid { httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ - Message: fmt.Sprintf("decode jwt claims: %s", err), - }) - return - } - agent, err := api.Database.GetProvisionerJobAgentByInstanceID(r.Context(), claims.Google.ComputeEngine.InstanceID) - if errors.Is(err, sql.ErrNoRows) { - httpapi.Write(rw, http.StatusNotFound, httpapi.Response{ - Message: fmt.Sprintf("instance with id %q not found", claims.Google.ComputeEngine.InstanceID), + Message: "resource doesn't have an agent", }) return } + agents, err := api.Database.GetProvisionerJobAgentsByResourceIDs(r.Context(), []uuid.UUID{resource.ID}) if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ Message: fmt.Sprintf("get provisioner job agent: %s", err), }) return } - resource, err := api.Database.GetProvisionerJobResourceByID(r.Context(), agent.ResourceID) + agent := agents[0] + conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ + CompressionMode: websocket.CompressionDisabled, + }) if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get provisioner job resource: %s", err), + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("accept websocket: %s", err), }) return } - job, err := api.Database.GetProvisionerJobByID(r.Context(), resource.JobID) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "") + }() + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config) if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get provisioner job: %s", err), - }) - return - } - if job.Type != database.ProvisionerJobTypeWorkspaceProvision { - httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ - Message: fmt.Sprintf("%q jobs cannot be authenticated", job.Type), - }) + _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) return } - var jobData workspaceProvisionJob - err = json.Unmarshal(job.Input, &jobData) + err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{ + ChannelID: agent.ID.String(), + Logger: api.Logger.Named("peerbroker-proxy-dial"), + Pubsub: api.Pubsub, + }) if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("extract job data: %s", err), - }) + _ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("serve: %s", err)) return } - resourceHistory, err := api.Database.GetWorkspaceHistoryByID(r.Context(), jobData.WorkspaceHistoryID) +} + +func (api *api) workspaceAgentServe(rw http.ResponseWriter, r *http.Request) { + api.websocketWaitGroup.Add(1) + defer api.websocketWaitGroup.Done() + + agent := httpmw.WorkspaceAgent(r) + conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ + CompressionMode: websocket.CompressionDisabled, + }) if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get workspace history: %s", err), + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("accept websocket: %s", err), }) return } - // This token should only be exchanged if the instance ID is valid - // for the latest history. If an instance ID is recycled by a cloud, - // we'd hate to leak access to a user's workspace. - latestHistory, err := api.Database.GetWorkspaceHistoryByWorkspaceIDWithoutAfter(r.Context(), resourceHistory.WorkspaceID) + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "") + }() + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config) if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get latest workspace history: %s", err), - }) + _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) return } - if latestHistory.ID.String() != resourceHistory.ID.String() { - httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ - Message: fmt.Sprintf("resource found for id %q, but isn't registered on the latest history", claims.Google.ComputeEngine.InstanceID), - }) + closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{ + ChannelID: agent.ID.String(), + Pubsub: api.Pubsub, + Logger: api.Logger.Named("peerbroker-proxy-listen"), + }) + if err != nil { + _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) return } - render.Status(r, http.StatusOK) - render.JSON(rw, r, WorkspaceAgentAuthenticateResponse{ - SessionToken: agent.AuthToken.String(), - }) + defer closer.Close() + <-session.CloseChan() } diff --git a/coderd/workspaceagent_test.go b/coderd/workspaceagent_test.go index c48dfc75af1d1..3737544f508ca 100644 --- a/coderd/workspaceagent_test.go +++ b/coderd/workspaceagent_test.go @@ -1,86 +1,45 @@ package coderd_test import ( - "bytes" "context" - "crypto/rand" - "crypto/rsa" - "encoding/base64" - "encoding/json" - "io/ioutil" - "math/big" - "net/http" "testing" "time" - "cloud.google.com/go/compute/metadata" - "github.com/golang-jwt/jwt" - "github.com/stretchr/testify/require" - "google.golang.org/api/idtoken" - "google.golang.org/api/option" - + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/agent" "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" - "github.com/coder/coder/cryptorand" "github.com/coder/coder/database" + "github.com/coder/coder/peer" + "github.com/coder/coder/peerbroker" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" + "github.com/google/uuid" + "github.com/stretchr/testify/require" ) -func TestPostWorkspaceAgentAuthenticateGoogleInstanceIdentity(t *testing.T) { +func TestWorkspaceAgentServe(t *testing.T) { t.Parallel() - t.Run("Expired", func(t *testing.T) { - t.Parallel() - instanceID := "instanceidentifier" - signedKey, keyID, privateKey := createSignedToken(t, instanceID, &jwt.MapClaims{}) - validator := createValidator(t, keyID, privateKey) - client := coderdtest.New(t, &coderdtest.Options{ - GoogleTokenValidator: validator, - }) - _, err := client.AuthenticateWorkspaceAgentUsingGoogleCloudIdentity(context.Background(), "", createMetadataClient(signedKey)) - var apiErr *codersdk.Error - require.ErrorAs(t, err, &apiErr) - require.Equal(t, http.StatusUnauthorized, apiErr.StatusCode()) - }) - - t.Run("InstanceNotFound", func(t *testing.T) { - t.Parallel() - instanceID := "instanceidentifier" - signedKey, keyID, privateKey := createSignedToken(t, instanceID, nil) - validator := createValidator(t, keyID, privateKey) - client := coderdtest.New(t, &coderdtest.Options{ - GoogleTokenValidator: validator, - }) - _, err := client.AuthenticateWorkspaceAgentUsingGoogleCloudIdentity(context.Background(), "", createMetadataClient(signedKey)) - var apiErr *codersdk.Error - require.ErrorAs(t, err, &apiErr) - require.Equal(t, http.StatusNotFound, apiErr.StatusCode()) - }) - t.Run("Success", func(t *testing.T) { t.Parallel() - instanceID := "instanceidentifier" - signedKey, keyID, privateKey := createSignedToken(t, instanceID, nil) - validator := createValidator(t, keyID, privateKey) - client := coderdtest.New(t, &coderdtest.Options{ - GoogleTokenValidator: validator, - }) + client := coderdtest.New(t, nil) user := coderdtest.CreateInitialUser(t, client) - coderdtest.NewProvisionerDaemon(t, client) + daemonCloser := coderdtest.NewProvisionerDaemon(t, client) + authToken := uuid.NewString() job := coderdtest.CreateProjectImportJob(t, client, user.Organization, &echo.Responses{ Parse: echo.ParseComplete, Provision: []*proto.Provision_Response{{ Type: &proto.Provision_Response_Complete{ Complete: &proto.Provision_Complete{ Resources: []*proto.Resource{{ - Name: "somename", - Type: "someinstance", + Name: "example", + Type: "aws_instance", Agent: &proto.Agent{ - Auth: &proto.Agent_GoogleInstanceIdentity{ - GoogleInstanceIdentity: &proto.GoogleInstanceIdentityAuth{ - InstanceId: instanceID, - }, + Id: uuid.NewString(), + Auth: &proto.Agent_Token{ + Token: authToken, }, }, }}, @@ -97,86 +56,33 @@ func TestPostWorkspaceAgentAuthenticateGoogleInstanceIdentity(t *testing.T) { }) require.NoError(t, err) coderdtest.AwaitWorkspaceProvisionJob(t, client, user.Organization, firstHistory.ProvisionJobID) - - _, err = client.AuthenticateWorkspaceAgentUsingGoogleCloudIdentity(context.Background(), "", createMetadataClient(signedKey)) + daemonCloser.Close() + resources, err := client.WorkspaceProvisionJobResources(context.Background(), user.Organization, firstHistory.ProvisionJobID) require.NoError(t, err) - }) -} + require.Len(t, resources, 1) -// Used to easily create an HTTP transport! -type roundTripper func(req *http.Request) (*http.Response, error) - -func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - return r(req) -} - -// Create's a new Google metadata client to authenticate. -func createMetadataClient(signedKey string) *metadata.Client { - return metadata.NewClient(&http.Client{ - Transport: roundTripper(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(bytes.NewReader([]byte(signedKey))), - Header: make(http.Header), - }, nil - }), - }) -} + agentClient := codersdk.New(client.URL) + agentClient.SessionToken = authToken + agentCloser := agent.New(agentClient.WorkspaceAgentServe, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil), + }) -// Create's a signed JWT with a randomly generated private key. -func createSignedToken(t *testing.T, instanceID string, claims *jwt.MapClaims) (signedKey string, keyID string, privateKey *rsa.PrivateKey) { - keyID, err := cryptorand.String(12) - require.NoError(t, err) - if claims == nil { - claims = &jwt.MapClaims{ - "exp": time.Now().AddDate(1, 0, 0).Unix(), - "google": map[string]interface{}{ - "compute_engine": map[string]string{ - "instance_id": instanceID, - }, - }, - } - } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - token.Header["kid"] = keyID - privateKey, err = rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - signedKey, err = token.SignedString(privateKey) - require.NoError(t, err) - return signedKey, keyID, privateKey -} + time.Sleep(time.Millisecond * 250) -// Create's a validator that verifies against the provided private key. -// In a production scenario, the validator calls against the Google OAuth API -// to obtain certificates. -func createValidator(t *testing.T, keyID string, privateKey *rsa.PrivateKey) *idtoken.Validator { - // Taken from: https://github.com/googleapis/google-api-go-client/blob/4bb729045d611fa77bdbeb971f6a1204ba23161d/idtoken/validate.go#L57-L75 - type jwk struct { - Kid string `json:"kid"` - N string `json:"n"` - E string `json:"e"` - } - type certResponse struct { - Keys []jwk `json:"keys"` - } + workspaceClient, err := client.WorkspaceAgentConnect(context.Background(), user.Organization, firstHistory.ProvisionJobID, resources[0].ID) + require.NoError(t, err) + stream, err := workspaceClient.NegotiateConnection(context.Background()) + require.NoError(t, err) + conn, err := peerbroker.Dial(stream, nil, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), + }) + require.NoError(t, err) + _, err = conn.Ping() + require.NoError(t, err) - validator, err := idtoken.NewValidator(context.Background(), option.WithHTTPClient(&http.Client{ - Transport: roundTripper(func(r *http.Request) (*http.Response, error) { - data, err := json.Marshal(certResponse{ - Keys: []jwk{{ - Kid: keyID, - N: base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()), - E: base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(privateKey.E)).Bytes()), - }}, - }) - require.NoError(t, err) - return &http.Response{ - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(bytes.NewReader(data)), - Header: make(http.Header), - }, nil - }), - })) - require.NoError(t, err) - return validator + workspaceClient.DRPCConn().Close() + conn.Close() + stream.Close() + agentCloser.Close() + }) } diff --git a/coderd/workspaceagentauth.go b/coderd/workspaceagentauth.go new file mode 100644 index 0000000000000..3d1913c54eedd --- /dev/null +++ b/coderd/workspaceagentauth.go @@ -0,0 +1,147 @@ +package coderd + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/go-chi/render" + + "github.com/coder/coder/database" + "github.com/coder/coder/httpapi" + + "github.com/mitchellh/mapstructure" +) + +type GoogleInstanceIdentityToken struct { + JSONWebToken string `json:"json_web_token" validate:"required"` +} + +// WorkspaceAgentAuthenticateResponse is returned when an instance ID +// has been exchanged for a session token. +type WorkspaceAgentAuthenticateResponse struct { + SessionToken string `json:"session_token"` +} + +type WorkspaceAgentResourceMetadata struct { + MemoryTotal uint64 `json:"memory_total"` + DiskTotal uint64 `json:"disk_total"` + CPUCores uint64 `json:"cpu_cores"` + CPUModel string `json:"cpu_model"` + CPUMhz float64 `json:"cpu_mhz"` +} + +type WorkspaceAgentInstanceMetadata struct { + JailOrchestrator string `json:"jail_orchestrator"` + OperatingSystem string `json:"operating_system"` + Platform string `json:"platform"` + PlatformFamily string `json:"platform_family"` + KernelVersion string `json:"kernel_version"` + KernelArchitecture string `json:"kernel_architecture"` + Cloud string `json:"cloud"` + Jail string `json:"jail"` + VNC bool `json:"vnc"` +} + +// Google Compute Engine supports instance identity verification: +// https://cloud.google.com/compute/docs/instances/verifying-instance-identity +// Using this, we can exchange a signed instance payload for an agent token. +func (api *api) postAuthenticateWorkspaceAgentUsingGoogleInstanceIdentity(rw http.ResponseWriter, r *http.Request) { + var req GoogleInstanceIdentityToken + if !httpapi.Read(rw, r, &req) { + return + } + + // We leave the audience blank. It's not important we validate who made the token. + payload, err := api.GoogleTokenValidator.Validate(r.Context(), req.JSONWebToken, "") + if err != nil { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: fmt.Sprintf("validate: %s", err), + }) + return + } + claims := struct { + Google struct { + ComputeEngine struct { + InstanceID string `mapstructure:"instance_id"` + } `mapstructure:"compute_engine"` + } `mapstructure:"google"` + }{} + err = mapstructure.Decode(payload.Claims, &claims) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("decode jwt claims: %s", err), + }) + return + } + agent, err := api.Database.GetProvisionerJobAgentByInstanceID(r.Context(), claims.Google.ComputeEngine.InstanceID) + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusNotFound, httpapi.Response{ + Message: fmt.Sprintf("instance with id %q not found", claims.Google.ComputeEngine.InstanceID), + }) + return + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner job agent: %s", err), + }) + return + } + resource, err := api.Database.GetProvisionerJobResourceByID(r.Context(), agent.ResourceID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner job resource: %s", err), + }) + return + } + job, err := api.Database.GetProvisionerJobByID(r.Context(), resource.JobID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner job: %s", err), + }) + return + } + if job.Type != database.ProvisionerJobTypeWorkspaceProvision { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("%q jobs cannot be authenticated", job.Type), + }) + return + } + var jobData workspaceProvisionJob + err = json.Unmarshal(job.Input, &jobData) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("extract job data: %s", err), + }) + return + } + resourceHistory, err := api.Database.GetWorkspaceHistoryByID(r.Context(), jobData.WorkspaceHistoryID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get workspace history: %s", err), + }) + return + } + // This token should only be exchanged if the instance ID is valid + // for the latest history. If an instance ID is recycled by a cloud, + // we'd hate to leak access to a user's workspace. + latestHistory, err := api.Database.GetWorkspaceHistoryByWorkspaceIDWithoutAfter(r.Context(), resourceHistory.WorkspaceID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get latest workspace history: %s", err), + }) + return + } + if latestHistory.ID.String() != resourceHistory.ID.String() { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("resource found for id %q, but isn't registered on the latest history", claims.Google.ComputeEngine.InstanceID), + }) + return + } + render.Status(r, http.StatusOK) + render.JSON(rw, r, WorkspaceAgentAuthenticateResponse{ + SessionToken: agent.AuthToken.String(), + }) +} diff --git a/coderd/workspaceagentauth_test.go b/coderd/workspaceagentauth_test.go new file mode 100644 index 0000000000000..c48dfc75af1d1 --- /dev/null +++ b/coderd/workspaceagentauth_test.go @@ -0,0 +1,182 @@ +package coderd_test + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "io/ioutil" + "math/big" + "net/http" + "testing" + "time" + + "cloud.google.com/go/compute/metadata" + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/require" + "google.golang.org/api/idtoken" + "google.golang.org/api/option" + + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/cryptorand" + "github.com/coder/coder/database" + "github.com/coder/coder/provisioner/echo" + "github.com/coder/coder/provisionersdk/proto" +) + +func TestPostWorkspaceAgentAuthenticateGoogleInstanceIdentity(t *testing.T) { + t.Parallel() + t.Run("Expired", func(t *testing.T) { + t.Parallel() + instanceID := "instanceidentifier" + signedKey, keyID, privateKey := createSignedToken(t, instanceID, &jwt.MapClaims{}) + validator := createValidator(t, keyID, privateKey) + client := coderdtest.New(t, &coderdtest.Options{ + GoogleTokenValidator: validator, + }) + _, err := client.AuthenticateWorkspaceAgentUsingGoogleCloudIdentity(context.Background(), "", createMetadataClient(signedKey)) + var apiErr *codersdk.Error + require.ErrorAs(t, err, &apiErr) + require.Equal(t, http.StatusUnauthorized, apiErr.StatusCode()) + }) + + t.Run("InstanceNotFound", func(t *testing.T) { + t.Parallel() + instanceID := "instanceidentifier" + signedKey, keyID, privateKey := createSignedToken(t, instanceID, nil) + validator := createValidator(t, keyID, privateKey) + client := coderdtest.New(t, &coderdtest.Options{ + GoogleTokenValidator: validator, + }) + _, err := client.AuthenticateWorkspaceAgentUsingGoogleCloudIdentity(context.Background(), "", createMetadataClient(signedKey)) + var apiErr *codersdk.Error + require.ErrorAs(t, err, &apiErr) + require.Equal(t, http.StatusNotFound, apiErr.StatusCode()) + }) + + t.Run("Success", func(t *testing.T) { + t.Parallel() + instanceID := "instanceidentifier" + signedKey, keyID, privateKey := createSignedToken(t, instanceID, nil) + validator := createValidator(t, keyID, privateKey) + client := coderdtest.New(t, &coderdtest.Options{ + GoogleTokenValidator: validator, + }) + user := coderdtest.CreateInitialUser(t, client) + coderdtest.NewProvisionerDaemon(t, client) + job := coderdtest.CreateProjectImportJob(t, client, user.Organization, &echo.Responses{ + Parse: echo.ParseComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "somename", + Type: "someinstance", + Agent: &proto.Agent{ + Auth: &proto.Agent_GoogleInstanceIdentity{ + GoogleInstanceIdentity: &proto.GoogleInstanceIdentityAuth{ + InstanceId: instanceID, + }, + }, + }, + }}, + }, + }, + }}, + }) + project := coderdtest.CreateProject(t, client, user.Organization, job.ID) + coderdtest.AwaitProjectImportJob(t, client, user.Organization, job.ID) + workspace := coderdtest.CreateWorkspace(t, client, "me", project.ID) + firstHistory, err := client.CreateWorkspaceHistory(context.Background(), "", workspace.Name, coderd.CreateWorkspaceHistoryRequest{ + ProjectVersionID: project.ActiveVersionID, + Transition: database.WorkspaceTransitionStart, + }) + require.NoError(t, err) + coderdtest.AwaitWorkspaceProvisionJob(t, client, user.Organization, firstHistory.ProvisionJobID) + + _, err = client.AuthenticateWorkspaceAgentUsingGoogleCloudIdentity(context.Background(), "", createMetadataClient(signedKey)) + require.NoError(t, err) + }) +} + +// Used to easily create an HTTP transport! +type roundTripper func(req *http.Request) (*http.Response, error) + +func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return r(req) +} + +// Create's a new Google metadata client to authenticate. +func createMetadataClient(signedKey string) *metadata.Client { + return metadata.NewClient(&http.Client{ + Transport: roundTripper(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(bytes.NewReader([]byte(signedKey))), + Header: make(http.Header), + }, nil + }), + }) +} + +// Create's a signed JWT with a randomly generated private key. +func createSignedToken(t *testing.T, instanceID string, claims *jwt.MapClaims) (signedKey string, keyID string, privateKey *rsa.PrivateKey) { + keyID, err := cryptorand.String(12) + require.NoError(t, err) + if claims == nil { + claims = &jwt.MapClaims{ + "exp": time.Now().AddDate(1, 0, 0).Unix(), + "google": map[string]interface{}{ + "compute_engine": map[string]string{ + "instance_id": instanceID, + }, + }, + } + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = keyID + privateKey, err = rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + signedKey, err = token.SignedString(privateKey) + require.NoError(t, err) + return signedKey, keyID, privateKey +} + +// Create's a validator that verifies against the provided private key. +// In a production scenario, the validator calls against the Google OAuth API +// to obtain certificates. +func createValidator(t *testing.T, keyID string, privateKey *rsa.PrivateKey) *idtoken.Validator { + // Taken from: https://github.com/googleapis/google-api-go-client/blob/4bb729045d611fa77bdbeb971f6a1204ba23161d/idtoken/validate.go#L57-L75 + type jwk struct { + Kid string `json:"kid"` + N string `json:"n"` + E string `json:"e"` + } + type certResponse struct { + Keys []jwk `json:"keys"` + } + + validator, err := idtoken.NewValidator(context.Background(), option.WithHTTPClient(&http.Client{ + Transport: roundTripper(func(r *http.Request) (*http.Response, error) { + data, err := json.Marshal(certResponse{ + Keys: []jwk{{ + Kid: keyID, + N: base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(privateKey.E)).Bytes()), + }}, + }) + require.NoError(t, err) + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(bytes.NewReader(data)), + Header: make(http.Header), + }, nil + }), + })) + require.NoError(t, err) + return validator +} diff --git a/codersdk/client.go b/codersdk/client.go index 976c31d06f3da..d47cd5fc341ea 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -81,6 +81,20 @@ func (c *Client) request(ctx context.Context, method, path string, body interfac // readBodyAsError reads the response as an httpapi.Message, and // wraps it in a codersdk.Error type for easy marshaling. func readBodyAsError(res *http.Response) error { + contentType := res.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "text/plain") { + resp, err := io.ReadAll(res.Body) + if err != nil { + return xerrors.Errorf("read body: %w", err) + } + return &Error{ + statusCode: res.StatusCode, + Response: httpapi.Response{ + Message: string(resp), + }, + } + } + var m httpapi.Response err := json.NewDecoder(res.Body).Decode(&m) if err != nil { diff --git a/codersdk/provisioners.go b/codersdk/provisioners.go index afef953beabb9..2cee00a626438 100644 --- a/codersdk/provisioners.go +++ b/codersdk/provisioners.go @@ -34,8 +34,8 @@ func (c *Client) ProvisionerDaemons(ctx context.Context) ([]coderd.ProvisionerDa return daemons, json.NewDecoder(res.Body).Decode(&daemons) } -// ProvisionerDaemonClient returns the gRPC service for a provisioner daemon implementation. -func (c *Client) ProvisionerDaemonClient(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { +// ProvisionerDaemonServe returns the gRPC service for a provisioner daemon implementation. +func (c *Client) ProvisionerDaemonServe(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { serverURL, err := c.URL.Parse("/api/v2/provisioners/daemons/serve") if err != nil { return nil, xerrors.Errorf("parse url: %w", err) diff --git a/codersdk/provisioners_test.go b/codersdk/provisioners_test.go index 9fbea9469303e..222a74dbac021 100644 --- a/codersdk/provisioners_test.go +++ b/codersdk/provisioners_test.go @@ -26,7 +26,7 @@ func TestProvisionerDaemonClient(t *testing.T) { t.Parallel() client := coderdtest.New(t, nil) ctx, cancelFunc := context.WithCancel(context.Background()) - daemon, err := client.ProvisionerDaemonClient(ctx) + daemon, err := client.ProvisionerDaemonServe(ctx) require.NoError(t, err) cancelFunc() _, err = daemon.AcquireJob(context.Background(), &proto.Empty{}) @@ -38,7 +38,7 @@ func TestProvisionerDaemonClient(t *testing.T) { client := coderdtest.New(t, nil) ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - daemon, err := client.ProvisionerDaemonClient(ctx) + daemon, err := client.ProvisionerDaemonServe(ctx) require.NoError(t, err) _, err = daemon.AcquireJob(ctx, &proto.Empty{}) require.NoError(t, err) diff --git a/codersdk/workspaceagent.go b/codersdk/workspaceagent.go index 7bfcab9202bfb..57507e7c578c6 100644 --- a/codersdk/workspaceagent.go +++ b/codersdk/workspaceagent.go @@ -4,12 +4,22 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" + "net/http/cookiejar" "cloud.google.com/go/compute/metadata" "golang.org/x/xerrors" + "nhooyr.io/websocket" "github.com/coder/coder/coderd" + "github.com/coder/coder/httpmw" + "github.com/coder/coder/peer" + "github.com/coder/coder/peerbroker" + "github.com/coder/coder/peerbroker/proto" + "github.com/coder/coder/provisionersdk" + "github.com/google/uuid" + "github.com/hashicorp/yamux" ) // AuthenticateWorkspaceAgentUsingGoogleCloudIdentity uses the Google Compute Engine Metadata API to @@ -42,3 +52,75 @@ func (c *Client) AuthenticateWorkspaceAgentUsingGoogleCloudIdentity(ctx context. var resp coderd.WorkspaceAgentAuthenticateResponse return resp, json.NewDecoder(res.Body).Decode(&resp) } + +func (c *Client) WorkspaceAgentConnect(ctx context.Context, organization string, job, resource uuid.UUID) (proto.DRPCPeerBrokerClient, error) { + serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceprovision/%s/%s/resources/%s/agent", organization, job.String(), resource.String())) + if err != nil { + return nil, xerrors.Errorf("parse url: %w", err) + } + jar, err := cookiejar.New(nil) + if err != nil { + return nil, xerrors.Errorf("create cookie jar: %w", err) + } + jar.SetCookies(serverURL, []*http.Cookie{{ + Name: httpmw.AuthCookie, + Value: c.SessionToken, + }}) + httpClient := &http.Client{ + Jar: jar, + } + conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ + HTTPClient: httpClient, + // Need to disable compression to avoid a data-race. + CompressionMode: websocket.CompressionDisabled, + }) + if err != nil { + if res == nil { + return nil, err + } + return nil, readBodyAsError(res) + } + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config) + if err != nil { + return nil, xerrors.Errorf("multiplex client: %w", err) + } + return proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), nil +} + +func (c *Client) WorkspaceAgentServe(ctx context.Context, opts *peer.ConnOptions) (*peerbroker.Listener, error) { + serverURL, err := c.URL.Parse("/api/v2/workspaceagent/serve") + if err != nil { + return nil, xerrors.Errorf("parse url: %w", err) + } + jar, err := cookiejar.New(nil) + if err != nil { + return nil, xerrors.Errorf("create cookie jar: %w", err) + } + jar.SetCookies(serverURL, []*http.Cookie{{ + Name: httpmw.AuthCookie, + Value: c.SessionToken, + }}) + httpClient := &http.Client{ + Jar: jar, + } + conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ + HTTPClient: httpClient, + // Need to disable compression to avoid a data-race. + CompressionMode: websocket.CompressionDisabled, + }) + if err != nil { + if res == nil { + return nil, err + } + return nil, readBodyAsError(res) + } + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config) + if err != nil { + return nil, xerrors.Errorf("multiplex client: %w", err) + } + return peerbroker.Listen(session, opts) +} diff --git a/codersdk/workspaces.go b/codersdk/workspaces.go index 28f926c518049..4e7a97e4e614c 100644 --- a/codersdk/workspaces.go +++ b/codersdk/workspaces.go @@ -157,3 +157,16 @@ func (c *Client) WorkspaceProvisionJobLogsBefore(ctx context.Context, organizati func (c *Client) WorkspaceProvisionJobLogsAfter(ctx context.Context, organization string, job uuid.UUID, after time.Time) (<-chan coderd.ProvisionerJobLog, error) { return c.provisionerJobLogsAfter(ctx, "workspaceprovision", organization, job, after) } + +func (c *Client) WorkspaceProvisionJobResources(ctx context.Context, organization string, job uuid.UUID) ([]coderd.ProvisionerJobResource, error) { + res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceprovision/%s/%s/resources", organization, job), nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, readBodyAsError(res) + } + var resources []coderd.ProvisionerJobResource + return resources, json.NewDecoder(res.Body).Decode(&resources) +} diff --git a/database/databasefake/databasefake.go b/database/databasefake/databasefake.go index 1e09a2a74f4f9..ad4d791d42e07 100644 --- a/database/databasefake/databasefake.go +++ b/database/databasefake/databasefake.go @@ -508,6 +508,19 @@ func (q *fakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.Provi return q.provisionerDaemons, nil } +func (q *fakeQuerier) GetProvisionerJobAgentByAuthToken(_ context.Context, authToken uuid.UUID) (database.ProvisionerJobAgent, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for i := len(q.provisionerJobAgent) - 1; i >= 0; i-- { + agent := q.provisionerJobAgent[i] + if agent.AuthToken.String() == authToken.String() { + return agent, nil + } + } + return database.ProvisionerJobAgent{}, sql.ErrNoRows +} + func (q *fakeQuerier) GetProvisionerJobAgentByInstanceID(_ context.Context, instanceID string) (database.ProvisionerJobAgent, error) { q.mutex.Lock() defer q.mutex.Unlock() diff --git a/database/querier.go b/database/querier.go index faa24ac3a4970..e6d33a1a4002a 100644 --- a/database/querier.go +++ b/database/querier.go @@ -26,6 +26,7 @@ type querier interface { GetProjectsByOrganizationIDs(ctx context.Context, ids []string) ([]Project, error) GetProvisionerDaemonByID(ctx context.Context, id uuid.UUID) (ProvisionerDaemon, error) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error) + GetProvisionerJobAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (ProvisionerJobAgent, error) GetProvisionerJobAgentByInstanceID(ctx context.Context, authInstanceID string) (ProvisionerJobAgent, error) GetProvisionerJobAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJobAgent, error) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (ProvisionerJob, error) diff --git a/database/query.sql b/database/query.sql index eabea0c43d709..3b2574acb15f1 100644 --- a/database/query.sql +++ b/database/query.sql @@ -226,6 +226,14 @@ SELECT FROM provisioner_daemon; +-- name: GetProvisionerJobAgentByAuthToken :one +SELECT + * +FROM + provisioner_job_agent +WHERE + auth_token = $1; + -- name: GetProvisionerJobAgentByInstanceID :one SELECT * diff --git a/database/query.sql.go b/database/query.sql.go index c74ecbcbf94e5..52475635e0785 100644 --- a/database/query.sql.go +++ b/database/query.sql.go @@ -617,6 +617,33 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa return items, nil } +const getProvisionerJobAgentByAuthToken = `-- name: GetProvisionerJobAgentByAuthToken :one +SELECT + id, created_at, updated_at, resource_id, auth_token, auth_instance_id, environment_variables, startup_script, instance_metadata, resource_metadata +FROM + provisioner_job_agent +WHERE + auth_token = $1 +` + +func (q *sqlQuerier) GetProvisionerJobAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (ProvisionerJobAgent, error) { + row := q.db.QueryRowContext(ctx, getProvisionerJobAgentByAuthToken, authToken) + var i ProvisionerJobAgent + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.ResourceID, + &i.AuthToken, + &i.AuthInstanceID, + &i.EnvironmentVariables, + &i.StartupScript, + &i.InstanceMetadata, + &i.ResourceMetadata, + ) + return i, err +} + const getProvisionerJobAgentByInstanceID = `-- name: GetProvisionerJobAgentByInstanceID :one SELECT id, created_at, updated_at, resource_id, auth_token, auth_instance_id, environment_variables, startup_script, instance_metadata, resource_metadata diff --git a/go.mod b/go.mod index 1180eba9f01cc..077e969ae9c69 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( github.com/hashicorp/go-version v1.4.0 github.com/hashicorp/terraform-config-inspect v0.0.0-20211115214459-90acf1ca460f github.com/hashicorp/terraform-exec v0.15.0 + github.com/hashicorp/terraform-json v0.13.0 github.com/hashicorp/terraform-plugin-sdk/v2 v2.10.1 github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 github.com/justinas/nosurf v1.1.1 @@ -52,7 +53,9 @@ require ( github.com/pion/transport v0.13.0 github.com/pion/webrtc/v3 v3.1.24 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 + github.com/powersj/whatsthis v1.3.0 github.com/quasilyte/go-ruleguard/dsl v0.3.17 + github.com/shirou/gopsutil v3.21.11+incompatible github.com/spf13/cobra v1.3.0 github.com/stretchr/testify v1.7.0 github.com/tabbed/pqtype v0.1.1 @@ -90,6 +93,7 @@ require ( github.com/docker/docker v20.10.12+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.4.0 // indirect + github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-playground/locales v0.14.0 // indirect github.com/go-playground/universal-translator v0.18.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect @@ -109,7 +113,6 @@ require ( github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.11.1 // indirect github.com/hashicorp/logutils v1.0.0 // indirect - github.com/hashicorp/terraform-json v0.13.0 // indirect github.com/hashicorp/terraform-plugin-go v0.5.0 // indirect github.com/hashicorp/terraform-plugin-log v0.2.0 // indirect github.com/hashicorp/terraform-registry-address v0.0.0-20210412075316-9b2996cce896 // indirect @@ -147,10 +150,13 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sirupsen/logrus v1.8.1 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/tklauser/go-sysconf v0.3.10 // indirect + github.com/tklauser/numcpus v0.4.0 // indirect github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect + github.com/yusufpapurcu/wmi v1.2.2 // indirect github.com/zclconf/go-cty v1.10.0 // indirect github.com/zeebo/errs v1.2.2 // indirect golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect diff --git a/go.sum b/go.sum index 68dd21a43d74e..d28ed5a7262da 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,7 @@ cloud.google.com/go/compute v1.5.0 h1:b1zWmYuuHz7gO9kDcM/EpHGr06UgsYNRpNJzI2kFiL cloud.google.com/go/compute v1.5.0/go.mod h1:9SMHyhJlzhlkJqrPAc839t2BZFTSk6Jdj6mkzQJeu0M= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= +cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk= cloud.google.com/go/firestore v1.6.1/go.mod h1:asNXNOzBdyVQmEU+ggO8UPodTkEVFW5Qx+rwHnAz+EY= cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= @@ -196,6 +197,7 @@ github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCS github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= github.com/bkaradzic/go-lz4 v1.0.0/go.mod h1:0YdlkowM3VswSROI7qDxhRvJ3sLhlFrRRwjwegp5jy4= +github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJmJgSg28kpZDP6UIiPt0e0Oz0kqKNGyRaWEPv84= github.com/blang/semver v3.1.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= @@ -340,6 +342,7 @@ github.com/containers/ocicrypt v1.1.0/go.mod h1:b8AOe0YR67uU8OqfVNcznfFpAzu3rdgU github.com/containers/ocicrypt v1.1.1/go.mod h1:Dm55fwWm1YZAjYRaJ94z2mfZikIyIN4B0oB3dj3jFxY= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= +github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-iptables v0.4.5/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU= github.com/coreos/go-iptables v0.5.0/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmekav8Dbxlm1MU= github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= @@ -482,6 +485,8 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas= github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-openapi/jsonpointer v0.19.2/go.mod h1:3akKfEdA7DF1sugOqz1dVQHBcuDBPKZGEoHC/NkiQRg= github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonreference v0.19.2/go.mod h1:jMjeRr2HHw6nAVajTXJ4eiUwohSTlpa0o73RUL1owJc= @@ -674,7 +679,9 @@ github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= +github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q= github.com/hashicorp/consul/api v1.11.0/go.mod h1:XjsvQN+RJGWI2TWy1/kqaE16HrR2J/FWgkYjdZQsX9M= +github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= github.com/hashicorp/consul/sdk v0.8.0/go.mod h1:GBvyrGALthsZObzUGsfgHZQDXjg4lOjagTIwIR1vPms= github.com/hashicorp/errwrap v0.0.0-20141028054710-7554cd9344ce/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -706,6 +713,7 @@ github.com/hashicorp/go-plugin v1.3.0/go.mod h1:F9eH4LrE/ZsRdbwhfjs9k9HoDUwAHnYt github.com/hashicorp/go-plugin v1.4.1 h1:6UltRQlLN9iZO513VveELp5xyaFxVD2+1OVylE+2E+w= github.com/hashicorp/go-plugin v1.4.1/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ= github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= +github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= @@ -717,6 +725,7 @@ github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09 github.com/hashicorp/go-version v1.3.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/go-version v1.4.0 h1:aAQzgqIrRKRa7w75CKpbBxYsmUoPjzVm1W59ca1L0J4= github.com/hashicorp/go-version v1.4.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= @@ -731,10 +740,13 @@ github.com/hashicorp/hcl/v2 v2.11.1 h1:yTyWcXcm9XB0TEkyU/JCRU6rYy4K+mgLtzn2wlrJb github.com/hashicorp/hcl/v2 v2.11.1/go.mod h1:FwWsfWEjyV/CMj8s/gqAuiviY72rJ1/oayI9WftqcKg= github.com/hashicorp/logutils v1.0.0 h1:dLEQVugN8vlakKOUE3ihGLTZJRB4j+M2cdTm/ORI65Y= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= +github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/mdns v1.0.1/go.mod h1:4gW7WsVCke5TE7EPeYliwHlRUyBtfCwuFwuMg2DmyNY= github.com/hashicorp/mdns v1.0.4/go.mod h1:mtBihi+LeNXGtG8L9dX59gAEa12BDtBQSp4v/YAJqrc= +github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/memberlist v0.2.2/go.mod h1:MS2lj3INKhZjWNqd3N0m3J+Jxf3DAOnAH9VT3Sh9MUE= github.com/hashicorp/memberlist v0.3.0/go.mod h1:MS2lj3INKhZjWNqd3N0m3J+Jxf3DAOnAH9VT3Sh9MUE= +github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= github.com/hashicorp/serf v0.9.5/go.mod h1:UWDWwZeL5cuWDJdl0C6wrvrUwEqtQ4ZKBKKENpqIUyk= github.com/hashicorp/serf v0.9.6/go.mod h1:TXZNMjZQijwlDvp+r0b63xZ45H7JmCmgg4gpTwn9UV4= github.com/hashicorp/terraform-json v0.13.0 h1:Li9L+lKD1FO5RVFRM1mMMIBDoUHslOniyEi5CM+FWGY= @@ -915,6 +927,7 @@ github.com/lunixbochs/vtclean v1.0.0 h1:xu2sLAri4lGiovBDQKxl5mrXyESr3gUr5m5SM5+L github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/lyft/protoc-gen-star v0.5.3/go.mod h1:V0xaHgaf5oCCqmcxYcWiDfTiKsZsRc87/1qhoTACD8w= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= +github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= @@ -957,9 +970,11 @@ github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKju github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI= github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/mistifyio/go-zfs v2.1.2-0.20190413222219-f784269be439+incompatible/go.mod h1:8AuVvqP/mXw1px98n46wfvcGfQ4ci2FwoAjKYxuo3Z4= +github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/cli v1.1.0/go.mod h1:xcISNoH86gajksDmfB23e/pu+B+GeFRMYmoHXxx3xhI= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= +github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-testing-interface v0.0.0-20171004221916-a61a99592b77/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= @@ -970,6 +985,8 @@ github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7/go.mod h1:ZX github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= +github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= +github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v0.0.0-20180220230111-00c29f56e238/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= @@ -1130,6 +1147,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/posener/complete v1.2.3/go.mod h1:WZIdtGGp+qx0sLrYKtIRAruyNpv6hFCicSgv7Sy7s/s= +github.com/powersj/whatsthis v1.3.0 h1:FhP+pZZr6rxBC2N/ydZOvzcFOx60Ujggy2ACYxa6Xac= +github.com/powersj/whatsthis v1.3.0/go.mod h1:8NwT2j1fdsmLLVBZ0uNPb1cvHwBHm6G3e0t/Kk7AsmI= github.com/pquerna/cachecontrol v0.0.0-20171018203845-0dec1b30a021/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA= github.com/prometheus/client_golang v0.0.0-20180209125602-c332b6f63c06/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= @@ -1193,6 +1212,8 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= +github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= @@ -1209,6 +1230,7 @@ github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/snowflakedb/gosnowflake v1.6.3/go.mod h1:6hLajn6yxuJ4xUHZegMekpq9rnQbGJ7TMwXjgTmA6lg= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= @@ -1221,6 +1243,7 @@ github.com/spf13/cast v1.4.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkU github.com/spf13/cobra v0.0.2-0.20171109065643-2da4a54c5cee/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/cobra v1.0.0/go.mod h1:/6GTrnGXV9HjY+aR4k0oJ5tcvakLuG6EuKReYlHNrgE= +github.com/spf13/cobra v1.1.1/go.mod h1:WnodtKOvamDL/PwE2M4iKs8aMDBZ5Q5klgD3qfVJQMI= github.com/spf13/cobra v1.3.0 h1:R7cSvGu+Vv+qX0gW5R/85dx2kmmJT5z5NM8ifdYjdn0= github.com/spf13/cobra v1.3.0/go.mod h1:BrRVncBjOJa/eUcVVm9CE+oC6as8k+VYr4NY7WCi9V4= github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= @@ -1233,6 +1256,7 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE= +github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/spf13/viper v1.10.0/go.mod h1:SoyBPwAtKDzypXNDFKN5kzH7ppppbGZtls1UpIy5AsM= github.com/stefanberger/go-pkcs11uri v0.0.0-20201008174630-78d3cae3a980/go.mod h1:AO3tvPzVZ/ayst6UlUKUv6rcPQInYe3IknH3jYhAKu8= github.com/stretchr/objx v0.0.0-20180129172003-8a3f7159479f/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -1255,6 +1279,10 @@ github.com/tabbed/pqtype v0.1.1 h1:PhEcb9JZ8jr7SUjJDFjRPxny0M8fkXZrxn/a9yQfoZg= github.com/tabbed/pqtype v0.1.1/go.mod h1:HLt2kLJPcUhODQkYn3mJkMHXVsuv3Z2n5NZEeKXL0Uk= github.com/tchap/go-patricia v2.2.6+incompatible/go.mod h1:bmLyhP68RS6kStMGxByiQ23RP/odRBOTVjwp2cDyi6I= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/tklauser/go-sysconf v0.3.10 h1:IJ1AZGZRWbY8T5Vfk04D9WOA5WSejdflXxP03OUqALw= +github.com/tklauser/go-sysconf v0.3.10/go.mod h1:C8XykCvCb+Gn0oNCWPIlcb0RuglQTYaQ2hGm7jmxEFk= +github.com/tklauser/numcpus v0.4.0 h1:E53Dm1HjH1/R2/aoCtXtPgzmElmn51aOkhCFSuZq//o= +github.com/tklauser/numcpus v0.4.0/go.mod h1:1+UI3pD8NW14VMwdgJNJ1ESk2UnwhAnz5hMwiKKqXCQ= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= @@ -1307,6 +1335,8 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPRg= +github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/yvasiyarov/go-metrics v0.0.0-20140926110328-57bccd1ccd43/go.mod h1:aX5oPXxHm3bOH+xeAttToC8pqch2ScQN/JoXYupl6xs= github.com/yvasiyarov/gorelic v0.0.0-20141212073537-a9bba5b9ab50/go.mod h1:NUSPSUX/bi6SeDMUh6brw0nXpxHnc96TguQh0+r/ssA= github.com/yvasiyarov/newrelic_platform_go v0.0.0-20140908184405-b21fdbd4370f/go.mod h1:GlGEuHIJweS1mbCqG+7vt2nvWLzLLnRHbXz5JKd/Qbg= @@ -1438,6 +1468,7 @@ golang.org/x/net v0.0.0-20181011144130-49bb7cea24b1/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181108082009-03003ca0c849/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1703,6 +1734,7 @@ golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191112195655-aa38f8e97acc/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -1952,6 +1984,7 @@ gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMy gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= diff --git a/httpmw/workspaceagent.go b/httpmw/workspaceagent.go new file mode 100644 index 0000000000000..85a46ed0001cb --- /dev/null +++ b/httpmw/workspaceagent.go @@ -0,0 +1,65 @@ +package httpmw + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + + "github.com/google/uuid" + + "github.com/coder/coder/database" + "github.com/coder/coder/httpapi" +) + +type workspaceAgentContextKey struct{} + +// WorkspaceAgent returns the workspace agent from the ExtractAgent handler. +func WorkspaceAgent(r *http.Request) database.ProvisionerJobAgent { + user, ok := r.Context().Value(workspaceAgentContextKey{}).(database.ProvisionerJobAgent) + if !ok { + panic("developer error: agent middleware not provided") + } + return user +} + +// ExtractWorkspaceAgent requires authentication using a valid agent token. +func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(AuthCookie) + if err != nil { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: fmt.Sprintf("%q cookie must be provided", AuthCookie), + }) + return + } + token, err := uuid.Parse(cookie.Value) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: fmt.Sprintf("parse token: %s", err), + }) + return + } + agent, err := db.GetProvisionerJobAgentByAuthToken(r.Context(), token) + if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: "agent token is invalid", + }) + return + } + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get workspace agent: %s", err), + }) + return + } + + ctx := context.WithValue(r.Context(), workspaceAgentContextKey{}, agent) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) + } +} diff --git a/httpmw/workspaceagent_test.go b/httpmw/workspaceagent_test.go new file mode 100644 index 0000000000000..b36d74c9c23dc --- /dev/null +++ b/httpmw/workspaceagent_test.go @@ -0,0 +1,73 @@ +package httpmw_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/database" + "github.com/coder/coder/database/databasefake" + "github.com/coder/coder/httpmw" +) + +func TestWorkspaceAgent(t *testing.T) { + t.Parallel() + + setup := func(db database.Store) (*http.Request, uuid.UUID) { + token := uuid.New() + r := httptest.NewRequest("GET", "/", nil) + r.AddCookie(&http.Cookie{ + Name: httpmw.AuthCookie, + Value: token.String(), + }) + return r, token + } + + t.Run("None", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractWorkspaceAgent(db), + ) + rtr.Get("/", nil) + r, _ := setup(db) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) + + t.Run("Found", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractWorkspaceAgent(db), + ) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + _ = httpmw.WorkspaceAgent(r) + rw.WriteHeader(http.StatusOK) + }) + r, token := setup(db) + _, err := db.InsertProvisionerJobAgent(context.Background(), database.InsertProvisionerJobAgentParams{ + ID: uuid.New(), + AuthToken: token, + }) + require.NoError(t, err) + require.NoError(t, err) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) +} diff --git a/httpmw/workspaceresourceparam.go b/httpmw/workspaceresourceparam.go new file mode 100644 index 0000000000000..2f19b153dc701 --- /dev/null +++ b/httpmw/workspaceresourceparam.go @@ -0,0 +1,64 @@ +package httpmw + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + + "github.com/coder/coder/database" + "github.com/coder/coder/httpapi" +) + +type workspaceResourceParamContextKey struct{} + +// ProvisionerJobParam returns the project from the ExtractProjectParam handler. +func WorkspaceResourceParam(r *http.Request) database.ProvisionerJobResource { + resource, ok := r.Context().Value(workspaceResourceParamContextKey{}).(database.ProvisionerJobResource) + if !ok { + panic("developer error: workspace resource param middleware not provided") + } + return resource +} + +// ExtractWorkspaceResourceParam grabs a workspace resource from the "provisionerjob" URL parameter. +func ExtractWorkspaceResourceParam(db database.Store) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + resourceID := chi.URLParam(r, "workspaceresource") + if resourceID == "" { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: "workspace resource must be provided", + }) + return + } + resourceUUID, err := uuid.Parse(resourceID) + if err != nil { + httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ + Message: "resource id must be a uuid", + }) + return + } + resource, err := db.GetProvisionerJobResourceByID(r.Context(), resourceUUID) + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusNotFound, httpapi.Response{ + Message: "resource doesn't exist with that id", + }) + return + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner resource: %s", err), + }) + return + } + + ctx := context.WithValue(r.Context(), workspaceResourceParamContextKey{}, resource) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) + } +} diff --git a/httpmw/workspaceresourceparam_test.go b/httpmw/workspaceresourceparam_test.go new file mode 100644 index 0000000000000..454a061c665e8 --- /dev/null +++ b/httpmw/workspaceresourceparam_test.go @@ -0,0 +1,109 @@ +package httpmw_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/database" + "github.com/coder/coder/database/databasefake" + "github.com/coder/coder/httpmw" +) + +func TestWorkspaceResourceParam(t *testing.T) { + t.Parallel() + + setup := func(db database.Store) (*http.Request, database.ProvisionerJobResource) { + r := httptest.NewRequest("GET", "/", nil) + resource, err := db.InsertProvisionerJobResource(context.Background(), database.InsertProvisionerJobResourceParams{ + ID: uuid.New(), + }) + require.NoError(t, err) + + ctx := chi.NewRouteContext() + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx)) + return r, resource + } + + t.Run("None", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractWorkspaceResourceParam(db), + ) + rtr.Get("/", nil) + r, _ := setup(db) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusBadRequest, res.StatusCode) + }) + + t.Run("BadUUID", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractWorkspaceResourceParam(db), + ) + rtr.Get("/", nil) + + r, _ := setup(db) + chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", "nothin") + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusBadRequest, res.StatusCode) + }) + + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractWorkspaceResourceParam(db), + ) + rtr.Get("/", nil) + + r, _ := setup(db) + chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", uuid.NewString()) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusNotFound, res.StatusCode) + }) + + t.Run("ProvisionerJob", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractWorkspaceResourceParam(db), + ) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + _ = httpmw.WorkspaceResourceParam(r) + rw.WriteHeader(http.StatusOK) + }) + + r, job := setup(db) + chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", job.ID.String()) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) +} diff --git a/peerbroker/listen.go b/peerbroker/listen.go index 8e92fe7a7c82d..00c9faf125ba8 100644 --- a/peerbroker/listen.go +++ b/peerbroker/listen.go @@ -22,7 +22,8 @@ import ( func Listen(connListener net.Listener, opts *peer.ConnOptions) (*Listener, error) { ctx, cancelFunc := context.WithCancel(context.Background()) listener := &Listener{ - connectionChannel: make(chan *peer.Conn), + connectionChannel: make(chan *peer.Conn), + connectionListener: connListener, closeFunc: cancelFunc, closed: make(chan struct{}), @@ -47,7 +48,8 @@ func Listen(connListener net.Listener, opts *peer.ConnOptions) (*Listener, error } type Listener struct { - connectionChannel chan *peer.Conn + connectionChannel chan *peer.Conn + connectionListener net.Listener closeFunc context.CancelFunc closed chan struct{} @@ -79,6 +81,7 @@ func (l *Listener) closeWithError(err error) error { return l.closeError } + _ = l.connectionListener.Close() l.closeError = err l.closeFunc() close(l.closed) diff --git a/provisionersdk/agent.go b/provisionersdk/agent.go index 7d97dd38ce70a..2f2a3a74dc68e 100644 --- a/provisionersdk/agent.go +++ b/provisionersdk/agent.go @@ -18,6 +18,7 @@ var ( $ProgressPreference = "SilentlyContinue" $ErrorActionPreference = "Stop" Invoke-WebRequest -Uri ${DOWNLOAD_URL} -OutFile $env:TEMP\coder.exe +$env:CODER_URL = "${ACCESS_URL}" Start-Process -FilePath $env:TEMP\coder.exe workspaces agent `, }, @@ -28,6 +29,7 @@ set -eu pipefail BINARY_LOCATION=$(mktemp -d)/coder curl -fsSL ${DOWNLOAD_URL} -o $BINARY_LOCATION chmod +x $BINARY_LOCATION +export CODER_URL="${ACCESS_URL}" exec $BINARY_LOCATION agent `, }, @@ -38,6 +40,7 @@ set -eu pipefail BINARY_LOCATION=$(mktemp -d)/coder curl -fsSL ${DOWNLOAD_URL} -o $BINARY_LOCATION chmod +x $BINARY_LOCATION +export CODER_URL="${ACCESS_URL}" exec $BINARY_LOCATION agent `, }, @@ -63,9 +66,16 @@ func AgentScript(coderURL *url.URL, operatingSystem, architecture string) (strin } return "", xerrors.Errorf("architecture %q not supported for %q. must be in: %v", architecture, operatingSystem, list) } - parsed, err := coderURL.Parse(fmt.Sprintf("/bin/coder-%s-%s", operatingSystem, architecture)) + downloadURL, err := coderURL.Parse(fmt.Sprintf("/bin/coder-%s-%s", operatingSystem, architecture)) if err != nil { - return "", xerrors.Errorf("parse url: %w", err) + return "", xerrors.Errorf("parse download url: %w", err) } - return strings.ReplaceAll(script, "${DOWNLOAD_URL}", parsed.String()), nil + accessURL, err := coderURL.Parse("/") + if err != nil { + return "", xerrors.Errorf("parse access url: %w", err) + } + return strings.NewReplacer( + "${DOWNLOAD_URL}", downloadURL.String(), + "${ACCESS_URL}", accessURL.String(), + ).Replace(script), nil } From 1aec36d1281e6dd4ebb06f3501119dcb285ad7df Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 3 Mar 2022 16:42:28 +0000 Subject: [PATCH 2/5] Cleanup code --- coderd/coderd.go | 1 + coderd/provisionerjobs.go | 45 +++++++++++++++-- coderd/workspaceagent.go | 39 +++++++++++++-- coderd/workspaceagent_test.go | 20 +++++--- database/databasefake/databasefake.go | 32 +++++++----- database/querier.go | 3 +- database/query.sql | 12 ++++- database/query.sql.go | 71 ++++++++++++++------------- provisioner/echo/serve.go | 33 +++++++++++-- 9 files changed, 188 insertions(+), 68 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index 90a34787404a1..b86de63834e00 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -156,6 +156,7 @@ func New(options *Options) (http.Handler, func()) { r.Get("/", api.provisionerJobResourcesByID) r.Route("/{workspaceresource}", func(r chi.Router) { r.Use(httpmw.ExtractWorkspaceResourceParam(options.Database)) + r.Get("/", api.provisionerJobResourceByID) r.Get("/agent", api.workspaceAgentConnectByResource) }) }) diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index e31f187d15365..3e2e665ee3fcc 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -261,15 +261,54 @@ func (api *api) provisionerJobResourcesByID(rw http.ResponseWriter, r *http.Requ apiResources = append(apiResources, convertProvisionerJobResource(resource, nil)) continue } - // TODO: This should be combined. - agents, err := api.Database.GetProvisionerJobAgentsByResourceIDs(r.Context(), []uuid.UUID{resource.ID}) + agent, err := api.Database.GetProvisionerJobAgentByResourceID(r.Context(), resource.ID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner job agent: %s", err), + }) + return + } + apiAgent, err := convertProvisionerJobAgent(agent) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("convert provisioner job agent: %s", err), + }) + return + } + apiResources = append(apiResources, convertProvisionerJobResource(resource, &apiAgent)) + } + render.Status(r, http.StatusOK) + render.JSON(rw, r, apiResources) +} + +func (api *api) provisionerJobResourceByID(rw http.ResponseWriter, r *http.Request) { + job := httpmw.ProvisionerJobParam(r) + if !convertProvisionerJob(job).Status.Completed() { + httpapi.Write(rw, http.StatusPreconditionFailed, httpapi.Response{ + Message: "Job hasn't completed!", + }) + return + } + resources, err := api.Database.GetProvisionerJobResourcesByJobID(r.Context(), job.ID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get provisioner job resources: %s", err), + }) + return + } + apiResources := make([]ProvisionerJobResource, 0) + for _, resource := range resources { + if !resource.AgentID.Valid { + apiResources = append(apiResources, convertProvisionerJobResource(resource, nil)) + continue + } + agent, err := api.Database.GetProvisionerJobAgentByResourceID(r.Context(), resource.ID) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: fmt.Sprintf("get provisioner job agent: %s", err), }) return } - agent := agents[0] apiAgent, err := convertProvisionerJobAgent(agent) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ diff --git a/coderd/workspaceagent.go b/coderd/workspaceagent.go index 4ebc77b100bd1..f70c663416f5c 100644 --- a/coderd/workspaceagent.go +++ b/coderd/workspaceagent.go @@ -1,14 +1,16 @@ package coderd import ( + "database/sql" "fmt" "io" "net/http" + "time" - "github.com/google/uuid" "github.com/hashicorp/yamux" "nhooyr.io/websocket" + "github.com/coder/coder/database" "github.com/coder/coder/httpapi" "github.com/coder/coder/httpmw" "github.com/coder/coder/peerbroker" @@ -27,14 +29,13 @@ func (api *api) workspaceAgentConnectByResource(rw http.ResponseWriter, r *http. }) return } - agents, err := api.Database.GetProvisionerJobAgentsByResourceIDs(r.Context(), []uuid.UUID{resource.ID}) + agent, err := api.Database.GetProvisionerJobAgentByResourceID(r.Context(), resource.ID) if err != nil { httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ Message: fmt.Sprintf("get provisioner job agent: %s", err), }) return } - agent := agents[0] conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ CompressionMode: websocket.CompressionDisabled, }) @@ -99,5 +100,35 @@ func (api *api) workspaceAgentServe(rw http.ResponseWriter, r *http.Request) { return } defer closer.Close() - <-session.CloseChan() + err = api.Database.UpdateProvisionerJobAgentByID(r.Context(), database.UpdateProvisionerJobAgentByIDParams{ + ID: agent.ID, + UpdatedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + }) + if err != nil { + _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) + return + } + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + for { + select { + case <-session.CloseChan(): + return + case <-ticker.C: + err = api.Database.UpdateProvisionerJobAgentByID(r.Context(), database.UpdateProvisionerJobAgentByIDParams{ + ID: agent.ID, + UpdatedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + }) + if err != nil { + _ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) + return + } + } + } } diff --git a/coderd/workspaceagent_test.go b/coderd/workspaceagent_test.go index 3737544f508ca..17d56b8c7a560 100644 --- a/coderd/workspaceagent_test.go +++ b/coderd/workspaceagent_test.go @@ -29,7 +29,8 @@ func TestWorkspaceAgentServe(t *testing.T) { daemonCloser := coderdtest.NewProvisionerDaemon(t, client) authToken := uuid.NewString() job := coderdtest.CreateProjectImportJob(t, client, user.Organization, &echo.Responses{ - Parse: echo.ParseComplete, + Parse: echo.ParseComplete, + ProvisionDryRun: echo.ProvisionComplete, Provision: []*proto.Provision_Response{{ Type: &proto.Provision_Response_Complete{ Complete: &proto.Provision_Complete{ @@ -50,16 +51,13 @@ func TestWorkspaceAgentServe(t *testing.T) { project := coderdtest.CreateProject(t, client, user.Organization, job.ID) coderdtest.AwaitProjectImportJob(t, client, user.Organization, job.ID) workspace := coderdtest.CreateWorkspace(t, client, "me", project.ID) - firstHistory, err := client.CreateWorkspaceHistory(context.Background(), "", workspace.Name, coderd.CreateWorkspaceHistoryRequest{ + history, err := client.CreateWorkspaceHistory(context.Background(), "", workspace.Name, coderd.CreateWorkspaceHistoryRequest{ ProjectVersionID: project.ActiveVersionID, Transition: database.WorkspaceTransitionStart, }) require.NoError(t, err) - coderdtest.AwaitWorkspaceProvisionJob(t, client, user.Organization, firstHistory.ProvisionJobID) + coderdtest.AwaitWorkspaceProvisionJob(t, client, user.Organization, history.ProvisionJobID) daemonCloser.Close() - resources, err := client.WorkspaceProvisionJobResources(context.Background(), user.Organization, firstHistory.ProvisionJobID) - require.NoError(t, err) - require.Len(t, resources, 1) agentClient := codersdk.New(client.URL) agentClient.SessionToken = authToken @@ -67,9 +65,15 @@ func TestWorkspaceAgentServe(t *testing.T) { Logger: slogtest.Make(t, nil), }) - time.Sleep(time.Millisecond * 250) + var resources []coderd.ProvisionerJobResource + require.Eventually(t, func() bool { + resources, err = client.WorkspaceProvisionJobResources(context.Background(), user.Organization, history.ProvisionJobID) + require.NoError(t, err) + require.Len(t, resources, 1) + return !resources[0].Agent.UpdatedAt.IsZero() + }, 5*time.Second, 25*time.Millisecond) - workspaceClient, err := client.WorkspaceAgentConnect(context.Background(), user.Organization, firstHistory.ProvisionJobID, resources[0].ID) + workspaceClient, err := client.WorkspaceAgentConnect(context.Background(), user.Organization, history.ProvisionJobID, resources[0].ID) require.NoError(t, err) stream, err := workspaceClient.NegotiateConnection(context.Background()) require.NoError(t, err) diff --git a/database/databasefake/databasefake.go b/database/databasefake/databasefake.go index ad4d791d42e07..65cbbdd528d9e 100644 --- a/database/databasefake/databasefake.go +++ b/database/databasefake/databasefake.go @@ -512,8 +512,7 @@ func (q *fakeQuerier) GetProvisionerJobAgentByAuthToken(_ context.Context, authT q.mutex.Lock() defer q.mutex.Unlock() - for i := len(q.provisionerJobAgent) - 1; i >= 0; i-- { - agent := q.provisionerJobAgent[i] + for _, agent := range q.provisionerJobAgent { if agent.AuthToken.String() == authToken.String() { return agent, nil } @@ -535,22 +534,16 @@ func (q *fakeQuerier) GetProvisionerJobAgentByInstanceID(_ context.Context, inst return database.ProvisionerJobAgent{}, sql.ErrNoRows } -func (q *fakeQuerier) GetProvisionerJobAgentsByResourceIDs(_ context.Context, ids []uuid.UUID) ([]database.ProvisionerJobAgent, error) { +func (q *fakeQuerier) GetProvisionerJobAgentByResourceID(ctx context.Context, resourceID uuid.UUID) (database.ProvisionerJobAgent, error) { q.mutex.Lock() defer q.mutex.Unlock() - agents := make([]database.ProvisionerJobAgent, 0) for _, agent := range q.provisionerJobAgent { - for _, id := range ids { - if agent.ResourceID.String() == id.String() { - agents = append(agents, agent) - } + if agent.ResourceID.String() == resourceID.String() { + return agent, nil } } - if len(agents) == 0 { - return nil, sql.ErrNoRows - } - return agents, nil + return database.ProvisionerJobAgent{}, sql.ErrNoRows } func (q *fakeQuerier) GetProvisionerDaemonByID(_ context.Context, id uuid.UUID) (database.ProvisionerDaemon, error) { @@ -969,6 +962,21 @@ func (q *fakeQuerier) UpdateProvisionerDaemonByID(_ context.Context, arg databas return sql.ErrNoRows } +func (q *fakeQuerier) UpdateProvisionerJobAgentByID(ctx context.Context, arg database.UpdateProvisionerJobAgentByIDParams) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + for index, agent := range q.provisionerJobAgent { + if agent.ID.String() != arg.ID.String() { + continue + } + agent.UpdatedAt = arg.UpdatedAt + q.provisionerJobAgent[index] = agent + return nil + } + return sql.ErrNoRows +} + func (q *fakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error { q.mutex.Lock() defer q.mutex.Unlock() diff --git a/database/querier.go b/database/querier.go index e6d33a1a4002a..20eaf693d99bc 100644 --- a/database/querier.go +++ b/database/querier.go @@ -28,7 +28,7 @@ type querier interface { GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error) GetProvisionerJobAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (ProvisionerJobAgent, error) GetProvisionerJobAgentByInstanceID(ctx context.Context, authInstanceID string) (ProvisionerJobAgent, error) - GetProvisionerJobAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJobAgent, error) + GetProvisionerJobAgentByResourceID(ctx context.Context, resourceID uuid.UUID) (ProvisionerJobAgent, error) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (ProvisionerJob, error) GetProvisionerJobResourceByID(ctx context.Context, id uuid.UUID) (ProvisionerJobResource, error) GetProvisionerJobResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]ProvisionerJobResource, error) @@ -63,6 +63,7 @@ type querier interface { InsertWorkspaceHistory(ctx context.Context, arg InsertWorkspaceHistoryParams) (WorkspaceHistory, error) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error UpdateProvisionerDaemonByID(ctx context.Context, arg UpdateProvisionerDaemonByIDParams) error + UpdateProvisionerJobAgentByID(ctx context.Context, arg UpdateProvisionerJobAgentByIDParams) error UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error UpdateWorkspaceHistoryByID(ctx context.Context, arg UpdateWorkspaceHistoryByIDParams) error diff --git a/database/query.sql b/database/query.sql index 3b2574acb15f1..8e1d9cbc4b805 100644 --- a/database/query.sql +++ b/database/query.sql @@ -354,13 +354,13 @@ FROM WHERE job_id = $1; --- name: GetProvisionerJobAgentsByResourceIDs :many +-- name: GetProvisionerJobAgentByResourceID :one SELECT * FROM provisioner_job_agent WHERE - resource_id = ANY(@ids :: uuid [ ]); + resource_id = $1; -- name: InsertAPIKey :one INSERT INTO @@ -660,6 +660,14 @@ SET WHERE id = $1; +-- name: UpdateProvisionerJobAgentByID :exec +UPDATE + provisioner_job_agent +SET + updated_at = $2 +WHERE + id = $1; + -- name: UpdateWorkspaceHistoryByID :exec UPDATE workspace_history diff --git a/database/query.sql.go b/database/query.sql.go index 52475635e0785..aa724700fa662 100644 --- a/database/query.sql.go +++ b/database/query.sql.go @@ -673,47 +673,31 @@ func (q *sqlQuerier) GetProvisionerJobAgentByInstanceID(ctx context.Context, aut return i, err } -const getProvisionerJobAgentsByResourceIDs = `-- name: GetProvisionerJobAgentsByResourceIDs :many +const getProvisionerJobAgentByResourceID = `-- name: GetProvisionerJobAgentByResourceID :one SELECT id, created_at, updated_at, resource_id, auth_token, auth_instance_id, environment_variables, startup_script, instance_metadata, resource_metadata FROM provisioner_job_agent WHERE - resource_id = ANY($1 :: uuid [ ]) + resource_id = $1 ` -func (q *sqlQuerier) GetProvisionerJobAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJobAgent, error) { - rows, err := q.db.QueryContext(ctx, getProvisionerJobAgentsByResourceIDs, pq.Array(ids)) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ProvisionerJobAgent - for rows.Next() { - var i ProvisionerJobAgent - if err := rows.Scan( - &i.ID, - &i.CreatedAt, - &i.UpdatedAt, - &i.ResourceID, - &i.AuthToken, - &i.AuthInstanceID, - &i.EnvironmentVariables, - &i.StartupScript, - &i.InstanceMetadata, - &i.ResourceMetadata, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil +func (q *sqlQuerier) GetProvisionerJobAgentByResourceID(ctx context.Context, resourceID uuid.UUID) (ProvisionerJobAgent, error) { + row := q.db.QueryRowContext(ctx, getProvisionerJobAgentByResourceID, resourceID) + var i ProvisionerJobAgent + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.ResourceID, + &i.AuthToken, + &i.AuthInstanceID, + &i.EnvironmentVariables, + &i.StartupScript, + &i.InstanceMetadata, + &i.ResourceMetadata, + ) + return i, err } const getProvisionerJobByID = `-- name: GetProvisionerJobByID :one @@ -2234,6 +2218,25 @@ func (q *sqlQuerier) UpdateProvisionerDaemonByID(ctx context.Context, arg Update return err } +const updateProvisionerJobAgentByID = `-- name: UpdateProvisionerJobAgentByID :exec +UPDATE + provisioner_job_agent +SET + updated_at = $2 +WHERE + id = $1 +` + +type UpdateProvisionerJobAgentByIDParams struct { + ID uuid.UUID `db:"id" json:"id"` + UpdatedAt sql.NullTime `db:"updated_at" json:"updated_at"` +} + +func (q *sqlQuerier) UpdateProvisionerJobAgentByID(ctx context.Context, arg UpdateProvisionerJobAgentByIDParams) error { + _, err := q.db.ExecContext(ctx, updateProvisionerJobAgentByID, arg.ID, arg.UpdatedAt) + return err +} + const updateProvisionerJobByID = `-- name: UpdateProvisionerJobByID :exec UPDATE provisioner_job diff --git a/provisioner/echo/serve.go b/provisioner/echo/serve.go index 0d2ac2e65d8cd..9d003ac8311cb 100644 --- a/provisioner/echo/serve.go +++ b/provisioner/echo/serve.go @@ -75,7 +75,11 @@ func (*echo) Parse(request *proto.Parse_Request, stream proto.DRPCProvisioner_Pa // Provision reads requests from the provided directory to stream responses. func (*echo) Provision(request *proto.Provision_Request, stream proto.DRPCProvisioner_ProvisionStream) error { for index := 0; ; index++ { - path := filepath.Join(request.Directory, fmt.Sprintf("%d.provision.protobuf", index)) + extension := ".protobuf" + if request.DryRun { + extension = ".dry.protobuf" + } + path := filepath.Join(request.Directory, fmt.Sprintf("%d.provision"+extension, index)) _, err := os.Stat(path) if err != nil { if index == 0 { @@ -107,14 +111,18 @@ func (*echo) Shutdown(_ context.Context, _ *proto.Empty) (*proto.Empty, error) { } type Responses struct { - Parse []*proto.Parse_Response - Provision []*proto.Provision_Response + Parse []*proto.Parse_Response + Provision []*proto.Provision_Response + ProvisionDryRun []*proto.Provision_Response } // Tar returns a tar archive of responses to provisioner operations. func Tar(responses *Responses) ([]byte, error) { if responses == nil { - responses = &Responses{ParseComplete, ProvisionComplete} + responses = &Responses{ParseComplete, ProvisionComplete, ProvisionComplete} + } + if responses.ProvisionDryRun == nil { + responses.ProvisionDryRun = responses.Provision } var buffer bytes.Buffer @@ -153,6 +161,23 @@ func Tar(responses *Responses) ([]byte, error) { return nil, err } } + for index, response := range responses.ProvisionDryRun { + data, err := protobuf.Marshal(response) + if err != nil { + return nil, err + } + err = writer.WriteHeader(&tar.Header{ + Name: fmt.Sprintf("%d.provision.dry.protobuf", index), + Size: int64(len(data)), + }) + if err != nil { + return nil, err + } + _, err = writer.Write(data) + if err != nil { + return nil, err + } + } err := writer.Flush() if err != nil { return nil, err From 74f0328d78437296a11a926d7b296fe45ae6604a Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 3 Mar 2022 16:48:35 +0000 Subject: [PATCH 3/5] Fix linting error --- cli/workspaceagent.go | 5 ++-- cli/workspaces.go | 1 + coderd/coderd.go | 1 - coderd/provisionerjobs.go | 42 ++------------------------- coderd/workspaceagent_test.go | 5 ++-- codersdk/workspaceagent.go | 5 ++-- database/databasefake/databasefake.go | 4 +-- 7 files changed, 14 insertions(+), 49 deletions(-) diff --git a/cli/workspaceagent.go b/cli/workspaceagent.go index 75a462c0a100e..3246b21bee5f1 100644 --- a/cli/workspaceagent.go +++ b/cli/workspaceagent.go @@ -4,11 +4,12 @@ import ( "net/url" "os" - "github.com/coder/coder/agent" - "github.com/coder/coder/codersdk" "github.com/powersj/whatsthis/pkg/cloud" "github.com/spf13/cobra" "golang.org/x/xerrors" + + "github.com/coder/coder/agent" + "github.com/coder/coder/codersdk" ) func workspaceAgent() *cobra.Command { diff --git a/cli/workspaces.go b/cli/workspaces.go index d405f00cea88b..b470fc7df1c60 100644 --- a/cli/workspaces.go +++ b/cli/workspaces.go @@ -6,6 +6,7 @@ func workspaces() *cobra.Command { cmd := &cobra.Command{ Use: "workspaces", } + cmd.AddCommand(workspaceAgent()) cmd.AddCommand(workspaceCreate()) return cmd diff --git a/coderd/coderd.go b/coderd/coderd.go index b86de63834e00..90a34787404a1 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -156,7 +156,6 @@ func New(options *Options) (http.Handler, func()) { r.Get("/", api.provisionerJobResourcesByID) r.Route("/{workspaceresource}", func(r chi.Router) { r.Use(httpmw.ExtractWorkspaceResourceParam(options.Database)) - r.Get("/", api.provisionerJobResourceByID) r.Get("/agent", api.workspaceAgentConnectByResource) }) }) diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 3e2e665ee3fcc..0016ceb77edf6 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -249,47 +249,9 @@ func (api *api) provisionerJobResourcesByID(rw http.ResponseWriter, r *http.Requ return } resources, err := api.Database.GetProvisionerJobResourcesByJobID(r.Context(), job.ID) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get provisioner job resources: %s", err), - }) - return - } - apiResources := make([]ProvisionerJobResource, 0) - for _, resource := range resources { - if !resource.AgentID.Valid { - apiResources = append(apiResources, convertProvisionerJobResource(resource, nil)) - continue - } - agent, err := api.Database.GetProvisionerJobAgentByResourceID(r.Context(), resource.ID) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("get provisioner job agent: %s", err), - }) - return - } - apiAgent, err := convertProvisionerJobAgent(agent) - if err != nil { - httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ - Message: fmt.Sprintf("convert provisioner job agent: %s", err), - }) - return - } - apiResources = append(apiResources, convertProvisionerJobResource(resource, &apiAgent)) - } - render.Status(r, http.StatusOK) - render.JSON(rw, r, apiResources) -} - -func (api *api) provisionerJobResourceByID(rw http.ResponseWriter, r *http.Request) { - job := httpmw.ProvisionerJobParam(r) - if !convertProvisionerJob(job).Status.Completed() { - httpapi.Write(rw, http.StatusPreconditionFailed, httpapi.Response{ - Message: "Job hasn't completed!", - }) - return + if errors.Is(err, sql.ErrNoRows) { + err = nil } - resources, err := api.Database.GetProvisionerJobResourcesByJobID(r.Context(), job.ID) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: fmt.Sprintf("get provisioner job resources: %s", err), diff --git a/coderd/workspaceagent_test.go b/coderd/workspaceagent_test.go index 17d56b8c7a560..a7e6dc57376f4 100644 --- a/coderd/workspaceagent_test.go +++ b/coderd/workspaceagent_test.go @@ -5,6 +5,9 @@ import ( "testing" "time" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" @@ -16,8 +19,6 @@ import ( "github.com/coder/coder/peerbroker" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" - "github.com/google/uuid" - "github.com/stretchr/testify/require" ) func TestWorkspaceAgentServe(t *testing.T) { diff --git a/codersdk/workspaceagent.go b/codersdk/workspaceagent.go index 57507e7c578c6..be451aa28b62e 100644 --- a/codersdk/workspaceagent.go +++ b/codersdk/workspaceagent.go @@ -12,14 +12,15 @@ import ( "golang.org/x/xerrors" "nhooyr.io/websocket" + "github.com/google/uuid" + "github.com/hashicorp/yamux" + "github.com/coder/coder/coderd" "github.com/coder/coder/httpmw" "github.com/coder/coder/peer" "github.com/coder/coder/peerbroker" "github.com/coder/coder/peerbroker/proto" "github.com/coder/coder/provisionersdk" - "github.com/google/uuid" - "github.com/hashicorp/yamux" ) // AuthenticateWorkspaceAgentUsingGoogleCloudIdentity uses the Google Compute Engine Metadata API to diff --git a/database/databasefake/databasefake.go b/database/databasefake/databasefake.go index 65cbbdd528d9e..214e6f45bb5d8 100644 --- a/database/databasefake/databasefake.go +++ b/database/databasefake/databasefake.go @@ -534,7 +534,7 @@ func (q *fakeQuerier) GetProvisionerJobAgentByInstanceID(_ context.Context, inst return database.ProvisionerJobAgent{}, sql.ErrNoRows } -func (q *fakeQuerier) GetProvisionerJobAgentByResourceID(ctx context.Context, resourceID uuid.UUID) (database.ProvisionerJobAgent, error) { +func (q *fakeQuerier) GetProvisionerJobAgentByResourceID(_ context.Context, resourceID uuid.UUID) (database.ProvisionerJobAgent, error) { q.mutex.Lock() defer q.mutex.Unlock() @@ -962,7 +962,7 @@ func (q *fakeQuerier) UpdateProvisionerDaemonByID(_ context.Context, arg databas return sql.ErrNoRows } -func (q *fakeQuerier) UpdateProvisionerJobAgentByID(ctx context.Context, arg database.UpdateProvisionerJobAgentByIDParams) error { +func (q *fakeQuerier) UpdateProvisionerJobAgentByID(_ context.Context, arg database.UpdateProvisionerJobAgentByIDParams) error { q.mutex.Lock() defer q.mutex.Unlock() From 6c2a3e2fefa1fcaacacde59a4ee17fb6f6c866fe Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Fri, 4 Mar 2022 01:40:50 +0000 Subject: [PATCH 4/5] Remove DownloadURL --- codersdk/files.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/codersdk/files.go b/codersdk/files.go index c7eb642f76b68..732cbb2c42d25 100644 --- a/codersdk/files.go +++ b/codersdk/files.go @@ -3,9 +3,7 @@ package codersdk import ( "context" "encoding/json" - "fmt" "net/http" - "net/url" "github.com/coder/coder/coderd" ) @@ -28,8 +26,3 @@ func (c *Client) UploadFile(ctx context.Context, contentType string, content []b var resp coderd.UploadFileResponse return resp, json.NewDecoder(res.Body).Decode(&resp) } - -// DownloadURL returns -func (c *Client) DownloadURL(asset string) (*url.URL, error) { - return c.URL.Parse(fmt.Sprintf("/api/v2/downloads/%s", asset)) -} From 989596547730f7863c993593440206fd96a5934c Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Fri, 4 Mar 2022 17:50:07 +0000 Subject: [PATCH 5/5] Fix listening --- agent/agent.go | 7 +---- cli/projectcreate.go | 15 +++++----- cli/root.go | 1 + cli/ssh.go | 64 +++++++++++++++++++++++++++++++++++++++++ cli/workspaceagent.go | 7 ++++- coderd/cmd/root.go | 18 ++++++++---- peerbroker/listen.go | 4 ++- provisionersdk/agent.go | 4 +-- 8 files changed, 97 insertions(+), 23 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index d8dba42a47b0b..a9d774446ef54 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -135,12 +135,7 @@ func (s *server) init(ctx context.Context) { }, ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { return &gossh.ServerConfig{ - Config: gossh.Config{ - // "arcfour" is the fastest SSH cipher. We prioritize throughput - // over encryption here, because the WebRTC connection is already - // encrypted. If possible, we'd disable encryption entirely here. - Ciphers: []string{"arcfour"}, - }, + Config: gossh.Config{}, NoClientAuth: true, } }, diff --git a/cli/projectcreate.go b/cli/projectcreate.go index a1e5027c26715..454833299514b 100644 --- a/cli/projectcreate.go +++ b/cli/projectcreate.go @@ -71,13 +71,6 @@ func projectCreate() *cobra.Command { if err != nil { return err } - project, err := client.CreateProject(cmd.Context(), organization.Name, coderd.CreateProjectRequest{ - Name: name, - VersionImportJobID: job.ID, - }) - if err != nil { - return err - } _, err = prompt(cmd, &promptui.Prompt{ Label: "Create project?", @@ -91,6 +84,14 @@ func projectCreate() *cobra.Command { return err } + project, err := client.CreateProject(cmd.Context(), organization.Name, coderd.CreateProjectRequest{ + Name: name, + VersionImportJobID: job.ID, + }) + if err != nil { + return err + } + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s The %s project has been created!\n", caret, color.HiCyanString(project.Name)) _, err = prompt(cmd, &promptui.Prompt{ Label: "Create a new workspace?", diff --git a/cli/root.go b/cli/root.go index 054d9d84f942b..81c4e09e4f14b 100644 --- a/cli/root.go +++ b/cli/root.go @@ -68,6 +68,7 @@ func Root() *cobra.Command { cmd.AddCommand(projects()) cmd.AddCommand(workspaces()) cmd.AddCommand(users()) + cmd.AddCommand(ssh()) cmd.PersistentFlags().String(varGlobalConfig, configdir.LocalConfig("coder"), "Path to the global `coder` config directory") cmd.PersistentFlags().Bool(varForceTty, false, "Force the `coder` command to run as if connected to a TTY") diff --git a/cli/ssh.go b/cli/ssh.go index 7f1e458cd3abe..ae7095398b120 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -1 +1,65 @@ package cli + +import ( + "io" + + "github.com/coder/coder/agent" + "github.com/coder/coder/peer" + "github.com/coder/coder/peerbroker" + "github.com/pion/webrtc/v3" + "github.com/spf13/cobra" +) + +func ssh() *cobra.Command { + return &cobra.Command{ + Use: "ssh", + RunE: func(cmd *cobra.Command, args []string) error { + client, err := createClient(cmd) + if err != nil { + return err + } + organization, err := currentOrganization(cmd, client) + if err != nil { + return err + } + history, err := client.WorkspaceHistory(cmd.Context(), "", "kyle", "") + if err != nil { + return err + } + resources, err := client.WorkspaceProvisionJobResources(cmd.Context(), organization.Name, history.ProvisionJobID) + if err != nil { + return err + } + for _, resource := range resources { + if resource.Agent == nil { + continue + } + wagent, err := client.WorkspaceAgentConnect(cmd.Context(), organization.Name, history.ProvisionJobID, resource.ID) + if err != nil { + return err + } + stream, err := wagent.NegotiateConnection(cmd.Context()) + if err != nil { + return err + } + conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{ + URLs: []string{"stun:stun.l.google.com:19302"}, + }}, &peer.ConnOptions{ + // Logger: slog.Make(sloghuman.Sink(cmd.OutOrStdout())).Leveled(slog.LevelDebug), + }) + if err != nil { + return err + } + sshConn, err := agent.DialSSH(conn) + if err != nil { + return err + } + go func() { + _, _ = io.Copy(cmd.OutOrStdout(), sshConn) + }() + _, _ = io.Copy(sshConn, cmd.InOrStdin()) + } + return nil + }, + } +} diff --git a/cli/workspaceagent.go b/cli/workspaceagent.go index 3246b21bee5f1..ebbbee8e903fd 100644 --- a/cli/workspaceagent.go +++ b/cli/workspaceagent.go @@ -4,12 +4,15 @@ import ( "net/url" "os" + "cdr.dev/slog" + "cdr.dev/slog/sloggers/sloghuman" "github.com/powersj/whatsthis/pkg/cloud" "github.com/spf13/cobra" "golang.org/x/xerrors" "github.com/coder/coder/agent" "github.com/coder/coder/codersdk" + "github.com/coder/coder/peer" ) func workspaceAgent() *cobra.Command { @@ -49,7 +52,9 @@ func workspaceAgent() *cobra.Command { } } client.SessionToken = sessionToken - closer := agent.New(client.WorkspaceAgentServe, nil) + closer := agent.New(client.WorkspaceAgentServe, &peer.ConnOptions{ + Logger: slog.Make(sloghuman.Sink(cmd.OutOrStdout())).Leveled(slog.LevelDebug), + }) <-cmd.Context().Done() return closer.Close() }, diff --git a/coderd/cmd/root.go b/coderd/cmd/root.go index c5acf3bba3bd0..f2633e4b30607 100644 --- a/coderd/cmd/root.go +++ b/coderd/cmd/root.go @@ -12,6 +12,7 @@ import ( "github.com/spf13/cobra" "golang.org/x/xerrors" + "google.golang.org/api/idtoken" "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" @@ -34,14 +35,19 @@ func Root() *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { logger := slog.Make(sloghuman.Sink(os.Stderr)) accessURL := &url.URL{ - Scheme: "http", - Host: address, + Scheme: "https", + Host: "momentum-finals-representation-failed.trycloudflare.com", + } + googleTokenValidator, err := idtoken.NewValidator(cmd.Context()) + if err != nil { + return xerrors.Errorf("create google token validator: %w", err) } handler, closeCoderd := coderd.New(&coderd.Options{ - AccessURL: accessURL, - Logger: logger, - Database: databasefake.New(), - Pubsub: database.NewPubsubInMemory(), + AccessURL: accessURL, + Logger: logger, + Database: databasefake.New(), + Pubsub: database.NewPubsubInMemory(), + GoogleTokenValidator: googleTokenValidator, }) listener, err := net.Listen("tcp", address) diff --git a/peerbroker/listen.go b/peerbroker/listen.go index 00c9faf125ba8..de761044308eb 100644 --- a/peerbroker/listen.go +++ b/peerbroker/listen.go @@ -108,7 +108,9 @@ type peerBrokerService struct { // NegotiateConnection negotiates a WebRTC connection. func (b *peerBrokerService) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error { // Start with no ICE servers. They can be sent by the client if provided. - peerConn, err := peer.Server([]webrtc.ICEServer{}, b.connOptions) + peerConn, err := peer.Server([]webrtc.ICEServer{{ + URLs: []string{"stun:stun.l.google.com:19302"}, + }}, b.connOptions) if err != nil { return xerrors.Errorf("create peer connection: %w", err) } diff --git a/provisionersdk/agent.go b/provisionersdk/agent.go index 2f2a3a74dc68e..c050a20b377c9 100644 --- a/provisionersdk/agent.go +++ b/provisionersdk/agent.go @@ -30,7 +30,7 @@ BINARY_LOCATION=$(mktemp -d)/coder curl -fsSL ${DOWNLOAD_URL} -o $BINARY_LOCATION chmod +x $BINARY_LOCATION export CODER_URL="${ACCESS_URL}" -exec $BINARY_LOCATION agent +exec $BINARY_LOCATION workspaces agent `, }, "darwin": { @@ -41,7 +41,7 @@ BINARY_LOCATION=$(mktemp -d)/coder curl -fsSL ${DOWNLOAD_URL} -o $BINARY_LOCATION chmod +x $BINARY_LOCATION export CODER_URL="${ACCESS_URL}" -exec $BINARY_LOCATION agent +exec $BINARY_LOCATION workspaces agent `, }, }