Skip to content

Conversation

@yujincheng08
Copy link

Adopt from https://github.com/lz4/lz4-java

This is just a basic implementation of lz4hc algorithm, without much optimization and integration with the original codebase.

Java code for reference
final class LZ4HCJavaSafeCompressor  {

    static final int MEMORY_USAGE = 14;
    static final int NOT_COMPRESSIBLE_DETECTION_LEVEL = 6;

    static final int MIN_MATCH = 4;

    static final int HASH_LOG = MEMORY_USAGE - 2;

    static final int COPY_LENGTH = 8;
    static final int LAST_LITERALS = 5;
    static final int MF_LIMIT = COPY_LENGTH + MIN_MATCH;

    static final int MAX_DISTANCE = 1 << 16;

    static final int ML_BITS = 4;
    static final int ML_MASK = (1 << ML_BITS) - 1;
    static final int RUN_BITS = 8 - ML_BITS;
    static final int RUN_MASK = (1 << RUN_BITS) - 1;

    static final int HASH_LOG_64K = HASH_LOG + 1;

    static final int HASH_LOG_HC = 15;
    static final int HASH_TABLE_SIZE_HC = 1 << HASH_LOG_HC;
    static final int OPTIMAL_ML = ML_MASK - 1 + MIN_MATCH;

    private final int maxAttempts;
    final int compressionLevel;

    LZ4HCJavaSafeCompressor(int compressionLevel) {
        this.maxAttempts = 1<<(compressionLevel-1);
        this.compressionLevel = compressionLevel;
    }

    static int hashHC(int i) {
        return (i * -1640531535) >>> ((MIN_MATCH * 8) - HASH_LOG_HC);
    }

    static class SafeUtils {
        public static int readInt(byte[] buf, int i) {
            return (buf[i] & 0xFF) | ((buf[i + 1] & 0xFF) << 8) | ((buf[i + 2] & 0xFF) << 16) | ((buf[i + 3] & 0xFF) << 24);
        }
    }
    enum LZ4SafeUtils {
        ;
        static boolean readIntEquals(byte[] buf, int i, int j) {
            return buf[i] == buf[j] && buf[i+1] == buf[j+1] && buf[i+2] == buf[j+2] && buf[i+3] == buf[j+3];
        }

        static void copy8Bytes(byte[] src, int sOff, byte[] dest, int dOff) {
            for (int i = 0; i < 8; ++i) {
                dest[dOff + i] = src[sOff + i];
            }
        }

        static int commonBytes(byte[] b, int o1, int o2, int limit) {
            int count = 0;
            while (o2 < limit && b[o1++] == b[o2++]) {
                ++count;
            }
            return count;
        }

        static int commonBytesBackward(byte[] b, int o1, int o2, int l1, int l2) {
            int count = 0;
            while (o1 > l1 && o2 > l2 && b[--o1] == b[--o2]) {
                ++count;
            }
            return count;
        }

        static void wildArraycopy(byte[] src, int sOff, byte[] dest, int dOff, int len) {
            try {
                for (int i = 0; i < len; i += 8) {
                    copy8Bytes(src, sOff + i, dest, dOff + i);
                }
            } catch (ArrayIndexOutOfBoundsException e) {
                throw new RuntimeException("Malformed input at offset " + sOff);
            }
        }

