Skip to content

Commit

Permalink
Reduced memory storages in intergal V_eff N_i N_j
Browse files Browse the repository at this point in the history
  • Loading branch information
Avirup Sircar committed Nov 12, 2024
1 parent 90dfef6 commit b0518b1
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 145 deletions.
21 changes: 11 additions & 10 deletions src/basis/EFEBDSOnTheFlyComputeDealii.t.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1800,16 +1800,17 @@ namespace dftefe
for (size_type quadId = 0; quadId < d_nQuadPointsIncell[cellId];
quadId++)
{
for (size_type iDim = 0; iDim < dim; iDim++)
{
basisGradientData.template copyFrom<memorySpace>(
iter->second->data(),
(d_dofsInCell[cellId] - d_classialDofsInCell),
(d_dofsInCell[cellId] - d_classialDofsInCell) * dim *
quadId + iDim * (d_dofsInCell[cellId] - d_classialDofsInCell),
cumulativeOffset + d_dofsInCell[cellId] * dim * quadId +
d_dofsInCell[cellId] * iDim + d_classialDofsInCell);
}
for (size_type iDim = 0; iDim < dim; iDim++)
{
basisGradientData.template copyFrom<memorySpace>(
iter->second->data(),
(d_dofsInCell[cellId] - d_classialDofsInCell),
(d_dofsInCell[cellId] - d_classialDofsInCell) * dim *
quadId +
iDim * (d_dofsInCell[cellId] - d_classialDofsInCell),
cumulativeOffset + d_dofsInCell[cellId] * dim * quadId +
d_dofsInCell[cellId] * iDim + d_classialDofsInCell);
}
}
}
cumulativeOffset +=
Expand Down
132 changes: 54 additions & 78 deletions src/basis/FEBasisOperations.t.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ namespace dftefe
sameQuadRuleInAllCells && (!variableDofsPerCell);
linearAlgebra::blasLapack::Layout layout =
linearAlgebra::blasLapack::Layout::ColMajor;
// size_type NStartOffset = 0;
size_type NifNjStartOffset = 0;

for (size_type cellStartId = 0; cellStartId < numLocallyOwnedCells;
Expand All @@ -156,6 +155,11 @@ namespace dftefe
numCellQuad.begin() + cellEndId,
numCellsInBlockQuad.begin());

const size_type numCumulativeQuadCellsInBlock =
std::accumulate(numCellsInBlockQuad.begin(),
numCellsInBlockQuad.end(),
0);

size_type numCumulativeQuadxDofsCellsInBlock = 0;
size_type numCumulativeDofsxDofsCellsInBlock = 0;
for (size_type iCell = 0; iCell < numCellsInBlock; iCell++)
Expand All @@ -173,10 +177,11 @@ namespace dftefe
// variableStridedBlockScale

/** --- Storages --------- **/
StorageUnion fxNConjBlock(numCumulativeQuadxDofsCellsInBlock,
ValueTypeUnion());
StorageBasis JxWxNBlock(numCumulativeQuadxDofsCellsInBlock,
ValueTypeBasisData());
StorageUnion fxJxW(1 /*numComponents of f*/ *
numCumulativeQuadCellsInBlock,
ValueTypeUnion());
StorageUnion fxJxWxNConjBlock(
numCumulativeQuadxDofsCellsInBlock, ValueTypeUnion());
StorageBasis basisDataInCellRange(0);
if (!zeroStrideBasisVal)
basisDataInCellRange.resize(
Expand All @@ -187,6 +192,28 @@ namespace dftefe
ValueTypeBasisData());
/** --- Storages --------- **/

/*--------- Compute fxJxW -----------------*/
// TransposedKhatriRao product for inp and JxW
size_type cumulativeA = 0, cumulativeB = 0, cumulativeC = 0;
for (size_type iCell = 0; iCell < numCellsInBlock; iCell++)
{
linearAlgebra::blasLapack::khatriRaoProduct(
layout,
1,
1,
numCellsInBlockQuad[iCell],
jxwStorage.data() +
quadRuleContainer->getCellQuadStartId(cellStartId) +
cumulativeA,
f.begin(cellStartId) + cumulativeB,
fxJxW.data() + cumulativeC,
linAlgOpContext);
cumulativeA += numCellsInBlockQuad[iCell];
cumulativeB += numCellsInBlockQuad[iCell];
cumulativeC += numCellsInBlockQuad[iCell];
}

/*--------- Compute fxJxWxNConj -----------------*/
linearAlgebra::blasLapack::ScalarOp scalarOpA =
linearAlgebra::blasLapack::ScalarOp::Identity;
linearAlgebra::blasLapack::ScalarOp scalarOpB =
Expand All @@ -200,7 +227,7 @@ namespace dftefe

for (size_type iCell = 0; iCell < numCellsInBlock; iCell++)
{
mTmp[iCell] = 1; // only for f numComponents = 1
mTmp[iCell] = 1; // only for fxJxW numComponents = 1
nTmp[iCell] = numCellsInBlockDofs[iCell];
kTmp[iCell] = numCellsInBlockQuad[iCell];
stATmp[iCell] = mTmp[iCell] * kTmp[iCell];
Expand Down Expand Up @@ -234,70 +261,22 @@ namespace dftefe
linearAlgebra::blasLapack::scaleStridedVarBatched<
ValueTypeBasisCoeff,
ValueTypeBasisData,
memorySpace>(
numCellsInBlock,
layout,
scalarOpA,
scalarOpB,
stA.data(),
stB.data(),
stC.data(),
mSize.data(),
nSize.data(),
kSize.data(),
f.begin(cellStartId),
basisDataInCellRange.data(),
// (feBasisDataStorage->getBasisDataInAllCells()).begin() +
// NStartOffset,
fxNConjBlock.data(),
linAlgOpContext);

/*
size_type cumulativeA = 0, cumulativeB = 0, cumulativeC = 0;
for (size_type iCell = 0; iCell < numCellsInBlock; ++iCell)
{
linearAlgebra::blasLapack::transposedKhatriRaoProduct(linearAlgebra::blasLapack::Layout::ColMajor,
1,
numCellsInBlockDofs[iCell],
numCellsInBlockQuad[iCell],
f.data()
+ cumulativeA, (feBasisDataStorage->getBasisDataInAllCells()).
data()
+ NStartOffset + cumulativeB, fxNConjBlock.data() + cumulativeC,
linAlgOpContext);
cumulativeA += numCellsInBlockQuad[iCell];
if(!zeroStrideBasisVal)
cumulativeB += numCellsInBlockQuad[iCell] *
numCellsInBlockDofs[iCell]; cumulativeC +=
numCellsInBlockDofs[iCell] * numCellsInBlockQuad[iCell];
}
*/

scalarOpB = linearAlgebra::blasLapack::ScalarOp::Identity;

/* Other params from previous declarations*/
linearAlgebra::blasLapack::scaleStridedVarBatched<
ValueTypeBasisData,
ValueTypeBasisData,
memorySpace>(
numCellsInBlock,
layout,
scalarOpA,
scalarOpB,
stA.data(),
stB.data(),
stC.data(),
mSize.data(),
nSize.data(),
kSize.data(),
jxwStorage.data() +
quadRuleContainer->getCellQuadStartId(cellStartId),
basisDataInCellRange.data(),
// (feBasisDataStorage->getBasisDataInAllCells()).begin() +
// NStartOffset,
JxWxNBlock.data(),
linAlgOpContext);
memorySpace>(numCellsInBlock,
layout,
scalarOpA,
scalarOpB,
stA.data(),
stB.data(),
stC.data(),
mSize.data(),
nSize.data(),
kSize.data(),
fxJxW.data(),
basisDataInCellRange.data(),
fxJxWxNConjBlock.data(),
linAlgOpContext);

/*--------- Do the integration -----------------*/
std::vector<linearAlgebra::blasLapack::Op> transA(
numCellsInBlock, linearAlgebra::blasLapack::Op::NoTrans);
std::vector<linearAlgebra::blasLapack::Op> transB(
Expand All @@ -320,9 +299,10 @@ namespace dftefe
ldaSizesTmp[iCell] = mSizesTmp[iCell];
ldbSizesTmp[iCell] = nSizesTmp[iCell];
ldcSizesTmp[iCell] = mSizesTmp[iCell];
strideATmp[iCell] = mSizesTmp[iCell] * kSizesTmp[iCell];
strideCTmp[iCell] = mSizesTmp[iCell] * nSizesTmp[iCell];
strideBTmp[iCell] = kSizesTmp[iCell] * nSizesTmp[iCell];
if (!zeroStrideBasisVal)
strideATmp[iCell] = mSizesTmp[iCell] * kSizesTmp[iCell];
strideBTmp[iCell] = kSizesTmp[iCell] * nSizesTmp[iCell];
strideCTmp[iCell] = mSizesTmp[iCell] * nSizesTmp[iCell];
}

utils::MemoryStorage<size_type, memorySpace> mSizes(
Expand Down Expand Up @@ -389,19 +369,15 @@ namespace dftefe
nSizes.data(),
kSizes.data(),
alpha,
JxWxNBlock.data(),
basisDataInCellRange.data(),
ldaSizes.data(),
fxNConjBlock.data(),
fxJxWxNConjBlock.data(),
ldbSizes.data(),
beta,
C,
ldcSizes.data(),
linAlgOpContext);

// if (!zeroStrideBasisVal)
// {
// NStartOffset += numCumulativeQuadxDofsCellsInBlock;
// }
NifNjStartOffset += numCumulativeDofsxDofsCellsInBlock;
}
}
Expand Down
Loading

0 comments on commit b0518b1

Please sign in to comment.