Skip to content

Commit

Permalink
Merge pull request #6 from parea-ai/PAI-583-experiment-in-typescript-…
Browse files Browse the repository at this point in the history
…easy-auto-evals

feat(tracer): experiment, evals, new context manager
  • Loading branch information
jalexanderII authored Jan 29, 2024
2 parents 89ef0d4 + 31adc24 commit 7d3b202
Show file tree
Hide file tree
Showing 8 changed files with 474 additions and 149 deletions.
29 changes: 13 additions & 16 deletions src/api-client.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import axios, { AxiosInstance, AxiosResponse, AxiosError } from 'axios';
import axios, { AxiosError, AxiosInstance, AxiosResponse } from 'axios';
import axiosRetry from 'axios-retry';

interface RequestConfig {
Expand Down Expand Up @@ -28,21 +28,6 @@ export class HTTPClient {
this.client.interceptors.response.use(this.responseInterceptor, this.errorInterceptor);
}

private requestInterceptor(config: any) {
// TBD: Add any request modifications here
return config;
}

private responseInterceptor(response: AxiosResponse) {
// TBD: Add any response modifications here
return response;
}

private errorInterceptor(error: AxiosError) {
// TBD: Add any error modifications here
return Promise.reject(error);
}

public static getInstance(): HTTPClient {
if (!HTTPClient.instance) {
HTTPClient.instance = new HTTPClient();
Expand All @@ -69,4 +54,16 @@ export class HTTPClient {
throw error;
}
}

private requestInterceptor(config: any) {
return config;
}

private responseInterceptor(response: AxiosResponse) {
return response;
}

private errorInterceptor(error: AxiosError) {
return Promise.reject(error);
}
}
68 changes: 59 additions & 9 deletions src/client.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
import { Completion, CompletionResponse, FeedbackRequest, UseDeployedPrompt, UseDeployedPromptResponse } from './types';
import {
Completion,
CompletionResponse,
CreateExperimentRequest,
DataItem,
ExperimentSchema,
ExperimentStatsSchema,
FeedbackRequest,
UseDeployedPrompt,
UseDeployedPromptResponse,
} from './types';

import { HTTPClient } from './api-client';
import { pareaLogger } from './parea_logger';
import { genTraceId } from './helpers';
import { getCurrentTraceId, traceData } from './utils/trace_utils';
import { asyncLocalStorage } from './utils/trace_utils';
import { Experiment } from './experiment/experiment';

const COMPLETION_ENDPOINT = '/completion';
const DEPLOYED_PROMPT_ENDPOINT = '/deployed-prompt';
const RECORD_FEEDBACK_ENDPOINT = '/feedback';
// const EXPERIMENT_ENDPOINT = "/experiment"
// const EXPERIMENT_STATS_ENDPOINT = "/experiment/{experiment_uuid}/stats"
// const EXPERIMENT_FINISHED_ENDPOINT = "/experiment/{experiment_uuid}/finished"
const EXPERIMENT_ENDPOINT = '/experiment';
const EXPERIMENT_STATS_ENDPOINT = '/experiment/{experiment_uuid}/stats';
const EXPERIMENT_FINISHED_ENDPOINT = '/experiment/{experiment_uuid}/finished';

export class Parea {
private apiKey: string;
Expand All @@ -24,15 +35,29 @@ export class Parea {
}

public async completion(data: Completion): Promise<CompletionResponse> {
const parentTraceId = getCurrentTraceId();
let experiment_uuid;
const parentStore = asyncLocalStorage.getStore();
const parentTraceId = parentStore ? Array.from(parentStore.keys())[0] : undefined; // Assuming the last traceId is the parent

const inference_id = genTraceId();
data.inference_id = inference_id;
data.parent_trace_id = parentTraceId || inference_id;

if (process.env.PAREA_OS_ENV_EXPERIMENT_UUID) {
experiment_uuid = process.env.PAREA_OS_ENV_EXPERIMENT_UUID;
data.experiment_uuid = experiment_uuid;
}

const response = await this.client.request({ method: 'POST', endpoint: COMPLETION_ENDPOINT, data });

if (parentTraceId) {
traceData[parentTraceId].children.push(inference_id);
await pareaLogger.recordLog(traceData[parentTraceId]);
if (parentStore && parentTraceId) {
const parentTraceLog = parentStore.get(parentTraceId);
if (parentTraceLog) {
parentTraceLog.traceLog.children.push(inference_id);
parentTraceLog.traceLog.experiment_uuid = experiment_uuid;
parentStore.set(parentTraceId, parentTraceLog);
await pareaLogger.recordLog(parentTraceLog.traceLog);
}
}

return response.data;
Expand All @@ -47,4 +72,29 @@ export class Parea {
await new Promise((resolve) => setTimeout(resolve, 2000)); // give logs time to update
await this.client.request({ method: 'POST', endpoint: RECORD_FEEDBACK_ENDPOINT, data });
}

public async createExperiment(data: CreateExperimentRequest): Promise<ExperimentSchema> {
const response = await this.client.request({ method: 'POST', endpoint: EXPERIMENT_ENDPOINT, data });
return response.data;
}

public async getExperimentStats(experimentUUID: string): Promise<ExperimentStatsSchema> {
const response = await this.client.request({
method: 'GET',
endpoint: EXPERIMENT_STATS_ENDPOINT.replace('{experiment_uuid}', experimentUUID),
});
return response.data;
}

public async finishExperiment(experimentUUID: string): Promise<ExperimentStatsSchema> {
const response = await this.client.request({
method: 'POST',
endpoint: EXPERIMENT_FINISHED_ENDPOINT.replace('{experiment_uuid}', experimentUUID),
});
return response.data;
}

public experiment(name: string, data: Iterable<DataItem>, func: (dataItem: DataItem) => Promise<any>): Experiment {
return new Experiment(name, data, func, this);
}
}
91 changes: 91 additions & 0 deletions src/experiment/experiment.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import { DataItem, ExperimentStatsSchema, TraceStatsSchema } from '../types';
import { Parea } from '../client';
import { asyncPool } from '../helpers';

function calculateAvgAsString(values: number[] | undefined): string {
if (!values || values.length === 0) {
return 'N/A';
}
const filteredValues = values.filter((x) => x !== null);
const avg = filteredValues.reduce((acc, curr) => acc + curr, 0) / filteredValues.length;
return avg.toFixed(2);
}

function calculateAvgStdForExperiment(experimentStats: ExperimentStatsSchema): { [key: string]: string } {
const traceStats: TraceStatsSchema[] = experimentStats.parent_trace_stats;
const latencyValues = traceStats.map((traceStat) => traceStat.latency || 0);
const inputTokensValues = traceStats.map((traceStat) => traceStat.input_tokens || 0);
const outputTokensValues = traceStats.map((traceStat) => traceStat.output_tokens || 0);
const totalTokensValues = traceStats.map((traceStat) => traceStat.total_tokens || 0);
const costValues = traceStats.map((traceStat) => traceStat.cost || 0);
const scoreNameToValues: { [key: string]: number[] } = {};

traceStats.forEach((traceStat) => {
traceStat.scores?.forEach((score) => {
if (!scoreNameToValues[score.name]) {
scoreNameToValues[score.name] = [];
}
scoreNameToValues[score.name].push(score.score);
});
});

const result: { [key: string]: string } = {
latency: calculateAvgAsString(latencyValues),
input_tokens: calculateAvgAsString(inputTokensValues),
output_tokens: calculateAvgAsString(outputTokensValues),
total_tokens: calculateAvgAsString(totalTokensValues),
cost: calculateAvgAsString(costValues),
};

Object.keys(scoreNameToValues).forEach((scoreName) => {
result[scoreName] = calculateAvgAsString(scoreNameToValues[scoreName]);
});

return result;
}

export async function experiment(
name: string,
data: Iterable<Record<string, any>>,
func: (...args: any[]) => Promise<any>,
p: Parea,
maxParallelCalls: number = 10,
): Promise<ExperimentStatsSchema> {
const experimentSchema = await p.createExperiment({ name });
const experimentUUID = experimentSchema.uuid;
process.env.PAREA_OS_ENV_EXPERIMENT_UUID = experimentUUID;

const tasksGenerator = asyncPool(maxParallelCalls, data, async (dataInput) => {
return func(dataInput);
});

for await (const _ of tasksGenerator) {
// Purposely ignore. Result not needed
void _;
}

const experimentStats: ExperimentStatsSchema = await p.finishExperiment(experimentUUID);
const statNameToAvgStd = calculateAvgStdForExperiment(experimentStats);
console.log(`Experiment stats:\n${JSON.stringify(statNameToAvgStd, null, 2)}\n\n`);
console.log(`View experiment & its traces at: https://app.parea.ai/experiments/${experimentUUID}\n`);
return experimentStats;
}

export class Experiment {
name: string;
data: Iterable<DataItem>;
func: (dataItem: DataItem) => Promise<any>;
p: Parea;
experimentStats?: ExperimentStatsSchema;

constructor(name: string, data: Iterable<DataItem>, func: (dataItem: DataItem) => Promise<any>, p: Parea) {
this.name = name;
this.data = data;
this.func = func;
this.p = p;
}

async run(): Promise<void> {
this.experimentStats = await experiment(this.name, this.data, this.func, this.p);
}
}
40 changes: 40 additions & 0 deletions src/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,43 @@ export function genTraceId(): string {
export function toDateTimeString(date: Date): string {
return moment(date).format('YYYY-MM-DD HH:mm:ss z');
}

export async function* asyncPool<T, R>(
concurrency: number,
iterable: Iterable<T>,
iteratorFn: (item: T) => Promise<R>,
): AsyncGenerator<R, void, unknown> {
const executing = new Set<Promise<R>>();

async function consume(): Promise<R> {
if (executing.size === 0) {
throw new Error('Attempted to consume with no promises executing.');
}
const finishedPromise = Promise.race(executing);
executing.delete(finishedPromise);
return finishedPromise;
}

for (const item of iterable) {
while (executing.size >= concurrency) {
yield await consume();
}

const taskPromise = iteratorFn(item).then(
(result: R) => {
executing.delete(taskPromise);
return result;
},
(error: any) => {
executing.delete(taskPromise);
throw error;
},
);

executing.add(taskPromise);
}

while (executing.size > 0) {
yield await consume();
}
}
Loading

0 comments on commit 7d3b202

Please sign in to comment.