Skip to content

Commit a3862f0

Browse files
committed
Improve ML-DSA private key import
1 parent 16a6818 commit a3862f0

5 files changed

Lines changed: 218 additions & 90 deletions

File tree

src/ssl_load.c

Lines changed: 55 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,9 @@ static int ProcessBufferTryDecodeDilithium(WOLFSSL_CTX* ctx, WOLFSSL* ssl,
946946
int ret;
947947
word32 idx;
948948
dilithium_key* key;
949+
int keyFormatTemp = 0;
950+
int keyTypeTemp;
951+
int keySizeTemp;
949952

950953
/* Allocate a Dilithium key to parse into. */
951954
key = (dilithium_key*)XMALLOC(sizeof(dilithium_key), heap,
@@ -955,106 +958,75 @@ static int ProcessBufferTryDecodeDilithium(WOLFSSL_CTX* ctx, WOLFSSL* ssl,
955958
}
956959

957960
/* Initialize Dilithium key. */
958-
ret = wc_dilithium_init(key);
959-
if (ret == 0) {
960-
/* Set up key to parse the format specified. */
961-
if ((*keyFormat == ML_DSA_LEVEL2k) || ((*keyFormat == 0) &&
962-
((der->length == ML_DSA_LEVEL2_KEY_SIZE) ||
963-
(der->length == ML_DSA_LEVEL2_PRV_KEY_SIZE)))) {
964-
ret = wc_dilithium_set_level(key, WC_ML_DSA_44);
965-
}
966-
else if ((*keyFormat == ML_DSA_LEVEL3k) || ((*keyFormat == 0) &&
967-
((der->length == ML_DSA_LEVEL3_KEY_SIZE) ||
968-
(der->length == ML_DSA_LEVEL3_PRV_KEY_SIZE)))) {
969-
ret = wc_dilithium_set_level(key, WC_ML_DSA_65);
970-
}
971-
else if ((*keyFormat == ML_DSA_LEVEL5k) || ((*keyFormat == 0) &&
972-
((der->length == ML_DSA_LEVEL5_KEY_SIZE) ||
973-
(der->length == ML_DSA_LEVEL5_PRV_KEY_SIZE)))) {
974-
ret = wc_dilithium_set_level(key, WC_ML_DSA_87);
975-
}
976-
#ifdef WOLFSSL_DILITHIUM_FIPS204_DRAFT
977-
else if ((*keyFormat == DILITHIUM_LEVEL2k) || ((*keyFormat == 0) &&
978-
((der->length == DILITHIUM_LEVEL2_KEY_SIZE) ||
979-
(der->length == DILITHIUM_LEVEL2_PRV_KEY_SIZE)))) {
980-
ret = wc_dilithium_set_level(key, WC_ML_DSA_44_DRAFT);
981-
}
982-
else if ((*keyFormat == DILITHIUM_LEVEL3k) || ((*keyFormat == 0) &&
983-
((der->length == DILITHIUM_LEVEL3_KEY_SIZE) ||
984-
(der->length == DILITHIUM_LEVEL3_PRV_KEY_SIZE)))) {
985-
ret = wc_dilithium_set_level(key, WC_ML_DSA_65_DRAFT);
986-
}
987-
else if ((*keyFormat == DILITHIUM_LEVEL5k) || ((*keyFormat == 0) &&
988-
((der->length == DILITHIUM_LEVEL5_KEY_SIZE) ||
989-
(der->length == DILITHIUM_LEVEL5_PRV_KEY_SIZE)))) {
990-
ret = wc_dilithium_set_level(key, WC_ML_DSA_87_DRAFT);
991-
}
992-
#endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */
993-
else {
994-
wc_dilithium_free(key);
995-
ret = ALGO_ID_E;
996-
}
997-
}
998-
961+
ret = wc_dilithium_init(key);
999962
if (ret == 0) {
1000963
/* Decode as a Dilithium private key. */
1001964
idx = 0;
1002965
ret = wc_Dilithium_PrivateKeyDecode(der->buffer, &idx, key, der->length);
1003966
if (ret == 0) {
1004-
/* Get the minimum Dilithium key size from SSL or SSL context
1005-
* object. */
1006-
int minKeySz = ssl ? ssl->options.minDilithiumKeySz :
1007-
ctx->minDilithiumKeySz;
1008-
1009-
/* Format is known. */
1010-
if (*keyFormat == ML_DSA_LEVEL2k) {
1011-
*keyType = dilithium_level2_sa_algo;
1012-
*keySize = ML_DSA_LEVEL2_KEY_SIZE;
1013-
}
1014-
else if (*keyFormat == ML_DSA_LEVEL3k) {
1015-
*keyType = dilithium_level3_sa_algo;
1016-
*keySize = ML_DSA_LEVEL3_KEY_SIZE;
1017-
}
1018-
else if (*keyFormat == ML_DSA_LEVEL5k) {
1019-
*keyType = dilithium_level5_sa_algo;
1020-
*keySize = ML_DSA_LEVEL5_KEY_SIZE;
1021-
}
1022-
#ifdef WOLFSSL_DILITHIUM_FIPS204_DRAFT
1023-
else if (*keyFormat == DILITHIUM_LEVEL2k) {
1024-
*keyType = dilithium_level2_sa_algo;
1025-
*keySize = DILITHIUM_LEVEL2_KEY_SIZE;
1026-
}
1027-
else if (*keyFormat == DILITHIUM_LEVEL3k) {
1028-
*keyType = dilithium_level3_sa_algo;
1029-
*keySize = DILITHIUM_LEVEL3_KEY_SIZE;
967+
ret = dilithium_get_oid_sum(key, &keyFormatTemp);
968+
if(ret == 0) {
969+
/* Format is known. */
970+
#if defined(WOLFSSL_DILITHIUM_FIPS204_DRAFT)
971+
if (keyFormatTemp == DILITHIUM_LEVEL2k) {
972+
keyTypeTemp = dilithium_level2_sa_algo;
973+
keySizeTemp = DILITHIUM_LEVEL2_KEY_SIZE;
974+
}
975+
else if (keyFormatTemp == DILITHIUM_LEVEL3k) {
976+
keyTypeTemp = dilithium_level3_sa_algo;
977+
keySizeTemp = DILITHIUM_LEVEL3_KEY_SIZE;
978+
}
979+
else if (keyFormatTemp == DILITHIUM_LEVEL5k) {
980+
keyTypeTemp = dilithium_level5_sa_algo;
981+
keySizeTemp = DILITHIUM_LEVEL5_KEY_SIZE;
982+
}
983+
else
984+
#endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */
985+
if (keyFormatTemp == ML_DSA_LEVEL2k) {
986+
keyTypeTemp = dilithium_level2_sa_algo;
987+
keySizeTemp = ML_DSA_LEVEL2_KEY_SIZE;
988+
}
989+
else if (keyFormatTemp == ML_DSA_LEVEL3k) {
990+
keyTypeTemp = dilithium_level3_sa_algo;
991+
keySizeTemp = ML_DSA_LEVEL3_KEY_SIZE;
992+
}
993+
else if (keyFormatTemp == ML_DSA_LEVEL5k) {
994+
keyTypeTemp = dilithium_level5_sa_algo;
995+
keySizeTemp = ML_DSA_LEVEL5_KEY_SIZE;
996+
}
997+
else {
998+
ret = ALGO_ID_E;
999+
}
10301000
}
1031-
else if (*keyFormat == DILITHIUM_LEVEL5k) {
1032-
*keyType = dilithium_level5_sa_algo;
1033-
*keySize = DILITHIUM_LEVEL5_KEY_SIZE;
1001+
1002+
if(ret == 0) {
1003+
/* Get the minimum Dilithium key size from SSL or SSL context
1004+
* object. */
1005+
int minKeySz = ssl ? ssl->options.minDilithiumKeySz :
1006+
ctx->minDilithiumKeySz;
1007+
1008+
/* Check that the size of the Dilithium key is enough. */
1009+
if (keySizeTemp < minKeySz) {
1010+
WOLFSSL_MSG("Dilithium private key too small");
1011+
ret = DILITHIUM_KEY_SIZE_E;
1012+
}
10341013
}
1035-
#endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */
10361014

1037-
/* Check that the size of the Dilithium key is enough. */
1038-
if (*keySize < minKeySz) {
1039-
WOLFSSL_MSG("Dilithium private key too small");
1040-
ret = DILITHIUM_KEY_SIZE_E;
1015+
if(ret == 0) {
1016+
*keyFormat = keyFormatTemp;
1017+
*keyType = keyTypeTemp;
1018+
*keySize = keySizeTemp;
10411019
}
10421020
}
1043-
/* Not a Dilithium key but check whether we know what it is. */
10441021
else if (*keyFormat == 0) {
10451022
WOLFSSL_MSG("Not a Dilithium key");
1046-
/* Format unknown so keep trying. */
1023+
/* Unknowun format was not dilithium, so keep trying other formats. */
10471024
ret = 0;
10481025
}
1049-
1026+
10501027
/* Free dynamically allocated data in key. */
10511028
wc_dilithium_free(key);
10521029
}
1053-
else if ((ret == WC_NO_ERR_TRACE(ALGO_ID_E)) && (*keyFormat == 0)) {
1054-
WOLFSSL_MSG("Not a Dilithium key");
1055-
/* Format unknown so keep trying. */
1056-
ret = 0;
1057-
}
10581030

10591031
/* Dispose of allocated key. */
10601032
XFREE(key, heap, DYNAMIC_TYPE_DILITHIUM);

tests/api.c

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13933,6 +13933,119 @@ static int test_wolfSSL_PKCS8_ED448(void)
1393313933
return EXPECT_RESULT();
1393413934
}
1393513935

13936+
static int test_wolfSSL_PKCS8_MLDSA(void)
13937+
{
13938+
EXPECT_DECLS;
13939+
#if !defined(NO_ASN) && defined(HAVE_PKCS8) && \
13940+
defined(HAVE_DILITHIUM) && !defined(NO_TLS) && \
13941+
(!defined(NO_WOLFSSL_CLIENT) || !defined(NO_WOLFSSL_SERVER))
13942+
13943+
WOLFSSL_CTX* ctx = NULL;
13944+
size_t i;
13945+
const int derMaxSz = 8192; /* Largest size will be 7520 of separated format, WC_ML_DSA_87, DER */
13946+
const int tempMaxSz = 10240; /* Largest size will be 10239 of separated format, WC_MLS_DSA_87, PEM */
13947+
byte* der = NULL;
13948+
byte* temp = NULL; /* Store PEM or intermediate key */
13949+
word32 derSz = 0;
13950+
word32 pemSz = 0;
13951+
word32 keySz = 0;
13952+
dilithium_key mldsa_key;
13953+
WC_RNG rng;
13954+
word32 size;
13955+
13956+
struct {
13957+
int wcId;
13958+
int oidSum;
13959+
int keySz;
13960+
} test_variant[] = {{WC_ML_DSA_44, ML_DSA_LEVEL2k, ML_DSA_LEVEL2_PRV_KEY_SIZE},
13961+
{WC_ML_DSA_65, ML_DSA_LEVEL3k, ML_DSA_LEVEL3_PRV_KEY_SIZE},
13962+
{WC_ML_DSA_87, ML_DSA_LEVEL5k, ML_DSA_LEVEL5_PRV_KEY_SIZE}};
13963+
13964+
(void) pemSz;
13965+
13966+
ExpectNotNull(der = (byte*) XMALLOC(derMaxSz, NULL, DYNAMIC_TYPE_TMP_BUFFER));
13967+
ExpectNotNull(temp = (byte*) XMALLOC(tempMaxSz, NULL, DYNAMIC_TYPE_TMP_BUFFER));
13968+
13969+
#ifndef NO_WOLFSSL_SERVER
13970+
ExpectNotNull(ctx = wolfSSL_CTX_new(wolfSSLv23_server_method()));
13971+
#else
13972+
ExpectNotNull(ctx = wolfSSL_CTX_new(wolfSSLv23_client_method()));
13973+
#endif /* NO_WOLFSSL_SERVER */
13974+
13975+
ExpectIntEQ(wc_InitRng(&rng), 0);
13976+
ExpectIntEQ(wc_dilithium_init(&mldsa_key), 0);
13977+
13978+
/* Test private + public key (separated format) */
13979+
for(i = 0; i < sizeof(test_variant) / sizeof(test_variant[0]); ++i) {
13980+
ExpectIntEQ(wc_dilithium_set_level(&mldsa_key, test_variant[i].wcId), 0);
13981+
ExpectIntEQ(wc_dilithium_make_key(&mldsa_key, &rng), 0);
13982+
13983+
ExpectIntGT(derSz = wc_Dilithium_KeyToDer(&mldsa_key, der, derMaxSz), 0);
13984+
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, der, derSz,
13985+
WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS);
13986+
13987+
#ifdef WOLFSSL_DER_TO_PEM
13988+
ExpectIntGT(pemSz = wc_DerToPem(der, derSz, temp, tempMaxSz, PKCS8_PRIVATEKEY_TYPE), 0);
13989+
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, temp, pemSz,
13990+
WOLFSSL_FILETYPE_PEM), WOLFSSL_SUCCESS);
13991+
#endif /* WOLFSSL_DER_TO_PEM */
13992+
}
13993+
13994+
/* Test private key only */
13995+
for(i = 0; i < sizeof(test_variant) / sizeof(test_variant[0]); ++i) {
13996+
ExpectIntEQ(wc_dilithium_set_level(&mldsa_key, test_variant[i].wcId), 0);
13997+
ExpectIntEQ(wc_dilithium_make_key(&mldsa_key, &rng), 0);
13998+
13999+
ExpectIntGT(derSz = wc_Dilithium_PrivateKeyToDer(&mldsa_key, der, derMaxSz), 0);
14000+
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, der, derSz,
14001+
WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS);
14002+
14003+
#ifdef WOLFSSL_DER_TO_PEM
14004+
ExpectIntGT(pemSz = wc_DerToPem(der, derSz, temp, tempMaxSz, PKCS8_PRIVATEKEY_TYPE), 0);
14005+
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, temp, pemSz,
14006+
WOLFSSL_FILETYPE_PEM), WOLFSSL_SUCCESS);
14007+
#endif /* WOLFSSL_DER_TO_PEM */
14008+
}
14009+
14010+
/* Test private + public key (integrated format) */
14011+
for(i = 0; i < sizeof(test_variant) / sizeof(test_variant[0]); ++i) {
14012+
ExpectIntEQ(wc_dilithium_set_level(&mldsa_key, test_variant[i].wcId), 0);
14013+
ExpectIntEQ(wc_dilithium_make_key(&mldsa_key, &rng), 0);
14014+
14015+
keySz = 0;
14016+
temp[0] = 0x04; /* ASN.1 OCTET STRING */
14017+
temp[1] = 0x82; /* 2 bytes length field */
14018+
temp[2] = (test_variant[i].keySz >> 8) & 0xff; /* MSB of the length */
14019+
temp[3] = test_variant[i].keySz & 0xff; /* LSB of the length */
14020+
keySz += 4;
14021+
size = tempMaxSz - keySz;
14022+
ExpectIntEQ(wc_dilithium_export_private(&mldsa_key, temp + keySz, &size), 0);
14023+
keySz += size;
14024+
size = tempMaxSz - keySz;
14025+
ExpectIntEQ(wc_dilithium_export_public(&mldsa_key, temp + keySz, &size), 0);
14026+
keySz += size;
14027+
derSz = derMaxSz;
14028+
ExpectIntGT(wc_CreatePKCS8Key(der, &derSz, temp, keySz, test_variant[i].oidSum, NULL, 0), 0);
14029+
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, der, derSz,
14030+
WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS);
14031+
14032+
#ifdef WOLFSSL_DER_TO_PEM
14033+
ExpectIntGT(pemSz = wc_DerToPem(der, derSz, temp, tempMaxSz, PKCS8_PRIVATEKEY_TYPE), 0);
14034+
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, temp, pemSz,
14035+
WOLFSSL_FILETYPE_PEM), WOLFSSL_SUCCESS);
14036+
#endif /* WOLFSSL_DER_TO_PEM */
14037+
}
14038+
14039+
wc_dilithium_free(&mldsa_key);
14040+
ExpectIntEQ(wc_FreeRng(&rng), 0);
14041+
wolfSSL_CTX_free(ctx);
14042+
XFREE(temp, NULL, DYNAMIC_TYPE_TMP_BUFFER);
14043+
XFREE(der, NULL, DYNAMIC_TYPE_TMP_BUFFER);
14044+
14045+
#endif
14046+
return EXPECT_RESULT();
14047+
}
14048+
1393614049
/* Testing functions dealing with PKCS5 */
1393714050
static int test_wolfSSL_PKCS5(void)
1393814051
{
@@ -67519,6 +67632,7 @@ TEST_CASE testCases[] = {
6751967632
TEST_DECL(test_wolfSSL_PKCS8),
6752067633
TEST_DECL(test_wolfSSL_PKCS8_ED25519),
6752167634
TEST_DECL(test_wolfSSL_PKCS8_ED448),
67635+
TEST_DECL(test_wolfSSL_PKCS8_MLDSA),
6752267636

6752367637
#ifdef HAVE_IO_TESTS_DEPENDENCIES
6752467638
TEST_DECL(test_wolfSSL_get_finished),

tests/api/test_mldsa.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2959,7 +2959,7 @@ int test_wc_dilithium_der(void)
29592959
idx = 0;
29602960
#ifdef WOLFSSL_DILITHIUM_FIPS204_DRAFT
29612961
ExpectIntEQ(wc_Dilithium_PrivateKeyDecode(der, &idx, key, privDerLen),
2962-
WC_NO_ERR_TRACE(BAD_FUNC_ARG));
2962+
WC_NO_ERR_TRACE(ASN_PARSE_E));
29632963
#else
29642964
ExpectIntEQ(wc_Dilithium_PrivateKeyDecode(der, &idx, key, privDerLen),
29652965
WC_NO_ERR_TRACE(ASN_PARSE_E));

wolfcrypt/src/dilithium.c

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9589,6 +9589,42 @@ static int mapOidToSecLevel(word32 oid)
95899589
}
95909590
}
95919591

