Skip to content

Commit

Permalink
Remove 'sample' parameter from stats::mean API (#2389)
Browse files Browse the repository at this point in the history
This PR removes the sample-parameter from the `raft::stats::mean` API to prevent people from using it by accident when for example computing the mean for a sampled variance computation.

This also invalidates some of the testcases. Within raft only test-code is affected by this change as the active usage of the sample parameter was already removed in #2381. 

This PR is based on #2381 but was separated for tracking purposes.

~~Note that this requires adaption of downstream libraries using the API. I am aware of at least one occurrence in `cuml`.~~
The old API remains in the code marked as deprecated which allows us to adapt downstream libraries at least for the duration of one release cycle.

Authors:
  - Malte Förster (https://github.com/mfoerste4)
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2389
  • Loading branch information
mfoerste4 authored Jan 30, 2025
1 parent 31d3151 commit cceb37d
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 187 deletions.
20 changes: 19 additions & 1 deletion cpp/include/raft/stats/detail/mean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,25 @@ namespace stats {
namespace detail {

template <typename Type, typename IdxType = int>
void mean(
void mean(Type* mu, const Type* data, IdxType D, IdxType N, bool rowMajor, cudaStream_t stream)
{
Type ratio = Type(1) / Type(N);
raft::linalg::reduce(mu,
data,
D,
N,
Type(0),
rowMajor,
false,
stream,
false,
raft::identity_op(),
raft::add_op(),
raft::mul_const_op<Type>(ratio));
}

template <typename Type, typename IdxType = int>
[[deprecated]] void mean(
Type* mu, const Type* data, IdxType D, IdxType N, bool sample, bool rowMajor, cudaStream_t stream)
{
Type ratio = Type(1) / ((sample) ? Type(N - 1) : Type(N));
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/stats/detail/scores.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ math_t r2_score(math_t* y, math_t* y_hat, int n, cudaStream_t stream)
{
rmm::device_scalar<math_t> y_bar(stream);

raft::stats::mean(y_bar.data(), y, 1, n, false, false, stream);
raft::stats::mean(y_bar.data(), y, 1, n, false, stream);
RAFT_CUDA_TRY(cudaPeekAtLastError());

rmm::device_uvector<math_t> sse_arr(n, stream);
Expand Down
66 changes: 60 additions & 6 deletions cpp/include/raft/stats/mean.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
* Copyright (c) 2018-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,14 +38,35 @@ namespace stats {
* @param data: the input matrix
* @param D: number of columns of data
* @param N: number of rows of data
* @param rowMajor: whether the input data is row or col major
* @param stream: cuda stream
*/
template <typename Type, typename IdxType = int>
void mean(Type* mu, const Type* data, IdxType D, IdxType N, bool rowMajor, cudaStream_t stream)
{
detail::mean(mu, data, D, N, rowMajor, stream);
}

/**
* @brief Compute mean of the input matrix
*
* Mean operation is assumed to be performed on a given column.
* Note: This call is deprecated, please use `mean` call without `sample` parameter.
*
* @tparam Type: the data type
* @tparam IdxType Integer type used to for addressing
* @param mu: the output mean vector
* @param data: the input matrix
* @param D: number of columns of data
* @param N: number of rows of data
* @param sample: whether to evaluate sample mean or not. In other words,
* whether
* to normalize the output using N-1 or N, for true or false, respectively
* @param rowMajor: whether the input data is row or col major
* @param stream: cuda stream
*/
template <typename Type, typename IdxType = int>
void mean(
[[deprecated("'sample' parameter deprecated")]] void mean(
Type* mu, const Type* data, IdxType D, IdxType N, bool sample, bool rowMajor, cudaStream_t stream)
{
detail::mean(mu, data, D, N, sample, rowMajor, stream);
Expand All @@ -67,14 +88,47 @@ void mean(
* @param[in] handle the raft handle
* @param[in] data: the input matrix
* @param[out] mu: the output mean vector
* @param[in] sample: whether to evaluate sample mean or not. In other words, whether
* to normalize the output using N-1 or N, for true or false, respectively
*/
template <typename value_t, typename idx_t, typename layout_t>
void mean(raft::resources const& handle,
raft::device_matrix_view<const value_t, idx_t, layout_t> data,
raft::device_vector_view<value_t, idx_t> mu,
bool sample)
raft::device_vector_view<value_t, idx_t> mu)
{
static_assert(
std::is_same_v<layout_t, raft::row_major> || std::is_same_v<layout_t, raft::col_major>,
"Data layout not supported");
RAFT_EXPECTS(data.extent(1) == mu.extent(0), "Size mismatch between data and mu");
RAFT_EXPECTS(mu.is_exhaustive(), "mu must be contiguous");
RAFT_EXPECTS(data.is_exhaustive(), "data must be contiguous");
detail::mean(mu.data_handle(),
data.data_handle(),
data.extent(1),
data.extent(0),
std::is_same_v<layout_t, raft::row_major>,
resource::get_cuda_stream(handle));
}

/**
* @brief Compute mean of the input matrix
*
* Mean operation is assumed to be performed on a given column.
* Note: This call is deprecated, please use `mean` call without `sample` parameter.
*
* @tparam value_t the data type
* @tparam idx_t index type
* @tparam layout_t Layout type of the input matrix.
* @param[in] handle the raft handle
* @param[in] data: the input matrix
* @param[out] mu: the output mean vector
* @param[in] sample: whether to evaluate sample mean or not. In other words, whether
* to normalize the output using N-1 or N, for true or false, respectively
*/
template <typename value_t, typename idx_t, typename layout_t>
[[deprecated("'sample' parameter deprecated")]] void mean(
raft::resources const& handle,
raft::device_matrix_view<const value_t, idx_t, layout_t> data,
raft::device_vector_view<value_t, idx_t> mu,
bool sample)
{
static_assert(
std::is_same_v<layout_t, raft::row_major> || std::is_same_v<layout_t, raft::col_major>,
Expand Down
3 changes: 1 addition & 2 deletions cpp/tests/random/rng.cu
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,7 @@ TEST(Rng, MeanError)
RngState r(seed, rtype);
normal(handle, r, data.data(), len, 3.3f, 0.23f);
// uniform(r, data, len, -1.0, 2.0);
raft::stats::mean(
mean_result.data(), data.data(), num_samples, num_experiments, false, false, stream);
raft::stats::mean(mean_result.data(), data.data(), num_samples, num_experiments, false, stream);
raft::stats::stddev(std_result.data(),
data.data(),
mean_result.data(),
Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/stats/cov.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class CovTest : public ::testing::TestWithParam<CovInputs<T>> {
cov_act.resize(cols * cols, stream);

normal(handle, r, data.data(), len, params.mean, var);
raft::stats::mean(mean_act.data(), data.data(), cols, rows, false, params.rowMajor, stream);
raft::stats::mean(mean_act.data(), data.data(), cols, rows, params.rowMajor, stream);
if (params.rowMajor) {
using layout = raft::row_major;
cov(handle,
Expand Down Expand Up @@ -102,7 +102,7 @@ class CovTest : public ::testing::TestWithParam<CovInputs<T>> {
raft::update_device(data_cm.data(), data_h, 6, stream);
raft::update_device(cov_cm_ref.data(), cov_cm_ref_h, 4, stream);

raft::stats::mean(mean_cm.data(), data_cm.data(), 2, 3, false, false, stream);
raft::stats::mean(mean_cm.data(), data_cm.data(), 2, 3, false, stream);
cov(handle, cov_cm.data(), data_cm.data(), mean_cm.data(), 2, 3, true, false, true, stream);
}

Expand Down
121 changes: 49 additions & 72 deletions cpp/tests/stats/mean.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ template <typename T>
struct MeanInputs {
T tolerance, mean;
int rows, cols;
bool sample, rowMajor;
bool rowMajor;
unsigned long long int seed;
T stddev = (T)1.0;
};
Expand All @@ -42,7 +42,7 @@ template <typename T>
::std::ostream& operator<<(::std::ostream& os, const MeanInputs<T>& dims)
{
return os << "{ " << dims.tolerance << ", " << dims.rows << ", " << dims.cols << ", "
<< dims.sample << ", " << dims.rowMajor << ", " << dims.stddev << "}" << std::endl;
<< ", " << dims.rowMajor << ", " << dims.stddev << "}" << std::endl;
}

template <typename T>
Expand Down Expand Up @@ -74,14 +74,12 @@ class MeanTest : public ::testing::TestWithParam<MeanInputs<T>> {
using layout = raft::row_major;
mean(handle,
raft::make_device_matrix_view<const T, int, layout>(data, rows, cols),
raft::make_device_vector_view<T, int>(mean_act.data(), cols),
params.sample);
raft::make_device_vector_view<T, int>(mean_act.data(), cols));
} else {
using layout = raft::col_major;
mean(handle,
raft::make_device_matrix_view<const T, int, layout>(data, rows, cols),
raft::make_device_vector_view<T, int>(mean_act.data(), cols),
params.sample);
raft::make_device_vector_view<T, int>(mean_act.data(), cols));
}
}

Expand All @@ -98,72 +96,51 @@ class MeanTest : public ::testing::TestWithParam<MeanInputs<T>> {
// measured mean (of a normal distribution) will fall outside of an epsilon of
// 0.15 only 4/10000 times. (epsilon of 0.1 will fail 30/100 times)
const std::vector<MeanInputs<float>> inputsf = {
{0.15f, 1.f, 1024, 32, true, false, 1234ULL},
{0.15f, 1.f, 1024, 64, true, false, 1234ULL},
{0.15f, 1.f, 1024, 128, true, false, 1234ULL},
{0.15f, 1.f, 1024, 256, true, false, 1234ULL},
{0.15f, -1.f, 1024, 32, false, false, 1234ULL},
{0.15f, -1.f, 1024, 64, false, false, 1234ULL},
{0.15f, -1.f, 1024, 128, false, false, 1234ULL},
{0.15f, -1.f, 1024, 256, false, false, 1234ULL},
{0.15f, 1.f, 1024, 32, true, true, 1234ULL},
{0.15f, 1.f, 1024, 64, true, true, 1234ULL},
{0.15f, 1.f, 1024, 128, true, true, 1234ULL},
{0.15f, 1.f, 1024, 256, true, true, 1234ULL},
{0.15f, -1.f, 1024, 32, false, true, 1234ULL},
{0.15f, -1.f, 1024, 64, false, true, 1234ULL},
{0.15f, -1.f, 1024, 128, false, true, 1234ULL},
{0.15f, -1.f, 1024, 256, false, true, 1234ULL},
{0.15f, -1.f, 1030, 1, false, false, 1234ULL},
{0.15f, -1.f, 1030, 60, true, false, 1234ULL},
{2.0f, -1.f, 31, 120, false, false, 1234ULL},
{2.0f, -1.f, 1, 130, false, false, 1234ULL},
{0.15f, -1.f, 1030, 1, false, true, 1234ULL},
{0.15f, -1.f, 1030, 60, true, true, 1234ULL},
{2.0f, -1.f, 31, 120, false, true, 1234ULL},
{2.0f, -1.f, 1, 130, false, true, 1234ULL},
{2.0f, -1.f, 1, 1, false, false, 1234ULL},
{2.0f, -1.f, 1, 1, false, true, 1234ULL},
{2.0f, -1.f, 7, 23, false, false, 1234ULL},
{2.0f, -1.f, 7, 23, false, true, 1234ULL},
{2.0f, -1.f, 17, 5, false, false, 1234ULL},
{2.0f, -1.f, 17, 5, false, true, 1234ULL},
{0.0001f, 0.1f, 1 << 27, 2, false, false, 1234ULL, 0.0001f},
{0.0001f, 0.1f, 1 << 27, 2, false, true, 1234ULL, 0.0001f}};

const std::vector<MeanInputs<double>> inputsd = {
{0.15, 1.0, 1024, 32, true, false, 1234ULL},
{0.15, 1.0, 1024, 64, true, false, 1234ULL},
{0.15, 1.0, 1024, 128, true, false, 1234ULL},
{0.15, 1.0, 1024, 256, true, false, 1234ULL},
{0.15, -1.0, 1024, 32, false, false, 1234ULL},
{0.15, -1.0, 1024, 64, false, false, 1234ULL},
{0.15, -1.0, 1024, 128, false, false, 1234ULL},
{0.15, -1.0, 1024, 256, false, false, 1234ULL},
{0.15, 1.0, 1024, 32, true, true, 1234ULL},
{0.15, 1.0, 1024, 64, true, true, 1234ULL},
{0.15, 1.0, 1024, 128, true, true, 1234ULL},
{0.15, 1.0, 1024, 256, true, true, 1234ULL},
{0.15, -1.0, 1024, 32, false, true, 1234ULL},
{0.15, -1.0, 1024, 64, false, true, 1234ULL},
{0.15, -1.0, 1024, 128, false, true, 1234ULL},
{0.15, -1.0, 1024, 256, false, true, 1234ULL},
{0.15, -1.0, 1030, 1, false, false, 1234ULL},
{0.15, -1.0, 1030, 60, true, false, 1234ULL},
{2.0, -1.0, 31, 120, false, false, 1234ULL},
{2.0, -1.0, 1, 130, false, false, 1234ULL},
{0.15, -1.0, 1030, 1, false, true, 1234ULL},
{0.15, -1.0, 1030, 60, true, true, 1234ULL},
{2.0, -1.0, 31, 120, false, true, 1234ULL},
{2.0, -1.0, 1, 130, false, true, 1234ULL},
{2.0, -1.0, 1, 1, false, false, 1234ULL},
{2.0, -1.0, 1, 1, false, true, 1234ULL},
{2.0, -1.0, 7, 23, false, false, 1234ULL},
{2.0, -1.0, 7, 23, false, true, 1234ULL},
{2.0, -1.0, 17, 5, false, false, 1234ULL},
{2.0, -1.0, 17, 5, false, true, 1234ULL},
{1e-8, 1e-1, 1 << 27, 2, false, false, 1234ULL, 0.0001},
{1e-8, 1e-1, 1 << 27, 2, false, true, 1234ULL, 0.0001}};
{0.15f, -1.f, 1024, 32, false, 1234ULL},
{0.15f, -1.f, 1024, 64, false, 1234ULL},
{0.15f, -1.f, 1024, 128, false, 1234ULL},
{0.15f, -1.f, 1024, 256, false, 1234ULL},
{0.15f, -1.f, 1024, 32, true, 1234ULL},
{0.15f, -1.f, 1024, 64, true, 1234ULL},
{0.15f, -1.f, 1024, 128, true, 1234ULL},
{0.15f, -1.f, 1024, 256, true, 1234ULL},
{0.15f, -1.f, 1030, 1, false, 1234ULL},
{2.0f, -1.f, 31, 120, false, 1234ULL},
{2.0f, -1.f, 1, 130, false, 1234ULL},
{0.15f, -1.f, 1030, 1, true, 1234ULL},
{2.0f, -1.f, 31, 120, true, 1234ULL},
{2.0f, -1.f, 1, 130, true, 1234ULL},
{2.0f, -1.f, 1, 1, false, 1234ULL},
{2.0f, -1.f, 1, 1, true, 1234ULL},
{2.0f, -1.f, 7, 23, false, 1234ULL},
{2.0f, -1.f, 7, 23, true, 1234ULL},
{2.0f, -1.f, 17, 5, false, 1234ULL},
{2.0f, -1.f, 17, 5, true, 1234ULL},
{0.0001f, 0.1f, 1 << 27, 2, false, 1234ULL, 0.0001f},
{0.0001f, 0.1f, 1 << 27, 2, true, 1234ULL, 0.0001f}};

const std::vector<MeanInputs<double>> inputsd = {{0.15, -1.0, 1024, 32, false, 1234ULL},
{0.15, -1.0, 1024, 64, false, 1234ULL},
{0.15, -1.0, 1024, 128, false, 1234ULL},
{0.15, -1.0, 1024, 256, false, 1234ULL},
{0.15, -1.0, 1024, 32, true, 1234ULL},
{0.15, -1.0, 1024, 64, true, 1234ULL},
{0.15, -1.0, 1024, 128, true, 1234ULL},
{0.15, -1.0, 1024, 256, true, 1234ULL},
{0.15, -1.0, 1030, 1, false, 1234ULL},
{2.0, -1.0, 31, 120, false, 1234ULL},
{2.0, -1.0, 1, 130, false, 1234ULL},
{0.15, -1.0, 1030, 1, true, 1234ULL},
{2.0, -1.0, 31, 120, true, 1234ULL},
{2.0, -1.0, 1, 130, true, 1234ULL},
{2.0, -1.0, 1, 1, false, 1234ULL},
{2.0, -1.0, 1, 1, true, 1234ULL},
{2.0, -1.0, 7, 23, false, 1234ULL},
{2.0, -1.0, 7, 23, true, 1234ULL},
{2.0, -1.0, 17, 5, false, 1234ULL},
{2.0, -1.0, 17, 5, true, 1234ULL},
{1e-8, 1e-1, 1 << 27, 2, false, 1234ULL, 0.0001},
{1e-8, 1e-1, 1 << 27, 2, true, 1234ULL, 0.0001}};

typedef MeanTest<float> MeanTestF;
TEST_P(MeanTestF, Result)
Expand Down
Loading

0 comments on commit cceb37d

Please sign in to comment.