@@ -10,11 +10,14 @@ import (
10
10
"errors"
11
11
"fmt"
12
12
"io"
13
+ "math/rand"
14
+ "mime"
13
15
"net"
14
16
"net/http"
15
17
"net/http/cookiejar"
16
18
"net/http/httptest"
17
19
"net/url"
20
+ "strconv"
18
21
"strings"
19
22
"testing"
20
23
"time"
@@ -34,9 +37,11 @@ import (
34
37
"cdr.dev/slog/sloggers/slogtest"
35
38
"github.com/coder/coder/v2/coderd"
36
39
"github.com/coder/coder/v2/coderd/externalauth"
40
+ "github.com/coder/coder/v2/coderd/httpapi"
37
41
"github.com/coder/coder/v2/coderd/promoauth"
38
42
"github.com/coder/coder/v2/coderd/util/syncmap"
39
43
"github.com/coder/coder/v2/codersdk"
44
+ "github.com/coder/coder/v2/testutil"
40
45
)
41
46
42
47
type token struct {
@@ -45,6 +50,13 @@ type token struct {
45
50
exp time.Time
46
51
}
47
52
53
+ type deviceFlow struct {
54
+ // userInput is the expected input to authenticate the device flow.
55
+ userInput string
56
+ exp time.Time
57
+ granted bool
58
+ }
59
+
48
60
// FakeIDP is a functional OIDC provider.
49
61
// It only supports 1 OIDC client.
50
62
type FakeIDP struct {
@@ -77,6 +89,8 @@ type FakeIDP struct {
77
89
refreshTokens * syncmap.Map [string , string ]
78
90
stateToIDTokenClaims * syncmap.Map [string , jwt.MapClaims ]
79
91
refreshIDTokenClaims * syncmap.Map [string , jwt.MapClaims ]
92
+ // Device flow
93
+ deviceCode * syncmap.Map [string , deviceFlow ]
80
94
81
95
// hooks
82
96
// hookValidRedirectURL can be used to reject a redirect url from the
@@ -226,6 +240,8 @@ const (
226
240
authorizePath = "/oauth2/authorize"
227
241
keysPath = "/oauth2/keys"
228
242
userInfoPath = "/oauth2/userinfo"
243
+ deviceAuth = "/login/device/code"
244
+ deviceVerify = "/login/device"
229
245
)
230
246
231
247
func NewFakeIDP (t testing.TB , opts ... FakeIDPOpt ) * FakeIDP {
@@ -246,6 +262,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
246
262
refreshTokensUsed : syncmap .New [string , bool ](),
247
263
stateToIDTokenClaims : syncmap .New [string , jwt.MapClaims ](),
248
264
refreshIDTokenClaims : syncmap .New [string , jwt.MapClaims ](),
265
+ deviceCode : syncmap .New [string , deviceFlow ](),
249
266
hookOnRefresh : func (_ string ) error { return nil },
250
267
hookUserInfo : func (email string ) (jwt.MapClaims , error ) { return jwt.MapClaims {}, nil },
251
268
hookValidRedirectURL : func (redirectURL string ) error { return nil },
@@ -288,11 +305,12 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
288
305
// ProviderJSON is the JSON representation of the OpenID Connect provider
289
306
// These are all the urls that the IDP will respond to.
290
307
f .provider = ProviderJSON {
291
- Issuer : issuer ,
292
- AuthURL : u .ResolveReference (& url.URL {Path : authorizePath }).String (),
293
- TokenURL : u .ResolveReference (& url.URL {Path : tokenPath }).String (),
294
- JWKSURL : u .ResolveReference (& url.URL {Path : keysPath }).String (),
295
- UserInfoURL : u .ResolveReference (& url.URL {Path : userInfoPath }).String (),
308
+ Issuer : issuer ,
309
+ AuthURL : u .ResolveReference (& url.URL {Path : authorizePath }).String (),
310
+ TokenURL : u .ResolveReference (& url.URL {Path : tokenPath }).String (),
311
+ JWKSURL : u .ResolveReference (& url.URL {Path : keysPath }).String (),
312
+ UserInfoURL : u .ResolveReference (& url.URL {Path : userInfoPath }).String (),
313
+ DeviceCodeURL : u .ResolveReference (& url.URL {Path : deviceAuth }).String (),
296
314
Algorithms : []string {
297
315
"RS256" ,
298
316
},
@@ -467,6 +485,31 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f
467
485
_ = res .Body .Close ()
468
486
}
469
487
488
+ // DeviceLogin does the oauth2 device flow for external auth providers.
489
+ func (* FakeIDP ) DeviceLogin (t testing.TB , client * codersdk.Client , externalAuthID string ) {
490
+ // First we need to initiate the device flow. This will have Coder hit the
491
+ // fake IDP and get a device code.
492
+ device , err := client .ExternalAuthDeviceByID (context .Background (), externalAuthID )
493
+ require .NoError (t , err )
494
+
495
+ // Now the user needs to go to the fake IDP page and click "allow" and enter
496
+ // the device code input. For our purposes, we just send an http request to
497
+ // the verification url. No additional user input is needed.
498
+ ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitShort )
499
+ defer cancel ()
500
+ resp , err := client .Request (ctx , http .MethodPost , device .VerificationURI , nil )
501
+ require .NoError (t , err )
502
+ defer resp .Body .Close ()
503
+
504
+ // Now we need to exchange the device code for an access token. We do this
505
+ // in this method because it is the user that does the polling for the device
506
+ // auth flow, not the backend.
507
+ err = client .ExternalAuthDeviceExchange (context .Background (), externalAuthID , codersdk.ExternalAuthDeviceExchange {
508
+ DeviceCode : device .DeviceCode ,
509
+ })
510
+ require .NoError (t , err )
511
+ }
512
+
470
513
// CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing
471
514
// unit tests, it's easier to skip this step sometimes. It does make an actual
472
515
// request to the IDP, so it should be equivalent to doing this "manually" with
@@ -536,12 +579,13 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
536
579
537
580
// ProviderJSON is the .well-known/configuration JSON
538
581
type ProviderJSON struct {
539
- Issuer string `json:"issuer"`
540
- AuthURL string `json:"authorization_endpoint"`
541
- TokenURL string `json:"token_endpoint"`
542
- JWKSURL string `json:"jwks_uri"`
543
- UserInfoURL string `json:"userinfo_endpoint"`
544
- Algorithms []string `json:"id_token_signing_alg_values_supported"`
582
+ Issuer string `json:"issuer"`
583
+ AuthURL string `json:"authorization_endpoint"`
584
+ TokenURL string `json:"token_endpoint"`
585
+ JWKSURL string `json:"jwks_uri"`
586
+ UserInfoURL string `json:"userinfo_endpoint"`
587
+ DeviceCodeURL string `json:"device_authorization_endpoint"`
588
+ Algorithms []string `json:"id_token_signing_alg_values_supported"`
545
589
// This is custom
546
590
ExternalAuthURL string `json:"external_auth_url"`
547
591
}
@@ -709,8 +753,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
709
753
}))
710
754
711
755
mux .Handle (tokenPath , http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
712
- values , err := f .authenticateOIDCClientRequest (t , r )
756
+ var values url.Values
757
+ var err error
758
+ if r .URL .Query ().Get ("grant_type" ) == "urn:ietf:params:oauth:grant-type:device_code" {
759
+ values = r .URL .Query ()
760
+ } else {
761
+ values , err = f .authenticateOIDCClientRequest (t , r )
762
+ }
713
763
f .logger .Info (r .Context (), "http idp call token" ,
764
+ slog .F ("url" , r .URL .String ()),
714
765
slog .F ("valid" , err == nil ),
715
766
slog .F ("grant_type" , values .Get ("grant_type" )),
716
767
slog .F ("values" , values .Encode ()),
@@ -784,6 +835,37 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
784
835
f .refreshTokensUsed .Store (refreshToken , true )
785
836
// Always invalidate the refresh token after it is used.
786
837
f .refreshTokens .Delete (refreshToken )
838
+ case "urn:ietf:params:oauth:grant-type:device_code" :
839
+ // Device flow
840
+ var resp externalauth.ExchangeDeviceCodeResponse
841
+ deviceCode := values .Get ("device_code" )
842
+ if deviceCode == "" {
843
+ resp .Error = "invalid_request"
844
+ resp .ErrorDescription = "missing device_code"
845
+ httpapi .Write (r .Context (), rw , http .StatusBadRequest , resp )
846
+ return
847
+ }
848
+
849
+ deviceFlow , ok := f .deviceCode .Load (deviceCode )
850
+ if ! ok {
851
+ resp .Error = "invalid_request"
852
+ resp .ErrorDescription = "device_code provided not found"
853
+ httpapi .Write (r .Context (), rw , http .StatusBadRequest , resp )
854
+ return
855
+ }
856
+
857
+ if ! deviceFlow .granted {
858
+ // Status code ok with the error as pending.
859
+ resp .Error = "authorization_pending"
860
+ resp .ErrorDescription = ""
861
+ httpapi .Write (r .Context (), rw , http .StatusOK , resp )
862
+ return
863
+ }
864
+
865
+ // Would be nice to get an actual email here.
866
+ claims = jwt.MapClaims {
867
+ "email" : "unknown-dev-auth" ,
868
+ }
787
869
default :
788
870
t .Errorf ("unexpected grant_type %q" , values .Get ("grant_type" ))
789
871
http .Error (rw , "invalid grant_type" , http .StatusBadRequest )
@@ -807,8 +889,30 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
807
889
// Store the claims for the next refresh
808
890
f .refreshIDTokenClaims .Store (refreshToken , claims )
809
891
810
- rw .Header ().Set ("Content-Type" , "application/json" )
811
- _ = json .NewEncoder (rw ).Encode (token )
892
+ mediaType , _ , _ := mime .ParseMediaType (r .Header .Get ("Accept" ))
893
+ if mediaType == "application/x-www-form-urlencoded" {
894
+ // This val encode might not work for some data structures.
895
+ // It's good enough for now...
896
+ rw .Header ().Set ("Content-Type" , "application/x-www-form-urlencoded" )
897
+ vals := url.Values {}
898
+ for k , v := range token {
899
+ vals .Set (k , fmt .Sprintf ("%v" , v ))
900
+ }
901
+ _ , _ = rw .Write ([]byte (vals .Encode ()))
902
+ return
903
+ }
904
+ // Default to json since the oauth2 package doesn't use Accept headers.
905
+ if mediaType == "application/json" || mediaType == "" {
906
+ rw .Header ().Set ("Content-Type" , "application/json" )
907
+ _ = json .NewEncoder (rw ).Encode (token )
908
+ return
909
+ }
910
+
911
+ // If we get something we don't support, throw an error.
912
+ httpapi .Write (r .Context (), rw , http .StatusBadRequest , codersdk.Response {
913
+ Message : "'Accept' header contains unsupported media type" ,
914
+ Detail : fmt .Sprintf ("Found %q" , mediaType ),
915
+ })
812
916
}))
813
917
814
918
validateMW := func (rw http.ResponseWriter , r * http.Request ) (email string , ok bool ) {
@@ -886,6 +990,125 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
886
990
_ = json .NewEncoder (rw ).Encode (set )
887
991
}))
888
992
993
+ mux .Handle (deviceVerify , http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
994
+ f .logger .Info (r .Context (), "http call device verify" )
995
+
996
+ inputParam := "user_input"
997
+ userInput := r .URL .Query ().Get (inputParam )
998
+ if userInput == "" {
999
+ httpapi .Write (r .Context (), rw , http .StatusBadRequest , codersdk.Response {
1000
+ Message : "Invalid user input" ,
1001
+ Detail : fmt .Sprintf ("Hit this url again with ?%s=<user_code>" , inputParam ),
1002
+ })
1003
+ return
1004
+ }
1005
+
1006
+ deviceCode := r .URL .Query ().Get ("device_code" )
1007
+ if deviceCode == "" {
1008
+ httpapi .Write (r .Context (), rw , http .StatusBadRequest , codersdk.Response {
1009
+ Message : "Invalid device code" ,
1010
+ Detail : "Hit this url again with ?device_code=<device_code>" ,
1011
+ })
1012
+ return
1013
+ }
1014
+
1015
+ flow , ok := f .deviceCode .Load (deviceCode )
1016
+ if ! ok {
1017
+ httpapi .Write (r .Context (), rw , http .StatusBadRequest , codersdk.Response {
1018
+ Message : "Invalid device code" ,
1019
+ Detail : "Device code not found." ,
1020
+ })
1021
+ return
1022
+ }
1023
+
1024
+ if time .Now ().After (flow .exp ) {
1025
+ httpapi .Write (r .Context (), rw , http .StatusBadRequest , codersdk.Response {
1026
+ Message : "Invalid device code" ,
1027
+ Detail : "Device code expired." ,
1028
+ })
1029
+ return
1030
+ }
1031
+
1032
+ if strings .TrimSpace (flow .userInput ) != strings .TrimSpace (userInput ) {
1033
+ httpapi .Write (r .Context (), rw , http .StatusBadRequest , codersdk.Response {
1034
+ Message : "Invalid device code" ,
1035
+ Detail : "user code does not match" ,
1036
+ })
1037
+ return
1038
+ }
1039
+
1040
+ f .deviceCode .Store (deviceCode , deviceFlow {
1041
+ userInput : flow .userInput ,
1042
+ exp : flow .exp ,
1043
+ granted : true ,
1044
+ })
1045
+ httpapi .Write (r .Context (), rw , http .StatusOK , codersdk.Response {
1046
+ Message : "Device authenticated!" ,
1047
+ })
1048
+ }))
1049
+
1050
+ mux .Handle (deviceAuth , http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
1051
+ f .logger .Info (r .Context (), "http call device auth" )
1052
+
1053
+ p := httpapi .NewQueryParamParser ()
1054
+ p .Required ("client_id" )
1055
+ clientID := p .String (r .URL .Query (), "" , "client_id" )
1056
+ _ = p .String (r .URL .Query (), "" , "scopes" )
1057
+ if len (p .Errors ) > 0 {
1058
+ httpapi .Write (r .Context (), rw , http .StatusBadRequest , codersdk.Response {
1059
+ Message : "Invalid query params" ,
1060
+ Validations : p .Errors ,
1061
+ })
1062
+ return
1063
+ }
1064
+
1065
+ if clientID != f .clientID {
1066
+ httpapi .Write (r .Context (), rw , http .StatusBadRequest , codersdk.Response {
1067
+ Message : "Invalid client id" ,
1068
+ })
1069
+ return
1070
+ }
1071
+
1072
+ deviceCode := uuid .NewString ()
1073
+ lifetime := time .Second * 900
1074
+ flow := deviceFlow {
1075
+ //nolint:gosec
1076
+ userInput : fmt .Sprintf ("%d" , rand .Intn (9999999 )+ 1e8 ),
1077
+ }
1078
+ f .deviceCode .Store (deviceCode , deviceFlow {
1079
+ userInput : flow .userInput ,
1080
+ exp : time .Now ().Add (lifetime ),
1081
+ })
1082
+
1083
+ verifyURL := f .issuerURL .ResolveReference (& url.URL {
1084
+ Path : deviceVerify ,
1085
+ RawQuery : url.Values {
1086
+ "device_code" : {deviceCode },
1087
+ "user_input" : {flow .userInput },
1088
+ }.Encode (),
1089
+ }).String ()
1090
+
1091
+ if mediaType , _ , _ := mime .ParseMediaType (r .Header .Get ("Accept" )); mediaType == "application/json" {
1092
+ httpapi .Write (r .Context (), rw , http .StatusOK , map [string ]any {
1093
+ "device_code" : deviceCode ,
1094
+ "user_code" : flow .userInput ,
1095
+ "verification_uri" : verifyURL ,
1096
+ "expires_in" : int (lifetime .Seconds ()),
1097
+ "interval" : 3 ,
1098
+ })
1099
+ return
1100
+ }
1101
+
1102
+ // By default, GitHub form encodes these.
1103
+ _ , _ = fmt .Fprint (rw , url.Values {
1104
+ "device_code" : {deviceCode },
1105
+ "user_code" : {flow .userInput },
1106
+ "verification_uri" : {verifyURL },
1107
+ "expires_in" : {strconv .Itoa (int (lifetime .Seconds ()))},
1108
+ "interval" : {"3" },
1109
+ }.Encode ())
1110
+ }))
1111
+
889
1112
mux .NotFound (func (rw http.ResponseWriter , r * http.Request ) {
890
1113
f .logger .Error (r .Context (), "http call not found" , slog .F ("path" , r .URL .Path ))
891
1114
t .Errorf ("unexpected request to IDP at path %q. Not supported" , r .URL .Path )
@@ -987,6 +1210,8 @@ type ExternalAuthConfigOptions struct {
987
1210
// completely customize the response. It captures all routes under the /external-auth-validate/*
988
1211
// so the caller can do whatever they want and even add routes.
989
1212
routes map [string ]func (email string , rw http.ResponseWriter , r * http.Request )
1213
+
1214
+ UseDeviceAuth bool
990
1215
}
991
1216
992
1217
func (o * ExternalAuthConfigOptions ) AddRoute (route string , handle func (email string , rw http.ResponseWriter , r * http.Request )) * ExternalAuthConfigOptions {
@@ -1033,17 +1258,30 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
1033
1258
}
1034
1259
}
1035
1260
instrumentF := promoauth .NewFactory (prometheus .NewRegistry ())
1261
+ oauthCfg := instrumentF .New (f .clientID , f .OIDCConfig (t , nil ))
1036
1262
cfg := & externalauth.Config {
1037
1263
DisplayName : id ,
1038
- InstrumentedOAuth2Config : instrumentF . New ( f . clientID , f . OIDCConfig ( t , nil )) ,
1264
+ InstrumentedOAuth2Config : oauthCfg ,
1039
1265
ID : id ,
1040
1266
// No defaults for these fields by omitting the type
1041
1267
Type : "" ,
1042
1268
DisplayIcon : f .WellknownConfig ().UserInfoURL ,
1043
1269
// Omit the /user for the validate so we can easily append to it when modifying
1044
1270
// the cfg for advanced tests.
1045
1271
ValidateURL : f .issuerURL .ResolveReference (& url.URL {Path : "/external-auth-validate/" }).String (),
1272
+ DeviceAuth : & externalauth.DeviceAuth {
1273
+ Config : oauthCfg ,
1274
+ ClientID : f .clientID ,
1275
+ TokenURL : f .provider .TokenURL ,
1276
+ Scopes : []string {},
1277
+ CodeURL : f .provider .DeviceCodeURL ,
1278
+ },
1279
+ }
1280
+
1281
+ if ! custom .UseDeviceAuth {
1282
+ cfg .DeviceAuth = nil
1046
1283
}
1284
+
1047
1285
for _ , opt := range opts {
1048
1286
opt (cfg )
1049
1287
}
0 commit comments