9592+
/* Get OID sum from dilithium key */
9593+
int dilithium_get_oid_sum(dilithium_key* key, int* keyFormat) {
9594+
int ret = 0;
9595+
9596+
#if defined(WOLFSSL_DILITHIUM_FIPS204_DRAFT)
9597+
if (key->params == NULL) {
9598+
ret = BAD_FUNC_ARG;
9599+
}
9600+
else if (key->params->level == WC_ML_DSA_44_DRAFT) {
9601+
*keyFormat = DILITHIUM_LEVEL2k;
9602+
}
9603+
else if (key->params->level == WC_ML_DSA_65_DRAFT) {
9604+
*keyFormat = DILITHIUM_LEVEL3k;
9605+
}
9606+
else if (key->params->level == WC_ML_DSA_87_DRAFT) {
9607+
*keyFormat = DILITHIUM_LEVEL5k;
9608+
}
9609+
else
9610+
#endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */
9611+
if (key->level == WC_ML_DSA_44) {
9612+
*keyFormat = ML_DSA_LEVEL2k;
9613+
}
9614+
else if (key->level == WC_ML_DSA_65) {
9615+
*keyFormat = ML_DSA_LEVEL3k;
9616+
}
9617+
else if (key->level == WC_ML_DSA_87) {
9618+
*keyFormat = ML_DSA_LEVEL5k;
9619+
}
9620+
else {
9621+
/* Level is not set */
9622+
ret = ALGO_ID_E;
9623+
}
9624+
9625+
return ret;
9626+
}
9627+
95929628
#if defined(WOLFSSL_DILITHIUM_PRIVATE_KEY)
95939629

