Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8e0a153

Browse files
authoredJan 22, 2024
chore: implement device auth flow for fake idp (#11707)
* chore: implement device auth flow for fake idp
1 parent 16c6cef commit 8e0a153

File tree

4 files changed

+333
-23
lines changed

4 files changed

+333
-23
lines changed
 

‎coderd/coderdtest/oidctest/idp.go

+253-15
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@ import (
1010
"errors"
1111
"fmt"
1212
"io"
13+
"math/rand"
14+
"mime"
1315
"net"
1416
"net/http"
1517
"net/http/cookiejar"
1618
"net/http/httptest"
1719
"net/url"
20+
"strconv"
1821
"strings"
1922
"testing"
2023
"time"
@@ -34,9 +37,11 @@ import (
3437
"cdr.dev/slog/sloggers/slogtest"
3538
"github.com/coder/coder/v2/coderd"
3639
"github.com/coder/coder/v2/coderd/externalauth"
40+
"github.com/coder/coder/v2/coderd/httpapi"
3741
"github.com/coder/coder/v2/coderd/promoauth"
3842
"github.com/coder/coder/v2/coderd/util/syncmap"
3943
"github.com/coder/coder/v2/codersdk"
44+
"github.com/coder/coder/v2/testutil"
4045
)
4146

4247
type token struct {
@@ -45,6 +50,13 @@ type token struct {
4550
exp time.Time
4651
}
4752

53+
type deviceFlow struct {
54+
// userInput is the expected input to authenticate the device flow.
55+
userInput string
56+
exp time.Time
57+
granted bool
58+
}
59+
4860
// FakeIDP is a functional OIDC provider.
4961
// It only supports 1 OIDC client.
5062
type FakeIDP struct {
@@ -77,6 +89,8 @@ type FakeIDP struct {
7789
refreshTokens *syncmap.Map[string, string]
7890
stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
7991
refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
92+
// Device flow
93+
deviceCode *syncmap.Map[string, deviceFlow]
8094

8195
// hooks
8296
// hookValidRedirectURL can be used to reject a redirect url from the
@@ -226,6 +240,8 @@ const (
226240
authorizePath = "/oauth2/authorize"
227241
keysPath = "/oauth2/keys"
228242
userInfoPath = "/oauth2/userinfo"
243+
deviceAuth = "/login/device/code"
244+
deviceVerify = "/login/device"
229245
)
230246

231247
func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
@@ -246,6 +262,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
246262
refreshTokensUsed: syncmap.New[string, bool](),
247263
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
248264
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
265+
deviceCode: syncmap.New[string, deviceFlow](),
249266
hookOnRefresh: func(_ string) error { return nil },
250267
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
251268
hookValidRedirectURL: func(redirectURL string) error { return nil },
@@ -288,11 +305,12 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
288305
// ProviderJSON is the JSON representation of the OpenID Connect provider
289306
// These are all the urls that the IDP will respond to.
290307
f.provider = ProviderJSON{
291-
Issuer: issuer,
292-
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
293-
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
294-
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
295-
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
308+
Issuer: issuer,
309+
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
310+
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
311+
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
312+
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
313+
DeviceCodeURL: u.ResolveReference(&url.URL{Path: deviceAuth}).String(),
296314
Algorithms: []string{
297315
"RS256",
298316
},
@@ -467,6 +485,31 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f
467485
_ = res.Body.Close()
468486
}
469487

488+
// DeviceLogin does the oauth2 device flow for external auth providers.
489+
func (*FakeIDP) DeviceLogin(t testing.TB, client *codersdk.Client, externalAuthID string) {
490+
// First we need to initiate the device flow. This will have Coder hit the
491+
// fake IDP and get a device code.
492+
device, err := client.ExternalAuthDeviceByID(context.Background(), externalAuthID)
493+
require.NoError(t, err)
494+
495+
// Now the user needs to go to the fake IDP page and click "allow" and enter
496+
// the device code input. For our purposes, we just send an http request to
497+
// the verification url. No additional user input is needed.
498+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
499+
defer cancel()
500+
resp, err := client.Request(ctx, http.MethodPost, device.VerificationURI, nil)
501+
require.NoError(t, err)
502+
defer resp.Body.Close()
503+
504+
// Now we need to exchange the device code for an access token. We do this
505+
// in this method because it is the user that does the polling for the device
506+
// auth flow, not the backend.
507+
err = client.ExternalAuthDeviceExchange(context.Background(), externalAuthID, codersdk.ExternalAuthDeviceExchange{
508+
DeviceCode: device.DeviceCode,
509+
})
510+
require.NoError(t, err)
511+
}
512+
470513
// CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing
471514
// unit tests, it's easier to skip this step sometimes. It does make an actual
472515
// request to the IDP, so it should be equivalent to doing this "manually" with
@@ -536,12 +579,13 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
536579

537580
// ProviderJSON is the .well-known/configuration JSON
538581
type ProviderJSON struct {
539-
Issuer string `json:"issuer"`
540-
AuthURL string `json:"authorization_endpoint"`
541-
TokenURL string `json:"token_endpoint"`
542-
JWKSURL string `json:"jwks_uri"`
543-
UserInfoURL string `json:"userinfo_endpoint"`
544-
Algorithms []string `json:"id_token_signing_alg_values_supported"`
582+
Issuer string `json:"issuer"`
583+
AuthURL string `json:"authorization_endpoint"`
584+
TokenURL string `json:"token_endpoint"`
585+
JWKSURL string `json:"jwks_uri"`
586+
UserInfoURL string `json:"userinfo_endpoint"`
587+
DeviceCodeURL string `json:"device_authorization_endpoint"`
588+
Algorithms []string `json:"id_token_signing_alg_values_supported"`
545589
// This is custom
546590
ExternalAuthURL string `json:"external_auth_url"`
547591
}
@@ -709,8 +753,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
709753
}))
710754

711755
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
712-
values, err := f.authenticateOIDCClientRequest(t, r)
756+
var values url.Values
757+
var err error
758+
if r.URL.Query().Get("grant_type") == "urn:ietf:params:oauth:grant-type:device_code" {
759+
values = r.URL.Query()
760+
} else {
761+
values, err = f.authenticateOIDCClientRequest(t, r)
762+
}
713763
f.logger.Info(r.Context(), "http idp call token",
764+
slog.F("url", r.URL.String()),
714765
slog.F("valid", err == nil),
715766
slog.F("grant_type", values.Get("grant_type")),
716767
slog.F("values", values.Encode()),
@@ -784,6 +835,37 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
784835
f.refreshTokensUsed.Store(refreshToken, true)
785836
// Always invalidate the refresh token after it is used.
786837
f.refreshTokens.Delete(refreshToken)
838+
case "urn:ietf:params:oauth:grant-type:device_code":
839+
// Device flow
840+
var resp externalauth.ExchangeDeviceCodeResponse
841+
deviceCode := values.Get("device_code")
842+
if deviceCode == "" {
843+
resp.Error = "invalid_request"
844+
resp.ErrorDescription = "missing device_code"
845+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
846+
return
847+
}
848+
849+
deviceFlow, ok := f.deviceCode.Load(deviceCode)
850+
if !ok {
851+
resp.Error = "invalid_request"
852+
resp.ErrorDescription = "device_code provided not found"
853+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
854+
return
855+
}
856+
857+
if !deviceFlow.granted {
858+
// Status code ok with the error as pending.
859+
resp.Error = "authorization_pending"
860+
resp.ErrorDescription = ""
861+
httpapi.Write(r.Context(), rw, http.StatusOK, resp)
862+
return
863+
}
864+
865+
// Would be nice to get an actual email here.
866+
claims = jwt.MapClaims{
867+
"email": "unknown-dev-auth",
868+
}
787869
default:
788870
t.Errorf("unexpected grant_type %q", values.Get("grant_type"))
789871
http.Error(rw, "invalid grant_type", http.StatusBadRequest)
@@ -807,8 +889,30 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
807889
// Store the claims for the next refresh
808890
f.refreshIDTokenClaims.Store(refreshToken, claims)
809891

810-
rw.Header().Set("Content-Type", "application/json")
811-
_ = json.NewEncoder(rw).Encode(token)
892+
mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept"))
893+
if mediaType == "application/x-www-form-urlencoded" {
894+
// This val encode might not work for some data structures.
895+
// It's good enough for now...
896+
rw.Header().Set("Content-Type", "application/x-www-form-urlencoded")
897+
vals := url.Values{}
898+
for k, v := range token {
899+
vals.Set(k, fmt.Sprintf("%v", v))
900+
}
901+
_, _ = rw.Write([]byte(vals.Encode()))
902+
return
903+
}
904+
// Default to json since the oauth2 package doesn't use Accept headers.
905+
if mediaType == "application/json" || mediaType == "" {
906+
rw.Header().Set("Content-Type", "application/json")
907+
_ = json.NewEncoder(rw).Encode(token)
908+
return
909+
}
910+
911+
// If we get something we don't support, throw an error.
912+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
913+
Message: "'Accept' header contains unsupported media type",
914+
Detail: fmt.Sprintf("Found %q", mediaType),
915+
})
812916
}))
813917

