diff --git a/coderd/database/awsiamrds/awsiamrds.go b/coderd/database/awsiamrds/awsiamrds.go index 1d4ded8ac2ea2..a8cd6ab495b55 100644 --- a/coderd/database/awsiamrds/awsiamrds.go +++ b/coderd/database/awsiamrds/awsiamrds.go @@ -10,7 +10,10 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + "github.com/lib/pq" "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" ) type awsIamRdsDriver struct { @@ -18,7 +21,10 @@ type awsIamRdsDriver struct { cfg aws.Config } -var _ driver.Driver = &awsIamRdsDriver{} +var ( + _ driver.Driver = &awsIamRdsDriver{} + _ database.ConnectorCreator = &awsIamRdsDriver{} +) // Register initializes and registers our aws iam rds wrapped database driver. func Register(ctx context.Context, parentName string) (string, error) { @@ -65,6 +71,16 @@ func (d *awsIamRdsDriver) Open(name string) (driver.Conn, error) { return conn, nil } +// Connector returns a driver.Connector that fetches a new authentication token for each connection. +func (d *awsIamRdsDriver) Connector(name string) (driver.Connector, error) { + connector := &connector{ + url: name, + cfg: d.cfg, + } + + return connector, nil +} + func getAuthenticatedURL(cfg aws.Config, dbURL string) (string, error) { nURL, err := url.Parse(dbURL) if err != nil { @@ -82,3 +98,37 @@ func getAuthenticatedURL(cfg aws.Config, dbURL string) (string, error) { return nURL.String(), nil } + +type connector struct { + url string + cfg aws.Config + dialer pq.Dialer +} + +var _ database.DialerConnector = &connector{} + +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + nURL, err := getAuthenticatedURL(c.cfg, c.url) + if err != nil { + return nil, xerrors.Errorf("assigning authentication token to url: %w", err) + } + + nc, err := pq.NewConnector(nURL) + if err != nil { + return nil, xerrors.Errorf("creating new connector: %w", err) + } + + if c.dialer != nil { + nc.Dialer(c.dialer) + } + + return nc.Connect(ctx) +} + +func (*connector) Driver() driver.Driver { + return &pq.Driver{} +} + +func (c *connector) Dialer(dialer pq.Dialer) { + c.dialer = dialer +} diff --git a/coderd/database/awsiamrds/awsiamrds_test.go b/coderd/database/awsiamrds/awsiamrds_test.go index d4a1ce193016e..36f4ea4d8f6b2 100644 --- a/coderd/database/awsiamrds/awsiamrds_test.go +++ b/coderd/database/awsiamrds/awsiamrds_test.go @@ -7,10 +7,11 @@ import ( "github.com/stretchr/testify/require" + "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/v2/cli" - awsrdsiam "github.com/coder/coder/v2/coderd/database/awsiamrds" + "github.com/coder/coder/v2/coderd/database/awsiamrds" + "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/testutil" ) @@ -22,13 +23,15 @@ func TestDriver(t *testing.T) { // export DBAWSIAMRDS_TEST_URL="postgres://user@host:5432/dbname"; url := os.Getenv("DBAWSIAMRDS_TEST_URL") if url == "" { + t.Log("skipping test; no DBAWSIAMRDS_TEST_URL set") t.Skip() } + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - sqlDriver, err := awsrdsiam.Register(ctx, "postgres") + sqlDriver, err := awsiamrds.Register(ctx, "postgres") require.NoError(t, err) db, err := cli.ConnectToPostgres(ctx, slogtest.Make(t, nil), sqlDriver, url) @@ -47,4 +50,23 @@ func TestDriver(t *testing.T) { var one int require.NoError(t, i.Scan(&one)) require.Equal(t, 1, one) + + ps, err := pubsub.New(ctx, logger, db, url) + require.NoError(t, err) + + gotChan := make(chan struct{}) + subCancel, err := ps.Subscribe("test", func(_ context.Context, _ []byte) { + close(gotChan) + }) + defer subCancel() + require.NoError(t, err) + + err = ps.Publish("test", []byte("hello")) + require.NoError(t, err) + + select { + case <-gotChan: + case <-ctx.Done(): + require.Fail(t, "timed out waiting for message") + } } diff --git a/coderd/database/connector.go b/coderd/database/connector.go new file mode 100644 index 0000000000000..5ade33ed18233 --- /dev/null +++ b/coderd/database/connector.go @@ -0,0 +1,19 @@ +package database + +import ( + "database/sql/driver" + + "github.com/lib/pq" +) + +// ConnectorCreator is a driver.Driver that can create a driver.Connector. +type ConnectorCreator interface { + driver.Driver + Connector(name string) (driver.Connector, error) +} + +// DialerConnector is a driver.Connector that can set a pq.Dialer. +type DialerConnector interface { + driver.Connector + Dialer(dialer pq.Dialer) +} diff --git a/coderd/database/dbtestutil/driver.go b/coderd/database/dbtestutil/driver.go new file mode 100644 index 0000000000000..cb2e05af78617 --- /dev/null +++ b/coderd/database/dbtestutil/driver.go @@ -0,0 +1,79 @@ +package dbtestutil + +import ( + "context" + "database/sql/driver" + + "github.com/lib/pq" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" +) + +var _ database.DialerConnector = &Connector{} + +type Connector struct { + name string + driver *Driver + dialer pq.Dialer +} + +func (c *Connector) Connect(_ context.Context) (driver.Conn, error) { + if c.dialer != nil { + conn, err := pq.DialOpen(c.dialer, c.name) + if err != nil { + return nil, xerrors.Errorf("failed to dial open connection: %w", err) + } + + c.driver.Connections <- conn + + return conn, nil + } + + conn, err := pq.Driver{}.Open(c.name) + if err != nil { + return nil, xerrors.Errorf("failed to open connection: %w", err) + } + + c.driver.Connections <- conn + + return conn, nil +} + +func (c *Connector) Driver() driver.Driver { + return c.driver +} + +func (c *Connector) Dialer(dialer pq.Dialer) { + c.dialer = dialer +} + +type Driver struct { + Connections chan driver.Conn +} + +func NewDriver() *Driver { + return &Driver{ + Connections: make(chan driver.Conn, 1), + } +} + +func (d *Driver) Connector(name string) (driver.Connector, error) { + return &Connector{ + name: name, + driver: d, + }, nil +} + +func (d *Driver) Open(name string) (driver.Conn, error) { + c, err := d.Connector(name) + if err != nil { + return nil, err + } + + return c.Connect(context.Background()) +} + +func (d *Driver) Close() { + close(d.Connections) +} diff --git a/coderd/database/pubsub/pubsub.go b/coderd/database/pubsub/pubsub.go index c391a7c3eaf66..79be4bd602032 100644 --- a/coderd/database/pubsub/pubsub.go +++ b/coderd/database/pubsub/pubsub.go @@ -3,6 +3,7 @@ package pubsub import ( "context" "database/sql" + "database/sql/driver" "errors" "io" "net" @@ -15,6 +16,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "golang.org/x/xerrors" + "github.com/coder/coder/v2/coderd/database" + "cdr.dev/slog" ) @@ -432,9 +435,35 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error { // pq.defaultDialer uses a zero net.Dialer as well. d: net.Dialer{}, } + connector driver.Connector + err error ) + + // Create a custom connector if the database driver supports it. + connectorCreator, ok := p.db.Driver().(database.ConnectorCreator) + if ok { + connector, err = connectorCreator.Connector(connectURL) + if err != nil { + return xerrors.Errorf("create custom connector: %w", err) + } + } else { + // use the default pq connector otherwise + connector, err = pq.NewConnector(connectURL) + if err != nil { + return xerrors.Errorf("create pq connector: %w", err) + } + } + + // Set the dialer if the connector supports it. + dc, ok := connector.(database.DialerConnector) + if !ok { + p.logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing") + } else { + dc.Dialer(dialer) + } + p.pgListener = pqListenerShim{ - Listener: pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) { + Listener: pq.NewConnectorListener(connector, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) { switch t { case pq.ListenerEventConnected: p.logger.Info(ctx, "pubsub connected to postgres") @@ -583,8 +612,8 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) { } // New creates a new Pubsub implementation using a PostgreSQL connection. -func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connectURL string) (*PGPubsub, error) { - p := newWithoutListener(logger, database) +func New(startCtx context.Context, logger slog.Logger, db *sql.DB, connectURL string) (*PGPubsub, error) { + p := newWithoutListener(logger, db) if err := p.startListener(startCtx, connectURL); err != nil { return nil, err } @@ -594,11 +623,11 @@ func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connect } // newWithoutListener creates a new PGPubsub without creating the pqListener. -func newWithoutListener(logger slog.Logger, database *sql.DB) *PGPubsub { +func newWithoutListener(logger slog.Logger, db *sql.DB) *PGPubsub { return &PGPubsub{ logger: logger, listenDone: make(chan struct{}), - db: database, + db: db, queues: make(map[string]map[uuid.UUID]*msgQueue), latencyMeasurer: NewLatencyMeasurer(logger.Named("latency-measurer")), diff --git a/coderd/database/pubsub/pubsub_test.go b/coderd/database/pubsub/pubsub_test.go index d36298bb3221d..6059b0cecbd97 100644 --- a/coderd/database/pubsub/pubsub_test.go +++ b/coderd/database/pubsub/pubsub_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "testing" + "time" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" @@ -51,7 +52,7 @@ func TestPGPubsub_Metrics(t *testing.T) { event := "test" data := "testing" messageChannel := make(chan []byte) - unsub0, err := uut.Subscribe(event, func(ctx context.Context, message []byte) { + unsub0, err := uut.Subscribe(event, func(_ context.Context, message []byte) { messageChannel <- message }) require.NoError(t, err) @@ -86,7 +87,7 @@ func TestPGPubsub_Metrics(t *testing.T) { for i := range colossalData { colossalData[i] = 'q' } - unsub1, err := uut.Subscribe(event, func(ctx context.Context, message []byte) { + unsub1, err := uut.Subscribe(event, func(_ context.Context, message []byte) { messageChannel <- message }) require.NoError(t, err) @@ -119,3 +120,74 @@ func TestPGPubsub_Metrics(t *testing.T) { !testutil.PromCounterGathered(t, metrics, "coder_pubsub_latency_measure_errs_total") }, testutil.WaitShort, testutil.IntervalFast) } + +func TestPGPubsubDriver(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, &slogtest.Options{ + IgnoreErrors: true, + }).Leveled(slog.LevelDebug) + + connectionURL, closePg, err := dbtestutil.Open() + require.NoError(t, err) + defer closePg() + + // use a separate subber and pubber so we can keep track of listener connections + db, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + pubber, err := pubsub.New(ctx, logger, db, connectionURL) + require.NoError(t, err) + defer pubber.Close() + + // use a connector that sends us the connections for the subber + subDriver := dbtestutil.NewDriver() + defer subDriver.Close() + tconn, err := subDriver.Connector(connectionURL) + require.NoError(t, err) + tcdb := sql.OpenDB(tconn) + subber, err := pubsub.New(ctx, logger, tcdb, connectionURL) + require.NoError(t, err) + defer subber.Close() + + // test that we can publish and subscribe + gotChan := make(chan struct{}, 1) + defer close(gotChan) + subCancel, err := subber.Subscribe("test", func(_ context.Context, _ []byte) { + gotChan <- struct{}{} + }) + require.NoError(t, err) + defer subCancel() + + // send a message + err = pubber.Publish("test", []byte("hello")) + require.NoError(t, err) + + // wait for the message + _ = testutil.RequireRecvCtx(ctx, t, gotChan) + + // read out first connection + firstConn := testutil.RequireRecvCtx(ctx, t, subDriver.Connections) + + // drop the underlying connection being used by the pubsub + // the pq.Listener should reconnect and repopulate it's listeners + // so old subscriptions should still work + err = firstConn.Close() + require.NoError(t, err) + + // wait for the reconnect + _ = testutil.RequireRecvCtx(ctx, t, subDriver.Connections) + // we need to sleep because the raw connection notification + // is sent before the pq.Listener can reestablish it's listeners + time.Sleep(1 * time.Second) + + // ensure our old subscription still fires + err = pubber.Publish("test", []byte("hello-again")) + require.NoError(t, err) + + // wait for the message on the old subscription + _ = testutil.RequireRecvCtx(ctx, t, gotChan) +} diff --git a/flake.nix b/flake.nix index 3a002707196db..15ce314c7b427 100644 --- a/flake.nix +++ b/flake.nix @@ -117,7 +117,7 @@ name = "coder-${osArch}"; # Updated with ./scripts/update-flake.sh`. # This should be updated whenever go.mod changes! - vendorHash = "sha256-I/FcLT6N7Nz21QptkvCcs/SpMJFH0B5xVzIZNrEqVGo="; + vendorHash = "sha256-fQsVoD/aRjVXmvQ/Pg4O9tpJCPlf3eC2uo0z0TU7AX8="; proxyVendor = true; src = ./.; nativeBuildInputs = with pkgs; [ getopt openssl zstd ]; diff --git a/go.mod b/go.mod index 38a110e79ea56..57b0ee11e919e 100644 --- a/go.mod +++ b/go.mod @@ -62,6 +62,11 @@ replace github.com/imulab/go-scim/pkg/v2 => github.com/coder/go-scim/pkg/v2 v2.0 // Fixes https://github.com/coder/coder/issues/6685 replace github.com/pkg/sftp => github.com/mafredri/sftp v1.13.6-0.20231212144145-8218e927edb0 +// Adds support for a new Listener from a driver.Connector +// This lets us use rotating authentication tokens for passwords in connection strings +// which we use in the awsiamrds package. +replace github.com/lib/pq => github.com/coder/pq v1.10.5-0.20240813183442-0c420cb5a048 + require ( cdr.dev/slog v1.6.2-0.20240126064726-20367d4aede6 cloud.google.com/go/compute/metadata v0.5.0 diff --git a/go.sum b/go.sum index a68b8c20730dd..f017dc1b7db3c 100644 --- a/go.sum +++ b/go.sum @@ -216,6 +216,8 @@ github.com/coder/go-httpstat v0.0.0-20230801153223-321c88088322 h1:m0lPZjlQ7vdVp github.com/coder/go-httpstat v0.0.0-20230801153223-321c88088322/go.mod h1:rOLFDDVKVFiDqZFXoteXc97YXx7kFi9kYqR+2ETPkLQ= github.com/coder/go-scim/pkg/v2 v2.0.0-20230221055123-1d63c1222136 h1:0RgB61LcNs24WOxc3PBvygSNTQurm0PYPujJjLLOzs0= github.com/coder/go-scim/pkg/v2 v2.0.0-20230221055123-1d63c1222136/go.mod h1:VkD1P761nykiq75dz+4iFqIQIZka189tx1BQLOp0Skc= +github.com/coder/pq v1.10.5-0.20240813183442-0c420cb5a048 h1:3jzYUlGH7ZELIH4XggXhnTnP05FCYiAFeQpoN+gNR5I= +github.com/coder/pq v1.10.5-0.20240813183442-0c420cb5a048/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0 h1:3A0ES21Ke+FxEM8CXx9n47SZOKOpgSE1bbJzlE4qPVs= github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0/go.mod h1:5UuS2Ts+nTToAMeOjNlnHFkPahrtDkmpydBen/3wgZc= github.com/coder/quartz v0.1.0 h1:cLL+0g5l7xTf6ordRnUMMiZtRE8Sq5LxpghS63vEXrQ= @@ -672,8 +674,6 @@ github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1 github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mafredri/sftp v1.13.6-0.20231212144145-8218e927edb0 h1:lG2o/EWMEOlV/RfQrf3zYfQStjnUj0Mg2gmbcBcoxFI=