Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce simulation memory usage for long plans #1337

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ private <Return> void stepEffectModel(

// Based on the task's return status, update its execution state and schedule its resumption.
if (status instanceof TaskStatus.Completed<Return>) {
final var children = new LinkedList<>(this.taskChildren.getOrDefault(task, Collections.emptySet()));
final var children = new LinkedList<>(Optional.ofNullable(this.taskChildren.remove(task))
.orElseGet(Collections::emptySet));

this.tasks.put(task, progress.completedAt(currentTime, children));
this.scheduledJobs.schedule(JobId.forTask(task), SubInstant.Tasks.at(currentTime));
Expand All @@ -218,14 +219,19 @@ private <Return> void stepEffectModel(
this.tasks.put(task, progress.continueWith(s.continuation()));
this.scheduledJobs.schedule(JobId.forTask(task), SubInstant.Tasks.at(currentTime.plus(s.delay())));
} else if (status instanceof TaskStatus.CallingTask<Return> s) {
final var target = TaskId.generate();
SimulationEngine.this.tasks.put(target, new ExecutionState.InProgress<>(currentTime, s.child().create(this.executor)));
SimulationEngine.this.taskParent.put(target, task);
SimulationEngine.this.taskChildren.computeIfAbsent(task, $ -> new HashSet<>()).add(target);
frame.signal(JobId.forTask(target));

this.tasks.put(task, progress.continueWith(s.continuation()));
this.waitingTasks.subscribeQuery(task, Set.of(SignalId.forTask(target)));
if (s.tailCall()) {
this.tasks.put(task, new ExecutionState.InProgress<>(progress.startOffset, s.child().create(this.executor)));
this.scheduledJobs.schedule(JobId.forTask(task), SubInstant.Tasks.at(currentTime));
} else {
final var target = TaskId.generate();
this.tasks.put(target, new ExecutionState.InProgress<>(currentTime, s.child().create(this.executor)));
this.taskParent.put(target, task);
this.taskChildren.computeIfAbsent(task, $ -> new HashSet<>()).add(target);
frame.signal(JobId.forTask(target));

this.tasks.put(task, progress.continueWith(s.continuation()));
this.waitingTasks.subscribeQuery(task, Set.of(SignalId.forTask(target)));
}
} else if (status instanceof TaskStatus.AwaitingCondition<Return> s) {
final var condition = ConditionId.generate();
this.conditions.put(condition, s.condition());
Expand Down Expand Up @@ -438,19 +444,20 @@ public static SimulationResults computeResults(

final var name = id.id();
final var resource = state.resource();
final boolean allowRLE = resource.allowRunLengthCompression();

switch (resource.getType()) {
case "real" -> realProfiles.put(
name,
Pair.of(
resource.getOutputType().getSchema(),
serializeProfile(elapsedTime, state, SimulationEngine::extractRealDynamics)));
serializeProfile(elapsedTime, state, SimulationEngine::extractRealDynamics, allowRLE)));

case "discrete" -> discreteProfiles.put(
name,
Pair.of(
resource.getOutputType().getSchema(),
serializeProfile(elapsedTime, state, SimulationEngine::extractDiscreteDynamics)));
serializeProfile(elapsedTime, state, SimulationEngine::extractDiscreteDynamics, allowRLE)));

default ->
throw new IllegalArgumentException(
Expand Down Expand Up @@ -602,11 +609,24 @@ private interface Translator<Target> {
<Dynamics> Target apply(Resource<Dynamics> resource, Dynamics dynamics);
}

private static <Target>
void appendProfileSegment(ArrayList<ProfileSegment<Target>> profile, Duration duration, Target value,
boolean allowRunLengthCompression) {
final int s = profile.size();
final ProfileSegment lastSeg = s > 0 ? profile.get(s - 1) : null;
if (allowRunLengthCompression && lastSeg != null && value.equals(lastSeg.dynamics())) {
profile.set(s - 1, new ProfileSegment<>(lastSeg.extent().plus(duration), value));
} else {
profile.add(new ProfileSegment<>(duration, value));
}
}

private static <Target, Dynamics>
List<ProfileSegment<Target>> serializeProfile(
final Duration elapsedTime,
final ProfilingState<Dynamics> state,
final Translator<Target> translator
final Translator<Target> translator,
final boolean allowRunLengthCompression
) {
final var profile = new ArrayList<ProfileSegment<Target>>(state.profile().segments().size());

Expand All @@ -615,18 +635,21 @@ List<ProfileSegment<Target>> serializeProfile(
var segment = iter.next();
while (iter.hasNext()) {
final var nextSegment = iter.next();

profile.add(new ProfileSegment<>(
nextSegment.startOffset().minus(segment.startOffset()),
translator.apply(state.resource(), segment.dynamics())));
appendProfileSegment(profile,
nextSegment.startOffset().minus(segment.startOffset()),
translator.apply(state.resource(), segment.dynamics()),
allowRunLengthCompression);
segment = nextSegment;
}

profile.add(new ProfileSegment<>(
elapsedTime.minus(segment.startOffset()),
translator.apply(state.resource(), segment.dynamics())));
appendProfileSegment(profile,
elapsedTime.minus(segment.startOffset()),
translator.apply(state.resource(), segment.dynamics()),
allowRunLengthCompression);
}

profile.trimToSize();

return profile;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ enum ContextType { Initializing, Reacting, Querying }
<Event> void emit(Event event, Topic<Event> topic);

void spawn(TaskFactory<?> task);

<Return> void call(TaskFactory<Return> task);

<Return> void tailCall(TaskFactory<Return> task);

void delay(Duration duration);

void waitUntil(Condition condition);
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ public <Return> void call(final TaskFactory<Return> task) {
throw new IllegalStateException("Cannot yield during initialization");
}

@Override
public <Return> void tailCall(final TaskFactory<Return> task) {
throw new IllegalStateException("Cannot yield during initialization");
}

@Override
public void delay(final Duration duration) {
throw new IllegalStateException("Cannot yield during initialization");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ public static <T> void call(final TaskFactory<T> task) {
context.get().call(task);
}

public static void tailCall(final Runnable task) {
tailCall(threaded(task));
}

public static <T> void tailCall(final Supplier<T> task) {
tailCall(threaded(task));
}

public static <T> void tailCall(final TaskFactory<T> task) {
context.get().tailCall(task);
}

public static void defer(final Duration duration, final Runnable task) {
spawn(replaying(() -> { delay(duration); spawn(task); }));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ public <Return> void call(final TaskFactory<Return> task) {
throw new IllegalStateException("Cannot schedule tasks in a query-only context");
}

@Override
public <Return> void tailCall(final TaskFactory<Return> task) {
throw new IllegalStateException("Cannot schedule tasks in a query-only context");
}

@Override
public void delay(final Duration duration) {
throw new IllegalStateException("Cannot yield in a query-only context");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,34 @@
import java.util.function.UnaryOperator;

public final class Registrar {

/**
* Whether to allow run length compression when saving resource profiles at the end of simulation by default.
*
* This compression is lossless in terms of the overall shape of the profile, but it will combine adjacent profile
* segments with the same value, thus obscuring the fact that multiple resource samples (again, all returning the same
* value) were taken within the segment.
*/
private static final boolean ALLOW_RUN_LENGTH_COMPRESSION_BY_DEFAULT = false;

private final Initializer builder;
private boolean allowRunLengthCompression = ALLOW_RUN_LENGTH_COMPRESSION_BY_DEFAULT;

public Registrar(final Initializer builder) {
this.builder = Objects.requireNonNull(builder);
}

public void allowRunLengthCompression(final boolean allow) {
this.allowRunLengthCompression = allow;
}

public boolean isInitializationComplete() {
return (ModelActions.context.get().getContextType() != Context.ContextType.Initializing);
}

public <Value> void discrete(final String name, final Resource<Value> resource, final ValueMapper<Value> mapper) {
this.builder.resource(name, makeResource("discrete", resource, mapper.getValueSchema(), mapper::serializeValue));
this.builder.resource(name, makeResource("discrete", resource, mapper.getValueSchema(), mapper::serializeValue,
allowRunLengthCompression));
}

public void real(final String name, final Resource<RealDynamics> resource) {
Expand All @@ -46,14 +62,16 @@ private void real(final String name, final Resource<RealDynamics> resource, Unar
"rate", ValueSchema.REAL))),
dynamics -> SerializedValue.of(Map.of(
"initial", SerializedValue.of(dynamics.initial),
"rate", SerializedValue.of(dynamics.rate)))));
"rate", SerializedValue.of(dynamics.rate))),
allowRunLengthCompression));
}

