diff --git a/src/main/java/org/xerial/snappy/Snappy.java b/src/main/java/org/xerial/snappy/Snappy.java index 31d0d2ee..0075fe3d 100755 --- a/src/main/java/org/xerial/snappy/Snappy.java +++ b/src/main/java/org/xerial/snappy/Snappy.java @@ -535,17 +535,19 @@ public static int uncompress(ByteBuffer compressed, ByteBuffer uncompressed) if (!compressed.isDirect()) { throw new SnappyError(SnappyErrorCode.NOT_A_DIRECT_BUFFER, "input is not a direct buffer"); } - if (!uncompressed.isDirect()) { - throw new SnappyError(SnappyErrorCode.NOT_A_DIRECT_BUFFER, "destination is not a direct buffer"); - } int cPos = compressed.position(); int cLen = compressed.remaining(); // pos limit // [ ......UUUUUU.........] - int decompressedSize = impl.rawUncompress(compressed, cPos, cLen, uncompressed, - uncompressed.position()); + final int decompressedSize; + if (uncompressed.isDirect()) { + decompressedSize = impl.rawUncompress(compressed, cPos, cLen, uncompressed, uncompressed.position()); + } else { + decompressedSize = impl.rawUncompressDirectToHeap(compressed, cPos, cLen, + uncompressed.array(), uncompressed.position()); + } uncompressed.limit(uncompressed.position() + decompressedSize); return decompressedSize; diff --git a/src/main/java/org/xerial/snappy/SnappyNative.cpp b/src/main/java/org/xerial/snappy/SnappyNative.cpp index aacd6290..d6df7121 100755 --- a/src/main/java/org/xerial/snappy/SnappyNative.cpp +++ b/src/main/java/org/xerial/snappy/SnappyNative.cpp @@ -297,3 +297,35 @@ JNIEXPORT void JNICALL Java_org_xerial_snappy_SnappyNative_arrayCopy env->ReleasePrimitiveArrayCritical((jarray) output, dest, 0); } +/* + * Class: org_xerial_snappy_SnappyNative + * Method: rawUncompressDirectToHeap + * Signature: (Ljava/nio/ByteBuffer;IILjava/lang/Object;I)I + */ +JNIEXPORT jint JNICALL Java_org_xerial_snappy_SnappyNative_rawUncompressDirectToHeap + (JNIEnv* env, jobject self, jobject compressedBuffer, jint inputPos, jint inputLength, + jobject uncompressedArray, jint outputOffset) +{ + char* in = (char*) env->GetDirectBufferAddress(compressedBuffer); + if (in == 0) { + throw_exception(env, self, 3); + return (jint) 0; + } + char* out = (char*) env->GetPrimitiveArrayCritical((jarray) uncompressedArray, 0); + if (out == 0) { + // out of memory + throw_exception(env, self, 4); + return (jint) 0; + } + size_t decompressedLength; + snappy::GetUncompressedLength(in + inputPos, (size_t) inputLength, &decompressedLength); + bool ret = snappy::RawUncompress(in + inputPos, (size_t) inputLength, out + outputOffset); + env->ReleasePrimitiveArrayCritical((jarray) uncompressedArray, out, 0); + if(!ret) { + // failed to decompress + throw_exception(env, self, 5); + return (jint) 0; + } + return (jint) decompressedLength; +} + diff --git a/src/main/java/org/xerial/snappy/SnappyNative.h b/src/main/java/org/xerial/snappy/SnappyNative.h index d1f3e580..fe1adbdc 100644 --- a/src/main/java/org/xerial/snappy/SnappyNative.h +++ b/src/main/java/org/xerial/snappy/SnappyNative.h @@ -55,6 +55,14 @@ JNIEXPORT jint JNICALL Java_org_xerial_snappy_SnappyNative_rawCompress__Ljava_la JNIEXPORT jint JNICALL Java_org_xerial_snappy_SnappyNative_rawUncompress__Ljava_nio_ByteBuffer_2IILjava_nio_ByteBuffer_2I (JNIEnv *, jobject, jobject, jint, jint, jobject, jint); +/* + * Class: org_xerial_snappy_SnappyNative + * Method: rawUncompressDirectToHeap + * Signature: (Ljava/nio/ByteBuffer;IILjava/lang/Object;I)I + */ +JNIEXPORT jint JNICALL Java_org_xerial_snappy_SnappyNative_rawUncompressDirectToHeap + (JNIEnv *, jobject, jobject, jint, jint, jobject, jint); + /* * Class: org_xerial_snappy_SnappyNative * Method: rawUncompress diff --git a/src/main/java/org/xerial/snappy/SnappyNative.java b/src/main/java/org/xerial/snappy/SnappyNative.java index 95a6f419..98909d4f 100755 --- a/src/main/java/org/xerial/snappy/SnappyNative.java +++ b/src/main/java/org/xerial/snappy/SnappyNative.java @@ -63,6 +63,10 @@ public native int rawUncompress(ByteBuffer compressed, int inputOffset, int inpu int outputOffset) throws IOException; + public native int rawUncompressDirectToHeap(ByteBuffer compressed, int inputOffset, int inputLength, + Object uncompressed, int outputOffset) + throws IOException; + public native int rawUncompress(Object input, int inputOffset, int inputLength, Object output, int outputOffset) throws IOException; diff --git a/src/main/resources/org/xerial/snappy/native/Linux/x86_64/libsnappyjava.so b/src/main/resources/org/xerial/snappy/native/Linux/x86_64/libsnappyjava.so index 69998c9a..719dd370 100755 Binary files a/src/main/resources/org/xerial/snappy/native/Linux/x86_64/libsnappyjava.so and b/src/main/resources/org/xerial/snappy/native/Linux/x86_64/libsnappyjava.so differ diff --git a/src/test/java/org/xerial/snappy/SnappyGenerativeTest.java b/src/test/java/org/xerial/snappy/SnappyGenerativeTest.java new file mode 100644 index 00000000..9602f19d --- /dev/null +++ b/src/test/java/org/xerial/snappy/SnappyGenerativeTest.java @@ -0,0 +1,66 @@ +package org.xerial.snappy; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; + +import static org.junit.Assert.assertArrayEquals; + +@RunWith(Parameterized.class) +public class SnappyGenerativeTest { + + @Parameterized.Parameters + public static Iterable data() { + List testCases = new ArrayList<>(100); + for (int i = 0 ; i < 100; ++i) { + testCases.add(randomData()); + } + return testCases; + } + + + private final byte[] input; + + public SnappyGenerativeTest(byte[] input) { + this.input = input; + } + + @Test + public void roundTripDirectToDirect() throws IOException { + ByteBuffer in = ByteBuffer.allocateDirect(input.length); + in.put(input); + ByteBuffer compressed = ByteBuffer.allocateDirect(input.length * 2); + Snappy.compress(in, compressed); + Snappy.uncompress(compressed, in); + byte[] result = new byte[input.length]; + in.flip(); + in.get(result); + assertArrayEquals(input, result); + } + + @Test + public void roundTripDirectToHeap() throws IOException { + ByteBuffer in = ByteBuffer.allocateDirect(input.length); + in.put(input); + in.flip(); + ByteBuffer compressed = ByteBuffer.allocateDirect(input.length * 2); + Snappy.compress(in, compressed); + ByteBuffer out = ByteBuffer.allocate(input.length); + Snappy.uncompress(compressed, out); + out.flip(); + assertArrayEquals(input, out.array()); + } + + private static Object[] randomData() { + int length = Math.abs(ThreadLocalRandom.current().nextInt(10,10_000)); + byte[] data = new byte[length]; + ThreadLocalRandom.current().nextBytes(data); + return new Object[] {data}; + } +} diff --git a/src/test/java/org/xerial/snappy/SnappyTest.java b/src/test/java/org/xerial/snappy/SnappyTest.java index 18b39e92..488bb11d 100755 --- a/src/test/java/org/xerial/snappy/SnappyTest.java +++ b/src/test/java/org/xerial/snappy/SnappyTest.java @@ -113,6 +113,51 @@ public void directBuffer() assertEquals(origStr, decompressed); } + + @Test + public void directBufferToHeapBuffer() throws Exception + { + + StringBuilder s = new StringBuilder(); + for (int i = 0; i < 20; ++i) { + s.append("Hello world!"); + } + String origStr = s.toString(); + byte[] orig = origStr.getBytes(); + ByteBuffer src = ByteBuffer.allocateDirect(orig.length); + src.put(orig); + src.flip(); + _logger.debug("input size: " + src.remaining()); + int maxCompressedLen = Snappy.maxCompressedLength(src.remaining()); + _logger.debug("max compressed length:" + maxCompressedLen); + + ByteBuffer compressed = ByteBuffer.allocateDirect(maxCompressedLen); + int compressedSize = Snappy.compress(src, compressed); + _logger.debug("compressed length: " + compressedSize); + + assertEquals(0, src.position()); + assertEquals(orig.length, src.remaining()); + assertEquals(orig.length, src.limit()); + + assertEquals(0, compressed.position()); + assertEquals(compressedSize, compressed.limit()); + assertEquals(compressedSize, compressed.remaining()); + + int uncompressedLen = Snappy.uncompressedLength(compressed); + _logger.debug("uncompressed length: " + uncompressedLen); + ByteBuffer extract = ByteBuffer.allocate(uncompressedLen); + int uncompressedLen2 = Snappy.uncompress(compressed, extract); + assertEquals(uncompressedLen, uncompressedLen2); + assertEquals(uncompressedLen, extract.remaining()); + + byte[] b = new byte[uncompressedLen]; + extract.get(b); + String decompressed = new String(b); + _logger.debug(decompressed); + + assertEquals(origStr, decompressed); + } + @Test public void bufferOffset() throws Exception