Skip to content

Commit

Permalink
Fix several bugs in cross-multiplication implementation (#2123)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vampire authored Mar 6, 2025
1 parent db76b97 commit a3a77a2
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ protected int estimateNumIterations(@Nullable Object dataProvider) {
return UNKNOWN_ITERATIONS;
}

if (dataProvider == null) {
return UNKNOWN_ITERATIONS;
}

// an IDataIterator probably already has the estimated size
// or knows better how to calculate it
if (dataProvider instanceof IDataIterator) {
Expand Down Expand Up @@ -107,18 +111,21 @@ protected int estimateNumIterations(Object[] dataProviders) {
}

protected boolean haveNext(Iterator<?>[] iterators, List<DataProviderInfo> dataProviderInfos) {
Assert.that(iterators.length == dataProviderInfos.size());
boolean result = true;

for (int i = 0; i < iterators.length; i++)
try {
boolean hasNext = iterators[i].hasNext();
if (i == 0) {
result = hasNext;
} else if (result != hasNext) {
DataProviderInfo provider = dataProviderInfos.get(i);
supervisor.error(context.getErrorInfoCollector(), new ErrorInfo(provider.getDataProviderMethod(),
createDifferentNumberOfDataValuesException(provider, hasNext), getErrorContext()));
return false;
result = iterators[0].hasNext();
} else if (iterators[i] != null) {
boolean hasNext = iterators[i].hasNext();
if (result != hasNext) {
DataProviderInfo provider = dataProviderInfos.get(i);
supervisor.error(context.getErrorInfoCollector(), new ErrorInfo(provider.getDataProviderMethod(),
createDifferentNumberOfDataValuesException(provider, hasNext), getErrorContext()));
return false;
}
}

} catch (Throwable t) {
Expand Down Expand Up @@ -353,6 +360,9 @@ public FeatureDataProviderIterator(IRunSupervisor supervisor, SpockExecutionCont
dataVariableNames = dataVariableNames();
dataProviders = createDataProviders();
dataProviderIterators = createDataProviderIterators();
if ((dataProviderIterators != null) && (dataProviders != null)) {
Assert.that(dataProviderIterators.length == dataProviders.length);
}
estimatedNumIterations = estimateNumIterations(dataProviderIterators);
}

Expand Down Expand Up @@ -380,12 +390,16 @@ public Object[] next() {
if (context.getErrorInfoCollector().hasErrors()) {
return null;
}
Assert.that(dataProviders.length == context.getCurrentFeature().getDataProviders().size());
firstIteration = false;

// advances iterators and computes args
Object[] next = new Object[dataProviders.length];
for (int i = 0; i < dataProviders.length; ) {
try {
if (dataProviderIterators[i] == null) {
continue;
}
// if the filter block excluded an iteration
// this might be called after the last iteration
// so just return null if no further data is available
Expand Down Expand Up @@ -476,14 +490,14 @@ private IDataIterator[] createDataProviderIterators() {
}

List<DataProviderInfo> dataProviderInfos = context.getCurrentFeature().getDataProviders();
List<IDataIterator> dataIterators = new ArrayList<>(dataProviders.length);
IDataIterator[] dataIterators = new IDataIterator[dataProviders.length];
for (int dataProviderIndex = 0, dataVariableNameIndex = 0; dataProviderIndex < dataProviders.length; dataProviderIndex++, dataVariableNameIndex++) {
String nextDataVariableName = dataVariableNames.get(dataVariableNameIndex);
if ((nextDataVariableMultiplication != null)
&& (nextDataVariableMultiplication.getDataVariables()[0].equals(nextDataVariableName))) {

// a cross multiplication starts
dataIterators.add(createDataProviderMultiplier(nextDataVariableMultiplication, dataProviderIndex));
dataIterators[dataProviderIndex] = createDataProviderMultiplier(nextDataVariableMultiplication, dataProviderIndex);
// skip processed providers and variables
int remainingVariables = nextDataVariableMultiplication.getDataVariables().length;
dataVariableNameIndex += remainingVariables - 1;
Expand All @@ -497,12 +511,14 @@ private IDataIterator[] createDataProviderIterators() {
nextDataVariableMultiplication = dataVariableMultiplications.hasNext() ? dataVariableMultiplications.next() : null;
} else {
// not a cross multiplication, just use a data provider iterator
dataIterators.add(new DataProviderIterator(
DataProviderInfo dataProviderInfo = dataProviderInfos.get(dataProviderIndex);
dataIterators[dataProviderIndex] = new DataProviderIterator(
supervisor, context, nextDataVariableName,
dataProviderInfos.get(dataProviderIndex), dataProviders[dataProviderIndex]));
dataProviderInfo, dataProviders[dataProviderIndex]);
dataVariableNameIndex += dataProviderInfo.getDataVariables().size() - 1;
}
}
return dataIterators.toArray(new IDataIterator[0]);
return dataIterators;
}

/**
Expand Down Expand Up @@ -534,6 +550,7 @@ private DataProviderMultiplier createDataProviderMultiplier(DataVariableMultipli

return new DataProviderMultiplier(supervisor, context,
Arrays.asList(dataVariableMultiplication.getDataVariables()),
multiplierProvider.multiplierProviderInfos.subList(0, 1),
multiplicandProviderInfos, multiplierProvider,
multiplicandProviders.toArray(new Object[0]));
} else {
Expand Down Expand Up @@ -610,9 +627,9 @@ public DataProviderIterator(IRunSupervisor supervisor, SpockExecutionContext con
DataProviderInfo providerInfo, Object provider) {
super(supervisor, context);
this.dataVariableNames = singletonList(Objects.requireNonNull(dataVariableName));
estimatedNumIterations = estimateNumIterations(Objects.requireNonNull(provider));
estimatedNumIterations = estimateNumIterations(provider);
this.providerInfo = Objects.requireNonNull(providerInfo);
this.provider = provider;
this.provider = Objects.requireNonNull(provider);
iterator = createIterator(provider, providerInfo);
}

Expand Down Expand Up @@ -670,8 +687,6 @@ private static class DataProviderMultiplier extends BaseDataIterator {
/**
* The provider infos for the multiplier data providers.
* These are only used for constructing meaningful errors.
* If {@link #multiplierProviders} is set, this will be set too,
* if it is {@code null}, this will be {@code null} too.
*/
private final List<DataProviderInfo> multiplierProviderInfos;

Expand Down Expand Up @@ -788,6 +803,10 @@ public DataProviderMultiplier(IRunSupervisor supervisor, SpockExecutionContext c
multiplierIterators = createIterators(this.multiplierProviders, this.multiplierProviderInfos);
multiplicandIterators = createIterators(this.multiplicandProviders, this.multiplicandProviderInfos);

if (multiplicandIterators != null) {
Assert.that(multiplicandProviderInfos.size() == multiplicandIterators.length);
}

int estimatedMultiplierIterations = estimateNumIterations(Objects.requireNonNull(multiplierProviders));
int estimatedMultiplicandIterations = estimateNumIterations(Objects.requireNonNull(multiplicandProviders));
estimatedNumIterations =
Expand All @@ -812,11 +831,11 @@ public DataProviderMultiplier(IRunSupervisor supervisor, SpockExecutionContext c
* @param multiplicandProviders the actual providers for the sets of multiplicand values
*/
public DataProviderMultiplier(IRunSupervisor supervisor, SpockExecutionContext context, List<String> dataVariableNames,
List<DataProviderInfo> multiplicandProviderInfos,
List<DataProviderInfo> multiplierProviderInfos, List<DataProviderInfo> multiplicandProviderInfos,
DataProviderMultiplier multiplierProvider, Object[] multiplicandProviders) {
super(supervisor, context);
this.dataVariableNames = Objects.requireNonNull(dataVariableNames);
multiplierProviderInfos = null;
this.multiplierProviderInfos = multiplierProviderInfos;
this.multiplicandProviderInfos = multiplicandProviderInfos;
this.multiplierProvider = multiplierProvider;
multiplierProviders = null;
Expand All @@ -826,6 +845,10 @@ public DataProviderMultiplier(IRunSupervisor supervisor, SpockExecutionContext c
multiplierIterators = new Iterator[]{multiplierProvider};
multiplicandIterators = createIterators(this.multiplicandProviders, this.multiplicandProviderInfos);

if (multiplicandIterators != null) {
Assert.that(multiplicandProviderInfos.size() == multiplicandIterators.length);
}

int estimatedMultiplierIterations = Objects.requireNonNull(multiplierProvider).getEstimatedNumIterations();
int estimatedMultiplicandIterations = estimateNumIterations(Objects.requireNonNull(multiplicandProviders));
estimatedNumIterations =
Expand Down Expand Up @@ -943,6 +966,7 @@ private Iterator<?>[] createIterators(Object[] dataProviders, List<DataProviderI
if (context.getErrorInfoCollector().hasErrors()) {
return null;
}
Assert.that(dataProviders.length == dataProviderInfos.size());

Iterator<?>[] iterators = new Iterator<?>[dataProviders.length];
for (int i = 0; i < dataProviders.length; i++) {
Expand All @@ -957,6 +981,7 @@ private Iterator<?>[] createIterators(Object[] dataProviders, List<DataProviderI
}

protected Object[] extractNextValues(Iterator<?>[] iterators, List<DataProviderInfo> providerInfos) {
Assert.that(iterators.length == providerInfos.size());
Object[] result = new Object[iterators.length];
for (int i = 0; i < iterators.length; i++) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class DataSpec extends EmbeddedSpecification {
}`""" == expected
}

def "only single data providers are combined"() {
def "only single data providers are combined - part 1"() {
given:
def expected = '''
tag::single-data-providers-combined-result1[]
Expand Down Expand Up @@ -239,9 +239,11 @@ class DataSpec extends EmbeddedSpecification {
*.displayName
.join('`\n- `')
}`""" == expected
}

when:
expected = '''
def "only single data providers are combined - part 2"() {
given:
def expected = '''
tag::single-data-providers-combined-result2[]
- `feature [a: 1, b: 3, c: 5, #0]`
- `feature [a: 1, b: 3, c: 6, #1]`
Expand All @@ -254,7 +256,9 @@ class DataSpec extends EmbeddedSpecification {
.findAll {it.startsWith('-') }
.join('\n')
.trim()
result = runner.runSpecBody '''

when:
def result = runner.runSpecBody '''
def feature() {
expect:
true
Expand Down
Loading

0 comments on commit a3a77a2

Please sign in to comment.