From d51557ee3a0940f535d121a49d039be695c18efb Mon Sep 17 00:00:00 2001 From: Robin Davies Date: Sat, 19 Oct 2024 02:20:57 -0400 Subject: [PATCH] Activitions incorrectly applied to row-based blocks in Wavnet. --- NAM/activations.h | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/NAM/activations.h b/NAM/activations.h index fe47203..c518385 100644 --- a/NAM/activations.h +++ b/NAM/activations.h @@ -53,11 +53,24 @@ class Activation Activation() = default; virtual ~Activation() = default; virtual void apply(Eigen::MatrixXf& matrix) { apply(matrix.data(), matrix.rows() * matrix.cols()); } - virtual void apply(Eigen::Block block) { apply(block.data(), block.rows() * block.cols()); } - virtual void apply(Eigen::Block block) - { - apply(block.data(), block.rows() * block.cols()); + + + virtual void apply(Eigen::Block block) { + // true -> A set of columns in column major order, or a set of rows in row major order. All data is contiguous. + this->apply(block.data(),(long)(block.rows()*block.cols())); + } + + virtual void apply(Eigen::Block block) { + // Overload for non-contiguous blocks. Apply column by column + for (int c = 0; c < block.cols(); ++c) + { + float *mem = &block.coeffRef(0,c); + this->apply(mem,(long)block.rows()); + } } + + + virtual void apply(float* data, long size) {} static Activation* get_activation(const std::string name);