Skip to content

Commit

Permalink
Merge branch 'huikang/incremental_dump' into 'main'
Browse files Browse the repository at this point in the history
sok incremental dump

See merge request dl/hugectr/hugectr!1524
  • Loading branch information
minseokl committed Jan 29, 2024
2 parents af9c40c + bb71a4f commit e5f021f
Show file tree
Hide file tree
Showing 18 changed files with 934 additions and 10 deletions.
6 changes: 6 additions & 0 deletions sparse_operation_kit/kit_src/variable/impl/det_variable.cu
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ void DETVariable<KeyType, ValueType>::eXport(KeyType* keys, ValueType* values,
CUDACHECK(cudaFree(d_values));
}

template <typename KeyType, typename ValueType>
void DETVariable<KeyType, ValueType>::eXport_if(KeyType* keys, ValueType* values, size_t* counter,
uint64_t threshold, cudaStream_t stream) {
throw std::runtime_error("SOK dynamic variable with DET backend don't support eXport_if");
}

template <typename KeyType, typename ValueType>
void DETVariable<KeyType, ValueType>::assign(const KeyType* keys, const ValueType* values,
size_t num_keys, cudaStream_t stream) {
Expand Down
2 changes: 2 additions & 0 deletions sparse_operation_kit/kit_src/variable/impl/det_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class DETVariable : public VariableBase<KeyType, ValueType> {
int64_t cols() override;

void eXport(KeyType *keys, ValueType *values, cudaStream_t stream = 0) override;
void eXport_if(KeyType *keys, ValueType *values, size_t *counter, uint64_t threshold,
cudaStream_t stream = 0) override;
void assign(const KeyType *keys, const ValueType *values, size_t num_keys,
cudaStream_t stream = 0) override;

Expand Down
50 changes: 42 additions & 8 deletions sparse_operation_kit/kit_src/variable/impl/hkv_variable.cu
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ __global__ void generate_normal_kernel(curandState* state, T** result, bool* d_f
}
}

template <class K, class S>
struct ExportIfPredFunctor {
__forceinline__ __device__ bool operator()(const K& key, S& score, const K& pattern,
const S& threshold) {
return score > threshold;
}
};

static void set_curand_states(curandState** states, cudaStream_t stream = 0) {
int device;
CUDACHECK(cudaGetDevice(&device));
Expand Down Expand Up @@ -197,7 +205,6 @@ HKVVariable<KeyType, ValueType>::HKVVariable(int64_t dimension, int64_t initial_
nv::merlin::EvictStrategy hkv_evict_strategy;
parse_evict_strategy(evict_strategy, hkv_evict_strategy);
hkv_table_option_.evict_strategy = hkv_evict_strategy;

hkv_table_->init(hkv_table_option_);
}

Expand Down Expand Up @@ -230,22 +237,49 @@ void HKVVariable<KeyType, ValueType>::eXport(KeyType* keys, ValueType* values,
ValueType* d_values;
CUDACHECK(cudaMallocManaged(&d_values, sizeof(ValueType) * num_keys * dim));

// KeyType* d_keys;
// CUDACHECK(cudaMalloc(&d_keys, sizeof(KeyType) * num_keys));
// ValueType* d_values;
// CUDACHECK(cudaMalloc(&d_values, sizeof(ValueType) * num_keys * dim));
hkv_table_->export_batch(hkv_table_option_.max_capacity, 0, d_keys, d_values, nullptr,
stream); // Meta missing
CUDACHECK(cudaStreamSynchronize(stream));

// clang-format off
std::memcpy(keys, d_keys, sizeof(KeyType) * num_keys);
std::memcpy(values, d_values, sizeof(ValueType) * num_keys * dim);
//CUDACHECK(cudaMemcpy(keys, d_keys, sizeof(KeyType) * num_keys,cudaMemcpyDeviceToHost));
//CUDACHECK(cudaMemcpy(values, d_values, sizeof(ValueType) * num_keys * dim,cudaMemcpyDeviceToHost));
CUDACHECK(cudaFree(d_keys));
CUDACHECK(cudaFree(d_values));
}

template <typename KeyType, typename ValueType>
void HKVVariable<KeyType, ValueType>::eXport_if(KeyType* keys, ValueType* values, size_t* counter,
uint64_t threshold, cudaStream_t stream) {
int64_t num_keys = rows();
int64_t dim = cols();

// `keys` and `values` are pointers of host memory
KeyType* d_keys;
CUDACHECK(cudaMallocManaged(&d_keys, sizeof(KeyType) * num_keys));
ValueType* d_values;
CUDACHECK(cudaMallocManaged(&d_values, sizeof(ValueType) * num_keys * dim));

uint64_t* d_socre_type;
CUDACHECK(cudaMallocManaged(&d_socre_type, sizeof(uint64_t) * num_keys));

uint64_t* d_dump_counter;
CUDACHECK(cudaMallocManaged(&d_dump_counter, sizeof(uint64_t)));
// useless HKV need a input , but do nothing in the ExportIfPredFunctor
KeyType pattern = 100;

hkv_table_->template export_batch_if<ExportIfPredFunctor>(
pattern, threshold, hkv_table_->capacity(), 0, d_dump_counter, d_keys, d_values, d_socre_type,
stream);
CUDACHECK(cudaStreamSynchronize(stream));
// clang-format off
std::memcpy(keys, d_keys, sizeof(KeyType) * (*d_dump_counter));
std::memcpy(values, d_values, sizeof(ValueType) * (*d_dump_counter) * dim);
counter[0] = (size_t)(*d_dump_counter);
// clang-format on
CUDACHECK(cudaFree(d_keys));
CUDACHECK(cudaFree(d_values));
CUDACHECK(cudaFree(d_socre_type));
CUDACHECK(cudaFree(d_dump_counter));
}

template <typename KeyType, typename ValueType>
Expand Down
2 changes: 2 additions & 0 deletions sparse_operation_kit/kit_src/variable/impl/hkv_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class HKVVariable : public VariableBase<KeyType, ValueType> {
int64_t cols() override;

void eXport(KeyType *keys, ValueType *values, cudaStream_t stream = 0) override;
void eXport_if(KeyType *keys, ValueType *values, size_t *counter, uint64_t threshold,
cudaStream_t stream = 0) override;
void assign(const KeyType *keys, const ValueType *values, size_t num_keys,
cudaStream_t stream = 0) override;

Expand Down
4 changes: 4 additions & 0 deletions sparse_operation_kit/kit_src/variable/impl/variable_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class VariableBase {
virtual int64_t cols() = 0;

virtual void eXport(KeyType *keys, ValueType *values, cudaStream_t stream = 0) = 0;

virtual void eXport_if(KeyType *keys, ValueType *values, size_t *counter, uint64_t threshold,
cudaStream_t stream = 0) = 0;

virtual void assign(const KeyType *keys, const ValueType *values, size_t num_keys,
cudaStream_t stream = 0) = 0;

Expand Down
6 changes: 6 additions & 0 deletions sparse_operation_kit/kit_src/variable/kernels/dummy_var.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ void DummyVar<KeyType, ValueType>::Export(void* keys, void* values, cudaStream_t
var_->eXport(static_cast<KeyType*>(keys), static_cast<ValueType*>(values), stream);
}

template <typename KeyType, typename ValueType>
void DummyVar<KeyType, ValueType>::ExportIf(void* keys, void* values,size_t* counter,uint64_t threshold, cudaStream_t stream) {
check_var();
var_->eXport_if(static_cast<KeyType*>(keys), static_cast<ValueType*>(values),counter,threshold, stream);
}

template <typename KeyType, typename ValueType>
void DummyVar<KeyType, ValueType>::Assign(const void* keys, const void* values, size_t num_keys,
cudaStream_t stream) {
Expand Down
1 change: 1 addition & 0 deletions sparse_operation_kit/kit_src/variable/kernels/dummy_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class DummyVar : public ResourceBase {
int64_t cols();

void Export(void *keys, void *values, cudaStream_t stream);
void ExportIf(void *keys, void *values, size_t *counter, uint64_t threshold, cudaStream_t stream);
void Assign(const void *keys, const void *values, size_t num_keys, cudaStream_t stream);

void SparseRead(const void *keys, void *values, size_t num_keys, cudaStream_t stream);
Expand Down
69 changes: 69 additions & 0 deletions sparse_operation_kit/kit_src/variable/kernels/dummy_var_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,75 @@ REGISTER_GPU_KERNELS(int32_t, int32_t, float, float);
#endif
#undef REGISTER_GPU_KERNELS

// -----------------------------------------------------------------------------------------------
// DummyVarExportIf
// -----------------------------------------------------------------------------------------------
template <typename KeyType, typename ValueType>
class DummyVarExportIfOp : public OpKernel {
public:
explicit DummyVarExportIfOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}

void Compute(OpKernelContext* ctx) override {
// Get DummyVar
core::RefCountPtr<DummyVar<KeyType, ValueType>> var;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));

tf_shared_lock ml(*var->mu());

// Get shape
int64_t rows = var->rows();
int64_t cols = var->cols();

const Tensor* threshold_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("threshold", &threshold_tensor));

AllocatorAttributes alloc_attr;
alloc_attr.set_on_host(true);
// temp buffer
Tensor tmp_indices;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT64, {rows}, &tmp_indices, alloc_attr));

Tensor tmp_values;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ValueType>::v(), {rows*cols}, &tmp_values, alloc_attr));