814918
validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
@@ -886,6 +990,125 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
886990
_ = json.NewEncoder(rw).Encode(set)
887991
}))
888992

993+
mux.Handle(deviceVerify, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
994+
f.logger.Info(r.Context(), "http call device verify")
995+
996+
inputParam := "user_input"
997+
userInput := r.URL.Query().Get(inputParam)
998+
if userInput == "" {
999+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1000+
Message: "Invalid user input",
1001+
Detail: fmt.Sprintf("Hit this url again with ?%s=<user_code>", inputParam),
1002+
})
1003+
return
1004+
}
1005+
1006+
deviceCode := r.URL.Query().Get("device_code")
1007+
if deviceCode == "" {
1008+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1009+
Message: "Invalid device code",
1010+
Detail: "Hit this url again with ?device_code=<device_code>",
1011+
})
1012+
return
1013+
}
1014+
1015+
flow, ok := f.deviceCode.Load(deviceCode)
1016+
if !ok {
1017+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1018+
Message: "Invalid device code",
1019+
Detail: "Device code not found.",
1020+
})
1021+
return
1022+
}
1023+
1024+
if time.Now().After(flow.exp) {
1025+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1026+
Message: "Invalid device code",
1027+
Detail: "Device code expired.",
1028+
})
1029+
return
1030+
}
1031+
1032+
if strings.TrimSpace(flow.userInput) != strings.TrimSpace(userInput) {
1033+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1034+
Message: "Invalid device code",
1035+
Detail: "user code does not match",
1036+
})
1037+
return
1038+
}
1039+
1040+
f.deviceCode.Store(deviceCode, deviceFlow{
1041+
userInput: flow.userInput,
1042+
exp: flow.exp,
1043+
granted: true,
1044+
})
1045+
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
1046+
Message: "Device authenticated!",
1047+
})
1048+
}))
1049+
1050+
mux.Handle(deviceAuth, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
1051+
f.logger.Info(r.Context(), "http call device auth")
1052+
1053+
p := httpapi.NewQueryParamParser()
1054+
p.Required("client_id")
1055+
clientID := p.String(r.URL.Query(), "", "client_id")
1056+
_ = p.String(r.URL.Query(), "", "scopes")
1057+
if len(p.Errors) > 0 {
1058+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1059+
Message: "Invalid query params",
1060+
Validations: p.Errors,
1061+
})
1062+
return
1063+
}
1064+
1065+
if clientID != f.clientID {
1066+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
1067+
Message: "Invalid client id",
1068+
})
1069+
return
1070+
}
1071+
1072+
deviceCode := uuid.NewString()
1073+
lifetime := time.Second * 900
1074+
flow := deviceFlow{
1075+
//nolint:gosec
1076+
userInput: fmt.Sprintf("%d", rand.Intn(9999999)+1e8),
1077+
}
1078+
f.deviceCode.Store(deviceCode, deviceFlow{
1079+
userInput: flow.userInput,
1080+
exp: time.Now().Add(lifetime),
1081+
})
1082+
1083+
verifyURL := f.issuerURL.ResolveReference(&url.URL{
1084+
Path: deviceVerify,
1085+
RawQuery: url.Values{
1086+
"device_code": {deviceCode},
1087+
"user_input": {flow.userInput},
1088+
}.Encode(),
1089+
}).String()
1090+
1091+
if mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")); mediaType == "application/json" {
1092+
httpapi.Write(r.Context(), rw, http.StatusOK, map[string]any{
1093+
"device_code": deviceCode,
1094+
"user_code": flow.userInput,
1095+
"verification_uri": verifyURL,
1096+
"expires_in": int(lifetime.Seconds()),
1097+
"interval": 3,
1098+
})
1099+
return
1100+
}
1101+
1102+
// By default, GitHub form encodes these.
1103+
_, _ = fmt.Fprint(rw, url.Values{
1104+
"device_code": {deviceCode},
1105+
"user_code": {flow.userInput},
1106+
"verification_uri": {verifyURL},
1107+
"expires_in": {strconv.Itoa(int(lifetime.Seconds()))},
1108+
"interval": {"3"},
1109+
}.Encode())
1110+
}))
1111+
8891112
mux.NotFound(func(rw http.ResponseWriter, r *http.Request) {
8901113
f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path))
8911114
t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path)
@@ -987,6 +1210,8 @@ type ExternalAuthConfigOptions struct {
9871210
// completely customize the response. It captures all routes under the /external-auth-validate/*
9881211
// so the caller can do whatever they want and even add routes.
9891212
routes map[string]func(email string, rw http.ResponseWriter, r *http.Request)
1213+
1214+
UseDeviceAuth bool
9901215
}
9911216

