Skip to content

Commit 63b1361

Browse files
Update Julia/Python/Rust/Fortran bindings
1 parent c74fa7d commit 63b1361

File tree

5 files changed

+401
-20
lines changed

5 files changed

+401
-20
lines changed

interface/ceed-fortran.c

+34
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,40 @@ CEED_EXTERN void fCeedBasisCreateH1(int *ceed, int *topo, int *num_comp, int *nn
449449
}
450450
}
451451

452+
#define fCeedBasisCreateHdiv FORTRAN_NAME(ceedbasiscreatehdiv, CEEDBASISCREATEHDIV)
453+
CEED_EXTERN void fCeedBasisCreateHdiv(int *ceed, int *topo, int *num_comp, int *nnodes, int *nqpts, const CeedScalar *interp, const CeedScalar *div,
454+
const CeedScalar *qref, const CeedScalar *qweight, int *basis, int *err) {
455+
if (CeedBasis_count == CeedBasis_count_max) {
456+
CeedBasis_count_max += CeedBasis_count_max / 2 + 1;
457+
CeedRealloc(CeedBasis_count_max, &CeedBasis_dict);
458+
}
459+
460+
*err = CeedBasisCreateHdiv(Ceed_dict[*ceed], (CeedElemTopology)*topo, *num_comp, *nnodes, *nqpts, interp, div, qref, qweight,
461+
&CeedBasis_dict[CeedBasis_count]);
462+
463+
if (*err == 0) {
464+
*basis = CeedBasis_count++;
465+
CeedBasis_n++;
466+
}
467+
}
468+
469+
#define fCeedBasisCreateHcurl FORTRAN_NAME(ceedbasiscreatehcurl, CEEDBASISCREATEHCURL)
470+
CEED_EXTERN void fCeedBasisCreateHcurl(int *ceed, int *topo, int *num_comp, int *nnodes, int *nqpts, const CeedScalar *interp, const CeedScalar *curl,
471+
const CeedScalar *qref, const CeedScalar *qweight, int *basis, int *err) {
472+
if (CeedBasis_count == CeedBasis_count_max) {
473+
CeedBasis_count_max += CeedBasis_count_max / 2 + 1;
474+
CeedRealloc(CeedBasis_count_max, &CeedBasis_dict);
475+
}
476+
477+
*err = CeedBasisCreateHcurl(Ceed_dict[*ceed], (CeedElemTopology)*topo, *num_comp, *nnodes, *nqpts, interp, curl, qref, qweight,
478+
&CeedBasis_dict[CeedBasis_count]);
479+
480+
if (*err == 0) {
481+
*basis = CeedBasis_count++;
482+
CeedBasis_n++;
483+
}
484+
}
485+
452486
#define fCeedBasisView FORTRAN_NAME(ceedbasisview, CEEDBASISVIEW)
453487
CEED_EXTERN void fCeedBasisView(int *basis, int *err) { *err = CeedBasisView(CeedBasis_dict[*basis], stdout); }
454488

julia/LibCEED.jl/src/Basis.jl