        static int encodeSequence(byte[] src, int anchor, int matchOff, int matchRef, int matchLen, byte[] dest, int dOff, int destEnd) {
            System.out.println("encodeSequence srcLen=" + src.length + " anchor=" + anchor + " matchOff=" + matchOff + " matchRef=" + matchRef + " matchLen=" + matchLen + " destLen=" + dest.length + " dOff=" + dOff + " destEnd=" + destEnd);
            final int runLen = matchOff - anchor;
            System.out.println("runLen=" + runLen);
            final int tokenOff = dOff++;

            if (dOff + runLen + (2 + 1 + LAST_LITERALS) + (runLen >>> 8) > destEnd) {
                throw new RuntimeException("maxDestLen is too small");
            }

            int token;
            if (runLen >= RUN_MASK) {
                token = (byte) (RUN_MASK << ML_BITS);
                dOff = writeLen(runLen - RUN_MASK, dest, dOff);
            } else {
                token = runLen << ML_BITS;
            }

            // copy literals
            wildArraycopy(src, anchor, dest, dOff, runLen);
            dOff += runLen;

            // encode offset
            final int matchDec = matchOff - matchRef;
            dest[dOff++] = (byte) matchDec;
            dest[dOff++] = (byte) (matchDec >>> 8);

            // encode match len
            matchLen -= 4;
            if (dOff + (1 + LAST_LITERALS) + (matchLen >>> 8) > destEnd) {
                throw new RuntimeException("maxDestLen is too small");
            }
            if (matchLen >= ML_MASK) {
                token |= ML_MASK;
                dOff = writeLen(matchLen - RUN_MASK, dest, dOff);
            } else {
                token |= matchLen;
            }

            dest[tokenOff] = (byte) token;

            return dOff;
        }

        static int lastLiterals(byte[] src, int sOff, int srcLen, byte[] dest, int dOff, int destEnd) {
            System.out.println("lastLiterals srcLen=" + src.length + " sOff=" + sOff + " srcLen=" + srcLen + " destLen=" + dest.length + " dOff=" + dOff + " destEnd=" + destEnd);
            final int runLen = srcLen;

            if (dOff + runLen + 1 + (runLen + 255 - RUN_MASK) / 255 > destEnd) {
                throw new RuntimeException();
            }

            if (runLen >= RUN_MASK) {
                dest[dOff++] = (byte) (RUN_MASK << ML_BITS);
                dOff = writeLen(runLen - RUN_MASK, dest, dOff);
            } else {
                dest[dOff++] = (byte) (runLen << ML_BITS);
            }
            // copy literals
            System.arraycopy(src, sOff, dest, dOff, runLen);
            dOff += runLen;

            return dOff;
        }

        static int writeLen(int len, byte[] dest, int dOff) {
            while (len >= 0xFF) {
                dest[dOff++] = (byte) 0xFF;
                len -= 0xFF;
            }
            dest[dOff++] = (byte) len;
            return dOff;
        }
    }


    static class Match {
        int start, ref, len;

        void fix(int correction) {
            start += correction;
            ref += correction;
            len -= correction;
        }

        int end() {
            return start + len;
        }
    }

    static void copyTo(Match m1, Match m2) {
        m2.len = m1.len;
        m2.start = m1.start;
        m2.ref = m1.ref;
    }


    private class HashTable {
        static final int MASK = MAX_DISTANCE - 1;
        int nextToUpdate;
        private final int base;
        private final int[] hashTable;
        private final short[] chainTable;

        HashTable(int base) {
            this.base = base;
            nextToUpdate = base;
            hashTable = new int[HASH_TABLE_SIZE_HC];
            Arrays.fill(hashTable, -1);
            chainTable = new short[MAX_DISTANCE];
        }

        private int hashPointer(byte[] bytes, int off) {
            final int v = SafeUtils.readInt(bytes, off);
            return hashPointer(v);
        }

        private int hashPointer(int v) {
            final int h = hashHC(v);
            return hashTable[h];
        }

        private int next(int off) {
            return off - (chainTable[off & MASK] & 0xFFFF);
        }

        private void addHash(byte[] bytes, int off) {
            final int v = SafeUtils.readInt(bytes, off);
            addHash(v, off);
        }

        private void addHash(int v, int off) {
            final int h = hashHC(v);
            int delta = off - hashTable[h];
            assert delta > 0 : delta;
            if (delta >= MAX_DISTANCE) {
                delta = MAX_DISTANCE - 1;
            }
            chainTable[off & MASK] = (short) delta;
            hashTable[h] = off;
        }

        void insert(int off, byte[] bytes) {
            for (; nextToUpdate < off; ++nextToUpdate) {
                addHash(bytes, nextToUpdate);
            }
        }