private static <Value> gov.nasa.jpl.aerie.merlin.protocol.model.Resource<Value> makeResource(
final String type,
final Resource<Value> resource,
final ValueSchema valueSchema,
final Function<Value, SerializedValue> serializer
final Function<Value, SerializedValue> serializer,
final boolean allowRunLengthCompression
) {
return new gov.nasa.jpl.aerie.merlin.protocol.model.Resource<>() {
@Override
Expand Down Expand Up @@ -82,6 +100,11 @@ public Value getDynamics(final Querier querier) {
return resource.getDynamics();
}
}

@Override
public boolean allowRunLengthCompression() {
return allowRunLengthCompression;
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ public <T> void call(final TaskFactory<T> task) {
});
}

@Override
public <T> void tailCall(final TaskFactory<T> task) {
this.memory.doOnce(() -> {
this.scheduler = null; // Relinquish the current scheduler before yielding, in case an exception is thrown.
this.scheduler = this.handle.tailCall(task);
});
}

@Override
public void delay(final Duration duration) {
this.memory.doOnce(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ public Scheduler call(final TaskFactory<?> child) {
return this.yield(TaskStatus.calling(child, ReplayingTask.this));
}

@Override
public Scheduler tailCall(final TaskFactory<?> child) {
return this.yield(TaskStatus.tailCalling(child, ReplayingTask.this));
}

@Override
public Scheduler await(final gov.nasa.jpl.aerie.merlin.protocol.model.Condition condition) {
return this.yield(TaskStatus.awaiting(condition, ReplayingTask.this));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ public interface TaskHandle {

Scheduler call(TaskFactory<?> child);

Scheduler tailCall(TaskFactory<?> child);

Scheduler await(gov.nasa.jpl.aerie.merlin.protocol.model.Condition condition);
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ public <T> void call(final TaskFactory<T> task) {
this.scheduler = this.handle.call(task);
}

@Override
public <T> void tailCall(final TaskFactory<T> task) {
this.scheduler = null; // Relinquish the current scheduler before yielding, in case an exception is thrown.
this.scheduler = this.handle.tailCall(task);
}

@Override
public void delay(final Duration duration) {
this.scheduler = null; // Relinquish the current scheduler before yielding, in case an exception is thrown.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ public Scheduler call(final TaskFactory<?> child) {
return this.yield(TaskStatus.calling(child, ThreadedTask.this));
}

@Override
public Scheduler tailCall(final TaskFactory<?> child) {
return this.yield(TaskStatus.tailCalling(child, ThreadedTask.this));
}

@Override
public Scheduler await(final gov.nasa.jpl.aerie.merlin.protocol.model.Condition condition) {
return this.yield(TaskStatus.awaiting(condition, ThreadedTask.this));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,18 @@ public interface Resource<Dynamics> {
* this resource. In other words, it cannot depend on any hidden state. </p>
*/
Dynamics getDynamics(Querier querier);

/**
* After a simulation completes the entire evolution of the dynamics of this resource will typically be serialized as
* a resource profile consisting of some number of sequential segments.
*
* If run length compression is allowed for this resource then whenever there is a "run" of two or more such segments,
* one after another with the same dynamics, they will be compressed into a single segment during that serialization.
* This does not change the represented evolution of the dynamics of the resource, but it loses the information that a
* sample was taken at the start of each segment after the first in such a run. If a mission model prefers not to
* lose that information then it can return false here.
*/
default boolean allowRunLengthCompression() {
return false;
}
}
Loading