Skip to content

Commit 97c1c57

Browse files
Address PR comments for high-level APIs
Add Rust sibling functions and tests, add Python tests, fix Julia versioning for CI.
1 parent 11b88dd commit 97c1c57

File tree

7 files changed

+466
-59
lines changed

7 files changed

+466
-59
lines changed

julia/LibCEED.jl/src/LibCEED.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ include("Request.jl")
150150
include("Operator.jl")
151151
include("Misc.jl")
152152

153-
const minimum_libceed_version = v"0.10.0"
153+
const minimum_libceed_version = v"0.12.0"
154154

155155
function __init__()
156156
if !ceedversion_ge(minimum_libceed_version)

julia/LibCEED.jl/src/generated/libceed_bindings.jl

+55-53
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,57 @@ end
2424
CEED_ERROR_UNSUPPORTED = -3
2525
end
2626

27+
@cenum CeedMemType::UInt32 begin
28+
CEED_MEM_HOST = 0
29+
CEED_MEM_DEVICE = 1
30+
end
31+
32+
@cenum CeedCopyMode::UInt32 begin
33+
CEED_COPY_VALUES = 0
34+
CEED_USE_POINTER = 1
35+
CEED_OWN_POINTER = 2
36+
end
37+
38+
@cenum CeedNormType::UInt32 begin
39+
CEED_NORM_1 = 0
40+
CEED_NORM_2 = 1
41+
CEED_NORM_MAX = 2
42+
end
43+
44+
@cenum CeedTransposeMode::UInt32 begin
45+
CEED_NOTRANSPOSE = 0
46+
CEED_TRANSPOSE = 1
47+
end
48+
49+
@cenum CeedEvalMode::UInt32 begin
50+
CEED_EVAL_NONE = 0
51+
CEED_EVAL_INTERP = 1
52+
CEED_EVAL_GRAD = 2
53+
CEED_EVAL_DIV = 4
54+
CEED_EVAL_CURL = 8
55+
CEED_EVAL_WEIGHT = 16
56+
end
57+
58+
@cenum CeedQuadMode::UInt32 begin
59+
CEED_GAUSS = 0
60+
CEED_GAUSS_LOBATTO = 1
61+
end
62+
63+
@cenum CeedElemTopology::UInt32 begin
64+
CEED_TOPOLOGY_LINE = 65536
65+
CEED_TOPOLOGY_TRIANGLE = 131073
66+
CEED_TOPOLOGY_QUAD = 131074
67+
CEED_TOPOLOGY_TET = 196611
68+
CEED_TOPOLOGY_PYRAMID = 196612
69+
CEED_TOPOLOGY_PRISM = 196613
70+
CEED_TOPOLOGY_HEX = 196614
71+
end
72+
73+
@cenum CeedContextFieldType::UInt32 begin
74+
CEED_CONTEXT_FIELD_DOUBLE = 1
75+
CEED_CONTEXT_FIELD_INT32 = 2
76+
end
77+
2778
mutable struct Ceed_private end
2879

2980
const Ceed = Ptr{Ceed_private}
@@ -123,27 +174,10 @@ function CeedGetScalarType(scalar_type)
123174
ccall((:CeedGetScalarType, libceed), Cint, (Ptr{CeedScalarType},), scalar_type)
124175
end
125176

126-
@cenum CeedMemType::UInt32 begin
127-
CEED_MEM_HOST = 0
128-
CEED_MEM_DEVICE = 1
129-
end
130-
131177
function CeedGetPreferredMemType(ceed, type)
132178
ccall((:CeedGetPreferredMemType, libceed), Cint, (Ceed, Ptr{CeedMemType}), ceed, type)
133179
end
134180

135-
@cenum CeedCopyMode::UInt32 begin
136-
CEED_COPY_VALUES = 0
137-
CEED_USE_POINTER = 1
138-
CEED_OWN_POINTER = 2
139-
end
140-
141-
@cenum CeedNormType::UInt32 begin
142-
CEED_NORM_1 = 0
143-
CEED_NORM_2 = 1
144-
CEED_NORM_MAX = 2
145-
end
146-
147181
function CeedVectorCreate(ceed, len, vec)
148182
ccall((:CeedVectorCreate, libceed), Cint, (Ceed, CeedSize, Ptr{CeedVector}), ceed, len, vec)
149183
end
@@ -240,11 +274,6 @@ function CeedRequestWait(req)
240274
ccall((:CeedRequestWait, libceed), Cint, (Ptr{CeedRequest},), req)
241275
end
242276

