diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ByteArrayOutputStreamWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ByteArrayOutputStreamWriter.java new file mode 100644 index 000000000..b60872998 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ByteArrayOutputStreamWriter.java @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + +import ai.rapids.cudf.HostMemoryBuffer; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; + +/** + * Adapter class which helps to save memory copy when shuffle manager uses + * {@link ByteArrayOutputStream} during serialization. + */ +public class ByteArrayOutputStreamWriter implements DataWriter { + private static final Method ENSURE_CAPACITY; + private static final Field BUF; + private static final Field COUNT; + + static { + try { + ENSURE_CAPACITY = ByteArrayOutputStream.class.getDeclaredMethod("ensureCapacity", int.class); + ENSURE_CAPACITY.setAccessible(true); + + BUF = ByteArrayOutputStream.class.getDeclaredField("buf"); + BUF.setAccessible(true); + + + COUNT = ByteArrayOutputStream.class.getDeclaredField("count"); + COUNT.setAccessible(true); + } catch (NoSuchMethodException | NoSuchFieldException e) { + throw new RuntimeException("Failed to find ByteArrayOutputStream.ensureCapacity", e); + } + } + + private final ByteArrayOutputStream out; + + public ByteArrayOutputStreamWriter(ByteArrayOutputStream bout) { + requireNonNull(bout, "Byte array output stream can't be null"); + this.out = bout; + } + + @Override + public void reserve(int size) throws IOException { + try { + ENSURE_CAPACITY.invoke(out, size); + } catch (Exception e) { + throw new RuntimeException("Failed to invoke ByteArrayOutputStream.ensureCapacity", e); + } + } + + @Override + public void writeInt(int v) throws IOException { + reserve(Integer.BYTES + out.size()); + byte[] bytes = new byte[4]; + bytes[0] = (byte) ((v >>> 24) & 0xFF); + bytes[1] = (byte) ((v >>> 16) & 0xFF); + bytes[2] = (byte) ((v >>> 8) & 0xFF); + bytes[3] = (byte) (v & 0xFF); + out.write(bytes); + } + + @Override + public void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException { + reserve(toIntExact(out.size() + len)); + + try { + byte[] buf = (byte[]) BUF.get(out); + int count = out.size(); + + src.getBytes(buf, count, srcOffset, len); + COUNT.setInt(out, toIntExact(count + len)); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + @Override + public void flush() throws IOException { + } + + @Override + public void write(byte[] arr, int offset, int length) throws IOException { + out.write(arr, offset, length); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java index c88f125b2..1d714e32b 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ /** * Visible for testing */ -class DataOutputStreamWriter extends DataWriter { +class DataOutputStreamWriter implements DataWriter { private final byte[] arrayBuffer = new byte[1024]; private final DataOutputStream dout; diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java index 1f2e8f3dc..1e5a6e8e6 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,11 +21,19 @@ import java.io.IOException; /** - * Visible for testing + * Output data writer for kudo serializer. */ -abstract class DataWriter { +public interface DataWriter { - public abstract void writeInt(int i) throws IOException; + /** + * Write int in network byte order. + */ + void writeInt(int i) throws IOException; + + /** + * Reserve space in the buffer for the given size. + */ + default void reserve(int size) throws IOException {} /** * Copy data from src starting at srcOffset and going for len bytes. @@ -34,11 +42,12 @@ abstract class DataWriter { * @param srcOffset offset to start at. * @param len amount to copy. */ - public abstract void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException; + void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException; - public void flush() throws IOException { - // NOOP by default - } + void flush() throws IOException; - public abstract void write(byte[] arr, int offset, int length) throws IOException; + /** + * Copy part of byte array to this writer. + */ + void write(byte[] arr, int offset, int length) throws IOException; } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java index 8b096813b..990a9faee 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids.jni.kudo; import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; import ai.rapids.cudf.BufferType; @@ -28,6 +29,7 @@ import com.nvidia.spark.rapids.jni.Pair; import com.nvidia.spark.rapids.jni.schema.Visitors; import java.io.BufferedOutputStream; +import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; @@ -328,6 +330,9 @@ private WriteMetrics writeSliced(HostColumnVector[] columns, DataWriter out, int new KudoTableHeaderCalc(rowOffset, numRows, flattenedColumnCount); withTime(() -> Visitors.visitColumns(columns, headerCalc), metrics::addCalcHeaderTime); KudoTableHeader header = headerCalc.getHeader(); + + out.reserve(toIntExact(header.getSerializedSize() + header.getTotalDataLen())); + long currentTime = System.nanoTime(); header.writeTo(out); metrics.addCopyHeaderTime(System.nanoTime() - currentTime); @@ -355,10 +360,15 @@ private WriteMetrics writeSliced(HostColumnVector[] columns, DataWriter out, int } private static DataWriter writerFrom(OutputStream out) { - if (!(out instanceof DataOutputStream)) { - out = new DataOutputStream(new BufferedOutputStream(out)); + if (out instanceof DataOutputStream) { + return new DataOutputStreamWriter((DataOutputStream) out); + } else if (out instanceof OpenByteArrayOutputStream) { + return new OpenByteArrayOutputStreamWriter((OpenByteArrayOutputStream) out); + } else if (out instanceof ByteArrayOutputStream) { + return new ByteArrayOutputStreamWriter((ByteArrayOutputStream) out); + } else { + return new DataOutputStreamWriter(new DataOutputStream(new BufferedOutputStream(out))); } - return new DataOutputStreamWriter((DataOutputStream) out); } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/OpenByteArrayOutputStream.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/OpenByteArrayOutputStream.java new file mode 100644 index 000000000..ca8fadf0a --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/OpenByteArrayOutputStream.java @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; + +import java.io.ByteArrayOutputStream; +import java.util.Arrays; + +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static java.util.Objects.requireNonNull; + +/** + * This class extends {@link ByteArrayOutputStream} to provide some internal methods to save copy. + */ +public class OpenByteArrayOutputStream extends ByteArrayOutputStream { + private static final int MAX_ARRAY_LENGTH = Integer.MAX_VALUE - 32; + + /** + * Creates a new byte array output stream. The buffer capacity is + * initially 32 bytes, though its size increases if necessary. + */ + public OpenByteArrayOutputStream() { + this(32); + } + + /** + * Creates a new byte array output stream, with a buffer capacity of + * the specified size, in bytes. + * + * @param size the initial size. + * @exception IllegalArgumentException if size is negative. + */ + public OpenByteArrayOutputStream(int size) { + super(size); + } + + /** + * Get underlying byte array. + */ + public byte[] getBuf() { + return buf; + } + + /** + * Get actual number of bytes that have been written to this output stream. + * @return Number of bytes written to this output stream. Note that this maybe smaller than length of + * {@link OpenByteArrayOutputStream#getBuf()}. + */ + public int getCount() { + return count; + } + + /** + * Increases the capacity if necessary to ensure that it can hold + * at least the number of elements specified by the minimum + * capacity argument. + * + * @param capacity the desired minimum capacity + * @throws IllegalStateException If {@code capacity < 0} or {@code capacity >= MAX_ARRAY_LENGTH}. + */ + public void reserve(int capacity) { + ensure(capacity >= 0, () -> "Requested capacity must be positive, but was " + capacity); + ensure(capacity < MAX_ARRAY_LENGTH, () -> "Requested capacity is too large: " + capacity); + + if (capacity > buf.length) { + buf = Arrays.copyOf(buf, capacity); + } + } + + /** + * Copy from {@link HostMemoryBuffer} to this output stream. + * @param srcBuf {@link HostMemoryBuffer} to copy from. + * @param offset Start position in source {@link HostMemoryBuffer}. + * @param length Number of bytes to copy. + */ + public void write(HostMemoryBuffer srcBuf, long offset, int length) { + requireNonNull(srcBuf, "Source buf can't be null!"); + reserve(count + length); + srcBuf.getBytes(buf, count, offset, length); + count += length; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/OpenByteArrayOutputStreamWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/OpenByteArrayOutputStreamWriter.java new file mode 100644 index 000000000..6976995f2 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/OpenByteArrayOutputStreamWriter.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + +import ai.rapids.cudf.HostMemoryBuffer; +import java.io.IOException; + +/** + * Adapter class which helps to save memory copy when shuffle manager uses + * {@link OpenByteArrayOutputStream} during serialization. + */ +public class OpenByteArrayOutputStreamWriter implements DataWriter { + private final OpenByteArrayOutputStream out; + + public OpenByteArrayOutputStreamWriter(OpenByteArrayOutputStream bout) { + requireNonNull(bout, "Byte array output stream can't be null"); + this.out = bout; + } + + @Override + public void reserve(int size) throws IOException { + out.reserve(size); + } + + @Override + public void writeInt(int v) throws IOException { + out.reserve(4 + out.size()); + out.write((v >>> 24) & 0xFF); + out.write((v >>> 16) & 0xFF); + out.write((v >>> 8) & 0xFF); + out.write((v >>> 0) & 0xFF); + } + + @Override + public void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException { + out.write(src, srcOffset, toIntExact(len)); + } + + @Override + public void flush() throws IOException { + } + + @Override + public void write(byte[] arr, int offset, int length) throws IOException { + out.write(arr, offset, length); + } +} diff --git a/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java b/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java index 3ffcb5e61..f98a7459f 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, NVIDIA CORPORATION. + * Copyright (c) 2024-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import java.io.DataInputStream; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -54,7 +55,7 @@ public void testSerializeAndDeserializeTable() { @Test public void testRowCountOnly() throws Exception { - ByteArrayOutputStream out = new ByteArrayOutputStream(); + OpenByteArrayOutputStream out = new OpenByteArrayOutputStream(); long bytesWritten = KudoSerializer.writeRowCountToStream(out, 5); assertEquals(28, bytesWritten); @@ -74,7 +75,7 @@ public void testWriteSimple() throws Exception { KudoSerializer serializer = new KudoSerializer(buildSimpleTestSchema()); try (Table t = buildSimpleTable()) { - ByteArrayOutputStream out = new ByteArrayOutputStream(); + OpenByteArrayOutputStream out = new OpenByteArrayOutputStream(); long bytesWritten = serializer.writeToStreamWithMetrics(t, out, 0, 4).getWrittenBytes(); assertEquals(189, bytesWritten); @@ -193,6 +194,36 @@ public void testSerializeValidity() { }); } + @Test + public void testByteArrayOutputStreamWriter() throws Exception { + ByteArrayOutputStream bout = new ByteArrayOutputStream(32); + DataWriter writer = new ByteArrayOutputStreamWriter(bout); + + writer.writeInt(0x12345678); + + byte[] testByteArr1 = new byte[2097]; + ThreadLocalRandom.current().nextBytes(testByteArr1); + writer.write(testByteArr1, 0, testByteArr1.length); + + byte[] testByteArr2 = new byte[7896]; + ThreadLocalRandom.current().nextBytes(testByteArr2); + try(HostMemoryBuffer buffer = HostMemoryBuffer.allocate(testByteArr2.length)) { + buffer.setBytes(0, testByteArr2, 0, testByteArr2.length); + writer.copyDataFrom(buffer, 0, testByteArr2.length); + } + + byte[] expected = new byte[4 + testByteArr1.length + testByteArr2.length]; + expected[0] = 0x12; + expected[1] = 0x34; + expected[2] = 0x56; + expected[3] = 0x78; + System.arraycopy(testByteArr1, 0, expected, 4, testByteArr1.length); + System.arraycopy(testByteArr2, 0, expected, 4 + testByteArr1.length, + testByteArr2.length); + + assertArrayEquals(expected, bout.toByteArray()); + } + private static Schema buildSimpleTestSchema() { Schema.Builder builder = Schema.builder(); @@ -363,7 +394,7 @@ private static void checkMergeTable(Table expected, List tableSlices try { KudoSerializer serializer = new KudoSerializer(schemaOf(expected)); - ByteArrayOutputStream bout = new ByteArrayOutputStream(); + OpenByteArrayOutputStream bout = new OpenByteArrayOutputStream(); for (TableSlice slice : tableSlices) { serializer.writeToStreamWithMetrics(slice.getBaseTable(), bout, slice.getStartRow(), slice.getNumRows()); }