From 3569e56cf912aac565d7357b4efeb7ecd37a643f Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Mon, 17 Mar 2025 11:13:40 -0500 Subject: [PATCH 01/10] fix: hide deleted users from org members query (cherry-pick #16830) (#16832) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cherry-picked fix: hide deleted users from org members query (#16830) Co-authored-by: ケイラ --- coderd/database/queries.sql.go | 2 +- coderd/database/queries/organizationmembers.sql | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 2f28b563e98b8..37c79537c622f 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4967,7 +4967,7 @@ SELECT FROM organization_members INNER JOIN - users ON organization_members.user_id = users.id + users ON organization_members.user_id = users.id AND users.deleted = false WHERE -- Filter by organization id CASE diff --git a/coderd/database/queries/organizationmembers.sql b/coderd/database/queries/organizationmembers.sql index 71304c8883602..8685e71129ac9 100644 --- a/coderd/database/queries/organizationmembers.sql +++ b/coderd/database/queries/organizationmembers.sql @@ -9,7 +9,7 @@ SELECT FROM organization_members INNER JOIN - users ON organization_members.user_id = users.id + users ON organization_members.user_id = users.id AND users.deleted = false WHERE -- Filter by organization id CASE From 8a3ebaca0112703297af94d9f65dedaa832e2b4e Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Mon, 17 Mar 2025 16:00:39 -0500 Subject: [PATCH 02/10] fix: fix deployment settings navigation issues (cherry-pick #16780) (#16790) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cherry-picked fix: fix deployment settings navigation issues (#16780) Co-authored-by: ケイラ --- site/e2e/tests/roles.spec.ts | 157 +++++++++++++++++ site/src/api/queries/organizations.ts | 17 -- site/src/contexts/auth/AuthProvider.tsx | 6 +- site/src/contexts/auth/permissions.tsx | 159 ++++++++++++------ .../modules/dashboard/DashboardProvider.tsx | 21 +-- .../dashboard/Navbar/DeploymentDropdown.tsx | 2 +- .../modules/dashboard/Navbar/MobileMenu.tsx | 2 +- site/src/modules/dashboard/Navbar/Navbar.tsx | 11 +- .../dashboard/Navbar/NavbarView.test.tsx | 2 +- .../dashboard/Navbar/ProxyMenu.stories.tsx | 4 +- .../management/DeploymentSettingsLayout.tsx | 26 ++- .../management/DeploymentSettingsProvider.tsx | 25 +-- .../management/organizationPermissions.tsx | 62 ------- .../TerminalPage/TerminalPage.stories.tsx | 6 +- site/src/router.tsx | 5 +- site/src/testHelpers/entities.ts | 58 ++++--- site/src/testHelpers/handlers.ts | 4 +- site/src/testHelpers/storybook.tsx | 4 +- 18 files changed, 350 insertions(+), 221 deletions(-) create mode 100644 site/e2e/tests/roles.spec.ts diff --git a/site/e2e/tests/roles.spec.ts b/site/e2e/tests/roles.spec.ts new file mode 100644 index 0000000000000..482436c9c9b2d --- /dev/null +++ b/site/e2e/tests/roles.spec.ts @@ -0,0 +1,157 @@ +import { type Page, expect, test } from "@playwright/test"; +import { + createOrganization, + createOrganizationMember, + setupApiCalls, +} from "../api"; +import { license, users } from "../constants"; +import { login, requiresLicense } from "../helpers"; +import { beforeCoderTest } from "../hooks"; + +test.beforeEach(async ({ page }) => { + beforeCoderTest(page); +}); + +type AdminSetting = (typeof adminSettings)[number]; + +const adminSettings = [ + "Deployment", + "Organizations", + "Healthcheck", + "Audit Logs", +] as const; + +async function hasAccessToAdminSettings(page: Page, settings: AdminSetting[]) { + // Organizations and Audit Logs both require a license to be visible + const visibleSettings = license + ? settings + : settings.filter((it) => it !== "Organizations" && it !== "Audit Logs"); + const adminSettingsButton = page.getByRole("button", { + name: "Admin settings", + }); + if (visibleSettings.length < 1) { + await expect(adminSettingsButton).not.toBeVisible(); + return; + } + + await adminSettingsButton.click(); + + for (const name of visibleSettings) { + await expect(page.getByText(name, { exact: true })).toBeVisible(); + } + + const hiddenSettings = adminSettings.filter( + (it) => !visibleSettings.includes(it), + ); + for (const name of hiddenSettings) { + await expect(page.getByText(name, { exact: true })).not.toBeVisible(); + } +} + +test.describe("roles admin settings access", () => { + test("member cannot see admin settings", async ({ page }) => { + await login(page, users.member); + await page.goto("/", { waitUntil: "domcontentloaded" }); + + // None, "Admin settings" button should not be visible + await hasAccessToAdminSettings(page, []); + }); + + test("template admin can see admin settings", async ({ page }) => { + await login(page, users.templateAdmin); + await page.goto("/", { waitUntil: "domcontentloaded" }); + + await hasAccessToAdminSettings(page, ["Deployment", "Organizations"]); + }); + + test("user admin can see admin settings", async ({ page }) => { + await login(page, users.userAdmin); + await page.goto("/", { waitUntil: "domcontentloaded" }); + + await hasAccessToAdminSettings(page, ["Deployment", "Organizations"]); + }); + + test("auditor can see admin settings", async ({ page }) => { + await login(page, users.auditor); + await page.goto("/", { waitUntil: "domcontentloaded" }); + + await hasAccessToAdminSettings(page, [ + "Deployment", + "Organizations", + "Audit Logs", + ]); + }); + + test("admin can see admin settings", async ({ page }) => { + await login(page, users.admin); + await page.goto("/", { waitUntil: "domcontentloaded" }); + + await hasAccessToAdminSettings(page, [ + "Deployment", + "Organizations", + "Healthcheck", + "Audit Logs", + ]); + }); +}); + +test.describe("org-scoped roles admin settings access", () => { + requiresLicense(); + + test.beforeEach(async ({ page }) => { + await login(page); + await setupApiCalls(page); + }); + + test("org template admin can see admin settings", async ({ page }) => { + const org = await createOrganization(); + const orgTemplateAdmin = await createOrganizationMember({ + [org.id]: ["organization-template-admin"], + }); + + await login(page, orgTemplateAdmin); + await page.goto("/", { waitUntil: "domcontentloaded" }); + + await hasAccessToAdminSettings(page, ["Organizations"]); + }); + + test("org user admin can see admin settings", async ({ page }) => { + const org = await createOrganization(); + const orgUserAdmin = await createOrganizationMember({ + [org.id]: ["organization-user-admin"], + }); + + await login(page, orgUserAdmin); + await page.goto("/", { waitUntil: "domcontentloaded" }); + + await hasAccessToAdminSettings(page, ["Deployment", "Organizations"]); + }); + + test("org auditor can see admin settings", async ({ page }) => { + const org = await createOrganization(); + const orgAuditor = await createOrganizationMember({ + [org.id]: ["organization-auditor"], + }); + + await login(page, orgAuditor); + await page.goto("/", { waitUntil: "domcontentloaded" }); + + await hasAccessToAdminSettings(page, ["Organizations", "Audit Logs"]); + }); + + test("org admin can see admin settings", async ({ page }) => { + const org = await createOrganization(); + const orgAdmin = await createOrganizationMember({ + [org.id]: ["organization-admin"], + }); + + await login(page, orgAdmin); + await page.goto("/", { waitUntil: "domcontentloaded" }); + + await hasAccessToAdminSettings(page, [ + "Deployment", + "Organizations", + "Audit Logs", + ]); + }); +}); diff --git a/site/src/api/queries/organizations.ts b/site/src/api/queries/organizations.ts index a27514a03c161..374f9e7eacf4e 100644 --- a/site/src/api/queries/organizations.ts +++ b/site/src/api/queries/organizations.ts @@ -6,10 +6,8 @@ import type { UpdateOrganizationRequest, } from "api/typesGenerated"; import { - type AnyOrganizationPermissions, type OrganizationPermissionName, type OrganizationPermissions, - anyOrganizationPermissionChecks, organizationPermissionChecks, } from "modules/management/organizationPermissions"; import type { QueryClient } from "react-query"; @@ -266,21 +264,6 @@ export const organizationsPermissions = ( }; }; -export const anyOrganizationPermissionsKey = [ - "authorization", - "anyOrganization", -]; - -export const anyOrganizationPermissions = () => { - return { - queryKey: anyOrganizationPermissionsKey, - queryFn: () => - API.checkAuthorization({ - checks: anyOrganizationPermissionChecks, - }) as Promise, - }; -}; - export const getOrganizationIdpSyncClaimFieldValuesKey = ( organization: string, field: string, diff --git a/site/src/contexts/auth/AuthProvider.tsx b/site/src/contexts/auth/AuthProvider.tsx index ad475bddcbfb7..7418691a291e5 100644 --- a/site/src/contexts/auth/AuthProvider.tsx +++ b/site/src/contexts/auth/AuthProvider.tsx @@ -18,7 +18,7 @@ import { useContext, } from "react"; import { useMutation, useQuery, useQueryClient } from "react-query"; -import { type Permissions, permissionsToCheck } from "./permissions"; +import { type Permissions, permissionChecks } from "./permissions"; export type AuthContextValue = { isLoading: boolean; @@ -50,13 +50,13 @@ export const AuthProvider: FC = ({ children }) => { const hasFirstUserQuery = useQuery(hasFirstUser(userMetadataState)); const permissionsQuery = useQuery({ - ...checkAuthorization({ checks: permissionsToCheck }), + ...checkAuthorization({ checks: permissionChecks }), enabled: userQuery.data !== undefined, }); const queryClient = useQueryClient(); const loginMutation = useMutation( - login({ checks: permissionsToCheck }, queryClient), + login({ checks: permissionChecks }, queryClient), ); const logoutMutation = useMutation(logout(queryClient)); diff --git a/site/src/contexts/auth/permissions.tsx b/site/src/contexts/auth/permissions.tsx index 1043862942edb..0d8957627c36d 100644 --- a/site/src/contexts/auth/permissions.tsx +++ b/site/src/contexts/auth/permissions.tsx @@ -1,156 +1,205 @@ import type { AuthorizationCheck } from "api/typesGenerated"; -export const checks = { - viewAllUsers: "viewAllUsers", - updateUsers: "updateUsers", - createUser: "createUser", - createTemplates: "createTemplates", - updateTemplates: "updateTemplates", - deleteTemplates: "deleteTemplates", - viewAnyAuditLog: "viewAnyAuditLog", - viewDeploymentValues: "viewDeploymentValues", - editDeploymentValues: "editDeploymentValues", - viewUpdateCheck: "viewUpdateCheck", - viewExternalAuthConfig: "viewExternalAuthConfig", - viewDeploymentStats: "viewDeploymentStats", - readWorkspaceProxies: "readWorkspaceProxies", - editWorkspaceProxies: "editWorkspaceProxies", - createOrganization: "createOrganization", - viewAnyGroup: "viewAnyGroup", - createGroup: "createGroup", - viewAllLicenses: "viewAllLicenses", - viewNotificationTemplate: "viewNotificationTemplate", - viewOrganizationIDPSyncSettings: "viewOrganizationIDPSyncSettings", -} as const satisfies Record; +export type Permissions = { + [k in PermissionName]: boolean; +}; -// Type expression seems a little redundant (`keyof typeof checks` has the same -// result), just because each key-value pair is currently symmetrical; this may -// change down the line -type PermissionValue = (typeof checks)[keyof typeof checks]; +export type PermissionName = keyof typeof permissionChecks; -export const permissionsToCheck = { - [checks.viewAllUsers]: { +export const permissionChecks = { + viewAllUsers: { object: { resource_type: "user", }, action: "read", }, - [checks.updateUsers]: { + updateUsers: { object: { resource_type: "user", }, action: "update", }, - [checks.createUser]: { + createUser: { object: { resource_type: "user", }, action: "create", }, - [checks.createTemplates]: { + createTemplates: { object: { resource_type: "template", any_org: true, }, action: "update", }, - [checks.updateTemplates]: { + updateTemplates: { object: { resource_type: "template", }, action: "update", }, - [checks.deleteTemplates]: { + deleteTemplates: { object: { resource_type: "template", }, action: "delete", }, - [checks.viewAnyAuditLog]: { - object: { - resource_type: "audit_log", - any_org: true, - }, - action: "read", - }, - [checks.viewDeploymentValues]: { + viewDeploymentValues: { object: { resource_type: "deployment_config", }, action: "read", }, - [checks.editDeploymentValues]: { + editDeploymentValues: { object: { resource_type: "deployment_config", }, action: "update", }, - [checks.viewUpdateCheck]: { + viewUpdateCheck: { object: { resource_type: "deployment_config", }, action: "read", }, - [checks.viewExternalAuthConfig]: { + viewExternalAuthConfig: { object: { resource_type: "deployment_config", }, action: "read", }, - [checks.viewDeploymentStats]: { + viewDeploymentStats: { object: { resource_type: "deployment_stats", }, action: "read", }, - [checks.readWorkspaceProxies]: { + readWorkspaceProxies: { object: { resource_type: "workspace_proxy", }, action: "read", }, - [checks.editWorkspaceProxies]: { + editWorkspaceProxies: { object: { resource_type: "workspace_proxy", }, action: "create", }, - [checks.createOrganization]: { + createOrganization: { object: { resource_type: "organization", }, action: "create", }, - [checks.viewAnyGroup]: { + viewAnyGroup: { object: { resource_type: "group", }, action: "read", }, - [checks.createGroup]: { + createGroup: { object: { resource_type: "group", }, action: "create", }, - [checks.viewAllLicenses]: { + viewAllLicenses: { object: { resource_type: "license", }, action: "read", }, - [checks.viewNotificationTemplate]: { + viewNotificationTemplate: { object: { resource_type: "notification_template", }, action: "read", }, - [checks.viewOrganizationIDPSyncSettings]: { + viewOrganizationIDPSyncSettings: { object: { resource_type: "idpsync_settings", }, action: "read", }, -} as const satisfies Record; -export type Permissions = Record; + viewAnyMembers: { + object: { + resource_type: "organization_member", + any_org: true, + }, + action: "read", + }, + editAnyGroups: { + object: { + resource_type: "group", + any_org: true, + }, + action: "update", + }, + assignAnyRoles: { + object: { + resource_type: "assign_org_role", + any_org: true, + }, + action: "assign", + }, + viewAnyIdpSyncSettings: { + object: { + resource_type: "idpsync_settings", + any_org: true, + }, + action: "read", + }, + editAnySettings: { + object: { + resource_type: "organization", + any_org: true, + }, + action: "update", + }, + viewAnyAuditLog: { + object: { + resource_type: "audit_log", + any_org: true, + }, + action: "read", + }, + viewDebugInfo: { + object: { + resource_type: "debug_info", + }, + action: "read", + }, +} as const satisfies Record; + +export const canViewDeploymentSettings = ( + permissions: Permissions | undefined, +): permissions is Permissions => { + return ( + permissions !== undefined && + (permissions.viewDeploymentValues || + permissions.viewAllLicenses || + permissions.viewAllUsers || + permissions.viewAnyGroup || + permissions.viewNotificationTemplate || + permissions.viewOrganizationIDPSyncSettings) + ); +}; + +/** + * Checks if the user can view or edit members or groups for the organization + * that produced the given OrganizationPermissions. + */ +export const canViewAnyOrganization = ( + permissions: Permissions | undefined, +): permissions is Permissions => { + return ( + permissions !== undefined && + (permissions.viewAnyMembers || + permissions.editAnyGroups || + permissions.assignAnyRoles || + permissions.viewAnyIdpSyncSettings || + permissions.editAnySettings) + ); +}; diff --git a/site/src/modules/dashboard/DashboardProvider.tsx b/site/src/modules/dashboard/DashboardProvider.tsx index bf8e307206aea..bb5987d6546be 100644 --- a/site/src/modules/dashboard/DashboardProvider.tsx +++ b/site/src/modules/dashboard/DashboardProvider.tsx @@ -1,10 +1,7 @@ import { appearance } from "api/queries/appearance"; import { entitlements } from "api/queries/entitlements"; import { experiments } from "api/queries/experiments"; -import { - anyOrganizationPermissions, - organizations, -} from "api/queries/organizations"; +import { organizations } from "api/queries/organizations"; import type { AppearanceConfig, Entitlements, @@ -13,8 +10,9 @@ import type { } from "api/typesGenerated"; import { ErrorAlert } from "components/Alert/ErrorAlert"; import { Loader } from "components/Loader/Loader"; +import { useAuthenticated } from "contexts/auth/RequireAuth"; +import { canViewAnyOrganization } from "contexts/auth/permissions"; import { useEmbeddedMetadata } from "hooks/useEmbeddedMetadata"; -import { canViewAnyOrganization } from "modules/management/organizationPermissions"; import { type FC, type PropsWithChildren, createContext } from "react"; import { useQuery } from "react-query"; import { selectFeatureVisibility } from "./entitlements"; @@ -34,20 +32,17 @@ export const DashboardContext = createContext( export const DashboardProvider: FC = ({ children }) => { const { metadata } = useEmbeddedMetadata(); + const { permissions } = useAuthenticated(); const entitlementsQuery = useQuery(entitlements(metadata.entitlements)); const experimentsQuery = useQuery(experiments(metadata.experiments)); const appearanceQuery = useQuery(appearance(metadata.appearance)); const organizationsQuery = useQuery(organizations()); - const anyOrganizationPermissionsQuery = useQuery( - anyOrganizationPermissions(), - ); const error = entitlementsQuery.error || appearanceQuery.error || experimentsQuery.error || - organizationsQuery.error || - anyOrganizationPermissionsQuery.error; + organizationsQuery.error; if (error) { return ; @@ -57,8 +52,7 @@ export const DashboardProvider: FC = ({ children }) => { !entitlementsQuery.data || !appearanceQuery.data || !experimentsQuery.data || - !organizationsQuery.data || - !anyOrganizationPermissionsQuery.data; + !organizationsQuery.data; if (isLoading) { return ; @@ -79,8 +73,7 @@ export const DashboardProvider: FC = ({ children }) => { organizations: organizationsQuery.data, showOrganizations, canViewOrganizationSettings: - showOrganizations && - canViewAnyOrganization(anyOrganizationPermissionsQuery.data), + showOrganizations && canViewAnyOrganization(permissions), }} > {children} diff --git a/site/src/modules/dashboard/Navbar/DeploymentDropdown.tsx b/site/src/modules/dashboard/Navbar/DeploymentDropdown.tsx index 746ddc8f89e78..876a3eb441cf1 100644 --- a/site/src/modules/dashboard/Navbar/DeploymentDropdown.tsx +++ b/site/src/modules/dashboard/Navbar/DeploymentDropdown.tsx @@ -82,7 +82,7 @@ const DeploymentDropdownContent: FC = ({ {canViewDeployment && ( diff --git a/site/src/modules/dashboard/Navbar/MobileMenu.tsx b/site/src/modules/dashboard/Navbar/MobileMenu.tsx index 20058335eb8e5..ae5f600ba68de 100644 --- a/site/src/modules/dashboard/Navbar/MobileMenu.tsx +++ b/site/src/modules/dashboard/Navbar/MobileMenu.tsx @@ -220,7 +220,7 @@ const AdminSettingsSub: FC = ({ asChild className={cn(itemStyles.default, itemStyles.sub)} > - Deployment + Deployment )} {canViewOrganizations && ( diff --git a/site/src/modules/dashboard/Navbar/Navbar.tsx b/site/src/modules/dashboard/Navbar/Navbar.tsx index f80887e1f1aec..7dc96c791e7ca 100644 --- a/site/src/modules/dashboard/Navbar/Navbar.tsx +++ b/site/src/modules/dashboard/Navbar/Navbar.tsx @@ -1,6 +1,7 @@ import { buildInfo } from "api/queries/buildInfo"; import { useProxy } from "contexts/ProxyContext"; import { useAuthenticated } from "contexts/auth/RequireAuth"; +import { canViewDeploymentSettings } from "contexts/auth/permissions"; import { useEmbeddedMetadata } from "hooks/useEmbeddedMetadata"; import { useDashboard } from "modules/dashboard/useDashboard"; import type { FC } from "react"; @@ -11,16 +12,16 @@ import { NavbarView } from "./NavbarView"; export const Navbar: FC = () => { const { metadata } = useEmbeddedMetadata(); const buildInfoQuery = useQuery(buildInfo(metadata["build-info"])); - const { appearance, canViewOrganizationSettings } = useDashboard(); const { user: me, permissions, signOut } = useAuthenticated(); const featureVisibility = useFeatureVisibility(); + const proxyContextValue = useProxy(); + + const canViewDeployment = canViewDeploymentSettings(permissions); + const canViewOrganizations = canViewOrganizationSettings; + const canViewHealth = permissions.viewDebugInfo; const canViewAuditLog = featureVisibility.audit_log && permissions.viewAnyAuditLog; - const canViewDeployment = permissions.viewDeploymentValues; - const canViewOrganizations = canViewOrganizationSettings; - const proxyContextValue = useProxy(); - const canViewHealth = canViewDeployment; return ( { await userEvent.click(deploymentMenu); const deploymentSettingsLink = await screen.findByText(/deployment/i); - expect(deploymentSettingsLink.href).toContain("/deployment/general"); + expect(deploymentSettingsLink.href).toContain("/deployment"); }); }); diff --git a/site/src/modules/dashboard/Navbar/ProxyMenu.stories.tsx b/site/src/modules/dashboard/Navbar/ProxyMenu.stories.tsx index 883bbd0dd2f61..8e8cf7fcb8951 100644 --- a/site/src/modules/dashboard/Navbar/ProxyMenu.stories.tsx +++ b/site/src/modules/dashboard/Navbar/ProxyMenu.stories.tsx @@ -3,7 +3,7 @@ import { fn, userEvent, within } from "@storybook/test"; import { getAuthorizationKey } from "api/queries/authCheck"; import { getPreferredProxy } from "contexts/ProxyContext"; import { AuthProvider } from "contexts/auth/AuthProvider"; -import { permissionsToCheck } from "contexts/auth/permissions"; +import { permissionChecks } from "contexts/auth/permissions"; import { MockAuthMethodsAll, MockPermissions, @@ -45,7 +45,7 @@ const meta: Meta = { { key: ["authMethods"], data: MockAuthMethodsAll }, { key: ["hasFirstUser"], data: true }, { - key: getAuthorizationKey({ checks: permissionsToCheck }), + key: getAuthorizationKey({ checks: permissionChecks }), data: MockPermissions, }, ], diff --git a/site/src/modules/management/DeploymentSettingsLayout.tsx b/site/src/modules/management/DeploymentSettingsLayout.tsx index 676a24c936246..c40b6440a81c3 100644 --- a/site/src/modules/management/DeploymentSettingsLayout.tsx +++ b/site/src/modules/management/DeploymentSettingsLayout.tsx @@ -8,19 +8,31 @@ import { import { Loader } from "components/Loader/Loader"; import { useAuthenticated } from "contexts/auth/RequireAuth"; import { RequirePermission } from "contexts/auth/RequirePermission"; +import { canViewDeploymentSettings } from "contexts/auth/permissions"; import { type FC, Suspense } from "react"; -import { Outlet } from "react-router-dom"; +import { Navigate, Outlet, useLocation } from "react-router-dom"; import { DeploymentSidebar } from "./DeploymentSidebar"; const DeploymentSettingsLayout: FC = () => { const { permissions } = useAuthenticated(); + const location = useLocation(); - // The deployment settings page also contains users, audit logs, and groups - // so this page must be visible if you can see any of these. - const canViewDeploymentSettingsPage = - permissions.viewDeploymentValues || - permissions.viewAllUsers || - permissions.viewAnyAuditLog; + if (location.pathname === "/deployment") { + return ( + + ); + } + + // The deployment settings page also contains users and groups and more so + // this page must be visible if you can see any of these. + const canViewDeploymentSettingsPage = canViewDeploymentSettings(permissions); return ( diff --git a/site/src/modules/management/DeploymentSettingsProvider.tsx b/site/src/modules/management/DeploymentSettingsProvider.tsx index 633c67d67fe44..766d75aacd216 100644 --- a/site/src/modules/management/DeploymentSettingsProvider.tsx +++ b/site/src/modules/management/DeploymentSettingsProvider.tsx @@ -2,8 +2,6 @@ import type { DeploymentConfig } from "api/api"; import { deploymentConfig } from "api/queries/deployment"; import { ErrorAlert } from "components/Alert/ErrorAlert"; import { Loader } from "components/Loader/Loader"; -import { useAuthenticated } from "contexts/auth/RequireAuth"; -import { RequirePermission } from "contexts/auth/RequirePermission"; import { type FC, createContext, useContext } from "react"; import { useQuery } from "react-query"; import { Outlet } from "react-router-dom"; @@ -28,19 +26,8 @@ export const useDeploymentSettings = (): DeploymentSettingsValue => { }; const DeploymentSettingsProvider: FC = () => { - const { permissions } = useAuthenticated(); const deploymentConfigQuery = useQuery(deploymentConfig()); - // The deployment settings page also contains users, audit logs, and groups - // so this page must be visible if you can see any of these. - const canViewDeploymentSettingsPage = - permissions.viewDeploymentValues || - permissions.viewAllUsers || - permissions.viewAnyAuditLog; - - // Not a huge problem to unload the content in the event of an error, - // because the sidebar rendering isn't tied to this. Even if the user hits - // a 403 error, they'll still have navigation options if (deploymentConfigQuery.error) { return ; } @@ -50,13 +37,11 @@ const DeploymentSettingsProvider: FC = () => { } return ( - - - - - + + + ); }; diff --git a/site/src/modules/management/organizationPermissions.tsx b/site/src/modules/management/organizationPermissions.tsx index 2059d8fd6f76f..1b79e11e68ca0 100644 --- a/site/src/modules/management/organizationPermissions.tsx +++ b/site/src/modules/management/organizationPermissions.tsx @@ -135,65 +135,3 @@ export const canEditOrganization = ( permissions.createOrgRoles) ); }; - -export type AnyOrganizationPermissions = { - [k in AnyOrganizationPermissionName]: boolean; -}; - -export type AnyOrganizationPermissionName = - keyof typeof anyOrganizationPermissionChecks; - -export const anyOrganizationPermissionChecks = { - viewAnyMembers: { - object: { - resource_type: "organization_member", - any_org: true, - }, - action: "read", - }, - editAnyGroups: { - object: { - resource_type: "group", - any_org: true, - }, - action: "update", - }, - assignAnyRoles: { - object: { - resource_type: "assign_org_role", - any_org: true, - }, - action: "assign", - }, - viewAnyIdpSyncSettings: { - object: { - resource_type: "idpsync_settings", - any_org: true, - }, - action: "read", - }, - editAnySettings: { - object: { - resource_type: "organization", - any_org: true, - }, - action: "update", - }, -} as const satisfies Record; - -/** - * Checks if the user can view or edit members or groups for the organization - * that produced the given OrganizationPermissions. - */ -export const canViewAnyOrganization = ( - permissions: AnyOrganizationPermissions | undefined, -): permissions is AnyOrganizationPermissions => { - return ( - permissions !== undefined && - (permissions.viewAnyMembers || - permissions.editAnyGroups || - permissions.assignAnyRoles || - permissions.viewAnyIdpSyncSettings || - permissions.editAnySettings) - ); -}; diff --git a/site/src/pages/TerminalPage/TerminalPage.stories.tsx b/site/src/pages/TerminalPage/TerminalPage.stories.tsx index b9dfeba1d811d..f50b75bac4a26 100644 --- a/site/src/pages/TerminalPage/TerminalPage.stories.tsx +++ b/site/src/pages/TerminalPage/TerminalPage.stories.tsx @@ -1,11 +1,10 @@ import type { Meta, StoryObj } from "@storybook/react"; import { getAuthorizationKey } from "api/queries/authCheck"; -import { anyOrganizationPermissionsKey } from "api/queries/organizations"; import { workspaceByOwnerAndNameKey } from "api/queries/workspaces"; import type { Workspace, WorkspaceAgentLifecycle } from "api/typesGenerated"; import { AuthProvider } from "contexts/auth/AuthProvider"; import { RequireAuth } from "contexts/auth/RequireAuth"; -import { permissionsToCheck } from "contexts/auth/permissions"; +import { permissionChecks } from "contexts/auth/permissions"; import { reactRouterOutlet, reactRouterParameters, @@ -74,10 +73,9 @@ const meta = { { key: ["appearance"], data: MockAppearanceConfig }, { key: ["organizations"], data: [MockDefaultOrganization] }, { - key: getAuthorizationKey({ checks: permissionsToCheck }), + key: getAuthorizationKey({ checks: permissionChecks }), data: { editWorkspaceProxies: true }, }, - { key: anyOrganizationPermissionsKey, data: {} }, ], chromatic: { delay: 300 }, }, diff --git a/site/src/router.tsx b/site/src/router.tsx index 66d37f92aeaf1..ebb9e6763d058 100644 --- a/site/src/router.tsx +++ b/site/src/router.tsx @@ -453,8 +453,6 @@ export const router = createBrowserRouter( path="notifications" element={} /> - } /> - } /> @@ -476,6 +474,9 @@ export const router = createBrowserRouter( } /> {groupsRouter()} + + } /> + } /> }> diff --git a/site/src/testHelpers/entities.ts b/site/src/testHelpers/entities.ts index 12654bc064fee..aa87ac7fbf6fc 100644 --- a/site/src/testHelpers/entities.ts +++ b/site/src/testHelpers/entities.ts @@ -2856,6 +2856,41 @@ export const MockPermissions: Permissions = { viewAllLicenses: true, viewNotificationTemplate: true, viewOrganizationIDPSyncSettings: true, + viewDebugInfo: true, + assignAnyRoles: true, + editAnyGroups: true, + editAnySettings: true, + viewAnyIdpSyncSettings: true, + viewAnyMembers: true, +}; + +export const MockNoPermissions: Permissions = { + createTemplates: false, + createUser: false, + deleteTemplates: false, + updateTemplates: false, + viewAllUsers: false, + updateUsers: false, + viewAnyAuditLog: false, + viewDeploymentValues: false, + editDeploymentValues: false, + viewUpdateCheck: false, + viewDeploymentStats: false, + viewExternalAuthConfig: false, + readWorkspaceProxies: false, + editWorkspaceProxies: false, + createOrganization: false, + viewAnyGroup: false, + createGroup: false, + viewAllLicenses: false, + viewNotificationTemplate: false, + viewOrganizationIDPSyncSettings: false, + viewDebugInfo: false, + assignAnyRoles: false, + editAnyGroups: false, + editAnySettings: false, + viewAnyIdpSyncSettings: false, + viewAnyMembers: false, }; export const MockOrganizationPermissions: OrganizationPermissions = { @@ -2890,29 +2925,6 @@ export const MockNoOrganizationPermissions: OrganizationPermissions = { editIdpSyncSettings: false, }; -export const MockNoPermissions: Permissions = { - createTemplates: false, - createUser: false, - deleteTemplates: false, - updateTemplates: false, - viewAllUsers: false, - updateUsers: false, - viewAnyAuditLog: false, - viewDeploymentValues: false, - editDeploymentValues: false, - viewUpdateCheck: false, - viewDeploymentStats: false, - viewExternalAuthConfig: false, - readWorkspaceProxies: false, - editWorkspaceProxies: false, - createOrganization: false, - viewAnyGroup: false, - createGroup: false, - viewAllLicenses: false, - viewNotificationTemplate: false, - viewOrganizationIDPSyncSettings: false, -}; - export const MockDeploymentConfig: DeploymentConfig = { config: { enable_terraform_debug_mode: true, diff --git a/site/src/testHelpers/handlers.ts b/site/src/testHelpers/handlers.ts index b458956b17a1d..71e67697572e2 100644 --- a/site/src/testHelpers/handlers.ts +++ b/site/src/testHelpers/handlers.ts @@ -1,7 +1,7 @@ import fs from "node:fs"; import path from "node:path"; import type { CreateWorkspaceBuildRequest } from "api/typesGenerated"; -import { permissionsToCheck } from "contexts/auth/permissions"; +import { permissionChecks } from "contexts/auth/permissions"; import { http, HttpResponse } from "msw"; import * as M from "./entities"; import { MockGroup, MockWorkspaceQuota } from "./entities"; @@ -173,7 +173,7 @@ export const handlers = [ }), http.post("/api/v2/authcheck", () => { const permissions = [ - ...Object.keys(permissionsToCheck), + ...Object.keys(permissionChecks), "canUpdateTemplate", "updateWorkspace", ]; diff --git a/site/src/testHelpers/storybook.tsx b/site/src/testHelpers/storybook.tsx index 2b81bf16cd40f..fdaeda69f15c1 100644 --- a/site/src/testHelpers/storybook.tsx +++ b/site/src/testHelpers/storybook.tsx @@ -6,7 +6,7 @@ import { hasFirstUserKey, meKey } from "api/queries/users"; import type { Entitlements } from "api/typesGenerated"; import { GlobalSnackbar } from "components/GlobalSnackbar/GlobalSnackbar"; import { AuthProvider } from "contexts/auth/AuthProvider"; -import { permissionsToCheck } from "contexts/auth/permissions"; +import { permissionChecks } from "contexts/auth/permissions"; import { DashboardContext } from "modules/dashboard/DashboardProvider"; import { DeploymentSettingsContext } from "modules/management/DeploymentSettingsProvider"; import { OrganizationSettingsContext } from "modules/management/OrganizationSettingsLayout"; @@ -114,7 +114,7 @@ export const withAuthProvider = (Story: FC, { parameters }: StoryContext) => { queryClient.setQueryData(meKey, parameters.user); queryClient.setQueryData(hasFirstUserKey, true); queryClient.setQueryData( - getAuthorizationKey({ checks: permissionsToCheck }), + getAuthorizationKey({ checks: permissionChecks }), parameters.permissions ?? {}, ); From 828f149b88b56211d44ff4fcf7c273ed6c1568c9 Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Wed, 16 Apr 2025 15:03:25 +0200 Subject: [PATCH 03/10] chore: fix gpg forwarding test (cherry-pick #17355) (#17415) Cherry-picked chore: fix gpg forwarding test (#17355) Co-authored-by: Dean Sheather --- .gitignore | 3 +++ cli/ssh_test.go | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index f98101cd7f920..f3bd9b24938e2 100644 --- a/.gitignore +++ b/.gitignore @@ -78,3 +78,6 @@ result # Zed .zed_server + +# dlv debug binaries for go tests +__debug_bin* diff --git a/cli/ssh_test.go b/cli/ssh_test.go index d20278bbf7ced..caeb09f123a4c 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -1908,7 +1908,9 @@ Expire-Date: 0 tpty.WriteLine("gpg --list-keys && echo gpg-''-listkeys-command-done") listKeysOutput := tpty.ExpectMatch("gpg--listkeys-command-done") require.Contains(t, listKeysOutput, "[ultimate] Coder Test ") - require.Contains(t, listKeysOutput, "[ultimate] Dean Sheather (work key) ") + // It's fine that this key is expired. We're just testing that the key trust + // gets synced properly. + require.Contains(t, listKeysOutput, "[ expired] Dean Sheather (work key) ") // Try to sign something. This demonstrates that the forwarding is // working as expected, since the workspace doesn't have access to the From 7d0e909ca26baa9559615455fa9d968a7efd1996 Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Wed, 16 Apr 2025 15:54:18 +0200 Subject: [PATCH 04/10] fix: reduce excessive logging when database is unreachable (cherry-pick #17363) (#17420) Cherry-picked fix: reduce excessive logging when database is unreachable (#17363) Fixes #17045 --------- Signed-off-by: Danny Kopping Signed-off-by: Danny Kopping Co-authored-by: Danny Kopping --- coderd/coderd.go | 9 ++--- coderd/tailnet.go | 16 ++++++++- coderd/tailnet_test.go | 41 ++++++++++++++++++++++ coderd/workspaceagents.go | 10 ++++++ codersdk/database.go | 7 ++++ codersdk/workspacesdk/dialer.go | 15 +++++--- codersdk/workspacesdk/workspacesdk_test.go | 35 ++++++++++++++++++ provisionerd/provisionerd.go | 30 +++++++++++----- site/src/api/typesGenerated.ts | 3 ++ tailnet/controllers.go | 14 ++++++-- 10 files changed, 160 insertions(+), 20 deletions(-) create mode 100644 codersdk/database.go diff --git a/coderd/coderd.go b/coderd/coderd.go index 1cb4c0592b66e..d2e0835f3c877 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -652,10 +652,11 @@ func New(options *Options) *API { api.Auditor.Store(&options.Auditor) api.TailnetCoordinator.Store(&options.TailnetCoordinator) dialer := &InmemTailnetDialer{ - CoordPtr: &api.TailnetCoordinator, - DERPFn: api.DERPMap, - Logger: options.Logger, - ClientID: uuid.New(), + CoordPtr: &api.TailnetCoordinator, + DERPFn: api.DERPMap, + Logger: options.Logger, + ClientID: uuid.New(), + DatabaseHealthCheck: api.Database, } stn, err := NewServerTailnet(api.ctx, options.Logger, diff --git a/coderd/tailnet.go b/coderd/tailnet.go index b06219db40a78..cfdc667f4da0f 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -24,9 +24,11 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/site" "github.com/coder/coder/v2/tailnet" @@ -534,6 +536,10 @@ func NewMultiAgentController(ctx context.Context, logger slog.Logger, tracer tra return m } +type Pinger interface { + Ping(context.Context) (time.Duration, error) +} + // InmemTailnetDialer is a tailnet.ControlProtocolDialer that connects to a Coordinator and DERPMap // service running in the same memory space. type InmemTailnetDialer struct { @@ -541,9 +547,17 @@ type InmemTailnetDialer struct { DERPFn func() *tailcfg.DERPMap Logger slog.Logger ClientID uuid.UUID + // DatabaseHealthCheck is used to validate that the store is reachable. + DatabaseHealthCheck Pinger } -func (a *InmemTailnetDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) { +func (a *InmemTailnetDialer) Dial(ctx context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) { + if a.DatabaseHealthCheck != nil { + if _, err := a.DatabaseHealthCheck.Ping(ctx); err != nil { + return tailnet.ControlProtocolClients{}, xerrors.Errorf("%w: %v", codersdk.ErrDatabaseNotReachable, err) + } + } + coord := a.CoordPtr.Load() if coord == nil { return tailnet.ControlProtocolClients{}, xerrors.Errorf("tailnet coordinator not initialized") diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index b0aaaedc769c0..28265404c3eae 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -11,6 +11,7 @@ import ( "strconv" "sync/atomic" "testing" + "time" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" @@ -18,6 +19,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" "tailscale.com/tailcfg" "github.com/coder/coder/v2/agent" @@ -25,6 +27,7 @@ import ( "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/tailnet" @@ -365,6 +368,44 @@ func TestServerTailnet_ReverseProxy(t *testing.T) { }) } +func TestDialFailure(t *testing.T) { + t.Parallel() + + // Setup. + ctx := testutil.Context(t, testutil.WaitShort) + logger := testutil.Logger(t) + + // Given: a tailnet coordinator. + coord := tailnet.NewCoordinator(logger) + t.Cleanup(func() { + _ = coord.Close() + }) + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + + // Given: a fake DB healthchecker which will always fail. + fch := &failingHealthcheck{} + + // When: dialing the in-memory coordinator. + dialer := &coderd.InmemTailnetDialer{ + CoordPtr: &coordPtr, + Logger: logger, + ClientID: uuid.UUID{5}, + DatabaseHealthCheck: fch, + } + _, err := dialer.Dial(ctx, nil) + + // Then: the error returned reflects the database has failed its healthcheck. + require.ErrorIs(t, err, codersdk.ErrDatabaseNotReachable) +} + +type failingHealthcheck struct{} + +func (failingHealthcheck) Ping(context.Context) (time.Duration, error) { + // Simulate a database connection error. + return 0, xerrors.New("oops") +} + type wrappedListener struct { net.Listener dials int32 diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index ddfb21a751671..aae93c0a370a0 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -897,6 +897,16 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + // Ensure the database is reachable before proceeding. + _, err := api.Database.Ping(ctx) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: codersdk.DatabaseNotReachable, + Detail: err.Error(), + }) + return + } + // This route accepts user API key auth and workspace proxy auth. The moon actor has // full permissions so should be able to pass this authz check. workspace := httpmw.WorkspaceParam(r) diff --git a/codersdk/database.go b/codersdk/database.go new file mode 100644 index 0000000000000..1a33da6362e0d --- /dev/null +++ b/codersdk/database.go @@ -0,0 +1,7 @@ +package codersdk + +import "golang.org/x/xerrors" + +const DatabaseNotReachable = "database not reachable" + +var ErrDatabaseNotReachable = xerrors.New(DatabaseNotReachable) diff --git a/codersdk/workspacesdk/dialer.go b/codersdk/workspacesdk/dialer.go index 23d618761b807..71cac0c5f04b1 100644 --- a/codersdk/workspacesdk/dialer.go +++ b/codersdk/workspacesdk/dialer.go @@ -11,17 +11,19 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" + "github.com/coder/websocket" + "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" - "github.com/coder/websocket" ) var permanentErrorStatuses = []int{ - http.StatusConflict, // returned if client/agent connections disabled (browser only) - http.StatusBadRequest, // returned if API mismatch - http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist + http.StatusConflict, // returned if client/agent connections disabled (browser only) + http.StatusBadRequest, // returned if API mismatch + http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist + http.StatusInternalServerError, // returned if database is not reachable, } type WebsocketDialer struct { @@ -89,6 +91,11 @@ func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenControl "Ensure your client release version (%s, different than the API version) matches the server release version", buildinfo.Version()) } + + if sdkErr.Message == codersdk.DatabaseNotReachable && + sdkErr.StatusCode() == http.StatusInternalServerError { + err = xerrors.Errorf("%w: %v", codersdk.ErrDatabaseNotReachable, err) + } } w.connected <- err return tailnet.ControlProtocolClients{}, err diff --git a/codersdk/workspacesdk/workspacesdk_test.go b/codersdk/workspacesdk/workspacesdk_test.go index 317db4471319f..e7ccd96e208fa 100644 --- a/codersdk/workspacesdk/workspacesdk_test.go +++ b/codersdk/workspacesdk/workspacesdk_test.go @@ -1,13 +1,21 @@ package workspacesdk_test import ( + "net/http" + "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" + "github.com/coder/websocket" + + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/testutil" ) func TestWorkspaceRewriteDERPMap(t *testing.T) { @@ -37,3 +45,30 @@ func TestWorkspaceRewriteDERPMap(t *testing.T) { require.Equal(t, "coconuts.org", node.HostName) require.Equal(t, 44558, node.DERPPort) } + +func TestWorkspaceDialerFailure(t *testing.T) { + t.Parallel() + + // Setup. + ctx := testutil.Context(t, testutil.WaitShort) + logger := testutil.Logger(t) + + // Given: a mock HTTP server which mimicks an unreachable database when calling the coordination endpoint. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{ + Message: codersdk.DatabaseNotReachable, + Detail: "oops", + }) + })) + t.Cleanup(srv.Close) + + u, err := url.Parse(srv.URL) + require.NoError(t, err) + + // When: calling the coordination endpoint. + dialer := workspacesdk.NewWebsocketDialer(logger, u, &websocket.DialOptions{}) + _, err = dialer.Dial(ctx, nil) + + // Then: an error indicating a database issue is returned, to conditionalize the behavior of the caller. + require.ErrorIs(t, err, codersdk.ErrDatabaseNotReachable) +} diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index b461bc593ee36..9adf9951fa488 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -20,12 +20,13 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" + "github.com/coder/retry" + "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionerd/runner" sdkproto "github.com/coder/coder/v2/provisionersdk/proto" - "github.com/coder/retry" ) // Dialer represents the function to create a daemon client connection. @@ -290,7 +291,7 @@ func (p *Server) acquireLoop() { defer p.wg.Done() defer func() { close(p.acquireDoneCh) }() ctx := p.closeContext - for { + for retrier := retry.New(10*time.Millisecond, 1*time.Second); retrier.Wait(ctx); { if p.acquireExit() { return } @@ -299,7 +300,17 @@ func (p *Server) acquireLoop() { p.opts.Logger.Debug(ctx, "shut down before client (re) connected") return } - p.acquireAndRunOne(client) + err := p.acquireAndRunOne(client) + if err != nil && ctx.Err() == nil { // Only log if context is not done. + // Short-circuit: don't wait for the retry delay to exit, if required. + if p.acquireExit() { + return + } + p.opts.Logger.Warn(ctx, "failed to acquire job, retrying", slog.F("delay", fmt.Sprintf("%vms", retrier.Delay.Milliseconds())), slog.Error(err)) + } else { + // Reset the retrier after each successful acquisition. + retrier.Reset() + } } } @@ -318,7 +329,7 @@ func (p *Server) acquireExit() bool { return false } -func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) { +func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) error { ctx := p.closeContext p.opts.Logger.Debug(ctx, "start of acquireAndRunOne") job, err := p.acquireGraceful(client) @@ -327,15 +338,15 @@ func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) { if errors.Is(err, context.Canceled) || errors.Is(err, yamux.ErrSessionShutdown) || errors.Is(err, fasthttputil.ErrInmemoryListenerClosed) { - return + return err } p.opts.Logger.Warn(ctx, "provisionerd was unable to acquire job", slog.Error(err)) - return + return xerrors.Errorf("failed to acquire job: %w", err) } if job.JobId == "" { p.opts.Logger.Debug(ctx, "acquire job successfully canceled") - return + return nil } if len(job.TraceMetadata) > 0 { @@ -390,9 +401,9 @@ func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) { Error: fmt.Sprintf("failed to connect to provisioner: %s", resp.Error), }) if err != nil { - p.opts.Logger.Error(ctx, "provisioner job failed", slog.F("job_id", job.JobId), slog.Error(err)) + p.opts.Logger.Error(ctx, "failed to report provisioner job failed", slog.F("job_id", job.JobId), slog.Error(err)) } - return + return xerrors.Errorf("failed to report provisioner job failed: %w", err) } p.mutex.Lock() @@ -416,6 +427,7 @@ func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) { p.mutex.Lock() p.activeJob = nil p.mutex.Unlock() + return nil } // acquireGraceful attempts to acquire a job from the server, handling canceling the acquisition if we gracefully shut diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 8c350d8f5bc31..abf2b06f1c24f 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -585,6 +585,9 @@ export interface DangerousConfig { readonly allow_all_cors: boolean; } +// From codersdk/database.go +export const DatabaseNotReachable = "database not reachable"; + // From healthsdk/healthsdk.go export interface DatabaseReport extends BaseReport { readonly healthy: boolean; diff --git a/tailnet/controllers.go b/tailnet/controllers.go index bf2ec1d964f56..5afce5b32304e 100644 --- a/tailnet/controllers.go +++ b/tailnet/controllers.go @@ -2,6 +2,7 @@ package tailnet import ( "context" + "errors" "fmt" "io" "maps" @@ -20,11 +21,12 @@ import ( "tailscale.com/util/dnsname" "cdr.dev/slog" + "github.com/coder/quartz" + "github.com/coder/retry" + "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/tailnet/proto" - "github.com/coder/quartz" - "github.com/coder/retry" ) // A Controller connects to the tailnet control plane, and then uses the control protocols to @@ -1362,6 +1364,14 @@ func (c *Controller) Run(ctx context.Context) { if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { return } + + // If the database is unreachable by the control plane, there's not much we can do, so we'll just retry later. + if errors.Is(err, codersdk.ErrDatabaseNotReachable) { + c.logger.Warn(c.ctx, "control plane lost connection to database, retrying", + slog.Error(err), slog.F("delay", fmt.Sprintf("%vms", retrier.Delay.Milliseconds()))) + continue + } + errF := slog.Error(err) var sdkErr *codersdk.Error if xerrors.As(err, &sdkErr) { From 550b81f1ea8d9eb4f09d879e34c8f2a84f428daa Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Tue, 22 Apr 2025 09:17:54 +0200 Subject: [PATCH 05/10] feat: add path & method labels to prometheus metrics for current requests (cherry-pick #17362) (#17493) Cherry-picked feat: add path & method labels to prometheus metrics for current requests (#17362) Closes: #17212 Co-authored-by: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> --- coderd/httpmw/prometheus.go | 52 ++++++++++++++---- coderd/httpmw/prometheus_test.go | 91 ++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 10 deletions(-) diff --git a/coderd/httpmw/prometheus.go b/coderd/httpmw/prometheus.go index b96be84e879e3..8b7b33381c74d 100644 --- a/coderd/httpmw/prometheus.go +++ b/coderd/httpmw/prometheus.go @@ -3,6 +3,7 @@ package httpmw import ( "net/http" "strconv" + "strings" "time" "github.com/go-chi/chi/v5" @@ -22,18 +23,18 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler Name: "requests_processed_total", Help: "The total number of processed API requests", }, []string{"code", "method", "path"}) - requestsConcurrent := factory.NewGauge(prometheus.GaugeOpts{ + requestsConcurrent := factory.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "coderd", Subsystem: "api", Name: "concurrent_requests", Help: "The number of concurrent API requests.", - }) - websocketsConcurrent := factory.NewGauge(prometheus.GaugeOpts{ + }, []string{"method", "path"}) + websocketsConcurrent := factory.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "coderd", Subsystem: "api", Name: "concurrent_websockets", Help: "The total number of concurrent API websockets.", - }) + }, []string{"path"}) websocketsDist := factory.NewHistogramVec(prometheus.HistogramOpts{ Namespace: "coderd", Subsystem: "api", @@ -61,7 +62,6 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler var ( start = time.Now() method = r.Method - rctx = chi.RouteContext(r.Context()) ) sw, ok := w.(*tracing.StatusWriter) @@ -72,16 +72,18 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler var ( dist *prometheus.HistogramVec distOpts []string + path = getRoutePattern(r) ) + // We want to count WebSockets separately. if httpapi.IsWebsocketUpgrade(r) { - websocketsConcurrent.Inc() - defer websocketsConcurrent.Dec() + websocketsConcurrent.WithLabelValues(path).Inc() + defer websocketsConcurrent.WithLabelValues(path).Dec() dist = websocketsDist } else { - requestsConcurrent.Inc() - defer requestsConcurrent.Dec() + requestsConcurrent.WithLabelValues(method, path).Inc() + defer requestsConcurrent.WithLabelValues(method, path).Dec() dist = requestsDist distOpts = []string{method} @@ -89,7 +91,6 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler next.ServeHTTP(w, r) - path := rctx.RoutePattern() distOpts = append(distOpts, path) statusStr := strconv.Itoa(sw.Status) @@ -98,3 +99,34 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler }) } } + +func getRoutePattern(r *http.Request) string { + rctx := chi.RouteContext(r.Context()) + if rctx == nil { + return "" + } + + if pattern := rctx.RoutePattern(); pattern != "" { + // Pattern is already available + return pattern + } + + routePath := r.URL.Path + if r.URL.RawPath != "" { + routePath = r.URL.RawPath + } + + tctx := chi.NewRouteContext() + routes := rctx.Routes + if routes != nil && !routes.Match(tctx, r.Method, routePath) { + // No matching pattern. /api/* requests will be matched as "UNKNOWN" + // All other ones will be matched as "STATIC". + if strings.HasPrefix(routePath, "/api/") { + return "UNKNOWN" + } + return "STATIC" + } + + // tctx has the updated pattern, since Match mutates it + return tctx.RoutePattern() +} diff --git a/coderd/httpmw/prometheus_test.go b/coderd/httpmw/prometheus_test.go index a51eea5d00312..d40558f5ca5e7 100644 --- a/coderd/httpmw/prometheus_test.go +++ b/coderd/httpmw/prometheus_test.go @@ -8,14 +8,19 @@ import ( "github.com/go-chi/chi/v5" "github.com/prometheus/client_golang/prometheus" + cm "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/testutil" + "github.com/coder/websocket" ) func TestPrometheus(t *testing.T) { t.Parallel() + t.Run("All", func(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/", nil) @@ -29,4 +34,90 @@ func TestPrometheus(t *testing.T) { require.NoError(t, err) require.Greater(t, len(metrics), 0) }) + + t.Run("Concurrent", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + reg := prometheus.NewRegistry() + promMW := httpmw.Prometheus(reg) + + // Create a test handler to simulate a WebSocket connection + testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(rw, r, nil) + if !assert.NoError(t, err, "failed to accept websocket") { + return + } + defer conn.Close(websocket.StatusGoingAway, "") + }) + + wrappedHandler := promMW(testHandler) + + r := chi.NewRouter() + r.Use(tracing.StatusWriterMiddleware, promMW) + r.Get("/api/v2/build/{build}/logs", func(rw http.ResponseWriter, r *http.Request) { + wrappedHandler.ServeHTTP(rw, r) + }) + + srv := httptest.NewServer(r) + defer srv.Close() + // nolint: bodyclose + conn, _, err := websocket.Dial(ctx, srv.URL+"/api/v2/build/1/logs", nil) + require.NoError(t, err, "failed to dial WebSocket") + defer conn.Close(websocket.StatusNormalClosure, "") + + metrics, err := reg.Gather() + require.NoError(t, err) + require.Greater(t, len(metrics), 0) + metricLabels := getMetricLabels(metrics) + + concurrentWebsockets, ok := metricLabels["coderd_api_concurrent_websockets"] + require.True(t, ok, "coderd_api_concurrent_websockets metric not found") + require.Equal(t, "/api/v2/build/{build}/logs", concurrentWebsockets["path"]) + }) + + t.Run("UserRoute", func(t *testing.T) { + t.Parallel() + reg := prometheus.NewRegistry() + promMW := httpmw.Prometheus(reg) + + r := chi.NewRouter() + r.With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {}) + + req := httptest.NewRequest("GET", "/api/v2/users/john", nil) + + sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()} + + r.ServeHTTP(sw, req) + + metrics, err := reg.Gather() + require.NoError(t, err) + require.Greater(t, len(metrics), 0) + metricLabels := getMetricLabels(metrics) + + reqProcessed, ok := metricLabels["coderd_api_requests_processed_total"] + require.True(t, ok, "coderd_api_requests_processed_total metric not found") + require.Equal(t, "/api/v2/users/{user}", reqProcessed["path"]) + require.Equal(t, "GET", reqProcessed["method"]) + + concurrentRequests, ok := metricLabels["coderd_api_concurrent_requests"] + require.True(t, ok, "coderd_api_concurrent_requests metric not found") + require.Equal(t, "/api/v2/users/{user}", concurrentRequests["path"]) + require.Equal(t, "GET", concurrentRequests["method"]) + }) +} + +func getMetricLabels(metrics []*cm.MetricFamily) map[string]map[string]string { + metricLabels := map[string]map[string]string{} + for _, metricFamily := range metrics { + metricName := metricFamily.GetName() + metricLabels[metricName] = map[string]string{} + for _, metric := range metricFamily.GetMetric() { + for _, labelPair := range metric.GetLabel() { + metricLabels[metricName][labelPair.GetName()] = labelPair.GetValue() + } + } + } + return metricLabels } From fc8454b4759368d35b6fb0b688f1af95cb17b47e Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Tue, 22 Apr 2025 09:56:03 +0200 Subject: [PATCH 06/10] feat: extend request logs with auth & DB info (#17498) --- Makefile | 7 +- coderd/coderd.go | 3 +- coderd/database/dbauthz/dbauthz.go | 31 +- coderd/database/queries.sql.go | 6 +- coderd/database/queries/users.sql | 4 +- coderd/httpmw/apikey.go | 2 + coderd/httpmw/logger.go | 76 ----- coderd/httpmw/loggermw/logger.go | 203 ++++++++++++ .../httpmw/loggermw/logger_internal_test.go | 311 ++++++++++++++++++ .../httpmw/loggermw/loggermock/loggermock.go | 83 +++++ coderd/httpmw/workspaceagentparam.go | 11 + coderd/httpmw/workspaceparam.go | 15 + coderd/provisionerjobs.go | 4 + coderd/provisionerjobs_internal_test.go | 7 + coderd/rbac/authz.go | 25 ++ coderd/workspaceagents.go | 10 + enterprise/coderd/provisionerdaemons.go | 5 + enterprise/wsproxy/wsproxy.go | 3 +- tailnet/test/integration/integration.go | 4 +- 19 files changed, 716 insertions(+), 94 deletions(-) delete mode 100644 coderd/httpmw/logger.go create mode 100644 coderd/httpmw/loggermw/logger.go create mode 100644 coderd/httpmw/loggermw/logger_internal_test.go create mode 100644 coderd/httpmw/loggermw/loggermock/loggermock.go diff --git a/Makefile b/Makefile index fbd324974f218..962d5bd71221e 100644 --- a/Makefile +++ b/Makefile @@ -564,7 +564,8 @@ GEN_FILES := \ examples/examples.gen.json \ $(TAILNETTEST_MOCKS) \ coderd/database/pubsub/psmock/psmock.go \ - agent/agentcontainers/acmock/acmock.go + agent/agentcontainers/acmock/acmock.go \ + coderd/httpmw/loggermw/loggermock/loggermock.go # all gen targets should be added here and to gen/mark-fresh @@ -600,6 +601,7 @@ gen/mark-fresh: $(TAILNETTEST_MOCKS) \ coderd/database/pubsub/psmock/psmock.go \ agent/agentcontainers/acmock/acmock.go \ + coderd/httpmw/loggermw/loggermock/loggermock.go " for file in $$files; do @@ -634,6 +636,9 @@ coderd/database/pubsub/psmock/psmock.go: coderd/database/pubsub/pubsub.go agent/agentcontainers/acmock/acmock.go: agent/agentcontainers/containers.go go generate ./agent/agentcontainers/acmock/ +coderd/httpmw/loggermw/loggermock/loggermock.go: coderd/httpmw/loggermw/logger.go + go generate ./coderd/httpmw/loggermw/loggermock/ + $(TAILNETTEST_MOCKS): tailnet/coordinator.go tailnet/service.go go generate ./tailnet/tailnettest/ diff --git a/coderd/coderd.go b/coderd/coderd.go index d2e0835f3c877..ebf2000a3d112 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -63,6 +63,7 @@ import ( "github.com/coder/coder/v2/coderd/healthcheck/derphealth" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/metricscache" "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/portsharing" @@ -788,7 +789,7 @@ func New(options *Options) *API { tracing.Middleware(api.TracerProvider), httpmw.AttachRequestID, httpmw.ExtractRealIP(api.RealIPConfig), - httpmw.Logger(api.Logger), + loggermw.Logger(api.Logger), singleSlashMW, rolestore.CustomRoleMW, prometheusMW, diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index faebb809f4599..a633c976c6494 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -24,6 +24,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi/httpapiconstraints" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/provisionersdk" @@ -162,6 +163,7 @@ func ActorFromContext(ctx context.Context) (rbac.Subject, bool) { var ( subjectProvisionerd = rbac.Subject{ + Type: rbac.SubjectTypeProvisionerd, FriendlyName: "Provisioner Daemon", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -195,6 +197,7 @@ var ( }.WithCachedASTValue() subjectAutostart = rbac.Subject{ + Type: rbac.SubjectTypeAutostart, FriendlyName: "Autostart", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -218,6 +221,7 @@ var ( // See unhanger package. subjectHangDetector = rbac.Subject{ + Type: rbac.SubjectTypeHangDetector, FriendlyName: "Hang Detector", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -238,6 +242,7 @@ var ( // See cryptokeys package. subjectCryptoKeyRotator = rbac.Subject{ + Type: rbac.SubjectTypeCryptoKeyRotator, FriendlyName: "Crypto Key Rotator", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -256,6 +261,7 @@ var ( // See cryptokeys package. subjectCryptoKeyReader = rbac.Subject{ + Type: rbac.SubjectTypeCryptoKeyReader, FriendlyName: "Crypto Key Reader", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -273,6 +279,7 @@ var ( }.WithCachedASTValue() subjectNotifier = rbac.Subject{ + Type: rbac.SubjectTypeNotifier, FriendlyName: "Notifier", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -290,6 +297,7 @@ var ( }.WithCachedASTValue() subjectResourceMonitor = rbac.Subject{ + Type: rbac.SubjectTypeResourceMonitor, FriendlyName: "Resource Monitor", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -308,6 +316,7 @@ var ( }.WithCachedASTValue() subjectSystemRestricted = rbac.Subject{ + Type: rbac.SubjectTypeSystemRestricted, FriendlyName: "System", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -342,6 +351,7 @@ var ( }.WithCachedASTValue() subjectSystemReadProvisionerDaemons = rbac.Subject{ + Type: rbac.SubjectTypeSystemReadProvisionerDaemons, FriendlyName: "Provisioner Daemons Reader", ID: uuid.Nil.String(), Roles: rbac.Roles([]rbac.Role{ @@ -362,53 +372,53 @@ var ( // AsProvisionerd returns a context with an actor that has permissions required // for provisionerd to function. func AsProvisionerd(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectProvisionerd) + return As(ctx, subjectProvisionerd) } // AsAutostart returns a context with an actor that has permissions required // for autostart to function. func AsAutostart(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectAutostart) + return As(ctx, subjectAutostart) } // AsHangDetector returns a context with an actor that has permissions required // for unhanger.Detector to function. func AsHangDetector(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectHangDetector) + return As(ctx, subjectHangDetector) } // AsKeyRotator returns a context with an actor that has permissions required for rotating crypto keys. func AsKeyRotator(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyRotator) + return As(ctx, subjectCryptoKeyRotator) } // AsKeyReader returns a context with an actor that has permissions required for reading crypto keys. func AsKeyReader(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyReader) + return As(ctx, subjectCryptoKeyReader) } // AsNotifier returns a context with an actor that has permissions required for // creating/reading/updating/deleting notifications. func AsNotifier(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectNotifier) + return As(ctx, subjectNotifier) } // AsResourceMonitor returns a context with an actor that has permissions required for // updating resource monitors. func AsResourceMonitor(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectResourceMonitor) + return As(ctx, subjectResourceMonitor) } // AsSystemRestricted returns a context with an actor that has permissions // required for various system operations (login, logout, metrics cache). func AsSystemRestricted(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectSystemRestricted) + return As(ctx, subjectSystemRestricted) } // AsSystemReadProvisionerDaemons returns a context with an actor that has permissions // to read provisioner daemons. func AsSystemReadProvisionerDaemons(ctx context.Context) context.Context { - return context.WithValue(ctx, authContextKey{}, subjectSystemReadProvisionerDaemons) + return As(ctx, subjectSystemReadProvisionerDaemons) } var AsRemoveActor = rbac.Subject{ @@ -426,6 +436,9 @@ func As(ctx context.Context, actor rbac.Subject) context.Context { // should be removed from the context. return context.WithValue(ctx, authContextKey{}, nil) } + if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil { + rlogger.WithAuthContext(actor) + } return context.WithValue(ctx, authContextKey{}, actor) } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 37c79537c622f..638124057f58b 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -11033,10 +11033,10 @@ func (q *sqlQuerier) GetActiveUserCount(ctx context.Context) (int64, error) { const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one SELECT - -- username is returned just to help for logging purposes + -- username and email are returned just to help for logging purposes -- status is used to enforce 'suspended' users, as all roles are ignored -- when suspended. - id, username, status, + id, username, status, email, -- All user roles, including their org roles. array_cat( -- All users are members @@ -11077,6 +11077,7 @@ type GetAuthorizationUserRolesRow struct { ID uuid.UUID `db:"id" json:"id"` Username string `db:"username" json:"username"` Status UserStatus `db:"status" json:"status"` + Email string `db:"email" json:"email"` Roles []string `db:"roles" json:"roles"` Groups []string `db:"groups" json:"groups"` } @@ -11090,6 +11091,7 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid. &i.ID, &i.Username, &i.Status, + &i.Email, pq.Array(&i.Roles), pq.Array(&i.Groups), ) diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index 1f30a2c2c1d24..5e9ffbe50cd31 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -244,10 +244,10 @@ WHERE -- This function returns roles for authorization purposes. Implied member roles -- are included. SELECT - -- username is returned just to help for logging purposes + -- username and email are returned just to help for logging purposes -- status is used to enforce 'suspended' users, as all roles are ignored -- when suspended. - id, username, status, + id, username, status, email, -- All user roles, including their org roles. array_cat( -- All users are members diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 38ba74031ba46..2600962e39171 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -465,7 +465,9 @@ func UserRBACSubject(ctx context.Context, db database.Store, userID uuid.UUID, s } actor := rbac.Subject{ + Type: rbac.SubjectTypeUser, FriendlyName: roles.Username, + Email: roles.Email, ID: userID.String(), Roles: rbacRoles, Groups: roles.Groups, diff --git a/coderd/httpmw/logger.go b/coderd/httpmw/logger.go deleted file mode 100644 index 79e95cf859d8e..0000000000000 --- a/coderd/httpmw/logger.go +++ /dev/null @@ -1,76 +0,0 @@ -package httpmw - -import ( - "context" - "fmt" - "net/http" - "time" - - "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/httpapi" - "github.com/coder/coder/v2/coderd/tracing" -) - -func Logger(log slog.Logger) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - start := time.Now() - - sw, ok := rw.(*tracing.StatusWriter) - if !ok { - panic(fmt.Sprintf("ResponseWriter not a *tracing.StatusWriter; got %T", rw)) - } - - httplog := log.With( - slog.F("host", httpapi.RequestHost(r)), - slog.F("path", r.URL.Path), - slog.F("proto", r.Proto), - slog.F("remote_addr", r.RemoteAddr), - // Include the start timestamp in the log so that we have the - // source of truth. There is at least a theoretical chance that - // there can be a delay between `next.ServeHTTP` ending and us - // actually logging the request. This can also be useful when - // filtering logs that started at a certain time (compared to - // trying to compute the value). - slog.F("start", start), - ) - - next.ServeHTTP(sw, r) - - end := time.Now() - - // Don't log successful health check requests. - if r.URL.Path == "/api/v2" && sw.Status == http.StatusOK { - return - } - - httplog = httplog.With( - slog.F("took", end.Sub(start)), - slog.F("status_code", sw.Status), - slog.F("latency_ms", float64(end.Sub(start)/time.Millisecond)), - ) - - // For status codes 400 and higher we - // want to log the response body. - if sw.Status >= http.StatusInternalServerError { - httplog = httplog.With( - slog.F("response_body", string(sw.ResponseBody())), - ) - } - - // We should not log at level ERROR for 5xx status codes because 5xx - // includes proxy errors etc. It also causes slogtest to fail - // instantly without an error message by default. - logLevelFn := httplog.Debug - if sw.Status >= http.StatusInternalServerError { - logLevelFn = httplog.Warn - } - - // We already capture most of this information in the span (minus - // the response body which we don't want to capture anyways). - tracing.RunWithoutSpan(r.Context(), func(ctx context.Context) { - logLevelFn(ctx, r.Method) - }) - }) - } -} diff --git a/coderd/httpmw/loggermw/logger.go b/coderd/httpmw/loggermw/logger.go new file mode 100644 index 0000000000000..9eeb07a5f10e5 --- /dev/null +++ b/coderd/httpmw/loggermw/logger.go @@ -0,0 +1,203 @@ +package loggermw + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "github.com/go-chi/chi/v5" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/tracing" +) + +func Logger(log slog.Logger) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + start := time.Now() + + sw, ok := rw.(*tracing.StatusWriter) + if !ok { + panic(fmt.Sprintf("ResponseWriter not a *tracing.StatusWriter; got %T", rw)) + } + + httplog := log.With( + slog.F("host", httpapi.RequestHost(r)), + slog.F("path", r.URL.Path), + slog.F("proto", r.Proto), + slog.F("remote_addr", r.RemoteAddr), + // Include the start timestamp in the log so that we have the + // source of truth. There is at least a theoretical chance that + // there can be a delay between `next.ServeHTTP` ending and us + // actually logging the request. This can also be useful when + // filtering logs that started at a certain time (compared to + // trying to compute the value). + slog.F("start", start), + ) + + logContext := NewRequestLogger(httplog, r.Method, start) + + ctx := WithRequestLogger(r.Context(), logContext) + + next.ServeHTTP(sw, r.WithContext(ctx)) + + // Don't log successful health check requests. + if r.URL.Path == "/api/v2" && sw.Status == http.StatusOK { + return + } + + // For status codes 500 and higher we + // want to log the response body. + if sw.Status >= http.StatusInternalServerError { + logContext.WithFields( + slog.F("response_body", string(sw.ResponseBody())), + ) + } + + logContext.WriteLog(r.Context(), sw.Status) + }) + } +} + +type RequestLogger interface { + WithFields(fields ...slog.Field) + WriteLog(ctx context.Context, status int) + WithAuthContext(actor rbac.Subject) +} + +type SlogRequestLogger struct { + log slog.Logger + written bool + message string + start time.Time + // Protects actors map for concurrent writes. + mu sync.RWMutex + actors map[rbac.SubjectType]rbac.Subject +} + +var _ RequestLogger = &SlogRequestLogger{} + +func NewRequestLogger(log slog.Logger, message string, start time.Time) RequestLogger { + return &SlogRequestLogger{ + log: log, + written: false, + message: message, + start: start, + actors: make(map[rbac.SubjectType]rbac.Subject), + } +} + +func (c *SlogRequestLogger) WithFields(fields ...slog.Field) { + c.log = c.log.With(fields...) +} + +func (c *SlogRequestLogger) WithAuthContext(actor rbac.Subject) { + c.mu.Lock() + defer c.mu.Unlock() + c.actors[actor.Type] = actor +} + +func (c *SlogRequestLogger) addAuthContextFields() { + c.mu.RLock() + defer c.mu.RUnlock() + + usr, ok := c.actors[rbac.SubjectTypeUser] + if ok { + c.log = c.log.With( + slog.F("requestor_id", usr.ID), + slog.F("requestor_name", usr.FriendlyName), + slog.F("requestor_email", usr.Email), + ) + } else { + // If there is no user, we log the requestor name for the first + // actor in a defined order. + for _, v := range actorLogOrder { + subj, ok := c.actors[v] + if !ok { + continue + } + c.log = c.log.With( + slog.F("requestor_name", subj.FriendlyName), + ) + break + } + } +} + +var actorLogOrder = []rbac.SubjectType{ + rbac.SubjectTypeAutostart, + rbac.SubjectTypeCryptoKeyReader, + rbac.SubjectTypeCryptoKeyRotator, + rbac.SubjectTypeHangDetector, + rbac.SubjectTypeNotifier, + rbac.SubjectTypePrebuildsOrchestrator, + rbac.SubjectTypeProvisionerd, + rbac.SubjectTypeResourceMonitor, + rbac.SubjectTypeSystemReadProvisionerDaemons, + rbac.SubjectTypeSystemRestricted, +} + +func (c *SlogRequestLogger) WriteLog(ctx context.Context, status int) { + if c.written { + return + } + c.written = true + end := time.Now() + + // Right before we write the log, we try to find the user in the actors + // and add the fields to the log. + c.addAuthContextFields() + + logger := c.log.With( + slog.F("took", end.Sub(c.start)), + slog.F("status_code", status), + slog.F("latency_ms", float64(end.Sub(c.start)/time.Millisecond)), + ) + + // If the request is routed, add the route parameters to the log. + if chiCtx := chi.RouteContext(ctx); chiCtx != nil { + urlParams := chiCtx.URLParams + routeParamsFields := make([]slog.Field, 0, len(urlParams.Keys)) + + for k, v := range urlParams.Keys { + if urlParams.Values[k] != "" { + routeParamsFields = append(routeParamsFields, slog.F("params_"+v, urlParams.Values[k])) + } + } + + if len(routeParamsFields) > 0 { + logger = logger.With(routeParamsFields...) + } + } + + // We already capture most of this information in the span (minus + // the response body which we don't want to capture anyways). + tracing.RunWithoutSpan(ctx, func(ctx context.Context) { + // We should not log at level ERROR for 5xx status codes because 5xx + // includes proxy errors etc. It also causes slogtest to fail + // instantly without an error message by default. + if status >= http.StatusInternalServerError { + logger.Warn(ctx, c.message) + } else { + logger.Debug(ctx, c.message) + } + }) +} + +type logContextKey struct{} + +func WithRequestLogger(ctx context.Context, rl RequestLogger) context.Context { + return context.WithValue(ctx, logContextKey{}, rl) +} + +func RequestLoggerFromContext(ctx context.Context) RequestLogger { + val := ctx.Value(logContextKey{}) + if logCtx, ok := val.(RequestLogger); ok { + return logCtx + } + return nil +} diff --git a/coderd/httpmw/loggermw/logger_internal_test.go b/coderd/httpmw/loggermw/logger_internal_test.go new file mode 100644 index 0000000000000..e88f8a69c178e --- /dev/null +++ b/coderd/httpmw/loggermw/logger_internal_test.go @@ -0,0 +1,311 @@ +package loggermw + +import ( + "context" + "net/http" + "net/http/httptest" + "slices" + "strings" + "sync" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/testutil" + "github.com/coder/websocket" +) + +func TestRequestLogger_WriteLog(t *testing.T) { + t.Parallel() + ctx := context.Background() + + sink := &fakeSink{} + logger := slog.Make(sink) + logger = logger.Leveled(slog.LevelDebug) + logCtx := NewRequestLogger(logger, "GET", time.Now()) + + // Add custom fields + logCtx.WithFields( + slog.F("custom_field", "custom_value"), + ) + + // Write log for 200 status + logCtx.WriteLog(ctx, http.StatusOK) + + require.Len(t, sink.entries, 1, "log was written twice") + + require.Equal(t, sink.entries[0].Message, "GET") + + require.Equal(t, sink.entries[0].Fields[0].Value, "custom_value") + + // Attempt to write again (should be skipped). + logCtx.WriteLog(ctx, http.StatusInternalServerError) + + require.Len(t, sink.entries, 1, "log was written twice") +} + +func TestLoggerMiddleware_SingleRequest(t *testing.T) { + t.Parallel() + + sink := &fakeSink{} + logger := slog.Make(sink) + logger = logger.Leveled(slog.LevelDebug) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + // Create a test handler to simulate an HTTP request + testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write([]byte("OK")) + }) + + // Wrap the test handler with the Logger middleware + loggerMiddleware := Logger(logger) + wrappedHandler := loggerMiddleware(testHandler) + + // Create a test HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path", nil) + require.NoError(t, err, "failed to create request") + + sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()} + + // Serve the request + wrappedHandler.ServeHTTP(sw, req) + + require.Len(t, sink.entries, 1, "log was written twice") + + require.Equal(t, sink.entries[0].Message, "GET") + + fieldsMap := make(map[string]any) + for _, field := range sink.entries[0].Fields { + fieldsMap[field.Name] = field.Value + } + + // Check that the log contains the expected fields + requiredFields := []string{"host", "path", "proto", "remote_addr", "start", "took", "status_code", "latency_ms"} + for _, field := range requiredFields { + _, exists := fieldsMap[field] + require.True(t, exists, "field %q is missing in log fields", field) + } + + require.Len(t, sink.entries[0].Fields, len(requiredFields), "log should contain only the required fields") + + // Check value of the status code + require.Equal(t, fieldsMap["status_code"], http.StatusOK) +} + +func TestLoggerMiddleware_WebSocket(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + sink := &fakeSink{ + newEntries: make(chan slog.SinkEntry, 2), + } + logger := slog.Make(sink) + logger = logger.Leveled(slog.LevelDebug) + done := make(chan struct{}) + wg := sync.WaitGroup{} + // Create a test handler to simulate a WebSocket connection + testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(rw, r, nil) + if !assert.NoError(t, err, "failed to accept websocket") { + return + } + defer conn.Close(websocket.StatusGoingAway, "") + + requestLgr := RequestLoggerFromContext(r.Context()) + requestLgr.WriteLog(r.Context(), http.StatusSwitchingProtocols) + // Block so we can be sure the end of the middleware isn't being called. + wg.Wait() + }) + + // Wrap the test handler with the Logger middleware + loggerMiddleware := Logger(logger) + wrappedHandler := loggerMiddleware(testHandler) + + // RequestLogger expects the ResponseWriter to be *tracing.StatusWriter + customHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + defer close(done) + sw := &tracing.StatusWriter{ResponseWriter: rw} + wrappedHandler.ServeHTTP(sw, r) + }) + + srv := httptest.NewServer(customHandler) + defer srv.Close() + wg.Add(1) + // nolint: bodyclose + conn, _, err := websocket.Dial(ctx, srv.URL, nil) + require.NoError(t, err, "failed to dial WebSocket") + defer conn.Close(websocket.StatusNormalClosure, "") + + // Wait for the log from within the handler + newEntry := testutil.RequireRecvCtx(ctx, t, sink.newEntries) + require.Equal(t, newEntry.Message, "GET") + + // Signal the websocket handler to return (and read to handle the close frame) + wg.Done() + _, _, err = conn.Read(ctx) + require.ErrorAs(t, err, &websocket.CloseError{}, "websocket read should fail with close error") + + // Wait for the request to finish completely and verify we only logged once + _ = testutil.RequireRecvCtx(ctx, t, done) + require.Len(t, sink.entries, 1, "log was written twice") +} + +func TestRequestLogger_HTTPRouteParams(t *testing.T) { + t.Parallel() + + sink := &fakeSink{} + logger := slog.Make(sink) + logger = logger.Leveled(slog.LevelDebug) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("workspace", "test-workspace") + chiCtx.URLParams.Add("agent", "test-agent") + + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + + // Create a test handler to simulate an HTTP request + testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write([]byte("OK")) + }) + + // Wrap the test handler with the Logger middleware + loggerMiddleware := Logger(logger) + wrappedHandler := loggerMiddleware(testHandler) + + // Create a test HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path/}", nil) + require.NoError(t, err, "failed to create request") + + sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()} + + // Serve the request + wrappedHandler.ServeHTTP(sw, req) + + fieldsMap := make(map[string]any) + for _, field := range sink.entries[0].Fields { + fieldsMap[field.Name] = field.Value + } + + // Check that the log contains the expected fields + requiredFields := []string{"workspace", "agent"} + for _, field := range requiredFields { + _, exists := fieldsMap["params_"+field] + require.True(t, exists, "field %q is missing in log fields", field) + } +} + +func TestRequestLogger_RouteParamsLogging(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + params map[string]string + expectedFields []string + }{ + { + name: "EmptyParams", + params: map[string]string{}, + expectedFields: []string{}, + }, + { + name: "SingleParam", + params: map[string]string{ + "workspace": "test-workspace", + }, + expectedFields: []string{"params_workspace"}, + }, + { + name: "MultipleParams", + params: map[string]string{ + "workspace": "test-workspace", + "agent": "test-agent", + "user": "test-user", + }, + expectedFields: []string{"params_workspace", "params_agent", "params_user"}, + }, + { + name: "EmptyValueParam", + params: map[string]string{ + "workspace": "test-workspace", + "agent": "", + }, + expectedFields: []string{"params_workspace"}, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + sink := &fakeSink{} + logger := slog.Make(sink) + logger = logger.Leveled(slog.LevelDebug) + + // Create a route context with the test parameters + chiCtx := chi.NewRouteContext() + for key, value := range tt.params { + chiCtx.URLParams.Add(key, value) + } + + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + logCtx := NewRequestLogger(logger, "GET", time.Now()) + + // Write the log + logCtx.WriteLog(ctx, http.StatusOK) + + require.Len(t, sink.entries, 1, "expected exactly one log entry") + + // Convert fields to map for easier checking + fieldsMap := make(map[string]any) + for _, field := range sink.entries[0].Fields { + fieldsMap[field.Name] = field.Value + } + + // Verify expected fields are present + for _, field := range tt.expectedFields { + value, exists := fieldsMap[field] + require.True(t, exists, "field %q should be present in log", field) + require.Equal(t, tt.params[strings.TrimPrefix(field, "params_")], value, "field %q has incorrect value", field) + } + + // Verify no unexpected fields are present + for field := range fieldsMap { + if field == "took" || field == "status_code" || field == "latency_ms" { + continue // Skip standard fields + } + require.True(t, slices.Contains(tt.expectedFields, field), "unexpected field %q in log", field) + } + }) + } +} + +type fakeSink struct { + entries []slog.SinkEntry + newEntries chan slog.SinkEntry +} + +func (s *fakeSink) LogEntry(_ context.Context, e slog.SinkEntry) { + s.entries = append(s.entries, e) + if s.newEntries != nil { + select { + case s.newEntries <- e: + default: + } + } +} + +func (*fakeSink) Sync() {} diff --git a/coderd/httpmw/loggermw/loggermock/loggermock.go b/coderd/httpmw/loggermw/loggermock/loggermock.go new file mode 100644 index 0000000000000..008f862107ae6 --- /dev/null +++ b/coderd/httpmw/loggermw/loggermock/loggermock.go @@ -0,0 +1,83 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coder/coder/v2/coderd/httpmw/loggermw (interfaces: RequestLogger) +// +// Generated by this command: +// +// mockgen -destination=loggermock/loggermock.go -package=loggermock . RequestLogger +// + +// Package loggermock is a generated GoMock package. +package loggermock + +import ( + context "context" + reflect "reflect" + + slog "cdr.dev/slog" + rbac "github.com/coder/coder/v2/coderd/rbac" + gomock "go.uber.org/mock/gomock" +) + +// MockRequestLogger is a mock of RequestLogger interface. +type MockRequestLogger struct { + ctrl *gomock.Controller + recorder *MockRequestLoggerMockRecorder + isgomock struct{} +} + +// MockRequestLoggerMockRecorder is the mock recorder for MockRequestLogger. +type MockRequestLoggerMockRecorder struct { + mock *MockRequestLogger +} + +// NewMockRequestLogger creates a new mock instance. +func NewMockRequestLogger(ctrl *gomock.Controller) *MockRequestLogger { + mock := &MockRequestLogger{ctrl: ctrl} + mock.recorder = &MockRequestLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRequestLogger) EXPECT() *MockRequestLoggerMockRecorder { + return m.recorder +} + +// WithAuthContext mocks base method. +func (m *MockRequestLogger) WithAuthContext(actor rbac.Subject) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "WithAuthContext", actor) +} + +// WithAuthContext indicates an expected call of WithAuthContext. +func (mr *MockRequestLoggerMockRecorder) WithAuthContext(actor any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithAuthContext", reflect.TypeOf((*MockRequestLogger)(nil).WithAuthContext), actor) +} + +// WithFields mocks base method. +func (m *MockRequestLogger) WithFields(fields ...slog.Field) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range fields { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "WithFields", varargs...) +} + +// WithFields indicates an expected call of WithFields. +func (mr *MockRequestLoggerMockRecorder) WithFields(fields ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithFields", reflect.TypeOf((*MockRequestLogger)(nil).WithFields), fields...) +} + +// WriteLog mocks base method. +func (m *MockRequestLogger) WriteLog(ctx context.Context, status int) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "WriteLog", ctx, status) +} + +// WriteLog indicates an expected call of WriteLog. +func (mr *MockRequestLoggerMockRecorder) WriteLog(ctx, status any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteLog", reflect.TypeOf((*MockRequestLogger)(nil).WriteLog), ctx, status) +} diff --git a/coderd/httpmw/workspaceagentparam.go b/coderd/httpmw/workspaceagentparam.go index a47ce3c377ae0..434e057c0eccc 100644 --- a/coderd/httpmw/workspaceagentparam.go +++ b/coderd/httpmw/workspaceagentparam.go @@ -6,8 +6,11 @@ import ( "github.com/go-chi/chi/v5" + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/codersdk" ) @@ -81,6 +84,14 @@ func ExtractWorkspaceAgentParam(db database.Store) func(http.Handler) http.Handl ctx = context.WithValue(ctx, workspaceAgentParamContextKey{}, agent) chi.RouteContext(ctx).URLParams.Add("workspace", build.WorkspaceID.String()) + + if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil { + rlogger.WithFields( + slog.F("workspace_name", resource.Name), + slog.F("agent_name", agent.Name), + ) + } + next.ServeHTTP(rw, r.WithContext(ctx)) }) } diff --git a/coderd/httpmw/workspaceparam.go b/coderd/httpmw/workspaceparam.go index 21e8dcfd62863..0c4e4f77354fc 100644 --- a/coderd/httpmw/workspaceparam.go +++ b/coderd/httpmw/workspaceparam.go @@ -9,8 +9,11 @@ import ( "github.com/go-chi/chi/v5" "github.com/google/uuid" + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/codersdk" ) @@ -48,6 +51,11 @@ func ExtractWorkspaceParam(db database.Store) func(http.Handler) http.Handler { } ctx = context.WithValue(ctx, workspaceParamContextKey{}, workspace) + + if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil { + rlogger.WithFields(slog.F("workspace_name", workspace.Name)) + } + next.ServeHTTP(rw, r.WithContext(ctx)) }) } @@ -154,6 +162,13 @@ func ExtractWorkspaceAndAgentParam(db database.Store) func(http.Handler) http.Ha ctx = context.WithValue(ctx, workspaceParamContextKey{}, workspace) ctx = context.WithValue(ctx, workspaceAgentParamContextKey{}, agent) + + if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil { + rlogger.WithFields( + slog.F("workspace_name", workspace.Name), + slog.F("agent_name", agent.Name), + ) + } next.ServeHTTP(rw, r.WithContext(ctx)) }) } diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 47963798f4d32..6d75227a14ccd 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -20,6 +20,7 @@ import ( "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/util/slice" @@ -554,6 +555,9 @@ func (f *logFollower) follow() { return } + // Log the request immediately instead of after it completes. + loggermw.RequestLoggerFromContext(f.ctx).WriteLog(f.ctx, http.StatusAccepted) + // no need to wait if the job is done if f.complete { return diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go index af5a7d66a6f4c..f3bc2eb1dea99 100644 --- a/coderd/provisionerjobs_internal_test.go +++ b/coderd/provisionerjobs_internal_test.go @@ -19,6 +19,8 @@ import ( "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw/loggermock" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionersdk" "github.com/coder/coder/v2/testutil" @@ -305,11 +307,16 @@ func Test_logFollower_EndOfLogs(t *testing.T) { JobStatus: database.ProvisionerJobStatusRunning, } + mockLogger := loggermock.NewMockRequestLogger(ctrl) + mockLogger.EXPECT().WriteLog(gomock.Any(), http.StatusAccepted).Times(1) + ctx = loggermw.WithRequestLogger(ctx, mockLogger) + // we need an HTTP server to get a websocket srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { uut := newLogFollower(ctx, logger, mDB, ps, rw, r, job, 0) uut.follow() })) + defer srv.Close() // job was incomplete when we create the logFollower, and still incomplete when it queries diff --git a/coderd/rbac/authz.go b/coderd/rbac/authz.go index aaba7d6eae3af..02268e052d2e2 100644 --- a/coderd/rbac/authz.go +++ b/coderd/rbac/authz.go @@ -57,6 +57,23 @@ func hashAuthorizeCall(actor Subject, action policy.Action, object Object) [32]b return hashOut } +// SubjectType represents the type of subject in the RBAC system. +type SubjectType string + +const ( + SubjectTypeUser SubjectType = "user" + SubjectTypeProvisionerd SubjectType = "provisionerd" + SubjectTypeAutostart SubjectType = "autostart" + SubjectTypeHangDetector SubjectType = "hang_detector" + SubjectTypeResourceMonitor SubjectType = "resource_monitor" + SubjectTypeCryptoKeyRotator SubjectType = "crypto_key_rotator" + SubjectTypeCryptoKeyReader SubjectType = "crypto_key_reader" + SubjectTypePrebuildsOrchestrator SubjectType = "prebuilds_orchestrator" + SubjectTypeSystemReadProvisionerDaemons SubjectType = "system_read_provisioner_daemons" + SubjectTypeSystemRestricted SubjectType = "system_restricted" + SubjectTypeNotifier SubjectType = "notifier" +) + // Subject is a struct that contains all the elements of a subject in an rbac // authorize. type Subject struct { @@ -66,6 +83,14 @@ type Subject struct { // external workspace proxy or other service type actor. FriendlyName string + // Email is entirely optional and is used for logging and debugging + // It is not used in any functional way. + Email string + + // Type indicates what kind of subject this is (user, system, provisioner, etc.) + // It is not used in any functional way, only for logging. + Type SubjectType + ID string Roles ExpandableRoles Groups []string diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index aae93c0a370a0..5182cd0ccb255 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -31,6 +31,7 @@ import ( "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" @@ -463,6 +464,9 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) { t := time.NewTicker(recheckInterval) defer t.Stop() + // Log the request immediately instead of after it completes. + loggermw.RequestLoggerFromContext(ctx).WriteLog(ctx, http.StatusAccepted) + go func() { defer func() { logger.Debug(ctx, "end log streaming loop") @@ -836,6 +840,9 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { encoder := wsjson.NewEncoder[*tailcfg.DERPMap](ws, websocket.MessageBinary) defer encoder.Close(websocket.StatusGoingAway) + // Log the request immediately instead of after it completes. + loggermw.RequestLoggerFromContext(ctx).WriteLog(ctx, http.StatusAccepted) + go func(ctx context.Context) { // TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout? t := time.NewTicker(api.AgentConnectionUpdateFrequency) @@ -1210,6 +1217,9 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ sendTicker := time.NewTicker(sendInterval) defer sendTicker.Stop() + // Log the request immediately instead of after it completes. + loggermw.RequestLoggerFromContext(ctx).WriteLog(ctx, http.StatusAccepted) + // Send initial metadata. sendMetadata() diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index f4335438654b5..ccd893bb25e82 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -24,6 +24,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/provisionerdserver" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" @@ -381,6 +382,10 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) logger.Debug(ctx, "drpc server error", slog.Error(err)) }, }) + + // Log the request immediately instead of after it completes. + loggermw.RequestLoggerFromContext(ctx).WriteLog(ctx, http.StatusAccepted) + err = server.Serve(ctx, session) srvCancel() logger.Info(ctx, "provisioner daemon disconnected", slog.Error(err)) diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index af4d5064f4531..81b800c536a0f 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -32,6 +32,7 @@ import ( "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/codersdk" @@ -336,7 +337,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) { tracing.Middleware(s.TracerProvider), httpmw.AttachRequestID, httpmw.ExtractRealIP(s.Options.RealIPConfig), - httpmw.Logger(s.Logger), + loggermw.Logger(s.Logger), prometheusMW, corsMW, diff --git a/tailnet/test/integration/integration.go b/tailnet/test/integration/integration.go index 08c66b515cd53..1190a3aa98b0d 100644 --- a/tailnet/test/integration/integration.go +++ b/tailnet/test/integration/integration.go @@ -33,7 +33,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/httpapi" - "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/tailnet" @@ -200,7 +200,7 @@ func (o SimpleServerOptions) Router(t *testing.T, logger slog.Logger) *chi.Mux { }) }, tracing.StatusWriterMiddleware, - httpmw.Logger(logger), + loggermw.Logger(logger), ) r.Route("/derp", func(r chi.Router) { From b56ffe01b992c672b39b0ca83b1f2bbb1440c482 Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Tue, 22 Apr 2025 18:40:40 +0500 Subject: [PATCH 07/10] chore: prevent null loading sync settings (cherry-pick #17430) (#17487) Co-authored-by: Steven Masley --- coderd/idpsync/group.go | 3 +++ coderd/idpsync/organization.go | 3 +++ coderd/idpsync/role.go | 3 +++ 3 files changed, 9 insertions(+) diff --git a/coderd/idpsync/group.go b/coderd/idpsync/group.go index 4524284260359..b85ce1b749e28 100644 --- a/coderd/idpsync/group.go +++ b/coderd/idpsync/group.go @@ -268,6 +268,9 @@ func (s *GroupSyncSettings) Set(v string) error { } func (s *GroupSyncSettings) String() string { + if s.Mapping == nil { + s.Mapping = make(map[string][]uuid.UUID) + } return runtimeconfig.JSONString(s) } diff --git a/coderd/idpsync/organization.go b/coderd/idpsync/organization.go index 87fd9af5e935d..22f5f6e257d3e 100644 --- a/coderd/idpsync/organization.go +++ b/coderd/idpsync/organization.go @@ -168,6 +168,9 @@ func (s *OrganizationSyncSettings) Set(v string) error { } func (s *OrganizationSyncSettings) String() string { + if s.Mapping == nil { + s.Mapping = make(map[string][]uuid.UUID) + } return runtimeconfig.JSONString(s) } diff --git a/coderd/idpsync/role.go b/coderd/idpsync/role.go index 5cb0ac172581c..254ceb3d5f981 100644 --- a/coderd/idpsync/role.go +++ b/coderd/idpsync/role.go @@ -284,5 +284,8 @@ func (s *RoleSyncSettings) Set(v string) error { } func (s *RoleSyncSettings) String() string { + if s.Mapping == nil { + s.Mapping = make(map[string][]string) + } return runtimeconfig.JSONString(s) } From 7670820b4249a40f19e7399a1378e0150fe22261 Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Tue, 22 Apr 2025 19:28:51 +0500 Subject: [PATCH 08/10] test: add unit test to excercise bug when idp sync hits deleted orgs (cherry-pick #17405) (#17423) Co-authored-by: Steven Masley --- coderd/database/dbauthz/dbauthz_test.go | 5 +- coderd/database/dbfake/builder.go | 18 +++ coderd/database/dbmem/dbmem.go | 18 ++- coderd/database/queries.sql.go | 13 +- coderd/database/queries/organizations.sql | 9 +- coderd/idpsync/organization.go | 57 ++++++++- coderd/idpsync/organizations_test.go | 111 ++++++++++++++++++ coderd/users.go | 2 +- .../coderd/enidpsync/organizations_test.go | 42 ++++--- 9 files changed, 242 insertions(+), 33 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 259fd941d2921..2d4dac4e2902e 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -839,7 +839,7 @@ func (s *MethodTestSuite) TestOrganization() { _ = dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID, OrganizationID: a.ID}) b := dbgen.Organization(s.T(), db, database.Organization{}) _ = dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID, OrganizationID: b.ID}) - check.Args(database.GetOrganizationsByUserIDParams{UserID: u.ID, Deleted: false}).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b)) + check.Args(database.GetOrganizationsByUserIDParams{UserID: u.ID, Deleted: sql.NullBool{Valid: true, Bool: false}}).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b)) })) s.Run("InsertOrganization", s.Subtest(func(db database.Store, check *expects) { check.Args(database.InsertOrganizationParams{ @@ -948,8 +948,7 @@ func (s *MethodTestSuite) TestOrganization() { member, policy.ActionRead, member, policy.ActionDelete). WithNotAuthorized("no rows"). - WithCancelled(cancelledErr). - ErrorsWithInMemDB(sql.ErrNoRows) + WithCancelled(cancelledErr) })) s.Run("UpdateOrganization", s.Subtest(func(db database.Store, check *expects) { o := dbgen.Organization(s.T(), db, database.Organization{ diff --git a/coderd/database/dbfake/builder.go b/coderd/database/dbfake/builder.go index 6803374e72445..c39b00ac79054 100644 --- a/coderd/database/dbfake/builder.go +++ b/coderd/database/dbfake/builder.go @@ -17,6 +17,7 @@ type OrganizationBuilder struct { t *testing.T db database.Store seed database.Organization + delete bool allUsersAllowance int32 members []uuid.UUID groups map[database.Group][]uuid.UUID @@ -44,6 +45,12 @@ func (b OrganizationBuilder) EveryoneAllowance(allowance int) OrganizationBuilde return b } +func (b OrganizationBuilder) Deleted(deleted bool) OrganizationBuilder { + //nolint: revive // returns modified struct + b.delete = deleted + return b +} + func (b OrganizationBuilder) Seed(seed database.Organization) OrganizationBuilder { //nolint: revive // returns modified struct b.seed = seed @@ -118,6 +125,17 @@ func (b OrganizationBuilder) Do() OrganizationResponse { } } + if b.delete { + now := dbtime.Now() + err = b.db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{ + UpdatedAt: now, + ID: org.ID, + }) + require.NoError(b.t, err) + org.Deleted = true + org.UpdatedAt = now + } + return OrganizationResponse{ Org: org, AllUsersGroup: everyone, diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index e4956958e0073..9a7533a5894d5 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -2167,10 +2167,13 @@ func (q *FakeQuerier) DeleteOrganizationMember(ctx context.Context, arg database q.mutex.Lock() defer q.mutex.Unlock() - deleted := slices.DeleteFunc(q.data.organizationMembers, func(member database.OrganizationMember) bool { - return member.OrganizationID == arg.OrganizationID && member.UserID == arg.UserID + deleted := false + q.data.organizationMembers = slices.DeleteFunc(q.data.organizationMembers, func(member database.OrganizationMember) bool { + match := member.OrganizationID == arg.OrganizationID && member.UserID == arg.UserID + deleted = deleted || match + return match }) - if len(deleted) == 0 { + if !deleted { return sql.ErrNoRows } @@ -3758,6 +3761,9 @@ func (q *FakeQuerier) GetOrganizations(_ context.Context, args database.GetOrgan if args.Name != "" && !strings.EqualFold(org.Name, args.Name) { continue } + if args.Deleted != org.Deleted { + continue + } tmp = append(tmp, org) } @@ -3774,7 +3780,11 @@ func (q *FakeQuerier) GetOrganizationsByUserID(_ context.Context, arg database.G continue } for _, organization := range q.organizations { - if organization.ID != organizationMember.OrganizationID || organization.Deleted != arg.Deleted { + if organization.ID != organizationMember.OrganizationID { + continue + } + + if arg.Deleted.Valid && organization.Deleted != arg.Deleted.Bool { continue } organizations = append(organizations, organization) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 638124057f58b..22ba10a99922e 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5221,8 +5221,13 @@ SELECT FROM organizations WHERE - -- Optionally include deleted organizations - deleted = $2 AND + -- Optionally provide a filter for deleted organizations. + CASE WHEN + $2 :: boolean IS NULL THEN + true + ELSE + deleted = $2 + END AND id = ANY( SELECT organization_id @@ -5234,8 +5239,8 @@ WHERE ` type GetOrganizationsByUserIDParams struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - Deleted bool `db:"deleted" json:"deleted"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + Deleted sql.NullBool `db:"deleted" json:"deleted"` } func (q *sqlQuerier) GetOrganizationsByUserID(ctx context.Context, arg GetOrganizationsByUserIDParams) ([]Organization, error) { diff --git a/coderd/database/queries/organizations.sql b/coderd/database/queries/organizations.sql index 822b51c0aa8ba..988efb535152d 100644 --- a/coderd/database/queries/organizations.sql +++ b/coderd/database/queries/organizations.sql @@ -55,8 +55,13 @@ SELECT FROM organizations WHERE - -- Optionally include deleted organizations - deleted = @deleted AND + -- Optionally provide a filter for deleted organizations. + CASE WHEN + sqlc.narg('deleted') :: boolean IS NULL THEN + true + ELSE + deleted = sqlc.narg('deleted') + END AND id = ANY( SELECT organization_id diff --git a/coderd/idpsync/organization.go b/coderd/idpsync/organization.go index 22f5f6e257d3e..5d56bc7d239a5 100644 --- a/coderd/idpsync/organization.go +++ b/coderd/idpsync/organization.go @@ -92,14 +92,16 @@ func (s AGPLIDPSync) SyncOrganizations(ctx context.Context, tx database.Store, u return nil // No sync configured, nothing to do } - expectedOrgs, err := orgSettings.ParseClaims(ctx, tx, params.MergedClaims) + expectedOrgIDs, err := orgSettings.ParseClaims(ctx, tx, params.MergedClaims) if err != nil { return xerrors.Errorf("organization claims: %w", err) } + // Fetch all organizations, even deleted ones. This is to remove a user + // from any deleted organizations they may be in. existingOrgs, err := tx.GetOrganizationsByUserID(ctx, database.GetOrganizationsByUserIDParams{ UserID: user.ID, - Deleted: false, + Deleted: sql.NullBool{}, }) if err != nil { return xerrors.Errorf("failed to get user organizations: %w", err) @@ -109,10 +111,35 @@ func (s AGPLIDPSync) SyncOrganizations(ctx context.Context, tx database.Store, u return org.ID }) + // finalExpected is the final set of org ids the user is expected to be in. + // Deleted orgs are omitted from this set. + finalExpected := expectedOrgIDs + if len(expectedOrgIDs) > 0 { + // If you pass in an empty slice to the db arg, you get all orgs. So the slice + // has to be non-empty to get the expected set. Logically it also does not make + // sense to fetch an empty set from the db. + expectedOrganizations, err := tx.GetOrganizations(ctx, database.GetOrganizationsParams{ + IDs: expectedOrgIDs, + // Do not include deleted organizations. Omitting deleted orgs will remove the + // user from any deleted organizations they are a member of. + Deleted: false, + }) + if err != nil { + return xerrors.Errorf("failed to get expected organizations: %w", err) + } + finalExpected = db2sdk.List(expectedOrganizations, func(org database.Organization) uuid.UUID { + return org.ID + }) + } + // Find the difference in the expected and the existing orgs, and // correct the set of orgs the user is a member of. - add, remove := slice.SymmetricDifference(existingOrgIDs, expectedOrgs) - notExists := make([]uuid.UUID, 0) + add, remove := slice.SymmetricDifference(existingOrgIDs, finalExpected) + // notExists is purely for debugging. It logs when the settings want + // a user in an organization, but the organization does not exist. + notExists := slice.DifferenceFunc(expectedOrgIDs, finalExpected, func(a, b uuid.UUID) bool { + return a == b + }) for _, orgID := range add { _, err := tx.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{ OrganizationID: orgID, @@ -123,9 +150,30 @@ func (s AGPLIDPSync) SyncOrganizations(ctx context.Context, tx database.Store, u }) if err != nil { if xerrors.Is(err, sql.ErrNoRows) { + // This should not happen because we check the org existence + // beforehand. notExists = append(notExists, orgID) continue } + + if database.IsUniqueViolation(err, database.UniqueOrganizationMembersPkey) { + // If we hit this error we have a bug. The user already exists in the + // organization, but was not detected to be at the start of this function. + // Instead of failing the function, an error will be logged. This is to not bring + // down the entire syncing behavior from a single failed org. Failing this can + // prevent user logins, so only fatal non-recoverable errors should be returned. + // + // Inserting a user is privilege escalation. So skipping this instead of failing + // leaves the user with fewer permissions. So this is safe from a security + // perspective to continue. + s.Logger.Error(ctx, "syncing user to organization failed as they are already a member, please report this failure to Coder", + slog.F("user_id", user.ID), + slog.F("username", user.Username), + slog.F("organization_id", orgID), + slog.Error(err), + ) + continue + } return xerrors.Errorf("add user to organization: %w", err) } } @@ -141,6 +189,7 @@ func (s AGPLIDPSync) SyncOrganizations(ctx context.Context, tx database.Store, u } if len(notExists) > 0 { + notExists = slice.Unique(notExists) // Remove duplicates s.Logger.Debug(ctx, "organizations do not exist but attempted to use in org sync", slog.F("not_found", notExists), slog.F("user_id", user.ID), diff --git a/coderd/idpsync/organizations_test.go b/coderd/idpsync/organizations_test.go index 51c8a7365d22b..3a00499bdbced 100644 --- a/coderd/idpsync/organizations_test.go +++ b/coderd/idpsync/organizations_test.go @@ -1,6 +1,7 @@ package idpsync_test import ( + "database/sql" "testing" "github.com/golang-jwt/jwt/v4" @@ -8,6 +9,11 @@ import ( "github.com/stretchr/testify/require" "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/testutil" @@ -38,3 +44,108 @@ func TestParseOrganizationClaims(t *testing.T) { require.False(t, params.SyncEntitled) }) } + +func TestSyncOrganizations(t *testing.T) { + t.Parallel() + + // This test creates some deleted organizations and checks the behavior is + // correct. + t.Run("SyncUserToDeletedOrg", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitMedium) + db, _ := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + + // Create orgs for: + // - stays = User is a member, and stays + // - leaves = User is a member, and leaves + // - joins = User is not a member, and joins + // For deleted orgs, the user **should not** be a member of afterwards. + // - deletedStays = User is a member of deleted org, and wants to stay + // - deletedLeaves = User is a member of deleted org, and wants to leave + // - deletedJoins = User is not a member of deleted org, and wants to join + stays := dbfake.Organization(t, db).Members(user).Do() + leaves := dbfake.Organization(t, db).Members(user).Do() + joins := dbfake.Organization(t, db).Do() + + deletedStays := dbfake.Organization(t, db).Members(user).Deleted(true).Do() + deletedLeaves := dbfake.Organization(t, db).Members(user).Deleted(true).Do() + deletedJoins := dbfake.Organization(t, db).Deleted(true).Do() + + // Now sync the user to the deleted organization + s := idpsync.NewAGPLSync( + slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewManager(), + idpsync.DeploymentSyncSettings{ + OrganizationField: "orgs", + OrganizationMapping: map[string][]uuid.UUID{ + "stay": {stays.Org.ID, deletedStays.Org.ID}, + "leave": {leaves.Org.ID, deletedLeaves.Org.ID}, + "join": {joins.Org.ID, deletedJoins.Org.ID}, + }, + OrganizationAssignDefault: false, + }, + ) + + err := s.SyncOrganizations(ctx, db, user, idpsync.OrganizationParams{ + SyncEntitled: true, + MergedClaims: map[string]interface{}{ + "orgs": []string{"stay", "join"}, + }, + }) + require.NoError(t, err) + + orgs, err := db.GetOrganizationsByUserID(ctx, database.GetOrganizationsByUserIDParams{ + UserID: user.ID, + Deleted: sql.NullBool{}, + }) + require.NoError(t, err) + require.Len(t, orgs, 2) + + // Verify the user only exists in 2 orgs. The one they stayed, and the one they + // joined. + inIDs := db2sdk.List(orgs, func(org database.Organization) uuid.UUID { + return org.ID + }) + require.ElementsMatch(t, []uuid.UUID{stays.Org.ID, joins.Org.ID}, inIDs) + }) + + t.Run("UserToZeroOrgs", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitMedium) + db, _ := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + + deletedLeaves := dbfake.Organization(t, db).Members(user).Deleted(true).Do() + + // Now sync the user to the deleted organization + s := idpsync.NewAGPLSync( + slogtest.Make(t, &slogtest.Options{}), + runtimeconfig.NewManager(), + idpsync.DeploymentSyncSettings{ + OrganizationField: "orgs", + OrganizationMapping: map[string][]uuid.UUID{ + "leave": {deletedLeaves.Org.ID}, + }, + OrganizationAssignDefault: false, + }, + ) + + err := s.SyncOrganizations(ctx, db, user, idpsync.OrganizationParams{ + SyncEntitled: true, + MergedClaims: map[string]interface{}{ + "orgs": []string{}, + }, + }) + require.NoError(t, err) + + orgs, err := db.GetOrganizationsByUserID(ctx, database.GetOrganizationsByUserIDParams{ + UserID: user.ID, + Deleted: sql.NullBool{}, + }) + require.NoError(t, err) + require.Len(t, orgs, 0) + }) +} diff --git a/coderd/users.go b/coderd/users.go index bf5b1db763fe9..4c3d1ceaa7b58 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -1274,7 +1274,7 @@ func (api *API) organizationsByUser(rw http.ResponseWriter, r *http.Request) { organizations, err := api.Database.GetOrganizationsByUserID(ctx, database.GetOrganizationsByUserIDParams{ UserID: user.ID, - Deleted: false, + Deleted: sql.NullBool{Bool: false, Valid: true}, }) if errors.Is(err, sql.ErrNoRows) { err = nil diff --git a/enterprise/coderd/enidpsync/organizations_test.go b/enterprise/coderd/enidpsync/organizations_test.go index 391535c9478d7..b2e120592b582 100644 --- a/enterprise/coderd/enidpsync/organizations_test.go +++ b/enterprise/coderd/enidpsync/organizations_test.go @@ -14,6 +14,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/entitlements" @@ -89,7 +90,8 @@ func TestOrganizationSync(t *testing.T) { Name: "SingleOrgDeployment", Case: func(t *testing.T, db database.Store) OrganizationSyncTestCase { def, _ := db.GetDefaultOrganization(context.Background()) - other := dbgen.Organization(t, db, database.Organization{}) + other := dbfake.Organization(t, db).Do() + deleted := dbfake.Organization(t, db).Deleted(true).Do() return OrganizationSyncTestCase{ Entitlements: entitled, Settings: idpsync.DeploymentSyncSettings{ @@ -123,11 +125,19 @@ func TestOrganizationSync(t *testing.T) { }) dbgen.OrganizationMember(t, db, database.OrganizationMember{ UserID: user.ID, - OrganizationID: other.ID, + OrganizationID: other.Org.ID, + }) + dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: deleted.Org.ID, }) }, Sync: ExpectedUser{ - Organizations: []uuid.UUID{def.ID, other.ID}, + Organizations: []uuid.UUID{ + def.ID, other.Org.ID, + // The user remains in the deleted org because no idp sync happens. + deleted.Org.ID, + }, }, }, }, @@ -138,17 +148,19 @@ func TestOrganizationSync(t *testing.T) { Name: "MultiOrgWithDefault", Case: func(t *testing.T, db database.Store) OrganizationSyncTestCase { def, _ := db.GetDefaultOrganization(context.Background()) - one := dbgen.Organization(t, db, database.Organization{}) - two := dbgen.Organization(t, db, database.Organization{}) - three := dbgen.Organization(t, db, database.Organization{}) + one := dbfake.Organization(t, db).Do() + two := dbfake.Organization(t, db).Do() + three := dbfake.Organization(t, db).Do() + deleted := dbfake.Organization(t, db).Deleted(true).Do() return OrganizationSyncTestCase{ Entitlements: entitled, Settings: idpsync.DeploymentSyncSettings{ OrganizationField: "organizations", OrganizationMapping: map[string][]uuid.UUID{ - "first": {one.ID}, - "second": {two.ID}, - "third": {three.ID}, + "first": {one.Org.ID}, + "second": {two.Org.ID}, + "third": {three.Org.ID}, + "deleted": {deleted.Org.ID}, }, OrganizationAssignDefault: true, }, @@ -167,7 +179,7 @@ func TestOrganizationSync(t *testing.T) { { Name: "AlreadyInOrgs", Claims: jwt.MapClaims{ - "organizations": []string{"second", "extra"}, + "organizations": []string{"second", "extra", "deleted"}, }, ExpectedParams: idpsync.OrganizationParams{ SyncEntitled: true, @@ -180,18 +192,18 @@ func TestOrganizationSync(t *testing.T) { }) dbgen.OrganizationMember(t, db, database.OrganizationMember{ UserID: user.ID, - OrganizationID: one.ID, + OrganizationID: one.Org.ID, }) }, Sync: ExpectedUser{ - Organizations: []uuid.UUID{def.ID, two.ID}, + Organizations: []uuid.UUID{def.ID, two.Org.ID}, }, }, { Name: "ManyClaims", Claims: jwt.MapClaims{ // Add some repeats - "organizations": []string{"second", "extra", "first", "third", "second", "second"}, + "organizations": []string{"second", "extra", "first", "third", "second", "second", "deleted"}, }, ExpectedParams: idpsync.OrganizationParams{ SyncEntitled: true, @@ -204,11 +216,11 @@ func TestOrganizationSync(t *testing.T) { }) dbgen.OrganizationMember(t, db, database.OrganizationMember{ UserID: user.ID, - OrganizationID: one.ID, + OrganizationID: one.Org.ID, }) }, Sync: ExpectedUser{ - Organizations: []uuid.UUID{def.ID, one.ID, two.ID, three.ID}, + Organizations: []uuid.UUID{def.ID, one.Org.ID, two.Org.ID, three.Org.ID}, }, }, }, From 8fe94e8c4774862b9543d07a76cb4e121bcb6f82 Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Wed, 23 Apr 2025 17:30:32 +0500 Subject: [PATCH 09/10] fix(scripts/release): handle cherry-pick bot titles in check commit metadata (cherry-pick #17535) (#17538) Co-authored-by: Mathias Fredriksson --- scripts/release/check_commit_metadata.sh | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/scripts/release/check_commit_metadata.sh b/scripts/release/check_commit_metadata.sh index f53de8e107430..1368425d00639 100755 --- a/scripts/release/check_commit_metadata.sh +++ b/scripts/release/check_commit_metadata.sh @@ -118,6 +118,23 @@ main() { title2=${parts2[*]:2} fi + # Handle cherry-pick bot, it turns "chore: foo bar (#42)" to + # "chore: foo bar (cherry-pick #42) (#43)". + if [[ ${title1} == *"(cherry-pick #"* ]]; then + title1=${title1%" ("*} + pr=${title1##*#} + pr=${pr%)} + title1=${title1%" ("*} + title1="${title1} (#${pr})"$'\n' + fi + if [[ ${title2} == *"(cherry-pick #"* ]]; then + title2=${title2%" ("*} + pr=${title2##*#} + pr=${pr%)} + title2=${title2%" ("*} + title2="${title2} (#${pr})"$'\n' + fi + if [[ ${title1} != "${title2}" ]]; then log "Invariant failed, cherry-picked commits have different titles: \"${title1%$'\n'}\" != \"${title2%$'\n'}\", attempting to check commit body for cherry-pick information..." From 8329bea8b79607c9b9f66ad1817a3624ef79be5f Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Wed, 30 Apr 2025 00:38:13 +1000 Subject: [PATCH 10/10] chore: apply Dockerfile architecture fix (#17602) --- scripts/Dockerfile.base | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/Dockerfile.base b/scripts/Dockerfile.base index f9d2bf6594b08..8090623be8306 100644 --- a/scripts/Dockerfile.base +++ b/scripts/Dockerfile.base @@ -26,7 +26,7 @@ RUN apk add --no-cache \ # Terraform was disabled in the edge repo due to a build issue. # https://gitlab.alpinelinux.org/alpine/aports/-/commit/f3e263d94cfac02d594bef83790c280e045eba35 # Using wget for now. Note that busybox unzip doesn't support streaming. -RUN ARCH="$(arch)"; if [ "${ARCH}" == "x86_64" ]; then ARCH="amd64"; elif [ "${ARCH}" == "aarch64" ]; then ARCH="arm64"; fi; wget -O /tmp/terraform.zip "https://releases.hashicorp.com/terraform/1.10.5/terraform_1.10.5_linux_${ARCH}.zip" && \ +RUN ARCH="$(arch)"; if [ "${ARCH}" == "x86_64" ]; then ARCH="amd64"; elif [ "${ARCH}" == "aarch64" ]; then ARCH="arm64"; elif [ "${ARCH}" == "armv7l" ]; then ARCH="arm"; fi; wget -O /tmp/terraform.zip "https://releases.hashicorp.com/terraform/1.10.5/terraform_1.10.5_linux_${ARCH}.zip" && \ busybox unzip /tmp/terraform.zip -d /usr/local/bin && \ rm -f /tmp/terraform.zip && \ chmod +x /usr/local/bin/terraform && \