Skip to content

Commit

Permalink
Support initial data attribute on startWorkflow (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
longquanzheng authored Nov 15, 2024
1 parent 815987c commit 4484a27
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 3 deletions.
21 changes: 21 additions & 0 deletions src/main/java/io/iworkflow/core/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ public String startWorkflow(
final Map<String, SearchAttributeValueType> saTypes = registry.getSearchAttributeKeyToTypeMap(wfType);
final List<SearchAttribute> convertedSAs = convertToSearchAttributeList(saTypes, options.getInitialSearchAttribute());
unregisterWorkflowOptions.initialSearchAttribute(convertedSAs);
checkInitialDataAttributes(registry.getDataAttributeTypeStore(wfType), options.getInitialDataAttribute());
unregisterWorkflowOptions.initialDataAttribute(options.getInitialDataAttribute());
}

final Optional<StateDef> stateDefOptional = registry.getWorkflowStartingState(wfType);
Expand Down Expand Up @@ -180,6 +182,25 @@ public String startWorkflow(
return unregisteredClient.startWorkflow(wfType, startStateId, workflowId, workflowTimeoutSeconds, input, unregisterWorkflowOptions.build());
}

private void checkInitialDataAttributes(final TypeStore dataAttributeTypeStore, final Map<String, Object> initialDataAttribute) {
if (initialDataAttribute.size() > 0) {
initialDataAttribute.forEach((key, val) -> {
if (!dataAttributeTypeStore.isValidNameOrPrefix(key)) {
throw new IllegalArgumentException(String.format("data attribute %s is not registered", key));
}
final Class<?> registeredType = dataAttributeTypeStore.getType(key);
final Class<?> requestedType = val.getClass();
if (!requestedType.isAssignableFrom(registeredType)) {
throw new IllegalArgumentException(
String.format(
"registered type %s is not assignable from %s",
registeredType.getName(),
requestedType.getName()));
}
});
}
}

private void checkWorkflowTypeExists(String wfType) {
final ObjectWorkflow wf = registry.getWorkflow(wfType);
if (wf == null) {
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/io/iworkflow/core/UnregisteredClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,14 @@ public String startWorkflow(
});
startOptions.searchAttributes(options.getInitialSearchAttribute());
}
if(options.getInitialDataAttribute().size()>0){
List<KeyValue> dataAttributes = options.getInitialDataAttribute().entrySet().stream()
.map(entry -> new KeyValue()
.key(entry.getKey())
.value(clientOptions.getObjectEncoder().encode(entry.getValue())))
.collect(Collectors.toList());
startOptions.dataAttributes(dataAttributes);
}

if (options.getStartStateOptions().isPresent()) {
request.stateOptions(options.getStartStateOptions().get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.immutables.value.Value;

import java.util.List;
import java.util.Map;
import java.util.Optional;

@Value.Immutable
Expand All @@ -25,6 +26,8 @@ public abstract class UnregisteredWorkflowOptions {

public abstract List<SearchAttribute> getInitialSearchAttribute();

public abstract Map<String, Object> getInitialDataAttribute();

public abstract Optional<WorkflowConfig> getWorkflowConfigOverride();

public abstract Optional<Boolean> getUsingMemoForDataAttributes();
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/io/iworkflow/core/WorkflowOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ public abstract class WorkflowOptions {

public abstract Map<String, Object> getInitialSearchAttribute();

public abstract Map<String, Object> getInitialDataAttribute();

public abstract Optional<WorkflowConfig> getWorkflowConfigOverride();

public abstract List<String> getWaitForCompletionStateIds();
Expand Down
12 changes: 9 additions & 3 deletions src/test/java/io/iworkflow/integ/PersistenceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.google.common.collect.ImmutableMap;
import io.iworkflow.core.Client;
import io.iworkflow.core.ClientOptions;
import io.iworkflow.core.WorkflowOptions;
import io.iworkflow.gen.models.SearchAttribute;
import io.iworkflow.gen.models.SearchAttributeValueType;
import io.iworkflow.integ.persistence.BasicPersistenceWorkflow;
Expand Down Expand Up @@ -52,18 +53,23 @@ public void testPersistenceWorkflow() throws InterruptedException {
final Client client = new Client(WorkflowRegistry.registry, ClientOptions.localDefault);
final String wfId = "basic-persistence-test-id" + System.currentTimeMillis() / 1000;
final String runId = client.startWorkflow(
BasicPersistenceWorkflow.class, wfId, 10, "start");
BasicPersistenceWorkflow.class, wfId, 10, "start",
WorkflowOptions.basicBuilder().initialDataAttribute(
ImmutableMap.of(BasicPersistenceWorkflow.TEST_INIT_DATA_OBJECT_KEY, "init-test-value"))
.build());
final String output = client.getSimpleWorkflowResultWithWait(String.class, wfId);
Assertions.assertEquals("test-value-2", output);

Map<String, Object> map =
client.getWorkflowDataAttributes(BasicPersistenceWorkflow.class, wfId, runId,
Arrays.asList(
BasicPersistenceWorkflow.TEST_INIT_DATA_OBJECT_KEY,
BasicPersistenceWorkflow.TEST_DATA_OBJECT_KEY,
BasicPersistenceWorkflow.TEST_DATA_OBJECT_PREFIX + "1",
BasicPersistenceWorkflow.TEST_DATA_OBJECT_PREFIX + "2"));
Assertions.assertEquals(
"query-start-query-decide", map.get(BasicPersistenceWorkflow.TEST_DATA_OBJECT_KEY));
Assertions.assertEquals("init-test-value", map.get(BasicPersistenceWorkflow.TEST_INIT_DATA_OBJECT_KEY));
Assertions.assertEquals(
11L, map.get(BasicPersistenceWorkflow.TEST_DATA_OBJECT_PREFIX + "1"));
Assertions.assertNull(map.get(BasicPersistenceWorkflow.TEST_DATA_OBJECT_PREFIX + "2"));
Expand All @@ -75,13 +81,13 @@ public void testPersistenceWorkflow() throws InterruptedException {
"query-start-query-decide", map2.get(BasicPersistenceWorkflow.TEST_DATA_OBJECT_KEY));

Map<String, Object> allDataObjects = client.getAllDataAttributes(BasicPersistenceWorkflow.class, wfId, runId);
Assertions.assertEquals(4, allDataObjects.size());
Assertions.assertEquals(5, allDataObjects.size());
Assertions.assertEquals("query-start-query-decide", allDataObjects.get(BasicPersistenceWorkflow.TEST_DATA_OBJECT_KEY));
Assertions.assertEquals(11L, allDataObjects.get(BasicPersistenceWorkflow.TEST_DATA_OBJECT_PREFIX + "1"));

// test no runId
Map<String, Object> allDataObjects2 = client.getAllDataAttributes(BasicPersistenceWorkflow.class, wfId);
Assertions.assertEquals(4, allDataObjects2.size());
Assertions.assertEquals(5, allDataObjects2.size());
Assertions.assertEquals("query-start-query-decide", allDataObjects2.get(BasicPersistenceWorkflow.TEST_DATA_OBJECT_KEY));
Assertions.assertEquals(11L, allDataObjects.get(BasicPersistenceWorkflow.TEST_DATA_OBJECT_PREFIX + "1"));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

@Component
public class BasicPersistenceWorkflow implements ObjectWorkflow {
public static final String TEST_INIT_DATA_OBJECT_KEY = "data-obj-0";
public static final String TEST_DATA_OBJECT_KEY = "data-obj-1";
public static final String TEST_DATA_OBJECT_MODEL_1 = "data-obj-2";

Expand All @@ -34,6 +35,7 @@ public List<StateDef> getWorkflowStates() {
@Override
public List<PersistenceFieldDef> getPersistenceSchema() {
return Arrays.asList(
DataAttributeDef.create(String.class, TEST_INIT_DATA_OBJECT_KEY),
DataAttributeDef.create(String.class, TEST_DATA_OBJECT_KEY),
DataAttributeDef.create(Context.class, TEST_DATA_OBJECT_MODEL_1),
DataAttributeDef.create(FakContextImpl.class, TEST_DATA_OBJECT_MODEL_2),
Expand Down

0 comments on commit 4484a27

Please sign in to comment.