Skip to content

Commit

Permalink
refactor(core): move to reactor for handling execution inputs (#5383)
Browse files Browse the repository at this point in the history
related-to: #5383
  • Loading branch information
fhussonnois committed Oct 15, 2024
1 parent acd2ce9 commit 05d1eea
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 154 deletions.
89 changes: 46 additions & 43 deletions core/src/main/java/io/kestra/core/runners/FlowInputOutput.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.google.common.collect.ImmutableMap;
import io.kestra.core.encryption.EncryptionService;
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.exceptions.KestraRuntimeException;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.flows.Data;
import io.kestra.core.models.flows.DependsOn;
Expand Down Expand Up @@ -34,7 +35,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import reactor.core.publisher.Mono;

import java.io.File;
import java.io.FileOutputStream;
Expand Down Expand Up @@ -93,13 +94,12 @@ public FlowInputOutput(
* @param data The Execution's inputs data.
* @return The list of {@link InputAndValue}.
*/
public List<InputAndValue> validateExecutionInputs(final List<Input<?>> inputs,
public Mono<List<InputAndValue>> validateExecutionInputs(final List<Input<?>> inputs,
final Execution execution,
final Publisher<CompletedPart> data) throws IOException {
if (ListUtils.isEmpty(inputs)) return Collections.emptyList();
final Publisher<CompletedPart> data) {
if (ListUtils.isEmpty(inputs)) return Mono.just(Collections.emptyList());

Map<String, ?> dataByInputId = readData(inputs, execution, data, false);
return this.resolveInputs(inputs, execution, dataByInputId);
return readData(inputs, execution, data, false).map(inputData -> resolveInputs(inputs, execution, inputData));
}

/**
Expand All @@ -110,9 +110,9 @@ public List<InputAndValue> validateExecutionInputs(final List<Input<?>> inputs,
* @param data The Execution's inputs data.
* @return The Map of typed inputs.
*/
public Map<String, Object> readExecutionInputs(final Flow flow,
public Mono<Map<String, Object>> readExecutionInputs(final Flow flow,
final Execution execution,
final Publisher<CompletedPart> data) throws IOException {
final Publisher<CompletedPart> data) {
return this.readExecutionInputs(flow.getInputs(), execution, data);
}

Expand All @@ -124,48 +124,51 @@ public Map<String, Object> readExecutionInputs(final Flow flow,
* @param data The Execution's inputs data.
* @return The Map of typed inputs.
*/
public Map<String, Object> readExecutionInputs(final List<Input<?>> inputs,
final Execution execution,
final Publisher<CompletedPart> data) throws IOException {
return this.readExecutionInputs(inputs, execution, readData(inputs, execution, data, true));
public Mono<Map<String, Object>> readExecutionInputs(final List<Input<?>> inputs,
final Execution execution,
final Publisher<CompletedPart> data) {
return readData(inputs, execution, data, true).map(inputData -> this.readExecutionInputs(inputs, execution, inputData));
}

private Map<String, ?> readData(List<Input<?>> inputs, Execution execution, Publisher<CompletedPart> data, boolean uploadFiles) throws IOException {
private Mono<Map<String, Object>> readData(List<Input<?>> inputs, Execution execution, Publisher<CompletedPart> data, boolean uploadFiles) {
return Flux.from(data)
.subscribeOn(Schedulers.boundedElastic())
.map(throwFunction(input -> {
if (input instanceof CompletedFileUpload fileUpload) {
if (!uploadFiles) {
// only build the storage URI
final String fileExtension = FileInput.findFileInputExtension(inputs, fileUpload.getFilename());
URI from = URI.create("kestra://" + StorageContext
.forInput(execution, fileUpload.getFilename(), fileUpload.getFilename() + fileExtension)
.getContextStorageURI()
);
return new AbstractMap.SimpleEntry<>(fileUpload.getFilename(), from.toString());
}
final String fileExtension = FileInput.findFileInputExtension(inputs, fileUpload.getFilename());
File tempFile = File.createTempFile(fileUpload.getFilename() + "_", fileExtension);
try (var inputStream = fileUpload.getInputStream();
var outputStream = new FileOutputStream(tempFile)) {
long transferredBytes = inputStream.transferTo(outputStream);
if (transferredBytes == 0) {
throw new RuntimeException("Can't upload file: " + fileUpload.getFilename());
.<AbstractMap.SimpleEntry<String, String>>handle((input, sink) -> {
try {
if (input instanceof CompletedFileUpload fileUpload) {
if (!uploadFiles) {
final String fileExtension = FileInput.findFileInputExtension(inputs, fileUpload.getFilename());
URI from = URI.create("kestra://" + StorageContext
.forInput(execution, fileUpload.getFilename(), fileUpload.getFilename() + fileExtension)
.getContextStorageURI()
);
sink.next(new AbstractMap.SimpleEntry<>(fileUpload.getFilename(), from.toString()));
return;
}
final String fileExtension = FileInput.findFileInputExtension(inputs, fileUpload.getFilename());

URI from = storageInterface.from(execution, fileUpload.getFilename(), tempFile);
return new AbstractMap.SimpleEntry<>(fileUpload.getFilename(), from.toString());
} finally {
if (!tempFile.delete()) {
tempFile.deleteOnExit();
}
File tempFile = File.createTempFile(fileUpload.getFilename() + "_", fileExtension);
try (var inputStream = fileUpload.getInputStream();
var outputStream = new FileOutputStream(tempFile)) {
long transferredBytes = inputStream.transferTo(outputStream);
if (transferredBytes == 0) {
sink.error(new KestraRuntimeException("Can't upload file: " + fileUpload.getFilename()));
return;
}
URI from = storageInterface.from(execution, fileUpload.getFilename(), tempFile);
sink.next(new AbstractMap.SimpleEntry<>(fileUpload.getFilename(), from.toString()));
} finally {
if (!tempFile.delete()) {
tempFile.deleteOnExit();
}
}
} else {
sink.next(new AbstractMap.SimpleEntry<>(input.getName(), new String(input.getBytes())));
}
} else {
return new AbstractMap.SimpleEntry<>(input.getName(), new String(input.getBytes()));
} catch (IOException e) {
sink.error(e);
}
}))
.collectMap(AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue)
.block();
})
.collectMap(AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue);
}

/**
Expand Down
83 changes: 54 additions & 29 deletions core/src/main/java/io/kestra/core/services/ExecutionService.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import io.kestra.core.events.CrudEvent;
import io.kestra.core.events.CrudEventType;
import io.kestra.core.exceptions.InternalException;
import io.kestra.core.models.executions.*;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.executions.ExecutionKilled;
import io.kestra.core.models.executions.ExecutionKilledExecution;
import io.kestra.core.models.executions.TaskRun;
import io.kestra.core.models.executions.TaskRunAttempt;
import io.kestra.core.models.flows.Flow;
import io.kestra.core.models.flows.State;
import io.kestra.core.models.flows.input.InputAndValue;
Expand All @@ -26,7 +30,6 @@
import io.kestra.plugin.core.flow.WorkingDirectory;
import io.micronaut.context.event.ApplicationEventPublisher;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.http.HttpResponse;
import io.micronaut.http.multipart.CompletedPart;
import jakarta.inject.Inject;
import jakarta.inject.Named;
Expand All @@ -38,12 +41,21 @@
import lombok.extern.slf4j.Slf4j;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.io.IOException;
import java.net.URI;
import java.time.Instant;
import java.time.ZonedDateTime;
import java.util.*;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Predicate;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -449,16 +461,17 @@ public Execution resume(Execution execution, Flow flow, State.Type newState) thr
* @return the execution in the new state.
* @throws Exception if the state of the execution cannot be updated
*/
public List<InputAndValue> validateForResume(final Execution execution, Flow flow, @Nullable Publisher<CompletedPart> inputs) throws Exception {
Task task = getFirstPausedTaskOrThrow(execution, flow);
if (task instanceof Pause pauseTask) {
return flowInputOutput.validateExecutionInputs(
pauseTask.getOnResume(),
execution,
inputs
);
}
return Collections.emptyList();
public Mono<List<InputAndValue>> validateForResume(final Execution execution, Flow flow, @Nullable Publisher<CompletedPart> inputs) {
return getFirstPausedTaskOrThrow(execution, flow).handle((task, sink) -> {
if (task instanceof Pause pauseTask) {
flowInputOutput.validateExecutionInputs(
pauseTask.getOnResume(),
execution,
inputs
).subscribe(sink::next, sink::error);
}
sink.next(Collections.emptyList());
});
}

/**
Expand All @@ -472,25 +485,37 @@ public List<InputAndValue> validateForResume(final Execution execution, Flow flo
* @return the execution in the new state.
* @throws Exception if the state of the execution cannot be updated
*/
public Execution resume(final Execution execution, Flow flow, State.Type newState, @Nullable Publisher<CompletedPart> inputs) throws Exception {
var task = getFirstPausedTaskOrThrow(execution, flow);
Map<String, Object> pauseOutputs = Collections.emptyMap();
if (task instanceof Pause pauseTask) {
pauseOutputs = flowInputOutput.readExecutionInputs(
pauseTask.getOnResume(),
execution,
inputs
);
}
public Mono<Execution> resume(final Execution execution, Flow flow, State.Type newState, @Nullable Publisher<CompletedPart> inputs) {
return getFirstPausedTaskOrThrow(execution, flow).handle((task, sink) -> {
Mono<Map<String, Object>> monoOutputs;

return resume(execution, flow, newState, pauseOutputs);
if (task instanceof Pause pauseTask) {
monoOutputs = flowInputOutput.readExecutionInputs(pauseTask.getOnResume(), execution, inputs);
} else {
monoOutputs = Mono.just(Collections.emptyMap());
}
Mono<Execution> monoExecution = monoOutputs.handle((outputs, monoSink) -> {
try {
sink.next(resume(execution, flow, newState, outputs));
} catch (Exception e) {
sink.error(e);
}
});
monoExecution.subscribe(sink::next, sink::error);
});
}

private static Task getFirstPausedTaskOrThrow(Execution execution, Flow flow) throws InternalException {
var runningTaskRun = execution
.findFirstByState(State.Type.PAUSED)
.orElseThrow(() -> new IllegalArgumentException("No paused task found on execution " + execution.getId()));
return flow.findTaskByTaskId(runningTaskRun.getTaskId());
private static Mono<Task> getFirstPausedTaskOrThrow(Execution execution, Flow flow){
return Mono.create(sink -> {
try {
var runningTaskRun = execution
.findFirstByState(State.Type.PAUSED)
.orElseThrow(() -> new IllegalArgumentException("No paused task found on execution " + execution.getId()));
sink.success(flow.findTaskByTaskId(runningTaskRun.getTaskId()));
} catch (InternalException e) {
sink.error(e);
}
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ void shouldNotUploadFileInputAfterValidation() throws IOException {
Publisher<CompletedPart> data = Mono.just(new MemoryCompletedFileUpload("input", "input", "???".getBytes(StandardCharsets.UTF_8)));

// When
List<InputAndValue> values = flowInputOutput.validateExecutionInputs(List.of(input), DEFAULT_TEST_EXECUTION, data);
List<InputAndValue> values = flowInputOutput.validateExecutionInputs(List.of(input), DEFAULT_TEST_EXECUTION, data).block();

// Then
Assertions.assertNull(values.getFirst().exception());
Expand Down
4 changes: 2 additions & 2 deletions core/src/test/java/io/kestra/plugin/core/flow/PauseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ public void runOnResume(RunnerUtils runnerUtils) throws Exception {
flow,
State.Type.RUNNING,
Flux.just(part1, part2)
);
).block();

execution = runnerUtils.awaitExecution(
e -> e.getId().equals(executionId) && e.getState().getCurrent() == State.Type.SUCCESS,
Expand All @@ -243,7 +243,7 @@ public void runOnResumeMissingInputs(RunnerUtils runnerUtils) throws Exception {

ConstraintViolationException e = assertThrows(
ConstraintViolationException.class,
() -> executionService.resume(execution, flow, State.Type.RUNNING, Mono.empty())
() -> executionService.resume(execution, flow, State.Type.RUNNING, Mono.empty()).block()
);

assertThat(e.getMessage(), containsString("Invalid input for `asked`, missing required input, but received `null`"));
Expand Down
Loading

0 comments on commit 05d1eea

Please sign in to comment.