243-
@cenum CeedTransposeMode::UInt32 begin
244-
CEED_NOTRANSPOSE = 0
245-
CEED_TRANSPOSE = 1
246-
end
247-
248277
function CeedElemRestrictionCreate(ceed, num_elem, elem_size, num_comp, comp_stride, l_size, mem_type, copy_mode, offsets, rstr)
249278
ccall((:CeedElemRestrictionCreate, libceed), Cint, (Ceed, CeedInt, CeedInt, CeedInt, CeedInt, CeedSize, CeedMemType, CeedCopyMode, Ptr{CeedInt}, Ptr{CeedElemRestriction}), ceed, num_elem, elem_size, num_comp, comp_stride, l_size, mem_type, copy_mode, offsets, rstr)
250279
end
@@ -325,30 +354,6 @@ function CeedElemRestrictionDestroy(rstr)
325354
ccall((:CeedElemRestrictionDestroy, libceed), Cint, (Ptr{CeedElemRestriction},), rstr)
326355
end
327356

328-
@cenum CeedEvalMode::UInt32 begin
329-
CEED_EVAL_NONE = 0
330-
CEED_EVAL_INTERP = 1
331-
CEED_EVAL_GRAD = 2
332-
CEED_EVAL_DIV = 4
333-
CEED_EVAL_CURL = 8
334-
CEED_EVAL_WEIGHT = 16
335-
end
336-
337-
@cenum CeedQuadMode::UInt32 begin
338-
CEED_GAUSS = 0
339-
CEED_GAUSS_LOBATTO = 1
340-
end
341-
342-
@cenum CeedElemTopology::UInt32 begin
343-
CEED_TOPOLOGY_LINE = 65536
344-
CEED_TOPOLOGY_TRIANGLE = 131073
345-
CEED_TOPOLOGY_QUAD = 131074
346-
CEED_TOPOLOGY_TET = 196611
347-
CEED_TOPOLOGY_PYRAMID = 196612
348-
CEED_TOPOLOGY_PRISM = 196613
349-
CEED_TOPOLOGY_HEX = 196614
350-
end
351-
352357
function CeedBasisCreateTensorH1Lagrange(ceed, dim, num_comp, P, Q, quad_mode, basis)
353358
ccall((:CeedBasisCreateTensorH1Lagrange, libceed), Cint, (Ceed, CeedInt, CeedInt, CeedInt, CeedInt, CeedQuadMode, Ptr{CeedBasis}), ceed, dim, num_comp, P, Q, quad_mode, basis)
354359
end
@@ -532,11 +537,6 @@ function CeedQFunctionFieldGetEvalMode(qf_field, eval_mode)
532537
ccall((:CeedQFunctionFieldGetEvalMode, libceed), Cint, (CeedQFunctionField, Ptr{CeedEvalMode}), qf_field, eval_mode)
533538
end
534539

535-
@cenum CeedContextFieldType::UInt32 begin
536-
CEED_CONTEXT_FIELD_DOUBLE = 1
537-
CEED_CONTEXT_FIELD_INT32 = 2
538-
end
539-
540540
# typedef int ( * CeedQFunctionContextDataDestroyUser ) ( void * data )
541541
const CeedQFunctionContextDataDestroyUser = Ptr{Cvoid}
542542

@@ -1364,7 +1364,9 @@ end
13641364

13651365
# Skipping MacroDefinition: CEED_EXTERN extern CEED_VISIBILITY ( default )
13661366

1367-
# Skipping MacroDefinition: CEED_QFUNCTION_HELPER CEED_QFUNCTION_ATTR static inline
1367+
# Skipping MacroDefinition: CEED_QFUNCTION_HELPER_ATTR CEED_QFUNCTION_ATTR __attribute__ ( ( always_inline ) )
1368+
1369+
# Skipping MacroDefinition: CEED_QFUNCTION_HELPER CEED_QFUNCTION_HELPER_ATTR static inline
13681370

13691371
const CeedInt_FMT = "d"
13701372

@@ -1374,7 +1376,7 @@ const CEED_VERSION_MINOR = 11
13741376

13751377
const CEED_VERSION_PATCH = 0
13761378

1377-
const CEED_VERSION_RELEASE = true
1379+
const CEED_VERSION_RELEASE = false
13781380

13791381
# Skipping MacroDefinition: CEED_INTERN extern CEED_VISIBILITY ( hidden )
13801382

