Skip to content

Commit 949ca4d

Browse files
committed
add unit tests for unauth'd email domains
1 parent 0fcce5f commit 949ca4d

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

coderd/domain_error_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package coderd_test
2+
3+
import (
4+
"context"
5+
"io"
6+
"net/http"
7+
"testing"
8+
9+
"github.com/golang-jwt/jwt/v4"
10+
"github.com/google/uuid"
11+
"github.com/stretchr/testify/require"
12+
13+
"github.com/coder/coder/v2/coderd"
14+
"github.com/coder/coder/v2/coderd/coderdtest"
15+
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
16+
)
17+
18+
// TestOIDCDomainErrorMessage ensures that when a user with an unauthorized domain
19+
// attempts to login, the error message doesn't expose the list of authorized domains.
20+
func TestOIDCDomainErrorMessage(t *testing.T) {
21+
t.Parallel()
22+
23+
// Setup OIDC fake provider
24+
fake := oidctest.NewFakeIDP(t, oidctest.WithServing())
25+
26+
// Configure OIDC provider with domain restrictions
27+
allowedDomains := []string{"allowed1.com", "allowed2.org", "company.internal"}
28+
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
29+
cfg.EmailDomain = allowedDomains
30+
cfg.AllowSignups = true
31+
})
32+
33+
// Create a Coder server with OIDC enabled
34+
server := coderdtest.New(t, &coderdtest.Options{
35+
OIDCConfig: cfg,
36+
})
37+
38+
// Test case 1: Email domain not in allowed list
39+
t.Run("ErrorMessageOmitsDomains", func(t *testing.T) {
40+
t.Parallel()
41+
42+
// Prepare claims with email from unauthorized domain
43+
claims := jwt.MapClaims{
44+
"email": "user@unauthorized.com",
45+
"email_verified": true,
46+
"sub": uuid.NewString(),
47+
}
48+
49+
// Attempt login and check for failure
50+
_, resp := fake.AttemptLogin(t, server, claims)
51+
defer resp.Body.Close()
52+
53+
// Verify the status code
54+
require.Equal(t, http.StatusForbidden, resp.StatusCode)
55+
56+
// Check the response content
57+
data, err := io.ReadAll(resp.Body)
58+
require.NoError(t, err)
59+
60+
// Verify the message contains the generic text
61+
require.Contains(t, string(data), "is not from an authorized domain")
62+
require.Contains(t, string(data), "Please contact your administrator")
63+
64+
// Verify it doesn't contain any of the allowed domains
65+
for _, domain := range allowedDomains {
66+
require.NotContains(t, string(data), domain)
67+
}
68+
})
69+
70+
// Test case 2: Malformed email without @ symbol
71+
t.Run("MalformedEmailErrorOmitsDomains", func(t *testing.T) {
72+
t.Parallel()
73+
74+
// Prepare claims with an invalid email format (no @ symbol)
75+
claims := jwt.MapClaims{
76+
"email": "invalid-email-without-domain",
77+
"email_verified": true,
78+
"sub": uuid.NewString(),
79+
}
80+
81+
// Attempt login and check for failure
82+
_, resp := fake.AttemptLogin(t, server, claims)
83+
defer resp.Body.Close()
84+
85+
// Verify the status code
86+
require.Equal(t, http.StatusForbidden, resp.StatusCode)
87+
88+
// Check the response content
89+
data, err := io.ReadAll(resp.Body)
90+
require.NoError(t, err)
91+
92+
// Verify the message contains the generic text
93+
require.Contains(t, string(data), "is not from an authorized domain")
94+
require.Contains(t, string(data), "Please contact your administrator")
95+
96+
// Verify it doesn't contain any of the allowed domains
97+
for _, domain := range allowedDomains {
98+
require.NotContains(t, string(data), domain)
99+
}
100+
})
101+
102+
// Test case 3: Authorized domain (should succeed)
103+
t.Run("AuthorizedDomainSucceeds", func(t *testing.T) {
104+
t.Parallel()
105+
106+
// Prepare claims with an authorized domain
107+
claims := jwt.MapClaims{
108+
"email": "user@allowed1.com",
109+
"email_verified": true,
110+
"sub": uuid.NewString(),
111+
}
112+
113+
// Attempt login and expect success
114+
client, _ := fake.Login(t, server, claims)
115+
116+
// Verify the user was created correctly
117+
user, err := client.User(context.Background(), "me")
118+
require.NoError(t, err)
119+
require.Equal(t, "user", user.Username)
120+
})
121+
}

0 commit comments

Comments
 (0)