+156-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ created using one of:
1717
- [`create_tensor_h1_lagrange_basis`](@ref)
1818
- [`create_tensor_h1_basis`](@ref)
1919
- [`create_h1_basis`](@ref)
20+
- [`create_hdiv_basis`](@ref)
21+
- [`create_hcurl_basis`](@ref)
2022
"""
2123
mutable struct Basis <: AbstractBasis
2224
ref::RefValue{C.CeedBasis}
@@ -112,7 +114,7 @@ end
112114
@doc raw"""
113115
create_h1_basis(c::Ceed, topo::Topology, ncomp, nnodes, nqpts, interp, grad, qref, qweight)
114116
115-
Create a non tensor-product basis for H^1 discretizations
117+
Create a non tensor-product basis for $H^1$ discretizations
116118
117119
# Arguments:
118120
- `ceed`: A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
@@ -166,6 +168,121 @@ function create_h1_basis(
166168
Basis(ref)
167169
end
168170

171+
@doc raw"""
172+
create_hdiv_basis(c::Ceed, topo::Topology, ncomp, nnodes, nqpts, interp, div, qref, qweight)
173+
174+
Create a non tensor-product basis for H(div) discretizations
175+
176+
# Arguments:
177+
- `ceed`: A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
178+
- `topo`: [`Topology`](@ref) of element, e.g. hypercube, simplex, etc.
179+
- `ncomp`: Number of field components (1 for scalar fields).
180+
- `nnodes`: Total number of nodes.
181+
- `nqpts`: Total number of quadrature points.
182+
- `interp`: Matrix of size `(dim, nqpts, nnodes)` expressing the values of basis functions
183+
at quadrature points.
184+
- `div`: Array of size `(nqpts, nnodes)` expressing divergence of basis functions at
185+
quadrature points.
186+
- `qref`: Array of length `nqpts` holding the locations of quadrature points on the
187+
reference element $[-1, 1]$.
188+
- `qweight`: Array of length `nqpts` holding the quadrature weights on the reference
189+
element.
190+
"""
191+
function create_hdiv_basis(
192+
c::Ceed,
193+
topo::Topology,
194+
ncomp,
195+
nnodes,
196+
nqpts,
197+
interp::AbstractArray{CeedScalar},
198+
div::AbstractArray{CeedScalar},
199+
qref::AbstractArray{CeedScalar},
200+
qweight::AbstractArray{CeedScalar},
201+
)
202+
dim = getdimension(topo)
203+
@assert size(interp) == (dim, nqpts, nnodes)
204+
@assert size(div) == (nqpts, nnodes)
205+
@assert length(qref) == nqpts
206+
@assert length(qweight) == nqpts
207+
208+
# Convert from Julia matrices and tensors (column-major) to row-major format
209+
interp_rowmajor = permutedims(interp, [3, 2, 1])
210+
div_rowmajor = collect(div')
211+
212+
ref = Ref{C.CeedBasis}()
213+
C.CeedBasisCreateHdiv(
214+
c[],
215+
topo,
216+
ncomp,
217+
nnodes,
218+
nqpts,
219+
interp_rowmajor,
220+
div_rowmajor,
221+
qref,
222+
qweight,
223+
ref,
224+
)
225+
Basis(ref)
226+
end
227+
228+
@doc raw"""
229+
create_hdiv_basis(c::Ceed, topo::Topology, ncomp, nnodes, nqpts, interp, curl, qref, qweight)
230+
231+
Create a non tensor-product basis for H(div) discretizations
232+
233+
# Arguments:
234+
- `ceed`: A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
235+
- `topo`: [`Topology`](@ref) of element, e.g. hypercube, simplex, etc.
236+
- `ncomp`: Number of field components (1 for scalar fields).
237+
- `nnodes`: Total number of nodes.
238+
- `nqpts`: Total number of quadrature points.
239+
- `interp`: Matrix of size `(dim, nqpts, nnodes)` expressing the values of basis functions
240+
at quadrature points.
241+
- `curl`: Matrix of size `(curlcomp, nqpts, nnodes)`, `curlcomp = 1 if dim < 3 else dim`)
242+
matrix expressing curl of basis functions at quadrature points.
243+
- `qref`: Array of length `nqpts` holding the locations of quadrature points on the
244+
reference element $[-1, 1]$.
245+
- `qweight`: Array of length `nqpts` holding the quadrature weights on the reference
246+
element.
247+
"""
248+
function create_hdiv_basis(
249+
c::Ceed,
250+
topo::Topology,
251+
ncomp,
252+
nnodes,
253+
nqpts,
254+
interp::AbstractArray{CeedScalar},
255+
curl::AbstractArray{CeedScalar},
256+
qref::AbstractArray{CeedScalar},
257+
qweight::AbstractArray{CeedScalar},
258+
)
259+
dim = getdimension(topo)
260+
curlcomp = dim < 3 ? 1 : dim
261+
@assert size(interp) == (dim, nqpts, nnodes)
262+
@assert size(curl) == (curlcomp, nqpts, nnodes)
263+
@assert length(qref) == nqpts
264+
@assert length(qweight) == nqpts
265+
266+
# Convert from Julia matrices and tensors (column-major) to row-major format
267+
interp_rowmajor = permutedims(interp, [3, 2, 1])
268+
curl_rowmajor = permutedims(curl, [3, 2, 1])
269+
270+
ref = Ref{C.CeedBasis}()
271+
C.CeedBasisCreateHcurl(
272+
c[],
273+
topo,
274+
ncomp,
275+
nnodes,
276+
nqpts,
277+
interp_rowmajor,
278+
curl_rowmajor,
279+
qref,
280+
qweight,
281+
ref,
282+
)
283+
Basis(ref)
284+
end
285+
169286
"""
170287
apply!(b::Basis, nelem, tmode::TransposeMode, emode::EvalMode, u::AbstractCeedVector, v::AbstractCeedVector)
171288
@@ -353,7 +470,13 @@ function getinterp(b::Basis)
353470
C.CeedBasisGetInterp(b[], ref)
354471
q = getnumqpts(b)
355472
p = getnumnodes(b)
356-
collect(unsafe_wrap(Array, ref[], (p, q))')
473+
qcomp = Ref{CeedInt}()
474+
C.CeedBasisGetNumQuadratureComponents(b[], C.CEED_EVAL_INTERP, qcomp)
475+
if qcomp == 1
476+
collect(unsafe_wrap(Array, ref[], (p, q))')
477+
else
478+
permutedims(unsafe_wrap(Array, ref[], (p, q, qcomp)), [3, 2, 1])
479+
end
357480
end
358481

359482
"""
@@ -399,3 +522,34 @@ function getgrad1d(b::Basis)
399522
p = getnumnodes1d(b)
400523
collect(unsafe_wrap(Array, ref[], (p, q))')
401524
end
525+
526+
"""
527+
getdiv(b::Basis)
528+
529+
Get the divergence matrix of the given [`Basis`](@ref). Returns a tensor of size
530+
`(getnumqpts(b), getnumnodes(b))`.
531+
"""
532+
function getdiv(b::Basis)
533+
ref = Ref{Ptr{CeedScalar}}()
534+
C.CeedBasisGetDiv(b[], ref)
535+
q = getnumqpts(b)
536+
p = getnumnodes(b)
537+
collect(unsafe_wrap(Array, ref[], (p, q))')
538+
end
539+
540+
"""
541+
getcurl(b::Basis)
542+
543+
Get the curl matrix of the given [`Basis`](@ref). Returns a tensor of size
544+
`(curlcomp, getnumqpts(b), getnumnodes(b))`, `curlcomp = 1 if getdimension(b) < 3 else
545+
getdimension(b)`.
546+
"""
547+
function getcurl(b::Basis)
548+
ref = Ref{Ptr{CeedScalar}}()
549+
C.CeedBasisGetCurl(b[], ref)
550+
q = getnumqpts(b)
551+
p = getnumnodes(b)
552+
qcomp = Ref{CeedInt}()
553+
C.CeedBasisGetNumQuadratureComponents(b[], C.CEED_EVAL_CURL, qcomp)
554+
permutedims(unsafe_wrap(Array, ref[], (p, q, qcomp)), [3, 2, 1])
555+
end

julia/LibCEED.jl/src/generated/libceed_bindings.jl

+51-18
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ function CeedVectorReferenceCopy(vec, vec_copy)
152152
ccall((:CeedVectorReferenceCopy, libceed), Cint, (CeedVector, Ptr{CeedVector}), vec, vec_copy)
153153
end
154154

155+
function CeedVectorCopy(vec, vec_copy)
156+
ccall((:CeedVectorCopy, libceed), Cint, (CeedVector, CeedVector), vec, vec_copy)
157+
end
158+
155159
function CeedVectorSetArray(vec, mem_type, copy_mode, array)
156160
ccall((:CeedVectorSetArray, libceed), Cint, (CeedVector, CeedMemType, CeedCopyMode, Ptr{CeedScalar}), vec, mem_type, copy_mode, array)
157161
end
@@ -200,6 +204,10 @@ function CeedVectorAXPY(y, alpha, x)
200204
ccall((:CeedVectorAXPY, libceed), Cint, (CeedVector, CeedScalar, CeedVector), y, alpha, x)
201205
end
202206

207+
function CeedVectorAXPBY(y, alpha, beta, x)
208+
ccall((:CeedVectorAXPBY, libceed), Cint, (CeedVector, CeedScalar, CeedScalar, CeedVector), y, alpha, beta, x)
209+
end
210+
203211
function CeedVectorPointwiseMult(w, x, y)
204212
ccall((:CeedVectorPointwiseMult, libceed), Cint, (CeedVector, CeedVector, CeedVector), w, x, y)
205213
end
@@ -208,6 +216,10 @@ function CeedVectorReciprocal(vec)
208216
ccall((:CeedVectorReciprocal, libceed), Cint, (CeedVector,), vec)
209217
end
210218

219+
function CeedVectorViewRange(vec, start, stop, step, fp_fmt, stream)
220+
ccall((:CeedVectorViewRange, libceed), Cint, (CeedVector, CeedSize, CeedSize, CeedInt, Ptr{Cchar}, Ptr{Libc.FILE}), vec, start, stop, step, fp_fmt, stream)
221+
end
222+
211223
function CeedVectorView(vec, fp_fmt, stream)
212224
ccall((:CeedVectorView, libceed), Cint, (CeedVector, Ptr{Cchar}, Ptr{Libc.FILE}), vec, fp_fmt, stream)
213225
end
@@ -353,6 +365,10 @@ function CeedBasisCreateHdiv(ceed, topo, num_comp, num_nodes, nqpts, interp, div
353365
ccall((:CeedBasisCreateHdiv, libceed), Cint, (Ceed, CeedElemTopology, CeedInt, CeedInt, CeedInt, Ptr{CeedScalar}, Ptr{CeedScalar}, Ptr{CeedScalar}, Ptr{CeedScalar}, Ptr{CeedBasis}), ceed, topo, num_comp, num_nodes, nqpts, interp, div, q_ref, q_weights, basis)
354366
end
355367

368+
function CeedBasisCreateHcurl(ceed, topo, num_comp, num_nodes, nqpts, interp, curl, q_ref, q_weights, basis)
369+
ccall((:CeedBasisCreateHcurl, libceed), Cint, (Ceed, CeedElemTopology, CeedInt, CeedInt, CeedInt, Ptr{CeedScalar}, Ptr{CeedScalar}, Ptr{CeedScalar}, Ptr{CeedScalar}, Ptr{CeedBasis}), ceed, topo, num_comp, num_nodes, nqpts, interp, curl, q_ref, q_weights, basis)
370+
end
371+
356372
function CeedBasisCreateProjection(basis_from, basis_to, basis_project)
357373
ccall((:CeedBasisCreateProjection, libceed), Cint, (CeedBasis, CeedBasis, Ptr{CeedBasis}), basis_from, basis_to, basis_project)
358374
end
@@ -381,10 +397,6 @@ function CeedBasisGetTopology(basis, topo)
381397
ccall((:CeedBasisGetTopology, libceed), Cint, (CeedBasis, Ptr{CeedElemTopology}), basis, topo)
382398
end
383399

384-
function CeedBasisGetNumQuadratureComponents(basis, Q_comp)
385-
ccall((:CeedBasisGetNumQuadratureComponents, libceed), Cint, (CeedBasis, Ptr{CeedInt}), basis, Q_comp)
386-
end
387-
388400
function CeedBasisGetNumComponents(basis, num_comp)
389401
ccall((:CeedBasisGetNumComponents, libceed), Cint, (CeedBasis, Ptr{CeedInt}), basis, num_comp)
390402
end
@@ -433,6 +445,10 @@ function CeedBasisGetDiv(basis, div)
433445
ccall((:CeedBasisGetDiv, libceed), Cint, (CeedBasis, Ptr{Ptr{CeedScalar}}), basis, div)
434446
end
435447

448+
function CeedBasisGetCurl(basis, curl)
449+
ccall((:CeedBasisGetCurl, libceed), Cint, (CeedBasis, Ptr{Ptr{CeedScalar}}), basis, curl)
450+
end
451+
436452
function CeedBasisDestroy(basis)
437453
ccall((:CeedBasisDestroy, libceed), Cint, (Ptr{CeedBasis},), basis)
438454
end
@@ -716,32 +732,36 @@ function CeedOperatorGetFlopsEstimate(op, flops)
716732
ccall((:CeedOperatorGetFlopsEstimate, libceed), Cint, (CeedOperator, Ptr{CeedSize}), op, flops)
717733
end
718734

719-
function CeedOperatorContextGetFieldLabel(op, field_name, field_label)
720-
ccall((:CeedOperatorContextGetFieldLabel, libceed), Cint, (CeedOperator, Ptr{Cchar}, Ptr{CeedContextFieldLabel}), op, field_name, field_label)
735+
function CeedOperatorGetContext(op, ctx)
736+
ccall((:CeedOperatorGetContext, libceed), Cint, (CeedOperator, Ptr{CeedQFunctionContext}), op, ctx)
721737
end
722738

723-
function CeedOperatorContextSetDouble(op, field_label, values)
724-
ccall((:CeedOperatorContextSetDouble, libceed), Cint, (CeedOperator, CeedContextFieldLabel, Ptr{Cdouble}), op, field_label, values)
739+
function CeedOperatorGetContextFieldLabel(op, field_name, field_label)
740+
ccall((:CeedOperatorGetContextFieldLabel, libceed), Cint, (CeedOperator, Ptr{Cchar}, Ptr{CeedContextFieldLabel}), op, field_name, field_label)
725741
end
726742

727-
function CeedOperatorContextGetDoubleRead(op, field_label, num_values, values)
728-
ccall((:CeedOperatorContextGetDoubleRead, libceed), Cint, (CeedOperator, CeedContextFieldLabel, Ptr{Csize_t}, Ptr{Ptr{Cdouble}}), op, field_label, num_values, values)
743+
function CeedOperatorSetContextDouble(op, field_label, values)
744+
ccall((:CeedOperatorSetContextDouble, libceed), Cint, (CeedOperator, CeedContextFieldLabel, Ptr{Cdouble}), op, field_label, values)
729745
end
730746

731-
function CeedOperatorContextRestoreDoubleRead(op, field_label, values)
732-
ccall((:CeedOperatorContextRestoreDoubleRead, libceed), Cint, (CeedOperator, CeedContextFieldLabel, Ptr{Ptr{Cdouble}}), op, field_label, values)
747+
function CeedOperatorGetContextDoubleRead(op, field_label, num_values, values)
748+
ccall((:CeedOperatorGetContextDoubleRead, libceed), Cint, (CeedOperator, CeedContextFieldLabel, Ptr{Csize_t}, Ptr{Ptr{Cdouble}}), op, field_label, num_values, values)
733749
end
734750

735-
function CeedOperatorContextSetInt32(op, field_label, values)
736-
ccall((:CeedOperatorContextSetInt32, libceed), Cint, (CeedOperator, CeedContextFieldLabel, Ptr{Cint}), op, field_label, values)
751+
function CeedOperatorRestoreContextDoubleRead(op, field_label, values)
752+
ccall((:CeedOperatorRestoreContextDoubleRead, libceed), Cint, (CeedOperator, CeedContextFieldLabel, Ptr{Ptr{Cdouble}}), op, field_label, values)
737753
end
738754

739-
function CeedOperatorContextGetInt32Read(op, field_label, num_values, values)
740-
ccall((:CeedOperatorContextGetInt32Read, libceed), Cint, (CeedOperator, CeedContextFieldLabel, Ptr{Csize_t}, Ptr{Ptr{Cint}}), op, field_label, num_values, values)
755+
function CeedOperatorSetContextInt32(op, field_label, values)
756+
ccall((:CeedOperatorSetContextInt32, libceed), Cint, (CeedOperator, CeedContextFieldLabel, Ptr{Cint}), op, field_label, values)
741757
end
742758

743-
function CeedOperatorContextRestoreInt32Read(op, field_label, values)
744-
ccall((:CeedOperatorContextRestoreInt32Read, libceed), Cint, (CeedOperator, CeedContextFieldLabel, Ptr{Ptr{Cint}}), op, field_label, values)
759+
function CeedOperatorGetContextInt32Read(op, field_label, num_values, values)
760+
ccall((:CeedOperatorGetContextInt32Read, libceed), Cint, (CeedOperator, CeedContextFieldLabel, Ptr{Csize_t}, Ptr{Ptr{Cint}}), op, field_label, num_values, values)
761+
end
762+
763+
function CeedOperatorRestoreContextInt32Read(op, field_label, values)
764+
ccall((:CeedOperatorRestoreContextInt32Read, libceed), Cint, (CeedOperator, CeedContextFieldLabel, Ptr{Ptr{Cint}}), op, field_label, values)
745765
end
746766

747767
function CeedOperatorApply(op, in, out, request)
@@ -987,6 +1007,7 @@ end
9871007
@cenum CeedFESpace::UInt32 begin
9881008
CEED_FE_SPACE_H1 = 1
9891009
CEED_FE_SPACE_HDIV = 2
1010+
CEED_FE_SPACE_HCURL = 3
9901011
end
9911012

9921013
function CeedBasisGetCollocatedGrad(basis, colo_grad_1d)
@@ -1009,10 +1030,18 @@ function CeedBasisReference(basis)
10091030
ccall((:CeedBasisReference, libceed), Cint, (CeedBasis,), basis)
10101031
end
10111032

1033+
function CeedBasisGetNumQuadratureComponents(basis, eval_mode, q_comp)
1034+
ccall((:CeedBasisGetNumQuadratureComponents, libceed), Cint, (CeedBasis, CeedEvalMode, Ptr{CeedInt}), basis, eval_mode, q_comp)
1035+
end
1036+
10121037
function CeedBasisGetFlopsEstimate(basis, t_mode, eval_mode, flops)
10131038
ccall((:CeedBasisGetFlopsEstimate, libceed), Cint, (CeedBasis, CeedTransposeMode, CeedEvalMode, Ptr{CeedSize}), basis, t_mode, eval_mode, flops)
10141039
end
10151040

1041+
function CeedBasisGetFESpace(basis, fe_space)
1042+
ccall((:CeedBasisGetFESpace, libceed), Cint, (CeedBasis, Ptr{CeedFESpace}), basis, fe_space)
1043+
end
1044+
10161045
function CeedBasisGetTopologyDimension(topo, dim)
10171046
ccall((:CeedBasisGetTopologyDimension, libceed), Cint, (CeedElemTopology, Ptr{CeedInt}), topo, dim)
10181047
end
@@ -1033,6 +1062,10 @@ function CeedTensorContractApply(contract, A, B, C, J, t, t_mode, Add, u, v)
10331062
ccall((:CeedTensorContractApply, libceed), Cint, (CeedTensorContract, CeedInt, CeedInt, CeedInt, CeedInt, Ptr{CeedScalar}, CeedTransposeMode, CeedInt, Ptr{CeedScalar}, Ptr{CeedScalar}), contract, A, B, C, J, t, t_mode, Add, u, v)
10341063
end
10351064

1065+
function CeedTensorContractStridedApply(contract, A, B, C, D, J, t, t_mode, add, u, v)
1066+
ccall((:CeedTensorContractStridedApply, libceed), Cint, (CeedTensorContract, CeedInt, CeedInt, CeedInt, CeedInt, CeedInt, Ptr{CeedScalar}, CeedTransposeMode, CeedInt, Ptr{CeedScalar}, Ptr{CeedScalar}), contract, A, B, C, D, J, t, t_mode, add, u, v)
1067+
end
1068+
10361069
function CeedTensorContractGetCeed(contract, ceed)
10371070
ccall((:CeedTensorContractGetCeed, libceed), Cint, (CeedTensorContract, Ptr{Ceed}), contract, ceed)
10381071
end

0 commit comments

Comments
 (0)