Skip to content

Commit 26e85b0

Browse files
authored
fix: use typed wireguard public keys in database structs (#2639)
1 parent 1157303 commit 26e85b0

File tree

6 files changed

+95
-25
lines changed

6 files changed

+95
-25
lines changed

coderd/database/dbtypes/dbtypes.go

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package dbtypes
2+
3+
import (
4+
"database/sql/driver"
5+
6+
"golang.org/x/xerrors"
7+
"tailscale.com/types/key"
8+
)
9+
10+
// NodePublic is a wrapper around a key.NodePublic which represents the
11+
// Wireguard public key for an agent..
12+
type NodePublic key.NodePublic
13+
14+
func (n NodePublic) String() string {
15+
return key.NodePublic(n).String()
16+
}
17+
18+
// This is necessary so NodePublic can be serialized in JSON loggers.
19+
func (n NodePublic) MarshalJSON() ([]byte, error) {
20+
j, err := key.NodePublic(n).MarshalText()
21+
// surround in quotes to make it a JSON string
22+
j = append([]byte{'"'}, append(j, '"')...)
23+
return j, err
24+
}
25+
26+
// Value is so NodePublic can be inserted into the database.
27+
func (n NodePublic) Value() (driver.Value, error) {
28+
return key.NodePublic(n).MarshalText()
29+
}
30+
31+
// Scan is so NodePublic can be read from the database.
32+
func (n *NodePublic) Scan(value interface{}) error {
33+
switch v := value.(type) {
34+
case []byte:
35+
return (*key.NodePublic)(n).UnmarshalText(v)
36+
case string:
37+
return (*key.NodePublic)(n).UnmarshalText([]byte(v))
38+
default:
39+
return xerrors.Errorf("unexpected type: %T", v)
40+
}
41+
}
42+
43+
// NodePublic is a wrapper around a key.NodePublic which represents the
44+
// Tailscale disco key for an agent.
45+
type DiscoPublic key.DiscoPublic
46+
47+
func (n DiscoPublic) String() string {
48+
return key.DiscoPublic(n).String()
49+
}
50+
51+
// This is necessary so DiscoPublic can be serialized in JSON loggers.
52+
func (n DiscoPublic) MarshalJSON() ([]byte, error) {
53+
j, err := key.DiscoPublic(n).MarshalText()
54+
// surround in quotes to make it a JSON string
55+
j = append([]byte{'"'}, append(j, '"')...)
56+
return j, err
57+
}
58+
59+
// Value is so DiscoPublic can be inserted into the database.
60+
func (n DiscoPublic) Value() (driver.Value, error) {
61+
return key.DiscoPublic(n).MarshalText()
62+
}
63+
64+
// Scan is so DiscoPublic can be read from the database.
65+
func (n *DiscoPublic) Scan(value interface{}) error {
66+
switch v := value.(type) {
67+
case []byte:
68+
return (*key.DiscoPublic)(n).UnmarshalText(v)
69+
case string:
70+
return (*key.DiscoPublic)(n).UnmarshalText([]byte(v))
71+
default:
72+
return xerrors.Errorf("unexpected type: %T", v)
73+
}
74+
}

coderd/database/models.go

+3-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

+6-5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/sqlc.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ packages:
1717
output_db_file_name: db_tmp.go
1818

1919
overrides:
20-
- column: workspaces.wireguard_public_key
21-
go_type: tailscale.com/types/key.MachinePublic
22-
- column: workspaces.disco_public_key
23-
go_type: tailscale.com/types/key.DiscoPublic
20+
- column: workspace_agents.wireguard_node_public_key
21+
go_type: github.com/coder/coder/coderd/database/dbtypes.NodePublic
22+
- column: workspace_agents.wireguard_disco_public_key
23+
go_type: github.com/coder/coder/coderd/database/dbtypes.DiscoPublic
2424

2525
rename:
2626
api_key: APIKey

coderd/provisionerdaemons.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ import (
1919
protobuf "google.golang.org/protobuf/proto"
2020
"storj.io/drpc/drpcmux"
2121
"storj.io/drpc/drpcserver"
22-
"tailscale.com/types/key"
2322

2423
"cdr.dev/slog"
2524

2625
"github.com/coder/coder/coderd/database"
26+
"github.com/coder/coder/coderd/database/dbtypes"
2727
"github.com/coder/coder/coderd/httpapi"
2828
"github.com/coder/coder/coderd/parameter"
2929
"github.com/coder/coder/coderd/rbac"
@@ -761,8 +761,8 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
761761
Valid: prAgent.StartupScript != "",
762762
},
763763
WireguardNodeIPv6: peerwg.UUIDToInet(agentID),
764-
WireguardNodePublicKey: key.NodePublic{}.String(),
765-
WireguardDiscoPublicKey: key.DiscoPublic{}.String(),
764+
WireguardNodePublicKey: dbtypes.NodePublic{},
765+
WireguardDiscoPublicKey: dbtypes.DiscoPublic{},
766766
})
767767
if err != nil {
768768
return xerrors.Errorf("insert agent: %w", err)

coderd/workspaceagents.go

+5-11
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"cdr.dev/slog"
2323
"github.com/coder/coder/agent"
2424
"github.com/coder/coder/coderd/database"
25+
"github.com/coder/coder/coderd/database/dbtypes"
2526
"github.com/coder/coder/coderd/httpapi"
2627
"github.com/coder/coder/coderd/httpmw"
2728
"github.com/coder/coder/coderd/rbac"
@@ -488,8 +489,8 @@ func (api *API) postWorkspaceAgentKeys(rw http.ResponseWriter, r *http.Request)
488489

489490
err := api.Database.UpdateWorkspaceAgentKeysByID(ctx, database.UpdateWorkspaceAgentKeysByIDParams{
490491
ID: workspaceAgent.ID,
491-
WireguardNodePublicKey: keys.Public.String(),
492-
WireguardDiscoPublicKey: keys.Disco.String(),
492+
WireguardNodePublicKey: dbtypes.NodePublic(keys.Public),
493+
WireguardDiscoPublicKey: dbtypes.DiscoPublic(keys.Disco),
493494
})
494495
if err != nil {
495496
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
@@ -711,15 +712,8 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work
711712
Directory: dbAgent.Directory,
712713
Apps: apps,
713714
IPv6: inetToNetaddr(dbAgent.WireguardNodeIPv6),
714-
}
715-
716-
err := workspaceAgent.WireguardPublicKey.UnmarshalText([]byte(dbAgent.WireguardNodePublicKey))
717-
if err != nil {
718-
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal wireguard node public key %q: %w", dbAgent.WireguardNodePublicKey, err)
719-
}
720-
err = workspaceAgent.DiscoPublicKey.UnmarshalText([]byte(dbAgent.WireguardDiscoPublicKey))
721-
if err != nil {
722-
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal disco public key %q: %w", dbAgent.WireguardDiscoPublicKey, err)
715+
WireguardPublicKey: key.NodePublic(dbAgent.WireguardNodePublicKey),
716+
DiscoPublicKey: key.DiscoPublic(dbAgent.WireguardDiscoPublicKey),
723717
}
724718

725719
if dbAgent.FirstConnectedAt.Valid {

0 commit comments

Comments
 (0)