diff --git a/.claude/docs/DATABASE.md b/.claude/docs/DATABASE.md index f6ba4bd78859b..090054772fc32 100644 --- a/.claude/docs/DATABASE.md +++ b/.claude/docs/DATABASE.md @@ -58,31 +58,6 @@ If adding fields to auditable types: - `ActionSecret`: Field contains sensitive data 3. Run `make gen` to verify no audit errors -## In-Memory Database (dbmem) Updates - -### Critical Requirements - -When adding new fields to database structs: - -- **CRITICAL**: Update `coderd/database/dbmem/dbmem.go` in-memory implementations -- The `Insert*` functions must include ALL new fields, not just basic ones -- Common issue: Tests pass with real database but fail with in-memory database due to missing field mappings -- Always verify in-memory database functions match the real database schema after migrations - -### Example Pattern - -```go -// In dbmem.go - ensure ALL fields are included -code := database.OAuth2ProviderAppCode{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - // ... existing fields ... - ResourceUri: arg.ResourceUri, // New field - CodeChallenge: arg.CodeChallenge, // New field - CodeChallengeMethod: arg.CodeChallengeMethod, // New field -} -``` - ## Database Architecture ### Core Components @@ -116,7 +91,6 @@ roles, err := db.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), user 1. **Nullable field errors**: Use `sql.Null*` types consistently 2. **Missing audit entries**: Update `enterprise/audit/table.go` -3. **dbmem inconsistencies**: Ensure in-memory implementations match schema ### Query Issues @@ -139,19 +113,6 @@ func TestDatabaseFunction(t *testing.T) { } ``` -### In-Memory Testing - -```go -func TestInMemoryDatabase(t *testing.T) { - db := dbmem.New() - - // Test with in-memory database - result, err := db.GetSomething(ctx, param) - require.NoError(t, err) - require.Equal(t, expected, result) -} -``` - ## Best Practices ### Schema Design diff --git a/.claude/docs/OAUTH2.md b/.claude/docs/OAUTH2.md index 2c766dd083516..9fb34f093042a 100644 --- a/.claude/docs/OAUTH2.md +++ b/.claude/docs/OAUTH2.md @@ -112,14 +112,13 @@ Always run the full test suite after OAuth2 changes: ## Common OAuth2 Issues 1. **OAuth2 endpoints returning wrong error format** - Ensure OAuth2 endpoints return RFC 6749 compliant errors -2. **OAuth2 tests failing but scripts working** - Check in-memory database implementations in `dbmem.go` -3. **Resource indicator validation failing** - Ensure database stores and retrieves resource parameters correctly -4. **PKCE tests failing** - Verify both authorization code storage and token exchange handle PKCE fields -5. **RFC compliance failures** - Verify against actual RFC specifications, not assumptions -6. **Authorization context errors in public endpoints** - Use `dbauthz.AsSystemRestricted(ctx)` pattern -7. **Default value mismatches** - Ensure database migrations match application code defaults -8. **Bearer token authentication issues** - Check token extraction precedence and format validation -9. **URI validation failures** - Support both standard schemes and custom schemes per protocol requirements +2. **Resource indicator validation failing** - Ensure database stores and retrieves resource parameters correctly +3. **PKCE tests failing** - Verify both authorization code storage and token exchange handle PKCE fields +4. **RFC compliance failures** - Verify against actual RFC specifications, not assumptions +5. **Authorization context errors in public endpoints** - Use `dbauthz.AsSystemRestricted(ctx)` pattern +6. **Default value mismatches** - Ensure database migrations match application code defaults +7. **Bearer token authentication issues** - Check token extraction precedence and format validation +8. **URI validation failures** - Support both standard schemes and custom schemes per protocol requirements ## Authorization Context Patterns diff --git a/.claude/docs/TESTING.md b/.claude/docs/TESTING.md index b8f92f531bb1c..eff655b0acadc 100644 --- a/.claude/docs/TESTING.md +++ b/.claude/docs/TESTING.md @@ -39,31 +39,6 @@ 2. **Verify information disclosure protections** 3. **Test token security and proper invalidation** -## Database Testing - -### In-Memory Database Testing - -When adding new database fields: - -- **CRITICAL**: Update `coderd/database/dbmem/dbmem.go` in-memory implementations -- The `Insert*` functions must include ALL new fields, not just basic ones -- Common issue: Tests pass with real database but fail with in-memory database due to missing field mappings -- Always verify in-memory database functions match the real database schema after migrations - -Example pattern: - -```go -// In dbmem.go - ensure ALL fields are included -code := database.OAuth2ProviderAppCode{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - // ... existing fields ... - ResourceUri: arg.ResourceUri, // New field - CodeChallenge: arg.CodeChallenge, // New field - CodeChallengeMethod: arg.CodeChallengeMethod, // New field -} -``` - ## Test Organization ### Test File Structure @@ -107,15 +82,13 @@ coderd/ ### Database-Related -1. **Tests passing locally but failing in CI** - Check if `dbmem` implementation needs updating -2. **SQL type errors** - Use `sql.Null*` types for nullable fields -3. **Race conditions in tests** - Use unique identifiers instead of hardcoded names +1. **SQL type errors** - Use `sql.Null*` types for nullable fields +2. **Race conditions in tests** - Use unique identifiers instead of hardcoded names ### OAuth2 Testing -1. **OAuth2 tests failing but scripts working** - Check in-memory database implementations in `dbmem.go` -2. **PKCE tests failing** - Verify both authorization code storage and token exchange handle PKCE fields -3. **Resource indicator validation failing** - Ensure database stores and retrieves resource parameters correctly +1. **PKCE tests failing** - Verify both authorization code storage and token exchange handle PKCE fields +2. **Resource indicator validation failing** - Ensure database stores and retrieves resource parameters correctly ### General Issues diff --git a/.claude/docs/TROUBLESHOOTING.md b/.claude/docs/TROUBLESHOOTING.md index 2b4bb3ee064cc..19c05a7a0cd62 100644 --- a/.claude/docs/TROUBLESHOOTING.md +++ b/.claude/docs/TROUBLESHOOTING.md @@ -21,59 +21,50 @@ } ``` -3. **Tests passing locally but failing in CI** - - **Solution**: Check if `dbmem` implementation needs updating - - Update `coderd/database/dbmem/dbmem.go` for Insert/Update methods - - Missing fields in dbmem can cause tests to fail even if main implementation is correct - ### Testing Issues -4. **"package should be X_test"** +3. **"package should be X_test"** - **Solution**: Use `package_test` naming for test files - Example: `identityprovider_test` for black-box testing -5. **Race conditions in tests** +4. **Race conditions in tests** - **Solution**: Use unique identifiers instead of hardcoded names - Example: `fmt.Sprintf("test-client-%s-%d", t.Name(), time.Now().UnixNano())` - Never use hardcoded names in concurrent tests -6. **Missing newlines** +5. **Missing newlines** - **Solution**: Ensure files end with newline character - Most editors can be configured to add this automatically ### OAuth2 Issues -7. **OAuth2 endpoints returning wrong error format** +6. **OAuth2 endpoints returning wrong error format** - **Solution**: Ensure OAuth2 endpoints return RFC 6749 compliant errors - Use standard error codes: `invalid_client`, `invalid_grant`, `invalid_request` - Format: `{"error": "code", "error_description": "details"}` -8. **OAuth2 tests failing but scripts working** - - **Solution**: Check in-memory database implementations in `dbmem.go` - - Ensure all OAuth2 fields are properly copied in Insert/Update methods - -9. **Resource indicator validation failing** +7. **Resource indicator validation failing** - **Solution**: Ensure database stores and retrieves resource parameters correctly - Check both authorization code storage and token exchange handling -10. **PKCE tests failing** +8. **PKCE tests failing** - **Solution**: Verify both authorization code storage and token exchange handle PKCE fields - Check `CodeChallenge` and `CodeChallengeMethod` field handling ### RFC Compliance Issues -11. **RFC compliance failures** +9. **RFC compliance failures** - **Solution**: Verify against actual RFC specifications, not assumptions - Use WebFetch tool to get current RFC content for compliance verification - Read the actual RFC specifications before implementation -12. **Default value mismatches** +10. **Default value mismatches** - **Solution**: Ensure database migrations match application code defaults - Example: RFC 7591 specifies `client_secret_basic` as default, not `client_secret_post` ### Authorization Issues -13. **Authorization context errors in public endpoints** +11. **Authorization context errors in public endpoints** - **Solution**: Use `dbauthz.AsSystemRestricted(ctx)` pattern - Example: @@ -84,17 +75,17 @@ ### Authentication Issues -14. **Bearer token authentication issues** +12. **Bearer token authentication issues** - **Solution**: Check token extraction precedence and format validation - Ensure proper RFC 6750 Bearer Token Support implementation -15. **URI validation failures** +13. **URI validation failures** - **Solution**: Support both standard schemes and custom schemes per protocol requirements - Native OAuth2 apps may use custom schemes ### General Development Issues -16. **Log message formatting errors** +14. **Log message formatting errors** - **Solution**: Use lowercase, descriptive messages without special characters - Follow Go logging conventions diff --git a/.claude/docs/WORKFLOWS.md b/.claude/docs/WORKFLOWS.md index 1bd595c8a4b34..b846110d589d8 100644 --- a/.claude/docs/WORKFLOWS.md +++ b/.claude/docs/WORKFLOWS.md @@ -81,11 +81,6 @@ - Add each new field with appropriate action (ActionTrack, ActionIgnore, ActionSecret) - Run `make gen` to verify no audit errors -6. **In-memory database (dbmem) updates**: - - When adding new fields to database structs, ensure `dbmem` implementation copies all fields - - Check `coderd/database/dbmem/dbmem.go` for Insert/Update methods - - Missing fields in dbmem can cause tests to fail even if main implementation is correct - ### Database Generation Process 1. Modify SQL files in `coderd/database/queries/` @@ -164,9 +159,8 @@ 1. **Development server won't start** - Use `./scripts/develop.sh` instead of manual commands 2. **Database migration errors** - Check migration file format and use helper scripts -3. **Test failures after database changes** - Update `dbmem` implementations -4. **Audit table errors** - Update `enterprise/audit/table.go` with new fields -5. **OAuth2 compliance issues** - Ensure RFC-compliant error responses +3. **Audit table errors** - Update `enterprise/audit/table.go` with new fields +4. **OAuth2 compliance issues** - Ensure RFC-compliant error responses ### Debug Commands diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8c4e21466c03d..3566f77982c1c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -311,94 +311,6 @@ jobs: - name: Check for unstaged files run: ./scripts/check_unstaged.sh - test-go: - runs-on: ${{ matrix.os == 'ubuntu-latest' && github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || matrix.os == 'macos-latest' && github.repository_owner == 'coder' && 'depot-macos-latest' || matrix.os == 'windows-2022' && github.repository_owner == 'coder' && 'depot-windows-2022-16' || matrix.os }} - needs: changes - if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main' - timeout-minutes: 20 - strategy: - fail-fast: false - matrix: - os: - - ubuntu-latest - - macos-latest - - windows-2022 - steps: - - name: Harden Runner - # Harden Runner is only supported on Ubuntu runners. - if: runner.os == 'Linux' - uses: step-security/harden-runner@6c439dc8bdf85cadbbce9ed30d1c7b959517bc49 # v2.12.2 - with: - egress-policy: audit - - # Set up RAM disks to speed up the rest of the job. This action is in - # a separate repository to allow its use before actions/checkout. - - name: Setup RAM Disks - if: runner.os == 'Windows' - uses: coder/setup-ramdisk-action@e1100847ab2d7bcd9d14bcda8f2d1b0f07b36f1b - - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 1 - - - name: Setup Go Paths - uses: ./.github/actions/setup-go-paths - - - name: Setup Go - uses: ./.github/actions/setup-go - with: - # Runners have Go baked-in and Go will automatically - # download the toolchain configured in go.mod, so we don't - # need to reinstall it. It's faster on Windows runners. - use-preinstalled-go: ${{ runner.os == 'Windows' }} - - - name: Setup Terraform - uses: ./.github/actions/setup-tf - - - name: Download Test Cache - id: download-cache - uses: ./.github/actions/test-cache/download - with: - key-prefix: test-go-${{ runner.os }}-${{ runner.arch }} - - - name: Test with Mock Database - id: test - shell: bash - run: | - # if macOS, install google-chrome for scaletests. As another concern, - # should we really have this kind of external dependency requirement - # on standard CI? - if [ "${{ matrix.os }}" == "macos-latest" ]; then - brew install google-chrome - fi - - # By default Go will use the number of logical CPUs, which - # is a fine default. - PARALLEL_FLAG="" - - # macOS will output "The default interactive shell is now zsh" - # intermittently in CI... - if [ "${{ matrix.os }}" == "macos-latest" ]; then - touch ~/.bash_profile && echo "export BASH_SILENCE_DEPRECATION_WARNING=1" >> ~/.bash_profile - fi - export TS_DEBUG_DISCO=true - gotestsum --junitfile="gotests.xml" --jsonfile="gotests.json" --rerun-fails=2 \ - --packages="./..." -- $PARALLEL_FLAG -short - - - name: Upload Test Cache - uses: ./.github/actions/test-cache/upload - with: - cache-key: ${{ steps.download-cache.outputs.cache-key }} - - - name: Upload test stats to Datadog - timeout-minutes: 1 - continue-on-error: true - uses: ./.github/actions/upload-datadog - if: success() || failure() - with: - api-key: ${{ secrets.DATADOG_API_KEY }} - test-go-pg: # make sure to adjust NUM_PARALLEL_PACKAGES and NUM_PARALLEL_TESTS below # when changing runner sizes @@ -539,16 +451,21 @@ jobs: # Postgres tends not to choke. NUM_PARALLEL_PACKAGES=8 NUM_PARALLEL_TESTS=16 + # Only the CLI and Agent are officially supported on Windows and the rest are too flaky + PACKAGES="./cli/... ./enterprise/cli/... ./agent/..." elif [ "${{ runner.os }}" == "macOS" ]; then # Our macOS runners have 8 cores. We set NUM_PARALLEL_TESTS to 16 # because the tests complete faster and Postgres doesn't choke. It seems # that macOS's tmpfs is faster than the one on Windows. NUM_PARALLEL_PACKAGES=8 NUM_PARALLEL_TESTS=16 + # Only the CLI and Agent are officially supported on macOS and the rest are too flaky + PACKAGES="./cli/... ./enterprise/cli/... ./agent/..." elif [ "${{ runner.os }}" == "Linux" ]; then # Our Linux runners have 8 cores. NUM_PARALLEL_PACKAGES=8 NUM_PARALLEL_TESTS=8 + PACKAGES="./..." fi # by default, run tests with cache @@ -565,10 +482,7 @@ jobs: # invalidated. See scripts/normalize_path.sh for more details. normalize_path_with_symlinks "$RUNNER_TEMP/sym" "$(dirname $(which terraform))" - # We rerun failing tests to counteract flakiness coming from Postgres - # choking on macOS and Windows sometimes. - DB=ci gotestsum --rerun-fails=2 --rerun-fails-max-failures=50 \ - --format standard-quiet --packages "./..." \ + gotestsum --format standard-quiet --packages "$PACKAGES" \ -- -timeout=20m -v -p $NUM_PARALLEL_PACKAGES -parallel=$NUM_PARALLEL_TESTS $TESTCOUNT - name: Upload Go Build Cache @@ -638,7 +552,6 @@ jobs: env: POSTGRES_VERSION: "17" TS_DEBUG_DISCO: "true" - TEST_RETRIES: 2 run: | make test-postgres @@ -655,55 +568,6 @@ jobs: with: api-key: ${{ secrets.DATADOG_API_KEY }} - test-go-race: - runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-16' || 'ubuntu-latest' }} - needs: changes - if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main' - timeout-minutes: 25 - steps: - - name: Harden Runner - uses: step-security/harden-runner@6c439dc8bdf85cadbbce9ed30d1c7b959517bc49 # v2.12.2 - with: - egress-policy: audit - - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 1 - - - name: Setup Go - uses: ./.github/actions/setup-go - - - name: Setup Terraform - uses: ./.github/actions/setup-tf - - - name: Download Test Cache - id: download-cache - uses: ./.github/actions/test-cache/download - with: - key-prefix: test-go-race-${{ runner.os }}-${{ runner.arch }} - - # We run race tests with reduced parallelism because they use more CPU and we were finding - # instances where tests appear to hang for multiple seconds, resulting in flaky tests when - # short timeouts are used. - # c.f. discussion on https://github.com/coder/coder/pull/15106 - - name: Run Tests - run: | - gotestsum --junitfile="gotests.xml" --packages="./..." --rerun-fails=2 --rerun-fails-abort-on-data-race -- -race -parallel 4 -p 4 - - - name: Upload Test Cache - uses: ./.github/actions/test-cache/upload - with: - cache-key: ${{ steps.download-cache.outputs.cache-key }} - - - name: Upload test stats to Datadog - timeout-minutes: 1 - continue-on-error: true - uses: ./.github/actions/upload-datadog - if: always() - with: - api-key: ${{ secrets.DATADOG_API_KEY }} - test-go-race-pg: runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-16' || 'ubuntu-latest' }} needs: changes @@ -741,7 +605,7 @@ jobs: POSTGRES_VERSION: "17" run: | make test-postgres-docker - DB=ci gotestsum --junitfile="gotests.xml" --packages="./..." --rerun-fails=2 --rerun-fails-abort-on-data-race -- -race -parallel 4 -p 4 + gotestsum --junitfile="gotests.xml" --packages="./..." -- -race -parallel 4 -p 4 - name: Upload Test Cache uses: ./.github/actions/test-cache/upload @@ -863,7 +727,6 @@ jobs: if: ${{ !matrix.variant.premium }} env: DEBUG: pw:api - CODER_E2E_TEST_RETRIES: 2 working-directory: site # Run all of the tests with a premium license @@ -873,7 +736,6 @@ jobs: DEBUG: pw:api CODER_E2E_LICENSE: ${{ secrets.CODER_E2E_LICENSE }} CODER_E2E_REQUIRE_PREMIUM_TESTS: "1" - CODER_E2E_TEST_RETRIES: 2 working-directory: site - name: Upload Playwright Failed Tests @@ -1037,9 +899,7 @@ jobs: - fmt - lint - gen - - test-go - test-go-pg - - test-go-race - test-go-race-pg - test-js - test-e2e @@ -1060,9 +920,7 @@ jobs: echo "- fmt: ${{ needs.fmt.result }}" echo "- lint: ${{ needs.lint.result }}" echo "- gen: ${{ needs.gen.result }}" - echo "- test-go: ${{ needs.test-go.result }}" echo "- test-go-pg: ${{ needs.test-go-pg.result }}" - echo "- test-go-race: ${{ needs.test-go-race.result }}" echo "- test-go-race-pg: ${{ needs.test-go-race-pg.result }}" echo "- test-js: ${{ needs.test-js.result }}" echo "- test-e2e: ${{ needs.test-e2e.result }}" @@ -1278,6 +1136,8 @@ jobs: # do (see above). CODER_SIGN_WINDOWS: "1" CODER_WINDOWS_RESOURCES: "1" + CODER_SIGN_GPG: "1" + CODER_GPG_RELEASE_KEY_BASE64: ${{ secrets.GPG_RELEASE_KEY_BASE64 }} EV_KEY: ${{ secrets.EV_KEY }} EV_KEYSTORE: ${{ secrets.EV_KEYSTORE }} EV_TSA_URL: ${{ secrets.EV_TSA_URL }} @@ -1545,7 +1405,7 @@ jobs: uses: google-github-actions/setup-gcloud@77e7a554d41e2ee56fc945c52dfd3f33d12def9a # v2.1.4 - name: Set up Flux CLI - uses: fluxcd/flux2/action@bda4c8187e436462be0d072e728b67afa215c593 # v2.6.3 + uses: fluxcd/flux2/action@6bf37f6a560fd84982d67f853162e4b3c2235edb # v2.6.4 with: # Keep this and the github action up to date with the version of flux installed in dogfood cluster version: "2.5.1" diff --git a/.github/workflows/docs-ci.yaml b/.github/workflows/docs-ci.yaml index f65ab434a9309..39954783f1ba8 100644 --- a/.github/workflows/docs-ci.yaml +++ b/.github/workflows/docs-ci.yaml @@ -28,7 +28,7 @@ jobs: - name: Setup Node uses: ./.github/actions/setup-node - - uses: tj-actions/changed-files@cf79a64fed8a943fb1073260883d08fe0dfb4e56 # v45.0.7 + - uses: tj-actions/changed-files@055970845dd036d7345da7399b7e89f2e10f2b04 # v45.0.7 id: changed-files with: files: | diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 5e793d81397dc..5a1faa9bd1528 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -323,6 +323,8 @@ jobs: env: CODER_SIGN_WINDOWS: "1" CODER_SIGN_DARWIN: "1" + CODER_SIGN_GPG: "1" + CODER_GPG_RELEASE_KEY_BASE64: ${{ secrets.GPG_RELEASE_KEY_BASE64 }} CODER_WINDOWS_RESOURCES: "1" AC_CERTIFICATE_FILE: /tmp/apple_cert.p12 AC_CERTIFICATE_PASSWORD_FILE: /tmp/apple_cert_password.txt @@ -693,6 +695,8 @@ jobs: gsutil -h "Cache-Control:no-cache,max-age=0" cp build/helm/provisioner_helm_${version}.tgz gs://helm.coder.com/v2 gsutil -h "Cache-Control:no-cache,max-age=0" cp build/helm/index.yaml gs://helm.coder.com/v2 gsutil -h "Cache-Control:no-cache,max-age=0" cp helm/artifacthub-repo.yml gs://helm.coder.com/v2 + helm push build/coder_helm_${version}.tgz oci://ghcr.io/coder/chart + helm push build/provisioner_helm_${version}.tgz oci://ghcr.io/coder/chart - name: Upload artifacts to actions (if dry-run) if: ${{ inputs.dry_run }} diff --git a/.github/workflows/start-workspace.yaml b/.github/workflows/start-workspace.yaml index 975acd7e1d939..9c1106a040a0e 100644 --- a/.github/workflows/start-workspace.yaml +++ b/.github/workflows/start-workspace.yaml @@ -19,7 +19,7 @@ jobs: timeout-minutes: 5 steps: - name: Start Coder workspace - uses: coder/start-workspace-action@35a4608cefc7e8cc56573cae7c3b85304575cb72 + uses: coder/start-workspace-action@f97a681b4cc7985c9eef9963750c7cc6ebc93a19 with: github-token: ${{ secrets.GITHUB_TOKEN }} github-username: >- diff --git a/.golangci.yaml b/.golangci.yaml index 2e1e853a0425a..aeebaf47e29a6 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -181,7 +181,6 @@ linters-settings: issues: exclude-dirs: - - coderd/database/dbmem - node_modules - .git diff --git a/CLAUDE.md b/CLAUDE.md index 4df2514a45863..d5335a6d4d0b3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -44,7 +44,6 @@ 2. Run `make gen` 3. If audit errors: update `enterprise/audit/table.go` 4. Run `make gen` again -5. Update `coderd/database/dbmem/dbmem.go` in-memory implementations ### LSP Navigation (USE FIRST) @@ -116,9 +115,8 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID) 1. **Audit table errors** → Update `enterprise/audit/table.go` 2. **OAuth2 errors** → Return RFC-compliant format -3. **dbmem failures** → Update in-memory implementations -4. **Race conditions** → Use unique test identifiers -5. **Missing newlines** → Ensure files end with newline +3. **Race conditions** → Use unique test identifiers +4. **Missing newlines** → Ensure files end with newline --- diff --git a/Makefile b/Makefile index 0ed464ba23a80..bd3f04a4874cd 100644 --- a/Makefile +++ b/Makefile @@ -252,6 +252,10 @@ $(CODER_ALL_BINARIES): go.mod go.sum \ fi cp "$@" "./site/out/bin/coder-$$os-$$arch$$dot_ext" + + if [[ "$${CODER_SIGN_GPG:-0}" == "1" ]]; then + cp "$@.asc" "./site/out/bin/coder-$$os-$$arch$$dot_ext.asc" + fi fi # This task builds Coder Desktop dylibs @@ -599,7 +603,6 @@ DB_GEN_FILES := \ coderd/database/dump.sql \ coderd/database/querier.go \ coderd/database/unique_constraint.go \ - coderd/database/dbmem/dbmem.go \ coderd/database/dbmetrics/dbmetrics.go \ coderd/database/dbauthz/dbauthz.go \ coderd/database/dbmock/dbmock.go @@ -973,7 +976,7 @@ sqlc-vet: test-postgres-docker test-postgres: test-postgres-docker # The postgres test is prone to failure, so we limit parallelism for # more consistent execution. - $(GIT_FLAGS) DB=ci gotestsum \ + $(GIT_FLAGS) gotestsum \ --junitfile="gotests.xml" \ --jsonfile="gotests.json" \ $(GOTESTSUM_RETRY_FLAGS) \ diff --git a/agent/agent.go b/agent/agent.go index 3c02b5f2790f0..75117769d8e2d 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -336,18 +336,16 @@ func (a *agent) init() { // will not report anywhere. a.scriptRunner.RegisterMetrics(a.prometheusRegistry) - if a.devcontainers { - containerAPIOpts := []agentcontainers.Option{ - agentcontainers.WithExecer(a.execer), - agentcontainers.WithCommandEnv(a.sshServer.CommandEnv), - agentcontainers.WithScriptLogger(func(logSourceID uuid.UUID) agentcontainers.ScriptLogger { - return a.logSender.GetScriptLogger(logSourceID) - }), - } - containerAPIOpts = append(containerAPIOpts, a.containerAPIOptions...) - - a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...) + containerAPIOpts := []agentcontainers.Option{ + agentcontainers.WithExecer(a.execer), + agentcontainers.WithCommandEnv(a.sshServer.CommandEnv), + agentcontainers.WithScriptLogger(func(logSourceID uuid.UUID) agentcontainers.ScriptLogger { + return a.logSender.GetScriptLogger(logSourceID) + }), } + containerAPIOpts = append(containerAPIOpts, a.containerAPIOptions...) + + a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...) a.reconnectingPTYServer = reconnectingpty.NewServer( a.logger.Named("reconnecting-pty"), @@ -1162,7 +1160,7 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context, scripts = manifest.Scripts devcontainerScripts map[uuid.UUID]codersdk.WorkspaceAgentScript ) - if a.containerAPI != nil { + if a.devcontainers { // Init the container API with the manifest and client so that // we can start accepting requests. The final start of the API // happens after the startup scripts have been executed to @@ -1197,7 +1195,7 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context, // autostarted devcontainer will be included in this time. err := a.scriptRunner.Execute(a.gracefulCtx, agentscripts.ExecuteStartScripts) - if a.containerAPI != nil { + if a.devcontainers { // Start the container API after the startup scripts have // been executed to ensure that the required tools can be // installed. @@ -1928,10 +1926,8 @@ func (a *agent) Close() error { a.logger.Error(a.hardCtx, "script runner close", slog.Error(err)) } - if a.containerAPI != nil { - if err := a.containerAPI.Close(); err != nil { - a.logger.Error(a.hardCtx, "container API close", slog.Error(err)) - } + if err := a.containerAPI.Close(); err != nil { + a.logger.Error(a.hardCtx, "container API close", slog.Error(err)) } // Wait for the graceful shutdown to complete, but don't wait forever so diff --git a/agent/agent_test.go b/agent/agent_test.go index 4a9141bd37f9e..d87148be9ad15 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -2441,7 +2441,8 @@ func TestAgent_DevcontainersDisabledForSubAgent(t *testing.T) { // Setup the agent with devcontainers enabled initially. //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, manifest, 0, func(*agenttest.Client, *agent.Options) { + conn, _, _, _, _ := setupAgent(t, manifest, 0, func(_ *agenttest.Client, o *agent.Options) { + o.Devcontainers = true }) // Query the containers API endpoint. This should fail because @@ -2453,8 +2454,8 @@ func TestAgent_DevcontainersDisabledForSubAgent(t *testing.T) { require.Error(t, err) // Verify the error message contains the expected text. - require.Contains(t, err.Error(), "The agent dev containers feature is experimental and not enabled by default.") - require.Contains(t, err.Error(), "To enable this feature, set CODER_AGENT_DEVCONTAINERS_ENABLE=true in your template.") + require.Contains(t, err.Error(), "Dev Container feature not supported.") + require.Contains(t, err.Error(), "Dev Container integration inside other Dev Containers is explicitly not supported.") } func TestAgent_Dial(t *testing.T) { diff --git a/agent/agentcontainers/api.go b/agent/agentcontainers/api.go index d749bf88a522e..dc92a4d38d9a2 100644 --- a/agent/agentcontainers/api.go +++ b/agent/agentcontainers/api.go @@ -2,8 +2,10 @@ package agentcontainers import ( "context" + "encoding/json" "errors" "fmt" + "maps" "net/http" "os" "path" @@ -30,6 +32,7 @@ import ( "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/provisioner" "github.com/coder/quartz" + "github.com/coder/websocket" ) const ( @@ -74,6 +77,7 @@ type API struct { mu sync.RWMutex // Protects the following fields. initDone chan struct{} // Closed by Init. + updateChans []chan struct{} closed bool containers codersdk.WorkspaceAgentListContainersResponse // Output from the last list operation. containersErr error // Error from the last list operation. @@ -535,6 +539,7 @@ func (api *API) Routes() http.Handler { r.Use(ensureInitDoneMW) r.Get("/", api.handleList) + r.Get("/watch", api.watchContainers) // TODO(mafredri): Simplify this route as the previous /devcontainers // /-route was dropped. We can drop the /devcontainers prefix here too. r.Route("/devcontainers/{devcontainer}", func(r chi.Router) { @@ -544,6 +549,88 @@ func (api *API) Routes() http.Handler { return r } +func (api *API) broadcastUpdatesLocked() { + // Broadcast state changes to WebSocket listeners. + for _, ch := range api.updateChans { + select { + case ch <- struct{}{}: + default: + } + } +} + +func (api *API) watchContainers(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + conn, err := websocket.Accept(rw, r, nil) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to upgrade connection to websocket.", + Detail: err.Error(), + }) + return + } + + // Here we close the websocket for reading, so that the websocket library will handle pings and + // close frames. + _ = conn.CloseRead(context.Background()) + + ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText) + defer wsNetConn.Close() + + go httpapi.Heartbeat(ctx, conn) + + updateCh := make(chan struct{}, 1) + + api.mu.Lock() + api.updateChans = append(api.updateChans, updateCh) + api.mu.Unlock() + + defer func() { + api.mu.Lock() + api.updateChans = slices.DeleteFunc(api.updateChans, func(ch chan struct{}) bool { + return ch == updateCh + }) + close(updateCh) + api.mu.Unlock() + }() + + encoder := json.NewEncoder(wsNetConn) + + ct, err := api.getContainers() + if err != nil { + api.logger.Error(ctx, "unable to get containers", slog.Error(err)) + return + } + + if err := encoder.Encode(ct); err != nil { + api.logger.Error(ctx, "encode container list", slog.Error(err)) + return + } + + for { + select { + case <-api.ctx.Done(): + return + + case <-ctx.Done(): + return + + case <-updateCh: + ct, err := api.getContainers() + if err != nil { + api.logger.Error(ctx, "unable to get containers", slog.Error(err)) + continue + } + + if err := encoder.Encode(ct); err != nil { + api.logger.Error(ctx, "encode container list", slog.Error(err)) + return + } + } + } +} + // handleList handles the HTTP request to list containers. func (api *API) handleList(rw http.ResponseWriter, r *http.Request) { ct, err := api.getContainers() @@ -583,8 +670,26 @@ func (api *API) updateContainers(ctx context.Context) error { api.mu.Lock() defer api.mu.Unlock() + var previouslyKnownDevcontainers map[string]codersdk.WorkspaceAgentDevcontainer + if len(api.updateChans) > 0 { + previouslyKnownDevcontainers = maps.Clone(api.knownDevcontainers) + } + api.processUpdatedContainersLocked(ctx, updated) + if len(api.updateChans) > 0 { + statesAreEqual := maps.EqualFunc( + previouslyKnownDevcontainers, + api.knownDevcontainers, + func(dc1, dc2 codersdk.WorkspaceAgentDevcontainer) bool { + return dc1.Equals(dc2) + }) + + if !statesAreEqual { + api.broadcastUpdatesLocked() + } + } + api.logger.Debug(ctx, "containers updated successfully", slog.F("container_count", len(api.containers.Containers)), slog.F("warning_count", len(api.containers.Warnings)), slog.F("devcontainer_count", len(api.knownDevcontainers))) return nil @@ -955,6 +1060,8 @@ func (api *API) handleDevcontainerRecreate(w http.ResponseWriter, r *http.Reques dc.Container = nil dc.Error = "" api.knownDevcontainers[dc.WorkspaceFolder] = dc + api.broadcastUpdatesLocked() + go func() { _ = api.CreateDevcontainer(dc.WorkspaceFolder, dc.ConfigPath, WithRemoveExistingContainer()) }() @@ -1070,6 +1177,7 @@ func (api *API) CreateDevcontainer(workspaceFolder, configPath string, opts ...D dc.Error = "" api.recreateSuccessTimes[dc.WorkspaceFolder] = api.clock.Now("agentcontainers", "recreate", "successTimes") api.knownDevcontainers[dc.WorkspaceFolder] = dc + api.broadcastUpdatesLocked() api.mu.Unlock() // Ensure an immediate refresh to accurately reflect the diff --git a/agent/agentcontainers/api_test.go b/agent/agentcontainers/api_test.go index 37ce66e2c150b..75b9342379a35 100644 --- a/agent/agentcontainers/api_test.go +++ b/agent/agentcontainers/api_test.go @@ -36,6 +36,7 @@ import ( "github.com/coder/coder/v2/pty" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" + "github.com/coder/websocket" ) // fakeContainerCLI implements the agentcontainers.ContainerCLI interface for @@ -441,6 +442,178 @@ func TestAPI(t *testing.T) { logbuf.Reset() }) + t.Run("Watch", func(t *testing.T) { + t.Parallel() + + fakeContainer1 := fakeContainer(t, func(c *codersdk.WorkspaceAgentContainer) { + c.ID = "container1" + c.FriendlyName = "devcontainer1" + c.Image = "busybox:latest" + c.Labels = map[string]string{ + agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project1", + agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project1/.devcontainer/devcontainer.json", + } + }) + + fakeContainer2 := fakeContainer(t, func(c *codersdk.WorkspaceAgentContainer) { + c.ID = "container2" + c.FriendlyName = "devcontainer2" + c.Image = "ubuntu:latest" + c.Labels = map[string]string{ + agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project2", + agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project2/.devcontainer/devcontainer.json", + } + }) + + stages := []struct { + containers []codersdk.WorkspaceAgentContainer + expected codersdk.WorkspaceAgentListContainersResponse + }{ + { + containers: []codersdk.WorkspaceAgentContainer{fakeContainer1}, + expected: codersdk.WorkspaceAgentListContainersResponse{ + Containers: []codersdk.WorkspaceAgentContainer{fakeContainer1}, + Devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + Name: "project1", + WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel], + ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel], + Status: "running", + Container: &fakeContainer1, + }, + }, + }, + }, + { + containers: []codersdk.WorkspaceAgentContainer{fakeContainer1, fakeContainer2}, + expected: codersdk.WorkspaceAgentListContainersResponse{ + Containers: []codersdk.WorkspaceAgentContainer{fakeContainer1, fakeContainer2}, + Devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + Name: "project1", + WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel], + ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel], + Status: "running", + Container: &fakeContainer1, + }, + { + Name: "project2", + WorkspaceFolder: fakeContainer2.Labels[agentcontainers.DevcontainerLocalFolderLabel], + ConfigPath: fakeContainer2.Labels[agentcontainers.DevcontainerConfigFileLabel], + Status: "running", + Container: &fakeContainer2, + }, + }, + }, + }, + { + containers: []codersdk.WorkspaceAgentContainer{fakeContainer2}, + expected: codersdk.WorkspaceAgentListContainersResponse{ + Containers: []codersdk.WorkspaceAgentContainer{fakeContainer2}, + Devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + Name: "", + WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel], + ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel], + Status: "stopped", + Container: nil, + }, + { + Name: "project2", + WorkspaceFolder: fakeContainer2.Labels[agentcontainers.DevcontainerLocalFolderLabel], + ConfigPath: fakeContainer2.Labels[agentcontainers.DevcontainerConfigFileLabel], + Status: "running", + Container: &fakeContainer2, + }, + }, + }, + }, + } + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + mClock = quartz.NewMock(t) + updaterTickerTrap = mClock.Trap().TickerFunc("updaterLoop") + mCtrl = gomock.NewController(t) + mLister = acmock.NewMockContainerCLI(mCtrl) + logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ) + + // Set up initial state for immediate send on connection + mLister.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: stages[0].containers}, nil) + mLister.EXPECT().DetectArchitecture(gomock.Any(), gomock.Any()).Return("", nil).AnyTimes() + + api := agentcontainers.NewAPI(logger, + agentcontainers.WithClock(mClock), + agentcontainers.WithContainerCLI(mLister), + agentcontainers.WithWatcher(watcher.NewNoop()), + ) + api.Start() + defer api.Close() + + srv := httptest.NewServer(api.Routes()) + defer srv.Close() + + updaterTickerTrap.MustWait(ctx).MustRelease(ctx) + defer updaterTickerTrap.Close() + + client, res, err := websocket.Dial(ctx, srv.URL+"/watch", nil) + require.NoError(t, err) + if res != nil && res.Body != nil { + defer res.Body.Close() + } + + // Read initial state sent immediately on connection + mt, msg, err := client.Read(ctx) + require.NoError(t, err) + require.Equal(t, websocket.MessageText, mt) + + var got codersdk.WorkspaceAgentListContainersResponse + err = json.Unmarshal(msg, &got) + require.NoError(t, err) + + require.Equal(t, stages[0].expected.Containers, got.Containers) + require.Len(t, got.Devcontainers, len(stages[0].expected.Devcontainers)) + for j, expectedDev := range stages[0].expected.Devcontainers { + gotDev := got.Devcontainers[j] + require.Equal(t, expectedDev.Name, gotDev.Name) + require.Equal(t, expectedDev.WorkspaceFolder, gotDev.WorkspaceFolder) + require.Equal(t, expectedDev.ConfigPath, gotDev.ConfigPath) + require.Equal(t, expectedDev.Status, gotDev.Status) + require.Equal(t, expectedDev.Container, gotDev.Container) + } + + // Process remaining stages through updater loop + for i, stage := range stages[1:] { + mLister.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: stage.containers}, nil) + + // Given: We allow the update loop to progress + _, aw := mClock.AdvanceNext() + aw.MustWait(ctx) + + // When: We attempt to read a message from the socket. + mt, msg, err := client.Read(ctx) + require.NoError(t, err) + require.Equal(t, websocket.MessageText, mt) + + // Then: We expect the receieved message matches the expected response. + var got codersdk.WorkspaceAgentListContainersResponse + err = json.Unmarshal(msg, &got) + require.NoError(t, err) + + require.Equal(t, stages[i+1].expected.Containers, got.Containers) + require.Len(t, got.Devcontainers, len(stages[i+1].expected.Devcontainers)) + for j, expectedDev := range stages[i+1].expected.Devcontainers { + gotDev := got.Devcontainers[j] + require.Equal(t, expectedDev.Name, gotDev.Name) + require.Equal(t, expectedDev.WorkspaceFolder, gotDev.WorkspaceFolder) + require.Equal(t, expectedDev.ConfigPath, gotDev.ConfigPath) + require.Equal(t, expectedDev.Status, gotDev.Status) + require.Equal(t, expectedDev.Container, gotDev.Container) + } + } + }) + // List tests the API.getContainers method using a mock // implementation. It specifically tests caching behavior. t.Run("List", func(t *testing.T) { diff --git a/agent/agentcontainers/devcontainercli.go b/agent/agentcontainers/devcontainercli.go index 55e4708d46134..d7cd25f85a7b3 100644 --- a/agent/agentcontainers/devcontainercli.go +++ b/agent/agentcontainers/devcontainercli.go @@ -106,63 +106,63 @@ type DevcontainerCLI interface { // DevcontainerCLIUpOptions are options for the devcontainer CLI Up // command. -type DevcontainerCLIUpOptions func(*devcontainerCLIUpConfig) +type DevcontainerCLIUpOptions func(*DevcontainerCLIUpConfig) -type devcontainerCLIUpConfig struct { - args []string // Additional arguments for the Up command. - stdout io.Writer - stderr io.Writer +type DevcontainerCLIUpConfig struct { + Args []string // Additional arguments for the Up command. + Stdout io.Writer + Stderr io.Writer } // WithRemoveExistingContainer is an option to remove the existing // container. func WithRemoveExistingContainer() DevcontainerCLIUpOptions { - return func(o *devcontainerCLIUpConfig) { - o.args = append(o.args, "--remove-existing-container") + return func(o *DevcontainerCLIUpConfig) { + o.Args = append(o.Args, "--remove-existing-container") } } // WithUpOutput sets additional stdout and stderr writers for logs // during Up operations. func WithUpOutput(stdout, stderr io.Writer) DevcontainerCLIUpOptions { - return func(o *devcontainerCLIUpConfig) { - o.stdout = stdout - o.stderr = stderr + return func(o *DevcontainerCLIUpConfig) { + o.Stdout = stdout + o.Stderr = stderr } } // DevcontainerCLIExecOptions are options for the devcontainer CLI Exec // command. -type DevcontainerCLIExecOptions func(*devcontainerCLIExecConfig) +type DevcontainerCLIExecOptions func(*DevcontainerCLIExecConfig) -type devcontainerCLIExecConfig struct { - args []string // Additional arguments for the Exec command. - stdout io.Writer - stderr io.Writer +type DevcontainerCLIExecConfig struct { + Args []string // Additional arguments for the Exec command. + Stdout io.Writer + Stderr io.Writer } // WithExecOutput sets additional stdout and stderr writers for logs // during Exec operations. func WithExecOutput(stdout, stderr io.Writer) DevcontainerCLIExecOptions { - return func(o *devcontainerCLIExecConfig) { - o.stdout = stdout - o.stderr = stderr + return func(o *DevcontainerCLIExecConfig) { + o.Stdout = stdout + o.Stderr = stderr } } // WithExecContainerID sets the container ID to target a specific // container. func WithExecContainerID(id string) DevcontainerCLIExecOptions { - return func(o *devcontainerCLIExecConfig) { - o.args = append(o.args, "--container-id", id) + return func(o *DevcontainerCLIExecConfig) { + o.Args = append(o.Args, "--container-id", id) } } // WithRemoteEnv sets environment variables for the Exec command. func WithRemoteEnv(env ...string) DevcontainerCLIExecOptions { - return func(o *devcontainerCLIExecConfig) { + return func(o *DevcontainerCLIExecConfig) { for _, e := range env { - o.args = append(o.args, "--remote-env", e) + o.Args = append(o.Args, "--remote-env", e) } } } @@ -185,8 +185,8 @@ func WithReadConfigOutput(stdout, stderr io.Writer) DevcontainerCLIReadConfigOpt } } -func applyDevcontainerCLIUpOptions(opts []DevcontainerCLIUpOptions) devcontainerCLIUpConfig { - conf := devcontainerCLIUpConfig{stdout: io.Discard, stderr: io.Discard} +func applyDevcontainerCLIUpOptions(opts []DevcontainerCLIUpOptions) DevcontainerCLIUpConfig { + conf := DevcontainerCLIUpConfig{Stdout: io.Discard, Stderr: io.Discard} for _, opt := range opts { if opt != nil { opt(&conf) @@ -195,8 +195,8 @@ func applyDevcontainerCLIUpOptions(opts []DevcontainerCLIUpOptions) devcontainer return conf } -func applyDevcontainerCLIExecOptions(opts []DevcontainerCLIExecOptions) devcontainerCLIExecConfig { - conf := devcontainerCLIExecConfig{stdout: io.Discard, stderr: io.Discard} +func applyDevcontainerCLIExecOptions(opts []DevcontainerCLIExecOptions) DevcontainerCLIExecConfig { + conf := DevcontainerCLIExecConfig{Stdout: io.Discard, Stderr: io.Discard} for _, opt := range opts { if opt != nil { opt(&conf) @@ -241,7 +241,7 @@ func (d *devcontainerCLI) Up(ctx context.Context, workspaceFolder, configPath st if configPath != "" { args = append(args, "--config", configPath) } - args = append(args, conf.args...) + args = append(args, conf.Args...) cmd := d.execer.CommandContext(ctx, "devcontainer", args...) // Capture stdout for parsing and stream logs for both default and provided writers. @@ -251,14 +251,14 @@ func (d *devcontainerCLI) Up(ctx context.Context, workspaceFolder, configPath st &devcontainerCLILogWriter{ ctx: ctx, logger: logger.With(slog.F("stdout", true)), - writer: conf.stdout, + writer: conf.Stdout, }, ) // Stream stderr logs and provided writer if any. cmd.Stderr = &devcontainerCLILogWriter{ ctx: ctx, logger: logger.With(slog.F("stderr", true)), - writer: conf.stderr, + writer: conf.Stderr, } if err := cmd.Run(); err != nil { @@ -293,17 +293,17 @@ func (d *devcontainerCLI) Exec(ctx context.Context, workspaceFolder, configPath if configPath != "" { args = append(args, "--config", configPath) } - args = append(args, conf.args...) + args = append(args, conf.Args...) args = append(args, cmd) args = append(args, cmdArgs...) c := d.execer.CommandContext(ctx, "devcontainer", args...) - c.Stdout = io.MultiWriter(conf.stdout, &devcontainerCLILogWriter{ + c.Stdout = io.MultiWriter(conf.Stdout, &devcontainerCLILogWriter{ ctx: ctx, logger: logger.With(slog.F("stdout", true)), writer: io.Discard, }) - c.Stderr = io.MultiWriter(conf.stderr, &devcontainerCLILogWriter{ + c.Stderr = io.MultiWriter(conf.Stderr, &devcontainerCLILogWriter{ ctx: ctx, logger: logger.With(slog.F("stderr", true)), writer: io.Discard, diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index 08fa02ddb4565..159fe345483d2 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -453,7 +453,7 @@ func TestSSHServer_ClosesStdin(t *testing.T) { // exit code 1 if it hits EOF, which is what we want to test. cmdErrCh := make(chan error, 1) go func() { - cmdErrCh <- sess.Start(fmt.Sprintf("echo started; read; echo \"read exit code: $?\" > %s", filePath)) + cmdErrCh <- sess.Start(fmt.Sprintf(`echo started; echo "read exit code: $(read && echo 0 || echo 1)" > %s`, filePath)) }() cmdErr := testutil.RequireReceive(ctx, t, cmdErrCh) diff --git a/agent/api.go b/agent/api.go index 0458df7c58e1f..ca0760e130ffe 100644 --- a/agent/api.go +++ b/agent/api.go @@ -6,6 +6,7 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/google/uuid" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/codersdk" @@ -36,12 +37,19 @@ func (a *agent) apiHandler() http.Handler { cacheDuration: cacheDuration, } - if a.containerAPI != nil { + if a.devcontainers { r.Mount("/api/v0/containers", a.containerAPI.Routes()) + } else if manifest := a.manifest.Load(); manifest != nil && manifest.ParentID != uuid.Nil { + r.HandleFunc("/api/v0/containers", func(w http.ResponseWriter, r *http.Request) { + httpapi.Write(r.Context(), w, http.StatusForbidden, codersdk.Response{ + Message: "Dev Container feature not supported.", + Detail: "Dev Container integration inside other Dev Containers is explicitly not supported.", + }) + }) } else { r.HandleFunc("/api/v0/containers", func(w http.ResponseWriter, r *http.Request) { httpapi.Write(r.Context(), w, http.StatusForbidden, codersdk.Response{ - Message: "The agent dev containers feature is experimental and not enabled by default.", + Message: "Dev Container feature not enabled.", Detail: "To enable this feature, set CODER_AGENT_DEVCONTAINERS_ENABLE=true in your template.", }) }) diff --git a/cli/cliui/resources.go b/cli/cliui/resources.go index be112ea177200..36ce4194d72c8 100644 --- a/cli/cliui/resources.go +++ b/cli/cliui/resources.go @@ -12,6 +12,7 @@ import ( "golang.org/x/mod/semver" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/codersdk" "github.com/coder/pretty" ) @@ -29,6 +30,7 @@ type WorkspaceResourcesOptions struct { ServerVersion string ListeningPorts map[uuid.UUID]codersdk.WorkspaceAgentListeningPortsResponse Devcontainers map[uuid.UUID]codersdk.WorkspaceAgentListContainersResponse + ShowDetails bool } // WorkspaceResources displays the connection status and tree-view of provided resources. @@ -69,7 +71,11 @@ func WorkspaceResources(writer io.Writer, resources []codersdk.WorkspaceResource totalAgents := 0 for _, resource := range resources { - totalAgents += len(resource.Agents) + for _, agent := range resource.Agents { + if !agent.ParentID.Valid { + totalAgents++ + } + } } for _, resource := range resources { @@ -94,12 +100,15 @@ func WorkspaceResources(writer io.Writer, resources []codersdk.WorkspaceResource "", }) // Display all agents associated with the resource. - for index, agent := range resource.Agents { + agents := slice.Filter(resource.Agents, func(agent codersdk.WorkspaceAgent) bool { + return !agent.ParentID.Valid + }) + for index, agent := range agents { tableWriter.AppendRow(renderAgentRow(agent, index, totalAgents, options)) for _, row := range renderListeningPorts(options, agent.ID, index, totalAgents) { tableWriter.AppendRow(row) } - for _, row := range renderDevcontainers(options, agent.ID, index, totalAgents) { + for _, row := range renderDevcontainers(resources, options, agent.ID, index, totalAgents) { tableWriter.AppendRow(row) } } @@ -125,7 +134,7 @@ func renderAgentRow(agent codersdk.WorkspaceAgent, index, totalAgents int, optio } if !options.HideAccess { sshCommand := "coder ssh " + options.WorkspaceName - if totalAgents > 1 { + if totalAgents > 1 || len(options.Devcontainers) > 0 { sshCommand += "." + agent.Name } sshCommand = pretty.Sprint(DefaultStyles.Code, sshCommand) @@ -164,45 +173,129 @@ func renderPortRow(port codersdk.WorkspaceAgentListeningPort, idx, total int) ta return table.Row{sb.String()} } -func renderDevcontainers(wro WorkspaceResourcesOptions, agentID uuid.UUID, index, totalAgents int) []table.Row { +func renderDevcontainers(resources []codersdk.WorkspaceResource, wro WorkspaceResourcesOptions, agentID uuid.UUID, index, totalAgents int) []table.Row { var rows []table.Row if wro.Devcontainers == nil { return []table.Row{} } dc, ok := wro.Devcontainers[agentID] - if !ok || len(dc.Containers) == 0 { + if !ok || len(dc.Devcontainers) == 0 { return []table.Row{} } rows = append(rows, table.Row{ fmt.Sprintf(" %s─ %s", renderPipe(index, totalAgents), "Devcontainers"), }) - for idx, container := range dc.Containers { - rows = append(rows, renderDevcontainerRow(container, idx, len(dc.Containers))) + for idx, devcontainer := range dc.Devcontainers { + rows = append(rows, renderDevcontainerRow(resources, devcontainer, idx, len(dc.Devcontainers), wro)...) } return rows } -func renderDevcontainerRow(container codersdk.WorkspaceAgentContainer, index, total int) table.Row { - var row table.Row - var sb strings.Builder - _, _ = sb.WriteString(" ") - _, _ = sb.WriteString(renderPipe(index, total)) - _, _ = sb.WriteString("─ ") - _, _ = sb.WriteString(pretty.Sprintf(DefaultStyles.Code, "%s", container.FriendlyName)) - row = append(row, sb.String()) - sb.Reset() - if container.Running { - _, _ = sb.WriteString(pretty.Sprintf(DefaultStyles.Keyword, "(%s)", container.Status)) - } else { - _, _ = sb.WriteString(pretty.Sprintf(DefaultStyles.Error, "(%s)", container.Status)) +func renderDevcontainerRow(resources []codersdk.WorkspaceResource, devcontainer codersdk.WorkspaceAgentDevcontainer, index, total int, wro WorkspaceResourcesOptions) []table.Row { + var rows []table.Row + + // If the devcontainer is running and has an associated agent, we want to + // display the agent's details. Otherwise, we just display the devcontainer + // name and status. + var subAgent *codersdk.WorkspaceAgent + displayName := devcontainer.Name + if devcontainer.Agent != nil && devcontainer.Status == codersdk.WorkspaceAgentDevcontainerStatusRunning { + for _, resource := range resources { + if agent, found := slice.Find(resource.Agents, func(agent codersdk.WorkspaceAgent) bool { + return agent.ID == devcontainer.Agent.ID + }); found { + subAgent = &agent + break + } + } + if subAgent != nil { + displayName = subAgent.Name + displayName += fmt.Sprintf(" (%s, %s)", subAgent.OperatingSystem, subAgent.Architecture) + } + } + + if devcontainer.Container != nil { + displayName += " " + pretty.Sprint(DefaultStyles.Keyword, "["+devcontainer.Container.FriendlyName+"]") + } + + // Build the main row. + row := table.Row{ + fmt.Sprintf(" %s─ %s", renderPipe(index, total), displayName), + } + + // Add status, health, and version columns. + if !wro.HideAgentState { + if subAgent != nil { + row = append(row, renderAgentStatus(*subAgent)) + row = append(row, renderAgentHealth(*subAgent)) + row = append(row, renderAgentVersion(subAgent.Version, wro.ServerVersion)) + } else { + row = append(row, renderDevcontainerStatus(devcontainer.Status)) + row = append(row, "") // No health for devcontainer without agent. + row = append(row, "") // No version for devcontainer without agent. + } + } + + // Add access column. + if !wro.HideAccess { + if subAgent != nil { + accessString := fmt.Sprintf("coder ssh %s.%s", wro.WorkspaceName, subAgent.Name) + row = append(row, pretty.Sprint(DefaultStyles.Code, accessString)) + } else { + row = append(row, "") // No access for devcontainers without agent. + } + } + + rows = append(rows, row) + + // Add error message if present. + if errorMessage := devcontainer.Error; errorMessage != "" { + // Cap error message length for display. + if !wro.ShowDetails && len(errorMessage) > 80 { + errorMessage = errorMessage[:79] + "…" + } + errorRow := table.Row{ + " × " + pretty.Sprint(DefaultStyles.Error, errorMessage), + "", + "", + "", + } + if !wro.HideAccess { + errorRow = append(errorRow, "") + } + rows = append(rows, errorRow) + } + + // Add listening ports for the devcontainer agent. + if subAgent != nil { + portRows := renderListeningPorts(wro, subAgent.ID, index, total) + for _, portRow := range portRows { + // Adjust indentation for ports under devcontainer agent. + if len(portRow) > 0 { + if str, ok := portRow[0].(string); ok { + portRow[0] = " " + str // Add extra indentation. + } + } + rows = append(rows, portRow) + } + } + + return rows +} + +func renderDevcontainerStatus(status codersdk.WorkspaceAgentDevcontainerStatus) string { + switch status { + case codersdk.WorkspaceAgentDevcontainerStatusRunning: + return pretty.Sprint(DefaultStyles.Keyword, "▶ running") + case codersdk.WorkspaceAgentDevcontainerStatusStopped: + return pretty.Sprint(DefaultStyles.Placeholder, "⏹ stopped") + case codersdk.WorkspaceAgentDevcontainerStatusStarting: + return pretty.Sprint(DefaultStyles.Warn, "⧗ starting") + case codersdk.WorkspaceAgentDevcontainerStatusError: + return pretty.Sprint(DefaultStyles.Error, "✘ error") + default: + return pretty.Sprint(DefaultStyles.Placeholder, "○ "+string(status)) } - row = append(row, sb.String()) - sb.Reset() - // "health" is not applicable here. - row = append(row, sb.String()) - _, _ = sb.WriteString(container.Image) - row = append(row, sb.String()) - return row } func renderAgentStatus(agent codersdk.WorkspaceAgent) string { diff --git a/cli/exp_rpty.go b/cli/exp_rpty.go index 48074c7ef5fb9..70154c57ea9bc 100644 --- a/cli/exp_rpty.go +++ b/cli/exp_rpty.go @@ -97,7 +97,7 @@ func handleRPTY(inv *serpent.Invocation, client *codersdk.Client, args handleRPT reconnectID = uuid.New() } - ws, agt, err := getWorkspaceAndAgent(ctx, inv, client, true, args.NamedWorkspace) + ws, agt, _, err := getWorkspaceAndAgent(ctx, inv, client, true, args.NamedWorkspace) if err != nil { return err } diff --git a/cli/open.go b/cli/open.go index ff950b552a853..cc21ea863430d 100644 --- a/cli/open.go +++ b/cli/open.go @@ -11,7 +11,9 @@ import ( "runtime" "slices" "strings" + "time" + "github.com/google/uuid" "github.com/skratchdot/open-golang/open" "golang.org/x/xerrors" @@ -42,7 +44,6 @@ func (r *RootCmd) openVSCode() *serpent.Command { generateToken bool testOpenError bool appearanceConfig codersdk.AppearanceConfig - containerName string ) client := new(codersdk.Client) @@ -71,7 +72,7 @@ func (r *RootCmd) openVSCode() *serpent.Command { // need to wait for the agent to start. workspaceQuery := inv.Args[0] autostart := true - workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, inv, client, autostart, workspaceQuery) + workspace, workspaceAgent, otherWorkspaceAgents, err := getWorkspaceAndAgent(ctx, inv, client, autostart, workspaceQuery) if err != nil { return xerrors.Errorf("get workspace and agent: %w", err) } @@ -79,6 +80,70 @@ func (r *RootCmd) openVSCode() *serpent.Command { workspaceName := workspace.Name + "." + workspaceAgent.Name insideThisWorkspace := insideAWorkspace && inWorkspaceName == workspaceName + // To properly work with devcontainers, VS Code has to connect to + // parent workspace agent. It will then proceed to enter the + // container given the correct parameters. There is inherently no + // dependency on the devcontainer agent in this scenario, but + // relying on it simplifies the logic and ensures the devcontainer + // is ready. To eliminate the dependency we would need to know that + // a sub-agent that hasn't been created yet may be a devcontainer, + // and thus will be created at a later time as well as expose the + // container folder on the API response. + var parentWorkspaceAgent codersdk.WorkspaceAgent + var devcontainer codersdk.WorkspaceAgentDevcontainer + if workspaceAgent.ParentID.Valid { + // This is likely a devcontainer agent, so we need to find the + // parent workspace agent as well as the devcontainer. + for _, otherAgent := range otherWorkspaceAgents { + if otherAgent.ID == workspaceAgent.ParentID.UUID { + parentWorkspaceAgent = otherAgent + break + } + } + if parentWorkspaceAgent.ID == uuid.Nil { + return xerrors.Errorf("parent workspace agent %s not found", workspaceAgent.ParentID.UUID) + } + + printedWaiting := false + for { + resp, err := client.WorkspaceAgentListContainers(ctx, parentWorkspaceAgent.ID, nil) + if err != nil { + return xerrors.Errorf("list parent workspace agent containers: %w", err) + } + + for _, dc := range resp.Devcontainers { + if dc.Agent.ID == workspaceAgent.ID { + devcontainer = dc + break + } + } + if devcontainer.ID == uuid.Nil { + cliui.Warnf(inv.Stderr, "Devcontainer %q not found, opening as a regular workspace...", workspaceAgent.Name) + parentWorkspaceAgent = codersdk.WorkspaceAgent{} // Reset to empty, so we don't use it later. + break + } + + // Precondition, the devcontainer must be running to enter + // it. Once running, devcontainer.Container will be set. + if devcontainer.Status == codersdk.WorkspaceAgentDevcontainerStatusRunning { + break + } + if devcontainer.Status != codersdk.WorkspaceAgentDevcontainerStatusStarting { + return xerrors.Errorf("devcontainer %q is in unexpected status %q, expected %q or %q", + devcontainer.Name, devcontainer.Status, + codersdk.WorkspaceAgentDevcontainerStatusRunning, + codersdk.WorkspaceAgentDevcontainerStatusStarting, + ) + } + + if !printedWaiting { + _, _ = fmt.Fprintf(inv.Stderr, "Waiting for devcontainer %q status to change from %q to %q...\n", devcontainer.Name, devcontainer.Status, codersdk.WorkspaceAgentDevcontainerStatusRunning) + printedWaiting = true + } + time.Sleep(5 * time.Second) // Wait a bit before retrying. + } + } + if !insideThisWorkspace { // Wait for the agent to connect, we don't care about readiness // otherwise (e.g. wait). @@ -99,6 +164,9 @@ func (r *RootCmd) openVSCode() *serpent.Command { // the created state, so we need to wait for that to happen. // However, if no directory is set, the expanded directory will // not be set either. + // + // Note that this is irrelevant for devcontainer sub agents, as + // they always have a directory set. if workspaceAgent.Directory != "" { workspace, workspaceAgent, err = waitForAgentCond(ctx, client, workspace, workspaceAgent, func(_ codersdk.WorkspaceAgent) bool { return workspaceAgent.LifecycleState != codersdk.WorkspaceAgentLifecycleCreated @@ -114,41 +182,6 @@ func (r *RootCmd) openVSCode() *serpent.Command { directory = inv.Args[1] } - if containerName != "" { - containers, err := client.WorkspaceAgentListContainers(ctx, workspaceAgent.ID, map[string]string{"devcontainer.local_folder": ""}) - if err != nil { - return xerrors.Errorf("list workspace agent containers: %w", err) - } - - var foundContainer bool - - for _, container := range containers.Containers { - if container.FriendlyName != containerName { - continue - } - - foundContainer = true - - if directory == "" { - localFolder, ok := container.Labels["devcontainer.local_folder"] - if !ok { - return xerrors.New("container missing `devcontainer.local_folder` label") - } - - directory, ok = container.Volumes[localFolder] - if !ok { - return xerrors.New("container missing volume for `devcontainer.local_folder`") - } - } - - break - } - - if !foundContainer { - return xerrors.New("no container found") - } - } - directory, err = resolveAgentAbsPath(workspaceAgent.ExpandedDirectory, directory, workspaceAgent.OperatingSystem, insideThisWorkspace) if err != nil { return xerrors.Errorf("resolve agent path: %w", err) @@ -174,14 +207,16 @@ func (r *RootCmd) openVSCode() *serpent.Command { u *url.URL qp url.Values ) - if containerName != "" { + if devcontainer.ID != uuid.Nil { u, qp = buildVSCodeWorkspaceDevContainerLink( token, client.URL.String(), workspace, - workspaceAgent, - containerName, + parentWorkspaceAgent, + devcontainer.Container.FriendlyName, directory, + devcontainer.WorkspaceFolder, + devcontainer.ConfigPath, ) } else { u, qp = buildVSCodeWorkspaceLink( @@ -247,13 +282,6 @@ func (r *RootCmd) openVSCode() *serpent.Command { ), Value: serpent.BoolOf(&generateToken), }, - { - Flag: "container", - FlagShorthand: "c", - Description: "Container name to connect to in the workspace.", - Value: serpent.StringOf(&containerName), - Hidden: true, // Hidden until this features is at least in beta. - }, { Flag: "test.open-error", Description: "Don't run the open command.", @@ -288,7 +316,7 @@ func (r *RootCmd) openApp() *serpent.Command { } workspaceName := inv.Args[0] - ws, agt, err := getWorkspaceAndAgent(ctx, inv, client, false, workspaceName) + ws, agt, _, err := getWorkspaceAndAgent(ctx, inv, client, false, workspaceName) if err != nil { var sdkErr *codersdk.Error if errors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusNotFound { @@ -430,8 +458,14 @@ func buildVSCodeWorkspaceDevContainerLink( workspaceAgent codersdk.WorkspaceAgent, containerName string, containerFolder string, + localWorkspaceFolder string, + localConfigFile string, ) (*url.URL, url.Values) { containerFolder = filepath.ToSlash(containerFolder) + localWorkspaceFolder = filepath.ToSlash(localWorkspaceFolder) + if localConfigFile != "" { + localConfigFile = filepath.ToSlash(localConfigFile) + } qp := url.Values{} qp.Add("url", clientURL) @@ -440,6 +474,8 @@ func buildVSCodeWorkspaceDevContainerLink( qp.Add("agent", workspaceAgent.Name) qp.Add("devContainerName", containerName) qp.Add("devContainerFolder", containerFolder) + qp.Add("localWorkspaceFolder", localWorkspaceFolder) + qp.Add("localConfigFile", localConfigFile) if token != "" { qp.Add("token", token) @@ -469,7 +505,7 @@ func waitForAgentCond(ctx context.Context, client *codersdk.Client, workspace co } for workspace = range wc { - workspaceAgent, err = getWorkspaceAgent(workspace, workspaceAgent.Name) + workspaceAgent, _, err = getWorkspaceAgent(workspace, workspaceAgent.Name) if err != nil { return workspace, workspaceAgent, xerrors.Errorf("get workspace agent: %w", err) } diff --git a/cli/open_test.go b/cli/open_test.go index b76b603d35b1e..e8d4aa3e65b2e 100644 --- a/cli/open_test.go +++ b/cli/open_test.go @@ -1,8 +1,10 @@ package cli_test import ( + "context" "net/url" "os" + "path" "path/filepath" "runtime" "strings" @@ -11,11 +13,11 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" + "golang.org/x/xerrors" "github.com/coder/coder/v2/agent" "github.com/coder/coder/v2/agent/agentcontainers" - "github.com/coder/coder/v2/agent/agentcontainers/acmock" + "github.com/coder/coder/v2/agent/agentcontainers/watcher" "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/coderd/coderdtest" @@ -289,238 +291,145 @@ func TestOpenVSCode_NoAgentDirectory(t *testing.T) { } } -func TestOpenVSCodeDevContainer(t *testing.T) { - t.Parallel() +type fakeContainerCLI struct { + resp codersdk.WorkspaceAgentListContainersResponse +} - if runtime.GOOS != "linux" { - t.Skip("DevContainers are only supported for agents on Linux") - } +func (f *fakeContainerCLI) List(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { + return f.resp, nil +} - agentName := "agent1" - agentDir, err := filepath.Abs(filepath.FromSlash("/tmp")) - require.NoError(t, err) +func (*fakeContainerCLI) DetectArchitecture(ctx context.Context, containerID string) (string, error) { + return runtime.GOARCH, nil +} - containerName := testutil.GetRandomName(t) - containerFolder := "/workspace/coder" +func (*fakeContainerCLI) Copy(ctx context.Context, containerID, src, dst string) error { + return nil +} - ctrl := gomock.NewController(t) - mccli := acmock.NewMockContainerCLI(ctrl) - mccli.EXPECT().List(gomock.Any()).Return( - codersdk.WorkspaceAgentListContainersResponse{ - Containers: []codersdk.WorkspaceAgentContainer{ - { - ID: uuid.NewString(), - CreatedAt: dbtime.Now(), - FriendlyName: containerName, - Image: "busybox:latest", - Labels: map[string]string{ - "devcontainer.local_folder": "/home/coder/coder", - }, - Running: true, - Status: "running", - Volumes: map[string]string{ - "/home/coder/coder": containerFolder, - }, - }, - }, - }, nil, - ).AnyTimes() +func (*fakeContainerCLI) ExecAs(ctx context.Context, containerID, user string, args ...string) ([]byte, error) { + return nil, nil +} - client, workspace, agentToken := setupWorkspaceForAgent(t, func(agents []*proto.Agent) []*proto.Agent { - agents[0].Directory = agentDir - agents[0].Name = agentName - agents[0].OperatingSystem = runtime.GOOS - return agents - }) +type fakeDevcontainerCLI struct { + config agentcontainers.DevcontainerConfig + execAgent func(ctx context.Context, token string) error +} - _ = agenttest.New(t, client.URL, agentToken, func(o *agent.Options) { - o.Devcontainers = true - o.DevcontainerAPIOptions = append(o.DevcontainerAPIOptions, - agentcontainers.WithContainerCLI(mccli), - agentcontainers.WithContainerLabelIncludeFilter("this.label.does.not.exist.ignore.devcontainers", "true"), - ) - }) - _ = coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() +func (f *fakeDevcontainerCLI) ReadConfig(ctx context.Context, workspaceFolder, configFile string, env []string, opts ...agentcontainers.DevcontainerCLIReadConfigOptions) (agentcontainers.DevcontainerConfig, error) { + return f.config, nil +} - insideWorkspaceEnv := map[string]string{ - "CODER": "true", - "CODER_WORKSPACE_NAME": workspace.Name, - "CODER_WORKSPACE_AGENT_NAME": agentName, +func (f *fakeDevcontainerCLI) Exec(ctx context.Context, workspaceFolder, configFile string, name string, args []string, opts ...agentcontainers.DevcontainerCLIExecOptions) error { + var opt agentcontainers.DevcontainerCLIExecConfig + for _, o := range opts { + o(&opt) } - - wd, err := os.Getwd() - require.NoError(t, err) - - tests := []struct { - name string - env map[string]string - args []string - wantDir string - wantError bool - wantToken bool - }{ - { - name: "nonexistent container", - args: []string{"--test.open-error", workspace.Name, "--container", containerName + "bad"}, - wantError: true, - }, - { - name: "ok", - args: []string{"--test.open-error", workspace.Name, "--container", containerName}, - wantDir: containerFolder, - wantError: false, - }, - { - name: "ok with absolute path", - args: []string{"--test.open-error", workspace.Name, "--container", containerName, containerFolder}, - wantDir: containerFolder, - wantError: false, - }, - { - name: "ok with relative path", - args: []string{"--test.open-error", workspace.Name, "--container", containerName, "my/relative/path"}, - wantDir: filepath.Join(agentDir, filepath.FromSlash("my/relative/path")), - wantError: false, - }, - { - name: "ok with token", - args: []string{"--test.open-error", workspace.Name, "--container", containerName, "--generate-token"}, - wantDir: containerFolder, - wantError: false, - wantToken: true, - }, - // Inside workspace, does not require --test.open-error - { - name: "ok inside workspace", - env: insideWorkspaceEnv, - args: []string{workspace.Name, "--container", containerName}, - wantDir: containerFolder, - }, - { - name: "ok inside workspace relative path", - env: insideWorkspaceEnv, - args: []string{workspace.Name, "--container", containerName, "foo"}, - wantDir: filepath.Join(wd, "foo"), - }, - { - name: "ok inside workspace token", - env: insideWorkspaceEnv, - args: []string{workspace.Name, "--container", containerName, "--generate-token"}, - wantDir: containerFolder, - wantToken: true, - }, + var token string + for _, arg := range opt.Args { + if strings.HasPrefix(arg, "CODER_AGENT_TOKEN=") { + token = strings.TrimPrefix(arg, "CODER_AGENT_TOKEN=") + break + } } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - inv, root := clitest.New(t, append([]string{"open", "vscode"}, tt.args...)...) - clitest.SetupConfig(t, client, root) - - pty := ptytest.New(t) - inv.Stdin = pty.Input() - inv.Stdout = pty.Output() - - ctx := testutil.Context(t, testutil.WaitLong) - inv = inv.WithContext(ctx) - - for k, v := range tt.env { - inv.Environ.Set(k, v) - } - - w := clitest.StartWithWaiter(t, inv) - - if tt.wantError { - w.RequireError() - return - } - - me, err := client.User(ctx, codersdk.Me) - require.NoError(t, err) - - line := pty.ReadLine(ctx) - u, err := url.ParseRequestURI(line) - require.NoError(t, err, "line: %q", line) - - qp := u.Query() - assert.Equal(t, client.URL.String(), qp.Get("url")) - assert.Equal(t, me.Username, qp.Get("owner")) - assert.Equal(t, workspace.Name, qp.Get("workspace")) - assert.Equal(t, agentName, qp.Get("agent")) - assert.Equal(t, containerName, qp.Get("devContainerName")) - - if tt.wantDir != "" { - assert.Equal(t, tt.wantDir, qp.Get("devContainerFolder")) - } else { - assert.Equal(t, containerFolder, qp.Get("devContainerFolder")) - } - - if tt.wantToken { - assert.NotEmpty(t, qp.Get("token")) - } else { - assert.Empty(t, qp.Get("token")) - } - - w.RequireSuccess() - }) + if token == "" { + return xerrors.New("no agent token provided in args") + } + if f.execAgent == nil { + return nil } + return f.execAgent(ctx, token) +} + +func (*fakeDevcontainerCLI) Up(ctx context.Context, workspaceFolder, configFile string, opts ...agentcontainers.DevcontainerCLIUpOptions) (string, error) { + return "", nil } -func TestOpenVSCodeDevContainer_NoAgentDirectory(t *testing.T) { +func TestOpenVSCodeDevContainer(t *testing.T) { t.Parallel() if runtime.GOOS != "linux" { t.Skip("DevContainers are only supported for agents on Linux") } - agentName := "agent1" + parentAgentName := "agent1" + devcontainerID := uuid.New() + devcontainerName := "wilson" + workspaceFolder := "/home/coder/wilson" + configFile := path.Join(workspaceFolder, ".devcontainer", "devcontainer.json") + + containerID := uuid.NewString() containerName := testutil.GetRandomName(t) - containerFolder := "/workspace/coder" + containerFolder := "/workspaces/wilson" + + client, workspace, agentToken := setupWorkspaceForAgent(t, func(agents []*proto.Agent) []*proto.Agent { + agents[0].Name = parentAgentName + agents[0].OperatingSystem = runtime.GOOS + return agents + }) - ctrl := gomock.NewController(t) - mccli := acmock.NewMockContainerCLI(ctrl) - mccli.EXPECT().List(gomock.Any()).Return( - codersdk.WorkspaceAgentListContainersResponse{ + fCCLI := &fakeContainerCLI{ + resp: codersdk.WorkspaceAgentListContainersResponse{ Containers: []codersdk.WorkspaceAgentContainer{ { - ID: uuid.NewString(), + ID: containerID, CreatedAt: dbtime.Now(), FriendlyName: containerName, Image: "busybox:latest", Labels: map[string]string{ - "devcontainer.local_folder": "/home/coder/coder", + agentcontainers.DevcontainerLocalFolderLabel: workspaceFolder, + agentcontainers.DevcontainerConfigFileLabel: configFile, + agentcontainers.DevcontainerIsTestRunLabel: "true", + "coder.test": t.Name(), }, Running: true, Status: "running", - Volumes: map[string]string{ - "/home/coder/coder": containerFolder, - }, }, }, - }, nil, - ).AnyTimes() - - client, workspace, agentToken := setupWorkspaceForAgent(t, func(agents []*proto.Agent) []*proto.Agent { - agents[0].Name = agentName - agents[0].OperatingSystem = runtime.GOOS - return agents - }) + }, + } + fDCCLI := &fakeDevcontainerCLI{ + config: agentcontainers.DevcontainerConfig{ + Workspace: agentcontainers.DevcontainerWorkspace{ + WorkspaceFolder: containerFolder, + }, + }, + execAgent: func(ctx context.Context, token string) error { + t.Logf("Starting devcontainer subagent with token: %s", token) + _ = agenttest.New(t, client.URL, token) + <-ctx.Done() + return ctx.Err() + }, + } _ = agenttest.New(t, client.URL, agentToken, func(o *agent.Options) { o.Devcontainers = true o.DevcontainerAPIOptions = append(o.DevcontainerAPIOptions, - agentcontainers.WithContainerCLI(mccli), - agentcontainers.WithContainerLabelIncludeFilter("this.label.does.not.exist.ignore.devcontainers", "true"), + agentcontainers.WithContainerCLI(fCCLI), + agentcontainers.WithDevcontainerCLI(fDCCLI), + agentcontainers.WithWatcher(watcher.NewNoop()), + agentcontainers.WithDevcontainers( + []codersdk.WorkspaceAgentDevcontainer{{ + ID: devcontainerID, + Name: devcontainerName, + WorkspaceFolder: workspaceFolder, + Status: codersdk.WorkspaceAgentDevcontainerStatusStopped, + }}, + []codersdk.WorkspaceAgentScript{{ + ID: devcontainerID, + LogSourceID: uuid.New(), + }}, + ), + agentcontainers.WithContainerLabelIncludeFilter("coder.test", t.Name()), ) }) - _ = coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).AgentNames([]string{parentAgentName, devcontainerName}).Wait() insideWorkspaceEnv := map[string]string{ "CODER": "true", "CODER_WORKSPACE_NAME": workspace.Name, - "CODER_WORKSPACE_AGENT_NAME": agentName, + "CODER_WORKSPACE_AGENT_NAME": devcontainerName, } wd, err := os.Getwd() @@ -535,41 +444,48 @@ func TestOpenVSCodeDevContainer_NoAgentDirectory(t *testing.T) { wantToken bool }{ { - name: "ok", - args: []string{"--test.open-error", workspace.Name, "--container", containerName}, + name: "nonexistent container", + args: []string{"--test.open-error", workspace.Name + "." + devcontainerName + "bad"}, + wantError: true, }, { - name: "no agent dir error relative path", - args: []string{"--test.open-error", workspace.Name, "--container", containerName, "my/relative/path"}, - wantDir: filepath.FromSlash("my/relative/path"), - wantError: true, + name: "ok", + args: []string{"--test.open-error", workspace.Name + "." + devcontainerName}, + wantError: false, }, { - name: "ok with absolute path", - args: []string{"--test.open-error", workspace.Name, "--container", containerName, "/home/coder"}, - wantDir: "/home/coder", + name: "ok with absolute path", + args: []string{"--test.open-error", workspace.Name + "." + devcontainerName, containerFolder}, + wantError: false, + }, + { + name: "ok with relative path", + args: []string{"--test.open-error", workspace.Name + "." + devcontainerName, "my/relative/path"}, + wantDir: path.Join(containerFolder, "my/relative/path"), + wantError: false, }, { name: "ok with token", - args: []string{"--test.open-error", workspace.Name, "--container", containerName, "--generate-token"}, + args: []string{"--test.open-error", workspace.Name + "." + devcontainerName, "--generate-token"}, + wantError: false, wantToken: true, }, // Inside workspace, does not require --test.open-error { name: "ok inside workspace", env: insideWorkspaceEnv, - args: []string{workspace.Name, "--container", containerName}, + args: []string{workspace.Name + "." + devcontainerName}, }, { name: "ok inside workspace relative path", env: insideWorkspaceEnv, - args: []string{workspace.Name, "--container", containerName, "foo"}, + args: []string{workspace.Name + "." + devcontainerName, "foo"}, wantDir: filepath.Join(wd, "foo"), }, { name: "ok inside workspace token", env: insideWorkspaceEnv, - args: []string{workspace.Name, "--container", containerName, "--generate-token"}, + args: []string{workspace.Name + "." + devcontainerName, "--generate-token"}, wantToken: true, }, } @@ -610,8 +526,10 @@ func TestOpenVSCodeDevContainer_NoAgentDirectory(t *testing.T) { assert.Equal(t, client.URL.String(), qp.Get("url")) assert.Equal(t, me.Username, qp.Get("owner")) assert.Equal(t, workspace.Name, qp.Get("workspace")) - assert.Equal(t, agentName, qp.Get("agent")) + assert.Equal(t, parentAgentName, qp.Get("agent")) assert.Equal(t, containerName, qp.Get("devContainerName")) + assert.Equal(t, workspaceFolder, qp.Get("localWorkspaceFolder")) + assert.Equal(t, configFile, qp.Get("localConfigFile")) if tt.wantDir != "" { assert.Equal(t, tt.wantDir, qp.Get("devContainerFolder")) diff --git a/cli/ping.go b/cli/ping.go index ec094ea1a317b..0836aa8a135db 100644 --- a/cli/ping.go +++ b/cli/ping.go @@ -110,7 +110,7 @@ func (r *RootCmd) ping() *serpent.Command { defer notifyCancel() workspaceName := inv.Args[0] - _, workspaceAgent, err := getWorkspaceAndAgent( + _, workspaceAgent, _, err := getWorkspaceAndAgent( ctx, inv, client, false, // Do not autostart for a ping. workspaceName, diff --git a/cli/portforward.go b/cli/portforward.go index e6ef2eb11bca8..7a7723213f760 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -84,7 +84,7 @@ func (r *RootCmd) portForward() *serpent.Command { return xerrors.New("no port-forwards requested") } - workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, inv, client, !disableAutostart, inv.Args[0]) + workspace, workspaceAgent, _, err := getWorkspaceAndAgent(ctx, inv, client, !disableAutostart, inv.Args[0]) if err != nil { return err } diff --git a/cli/provisionerjobs.go b/cli/provisionerjobs.go index c2b6b78658447..2ddd04c5b6a29 100644 --- a/cli/provisionerjobs.go +++ b/cli/provisionerjobs.go @@ -166,7 +166,7 @@ func (r *RootCmd) provisionerJobsCancel() *serpent.Command { err = client.CancelTemplateVersion(ctx, ptr.NilToEmpty(job.Input.TemplateVersionID)) case codersdk.ProvisionerJobTypeWorkspaceBuild: _, _ = fmt.Fprintf(inv.Stdout, "Canceling workspace build job %s...\n", job.ID) - err = client.CancelWorkspaceBuild(ctx, ptr.NilToEmpty(job.Input.WorkspaceBuildID)) + err = client.CancelWorkspaceBuild(ctx, ptr.NilToEmpty(job.Input.WorkspaceBuildID), codersdk.CancelWorkspaceBuildParams{}) } if err != nil { return xerrors.Errorf("cancel provisioner job: %w", err) diff --git a/cli/server.go b/cli/server.go index 5074bffc3a342..602f05d028b66 100644 --- a/cli/server.go +++ b/cli/server.go @@ -77,7 +77,6 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/awsiamrds" "github.com/coder/coder/v2/coderd/database/dbauthz" - "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/database/dbmetrics" "github.com/coder/coder/v2/coderd/database/dbpurge" "github.com/coder/coder/v2/coderd/database/migrations" @@ -423,7 +422,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. builtinPostgres := false // Only use built-in if PostgreSQL URL isn't specified! - if !vals.InMemoryDatabase && vals.PostgresURL == "" { + if vals.PostgresURL == "" { var closeFunc func() error cliui.Infof(inv.Stdout, "Using built-in PostgreSQL (%s)", config.PostgresPath()) customPostgresCacheDir := "" @@ -726,42 +725,37 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. // nil, that case of the select will just never fire, but it's important not to have a // "bare" read on this channel. var pubsubWatchdogTimeout <-chan struct{} - if vals.InMemoryDatabase { - // This is only used for testing. - options.Database = dbmem.New() - options.Pubsub = pubsub.NewInMemory() - } else { - sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver) - if err != nil { - return xerrors.Errorf("connect to postgres: %w", err) - } - defer func() { - _ = sqlDB.Close() - }() - if options.DeploymentValues.Prometheus.Enable { - // At this stage we don't think the database name serves much purpose in these metrics. - // It requires parsing the DSN to determine it, which requires pulling in another dependency - // (i.e. https://github.com/jackc/pgx), but it's rather heavy. - // The conn string (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) can - // take different forms, which make parsing non-trivial. - options.PrometheusRegistry.MustRegister(collectors.NewDBStatsCollector(sqlDB, "")) - } + sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver) + if err != nil { + return xerrors.Errorf("connect to postgres: %w", err) + } + defer func() { + _ = sqlDB.Close() + }() - options.Database = database.New(sqlDB) - ps, err := pubsub.New(ctx, logger.Named("pubsub"), sqlDB, dbURL) - if err != nil { - return xerrors.Errorf("create pubsub: %w", err) - } - options.Pubsub = ps - if options.DeploymentValues.Prometheus.Enable { - options.PrometheusRegistry.MustRegister(ps) - } - defer options.Pubsub.Close() - psWatchdog := pubsub.NewWatchdog(ctx, logger.Named("pswatch"), ps) - pubsubWatchdogTimeout = psWatchdog.Timeout() - defer psWatchdog.Close() + if options.DeploymentValues.Prometheus.Enable { + // At this stage we don't think the database name serves much purpose in these metrics. + // It requires parsing the DSN to determine it, which requires pulling in another dependency + // (i.e. https://github.com/jackc/pgx), but it's rather heavy. + // The conn string (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) can + // take different forms, which make parsing non-trivial. + options.PrometheusRegistry.MustRegister(collectors.NewDBStatsCollector(sqlDB, "")) + } + + options.Database = database.New(sqlDB) + ps, err := pubsub.New(ctx, logger.Named("pubsub"), sqlDB, dbURL) + if err != nil { + return xerrors.Errorf("create pubsub: %w", err) + } + options.Pubsub = ps + if options.DeploymentValues.Prometheus.Enable { + options.PrometheusRegistry.MustRegister(ps) } + defer options.Pubsub.Close() + psWatchdog := pubsub.NewWatchdog(ctx, logger.Named("pswatch"), ps) + pubsubWatchdogTimeout = psWatchdog.Timeout() + defer psWatchdog.Close() if options.DeploymentValues.Prometheus.Enable && options.DeploymentValues.Prometheus.CollectDBMetrics { options.Database = dbmetrics.NewQueryMetrics(options.Database, options.Logger, options.PrometheusRegistry) diff --git a/cli/server_test.go b/cli/server_test.go index 2d0bbdd24e83b..435ed2879c9a3 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -59,9 +59,6 @@ import ( ) func dbArg(t *testing.T) string { - if !dbtestutil.WillUsePostgres() { - return "--in-memory" - } dbURL, err := dbtestutil.Open(t) require.NoError(t, err) return "--postgres-url=" + dbURL diff --git a/cli/show.go b/cli/show.go index f2d3df3ecc3c5..284e8581f5dda 100644 --- a/cli/show.go +++ b/cli/show.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" + "github.com/coder/coder/v2/agent/agentcontainers" "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/codersdk" "github.com/coder/serpent" @@ -15,9 +16,18 @@ import ( func (r *RootCmd) show() *serpent.Command { client := new(codersdk.Client) + var details bool return &serpent.Command{ Use: "show ", Short: "Display details of a workspace's resources and agents", + Options: serpent.OptionSet{ + { + Flag: "details", + Description: "Show full error messages and additional details.", + Default: "false", + Value: serpent.BoolOf(&details), + }, + }, Middleware: serpent.Chain( serpent.RequireNArgs(1), r.InitClient(client), @@ -35,6 +45,7 @@ func (r *RootCmd) show() *serpent.Command { options := cliui.WorkspaceResourcesOptions{ WorkspaceName: workspace.Name, ServerVersion: buildInfo.Version, + ShowDetails: details, } if workspace.LatestBuild.Status == codersdk.WorkspaceStatusRunning { // Get listening ports for each agent. @@ -42,6 +53,7 @@ func (r *RootCmd) show() *serpent.Command { options.ListeningPorts = ports options.Devcontainers = devcontainers } + return cliui.WorkspaceResources(inv.Stdout, workspace.LatestBuild.Resources, options) }, } @@ -68,13 +80,17 @@ func fetchRuntimeResources(inv *serpent.Invocation, client *codersdk.Client, res ports[agent.ID] = lp mu.Unlock() }() + + if agent.ParentID.Valid { + continue + } wg.Add(1) go func() { defer wg.Done() dc, err := client.WorkspaceAgentListContainers(inv.Context(), agent.ID, map[string]string{ // Labels set by VSCode Remote Containers and @devcontainers/cli. - "devcontainer.config_file": "", - "devcontainer.local_folder": "", + agentcontainers.DevcontainerConfigFileLabel: "", + agentcontainers.DevcontainerLocalFolderLabel: "", }) if err != nil { cliui.Warnf(inv.Stderr, "Failed to get devcontainers for agent %s: %v", agent.Name, err) diff --git a/cli/show_test.go b/cli/show_test.go index 7191898f8c0ec..36a5824174fc4 100644 --- a/cli/show_test.go +++ b/cli/show_test.go @@ -1,12 +1,19 @@ package cli_test import ( + "bytes" "testing" + "time" + "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/coder/coder/v2/agent/agentcontainers" "github.com/coder/coder/v2/cli/clitest" + "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/pty/ptytest" ) @@ -53,3 +60,354 @@ func TestShow(t *testing.T) { <-doneChan }) } + +func TestShowDevcontainers_Golden(t *testing.T) { + t.Parallel() + + mainAgentID := uuid.MustParse("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + agentID := mainAgentID + + testCases := []struct { + name string + showDetails bool + devcontainers []codersdk.WorkspaceAgentDevcontainer + listeningPorts map[uuid.UUID]codersdk.WorkspaceAgentListeningPortsResponse + }{ + { + name: "running_devcontainer_with_agent", + devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"), + Name: "web-dev", + WorkspaceFolder: "/workspaces/web-dev", + ConfigPath: "/workspaces/web-dev/.devcontainer/devcontainer.json", + Status: codersdk.WorkspaceAgentDevcontainerStatusRunning, + Dirty: false, + Container: &codersdk.WorkspaceAgentContainer{ + ID: "container-web-dev", + FriendlyName: "quirky_lovelace", + Image: "mcr.microsoft.com/devcontainers/typescript-node:1.0.0", + Running: true, + Status: "running", + CreatedAt: time.Now().Add(-1 * time.Hour), + Labels: map[string]string{ + agentcontainers.DevcontainerConfigFileLabel: "/workspaces/web-dev/.devcontainer/devcontainer.json", + agentcontainers.DevcontainerLocalFolderLabel: "/workspaces/web-dev", + }, + }, + Agent: &codersdk.WorkspaceAgentDevcontainerAgent{ + ID: uuid.MustParse("22222222-2222-2222-2222-222222222222"), + Name: "web-dev", + Directory: "/workspaces/web-dev", + }, + }, + }, + listeningPorts: map[uuid.UUID]codersdk.WorkspaceAgentListeningPortsResponse{ + uuid.MustParse("22222222-2222-2222-2222-222222222222"): { + Ports: []codersdk.WorkspaceAgentListeningPort{ + { + ProcessName: "node", + Network: "tcp", + Port: 3000, + }, + { + ProcessName: "webpack-dev-server", + Network: "tcp", + Port: 8080, + }, + }, + }, + }, + }, + { + name: "running_devcontainer_without_agent", + devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + ID: uuid.MustParse("33333333-3333-3333-3333-333333333333"), + Name: "web-server", + WorkspaceFolder: "/workspaces/web-server", + ConfigPath: "/workspaces/web-server/.devcontainer/devcontainer.json", + Status: codersdk.WorkspaceAgentDevcontainerStatusRunning, + Dirty: false, + Container: &codersdk.WorkspaceAgentContainer{ + ID: "container-web-server", + FriendlyName: "amazing_turing", + Image: "nginx:latest", + Running: true, + Status: "running", + CreatedAt: time.Now().Add(-30 * time.Minute), + Labels: map[string]string{ + agentcontainers.DevcontainerConfigFileLabel: "/workspaces/web-server/.devcontainer/devcontainer.json", + agentcontainers.DevcontainerLocalFolderLabel: "/workspaces/web-server", + }, + }, + Agent: nil, // No agent for this running container. + }, + }, + }, + { + name: "stopped_devcontainer", + devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + ID: uuid.MustParse("44444444-4444-4444-4444-444444444444"), + Name: "api-dev", + WorkspaceFolder: "/workspaces/api-dev", + ConfigPath: "/workspaces/api-dev/.devcontainer/devcontainer.json", + Status: codersdk.WorkspaceAgentDevcontainerStatusStopped, + Dirty: false, + Container: &codersdk.WorkspaceAgentContainer{ + ID: "container-api-dev", + FriendlyName: "clever_darwin", + Image: "mcr.microsoft.com/devcontainers/go:1.0.0", + Running: false, + Status: "exited", + CreatedAt: time.Now().Add(-2 * time.Hour), + Labels: map[string]string{ + agentcontainers.DevcontainerConfigFileLabel: "/workspaces/api-dev/.devcontainer/devcontainer.json", + agentcontainers.DevcontainerLocalFolderLabel: "/workspaces/api-dev", + }, + }, + Agent: nil, // No agent for stopped container. + }, + }, + }, + { + name: "starting_devcontainer", + devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + ID: uuid.MustParse("55555555-5555-5555-5555-555555555555"), + Name: "database-dev", + WorkspaceFolder: "/workspaces/database-dev", + ConfigPath: "/workspaces/database-dev/.devcontainer/devcontainer.json", + Status: codersdk.WorkspaceAgentDevcontainerStatusStarting, + Dirty: false, + Container: &codersdk.WorkspaceAgentContainer{ + ID: "container-database-dev", + FriendlyName: "nostalgic_hawking", + Image: "mcr.microsoft.com/devcontainers/postgres:1.0.0", + Running: false, + Status: "created", + CreatedAt: time.Now().Add(-5 * time.Minute), + Labels: map[string]string{ + agentcontainers.DevcontainerConfigFileLabel: "/workspaces/database-dev/.devcontainer/devcontainer.json", + agentcontainers.DevcontainerLocalFolderLabel: "/workspaces/database-dev", + }, + }, + Agent: nil, // No agent yet while starting. + }, + }, + }, + { + name: "error_devcontainer", + devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + ID: uuid.MustParse("66666666-6666-6666-6666-666666666666"), + Name: "failed-dev", + WorkspaceFolder: "/workspaces/failed-dev", + ConfigPath: "/workspaces/failed-dev/.devcontainer/devcontainer.json", + Status: codersdk.WorkspaceAgentDevcontainerStatusError, + Dirty: false, + Error: "Failed to pull image mcr.microsoft.com/devcontainers/go:latest: timeout after 5m0s", + Container: nil, // No container due to error. + Agent: nil, // No agent due to error. + }, + }, + }, + + { + name: "mixed_devcontainer_states", + devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + ID: uuid.MustParse("88888888-8888-8888-8888-888888888888"), + Name: "frontend", + WorkspaceFolder: "/workspaces/frontend", + Status: codersdk.WorkspaceAgentDevcontainerStatusRunning, + Container: &codersdk.WorkspaceAgentContainer{ + ID: "container-frontend", + FriendlyName: "vibrant_tesla", + Image: "node:18", + Running: true, + Status: "running", + CreatedAt: time.Now().Add(-30 * time.Minute), + }, + Agent: &codersdk.WorkspaceAgentDevcontainerAgent{ + ID: uuid.MustParse("99999999-9999-9999-9999-999999999999"), + Name: "frontend", + Directory: "/workspaces/frontend", + }, + }, + { + ID: uuid.MustParse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"), + Name: "backend", + WorkspaceFolder: "/workspaces/backend", + Status: codersdk.WorkspaceAgentDevcontainerStatusStopped, + Container: &codersdk.WorkspaceAgentContainer{ + ID: "container-backend", + FriendlyName: "peaceful_curie", + Image: "python:3.11", + Running: false, + Status: "exited", + CreatedAt: time.Now().Add(-1 * time.Hour), + }, + Agent: nil, + }, + { + ID: uuid.MustParse("bbbbbbbb-cccc-dddd-eeee-ffffffffffff"), + Name: "error-container", + WorkspaceFolder: "/workspaces/error-container", + Status: codersdk.WorkspaceAgentDevcontainerStatusError, + Error: "Container build failed: dockerfile syntax error on line 15", + Container: nil, + Agent: nil, + }, + }, + listeningPorts: map[uuid.UUID]codersdk.WorkspaceAgentListeningPortsResponse{ + uuid.MustParse("99999999-9999-9999-9999-999999999999"): { + Ports: []codersdk.WorkspaceAgentListeningPort{ + { + ProcessName: "vite", + Network: "tcp", + Port: 5173, + }, + }, + }, + }, + }, + { + name: "running_devcontainer_with_agent_and_error", + devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + ID: uuid.MustParse("cccccccc-dddd-eeee-ffff-000000000000"), + Name: "problematic-dev", + WorkspaceFolder: "/workspaces/problematic-dev", + ConfigPath: "/workspaces/problematic-dev/.devcontainer/devcontainer.json", + Status: codersdk.WorkspaceAgentDevcontainerStatusRunning, + Dirty: false, + Error: "Warning: Container started but healthcheck failed", + Container: &codersdk.WorkspaceAgentContainer{ + ID: "container-problematic", + FriendlyName: "cranky_mendel", + Image: "mcr.microsoft.com/devcontainers/python:1.0.0", + Running: true, + Status: "running", + CreatedAt: time.Now().Add(-15 * time.Minute), + Labels: map[string]string{ + agentcontainers.DevcontainerConfigFileLabel: "/workspaces/problematic-dev/.devcontainer/devcontainer.json", + agentcontainers.DevcontainerLocalFolderLabel: "/workspaces/problematic-dev", + }, + }, + Agent: &codersdk.WorkspaceAgentDevcontainerAgent{ + ID: uuid.MustParse("dddddddd-eeee-ffff-aaaa-111111111111"), + Name: "problematic-dev", + Directory: "/workspaces/problematic-dev", + }, + }, + }, + listeningPorts: map[uuid.UUID]codersdk.WorkspaceAgentListeningPortsResponse{ + uuid.MustParse("dddddddd-eeee-ffff-aaaa-111111111111"): { + Ports: []codersdk.WorkspaceAgentListeningPort{ + { + ProcessName: "python", + Network: "tcp", + Port: 8000, + }, + }, + }, + }, + }, + { + name: "long_error_message", + devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + ID: uuid.MustParse("eeeeeeee-ffff-0000-1111-222222222222"), + Name: "long-error-dev", + WorkspaceFolder: "/workspaces/long-error-dev", + ConfigPath: "/workspaces/long-error-dev/.devcontainer/devcontainer.json", + Status: codersdk.WorkspaceAgentDevcontainerStatusError, + Dirty: false, + Error: "Failed to build devcontainer: dockerfile parse error at line 25: unknown instruction 'INSTALL', did you mean 'RUN apt-get install'? This is a very long error message that should be truncated when detail flag is not used", + Container: nil, + Agent: nil, + }, + }, + }, + { + name: "long_error_message_with_detail", + showDetails: true, + devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + ID: uuid.MustParse("eeeeeeee-ffff-0000-1111-222222222222"), + Name: "long-error-dev", + WorkspaceFolder: "/workspaces/long-error-dev", + ConfigPath: "/workspaces/long-error-dev/.devcontainer/devcontainer.json", + Status: codersdk.WorkspaceAgentDevcontainerStatusError, + Dirty: false, + Error: "Failed to build devcontainer: dockerfile parse error at line 25: unknown instruction 'INSTALL', did you mean 'RUN apt-get install'? This is a very long error message that should be truncated when detail flag is not used", + Container: nil, + Agent: nil, + }, + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var allAgents []codersdk.WorkspaceAgent + mainAgent := codersdk.WorkspaceAgent{ + ID: mainAgentID, + Name: "main", + OperatingSystem: "linux", + Architecture: "amd64", + Status: codersdk.WorkspaceAgentConnected, + Health: codersdk.WorkspaceAgentHealth{Healthy: true}, + Version: "v2.15.0", + } + allAgents = append(allAgents, mainAgent) + + for _, dc := range tc.devcontainers { + if dc.Agent != nil { + devcontainerAgent := codersdk.WorkspaceAgent{ + ID: dc.Agent.ID, + ParentID: uuid.NullUUID{UUID: mainAgentID, Valid: true}, + Name: dc.Agent.Name, + OperatingSystem: "linux", + Architecture: "amd64", + Status: codersdk.WorkspaceAgentConnected, + Health: codersdk.WorkspaceAgentHealth{Healthy: true}, + Version: "v2.15.0", + } + allAgents = append(allAgents, devcontainerAgent) + } + } + + resources := []codersdk.WorkspaceResource{ + { + Type: "compute", + Name: "main", + Agents: allAgents, + }, + } + options := cliui.WorkspaceResourcesOptions{ + WorkspaceName: "test-workspace", + ServerVersion: "v2.15.0", + ShowDetails: tc.showDetails, + Devcontainers: map[uuid.UUID]codersdk.WorkspaceAgentListContainersResponse{ + agentID: { + Devcontainers: tc.devcontainers, + }, + }, + ListeningPorts: tc.listeningPorts, + } + + var buf bytes.Buffer + err := cliui.WorkspaceResources(&buf, resources, options) + require.NoError(t, err) + + replacements := map[string]string{} + clitest.TestGoldenFile(t, "TestShowDevcontainers_Golden/"+tc.name, buf.Bytes(), replacements) + }) + } +} diff --git a/cli/speedtest.go b/cli/speedtest.go index 0d9f839d6b458..08112f50cce2c 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -83,7 +83,7 @@ func (r *RootCmd) speedtest() *serpent.Command { return xerrors.Errorf("--direct (-d) is incompatible with --%s", varDisableDirect) } - _, workspaceAgent, err := getWorkspaceAndAgent(ctx, inv, client, false, inv.Args[0]) + _, workspaceAgent, _, err := getWorkspaceAndAgent(ctx, inv, client, false, inv.Args[0]) if err != nil { return err } diff --git a/cli/ssh.go b/cli/ssh.go index 56ab0b2a0d3af..9327a0101c0cf 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -754,7 +754,8 @@ func findWorkspaceAndAgentByHostname( hostname = strings.TrimSuffix(hostname, qualifiedSuffix) } hostname = normalizeWorkspaceInput(hostname) - return getWorkspaceAndAgent(ctx, inv, client, !disableAutostart, hostname) + ws, agent, _, err := getWorkspaceAndAgent(ctx, inv, client, !disableAutostart, hostname) + return ws, agent, err } // watchAndClose ensures closer is called if the context is canceled or @@ -827,9 +828,10 @@ startWatchLoop: } // getWorkspaceAgent returns the workspace and agent selected using either the -// `[.]` syntax via `in`. +// `[.]` syntax via `in`. It will also return any other agents +// in the workspace as a slice for use in child->parent lookups. // If autoStart is true, the workspace will be started if it is not already running. -func getWorkspaceAndAgent(ctx context.Context, inv *serpent.Invocation, client *codersdk.Client, autostart bool, input string) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive +func getWorkspaceAndAgent(ctx context.Context, inv *serpent.Invocation, client *codersdk.Client, autostart bool, input string) (codersdk.Workspace, codersdk.WorkspaceAgent, []codersdk.WorkspaceAgent, error) { //nolint:revive var ( workspace codersdk.Workspace // The input will be `owner/name.agent` @@ -840,27 +842,27 @@ func getWorkspaceAndAgent(ctx context.Context, inv *serpent.Invocation, client * workspace, err = namedWorkspace(ctx, client, workspaceParts[0]) if err != nil { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, err } if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart { if !autostart { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("workspace must be started") + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.New("workspace must be started") } // Autostart the workspace for the user. // For some failure modes, return a better message. if workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionDelete { // Any sort of deleting status, we should reject with a nicer error. - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q is deleted", workspace.Name) + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("workspace %q is deleted", workspace.Name) } if workspace.LatestBuild.Job.Status == codersdk.ProvisionerJobFailed { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("workspace %q is in failed state, unable to autostart the workspace", workspace.Name) } // The workspace needs to be stopped before we can start it. // It cannot be in any pending or failed state. if workspace.LatestBuild.Status != codersdk.WorkspaceStatusStopped { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("workspace must be started; was unable to autostart as the last build job is %q, expected %q", workspace.LatestBuild.Status, codersdk.WorkspaceStatusStopped, @@ -881,48 +883,48 @@ func getWorkspaceAndAgent(ctx context.Context, inv *serpent.Invocation, client * case http.StatusForbidden: _, err = startWorkspace(inv, client, workspace, workspaceParameterFlags{}, buildFlags{}, WorkspaceUpdate) if err != nil { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("start workspace with active template version: %w", err) + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with active template version: %w", err) } _, _ = fmt.Fprintln(inv.Stdout, "Unable to start the workspace with template version from last build. Your workspace has been updated to the current active template version.") } } else if err != nil { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("start workspace with current template version: %w", err) + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with current template version: %w", err) } // Refresh workspace state so that `outdated`, `build`,`template_*` fields are up-to-date. workspace, err = namedWorkspace(ctx, client, workspaceParts[0]) if err != nil { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, err } } if workspace.LatestBuild.Job.CompletedAt == nil { err := cliui.WorkspaceBuild(ctx, inv.Stderr, client, workspace.LatestBuild.ID) if err != nil { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, err } // Fetch up-to-date build information after completion. workspace.LatestBuild, err = client.WorkspaceBuild(ctx, workspace.LatestBuild.ID) if err != nil { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, err } } if workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionDelete { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q is being deleted", workspace.Name) + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("workspace %q is being deleted", workspace.Name) } var agentName string if len(workspaceParts) >= 2 { agentName = workspaceParts[1] } - workspaceAgent, err := getWorkspaceAgent(workspace, agentName) + workspaceAgent, otherWorkspaceAgents, err := getWorkspaceAgent(workspace, agentName) if err != nil { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, err } - return workspace, workspaceAgent, nil + return workspace, workspaceAgent, otherWorkspaceAgents, nil } -func getWorkspaceAgent(workspace codersdk.Workspace, agentName string) (workspaceAgent codersdk.WorkspaceAgent, err error) { +func getWorkspaceAgent(workspace codersdk.Workspace, agentName string) (workspaceAgent codersdk.WorkspaceAgent, otherAgents []codersdk.WorkspaceAgent, err error) { resources := workspace.LatestBuild.Resources var ( @@ -936,22 +938,23 @@ func getWorkspaceAgent(workspace codersdk.Workspace, agentName string) (workspac } } if len(agents) == 0 { - return codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q has no agents", workspace.Name) + return codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("workspace %q has no agents", workspace.Name) } slices.Sort(availableNames) if agentName != "" { - for _, otherAgent := range agents { - if otherAgent.Name != agentName { + for i, agent := range agents { + if agent.Name != agentName || agent.ID.String() == agentName { continue } - return otherAgent, nil + otherAgents := slices.Delete(agents, i, i+1) + return agent, otherAgents, nil } - return codersdk.WorkspaceAgent{}, xerrors.Errorf("agent not found by name %q, available agents: %v", agentName, availableNames) + return codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("agent not found by name %q, available agents: %v", agentName, availableNames) } if len(agents) == 1 { - return agents[0], nil + return agents[0], nil, nil } - return codersdk.WorkspaceAgent{}, xerrors.Errorf("multiple agents found, please specify the agent name, available agents: %v", availableNames) + return codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("multiple agents found, please specify the agent name, available agents: %v", availableNames) } // Attempt to poll workspace autostop. We write a per-workspace lockfile to diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index 0d956def68938..a7fac11c7254c 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -376,7 +376,7 @@ func Test_getWorkspaceAgent(t *testing.T) { agent := createAgent("main") workspace := createWorkspaceWithAgents([]codersdk.WorkspaceAgent{agent}) - result, err := getWorkspaceAgent(workspace, "") + result, _, err := getWorkspaceAgent(workspace, "") require.NoError(t, err) assert.Equal(t, agent.ID, result.ID) assert.Equal(t, "main", result.Name) @@ -388,7 +388,7 @@ func Test_getWorkspaceAgent(t *testing.T) { agent2 := createAgent("main2") workspace := createWorkspaceWithAgents([]codersdk.WorkspaceAgent{agent1, agent2}) - _, err := getWorkspaceAgent(workspace, "") + _, _, err := getWorkspaceAgent(workspace, "") require.Error(t, err) assert.Contains(t, err.Error(), "multiple agents found") assert.Contains(t, err.Error(), "available agents: [main1 main2]") @@ -400,10 +400,13 @@ func Test_getWorkspaceAgent(t *testing.T) { agent2 := createAgent("main2") workspace := createWorkspaceWithAgents([]codersdk.WorkspaceAgent{agent1, agent2}) - result, err := getWorkspaceAgent(workspace, "main1") + result, other, err := getWorkspaceAgent(workspace, "main1") require.NoError(t, err) assert.Equal(t, agent1.ID, result.ID) assert.Equal(t, "main1", result.Name) + assert.Len(t, other, 1) + assert.Equal(t, agent2.ID, other[0].ID) + assert.Equal(t, "main2", other[0].Name) }) t.Run("AgentNameSpecified_NotFound", func(t *testing.T) { @@ -412,7 +415,7 @@ func Test_getWorkspaceAgent(t *testing.T) { agent2 := createAgent("main2") workspace := createWorkspaceWithAgents([]codersdk.WorkspaceAgent{agent1, agent2}) - _, err := getWorkspaceAgent(workspace, "nonexistent") + _, _, err := getWorkspaceAgent(workspace, "nonexistent") require.Error(t, err) assert.Contains(t, err.Error(), `agent not found by name "nonexistent"`) assert.Contains(t, err.Error(), "available agents: [main1 main2]") @@ -422,7 +425,7 @@ func Test_getWorkspaceAgent(t *testing.T) { t.Parallel() workspace := createWorkspaceWithAgents([]codersdk.WorkspaceAgent{}) - _, err := getWorkspaceAgent(workspace, "") + _, _, err := getWorkspaceAgent(workspace, "") require.Error(t, err) assert.Contains(t, err.Error(), `workspace "test-workspace" has no agents`) }) @@ -435,7 +438,7 @@ func Test_getWorkspaceAgent(t *testing.T) { agent3 := createAgent("krypton") workspace := createWorkspaceWithAgents([]codersdk.WorkspaceAgent{agent2, agent1, agent3}) - _, err := getWorkspaceAgent(workspace, "nonexistent") + _, _, err := getWorkspaceAgent(workspace, "nonexistent") require.Error(t, err) // Available agents should be sorted alphabetically. assert.Contains(t, err.Error(), "available agents: [clark krypton zod]") diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 582f8a3fdf691..7a91cfa3ce365 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -2104,7 +2104,7 @@ func TestSSH_Container(t *testing.T) { clitest.SetupConfig(t, client, root) err := inv.WithContext(ctx).Run() - require.ErrorContains(t, err, "The agent dev containers feature is experimental and not enabled by default.") + require.ErrorContains(t, err, "Dev Container feature not enabled.") }) } diff --git a/cli/testdata/TestShowDevcontainers_Golden/error_devcontainer.golden b/cli/testdata/TestShowDevcontainers_Golden/error_devcontainer.golden new file mode 100644 index 0000000000000..03a19f16df4e1 --- /dev/null +++ b/cli/testdata/TestShowDevcontainers_Golden/error_devcontainer.golden @@ -0,0 +1,9 @@ +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ RESOURCE STATUS HEALTH VERSION ACCESS │ +├─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ compute.main │ +│ └─ main (linux, amd64) ⦿ connected ✔ healthy v2.15.0 coder ssh test-workspace.main │ +│ └─ Devcontainers │ +│ └─ failed-dev ✘ error │ +│ × Failed to pull image mcr.microsoft.com/devcontainers/go:latest: timeout after 5… │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ diff --git a/cli/testdata/TestShowDevcontainers_Golden/long_error_message.golden b/cli/testdata/TestShowDevcontainers_Golden/long_error_message.golden new file mode 100644 index 0000000000000..1e80d338a74a8 --- /dev/null +++ b/cli/testdata/TestShowDevcontainers_Golden/long_error_message.golden @@ -0,0 +1,9 @@ +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ RESOURCE STATUS HEALTH VERSION ACCESS │ +├─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ compute.main │ +│ └─ main (linux, amd64) ⦿ connected ✔ healthy v2.15.0 coder ssh test-workspace.main │ +│ └─ Devcontainers │ +│ └─ long-error-dev ✘ error │ +│ × Failed to build devcontainer: dockerfile parse error at line 25: unknown instru… │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ diff --git a/cli/testdata/TestShowDevcontainers_Golden/long_error_message_with_detail.golden b/cli/testdata/TestShowDevcontainers_Golden/long_error_message_with_detail.golden new file mode 100644 index 0000000000000..9310f7f19a350 --- /dev/null +++ b/cli/testdata/TestShowDevcontainers_Golden/long_error_message_with_detail.golden @@ -0,0 +1,9 @@ +┌────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ RESOURCE STATUS HEALTH VERSION ACCESS │ +├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ compute.main │ +│ └─ main (linux, amd64) ⦿ connected ✔ healthy v2.15.0 coder ssh test-workspace.main │ +│ └─ Devcontainers │ +│ └─ long-error-dev ✘ error │ +│ × Failed to build devcontainer: dockerfile parse error at line 25: unknown instruction 'INSTALL', did you mean 'RUN apt-get install'? This is a very long error message that should be truncated when detail flag is not used │ +└────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ diff --git a/cli/testdata/TestShowDevcontainers_Golden/mixed_devcontainer_states.golden b/cli/testdata/TestShowDevcontainers_Golden/mixed_devcontainer_states.golden new file mode 100644 index 0000000000000..dfbd677cc3dbe --- /dev/null +++ b/cli/testdata/TestShowDevcontainers_Golden/mixed_devcontainer_states.golden @@ -0,0 +1,13 @@ +┌───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ RESOURCE STATUS HEALTH VERSION ACCESS │ +├───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ compute.main │ +│ └─ main (linux, amd64) ⦿ connected ✔ healthy v2.15.0 coder ssh test-workspace.main │ +│ └─ Devcontainers │ +│ ├─ frontend (linux, amd64) [vibrant_tesla] ⦿ connected ✔ healthy v2.15.0 coder ssh test-workspace.frontend │ +│ ├─ Open Ports │ +│ └─ 5173/tcp [vite] │ +│ ├─ backend [peaceful_curie] ⏹ stopped │ +│ └─ error-container ✘ error │ +│ × Container build failed: dockerfile syntax error on line 15 │ +└───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ diff --git a/cli/testdata/TestShowDevcontainers_Golden/running_devcontainer_with_agent.golden b/cli/testdata/TestShowDevcontainers_Golden/running_devcontainer_with_agent.golden new file mode 100644 index 0000000000000..ab5d2a2085227 --- /dev/null +++ b/cli/testdata/TestShowDevcontainers_Golden/running_devcontainer_with_agent.golden @@ -0,0 +1,11 @@ +┌──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ RESOURCE STATUS HEALTH VERSION ACCESS │ +├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ compute.main │ +│ └─ main (linux, amd64) ⦿ connected ✔ healthy v2.15.0 coder ssh test-workspace.main │ +│ └─ Devcontainers │ +│ └─ web-dev (linux, amd64) [quirky_lovelace] ⦿ connected ✔ healthy v2.15.0 coder ssh test-workspace.web-dev │ +│ └─ Open Ports │ +│ ├─ 3000/tcp [node] │ +│ └─ 8080/tcp [webpack-dev-server] │ +└──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ diff --git a/cli/testdata/TestShowDevcontainers_Golden/running_devcontainer_with_agent_and_error.golden b/cli/testdata/TestShowDevcontainers_Golden/running_devcontainer_with_agent_and_error.golden new file mode 100644 index 0000000000000..6b73f7175bac8 --- /dev/null +++ b/cli/testdata/TestShowDevcontainers_Golden/running_devcontainer_with_agent_and_error.golden @@ -0,0 +1,11 @@ +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ RESOURCE STATUS HEALTH VERSION ACCESS │ +├─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ compute.main │ +│ └─ main (linux, amd64) ⦿ connected ✔ healthy v2.15.0 coder ssh test-workspace.main │ +│ └─ Devcontainers │ +│ └─ problematic-dev (linux, amd64) [cranky_mendel] ⦿ connected ✔ healthy v2.15.0 coder ssh test-workspace.problematic-dev │ +│ × Warning: Container started but healthcheck failed │ +│ └─ Open Ports │ +│ └─ 8000/tcp [python] │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ diff --git a/cli/testdata/TestShowDevcontainers_Golden/running_devcontainer_without_agent.golden b/cli/testdata/TestShowDevcontainers_Golden/running_devcontainer_without_agent.golden new file mode 100644 index 0000000000000..70c3874acc774 --- /dev/null +++ b/cli/testdata/TestShowDevcontainers_Golden/running_devcontainer_without_agent.golden @@ -0,0 +1,8 @@ +┌──────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ RESOURCE STATUS HEALTH VERSION ACCESS │ +├──────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ compute.main │ +│ └─ main (linux, amd64) ⦿ connected ✔ healthy v2.15.0 coder ssh test-workspace.main │ +│ └─ Devcontainers │ +│ └─ web-server [amazing_turing] ▶ running │ +└──────────────────────────────────────────────────────────────────────────────────────────────────────┘ diff --git a/cli/testdata/TestShowDevcontainers_Golden/starting_devcontainer.golden b/cli/testdata/TestShowDevcontainers_Golden/starting_devcontainer.golden new file mode 100644 index 0000000000000..472201ecc7818 --- /dev/null +++ b/cli/testdata/TestShowDevcontainers_Golden/starting_devcontainer.golden @@ -0,0 +1,8 @@ +┌───────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ RESOURCE STATUS HEALTH VERSION ACCESS │ +├───────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ compute.main │ +│ └─ main (linux, amd64) ⦿ connected ✔ healthy v2.15.0 coder ssh test-workspace.main │ +│ └─ Devcontainers │ +│ └─ database-dev [nostalgic_hawking] ⧗ starting │ +└───────────────────────────────────────────────────────────────────────────────────────────────────────────┘ diff --git a/cli/testdata/TestShowDevcontainers_Golden/stopped_devcontainer.golden b/cli/testdata/TestShowDevcontainers_Golden/stopped_devcontainer.golden new file mode 100644 index 0000000000000..41313b235acc7 --- /dev/null +++ b/cli/testdata/TestShowDevcontainers_Golden/stopped_devcontainer.golden @@ -0,0 +1,8 @@ +┌──────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ RESOURCE STATUS HEALTH VERSION ACCESS │ +├──────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ compute.main │ +│ └─ main (linux, amd64) ⦿ connected ✔ healthy v2.15.0 coder ssh test-workspace.main │ +│ └─ Devcontainers │ +│ └─ api-dev [clever_darwin] ⏹ stopped │ +└──────────────────────────────────────────────────────────────────────────────────────────────────┘ diff --git a/cli/testdata/coder_show_--help.golden b/cli/testdata/coder_show_--help.golden index fc048aa067ea6..76555221e4602 100644 --- a/cli/testdata/coder_show_--help.golden +++ b/cli/testdata/coder_show_--help.golden @@ -1,9 +1,13 @@ coder v0.0.0-devel USAGE: - coder show + coder show [flags] Display details of a workspace's resources and agents +OPTIONS: + --details bool (default: false) + Show full error messages and additional details. + ——— Run `coder --help` for a list of global options. diff --git a/cli/testdata/server-config.yaml.golden b/cli/testdata/server-config.yaml.golden index dabeab8ca48bd..e23274e442078 100644 --- a/cli/testdata/server-config.yaml.golden +++ b/cli/testdata/server-config.yaml.golden @@ -462,9 +462,6 @@ enableSwagger: false # to be configured as a shared directory across coderd/provisionerd replicas. # (default: [cache dir], type: string) cacheDir: [cache dir] -# Controls whether data will be stored in an in-memory database. -# (default: , type: bool) -inMemoryDatabase: false # Controls whether Coder data, including built-in Postgres, will be stored in a # temporary directory and deleted when the server is stopped. # (default: , type: bool) diff --git a/cli/vscodessh.go b/cli/vscodessh.go index 872f7d837c0cd..e0b963b7ed80d 100644 --- a/cli/vscodessh.go +++ b/cli/vscodessh.go @@ -102,7 +102,7 @@ func (r *RootCmd) vscodeSSH() *serpent.Command { // will call this command after the workspace is started. autostart := false - workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, inv, client, autostart, fmt.Sprintf("%s/%s", owner, name)) + workspace, workspaceAgent, _, err := getWorkspaceAndAgent(ctx, inv, client, autostart, fmt.Sprintf("%s/%s", owner, name)) if err != nil { return xerrors.Errorf("find workspace and agent: %w", err) } diff --git a/coderd/agentapi/api.go b/coderd/agentapi/api.go index c409f8ea89e9b..dbcb8ea024914 100644 --- a/coderd/agentapi/api.go +++ b/coderd/agentapi/api.go @@ -19,7 +19,7 @@ import ( agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd/agentapi/resourcesmonitor" "github.com/coder/coder/v2/coderd/appearance" - "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/connectionlog" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/externalauth" @@ -50,7 +50,7 @@ type API struct { *ResourcesMonitoringAPI *LogsAPI *ScriptsAPI - *AuditAPI + *ConnLogAPI *SubAgentAPI *tailnet.DRPCService @@ -71,7 +71,7 @@ type Options struct { Database database.Store NotificationsEnqueuer notifications.Enqueuer Pubsub pubsub.Pubsub - Auditor *atomic.Pointer[audit.Auditor] + ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger] DerpMapFn func() *tailcfg.DERPMap TailnetCoordinator *atomic.Pointer[tailnet.Coordinator] StatsReporter *workspacestats.Reporter @@ -180,11 +180,11 @@ func New(opts Options) *API { Database: opts.Database, } - api.AuditAPI = &AuditAPI{ - AgentFn: api.agent, - Auditor: opts.Auditor, - Database: opts.Database, - Log: opts.Log, + api.ConnLogAPI = &ConnLogAPI{ + AgentFn: api.agent, + ConnectionLogger: opts.ConnectionLogger, + Database: opts.Database, + Log: opts.Log, } api.DRPCService = &tailnet.DRPCService{ diff --git a/coderd/agentapi/audit.go b/coderd/agentapi/audit.go deleted file mode 100644 index 2025b2d6cd92b..0000000000000 --- a/coderd/agentapi/audit.go +++ /dev/null @@ -1,105 +0,0 @@ -package agentapi - -import ( - "context" - "encoding/json" - "strconv" - "sync/atomic" - - "github.com/google/uuid" - "golang.org/x/xerrors" - "google.golang.org/protobuf/types/known/emptypb" - - "cdr.dev/slog" - - agentproto "github.com/coder/coder/v2/agent/proto" - "github.com/coder/coder/v2/coderd/audit" - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/db2sdk" - "github.com/coder/coder/v2/codersdk/agentsdk" -) - -type AuditAPI struct { - AgentFn func(context.Context) (database.WorkspaceAgent, error) - Auditor *atomic.Pointer[audit.Auditor] - Database database.Store - Log slog.Logger -} - -func (a *AuditAPI) ReportConnection(ctx context.Context, req *agentproto.ReportConnectionRequest) (*emptypb.Empty, error) { - // We will use connection ID as request ID, typically this is the - // SSH session ID as reported by the agent. - connectionID, err := uuid.FromBytes(req.GetConnection().GetId()) - if err != nil { - return nil, xerrors.Errorf("connection id from bytes: %w", err) - } - - action, err := db2sdk.AuditActionFromAgentProtoConnectionAction(req.GetConnection().GetAction()) - if err != nil { - return nil, err - } - connectionType, err := agentsdk.ConnectionTypeFromProto(req.GetConnection().GetType()) - if err != nil { - return nil, err - } - - // Fetch contextual data for this audit event. - workspaceAgent, err := a.AgentFn(ctx) - if err != nil { - return nil, xerrors.Errorf("get agent: %w", err) - } - workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) - if err != nil { - return nil, xerrors.Errorf("get workspace by agent id: %w", err) - } - build, err := a.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest workspace build by workspace id: %w", err) - } - - // We pass the below information to the Auditor so that it - // can form a friendly string for the user to view in the UI. - type additionalFields struct { - audit.AdditionalFields - - ConnectionType agentsdk.ConnectionType `json:"connection_type"` - Reason string `json:"reason,omitempty"` - } - resourceInfo := additionalFields{ - AdditionalFields: audit.AdditionalFields{ - WorkspaceID: workspace.ID, - WorkspaceName: workspace.Name, - WorkspaceOwner: workspace.OwnerUsername, - BuildNumber: strconv.FormatInt(int64(build.BuildNumber), 10), - BuildReason: database.BuildReason(string(build.Reason)), - }, - ConnectionType: connectionType, - Reason: req.GetConnection().GetReason(), - } - - riBytes, err := json.Marshal(resourceInfo) - if err != nil { - a.Log.Error(ctx, "marshal resource info for agent connection failed", slog.Error(err)) - riBytes = []byte("{}") - } - - audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.WorkspaceAgent]{ - Audit: *a.Auditor.Load(), - Log: a.Log, - Time: req.GetConnection().GetTimestamp().AsTime(), - OrganizationID: workspace.OrganizationID, - RequestID: connectionID, - Action: action, - New: workspaceAgent, - Old: workspaceAgent, - IP: req.GetConnection().GetIp(), - Status: int(req.GetConnection().GetStatusCode()), - AdditionalFields: riBytes, - - // It's not possible to tell which user connected. Once we have - // the capability, this may be reported by the agent. - UserID: uuid.Nil, - }) - - return &emptypb.Empty{}, nil -} diff --git a/coderd/agentapi/connectionlog.go b/coderd/agentapi/connectionlog.go new file mode 100644 index 0000000000000..f26f835746981 --- /dev/null +++ b/coderd/agentapi/connectionlog.go @@ -0,0 +1,106 @@ +package agentapi + +import ( + "context" + "database/sql" + "sync/atomic" + + "github.com/google/uuid" + "golang.org/x/xerrors" + "google.golang.org/protobuf/types/known/emptypb" + + "cdr.dev/slog" + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/connectionlog" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" +) + +type ConnLogAPI struct { + AgentFn func(context.Context) (database.WorkspaceAgent, error) + ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger] + Database database.Store + Log slog.Logger +} + +func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.ReportConnectionRequest) (*emptypb.Empty, error) { + // We use the connection ID to identify which connection log event to mark + // as closed, when we receive a close action for that ID. + connectionID, err := uuid.FromBytes(req.GetConnection().GetId()) + if err != nil { + return nil, xerrors.Errorf("connection id from bytes: %w", err) + } + + if connectionID == uuid.Nil { + return nil, xerrors.New("connection ID cannot be nil") + } + action, err := db2sdk.ConnectionLogStatusFromAgentProtoConnectionAction(req.GetConnection().GetAction()) + if err != nil { + return nil, err + } + connectionType, err := db2sdk.ConnectionLogConnectionTypeFromAgentProtoConnectionType(req.GetConnection().GetType()) + if err != nil { + return nil, err + } + + var code sql.NullInt32 + if action == database.ConnectionStatusDisconnected { + code = sql.NullInt32{ + Int32: req.GetConnection().GetStatusCode(), + Valid: true, + } + } + + // Fetch contextual data for this connection log event. + workspaceAgent, err := a.AgentFn(ctx) + if err != nil { + return nil, xerrors.Errorf("get agent: %w", err) + } + workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) + if err != nil { + return nil, xerrors.Errorf("get workspace by agent id: %w", err) + } + + reason := req.GetConnection().GetReason() + connLogger := *a.ConnectionLogger.Load() + err = connLogger.Upsert(ctx, database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: req.GetConnection().GetTimestamp().AsTime(), + OrganizationID: workspace.OrganizationID, + WorkspaceOwnerID: workspace.OwnerID, + WorkspaceID: workspace.ID, + WorkspaceName: workspace.Name, + AgentName: workspaceAgent.Name, + Type: connectionType, + Code: code, + Ip: database.ParseIP(req.GetConnection().GetIp()), + ConnectionID: uuid.NullUUID{ + UUID: connectionID, + Valid: true, + }, + DisconnectReason: sql.NullString{ + String: reason, + Valid: reason != "", + }, + // We supply the action: + // - So the DB can handle duplicate connections or disconnections properly. + // - To make it clear whether this is a connection or disconnection + // prior to it's insertion into the DB (logs) + ConnectionStatus: action, + + // It's not possible to tell which user connected. Once we have + // the capability, this may be reported by the agent. + UserID: uuid.NullUUID{ + Valid: false, + }, + // N/A + UserAgent: sql.NullString{}, + // N/A + SlugOrPort: sql.NullString{}, + }) + if err != nil { + return nil, xerrors.Errorf("export connection log: %w", err) + } + + return &emptypb.Empty{}, nil +} diff --git a/coderd/agentapi/audit_test.go b/coderd/agentapi/connectionlog_test.go similarity index 62% rename from coderd/agentapi/audit_test.go rename to coderd/agentapi/connectionlog_test.go index b881fde5d22bc..4a060b8f16faf 100644 --- a/coderd/agentapi/audit_test.go +++ b/coderd/agentapi/connectionlog_test.go @@ -2,7 +2,7 @@ package agentapi_test import ( "context" - "encoding/json" + "database/sql" "net" "sync/atomic" "testing" @@ -16,15 +16,14 @@ import ( agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd/agentapi" - "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/connectionlog" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtime" - "github.com/coder/coder/v2/codersdk/agentsdk" ) -func TestAuditReport(t *testing.T) { +func TestConnectionLog(t *testing.T) { t.Parallel() var ( @@ -38,10 +37,6 @@ func TestAuditReport(t *testing.T) { OwnerID: owner.ID, Name: "cool-workspace", } - build = database.WorkspaceBuild{ - ID: uuid.New(), - WorkspaceID: workspace.ID, - } agent = database.WorkspaceAgent{ ID: uuid.New(), } @@ -62,7 +57,7 @@ func TestAuditReport(t *testing.T) { id: uuid.New(), action: agentproto.Connection_CONNECT.Enum(), typ: agentproto.Connection_SSH.Enum(), - time: time.Now(), + time: dbtime.Now(), ip: "127.0.0.1", status: 200, }, @@ -71,7 +66,7 @@ func TestAuditReport(t *testing.T) { id: uuid.New(), action: agentproto.Connection_CONNECT.Enum(), typ: agentproto.Connection_VSCODE.Enum(), - time: time.Now(), + time: dbtime.Now(), ip: "8.8.8.8", }, { @@ -79,28 +74,28 @@ func TestAuditReport(t *testing.T) { id: uuid.New(), action: agentproto.Connection_CONNECT.Enum(), typ: agentproto.Connection_JETBRAINS.Enum(), - time: time.Now(), + time: dbtime.Now(), }, { name: "Reconnecting PTY Connect", id: uuid.New(), action: agentproto.Connection_CONNECT.Enum(), typ: agentproto.Connection_RECONNECTING_PTY.Enum(), - time: time.Now(), + time: dbtime.Now(), }, { name: "SSH Disconnect", id: uuid.New(), action: agentproto.Connection_DISCONNECT.Enum(), typ: agentproto.Connection_SSH.Enum(), - time: time.Now(), + time: dbtime.Now(), }, { name: "SSH Disconnect", id: uuid.New(), action: agentproto.Connection_DISCONNECT.Enum(), typ: agentproto.Connection_SSH.Enum(), - time: time.Now(), + time: dbtime.Now(), status: 500, reason: "because error says so", }, @@ -110,15 +105,14 @@ func TestAuditReport(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - mAudit := audit.NewMock() + connLogger := connectionlog.NewFake() mDB := dbmock.NewMockStore(gomock.NewController(t)) mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(workspace, nil) - mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspace.ID).Return(build, nil) - api := &agentapi.AuditAPI{ - Auditor: asAtomicPointer[audit.Auditor](mAudit), - Database: mDB, + api := &agentapi.ConnLogAPI{ + ConnectionLogger: asAtomicPointer[connectionlog.ConnectionLogger](connLogger), + Database: mDB, AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, @@ -135,41 +129,48 @@ func TestAuditReport(t *testing.T) { }, }) - require.True(t, mAudit.Contains(t, database.AuditLog{ - Time: dbtime.Time(tt.time).In(time.UTC), - Action: agentProtoConnectionActionToAudit(t, *tt.action), - OrganizationID: workspace.OrganizationID, - UserID: uuid.Nil, - RequestID: tt.id, - ResourceType: database.ResourceTypeWorkspaceAgent, - ResourceID: agent.ID, - ResourceTarget: agent.Name, - Ip: pqtype.Inet{Valid: true, IPNet: net.IPNet{IP: net.ParseIP(tt.ip), Mask: net.CIDRMask(32, 32)}}, - StatusCode: tt.status, - })) + require.True(t, connLogger.Contains(t, database.UpsertConnectionLogParams{ + Time: dbtime.Time(tt.time).In(time.UTC), + OrganizationID: workspace.OrganizationID, + WorkspaceOwnerID: workspace.OwnerID, + WorkspaceID: workspace.ID, + WorkspaceName: workspace.Name, + AgentName: agent.Name, + UserID: uuid.NullUUID{ + UUID: uuid.Nil, + Valid: false, + }, + ConnectionStatus: agentProtoConnectionActionToConnectionLog(t, *tt.action), - // Check some additional fields. - var m map[string]any - err := json.Unmarshal(mAudit.AuditLogs()[0].AdditionalFields, &m) - require.NoError(t, err) - require.Equal(t, string(agentProtoConnectionTypeToSDK(t, *tt.typ)), m["connection_type"].(string)) - if tt.reason != "" { - require.Equal(t, tt.reason, m["reason"]) - } + Code: sql.NullInt32{ + Int32: tt.status, + Valid: *tt.action == agentproto.Connection_DISCONNECT, + }, + Ip: pqtype.Inet{Valid: true, IPNet: net.IPNet{IP: net.ParseIP(tt.ip), Mask: net.CIDRMask(32, 32)}}, + Type: agentProtoConnectionTypeToConnectionLog(t, *tt.typ), + DisconnectReason: sql.NullString{ + String: tt.reason, + Valid: tt.reason != "", + }, + ConnectionID: uuid.NullUUID{ + UUID: tt.id, + Valid: tt.id != uuid.Nil, + }, + })) }) } } -func agentProtoConnectionActionToAudit(t *testing.T, action agentproto.Connection_Action) database.AuditAction { - a, err := db2sdk.AuditActionFromAgentProtoConnectionAction(action) +func agentProtoConnectionTypeToConnectionLog(t *testing.T, typ agentproto.Connection_Type) database.ConnectionType { + a, err := db2sdk.ConnectionLogConnectionTypeFromAgentProtoConnectionType(typ) require.NoError(t, err) return a } -func agentProtoConnectionTypeToSDK(t *testing.T, typ agentproto.Connection_Type) agentsdk.ConnectionType { - action, err := agentsdk.ConnectionTypeFromProto(typ) +func agentProtoConnectionActionToConnectionLog(t *testing.T, action agentproto.Connection_Action) database.ConnectionStatus { + a, err := db2sdk.ConnectionLogStatusFromAgentProtoConnectionAction(action) require.NoError(t, err) - return action + return a } func asAtomicPointer[T any](v T) *atomic.Pointer[T] { diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 1ee6ea77af5d9..bcc7443c1c928 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -8778,6 +8778,41 @@ const docTemplate = `{ } } }, + "/workspaceagents/{workspaceagent}/containers/watch": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "produces": [ + "application/json" + ], + "tags": [ + "Agents" + ], + "summary": "Watch workspace agent for container updates.", + "operationId": "watch-workspace-agent-for-container-updates", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace agent ID", + "name": "workspaceagent", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.WorkspaceAgentListContainersResponse" + } + } + } + } + }, "/workspaceagents/{workspaceagent}/coordinate": { "get": { "security": [ @@ -9122,6 +9157,16 @@ const docTemplate = `{ "name": "workspacebuild", "in": "path", "required": true + }, + { + "enum": [ + "running", + "pending" + ], + "type": "string", + "description": "Expected status of the job. If expect_status is supplied, the request will be rejected with 412 Precondition Failed if the job doesn't match the state when performing the cancellation.", + "name": "expect_status", + "in": "query" } ], "responses": { @@ -11356,12 +11401,14 @@ const docTemplate = `{ "enum": [ "initiator", "autostart", - "autostop" + "autostop", + "dormancy" ], "x-enum-varnames": [ "BuildReasonInitiator", "BuildReasonAutostart", - "BuildReasonAutostop" + "BuildReasonAutostop", + "BuildReasonDormancy" ] }, "codersdk.ChangePasswordWithOneTimePasscodeRequest": { @@ -12293,9 +12340,6 @@ const docTemplate = `{ "http_cookies": { "$ref": "#/definitions/codersdk.HTTPCookieConfig" }, - "in_memory_database": { - "type": "boolean" - }, "job_hang_detector_interval": { "type": "integer" }, @@ -15320,6 +15364,7 @@ const docTemplate = `{ "assign_org_role", "assign_role", "audit_log", + "connection_log", "crypto_key", "debug_info", "deployment_config", @@ -15359,6 +15404,7 @@ const docTemplate = `{ "ResourceAssignOrgRole", "ResourceAssignRole", "ResourceAuditLog", + "ResourceConnectionLog", "ResourceCryptoKey", "ResourceDebugInfo", "ResourceDeploymentConfig", diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index b55a08caa8ec6..8485df8f2a745 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -7751,6 +7751,37 @@ } } }, + "/workspaceagents/{workspaceagent}/containers/watch": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "produces": ["application/json"], + "tags": ["Agents"], + "summary": "Watch workspace agent for container updates.", + "operationId": "watch-workspace-agent-for-container-updates", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace agent ID", + "name": "workspaceagent", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.WorkspaceAgentListContainersResponse" + } + } + } + } + }, "/workspaceagents/{workspaceagent}/coordinate": { "get": { "security": [ @@ -8065,6 +8096,13 @@ "name": "workspacebuild", "in": "path", "required": true + }, + { + "enum": ["running", "pending"], + "type": "string", + "description": "Expected status of the job. If expect_status is supplied, the request will be rejected with 412 Precondition Failed if the job doesn't match the state when performing the cancellation.", + "name": "expect_status", + "in": "query" } ], "responses": { @@ -10099,11 +10137,12 @@ }, "codersdk.BuildReason": { "type": "string", - "enum": ["initiator", "autostart", "autostop"], + "enum": ["initiator", "autostart", "autostop", "dormancy"], "x-enum-varnames": [ "BuildReasonInitiator", "BuildReasonAutostart", - "BuildReasonAutostop" + "BuildReasonAutostop", + "BuildReasonDormancy" ] }, "codersdk.ChangePasswordWithOneTimePasscodeRequest": { @@ -10981,9 +11020,6 @@ "http_cookies": { "$ref": "#/definitions/codersdk.HTTPCookieConfig" }, - "in_memory_database": { - "type": "boolean" - }, "job_hang_detector_interval": { "type": "integer" }, @@ -13900,6 +13936,7 @@ "assign_org_role", "assign_role", "audit_log", + "connection_log", "crypto_key", "debug_info", "deployment_config", @@ -13939,6 +13976,7 @@ "ResourceAssignOrgRole", "ResourceAssignRole", "ResourceAuditLog", + "ResourceConnectionLog", "ResourceCryptoKey", "ResourceDebugInfo", "ResourceDeploymentConfig", diff --git a/coderd/audit/request.go b/coderd/audit/request.go index 0fa88fa40e2ea..ae6a57e6c2775 100644 --- a/coderd/audit/request.go +++ b/coderd/audit/request.go @@ -6,13 +6,11 @@ import ( "encoding/json" "flag" "fmt" - "net" "net/http" "strconv" "time" "github.com/google/uuid" - "github.com/sqlc-dev/pqtype" "go.opentelemetry.io/otel/baggage" "golang.org/x/xerrors" @@ -434,7 +432,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request action = req.Action } - ip := ParseIP(p.Request.RemoteAddr) + ip := database.ParseIP(p.Request.RemoteAddr) auditLog := database.AuditLog{ ID: uuid.New(), Time: dbtime.Now(), @@ -466,7 +464,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request // BackgroundAudit creates an audit log for a background event. // The audit log is committed upon invocation. func BackgroundAudit[T Auditable](ctx context.Context, p *BackgroundAuditParams[T]) { - ip := ParseIP(p.IP) + ip := database.ParseIP(p.IP) diff := Diff(p.Audit, p.Old, p.New) var err error @@ -581,19 +579,3 @@ func either[T Auditable, R any](old, newVal T, fn func(T) R, auditAction databas panic("both old and new are nil") } } - -func ParseIP(ipStr string) pqtype.Inet { - ip := net.ParseIP(ipStr) - ipNet := net.IPNet{} - if ip != nil { - ipNet = net.IPNet{ - IP: ip, - Mask: net.CIDRMask(len(ip)*8, len(ip)*8), - } - } - - return pqtype.Inet{ - IPNet: ipNet, - Valid: ip != nil, - } -} diff --git a/coderd/autobuild/lifecycle_executor.go b/coderd/autobuild/lifecycle_executor.go index 1846b1ea18284..d49bf831515d0 100644 --- a/coderd/autobuild/lifecycle_executor.go +++ b/coderd/autobuild/lifecycle_executor.go @@ -520,6 +520,8 @@ func isEligibleForAutostart(user database.User, ws database.Workspace, build dat return false } + // Get the next allowed autostart time after the build's creation time, + // based on the workspace's schedule and the template's allowed days. nextTransition, err := schedule.NextAllowedAutostart(build.CreatedAt, ws.AutostartSchedule.String, templateSchedule) if err != nil { return false diff --git a/coderd/autobuild/lifecycle_executor_test.go b/coderd/autobuild/lifecycle_executor_test.go index 3bca6856534fa..0229a907cbb2e 100644 --- a/coderd/autobuild/lifecycle_executor_test.go +++ b/coderd/autobuild/lifecycle_executor_test.go @@ -2,9 +2,16 @@ package autobuild_test import ( "context" + "database/sql" + "errors" "testing" "time" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/quartz" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1183,6 +1190,348 @@ func TestNotifications(t *testing.T) { }) } +// TestExecutorPrebuilds verifies AGPL behavior for prebuilt workspaces. +// It ensures that workspace schedules do not trigger while the workspace +// is still in a prebuilt state. Scheduling behavior only applies after the +// workspace has been claimed and becomes a regular user workspace. +// For enterprise-related functionality, see enterprise/coderd/workspaces_test.go. +func TestExecutorPrebuilds(t *testing.T) { + t.Parallel() + + if !dbtestutil.WillUsePostgres() { + t.Skip("this test requires postgres") + } + + // Prebuild workspaces should not be autostopped when the deadline is reached. + // After being claimed, the workspace should stop at the deadline. + t.Run("OnlyStopsAfterClaimed", func(t *testing.T) { + t.Parallel() + + // Setup + ctx := testutil.Context(t, testutil.WaitShort) + clock := quartz.NewMock(t) + db, pb := dbtestutil.NewDB(t, dbtestutil.WithDumpOnFailure()) + var ( + tickCh = make(chan time.Time) + statsCh = make(chan autobuild.Stats) + client = coderdtest.New(t, &coderdtest.Options{ + Database: db, + Pubsub: pb, + AutobuildTicker: tickCh, + IncludeProvisionerDaemon: true, + AutobuildStats: statsCh, + }) + ) + + // Setup user, template and template version + owner := coderdtest.CreateFirstUser(t, client) + _, user := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, rbac.RoleMember()) + version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID) + + // Database setup of a preset with a prebuild instance + preset := setupTestDBPreset(t, db, version.ID, int32(1)) + + // Given: a running prebuilt workspace with a deadline and ready to be claimed + dbPrebuild := setupTestDBPrebuiltWorkspace( + ctx, t, clock, db, pb, + owner.OrganizationID, + template.ID, + version.ID, + preset.ID, + ) + prebuild := coderdtest.MustWorkspace(t, client, dbPrebuild.ID) + require.Equal(t, codersdk.WorkspaceTransitionStart, prebuild.LatestBuild.Transition) + require.NotZero(t, prebuild.LatestBuild.Deadline) + + // When: the autobuild executor ticks *after* the deadline: + go func() { + tickCh <- prebuild.LatestBuild.Deadline.Time.Add(time.Minute) + }() + + // Then: the prebuilt workspace should remain in a start transition + prebuildStats := <-statsCh + require.Len(t, prebuildStats.Errors, 0) + require.Len(t, prebuildStats.Transitions, 0) + require.Equal(t, codersdk.WorkspaceTransitionStart, prebuild.LatestBuild.Transition) + prebuild = coderdtest.MustWorkspace(t, client, prebuild.ID) + require.Equal(t, codersdk.BuildReasonInitiator, prebuild.LatestBuild.Reason) + + // Given: a user claims the prebuilt workspace + dbWorkspace := dbgen.ClaimPrebuild(t, db, user.ID, "claimedWorkspace-autostop", preset.ID) + workspace := coderdtest.MustWorkspace(t, client, dbWorkspace.ID) + + // When: the autobuild executor ticks *after* the deadline: + go func() { + tickCh <- workspace.LatestBuild.Deadline.Time.Add(time.Minute) + close(tickCh) + }() + + // Then: the workspace should be stopped + workspaceStats := <-statsCh + require.Len(t, workspaceStats.Errors, 0) + require.Len(t, workspaceStats.Transitions, 1) + require.Contains(t, workspaceStats.Transitions, workspace.ID) + require.Equal(t, database.WorkspaceTransitionStop, workspaceStats.Transitions[workspace.ID]) + workspace = coderdtest.MustWorkspace(t, client, workspace.ID) + require.Equal(t, codersdk.BuildReasonAutostop, workspace.LatestBuild.Reason) + }) + + // Prebuild workspaces should not be autostarted when the autostart scheduled is reached. + // After being claimed, the workspace should autostart at the schedule. + t.Run("OnlyStartsAfterClaimed", func(t *testing.T) { + t.Parallel() + + // Setup + ctx := testutil.Context(t, testutil.WaitShort) + clock := quartz.NewMock(t) + db, pb := dbtestutil.NewDB(t, dbtestutil.WithDumpOnFailure()) + var ( + tickCh = make(chan time.Time) + statsCh = make(chan autobuild.Stats) + client = coderdtest.New(t, &coderdtest.Options{ + Database: db, + Pubsub: pb, + AutobuildTicker: tickCh, + IncludeProvisionerDaemon: true, + AutobuildStats: statsCh, + }) + ) + + // Setup user, template and template version + owner := coderdtest.CreateFirstUser(t, client) + _, user := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, rbac.RoleMember()) + version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID) + + // Database setup of a preset with a prebuild instance + preset := setupTestDBPreset(t, db, version.ID, int32(1)) + + // Given: prebuilt workspace is stopped and set to autostart daily at midnight + sched := mustSchedule(t, "CRON_TZ=UTC 0 0 * * *") + autostartSched := sql.NullString{ + String: sched.String(), + Valid: true, + } + dbPrebuild := setupTestDBPrebuiltWorkspace( + ctx, t, clock, db, pb, + owner.OrganizationID, + template.ID, + version.ID, + preset.ID, + WithAutostartSchedule(autostartSched), + WithIsStopped(true), + ) + prebuild := coderdtest.MustWorkspace(t, client, dbPrebuild.ID) + require.Equal(t, codersdk.WorkspaceTransitionStop, prebuild.LatestBuild.Transition) + require.NotNil(t, prebuild.AutostartSchedule) + + // Tick at the next scheduled time after the prebuild’s LatestBuild.CreatedAt, + // since the next allowed autostart is calculated starting from that point. + // When: the autobuild executor ticks after the scheduled time + go func() { + tickCh <- sched.Next(prebuild.LatestBuild.CreatedAt).Add(time.Minute) + }() + + // Then: the prebuilt workspace should remain in a stop transition + prebuildStats := <-statsCh + require.Len(t, prebuildStats.Errors, 0) + require.Len(t, prebuildStats.Transitions, 0) + require.Equal(t, codersdk.WorkspaceTransitionStop, prebuild.LatestBuild.Transition) + prebuild = coderdtest.MustWorkspace(t, client, prebuild.ID) + require.Equal(t, codersdk.BuildReasonInitiator, prebuild.LatestBuild.Reason) + + // Given: prebuilt workspace is in a start status + setupTestDBWorkspaceBuild( + ctx, t, clock, db, pb, + owner.OrganizationID, + prebuild.ID, + version.ID, + preset.ID, + database.WorkspaceTransitionStart) + + // Given: a user claims the prebuilt workspace + dbWorkspace := dbgen.ClaimPrebuild(t, db, user.ID, "claimedWorkspace-autostart", preset.ID) + workspace := coderdtest.MustWorkspace(t, client, dbWorkspace.ID) + + // Given: the prebuilt workspace goes to a stop status + setupTestDBWorkspaceBuild( + ctx, t, clock, db, pb, + owner.OrganizationID, + prebuild.ID, + version.ID, + preset.ID, + database.WorkspaceTransitionStop) + + // Tick at the next scheduled time after the prebuild’s LatestBuild.CreatedAt, + // since the next allowed autostart is calculated starting from that point. + // When: the autobuild executor ticks after the scheduled time + go func() { + tickCh <- sched.Next(workspace.LatestBuild.CreatedAt).Add(time.Minute) + close(tickCh) + }() + + // Then: the workspace should eventually be started + workspaceStats := <-statsCh + require.Len(t, workspaceStats.Errors, 0) + require.Len(t, workspaceStats.Transitions, 1) + require.Contains(t, workspaceStats.Transitions, workspace.ID) + require.Equal(t, database.WorkspaceTransitionStart, workspaceStats.Transitions[workspace.ID]) + workspace = coderdtest.MustWorkspace(t, client, workspace.ID) + require.Equal(t, codersdk.BuildReasonAutostart, workspace.LatestBuild.Reason) + }) +} + +func setupTestDBPreset( + t *testing.T, + db database.Store, + templateVersionID uuid.UUID, + desiredInstances int32, +) database.TemplateVersionPreset { + t.Helper() + + preset := dbgen.Preset(t, db, database.InsertPresetParams{ + TemplateVersionID: templateVersionID, + Name: "preset-test", + DesiredInstances: sql.NullInt32{ + Valid: true, + Int32: desiredInstances, + }, + }) + dbgen.PresetParameter(t, db, database.InsertPresetParametersParams{ + TemplateVersionPresetID: preset.ID, + Names: []string{"test-name"}, + Values: []string{"test-value"}, + }) + + return preset +} + +type SetupPrebuiltOptions struct { + AutostartSchedule sql.NullString + IsStopped bool +} + +func WithAutostartSchedule(sched sql.NullString) func(*SetupPrebuiltOptions) { + return func(o *SetupPrebuiltOptions) { + o.AutostartSchedule = sched + } +} + +func WithIsStopped(isStopped bool) func(*SetupPrebuiltOptions) { + return func(o *SetupPrebuiltOptions) { + o.IsStopped = isStopped + } +} + +func setupTestDBWorkspaceBuild( + ctx context.Context, + t *testing.T, + clock quartz.Clock, + db database.Store, + ps pubsub.Pubsub, + orgID uuid.UUID, + workspaceID uuid.UUID, + templateVersionID uuid.UUID, + presetID uuid.UUID, + transition database.WorkspaceTransition, +) (database.ProvisionerJob, database.WorkspaceBuild) { + t.Helper() + + var buildNumber int32 = 1 + latestWorkspaceBuild, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID) + if !errors.Is(err, sql.ErrNoRows) { + buildNumber = latestWorkspaceBuild.BuildNumber + 1 + } + + job := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ + InitiatorID: database.PrebuildsSystemUserID, + CreatedAt: clock.Now().Add(-time.Hour * 2), + StartedAt: sql.NullTime{Time: clock.Now().Add(-time.Hour * 2), Valid: true}, + CompletedAt: sql.NullTime{Time: clock.Now().Add(-time.Hour), Valid: true}, + OrganizationID: orgID, + JobStatus: database.ProvisionerJobStatusSucceeded, + }) + workspaceBuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: workspaceID, + InitiatorID: database.PrebuildsSystemUserID, + TemplateVersionID: templateVersionID, + BuildNumber: buildNumber, + JobID: job.ID, + TemplateVersionPresetID: uuid.NullUUID{UUID: presetID, Valid: true}, + Transition: transition, + CreatedAt: clock.Now(), + }) + dbgen.WorkspaceBuildParameters(t, db, []database.WorkspaceBuildParameter{ + { + WorkspaceBuildID: workspaceBuild.ID, + Name: "test", + Value: "test", + }, + }) + + workspaceResource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ + JobID: job.ID, + Transition: database.WorkspaceTransitionStart, + Type: "compute", + Name: "main", + }) + + // Workspaces are eligible to be claimed once their agent is marked "ready" + dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ + Name: "test", + ResourceID: workspaceResource.ID, + Architecture: "i386", + OperatingSystem: "linux", + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + StartedAt: sql.NullTime{Time: time.Now().Add(time.Hour), Valid: true}, + ReadyAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + APIKeyScope: database.AgentKeyScopeEnumAll, + }) + + return job, workspaceBuild +} + +func setupTestDBPrebuiltWorkspace( + ctx context.Context, + t *testing.T, + clock quartz.Clock, + db database.Store, + ps pubsub.Pubsub, + orgID uuid.UUID, + templateID uuid.UUID, + templateVersionID uuid.UUID, + presetID uuid.UUID, + opts ...func(*SetupPrebuiltOptions), +) database.WorkspaceTable { + t.Helper() + + // Optional parameters + options := &SetupPrebuiltOptions{} + for _, opt := range opts { + opt(options) + } + + buildTransition := database.WorkspaceTransitionStart + if options.IsStopped { + buildTransition = database.WorkspaceTransitionStop + } + + workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ + TemplateID: templateID, + OrganizationID: orgID, + OwnerID: database.PrebuildsSystemUserID, + Deleted: false, + CreatedAt: time.Now().Add(-time.Hour * 2), + AutostartSchedule: options.AutostartSchedule, + }) + setupTestDBWorkspaceBuild(ctx, t, clock, db, ps, orgID, workspace.ID, templateVersionID, presetID, buildTransition) + + return workspace +} + func mustProvisionWorkspace(t *testing.T, client *codersdk.Client, mut ...func(*codersdk.CreateWorkspaceRequest)) codersdk.Workspace { t.Helper() user := coderdtest.CreateFirstUser(t, client) diff --git a/coderd/coderd.go b/coderd/coderd.go index 72316d1ea18e5..c3c1fb09cc6cc 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -59,6 +59,7 @@ import ( "github.com/coder/coder/v2/coderd/appearance" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/awsidentity" + "github.com/coder/coder/v2/coderd/connectionlog" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbrollup" @@ -154,6 +155,7 @@ type Options struct { CacheDir string Auditor audit.Auditor + ConnectionLogger connectionlog.ConnectionLogger AgentConnectionUpdateFrequency time.Duration AgentInactiveDisconnectTimeout time.Duration AWSCertificates awsidentity.Certificates @@ -400,6 +402,9 @@ func New(options *Options) *API { if options.Auditor == nil { options.Auditor = audit.NewNop() } + if options.ConnectionLogger == nil { + options.ConnectionLogger = connectionlog.NewNop() + } if options.SSHConfig.HostnamePrefix == "" { options.SSHConfig.HostnamePrefix = "coder." } @@ -568,6 +573,7 @@ func New(options *Options) *API { }, metricsCache: metricsCache, Auditor: atomic.Pointer[audit.Auditor]{}, + ConnectionLogger: atomic.Pointer[connectionlog.ConnectionLogger]{}, TailnetCoordinator: atomic.Pointer[tailnet.Coordinator]{}, UpdatesProvider: updatesProvider, TemplateScheduleStore: options.TemplateScheduleStore, @@ -589,7 +595,7 @@ func New(options *Options) *API { options.Logger.Named("workspaceapps"), options.AccessURL, options.Authorizer, - &api.Auditor, + &api.ConnectionLogger, options.Database, options.DeploymentValues, oauthConfigs, @@ -691,6 +697,7 @@ func New(options *Options) *API { } api.Auditor.Store(&options.Auditor) + api.ConnectionLogger.Store(&options.ConnectionLogger) api.TailnetCoordinator.Store(&options.TailnetCoordinator) dialer := &InmemTailnetDialer{ CoordPtr: &api.TailnetCoordinator, @@ -1351,6 +1358,7 @@ func New(options *Options) *API { r.Get("/listening-ports", api.workspaceAgentListeningPorts) r.Get("/connection", api.workspaceAgentConnection) r.Get("/containers", api.workspaceAgentListContainers) + r.Get("/containers/watch", api.watchWorkspaceAgentContainers) r.Post("/containers/devcontainers/{devcontainer}/recreate", api.workspaceAgentRecreateDevcontainer) r.Get("/coordinate", api.workspaceAgentClientCoordinate) @@ -1612,6 +1620,7 @@ type API struct { // specific replica. ID uuid.UUID Auditor atomic.Pointer[audit.Auditor] + ConnectionLogger atomic.Pointer[connectionlog.ConnectionLogger] WorkspaceClientCoordinateOverride atomic.Pointer[func(rw http.ResponseWriter) bool] TailnetCoordinator atomic.Pointer[tailnet.Coordinator] NetworkTelemetryBatcher *tailnet.NetworkTelemetryBatcher diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 67551d0e3d2dd..68ab5a27e5a18 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -451,6 +451,7 @@ func randomRBACType() string { all := []string{ rbac.ResourceWorkspace.Type, rbac.ResourceAuditLog.Type, + rbac.ResourceConnectionLog.Type, rbac.ResourceTemplate.Type, rbac.ResourceGroup.Type, rbac.ResourceFile.Type, diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 55e62561af60a..96030b215e5dd 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -1,6 +1,7 @@ package coderdtest import ( + "archive/tar" "bytes" "context" "crypto" @@ -52,6 +53,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/archive" "github.com/coder/coder/v2/coderd/files" "github.com/coder/quartz" @@ -59,6 +61,7 @@ import ( "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/autobuild" "github.com/coder/coder/v2/coderd/awsidentity" + "github.com/coder/coder/v2/coderd/connectionlog" "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" @@ -123,6 +126,7 @@ type Options struct { TemplateScheduleStore schedule.TemplateScheduleStore Coordinator tailnet.Coordinator CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider + ConnectionLogger connectionlog.ConnectionLogger HealthcheckFunc func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport HealthcheckTimeout time.Duration @@ -354,6 +358,12 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can } auditor.Store(&options.Auditor) + var connectionLogger atomic.Pointer[connectionlog.ConnectionLogger] + if options.ConnectionLogger == nil { + options.ConnectionLogger = connectionlog.NewNop() + } + connectionLogger.Store(&options.ConnectionLogger) + ctx, cancelFunc := context.WithCancel(context.Background()) experiments := coderd.ReadExperiments(*options.Logger, options.DeploymentValues.Experiments) lifecycleExecutor := autobuild.NewExecutor( @@ -541,6 +551,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can ExternalAuthConfigs: options.ExternalAuthConfigs, Auditor: options.Auditor, + ConnectionLogger: options.ConnectionLogger, AWSCertificates: options.AWSCertificates, AzureCertificates: options.AzureCertificates, GithubOAuth2Config: options.GithubOAuth2Config, @@ -886,14 +897,22 @@ func createAnotherUserRetry(t testing.TB, client *codersdk.Client, organizationI return other, user } -// CreateTemplateVersion creates a template import provisioner job -// with the responses provided. It uses the "echo" provisioner for compatibility -// with testing. -func CreateTemplateVersion(t testing.TB, client *codersdk.Client, organizationID uuid.UUID, res *echo.Responses, mutators ...func(*codersdk.CreateTemplateVersionRequest)) codersdk.TemplateVersion { +func CreateTemplateVersionMimeType(t testing.TB, client *codersdk.Client, mimeType string, organizationID uuid.UUID, res *echo.Responses, mutators ...func(*codersdk.CreateTemplateVersionRequest)) codersdk.TemplateVersion { t.Helper() data, err := echo.TarWithOptions(context.Background(), client.Logger(), res) require.NoError(t, err) - file, err := client.Upload(context.Background(), codersdk.ContentTypeTar, bytes.NewReader(data)) + + switch mimeType { + case codersdk.ContentTypeTar: + // do nothing + case codersdk.ContentTypeZip: + data, err = archive.CreateZipFromTar(tar.NewReader(bytes.NewBuffer(data)), int64(len(data))) + require.NoError(t, err, "creating zip") + default: + t.Fatal("unexpected mime type", mimeType) + } + + file, err := client.Upload(context.Background(), mimeType, bytes.NewReader(data)) require.NoError(t, err) req := codersdk.CreateTemplateVersionRequest{ @@ -910,6 +929,13 @@ func CreateTemplateVersion(t testing.TB, client *codersdk.Client, organizationID return templateVersion } +// CreateTemplateVersion creates a template import provisioner job +// with the responses provided. It uses the "echo" provisioner for compatibility +// with testing. +func CreateTemplateVersion(t testing.TB, client *codersdk.Client, organizationID uuid.UUID, res *echo.Responses, mutators ...func(*codersdk.CreateTemplateVersionRequest)) codersdk.TemplateVersion { + return CreateTemplateVersionMimeType(t, client, codersdk.ContentTypeTar, organizationID, res, mutators...) +} + // CreateWorkspaceBuild creates a workspace build for the given workspace and transition. func CreateWorkspaceBuild( t *testing.T, diff --git a/coderd/coderdtest/dynamicparameters.go b/coderd/coderdtest/dynamicparameters.go index 5d03f9fde9639..28e01885560ca 100644 --- a/coderd/coderdtest/dynamicparameters.go +++ b/coderd/coderdtest/dynamicparameters.go @@ -20,6 +20,9 @@ type DynamicParameterTemplateParams struct { Plan json.RawMessage ModulesArchive []byte + // Uses a zip archive instead of a tar + Zip bool + // StaticParams is used if the provisioner daemon version does not support dynamic parameters. StaticParams []*proto.RichParameter @@ -45,7 +48,11 @@ func DynamicParameterTemplate(t *testing.T, client *codersdk.Client, org uuid.UU }, }} - version := CreateTemplateVersion(t, client, org, files, func(request *codersdk.CreateTemplateVersionRequest) { + mime := codersdk.ContentTypeTar + if args.Zip { + mime = codersdk.ContentTypeZip + } + version := CreateTemplateVersionMimeType(t, client, mime, org, files, func(request *codersdk.CreateTemplateVersionRequest) { if args.TemplateID != uuid.Nil { request.TemplateID = args.TemplateID } diff --git a/coderd/connectionlog/connectionlog.go b/coderd/connectionlog/connectionlog.go new file mode 100644 index 0000000000000..1b56ffc288fd3 --- /dev/null +++ b/coderd/connectionlog/connectionlog.go @@ -0,0 +1,121 @@ +package connectionlog + +import ( + "context" + "sync" + "testing" + + "github.com/google/uuid" + + "github.com/coder/coder/v2/coderd/database" +) + +type ConnectionLogger interface { + Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error +} + +type nop struct{} + +func NewNop() ConnectionLogger { + return nop{} +} + +func (nop) Upsert(context.Context, database.UpsertConnectionLogParams) error { + return nil +} + +func NewFake() *FakeConnectionLogger { + return &FakeConnectionLogger{} +} + +type FakeConnectionLogger struct { + mu sync.Mutex + upsertions []database.UpsertConnectionLogParams +} + +func (m *FakeConnectionLogger) Reset() { + m.mu.Lock() + defer m.mu.Unlock() + m.upsertions = make([]database.UpsertConnectionLogParams, 0) +} + +func (m *FakeConnectionLogger) ConnectionLogs() []database.UpsertConnectionLogParams { + m.mu.Lock() + defer m.mu.Unlock() + return m.upsertions +} + +func (m *FakeConnectionLogger) Upsert(_ context.Context, clog database.UpsertConnectionLogParams) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.upsertions = append(m.upsertions, clog) + + return nil +} + +func (m *FakeConnectionLogger) Contains(t testing.TB, expected database.UpsertConnectionLogParams) bool { + m.mu.Lock() + defer m.mu.Unlock() + for idx, cl := range m.upsertions { + if expected.ID != uuid.Nil && cl.ID != expected.ID { + t.Logf("connection log %d: expected ID %s, got %s", idx+1, expected.ID, cl.ID) + continue + } + if !expected.Time.IsZero() && expected.Time != cl.Time { + t.Logf("connection log %d: expected Time %s, got %s", idx+1, expected.Time, cl.Time) + continue + } + if expected.OrganizationID != uuid.Nil && cl.OrganizationID != expected.OrganizationID { + t.Logf("connection log %d: expected OrganizationID %s, got %s", idx+1, expected.OrganizationID, cl.OrganizationID) + continue + } + if expected.WorkspaceOwnerID != uuid.Nil && cl.WorkspaceOwnerID != expected.WorkspaceOwnerID { + t.Logf("connection log %d: expected WorkspaceOwnerID %s, got %s", idx+1, expected.WorkspaceOwnerID, cl.WorkspaceOwnerID) + continue + } + if expected.WorkspaceID != uuid.Nil && cl.WorkspaceID != expected.WorkspaceID { + t.Logf("connection log %d: expected WorkspaceID %s, got %s", idx+1, expected.WorkspaceID, cl.WorkspaceID) + continue + } + if expected.WorkspaceName != "" && cl.WorkspaceName != expected.WorkspaceName { + t.Logf("connection log %d: expected WorkspaceName %s, got %s", idx+1, expected.WorkspaceName, cl.WorkspaceName) + continue + } + if expected.AgentName != "" && cl.AgentName != expected.AgentName { + t.Logf("connection log %d: expected AgentName %s, got %s", idx+1, expected.AgentName, cl.AgentName) + continue + } + if expected.Type != "" && cl.Type != expected.Type { + t.Logf("connection log %d: expected Type %s, got %s", idx+1, expected.Type, cl.Type) + continue + } + if expected.Code.Valid && cl.Code.Int32 != expected.Code.Int32 { + t.Logf("connection log %d: expected Code %d, got %d", idx+1, expected.Code.Int32, cl.Code.Int32) + continue + } + if expected.Ip.Valid && cl.Ip.IPNet.String() != expected.Ip.IPNet.String() { + t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.Ip.IPNet, cl.Ip.IPNet) + continue + } + if expected.UserAgent.Valid && cl.UserAgent.String != expected.UserAgent.String { + t.Logf("connection log %d: expected UserAgent %s, got %s", idx+1, expected.UserAgent.String, cl.UserAgent.String) + continue + } + if expected.UserID.Valid && cl.UserID.UUID != expected.UserID.UUID { + t.Logf("connection log %d: expected UserID %s, got %s", idx+1, expected.UserID.UUID, cl.UserID.UUID) + continue + } + if expected.SlugOrPort.Valid && cl.SlugOrPort.String != expected.SlugOrPort.String { + t.Logf("connection log %d: expected SlugOrPort %s, got %s", idx+1, expected.SlugOrPort.String, cl.SlugOrPort.String) + continue + } + if expected.ConnectionID.Valid && cl.ConnectionID.UUID != expected.ConnectionID.UUID { + t.Logf("connection log %d: expected ConnectionID %s, got %s", idx+1, expected.ConnectionID.UUID, cl.ConnectionID.UUID) + continue + } + return true + } + + return false +} diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index 5e9be4d61a57c..320a90b09430b 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -781,26 +781,31 @@ func TemplateRoleActions(role codersdk.TemplateRole) []policy.Action { return []policy.Action{} } -func AuditActionFromAgentProtoConnectionAction(action agentproto.Connection_Action) (database.AuditAction, error) { - switch action { - case agentproto.Connection_CONNECT: - return database.AuditActionConnect, nil - case agentproto.Connection_DISCONNECT: - return database.AuditActionDisconnect, nil +func ConnectionLogConnectionTypeFromAgentProtoConnectionType(typ agentproto.Connection_Type) (database.ConnectionType, error) { + switch typ { + case agentproto.Connection_SSH: + return database.ConnectionTypeSsh, nil + case agentproto.Connection_JETBRAINS: + return database.ConnectionTypeJetbrains, nil + case agentproto.Connection_VSCODE: + return database.ConnectionTypeVscode, nil + case agentproto.Connection_RECONNECTING_PTY: + return database.ConnectionTypeReconnectingPty, nil default: - // Also Connection_ACTION_UNSPECIFIED, no mapping. - return "", xerrors.Errorf("unknown agent connection action %q", action) + // Also Connection_TYPE_UNSPECIFIED, no mapping. + return "", xerrors.Errorf("unknown agent connection type %q", typ) } } -func AgentProtoConnectionActionToAuditAction(action database.AuditAction) (agentproto.Connection_Action, error) { +func ConnectionLogStatusFromAgentProtoConnectionAction(action agentproto.Connection_Action) (database.ConnectionStatus, error) { switch action { - case database.AuditActionConnect: - return agentproto.Connection_CONNECT, nil - case database.AuditActionDisconnect: - return agentproto.Connection_DISCONNECT, nil + case agentproto.Connection_CONNECT: + return database.ConnectionStatusConnected, nil + case agentproto.Connection_DISCONNECT: + return database.ConnectionStatusDisconnected, nil default: - return agentproto.Connection_ACTION_UNSPECIFIED, xerrors.Errorf("unknown agent connection action %q", action) + // Also Connection_ACTION_UNSPECIFIED, no mapping. + return "", xerrors.Errorf("unknown agent connection action %q", action) } } diff --git a/coderd/database/db_test.go b/coderd/database/db_test.go index 68b60a788fd3d..f9442942e53e1 100644 --- a/coderd/database/db_test.go +++ b/coderd/database/db_test.go @@ -85,6 +85,10 @@ func TestNestedInTx(t *testing.T) { func testSQLDB(t testing.TB) *sql.DB { t.Helper() + if !dbtestutil.WillUsePostgres() { + t.Skip("this test requires postgres") + } + connection, err := dbtestutil.Open(t) require.NoError(t, err) diff --git a/coderd/database/dbauthz/customroles_test.go b/coderd/database/dbauthz/customroles_test.go index 5e19f43ab5376..54541d4670c2c 100644 --- a/coderd/database/dbauthz/customroles_test.go +++ b/coderd/database/dbauthz/customroles_test.go @@ -12,7 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/codersdk" @@ -202,7 +202,7 @@ func TestInsertCustomRoles(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rec := &coderdtest.RecordingAuthorizer{ Wrapped: rbac.NewAuthorizer(prometheus.NewRegistry()), } diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index eea1b04a51fc5..a1c758ce03415 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -306,6 +306,24 @@ var ( Scope: rbac.ScopeAll, }.WithCachedASTValue() + subjectConnectionLogger = rbac.Subject{ + Type: rbac.SubjectTypeConnectionLogger, + FriendlyName: "Connection Logger", + ID: uuid.Nil.String(), + Roles: rbac.Roles([]rbac.Role{ + { + Identifier: rbac.RoleIdentifier{Name: "connectionlogger"}, + DisplayName: "Connection Logger", + Site: rbac.Permissions(map[string][]policy.Action{ + rbac.ResourceConnectionLog.Type: {policy.ActionUpdate, policy.ActionRead}, + }), + Org: map[string][]rbac.Permission{}, + User: []rbac.Permission{}, + }, + }), + Scope: rbac.ScopeAll, + }.WithCachedASTValue() + subjectNotifier = rbac.Subject{ Type: rbac.SubjectTypeNotifier, FriendlyName: "Notifier", @@ -521,6 +539,10 @@ func AsKeyReader(ctx context.Context) context.Context { return As(ctx, subjectCryptoKeyReader) } +func AsConnectionLogger(ctx context.Context) context.Context { + return As(ctx, subjectConnectionLogger) +} + // AsNotifier returns a context with an actor that has permissions required for // creating/reading/updating/deleting notifications. func AsNotifier(ctx context.Context) context.Context { @@ -1182,6 +1204,27 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole) return nil } +func (q *querier) authorizeProvisionerJob(ctx context.Context, job database.ProvisionerJob) error { + switch job.Type { + case database.ProvisionerJobTypeWorkspaceBuild: + // Authorized call to get workspace build. If we can read the build, we + // can read the job. + _, err := q.GetWorkspaceBuildByJobID(ctx, job.ID) + if err != nil { + return xerrors.Errorf("fetch related workspace build: %w", err) + } + case database.ProvisionerJobTypeTemplateVersionDryRun, database.ProvisionerJobTypeTemplateVersionImport: + // Authorized call to get template version. + _, err := authorizedTemplateVersionFromJob(ctx, q, job) + if err != nil { + return xerrors.Errorf("fetch related template version: %w", err) + } + default: + return xerrors.Errorf("unknown job type: %q", job.Type) + } + return nil +} + func (q *querier) AcquireLock(ctx context.Context, id int64) error { return q.db.AcquireLock(ctx, id) } @@ -1835,6 +1878,21 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI return q.db.GetAuthorizationUserRoles(ctx, userID) } +func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) { + // Just like with the audit logs query, shortcut if the user is an owner. + err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog) + if err == nil { + return q.db.GetConnectionLogsOffset(ctx, arg) + } + + prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceConnectionLog.Type) + if err != nil { + return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + + return q.db.GetAuthorizedConnectionLogsOffset(ctx, arg, prep) +} + func (q *querier) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { return "", err @@ -2445,32 +2503,24 @@ func (q *querier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (data return database.ProvisionerJob{}, err } - switch job.Type { - case database.ProvisionerJobTypeWorkspaceBuild: - // Authorized call to get workspace build. If we can read the build, we - // can read the job. - _, err := q.GetWorkspaceBuildByJobID(ctx, id) - if err != nil { - return database.ProvisionerJob{}, xerrors.Errorf("fetch related workspace build: %w", err) - } - case database.ProvisionerJobTypeTemplateVersionDryRun, database.ProvisionerJobTypeTemplateVersionImport: - // Authorized call to get template version. - _, err := authorizedTemplateVersionFromJob(ctx, q, job) - if err != nil { - return database.ProvisionerJob{}, xerrors.Errorf("fetch related template version: %w", err) - } - default: - return database.ProvisionerJob{}, xerrors.Errorf("unknown job type: %q", job.Type) + if err := q.authorizeProvisionerJob(ctx, job); err != nil { + return database.ProvisionerJob{}, err } return job, nil } func (q *querier) GetProvisionerJobByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceProvisionerJobs); err != nil { + job, err := q.db.GetProvisionerJobByIDForUpdate(ctx, id) + if err != nil { return database.ProvisionerJob{}, err } - return q.db.GetProvisionerJobByIDForUpdate(ctx, id) + + if err := q.authorizeProvisionerJob(ctx, job); err != nil { + return database.ProvisionerJob{}, err + } + + return job, nil } func (q *querier) GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ProvisionerJobTiming, error) { @@ -2583,6 +2633,14 @@ func (q *querier) GetRunningPrebuiltWorkspaces(ctx context.Context) ([]database. return q.db.GetRunningPrebuiltWorkspaces(ctx) } +func (q *querier) GetRunningPrebuiltWorkspacesOptimized(ctx context.Context) ([]database.GetRunningPrebuiltWorkspacesOptimizedRow, error) { + // This query returns only prebuilt workspaces, but we decided to require permissions for all workspaces. + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil { + return nil, err + } + return q.db.GetRunningPrebuiltWorkspacesOptimized(ctx) +} + func (q *querier) GetRuntimeConfig(ctx context.Context, key string) (string, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { return "", err @@ -5078,6 +5136,13 @@ func (q *querier) UpsertApplicationName(ctx context.Context, value string) error return q.db.UpsertApplicationName(ctx, value) } +func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil { + return database.ConnectionLog{}, err + } + return q.db.UpsertConnectionLog(ctx, arg) +} + func (q *querier) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { return err @@ -5323,3 +5388,7 @@ func (q *querier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg database func (q *querier) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, _ rbac.PreparedAuthorized) (int64, error) { return q.CountAuditLogs(ctx, arg) } + +func (q *querier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, _ rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) { + return q.GetConnectionLogsOffset(ctx, arg) +} diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 006320ef459a4..5416f33e521ec 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -19,7 +19,6 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database/db2sdk" - "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/codersdk" @@ -179,7 +178,8 @@ func TestDBAuthzRecursive(t *testing.T) { if method.Name == "InTx" || method.Name == "Ping" || method.Name == "Wrappers" || - method.Name == "PGLocks" { + method.Name == "PGLocks" || + method.Name == "GetRunningPrebuiltWorkspacesOptimized" { continue } // easy to know which method failed. @@ -339,6 +339,75 @@ func (s *MethodTestSuite) TestAuditLogs() { })) } +func (s *MethodTestSuite) TestConnectionLogs() { + createWorkspace := func(t *testing.T, db database.Store) database.WorkspaceTable { + u := dbgen.User(s.T(), db, database.User{}) + o := dbgen.Organization(s.T(), db, database.Organization{}) + tpl := dbgen.Template(s.T(), db, database.Template{ + OrganizationID: o.ID, + CreatedBy: u.ID, + }) + return dbgen.Workspace(s.T(), db, database.WorkspaceTable{ + ID: uuid.New(), + OwnerID: u.ID, + OrganizationID: o.ID, + AutomaticUpdates: database.AutomaticUpdatesNever, + TemplateID: tpl.ID, + }) + } + s.Run("UpsertConnectionLog", s.Subtest(func(db database.Store, check *expects) { + ws := createWorkspace(s.T(), db) + check.Args(database.UpsertConnectionLogParams{ + Ip: defaultIPAddress(), + Type: database.ConnectionTypeSsh, + WorkspaceID: ws.ID, + OrganizationID: ws.OrganizationID, + ConnectionStatus: database.ConnectionStatusConnected, + WorkspaceOwnerID: ws.OwnerID, + }).Asserts(rbac.ResourceConnectionLog, policy.ActionUpdate) + })) + s.Run("GetConnectionLogsOffset", s.Subtest(func(db database.Store, check *expects) { + ws := createWorkspace(s.T(), db) + _ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{ + Ip: defaultIPAddress(), + Type: database.ConnectionTypeSsh, + WorkspaceID: ws.ID, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + }) + _ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{ + Ip: defaultIPAddress(), + Type: database.ConnectionTypeSsh, + WorkspaceID: ws.ID, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + }) + check.Args(database.GetConnectionLogsOffsetParams{ + LimitOpt: 10, + }).Asserts(rbac.ResourceConnectionLog, policy.ActionRead).WithNotAuthorized("nil") + })) + s.Run("GetAuthorizedConnectionLogsOffset", s.Subtest(func(db database.Store, check *expects) { + ws := createWorkspace(s.T(), db) + _ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{ + Ip: defaultIPAddress(), + Type: database.ConnectionTypeSsh, + WorkspaceID: ws.ID, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + }) + _ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{ + Ip: defaultIPAddress(), + Type: database.ConnectionTypeSsh, + WorkspaceID: ws.ID, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + }) + check.Args(database.GetConnectionLogsOffsetParams{ + LimitOpt: 10, + }, emptyPreparedAuthorized{}).Asserts(rbac.ResourceConnectionLog, policy.ActionRead) + })) +} + func (s *MethodTestSuite) TestFile() { s.Run("GetFileByHashAndCreator", s.Subtest(func(db database.Store, check *expects) { f := dbgen.File(s.T(), db, database.File{}) @@ -3661,148 +3730,119 @@ func (s *MethodTestSuite) TestExtraMethods() { func (s *MethodTestSuite) TestTailnetFunctions() { s.Run("CleanTailnetCoordinators", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("CleanTailnetLostPeers", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("CleanTailnetTunnels", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("DeleteAllTailnetClientSubscriptions", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteAllTailnetClientSubscriptionsParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("DeleteAllTailnetTunnels", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteAllTailnetTunnelsParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("DeleteCoordinator", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("DeleteTailnetAgent", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteTailnetAgentParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate).Errors(sql.ErrNoRows). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate).Errors(sql.ErrNoRows) })) s.Run("DeleteTailnetClient", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteTailnetClientParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete).Errors(sql.ErrNoRows). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete).Errors(sql.ErrNoRows) })) s.Run("DeleteTailnetClientSubscription", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteTailnetClientSubscriptionParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete) })) s.Run("DeleteTailnetPeer", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteTailnetPeerParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented). - ErrorsWithPG(sql.ErrNoRows) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete).Errors(sql.ErrNoRows) })) s.Run("DeleteTailnetTunnel", s.Subtest(func(_ database.Store, check *expects) { check.Args(database.DeleteTailnetTunnelParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete). - ErrorsWithInMemDB(dbmem.ErrUnimplemented). - ErrorsWithPG(sql.ErrNoRows) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionDelete).Errors(sql.ErrNoRows) })) s.Run("GetAllTailnetAgents", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetTailnetAgents", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetTailnetClientsForAgent", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetTailnetPeers", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetTailnetTunnelPeerBindings", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetTailnetTunnelPeerIDs", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetAllTailnetCoordinators", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetAllTailnetPeers", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("GetAllTailnetTunnels", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionRead) })) s.Run("UpsertTailnetAgent", s.Subtest(func(db database.Store, check *expects) { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) check.Args(database.UpsertTailnetAgentParams{Node: json.RawMessage("{}")}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate) })) s.Run("UpsertTailnetClient", s.Subtest(func(db database.Store, check *expects) { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) check.Args(database.UpsertTailnetClientParams{Node: json.RawMessage("{}")}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate) })) s.Run("UpsertTailnetClientSubscription", s.Subtest(func(db database.Store, check *expects) { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) check.Args(database.UpsertTailnetClientSubscriptionParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate) })) s.Run("UpsertTailnetCoordinator", s.Subtest(func(_ database.Store, check *expects) { check.Args(uuid.New()). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate) })) s.Run("UpsertTailnetPeer", s.Subtest(func(db database.Store, check *expects) { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) check.Args(database.UpsertTailnetPeerParams{ Status: database.TailnetStatusOk, }). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionCreate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionCreate) })) s.Run("UpsertTailnetTunnel", s.Subtest(func(db database.Store, check *expects) { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) check.Args(database.UpsertTailnetTunnelParams{}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionCreate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionCreate) })) s.Run("UpdateTailnetPeerStatusByCoordinator", s.Subtest(func(db database.Store, check *expects) { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) check.Args(database.UpdateTailnetPeerStatusByCoordinatorParams{Status: database.TailnetStatusOk}). - Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTailnetCoordinator, policy.ActionUpdate) })) } @@ -4655,8 +4695,59 @@ func (s *MethodTestSuite) TestSystemFunctions() { VapidPrivateKey: "test", }).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) - s.Run("GetProvisionerJobByIDForUpdate", s.Subtest(func(db database.Store, check *expects) { - check.Args(uuid.New()).Asserts(rbac.ResourceProvisionerJobs, policy.ActionRead).Errors(sql.ErrNoRows) + s.Run("Build/GetProvisionerJobByIDForUpdate", s.Subtest(func(db database.Store, check *expects) { + u := dbgen.User(s.T(), db, database.User{}) + o := dbgen.Organization(s.T(), db, database.Organization{}) + tpl := dbgen.Template(s.T(), db, database.Template{ + OrganizationID: o.ID, + CreatedBy: u.ID, + }) + w := dbgen.Workspace(s.T(), db, database.WorkspaceTable{ + OwnerID: u.ID, + OrganizationID: o.ID, + TemplateID: tpl.ID, + }) + j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeWorkspaceBuild, + }) + tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, + JobID: j.ID, + OrganizationID: o.ID, + CreatedBy: u.ID, + }) + _ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{ + JobID: j.ID, + WorkspaceID: w.ID, + TemplateVersionID: tv.ID, + }) + check.Args(j.ID).Asserts(w, policy.ActionRead).Returns(j) + })) + s.Run("TemplateVersion/GetProvisionerJobByIDForUpdate", s.Subtest(func(db database.Store, check *expects) { + dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) + j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeTemplateVersionImport, + }) + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, + JobID: j.ID, + }) + check.Args(j.ID).Asserts(v.RBACObject(tpl), policy.ActionRead).Returns(j) + })) + s.Run("TemplateVersionDryRun/GetProvisionerJobByIDForUpdate", s.Subtest(func(db database.Store, check *expects) { + dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) + tpl := dbgen.Template(s.T(), db, database.Template{}) + v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, + }) + j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + Input: must(json.Marshal(struct { + TemplateVersionID uuid.UUID `json:"template_version_id"` + }{TemplateVersionID: v.ID})), + }) + check.Args(j.ID).Asserts(v.RBACObject(tpl), policy.ActionRead).Returns(j) })) s.Run("HasTemplateVersionsWithAITask", s.Subtest(func(db database.Store, check *expects) { check.Args().Asserts() @@ -4736,21 +4827,18 @@ func (s *MethodTestSuite) TestNotifications() { dbtestutil.DisableForeignKeysAndTriggers(s.T(), db) user := dbgen.User(s.T(), db, database.User{}) check.Args(user.ID).Asserts(rbac.ResourceNotificationTemplate, policy.ActionRead). - ErrorsWithPG(sql.ErrNoRows). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + ErrorsWithPG(sql.ErrNoRows) })) s.Run("GetNotificationTemplatesByKind", s.Subtest(func(db database.Store, check *expects) { check.Args(database.NotificationTemplateKindSystem). - Asserts(). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts() // TODO(dannyk): add support for other database.NotificationTemplateKind types once implemented. })) s.Run("UpdateNotificationTemplateMethodByID", s.Subtest(func(db database.Store, check *expects) { check.Args(database.UpdateNotificationTemplateMethodByIDParams{ Method: database.NullNotificationMethod{NotificationMethod: database.NotificationMethodWebhook, Valid: true}, ID: notifications.TemplateWorkspaceDormant, - }).Asserts(rbac.ResourceNotificationTemplate, policy.ActionUpdate). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + }).Asserts(rbac.ResourceNotificationTemplate, policy.ActionUpdate) })) // Notification preferences @@ -5064,8 +5152,7 @@ func (s *MethodTestSuite) TestPrebuilds() { rbac.ResourceWorkspace.WithOwner(user.ID.String()).InOrg(org.ID), policy.ActionCreate, template, policy.ActionRead, template, policy.ActionUse, - ).ErrorsWithInMemDB(dbmem.ErrUnimplemented). - ErrorsWithPG(sql.ErrNoRows) + ).Errors(sql.ErrNoRows) })) s.Run("GetPrebuildMetrics", s.Subtest(func(_ database.Store, check *expects) { check.Args(). @@ -5079,29 +5166,24 @@ func (s *MethodTestSuite) TestPrebuilds() { })) s.Run("CountInProgressPrebuilds", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead) })) s.Run("GetPresetsAtFailureLimit", s.Subtest(func(_ database.Store, check *expects) { check.Args(int64(0)). - Asserts(rbac.ResourceTemplate.All(), policy.ActionViewInsights). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTemplate.All(), policy.ActionViewInsights) })) s.Run("GetPresetsBackoff", s.Subtest(func(_ database.Store, check *expects) { check.Args(time.Time{}). - Asserts(rbac.ResourceTemplate.All(), policy.ActionViewInsights). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTemplate.All(), policy.ActionViewInsights) })) s.Run("GetRunningPrebuiltWorkspaces", s.Subtest(func(_ database.Store, check *expects) { check.Args(). - Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead) })) s.Run("GetTemplatePresetsWithPrebuilds", s.Subtest(func(db database.Store, check *expects) { user := dbgen.User(s.T(), db, database.User{}) check.Args(uuid.NullUUID{UUID: user.ID, Valid: true}). - Asserts(rbac.ResourceTemplate.All(), policy.ActionRead). - ErrorsWithInMemDB(dbmem.ErrUnimplemented) + Asserts(rbac.ResourceTemplate.All(), policy.ActionRead) })) s.Run("GetPresetByID", s.Subtest(func(db database.Store, check *expects) { org := dbgen.Organization(s.T(), db, database.Organization{}) diff --git a/coderd/database/dbauthz/setup_test.go b/coderd/database/dbauthz/setup_test.go index 555a17fb2070f..23effafc632e0 100644 --- a/coderd/database/dbauthz/setup_test.go +++ b/coderd/database/dbauthz/setup_test.go @@ -41,6 +41,8 @@ var skipMethods = map[string]string{ "Wrappers": "Not relevant", "AcquireLock": "Not relevant", "TryAcquireLock": "Not relevant", + // This method will be removed once we know this works correctly. + "GetRunningPrebuiltWorkspacesOptimized": "Not relevant", } // TestMethodTestSuite runs MethodTestSuite. diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 0bb7bde403297..9720050a43cb1 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -73,6 +73,53 @@ func AuditLog(t testing.TB, db database.Store, seed database.AuditLog) database. return log } +func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnectionLogParams) database.ConnectionLog { + log, err := db.UpsertConnectionLog(genCtx, database.UpsertConnectionLogParams{ + ID: takeFirst(seed.ID, uuid.New()), + Time: takeFirst(seed.Time, dbtime.Now()), + OrganizationID: takeFirst(seed.OrganizationID, uuid.New()), + WorkspaceOwnerID: takeFirst(seed.WorkspaceOwnerID, uuid.New()), + WorkspaceID: takeFirst(seed.WorkspaceID, uuid.New()), + WorkspaceName: takeFirst(seed.WorkspaceName, testutil.GetRandomName(t)), + AgentName: takeFirst(seed.AgentName, testutil.GetRandomName(t)), + Type: takeFirst(seed.Type, database.ConnectionTypeSsh), + Code: sql.NullInt32{ + Int32: takeFirst(seed.Code.Int32, 0), + Valid: takeFirst(seed.Code.Valid, false), + }, + Ip: pqtype.Inet{ + IPNet: net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + Valid: true, + }, + UserAgent: sql.NullString{ + String: takeFirst(seed.UserAgent.String, ""), + Valid: takeFirst(seed.UserAgent.Valid, false), + }, + UserID: uuid.NullUUID{ + UUID: takeFirst(seed.UserID.UUID, uuid.Nil), + Valid: takeFirst(seed.UserID.Valid, false), + }, + SlugOrPort: sql.NullString{ + String: takeFirst(seed.SlugOrPort.String, ""), + Valid: takeFirst(seed.SlugOrPort.Valid, false), + }, + ConnectionID: uuid.NullUUID{ + UUID: takeFirst(seed.ConnectionID.UUID, uuid.Nil), + Valid: takeFirst(seed.ConnectionID.Valid, false), + }, + DisconnectReason: sql.NullString{ + String: takeFirst(seed.DisconnectReason.String, ""), + Valid: takeFirst(seed.DisconnectReason.Valid, false), + }, + ConnectionStatus: takeFirst(seed.ConnectionStatus, database.ConnectionStatusConnected), + }) + require.NoError(t, err, "insert connection log") + return log +} + func Template(t testing.TB, db database.Store, seed database.Template) database.Template { id := takeFirst(seed.ID, uuid.New()) if seed.GroupACL == nil { @@ -204,6 +251,17 @@ func WorkspaceAgent(t testing.TB, db database.Store, orig database.WorkspaceAgen require.NoError(t, err, "update workspace agent first connected at") } + // If the lifecycle state is "ready", update the agent with the corresponding timestamps + if orig.LifecycleState == database.WorkspaceAgentLifecycleStateReady && orig.StartedAt.Valid && orig.ReadyAt.Valid { + err := db.UpdateWorkspaceAgentLifecycleStateByID(genCtx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agt.ID, + LifecycleState: orig.LifecycleState, + StartedAt: orig.StartedAt, + ReadyAt: orig.ReadyAt, + }) + require.NoError(t, err, "update workspace agent lifecycle state") + } + if orig.ParentID.UUID == uuid.Nil { // Add a test antagonist. For every agent we add a deleted sub agent // to discover cases where deletion should be handled. @@ -348,6 +406,14 @@ func Workspace(t testing.TB, db database.Store, orig database.WorkspaceTable) da NextStartAt: orig.NextStartAt, }) require.NoError(t, err, "insert workspace") + if orig.Deleted { + err = db.UpdateWorkspaceDeletedByID(genCtx, database.UpdateWorkspaceDeletedByIDParams{ + ID: workspace.ID, + Deleted: true, + }) + require.NoError(t, err, "set workspace as deleted") + workspace.Deleted = true + } return workspace } @@ -1352,6 +1418,17 @@ func PresetParameter(t testing.TB, db database.Store, seed database.InsertPreset return parameters } +func ClaimPrebuild(t testing.TB, db database.Store, newUserID uuid.UUID, newName string, presetID uuid.UUID) database.ClaimPrebuiltWorkspaceRow { + claimedWorkspace, err := db.ClaimPrebuiltWorkspace(genCtx, database.ClaimPrebuiltWorkspaceParams{ + NewUserID: newUserID, + NewName: newName, + PresetID: presetID, + }) + require.NoError(t, err, "claim prebuilt workspace") + + return claimedWorkspace +} + func provisionerJobTiming(t testing.TB, db database.Store, seed database.ProvisionerJobTiming) database.ProvisionerJobTiming { timing, err := db.InsertProvisionerJobTimings(genCtx, database.InsertProvisionerJobTimingsParams{ JobID: takeFirst(seed.JobID, uuid.New()), diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go deleted file mode 100644 index d106e6a5858fb..0000000000000 --- a/coderd/database/dbmem/dbmem.go +++ /dev/null @@ -1,14242 +0,0 @@ -package dbmem - -import ( - "bytes" - "context" - "database/sql" - "encoding/json" - "errors" - "fmt" - "math" - insecurerand "math/rand" //#nosec // this is only used for shuffling an array to pick random jobs to reap - "reflect" - "regexp" - "slices" - "sort" - "strings" - "sync" - "time" - - "github.com/google/uuid" - "github.com/lib/pq" - "golang.org/x/exp/constraints" - "golang.org/x/exp/maps" - "golang.org/x/xerrors" - - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbtime" - "github.com/coder/coder/v2/coderd/notifications/types" - "github.com/coder/coder/v2/coderd/rbac" - "github.com/coder/coder/v2/coderd/rbac/regosql" - "github.com/coder/coder/v2/coderd/util/slice" - "github.com/coder/coder/v2/coderd/workspaceapps/appurl" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/provisionersdk" -) - -var validProxyByHostnameRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) - -// A full mapping of error codes from pq v1.10.9 can be found here: -// https://github.com/lib/pq/blob/2a217b94f5ccd3de31aec4152a541b9ff64bed05/error.go#L75 -var ( - errForeignKeyConstraint = &pq.Error{ - Code: "23503", // "foreign_key_violation" - Message: "update or delete on table violates foreign key constraint", - } - errUniqueConstraint = &pq.Error{ - Code: "23505", // "unique_violation" - Message: "duplicate key value violates unique constraint", - } -) - -// New returns an in-memory fake of the database. -func New() database.Store { - q := &FakeQuerier{ - mutex: &sync.RWMutex{}, - data: &data{ - apiKeys: make([]database.APIKey, 0), - auditLogs: make([]database.AuditLog, 0), - customRoles: make([]database.CustomRole, 0), - dbcryptKeys: make([]database.DBCryptKey, 0), - externalAuthLinks: make([]database.ExternalAuthLink, 0), - files: make([]database.File, 0), - gitSSHKey: make([]database.GitSSHKey, 0), - groups: make([]database.Group, 0), - groupMembers: make([]database.GroupMemberTable, 0), - licenses: make([]database.License, 0), - locks: map[int64]struct{}{}, - notificationMessages: make([]database.NotificationMessage, 0), - notificationPreferences: make([]database.NotificationPreference, 0), - organizationMembers: make([]database.OrganizationMember, 0), - organizations: make([]database.Organization, 0), - inboxNotifications: make([]database.InboxNotification, 0), - parameterSchemas: make([]database.ParameterSchema, 0), - presets: make([]database.TemplateVersionPreset, 0), - presetParameters: make([]database.TemplateVersionPresetParameter, 0), - presetPrebuildSchedules: make([]database.TemplateVersionPresetPrebuildSchedule, 0), - provisionerDaemons: make([]database.ProvisionerDaemon, 0), - provisionerJobs: make([]database.ProvisionerJob, 0), - provisionerJobLogs: make([]database.ProvisionerJobLog, 0), - provisionerKeys: make([]database.ProvisionerKey, 0), - runtimeConfig: map[string]string{}, - telemetryItems: make([]database.TelemetryItem, 0), - templateVersions: make([]database.TemplateVersionTable, 0), - templateVersionTerraformValues: make([]database.TemplateVersionTerraformValue, 0), - templates: make([]database.TemplateTable, 0), - users: make([]database.User, 0), - userConfigs: make([]database.UserConfig, 0), - userStatusChanges: make([]database.UserStatusChange, 0), - workspaceAgents: make([]database.WorkspaceAgent, 0), - workspaceResources: make([]database.WorkspaceResource, 0), - workspaceModules: make([]database.WorkspaceModule, 0), - workspaceResourceMetadata: make([]database.WorkspaceResourceMetadatum, 0), - workspaceAgentStats: make([]database.WorkspaceAgentStat, 0), - workspaceAgentLogs: make([]database.WorkspaceAgentLog, 0), - workspaceBuilds: make([]database.WorkspaceBuild, 0), - workspaceApps: make([]database.WorkspaceApp, 0), - workspaceAppAuditSessions: make([]database.WorkspaceAppAuditSession, 0), - workspaces: make([]database.WorkspaceTable, 0), - workspaceProxies: make([]database.WorkspaceProxy, 0), - }, - } - // Always start with a default org. Matching migration 198. - defaultOrg, err := q.InsertOrganization(context.Background(), database.InsertOrganizationParams{ - ID: uuid.New(), - Name: "coder", - DisplayName: "Coder", - Description: "Builtin default organization.", - Icon: "", - CreatedAt: dbtime.Now(), - UpdatedAt: dbtime.Now(), - }) - if err != nil { - panic(xerrors.Errorf("failed to create default organization: %w", err)) - } - - _, err = q.InsertAllUsersGroup(context.Background(), defaultOrg.ID) - if err != nil { - panic(xerrors.Errorf("failed to create default group: %w", err)) - } - - q.defaultProxyDisplayName = "Default" - q.defaultProxyIconURL = "/emojis/1f3e1.png" - - _, err = q.InsertProvisionerKey(context.Background(), database.InsertProvisionerKeyParams{ - ID: codersdk.ProvisionerKeyUUIDBuiltIn, - OrganizationID: defaultOrg.ID, - CreatedAt: dbtime.Now(), - HashedSecret: []byte{}, - Name: codersdk.ProvisionerKeyNameBuiltIn, - Tags: map[string]string{}, - }) - if err != nil { - panic(xerrors.Errorf("failed to create built-in provisioner key: %w", err)) - } - _, err = q.InsertProvisionerKey(context.Background(), database.InsertProvisionerKeyParams{ - ID: codersdk.ProvisionerKeyUUIDUserAuth, - OrganizationID: defaultOrg.ID, - CreatedAt: dbtime.Now(), - HashedSecret: []byte{}, - Name: codersdk.ProvisionerKeyNameUserAuth, - Tags: map[string]string{}, - }) - if err != nil { - panic(xerrors.Errorf("failed to create user-auth provisioner key: %w", err)) - } - _, err = q.InsertProvisionerKey(context.Background(), database.InsertProvisionerKeyParams{ - ID: codersdk.ProvisionerKeyUUIDPSK, - OrganizationID: defaultOrg.ID, - CreatedAt: dbtime.Now(), - HashedSecret: []byte{}, - Name: codersdk.ProvisionerKeyNamePSK, - Tags: map[string]string{}, - }) - if err != nil { - panic(xerrors.Errorf("failed to create psk provisioner key: %w", err)) - } - - q.mutex.Lock() - // We can't insert this user using the interface, because it's a system user. - q.data.users = append(q.data.users, database.User{ - ID: database.PrebuildsSystemUserID, - Email: "prebuilds@coder.com", - Username: "prebuilds", - CreatedAt: dbtime.Now(), - UpdatedAt: dbtime.Now(), - Status: "active", - LoginType: "none", - HashedPassword: []byte{}, - IsSystem: true, - Deleted: false, - }) - q.mutex.Unlock() - - return q -} - -type rwMutex interface { - Lock() - RLock() - Unlock() - RUnlock() -} - -// inTxMutex is a no op, since inside a transaction we are already locked. -type inTxMutex struct{} - -func (inTxMutex) Lock() {} -func (inTxMutex) RLock() {} -func (inTxMutex) Unlock() {} -func (inTxMutex) RUnlock() {} - -// FakeQuerier replicates database functionality to enable quick testing. It's an exported type so that our test code -// can do type checks. -type FakeQuerier struct { - mutex rwMutex - *data -} - -func (*FakeQuerier) Wrappers() []string { - return []string{} -} - -type fakeTx struct { - *FakeQuerier - locks map[int64]struct{} -} - -type data struct { - // Legacy tables - apiKeys []database.APIKey - organizations []database.Organization - organizationMembers []database.OrganizationMember - users []database.User - userLinks []database.UserLink - - // New tables - auditLogs []database.AuditLog - cryptoKeys []database.CryptoKey - dbcryptKeys []database.DBCryptKey - files []database.File - externalAuthLinks []database.ExternalAuthLink - gitSSHKey []database.GitSSHKey - groupMembers []database.GroupMemberTable - groups []database.Group - licenses []database.License - notificationMessages []database.NotificationMessage - notificationPreferences []database.NotificationPreference - notificationReportGeneratorLogs []database.NotificationReportGeneratorLog - inboxNotifications []database.InboxNotification - oauth2ProviderApps []database.OAuth2ProviderApp - oauth2ProviderAppSecrets []database.OAuth2ProviderAppSecret - oauth2ProviderAppCodes []database.OAuth2ProviderAppCode - oauth2ProviderAppTokens []database.OAuth2ProviderAppToken - parameterSchemas []database.ParameterSchema - provisionerDaemons []database.ProvisionerDaemon - provisionerJobLogs []database.ProvisionerJobLog - provisionerJobs []database.ProvisionerJob - provisionerKeys []database.ProvisionerKey - replicas []database.Replica - templateVersions []database.TemplateVersionTable - templateVersionParameters []database.TemplateVersionParameter - templateVersionTerraformValues []database.TemplateVersionTerraformValue - templateVersionVariables []database.TemplateVersionVariable - templateVersionWorkspaceTags []database.TemplateVersionWorkspaceTag - templates []database.TemplateTable - templateUsageStats []database.TemplateUsageStat - userConfigs []database.UserConfig - webpushSubscriptions []database.WebpushSubscription - workspaceAgents []database.WorkspaceAgent - workspaceAgentMetadata []database.WorkspaceAgentMetadatum - workspaceAgentLogs []database.WorkspaceAgentLog - workspaceAgentLogSources []database.WorkspaceAgentLogSource - workspaceAgentPortShares []database.WorkspaceAgentPortShare - workspaceAgentScriptTimings []database.WorkspaceAgentScriptTiming - workspaceAgentScripts []database.WorkspaceAgentScript - workspaceAgentStats []database.WorkspaceAgentStat - workspaceAgentMemoryResourceMonitors []database.WorkspaceAgentMemoryResourceMonitor - workspaceAgentVolumeResourceMonitors []database.WorkspaceAgentVolumeResourceMonitor - workspaceAgentDevcontainers []database.WorkspaceAgentDevcontainer - workspaceApps []database.WorkspaceApp - workspaceAppStatuses []database.WorkspaceAppStatus - workspaceAppAuditSessions []database.WorkspaceAppAuditSession - workspaceAppStatsLastInsertID int64 - workspaceAppStats []database.WorkspaceAppStat - workspaceBuilds []database.WorkspaceBuild - workspaceBuildParameters []database.WorkspaceBuildParameter - workspaceResourceMetadata []database.WorkspaceResourceMetadatum - workspaceResources []database.WorkspaceResource - workspaceModules []database.WorkspaceModule - workspaces []database.WorkspaceTable - workspaceProxies []database.WorkspaceProxy - customRoles []database.CustomRole - provisionerJobTimings []database.ProvisionerJobTiming - runtimeConfig map[string]string - // Locks is a map of lock names. Any keys within the map are currently - // locked. - locks map[int64]struct{} - deploymentID string - derpMeshKey string - lastUpdateCheck []byte - announcementBanners []byte - healthSettings []byte - notificationsSettings []byte - oauth2GithubDefaultEligible *bool - applicationName string - logoURL string - appSecurityKey string - oauthSigningKey string - coordinatorResumeTokenSigningKey string - lastLicenseID int32 - defaultProxyDisplayName string - defaultProxyIconURL string - webpushVAPIDPublicKey string - webpushVAPIDPrivateKey string - userStatusChanges []database.UserStatusChange - telemetryItems []database.TelemetryItem - presets []database.TemplateVersionPreset - presetParameters []database.TemplateVersionPresetParameter - presetPrebuildSchedules []database.TemplateVersionPresetPrebuildSchedule - prebuildsSettings []byte -} - -func tryPercentileCont(fs []float64, p float64) float64 { - if len(fs) == 0 { - return -1 - } - sort.Float64s(fs) - pos := p * (float64(len(fs)) - 1) / 100 - lower, upper := int(pos), int(math.Ceil(pos)) - if lower == upper { - return fs[lower] - } - return fs[lower] + (fs[upper]-fs[lower])*(pos-float64(lower)) -} - -func tryPercentileDisc(fs []float64, p float64) float64 { - if len(fs) == 0 { - return -1 - } - sort.Float64s(fs) - return fs[max(int(math.Ceil(float64(len(fs))*p/100-1)), 0)] -} - -func validateDatabaseTypeWithValid(v reflect.Value) (handled bool, err error) { - if v.Kind() == reflect.Struct { - return false, nil - } - - if v.CanInterface() { - if !strings.Contains(v.Type().PkgPath(), "coderd/database") { - return true, nil - } - if valid, ok := v.Interface().(interface{ Valid() bool }); ok { - if !valid.Valid() { - return true, xerrors.Errorf("invalid %s: %q", v.Type().Name(), v.Interface()) - } - } - return true, nil - } - return false, nil -} - -// validateDatabaseType uses reflect to check if struct properties are types -// with a Valid() bool function set. If so, call it and return an error -// if false. -// -// Note that we only check immediate values and struct fields. We do not -// recurse into nested structs. -func validateDatabaseType(args interface{}) error { - v := reflect.ValueOf(args) - - // Note: database.Null* types don't have a Valid method, we skip them here - // because their embedded types may have a Valid method and we don't want - // to bother with checking both that the Valid field is true and that the - // type it embeds validates to true. We would need to check: - // - // dbNullEnum.Valid && dbNullEnum.Enum.Valid() - if strings.HasPrefix(v.Type().Name(), "Null") { - return nil - } - - if ok, err := validateDatabaseTypeWithValid(v); ok { - return err - } - switch v.Kind() { - case reflect.Struct: - var errs []string - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - if ok, err := validateDatabaseTypeWithValid(field); ok && err != nil { - errs = append(errs, fmt.Sprintf("%s.%s: %s", v.Type().Name(), v.Type().Field(i).Name, err.Error())) - } - } - if len(errs) > 0 { - return xerrors.Errorf("invalid database type fields:\n\t%s", strings.Join(errs, "\n\t")) - } - default: - panic(fmt.Sprintf("unhandled type: %s", v.Type().Name())) - } - return nil -} - -func newUniqueConstraintError(uc database.UniqueConstraint) *pq.Error { - newErr := *errUniqueConstraint - newErr.Constraint = string(uc) - - return &newErr -} - -func (*FakeQuerier) Ping(_ context.Context) (time.Duration, error) { - return 0, nil -} - -func (*FakeQuerier) PGLocks(_ context.Context) (database.PGLocks, error) { - return []database.PGLock{}, nil -} - -func (tx *fakeTx) AcquireLock(_ context.Context, id int64) error { - if _, ok := tx.FakeQuerier.locks[id]; ok { - return xerrors.Errorf("cannot acquire lock %d: already held", id) - } - tx.FakeQuerier.locks[id] = struct{}{} - tx.locks[id] = struct{}{} - return nil -} - -func (tx *fakeTx) TryAcquireLock(_ context.Context, id int64) (bool, error) { - if _, ok := tx.FakeQuerier.locks[id]; ok { - return false, nil - } - tx.FakeQuerier.locks[id] = struct{}{} - tx.locks[id] = struct{}{} - return true, nil -} - -func (tx *fakeTx) releaseLocks() { - for id := range tx.locks { - delete(tx.FakeQuerier.locks, id) - } - tx.locks = map[int64]struct{}{} -} - -// InTx doesn't rollback data properly for in-memory yet. -func (q *FakeQuerier) InTx(fn func(database.Store) error, opts *database.TxOptions) error { - q.mutex.Lock() - defer q.mutex.Unlock() - tx := &fakeTx{ - FakeQuerier: &FakeQuerier{mutex: inTxMutex{}, data: q.data}, - locks: map[int64]struct{}{}, - } - defer tx.releaseLocks() - - if opts != nil { - database.IncrementExecutionCount(opts) - } - return fn(tx) -} - -// getUserByIDNoLock is used by other functions in the database fake. -func (q *FakeQuerier) getUserByIDNoLock(id uuid.UUID) (database.User, error) { - for _, user := range q.users { - if user.ID == id { - return user, nil - } - } - return database.User{}, sql.ErrNoRows -} - -func convertUsers(users []database.User, count int64) []database.GetUsersRow { - rows := make([]database.GetUsersRow, len(users)) - for i, u := range users { - rows[i] = database.GetUsersRow{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Name: u.Name, - HashedPassword: u.HashedPassword, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, - Status: u.Status, - RBACRoles: u.RBACRoles, - LoginType: u.LoginType, - AvatarURL: u.AvatarURL, - Deleted: u.Deleted, - LastSeenAt: u.LastSeenAt, - Count: count, - IsSystem: u.IsSystem, - } - } - - return rows -} - -// mapAgentStatus determines the agent status based on different timestamps like created_at, last_connected_at, disconnected_at, etc. -// The function must be in sync with: coderd/workspaceagents.go:convertWorkspaceAgent. -func mapAgentStatus(dbAgent database.WorkspaceAgent, agentInactiveDisconnectTimeoutSeconds int64) string { - var status string - connectionTimeout := time.Duration(dbAgent.ConnectionTimeoutSeconds) * time.Second - switch { - case !dbAgent.FirstConnectedAt.Valid: - switch { - case connectionTimeout > 0 && dbtime.Now().Sub(dbAgent.CreatedAt) > connectionTimeout: - // If the agent took too long to connect the first time, - // mark it as timed out. - status = "timeout" - default: - // If the agent never connected, it's waiting for the compute - // to start up. - status = "connecting" - } - case dbAgent.DisconnectedAt.Time.After(dbAgent.LastConnectedAt.Time): - // If we've disconnected after our last connection, we know the - // agent is no longer connected. - status = "disconnected" - case dbtime.Now().Sub(dbAgent.LastConnectedAt.Time) > time.Duration(agentInactiveDisconnectTimeoutSeconds)*time.Second: - // The connection died without updating the last connected. - status = "disconnected" - case dbAgent.LastConnectedAt.Valid: - // The agent should be assumed connected if it's under inactivity timeouts - // and last connected at has been properly set. - status = "connected" - default: - panic("unknown agent status: " + status) - } - return status -} - -func (q *FakeQuerier) convertToWorkspaceRowsNoLock(ctx context.Context, workspaces []database.WorkspaceTable, count int64, withSummary bool) []database.GetWorkspacesRow { //nolint:revive // withSummary flag ensures the extra technical row - rows := make([]database.GetWorkspacesRow, 0, len(workspaces)) - for _, w := range workspaces { - extended := q.extendWorkspace(w) - - wr := database.GetWorkspacesRow{ - ID: w.ID, - CreatedAt: w.CreatedAt, - UpdatedAt: w.UpdatedAt, - OwnerID: w.OwnerID, - OrganizationID: w.OrganizationID, - TemplateID: w.TemplateID, - Deleted: w.Deleted, - Name: w.Name, - AutostartSchedule: w.AutostartSchedule, - Ttl: w.Ttl, - LastUsedAt: w.LastUsedAt, - DormantAt: w.DormantAt, - DeletingAt: w.DeletingAt, - AutomaticUpdates: w.AutomaticUpdates, - Favorite: w.Favorite, - NextStartAt: w.NextStartAt, - - OwnerAvatarUrl: extended.OwnerAvatarUrl, - OwnerUsername: extended.OwnerUsername, - OwnerName: extended.OwnerName, - - OrganizationName: extended.OrganizationName, - OrganizationDisplayName: extended.OrganizationDisplayName, - OrganizationIcon: extended.OrganizationIcon, - OrganizationDescription: extended.OrganizationDescription, - - TemplateName: extended.TemplateName, - TemplateDisplayName: extended.TemplateDisplayName, - TemplateIcon: extended.TemplateIcon, - TemplateDescription: extended.TemplateDescription, - - Count: count, - - // These fields are missing! - // Try to resolve them below - TemplateVersionID: uuid.UUID{}, - TemplateVersionName: sql.NullString{}, - LatestBuildCompletedAt: sql.NullTime{}, - LatestBuildCanceledAt: sql.NullTime{}, - LatestBuildError: sql.NullString{}, - LatestBuildTransition: "", - LatestBuildStatus: "", - } - - if build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID); err == nil { - for _, tv := range q.templateVersions { - if tv.ID == build.TemplateVersionID { - wr.TemplateVersionID = tv.ID - wr.TemplateVersionName = sql.NullString{ - Valid: true, - String: tv.Name, - } - break - } - } - - if pj, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID); err == nil { - wr.LatestBuildStatus = pj.JobStatus - wr.LatestBuildCanceledAt = pj.CanceledAt - wr.LatestBuildCompletedAt = pj.CompletedAt - wr.LatestBuildError = pj.Error - } - - wr.LatestBuildTransition = build.Transition - } - - rows = append(rows, wr) - } - if withSummary { - rows = append(rows, database.GetWorkspacesRow{ - Name: "**TECHNICAL_ROW**", - Count: count, - }) - } - return rows -} - -func (q *FakeQuerier) getWorkspaceByIDNoLock(_ context.Context, id uuid.UUID) (database.Workspace, error) { - return q.getWorkspaceNoLock(func(w database.WorkspaceTable) bool { - return w.ID == id - }) -} - -func (q *FakeQuerier) getWorkspaceNoLock(find func(w database.WorkspaceTable) bool) (database.Workspace, error) { - w, found := slice.Find(q.workspaces, find) - if found { - return q.extendWorkspace(w), nil - } - return database.Workspace{}, sql.ErrNoRows -} - -func (q *FakeQuerier) extendWorkspace(w database.WorkspaceTable) database.Workspace { - var extended database.Workspace - // This is a cheeky way to copy the fields over without explicitly listing them all. - d, _ := json.Marshal(w) - _ = json.Unmarshal(d, &extended) - - org, _ := slice.Find(q.organizations, func(o database.Organization) bool { - return o.ID == w.OrganizationID - }) - extended.OrganizationName = org.Name - extended.OrganizationDescription = org.Description - extended.OrganizationDisplayName = org.DisplayName - extended.OrganizationIcon = org.Icon - - tpl, _ := slice.Find(q.templates, func(t database.TemplateTable) bool { - return t.ID == w.TemplateID - }) - extended.TemplateName = tpl.Name - extended.TemplateDisplayName = tpl.DisplayName - extended.TemplateDescription = tpl.Description - extended.TemplateIcon = tpl.Icon - - owner, _ := slice.Find(q.users, func(u database.User) bool { - return u.ID == w.OwnerID - }) - extended.OwnerUsername = owner.Username - extended.OwnerName = owner.Name - extended.OwnerAvatarUrl = owner.AvatarURL - - return extended -} - -func (q *FakeQuerier) getWorkspaceByAgentIDNoLock(_ context.Context, agentID uuid.UUID) (database.Workspace, error) { - var agent database.WorkspaceAgent - for _, _agent := range q.workspaceAgents { - if _agent.ID == agentID { - agent = _agent - break - } - } - if agent.ID == uuid.Nil { - return database.Workspace{}, sql.ErrNoRows - } - - var resource database.WorkspaceResource - for _, _resource := range q.workspaceResources { - if _resource.ID == agent.ResourceID { - resource = _resource - break - } - } - if resource.ID == uuid.Nil { - return database.Workspace{}, sql.ErrNoRows - } - - var build database.WorkspaceBuild - for _, _build := range q.workspaceBuilds { - if _build.JobID == resource.JobID { - build = q.workspaceBuildWithUserNoLock(_build) - break - } - } - if build.ID == uuid.Nil { - return database.Workspace{}, sql.ErrNoRows - } - - return q.getWorkspaceNoLock(func(w database.WorkspaceTable) bool { - return w.ID == build.WorkspaceID - }) -} - -func (q *FakeQuerier) getWorkspaceBuildByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { - for _, build := range q.workspaceBuilds { - if build.ID == id { - return q.workspaceBuildWithUserNoLock(build), nil - } - } - return database.WorkspaceBuild{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getLatestWorkspaceBuildByWorkspaceIDNoLock(_ context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { - var row database.WorkspaceBuild - var buildNum int32 = -1 - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.WorkspaceID == workspaceID && workspaceBuild.BuildNumber > buildNum { - row = q.workspaceBuildWithUserNoLock(workspaceBuild) - buildNum = workspaceBuild.BuildNumber - } - } - if buildNum == -1 { - return database.WorkspaceBuild{}, sql.ErrNoRows - } - return row, nil -} - -func (q *FakeQuerier) getTemplateByIDNoLock(_ context.Context, id uuid.UUID) (database.Template, error) { - for _, template := range q.templates { - if template.ID == id { - return q.templateWithNameNoLock(template), nil - } - } - return database.Template{}, sql.ErrNoRows -} - -func (q *FakeQuerier) templatesWithUserNoLock(tpl []database.TemplateTable) []database.Template { - cpy := make([]database.Template, 0, len(tpl)) - for _, t := range tpl { - cpy = append(cpy, q.templateWithNameNoLock(t)) - } - return cpy -} - -func (q *FakeQuerier) templateWithNameNoLock(tpl database.TemplateTable) database.Template { - var user database.User - for _, _user := range q.users { - if _user.ID == tpl.CreatedBy { - user = _user - break - } - } - - var org database.Organization - for _, _org := range q.organizations { - if _org.ID == tpl.OrganizationID { - org = _org - break - } - } - - var withNames database.Template - // This is a cheeky way to copy the fields over without explicitly listing them all. - d, _ := json.Marshal(tpl) - _ = json.Unmarshal(d, &withNames) - withNames.CreatedByUsername = user.Username - withNames.CreatedByAvatarURL = user.AvatarURL - withNames.OrganizationName = org.Name - withNames.OrganizationDisplayName = org.DisplayName - withNames.OrganizationIcon = org.Icon - return withNames -} - -func (q *FakeQuerier) templateVersionWithUserNoLock(tpl database.TemplateVersionTable) database.TemplateVersion { - var user database.User - for _, _user := range q.users { - if _user.ID == tpl.CreatedBy { - user = _user - break - } - } - var withUser database.TemplateVersion - // This is a cheeky way to copy the fields over without explicitly listing them all. - d, _ := json.Marshal(tpl) - _ = json.Unmarshal(d, &withUser) - withUser.CreatedByUsername = user.Username - withUser.CreatedByAvatarURL = user.AvatarURL - return withUser -} - -func (q *FakeQuerier) workspaceBuildWithUserNoLock(tpl database.WorkspaceBuild) database.WorkspaceBuild { - var user database.User - for _, _user := range q.users { - if _user.ID == tpl.InitiatorID { - user = _user - break - } - } - var withUser database.WorkspaceBuild - // This is a cheeky way to copy the fields over without explicitly listing them all. - d, _ := json.Marshal(tpl) - _ = json.Unmarshal(d, &withUser) - withUser.InitiatorByUsername = user.Username - withUser.InitiatorByAvatarUrl = user.AvatarURL - return withUser -} - -func (q *FakeQuerier) getTemplateVersionByIDNoLock(_ context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) { - for _, templateVersion := range q.templateVersions { - if templateVersion.ID != templateVersionID { - continue - } - return q.templateVersionWithUserNoLock(templateVersion), nil - } - return database.TemplateVersion{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getWorkspaceAgentByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { - // The schema sorts this by created at, so we iterate the array backwards. - for i := len(q.workspaceAgents) - 1; i >= 0; i-- { - agent := q.workspaceAgents[i] - if !agent.Deleted && agent.ID == id { - return agent, nil - } - } - return database.WorkspaceAgent{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getWorkspaceAgentsByResourceIDsNoLock(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) { - workspaceAgents := make([]database.WorkspaceAgent, 0) - for _, agent := range q.workspaceAgents { - if agent.Deleted { - continue - } - for _, resourceID := range resourceIDs { - if agent.ResourceID != resourceID { - continue - } - workspaceAgents = append(workspaceAgents, agent) - } - } - return workspaceAgents, nil -} - -func (q *FakeQuerier) getWorkspaceAppByAgentIDAndSlugNoLock(_ context.Context, arg database.GetWorkspaceAppByAgentIDAndSlugParams) (database.WorkspaceApp, error) { - for _, app := range q.workspaceApps { - if app.AgentID != arg.AgentID { - continue - } - if app.Slug != arg.Slug { - continue - } - return app, nil - } - return database.WorkspaceApp{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getProvisionerJobByIDNoLock(_ context.Context, id uuid.UUID) (database.ProvisionerJob, error) { - for _, provisionerJob := range q.provisionerJobs { - if provisionerJob.ID != id { - continue - } - // clone the Tags before returning, since maps are reference types and - // we don't want the caller to be able to mutate the map we have inside - // dbmem! - provisionerJob.Tags = maps.Clone(provisionerJob.Tags) - return provisionerJob, nil - } - return database.ProvisionerJob{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getWorkspaceResourcesByJobIDNoLock(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { - resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.workspaceResources { - if resource.JobID != jobID { - continue - } - resources = append(resources, resource) - } - return resources, nil -} - -func (q *FakeQuerier) getGroupByIDNoLock(_ context.Context, id uuid.UUID) (database.Group, error) { - for _, group := range q.groups { - if group.ID == id { - return group, nil - } - } - - return database.Group{}, sql.ErrNoRows -} - -// ErrUnimplemented is returned by methods only used by the enterprise/tailnet.pgCoord. This coordinator explicitly -// depends on postgres triggers that announce changes on the pubsub. Implementing support for this in the fake -// database would strongly couple the FakeQuerier to the pubsub, which is undesirable. Furthermore, it makes little -// sense to directly test the pgCoord against anything other than postgres. The FakeQuerier is designed to allow us to -// test the Coderd API, and for that kind of test, the in-memory, AGPL tailnet coordinator is sufficient. Therefore, -// these methods remain unimplemented in the FakeQuerier. -var ErrUnimplemented = xerrors.New("unimplemented") - -func uniqueSortedUUIDs(uuids []uuid.UUID) []uuid.UUID { - set := make(map[uuid.UUID]struct{}) - for _, id := range uuids { - set[id] = struct{}{} - } - unique := make([]uuid.UUID, 0, len(set)) - for id := range set { - unique = append(unique, id) - } - slices.SortFunc(unique, func(a, b uuid.UUID) int { - return slice.Ascending(a.String(), b.String()) - }) - return unique -} - -func (q *FakeQuerier) getOrganizationMemberNoLock(orgID uuid.UUID) []database.OrganizationMember { - var members []database.OrganizationMember - for _, member := range q.organizationMembers { - if member.OrganizationID == orgID { - members = append(members, member) - } - } - - return members -} - -var errUserDeleted = xerrors.New("user deleted") - -// getGroupMemberNoLock fetches a group member by user ID and group ID. -func (q *FakeQuerier) getGroupMemberNoLock(ctx context.Context, userID, groupID uuid.UUID) (database.GroupMember, error) { - groupName := "Everyone" - orgID := groupID - groupIsEveryone := q.isEveryoneGroup(groupID) - if !groupIsEveryone { - group, err := q.getGroupByIDNoLock(ctx, groupID) - if err != nil { - return database.GroupMember{}, err - } - groupName = group.Name - orgID = group.OrganizationID - } - - user, err := q.getUserByIDNoLock(userID) - if err != nil { - return database.GroupMember{}, err - } - if user.Deleted { - return database.GroupMember{}, errUserDeleted - } - - return database.GroupMember{ - UserID: user.ID, - UserEmail: user.Email, - UserUsername: user.Username, - UserHashedPassword: user.HashedPassword, - UserCreatedAt: user.CreatedAt, - UserUpdatedAt: user.UpdatedAt, - UserStatus: user.Status, - UserRbacRoles: user.RBACRoles, - UserLoginType: user.LoginType, - UserAvatarUrl: user.AvatarURL, - UserDeleted: user.Deleted, - UserLastSeenAt: user.LastSeenAt, - UserQuietHoursSchedule: user.QuietHoursSchedule, - UserName: user.Name, - UserGithubComUserID: user.GithubComUserID, - OrganizationID: orgID, - GroupName: groupName, - GroupID: groupID, - }, nil -} - -// getEveryoneGroupMembersNoLock fetches all the users in an organization. -func (q *FakeQuerier) getEveryoneGroupMembersNoLock(ctx context.Context, orgID uuid.UUID) []database.GroupMember { - var ( - everyone []database.GroupMember - orgMembers = q.getOrganizationMemberNoLock(orgID) - ) - for _, member := range orgMembers { - groupMember, err := q.getGroupMemberNoLock(ctx, member.UserID, orgID) - if errors.Is(err, errUserDeleted) { - continue - } - if err != nil { - return nil - } - everyone = append(everyone, groupMember) - } - return everyone -} - -// isEveryoneGroup returns true if the provided ID matches -// an organization ID. -func (q *FakeQuerier) isEveryoneGroup(id uuid.UUID) bool { - for _, org := range q.organizations { - if org.ID == id { - return true - } - } - return false -} - -func (q *FakeQuerier) GetActiveDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - ks := make([]database.DBCryptKey, 0, len(q.dbcryptKeys)) - for _, k := range q.dbcryptKeys { - if !k.ActiveKeyDigest.Valid { - continue - } - ks = append([]database.DBCryptKey{}, k) - } - return ks, nil -} - -func maxTime(t, u time.Time) time.Time { - if t.After(u) { - return t - } - return u -} - -func minTime(t, u time.Time) time.Time { - if t.Before(u) { - return t - } - return u -} - -func provisionerJobStatus(j database.ProvisionerJob) database.ProvisionerJobStatus { - if isNotNull(j.CompletedAt) { - if j.Error.String != "" { - return database.ProvisionerJobStatusFailed - } - if isNotNull(j.CanceledAt) { - return database.ProvisionerJobStatusCanceled - } - return database.ProvisionerJobStatusSucceeded - } - - if isNotNull(j.CanceledAt) { - return database.ProvisionerJobStatusCanceling - } - if isNull(j.StartedAt) { - return database.ProvisionerJobStatusPending - } - return database.ProvisionerJobStatusRunning -} - -// isNull is only used in dbmem, so reflect is ok. Use this to make the logic -// look more similar to the postgres. -func isNull(v interface{}) bool { - return !isNotNull(v) -} - -func isNotNull(v interface{}) bool { - return reflect.ValueOf(v).FieldByName("Valid").Bool() -} - -// Took the error from the real database. -var deletedUserLinkError = &pq.Error{ - Severity: "ERROR", - // "raise_exception" error - Code: "P0001", - Message: "Cannot create user_link for deleted user", - Where: "PL/pgSQL function insert_user_links_fail_if_user_deleted() line 7 at RAISE", - File: "pl_exec.c", - Line: "3864", - Routine: "exec_stmt_raise", -} - -// m1 and m2 are equal iff |m1| = |m2| ^ m2 ⊆ m1 -func tagsEqual(m1, m2 map[string]string) bool { - return len(m1) == len(m2) && tagsSubset(m1, m2) -} - -// m2 is a subset of m1 if each key in m1 exists in m2 -// with the same value -func tagsSubset(m1, m2 map[string]string) bool { - for k, v1 := range m1 { - if v2, found := m2[k]; !found || v1 != v2 { - return false - } - } - return true -} - -// default tags when no tag is specified for a provisioner or job -var tagsUntagged = provisionersdk.MutateTags(uuid.Nil, nil) - -func least[T constraints.Ordered](a, b T) T { - if a < b { - return a - } - return b -} - -func (q *FakeQuerier) getLatestWorkspaceAppByTemplateIDUserIDSlugNoLock(ctx context.Context, templateID, userID uuid.UUID, slug string) (database.WorkspaceApp, error) { - /* - SELECT - app.display_name, - app.icon, - app.slug - FROM - workspace_apps AS app - JOIN - workspace_agents AS agent - ON - agent.id = app.agent_id - JOIN - workspace_resources AS resource - ON - resource.id = agent.resource_id - JOIN - workspace_builds AS build - ON - build.job_id = resource.job_id - JOIN - workspaces AS workspace - ON - workspace.id = build.workspace_id - WHERE - -- Requires lateral join. - app.slug = app_usage.key - AND workspace.owner_id = tus.user_id - AND workspace.template_id = tus.template_id - ORDER BY - app.created_at DESC - LIMIT 1 - */ - - var workspaces []database.WorkspaceTable - for _, w := range q.workspaces { - if w.TemplateID != templateID || w.OwnerID != userID { - continue - } - workspaces = append(workspaces, w) - } - slices.SortFunc(workspaces, func(a, b database.WorkspaceTable) int { - if a.CreatedAt.Before(b.CreatedAt) { - return 1 - } else if a.CreatedAt.Equal(b.CreatedAt) { - return 0 - } - return -1 - }) - - for _, workspace := range workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - continue - } - - resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, build.JobID) - if err != nil { - continue - } - var resourceIDs []uuid.UUID - for _, resource := range resources { - resourceIDs = append(resourceIDs, resource.ID) - } - - agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) - if err != nil { - continue - } - - for _, agent := range agents { - app, err := q.getWorkspaceAppByAgentIDAndSlugNoLock(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{ - AgentID: agent.ID, - Slug: slug, - }) - if err != nil { - continue - } - return app, nil - } - } - - return database.WorkspaceApp{}, sql.ErrNoRows -} - -// getOrganizationByIDNoLock is used by other functions in the database fake. -func (q *FakeQuerier) getOrganizationByIDNoLock(id uuid.UUID) (database.Organization, error) { - for _, organization := range q.organizations { - if organization.ID == id { - return organization, nil - } - } - return database.Organization{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getWorkspaceAgentScriptsByAgentIDsNoLock(ids []uuid.UUID) ([]database.WorkspaceAgentScript, error) { - scripts := make([]database.WorkspaceAgentScript, 0) - for _, script := range q.workspaceAgentScripts { - for _, id := range ids { - if script.WorkspaceAgentID == id { - scripts = append(scripts, script) - break - } - } - } - return scripts, nil -} - -// getOwnerFromTags returns the lowercase owner from tags, matching SQL's COALESCE(tags ->> 'owner', ”) -func getOwnerFromTags(tags map[string]string) string { - if owner, ok := tags["owner"]; ok { - return strings.ToLower(owner) - } - return "" -} - -// provisionerTagsetContains checks if daemonTags contain all key-value pairs from jobTags -func provisionerTagsetContains(daemonTags, jobTags map[string]string) bool { - for jobKey, jobValue := range jobTags { - if daemonValue, exists := daemonTags[jobKey]; !exists || daemonValue != jobValue { - return false - } - } - return true -} - -// GetProvisionerJobsByIDsWithQueuePosition mimics the SQL logic in pure Go -func (q *FakeQuerier) getProvisionerJobsByIDsWithQueuePositionLockedTagBasedQueue(_ context.Context, jobIDs []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { - // Step 1: Filter provisionerJobs based on jobIDs - filteredJobs := make(map[uuid.UUID]database.ProvisionerJob) - for _, job := range q.provisionerJobs { - for _, id := range jobIDs { - if job.ID == id { - filteredJobs[job.ID] = job - } - } - } - - // Step 2: Identify pending jobs - pendingJobs := make(map[uuid.UUID]database.ProvisionerJob) - for _, job := range q.provisionerJobs { - if job.JobStatus == "pending" { - pendingJobs[job.ID] = job - } - } - - // Step 3: Identify pending jobs that have a matching provisioner - matchedJobs := make(map[uuid.UUID]struct{}) - for _, job := range pendingJobs { - for _, daemon := range q.provisionerDaemons { - if provisionerTagsetContains(daemon.Tags, job.Tags) { - matchedJobs[job.ID] = struct{}{} - break - } - } - } - - // Step 4: Rank pending jobs per provisioner - jobRanks := make(map[uuid.UUID][]database.ProvisionerJob) - for _, job := range pendingJobs { - for _, daemon := range q.provisionerDaemons { - if provisionerTagsetContains(daemon.Tags, job.Tags) { - jobRanks[daemon.ID] = append(jobRanks[daemon.ID], job) - } - } - } - - // Sort jobs per provisioner by CreatedAt - for daemonID := range jobRanks { - sort.Slice(jobRanks[daemonID], func(i, j int) bool { - return jobRanks[daemonID][i].CreatedAt.Before(jobRanks[daemonID][j].CreatedAt) - }) - } - - // Step 5: Compute queue position & max queue size across all provisioners - jobQueueStats := make(map[uuid.UUID]database.GetProvisionerJobsByIDsWithQueuePositionRow) - for _, jobs := range jobRanks { - queueSize := int64(len(jobs)) // Queue size per provisioner - for i, job := range jobs { - queuePosition := int64(i + 1) - - // If the job already exists, update only if this queuePosition is better - if existing, exists := jobQueueStats[job.ID]; exists { - jobQueueStats[job.ID] = database.GetProvisionerJobsByIDsWithQueuePositionRow{ - ID: job.ID, - CreatedAt: job.CreatedAt, - ProvisionerJob: job, - QueuePosition: min(existing.QueuePosition, queuePosition), - QueueSize: max(existing.QueueSize, queueSize), // Take the maximum queue size across provisioners - } - } else { - jobQueueStats[job.ID] = database.GetProvisionerJobsByIDsWithQueuePositionRow{ - ID: job.ID, - CreatedAt: job.CreatedAt, - ProvisionerJob: job, - QueuePosition: queuePosition, - QueueSize: queueSize, - } - } - } - } - - // Step 6: Compute the final results with minimal checks - var results []database.GetProvisionerJobsByIDsWithQueuePositionRow - for _, job := range filteredJobs { - // If the job has a computed rank, use it - if rank, found := jobQueueStats[job.ID]; found { - results = append(results, rank) - } else { - // Otherwise, return (0,0) for non-pending jobs and unranked pending jobs - results = append(results, database.GetProvisionerJobsByIDsWithQueuePositionRow{ - ID: job.ID, - CreatedAt: job.CreatedAt, - ProvisionerJob: job, - QueuePosition: 0, - QueueSize: 0, - }) - } - } - - // Step 7: Sort results by CreatedAt - sort.Slice(results, func(i, j int) bool { - return results[i].CreatedAt.Before(results[j].CreatedAt) - }) - - return results, nil -} - -func (q *FakeQuerier) getProvisionerJobsByIDsWithQueuePositionLockedGlobalQueue(_ context.Context, ids []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { - // WITH pending_jobs AS ( - // SELECT - // id, created_at - // FROM - // provisioner_jobs - // WHERE - // started_at IS NULL - // AND - // canceled_at IS NULL - // AND - // completed_at IS NULL - // AND - // error IS NULL - // ), - type pendingJobRow struct { - ID uuid.UUID - CreatedAt time.Time - } - pendingJobs := make([]pendingJobRow, 0) - for _, job := range q.provisionerJobs { - if job.StartedAt.Valid || - job.CanceledAt.Valid || - job.CompletedAt.Valid || - job.Error.Valid { - continue - } - pendingJobs = append(pendingJobs, pendingJobRow{ - ID: job.ID, - CreatedAt: job.CreatedAt, - }) - } - - // queue_position AS ( - // SELECT - // id, - // ROW_NUMBER() OVER (ORDER BY created_at ASC) AS queue_position - // FROM - // pending_jobs - // ), - slices.SortFunc(pendingJobs, func(a, b pendingJobRow) int { - c := a.CreatedAt.Compare(b.CreatedAt) - return c - }) - - queuePosition := make(map[uuid.UUID]int64) - for idx, pj := range pendingJobs { - queuePosition[pj.ID] = int64(idx + 1) - } - - // queue_size AS ( - // SELECT COUNT(*) AS count FROM pending_jobs - // ), - queueSize := len(pendingJobs) - - // SELECT - // sqlc.embed(pj), - // COALESCE(qp.queue_position, 0) AS queue_position, - // COALESCE(qs.count, 0) AS queue_size - // FROM - // provisioner_jobs pj - // LEFT JOIN - // queue_position qp ON pj.id = qp.id - // LEFT JOIN - // queue_size qs ON TRUE - // WHERE - // pj.id IN (...) - jobs := make([]database.GetProvisionerJobsByIDsWithQueuePositionRow, 0) - for _, job := range q.provisionerJobs { - if ids != nil && !slices.Contains(ids, job.ID) { - continue - } - // clone the Tags before appending, since maps are reference types and - // we don't want the caller to be able to mutate the map we have inside - // dbmem! - job.Tags = maps.Clone(job.Tags) - job := database.GetProvisionerJobsByIDsWithQueuePositionRow{ - // sqlc.embed(pj), - ProvisionerJob: job, - // COALESCE(qp.queue_position, 0) AS queue_position, - QueuePosition: queuePosition[job.ID], - // COALESCE(qs.count, 0) AS queue_size - QueueSize: int64(queueSize), - } - jobs = append(jobs, job) - } - - return jobs, nil -} - -// isDeprecated returns true if the template is deprecated. -// A template is considered deprecated when it has a deprecation message. -func isDeprecated(template database.Template) bool { - return template.Deprecated != "" -} - -func (q *FakeQuerier) getWorkspaceBuildParametersNoLock(workspaceBuildID uuid.UUID) ([]database.WorkspaceBuildParameter, error) { - params := make([]database.WorkspaceBuildParameter, 0) - for _, param := range q.workspaceBuildParameters { - if param.WorkspaceBuildID != workspaceBuildID { - continue - } - params = append(params, param) - } - return params, nil -} - -func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error { - return xerrors.New("AcquireLock must only be called within a transaction") -} - -// AcquireNotificationMessages implements the *basic* business logic, but is *not* exhaustive or meant to be 1:1 with -// the real AcquireNotificationMessages query. -func (q *FakeQuerier) AcquireNotificationMessages(_ context.Context, arg database.AcquireNotificationMessagesParams) ([]database.AcquireNotificationMessagesRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - // Shift the first "Count" notifications off the slice (FIFO). - sz := len(q.notificationMessages) - if sz > int(arg.Count) { - sz = int(arg.Count) - } - - list := q.notificationMessages[:sz] - q.notificationMessages = q.notificationMessages[sz:] - - var out []database.AcquireNotificationMessagesRow - for _, nm := range list { - acquirableStatuses := []database.NotificationMessageStatus{database.NotificationMessageStatusPending, database.NotificationMessageStatusTemporaryFailure} - if !slices.Contains(acquirableStatuses, nm.Status) { - continue - } - - // Mimic mutation in database query. - nm.UpdatedAt = sql.NullTime{Time: dbtime.Now(), Valid: true} - nm.Status = database.NotificationMessageStatusLeased - nm.StatusReason = sql.NullString{String: fmt.Sprintf("Enqueued by notifier %d", arg.NotifierID), Valid: true} - nm.LeasedUntil = sql.NullTime{Time: dbtime.Now().Add(time.Second * time.Duration(arg.LeaseSeconds)), Valid: true} - - out = append(out, database.AcquireNotificationMessagesRow{ - ID: nm.ID, - Payload: nm.Payload, - Method: nm.Method, - TitleTemplate: "This is a title with {{.Labels.variable}}", - BodyTemplate: "This is a body with {{.Labels.variable}}", - TemplateID: nm.NotificationTemplateID, - }) - } - - return out, nil -} - -func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ProvisionerJob{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, provisionerJob := range q.provisionerJobs { - if provisionerJob.OrganizationID != arg.OrganizationID { - continue - } - if provisionerJob.StartedAt.Valid { - continue - } - found := false - for _, provisionerType := range arg.Types { - if provisionerJob.Provisioner != provisionerType { - continue - } - found = true - break - } - if !found { - continue - } - tags := map[string]string{} - if arg.ProvisionerTags != nil { - err := json.Unmarshal(arg.ProvisionerTags, &tags) - if err != nil { - return provisionerJob, xerrors.Errorf("unmarshal: %w", err) - } - } - - // Special case for untagged provisioners: only match untagged jobs. - // Ref: coderd/database/queries/provisionerjobs.sql:24-30 - // CASE WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb - // THEN nested.tags :: jsonb = @tags :: jsonb - if tagsEqual(provisionerJob.Tags, tagsUntagged) && !tagsEqual(provisionerJob.Tags, tags) { - continue - } - // ELSE nested.tags :: jsonb <@ @tags :: jsonb - if !tagsSubset(provisionerJob.Tags, tags) { - continue - } - provisionerJob.StartedAt = arg.StartedAt - provisionerJob.UpdatedAt = arg.StartedAt.Time - provisionerJob.WorkerID = arg.WorkerID - provisionerJob.JobStatus = provisionerJobStatus(provisionerJob) - q.provisionerJobs[index] = provisionerJob - // clone the Tags before returning, since maps are reference types and - // we don't want the caller to be able to mutate the map we have inside - // dbmem! - provisionerJob.Tags = maps.Clone(provisionerJob.Tags) - return provisionerJob, nil - } - return database.ProvisionerJob{}, sql.ErrNoRows -} - -func (q *FakeQuerier) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - workspace, err := q.getWorkspaceByIDNoLock(ctx, arg.WorkspaceID) - if err != nil { - return err - } - latestBuild, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, arg.WorkspaceID) - if err != nil { - return err - } - - now := dbtime.Now() - for i := range q.workspaceBuilds { - if q.workspaceBuilds[i].BuildNumber != latestBuild.BuildNumber { - continue - } - // If the build is not active, do not bump. - if q.workspaceBuilds[i].Transition != database.WorkspaceTransitionStart { - return nil - } - // If the provisioner job is not completed, do not bump. - pj, err := q.getProvisionerJobByIDNoLock(ctx, q.workspaceBuilds[i].JobID) - if err != nil { - return err - } - if !pj.CompletedAt.Valid { - return nil - } - // Do not bump if the deadline is not set. - if q.workspaceBuilds[i].Deadline.IsZero() { - return nil - } - - // Check the template default TTL. - template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) - if err != nil { - return err - } - if template.ActivityBump == 0 { - return nil - } - activityBump := time.Duration(template.ActivityBump) - - var ttlDur time.Duration - if now.Add(activityBump).After(arg.NextAutostart) && arg.NextAutostart.After(now) { - // Extend to TTL (NOT activity bump) - add := arg.NextAutostart.Sub(now) - if workspace.Ttl.Valid && template.AllowUserAutostop { - add += time.Duration(workspace.Ttl.Int64) - } else { - add += time.Duration(template.DefaultTTL) - } - ttlDur = add - } else { - // Otherwise, default to regular activity bump duration. - ttlDur = activityBump - } - - // Only bump if 5% of the deadline has passed. - ttlDur95 := ttlDur - (ttlDur / 20) - minBumpDeadline := q.workspaceBuilds[i].Deadline.Add(-ttlDur95) - if now.Before(minBumpDeadline) { - return nil - } - - // Bump. - newDeadline := now.Add(ttlDur) - // Never decrease deadlines from a bump - newDeadline = maxTime(newDeadline, q.workspaceBuilds[i].Deadline) - q.workspaceBuilds[i].UpdatedAt = now - if !q.workspaceBuilds[i].MaxDeadline.IsZero() { - q.workspaceBuilds[i].Deadline = minTime(newDeadline, q.workspaceBuilds[i].MaxDeadline) - } else { - q.workspaceBuilds[i].Deadline = newDeadline - } - return nil - } - - return sql.ErrNoRows -} - -// nolint:revive // It's not a control flag, it's a filter. -func (q *FakeQuerier) AllUserIDs(_ context.Context, includeSystem bool) ([]uuid.UUID, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - userIDs := make([]uuid.UUID, 0, len(q.users)) - for idx := range q.users { - if !includeSystem && q.users[idx].IsSystem { - continue - } - - userIDs = append(userIDs, q.users[idx].ID) - } - return userIDs, nil -} - -func (q *FakeQuerier) ArchiveUnusedTemplateVersions(_ context.Context, arg database.ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - q.mutex.Lock() - defer q.mutex.Unlock() - type latestBuild struct { - Number int32 - Version uuid.UUID - } - latest := make(map[uuid.UUID]latestBuild) - - for _, b := range q.workspaceBuilds { - v, ok := latest[b.WorkspaceID] - if ok || b.BuildNumber < v.Number { - // Not the latest - continue - } - // Ignore deleted workspaces. - if b.Transition == database.WorkspaceTransitionDelete { - continue - } - latest[b.WorkspaceID] = latestBuild{ - Number: b.BuildNumber, - Version: b.TemplateVersionID, - } - } - - usedVersions := make(map[uuid.UUID]bool) - for _, l := range latest { - usedVersions[l.Version] = true - } - for _, tpl := range q.templates { - usedVersions[tpl.ActiveVersionID] = true - } - - var archived []uuid.UUID - for i, v := range q.templateVersions { - if arg.TemplateVersionID != uuid.Nil { - if v.ID != arg.TemplateVersionID { - continue - } - } - if v.Archived { - continue - } - - if _, ok := usedVersions[v.ID]; !ok { - var job *database.ProvisionerJob - for i, j := range q.provisionerJobs { - if v.JobID == j.ID { - job = &q.provisionerJobs[i] - break - } - } - - if arg.JobStatus.Valid { - if job.JobStatus != arg.JobStatus.ProvisionerJobStatus { - continue - } - } - - if job.JobStatus == database.ProvisionerJobStatusRunning || job.JobStatus == database.ProvisionerJobStatusPending { - continue - } - - v.Archived = true - q.templateVersions[i] = v - archived = append(archived, v.ID) - } - } - - return archived, nil -} - -func (q *FakeQuerier) BatchUpdateWorkspaceLastUsedAt(_ context.Context, arg database.BatchUpdateWorkspaceLastUsedAtParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - // temporary map to avoid O(q.workspaces*arg.workspaceIds) - m := make(map[uuid.UUID]struct{}) - for _, id := range arg.IDs { - m[id] = struct{}{} - } - n := 0 - for i := 0; i < len(q.workspaces); i++ { - if _, found := m[q.workspaces[i].ID]; !found { - continue - } - // WHERE last_used_at < @last_used_at - if !q.workspaces[i].LastUsedAt.Before(arg.LastUsedAt) { - continue - } - q.workspaces[i].LastUsedAt = arg.LastUsedAt - n++ - } - return nil -} - -func (q *FakeQuerier) BatchUpdateWorkspaceNextStartAt(_ context.Context, arg database.BatchUpdateWorkspaceNextStartAtParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, workspace := range q.workspaces { - for j, workspaceID := range arg.IDs { - if workspace.ID != workspaceID { - continue - } - - nextStartAt := arg.NextStartAts[j] - if nextStartAt.IsZero() { - q.workspaces[i].NextStartAt = sql.NullTime{} - } else { - q.workspaces[i].NextStartAt = sql.NullTime{Valid: true, Time: nextStartAt} - } - - break - } - } - - return nil -} - -func (*FakeQuerier) BulkMarkNotificationMessagesFailed(_ context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) { - err := validateDatabaseType(arg) - if err != nil { - return 0, err - } - return int64(len(arg.IDs)), nil -} - -func (*FakeQuerier) BulkMarkNotificationMessagesSent(_ context.Context, arg database.BulkMarkNotificationMessagesSentParams) (int64, error) { - err := validateDatabaseType(arg) - if err != nil { - return 0, err - } - return int64(len(arg.IDs)), nil -} - -func (q *FakeQuerier) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) { - return database.ClaimPrebuiltWorkspaceRow{}, ErrUnimplemented -} - -func (*FakeQuerier) CleanTailnetCoordinators(_ context.Context) error { - return ErrUnimplemented -} - -func (*FakeQuerier) CleanTailnetLostPeers(context.Context) error { - return ErrUnimplemented -} - -func (*FakeQuerier) CleanTailnetTunnels(context.Context) error { - return ErrUnimplemented -} - -func (q *FakeQuerier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) { - return q.CountAuthorizedAuditLogs(ctx, arg, nil) -} - -func (q *FakeQuerier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) { - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) CountUnreadInboxNotificationsByUserID(_ context.Context, userID uuid.UUID) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var count int64 - for _, notification := range q.inboxNotifications { - if notification.UserID != userID { - continue - } - - if notification.ReadAt.Valid { - continue - } - - count++ - } - - return count, nil -} - -func (q *FakeQuerier) CustomRoles(_ context.Context, arg database.CustomRolesParams) ([]database.CustomRole, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - found := make([]database.CustomRole, 0) - for _, role := range q.data.customRoles { - if len(arg.LookupRoles) > 0 { - if !slices.ContainsFunc(arg.LookupRoles, func(pair database.NameOrganizationPair) bool { - if pair.Name != role.Name { - return false - } - - if role.OrganizationID.Valid { - // Expect org match - return role.OrganizationID.UUID == pair.OrganizationID - } - // Expect no org - return pair.OrganizationID == uuid.Nil - }) { - continue - } - } - - if arg.ExcludeOrgRoles && role.OrganizationID.Valid { - continue - } - - if arg.OrganizationID != uuid.Nil && role.OrganizationID.UUID != arg.OrganizationID { - continue - } - - found = append(found, role) - } - - return found, nil -} - -func (q *FakeQuerier) DeleteAPIKeyByID(_ context.Context, id string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, apiKey := range q.apiKeys { - if apiKey.ID != id { - continue - } - q.apiKeys[index] = q.apiKeys[len(q.apiKeys)-1] - q.apiKeys = q.apiKeys[:len(q.apiKeys)-1] - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteAPIKeysByUserID(_ context.Context, userID uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i := len(q.apiKeys) - 1; i >= 0; i-- { - if q.apiKeys[i].UserID == userID { - q.apiKeys = append(q.apiKeys[:i], q.apiKeys[i+1:]...) - } - } - - return nil -} - -func (*FakeQuerier) DeleteAllTailnetClientSubscriptions(_ context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - return ErrUnimplemented -} - -func (*FakeQuerier) DeleteAllTailnetTunnels(_ context.Context, arg database.DeleteAllTailnetTunnelsParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - return ErrUnimplemented -} - -func (q *FakeQuerier) DeleteAllWebpushSubscriptions(_ context.Context) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.webpushSubscriptions = make([]database.WebpushSubscription, 0) - return nil -} - -func (q *FakeQuerier) DeleteApplicationConnectAPIKeysByUserID(_ context.Context, userID uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i := len(q.apiKeys) - 1; i >= 0; i-- { - if q.apiKeys[i].UserID == userID && q.apiKeys[i].Scope == database.APIKeyScopeApplicationConnect { - q.apiKeys = append(q.apiKeys[:i], q.apiKeys[i+1:]...) - } - } - - return nil -} - -func (*FakeQuerier) DeleteCoordinator(context.Context, uuid.UUID) error { - return ErrUnimplemented -} - -func (q *FakeQuerier) DeleteCryptoKey(_ context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.CryptoKey{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, key := range q.cryptoKeys { - if key.Feature == arg.Feature && key.Sequence == arg.Sequence { - q.cryptoKeys[i].Secret.String = "" - q.cryptoKeys[i].Secret.Valid = false - q.cryptoKeys[i].SecretKeyID.String = "" - q.cryptoKeys[i].SecretKeyID.Valid = false - return q.cryptoKeys[i], nil - } - } - return database.CryptoKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteCustomRole(_ context.Context, arg database.DeleteCustomRoleParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - initial := len(q.data.customRoles) - q.data.customRoles = slices.DeleteFunc(q.data.customRoles, func(role database.CustomRole) bool { - return role.OrganizationID.UUID == arg.OrganizationID.UUID && role.Name == arg.Name - }) - if initial == len(q.data.customRoles) { - return sql.ErrNoRows - } - - // Emulate the trigger 'remove_organization_member_custom_role' - for i, mem := range q.organizationMembers { - if mem.OrganizationID == arg.OrganizationID.UUID { - mem.Roles = slices.DeleteFunc(mem.Roles, func(role string) bool { - return role == arg.Name - }) - q.organizationMembers[i] = mem - } - } - return nil -} - -func (q *FakeQuerier) DeleteExternalAuthLink(_ context.Context, arg database.DeleteExternalAuthLinkParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, key := range q.externalAuthLinks { - if key.UserID != arg.UserID { - continue - } - if key.ProviderID != arg.ProviderID { - continue - } - q.externalAuthLinks[index] = q.externalAuthLinks[len(q.externalAuthLinks)-1] - q.externalAuthLinks = q.externalAuthLinks[:len(q.externalAuthLinks)-1] - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, key := range q.gitSSHKey { - if key.UserID != userID { - continue - } - q.gitSSHKey[index] = q.gitSSHKey[len(q.gitSSHKey)-1] - q.gitSSHKey = q.gitSSHKey[:len(q.gitSSHKey)-1] - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteGroupByID(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, group := range q.groups { - if group.ID == id { - q.groups = append(q.groups[:i], q.groups[i+1:]...) - return nil - } - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteGroupMemberFromGroup(_ context.Context, arg database.DeleteGroupMemberFromGroupParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, member := range q.groupMembers { - if member.UserID == arg.UserID && member.GroupID == arg.GroupID { - q.groupMembers = append(q.groupMembers[:i], q.groupMembers[i+1:]...) - } - } - return nil -} - -func (q *FakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, l := range q.licenses { - if l.ID == id { - q.licenses[index] = q.licenses[len(q.licenses)-1] - q.licenses = q.licenses[:len(q.licenses)-1] - return id, nil - } - } - return 0, sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, app := range q.oauth2ProviderApps { - if app.ID == id { - q.oauth2ProviderApps = append(q.oauth2ProviderApps[:i], q.oauth2ProviderApps[i+1:]...) - - // Also delete related secrets and tokens - for j := len(q.oauth2ProviderAppSecrets) - 1; j >= 0; j-- { - if q.oauth2ProviderAppSecrets[j].AppID == id { - q.oauth2ProviderAppSecrets = append(q.oauth2ProviderAppSecrets[:j], q.oauth2ProviderAppSecrets[j+1:]...) - } - } - - // Delete tokens for the app's secrets - for j := len(q.oauth2ProviderAppTokens) - 1; j >= 0; j-- { - token := q.oauth2ProviderAppTokens[j] - for _, secret := range q.oauth2ProviderAppSecrets { - if secret.AppID == id && token.AppSecretID == secret.ID { - q.oauth2ProviderAppTokens = append(q.oauth2ProviderAppTokens[:j], q.oauth2ProviderAppTokens[j+1:]...) - break - } - } - } - - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteOAuth2ProviderAppByID(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - index := slices.IndexFunc(q.oauth2ProviderApps, func(app database.OAuth2ProviderApp) bool { - return app.ID == id - }) - - if index < 0 { - return sql.ErrNoRows - } - - q.oauth2ProviderApps[index] = q.oauth2ProviderApps[len(q.oauth2ProviderApps)-1] - q.oauth2ProviderApps = q.oauth2ProviderApps[:len(q.oauth2ProviderApps)-1] - - // Cascade delete secrets associated with the deleted app. - var deletedSecretIDs []uuid.UUID - q.oauth2ProviderAppSecrets = slices.DeleteFunc(q.oauth2ProviderAppSecrets, func(secret database.OAuth2ProviderAppSecret) bool { - matches := secret.AppID == id - if matches { - deletedSecretIDs = append(deletedSecretIDs, secret.ID) - } - return matches - }) - - // Cascade delete tokens through the deleted secrets. - var keyIDsToDelete []string - q.oauth2ProviderAppTokens = slices.DeleteFunc(q.oauth2ProviderAppTokens, func(token database.OAuth2ProviderAppToken) bool { - matches := slice.Contains(deletedSecretIDs, token.AppSecretID) - if matches { - keyIDsToDelete = append(keyIDsToDelete, token.APIKeyID) - } - return matches - }) - - // Cascade delete API keys linked to the deleted tokens. - q.apiKeys = slices.DeleteFunc(q.apiKeys, func(key database.APIKey) bool { - return slices.Contains(keyIDsToDelete, key.ID) - }) - - return nil -} - -func (q *FakeQuerier) DeleteOAuth2ProviderAppCodeByID(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, code := range q.oauth2ProviderAppCodes { - if code.ID == id { - q.oauth2ProviderAppCodes[index] = q.oauth2ProviderAppCodes[len(q.oauth2ProviderAppCodes)-1] - q.oauth2ProviderAppCodes = q.oauth2ProviderAppCodes[:len(q.oauth2ProviderAppCodes)-1] - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteOAuth2ProviderAppCodesByAppAndUserID(_ context.Context, arg database.DeleteOAuth2ProviderAppCodesByAppAndUserIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, code := range q.oauth2ProviderAppCodes { - if code.AppID == arg.AppID && code.UserID == arg.UserID { - q.oauth2ProviderAppCodes[index] = q.oauth2ProviderAppCodes[len(q.oauth2ProviderAppCodes)-1] - q.oauth2ProviderAppCodes = q.oauth2ProviderAppCodes[:len(q.oauth2ProviderAppCodes)-1] - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteOAuth2ProviderAppSecretByID(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - index := slices.IndexFunc(q.oauth2ProviderAppSecrets, func(secret database.OAuth2ProviderAppSecret) bool { - return secret.ID == id - }) - - if index < 0 { - return sql.ErrNoRows - } - - q.oauth2ProviderAppSecrets[index] = q.oauth2ProviderAppSecrets[len(q.oauth2ProviderAppSecrets)-1] - q.oauth2ProviderAppSecrets = q.oauth2ProviderAppSecrets[:len(q.oauth2ProviderAppSecrets)-1] - - // Cascade delete tokens created through the deleted secret. - var keyIDsToDelete []string - q.oauth2ProviderAppTokens = slices.DeleteFunc(q.oauth2ProviderAppTokens, func(token database.OAuth2ProviderAppToken) bool { - matches := token.AppSecretID == id - if matches { - keyIDsToDelete = append(keyIDsToDelete, token.APIKeyID) - } - return matches - }) - - // Cascade delete API keys linked to the deleted tokens. - q.apiKeys = slices.DeleteFunc(q.apiKeys, func(key database.APIKey) bool { - return slices.Contains(keyIDsToDelete, key.ID) - }) - - return nil -} - -func (q *FakeQuerier) DeleteOAuth2ProviderAppTokensByAppAndUserID(_ context.Context, arg database.DeleteOAuth2ProviderAppTokensByAppAndUserIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - var keyIDsToDelete []string - q.oauth2ProviderAppTokens = slices.DeleteFunc(q.oauth2ProviderAppTokens, func(token database.OAuth2ProviderAppToken) bool { - // Join secrets and keys to see if the token matches. - secretIdx := slices.IndexFunc(q.oauth2ProviderAppSecrets, func(secret database.OAuth2ProviderAppSecret) bool { - return secret.ID == token.AppSecretID - }) - keyIdx := slices.IndexFunc(q.apiKeys, func(key database.APIKey) bool { - return key.ID == token.APIKeyID - }) - matches := secretIdx != -1 && - q.oauth2ProviderAppSecrets[secretIdx].AppID == arg.AppID && - keyIdx != -1 && q.apiKeys[keyIdx].UserID == arg.UserID - if matches { - keyIDsToDelete = append(keyIDsToDelete, token.APIKeyID) - } - return matches - }) - - // Cascade delete API keys linked to the deleted tokens. - q.apiKeys = slices.DeleteFunc(q.apiKeys, func(key database.APIKey) bool { - return slices.Contains(keyIDsToDelete, key.ID) - }) - - return nil -} - -func (*FakeQuerier) DeleteOldNotificationMessages(_ context.Context) error { - return nil -} - -func (q *FakeQuerier) DeleteOldProvisionerDaemons(_ context.Context) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - now := dbtime.Now() - weekInterval := 7 * 24 * time.Hour - weekAgo := now.Add(-weekInterval) - - var validDaemons []database.ProvisionerDaemon - for _, p := range q.provisionerDaemons { - if (p.CreatedAt.Before(weekAgo) && !p.LastSeenAt.Valid) || (p.LastSeenAt.Valid && p.LastSeenAt.Time.Before(weekAgo)) { - continue - } - validDaemons = append(validDaemons, p) - } - q.provisionerDaemons = validDaemons - return nil -} - -func (q *FakeQuerier) DeleteOldWorkspaceAgentLogs(_ context.Context, threshold time.Time) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - /* - WITH - latest_builds AS ( - SELECT - workspace_id, max(build_number) AS max_build_number - FROM - workspace_builds - GROUP BY - workspace_id - ), - */ - latestBuilds := make(map[uuid.UUID]int32) - for _, wb := range q.workspaceBuilds { - if lastBuildNumber, found := latestBuilds[wb.WorkspaceID]; found && lastBuildNumber > wb.BuildNumber { - continue - } - // not found or newer build number - latestBuilds[wb.WorkspaceID] = wb.BuildNumber - } - - /* - old_agents AS ( - SELECT - wa.id - FROM - workspace_agents AS wa - JOIN - workspace_resources AS wr - ON - wa.resource_id = wr.id - JOIN - workspace_builds AS wb - ON - wb.job_id = wr.job_id - LEFT JOIN - latest_builds - ON - latest_builds.workspace_id = wb.workspace_id - AND - latest_builds.max_build_number = wb.build_number - WHERE - -- Filter out the latest builds for each workspace. - latest_builds.workspace_id IS NULL - AND CASE - -- If the last time the agent connected was before @threshold - WHEN wa.last_connected_at IS NOT NULL THEN - wa.last_connected_at < @threshold :: timestamptz - -- The agent never connected, and was created before @threshold - ELSE wa.created_at < @threshold :: timestamptz - END - ) - */ - oldAgents := make(map[uuid.UUID]struct{}) - for _, wa := range q.workspaceAgents { - for _, wr := range q.workspaceResources { - if wr.ID != wa.ResourceID { - continue - } - for _, wb := range q.workspaceBuilds { - if wb.JobID != wr.JobID { - continue - } - latestBuildNumber, found := latestBuilds[wb.WorkspaceID] - if !found { - panic("workspaceBuilds got modified somehow while q was locked! This is a bug in dbmem!") - } - if latestBuildNumber == wb.BuildNumber { - continue - } - if wa.LastConnectedAt.Valid && wa.LastConnectedAt.Time.Before(threshold) || wa.CreatedAt.Before(threshold) { - oldAgents[wa.ID] = struct{}{} - } - } - } - } - /* - DELETE FROM workspace_agent_logs WHERE agent_id IN (SELECT id FROM old_agents); - */ - var validLogs []database.WorkspaceAgentLog - for _, log := range q.workspaceAgentLogs { - if _, found := oldAgents[log.AgentID]; found { - continue - } - validLogs = append(validLogs, log) - } - q.workspaceAgentLogs = validLogs - return nil -} - -func (q *FakeQuerier) DeleteOldWorkspaceAgentStats(_ context.Context) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - /* - DELETE FROM - workspace_agent_stats - WHERE - created_at < ( - SELECT - COALESCE( - -- When generating initial template usage stats, all the - -- raw agent stats are needed, after that only ~30 mins - -- from last rollup is needed. Deployment stats seem to - -- use between 15 mins and 1 hour of data. We keep a - -- little bit more (1 day) just in case. - MAX(start_time) - '1 days'::interval, - -- Fall back to ~6 months ago if there are no template - -- usage stats so that we don't delete the data before - -- it's rolled up. - NOW() - '180 days'::interval - ) - FROM - template_usage_stats - ) - AND created_at < ( - -- Delete at most in batches of 3 days (with a batch size of 3 days, we - -- can clear out the previous 6 months of data in ~60 iterations) whilst - -- keeping the DB load relatively low. - SELECT - COALESCE(MIN(created_at) + '3 days'::interval, NOW()) - FROM - workspace_agent_stats - ); - */ - - now := dbtime.Now() - var limit time.Time - // MAX - for _, stat := range q.templateUsageStats { - if stat.StartTime.After(limit) { - limit = stat.StartTime.AddDate(0, 0, -1) - } - } - // COALESCE - if limit.IsZero() { - limit = now.AddDate(0, 0, -180) - } - - var validStats []database.WorkspaceAgentStat - var batchLimit time.Time - for _, stat := range q.workspaceAgentStats { - if batchLimit.IsZero() || stat.CreatedAt.Before(batchLimit) { - batchLimit = stat.CreatedAt - } - } - if batchLimit.IsZero() { - batchLimit = time.Now() - } else { - batchLimit = batchLimit.AddDate(0, 0, 3) - } - for _, stat := range q.workspaceAgentStats { - if stat.CreatedAt.Before(limit) && stat.CreatedAt.Before(batchLimit) { - continue - } - validStats = append(validStats, stat) - } - q.workspaceAgentStats = validStats - return nil -} - -func (q *FakeQuerier) DeleteOrganizationMember(ctx context.Context, arg database.DeleteOrganizationMemberParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - deleted := false - q.data.organizationMembers = slices.DeleteFunc(q.data.organizationMembers, func(member database.OrganizationMember) bool { - match := member.OrganizationID == arg.OrganizationID && member.UserID == arg.UserID - deleted = deleted || match - return match - }) - if !deleted { - return sql.ErrNoRows - } - - // Delete group member trigger - q.groupMembers = slices.DeleteFunc(q.groupMembers, func(member database.GroupMemberTable) bool { - if member.UserID != arg.UserID { - return false - } - g, _ := q.getGroupByIDNoLock(ctx, member.GroupID) - return g.OrganizationID == arg.OrganizationID - }) - - return nil -} - -func (q *FakeQuerier) DeleteProvisionerKey(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, key := range q.provisionerKeys { - if key.ID == id { - q.provisionerKeys = append(q.provisionerKeys[:i], q.provisionerKeys[i+1:]...) - return nil - } - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteReplicasUpdatedBefore(_ context.Context, before time.Time) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, replica := range q.replicas { - if replica.UpdatedAt.Before(before) { - q.replicas = append(q.replicas[:i], q.replicas[i+1:]...) - } - } - - return nil -} - -func (q *FakeQuerier) DeleteRuntimeConfig(_ context.Context, key string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - delete(q.runtimeConfig, key) - return nil -} - -func (*FakeQuerier) DeleteTailnetAgent(context.Context, database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { - return database.DeleteTailnetAgentRow{}, ErrUnimplemented -} - -func (*FakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { - return database.DeleteTailnetClientRow{}, ErrUnimplemented -} - -func (*FakeQuerier) DeleteTailnetClientSubscription(context.Context, database.DeleteTailnetClientSubscriptionParams) error { - return ErrUnimplemented -} - -func (*FakeQuerier) DeleteTailnetPeer(_ context.Context, arg database.DeleteTailnetPeerParams) (database.DeleteTailnetPeerRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.DeleteTailnetPeerRow{}, err - } - - return database.DeleteTailnetPeerRow{}, ErrUnimplemented -} - -func (*FakeQuerier) DeleteTailnetTunnel(_ context.Context, arg database.DeleteTailnetTunnelParams) (database.DeleteTailnetTunnelRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.DeleteTailnetTunnelRow{}, err - } - - return database.DeleteTailnetTunnelRow{}, ErrUnimplemented -} - -func (q *FakeQuerier) DeleteWebpushSubscriptionByUserIDAndEndpoint(_ context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, subscription := range q.webpushSubscriptions { - if subscription.UserID == arg.UserID && subscription.Endpoint == arg.Endpoint { - q.webpushSubscriptions[i] = q.webpushSubscriptions[len(q.webpushSubscriptions)-1] - q.webpushSubscriptions = q.webpushSubscriptions[:len(q.webpushSubscriptions)-1] - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteWebpushSubscriptions(_ context.Context, ids []uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - for i, subscription := range q.webpushSubscriptions { - if slices.Contains(ids, subscription.ID) { - q.webpushSubscriptions[i] = q.webpushSubscriptions[len(q.webpushSubscriptions)-1] - q.webpushSubscriptions = q.webpushSubscriptions[:len(q.webpushSubscriptions)-1] - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) DeleteWorkspaceAgentPortShare(_ context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, share := range q.workspaceAgentPortShares { - if share.WorkspaceID == arg.WorkspaceID && share.AgentName == arg.AgentName && share.Port == arg.Port { - q.workspaceAgentPortShares = append(q.workspaceAgentPortShares[:i], q.workspaceAgentPortShares[i+1:]...) - return nil - } - } - - return nil -} - -func (q *FakeQuerier) DeleteWorkspaceAgentPortSharesByTemplate(_ context.Context, templateID uuid.UUID) error { - err := validateDatabaseType(templateID) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, workspace := range q.workspaces { - if workspace.TemplateID != templateID { - continue - } - for i, share := range q.workspaceAgentPortShares { - if share.WorkspaceID != workspace.ID { - continue - } - q.workspaceAgentPortShares = append(q.workspaceAgentPortShares[:i], q.workspaceAgentPortShares[i+1:]...) - } - } - - return nil -} - -func (q *FakeQuerier) DeleteWorkspaceSubAgentByID(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, agent := range q.workspaceAgents { - if agent.ID == id && agent.ParentID.Valid { - q.workspaceAgents[i].Deleted = true - return nil - } - } - - return nil -} - -func (*FakeQuerier) DisableForeignKeysAndTriggers(_ context.Context) error { - // This is a no-op in the in-memory database. - return nil -} - -func (q *FakeQuerier) EnqueueNotificationMessage(_ context.Context, arg database.EnqueueNotificationMessageParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - var payload types.MessagePayload - err = json.Unmarshal(arg.Payload, &payload) - if err != nil { - return err - } - - nm := database.NotificationMessage{ - ID: arg.ID, - UserID: arg.UserID, - Method: arg.Method, - Payload: arg.Payload, - NotificationTemplateID: arg.NotificationTemplateID, - Targets: arg.Targets, - CreatedBy: arg.CreatedBy, - // Default fields. - CreatedAt: dbtime.Now(), - Status: database.NotificationMessageStatusPending, - } - - q.notificationMessages = append(q.notificationMessages, nm) - - return err -} - -func (q *FakeQuerier) FavoriteWorkspace(_ context.Context, arg uuid.UUID) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i := 0; i < len(q.workspaces); i++ { - if q.workspaces[i].ID != arg { - continue - } - q.workspaces[i].Favorite = true - return nil - } - return nil -} - -func (q *FakeQuerier) FetchMemoryResourceMonitorsByAgentID(_ context.Context, agentID uuid.UUID) (database.WorkspaceAgentMemoryResourceMonitor, error) { - for _, monitor := range q.workspaceAgentMemoryResourceMonitors { - if monitor.AgentID == agentID { - return monitor, nil - } - } - - return database.WorkspaceAgentMemoryResourceMonitor{}, sql.ErrNoRows -} - -func (q *FakeQuerier) FetchMemoryResourceMonitorsUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.WorkspaceAgentMemoryResourceMonitor, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - monitors := []database.WorkspaceAgentMemoryResourceMonitor{} - for _, monitor := range q.workspaceAgentMemoryResourceMonitors { - if monitor.UpdatedAt.After(updatedAt) { - monitors = append(monitors, monitor) - } - } - return monitors, nil -} - -func (q *FakeQuerier) FetchNewMessageMetadata(_ context.Context, arg database.FetchNewMessageMetadataParams) (database.FetchNewMessageMetadataRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.FetchNewMessageMetadataRow{}, err - } - - user, err := q.getUserByIDNoLock(arg.UserID) - if err != nil { - return database.FetchNewMessageMetadataRow{}, xerrors.Errorf("fetch user: %w", err) - } - - // Mimic COALESCE in query - userName := user.Name - if userName == "" { - userName = user.Username - } - - actions, err := json.Marshal([]types.TemplateAction{{URL: "http://xyz.com", Label: "XYZ"}}) - if err != nil { - return database.FetchNewMessageMetadataRow{}, err - } - - return database.FetchNewMessageMetadataRow{ - UserEmail: user.Email, - UserName: userName, - UserUsername: user.Username, - NotificationName: "Some notification", - Actions: actions, - UserID: arg.UserID, - }, nil -} - -func (q *FakeQuerier) FetchVolumesResourceMonitorsByAgentID(_ context.Context, agentID uuid.UUID) ([]database.WorkspaceAgentVolumeResourceMonitor, error) { - monitors := []database.WorkspaceAgentVolumeResourceMonitor{} - - for _, monitor := range q.workspaceAgentVolumeResourceMonitors { - if monitor.AgentID == agentID { - monitors = append(monitors, monitor) - } - } - - return monitors, nil -} - -func (q *FakeQuerier) FetchVolumesResourceMonitorsUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.WorkspaceAgentVolumeResourceMonitor, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - monitors := []database.WorkspaceAgentVolumeResourceMonitor{} - for _, monitor := range q.workspaceAgentVolumeResourceMonitors { - if monitor.UpdatedAt.After(updatedAt) { - monitors = append(monitors, monitor) - } - } - return monitors, nil -} - -func (q *FakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, apiKey := range q.apiKeys { - if apiKey.ID == id { - return apiKey, nil - } - } - return database.APIKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetAPIKeyByName(_ context.Context, params database.GetAPIKeyByNameParams) (database.APIKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if params.TokenName == "" { - return database.APIKey{}, sql.ErrNoRows - } - for _, apiKey := range q.apiKeys { - if params.UserID == apiKey.UserID && params.TokenName == apiKey.TokenName { - return apiKey, nil - } - } - return database.APIKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetAPIKeysByLoginType(_ context.Context, t database.LoginType) ([]database.APIKey, error) { - if err := validateDatabaseType(t); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - apiKeys := make([]database.APIKey, 0) - for _, key := range q.apiKeys { - if key.LoginType == t { - apiKeys = append(apiKeys, key) - } - } - return apiKeys, nil -} - -func (q *FakeQuerier) GetAPIKeysByUserID(_ context.Context, params database.GetAPIKeysByUserIDParams) ([]database.APIKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - apiKeys := make([]database.APIKey, 0) - for _, key := range q.apiKeys { - if key.UserID == params.UserID && key.LoginType == params.LoginType { - apiKeys = append(apiKeys, key) - } - } - return apiKeys, nil -} - -func (q *FakeQuerier) GetAPIKeysLastUsedAfter(_ context.Context, after time.Time) ([]database.APIKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - apiKeys := make([]database.APIKey, 0) - for _, key := range q.apiKeys { - if key.LastUsed.After(after) { - apiKeys = append(apiKeys, key) - } - } - return apiKeys, nil -} - -func (q *FakeQuerier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var activeSchedules []database.TemplateVersionPresetPrebuildSchedule - - // Create a map of active template version IDs for quick lookup - activeTemplateVersions := make(map[uuid.UUID]bool) - for _, template := range q.templates { - if !template.Deleted && template.Deprecated == "" { - activeTemplateVersions[template.ActiveVersionID] = true - } - } - - // Create a map of presets for quick lookup - presetMap := make(map[uuid.UUID]database.TemplateVersionPreset) - for _, preset := range q.presets { - presetMap[preset.ID] = preset - } - - // Filter preset prebuild schedules to only include those for active template versions - for _, schedule := range q.presetPrebuildSchedules { - // Look up the preset using the map - preset, exists := presetMap[schedule.PresetID] - if !exists { - continue - } - - // Check if preset's template version is active - if !activeTemplateVersions[preset.TemplateVersionID] { - continue - } - - activeSchedules = append(activeSchedules, schedule) - } - - return activeSchedules, nil -} - -// nolint:revive // It's not a control flag, it's a filter. -func (q *FakeQuerier) GetActiveUserCount(_ context.Context, includeSystem bool) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - active := int64(0) - for _, u := range q.users { - if !includeSystem && u.IsSystem { - continue - } - - if u.Status == database.UserStatusActive && !u.Deleted { - active++ - } - } - return active, nil -} - -func (q *FakeQuerier) GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]database.WorkspaceBuild, error) { - workspaceIDs := func() []uuid.UUID { - q.mutex.RLock() - defer q.mutex.RUnlock() - - ids := []uuid.UUID{} - for _, workspace := range q.workspaces { - if workspace.TemplateID == templateID { - ids = append(ids, workspace.ID) - } - } - return ids - }() - - builds, err := q.GetLatestWorkspaceBuildsByWorkspaceIDs(ctx, workspaceIDs) - if err != nil { - return nil, err - } - - filteredBuilds := []database.WorkspaceBuild{} - for _, build := range builds { - if build.Transition == database.WorkspaceTransitionStart { - filteredBuilds = append(filteredBuilds, build) - } - } - return filteredBuilds, nil -} - -func (*FakeQuerier) GetAllTailnetAgents(_ context.Context) ([]database.TailnetAgent, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetAllTailnetCoordinators(context.Context) ([]database.TailnetCoordinator, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetAllTailnetPeers(context.Context) ([]database.TailnetPeer, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetAllTailnetTunnels(context.Context) ([]database.TailnetTunnel, error) { - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) GetAnnouncementBanners(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.announcementBanners == nil { - return "", sql.ErrNoRows - } - - return string(q.announcementBanners), nil -} - -func (q *FakeQuerier) GetAppSecurityKey(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.appSecurityKey, nil -} - -func (q *FakeQuerier) GetApplicationName(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.applicationName == "" { - return "", sql.ErrNoRows - } - - return q.applicationName, nil -} - -func (q *FakeQuerier) GetAuditLogsOffset(ctx context.Context, arg database.GetAuditLogsOffsetParams) ([]database.GetAuditLogsOffsetRow, error) { - return q.GetAuthorizedAuditLogsOffset(ctx, arg, nil) -} - -func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.UUID) (database.GetAuthorizationUserRolesRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var user *database.User - roles := make([]string, 0) - for _, u := range q.users { - if u.ID == userID { - roles = append(roles, u.RBACRoles...) - roles = append(roles, "member") - user = &u - break - } - } - - for _, mem := range q.organizationMembers { - if mem.UserID == userID { - for _, orgRole := range mem.Roles { - roles = append(roles, orgRole+":"+mem.OrganizationID.String()) - } - roles = append(roles, "organization-member:"+mem.OrganizationID.String()) - } - } - - var groups []string - for _, member := range q.groupMembers { - if member.UserID == userID { - groups = append(groups, member.GroupID.String()) - } - } - - if user == nil { - return database.GetAuthorizationUserRolesRow{}, sql.ErrNoRows - } - - return database.GetAuthorizationUserRolesRow{ - ID: userID, - Username: user.Username, - Status: user.Status, - Roles: roles, - Groups: groups, - }, nil -} - -func (q *FakeQuerier) GetCoordinatorResumeTokenSigningKey(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - if q.coordinatorResumeTokenSigningKey == "" { - return "", sql.ErrNoRows - } - return q.coordinatorResumeTokenSigningKey, nil -} - -func (q *FakeQuerier) GetCryptoKeyByFeatureAndSequence(_ context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.CryptoKey{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, key := range q.cryptoKeys { - if key.Feature == arg.Feature && key.Sequence == arg.Sequence { - // Keys with NULL secrets are considered deleted. - if key.Secret.Valid { - return key, nil - } - return database.CryptoKey{}, sql.ErrNoRows - } - } - - return database.CryptoKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetCryptoKeys(_ context.Context) ([]database.CryptoKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - keys := make([]database.CryptoKey, 0) - for _, key := range q.cryptoKeys { - if key.Secret.Valid { - keys = append(keys, key) - } - } - return keys, nil -} - -func (q *FakeQuerier) GetCryptoKeysByFeature(_ context.Context, feature database.CryptoKeyFeature) ([]database.CryptoKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - keys := make([]database.CryptoKey, 0) - for _, key := range q.cryptoKeys { - if key.Feature == feature && key.Secret.Valid { - keys = append(keys, key) - } - } - // We want to return the highest sequence number first. - slices.SortFunc(keys, func(i, j database.CryptoKey) int { - return int(j.Sequence - i.Sequence) - }) - return keys, nil -} - -func (q *FakeQuerier) GetDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - ks := make([]database.DBCryptKey, 0) - ks = append(ks, q.dbcryptKeys...) - return ks, nil -} - -func (q *FakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.derpMeshKey == "" { - return "", sql.ErrNoRows - } - return q.derpMeshKey, nil -} - -func (q *FakeQuerier) GetDefaultOrganization(_ context.Context) (database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, org := range q.organizations { - if org.IsDefault { - return org, nil - } - } - return database.Organization{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetDefaultProxyConfig(_ context.Context) (database.GetDefaultProxyConfigRow, error) { - return database.GetDefaultProxyConfigRow{ - DisplayName: q.defaultProxyDisplayName, - IconUrl: q.defaultProxyIconURL, - }, nil -} - -func (q *FakeQuerier) GetDeploymentDAUs(_ context.Context, tzOffset int32) ([]database.GetDeploymentDAUsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - seens := make(map[time.Time]map[uuid.UUID]struct{}) - - for _, as := range q.workspaceAgentStats { - if as.ConnectionCount == 0 { - continue - } - date := as.CreatedAt.UTC().Add(time.Duration(tzOffset) * -1 * time.Hour).Truncate(time.Hour * 24) - - dateEntry := seens[date] - if dateEntry == nil { - dateEntry = make(map[uuid.UUID]struct{}) - } - dateEntry[as.UserID] = struct{}{} - seens[date] = dateEntry - } - - seenKeys := maps.Keys(seens) - sort.Slice(seenKeys, func(i, j int) bool { - return seenKeys[i].Before(seenKeys[j]) - }) - - var rs []database.GetDeploymentDAUsRow - for _, key := range seenKeys { - ids := seens[key] - for id := range ids { - rs = append(rs, database.GetDeploymentDAUsRow{ - Date: key, - UserID: id, - }) - } - } - - return rs, nil -} - -func (q *FakeQuerier) GetDeploymentID(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.deploymentID, nil -} - -func (q *FakeQuerier) GetDeploymentWorkspaceAgentStats(_ context.Context, createdAfter time.Time) (database.GetDeploymentWorkspaceAgentStatsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) { - agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) - } - } - - latestAgentStats := map[uuid.UUID]database.WorkspaceAgentStat{} - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) { - latestAgentStats[agentStat.AgentID] = agentStat - } - } - - stat := database.GetDeploymentWorkspaceAgentStatsRow{} - for _, agentStat := range latestAgentStats { - stat.SessionCountVSCode += agentStat.SessionCountVSCode - stat.SessionCountJetBrains += agentStat.SessionCountJetBrains - stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - stat.SessionCountSSH += agentStat.SessionCountSSH - } - - latencies := make([]float64, 0) - for _, agentStat := range agentStatsCreatedAfter { - if agentStat.ConnectionMedianLatencyMS <= 0 { - continue - } - stat.WorkspaceRxBytes += agentStat.RxBytes - stat.WorkspaceTxBytes += agentStat.TxBytes - latencies = append(latencies, agentStat.ConnectionMedianLatencyMS) - } - - stat.WorkspaceConnectionLatency50 = tryPercentileCont(latencies, 50) - stat.WorkspaceConnectionLatency95 = tryPercentileCont(latencies, 95) - - return stat, nil -} - -func (q *FakeQuerier) GetDeploymentWorkspaceAgentUsageStats(_ context.Context, createdAt time.Time) (database.GetDeploymentWorkspaceAgentUsageStatsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - stat := database.GetDeploymentWorkspaceAgentUsageStatsRow{} - sessions := make(map[uuid.UUID]database.WorkspaceAgentStat) - agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) - for _, agentStat := range q.workspaceAgentStats { - // WHERE workspace_agent_stats.created_at > $1 - if agentStat.CreatedAt.After(createdAt) { - agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) - } - // WHERE - // created_at > $1 - // AND created_at < date_trunc('minute', now()) -- Exclude current partial minute - // AND usage = true - if agentStat.Usage && - (agentStat.CreatedAt.After(createdAt) || agentStat.CreatedAt.Equal(createdAt)) && - agentStat.CreatedAt.Before(time.Now().Truncate(time.Minute)) { - val, ok := sessions[agentStat.AgentID] - if !ok { - sessions[agentStat.AgentID] = agentStat - } else if agentStat.CreatedAt.After(val.CreatedAt) { - sessions[agentStat.AgentID] = agentStat - } else if agentStat.CreatedAt.Truncate(time.Minute).Equal(val.CreatedAt.Truncate(time.Minute)) { - val.SessionCountVSCode += agentStat.SessionCountVSCode - val.SessionCountJetBrains += agentStat.SessionCountJetBrains - val.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - val.SessionCountSSH += agentStat.SessionCountSSH - sessions[agentStat.AgentID] = val - } - } - } - - latencies := make([]float64, 0) - for _, agentStat := range agentStatsCreatedAfter { - if agentStat.ConnectionMedianLatencyMS <= 0 { - continue - } - stat.WorkspaceRxBytes += agentStat.RxBytes - stat.WorkspaceTxBytes += agentStat.TxBytes - latencies = append(latencies, agentStat.ConnectionMedianLatencyMS) - } - stat.WorkspaceConnectionLatency50 = tryPercentileCont(latencies, 50) - stat.WorkspaceConnectionLatency95 = tryPercentileCont(latencies, 95) - - for _, agentStat := range sessions { - stat.SessionCountVSCode += agentStat.SessionCountVSCode - stat.SessionCountJetBrains += agentStat.SessionCountJetBrains - stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - stat.SessionCountSSH += agentStat.SessionCountSSH - } - - return stat, nil -} - -func (q *FakeQuerier) GetDeploymentWorkspaceStats(ctx context.Context) (database.GetDeploymentWorkspaceStatsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - stat := database.GetDeploymentWorkspaceStatsRow{} - for _, workspace := range q.workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return stat, err - } - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return stat, err - } - if !job.StartedAt.Valid { - stat.PendingWorkspaces++ - continue - } - if job.StartedAt.Valid && - !job.CanceledAt.Valid && - time.Since(job.UpdatedAt) <= 30*time.Second && - !job.CompletedAt.Valid { - stat.BuildingWorkspaces++ - continue - } - if job.CompletedAt.Valid && - !job.CanceledAt.Valid && - !job.Error.Valid { - if build.Transition == database.WorkspaceTransitionStart { - stat.RunningWorkspaces++ - } - if build.Transition == database.WorkspaceTransitionStop { - stat.StoppedWorkspaces++ - } - continue - } - if job.CanceledAt.Valid || job.Error.Valid { - stat.FailedWorkspaces++ - continue - } - } - return stat, nil -} - -func (q *FakeQuerier) GetEligibleProvisionerDaemonsByProvisionerJobIDs(_ context.Context, provisionerJobIds []uuid.UUID) ([]database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - results := make([]database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow, 0) - seen := make(map[string]struct{}) // Track unique combinations - - for _, jobID := range provisionerJobIds { - var job database.ProvisionerJob - found := false - for _, j := range q.provisionerJobs { - if j.ID == jobID { - job = j - found = true - break - } - } - if !found { - continue - } - - for _, daemon := range q.provisionerDaemons { - if daemon.OrganizationID != job.OrganizationID { - continue - } - - if !tagsSubset(job.Tags, daemon.Tags) { - continue - } - - provisionerMatches := false - for _, p := range daemon.Provisioners { - if p == job.Provisioner { - provisionerMatches = true - break - } - } - if !provisionerMatches { - continue - } - - key := jobID.String() + "-" + daemon.ID.String() - if _, exists := seen[key]; exists { - continue - } - seen[key] = struct{}{} - - results = append(results, database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow{ - JobID: jobID, - ProvisionerDaemon: daemon, - }) - } - } - - return results, nil -} - -func (q *FakeQuerier) GetExternalAuthLink(_ context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ExternalAuthLink{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - for _, gitAuthLink := range q.externalAuthLinks { - if arg.UserID != gitAuthLink.UserID { - continue - } - if arg.ProviderID != gitAuthLink.ProviderID { - continue - } - return gitAuthLink, nil - } - return database.ExternalAuthLink{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetExternalAuthLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.ExternalAuthLink, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - gals := make([]database.ExternalAuthLink, 0) - for _, gal := range q.externalAuthLinks { - if gal.UserID == userID { - gals = append(gals, gal) - } - } - return gals, nil -} - -func (q *FakeQuerier) GetFailedWorkspaceBuildsByTemplateID(ctx context.Context, arg database.GetFailedWorkspaceBuildsByTemplateIDParams) ([]database.GetFailedWorkspaceBuildsByTemplateIDRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaceBuildStats := []database.GetFailedWorkspaceBuildsByTemplateIDRow{} - for _, wb := range q.workspaceBuilds { - job, err := q.getProvisionerJobByIDNoLock(ctx, wb.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job by ID: %w", err) - } - - if job.JobStatus != database.ProvisionerJobStatusFailed { - continue - } - - if !job.CompletedAt.Valid { - continue - } - - if wb.CreatedAt.Before(arg.Since) { - continue - } - - w, err := q.getWorkspaceByIDNoLock(ctx, wb.WorkspaceID) - if err != nil { - return nil, xerrors.Errorf("get workspace by ID: %w", err) - } - - t, err := q.getTemplateByIDNoLock(ctx, w.TemplateID) - if err != nil { - return nil, xerrors.Errorf("get template by ID: %w", err) - } - - if t.ID != arg.TemplateID { - continue - } - - workspaceOwner, err := q.getUserByIDNoLock(w.OwnerID) - if err != nil { - return nil, xerrors.Errorf("get user by ID: %w", err) - } - - templateVersion, err := q.getTemplateVersionByIDNoLock(ctx, wb.TemplateVersionID) - if err != nil { - return nil, xerrors.Errorf("get template version by ID: %w", err) - } - - workspaceBuildStats = append(workspaceBuildStats, database.GetFailedWorkspaceBuildsByTemplateIDRow{ - WorkspaceID: w.ID, - WorkspaceName: w.Name, - WorkspaceOwnerUsername: workspaceOwner.Username, - TemplateVersionName: templateVersion.Name, - WorkspaceBuildNumber: wb.BuildNumber, - }) - } - - sort.Slice(workspaceBuildStats, func(i, j int) bool { - if workspaceBuildStats[i].TemplateVersionName != workspaceBuildStats[j].TemplateVersionName { - return workspaceBuildStats[i].TemplateVersionName < workspaceBuildStats[j].TemplateVersionName - } - return workspaceBuildStats[i].WorkspaceBuildNumber > workspaceBuildStats[j].WorkspaceBuildNumber - }) - return workspaceBuildStats, nil -} - -func (q *FakeQuerier) GetFileByHashAndCreator(_ context.Context, arg database.GetFileByHashAndCreatorParams) (database.File, error) { - if err := validateDatabaseType(arg); err != nil { - return database.File{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, file := range q.files { - if file.Hash == arg.Hash && file.CreatedBy == arg.CreatedBy { - return file, nil - } - } - return database.File{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetFileByID(_ context.Context, id uuid.UUID) (database.File, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, file := range q.files { - if file.ID == id { - return file, nil - } - } - return database.File{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetFileIDByTemplateVersionID(ctx context.Context, templateVersionID uuid.UUID) (uuid.UUID, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, v := range q.templateVersions { - if v.ID == templateVersionID { - jobID := v.JobID - for _, j := range q.provisionerJobs { - if j.ID == jobID { - if j.StorageMethod == database.ProvisionerStorageMethodFile { - return j.FileID, nil - } - // We found the right job id but it wasn't a proper match. - break - } - } - // We found the right template version but it wasn't a proper match. - break - } - } - - return uuid.Nil, sql.ErrNoRows -} - -func (q *FakeQuerier) GetFileTemplates(_ context.Context, id uuid.UUID) ([]database.GetFileTemplatesRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - rows := make([]database.GetFileTemplatesRow, 0) - var file database.File - for _, f := range q.files { - if f.ID == id { - file = f - break - } - } - if file.Hash == "" { - return rows, nil - } - - for _, job := range q.provisionerJobs { - if job.FileID == id { - for _, version := range q.templateVersions { - if version.JobID == job.ID { - for _, template := range q.templates { - if template.ID == version.TemplateID.UUID { - rows = append(rows, database.GetFileTemplatesRow{ - FileID: file.ID, - FileCreatedBy: file.CreatedBy, - TemplateID: template.ID, - TemplateOrganizationID: template.OrganizationID, - TemplateCreatedBy: template.CreatedBy, - UserACL: template.UserACL, - GroupACL: template.GroupACL, - }) - } - } - } - } - } - } - - return rows, nil -} - -func (q *FakeQuerier) GetFilteredInboxNotificationsByUserID(_ context.Context, arg database.GetFilteredInboxNotificationsByUserIDParams) ([]database.InboxNotification, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - notifications := make([]database.InboxNotification, 0) - // TODO : after using go version >= 1.23 , we can change this one to https://pkg.go.dev/slices#Backward - for idx := len(q.inboxNotifications) - 1; idx >= 0; idx-- { - notification := q.inboxNotifications[idx] - - if notification.UserID == arg.UserID { - if !arg.CreatedAtOpt.IsZero() && !notification.CreatedAt.Before(arg.CreatedAtOpt) { - continue - } - - templateFound := false - for _, template := range arg.Templates { - if notification.TemplateID == template { - templateFound = true - } - } - - if len(arg.Templates) > 0 && !templateFound { - continue - } - - targetsFound := true - for _, target := range arg.Targets { - targetFound := false - for _, insertedTarget := range notification.Targets { - if insertedTarget == target { - targetFound = true - break - } - } - - if !targetFound { - targetsFound = false - break - } - } - - if !targetsFound { - continue - } - - if (arg.LimitOpt == 0 && len(notifications) == 25) || - (arg.LimitOpt != 0 && len(notifications) == int(arg.LimitOpt)) { - break - } - - notifications = append(notifications, notification) - } - } - - return notifications, nil -} - -func (q *FakeQuerier) GetGitSSHKey(_ context.Context, userID uuid.UUID) (database.GitSSHKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, key := range q.gitSSHKey { - if key.UserID == userID { - return key, nil - } - } - return database.GitSSHKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getGroupByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGroupByOrgAndNameParams) (database.Group, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Group{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, group := range q.groups { - if group.OrganizationID == arg.OrganizationID && - group.Name == arg.Name { - return group, nil - } - } - - return database.Group{}, sql.ErrNoRows -} - -//nolint:revive // It's not a control flag, its a filter -func (q *FakeQuerier) GetGroupMembers(ctx context.Context, includeSystem bool) ([]database.GroupMember, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - members := make([]database.GroupMemberTable, 0, len(q.groupMembers)) - members = append(members, q.groupMembers...) - for _, org := range q.organizations { - for _, user := range q.users { - if !includeSystem && user.IsSystem { - continue - } - members = append(members, database.GroupMemberTable{ - UserID: user.ID, - GroupID: org.ID, - }) - } - } - - var groupMembers []database.GroupMember - for _, member := range members { - groupMember, err := q.getGroupMemberNoLock(ctx, member.UserID, member.GroupID) - if errors.Is(err, errUserDeleted) { - continue - } - if err != nil { - return nil, err - } - groupMembers = append(groupMembers, groupMember) - } - - return groupMembers, nil -} - -func (q *FakeQuerier) GetGroupMembersByGroupID(ctx context.Context, arg database.GetGroupMembersByGroupIDParams) ([]database.GroupMember, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.isEveryoneGroup(arg.GroupID) { - return q.getEveryoneGroupMembersNoLock(ctx, arg.GroupID), nil - } - - var groupMembers []database.GroupMember - for _, member := range q.groupMembers { - if member.GroupID == arg.GroupID { - groupMember, err := q.getGroupMemberNoLock(ctx, member.UserID, member.GroupID) - if errors.Is(err, errUserDeleted) { - continue - } - if err != nil { - return nil, err - } - groupMembers = append(groupMembers, groupMember) - } - } - - return groupMembers, nil -} - -func (q *FakeQuerier) GetGroupMembersCountByGroupID(ctx context.Context, arg database.GetGroupMembersCountByGroupIDParams) (int64, error) { - users, err := q.GetGroupMembersByGroupID(ctx, database.GetGroupMembersByGroupIDParams(arg)) - if err != nil { - return 0, err - } - return int64(len(users)), nil -} - -func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams) ([]database.GetGroupsRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - userGroupIDs := make(map[uuid.UUID]struct{}) - if arg.HasMemberID != uuid.Nil { - for _, member := range q.groupMembers { - if member.UserID == arg.HasMemberID { - userGroupIDs[member.GroupID] = struct{}{} - } - } - - // Handle the everyone group - for _, orgMember := range q.organizationMembers { - if orgMember.UserID == arg.HasMemberID { - userGroupIDs[orgMember.OrganizationID] = struct{}{} - } - } - } - - orgDetailsCache := make(map[uuid.UUID]struct{ name, displayName string }) - filtered := make([]database.GetGroupsRow, 0) - for _, group := range q.groups { - if len(arg.GroupIds) > 0 { - if !slices.Contains(arg.GroupIds, group.ID) { - continue - } - } - - if arg.OrganizationID != uuid.Nil && group.OrganizationID != arg.OrganizationID { - continue - } - - _, ok := userGroupIDs[group.ID] - if arg.HasMemberID != uuid.Nil && !ok { - continue - } - - if len(arg.GroupNames) > 0 && !slices.Contains(arg.GroupNames, group.Name) { - continue - } - - orgDetails, ok := orgDetailsCache[group.ID] - if !ok { - for _, org := range q.organizations { - if group.OrganizationID == org.ID { - orgDetails = struct{ name, displayName string }{ - name: org.Name, displayName: org.DisplayName, - } - break - } - } - orgDetailsCache[group.ID] = orgDetails - } - - filtered = append(filtered, database.GetGroupsRow{ - Group: group, - OrganizationName: orgDetails.name, - OrganizationDisplayName: orgDetails.displayName, - }) - } - - return filtered, nil -} - -func (q *FakeQuerier) GetHealthSettings(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.healthSettings == nil { - return "{}", nil - } - - return string(q.healthSettings), nil -} - -func (q *FakeQuerier) GetInboxNotificationByID(_ context.Context, id uuid.UUID) (database.InboxNotification, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, notification := range q.inboxNotifications { - if notification.ID == id { - return notification, nil - } - } - - return database.InboxNotification{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetInboxNotificationsByUserID(_ context.Context, params database.GetInboxNotificationsByUserIDParams) ([]database.InboxNotification, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - notifications := make([]database.InboxNotification, 0) - for _, notification := range q.inboxNotifications { - if notification.UserID == params.UserID { - notifications = append(notifications, notification) - } - } - - return notifications, nil -} - -func (q *FakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.lastUpdateCheck == nil { - return "", sql.ErrNoRows - } - return string(q.lastUpdateCheck), nil -} - -func (q *FakeQuerier) GetLatestCryptoKeyByFeature(_ context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var latestKey database.CryptoKey - for _, key := range q.cryptoKeys { - if key.Feature == feature && latestKey.Sequence < key.Sequence { - latestKey = key - } - } - if latestKey.StartsAt.IsZero() { - return database.CryptoKey{}, sql.ErrNoRows - } - return latestKey, nil -} - -func (q *FakeQuerier) GetLatestWorkspaceAppStatusesByWorkspaceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceAppStatus, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Map to track latest status per workspace ID - latestByWorkspace := make(map[uuid.UUID]database.WorkspaceAppStatus) - - // Find latest status for each workspace ID - for _, appStatus := range q.workspaceAppStatuses { - if !slices.Contains(ids, appStatus.WorkspaceID) { - continue - } - - current, exists := latestByWorkspace[appStatus.WorkspaceID] - if !exists || appStatus.CreatedAt.After(current.CreatedAt) { - latestByWorkspace[appStatus.WorkspaceID] = appStatus - } - } - - // Convert map to slice - appStatuses := make([]database.WorkspaceAppStatus, 0, len(latestByWorkspace)) - for _, status := range latestByWorkspace { - appStatuses = append(appStatuses, status) - } - - return appStatuses, nil -} - -func (q *FakeQuerier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspaceID) -} - -func (q *FakeQuerier) GetLatestWorkspaceBuilds(_ context.Context) ([]database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - builds := make(map[uuid.UUID]database.WorkspaceBuild) - buildNumbers := make(map[uuid.UUID]int32) - for _, workspaceBuild := range q.workspaceBuilds { - id := workspaceBuild.WorkspaceID - if workspaceBuild.BuildNumber > buildNumbers[id] { - builds[id] = q.workspaceBuildWithUserNoLock(workspaceBuild) - buildNumbers[id] = workspaceBuild.BuildNumber - } - } - var returnBuilds []database.WorkspaceBuild - for i, n := range buildNumbers { - if n > 0 { - b := builds[i] - returnBuilds = append(returnBuilds, b) - } - } - if len(returnBuilds) == 0 { - return nil, sql.ErrNoRows - } - return returnBuilds, nil -} - -func (q *FakeQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - builds := make(map[uuid.UUID]database.WorkspaceBuild) - buildNumbers := make(map[uuid.UUID]int32) - for _, workspaceBuild := range q.workspaceBuilds { - for _, id := range ids { - if id == workspaceBuild.WorkspaceID && workspaceBuild.BuildNumber > buildNumbers[id] { - builds[id] = q.workspaceBuildWithUserNoLock(workspaceBuild) - buildNumbers[id] = workspaceBuild.BuildNumber - } - } - } - var returnBuilds []database.WorkspaceBuild - for i, n := range buildNumbers { - if n > 0 { - b := builds[i] - returnBuilds = append(returnBuilds, b) - } - } - if len(returnBuilds) == 0 { - return nil, sql.ErrNoRows - } - return returnBuilds, nil -} - -func (q *FakeQuerier) GetLicenseByID(_ context.Context, id int32) (database.License, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, license := range q.licenses { - if license.ID == id { - return license, nil - } - } - return database.License{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetLicenses(_ context.Context) ([]database.License, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - results := append([]database.License{}, q.licenses...) - sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID }) - return results, nil -} - -func (q *FakeQuerier) GetLogoURL(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.logoURL == "" { - return "", sql.ErrNoRows - } - - return q.logoURL, nil -} - -func (q *FakeQuerier) GetNotificationMessagesByStatus(_ context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - var out []database.NotificationMessage - for _, m := range q.notificationMessages { - if len(out) > int(arg.Limit) { - return out, nil - } - - if m.Status == arg.Status { - out = append(out, m) - } - } - - return out, nil -} - -func (q *FakeQuerier) GetNotificationReportGeneratorLogByTemplate(_ context.Context, templateID uuid.UUID) (database.NotificationReportGeneratorLog, error) { - err := validateDatabaseType(templateID) - if err != nil { - return database.NotificationReportGeneratorLog{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, record := range q.notificationReportGeneratorLogs { - if record.NotificationTemplateID == templateID { - return record, nil - } - } - return database.NotificationReportGeneratorLog{}, sql.ErrNoRows -} - -func (*FakeQuerier) GetNotificationTemplateByID(_ context.Context, _ uuid.UUID) (database.NotificationTemplate, error) { - // Not implementing this function because it relies on state in the database which is created with migrations. - // We could consider using code-generation to align the database state and dbmem, but it's not worth it right now. - return database.NotificationTemplate{}, ErrUnimplemented -} - -func (*FakeQuerier) GetNotificationTemplatesByKind(_ context.Context, _ database.NotificationTemplateKind) ([]database.NotificationTemplate, error) { - // Not implementing this function because it relies on state in the database which is created with migrations. - // We could consider using code-generation to align the database state and dbmem, but it's not worth it right now. - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) GetNotificationsSettings(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.notificationsSettings == nil { - return "{}", nil - } - - return string(q.notificationsSettings), nil -} - -func (q *FakeQuerier) GetOAuth2GithubDefaultEligible(_ context.Context) (bool, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.oauth2GithubDefaultEligible == nil { - return false, sql.ErrNoRows - } - return *q.oauth2GithubDefaultEligible, nil -} - -func (q *FakeQuerier) GetOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) (database.OAuth2ProviderApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, app := range q.oauth2ProviderApps { - if app.ID == id { - return app, nil - } - } - return database.OAuth2ProviderApp{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppByID(_ context.Context, id uuid.UUID) (database.OAuth2ProviderApp, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, app := range q.oauth2ProviderApps { - if app.ID == id { - return app, nil - } - } - return database.OAuth2ProviderApp{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken sql.NullString) (database.OAuth2ProviderApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, app := range q.data.oauth2ProviderApps { - if app.RegistrationAccessToken.Valid && registrationAccessToken.Valid && - app.RegistrationAccessToken.String == registrationAccessToken.String { - return app, nil - } - } - return database.OAuth2ProviderApp{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppCodeByID(_ context.Context, id uuid.UUID) (database.OAuth2ProviderAppCode, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, code := range q.oauth2ProviderAppCodes { - if code.ID == id { - return code, nil - } - } - return database.OAuth2ProviderAppCode{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppCodeByPrefix(_ context.Context, secretPrefix []byte) (database.OAuth2ProviderAppCode, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, code := range q.oauth2ProviderAppCodes { - if bytes.Equal(code.SecretPrefix, secretPrefix) { - return code, nil - } - } - return database.OAuth2ProviderAppCode{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppSecretByID(_ context.Context, id uuid.UUID) (database.OAuth2ProviderAppSecret, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, secret := range q.oauth2ProviderAppSecrets { - if secret.ID == id { - return secret, nil - } - } - return database.OAuth2ProviderAppSecret{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppSecretByPrefix(_ context.Context, secretPrefix []byte) (database.OAuth2ProviderAppSecret, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, secret := range q.oauth2ProviderAppSecrets { - if bytes.Equal(secret.SecretPrefix, secretPrefix) { - return secret, nil - } - } - return database.OAuth2ProviderAppSecret{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppSecretsByAppID(_ context.Context, appID uuid.UUID) ([]database.OAuth2ProviderAppSecret, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, app := range q.oauth2ProviderApps { - if app.ID == appID { - secrets := []database.OAuth2ProviderAppSecret{} - for _, secret := range q.oauth2ProviderAppSecrets { - if secret.AppID == appID { - secrets = append(secrets, secret) - } - } - - slices.SortFunc(secrets, func(a, b database.OAuth2ProviderAppSecret) int { - if a.CreatedAt.Before(b.CreatedAt) { - return -1 - } else if a.CreatedAt.Equal(b.CreatedAt) { - return 0 - } - return 1 - }) - return secrets, nil - } - } - - return []database.OAuth2ProviderAppSecret{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppTokenByAPIKeyID(_ context.Context, apiKeyID string) (database.OAuth2ProviderAppToken, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, token := range q.oauth2ProviderAppTokens { - if token.APIKeyID == apiKeyID { - return token, nil - } - } - - return database.OAuth2ProviderAppToken{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderAppTokenByPrefix(_ context.Context, hashPrefix []byte) (database.OAuth2ProviderAppToken, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, token := range q.oauth2ProviderAppTokens { - if bytes.Equal(token.HashPrefix, hashPrefix) { - return token, nil - } - } - return database.OAuth2ProviderAppToken{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOAuth2ProviderApps(_ context.Context) ([]database.OAuth2ProviderApp, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - slices.SortFunc(q.oauth2ProviderApps, func(a, b database.OAuth2ProviderApp) int { - return slice.Ascending(a.Name, b.Name) - }) - return q.oauth2ProviderApps, nil -} - -func (q *FakeQuerier) GetOAuth2ProviderAppsByUserID(_ context.Context, userID uuid.UUID) ([]database.GetOAuth2ProviderAppsByUserIDRow, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - rows := []database.GetOAuth2ProviderAppsByUserIDRow{} - for _, app := range q.oauth2ProviderApps { - tokens := []database.OAuth2ProviderAppToken{} - for _, secret := range q.oauth2ProviderAppSecrets { - if secret.AppID == app.ID { - for _, token := range q.oauth2ProviderAppTokens { - if token.AppSecretID == secret.ID { - keyIdx := slices.IndexFunc(q.apiKeys, func(key database.APIKey) bool { - return key.ID == token.APIKeyID - }) - if keyIdx != -1 && q.apiKeys[keyIdx].UserID == userID { - tokens = append(tokens, token) - } - } - } - } - } - if len(tokens) > 0 { - rows = append(rows, database.GetOAuth2ProviderAppsByUserIDRow{ - OAuth2ProviderApp: app, - TokenCount: int64(len(tokens)), - }) - } - } - return rows, nil -} - -func (q *FakeQuerier) GetOAuthSigningKey(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.oauthSigningKey, nil -} - -func (q *FakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getOrganizationByIDNoLock(id) -} - -func (q *FakeQuerier) GetOrganizationByName(_ context.Context, params database.GetOrganizationByNameParams) (database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, organization := range q.organizations { - if organization.Name == params.Name && organization.Deleted == params.Deleted { - return organization, nil - } - } - return database.Organization{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uuid.UUID) ([]database.GetOrganizationIDsByMemberIDsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - getOrganizationIDsByMemberIDRows := make([]database.GetOrganizationIDsByMemberIDsRow, 0, len(ids)) - for _, userID := range ids { - userOrganizationIDs := make([]uuid.UUID, 0) - for _, membership := range q.organizationMembers { - if membership.UserID == userID { - userOrganizationIDs = append(userOrganizationIDs, membership.OrganizationID) - } - } - getOrganizationIDsByMemberIDRows = append(getOrganizationIDsByMemberIDRows, database.GetOrganizationIDsByMemberIDsRow{ - UserID: userID, - OrganizationIDs: userOrganizationIDs, - }) - } - return getOrganizationIDsByMemberIDRows, nil -} - -func (q *FakeQuerier) GetOrganizationResourceCountByID(_ context.Context, organizationID uuid.UUID) (database.GetOrganizationResourceCountByIDRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspacesCount := 0 - for _, workspace := range q.workspaces { - if workspace.OrganizationID == organizationID { - workspacesCount++ - } - } - - groupsCount := 0 - for _, group := range q.groups { - if group.OrganizationID == organizationID { - groupsCount++ - } - } - - templatesCount := 0 - for _, template := range q.templates { - if template.OrganizationID == organizationID { - templatesCount++ - } - } - - organizationMembersCount := 0 - for _, organizationMember := range q.organizationMembers { - if organizationMember.OrganizationID == organizationID { - organizationMembersCount++ - } - } - - provKeyCount := 0 - for _, provKey := range q.provisionerKeys { - if provKey.OrganizationID == organizationID { - provKeyCount++ - } - } - - return database.GetOrganizationResourceCountByIDRow{ - WorkspaceCount: int64(workspacesCount), - GroupCount: int64(groupsCount), - TemplateCount: int64(templatesCount), - MemberCount: int64(organizationMembersCount), - ProvisionerKeyCount: int64(provKeyCount), - }, nil -} - -func (q *FakeQuerier) GetOrganizations(_ context.Context, args database.GetOrganizationsParams) ([]database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - tmp := make([]database.Organization, 0) - for _, org := range q.organizations { - if len(args.IDs) > 0 { - if !slices.Contains(args.IDs, org.ID) { - continue - } - } - if args.Name != "" && !strings.EqualFold(org.Name, args.Name) { - continue - } - if args.Deleted != org.Deleted { - continue - } - tmp = append(tmp, org) - } - - return tmp, nil -} - -func (q *FakeQuerier) GetOrganizationsByUserID(_ context.Context, arg database.GetOrganizationsByUserIDParams) ([]database.Organization, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - organizations := make([]database.Organization, 0) - for _, organizationMember := range q.organizationMembers { - if organizationMember.UserID != arg.UserID { - continue - } - for _, organization := range q.organizations { - if organization.ID != organizationMember.OrganizationID { - continue - } - - if arg.Deleted.Valid && organization.Deleted != arg.Deleted.Bool { - continue - } - organizations = append(organizations, organization) - } - } - - return organizations, nil -} - -func (q *FakeQuerier) GetParameterSchemasByJobID(_ context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - parameters := make([]database.ParameterSchema, 0) - for _, parameterSchema := range q.parameterSchemas { - if parameterSchema.JobID != jobID { - continue - } - parameters = append(parameters, parameterSchema) - } - if len(parameters) == 0 { - return nil, sql.ErrNoRows - } - sort.Slice(parameters, func(i, j int) bool { - return parameters[i].Index < parameters[j].Index - }) - return parameters, nil -} - -func (*FakeQuerier) GetPrebuildMetrics(_ context.Context) ([]database.GetPrebuildMetricsRow, error) { - return make([]database.GetPrebuildMetricsRow, 0), nil -} - -func (q *FakeQuerier) GetPrebuildsSettings(_ context.Context) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return string(slices.Clone(q.prebuildsSettings)), nil -} - -func (q *FakeQuerier) GetPresetByID(_ context.Context, presetID uuid.UUID) (database.GetPresetByIDRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - empty := database.GetPresetByIDRow{} - - // Create an index for faster lookup - versionMap := make(map[uuid.UUID]database.TemplateVersionTable) - for _, tv := range q.templateVersions { - versionMap[tv.ID] = tv - } - - for _, preset := range q.presets { - if preset.ID == presetID { - tv, ok := versionMap[preset.TemplateVersionID] - if !ok { - return empty, xerrors.Errorf("template version %v does not exist", preset.TemplateVersionID) - } - return database.GetPresetByIDRow{ - ID: preset.ID, - TemplateVersionID: preset.TemplateVersionID, - Name: preset.Name, - CreatedAt: preset.CreatedAt, - DesiredInstances: preset.DesiredInstances, - InvalidateAfterSecs: preset.InvalidateAfterSecs, - PrebuildStatus: preset.PrebuildStatus, - TemplateID: tv.TemplateID, - OrganizationID: tv.OrganizationID, - }, nil - } - } - - return empty, xerrors.Errorf("preset %v does not exist", presetID) -} - -func (q *FakeQuerier) GetPresetByWorkspaceBuildID(_ context.Context, workspaceBuildID uuid.UUID) (database.TemplateVersionPreset, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.ID != workspaceBuildID { - continue - } - for _, preset := range q.presets { - if preset.TemplateVersionID == workspaceBuild.TemplateVersionID { - return preset, nil - } - } - } - return database.TemplateVersionPreset{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetPresetParametersByPresetID(_ context.Context, presetID uuid.UUID) ([]database.TemplateVersionPresetParameter, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - parameters := make([]database.TemplateVersionPresetParameter, 0) - for _, parameter := range q.presetParameters { - if parameter.TemplateVersionPresetID != presetID { - continue - } - parameters = append(parameters, parameter) - } - - return parameters, nil -} - -func (q *FakeQuerier) GetPresetParametersByTemplateVersionID(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionPresetParameter, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - presets := make([]database.TemplateVersionPreset, 0) - parameters := make([]database.TemplateVersionPresetParameter, 0) - for _, preset := range q.presets { - if preset.TemplateVersionID != templateVersionID { - continue - } - presets = append(presets, preset) - } - for _, parameter := range q.presetParameters { - for _, preset := range presets { - if parameter.TemplateVersionPresetID != preset.ID { - continue - } - parameters = append(parameters, parameter) - } - } - - return parameters, nil -} - -func (q *FakeQuerier) GetPresetsAtFailureLimit(ctx context.Context, hardLimit int64) ([]database.GetPresetsAtFailureLimitRow, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetPresetsBackoff(_ context.Context, _ time.Time) ([]database.GetPresetsBackoffRow, error) { - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) GetPresetsByTemplateVersionID(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionPreset, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - presets := make([]database.TemplateVersionPreset, 0) - for _, preset := range q.presets { - if preset.TemplateVersionID == templateVersionID { - presets = append(presets, preset) - } - } - return presets, nil -} - -func (q *FakeQuerier) GetPreviousTemplateVersion(_ context.Context, arg database.GetPreviousTemplateVersionParams) (database.TemplateVersion, error) { - if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersion{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - var currentTemplateVersion database.TemplateVersion - for _, templateVersion := range q.templateVersions { - if templateVersion.TemplateID != arg.TemplateID { - continue - } - if templateVersion.Name != arg.Name { - continue - } - if templateVersion.OrganizationID != arg.OrganizationID { - continue - } - currentTemplateVersion = q.templateVersionWithUserNoLock(templateVersion) - break - } - - previousTemplateVersions := make([]database.TemplateVersion, 0) - for _, templateVersion := range q.templateVersions { - if templateVersion.ID == currentTemplateVersion.ID { - continue - } - if templateVersion.OrganizationID != arg.OrganizationID { - continue - } - if templateVersion.TemplateID != currentTemplateVersion.TemplateID { - continue - } - - if templateVersion.CreatedAt.Before(currentTemplateVersion.CreatedAt) { - previousTemplateVersions = append(previousTemplateVersions, q.templateVersionWithUserNoLock(templateVersion)) - } - } - - if len(previousTemplateVersions) == 0 { - return database.TemplateVersion{}, sql.ErrNoRows - } - - sort.Slice(previousTemplateVersions, func(i, j int) bool { - return previousTemplateVersions[i].CreatedAt.After(previousTemplateVersions[j].CreatedAt) - }) - - return previousTemplateVersions[0], nil -} - -func (q *FakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.ProvisionerDaemon, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if len(q.provisionerDaemons) == 0 { - // Returning err=nil here for consistency with real querier - return []database.ProvisionerDaemon{}, nil - } - // copy the data so that the caller can't manipulate any data inside dbmem - // after returning - out := make([]database.ProvisionerDaemon, len(q.provisionerDaemons)) - copy(out, q.provisionerDaemons) - for i := range out { - // maps are reference types, so we need to clone them - out[i].Tags = maps.Clone(out[i].Tags) - } - return out, nil -} - -func (q *FakeQuerier) GetProvisionerDaemonsByOrganization(_ context.Context, arg database.GetProvisionerDaemonsByOrganizationParams) ([]database.ProvisionerDaemon, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - daemons := make([]database.ProvisionerDaemon, 0) - for _, daemon := range q.provisionerDaemons { - if daemon.OrganizationID != arg.OrganizationID { - continue - } - // Special case for untagged provisioners: only match untagged jobs. - // Ref: coderd/database/queries/provisionerjobs.sql:24-30 - // CASE WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb - // THEN nested.tags :: jsonb = @tags :: jsonb - if tagsEqual(arg.WantTags, tagsUntagged) && !tagsEqual(arg.WantTags, daemon.Tags) { - continue - } - // ELSE nested.tags :: jsonb <@ @tags :: jsonb - if !tagsSubset(arg.WantTags, daemon.Tags) { - continue - } - daemon.Tags = maps.Clone(daemon.Tags) - daemons = append(daemons, daemon) - } - - return daemons, nil -} - -func (q *FakeQuerier) GetProvisionerDaemonsWithStatusByOrganization(ctx context.Context, arg database.GetProvisionerDaemonsWithStatusByOrganizationParams) ([]database.GetProvisionerDaemonsWithStatusByOrganizationRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - var rows []database.GetProvisionerDaemonsWithStatusByOrganizationRow - for _, daemon := range q.provisionerDaemons { - if daemon.OrganizationID != arg.OrganizationID { - continue - } - if len(arg.IDs) > 0 && !slices.Contains(arg.IDs, daemon.ID) { - continue - } - - if len(arg.Tags) > 0 { - // Special case for untagged provisioners: only match untagged jobs. - // Ref: coderd/database/queries/provisionerjobs.sql:24-30 - // CASE WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb - // THEN nested.tags :: jsonb = @tags :: jsonb - if tagsEqual(arg.Tags, tagsUntagged) && !tagsEqual(arg.Tags, daemon.Tags) { - continue - } - // ELSE nested.tags :: jsonb <@ @tags :: jsonb - if !tagsSubset(arg.Tags, daemon.Tags) { - continue - } - } - - var status database.ProvisionerDaemonStatus - var currentJob database.ProvisionerJob - if !daemon.LastSeenAt.Valid || daemon.LastSeenAt.Time.Before(time.Now().Add(-time.Duration(arg.StaleIntervalMS)*time.Millisecond)) { - status = database.ProvisionerDaemonStatusOffline - } else { - for _, job := range q.provisionerJobs { - if job.WorkerID.Valid && job.WorkerID.UUID == daemon.ID && !job.CompletedAt.Valid && !job.Error.Valid { - currentJob = job - break - } - } - - if currentJob.ID != uuid.Nil { - status = database.ProvisionerDaemonStatusBusy - } else { - status = database.ProvisionerDaemonStatusIdle - } - } - var currentTemplate database.Template - if currentJob.ID != uuid.Nil { - var input codersdk.ProvisionerJobInput - err := json.Unmarshal(currentJob.Input, &input) - if err != nil { - return nil, err - } - if input.WorkspaceBuildID != nil { - b, err := q.getWorkspaceBuildByIDNoLock(ctx, *input.WorkspaceBuildID) - if err != nil { - return nil, err - } - input.TemplateVersionID = &b.TemplateVersionID - } - if input.TemplateVersionID != nil { - v, err := q.getTemplateVersionByIDNoLock(ctx, *input.TemplateVersionID) - if err != nil { - return nil, err - } - currentTemplate, err = q.getTemplateByIDNoLock(ctx, v.TemplateID.UUID) - if err != nil { - return nil, err - } - } - } - - var previousJob database.ProvisionerJob - for _, job := range q.provisionerJobs { - if !job.WorkerID.Valid || job.WorkerID.UUID != daemon.ID { - continue - } - - if job.StartedAt.Valid || - job.CanceledAt.Valid || - job.CompletedAt.Valid || - job.Error.Valid { - if job.CompletedAt.Time.After(previousJob.CompletedAt.Time) { - previousJob = job - } - } - } - var previousTemplate database.Template - if previousJob.ID != uuid.Nil { - var input codersdk.ProvisionerJobInput - err := json.Unmarshal(previousJob.Input, &input) - if err != nil { - return nil, err - } - if input.WorkspaceBuildID != nil { - b, err := q.getWorkspaceBuildByIDNoLock(ctx, *input.WorkspaceBuildID) - if err != nil { - return nil, err - } - input.TemplateVersionID = &b.TemplateVersionID - } - if input.TemplateVersionID != nil { - v, err := q.getTemplateVersionByIDNoLock(ctx, *input.TemplateVersionID) - if err != nil { - return nil, err - } - previousTemplate, err = q.getTemplateByIDNoLock(ctx, v.TemplateID.UUID) - if err != nil { - return nil, err - } - } - } - - // Get the provisioner key name - var keyName string - for _, key := range q.provisionerKeys { - if key.ID == daemon.KeyID { - keyName = key.Name - break - } - } - - rows = append(rows, database.GetProvisionerDaemonsWithStatusByOrganizationRow{ - ProvisionerDaemon: daemon, - Status: status, - KeyName: keyName, - CurrentJobID: uuid.NullUUID{UUID: currentJob.ID, Valid: currentJob.ID != uuid.Nil}, - CurrentJobStatus: database.NullProvisionerJobStatus{ProvisionerJobStatus: currentJob.JobStatus, Valid: currentJob.ID != uuid.Nil}, - CurrentJobTemplateName: currentTemplate.Name, - CurrentJobTemplateDisplayName: currentTemplate.DisplayName, - CurrentJobTemplateIcon: currentTemplate.Icon, - PreviousJobID: uuid.NullUUID{UUID: previousJob.ID, Valid: previousJob.ID != uuid.Nil}, - PreviousJobStatus: database.NullProvisionerJobStatus{ProvisionerJobStatus: previousJob.JobStatus, Valid: previousJob.ID != uuid.Nil}, - PreviousJobTemplateName: previousTemplate.Name, - PreviousJobTemplateDisplayName: previousTemplate.DisplayName, - PreviousJobTemplateIcon: previousTemplate.Icon, - }) - } - - slices.SortFunc(rows, func(a, b database.GetProvisionerDaemonsWithStatusByOrganizationRow) int { - return b.ProvisionerDaemon.CreatedAt.Compare(a.ProvisionerDaemon.CreatedAt) - }) - - if arg.Limit.Valid && arg.Limit.Int32 > 0 && len(rows) > int(arg.Limit.Int32) { - rows = rows[:arg.Limit.Int32] - } - - return rows, nil -} - -func (q *FakeQuerier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getProvisionerJobByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetProvisionerJobByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getProvisionerJobByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetProvisionerJobTimingsByJobID(_ context.Context, jobID uuid.UUID) ([]database.ProvisionerJobTiming, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - timings := make([]database.ProvisionerJobTiming, 0) - for _, timing := range q.provisionerJobTimings { - if timing.JobID == jobID { - timings = append(timings, timing) - } - } - if len(timings) == 0 { - return nil, sql.ErrNoRows - } - sort.Slice(timings, func(i, j int) bool { - return timings[i].StartedAt.Before(timings[j].StartedAt) - }) - - return timings, nil -} - -func (q *FakeQuerier) GetProvisionerJobsByIDs(_ context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - jobs := make([]database.ProvisionerJob, 0) - for _, job := range q.provisionerJobs { - for _, id := range ids { - if id == job.ID { - // clone the Tags before appending, since maps are reference types and - // we don't want the caller to be able to mutate the map we have inside - // dbmem! - job.Tags = maps.Clone(job.Tags) - jobs = append(jobs, job) - break - } - } - } - if len(jobs) == 0 { - return nil, sql.ErrNoRows - } - - return jobs, nil -} - -func (q *FakeQuerier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if arg.IDs == nil { - arg.IDs = []uuid.UUID{} - } - return q.getProvisionerJobsByIDsWithQueuePositionLockedTagBasedQueue(ctx, arg.IDs) -} - -func (q *FakeQuerier) GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner(ctx context.Context, arg database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerParams) ([]database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - /* - -- name: GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner :many - WITH pending_jobs AS ( - SELECT - id, created_at - FROM - provisioner_jobs - WHERE - started_at IS NULL - AND - canceled_at IS NULL - AND - completed_at IS NULL - AND - error IS NULL - ), - queue_position AS ( - SELECT - id, - ROW_NUMBER() OVER (ORDER BY created_at ASC) AS queue_position - FROM - pending_jobs - ), - queue_size AS ( - SELECT COUNT(*) AS count FROM pending_jobs - ) - SELECT - sqlc.embed(pj), - COALESCE(qp.queue_position, 0) AS queue_position, - COALESCE(qs.count, 0) AS queue_size, - array_agg(DISTINCT pd.id) FILTER (WHERE pd.id IS NOT NULL)::uuid[] AS available_workers - FROM - provisioner_jobs pj - LEFT JOIN - queue_position qp ON qp.id = pj.id - LEFT JOIN - queue_size qs ON TRUE - LEFT JOIN - provisioner_daemons pd ON ( - -- See AcquireProvisionerJob. - pj.started_at IS NULL - AND pj.organization_id = pd.organization_id - AND pj.provisioner = ANY(pd.provisioners) - AND provisioner_tagset_contains(pd.tags, pj.tags) - ) - WHERE - (sqlc.narg('organization_id')::uuid IS NULL OR pj.organization_id = @organization_id) - AND (COALESCE(array_length(@status::provisioner_job_status[], 1), 1) > 0 OR pj.job_status = ANY(@status::provisioner_job_status[])) - GROUP BY - pj.id, - qp.queue_position, - qs.count - ORDER BY - pj.created_at DESC - LIMIT - sqlc.narg('limit')::int; - */ - rowsWithQueuePosition, err := q.getProvisionerJobsByIDsWithQueuePositionLockedGlobalQueue(ctx, nil) - if err != nil { - return nil, err - } - - var rows []database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow - for _, rowQP := range rowsWithQueuePosition { - job := rowQP.ProvisionerJob - - if job.OrganizationID != arg.OrganizationID { - continue - } - if len(arg.Status) > 0 && !slices.Contains(arg.Status, job.JobStatus) { - continue - } - if len(arg.IDs) > 0 && !slices.Contains(arg.IDs, job.ID) { - continue - } - if len(arg.Tags) > 0 && !tagsSubset(job.Tags, arg.Tags) { - continue - } - - row := database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow{ - ProvisionerJob: rowQP.ProvisionerJob, - QueuePosition: rowQP.QueuePosition, - QueueSize: rowQP.QueueSize, - } - - // Start add metadata. - var input codersdk.ProvisionerJobInput - err := json.Unmarshal([]byte(job.Input), &input) - if err != nil { - return nil, err - } - templateVersionID := input.TemplateVersionID - if input.WorkspaceBuildID != nil { - workspaceBuild, err := q.getWorkspaceBuildByIDNoLock(ctx, *input.WorkspaceBuildID) - if err != nil { - return nil, err - } - workspace, err := q.getWorkspaceByIDNoLock(ctx, workspaceBuild.WorkspaceID) - if err != nil { - return nil, err - } - row.WorkspaceID = uuid.NullUUID{UUID: workspace.ID, Valid: true} - row.WorkspaceName = workspace.Name - if templateVersionID == nil { - templateVersionID = &workspaceBuild.TemplateVersionID - } - } - if templateVersionID != nil { - templateVersion, err := q.getTemplateVersionByIDNoLock(ctx, *templateVersionID) - if err != nil { - return nil, err - } - row.TemplateVersionName = templateVersion.Name - template, err := q.getTemplateByIDNoLock(ctx, templateVersion.TemplateID.UUID) - if err != nil { - return nil, err - } - row.TemplateID = uuid.NullUUID{UUID: template.ID, Valid: true} - row.TemplateName = template.Name - row.TemplateDisplayName = template.DisplayName - } - // End add metadata. - - if row.QueuePosition > 0 { - var availableWorkers []database.ProvisionerDaemon - for _, daemon := range q.provisionerDaemons { - if daemon.OrganizationID == job.OrganizationID && slices.Contains(daemon.Provisioners, job.Provisioner) { - if tagsEqual(job.Tags, tagsUntagged) { - if tagsEqual(job.Tags, daemon.Tags) { - availableWorkers = append(availableWorkers, daemon) - } - } else if tagsSubset(job.Tags, daemon.Tags) { - availableWorkers = append(availableWorkers, daemon) - } - } - } - slices.SortFunc(availableWorkers, func(a, b database.ProvisionerDaemon) int { - return a.CreatedAt.Compare(b.CreatedAt) - }) - for _, worker := range availableWorkers { - row.AvailableWorkers = append(row.AvailableWorkers, worker.ID) - } - } - - // Add daemon name to provisioner job - for _, daemon := range q.provisionerDaemons { - if job.WorkerID.Valid && job.WorkerID.UUID == daemon.ID { - row.WorkerName = daemon.Name - } - } - rows = append(rows, row) - } - - slices.SortFunc(rows, func(a, b database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow) int { - return b.ProvisionerJob.CreatedAt.Compare(a.ProvisionerJob.CreatedAt) - }) - if arg.Limit.Valid && arg.Limit.Int32 > 0 && len(rows) > int(arg.Limit.Int32) { - rows = rows[:arg.Limit.Int32] - } - return rows, nil -} - -func (q *FakeQuerier) GetProvisionerJobsCreatedAfter(_ context.Context, after time.Time) ([]database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - jobs := make([]database.ProvisionerJob, 0) - for _, job := range q.provisionerJobs { - if job.CreatedAt.After(after) { - // clone the Tags before appending, since maps are reference types and - // we don't want the caller to be able to mutate the map we have inside - // dbmem! - job.Tags = maps.Clone(job.Tags) - jobs = append(jobs, job) - } - } - return jobs, nil -} - -func (q *FakeQuerier) GetProvisionerJobsToBeReaped(_ context.Context, arg database.GetProvisionerJobsToBeReapedParams) ([]database.ProvisionerJob, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - maxJobs := arg.MaxJobs - - hungJobs := []database.ProvisionerJob{} - for _, provisionerJob := range q.provisionerJobs { - if !provisionerJob.CompletedAt.Valid { - if (provisionerJob.StartedAt.Valid && provisionerJob.UpdatedAt.Before(arg.HungSince)) || - (!provisionerJob.StartedAt.Valid && provisionerJob.UpdatedAt.Before(arg.PendingSince)) { - // clone the Tags before appending, since maps are reference types and - // we don't want the caller to be able to mutate the map we have inside - // dbmem! - provisionerJob.Tags = maps.Clone(provisionerJob.Tags) - hungJobs = append(hungJobs, provisionerJob) - if len(hungJobs) >= int(maxJobs) { - break - } - } - } - } - insecurerand.Shuffle(len(hungJobs), func(i, j int) { - hungJobs[i], hungJobs[j] = hungJobs[j], hungJobs[i] - }) - return hungJobs, nil -} - -func (q *FakeQuerier) GetProvisionerKeyByHashedSecret(_ context.Context, hashedSecret []byte) (database.ProvisionerKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, key := range q.provisionerKeys { - if bytes.Equal(key.HashedSecret, hashedSecret) { - return key, nil - } - } - - return database.ProvisionerKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetProvisionerKeyByID(_ context.Context, id uuid.UUID) (database.ProvisionerKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, key := range q.provisionerKeys { - if key.ID == id { - return key, nil - } - } - - return database.ProvisionerKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetProvisionerKeyByName(_ context.Context, arg database.GetProvisionerKeyByNameParams) (database.ProvisionerKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, key := range q.provisionerKeys { - if strings.EqualFold(key.Name, arg.Name) && key.OrganizationID == arg.OrganizationID { - return key, nil - } - } - - return database.ProvisionerKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetProvisionerLogsAfterID(_ context.Context, arg database.GetProvisionerLogsAfterIDParams) ([]database.ProvisionerJobLog, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - logs := make([]database.ProvisionerJobLog, 0) - for _, jobLog := range q.provisionerJobLogs { - if jobLog.JobID != arg.JobID { - continue - } - if jobLog.ID <= arg.CreatedAfter { - continue - } - logs = append(logs, jobLog) - } - return logs, nil -} - -func (q *FakeQuerier) GetQuotaAllowanceForUser(_ context.Context, params database.GetQuotaAllowanceForUserParams) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var sum int64 - for _, member := range q.groupMembers { - if member.UserID != params.UserID { - continue - } - if _, err := q.getOrganizationByIDNoLock(member.GroupID); err == nil { - // This should never happen, but it has been reported in customer deployments. - // The SQL handles this case, and omits `group_members` rows in the - // Everyone group. It counts these distinctly via `organization_members` table. - continue - } - for _, group := range q.groups { - if group.ID == member.GroupID { - sum += int64(group.QuotaAllowance) - continue - } - } - } - - // Grab the quota for the Everyone group iff the user is a member of - // said organization. - for _, mem := range q.organizationMembers { - if mem.UserID != params.UserID { - continue - } - - group, err := q.getGroupByIDNoLock(context.Background(), mem.OrganizationID) - if err != nil { - return -1, xerrors.Errorf("failed to get everyone group for org %q", mem.OrganizationID.String()) - } - if group.OrganizationID != params.OrganizationID { - continue - } - sum += int64(group.QuotaAllowance) - } - - return sum, nil -} - -func (q *FakeQuerier) GetQuotaConsumedForUser(_ context.Context, params database.GetQuotaConsumedForUserParams) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var sum int64 - for _, workspace := range q.workspaces { - if workspace.OwnerID != params.OwnerID { - continue - } - if workspace.OrganizationID != params.OrganizationID { - continue - } - if workspace.Deleted { - continue - } - - var lastBuild database.WorkspaceBuild - for _, build := range q.workspaceBuilds { - if build.WorkspaceID != workspace.ID { - continue - } - if build.CreatedAt.After(lastBuild.CreatedAt) { - lastBuild = build - } - } - sum += int64(lastBuild.DailyCost) - } - return sum, nil -} - -func (q *FakeQuerier) GetReplicaByID(_ context.Context, id uuid.UUID) (database.Replica, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, replica := range q.replicas { - if replica.ID == id { - return replica, nil - } - } - - return database.Replica{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetReplicasUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.Replica, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - replicas := make([]database.Replica, 0) - for _, replica := range q.replicas { - if replica.UpdatedAt.After(updatedAt) && !replica.StoppedAt.Valid { - replicas = append(replicas, replica) - } - } - return replicas, nil -} - -func (q *FakeQuerier) GetRunningPrebuiltWorkspaces(ctx context.Context) ([]database.GetRunningPrebuiltWorkspacesRow, error) { - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) GetRuntimeConfig(_ context.Context, key string) (string, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - val, ok := q.runtimeConfig[key] - if !ok { - return "", sql.ErrNoRows - } - - return val, nil -} - -func (*FakeQuerier) GetTailnetAgents(context.Context, uuid.UUID) ([]database.TailnetAgent, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetTailnetClientsForAgent(context.Context, uuid.UUID) ([]database.TailnetClient, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetTailnetPeers(context.Context, uuid.UUID) ([]database.TailnetPeer, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetTailnetTunnelPeerBindings(context.Context, uuid.UUID) ([]database.GetTailnetTunnelPeerBindingsRow, error) { - return nil, ErrUnimplemented -} - -func (*FakeQuerier) GetTailnetTunnelPeerIDs(context.Context, uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) { - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) GetTelemetryItem(_ context.Context, key string) (database.TelemetryItem, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, item := range q.telemetryItems { - if item.Key == key { - return item, nil - } - } - - return database.TelemetryItem{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetTelemetryItems(_ context.Context) ([]database.TelemetryItem, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - return slices.Clone(q.telemetryItems), nil -} - -func (q *FakeQuerier) GetTemplateAppInsights(ctx context.Context, arg database.GetTemplateAppInsightsParams) ([]database.GetTemplateAppInsightsRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - /* - WITH - */ - - /* - -- Create a list of all unique apps by template, this is used to - -- filter out irrelevant template usage stats. - apps AS ( - SELECT DISTINCT ON (ws.template_id, app.slug) - ws.template_id, - app.slug, - app.display_name, - app.icon - FROM - workspaces ws - JOIN - workspace_builds AS build - ON - build.workspace_id = ws.id - JOIN - workspace_resources AS resource - ON - resource.job_id = build.job_id - JOIN - workspace_agents AS agent - ON - agent.resource_id = resource.id - JOIN - workspace_apps AS app - ON - app.agent_id = agent.id - WHERE - -- Partial query parameter filter. - CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN ws.template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - ORDER BY - ws.template_id, app.slug, app.created_at DESC - ), - -- Join apps and template usage stats to filter out irrelevant rows. - -- Note that this way of joining will eliminate all data-points that - -- aren't for "real" apps. That means ports are ignored (even though - -- they're part of the dataset), as well as are "[terminal]" entries - -- which are alternate datapoints for reconnecting pty usage. - template_usage_stats_with_apps AS ( - SELECT - tus.start_time, - tus.template_id, - tus.user_id, - apps.slug, - apps.display_name, - apps.icon, - tus.app_usage_mins - FROM - apps - JOIN - template_usage_stats AS tus - ON - -- Query parameter filter. - tus.start_time >= @start_time::timestamptz - AND tus.end_time <= @end_time::timestamptz - AND CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN tus.template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - -- Primary join condition. - AND tus.template_id = apps.template_id - AND apps.slug IN (SELECT jsonb_object_keys(tus.app_usage_mins)) - ), - -- Group the app insights by interval, user and unique app. This - -- allows us to deduplicate a user using the same app across - -- multiple templates. - app_insights AS ( - SELECT - user_id, - slug, - display_name, - icon, - -- See motivation in GetTemplateInsights for LEAST(SUM(n), 30). - LEAST(SUM(app_usage.value::smallint), 30) AS usage_mins - FROM - template_usage_stats_with_apps, jsonb_each(app_usage_mins) AS app_usage - WHERE - app_usage.key = slug - GROUP BY - start_time, user_id, slug, display_name, icon - ), - -- Analyze the users unique app usage across all templates. Count - -- usage across consecutive intervals as continuous usage. - times_used AS ( - SELECT DISTINCT ON (user_id, slug, display_name, icon, uniq) - slug, - display_name, - icon, - -- Turn start_time into a unique identifier that identifies a users - -- continuous app usage. The value of uniq is otherwise garbage. - -- - -- Since we're aggregating per user app usage across templates, - -- there can be duplicate start_times. To handle this, we use the - -- dense_rank() function, otherwise row_number() would suffice. - start_time - ( - dense_rank() OVER ( - PARTITION BY - user_id, slug, display_name, icon - ORDER BY - start_time - ) * '30 minutes'::interval - ) AS uniq - FROM - template_usage_stats_with_apps - ), - */ - - // Due to query optimizations, this logic is somewhat inverted from - // the above query. - type appInsightsGroupBy struct { - StartTime time.Time - UserID uuid.UUID - Slug string - DisplayName string - Icon string - } - type appTimesUsedGroupBy struct { - UserID uuid.UUID - Slug string - DisplayName string - Icon string - } - type appInsightsRow struct { - appInsightsGroupBy - TemplateIDs []uuid.UUID - AppUsageMins int64 - } - appInsightRows := make(map[appInsightsGroupBy]appInsightsRow) - appTimesUsedRows := make(map[appTimesUsedGroupBy]map[time.Time]struct{}) - // FROM - for _, stat := range q.templateUsageStats { - // WHERE - if stat.StartTime.Before(arg.StartTime) || stat.EndTime.After(arg.EndTime) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, stat.TemplateID) { - continue - } - - // json_each - for slug, appUsage := range stat.AppUsageMins { - // FROM apps JOIN template_usage_stats - app, _ := q.getLatestWorkspaceAppByTemplateIDUserIDSlugNoLock(ctx, stat.TemplateID, stat.UserID, slug) - if app.Slug == "" { - continue - } - - // SELECT - key := appInsightsGroupBy{ - StartTime: stat.StartTime, - UserID: stat.UserID, - Slug: slug, - DisplayName: app.DisplayName, - Icon: app.Icon, - } - row, ok := appInsightRows[key] - if !ok { - row = appInsightsRow{ - appInsightsGroupBy: key, - } - } - row.TemplateIDs = append(row.TemplateIDs, stat.TemplateID) - row.AppUsageMins = least(row.AppUsageMins+appUsage, 30) - appInsightRows[key] = row - - // Prepare to do times_used calculation, distinct start times. - timesUsedKey := appTimesUsedGroupBy{ - UserID: stat.UserID, - Slug: slug, - DisplayName: app.DisplayName, - Icon: app.Icon, - } - if appTimesUsedRows[timesUsedKey] == nil { - appTimesUsedRows[timesUsedKey] = make(map[time.Time]struct{}) - } - // This assigns a distinct time, so we don't need to - // dense_rank() later on, we can simply do row_number(). - appTimesUsedRows[timesUsedKey][stat.StartTime] = struct{}{} - } - } - - appTimesUsedTempRows := make(map[appTimesUsedGroupBy][]time.Time) - for key, times := range appTimesUsedRows { - for t := range times { - appTimesUsedTempRows[key] = append(appTimesUsedTempRows[key], t) - } - } - for _, times := range appTimesUsedTempRows { - slices.SortFunc(times, func(a, b time.Time) int { - return int(a.Sub(b)) - }) - } - for key, times := range appTimesUsedTempRows { - uniq := make(map[time.Time]struct{}) - for i, t := range times { - uniq[t.Add(-(30 * time.Minute * time.Duration(i)))] = struct{}{} - } - appTimesUsedRows[key] = uniq - } - - /* - -- Even though we allow identical apps to be aggregated across - -- templates, we still want to be able to report which templates - -- the data comes from. - templates AS ( - SELECT - slug, - display_name, - icon, - array_agg(DISTINCT template_id)::uuid[] AS template_ids - FROM - template_usage_stats_with_apps - GROUP BY - slug, display_name, icon - ) - */ - - type appGroupBy struct { - Slug string - DisplayName string - Icon string - } - type templateRow struct { - appGroupBy - TemplateIDs []uuid.UUID - } - - templateRows := make(map[appGroupBy]templateRow) - for _, aiRow := range appInsightRows { - key := appGroupBy{ - Slug: aiRow.Slug, - DisplayName: aiRow.DisplayName, - Icon: aiRow.Icon, - } - row, ok := templateRows[key] - if !ok { - row = templateRow{ - appGroupBy: key, - } - } - row.TemplateIDs = uniqueSortedUUIDs(append(row.TemplateIDs, aiRow.TemplateIDs...)) - templateRows[key] = row - } - - /* - SELECT - t.template_ids, - COUNT(DISTINCT ai.user_id) AS active_users, - ai.slug, - ai.display_name, - ai.icon, - (SUM(ai.usage_mins) * 60)::bigint AS usage_seconds - FROM - app_insights AS ai - JOIN - templates AS t - ON - t.slug = ai.slug - AND t.display_name = ai.display_name - AND t.icon = ai.icon - GROUP BY - t.template_ids, ai.slug, ai.display_name, ai.icon; - */ - - type templateAppInsightsRow struct { - TemplateIDs []uuid.UUID - ActiveUserIDs []uuid.UUID - UsageSeconds int64 - } - groupedRows := make(map[appGroupBy]templateAppInsightsRow) - for _, aiRow := range appInsightRows { - key := appGroupBy{ - Slug: aiRow.Slug, - DisplayName: aiRow.DisplayName, - Icon: aiRow.Icon, - } - row := groupedRows[key] - row.ActiveUserIDs = append(row.ActiveUserIDs, aiRow.UserID) - row.UsageSeconds += aiRow.AppUsageMins * 60 - groupedRows[key] = row - } - - var rows []database.GetTemplateAppInsightsRow - for key, gr := range groupedRows { - row := database.GetTemplateAppInsightsRow{ - TemplateIDs: templateRows[key].TemplateIDs, - ActiveUsers: int64(len(uniqueSortedUUIDs(gr.ActiveUserIDs))), - Slug: key.Slug, - DisplayName: key.DisplayName, - Icon: key.Icon, - UsageSeconds: gr.UsageSeconds, - } - for tuk, uniq := range appTimesUsedRows { - if key.Slug == tuk.Slug && key.DisplayName == tuk.DisplayName && key.Icon == tuk.Icon { - row.TimesUsed += int64(len(uniq)) - } - } - rows = append(rows, row) - } - - // NOTE(mafredri): Add sorting if we decide on how to handle PostgreSQL collations. - // ORDER BY slug_or_port, display_name, icon, is_app - return rows, nil -} - -func (q *FakeQuerier) GetTemplateAppInsightsByTemplate(ctx context.Context, arg database.GetTemplateAppInsightsByTemplateParams) ([]database.GetTemplateAppInsightsByTemplateRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - type uniqueKey struct { - TemplateID uuid.UUID - DisplayName string - Slug string - } - - // map (TemplateID + DisplayName + Slug) x time.Time x UserID x - usageByTemplateAppUser := map[uniqueKey]map[time.Time]map[uuid.UUID]int64{} - - // Review agent stats in terms of usage - for _, s := range q.workspaceAppStats { - // (was.session_started_at >= ts.from_ AND was.session_started_at < ts.to_) - // OR (was.session_ended_at > ts.from_ AND was.session_ended_at < ts.to_) - // OR (was.session_started_at < ts.from_ AND was.session_ended_at >= ts.to_) - if !(((s.SessionStartedAt.After(arg.StartTime) || s.SessionStartedAt.Equal(arg.StartTime)) && s.SessionStartedAt.Before(arg.EndTime)) || - (s.SessionEndedAt.After(arg.StartTime) && s.SessionEndedAt.Before(arg.EndTime)) || - (s.SessionStartedAt.Before(arg.StartTime) && (s.SessionEndedAt.After(arg.EndTime) || s.SessionEndedAt.Equal(arg.EndTime)))) { - continue - } - - w, err := q.getWorkspaceByIDNoLock(ctx, s.WorkspaceID) - if err != nil { - return nil, err - } - - app, _ := q.getWorkspaceAppByAgentIDAndSlugNoLock(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{ - AgentID: s.AgentID, - Slug: s.SlugOrPort, - }) - - key := uniqueKey{ - TemplateID: w.TemplateID, - DisplayName: app.DisplayName, - Slug: app.Slug, - } - - t := s.SessionStartedAt.Truncate(time.Minute) - if t.Before(arg.StartTime) { - t = arg.StartTime - } - for t.Before(s.SessionEndedAt) && t.Before(arg.EndTime) { - if _, ok := usageByTemplateAppUser[key]; !ok { - usageByTemplateAppUser[key] = map[time.Time]map[uuid.UUID]int64{} - } - if _, ok := usageByTemplateAppUser[key][t]; !ok { - usageByTemplateAppUser[key][t] = map[uuid.UUID]int64{} - } - if _, ok := usageByTemplateAppUser[key][t][s.UserID]; !ok { - usageByTemplateAppUser[key][t][s.UserID] = 60 // 1 minute - } - t = t.Add(1 * time.Minute) - } - } - - // Sort usage data - usageKeys := make([]uniqueKey, len(usageByTemplateAppUser)) - var i int - for key := range usageByTemplateAppUser { - usageKeys[i] = key - i++ - } - - slices.SortFunc(usageKeys, func(a, b uniqueKey) int { - if a.TemplateID != b.TemplateID { - return slice.Ascending(a.TemplateID.String(), b.TemplateID.String()) - } - if a.DisplayName != b.DisplayName { - return slice.Ascending(a.DisplayName, b.DisplayName) - } - return slice.Ascending(a.Slug, b.Slug) - }) - - // Build result - var result []database.GetTemplateAppInsightsByTemplateRow - for _, usageKey := range usageKeys { - r := database.GetTemplateAppInsightsByTemplateRow{ - TemplateID: usageKey.TemplateID, - DisplayName: usageKey.DisplayName, - SlugOrPort: usageKey.Slug, - } - for _, mUserUsage := range usageByTemplateAppUser[usageKey] { - r.ActiveUsers += int64(len(mUserUsage)) - for _, usage := range mUserUsage { - r.UsageSeconds += usage - } - } - result = append(result, r) - } - return result, nil -} - -func (q *FakeQuerier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { - if err := validateDatabaseType(arg); err != nil { - return database.GetTemplateAverageBuildTimeRow{}, err - } - - var emptyRow database.GetTemplateAverageBuildTimeRow - var ( - startTimes []float64 - stopTimes []float64 - deleteTimes []float64 - ) - q.mutex.RLock() - defer q.mutex.RUnlock() - for _, wb := range q.workspaceBuilds { - version, err := q.getTemplateVersionByIDNoLock(ctx, wb.TemplateVersionID) - if err != nil { - return emptyRow, err - } - if version.TemplateID != arg.TemplateID { - continue - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, wb.JobID) - if err != nil { - return emptyRow, err - } - if job.CompletedAt.Valid { - took := job.CompletedAt.Time.Sub(job.StartedAt.Time).Seconds() - switch wb.Transition { - case database.WorkspaceTransitionStart: - startTimes = append(startTimes, took) - case database.WorkspaceTransitionStop: - stopTimes = append(stopTimes, took) - case database.WorkspaceTransitionDelete: - deleteTimes = append(deleteTimes, took) - } - } - } - - var row database.GetTemplateAverageBuildTimeRow - row.Delete50, row.Delete95 = tryPercentileDisc(deleteTimes, 50), tryPercentileDisc(deleteTimes, 95) - row.Stop50, row.Stop95 = tryPercentileDisc(stopTimes, 50), tryPercentileDisc(stopTimes, 95) - row.Start50, row.Start95 = tryPercentileDisc(startTimes, 50), tryPercentileDisc(startTimes, 95) - return row, nil -} - -func (q *FakeQuerier) GetTemplateByID(ctx context.Context, id uuid.UUID) (database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getTemplateByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetTemplateByOrganizationAndName(_ context.Context, arg database.GetTemplateByOrganizationAndNameParams) (database.Template, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Template{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, template := range q.templates { - if template.OrganizationID != arg.OrganizationID { - continue - } - if !strings.EqualFold(template.Name, arg.Name) { - continue - } - if template.Deleted != arg.Deleted { - continue - } - return q.templateWithNameNoLock(template), nil - } - return database.Template{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetTemplateDAUs(_ context.Context, arg database.GetTemplateDAUsParams) ([]database.GetTemplateDAUsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - seens := make(map[time.Time]map[uuid.UUID]struct{}) - - for _, as := range q.workspaceAgentStats { - if as.TemplateID != arg.TemplateID { - continue - } - if as.ConnectionCount == 0 { - continue - } - - date := as.CreatedAt.UTC().Add(time.Duration(arg.TzOffset) * time.Hour * -1).Truncate(time.Hour * 24) - - dateEntry := seens[date] - if dateEntry == nil { - dateEntry = make(map[uuid.UUID]struct{}) - } - dateEntry[as.UserID] = struct{}{} - seens[date] = dateEntry - } - - seenKeys := maps.Keys(seens) - sort.Slice(seenKeys, func(i, j int) bool { - return seenKeys[i].Before(seenKeys[j]) - }) - - var rs []database.GetTemplateDAUsRow - for _, key := range seenKeys { - ids := seens[key] - for id := range ids { - rs = append(rs, database.GetTemplateDAUsRow{ - Date: key, - UserID: id, - }) - } - } - - return rs, nil -} - -func (q *FakeQuerier) GetTemplateInsights(_ context.Context, arg database.GetTemplateInsightsParams) (database.GetTemplateInsightsRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.GetTemplateInsightsRow{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - /* - WITH - */ - - /* - insights AS ( - SELECT - user_id, - -- See motivation in GetTemplateInsights for LEAST(SUM(n), 30). - LEAST(SUM(usage_mins), 30) AS usage_mins, - LEAST(SUM(ssh_mins), 30) AS ssh_mins, - LEAST(SUM(sftp_mins), 30) AS sftp_mins, - LEAST(SUM(reconnecting_pty_mins), 30) AS reconnecting_pty_mins, - LEAST(SUM(vscode_mins), 30) AS vscode_mins, - LEAST(SUM(jetbrains_mins), 30) AS jetbrains_mins - FROM - template_usage_stats - WHERE - start_time >= @start_time::timestamptz - AND end_time <= @end_time::timestamptz - AND CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - GROUP BY - start_time, user_id - ), - */ - - type insightsGroupBy struct { - StartTime time.Time - UserID uuid.UUID - } - type insightsRow struct { - insightsGroupBy - UsageMins int16 - SSHMins int16 - SFTPMins int16 - ReconnectingPTYMins int16 - VSCodeMins int16 - JetBrainsMins int16 - } - insights := make(map[insightsGroupBy]insightsRow) - for _, stat := range q.templateUsageStats { - if stat.StartTime.Before(arg.StartTime) || stat.EndTime.After(arg.EndTime) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, stat.TemplateID) { - continue - } - key := insightsGroupBy{ - StartTime: stat.StartTime, - UserID: stat.UserID, - } - row, ok := insights[key] - if !ok { - row = insightsRow{ - insightsGroupBy: key, - } - } - row.UsageMins = least(row.UsageMins+stat.UsageMins, 30) - row.SSHMins = least(row.SSHMins+stat.SshMins, 30) - row.SFTPMins = least(row.SFTPMins+stat.SftpMins, 30) - row.ReconnectingPTYMins = least(row.ReconnectingPTYMins+stat.ReconnectingPtyMins, 30) - row.VSCodeMins = least(row.VSCodeMins+stat.VscodeMins, 30) - row.JetBrainsMins = least(row.JetBrainsMins+stat.JetbrainsMins, 30) - insights[key] = row - } - - /* - templates AS ( - SELECT - array_agg(DISTINCT template_id) AS template_ids, - array_agg(DISTINCT template_id) FILTER (WHERE ssh_mins > 0) AS ssh_template_ids, - array_agg(DISTINCT template_id) FILTER (WHERE sftp_mins > 0) AS sftp_template_ids, - array_agg(DISTINCT template_id) FILTER (WHERE reconnecting_pty_mins > 0) AS reconnecting_pty_template_ids, - array_agg(DISTINCT template_id) FILTER (WHERE vscode_mins > 0) AS vscode_template_ids, - array_agg(DISTINCT template_id) FILTER (WHERE jetbrains_mins > 0) AS jetbrains_template_ids - FROM - template_usage_stats - WHERE - start_time >= @start_time::timestamptz - AND end_time <= @end_time::timestamptz - AND CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - ) - */ - - type templateRow struct { - TemplateIDs []uuid.UUID - SSHTemplateIDs []uuid.UUID - SFTPTemplateIDs []uuid.UUID - ReconnectingPTYIDs []uuid.UUID - VSCodeTemplateIDs []uuid.UUID - JetBrainsTemplateIDs []uuid.UUID - } - templates := templateRow{} - for _, stat := range q.templateUsageStats { - if stat.StartTime.Before(arg.StartTime) || stat.EndTime.After(arg.EndTime) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, stat.TemplateID) { - continue - } - templates.TemplateIDs = append(templates.TemplateIDs, stat.TemplateID) - if stat.SshMins > 0 { - templates.SSHTemplateIDs = append(templates.SSHTemplateIDs, stat.TemplateID) - } - if stat.SftpMins > 0 { - templates.SFTPTemplateIDs = append(templates.SFTPTemplateIDs, stat.TemplateID) - } - if stat.ReconnectingPtyMins > 0 { - templates.ReconnectingPTYIDs = append(templates.ReconnectingPTYIDs, stat.TemplateID) - } - if stat.VscodeMins > 0 { - templates.VSCodeTemplateIDs = append(templates.VSCodeTemplateIDs, stat.TemplateID) - } - if stat.JetbrainsMins > 0 { - templates.JetBrainsTemplateIDs = append(templates.JetBrainsTemplateIDs, stat.TemplateID) - } - } - - /* - SELECT - COALESCE((SELECT template_ids FROM templates), '{}')::uuid[] AS template_ids, -- Includes app usage. - COALESCE((SELECT ssh_template_ids FROM templates), '{}')::uuid[] AS ssh_template_ids, - COALESCE((SELECT sftp_template_ids FROM templates), '{}')::uuid[] AS sftp_template_ids, - COALESCE((SELECT reconnecting_pty_template_ids FROM templates), '{}')::uuid[] AS reconnecting_pty_template_ids, - COALESCE((SELECT vscode_template_ids FROM templates), '{}')::uuid[] AS vscode_template_ids, - COALESCE((SELECT jetbrains_template_ids FROM templates), '{}')::uuid[] AS jetbrains_template_ids, - COALESCE(COUNT(DISTINCT user_id), 0)::bigint AS active_users, -- Includes app usage. - COALESCE(SUM(usage_mins) * 60, 0)::bigint AS usage_total_seconds, -- Includes app usage. - COALESCE(SUM(ssh_mins) * 60, 0)::bigint AS usage_ssh_seconds, - COALESCE(SUM(sftp_mins) * 60, 0)::bigint AS usage_sftp_seconds, - COALESCE(SUM(reconnecting_pty_mins) * 60, 0)::bigint AS usage_reconnecting_pty_seconds, - COALESCE(SUM(vscode_mins) * 60, 0)::bigint AS usage_vscode_seconds, - COALESCE(SUM(jetbrains_mins) * 60, 0)::bigint AS usage_jetbrains_seconds - FROM - insights; - */ - - var row database.GetTemplateInsightsRow - row.TemplateIDs = uniqueSortedUUIDs(templates.TemplateIDs) - row.SshTemplateIds = uniqueSortedUUIDs(templates.SSHTemplateIDs) - row.SftpTemplateIds = uniqueSortedUUIDs(templates.SFTPTemplateIDs) - row.ReconnectingPtyTemplateIds = uniqueSortedUUIDs(templates.ReconnectingPTYIDs) - row.VscodeTemplateIds = uniqueSortedUUIDs(templates.VSCodeTemplateIDs) - row.JetbrainsTemplateIds = uniqueSortedUUIDs(templates.JetBrainsTemplateIDs) - activeUserIDs := make(map[uuid.UUID]struct{}) - for _, insight := range insights { - activeUserIDs[insight.UserID] = struct{}{} - row.UsageTotalSeconds += int64(insight.UsageMins) * 60 - row.UsageSshSeconds += int64(insight.SSHMins) * 60 - row.UsageSftpSeconds += int64(insight.SFTPMins) * 60 - row.UsageReconnectingPtySeconds += int64(insight.ReconnectingPTYMins) * 60 - row.UsageVscodeSeconds += int64(insight.VSCodeMins) * 60 - row.UsageJetbrainsSeconds += int64(insight.JetBrainsMins) * 60 - } - row.ActiveUsers = int64(len(activeUserIDs)) - - return row, nil -} - -func (q *FakeQuerier) GetTemplateInsightsByInterval(_ context.Context, arg database.GetTemplateInsightsByIntervalParams) ([]database.GetTemplateInsightsByIntervalRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - /* - WITH - ts AS ( - SELECT - d::timestamptz AS from_, - CASE - WHEN (d::timestamptz + (@interval_days::int || ' day')::interval) <= @end_time::timestamptz - THEN (d::timestamptz + (@interval_days::int || ' day')::interval) - ELSE @end_time::timestamptz - END AS to_ - FROM - -- Subtract 1 microsecond from end_time to avoid including the next interval in the results. - generate_series(@start_time::timestamptz, (@end_time::timestamptz) - '1 microsecond'::interval, (@interval_days::int || ' day')::interval) AS d - ) - - SELECT - ts.from_ AS start_time, - ts.to_ AS end_time, - array_remove(array_agg(DISTINCT tus.template_id), NULL)::uuid[] AS template_ids, - COUNT(DISTINCT tus.user_id) AS active_users - FROM - ts - LEFT JOIN - template_usage_stats AS tus - ON - tus.start_time >= ts.from_ - AND tus.end_time <= ts.to_ - AND CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN tus.template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - GROUP BY - ts.from_, ts.to_; - */ - - type interval struct { - From time.Time - To time.Time - } - var ts []interval - for d := arg.StartTime; d.Before(arg.EndTime); d = d.AddDate(0, 0, int(arg.IntervalDays)) { - to := d.AddDate(0, 0, int(arg.IntervalDays)) - if to.After(arg.EndTime) { - to = arg.EndTime - } - ts = append(ts, interval{From: d, To: to}) - } - - type grouped struct { - TemplateIDs map[uuid.UUID]struct{} - UserIDs map[uuid.UUID]struct{} - } - groupedByInterval := make(map[interval]grouped) - for _, tus := range q.templateUsageStats { - for _, t := range ts { - if tus.StartTime.Before(t.From) || tus.EndTime.After(t.To) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, tus.TemplateID) { - continue - } - g, ok := groupedByInterval[t] - if !ok { - g = grouped{ - TemplateIDs: make(map[uuid.UUID]struct{}), - UserIDs: make(map[uuid.UUID]struct{}), - } - } - g.TemplateIDs[tus.TemplateID] = struct{}{} - g.UserIDs[tus.UserID] = struct{}{} - groupedByInterval[t] = g - } - } - - var rows []database.GetTemplateInsightsByIntervalRow - for _, t := range ts { // Ordered by interval. - row := database.GetTemplateInsightsByIntervalRow{ - StartTime: t.From, - EndTime: t.To, - } - row.TemplateIDs = uniqueSortedUUIDs(maps.Keys(groupedByInterval[t].TemplateIDs)) - row.ActiveUsers = int64(len(groupedByInterval[t].UserIDs)) - rows = append(rows, row) - } - - return rows, nil -} - -func (q *FakeQuerier) GetTemplateInsightsByTemplate(_ context.Context, arg database.GetTemplateInsightsByTemplateParams) ([]database.GetTemplateInsightsByTemplateRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // map time.Time x TemplateID x UserID x - appUsageByTemplateAndUser := map[time.Time]map[uuid.UUID]map[uuid.UUID]database.GetTemplateInsightsByTemplateRow{} - - // Review agent stats in terms of usage - templateIDSet := make(map[uuid.UUID]struct{}) - - for _, s := range q.workspaceAgentStats { - if s.CreatedAt.Before(arg.StartTime) || s.CreatedAt.Equal(arg.EndTime) || s.CreatedAt.After(arg.EndTime) { - continue - } - if s.ConnectionCount == 0 { - continue - } - - t := s.CreatedAt.Truncate(time.Minute) - templateIDSet[s.TemplateID] = struct{}{} - - if _, ok := appUsageByTemplateAndUser[t]; !ok { - appUsageByTemplateAndUser[t] = make(map[uuid.UUID]map[uuid.UUID]database.GetTemplateInsightsByTemplateRow) - } - - if _, ok := appUsageByTemplateAndUser[t][s.TemplateID]; !ok { - appUsageByTemplateAndUser[t][s.TemplateID] = make(map[uuid.UUID]database.GetTemplateInsightsByTemplateRow) - } - - if _, ok := appUsageByTemplateAndUser[t][s.TemplateID][s.UserID]; !ok { - appUsageByTemplateAndUser[t][s.TemplateID][s.UserID] = database.GetTemplateInsightsByTemplateRow{} - } - - u := appUsageByTemplateAndUser[t][s.TemplateID][s.UserID] - if s.SessionCountJetBrains > 0 { - u.UsageJetbrainsSeconds = 60 - } - if s.SessionCountVSCode > 0 { - u.UsageVscodeSeconds = 60 - } - if s.SessionCountReconnectingPTY > 0 { - u.UsageReconnectingPtySeconds = 60 - } - if s.SessionCountSSH > 0 { - u.UsageSshSeconds = 60 - } - appUsageByTemplateAndUser[t][s.TemplateID][s.UserID] = u - } - - // Sort used templates - templateIDs := make([]uuid.UUID, 0, len(templateIDSet)) - for templateID := range templateIDSet { - templateIDs = append(templateIDs, templateID) - } - slices.SortFunc(templateIDs, func(a, b uuid.UUID) int { - return slice.Ascending(a.String(), b.String()) - }) - - // Build result - var result []database.GetTemplateInsightsByTemplateRow - for _, templateID := range templateIDs { - r := database.GetTemplateInsightsByTemplateRow{ - TemplateID: templateID, - } - - uniqueUsers := map[uuid.UUID]struct{}{} - - for _, mTemplateUserUsage := range appUsageByTemplateAndUser { - mUserUsage, ok := mTemplateUserUsage[templateID] - if !ok { - continue // template was not used in this time window - } - - for userID, usage := range mUserUsage { - uniqueUsers[userID] = struct{}{} - - r.UsageJetbrainsSeconds += usage.UsageJetbrainsSeconds - r.UsageVscodeSeconds += usage.UsageVscodeSeconds - r.UsageReconnectingPtySeconds += usage.UsageReconnectingPtySeconds - r.UsageSshSeconds += usage.UsageSshSeconds - } - } - - r.ActiveUsers = int64(len(uniqueUsers)) - - result = append(result, r) - } - return result, nil -} - -func (q *FakeQuerier) GetTemplateParameterInsights(ctx context.Context, arg database.GetTemplateParameterInsightsParams) ([]database.GetTemplateParameterInsightsRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // WITH latest_workspace_builds ... - latestWorkspaceBuilds := make(map[uuid.UUID]database.WorkspaceBuild) - for _, wb := range q.workspaceBuilds { - if wb.CreatedAt.Before(arg.StartTime) || wb.CreatedAt.Equal(arg.EndTime) || wb.CreatedAt.After(arg.EndTime) { - continue - } - if latestWorkspaceBuilds[wb.WorkspaceID].BuildNumber < wb.BuildNumber { - latestWorkspaceBuilds[wb.WorkspaceID] = wb - } - } - if len(arg.TemplateIDs) > 0 { - for wsID := range latestWorkspaceBuilds { - ws, err := q.getWorkspaceByIDNoLock(ctx, wsID) - if err != nil { - return nil, err - } - if slices.Contains(arg.TemplateIDs, ws.TemplateID) { - delete(latestWorkspaceBuilds, wsID) - } - } - } - // WITH unique_template_params ... - num := int64(0) - uniqueTemplateParams := make(map[string]*database.GetTemplateParameterInsightsRow) - uniqueTemplateParamWorkspaceBuildIDs := make(map[string][]uuid.UUID) - for _, wb := range latestWorkspaceBuilds { - tv, err := q.getTemplateVersionByIDNoLock(ctx, wb.TemplateVersionID) - if err != nil { - return nil, err - } - for _, tvp := range q.templateVersionParameters { - if tvp.TemplateVersionID != tv.ID { - continue - } - // GROUP BY tvp.name, tvp.type, tvp.display_name, tvp.description, tvp.options - key := fmt.Sprintf("%s:%s:%s:%s:%s", tvp.Name, tvp.Type, tvp.DisplayName, tvp.Description, tvp.Options) - if _, ok := uniqueTemplateParams[key]; !ok { - num++ - uniqueTemplateParams[key] = &database.GetTemplateParameterInsightsRow{ - Num: num, - Name: tvp.Name, - Type: tvp.Type, - DisplayName: tvp.DisplayName, - Description: tvp.Description, - Options: tvp.Options, - } - } - uniqueTemplateParams[key].TemplateIDs = append(uniqueTemplateParams[key].TemplateIDs, tv.TemplateID.UUID) - uniqueTemplateParamWorkspaceBuildIDs[key] = append(uniqueTemplateParamWorkspaceBuildIDs[key], wb.ID) - } - } - // SELECT ... - counts := make(map[string]map[string]int64) - for key, utp := range uniqueTemplateParams { - for _, wbp := range q.workspaceBuildParameters { - if !slices.Contains(uniqueTemplateParamWorkspaceBuildIDs[key], wbp.WorkspaceBuildID) { - continue - } - if wbp.Name != utp.Name { - continue - } - if counts[key] == nil { - counts[key] = make(map[string]int64) - } - counts[key][wbp.Value]++ - } - } - - var rows []database.GetTemplateParameterInsightsRow - for key, utp := range uniqueTemplateParams { - for value, count := range counts[key] { - rows = append(rows, database.GetTemplateParameterInsightsRow{ - Num: utp.Num, - TemplateIDs: uniqueSortedUUIDs(utp.TemplateIDs), - Name: utp.Name, - DisplayName: utp.DisplayName, - Type: utp.Type, - Description: utp.Description, - Options: utp.Options, - Value: value, - Count: count, - }) - } - } - - // NOTE(mafredri): Add sorting if we decide on how to handle PostgreSQL collations. - // ORDER BY utp.name, utp.type, utp.display_name, utp.description, utp.options, wbp.value - return rows, nil -} - -func (*FakeQuerier) GetTemplatePresetsWithPrebuilds(_ context.Context, _ uuid.NullUUID) ([]database.GetTemplatePresetsWithPrebuildsRow, error) { - return nil, ErrUnimplemented -} - -func (q *FakeQuerier) GetTemplateUsageStats(_ context.Context, arg database.GetTemplateUsageStatsParams) ([]database.TemplateUsageStat, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - var stats []database.TemplateUsageStat - for _, stat := range q.templateUsageStats { - // Exclude all chunks that don't fall exactly within the range. - if stat.StartTime.Before(arg.StartTime) || stat.EndTime.After(arg.EndTime) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, stat.TemplateID) { - continue - } - stats = append(stats, stat) - } - - if len(stats) == 0 { - return nil, sql.ErrNoRows - } - - return stats, nil -} - -func (q *FakeQuerier) GetTemplateVersionByID(ctx context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getTemplateVersionByIDNoLock(ctx, templateVersionID) -} - -func (q *FakeQuerier) GetTemplateVersionByJobID(_ context.Context, jobID uuid.UUID) (database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, templateVersion := range q.templateVersions { - if templateVersion.JobID != jobID { - continue - } - return q.templateVersionWithUserNoLock(templateVersion), nil - } - return database.TemplateVersion{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetTemplateVersionByTemplateIDAndName(_ context.Context, arg database.GetTemplateVersionByTemplateIDAndNameParams) (database.TemplateVersion, error) { - if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersion{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, templateVersion := range q.templateVersions { - if templateVersion.TemplateID != arg.TemplateID { - continue - } - if !strings.EqualFold(templateVersion.Name, arg.Name) { - continue - } - return q.templateVersionWithUserNoLock(templateVersion), nil - } - return database.TemplateVersion{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetTemplateVersionParameters(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - parameters := make([]database.TemplateVersionParameter, 0) - for _, param := range q.templateVersionParameters { - if param.TemplateVersionID != templateVersionID { - continue - } - parameters = append(parameters, param) - } - sort.Slice(parameters, func(i, j int) bool { - if parameters[i].DisplayOrder != parameters[j].DisplayOrder { - return parameters[i].DisplayOrder < parameters[j].DisplayOrder - } - return strings.ToLower(parameters[i].Name) < strings.ToLower(parameters[j].Name) - }) - return parameters, nil -} - -func (q *FakeQuerier) GetTemplateVersionTerraformValues(ctx context.Context, templateVersionID uuid.UUID) (database.TemplateVersionTerraformValue, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, tvtv := range q.templateVersionTerraformValues { - if tvtv.TemplateVersionID == templateVersionID { - return tvtv, nil - } - } - - return database.TemplateVersionTerraformValue{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetTemplateVersionVariables(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionVariable, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - variables := make([]database.TemplateVersionVariable, 0) - for _, variable := range q.templateVersionVariables { - if variable.TemplateVersionID != templateVersionID { - continue - } - variables = append(variables, variable) - } - return variables, nil -} - -func (q *FakeQuerier) GetTemplateVersionWorkspaceTags(_ context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionWorkspaceTag, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaceTags := make([]database.TemplateVersionWorkspaceTag, 0) - for _, workspaceTag := range q.templateVersionWorkspaceTags { - if workspaceTag.TemplateVersionID != templateVersionID { - continue - } - workspaceTags = append(workspaceTags, workspaceTag) - } - - sort.Slice(workspaceTags, func(i, j int) bool { - return workspaceTags[i].Key < workspaceTags[j].Key - }) - return workspaceTags, nil -} - -func (q *FakeQuerier) GetTemplateVersionsByIDs(_ context.Context, ids []uuid.UUID) ([]database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - versions := make([]database.TemplateVersion, 0) - for _, version := range q.templateVersions { - for _, id := range ids { - if id == version.ID { - versions = append(versions, q.templateVersionWithUserNoLock(version)) - break - } - } - } - if len(versions) == 0 { - return nil, sql.ErrNoRows - } - - return versions, nil -} - -func (q *FakeQuerier) GetTemplateVersionsByTemplateID(_ context.Context, arg database.GetTemplateVersionsByTemplateIDParams) (version []database.TemplateVersion, err error) { - if err := validateDatabaseType(arg); err != nil { - return version, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, templateVersion := range q.templateVersions { - if templateVersion.TemplateID.UUID != arg.TemplateID { - continue - } - if arg.Archived.Valid && arg.Archived.Bool != templateVersion.Archived { - continue - } - version = append(version, q.templateVersionWithUserNoLock(templateVersion)) - } - - // Database orders by created_at - slices.SortFunc(version, func(a, b database.TemplateVersion) int { - if a.CreatedAt.Equal(b.CreatedAt) { - // Technically the postgres database also orders by uuid. So match - // that behavior - return slice.Ascending(a.ID.String(), b.ID.String()) - } - if a.CreatedAt.Before(b.CreatedAt) { - return -1 - } - return 1 - }) - - if arg.AfterID != uuid.Nil { - found := false - for i, v := range version { - if v.ID == arg.AfterID { - // We want to return all users after index i. - version = version[i+1:] - found = true - break - } - } - - // If no users after the time, then we return an empty list. - if !found { - return nil, sql.ErrNoRows - } - } - - if arg.OffsetOpt > 0 { - if int(arg.OffsetOpt) > len(version)-1 { - return nil, sql.ErrNoRows - } - version = version[arg.OffsetOpt:] - } - - if arg.LimitOpt > 0 { - if int(arg.LimitOpt) > len(version) { - // #nosec G115 - Safe conversion as version slice length is expected to be within int32 range - arg.LimitOpt = int32(len(version)) - } - version = version[:arg.LimitOpt] - } - - if len(version) == 0 { - return nil, sql.ErrNoRows - } - - return version, nil -} - -func (q *FakeQuerier) GetTemplateVersionsCreatedAfter(_ context.Context, after time.Time) ([]database.TemplateVersion, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - versions := make([]database.TemplateVersion, 0) - for _, version := range q.templateVersions { - if version.CreatedAt.After(after) { - versions = append(versions, q.templateVersionWithUserNoLock(version)) - } - } - return versions, nil -} - -func (q *FakeQuerier) GetTemplates(_ context.Context) ([]database.Template, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - templates := slices.Clone(q.templates) - slices.SortFunc(templates, func(a, b database.TemplateTable) int { - if a.Name != b.Name { - return slice.Ascending(a.Name, b.Name) - } - return slice.Ascending(a.ID.String(), b.ID.String()) - }) - - return q.templatesWithUserNoLock(templates), nil -} - -func (q *FakeQuerier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - return q.GetAuthorizedTemplates(ctx, arg, nil) -} - -func (q *FakeQuerier) GetUnexpiredLicenses(_ context.Context) ([]database.License, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - now := time.Now() - var results []database.License - for _, l := range q.licenses { - if l.Exp.After(now) { - results = append(results, l) - } - } - sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID }) - return results, nil -} - -func (q *FakeQuerier) GetUserActivityInsights(_ context.Context, arg database.GetUserActivityInsightsParams) ([]database.GetUserActivityInsightsRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - /* - WITH - */ - /* - deployment_stats AS ( - SELECT - start_time, - user_id, - array_agg(template_id) AS template_ids, - -- See motivation in GetTemplateInsights for LEAST(SUM(n), 30). - LEAST(SUM(usage_mins), 30) AS usage_mins - FROM - template_usage_stats - WHERE - start_time >= @start_time::timestamptz - AND end_time <= @end_time::timestamptz - AND CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - GROUP BY - start_time, user_id - ), - */ - - type deploymentStatsGroupBy struct { - StartTime time.Time - UserID uuid.UUID - } - type deploymentStatsRow struct { - deploymentStatsGroupBy - TemplateIDs []uuid.UUID - UsageMins int16 - } - deploymentStatsRows := make(map[deploymentStatsGroupBy]deploymentStatsRow) - for _, stat := range q.templateUsageStats { - if stat.StartTime.Before(arg.StartTime) || stat.EndTime.After(arg.EndTime) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, stat.TemplateID) { - continue - } - key := deploymentStatsGroupBy{ - StartTime: stat.StartTime, - UserID: stat.UserID, - } - row, ok := deploymentStatsRows[key] - if !ok { - row = deploymentStatsRow{ - deploymentStatsGroupBy: key, - } - } - row.TemplateIDs = append(row.TemplateIDs, stat.TemplateID) - row.UsageMins = least(row.UsageMins+stat.UsageMins, 30) - deploymentStatsRows[key] = row - } - - /* - template_ids AS ( - SELECT - user_id, - array_agg(DISTINCT template_id) AS ids - FROM - deployment_stats, unnest(template_ids) template_id - GROUP BY - user_id - ) - */ - - type templateIDsRow struct { - UserID uuid.UUID - TemplateIDs []uuid.UUID - } - templateIDs := make(map[uuid.UUID]templateIDsRow) - for _, dsRow := range deploymentStatsRows { - row, ok := templateIDs[dsRow.UserID] - if !ok { - row = templateIDsRow{ - UserID: row.UserID, - } - } - row.TemplateIDs = uniqueSortedUUIDs(append(row.TemplateIDs, dsRow.TemplateIDs...)) - templateIDs[dsRow.UserID] = row - } - - /* - SELECT - ds.user_id, - u.username, - u.avatar_url, - t.ids::uuid[] AS template_ids, - (SUM(ds.usage_mins) * 60)::bigint AS usage_seconds - FROM - deployment_stats ds - JOIN - users u - ON - u.id = ds.user_id - JOIN - template_ids t - ON - ds.user_id = t.user_id - GROUP BY - ds.user_id, u.username, u.avatar_url, t.ids - ORDER BY - ds.user_id ASC; - */ - - var rows []database.GetUserActivityInsightsRow - groupedRows := make(map[uuid.UUID]database.GetUserActivityInsightsRow) - for _, dsRow := range deploymentStatsRows { - row, ok := groupedRows[dsRow.UserID] - if !ok { - user, err := q.getUserByIDNoLock(dsRow.UserID) - if err != nil { - return nil, err - } - row = database.GetUserActivityInsightsRow{ - UserID: user.ID, - Username: user.Username, - AvatarURL: user.AvatarURL, - TemplateIDs: templateIDs[user.ID].TemplateIDs, - } - } - row.UsageSeconds += int64(dsRow.UsageMins) * 60 - groupedRows[dsRow.UserID] = row - } - for _, row := range groupedRows { - rows = append(rows, row) - } - if len(rows) == 0 { - return nil, sql.ErrNoRows - } - slices.SortFunc(rows, func(a, b database.GetUserActivityInsightsRow) int { - return slice.Ascending(a.UserID.String(), b.UserID.String()) - }) - - return rows, nil -} - -func (q *FakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, user := range q.users { - if !user.Deleted && (strings.EqualFold(user.Email, arg.Email) || strings.EqualFold(user.Username, arg.Username)) { - return user, nil - } - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetUserByID(_ context.Context, id uuid.UUID) (database.User, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getUserByIDNoLock(id) -} - -// nolint:revive // It's not a control flag, it's a filter. -func (q *FakeQuerier) GetUserCount(_ context.Context, includeSystem bool) (int64, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - existing := int64(0) - for _, u := range q.users { - if !includeSystem && u.IsSystem { - continue - } - if !u.Deleted { - existing++ - } - - if !includeSystem && u.IsSystem { - continue - } - } - return existing, nil -} - -func (q *FakeQuerier) GetUserLatencyInsights(_ context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - /* - SELECT - tus.user_id, - u.username, - u.avatar_url, - array_agg(DISTINCT tus.template_id)::uuid[] AS template_ids, - COALESCE((PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY tus.median_latency_ms)), -1)::float AS workspace_connection_latency_50, - COALESCE((PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY tus.median_latency_ms)), -1)::float AS workspace_connection_latency_95 - FROM - template_usage_stats tus - JOIN - users u - ON - u.id = tus.user_id - WHERE - tus.start_time >= @start_time::timestamptz - AND tus.end_time <= @end_time::timestamptz - AND CASE WHEN COALESCE(array_length(@template_ids::uuid[], 1), 0) > 0 THEN tus.template_id = ANY(@template_ids::uuid[]) ELSE TRUE END - GROUP BY - tus.user_id, u.username, u.avatar_url - ORDER BY - tus.user_id ASC; - */ - - latenciesByUserID := make(map[uuid.UUID][]float64) - seenTemplatesByUserID := make(map[uuid.UUID][]uuid.UUID) - for _, stat := range q.templateUsageStats { - if stat.StartTime.Before(arg.StartTime) || stat.EndTime.After(arg.EndTime) { - continue - } - if len(arg.TemplateIDs) > 0 && !slices.Contains(arg.TemplateIDs, stat.TemplateID) { - continue - } - - if stat.MedianLatencyMs.Valid { - latenciesByUserID[stat.UserID] = append(latenciesByUserID[stat.UserID], stat.MedianLatencyMs.Float64) - } - seenTemplatesByUserID[stat.UserID] = uniqueSortedUUIDs(append(seenTemplatesByUserID[stat.UserID], stat.TemplateID)) - } - - var rows []database.GetUserLatencyInsightsRow - for userID, latencies := range latenciesByUserID { - user, err := q.getUserByIDNoLock(userID) - if err != nil { - return nil, err - } - row := database.GetUserLatencyInsightsRow{ - UserID: userID, - Username: user.Username, - AvatarURL: user.AvatarURL, - TemplateIDs: seenTemplatesByUserID[userID], - WorkspaceConnectionLatency50: tryPercentileCont(latencies, 50), - WorkspaceConnectionLatency95: tryPercentileCont(latencies, 95), - } - rows = append(rows, row) - } - slices.SortFunc(rows, func(a, b database.GetUserLatencyInsightsRow) int { - return slice.Ascending(a.UserID.String(), b.UserID.String()) - }) - - return rows, nil -} - -func (q *FakeQuerier) GetUserLinkByLinkedID(_ context.Context, id string) (database.UserLink, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, link := range q.userLinks { - user, err := q.getUserByIDNoLock(link.UserID) - if err == nil && user.Deleted { - continue - } - if link.LinkedID == id { - return link, nil - } - } - return database.UserLink{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) { - if err := validateDatabaseType(params); err != nil { - return database.UserLink{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, link := range q.userLinks { - if link.UserID == params.UserID && link.LoginType == params.LoginType { - return link, nil - } - } - return database.UserLink{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetUserLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.UserLink, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - uls := make([]database.UserLink, 0) - for _, ul := range q.userLinks { - if ul.UserID == userID { - uls = append(uls, ul) - } - } - return uls, nil -} - -func (q *FakeQuerier) GetUserNotificationPreferences(_ context.Context, userID uuid.UUID) ([]database.NotificationPreference, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - out := make([]database.NotificationPreference, 0, len(q.notificationPreferences)) - for _, np := range q.notificationPreferences { - if np.UserID != userID { - continue - } - - out = append(out, np) - } - - return out, nil -} - -func (q *FakeQuerier) GetUserStatusCounts(_ context.Context, arg database.GetUserStatusCountsParams) ([]database.GetUserStatusCountsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - result := make([]database.GetUserStatusCountsRow, 0) - for _, change := range q.userStatusChanges { - if change.ChangedAt.Before(arg.StartTime) || change.ChangedAt.After(arg.EndTime) { - continue - } - date := time.Date(change.ChangedAt.Year(), change.ChangedAt.Month(), change.ChangedAt.Day(), 0, 0, 0, 0, time.UTC) - if !slices.ContainsFunc(result, func(r database.GetUserStatusCountsRow) bool { - return r.Status == change.NewStatus && r.Date.Equal(date) - }) { - result = append(result, database.GetUserStatusCountsRow{ - Status: change.NewStatus, - Date: date, - Count: 1, - }) - } else { - for i, r := range result { - if r.Status == change.NewStatus && r.Date.Equal(date) { - result[i].Count++ - break - } - } - } - } - - return result, nil -} - -func (q *FakeQuerier) GetUserTerminalFont(ctx context.Context, userID uuid.UUID) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, uc := range q.userConfigs { - if uc.UserID != userID || uc.Key != "terminal_font" { - continue - } - return uc.Value, nil - } - - return "", sql.ErrNoRows -} - -func (q *FakeQuerier) GetUserThemePreference(_ context.Context, userID uuid.UUID) (string, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, uc := range q.userConfigs { - if uc.UserID != userID || uc.Key != "theme_preference" { - continue - } - return uc.Value, nil - } - - return "", sql.ErrNoRows -} - -func (q *FakeQuerier) GetUserWorkspaceBuildParameters(_ context.Context, params database.GetUserWorkspaceBuildParametersParams) ([]database.GetUserWorkspaceBuildParametersRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - userWorkspaceIDs := make(map[uuid.UUID]struct{}) - for _, ws := range q.workspaces { - if ws.OwnerID != params.OwnerID { - continue - } - if ws.TemplateID != params.TemplateID { - continue - } - userWorkspaceIDs[ws.ID] = struct{}{} - } - - userWorkspaceBuilds := make(map[uuid.UUID]struct{}) - for _, wb := range q.workspaceBuilds { - if _, ok := userWorkspaceIDs[wb.WorkspaceID]; !ok { - continue - } - userWorkspaceBuilds[wb.ID] = struct{}{} - } - - templateVersions := make(map[uuid.UUID]struct{}) - for _, tv := range q.templateVersions { - if tv.TemplateID.UUID != params.TemplateID { - continue - } - templateVersions[tv.ID] = struct{}{} - } - - tvps := make(map[string]struct{}) - for _, tvp := range q.templateVersionParameters { - if _, ok := templateVersions[tvp.TemplateVersionID]; !ok { - continue - } - - if _, ok := tvps[tvp.Name]; !ok && !tvp.Ephemeral { - tvps[tvp.Name] = struct{}{} - } - } - - userWorkspaceBuildParameters := make(map[string]database.GetUserWorkspaceBuildParametersRow) - for _, wbp := range q.workspaceBuildParameters { - if _, ok := userWorkspaceBuilds[wbp.WorkspaceBuildID]; !ok { - continue - } - if _, ok := tvps[wbp.Name]; !ok { - continue - } - userWorkspaceBuildParameters[wbp.Name] = database.GetUserWorkspaceBuildParametersRow{ - Name: wbp.Name, - Value: wbp.Value, - } - } - - return maps.Values(userWorkspaceBuildParameters), nil -} - -func (q *FakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) { - if err := validateDatabaseType(params); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Avoid side-effect of sorting. - users := make([]database.User, len(q.users)) - copy(users, q.users) - - // Database orders by username - slices.SortFunc(users, func(a, b database.User) int { - return slice.Ascending(strings.ToLower(a.Username), strings.ToLower(b.Username)) - }) - - // Filter out deleted since they should never be returned.. - tmp := make([]database.User, 0, len(users)) - for _, user := range users { - if !user.Deleted { - tmp = append(tmp, user) - } - } - users = tmp - - if params.AfterID != uuid.Nil { - found := false - for i, v := range users { - if v.ID == params.AfterID { - // We want to return all users after index i. - users = users[i+1:] - found = true - break - } - } - - // If no users after the time, then we return an empty list. - if !found { - return []database.GetUsersRow{}, nil - } - } - - if params.Search != "" { - tmp := make([]database.User, 0, len(users)) - for i, user := range users { - if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } - } - users = tmp - } - - if len(params.Status) > 0 { - usersFilteredByStatus := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { - return strings.EqualFold(string(a), string(b)) - }) { - usersFilteredByStatus = append(usersFilteredByStatus, users[i]) - } - } - users = usersFilteredByStatus - } - - if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember().String()) { - usersFilteredByRole := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { - usersFilteredByRole = append(usersFilteredByRole, users[i]) - } - } - users = usersFilteredByRole - } - - if len(params.LoginType) > 0 { - usersFilteredByLoginType := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.ContainsCompare(params.LoginType, user.LoginType, func(a, b database.LoginType) bool { - return strings.EqualFold(string(a), string(b)) - }) { - usersFilteredByLoginType = append(usersFilteredByLoginType, users[i]) - } - } - users = usersFilteredByLoginType - } - - if !params.CreatedBefore.IsZero() { - usersFilteredByCreatedAt := make([]database.User, 0, len(users)) - for i, user := range users { - if user.CreatedAt.Before(params.CreatedBefore) { - usersFilteredByCreatedAt = append(usersFilteredByCreatedAt, users[i]) - } - } - users = usersFilteredByCreatedAt - } - - if !params.CreatedAfter.IsZero() { - usersFilteredByCreatedAt := make([]database.User, 0, len(users)) - for i, user := range users { - if user.CreatedAt.After(params.CreatedAfter) { - usersFilteredByCreatedAt = append(usersFilteredByCreatedAt, users[i]) - } - } - users = usersFilteredByCreatedAt - } - - if !params.LastSeenBefore.IsZero() { - usersFilteredByLastSeen := make([]database.User, 0, len(users)) - for i, user := range users { - if user.LastSeenAt.Before(params.LastSeenBefore) { - usersFilteredByLastSeen = append(usersFilteredByLastSeen, users[i]) - } - } - users = usersFilteredByLastSeen - } - - if !params.LastSeenAfter.IsZero() { - usersFilteredByLastSeen := make([]database.User, 0, len(users)) - for i, user := range users { - if user.LastSeenAt.After(params.LastSeenAfter) { - usersFilteredByLastSeen = append(usersFilteredByLastSeen, users[i]) - } - } - users = usersFilteredByLastSeen - } - - if !params.IncludeSystem { - users = slices.DeleteFunc(users, func(u database.User) bool { - return u.IsSystem - }) - } - - if params.GithubComUserID != 0 { - usersFilteredByGithubComUserID := make([]database.User, 0, len(users)) - for i, user := range users { - if user.GithubComUserID.Int64 == params.GithubComUserID { - usersFilteredByGithubComUserID = append(usersFilteredByGithubComUserID, users[i]) - } - } - users = usersFilteredByGithubComUserID - } - - beforePageCount := len(users) - - if params.OffsetOpt > 0 { - if int(params.OffsetOpt) > len(users)-1 { - return []database.GetUsersRow{}, nil - } - users = users[params.OffsetOpt:] - } - - if params.LimitOpt > 0 { - if int(params.LimitOpt) > len(users) { - // #nosec G115 - Safe conversion as users slice length is expected to be within int32 range - params.LimitOpt = int32(len(users)) - } - users = users[:params.LimitOpt] - } - - return convertUsers(users, int64(beforePageCount)), nil -} - -func (q *FakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]database.User, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - users := make([]database.User, 0) - for _, user := range q.users { - for _, id := range ids { - if user.ID != id { - continue - } - users = append(users, user) - } - } - return users, nil -} - -func (q *FakeQuerier) GetWebpushSubscriptionsByUserID(_ context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - out := make([]database.WebpushSubscription, 0) - for _, subscription := range q.webpushSubscriptions { - if subscription.UserID == userID { - out = append(out, subscription) - } - } - - return out, nil -} - -func (q *FakeQuerier) GetWebpushVAPIDKeys(_ context.Context) (database.GetWebpushVAPIDKeysRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if q.webpushVAPIDPublicKey == "" && q.webpushVAPIDPrivateKey == "" { - return database.GetWebpushVAPIDKeysRow{}, sql.ErrNoRows - } - - return database.GetWebpushVAPIDKeysRow{ - VapidPublicKey: q.webpushVAPIDPublicKey, - VapidPrivateKey: q.webpushVAPIDPrivateKey, - }, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentAndLatestBuildByAuthToken(_ context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - rows := []database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{} - // We want to return the latest build number for each workspace - latestBuildNumber := make(map[uuid.UUID]int32) - - for _, agt := range q.workspaceAgents { - if agt.Deleted { - continue - } - - // get the related workspace and user - for _, res := range q.workspaceResources { - if agt.ResourceID != res.ID { - continue - } - for _, build := range q.workspaceBuilds { - if build.JobID != res.JobID { - continue - } - for _, ws := range q.workspaces { - if build.WorkspaceID != ws.ID { - continue - } - if ws.Deleted { - continue - } - row := database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{ - WorkspaceTable: database.WorkspaceTable{ - ID: ws.ID, - TemplateID: ws.TemplateID, - }, - WorkspaceAgent: agt, - WorkspaceBuild: build, - } - usr, err := q.getUserByIDNoLock(ws.OwnerID) - if err != nil { - return database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{}, sql.ErrNoRows - } - row.WorkspaceTable.OwnerID = usr.ID - - // Keep track of the latest build number - rows = append(rows, row) - if build.BuildNumber > latestBuildNumber[ws.ID] { - latestBuildNumber[ws.ID] = build.BuildNumber - } - } - } - } - } - - for i := range rows { - if rows[i].WorkspaceAgent.AuthToken != authToken { - continue - } - - if rows[i].WorkspaceBuild.BuildNumber != latestBuildNumber[rows[i].WorkspaceTable.ID] { - continue - } - - return rows[i], nil - } - - return database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceAgentByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceID string) (database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - // The schema sorts this by created at, so we iterate the array backwards. - for i := len(q.workspaceAgents) - 1; i >= 0; i-- { - agent := q.workspaceAgents[i] - if !agent.Deleted && agent.AuthInstanceID.Valid && agent.AuthInstanceID.String == instanceID { - return agent, nil - } - } - return database.WorkspaceAgent{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceAgentDevcontainersByAgentID(_ context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentDevcontainer, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - devcontainers := make([]database.WorkspaceAgentDevcontainer, 0) - for _, dc := range q.workspaceAgentDevcontainers { - if dc.WorkspaceAgentID == workspaceAgentID { - devcontainers = append(devcontainers, dc) - } - } - if len(devcontainers) == 0 { - return nil, sql.ErrNoRows - } - return devcontainers, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentLifecycleStateByID(ctx context.Context, id uuid.UUID) (database.GetWorkspaceAgentLifecycleStateByIDRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - agent, err := q.getWorkspaceAgentByIDNoLock(ctx, id) - if err != nil { - return database.GetWorkspaceAgentLifecycleStateByIDRow{}, err - } - return database.GetWorkspaceAgentLifecycleStateByIDRow{ - LifecycleState: agent.LifecycleState, - StartedAt: agent.StartedAt, - ReadyAt: agent.ReadyAt, - }, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentLogSourcesByAgentIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceAgentLogSource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - logSources := make([]database.WorkspaceAgentLogSource, 0) - for _, logSource := range q.workspaceAgentLogSources { - for _, id := range ids { - if logSource.WorkspaceAgentID == id { - logSources = append(logSources, logSource) - break - } - } - } - return logSources, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentLogsAfter(_ context.Context, arg database.GetWorkspaceAgentLogsAfterParams) ([]database.WorkspaceAgentLog, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - logs := []database.WorkspaceAgentLog{} - for _, log := range q.workspaceAgentLogs { - if log.AgentID != arg.AgentID { - continue - } - if arg.CreatedAfter != 0 && log.ID <= arg.CreatedAfter { - continue - } - logs = append(logs, log) - } - return logs, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentMetadata(_ context.Context, arg database.GetWorkspaceAgentMetadataParams) ([]database.WorkspaceAgentMetadatum, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - metadata := make([]database.WorkspaceAgentMetadatum, 0) - for _, m := range q.workspaceAgentMetadata { - if m.WorkspaceAgentID == arg.WorkspaceAgentID { - if len(arg.Keys) > 0 && !slices.Contains(arg.Keys, m.Key) { - continue - } - metadata = append(metadata, m) - } - } - return metadata, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentPortShare(_ context.Context, arg database.GetWorkspaceAgentPortShareParams) (database.WorkspaceAgentPortShare, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceAgentPortShare{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, share := range q.workspaceAgentPortShares { - if share.WorkspaceID == arg.WorkspaceID && share.AgentName == arg.AgentName && share.Port == arg.Port { - return share, nil - } - } - - return database.WorkspaceAgentPortShare{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceAgentScriptTimingsByBuildID(ctx context.Context, id uuid.UUID) ([]database.GetWorkspaceAgentScriptTimingsByBuildIDRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - build, err := q.getWorkspaceBuildByIDNoLock(ctx, id) - if err != nil { - return nil, xerrors.Errorf("get build: %w", err) - } - - resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get resources: %w", err) - } - resourceIDs := make([]uuid.UUID, 0, len(resources)) - for _, res := range resources { - resourceIDs = append(resourceIDs, res.ID) - } - - agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) - if err != nil { - return nil, xerrors.Errorf("get agents: %w", err) - } - agentIDs := make([]uuid.UUID, 0, len(agents)) - for _, agent := range agents { - agentIDs = append(agentIDs, agent.ID) - } - - scripts, err := q.getWorkspaceAgentScriptsByAgentIDsNoLock(agentIDs) - if err != nil { - return nil, xerrors.Errorf("get scripts: %w", err) - } - scriptIDs := make([]uuid.UUID, 0, len(scripts)) - for _, script := range scripts { - scriptIDs = append(scriptIDs, script.ID) - } - - rows := []database.GetWorkspaceAgentScriptTimingsByBuildIDRow{} - for _, t := range q.workspaceAgentScriptTimings { - if !slice.Contains(scriptIDs, t.ScriptID) { - continue - } - - var script database.WorkspaceAgentScript - for _, s := range scripts { - if s.ID == t.ScriptID { - script = s - break - } - } - if script.ID == uuid.Nil { - return nil, xerrors.Errorf("script with ID %s not found", t.ScriptID) - } - - var agent database.WorkspaceAgent - for _, a := range agents { - if a.ID == script.WorkspaceAgentID { - agent = a - break - } - } - if agent.ID == uuid.Nil { - return nil, xerrors.Errorf("agent with ID %s not found", t.ScriptID) - } - - rows = append(rows, database.GetWorkspaceAgentScriptTimingsByBuildIDRow{ - ScriptID: t.ScriptID, - StartedAt: t.StartedAt, - EndedAt: t.EndedAt, - ExitCode: t.ExitCode, - Stage: t.Stage, - Status: t.Status, - DisplayName: script.DisplayName, - WorkspaceAgentID: agent.ID, - WorkspaceAgentName: agent.Name, - }) - } - - // We want to only return the first script run for each Script ID. - slices.SortFunc(rows, func(a, b database.GetWorkspaceAgentScriptTimingsByBuildIDRow) int { - return a.StartedAt.Compare(b.StartedAt) - }) - rows = slices.CompactFunc(rows, func(e1, e2 database.GetWorkspaceAgentScriptTimingsByBuildIDRow) bool { - return e1.ScriptID == e2.ScriptID - }) - - return rows, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentScriptsByAgentIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceAgentScript, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceAgentScriptsByAgentIDsNoLock(ids) -} - -func (q *FakeQuerier) GetWorkspaceAgentStats(_ context.Context, createdAfter time.Time) ([]database.GetWorkspaceAgentStatsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) || agentStat.CreatedAt.Equal(createdAfter) { - agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) - } - } - - latestAgentStats := map[uuid.UUID]database.WorkspaceAgentStat{} - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) || agentStat.CreatedAt.Equal(createdAfter) { - latestAgentStats[agentStat.AgentID] = agentStat - } - } - - statByAgent := map[uuid.UUID]database.GetWorkspaceAgentStatsRow{} - for agentID, agentStat := range latestAgentStats { - stat := statByAgent[agentID] - stat.AgentID = agentStat.AgentID - stat.TemplateID = agentStat.TemplateID - stat.UserID = agentStat.UserID - stat.WorkspaceID = agentStat.WorkspaceID - stat.SessionCountVSCode += agentStat.SessionCountVSCode - stat.SessionCountJetBrains += agentStat.SessionCountJetBrains - stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - stat.SessionCountSSH += agentStat.SessionCountSSH - statByAgent[stat.AgentID] = stat - } - - latenciesByAgent := map[uuid.UUID][]float64{} - minimumDateByAgent := map[uuid.UUID]time.Time{} - for _, agentStat := range agentStatsCreatedAfter { - if agentStat.ConnectionMedianLatencyMS <= 0 { - continue - } - stat := statByAgent[agentStat.AgentID] - minimumDate := minimumDateByAgent[agentStat.AgentID] - if agentStat.CreatedAt.Before(minimumDate) || minimumDate.IsZero() { - minimumDateByAgent[agentStat.AgentID] = agentStat.CreatedAt - } - stat.WorkspaceRxBytes += agentStat.RxBytes - stat.WorkspaceTxBytes += agentStat.TxBytes - statByAgent[agentStat.AgentID] = stat - latenciesByAgent[agentStat.AgentID] = append(latenciesByAgent[agentStat.AgentID], agentStat.ConnectionMedianLatencyMS) - } - - for _, stat := range statByAgent { - stat.AggregatedFrom = minimumDateByAgent[stat.AgentID] - statByAgent[stat.AgentID] = stat - - latencies, ok := latenciesByAgent[stat.AgentID] - if !ok { - continue - } - stat.WorkspaceConnectionLatency50 = tryPercentileCont(latencies, 50) - stat.WorkspaceConnectionLatency95 = tryPercentileCont(latencies, 95) - statByAgent[stat.AgentID] = stat - } - - stats := make([]database.GetWorkspaceAgentStatsRow, 0, len(statByAgent)) - for _, agent := range statByAgent { - stats = append(stats, agent) - } - return stats, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentStatsAndLabels(ctx context.Context, createdAfter time.Time) ([]database.GetWorkspaceAgentStatsAndLabelsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - agentStatsCreatedAfter := make([]database.WorkspaceAgentStat, 0) - latestAgentStats := map[uuid.UUID]database.WorkspaceAgentStat{} - - for _, agentStat := range q.workspaceAgentStats { - if agentStat.CreatedAt.After(createdAfter) { - agentStatsCreatedAfter = append(agentStatsCreatedAfter, agentStat) - latestAgentStats[agentStat.AgentID] = agentStat - } - } - - statByAgent := map[uuid.UUID]database.GetWorkspaceAgentStatsAndLabelsRow{} - - // Session and connection metrics - for _, agentStat := range latestAgentStats { - stat := statByAgent[agentStat.AgentID] - stat.SessionCountVSCode += agentStat.SessionCountVSCode - stat.SessionCountJetBrains += agentStat.SessionCountJetBrains - stat.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - stat.SessionCountSSH += agentStat.SessionCountSSH - stat.ConnectionCount += agentStat.ConnectionCount - if agentStat.ConnectionMedianLatencyMS >= 0 && stat.ConnectionMedianLatencyMS < agentStat.ConnectionMedianLatencyMS { - stat.ConnectionMedianLatencyMS = agentStat.ConnectionMedianLatencyMS - } - statByAgent[agentStat.AgentID] = stat - } - - // Tx, Rx metrics - for _, agentStat := range agentStatsCreatedAfter { - stat := statByAgent[agentStat.AgentID] - stat.RxBytes += agentStat.RxBytes - stat.TxBytes += agentStat.TxBytes - statByAgent[agentStat.AgentID] = stat - } - - // Labels - for _, agentStat := range agentStatsCreatedAfter { - stat := statByAgent[agentStat.AgentID] - - user, err := q.getUserByIDNoLock(agentStat.UserID) - if err != nil { - return nil, err - } - - stat.Username = user.Username - - workspace, err := q.getWorkspaceByIDNoLock(ctx, agentStat.WorkspaceID) - if err != nil { - return nil, err - } - stat.WorkspaceName = workspace.Name - - agent, err := q.getWorkspaceAgentByIDNoLock(ctx, agentStat.AgentID) - if err != nil { - return nil, err - } - stat.AgentName = agent.Name - - statByAgent[agentStat.AgentID] = stat - } - - stats := make([]database.GetWorkspaceAgentStatsAndLabelsRow, 0, len(statByAgent)) - for _, agent := range statByAgent { - stats = append(stats, agent) - } - return stats, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentUsageStats(_ context.Context, createdAt time.Time) ([]database.GetWorkspaceAgentUsageStatsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - type agentStatsKey struct { - UserID uuid.UUID - AgentID uuid.UUID - WorkspaceID uuid.UUID - TemplateID uuid.UUID - } - - type minuteStatsKey struct { - agentStatsKey - MinuteBucket time.Time - } - - latestAgentStats := map[agentStatsKey]database.GetWorkspaceAgentUsageStatsRow{} - latestAgentLatencies := map[agentStatsKey][]float64{} - for _, agentStat := range q.workspaceAgentStats { - key := agentStatsKey{ - UserID: agentStat.UserID, - AgentID: agentStat.AgentID, - WorkspaceID: agentStat.WorkspaceID, - TemplateID: agentStat.TemplateID, - } - if agentStat.CreatedAt.After(createdAt) { - val, ok := latestAgentStats[key] - if ok { - val.WorkspaceRxBytes += agentStat.RxBytes - val.WorkspaceTxBytes += agentStat.TxBytes - latestAgentStats[key] = val - } else { - latestAgentStats[key] = database.GetWorkspaceAgentUsageStatsRow{ - UserID: agentStat.UserID, - AgentID: agentStat.AgentID, - WorkspaceID: agentStat.WorkspaceID, - TemplateID: agentStat.TemplateID, - AggregatedFrom: createdAt, - WorkspaceRxBytes: agentStat.RxBytes, - WorkspaceTxBytes: agentStat.TxBytes, - } - } - - latencies, ok := latestAgentLatencies[key] - if !ok { - latestAgentLatencies[key] = []float64{agentStat.ConnectionMedianLatencyMS} - } else { - latestAgentLatencies[key] = append(latencies, agentStat.ConnectionMedianLatencyMS) - } - } - } - - for key, latencies := range latestAgentLatencies { - val, ok := latestAgentStats[key] - if ok { - val.WorkspaceConnectionLatency50 = tryPercentileCont(latencies, 50) - val.WorkspaceConnectionLatency95 = tryPercentileCont(latencies, 95) - } - latestAgentStats[key] = val - } - - type bucketRow struct { - database.GetWorkspaceAgentUsageStatsRow - MinuteBucket time.Time - } - - minuteBuckets := make(map[minuteStatsKey]bucketRow) - for _, agentStat := range q.workspaceAgentStats { - if agentStat.Usage && - (agentStat.CreatedAt.After(createdAt) || agentStat.CreatedAt.Equal(createdAt)) && - agentStat.CreatedAt.Before(time.Now().Truncate(time.Minute)) { - key := minuteStatsKey{ - agentStatsKey: agentStatsKey{ - UserID: agentStat.UserID, - AgentID: agentStat.AgentID, - WorkspaceID: agentStat.WorkspaceID, - TemplateID: agentStat.TemplateID, - }, - MinuteBucket: agentStat.CreatedAt.Truncate(time.Minute), - } - val, ok := minuteBuckets[key] - if ok { - val.SessionCountVSCode += agentStat.SessionCountVSCode - val.SessionCountJetBrains += agentStat.SessionCountJetBrains - val.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - val.SessionCountSSH += agentStat.SessionCountSSH - minuteBuckets[key] = val - } else { - minuteBuckets[key] = bucketRow{ - GetWorkspaceAgentUsageStatsRow: database.GetWorkspaceAgentUsageStatsRow{ - UserID: agentStat.UserID, - AgentID: agentStat.AgentID, - WorkspaceID: agentStat.WorkspaceID, - TemplateID: agentStat.TemplateID, - SessionCountVSCode: agentStat.SessionCountVSCode, - SessionCountSSH: agentStat.SessionCountSSH, - SessionCountJetBrains: agentStat.SessionCountJetBrains, - SessionCountReconnectingPTY: agentStat.SessionCountReconnectingPTY, - }, - MinuteBucket: agentStat.CreatedAt.Truncate(time.Minute), - } - } - } - } - - // Get the latest minute bucket for each agent. - latestBuckets := make(map[uuid.UUID]bucketRow) - for key, bucket := range minuteBuckets { - latest, ok := latestBuckets[key.AgentID] - if !ok || key.MinuteBucket.After(latest.MinuteBucket) { - latestBuckets[key.AgentID] = bucket - } - } - - for key, stat := range latestAgentStats { - bucket, ok := latestBuckets[stat.AgentID] - if ok { - stat.SessionCountVSCode = bucket.SessionCountVSCode - stat.SessionCountJetBrains = bucket.SessionCountJetBrains - stat.SessionCountReconnectingPTY = bucket.SessionCountReconnectingPTY - stat.SessionCountSSH = bucket.SessionCountSSH - } - latestAgentStats[key] = stat - } - return maps.Values(latestAgentStats), nil -} - -func (q *FakeQuerier) GetWorkspaceAgentUsageStatsAndLabels(_ context.Context, createdAt time.Time) ([]database.GetWorkspaceAgentUsageStatsAndLabelsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - type statsKey struct { - AgentID uuid.UUID - UserID uuid.UUID - WorkspaceID uuid.UUID - } - - latestAgentStats := map[statsKey]database.WorkspaceAgentStat{} - maxConnMedianLatency := 0.0 - for _, agentStat := range q.workspaceAgentStats { - key := statsKey{ - AgentID: agentStat.AgentID, - UserID: agentStat.UserID, - WorkspaceID: agentStat.WorkspaceID, - } - // WHERE workspace_agent_stats.created_at > $1 - // GROUP BY user_id, agent_id, workspace_id - if agentStat.CreatedAt.After(createdAt) { - val, ok := latestAgentStats[key] - if !ok { - val = agentStat - val.SessionCountJetBrains = 0 - val.SessionCountReconnectingPTY = 0 - val.SessionCountSSH = 0 - val.SessionCountVSCode = 0 - } else { - val.RxBytes += agentStat.RxBytes - val.TxBytes += agentStat.TxBytes - } - if agentStat.ConnectionMedianLatencyMS > maxConnMedianLatency { - val.ConnectionMedianLatencyMS = agentStat.ConnectionMedianLatencyMS - } - latestAgentStats[key] = val - } - // WHERE usage = true AND created_at > now() - '1 minute'::interval - // GROUP BY user_id, agent_id, workspace_id - if agentStat.Usage && agentStat.CreatedAt.After(dbtime.Now().Add(-time.Minute)) { - val, ok := latestAgentStats[key] - if !ok { - latestAgentStats[key] = agentStat - } else { - val.SessionCountVSCode += agentStat.SessionCountVSCode - val.SessionCountJetBrains += agentStat.SessionCountJetBrains - val.SessionCountReconnectingPTY += agentStat.SessionCountReconnectingPTY - val.SessionCountSSH += agentStat.SessionCountSSH - val.ConnectionCount += agentStat.ConnectionCount - latestAgentStats[key] = val - } - } - } - - stats := make([]database.GetWorkspaceAgentUsageStatsAndLabelsRow, 0, len(latestAgentStats)) - for key, agentStat := range latestAgentStats { - user, err := q.getUserByIDNoLock(key.UserID) - if err != nil { - return nil, err - } - workspace, err := q.getWorkspaceByIDNoLock(context.Background(), key.WorkspaceID) - if err != nil { - return nil, err - } - agent, err := q.getWorkspaceAgentByIDNoLock(context.Background(), key.AgentID) - if err != nil { - return nil, err - } - stats = append(stats, database.GetWorkspaceAgentUsageStatsAndLabelsRow{ - Username: user.Username, - AgentName: agent.Name, - WorkspaceName: workspace.Name, - RxBytes: agentStat.RxBytes, - TxBytes: agentStat.TxBytes, - SessionCountVSCode: agentStat.SessionCountVSCode, - SessionCountSSH: agentStat.SessionCountSSH, - SessionCountJetBrains: agentStat.SessionCountJetBrains, - SessionCountReconnectingPTY: agentStat.SessionCountReconnectingPTY, - ConnectionCount: agentStat.ConnectionCount, - ConnectionMedianLatencyMS: agentStat.ConnectionMedianLatencyMS, - }) - } - return stats, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentsByParentID(_ context.Context, parentID uuid.UUID) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaceAgents := make([]database.WorkspaceAgent, 0) - for _, agent := range q.workspaceAgents { - if !agent.ParentID.Valid || agent.ParentID.UUID != parentID || agent.Deleted { - continue - } - - workspaceAgents = append(workspaceAgents, agent) - } - - return workspaceAgents, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) -} - -func (q *FakeQuerier) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]database.WorkspaceAgent, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - build, err := q.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams(arg)) - if err != nil { - return nil, err - } - - resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, build.JobID) - if err != nil { - return nil, err - } - - var resourceIDs []uuid.UUID - for _, resource := range resources { - resourceIDs = append(resourceIDs, resource.ID) - } - - return q.GetWorkspaceAgentsByResourceIDs(ctx, resourceIDs) -} - -func (q *FakeQuerier) GetWorkspaceAgentsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaceAgents := make([]database.WorkspaceAgent, 0) - for _, agent := range q.workspaceAgents { - if agent.Deleted { - continue - } - if agent.CreatedAt.After(after) { - workspaceAgents = append(workspaceAgents, agent) - } - } - return workspaceAgents, nil -} - -func (q *FakeQuerier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgent, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Get latest build for workspace. - workspaceBuild, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspaceID) - if err != nil { - return nil, xerrors.Errorf("get latest workspace build: %w", err) - } - - // Get resources for build. - resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, workspaceBuild.JobID) - if err != nil { - return nil, xerrors.Errorf("get workspace resources: %w", err) - } - if len(resources) == 0 { - return []database.WorkspaceAgent{}, nil - } - - resourceIDs := make([]uuid.UUID, len(resources)) - for i, resource := range resources { - resourceIDs[i] = resource.ID - } - - agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs) - if err != nil { - return nil, xerrors.Errorf("get workspace agents: %w", err) - } - - return agents, nil -} - -func (q *FakeQuerier) GetWorkspaceAppByAgentIDAndSlug(ctx context.Context, arg database.GetWorkspaceAppByAgentIDAndSlugParams) (database.WorkspaceApp, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceApp{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceAppByAgentIDAndSlugNoLock(ctx, arg) -} - -func (q *FakeQuerier) GetWorkspaceAppStatusesByAppIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceAppStatus, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - statuses := make([]database.WorkspaceAppStatus, 0) - for _, status := range q.workspaceAppStatuses { - for _, id := range ids { - if status.AppID == id { - statuses = append(statuses, status) - } - } - } - return statuses, nil -} - -func (q *FakeQuerier) GetWorkspaceAppsByAgentID(_ context.Context, id uuid.UUID) ([]database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - apps := make([]database.WorkspaceApp, 0) - for _, app := range q.workspaceApps { - if app.AgentID == id { - apps = append(apps, app) - } - } - return apps, nil -} - -func (q *FakeQuerier) GetWorkspaceAppsByAgentIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - apps := make([]database.WorkspaceApp, 0) - for _, app := range q.workspaceApps { - for _, id := range ids { - if app.AgentID == id { - apps = append(apps, app) - break - } - } - } - return apps, nil -} - -func (q *FakeQuerier) GetWorkspaceAppsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceApp, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - apps := make([]database.WorkspaceApp, 0) - for _, app := range q.workspaceApps { - if app.CreatedAt.After(after) { - apps = append(apps, app) - } - } - return apps, nil -} - -func (q *FakeQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceBuildByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetWorkspaceBuildByJobID(_ context.Context, jobID uuid.UUID) (database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, build := range q.workspaceBuilds { - if build.JobID == jobID { - return q.workspaceBuildWithUserNoLock(build), nil - } - } - return database.WorkspaceBuild{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceBuildByWorkspaceIDAndBuildNumber(_ context.Context, arg database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams) (database.WorkspaceBuild, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceBuild{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.WorkspaceID != arg.WorkspaceID { - continue - } - if workspaceBuild.BuildNumber != arg.BuildNumber { - continue - } - return q.workspaceBuildWithUserNoLock(workspaceBuild), nil - } - return database.WorkspaceBuild{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceBuildParameters(_ context.Context, workspaceBuildID uuid.UUID) ([]database.WorkspaceBuildParameter, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceBuildParametersNoLock(workspaceBuildID) -} - -func (q *FakeQuerier) GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIDs []uuid.UUID) ([]database.WorkspaceBuildParameter, error) { - // No auth filter. - return q.GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs, nil) -} - -func (q *FakeQuerier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - templateStats := map[uuid.UUID]database.GetWorkspaceBuildStatsByTemplatesRow{} - for _, wb := range q.workspaceBuilds { - job, err := q.getProvisionerJobByIDNoLock(ctx, wb.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job by ID: %w", err) - } - - if !job.CompletedAt.Valid { - continue - } - - if wb.CreatedAt.Before(since) { - continue - } - - w, err := q.getWorkspaceByIDNoLock(ctx, wb.WorkspaceID) - if err != nil { - return nil, xerrors.Errorf("get workspace by ID: %w", err) - } - - if _, ok := templateStats[w.TemplateID]; !ok { - t, err := q.getTemplateByIDNoLock(ctx, w.TemplateID) - if err != nil { - return nil, xerrors.Errorf("get template by ID: %w", err) - } - - templateStats[w.TemplateID] = database.GetWorkspaceBuildStatsByTemplatesRow{ - TemplateID: w.TemplateID, - TemplateName: t.Name, - TemplateDisplayName: t.DisplayName, - TemplateOrganizationID: w.OrganizationID, - } - } - - s := templateStats[w.TemplateID] - s.TotalBuilds++ - if job.JobStatus == database.ProvisionerJobStatusFailed { - s.FailedBuilds++ - } - templateStats[w.TemplateID] = s - } - - rows := make([]database.GetWorkspaceBuildStatsByTemplatesRow, 0, len(templateStats)) - for _, ts := range templateStats { - rows = append(rows, ts) - } - - sort.Slice(rows, func(i, j int) bool { - return rows[i].TemplateName < rows[j].TemplateName - }) - return rows, nil -} - -func (q *FakeQuerier) GetWorkspaceBuildsByWorkspaceID(_ context.Context, - params database.GetWorkspaceBuildsByWorkspaceIDParams, -) ([]database.WorkspaceBuild, error) { - if err := validateDatabaseType(params); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - history := make([]database.WorkspaceBuild, 0) - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.CreatedAt.Before(params.Since) { - continue - } - if workspaceBuild.WorkspaceID == params.WorkspaceID { - history = append(history, q.workspaceBuildWithUserNoLock(workspaceBuild)) - } - } - - // Order by build_number - slices.SortFunc(history, func(a, b database.WorkspaceBuild) int { - return slice.Descending(a.BuildNumber, b.BuildNumber) - }) - - if params.AfterID != uuid.Nil { - found := false - for i, v := range history { - if v.ID == params.AfterID { - // We want to return all builds after index i. - history = history[i+1:] - found = true - break - } - } - - // If no builds after the time, then we return an empty list. - if !found { - return nil, sql.ErrNoRows - } - } - - if params.OffsetOpt > 0 { - if int(params.OffsetOpt) > len(history)-1 { - return nil, sql.ErrNoRows - } - history = history[params.OffsetOpt:] - } - - if params.LimitOpt > 0 { - if int(params.LimitOpt) > len(history) { - // #nosec G115 - Safe conversion as history slice length is expected to be within int32 range - params.LimitOpt = int32(len(history)) - } - history = history[:params.LimitOpt] - } - - if len(history) == 0 { - return nil, sql.ErrNoRows - } - return history, nil -} - -func (q *FakeQuerier) GetWorkspaceBuildsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceBuild, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaceBuilds := make([]database.WorkspaceBuild, 0) - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.CreatedAt.After(after) { - workspaceBuilds = append(workspaceBuilds, q.workspaceBuildWithUserNoLock(workspaceBuild)) - } - } - return workspaceBuilds, nil -} - -func (q *FakeQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - w, err := q.getWorkspaceByAgentIDNoLock(ctx, agentID) - if err != nil { - return database.Workspace{}, err - } - - return w, nil -} - -func (q *FakeQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceByIDNoLock(ctx, id) -} - -func (q *FakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg database.GetWorkspaceByOwnerIDAndNameParams) (database.Workspace, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Workspace{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - var found *database.WorkspaceTable - for _, workspace := range q.workspaces { - if workspace.OwnerID != arg.OwnerID { - continue - } - if !strings.EqualFold(workspace.Name, arg.Name) { - continue - } - if workspace.Deleted != arg.Deleted { - continue - } - - // Return the most recent workspace with the given name - if found == nil || workspace.CreatedAt.After(found.CreatedAt) { - found = &workspace - } - } - if found != nil { - return q.extendWorkspace(*found), nil - } - return database.Workspace{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceByResourceID(ctx context.Context, resourceID uuid.UUID) (database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, resource := range q.workspaceResources { - if resource.ID != resourceID { - continue - } - - for _, build := range q.workspaceBuilds { - if build.JobID != resource.JobID { - continue - } - - for _, workspace := range q.workspaces { - if workspace.ID != build.WorkspaceID { - continue - } - - return q.extendWorkspace(workspace), nil - } - } - } - - return database.Workspace{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceAppID uuid.UUID) (database.Workspace, error) { - if err := validateDatabaseType(workspaceAppID); err != nil { - return database.Workspace{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, workspaceApp := range q.workspaceApps { - if workspaceApp.ID == workspaceAppID { - return q.getWorkspaceByAgentIDNoLock(context.Background(), workspaceApp.AgentID) - } - } - return database.Workspace{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceModulesByJobID(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceModule, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - modules := make([]database.WorkspaceModule, 0) - for _, module := range q.workspaceModules { - if module.JobID == jobID { - modules = append(modules, module) - } - } - return modules, nil -} - -func (q *FakeQuerier) GetWorkspaceModulesCreatedAfter(_ context.Context, createdAt time.Time) ([]database.WorkspaceModule, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - modules := make([]database.WorkspaceModule, 0) - for _, module := range q.workspaceModules { - if module.CreatedAt.After(createdAt) { - modules = append(modules, module) - } - } - return modules, nil -} - -func (q *FakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.WorkspaceProxy, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - cpy := make([]database.WorkspaceProxy, 0, len(q.workspaceProxies)) - - for _, p := range q.workspaceProxies { - if !p.Deleted { - cpy = append(cpy, p) - } - } - return cpy, nil -} - -func (q *FakeQuerier) GetWorkspaceProxyByHostname(_ context.Context, params database.GetWorkspaceProxyByHostnameParams) (database.WorkspaceProxy, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Return zero rows if this is called with a non-sanitized hostname. The SQL - // version of this query does the same thing. - if !validProxyByHostnameRegex.MatchString(params.Hostname) { - return database.WorkspaceProxy{}, sql.ErrNoRows - } - - // This regex matches the SQL version. - accessURLRegex := regexp.MustCompile(`[^:]*://` + regexp.QuoteMeta(params.Hostname) + `([:/]?.)*`) - - for _, proxy := range q.workspaceProxies { - if proxy.Deleted { - continue - } - if params.AllowAccessUrl && accessURLRegex.MatchString(proxy.Url) { - return proxy, nil - } - - // Compile the app hostname regex. This is slow sadly. - if params.AllowWildcardHostname { - wildcardRegexp, err := appurl.CompileHostnamePattern(proxy.WildcardHostname) - if err != nil { - return database.WorkspaceProxy{}, xerrors.Errorf("compile hostname pattern %q for proxy %q (%s): %w", proxy.WildcardHostname, proxy.Name, proxy.ID.String(), err) - } - if _, ok := appurl.ExecuteHostnamePattern(wildcardRegexp, params.Hostname); ok { - return proxy, nil - } - } - } - - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (database.WorkspaceProxy, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, proxy := range q.workspaceProxies { - if proxy.ID == id { - return proxy, nil - } - } - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceProxyByName(_ context.Context, name string) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, proxy := range q.workspaceProxies { - if proxy.Deleted { - continue - } - if proxy.Name == name { - return proxy, nil - } - } - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID) (database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, resource := range q.workspaceResources { - if resource.ID == id { - return resource, nil - } - } - return database.WorkspaceResource{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceResourceMetadataByResourceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceResourceMetadatum, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - metadata := make([]database.WorkspaceResourceMetadatum, 0) - for _, metadatum := range q.workspaceResourceMetadata { - for _, id := range ids { - if metadatum.WorkspaceResourceID == id { - metadata = append(metadata, metadatum) - } - } - } - return metadata, nil -} - -func (q *FakeQuerier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, after time.Time) ([]database.WorkspaceResourceMetadatum, error) { - resources, err := q.GetWorkspaceResourcesCreatedAfter(ctx, after) - if err != nil { - return nil, err - } - resourceIDs := map[uuid.UUID]struct{}{} - for _, resource := range resources { - resourceIDs[resource.ID] = struct{}{} - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - metadata := make([]database.WorkspaceResourceMetadatum, 0) - for _, m := range q.workspaceResourceMetadata { - _, ok := resourceIDs[m.WorkspaceResourceID] - if !ok { - continue - } - metadata = append(metadata, m) - } - return metadata, nil -} - -func (q *FakeQuerier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - return q.getWorkspaceResourcesByJobIDNoLock(ctx, jobID) -} - -func (q *FakeQuerier) GetWorkspaceResourcesByJobIDs(_ context.Context, jobIDs []uuid.UUID) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.workspaceResources { - for _, jobID := range jobIDs { - if resource.JobID != jobID { - continue - } - resources = append(resources, resource) - } - } - return resources, nil -} - -func (q *FakeQuerier) GetWorkspaceResourcesCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.workspaceResources { - if resource.CreatedAt.After(after) { - resources = append(resources, resource) - } - } - return resources, nil -} - -func (q *FakeQuerier) GetWorkspaceUniqueOwnerCountByTemplateIDs(_ context.Context, templateIds []uuid.UUID) ([]database.GetWorkspaceUniqueOwnerCountByTemplateIDsRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaceOwners := make(map[uuid.UUID]map[uuid.UUID]struct{}) - for _, workspace := range q.workspaces { - if workspace.Deleted { - continue - } - if !slices.Contains(templateIds, workspace.TemplateID) { - continue - } - _, ok := workspaceOwners[workspace.TemplateID] - if !ok { - workspaceOwners[workspace.TemplateID] = make(map[uuid.UUID]struct{}) - } - workspaceOwners[workspace.TemplateID][workspace.OwnerID] = struct{}{} - } - resp := make([]database.GetWorkspaceUniqueOwnerCountByTemplateIDsRow, 0) - for _, templateID := range templateIds { - count := len(workspaceOwners[templateID]) - resp = append(resp, database.GetWorkspaceUniqueOwnerCountByTemplateIDsRow{ - TemplateID: templateID, - UniqueOwnersSum: int64(count), - }) - } - - return resp, nil -} - -func (q *FakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.GetWorkspacesRow, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - // A nil auth filter means no auth filter. - workspaceRows, err := q.GetAuthorizedWorkspaces(ctx, arg, nil) - return workspaceRows, err -} - -func (q *FakeQuerier) GetWorkspacesAndAgentsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error) { - // No auth filter. - return q.GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx, ownerID, nil) -} - -func (q *FakeQuerier) GetWorkspacesByTemplateID(_ context.Context, templateID uuid.UUID) ([]database.WorkspaceTable, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaces := []database.WorkspaceTable{} - for _, workspace := range q.workspaces { - if workspace.TemplateID == templateID { - workspaces = append(workspaces, workspace) - } - } - - return workspaces, nil -} - -func (q *FakeQuerier) GetWorkspacesEligibleForTransition(ctx context.Context, now time.Time) ([]database.GetWorkspacesEligibleForTransitionRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - workspaces := []database.GetWorkspacesEligibleForTransitionRow{} - for _, workspace := range q.workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get workspace build by ID: %w", err) - } - - user, err := q.getUserByIDNoLock(workspace.OwnerID) - if err != nil { - return nil, xerrors.Errorf("get user by ID: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job by ID: %w", err) - } - - template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) - if err != nil { - return nil, xerrors.Errorf("get template by ID: %w", err) - } - - if workspace.Deleted { - continue - } - - if job.JobStatus != database.ProvisionerJobStatusFailed && - !workspace.DormantAt.Valid && - build.Transition == database.WorkspaceTransitionStart && - (user.Status == database.UserStatusSuspended || (!build.Deadline.IsZero() && build.Deadline.Before(now))) { - workspaces = append(workspaces, database.GetWorkspacesEligibleForTransitionRow{ - ID: workspace.ID, - Name: workspace.Name, - }) - continue - } - - if user.Status == database.UserStatusActive && - job.JobStatus != database.ProvisionerJobStatusFailed && - build.Transition == database.WorkspaceTransitionStop && - workspace.AutostartSchedule.Valid && - // We do not know if workspace with a zero next start is eligible - // for autostart, so we accept this false-positive. This can occur - // when a coder version is upgraded and next_start_at has yet to - // be set. - (workspace.NextStartAt.Time.IsZero() || - !now.Before(workspace.NextStartAt.Time)) { - workspaces = append(workspaces, database.GetWorkspacesEligibleForTransitionRow{ - ID: workspace.ID, - Name: workspace.Name, - }) - continue - } - - if !workspace.DormantAt.Valid && - template.TimeTilDormant > 0 && - now.Sub(workspace.LastUsedAt) >= time.Duration(template.TimeTilDormant) { - workspaces = append(workspaces, database.GetWorkspacesEligibleForTransitionRow{ - ID: workspace.ID, - Name: workspace.Name, - }) - continue - } - - if workspace.DormantAt.Valid && - workspace.DeletingAt.Valid && - workspace.DeletingAt.Time.Before(now) && - template.TimeTilDormantAutoDelete > 0 { - if build.Transition == database.WorkspaceTransitionDelete && - job.JobStatus == database.ProvisionerJobStatusFailed { - if job.CanceledAt.Valid && now.Sub(job.CanceledAt.Time) <= 24*time.Hour { - continue - } - - if job.CompletedAt.Valid && now.Sub(job.CompletedAt.Time) <= 24*time.Hour { - continue - } - } - - workspaces = append(workspaces, database.GetWorkspacesEligibleForTransitionRow{ - ID: workspace.ID, - Name: workspace.Name, - }) - continue - } - - if template.FailureTTL > 0 && - build.Transition == database.WorkspaceTransitionStart && - job.JobStatus == database.ProvisionerJobStatusFailed && - job.CompletedAt.Valid && - now.Sub(job.CompletedAt.Time) > time.Duration(template.FailureTTL) { - workspaces = append(workspaces, database.GetWorkspacesEligibleForTransitionRow{ - ID: workspace.ID, - Name: workspace.Name, - }) - continue - } - } - - return workspaces, nil -} - -func (q *FakeQuerier) HasTemplateVersionsWithAITask(_ context.Context) (bool, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, templateVersion := range q.templateVersions { - if templateVersion.HasAITask.Valid && templateVersion.HasAITask.Bool { - return true, nil - } - } - - return false, nil -} - -func (q *FakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { - if err := validateDatabaseType(arg); err != nil { - return database.APIKey{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - if arg.LifetimeSeconds == 0 { - arg.LifetimeSeconds = 86400 - } - - for _, u := range q.users { - if u.ID == arg.UserID && u.Deleted { - return database.APIKey{}, xerrors.Errorf("refusing to create APIKey for deleted user") - } - } - - //nolint:gosimple - key := database.APIKey{ - ID: arg.ID, - LifetimeSeconds: arg.LifetimeSeconds, - HashedSecret: arg.HashedSecret, - IPAddress: arg.IPAddress, - UserID: arg.UserID, - ExpiresAt: arg.ExpiresAt, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - LastUsed: arg.LastUsed, - LoginType: arg.LoginType, - Scope: arg.Scope, - TokenName: arg.TokenName, - } - q.apiKeys = append(q.apiKeys, key) - return key, nil -} - -func (q *FakeQuerier) InsertAllUsersGroup(ctx context.Context, orgID uuid.UUID) (database.Group, error) { - return q.InsertGroup(ctx, database.InsertGroupParams{ - ID: orgID, - Name: database.EveryoneGroup, - DisplayName: "", - OrganizationID: orgID, - AvatarURL: "", - QuotaAllowance: 0, - }) -} - -func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAuditLogParams) (database.AuditLog, error) { - if err := validateDatabaseType(arg); err != nil { - return database.AuditLog{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - alog := database.AuditLog(arg) - - q.auditLogs = append(q.auditLogs, alog) - slices.SortFunc(q.auditLogs, func(a, b database.AuditLog) int { - if a.Time.Before(b.Time) { - return -1 - } else if a.Time.Equal(b.Time) { - return 0 - } - return 1 - }) - - return alog, nil -} - -func (q *FakeQuerier) InsertCryptoKey(_ context.Context, arg database.InsertCryptoKeyParams) (database.CryptoKey, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.CryptoKey{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - key := database.CryptoKey{ - Feature: arg.Feature, - Sequence: arg.Sequence, - Secret: arg.Secret, - SecretKeyID: arg.SecretKeyID, - StartsAt: arg.StartsAt, - } - - q.cryptoKeys = append(q.cryptoKeys, key) - - return key, nil -} - -func (q *FakeQuerier) InsertCustomRole(_ context.Context, arg database.InsertCustomRoleParams) (database.CustomRole, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.CustomRole{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - for i := range q.customRoles { - if strings.EqualFold(q.customRoles[i].Name, arg.Name) && - q.customRoles[i].OrganizationID.UUID == arg.OrganizationID.UUID { - return database.CustomRole{}, errUniqueConstraint - } - } - - role := database.CustomRole{ - ID: uuid.New(), - Name: arg.Name, - DisplayName: arg.DisplayName, - OrganizationID: arg.OrganizationID, - SitePermissions: arg.SitePermissions, - OrgPermissions: arg.OrgPermissions, - UserPermissions: arg.UserPermissions, - CreatedAt: dbtime.Now(), - UpdatedAt: dbtime.Now(), - } - q.customRoles = append(q.customRoles, role) - - return role, nil -} - -func (q *FakeQuerier) InsertDBCryptKey(_ context.Context, arg database.InsertDBCryptKeyParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - for _, key := range q.dbcryptKeys { - if key.Number == arg.Number { - return errUniqueConstraint - } - } - - q.dbcryptKeys = append(q.dbcryptKeys, database.DBCryptKey{ - Number: arg.Number, - ActiveKeyDigest: sql.NullString{String: arg.ActiveKeyDigest, Valid: true}, - Test: arg.Test, - }) - return nil -} - -func (q *FakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.derpMeshKey = id - return nil -} - -func (q *FakeQuerier) InsertDeploymentID(_ context.Context, id string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.deploymentID = id - return nil -} - -func (q *FakeQuerier) InsertExternalAuthLink(_ context.Context, arg database.InsertExternalAuthLinkParams) (database.ExternalAuthLink, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ExternalAuthLink{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - // nolint:gosimple - gitAuthLink := database.ExternalAuthLink{ - ProviderID: arg.ProviderID, - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OAuthAccessToken: arg.OAuthAccessToken, - OAuthAccessTokenKeyID: arg.OAuthAccessTokenKeyID, - OAuthRefreshToken: arg.OAuthRefreshToken, - OAuthRefreshTokenKeyID: arg.OAuthRefreshTokenKeyID, - OAuthExpiry: arg.OAuthExpiry, - OAuthExtra: arg.OAuthExtra, - } - q.externalAuthLinks = append(q.externalAuthLinks, gitAuthLink) - return gitAuthLink, nil -} - -func (q *FakeQuerier) InsertFile(_ context.Context, arg database.InsertFileParams) (database.File, error) { - if err := validateDatabaseType(arg); err != nil { - return database.File{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - if slices.ContainsFunc(q.files, func(file database.File) bool { - return file.CreatedBy == arg.CreatedBy && file.Hash == arg.Hash - }) { - return database.File{}, newUniqueConstraintError(database.UniqueFilesHashCreatedByKey) - } - - //nolint:gosimple - file := database.File{ - ID: arg.ID, - Hash: arg.Hash, - CreatedAt: arg.CreatedAt, - CreatedBy: arg.CreatedBy, - Mimetype: arg.Mimetype, - Data: arg.Data, - } - q.files = append(q.files, file) - return file, nil -} - -func (q *FakeQuerier) InsertGitSSHKey(_ context.Context, arg database.InsertGitSSHKeyParams) (database.GitSSHKey, error) { - if err := validateDatabaseType(arg); err != nil { - return database.GitSSHKey{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - gitSSHKey := database.GitSSHKey{ - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - PrivateKey: arg.PrivateKey, - PublicKey: arg.PublicKey, - } - q.gitSSHKey = append(q.gitSSHKey, gitSSHKey) - return gitSSHKey, nil -} - -func (q *FakeQuerier) InsertGroup(_ context.Context, arg database.InsertGroupParams) (database.Group, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Group{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, group := range q.groups { - if group.OrganizationID == arg.OrganizationID && - group.Name == arg.Name { - return database.Group{}, errUniqueConstraint - } - } - - //nolint:gosimple - group := database.Group{ - ID: arg.ID, - Name: arg.Name, - DisplayName: arg.DisplayName, - OrganizationID: arg.OrganizationID, - AvatarURL: arg.AvatarURL, - QuotaAllowance: arg.QuotaAllowance, - Source: database.GroupSourceUser, - } - - q.groups = append(q.groups, group) - - return group, nil -} - -func (q *FakeQuerier) InsertGroupMember(_ context.Context, arg database.InsertGroupMemberParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, member := range q.groupMembers { - if member.GroupID == arg.GroupID && - member.UserID == arg.UserID { - return errUniqueConstraint - } - } - - //nolint:gosimple - q.groupMembers = append(q.groupMembers, database.GroupMemberTable{ - GroupID: arg.GroupID, - UserID: arg.UserID, - }) - - return nil -} - -func (q *FakeQuerier) InsertInboxNotification(_ context.Context, arg database.InsertInboxNotificationParams) (database.InboxNotification, error) { - if err := validateDatabaseType(arg); err != nil { - return database.InboxNotification{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - notification := database.InboxNotification{ - ID: arg.ID, - UserID: arg.UserID, - TemplateID: arg.TemplateID, - Targets: arg.Targets, - Title: arg.Title, - Content: arg.Content, - Icon: arg.Icon, - Actions: arg.Actions, - CreatedAt: arg.CreatedAt, - } - - q.inboxNotifications = append(q.inboxNotifications, notification) - return notification, nil -} - -func (q *FakeQuerier) InsertLicense( - _ context.Context, arg database.InsertLicenseParams, -) (database.License, error) { - if err := validateDatabaseType(arg); err != nil { - return database.License{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - l := database.License{ - ID: q.lastLicenseID + 1, - UploadedAt: arg.UploadedAt, - JWT: arg.JWT, - Exp: arg.Exp, - } - q.lastLicenseID = l.ID - q.licenses = append(q.licenses, l) - return l, nil -} - -func (q *FakeQuerier) InsertMemoryResourceMonitor(_ context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceAgentMemoryResourceMonitor{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:unconvert // The structs field-order differs so this is needed. - monitor := database.WorkspaceAgentMemoryResourceMonitor(database.WorkspaceAgentMemoryResourceMonitor{ - AgentID: arg.AgentID, - Enabled: arg.Enabled, - State: arg.State, - Threshold: arg.Threshold, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - DebouncedUntil: arg.DebouncedUntil, - }) - - q.workspaceAgentMemoryResourceMonitors = append(q.workspaceAgentMemoryResourceMonitors, monitor) - return monitor, nil -} - -func (q *FakeQuerier) InsertMissingGroups(_ context.Context, arg database.InsertMissingGroupsParams) ([]database.Group, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - groupNameMap := make(map[string]struct{}) - for _, g := range arg.GroupNames { - groupNameMap[g] = struct{}{} - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, g := range q.groups { - if g.OrganizationID != arg.OrganizationID { - continue - } - delete(groupNameMap, g.Name) - } - - newGroups := make([]database.Group, 0, len(groupNameMap)) - for k := range groupNameMap { - g := database.Group{ - ID: uuid.New(), - Name: k, - OrganizationID: arg.OrganizationID, - AvatarURL: "", - QuotaAllowance: 0, - DisplayName: "", - Source: arg.Source, - } - q.groups = append(q.groups, g) - newGroups = append(newGroups, g) - } - - return newGroups, nil -} - -func (q *FakeQuerier) InsertOAuth2ProviderApp(_ context.Context, arg database.InsertOAuth2ProviderAppParams) (database.OAuth2ProviderApp, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderApp{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple // Go wants database.OAuth2ProviderApp(arg), but we cannot be sure the structs will remain identical. - app := database.OAuth2ProviderApp{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Name: arg.Name, - Icon: arg.Icon, - CallbackURL: arg.CallbackURL, - RedirectUris: arg.RedirectUris, - ClientType: arg.ClientType, - DynamicallyRegistered: arg.DynamicallyRegistered, - ClientIDIssuedAt: arg.ClientIDIssuedAt, - ClientSecretExpiresAt: arg.ClientSecretExpiresAt, - GrantTypes: arg.GrantTypes, - ResponseTypes: arg.ResponseTypes, - TokenEndpointAuthMethod: arg.TokenEndpointAuthMethod, - Scope: arg.Scope, - Contacts: arg.Contacts, - ClientUri: arg.ClientUri, - LogoUri: arg.LogoUri, - TosUri: arg.TosUri, - PolicyUri: arg.PolicyUri, - JwksUri: arg.JwksUri, - Jwks: arg.Jwks, - SoftwareID: arg.SoftwareID, - SoftwareVersion: arg.SoftwareVersion, - RegistrationAccessToken: arg.RegistrationAccessToken, - RegistrationClientUri: arg.RegistrationClientUri, - } - - // Apply RFC-compliant defaults to match database migration defaults - if !app.ClientType.Valid { - app.ClientType = sql.NullString{String: "confidential", Valid: true} - } - if !app.DynamicallyRegistered.Valid { - app.DynamicallyRegistered = sql.NullBool{Bool: false, Valid: true} - } - if len(app.GrantTypes) == 0 { - app.GrantTypes = []string{"authorization_code", "refresh_token"} - } - if len(app.ResponseTypes) == 0 { - app.ResponseTypes = []string{"code"} - } - if !app.TokenEndpointAuthMethod.Valid { - app.TokenEndpointAuthMethod = sql.NullString{String: "client_secret_basic", Valid: true} - } - if !app.Scope.Valid { - app.Scope = sql.NullString{String: "", Valid: true} - } - if app.Contacts == nil { - app.Contacts = []string{} - } - q.oauth2ProviderApps = append(q.oauth2ProviderApps, app) - - return app, nil -} - -func (q *FakeQuerier) InsertOAuth2ProviderAppCode(_ context.Context, arg database.InsertOAuth2ProviderAppCodeParams) (database.OAuth2ProviderAppCode, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderAppCode{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, app := range q.oauth2ProviderApps { - if app.ID == arg.AppID { - code := database.OAuth2ProviderAppCode{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - ExpiresAt: arg.ExpiresAt, - SecretPrefix: arg.SecretPrefix, - HashedSecret: arg.HashedSecret, - UserID: arg.UserID, - AppID: arg.AppID, - ResourceUri: arg.ResourceUri, - CodeChallenge: arg.CodeChallenge, - CodeChallengeMethod: arg.CodeChallengeMethod, - } - q.oauth2ProviderAppCodes = append(q.oauth2ProviderAppCodes, code) - return code, nil - } - } - - return database.OAuth2ProviderAppCode{}, sql.ErrNoRows -} - -func (q *FakeQuerier) InsertOAuth2ProviderAppSecret(_ context.Context, arg database.InsertOAuth2ProviderAppSecretParams) (database.OAuth2ProviderAppSecret, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderAppSecret{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, app := range q.oauth2ProviderApps { - if app.ID == arg.AppID { - secret := database.OAuth2ProviderAppSecret{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - SecretPrefix: arg.SecretPrefix, - HashedSecret: arg.HashedSecret, - DisplaySecret: arg.DisplaySecret, - AppID: arg.AppID, - } - q.oauth2ProviderAppSecrets = append(q.oauth2ProviderAppSecrets, secret) - return secret, nil - } - } - - return database.OAuth2ProviderAppSecret{}, sql.ErrNoRows -} - -func (q *FakeQuerier) InsertOAuth2ProviderAppToken(_ context.Context, arg database.InsertOAuth2ProviderAppTokenParams) (database.OAuth2ProviderAppToken, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderAppToken{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, secret := range q.oauth2ProviderAppSecrets { - if secret.ID == arg.AppSecretID { - //nolint:gosimple // Go wants database.OAuth2ProviderAppToken(arg), but we cannot be sure the structs will remain identical. - token := database.OAuth2ProviderAppToken{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - ExpiresAt: arg.ExpiresAt, - HashPrefix: arg.HashPrefix, - RefreshHash: arg.RefreshHash, - APIKeyID: arg.APIKeyID, - AppSecretID: arg.AppSecretID, - UserID: arg.UserID, - Audience: arg.Audience, - } - q.oauth2ProviderAppTokens = append(q.oauth2ProviderAppTokens, token) - return token, nil - } - } - - return database.OAuth2ProviderAppToken{}, sql.ErrNoRows -} - -func (q *FakeQuerier) InsertOrganization(_ context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Organization{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - organization := database.Organization{ - ID: arg.ID, - Name: arg.Name, - DisplayName: arg.DisplayName, - Description: arg.Description, - Icon: arg.Icon, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - IsDefault: len(q.organizations) == 0, - } - q.organizations = append(q.organizations, organization) - return organization, nil -} - -func (q *FakeQuerier) InsertOrganizationMember(_ context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) { - if err := validateDatabaseType(arg); err != nil { - return database.OrganizationMember{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - if slices.IndexFunc(q.data.organizationMembers, func(member database.OrganizationMember) bool { - return member.OrganizationID == arg.OrganizationID && member.UserID == arg.UserID - }) >= 0 { - // Error pulled from a live db error - return database.OrganizationMember{}, &pq.Error{ - Severity: "ERROR", - Code: "23505", - Message: "duplicate key value violates unique constraint \"organization_members_pkey\"", - Detail: "Key (organization_id, user_id)=(f7de1f4e-5833-4410-a28d-0a105f96003f, 36052a80-4a7f-4998-a7ca-44cefa608d3e) already exists.", - Table: "organization_members", - Constraint: "organization_members_pkey", - } - } - - //nolint:gosimple - organizationMember := database.OrganizationMember{ - OrganizationID: arg.OrganizationID, - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Roles: arg.Roles, - } - q.organizationMembers = append(q.organizationMembers, organizationMember) - return organizationMember, nil -} - -func (q *FakeQuerier) InsertPreset(_ context.Context, arg database.InsertPresetParams) (database.TemplateVersionPreset, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.TemplateVersionPreset{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple // arg needs to keep its type for interface reasons and that type is not appropriate for preset below. - preset := database.TemplateVersionPreset{ - ID: uuid.New(), - TemplateVersionID: arg.TemplateVersionID, - Name: arg.Name, - CreatedAt: arg.CreatedAt, - DesiredInstances: arg.DesiredInstances, - InvalidateAfterSecs: sql.NullInt32{ - Int32: 0, - Valid: true, - }, - PrebuildStatus: database.PrebuildStatusHealthy, - IsDefault: arg.IsDefault, - } - q.presets = append(q.presets, preset) - return preset, nil -} - -func (q *FakeQuerier) InsertPresetParameters(_ context.Context, arg database.InsertPresetParametersParams) ([]database.TemplateVersionPresetParameter, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - presetParameters := make([]database.TemplateVersionPresetParameter, 0, len(arg.Names)) - for i, v := range arg.Names { - presetParameter := database.TemplateVersionPresetParameter{ - ID: uuid.New(), - TemplateVersionPresetID: arg.TemplateVersionPresetID, - Name: v, - Value: arg.Values[i], - } - presetParameters = append(presetParameters, presetParameter) - q.presetParameters = append(q.presetParameters, presetParameter) - } - - return presetParameters, nil -} - -func (q *FakeQuerier) InsertPresetPrebuildSchedule(ctx context.Context, arg database.InsertPresetPrebuildScheduleParams) (database.TemplateVersionPresetPrebuildSchedule, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.TemplateVersionPresetPrebuildSchedule{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - presetPrebuildSchedule := database.TemplateVersionPresetPrebuildSchedule{ - ID: uuid.New(), - PresetID: arg.PresetID, - CronExpression: arg.CronExpression, - DesiredInstances: arg.DesiredInstances, - } - q.presetPrebuildSchedules = append(q.presetPrebuildSchedules, presetPrebuildSchedule) - return presetPrebuildSchedule, nil -} - -func (q *FakeQuerier) InsertProvisionerJob(_ context.Context, arg database.InsertProvisionerJobParams) (database.ProvisionerJob, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ProvisionerJob{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - job := database.ProvisionerJob{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OrganizationID: arg.OrganizationID, - InitiatorID: arg.InitiatorID, - Provisioner: arg.Provisioner, - StorageMethod: arg.StorageMethod, - FileID: arg.FileID, - Type: arg.Type, - Input: arg.Input, - Tags: maps.Clone(arg.Tags), - TraceMetadata: arg.TraceMetadata, - } - job.JobStatus = provisionerJobStatus(job) - q.provisionerJobs = append(q.provisionerJobs, job) - return job, nil -} - -func (q *FakeQuerier) InsertProvisionerJobLogs(_ context.Context, arg database.InsertProvisionerJobLogsParams) ([]database.ProvisionerJobLog, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - logs := make([]database.ProvisionerJobLog, 0) - id := int64(1) - if len(q.provisionerJobLogs) > 0 { - id = q.provisionerJobLogs[len(q.provisionerJobLogs)-1].ID - } - for index, output := range arg.Output { - id++ - logs = append(logs, database.ProvisionerJobLog{ - ID: id, - JobID: arg.JobID, - CreatedAt: arg.CreatedAt[index], - Source: arg.Source[index], - Level: arg.Level[index], - Stage: arg.Stage[index], - Output: output, - }) - } - q.provisionerJobLogs = append(q.provisionerJobLogs, logs...) - return logs, nil -} - -func (q *FakeQuerier) InsertProvisionerJobTimings(_ context.Context, arg database.InsertProvisionerJobTimingsParams) ([]database.ProvisionerJobTiming, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - insertedTimings := make([]database.ProvisionerJobTiming, 0, len(arg.StartedAt)) - for i := range arg.StartedAt { - timing := database.ProvisionerJobTiming{ - JobID: arg.JobID, - StartedAt: arg.StartedAt[i], - EndedAt: arg.EndedAt[i], - Stage: arg.Stage[i], - Source: arg.Source[i], - Action: arg.Action[i], - Resource: arg.Resource[i], - } - q.provisionerJobTimings = append(q.provisionerJobTimings, timing) - insertedTimings = append(insertedTimings, timing) - } - - return insertedTimings, nil -} - -func (q *FakeQuerier) InsertProvisionerKey(_ context.Context, arg database.InsertProvisionerKeyParams) (database.ProvisionerKey, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.ProvisionerKey{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, key := range q.provisionerKeys { - if key.ID == arg.ID || (key.OrganizationID == arg.OrganizationID && strings.EqualFold(key.Name, arg.Name)) { - return database.ProvisionerKey{}, newUniqueConstraintError(database.UniqueProvisionerKeysOrganizationIDNameIndex) - } - } - - //nolint:gosimple - provisionerKey := database.ProvisionerKey{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - OrganizationID: arg.OrganizationID, - Name: strings.ToLower(arg.Name), - HashedSecret: arg.HashedSecret, - Tags: arg.Tags, - } - q.provisionerKeys = append(q.provisionerKeys, provisionerKey) - - return provisionerKey, nil -} - -func (q *FakeQuerier) InsertReplica(_ context.Context, arg database.InsertReplicaParams) (database.Replica, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Replica{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - replica := database.Replica{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - StartedAt: arg.StartedAt, - UpdatedAt: arg.UpdatedAt, - Hostname: arg.Hostname, - RegionID: arg.RegionID, - RelayAddress: arg.RelayAddress, - Version: arg.Version, - DatabaseLatency: arg.DatabaseLatency, - Primary: arg.Primary, - } - q.replicas = append(q.replicas, replica) - return replica, nil -} - -func (q *FakeQuerier) InsertTelemetryItemIfNotExists(_ context.Context, arg database.InsertTelemetryItemIfNotExistsParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, item := range q.telemetryItems { - if item.Key == arg.Key { - return nil - } - } - - q.telemetryItems = append(q.telemetryItems, database.TelemetryItem{ - Key: arg.Key, - Value: arg.Value, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }) - return nil -} - -func (q *FakeQuerier) InsertTemplate(_ context.Context, arg database.InsertTemplateParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - template := database.TemplateTable{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OrganizationID: arg.OrganizationID, - Name: arg.Name, - Provisioner: arg.Provisioner, - ActiveVersionID: arg.ActiveVersionID, - Description: arg.Description, - CreatedBy: arg.CreatedBy, - UserACL: arg.UserACL, - GroupACL: arg.GroupACL, - DisplayName: arg.DisplayName, - Icon: arg.Icon, - AllowUserCancelWorkspaceJobs: arg.AllowUserCancelWorkspaceJobs, - AllowUserAutostart: true, - AllowUserAutostop: true, - MaxPortSharingLevel: arg.MaxPortSharingLevel, - UseClassicParameterFlow: arg.UseClassicParameterFlow, - } - q.templates = append(q.templates, template) - return nil -} - -func (q *FakeQuerier) InsertTemplateVersion(_ context.Context, arg database.InsertTemplateVersionParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - if len(arg.Message) > 1048576 { - return xerrors.New("message too long") - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - version := database.TemplateVersionTable{ - ID: arg.ID, - TemplateID: arg.TemplateID, - OrganizationID: arg.OrganizationID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Name: arg.Name, - Message: arg.Message, - Readme: arg.Readme, - JobID: arg.JobID, - CreatedBy: arg.CreatedBy, - SourceExampleID: arg.SourceExampleID, - } - q.templateVersions = append(q.templateVersions, version) - return nil -} - -func (q *FakeQuerier) InsertTemplateVersionParameter(_ context.Context, arg database.InsertTemplateVersionParameterParams) (database.TemplateVersionParameter, error) { - if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersionParameter{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - param := database.TemplateVersionParameter{ - TemplateVersionID: arg.TemplateVersionID, - Name: arg.Name, - DisplayName: arg.DisplayName, - Description: arg.Description, - Type: arg.Type, - FormType: arg.FormType, - Mutable: arg.Mutable, - DefaultValue: arg.DefaultValue, - Icon: arg.Icon, - Options: arg.Options, - ValidationError: arg.ValidationError, - ValidationRegex: arg.ValidationRegex, - ValidationMin: arg.ValidationMin, - ValidationMax: arg.ValidationMax, - ValidationMonotonic: arg.ValidationMonotonic, - Required: arg.Required, - DisplayOrder: arg.DisplayOrder, - Ephemeral: arg.Ephemeral, - } - q.templateVersionParameters = append(q.templateVersionParameters, param) - return param, nil -} - -func (q *FakeQuerier) InsertTemplateVersionTerraformValuesByJobID(_ context.Context, arg database.InsertTemplateVersionTerraformValuesByJobIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - // Find the template version by the job_id - templateVersion, ok := slice.Find(q.templateVersions, func(v database.TemplateVersionTable) bool { - return v.JobID == arg.JobID - }) - if !ok { - return sql.ErrNoRows - } - - if !json.Valid(arg.CachedPlan) { - return xerrors.Errorf("cached plan must be valid json, received %q", string(arg.CachedPlan)) - } - - // Insert the new row - row := database.TemplateVersionTerraformValue{ - TemplateVersionID: templateVersion.ID, - UpdatedAt: arg.UpdatedAt, - CachedPlan: arg.CachedPlan, - CachedModuleFiles: arg.CachedModuleFiles, - ProvisionerdVersion: arg.ProvisionerdVersion, - } - q.templateVersionTerraformValues = append(q.templateVersionTerraformValues, row) - return nil -} - -func (q *FakeQuerier) InsertTemplateVersionVariable(_ context.Context, arg database.InsertTemplateVersionVariableParams) (database.TemplateVersionVariable, error) { - if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersionVariable{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - variable := database.TemplateVersionVariable{ - TemplateVersionID: arg.TemplateVersionID, - Name: arg.Name, - Description: arg.Description, - Type: arg.Type, - Value: arg.Value, - DefaultValue: arg.DefaultValue, - Required: arg.Required, - Sensitive: arg.Sensitive, - } - q.templateVersionVariables = append(q.templateVersionVariables, variable) - return variable, nil -} - -func (q *FakeQuerier) InsertTemplateVersionWorkspaceTag(_ context.Context, arg database.InsertTemplateVersionWorkspaceTagParams) (database.TemplateVersionWorkspaceTag, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.TemplateVersionWorkspaceTag{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - workspaceTag := database.TemplateVersionWorkspaceTag{ - TemplateVersionID: arg.TemplateVersionID, - Key: arg.Key, - Value: arg.Value, - } - q.templateVersionWorkspaceTags = append(q.templateVersionWorkspaceTags, workspaceTag) - return workspaceTag, nil -} - -func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, user := range q.users { - if user.Username == arg.Username && !user.Deleted { - return database.User{}, errUniqueConstraint - } - } - - status := database.UserStatusDormant - if arg.Status != "" { - status = database.UserStatus(arg.Status) - } - - user := database.User{ - ID: arg.ID, - Email: arg.Email, - HashedPassword: arg.HashedPassword, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Username: arg.Username, - Name: arg.Name, - Status: status, - RBACRoles: arg.RBACRoles, - LoginType: arg.LoginType, - IsSystem: false, - } - q.users = append(q.users, user) - sort.Slice(q.users, func(i, j int) bool { - return q.users[i].CreatedAt.Before(q.users[j].CreatedAt) - }) - - q.userStatusChanges = append(q.userStatusChanges, database.UserStatusChange{ - UserID: user.ID, - NewStatus: user.Status, - ChangedAt: user.UpdatedAt, - }) - return user, nil -} - -func (q *FakeQuerier) InsertUserGroupsByID(_ context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - var groupIDs []uuid.UUID - for _, group := range q.groups { - for _, groupID := range arg.GroupIds { - if group.ID == groupID { - q.groupMembers = append(q.groupMembers, database.GroupMemberTable{ - UserID: arg.UserID, - GroupID: groupID, - }) - groupIDs = append(groupIDs, group.ID) - } - } - } - - return groupIDs, nil -} - -func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - var groupIDs []uuid.UUID - for _, group := range q.groups { - for _, groupName := range arg.GroupNames { - if group.Name == groupName { - groupIDs = append(groupIDs, group.ID) - } - } - } - - for _, groupID := range groupIDs { - q.groupMembers = append(q.groupMembers, database.GroupMemberTable{ - UserID: arg.UserID, - GroupID: groupID, - }) - } - - return nil -} - -func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - if u, err := q.getUserByIDNoLock(args.UserID); err == nil && u.Deleted { - return database.UserLink{}, deletedUserLinkError - } - - //nolint:gosimple - link := database.UserLink{ - UserID: args.UserID, - LoginType: args.LoginType, - LinkedID: args.LinkedID, - OAuthAccessToken: args.OAuthAccessToken, - OAuthAccessTokenKeyID: args.OAuthAccessTokenKeyID, - OAuthRefreshToken: args.OAuthRefreshToken, - OAuthRefreshTokenKeyID: args.OAuthRefreshTokenKeyID, - OAuthExpiry: args.OAuthExpiry, - Claims: args.Claims, - } - - q.userLinks = append(q.userLinks, link) - - return link, nil -} - -func (q *FakeQuerier) InsertVolumeResourceMonitor(_ context.Context, arg database.InsertVolumeResourceMonitorParams) (database.WorkspaceAgentVolumeResourceMonitor, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceAgentVolumeResourceMonitor{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - monitor := database.WorkspaceAgentVolumeResourceMonitor{ - AgentID: arg.AgentID, - Path: arg.Path, - Enabled: arg.Enabled, - State: arg.State, - Threshold: arg.Threshold, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - DebouncedUntil: arg.DebouncedUntil, - } - - q.workspaceAgentVolumeResourceMonitors = append(q.workspaceAgentVolumeResourceMonitors, monitor) - return monitor, nil -} - -func (q *FakeQuerier) InsertWebpushSubscription(_ context.Context, arg database.InsertWebpushSubscriptionParams) (database.WebpushSubscription, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WebpushSubscription{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - newSub := database.WebpushSubscription{ - ID: uuid.New(), - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - Endpoint: arg.Endpoint, - EndpointP256dhKey: arg.EndpointP256dhKey, - EndpointAuthKey: arg.EndpointAuthKey, - } - q.webpushSubscriptions = append(q.webpushSubscriptions, newSub) - return newSub, nil -} - -func (q *FakeQuerier) InsertWorkspace(_ context.Context, arg database.InsertWorkspaceParams) (database.WorkspaceTable, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceTable{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - workspace := database.WorkspaceTable{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OwnerID: arg.OwnerID, - OrganizationID: arg.OrganizationID, - TemplateID: arg.TemplateID, - Name: arg.Name, - AutostartSchedule: arg.AutostartSchedule, - Ttl: arg.Ttl, - LastUsedAt: arg.LastUsedAt, - AutomaticUpdates: arg.AutomaticUpdates, - NextStartAt: arg.NextStartAt, - } - q.workspaces = append(q.workspaces, workspace) - return workspace, nil -} - -func (q *FakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.InsertWorkspaceAgentParams) (database.WorkspaceAgent, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceAgent{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - agent := database.WorkspaceAgent{ - ID: arg.ID, - ParentID: arg.ParentID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - ResourceID: arg.ResourceID, - AuthToken: arg.AuthToken, - AuthInstanceID: arg.AuthInstanceID, - EnvironmentVariables: arg.EnvironmentVariables, - Name: arg.Name, - Architecture: arg.Architecture, - OperatingSystem: arg.OperatingSystem, - Directory: arg.Directory, - InstanceMetadata: arg.InstanceMetadata, - ResourceMetadata: arg.ResourceMetadata, - ConnectionTimeoutSeconds: arg.ConnectionTimeoutSeconds, - TroubleshootingURL: arg.TroubleshootingURL, - MOTDFile: arg.MOTDFile, - LifecycleState: database.WorkspaceAgentLifecycleStateCreated, - DisplayApps: arg.DisplayApps, - DisplayOrder: arg.DisplayOrder, - APIKeyScope: arg.APIKeyScope, - } - - q.workspaceAgents = append(q.workspaceAgents, agent) - return agent, nil -} - -func (q *FakeQuerier) InsertWorkspaceAgentDevcontainers(_ context.Context, arg database.InsertWorkspaceAgentDevcontainersParams) ([]database.WorkspaceAgentDevcontainer, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, agent := range q.workspaceAgents { - if agent.ID == arg.WorkspaceAgentID { - var devcontainers []database.WorkspaceAgentDevcontainer - for i, id := range arg.ID { - devcontainers = append(devcontainers, database.WorkspaceAgentDevcontainer{ - WorkspaceAgentID: arg.WorkspaceAgentID, - CreatedAt: arg.CreatedAt, - ID: id, - Name: arg.Name[i], - WorkspaceFolder: arg.WorkspaceFolder[i], - ConfigPath: arg.ConfigPath[i], - }) - } - q.workspaceAgentDevcontainers = append(q.workspaceAgentDevcontainers, devcontainers...) - return devcontainers, nil - } - } - - return nil, errForeignKeyConstraint -} - -func (q *FakeQuerier) InsertWorkspaceAgentLogSources(_ context.Context, arg database.InsertWorkspaceAgentLogSourcesParams) ([]database.WorkspaceAgentLogSource, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - logSources := make([]database.WorkspaceAgentLogSource, 0) - for index, source := range arg.ID { - logSource := database.WorkspaceAgentLogSource{ - ID: source, - WorkspaceAgentID: arg.WorkspaceAgentID, - CreatedAt: arg.CreatedAt, - DisplayName: arg.DisplayName[index], - Icon: arg.Icon[index], - } - logSources = append(logSources, logSource) - } - q.workspaceAgentLogSources = append(q.workspaceAgentLogSources, logSources...) - return logSources, nil -} - -func (q *FakeQuerier) InsertWorkspaceAgentLogs(_ context.Context, arg database.InsertWorkspaceAgentLogsParams) ([]database.WorkspaceAgentLog, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - logs := []database.WorkspaceAgentLog{} - id := int64(0) - if len(q.workspaceAgentLogs) > 0 { - id = q.workspaceAgentLogs[len(q.workspaceAgentLogs)-1].ID - } - outputLength := int32(0) - for index, output := range arg.Output { - id++ - logs = append(logs, database.WorkspaceAgentLog{ - ID: id, - AgentID: arg.AgentID, - CreatedAt: arg.CreatedAt, - Level: arg.Level[index], - LogSourceID: arg.LogSourceID, - Output: output, - }) - // #nosec G115 - Safe conversion as log output length is expected to be within int32 range - outputLength += int32(len(output)) - } - for index, agent := range q.workspaceAgents { - if agent.ID != arg.AgentID { - continue - } - // Greater than 1MB, same as the PostgreSQL constraint! - if agent.LogsLength+outputLength > (1 << 20) { - return nil, &pq.Error{ - Constraint: "max_logs_length", - Table: "workspace_agents", - } - } - agent.LogsLength += outputLength - q.workspaceAgents[index] = agent - break - } - q.workspaceAgentLogs = append(q.workspaceAgentLogs, logs...) - return logs, nil -} - -func (q *FakeQuerier) InsertWorkspaceAgentMetadata(_ context.Context, arg database.InsertWorkspaceAgentMetadataParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - metadatum := database.WorkspaceAgentMetadatum{ - WorkspaceAgentID: arg.WorkspaceAgentID, - Script: arg.Script, - DisplayName: arg.DisplayName, - Key: arg.Key, - Timeout: arg.Timeout, - Interval: arg.Interval, - DisplayOrder: arg.DisplayOrder, - } - - q.workspaceAgentMetadata = append(q.workspaceAgentMetadata, metadatum) - return nil -} - -func (q *FakeQuerier) InsertWorkspaceAgentScriptTimings(_ context.Context, arg database.InsertWorkspaceAgentScriptTimingsParams) (database.WorkspaceAgentScriptTiming, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceAgentScriptTiming{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - timing := database.WorkspaceAgentScriptTiming(arg) - q.workspaceAgentScriptTimings = append(q.workspaceAgentScriptTimings, timing) - - return timing, nil -} - -func (q *FakeQuerier) InsertWorkspaceAgentScripts(_ context.Context, arg database.InsertWorkspaceAgentScriptsParams) ([]database.WorkspaceAgentScript, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - scripts := make([]database.WorkspaceAgentScript, 0) - for index, source := range arg.LogSourceID { - script := database.WorkspaceAgentScript{ - LogSourceID: source, - WorkspaceAgentID: arg.WorkspaceAgentID, - ID: arg.ID[index], - LogPath: arg.LogPath[index], - Script: arg.Script[index], - Cron: arg.Cron[index], - StartBlocksLogin: arg.StartBlocksLogin[index], - RunOnStart: arg.RunOnStart[index], - RunOnStop: arg.RunOnStop[index], - TimeoutSeconds: arg.TimeoutSeconds[index], - CreatedAt: arg.CreatedAt, - } - scripts = append(scripts, script) - } - q.workspaceAgentScripts = append(q.workspaceAgentScripts, scripts...) - return scripts, nil -} - -func (q *FakeQuerier) InsertWorkspaceAgentStats(_ context.Context, arg database.InsertWorkspaceAgentStatsParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - var connectionsByProto []map[string]int64 - if err := json.Unmarshal(arg.ConnectionsByProto, &connectionsByProto); err != nil { - return err - } - for i := 0; i < len(arg.ID); i++ { - cbp, err := json.Marshal(connectionsByProto[i]) - if err != nil { - return xerrors.Errorf("failed to marshal connections_by_proto: %w", err) - } - stat := database.WorkspaceAgentStat{ - ID: arg.ID[i], - CreatedAt: arg.CreatedAt[i], - WorkspaceID: arg.WorkspaceID[i], - AgentID: arg.AgentID[i], - UserID: arg.UserID[i], - ConnectionsByProto: cbp, - ConnectionCount: arg.ConnectionCount[i], - RxPackets: arg.RxPackets[i], - RxBytes: arg.RxBytes[i], - TxPackets: arg.TxPackets[i], - TxBytes: arg.TxBytes[i], - TemplateID: arg.TemplateID[i], - SessionCountVSCode: arg.SessionCountVSCode[i], - SessionCountJetBrains: arg.SessionCountJetBrains[i], - SessionCountReconnectingPTY: arg.SessionCountReconnectingPTY[i], - SessionCountSSH: arg.SessionCountSSH[i], - ConnectionMedianLatencyMS: arg.ConnectionMedianLatencyMS[i], - Usage: arg.Usage[i], - } - q.workspaceAgentStats = append(q.workspaceAgentStats, stat) - } - - return nil -} - -func (q *FakeQuerier) InsertWorkspaceAppStats(_ context.Context, arg database.InsertWorkspaceAppStatsParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - -InsertWorkspaceAppStatsLoop: - for i := 0; i < len(arg.UserID); i++ { - stat := database.WorkspaceAppStat{ - ID: q.workspaceAppStatsLastInsertID + 1, - UserID: arg.UserID[i], - WorkspaceID: arg.WorkspaceID[i], - AgentID: arg.AgentID[i], - AccessMethod: arg.AccessMethod[i], - SlugOrPort: arg.SlugOrPort[i], - SessionID: arg.SessionID[i], - SessionStartedAt: arg.SessionStartedAt[i], - SessionEndedAt: arg.SessionEndedAt[i], - Requests: arg.Requests[i], - } - for j, s := range q.workspaceAppStats { - // Check unique constraint for upsert. - if s.UserID == stat.UserID && s.AgentID == stat.AgentID && s.SessionID == stat.SessionID { - q.workspaceAppStats[j].SessionEndedAt = stat.SessionEndedAt - q.workspaceAppStats[j].Requests = stat.Requests - continue InsertWorkspaceAppStatsLoop - } - } - q.workspaceAppStats = append(q.workspaceAppStats, stat) - q.workspaceAppStatsLastInsertID++ - } - - return nil -} - -func (q *FakeQuerier) InsertWorkspaceAppStatus(_ context.Context, arg database.InsertWorkspaceAppStatusParams) (database.WorkspaceAppStatus, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceAppStatus{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - status := database.WorkspaceAppStatus{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - WorkspaceID: arg.WorkspaceID, - AgentID: arg.AgentID, - AppID: arg.AppID, - State: arg.State, - Message: arg.Message, - Uri: arg.Uri, - } - q.workspaceAppStatuses = append(q.workspaceAppStatuses, status) - return status, nil -} - -func (q *FakeQuerier) InsertWorkspaceBuild(_ context.Context, arg database.InsertWorkspaceBuildParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - workspaceBuild := database.WorkspaceBuild{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - WorkspaceID: arg.WorkspaceID, - TemplateVersionID: arg.TemplateVersionID, - BuildNumber: arg.BuildNumber, - Transition: arg.Transition, - InitiatorID: arg.InitiatorID, - JobID: arg.JobID, - ProvisionerState: arg.ProvisionerState, - Deadline: arg.Deadline, - MaxDeadline: arg.MaxDeadline, - Reason: arg.Reason, - TemplateVersionPresetID: arg.TemplateVersionPresetID, - } - q.workspaceBuilds = append(q.workspaceBuilds, workspaceBuild) - return nil -} - -func (q *FakeQuerier) InsertWorkspaceBuildParameters(_ context.Context, arg database.InsertWorkspaceBuildParametersParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, name := range arg.Name { - q.workspaceBuildParameters = append(q.workspaceBuildParameters, database.WorkspaceBuildParameter{ - WorkspaceBuildID: arg.WorkspaceBuildID, - Name: name, - Value: arg.Value[index], - }) - } - return nil -} - -func (q *FakeQuerier) InsertWorkspaceModule(_ context.Context, arg database.InsertWorkspaceModuleParams) (database.WorkspaceModule, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceModule{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - workspaceModule := database.WorkspaceModule(arg) - q.workspaceModules = append(q.workspaceModules, workspaceModule) - return workspaceModule, nil -} - -func (q *FakeQuerier) InsertWorkspaceProxy(_ context.Context, arg database.InsertWorkspaceProxyParams) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - lastRegionID := int32(0) - for _, p := range q.workspaceProxies { - if !p.Deleted && p.Name == arg.Name { - return database.WorkspaceProxy{}, errUniqueConstraint - } - if p.RegionID > lastRegionID { - lastRegionID = p.RegionID - } - } - - p := database.WorkspaceProxy{ - ID: arg.ID, - Name: arg.Name, - DisplayName: arg.DisplayName, - Icon: arg.Icon, - DerpEnabled: arg.DerpEnabled, - DerpOnly: arg.DerpOnly, - TokenHashedSecret: arg.TokenHashedSecret, - RegionID: lastRegionID + 1, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Deleted: false, - } - q.workspaceProxies = append(q.workspaceProxies, p) - return p, nil -} - -func (q *FakeQuerier) InsertWorkspaceResource(_ context.Context, arg database.InsertWorkspaceResourceParams) (database.WorkspaceResource, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceResource{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - //nolint:gosimple - resource := database.WorkspaceResource{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - JobID: arg.JobID, - Transition: arg.Transition, - Type: arg.Type, - Name: arg.Name, - Hide: arg.Hide, - Icon: arg.Icon, - DailyCost: arg.DailyCost, - ModulePath: arg.ModulePath, - } - q.workspaceResources = append(q.workspaceResources, resource) - return resource, nil -} - -func (q *FakeQuerier) InsertWorkspaceResourceMetadata(_ context.Context, arg database.InsertWorkspaceResourceMetadataParams) ([]database.WorkspaceResourceMetadatum, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - metadata := make([]database.WorkspaceResourceMetadatum, 0) - id := int64(1) - if len(q.workspaceResourceMetadata) > 0 { - id = q.workspaceResourceMetadata[len(q.workspaceResourceMetadata)-1].ID - } - for index, key := range arg.Key { - id++ - value := arg.Value[index] - metadata = append(metadata, database.WorkspaceResourceMetadatum{ - ID: id, - WorkspaceResourceID: arg.WorkspaceResourceID, - Key: key, - Value: sql.NullString{ - String: value, - Valid: value != "", - }, - Sensitive: arg.Sensitive[index], - }) - } - q.workspaceResourceMetadata = append(q.workspaceResourceMetadata, metadata...) - return metadata, nil -} - -func (q *FakeQuerier) ListProvisionerKeysByOrganization(_ context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - keys := make([]database.ProvisionerKey, 0) - for _, key := range q.provisionerKeys { - if key.OrganizationID == organizationID { - keys = append(keys, key) - } - } - - return keys, nil -} - -func (q *FakeQuerier) ListProvisionerKeysByOrganizationExcludeReserved(_ context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - keys := make([]database.ProvisionerKey, 0) - for _, key := range q.provisionerKeys { - if key.ID.String() == codersdk.ProvisionerKeyIDBuiltIn || - key.ID.String() == codersdk.ProvisionerKeyIDUserAuth || - key.ID.String() == codersdk.ProvisionerKeyIDPSK { - continue - } - if key.OrganizationID == organizationID { - keys = append(keys, key) - } - } - - return keys, nil -} - -func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - shares := []database.WorkspaceAgentPortShare{} - for _, share := range q.workspaceAgentPortShares { - if share.WorkspaceID == workspaceID { - shares = append(shares, share) - } - } - - return shares, nil -} - -func (q *FakeQuerier) MarkAllInboxNotificationsAsRead(_ context.Context, arg database.MarkAllInboxNotificationsAsReadParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - for idx, notif := range q.inboxNotifications { - if notif.UserID == arg.UserID && !notif.ReadAt.Valid { - q.inboxNotifications[idx].ReadAt = arg.ReadAt - } - } - - return nil -} - -// nolint:forcetypeassert -func (q *FakeQuerier) OIDCClaimFieldValues(_ context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) { - orgMembers := q.getOrganizationMemberNoLock(args.OrganizationID) - - var values []string - for _, link := range q.userLinks { - if args.OrganizationID != uuid.Nil { - inOrg := slices.ContainsFunc(orgMembers, func(organizationMember database.OrganizationMember) bool { - return organizationMember.UserID == link.UserID - }) - if !inOrg { - continue - } - } - - if link.LoginType != database.LoginTypeOIDC { - continue - } - - if len(link.Claims.MergedClaims) == 0 { - continue - } - - value, ok := link.Claims.MergedClaims[args.ClaimField] - if !ok { - continue - } - switch value := value.(type) { - case string: - values = append(values, value) - case []string: - values = append(values, value...) - case []any: - for _, v := range value { - if sv, ok := v.(string); ok { - values = append(values, sv) - } - } - default: - continue - } - } - - return slice.Unique(values), nil -} - -func (q *FakeQuerier) OIDCClaimFields(_ context.Context, organizationID uuid.UUID) ([]string, error) { - orgMembers := q.getOrganizationMemberNoLock(organizationID) - - var fields []string - for _, link := range q.userLinks { - if organizationID != uuid.Nil { - inOrg := slices.ContainsFunc(orgMembers, func(organizationMember database.OrganizationMember) bool { - return organizationMember.UserID == link.UserID - }) - if !inOrg { - continue - } - } - - if link.LoginType != database.LoginTypeOIDC { - continue - } - - for k := range link.Claims.MergedClaims { - fields = append(fields, k) - } - } - - return slice.Unique(fields), nil -} - -func (q *FakeQuerier) OrganizationMembers(_ context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) { - if err := validateDatabaseType(arg); err != nil { - return []database.OrganizationMembersRow{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - tmp := make([]database.OrganizationMembersRow, 0) - for _, organizationMember := range q.organizationMembers { - if arg.OrganizationID != uuid.Nil && organizationMember.OrganizationID != arg.OrganizationID { - continue - } - - if arg.UserID != uuid.Nil && organizationMember.UserID != arg.UserID { - continue - } - - user, _ := q.getUserByIDNoLock(organizationMember.UserID) - tmp = append(tmp, database.OrganizationMembersRow{ - OrganizationMember: organizationMember, - Username: user.Username, - AvatarURL: user.AvatarURL, - Name: user.Name, - Email: user.Email, - GlobalRoles: user.RBACRoles, - }) - } - return tmp, nil -} - -func (q *FakeQuerier) PaginatedOrganizationMembers(_ context.Context, arg database.PaginatedOrganizationMembersParams) ([]database.PaginatedOrganizationMembersRow, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // All of the members in the organization - orgMembers := make([]database.OrganizationMember, 0) - for _, mem := range q.organizationMembers { - if mem.OrganizationID != arg.OrganizationID { - continue - } - - orgMembers = append(orgMembers, mem) - } - - selectedMembers := make([]database.PaginatedOrganizationMembersRow, 0) - - skippedMembers := 0 - for _, organizationMember := range orgMembers { - if skippedMembers < int(arg.OffsetOpt) { - skippedMembers++ - continue - } - - // if the limit is set to 0 we treat that as returning all of the org members - if int(arg.LimitOpt) != 0 && len(selectedMembers) >= int(arg.LimitOpt) { - break - } - - user, _ := q.getUserByIDNoLock(organizationMember.UserID) - selectedMembers = append(selectedMembers, database.PaginatedOrganizationMembersRow{ - OrganizationMember: organizationMember, - Username: user.Username, - AvatarURL: user.AvatarURL, - Name: user.Name, - Email: user.Email, - GlobalRoles: user.RBACRoles, - Count: int64(len(orgMembers)), - }) - } - return selectedMembers, nil -} - -func (q *FakeQuerier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(_ context.Context, templateID uuid.UUID) error { - err := validateDatabaseType(templateID) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, workspace := range q.workspaces { - if workspace.TemplateID != templateID { - continue - } - for i, share := range q.workspaceAgentPortShares { - if share.WorkspaceID != workspace.ID { - continue - } - if share.ShareLevel == database.AppSharingLevelPublic { - share.ShareLevel = database.AppSharingLevelAuthenticated - } - q.workspaceAgentPortShares[i] = share - } - } - - return nil -} - -func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.RegisterWorkspaceProxyParams) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, p := range q.workspaceProxies { - if p.ID == arg.ID { - p.Url = arg.Url - p.WildcardHostname = arg.WildcardHostname - p.DerpEnabled = arg.DerpEnabled - p.DerpOnly = arg.DerpOnly - p.Version = arg.Version - p.UpdatedAt = dbtime.Now() - q.workspaceProxies[i] = p - return p, nil - } - } - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - newMembers := q.groupMembers[:0] - for _, member := range q.groupMembers { - if member.UserID == userID { - continue - } - newMembers = append(newMembers, member) - } - q.groupMembers = newMembers - - return nil -} - -func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) { - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - removed := make([]uuid.UUID, 0) - q.data.groupMembers = slices.DeleteFunc(q.data.groupMembers, func(groupMember database.GroupMemberTable) bool { - // Delete all group members that match the arguments. - if groupMember.UserID != arg.UserID { - // Not the right user, ignore. - return false - } - - if !slices.Contains(arg.GroupIds, groupMember.GroupID) { - return false - } - - removed = append(removed, groupMember.GroupID) - return true - }) - - return removed, nil -} - -func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i := range q.dbcryptKeys { - key := q.dbcryptKeys[i] - - // Is the key already revoked? - if !key.ActiveKeyDigest.Valid { - continue - } - - if key.ActiveKeyDigest.String != activeKeyDigest { - continue - } - - // Check for foreign key constraints. - for _, ul := range q.userLinks { - if (ul.OAuthAccessTokenKeyID.Valid && ul.OAuthAccessTokenKeyID.String == activeKeyDigest) || - (ul.OAuthRefreshTokenKeyID.Valid && ul.OAuthRefreshTokenKeyID.String == activeKeyDigest) { - return errForeignKeyConstraint - } - } - for _, gal := range q.externalAuthLinks { - if (gal.OAuthAccessTokenKeyID.Valid && gal.OAuthAccessTokenKeyID.String == activeKeyDigest) || - (gal.OAuthRefreshTokenKeyID.Valid && gal.OAuthRefreshTokenKeyID.String == activeKeyDigest) { - return errForeignKeyConstraint - } - } - - // Revoke the key. - q.dbcryptKeys[i].RevokedAt = sql.NullTime{Time: dbtime.Now(), Valid: true} - q.dbcryptKeys[i].RevokedKeyDigest = sql.NullString{String: key.ActiveKeyDigest.String, Valid: true} - q.dbcryptKeys[i].ActiveKeyDigest = sql.NullString{} - return nil - } - - return sql.ErrNoRows -} - -func (*FakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) { - return false, xerrors.New("TryAcquireLock must only be called within a transaction") -} - -func (q *FakeQuerier) UnarchiveTemplateVersion(_ context.Context, arg database.UnarchiveTemplateVersionParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, v := range q.data.templateVersions { - if v.ID == arg.TemplateVersionID { - v.Archived = false - v.UpdatedAt = arg.UpdatedAt - q.data.templateVersions[i] = v - return nil - } - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UnfavoriteWorkspace(_ context.Context, arg uuid.UUID) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i := 0; i < len(q.workspaces); i++ { - if q.workspaces[i].ID != arg { - continue - } - q.workspaces[i].Favorite = false - return nil - } - - return nil -} - -func (q *FakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPIKeyByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, apiKey := range q.apiKeys { - if apiKey.ID != arg.ID { - continue - } - apiKey.LastUsed = arg.LastUsed - apiKey.ExpiresAt = arg.ExpiresAt - apiKey.IPAddress = arg.IPAddress - q.apiKeys[index] = apiKey - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateCryptoKeyDeletesAt(_ context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.CryptoKey{}, err - } - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, key := range q.cryptoKeys { - if key.Feature == arg.Feature && key.Sequence == arg.Sequence { - key.DeletesAt = arg.DeletesAt - q.cryptoKeys[i] = key - return key, nil - } - } - - return database.CryptoKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateCustomRole(_ context.Context, arg database.UpdateCustomRoleParams) (database.CustomRole, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.CustomRole{}, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - for i := range q.customRoles { - if strings.EqualFold(q.customRoles[i].Name, arg.Name) && - q.customRoles[i].OrganizationID.UUID == arg.OrganizationID.UUID { - q.customRoles[i].DisplayName = arg.DisplayName - q.customRoles[i].OrganizationID = arg.OrganizationID - q.customRoles[i].SitePermissions = arg.SitePermissions - q.customRoles[i].OrgPermissions = arg.OrgPermissions - q.customRoles[i].UserPermissions = arg.UserPermissions - q.customRoles[i].UpdatedAt = dbtime.Now() - return q.customRoles[i], nil - } - } - return database.CustomRole{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateExternalAuthLink(_ context.Context, arg database.UpdateExternalAuthLinkParams) (database.ExternalAuthLink, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ExternalAuthLink{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - for index, gitAuthLink := range q.externalAuthLinks { - if gitAuthLink.ProviderID != arg.ProviderID { - continue - } - if gitAuthLink.UserID != arg.UserID { - continue - } - gitAuthLink.UpdatedAt = arg.UpdatedAt - gitAuthLink.OAuthAccessToken = arg.OAuthAccessToken - gitAuthLink.OAuthAccessTokenKeyID = arg.OAuthAccessTokenKeyID - gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken - gitAuthLink.OAuthRefreshTokenKeyID = arg.OAuthRefreshTokenKeyID - gitAuthLink.OAuthExpiry = arg.OAuthExpiry - gitAuthLink.OAuthExtra = arg.OAuthExtra - q.externalAuthLinks[index] = gitAuthLink - - return gitAuthLink, nil - } - return database.ExternalAuthLink{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateExternalAuthLinkRefreshToken(_ context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - for index, gitAuthLink := range q.externalAuthLinks { - if gitAuthLink.ProviderID != arg.ProviderID { - continue - } - if gitAuthLink.UserID != arg.UserID { - continue - } - gitAuthLink.UpdatedAt = arg.UpdatedAt - gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken - q.externalAuthLinks[index] = gitAuthLink - - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateGitSSHKey(_ context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { - if err := validateDatabaseType(arg); err != nil { - return database.GitSSHKey{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, key := range q.gitSSHKey { - if key.UserID != arg.UserID { - continue - } - key.UpdatedAt = arg.UpdatedAt - key.PrivateKey = arg.PrivateKey - key.PublicKey = arg.PublicKey - q.gitSSHKey[index] = key - return key, nil - } - return database.GitSSHKey{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateGroupByID(_ context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Group{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, group := range q.groups { - if group.ID == arg.ID { - group.DisplayName = arg.DisplayName - group.Name = arg.Name - group.AvatarURL = arg.AvatarURL - group.QuotaAllowance = arg.QuotaAllowance - q.groups[i] = group - return group, nil - } - } - return database.Group{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateInactiveUsersToDormant(_ context.Context, params database.UpdateInactiveUsersToDormantParams) ([]database.UpdateInactiveUsersToDormantRow, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - var updated []database.UpdateInactiveUsersToDormantRow - for index, user := range q.users { - if user.Status == database.UserStatusActive && user.LastSeenAt.Before(params.LastSeenAfter) && !user.IsSystem { - q.users[index].Status = database.UserStatusDormant - q.users[index].UpdatedAt = params.UpdatedAt - updated = append(updated, database.UpdateInactiveUsersToDormantRow{ - ID: user.ID, - Email: user.Email, - Username: user.Username, - LastSeenAt: user.LastSeenAt, - }) - q.userStatusChanges = append(q.userStatusChanges, database.UserStatusChange{ - UserID: user.ID, - NewStatus: database.UserStatusDormant, - ChangedAt: params.UpdatedAt, - }) - } - } - - if len(updated) == 0 { - return nil, sql.ErrNoRows - } - - return updated, nil -} - -func (q *FakeQuerier) UpdateInboxNotificationReadStatus(_ context.Context, arg database.UpdateInboxNotificationReadStatusParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i := range q.inboxNotifications { - if q.inboxNotifications[i].ID == arg.ID { - q.inboxNotifications[i].ReadAt = arg.ReadAt - } - } - - return nil -} - -func (q *FakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { - if err := validateDatabaseType(arg); err != nil { - return database.OrganizationMember{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, mem := range q.organizationMembers { - if mem.UserID == arg.UserID && mem.OrganizationID == arg.OrgID { - uniqueRoles := make([]string, 0, len(arg.GrantedRoles)) - exist := make(map[string]struct{}) - for _, r := range arg.GrantedRoles { - if _, ok := exist[r]; ok { - continue - } - exist[r] = struct{}{} - uniqueRoles = append(uniqueRoles, r) - } - sort.Strings(uniqueRoles) - - mem.Roles = uniqueRoles - q.organizationMembers[i] = mem - return mem, nil - } - } - - return database.OrganizationMember{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateMemoryResourceMonitor(_ context.Context, arg database.UpdateMemoryResourceMonitorParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, monitor := range q.workspaceAgentMemoryResourceMonitors { - if monitor.AgentID != arg.AgentID { - continue - } - - monitor.State = arg.State - monitor.UpdatedAt = arg.UpdatedAt - monitor.DebouncedUntil = arg.DebouncedUntil - q.workspaceAgentMemoryResourceMonitors[i] = monitor - return nil - } - - return nil -} - -func (*FakeQuerier) UpdateNotificationTemplateMethodByID(_ context.Context, _ database.UpdateNotificationTemplateMethodByIDParams) (database.NotificationTemplate, error) { - // Not implementing this function because it relies on state in the database which is created with migrations. - // We could consider using code-generation to align the database state and dbmem, but it's not worth it right now. - return database.NotificationTemplate{}, ErrUnimplemented -} - -func (q *FakeQuerier) UpdateOAuth2ProviderAppByClientID(ctx context.Context, arg database.UpdateOAuth2ProviderAppByClientIDParams) (database.OAuth2ProviderApp, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderApp{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, app := range q.oauth2ProviderApps { - if app.ID == arg.ID { - app.UpdatedAt = arg.UpdatedAt - app.Name = arg.Name - app.Icon = arg.Icon - app.CallbackURL = arg.CallbackURL - app.RedirectUris = arg.RedirectUris - app.GrantTypes = arg.GrantTypes - app.ResponseTypes = arg.ResponseTypes - app.TokenEndpointAuthMethod = arg.TokenEndpointAuthMethod - app.Scope = arg.Scope - app.Contacts = arg.Contacts - app.ClientUri = arg.ClientUri - app.LogoUri = arg.LogoUri - app.TosUri = arg.TosUri - app.PolicyUri = arg.PolicyUri - app.JwksUri = arg.JwksUri - app.Jwks = arg.Jwks - app.SoftwareID = arg.SoftwareID - app.SoftwareVersion = arg.SoftwareVersion - - // Apply RFC-compliant defaults to match database migration defaults - if !app.ClientType.Valid { - app.ClientType = sql.NullString{String: "confidential", Valid: true} - } - if !app.DynamicallyRegistered.Valid { - app.DynamicallyRegistered = sql.NullBool{Bool: false, Valid: true} - } - if len(app.GrantTypes) == 0 { - app.GrantTypes = []string{"authorization_code", "refresh_token"} - } - if len(app.ResponseTypes) == 0 { - app.ResponseTypes = []string{"code"} - } - if !app.TokenEndpointAuthMethod.Valid { - app.TokenEndpointAuthMethod = sql.NullString{String: "client_secret_basic", Valid: true} - } - if !app.Scope.Valid { - app.Scope = sql.NullString{String: "", Valid: true} - } - if app.Contacts == nil { - app.Contacts = []string{} - } - - q.oauth2ProviderApps[i] = app - return app, nil - } - } - return database.OAuth2ProviderApp{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateOAuth2ProviderAppByID(_ context.Context, arg database.UpdateOAuth2ProviderAppByIDParams) (database.OAuth2ProviderApp, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderApp{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, app := range q.oauth2ProviderApps { - if app.Name == arg.Name && app.ID != arg.ID { - return database.OAuth2ProviderApp{}, errUniqueConstraint - } - } - - for index, app := range q.oauth2ProviderApps { - if app.ID == arg.ID { - app.UpdatedAt = arg.UpdatedAt - app.Name = arg.Name - app.Icon = arg.Icon - app.CallbackURL = arg.CallbackURL - app.RedirectUris = arg.RedirectUris - app.ClientType = arg.ClientType - app.DynamicallyRegistered = arg.DynamicallyRegistered - app.ClientSecretExpiresAt = arg.ClientSecretExpiresAt - app.GrantTypes = arg.GrantTypes - app.ResponseTypes = arg.ResponseTypes - app.TokenEndpointAuthMethod = arg.TokenEndpointAuthMethod - app.Scope = arg.Scope - app.Contacts = arg.Contacts - app.ClientUri = arg.ClientUri - app.LogoUri = arg.LogoUri - app.TosUri = arg.TosUri - app.PolicyUri = arg.PolicyUri - app.JwksUri = arg.JwksUri - app.Jwks = arg.Jwks - app.SoftwareID = arg.SoftwareID - app.SoftwareVersion = arg.SoftwareVersion - - // Apply RFC-compliant defaults to match database migration defaults - if !app.ClientType.Valid { - app.ClientType = sql.NullString{String: "confidential", Valid: true} - } - if !app.DynamicallyRegistered.Valid { - app.DynamicallyRegistered = sql.NullBool{Bool: false, Valid: true} - } - if len(app.GrantTypes) == 0 { - app.GrantTypes = []string{"authorization_code", "refresh_token"} - } - if len(app.ResponseTypes) == 0 { - app.ResponseTypes = []string{"code"} - } - if !app.TokenEndpointAuthMethod.Valid { - app.TokenEndpointAuthMethod = sql.NullString{String: "client_secret_basic", Valid: true} - } - if !app.Scope.Valid { - app.Scope = sql.NullString{String: "", Valid: true} - } - if app.Contacts == nil { - app.Contacts = []string{} - } - - q.oauth2ProviderApps[index] = app - return app, nil - } - } - return database.OAuth2ProviderApp{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateOAuth2ProviderAppSecretByID(_ context.Context, arg database.UpdateOAuth2ProviderAppSecretByIDParams) (database.OAuth2ProviderAppSecret, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.OAuth2ProviderAppSecret{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, secret := range q.oauth2ProviderAppSecrets { - if secret.ID == arg.ID { - newSecret := database.OAuth2ProviderAppSecret{ - ID: arg.ID, - CreatedAt: secret.CreatedAt, - SecretPrefix: secret.SecretPrefix, - HashedSecret: secret.HashedSecret, - DisplaySecret: secret.DisplaySecret, - AppID: secret.AppID, - LastUsedAt: arg.LastUsedAt, - } - q.oauth2ProviderAppSecrets[index] = newSecret - return newSecret, nil - } - } - return database.OAuth2ProviderAppSecret{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateOrganization(_ context.Context, arg database.UpdateOrganizationParams) (database.Organization, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.Organization{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - // Enforce the unique constraint, because the API endpoint relies on the database catching - // non-unique names during updates. - for _, org := range q.organizations { - if org.Name == arg.Name && org.ID != arg.ID { - return database.Organization{}, errUniqueConstraint - } - } - - for i, org := range q.organizations { - if org.ID == arg.ID { - org.Name = arg.Name - org.DisplayName = arg.DisplayName - org.Description = arg.Description - org.Icon = arg.Icon - q.organizations[i] = org - return org, nil - } - } - return database.Organization{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateOrganizationDeletedByID(_ context.Context, arg database.UpdateOrganizationDeletedByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, organization := range q.organizations { - if organization.ID != arg.ID || organization.IsDefault { - continue - } - organization.Deleted = true - organization.UpdatedAt = arg.UpdatedAt - q.organizations[index] = organization - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdatePresetPrebuildStatus(ctx context.Context, arg database.UpdatePresetPrebuildStatusParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, preset := range q.presets { - if preset.ID == arg.PresetID { - preset.PrebuildStatus = arg.Status - return nil - } - } - - return xerrors.Errorf("preset %v does not exist", arg.PresetID) -} - -func (q *FakeQuerier) UpdateProvisionerDaemonLastSeenAt(_ context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for idx := range q.provisionerDaemons { - if q.provisionerDaemons[idx].ID != arg.ID { - continue - } - if q.provisionerDaemons[idx].LastSeenAt.Time.After(arg.LastSeenAt.Time) { - continue - } - q.provisionerDaemons[idx].LastSeenAt = arg.LastSeenAt - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, job := range q.provisionerJobs { - if arg.ID != job.ID { - continue - } - job.UpdatedAt = arg.UpdatedAt - job.JobStatus = provisionerJobStatus(job) - q.provisionerJobs[index] = job - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateProvisionerJobWithCancelByID(_ context.Context, arg database.UpdateProvisionerJobWithCancelByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, job := range q.provisionerJobs { - if arg.ID != job.ID { - continue - } - job.CanceledAt = arg.CanceledAt - job.CompletedAt = arg.CompletedAt - job.JobStatus = provisionerJobStatus(job) - q.provisionerJobs[index] = job - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateProvisionerJobWithCompleteByID(_ context.Context, arg database.UpdateProvisionerJobWithCompleteByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, job := range q.provisionerJobs { - if arg.ID != job.ID { - continue - } - job.UpdatedAt = arg.UpdatedAt - job.CompletedAt = arg.CompletedAt - job.Error = arg.Error - job.ErrorCode = arg.ErrorCode - job.JobStatus = provisionerJobStatus(job) - q.provisionerJobs[index] = job - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateProvisionerJobWithCompleteWithStartedAtByID(_ context.Context, arg database.UpdateProvisionerJobWithCompleteWithStartedAtByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, job := range q.provisionerJobs { - if arg.ID != job.ID { - continue - } - job.UpdatedAt = arg.UpdatedAt - job.CompletedAt = arg.CompletedAt - job.Error = arg.Error - job.ErrorCode = arg.ErrorCode - job.StartedAt = arg.StartedAt - job.JobStatus = provisionerJobStatus(job) - q.provisionerJobs[index] = job - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateReplica(_ context.Context, arg database.UpdateReplicaParams) (database.Replica, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Replica{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, replica := range q.replicas { - if replica.ID != arg.ID { - continue - } - replica.Hostname = arg.Hostname - replica.StartedAt = arg.StartedAt - replica.StoppedAt = arg.StoppedAt - replica.UpdatedAt = arg.UpdatedAt - replica.RelayAddress = arg.RelayAddress - replica.RegionID = arg.RegionID - replica.Version = arg.Version - replica.Error = arg.Error - replica.DatabaseLatency = arg.DatabaseLatency - replica.Primary = arg.Primary - q.replicas[index] = replica - return replica, nil - } - return database.Replica{}, sql.ErrNoRows -} - -func (*FakeQuerier) UpdateTailnetPeerStatusByCoordinator(context.Context, database.UpdateTailnetPeerStatusByCoordinatorParams) error { - return ErrUnimplemented -} - -func (q *FakeQuerier) UpdateTemplateACLByID(_ context.Context, arg database.UpdateTemplateACLByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, template := range q.templates { - if template.ID == arg.ID { - template.GroupACL = arg.GroupACL - template.UserACL = arg.UserACL - - q.templates[i] = template - return nil - } - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateAccessControlByID(_ context.Context, arg database.UpdateTemplateAccessControlByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for idx, tpl := range q.templates { - if tpl.ID != arg.ID { - continue - } - q.templates[idx].RequireActiveVersion = arg.RequireActiveVersion - q.templates[idx].Deprecated = arg.Deprecated - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateActiveVersionByID(_ context.Context, arg database.UpdateTemplateActiveVersionByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, template := range q.templates { - if template.ID != arg.ID { - continue - } - template.ActiveVersionID = arg.ActiveVersionID - template.UpdatedAt = arg.UpdatedAt - q.templates[index] = template - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateDeletedByID(_ context.Context, arg database.UpdateTemplateDeletedByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, template := range q.templates { - if template.ID != arg.ID { - continue - } - template.Deleted = arg.Deleted - template.UpdatedAt = arg.UpdatedAt - q.templates[index] = template - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.UpdateTemplateMetaByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for idx, tpl := range q.templates { - if tpl.ID != arg.ID { - continue - } - tpl.UpdatedAt = dbtime.Now() - tpl.Name = arg.Name - tpl.DisplayName = arg.DisplayName - tpl.Description = arg.Description - tpl.Icon = arg.Icon - tpl.GroupACL = arg.GroupACL - tpl.AllowUserCancelWorkspaceJobs = arg.AllowUserCancelWorkspaceJobs - tpl.MaxPortSharingLevel = arg.MaxPortSharingLevel - tpl.UseClassicParameterFlow = arg.UseClassicParameterFlow - q.templates[idx] = tpl - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateScheduleByID(_ context.Context, arg database.UpdateTemplateScheduleByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for idx, tpl := range q.templates { - if tpl.ID != arg.ID { - continue - } - tpl.AllowUserAutostart = arg.AllowUserAutostart - tpl.AllowUserAutostop = arg.AllowUserAutostop - tpl.UpdatedAt = dbtime.Now() - tpl.DefaultTTL = arg.DefaultTTL - tpl.ActivityBump = arg.ActivityBump - tpl.AutostopRequirementDaysOfWeek = arg.AutostopRequirementDaysOfWeek - tpl.AutostopRequirementWeeks = arg.AutostopRequirementWeeks - tpl.AutostartBlockDaysOfWeek = arg.AutostartBlockDaysOfWeek - tpl.FailureTTL = arg.FailureTTL - tpl.TimeTilDormant = arg.TimeTilDormant - tpl.TimeTilDormantAutoDelete = arg.TimeTilDormantAutoDelete - q.templates[idx] = tpl - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateVersionAITaskByJobID(_ context.Context, arg database.UpdateTemplateVersionAITaskByJobIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, templateVersion := range q.templateVersions { - if templateVersion.JobID != arg.JobID { - continue - } - templateVersion.HasAITask = arg.HasAITask - templateVersion.UpdatedAt = arg.UpdatedAt - q.templateVersions[index] = templateVersion - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateVersionByID(_ context.Context, arg database.UpdateTemplateVersionByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, templateVersion := range q.templateVersions { - if templateVersion.ID != arg.ID { - continue - } - templateVersion.TemplateID = arg.TemplateID - templateVersion.UpdatedAt = arg.UpdatedAt - templateVersion.Name = arg.Name - templateVersion.Message = arg.Message - q.templateVersions[index] = templateVersion - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateVersionDescriptionByJobID(_ context.Context, arg database.UpdateTemplateVersionDescriptionByJobIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, templateVersion := range q.templateVersions { - if templateVersion.JobID != arg.JobID { - continue - } - templateVersion.Readme = arg.Readme - templateVersion.UpdatedAt = arg.UpdatedAt - q.templateVersions[index] = templateVersion - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateVersionExternalAuthProvidersByJobID(_ context.Context, arg database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, templateVersion := range q.templateVersions { - if templateVersion.JobID != arg.JobID { - continue - } - templateVersion.ExternalAuthProviders = arg.ExternalAuthProviders - templateVersion.UpdatedAt = arg.UpdatedAt - q.templateVersions[index] = templateVersion - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateTemplateWorkspacesLastUsedAt(_ context.Context, arg database.UpdateTemplateWorkspacesLastUsedAtParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, ws := range q.workspaces { - if ws.TemplateID != arg.TemplateID { - continue - } - ws.LastUsedAt = arg.LastUsedAt - q.workspaces[i] = ws - } - - return nil -} - -func (q *FakeQuerier) UpdateUserDeletedByID(_ context.Context, id uuid.UUID) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, u := range q.users { - if u.ID == id { - u.Deleted = true - q.users[i] = u - // NOTE: In the real world, this is done by a trigger. - q.apiKeys = slices.DeleteFunc(q.apiKeys, func(u database.APIKey) bool { - return id == u.UserID - }) - - q.userLinks = slices.DeleteFunc(q.userLinks, func(u database.UserLink) bool { - return id == u.UserID - }) - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserGithubComUserID(_ context.Context, arg database.UpdateUserGithubComUserIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, user := range q.users { - if user.ID != arg.ID { - continue - } - user.GithubComUserID = arg.GithubComUserID - q.users[i] = user - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserHashedOneTimePasscode(_ context.Context, arg database.UpdateUserHashedOneTimePasscodeParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, user := range q.users { - if user.ID != arg.ID { - continue - } - user.HashedOneTimePasscode = arg.HashedOneTimePasscode - user.OneTimePasscodeExpiresAt = arg.OneTimePasscodeExpiresAt - q.users[i] = user - } - return nil -} - -func (q *FakeQuerier) UpdateUserHashedPassword(_ context.Context, arg database.UpdateUserHashedPasswordParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, user := range q.users { - if user.ID != arg.ID { - continue - } - user.HashedPassword = arg.HashedPassword - user.HashedOneTimePasscode = nil - user.OneTimePasscodeExpiresAt = sql.NullTime{} - q.users[i] = user - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserLastSeenAt(_ context.Context, arg database.UpdateUserLastSeenAtParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, user := range q.users { - if user.ID != arg.ID { - continue - } - user.LastSeenAt = arg.LastSeenAt - user.UpdatedAt = arg.UpdatedAt - q.users[index] = user - return user, nil - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) { - if err := validateDatabaseType(params); err != nil { - return database.UserLink{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - if u, err := q.getUserByIDNoLock(params.UserID); err == nil && u.Deleted { - return database.UserLink{}, deletedUserLinkError - } - - for i, link := range q.userLinks { - if link.UserID == params.UserID && link.LoginType == params.LoginType { - link.OAuthAccessToken = params.OAuthAccessToken - link.OAuthAccessTokenKeyID = params.OAuthAccessTokenKeyID - link.OAuthRefreshToken = params.OAuthRefreshToken - link.OAuthRefreshTokenKeyID = params.OAuthRefreshTokenKeyID - link.OAuthExpiry = params.OAuthExpiry - link.Claims = params.Claims - - q.userLinks[i] = link - return link, nil - } - } - - return database.UserLink{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserLinkedID(_ context.Context, params database.UpdateUserLinkedIDParams) (database.UserLink, error) { - if err := validateDatabaseType(params); err != nil { - return database.UserLink{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, link := range q.userLinks { - if link.UserID == params.UserID && link.LoginType == params.LoginType { - link.LinkedID = params.LinkedID - - q.userLinks[i] = link - return link, nil - } - } - - return database.UserLink{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserLoginType(_ context.Context, arg database.UpdateUserLoginTypeParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, u := range q.users { - if u.ID == arg.UserID { - u.LoginType = arg.NewLoginType - if arg.NewLoginType != database.LoginTypePassword { - u.HashedPassword = []byte{} - } - q.users[i] = u - return u, nil - } - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserNotificationPreferences(_ context.Context, arg database.UpdateUserNotificationPreferencesParams) (int64, error) { - err := validateDatabaseType(arg) - if err != nil { - return 0, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - var upserted int64 - for i := range arg.NotificationTemplateIds { - var ( - found bool - templateID = arg.NotificationTemplateIds[i] - disabled = arg.Disableds[i] - ) - - for j, np := range q.notificationPreferences { - if np.UserID != arg.UserID { - continue - } - - if np.NotificationTemplateID != templateID { - continue - } - - np.Disabled = disabled - np.UpdatedAt = dbtime.Now() - q.notificationPreferences[j] = np - - upserted++ - found = true - break - } - - if !found { - np := database.NotificationPreference{ - Disabled: disabled, - UserID: arg.UserID, - NotificationTemplateID: templateID, - CreatedAt: dbtime.Now(), - UpdatedAt: dbtime.Now(), - } - q.notificationPreferences = append(q.notificationPreferences, np) - upserted++ - } - } - - return upserted, nil -} - -func (q *FakeQuerier) UpdateUserProfile(_ context.Context, arg database.UpdateUserProfileParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, user := range q.users { - if user.ID != arg.ID { - continue - } - user.Email = arg.Email - user.Username = arg.Username - user.AvatarURL = arg.AvatarURL - user.Name = arg.Name - q.users[index] = user - return user, nil - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserQuietHoursSchedule(_ context.Context, arg database.UpdateUserQuietHoursScheduleParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, user := range q.users { - if user.ID != arg.ID { - continue - } - user.QuietHoursSchedule = arg.QuietHoursSchedule - q.users[index] = user - return user, nil - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserRoles(_ context.Context, arg database.UpdateUserRolesParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, user := range q.users { - if user.ID != arg.ID { - continue - } - - // Set new roles - user.RBACRoles = slice.Unique(arg.GrantedRoles) - // Remove duplicates and sort - uniqueRoles := make([]string, 0, len(user.RBACRoles)) - exist := make(map[string]struct{}) - for _, r := range user.RBACRoles { - if _, ok := exist[r]; ok { - continue - } - exist[r] = struct{}{} - uniqueRoles = append(uniqueRoles, r) - } - sort.Strings(uniqueRoles) - user.RBACRoles = uniqueRoles - - q.users[index] = user - return user, nil - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserStatus(_ context.Context, arg database.UpdateUserStatusParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, user := range q.users { - if user.ID != arg.ID { - continue - } - user.Status = arg.Status - user.UpdatedAt = arg.UpdatedAt - q.users[index] = user - - q.userStatusChanges = append(q.userStatusChanges, database.UserStatusChange{ - UserID: user.ID, - NewStatus: user.Status, - ChangedAt: user.UpdatedAt, - }) - return user, nil - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserTerminalFont(ctx context.Context, arg database.UpdateUserTerminalFontParams) (database.UserConfig, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.UserConfig{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, uc := range q.userConfigs { - if uc.UserID != arg.UserID || uc.Key != "terminal_font" { - continue - } - uc.Value = arg.TerminalFont - q.userConfigs[i] = uc - return uc, nil - } - - uc := database.UserConfig{ - UserID: arg.UserID, - Key: "terminal_font", - Value: arg.TerminalFont, - } - q.userConfigs = append(q.userConfigs, uc) - return uc, nil -} - -func (q *FakeQuerier) UpdateUserThemePreference(_ context.Context, arg database.UpdateUserThemePreferenceParams) (database.UserConfig, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.UserConfig{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, uc := range q.userConfigs { - if uc.UserID != arg.UserID || uc.Key != "theme_preference" { - continue - } - uc.Value = arg.ThemePreference - q.userConfigs[i] = uc - return uc, nil - } - - uc := database.UserConfig{ - UserID: arg.UserID, - Key: "theme_preference", - Value: arg.ThemePreference, - } - q.userConfigs = append(q.userConfigs, uc) - return uc, nil -} - -func (q *FakeQuerier) UpdateVolumeResourceMonitor(_ context.Context, arg database.UpdateVolumeResourceMonitorParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, monitor := range q.workspaceAgentVolumeResourceMonitors { - if monitor.AgentID != arg.AgentID || monitor.Path != arg.Path { - continue - } - - monitor.State = arg.State - monitor.UpdatedAt = arg.UpdatedAt - monitor.DebouncedUntil = arg.DebouncedUntil - q.workspaceAgentVolumeResourceMonitors[i] = monitor - return nil - } - - return nil -} - -func (q *FakeQuerier) UpdateWorkspace(_ context.Context, arg database.UpdateWorkspaceParams) (database.WorkspaceTable, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceTable{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, workspace := range q.workspaces { - if workspace.Deleted || workspace.ID != arg.ID { - continue - } - for _, other := range q.workspaces { - if other.Deleted || other.ID == workspace.ID || workspace.OwnerID != other.OwnerID { - continue - } - if other.Name == arg.Name { - return database.WorkspaceTable{}, errUniqueConstraint - } - } - - workspace.Name = arg.Name - q.workspaces[i] = workspace - - return workspace, nil - } - - return database.WorkspaceTable{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAgentConnectionByID(_ context.Context, arg database.UpdateWorkspaceAgentConnectionByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, agent := range q.workspaceAgents { - if agent.ID != arg.ID { - continue - } - agent.FirstConnectedAt = arg.FirstConnectedAt - agent.LastConnectedAt = arg.LastConnectedAt - agent.DisconnectedAt = arg.DisconnectedAt - agent.UpdatedAt = arg.UpdatedAt - agent.LastConnectedReplicaID = arg.LastConnectedReplicaID - q.workspaceAgents[index] = agent - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAgentLifecycleStateByID(_ context.Context, arg database.UpdateWorkspaceAgentLifecycleStateByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - for i, agent := range q.workspaceAgents { - if agent.ID == arg.ID { - agent.LifecycleState = arg.LifecycleState - agent.StartedAt = arg.StartedAt - agent.ReadyAt = arg.ReadyAt - q.workspaceAgents[i] = agent - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAgentLogOverflowByID(_ context.Context, arg database.UpdateWorkspaceAgentLogOverflowByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - for i, agent := range q.workspaceAgents { - if agent.ID == arg.ID { - agent.LogsOverflowed = arg.LogsOverflowed - q.workspaceAgents[i] = agent - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAgentMetadata(_ context.Context, arg database.UpdateWorkspaceAgentMetadataParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, m := range q.workspaceAgentMetadata { - if m.WorkspaceAgentID != arg.WorkspaceAgentID { - continue - } - for j := 0; j < len(arg.Key); j++ { - if m.Key == arg.Key[j] { - q.workspaceAgentMetadata[i].Value = arg.Value[j] - q.workspaceAgentMetadata[i].Error = arg.Error[j] - q.workspaceAgentMetadata[i].CollectedAt = arg.CollectedAt[j] - return nil - } - } - } - - return nil -} - -func (q *FakeQuerier) UpdateWorkspaceAgentStartupByID(_ context.Context, arg database.UpdateWorkspaceAgentStartupByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - if len(arg.Subsystems) > 0 { - seen := map[database.WorkspaceAgentSubsystem]struct{}{ - arg.Subsystems[0]: {}, - } - for i := 1; i < len(arg.Subsystems); i++ { - s := arg.Subsystems[i] - if _, ok := seen[s]; ok { - return xerrors.Errorf("duplicate subsystem %q", s) - } - seen[s] = struct{}{} - - if arg.Subsystems[i-1] > arg.Subsystems[i] { - return xerrors.Errorf("subsystems not sorted: %q > %q", arg.Subsystems[i-1], arg.Subsystems[i]) - } - } - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, agent := range q.workspaceAgents { - if agent.ID != arg.ID { - continue - } - - agent.Version = arg.Version - agent.APIVersion = arg.APIVersion - agent.ExpandedDirectory = arg.ExpandedDirectory - agent.Subsystems = arg.Subsystems - q.workspaceAgents[index] = agent - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAppHealthByID(_ context.Context, arg database.UpdateWorkspaceAppHealthByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, app := range q.workspaceApps { - if app.ID != arg.ID { - continue - } - app.Health = arg.Health - q.workspaceApps[index] = app - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAutomaticUpdates(_ context.Context, arg database.UpdateWorkspaceAutomaticUpdatesParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - workspace.AutomaticUpdates = arg.AutomaticUpdates - q.workspaces[index] = workspace - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAutostart(_ context.Context, arg database.UpdateWorkspaceAutostartParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - workspace.AutostartSchedule = arg.AutostartSchedule - workspace.NextStartAt = arg.NextStartAt - q.workspaces[index] = workspace - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceBuildAITaskByID(_ context.Context, arg database.UpdateWorkspaceBuildAITaskByIDParams) error { - if arg.HasAITask.Bool && !arg.SidebarAppID.Valid { - return xerrors.Errorf("ai_task_sidebar_app_id is required when has_ai_task is true") - } - if !arg.HasAITask.Valid && arg.SidebarAppID.Valid { - return xerrors.Errorf("ai_task_sidebar_app_id is can only be set when has_ai_task is true") - } - - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.ID != arg.ID { - continue - } - workspaceBuild.HasAITask = arg.HasAITask - workspaceBuild.AITaskSidebarAppID = arg.SidebarAppID - workspaceBuild.UpdatedAt = dbtime.Now() - q.workspaceBuilds[index] = workspaceBuild - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceBuildCostByID(_ context.Context, arg database.UpdateWorkspaceBuildCostByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.ID != arg.ID { - continue - } - workspaceBuild.DailyCost = arg.DailyCost - q.workspaceBuilds[index] = workspaceBuild - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceBuildDeadlineByID(_ context.Context, arg database.UpdateWorkspaceBuildDeadlineByIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for idx, build := range q.workspaceBuilds { - if build.ID != arg.ID { - continue - } - build.Deadline = arg.Deadline - build.MaxDeadline = arg.MaxDeadline - build.UpdatedAt = arg.UpdatedAt - q.workspaceBuilds[idx] = build - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceBuildProvisionerStateByID(_ context.Context, arg database.UpdateWorkspaceBuildProvisionerStateByIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for idx, build := range q.workspaceBuilds { - if build.ID != arg.ID { - continue - } - build.ProvisionerState = arg.ProvisionerState - build.UpdatedAt = arg.UpdatedAt - q.workspaceBuilds[idx] = build - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceDeletedByID(_ context.Context, arg database.UpdateWorkspaceDeletedByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - workspace.Deleted = arg.Deleted - q.workspaces[index] = workspace - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceDormantDeletingAt(_ context.Context, arg database.UpdateWorkspaceDormantDeletingAtParams) (database.WorkspaceTable, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceTable{}, err - } - q.mutex.Lock() - defer q.mutex.Unlock() - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - workspace.DormantAt = arg.DormantAt - if workspace.DormantAt.Time.IsZero() { - workspace.LastUsedAt = dbtime.Now() - workspace.DeletingAt = sql.NullTime{} - } - if !workspace.DormantAt.Time.IsZero() { - var template database.TemplateTable - for _, t := range q.templates { - if t.ID == workspace.TemplateID { - template = t - break - } - } - if template.ID == uuid.Nil { - return database.WorkspaceTable{}, xerrors.Errorf("unable to find workspace template") - } - if template.TimeTilDormantAutoDelete > 0 { - workspace.DeletingAt = sql.NullTime{ - Valid: true, - Time: workspace.DormantAt.Time.Add(time.Duration(template.TimeTilDormantAutoDelete)), - } - } - } - q.workspaces[index] = workspace - return workspace, nil - } - return database.WorkspaceTable{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceLastUsedAt(_ context.Context, arg database.UpdateWorkspaceLastUsedAtParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - workspace.LastUsedAt = arg.LastUsedAt - q.workspaces[index] = workspace - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceNextStartAt(_ context.Context, arg database.UpdateWorkspaceNextStartAtParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - - workspace.NextStartAt = arg.NextStartAt - q.workspaces[index] = workspace - - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceProxy(_ context.Context, arg database.UpdateWorkspaceProxyParams) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - for _, p := range q.workspaceProxies { - if p.Name == arg.Name && p.ID != arg.ID { - return database.WorkspaceProxy{}, errUniqueConstraint - } - } - - for i, p := range q.workspaceProxies { - if p.ID == arg.ID { - p.Name = arg.Name - p.DisplayName = arg.DisplayName - p.Icon = arg.Icon - if len(p.TokenHashedSecret) > 0 { - p.TokenHashedSecret = arg.TokenHashedSecret - } - q.workspaceProxies[i] = p - return p, nil - } - } - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceProxyDeleted(_ context.Context, arg database.UpdateWorkspaceProxyDeletedParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, p := range q.workspaceProxies { - if p.ID == arg.ID { - p.Deleted = arg.Deleted - p.UpdatedAt = dbtime.Now() - q.workspaceProxies[i] = p - return nil - } - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceTTL(_ context.Context, arg database.UpdateWorkspaceTTLParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - workspace.Ttl = arg.Ttl - q.workspaces[index] = workspace - return nil - } - - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspacesDormantDeletingAtByTemplateID(_ context.Context, arg database.UpdateWorkspacesDormantDeletingAtByTemplateIDParams) ([]database.WorkspaceTable, error) { - q.mutex.Lock() - defer q.mutex.Unlock() - - err := validateDatabaseType(arg) - if err != nil { - return nil, err - } - - affectedRows := []database.WorkspaceTable{} - for i, ws := range q.workspaces { - if ws.TemplateID != arg.TemplateID { - continue - } - - if ws.DormantAt.Time.IsZero() { - continue - } - - if !arg.DormantAt.IsZero() { - ws.DormantAt = sql.NullTime{ - Valid: true, - Time: arg.DormantAt, - } - } - - deletingAt := sql.NullTime{ - Valid: arg.TimeTilDormantAutodeleteMs > 0, - } - if arg.TimeTilDormantAutodeleteMs > 0 { - deletingAt.Time = ws.DormantAt.Time.Add(time.Duration(arg.TimeTilDormantAutodeleteMs) * time.Millisecond) - } - ws.DeletingAt = deletingAt - q.workspaces[i] = ws - affectedRows = append(affectedRows, ws) - } - - return affectedRows, nil -} - -func (q *FakeQuerier) UpdateWorkspacesTTLByTemplateID(_ context.Context, arg database.UpdateWorkspacesTTLByTemplateIDParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, ws := range q.workspaces { - if ws.TemplateID != arg.TemplateID { - continue - } - - q.workspaces[i].Ttl = arg.Ttl - } - - return nil -} - -func (q *FakeQuerier) UpsertAnnouncementBanners(_ context.Context, data string) error { - q.mutex.RLock() - defer q.mutex.RUnlock() - - q.announcementBanners = []byte(data) - return nil -} - -func (q *FakeQuerier) UpsertAppSecurityKey(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.appSecurityKey = data - return nil -} - -func (q *FakeQuerier) UpsertApplicationName(_ context.Context, data string) error { - q.mutex.RLock() - defer q.mutex.RUnlock() - - q.applicationName = data - return nil -} - -func (q *FakeQuerier) UpsertCoordinatorResumeTokenSigningKey(_ context.Context, value string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.coordinatorResumeTokenSigningKey = value - return nil -} - -func (q *FakeQuerier) UpsertDefaultProxy(_ context.Context, arg database.UpsertDefaultProxyParams) error { - q.defaultProxyDisplayName = arg.DisplayName - q.defaultProxyIconURL = arg.IconUrl - return nil -} - -func (q *FakeQuerier) UpsertHealthSettings(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.healthSettings = []byte(data) - return nil -} - -func (q *FakeQuerier) UpsertLastUpdateCheck(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.lastUpdateCheck = []byte(data) - return nil -} - -func (q *FakeQuerier) UpsertLogoURL(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.logoURL = data - return nil -} - -func (q *FakeQuerier) UpsertNotificationReportGeneratorLog(_ context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, record := range q.notificationReportGeneratorLogs { - if arg.NotificationTemplateID == record.NotificationTemplateID { - q.notificationReportGeneratorLogs[i].LastGeneratedAt = arg.LastGeneratedAt - return nil - } - } - - q.notificationReportGeneratorLogs = append(q.notificationReportGeneratorLogs, database.NotificationReportGeneratorLog(arg)) - return nil -} - -func (q *FakeQuerier) UpsertNotificationsSettings(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.notificationsSettings = []byte(data) - return nil -} - -func (q *FakeQuerier) UpsertOAuth2GithubDefaultEligible(_ context.Context, eligible bool) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.oauth2GithubDefaultEligible = &eligible - return nil -} - -func (q *FakeQuerier) UpsertOAuthSigningKey(_ context.Context, value string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.oauthSigningKey = value - return nil -} - -func (q *FakeQuerier) UpsertPrebuildsSettings(_ context.Context, value string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.prebuildsSettings = []byte(value) - return nil -} - -func (q *FakeQuerier) UpsertProvisionerDaemon(_ context.Context, arg database.UpsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) { - if err := validateDatabaseType(arg); err != nil { - return database.ProvisionerDaemon{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - // Look for existing daemon using the same composite key as SQL - for i, d := range q.provisionerDaemons { - if d.OrganizationID == arg.OrganizationID && - d.Name == arg.Name && - getOwnerFromTags(d.Tags) == getOwnerFromTags(arg.Tags) { - d.Provisioners = arg.Provisioners - d.Tags = maps.Clone(arg.Tags) - d.LastSeenAt = arg.LastSeenAt - d.Version = arg.Version - d.APIVersion = arg.APIVersion - d.OrganizationID = arg.OrganizationID - d.KeyID = arg.KeyID - q.provisionerDaemons[i] = d - return d, nil - } - } - d := database.ProvisionerDaemon{ - ID: uuid.New(), - CreatedAt: arg.CreatedAt, - Name: arg.Name, - Provisioners: arg.Provisioners, - Tags: maps.Clone(arg.Tags), - LastSeenAt: arg.LastSeenAt, - Version: arg.Version, - APIVersion: arg.APIVersion, - OrganizationID: arg.OrganizationID, - KeyID: arg.KeyID, - } - q.provisionerDaemons = append(q.provisionerDaemons, d) - return d, nil -} - -func (q *FakeQuerier) UpsertRuntimeConfig(_ context.Context, arg database.UpsertRuntimeConfigParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - q.runtimeConfig[arg.Key] = arg.Value - return nil -} - -func (*FakeQuerier) UpsertTailnetAgent(context.Context, database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { - return database.TailnetAgent{}, ErrUnimplemented -} - -func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetClientParams) (database.TailnetClient, error) { - return database.TailnetClient{}, ErrUnimplemented -} - -func (*FakeQuerier) UpsertTailnetClientSubscription(context.Context, database.UpsertTailnetClientSubscriptionParams) error { - return ErrUnimplemented -} - -func (*FakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { - return database.TailnetCoordinator{}, ErrUnimplemented -} - -func (*FakeQuerier) UpsertTailnetPeer(_ context.Context, arg database.UpsertTailnetPeerParams) (database.TailnetPeer, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.TailnetPeer{}, err - } - - return database.TailnetPeer{}, ErrUnimplemented -} - -func (*FakeQuerier) UpsertTailnetTunnel(_ context.Context, arg database.UpsertTailnetTunnelParams) (database.TailnetTunnel, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.TailnetTunnel{}, err - } - - return database.TailnetTunnel{}, ErrUnimplemented -} - -func (q *FakeQuerier) UpsertTelemetryItem(_ context.Context, arg database.UpsertTelemetryItemParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, item := range q.telemetryItems { - if item.Key == arg.Key { - q.telemetryItems[i].Value = arg.Value - q.telemetryItems[i].UpdatedAt = time.Now() - return nil - } - } - - q.telemetryItems = append(q.telemetryItems, database.TelemetryItem{ - Key: arg.Key, - Value: arg.Value, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }) - - return nil -} - -func (q *FakeQuerier) UpsertTemplateUsageStats(ctx context.Context) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - /* - WITH - */ - - /* - latest_start AS ( - SELECT - -- Truncate to hour so that we always look at even ranges of data. - date_trunc('hour', COALESCE( - MAX(start_time) - '1 hour'::interval), - -- Fallback when there are no template usage stats yet. - -- App stats can exist before this, but not agent stats, - -- limit the lookback to avoid inconsistency. - (SELECT MIN(created_at) FROM workspace_agent_stats) - )) AS t - FROM - template_usage_stats - ), - */ - - now := time.Now() - latestStart := time.Time{} - for _, stat := range q.templateUsageStats { - if stat.StartTime.After(latestStart) { - latestStart = stat.StartTime.Add(-time.Hour) - } - } - if latestStart.IsZero() { - for _, stat := range q.workspaceAgentStats { - if latestStart.IsZero() || stat.CreatedAt.Before(latestStart) { - latestStart = stat.CreatedAt - } - } - } - if latestStart.IsZero() { - return nil - } - latestStart = latestStart.Truncate(time.Hour) - - /* - workspace_app_stat_buckets AS ( - SELECT - -- Truncate the minute to the nearest half hour, this is the bucket size - -- for the data. - date_trunc('hour', s.minute_bucket) + trunc(date_part('minute', s.minute_bucket) / 30) * 30 * '1 minute'::interval AS time_bucket, - w.template_id, - was.user_id, - -- Both app stats and agent stats track web terminal usage, but - -- by different means. The app stats value should be more - -- accurate so we don't want to discard it just yet. - CASE - WHEN was.access_method = 'terminal' - THEN '[terminal]' -- Unique name, app names can't contain brackets. - ELSE was.slug_or_port - END AS app_name, - COUNT(DISTINCT s.minute_bucket) AS app_minutes, - -- Store each unique minute bucket for later merge between datasets. - array_agg(DISTINCT s.minute_bucket) AS minute_buckets - FROM - workspace_app_stats AS was - JOIN - workspaces AS w - ON - w.id = was.workspace_id - -- Generate a series of minute buckets for each session for computing the - -- mintes/bucket. - CROSS JOIN - generate_series( - date_trunc('minute', was.session_started_at), - -- Subtract 1 microsecond to avoid creating an extra series. - date_trunc('minute', was.session_ended_at - '1 microsecond'::interval), - '1 minute'::interval - ) AS s(minute_bucket) - WHERE - -- s.minute_bucket >= @start_time::timestamptz - -- AND s.minute_bucket < @end_time::timestamptz - s.minute_bucket >= (SELECT t FROM latest_start) - AND s.minute_bucket < NOW() - GROUP BY - time_bucket, w.template_id, was.user_id, was.access_method, was.slug_or_port - ), - */ - - type workspaceAppStatGroupBy struct { - TimeBucket time.Time - TemplateID uuid.UUID - UserID uuid.UUID - AccessMethod string - SlugOrPort string - } - type workspaceAppStatRow struct { - workspaceAppStatGroupBy - AppName string - AppMinutes int - MinuteBuckets map[time.Time]struct{} - } - workspaceAppStatRows := make(map[workspaceAppStatGroupBy]workspaceAppStatRow) - for _, was := range q.workspaceAppStats { - // Preflight: s.minute_bucket >= (SELECT t FROM latest_start) - if was.SessionEndedAt.Before(latestStart) { - continue - } - // JOIN workspaces - w, err := q.getWorkspaceByIDNoLock(ctx, was.WorkspaceID) - if err != nil { - return err - } - // CROSS JOIN generate_series - for t := was.SessionStartedAt.Truncate(time.Minute); t.Before(was.SessionEndedAt); t = t.Add(time.Minute) { - // WHERE - if t.Before(latestStart) || t.After(now) || t.Equal(now) { - continue - } - - bucket := t.Truncate(30 * time.Minute) - // GROUP BY - key := workspaceAppStatGroupBy{ - TimeBucket: bucket, - TemplateID: w.TemplateID, - UserID: was.UserID, - AccessMethod: was.AccessMethod, - SlugOrPort: was.SlugOrPort, - } - // SELECT - row, ok := workspaceAppStatRows[key] - if !ok { - row = workspaceAppStatRow{ - workspaceAppStatGroupBy: key, - AppName: was.SlugOrPort, - AppMinutes: 0, - MinuteBuckets: make(map[time.Time]struct{}), - } - if was.AccessMethod == "terminal" { - row.AppName = "[terminal]" - } - } - row.MinuteBuckets[t] = struct{}{} - row.AppMinutes = len(row.MinuteBuckets) - workspaceAppStatRows[key] = row - } - } - - /* - agent_stats_buckets AS ( - SELECT - -- Truncate the minute to the nearest half hour, this is the bucket size - -- for the data. - date_trunc('hour', created_at) + trunc(date_part('minute', created_at) / 30) * 30 * '1 minute'::interval AS time_bucket, - template_id, - user_id, - -- Store each unique minute bucket for later merge between datasets. - array_agg( - DISTINCT CASE - WHEN - session_count_ssh > 0 - -- TODO(mafredri): Enable when we have the column. - -- OR session_count_sftp > 0 - OR session_count_reconnecting_pty > 0 - OR session_count_vscode > 0 - OR session_count_jetbrains > 0 - THEN - date_trunc('minute', created_at) - ELSE - NULL - END - ) AS minute_buckets, - COUNT(DISTINCT CASE WHEN session_count_ssh > 0 THEN date_trunc('minute', created_at) ELSE NULL END) AS ssh_mins, - -- TODO(mafredri): Enable when we have the column. - -- COUNT(DISTINCT CASE WHEN session_count_sftp > 0 THEN date_trunc('minute', created_at) ELSE NULL END) AS sftp_mins, - COUNT(DISTINCT CASE WHEN session_count_reconnecting_pty > 0 THEN date_trunc('minute', created_at) ELSE NULL END) AS reconnecting_pty_mins, - COUNT(DISTINCT CASE WHEN session_count_vscode > 0 THEN date_trunc('minute', created_at) ELSE NULL END) AS vscode_mins, - COUNT(DISTINCT CASE WHEN session_count_jetbrains > 0 THEN date_trunc('minute', created_at) ELSE NULL END) AS jetbrains_mins, - -- NOTE(mafredri): The agent stats are currently very unreliable, and - -- sometimes the connections are missing, even during active sessions. - -- Since we can't fully rely on this, we check for "any connection - -- during this half-hour". A better solution here would be preferable. - MAX(connection_count) > 0 AS has_connection - FROM - workspace_agent_stats - WHERE - -- created_at >= @start_time::timestamptz - -- AND created_at < @end_time::timestamptz - created_at >= (SELECT t FROM latest_start) - AND created_at < NOW() - -- Inclusion criteria to filter out empty results. - AND ( - session_count_ssh > 0 - -- TODO(mafredri): Enable when we have the column. - -- OR session_count_sftp > 0 - OR session_count_reconnecting_pty > 0 - OR session_count_vscode > 0 - OR session_count_jetbrains > 0 - ) - GROUP BY - time_bucket, template_id, user_id - ), - */ - - type agentStatGroupBy struct { - TimeBucket time.Time - TemplateID uuid.UUID - UserID uuid.UUID - } - type agentStatRow struct { - agentStatGroupBy - MinuteBuckets map[time.Time]struct{} - SSHMinuteBuckets map[time.Time]struct{} - SSHMins int - SFTPMinuteBuckets map[time.Time]struct{} - SFTPMins int - ReconnectingPTYMinuteBuckets map[time.Time]struct{} - ReconnectingPTYMins int - VSCodeMinuteBuckets map[time.Time]struct{} - VSCodeMins int - JetBrainsMinuteBuckets map[time.Time]struct{} - JetBrainsMins int - HasConnection bool - } - agentStatRows := make(map[agentStatGroupBy]agentStatRow) - for _, was := range q.workspaceAgentStats { - // WHERE - if was.CreatedAt.Before(latestStart) || was.CreatedAt.After(now) || was.CreatedAt.Equal(now) { - continue - } - if was.SessionCountSSH == 0 && was.SessionCountReconnectingPTY == 0 && was.SessionCountVSCode == 0 && was.SessionCountJetBrains == 0 { - continue - } - // GROUP BY - key := agentStatGroupBy{ - TimeBucket: was.CreatedAt.Truncate(30 * time.Minute), - TemplateID: was.TemplateID, - UserID: was.UserID, - } - // SELECT - row, ok := agentStatRows[key] - if !ok { - row = agentStatRow{ - agentStatGroupBy: key, - MinuteBuckets: make(map[time.Time]struct{}), - SSHMinuteBuckets: make(map[time.Time]struct{}), - SFTPMinuteBuckets: make(map[time.Time]struct{}), - ReconnectingPTYMinuteBuckets: make(map[time.Time]struct{}), - VSCodeMinuteBuckets: make(map[time.Time]struct{}), - JetBrainsMinuteBuckets: make(map[time.Time]struct{}), - } - } - minute := was.CreatedAt.Truncate(time.Minute) - row.MinuteBuckets[minute] = struct{}{} - if was.SessionCountSSH > 0 { - row.SSHMinuteBuckets[minute] = struct{}{} - row.SSHMins = len(row.SSHMinuteBuckets) - } - // TODO(mafredri): Enable when we have the column. - // if was.SessionCountSFTP > 0 { - // row.SFTPMinuteBuckets[minute] = struct{}{} - // row.SFTPMins = len(row.SFTPMinuteBuckets) - // } - _ = row.SFTPMinuteBuckets - if was.SessionCountReconnectingPTY > 0 { - row.ReconnectingPTYMinuteBuckets[minute] = struct{}{} - row.ReconnectingPTYMins = len(row.ReconnectingPTYMinuteBuckets) - } - if was.SessionCountVSCode > 0 { - row.VSCodeMinuteBuckets[minute] = struct{}{} - row.VSCodeMins = len(row.VSCodeMinuteBuckets) - } - if was.SessionCountJetBrains > 0 { - row.JetBrainsMinuteBuckets[minute] = struct{}{} - row.JetBrainsMins = len(row.JetBrainsMinuteBuckets) - } - if !row.HasConnection { - row.HasConnection = was.ConnectionCount > 0 - } - agentStatRows[key] = row - } - - /* - stats AS ( - SELECT - stats.time_bucket AS start_time, - stats.time_bucket + '30 minutes'::interval AS end_time, - stats.template_id, - stats.user_id, - -- Sum/distinct to handle zero/duplicate values due union and to unnest. - COUNT(DISTINCT minute_bucket) AS usage_mins, - array_agg(DISTINCT minute_bucket) AS minute_buckets, - SUM(DISTINCT stats.ssh_mins) AS ssh_mins, - SUM(DISTINCT stats.sftp_mins) AS sftp_mins, - SUM(DISTINCT stats.reconnecting_pty_mins) AS reconnecting_pty_mins, - SUM(DISTINCT stats.vscode_mins) AS vscode_mins, - SUM(DISTINCT stats.jetbrains_mins) AS jetbrains_mins, - -- This is what we unnested, re-nest as json. - jsonb_object_agg(stats.app_name, stats.app_minutes) FILTER (WHERE stats.app_name IS NOT NULL) AS app_usage_mins - FROM ( - SELECT - time_bucket, - template_id, - user_id, - 0 AS ssh_mins, - 0 AS sftp_mins, - 0 AS reconnecting_pty_mins, - 0 AS vscode_mins, - 0 AS jetbrains_mins, - app_name, - app_minutes, - minute_buckets - FROM - workspace_app_stat_buckets - - UNION ALL - - SELECT - time_bucket, - template_id, - user_id, - ssh_mins, - -- TODO(mafredri): Enable when we have the column. - 0 AS sftp_mins, - reconnecting_pty_mins, - vscode_mins, - jetbrains_mins, - NULL AS app_name, - NULL AS app_minutes, - minute_buckets - FROM - agent_stats_buckets - WHERE - -- See note in the agent_stats_buckets CTE. - has_connection - ) AS stats, unnest(minute_buckets) AS minute_bucket - GROUP BY - stats.time_bucket, stats.template_id, stats.user_id - ), - */ - - type statsGroupBy struct { - TimeBucket time.Time - TemplateID uuid.UUID - UserID uuid.UUID - } - type statsRow struct { - statsGroupBy - UsageMinuteBuckets map[time.Time]struct{} - UsageMins int - SSHMins int - SFTPMins int - ReconnectingPTYMins int - VSCodeMins int - JetBrainsMins int - AppUsageMinutes map[string]int - } - statsRows := make(map[statsGroupBy]statsRow) - for _, was := range workspaceAppStatRows { - // GROUP BY - key := statsGroupBy{ - TimeBucket: was.TimeBucket, - TemplateID: was.TemplateID, - UserID: was.UserID, - } - // SELECT - row, ok := statsRows[key] - if !ok { - row = statsRow{ - statsGroupBy: key, - UsageMinuteBuckets: make(map[time.Time]struct{}), - AppUsageMinutes: make(map[string]int), - } - } - for t := range was.MinuteBuckets { - row.UsageMinuteBuckets[t] = struct{}{} - } - row.UsageMins = len(row.UsageMinuteBuckets) - row.AppUsageMinutes[was.AppName] = was.AppMinutes - statsRows[key] = row - } - for _, was := range agentStatRows { - // GROUP BY - key := statsGroupBy{ - TimeBucket: was.TimeBucket, - TemplateID: was.TemplateID, - UserID: was.UserID, - } - // SELECT - row, ok := statsRows[key] - if !ok { - row = statsRow{ - statsGroupBy: key, - UsageMinuteBuckets: make(map[time.Time]struct{}), - AppUsageMinutes: make(map[string]int), - } - } - for t := range was.MinuteBuckets { - row.UsageMinuteBuckets[t] = struct{}{} - } - row.UsageMins = len(row.UsageMinuteBuckets) - row.SSHMins += was.SSHMins - row.SFTPMins += was.SFTPMins - row.ReconnectingPTYMins += was.ReconnectingPTYMins - row.VSCodeMins += was.VSCodeMins - row.JetBrainsMins += was.JetBrainsMins - statsRows[key] = row - } - - /* - minute_buckets AS ( - -- Create distinct minute buckets for user-activity, so we can filter out - -- irrelevant latencies. - SELECT DISTINCT ON (stats.start_time, stats.template_id, stats.user_id, minute_bucket) - stats.start_time, - stats.template_id, - stats.user_id, - minute_bucket - FROM - stats, unnest(minute_buckets) AS minute_bucket - ), - latencies AS ( - -- Select all non-zero latencies for all the minutes that a user used the - -- workspace in some way. - SELECT - mb.start_time, - mb.template_id, - mb.user_id, - -- TODO(mafredri): We're doing medians on medians here, we may want to - -- improve upon this at some point. - PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY was.connection_median_latency_ms)::real AS median_latency_ms - FROM - minute_buckets AS mb - JOIN - workspace_agent_stats AS was - ON - date_trunc('minute', was.created_at) = mb.minute_bucket - AND was.template_id = mb.template_id - AND was.user_id = mb.user_id - AND was.connection_median_latency_ms >= 0 - GROUP BY - mb.start_time, mb.template_id, mb.user_id - ) - */ - - type latenciesGroupBy struct { - StartTime time.Time - TemplateID uuid.UUID - UserID uuid.UUID - } - type latenciesRow struct { - latenciesGroupBy - Latencies []float64 - MedianLatencyMS float64 - } - latenciesRows := make(map[latenciesGroupBy]latenciesRow) - for _, stat := range statsRows { - for t := range stat.UsageMinuteBuckets { - // GROUP BY - key := latenciesGroupBy{ - StartTime: stat.TimeBucket, - TemplateID: stat.TemplateID, - UserID: stat.UserID, - } - // JOIN - for _, was := range q.workspaceAgentStats { - if !t.Equal(was.CreatedAt.Truncate(time.Minute)) { - continue - } - if was.TemplateID != stat.TemplateID || was.UserID != stat.UserID { - continue - } - if was.ConnectionMedianLatencyMS < 0 { - continue - } - // SELECT - row, ok := latenciesRows[key] - if !ok { - row = latenciesRow{ - latenciesGroupBy: key, - } - } - row.Latencies = append(row.Latencies, was.ConnectionMedianLatencyMS) - sort.Float64s(row.Latencies) - if len(row.Latencies) == 1 { - row.MedianLatencyMS = was.ConnectionMedianLatencyMS - } else if len(row.Latencies)%2 == 0 { - row.MedianLatencyMS = (row.Latencies[len(row.Latencies)/2-1] + row.Latencies[len(row.Latencies)/2]) / 2 - } else { - row.MedianLatencyMS = row.Latencies[len(row.Latencies)/2] - } - latenciesRows[key] = row - } - } - } - - /* - INSERT INTO template_usage_stats AS tus ( - start_time, - end_time, - template_id, - user_id, - usage_mins, - median_latency_ms, - ssh_mins, - sftp_mins, - reconnecting_pty_mins, - vscode_mins, - jetbrains_mins, - app_usage_mins - ) ( - SELECT - stats.start_time, - stats.end_time, - stats.template_id, - stats.user_id, - stats.usage_mins, - latencies.median_latency_ms, - stats.ssh_mins, - stats.sftp_mins, - stats.reconnecting_pty_mins, - stats.vscode_mins, - stats.jetbrains_mins, - stats.app_usage_mins - FROM - stats - LEFT JOIN - latencies - ON - -- The latencies group-by ensures there at most one row. - latencies.start_time = stats.start_time - AND latencies.template_id = stats.template_id - AND latencies.user_id = stats.user_id - ) - ON CONFLICT - (start_time, template_id, user_id) - DO UPDATE - SET - usage_mins = EXCLUDED.usage_mins, - median_latency_ms = EXCLUDED.median_latency_ms, - ssh_mins = EXCLUDED.ssh_mins, - sftp_mins = EXCLUDED.sftp_mins, - reconnecting_pty_mins = EXCLUDED.reconnecting_pty_mins, - vscode_mins = EXCLUDED.vscode_mins, - jetbrains_mins = EXCLUDED.jetbrains_mins, - app_usage_mins = EXCLUDED.app_usage_mins - WHERE - (tus.*) IS DISTINCT FROM (EXCLUDED.*); - */ - -TemplateUsageStatsInsertLoop: - for _, stat := range statsRows { - // LEFT JOIN latencies - latency, latencyOk := latenciesRows[latenciesGroupBy{ - StartTime: stat.TimeBucket, - TemplateID: stat.TemplateID, - UserID: stat.UserID, - }] - - // SELECT - tus := database.TemplateUsageStat{ - StartTime: stat.TimeBucket, - EndTime: stat.TimeBucket.Add(30 * time.Minute), - TemplateID: stat.TemplateID, - UserID: stat.UserID, - // #nosec G115 - Safe conversion for usage minutes which are expected to be within int16 range - UsageMins: int16(stat.UsageMins), - MedianLatencyMs: sql.NullFloat64{Float64: latency.MedianLatencyMS, Valid: latencyOk}, - // #nosec G115 - Safe conversion for SSH minutes which are expected to be within int16 range - SshMins: int16(stat.SSHMins), - // #nosec G115 - Safe conversion for SFTP minutes which are expected to be within int16 range - SftpMins: int16(stat.SFTPMins), - // #nosec G115 - Safe conversion for ReconnectingPTY minutes which are expected to be within int16 range - ReconnectingPtyMins: int16(stat.ReconnectingPTYMins), - // #nosec G115 - Safe conversion for VSCode minutes which are expected to be within int16 range - VscodeMins: int16(stat.VSCodeMins), - // #nosec G115 - Safe conversion for JetBrains minutes which are expected to be within int16 range - JetbrainsMins: int16(stat.JetBrainsMins), - } - if len(stat.AppUsageMinutes) > 0 { - tus.AppUsageMins = make(map[string]int64, len(stat.AppUsageMinutes)) - for k, v := range stat.AppUsageMinutes { - tus.AppUsageMins[k] = int64(v) - } - } - - // ON CONFLICT - for i, existing := range q.templateUsageStats { - if existing.StartTime.Equal(tus.StartTime) && existing.TemplateID == tus.TemplateID && existing.UserID == tus.UserID { - q.templateUsageStats[i] = tus - continue TemplateUsageStatsInsertLoop - } - } - // INSERT INTO - q.templateUsageStats = append(q.templateUsageStats, tus) - } - - return nil -} - -func (q *FakeQuerier) UpsertWebpushVAPIDKeys(_ context.Context, arg database.UpsertWebpushVAPIDKeysParams) error { - err := validateDatabaseType(arg) - if err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - q.webpushVAPIDPublicKey = arg.VapidPublicKey - q.webpushVAPIDPrivateKey = arg.VapidPrivateKey - return nil -} - -func (q *FakeQuerier) UpsertWorkspaceAgentPortShare(_ context.Context, arg database.UpsertWorkspaceAgentPortShareParams) (database.WorkspaceAgentPortShare, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceAgentPortShare{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, share := range q.workspaceAgentPortShares { - if share.WorkspaceID == arg.WorkspaceID && share.Port == arg.Port && share.AgentName == arg.AgentName { - share.ShareLevel = arg.ShareLevel - share.Protocol = arg.Protocol - q.workspaceAgentPortShares[i] = share - return share, nil - } - } - - //nolint:gosimple // casts are not a simplification - psl := database.WorkspaceAgentPortShare{ - WorkspaceID: arg.WorkspaceID, - AgentName: arg.AgentName, - Port: arg.Port, - ShareLevel: arg.ShareLevel, - Protocol: arg.Protocol, - } - q.workspaceAgentPortShares = append(q.workspaceAgentPortShares, psl) - - return psl, nil -} - -func (q *FakeQuerier) UpsertWorkspaceApp(ctx context.Context, arg database.UpsertWorkspaceAppParams) (database.WorkspaceApp, error) { - err := validateDatabaseType(arg) - if err != nil { - return database.WorkspaceApp{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - if arg.SharingLevel == "" { - arg.SharingLevel = database.AppSharingLevelOwner - } - if arg.OpenIn == "" { - arg.OpenIn = database.WorkspaceAppOpenInSlimWindow - } - - buildApp := func(id uuid.UUID, createdAt time.Time) database.WorkspaceApp { - return database.WorkspaceApp{ - ID: id, - CreatedAt: createdAt, - AgentID: arg.AgentID, - Slug: arg.Slug, - DisplayName: arg.DisplayName, - Icon: arg.Icon, - Command: arg.Command, - Url: arg.Url, - External: arg.External, - Subdomain: arg.Subdomain, - SharingLevel: arg.SharingLevel, - HealthcheckUrl: arg.HealthcheckUrl, - HealthcheckInterval: arg.HealthcheckInterval, - HealthcheckThreshold: arg.HealthcheckThreshold, - Health: arg.Health, - Hidden: arg.Hidden, - DisplayOrder: arg.DisplayOrder, - OpenIn: arg.OpenIn, - DisplayGroup: arg.DisplayGroup, - } - } - - for i, app := range q.workspaceApps { - if app.ID == arg.ID { - q.workspaceApps[i] = buildApp(app.ID, app.CreatedAt) - return q.workspaceApps[i], nil - } - } - - workspaceApp := buildApp(arg.ID, arg.CreatedAt) - q.workspaceApps = append(q.workspaceApps, workspaceApp) - return workspaceApp, nil -} - -func (q *FakeQuerier) UpsertWorkspaceAppAuditSession(_ context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (bool, error) { - err := validateDatabaseType(arg) - if err != nil { - return false, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, s := range q.workspaceAppAuditSessions { - if s.AgentID != arg.AgentID { - continue - } - if s.AppID != arg.AppID { - continue - } - if s.UserID != arg.UserID { - continue - } - if s.Ip != arg.Ip { - continue - } - if s.UserAgent != arg.UserAgent { - continue - } - if s.SlugOrPort != arg.SlugOrPort { - continue - } - if s.StatusCode != arg.StatusCode { - continue - } - - staleTime := dbtime.Now().Add(-(time.Duration(arg.StaleIntervalMS) * time.Millisecond)) - fresh := s.UpdatedAt.After(staleTime) - - q.workspaceAppAuditSessions[i].UpdatedAt = arg.UpdatedAt - if !fresh { - q.workspaceAppAuditSessions[i].ID = arg.ID - q.workspaceAppAuditSessions[i].StartedAt = arg.StartedAt - return true, nil - } - return false, nil - } - - q.workspaceAppAuditSessions = append(q.workspaceAppAuditSessions, database.WorkspaceAppAuditSession{ - AgentID: arg.AgentID, - AppID: arg.AppID, - UserID: arg.UserID, - Ip: arg.Ip, - UserAgent: arg.UserAgent, - SlugOrPort: arg.SlugOrPort, - StatusCode: arg.StatusCode, - StartedAt: arg.StartedAt, - UpdatedAt: arg.UpdatedAt, - }) - return true, nil -} - -func (q *FakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Call this to match the same function calls as the SQL implementation. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL()) - if err != nil { - return nil, err - } - } - - var templates []database.Template - for _, templateTable := range q.templates { - template := q.templateWithNameNoLock(templateTable) - if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil { - continue - } - - if template.Deleted != arg.Deleted { - continue - } - if arg.OrganizationID != uuid.Nil && template.OrganizationID != arg.OrganizationID { - continue - } - - if arg.ExactName != "" && !strings.EqualFold(template.Name, arg.ExactName) { - continue - } - // Filters templates based on the search query filter 'Deprecated' status - // Matching SQL logic: - // -- Filter by deprecated - // AND CASE - // WHEN :deprecated IS NOT NULL THEN - // CASE - // WHEN :deprecated THEN deprecated != '' - // ELSE deprecated = '' - // END - // ELSE true - if arg.Deprecated.Valid && arg.Deprecated.Bool != isDeprecated(template) { - continue - } - if arg.FuzzyName != "" { - if !strings.Contains(strings.ToLower(template.Name), strings.ToLower(arg.FuzzyName)) { - continue - } - } - - if len(arg.IDs) > 0 { - match := false - for _, id := range arg.IDs { - if template.ID == id { - match = true - break - } - } - if !match { - continue - } - } - - if arg.HasAITask.Valid { - tv, err := q.getTemplateVersionByIDNoLock(ctx, template.ActiveVersionID) - if err != nil { - return nil, xerrors.Errorf("get template version: %w", err) - } - tvHasAITask := tv.HasAITask.Valid && tv.HasAITask.Bool - if tvHasAITask != arg.HasAITask.Bool { - continue - } - } - - templates = append(templates, template) - } - if len(templates) > 0 { - slices.SortFunc(templates, func(a, b database.Template) int { - if a.Name != b.Name { - return slice.Ascending(a.Name, b.Name) - } - return slice.Ascending(a.ID.String(), b.ID.String()) - }) - return templates, nil - } - - return nil, sql.ErrNoRows -} - -func (q *FakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var template database.TemplateTable - for _, t := range q.templates { - if t.ID == id { - template = t - break - } - } - - if template.ID == uuid.Nil { - return nil, sql.ErrNoRows - } - - groups := make([]database.TemplateGroup, 0, len(template.GroupACL)) - for k, v := range template.GroupACL { - group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k)) - if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get group by ID: %w", err) - } - // We don't delete groups from the map if they - // get deleted so just skip. - if xerrors.Is(err, sql.ErrNoRows) { - continue - } - - groups = append(groups, database.TemplateGroup{ - Group: group, - Actions: v, - }) - } - - return groups, nil -} - -func (q *FakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var template database.TemplateTable - for _, t := range q.templates { - if t.ID == id { - template = t - break - } - } - - if template.ID == uuid.Nil { - return nil, sql.ErrNoRows - } - - users := make([]database.TemplateUser, 0, len(template.UserACL)) - for k, v := range template.UserACL { - user, err := q.getUserByIDNoLock(uuid.MustParse(k)) - if err != nil && xerrors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get user by ID: %w", err) - } - // We don't delete users from the map if they - // get deleted so just skip. - if xerrors.Is(err, sql.ErrNoRows) { - continue - } - - if user.Deleted || user.Status == database.UserStatusSuspended { - continue - } - - users = append(users, database.TemplateUser{ - User: user, - Actions: v, - }) - } - - return users, nil -} - -func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - if prepared != nil { - // Call this to match the same function calls as the SQL implementation. - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) - if err != nil { - return nil, err - } - } - - workspaces := make([]database.WorkspaceTable, 0) - for _, workspace := range q.workspaces { - if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID { - continue - } - - if len(arg.HasParam) > 0 || len(arg.ParamNames) > 0 { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - params := make([]database.WorkspaceBuildParameter, 0) - for _, param := range q.workspaceBuildParameters { - if param.WorkspaceBuildID != build.ID { - continue - } - params = append(params, param) - } - - index := slices.IndexFunc(params, func(buildParam database.WorkspaceBuildParameter) bool { - // If hasParam matches, then we are done. This is a good match. - if slices.ContainsFunc(arg.HasParam, func(name string) bool { - return strings.EqualFold(buildParam.Name, name) - }) { - return true - } - - // Check name + value - match := false - for i := range arg.ParamNames { - matchName := arg.ParamNames[i] - if !strings.EqualFold(matchName, buildParam.Name) { - continue - } - - matchValue := arg.ParamValues[i] - if !strings.EqualFold(matchValue, buildParam.Value) { - continue - } - match = true - break - } - - return match - }) - if index < 0 { - continue - } - } - - if arg.OrganizationID != uuid.Nil { - if workspace.OrganizationID != arg.OrganizationID { - continue - } - } - - if arg.OwnerUsername != "" { - owner, err := q.getUserByIDNoLock(workspace.OwnerID) - if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) { - continue - } - } - - if arg.TemplateName != "" { - template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) - if err == nil && !strings.EqualFold(arg.TemplateName, template.Name) { - continue - } - } - - if arg.UsingActive.Valid { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) - if err != nil { - return nil, xerrors.Errorf("get template: %w", err) - } - - updated := build.TemplateVersionID == template.ActiveVersionID - if arg.UsingActive.Bool != updated { - continue - } - } - - if !arg.Deleted && workspace.Deleted { - continue - } - - if arg.Name != "" && !strings.Contains(strings.ToLower(workspace.Name), strings.ToLower(arg.Name)) { - continue - } - - if !arg.LastUsedBefore.IsZero() { - if workspace.LastUsedAt.After(arg.LastUsedBefore) { - continue - } - } - - if !arg.LastUsedAfter.IsZero() { - if workspace.LastUsedAt.Before(arg.LastUsedAfter) { - continue - } - } - - if arg.Status != "" { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - // This logic should match the logic in the workspace.sql file. - var statusMatch bool - switch database.WorkspaceStatus(arg.Status) { - case database.WorkspaceStatusStarting: - statusMatch = job.JobStatus == database.ProvisionerJobStatusRunning && - build.Transition == database.WorkspaceTransitionStart - case database.WorkspaceStatusStopping: - statusMatch = job.JobStatus == database.ProvisionerJobStatusRunning && - build.Transition == database.WorkspaceTransitionStop - case database.WorkspaceStatusDeleting: - statusMatch = job.JobStatus == database.ProvisionerJobStatusRunning && - build.Transition == database.WorkspaceTransitionDelete - - case "started": - statusMatch = job.JobStatus == database.ProvisionerJobStatusSucceeded && - build.Transition == database.WorkspaceTransitionStart - case database.WorkspaceStatusDeleted: - statusMatch = job.JobStatus == database.ProvisionerJobStatusSucceeded && - build.Transition == database.WorkspaceTransitionDelete - case database.WorkspaceStatusStopped: - statusMatch = job.JobStatus == database.ProvisionerJobStatusSucceeded && - build.Transition == database.WorkspaceTransitionStop - case database.WorkspaceStatusRunning: - statusMatch = job.JobStatus == database.ProvisionerJobStatusSucceeded && - build.Transition == database.WorkspaceTransitionStart - default: - statusMatch = job.JobStatus == database.ProvisionerJobStatus(arg.Status) - } - if !statusMatch { - continue - } - } - - if arg.HasAgent != "" { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID) - if err != nil { - return nil, xerrors.Errorf("get workspace resources: %w", err) - } - - var workspaceResourceIDs []uuid.UUID - for _, wr := range workspaceResources { - workspaceResourceIDs = append(workspaceResourceIDs, wr.ID) - } - - workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs) - if err != nil { - return nil, xerrors.Errorf("get workspace agents: %w", err) - } - - var hasAgentMatched bool - for _, wa := range workspaceAgents { - if mapAgentStatus(wa, arg.AgentInactiveDisconnectTimeoutSeconds) == arg.HasAgent { - hasAgentMatched = true - } - } - - if !hasAgentMatched { - continue - } - } - - if arg.Dormant && !workspace.DormantAt.Valid { - continue - } - - if len(arg.TemplateIDs) > 0 { - match := false - for _, id := range arg.TemplateIDs { - if workspace.TemplateID == id { - match = true - break - } - } - if !match { - continue - } - } - - if len(arg.WorkspaceIds) > 0 { - match := false - for _, id := range arg.WorkspaceIds { - if workspace.ID == id { - match = true - break - } - } - if !match { - continue - } - } - - if arg.HasAITask.Valid { - hasAITask, err := func() (bool, error) { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return false, xerrors.Errorf("get latest build: %w", err) - } - if build.HasAITask.Valid { - return build.HasAITask.Bool, nil - } - // If the build has a nil AI task, check if the job is in progress - // and if it has a non-empty AI Prompt parameter - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return false, xerrors.Errorf("get provisioner job: %w", err) - } - if job.CompletedAt.Valid { - return false, nil - } - parameters, err := q.getWorkspaceBuildParametersNoLock(build.ID) - if err != nil { - return false, xerrors.Errorf("get workspace build parameters: %w", err) - } - for _, param := range parameters { - if param.Name == "AI Prompt" && param.Value != "" { - return true, nil - } - } - return false, nil - }() - if err != nil { - return nil, xerrors.Errorf("get hasAITask: %w", err) - } - if hasAITask != arg.HasAITask.Bool { - continue - } - } - - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil { - continue - } - workspaces = append(workspaces, workspace) - } - - // Sort workspaces (ORDER BY) - isRunning := func(build database.WorkspaceBuild, job database.ProvisionerJob) bool { - return job.CompletedAt.Valid && !job.CanceledAt.Valid && !job.Error.Valid && build.Transition == database.WorkspaceTransitionStart - } - - preloadedWorkspaceBuilds := map[uuid.UUID]database.WorkspaceBuild{} - preloadedProvisionerJobs := map[uuid.UUID]database.ProvisionerJob{} - preloadedUsers := map[uuid.UUID]database.User{} - - for _, w := range workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID) - if err == nil { - preloadedWorkspaceBuilds[w.ID] = build - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err == nil { - preloadedProvisionerJobs[w.ID] = job - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - user, err := q.getUserByIDNoLock(w.OwnerID) - if err == nil { - preloadedUsers[w.ID] = user - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get user: %w", err) - } - } - - sort.Slice(workspaces, func(i, j int) bool { - w1 := workspaces[i] - w2 := workspaces[j] - - // Order by: favorite first - if arg.RequesterID == w1.OwnerID && w1.Favorite { - return true - } - if arg.RequesterID == w2.OwnerID && w2.Favorite { - return false - } - - // Order by: running - w1IsRunning := isRunning(preloadedWorkspaceBuilds[w1.ID], preloadedProvisionerJobs[w1.ID]) - w2IsRunning := isRunning(preloadedWorkspaceBuilds[w2.ID], preloadedProvisionerJobs[w2.ID]) - - if w1IsRunning && !w2IsRunning { - return true - } - - if !w1IsRunning && w2IsRunning { - return false - } - - // Order by: usernames - if strings.Compare(preloadedUsers[w1.ID].Username, preloadedUsers[w2.ID].Username) < 0 { - return true - } - - // Order by: workspace names - return strings.Compare(w1.Name, w2.Name) < 0 - }) - - beforePageCount := len(workspaces) - - if arg.Offset > 0 { - if int(arg.Offset) > len(workspaces) { - return q.convertToWorkspaceRowsNoLock(ctx, []database.WorkspaceTable{}, int64(beforePageCount), arg.WithSummary), nil - } - workspaces = workspaces[arg.Offset:] - } - if arg.Limit > 0 { - if int(arg.Limit) > len(workspaces) { - return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount), arg.WithSummary), nil - } - workspaces = workspaces[:arg.Limit] - } - - return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount), arg.WithSummary), nil -} - -func (q *FakeQuerier) GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx context.Context, ownerID uuid.UUID, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesAndAgentsByOwnerIDRow, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if prepared != nil { - // Call this to match the same function calls as the SQL implementation. - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) - if err != nil { - return nil, err - } - } - workspaces := make([]database.WorkspaceTable, 0) - for _, workspace := range q.workspaces { - if workspace.OwnerID == ownerID && !workspace.Deleted { - workspaces = append(workspaces, workspace) - } - } - - out := make([]database.GetWorkspacesAndAgentsByOwnerIDRow, 0, len(workspaces)) - for _, w := range workspaces { - // these always exist - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - outAgents := make([]database.AgentIDNamePair, 0) - resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID) - if err != nil { - return nil, xerrors.Errorf("get workspace resources: %w", err) - } - if len(resources) > 0 { - agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, []uuid.UUID{resources[0].ID}) - if err != nil { - return nil, xerrors.Errorf("get workspace agents: %w", err) - } - for _, a := range agents { - outAgents = append(outAgents, database.AgentIDNamePair{ - ID: a.ID, - Name: a.Name, - }) - } - } - - out = append(out, database.GetWorkspacesAndAgentsByOwnerIDRow{ - ID: w.ID, - Name: w.Name, - JobStatus: job.JobStatus, - Transition: build.Transition, - Agents: outAgents, - }) - } - - return out, nil -} - -func (q *FakeQuerier) GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIDs []uuid.UUID, prepared rbac.PreparedAuthorized) ([]database.WorkspaceBuildParameter, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - if prepared != nil { - // Call this to match the same function calls as the SQL implementation. - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) - if err != nil { - return nil, err - } - } - - filteredParameters := make([]database.WorkspaceBuildParameter, 0) - for _, buildID := range workspaceBuildIDs { - parameters, err := q.GetWorkspaceBuildParameters(ctx, buildID) - if err != nil { - return nil, err - } - filteredParameters = append(filteredParameters, parameters...) - } - - return filteredParameters, nil -} - -func (q *FakeQuerier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - // Call this to match the same function calls as the SQL implementation. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ - VariableConverter: regosql.UserConverter(), - }) - if err != nil { - return nil, err - } - } - - users, err := q.GetUsers(ctx, arg) - if err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - filteredUsers := make([]database.GetUsersRow, 0, len(users)) - for _, user := range users { - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { - continue - } - - filteredUsers = append(filteredUsers, user) - } - return filteredUsers, nil -} - -func (q *FakeQuerier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg database.GetAuditLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]database.GetAuditLogsOffsetRow, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - // Call this to match the same function calls as the SQL implementation. - // It functionally does nothing for filtering. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ - VariableConverter: regosql.AuditLogConverter(), - }) - if err != nil { - return nil, err - } - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - if arg.LimitOpt == 0 { - // Default to 100 is set in the SQL query. - arg.LimitOpt = 100 - } - - logs := make([]database.GetAuditLogsOffsetRow, 0, arg.LimitOpt) - - // q.auditLogs are already sorted by time DESC, so no need to sort after the fact. - for _, alog := range q.auditLogs { - if arg.OffsetOpt > 0 { - arg.OffsetOpt-- - continue - } - if arg.RequestID != uuid.Nil && arg.RequestID != alog.RequestID { - continue - } - if arg.OrganizationID != uuid.Nil && arg.OrganizationID != alog.OrganizationID { - continue - } - if arg.Action != "" && string(alog.Action) != arg.Action { - continue - } - if arg.ResourceType != "" && !strings.Contains(string(alog.ResourceType), arg.ResourceType) { - continue - } - if arg.ResourceID != uuid.Nil && alog.ResourceID != arg.ResourceID { - continue - } - if arg.Username != "" { - user, err := q.getUserByIDNoLock(alog.UserID) - if err == nil && !strings.EqualFold(arg.Username, user.Username) { - continue - } - } - if arg.Email != "" { - user, err := q.getUserByIDNoLock(alog.UserID) - if err == nil && !strings.EqualFold(arg.Email, user.Email) { - continue - } - } - if !arg.DateFrom.IsZero() { - if alog.Time.Before(arg.DateFrom) { - continue - } - } - if !arg.DateTo.IsZero() { - if alog.Time.After(arg.DateTo) { - continue - } - } - if arg.BuildReason != "" { - workspaceBuild, err := q.getWorkspaceBuildByIDNoLock(context.Background(), alog.ResourceID) - if err == nil && !strings.EqualFold(arg.BuildReason, string(workspaceBuild.Reason)) { - continue - } - } - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, alog.RBACObject()) != nil { - continue - } - - user, err := q.getUserByIDNoLock(alog.UserID) - userValid := err == nil - - org, _ := q.getOrganizationByIDNoLock(alog.OrganizationID) - - cpy := alog - logs = append(logs, database.GetAuditLogsOffsetRow{ - AuditLog: cpy, - OrganizationName: org.Name, - OrganizationDisplayName: org.DisplayName, - OrganizationIcon: org.Icon, - UserUsername: sql.NullString{String: user.Username, Valid: userValid}, - UserName: sql.NullString{String: user.Name, Valid: userValid}, - UserEmail: sql.NullString{String: user.Email, Valid: userValid}, - UserCreatedAt: sql.NullTime{Time: user.CreatedAt, Valid: userValid}, - UserUpdatedAt: sql.NullTime{Time: user.UpdatedAt, Valid: userValid}, - UserLastSeenAt: sql.NullTime{Time: user.LastSeenAt, Valid: userValid}, - UserLoginType: database.NullLoginType{LoginType: user.LoginType, Valid: userValid}, - UserDeleted: sql.NullBool{Bool: user.Deleted, Valid: userValid}, - UserQuietHoursSchedule: sql.NullString{String: user.QuietHoursSchedule, Valid: userValid}, - UserStatus: database.NullUserStatus{UserStatus: user.Status, Valid: userValid}, - UserRoles: user.RBACRoles, - }) - - if len(logs) >= int(arg.LimitOpt) { - break - } - } - - return logs, nil -} - -func (q *FakeQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, prepared rbac.PreparedAuthorized) (int64, error) { - if err := validateDatabaseType(arg); err != nil { - return 0, err - } - - // Call this to match the same function calls as the SQL implementation. - // It functionally does nothing for filtering. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ - VariableConverter: regosql.AuditLogConverter(), - }) - if err != nil { - return 0, err - } - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - var count int64 - - // q.auditLogs are already sorted by time DESC, so no need to sort after the fact. - for _, alog := range q.auditLogs { - if arg.RequestID != uuid.Nil && arg.RequestID != alog.RequestID { - continue - } - if arg.OrganizationID != uuid.Nil && arg.OrganizationID != alog.OrganizationID { - continue - } - if arg.Action != "" && string(alog.Action) != arg.Action { - continue - } - if arg.ResourceType != "" && !strings.Contains(string(alog.ResourceType), arg.ResourceType) { - continue - } - if arg.ResourceID != uuid.Nil && alog.ResourceID != arg.ResourceID { - continue - } - if arg.Username != "" { - user, err := q.getUserByIDNoLock(alog.UserID) - if err == nil && !strings.EqualFold(arg.Username, user.Username) { - continue - } - } - if arg.Email != "" { - user, err := q.getUserByIDNoLock(alog.UserID) - if err == nil && !strings.EqualFold(arg.Email, user.Email) { - continue - } - } - if !arg.DateFrom.IsZero() { - if alog.Time.Before(arg.DateFrom) { - continue - } - } - if !arg.DateTo.IsZero() { - if alog.Time.After(arg.DateTo) { - continue - } - } - if arg.BuildReason != "" { - workspaceBuild, err := q.getWorkspaceBuildByIDNoLock(context.Background(), alog.ResourceID) - if err == nil && !strings.EqualFold(arg.BuildReason, string(workspaceBuild.Reason)) { - continue - } - } - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, alog.RBACObject()) != nil { - continue - } - - count++ - } - - return count, nil -} diff --git a/coderd/database/dbmem/dbmem_test.go b/coderd/database/dbmem/dbmem_test.go deleted file mode 100644 index c3df828b95c98..0000000000000 --- a/coderd/database/dbmem/dbmem_test.go +++ /dev/null @@ -1,209 +0,0 @@ -package dbmem_test - -import ( - "context" - "database/sql" - "sort" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" - "github.com/coder/coder/v2/coderd/database/dbtime" -) - -// test that transactions don't deadlock, and that we don't see intermediate state. -func TestInTx(t *testing.T) { - t.Parallel() - - uut := dbmem.New() - - inTx := make(chan any) - queriesDone := make(chan any) - queriesStarted := make(chan any) - go func() { - err := uut.InTx(func(tx database.Store) error { - close(inTx) - _, err := tx.InsertOrganization(context.Background(), database.InsertOrganizationParams{ - Name: "1", - }) - assert.NoError(t, err) - <-queriesStarted - time.Sleep(5 * time.Millisecond) - _, err = tx.InsertOrganization(context.Background(), database.InsertOrganizationParams{ - Name: "2", - }) - assert.NoError(t, err) - return nil - }, nil) - assert.NoError(t, err) - }() - var nums []int - go func() { - <-inTx - for i := 0; i < 20; i++ { - orgs, err := uut.GetOrganizations(context.Background(), database.GetOrganizationsParams{}) - if err != nil { - assert.ErrorIs(t, err, sql.ErrNoRows) - } - nums = append(nums, len(orgs)) - time.Sleep(time.Millisecond) - } - close(queriesDone) - }() - close(queriesStarted) - <-queriesDone - // ensure we never saw 1 org, only 0 or 2. - for i := 0; i < 20; i++ { - assert.NotEqual(t, 1, nums[i]) - } -} - -// TestUserOrder ensures that the fake database returns users sorted by username. -func TestUserOrder(t *testing.T) { - t.Parallel() - - db := dbmem.New() - now := dbtime.Now() - - usernames := []string{"b-user", "d-user", "a-user", "c-user", "e-user"} - for _, username := range usernames { - dbgen.User(t, db, database.User{Username: username, CreatedAt: now}) - } - - users, err := db.GetUsers(context.Background(), database.GetUsersParams{}) - require.NoError(t, err) - require.Lenf(t, users, len(usernames), "expected %d users", len(usernames)) - - sort.Strings(usernames) - for i, user := range users { - require.Equal(t, usernames[i], user.Username) - } -} - -func TestProxyByHostname(t *testing.T) { - t.Parallel() - - db := dbmem.New() - - // Insert a bunch of different proxies. - proxies := []struct { - name string - accessURL string - wildcardHostname string - }{ - { - name: "one", - accessURL: "https://one.coder.com", - wildcardHostname: "*.wildcard.one.coder.com", - }, - { - name: "two", - accessURL: "https://two.coder.com", - wildcardHostname: "*--suffix.two.coder.com", - }, - } - for _, p := range proxies { - dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{ - Name: p.name, - Url: p.accessURL, - WildcardHostname: p.wildcardHostname, - }) - } - - cases := []struct { - name string - testHostname string - allowAccessURL bool - allowWildcardHost bool - matchProxyName string - }{ - { - name: "NoMatch", - testHostname: "test.com", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "", - }, - { - name: "MatchAccessURL", - testHostname: "one.coder.com", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "one", - }, - { - name: "MatchWildcard", - testHostname: "something.wildcard.one.coder.com", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "one", - }, - { - name: "MatchSuffix", - testHostname: "something--suffix.two.coder.com", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "two", - }, - { - name: "ValidateHostname/1", - testHostname: ".*ne.coder.com", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "", - }, - { - name: "ValidateHostname/2", - testHostname: "https://one.coder.com", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "", - }, - { - name: "ValidateHostname/3", - testHostname: "one.coder.com:8080/hello", - allowAccessURL: true, - allowWildcardHost: true, - matchProxyName: "", - }, - { - name: "IgnoreAccessURLMatch", - testHostname: "one.coder.com", - allowAccessURL: false, - allowWildcardHost: true, - matchProxyName: "", - }, - { - name: "IgnoreWildcardMatch", - testHostname: "hi.wildcard.one.coder.com", - allowAccessURL: true, - allowWildcardHost: false, - matchProxyName: "", - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - t.Parallel() - - proxy, err := db.GetWorkspaceProxyByHostname(context.Background(), database.GetWorkspaceProxyByHostnameParams{ - Hostname: c.testHostname, - AllowAccessUrl: c.allowAccessURL, - AllowWildcardHostname: c.allowWildcardHost, - }) - if c.matchProxyName == "" { - require.ErrorIs(t, err, sql.ErrNoRows) - require.Empty(t, proxy) - } else { - require.NoError(t, err) - require.NotEmpty(t, proxy) - require.Equal(t, c.matchProxyName, proxy.Name) - } - }) - } -} diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index debb8c2b89f56..e353a4688281d 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -656,6 +656,13 @@ func (m queryMetricsStore) GetAuthorizationUserRoles(ctx context.Context, userID return row, err } +func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) { + start := time.Now() + r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg) + m.queryLatencies.WithLabelValues("GetConnectionLogsOffset").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) { start := time.Now() r0, r1 := m.s.GetCoordinatorResumeTokenSigningKey(ctx) @@ -1335,6 +1342,13 @@ func (m queryMetricsStore) GetRunningPrebuiltWorkspaces(ctx context.Context) ([] return r0, r1 } +func (m queryMetricsStore) GetRunningPrebuiltWorkspacesOptimized(ctx context.Context) ([]database.GetRunningPrebuiltWorkspacesOptimizedRow, error) { + start := time.Now() + r0, r1 := m.s.GetRunningPrebuiltWorkspacesOptimized(ctx) + m.queryLatencies.WithLabelValues("GetRunningPrebuiltWorkspacesOptimized").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) GetRuntimeConfig(ctx context.Context, key string) (string, error) { start := time.Now() r0, r1 := m.s.GetRuntimeConfig(ctx, key) @@ -3155,6 +3169,13 @@ func (m queryMetricsStore) UpsertApplicationName(ctx context.Context, value stri return r0 } +func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) { + start := time.Now() + r0, r1 := m.s.UpsertConnectionLog(ctx, arg) + m.queryLatencies.WithLabelValues("UpsertConnectionLog").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error { start := time.Now() r0 := m.s.UpsertCoordinatorResumeTokenSigningKey(ctx, value) @@ -3385,3 +3406,10 @@ func (m queryMetricsStore) CountAuthorizedAuditLogs(ctx context.Context, arg dat m.queryLatencies.WithLabelValues("CountAuthorizedAuditLogs").Observe(time.Since(start).Seconds()) return r0, r1 } + +func (m queryMetricsStore) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) { + start := time.Now() + r0, r1 := m.s.GetAuthorizedConnectionLogsOffset(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedConnectionLogsOffset").Observe(time.Since(start).Seconds()) + return r0, r1 +} diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 059f37f8852b9..14e5344325b9b 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1248,6 +1248,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedAuditLogsOffset(ctx, arg, prepared return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedAuditLogsOffset", reflect.TypeOf((*MockStore)(nil).GetAuthorizedAuditLogsOffset), ctx, arg, prepared) } +// GetAuthorizedConnectionLogsOffset mocks base method. +func (m *MockStore) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAuthorizedConnectionLogsOffset", ctx, arg, prepared) + ret0, _ := ret[0].([]database.GetConnectionLogsOffsetRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAuthorizedConnectionLogsOffset indicates an expected call of GetAuthorizedConnectionLogsOffset. +func (mr *MockStoreMockRecorder) GetAuthorizedConnectionLogsOffset(ctx, arg, prepared any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedConnectionLogsOffset", reflect.TypeOf((*MockStore)(nil).GetAuthorizedConnectionLogsOffset), ctx, arg, prepared) +} + // GetAuthorizedTemplates mocks base method. func (m *MockStore) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { m.ctrl.T.Helper() @@ -1323,6 +1338,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedWorkspacesAndAgentsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetAuthorizedWorkspacesAndAgentsByOwnerID), ctx, ownerID, prepared) } +// GetConnectionLogsOffset mocks base method. +func (m *MockStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetConnectionLogsOffset", ctx, arg) + ret0, _ := ret[0].([]database.GetConnectionLogsOffsetRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetConnectionLogsOffset indicates an expected call of GetConnectionLogsOffset. +func (mr *MockStoreMockRecorder) GetConnectionLogsOffset(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConnectionLogsOffset", reflect.TypeOf((*MockStore)(nil).GetConnectionLogsOffset), ctx, arg) +} + // GetCoordinatorResumeTokenSigningKey mocks base method. func (m *MockStore) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) { m.ctrl.T.Helper() @@ -2778,6 +2808,21 @@ func (mr *MockStoreMockRecorder) GetRunningPrebuiltWorkspaces(ctx any) *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRunningPrebuiltWorkspaces", reflect.TypeOf((*MockStore)(nil).GetRunningPrebuiltWorkspaces), ctx) } +// GetRunningPrebuiltWorkspacesOptimized mocks base method. +func (m *MockStore) GetRunningPrebuiltWorkspacesOptimized(ctx context.Context) ([]database.GetRunningPrebuiltWorkspacesOptimizedRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRunningPrebuiltWorkspacesOptimized", ctx) + ret0, _ := ret[0].([]database.GetRunningPrebuiltWorkspacesOptimizedRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRunningPrebuiltWorkspacesOptimized indicates an expected call of GetRunningPrebuiltWorkspacesOptimized. +func (mr *MockStoreMockRecorder) GetRunningPrebuiltWorkspacesOptimized(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRunningPrebuiltWorkspacesOptimized", reflect.TypeOf((*MockStore)(nil).GetRunningPrebuiltWorkspacesOptimized), ctx) +} + // GetRuntimeConfig mocks base method. func (m *MockStore) GetRuntimeConfig(ctx context.Context, key string) (string, error) { m.ctrl.T.Helper() @@ -6683,6 +6728,21 @@ func (mr *MockStoreMockRecorder) UpsertApplicationName(ctx, value any) *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertApplicationName", reflect.TypeOf((*MockStore)(nil).UpsertApplicationName), ctx, value) } +// UpsertConnectionLog mocks base method. +func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertConnectionLog", ctx, arg) + ret0, _ := ret[0].(database.ConnectionLog) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertConnectionLog indicates an expected call of UpsertConnectionLog. +func (mr *MockStoreMockRecorder) UpsertConnectionLog(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertConnectionLog", reflect.TypeOf((*MockStore)(nil).UpsertConnectionLog), ctx, arg) +} + // UpsertCoordinatorResumeTokenSigningKey mocks base method. func (m *MockStore) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error { m.ctrl.T.Helper() diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index fa3567c490826..f67e3206b09d1 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -20,14 +20,15 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/testutil" ) // WillUsePostgres returns true if a call to NewDB() will return a real, postgres-backed Store and Pubsub. +// TODO(hugodutka): since we removed the in-memory database, this is always true, +// and we need to remove this function. https://github.com/coder/internal/issues/758 func WillUsePostgres() bool { - return os.Getenv("DB") != "" + return true } type options struct { @@ -109,52 +110,48 @@ func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) { var db database.Store var ps pubsub.Pubsub - if WillUsePostgres() { - connectionURL := os.Getenv("CODER_PG_CONNECTION_URL") - if connectionURL == "" && o.url != "" { - connectionURL = o.url - } - if connectionURL == "" { - var err error - connectionURL, err = Open(t) - require.NoError(t, err) - } - - if o.fixedTimezone == "" { - // To make sure we find timezone-related issues, we set the timezone - // of the database to a non-UTC one. - // The below was picked due to the following properties: - // - It has a non-UTC offset - // - It has a fractional hour UTC offset - // - It includes a daylight savings time component - o.fixedTimezone = DefaultTimezone - } - dbName := dbNameFromConnectionURL(t, connectionURL) - setDBTimezone(t, connectionURL, dbName, o.fixedTimezone) - sqlDB, err := sql.Open("postgres", connectionURL) + connectionURL := os.Getenv("CODER_PG_CONNECTION_URL") + if connectionURL == "" && o.url != "" { + connectionURL = o.url + } + if connectionURL == "" { + var err error + connectionURL, err = Open(t) require.NoError(t, err) - t.Cleanup(func() { - _ = sqlDB.Close() - }) - if o.returnSQLDB != nil { - o.returnSQLDB(sqlDB) - } - if o.dumpOnFailure { - t.Cleanup(func() { DumpOnFailure(t, connectionURL) }) - } - // Unit tests should not retry serial transaction failures. - db = database.New(sqlDB, database.WithSerialRetryCount(1)) + } - ps, err = pubsub.New(context.Background(), o.logger, sqlDB, connectionURL) - require.NoError(t, err) - t.Cleanup(func() { - _ = ps.Close() - }) - } else { - db = dbmem.New() - ps = pubsub.NewInMemory() + if o.fixedTimezone == "" { + // To make sure we find timezone-related issues, we set the timezone + // of the database to a non-UTC one. + // The below was picked due to the following properties: + // - It has a non-UTC offset + // - It has a fractional hour UTC offset + // - It includes a daylight savings time component + o.fixedTimezone = DefaultTimezone + } + dbName := dbNameFromConnectionURL(t, connectionURL) + setDBTimezone(t, connectionURL, dbName, o.fixedTimezone) + + sqlDB, err := sql.Open("postgres", connectionURL) + require.NoError(t, err) + t.Cleanup(func() { + _ = sqlDB.Close() + }) + if o.returnSQLDB != nil { + o.returnSQLDB(sqlDB) + } + if o.dumpOnFailure { + t.Cleanup(func() { DumpOnFailure(t, connectionURL) }) } + // Unit tests should not retry serial transaction failures. + db = database.New(sqlDB, database.WithSerialRetryCount(1)) + + ps, err = pubsub.New(context.Background(), o.logger, sqlDB, connectionURL) + require.NoError(t, err) + t.Cleanup(func() { + _ = ps.Close() + }) return db, ps } diff --git a/coderd/database/dbtestutil/postgres.go b/coderd/database/dbtestutil/postgres.go index c1cfa383577de..e5aa4b14de83b 100644 --- a/coderd/database/dbtestutil/postgres.go +++ b/coderd/database/dbtestutil/postgres.go @@ -81,7 +81,7 @@ func initDefaultConnection(t TBSubset) error { } var dbErr error - // Retry up to 3 seconds for temporary errors. + // Retry up to 10 seconds for temporary errors. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() for r := retry.New(10*time.Millisecond, 500*time.Millisecond); r.Wait(ctx); { @@ -93,7 +93,7 @@ func initDefaultConnection(t TBSubset) error { if !containsAnySubstring(errString, retryableErrSubstrings) { break } - t.Logf("failed to connect to postgres, retrying: %s", errString) + t.Logf("%s failed to connect to postgres, retrying: %s", time.Now().Format(time.StampMilli), errString) } // After the loop dbErr is the last connection error (if any). diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 54f984294fa4e..26818fbf6c99d 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -38,6 +38,8 @@ CREATE TYPE audit_action AS ENUM ( 'close' ); +COMMENT ON TYPE audit_action IS 'NOTE: `connect`, `disconnect`, `open`, and `close` are deprecated and no longer used - these events are now tracked in the connection_logs table.'; + CREATE TYPE automatic_updates AS ENUM ( 'always', 'never' @@ -52,6 +54,20 @@ CREATE TYPE build_reason AS ENUM ( 'autodelete' ); +CREATE TYPE connection_status AS ENUM ( + 'connected', + 'disconnected' +); + +CREATE TYPE connection_type AS ENUM ( + 'ssh', + 'vscode', + 'jetbrains', + 'reconnecting_pty', + 'workspace_app', + 'port_forwarding' +); + CREATE TYPE crypto_key_feature AS ENUM ( 'workspace_apps_token', 'workspace_apps_api_key', @@ -823,6 +839,39 @@ CREATE TABLE audit_logs ( resource_icon text NOT NULL ); +CREATE TABLE connection_logs ( + id uuid NOT NULL, + connect_time timestamp with time zone NOT NULL, + organization_id uuid NOT NULL, + workspace_owner_id uuid NOT NULL, + workspace_id uuid NOT NULL, + workspace_name text NOT NULL, + agent_name text NOT NULL, + type connection_type NOT NULL, + ip inet NOT NULL, + code integer, + user_agent text, + user_id uuid, + slug_or_port text, + connection_id uuid, + disconnect_time timestamp with time zone, + disconnect_reason text +); + +COMMENT ON COLUMN connection_logs.code IS 'Either the HTTP status code of the web request, or the exit code of an SSH connection. For non-web connections, this is Null until we receive a disconnect event for the same connection_id.'; + +COMMENT ON COLUMN connection_logs.user_agent IS 'Null for SSH events. For web connections, this is the User-Agent header from the request.'; + +COMMENT ON COLUMN connection_logs.user_id IS 'Null for SSH events. For web connections, this is the ID of the user that made the request.'; + +COMMENT ON COLUMN connection_logs.slug_or_port IS 'Null for SSH events. For web connections, this is the slug of the app or the port number being forwarded.'; + +COMMENT ON COLUMN connection_logs.connection_id IS 'The SSH connection ID. Used to correlate connections and disconnections. As it originates from the agent, it is not guaranteed to be unique.'; + +COMMENT ON COLUMN connection_logs.disconnect_time IS 'The time the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id.'; + +COMMENT ON COLUMN connection_logs.disconnect_reason IS 'The reason the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id.'; + CREATE TABLE crypto_keys ( feature crypto_key_feature NOT NULL, sequence integer NOT NULL, @@ -2413,6 +2462,9 @@ ALTER TABLE ONLY api_keys ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); +ALTER TABLE ONLY connection_logs + ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id); + ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence); @@ -2699,6 +2751,18 @@ CREATE INDEX idx_audit_log_user_id ON audit_logs USING btree (user_id); CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC); +CREATE INDEX idx_connection_logs_connect_time_desc ON connection_logs USING btree (connect_time DESC); + +CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name); + +COMMENT ON INDEX idx_connection_logs_connection_id_workspace_id_agent_name IS 'Connection ID is NULL for web events, but present for SSH events. Therefore, this index allows multiple web events for the same workspace & agent. For SSH events, the upsertion query handles duplicates on this index by upserting the disconnect_time and disconnect_reason for the same connection_id when the connection is closed.'; + +CREATE INDEX idx_connection_logs_organization_id ON connection_logs USING btree (organization_id); + +CREATE INDEX idx_connection_logs_workspace_id ON connection_logs USING btree (workspace_id); + +CREATE INDEX idx_connection_logs_workspace_owner_id ON connection_logs USING btree (workspace_owner_id); + CREATE INDEX idx_custom_roles_id ON custom_roles USING btree (id); CREATE UNIQUE INDEX idx_custom_roles_name_lower ON custom_roles USING btree (lower(name)); @@ -2906,6 +2970,15 @@ forward without requiring a migration to clean up historical data.'; ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; +ALTER TABLE ONLY connection_logs + ADD CONSTRAINT connection_logs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; + +ALTER TABLE ONLY connection_logs + ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE; + +ALTER TABLE ONLY connection_logs + ADD CONSTRAINT connection_logs_workspace_owner_id_fkey FOREIGN KEY (workspace_owner_id) REFERENCES users(id) ON DELETE CASCADE; + ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest); diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index b3b2d631aaa4d..c3aaf7342a97c 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -7,6 +7,9 @@ type ForeignKeyConstraint string // ForeignKeyConstraint enums. const ( ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; + ForeignKeyConnectionLogsOrganizationID ForeignKeyConstraint = "connection_logs_organization_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; + ForeignKeyConnectionLogsWorkspaceID ForeignKeyConstraint = "connection_logs_workspace_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE; + ForeignKeyConnectionLogsWorkspaceOwnerID ForeignKeyConstraint = "connection_logs_workspace_owner_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_owner_id_fkey FOREIGN KEY (workspace_owner_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyCryptoKeysSecretKeyID ForeignKeyConstraint = "crypto_keys_secret_key_id_fkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyFkOauth2ProviderAppTokensUserID ForeignKeyConstraint = "fk_oauth2_provider_app_tokens_user_id" // ALTER TABLE ONLY oauth2_provider_app_tokens ADD CONSTRAINT fk_oauth2_provider_app_tokens_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyGitAuthLinksOauthAccessTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); diff --git a/coderd/database/migrations/000349_connection_logs.down.sql b/coderd/database/migrations/000349_connection_logs.down.sql new file mode 100644 index 0000000000000..1a00797086402 --- /dev/null +++ b/coderd/database/migrations/000349_connection_logs.down.sql @@ -0,0 +1,11 @@ +DROP INDEX IF EXISTS idx_connection_logs_workspace_id; +DROP INDEX IF EXISTS idx_connection_logs_workspace_owner_id; +DROP INDEX IF EXISTS idx_connection_logs_organization_id; +DROP INDEX IF EXISTS idx_connection_logs_connect_time_desc; +DROP INDEX IF EXISTS idx_connection_logs_connection_id_workspace_id_agent_name; + +DROP TABLE IF EXISTS connection_logs; + +DROP TYPE IF EXISTS connection_type; + +DROP TYPE IF EXISTS connection_status; diff --git a/coderd/database/migrations/000349_connection_logs.up.sql b/coderd/database/migrations/000349_connection_logs.up.sql new file mode 100644 index 0000000000000..b9d7f0cdda41c --- /dev/null +++ b/coderd/database/migrations/000349_connection_logs.up.sql @@ -0,0 +1,68 @@ +CREATE TYPE connection_status AS ENUM ( + 'connected', + 'disconnected' +); + +CREATE TYPE connection_type AS ENUM ( + -- SSH events + 'ssh', + 'vscode', + 'jetbrains', + 'reconnecting_pty', + -- Web events + 'workspace_app', + 'port_forwarding' +); + +CREATE TABLE connection_logs ( + id uuid NOT NULL, + connect_time timestamp with time zone NOT NULL, + organization_id uuid NOT NULL REFERENCES organizations (id) ON DELETE CASCADE, + workspace_owner_id uuid NOT NULL REFERENCES users (id) ON DELETE CASCADE, + workspace_id uuid NOT NULL REFERENCES workspaces (id) ON DELETE CASCADE, + workspace_name text NOT NULL, + agent_name text NOT NULL, + type connection_type NOT NULL, + ip inet NOT NULL, + code integer, + + -- Only set for web events + user_agent text, + user_id uuid, + slug_or_port text, + + -- Null for web events + connection_id uuid, + disconnect_time timestamp with time zone, -- Null until we upsert a disconnect log for the same connection_id. + disconnect_reason text, + + PRIMARY KEY (id) +); + + +COMMENT ON COLUMN connection_logs.code IS 'Either the HTTP status code of the web request, or the exit code of an SSH connection. For non-web connections, this is Null until we receive a disconnect event for the same connection_id.'; + +COMMENT ON COLUMN connection_logs.user_agent IS 'Null for SSH events. For web connections, this is the User-Agent header from the request.'; + +COMMENT ON COLUMN connection_logs.user_id IS 'Null for SSH events. For web connections, this is the ID of the user that made the request.'; + +COMMENT ON COLUMN connection_logs.slug_or_port IS 'Null for SSH events. For web connections, this is the slug of the app or the port number being forwarded.'; + +COMMENT ON COLUMN connection_logs.connection_id IS 'The SSH connection ID. Used to correlate connections and disconnections. As it originates from the agent, it is not guaranteed to be unique.'; + +COMMENT ON COLUMN connection_logs.disconnect_time IS 'The time the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id.'; + +COMMENT ON COLUMN connection_logs.disconnect_reason IS 'The reason the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id.'; + +COMMENT ON TYPE audit_action IS 'NOTE: `connect`, `disconnect`, `open`, and `close` are deprecated and no longer used - these events are now tracked in the connection_logs table.'; + +-- To associate connection closure events with the connection start events. +CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name +ON connection_logs (connection_id, workspace_id, agent_name); + +COMMENT ON INDEX idx_connection_logs_connection_id_workspace_id_agent_name IS 'Connection ID is NULL for web events, but present for SSH events. Therefore, this index allows multiple web events for the same workspace & agent. For SSH events, the upsertion query handles duplicates on this index by upserting the disconnect_time and disconnect_reason for the same connection_id when the connection is closed.'; + +CREATE INDEX idx_connection_logs_connect_time_desc ON connection_logs USING btree (connect_time DESC); +CREATE INDEX idx_connection_logs_organization_id ON connection_logs USING btree (organization_id); +CREATE INDEX idx_connection_logs_workspace_owner_id ON connection_logs USING btree (workspace_owner_id); +CREATE INDEX idx_connection_logs_workspace_id ON connection_logs USING btree (workspace_id); diff --git a/coderd/database/migrations/testdata/fixtures/000349_connection_logs.up.sql b/coderd/database/migrations/testdata/fixtures/000349_connection_logs.up.sql new file mode 100644 index 0000000000000..bbddf5226bc29 --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000349_connection_logs.up.sql @@ -0,0 +1,53 @@ +INSERT INTO connection_logs ( + id, + connect_time, + organization_id, + workspace_owner_id, + workspace_id, + workspace_name, + agent_name, + type, + code, + ip, + user_agent, + user_id, + slug_or_port, + connection_id, + disconnect_time, + disconnect_reason +) VALUES ( + '00000000-0000-0000-0000-000000000001', -- log id + '2023-10-01 12:00:00+00', -- start time + 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', -- organization id + 'a0061a8e-7db7-4585-838c-3116a003dd21', -- workspace owner id + '3a9a1feb-e89d-457c-9d53-ac751b198ebe', -- workspace id + 'Test Workspace', -- workspace name + 'test-agent', -- agent name + 'ssh', -- type + 0, -- code + '127.0.0.1', -- ip + NULL, -- user agent + NULL, -- user id + NULL, -- slug or port + '00000000-0000-0000-0000-000000000003', -- connection id + '2023-10-01 12:00:10+00', -- close time + 'server shut down' -- reason +), +( + '00000000-0000-0000-0000-000000000002', -- log id + '2023-10-01 12:05:00+00', -- start time + 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', -- organization id + 'a0061a8e-7db7-4585-838c-3116a003dd21', -- workspace owner id + '3a9a1feb-e89d-457c-9d53-ac751b198ebe', -- workspace id + 'Test Workspace', -- workspace name + 'test-agent', -- agent name + 'workspace_app', -- type + 200, -- code + '127.0.0.1', + 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.4896.127 Safari/537.36', + 'a0061a8e-7db7-4585-838c-3116a003dd21', -- user id + 'code-server', -- slug or port + NULL, -- connection id (request ID) + NULL, -- close time + NULL -- reason +); diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 07e1f2dc32352..b49fa113d4b12 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -117,6 +117,19 @@ func (w AuditLog) RBACObject() rbac.Object { return obj } +func (w GetConnectionLogsOffsetRow) RBACObject() rbac.Object { + return w.ConnectionLog.RBACObject() +} + +func (w ConnectionLog) RBACObject() rbac.Object { + obj := rbac.ResourceConnectionLog.WithID(w.ID) + if w.OrganizationID != uuid.Nil { + obj = obj.InOrg(w.OrganizationID) + } + + return obj +} + func (s APIKeyScope) ToRBAC() rbac.ScopeName { switch s { case APIKeyScopeAll: diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 785ccf86afd27..c0892aebdeb01 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -50,6 +50,7 @@ type customQuerier interface { workspaceQuerier userQuerier auditLogQuerier + connectionLogQuerier } type templateQuerier interface { @@ -611,6 +612,81 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi return count, nil } +type connectionLogQuerier interface { + GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error) +} + +func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error) { + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.ConnectionLogConverter(), + }) + if err != nil { + return nil, xerrors.Errorf("compile authorized filter: %w", err) + } + filtered, err := insertAuthorizedFilter(getConnectionLogsOffset, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return nil, xerrors.Errorf("insert authorized filter: %w", err) + } + + query := fmt.Sprintf("-- name: GetAuthorizedConnectionLogsOffset :many\n%s", filtered) + rows, err := q.db.QueryContext(ctx, query, + arg.OffsetOpt, + arg.LimitOpt, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetConnectionLogsOffsetRow + for rows.Next() { + var i GetConnectionLogsOffsetRow + if err := rows.Scan( + &i.ConnectionLog.ID, + &i.ConnectionLog.ConnectTime, + &i.ConnectionLog.OrganizationID, + &i.ConnectionLog.WorkspaceOwnerID, + &i.ConnectionLog.WorkspaceID, + &i.ConnectionLog.WorkspaceName, + &i.ConnectionLog.AgentName, + &i.ConnectionLog.Type, + &i.ConnectionLog.Ip, + &i.ConnectionLog.Code, + &i.ConnectionLog.UserAgent, + &i.ConnectionLog.UserID, + &i.ConnectionLog.SlugOrPort, + &i.ConnectionLog.ConnectionID, + &i.ConnectionLog.DisconnectTime, + &i.ConnectionLog.DisconnectReason, + &i.UserUsername, + &i.UserName, + &i.UserEmail, + &i.UserCreatedAt, + &i.UserUpdatedAt, + &i.UserLastSeenAt, + &i.UserStatus, + &i.UserLoginType, + &i.UserRoles, + &i.UserAvatarUrl, + &i.UserDeleted, + &i.UserQuietHoursSchedule, + &i.WorkspaceOwnerUsername, + &i.OrganizationName, + &i.OrganizationDisplayName, + &i.OrganizationIcon, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + func insertAuthorizedFilter(query string, replaceWith string) (string, error) { if !strings.Contains(query, authorizedQueryPlaceholder) { return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query") diff --git a/coderd/database/models.go b/coderd/database/models.go index 749de51118152..169f6a60be709 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -196,6 +196,7 @@ func AllAppSharingLevelValues() []AppSharingLevel { } } +// NOTE: `connect`, `disconnect`, `open`, and `close` are deprecated and no longer used - these events are now tracked in the connection_logs table. type AuditAction string const ( @@ -415,6 +416,134 @@ func AllBuildReasonValues() []BuildReason { } } +type ConnectionStatus string + +const ( + ConnectionStatusConnected ConnectionStatus = "connected" + ConnectionStatusDisconnected ConnectionStatus = "disconnected" +) + +func (e *ConnectionStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = ConnectionStatus(s) + case string: + *e = ConnectionStatus(s) + default: + return fmt.Errorf("unsupported scan type for ConnectionStatus: %T", src) + } + return nil +} + +type NullConnectionStatus struct { + ConnectionStatus ConnectionStatus `json:"connection_status"` + Valid bool `json:"valid"` // Valid is true if ConnectionStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullConnectionStatus) Scan(value interface{}) error { + if value == nil { + ns.ConnectionStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.ConnectionStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullConnectionStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.ConnectionStatus), nil +} + +func (e ConnectionStatus) Valid() bool { + switch e { + case ConnectionStatusConnected, + ConnectionStatusDisconnected: + return true + } + return false +} + +func AllConnectionStatusValues() []ConnectionStatus { + return []ConnectionStatus{ + ConnectionStatusConnected, + ConnectionStatusDisconnected, + } +} + +type ConnectionType string + +const ( + ConnectionTypeSsh ConnectionType = "ssh" + ConnectionTypeVscode ConnectionType = "vscode" + ConnectionTypeJetbrains ConnectionType = "jetbrains" + ConnectionTypeReconnectingPty ConnectionType = "reconnecting_pty" + ConnectionTypeWorkspaceApp ConnectionType = "workspace_app" + ConnectionTypePortForwarding ConnectionType = "port_forwarding" +) + +func (e *ConnectionType) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = ConnectionType(s) + case string: + *e = ConnectionType(s) + default: + return fmt.Errorf("unsupported scan type for ConnectionType: %T", src) + } + return nil +} + +type NullConnectionType struct { + ConnectionType ConnectionType `json:"connection_type"` + Valid bool `json:"valid"` // Valid is true if ConnectionType is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullConnectionType) Scan(value interface{}) error { + if value == nil { + ns.ConnectionType, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.ConnectionType.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullConnectionType) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.ConnectionType), nil +} + +func (e ConnectionType) Valid() bool { + switch e { + case ConnectionTypeSsh, + ConnectionTypeVscode, + ConnectionTypeJetbrains, + ConnectionTypeReconnectingPty, + ConnectionTypeWorkspaceApp, + ConnectionTypePortForwarding: + return true + } + return false +} + +func AllConnectionTypeValues() []ConnectionType { + return []ConnectionType{ + ConnectionTypeSsh, + ConnectionTypeVscode, + ConnectionTypeJetbrains, + ConnectionTypeReconnectingPty, + ConnectionTypeWorkspaceApp, + ConnectionTypePortForwarding, + } +} + type CryptoKeyFeature string const ( @@ -2784,6 +2913,32 @@ type AuditLog struct { ResourceIcon string `db:"resource_icon" json:"resource_icon"` } +type ConnectionLog struct { + ID uuid.UUID `db:"id" json:"id"` + ConnectTime time.Time `db:"connect_time" json:"connect_time"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"` + WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"` + WorkspaceName string `db:"workspace_name" json:"workspace_name"` + AgentName string `db:"agent_name" json:"agent_name"` + Type ConnectionType `db:"type" json:"type"` + Ip pqtype.Inet `db:"ip" json:"ip"` + // Either the HTTP status code of the web request, or the exit code of an SSH connection. For non-web connections, this is Null until we receive a disconnect event for the same connection_id. + Code sql.NullInt32 `db:"code" json:"code"` + // Null for SSH events. For web connections, this is the User-Agent header from the request. + UserAgent sql.NullString `db:"user_agent" json:"user_agent"` + // Null for SSH events. For web connections, this is the ID of the user that made the request. + UserID uuid.NullUUID `db:"user_id" json:"user_id"` + // Null for SSH events. For web connections, this is the slug of the app or the port number being forwarded. + SlugOrPort sql.NullString `db:"slug_or_port" json:"slug_or_port"` + // The SSH connection ID. Used to correlate connections and disconnections. As it originates from the agent, it is not guaranteed to be unique. + ConnectionID uuid.NullUUID `db:"connection_id" json:"connection_id"` + // The time the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id. + DisconnectTime sql.NullTime `db:"disconnect_time" json:"disconnect_time"` + // The reason the connection was closed. Null for web connections. For other connections, this is null until we receive a disconnect event for the same connection_id. + DisconnectReason sql.NullString `db:"disconnect_reason" json:"disconnect_reason"` +} + type CryptoKey struct { Feature CryptoKeyFeature `db:"feature" json:"feature"` Sequence int32 `db:"sequence" json:"sequence"` diff --git a/coderd/database/oidcclaims_test.go b/coderd/database/oidcclaims_test.go index f9fe1711b19b8..fe4a10d83495e 100644 --- a/coderd/database/oidcclaims_test.go +++ b/coderd/database/oidcclaims_test.go @@ -222,7 +222,6 @@ func (g userGenerator) withLink(lt database.LoginType, rawJSON json.RawMessage) err := sql.UpdateUserLinkRawJSON(context.Background(), user.ID, rawJSON) require.NoError(t, err) } else { - // no need to test the json key logic in dbmem. Everything is type safe. var claims database.UserLinkClaims err := json.Unmarshal(rawJSON, &claims) require.NoError(t, err) diff --git a/coderd/database/querier.go b/coderd/database/querier.go index dcbac88611dd0..8af37596cb5c6 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -156,6 +156,7 @@ type sqlcQuerier interface { // This function returns roles for authorization purposes. Implied member roles // are included. GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error) + GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error) GetCryptoKeys(ctx context.Context) ([]CryptoKey, error) @@ -298,6 +299,7 @@ type sqlcQuerier interface { GetReplicaByID(ctx context.Context, id uuid.UUID) (Replica, error) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) GetRunningPrebuiltWorkspaces(ctx context.Context) ([]GetRunningPrebuiltWorkspacesRow, error) + GetRunningPrebuiltWorkspacesOptimized(ctx context.Context) ([]GetRunningPrebuiltWorkspacesOptimizedRow, error) GetRuntimeConfig(ctx context.Context, key string) (string, error) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]TailnetAgent, error) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error) @@ -646,6 +648,7 @@ type sqlcQuerier interface { UpsertAnnouncementBanners(ctx context.Context, value string) error UpsertAppSecurityKey(ctx context.Context, value string) error UpsertApplicationName(ctx context.Context, value string) error + UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error // The default proxy is implied and not actually stored in the database. // So we need to store it's configuration here for display purposes. diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index f80f68115ad2c..298813276f902 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "sort" "testing" "time" @@ -13,6 +14,7 @@ import ( "github.com/google/uuid" "github.com/lib/pq" "github.com/prometheus/client_golang/prometheus" + "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1400,7 +1402,6 @@ func TestGetUsers_IncludeSystem(t *testing.T) { // Given: a system user // postgres: introduced by migration coderd/database/migrations/00030*_system_user.up.sql - // dbmem: created in dbmem/dbmem.go db, _ := dbtestutil.NewDB(t) other := dbgen.User(t, db, database.User{}) users, err := db.GetUsers(ctx, database.GetUsersParams{ @@ -2086,6 +2087,447 @@ func auditOnlyIDs[T database.AuditLog | database.GetAuditLogsOffsetRow](logs []T return ids } +func TestGetAuthorizedConnectionLogsOffset(t *testing.T) { + t.Parallel() + + var allLogs []database.ConnectionLog + db, _ := dbtestutil.NewDB(t) + authz := rbac.NewAuthorizer(prometheus.NewRegistry()) + authDb := dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer()) + + orgA := dbfake.Organization(t, db).Do() + orgB := dbfake.Organization(t, db).Do() + + user := dbgen.User(t, db, database.User{}) + + tpl := dbgen.Template(t, db, database.Template{ + OrganizationID: orgA.Org.ID, + CreatedBy: user.ID, + }) + + wsID := uuid.New() + createTemplateVersion(t, db, tpl, tvArgs{ + WorkspaceTransition: database.WorkspaceTransitionStart, + Status: database.ProvisionerJobStatusSucceeded, + CreateWorkspace: true, + WorkspaceID: wsID, + }) + + // This map is a simple way to insert a given number of organizations + // and audit logs for each organization. + // map[orgID][]ConnectionLogID + orgConnectionLogs := map[uuid.UUID][]uuid.UUID{ + orgA.Org.ID: {uuid.New(), uuid.New()}, + orgB.Org.ID: {uuid.New(), uuid.New()}, + } + orgIDs := make([]uuid.UUID, 0, len(orgConnectionLogs)) + for orgID := range orgConnectionLogs { + orgIDs = append(orgIDs, orgID) + } + for orgID, ids := range orgConnectionLogs { + for _, id := range ids { + allLogs = append(allLogs, dbgen.ConnectionLog(t, authDb, database.UpsertConnectionLogParams{ + WorkspaceID: wsID, + WorkspaceOwnerID: user.ID, + ID: id, + OrganizationID: orgID, + })) + } + } + + // Now fetch all the logs + ctx := testutil.Context(t, testutil.WaitLong) + auditorRole, err := rbac.RoleByName(rbac.RoleAuditor()) + require.NoError(t, err) + + memberRole, err := rbac.RoleByName(rbac.RoleMember()) + require.NoError(t, err) + + orgAuditorRoles := func(t *testing.T, orgID uuid.UUID) rbac.Role { + t.Helper() + + role, err := rbac.RoleByName(rbac.ScopedRoleOrgAuditor(orgID)) + require.NoError(t, err) + return role + } + + t.Run("NoAccess", func(t *testing.T) { + t.Parallel() + + // Given: A user who is a member of 0 organizations + memberCtx := dbauthz.As(ctx, rbac.Subject{ + FriendlyName: "member", + ID: uuid.NewString(), + Roles: rbac.Roles{memberRole}, + Scope: rbac.ScopeAll, + }) + + // When: The user queries for connection logs + logs, err := authDb.GetConnectionLogsOffset(memberCtx, database.GetConnectionLogsOffsetParams{}) + require.NoError(t, err) + // Then: No logs returned + require.Len(t, logs, 0, "no logs should be returned") + }) + + t.Run("SiteWideAuditor", func(t *testing.T) { + t.Parallel() + + // Given: A site wide auditor + siteAuditorCtx := dbauthz.As(ctx, rbac.Subject{ + FriendlyName: "owner", + ID: uuid.NewString(), + Roles: rbac.Roles{auditorRole}, + Scope: rbac.ScopeAll, + }) + + // When: the auditor queries for connection logs + logs, err := authDb.GetConnectionLogsOffset(siteAuditorCtx, database.GetConnectionLogsOffsetParams{}) + require.NoError(t, err) + // Then: All logs are returned + require.ElementsMatch(t, connectionOnlyIDs(allLogs), connectionOnlyIDs(logs)) + }) + + t.Run("SingleOrgAuditor", func(t *testing.T) { + t.Parallel() + + orgID := orgIDs[0] + // Given: An organization scoped auditor + orgAuditCtx := dbauthz.As(ctx, rbac.Subject{ + FriendlyName: "org-auditor", + ID: uuid.NewString(), + Roles: rbac.Roles{orgAuditorRoles(t, orgID)}, + Scope: rbac.ScopeAll, + }) + + // When: The auditor queries for connection logs + logs, err := authDb.GetConnectionLogsOffset(orgAuditCtx, database.GetConnectionLogsOffsetParams{}) + require.NoError(t, err) + // Then: Only the logs for the organization are returned + require.ElementsMatch(t, orgConnectionLogs[orgID], connectionOnlyIDs(logs)) + }) + + t.Run("TwoOrgAuditors", func(t *testing.T) { + t.Parallel() + + first := orgIDs[0] + second := orgIDs[1] + // Given: A user who is an auditor for two organizations + multiOrgAuditCtx := dbauthz.As(ctx, rbac.Subject{ + FriendlyName: "org-auditor", + ID: uuid.NewString(), + Roles: rbac.Roles{orgAuditorRoles(t, first), orgAuditorRoles(t, second)}, + Scope: rbac.ScopeAll, + }) + + // When: The user queries for connection logs + logs, err := authDb.GetConnectionLogsOffset(multiOrgAuditCtx, database.GetConnectionLogsOffsetParams{}) + require.NoError(t, err) + // Then: All logs for both organizations are returned + require.ElementsMatch(t, append(orgConnectionLogs[first], orgConnectionLogs[second]...), connectionOnlyIDs(logs)) + }) + + t.Run("ErroneousOrg", func(t *testing.T) { + t.Parallel() + + // Given: A user who is an auditor for an organization that has 0 logs + userCtx := dbauthz.As(ctx, rbac.Subject{ + FriendlyName: "org-auditor", + ID: uuid.NewString(), + Roles: rbac.Roles{orgAuditorRoles(t, uuid.New())}, + Scope: rbac.ScopeAll, + }) + + // When: The user queries for audit logs + logs, err := authDb.GetConnectionLogsOffset(userCtx, database.GetConnectionLogsOffsetParams{}) + require.NoError(t, err) + // Then: No logs are returned + require.Len(t, logs, 0, "no logs should be returned") + }) +} + +func connectionOnlyIDs[T database.ConnectionLog | database.GetConnectionLogsOffsetRow](logs []T) []uuid.UUID { + ids := make([]uuid.UUID, 0, len(logs)) + for _, log := range logs { + switch log := any(log).(type) { + case database.ConnectionLog: + ids = append(ids, log.ID) + case database.GetConnectionLogsOffsetRow: + ids = append(ids, log.ConnectionLog.ID) + default: + panic("unreachable") + } + } + return ids +} + +func TestUpsertConnectionLog(t *testing.T) { + t.Parallel() + createWorkspace := func(t *testing.T, db database.Store) database.WorkspaceTable { + u := dbgen.User(t, db, database.User{}) + o := dbgen.Organization(t, db, database.Organization{}) + tpl := dbgen.Template(t, db, database.Template{ + OrganizationID: o.ID, + CreatedBy: u.ID, + }) + return dbgen.Workspace(t, db, database.WorkspaceTable{ + ID: uuid.New(), + OwnerID: u.ID, + OrganizationID: o.ID, + AutomaticUpdates: database.AutomaticUpdatesNever, + TemplateID: tpl.ID, + }) + } + + t.Run("ConnectThenDisconnect", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := context.Background() + + ws := createWorkspace(t, db) + + connectionID := uuid.New() + agentName := "test-agent" + + // 1. Insert a 'connect' event. + connectTime := dbtime.Now() + connectParams := database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: connectTime, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceID: ws.ID, + WorkspaceName: ws.Name, + AgentName: agentName, + Type: database.ConnectionTypeSsh, + ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, + ConnectionStatus: database.ConnectionStatusConnected, + Ip: pqtype.Inet{ + IPNet: net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + Valid: true, + }, + } + + log1, err := db.UpsertConnectionLog(ctx, connectParams) + require.NoError(t, err) + require.Equal(t, connectParams.ID, log1.ID) + require.False(t, log1.DisconnectTime.Valid, "CloseTime should not be set on connect") + + // Check that one row exists. + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10}) + require.NoError(t, err) + require.Len(t, rows, 1) + + // 2. Insert a 'disconnected' event for the same connection. + disconnectTime := connectTime.Add(time.Second) + disconnectParams := database.UpsertConnectionLogParams{ + ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, + WorkspaceID: ws.ID, + AgentName: agentName, + ConnectionStatus: database.ConnectionStatusDisconnected, + + // Updated to: + Time: disconnectTime, + DisconnectReason: sql.NullString{String: "test disconnect", Valid: true}, + Code: sql.NullInt32{Int32: 1, Valid: true}, + + // Ignored + ID: uuid.New(), + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceName: ws.Name, + Type: database.ConnectionTypeSsh, + Ip: pqtype.Inet{ + IPNet: net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 254), + }, + Valid: true, + }, + } + + log2, err := db.UpsertConnectionLog(ctx, disconnectParams) + require.NoError(t, err) + + // Updated + require.Equal(t, log1.ID, log2.ID) + require.True(t, log2.DisconnectTime.Valid) + require.True(t, disconnectTime.Equal(log2.DisconnectTime.Time)) + require.Equal(t, disconnectParams.DisconnectReason.String, log2.DisconnectReason.String) + + rows, err = db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{}) + require.NoError(t, err) + require.Len(t, rows, 1) + }) + + t.Run("ConnectDoesNotUpdate", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := context.Background() + + ws := createWorkspace(t, db) + + connectionID := uuid.New() + agentName := "test-agent" + + // 1. Insert a 'connect' event. + connectTime := dbtime.Now() + connectParams := database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: connectTime, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceID: ws.ID, + WorkspaceName: ws.Name, + AgentName: agentName, + Type: database.ConnectionTypeSsh, + ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, + ConnectionStatus: database.ConnectionStatusConnected, + Ip: pqtype.Inet{ + IPNet: net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + Valid: true, + }, + } + + log, err := db.UpsertConnectionLog(ctx, connectParams) + require.NoError(t, err) + + // 2. Insert another 'connect' event for the same connection. + connectTime2 := connectTime.Add(time.Second) + connectParams2 := database.UpsertConnectionLogParams{ + ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, + WorkspaceID: ws.ID, + AgentName: agentName, + ConnectionStatus: database.ConnectionStatusConnected, + + // Ignored + ID: uuid.New(), + Time: connectTime2, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceName: ws.Name, + Type: database.ConnectionTypeSsh, + Code: sql.NullInt32{Int32: 0, Valid: false}, + Ip: pqtype.Inet{ + IPNet: net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 254), + }, + Valid: true, + }, + } + + origLog, err := db.UpsertConnectionLog(ctx, connectParams2) + require.NoError(t, err) + require.Equal(t, log, origLog, "connect update should be a no-op") + + // Check that still only one row exists. + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{}) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, log, rows[0].ConnectionLog) + }) + + t.Run("DisconnectThenConnect", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := context.Background() + + ws := createWorkspace(t, db) + + connectionID := uuid.New() + agentName := "test-agent" + + // Insert just a 'disconect' event + disconnectTime := dbtime.Now() + disconnectParams := database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: disconnectTime, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceID: ws.ID, + WorkspaceName: ws.Name, + AgentName: agentName, + Type: database.ConnectionTypeSsh, + ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, + ConnectionStatus: database.ConnectionStatusDisconnected, + DisconnectReason: sql.NullString{String: "server shutting down", Valid: true}, + Ip: pqtype.Inet{ + IPNet: net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + Valid: true, + }, + } + + _, err := db.UpsertConnectionLog(ctx, disconnectParams) + require.NoError(t, err) + + firstRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{}) + require.NoError(t, err) + require.Len(t, firstRows, 1) + + // We expect the connection event to be marked as closed with the start + // and close time being the same. + require.True(t, firstRows[0].ConnectionLog.DisconnectTime.Valid) + require.Equal(t, disconnectTime, firstRows[0].ConnectionLog.DisconnectTime.Time.UTC()) + require.Equal(t, firstRows[0].ConnectionLog.ConnectTime.UTC(), firstRows[0].ConnectionLog.DisconnectTime.Time.UTC()) + + // Now insert a 'connect' event for the same connection. + // This should be a no op + connectTime := disconnectTime.Add(time.Second) + connectParams := database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: connectTime, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceID: ws.ID, + WorkspaceName: ws.Name, + AgentName: agentName, + Type: database.ConnectionTypeSsh, + ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, + ConnectionStatus: database.ConnectionStatusConnected, + DisconnectReason: sql.NullString{String: "reconnected", Valid: true}, + Code: sql.NullInt32{Int32: 0, Valid: false}, + Ip: pqtype.Inet{ + IPNet: net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + Valid: true, + }, + } + + _, err = db.UpsertConnectionLog(ctx, connectParams) + require.NoError(t, err) + + secondRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{}) + require.NoError(t, err) + require.Len(t, secondRows, 1) + require.Equal(t, firstRows, secondRows) + + // Upsert a disconnection, which should also be a no op + disconnectParams.DisconnectReason = sql.NullString{ + String: "updated close reason", + Valid: true, + } + _, err = db.UpsertConnectionLog(ctx, disconnectParams) + require.NoError(t, err) + thirdRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{}) + require.NoError(t, err) + require.Len(t, secondRows, 1) + // The close reason shouldn't be updated + require.Equal(t, secondRows, thirdRows) + }) +} + type tvArgs struct { Status database.ProvisionerJobStatus // CreateWorkspace is true if we should create a workspace for the template version @@ -5021,3 +5463,102 @@ func requireUsersMatch(t testing.TB, expected []database.User, found []database. t.Helper() require.ElementsMatch(t, expected, database.ConvertUserRows(found), msg) } + +// TestGetRunningPrebuiltWorkspaces ensures the correct behavior of the +// GetRunningPrebuiltWorkspaces query. +func TestGetRunningPrebuiltWorkspaces(t *testing.T) { + t.Parallel() + + if !dbtestutil.WillUsePostgres() { + t.Skip("Test requires PostgreSQL for complex queries") + } + + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + now := dbtime.Now() + + // Given: a prebuilt workspace with a successful start build and a stop build. + org := dbgen.Organization(t, db, database.Organization{}) + user := dbgen.User(t, db, database.User{}) + template := dbgen.Template(t, db, database.Template{ + CreatedBy: user.ID, + OrganizationID: org.ID, + }) + templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true}, + OrganizationID: org.ID, + CreatedBy: user.ID, + }) + preset := dbgen.Preset(t, db, database.InsertPresetParams{ + TemplateVersionID: templateVersion.ID, + DesiredInstances: sql.NullInt32{Int32: 1, Valid: true}, + }) + + setupFixture := func(t *testing.T, db database.Store, name string, deleted bool, transition database.WorkspaceTransition, jobStatus database.ProvisionerJobStatus) database.WorkspaceTable { + t.Helper() + ws := dbgen.Workspace(t, db, database.WorkspaceTable{ + OwnerID: database.PrebuildsSystemUserID, + TemplateID: template.ID, + Name: name, + Deleted: deleted, + }) + var canceledAt sql.NullTime + var jobError sql.NullString + switch jobStatus { + case database.ProvisionerJobStatusFailed: + jobError = sql.NullString{String: assert.AnError.Error(), Valid: true} + case database.ProvisionerJobStatusCanceled: + canceledAt = sql.NullTime{Time: now, Valid: true} + } + pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{ + OrganizationID: org.ID, + InitiatorID: database.PrebuildsSystemUserID, + Provisioner: database.ProvisionerTypeEcho, + Type: database.ProvisionerJobTypeWorkspaceBuild, + StartedAt: sql.NullTime{Time: now.Add(-time.Minute), Valid: true}, + CanceledAt: canceledAt, + CompletedAt: sql.NullTime{Time: now, Valid: true}, + Error: jobError, + }) + wb := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: ws.ID, + TemplateVersionID: templateVersion.ID, + TemplateVersionPresetID: uuid.NullUUID{UUID: preset.ID, Valid: true}, + JobID: pj.ID, + BuildNumber: 1, + Transition: transition, + InitiatorID: database.PrebuildsSystemUserID, + Reason: database.BuildReasonInitiator, + }) + // Ensure things are set up as expectd + require.Equal(t, transition, wb.Transition) + require.Equal(t, int32(1), wb.BuildNumber) + require.Equal(t, jobStatus, pj.JobStatus) + require.Equal(t, deleted, ws.Deleted) + + return ws + } + + // Given: a number of prebuild workspaces with different states exist. + runningPrebuild := setupFixture(t, db, "running-prebuild", false, database.WorkspaceTransitionStart, database.ProvisionerJobStatusSucceeded) + _ = setupFixture(t, db, "stopped-prebuild", false, database.WorkspaceTransitionStop, database.ProvisionerJobStatusSucceeded) + _ = setupFixture(t, db, "failed-prebuild", false, database.WorkspaceTransitionStart, database.ProvisionerJobStatusFailed) + _ = setupFixture(t, db, "canceled-prebuild", false, database.WorkspaceTransitionStart, database.ProvisionerJobStatusCanceled) + _ = setupFixture(t, db, "deleted-prebuild", true, database.WorkspaceTransitionStart, database.ProvisionerJobStatusSucceeded) + + // Given: a regular workspace also exists. + _ = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + TemplateID: template.ID, + Name: "test-running-regular-workspace", + Deleted: false, + }) + + // When: we query for running prebuild workspaces + runningPrebuilds, err := db.GetRunningPrebuiltWorkspaces(ctx) + require.NoError(t, err) + + // Then: only the running prebuild workspace should be returned. + require.Len(t, runningPrebuilds, 1, "expected only one running prebuilt workspace") + require.Equal(t, runningPrebuild.ID, runningPrebuilds[0].ID, "expected the running prebuilt workspace to be returned") +} diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 15f4be06a3fa0..23f7cf3bfbca0 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -880,6 +880,246 @@ func (q *sqlQuerier) InsertAuditLog(ctx context.Context, arg InsertAuditLogParam return i, err } +const getConnectionLogsOffset = `-- name: GetConnectionLogsOffset :many +SELECT + connection_logs.id, connection_logs.connect_time, connection_logs.organization_id, connection_logs.workspace_owner_id, connection_logs.workspace_id, connection_logs.workspace_name, connection_logs.agent_name, connection_logs.type, connection_logs.ip, connection_logs.code, connection_logs.user_agent, connection_logs.user_id, connection_logs.slug_or_port, connection_logs.connection_id, connection_logs.disconnect_time, connection_logs.disconnect_reason, + -- sqlc.embed(users) would be nice but it does not seem to play well with + -- left joins. This user metadata is necessary for parity with the audit logs + -- API. + users.username AS user_username, + users.name AS user_name, + users.email AS user_email, + users.created_at AS user_created_at, + users.updated_at AS user_updated_at, + users.last_seen_at AS user_last_seen_at, + users.status AS user_status, + users.login_type AS user_login_type, + users.rbac_roles AS user_roles, + users.avatar_url AS user_avatar_url, + users.deleted AS user_deleted, + users.quiet_hours_schedule AS user_quiet_hours_schedule, + workspace_owner.username AS workspace_owner_username, + organizations.name AS organization_name, + organizations.display_name AS organization_display_name, + organizations.icon AS organization_icon +FROM + connection_logs +JOIN users AS workspace_owner ON + connection_logs.workspace_owner_id = workspace_owner.id +LEFT JOIN users ON + connection_logs.user_id = users.id +JOIN organizations ON + connection_logs.organization_id = organizations.id +WHERE TRUE + -- Authorize Filter clause will be injected below in + -- GetAuthorizedConnectionLogsOffset + -- @authorize_filter +ORDER BY + connect_time DESC +LIMIT + -- a limit of 0 means "no limit". The connection log table is unbounded + -- in size, and is expected to be quite large. Implement a default + -- limit of 100 to prevent accidental excessively large queries. + COALESCE(NULLIF($2 :: int, 0), 100) +OFFSET + $1 +` + +type GetConnectionLogsOffsetParams struct { + OffsetOpt int32 `db:"offset_opt" json:"offset_opt"` + LimitOpt int32 `db:"limit_opt" json:"limit_opt"` +} + +type GetConnectionLogsOffsetRow struct { + ConnectionLog ConnectionLog `db:"connection_log" json:"connection_log"` + UserUsername sql.NullString `db:"user_username" json:"user_username"` + UserName sql.NullString `db:"user_name" json:"user_name"` + UserEmail sql.NullString `db:"user_email" json:"user_email"` + UserCreatedAt sql.NullTime `db:"user_created_at" json:"user_created_at"` + UserUpdatedAt sql.NullTime `db:"user_updated_at" json:"user_updated_at"` + UserLastSeenAt sql.NullTime `db:"user_last_seen_at" json:"user_last_seen_at"` + UserStatus NullUserStatus `db:"user_status" json:"user_status"` + UserLoginType NullLoginType `db:"user_login_type" json:"user_login_type"` + UserRoles pq.StringArray `db:"user_roles" json:"user_roles"` + UserAvatarUrl sql.NullString `db:"user_avatar_url" json:"user_avatar_url"` + UserDeleted sql.NullBool `db:"user_deleted" json:"user_deleted"` + UserQuietHoursSchedule sql.NullString `db:"user_quiet_hours_schedule" json:"user_quiet_hours_schedule"` + WorkspaceOwnerUsername string `db:"workspace_owner_username" json:"workspace_owner_username"` + OrganizationName string `db:"organization_name" json:"organization_name"` + OrganizationDisplayName string `db:"organization_display_name" json:"organization_display_name"` + OrganizationIcon string `db:"organization_icon" json:"organization_icon"` +} + +func (q *sqlQuerier) GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error) { + rows, err := q.db.QueryContext(ctx, getConnectionLogsOffset, arg.OffsetOpt, arg.LimitOpt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetConnectionLogsOffsetRow + for rows.Next() { + var i GetConnectionLogsOffsetRow + if err := rows.Scan( + &i.ConnectionLog.ID, + &i.ConnectionLog.ConnectTime, + &i.ConnectionLog.OrganizationID, + &i.ConnectionLog.WorkspaceOwnerID, + &i.ConnectionLog.WorkspaceID, + &i.ConnectionLog.WorkspaceName, + &i.ConnectionLog.AgentName, + &i.ConnectionLog.Type, + &i.ConnectionLog.Ip, + &i.ConnectionLog.Code, + &i.ConnectionLog.UserAgent, + &i.ConnectionLog.UserID, + &i.ConnectionLog.SlugOrPort, + &i.ConnectionLog.ConnectionID, + &i.ConnectionLog.DisconnectTime, + &i.ConnectionLog.DisconnectReason, + &i.UserUsername, + &i.UserName, + &i.UserEmail, + &i.UserCreatedAt, + &i.UserUpdatedAt, + &i.UserLastSeenAt, + &i.UserStatus, + &i.UserLoginType, + &i.UserRoles, + &i.UserAvatarUrl, + &i.UserDeleted, + &i.UserQuietHoursSchedule, + &i.WorkspaceOwnerUsername, + &i.OrganizationName, + &i.OrganizationDisplayName, + &i.OrganizationIcon, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const upsertConnectionLog = `-- name: UpsertConnectionLog :one +INSERT INTO connection_logs ( + id, + connect_time, + organization_id, + workspace_owner_id, + workspace_id, + workspace_name, + agent_name, + type, + code, + ip, + user_agent, + user_id, + slug_or_port, + connection_id, + disconnect_reason, + disconnect_time +) VALUES + ($1, $15, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, + -- If we've only received a disconnect event, mark the event as immediately + -- closed. + CASE + WHEN $16::connection_status = 'disconnected' + THEN $15 :: timestamp with time zone + ELSE NULL + END) +ON CONFLICT (connection_id, workspace_id, agent_name) +DO UPDATE SET + -- No-op if the connection is still open. + disconnect_time = CASE + WHEN $16::connection_status = 'disconnected' + -- Can only be set once + AND connection_logs.disconnect_time IS NULL + THEN EXCLUDED.connect_time + ELSE connection_logs.disconnect_time + END, + disconnect_reason = CASE + WHEN $16::connection_status = 'disconnected' + -- Can only be set once + AND connection_logs.disconnect_reason IS NULL + THEN EXCLUDED.disconnect_reason + ELSE connection_logs.disconnect_reason + END, + code = CASE + WHEN $16::connection_status = 'disconnected' + -- Can only be set once + AND connection_logs.code IS NULL + THEN EXCLUDED.code + ELSE connection_logs.code + END +RETURNING id, connect_time, organization_id, workspace_owner_id, workspace_id, workspace_name, agent_name, type, ip, code, user_agent, user_id, slug_or_port, connection_id, disconnect_time, disconnect_reason +` + +type UpsertConnectionLogParams struct { + ID uuid.UUID `db:"id" json:"id"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"` + WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"` + WorkspaceName string `db:"workspace_name" json:"workspace_name"` + AgentName string `db:"agent_name" json:"agent_name"` + Type ConnectionType `db:"type" json:"type"` + Code sql.NullInt32 `db:"code" json:"code"` + Ip pqtype.Inet `db:"ip" json:"ip"` + UserAgent sql.NullString `db:"user_agent" json:"user_agent"` + UserID uuid.NullUUID `db:"user_id" json:"user_id"` + SlugOrPort sql.NullString `db:"slug_or_port" json:"slug_or_port"` + ConnectionID uuid.NullUUID `db:"connection_id" json:"connection_id"` + DisconnectReason sql.NullString `db:"disconnect_reason" json:"disconnect_reason"` + Time time.Time `db:"time" json:"time"` + ConnectionStatus ConnectionStatus `db:"connection_status" json:"connection_status"` +} + +func (q *sqlQuerier) UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error) { + row := q.db.QueryRowContext(ctx, upsertConnectionLog, + arg.ID, + arg.OrganizationID, + arg.WorkspaceOwnerID, + arg.WorkspaceID, + arg.WorkspaceName, + arg.AgentName, + arg.Type, + arg.Code, + arg.Ip, + arg.UserAgent, + arg.UserID, + arg.SlugOrPort, + arg.ConnectionID, + arg.DisconnectReason, + arg.Time, + arg.ConnectionStatus, + ) + var i ConnectionLog + err := row.Scan( + &i.ID, + &i.ConnectTime, + &i.OrganizationID, + &i.WorkspaceOwnerID, + &i.WorkspaceID, + &i.WorkspaceName, + &i.AgentName, + &i.Type, + &i.Ip, + &i.Code, + &i.UserAgent, + &i.UserID, + &i.SlugOrPort, + &i.ConnectionID, + &i.DisconnectTime, + &i.DisconnectReason, + ) + return i, err +} + const deleteCryptoKey = `-- name: DeleteCryptoKey :one UPDATE crypto_keys SET secret = NULL, secret_key_id = NULL @@ -6843,6 +7083,7 @@ FROM workspace_prebuilds p INNER JOIN workspace_latest_builds b ON b.workspace_id = p.id WHERE (b.transition = 'start'::workspace_transition AND b.job_status = 'succeeded'::provisioner_job_status) +ORDER BY p.id ` type GetRunningPrebuiltWorkspacesRow struct { @@ -6886,6 +7127,106 @@ func (q *sqlQuerier) GetRunningPrebuiltWorkspaces(ctx context.Context) ([]GetRun return items, nil } +const getRunningPrebuiltWorkspacesOptimized = `-- name: GetRunningPrebuiltWorkspacesOptimized :many +WITH latest_prebuilds AS ( + -- All workspaces that match the following criteria: + -- 1. Owned by prebuilds user + -- 2. Not deleted + -- 3. Latest build is a 'start' transition + -- 4. Latest build was successful + SELECT + workspaces.id, + workspaces.name, + workspaces.template_id, + workspace_latest_builds.template_version_id, + workspace_latest_builds.job_id, + workspaces.created_at + FROM workspace_latest_builds + JOIN workspaces ON workspaces.id = workspace_latest_builds.workspace_id + WHERE workspace_latest_builds.transition = 'start'::workspace_transition + AND workspace_latest_builds.job_status = 'succeeded'::provisioner_job_status + AND workspaces.owner_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::UUID + AND NOT workspaces.deleted +), +workspace_latest_presets AS ( + -- For each of the above workspaces, the preset_id of the most recent + -- successful start transition. + SELECT DISTINCT ON (latest_prebuilds.id) + latest_prebuilds.id AS workspace_id, + workspace_builds.template_version_preset_id AS current_preset_id + FROM latest_prebuilds + JOIN workspace_builds ON workspace_builds.workspace_id = latest_prebuilds.id + WHERE workspace_builds.transition = 'start'::workspace_transition + AND workspace_builds.template_version_preset_id IS NOT NULL + ORDER BY latest_prebuilds.id, workspace_builds.build_number DESC +), +ready_agents AS ( + -- For each of the above workspaces, check if all agents are ready. + SELECT + latest_prebuilds.job_id, + BOOL_AND(workspace_agents.lifecycle_state = 'ready'::workspace_agent_lifecycle_state)::boolean AS ready + FROM latest_prebuilds + JOIN workspace_resources ON workspace_resources.job_id = latest_prebuilds.job_id + JOIN workspace_agents ON workspace_agents.resource_id = workspace_resources.id + WHERE workspace_agents.deleted = false + AND workspace_agents.parent_id IS NULL + GROUP BY latest_prebuilds.job_id +) +SELECT + latest_prebuilds.id, + latest_prebuilds.name, + latest_prebuilds.template_id, + latest_prebuilds.template_version_id, + workspace_latest_presets.current_preset_id, + COALESCE(ready_agents.ready, false)::boolean AS ready, + latest_prebuilds.created_at +FROM latest_prebuilds +LEFT JOIN ready_agents ON ready_agents.job_id = latest_prebuilds.job_id +LEFT JOIN workspace_latest_presets ON workspace_latest_presets.workspace_id = latest_prebuilds.id +ORDER BY latest_prebuilds.id +` + +type GetRunningPrebuiltWorkspacesOptimizedRow struct { + ID uuid.UUID `db:"id" json:"id"` + Name string `db:"name" json:"name"` + TemplateID uuid.UUID `db:"template_id" json:"template_id"` + TemplateVersionID uuid.UUID `db:"template_version_id" json:"template_version_id"` + CurrentPresetID uuid.NullUUID `db:"current_preset_id" json:"current_preset_id"` + Ready bool `db:"ready" json:"ready"` + CreatedAt time.Time `db:"created_at" json:"created_at"` +} + +func (q *sqlQuerier) GetRunningPrebuiltWorkspacesOptimized(ctx context.Context) ([]GetRunningPrebuiltWorkspacesOptimizedRow, error) { + rows, err := q.db.QueryContext(ctx, getRunningPrebuiltWorkspacesOptimized) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetRunningPrebuiltWorkspacesOptimizedRow + for rows.Next() { + var i GetRunningPrebuiltWorkspacesOptimizedRow + if err := rows.Scan( + &i.ID, + &i.Name, + &i.TemplateID, + &i.TemplateVersionID, + &i.CurrentPresetID, + &i.Ready, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getTemplatePresetsWithPrebuilds = `-- name: GetTemplatePresetsWithPrebuilds :many SELECT t.id AS template_id, @@ -19971,7 +20312,12 @@ WHERE provisioner_jobs.completed_at IS NOT NULL AND ($1 :: timestamptz) - provisioner_jobs.completed_at > (INTERVAL '1 millisecond' * (templates.failure_ttl / 1000000)) ) - ) AND workspaces.deleted = 'false' + ) + AND workspaces.deleted = 'false' + -- Prebuilt workspaces (identified by having the prebuilds system user as owner_id) + -- should not be considered by the lifecycle executor, as they are handled by the + -- prebuilds reconciliation loop. + AND workspaces.owner_id != 'c42fdf75-3097-471c-8c33-fb52454d81c0'::UUID ` type GetWorkspacesEligibleForTransitionRow struct { diff --git a/coderd/database/queries/connectionlogs.sql b/coderd/database/queries/connectionlogs.sql new file mode 100644 index 0000000000000..172a7c533d7d5 --- /dev/null +++ b/coderd/database/queries/connectionlogs.sql @@ -0,0 +1,97 @@ +-- name: GetConnectionLogsOffset :many +SELECT + sqlc.embed(connection_logs), + -- sqlc.embed(users) would be nice but it does not seem to play well with + -- left joins. This user metadata is necessary for parity with the audit logs + -- API. + users.username AS user_username, + users.name AS user_name, + users.email AS user_email, + users.created_at AS user_created_at, + users.updated_at AS user_updated_at, + users.last_seen_at AS user_last_seen_at, + users.status AS user_status, + users.login_type AS user_login_type, + users.rbac_roles AS user_roles, + users.avatar_url AS user_avatar_url, + users.deleted AS user_deleted, + users.quiet_hours_schedule AS user_quiet_hours_schedule, + workspace_owner.username AS workspace_owner_username, + organizations.name AS organization_name, + organizations.display_name AS organization_display_name, + organizations.icon AS organization_icon +FROM + connection_logs +JOIN users AS workspace_owner ON + connection_logs.workspace_owner_id = workspace_owner.id +LEFT JOIN users ON + connection_logs.user_id = users.id +JOIN organizations ON + connection_logs.organization_id = organizations.id +WHERE TRUE + -- Authorize Filter clause will be injected below in + -- GetAuthorizedConnectionLogsOffset + -- @authorize_filter +ORDER BY + connect_time DESC +LIMIT + -- a limit of 0 means "no limit". The connection log table is unbounded + -- in size, and is expected to be quite large. Implement a default + -- limit of 100 to prevent accidental excessively large queries. + COALESCE(NULLIF(@limit_opt :: int, 0), 100) +OFFSET + @offset_opt; + + +-- name: UpsertConnectionLog :one +INSERT INTO connection_logs ( + id, + connect_time, + organization_id, + workspace_owner_id, + workspace_id, + workspace_name, + agent_name, + type, + code, + ip, + user_agent, + user_id, + slug_or_port, + connection_id, + disconnect_reason, + disconnect_time +) VALUES + ($1, @time, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, + -- If we've only received a disconnect event, mark the event as immediately + -- closed. + CASE + WHEN @connection_status::connection_status = 'disconnected' + THEN @time :: timestamp with time zone + ELSE NULL + END) +ON CONFLICT (connection_id, workspace_id, agent_name) +DO UPDATE SET + -- No-op if the connection is still open. + disconnect_time = CASE + WHEN @connection_status::connection_status = 'disconnected' + -- Can only be set once + AND connection_logs.disconnect_time IS NULL + THEN EXCLUDED.connect_time + ELSE connection_logs.disconnect_time + END, + disconnect_reason = CASE + WHEN @connection_status::connection_status = 'disconnected' + -- Can only be set once + AND connection_logs.disconnect_reason IS NULL + THEN EXCLUDED.disconnect_reason + ELSE connection_logs.disconnect_reason + END, + code = CASE + WHEN @connection_status::connection_status = 'disconnected' + -- Can only be set once + AND connection_logs.code IS NULL + THEN EXCLUDED.code + ELSE connection_logs.code + END +RETURNING *; diff --git a/coderd/database/queries/prebuilds.sql b/coderd/database/queries/prebuilds.sql index 2fc9f3f4a67f6..7e1dbc71f4a26 100644 --- a/coderd/database/queries/prebuilds.sql +++ b/coderd/database/queries/prebuilds.sql @@ -48,6 +48,64 @@ WHERE tvp.desired_instances IS NOT NULL -- Consider only presets that have a pre -- AND NOT t.deleted -- We don't exclude deleted templates because there's no constraint in the DB preventing a soft deletion on a template while workspaces are running. AND (t.id = sqlc.narg('template_id')::uuid OR sqlc.narg('template_id') IS NULL); +-- name: GetRunningPrebuiltWorkspacesOptimized :many +WITH latest_prebuilds AS ( + -- All workspaces that match the following criteria: + -- 1. Owned by prebuilds user + -- 2. Not deleted + -- 3. Latest build is a 'start' transition + -- 4. Latest build was successful + SELECT + workspaces.id, + workspaces.name, + workspaces.template_id, + workspace_latest_builds.template_version_id, + workspace_latest_builds.job_id, + workspaces.created_at + FROM workspace_latest_builds + JOIN workspaces ON workspaces.id = workspace_latest_builds.workspace_id + WHERE workspace_latest_builds.transition = 'start'::workspace_transition + AND workspace_latest_builds.job_status = 'succeeded'::provisioner_job_status + AND workspaces.owner_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::UUID + AND NOT workspaces.deleted +), +workspace_latest_presets AS ( + -- For each of the above workspaces, the preset_id of the most recent + -- successful start transition. + SELECT DISTINCT ON (latest_prebuilds.id) + latest_prebuilds.id AS workspace_id, + workspace_builds.template_version_preset_id AS current_preset_id + FROM latest_prebuilds + JOIN workspace_builds ON workspace_builds.workspace_id = latest_prebuilds.id + WHERE workspace_builds.transition = 'start'::workspace_transition + AND workspace_builds.template_version_preset_id IS NOT NULL + ORDER BY latest_prebuilds.id, workspace_builds.build_number DESC +), +ready_agents AS ( + -- For each of the above workspaces, check if all agents are ready. + SELECT + latest_prebuilds.job_id, + BOOL_AND(workspace_agents.lifecycle_state = 'ready'::workspace_agent_lifecycle_state)::boolean AS ready + FROM latest_prebuilds + JOIN workspace_resources ON workspace_resources.job_id = latest_prebuilds.job_id + JOIN workspace_agents ON workspace_agents.resource_id = workspace_resources.id + WHERE workspace_agents.deleted = false + AND workspace_agents.parent_id IS NULL + GROUP BY latest_prebuilds.job_id +) +SELECT + latest_prebuilds.id, + latest_prebuilds.name, + latest_prebuilds.template_id, + latest_prebuilds.template_version_id, + workspace_latest_presets.current_preset_id, + COALESCE(ready_agents.ready, false)::boolean AS ready, + latest_prebuilds.created_at +FROM latest_prebuilds +LEFT JOIN ready_agents ON ready_agents.job_id = latest_prebuilds.job_id +LEFT JOIN workspace_latest_presets ON workspace_latest_presets.workspace_id = latest_prebuilds.id +ORDER BY latest_prebuilds.id; + -- name: GetRunningPrebuiltWorkspaces :many SELECT p.id, @@ -60,7 +118,8 @@ SELECT FROM workspace_prebuilds p INNER JOIN workspace_latest_builds b ON b.workspace_id = p.id WHERE (b.transition = 'start'::workspace_transition - AND b.job_status = 'succeeded'::provisioner_job_status); + AND b.job_status = 'succeeded'::provisioner_job_status) +ORDER BY p.id; -- name: CountInProgressPrebuilds :many -- CountInProgressPrebuilds returns the number of in-progress prebuilds, grouped by preset ID and transition. diff --git a/coderd/database/queries/workspaces.sql b/coderd/database/queries/workspaces.sql index 25e4d4f97a46b..f166d16f742cd 100644 --- a/coderd/database/queries/workspaces.sql +++ b/coderd/database/queries/workspaces.sql @@ -758,7 +758,12 @@ WHERE provisioner_jobs.completed_at IS NOT NULL AND (@now :: timestamptz) - provisioner_jobs.completed_at > (INTERVAL '1 millisecond' * (templates.failure_ttl / 1000000)) ) - ) AND workspaces.deleted = 'false'; + ) + AND workspaces.deleted = 'false' + -- Prebuilt workspaces (identified by having the prebuilds system user as owner_id) + -- should not be considered by the lifecycle executor, as they are handled by the + -- prebuilds reconciliation loop. + AND workspaces.owner_id != 'c42fdf75-3097-471c-8c33-fb52454d81c0'::UUID; -- name: UpdateWorkspaceDormantDeletingAt :one UPDATE diff --git a/coderd/database/types.go b/coderd/database/types.go index a4a723d02b466..6d0f036fe692c 100644 --- a/coderd/database/types.go +++ b/coderd/database/types.go @@ -4,10 +4,12 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "net" "strings" "time" "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/rbac/policy" @@ -237,3 +239,19 @@ func (a *UserLinkClaims) Scan(src interface{}) error { func (a UserLinkClaims) Value() (driver.Value, error) { return json.Marshal(a) } + +func ParseIP(ipStr string) pqtype.Inet { + ip := net.ParseIP(ipStr) + ipNet := net.IPNet{} + if ip != nil { + ipNet = net.IPNet{ + IP: ip, + Mask: net.CIDRMask(len(ip)*8, len(ip)*8), + } + } + + return pqtype.Inet{ + IPNet: ipNet, + Valid: ip != nil, + } +} diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index b3af136997c9c..38c95e67410c9 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -9,6 +9,7 @@ const ( UniqueAgentStatsPkey UniqueConstraint = "agent_stats_pkey" // ALTER TABLE ONLY workspace_agent_stats ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id); UniqueAPIKeysPkey UniqueConstraint = "api_keys_pkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id); UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); + UniqueConnectionLogsPkey UniqueConstraint = "connection_logs_pkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id); UniqueCryptoKeysPkey UniqueConstraint = "crypto_keys_pkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_pkey PRIMARY KEY (feature, sequence); UniqueCustomRolesUniqueKey UniqueConstraint = "custom_roles_unique_key" // ALTER TABLE ONLY custom_roles ADD CONSTRAINT custom_roles_unique_key UNIQUE (name, organization_id); UniqueDbcryptKeysActiveKeyDigestKey UniqueConstraint = "dbcrypt_keys_active_key_digest_key" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_active_key_digest_key UNIQUE (active_key_digest); @@ -100,6 +101,7 @@ const ( UniqueWorkspaceResourcesPkey UniqueConstraint = "workspace_resources_pkey" // ALTER TABLE ONLY workspace_resources ADD CONSTRAINT workspace_resources_pkey PRIMARY KEY (id); UniqueWorkspacesPkey UniqueConstraint = "workspaces_pkey" // ALTER TABLE ONLY workspaces ADD CONSTRAINT workspaces_pkey PRIMARY KEY (id); UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type); + UniqueIndexConnectionLogsConnectionIDWorkspaceIDAgentName UniqueConstraint = "idx_connection_logs_connection_id_workspace_id_agent_name" // CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name); UniqueIndexCustomRolesNameLower UniqueConstraint = "idx_custom_roles_name_lower" // CREATE UNIQUE INDEX idx_custom_roles_name_lower ON custom_roles USING btree (lower(name)); UniqueIndexOrganizationNameLower UniqueConstraint = "idx_organization_name_lower" // CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name)) WHERE (deleted = false); UniqueIndexProvisionerDaemonsOrgNameOwnerKey UniqueConstraint = "idx_provisioner_daemons_org_name_owner_key" // CREATE UNIQUE INDEX idx_provisioner_daemons_org_name_owner_key ON provisioner_daemons USING btree (organization_id, name, lower(COALESCE((tags ->> 'owner'::text), ''::text))); diff --git a/coderd/dynamicparameters/resolver.go b/coderd/dynamicparameters/resolver.go index bd8e2294cf136..7fc67d29a0d55 100644 --- a/coderd/dynamicparameters/resolver.go +++ b/coderd/dynamicparameters/resolver.go @@ -55,19 +55,21 @@ func ResolveParameters( values[preset.Name] = parameterValue{Source: sourcePreset, Value: preset.Value} } - // originalValues is going to be used to detect if a user tried to change + // originalInputValues is going to be used to detect if a user tried to change // an immutable parameter after the first build. - originalValues := make(map[string]parameterValue, len(values)) + // The actual input values are mutated based on attributes like mutability + // and ephemerality. + originalInputValues := make(map[string]parameterValue, len(values)) for name, value := range values { // Store the original values for later use. - originalValues[name] = value + originalInputValues[name] = value } // Render the parameters using the values that were supplied to the previous build. // // This is how the form should look to the user on their workspace settings page. // This is the original form truth that our validations should initially be based on. - output, diags := renderer.Render(ctx, ownerID, values.ValuesMap()) + output, diags := renderer.Render(ctx, ownerID, previousValuesMap) if diags.HasErrors() { // Top level diagnostics should break the build. Previous values (and new) should // always be valid. If there is a case where this is not true, then this has to @@ -91,22 +93,6 @@ func ResolveParameters( delete(values, parameter.Name) } } - - // Immutable parameters should also not be allowed to be changed from - // the previous build. Remove any values taken from the preset or - // new build params. This forces the value to be the same as it was before. - // - // We do this so the next form render uses the original immutable value. - if !firstBuild && !parameter.Mutable { - delete(values, parameter.Name) - prev, ok := previousValuesMap[parameter.Name] - if ok { - values[parameter.Name] = parameterValue{ - Value: prev, - Source: sourcePrevious, - } - } - } } // This is the final set of values that will be used. Any errors at this stage @@ -116,7 +102,7 @@ func ResolveParameters( return nil, parameterValidationError(diags) } - // parameterNames is going to be used to remove any excess values that were left + // parameterNames is going to be used to remove any excess values left // around without a parameter. parameterNames := make(map[string]struct{}, len(output.Parameters)) parameterError := parameterValidationError(nil) @@ -124,15 +110,20 @@ func ResolveParameters( parameterNames[parameter.Name] = struct{}{} if !firstBuild && !parameter.Mutable { - originalValue, ok := originalValues[parameter.Name] + // previousValuesMap should be used over the first render output + // for the previous state of parameters. The previous build + // should emit all values, so the previousValuesMap should be + // complete with all parameter values (user specified and defaults) + originalValue, ok := previousValuesMap[parameter.Name] + // Immutable parameters should not be changed after the first build. - // If the value matches the original value, that is fine. + // If the value matches the previous input value, that is fine. // - // If the original value is not set, that means this is a new parameter. New + // If the previous value is not set, that means this is a new parameter. New // immutable parameters are allowed. This is an opinionated choice to prevent // workspaces failing to update or delete. Ideally we would block this, as // immutable parameters should only be able to be set at creation time. - if ok && parameter.Value.AsString() != originalValue.Value { + if ok && parameter.Value.AsString() != originalValue { var src *hcl.Range if parameter.Source != nil { src = ¶meter.Source.HCLBlock().TypeRange diff --git a/coderd/dynamicparameters/resolver_test.go b/coderd/dynamicparameters/resolver_test.go index ec5218613ff03..e6675e6f4c7dc 100644 --- a/coderd/dynamicparameters/resolver_test.go +++ b/coderd/dynamicparameters/resolver_test.go @@ -10,6 +10,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/dynamicparameters" "github.com/coder/coder/v2/coderd/dynamicparameters/rendermock" + "github.com/coder/coder/v2/coderd/httpapi/httperror" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" "github.com/coder/preview" @@ -56,4 +57,69 @@ func TestResolveParameters(t *testing.T) { require.NoError(t, err) require.Equal(t, map[string]string{"immutable": "foo"}, values) }) + + // Tests a parameter going from mutable -> immutable + t.Run("BecameImmutable", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + render := rendermock.NewMockRenderer(ctrl) + + mutable := previewtypes.ParameterData{ + Name: "immutable", + Type: previewtypes.ParameterTypeString, + FormType: provider.ParameterFormTypeInput, + Mutable: true, + DefaultValue: previewtypes.StringLiteral("foo"), + Required: true, + } + immutable := mutable + immutable.Mutable = false + + // A single immutable parameter with no previous value. + render.EXPECT(). + Render(gomock.Any(), gomock.Any(), gomock.Any()). + // Return the mutable param first + Return(&preview.Output{ + Parameters: []previewtypes.Parameter{ + { + ParameterData: mutable, + Value: previewtypes.StringLiteral("foo"), + Diagnostics: nil, + }, + }, + }, nil) + + render.EXPECT(). + Render(gomock.Any(), gomock.Any(), gomock.Any()). + // Then the immutable param + Return(&preview.Output{ + Parameters: []previewtypes.Parameter{ + { + ParameterData: immutable, + // The user set the value to bar + Value: previewtypes.StringLiteral("bar"), + Diagnostics: nil, + }, + }, + }, nil) + + ctx := testutil.Context(t, testutil.WaitShort) + _, err := dynamicparameters.ResolveParameters(ctx, uuid.New(), render, false, + []database.WorkspaceBuildParameter{ + {Name: "immutable", Value: "foo"}, // Previous value foo + }, + []codersdk.WorkspaceBuildParameter{ + {Name: "immutable", Value: "bar"}, // New value + }, + []database.TemplateVersionPresetParameter{}, // No preset values + ) + require.Error(t, err) + resp, ok := httperror.IsResponder(err) + require.True(t, ok) + + _, respErr := resp.Response() + require.Len(t, respErr.Validations, 1) + require.Contains(t, respErr.Validations[0].Error(), "is not mutable") + }) } diff --git a/coderd/files/cache.go b/coderd/files/cache.go index 159f1b8aee053..d9e54a66e1c91 100644 --- a/coderd/files/cache.go +++ b/coderd/files/cache.go @@ -303,10 +303,21 @@ func fetch(store database.Store, fileID uuid.UUID) (CacheEntryValue, error) { return CacheEntryValue{}, xerrors.Errorf("failed to read file from database: %w", err) } - content := bytes.NewBuffer(file.Data) + var files fs.FS + switch file.Mimetype { + case "application/zip", "application/x-zip-compressed": + files, err = archivefs.FromZipReader(bytes.NewReader(file.Data), int64(len(file.Data))) + if err != nil { + return CacheEntryValue{}, xerrors.Errorf("failed to read zip file: %w", err) + } + default: + // Assume '"application/x-tar"' as the default mimetype. + files = archivefs.FromTarReader(bytes.NewBuffer(file.Data)) + } + return CacheEntryValue{ Object: file.RBACObject(), - FS: archivefs.FromTarReader(content), + FS: files, Size: int64(len(file.Data)), }, nil } diff --git a/coderd/notifications/notifications_test.go b/coderd/notifications/notifications_test.go index ec9edee4c8514..e213a62df9996 100644 --- a/coderd/notifications/notifications_test.go +++ b/coderd/notifications/notifications_test.go @@ -769,7 +769,7 @@ func TestNotificationTemplates_Golden(t *testing.T) { hello = "localhost" from = "system@coder.com" - hint = "run \"DB=ci make gen/golden-files\" and commit the changes" + hint = "run \"make gen/golden-files\" and commit the changes" ) tests := []struct { diff --git a/coderd/rbac/authz.go b/coderd/rbac/authz.go index f57ed2585c068..fcb6621a34cee 100644 --- a/coderd/rbac/authz.go +++ b/coderd/rbac/authz.go @@ -65,6 +65,7 @@ const ( SubjectTypeUser SubjectType = "user" SubjectTypeProvisionerd SubjectType = "provisionerd" SubjectTypeAutostart SubjectType = "autostart" + SubjectTypeConnectionLogger SubjectType = "connection_logger" SubjectTypeJobReaper SubjectType = "job_reaper" SubjectTypeResourceMonitor SubjectType = "resource_monitor" SubjectTypeCryptoKeyRotator SubjectType = "crypto_key_rotator" diff --git a/coderd/rbac/object_gen.go b/coderd/rbac/object_gen.go index d0d5dc4aab0fe..5fb3cc2bd8a3b 100644 --- a/coderd/rbac/object_gen.go +++ b/coderd/rbac/object_gen.go @@ -54,6 +54,14 @@ var ( Type: "audit_log", } + // ResourceConnectionLog + // Valid Actions + // - "ActionRead" :: read connection logs + // - "ActionUpdate" :: upsert connection log entries + ResourceConnectionLog = Object{ + Type: "connection_log", + } + // ResourceCryptoKey // Valid Actions // - "ActionCreate" :: create crypto keys @@ -368,6 +376,7 @@ func AllResources() []Objecter { ResourceAssignOrgRole, ResourceAssignRole, ResourceAuditLog, + ResourceConnectionLog, ResourceCryptoKey, ResourceDebugInfo, ResourceDeploymentConfig, diff --git a/coderd/rbac/policy/policy.go b/coderd/rbac/policy/policy.go index a3ad614439c9a..a10abfb9605ca 100644 --- a/coderd/rbac/policy/policy.go +++ b/coderd/rbac/policy/policy.go @@ -138,6 +138,12 @@ var RBACPermissions = map[string]PermissionDefinition{ ActionCreate: actDef("create new audit log entries"), }, }, + "connection_log": { + Actions: map[Action]ActionDefinition{ + ActionRead: actDef("read connection logs"), + ActionUpdate: actDef("upsert connection log entries"), + }, + }, "deployment_config": { Actions: map[Action]ActionDefinition{ ActionRead: actDef("read deployment config"), diff --git a/coderd/rbac/regosql/configs.go b/coderd/rbac/regosql/configs.go index 2cb03b238f471..69d425d9dba2f 100644 --- a/coderd/rbac/regosql/configs.go +++ b/coderd/rbac/regosql/configs.go @@ -50,6 +50,20 @@ func AuditLogConverter() *sqltypes.VariableConverter { return matcher } +func ConnectionLogConverter() *sqltypes.VariableConverter { + matcher := sqltypes.NewVariableConverter().RegisterMatcher( + resourceIDMatcher(), + sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}), + // Connection logs have no user owner, only owner by an organization. + sqltypes.AlwaysFalse(userOwnerMatcher()), + ) + matcher.RegisterMatcher( + sqltypes.AlwaysFalse(groupACLMatcher(matcher)), + sqltypes.AlwaysFalse(userACLMatcher(matcher)), + ) + return matcher +} + func UserConverter() *sqltypes.VariableConverter { matcher := sqltypes.NewVariableConverter().RegisterMatcher( resourceIDMatcher(), diff --git a/coderd/rbac/roles.go b/coderd/rbac/roles.go index ebc7ff8f12070..b8d3f959ce477 100644 --- a/coderd/rbac/roles.go +++ b/coderd/rbac/roles.go @@ -315,6 +315,7 @@ func ReloadBuiltinRoles(opts *RoleOptions) { Site: Permissions(map[string][]policy.Action{ ResourceAssignOrgRole.Type: {policy.ActionRead}, ResourceAuditLog.Type: {policy.ActionRead}, + ResourceConnectionLog.Type: {policy.ActionRead}, // Allow auditors to see the resources that audit logs reflect. ResourceTemplate.Type: {policy.ActionRead, policy.ActionViewInsights}, ResourceUser.Type: {policy.ActionRead}, @@ -456,7 +457,8 @@ func ReloadBuiltinRoles(opts *RoleOptions) { Site: []Permission{}, Org: map[string][]Permission{ organizationID.String(): Permissions(map[string][]policy.Action{ - ResourceAuditLog.Type: {policy.ActionRead}, + ResourceAuditLog.Type: {policy.ActionRead}, + ResourceConnectionLog.Type: {policy.ActionRead}, // Allow auditors to see the resources that audit logs reflect. ResourceTemplate.Type: {policy.ActionRead, policy.ActionViewInsights}, ResourceGroup.Type: {policy.ActionRead}, diff --git a/coderd/rbac/roles_test.go b/coderd/rbac/roles_test.go index 3e6f7d1e330d5..267a99993e642 100644 --- a/coderd/rbac/roles_test.go +++ b/coderd/rbac/roles_test.go @@ -849,6 +849,15 @@ func TestRolePermissions(t *testing.T) { }, }, }, + { + Name: "ConnectionLogs", + Actions: []policy.Action{policy.ActionRead, policy.ActionUpdate}, + Resource: rbac.ResourceConnectionLog, + AuthorizeMap: map[bool][]hasAuthSubjects{ + true: {owner}, + false: {setOtherOrg, setOrgNotMe, memberMe, orgMemberMe, templateAdmin, userAdmin}, + }, + }, } // We expect every permission to be tested above. diff --git a/coderd/rbac/rolestore/rolestore_test.go b/coderd/rbac/rolestore/rolestore_test.go index b7712357d0721..47289704d8e49 100644 --- a/coderd/rbac/rolestore/rolestore_test.go +++ b/coderd/rbac/rolestore/rolestore_test.go @@ -8,7 +8,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/rolestore" "github.com/coder/coder/v2/testutil" @@ -17,7 +17,7 @@ import ( func TestExpandCustomRoleRoles(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) org := dbgen.Organization(t, db, database.Organization{}) diff --git a/coderd/runtimeconfig/entry_test.go b/coderd/runtimeconfig/entry_test.go index 3092dae88c4cd..f8e2a925e29d8 100644 --- a/coderd/runtimeconfig/entry_test.go +++ b/coderd/runtimeconfig/entry_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/runtimeconfig" "github.com/coder/coder/v2/testutil" "github.com/coder/serpent" @@ -32,7 +32,7 @@ func TestEntry(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) mgr := runtimeconfig.NewManager() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) override := serpent.String("dogfood@dev.coder.com") @@ -54,7 +54,7 @@ func TestEntry(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) mgr := runtimeconfig.NewManager() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) override := serpent.Struct[map[string]string]{ Value: map[string]string{ diff --git a/coderd/schedule/autostart.go b/coderd/schedule/autostart.go index 0a7f583e4f9b2..538d3dd346fcd 100644 --- a/coderd/schedule/autostart.go +++ b/coderd/schedule/autostart.go @@ -33,6 +33,8 @@ func NextAutostart(at time.Time, wsSchedule string, templateSchedule TemplateSch return zonedTransition, allowed } +// NextAllowedAutostart returns the next valid autostart time after 'at', based on the workspace's +// cron schedule and the template's allowed days. It searches up to 7 days ahead to find a match. func NextAllowedAutostart(at time.Time, wsSchedule string, templateSchedule TemplateScheduleOptions) (time.Time, error) { next := at diff --git a/coderd/searchquery/search_test.go b/coderd/searchquery/search_test.go index 2b7f4f402e008..ad5f2df966ef9 100644 --- a/coderd/searchquery/search_test.go +++ b/coderd/searchquery/search_test.go @@ -14,7 +14,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/searchquery" "github.com/coder/coder/v2/codersdk" ) @@ -300,7 +300,7 @@ func TestSearchWorkspace(t *testing.T) { t.Run(c.Name, func(t *testing.T) { t.Parallel() // TODO: Replace this with the mock database. - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) if c.Setup != nil { c.Setup(t, db) } @@ -331,7 +331,8 @@ func TestSearchWorkspace(t *testing.T) { query := `` timeout := 1337 * time.Second - values, errs := searchquery.Workspaces(context.Background(), dbmem.New(), query, codersdk.Pagination{}, timeout) + db, _ := dbtestutil.NewDB(t) + values, errs := searchquery.Workspaces(context.Background(), db, query, codersdk.Pagination{}, timeout) require.Empty(t, errs) require.Equal(t, int64(timeout.Seconds()), values.AgentInactiveDisconnectTimeoutSeconds) }) @@ -389,7 +390,7 @@ func TestSearchAudit(t *testing.T) { t.Parallel() // Do not use a real database, this is only used for an // organization lookup. - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) values, countValues, errs := searchquery.AuditLogs(context.Background(), db, c.Query) if c.ExpectedErrorContains != "" { require.True(t, len(errs) > 0, "expect some errors") @@ -628,7 +629,7 @@ func TestSearchTemplates(t *testing.T) { t.Parallel() // Do not use a real database, this is only used for an // organization lookup. - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) values, errs := searchquery.Templates(context.Background(), db, c.Query) if c.ExpectedErrorContains != "" { require.True(t, len(errs) > 0, "expect some errors") diff --git a/coderd/updatecheck/updatecheck_test.go b/coderd/updatecheck/updatecheck_test.go index 725ceb44d9d6f..2e616a550f231 100644 --- a/coderd/updatecheck/updatecheck_test.go +++ b/coderd/updatecheck/updatecheck_test.go @@ -14,7 +14,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/updatecheck" "github.com/coder/coder/v2/testutil" ) @@ -49,7 +49,7 @@ func TestChecker_Notify(t *testing.T) { })) defer srv.Close() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named(t.Name()) notify := make(chan updatecheck.Result, len(wantVersion)) c := updatecheck.New(db, logger, updatecheck.Options{ @@ -130,7 +130,7 @@ func TestChecker_Latest(t *testing.T) { })) defer srv.Close() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named(t.Name()) c := updatecheck.New(db, logger, updatecheck.Options{ URL: srv.URL, diff --git a/coderd/users_test.go b/coderd/users_test.go index bd0f138b6a339..9d695f37c9906 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -32,7 +32,6 @@ import ( "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/util/ptr" @@ -1794,15 +1793,6 @@ func TestUsersFilter(t *testing.T) { } } - // TODO: This can be removed with dbmem - if !dbtestutil.WillUsePostgres() { - for i := range matched.Users { - if len(matched.Users[i].OrganizationIDs) == 0 { - matched.Users[i].OrganizationIDs = nil - } - } - } - require.ElementsMatch(t, exp, matched.Users, "expected users returned") }) } diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 0ab28b340a1d1..3ae57d8394d43 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -801,6 +801,106 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req httpapi.Write(ctx, rw, http.StatusOK, portsResponse) } +// @Summary Watch workspace agent for container updates. +// @ID watch-workspace-agent-for-container-updates +// @Security CoderSessionToken +// @Produce json +// @Tags Agents +// @Param workspaceagent path string true "Workspace agent ID" format(uuid) +// @Success 200 {object} codersdk.WorkspaceAgentListContainersResponse +// @Router /workspaceagents/{workspaceagent}/containers/watch [get] +func (api *API) watchWorkspaceAgentContainers(rw http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + workspaceAgent = httpmw.WorkspaceAgentParam(r) + ) + + // If the agent is unreachable, the request will hang. Assume that if we + // don't get a response after 30s that the agent is unreachable. + dialCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + apiAgent, err := db2sdk.WorkspaceAgent( + api.DERPMap(), + *api.TailnetCoordinator.Load(), + workspaceAgent, + nil, + nil, + nil, + api.AgentInactiveDisconnectTimeout, + api.DeploymentValues.AgentFallbackTroubleshootingURL.String(), + ) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error reading workspace agent.", + Detail: err.Error(), + }) + return + } + if apiAgent.Status != codersdk.WorkspaceAgentConnected { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Agent state is %q, it must be in the %q state.", apiAgent.Status, codersdk.WorkspaceAgentConnected), + }) + return + } + + agentConn, release, err := api.agentProvider.AgentConn(dialCtx, workspaceAgent.ID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error dialing workspace agent.", + Detail: err.Error(), + }) + return + } + defer release() + + watcherLogger := api.Logger.Named("agent_container_watcher").With(slog.F("agent_id", workspaceAgent.ID)) + containersCh, closer, err := agentConn.WatchContainers(ctx, watcherLogger) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error watching agent's containers.", + Detail: err.Error(), + }) + return + } + defer closer.Close() + + conn, err := websocket.Accept(rw, r, nil) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to upgrade connection to websocket.", + Detail: err.Error(), + }) + return + } + + // Here we close the websocket for reading, so that the websocket library will handle pings and + // close frames. + _ = conn.CloseRead(context.Background()) + + ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText) + defer wsNetConn.Close() + + go httpapi.Heartbeat(ctx, conn) + + encoder := json.NewEncoder(wsNetConn) + + for { + select { + case <-api.ctx.Done(): + return + + case <-ctx.Done(): + return + + case containers := <-containersCh: + if err := encoder.Encode(containers); err != nil { + api.Logger.Error(ctx, "encode containers", slog.Error(err)) + return + } + } + } +} + // @Summary Get running containers for workspace agent // @ID get-running-containers-for-workspace-agent // @Security CoderSessionToken diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 4a37a1bf7bc52..30859cb6391e6 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -46,7 +46,6 @@ import ( "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/pubsub" @@ -1387,6 +1386,192 @@ func TestWorkspaceAgentContainers(t *testing.T) { }) } +func TestWatchWorkspaceAgentDevcontainers(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitLong) + logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + mClock = quartz.NewMock(t) + updaterTickerTrap = mClock.Trap().TickerFunc("updaterLoop") + mCtrl = gomock.NewController(t) + mCCLI = acmock.NewMockContainerCLI(mCtrl) + + client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{Logger: &logger}) + user = coderdtest.CreateFirstUser(t, client) + r = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent(func(agents []*proto.Agent) []*proto.Agent { + return agents + }).Do() + + fakeContainer1 = codersdk.WorkspaceAgentContainer{ + ID: "container1", + CreatedAt: dbtime.Now(), + FriendlyName: "container1", + Image: "busybox:latest", + Labels: map[string]string{ + agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project1", + agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project1/.devcontainer/devcontainer.json", + }, + Running: true, + Status: "running", + } + + fakeContainer2 = codersdk.WorkspaceAgentContainer{ + ID: "container1", + CreatedAt: dbtime.Now(), + FriendlyName: "container2", + Image: "busybox:latest", + Labels: map[string]string{ + agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project2", + agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project2/.devcontainer/devcontainer.json", + }, + Running: true, + Status: "running", + } + ) + + stages := []struct { + containers []codersdk.WorkspaceAgentContainer + expected codersdk.WorkspaceAgentListContainersResponse + }{ + { + containers: []codersdk.WorkspaceAgentContainer{fakeContainer1}, + expected: codersdk.WorkspaceAgentListContainersResponse{ + Containers: []codersdk.WorkspaceAgentContainer{fakeContainer1}, + Devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + Name: "project1", + WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel], + ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel], + Status: "running", + Container: &fakeContainer1, + }, + }, + }, + }, + { + containers: []codersdk.WorkspaceAgentContainer{fakeContainer1, fakeContainer2}, + expected: codersdk.WorkspaceAgentListContainersResponse{ + Containers: []codersdk.WorkspaceAgentContainer{fakeContainer1, fakeContainer2}, + Devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + Name: "project1", + WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel], + ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel], + Status: "running", + Container: &fakeContainer1, + }, + { + Name: "project2", + WorkspaceFolder: fakeContainer2.Labels[agentcontainers.DevcontainerLocalFolderLabel], + ConfigPath: fakeContainer2.Labels[agentcontainers.DevcontainerConfigFileLabel], + Status: "running", + Container: &fakeContainer2, + }, + }, + }, + }, + { + containers: []codersdk.WorkspaceAgentContainer{fakeContainer2}, + expected: codersdk.WorkspaceAgentListContainersResponse{ + Containers: []codersdk.WorkspaceAgentContainer{fakeContainer2}, + Devcontainers: []codersdk.WorkspaceAgentDevcontainer{ + { + Name: "", + WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel], + ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel], + Status: "stopped", + Container: nil, + }, + { + Name: "project2", + WorkspaceFolder: fakeContainer2.Labels[agentcontainers.DevcontainerLocalFolderLabel], + ConfigPath: fakeContainer2.Labels[agentcontainers.DevcontainerConfigFileLabel], + Status: "running", + Container: &fakeContainer2, + }, + }, + }, + }, + } + + // Set up initial state for immediate send on connection + mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: stages[0].containers}, nil) + mCCLI.EXPECT().DetectArchitecture(gomock.Any(), gomock.Any()).Return("", nil).AnyTimes() + + _ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) { + o.Logger = logger.Named("agent") + o.Devcontainers = true + o.DevcontainerAPIOptions = []agentcontainers.Option{ + agentcontainers.WithClock(mClock), + agentcontainers.WithContainerCLI(mCCLI), + agentcontainers.WithWatcher(watcher.NewNoop()), + } + }) + + resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait() + require.Len(t, resources, 1, "expected one resource") + require.Len(t, resources[0].Agents, 1, "expected one agent") + agentID := resources[0].Agents[0].ID + + updaterTickerTrap.MustWait(ctx).MustRelease(ctx) + defer updaterTickerTrap.Close() + + containers, closer, err := client.WatchWorkspaceAgentContainers(ctx, agentID) + require.NoError(t, err) + defer func() { + closer.Close() + }() + + // Read initial state sent immediately on connection + var got codersdk.WorkspaceAgentListContainersResponse + select { + case <-ctx.Done(): + case got = <-containers: + } + require.NoError(t, ctx.Err()) + + require.Equal(t, stages[0].expected.Containers, got.Containers) + require.Len(t, got.Devcontainers, len(stages[0].expected.Devcontainers)) + for j, expectedDev := range stages[0].expected.Devcontainers { + gotDev := got.Devcontainers[j] + require.Equal(t, expectedDev.Name, gotDev.Name) + require.Equal(t, expectedDev.WorkspaceFolder, gotDev.WorkspaceFolder) + require.Equal(t, expectedDev.ConfigPath, gotDev.ConfigPath) + require.Equal(t, expectedDev.Status, gotDev.Status) + require.Equal(t, expectedDev.Container, gotDev.Container) + } + + // Process remaining stages through updater loop + for i, stage := range stages[1:] { + mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: stage.containers}, nil) + + _, aw := mClock.AdvanceNext() + aw.MustWait(ctx) + + var got codersdk.WorkspaceAgentListContainersResponse + select { + case <-ctx.Done(): + case got = <-containers: + } + require.NoError(t, ctx.Err()) + + require.Equal(t, stages[i+1].expected.Containers, got.Containers) + require.Len(t, got.Devcontainers, len(stages[i+1].expected.Devcontainers)) + for j, expectedDev := range stages[i+1].expected.Devcontainers { + gotDev := got.Devcontainers[j] + require.Equal(t, expectedDev.Name, gotDev.Name) + require.Equal(t, expectedDev.WorkspaceFolder, gotDev.WorkspaceFolder) + require.Equal(t, expectedDev.ConfigPath, gotDev.ConfigPath) + require.Equal(t, expectedDev.Status, gotDev.Status) + require.Equal(t, expectedDev.Container, gotDev.Container) + } + } +} + func TestWorkspaceAgentRecreateDevcontainer(t *testing.T) { t.Parallel() @@ -1989,8 +2174,8 @@ func (s *testWAMErrorStore) GetWorkspaceAgentMetadata(ctx context.Context, arg d func TestWorkspaceAgent_Metadata_CatchMemoryLeak(t *testing.T) { t.Parallel() - db := &testWAMErrorStore{Store: dbmem.New()} - psub := pubsub.NewInMemory() + store, psub := dbtestutil.NewDB(t) + db := &testWAMErrorStore{Store: store} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("coderd").Leveled(slog.LevelDebug) client := coderdtest.New(t, &coderdtest.Options{ Database: db, diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index 1cbabad8ea622..0806118f2a832 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -139,7 +139,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { Database: api.Database, NotificationsEnqueuer: api.NotificationsEnqueuer, Pubsub: api.Pubsub, - Auditor: &api.Auditor, + ConnectionLogger: &api.ConnectionLogger, DerpMapFn: api.DERPMap, TailnetCoordinator: &api.TailnetCoordinator, AppearanceFetcher: &api.AppearanceFetcher, diff --git a/coderd/workspaceapps/db.go b/coderd/workspaceapps/db.go index 0b598a6f0aab9..61a9e218edc7f 100644 --- a/coderd/workspaceapps/db.go +++ b/coderd/workspaceapps/db.go @@ -3,7 +3,6 @@ package workspaceapps import ( "context" "database/sql" - "encoding/json" "fmt" "net/http" "net/url" @@ -18,7 +17,7 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/connectionlog" "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" @@ -40,7 +39,7 @@ type DBTokenProvider struct { // DashboardURL is the main dashboard access URL for error pages. DashboardURL *url.URL Authorizer rbac.Authorizer - Auditor *atomic.Pointer[audit.Auditor] + ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger] Database database.Store DeploymentValues *codersdk.DeploymentValues OAuth2Configs *httpmw.OAuth2Configs @@ -54,7 +53,7 @@ var _ SignedTokenProvider = &DBTokenProvider{} func NewDBTokenProvider(log slog.Logger, accessURL *url.URL, authz rbac.Authorizer, - auditor *atomic.Pointer[audit.Auditor], + connectionLogger *atomic.Pointer[connectionlog.ConnectionLogger], db database.Store, cfg *codersdk.DeploymentValues, oauth2Cfgs *httpmw.OAuth2Configs, @@ -73,7 +72,7 @@ func NewDBTokenProvider(log slog.Logger, Logger: log, DashboardURL: accessURL, Authorizer: authz, - Auditor: auditor, + ConnectionLogger: connectionLogger, Database: db, DeploymentValues: cfg, OAuth2Configs: oauth2Cfgs, @@ -95,7 +94,7 @@ func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r * // // permissions. dangerousSystemCtx := dbauthz.AsSystemRestricted(ctx) - aReq, commitAudit := p.auditInitRequest(ctx, rw, r) + aReq, commitAudit := p.connLogInitRequest(ctx, rw, r) defer commitAudit() appReq := issueReq.AppRequest.Normalize() @@ -386,20 +385,20 @@ func (p *DBTokenProvider) authorizeRequest(ctx context.Context, roles *rbac.Subj return false, warnings, nil } -type auditRequest struct { +type connLogRequest struct { time time.Time apiKey *database.APIKey dbReq *databaseRequest } -// auditInitRequest creates a new audit session and audit log for the given -// request, if one does not already exist. If an audit session already exists, -// it will be updated with the current timestamp. A session is used to reduce -// the number of audit logs created. +// connLogInitRequest creates a new connection log session and connect log for the +// given request, if one does not already exist. If a connection log session +// already exists, it will be updated with the current timestamp. A session is used to +// reduce the number of connection logs created. // // A session is unique to the agent, app, user and users IP. If any of these -// values change, a new session and audit log is created. -func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (aReq *auditRequest, commit func()) { +// values change, a new session and connect log is created. +func (p *DBTokenProvider) connLogInitRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (aReq *connLogRequest, commit func()) { // Get the status writer from the request context so we can figure // out the HTTP status and autocommit the audit log. sw, ok := w.(*tracing.StatusWriter) @@ -407,12 +406,12 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW panic("dev error: http.ResponseWriter is not *tracing.StatusWriter") } - aReq = &auditRequest{ + aReq = &connLogRequest{ time: dbtime.Now(), } - // Set the commit function on the status writer to create an audit - // log, this ensures that the status and response body are available. + // Set the commit function on the status writer to create a connection log + // this ensures that the status and response body are available. var committed bool return aReq, func() { if committed { @@ -422,7 +421,7 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW if aReq.dbReq == nil { // App doesn't exist, there's information in the Request - // struct but we need UUIDs for audit logging. + // struct but we need UUIDs for connection logging. return } @@ -434,28 +433,25 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW ip := r.RemoteAddr // Approximation of the status code. - statusCode := sw.Status + // #nosec G115 - Safe conversion as HTTP status code is expected to be within int32 range (typically 100-599) + var statusCode int32 = int32(sw.Status) if statusCode == 0 { statusCode = http.StatusOK } - type additionalFields struct { - audit.AdditionalFields - SlugOrPort string `json:"slug_or_port,omitempty"` - } - appInfo := additionalFields{ - AdditionalFields: audit.AdditionalFields{ - WorkspaceOwner: aReq.dbReq.Workspace.OwnerUsername, - WorkspaceName: aReq.dbReq.Workspace.Name, - WorkspaceID: aReq.dbReq.Workspace.ID, - }, - } + var ( + connType database.ConnectionType + slugOrPort = aReq.dbReq.AppSlugOrPort + ) + switch { case aReq.dbReq.AccessMethod == AccessMethodTerminal: - appInfo.SlugOrPort = "terminal" + connType = database.ConnectionTypeWorkspaceApp + slugOrPort = "terminal" case aReq.dbReq.App.ID == uuid.Nil: - // If this isn't an app or a terminal, it's a port. - appInfo.SlugOrPort = aReq.dbReq.AppSlugOrPort + connType = database.ConnectionTypePortForwarding + default: + connType = database.ConnectionTypeWorkspaceApp } // If we end up logging, ensure relevant fields are set. @@ -465,7 +461,7 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW slog.F("app_id", aReq.dbReq.App.ID), slog.F("user_id", userID), slog.F("user_agent", userAgent), - slog.F("app_slug_or_port", appInfo.SlugOrPort), + slog.F("app_slug_or_port", slugOrPort), slog.F("status_code", statusCode), ) @@ -485,9 +481,8 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW UserID: userID, // Can be unset, in which case uuid.Nil is fine. Ip: ip, UserAgent: userAgent, - SlugOrPort: appInfo.SlugOrPort, - // #nosec G115 - Safe conversion as HTTP status code is expected to be within int32 range (typically 100-599) - StatusCode: int32(statusCode), + SlugOrPort: slugOrPort, + StatusCode: statusCode, StartedAt: aReq.time, UpdatedAt: aReq.time, }) @@ -500,7 +495,7 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW if err != nil { logger.Error(ctx, "update workspace app audit session failed", slog.Error(err)) - // Avoid spamming the audit log if deduplication failed, this should + // Avoid spamming the connection log if deduplication failed, this should // only happen if there are problems communicating with the database. return } @@ -511,51 +506,37 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW return } - // Marshal additional fields only if we're writing an audit log entry. - appInfoBytes, err := json.Marshal(appInfo) - if err != nil { - logger.Error(ctx, "marshal additional fields failed", slog.Error(err)) - } + connLogger := *p.ConnectionLogger.Load() + + err = connLogger.Upsert(ctx, database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: aReq.time, + OrganizationID: aReq.dbReq.Workspace.OrganizationID, + WorkspaceOwnerID: aReq.dbReq.Workspace.OwnerID, + WorkspaceID: aReq.dbReq.Workspace.ID, + WorkspaceName: aReq.dbReq.Workspace.Name, + AgentName: aReq.dbReq.Agent.Name, + Type: connType, + Code: sql.NullInt32{ + Int32: statusCode, + Valid: true, + }, + Ip: database.ParseIP(ip), + UserAgent: sql.NullString{Valid: userAgent != "", String: userAgent}, + UserID: uuid.NullUUID{ + UUID: userID, + Valid: userID != uuid.Nil, + }, + SlugOrPort: sql.NullString{Valid: slugOrPort != "", String: slugOrPort}, + ConnectionStatus: database.ConnectionStatusConnected, - // We use the background audit function instead of init request - // here because we don't know the resource type ahead of time. - // This also allows us to log unauthenticated access. - auditor := *p.Auditor.Load() - requestID := httpmw.RequestID(r) - switch { - case aReq.dbReq.App.ID != uuid.Nil: - audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.WorkspaceApp]{ - Audit: auditor, - Log: logger, - - Action: database.AuditActionOpen, - OrganizationID: aReq.dbReq.Workspace.OrganizationID, - UserID: userID, - RequestID: requestID, - Time: aReq.time, - Status: statusCode, - IP: ip, - UserAgent: userAgent, - New: aReq.dbReq.App, - AdditionalFields: appInfoBytes, - }) - default: - // Web terminal, port app, etc. - audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.WorkspaceAgent]{ - Audit: auditor, - Log: logger, - - Action: database.AuditActionOpen, - OrganizationID: aReq.dbReq.Workspace.OrganizationID, - UserID: userID, - RequestID: requestID, - Time: aReq.time, - Status: statusCode, - IP: ip, - UserAgent: userAgent, - New: aReq.dbReq.Agent, - AdditionalFields: appInfoBytes, - }) + // N/A + ConnectionID: uuid.NullUUID{}, + DisconnectReason: sql.NullString{}, + }) + if err != nil { + logger.Error(ctx, "upsert connection log failed", slog.Error(err)) + return } } } diff --git a/coderd/workspaceapps/db_test.go b/coderd/workspaceapps/db_test.go index a1f3fb452fbe5..e78762c035565 100644 --- a/coderd/workspaceapps/db_test.go +++ b/coderd/workspaceapps/db_test.go @@ -3,7 +3,6 @@ package workspaceapps_test import ( "context" "database/sql" - "encoding/json" "fmt" "io" "net" @@ -22,10 +21,9 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/coder/v2/agent/agenttest" - "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/connectionlog" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/jwtutils" "github.com/coder/coder/v2/coderd/tracing" @@ -83,12 +81,12 @@ func Test_ResolveRequest(t *testing.T) { deploymentValues.Dangerous.AllowPathAppSharing = true deploymentValues.Dangerous.AllowPathAppSiteOwnerAccess = true - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() t.Cleanup(func() { if t.Failed() { return } - assert.Len(t, auditor.AuditLogs(), 0, "one or more test cases produced unexpected audit logs, did you replace the auditor or forget to call ResetLogs?") + assert.Len(t, connLogger.ConnectionLogs(), 0, "one or more test cases produced unexpected connection logs, did you replace the auditor or forget to call ResetLogs?") }) client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ AppHostname: "*.test.coder.com", @@ -105,7 +103,7 @@ func Test_ResolveRequest(t *testing.T) { "CF-Connecting-IP", }, }, - Auditor: auditor, + ConnectionLogger: connLogger, }) t.Cleanup(func() { _ = closer.Close() @@ -231,23 +229,8 @@ func Test_ResolveRequest(t *testing.T) { } require.NotEqual(t, uuid.Nil, agentID) - //nolint:gocritic // This is a test, allow dbauthz.AsSystemRestricted. - agent, err := api.Database.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID) - require.NoError(t, err) - - //nolint:gocritic // This is a test, allow dbauthz.AsSystemRestricted. - apps, err := api.Database.GetWorkspaceAppsByAgentID(dbauthz.AsSystemRestricted(ctx), agentID) - require.NoError(t, err) - appsBySlug := make(map[string]database.WorkspaceApp, len(apps)) - for _, app := range apps { - appsBySlug[app.Slug] = app - } - // Reset audit logs so cleanup check can pass. - auditor.ResetLogs() - - assertAuditAgent := auditAsserter[database.WorkspaceAgent](workspace) - assertAuditApp := auditAsserter[database.WorkspaceApp](workspace) + connLogger.Reset() t.Run("OK", func(t *testing.T) { t.Parallel() @@ -285,9 +268,9 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: app, }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) - auditableUA := "Tidua" + auditableUA := "Noitcennoc" t.Log("app", app) rw := httptest.NewRecorder() @@ -297,7 +280,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set("User-Agent", auditableUA) // Try resolving the request without a token. - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -333,8 +316,8 @@ func Test_ResolveRequest(t *testing.T) { require.Equal(t, codersdk.SignedAppTokenCookie, cookie.Name) require.Equal(t, req.BasePath, cookie.Path) - assertAuditApp(t, rw, r, auditor, appsBySlug[app], me.ID, nil) - require.Len(t, auditor.AuditLogs(), 1, "audit log count") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, me.ID) + require.Len(t, connLogger.ConnectionLogs(), 1) var parsedToken workspaceapps.SignedToken err := jwtutils.Verify(ctx, api.AppSigningKeyCache, cookie.Value, &parsedToken) @@ -350,7 +333,7 @@ func Test_ResolveRequest(t *testing.T) { r.AddCookie(cookie) r.RemoteAddr = auditableIP - secondToken, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + secondToken, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -363,7 +346,7 @@ func Test_ResolveRequest(t *testing.T) { require.WithinDuration(t, token.Expiry.Time(), secondToken.Expiry.Time(), 2*time.Second) secondToken.Expiry = token.Expiry require.Equal(t, token, secondToken) - require.Len(t, auditor.AuditLogs(), 1, "no new audit log, FromRequest returned the same token and is not audited") + require.Len(t, connLogger.ConnectionLogs(), 1, "no new connection log, FromRequest returned the same token and is not logged") } }) } @@ -382,7 +365,7 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: app, }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) t.Log("app", app) @@ -391,7 +374,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, secondUserClient.SessionToken()) r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -406,14 +389,15 @@ func Test_ResolveRequest(t *testing.T) { require.Nil(t, token) require.NotZero(t, w.StatusCode) require.Equal(t, http.StatusNotFound, w.StatusCode) + require.Len(t, connLogger.ConnectionLogs(), 1) return } require.True(t, ok) require.NotNil(t, token) require.Zero(t, w.StatusCode) - assertAuditApp(t, rw, r, auditor, appsBySlug[app], secondUser.ID, nil) - require.Len(t, auditor.AuditLogs(), 1, "single audit log") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, secondUser.ID) + require.Len(t, connLogger.ConnectionLogs(), 1) } }) @@ -430,14 +414,14 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: app, }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) t.Log("app", app) rw := httptest.NewRecorder() r := httptest.NewRequest("GET", "/app", nil) r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -452,8 +436,8 @@ func Test_ResolveRequest(t *testing.T) { require.NotZero(t, rw.Code) require.NotEqual(t, http.StatusOK, rw.Code) - assertAuditApp(t, rw, r, auditor, appsBySlug[app], uuid.Nil, nil) - require.Len(t, auditor.AuditLogs(), 1, "audit log for unauthenticated requests") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, uuid.Nil) + require.Len(t, connLogger.ConnectionLogs(), 1) } else { if !assert.True(t, ok) { dump, err := httputil.DumpResponse(w, true) @@ -466,8 +450,8 @@ func Test_ResolveRequest(t *testing.T) { t.Fatalf("expected 200 (or unset) response code, got %d", rw.Code) } - assertAuditApp(t, rw, r, auditor, appsBySlug[app], uuid.Nil, nil) - require.Len(t, auditor.AuditLogs(), 1, "single audit log") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, uuid.Nil) + require.Len(t, connLogger.ConnectionLogs(), 1) } _ = w.Body.Close() } @@ -479,12 +463,12 @@ func Test_ResolveRequest(t *testing.T) { req := (workspaceapps.Request{ AccessMethod: "invalid", }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() r := httptest.NewRequest("GET", "/app", nil) r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -494,7 +478,7 @@ func Test_ResolveRequest(t *testing.T) { }) require.False(t, ok) require.Nil(t, token) - require.Len(t, auditor.AuditLogs(), 0, "no audit logs for invalid requests") + require.Len(t, connLogger.ConnectionLogs(), 0) }) t.Run("SplitWorkspaceAndAgent", func(t *testing.T) { @@ -562,7 +546,7 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: appNamePublic, }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() @@ -570,7 +554,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -591,11 +575,11 @@ func Test_ResolveRequest(t *testing.T) { require.Equal(t, token.AgentNameOrID, c.agent) require.Equal(t, token.WorkspaceID, workspace.ID) require.Equal(t, token.AgentID, agentID) - assertAuditApp(t, rw, r, auditor, appsBySlug[token.AppSlugOrPort], me.ID, nil) - require.Len(t, auditor.AuditLogs(), 1, "single audit log") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, token.AppSlugOrPort, database.ConnectionTypeWorkspaceApp, me.ID) + require.Len(t, connLogger.ConnectionLogs(), 1) } else { require.Nil(t, token) - require.Len(t, auditor.AuditLogs(), 0, "no audit logs") + require.Len(t, connLogger.ConnectionLogs(), 0) } _ = w.Body.Close() }) @@ -637,7 +621,7 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: appNameOwner, }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() @@ -651,7 +635,7 @@ func Test_ResolveRequest(t *testing.T) { // Even though the token is invalid, we should still perform request // resolution without failure since we'll just ignore the bad token. - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -676,8 +660,8 @@ func Test_ResolveRequest(t *testing.T) { require.NoError(t, err) require.Equal(t, appNameOwner, parsedToken.AppSlugOrPort) - assertAuditApp(t, rw, r, auditor, appsBySlug[appNameOwner], me.ID, nil) - require.Len(t, auditor.AuditLogs(), 1, "single audit log") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, appNameOwner, database.ConnectionTypeWorkspaceApp, me.ID) + require.Len(t, connLogger.ConnectionLogs(), 1) }) t.Run("PortPathBlocked", func(t *testing.T) { @@ -692,7 +676,7 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: "8080", }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() @@ -700,7 +684,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -715,7 +699,7 @@ func Test_ResolveRequest(t *testing.T) { _ = w.Body.Close() // TODO(mafredri): Verify this is the correct status code. require.Equal(t, http.StatusInternalServerError, w.StatusCode) - require.Len(t, auditor.AuditLogs(), 0, "no audit logs for port path blocked requests") + require.Len(t, connLogger.ConnectionLogs(), 0, "no connection logs for port path blocked requests") }) t.Run("PortSubdomain", func(t *testing.T) { @@ -730,7 +714,7 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: "9090", }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() @@ -738,7 +722,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -749,11 +733,8 @@ func Test_ResolveRequest(t *testing.T) { require.True(t, ok) require.Equal(t, req.AppSlugOrPort, token.AppSlugOrPort) require.Equal(t, "http://127.0.0.1:9090", token.AppURL) - - assertAuditAgent(t, rw, r, auditor, agent, me.ID, map[string]any{ - "slug_or_port": "9090", - }) - require.Len(t, auditor.AuditLogs(), 1, "single audit log") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, "9090", database.ConnectionTypePortForwarding, me.ID) + require.Len(t, connLogger.ConnectionLogs(), 1) }) t.Run("PortSubdomainHTTPSS", func(t *testing.T) { @@ -768,7 +749,7 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: "9090ss", }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() @@ -776,7 +757,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - _, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + _, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -792,7 +773,7 @@ func Test_ResolveRequest(t *testing.T) { require.NoError(t, err) require.Contains(t, string(b), "404 - Application Not Found") require.Equal(t, http.StatusNotFound, w.StatusCode) - require.Len(t, auditor.AuditLogs(), 0, "no audit logs for invalid requests") + require.Len(t, connLogger.ConnectionLogs(), 0) }) t.Run("SubdomainEndsInS", func(t *testing.T) { @@ -807,7 +788,7 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: appNameEndsInS, }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() @@ -815,7 +796,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -825,8 +806,8 @@ func Test_ResolveRequest(t *testing.T) { }) require.True(t, ok) require.Equal(t, req.AppSlugOrPort, token.AppSlugOrPort) - assertAuditApp(t, rw, r, auditor, appsBySlug[appNameEndsInS], me.ID, nil) - require.Len(t, auditor.AuditLogs(), 1, "single audit log") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, appNameEndsInS, database.ConnectionTypeWorkspaceApp, me.ID) + require.Len(t, connLogger.ConnectionLogs(), 1) }) t.Run("Terminal", func(t *testing.T) { @@ -838,7 +819,7 @@ func Test_ResolveRequest(t *testing.T) { AgentNameOrID: agentID.String(), }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() @@ -846,7 +827,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -862,10 +843,8 @@ func Test_ResolveRequest(t *testing.T) { require.Equal(t, req.AgentNameOrID, token.Request.AgentNameOrID) require.Empty(t, token.AppSlugOrPort) require.Empty(t, token.AppURL) - assertAuditAgent(t, rw, r, auditor, agent, me.ID, map[string]any{ - "slug_or_port": "terminal", - }) - require.Len(t, auditor.AuditLogs(), 1, "single audit log") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, "terminal", database.ConnectionTypeWorkspaceApp, me.ID) + require.Len(t, connLogger.ConnectionLogs(), 1) }) t.Run("InsufficientPermissions", func(t *testing.T) { @@ -880,7 +859,7 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: appNameOwner, }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() @@ -888,7 +867,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, secondUserClient.SessionToken()) r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -898,8 +877,8 @@ func Test_ResolveRequest(t *testing.T) { }) require.False(t, ok) require.Nil(t, token) - assertAuditApp(t, rw, r, auditor, appsBySlug[appNameOwner], secondUser.ID, nil) - require.Len(t, auditor.AuditLogs(), 1, "single audit log") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, appNameOwner, database.ConnectionTypeWorkspaceApp, secondUser.ID) + require.Len(t, connLogger.ConnectionLogs(), 1) }) t.Run("UserNotFound", func(t *testing.T) { @@ -913,7 +892,7 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: appNameOwner, }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() @@ -921,7 +900,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -931,7 +910,7 @@ func Test_ResolveRequest(t *testing.T) { }) require.False(t, ok) require.Nil(t, token) - require.Len(t, auditor.AuditLogs(), 0, "no audit logs for user not found") + require.Len(t, connLogger.ConnectionLogs(), 0) }) t.Run("RedirectSubdomainAuth", func(t *testing.T) { @@ -946,7 +925,7 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: appNameOwner, }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() @@ -955,7 +934,7 @@ func Test_ResolveRequest(t *testing.T) { r.Host = "app.com" r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -972,8 +951,8 @@ func Test_ResolveRequest(t *testing.T) { require.Equal(t, http.StatusSeeOther, w.StatusCode) // Note that we don't capture the owner UUID here because the apiKey // check/authorization exits early. - assertAuditApp(t, rw, r, auditor, appsBySlug[appNameOwner], uuid.Nil, nil) - require.Len(t, auditor.AuditLogs(), 1, "autit log entry for redirect") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, appNameOwner, database.ConnectionTypeWorkspaceApp, uuid.Nil) + require.Len(t, connLogger.ConnectionLogs(), 1) loc, err := w.Location() require.NoError(t, err) @@ -1012,7 +991,7 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: appNameAgentUnhealthy, }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() @@ -1020,7 +999,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -1034,8 +1013,8 @@ func Test_ResolveRequest(t *testing.T) { w := rw.Result() defer w.Body.Close() require.Equal(t, http.StatusBadGateway, w.StatusCode) - assertAuditApp(t, rw, r, auditor, appsBySlug[appNameAgentUnhealthy], me.ID, nil) - require.Len(t, auditor.AuditLogs(), 1, "single audit log") + assertConnLogContains(t, rw, r, connLogger, workspace, agentNameUnhealthy, appNameAgentUnhealthy, database.ConnectionTypeWorkspaceApp, me.ID) + require.Len(t, connLogger.ConnectionLogs(), 1) body, err := io.ReadAll(w.Body) require.NoError(t, err) @@ -1075,7 +1054,7 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: appNameInitializing, }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() @@ -1083,7 +1062,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -1093,8 +1072,8 @@ func Test_ResolveRequest(t *testing.T) { }) require.True(t, ok, "ResolveRequest failed, should pass even though app is initializing") require.NotNil(t, token) - assertAuditApp(t, rw, r, auditor, appsBySlug[token.AppSlugOrPort], me.ID, nil) - require.Len(t, auditor.AuditLogs(), 1, "single audit log") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, token.AppSlugOrPort, database.ConnectionTypeWorkspaceApp, me.ID) + require.Len(t, connLogger.ConnectionLogs(), 1) }) // Unhealthy apps are now permitted to connect anyways. This wasn't always @@ -1133,7 +1112,7 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: appNameUnhealthy, }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) rw := httptest.NewRecorder() @@ -1141,7 +1120,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - token, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + token, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -1151,11 +1130,11 @@ func Test_ResolveRequest(t *testing.T) { }) require.True(t, ok, "ResolveRequest failed, should pass even though app is unhealthy") require.NotNil(t, token) - assertAuditApp(t, rw, r, auditor, appsBySlug[token.AppSlugOrPort], me.ID, nil) - require.Len(t, auditor.AuditLogs(), 1, "single audit log") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, token.AppSlugOrPort, database.ConnectionTypeWorkspaceApp, me.ID) + require.Len(t, connLogger.ConnectionLogs(), 1) }) - t.Run("AuditLogging", func(t *testing.T) { + t.Run("ConnectionLogging", func(t *testing.T) { t.Parallel() for _, app := range allApps { @@ -1168,18 +1147,18 @@ func Test_ResolveRequest(t *testing.T) { AppSlugOrPort: app, }).Normalize() - auditor := audit.NewMock() + connLogger := connectionlog.NewFake() auditableIP := testutil.RandomIPv6(t) t.Log("app", app) - // First request, new audit log. + // First request, new connection log. rw := httptest.NewRecorder() r := httptest.NewRequest("GET", "/app", nil) r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - _, ok := workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + _, ok := workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -1188,8 +1167,8 @@ func Test_ResolveRequest(t *testing.T) { AppRequest: req, }) require.True(t, ok) - assertAuditApp(t, rw, r, auditor, appsBySlug[app], me.ID, nil) - require.Len(t, auditor.AuditLogs(), 1, "single audit log") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, me.ID) + require.Len(t, connLogger.ConnectionLogs(), 1) // Second request, no audit log because the session is active. rw = httptest.NewRecorder() @@ -1197,7 +1176,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - _, ok = workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + _, ok = workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -1206,7 +1185,7 @@ func Test_ResolveRequest(t *testing.T) { AppRequest: req, }) require.True(t, ok) - require.Len(t, auditor.AuditLogs(), 1, "single audit log, previous session active") + require.Len(t, connLogger.ConnectionLogs(), 1, "single connection log, previous session active") // Third request, session timed out, new audit log. rw = httptest.NewRecorder() @@ -1214,7 +1193,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - sessionTimeoutTokenProvider := signedTokenProviderWithAuditor(t, api.WorkspaceAppsProvider, auditor, 0) + sessionTimeoutTokenProvider := signedTokenProviderWithConnLogger(t, api.WorkspaceAppsProvider, connLogger, 0) _, ok = workspaceappsResolveRequest(t, nil, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: sessionTimeoutTokenProvider, @@ -1224,8 +1203,8 @@ func Test_ResolveRequest(t *testing.T) { AppRequest: req, }) require.True(t, ok) - assertAuditApp(t, rw, r, auditor, appsBySlug[app], me.ID, nil) - require.Len(t, auditor.AuditLogs(), 2, "two audit logs, session timed out") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, me.ID) + require.Len(t, connLogger.ConnectionLogs(), 2, "two connection logs, session timed out") // Fourth request, new IP produces new audit log. auditableIP = testutil.RandomIPv6(t) @@ -1234,7 +1213,7 @@ func Test_ResolveRequest(t *testing.T) { r.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) r.RemoteAddr = auditableIP - _, ok = workspaceappsResolveRequest(t, auditor, rw, r, workspaceapps.ResolveRequestOptions{ + _, ok = workspaceappsResolveRequest(t, connLogger, rw, r, workspaceapps.ResolveRequestOptions{ Logger: api.Logger, SignedTokenProvider: api.WorkspaceAppsProvider, DashboardURL: api.AccessURL, @@ -1243,16 +1222,16 @@ func Test_ResolveRequest(t *testing.T) { AppRequest: req, }) require.True(t, ok) - assertAuditApp(t, rw, r, auditor, appsBySlug[app], me.ID, nil) - require.Len(t, auditor.AuditLogs(), 3, "three audit logs, new IP") + assertConnLogContains(t, rw, r, connLogger, workspace, agentName, app, database.ConnectionTypeWorkspaceApp, me.ID) + require.Len(t, connLogger.ConnectionLogs(), 3, "three connection logs, new IP") } }) } -func workspaceappsResolveRequest(t testing.TB, auditor audit.Auditor, w http.ResponseWriter, r *http.Request, opts workspaceapps.ResolveRequestOptions) (token *workspaceapps.SignedToken, ok bool) { +func workspaceappsResolveRequest(t testing.TB, connLogger connectionlog.ConnectionLogger, w http.ResponseWriter, r *http.Request, opts workspaceapps.ResolveRequestOptions) (token *workspaceapps.SignedToken, ok bool) { t.Helper() - if opts.SignedTokenProvider != nil && auditor != nil { - opts.SignedTokenProvider = signedTokenProviderWithAuditor(t, opts.SignedTokenProvider, auditor, time.Hour) + if opts.SignedTokenProvider != nil && connLogger != nil { + opts.SignedTokenProvider = signedTokenProviderWithConnLogger(t, opts.SignedTokenProvider, connLogger, time.Hour) } tracing.StatusWriterMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -1264,52 +1243,41 @@ func workspaceappsResolveRequest(t testing.TB, auditor audit.Auditor, w http.Res return token, ok } -func signedTokenProviderWithAuditor(t testing.TB, provider workspaceapps.SignedTokenProvider, auditor audit.Auditor, sessionTimeout time.Duration) workspaceapps.SignedTokenProvider { +func signedTokenProviderWithConnLogger(t testing.TB, provider workspaceapps.SignedTokenProvider, connLogger connectionlog.ConnectionLogger, sessionTimeout time.Duration) workspaceapps.SignedTokenProvider { t.Helper() p, ok := provider.(*workspaceapps.DBTokenProvider) require.True(t, ok, "provider is not a DBTokenProvider") shallowCopy := *p - shallowCopy.Auditor = &atomic.Pointer[audit.Auditor]{} - shallowCopy.Auditor.Store(&auditor) + shallowCopy.ConnectionLogger = &atomic.Pointer[connectionlog.ConnectionLogger]{} + shallowCopy.ConnectionLogger.Store(&connLogger) shallowCopy.WorkspaceAppAuditSessionTimeout = sessionTimeout return &shallowCopy } -func auditAsserter[T audit.Auditable](workspace codersdk.Workspace) func(t testing.TB, rr *httptest.ResponseRecorder, r *http.Request, auditor *audit.MockAuditor, auditable T, userID uuid.UUID, additionalFields map[string]any) { - return func(t testing.TB, rr *httptest.ResponseRecorder, r *http.Request, auditor *audit.MockAuditor, auditable T, userID uuid.UUID, additionalFields map[string]any) { - t.Helper() - - resp := rr.Result() - defer resp.Body.Close() - - require.True(t, auditor.Contains(t, database.AuditLog{ - OrganizationID: workspace.OrganizationID, - Action: database.AuditActionOpen, - ResourceType: audit.ResourceType(auditable), - ResourceID: audit.ResourceID(auditable), - ResourceTarget: audit.ResourceTarget(auditable), - UserID: userID, - Ip: audit.ParseIP(r.RemoteAddr), - UserAgent: sql.NullString{Valid: r.UserAgent() != "", String: r.UserAgent()}, - StatusCode: int32(resp.StatusCode), //nolint:gosec - }), "audit log") - - // Verify additional fields, assume the last log entry. - alog := auditor.AuditLogs()[len(auditor.AuditLogs())-1] - - // Contains does not verify uuid.Nil. - if userID == uuid.Nil { - require.Equal(t, uuid.Nil, alog.UserID, "unauthenticated user") - } +func assertConnLogContains(t *testing.T, rr *httptest.ResponseRecorder, r *http.Request, connLogger *connectionlog.FakeConnectionLogger, workspace codersdk.Workspace, agentName string, slugOrPort string, typ database.ConnectionType, userID uuid.UUID) { + t.Helper() - add := make(map[string]any) - if len(alog.AdditionalFields) > 0 { - err := json.Unmarshal([]byte(alog.AdditionalFields), &add) - require.NoError(t, err, "audit log unmarhsal additional fields") - } - for k, v := range additionalFields { - require.Equal(t, v, add[k], "audit log additional field %s: additional fields: %v", k, add) - } - } + resp := rr.Result() + defer resp.Body.Close() + + require.True(t, connLogger.Contains(t, database.UpsertConnectionLogParams{ + OrganizationID: workspace.OrganizationID, + WorkspaceOwnerID: workspace.OwnerID, + WorkspaceID: workspace.ID, + WorkspaceName: workspace.Name, + AgentName: agentName, + Type: typ, + Ip: database.ParseIP(r.RemoteAddr), + UserAgent: sql.NullString{Valid: r.UserAgent() != "", String: r.UserAgent()}, + Code: sql.NullInt32{ + Int32: int32(resp.StatusCode), // nolint:gosec + Valid: true, + }, + UserID: uuid.NullUUID{ + UUID: userID, + Valid: true, + }, + SlugOrPort: sql.NullString{Valid: slugOrPort != "", String: slugOrPort}, + })) } diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index 6c3321625c9b3..c8b1008280b09 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -581,10 +581,24 @@ func (api *API) notifyWorkspaceUpdated( // @Produce json // @Tags Builds // @Param workspacebuild path string true "Workspace build ID" +// @Param expect_status query string false "Expected status of the job. If expect_status is supplied, the request will be rejected with 412 Precondition Failed if the job doesn't match the state when performing the cancellation." Enums(running, pending) // @Success 200 {object} codersdk.Response // @Router /workspacebuilds/{workspacebuild}/cancel [patch] func (api *API) patchCancelWorkspaceBuild(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + + var expectStatus database.ProvisionerJobStatus + expectStatusParam := r.URL.Query().Get("expect_status") + if expectStatusParam != "" { + if expectStatusParam != "running" && expectStatusParam != "pending" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Invalid expect_status %q. Only 'running' or 'pending' are allowed.", expectStatusParam), + }) + return + } + expectStatus = database.ProvisionerJobStatus(expectStatusParam) + } + workspaceBuild := httpmw.WorkspaceBuildParam(r) workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID) if err != nil { @@ -594,58 +608,78 @@ func (api *API) patchCancelWorkspaceBuild(rw http.ResponseWriter, r *http.Reques return } - valid, err := api.verifyUserCanCancelWorkspaceBuilds(ctx, httpmw.APIKey(r).UserID, workspace.TemplateID) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error verifying permission to cancel workspace build.", - Detail: err.Error(), - }) - return - } - if !valid { - httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{ - Message: "User is not allowed to cancel workspace builds. Owner role is required.", - }) - return + code := http.StatusInternalServerError + resp := codersdk.Response{ + Message: "Internal error canceling workspace build.", } + err = api.Database.InTx(func(db database.Store) error { + valid, err := verifyUserCanCancelWorkspaceBuilds(ctx, db, httpmw.APIKey(r).UserID, workspace.TemplateID, expectStatus) + if err != nil { + code = http.StatusInternalServerError + resp.Message = "Internal error verifying permission to cancel workspace build." + resp.Detail = err.Error() - job, err := api.Database.GetProvisionerJobByID(ctx, workspaceBuild.JobID) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching provisioner job.", - Detail: err.Error(), - }) - return - } - if job.CompletedAt.Valid { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Job has already completed!", - }) - return - } - if job.CanceledAt.Valid { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Job has already been marked as canceled!", + return xerrors.Errorf("verify user can cancel workspace builds: %w", err) + } + if !valid { + code = http.StatusForbidden + resp.Message = "User is not allowed to cancel workspace builds. Owner role is required." + + return xerrors.New("user is not allowed to cancel workspace builds") + } + + job, err := db.GetProvisionerJobByIDForUpdate(ctx, workspaceBuild.JobID) + if err != nil { + code = http.StatusInternalServerError + resp.Message = "Internal error fetching provisioner job." + resp.Detail = err.Error() + + return xerrors.Errorf("get provisioner job: %w", err) + } + if job.CompletedAt.Valid { + code = http.StatusBadRequest + resp.Message = "Job has already completed!" + + return xerrors.New("job has already completed") + } + if job.CanceledAt.Valid { + code = http.StatusBadRequest + resp.Message = "Job has already been marked as canceled!" + + return xerrors.New("job has already been marked as canceled") + } + + if expectStatus != "" && job.JobStatus != expectStatus { + code = http.StatusPreconditionFailed + resp.Message = "Job is not in the expected state." + + return xerrors.Errorf("job is not in the expected state: expected: %q, got %q", expectStatus, job.JobStatus) + } + + err = db.UpdateProvisionerJobWithCancelByID(ctx, database.UpdateProvisionerJobWithCancelByIDParams{ + ID: job.ID, + CanceledAt: sql.NullTime{ + Time: dbtime.Now(), + Valid: true, + }, + CompletedAt: sql.NullTime{ + Time: dbtime.Now(), + // If the job is running, don't mark it completed! + Valid: !job.WorkerID.Valid, + }, }) - return - } - err = api.Database.UpdateProvisionerJobWithCancelByID(ctx, database.UpdateProvisionerJobWithCancelByIDParams{ - ID: job.ID, - CanceledAt: sql.NullTime{ - Time: dbtime.Now(), - Valid: true, - }, - CompletedAt: sql.NullTime{ - Time: dbtime.Now(), - // If the job is running, don't mark it completed! - Valid: !job.WorkerID.Valid, - }, - }) + if err != nil { + code = http.StatusInternalServerError + resp.Message = "Internal error updating provisioner job." + resp.Detail = err.Error() + + return xerrors.Errorf("update provisioner job: %w", err) + } + + return nil + }, nil) if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error updating provisioner job.", - Detail: err.Error(), - }) + httpapi.Write(ctx, rw, code, resp) return } @@ -659,8 +693,14 @@ func (api *API) patchCancelWorkspaceBuild(rw http.ResponseWriter, r *http.Reques }) } -func (api *API) verifyUserCanCancelWorkspaceBuilds(ctx context.Context, userID uuid.UUID, templateID uuid.UUID) (bool, error) { - template, err := api.Database.GetTemplateByID(ctx, templateID) +func verifyUserCanCancelWorkspaceBuilds(ctx context.Context, store database.Store, userID uuid.UUID, templateID uuid.UUID, jobStatus database.ProvisionerJobStatus) (bool, error) { + // If the jobStatus is pending, we always allow cancellation regardless of + // the template setting as it's non-destructive to Terraform resources. + if jobStatus == database.ProvisionerJobStatusPending { + return true, nil + } + + template, err := store.GetTemplateByID(ctx, templateID) if err != nil { return false, xerrors.New("no template exists for this workspace") } @@ -669,7 +709,7 @@ func (api *API) verifyUserCanCancelWorkspaceBuilds(ctx context.Context, userID u return true, nil // all users can cancel workspace builds } - user, err := api.Database.GetUserByID(ctx, userID) + user, err := store.GetUserByID(ctx, userID) if err != nil { return false, xerrors.New("user does not exist") } diff --git a/coderd/workspacebuilds_test.go b/coderd/workspacebuilds_test.go index b9d32a00b139a..ebab0770b71b4 100644 --- a/coderd/workspacebuilds_test.go +++ b/coderd/workspacebuilds_test.go @@ -573,7 +573,7 @@ func TestPatchCancelWorkspaceBuild(t *testing.T) { build, err = client.WorkspaceBuild(ctx, workspace.LatestBuild.ID) return assert.NoError(t, err) && build.Job.Status == codersdk.ProvisionerJobRunning }, testutil.WaitShort, testutil.IntervalFast) - err := client.CancelWorkspaceBuild(ctx, build.ID) + err := client.CancelWorkspaceBuild(ctx, build.ID, codersdk.CancelWorkspaceBuildParams{}) require.NoError(t, err) require.Eventually(t, func() bool { var err error @@ -618,11 +618,199 @@ func TestPatchCancelWorkspaceBuild(t *testing.T) { build, err = userClient.WorkspaceBuild(ctx, workspace.LatestBuild.ID) return assert.NoError(t, err) && build.Job.Status == codersdk.ProvisionerJobRunning }, testutil.WaitShort, testutil.IntervalFast) - err := userClient.CancelWorkspaceBuild(ctx, build.ID) + err := userClient.CancelWorkspaceBuild(ctx, build.ID, codersdk.CancelWorkspaceBuildParams{}) var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) require.Equal(t, http.StatusForbidden, apiErr.StatusCode()) }) + + t.Run("Cancel with expect_state=pending", func(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("this test requires postgres") + } + // Given: a coderd instance with a provisioner daemon + store, ps, db := dbtestutil.NewDBWithSQLDB(t) + client, closeDaemon := coderdtest.NewWithProvisionerCloser(t, &coderdtest.Options{ + Database: store, + Pubsub: ps, + IncludeProvisionerDaemon: true, + }) + defer closeDaemon.Close() + // Given: a user, template, and workspace + user := coderdtest.CreateFirstUser(t, client) + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, template.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + + // Stop the provisioner daemon. + require.NoError(t, closeDaemon.Close()) + ctx := testutil.Context(t, testutil.WaitLong) + // Given: no provisioner daemons exist. + _, err := db.ExecContext(ctx, `DELETE FROM provisioner_daemons;`) + require.NoError(t, err) + + // When: a new workspace build is created + build, err := client.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{ + TemplateVersionID: template.ActiveVersionID, + Transition: codersdk.WorkspaceTransitionStart, + }) + // Then: the request should succeed. + require.NoError(t, err) + // Then: the provisioner job should remain pending. + require.Equal(t, codersdk.ProvisionerJobPending, build.Job.Status) + + // Then: the response should indicate no provisioners are available. + if assert.NotNil(t, build.MatchedProvisioners) { + assert.Zero(t, build.MatchedProvisioners.Count) + assert.Zero(t, build.MatchedProvisioners.Available) + assert.Zero(t, build.MatchedProvisioners.MostRecentlySeen.Time) + assert.False(t, build.MatchedProvisioners.MostRecentlySeen.Valid) + } + + // When: the workspace build is canceled + err = client.CancelWorkspaceBuild(ctx, build.ID, codersdk.CancelWorkspaceBuildParams{ + ExpectStatus: codersdk.CancelWorkspaceBuildStatusPending, + }) + require.NoError(t, err) + + // Then: the workspace build should be canceled. + build, err = client.WorkspaceBuild(ctx, build.ID) + require.NoError(t, err) + require.Equal(t, codersdk.ProvisionerJobCanceled, build.Job.Status) + }) + + t.Run("Cancel with expect_state=pending when job is running - should fail with 412", func(t *testing.T) { + t.Parallel() + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionApply: []*proto.Response{{ + Type: &proto.Response_Log{ + Log: &proto.Log{}, + }, + }}, + ProvisionPlan: echo.PlanComplete, + }) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, template.ID) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + var build codersdk.WorkspaceBuild + require.Eventually(t, func() bool { + var err error + build, err = client.WorkspaceBuild(ctx, workspace.LatestBuild.ID) + return assert.NoError(t, err) && build.Job.Status == codersdk.ProvisionerJobRunning + }, testutil.WaitShort, testutil.IntervalFast) + + // When: a cancel request is made with expect_state=pending + err := client.CancelWorkspaceBuild(ctx, build.ID, codersdk.CancelWorkspaceBuildParams{ + ExpectStatus: codersdk.CancelWorkspaceBuildStatusPending, + }) + // Then: the request should fail with 412. + require.Error(t, err) + + var apiErr *codersdk.Error + require.ErrorAs(t, err, &apiErr) + require.Equal(t, http.StatusPreconditionFailed, apiErr.StatusCode()) + }) + + t.Run("Cancel with expect_state=running when job is pending - should fail with 412", func(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("this test requires postgres") + } + // Given: a coderd instance with a provisioner daemon + store, ps, db := dbtestutil.NewDBWithSQLDB(t) + client, closeDaemon := coderdtest.NewWithProvisionerCloser(t, &coderdtest.Options{ + Database: store, + Pubsub: ps, + IncludeProvisionerDaemon: true, + }) + defer closeDaemon.Close() + // Given: a user, template, and workspace + user := coderdtest.CreateFirstUser(t, client) + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, template.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + + // Stop the provisioner daemon. + require.NoError(t, closeDaemon.Close()) + ctx := testutil.Context(t, testutil.WaitLong) + // Given: no provisioner daemons exist. + _, err := db.ExecContext(ctx, `DELETE FROM provisioner_daemons;`) + require.NoError(t, err) + + // When: a new workspace build is created + build, err := client.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{ + TemplateVersionID: template.ActiveVersionID, + Transition: codersdk.WorkspaceTransitionStart, + }) + // Then: the request should succeed. + require.NoError(t, err) + // Then: the provisioner job should remain pending. + require.Equal(t, codersdk.ProvisionerJobPending, build.Job.Status) + + // Then: the response should indicate no provisioners are available. + if assert.NotNil(t, build.MatchedProvisioners) { + assert.Zero(t, build.MatchedProvisioners.Count) + assert.Zero(t, build.MatchedProvisioners.Available) + assert.Zero(t, build.MatchedProvisioners.MostRecentlySeen.Time) + assert.False(t, build.MatchedProvisioners.MostRecentlySeen.Valid) + } + + // When: a cancel request is made with expect_state=running + err = client.CancelWorkspaceBuild(ctx, build.ID, codersdk.CancelWorkspaceBuildParams{ + ExpectStatus: codersdk.CancelWorkspaceBuildStatusRunning, + }) + // Then: the request should fail with 412. + require.Error(t, err) + + var apiErr *codersdk.Error + require.ErrorAs(t, err, &apiErr) + require.Equal(t, http.StatusPreconditionFailed, apiErr.StatusCode()) + }) + + t.Run("Cancel with expect_state - invalid status", func(t *testing.T) { + t.Parallel() + + // Given: a coderd instance with a provisioner daemon + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionApply: []*proto.Response{{ + Type: &proto.Response_Log{ + Log: &proto.Log{}, + }, + }}, + ProvisionPlan: echo.PlanComplete, + }) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, template.ID) + + ctx := testutil.Context(t, testutil.WaitLong) + + // When: a cancel request is made with invalid expect_state + err := client.CancelWorkspaceBuild(ctx, workspace.LatestBuild.ID, codersdk.CancelWorkspaceBuildParams{ + ExpectStatus: "invalid_status", + }) + // Then: the request should fail with 400. + var apiErr *codersdk.Error + require.ErrorAs(t, err, &apiErr) + require.Equal(t, http.StatusBadRequest, apiErr.StatusCode()) + require.Contains(t, apiErr.Message, "Invalid expect_status") + }) } func TestWorkspaceBuildResources(t *testing.T) { @@ -968,7 +1156,7 @@ func TestWorkspaceBuildStatus(t *testing.T) { _ = closeDaemon.Close() // after successful cancel is "canceled" build = coderdtest.CreateWorkspaceBuild(t, client, workspace, database.WorkspaceTransitionStart) - err = client.CancelWorkspaceBuild(ctx, build.ID) + err = client.CancelWorkspaceBuild(ctx, build.ID, codersdk.CancelWorkspaceBuildParams{}) require.NoError(t, err) workspace, err = client.Workspace(ctx, workspace.ID) diff --git a/coderd/workspaces_test.go b/coderd/workspaces_test.go index d51a228a3f7a1..f99e3b9e3ec3f 100644 --- a/coderd/workspaces_test.go +++ b/coderd/workspaces_test.go @@ -8,7 +8,6 @@ import ( "fmt" "math" "net/http" - "os" "slices" "strings" "testing" @@ -1426,9 +1425,6 @@ func TestWorkspaceByOwnerAndName(t *testing.T) { // TestWorkspaceFilterAllStatus tests workspace status is correctly set given a set of conditions. func TestWorkspaceFilterAllStatus(t *testing.T) { t.Parallel() - if os.Getenv("DB") != "" { - t.Skip(`This test takes too long with an actual database. Takes 10s on local machine`) - } // For this test, we do not care about permissions. // nolint:gocritic // unit testing @@ -3245,7 +3241,7 @@ func TestWorkspaceWatcher(t *testing.T) { closeFunc.Close() build := coderdtest.CreateWorkspaceBuild(t, client, workspace, database.WorkspaceTransitionStart) wait("first is for the workspace build itself", nil) - err = client.CancelWorkspaceBuild(ctx, build.ID) + err = client.CancelWorkspaceBuild(ctx, build.ID, codersdk.CancelWorkspaceBuildParams{}) require.NoError(t, err) wait("second is for the build cancel", nil) } diff --git a/codersdk/agentsdk/convert.go b/codersdk/agentsdk/convert.go index d01c9e527fce9..775ce06c73c69 100644 --- a/codersdk/agentsdk/convert.go +++ b/codersdk/agentsdk/convert.go @@ -408,40 +408,6 @@ func ProtoFromLifecycleState(s codersdk.WorkspaceAgentLifecycle) (proto.Lifecycl return proto.Lifecycle_State(caps), nil } -func ConnectionTypeFromProto(typ proto.Connection_Type) (ConnectionType, error) { - switch typ { - case proto.Connection_TYPE_UNSPECIFIED: - return ConnectionTypeUnspecified, nil - case proto.Connection_SSH: - return ConnectionTypeSSH, nil - case proto.Connection_VSCODE: - return ConnectionTypeVSCode, nil - case proto.Connection_JETBRAINS: - return ConnectionTypeJetBrains, nil - case proto.Connection_RECONNECTING_PTY: - return ConnectionTypeReconnectingPTY, nil - default: - return "", xerrors.Errorf("unknown connection type %q", typ) - } -} - -func ProtoFromConnectionType(typ ConnectionType) (proto.Connection_Type, error) { - switch typ { - case ConnectionTypeUnspecified: - return proto.Connection_TYPE_UNSPECIFIED, nil - case ConnectionTypeSSH: - return proto.Connection_SSH, nil - case ConnectionTypeVSCode: - return proto.Connection_VSCODE, nil - case ConnectionTypeJetBrains: - return proto.Connection_JETBRAINS, nil - case ConnectionTypeReconnectingPTY: - return proto.Connection_RECONNECTING_PTY, nil - default: - return 0, xerrors.Errorf("unknown connection type %q", typ) - } -} - func DevcontainersFromProto(pdcs []*proto.WorkspaceAgentDevcontainer) ([]codersdk.WorkspaceAgentDevcontainer, error) { ret := make([]codersdk.WorkspaceAgentDevcontainer, len(pdcs)) for i, pdc := range pdcs { diff --git a/codersdk/audit.go b/codersdk/audit.go index 49e597845b964..1e529202b5285 100644 --- a/codersdk/audit.go +++ b/codersdk/audit.go @@ -38,8 +38,12 @@ const ( ResourceTypeIdpSyncSettingsOrganization ResourceType = "idp_sync_settings_organization" ResourceTypeIdpSyncSettingsGroup ResourceType = "idp_sync_settings_group" ResourceTypeIdpSyncSettingsRole ResourceType = "idp_sync_settings_role" - ResourceTypeWorkspaceAgent ResourceType = "workspace_agent" - ResourceTypeWorkspaceApp ResourceType = "workspace_app" + // Deprecated: Workspace Agent connections are now included in the + // connection log. + ResourceTypeWorkspaceAgent ResourceType = "workspace_agent" + // Deprecated: Workspace App connections are now included in the + // connection log. + ResourceTypeWorkspaceApp ResourceType = "workspace_app" ) func (r ResourceType) FriendlyString() string { @@ -113,10 +117,17 @@ const ( AuditActionLogout AuditAction = "logout" AuditActionRegister AuditAction = "register" AuditActionRequestPasswordReset AuditAction = "request_password_reset" - AuditActionConnect AuditAction = "connect" - AuditActionDisconnect AuditAction = "disconnect" - AuditActionOpen AuditAction = "open" - AuditActionClose AuditAction = "close" + // Deprecated: Workspace connections are now included in the + // connection log. + AuditActionConnect AuditAction = "connect" + // Deprecated: Workspace disconnections are now included in the + // connection log. + AuditActionDisconnect AuditAction = "disconnect" + // Deprecated: Workspace App connections are now included in the + // connection log. + AuditActionOpen AuditAction = "open" + // Deprecated: This action is unused. + AuditActionClose AuditAction = "close" ) func (a AuditAction) Friendly() string { diff --git a/codersdk/deployment.go b/codersdk/deployment.go index b24e321b8e434..61c3c805a29a9 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -67,6 +67,7 @@ type FeatureName string const ( FeatureUserLimit FeatureName = "user_limit" FeatureAuditLog FeatureName = "audit_log" + FeatureConnectionLog FeatureName = "connection_log" FeatureBrowserOnly FeatureName = "browser_only" FeatureSCIM FeatureName = "scim" FeatureTemplateRBAC FeatureName = "template_rbac" @@ -90,6 +91,7 @@ const ( var FeatureNames = []FeatureName{ FeatureUserLimit, FeatureAuditLog, + FeatureConnectionLog, FeatureBrowserOnly, FeatureSCIM, FeatureTemplateRBAC, @@ -354,7 +356,6 @@ type DeploymentValues struct { ProxyTrustedHeaders serpent.StringArray `json:"proxy_trusted_headers,omitempty" typescript:",notnull"` ProxyTrustedOrigins serpent.StringArray `json:"proxy_trusted_origins,omitempty" typescript:",notnull"` CacheDir serpent.String `json:"cache_directory,omitempty" typescript:",notnull"` - InMemoryDatabase serpent.Bool `json:"in_memory_database,omitempty" typescript:",notnull"` EphemeralDeployment serpent.Bool `json:"ephemeral_deployment,omitempty" typescript:",notnull"` PostgresURL serpent.String `json:"pg_connection_url,omitempty" typescript:",notnull"` PostgresAuth string `json:"pg_auth,omitempty" typescript:",notnull"` @@ -2404,15 +2405,6 @@ func (c *DeploymentValues) Options() serpent.OptionSet { Value: &c.CacheDir, YAML: "cacheDir", }, - { - Name: "In Memory Database", - Description: "Controls whether data will be stored in an in-memory database.", - Flag: "in-memory", - Env: "CODER_IN_MEMORY", - Hidden: true, - Value: &c.InMemoryDatabase, - YAML: "inMemoryDatabase", - }, { Name: "Ephemeral Deployment", Description: "Controls whether Coder data, including built-in Postgres, will be stored in a temporary directory and deleted when the server is stopped.", diff --git a/codersdk/rbacresources_gen.go b/codersdk/rbacresources_gen.go index 5ffcfed6b4c35..3e22d29c73297 100644 --- a/codersdk/rbacresources_gen.go +++ b/codersdk/rbacresources_gen.go @@ -9,6 +9,7 @@ const ( ResourceAssignOrgRole RBACResource = "assign_org_role" ResourceAssignRole RBACResource = "assign_role" ResourceAuditLog RBACResource = "audit_log" + ResourceConnectionLog RBACResource = "connection_log" ResourceCryptoKey RBACResource = "crypto_key" ResourceDebugInfo RBACResource = "debug_info" ResourceDeploymentConfig RBACResource = "deployment_config" @@ -72,6 +73,7 @@ var RBACResourceActions = map[RBACResource][]RBACAction{ ResourceAssignOrgRole: {ActionAssign, ActionCreate, ActionDelete, ActionRead, ActionUnassign, ActionUpdate}, ResourceAssignRole: {ActionAssign, ActionRead, ActionUnassign}, ResourceAuditLog: {ActionCreate, ActionRead}, + ResourceConnectionLog: {ActionRead, ActionUpdate}, ResourceCryptoKey: {ActionCreate, ActionDelete, ActionRead, ActionUpdate}, ResourceDebugInfo: {ActionRead}, ResourceDeploymentConfig: {ActionRead, ActionUpdate}, diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index d08191a614a99..09b919a428a84 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -164,7 +164,7 @@ func TestTools(t *testing.T) { // Important: cancel the build. We don't run any provisioners, so this // will remain in the 'pending' state indefinitely. - require.NoError(t, client.CancelWorkspaceBuild(ctx, result.ID)) + require.NoError(t, client.CancelWorkspaceBuild(ctx, result.ID, codersdk.CancelWorkspaceBuildParams{})) }) t.Run("Start", func(t *testing.T) { @@ -184,7 +184,7 @@ func TestTools(t *testing.T) { // Important: cancel the build. We don't run any provisioners, so this // will remain in the 'pending' state indefinitely. - require.NoError(t, client.CancelWorkspaceBuild(ctx, result.ID)) + require.NoError(t, client.CancelWorkspaceBuild(ctx, result.ID, codersdk.CancelWorkspaceBuildParams{})) }) t.Run("TemplateVersionChange", func(t *testing.T) { @@ -216,7 +216,7 @@ func TestTools(t *testing.T) { require.Equal(t, r.Workspace.ID.String(), updateBuild.WorkspaceID.String()) require.Equal(t, newVersion.TemplateVersion.ID.String(), updateBuild.TemplateVersionID.String()) // Cancel the build so it doesn't remain in the 'pending' state indefinitely. - require.NoError(t, client.CancelWorkspaceBuild(ctx, updateBuild.ID)) + require.NoError(t, client.CancelWorkspaceBuild(ctx, updateBuild.ID, codersdk.CancelWorkspaceBuildParams{})) // Roll back to the original version rollbackBuild, err := testTool(t, toolsdk.CreateWorkspaceBuild, tb, toolsdk.CreateWorkspaceBuildArgs{ @@ -229,7 +229,7 @@ func TestTools(t *testing.T) { require.Equal(t, r.Workspace.ID.String(), rollbackBuild.WorkspaceID.String()) require.Equal(t, originalVersionID.String(), rollbackBuild.TemplateVersionID.String()) // Cancel the build so it doesn't remain in the 'pending' state indefinitely. - require.NoError(t, client.CancelWorkspaceBuild(ctx, rollbackBuild.ID)) + require.NoError(t, client.CancelWorkspaceBuild(ctx, rollbackBuild.ID, codersdk.CancelWorkspaceBuildParams{})) }) }) diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 2bfae8aac36cf..1eb37bb07c989 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -421,6 +421,19 @@ type WorkspaceAgentDevcontainer struct { Error string `json:"error,omitempty"` } +func (d WorkspaceAgentDevcontainer) Equals(other WorkspaceAgentDevcontainer) bool { + return d.ID == other.ID && + d.Name == other.Name && + d.WorkspaceFolder == other.WorkspaceFolder && + d.Status == other.Status && + d.Dirty == other.Dirty && + (d.Container == nil && other.Container == nil || + (d.Container != nil && other.Container != nil && d.Container.ID == other.Container.ID)) && + (d.Agent == nil && other.Agent == nil || + (d.Agent != nil && other.Agent != nil && *d.Agent == *other.Agent)) && + d.Error == other.Error +} + // WorkspaceAgentDevcontainerAgent represents the sub agent for a // devcontainer. type WorkspaceAgentDevcontainerAgent struct { @@ -520,6 +533,40 @@ func (c *Client) WorkspaceAgentListContainers(ctx context.Context, agentID uuid. return cr, json.NewDecoder(res.Body).Decode(&cr) } +func (c *Client) WatchWorkspaceAgentContainers(ctx context.Context, agentID uuid.UUID) (<-chan WorkspaceAgentListContainersResponse, io.Closer, error) { + reqURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/containers/watch", agentID)) + if err != nil { + return nil, nil, err + } + + jar, err := cookiejar.New(nil) + if err != nil { + return nil, nil, xerrors.Errorf("create cookie jar: %w", err) + } + + jar.SetCookies(reqURL, []*http.Cookie{{ + Name: SessionTokenCookie, + Value: c.SessionToken(), + }}) + + conn, res, err := websocket.Dial(ctx, reqURL.String(), &websocket.DialOptions{ + CompressionMode: websocket.CompressionDisabled, + HTTPClient: &http.Client{ + Jar: jar, + Transport: c.HTTPClient.Transport, + }, + }) + if err != nil { + if res == nil { + return nil, nil, err + } + return nil, nil, ReadBodyAsError(res) + } + + d := wsjson.NewDecoder[WorkspaceAgentListContainersResponse](conn, websocket.MessageText, c.logger) + return d.Chan(), d, nil +} + // WorkspaceAgentRecreateDevcontainer recreates the devcontainer with the given ID. func (c *Client) WorkspaceAgentRecreateDevcontainer(ctx context.Context, agentID uuid.UUID, devcontainerID string) (Response, error) { res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/workspaceagents/%s/containers/devcontainers/%s/recreate", agentID, devcontainerID), nil) diff --git a/codersdk/workspacebuilds.go b/codersdk/workspacebuilds.go index 328b8bc26566f..0960c6789dea4 100644 --- a/codersdk/workspacebuilds.go +++ b/codersdk/workspacebuilds.go @@ -37,15 +37,18 @@ const ( type BuildReason string const ( - // "initiator" is used when a workspace build is triggered by a user. + // BuildReasonInitiator "initiator" is used when a workspace build is triggered by a user. // Combined with the initiator id/username, it indicates which user initiated the build. BuildReasonInitiator BuildReason = "initiator" - // "autostart" is used when a build to start a workspace is triggered by Autostart. + // BuildReasonAutostart "autostart" is used when a build to start a workspace is triggered by Autostart. // The initiator id/username in this case is the workspace owner and can be ignored. BuildReasonAutostart BuildReason = "autostart" - // "autostop" is used when a build to stop a workspace is triggered by Autostop. + // BuildReasonAutostop "autostop" is used when a build to stop a workspace is triggered by Autostop. // The initiator id/username in this case is the workspace owner and can be ignored. BuildReasonAutostop BuildReason = "autostop" + // BuildReasonDormancy "dormancy" is used when a build to stop a workspace is triggered due to inactivity (dormancy). + // The initiator id/username in this case is the workspace owner and can be ignored. + BuildReasonDormancy BuildReason = "dormancy" ) // WorkspaceBuild is an at-point representation of a workspace state. @@ -123,9 +126,29 @@ func (c *Client) WorkspaceBuild(ctx context.Context, id uuid.UUID) (WorkspaceBui return workspaceBuild, json.NewDecoder(res.Body).Decode(&workspaceBuild) } +type CancelWorkspaceBuildStatus string + +const ( + CancelWorkspaceBuildStatusRunning CancelWorkspaceBuildStatus = "running" + CancelWorkspaceBuildStatusPending CancelWorkspaceBuildStatus = "pending" +) + +type CancelWorkspaceBuildParams struct { + // ExpectStatus ensures the build is in the expected status before canceling. + ExpectStatus CancelWorkspaceBuildStatus `json:"expect_status,omitempty"` +} + +func (c *CancelWorkspaceBuildParams) asRequestOption() RequestOption { + return func(r *http.Request) { + q := r.URL.Query() + q.Set("expect_status", string(c.ExpectStatus)) + r.URL.RawQuery = q.Encode() + } +} + // CancelWorkspaceBuild marks a workspace build job as canceled. -func (c *Client) CancelWorkspaceBuild(ctx context.Context, id uuid.UUID) error { - res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/v2/workspacebuilds/%s/cancel", id), nil) +func (c *Client) CancelWorkspaceBuild(ctx context.Context, id uuid.UUID, req CancelWorkspaceBuildParams) error { + res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/v2/workspacebuilds/%s/cancel", id), nil, req.asRequestOption()) if err != nil { return err } diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index ee0b36e5a0c23..ce66d5e1b8a70 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -20,10 +20,14 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/net/speedtest" + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/healthsdk" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/tailnet" + "github.com/coder/websocket" ) // NewAgentConn creates a new WorkspaceAgentConn. `conn` may be unique @@ -387,6 +391,30 @@ func (c *AgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgent return resp, json.NewDecoder(res.Body).Decode(&resp) } +func (c *AgentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + host := net.JoinHostPort(c.agentAddress().String(), strconv.Itoa(AgentHTTPAPIServerPort)) + url := fmt.Sprintf("http://%s%s", host, "/api/v0/containers/watch") + + conn, res, err := websocket.Dial(ctx, url, &websocket.DialOptions{ + HTTPClient: c.apiClient(), + }) + if err != nil { + if res == nil { + return nil, nil, err + } + return nil, nil, codersdk.ReadBodyAsError(res) + } + if res != nil && res.Body != nil { + defer res.Body.Close() + } + + d := wsjson.NewDecoder[codersdk.WorkspaceAgentListContainersResponse](conn, websocket.MessageText, logger) + return d.Chan(), d, nil +} + // RecreateDevcontainer recreates a devcontainer with the given container. // This is a blocking call and will wait for the container to be recreated. func (c *AgentConn) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) { diff --git a/docs/about/contributing/backend.md b/docs/about/contributing/backend.md index fd1a80dc6b73c..ad5d91bcda879 100644 --- a/docs/about/contributing/backend.md +++ b/docs/about/contributing/backend.md @@ -16,9 +16,9 @@ Need help or have questions? Join the conversation on our [Discord server](https To understand how the backend fits into the broader system, we recommend reviewing the following resources: -* [General Concepts](../admin/infrastructure/validated-architectures/index.md#general-concepts): Essential concepts and language used to describe how Coder is structured and operated. +* [General Concepts](../../admin/infrastructure/validated-architectures/index.md#general-concepts): Essential concepts and language used to describe how Coder is structured and operated. -* [Architecture](../admin/infrastructure/architecture.md): A high-level overview of the infrastructure layout, key services, and how components interact. +* [Architecture](../../admin/infrastructure/architecture.md): A high-level overview of the infrastructure layout, key services, and how components interact. These sections provide the necessary context for navigating and contributing to the backend effectively. @@ -68,7 +68,6 @@ The Coder backend is organized into multiple packages and directories, each with * [dbauthz](https://github.com/coder/coder/tree/main/coderd/database/dbauthz): AuthZ wrappers for database queries, ideally, every query should verify first if the accessor is eligible to see the query results. * [dbfake](https://github.com/coder/coder/tree/main/coderd/database/dbfake): helper functions to quickly prepare the initial database state for testing purposes (e.g. create N healthy workspaces and templates), operates on higher level than [dbgen](https://github.com/coder/coder/tree/main/coderd/database/dbgen) * [dbgen](https://github.com/coder/coder/tree/main/coderd/database/dbgen): helper functions to insert raw records to the database store, used for testing purposes - * [dbmem](https://github.com/coder/coder/tree/main/coderd/database/dbmem): in-memory implementation of the database store, ideally, every real query should have a complimentary Go implementation * [dbmock](https://github.com/coder/coder/tree/main/coderd/database/dbmock): a store wrapper for database queries, useful to verify if the function has been called, used for testing purposes * [dbpurge](https://github.com/coder/coder/tree/main/coderd/database/dbpurge): simple wrapper for periodic database cleanup operations * [migrations](https://github.com/coder/coder/tree/main/coderd/database/migrations): an ordered list of up/down database migrations, use `./create_migration.sh my_migration_name` to modify the database schema @@ -169,9 +168,9 @@ There are two types of fixtures that are used to test that migrations don't break existing Coder deployments: * Partial fixtures - [`migrations/testdata/fixtures`](../../coderd/database/migrations/testdata/fixtures) + [`migrations/testdata/fixtures`](../../../coderd/database/migrations/testdata/fixtures) * Full database dumps - [`migrations/testdata/full_dumps`](../../coderd/database/migrations/testdata/full_dumps) + [`migrations/testdata/full_dumps`](../../../coderd/database/migrations/testdata/full_dumps) Both types behave like database migrations (they also [`migrate`](https://github.com/golang-migrate/migrate)). Their behavior mirrors @@ -194,7 +193,7 @@ To add a new partial fixture, run the following command: ``` Then add some queries to insert data and commit the file to the repo. See -[`000024_example.up.sql`](../../coderd/database/migrations/testdata/fixtures/000024_example.up.sql) +[`000024_example.up.sql`](../../../coderd/database/migrations/testdata/fixtures/000024_example.up.sql) for an example. To create a full dump, run a fully fledged Coder deployment and use it to diff --git a/docs/about/screenshots.md b/docs/about/screenshots.md index ddf71b823f7fc..dff7ea75946d8 100644 --- a/docs/about/screenshots.md +++ b/docs/about/screenshots.md @@ -2,19 +2,19 @@ ## Log in -![Install Coder in your cloud or air-gapped on-premises. Developers simply log in via their browser to access their Workspaces.](../images/screenshots/login.png) +![Install Coder in your cloud or air-gapped on-premises. Developers simply log in via their browser to access their Workspaces.](../images/screenshots/coder-login.png) Install Coder in your cloud or air-gapped on-premises. Developers simply log in via their browser to access their Workspaces. ## Templates -![Developers provision their own ephemeral Workspaces in minutes using pre-defined Templates that include approved tooling and infrastructure.](../images/screenshots/templates_listing.png) +![Developers provision their own ephemeral Workspaces in minutes using pre-defined Templates that include approved tooling and infrastructure.](../images/screenshots/templates-listing.png) Developers provision their own ephemeral Workspaces in minutes using pre-defined Templates that include approved tooling and infrastructure. -![Template administrators can either create a new Template from scratch or choose a Starter Template](../images/screenshots/starter_templates.png) +![Template administrators can either create a new Template from scratch or choose a Starter Template](../images/screenshots/starter-templates.png) Template administrators can either create a new Template from scratch or choose a Starter Template. @@ -26,25 +26,25 @@ underlying infrastructure that Coder Workspaces run on. ## Workspaces -![Developers create and delete their own workspaces. Coder administrators can easily enforce Workspace scheduling and autostop policies to ensure idle Workspaces don’t burn unnecessary cloud budget.](../images/screenshots/workspaces_listing.png) +![Developers create and delete their own workspaces. Coder administrators can easily enforce Workspace scheduling and autostop policies to ensure idle Workspaces don’t burn unnecessary cloud budget.](../images/screenshots/workspaces-listing.png) Developers create and delete their own workspaces. Coder administrators can easily enforce Workspace scheduling and autostop policies to ensure idle Workspaces don’t burn unnecessary cloud budget. -![Developers launch their favorite web-based or desktop IDE, browse files, or access their Workspace’s Terminal.](../images/screenshots/workspace_launch.png) +![Developers launch their favorite web-based or desktop IDE, browse files, or access their Workspace’s Terminal.](../images/screenshots/workspace-running-with-topbar.png) Developers launch their favorite web-based or desktop IDE, browse files, or access their Workspace’s Terminal. ## Administration -![Coder administrators can access Template usage insights to understand which Templates are most popular and how well they perform for developers.](../images/screenshots/templates_insights.png) +![Coder administrators can access Template usage insights to understand which Templates are most popular and how well they perform for developers.](../images/screenshots/template-insights.png) Coder administrators can access Template usage insights to understand which Templates are most popular and how well they perform for developers. -![Coder administrators can control *every* aspect of their Coder deployment.](../images/screenshots/settings.png) +![Coder administrators can control *every* aspect of their Coder deployment.](../images/screenshots/admin-settings.png) Coder administrators can control *every* aspect of their Coder deployment. diff --git a/docs/admin/templates/extending-templates/dynamic-parameters.md b/docs/admin/templates/extending-templates/dynamic-parameters.md new file mode 100644 index 0000000000000..d676c3bcf3148 --- /dev/null +++ b/docs/admin/templates/extending-templates/dynamic-parameters.md @@ -0,0 +1,833 @@ +# Dynamic Parameters + +Coder v2.24.0 introduces Dynamic Parameters to extend Coder [parameters](./parameters.md) with conditional form controls, +enriched input types, and user identity awareness. +This allows template authors to create interactive workspace creation forms with more environment customization, +and that means fewer templates to maintain. + +![Dynamic Parameters in Action](https://i.imgur.com/uR8mpRJ.gif) + +All parameters are parsed from Terraform, so your workspace creation forms live in the same location as your provisioning code. +You can use all the native Terraform functions and conditionality to create a self-service tooling catalog for every template. + +Administrators can use Dynamic Parameters to: + +- Create parameters which respond to the inputs of others. +- Only show parameters when other input criteria are met. +- Only show select parameters to target Coder roles or groups. + +You can try the Dynamic Parameter syntax and any of the code examples below in the +[Parameters Playground](https://playground.coder.app/parameters). +You should experiment with parameters in the playground before you upgrade live templates. + +## When You Should Upgrade to Dynamic Parameters + +While Dynamic parameters introduce a variety of new powerful tools, all functionality is backwards compatible with +existing coder templates. +When you opt-in to the new experience, no functional changes will be applied to your production parameters. + +Some reasons Coder template admins should try Dynamic Parameters: + +- You maintain or support many templates for teams with unique expectations or use cases. +- You want to selectively expose privileged workspace options to admins, power users, or personas. +- You want to make the workspace creation flow more ergonomic for developers. + +Dynamic Parameters help you reduce template duplication by setting the conditions for which users should see specific parameters. +They reduce the potential complexity of user-facing configuration by allowing administrators to organize a long list of options into interactive, branching paths for workspace customization. +They allow you to set resource guardrails by referencing Coder identity in the `coder_workspace_owner` data source. + +## How to enable Dynamic Parameters + +In Coder v2.24.0, you can opt-in to Dynamic Parameters on a per-template basis. + +1. Go to your template's settings and enable the **Enable dynamic parameters for workspace creation** option. + + ![Enable dynamic parameters for workspace creation](../../../images/admin/templates/extend-templates/dyn-params/enable-dynamic-parameters.png) + +1. Update your template to use version >=2.4.0 of the Coder provider with the following Terraform block. + + ```terraform + terraform { + required_providers { + coder = { + source = "coder/coder" + version = ">=2.4.0" + } + } + } + ``` + +1. This enables Dynamic Parameters in the template. + Add some [conditional parameters](#available-form-input-types). + + Note that these new features must be declared in your Terraform to start leveraging Dynamic Parameters. + +1. Save and publish the template. + +1. Users should see the updated workspace creation form. + +Dynamic Parameters features are backwards compatible, so all existing templates may be upgraded in-place. +If you decide to revert to the legacy flow later, disable Dynamic Parameters in the template's settings. + +## Features and Capabilities + +Dynamic Parameters introduces three primary enhancements to the standard parameter system: + +- **Conditional Parameters** + + - Parameters can respond to changes in other parameters + - Show or hide parameters based on other selections + - Modify validation rules conditionally + - Create branching paths in workspace creation forms + +- **Reference User Properties** + + - Read user data at build time from [`coder_workspace_owner`](https://registry.terraform.io/providers/coder/coder/latest/docs/data-sources/workspace_owner) + - Conditionally hide parameters based on user's role + - Change parameter options based on user groups + - Reference user name, groups, and roles in parameter text + +- **Additional Form Inputs** + + - Searchable dropdown lists for easier selection + - Multi-select options for choosing multiple items + - Secret text inputs for sensitive information + - Slider input for disk size, model temperature + - Disabled parameters to display immutable data + +> [!IMPORTANT] +> Dynamic Parameters does not support external data fetching via HTTP endpoints at workspace build time. +> +> External fetching would introduce unpredictability in workspace builds after publishing a template. +> Instead, we recommend that template administrators pull in any required data for a workspace build as a +> [locals](https://developer.hashicorp.com/terraform/tutorials/configuration-language/locals) or JSON file, +> then reference that data in Terraform. +> +> If you have a use case for external data fetching, please file an issue or create a discussion in the +> [Coder GitHub repository](https://github.com/coder/coder). + +## Available Form Input Types + +Dynamic Parameters supports a variety of form types to create rich, interactive user experiences. + +![Old vs New Parameters](../../../images/admin/templates/extend-templates/dyn-params/dynamic-params-compare.png) + +Different parameter types support different form types. +You can specify the form type using the +[`form_type`](https://registry.terraform.io/providers/coder/coder/latest/docs/data-sources/parameter#form_type-1) attribute. + +The **Options** column in the table below indicates whether the form type supports options (**Yes**) or doesn't support them (**No**). +When supported, you can specify options using one or more `option` blocks in your parameter definition, +where each option has a `name` (displayed to the user) and a `value` (used in your template logic). + +| Form Type | Parameter Types | Options | Notes | +|----------------|--------------------------------------------|---------|------------------------------------------------------------------------------------------------------------------------| +| `radio` | `string`, `number`, `bool`, `list(string)` | Yes | Radio buttons for selecting a single option with all choices visible at once.
The classic parameter option. | +| `dropdown` | `string`, `number` | Yes | Choose a single option from a searchable dropdown list.
Default for `string` or `number` parameters with options. | +| `multi-select` | `list(string)` | Yes | Select multiple items from a list with checkboxes. | +| `tag-select` | `list(string)` | No | Default for `list(string)` parameters without options. | +| `input` | `string`, `number` | No | Standard single-line text input field.
Default for `string/number` parameters without options. | +| `textarea` | `string` | No | Multi-line text input field for longer content. | +| `slider` | `number` | No | Slider selection with min/max validation for numeric values. | +| `checkbox` | `bool` | No | A single checkbox for boolean parameters.
Default for boolean parameters. | + +### Available Styling Options + +The `coder_parameter` resource supports an additional `styling` attribute for special cosmetic changes that can be used +to further customize the workspace creation form. + +This can be used for: + +- Masking private inputs +- Marking inputs as read-only +- Setting placeholder text + +Note that the `styling` attribute should not be used as a governance tool, since it only changes how the interactive +form is displayed. +Users can avoid restrictions like `disabled` if they create a workspace via the CLI. + +This attribute accepts JSON like so: + +```terraform +data "coder_parameter" "styled_parameter" { + ... + styling = jsonencode({ + disabled = true + }) +} +``` + +Not all styling attributes are supported by all form types, use the reference below for syntax: + +| Styling Option | Compatible parameter types | Compatible form types | Notes | +|----------------|----------------------------|-----------------------|-------------------------------------------------------------------------------------| +| `disabled` | All parameter types | All form types | Disables the form control when `true`. | +| `placeholder` | `string` | `input`, `textarea` | Sets placeholder text.
This is overwritten by user entry. | +| `mask_input` | `string`, `number` | `input`, `textarea` | Masks inputs as asterisks (`*`). Used to cosmetically hide token or password entry. | + +## Use Case Examples + +### New Form Types + +The following examples show some basic usage of the +[`form_type`](https://registry.terraform.io/providers/coder/coder/latest/docs/data-sources/parameter#form_type-1) +attribute [explained above](#available-form-input-types). +These are used to change the input style of form controls in the create workspace form. + +
+ +### Dropdowns + +Single-select parameters with options can use the `form_type="dropdown"` attribute for better organization. + +[Try dropdown lists on the Parameter Playground](https://playground.coder.app/parameters/kgNBpjnz7x) + +```terraform +locals { + ides = [ + "VS Code", + "JetBrains IntelliJ", + "PyCharm", + "GoLand", + "WebStorm", + "Vim", + "Emacs", + "Neovim" + ] +} + +data "coder_parameter" "ides_dropdown" { + name = "ides_dropdown" + display_name = "Select your IDEs" + type = "string" + + form_type = "dropdown" + + dynamic "option" { + for_each = local.ides + content { + name = option.value + value = option.value + } + } +} +``` + +### Text Area + +The large text entry option can be used to enter long strings like AI prompts, scripts, or natural language. + +[Try textarea parameters on the Parameter Playground](https://playground.coder.app/parameters/RCAHA1Oi1_) + +```terraform + +data "coder_parameter" "text_area" { + name = "text_area" + description = "Enter multi-line text." + mutable = true + display_name = "Textarea" + + form_type = "textarea" + type = "string" + + default = <<-EOT + This is an example of multi-line text entry. + + The 'textarea' form_type is useful for + - AI prompts + - Scripts + - Read-only info (try the 'disabled' styling option) + EOT +} + +``` + +### Multi-select + +Multi-select parameters allow users to select one or many options from a single list of options. +For example, adding multiple IDEs with a single parameter. + +[Try multi-select parameters on the Parameter Playground](https://playground.coder.app/parameters/XogX54JV_f) + +```terraform +locals { + ides = [ + "VS Code", "JetBrains IntelliJ", + "GoLand", "WebStorm", + "Vim", "Emacs", + "Neovim", "PyCharm", + "Databricks", "Jupyter Notebook", + ] +} + +data "coder_parameter" "ide_selector" { + name = "ide_selector" + description = "Choose any IDEs for your workspace." + mutable = true + display_name = "Select multiple IDEs" + + + # Allows users to select multiple IDEs from the list. + form_type = "multi-select" + type = "list(string)" + + + dynamic "option" { + for_each = local.ides + content { + name = option.value + value = option.value + } + } +} +``` + +### Radio + +Radio buttons are used to select a single option with high visibility. +This is the original styling for list parameters. + +[Try radio parameters on the Parameter Playground](https://playground.coder.app/parameters/3OMDp5ANZI). + +```terraform +data "coder_parameter" "environment" { + name = "environment" + display_name = "Environment" + description = "An example of environment listing with the radio form type." + type = "string" + default = "dev" + + form_type = "radio" + + option { + name = "Development" + value = "dev" + } + option { + name = "Experimental" + value = "exp" + } + option { + name = "Staging" + value = "staging" + } + option { + name = "Production" + value = "prod" + } +} +``` + +### Checkboxes + +A single checkbox for boolean values. +This can be used for a TOS confirmation or to expose advanced options. + +[Try checkbox parameters on the Parameters Playground](https://playground.coder.app/parameters/ycWuQJk2Py). + +```terraform +data "coder_parameter" "enable_gpu" { + name = "enable_gpu" + display_name = "Enable GPU" + type = "bool" + form_type = "checkbox" # This is the default for boolean parameters + default = false +} +``` + +### Slider + +Sliders can be used for configuration on a linear scale, like resource allocation. +The `validation` block is used to constrain (or clamp) the minimum and maximum values for the parameter. + +[Try slider parameters on the Parameters Playground](https://playground.coder.app/parameters/RsBNcWVvfm). + +```terraform +data "coder_parameter" "cpu_cores" { + name = "cpu_cores" + display_name = "CPU Cores" + type = "number" + form_type = "slider" + default = 2 + validation { + min = 1 + max = 8 + } +} +``` + +### Masked Input + +Masked input parameters can be used to visually hide secret values in the workspace creation form. +Note that this does not secure information on the backend and is purely cosmetic. + +[Try private parameters on the Parameters Playground](https://playground.coder.app/parameters/wmiP7FM3Za). + +Note: This text may not be properly hidden in the Playground. +The `mask_input` styling attribute is supported in v2.24.0 and later. + +```terraform +data "coder_parameter" "private_api_key" { + name = "private_api_key" + display_name = "Your super secret API key" + type = "string" + + form_type = "input" # | "textarea" + + # Will render as "**********" + default = "privatekey" + + styling = jsonencode({ + mask_input = true + }) +} +``` + +
+ +### Conditional Parameters + +Using native Terraform syntax and parameter attributes like `count`, we can allow some parameters to react to user inputs. + +This means: + +- Hiding parameters unless activated +- Conditionally setting default values +- Changing available options based on other parameter inputs + +Use these in conjunction to build intuitive, reactive forms for workspace creation. + +
+ +### Hide/Show Options + +Use Terraform conditionals and the `count` block to allow a checkbox to expose or hide a subsequent parameter. + +[Try conditional parameters on the Parameter Playground](https://playground.coder.app/parameters/xmG5MKEGNM). + +```terraform +data "coder_parameter" "show_cpu_cores" { + name = "show_cpu_cores" + display_name = "Toggles next parameter" + description = "Select this checkbox to show the CPU cores parameter." + type = "bool" + form_type = "checkbox" + default = false + order = 1 +} + +data "coder_parameter" "cpu_cores" { + # Only show this parameter if the previous box is selected. + count = data.coder_parameter.show_cpu_cores.value ? 1 : 0 + + name = "cpu_cores" + display_name = "CPU Cores" + type = "number" + form_type = "slider" + default = 2 + order = 2 + validation { + min = 1 + max = 8 + } +} +``` + +### Dynamic Defaults + +Influence which option is selected by default for one parameter based on the selection of another. +This allows you to suggest an option dynamically without strict enforcement. + +[Try dynamic defaults in the Parameter Playground](https://playground.coder.app/parameters/DEi-Bi6DVe). + +```terraform +locals { + ides = [ + "VS Code", + "IntelliJ", "GoLand", + "WebStorm", "PyCharm", + "Databricks", "Jupyter Notebook", + ] + mlkit_ides = jsonencode(["Databricks", "PyCharm"]) + core_ides = jsonencode(["VS Code", "GoLand"]) +} + +data "coder_parameter" "git_repo" { + name = "git_repo" + display_name = "Git repo" + description = "Select a git repo to work on." + order = 1 + mutable = true + type = "string" + form_type = "dropdown" + + option { + # A Go-heavy repository + name = "coder/coder" + value = "coder/coder" + } + + option { + # A python-heavy repository + name = "coder/mlkit" + value = "coder/mlkit" + } +} + +data "coder_parameter" "ide_selector" { + # Conditionally expose this parameter + count = try(data.coder_parameter.git_repo.value, "") != "" ? 1 : 0 + + name = "ide_selector" + description = "Choose any IDEs for your workspace." + order = 2 + mutable = true + + display_name = "Select IDEs" + form_type = "multi-select" + type = "list(string)" + default = try(data.coder_parameter.git_repo.value, "") == "coder/mlkit" ? local.mlkit_ides : local.core_ides + + + dynamic "option" { + for_each = local.ides + content { + name = option.value + value = option.value + } + } +} +``` + +## Dynamic Validation + +A parameter's validation block can leverage inputs from other parameters. + +[Try dynamic validation in the Parameter Playground](https://playground.coder.app/parameters/sdbzXxagJ4). + +```terraform +data "coder_parameter" "git_repo" { + name = "git_repo" + display_name = "Git repo" + description = "Select a git repo to work on." + order = 1 + mutable = true + type = "string" + form_type = "dropdown" + + option { + # A Go-heavy repository + name = "coder/coder" + value = "coder/coder" + } + + option { + # A python-heavy repository + name = "coder/mlkit" + value = "coder/mlkit" + } +} + +data "coder_parameter" "cpu_cores" { + # Only show this parameter if the previous box is selected. + count = data.coder_parameter.show_cpu_cores.value ? 1 : 0 + + name = "cpu_cores" + display_name = "CPU Cores" + type = "number" + form_type = "slider" + order = 2 + + # Dynamically set default + default = try(data.coder_parameter.git_repo.value, "") == "coder/mlkit" ? 12 : 6 + + validation { + min = 1 + + # Dynamically set max validation + max = try(data.coder_parameter.git_repo.value, "") == "coder/mlkit" ? 16 : 8 + } +} +``` + + + +
+ +## Identity-Aware Parameters (Premium) + +Premium users can leverage our roles and groups to conditionally expose or change parameters based on user identity. +This is helpful for establishing governance policy directly in the workspace creation form, +rather than creating multiple templates to manage RBAC. + +User identity is referenced in Terraform by reading the +[`coder_workspace_owner`](https://registry.terraform.io/providers/coder/coder/latest/docs/data-sources/workspace_owner) data source. + +
+ +### Role-aware Options + +Template administrators often want to expose certain experimental or unstable options only to those with elevated roles. +You can now do this by setting `count` based on a user's group or role, referencing the +[`coder_workspace_owner`](https://registry.terraform.io/providers/coder/coder/latest/docs/data-sources/workspace_owner) +data source. + +[Try out admin-only options in the Playground](https://playground.coder.app/parameters/5Gn9W3hYs7). + +```terraform + +locals { + roles = [for r in data.coder_workspace_owner.me.rbac_roles: r.name] + is_admin = contains(data.coder_workspace_owner.me.groups, "admin") + has_admin_role = contains(local.roles, "owner") +} + +data "coder_workspace_owner" "me" {} + +data "coder_parameter" "advanced_settings" { + # This parameter is only visible when the user is an administrator + count = local.is_admin ? 1 : 0 + + name = "advanced_settings" + display_name = "Add an arbitrary script" + description = "An advanced configuration option only available to admins." + type = "string" + form_type = "textarea" + mutable = true + order = 5 + + styling = jsonencode({ + placeholder = <<-EOT + #!/usr/bin/env bash + while true; do + echo "hello world" + sleep 1 + done + EOT + }) +} + +``` + +### Group-aware Regions + +You can expose regions depending on which group a user belongs to. +This way developers can't accidentally induce low-latency with world-spanning connections. + +[Try user-aware regions in the parameter playground](https://playground.coder.app/parameters/tBD-mbZRGm) + +```terraform + +locals { + eu_regions = [ + "eu-west-1 (Ireland)", + "eu-central-1 (Frankfurt)", + "eu-north-1 (Stockholm)", + "eu-west-3 (Paris)", + "eu-south-1 (Milan)" + ] + + us_regions = [ + "us-east-1 (N. Virginia)", + "us-west-1 (California)", + "us-west-2 (Oregon)", + "us-east-2 (Ohio)", + "us-central-1 (Iowa)" + ] + + eu_group_name = "eu-helsinki" + is_eu_dev = contains(data.coder_workspace_owner.me.groups, local.eu_group_name) + region_desc_tag = local.is_eu_dev ? "european" : "american" +} + +data "coder_parameter" "region" { + name = "region" + display_name = "Select a Region" + description = "Select from ${local.region_desc_tag} region options." + type = "string" + form_type = "dropdown" + order = 5 + default = local.is_eu_dev ? local.eu_regions[0] : local.us_regions[0] + + dynamic "option" { + for_each = local.is_eu_dev ? local.eu_regions : local.us_regions + content { + name = option.value + value = option.value + description = "Use ${option.value}" + } + } +} +``` + +### Groups As Namespaces + +A slightly unorthodox way to leverage this is by filling the selections of a parameter from the user's groups. +Some users associate groups with namespaces, such as Kubernetes, then allow users to target that namespace with a parameter. + +[Try groups as options in the Parameter Playground](https://playground.coder.app/parameters/lKbU53nYjl). + +```terraform +locals { + groups = data.coder_workspace_owner.me.groups +} + +data "coder_workspace_owner" "me" {} + +data "coder_parameter" "your_groups" { + type = "string" + name = "your_groups" + display_name = "Your Coder Groups" + description = "Select your namespace..." + default = "target-${local.groups[0]}" + mutable = true + form_type = "dropdown" + + dynamic "option" { + # options populated directly from groups + for_each = local.groups + content { + name = option.value + # Native terraform be used to decorate output + value = "target-${option.value}" + } + } +} +``` + +
+ +## Troubleshooting + +Dynamic Parameters is still in Beta as we continue to polish and improve the workflow. +If you have any issues during upgrade, please file an issue in our +[GitHub repository](https://github.com/coder/coder/issues/new?labels=parameters) and include a +[Playground link](https://playground.coder.app/parameters) where applicable. +We appreciate the feedback and look forward to what the community creates with this system! + +You can also [search or track the list of known issues](https://github.com/coder/coder/issues?q=is%3Aissue%20state%3Aopen%20label%3Aparameters). + +You can share anything you build with Dynamic Parameters in our [Discord](https://coder.com/chat). + +### Enabled Dynamic Parameters, but my template looks the same + +Ensure that the following version requirements are met: + +- `coder/coder`: >= [v2.24.0](https://github.com/coder/coder/releases/tag/v2.24.0) +- `coder/terraform-provider-coder`: >= [v2.5.3](https://github.com/coder/terraform-provider-coder/releases/tag/v2.5.3) + +Enabling Dynamic Parameters on an existing template requires administrators to publish a new template version. +This will resolve the necessary template metadata to render the form. + +### Reverting to classic parameters + +To revert Dynamic Parameters on a template: + +1. Prepare your template by removing any conditional logic or user data references in parameters. +1. As a template administrator or owner, go to your template's settings: + + **Templates** > **Your template** > **Settings** + +1. Uncheck the **Enable dynamic parameters for workspace creation** option. +1. Create a new template version and publish to the active version. + +### Template variables not showing up + +In beta, template variables are not supported in Dynamic Parameters. + +This issue will be resolved by the next minor release of `coder/coder`. +If this is issue is blocking your usage of Dynamic Parameters, please let us know in [this thread](https://github.com/coder/coder/issues/18671). + +### Can I use registry modules with Dynamic Parameters? + +Yes, registry modules are supported with Dynamic Parameters. + +Unless explicitly mentioned, no registry modules require Dynamic Parameters. +Later in 2025, more registry modules will be converted to Dynamic Parameters to improve their UX. + +In the meantime, you can safely convert existing templates and build new parameters on top of the functionality provided in the registry. diff --git a/docs/admin/templates/extending-templates/parameters.md b/docs/admin/templates/extending-templates/parameters.md index 6977d4d3b4c0b..5b380645c1b36 100644 --- a/docs/admin/templates/extending-templates/parameters.md +++ b/docs/admin/templates/extending-templates/parameters.md @@ -207,8 +207,8 @@ data "coder_parameter" "dotfiles_url" { Immutable parameters can only be set in these situations: - Creating a workspace for the first time. -- Updating a workspace to a new template version. This sets the initial value - for required parameters. +- Updating a workspace to a new template version. + This sets the initial value for required parameters. The idea is to prevent users from modifying fragile or persistent workspace resources like volumes, regions, and so on. @@ -224,9 +224,8 @@ data "coder_parameter" "region" { } ``` -You can modify a parameter's `mutable` attribute state anytime. In case of -emergency, you can temporarily allow for changing immutable parameters to fix an -operational issue, but it is not advised to overuse this opportunity. +If a required parameter is empty or if the workspace creation page detects an incompatibility between selected +parameters, the **Create workspace** button is disabled until the issues are resolved. ## Ephemeral parameters @@ -394,544 +393,10 @@ parameters in one of two ways: ## Dynamic Parameters (beta) -Dynamic Parameters enhances Coder's existing parameter system with real-time validation, -conditional parameter behavior, and richer input types. -This feature allows template authors to create more interactive and responsive workspace creation experiences. +Coder v2.24.0 introduces [Dynamic Parameters](./dynamic-parameters.md) to extend the existing parameter system with +conditional form controls, enriched input types, and user identity awareness. +This feature allows template authors to create interactive workspace creation forms, meaning more environment +customization and fewer templates to maintain. -### Enable Dynamic Parameters - -To use Dynamic Parameters, enable the experiment flag or set the environment variable. - -Note that as of v2.22.0, Dynamic parameters are an unsafe experiment and will not be enabled with the experiment wildcard. - -
- -#### Flag - -```shell -coder server --experiments=dynamic-parameters -``` - -#### Env Variable - -```shell -CODER_EXPERIMENTS=dynamic-parameters -``` - -
- -Dynamic Parameters also require version >=2.4.0 of the Coder provider. - -Enable the experiment, then include the following at the top of your template: - -```terraform -terraform { - required_providers { - coder = { - source = "coder/coder" - version = ">=2.4.0" - } - } -} -``` - -Once enabled, users can toggle between the experimental and classic interfaces during -workspace creation using an escape hatch in the workspace creation form. - -## Features and Capabilities - -Dynamic Parameters introduces three primary enhancements to the standard parameter system: - -- **Conditional Parameters** - - - Parameters can respond to changes in other parameters - - Show or hide parameters based on other selections - - Modify validation rules conditionally - - Create branching paths in workspace creation forms - -- **Reference User Properties** - - - Read user data at build time from [`coder_workspace_owner`](https://registry.terraform.io/providers/coder/coder/latest/docs/data-sources/workspace_owner) - - Conditionally hide parameters based on user's role - - Change parameter options based on user groups - - Reference user name in parameters - -- **Additional Form Inputs** - - - Searchable dropdown lists for easier selection - - Multi-select options for choosing multiple items - - Secret text inputs for sensitive information - - Key-value pair inputs for complex data - - Button parameters for toggling sections - -## Available Form Input Types - -Dynamic Parameters supports a variety of form types to create rich, interactive user experiences. - -You can specify the form type using the `form_type` property. -Different parameter types support different form types. - -The "Options" column in the table below indicates whether the form type requires options to be defined (Yes) or doesn't support/require them (No). When required, options are specified using one or more `option` blocks in your parameter definition, where each option has a `name` (displayed to the user) and a `value` (used in your template logic). - -| Form Type | Parameter Types | Options | Notes | -|----------------|--------------------------------------------|---------|------------------------------------------------------------------------------------------------------------------------------| -| `checkbox` | `bool` | No | A single checkbox for boolean parameters. Default for boolean parameters. | -| `dropdown` | `string`, `number` | Yes | Searchable dropdown list for choosing a single option from a list. Default for `string` or `number` parameters with options. | -| `input` | `string`, `number` | No | Standard single-line text input field. Default for string/number parameters without options. | -| `multi-select` | `list(string)` | Yes | Select multiple items from a list with checkboxes. | -| `radio` | `string`, `number`, `bool`, `list(string)` | Yes | Radio buttons for selecting a single option with all choices visible at once. | -| `slider` | `number` | No | Slider selection with min/max validation for numeric values. | -| `switch` | `bool` | No | Toggle switch alternative for boolean parameters. | -| `tag-select` | `list(string)` | No | Default for list(string) parameters without options. | -| `textarea` | `string` | No | Multi-line text input field for longer content. | -| `error` | | No | Used to display an error message when a parameter form_type is unknown | - -### Form Type Examples - -
checkbox: A single checkbox for boolean values - -```tf -data "coder_parameter" "enable_gpu" { - name = "enable_gpu" - display_name = "Enable GPU" - type = "bool" - form_type = "checkbox" # This is the default for boolean parameters - default = false -} -``` - -
- -
dropdown: A searchable select menu for choosing a single option from a list - -```tf -data "coder_parameter" "region" { - name = "region" - display_name = "Region" - description = "Select a region" - type = "string" - form_type = "dropdown" # This is the default for string parameters with options - - option { - name = "US East" - value = "us-east-1" - } - option { - name = "US West" - value = "us-west-2" - } -} -``` - -
- -
input: A standard text input field - -```tf -data "coder_parameter" "custom_domain" { - name = "custom_domain" - display_name = "Custom Domain" - type = "string" - form_type = "input" # This is the default for string parameters without options - default = "" -} -``` - -
- -
key-value: Input for entering key-value pairs - -```tf -data "coder_parameter" "environment_vars" { - name = "environment_vars" - display_name = "Environment Variables" - type = "string" - form_type = "key-value" - default = jsonencode({"NODE_ENV": "development"}) -} -``` - -
- -
multi-select: Checkboxes for selecting multiple options from a list - -```tf -data "coder_parameter" "tools" { - name = "tools" - display_name = "Developer Tools" - type = "list(string)" - form_type = "multi-select" - default = jsonencode(["git", "docker"]) - - option { - name = "Git" - value = "git" - } - option { - name = "Docker" - value = "docker" - } - option { - name = "Kubernetes CLI" - value = "kubectl" - } -} -``` - -
- -
password: A text input that masks sensitive information - -```tf -data "coder_parameter" "api_key" { - name = "api_key" - display_name = "API Key" - type = "string" - form_type = "password" - secret = true -} -``` - -
- -
radio: Radio buttons for selecting a single option with high visibility - -```tf -data "coder_parameter" "environment" { - name = "environment" - display_name = "Environment" - type = "string" - form_type = "radio" - default = "dev" - - option { - name = "Development" - value = "dev" - } - option { - name = "Staging" - value = "staging" - } -} -``` - -
- -
slider: A slider for selecting numeric values within a range - -```tf -data "coder_parameter" "cpu_cores" { - name = "cpu_cores" - display_name = "CPU Cores" - type = "number" - form_type = "slider" - default = 2 - validation { - min = 1 - max = 8 - } -} -``` - -
- -
switch: A toggle switch for boolean values - -```tf -data "coder_parameter" "advanced_mode" { - name = "advanced_mode" - display_name = "Advanced Mode" - type = "bool" - form_type = "switch" - default = false -} -``` - -
- -
textarea: A multi-line text input field for longer content - -```tf -data "coder_parameter" "init_script" { - name = "init_script" - display_name = "Initialization Script" - type = "string" - form_type = "textarea" - default = "#!/bin/bash\necho 'Hello World'" -} -``` - -
- -## Dynamic Parameter Use Case Examples - -
Conditional Parameters: Region and Instance Types - -This example shows instance types based on the selected region: - -```tf -data "coder_parameter" "region" { - name = "region" - display_name = "Region" - description = "Select a region for your workspace" - type = "string" - default = "us-east-1" - - option { - name = "US East (N. Virginia)" - value = "us-east-1" - } - - option { - name = "US West (Oregon)" - value = "us-west-2" - } -} - -data "coder_parameter" "instance_type" { - name = "instance_type" - display_name = "Instance Type" - description = "Select an instance type available in the selected region" - type = "string" - - # This option will only appear when us-east-1 is selected - dynamic "option" { - for_each = data.coder_parameter.region.value == "us-east-1" ? [1] : [] - content { - name = "t3.large (US East)" - value = "t3.large" - } - } - - # This option will only appear when us-west-2 is selected - dynamic "option" { - for_each = data.coder_parameter.region.value == "us-west-2" ? [1] : [] - content { - name = "t3.medium (US West)" - value = "t3.medium" - } - } -} -``` - -
- -
Advanced Options Toggle - -This example shows how to create an advanced options section: - -```tf -data "coder_parameter" "show_advanced" { - name = "show_advanced" - display_name = "Show Advanced Options" - description = "Enable to show advanced configuration options" - type = "bool" - default = false - order = 0 -} - -data "coder_parameter" "advanced_setting" { - # This parameter is only visible when show_advanced is true - count = data.coder_parameter.show_advanced.value ? 1 : 0 - name = "advanced_setting" - display_name = "Advanced Setting" - description = "An advanced configuration option" - type = "string" - default = "default_value" - mutable = true - order = 1 -} - -
- -
Multi-select IDE Options - -This example allows selecting multiple IDEs to install: - -```tf -data "coder_parameter" "ides" { - name = "ides" - display_name = "IDEs to Install" - description = "Select which IDEs to install in your workspace" - type = "list(string)" - default = jsonencode(["vscode"]) - mutable = true - form_type = "multi-select" - - option { - name = "VS Code" - value = "vscode" - icon = "/icon/vscode.png" - } - - option { - name = "JetBrains IntelliJ" - value = "intellij" - icon = "/icon/intellij.png" - } - - option { - name = "JupyterLab" - value = "jupyter" - icon = "/icon/jupyter.png" - } -} -``` - -
- -
Team-specific Resources - -This example filters resources based on user group membership: - -```tf -data "coder_parameter" "instance_type" { - name = "instance_type" - display_name = "Instance Type" - description = "Select an instance type for your workspace" - type = "string" - - # Show GPU options only if user belongs to the "data-science" group - dynamic "option" { - for_each = contains(data.coder_workspace_owner.me.groups, "data-science") ? [1] : [] - content { - name = "p3.2xlarge (GPU)" - value = "p3.2xlarge" - } - } - - # Standard options for all users - option { - name = "t3.medium (Standard)" - value = "t3.medium" - } -} -``` - -### Advanced Usage Patterns - -
Creating Branching Paths - -For templates serving multiple teams or use cases, you can create comprehensive branching paths: - -```tf -data "coder_parameter" "environment_type" { - name = "environment_type" - display_name = "Environment Type" - description = "Select your preferred development environment" - type = "string" - default = "container" - - option { - name = "Container" - value = "container" - } - - option { - name = "Virtual Machine" - value = "vm" - } -} - -# Container-specific parameters -data "coder_parameter" "container_image" { - name = "container_image" - display_name = "Container Image" - description = "Select a container image for your environment" - type = "string" - default = "ubuntu:latest" - - # Only show when container environment is selected - condition { - field = data.coder_parameter.environment_type.name - value = "container" - } - - option { - name = "Ubuntu" - value = "ubuntu:latest" - } - - option { - name = "Python" - value = "python:3.9" - } -} - -# VM-specific parameters -data "coder_parameter" "vm_image" { - name = "vm_image" - display_name = "VM Image" - description = "Select a VM image for your environment" - type = "string" - default = "ubuntu-20.04" - - # Only show when VM environment is selected - condition { - field = data.coder_parameter.environment_type.name - value = "vm" - } - - option { - name = "Ubuntu 20.04" - value = "ubuntu-20.04" - } - - option { - name = "Debian 11" - value = "debian-11" - } -} -``` - -
- -
Conditional Validation - -Adjust validation rules dynamically based on parameter values: - -```tf -data "coder_parameter" "team" { - name = "team" - display_name = "Team" - type = "string" - default = "engineering" - - option { - name = "Engineering" - value = "engineering" - } - - option { - name = "Data Science" - value = "data-science" - } -} - -data "coder_parameter" "cpu_count" { - name = "cpu_count" - display_name = "CPU Count" - type = "number" - default = 2 - - # Engineering team has lower limits - dynamic "validation" { - for_each = data.coder_parameter.team.value == "engineering" ? [1] : [] - content { - min = 1 - max = 4 - } - } - - # Data Science team has higher limits - dynamic "validation" { - for_each = data.coder_parameter.team.value == "data-science" ? [1] : [] - content { - min = 2 - max = 8 - } - } -} -``` - -
+You can read more in the [Dynamic Parameters documentation](./dynamic-parameters.md) and try it out in the +[Parameters Playground](https://playground.coder.app/parameters). diff --git a/docs/admin/templates/extending-templates/prebuilt-workspaces.md b/docs/admin/templates/extending-templates/prebuilt-workspaces.md index 2c5e73ad289b4..8e61687ce0f01 100644 --- a/docs/admin/templates/extending-templates/prebuilt-workspaces.md +++ b/docs/admin/templates/extending-templates/prebuilt-workspaces.md @@ -2,13 +2,10 @@ > [!WARNING] > Prebuilds Compatibility Limitations: -> Prebuilt workspaces are currently not compatible with configurations that have Workspace schedule (autostart/autostop), or Dormancy enabled. -> If these features are configured, prebuilt workspaces may fail to run correctly. +> Prebuilt workspaces currently do not work reliably with [DevContainers feature](../managing-templates/devcontainers/index.md). +> If your project relies on DevContainer configuration, we recommend disabling prebuilds or carefully testing behavior before enabling them. > -> In addition, prebuilds currently do not work reliably with [DevContainers feature](../managing-templates/devcontainers/index.md). -> If your project relies on DevContainer configuration, we recommend disabling prebuilds or carefully testing behavior before enabling them in production. -> -> We’re actively working to improve compatibility, but for now, please avoid using prebuilds with these features to ensure stability and expected behavior. +> We’re actively working to improve compatibility, but for now, please avoid using prebuilds with this feature to ensure stability and expected behavior. Prebuilt workspaces allow template administrators to improve the developer experience by reducing workspace creation time with an automatically maintained pool of ready-to-use workspaces for specific parameter presets. @@ -26,7 +23,7 @@ Prebuilt workspaces are: ## Relationship to workspace presets -Prebuilt workspaces are tightly integrated with [workspace presets](./parameters.md#workspace-presets-beta): +Prebuilt workspaces are tightly integrated with [workspace presets](./parameters.md#workspace-presets): 1. Each prebuilt workspace is associated with a specific template preset. 1. The preset must define all required parameters needed to build the workspace. diff --git a/docs/images/admin/templates/extend-templates/dyn-params/dynamic-params-compare.png b/docs/images/admin/templates/extend-templates/dyn-params/dynamic-params-compare.png new file mode 100644 index 0000000000000..31f02506bfb22 Binary files /dev/null and b/docs/images/admin/templates/extend-templates/dyn-params/dynamic-params-compare.png differ diff --git a/docs/images/admin/templates/extend-templates/dyn-params/enable-dynamic-parameters.png b/docs/images/admin/templates/extend-templates/dyn-params/enable-dynamic-parameters.png new file mode 100644 index 0000000000000..13732661e7eb7 Binary files /dev/null and b/docs/images/admin/templates/extend-templates/dyn-params/enable-dynamic-parameters.png differ diff --git a/docs/images/guides/ai-agents/duplicate.png b/docs/images/guides/ai-agents/duplicate.png new file mode 100644 index 0000000000000..0122671424792 Binary files /dev/null and b/docs/images/guides/ai-agents/duplicate.png differ diff --git a/docs/images/guides/ai-agents/landing.png b/docs/images/guides/ai-agents/landing.png new file mode 100644 index 0000000000000..b1c09a4f222c7 Binary files /dev/null and b/docs/images/guides/ai-agents/landing.png differ diff --git a/docs/images/platforms/docker/create-workspace.png b/docs/images/platforms/docker/create-workspace.png deleted file mode 100644 index 9959244a96f1c..0000000000000 Binary files a/docs/images/platforms/docker/create-workspace.png and /dev/null differ diff --git a/docs/images/platforms/docker/ides.png b/docs/images/platforms/docker/ides.png deleted file mode 100755 index 2293b7af636f1..0000000000000 Binary files a/docs/images/platforms/docker/ides.png and /dev/null differ diff --git a/docs/images/platforms/docker/login.png b/docs/images/platforms/docker/login.png deleted file mode 100755 index c5bad763e92a8..0000000000000 Binary files a/docs/images/platforms/docker/login.png and /dev/null differ diff --git a/docs/images/platforms/kubernetes/region-picker.png b/docs/images/platforms/kubernetes/region-picker.png deleted file mode 100644 index f40a3379010d7..0000000000000 Binary files a/docs/images/platforms/kubernetes/region-picker.png and /dev/null differ diff --git a/docs/images/platforms/kubernetes/starter-template.png b/docs/images/platforms/kubernetes/starter-template.png deleted file mode 100644 index ff81645d73f73..0000000000000 Binary files a/docs/images/platforms/kubernetes/starter-template.png and /dev/null differ diff --git a/docs/images/platforms/kubernetes/template-variables.png b/docs/images/platforms/kubernetes/template-variables.png deleted file mode 100644 index 2d0a9993e4385..0000000000000 Binary files a/docs/images/platforms/kubernetes/template-variables.png and /dev/null differ diff --git a/docs/images/screenshots/admin-settings.png b/docs/images/screenshots/admin-settings.png new file mode 100644 index 0000000000000..0b5c249544e83 Binary files /dev/null and b/docs/images/screenshots/admin-settings.png differ diff --git a/docs/images/screenshots/audit.png b/docs/images/screenshots/audit.png index 5538c67afd8e3..1340179ebc141 100644 Binary files a/docs/images/screenshots/audit.png and b/docs/images/screenshots/audit.png differ diff --git a/docs/images/screenshots/coder-login.png b/docs/images/screenshots/coder-login.png new file mode 100644 index 0000000000000..2757c225afff5 Binary files /dev/null and b/docs/images/screenshots/coder-login.png differ diff --git a/docs/images/screenshots/create-template.png b/docs/images/screenshots/create-template.png index e442a8557c42b..ef54f45d47319 100644 Binary files a/docs/images/screenshots/create-template.png and b/docs/images/screenshots/create-template.png differ diff --git a/docs/images/screenshots/healthcheck.png b/docs/images/screenshots/healthcheck.png index 5b42f716ca7b6..73143fbc9f1d7 100644 Binary files a/docs/images/screenshots/healthcheck.png and b/docs/images/screenshots/healthcheck.png differ diff --git a/docs/images/screenshots/login.png b/docs/images/screenshots/login.png deleted file mode 100644 index 9bfe85e9f4cea..0000000000000 Binary files a/docs/images/screenshots/login.png and /dev/null differ diff --git a/docs/images/screenshots/settings.png b/docs/images/screenshots/settings.png deleted file mode 100644 index cf3f19116fb13..0000000000000 Binary files a/docs/images/screenshots/settings.png and /dev/null differ diff --git a/docs/images/screenshots/starter-templates.png b/docs/images/screenshots/starter-templates.png new file mode 100644 index 0000000000000..51ac42c4bce5f Binary files /dev/null and b/docs/images/screenshots/starter-templates.png differ diff --git a/docs/images/screenshots/starter_templates.png b/docs/images/screenshots/starter_templates.png deleted file mode 100644 index 1eab19f2901cd..0000000000000 Binary files a/docs/images/screenshots/starter_templates.png and /dev/null differ diff --git a/docs/images/screenshots/template-insights.png b/docs/images/screenshots/template-insights.png new file mode 100644 index 0000000000000..605f49d780d8e Binary files /dev/null and b/docs/images/screenshots/template-insights.png differ diff --git a/docs/images/screenshots/templates-listing.png b/docs/images/screenshots/templates-listing.png new file mode 100644 index 0000000000000..e70158a4d7733 Binary files /dev/null and b/docs/images/screenshots/templates-listing.png differ diff --git a/docs/images/screenshots/templates_insights.png b/docs/images/screenshots/templates_insights.png deleted file mode 100644 index 8375661da2603..0000000000000 Binary files a/docs/images/screenshots/templates_insights.png and /dev/null differ diff --git a/docs/images/screenshots/templates_listing.png b/docs/images/screenshots/templates_listing.png deleted file mode 100644 index e887de4f4e2aa..0000000000000 Binary files a/docs/images/screenshots/templates_listing.png and /dev/null differ diff --git a/docs/images/screenshots/terraform.png b/docs/images/screenshots/terraform.png index d8780d650ea1f..654acb936bbd6 100644 Binary files a/docs/images/screenshots/terraform.png and b/docs/images/screenshots/terraform.png differ diff --git a/docs/images/screenshots/welcome-create-admin-user.png b/docs/images/screenshots/welcome-create-admin-user.png index fcb099bf888d2..c2fb24ebd9730 100644 Binary files a/docs/images/screenshots/welcome-create-admin-user.png and b/docs/images/screenshots/welcome-create-admin-user.png differ diff --git a/docs/images/screenshots/workspace-running-with-topbar.png b/docs/images/screenshots/workspace-running-with-topbar.png index ab3f6a78a9e6e..62b32d46bc3fa 100644 Binary files a/docs/images/screenshots/workspace-running-with-topbar.png and b/docs/images/screenshots/workspace-running-with-topbar.png differ diff --git a/docs/images/screenshots/workspace_launch.png b/docs/images/screenshots/workspace_launch.png deleted file mode 100644 index ab2092e7f5d7d..0000000000000 Binary files a/docs/images/screenshots/workspace_launch.png and /dev/null differ diff --git a/docs/images/screenshots/workspaces-listing.png b/docs/images/screenshots/workspaces-listing.png new file mode 100644 index 0000000000000..078dfbb4f6532 Binary files /dev/null and b/docs/images/screenshots/workspaces-listing.png differ diff --git a/docs/images/screenshots/workspaces_listing.png b/docs/images/screenshots/workspaces_listing.png deleted file mode 100644 index ee206c100f5ba..0000000000000 Binary files a/docs/images/screenshots/workspaces_listing.png and /dev/null differ diff --git a/docs/images/start/blank-workspaces.png b/docs/images/start/blank-workspaces.png deleted file mode 100644 index 3dcc74020e4b8..0000000000000 Binary files a/docs/images/start/blank-workspaces.png and /dev/null differ diff --git a/docs/images/templates/coder-login-web.png b/docs/images/templates/coder-login-web.png deleted file mode 100644 index 854c305d1b162..0000000000000 Binary files a/docs/images/templates/coder-login-web.png and /dev/null differ diff --git a/docs/install/kubernetes.md b/docs/install/kubernetes.md index b57169fd1d9e4..72c51e0da3e8c 100644 --- a/docs/install/kubernetes.md +++ b/docs/install/kubernetes.md @@ -44,7 +44,7 @@ on the internet that explain sensible configurations for this chart. Example: ```console # Install PostgreSQL helm repo add bitnami https://charts.bitnami.com/bitnami -helm install coder-db bitnami/postgresql \ +helm install postgresql bitnami/postgresql \ --namespace coder \ --set auth.username=coder \ --set auth.password=coder \ @@ -55,7 +55,7 @@ helm install coder-db bitnami/postgresql \ The cluster-internal DB URL for the above database is: ```shell -postgres://coder:coder@coder-db-postgresql.coder.svc.cluster.local:5432/coder?sslmode=disable +postgres://coder:coder@postgresql.coder.svc.cluster.local:5432/coder?sslmode=disable ``` You can optionally use the @@ -69,7 +69,7 @@ self-managed PostgreSQL, the address will be: ```sh kubectl create secret generic coder-db-url -n coder \ - --from-literal=url="postgres://coder:coder@coder-db-postgresql.coder.svc.cluster.local:5432/coder?sslmode=disable" + --from-literal=url="postgres://coder:coder@postgresql.coder.svc.cluster.local:5432/coder?sslmode=disable" ``` ## 4. Install Coder with Helm @@ -127,25 +127,51 @@ We support two release channels: mainline and stable - read the - **Mainline** Coder release: - + - **Chart Registry** - ```shell - helm install coder coder-v2/coder \ - --namespace coder \ - --values values.yaml \ - --version 2.23.1 - ``` + + + ```shell + helm install coder coder-v2/coder \ + --namespace coder \ + --values values.yaml \ + --version 2.23.1 + ``` + + - **OCI Registry** + + + + ```shell + helm install coder oci://ghcr.io/coder/chart/coder \ + --namespace coder \ + --values values.yaml \ + --version 2.23.1 + ``` - **Stable** Coder release: - + - **Chart Registry** + + + + ```shell + helm install coder coder-v2/coder \ + --namespace coder \ + --values values.yaml \ + --version 2.22.1 + ``` + + - **OCI Registry** + + - ```shell - helm install coder coder-v2/coder \ - --namespace coder \ - --values values.yaml \ - --version 2.22.1 - ``` + ```shell + helm install coder oci://ghcr.io/coder/chart/coder \ + --namespace coder \ + --values values.yaml \ + --version 2.22.1 + ``` You can watch Coder start up by running `kubectl get pods -n coder`. Once Coder has started, the `coder-*` pods should enter the `Running` state. diff --git a/docs/manifest.json b/docs/manifest.json index 9b85e634dce14..65555caa0df4f 100644 --- a/docs/manifest.json +++ b/docs/manifest.json @@ -527,6 +527,12 @@ "description": "Use parameters to customize workspaces at build", "path": "./admin/templates/extending-templates/parameters.md" }, + { + "title": "Dynamic Parameters", + "description": "Conditional, identity-aware parameter syntax for advanced users.", + "path": "./admin/templates/extending-templates/dynamic-parameters.md", + "state": ["beta"] + }, { "title": "Prebuilt workspaces", "description": "Pre-provision a ready-to-deploy workspace with a defined set of parameters", diff --git a/docs/reference/api/agents.md b/docs/reference/api/agents.md index cff5fef6f3f8a..54e9b0e6ad628 100644 --- a/docs/reference/api/agents.md +++ b/docs/reference/api/agents.md @@ -899,6 +899,111 @@ curl -X POST http://coder-server:8080/api/v2/workspaceagents/{workspaceagent}/co To perform this operation, you must be authenticated. [Learn more](authentication.md). +## Watch workspace agent for container updates + +### Code samples + +```shell +# Example request using curl +curl -X GET http://coder-server:8080/api/v2/workspaceagents/{workspaceagent}/containers/watch \ + -H 'Accept: application/json' \ + -H 'Coder-Session-Token: API_KEY' +``` + +`GET /workspaceagents/{workspaceagent}/containers/watch` + +### Parameters + +| Name | In | Type | Required | Description | +|------------------|------|--------------|----------|--------------------| +| `workspaceagent` | path | string(uuid) | true | Workspace agent ID | + +### Example responses + +> 200 Response + +```json +{ + "containers": [ + { + "created_at": "2019-08-24T14:15:22Z", + "id": "string", + "image": "string", + "labels": { + "property1": "string", + "property2": "string" + }, + "name": "string", + "ports": [ + { + "host_ip": "string", + "host_port": 0, + "network": "string", + "port": 0 + } + ], + "running": true, + "status": "string", + "volumes": { + "property1": "string", + "property2": "string" + } + } + ], + "devcontainers": [ + { + "agent": { + "directory": "string", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "name": "string" + }, + "config_path": "string", + "container": { + "created_at": "2019-08-24T14:15:22Z", + "id": "string", + "image": "string", + "labels": { + "property1": "string", + "property2": "string" + }, + "name": "string", + "ports": [ + { + "host_ip": "string", + "host_port": 0, + "network": "string", + "port": 0 + } + ], + "running": true, + "status": "string", + "volumes": { + "property1": "string", + "property2": "string" + } + }, + "dirty": true, + "error": "string", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "name": "string", + "status": "running", + "workspace_folder": "string" + } + ], + "warnings": [ + "string" + ] +} +``` + +### Responses + +| Status | Meaning | Description | Schema | +|--------|---------------------------------------------------------|-------------|----------------------------------------------------------------------------------------------------------| +| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.WorkspaceAgentListContainersResponse](schemas.md#codersdkworkspaceagentlistcontainersresponse) | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). + ## Coordinate workspace agent ### Code samples diff --git a/docs/reference/api/builds.md b/docs/reference/api/builds.md index bb279f5825f6e..686f19316a8c0 100644 --- a/docs/reference/api/builds.md +++ b/docs/reference/api/builds.md @@ -491,9 +491,17 @@ curl -X PATCH http://coder-server:8080/api/v2/workspacebuilds/{workspacebuild}/c ### Parameters -| Name | In | Type | Required | Description | -|------------------|------|--------|----------|--------------------| -| `workspacebuild` | path | string | true | Workspace build ID | +| Name | In | Type | Required | Description | +|------------------|-------|--------|----------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `workspacebuild` | path | string | true | Workspace build ID | +| `expect_status` | query | string | false | Expected status of the job. If expect_status is supplied, the request will be rejected with 412 Precondition Failed if the job doesn't match the state when performing the cancellation. | + +#### Enumerated Values + +| Parameter | Value | +|-----------------|-----------| +| `expect_status` | `running` | +| `expect_status` | `pending` | ### Example responses diff --git a/docs/reference/api/general.md b/docs/reference/api/general.md index 8f440c55b42d6..72543f6774dfd 100644 --- a/docs/reference/api/general.md +++ b/docs/reference/api/general.md @@ -265,7 +265,6 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \ "same_site": "string", "secure_auth_cookie": true }, - "in_memory_database": true, "job_hang_detector_interval": 0, "logging": { "human": "string", diff --git a/docs/reference/api/members.md b/docs/reference/api/members.md index b19c859aa10c1..4b0adbf45e338 100644 --- a/docs/reference/api/members.md +++ b/docs/reference/api/members.md @@ -187,6 +187,7 @@ Status Code **200** | `resource_type` | `assign_org_role` | | `resource_type` | `assign_role` | | `resource_type` | `audit_log` | +| `resource_type` | `connection_log` | | `resource_type` | `crypto_key` | | `resource_type` | `debug_info` | | `resource_type` | `deployment_config` | @@ -356,6 +357,7 @@ Status Code **200** | `resource_type` | `assign_org_role` | | `resource_type` | `assign_role` | | `resource_type` | `audit_log` | +| `resource_type` | `connection_log` | | `resource_type` | `crypto_key` | | `resource_type` | `debug_info` | | `resource_type` | `deployment_config` | @@ -525,6 +527,7 @@ Status Code **200** | `resource_type` | `assign_org_role` | | `resource_type` | `assign_role` | | `resource_type` | `audit_log` | +| `resource_type` | `connection_log` | | `resource_type` | `crypto_key` | | `resource_type` | `debug_info` | | `resource_type` | `deployment_config` | @@ -663,6 +666,7 @@ Status Code **200** | `resource_type` | `assign_org_role` | | `resource_type` | `assign_role` | | `resource_type` | `audit_log` | +| `resource_type` | `connection_log` | | `resource_type` | `crypto_key` | | `resource_type` | `debug_info` | | `resource_type` | `deployment_config` | @@ -1023,6 +1027,7 @@ Status Code **200** | `resource_type` | `assign_org_role` | | `resource_type` | `assign_role` | | `resource_type` | `audit_log` | +| `resource_type` | `connection_log` | | `resource_type` | `crypto_key` | | `resource_type` | `debug_info` | | `resource_type` | `deployment_config` | diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index 973797d52d554..3788d97753457 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -1049,6 +1049,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in | `initiator` | | `autostart` | | `autostop` | +| `dormancy` | ## codersdk.ChangePasswordWithOneTimePasscodeRequest @@ -1986,7 +1987,6 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o "same_site": "string", "secure_auth_cookie": true }, - "in_memory_database": true, "job_hang_detector_interval": 0, "logging": { "human": "string", @@ -2474,7 +2474,6 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o "same_site": "string", "secure_auth_cookie": true }, - "in_memory_database": true, "job_hang_detector_interval": 0, "logging": { "human": "string", @@ -2771,7 +2770,6 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o | `hide_ai_tasks` | boolean | false | | | | `http_address` | string | false | | Http address is a string because it may be set to zero to disable. | | `http_cookies` | [codersdk.HTTPCookieConfig](#codersdkhttpcookieconfig) | false | | | -| `in_memory_database` | boolean | false | | | | `job_hang_detector_interval` | integer | false | | | | `logging` | [codersdk.LoggingConfig](#codersdkloggingconfig) | false | | | | `metrics_cache_refresh_interval` | integer | false | | | @@ -6054,6 +6052,7 @@ Git clone makes use of this by parsing the URL from: 'Username for "https://gith | `assign_org_role` | | `assign_role` | | `audit_log` | +| `connection_log` | | `crypto_key` | | `debug_info` | | `deployment_config` | diff --git a/docs/reference/cli/show.md b/docs/reference/cli/show.md index 87c527ed939f9..c6fb9a2c81f64 100644 --- a/docs/reference/cli/show.md +++ b/docs/reference/cli/show.md @@ -6,5 +6,16 @@ Display details of a workspace's resources and agents ## Usage ```console -coder show +coder show [flags] ``` + +## Options + +### --details + +| | | +|---------|--------------------| +| Type | bool | +| Default | false | + +Show full error messages and additional details. diff --git a/docs/tutorials/quickstart.md b/docs/tutorials/quickstart.md index 595414fd63ccd..7f684fd28c266 100644 --- a/docs/tutorials/quickstart.md +++ b/docs/tutorials/quickstart.md @@ -116,7 +116,7 @@ is installed. ![Create template](../images/screenshots/create-template.png)_Create template_ -1. Select **Create template**. +1. Select **Save**. 1. After the template is ready, select **Create Workspace**. diff --git a/docs/tutorials/template-from-scratch.md b/docs/tutorials/template-from-scratch.md index 22c4c5392001e..3abfdbf940c10 100644 --- a/docs/tutorials/template-from-scratch.md +++ b/docs/tutorials/template-from-scratch.md @@ -351,11 +351,11 @@ use the Coder CLI. 1. In your web browser, enter your credentials: - Log in to your Coder deployment + ![Log in to your Coder deployment](../images/screenshots/coder-login.png) 1. Copy the session token to the clipboard: - Copy session token + ![Copy session token](../images/templates/coder-session-token.png) 1. Paste it into the CLI: diff --git a/docs/user-guides/desktop/index.md b/docs/user-guides/desktop/index.md index d47c2d2a604de..116f7d4d6de69 100644 --- a/docs/user-guides/desktop/index.md +++ b/docs/user-guides/desktop/index.md @@ -1,13 +1,19 @@ # Coder Desktop Coder Desktop provides seamless access to your remote workspaces without the need to install a CLI or configure manual port forwarding. -Connect to workspace services using simple hostnames like `myworkspace.coder`, launch native applications with one click, and synchronize files between local and remote environments. +Connect to workspace services using simple hostnames like `myworkspace.coder`, launch native applications with one click, +and synchronize files between local and remote environments. -> [!NOTE] -> Coder Desktop requires a Coder deployment running [v2.20.0](https://github.com/coder/coder/releases/tag/v2.20.0) or later. +Coder Desktop requires a Coder deployment running [v2.20.0](https://github.com/coder/coder/releases/tag/v2.20.0) or later. ## Install Coder Desktop +> [!IMPORTANT] +> Coder Desktop can't connect through a corporate VPN. +> +> Due to a [known issue](#coder-desktop-cant-connect-through-another-vpn), +> if your Coder deployment requires that you connect through a corporate VPN, Desktop will timeout when it tries to connect. +
You can install Coder Desktop on macOS or Windows. @@ -113,7 +119,7 @@ Before you can use Coder Desktop, you will need to sign in. ![Coder Desktop on Windows - enable Coder Connect](../../images/user-guides/desktop/coder-desktop-win-enable-coder-connect.png) - This may take a few moments, as Coder Desktop will download the necessary components from the Coder server if they have been updated. + This may take a few moments, because Coder Desktop will download the necessary components from the Coder server if they have been updated. 1. macOS: You may be prompted to enter your password to allow Coder Connect to start. @@ -121,7 +127,26 @@ Before you can use Coder Desktop, you will need to sign in. ## Troubleshooting -Do not install more than one copy of Coder Desktop. To avoid system VPN configuration conflicts, only one copy of `Coder Desktop.app` should exist on your Mac, and it must remain in `/Applications`. +If you encounter an issue with Coder Desktop that is not listed here, file an issue in the GitHub repository for +Coder Desktop for [macOS](https://github.com/coder/coder-desktop-macos/issues) or +[Windows](https://github.com/coder/coder-desktop-windows/issues), in the +[main Coder repository](https://github.com/coder/coder/issues), or consult the +[community on Discord](https://coder.com/chat). + +### Known Issues + +#### macOS: Do not install more than one copy of Coder Desktop + +To avoid system VPN configuration conflicts, only one copy of `Coder Desktop.app` should exist on your Mac, and it must remain in `/Applications`. + +#### Coder Desktop can't connect through another VPN + +If the logged in Coder deployment requires a corporate VPN to connect, Coder Connect can't establish communication +through the VPN, and will time out. + +This is due to known issues with [macOS](https://github.com/coder/coder-desktop-macos/issues/201) and +[Windows](https://github.com/coder/coder-desktop-windows/issues/147) networking. +A resolution is in progress. ## Next Steps diff --git a/dogfood/coder/main.tf b/dogfood/coder/main.tf index a9cb72b9c7984..7b8058d676328 100644 --- a/dogfood/coder/main.tf +++ b/dogfood/coder/main.tf @@ -306,8 +306,10 @@ module "vscode-web" { module "jetbrains" { count = data.coder_workspace.me.start_count - source = "git::https://github.com/coder/registry.git//registry/coder/modules/jetbrains?ref=jetbrains" + source = "dev.registry.coder.com/coder/jetbrains/coder" + version = "1.0.0" agent_id = coder_agent.dev.id + agent_name = "dev" folder = local.repo_dir major_version = "latest" } @@ -315,7 +317,7 @@ module "jetbrains" { module "filebrowser" { count = data.coder_workspace.me.start_count source = "dev.registry.coder.com/coder/filebrowser/coder" - version = "1.0.31" + version = "1.1.1" agent_id = coder_agent.dev.id agent_name = "dev" } @@ -337,7 +339,7 @@ module "cursor" { module "windsurf" { count = data.coder_workspace.me.start_count - source = "registry.coder.com/coder/windsurf/coder" + source = "dev.registry.coder.com/coder/windsurf/coder" version = "1.0.0" agent_id = coder_agent.dev.id folder = local.repo_dir @@ -345,7 +347,8 @@ module "windsurf" { module "zed" { count = data.coder_workspace.me.start_count - source = "./zed" + source = "dev.registry.coder.com/coder/zed/coder" + version = "1.0.0" agent_id = coder_agent.dev.id agent_name = "dev" folder = local.repo_dir diff --git a/dogfood/coder/zed/main.tf b/dogfood/coder/zed/main.tf deleted file mode 100644 index 96466ba258a1b..0000000000000 --- a/dogfood/coder/zed/main.tf +++ /dev/null @@ -1,39 +0,0 @@ -terraform { - required_version = ">= 1.0" - required_providers { - coder = { - source = "coder/coder" - version = ">= 0.17" - } - } -} - -variable "agent_id" { - type = string -} - -variable "agent_name" { - type = string - default = "" -} - -variable "folder" { - type = string -} - -data "coder_workspace" "me" {} - -locals { - workspace_name = lower(data.coder_workspace.me.name) - agent_name = lower(var.agent_name) - hostname = var.agent_name != "" ? "${local.agent_name}.${local.workspace_name}.me.coder" : "${local.workspace_name}.coder" -} - -resource "coder_app" "zed" { - agent_id = var.agent_id - display_name = "Zed" - slug = "zed" - icon = "/icon/zed.svg" - external = true - url = "zed://ssh/${local.hostname}/${var.folder}" -} diff --git a/enterprise/audit/audit_test.go b/enterprise/audit/audit_test.go index bf9393612d65c..dd5d6274f65e9 100644 --- a/enterprise/audit/audit_test.go +++ b/enterprise/audit/audit_test.go @@ -8,7 +8,7 @@ import ( "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/enterprise/audit" "github.com/coder/coder/v2/enterprise/audit/audittest" ) @@ -86,11 +86,12 @@ func TestAuditor(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { t.Parallel() + db, _ := dbtestutil.NewDB(t) var ( backend = &testBackend{decision: test.backendDecision, err: test.backendError} exporter = audit.NewAuditor( - dbmem.New(), + db, audit.FilterFunc(func(_ context.Context, _ database.AuditLog) (audit.FilterDecision, error) { return test.filterDecision, test.filterError }), diff --git a/enterprise/audit/backends/postgres_test.go b/enterprise/audit/backends/postgres_test.go index d9a517ca62eaf..5d0032e207ed3 100644 --- a/enterprise/audit/backends/postgres_test.go +++ b/enterprise/audit/backends/postgres_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/enterprise/audit" "github.com/coder/coder/v2/enterprise/audit/audittest" "github.com/coder/coder/v2/enterprise/audit/backends" @@ -20,7 +20,7 @@ func TestPostgresBackend(t *testing.T) { var ( ctx, cancel = context.WithCancel(context.Background()) - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) pgb = backends.NewPostgres(db, true) alog = audittest.RandomLog() ) diff --git a/enterprise/audit/backends/slog.go b/enterprise/audit/backends/slog.go index c49ebae296ff0..7418070b49c38 100644 --- a/enterprise/audit/backends/slog.go +++ b/enterprise/audit/backends/slog.go @@ -12,38 +12,34 @@ import ( "github.com/coder/coder/v2/enterprise/audit" ) -type slogBackend struct { +type SlogExporter struct { log slog.Logger } -func NewSlog(logger slog.Logger) audit.Backend { - return &slogBackend{log: logger} +func NewSlogExporter(logger slog.Logger) *SlogExporter { + return &SlogExporter{log: logger} } -func (*slogBackend) Decision() audit.FilterDecision { - return audit.FilterDecisionExport -} - -func (b *slogBackend) Export(ctx context.Context, alog database.AuditLog, details audit.BackendDetails) error { +func (e *SlogExporter) ExportStruct(ctx context.Context, data any, message string, extraFields ...slog.Field) error { // We don't use structs.Map because we don't want to recursively convert // fields into maps. When we keep the type information, slog can more // pleasantly format the output. For example, the clean result of // (*NullString).Value() may be printed instead of {String: "foo", Valid: true}. - sfs := structs.Fields(alog) + sfs := structs.Fields(data) var fields []any for _, sf := range sfs { - fields = append(fields, b.fieldToSlog(sf)) + fields = append(fields, e.fieldToSlog(sf)) } - if details.Actor != nil { - fields = append(fields, slog.F("actor", details.Actor)) + for _, field := range extraFields { + fields = append(fields, field) } - b.log.Info(ctx, "audit_log", fields...) + e.log.Info(ctx, message, fields...) return nil } -func (*slogBackend) fieldToSlog(field *structs.Field) slog.Field { +func (*SlogExporter) fieldToSlog(field *structs.Field) slog.Field { val := field.Value() switch ty := field.Value().(type) { @@ -55,3 +51,26 @@ func (*slogBackend) fieldToSlog(field *structs.Field) slog.Field { return slog.F(field.Name(), val) } + +type auditSlogBackend struct { + exporter *SlogExporter +} + +func NewSlog(logger slog.Logger) audit.Backend { + return &auditSlogBackend{ + exporter: NewSlogExporter(logger), + } +} + +func (*auditSlogBackend) Decision() audit.FilterDecision { + return audit.FilterDecisionExport +} + +func (b *auditSlogBackend) Export(ctx context.Context, alog database.AuditLog, details audit.BackendDetails) error { + var extraFields []slog.Field + if details.Actor != nil { + extraFields = append(extraFields, slog.F("actor", details.Actor)) + } + + return b.exporter.ExportStruct(ctx, alog, "audit_log", extraFields...) +} diff --git a/enterprise/audit/backends/slog_test.go b/enterprise/audit/backends/slog_test.go index 5fe3cf70c519a..99be36b3f9d15 100644 --- a/enterprise/audit/backends/slog_test.go +++ b/enterprise/audit/backends/slog_test.go @@ -24,7 +24,7 @@ import ( "github.com/coder/coder/v2/enterprise/audit/backends" ) -func TestSlogBackend(t *testing.T) { +func TestSlogExporter(t *testing.T) { t.Parallel() t.Run("OK", func(t *testing.T) { t.Parallel() @@ -32,30 +32,29 @@ func TestSlogBackend(t *testing.T) { var ( ctx, cancel = context.WithCancel(context.Background()) - sink = &fakeSink{} - logger = slog.Make(sink) - backend = backends.NewSlog(logger) + sink = &fakeSink{} + logger = slog.Make(sink) + exporter = backends.NewSlogExporter(logger) alog = audittest.RandomLog() ) defer cancel() - err := backend.Export(ctx, alog, audit.BackendDetails{}) + err := exporter.ExportStruct(ctx, alog, "audit_log") require.NoError(t, err) require.Len(t, sink.entries, 1) require.Equal(t, sink.entries[0].Message, "audit_log") require.Len(t, sink.entries[0].Fields, len(structs.Fields(alog))) }) - t.Run("FormatsCorrectly", func(t *testing.T) { t.Parallel() var ( ctx, cancel = context.WithCancel(context.Background()) - buf = bytes.NewBuffer(nil) - logger = slog.Make(slogjson.Sink(buf)) - backend = backends.NewSlog(logger) + buf = bytes.NewBuffer(nil) + logger = slog.Make(slogjson.Sink(buf)) + exporter = backends.NewSlogExporter(logger) _, inet, _ = net.ParseCIDR("127.0.0.1/32") alog = database.AuditLog{ @@ -81,11 +80,11 @@ func TestSlogBackend(t *testing.T) { ) defer cancel() - err := backend.Export(ctx, alog, audit.BackendDetails{Actor: &audit.Actor{ + err := exporter.ExportStruct(ctx, alog, "audit_log", slog.F("actor", &audit.Actor{ ID: uuid.UUID{2}, Username: "coadler", Email: "doug@coder.com", - }}) + })) require.NoError(t, err) logger.Sync() diff --git a/enterprise/cli/server.go b/enterprise/cli/server.go index 1bf4f31a8506b..3b1fd63ab1c4c 100644 --- a/enterprise/cli/server.go +++ b/enterprise/cli/server.go @@ -87,6 +87,7 @@ func (r *RootCmd) Server(_ func()) *serpent.Command { o := &coderd.Options{ Options: options, AuditLogging: true, + ConnectionLogging: true, BrowserOnly: options.DeploymentValues.BrowserOnly.Value(), SCIMAPIKey: []byte(options.DeploymentValues.SCIMAPIKey.Value()), RBAC: true, diff --git a/enterprise/cli/server_test.go b/enterprise/cli/server_test.go index d913bc443c19f..7489699a6f3dd 100644 --- a/enterprise/cli/server_test.go +++ b/enterprise/cli/server_test.go @@ -12,10 +12,17 @@ import ( "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/config" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/enterprise/cli" "github.com/coder/coder/v2/testutil" ) +func dbArg(t *testing.T) string { + dbURL, err := dbtestutil.Open(t) + require.NoError(t, err) + return "--postgres-url=" + dbURL +} + // TestServer runs the enterprise server command // and waits for /healthz to return "OK". func TestServer_Single(t *testing.T) { @@ -27,9 +34,10 @@ func TestServer_Single(t *testing.T) { var root cli.RootCmd cmd, err := root.Command(root.EnterpriseSubcommands()) require.NoError(t, err) + inv, cfg := clitest.NewWithCommand(t, cmd, "server", - "--in-memory", + dbArg(t), "--http-address", ":0", "--access-url", "http://example.com", ) diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 5f79608275f96..6d523e9226b88 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -22,6 +22,7 @@ import ( agplportsharing "github.com/coder/coder/v2/coderd/portsharing" agplprebuilds "github.com/coder/coder/v2/coderd/prebuilds" "github.com/coder/coder/v2/coderd/rbac/policy" + "github.com/coder/coder/v2/enterprise/coderd/connectionlog" "github.com/coder/coder/v2/enterprise/coderd/enidpsync" "github.com/coder/coder/v2/enterprise/coderd/portsharing" @@ -36,6 +37,7 @@ import ( "github.com/coder/coder/v2/coderd" agplaudit "github.com/coder/coder/v2/coderd/audit" + agplconnectionlog "github.com/coder/coder/v2/coderd/connectionlog" agpldbauthz "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/healthcheck" @@ -123,6 +125,13 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { options.IDPSync = enidpsync.NewSync(options.Logger, options.RuntimeConfig, options.Entitlements, idpsync.FromDeploymentValues(options.DeploymentValues)) } + if options.ConnectionLogger == nil { + options.ConnectionLogger = connectionlog.NewConnectionLogger( + connectionlog.NewDBBackend(options.Database), + connectionlog.NewSlogBackend(options.Logger), + ) + } + api := &API{ ctx: ctx, cancel: cancelFunc, @@ -593,8 +602,9 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { type Options struct { *coderd.Options - RBAC bool - AuditLogging bool + RBAC bool + AuditLogging bool + ConnectionLogging bool // Whether to block non-browser connections. BrowserOnly bool SCIMAPIKey []byte @@ -695,6 +705,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { ctx, api.Database, len(agedReplicas), len(api.ExternalAuthConfigs), api.LicenseKeys, map[codersdk.FeatureName]bool{ codersdk.FeatureAuditLog: api.AuditLogging, + codersdk.FeatureConnectionLog: api.ConnectionLogging, codersdk.FeatureBrowserOnly: api.BrowserOnly, codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0, codersdk.FeatureMultipleExternalAuth: len(api.ExternalAuthConfigs) > 1, @@ -733,6 +744,14 @@ func (api *API) updateEntitlements(ctx context.Context) error { api.AGPL.Auditor.Store(&auditor) } + if initial, changed, enabled := featureChanged(codersdk.FeatureConnectionLog); shouldUpdate(initial, changed, enabled) { + connectionLogger := agplconnectionlog.NewNop() + if enabled { + connectionLogger = api.AGPL.Options.ConnectionLogger + } + api.AGPL.ConnectionLogger.Store(&connectionLogger) + } + if initial, changed, enabled := featureChanged(codersdk.FeatureBrowserOnly); shouldUpdate(initial, changed, enabled) { var handler func(rw http.ResponseWriter) bool if enabled { @@ -780,13 +799,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { if initial, changed, enabled := featureChanged(codersdk.FeatureHighAvailability); shouldUpdate(initial, changed, enabled) { var coordinator agpltailnet.Coordinator - // If HA is enabled, but the database is in-memory, we can't actually - // run HA and the PG coordinator. So throw a log line, and continue to use - // the in memory AGPL coordinator. - if enabled && api.DeploymentValues.InMemoryDatabase.Value() { - api.Logger.Warn(ctx, "high availability is enabled, but cannot be configured due to the database being set to in-memory") - } - if enabled && !api.DeploymentValues.InMemoryDatabase.Value() { + if enabled { haCoordinator, err := tailnet.NewPGCoord(api.ctx, api.Logger, api.Pubsub, api.Database) if err != nil { api.Logger.Error(ctx, "unable to set up high availability coordinator", slog.Error(err)) diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 89a61c657e21a..52301f6dae034 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -42,7 +42,6 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbfake" - "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/rbac" @@ -323,10 +322,11 @@ func TestAuditLogging(t *testing.T) { t.Parallel() t.Run("Enabled", func(t *testing.T) { t.Parallel() + db, _ := dbtestutil.NewDB(t) _, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ AuditLogging: true, Options: &coderdtest.Options{ - Auditor: audit.NewAuditor(dbmem.New(), audit.DefaultFilter), + Auditor: audit.NewAuditor(db, audit.DefaultFilter), }, LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ @@ -334,8 +334,9 @@ func TestAuditLogging(t *testing.T) { }, }, }) + db, _ = dbtestutil.NewDB(t) auditor := *api.AGPL.Auditor.Load() - ea := audit.NewAuditor(dbmem.New(), audit.DefaultFilter) + ea := audit.NewAuditor(db, audit.DefaultFilter) t.Logf("%T = %T", auditor, ea) assert.EqualValues(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(auditor).Type()) }) diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index bd81e5a039599..e4088e83d09f5 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -22,7 +22,6 @@ import ( "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/drpcsdk" @@ -149,8 +148,6 @@ func NewWithAPI(t *testing.T, options *Options) ( // we check for the in-memory test types so that the real types don't have to exported _, ok := coderAPI.Pubsub.(*pubsub.MemoryPubsub) require.False(t, ok, "FeatureHighAvailability is incompatible with MemoryPubsub") - _, ok = coderAPI.Database.(*dbmem.FakeQuerier) - require.False(t, ok, "FeatureHighAvailability is incompatible with dbmem") } } _ = AddLicense(t, client, lo) diff --git a/enterprise/coderd/connectionlog/connectionlog.go b/enterprise/coderd/connectionlog/connectionlog.go new file mode 100644 index 0000000000000..e428a13baf183 --- /dev/null +++ b/enterprise/coderd/connectionlog/connectionlog.go @@ -0,0 +1,66 @@ +package connectionlog + +import ( + "context" + + "github.com/hashicorp/go-multierror" + + "cdr.dev/slog" + agpl "github.com/coder/coder/v2/coderd/connectionlog" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + auditbackends "github.com/coder/coder/v2/enterprise/audit/backends" +) + +type Backend interface { + Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error +} + +func NewConnectionLogger(backends ...Backend) agpl.ConnectionLogger { + return &connectionLogger{ + backends: backends, + } +} + +type connectionLogger struct { + backends []Backend +} + +func (c *connectionLogger) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error { + var errs error + for _, backend := range c.backends { + err := backend.Upsert(ctx, clog) + if err != nil { + errs = multierror.Append(errs, err) + } + } + return errs +} + +type dbBackend struct { + db database.Store +} + +func NewDBBackend(db database.Store) Backend { + return &dbBackend{db: db} +} + +func (b *dbBackend) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error { + //nolint:gocritic // This is the Connection Logger + _, err := b.db.UpsertConnectionLog(dbauthz.AsConnectionLogger(ctx), clog) + return err +} + +type connectionSlogBackend struct { + exporter *auditbackends.SlogExporter +} + +func NewSlogBackend(logger slog.Logger) Backend { + return &connectionSlogBackend{ + exporter: auditbackends.NewSlogExporter(logger), + } +} + +func (b *connectionSlogBackend) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error { + return b.exporter.ExportStruct(ctx, clog, "connection_log") +} diff --git a/enterprise/coderd/dormancy/dormantusersjob_test.go b/enterprise/coderd/dormancy/dormantusersjob_test.go index bb3e0b4170baf..e5e5276fe67a9 100644 --- a/enterprise/coderd/dormancy/dormantusersjob_test.go +++ b/enterprise/coderd/dormancy/dormantusersjob_test.go @@ -12,7 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/enterprise/coderd/dormancy" "github.com/coder/quartz" ) @@ -26,7 +26,7 @@ func TestCheckInactiveUsers(t *testing.T) { // Add some dormant accounts logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) ctx, cancelFunc := context.WithCancel(context.Background()) t.Cleanup(cancelFunc) @@ -75,7 +75,7 @@ func TestCheckInactiveUsers(t *testing.T) { allUsers := ignoreUpdatedAt(database.ConvertUserRows(rows)) // Verify user status - expectedUsers := []database.User{ + expectedUsers := ignoreUpdatedAt([]database.User{ asDormant(inactiveUser1), asDormant(inactiveUser2), asDormant(inactiveUser3), @@ -85,14 +85,24 @@ func TestCheckInactiveUsers(t *testing.T) { suspendedUser1, suspendedUser2, suspendedUser3, - } + }) + require.ElementsMatch(t, allUsers, expectedUsers) } func setupUser(ctx context.Context, t *testing.T, db database.Store, email string, status database.UserStatus, lastSeenAt time.Time) database.User { t.Helper() - user, err := db.InsertUser(ctx, database.InsertUserParams{ID: uuid.New(), LoginType: database.LoginTypePassword, Username: uuid.NewString()[:8], Email: email}) + now := dbtestutil.NowInDefaultTimezone() + user, err := db.InsertUser(ctx, database.InsertUserParams{ + ID: uuid.New(), + LoginType: database.LoginTypePassword, + Username: uuid.NewString()[:8], + Email: email, + RBACRoles: []string{}, + CreatedAt: now, + UpdatedAt: now, + }) require.NoError(t, err) // At the beginning of the test all users are marked as active user, err = db.UpdateUserStatus(ctx, database.UpdateUserStatusParams{ID: user.ID, Status: status}) diff --git a/enterprise/coderd/dynamicparameters_test.go b/enterprise/coderd/dynamicparameters_test.go index e13d370a059ad..94a4158dc8354 100644 --- a/enterprise/coderd/dynamicparameters_test.go +++ b/enterprise/coderd/dynamicparameters_test.go @@ -338,7 +338,6 @@ func TestDynamicParameterBuild(t *testing.T) { bld, err := templateAdmin.CreateWorkspaceBuild(ctx, wrk.ID, codersdk.CreateWorkspaceBuildRequest{ TemplateVersionID: immutable.ID, // Use the new template version with the immutable parameter Transition: codersdk.WorkspaceTransitionDelete, - DryRun: false, }) require.NoError(t, err) coderdtest.AwaitWorkspaceBuildJobCompleted(t, templateAdmin, bld.ID) @@ -354,6 +353,75 @@ func TestDynamicParameterBuild(t *testing.T) { require.NoError(t, err) require.Equal(t, wrk.ID, deleted.ID, "workspace should be deleted") }) + + t.Run("PreviouslyImmutable", func(t *testing.T) { + // Ok this is a weird test to document how things are working. + // What if a parameter flips it's immutability based on a value? + // The current behavior is to source immutability from the new state. + // So the value is allowed to be changed. + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + // Start with a new template that has 1 parameter that is immutable + immutable, _ := coderdtest.DynamicParameterTemplate(t, templateAdmin, orgID, coderdtest.DynamicParameterTemplateParams{ + MainTF: "# PreviouslyImmutable\n" + string(must(os.ReadFile("testdata/parameters/dynamicimmutable/main.tf"))), + }) + + // Create the workspace with the immutable parameter + wrk, err := templateAdmin.CreateUserWorkspace(ctx, codersdk.Me, codersdk.CreateWorkspaceRequest{ + TemplateID: immutable.ID, + Name: coderdtest.RandomUsername(t), + RichParameterValues: []codersdk.WorkspaceBuildParameter{ + {Name: "isimmutable", Value: "true"}, + {Name: "immutable", Value: "coder"}, + }, + }) + require.NoError(t, err) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, templateAdmin, wrk.LatestBuild.ID) + + // Try new values + _, err = templateAdmin.CreateWorkspaceBuild(ctx, wrk.ID, codersdk.CreateWorkspaceBuildRequest{ + Transition: codersdk.WorkspaceTransitionStart, + RichParameterValues: []codersdk.WorkspaceBuildParameter{ + {Name: "isimmutable", Value: "false"}, + {Name: "immutable", Value: "not-coder"}, + }, + }) + require.NoError(t, err) + }) + + t.Run("PreviouslyMutable", func(t *testing.T) { + // The value cannot be changed because it becomes immutable. + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + immutable, _ := coderdtest.DynamicParameterTemplate(t, templateAdmin, orgID, coderdtest.DynamicParameterTemplateParams{ + MainTF: "# PreviouslyMutable\n" + string(must(os.ReadFile("testdata/parameters/dynamicimmutable/main.tf"))), + }) + + // Create the workspace with the mutable parameter + wrk, err := templateAdmin.CreateUserWorkspace(ctx, codersdk.Me, codersdk.CreateWorkspaceRequest{ + TemplateID: immutable.ID, + Name: coderdtest.RandomUsername(t), + RichParameterValues: []codersdk.WorkspaceBuildParameter{ + {Name: "isimmutable", Value: "false"}, + {Name: "immutable", Value: "coder"}, + }, + }) + require.NoError(t, err) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, templateAdmin, wrk.LatestBuild.ID) + + // Switch it to immutable, which breaks the validation + _, err = templateAdmin.CreateWorkspaceBuild(ctx, wrk.ID, codersdk.CreateWorkspaceBuildRequest{ + Transition: codersdk.WorkspaceTransitionStart, + RichParameterValues: []codersdk.WorkspaceBuildParameter{ + {Name: "isimmutable", Value: "true"}, + {Name: "immutable", Value: "not-coder"}, + }, + }) + require.Error(t, err) + require.ErrorContains(t, err, "is not mutable") + }) }) } diff --git a/enterprise/coderd/enidpsync/organizations_test.go b/enterprise/coderd/enidpsync/organizations_test.go index d2a5aafece558..13a9bd69ed8fd 100644 --- a/enterprise/coderd/enidpsync/organizations_test.go +++ b/enterprise/coderd/enidpsync/organizations_test.go @@ -53,10 +53,6 @@ type OrganizationSyncTestCase struct { func TestOrganizationSync(t *testing.T) { t.Parallel() - if dbtestutil.WillUsePostgres() { - t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres") - } - requireUserOrgs := func(t *testing.T, db database.Store, user database.User, expected []uuid.UUID) { t.Helper() diff --git a/enterprise/coderd/license/license_test.go b/enterprise/coderd/license/license_test.go index bf6d6448205e0..5ec28ffa9c294 100644 --- a/enterprise/coderd/license/license_test.go +++ b/enterprise/coderd/license/license_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" @@ -30,7 +30,7 @@ func TestEntitlements(t *testing.T) { t.Run("Defaults", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) entitlements, err := license.Entitlements(context.Background(), db, 1, 1, coderdenttest.Keys, all) require.NoError(t, err) require.False(t, entitlements.HasLicense) @@ -42,7 +42,7 @@ func TestEntitlements(t *testing.T) { }) t.Run("Always return the current user count", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) entitlements, err := license.Entitlements(context.Background(), db, 1, 1, coderdenttest.Keys, all) require.NoError(t, err) require.False(t, entitlements.HasLicense) @@ -51,7 +51,7 @@ func TestEntitlements(t *testing.T) { }) t.Run("SingleLicenseNothing", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}), Exp: dbtime.Now().Add(time.Hour), @@ -67,7 +67,7 @@ func TestEntitlements(t *testing.T) { }) t.Run("SingleLicenseAll", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ Features: func() license.Features { @@ -90,7 +90,7 @@ func TestEntitlements(t *testing.T) { }) t.Run("SingleLicenseGrace", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ Features: license.Features{ @@ -116,7 +116,7 @@ func TestEntitlements(t *testing.T) { }) t.Run("Expiration warning", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ Features: license.Features{ @@ -145,7 +145,7 @@ func TestEntitlements(t *testing.T) { t.Run("Expiration warning for license expiring in 1 day", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ Features: license.Features{ @@ -174,7 +174,7 @@ func TestEntitlements(t *testing.T) { t.Run("Expiration warning for trials", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ Features: license.Features{ @@ -204,7 +204,7 @@ func TestEntitlements(t *testing.T) { t.Run("Expiration warning for non trials", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ Features: license.Features{ @@ -233,7 +233,7 @@ func TestEntitlements(t *testing.T) { t.Run("SingleLicenseNotEntitled", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}), Exp: time.Now().Add(time.Hour), @@ -261,11 +261,13 @@ func TestEntitlements(t *testing.T) { }) t.Run("TooManyUsers", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) activeUser1, err := db.InsertUser(context.Background(), database.InsertUserParams{ ID: uuid.New(), Username: "test1", + Email: "test1@coder.com", LoginType: database.LoginTypePassword, + RBACRoles: []string{}, }) require.NoError(t, err) _, err = db.UpdateUserStatus(context.Background(), database.UpdateUserStatusParams{ @@ -277,7 +279,9 @@ func TestEntitlements(t *testing.T) { activeUser2, err := db.InsertUser(context.Background(), database.InsertUserParams{ ID: uuid.New(), Username: "test2", + Email: "test2@coder.com", LoginType: database.LoginTypePassword, + RBACRoles: []string{}, }) require.NoError(t, err) _, err = db.UpdateUserStatus(context.Background(), database.UpdateUserStatusParams{ @@ -289,7 +293,9 @@ func TestEntitlements(t *testing.T) { _, err = db.InsertUser(context.Background(), database.InsertUserParams{ ID: uuid.New(), Username: "dormant-user", + Email: "dormant-user@coder.com", LoginType: database.LoginTypePassword, + RBACRoles: []string{}, }) require.NoError(t, err) db.InsertLicense(context.Background(), database.InsertLicenseParams{ @@ -307,7 +313,7 @@ func TestEntitlements(t *testing.T) { }) t.Run("MaximizeUserLimit", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertUser(context.Background(), database.InsertUserParams{}) db.InsertUser(context.Background(), database.InsertUserParams{}) db.InsertLicense(context.Background(), database.InsertLicenseParams{ @@ -335,7 +341,7 @@ func TestEntitlements(t *testing.T) { }) t.Run("MultipleLicenseEnabled", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) // One trial db.InsertLicense(context.Background(), database.InsertLicenseParams{ Exp: time.Now().Add(time.Hour), @@ -359,7 +365,7 @@ func TestEntitlements(t *testing.T) { t.Run("Enterprise", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) _, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ Exp: time.Now().Add(time.Hour), JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ @@ -390,7 +396,7 @@ func TestEntitlements(t *testing.T) { t.Run("Premium", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) _, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ Exp: time.Now().Add(time.Hour), JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ @@ -421,7 +427,7 @@ func TestEntitlements(t *testing.T) { t.Run("SetNone", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) _, err := db.InsertLicense(context.Background(), database.InsertLicenseParams{ Exp: time.Now().Add(time.Hour), JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ @@ -443,7 +449,7 @@ func TestEntitlements(t *testing.T) { // AllFeatures uses the deprecated 'AllFeatures' boolean. t.Run("AllFeatures", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ Exp: time.Now().Add(time.Hour), JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ @@ -473,7 +479,7 @@ func TestEntitlements(t *testing.T) { t.Run("AllFeaturesAlwaysEnable", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ Exp: dbtime.Now().Add(time.Hour), JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ @@ -504,7 +510,7 @@ func TestEntitlements(t *testing.T) { t.Run("AllFeaturesGrace", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ Exp: dbtime.Now().Add(time.Hour), JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ @@ -535,7 +541,7 @@ func TestEntitlements(t *testing.T) { t.Run("MultipleReplicasNoLicense", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) entitlements, err := license.Entitlements(context.Background(), db, 2, 1, coderdenttest.Keys, all) require.NoError(t, err) require.False(t, entitlements.HasLicense) @@ -545,7 +551,7 @@ func TestEntitlements(t *testing.T) { t.Run("MultipleReplicasNotEntitled", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ Exp: time.Now().Add(time.Hour), JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ @@ -565,7 +571,7 @@ func TestEntitlements(t *testing.T) { t.Run("MultipleReplicasGrace", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ Features: license.Features{ @@ -587,7 +593,7 @@ func TestEntitlements(t *testing.T) { t.Run("MultipleGitAuthNoLicense", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) entitlements, err := license.Entitlements(context.Background(), db, 1, 2, coderdenttest.Keys, all) require.NoError(t, err) require.False(t, entitlements.HasLicense) @@ -597,7 +603,7 @@ func TestEntitlements(t *testing.T) { t.Run("MultipleGitAuthNotEntitled", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ Exp: time.Now().Add(time.Hour), JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ @@ -617,7 +623,7 @@ func TestEntitlements(t *testing.T) { t.Run("MultipleGitAuthGrace", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ GraceAt: time.Now().Add(-time.Hour), @@ -649,6 +655,7 @@ func TestLicenseEntitlements(t *testing.T) { // maybe some should be moved to "AlwaysEnabled" instead. defaultEnablements := map[codersdk.FeatureName]bool{ codersdk.FeatureAuditLog: true, + codersdk.FeatureConnectionLog: true, codersdk.FeatureBrowserOnly: true, codersdk.FeatureSCIM: true, codersdk.FeatureMultipleExternalAuth: true, diff --git a/enterprise/coderd/prebuilds/reconcile.go b/enterprise/coderd/prebuilds/reconcile.go index 44e0e82c8881a..cce39ea251323 100644 --- a/enterprise/coderd/prebuilds/reconcile.go +++ b/enterprise/coderd/prebuilds/reconcile.go @@ -12,6 +12,7 @@ import ( "sync/atomic" "time" + "github.com/google/go-cmp/cmp" "github.com/hashicorp/go-multierror" "github.com/prometheus/client_golang/prometheus" @@ -398,11 +399,21 @@ func (c *StoreReconciler) SnapshotState(ctx context.Context, store database.Stor return xerrors.Errorf("failed to get preset prebuild schedules: %w", err) } + // Get results from both original and optimized queries for comparison allRunningPrebuilds, err := db.GetRunningPrebuiltWorkspaces(ctx) if err != nil { return xerrors.Errorf("failed to get running prebuilds: %w", err) } + // Compare with optimized query to ensure behavioral correctness + optimized, err := db.GetRunningPrebuiltWorkspacesOptimized(ctx) + if err != nil { + // Log the error but continue with original results + c.logger.Error(ctx, "optimized GetRunningPrebuiltWorkspacesOptimized failed", slog.Error(err)) + } else { + CompareGetRunningPrebuiltWorkspacesResults(ctx, c.logger, allRunningPrebuilds, optimized) + } + allPrebuildsInProgress, err := db.CountInProgressPrebuilds(ctx) if err != nil { return xerrors.Errorf("failed to get prebuilds in progress: %w", err) @@ -922,3 +933,30 @@ func SetPrebuildsReconciliationPaused(ctx context.Context, db database.Store, pa } return db.UpsertPrebuildsSettings(ctx, string(settingsJSON)) } + +// CompareGetRunningPrebuiltWorkspacesResults compares the original and optimized +// query results and logs any differences found. This function can be easily +// removed once we're confident the optimized query works correctly. +// TODO(Cian): Remove this function once the optimized query is stable and correct. +func CompareGetRunningPrebuiltWorkspacesResults( + ctx context.Context, + logger slog.Logger, + original []database.GetRunningPrebuiltWorkspacesRow, + optimized []database.GetRunningPrebuiltWorkspacesOptimizedRow, +) { + if len(original) == 0 && len(optimized) == 0 { + return + } + // Convert optimized results to the same type as original for comparison + optimizedConverted := make([]database.GetRunningPrebuiltWorkspacesRow, len(optimized)) + for i, row := range optimized { + optimizedConverted[i] = database.GetRunningPrebuiltWorkspacesRow(row) + } + + // Compare the results and log an error if they differ. + // NOTE: explicitly not sorting here as both query results are ordered by ID. + if diff := cmp.Diff(original, optimizedConverted); diff != "" { + logger.Error(ctx, "results differ for GetRunningPrebuiltWorkspacesOptimized", + slog.F("diff", diff)) + } +} diff --git a/enterprise/coderd/prebuilds/reconcile_test.go b/enterprise/coderd/prebuilds/reconcile_test.go index fce5269214ed1..858b01abc00b9 100644 --- a/enterprise/coderd/prebuilds/reconcile_test.go +++ b/enterprise/coderd/prebuilds/reconcile_test.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "sort" + "strings" "sync" "testing" "time" @@ -26,6 +27,7 @@ import ( "tailscale.com/types/ptr" "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogjson" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/quartz" @@ -370,6 +372,8 @@ func TestPrebuildReconciliation(t *testing.T) { templateVersionID, ) + setupTestDBPrebuildAntagonists(t, db, pubSub, org) + if !templateVersionActive { // Create a new template version and mark it as active // This marks the template version that we care about as inactive @@ -2116,6 +2120,115 @@ func setupTestDBWorkspaceAgent(t *testing.T, db database.Store, workspaceID uuid return agent } +// setupTestDBAntagonists creates test antagonists that should not influence running prebuild workspace tests. +// 1. A stopped prebuilt workspace (STOP then START transitions, owned by +// prebuilds system user). +// 2. A running regular workspace (not owned by the prebuilds system user). +func setupTestDBPrebuildAntagonists(t *testing.T, db database.Store, ps pubsub.Pubsub, org database.Organization) { + t.Helper() + + templateAdmin := dbgen.User(t, db, database.User{RBACRoles: []string{codersdk.RoleTemplateAdmin}}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + OrganizationID: org.ID, + UserID: templateAdmin.ID, + }) + member := dbgen.User(t, db, database.User{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + OrganizationID: org.ID, + UserID: member.ID, + }) + tpl := dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + CreatedBy: templateAdmin.ID, + }) + tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, + OrganizationID: org.ID, + CreatedBy: templateAdmin.ID, + }) + + // 1) Stopped prebuilt workspace (owned by prebuilds system user) + stoppedPrebuild := dbgen.Workspace(t, db, database.WorkspaceTable{ + OwnerID: database.PrebuildsSystemUserID, + TemplateID: tpl.ID, + Name: "prebuild-antagonist-stopped", + Deleted: false, + }) + + // STOP build (build number 2, most recent) + stoppedJob2 := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ + OrganizationID: org.ID, + InitiatorID: database.PrebuildsSystemUserID, + Provisioner: database.ProvisionerTypeEcho, + Type: database.ProvisionerJobTypeWorkspaceBuild, + StartedAt: sql.NullTime{Time: dbtime.Now().Add(-30 * time.Second), Valid: true}, + CompletedAt: sql.NullTime{Time: dbtime.Now().Add(-20 * time.Second), Valid: true}, + Error: sql.NullString{}, + ErrorCode: sql.NullString{}, + }) + dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: stoppedPrebuild.ID, + TemplateVersionID: tv.ID, + JobID: stoppedJob2.ID, + BuildNumber: 2, + Transition: database.WorkspaceTransitionStop, + InitiatorID: database.PrebuildsSystemUserID, + Reason: database.BuildReasonInitiator, + // Explicitly not using a preset here. This shouldn't normally be possible, + // but without this the reconciler will try to create a new prebuild for + // this preset, which will affect the tests. + TemplateVersionPresetID: uuid.NullUUID{}, + }) + + // START build (build number 1, older) + stoppedJob1 := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ + OrganizationID: org.ID, + InitiatorID: database.PrebuildsSystemUserID, + Provisioner: database.ProvisionerTypeEcho, + Type: database.ProvisionerJobTypeWorkspaceBuild, + StartedAt: sql.NullTime{Time: dbtime.Now().Add(-60 * time.Second), Valid: true}, + CompletedAt: sql.NullTime{Time: dbtime.Now().Add(-50 * time.Second), Valid: true}, + Error: sql.NullString{}, + ErrorCode: sql.NullString{}, + }) + dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: stoppedPrebuild.ID, + TemplateVersionID: tv.ID, + JobID: stoppedJob1.ID, + BuildNumber: 1, + Transition: database.WorkspaceTransitionStart, + InitiatorID: database.PrebuildsSystemUserID, + Reason: database.BuildReasonInitiator, + }) + + // 2) Running regular workspace (not owned by prebuilds system user) + regularWorkspace := dbgen.Workspace(t, db, database.WorkspaceTable{ + OwnerID: member.ID, + TemplateID: tpl.ID, + Name: "antagonist-regular-workspace", + Deleted: false, + }) + regularJob := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{ + OrganizationID: org.ID, + InitiatorID: member.ID, + Provisioner: database.ProvisionerTypeEcho, + Type: database.ProvisionerJobTypeWorkspaceBuild, + StartedAt: sql.NullTime{Time: dbtime.Now().Add(-40 * time.Second), Valid: true}, + CompletedAt: sql.NullTime{Time: dbtime.Now().Add(-30 * time.Second), Valid: true}, + Error: sql.NullString{}, + ErrorCode: sql.NullString{}, + }) + dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: regularWorkspace.ID, + TemplateVersionID: tv.ID, + JobID: regularJob.ID, + BuildNumber: 1, + Transition: database.WorkspaceTransitionStart, + InitiatorID: member.ID, + Reason: database.BuildReasonInitiator, + }) +} + var allTransitions = []database.WorkspaceTransition{ database.WorkspaceTransitionStart, database.WorkspaceTransitionStop, @@ -2220,3 +2333,164 @@ func TestReconciliationRespectsPauseSetting(t *testing.T) { require.NoError(t, err) require.Len(t, workspaces, 2, "should have recreated 2 prebuilds after resuming") } + +func TestCompareGetRunningPrebuiltWorkspacesResults(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + // Helper to create test data + createWorkspaceRow := func(id string, name string, ready bool) database.GetRunningPrebuiltWorkspacesRow { + uid := uuid.MustParse(id) + return database.GetRunningPrebuiltWorkspacesRow{ + ID: uid, + Name: name, + TemplateID: uuid.New(), + TemplateVersionID: uuid.New(), + CurrentPresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + Ready: ready, + CreatedAt: time.Now(), + } + } + + createOptimizedRow := func(row database.GetRunningPrebuiltWorkspacesRow) database.GetRunningPrebuiltWorkspacesOptimizedRow { + return database.GetRunningPrebuiltWorkspacesOptimizedRow(row) + } + + t.Run("identical results - no logging", func(t *testing.T) { + t.Parallel() + + var sb strings.Builder + logger := slog.Make(slogjson.Sink(&sb)) + + original := []database.GetRunningPrebuiltWorkspacesRow{ + createWorkspaceRow("550e8400-e29b-41d4-a716-446655440000", "workspace1", true), + createWorkspaceRow("550e8400-e29b-41d4-a716-446655440001", "workspace2", false), + } + + optimized := []database.GetRunningPrebuiltWorkspacesOptimizedRow{ + createOptimizedRow(original[0]), + createOptimizedRow(original[1]), + } + + prebuilds.CompareGetRunningPrebuiltWorkspacesResults(ctx, logger, original, optimized) + + // Should not log any errors when results are identical + require.Empty(t, strings.TrimSpace(sb.String())) + }) + + t.Run("count mismatch - logs error", func(t *testing.T) { + t.Parallel() + + var sb strings.Builder + logger := slog.Make(slogjson.Sink(&sb)) + + original := []database.GetRunningPrebuiltWorkspacesRow{ + createWorkspaceRow("550e8400-e29b-41d4-a716-446655440000", "workspace1", true), + } + + optimized := []database.GetRunningPrebuiltWorkspacesOptimizedRow{ + createOptimizedRow(original[0]), + createOptimizedRow(createWorkspaceRow("550e8400-e29b-41d4-a716-446655440001", "workspace2", false)), + } + + prebuilds.CompareGetRunningPrebuiltWorkspacesResults(ctx, logger, original, optimized) + + // Should log exactly one error. + if lines := strings.Split(strings.TrimSpace(sb.String()), "\n"); assert.NotEmpty(t, lines) { + require.Len(t, lines, 1) + assert.Contains(t, lines[0], "ERROR") + assert.Contains(t, lines[0], "workspace2") + assert.Contains(t, lines[0], "CurrentPresetID") + } + }) + + t.Run("count mismatch - other direction", func(t *testing.T) { + t.Parallel() + + var sb strings.Builder + logger := slog.Make(slogjson.Sink(&sb)) + + original := []database.GetRunningPrebuiltWorkspacesRow{} + + optimized := []database.GetRunningPrebuiltWorkspacesOptimizedRow{ + createOptimizedRow(createWorkspaceRow("550e8400-e29b-41d4-a716-446655440001", "workspace2", false)), + } + + prebuilds.CompareGetRunningPrebuiltWorkspacesResults(ctx, logger, original, optimized) + + if lines := strings.Split(strings.TrimSpace(sb.String()), "\n"); assert.NotEmpty(t, lines) { + require.Len(t, lines, 1) + assert.Contains(t, lines[0], "ERROR") + assert.Contains(t, lines[0], "workspace2") + assert.Contains(t, lines[0], "CurrentPresetID") + } + }) + + t.Run("field differences - logs errors", func(t *testing.T) { + t.Parallel() + + var sb strings.Builder + logger := slog.Make(slogjson.Sink(&sb)) + + workspace1 := createWorkspaceRow("550e8400-e29b-41d4-a716-446655440000", "workspace1", true) + workspace2 := createWorkspaceRow("550e8400-e29b-41d4-a716-446655440001", "workspace2", false) + + original := []database.GetRunningPrebuiltWorkspacesRow{workspace1, workspace2} + + // Create optimized with different values + optimized1 := createOptimizedRow(workspace1) + optimized1.Name = "different-name" // Different name + optimized1.Ready = false // Different ready status + + optimized2 := createOptimizedRow(workspace2) + optimized2.CurrentPresetID = uuid.NullUUID{Valid: false} // Different preset ID (NULL) + + optimized := []database.GetRunningPrebuiltWorkspacesOptimizedRow{optimized1, optimized2} + + prebuilds.CompareGetRunningPrebuiltWorkspacesResults(ctx, logger, original, optimized) + + // Should log exactly one error with a cmp.Diff output + if lines := strings.Split(strings.TrimSpace(sb.String()), "\n"); assert.NotEmpty(t, lines) { + require.Len(t, lines, 1) + assert.Contains(t, lines[0], "ERROR") + assert.Contains(t, lines[0], "different-name") + assert.Contains(t, lines[0], "workspace1") + assert.Contains(t, lines[0], "Ready") + assert.Contains(t, lines[0], "CurrentPresetID") + } + }) + + t.Run("empty results - no logging", func(t *testing.T) { + t.Parallel() + + var sb strings.Builder + logger := slog.Make(slogjson.Sink(&sb)) + + original := []database.GetRunningPrebuiltWorkspacesRow{} + optimized := []database.GetRunningPrebuiltWorkspacesOptimizedRow{} + + prebuilds.CompareGetRunningPrebuiltWorkspacesResults(ctx, logger, original, optimized) + + // Should not log any errors when both results are empty + require.Empty(t, strings.TrimSpace(sb.String())) + }) + + t.Run("nil original", func(t *testing.T) { + t.Parallel() + var sb strings.Builder + logger := slog.Make(slogjson.Sink(&sb)) + prebuilds.CompareGetRunningPrebuiltWorkspacesResults(ctx, logger, nil, []database.GetRunningPrebuiltWorkspacesOptimizedRow{}) + // Should not log any errors when original is nil + require.Empty(t, strings.TrimSpace(sb.String())) + }) + + t.Run("nil optimized ", func(t *testing.T) { + t.Parallel() + var sb strings.Builder + logger := slog.Make(slogjson.Sink(&sb)) + prebuilds.CompareGetRunningPrebuiltWorkspacesResults(ctx, logger, []database.GetRunningPrebuiltWorkspacesRow{}, nil) + // Should not log any errors when optimized is nil + require.Empty(t, strings.TrimSpace(sb.String())) + }) +} diff --git a/enterprise/coderd/proxyhealth/proxyhealth_test.go b/enterprise/coderd/proxyhealth/proxyhealth_test.go index 6879382192116..a002b6d9e7a09 100644 --- a/enterprise/coderd/proxyhealth/proxyhealth_test.go +++ b/enterprise/coderd/proxyhealth/proxyhealth_test.go @@ -12,7 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/proxyhealth" @@ -46,7 +46,7 @@ func TestProxyHealth_Nil(t *testing.T) { func TestProxyHealth_Unregistered(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) proxies := []database.WorkspaceProxy{ insertProxy(t, db, ""), @@ -72,7 +72,7 @@ func TestProxyHealth_Unregistered(t *testing.T) { func TestProxyHealth_Unhealthy(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) srvBadReport := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httpapi.Write(context.Background(), w, http.StatusOK, codersdk.ProxyHealthReport{ @@ -112,7 +112,7 @@ func TestProxyHealth_Unhealthy(t *testing.T) { func TestProxyHealth_Reachable(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httpapi.Write(context.Background(), w, http.StatusOK, codersdk.ProxyHealthReport{ @@ -147,7 +147,7 @@ func TestProxyHealth_Reachable(t *testing.T) { func TestProxyHealth_Unreachable(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) cli := &http.Client{ Transport: &http.Transport{ diff --git a/enterprise/coderd/testdata/parameters/dynamicimmutable/main.tf b/enterprise/coderd/testdata/parameters/dynamicimmutable/main.tf new file mode 100644 index 0000000000000..08bdd3336faa9 --- /dev/null +++ b/enterprise/coderd/testdata/parameters/dynamicimmutable/main.tf @@ -0,0 +1,23 @@ +terraform { + required_providers { + coder = { + source = "coder/coder" + } + } +} + +data "coder_workspace_owner" "me" {} + +data "coder_parameter" "isimmutable" { + name = "isimmutable" + type = "bool" + mutable = true + default = "true" +} + +data "coder_parameter" "immutable" { + name = "immutable" + type = "string" + mutable = data.coder_parameter.isimmutable.value == "false" + default = "Hello World" +} diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go index 3bed052702637..1030536f2111d 100644 --- a/enterprise/coderd/workspaces_test.go +++ b/enterprise/coderd/workspaces_test.go @@ -10,10 +10,17 @@ import ( "os" "os/exec" "path/filepath" + "strings" "sync/atomic" "testing" "time" + "github.com/prometheus/client_golang/prometheus" + + "github.com/coder/coder/v2/coderd/files" + agplprebuilds "github.com/coder/coder/v2/coderd/prebuilds" + "github.com/coder/coder/v2/enterprise/coderd/prebuilds" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -287,7 +294,9 @@ func TestCreateUserWorkspace(t *testing.T) { OrganizationID: first.OrganizationID, }) - template, _ := coderdtest.DynamicParameterTemplate(t, admin, first.OrganizationID, coderdtest.DynamicParameterTemplateParams{}) + template, _ := coderdtest.DynamicParameterTemplate(t, admin, first.OrganizationID, coderdtest.DynamicParameterTemplateParams{ + Zip: true, + }) ctx = testutil.Context(t, testutil.WaitLong) @@ -1713,6 +1722,793 @@ func TestTemplateDoesNotAllowUserAutostop(t *testing.T) { }) } +func TestExecutorPrebuilds(t *testing.T) { + t.Parallel() + + if !dbtestutil.WillUsePostgres() { + t.Skip("this test requires postgres") + } + + getRunningPrebuilds := func( + t *testing.T, + ctx context.Context, + db database.Store, + prebuildInstances int, + ) []database.GetRunningPrebuiltWorkspacesRow { + t.Helper() + + var runningPrebuilds []database.GetRunningPrebuiltWorkspacesRow + testutil.Eventually(ctx, t, func(context.Context) bool { + rows, err := db.GetRunningPrebuiltWorkspaces(ctx) + if err != nil { + return false + } + + for _, row := range rows { + runningPrebuilds = append(runningPrebuilds, row) + + agents, err := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, row.ID) + if err != nil { + return false + } + + for _, agent := range agents { + err = db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agent.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + StartedAt: sql.NullTime{Time: time.Now().Add(time.Hour), Valid: true}, + ReadyAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + }) + if err != nil { + return false + } + } + } + + t.Logf("found %d running prebuilds so far, want %d", len(runningPrebuilds), prebuildInstances) + return len(runningPrebuilds) == prebuildInstances + }, testutil.IntervalSlow, "prebuilds not running") + + return runningPrebuilds + } + + runReconciliationLoop := func( + t *testing.T, + ctx context.Context, + db database.Store, + reconciler *prebuilds.StoreReconciler, + presets []codersdk.Preset, + ) { + t.Helper() + + state, err := reconciler.SnapshotState(ctx, db) + require.NoError(t, err) + ps, err := state.FilterByPreset(presets[0].ID) + require.NoError(t, err) + require.NotNil(t, ps) + actions, err := reconciler.CalculateActions(ctx, *ps) + require.NoError(t, err) + require.NotNil(t, actions) + require.NoError(t, reconciler.ReconcilePreset(ctx, *ps)) + } + + claimPrebuild := func( + t *testing.T, + ctx context.Context, + client *codersdk.Client, + userClient *codersdk.Client, + username string, + version codersdk.TemplateVersion, + presetID uuid.UUID, + ) codersdk.Workspace { + t.Helper() + + workspaceName := strings.ReplaceAll(testutil.GetRandomName(t), "_", "-") + userWorkspace, err := userClient.CreateUserWorkspace(ctx, username, codersdk.CreateWorkspaceRequest{ + TemplateVersionID: version.ID, + Name: workspaceName, + TemplateVersionPresetID: presetID, + }) + require.NoError(t, err) + build := coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, userWorkspace.LatestBuild.ID) + require.Equal(t, build.Job.Status, codersdk.ProvisionerJobSucceeded) + workspace := coderdtest.MustWorkspace(t, client, userWorkspace.ID) + assert.Equal(t, codersdk.WorkspaceTransitionStart, workspace.LatestBuild.Transition) + + return workspace + } + + // Prebuilt workspaces should not be autostopped based on the default TTL. + // This test ensures that DefaultTTLMillis is ignored while the workspace is in a prebuild state. + // Once the workspace is claimed, the default autostop timer should take effect. + t.Run("DefaultTTLOnlyTriggersAfterClaim", func(t *testing.T) { + t.Parallel() + + // Set the clock to Monday, January 1st, 2024 at 8:00 AM UTC to keep the test deterministic + clock := quartz.NewMock(t) + clock.Set(time.Date(2024, 1, 1, 8, 0, 0, 0, time.UTC)) + + // Setup + ctx := testutil.Context(t, testutil.WaitSuperLong) + db, pb := dbtestutil.NewDB(t, dbtestutil.WithDumpOnFailure()) + logger := testutil.Logger(t) + tickCh := make(chan time.Time) + statsCh := make(chan autobuild.Stats) + notificationsNoop := notifications.NewNoopEnqueuer() + client, _, api, owner := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pb, + AutobuildTicker: tickCh, + IncludeProvisionerDaemon: true, + AutobuildStats: statsCh, + Clock: clock, + TemplateScheduleStore: schedule.NewEnterpriseTemplateScheduleStore( + agplUserQuietHoursScheduleStore(), + notificationsNoop, + logger, + clock, + ), + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{codersdk.FeatureAdvancedTemplateScheduling: 1}, + }, + }) + + // Setup Prebuild reconciler + cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) + reconciler := prebuilds.NewStoreReconciler( + db, pb, cache, + codersdk.PrebuildsConfig{}, + logger, + clock, + prometheus.NewRegistry(), + notificationsNoop, + ) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) + api.AGPL.PrebuildsClaimer.Store(&claimer) + + // Setup user, template and template version with a preset with 1 prebuild instance + prebuildInstances := int32(1) + ttlTime := 2 * time.Hour + userClient, user := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, rbac.RoleMember()) + version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, templateWithAgentAndPresetsWithPrebuilds(prebuildInstances)) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID, func(ctr *codersdk.CreateTemplateRequest) { + // Set a template level TTL to trigger the autostop + // Template level TTL can only be set if autostop is disabled for users + ctr.AllowUserAutostop = ptr.Ref[bool](false) + ctr.DefaultTTLMillis = ptr.Ref[int64](ttlTime.Milliseconds()) + }) + presets, err := client.TemplateVersionPresets(ctx, version.ID) + require.NoError(t, err) + require.Len(t, presets, 1) + + // Given: Reconciliation loop runs and starts prebuilt workspace + runReconciliationLoop(t, ctx, db, reconciler, presets) + runningPrebuilds := getRunningPrebuilds(t, ctx, db, int(prebuildInstances)) + require.Len(t, runningPrebuilds, int(prebuildInstances)) + + // Given: a running prebuilt workspace with a deadline, ready to be claimed + prebuild := coderdtest.MustWorkspace(t, client, runningPrebuilds[0].ID) + require.Equal(t, codersdk.WorkspaceTransitionStart, prebuild.LatestBuild.Transition) + require.NotZero(t, prebuild.LatestBuild.Deadline) + + // When: the autobuild executor ticks *after* the deadline + next := prebuild.LatestBuild.Deadline.Time.Add(time.Minute) + clock.Set(next) + go func() { + tickCh <- next + }() + + // Then: the prebuilt workspace should remain in a start transition + prebuildStats := <-statsCh + require.Len(t, prebuildStats.Errors, 0) + require.Len(t, prebuildStats.Transitions, 0) + require.Equal(t, codersdk.WorkspaceTransitionStart, prebuild.LatestBuild.Transition) + prebuild = coderdtest.MustWorkspace(t, client, prebuild.ID) + require.Equal(t, codersdk.BuildReasonInitiator, prebuild.LatestBuild.Reason) + + // Given: a user claims the prebuilt workspace sometime later + clock.Set(clock.Now().Add(ttlTime)) + workspace := claimPrebuild(t, ctx, client, userClient, user.Username, version, presets[0].ID) + require.Equal(t, prebuild.ID, workspace.ID) + // Workspace deadline must be ttlTime from the time it is claimed + require.True(t, workspace.LatestBuild.Deadline.Time.Equal(clock.Now().Add(ttlTime))) + + // When: the autobuild executor ticks *after* the deadline + next = workspace.LatestBuild.Deadline.Time.Add(time.Minute) + clock.Set(next) + go func() { + tickCh <- next + close(tickCh) + }() + + // Then: the workspace should be stopped + workspaceStats := <-statsCh + require.Len(t, workspaceStats.Errors, 0) + require.Len(t, workspaceStats.Transitions, 1) + require.Contains(t, workspaceStats.Transitions, workspace.ID) + require.Equal(t, database.WorkspaceTransitionStop, workspaceStats.Transitions[workspace.ID]) + workspace = coderdtest.MustWorkspace(t, client, workspace.ID) + require.Equal(t, codersdk.BuildReasonAutostop, workspace.LatestBuild.Reason) + }) + + // Prebuild workspaces should not follow the autostop schedule. + // This test verifies that AutostopRequirement (autostop schedule) is ignored while the workspace is a prebuild. + // After being claimed, the workspace should be stopped according to the autostop schedule. + t.Run("AutostopScheduleOnlyTriggersAfterClaim", func(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + isClaimedBeforeDeadline bool + }{ + // If the prebuild is claimed before the scheduled deadline, + // the claimed workspace should inherit and respect that same deadline. + { + name: "ClaimedBeforeDeadline_UsesSameDeadline", + isClaimedBeforeDeadline: true, + }, + // If the prebuild is claimed after the scheduled deadline, + // the workspace should not stop immediately, but instead respect the next + // valid scheduled deadline (the next day). + { + name: "ClaimedAfterDeadline_SchedulesForNextDay", + isClaimedBeforeDeadline: false, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Set the clock to Monday, January 1st, 2024 at 8:00 AM UTC to keep the test deterministic + clock := quartz.NewMock(t) + clock.Set(time.Date(2024, 1, 1, 8, 0, 0, 0, time.UTC)) + + // Setup + ctx := testutil.Context(t, testutil.WaitSuperLong) + db, pb := dbtestutil.NewDB(t, dbtestutil.WithDumpOnFailure()) + logger := testutil.Logger(t) + tickCh := make(chan time.Time) + statsCh := make(chan autobuild.Stats) + notificationsNoop := notifications.NewNoopEnqueuer() + client, _, api, owner := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pb, + AutobuildTicker: tickCh, + IncludeProvisionerDaemon: true, + AutobuildStats: statsCh, + Clock: clock, + TemplateScheduleStore: schedule.NewEnterpriseTemplateScheduleStore( + agplUserQuietHoursScheduleStore(), + notificationsNoop, + logger, + clock, + ), + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{codersdk.FeatureAdvancedTemplateScheduling: 1}, + }, + }) + + // Setup Prebuild reconciler + cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) + reconciler := prebuilds.NewStoreReconciler( + db, pb, cache, + codersdk.PrebuildsConfig{}, + logger, + clock, + prometheus.NewRegistry(), + notificationsNoop, + ) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) + api.AGPL.PrebuildsClaimer.Store(&claimer) + + // Setup user, template and template version with a preset with 1 prebuild instance + prebuildInstances := int32(1) + userClient, user := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, rbac.RoleMember()) + version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, templateWithAgentAndPresetsWithPrebuilds(prebuildInstances)) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID, func(ctr *codersdk.CreateTemplateRequest) { + // Set a template level Autostop schedule to trigger the autostop daily + ctr.AutostopRequirement = ptr.Ref[codersdk.TemplateAutostopRequirement]( + codersdk.TemplateAutostopRequirement{ + DaysOfWeek: []string{"monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday"}, + Weeks: 1, + }) + }) + presets, err := client.TemplateVersionPresets(ctx, version.ID) + require.NoError(t, err) + require.Len(t, presets, 1) + + // Given: Reconciliation loop runs and starts prebuilt workspace + runReconciliationLoop(t, ctx, db, reconciler, presets) + runningPrebuilds := getRunningPrebuilds(t, ctx, db, int(prebuildInstances)) + require.Len(t, runningPrebuilds, int(prebuildInstances)) + + // Given: a running prebuilt workspace with a deadline, ready to be claimed + prebuild := coderdtest.MustWorkspace(t, client, runningPrebuilds[0].ID) + require.Equal(t, codersdk.WorkspaceTransitionStart, prebuild.LatestBuild.Transition) + require.NotZero(t, prebuild.LatestBuild.Deadline) + + next := clock.Now() + if tc.isClaimedBeforeDeadline { + // When: the autobuild executor ticks *before* the deadline: + next = next.Add(time.Minute) + } else { + // When: the autobuild executor ticks *after* the deadline: + next = next.Add(24 * time.Hour) + } + + clock.Set(next) + go func() { + tickCh <- next + }() + + // Then: the prebuilt workspace should remain in a start transition + prebuildStats := <-statsCh + require.Len(t, prebuildStats.Errors, 0) + require.Len(t, prebuildStats.Transitions, 0) + require.Equal(t, codersdk.WorkspaceTransitionStart, prebuild.LatestBuild.Transition) + prebuild = coderdtest.MustWorkspace(t, client, prebuild.ID) + require.Equal(t, codersdk.BuildReasonInitiator, prebuild.LatestBuild.Reason) + + // Given: a user claims the prebuilt workspace + workspace := claimPrebuild(t, ctx, client, userClient, user.Username, version, presets[0].ID) + require.Equal(t, prebuild.ID, workspace.ID) + + if tc.isClaimedBeforeDeadline { + // Then: the claimed workspace should inherit and respect that same deadline. + require.True(t, workspace.LatestBuild.Deadline.Time.Equal(prebuild.LatestBuild.Deadline.Time)) + } else { + // Then: the claimed workspace should respect the next valid scheduled deadline (next day). + require.True(t, workspace.LatestBuild.Deadline.Time.Equal(clock.Now().Truncate(24*time.Hour).Add(24*time.Hour))) + } + + // When: the autobuild executor ticks *after* the deadline: + next = workspace.LatestBuild.Deadline.Time.Add(time.Minute) + clock.Set(next) + go func() { + tickCh <- next + close(tickCh) + }() + + // Then: the workspace should be stopped + workspaceStats := <-statsCh + require.Len(t, workspaceStats.Errors, 0) + require.Len(t, workspaceStats.Transitions, 1) + require.Contains(t, workspaceStats.Transitions, workspace.ID) + require.Equal(t, database.WorkspaceTransitionStop, workspaceStats.Transitions[workspace.ID]) + workspace = coderdtest.MustWorkspace(t, client, workspace.ID) + require.Equal(t, codersdk.BuildReasonAutostop, workspace.LatestBuild.Reason) + }) + } + }) + + // Prebuild workspaces should not follow the autostart schedule. + // This test verifies that AutostartRequirement (autostart schedule) is ignored while the workspace is a prebuild. + t.Run("AutostartScheduleOnlyTriggersAfterClaim", func(t *testing.T) { + t.Parallel() + + // Set the clock to dbtime.Now() to match the workspace build's CreatedAt + clock := quartz.NewMock(t) + clock.Set(dbtime.Now()) + + // Setup + ctx := testutil.Context(t, testutil.WaitSuperLong) + db, pb := dbtestutil.NewDB(t, dbtestutil.WithDumpOnFailure()) + logger := testutil.Logger(t) + tickCh := make(chan time.Time) + statsCh := make(chan autobuild.Stats) + notificationsNoop := notifications.NewNoopEnqueuer() + client, _, api, owner := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pb, + AutobuildTicker: tickCh, + IncludeProvisionerDaemon: true, + AutobuildStats: statsCh, + Clock: clock, + TemplateScheduleStore: schedule.NewEnterpriseTemplateScheduleStore( + agplUserQuietHoursScheduleStore(), + notificationsNoop, + logger, + clock, + ), + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{codersdk.FeatureAdvancedTemplateScheduling: 1}, + }, + }) + + // Setup Prebuild reconciler + cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) + reconciler := prebuilds.NewStoreReconciler( + db, pb, cache, + codersdk.PrebuildsConfig{}, + logger, + clock, + prometheus.NewRegistry(), + notificationsNoop, + ) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) + api.AGPL.PrebuildsClaimer.Store(&claimer) + + // Setup user, template and template version with a preset with 1 prebuild instance + prebuildInstances := int32(1) + userClient, user := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, rbac.RoleMember()) + version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, templateWithAgentAndPresetsWithPrebuilds(prebuildInstances)) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID, func(ctr *codersdk.CreateTemplateRequest) { + // Set a template level Autostart schedule to trigger the autostart daily + ctr.AllowUserAutostart = ptr.Ref[bool](true) + ctr.AutostartRequirement = &codersdk.TemplateAutostartRequirement{DaysOfWeek: codersdk.AllDaysOfWeek} + }) + presets, err := client.TemplateVersionPresets(ctx, version.ID) + require.NoError(t, err) + require.Len(t, presets, 1) + + // Given: Reconciliation loop runs and starts prebuilt workspace + runReconciliationLoop(t, ctx, db, reconciler, presets) + runningPrebuilds := getRunningPrebuilds(t, ctx, db, int(prebuildInstances)) + require.Len(t, runningPrebuilds, int(prebuildInstances)) + + // Given: prebuilt workspace has autostart schedule daily at midnight + prebuild := coderdtest.MustWorkspace(t, client, runningPrebuilds[0].ID) + sched, err := cron.Weekly("CRON_TZ=UTC 0 0 * * *") + require.NoError(t, err) + err = client.UpdateWorkspaceAutostart(ctx, prebuild.ID, codersdk.UpdateWorkspaceAutostartRequest{ + Schedule: ptr.Ref(sched.String()), + }) + require.NoError(t, err) + + // Given: prebuilt workspace is stopped + prebuild = coderdtest.MustTransitionWorkspace(t, client, prebuild.ID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, prebuild.LatestBuild.ID) + + // Tick at the next scheduled time after the prebuild’s LatestBuild.CreatedAt, + // since the next allowed autostart is calculated starting from that point. + // When: the autobuild executor ticks after the scheduled time + go func() { + tickCh <- sched.Next(prebuild.LatestBuild.CreatedAt).Add(time.Minute) + }() + + // Then: the prebuilt workspace should remain in a stop transition + prebuildStats := <-statsCh + require.Len(t, prebuildStats.Errors, 0) + require.Len(t, prebuildStats.Transitions, 0) + require.Equal(t, codersdk.WorkspaceTransitionStop, prebuild.LatestBuild.Transition) + prebuild = coderdtest.MustWorkspace(t, client, prebuild.ID) + require.Equal(t, codersdk.BuildReasonInitiator, prebuild.LatestBuild.Reason) + + // Given: a prebuilt workspace that is running and ready to be claimed + prebuild = coderdtest.MustTransitionWorkspace(t, client, prebuild.ID, codersdk.WorkspaceTransitionStop, codersdk.WorkspaceTransitionStart) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, prebuild.LatestBuild.ID) + + // Make sure the workspace's agent is again ready + getRunningPrebuilds(t, ctx, db, int(prebuildInstances)) + + // Given: a user claims the prebuilt workspace + workspace := claimPrebuild(t, ctx, client, userClient, user.Username, version, presets[0].ID) + require.Equal(t, prebuild.ID, workspace.ID) + require.NotNil(t, workspace.NextStartAt) + + // Given: workspace is stopped + workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + + // Then: the claimed workspace should inherit and respect that same NextStartAt + require.True(t, workspace.NextStartAt.Equal(*prebuild.NextStartAt)) + + // Tick at the next scheduled time after the prebuild’s LatestBuild.CreatedAt, + // since the next allowed autostart is calculated starting from that point. + // When: the autobuild executor ticks after the scheduled time + go func() { + tickCh <- sched.Next(prebuild.LatestBuild.CreatedAt).Add(time.Minute) + }() + + // Then: the workspace should have a NextStartAt equal to the next autostart schedule + workspaceStats := <-statsCh + require.Len(t, workspaceStats.Errors, 0) + require.Len(t, workspaceStats.Transitions, 1) + workspace = coderdtest.MustWorkspace(t, client, workspace.ID) + require.NotNil(t, workspace.NextStartAt) + require.Equal(t, sched.Next(clock.Now()), workspace.NextStartAt.UTC()) + }) + + // Prebuild workspaces should not transition to dormant when the inactive TTL is reached. + // This test verifies that TimeTilDormantMillis is ignored while the workspace is a prebuild. + // After being claimed, the workspace should become dormant according to the configured inactivity period. + t.Run("DormantOnlyAfterClaimed", func(t *testing.T) { + t.Parallel() + + // Set the clock to Monday, January 1st, 2024 at 8:00 AM UTC to keep the test deterministic + clock := quartz.NewMock(t) + clock.Set(time.Date(2024, 1, 1, 8, 0, 0, 0, time.UTC)) + + // Setup + ctx := testutil.Context(t, testutil.WaitSuperLong) + db, pb := dbtestutil.NewDB(t, dbtestutil.WithDumpOnFailure()) + logger := testutil.Logger(t) + tickCh := make(chan time.Time) + statsCh := make(chan autobuild.Stats) + notificationsNoop := notifications.NewNoopEnqueuer() + client, _, api, owner := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pb, + AutobuildTicker: tickCh, + IncludeProvisionerDaemon: true, + AutobuildStats: statsCh, + Clock: clock, + TemplateScheduleStore: schedule.NewEnterpriseTemplateScheduleStore( + agplUserQuietHoursScheduleStore(), + notificationsNoop, + logger, + clock, + ), + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{codersdk.FeatureAdvancedTemplateScheduling: 1}, + }, + }) + + // Setup Prebuild reconciler + cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) + reconciler := prebuilds.NewStoreReconciler( + db, pb, cache, + codersdk.PrebuildsConfig{}, + logger, + clock, + prometheus.NewRegistry(), + notificationsNoop, + ) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) + api.AGPL.PrebuildsClaimer.Store(&claimer) + + // Setup user, template and template version with a preset with 1 prebuild instance + prebuildInstances := int32(1) + inactiveTTL := 2 * time.Hour + userClient, user := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, rbac.RoleMember()) + version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, templateWithAgentAndPresetsWithPrebuilds(prebuildInstances)) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID, func(ctr *codersdk.CreateTemplateRequest) { + // Set a template level inactive TTL to trigger dormancy + ctr.TimeTilDormantMillis = ptr.Ref[int64](inactiveTTL.Milliseconds()) + }) + presets, err := client.TemplateVersionPresets(ctx, version.ID) + require.NoError(t, err) + require.Len(t, presets, 1) + + // Given: reconciliation loop runs and starts prebuilt workspace + runReconciliationLoop(t, ctx, db, reconciler, presets) + runningPrebuilds := getRunningPrebuilds(t, ctx, db, int(prebuildInstances)) + require.Len(t, runningPrebuilds, int(prebuildInstances)) + + // Given: a running prebuilt workspace, ready to be claimed + prebuild := coderdtest.MustWorkspace(t, client, runningPrebuilds[0].ID) + require.Equal(t, codersdk.WorkspaceTransitionStart, prebuild.LatestBuild.Transition) + + // When: the autobuild executor ticks *after* the inactive TTL + go func() { + tickCh <- prebuild.LastUsedAt.Add(inactiveTTL).Add(time.Minute) + }() + + // Then: the prebuilt workspace should remain in a start transition + prebuildStats := <-statsCh + require.Len(t, prebuildStats.Errors, 0) + require.Len(t, prebuildStats.Transitions, 0) + require.Equal(t, codersdk.WorkspaceTransitionStart, prebuild.LatestBuild.Transition) + prebuild = coderdtest.MustWorkspace(t, client, prebuild.ID) + require.Equal(t, codersdk.BuildReasonInitiator, prebuild.LatestBuild.Reason) + + // Given: a user claims the prebuilt workspace sometime later + clock.Set(clock.Now().Add(inactiveTTL)) + workspace := claimPrebuild(t, ctx, client, userClient, user.Username, version, presets[0].ID) + require.Equal(t, prebuild.ID, workspace.ID) + require.Nil(t, prebuild.DormantAt) + + // When: the autobuild executor ticks *after* the inactive TTL + go func() { + tickCh <- prebuild.LastUsedAt.Add(inactiveTTL).Add(time.Minute) + close(tickCh) + }() + + // Then: the workspace should transition to stopped state for breaching failure TTL + workspaceStats := <-statsCh + require.Len(t, workspaceStats.Errors, 0) + require.Len(t, workspaceStats.Transitions, 1) + require.Contains(t, workspaceStats.Transitions, workspace.ID) + require.Equal(t, database.WorkspaceTransitionStop, workspaceStats.Transitions[workspace.ID]) + workspace = coderdtest.MustWorkspace(t, client, workspace.ID) + require.Equal(t, codersdk.BuildReasonDormancy, workspace.LatestBuild.Reason) + require.NotNil(t, workspace.DormantAt) + }) + + // Prebuild workspaces should not be deleted when the failure TTL is reached. + // This test verifies that FailureTTLMillis is ignored while the workspace is a prebuild. + t.Run("FailureTTLOnlyAfterClaimed", func(t *testing.T) { + t.Parallel() + + // Set the clock to Monday, January 1st, 2024 at 8:00 AM UTC to keep the test deterministic + clock := quartz.NewMock(t) + clock.Set(time.Date(2024, 1, 1, 8, 0, 0, 0, time.UTC)) + + // Setup + ctx := testutil.Context(t, testutil.WaitSuperLong) + db, pb := dbtestutil.NewDB(t, dbtestutil.WithDumpOnFailure()) + logger := testutil.Logger(t) + tickCh := make(chan time.Time) + statsCh := make(chan autobuild.Stats) + notificationsNoop := notifications.NewNoopEnqueuer() + client, _, api, owner := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pb, + AutobuildTicker: tickCh, + IncludeProvisionerDaemon: true, + AutobuildStats: statsCh, + Clock: clock, + TemplateScheduleStore: schedule.NewEnterpriseTemplateScheduleStore( + agplUserQuietHoursScheduleStore(), + notificationsNoop, + logger, + clock, + ), + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureAdvancedTemplateScheduling: 1, + }, + }, + }) + + // Setup Prebuild reconciler + cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) + reconciler := prebuilds.NewStoreReconciler( + db, pb, cache, + codersdk.PrebuildsConfig{}, + logger, + clock, + prometheus.NewRegistry(), + notificationsNoop, + ) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) + api.AGPL.PrebuildsClaimer.Store(&claimer) + + // Setup user, template and template version with a preset with 1 prebuild instance + prebuildInstances := int32(1) + failureTTL := 2 * time.Hour + version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, templateWithFailedResponseAndPresetsWithPrebuilds(prebuildInstances)) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID, func(ctr *codersdk.CreateTemplateRequest) { + // Set a template level Failure TTL to trigger workspace deletion + ctr.FailureTTLMillis = ptr.Ref[int64](failureTTL.Milliseconds()) + }) + presets, err := client.TemplateVersionPresets(ctx, version.ID) + require.NoError(t, err) + require.Len(t, presets, 1) + + // Given: reconciliation loop runs and starts prebuilt workspace in failed state + runReconciliationLoop(t, ctx, db, reconciler, presets) + + var failedWorkspaceBuilds []database.GetFailedWorkspaceBuildsByTemplateIDRow + require.Eventually(t, func() bool { + rows, err := db.GetFailedWorkspaceBuildsByTemplateID(ctx, database.GetFailedWorkspaceBuildsByTemplateIDParams{ + TemplateID: template.ID, + }) + if err != nil { + return false + } + + failedWorkspaceBuilds = append(failedWorkspaceBuilds, rows...) + + t.Logf("found %d failed prebuilds so far, want %d", len(failedWorkspaceBuilds), prebuildInstances) + return len(failedWorkspaceBuilds) == int(prebuildInstances) + }, testutil.WaitSuperLong, testutil.IntervalSlow) + require.Len(t, failedWorkspaceBuilds, int(prebuildInstances)) + + // Given: a failed prebuilt workspace + prebuild := coderdtest.MustWorkspace(t, client, failedWorkspaceBuilds[0].WorkspaceID) + require.Equal(t, codersdk.WorkspaceStatusFailed, prebuild.LatestBuild.Status) + + // When: the autobuild executor ticks *after* the failure TTL + go func() { + tickCh <- prebuild.LatestBuild.Job.CompletedAt.Add(failureTTL * 2) + }() + + // Then: the prebuilt workspace should remain in a start transition + prebuildStats := <-statsCh + require.Len(t, prebuildStats.Errors, 0) + require.Len(t, prebuildStats.Transitions, 0) + require.Equal(t, codersdk.WorkspaceTransitionStart, prebuild.LatestBuild.Transition) + prebuild = coderdtest.MustWorkspace(t, client, prebuild.ID) + require.Equal(t, codersdk.BuildReasonInitiator, prebuild.LatestBuild.Reason) + }) +} + +func templateWithAgentAndPresetsWithPrebuilds(desiredInstances int32) *echo.Responses { + return &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: []*proto.Response{ + { + Type: &proto.Response_Plan{ + Plan: &proto.PlanComplete{ + Presets: []*proto.Preset{ + { + Name: "preset-test", + Parameters: []*proto.PresetParameter{ + { + Name: "k1", + Value: "v1", + }, + }, + Prebuild: &proto.Prebuild{ + Instances: desiredInstances, + }, + }, + }, + }, + }, + }, + }, + ProvisionApply: []*proto.Response{ + { + Type: &proto.Response_Apply{ + Apply: &proto.ApplyComplete{ + Resources: []*proto.Resource{ + { + Type: "compute", + Name: "main", + Agents: []*proto.Agent{ + { + Name: "smith", + OperatingSystem: "linux", + Architecture: "i386", + }, + }, + }, + }, + }, + }, + }, + }, + } +} + +func templateWithFailedResponseAndPresetsWithPrebuilds(desiredInstances int32) *echo.Responses { + return &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: []*proto.Response{ + { + Type: &proto.Response_Plan{ + Plan: &proto.PlanComplete{ + Presets: []*proto.Preset{ + { + Name: "preset-test", + Parameters: []*proto.PresetParameter{ + { + Name: "k1", + Value: "v1", + }, + }, + Prebuild: &proto.Prebuild{ + Instances: desiredInstances, + }, + }, + }, + }, + }, + }, + }, + ProvisionApply: echo.ApplyFailed, + } +} + // TestWorkspaceTemplateParamsChange tests a workspace with a parameter that // validation changes on apply. The params used in create workspace are invalid // according to the static params on import. diff --git a/enterprise/replicasync/replicasync_test.go b/enterprise/replicasync/replicasync_test.go index 1a9fd50e81223..0438db8e21673 100644 --- a/enterprise/replicasync/replicasync_test.go +++ b/enterprise/replicasync/replicasync_test.go @@ -16,10 +16,8 @@ import ( "go.uber.org/goleak" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" - "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/enterprise/replicasync" "github.com/coder/coder/v2/testutil" ) @@ -215,11 +213,7 @@ func TestReplica(t *testing.T) { t.Parallel() ctx, cancelCtx := context.WithCancel(context.Background()) defer cancelCtx() - // This doesn't use the database fake because creating - // this many PostgreSQL connections takes some - // configuration tweaking. - db := dbmem.New() - pubsub := pubsub.NewInMemory() + db, pubsub := dbtestutil.NewDB(t) logger := testutil.Logger(t) dh := &derpyHandler{} defer dh.requireOnlyDERPPaths(t) diff --git a/enterprise/trialer/trialer_test.go b/enterprise/trialer/trialer_test.go index 7149044a3e89f..575c945fe3d8f 100644 --- a/enterprise/trialer/trialer_test.go +++ b/enterprise/trialer/trialer_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/trialer" @@ -24,10 +24,12 @@ func TestTrialer(t *testing.T) { _, _ = w.Write([]byte(license)) })) defer srv.Close() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) + err := db.InsertDeploymentID(context.Background(), "test-deployment") + require.NoError(t, err) gen := trialer.New(db, srv.URL, coderdenttest.Keys) - err := gen(context.Background(), codersdk.LicensorTrialRequest{Email: "kyle+colin@coder.com"}) + err = gen(context.Background(), codersdk.LicensorTrialRequest{Email: "kyle+colin@coder.com"}) require.NoError(t, err) licenses, err := db.GetLicenses(context.Background()) require.NoError(t, err) diff --git a/go.mod b/go.mod index 886515cf29dbf..fa91932ceaecf 100644 --- a/go.mod +++ b/go.mod @@ -341,7 +341,7 @@ require ( github.com/hashicorp/go-terraform-address v0.0.0-20240523040243-ccea9d309e0c github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/hashicorp/hcl v1.0.1-vault-7 // indirect - github.com/hashicorp/hcl/v2 v2.23.0 + github.com/hashicorp/hcl/v2 v2.24.0 github.com/hashicorp/logutils v1.0.0 // indirect github.com/hashicorp/terraform-plugin-go v0.27.0 // indirect github.com/hashicorp/terraform-plugin-log v0.9.0 // indirect @@ -485,7 +485,7 @@ require ( github.com/coder/aisdk-go v0.0.9 github.com/coder/preview v1.0.3-0.20250701142654-c3d6e86b9393 github.com/fsnotify/fsnotify v1.9.0 - github.com/mark3labs/mcp-go v0.32.0 + github.com/mark3labs/mcp-go v0.33.0 ) require ( diff --git a/go.sum b/go.sum index ded3464d585b3..e46a4eb61a477 100644 --- a/go.sum +++ b/go.sum @@ -1376,8 +1376,8 @@ github.com/hashicorp/hc-install v0.9.2 h1:v80EtNX4fCVHqzL9Lg/2xkp62bbvQMnvPQ0G+O github.com/hashicorp/hc-install v0.9.2/go.mod h1:XUqBQNnuT4RsxoxiM9ZaUk0NX8hi2h+Lb6/c0OZnC/I= github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I= github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM= -github.com/hashicorp/hcl/v2 v2.23.0 h1:Fphj1/gCylPxHutVSEOf2fBOh1VE4AuLV7+kbJf3qos= -github.com/hashicorp/hcl/v2 v2.23.0/go.mod h1:62ZYHrXgPoX8xBnzl8QzbWq4dyDsDtfCRgIq1rbJEvA= +github.com/hashicorp/hcl/v2 v2.24.0 h1:2QJdZ454DSsYGoaE6QheQZjtKZSUs9Nh2izTWiwQxvE= +github.com/hashicorp/hcl/v2 v2.24.0/go.mod h1:oGoO1FIQYfn/AgyOhlg9qLC6/nOJPX3qGbkZpYAcqfM= github.com/hashicorp/logutils v1.0.0 h1:dLEQVugN8vlakKOUE3ihGLTZJRB4j+M2cdTm/ORI65Y= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= github.com/hashicorp/terraform-exec v0.23.0 h1:MUiBM1s0CNlRFsCLJuM5wXZrzA3MnPYEsiXmzATMW/I= @@ -1503,8 +1503,8 @@ github.com/makeworld-the-better-one/dither/v2 v2.4.0 h1:Az/dYXiTcwcRSe59Hzw4RI1r github.com/makeworld-the-better-one/dither/v2 v2.4.0/go.mod h1:VBtN8DXO7SNtyGmLiGA7IsFeKrBkQPze1/iAeM95arc= github.com/marekm4/color-extractor v1.2.1 h1:3Zb2tQsn6bITZ8MBVhc33Qn1k5/SEuZ18mrXGUqIwn0= github.com/marekm4/color-extractor v1.2.1/go.mod h1:90VjmiHI6M8ez9eYUaXLdcKnS+BAOp7w+NpwBdkJmpA= -github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= -github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mark3labs/mcp-go v0.33.0 h1:naxhjnTIs/tyPZmWUZFuG0lDmdA6sUyYGGf3gsHvTCc= +github.com/mark3labs/mcp-go v0.33.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= diff --git a/scaletest/workspacebuild/run.go b/scaletest/workspacebuild/run.go index f19c556823faf..3b369a4e48a72 100644 --- a/scaletest/workspacebuild/run.go +++ b/scaletest/workspacebuild/run.go @@ -150,7 +150,7 @@ func (r *CleanupRunner) Run(ctx context.Context, _ string, logs io.Writer) error if err == nil && build.Job.Status.Active() { // mark the build as canceled logger.Info(ctx, "canceling workspace build", slog.F("build_id", build.ID), slog.F("workspace_id", r.workspaceID)) - if err = r.client.CancelWorkspaceBuild(ctx, build.ID); err == nil { + if err = r.client.CancelWorkspaceBuild(ctx, build.ID, codersdk.CancelWorkspaceBuildParams{}); err == nil { // Wait for the job to cancel before we delete it _ = waitForBuild(ctx, logs, r.client, build.ID) // it will return a "build canceled" error } else { diff --git a/scripts/build_go.sh b/scripts/build_go.sh index 97d9431beb544..b3b074b183f91 100755 --- a/scripts/build_go.sh +++ b/scripts/build_go.sh @@ -20,6 +20,9 @@ # binary will be signed using ./sign_darwin.sh. Read that file for more details # on the requirements. # +# If the --sign-gpg parameter is specified, the output binary will be signed using ./sign_with_gpg.sh. +# Read that file for more details on the requirements. +# # If the --agpl parameter is specified, builds only the AGPL-licensed code (no # Coder enterprise features). # @@ -41,6 +44,7 @@ slim="${CODER_SLIM_BUILD:-0}" agpl="${CODER_BUILD_AGPL:-0}" sign_darwin="${CODER_SIGN_DARWIN:-0}" sign_windows="${CODER_SIGN_WINDOWS:-0}" +sign_gpg="${CODER_SIGN_GPG:-0}" boringcrypto=${CODER_BUILD_BORINGCRYPTO:-0} dylib=0 windows_resources="${CODER_WINDOWS_RESOURCES:-0}" @@ -85,6 +89,10 @@ while true; do sign_windows=1 shift ;; + --sign-gpg) + sign_gpg=1 + shift + ;; --boringcrypto) boringcrypto=1 shift @@ -319,4 +327,9 @@ if [[ "$sign_windows" == 1 ]] && [[ "$os" == "windows" ]]; then execrelative ./sign_windows.sh "$output_path" 1>&2 fi +# Platform agnostic signing +if [[ "$sign_gpg" == 1 ]]; then + execrelative ./sign_with_gpg.sh "$output_path" 1>&2 +fi + echo "$output_path" diff --git a/scripts/dbgen/main.go b/scripts/dbgen/main.go index 7396a5140d605..561a46199a6ef 100644 --- a/scripts/dbgen/main.go +++ b/scripts/dbgen/main.go @@ -7,7 +7,6 @@ import ( "go/format" "go/token" "os" - "path" "path/filepath" "reflect" "runtime" @@ -52,14 +51,6 @@ func run() error { return err } databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database") - - err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbmem", "dbmem.go"), "q", "FakeQuerier", func(_ stubParams) string { - return `panic("not implemented")` - }) - if err != nil { - return xerrors.Errorf("stub dbmem: %w", err) - } - err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbmetrics", "querymetrics.go"), "m", "queryMetricsStore", func(params stubParams) string { return fmt.Sprintf(` start := time.Now() @@ -257,13 +248,13 @@ func orderAndStubDatabaseFunctions(filePath, receiver, structName string, stub f contents, err := os.ReadFile(filePath) if err != nil { - return xerrors.Errorf("read dbmem: %w", err) + return xerrors.Errorf("read file: %w", err) } // Required to preserve imports! f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), packageName, goast.New()).Parse(contents) if err != nil { - return xerrors.Errorf("parse dbmem: %w", err) + return xerrors.Errorf("parse file: %w", err) } pointer := false @@ -298,76 +289,6 @@ func orderAndStubDatabaseFunctions(filePath, receiver, structName string, stub f for _, fn := range funcs { var bodyStmts []dst.Stmt - // Add input validation, only relevant for dbmem. - if strings.Contains(filePath, "dbmem") && len(fn.Func.Params.List) == 2 && fn.Func.Params.List[1].Names[0].Name == "arg" { - /* - err := validateDatabaseType(arg) - if err != nil { - return database.User{}, err - } - */ - bodyStmts = append(bodyStmts, &dst.AssignStmt{ - Lhs: []dst.Expr{dst.NewIdent("err")}, - Tok: token.DEFINE, - Rhs: []dst.Expr{ - &dst.CallExpr{ - Fun: &dst.Ident{ - Name: "validateDatabaseType", - }, - Args: []dst.Expr{dst.NewIdent("arg")}, - }, - }, - }) - returnStmt := &dst.ReturnStmt{ - Results: []dst.Expr{}, // Filled below. - } - bodyStmts = append(bodyStmts, &dst.IfStmt{ - Cond: &dst.BinaryExpr{ - X: dst.NewIdent("err"), - Op: token.NEQ, - Y: dst.NewIdent("nil"), - }, - Body: &dst.BlockStmt{ - List: []dst.Stmt{ - returnStmt, - }, - }, - Decs: dst.IfStmtDecorations{ - NodeDecs: dst.NodeDecs{ - After: dst.EmptyLine, - }, - }, - }) - for _, r := range fn.Func.Results.List { - switch typ := r.Type.(type) { - case *dst.StarExpr, *dst.ArrayType, *dst.SelectorExpr: - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("nil")) - case *dst.Ident: - if typ.Path != "" { - returnStmt.Results = append(returnStmt.Results, dst.NewIdent(fmt.Sprintf("%s.%s{}", path.Base(typ.Path), typ.Name))) - } else { - switch typ.Name { - case "uint8", "uint16", "uint32", "uint64", "uint", "uintptr", - "int8", "int16", "int32", "int64", "int", - "byte", "rune", - "float32", "float64", - "complex64", "complex128": - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("0")) - case "string": - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("\"\"")) - case "bool": - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("false")) - case "error": - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("err")) - default: - panic(fmt.Sprintf("unknown ident: %#v", r.Type)) - } - } - default: - panic(fmt.Sprintf("unknown return type: %T", r.Type)) - } - } - } decl, ok := declByName[fn.Name] if !ok { typeName := structName diff --git a/scripts/release/publish.sh b/scripts/release/publish.sh index df28d46ad2710..5ffd40aeb65cb 100755 --- a/scripts/release/publish.sh +++ b/scripts/release/publish.sh @@ -129,26 +129,9 @@ if [[ "$dry_run" == 0 ]] && [[ "${CODER_GPG_RELEASE_KEY_BASE64:-}" != "" ]]; the log "--- Signing checksums file" log - # Import the GPG key. - old_gnupg_home="${GNUPGHOME:-}" - gnupg_home_temp="$(mktemp -d)" - export GNUPGHOME="$gnupg_home_temp" - echo "$CODER_GPG_RELEASE_KEY_BASE64" | base64 -d | gpg --import 1>&2 - - # Sign the checksums file. This generates a file in the same directory and - # with the same name as the checksums file but ending in ".asc". - # - # We pipe `true` into `gpg` so that it never tries to be interactive (i.e. - # ask for a passphrase). The key we import above is not password protected. - true | gpg --detach-sign --armor "${temp_dir}/${checksum_file}" 1>&2 - - rm -rf "$gnupg_home_temp" - unset GNUPGHOME - if [[ "$old_gnupg_home" != "" ]]; then - export GNUPGHOME="$old_gnupg_home" - fi - + execrelative ../sign_with_gpg.sh "${temp_dir}/${checksum_file}" signed_checksum_path="${temp_dir}/${checksum_file}.asc" + if [[ ! -e "$signed_checksum_path" ]]; then log "Signed checksum file not found: ${signed_checksum_path}" log diff --git a/scripts/sign_with_gpg.sh b/scripts/sign_with_gpg.sh new file mode 100755 index 0000000000000..fb75df5ca1bb9 --- /dev/null +++ b/scripts/sign_with_gpg.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash + +# This script signs a given binary using GPG. +# It expects the binary to be signed as the first argument. +# +# Usage: ./sign_with_gpg.sh path/to/binary +# +# On success, the input file will be signed using the GPG key and the signature output file will moved to /site/out/bin/ (happens in the Makefile) +# +# Depends on the GPG utility. Requires the following environment variables to be set: +# - $CODER_GPG_RELEASE_KEY_BASE64: The base64 encoded private key to use. + +set -euo pipefail +# shellcheck source=scripts/lib.sh +source "$(dirname "${BASH_SOURCE[0]}")/lib.sh" + +requiredenvs CODER_GPG_RELEASE_KEY_BASE64 + +FILE_TO_SIGN="$1" + +if [[ -z "$FILE_TO_SIGN" ]]; then + error "Usage: $0 " +fi + +if [[ ! -f "$FILE_TO_SIGN" ]]; then + error "File not found: $FILE_TO_SIGN" +fi + +# Import the GPG key. +old_gnupg_home="${GNUPGHOME:-}" +gnupg_home_temp="$(mktemp -d)" +export GNUPGHOME="$gnupg_home_temp" + +# Ensure GPG uses the temporary directory +echo "$CODER_GPG_RELEASE_KEY_BASE64" | base64 -d | gpg --homedir "$gnupg_home_temp" --import 1>&2 + +# Sign the binary. This generates a file in the same directory and +# with the same name as the binary but ending in ".asc". +# +# We pipe `true` into `gpg` so that it never tries to be interactive (i.e. +# ask for a passphrase). The key we import above is not password protected. +true | gpg --homedir "$gnupg_home_temp" --detach-sign --armor "$FILE_TO_SIGN" 1>&2 + +# Verify the signature and capture the exit status +gpg --homedir "$gnupg_home_temp" --verify "${FILE_TO_SIGN}.asc" "$FILE_TO_SIGN" 1>&2 +verification_result=$? + +# Clean up the temporary GPG home +rm -rf "$gnupg_home_temp" +unset GNUPGHOME +if [[ "$old_gnupg_home" != "" ]]; then + export GNUPGHOME="$old_gnupg_home" +fi + +if [[ $verification_result -eq 0 ]]; then + echo "${FILE_TO_SIGN}.asc" +else + error "Signature verification failed!" +fi diff --git a/site/e2e/playwright.config.ts b/site/e2e/playwright.config.ts index 436af99240493..4b3e5c5c86fc6 100644 --- a/site/e2e/playwright.config.ts +++ b/site/e2e/playwright.config.ts @@ -72,7 +72,7 @@ export default defineConfig({ "--global-config $(mktemp -d -t e2e-XXXXXXXXXX)", `--access-url=http://localhost:${coderPort}`, `--http-address=0.0.0.0:${coderPort}`, - "--in-memory", + "--ephemeral", "--telemetry=false", "--dangerous-disable-rate-limits", "--provisioner-daemons 10", diff --git a/site/package.json b/site/package.json index 1512a803b0a96..e3a99b9d8eebf 100644 --- a/site/package.json +++ b/site/package.json @@ -120,6 +120,7 @@ "undici": "6.21.2", "unique-names-generator": "4.7.1", "uuid": "9.0.1", + "websocket-ts": "2.2.1", "yup": "1.6.1" }, "devDependencies": { diff --git a/site/pnpm-lock.yaml b/site/pnpm-lock.yaml index 62cdc6176092a..3c7f5176b5b6b 100644 --- a/site/pnpm-lock.yaml +++ b/site/pnpm-lock.yaml @@ -274,6 +274,9 @@ importers: uuid: specifier: 9.0.1 version: 9.0.1 + websocket-ts: + specifier: 2.2.1 + version: 2.2.1 yup: specifier: 1.6.1 version: 1.6.1 @@ -6344,6 +6347,9 @@ packages: webpack-virtual-modules@0.5.0: resolution: {integrity: sha512-kyDivFZ7ZM0BVOUteVbDFhlRt7Ah/CSPwJdi8hBpkK7QLumUqdLtVfm/PX/hkcnrvr0i77fO5+TjZ94Pe+C9iw==, tarball: https://registry.npmjs.org/webpack-virtual-modules/-/webpack-virtual-modules-0.5.0.tgz} + websocket-ts@2.2.1: + resolution: {integrity: sha512-YKPDfxlK5qOheLZ2bTIiktZO1bpfGdNCPJmTEaPW7G9UXI1GKjDdeacOrsULUS000OPNxDVOyAuKLuIWPqWM0Q==, tarball: https://registry.npmjs.org/websocket-ts/-/websocket-ts-2.2.1.tgz} + whatwg-encoding@2.0.0: resolution: {integrity: sha512-p41ogyeMUrw3jWclHWTQg1k05DSVXPLcVxRTYsXUk+ZooOCZLcoYgPZ/HL/D/N+uQPOtcp1me1WhBEaX02mhWg==, tarball: https://registry.npmjs.org/whatwg-encoding/-/whatwg-encoding-2.0.0.tgz} engines: {node: '>=12'} @@ -13266,6 +13272,8 @@ snapshots: webpack-virtual-modules@0.5.0: {} + websocket-ts@2.2.1: {} + whatwg-encoding@2.0.0: dependencies: iconv-lite: 0.6.3 diff --git a/site/site_test.go b/site/site_test.go index f7301debba2be..fa3c0809f22a7 100644 --- a/site/site_test.go +++ b/site/site_test.go @@ -27,7 +27,6 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpmw" @@ -46,7 +45,7 @@ func TestInjection(t *testing.T) { }, } binFs := http.FS(fstest.MapFS{}) - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) handler := site.New(&site.Options{ Telemetry: telemetry.NewNoop(), BinFS: binFs, @@ -73,13 +72,17 @@ func TestInjection(t *testing.T) { // This will update as part of the request! got.LastSeenAt = user.LastSeenAt + // json.Unmarshal doesn't parse the timezone correctly + got.CreatedAt = got.CreatedAt.In(user.CreatedAt.Location()) + got.UpdatedAt = got.UpdatedAt.In(user.CreatedAt.Location()) + require.Equal(t, db2sdk.User(user, []uuid.UUID{}), got) } func TestInjectionFailureProducesCleanHTML(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) // Create an expired user with a refresh token, but provide no OAuth2 // configuration so that refresh is impossible, this should result in diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 2b13c77faffa1..7c10188648121 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -129,6 +129,14 @@ export const watchWorkspace = ( }); }; +export const watchAgentContainers = ( + agentId: string, +): OneWayWebSocket => { + return new OneWayWebSocket({ + apiRoute: `/api/v2/workspaceagents/${agentId}/containers/watch`, + }); +}; + type WatchInboxNotificationsParams = Readonly<{ read_status?: "read" | "unread" | "all"; }>; @@ -1277,9 +1285,12 @@ class ApiMethods { cancelWorkspaceBuild = async ( workspaceBuildId: TypesGen.WorkspaceBuild["id"], + params?: TypesGen.CancelWorkspaceBuildParams, ): Promise => { const response = await this.axios.patch( `/api/v2/workspacebuilds/${workspaceBuildId}/cancel`, + null, + { params }, ); return response.data; diff --git a/site/src/api/queries/workspaces.ts b/site/src/api/queries/workspaces.ts index 5a4cdb46dd4e9..05fb09314d741 100644 --- a/site/src/api/queries/workspaces.ts +++ b/site/src/api/queries/workspaces.ts @@ -266,7 +266,12 @@ export const startWorkspace = ( export const cancelBuild = (workspace: Workspace, queryClient: QueryClient) => { return { mutationFn: () => { - return API.cancelWorkspaceBuild(workspace.latest_build.id); + const { status } = workspace.latest_build; + const params = + status === "pending" || status === "running" + ? { expect_status: status } + : undefined; + return API.cancelWorkspaceBuild(workspace.latest_build.id, params); }, onSuccess: async () => { await queryClient.invalidateQueries({ diff --git a/site/src/api/rbacresourcesGenerated.ts b/site/src/api/rbacresourcesGenerated.ts index de09b245ff049..5d632d57fad95 100644 --- a/site/src/api/rbacresourcesGenerated.ts +++ b/site/src/api/rbacresourcesGenerated.ts @@ -31,6 +31,10 @@ export const RBACResourceActions: Partial< create: "create new audit log entries", read: "read audit logs", }, + connection_log: { + read: "read connection logs", + update: "upsert connection log entries", + }, crypto_key: { create: "create crypto keys", delete: "delete crypto keys", diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 4ab5403081a60..23a739df063de 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -275,11 +275,12 @@ export interface BuildInfoResponse { } // From codersdk/workspacebuilds.go -export type BuildReason = "autostart" | "autostop" | "initiator"; +export type BuildReason = "autostart" | "autostop" | "dormancy" | "initiator"; export const BuildReasons: BuildReason[] = [ "autostart", "autostop", + "dormancy", "initiator", ]; @@ -292,6 +293,19 @@ export const BypassRatelimitHeader = "X-Coder-Bypass-Ratelimit"; // From codersdk/client.go export const CLITelemetryHeader = "Coder-CLI-Telemetry"; +// From codersdk/workspacebuilds.go +export interface CancelWorkspaceBuildParams { + readonly expect_status?: CancelWorkspaceBuildStatus; +} + +// From codersdk/workspacebuilds.go +export type CancelWorkspaceBuildStatus = "pending" | "running"; + +export const CancelWorkspaceBuildStatuses: CancelWorkspaceBuildStatus[] = [ + "pending", + "running", +]; + // From codersdk/users.go export interface ChangePasswordWithOneTimePasscodeRequest { readonly email: string; @@ -654,7 +668,6 @@ export interface DeploymentValues { readonly proxy_trusted_headers?: string; readonly proxy_trusted_origins?: string; readonly cache_directory?: string; - readonly in_memory_database?: boolean; readonly ephemeral_deployment?: boolean; readonly pg_connection_url?: string; readonly pg_auth?: string; @@ -907,6 +920,7 @@ export type FeatureName = | "appearance" | "audit_log" | "browser_only" + | "connection_log" | "control_shared_ports" | "custom_roles" | "external_provisioner_daemons" @@ -928,6 +942,7 @@ export const FeatureNames: FeatureName[] = [ "appearance", "audit_log", "browser_only", + "connection_log", "control_shared_ports", "custom_roles", "external_provisioner_daemons", @@ -2228,6 +2243,7 @@ export type RBACResource = | "assign_org_role" | "assign_role" | "audit_log" + | "connection_log" | "crypto_key" | "debug_info" | "deployment_config" @@ -2267,6 +2283,7 @@ export const RBACResources: RBACResource[] = [ "assign_org_role", "assign_role", "audit_log", + "connection_log", "crypto_key", "debug_info", "deployment_config", diff --git a/site/src/hooks/index.ts b/site/src/hooks/index.ts index 4453e36fa4bb4..901fee8a50ded 100644 --- a/site/src/hooks/index.ts +++ b/site/src/hooks/index.ts @@ -3,4 +3,3 @@ export * from "./useClickable"; export * from "./useClickableTableRow"; export * from "./useClipboard"; export * from "./usePagination"; -export * from "./useWithRetry"; diff --git a/site/src/hooks/useWithRetry.test.ts b/site/src/hooks/useWithRetry.test.ts deleted file mode 100644 index 7ed7b4331f21e..0000000000000 --- a/site/src/hooks/useWithRetry.test.ts +++ /dev/null @@ -1,329 +0,0 @@ -import { act, renderHook } from "@testing-library/react"; -import { useWithRetry } from "./useWithRetry"; - -// Mock timers -jest.useFakeTimers(); - -describe("useWithRetry", () => { - let mockFn: jest.Mock; - - beforeEach(() => { - mockFn = jest.fn(); - jest.clearAllTimers(); - }); - - afterEach(() => { - jest.clearAllMocks(); - }); - - it("should initialize with correct default state", () => { - const { result } = renderHook(() => useWithRetry(mockFn)); - - expect(result.current.isLoading).toBe(false); - expect(result.current.nextRetryAt).toBe(undefined); - }); - - it("should execute function successfully on first attempt", async () => { - mockFn.mockResolvedValue(undefined); - - const { result } = renderHook(() => useWithRetry(mockFn)); - - await act(async () => { - await result.current.call(); - }); - - expect(mockFn).toHaveBeenCalledTimes(1); - expect(result.current.isLoading).toBe(false); - expect(result.current.nextRetryAt).toBe(undefined); - }); - - it("should set isLoading to true during execution", async () => { - let resolvePromise: () => void; - const promise = new Promise((resolve) => { - resolvePromise = resolve; - }); - mockFn.mockReturnValue(promise); - - const { result } = renderHook(() => useWithRetry(mockFn)); - - act(() => { - result.current.call(); - }); - - expect(result.current.isLoading).toBe(true); - - await act(async () => { - resolvePromise!(); - await promise; - }); - - expect(result.current.isLoading).toBe(false); - }); - - it("should retry on failure with exponential backoff", async () => { - mockFn - .mockRejectedValueOnce(new Error("First failure")) - .mockRejectedValueOnce(new Error("Second failure")) - .mockResolvedValueOnce(undefined); - - const { result } = renderHook(() => useWithRetry(mockFn)); - - // Start the call - await act(async () => { - await result.current.call(); - }); - - expect(mockFn).toHaveBeenCalledTimes(1); - expect(result.current.isLoading).toBe(false); - expect(result.current.nextRetryAt).not.toBe(null); - - // Fast-forward to first retry (1 second) - await act(async () => { - jest.advanceTimersByTime(1000); - }); - - expect(mockFn).toHaveBeenCalledTimes(2); - expect(result.current.isLoading).toBe(false); - expect(result.current.nextRetryAt).not.toBe(null); - - // Fast-forward to second retry (2 seconds) - await act(async () => { - jest.advanceTimersByTime(2000); - }); - - expect(mockFn).toHaveBeenCalledTimes(3); - expect(result.current.isLoading).toBe(false); - expect(result.current.nextRetryAt).toBe(undefined); - }); - - it("should continue retrying without limit", async () => { - mockFn.mockRejectedValue(new Error("Always fails")); - - const { result } = renderHook(() => useWithRetry(mockFn)); - - // Start the call - await act(async () => { - await result.current.call(); - }); - - expect(mockFn).toHaveBeenCalledTimes(1); - expect(result.current.isLoading).toBe(false); - expect(result.current.nextRetryAt).not.toBe(null); - - // Fast-forward through multiple retries to verify it continues - for (let i = 1; i < 15; i++) { - const delay = Math.min(1000 * 2 ** (i - 1), 600000); // exponential backoff with max delay - await act(async () => { - jest.advanceTimersByTime(delay); - }); - expect(mockFn).toHaveBeenCalledTimes(i + 1); - expect(result.current.isLoading).toBe(false); - expect(result.current.nextRetryAt).not.toBe(null); - } - - // Should still be retrying after 15 attempts - expect(result.current.nextRetryAt).not.toBe(null); - }); - - it("should respect max delay of 10 minutes", async () => { - mockFn.mockRejectedValue(new Error("Always fails")); - - const { result } = renderHook(() => useWithRetry(mockFn)); - - // Start the call - await act(async () => { - await result.current.call(); - }); - - expect(result.current.isLoading).toBe(false); - - // Fast-forward through several retries to reach max delay - // After attempt 9, delay would be 1000 * 2^9 = 512000ms, which is less than 600000ms (10 min) - // After attempt 10, delay would be 1000 * 2^10 = 1024000ms, which should be capped at 600000ms - - // Skip to attempt 9 (delay calculation: 1000 * 2^8 = 256000ms) - for (let i = 1; i < 9; i++) { - const delay = 1000 * 2 ** (i - 1); - await act(async () => { - jest.advanceTimersByTime(delay); - }); - } - - expect(mockFn).toHaveBeenCalledTimes(9); - expect(result.current.nextRetryAt).not.toBe(null); - - // The 9th retry should use max delay (600000ms = 10 minutes) - await act(async () => { - jest.advanceTimersByTime(600000); - }); - - expect(mockFn).toHaveBeenCalledTimes(10); - expect(result.current.isLoading).toBe(false); - expect(result.current.nextRetryAt).not.toBe(null); - - // Continue with more retries at max delay to verify it continues indefinitely - await act(async () => { - jest.advanceTimersByTime(600000); - }); - - expect(mockFn).toHaveBeenCalledTimes(11); - expect(result.current.nextRetryAt).not.toBe(null); - }); - - it("should cancel previous retry when call is invoked again", async () => { - mockFn - .mockRejectedValueOnce(new Error("First failure")) - .mockResolvedValueOnce(undefined); - - const { result } = renderHook(() => useWithRetry(mockFn)); - - // Start the first call - await act(async () => { - await result.current.call(); - }); - - expect(mockFn).toHaveBeenCalledTimes(1); - expect(result.current.isLoading).toBe(false); - expect(result.current.nextRetryAt).not.toBe(null); - - // Call again before retry happens - await act(async () => { - await result.current.call(); - }); - - expect(mockFn).toHaveBeenCalledTimes(2); - expect(result.current.isLoading).toBe(false); - expect(result.current.nextRetryAt).toBe(undefined); - - // Advance time to ensure previous retry was cancelled - await act(async () => { - jest.advanceTimersByTime(5000); - }); - - expect(mockFn).toHaveBeenCalledTimes(2); // Should not have been called again - }); - - it("should set nextRetryAt when scheduling retry", async () => { - mockFn - .mockRejectedValueOnce(new Error("Failure")) - .mockResolvedValueOnce(undefined); - - const { result } = renderHook(() => useWithRetry(mockFn)); - - // Start the call - await act(async () => { - await result.current.call(); - }); - - const nextRetryAt = result.current.nextRetryAt; - expect(nextRetryAt).not.toBe(null); - expect(nextRetryAt).toBeInstanceOf(Date); - - // nextRetryAt should be approximately 1 second in the future - const expectedTime = Date.now() + 1000; - const actualTime = nextRetryAt!.getTime(); - expect(Math.abs(actualTime - expectedTime)).toBeLessThan(100); // Allow 100ms tolerance - - // Advance past retry time - await act(async () => { - jest.advanceTimersByTime(1000); - }); - - expect(result.current.nextRetryAt).toBe(undefined); - }); - - it("should cleanup timer on unmount", async () => { - mockFn.mockRejectedValue(new Error("Failure")); - - const { result, unmount } = renderHook(() => useWithRetry(mockFn)); - - // Start the call to create timer - await act(async () => { - await result.current.call(); - }); - - expect(result.current.isLoading).toBe(false); - expect(result.current.nextRetryAt).not.toBe(null); - - // Unmount should cleanup timer - unmount(); - - // Advance time to ensure timer was cleared - await act(async () => { - jest.advanceTimersByTime(5000); - }); - - // Function should not have been called again - expect(mockFn).toHaveBeenCalledTimes(1); - }); - - it("should prevent scheduling retries when function completes after unmount", async () => { - let rejectPromise: (error: Error) => void; - const promise = new Promise((_, reject) => { - rejectPromise = reject; - }); - mockFn.mockReturnValue(promise); - - const { result, unmount } = renderHook(() => useWithRetry(mockFn)); - - // Start the call - this will make the function in-flight - act(() => { - result.current.call(); - }); - - expect(result.current.isLoading).toBe(true); - - // Unmount while function is still in-flight - unmount(); - - // Function completes with error after unmount - await act(async () => { - rejectPromise!(new Error("Failed after unmount")); - await promise.catch(() => {}); // Suppress unhandled rejection - }); - - // Advance time to ensure no retry timers were scheduled - await act(async () => { - jest.advanceTimersByTime(5000); - }); - - // Function should only have been called once (no retries after unmount) - expect(mockFn).toHaveBeenCalledTimes(1); - }); - - it("should do nothing when call() is invoked while function is already loading", async () => { - let resolvePromise: () => void; - const promise = new Promise((resolve) => { - resolvePromise = resolve; - }); - mockFn.mockReturnValue(promise); - - const { result } = renderHook(() => useWithRetry(mockFn)); - - // Start the first call - this will set isLoading to true - act(() => { - result.current.call(); - }); - - expect(result.current.isLoading).toBe(true); - expect(mockFn).toHaveBeenCalledTimes(1); - - // Try to call again while loading - should do nothing - act(() => { - result.current.call(); - }); - - // Function should not have been called again - expect(mockFn).toHaveBeenCalledTimes(1); - expect(result.current.isLoading).toBe(true); - - // Complete the original promise - await act(async () => { - resolvePromise!(); - await promise; - }); - - expect(result.current.isLoading).toBe(false); - expect(mockFn).toHaveBeenCalledTimes(1); - }); -}); diff --git a/site/src/hooks/useWithRetry.ts b/site/src/hooks/useWithRetry.ts deleted file mode 100644 index 1310da221efc5..0000000000000 --- a/site/src/hooks/useWithRetry.ts +++ /dev/null @@ -1,106 +0,0 @@ -import { useCallback, useEffect, useRef, useState } from "react"; -import { useEffectEvent } from "./hookPolyfills"; - -const DELAY_MS = 1_000; -const MAX_DELAY_MS = 600_000; // 10 minutes -// Determines how much the delay between retry attempts increases after each -// failure. -const MULTIPLIER = 2; - -interface UseWithRetryResult { - call: () => void; - nextRetryAt: Date | undefined; - isLoading: boolean; -} - -interface RetryState { - isLoading: boolean; - nextRetryAt: Date | undefined; -} - -/** - * Hook that wraps a function with automatic retry functionality - * Provides a simple interface for executing functions with exponential backoff retry - */ -export function useWithRetry(fn: () => Promise): UseWithRetryResult { - const [state, setState] = useState({ - isLoading: false, - nextRetryAt: undefined, - }); - - const timeoutRef = useRef(null); - const mountedRef = useRef(true); - - const clearTimeout = useCallback(() => { - if (timeoutRef.current) { - window.clearTimeout(timeoutRef.current); - timeoutRef.current = null; - } - }, []); - - const stableFn = useEffectEvent(fn); - - const call = useCallback(() => { - if (state.isLoading) { - return; - } - - clearTimeout(); - - const executeAttempt = async (attempt = 0): Promise => { - if (!mountedRef.current) { - return; - } - setState({ - isLoading: true, - nextRetryAt: undefined, - }); - - try { - await stableFn(); - if (mountedRef.current) { - setState({ isLoading: false, nextRetryAt: undefined }); - } - } catch (error) { - if (!mountedRef.current) { - return; - } - const delayMs = Math.min( - DELAY_MS * MULTIPLIER ** attempt, - MAX_DELAY_MS, - ); - - setState({ - isLoading: false, - nextRetryAt: new Date(Date.now() + delayMs), - }); - - timeoutRef.current = window.setTimeout(() => { - if (!mountedRef.current) { - return; - } - setState({ - isLoading: false, - nextRetryAt: undefined, - }); - executeAttempt(attempt + 1); - }, delayMs); - } - }; - - executeAttempt(); - }, [state.isLoading, stableFn, clearTimeout]); - - useEffect(() => { - return () => { - mountedRef.current = false; - clearTimeout(); - }; - }, [clearTimeout]); - - return { - call, - nextRetryAt: state.nextRetryAt, - isLoading: state.isLoading, - }; -} diff --git a/site/src/modules/management/DeploymentSidebarView.stories.tsx b/site/src/modules/management/DeploymentSidebarView.stories.tsx index d7fee99bc2ade..2465556110e98 100644 --- a/site/src/modules/management/DeploymentSidebarView.stories.tsx +++ b/site/src/modules/management/DeploymentSidebarView.stories.tsx @@ -1,5 +1,9 @@ import type { Meta, StoryObj } from "@storybook/react"; -import { MockNoPermissions, MockPermissions } from "testHelpers/entities"; +import { + MockBuildInfo, + MockNoPermissions, + MockPermissions, +} from "testHelpers/entities"; import { withDashboardProvider } from "testHelpers/storybook"; import { DeploymentSidebarView } from "./DeploymentSidebarView"; @@ -10,6 +14,8 @@ const meta: Meta = { parameters: { showOrganizations: true }, args: { permissions: MockPermissions, + experiments: [], + buildInfo: MockBuildInfo, }, }; diff --git a/site/src/modules/resources/AgentDevcontainerCard.tsx b/site/src/modules/resources/AgentDevcontainerCard.tsx index c7516dde15c39..bd2f05b123cad 100644 --- a/site/src/modules/resources/AgentDevcontainerCard.tsx +++ b/site/src/modules/resources/AgentDevcontainerCard.tsx @@ -130,12 +130,6 @@ export const AgentDevcontainerCard: FC = ({ return { previousData }; }, - onSuccess: async () => { - // Invalidate the containers query to refetch updated data. - await queryClient.invalidateQueries({ - queryKey: ["agents", parentAgent.id, "containers"], - }); - }, onError: (error, _, context) => { // If the mutation fails, use the context returned from // onMutate to roll back. diff --git a/site/src/modules/resources/AgentRow.tsx b/site/src/modules/resources/AgentRow.tsx index 3d0888f7872b1..0b5d8a5dc15c3 100644 --- a/site/src/modules/resources/AgentRow.tsx +++ b/site/src/modules/resources/AgentRow.tsx @@ -2,14 +2,12 @@ import type { Interpolation, Theme } from "@emotion/react"; import Collapse from "@mui/material/Collapse"; import Divider from "@mui/material/Divider"; import Skeleton from "@mui/material/Skeleton"; -import { API } from "api/api"; import type { Template, Workspace, WorkspaceAgent, WorkspaceAgentMetadata, } from "api/typesGenerated"; -import { isAxiosError } from "axios"; import { Button } from "components/Button/Button"; import { DropdownArrow } from "components/DropdownArrow/DropdownArrow"; import { Stack } from "components/Stack/Stack"; @@ -25,7 +23,6 @@ import { useRef, useState, } from "react"; -import { useQuery } from "react-query"; import AutoSizer from "react-virtualized-auto-sizer"; import type { FixedSizeList as List, ListOnScrollProps } from "react-window"; import { AgentApps, organizeAgentApps } from "./AgentApps/AgentApps"; @@ -41,6 +38,7 @@ import { PortForwardButton } from "./PortForwardButton"; import { AgentSSHButton } from "./SSHButton/SSHButton"; import { TerminalLink } from "./TerminalLink/TerminalLink"; import { VSCodeDesktopButton } from "./VSCodeDesktopButton/VSCodeDesktopButton"; +import { useAgentContainers } from "./useAgentContainers"; import { useAgentLogs } from "./useAgentLogs"; interface AgentRowProps { @@ -133,20 +131,7 @@ export const AgentRow: FC = ({ setBottomOfLogs(distanceFromBottom < AGENT_LOG_LINE_HEIGHT); }, []); - const { data: devcontainers } = useQuery({ - queryKey: ["agents", agent.id, "containers"], - queryFn: () => API.getAgentContainers(agent.id), - enabled: agent.status === "connected", - select: (res) => res.devcontainers, - // TODO: Implement a websocket connection to get updates on containers - // without having to poll. - refetchInterval: ({ state }) => { - const { error } = state; - return isAxiosError(error) && error.response?.status === 403 - ? false - : 10_000; - }, - }); + const devcontainers = useAgentContainers(agent); // This is used to show the parent apps of the devcontainer. const [showParentApps, setShowParentApps] = useState(false); diff --git a/site/src/modules/resources/useAgentContainers.test.tsx b/site/src/modules/resources/useAgentContainers.test.tsx new file mode 100644 index 0000000000000..922941e04c074 --- /dev/null +++ b/site/src/modules/resources/useAgentContainers.test.tsx @@ -0,0 +1,196 @@ +import { renderHook, waitFor } from "@testing-library/react"; +import * as API from "api/api"; +import type { WorkspaceAgentListContainersResponse } from "api/typesGenerated"; +import * as GlobalSnackbar from "components/GlobalSnackbar/utils"; +import { http, HttpResponse } from "msw"; +import type { FC, PropsWithChildren } from "react"; +import { QueryClient, QueryClientProvider } from "react-query"; +import { + MockWorkspaceAgent, + MockWorkspaceAgentDevcontainer, +} from "testHelpers/entities"; +import { server } from "testHelpers/server"; +import type { OneWayWebSocket } from "utils/OneWayWebSocket"; +import { useAgentContainers } from "./useAgentContainers"; + +const createWrapper = (): FC => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }); + return ({ children }) => ( + {children} + ); +}; + +describe("useAgentContainers", () => { + it("returns containers when agent is connected", async () => { + server.use( + http.get( + `/api/v2/workspaceagents/${MockWorkspaceAgent.id}/containers`, + () => { + return HttpResponse.json({ + devcontainers: [MockWorkspaceAgentDevcontainer], + containers: [], + }); + }, + ), + ); + + const { result } = renderHook( + () => useAgentContainers(MockWorkspaceAgent), + { + wrapper: createWrapper(), + }, + ); + + await waitFor(() => { + expect(result.current).toEqual([MockWorkspaceAgentDevcontainer]); + }); + }); + + it("returns undefined when agent is not connected", () => { + const disconnectedAgent = { + ...MockWorkspaceAgent, + status: "disconnected" as const, + }; + + const { result } = renderHook(() => useAgentContainers(disconnectedAgent), { + wrapper: createWrapper(), + }); + + expect(result.current).toBeUndefined(); + }); + + it("handles API errors gracefully", async () => { + server.use( + http.get( + `/api/v2/workspaceagents/${MockWorkspaceAgent.id}/containers`, + () => { + return HttpResponse.error(); + }, + ), + ); + + const { result } = renderHook( + () => useAgentContainers(MockWorkspaceAgent), + { + wrapper: createWrapper(), + }, + ); + + await waitFor(() => { + expect(result.current).toBeUndefined(); + }); + }); + + it("handles parsing errors from WebSocket", async () => { + const displayErrorSpy = jest.spyOn(GlobalSnackbar, "displayError"); + const watchAgentContainersSpy = jest.spyOn(API, "watchAgentContainers"); + + const mockSocket = { + addEventListener: jest.fn(), + close: jest.fn(), + }; + watchAgentContainersSpy.mockReturnValue( + mockSocket as unknown as OneWayWebSocket, + ); + + server.use( + http.get( + `/api/v2/workspaceagents/${MockWorkspaceAgent.id}/containers`, + () => { + return HttpResponse.json({ + devcontainers: [MockWorkspaceAgentDevcontainer], + containers: [], + }); + }, + ), + ); + + const { unmount } = renderHook( + () => useAgentContainers(MockWorkspaceAgent), + { + wrapper: createWrapper(), + }, + ); + + // Simulate message event with parsing error + const messageHandler = mockSocket.addEventListener.mock.calls.find( + (call) => call[0] === "message", + )?.[1]; + + if (messageHandler) { + messageHandler({ + parseError: new Error("Parse error"), + parsedMessage: null, + }); + } + + await waitFor(() => { + expect(displayErrorSpy).toHaveBeenCalledWith( + "Failed to update containers", + "Please try refreshing the page", + ); + }); + + unmount(); + displayErrorSpy.mockRestore(); + watchAgentContainersSpy.mockRestore(); + }); + + it("handles WebSocket errors", async () => { + const displayErrorSpy = jest.spyOn(GlobalSnackbar, "displayError"); + const watchAgentContainersSpy = jest.spyOn(API, "watchAgentContainers"); + + const mockSocket = { + addEventListener: jest.fn(), + close: jest.fn(), + }; + watchAgentContainersSpy.mockReturnValue( + mockSocket as unknown as OneWayWebSocket, + ); + + server.use( + http.get( + `/api/v2/workspaceagents/${MockWorkspaceAgent.id}/containers`, + () => { + return HttpResponse.json({ + devcontainers: [MockWorkspaceAgentDevcontainer], + containers: [], + }); + }, + ), + ); + + const { unmount } = renderHook( + () => useAgentContainers(MockWorkspaceAgent), + { + wrapper: createWrapper(), + }, + ); + + // Simulate error event + const errorHandler = mockSocket.addEventListener.mock.calls.find( + (call) => call[0] === "error", + )?.[1]; + + if (errorHandler) { + errorHandler(new Error("WebSocket error")); + } + + await waitFor(() => { + expect(displayErrorSpy).toHaveBeenCalledWith( + "Failed to load containers", + "Please try refreshing the page", + ); + }); + + unmount(); + displayErrorSpy.mockRestore(); + watchAgentContainersSpy.mockRestore(); + }); +}); diff --git a/site/src/modules/resources/useAgentContainers.ts b/site/src/modules/resources/useAgentContainers.ts new file mode 100644 index 0000000000000..0db4e2fc4b613 --- /dev/null +++ b/site/src/modules/resources/useAgentContainers.ts @@ -0,0 +1,59 @@ +import { API, watchAgentContainers } from "api/api"; +import type { + WorkspaceAgent, + WorkspaceAgentDevcontainer, + WorkspaceAgentListContainersResponse, +} from "api/typesGenerated"; +import { displayError } from "components/GlobalSnackbar/utils"; +import { useEffectEvent } from "hooks/hookPolyfills"; +import { useEffect } from "react"; +import { useQuery, useQueryClient } from "react-query"; + +export function useAgentContainers( + agent: WorkspaceAgent, +): readonly WorkspaceAgentDevcontainer[] | undefined { + const queryClient = useQueryClient(); + + const { data: devcontainers } = useQuery({ + queryKey: ["agents", agent.id, "containers"], + queryFn: () => API.getAgentContainers(agent.id), + enabled: agent.status === "connected", + select: (res) => res.devcontainers, + staleTime: Number.POSITIVE_INFINITY, + }); + + const updateDevcontainersCache = useEffectEvent( + async (data: WorkspaceAgentListContainersResponse) => { + const queryKey = ["agents", agent.id, "containers"]; + + queryClient.setQueryData(queryKey, data); + }, + ); + + useEffect(() => { + const socket = watchAgentContainers(agent.id); + + socket.addEventListener("message", (event) => { + if (event.parseError) { + displayError( + "Failed to update containers", + "Please try refreshing the page", + ); + return; + } + + updateDevcontainersCache(event.parsedMessage); + }); + + socket.addEventListener("error", () => { + displayError( + "Failed to load containers", + "Please try refreshing the page", + ); + }); + + return () => socket.close(); + }, [agent.id, updateDevcontainersCache]); + + return devcontainers; +} diff --git a/site/src/modules/workspaces/WorkspaceBuildCancelDialog/WorkspaceBuildCancelDialog.tsx b/site/src/modules/workspaces/WorkspaceBuildCancelDialog/WorkspaceBuildCancelDialog.tsx new file mode 100644 index 0000000000000..cbd7eb7d0c1ed --- /dev/null +++ b/site/src/modules/workspaces/WorkspaceBuildCancelDialog/WorkspaceBuildCancelDialog.tsx @@ -0,0 +1,32 @@ +import type { Workspace } from "api/typesGenerated"; +import { ConfirmDialog } from "components/Dialogs/ConfirmDialog/ConfirmDialog"; +import type { FC } from "react"; + +interface WorkspaceBuildCancelDialogProps { + open: boolean; + onClose: () => void; + onConfirm: () => void; + workspace: Workspace; +} + +export const WorkspaceBuildCancelDialog: FC< + WorkspaceBuildCancelDialogProps +> = ({ open, onClose, onConfirm, workspace }) => { + const action = + workspace.latest_build.status === "pending" + ? "remove the current build from the build queue" + : "stop the current build process"; + + return ( + + ); +}; diff --git a/site/src/modules/workspaces/actions.ts b/site/src/modules/workspaces/actions.ts index f109c4d9ad1b9..8b17d3e937c74 100644 --- a/site/src/modules/workspaces/actions.ts +++ b/site/src/modules/workspaces/actions.ts @@ -145,7 +145,7 @@ export const abilitiesByWorkspaceStatus = ( case "pending": { return { actions: ["pending"], - canCancel: false, + canCancel: true, canAcceptJobs: false, }; } diff --git a/site/src/pages/TerminalPage/TerminalAlerts.tsx b/site/src/pages/TerminalPage/TerminalAlerts.tsx index 07740135769f3..6a06a76964128 100644 --- a/site/src/pages/TerminalPage/TerminalAlerts.tsx +++ b/site/src/pages/TerminalPage/TerminalAlerts.tsx @@ -170,14 +170,16 @@ const TerminalAlert: FC = (props) => { ); }; +// Since the terminal connection is always trying to reconnect, we show this +// alert to indicate that the terminal is trying to connect. const DisconnectedAlert: FC = (props) => { return ( } > - Disconnected + Trying to connect... ); }; diff --git a/site/src/pages/TerminalPage/TerminalPage.test.tsx b/site/src/pages/TerminalPage/TerminalPage.test.tsx index 7600fa5257d43..4591190ad9904 100644 --- a/site/src/pages/TerminalPage/TerminalPage.test.tsx +++ b/site/src/pages/TerminalPage/TerminalPage.test.tsx @@ -85,7 +85,7 @@ describe("TerminalPage", () => { await expectTerminalText(container, Language.workspaceErrorMessagePrefix); }); - it("shows an error if the websocket fails", async () => { + it("shows reconnect message when websocket fails", async () => { server.use( http.get("/api/v2/workspaceagents/:agentId/pty", () => { return HttpResponse.json({}, { status: 500 }); @@ -94,7 +94,9 @@ describe("TerminalPage", () => { const { container } = await renderTerminal(); - await expectTerminalText(container, Language.websocketErrorMessagePrefix); + await waitFor(() => { + expect(container.textContent).toContain("Trying to connect..."); + }); }); it("renders data from the backend", async () => { diff --git a/site/src/pages/TerminalPage/TerminalPage.tsx b/site/src/pages/TerminalPage/TerminalPage.tsx index 2023bdb0eeb29..5c13e89c30005 100644 --- a/site/src/pages/TerminalPage/TerminalPage.tsx +++ b/site/src/pages/TerminalPage/TerminalPage.tsx @@ -26,6 +26,13 @@ import { openMaybePortForwardedURL } from "utils/portForward"; import { terminalWebsocketUrl } from "utils/terminal"; import { getMatchingAgentOrFirst } from "utils/workspace"; import { v4 as uuidv4 } from "uuid"; +// Use websocket-ts for better WebSocket handling and auto-reconnection. +import { + ExponentialBackoff, + type Websocket, + WebsocketBuilder, + WebsocketEvent, +} from "websocket-ts"; import { TerminalAlerts } from "./TerminalAlerts"; import type { ConnectionStatus } from "./types"; @@ -221,7 +228,7 @@ const TerminalPage: FC = () => { } // Hook up terminal events to the websocket. - let websocket: WebSocket | null; + let websocket: Websocket | null; const disposers = [ terminal.onData((data) => { websocket?.send( @@ -259,9 +266,11 @@ const TerminalPage: FC = () => { if (disposed) { return; // Unmounted while we waited for the async call. } - websocket = new WebSocket(url); + websocket = new WebsocketBuilder(url) + .withBackoff(new ExponentialBackoff(1000, 6)) + .build(); websocket.binaryType = "arraybuffer"; - websocket.addEventListener("open", () => { + websocket.addEventListener(WebsocketEvent.open, () => { // Now that we are connected, allow user input. terminal.options = { disableStdin: false, @@ -278,18 +287,16 @@ const TerminalPage: FC = () => { ); setConnectionStatus("connected"); }); - websocket.addEventListener("error", () => { + websocket.addEventListener(WebsocketEvent.error, (_, event) => { + console.error("WebSocket error:", event); terminal.options.disableStdin = true; - terminal.writeln( - `${Language.websocketErrorMessagePrefix}socket errored`, - ); setConnectionStatus("disconnected"); }); - websocket.addEventListener("close", () => { + websocket.addEventListener(WebsocketEvent.close, () => { terminal.options.disableStdin = true; setConnectionStatus("disconnected"); }); - websocket.addEventListener("message", (event) => { + websocket.addEventListener(WebsocketEvent.message, (_, event) => { if (typeof event.data === "string") { // This exclusively occurs when testing. // "jest-websocket-mock" doesn't support ArrayBuffer. @@ -298,12 +305,25 @@ const TerminalPage: FC = () => { terminal.write(new Uint8Array(event.data)); } }); + websocket.addEventListener(WebsocketEvent.reconnect, () => { + if (websocket) { + websocket.binaryType = "arraybuffer"; + websocket.send( + new TextEncoder().encode( + JSON.stringify({ + height: terminal.rows, + width: terminal.cols, + }), + ), + ); + } + }); }) .catch((error) => { if (disposed) { return; // Unmounted while we waited for the async call. } - terminal.writeln(Language.websocketErrorMessagePrefix + error.message); + console.error("WebSocket connection failed:", error); setConnectionStatus("disconnected"); }); diff --git a/site/src/pages/WorkspacePage/WorkspacePage.test.tsx b/site/src/pages/WorkspacePage/WorkspacePage.test.tsx index ad320018da9fb..645c03380501a 100644 --- a/site/src/pages/WorkspacePage/WorkspacePage.test.tsx +++ b/site/src/pages/WorkspacePage/WorkspacePage.test.tsx @@ -19,6 +19,7 @@ import { MockFailedWorkspace, MockOrganization, MockOutdatedWorkspace, + MockPendingWorkspace, MockStartingWorkspace, MockStoppedWorkspace, MockTemplate, @@ -224,11 +225,59 @@ describe("WorkspacePage", () => { }), ); + const user = userEvent.setup({ delay: 0 }); const cancelWorkspaceMock = jest .spyOn(API, "cancelWorkspaceBuild") .mockImplementation(() => Promise.resolve({ message: "job canceled" })); + await renderWorkspacePage(MockStartingWorkspace); + + // Click on Cancel + const cancelButton = await screen.findByRole("button", { name: "Cancel" }); + await user.click(cancelButton); + + // Get dialog and confirm + const dialog = await screen.findByTestId("dialog"); + const confirmButton = within(dialog).getByRole("button", { + name: "Confirm", + hidden: false, + }); + await user.click(confirmButton); + + expect(cancelWorkspaceMock).toHaveBeenCalledWith( + MockStartingWorkspace.latest_build.id, + undefined, + ); + }); - await testButton(MockStartingWorkspace, "Cancel", cancelWorkspaceMock); + it("requests cancellation when the user presses Cancel and the workspace is pending", async () => { + server.use( + http.get("/api/v2/users/:userId/workspace/:workspaceName", () => { + return HttpResponse.json(MockPendingWorkspace); + }), + ); + + const user = userEvent.setup({ delay: 0 }); + const cancelWorkspaceMock = jest + .spyOn(API, "cancelWorkspaceBuild") + .mockImplementation(() => Promise.resolve({ message: "job canceled" })); + await renderWorkspacePage(MockPendingWorkspace); + + // Click on Cancel + const cancelButton = await screen.findByRole("button", { name: "Cancel" }); + await user.click(cancelButton); + + // Get dialog and confirm + const dialog = await screen.findByTestId("dialog"); + const confirmButton = within(dialog).getByRole("button", { + name: "Confirm", + hidden: false, + }); + await user.click(confirmButton); + + expect(cancelWorkspaceMock).toHaveBeenCalledWith( + MockPendingWorkspace.latest_build.id, + { expect_status: "pending" }, + ); }); it("requests an update when the user presses Update", async () => { diff --git a/site/src/pages/WorkspacePage/WorkspaceReadyPage.tsx b/site/src/pages/WorkspacePage/WorkspaceReadyPage.tsx index 79ec5ad11a2b5..4034cc144e127 100644 --- a/site/src/pages/WorkspacePage/WorkspaceReadyPage.tsx +++ b/site/src/pages/WorkspacePage/WorkspaceReadyPage.tsx @@ -20,6 +20,7 @@ import { displayError } from "components/GlobalSnackbar/utils"; import { useWorkspaceBuildLogs } from "hooks/useWorkspaceBuildLogs"; import { EphemeralParametersDialog } from "modules/workspaces/EphemeralParametersDialog/EphemeralParametersDialog"; import { WorkspaceErrorDialog } from "modules/workspaces/ErrorDialog/WorkspaceErrorDialog"; +import { WorkspaceBuildCancelDialog } from "modules/workspaces/WorkspaceBuildCancelDialog/WorkspaceBuildCancelDialog"; import { WorkspaceUpdateDialogs, useWorkspaceUpdate, @@ -80,6 +81,8 @@ export const WorkspaceReadyPage: FC = ({ ephemeralParameters: TypesGen.TemplateVersionParameter[]; }>({ open: false, action: "start", ephemeralParameters: [] }); + const [isCancelConfirmOpen, setIsCancelConfirmOpen] = useState(false); + const { mutate: mutateRestartWorkspace, isPending: isRestarting } = useMutation({ mutationFn: API.restartWorkspace, @@ -316,7 +319,7 @@ export const WorkspaceReadyPage: FC = ({ } }} handleUpdate={workspaceUpdate.update} - handleCancel={cancelBuildMutation.mutate} + handleCancel={() => setIsCancelConfirmOpen(true)} handleRetry={handleRetry} handleDebug={handleDebug} handleDormantActivate={async () => { @@ -352,6 +355,16 @@ export const WorkspaceReadyPage: FC = ({ } /> + setIsCancelConfirmOpen(false)} + onConfirm={() => { + cancelBuildMutation.mutate(); + setIsCancelConfirmOpen(false); + }} + workspace={workspace} + /> + diff --git a/site/src/pages/WorkspacesPage/WorkspacesPage.tsx b/site/src/pages/WorkspacesPage/WorkspacesPage.tsx index 22ba0d15f1f9a..fa96191501379 100644 --- a/site/src/pages/WorkspacesPage/WorkspacesPage.tsx +++ b/site/src/pages/WorkspacesPage/WorkspacesPage.tsx @@ -2,7 +2,7 @@ import { getErrorDetail, getErrorMessage } from "api/errors"; import { workspacePermissionsByOrganization } from "api/queries/organizations"; import { templates } from "api/queries/templates"; import { workspaces } from "api/queries/workspaces"; -import type { Workspace } from "api/typesGenerated"; +import type { Workspace, WorkspaceStatus } from "api/typesGenerated"; import { useFilter } from "components/Filter/Filter"; import { useUserFilterMenu } from "components/Filter/UserFilter"; import { displayError } from "components/GlobalSnackbar/utils"; @@ -22,6 +22,18 @@ import { WorkspacesPageView } from "./WorkspacesPageView"; import { useBatchActions } from "./batchActions"; import { useStatusFilterMenu, useTemplateFilterMenu } from "./filter/menus"; +// To reduce the number of fetches, we reduce the fetch interval if there are no +// active workspace builds. +const ACTIVE_BUILD_STATUSES: WorkspaceStatus[] = [ + "canceling", + "deleting", + "pending", + "starting", + "stopping", +]; +const ACTIVE_BUILDS_REFRESH_INTERVAL = 5_000; +const NO_ACTIVE_BUILDS_REFRESH_INTERVAL = 30_000; + function useSafeSearchParams() { // Have to wrap setSearchParams because React Router doesn't make sure that // the function's memory reference stays stable on each render, even though @@ -78,8 +90,23 @@ const WorkspacesPage: FC = () => { const { data, error, refetch } = useQuery({ ...workspacesQueryOptions, refetchInterval: ({ state }) => { - return state.error ? false : 5_000; + if (state.error) return false; + + // Default to 5s interval until first fetch completes + if (!state.data) return ACTIVE_BUILDS_REFRESH_INTERVAL; + + // Check if any workspace has an active build + const hasActiveBuilds = state.data.workspaces?.some((workspace) => { + const status = workspace.latest_build.status; + return ACTIVE_BUILD_STATUSES.includes(status); + }); + + // Poll every 5s if there are active builds, otherwise every 30s + return hasActiveBuilds + ? ACTIVE_BUILDS_REFRESH_INTERVAL + : NO_ACTIVE_BUILDS_REFRESH_INTERVAL; }, + refetchOnWindowFocus: "always", }); const [checkedWorkspaces, setCheckedWorkspaces] = useState< diff --git a/site/src/pages/WorkspacesPage/WorkspacesTable.tsx b/site/src/pages/WorkspacesPage/WorkspacesTable.tsx index 6213224dea602..23695d8fbc591 100644 --- a/site/src/pages/WorkspacesPage/WorkspacesTable.tsx +++ b/site/src/pages/WorkspacesPage/WorkspacesTable.tsx @@ -62,6 +62,7 @@ import { import { useAppLink } from "modules/apps/useAppLink"; import { useDashboard } from "modules/dashboard/useDashboard"; import { WorkspaceAppStatus } from "modules/workspaces/WorkspaceAppStatus/WorkspaceAppStatus"; +import { WorkspaceBuildCancelDialog } from "modules/workspaces/WorkspaceBuildCancelDialog/WorkspaceBuildCancelDialog"; import { WorkspaceDormantBadge } from "modules/workspaces/WorkspaceDormantBadge/WorkspaceDormantBadge"; import { WorkspaceMoreActions } from "modules/workspaces/WorkspaceMoreActions/WorkspaceMoreActions"; import { WorkspaceOutdatedTooltip } from "modules/workspaces/WorkspaceOutdatedTooltip/WorkspaceOutdatedTooltip"; @@ -495,8 +496,8 @@ const WorkspaceActionsCell: FC = ({ onError: onActionError, }); - // State for stop confirmation dialog const [isStopConfirmOpen, setIsStopConfirmOpen] = useState(false); + const [isCancelConfirmOpen, setIsCancelConfirmOpen] = useState(false); const isRetrying = startWorkspaceMutation.isPending || @@ -606,7 +607,7 @@ const WorkspaceActionsCell: FC = ({ {abilities.canCancel && ( setIsCancelConfirmOpen(true)} isLoading={cancelBuildMutation.isPending} label="Cancel build" > @@ -643,6 +644,16 @@ const WorkspaceActionsCell: FC = ({ }} type="delete" /> + + setIsCancelConfirmOpen(false)} + onConfirm={() => { + cancelBuildMutation.mutate(); + setIsCancelConfirmOpen(false); + }} + workspace={workspace} + /> ); }; diff --git a/site/src/utils/workspace.tsx b/site/src/utils/workspace.tsx index 135b965589054..c88ffc9d8edaa 100644 --- a/site/src/utils/workspace.tsx +++ b/site/src/utils/workspace.tsx @@ -75,14 +75,16 @@ export const getDisplayWorkspaceBuildStatus = ( export const getDisplayWorkspaceBuildInitiatedBy = ( build: TypesGen.WorkspaceBuild, -): string => { +): string | undefined => { switch (build.reason) { case "initiator": return build.initiator_name; case "autostart": case "autostop": + case "dormancy": return "Coder"; } + return undefined; }; const getWorkspaceBuildDurationInSeconds = (