Skip to content

Commit 7a83f0a

Browse files
committed
Speedup DTLS max fragment size calculation for MTU limits
Refactor wolfssl_local_GetRecordSize to compute the record size directly from cipher specs instead of calling BuildMessage. Add a unit test that compares the new calculation to BuildMessage's size-only output across every registered cipher suite and supported (D)TLS version.
1 parent 3181e2b commit 7a83f0a

4 files changed

Lines changed: 269 additions & 22 deletions

File tree

src/internal.c

Lines changed: 87 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42252,30 +42252,100 @@ int wolfSSL_AsyncPush(WOLFSSL* ssl, WC_ASYNC_DEV* asyncDev)
4225242252
*/
4225342253
int wolfssl_local_GetRecordSize(WOLFSSL *ssl, int payloadSz, int isEncrypted)
4225442254
{
42255-
int recordSz;
42255+
int sz;
42256+
int headerSz;
42257+
int digestSz;
42258+
int ivSz;
42259+
#ifdef WOLFSSL_DTLS_CID
42260+
byte cidSz;
42261+
#endif
42262+
#ifndef WOLFSSL_AEAD_ONLY
42263+
int blockSz;
42264+
int pad;
42265+
#endif
4225642266

4225742267
if (ssl == NULL)
4225842268
return BAD_FUNC_ARG;
4225942269

42260-
if (isEncrypted) {
42261-
recordSz = BuildMessage(ssl, NULL, 0, NULL, payloadSz, application_data,
42262-
0, 1, 0, CUR_ORDER);
42263-
/* use a safe upper bound in case of error */
42264-
if (recordSz < 0) {
42265-
recordSz = payloadSz + RECORD_HEADER_SZ
42266-
+ cipherExtraData(ssl) + COMP_EXTRA;
42267-
if (ssl->options.dtls) {
42268-
recordSz += DTLS_RECORD_EXTRA;
42269-
}
42270-
}
42270+
if (!isEncrypted) {
42271+
sz = payloadSz + RECORD_HEADER_SZ;
42272+
if (ssl->options.dtls)
42273+
sz += DTLS_RECORD_EXTRA;
42274+
return sz;
4227142275
}
42272-
else {
42273-
recordSz = payloadSz + RECORD_HEADER_SZ;
42274-
if (ssl->options.dtls) {
42275-
recordSz += DTLS_RECORD_EXTRA;
42276+
42277+
#ifdef WOLFSSL_TLS13
42278+
if (ssl->options.tls1_3) {
42279+
#ifdef WOLFSSL_DTLS13
42280+
if (ssl->options.dtls)
42281+
headerSz = Dtls13GetRlHeaderLength(ssl, 1);
42282+
else
42283+
#endif
42284+
headerSz = RECORD_HEADER_SZ;
42285+
sz = payloadSz + headerSz + 1 /* inner type */
42286+
+ ssl->specs.aead_mac_size;
42287+
#ifdef WOLFSSL_DTLS13
42288+
if (ssl->options.dtls && sz < Dtls13MinimumRecordLength(ssl))
42289+
sz = Dtls13MinimumRecordLength(ssl);
42290+
#endif
42291+
return sz;
42292+
}
42293+
#endif
42294+
42295+
/* TLS 1.2 / TLS 1.1 / DTLS 1.2 path. Mirror BuildMessage's size
42296+
* calculation so the result matches exactly. */
42297+
headerSz = RECORD_HEADER_SZ;
42298+
sz = payloadSz + RECORD_HEADER_SZ;
42299+
42300+
if (ssl->options.dtls) {
42301+
sz += DTLS_RECORD_EXTRA;
42302+
headerSz += DTLS_RECORD_EXTRA;
42303+
#ifdef WOLFSSL_DTLS_CID
42304+
cidSz = DtlsGetCidTxSize(ssl);
42305+
if (cidSz > 0) {
42306+
sz += cidSz;
42307+
headerSz += cidSz;
42308+
sz++; /* real_type byte appended */
4227642309
}
42310+
#endif
42311+
}
42312+
42313+
digestSz = (int)ssl->specs.hash_size;
42314+
#ifdef HAVE_TRUNCATED_HMAC
42315+
if (ssl->truncated_hmac)
42316+
digestSz = min(TRUNCATED_HMAC_SZ, digestSz);
42317+
#endif
42318+
sz += digestSz;
42319+
42320+
#ifndef WOLFSSL_AEAD_ONLY
42321+
if (ssl->specs.cipher_type == block) {
42322+
blockSz = (int)ssl->specs.block_size;
42323+
42324+
if (ssl->options.tls1_1)
42325+
sz += blockSz; /* explicit IV */
42326+
sz += 1; /* pad-length byte */
42327+
42328+
#if defined(HAVE_ENCRYPT_THEN_MAC) && !defined(WOLFSSL_AEAD_ONLY)
42329+
if (ssl->options.startedETMWrite)
42330+
pad = blockSz != 0 ?
42331+
(sz - headerSz - digestSz) % blockSz : 0;
42332+
else
42333+
#endif
42334+
pad = blockSz != 0 ? (sz - headerSz) % blockSz : 0;
42335+
if (pad != 0)
42336+
pad = blockSz - pad;
42337+
sz += pad;
4227742338
}
42278-
return recordSz;
42339+
else
42340+
#endif /* WOLFSSL_AEAD_ONLY */
42341+
if (ssl->specs.cipher_type == aead) {
42342+
ivSz = 0;
42343+
if (ssl->specs.bulk_cipher_algorithm != wolfssl_chacha)
42344+
ivSz = AESGCM_EXP_IV_SZ;
42345+
sz += ivSz + (int)ssl->specs.aead_mac_size - digestSz;
42346+
}
42347+
42348+
return sz;
4227942349
}
4228042350
#endif
4228142351

tests/api/test_tls.c

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,3 +1121,178 @@ int test_tls12_peerauth_failsafe(void)
11211121
#endif
11221122
return EXPECT_RESULT();
11231123
}
1124+
1125+
/* Verify that wolfssl_local_GetRecordSize agrees exactly with
1126+
* BuildMessage's size-only output. Iterates every cipher suite registered
1127+
* in internal.c via GetCipherNames(), tries each against every supported
1128+
* (D)TLS protocol version (with and without DTLS connection ID where
1129+
* applicable), and on successful handshake compares the two
1130+
* record-size calculations across a range of payload sizes. Logs each
1131+
* skipped combination so coverage gaps are visible. Handshake failures
1132+
* outside the explicit allow-list fail the test instead of being
1133+
* silently skipped. */
1134+
int test_record_size_matches_build_message(void)
1135+
{
1136+
EXPECT_DECLS;
1137+
#if defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES)
1138+
const int sizes[] = {
1139+
1, 8, 15, 16, 17, 31, 32, 33, 64, 100, 256, 1000, 4096, 16000
1140+
};
1141+
struct {
1142+
method_provider client;
1143+
method_provider server;
1144+
int use_cid;
1145+
int is_tls13;
1146+
const char* label;
1147+
} versions[] = {
1148+
#ifndef WOLFSSL_NO_TLS12
1149+
{ wolfTLSv1_2_client_method, wolfTLSv1_2_server_method, 0, 0,
1150+
"TLSv1.2" },
1151+
#ifdef WOLFSSL_DTLS
1152+
{ wolfDTLSv1_2_client_method, wolfDTLSv1_2_server_method, 0, 0,
1153+
"DTLSv1.2" },
1154+
#ifdef WOLFSSL_DTLS_CID
1155+
{ wolfDTLSv1_2_client_method, wolfDTLSv1_2_server_method, 1, 0,
1156+
"DTLSv1.2+CID" },
1157+
#endif
1158+
#endif
1159+
#endif
1160+
#ifdef WOLFSSL_TLS13
1161+
{ wolfTLSv1_3_client_method, wolfTLSv1_3_server_method, 0, 1,
1162+
"TLSv1.3" },
1163+
#ifdef WOLFSSL_DTLS13
1164+
{ wolfDTLSv1_3_client_method, wolfDTLSv1_3_server_method, 0, 1,
1165+
"DTLSv1.3" },
1166+
#ifdef WOLFSSL_DTLS_CID
1167+
{ wolfDTLSv1_3_client_method, wolfDTLSv1_3_server_method, 1, 1,
1168+
"DTLSv1.3+CID" },
1169+
#endif
1170+
#endif
1171+
#endif
1172+
};
1173+
const CipherSuiteInfo* allCiphers = GetCipherNames();
1174+
int numCiphers = GetCipherNamesSize();
1175+
int tested = 0;
1176+
size_t v, j;
1177+
int i;
1178+
1179+
fprintf(stderr, "\n");
1180+
for (v = 0; v < XELEM_CNT(versions) && EXPECT_SUCCESS(); v++) {
1181+
for (i = 0; i < numCiphers && EXPECT_SUCCESS(); i++) {
1182+
WOLFSSL_CTX *ctx_c = NULL, *ctx_s = NULL;
1183+
WOLFSSL *ssl_c = NULL, *ssl_s = NULL;
1184+
struct test_memio_ctx test_ctx;
1185+
const char* name = allCiphers[i].name;
1186+
int isTls13Cipher = (XSTRSTR(name, "TLS13-") != NULL);
1187+
int handshakeRet;
1188+
1189+
XMEMSET(&test_ctx, 0, sizeof(test_ctx));
1190+
1191+
ExpectIntEQ(test_memio_setup(&test_ctx, &ctx_c, &ctx_s, &ssl_c,
1192+
&ssl_s, versions[v].client, versions[v].server), 0);
1193+
1194+
/* Skip ciphers that aren't valid for this version/build. */
1195+
if (wolfSSL_set_cipher_list(ssl_c, name) != 1 ||
1196+
wolfSSL_set_cipher_list(ssl_s, name) != 1) {
1197+
fprintf(stderr,
1198+
" [SKIP %-12s %-40s] cipher not selectable\n",
1199+
versions[v].label, name);
1200+
goto next_iter;
1201+
}
1202+
1203+
#ifdef WOLFSSL_DTLS_CID
1204+
if (versions[v].use_cid) {
1205+
unsigned char cid_c[] = { 0, 1, 2, 3 };
1206+
unsigned char cid_s[] = { 4, 5, 6, 7, 8, 9 };
1207+
ExpectIntEQ(wolfSSL_dtls_cid_use(ssl_c), 1);
1208+
ExpectIntEQ(wolfSSL_dtls_cid_use(ssl_s), 1);
1209+
ExpectIntEQ(wolfSSL_dtls_cid_set(ssl_c, cid_s,
1210+
(int)sizeof(cid_s)), 1);
1211+
ExpectIntEQ(wolfSSL_dtls_cid_set(ssl_s, cid_c,
1212+
(int)sizeof(cid_c)), 1);
1213+
}
1214+
#endif
1215+
1216+
handshakeRet = test_memio_do_handshake(ssl_c, ssl_s, 10, NULL);
1217+
if (handshakeRet != 0) {
1218+
/* Allow-list of ciphers that legitimately can't negotiate
1219+
* with the default test_memio configuration. Anything else
1220+
* is a real failure. */
1221+
int expected = 0;
1222+
const char* reason = NULL;
1223+
1224+
if (isTls13Cipher != versions[v].is_tls13) {
1225+
expected = 1;
1226+
reason = "version mismatch";
1227+
}
1228+
else if (XSTRSTR(name, "ECDSA") != NULL) {
1229+
expected = 1;
1230+
reason = "no ECDSA cert";
1231+
}
1232+
else if (XSTRSTR(name, "ECDH-") != NULL) {
1233+
expected = 1;
1234+
reason = "no static ECDH cert";
1235+
}
1236+
else if (XSTRSTR(name, "PSK") != NULL) {
1237+
expected = 1;
1238+
reason = "no PSK callback";
1239+
}
1240+
else if (XSTRSTR(name, "ANON") != NULL ||
1241+
XSTRSTR(name, "anon") != NULL) {
1242+
expected = 1;
1243+
reason = "anon not enabled";
1244+
}
1245+
else if (XSTRSTR(name, "SRP") != NULL) {
1246+
expected = 1;
1247+
reason = "no SRP setup";
1248+
}
1249+
1250+
if (!expected) {
1251+
fprintf(stderr,
1252+
" [FAIL %-12s %-40s] unexpected handshake "
1253+
"failure (%d)\n",
1254+
versions[v].label, name, handshakeRet);
1255+
}
1256+
else {
1257+
fprintf(stderr,
1258+
" [SKIP %-12s %-40s] %s\n",
1259+
versions[v].label, name, reason);
1260+
}
1261+
ExpectIntEQ(expected, 1);
1262+
goto next_iter;
1263+
}
1264+
1265+
for (j = 0; j < XELEM_CNT(sizes) && EXPECT_SUCCESS(); j++) {
1266+
int payload = sizes[j];
1267+
int recordSz, buildSz;
1268+
1269+
recordSz = wolfssl_local_GetRecordSize(ssl_c, payload, 1);
1270+
buildSz = BuildMessage(ssl_c, NULL, 0, NULL, payload,
1271+
application_data, 0, 1, 0, CUR_ORDER);
1272+
1273+
ExpectIntGE(recordSz, 0);
1274+
ExpectIntGE(buildSz, 0);
1275+
ExpectIntEQ(recordSz, buildSz);
1276+
if (recordSz != buildSz) {
1277+
fprintf(stderr,
1278+
" [MISMATCH %-12s %-40s payload=%d]"
1279+
" recordSz=%d buildSz=%d\n",
1280+
versions[v].label, name, payload,
1281+
recordSz, buildSz);
1282+
}
1283+
}
1284+
tested++;
1285+
1286+
next_iter:
1287+
wolfSSL_free(ssl_c); wolfSSL_CTX_free(ctx_c);
1288+
wolfSSL_free(ssl_s); wolfSSL_CTX_free(ctx_s);
1289+
}
1290+
}
1291+
1292+
/* Sanity: at least one cipher/version combination must have been
1293+
* exercised per supported version, otherwise the test is silently a
1294+
* no-op. */
1295+
ExpectIntGE(tested, (int)XELEM_CNT(versions));
1296+
#endif
1297+
return EXPECT_RESULT();
1298+
}

tests/api/test_tls.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ int test_tls12_etm_failed_resumption(void);
3535
int test_tls_set_curves_list_ecc_fallback(void);
3636
int test_tls12_corrupted_finished(void);
3737
int test_tls12_peerauth_failsafe(void);
38+
int test_record_size_matches_build_message(void);
3839

3940
#define TEST_TLS_DECLS \
4041
TEST_DECL_GROUP("tls", test_utils_memio_move_message), \
@@ -49,6 +50,7 @@ int test_tls12_peerauth_failsafe(void);
4950
TEST_DECL_GROUP("tls", test_tls12_etm_failed_resumption), \
5051
TEST_DECL_GROUP("tls", test_tls_set_curves_list_ecc_fallback), \
5152
TEST_DECL_GROUP("tls", test_tls12_corrupted_finished), \
52-
TEST_DECL_GROUP("tls", test_tls12_peerauth_failsafe)
53+
TEST_DECL_GROUP("tls", test_tls12_peerauth_failsafe), \
54+
TEST_DECL_GROUP("tls", test_record_size_matches_build_message)
5355

5456
#endif /* TESTS_API_TEST_TLS_H */

wolfssl/internal.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6815,7 +6815,7 @@ WOLFSSL_LOCAL int VerifyClientSuite(word16 havePSK, byte cipherSuite0,
68156815
byte cipherSuite);
68166816

68176817
WOLFSSL_LOCAL int SetTicket(WOLFSSL* ssl, const byte* ticket, word32 length);
6818-
WOLFSSL_LOCAL int wolfssl_local_GetRecordSize(WOLFSSL *ssl, int payloadSz,
6818+
WOLFSSL_TEST_VIS int wolfssl_local_GetRecordSize(WOLFSSL *ssl, int payloadSz,
68196819
int isEncrypted);
68206820
WOLFSSL_LOCAL int wolfssl_local_GetMaxPlaintextSize(WOLFSSL *ssl);
68216821
WOLFSSL_LOCAL int wolfSSL_GetMaxFragSize(WOLFSSL* ssl);
@@ -7106,8 +7106,8 @@ typedef struct CipherSuiteInfo {
71067106
byte flags;
71077107
} CipherSuiteInfo;
71087108

7109-
WOLFSSL_LOCAL const CipherSuiteInfo* GetCipherNames(void);
7110-
WOLFSSL_LOCAL int GetCipherNamesSize(void);
7109+
WOLFSSL_TEST_VIS const CipherSuiteInfo* GetCipherNames(void);
7110+
WOLFSSL_TEST_VIS int GetCipherNamesSize(void);
71117111
WOLFSSL_LOCAL const char* GetCipherNameInternal(byte cipherSuite0, byte cipherSuite);
71127112
#if defined(OPENSSL_ALL) || defined(WOLFSSL_QT)
71137113
/* used in wolfSSL_sk_CIPHER_description */
@@ -7187,7 +7187,7 @@ WOLFSSL_LOCAL int InitHandshakeHashesAndCopy(WOLFSSL* ssl, HS_Hashes* source,
71877187
#ifndef WOLFSSL_NO_TLS12
71887188
WOLFSSL_LOCAL void FreeBuildMsgArgs(WOLFSSL* ssl, BuildMsgArgs* args);
71897189
#endif
7190-
WOLFSSL_LOCAL int BuildMessage(WOLFSSL* ssl, byte* output, int outSz,
7190+
WOLFSSL_TEST_VIS int BuildMessage(WOLFSSL* ssl, byte* output, int outSz,
71917191
const byte* input, int inSz, int type, int hashOutput,
71927192
int sizeOnly, int asyncOkay, int epochOrder);
71937193

0 commit comments

Comments
 (0)