contrib/zstd/lib/common/entropy_common.c in extzstd-0.2 vs contrib/zstd/lib/common/entropy_common.c in extzstd-0.3

- old
+ new

@@ -70,11 +70,25 @@ U32 bitStream; int bitCount; unsigned charnum = 0; int previous0 = 0; - if (hbSize < 4) return ERROR(srcSize_wrong); + if (hbSize < 4) { + /* This function only works when hbSize >= 4 */ + char buffer[4]; + memset(buffer, 0, sizeof(buffer)); + memcpy(buffer, headerBuffer, hbSize); + { size_t const countSize = FSE_readNCount(normalizedCounter, maxSVPtr, tableLogPtr, + buffer, sizeof(buffer)); + if (FSE_isError(countSize)) return countSize; + if (countSize > hbSize) return ERROR(corruption_detected); + return countSize; + } } + assert(hbSize >= 4); + + /* init */ + memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0])); /* all symbols not present in NCount have a frequency of 0 */ bitStream = MEM_readLE32(ip); nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG; /* extract tableLog */ if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge); bitStream >>= 4; bitCount = 4; @@ -103,9 +117,10 @@ n0 += bitStream & 3; bitCount += 2; if (n0 > *maxSVPtr) return ERROR(maxSymbolValue_tooSmall); while (charnum < n0) normalizedCounter[charnum++] = 0; if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { + assert((bitCount >> 3) <= 3); /* For first condition to work */ ip += bitCount>>3; bitCount &= 7; bitStream = MEM_readLE32(ip) >> bitCount; } else { bitStream >>= 2;