Skip to content

Commit

Permalink
Merge pull request #9 from parea-ai/fix-experiment
Browse files Browse the repository at this point in the history
fix(experiment): parsing of inputs for experiments
  • Loading branch information
jalexanderII authored Jan 30, 2024
2 parents ee288dc + 4746439 commit 6f0e926
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
5 changes: 3 additions & 2 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ export class Parea {
return response.data;
}

public experiment(name: string, data: Iterable<DataItem>, func: (dataItem: DataItem) => Promise<any>): Experiment {
return new Experiment(name, data, func, this);
public experiment(name: string, data: Iterable<DataItem>, func: (...dataItem: any[]) => Promise<any>): Experiment {
const convertedData: Iterable<any[]> = Array.from(data).map((item) => Object.values(item));
return new Experiment(name, convertedData, func, this);
}
}
18 changes: 10 additions & 8 deletions src/experiment/experiment.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { DataItem, ExperimentStatsSchema, TraceStatsSchema } from '../types';
import { ExperimentStatsSchema, TraceStatsSchema } from '../types';
import { Parea } from '../client';
import { asyncPool } from '../helpers';

Expand Down Expand Up @@ -46,8 +46,8 @@ function calculateAvgStdForExperiment(experimentStats: ExperimentStatsSchema): {

async function experiment(
name: string,
data: Iterable<Record<string, any>>,
func: (...args: any[]) => Promise<any>,
data: Iterable<any[]>,
func: (...dataItem: any[]) => Promise<any>,
p: Parea,
maxParallelCalls: number = 10,
): Promise<ExperimentStatsSchema> {
Expand All @@ -56,7 +56,7 @@ async function experiment(
process.env.PAREA_OS_ENV_EXPERIMENT_UUID = experimentUUID;

const tasksGenerator = asyncPool(maxParallelCalls, data, async (dataInput) => {
return func(dataInput);
return func(...dataInput);
});

for await (const _ of tasksGenerator) {
Expand All @@ -73,19 +73,21 @@ async function experiment(

export class Experiment {
name: string;
data: Iterable<DataItem>;
func: (dataItem: DataItem) => Promise<any>;
data: Iterable<any[]>;
func: (...dataItem: any[]) => Promise<any>;
p: Parea;
experimentStats?: ExperimentStatsSchema;

constructor(name: string, data: Iterable<DataItem>, func: (dataItem: DataItem) => Promise<any>, p: Parea) {
constructor(name: string, data: Iterable<any[]>, func: (...dataItem: any[]) => Promise<any>, p: Parea) {
this.name = name;
this.data = data;
this.func = func;
this.p = p;
}

async run(): Promise<void> {
this.experimentStats = await experiment(this.name, this.data, this.func, this.p);
this.experimentStats = new ExperimentStatsSchema(
(await experiment(this.name, this.data, this.func, this.p)).parent_trace_stats,
);
}
}
10 changes: 9 additions & 1 deletion src/utils/trace_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,14 @@ function extractFunctionParams(func: Function, args: any[]): { [key: string]: an

// Constructing an object of paramName: value
return paramNames.reduce((acc, paramName, index) => {
return { ...acc, [paramName]: typeof args[index] === 'string' ? args[index] : JSON.stringify(args[index]) };
return {
...acc,
[paramName]:
typeof args[index] === 'string'
? args[index]
: Array.isArray(args[index])
? args[index]
: JSON.stringify(args[index]),
};
}, {});
}

0 comments on commit 6f0e926

Please sign in to comment.