Skip to content

Commit 063bea1

Browse files
Address PR comments
1 parent 5a9d33d commit 063bea1

File tree

8 files changed

+191
-143
lines changed

8 files changed

+191
-143
lines changed

backends/ref/ceed-ref-basis.c

+8-8
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ static int CeedBasisApply_Ref(CeedBasis basis, CeedInt num_elem, CeedTransposeMo
2424
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
2525
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
2626
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
27-
CeedFESpace fe_space;
28-
CeedCall(CeedBasisGetFESpace(basis, &fe_space));
2927
CeedTensorContract contract;
3028
CeedCallBackend(CeedBasisGetTensorContract(basis, &contract));
3129
const CeedInt add = (t_mode == CEED_TRANSPOSE);
@@ -195,12 +193,13 @@ static int CeedBasisApply_Ref(CeedBasis basis, CeedInt num_elem, CeedTransposeMo
195193
switch (eval_mode) {
196194
// Interpolate to/from quadrature points
197195
case CEED_EVAL_INTERP: {
198-
CeedInt qdim = (fe_space == CEED_FE_SPACE_H1) ? 1 : dim;
199-
CeedInt P = num_nodes, Q = qdim * num_qpts;
196+
CeedInt q_comp;
197+
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, &q_comp));
198+
CeedInt P = num_nodes, Q = q_comp * num_qpts;
200199
const CeedScalar *interp;
201200
CeedCallBackend(CeedBasisGetInterp(basis, &interp));
202201
if (t_mode == CEED_TRANSPOSE) {
203-
P = qdim * num_qpts;
202+
P = q_comp * num_qpts;
204203
Q = num_nodes;
205204
}
206205
CeedCallBackend(CeedTensorContractApply(contract, num_comp, P, num_elem, Q, interp, t_mode, add, u, v));
@@ -250,12 +249,13 @@ static int CeedBasisApply_Ref(CeedBasis basis, CeedInt num_elem, CeedTransposeMo
250249
} break;
251250
// Evaluate the curl to/from the quadrature points
252251
case CEED_EVAL_CURL: {
253-
CeedInt cdim = (dim < 3) ? 1 : dim;
254-
CeedInt P = num_nodes, Q = cdim * num_qpts;
252+
CeedInt curl_comp;
253+
CeedCallBackend(CeedBasisGetNumCurlComponents(basis, &curl_comp));
254+
CeedInt P = num_nodes, Q = curl_comp * num_qpts;
255255
const CeedScalar *curl;
256256
CeedCallBackend(CeedBasisGetCurl(basis, &curl));
257257
if (t_mode == CEED_TRANSPOSE) {
258-
P = cdim * num_qpts;
258+
P = curl_comp * num_qpts;
259259
Q = num_nodes;
260260
}
261261
CeedCallBackend(CeedTensorContractApply(contract, num_comp, P, num_elem, Q, curl, t_mode, add, u, v));

include/ceed-impl.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ struct CeedBasis_private {
191191
CeedScalar *grad; /* row-major matrix of shape [dim * Q, P] matrix expressing derivatives of nodal basis functions at quadrature points */
192192
CeedScalar *grad_1d; /* row-major matrix of shape [Q1d, P1d] matrix expressing derivatives of nodal basis functions at quadrature points */
193193
CeedScalar *div; /* row-major matrix of shape [Q, P] expressing the divergence of basis functions at quadrature points for H(div) discretizations */
194-
CeedScalar *curl; /* row-major matrix of shape [cdim * Q, P], cdim = 1 if dim < 3 else dim, expressing the curl of basis functions at quadrature
195-
points for H(curl) discretizations */
194+
CeedScalar *curl; /* row-major matrix of shape [curl_dim * Q, P], curl_dim = 1 if dim < 3 else dim, expressing the curl of basis functions at
195+
quadrature points for H(curl) discretizations */
196196
void *data; /* place for the backend to store any data */
197197
};
198198

include/ceed/backend.h

+5
Original file line numberDiff line numberDiff line change
@@ -280,4 +280,9 @@ CEED_EXTERN int CeedOperatorSetData(CeedOperator op, void *data);
280280
CEED_EXTERN int CeedOperatorReference(CeedOperator op);
281281
CEED_EXTERN int CeedOperatorSetSetupDone(CeedOperator op);
282282

283+
CEED_EXTERN int CeedMatrixMatrixMultiply(Ceed ceed, const CeedScalar *mat_A, const CeedScalar *mat_B, CeedScalar *mat_C, CeedInt m, CeedInt n,
284+
CeedInt kk);
285+
CEED_EXTERN int CeedHouseholderApplyQ(CeedScalar *A, const CeedScalar *Q, const CeedScalar *tau, CeedTransposeMode t_mode, CeedInt m, CeedInt n,
286+
CeedInt k, CeedInt row, CeedInt col);
287+
283288
#endif

include/ceed/ceed.h

+2-4
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,8 @@ CEED_EXTERN int CeedBasisGetCeed(CeedBasis basis, Ceed *ceed);
372372
CEED_EXTERN int CeedBasisGetDimension(CeedBasis basis, CeedInt *dim);
373373
CEED_EXTERN int CeedBasisGetTopology(CeedBasis basis, CeedElemTopology *topo);
374374
CEED_EXTERN int CeedBasisGetNumComponents(CeedBasis basis, CeedInt *num_comp);
375+
CEED_EXTERN int CeedBasisGetNumQuadratureComponents(CeedBasis basis, CeedInt *q_comp);
376+
CEED_EXTERN int CeedBasisGetNumCurlComponents(CeedBasis basis, CeedInt *curl_comp);
375377
CEED_EXTERN int CeedBasisGetNumNodes(CeedBasis basis, CeedInt *P);
376378
CEED_EXTERN int CeedBasisGetNumNodes1D(CeedBasis basis, CeedInt *P_1d);
377379
CEED_EXTERN int CeedBasisGetNumQuadraturePoints(CeedBasis basis, CeedInt *Q);
@@ -391,10 +393,6 @@ CEED_EXTERN int CeedLobattoQuadrature(CeedInt Q, CeedScalar *q_ref_1d, CeedScala
391393
CEED_EXTERN int CeedQRFactorization(Ceed ceed, CeedScalar *mat, CeedScalar *tau, CeedInt m, CeedInt n);
392394
CEED_EXTERN int CeedSymmetricSchurDecomposition(Ceed ceed, CeedScalar *mat, CeedScalar *lambda, CeedInt n);
393395
CEED_EXTERN int CeedSimultaneousDiagonalization(Ceed ceed, CeedScalar *mat_A, CeedScalar *mat_B, CeedScalar *x, CeedScalar *lambda, CeedInt n);
394-
CEED_EXTERN int CeedHouseholderApplyQ(CeedScalar *A, const CeedScalar *Q, const CeedScalar *tau, CeedTransposeMode t_mode, CeedInt m, CeedInt n,
395-
CeedInt k, CeedInt row, CeedInt col);
396-
CEED_EXTERN int CeedMatrixMatrixMultiply(Ceed ceed, const CeedScalar *mat_A, const CeedScalar *mat_B, CeedScalar *mat_C, CeedInt m, CeedInt n,
397-
CeedInt kk);
398396

399397
/** Handle for the user provided CeedQFunction callback function
400398

0 commit comments

Comments
 (0)