Skip to content

Commit

Permalink
Standardize workspace abstraction (#2524)
Browse files Browse the repository at this point in the history
  • Loading branch information
amberhassaan authored Dec 16, 2023
1 parent b17d080 commit 62a0534
Show file tree
Hide file tree
Showing 17 changed files with 595 additions and 581 deletions.
21 changes: 21 additions & 0 deletions src/ocl/convolutionocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ static inline void ValidateGroupCount(const TensorDescriptor& x,
MIOPEN_THROW(miopenStatusBadParm, "Invalid group number");
}

static inline void ValidateWorkspace(Data_t workSpace, const size_t workSpaceSize)
{

[[maybe_unused]] bool x = (workSpace != nullptr);
[[maybe_unused]] bool y = (workSpaceSize != 0);

assert(((x && y) || (!x && !y)) && "workspace pointer and size don't match. Either both should "
"be zero or both should be non-zero");

/// \todo could add a check here that workSpace points to GPU memory
}

static Invoker PrepareInvoker(ExecutionContext ctx,
const conv::ProblemDescription& problem,
const NetworkConfig& config,
Expand Down Expand Up @@ -260,6 +272,7 @@ void ConvolutionDescriptor::FindConvFwdAlgorithm(Handle& handle,
bool exhaustiveSearch) const
{
MIOPEN_LOG_I("requestAlgoCount = " << requestAlgoCount << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
if(x == nullptr || w == nullptr || y == nullptr)
MIOPEN_THROW(miopenStatusBadParm, "Buffers cannot be NULL");
if(returnedAlgoCount == nullptr)
Expand Down Expand Up @@ -495,6 +508,7 @@ void ConvolutionDescriptor::ConvolutionForward(Handle& handle,
size_t workSpaceSize) const
{
MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);

const auto tensors = ConvFwdTensors{xDesc, x, wDesc, w, yDesc, y};
ValidateTensors(tensors);
Expand Down Expand Up @@ -812,6 +826,7 @@ void ConvolutionDescriptor::ConvolutionForwardImmediate(Handle& handle,
const solver::Id solver_id) const
{
MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
const auto tensors = ConvFwdTensors{xDesc, x, wDesc, w, yDesc, y};

ValidateTensors(tensors);
Expand Down Expand Up @@ -846,6 +861,7 @@ void ConvolutionDescriptor::FindConvBwdDataAlgorithm(Handle& handle,
bool exhaustiveSearch) const
{
MIOPEN_LOG_I("requestAlgoCount = " << requestAlgoCount << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
if(dx == nullptr || w == nullptr || dy == nullptr)
MIOPEN_THROW(miopenStatusBadParm, "Buffers cannot be NULL");
if(returnedAlgoCount == nullptr)
Expand Down Expand Up @@ -944,6 +960,7 @@ void ConvolutionDescriptor::ConvolutionBackwardData(Handle& handle,
size_t workSpaceSize) const
{
MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);

auto tensors = ConvBwdTensors{dyDesc, dy, wDesc, w, dxDesc, dx};

Expand Down Expand Up @@ -1015,6 +1032,7 @@ void ConvolutionDescriptor::ConvolutionBackwardImmediate(Handle& handle,
solver::Id solver_id) const
{
MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
auto tensors = ConvBwdTensors{dyDesc, dy, wDesc, w, dxDesc, dx};

ValidateTensors(tensors);
Expand Down Expand Up @@ -1055,6 +1073,7 @@ void ConvolutionDescriptor::FindConvBwdWeightsAlgorithm(Handle& handle,
bool exhaustiveSearch) const
{
MIOPEN_LOG_I("requestAlgoCount = " << requestAlgoCount << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
if(x == nullptr || dw == nullptr || dy == nullptr)
MIOPEN_THROW(miopenStatusBadParm, "Buffers cannot be NULL");
if(returnedAlgoCount == nullptr)
Expand Down Expand Up @@ -1151,6 +1170,7 @@ void ConvolutionDescriptor::ConvolutionBackwardWeights(const Handle& handle,
size_t workSpaceSize) const
{
MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
decltype(auto) tensors = ConvWrwTensors{dyDesc, dy, xDesc, x, dwDesc, dw};
ValidateTensors(tensors);
ValidateAlphaBeta(alpha, beta);
Expand Down Expand Up @@ -1218,6 +1238,7 @@ void ConvolutionDescriptor::ConvolutionWrwImmediate(Handle& handle,
solver::Id solver_id) const
{
MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
auto tensors = ConvWrwTensors{dyDesc, dy, xDesc, x, dwDesc, dw};
ValidateTensors(tensors);

Expand Down
Loading

0 comments on commit 62a0534

Please sign in to comment.