@@ -141,10 +141,14 @@ typedef struct
141
141
Port * port ;
142
142
bool channel_binding_in_use ;
143
143
144
+ /* State data depending on the hash type */
145
+ pg_cryptohash_type hash_type ;
146
+ int key_length ;
147
+
144
148
int iterations ;
145
149
char * salt ; /* base64-encoded */
146
- uint8 StoredKey [SCRAM_KEY_LEN ];
147
- uint8 ServerKey [SCRAM_KEY_LEN ];
150
+ uint8 StoredKey [SCRAM_MAX_KEY_LEN ];
151
+ uint8 ServerKey [SCRAM_MAX_KEY_LEN ];
148
152
149
153
/* Fields of the first message from client */
150
154
char cbind_flag ;
@@ -155,7 +159,7 @@ typedef struct
155
159
/* Fields from the last message from client */
156
160
char * client_final_message_without_proof ;
157
161
char * client_final_nonce ;
158
- char ClientProof [SCRAM_KEY_LEN ];
162
+ char ClientProof [SCRAM_MAX_KEY_LEN ];
159
163
160
164
/* Fields generated in the server */
161
165
char * server_first_message ;
@@ -177,12 +181,15 @@ static char *build_server_first_message(scram_state *state);
177
181
static char * build_server_final_message (scram_state * state );
178
182
static bool verify_client_proof (scram_state * state );
179
183
static bool verify_final_nonce (scram_state * state );
180
- static void mock_scram_secret (const char * username , int * iterations ,
181
- char * * salt , uint8 * stored_key , uint8 * server_key );
184
+ static void mock_scram_secret (const char * username , pg_cryptohash_type * hash_type ,
185
+ int * iterations , int * key_length , char * * salt ,
186
+ uint8 * stored_key , uint8 * server_key );
182
187
static bool is_scram_printable (char * p );
183
188
static char * sanitize_char (char c );
184
189
static char * sanitize_str (const char * s );
185
- static char * scram_mock_salt (const char * username );
190
+ static char * scram_mock_salt (const char * username ,
191
+ pg_cryptohash_type hash_type ,
192
+ int key_length );
186
193
187
194
/*
188
195
* Get a list of SASL mechanisms that this module supports.
@@ -266,8 +273,11 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
266
273
267
274
if (password_type == PASSWORD_TYPE_SCRAM_SHA_256 )
268
275
{
269
- if (parse_scram_secret (shadow_pass , & state -> iterations , & state -> salt ,
270
- state -> StoredKey , state -> ServerKey ))
276
+ if (parse_scram_secret (shadow_pass , & state -> iterations ,
277
+ & state -> hash_type , & state -> key_length ,
278
+ & state -> salt ,
279
+ state -> StoredKey ,
280
+ state -> ServerKey ))
271
281
got_secret = true;
272
282
else
273
283
{
@@ -310,8 +320,10 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
310
320
*/
311
321
if (!got_secret )
312
322
{
313
- mock_scram_secret (state -> port -> user_name , & state -> iterations ,
314
- & state -> salt , state -> StoredKey , state -> ServerKey );
323
+ mock_scram_secret (state -> port -> user_name , & state -> hash_type ,
324
+ & state -> iterations , & state -> key_length ,
325
+ & state -> salt ,
326
+ state -> StoredKey , state -> ServerKey );
315
327
state -> doomed = true;
316
328
}
317
329
@@ -482,7 +494,8 @@ pg_be_scram_build_secret(const char *password)
482
494
(errcode (ERRCODE_INTERNAL_ERROR ),
483
495
errmsg ("could not generate random salt" )));
484
496
485
- result = scram_build_secret (saltbuf , SCRAM_DEFAULT_SALT_LEN ,
497
+ result = scram_build_secret (PG_SHA256 , SCRAM_SHA_256_KEY_LEN ,
498
+ saltbuf , SCRAM_DEFAULT_SALT_LEN ,
486
499
SCRAM_DEFAULT_ITERATIONS , password ,
487
500
& errstr );
488
501
@@ -505,16 +518,18 @@ scram_verify_plain_password(const char *username, const char *password,
505
518
char * salt ;
506
519
int saltlen ;
507
520
int iterations ;
508
- uint8 salted_password [SCRAM_KEY_LEN ];
509
- uint8 stored_key [SCRAM_KEY_LEN ];
510
- uint8 server_key [SCRAM_KEY_LEN ];
511
- uint8 computed_key [SCRAM_KEY_LEN ];
521
+ int key_length = 0 ;
522
+ pg_cryptohash_type hash_type ;
523
+ uint8 salted_password [SCRAM_MAX_KEY_LEN ];
524
+ uint8 stored_key [SCRAM_MAX_KEY_LEN ];
525
+ uint8 server_key [SCRAM_MAX_KEY_LEN ];
526
+ uint8 computed_key [SCRAM_MAX_KEY_LEN ];
512
527
char * prep_password ;
513
528
pg_saslprep_rc rc ;
514
529
const char * errstr = NULL ;
515
530
516
- if (!parse_scram_secret (secret , & iterations , & encoded_salt ,
517
- stored_key , server_key ))
531
+ if (!parse_scram_secret (secret , & iterations , & hash_type , & key_length ,
532
+ & encoded_salt , stored_key , server_key ))
518
533
{
519
534
/*
520
535
* The password looked like a SCRAM secret, but could not be parsed.
@@ -541,9 +556,11 @@ scram_verify_plain_password(const char *username, const char *password,
541
556
password = prep_password ;
542
557
543
558
/* Compute Server Key based on the user-supplied plaintext password */
544
- if (scram_SaltedPassword (password , salt , saltlen , iterations ,
559
+ if (scram_SaltedPassword (password , hash_type , key_length ,
560
+ salt , saltlen , iterations ,
545
561
salted_password , & errstr ) < 0 ||
546
- scram_ServerKey (salted_password , computed_key , & errstr ) < 0 )
562
+ scram_ServerKey (salted_password , hash_type , key_length ,
563
+ computed_key , & errstr ) < 0 )
547
564
{
548
565
elog (ERROR , "could not compute server key: %s" , errstr );
549
566
}
@@ -555,7 +572,7 @@ scram_verify_plain_password(const char *username, const char *password,
555
572
* Compare the secret's Server Key with the one computed from the
556
573
* user-supplied password.
557
574
*/
558
- return memcmp (computed_key , server_key , SCRAM_KEY_LEN ) == 0 ;
575
+ return memcmp (computed_key , server_key , key_length ) == 0 ;
559
576
}
560
577
561
578
@@ -565,14 +582,15 @@ scram_verify_plain_password(const char *username, const char *password,
565
582
* On success, the iteration count, salt, stored key, and server key are
566
583
* extracted from the secret, and returned to the caller. For 'stored_key'
567
584
* and 'server_key', the caller must pass pre-allocated buffers of size
568
- * SCRAM_KEY_LEN . Salt is returned as a base64-encoded, null-terminated
585
+ * SCRAM_MAX_KEY_LEN . Salt is returned as a base64-encoded, null-terminated
569
586
* string. The buffer for the salt is palloc'd by this function.
570
587
*
571
588
* Returns true if the SCRAM secret has been parsed, and false otherwise.
572
589
*/
573
590
bool
574
- parse_scram_secret (const char * secret , int * iterations , char * * salt ,
575
- uint8 * stored_key , uint8 * server_key )
591
+ parse_scram_secret (const char * secret , int * iterations ,
592
+ pg_cryptohash_type * hash_type , int * key_length ,
593
+ char * * salt , uint8 * stored_key , uint8 * server_key )
576
594
{
577
595
char * v ;
578
596
char * p ;
@@ -606,6 +624,8 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
606
624
/* Parse the fields */
607
625
if (strcmp (scheme_str , "SCRAM-SHA-256" ) != 0 )
608
626
goto invalid_secret ;
627
+ * hash_type = PG_SHA256 ;
628
+ * key_length = SCRAM_SHA_256_KEY_LEN ;
609
629
610
630
errno = 0 ;
611
631
* iterations = strtol (iterations_str , & p , 10 );
@@ -631,17 +651,17 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
631
651
decoded_stored_buf = palloc (decoded_len );
632
652
decoded_len = pg_b64_decode (storedkey_str , strlen (storedkey_str ),
633
653
decoded_stored_buf , decoded_len );
634
- if (decoded_len != SCRAM_KEY_LEN )
654
+ if (decoded_len != * key_length )
635
655
goto invalid_secret ;
636
- memcpy (stored_key , decoded_stored_buf , SCRAM_KEY_LEN );
656
+ memcpy (stored_key , decoded_stored_buf , * key_length );
637
657
638
658
decoded_len = pg_b64_dec_len (strlen (serverkey_str ));
639
659
decoded_server_buf = palloc (decoded_len );
640
660
decoded_len = pg_b64_decode (serverkey_str , strlen (serverkey_str ),
641
661
decoded_server_buf , decoded_len );
642
- if (decoded_len != SCRAM_KEY_LEN )
662
+ if (decoded_len != * key_length )
643
663
goto invalid_secret ;
644
- memcpy (server_key , decoded_server_buf , SCRAM_KEY_LEN );
664
+ memcpy (server_key , decoded_server_buf , * key_length );
645
665
646
666
return true;
647
667
@@ -655,20 +675,25 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
655
675
*
656
676
* In a normal authentication, these are extracted from the secret
657
677
* stored in the server. This function generates values that look
658
- * realistic, for when there is no stored secret.
678
+ * realistic, for when there is no stored secret, using SCRAM-SHA-256 .
659
679
*
660
680
* Like in parse_scram_secret(), for 'stored_key' and 'server_key', the
661
- * caller must pass pre-allocated buffers of size SCRAM_KEY_LEN , and
681
+ * caller must pass pre-allocated buffers of size SCRAM_MAX_KEY_LEN , and
662
682
* the buffer for the salt is palloc'd by this function.
663
683
*/
664
684
static void
665
- mock_scram_secret (const char * username , int * iterations , char * * salt ,
685
+ mock_scram_secret (const char * username , pg_cryptohash_type * hash_type ,
686
+ int * iterations , int * key_length , char * * salt ,
666
687
uint8 * stored_key , uint8 * server_key )
667
688
{
668
689
char * raw_salt ;
669
690
char * encoded_salt ;
670
691
int encoded_len ;
671
692
693
+ /* Enforce the use of SHA-256, which would be realistic enough */
694
+ * hash_type = PG_SHA256 ;
695
+ * key_length = SCRAM_SHA_256_KEY_LEN ;
696
+
672
697
/*
673
698
* Generate deterministic salt.
674
699
*
@@ -677,7 +702,7 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
677
702
* as the salt generated for mock authentication uses the cluster's nonce
678
703
* value.
679
704
*/
680
- raw_salt = scram_mock_salt (username );
705
+ raw_salt = scram_mock_salt (username , * hash_type , * key_length );
681
706
if (raw_salt == NULL )
682
707
elog (ERROR , "could not encode salt" );
683
708
@@ -695,8 +720,8 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
695
720
* iterations = SCRAM_DEFAULT_ITERATIONS ;
696
721
697
722
/* StoredKey and ServerKey are not used in a doomed authentication */
698
- memset (stored_key , 0 , SCRAM_KEY_LEN );
699
- memset (server_key , 0 , SCRAM_KEY_LEN );
723
+ memset (stored_key , 0 , SCRAM_MAX_KEY_LEN );
724
+ memset (server_key , 0 , SCRAM_MAX_KEY_LEN );
700
725
}
701
726
702
727
/*
@@ -1111,10 +1136,10 @@ verify_final_nonce(scram_state *state)
1111
1136
static bool
1112
1137
verify_client_proof (scram_state * state )
1113
1138
{
1114
- uint8 ClientSignature [SCRAM_KEY_LEN ];
1115
- uint8 ClientKey [SCRAM_KEY_LEN ];
1116
- uint8 client_StoredKey [SCRAM_KEY_LEN ];
1117
- pg_hmac_ctx * ctx = pg_hmac_create (PG_SHA256 );
1139
+ uint8 ClientSignature [SCRAM_MAX_KEY_LEN ];
1140
+ uint8 ClientKey [SCRAM_MAX_KEY_LEN ];
1141
+ uint8 client_StoredKey [SCRAM_MAX_KEY_LEN ];
1142
+ pg_hmac_ctx * ctx = pg_hmac_create (state -> hash_type );
1118
1143
int i ;
1119
1144
const char * errstr = NULL ;
1120
1145
@@ -1123,7 +1148,7 @@ verify_client_proof(scram_state *state)
1123
1148
* here even when processing the calculations as this could involve a mock
1124
1149
* authentication.
1125
1150
*/
1126
- if (pg_hmac_init (ctx , state -> StoredKey , SCRAM_KEY_LEN ) < 0 ||
1151
+ if (pg_hmac_init (ctx , state -> StoredKey , state -> key_length ) < 0 ||
1127
1152
pg_hmac_update (ctx ,
1128
1153
(uint8 * ) state -> client_first_message_bare ,
1129
1154
strlen (state -> client_first_message_bare )) < 0 ||
@@ -1135,7 +1160,7 @@ verify_client_proof(scram_state *state)
1135
1160
pg_hmac_update (ctx ,
1136
1161
(uint8 * ) state -> client_final_message_without_proof ,
1137
1162
strlen (state -> client_final_message_without_proof )) < 0 ||
1138
- pg_hmac_final (ctx , ClientSignature , sizeof ( ClientSignature ) ) < 0 )
1163
+ pg_hmac_final (ctx , ClientSignature , state -> key_length ) < 0 )
1139
1164
{
1140
1165
elog (ERROR , "could not calculate client signature: %s" ,
1141
1166
pg_hmac_error (ctx ));
@@ -1144,14 +1169,15 @@ verify_client_proof(scram_state *state)
1144
1169
pg_hmac_free (ctx );
1145
1170
1146
1171
/* Extract the ClientKey that the client calculated from the proof */
1147
- for (i = 0 ; i < SCRAM_KEY_LEN ; i ++ )
1172
+ for (i = 0 ; i < state -> key_length ; i ++ )
1148
1173
ClientKey [i ] = state -> ClientProof [i ] ^ ClientSignature [i ];
1149
1174
1150
1175
/* Hash it one more time, and compare with StoredKey */
1151
- if (scram_H (ClientKey , SCRAM_KEY_LEN , client_StoredKey , & errstr ) < 0 )
1176
+ if (scram_H (ClientKey , state -> hash_type , state -> key_length ,
1177
+ client_StoredKey , & errstr ) < 0 )
1152
1178
elog (ERROR , "could not hash stored key: %s" , errstr );
1153
1179
1154
- if (memcmp (client_StoredKey , state -> StoredKey , SCRAM_KEY_LEN ) != 0 )
1180
+ if (memcmp (client_StoredKey , state -> StoredKey , state -> key_length ) != 0 )
1155
1181
return false;
1156
1182
1157
1183
return true;
@@ -1349,12 +1375,12 @@ read_client_final_message(scram_state *state, const char *input)
1349
1375
client_proof_len = pg_b64_dec_len (strlen (value ));
1350
1376
client_proof = palloc (client_proof_len );
1351
1377
if (pg_b64_decode (value , strlen (value ), client_proof ,
1352
- client_proof_len ) != SCRAM_KEY_LEN )
1378
+ client_proof_len ) != state -> key_length )
1353
1379
ereport (ERROR ,
1354
1380
(errcode (ERRCODE_PROTOCOL_VIOLATION ),
1355
1381
errmsg ("malformed SCRAM message" ),
1356
1382
errdetail ("Malformed proof in client-final-message." )));
1357
- memcpy (state -> ClientProof , client_proof , SCRAM_KEY_LEN );
1383
+ memcpy (state -> ClientProof , client_proof , state -> key_length );
1358
1384
pfree (client_proof );
1359
1385
1360
1386
if (* p != '\0' )
@@ -1374,13 +1400,13 @@ read_client_final_message(scram_state *state, const char *input)
1374
1400
static char *
1375
1401
build_server_final_message (scram_state * state )
1376
1402
{
1377
- uint8 ServerSignature [SCRAM_KEY_LEN ];
1403
+ uint8 ServerSignature [SCRAM_MAX_KEY_LEN ];
1378
1404
char * server_signature_base64 ;
1379
1405
int siglen ;
1380
- pg_hmac_ctx * ctx = pg_hmac_create (PG_SHA256 );
1406
+ pg_hmac_ctx * ctx = pg_hmac_create (state -> hash_type );
1381
1407
1382
1408
/* calculate ServerSignature */
1383
- if (pg_hmac_init (ctx , state -> ServerKey , SCRAM_KEY_LEN ) < 0 ||
1409
+ if (pg_hmac_init (ctx , state -> ServerKey , state -> key_length ) < 0 ||
1384
1410
pg_hmac_update (ctx ,
1385
1411
(uint8 * ) state -> client_first_message_bare ,
1386
1412
strlen (state -> client_first_message_bare )) < 0 ||
@@ -1392,19 +1418,19 @@ build_server_final_message(scram_state *state)
1392
1418
pg_hmac_update (ctx ,
1393
1419
(uint8 * ) state -> client_final_message_without_proof ,
1394
1420
strlen (state -> client_final_message_without_proof )) < 0 ||
1395
- pg_hmac_final (ctx , ServerSignature , sizeof ( ServerSignature ) ) < 0 )
1421
+ pg_hmac_final (ctx , ServerSignature , state -> key_length ) < 0 )
1396
1422
{
1397
1423
elog (ERROR , "could not calculate server signature: %s" ,
1398
1424
pg_hmac_error (ctx ));
1399
1425
}
1400
1426
1401
1427
pg_hmac_free (ctx );
1402
1428
1403
- siglen = pg_b64_enc_len (SCRAM_KEY_LEN );
1429
+ siglen = pg_b64_enc_len (state -> key_length );
1404
1430
/* don't forget the zero-terminator */
1405
1431
server_signature_base64 = palloc (siglen + 1 );
1406
1432
siglen = pg_b64_encode ((const char * ) ServerSignature ,
1407
- SCRAM_KEY_LEN , server_signature_base64 ,
1433
+ state -> key_length , server_signature_base64 ,
1408
1434
siglen );
1409
1435
if (siglen < 0 )
1410
1436
elog (ERROR , "could not encode server signature" );
@@ -1431,10 +1457,11 @@ build_server_final_message(scram_state *state)
1431
1457
* pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL.
1432
1458
*/
1433
1459
static char *
1434
- scram_mock_salt (const char * username )
1460
+ scram_mock_salt (const char * username , pg_cryptohash_type hash_type ,
1461
+ int key_length )
1435
1462
{
1436
1463
pg_cryptohash_ctx * ctx ;
1437
- static uint8 sha_digest [PG_SHA256_DIGEST_LENGTH ];
1464
+ static uint8 sha_digest [SCRAM_MAX_KEY_LEN ];
1438
1465
char * mock_auth_nonce = GetMockAuthenticationNonce ();
1439
1466
1440
1467
/*
@@ -1446,11 +1473,17 @@ scram_mock_salt(const char *username)
1446
1473
StaticAssertDecl (PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN ,
1447
1474
"salt length greater than SHA256 digest length" );
1448
1475
1449
- ctx = pg_cryptohash_create (PG_SHA256 );
1476
+ /*
1477
+ * This may be worth refreshing if support for more hash methods is\
1478
+ * added.
1479
+ */
1480
+ Assert (hash_type == PG_SHA256 );
1481
+
1482
+ ctx = pg_cryptohash_create (hash_type );
1450
1483
if (pg_cryptohash_init (ctx ) < 0 ||
1451
1484
pg_cryptohash_update (ctx , (uint8 * ) username , strlen (username )) < 0 ||
1452
1485
pg_cryptohash_update (ctx , (uint8 * ) mock_auth_nonce , MOCK_AUTH_NONCE_LEN ) < 0 ||
1453
- pg_cryptohash_final (ctx , sha_digest , sizeof ( sha_digest ) ) < 0 )
1486
+ pg_cryptohash_final (ctx , sha_digest , key_length ) < 0 )
1454
1487
{
1455
1488
pg_cryptohash_free (ctx );
1456
1489
return NULL ;
0 commit comments