Skip to content

Commit b3bb7d1

Browse files
committed
Remove hardcoded dependency to cryptohash type in the internals of SCRAM
SCRAM_KEY_LEN was a variable used in the internal routines of SCRAM to size a set of fixed-sized arrays used in the SHA and HMAC computations during the SASL exchange or when building a SCRAM password. This had a hard dependency on SHA-256, reducing the flexibility of SCRAM when it comes to the addition of more hash methods. A second issue was that SHA-256 is assumed as the cryptohash method to use all the time. This commit renames SCRAM_KEY_LEN to a more generic SCRAM_KEY_MAX_LEN, which is used as the size of the buffers used by the internal routines of SCRAM. This is aimed at tracking centrally the maximum size necessary for all the hash methods supported by SCRAM. A global variable has the advantage of keeping the code in its simplest form, reducing the need of more alloc/free logic for all the buffers used in the hash calculations. A second change is that the key length (SHA digest length) and hash types are now tracked by the state data in the backend and the frontend, the common portions being extended to handle these as arguments by the internal routines of SCRAM. There are a few RFC proposals floating around to extend the SCRAM protocol, including some to use stronger cryptohash algorithms, so this lifts some of the existing restrictions in the code. The code in charge of parsing and building SCRAM secrets is extended to rely on the key length and on the cryptohash type used for the exchange, assuming currently that only SHA-256 is supported for the moment. Note that the mock authentication simply enforces SHA-256. Author: Michael Paquier Reviewed-by: Peter Eisentraut, Jonathan Katz Discussion: https://postgr.es/m/Y5k3Qiweo/1g9CG6@paquier.xyz
1 parent eb60eb0 commit b3bb7d1

File tree

6 files changed

+206
-131
lines changed

6 files changed

+206
-131
lines changed

src/backend/libpq/auth-scram.c

Lines changed: 87 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,14 @@ typedef struct
141141
Port *port;
142142
bool channel_binding_in_use;
143143

144+
/* State data depending on the hash type */
145+
pg_cryptohash_type hash_type;
146+
int key_length;
147+
144148
int iterations;
145149
char *salt; /* base64-encoded */
146-
uint8 StoredKey[SCRAM_KEY_LEN];
147-
uint8 ServerKey[SCRAM_KEY_LEN];
150+
uint8 StoredKey[SCRAM_MAX_KEY_LEN];
151+
uint8 ServerKey[SCRAM_MAX_KEY_LEN];
148152

149153
/* Fields of the first message from client */
150154
char cbind_flag;
@@ -155,7 +159,7 @@ typedef struct
155159
/* Fields from the last message from client */
156160
char *client_final_message_without_proof;
157161
char *client_final_nonce;
158-
char ClientProof[SCRAM_KEY_LEN];
162+
char ClientProof[SCRAM_MAX_KEY_LEN];
159163

160164
/* Fields generated in the server */
161165
char *server_first_message;
@@ -177,12 +181,15 @@ static char *build_server_first_message(scram_state *state);
177181
static char *build_server_final_message(scram_state *state);
178182
static bool verify_client_proof(scram_state *state);
179183
static bool verify_final_nonce(scram_state *state);
180-
static void mock_scram_secret(const char *username, int *iterations,
181-
char **salt, uint8 *stored_key, uint8 *server_key);
184+
static void mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
185+
int *iterations, int *key_length, char **salt,
186+
uint8 *stored_key, uint8 *server_key);
182187
static bool is_scram_printable(char *p);
183188
static char *sanitize_char(char c);
184189
static char *sanitize_str(const char *s);
185-
static char *scram_mock_salt(const char *username);
190+
static char *scram_mock_salt(const char *username,
191+
pg_cryptohash_type hash_type,
192+
int key_length);
186193

