From 94aacd8bd69b7bfafce14fbe7639274e11d92d51 Mon Sep 17 00:00:00 2001 From: Adam Stylinski Date: Mon, 23 Sep 2024 18:26:04 -0400 Subject: [PATCH] Try to simply the inflate loop by collapsing most cases to chunksets --- arch/arm/arm_functions.h | 2 +- arch/generic/generic_functions.h | 2 +- arch/power/power_functions.h | 2 +- arch/riscv/riscv_functions.h | 2 +- arch/x86/chunkset_avx2.c | 23 +------ arch/x86/x86_functions.h | 6 +- chunkset_tpl.h | 108 ++++++++++++++++++------------- functable.c | 4 +- functable.h | 2 +- inffast_tpl.h | 22 ++++--- inflate.c | 2 +- inflate_p.h | 2 +- 12 files changed, 92 insertions(+), 85 deletions(-) diff --git a/arch/arm/arm_functions.h b/arch/arm/arm_functions.h index 61c68271..2b38e665 100644 --- a/arch/arm/arm_functions.h +++ b/arch/arm/arm_functions.h @@ -8,7 +8,7 @@ #ifdef ARM_NEON uint32_t adler32_neon(uint32_t adler, const uint8_t *buf, size_t len); uint32_t chunksize_neon(void); -uint8_t* chunkmemset_safe_neon(uint8_t *out, unsigned dist, unsigned len, unsigned left); +uint8_t* chunkmemset_safe_neon(uint8_t *out, uint8_t *from, unsigned len, unsigned left); # ifdef HAVE_BUILTIN_CTZLL uint32_t compare256_neon(const uint8_t *src0, const uint8_t *src1); diff --git a/arch/generic/generic_functions.h b/arch/generic/generic_functions.h index 997dd4d0..e243f326 100644 --- a/arch/generic/generic_functions.h +++ b/arch/generic/generic_functions.h @@ -22,7 +22,7 @@ typedef uint32_t (*crc32_func)(uint32_t crc32, const uint8_t *buf, size_t len); uint32_t adler32_c(uint32_t adler, const uint8_t *buf, size_t len); uint32_t chunksize_c(void); -uint8_t* chunkmemset_safe_c(uint8_t *out, unsigned dist, unsigned len, unsigned left); +uint8_t* chunkmemset_safe_c(uint8_t *out, uint8_t *from, unsigned len, unsigned left); void inflate_fast_c(PREFIX3(stream) *strm, uint32_t start); uint32_t PREFIX(crc32_braid)(uint32_t crc, const uint8_t *buf, size_t len); diff --git a/arch/power/power_functions.h b/arch/power/power_functions.h index cb6b7650..44d36af8 100644 --- a/arch/power/power_functions.h +++ b/arch/power/power_functions.h @@ -15,7 +15,7 @@ void slide_hash_vmx(deflate_state *s); #ifdef POWER8_VSX uint32_t adler32_power8(uint32_t adler, const uint8_t *buf, size_t len); uint32_t chunksize_power8(void); -uint8_t* chunkmemset_safe_power8(uint8_t *out, unsigned dist, unsigned len, unsigned left); +uint8_t* chunkmemset_safe_power8(uint8_t *out, uint8_t *from, unsigned len, unsigned left); uint32_t crc32_power8(uint32_t crc, const uint8_t *buf, size_t len); void slide_hash_power8(deflate_state *s); void inflate_fast_power8(PREFIX3(stream) *strm, uint32_t start); diff --git a/arch/riscv/riscv_functions.h b/arch/riscv/riscv_functions.h index 015b2fbd..1792b9d2 100644 --- a/arch/riscv/riscv_functions.h +++ b/arch/riscv/riscv_functions.h @@ -13,7 +13,7 @@ uint32_t adler32_rvv(uint32_t adler, const uint8_t *buf, size_t len); uint32_t adler32_fold_copy_rvv(uint32_t adler, uint8_t *dst, const uint8_t *src, size_t len); uint32_t chunksize_rvv(void); -uint8_t* chunkmemset_safe_rvv(uint8_t *out, unsigned dist, unsigned len, unsigned left); +uint8_t* chunkmemset_safe_rvv(uint8_t *out, uint8_t *from, unsigned len, unsigned left); uint32_t compare256_rvv(const uint8_t *src0, const uint8_t *src1); uint32_t longest_match_rvv(deflate_state *const s, Pos cur_match); diff --git a/arch/x86/chunkset_avx2.c b/arch/x86/chunkset_avx2.c index 86cbaaa8..8cc17103 100644 --- a/arch/x86/chunkset_avx2.c +++ b/arch/x86/chunkset_avx2.c @@ -15,6 +15,7 @@ typedef __m128i halfchunk_t; #define HAVE_CHUNKMEMSET_4 #define HAVE_CHUNKMEMSET_8 #define HAVE_CHUNKMEMSET_16 +#define HAVE_CHUNKMEMSET_1 #define HAVE_CHUNK_MAG #define HAVE_HALF_CHUNK @@ -125,24 +126,6 @@ static inline chunk_t GET_CHUNK_MAG(uint8_t *buf, uint32_t *chunk_rem, uint32_t return ret_vec; } -static inline void halfchunkmemset_2(uint8_t *from, halfchunk_t *chunk) { - int16_t tmp; - memcpy(&tmp, from, sizeof(tmp)); - *chunk = _mm_set1_epi16(tmp); -} - -static inline void halfchunkmemset_4(uint8_t *from, halfchunk_t *chunk) { - int32_t tmp; - memcpy(&tmp, from, sizeof(tmp)); - *chunk = _mm_set1_epi32(tmp); -} - -static inline void halfchunkmemset_8(uint8_t *from, halfchunk_t *chunk) { - int64_t tmp; - memcpy(&tmp, from, sizeof(tmp)); - *chunk = _mm_set1_epi64x(tmp); -} - static inline void loadhalfchunk(uint8_t const *s, halfchunk_t *chunk) { *chunk = _mm_loadu_si128((__m128i *)s); } @@ -151,10 +134,10 @@ static inline void storehalfchunk(uint8_t *out, halfchunk_t *chunk) { _mm_storeu_si128((__m128i *)out, *chunk); } -static inline chunk_t halfchunk2whole(halfchunk_t chunk) { +static inline chunk_t halfchunk2whole(halfchunk_t *chunk) { /* We zero extend mostly to appease some memory sanitizers. These bytes are ultimately * unlikely to be actually written or read from */ - return _mm256_zextsi128_si256(chunk); + return _mm256_zextsi128_si256(*chunk); } static inline halfchunk_t GET_HALFCHUNK_MAG(uint8_t *buf, uint32_t *chunk_rem, uint32_t dist) { diff --git a/arch/x86/x86_functions.h b/arch/x86/x86_functions.h index 5aa9b317..5f8fcf63 100644 --- a/arch/x86/x86_functions.h +++ b/arch/x86/x86_functions.h @@ -8,7 +8,7 @@ #ifdef X86_SSE2 uint32_t chunksize_sse2(void); -uint8_t* chunkmemset_safe_sse2(uint8_t *out, unsigned dist, unsigned len, unsigned left); +uint8_t* chunkmemset_safe_sse2(uint8_t *out, uint8_t *from, unsigned len, unsigned left); # ifdef HAVE_BUILTIN_CTZ uint32_t compare256_sse2(const uint8_t *src0, const uint8_t *src1); @@ -21,7 +21,7 @@ uint8_t* chunkmemset_safe_sse2(uint8_t *out, unsigned dist, unsigned len, unsign #ifdef X86_SSSE3 uint32_t adler32_ssse3(uint32_t adler, const uint8_t *buf, size_t len); -uint8_t* chunkmemset_safe_ssse3(uint8_t *out, unsigned dist, unsigned len, unsigned left); +uint8_t* chunkmemset_safe_ssse3(uint8_t *out, uint8_t *from, unsigned len, unsigned left); void inflate_fast_ssse3(PREFIX3(stream) *strm, uint32_t start); #endif @@ -33,7 +33,7 @@ uint32_t adler32_fold_copy_sse42(uint32_t adler, uint8_t *dst, const uint8_t *sr uint32_t adler32_avx2(uint32_t adler, const uint8_t *buf, size_t len); uint32_t adler32_fold_copy_avx2(uint32_t adler, uint8_t *dst, const uint8_t *src, size_t len); uint32_t chunksize_avx2(void); -uint8_t* chunkmemset_safe_avx2(uint8_t *out, unsigned dist, unsigned len, unsigned left); +uint8_t* chunkmemset_safe_avx2(uint8_t *out, uint8_t *from, unsigned len, unsigned left); # ifdef HAVE_BUILTIN_CTZ uint32_t compare256_avx2(const uint8_t *src0, const uint8_t *src1); diff --git a/chunkset_tpl.h b/chunkset_tpl.h index 9330e804..fc9f755e 100644 --- a/chunkset_tpl.h +++ b/chunkset_tpl.h @@ -4,6 +4,7 @@ #include "zbuild.h" #include +#include /* Returns the chunk size */ Z_INTERNAL uint32_t CHUNKSIZE(void) { @@ -69,18 +70,18 @@ static inline uint8_t* CHUNKUNROLL(uint8_t *out, unsigned *dist, unsigned *len) static inline chunk_t GET_CHUNK_MAG(uint8_t *buf, uint32_t *chunk_rem, uint32_t dist) { /* This code takes string of length dist from "from" and repeats * it for as many times as can fit in a chunk_t (vector register) */ - uint32_t cpy_dist; - uint32_t bytes_remaining = sizeof(chunk_t); + uint64_t cpy_dist; + uint64_t bytes_remaining = sizeof(chunk_t); chunk_t chunk_load; uint8_t *cur_chunk = (uint8_t *)&chunk_load; while (bytes_remaining) { cpy_dist = MIN(dist, bytes_remaining); - memcpy(cur_chunk, buf, cpy_dist); + memcpy(cur_chunk, buf, (size_t)cpy_dist); bytes_remaining -= cpy_dist; cur_chunk += cpy_dist; /* This allows us to bypass an expensive integer division since we're effectively * counting in this loop, anyway */ - *chunk_rem = cpy_dist; + *chunk_rem = (uint32_t)cpy_dist; } return chunk_load; @@ -109,21 +110,33 @@ static inline uint8_t* HALFCHUNKCOPY(uint8_t *out, uint8_t const *from, unsigned /* Copy DIST bytes from OUT - DIST into OUT + DIST * k, for 0 <= k < LEN/DIST. Return OUT + LEN. */ -static inline uint8_t* CHUNKMEMSET(uint8_t *out, unsigned dist, unsigned len) { +static inline uint8_t* CHUNKMEMSET(uint8_t *out, uint8_t *from, unsigned len) { /* Debug performance related issues when len < sizeof(uint64_t): Assert(len >= sizeof(uint64_t), "chunkmemset should be called on larger chunks"); */ - Assert(dist > 0, "chunkmemset cannot have a distance 0"); + Assert(from != out, "chunkmemset cannot have a distance 0"); - uint8_t *from = out - dist; chunk_t chunk_load; uint32_t chunk_mod = 0; uint32_t adv_amount; + int64_t sdist = out - from; + uint64_t dist = llabs(sdist); + + /* We are supporting the case for when we are reading bytes from ahead in the buffer. + * We now have to handle this, though it wasn't _quite_ clear if this rare circumstance + * always needed to be handled here or if we're just now seeing it because we are + * dispatching to this function, more */ + if (sdist < 0 && dist < len) { + /* Here the memmove semantics match perfectly, as when this happens we are + * effectively sliding down the contents of memory by dist bytes */ + memmove(out, from, len); + return out + len; + } if (dist == 1) { memset(out, *from, len); return out + len; - } else if (dist > sizeof(chunk_t)) { - return CHUNKCOPY(out, out - dist, len); + } else if (dist >= sizeof(chunk_t)) { + return CHUNKCOPY(out, from, len); } /* Only AVX2 as there's 128 bit vectors and 256 bit. We allow for shorter vector @@ -135,33 +148,22 @@ static inline uint8_t* CHUNKMEMSET(uint8_t *out, unsigned dist, unsigned len) { * making the code a little smaller. */ #ifdef HAVE_HALF_CHUNK if (len <= sizeof(halfchunk_t)) { - if (dist > sizeof(halfchunk_t)) { - return HALFCHUNKCOPY(out, out - dist, len); + if (dist >= sizeof(halfchunk_t)) + return HALFCHUNKCOPY(out, from, len); + + if ((dist % 2) != 0 || dist == 6) { + halfchunk_t halfchunk_load = GET_HALFCHUNK_MAG(from, &chunk_mod, (unsigned)dist); + + adv_amount = sizeof(halfchunk_t) - chunk_mod; + if (len == sizeof(halfchunk_t)) { + storehalfchunk(out, &halfchunk_load); + len -= adv_amount; + out += adv_amount; + } + + chunk_load = halfchunk2whole(&halfchunk_load); + goto rem_bytes; } - - halfchunk_t halfchunk_load; - - if (dist == 2) { - halfchunkmemset_2(from, &halfchunk_load); - } else if (dist == 4) { - halfchunkmemset_4(from, &halfchunk_load); - } else if (dist == 8) { - halfchunkmemset_8(from, &halfchunk_load); - } else if (dist == 16) { - loadhalfchunk(from, &halfchunk_load); - } else { - halfchunk_load = GET_HALFCHUNK_MAG(from, &chunk_mod, dist); - } - - adv_amount = sizeof(halfchunk_t) - chunk_mod; - while (len >= sizeof(halfchunk_t)) { - storehalfchunk(out, &halfchunk_load); - len -= adv_amount; - out += adv_amount; - } - - chunk_load = halfchunk2whole(halfchunk_load); - goto rem_bytes; } #endif @@ -185,11 +187,7 @@ static inline uint8_t* CHUNKMEMSET(uint8_t *out, unsigned dist, unsigned len) { chunkmemset_16(from, &chunk_load); } else #endif - if (dist == sizeof(chunk_t)) { - loadchunk(from, &chunk_load); - } else { - chunk_load = GET_CHUNK_MAG(from, &chunk_mod, dist); - } + chunk_load = GET_CHUNK_MAG(from, &chunk_mod, (unsigned)dist); adv_amount = sizeof(chunk_t) - chunk_mod; @@ -221,7 +219,7 @@ rem_bytes: return out; } -Z_INTERNAL uint8_t* CHUNKMEMSET_SAFE(uint8_t *out, unsigned dist, unsigned len, unsigned left) { +Z_INTERNAL uint8_t* CHUNKMEMSET_SAFE(uint8_t *out, uint8_t *from, unsigned len, unsigned left) { #if !defined(UNALIGNED64_OK) # if !defined(UNALIGNED_OK) static const uint32_t align_mask = 7; @@ -231,7 +229,7 @@ Z_INTERNAL uint8_t* CHUNKMEMSET_SAFE(uint8_t *out, unsigned dist, unsigned len, #endif len = MIN(len, left); - uint8_t *from = out - dist; + #if !defined(UNALIGNED64_OK) while (((uintptr_t)out & align_mask) && (len > 0)) { *out++ = *from++; @@ -239,15 +237,37 @@ Z_INTERNAL uint8_t* CHUNKMEMSET_SAFE(uint8_t *out, unsigned dist, unsigned len, --left; } #endif - if (left < (unsigned)(3 * sizeof(chunk_t))) { + if (UNLIKELY(left < sizeof(chunk_t))) { while (len > 0) { *out++ = *from++; --len; } + return out; } + if (len) - return CHUNKMEMSET(out, dist, len); + out = CHUNKMEMSET(out, from, len); return out; } + +static inline uint8_t *CHUNKCOPY_SAFE(uint8_t *out, uint8_t *from, unsigned len, uint8_t *safe) +{ + if (out == from) + return out + len; + + uint64_t safelen = (safe - out); + len = MIN(len, (unsigned)safelen); + + uint64_t from_dist = (uint64_t)llabs(safe - from); + if (UNLIKELY(from_dist < sizeof(chunk_t) || safelen < sizeof(chunk_t))) { + while (len--) { + *out++ = *from++; + } + + return out; + } + + return CHUNKMEMSET(out, from, len); +} diff --git a/functable.c b/functable.c index dd8f7731..832a57e7 100644 --- a/functable.c +++ b/functable.c @@ -273,9 +273,9 @@ static uint32_t adler32_fold_copy_stub(uint32_t adler, uint8_t* dst, const uint8 return functable.adler32_fold_copy(adler, dst, src, len); } -static uint8_t* chunkmemset_safe_stub(uint8_t* out, unsigned dist, unsigned len, unsigned left) { +static uint8_t* chunkmemset_safe_stub(uint8_t* out, uint8_t *from, unsigned len, unsigned left) { init_functable(); - return functable.chunkmemset_safe(out, dist, len, left); + return functable.chunkmemset_safe(out, from, len, left); } static uint32_t chunksize_stub(void) { diff --git a/functable.h b/functable.h index 173a030c..83dda880 100644 --- a/functable.h +++ b/functable.h @@ -27,7 +27,7 @@ struct functable_s { void (* force_init) (void); uint32_t (* adler32) (uint32_t adler, const uint8_t *buf, size_t len); uint32_t (* adler32_fold_copy) (uint32_t adler, uint8_t *dst, const uint8_t *src, size_t len); - uint8_t* (* chunkmemset_safe) (uint8_t *out, unsigned dist, unsigned len, unsigned left); + uint8_t* (* chunkmemset_safe) (uint8_t *out, uint8_t *from, unsigned len, unsigned left); uint32_t (* chunksize) (void); uint32_t (* compare256) (const uint8_t *src0, const uint8_t *src1); uint32_t (* crc32) (uint32_t crc, const uint8_t *buf, size_t len); diff --git a/inffast_tpl.h b/inffast_tpl.h index 23a6abd8..afa5e04e 100644 --- a/inffast_tpl.h +++ b/inffast_tpl.h @@ -235,7 +235,7 @@ void Z_INTERNAL INFLATE_FAST(PREFIX3(stream) *strm, uint32_t start) { from += wsize - op; if (op < len) { /* some from end of window */ len -= op; - out = chunkcopy_safe(out, from, op, safe); + out = CHUNKCOPY_SAFE(out, from, op, safe); from = window; /* more from start of window */ op = wnext; /* This (rare) case can create a situation where @@ -245,19 +245,23 @@ void Z_INTERNAL INFLATE_FAST(PREFIX3(stream) *strm, uint32_t start) { } if (op < len) { /* still need some from output */ len -= op; - out = chunkcopy_safe(out, from, op, safe); - if (!extra_safe) + if (!extra_safe) { + out = CHUNKCOPY_SAFE(out, from, op, safe); out = CHUNKUNROLL(out, &dist, &len); - out = chunkcopy_safe(out, out - dist, len, safe); + out = CHUNKCOPY_SAFE(out, out - dist, len, safe); + } else { + out = chunkcopy_safe(out, from, op, safe); + out = chunkcopy_safe(out, out - dist, len, safe); + } } else { - out = chunkcopy_safe(out, from, len, safe); + if (!extra_safe) + out = CHUNKCOPY_SAFE(out, from, len, safe); + else + out = chunkcopy_safe(out, from, len, safe); } } else if (extra_safe) { /* Whole reference is in range of current output. */ - if (dist >= len || dist >= state->chunksize) out = chunkcopy_safe(out, out - dist, len, safe); - else - out = CHUNKMEMSET_SAFE(out, dist, len, (unsigned)((safe - out))); } else { /* Whole reference is in range of current output. No range checks are necessary because we start with room for at least 258 bytes of output, @@ -267,7 +271,7 @@ void Z_INTERNAL INFLATE_FAST(PREFIX3(stream) *strm, uint32_t start) { if (dist >= len || dist >= state->chunksize) out = CHUNKCOPY(out, out - dist, len); else - out = CHUNKMEMSET(out, dist, len); + out = CHUNKMEMSET(out, out - dist, len); } } else if ((op & 64) == 0) { /* 2nd level distance code */ here = dcode + here->val + BITS(op); diff --git a/inflate.c b/inflate.c index 5a89a4bb..d2ffdcb3 100644 --- a/inflate.c +++ b/inflate.c @@ -1090,7 +1090,7 @@ int32_t Z_EXPORT PREFIX(inflate)(PREFIX3(stream) *strm, int32_t flush) { } else { copy = MIN(state->length, left); - put = FUNCTABLE_CALL(chunkmemset_safe)(put, state->offset, copy, left); + put = FUNCTABLE_CALL(chunkmemset_safe)(put, put - state->offset, copy, left); } left -= copy; state->length -= copy; diff --git a/inflate_p.h b/inflate_p.h index 59ad6d17..54c8dec9 100644 --- a/inflate_p.h +++ b/inflate_p.h @@ -150,7 +150,7 @@ static inline uint64_t load_64_bits(const unsigned char *in, unsigned bits) { /* Behave like chunkcopy, but avoid writing beyond of legal output. */ static inline uint8_t* chunkcopy_safe(uint8_t *out, uint8_t *from, uint64_t len, uint8_t *safe) { - uint64_t safelen = (safe - out) + 1; + uint64_t safelen = safe - out; len = MIN(len, safelen); int32_t olap_src = from >= out && from < out + len; int32_t olap_dst = out >= from && out < from + len;