        boolean insertAndFindBestMatch(byte[] buf, int off, int matchLimit, Match match) {
            match.start = off;
            match.len = 0;
            int delta = 0;
            int repl = 0;

            insert(off, buf);

            int ref = hashPointer(buf, off);

            if (ref >= off - 4 && ref <= off && ref >= base) { // potential repetition
                if (LZ4SafeUtils.readIntEquals(buf, ref, off)) { // confirmed
                    delta = off - ref;
                    repl = match.len = MIN_MATCH + LZ4SafeUtils.commonBytes(buf, ref + MIN_MATCH, off + MIN_MATCH, matchLimit);
                    match.ref = ref;
                }
                ref = next(ref);
            }

            for (int i = 0; i < maxAttempts; ++i) {
                if (ref < Math.max(base, off - MAX_DISTANCE + 1) || ref > off) {
                    break;
                }
                if (LZ4SafeUtils.readIntEquals(buf, ref, off)) {
                    final int matchLen = MIN_MATCH + LZ4SafeUtils.commonBytes(buf, ref + MIN_MATCH, off + MIN_MATCH, matchLimit);
                    if (matchLen > match.len) {
                        match.ref = ref;
                        match.len = matchLen;
                    }
                }
                ref = next(ref);
            }

            if (repl != 0) {
                int ptr = off;
                final int end = off + repl - (MIN_MATCH - 1);
                while (ptr < end - delta) {
                    chainTable[ptr & MASK] = (short) delta; // pre load
                    ++ptr;
                }
                do {
                    chainTable[ptr & MASK] = (short) delta;
                    hashTable[hashHC(SafeUtils.readInt(buf, ptr))] = ptr;
                    ++ptr;
                } while (ptr < end);
                nextToUpdate = end;
            }

            return match.len != 0;
        }

        boolean insertAndFindWiderMatch(byte[] buf, int off, int startLimit, int matchLimit, int minLen, Match match) {
            match.len = minLen;

            insert(off, buf);

            final int delta = off - startLimit;
            int ref = hashPointer(buf, off);
            for (int i = 0; i < maxAttempts; ++i) {
                if (ref < Math.max(base, off - MAX_DISTANCE + 1) || ref > off) {
                    break;
                }
                if (LZ4SafeUtils.readIntEquals(buf, ref, off)) {
                    final int matchLenForward = MIN_MATCH +LZ4SafeUtils.commonBytes(buf, ref + MIN_MATCH, off + MIN_MATCH, matchLimit);
                    final int matchLenBackward = LZ4SafeUtils.commonBytesBackward(buf, ref, off, base, startLimit);
                    final int matchLen = matchLenBackward + matchLenForward;
                    if (matchLen > match.len) {
                        match.len = matchLen;
                        match.ref = ref - matchLenBackward;
                        match.start = off - matchLenBackward;
                    }
                }
                ref = next(ref);
            }

            return match.len > minLen;
        }


    }


