Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance DataWriter to save memory during kudo serialization. #2891

Merged
merged 7 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni.kudo;

import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

import ai.rapids.cudf.HostMemoryBuffer;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;

/**
* Adapter class which helps to save memory copy when shuffle manager uses
* {@link ByteArrayOutputStream} during serialization.
*/
public class ByteArrayOutputStreamWriter implements DataWriter {
private static final Method ENSURE_CAPACITY;
private static final Field BUF;
private static final Field COUNT;

static {
try {
ENSURE_CAPACITY = ByteArrayOutputStream.class.getDeclaredMethod("ensureCapacity", int.class);
ENSURE_CAPACITY.setAccessible(true);

BUF = ByteArrayOutputStream.class.getDeclaredField("buf");
BUF.setAccessible(true);


COUNT = ByteArrayOutputStream.class.getDeclaredField("count");
COUNT.setAccessible(true);
} catch (NoSuchMethodException | NoSuchFieldException e) {
throw new RuntimeException("Failed to find ByteArrayOutputStream.ensureCapacity", e);
}
}

private final ByteArrayOutputStream out;

public ByteArrayOutputStreamWriter(ByteArrayOutputStream bout) {
requireNonNull(bout, "Byte array output stream can't be null");
this.out = bout;
}

@Override
public void reserve(int size) throws IOException {
try {
ENSURE_CAPACITY.invoke(out, size);
} catch (Exception e) {
throw new RuntimeException("Failed to invoke ByteArrayOutputStream.ensureCapacity", e);
}
}

@Override
public void writeInt(int v) throws IOException {
reserve(Integer.BYTES + out.size());
byte[] bytes = new byte[4];
bytes[0] = (byte) ((v >>> 24) & 0xFF);
bytes[1] = (byte) ((v >>> 16) & 0xFF);
bytes[2] = (byte) ((v >>> 8) & 0xFF);
bytes[3] = (byte) (v & 0xFF);
out.write(bytes);
}

@Override
public void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException {
reserve(toIntExact(out.size() + len));

try {
byte[] buf = (byte[]) BUF.get(out);
int count = out.size();

src.getBytes(buf, count, srcOffset, len);
COUNT.setInt(out, toIntExact(count + len));
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}

@Override
public void flush() throws IOException {
}

@Override
public void write(byte[] arr, int offset, int length) throws IOException {
out.write(arr, offset, length);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,7 +24,7 @@
/**
* Visible for testing
*/
class DataOutputStreamWriter extends DataWriter {
class DataOutputStreamWriter implements DataWriter {
private final byte[] arrayBuffer = new byte[1024];
private final DataOutputStream dout;

Expand Down
27 changes: 18 additions & 9 deletions src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,11 +21,19 @@
import java.io.IOException;

/**
* Visible for testing
* Output data writer for kudo serializer.
*/
abstract class DataWriter {
public interface DataWriter {

public abstract void writeInt(int i) throws IOException;
/**
* Write int in network byte order.
*/
void writeInt(int i) throws IOException;

/**
* Reserve space in the buffer for the given size.
*/
default void reserve(int size) throws IOException {}

/**
* Copy data from src starting at srcOffset and going for len bytes.
Expand All @@ -34,11 +42,12 @@ abstract class DataWriter {
* @param srcOffset offset to start at.
* @param len amount to copy.
*/
public abstract void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException;
void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException;

public void flush() throws IOException {
// NOOP by default
}
void flush() throws IOException;

public abstract void write(byte[] arr, int offset, int length) throws IOException;
/**
* Copy part of byte array to this writer.
*/
void write(byte[] arr, int offset, int length) throws IOException;
}
16 changes: 13 additions & 3 deletions src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids.jni.kudo;

import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

import ai.rapids.cudf.BufferType;
Expand All @@ -28,6 +29,7 @@
import com.nvidia.spark.rapids.jni.Pair;
import com.nvidia.spark.rapids.jni.schema.Visitors;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
Expand Down Expand Up @@ -328,6 +330,9 @@ private WriteMetrics writeSliced(HostColumnVector[] columns, DataWriter out, int
new KudoTableHeaderCalc(rowOffset, numRows, flattenedColumnCount);
withTime(() -> Visitors.visitColumns(columns, headerCalc), metrics::addCalcHeaderTime);
KudoTableHeader header = headerCalc.getHeader();

out.reserve(toIntExact(header.getSerializedSize() + header.getTotalDataLen()));

long currentTime = System.nanoTime();
header.writeTo(out);
metrics.addCopyHeaderTime(System.nanoTime() - currentTime);
Expand Down Expand Up @@ -355,10 +360,15 @@ private WriteMetrics writeSliced(HostColumnVector[] columns, DataWriter out, int
}

private static DataWriter writerFrom(OutputStream out) {
if (!(out instanceof DataOutputStream)) {
out = new DataOutputStream(new BufferedOutputStream(out));
if (out instanceof DataOutputStream) {
return new DataOutputStreamWriter((DataOutputStream) out);
} else if (out instanceof OpenByteArrayOutputStream) {
return new OpenByteArrayOutputStreamWriter((OpenByteArrayOutputStream) out);
} else if (out instanceof ByteArrayOutputStream) {
return new ByteArrayOutputStreamWriter((ByteArrayOutputStream) out);
} else {
return new DataOutputStreamWriter(new DataOutputStream(new BufferedOutputStream(out)));
}
return new DataOutputStreamWriter((DataOutputStream) out);
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni.kudo;

import ai.rapids.cudf.HostMemoryBuffer;

import java.io.ByteArrayOutputStream;
import java.util.Arrays;

import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
import static java.util.Objects.requireNonNull;

/**
* This class extends {@link ByteArrayOutputStream} to provide some internal methods to save copy.
*/
public class OpenByteArrayOutputStream extends ByteArrayOutputStream {
private static final int MAX_ARRAY_LENGTH = Integer.MAX_VALUE - 32;

/**
* Creates a new byte array output stream. The buffer capacity is
* initially 32 bytes, though its size increases if necessary.
*/
public OpenByteArrayOutputStream() {
this(32);
}

/**
* Creates a new byte array output stream, with a buffer capacity of
* the specified size, in bytes.
*
* @param size the initial size.
* @exception IllegalArgumentException if size is negative.
*/
public OpenByteArrayOutputStream(int size) {
super(size);
}

/**
* Get underlying byte array.
*/
public byte[] getBuf() {
return buf;
}

/**
* Get actual number of bytes that have been written to this output stream.
* @return Number of bytes written to this output stream. Note that this maybe smaller than length of
* {@link OpenByteArrayOutputStream#getBuf()}.
*/
public int getCount() {
return count;
}

/**
* Increases the capacity if necessary to ensure that it can hold
* at least the number of elements specified by the minimum
* capacity argument.
*
* @param capacity the desired minimum capacity
* @throws IllegalStateException If {@code capacity < 0} or {@code capacity >= MAX_ARRAY_LENGTH}.
*/
public void reserve(int capacity) {
ensure(capacity >= 0, () -> "Requested capacity must be positive, but was " + capacity);
ensure(capacity < MAX_ARRAY_LENGTH, () -> "Requested capacity is too large: " + capacity);

if (capacity > buf.length) {
buf = Arrays.copyOf(buf, capacity);
}
}

/**
* Copy from {@link HostMemoryBuffer} to this output stream.
* @param srcBuf {@link HostMemoryBuffer} to copy from.
* @param offset Start position in source {@link HostMemoryBuffer}.
* @param length Number of bytes to copy.
*/
public void write(HostMemoryBuffer srcBuf, long offset, int length) {
requireNonNull(srcBuf, "Source buf can't be null!");
reserve(count + length);
srcBuf.getBytes(buf, count, offset, length);
count += length;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni.kudo;

import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

import ai.rapids.cudf.HostMemoryBuffer;
import java.io.IOException;

/**
* Adapter class which helps to save memory copy when shuffle manager uses
* {@link OpenByteArrayOutputStream} during serialization.
*/
public class OpenByteArrayOutputStreamWriter implements DataWriter {
private final OpenByteArrayOutputStream out;

public OpenByteArrayOutputStreamWriter(OpenByteArrayOutputStream bout) {
requireNonNull(bout, "Byte array output stream can't be null");
this.out = bout;
}

@Override
public void reserve(int size) throws IOException {
out.reserve(size);
}

@Override
public void writeInt(int v) throws IOException {
out.reserve(4 + out.size());
out.write((v >>> 24) & 0xFF);
out.write((v >>> 16) & 0xFF);
out.write((v >>> 8) & 0xFF);
out.write((v >>> 0) & 0xFF);
}

@Override
public void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException {
out.write(src, srcOffset, toIntExact(len));
}

@Override
public void flush() throws IOException {
}

@Override
public void write(byte[] arr, int offset, int length) throws IOException {
out.write(arr, offset, length);
}
}
Loading