@@ -1096,7 +1096,7 @@ WOLFSSH* SshInit(WOLFSSH* ssh, WOLFSSH_CTX* ctx)
10961096 ssh->fs = NULL;
10971097 ssh->acceptState = ACCEPT_BEGIN;
10981098 ssh->clientState = CLIENT_BEGIN;
1099- ssh->isKeying = 1;
1099+ ssh->isKeying = 0; /* initial state of not keying yet */
11001100 ssh->authId = ID_USERAUTH_PUBLICKEY;
11011101 ssh->supportedAuth[0] = ID_USERAUTH_PUBLICKEY;
11021102 ssh->supportedAuth[1] = ID_USERAUTH_PASSWORD;
@@ -4058,6 +4058,15 @@ static int DoKexInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx)
40584058 ret = WS_BAD_ARGUMENT;
40594059 }
40604060
4061+ if (ret == WS_SUCCESS) {
4062+ /* Check if already in process of keying and error out if so. */
4063+ if (ssh->isKeying & WOLFSSH_PEER_IS_KEYING) {
4064+ WLOG(WS_LOG_ERROR,
4065+ "Already in keying process and got KEX init");
4066+ ret = WS_INVALID_STATE_E;
4067+ }
4068+ }
4069+
40614070 /*
40624071 * I don't need to save what the client sends here. I should decode
40634072 * each list into a local array of IDs, and pick the one the peer is
@@ -4067,6 +4076,8 @@ static int DoKexInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx)
40674076 */
40684077
40694078 if (ret == WS_SUCCESS) {
4079+ /* Set peer is keying flag after receiving SSH_MSG_KEX_INIT */
4080+ ssh->isKeying |= WOLFSSH_PEER_IS_KEYING;
40704081 if (ssh->handshake == NULL) {
40714082 ssh->handshake = HandshakeInfoNew(ssh->ctx->heap);
40724083 if (ssh->handshake == NULL) {
@@ -5881,6 +5892,13 @@ static int DoNewKeys(WOLFSSH* ssh, byte* buf, word32 len, word32* idx)
58815892 if (ssh == NULL || ssh->handshake == NULL)
58825893 ret = WS_BAD_ARGUMENT;
58835894
5895+ if (ret == WS_SUCCESS) {
5896+ if (ssh->isKeying & WOLFSSH_SELF_IS_KEYING) {
5897+ WLOG(WS_LOG_ERROR, "Keying failed");
5898+ ret = WS_INVALID_STATE_E;
5899+ }
5900+ }
5901+
58845902 if (ret == WS_SUCCESS) {
58855903 ssh->peerEncryptId = ssh->handshake->encryptId;
58865904 ssh->peerMacId = ssh->handshake->macId;
@@ -5941,7 +5959,9 @@ static int DoNewKeys(WOLFSSH* ssh, byte* buf, word32 len, word32* idx)
59415959 if (ret == WS_SUCCESS) {
59425960 ssh->rxCount = 0;
59435961 ssh->highwaterFlag = 0;
5944- ssh->isKeying = 0;
5962+
5963+ /* Clear peer is keying flag */
5964+ ssh->isKeying &= ~WOLFSSH_PEER_IS_KEYING;
59455965 HandshakeInfoFree(ssh->handshake, ssh->ctx->heap);
59465966 ssh->handshake = NULL;
59475967 WLOG(WS_LOG_DEBUG, "Keying completed");
@@ -9405,7 +9425,7 @@ static int DoPacket(WOLFSSH* ssh, byte* bufferConsumed)
94059425 case MSGID_KEXINIT:
94069426 WLOG(WS_LOG_DEBUG, "Decoding MSGID_KEXINIT");
94079427 ret = DoKexInit(ssh, buf + idx, payloadSz, &payloadIdx);
9408- if (ssh->isKeying == 1 &&
9428+ if (ssh->isKeying &&
94099429 ssh->connectState == CONNECT_SERVER_CHANNEL_REQUEST_DONE) {
94109430 if (ssh->handshake->kexId == ID_DH_GEX_SHA256) {
94119431#if !defined(WOLFSSH_NO_DH) && !defined(WOLFSSH_NO_DH_GEX_SHA256)
@@ -10501,7 +10521,8 @@ int SendKexInit(WOLFSSH* ssh)
1050110521 }
1050210522
1050310523 if (ret == WS_SUCCESS) {
10504- ssh->isKeying = 1;
10524+ /* Set self is keying flag since we started sending the KEX init msg */
10525+ ssh->isKeying |= WOLFSSH_SELF_IS_KEYING;
1050510526 if (ssh->handshake == NULL) {
1050610527 ssh->handshake = HandshakeInfoNew(ssh->ctx->heap);
1050710528 if (ssh->handshake == NULL) {
@@ -12534,9 +12555,13 @@ int SendNewKeys(WOLFSSH* ssh)
1253412555 ssh->txCount = 0;
1253512556 }
1253612557
12537- if (ret == WS_SUCCESS)
12558+ if (ret == WS_SUCCESS) {
1253812559 ret = wolfSSH_SendPacket(ssh);
1253912560
12561+ /* Clear self is keying flag */
12562+ ssh->isKeying &= ~WOLFSSH_SELF_IS_KEYING;
12563+ }
12564+
1254012565 WLOG(WS_LOG_DEBUG, "Leaving SendNewKeys(), ret = %d", ret);
1254112566 return ret;
1254212567}
0 commit comments