Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Glusa <[email protected]>
  • Loading branch information
cgcgcg committed Sep 4, 2024
1 parent 27b2765 commit 6afb4a7
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,43 @@
#include "MueLu_DroppingCommon.hpp"
#include "Kokkos_Core.hpp"
#include "Kokkos_ArithTraits.hpp"
#include "Xpetra_Matrix.hpp"
#include "MueLu_Utilities.hpp"

namespace MueLu::ClassicalDropping {

template <class local_matrix_type, class diag_view_type>
template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
class AbsDropFunctor {
private:
using matrix_type = Xpetra::Matrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>;
using diag_vec_type = Xpetra::MultiVector<Scalar, LocalOrdinal, GlobalOrdinal, Node>;
using local_matrix_type = typename matrix_type::local_matrix_type;
using scalar_type = typename local_matrix_type::value_type;
using local_ordinal_type = typename local_matrix_type::ordinal_type;
using memory_space = typename local_matrix_type::memory_space;
using results_view = Kokkos::View<DecisionType*, memory_space>;
using diag_view_type = typename Kokkos::DualView<const scalar_type*, Kokkos::LayoutStride, typename Node::device_type, Kokkos::MemoryUnmanaged>::t_dev;

using results_view = Kokkos::View<DecisionType*, memory_space>;

using ATS = Kokkos::ArithTraits<scalar_type>;
using magnitudeType = typename ATS::magnitudeType;
using boundary_nodes_view = Kokkos::View<const bool*, memory_space>;

local_matrix_type A;
Teuchos::RCP<diag_vec_type> diagVec;
diag_view_type diag; // corresponds to overlapped diagonal
magnitudeType eps;
results_view results;

public:
AbsDropFunctor(local_matrix_type& A_, magnitudeType threshold, diag_view_type diag_, results_view& results_)
: A(A_)
, diag(diag_)
AbsDropFunctor(matrix_type& A_, magnitudeType threshold, results_view& results_)
: A(A_.getLocalMatrixDevice())
, eps(threshold)
, results(results_) {}
, results(results_) {
diagVec = Utilities<Scalar, LocalOrdinal, GlobalOrdinal, Node>::GetMatrixOverlappedDiagonal(A_);
auto lclDiag2d = diagVec->getDeviceLocalView(Xpetra::Access::ReadOnly);
diag = Kokkos::subview(lclDiag2d, Kokkos::ALL(), 0);
}

KOKKOS_FORCEINLINE_FUNCTION
bool operator()(const local_ordinal_type rlid) const {
Expand All @@ -58,29 +69,39 @@ class AbsDropFunctor {
}
};

template <class local_matrix_type, class diag_view_type>
template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
class SignedClassicalRSDropFunctor {
private:
using matrix_type = Xpetra::Matrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>;
using local_matrix_type = typename matrix_type::local_matrix_type;
using scalar_type = typename local_matrix_type::value_type;
using local_ordinal_type = typename local_matrix_type::ordinal_type;
using memory_space = typename local_matrix_type::memory_space;
using results_view = Kokkos::View<DecisionType*, memory_space>;

using results_view = Kokkos::View<DecisionType*, memory_space>;

using ATS = Kokkos::ArithTraits<scalar_type>;
using magnitudeType = typename ATS::magnitudeType;
using boundary_nodes_view = Kokkos::View<const bool*, memory_space>;

using diag_vec_type = Xpetra::MultiVector<magnitudeType, LocalOrdinal, GlobalOrdinal, Node>;
using diag_view_type = typename Kokkos::DualView<const magnitudeType*, Kokkos::LayoutStride, typename Node::device_type, Kokkos::MemoryUnmanaged>::t_dev;

local_matrix_type A;
Teuchos::RCP<diag_vec_type> diagVec;
diag_view_type diag; // corresponds to overlapped diagonal
magnitudeType eps;
results_view results;

public:
SignedClassicalRSDropFunctor(local_matrix_type& A_, magnitudeType threshold, diag_view_type diag_, results_view& results_)
: A(A_)
, diag(diag_)
SignedClassicalRSDropFunctor(matrix_type& A_, magnitudeType threshold, results_view& results_)
: A(A_.getLocalMatrixDevice())
, eps(threshold)
, results(results_) {}
, results(results_) {
diagVec = Utilities<Scalar, LocalOrdinal, GlobalOrdinal, Node>::GetMatrixMaxMinusOffDiagonal(A_);
auto lclDiag2d = diagVec->getDeviceLocalView(Xpetra::Access::ReadOnly);
diag = Kokkos::subview(lclDiag2d, Kokkos::ALL(), 0);
}

KOKKOS_FORCEINLINE_FUNCTION
bool operator()(const local_ordinal_type rlid) const {
Expand All @@ -97,30 +118,40 @@ class SignedClassicalRSDropFunctor {
}
};

template <class local_matrix_type, class diag_view_type>
template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
class SignedClassicalSADropFunctor {
private:
using matrix_type = Xpetra::Matrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>;
using diag_vec_type = Xpetra::MultiVector<Scalar, LocalOrdinal, GlobalOrdinal, Node>;
using local_matrix_type = typename matrix_type::local_matrix_type;
using scalar_type = typename local_matrix_type::value_type;
using local_ordinal_type = typename local_matrix_type::ordinal_type;
using memory_space = typename local_matrix_type::memory_space;
using results_view = Kokkos::View<DecisionType*, memory_space>;
using diag_view_type = typename Kokkos::DualView<const scalar_type*, Kokkos::LayoutStride, typename Node::device_type, Kokkos::MemoryUnmanaged>::t_dev;

using results_view = Kokkos::View<DecisionType*, memory_space>;

using ATS = Kokkos::ArithTraits<scalar_type>;
using magnitudeType = typename ATS::magnitudeType;
using mATS = Kokkos::ArithTraits<magnitudeType>;
using boundary_nodes_view = Kokkos::View<const bool*, memory_space>;

local_matrix_type A;
Teuchos::RCP<diag_vec_type> diagVec;
diag_view_type diag; // corresponds to overlapped diagonal
magnitudeType eps;
results_view results;

public:
SignedClassicalSADropFunctor(local_matrix_type& A_, magnitudeType threshold, diag_view_type diag_, results_view& results_)
: A(A_)
, diag(diag_)
SignedClassicalSADropFunctor(matrix_type& A_, magnitudeType threshold, results_view& results_)
: A(A_.getLocalMatrixDevice())
, eps(threshold)
, results(results_) {}
, results(results_) {
// Construct ghosted matrix diagonal
diagVec = Utilities<Scalar, LocalOrdinal, GlobalOrdinal, Node>::GetMatrixOverlappedDiagonal(A_);
auto lclDiag2d = diagVec->getDeviceLocalView(Xpetra::Access::ReadOnly);
diag = Kokkos::subview(lclDiag2d, Kokkos::ALL(), 0);
}

KOKKOS_FORCEINLINE_FUNCTION
bool operator()(const local_ordinal_type rlid) const {
Expand Down
Loading

0 comments on commit 6afb4a7

Please sign in to comment.