diff --git a/cli/server.go b/cli/server.go index a2574593b18b6..c0d7d6fcee13e 100644 --- a/cli/server.go +++ b/cli/server.go @@ -95,6 +95,7 @@ import ( "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/unhanger" "github.com/coder/coder/v2/coderd/updatecheck" + "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/coderd/util/slice" stringutil "github.com/coder/coder/v2/coderd/util/strings" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" @@ -779,7 +780,10 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. // Manage push notifications. experiments := coderd.ReadExperiments(options.Logger, options.DeploymentValues.Experiments.Value()) if experiments.Enabled(codersdk.ExperimentWebPush) { - webpusher, err := webpush.New(ctx, &options.Logger, options.Database) + if !strings.HasPrefix(options.AccessURL.String(), "https://") { + options.Logger.Warn(ctx, "access URL is not HTTPS, so web push notifications may not work on some browsers", slog.F("access_url", options.AccessURL.String())) + } + webpusher, err := webpush.New(ctx, ptr.Ref(options.Logger.Named("webpush")), options.Database, options.AccessURL.String()) if err != nil { options.Logger.Error(ctx, "failed to create web push dispatcher", slog.Error(err)) options.Logger.Warn(ctx, "web push notifications will not work until the VAPID keys are regenerated") diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index ca0bc25e29647..b9097863a5f67 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -284,7 +284,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can if options.WebpushDispatcher == nil { // nolint:gocritic // Gets/sets VAPID keys. - pushNotifier, err := webpush.New(dbauthz.AsNotifier(context.Background()), options.Logger, options.Database) + pushNotifier, err := webpush.New(dbauthz.AsNotifier(context.Background()), options.Logger, options.Database, "http://example.com") if err != nil { panic(xerrors.Errorf("failed to create web push notifier: %w", err)) } diff --git a/coderd/webpush/webpush.go b/coderd/webpush/webpush.go index a6c0790d2dce1..eb35685402c21 100644 --- a/coderd/webpush/webpush.go +++ b/coderd/webpush/webpush.go @@ -41,13 +41,14 @@ type Dispatcher interface { // for updates inside of a workspace, which we want to be immediate. // // See: https://github.com/coder/internal/issues/528 -func New(ctx context.Context, log *slog.Logger, db database.Store) (Dispatcher, error) { +func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string) (Dispatcher, error) { keys, err := db.GetWebpushVAPIDKeys(ctx) if err != nil { if !errors.Is(err, sql.ErrNoRows) { return nil, xerrors.Errorf("get notification vapid keys: %w", err) } } + if keys.VapidPublicKey == "" || keys.VapidPrivateKey == "" { // Generate new VAPID keys. This also deletes all existing push // subscriptions as part of the transaction, as they are no longer @@ -62,6 +63,7 @@ func New(ctx context.Context, log *slog.Logger, db database.Store) (Dispatcher, } return &Webpusher{ + vapidSub: vapidSub, store: db, log: log, VAPIDPublicKey: keys.VapidPublicKey, @@ -72,7 +74,13 @@ func New(ctx context.Context, log *slog.Logger, db database.Store) (Dispatcher, type Webpusher struct { store database.Store log *slog.Logger + // VAPID allows us to identify the sender of the message. + // This must be a https:// URL or an email address. + // Some push services (such as Apple's) require this to be set. + vapidSub string + // public and private keys for VAPID. These are used to sign and encrypt + // the message payload. VAPIDPublicKey string VAPIDPrivateKey string } @@ -148,10 +156,12 @@ func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string Endpoint: endpoint, Keys: keys, }, &webpush.Options{ + Subscriber: n.vapidSub, VAPIDPublicKey: n.VAPIDPublicKey, VAPIDPrivateKey: n.VAPIDPrivateKey, }) if err != nil { + n.log.Error(ctx, "failed to send webpush notification", slog.Error(err), slog.F("endpoint", endpoint)) return -1, nil, xerrors.Errorf("send webpush notification: %w", err) } defer resp.Body.Close() diff --git a/coderd/webpush/webpush_test.go b/coderd/webpush/webpush_test.go index 2566e0edb348d..0c01c55fca86b 100644 --- a/coderd/webpush/webpush_test.go +++ b/coderd/webpush/webpush_test.go @@ -2,6 +2,8 @@ package webpush_test import ( "context" + "encoding/json" + "io" "net/http" "net/http/httptest" "testing" @@ -32,7 +34,9 @@ func TestPush(t *testing.T) { t.Run("SuccessfulDelivery", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) - manager, store, serverURL := setupPushTest(ctx, t, func(w http.ResponseWriter, _ *http.Request) { + msg := randomWebpushMessage(t) + manager, store, serverURL := setupPushTest(ctx, t, func(w http.ResponseWriter, r *http.Request) { + assertWebpushPayload(t, r) w.WriteHeader(http.StatusOK) }) user := dbgen.User(t, store, database.User{}) @@ -45,16 +49,7 @@ func TestPush(t *testing.T) { }) require.NoError(t, err) - notification := codersdk.WebpushMessage{ - Title: "Test Title", - Body: "Test Body", - Actions: []codersdk.WebpushMessageAction{ - {Label: "View", URL: "https://coder.com/view"}, - }, - Icon: "workspace", - } - - err = manager.Dispatch(ctx, user.ID, notification) + err = manager.Dispatch(ctx, user.ID, msg) require.NoError(t, err) subscriptions, err := store.GetWebpushSubscriptionsByUserID(ctx, user.ID) @@ -66,7 +61,8 @@ func TestPush(t *testing.T) { t.Run("ExpiredSubscription", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) - manager, store, serverURL := setupPushTest(ctx, t, func(w http.ResponseWriter, _ *http.Request) { + manager, store, serverURL := setupPushTest(ctx, t, func(w http.ResponseWriter, r *http.Request) { + assertWebpushPayload(t, r) w.WriteHeader(http.StatusGone) }) user := dbgen.User(t, store, database.User{}) @@ -79,12 +75,8 @@ func TestPush(t *testing.T) { }) require.NoError(t, err) - notification := codersdk.WebpushMessage{ - Title: "Test Title", - Body: "Test Body", - } - - err = manager.Dispatch(ctx, user.ID, notification) + msg := randomWebpushMessage(t) + err = manager.Dispatch(ctx, user.ID, msg) require.NoError(t, err) subscriptions, err := store.GetWebpushSubscriptionsByUserID(ctx, user.ID) @@ -95,7 +87,8 @@ func TestPush(t *testing.T) { t.Run("FailedDelivery", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) - manager, store, serverURL := setupPushTest(ctx, t, func(w http.ResponseWriter, _ *http.Request) { + manager, store, serverURL := setupPushTest(ctx, t, func(w http.ResponseWriter, r *http.Request) { + assertWebpushPayload(t, r) w.WriteHeader(http.StatusBadRequest) w.Write([]byte("Invalid request")) }) @@ -110,12 +103,8 @@ func TestPush(t *testing.T) { }) require.NoError(t, err) - notification := codersdk.WebpushMessage{ - Title: "Test Title", - Body: "Test Body", - } - - err = manager.Dispatch(ctx, user.ID, notification) + msg := randomWebpushMessage(t) + err = manager.Dispatch(ctx, user.ID, msg) require.Error(t, err) assert.Contains(t, err.Error(), "Invalid request") @@ -130,13 +119,15 @@ func TestPush(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) var okEndpointCalled bool var goneEndpointCalled bool - manager, store, serverOKURL := setupPushTest(ctx, t, func(w http.ResponseWriter, _ *http.Request) { + manager, store, serverOKURL := setupPushTest(ctx, t, func(w http.ResponseWriter, r *http.Request) { okEndpointCalled = true + assertWebpushPayload(t, r) w.WriteHeader(http.StatusOK) }) - serverGone := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + serverGone := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { goneEndpointCalled = true + assertWebpushPayload(t, r) w.WriteHeader(http.StatusGone) })) defer serverGone.Close() @@ -163,15 +154,8 @@ func TestPush(t *testing.T) { }) require.NoError(t, err) - notification := codersdk.WebpushMessage{ - Title: "Test Title", - Body: "Test Body", - Actions: []codersdk.WebpushMessageAction{ - {Label: "View", URL: "https://coder.com/view"}, - }, - } - - err = manager.Dispatch(ctx, user.ID, notification) + msg := randomWebpushMessage(t) + err = manager.Dispatch(ctx, user.ID, msg) require.NoError(t, err) assert.True(t, okEndpointCalled, "The valid endpoint should be called") assert.True(t, goneEndpointCalled, "The expired endpoint should be called") @@ -189,8 +173,9 @@ func TestPush(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) var requestReceived bool - manager, store, serverURL := setupPushTest(ctx, t, func(w http.ResponseWriter, _ *http.Request) { + manager, store, serverURL := setupPushTest(ctx, t, func(w http.ResponseWriter, r *http.Request) { requestReceived = true + assertWebpushPayload(t, r) w.WriteHeader(http.StatusOK) }) @@ -205,17 +190,8 @@ func TestPush(t *testing.T) { }) require.NoError(t, err, "Failed to insert push subscription") - notification := codersdk.WebpushMessage{ - Title: "Test Notification", - Body: "This is a test notification body", - Actions: []codersdk.WebpushMessageAction{ - {Label: "View Workspace", URL: "https://coder.com/workspace/123"}, - {Label: "Cancel", URL: "https://coder.com/cancel"}, - }, - Icon: "workspace-icon", - } - - err = manager.Dispatch(ctx, user.ID, notification) + msg := randomWebpushMessage(t) + err = manager.Dispatch(ctx, user.ID, msg) require.NoError(t, err, "The push notification should be dispatched successfully") require.True(t, requestReceived, "The push notification request should have been received by the server") }) @@ -242,15 +218,42 @@ func TestPush(t *testing.T) { }) } +func randomWebpushMessage(t testing.TB) codersdk.WebpushMessage { + t.Helper() + return codersdk.WebpushMessage{ + Title: testutil.GetRandomName(t), + Body: testutil.GetRandomName(t), + + Actions: []codersdk.WebpushMessageAction{ + {Label: "A", URL: "https://example.com/a"}, + {Label: "B", URL: "https://example.com/b"}, + }, + Icon: "https://example.com/icon.png", + } +} + +func assertWebpushPayload(t testing.TB, r *http.Request) { + t.Helper() + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "application/octet-stream", r.Header.Get("Content-Type")) + assert.Equal(t, r.Header.Get("content-encoding"), "aes128gcm") + assert.Contains(t, r.Header.Get("Authorization"), "vapid") + + // Attempting to decode the request body as JSON should fail as it is + // encrypted. + assert.Error(t, json.NewDecoder(r.Body).Decode(io.Discard)) +} + // setupPushTest creates a common test setup for webpush notification tests func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) { + t.Helper() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) db, _ := dbtestutil.NewDB(t) server := httptest.NewServer(http.HandlerFunc(handlerFunc)) t.Cleanup(server.Close) - manager, err := webpush.New(ctx, &logger, db) + manager, err := webpush.New(ctx, &logger, db, "http://example.com") require.NoError(t, err, "Failed to create webpush manager") return manager, db, server.URL diff --git a/site/src/api/api.ts b/site/src/api/api.ts index b042735357ab0..85953bbce736f 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -2371,6 +2371,28 @@ class ApiMethods { await this.axios.post("/api/v2/notifications/test"); }; + createWebPushSubscription = async ( + userId: string, + req: TypesGen.WebpushSubscription, + ) => { + await this.axios.post( + `/api/v2/users/${userId}/webpush/subscription`, + req, + ); + }; + + deleteWebPushSubscription = async ( + userId: string, + req: TypesGen.DeleteWebpushSubscription, + ) => { + await this.axios.delete( + `/api/v2/users/${userId}/webpush/subscription`, + { + data: req, + }, + ); + }; + requestOneTimePassword = async ( req: TypesGen.RequestOneTimePasscodeRequest, ) => { diff --git a/site/src/contexts/useWebpushNotifications.ts b/site/src/contexts/useWebpushNotifications.ts new file mode 100644 index 0000000000000..0f3949135c287 --- /dev/null +++ b/site/src/contexts/useWebpushNotifications.ts @@ -0,0 +1,110 @@ +import { API } from "api/api"; +import { buildInfo } from "api/queries/buildInfo"; +import { experiments } from "api/queries/experiments"; +import { useEmbeddedMetadata } from "hooks/useEmbeddedMetadata"; +import { useEffect, useState } from "react"; +import { useQuery } from "react-query"; + +interface WebpushNotifications { + readonly enabled: boolean; + readonly subscribed: boolean; + readonly loading: boolean; + + subscribe(): Promise; + unsubscribe(): Promise; +} + +export const useWebpushNotifications = (): WebpushNotifications => { + const { metadata } = useEmbeddedMetadata(); + const buildInfoQuery = useQuery(buildInfo(metadata["build-info"])); + const enabledExperimentsQuery = useQuery(experiments(metadata.experiments)); + + const [subscribed, setSubscribed] = useState(false); + const [loading, setLoading] = useState(true); + const [enabled, setEnabled] = useState(false); + + useEffect(() => { + // Check if the experiment is enabled. + if (enabledExperimentsQuery.data?.includes("web-push")) { + setEnabled(true); + } + + // Check if browser supports push notifications + if (!("Notification" in window) || !("serviceWorker" in navigator)) { + setSubscribed(false); + setLoading(false); + return; + } + + const checkSubscription = async () => { + try { + const registration = await navigator.serviceWorker.ready; + const subscription = await registration.pushManager.getSubscription(); + setSubscribed(!!subscription); + } catch (error) { + console.error("Error checking push subscription:", error); + setSubscribed(false); + } finally { + setLoading(false); + } + }; + + checkSubscription(); + }, [enabledExperimentsQuery.data]); + + const subscribe = async (): Promise => { + try { + setLoading(true); + const registration = await navigator.serviceWorker.ready; + const vapidPublicKey = buildInfoQuery.data?.webpush_public_key; + + const subscription = await registration.pushManager.subscribe({ + userVisibleOnly: true, + applicationServerKey: vapidPublicKey, + }); + const json = subscription.toJSON(); + if (!json.keys || !json.endpoint) { + throw new Error("No keys or endpoint found"); + } + + await API.createWebPushSubscription("me", { + endpoint: json.endpoint, + auth_key: json.keys.auth, + p256dh_key: json.keys.p256dh, + }); + + setSubscribed(true); + } catch (error) { + console.error("Subscription failed:", error); + throw error; + } finally { + setLoading(false); + } + }; + + const unsubscribe = async (): Promise => { + try { + setLoading(true); + const registration = await navigator.serviceWorker.ready; + const subscription = await registration.pushManager.getSubscription(); + + if (subscription) { + await subscription.unsubscribe(); + setSubscribed(false); + } + } catch (error) { + console.error("Unsubscription failed:", error); + throw error; + } finally { + setLoading(false); + } + }; + + return { + subscribed, + enabled, + loading: loading || buildInfoQuery.isLoading, + subscribe, + unsubscribe, + }; +}; diff --git a/site/src/index.tsx b/site/src/index.tsx index aef10d6c64f4d..85d66b9833d3e 100644 --- a/site/src/index.tsx +++ b/site/src/index.tsx @@ -14,5 +14,10 @@ if (element === null) { throw new Error("root element is null"); } +// The service worker handles push notifications. +if ("serviceWorker" in navigator) { + navigator.serviceWorker.register("/serviceWorker.js"); +} + const root = createRoot(element); root.render(); diff --git a/site/src/modules/dashboard/Navbar/NavbarView.tsx b/site/src/modules/dashboard/Navbar/NavbarView.tsx index 204828c2fd8ac..a581e2b2434f7 100644 --- a/site/src/modules/dashboard/Navbar/NavbarView.tsx +++ b/site/src/modules/dashboard/Navbar/NavbarView.tsx @@ -1,10 +1,15 @@ import { API } from "api/api"; +import { experiments } from "api/queries/experiments"; import type * as TypesGen from "api/typesGenerated"; +import { Button } from "components/Button/Button"; import { ExternalImage } from "components/ExternalImage/ExternalImage"; import { CoderIcon } from "components/Icons/CoderIcon"; import type { ProxyContextValue } from "contexts/ProxyContext"; +import { useWebpushNotifications } from "contexts/useWebpushNotifications"; +import { useEmbeddedMetadata } from "hooks/useEmbeddedMetadata"; import { NotificationsInbox } from "modules/notifications/NotificationsInbox/NotificationsInbox"; import type { FC } from "react"; +import { useQuery } from "react-query"; import { NavLink, useLocation } from "react-router-dom"; import { cn } from "utils/cn"; import { DeploymentDropdown } from "./DeploymentDropdown"; @@ -43,6 +48,9 @@ export const NavbarView: FC = ({ canViewAuditLog, proxyContextValue, }) => { + const { subscribed, enabled, loading, subscribe, unsubscribe } = + useWebpushNotifications(); + return (
@@ -71,6 +79,18 @@ export const NavbarView: FC = ({ />
+ {enabled ? ( + subscribed ? ( + + ) : ( + + ) + ) : null} + + +import type { WebpushMessage } from "api/typesGenerated"; + +// @ts-ignore +declare const self: ServiceWorkerGlobalScope; + +self.addEventListener("install", (event) => { + self.skipWaiting(); +}); + +self.addEventListener("activate", (event) => { + event.waitUntil(self.clients.claim()); +}); + +self.addEventListener("push", (event) => { + if (!event.data) { + return; + } + + let payload: WebpushMessage; + try { + payload = event.data?.json(); + } catch (e) { + console.error("Error parsing push payload:", e); + return; + } + + event.waitUntil( + self.registration.showNotification(payload.title, { + body: payload.body || "", + icon: payload.icon || "/favicon.ico", + }), + ); +}); + +// Handle notification click +self.addEventListener("notificationclick", (event) => { + event.notification.close(); +}); diff --git a/site/src/testHelpers/entities.ts b/site/src/testHelpers/entities.ts index d956e09957c7e..f80171122826c 100644 --- a/site/src/testHelpers/entities.ts +++ b/site/src/testHelpers/entities.ts @@ -227,6 +227,7 @@ export const MockBuildInfo: TypesGen.BuildInfoResponse = { workspace_proxy: false, upgrade_message: "My custom upgrade message", deployment_id: "510d407f-e521-4180-b559-eab4a6d802b8", + webpush_public_key: "fake-public-key", telemetry: true, }; diff --git a/site/vite.config.mts b/site/vite.config.mts index 436565c491240..89c5c924a8563 100644 --- a/site/vite.config.mts +++ b/site/vite.config.mts @@ -1,5 +1,6 @@ import * as path from "node:path"; import react from "@vitejs/plugin-react"; +import { buildSync } from "esbuild"; import { visualizer } from "rollup-plugin-visualizer"; import { type PluginOption, defineConfig } from "vite"; import checker from "vite-plugin-checker"; @@ -28,6 +29,19 @@ export default defineConfig({ emptyOutDir: false, // 'hidden' works like true except that the corresponding sourcemap comments in the bundled files are suppressed sourcemap: "hidden", + rollupOptions: { + input: { + index: path.resolve(__dirname, "./index.html"), + serviceWorker: path.resolve(__dirname, "./src/serviceWorker.ts"), + }, + output: { + entryFileNames: (chunkInfo) => { + return chunkInfo.name === "serviceWorker" + ? "[name].js" + : "assets/[name]-[hash].js"; + }, + }, + }, }, define: { "process.env": { @@ -89,6 +103,10 @@ export default defineConfig({ target: process.env.CODER_HOST || "http://localhost:3000", secure: process.env.NODE_ENV === "production", }, + "/serviceWorker.js": { + target: process.env.CODER_HOST || "http://localhost:3000", + secure: process.env.NODE_ENV === "production", + }, }, allowedHosts: true, },