Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 94 additions & 9 deletions src/WebSocketProtocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,24 @@ struct WebSocketProtocol {
}
}

static inline void unmaskPrecise8(char *src, uint64_t maskInt, char* mask, unsigned int length) {
size_t lengthu64 = length/8;
uint64_t* u64I = (uint64_t*)src;
uint64_t* u64O = (uint64_t*)src;
for(size_t m=0; m<lengthu64; m++){
u64O[m] = u64I[m] ^ maskInt;
}
size_t remain = length % 8;
if (remain > 0) {
size_t roffset = length - remain;
char* rI = (src + roffset);
char* rO = (src + roffset);
for (size_t i=0; i<remain; i++) {
rO[i] = rI[i] ^ mask[i % 4];
}
}
}

/* DESTINATION = 6 makes this not SIMD, DESTINATION = 4 is with SIMD but we don't want that for short messages */
template <int DESTINATION>
static inline void unmaskImprecise4(char *src, uint32_t mask, unsigned int length) {
Expand Down Expand Up @@ -322,6 +340,51 @@ struct WebSocketProtocol {
}
}

template <int HEADER_SIZE>
static inline void unmaskPreciseCopyMask(char *src, unsigned int length) {
if constexpr (HEADER_SIZE != 6) {
char mask[8] = {src[-4], src[-3], src[-2], src[-1], src[-4], src[-3], src[-2], src[-1]};
uint64_t maskInt;
memcpy(&maskInt, mask, 8);

size_t lengthu64 = length/8;
uint64_t* u64I = (uint64_t*)src;
uint64_t* u64O = (uint64_t*)src;
for(size_t m=0; m<lengthu64; m++){
u64O[m] = u64I[m] ^ maskInt;
}
size_t remain = length % 8;
if (remain > 0) {
size_t roffset = length - remain;
char* rI = (src + roffset);
char* rO = (src + roffset);
for (size_t i=0; i<remain; i++) {
rO[i] = rI[i] ^ mask[i % 4];
}
}
} else {
char mask[4] = {src[-4], src[-3], src[-2], src[-1]};
uint32_t maskInt;
memcpy(&maskInt, mask, 4);

size_t lengthu32 = length/4;
uint32_t* u32I = (uint32_t*)src;
uint32_t* u32O = (uint32_t*)src;
for(size_t i=0; i<lengthu32; i++){
u32O[i] = u32I[i] ^ maskInt;
}
size_t remain = length%4;
if (remain > 0) {
size_t roffset = length - remain;
char* rI = (src + roffset);
char* rO = (src + roffset);
for (size_t i=0; i<remain; i++) {
rO[i] = rI[i] ^ mask[i % 4];
}
}
}
}

static inline void rotateMask(unsigned int offset, char *mask) {
char originalMask[4] = {mask[0], mask[1], mask[2], mask[3]};
mask[(0 + offset) % 4] = originalMask[0];
Expand All @@ -339,6 +402,28 @@ struct WebSocketProtocol {
}
}

static inline void unmaskPreciseInplace(char *src, size_t length, char *mask) {
uint64_t maskInt;
memcpy(&maskInt, mask, 4);
memcpy(((char *)&maskInt) + 4, mask, 4);

size_t lengthu64 = length/8;
uint64_t* u64I = (uint64_t*)src;
uint64_t* u64O = (uint64_t*)src;
for(size_t m=0; m<lengthu64; m++){
u64O[m] = u64I[m] ^ maskInt;
}
size_t remain = length % 8;
if (remain > 0) {
size_t roffset = length - remain;
char* rI = (src + roffset);
char* rO = (src + roffset);
for (size_t i=0; i<remain; i++) {
rO[i] = rI[i] ^ mask[i % 4];
}
}
}

template <unsigned int MESSAGE_HEADER, typename T>
static inline bool consumeMessage(T payLength, char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
if (getOpCode(src)) {
Expand All @@ -361,9 +446,9 @@ struct WebSocketProtocol {
if (payLength + MESSAGE_HEADER <= length) {
bool fin = isFin(src);
if (isServer) {
/* This guy can never be assumed to be perfectly aligned since we can get multiple messages in one read */
unmaskImpreciseCopyMask<MESSAGE_HEADER>(src + MESSAGE_HEADER, (unsigned int) payLength);
if (Impl::handleFragment(src, payLength, 0, wState->state.opCode[wState->state.opStack], fin, wState, user)) {
/* use precise mask, which is better than unprecise to avoid effect remain data */
unmaskPreciseCopyMask<MESSAGE_HEADER>(src + MESSAGE_HEADER, (unsigned int) payLength);
if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0, wState->state.opCode[wState->state.opStack], fin, wState, user)) {
return true;
}
} else {
Expand All @@ -387,10 +472,11 @@ struct WebSocketProtocol {
bool fin = isFin(src);
if constexpr (isServer) {
memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4);
uint64_t mask;
memcpy(&mask, src + MESSAGE_HEADER - 4, 4);
memcpy(((char *)&mask) + 4, src + MESSAGE_HEADER - 4, 4);
unmaskImprecise8<0>(src + MESSAGE_HEADER, mask, length);
uint64_t maskInt;
memcpy(&maskInt, src + MESSAGE_HEADER - 4, 4);
memcpy(((char *)&maskInt) + 4, src + MESSAGE_HEADER - 4, 4);
char* mask = src + MESSAGE_HEADER - 4;
unmaskPrecise8(src + MESSAGE_HEADER, maskInt, mask, length - MESSAGE_HEADER);
rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask);
}
Impl::handleFragment(src + MESSAGE_HEADER, length - MESSAGE_HEADER, wState->remainingBytes, wState->state.opCode[wState->state.opStack], fin, wState, user);
Expand Down Expand Up @@ -435,8 +521,7 @@ struct WebSocketProtocol {
if /*constexpr*/ (LIBUS_RECV_BUFFER_LENGTH == length) {
unmaskAll(src, wState->mask);
} else {
// Slow path
unmaskInplace(src, src + ((length >> 2) + 1) * 4, wState->mask);
unmaskPreciseInplace(src, length, wState->mask);
}
}
}
Expand Down
Loading