diff --git a/coderd/database/dbtypes/dbtypes.go b/coderd/database/dbtypes/dbtypes.go new file mode 100644 index 0000000000000..3653f4f37cb62 --- /dev/null +++ b/coderd/database/dbtypes/dbtypes.go @@ -0,0 +1,74 @@ +package dbtypes + +import ( + "database/sql/driver" + + "golang.org/x/xerrors" + "tailscale.com/types/key" +) + +// NodePublic is a wrapper around a key.NodePublic which represents the +// Wireguard public key for an agent.. +type NodePublic key.NodePublic + +func (n NodePublic) String() string { + return key.NodePublic(n).String() +} + +// This is necessary so NodePublic can be serialized in JSON loggers. +func (n NodePublic) MarshalJSON() ([]byte, error) { + j, err := key.NodePublic(n).MarshalText() + // surround in quotes to make it a JSON string + j = append([]byte{'"'}, append(j, '"')...) + return j, err +} + +// Value is so NodePublic can be inserted into the database. +func (n NodePublic) Value() (driver.Value, error) { + return key.NodePublic(n).MarshalText() +} + +// Scan is so NodePublic can be read from the database. +func (n *NodePublic) Scan(value interface{}) error { + switch v := value.(type) { + case []byte: + return (*key.NodePublic)(n).UnmarshalText(v) + case string: + return (*key.NodePublic)(n).UnmarshalText([]byte(v)) + default: + return xerrors.Errorf("unexpected type: %T", v) + } +} + +// NodePublic is a wrapper around a key.NodePublic which represents the +// Tailscale disco key for an agent. +type DiscoPublic key.DiscoPublic + +func (n DiscoPublic) String() string { + return key.DiscoPublic(n).String() +} + +// This is necessary so DiscoPublic can be serialized in JSON loggers. +func (n DiscoPublic) MarshalJSON() ([]byte, error) { + j, err := key.DiscoPublic(n).MarshalText() + // surround in quotes to make it a JSON string + j = append([]byte{'"'}, append(j, '"')...) + return j, err +} + +// Value is so DiscoPublic can be inserted into the database. +func (n DiscoPublic) Value() (driver.Value, error) { + return key.DiscoPublic(n).MarshalText() +} + +// Scan is so DiscoPublic can be read from the database. +func (n *DiscoPublic) Scan(value interface{}) error { + switch v := value.(type) { + case []byte: + return (*key.DiscoPublic)(n).UnmarshalText(v) + case string: + return (*key.DiscoPublic)(n).UnmarshalText([]byte(v)) + default: + return xerrors.Errorf("unexpected type: %T", v) + } +} diff --git a/coderd/database/models.go b/coderd/database/models.go index 8664803f7d98e..660a7620df454 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -10,6 +10,7 @@ import ( "fmt" "time" + "github.com/coder/coder/coderd/database/dbtypes" "github.com/google/uuid" "github.com/tabbed/pqtype" ) @@ -521,8 +522,8 @@ type WorkspaceAgent struct { ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"` Directory string `db:"directory" json:"directory"` WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"` - WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"` - WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"` + WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"` + WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"` } type WorkspaceApp struct { diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 4eb3fdf6860c7..96efc5682a981 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -10,6 +10,7 @@ import ( "encoding/json" "time" + "github.com/coder/coder/coderd/database/dbtypes" "github.com/google/uuid" "github.com/lib/pq" "github.com/tabbed/pqtype" @@ -3105,8 +3106,8 @@ type InsertWorkspaceAgentParams struct { InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"` ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"` WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"` - WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"` - WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"` + WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"` + WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"` } func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) { @@ -3196,9 +3197,9 @@ WHERE ` type UpdateWorkspaceAgentKeysByIDParams struct { - ID uuid.UUID `db:"id" json:"id"` - WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"` - WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"` + ID uuid.UUID `db:"id" json:"id"` + WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"` + WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"` } func (q *sqlQuerier) UpdateWorkspaceAgentKeysByID(ctx context.Context, arg UpdateWorkspaceAgentKeysByIDParams) error { diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index 8827cc76083b3..6f42bbaa4bd2c 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -17,10 +17,10 @@ packages: output_db_file_name: db_tmp.go overrides: - - column: workspaces.wireguard_public_key - go_type: tailscale.com/types/key.MachinePublic - - column: workspaces.disco_public_key - go_type: tailscale.com/types/key.DiscoPublic + - column: workspace_agents.wireguard_node_public_key + go_type: github.com/coder/coder/coderd/database/dbtypes.NodePublic + - column: workspace_agents.wireguard_disco_public_key + go_type: github.com/coder/coder/coderd/database/dbtypes.DiscoPublic rename: api_key: APIKey diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go index c9dfb92986848..7a013e15d755d 100644 --- a/coderd/provisionerdaemons.go +++ b/coderd/provisionerdaemons.go @@ -19,11 +19,11 @@ import ( protobuf "google.golang.org/protobuf/proto" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" - "tailscale.com/types/key" "cdr.dev/slog" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbtypes" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/parameter" "github.com/coder/coder/coderd/rbac" @@ -761,8 +761,8 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid. Valid: prAgent.StartupScript != "", }, WireguardNodeIPv6: peerwg.UUIDToInet(agentID), - WireguardNodePublicKey: key.NodePublic{}.String(), - WireguardDiscoPublicKey: key.DiscoPublic{}.String(), + WireguardNodePublicKey: dbtypes.NodePublic{}, + WireguardDiscoPublicKey: dbtypes.DiscoPublic{}, }) if err != nil { return xerrors.Errorf("insert agent: %w", err) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index af23d11dc7922..a029d113768f7 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -22,6 +22,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/agent" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbtypes" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/rbac" @@ -488,8 +489,8 @@ func (api *API) postWorkspaceAgentKeys(rw http.ResponseWriter, r *http.Request) err := api.Database.UpdateWorkspaceAgentKeysByID(ctx, database.UpdateWorkspaceAgentKeysByIDParams{ ID: workspaceAgent.ID, - WireguardNodePublicKey: keys.Public.String(), - WireguardDiscoPublicKey: keys.Disco.String(), + WireguardNodePublicKey: dbtypes.NodePublic(keys.Public), + WireguardDiscoPublicKey: dbtypes.DiscoPublic(keys.Disco), }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ @@ -711,15 +712,8 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work Directory: dbAgent.Directory, Apps: apps, IPv6: inetToNetaddr(dbAgent.WireguardNodeIPv6), - } - - err := workspaceAgent.WireguardPublicKey.UnmarshalText([]byte(dbAgent.WireguardNodePublicKey)) - if err != nil { - return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal wireguard node public key %q: %w", dbAgent.WireguardNodePublicKey, err) - } - err = workspaceAgent.DiscoPublicKey.UnmarshalText([]byte(dbAgent.WireguardDiscoPublicKey)) - if err != nil { - return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal disco public key %q: %w", dbAgent.WireguardDiscoPublicKey, err) + WireguardPublicKey: key.NodePublic(dbAgent.WireguardNodePublicKey), + DiscoPublicKey: key.DiscoPublic(dbAgent.WireguardDiscoPublicKey), } if dbAgent.FirstConnectedAt.Valid {