python/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from .ceed import Ceed
99
from .ceed_vector import Vector
10-
from .ceed_basis import Basis, BasisTensorH1, BasisTensorH1Lagrange, BasisH1
10+
from .ceed_basis import Basis, BasisTensorH1, BasisTensorH1Lagrange, BasisH1, BasisHdiv, BasisHcurl
1111
from .ceed_elemrestriction import ElemRestriction, StridedElemRestriction, BlockedElemRestriction, BlockedStridedElemRestriction
1212
from .ceed_qfunction import QFunction, QFunctionByName, IdentityQFunction
1313
from .ceed_operator import Operator, CompositeOperator
@@ -18,7 +18,7 @@
1818
# ------------------------------------------------------------------------------
1919
__all__ = ["Ceed",
2020
"Vector",
21-
"Basis", "BasisTensorH1", "BasisTensorH1Lagrange", "BasisH1",
21+
"Basis", "BasisTensorH1", "BasisTensorH1Lagrange", "BasisH1", "BasisHdiv", "BasisHcurl",
2222
"ElemRestriction", "StridedElemRestriction", "BlockedElemRestriction", "BlockedStridedelemRestriction",
2323
"QFunction", "QFunctionByName", "IdentityQFunction",
2424
"Operator", "CompositeOperator",

python/ceed.py

+55-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import tempfile
1414
from abc import ABC
1515
from .ceed_vector import Vector
16-
from .ceed_basis import BasisTensorH1, BasisTensorH1Lagrange, BasisH1
16+
from .ceed_basis import BasisTensorH1, BasisTensorH1Lagrange, BasisH1, BasisHdiv, BasisHcurl
1717
from .ceed_elemrestriction import ElemRestriction, StridedElemRestriction, BlockedElemRestriction, BlockedStridedElemRestriction
1818
from .ceed_qfunction import QFunction, QFunctionByName, IdentityQFunction
1919
from .ceed_qfunctioncontext import QFunctionContext
@@ -356,7 +356,7 @@ def BasisH1(self, topo, ncomp, nnodes, nqpts, interp, grad, qref, qweight):
356356
*interp: Numpy array holding the row-major (nqpts * nnodes) matrix
357357
expressing the values of nodal basis functions at
358358
quadrature points
359-
*grad: Numpy array holding the row-major (nqpts * dim * nnodes)
359+
*grad: Numpy array holding the row-major (dim * nqpts * nnodes)
360360
matrix expressing the derivatives of nodal basis functions
361361
at quadrature points
362362
*qref: Numpy array of length (nqpts * dim) holding the locations of
@@ -370,6 +370,59 @@ def BasisH1(self, topo, ncomp, nnodes, nqpts, interp, grad, qref, qweight):
370370
return BasisH1(self, topo, ncomp, nnodes, nqpts,
371371
interp, grad, qref, qweight)
372372

373+
def BasisHdiv(self, topo, ncomp, nnodes, nqpts, interp, div, qref, qweight):
374+
"""Ceed Hdiv Basis: finite element non tensor-product basis for H(div)
375+
discretizations.
376+
377+
Args:
378+
topo: topology of the element, e.g. hypercube, simplex, etc
379+
ncomp: number of field components (1 for scalar fields)
380+
nnodes: total number of nodes
381+
nqpts: total number of quadrature points
382+
*interp: Numpy array holding the row-major (dim * nqpts * nnodes)
383+
matrix expressing the values of basis functions at
384+
quadrature points
385+
*div: Numpy array holding the row-major (nqpts * nnodes) matrix
386+
expressing the divergence of basis functions at
387+
quadrature points
388+
*qref: Numpy array of length (nqpts * dim) holding the locations of
389+
quadrature points on the reference element [-1, 1]
390+
*qweight: Numpy array of length nnodes holding the quadrature
391+
weights on the reference element
392+
393+
Returns:
394+
basis: Ceed Basis"""
395+
396+
return BasisHdiv(self, topo, ncomp, nnodes, nqpts,
397+
interp, div, qref, qweight)
398+
399+
def BasisHcurl(self, topo, ncomp, nnodes, nqpts,
400+
interp, curl, qref, qweight):
401+
"""Ceed Hcurl Basis: finite element non tensor-product basis for H(curl)
402+
discretizations.
403+
404+
Args:
405+
topo: topology of the element, e.g. hypercube, simplex, etc
406+
ncomp: number of field components (1 for scalar fields)
407+
nnodes: total number of nodes
408+
nqpts: total number of quadrature points
409+
*interp: Numpy array holding the row-major (dim * nqpts * nnodes)
410+
matrix expressing the values of basis functions at
411+
quadrature points
412+
*curl: Numpy array holding the row-major (curlcomp * nqpts * nnodes),
413+
curlcomp = 1 if dim < 3 else dim, matrix expressing the curl
414+
of basis functions at quadrature points
415+
*qref: Numpy array of length (nqpts * dim) holding the locations of
416+
quadrature points on the reference element [-1, 1]
417+
*qweight: Numpy array of length nnodes holding the quadrature
418+
weights on the reference element
419+
420+
Returns:
421+
basis: Ceed Basis"""
422+
423+
return BasisHcurl(self, topo, ncomp, nnodes, nqpts,
424+
interp, curl, qref, qweight)
425+
373426
# CeedQFunction
374427
def QFunction(self, vlength, f, source):
375428
"""Ceed QFunction: point-wise operation at quadrature points for

python/tests/buildmats.py

+77
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,80 @@ def buildmats(qref, qweight, mat_dtype="float64"):
4747
grad[(i + Q) * P + 5] = 2. * (1. * (x2 - 1. / 2.) + x2 * 1.)
4848

4949
return interp, grad
50+
51+
52+
def buildmatshdiv(qref, qweight, mat_dtype="float64"):
53+
P, Q, dim = 4, 4, 2
54+
interp = np.empty(dim * P * Q, dtype=mat_dtype)
55+
div = np.empty(P * Q, dtype=mat_dtype)
56+
57+
qref[0] = -1. / np.sqrt(3.)
58+
qref[1] = qref[0]
59+
qref[2] = qref[0]
60+
qref[3] = -qref[0]
61+
qref[4] = -qref[0]
62+
qref[5] = -qref[0]
63+
qref[6] = qref[0]
64+
qref[7] = qref[0]
65+
qweight[0] = 1.
66+
qweight[1] = 1.
67+
qweight[2] = 1.
68+
qweight[3] = 1.
69+
70+
# Loop over quadrature points
71+
for i in range(Q):
72+
x1 = qref[0 * Q + i]
73+
x2 = qref[1 * Q + i]
74+
# Interp
75+
interp[(i + 0) * P + 0] = 0.
76+
interp[(i + Q) * P + 0] = 1. - x2
77+
interp[(i + 0) * P + 1] = x1 - 1.
78+
interp[(i + Q) * P + 1] = 0.
79+
interp[(i + 0) * P + 2] = -x1
80+
interp[(i + Q) * P + 2] = 0.
81+
interp[(i + 0) * P + 3] = 0.
82+
interp[(i + Q) * P + 3] = x2
83+
# Div
84+
div[i * P + 0] = -1.
85+
div[i * P + 1] = 1.
86+
div[i * P + 2] = -1.
87+
div[i * P + 3] = 1.
88+
89+
return interp, div
90+
91+
92+
def buildmatshcurl(qref, qweight, mat_dtype="float64"):
93+
P, Q, dim = 3, 4, 2
94+
interp = np.empty(dim * P * Q, dtype=mat_dtype)
95+
curl = np.empty(P * Q, dtype=mat_dtype)
96+
97+
qref[0] = 0.2
98+
qref[1] = 0.6
99+
qref[2] = 1. / 3.
100+
qref[3] = 0.2
101+
qref[4] = 0.2
102+
qref[5] = 0.2
103+
qref[6] = 1. / 3.
104+
qref[7] = 0.6
105+
qweight[0] = 25. / 96.
106+
qweight[1] = 25. / 96.
107+
qweight[2] = -27. / 96.
108+
qweight[3] = 25. / 96.
109+
110+
# Loop over quadrature points
111+
for i in range(Q):
112+
x1 = qref[0 * Q + i]
113+
x2 = qref[1 * Q + i]
114+
# Interp
115+
interp[(i + 0) * P + 0] = -x2
116+
interp[(i + Q) * P + 0] = x1
117+
interp[(i + 0) * P + 1] = x2
118+
interp[(i + Q) * P + 1] = 1. - x1
119+
interp[(i + 0) * P + 2] = 1. - x2
120+
interp[(i + Q) * P + 2] = x1
121+
# Curl
122+
curl[i * P + 0] = 2.
123+
curl[i * P + 1] = -2.
124+
curl[i * P + 2] = -2.
125+
126+
return interp, curl

0 commit comments

Comments
 (0)