    public int compress(byte[] src, int srcOff, int srcLen, byte[] dest, int destOff, int maxDestLen) {

        final int srcEnd = srcOff + srcLen;
        final int destEnd = destOff + maxDestLen;
        final int mfLimit = srcEnd - MF_LIMIT;
        final int matchLimit = srcEnd - LAST_LITERALS;

        int sOff = srcOff;
        int dOff = destOff;
        int anchor = sOff++;

        final HashTable ht = new HashTable(srcOff);
        final Match match0 = new Match();
        final Match match1 = new Match();
        final Match match2 = new Match();
        final Match match3 = new Match();

        main:
        while (sOff < mfLimit) {
            if (!ht.insertAndFindBestMatch(src, sOff, matchLimit, match1)) {
                ++sOff;
                continue;
            }

            // saved, in case we would skip too much
            copyTo(match1, match0);

            search2:
            while (true) {
                assert match1.start >= anchor;
                if (match1.end() >= mfLimit
                        || !ht.insertAndFindWiderMatch(src, match1.end() - 2, match1.start + 1, matchLimit, match1.len, match2)) {
                    // no better match
                    dOff = LZ4SafeUtils.encodeSequence(src, anchor, match1.start, match1.ref, match1.len, dest, dOff, destEnd);
                    anchor = sOff = match1.end();
                    continue main;
                }

                if (match0.start < match1.start) {
                    if (match2.start < match1.start + match0.len) { // empirical
                        copyTo(match0, match1);
                    }
                }
                assert match2.start > match1.start;

                if (match2.start - match1.start < 3) { // First Match too small : removed
                    copyTo(match2, match1);
                    continue search2;
                }

                search3:
                while (true) {
                    if (match2.start - match1.start < OPTIMAL_ML) {
                        int newMatchLen = match1.len;
                        if (newMatchLen > OPTIMAL_ML) {
                            newMatchLen = OPTIMAL_ML;
                        }
                        if (match1.start + newMatchLen > match2.end() - MIN_MATCH) {
                            newMatchLen = match2.start - match1.start + match2.len - MIN_MATCH;
                        }
                        final int correction = newMatchLen - (match2.start - match1.start);
                        if (correction > 0) {
                            match2.fix(correction);
                        }
                    }

                    if (match2.start + match2.len >= mfLimit
                            || !ht.insertAndFindWiderMatch(src, match2.end() - 3, match2.start, matchLimit, match2.len, match3)) {
                        // no better match -> 2 sequences to encode
                        if (match2.start < match1.end()) {
                            match1.len = match2.start - match1.start;
                        }
                        // encode seq 1
                        dOff = LZ4SafeUtils.encodeSequence(src, anchor, match1.start, match1.ref, match1.len, dest, dOff, destEnd);
                        anchor = sOff = match1.end();
                        // encode seq 2
                        dOff = LZ4SafeUtils.encodeSequence(src, anchor, match2.start, match2.ref, match2.len, dest, dOff, destEnd);
                        anchor = sOff = match2.end();
                        continue main;
                    }

                    if (match3.start < match1.end() + 3) { // Not enough space for match 2 : remove it
                        if (match3.start >= match1.end()) { // // can write Seq1 immediately ==> Seq2 is removed, so Seq3 becomes Seq1
                            if (match2.start < match1.end()) {
                                final int correction = match1.end() - match2.start;
                                match2.fix(correction);
                                if (match2.len < MIN_MATCH) {
                                    copyTo(match3, match2);
                                }
                            }

                            dOff = LZ4SafeUtils.encodeSequence(src, anchor, match1.start, match1.ref, match1.len, dest, dOff, destEnd);
                            anchor = sOff = match1.end();

                            copyTo(match3, match1);
                            copyTo(match2, match0);

                            continue search2;
                        }

                        copyTo(match3, match2);
                        continue search3;
                    }

                    // OK, now we have 3 ascending matches; let's write at least the first one
                    if (match2.start < match1.end()) {
                        if (match2.start - match1.start < ML_MASK) {
                            if (match1.len > OPTIMAL_ML) {
                                match1.len = OPTIMAL_ML;
                            }
                            if (match1.end() > match2.end() - MIN_MATCH) {
                                match1.len = match2.end() - match1.start - MIN_MATCH;
                            }
                            final int correction = match1.end() - match2.start;
                            match2.fix(correction);
                        } else {
                            match1.len = match2.start - match1.start;
                        }
                    }

                    dOff = LZ4SafeUtils.encodeSequence(src, anchor, match1.start, match1.ref, match1.len, dest, dOff, destEnd);
                    anchor = sOff = match1.end();

                    copyTo(match2, match1);
                    copyTo(match3, match2);

                    continue search3;
                }

            }

        }

        dOff = LZ4SafeUtils.lastLiterals(src, anchor, srcEnd - anchor, dest, dOff, destEnd);
        return dOff - destOff;
    }
}

@yujincheng08 yujincheng08 marked this pull request as ready for review August 27, 2025 13:17
@PSeitz
Copy link
Owner

PSeitz commented Oct 28, 2025

Thanks! Can you connect it to the proptests and add it to the benchmarks in benches/binggan_bench.rs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants