From b0518b1ea8d227db2c430b8c2938710516af6a8b Mon Sep 17 00:00:00 2001 From: Avirup Sircar Date: Tue, 12 Nov 2024 08:50:11 -0500 Subject: [PATCH] Reduced memory storages in intergal V_eff N_i N_j --- src/basis/EFEBDSOnTheFlyComputeDealii.t.cpp | 21 +-- src/basis/FEBasisOperations.t.cpp | 132 +++++++----------- .../OrthoEFEOverlapOperatorContext.t.cpp | 128 +++++++++-------- src/ksdft/Defaults.cpp | 2 +- 4 files changed, 138 insertions(+), 145 deletions(-) diff --git a/src/basis/EFEBDSOnTheFlyComputeDealii.t.cpp b/src/basis/EFEBDSOnTheFlyComputeDealii.t.cpp index 5fae6cd6..82d93fb0 100644 --- a/src/basis/EFEBDSOnTheFlyComputeDealii.t.cpp +++ b/src/basis/EFEBDSOnTheFlyComputeDealii.t.cpp @@ -1800,16 +1800,17 @@ namespace dftefe for (size_type quadId = 0; quadId < d_nQuadPointsIncell[cellId]; quadId++) { - for (size_type iDim = 0; iDim < dim; iDim++) - { - basisGradientData.template copyFrom( - iter->second->data(), - (d_dofsInCell[cellId] - d_classialDofsInCell), - (d_dofsInCell[cellId] - d_classialDofsInCell) * dim * - quadId + iDim * (d_dofsInCell[cellId] - d_classialDofsInCell), - cumulativeOffset + d_dofsInCell[cellId] * dim * quadId + - d_dofsInCell[cellId] * iDim + d_classialDofsInCell); - } + for (size_type iDim = 0; iDim < dim; iDim++) + { + basisGradientData.template copyFrom( + iter->second->data(), + (d_dofsInCell[cellId] - d_classialDofsInCell), + (d_dofsInCell[cellId] - d_classialDofsInCell) * dim * + quadId + + iDim * (d_dofsInCell[cellId] - d_classialDofsInCell), + cumulativeOffset + d_dofsInCell[cellId] * dim * quadId + + d_dofsInCell[cellId] * iDim + d_classialDofsInCell); + } } } cumulativeOffset += diff --git a/src/basis/FEBasisOperations.t.cpp b/src/basis/FEBasisOperations.t.cpp index 0177eb4b..d1bfd878 100644 --- a/src/basis/FEBasisOperations.t.cpp +++ b/src/basis/FEBasisOperations.t.cpp @@ -137,7 +137,6 @@ namespace dftefe sameQuadRuleInAllCells && (!variableDofsPerCell); linearAlgebra::blasLapack::Layout layout = linearAlgebra::blasLapack::Layout::ColMajor; - // size_type NStartOffset = 0; size_type NifNjStartOffset = 0; for (size_type cellStartId = 0; cellStartId < numLocallyOwnedCells; @@ -156,6 +155,11 @@ namespace dftefe numCellQuad.begin() + cellEndId, numCellsInBlockQuad.begin()); + const size_type numCumulativeQuadCellsInBlock = + std::accumulate(numCellsInBlockQuad.begin(), + numCellsInBlockQuad.end(), + 0); + size_type numCumulativeQuadxDofsCellsInBlock = 0; size_type numCumulativeDofsxDofsCellsInBlock = 0; for (size_type iCell = 0; iCell < numCellsInBlock; iCell++) @@ -173,10 +177,11 @@ namespace dftefe // variableStridedBlockScale /** --- Storages --------- **/ - StorageUnion fxNConjBlock(numCumulativeQuadxDofsCellsInBlock, - ValueTypeUnion()); - StorageBasis JxWxNBlock(numCumulativeQuadxDofsCellsInBlock, - ValueTypeBasisData()); + StorageUnion fxJxW(1 /*numComponents of f*/ * + numCumulativeQuadCellsInBlock, + ValueTypeUnion()); + StorageUnion fxJxWxNConjBlock( + numCumulativeQuadxDofsCellsInBlock, ValueTypeUnion()); StorageBasis basisDataInCellRange(0); if (!zeroStrideBasisVal) basisDataInCellRange.resize( @@ -187,6 +192,28 @@ namespace dftefe ValueTypeBasisData()); /** --- Storages --------- **/ + /*--------- Compute fxJxW -----------------*/ + // TransposedKhatriRao product for inp and JxW + size_type cumulativeA = 0, cumulativeB = 0, cumulativeC = 0; + for (size_type iCell = 0; iCell < numCellsInBlock; iCell++) + { + linearAlgebra::blasLapack::khatriRaoProduct( + layout, + 1, + 1, + numCellsInBlockQuad[iCell], + jxwStorage.data() + + quadRuleContainer->getCellQuadStartId(cellStartId) + + cumulativeA, + f.begin(cellStartId) + cumulativeB, + fxJxW.data() + cumulativeC, + linAlgOpContext); + cumulativeA += numCellsInBlockQuad[iCell]; + cumulativeB += numCellsInBlockQuad[iCell]; + cumulativeC += numCellsInBlockQuad[iCell]; + } + + /*--------- Compute fxJxWxNConj -----------------*/ linearAlgebra::blasLapack::ScalarOp scalarOpA = linearAlgebra::blasLapack::ScalarOp::Identity; linearAlgebra::blasLapack::ScalarOp scalarOpB = @@ -200,7 +227,7 @@ namespace dftefe for (size_type iCell = 0; iCell < numCellsInBlock; iCell++) { - mTmp[iCell] = 1; // only for f numComponents = 1 + mTmp[iCell] = 1; // only for fxJxW numComponents = 1 nTmp[iCell] = numCellsInBlockDofs[iCell]; kTmp[iCell] = numCellsInBlockQuad[iCell]; stATmp[iCell] = mTmp[iCell] * kTmp[iCell]; @@ -234,70 +261,22 @@ namespace dftefe linearAlgebra::blasLapack::scaleStridedVarBatched< ValueTypeBasisCoeff, ValueTypeBasisData, - memorySpace>( - numCellsInBlock, - layout, - scalarOpA, - scalarOpB, - stA.data(), - stB.data(), - stC.data(), - mSize.data(), - nSize.data(), - kSize.data(), - f.begin(cellStartId), - basisDataInCellRange.data(), - // (feBasisDataStorage->getBasisDataInAllCells()).begin() + - // NStartOffset, - fxNConjBlock.data(), - linAlgOpContext); - - /* - size_type cumulativeA = 0, cumulativeB = 0, cumulativeC = 0; - for (size_type iCell = 0; iCell < numCellsInBlock; ++iCell) - { - linearAlgebra::blasLapack::transposedKhatriRaoProduct(linearAlgebra::blasLapack::Layout::ColMajor, - 1, - numCellsInBlockDofs[iCell], - numCellsInBlockQuad[iCell], - f.data() - + cumulativeA, (feBasisDataStorage->getBasisDataInAllCells()). - data() - + NStartOffset + cumulativeB, fxNConjBlock.data() + cumulativeC, - linAlgOpContext); - cumulativeA += numCellsInBlockQuad[iCell]; - if(!zeroStrideBasisVal) - cumulativeB += numCellsInBlockQuad[iCell] * - numCellsInBlockDofs[iCell]; cumulativeC += - numCellsInBlockDofs[iCell] * numCellsInBlockQuad[iCell]; - } - */ - - scalarOpB = linearAlgebra::blasLapack::ScalarOp::Identity; - - /* Other params from previous declarations*/ - linearAlgebra::blasLapack::scaleStridedVarBatched< - ValueTypeBasisData, - ValueTypeBasisData, - memorySpace>( - numCellsInBlock, - layout, - scalarOpA, - scalarOpB, - stA.data(), - stB.data(), - stC.data(), - mSize.data(), - nSize.data(), - kSize.data(), - jxwStorage.data() + - quadRuleContainer->getCellQuadStartId(cellStartId), - basisDataInCellRange.data(), - // (feBasisDataStorage->getBasisDataInAllCells()).begin() + - // NStartOffset, - JxWxNBlock.data(), - linAlgOpContext); + memorySpace>(numCellsInBlock, + layout, + scalarOpA, + scalarOpB, + stA.data(), + stB.data(), + stC.data(), + mSize.data(), + nSize.data(), + kSize.data(), + fxJxW.data(), + basisDataInCellRange.data(), + fxJxWxNConjBlock.data(), + linAlgOpContext); + /*--------- Do the integration -----------------*/ std::vector transA( numCellsInBlock, linearAlgebra::blasLapack::Op::NoTrans); std::vector transB( @@ -320,9 +299,10 @@ namespace dftefe ldaSizesTmp[iCell] = mSizesTmp[iCell]; ldbSizesTmp[iCell] = nSizesTmp[iCell]; ldcSizesTmp[iCell] = mSizesTmp[iCell]; - strideATmp[iCell] = mSizesTmp[iCell] * kSizesTmp[iCell]; - strideCTmp[iCell] = mSizesTmp[iCell] * nSizesTmp[iCell]; - strideBTmp[iCell] = kSizesTmp[iCell] * nSizesTmp[iCell]; + if (!zeroStrideBasisVal) + strideATmp[iCell] = mSizesTmp[iCell] * kSizesTmp[iCell]; + strideBTmp[iCell] = kSizesTmp[iCell] * nSizesTmp[iCell]; + strideCTmp[iCell] = mSizesTmp[iCell] * nSizesTmp[iCell]; } utils::MemoryStorage mSizes( @@ -389,19 +369,15 @@ namespace dftefe nSizes.data(), kSizes.data(), alpha, - JxWxNBlock.data(), + basisDataInCellRange.data(), ldaSizes.data(), - fxNConjBlock.data(), + fxJxWxNConjBlock.data(), ldbSizes.data(), beta, C, ldcSizes.data(), linAlgOpContext); - // if (!zeroStrideBasisVal) - // { - // NStartOffset += numCumulativeQuadxDofsCellsInBlock; - // } NifNjStartOffset += numCumulativeDofsxDofsCellsInBlock; } } diff --git a/src/basis/OrthoEFEOverlapOperatorContext.t.cpp b/src/basis/OrthoEFEOverlapOperatorContext.t.cpp index 57b87cbe..3d2e4c01 100644 --- a/src/basis/OrthoEFEOverlapOperatorContext.t.cpp +++ b/src/basis/OrthoEFEOverlapOperatorContext.t.cpp @@ -599,25 +599,26 @@ namespace dftefe JxWxNCellConj.resize( dofsPerCell * nQuadPointInCellEnrichmentBlockEnrichment, 0); - m = 1, n = dofsPerCell, k = nQuadPointInCellEnrichmentBlockEnrichment; - - linearAlgebra::blasLapack::scaleStridedVarBatched( - 1, - linearAlgebra::blasLapack::Layout::ColMajor, - linearAlgebra::blasLapack::ScalarOp::Identity, - linearAlgebra::blasLapack::ScalarOp::Conj, - &stride, - &stride, - &stride, - &m, - &n, - &k, - cellJxWValuesEnrichmentBlockEnrichment.data(), - cumulativeEnrichmentBlockEnrichmentDofQuadPoints, - JxWxNCellConj.data(), - linAlgOpContext); + m = 1, n = dofsPerCell, + k = nQuadPointInCellEnrichmentBlockEnrichment; + + linearAlgebra::blasLapack::scaleStridedVarBatched< + ValueTypeOperator, + ValueTypeOperator, + memorySpace>(1, + linearAlgebra::blasLapack::Layout::ColMajor, + linearAlgebra::blasLapack::ScalarOp::Identity, + linearAlgebra::blasLapack::ScalarOp::Conj, + &stride, + &stride, + &stride, + &m, + &n, + &k, + cellJxWValuesEnrichmentBlockEnrichment.data(), + cumulativeEnrichmentBlockEnrichmentDofQuadPoints, + JxWxNCellConj.data(), + linAlgOpContext); linearAlgebra::blasLapack:: gemm( @@ -637,29 +638,30 @@ namespace dftefe dofsPerCell, linAlgOpContext); - JxWxNCellConj.resize( - dofsPerCellCFE * nQuadPointInCellEnrichmentBlockClassical, 0); + JxWxNCellConj.resize(dofsPerCellCFE * + nQuadPointInCellEnrichmentBlockClassical, + 0); m = 1, n = dofsPerCellCFE, - k = nQuadPointInCellEnrichmentBlockClassical; - - linearAlgebra::blasLapack::scaleStridedVarBatched( - 1, - linearAlgebra::blasLapack::Layout::ColMajor, - linearAlgebra::blasLapack::ScalarOp::Identity, - linearAlgebra::blasLapack::ScalarOp::Conj, - &stride, - &stride, - &stride, - &m, - &n, - &k, - cellJxWValuesEnrichmentBlockClassical.data(), - cumulativeEnrichmentBlockClassicalDofQuadPoints, - JxWxNCellConj.data(), - linAlgOpContext); + k = nQuadPointInCellEnrichmentBlockClassical; + + linearAlgebra::blasLapack::scaleStridedVarBatched< + ValueTypeOperator, + ValueTypeOperator, + memorySpace>(1, + linearAlgebra::blasLapack::Layout::ColMajor, + linearAlgebra::blasLapack::ScalarOp::Identity, + linearAlgebra::blasLapack::ScalarOp::Conj, + &stride, + &stride, + &stride, + &m, + &n, + &k, + cellJxWValuesEnrichmentBlockClassical.data(), + cumulativeEnrichmentBlockClassicalDofQuadPoints, + JxWxNCellConj.data(), + linAlgOpContext); linearAlgebra::blasLapack:: gemm( @@ -714,7 +716,8 @@ namespace dftefe // ciNciNcj = (ValueTypeOperator)0; // // Ni_pristine*Ni_classical at quadpoints // for (unsigned int qPoint = 0; - // qPoint < nQuadPointInCellEnrichmentBlockEnrichment; + // qPoint < + // nQuadPointInCellEnrichmentBlockEnrichment; // qPoint++) // { // NpiNcj += @@ -722,7 +725,8 @@ namespace dftefe // (iNode - dofsPerCellCFE) * // nQuadPointInCellEnrichmentBlockEnrichment + // qPoint) * - // *(cumulativeEnrichmentBlockEnrichmentDofQuadPoints + + // *(cumulativeEnrichmentBlockEnrichmentDofQuadPoints + // + // dofsPerCell * qPoint + jNode // /*nQuadPointInCellEnrichmentBlockEnrichment * // jNode + @@ -730,17 +734,20 @@ namespace dftefe // cellJxWValuesEnrichmentBlockEnrichment[qPoint]; // } - // // Ni_classical using Mc = d quadrature * interpolated + // // Ni_classical using Mc = d quadrature * + // interpolated // // ci's in Ni_classicalQuadrature of Mc = d // for (unsigned int qPoint = 0; - // qPoint < nQuadPointInCellEnrichmentBlockClassical; + // qPoint < + // nQuadPointInCellEnrichmentBlockClassical; // qPoint++) // { // ciNciNcj += // classicalComponentInQuadValuesEC // [numEnrichmentIdsInCell * qPoint + // (iNode - dofsPerCellCFE)] * - // *(cumulativeEnrichmentBlockClassicalDofQuadPoints + + // *(cumulativeEnrichmentBlockClassicalDofQuadPoints + // + // dofsPerCellCFE * qPoint + jNode // /*nQuadPointInCellEnrichmentBlockClassical * // jNode + qPoint*/) * @@ -748,9 +755,11 @@ namespace dftefe // } // *basisOverlapTmpIter += NpiNcj - ciNciNcj; - *basisOverlapTmpIter = - *(basisOverlapECBlockEnrich.data() + (iNode - dofsPerCellCFE) * dofsPerCell + jNode) - - *(basisOverlapECBlockClass.data() + (iNode - dofsPerCellCFE) * dofsPerCellCFE + jNode); + *basisOverlapTmpIter = + *(basisOverlapECBlockEnrich.data() + + (iNode - dofsPerCellCFE) * dofsPerCell + jNode) - + *(basisOverlapECBlockClass.data() + + (iNode - dofsPerCellCFE) * dofsPerCellCFE + jNode); } else if (iNode < dofsPerCellCFE && @@ -760,11 +769,13 @@ namespace dftefe // NcicjNcj = (ValueTypeOperator)0; // // Ni_pristine*Ni_classical at quadpoints // for (unsigned int qPoint = 0; - // qPoint < nQuadPointInCellEnrichmentBlockEnrichment; + // qPoint < + // nQuadPointInCellEnrichmentBlockEnrichment; // qPoint++) // { // NciNpj += - // *(cumulativeEnrichmentBlockEnrichmentDofQuadPoints + + // *(cumulativeEnrichmentBlockEnrichmentDofQuadPoints + // + // dofsPerCell * qPoint + iNode // /*nQuadPointInCellEnrichmentBlockEnrichment * // iNode + @@ -776,14 +787,17 @@ namespace dftefe // cellJxWValuesEnrichmentBlockEnrichment[qPoint]; // } - // // Ni_classical using Mc = d quadrature * interpolated + // // Ni_classical using Mc = d quadrature * + // interpolated // // ci's in Ni_classicalQuadrature of Mc = d // for (unsigned int qPoint = 0; - // qPoint < nQuadPointInCellEnrichmentBlockClassical; + // qPoint < + // nQuadPointInCellEnrichmentBlockClassical; // qPoint++) // { // NcicjNcj += - // *(cumulativeEnrichmentBlockClassicalDofQuadPoints + + // *(cumulativeEnrichmentBlockClassicalDofQuadPoints + // + // dofsPerCellCFE * qPoint + iNode // /*nQuadPointInCellEnrichmentBlockClassical * // iNode + qPoint*/) * @@ -794,9 +808,11 @@ namespace dftefe // } // *basisOverlapTmpIter += NciNpj - NcicjNcj; - *basisOverlapTmpIter = - *(basisOverlapECBlockEnrich.data() + (jNode - dofsPerCellCFE) * dofsPerCell + iNode) - - *(basisOverlapECBlockClass.data() + (jNode - dofsPerCellCFE) * dofsPerCellCFE + iNode); + *basisOverlapTmpIter = + *(basisOverlapECBlockEnrich.data() + + (jNode - dofsPerCellCFE) * dofsPerCell + iNode) - + *(basisOverlapECBlockClass.data() + + (jNode - dofsPerCellCFE) * dofsPerCellCFE + iNode); } else if (iNode >= dofsPerCellCFE && jNode >= dofsPerCellCFE) diff --git a/src/ksdft/Defaults.cpp b/src/ksdft/Defaults.cpp index e04c6fdf..5771619e 100644 --- a/src/ksdft/Defaults.cpp +++ b/src/ksdft/Defaults.cpp @@ -92,7 +92,7 @@ namespace dftefe */ const size_type KSDFTDefaults::MAX_WAVEFN_BATCH_SIZE = 400; const size_type KSDFTDefaults::MAX_KINENG_WAVEFN_BATCH_SIZE = 50; - const size_type KSDFTDefaults::CELL_BATCH_SIZE = 20; + const size_type KSDFTDefaults::CELL_BATCH_SIZE = 1; const size_type KSDFTDefaults::CELL_BATCH_SIZE_GRAD_EVAL = 1; } // end of namespace ksdft