diff --git a/src/WebSocketProtocol.h b/src/WebSocketProtocol.h index 15d57a91c..280d267e2 100644 --- a/src/WebSocketProtocol.h +++ b/src/WebSocketProtocol.h @@ -295,6 +295,24 @@ struct WebSocketProtocol { } } + static inline void unmaskPrecise8(char *src, uint64_t maskInt, char* mask, unsigned int length) { + size_t lengthu64 = length/8; + uint64_t* u64I = (uint64_t*)src; + uint64_t* u64O = (uint64_t*)src; + for(size_t m=0; m 0) { + size_t roffset = length - remain; + char* rI = (src + roffset); + char* rO = (src + roffset); + for (size_t i=0; i static inline void unmaskImprecise4(char *src, uint32_t mask, unsigned int length) { @@ -322,6 +340,51 @@ struct WebSocketProtocol { } } + template + static inline void unmaskPreciseCopyMask(char *src, unsigned int length) { + if constexpr (HEADER_SIZE != 6) { + char mask[8] = {src[-4], src[-3], src[-2], src[-1], src[-4], src[-3], src[-2], src[-1]}; + uint64_t maskInt; + memcpy(&maskInt, mask, 8); + + size_t lengthu64 = length/8; + uint64_t* u64I = (uint64_t*)src; + uint64_t* u64O = (uint64_t*)src; + for(size_t m=0; m 0) { + size_t roffset = length - remain; + char* rI = (src + roffset); + char* rO = (src + roffset); + for (size_t i=0; i 0) { + size_t roffset = length - remain; + char* rI = (src + roffset); + char* rO = (src + roffset); + for (size_t i=0; i 0) { + size_t roffset = length - remain; + char* rI = (src + roffset); + char* rO = (src + roffset); + for (size_t i=0; i static inline bool consumeMessage(T payLength, char *&src, unsigned int &length, WebSocketState *wState, void *user) { if (getOpCode(src)) { @@ -361,9 +446,9 @@ struct WebSocketProtocol { if (payLength + MESSAGE_HEADER <= length) { bool fin = isFin(src); if (isServer) { - /* This guy can never be assumed to be perfectly aligned since we can get multiple messages in one read */ - unmaskImpreciseCopyMask(src + MESSAGE_HEADER, (unsigned int) payLength); - if (Impl::handleFragment(src, payLength, 0, wState->state.opCode[wState->state.opStack], fin, wState, user)) { + /* use precise mask, which is better than unprecise to avoid effect remain data */ + unmaskPreciseCopyMask(src + MESSAGE_HEADER, (unsigned int) payLength); + if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0, wState->state.opCode[wState->state.opStack], fin, wState, user)) { return true; } } else { @@ -387,10 +472,11 @@ struct WebSocketProtocol { bool fin = isFin(src); if constexpr (isServer) { memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4); - uint64_t mask; - memcpy(&mask, src + MESSAGE_HEADER - 4, 4); - memcpy(((char *)&mask) + 4, src + MESSAGE_HEADER - 4, 4); - unmaskImprecise8<0>(src + MESSAGE_HEADER, mask, length); + uint64_t maskInt; + memcpy(&maskInt, src + MESSAGE_HEADER - 4, 4); + memcpy(((char *)&maskInt) + 4, src + MESSAGE_HEADER - 4, 4); + char* mask = src + MESSAGE_HEADER - 4; + unmaskPrecise8(src + MESSAGE_HEADER, maskInt, mask, length - MESSAGE_HEADER); rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask); } Impl::handleFragment(src + MESSAGE_HEADER, length - MESSAGE_HEADER, wState->remainingBytes, wState->state.opCode[wState->state.opStack], fin, wState, user); @@ -435,8 +521,7 @@ struct WebSocketProtocol { if /*constexpr*/ (LIBUS_RECV_BUFFER_LENGTH == length) { unmaskAll(src, wState->mask); } else { - // Slow path - unmaskInplace(src, src + ((length >> 2) + 1) * 4, wState->mask); + unmaskPreciseInplace(src, length, wState->mask); } } }