187194
/*
188195
* Get a list of SASL mechanisms that this module supports.
@@ -266,8 +273,11 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
266273

267274
if (password_type == PASSWORD_TYPE_SCRAM_SHA_256)
268275
{
269-
if (parse_scram_secret(shadow_pass, &state->iterations, &state->salt,
270-
state->StoredKey, state->ServerKey))
276+
if (parse_scram_secret(shadow_pass, &state->iterations,
277+
&state->hash_type, &state->key_length,
278+
&state->salt,
279+
state->StoredKey,
280+
state->ServerKey))
271281
got_secret = true;
272282
else
273283
{
@@ -310,8 +320,10 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
310320
*/
311321
if (!got_secret)
312322
{
313-
mock_scram_secret(state->port->user_name, &state->iterations,
314-
&state->salt, state->StoredKey, state->ServerKey);
323+
mock_scram_secret(state->port->user_name, &state->hash_type,
324+
&state->iterations, &state->key_length,
325+
&state->salt,
326+
state->StoredKey, state->ServerKey);
315327
state->doomed = true;
316328
}
317329

@@ -482,7 +494,8 @@ pg_be_scram_build_secret(const char *password)
482494
(errcode(ERRCODE_INTERNAL_ERROR),
483495
errmsg("could not generate random salt")));
484496

485-
result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
497+
result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN,
498+
saltbuf, SCRAM_DEFAULT_SALT_LEN,
486499
SCRAM_DEFAULT_ITERATIONS, password,
487500
&errstr);
488501

