contrib/zstd/lib/common/entropy_common.c in extzstd-0.2 vs contrib/zstd/lib/common/entropy_common.c in extzstd-0.3
- old
+ new
@@ -70,11 +70,25 @@
U32 bitStream;
int bitCount;
unsigned charnum = 0;
int previous0 = 0;
- if (hbSize < 4) return ERROR(srcSize_wrong);
+ if (hbSize < 4) {
+ /* This function only works when hbSize >= 4 */
+ char buffer[4];
+ memset(buffer, 0, sizeof(buffer));
+ memcpy(buffer, headerBuffer, hbSize);
+ { size_t const countSize = FSE_readNCount(normalizedCounter, maxSVPtr, tableLogPtr,
+ buffer, sizeof(buffer));
+ if (FSE_isError(countSize)) return countSize;
+ if (countSize > hbSize) return ERROR(corruption_detected);
+ return countSize;
+ } }
+ assert(hbSize >= 4);
+
+ /* init */
+ memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0])); /* all symbols not present in NCount have a frequency of 0 */
bitStream = MEM_readLE32(ip);
nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG; /* extract tableLog */
if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge);
bitStream >>= 4;
bitCount = 4;
@@ -103,9 +117,10 @@
n0 += bitStream & 3;
bitCount += 2;
if (n0 > *maxSVPtr) return ERROR(maxSymbolValue_tooSmall);
while (charnum < n0) normalizedCounter[charnum++] = 0;
if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
+ assert((bitCount >> 3) <= 3); /* For first condition to work */
ip += bitCount>>3;
bitCount &= 7;
bitStream = MEM_readLE32(ip) >> bitCount;
} else {
bitStream >>= 2;