Skip to content

Commit 5ef5e8a

Browse files
authored
Merge pull request #53 from PRONTOLab/as
Allow scalar
2 parents d3bb253 + 0d0068a commit 5ef5e8a

File tree

1 file changed

+61
-9
lines changed

1 file changed

+61
-9
lines changed

src/data_free_ocean_climate_simulation.jl

+61-9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,47 @@ using Dates
1212
using Printf
1313
using Profile
1414

15+
# https://github.com/CliMA/Oceananigans.jl/blob/da9959f3e5d8ee7cf2fb42b74ecc892874ec1687/src/AbstractOperations/conditional_operations.jl#L8
16+
Base.@nospecializeinfer function Reactant.traced_type_inner(
17+
@nospecialize(OA::Type{Oceananigans.AbstractOperations.ConditionalOperation{LX, LY, LZ, O, F, G, C, M, T}}),
18+
seen,
19+
mode::Reactant.TraceMode,
20+
@nospecialize(track_numbers::Type),
21+
@nospecialize(sharding),
22+
@nospecialize(runtime)
23+
) where {LX, LY, LZ, O, F, G, C, M, T}
24+
LX2 = Reactant.traced_type_inner(LX, seen, mode, track_numbers, sharding, runtime)
25+
LY2 = Reactant.traced_type_inner(LY, seen, mode, track_numbers, sharding, runtime)
26+
LZ2 = Reactant.traced_type_inner(LZ, seen, mode, track_numbers, sharding, runtime)
27+
O2 = Reactant.traced_type_inner(O, seen, mode, track_numbers, sharding, runtime)
28+
F2 = Reactant.traced_type_inner(F, seen, mode, track_numbers, sharding, runtime)
29+
G2 = Reactant.traced_type_inner(G, seen, mode, track_numbers, sharding, runtime)
30+
C2 = Reactant.traced_type_inner(C, seen, mode, track_numbers, sharding, runtime)
31+
M2 = Reactant.traced_type_inner(M, seen, mode, track_numbers, sharding, runtime)
32+
T2 = eltype(O2)
33+
return Oceananigans.AbstractOperations.ConditionalOperation{LX2, LY2, LZ2, O2, F2, G2, C2, M2, T2}
34+
end
35+
36+
# https://github.com/CliMA/Oceananigans.jl/blob/da9959f3e5d8ee7cf2fb42b74ecc892874ec1687/src/AbstractOperations/kernel_function_operation.jl#L3
37+
# struct KernelFunctionOperation{LX, LY, LZ, G, T, K, D} <: AbstractOperation{LX, LY, LZ, G, T}
38+
Base.@nospecializeinfer function Reactant.traced_type_inner(
39+
@nospecialize(OA::Type{Oceananigans.AbstractOperations.KernelFunctionOperation{LX, LY, LZ, G, T, K, D}}),
40+
seen,
41+
mode::Reactant.TraceMode,
42+
@nospecialize(track_numbers::Type),
43+
@nospecialize(sharding),
44+
@nospecialize(runtime)
45+
) where {LX, LY, LZ, G, T, K, D}
46+
LX2 = Reactant.traced_type_inner(LX, seen, mode, track_numbers, sharding, runtime)
47+
LY2 = Reactant.traced_type_inner(LY, seen, mode, track_numbers, sharding, runtime)
48+
LZ2 = Reactant.traced_type_inner(LZ, seen, mode, track_numbers, sharding, runtime)
49+
G2 = Reactant.traced_type_inner(G, seen, mode, track_numbers, sharding, runtime)
50+
K2 = Reactant.traced_type_inner(K, seen, mode, track_numbers, sharding, runtime)
51+
D2 = Reactant.traced_type_inner(D, seen, mode, track_numbers, sharding, runtime)
52+
T2 = eltype(G2)
53+
return Oceananigans.AbstractOperations.KernelFunctionOperation{LX2, LY2, LZ2, G2, T2, K2, D2}
54+
end
55+
1556
const PROFILE = Ref(false)
1657

1758
macro gbprofile(name::String, expr::Expr)
@@ -121,19 +162,31 @@ function gaussian_islands_tripolar_grid(arch::Architectures.AbstractArchitecture
121162
active_cells_map=false)
122163
end
123164

165+
function set_tracers(T, Ta, u, ua, shortwave, Qs)
166+
T .= Ta .+ 273.15
167+
u .= ua
168+
shortwave .= Qs
169+
nothing
170+
end
171+
124172
function data_free_ocean_climate_simulation_init(
125173
arch::Architectures.AbstractArchitecture=Architectures.ReactantState();
126174
# Horizontal resolution
127175
resolution::Real = 2, # 1/4 for quarter degree
128176
# Vertical resolution
129177
Nz::Int = 20, # eventually we want to increase this to between 100-600
178+
output::Bool = false
130179
)
131180

132181
grid = gaussian_islands_tripolar_grid(arch, resolution, Nz)
133182

134183
# See visualize_ocean_climate_simulation.jl for information about how to
135184
# visualize the results of this run.
136-
ocean = @gbprofile "ocean_simulation" ocean_simulation(grid)
185+
Δt=30seconds
186+
ocean = @gbprofile "ocean_simulation" ocean_simulation(grid;
187+
Δt,
188+
free_surface=ClimaOcean.OceanSimulations.default_free_surface(grid, fixed_Δt=Δt)
189+
)
137190

138191
@gbprofile "set_ocean_model" set!(ocean.model, T=Tᵢ, S=Sᵢ)
139192

@@ -156,10 +209,13 @@ function data_free_ocean_climate_simulation_init(
156209
set!(ua, zonal_wind)
157210
set!(Qs, sunlight)
158211

159-
parent(atmosphere.tracers.T) .= parent(Ta) .+ 273.15
160-
parent(atmosphere.velocities.u) .= parent(ua)
212+
if arch isa Architectures.ReactantState
213+
@jit set_tracers(parent(atmosphere.tracers.T), parent(Ta), parent(atmosphere.velocities.u), parent(ua), parent(atmosphere.downwelling_radiation.shortwave), parent(Qs))
214+
else
215+
set_tracers(parent(atmosphere.tracers.T), parent(Ta), parent(atmosphere.velocities.u), parent(ua), parent(atmosphere.downwelling_radiation.shortwave), parent(Qs))
216+
end
217+
161218
parent(atmosphere.tracers.q) .= 0
162-
parent(atmosphere.downwelling_radiation.shortwave) .= parent(Qs)
163219

164220
# Atmospheric model
165221
radiation = Radiation(arch)
@@ -174,10 +230,6 @@ function data_free_ocean_climate_simulation_init(
174230

175231
wall_time[] = time_ns()
176232

177-
if !(arch isa Architectures.ReactantState)
178-
add_callback!(simulation, progress, IterationInterval(10))
179-
end
180-
181233
# Output
182234
prefix = if arch isa Distributed
183235
"ocean_climate_simulation_rank$(arch.local_rank)"
@@ -187,7 +239,7 @@ function data_free_ocean_climate_simulation_init(
187239

188240
Nz = size(grid, 3)
189241
outputs = merge(ocean.model.velocities, ocean.model.tracers)
190-
if !(arch isa Architectures.ReactantState)
242+
if output && !(arch isa Architectures.ReactantState)
191243
surface_writer = JLD2OutputWriter(ocean.model, outputs,
192244
filename = prefix * "_surface.jld2",
193245
indices = (:, :, Nz),

0 commit comments

Comments
 (0)