@@ -505,16 +518,18 @@ scram_verify_plain_password(const char *username, const char *password,
505518
char *salt;
506519
int saltlen;
507520
int iterations;
508-
uint8 salted_password[SCRAM_KEY_LEN];
509-
uint8 stored_key[SCRAM_KEY_LEN];
510-
uint8 server_key[SCRAM_KEY_LEN];
511-
uint8 computed_key[SCRAM_KEY_LEN];
521+
int key_length = 0;
522+
pg_cryptohash_type hash_type;
523+
uint8 salted_password[SCRAM_MAX_KEY_LEN];
524+
uint8 stored_key[SCRAM_MAX_KEY_LEN];
525+
uint8 server_key[SCRAM_MAX_KEY_LEN];
526+
uint8 computed_key[SCRAM_MAX_KEY_LEN];
512527
char *prep_password;
513528
pg_saslprep_rc rc;
514529
const char *errstr = NULL;
515530

516-
if (!parse_scram_secret(secret, &iterations, &encoded_salt,
517-
stored_key, server_key))
531+
if (!parse_scram_secret(secret, &iterations, &hash_type, &key_length,
532+
&encoded_salt, stored_key, server_key))
518533
{
519534
/*
520535
* The password looked like a SCRAM secret, but could not be parsed.
@@ -541,9 +556,11 @@ scram_verify_plain_password(const char *username, const char *password,
541556
password = prep_password;
542557

543558
/* Compute Server Key based on the user-supplied plaintext password */
544-
if (scram_SaltedPassword(password, salt, saltlen, iterations,
559+
if (scram_SaltedPassword(password, hash_type, key_length,
560+
salt, saltlen, iterations,
545561
salted_password, &errstr) < 0 ||
546-
scram_ServerKey(salted_password, computed_key, &errstr) < 0)
562+
scram_ServerKey(salted_password, hash_type, key_length,
563+
computed_key, &errstr) < 0)
547564
{
548565
elog(ERROR, "could not compute server key: %s", errstr);
549566
}
@@ -555,7 +572,7 @@ scram_verify_plain_password(const char *username, const char *password,
555572
* Compare the secret's Server Key with the one computed from the
556573
* user-supplied password.
557574
*/
558-
return memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0;
575+
return memcmp(computed_key, server_key, key_length) == 0;
559576
}
560577

561578

@@ -565,14 +582,15 @@ scram_verify_plain_password(const char *username, const char *password,
565582
* On success, the iteration count, salt, stored key, and server key are
566583
* extracted from the secret, and returned to the caller. For 'stored_key'
567584
* and 'server_key', the caller must pass pre-allocated buffers of size
568-
* SCRAM_KEY_LEN. Salt is returned as a base64-encoded, null-terminated
585+
* SCRAM_MAX_KEY_LEN. Salt is returned as a base64-encoded, null-terminated
569586
* string. The buffer for the salt is palloc'd by this function.
570587
*
571588
* Returns true if the SCRAM secret has been parsed, and false otherwise.
572589
*/
573590
bool
574-
parse_scram_secret(const char *secret, int *iterations, char **salt,
575-
uint8 *stored_key, uint8 *server_key)
591+
parse_scram_secret(const char *secret, int *iterations,
592+
pg_cryptohash_type *hash_type, int *key_length,
593+
char **salt, uint8 *stored_key, uint8 *server_key)
576594
{
577595
char *v;
578596
char *p;
@@ -606,6 +624,8 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
606624
/* Parse the fields */
607625
if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
608626
goto invalid_secret;
627+
*hash_type = PG_SHA256;
628+
*key_length = SCRAM_SHA_256_KEY_LEN;
609629

610630
errno = 0;
611631
*iterations = strtol(iterations_str, &p, 10);
@@ -631,17 +651,17 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
631651
decoded_stored_buf = palloc(decoded_len);
632652
decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
633653
decoded_stored_buf, decoded_len);
634-
if (decoded_len != SCRAM_KEY_LEN)
654+
if (decoded_len != *key_length)
635655
goto invalid_secret;
636-
memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
656+
memcpy(stored_key, decoded_stored_buf, *key_length);
637657

638658
decoded_len = pg_b64_dec_len(strlen(serverkey_str));
639659
decoded_server_buf = palloc(decoded_len);
640660
decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
641661
decoded_server_buf, decoded_len);
642-
if (decoded_len != SCRAM_KEY_LEN)
662+
if (decoded_len != *key_length)
643663
goto invalid_secret;
644-
memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
664+
memcpy(server_key, decoded_server_buf, *key_length);
645665

646666
return true;
647667

@@ -655,20 +675,25 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
655675
*
656676
* In a normal authentication, these are extracted from the secret
657677
* stored in the server. This function generates values that look
658-
* realistic, for when there is no stored secret.
678+
* realistic, for when there is no stored secret, using SCRAM-SHA-256.
659679
*
660680
* Like in parse_scram_secret(), for 'stored_key' and 'server_key', the
661-
* caller must pass pre-allocated buffers of size SCRAM_KEY_LEN, and
681+
* caller must pass pre-allocated buffers of size SCRAM_MAX_KEY_LEN, and
662682
* the buffer for the salt is palloc'd by this function.
663683
*/
664684
static void
665-
mock_scram_secret(const char *username, int *iterations, char **salt,
685+
mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
686+
int *iterations, int *key_length, char **salt,
666687
uint8 *stored_key, uint8 *server_key)
667688
{
668689
char *raw_salt;
669690
char *encoded_salt;
670691
int encoded_len;
671692

693+
/* Enforce the use of SHA-256, which would be realistic enough */
694+
*hash_type = PG_SHA256;
695+
*key_length = SCRAM_SHA_256_KEY_LEN;
696+
672697
/*
673698
* Generate deterministic salt.
674699
*
@@ -677,7 +702,7 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
677702
* as the salt generated for mock authentication uses the cluster's nonce
678703
* value.
679704
*/
680-
raw_salt = scram_mock_salt(username);
705+
raw_salt = scram_mock_salt(username, *hash_type, *key_length);
681706
if (raw_salt == NULL)
682707
elog(ERROR, "could not encode salt");
683708

@@ -695,8 +720,8 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
695720
*iterations = SCRAM_DEFAULT_ITERATIONS;
696721

697722
/* StoredKey and ServerKey are not used in a doomed authentication */
698-
memset(stored_key, 0, SCRAM_KEY_LEN);
699-
memset(server_key, 0, SCRAM_KEY_LEN);
723+
memset(stored_key, 0, SCRAM_MAX_KEY_LEN);
724+
memset(server_key, 0, SCRAM_MAX_KEY_LEN);
700725
}
701726

702727
/*
@@ -1111,10 +1136,10 @@ verify_final_nonce(scram_state *state)
11111136
static bool
11121137
verify_client_proof(scram_state *state)
11131138
{
1114-
uint8 ClientSignature[SCRAM_KEY_LEN];
1115-
uint8 ClientKey[SCRAM_KEY_LEN];
1116-
uint8 client_StoredKey[SCRAM_KEY_LEN];
1117-
pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
1139+
uint8 ClientSignature[SCRAM_MAX_KEY_LEN];
1140+
uint8 ClientKey[SCRAM_MAX_KEY_LEN];
1141+
uint8 client_StoredKey[SCRAM_MAX_KEY_LEN];
1142+
pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
11181143
int i;
11191144
const char *errstr = NULL;
11201145

@@ -1123,7 +1148,7 @@ verify_client_proof(scram_state *state)
11231148
* here even when processing the calculations as this could involve a mock
11241149
* authentication.
11251150
*/
1126-
if (pg_hmac_init(ctx, state->StoredKey, SCRAM_KEY_LEN) < 0 ||
1151+
if (pg_hmac_init(ctx, state->StoredKey, state->key_length) < 0 ||
11271152
pg_hmac_update(ctx,
11281153
(uint8 *) state->client_first_message_bare,
11291154
strlen(state->client_first_message_bare)) < 0 ||
@@ -1135,7 +1160,7 @@ verify_client_proof(scram_state *state)
11351160
pg_hmac_update(ctx,
11361161
(uint8 *) state->client_final_message_without_proof,
11371162
strlen(state->client_final_message_without_proof)) < 0 ||
1138-
pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
1163+
pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
11391164
{
11401165
elog(ERROR, "could not calculate client signature: %s",
11411166
pg_hmac_error(ctx));
@@ -1144,14 +1169,15 @@ verify_client_proof(scram_state *state)
11441169
pg_hmac_free(ctx);
11451170

11461171
/* Extract the ClientKey that the client calculated from the proof */
1147-
for (i = 0; i < SCRAM_KEY_LEN; i++)
1172+
for (i = 0; i < state->key_length; i++)
11481173
ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
11491174

11501175
/* Hash it one more time, and compare with StoredKey */
1151-
if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey, &errstr) < 0)
1176+
if (scram_H(ClientKey, state->hash_type, state->key_length,
1177+
client_StoredKey, &errstr) < 0)
11521178
elog(ERROR, "could not hash stored key: %s", errstr);
11531179

1154-
if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
1180+
if (memcmp(client_StoredKey, state->StoredKey, state->key_length) != 0)
11551181
return false;
11561182

11571183
return true;
@@ -1349,12 +1375,12 @@ read_client_final_message(scram_state *state, const char *input)
13491375
client_proof_len = pg_b64_dec_len(strlen(value));
13501376
client_proof = palloc(client_proof_len);
13511377
if (pg_b64_decode(value, strlen(value), client_proof,
1352-
client_proof_len) != SCRAM_KEY_LEN)
1378+
client_proof_len) != state->key_length)
13531379
ereport(ERROR,
13541380
(errcode(ERRCODE_PROTOCOL_VIOLATION),
13551381
errmsg("malformed SCRAM message"),
13561382
errdetail("Malformed proof in client-final-message.")));
1357-
memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN);
1383+
memcpy(state->ClientProof, client_proof, state->key_length);
13581384
pfree(client_proof);
13591385

13601386
if (*p != '\0')
@@ -1374,13 +1400,13 @@ read_client_final_message(scram_state *state, const char *input)
13741400
static char *
13751401
build_server_final_message(scram_state *state)
13761402
{
1377-
uint8 ServerSignature[SCRAM_KEY_LEN];
1403+
uint8 ServerSignature[SCRAM_MAX_KEY_LEN];
13781404
char *server_signature_base64;
13791405
int siglen;
1380-
pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
1406+
pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
13811407

13821408
/* calculate ServerSignature */
1383-
if (pg_hmac_init(ctx, state->ServerKey, SCRAM_KEY_LEN) < 0 ||
1409+
if (pg_hmac_init(ctx, state->ServerKey, state->key_length) < 0 ||
13841410
pg_hmac_update(ctx,
13851411
(uint8 *) state->client_first_message_bare,
13861412
strlen(state->client_first_message_bare)) < 0 ||
@@ -1392,19 +1418,19 @@ build_server_final_message(scram_state *state)
13921418
pg_hmac_update(ctx,
13931419
(uint8 *) state->client_final_message_without_proof,
13941420
strlen(state->client_final_message_without_proof)) < 0 ||
1395-
pg_hmac_final(ctx, ServerSignature, sizeof(ServerSignature)) < 0)
1421+
pg_hmac_final(ctx, ServerSignature, state->key_length) < 0)
13961422
{
13971423
elog(ERROR, "could not calculate server signature: %s",
13981424
pg_hmac_error(ctx));
13991425
}
14001426

14011427
pg_hmac_free(ctx);
14021428

1403-
siglen = pg_b64_enc_len(SCRAM_KEY_LEN);
1429+
siglen = pg_b64_enc_len(state->key_length);
14041430
/* don't forget the zero-terminator */
14051431
server_signature_base64 = palloc(siglen + 1);
14061432
siglen = pg_b64_encode((const char *) ServerSignature,
1407-
SCRAM_KEY_LEN, server_signature_base64,
1433+
state->key_length, server_signature_base64,
14081434
siglen);
14091435
if (siglen < 0)
14101436
elog(ERROR, "could not encode server signature");
@@ -1431,10 +1457,11 @@ build_server_final_message(scram_state *state)
14311457
* pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL.
14321458
*/
14331459
static char *
1434-
scram_mock_salt(const char *username)
1460+
scram_mock_salt(const char *username, pg_cryptohash_type hash_type,
1461+
int key_length)
14351462
{
14361463
pg_cryptohash_ctx *ctx;
1437-
static uint8 sha_digest[PG_SHA256_DIGEST_LENGTH];
1464+
static uint8 sha_digest[SCRAM_MAX_KEY_LEN];
14381465
char *mock_auth_nonce = GetMockAuthenticationNonce();
14391466

14401467
/*
@@ -1446,11 +1473,17 @@ scram_mock_salt(const char *username)
14461473
StaticAssertDecl(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
14471474
"salt length greater than SHA256 digest length");
14481475

1449-
ctx = pg_cryptohash_create(PG_SHA256);
1476+
/*
1477+
* This may be worth refreshing if support for more hash methods is\
1478+
* added.
1479+
*/
1480+
Assert(hash_type == PG_SHA256);
1481+
1482+
ctx = pg_cryptohash_create(hash_type);
14501483
if (pg_cryptohash_init(ctx) < 0 ||
14511484
pg_cryptohash_update(ctx, (uint8 *) username, strlen(username)) < 0 ||
14521485
pg_cryptohash_update(ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 ||
1453-
pg_cryptohash_final(ctx, sha_digest, sizeof(sha_digest)) < 0)
1486+
pg_cryptohash_final(ctx, sha_digest, key_length) < 0)
14541487
{
14551488
pg_cryptohash_free(ctx);
14561489
return NULL;

src/backend/libpq/crypt.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,17 @@ get_password_type(const char *shadow_pass)
9090
{
9191
char *encoded_salt;
9292
int iterations;
93-
uint8 stored_key[SCRAM_KEY_LEN];
94-
uint8 server_key[SCRAM_KEY_LEN];
93+
int key_length = 0;
94+
pg_cryptohash_type hash_type;
95+
uint8 stored_key[SCRAM_MAX_KEY_LEN];
96+
uint8 server_key[SCRAM_MAX_KEY_LEN];
9597

9698
if (strncmp(shadow_pass, "md5", 3) == 0 &&
9799
strlen(shadow_pass) == MD5_PASSWD_LEN &&
98100
strspn(shadow_pass + 3, MD5_PASSWD_CHARSET) == MD5_PASSWD_LEN - 3)
99101
return PASSWORD_TYPE_MD5;
100-
if (parse_scram_secret(shadow_pass, &iterations, &encoded_salt,
101-
stored_key, server_key))
102+
if (parse_scram_secret(shadow_pass, &iterations, &hash_type, &key_length,
103+
&encoded_salt, stored_key, server_key))
102104
return PASSWORD_TYPE_SCRAM_SHA_256;
103105
return PASSWORD_TYPE_PLAINTEXT;
104106
}

0 commit comments

Comments
 (0)