95949630
/* Decode the DER encoded Dilithium key.
@@ -9627,9 +9663,13 @@ int wc_Dilithium_PrivateKeyDecode(const byte* input, word32* inOutIdx,
96279663
}
96289664

96299665
if (ret == 0) {
9630-
/* Get OID sum for level. */
9666+
/* Get OID sum for level. */
9667+
if(key->level == 0) { /* Check first, because key->params will be NULL when key->level = 0 */
9668+
/* Level not set by caller, decode from DER */
9669+
keytype = ANONk;
9670+
}
96319671
#if defined(WOLFSSL_DILITHIUM_FIPS204_DRAFT)
9632-
if (key->params == NULL) {
9672+
else if (key->params == NULL) {
96339673
ret = BAD_FUNC_ARG;
96349674
}
96359675
else if (key->params->level == WC_ML_DSA_44_DRAFT) {
@@ -9641,9 +9681,8 @@ int wc_Dilithium_PrivateKeyDecode(const byte* input, word32* inOutIdx,
96419681
else if (key->params->level == WC_ML_DSA_87_DRAFT) {
96429682
keytype = DILITHIUM_LEVEL5k;
96439683
}
9644-
else
96459684
#endif
9646-
if (key->level == WC_ML_DSA_44) {
9685+
else if (key->level == WC_ML_DSA_44) {
96479686
keytype = ML_DSA_LEVEL2k;
96489687
}
96499688
else if (key->level == WC_ML_DSA_65) {
@@ -9653,8 +9692,7 @@ int wc_Dilithium_PrivateKeyDecode(const byte* input, word32* inOutIdx,
96539692
keytype = ML_DSA_LEVEL5k;
96549693
}
96559694
else {
9656-
/* Level not set by caller, decode from DER */
9657-
keytype = ANONk; /* 0, not a valid key type in this situation*/
9695+
ret = BAD_FUNC_ARG;
96589696
}
96599697
}
96609698

wolfssl/wolfcrypt/dilithium.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,10 @@ int wc_dilithium_export_key(dilithium_key* key, byte* priv, word32 *privSz,
813813
byte* pub, word32 *pubSz);
814814
#endif
815815

816+
#ifndef WOLFSSL_DILITHIUM_NO_ASN1
817+
WOLFSSL_LOCAL int dilithium_get_oid_sum(dilithium_key* key, int* keyFormat);
818+
#endif /* WOLFSSL_DILITHIUM_NO_ASN1 */
819+
816820
#ifndef WOLFSSL_DILITHIUM_NO_ASN1
817821
#if defined(WOLFSSL_DILITHIUM_PRIVATE_KEY)
818822
WOLFSSL_API int wc_Dilithium_PrivateKeyDecode(const byte* input,

0 commit comments

Comments
 (0)