9921217
func (o *ExternalAuthConfigOptions) AddRoute(route string, handle func(email string, rw http.ResponseWriter, r *http.Request)) *ExternalAuthConfigOptions {
@@ -1033,17 +1258,30 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
10331258
}
10341259
}
10351260
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
1261+
oauthCfg := instrumentF.New(f.clientID, f.OIDCConfig(t, nil))
10361262
cfg := &externalauth.Config{
10371263
DisplayName: id,
1038-
InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)),
1264+
InstrumentedOAuth2Config: oauthCfg,
10391265
ID: id,
10401266
// No defaults for these fields by omitting the type
10411267
Type: "",
10421268
DisplayIcon: f.WellknownConfig().UserInfoURL,
10431269
// Omit the /user for the validate so we can easily append to it when modifying
10441270
// the cfg for advanced tests.
10451271
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
1272+
DeviceAuth: &externalauth.DeviceAuth{
1273+
Config: oauthCfg,
1274+
ClientID: f.clientID,
1275+
TokenURL: f.provider.TokenURL,
1276+
Scopes: []string{},
1277+
CodeURL: f.provider.DeviceCodeURL,
1278+
},
1279+
}
1280+
1281+
if !custom.UseDeviceAuth {
1282+
cfg.DeviceAuth = nil
10461283
}
1284+
10471285
for _, opt := range opts {
10481286
opt(cfg)
10491287
}
There was a problem loading the remainder of the diff.

0 commit comments

Comments
 (0)
Failed to load comments.