Skip to content

Commit

Permalink
Fix asset graph throttling (#26370)
Browse files Browse the repository at this point in the history
## Summary & Motivation

The use of Lodash's `throttle` caused incorrect promise handling due to
its behavior:

> "Subsequent calls to the throttled function return the **result of the
last func invocation**."
> ([Lodash docs](https://lodash.com/docs/#throttle))

This led to promises resolving with stale results from previous
requests.

To fix this, I replaced `throttle` with a custom `throttleLatest`
implementation that ensures the promise returned corresponds to the
current request, not a previous one.

Additionally, I resolved a potential issue with worker listeners not
being cleaned up properly, ensuring correct promise resolution.

## How I Tested These Changes

- Added Jest tests for `throttleLatest` with the required behavior.
- Verified that the asset graph renders correctly on initial load in the
UI.

## Changelog

[ui] Fixed an issue that would sometimes cause the asset graph to fail
to render on initial load.
  • Loading branch information
salazarm authored Dec 10, 2024
1 parent 319b36d commit 4eba622
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@ import {ComputeGraphDataMessageType} from '../../src/asset-graph/ComputeGraphDat

// eslint-disable-next-line import/no-default-export
export default class MockWorker {
onmessage = (_: any) => {};
onmessage: Array<(data: any) => void> = [];

addEventListener(_type: string, handler: any) {
this.onmessage = handler;
this.onmessage.push(handler);
}

removeEventListener(_type: string, handler: any) {
const index = this.onmessage.indexOf(handler);
if (index !== -1) {
this.onmessage.splice(index, 1);
}
}

// mock expects data: { } instead of e: { data: { } }
Expand All @@ -17,7 +24,7 @@ export default class MockWorker {
setFeatureFlagsInternal({flagAssetSelectionSyntax: true});
}
const state = await computeGraphData(data);
this.onmessage({data: state});
this.onmessage.forEach((onmessage) => onmessage({data: {...state, id: data.id}}));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ import {
} from '../pipelines/GraphNotices';
import {ExplorerPath} from '../pipelines/PipelinePathUtils';
import {StaticSetFilter} from '../ui/BaseFilters/useStaticSetFilter';
import {Loading} from '../ui/Loading';
import {Loading, LoadingSpinner} from '../ui/Loading';

type AssetNode = AssetNodeForGraphQueryFragment;

Expand Down Expand Up @@ -152,6 +152,9 @@ export const AssetGraphExplorer = (props: Props) => {
return (
<Loading allowStaleData queryResult={fetchResult}>
{() => {
if (graphDataLoading || filteredAssetsLoading) {
return <LoadingSpinner purpose="page" />;
}
if (!assetGraphData || !allAssetKeys || !fullAssetGraphData) {
return <NonIdealState icon="error" title="Query Error" />;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export function computeGraphData({
opsQuery,
kinds: _kinds,
hideEdgesToNodesOutsideQuery,
}: Omit<ComputeGraphDataMessageType, 'type'>): GraphDataState {
}: Omit<ComputeGraphDataMessageType, 'id' | 'type'>): GraphDataState {
if (repoFilteredNodes === undefined || graphQueryItems === undefined) {
return {
allAssetKeys: [],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import {AssetNodeForGraphQueryFragment} from './types/useAssetGraphData.types';
import {AssetGraphFetchScope, AssetGraphQueryItem} from './useAssetGraphData';

export type ComputeGraphDataMessageType = {
id: number;
type: 'computeGraphData';
repoFilteredNodes?: AssetNodeForGraphQueryFragment[];
graphQueryItems?: AssetGraphQueryItem[];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ self.addEventListener('message', async (event: MessageEvent & {data: WorkerMessa
setFeatureFlags({[FeatureFlag.flagAssetSelectionSyntax]: true});
}
const state = await computeGraphData(data);
self.postMessage(state);
self.postMessage({...state, id: data.id});
}
});

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import {throttleLatest} from '../throttleLatest';

jest.useFakeTimers();

describe('throttleLatest', () => {
let mockFunction: jest.Mock<Promise<string>, [number]>;
let throttledFunction: (arg: number) => Promise<string>;

beforeEach(() => {
jest.clearAllMocks();
mockFunction = jest.fn((arg: number) => {
return Promise.resolve(`Result: ${arg}`);
});
throttledFunction = throttleLatest(mockFunction, 2000);
});

it('should execute the first call immediately', async () => {
const promise = throttledFunction(1);
expect(mockFunction).toHaveBeenCalledWith(1);

await expect(promise).resolves.toBe('Result: 1');
});

it('should throttle subsequent calls within wait time and reject previous promises', async () => {
const promise1 = throttledFunction(1);
const promise2 = throttledFunction(2);

await expect(promise1).rejects.toThrow('Throttled: A new call has been made.');

expect(mockFunction).toHaveBeenCalledTimes(1);

jest.runAllTimers();

await expect(promise2).resolves.toBe('Result: 2');
});

it('should allow a new call after the wait time', async () => {
const promise1 = throttledFunction(1);

jest.advanceTimersByTime(1000);

const promise2 = throttledFunction(2);

await expect(promise1).rejects.toThrow('Throttled: A new call has been made.');

jest.advanceTimersByTime(1000);

await expect(promise2).resolves.toBe('Result: 2');

const promise3 = throttledFunction(3);

await jest.runAllTimers();

await expect(promise3).resolves.toBe('Result: 3');

expect(mockFunction).toHaveBeenCalledTimes(3);
expect(mockFunction).toHaveBeenNthCalledWith(3, 3);
});

it('should handle multiple rapid calls correctly', async () => {
const promise1 = throttledFunction(1);
await Promise.resolve();

throttledFunction(2);

const promise3 = throttledFunction(3);

await jest.runAllTimers();

expect(mockFunction).toHaveBeenNthCalledWith(1, 1);
expect(mockFunction).toHaveBeenCalledTimes(2);
expect(mockFunction).toHaveBeenNthCalledWith(2, 3);
await expect(promise1).resolves.toBe('Result: 1');
await expect(promise3).resolves.toBe('Result: 3');
});

it('should reject the previous active promise when a new call is made before it resolves', async () => {
// Modify mockFunction to return a promise that doesn't resolve immediately
mockFunction.mockImplementationOnce((arg: number) => {
return new Promise((resolve) => {
setTimeout(() => resolve(`Result: ${arg}`), 5000);
});
});

const promise1 = throttledFunction(1);

// After 100ms, make a new call
jest.advanceTimersByTime(100);
const promise2 = throttledFunction(2);

// The first promise should be rejected
await expect(promise1).rejects.toThrow('Throttled: A new call has been made.');

// The second promise is scheduled to execute after the remaining time (2000 - 100 = 1900ms)
jest.advanceTimersByTime(1900);

// Now, the second call should resolve
await expect(promise2).resolves.toBe('Result: 2');
});

it('should handle function rejection correctly', async () => {
mockFunction.mockImplementationOnce(() => {
return Promise.reject(new Error('Function failed'));
});

const promise1 = throttledFunction(1);
jest.runAllTimers();

await expect(promise1).rejects.toThrow('Function failed');
});

it('should not reject promises if no new call is made within wait time', async () => {
const promise1 = throttledFunction(1);

// No subsequent calls
jest.runAllTimers();

await expect(promise1).resolves.toBe('Result: 1');
});

it('should handle multiple sequential calls with enough time between them', async () => {
const promise1 = throttledFunction(1);
jest.runAllTimers();
await expect(promise1).resolves.toBe('Result: 1');

const promise2 = throttledFunction(2);
jest.runAllTimers();
await expect(promise2).resolves.toBe('Result: 2');

const promise3 = throttledFunction(3);
jest.runAllTimers();
await expect(promise3).resolves.toBe('Result: 3');

expect(mockFunction).toHaveBeenCalledTimes(3);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
export type ThrottledFunction<T extends (...args: any[]) => Promise<any>> = (
...args: Parameters<T>
) => ReturnType<T>;

export function throttleLatest<T extends (...args: any[]) => Promise<any>>(
func: T,
wait: number,
): ThrottledFunction<T> {
let timeout: NodeJS.Timeout | null = null;
let lastCallTime: number = 0;
let activeReject: ((reason?: any) => void) | null = null;

return function (...args: Parameters<T>): ReturnType<T> {
const now = Date.now();

return new Promise((resolve, reject) => {
// If a call is already active, reject its promise
if (activeReject) {
activeReject(new Error('Throttled: A new call has been made.'));
activeReject = null;
}

const execute = () => {
lastCallTime = Date.now();
activeReject = reject;

func(...args)
.then((result) => {
resolve(result);
activeReject = null;
})
.catch((error) => {
reject(error);
activeReject = null;
});
};

const remaining = wait - (now - lastCallTime);
if (remaining <= 0) {
if (timeout) {
clearTimeout(timeout);
timeout = null;
}
execute();
} else {
if (timeout) {
clearTimeout(timeout);
}
timeout = setTimeout(() => {
execute();
timeout = null;
}, remaining);
}
}) as ReturnType<T>;
};
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import keyBy from 'lodash/keyBy';
import memoize from 'lodash/memoize';
import reject from 'lodash/reject';
import throttle from 'lodash/throttle';
import {useEffect, useMemo, useRef, useState} from 'react';
import {FeatureFlag} from 'shared/app/FeatureFlags.oss';

Expand All @@ -10,6 +9,7 @@ import {GraphData, buildGraphData, tokenForAssetKey} from './Utils';
import {gql} from '../apollo-client';
import {computeGraphData as computeGraphDataImpl} from './ComputeGraphData';
import {ComputeGraphDataMessageType} from './ComputeGraphData.types';
import {throttleLatest} from './throttleLatest';
import {featureEnabled} from '../app/Flags';
import {
AssetGraphQuery,
Expand Down Expand Up @@ -147,15 +147,22 @@ export function useAssetGraphData(opsQuery: string, options: AssetGraphFetchScop
kinds,
hideEdgesToNodesOutsideQuery,
flagAssetSelectionSyntax: featureEnabled(FeatureFlag.flagAssetSelectionSyntax),
})?.then((data) => {
if (lastProcessedRequestRef.current < requestId) {
lastProcessedRequestRef.current = requestId;
setState(data);
})
?.then((data) => {
if (lastProcessedRequestRef.current < requestId) {
lastProcessedRequestRef.current = requestId;
setState(data);
if (requestId === currentRequestRef.current) {
setGraphDataLoading(false);
}
}
})
.catch((e) => {
console.error(e);
if (requestId === currentRequestRef.current) {
setGraphDataLoading(false);
}
}
});
});
}, [
repoFilteredNodes,
graphQueryItems,
Expand Down Expand Up @@ -299,31 +306,38 @@ export const ASSET_GRAPH_QUERY = gql`
${ASSET_NODE_FRAGMENT}
`;

const computeGraphData = throttle(
const computeGraphData = throttleLatest(
indexedDBAsyncMemoize<
ComputeGraphDataMessageType,
Omit<ComputeGraphDataMessageType, 'id' | 'type'>,
GraphDataState,
typeof computeGraphDataWrapper
>(computeGraphDataWrapper, (props) => {
return JSON.stringify(props);
}),
2000,
{leading: true},
);

const getWorker = memoize(() => new Worker(new URL('./ComputeGraphData.worker', import.meta.url)));

let _id = 0;
async function computeGraphDataWrapper(
props: Omit<ComputeGraphDataMessageType, 'type'>,
props: Omit<ComputeGraphDataMessageType, 'id' | 'type'>,
): Promise<GraphDataState> {
if (featureEnabled(FeatureFlag.flagAssetSelectionWorker)) {
const worker = getWorker();
return new Promise<GraphDataState>((resolve) => {
worker.addEventListener('message', (event) => {
resolve(event.data as GraphDataState);
});
const id = ++_id;
const callback = (event: MessageEvent) => {
const data = event.data as GraphDataState & {id: number};
if (data.id === id) {
resolve(data);
worker.removeEventListener('message', callback);
}
};
worker.addEventListener('message', callback);
const message: ComputeGraphDataMessageType = {
type: 'computeGraphData',
id,
...props,
};
worker.postMessage(message);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@ import '../../../jest/mocks/ComputeGraphData.worker';

// This file must be mocked because Jest can't handle `import.meta.url`.
jest.mock('../../graph/asyncGraphLayout', () => ({}));
jest.mock(
'lodash/throttle',
() =>
(fn: (...args: any[]) => any) =>
jest.mock('../../asset-graph/throttleLatest', () => ({
throttleLatest:
(fn: any) =>
(...args: any[]) =>
fn(...args),
);
}));

// These files must be mocked because useVirtualizer tries to create a ResizeObserver,
// and the component tree fails to mount.
Expand Down

1 comment on commit 4eba622

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deploy preview for dagit-core-storybook ready!

✅ Preview
https://dagit-core-storybook-70yt3vzpc-elementl.vercel.app

Built with commit 4eba622.
This pull request is being automatically deployed with vercel-action

Please sign in to comment.