Skip to content

Commit

Permalink
feat: add external memo
Browse files Browse the repository at this point in the history
  • Loading branch information
yjl9903 committed Mar 7, 2024
1 parent b97d1ee commit e4d9b70
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/async.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ export function memoAsync<F extends AsyncFn>(
};

memoFunc.remove = async (...args) => {
const cur = walkOrBreak<F>(root, args as Parameters<F>);
const path = options.serialize ? options.serialize.bind(memoFunc)(...args) : args;
const cur = walkOrBreak<F, any[]>(root, path);

clearNode(cur);
if (options.external) {
await options.external.remove
Expand Down
93 changes: 93 additions & 0 deletions src/external.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import type { AsyncFn, MemoFunc, MemoExternalOptions } from './types';

import { State, clearNode, makeNode, walkAndCreate, walkOrBreak } from './trie';

export function memoExternal<F extends AsyncFn>(
fn: F,
options: MemoExternalOptions<F>
): MemoFunc<F> {
const root = makeNode<F>();

const memoFunc = async function (...args: Parameters<F>) {
// Serialize args
const path = options.serialize ? options.serialize.bind(memoFunc)(...args) : args;
const cur = walkAndCreate<F, any[]>(root, path);

// if (cur.state === State.Ok) {
// return cur.value;
// } else if (cur.state === State.Error) {
// throw cur.error;
// } else
if (cur.state === State.Waiting) {
return new Promise((res, rej) => {
if (!cur.callbacks) {
cur.callbacks = new Set();
}
cur.callbacks!.add({ res, rej });
});
} else {
try {
cur.state = State.Waiting;

const externalOnError = options.external.error ?? (() => undefined);
const external = await options.external.get.bind(memoFunc)(args).catch(externalOnError);

const hasExternalCache = external !== undefined && external !== null;
const value = hasExternalCache ? external : await fn(...args);

cur.state = State.Ok;
cur.value = value;

if (!hasExternalCache) {
await options.external.set.bind(memoFunc)(args, value).catch(externalOnError);
}

// Resolve other waiting callbacks
for (const callback of cur.callbacks ?? []) {
callback.res(value);
}

return value;
} catch (error) {
cur.state = State.Error;
cur.error = error;

// Reject other waiting callbacks
for (const callback of cur.callbacks ?? []) {
callback.rej(error);
}

throw error;
}
}
} as MemoFunc<F>;

memoFunc.get = (...args) => {
return memoFunc(...args);
};

memoFunc.raw = (...args) => {
return fn(...args) as ReturnType<F>;
};

memoFunc.remove = async (...args) => {
const path = options.serialize ? options.serialize.bind(memoFunc)(...args) : args;
const cur = walkOrBreak<F, any[]>(root, path);

clearNode(cur);
await options.external.remove
.bind(memoFunc)(args as Parameters<F>)
.catch(options.external?.error ?? (() => undefined));
};

memoFunc.clear = async () => {
clearNode(root);
await options.external.clear
.bind(memoFunc)()
.catch(options.external?.error ?? (() => undefined));
};

memoFunc.external = options.external;

return memoFunc;
}
4 changes: 3 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ export * from './sync';

export * from './async';

export type { MemoFunc, MemoOptions } from './types';
export * from './external';

export type { MemoFunc, MemoOptions, MemoAsyncOptions, MemoExternalOptions } from './types';
17 changes: 17 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ export interface MemoAsyncOptions<F extends Fn> extends MemoOptions<F> {
};
}

export interface MemoExternalOptions<F extends Fn> extends MemoOptions<F> {
external: {
get: (
this: MemoFunc<F>,
args: Parameters<F>
) => Promise<Awaited<ReturnType<F>> | undefined | null>;

set: (this: MemoFunc<F>, args: Parameters<F>, value: Awaited<ReturnType<F>>) => Promise<void>;

remove: (this: MemoFunc<F>, args: Parameters<F>) => Promise<void>;

clear: (this: MemoFunc<F>) => Promise<void>;

error?: (err: unknown) => void | Promise<void>;
};
}

export type Fn = (...params: any[]) => any;

export type AsyncFn = (...params: any[]) => Promise<any>;
Expand Down
107 changes: 106 additions & 1 deletion test/memo.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { describe, it, expect } from 'vitest';

import { memo, memoAsync } from '../src';
import { memo, memoAsync, memoExternal } from '../src';

describe('memo sync', () => {
it('should work', () => {
Expand Down Expand Up @@ -101,6 +101,111 @@ describe('memo async', () => {
});
});

describe('memo external', () => {
it('should prefer external cache', async () => {
const func = memoExternal(async () => 1, {
external: {
async get() {
return 2;
},
async set() {},
async clear() {},
async remove() {}
}
});

expect(await func()).toBe(2);
expect(await func()).toBe(2);
expect(await func()).toBe(2);
expect(await func()).toBe(2);
});

it('should skip external cache', async () => {
let cnt = 0;
const func = memoExternal(async () => ++cnt, {
external: {
async get() {
return undefined;
},
async set() {},
async clear() {},
async remove() {}
}
});

expect(await func()).toBe(1);
expect(await func()).toBe(2);
expect(await func()).toBe(3);
expect(await func()).toBe(4);
});

it('should get external cache once', async () => {
let cnt = 0;
const func = memoExternal(async () => ++cnt, {
external: {
async get() {
await sleep(100);
return undefined;
},
async set() {},
async clear() {},
async remove() {}
}
});

const tasks = await Promise.all([func(), func(), func(), func(), func()]);
expect(tasks).toStrictEqual([1, 1, 1, 1, 1]);
});

it('should get external cache twice', async () => {
let cnt = 0;
const func = memoExternal(async () => ++cnt, {
external: {
async get() {
await sleep(100);
return undefined;
},
async set() {},
async clear() {},
async remove() {}
}
});

const tasks = await Promise.all([func(), func(), func(), func(), func()]);
expect(tasks).toStrictEqual([1, 1, 1, 1, 1]);

const tasks2 = await Promise.all([func(), func(), func(), func(), func()]);
expect(tasks2).toStrictEqual([2, 2, 2, 2, 2]);
});

it('should get external cache after removing', async () => {
let cnt = 0;
const func = memoExternal(async () => 0, {
external: {
async get() {
await sleep(100);
return ++cnt;
},
async set() {},
async clear() {
cnt = 0;
},
async remove() {
cnt = 0;
}
}
});

const tasks = await Promise.all([func(), func(), func(), func(), func()]);
expect(tasks).toStrictEqual([1, 1, 1, 1, 1]);

func.clear();

const tasks2 = await Promise.all([func(), func(), func(), func(), func()]);
expect(tasks2).toStrictEqual([1, 1, 1, 1, 1]);
});
});

function sleep(time: number): Promise<void> {
return new Promise((res) => {
setTimeout(() => res(), time);
Expand Down

0 comments on commit e4d9b70

Please sign in to comment.