Tensor counter;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_UINT64, {1}, &counter, alloc_attr));
// Get cuda stream of tensorflow
auto device_ctx = ctx->op_device_context();
OP_REQUIRES(ctx, device_ctx != nullptr, errors::Aborted("No valid device context."));
cudaStream_t stream = stream_executor::gpu::AsGpuStreamValue(device_ctx->stream());
var->ExportIf(tmp_indices.data(), tmp_values.data(),(size_t*)counter.data(),((uint64_t*)threshold_tensor->data())[0], stream);
// Allocate output
Tensor* indices = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {((size_t*)counter.data())[0]}, &indices));
Tensor* values = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {((size_t*)counter.data())[0], cols}, &values));

std::memcpy(indices->data(), tmp_indices.data(), sizeof(KeyType) * ((size_t*)counter.data())[0]);
std::memcpy(values->data(), tmp_values.data(), sizeof(ValueType) * ((size_t*)counter.data())[0] * cols);
}
};

#define REGISTER_GPU_KERNELS(key_type_tf, key_type, dtype_tf, dtype) \
REGISTER_KERNEL_BUILDER(Name("DummyVarExportIf") \
.Device(DEVICE_GPU) \
.HostMemory("resource") \
.HostMemory("threshold") \
.HostMemory("indices") \
.HostMemory("values") \
.TypeConstraint<key_type_tf>("key_type") \
.TypeConstraint<dtype_tf>("dtype"), \
DummyVarExportIfOp<key_type, dtype>)
#if TF_VERSION_MAJOR == 1
REGISTER_GPU_KERNELS(int64, int64_t, float, float);
REGISTER_GPU_KERNELS(int32, int32_t, float, float);
#else
REGISTER_GPU_KERNELS(int64_t, int64_t, float, float);
REGISTER_GPU_KERNELS(int32_t, int32_t, float, float);
#endif
#undef REGISTER_GPU_KERNELS


// -----------------------------------------------------------------------------------------------
// DummyVarSparseRead
// -----------------------------------------------------------------------------------------------
Expand Down
10 changes: 10 additions & 0 deletions sparse_operation_kit/kit_src/variable/ops/dummy_var_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ REGISTER_OP("DummyVarExport")
.Attr("dtype: {float32} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) { return sok_tsl_status(); });

REGISTER_OP("DummyVarExportIf")
.Input("resource: resource")
.Input("threshold: uint64")
.Output("indices: key_type")
.Output("values: dtype")
.Attr("key_type: {int32, int64} = DT_INT64")
.Attr("dtype: {float32} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) { return sok_tsl_status(); });


REGISTER_OP("DummyVarSparseRead")
.Input("resource: resource")
.Input("indices: key_type")
Expand Down
2 changes: 1 addition & 1 deletion sparse_operation_kit/sparse_operation_kit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from sparse_operation_kit.lookup import lookup_sparse
from sparse_operation_kit.lookup import all2all_dense_embedding

from sparse_operation_kit.dump_load import dump, load
from sparse_operation_kit.dump_load import dump, load, incremental_model_dump


# a specific code path for dl framework tf2.11.0
Expand Down
Loading

0 comments on commit e5f021f

Please sign in to comment.