From 6aa510d54ecfca7b1163e5e84bf324982107a051 Mon Sep 17 00:00:00 2001 From: Jamie Rasmussen Date: Wed, 23 Oct 2024 11:17:08 -0500 Subject: [PATCH] in progress --- tests/trace/test_saved_view.py | 0 weave-js/src/common/hooks/useProjectInfo.ts | 84 ++++ weave-js/src/common/hooks/useViewerInfo.ts | 5 +- .../PagePanelComponents/Home/Browse3.tsx | 124 +++++- .../Browse3/pages/CallsPage/CallsPage.tsx | 367 +++++++++++++++--- .../Browse3/pages/CallsPage/CallsTable.tsx | 93 +++-- .../Home/Browse3/pages/ObjectVersionPage.tsx | 1 + .../pages/SavedViews/ConfirmDeleteDialog.tsx | 92 +++++ .../pages/SavedViews/DropdownSwitchView.tsx | 44 +++ .../pages/SavedViews/DropdownViewActions.tsx | 51 +++ .../pages/SavedViews/PanelSwitchView.tsx | 65 ++++ .../Browse3/pages/SavedViews/PanelView.tsx | 68 ++++ .../pages/SavedViews/SavedViewPrefix.tsx | 16 + .../pages/SavedViews/SavedViewSuffix.tsx | 49 +++ .../SavedViews/SavedViewSuffixTimestamp.tsx | 23 ++ .../Browse3/pages/SavedViews/ViewName.tsx | 23 ++ .../pages/SavedViews/ViewNameEditing.tsx | 72 ++++ .../Browse3/pages/SavedViews/savedViewUtil.ts | 112 ++++++ .../Browse3/pages/common/SimplePageLayout.tsx | 2 +- .../pages/common/TypeVersionCategoryChip.tsx | 1 + .../pages/wfReactInterface/constants.ts | 1 + .../generatedBuiltinObjectClasses.zod.ts | 50 +++ .../traceServerClientTypes.ts | 10 + .../traceServerDirectClient.ts | 9 + weave-js/src/components/UserName.tsx | 38 ++ weave/__init__.py | 2 + weave/flow/saved_view.py | 186 +++++++++ weave/trace/weave_init.py | 10 + .../builtin_object_registry.py | 4 + .../generated_base_object_class_schemas.json | 4 +- ...enerated_builtin_object_class_schemas.json | 208 +++++++++- .../builtin_object_classes/saved_view.py | 66 ++++ weave/wandb_interface/wandb_api.py | 9 + 33 files changed, 1799 insertions(+), 90 deletions(-) create mode 100644 tests/trace/test_saved_view.py create mode 100644 weave-js/src/common/hooks/useProjectInfo.ts create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/ConfirmDeleteDialog.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/DropdownSwitchView.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/DropdownViewActions.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/PanelSwitchView.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/PanelView.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/SavedViewPrefix.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/SavedViewSuffix.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/SavedViewSuffixTimestamp.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/ViewName.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/ViewNameEditing.tsx create mode 100644 weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/savedViewUtil.ts create mode 100644 weave-js/src/components/UserName.tsx create mode 100644 weave/flow/saved_view.py create mode 100644 weave/trace_server/interface/builtin_object_classes/saved_view.py diff --git a/tests/trace/test_saved_view.py b/tests/trace/test_saved_view.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/weave-js/src/common/hooks/useProjectInfo.ts b/weave-js/src/common/hooks/useProjectInfo.ts new file mode 100644 index 000000000000..241f596a4d12 --- /dev/null +++ b/weave-js/src/common/hooks/useProjectInfo.ts @@ -0,0 +1,84 @@ +/** + * This is a GraphQL approach to querying project information. + */ + +import {gql, useApolloClient} from '@apollo/client'; +import {useEffect, useState} from 'react'; + +// Note: id is the "external" ID, which changes when a project is renamed. +// internalId does not change. +const PROJECT_QUERY = gql` + query Project($entityName: String!, $projectName: String!) { + project(name: $projectName, entityName: $entityName) { + id + internalId + } + } +`; + +export type ProjectInfo = { + externalIdEncoded: string; + internalIdEncoded: string; +}; +type ProjectInfoResponseLoading = { + loading: true; + projectInfo: {}; +}; +export type MaybeProjectInfo = ProjectInfo | null; +type ProjectInfoResponseSuccess = { + loading: false; + projectInfo: MaybeProjectInfo; +}; +type ProjectInfoResponse = + | ProjectInfoResponseLoading + | ProjectInfoResponseSuccess; + +export const useProjectInfo = ( + entityName: string, + projectName: string +): ProjectInfoResponse => { + const [response, setResponse] = useState({ + loading: true, + projectInfo: {}, + }); + + const apolloClient = useApolloClient(); + + useEffect(() => { + let mounted = true; + apolloClient + .query({ + query: PROJECT_QUERY as any, + variables: { + entityName, + projectName, + }, + }) + .then(result => { + if (!mounted) { + return; + } + const projectInfo = result.data.project; + if (!projectInfo) { + // Invalid project + setResponse({ + loading: false, + projectInfo: null, + }); + return; + } + setResponse({ + loading: false, + projectInfo: { + externalIdEncoded: projectInfo.id, + internalIdEncoded: projectInfo.internalId, + }, + }); + }); + return () => { + mounted = false; + }; + }, [apolloClient, entityName, projectName]); + + return response; +}; diff --git a/weave-js/src/common/hooks/useViewerInfo.ts b/weave-js/src/common/hooks/useViewerInfo.ts index 89edde753a71..6b85ec485e4c 100644 --- a/weave-js/src/common/hooks/useViewerInfo.ts +++ b/weave-js/src/common/hooks/useViewerInfo.ts @@ -25,7 +25,7 @@ const VIEWER_QUERY = gql` `; // TODO: Would be useful to add admin mode flags -type UserInfo = { +export type UserInfo = { id: string; username: string; teams: string[]; @@ -35,9 +35,10 @@ type UserInfoResponseLoading = { loading: true; userInfo: {}; }; +export type MaybeUserInfo = UserInfo | null; type UserInfoResponseSuccess = { loading: false; - userInfo: UserInfo | null; + userInfo: MaybeUserInfo; }; type UserInfoResponse = UserInfoResponseLoading | UserInfoResponseSuccess; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index c0192ddbe9f6..29abc1f9a972 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -29,9 +29,20 @@ import { useParams, } from 'react-router-dom'; +import { + ProjectInfo, + useProjectInfo, +} from '../../../common/hooks/useProjectInfo'; +import { + MaybeUserInfo, + useViewerInfo, +} from '../../../common/hooks/useViewerInfo'; import {URL_BROWSE3} from '../../../urls'; +import {useLocalStorage} from '../../../util/useLocalStorage'; +import {Alert} from '../../Alert'; import {Button} from '../../Button'; import {ErrorBoundary} from '../../ErrorBoundary'; +import {Loading} from '../../Loading'; import {Browse2EntityPage} from './Browse2/Browse2EntityPage'; import {Browse2HomePage} from './Browse2/Browse2HomePage'; import {ComparePage} from './Browse3/compare/ComparePage'; @@ -92,6 +103,7 @@ import { WFDataModelAutoProvider, } from './Browse3/pages/wfReactInterface/context'; import {useHasTraceServerClientContext} from './Browse3/pages/wfReactInterface/traceServerClientContext'; +import {sanitizeObjectId} from './Browse3/pages/wfReactInterface/traceServerDirectClient'; import {TableRowSelectionProvider} from './TableRowSelectionContext'; import {useDrawerResize} from './useDrawerResize'; @@ -659,10 +671,116 @@ const CallPageBinding = () => { ); }; -// TODO(tim/weaveflow_improved_nav): Generalize this const CallsPageBinding = () => { + const {entity, project} = useParamsDecoded(); + const {loading: loadingUserInfo, userInfo} = useViewerInfo(); + const {loading: loadingProjectInfo, projectInfo} = useProjectInfo( + entity, + project + ); + if (loadingUserInfo || loadingProjectInfo) { + return ; + } + if (!projectInfo) { + return Invalid project: {project}; + } + return ( + + ); +}; + +type ComparePageBindingWithProjectProps = { + projectInfo: ProjectInfo; + userInfo: MaybeUserInfo; +}; + +const CallsPageBindingWithProject = ({ + projectInfo, + userInfo, +}: ComparePageBindingWithProjectProps) => { + // const query = useURLSearchParamsDict(); + // const {entity, project, tab} = useParamsDecoded(); + // // Using internal ID because it doesn't change across project renames + // const storageKey = `SavedView.lastViewed.${projectInfo.internalIdEncoded}.${tab}`; + // console.log({storageKey, query}); + // const [lastView, setLastView] = useLocalStorage(storageKey, 'default'); + // if (lastView !== 'default' && Object.keys(query).length === 0) { + // console.log('returning CallsPageBindingLoadView'); + // return ( + // + // ); + // } + // console.log('returning CallsPageBindingLoaded'); + return ( + + ); +}; + +// type CallsPageBindingLoadViewProps = { +// entity: string; +// project: string; +// view: string; +// }; + +// Load a saved view +// const CallsPageBindingLoadView = ({ +// entity, +// project, +// view, +// }: CallsPageBindingLoadViewProps) => { +// const history = useHistory(); +// const getTsClient = useGetTraceServerClientContext(); +// const tsClient = getTsClient(); +// tsClient +// .objRead({ +// project_id: projectIdFromParts({ +// entity, +// project, +// }), +// object_id: view, +// digest: 'latest', +// }) +// .then((res: TraceObjReadRes) => { +// const search = savedViewObjectToQuery(res.obj); +// if (search) { +// history.replace({search}); +// } else { +// // TODO: saved view has no description. We don't want to +// // go into an infinite loop of requests. Should have a +// // way to report error. +// } +// }); +// return ; +// }; + +const CallsPageBindingLoaded = ({ + projectInfo, + userInfo, +}: ComparePageBindingWithProjectProps) => { const {entity, project, tab} = useParamsDecoded(); + + const currentViewerId = userInfo ? userInfo.id : null; + const isReadonly = !currentViewerId || !userInfo?.teams.includes(entity); + const query = useURLSearchParamsDict(); + + // Using internal ID because it doesn't change across project renames + const [lastView, setLastView] = useLocalStorage( + `SavedView.lastViewed.${projectInfo.internalIdEncoded}.${tab}`, + 'default' + ); + const onRecordLastView = (loadedView: string) => { + setLastView(loadedView); + }; + const view = query.view ? sanitizeObjectId(query.view) : lastView; + const initialFilter = useMemo(() => { if (tab === 'evaluations') { return { @@ -774,8 +892,12 @@ const CallsPageBinding = () => { return ( { }; export const CallsPage: FC<{ + currentViewerId: string | null; + isReadonly: boolean; + entity: string; project: string; + + view: string; + onRecordLastView: (view: string) => void; + initialFilter?: WFHighLevelCallFilter; // Setting this will make the component a controlled component. The parent // is responsible for updating the filter. @@ -47,6 +73,7 @@ export const CallsPage: FC<{ paginationModel: GridPaginationModel; setPaginationModel: (newModel: GridPaginationModel) => void; }> = props => { + const {entity, project, view} = props; const [filter, setFilter] = useControllableState( props.initialFilter ?? {}, props.onFilterUpdate @@ -54,61 +81,299 @@ export const CallsPage: FC<{ const isEvaluationTable = useCurrentFilterIsEvaluationsFilter( filter, - props.entity, - props.project + entity, + project ); + // table is the internal id stored in the object. + // TODO: Should we just use the capitalized version? + const table = isEvaluationTable ? 'evaluations' : 'traces'; + const defaultLabel = capitalizeFirst(table); + + const [views, setViews] = useState(null); + + const getTsClient = useGetTraceServerClientContext(); + + const tsClient = getTsClient(); + const projectId = projectIdFromParts({entity, project}); + + // const {loading, result: savedViews} = useBaseObjectInstances('SavedView', { + // project_id: projectId, + // filter: { + // // TODO: Could we filter at query time based on the page + // // so we don't have to do it on the result? + // base_object_classes: ['SavedView'], + // latest_only: true, + // }, + // }); + // console.log('after usebase object instances'); + // console.log({loading, savedViews}); + + // TODO: Memo + // const views = savedViews?.filter(v => v.val.table === table) ?? []; + // console.log({loading, savedViews, views}); + + const fetchViews = useCallback(() => { + tsClient + .objsQuery({ + project_id: projectId, + filter: { + // TODO: Could we filter at query time based on the page + // so we don't have to do it on the result? + base_object_classes: ['SavedView'], + latest_only: true, + }, + }) + .then(res => { + const viewsForPage = res.objs.filter(v => v.val.table === table); + // Add a "default" view if we don't have one + if (!viewsForPage.some(v => v.object_id === 'default')) { + viewsForPage.push(getDefaultView(projectId, defaultLabel)); + } + setViews(viewsForPage); + }) + .catch(err => { + console.error(err); + }); + }, [projectId, tsClient, table, defaultLabel]); + + // Load view data on mount + // eslint-disable-next-line react-hooks/exhaustive-deps + useEffect(fetchViews, [table]); + + const baseView = + views?.find(v => v.object_id === view) ?? + getDefaultView(projectId, defaultLabel); + const currentViewDefinition = useCurrentViewDefinition(); + + const history = useHistory(); + const onLoadView = (viewToLoad: TraceObjSchema) => { + // We want to preserve any params that are not part of view definition, + // e.g. peek drawer state. + const newQuery = new URLSearchParams(history.location.search); + + // Clear out any params related to saved views + for (const key of SAVED_PARAM_KEYS) { + newQuery.delete(key); + } - const title = useMemo(() => { - if (isEvaluationTable) { - return 'Evaluations'; + // Update with params from the view definition + for (const [key, value] of Object.entries(viewToLoad.val.definition)) { + newQuery.set(key, JSON.stringify(value)); } - if (filter.opVersionRefs?.length === 1) { - const opName = opVersionRefOpName(filter.opVersionRefs[0]); - if (opName) { - return opNiceName(opName) + ' Traces'; + + newQuery.set('view', viewToLoad.object_id); + history.push({search: newQuery.toString()}); + props.onRecordLastView(viewToLoad.object_id); + }; + const onResetView = () => { + const viewToLoad = views?.find(v => v.object_id === view); + if (viewToLoad) { + onLoadView(viewToLoad); + } + }; + + const viewDef = useCurrentViewDefinition(); + const {currentViewerId, onRecordLastView} = props; + + const onUpsertView = useCallback( + (objectId: string, label: string | null, successMessage: string) => { + if (label === null) { + // If caller doesn't provide a new label, use the existing one. + label = baseView.val.label; } + const className = 'SavedView'; + tsClient + .objCreate({ + obj: { + project_id: projectIdFromParts({entity, project}), + object_id: objectId, + val: { + _type: className, + table, + // name, + _class_name: className, + _bases: ['SavedView', 'Object', 'BaseModel'], + label, + definition: viewDef, + creatorUserId: currentViewerId, + }, + }, + }) + .then(res => { + const newQuery = new URLSearchParams(history.location.search); + newQuery.set('view', objectId); + history.push({search: newQuery.toString()}); + fetchViews(); + toast(successMessage); + onRecordLastView(objectId); + }); + }, + [ + entity, + fetchViews, + history, + project, + table, + currentViewerId, + onRecordLastView, + baseView.val.label, + tsClient, + viewDef, + ] + ); + + const onSaveNewView = () => { + // setIsEditingName(true); + // querySetString(history, 'view', 'placeholder'); + // // TODO: Set focus to name input + const now = new Date(); + const objectId = `SavedView_${table}_${now + .toISOString() + .replace('T', '_') + .replace(/[:.]/g, '-') + .slice(0, -1)}`; + onUpsertView(objectId, 'Untitled view', 'Successfully created new view.'); + }; + + const onSaveView = () => { + onUpsertView(view, null, 'Successfully saved view.'); + }; + + const onRenameView = (newName: string) => { + onUpsertView(view, newName, 'Successfully renamed view.'); + }; + + const onDeleteView = () => { + tsClient + .objDelete({ + project_id: projectIdFromParts({entity, project}), + object_id: view, + }) + .then(res => { + fetchViews(); + // TODO: Use label of view + toast(`Successfully deleted view.`); + // onRecordLastView(objectId); + const newQuery = new URLSearchParams(); + newQuery.set('view', 'default'); + history.push({search: newQuery.toString()}); + }); + }; + + const savedViewsInfo: SavedViewsInfo = { + currentViewerId: props.currentViewerId, + isLoading: views === null, + currentViewId: view, + currentViewDefinition, + isModified: !_.isEqual(currentViewDefinition, baseView?.val.definition), + views: views ?? [], + baseView, + onLoadView, + onSaveView, + onSaveNewView, + onResetView, + onDeleteView, + }; + console.log({ + currentViewDefinition, + bv: baseView?.val.definition, + savedViewsInfo, + }); + + const onNameChanged = (newName: string) => { + if (views === null) { + return; } - return 'Traces'; - }, [filter.opVersionRefs, isEvaluationTable]); + // Update the local state with the new name + const updatedViews = views.map(v => { + if (v.object_id === view) { + return {...v, val: {...v.val, name: newName}}; + } + return v; + }); + setViews(updatedViews); + // Update the server with the new name + onRenameView(newName); + }; + const activeName = + view === 'placeholder' + ? 'Untitled view' + : baseView.val.label ?? 'Untitled view'; + const [isEditingName, setIsEditingName] = useState(false); + const title = ( + + {isEditingName ? ( + setIsEditingName(false)} + /> + ) : ( + setIsEditingName(true)} + tooltip="Click to rename view" + /> + )} + + ); return ( - - ), - }, - ]} - headerExtra={} - /> + + + + ), + headerSuffix: + views !== null ? ( + + + + ) : undefined, + }}> + + ), + }, + ]} + headerExtra={} + /> + ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx index 3632664aaf09..cdc866177e13 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx @@ -138,6 +138,29 @@ const DEFAULT_PAGINATION_CALLS: GridPaginationModel = { page: 0, }; +const resolveColumnVisibility = ( + columnVisibilityModel: GridColumnVisibilityModel, + columns: GridColDef[] +): GridColumnVisibilityModel => { + const resolvedModel: GridColumnVisibilityModel = { + ...columnVisibilityModel, + }; + + // By default columns are shown. If we have columns that we want to + // hide by default, set that in the visibility model. + for (const col of columns) { + if ( + !(col.field in resolvedModel) && + DEFAULT_HIDDEN_COLUMN_PREFIXES.some(prefix => + col.field.startsWith(prefix) + ) + ) { + resolvedModel[col.field] = false; + } + } + return resolvedModel; +}; + export const CallsTable: FC<{ entity: string; project: string; @@ -530,37 +553,43 @@ export const CallsTable: FC<{ ); // Set default hidden columns to be hidden - useEffect(() => { - if (!setColumnVisibilityModel || !columnVisibilityModel) { - return; - } - const hiddenColumns: string[] = []; - for (const hiddenColPrefix of DEFAULT_HIDDEN_COLUMN_PREFIXES) { - const cols = columns.cols.filter(col => - col.field.startsWith(hiddenColPrefix) - ); - hiddenColumns.push(...cols.map(col => col.field)); - } - // Check if we need to update - only update if any annotation columns are missing from the model - const needsUpdate = hiddenColumns.some( - col => columnVisibilityModel[col] === undefined - ); - if (!needsUpdate) { - return; - } - const hiddenColumnVisiblityFalse = hiddenColumns.reduce((acc, col) => { - // Only add columns=false when not already in the model - if (columnVisibilityModel[col] === undefined) { - acc[col] = false; - } - return acc; - }, {} as Record); + const columnVisibilityModelResolved = resolveColumnVisibility( + columnVisibilityModel ?? {}, + columns.cols + ); - setColumnVisibilityModel({ - ...columnVisibilityModel, - ...hiddenColumnVisiblityFalse, - }); - }, [columns.cols, columnVisibilityModel, setColumnVisibilityModel]); + // // Set default hidden columns to be hidden + // useEffect(() => { + // if (!setColumnVisibilityModel || !columnVisibilityModel) { + // return; + // } + // const hiddenColumns: string[] = []; + // for (const hiddenColPrefix of DEFAULT_HIDDEN_COLUMN_PREFIXES) { + // const cols = columns.cols.filter(col => + // col.field.startsWith(hiddenColPrefix) + // ); + // hiddenColumns.push(...cols.map(col => col.field)); + // } + // // Check if we need to update - only update if any annotation columns are missing from the model + // const needsUpdate = hiddenColumns.some( + // col => columnVisibilityModel[col] === undefined + // ); + // if (!needsUpdate) { + // return; + // } + // const hiddenColumnVisiblityFalse = hiddenColumns.reduce((acc, col) => { + // // Only add columns=false when not already in the model + // if (columnVisibilityModel[col] === undefined) { + // acc[col] = false; + // } + // return acc; + // }, {} as Record); + + // setColumnVisibilityModel({ + // ...columnVisibilityModel, + // ...hiddenColumnVisiblityFalse, + // }); + // }, [columns.cols, columnVisibilityModel, setColumnVisibilityModel]); // Selection Management const [selectedCalls, setSelectedCalls] = useState([]); @@ -886,7 +915,7 @@ export const CallsTable: FC<{
@@ -929,7 +958,7 @@ export const CallsTable: FC<{ rows={tableData} // initialState={initialState} onColumnVisibilityModelChange={onColumnVisibilityModelChange} - columnVisibilityModel={columnVisibilityModel} + columnVisibilityModel={columnVisibilityModelResolved} // SORT SECTION START sortingMode="server" sortModel={sortModel} diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx index 1b0bb491f1c0..df6ab9e8ee6e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx @@ -55,6 +55,7 @@ const OBJECT_ICONS: Record = { Scorer: 'type-number-alt', ActionSpec: 'rocket-launch', AnnotationSpec: 'forum-chat-bubble', + SavedView: 'view-glasses', }; const ObjectIcon = ({baseObjectClass}: ObjectIconProps) => { if (baseObjectClass in OBJECT_ICONS) { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/ConfirmDeleteDialog.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/ConfirmDeleteDialog.tsx new file mode 100644 index 000000000000..9d51ef7a63d4 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/ConfirmDeleteDialog.tsx @@ -0,0 +1,92 @@ +import { + Dialog, + DialogActions as MaterialDialogActions, + DialogContent as MaterialDialogContent, + DialogTitle as MaterialDialogTitle, +} from '@material-ui/core'; +import {Button} from '@wandb/weave/components/Button'; +import React, {useState} from 'react'; +import styled from 'styled-components'; + +// TODO: Need to cleanup duplication with CallPage OverflowMenu +const DialogContent = styled(MaterialDialogContent)` + padding: 0 32px !important; +`; +DialogContent.displayName = 'S.DialogContent'; + +const DialogTitle = styled(MaterialDialogTitle)` + padding: 32px 32px 16px 32px !important; + + h2 { + font-weight: 600; + font-size: 24px; + line-height: 30px; + } +`; +DialogTitle.displayName = 'S.DialogTitle'; + +const DialogActions = styled(MaterialDialogActions)<{$align: string}>` + justify-content: ${({$align}) => + $align === 'left' ? 'flex-start' : 'flex-end'} !important; + padding: 32px 32px 32px 32px !important; +`; +DialogActions.displayName = 'S.DialogActions'; + +type ConfirmDeleteDialogProps = { + setConfirmDelete: (confirmDelete: boolean) => void; + onDeleteCallback: () => void; +}; + +export const ConfirmDeleteDialog = ({ + setConfirmDelete, + onDeleteCallback, +}: ConfirmDeleteDialogProps) => { + const [deleteLoading, setDeleteLoading] = useState(false); + const [error, setError] = useState(null); + + const onDelete = () => { + setDeleteLoading(true); + onDeleteCallback(); + setDeleteLoading(false); + setConfirmDelete(false); + }; + return ( + { + setConfirmDelete(false); + setError(null); + }} + maxWidth="xs" + fullWidth> + Delete this view? + + {error != null ? ( +

{error}

+ ) : ( +

+ You can delete this view if you believe it is no longer useful to + you and your team. This cannot be undone. +

+ )} +
+ + + + +
+ ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/DropdownSwitchView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/DropdownSwitchView.tsx new file mode 100644 index 000000000000..0ef19100ba47 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/DropdownSwitchView.tsx @@ -0,0 +1,44 @@ +import {Button} from '@wandb/weave/components/Button'; +import * as DropdownMenu from '@wandb/weave/components/DropdownMenu'; +import React, {useState} from 'react'; + +import {TraceObjSchema} from '../wfReactInterface/traceServerClientTypes'; +import {PanelSwitchView} from './PanelSwitchView'; +import {SavedViewsInfo} from './savedViewUtil'; + +type DropdownSwitchViewProps = { + savedViewsInfo: SavedViewsInfo; +}; + +export const DropdownSwitchView = ({ + savedViewsInfo, +}: DropdownSwitchViewProps) => { + const {isLoading} = savedViewsInfo; + const [isOpen, setIsOpen] = useState(false); + + const onLoadView = (view: TraceObjSchema) => { + setIsOpen(false); + savedViewsInfo.onLoadView(view); + }; + + return ( + + + + + )} + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/SavedViewSuffixTimestamp.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/SavedViewSuffixTimestamp.tsx new file mode 100644 index 000000000000..6367738b3dd4 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/SavedViewSuffixTimestamp.tsx @@ -0,0 +1,23 @@ +import React from 'react'; + +type SavedViewSuffixTimestampProps = { + createdAt: string; +}; + +export const SavedViewSuffixTimestamp = ({ + createdAt, +}: SavedViewSuffixTimestampProps) => { + // Jul 25 at 5:22pm + const saveTime = new Date(createdAt) + .toLocaleString('en-US', { + month: 'short', + day: 'numeric', + hour: 'numeric', + minute: '2-digit', + hour12: true, + }) + .replace(', ', ' at ') + .replace(/AM|PM/, match => match.toLowerCase()); + + return <>{saveTime}; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/ViewName.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/ViewName.tsx new file mode 100644 index 000000000000..8f57af49bb66 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/ViewName.tsx @@ -0,0 +1,23 @@ +import {Tooltip} from '@wandb/weave/components/Tooltip'; +import React from 'react'; + +type ViewNameProps = { + value: string; + onEditNameStart: () => void; + tooltip?: string; +}; + +export const ViewName = ({value, onEditNameStart, tooltip}: ViewNameProps) => { + const onClick = () => { + onEditNameStart(); + }; + const body = ( +
+ {value} +
+ ); + if (tooltip) { + return ; + } + return body; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/ViewNameEditing.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/ViewNameEditing.tsx new file mode 100644 index 000000000000..9661a5e7aa62 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/ViewNameEditing.tsx @@ -0,0 +1,72 @@ +import React, {useEffect, useRef, useState} from 'react'; + +type ViewNameEditingProps = { + value: string; + + onChanged: (value: string) => void; + onExit: () => void; +}; + +export const ViewNameEditing = ({ + value, + onChanged, + onExit, +}: ViewNameEditingProps) => { + const [activeValue, setActiveValue] = useState(value); + const inputRef = useRef(null); + + // Select all of the text + // TODO: Make this behavior optional? + useEffect(() => { + if (inputRef.current) { + inputRef.current.select(); + } + }, []); + + const onChange = (e: React.ChangeEvent) => { + const newValue = e.currentTarget.value; + setActiveValue(newValue); + }; + const onBlur = () => { + // TODO: Trim? Disallow whitespace only? + if (activeValue !== value) { + onChanged(activeValue); + } + onExit(); + }; + const onKeyDown = (e: React.KeyboardEvent) => { + if (e.key === 'Escape') { + onExit(); + } else if (e.key === 'Enter') { + onBlur(); + } + }; + const placeholder = value; + return ( +
+ +

Enter

+
+ ); + + // onChange={e => { + // let newVal = e.currentTarget.value; + // if (this.props.type === 'url') { + // newVal = removeUrlProtocolPrefix(newVal); + // } + // this.updateValue(newVal); + // }} + // placeholder={this.props.placeholder} + // onKeyDown={this.onKeyDown} + // ref={this.inputRef} + // onBlur={this.stopEditing} +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/savedViewUtil.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/savedViewUtil.ts new file mode 100644 index 000000000000..2367feef68ed --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/SavedViews/savedViewUtil.ts @@ -0,0 +1,112 @@ +import _ from 'lodash'; +import {useMemo} from 'react'; +import {useParams} from 'react-router-dom'; + +import {useURLSearchParamsDict} from '../util'; +import {TraceObjSchema} from '../wfReactInterface/traceServerClientTypes'; + +// Copied from Browse3 +export const useParamsDecoded = () => { + // Handle the case where entity/project (old) have spaces + const params = useParams(); + return useMemo(() => { + return Object.fromEntries( + Object.entries(params).map(([key, value]) => [ + key, + decodeURIComponent(value), + ]) + ); + }, [params]); +}; + +// Not page - always load first page +export const SAVED_PARAM_KEYS = [ + 'cols', + 'filter', + 'filters', + 'sort', + 'pin', + 'pageSize', +]; + +// Value of the object +type ViewDefinition = Record; + +// Get the current view definition from the query params +export const useCurrentViewDefinition = (): ViewDefinition => { + const query = useURLSearchParamsDict(); + const picked = _.pick(query, SAVED_PARAM_KEYS); + const parsed = _.mapValues(picked, v => { + try { + return JSON.parse(v); + } catch (e) { + return null; + } + }); + const filtered = _.pickBy(parsed, v => v !== null); + return filtered; +}; + +export const getDefaultViewDefinition = (label: string): ViewDefinition => { + return { + label, + definition: { + // TODO: Need less fragile way of setting this up + // cols: { + // 'attributes.weave.client_version': false, + // 'attributes.weave.os_name': false, + // 'attributes.weave.os_release': false, + // 'attributes.weave.os_version': false, + // 'attributes.weave.source': false, + // 'attributes.weave.sys_version': false, + // }, + }, + }; +}; + +export const getDefaultView = ( + projectId: string, + label: string +): TraceObjSchema => { + const val = getDefaultViewDefinition(label); + return { + project_id: projectId, + object_id: 'default', + created_at: '', + digest: '', + version_index: 0, + is_latest: 1, + kind: 'object', + val, + }; +}; + +export type SavedViewsInfo = { + currentViewerId: string | null; // user id of viewer, null if not logged in + + isLoading: boolean; + + currentViewId: string; // objectId Can be special value "default" + views: TraceObjSchema[]; // Only the latest version of each view + + baseView: TraceObjSchema; + currentViewDefinition: ViewDefinition; + isModified: boolean; // Whether current view is not same as base view + + onLoadView: (view: TraceObjSchema) => void; + onSaveView: () => void; + onSaveNewView: () => void; + onResetView: () => void; + onDeleteView: () => void; +}; + +export const savedViewObjectToQuery = (view: TraceObjSchema): string => { + const params = new URLSearchParams(); + params.set('view', view.object_id); + const {definition} = view.val; + Object.entries(definition).forEach(([key, value]) => { + const v = JSON.stringify(value); + params.set(key, v); + }); + return params.toString(); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx index 9d2c6ab9718c..fe245a70eabe 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx @@ -28,7 +28,7 @@ export const SimplePageLayoutContext = createContext({}); export const SimplePageLayout: FC<{ - title: string; + title: React.ReactNode; tabs: Array<{ label: string; content: ReactNode; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx index 235d23e0e15b..9ba8dce1d819 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx @@ -12,6 +12,7 @@ const colorMap: Record = { Scorer: 'purple', ActionSpec: 'sienna', AnnotationSpec: 'magenta', + SavedView: 'magenta', }; export const TypeVersionCategoryChip: React.FC<{ diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts index aeb190b02e91..331e507a3d94 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts @@ -28,4 +28,5 @@ export const KNOWN_BASE_OBJECT_CLASSES = [ 'Scorer', 'ActionSpec', 'AnnotationSpec', + 'SavedView', ] as const; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBuiltinObjectClasses.zod.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBuiltinObjectClasses.zod.ts index 5a1c4bbf2327..346f39257e6e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBuiltinObjectClasses.zod.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBuiltinObjectClasses.zod.ts @@ -6,6 +6,12 @@ export type ActionType = z.infer; export const ModelSchema = z.enum(['gpt-4o', 'gpt-4o-mini']); export type Model = z.infer; +export const LogicoperatorSchema = z.enum(['and']); +export type Logicoperator = z.infer; + +export const SortSchema = z.enum(['asc', 'desc']); +export type Sort = z.infer; + export const ConfigSchema = z.object({ action_type: ActionTypeSchema.optional(), model: ModelSchema.optional(), @@ -32,6 +38,26 @@ export const LeaderboardColumnSchema = z.object({ }); export type LeaderboardColumn = z.infer; +export const FilterSchema = z.object({ + field: z.string(), + id: z.number(), + operator: z.string(), + value: z.any(), +}); +export type Filter = z.infer; + +export const PinSchema = z.object({ + left: z.array(z.string()), + right: z.array(z.string()), +}); +export type Pin = z.infer; + +export const SortClauseSchema = z.object({ + field: z.string(), + sort: SortSchema, +}); +export type SortClause = z.infer; + export const TestOnlyNestedBaseModelSchema = z.object({ a: z.number(), }); @@ -62,6 +88,12 @@ export const LeaderboardSchema = z.object({ }); export type Leaderboard = z.infer; +export const FiltersSchema = z.object({ + items: z.array(FilterSchema), + logicOperator: LogicoperatorSchema, +}); +export type Filters = z.infer; + export const TestOnlyExampleSchema = z.object({ description: z.union([z.null(), z.string()]).optional(), name: z.union([z.null(), z.string()]).optional(), @@ -71,10 +103,28 @@ export const TestOnlyExampleSchema = z.object({ }); export type TestOnlyExample = z.infer; +export const SavedViewDefinitionSchema = z.object({ + cols: z.union([z.record(z.string(), z.boolean()), z.null()]).optional(), + filters: z.union([FiltersSchema, z.null()]).optional(), + pin: z.union([PinSchema, z.null()]).optional(), + sort: z.union([z.array(SortClauseSchema), z.null()]).optional(), +}); +export type SavedViewDefinition = z.infer; + +export const SavedViewSchema = z.object({ + creatorUserId: z.string(), + definition: SavedViewDefinitionSchema, + description: z.union([z.null(), z.string()]).optional(), + name: z.union([z.null(), z.string()]).optional(), + table: z.string(), +}); +export type SavedView = z.infer; + export const builtinObjectClassRegistry = { ActionSpec: ActionSpecSchema, AnnotationSpec: AnnotationSpecSchema, Leaderboard: LeaderboardSchema, + SavedView: SavedViewSchema, TestOnlyExample: TestOnlyExampleSchema, TestOnlyNestedBaseObject: TestOnlyNestedBaseObjectSchema, }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts index 7c89efd44196..69208241c00c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerClientTypes.ts @@ -239,6 +239,16 @@ export type TraceObjReadRes = { obj: TraceObjSchema; }; +export type TraceObjDeleteReq = { + project_id: string; + object_id: string; + digests?: string[]; +}; + +export type TraceObjDeleteRes = { + num_deleted: number; +}; + export type TraceObjCreateReq = { obj: { project_id: string; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts index 22d2ec797810..bb7dd0afe956 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/traceServerDirectClient.ts @@ -42,6 +42,8 @@ import { TraceFileContentReadRes, TraceObjCreateReq, TraceObjCreateRes, + TraceObjDeleteReq, + TraceObjDeleteRes, TraceObjQueryReq, TraceObjQueryRes, TraceObjReadReq, @@ -232,6 +234,13 @@ export class DirectTraceServerClient { return this.makeRequest('/obj/read', req); } + public objDelete(req: TraceObjDeleteReq): Promise { + return this.makeRequest( + '/obj/delete', + req + ); + } + public readBatch(req: TraceRefsReadBatchReq): Promise { return this.makeRequest( '/refs/read_batch', diff --git a/weave-js/src/components/UserName.tsx b/weave-js/src/components/UserName.tsx new file mode 100644 index 000000000000..a552c0b92405 --- /dev/null +++ b/weave-js/src/components/UserName.tsx @@ -0,0 +1,38 @@ +/** + * Just a username to display + */ + +import React from 'react'; + +import {useUsers} from './UserLink'; + +type UserNameProps = { + userId: string | null; + prefix?: string; + field?: 'name' | 'username'; +}; + +export const UserName = ({userId, prefix, field}: UserNameProps) => { + const users = useUsers(userId ? [userId] : []); + if (userId == null) { + return null; + } + if (users === 'load' || users === 'loading') { + return null; + } + if (users === 'error') { + return null; + } + const user = users[0]; + const value = user[field ?? 'name']; + if (!value) { + return null; + } + + return ( + + {prefix} + {value} + + ); +}; diff --git a/weave/__init__.py b/weave/__init__.py index 87ae9aee0132..3d4420cbdca7 100644 --- a/weave/__init__.py +++ b/weave/__init__.py @@ -13,6 +13,7 @@ from weave.flow.model import Model from weave.flow.obj import Object from weave.flow.prompt.prompt import EasyPrompt, MessagesPrompt, Prompt, StringPrompt +from weave.flow.saved_view import SavedView from weave.trace.util import Thread as Thread from weave.trace.util import ThreadPoolExecutor as ThreadPoolExecutor @@ -40,4 +41,5 @@ MessagesPrompt, Evaluation, Scorer, + SavedView, ] diff --git a/weave/flow/saved_view.py b/weave/flow/saved_view.py new file mode 100644 index 000000000000..8e4baad16e1b --- /dev/null +++ b/weave/flow/saved_view.py @@ -0,0 +1,186 @@ +from datetime import datetime +from typing import Any + +from weave.trace.api import publish as weave_publish +from weave.trace.api import ref as weave_ref +from weave.trace.weave_init import get_viewer_id +from weave.trace_server.interface.builtin_object_classes.saved_view import ( + Filter, + Filters, + Pin, + SortClause, + SortDirection, +) +from weave.trace_server.interface.builtin_object_classes.saved_view import ( + SavedView as SavedViewBase, +) + +DEFAULT_PIN = Pin(left=["CustomCheckbox", "op_name"], right=[]) +DEFAULT_FILTER = Filters(items=[], logicOperator="and") + +OPERATOR_MAP = { + "equals": "(string): equals", +} +OPERATOR_MAP_INV = {v: k for k, v in OPERATOR_MAP.items()} + + +class SavedView: + """A fluent-style class for working with SavedView objects.""" + + base: SavedViewBase + + def __init__(self, table: str, label: str) -> None: + creator_user_id = get_viewer_id() + self.base = SavedViewBase( + table=table, + label=label, + creatorUserId=creator_user_id, + definition={}, + ) + + def rename(self, label: str) -> "SavedView": + self.base.label = label + return self + + def add_filter(self, field: str, operator: str, value: Any) -> "SavedView": + if not self.base.definition.filters: + self.base.definition.filters = DEFAULT_FILTER.copy() + assert self.base.definition.filters is not None + op = OPERATOR_MAP.get(operator) + if not op: + raise ValueError(f"Operator {operator} not supported") + next_id = len(self.base.definition.filters.items) + filter = Filter(id=next_id, field=field, operator=op, value=value) + self.base.definition.filters.items.append(filter) + return self + + def add_sort(self, field: str, sort: SortDirection) -> "SavedView": + if self.base.definition.sort is None: + self.base.definition.sort = [] + clause = SortClause(field=field, sort=sort) + self.base.definition.sort.append(clause) + return self + + def sort_by(self, field: str, sort: SortDirection) -> "SavedView": + self.base.definition.sort = [] + return self.add_sort(field, sort) + + def show_column(self, col_name: str) -> "SavedView": + if not self.base.definition.cols: + self.base.definition.cols = {} + self.base.definition.cols[col_name] = True + return self + + def hide_column(self, col_name: str) -> "SavedView": + if not self.base.definition.cols: + self.base.definition.cols = {} + self.base.definition.cols[col_name] = False + return self + + def pin_column_left(self, col_name: str) -> "SavedView": + if not self.base.definition.pin: + self.base.definition.pin = DEFAULT_PIN.copy() + assert self.base.definition.pin is not None + if col_name in self.base.definition.pin.right: + self.base.definition.pin.right.remove(col_name) + if col_name not in self.base.definition.pin.left: + self.base.definition.pin.left.append(col_name) + return self + + def pin_column_right(self, col_name: str) -> "SavedView": + if not self.base.definition.pin: + self.base.definition.pin = DEFAULT_PIN.copy() + assert self.base.definition.pin is not None + if col_name in self.base.definition.pin.left: + self.base.definition.pin.left.remove(col_name) + if col_name not in self.base.definition.pin.right: + self.base.definition.pin.right.append(col_name) + return self + + def unpin_column(self, col_name: str) -> "SavedView": + if not self.base.definition.pin: + self.base.definition.pin = DEFAULT_PIN.copy() + assert self.base.definition.pin is not None + if col_name in self.base.definition.pin.left: + self.base.definition.pin.left.remove(col_name) + elif col_name in self.base.definition.pin.right: + self.base.definition.pin.right.remove(col_name) + return self + + def page_size(self, page_size: int) -> "SavedView": + self.base.definition.page_size = page_size + return self + + @property + def name(self) -> str: + return self.base.label + + def __str__(self) -> str: + parts = [] + parts.append(f"SavedView '{self.name}'") + + if self.base.definition.filters and self.base.definition.filters.items: + filter_strs = [] + for f in self.base.definition.filters.items: + filter_strs.append(f"{f.field} {f.operator} {f.value}") + parts.append(f"Filters: {', '.join(filter_strs)}") + + if self.base.definition.sort: + sort_strs = [] + for s in self.base.definition.sort: + sort_strs.append(f"{s.field} {s.sort}") + parts.append(f"Sort: {', '.join(sort_strs)}") + + if self.base.definition.cols: + shown = [ + col for col, visible in self.base.definition.cols.items() if visible + ] + hidden = [ + col for col, visible in self.base.definition.cols.items() if not visible + ] + if shown: + parts.append(f"Shown columns: {', '.join(shown)}") + if hidden: + parts.append(f"Hidden columns: {', '.join(hidden)}") + + if self.base.definition.pin: + if self.base.definition.pin.left: + parts.append( + f"Left-pinned columns: {', '.join(self.base.definition.pin.left)}" + ) + if self.base.definition.pin.right: + parts.append( + f"Right-pinned columns: {', '.join(self.base.definition.pin.right)}" + ) + + if self.base.definition.page_size: + parts.append(f"Page size: {self.base.definition.page_size}") + + return "\n".join(parts) + + def save(self) -> None: + # Version creator is current user + creator_user_id = get_viewer_id() + if not creator_user_id: + raise ValueError("No viewer ID found") + self.base.creatorUserId = creator_user_id + name = self.base.name + if name is None: + formatted_now = ( + datetime.now() + .isoformat() + .replace("T", "_") + .replace(":", "-") + .replace(".", "-")[:-1] + ) + name = f"SavedView_{self.base.table}_{formatted_now}" + weave_publish(self.base, name) + + @classmethod + def load(cls, ref: str) -> "SavedView": + base = weave_ref(ref).get() + instance = cls.__new__(cls) + instance.base = base + return instance + + # TODO: Where should we put method to query SavedViews? diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py index f51d42d5018d..6595520c0e0e 100644 --- a/weave/trace/weave_init.py +++ b/weave/trace/weave_init.py @@ -18,6 +18,16 @@ def reset(self) -> None: _current_inited_client: InitializedClient | None = None +def get_viewer_id() -> str | None: + from weave.wandb_interface import wandb_api + + api = wandb_api.get_wandb_api_sync() + try: + return api.viewer_id() + except AttributeError: + return None + + def get_username() -> str | None: from weave.wandb_interface import wandb_api diff --git a/weave/trace_server/interface/builtin_object_classes/builtin_object_registry.py b/weave/trace_server/interface/builtin_object_classes/builtin_object_registry.py index 9cc76cefb2e4..2e89fa264dc1 100644 --- a/weave/trace_server/interface/builtin_object_classes/builtin_object_registry.py +++ b/weave/trace_server/interface/builtin_object_classes/builtin_object_registry.py @@ -6,6 +6,9 @@ BaseObject, ) from weave.trace_server.interface.builtin_object_classes.leaderboard import Leaderboard +from weave.trace_server.interface.builtin_object_classes.saved_view import ( + SavedView, +) from weave.trace_server.interface.builtin_object_classes.test_only_example import ( TestOnlyExample, TestOnlyNestedBaseObject, @@ -29,3 +32,4 @@ def register_base_object(cls: type[BaseObject]) -> None: register_base_object(Leaderboard) register_base_object(ActionSpec) register_base_object(AnnotationSpec) +register_base_object(SavedView) diff --git a/weave/trace_server/interface/builtin_object_classes/generated/generated_base_object_class_schemas.json b/weave/trace_server/interface/builtin_object_classes/generated/generated_base_object_class_schemas.json index a50c85688bd1..d835313d6542 100644 --- a/weave/trace_server/interface/builtin_object_classes/generated/generated_base_object_class_schemas.json +++ b/weave/trace_server/interface/builtin_object_classes/generated/generated_base_object_class_schemas.json @@ -79,10 +79,10 @@ }, "field_schema": { "default": {}, - "description": "Expected to be valid JSON Schema. Can be provided as a dict or a Pydantic model class", + "description": "Expected to be valid JSON Schema. Can be provided as a dict, a Pydantic model class, a tuple of a primitive type and a Pydantic Field, or primitive type", "examples": [ { - "max_length": 100, + "maxLength": 100, "type": "string" }, { diff --git a/weave/trace_server/interface/builtin_object_classes/generated/generated_builtin_object_class_schemas.json b/weave/trace_server/interface/builtin_object_classes/generated/generated_builtin_object_class_schemas.json index d835313d6542..70ce0daddda3 100644 --- a/weave/trace_server/interface/builtin_object_classes/generated/generated_builtin_object_class_schemas.json +++ b/weave/trace_server/interface/builtin_object_classes/generated/generated_builtin_object_class_schemas.json @@ -166,6 +166,58 @@ "title": "ContainsWordsActionConfig", "type": "object" }, + "Filter": { + "properties": { + "id": { + "title": "Id", + "type": "integer" + }, + "field": { + "title": "Field", + "type": "string" + }, + "operator": { + "title": "Operator", + "type": "string" + }, + "value": { + "title": "Value" + } + }, + "required": [ + "id", + "field", + "operator", + "value" + ], + "title": "Filter", + "type": "object" + }, + "Filters": { + "properties": { + "items": { + "items": { + "$ref": "#/$defs/Filter" + }, + "title": "Items", + "type": "array" + }, + "logicOperator": { + "const": "and", + "enum": [ + "and" + ], + "title": "Logicoperator", + "type": "string" + } + }, + "required": [ + "items", + "logicOperator" + ], + "title": "Filters", + "type": "object" + }, "Leaderboard": { "properties": { "name": { @@ -277,6 +329,156 @@ "title": "LlmJudgeActionConfig", "type": "object" }, + "Pin": { + "properties": { + "left": { + "items": { + "type": "string" + }, + "title": "Left", + "type": "array" + }, + "right": { + "items": { + "type": "string" + }, + "title": "Right", + "type": "array" + } + }, + "required": [ + "left", + "right" + ], + "title": "Pin", + "type": "object" + }, + "SavedView": { + "properties": { + "name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Name" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Description" + }, + "table": { + "title": "Table", + "type": "string" + }, + "creatorUserId": { + "title": "Creatoruserid", + "type": "string" + }, + "definition": { + "$ref": "#/$defs/SavedViewDefinition" + } + }, + "required": [ + "table", + "creatorUserId", + "definition" + ], + "title": "SavedView", + "type": "object" + }, + "SavedViewDefinition": { + "properties": { + "filters": { + "anyOf": [ + { + "$ref": "#/$defs/Filters" + }, + { + "type": "null" + } + ], + "default": null + }, + "cols": { + "anyOf": [ + { + "additionalProperties": { + "type": "boolean" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Cols" + }, + "pin": { + "anyOf": [ + { + "$ref": "#/$defs/Pin" + }, + { + "type": "null" + } + ], + "default": null + }, + "sort": { + "anyOf": [ + { + "items": { + "$ref": "#/$defs/SortClause" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Sort" + } + }, + "title": "SavedViewDefinition", + "type": "object" + }, + "SortClause": { + "properties": { + "field": { + "title": "Field", + "type": "string" + }, + "sort": { + "enum": [ + "asc", + "desc" + ], + "title": "Sort", + "type": "string" + } + }, + "required": [ + "field", + "sort" + ], + "title": "SortClause", + "type": "object" + }, "TestOnlyExample": { "properties": { "name": { @@ -389,6 +591,9 @@ }, "AnnotationSpec": { "$ref": "#/$defs/AnnotationSpec" + }, + "SavedView": { + "$ref": "#/$defs/SavedView" } }, "required": [ @@ -396,7 +601,8 @@ "TestOnlyNestedBaseObject", "Leaderboard", "ActionSpec", - "AnnotationSpec" + "AnnotationSpec", + "SavedView" ], "title": "CompositeBaseObject", "type": "object" diff --git a/weave/trace_server/interface/builtin_object_classes/saved_view.py b/weave/trace_server/interface/builtin_object_classes/saved_view.py new file mode 100644 index 000000000000..147adf162924 --- /dev/null +++ b/weave/trace_server/interface/builtin_object_classes/saved_view.py @@ -0,0 +1,66 @@ +from typing import Any, Literal + +from pydantic import BaseModel, Field + +from weave.trace_server.interface.builtin_object_classes import base_object_def + + +class LegacyFilter(BaseModel): + op_version_refs: list[str] | None = Field(default=None) + input_object_version_refs: list[str] | None = Field(default=None) + output_object_version_refs: list[str] | None = Field(default=None) + + +class Filter(BaseModel): + id: int + field: str + # Type of operator could be locked down more, but this is better for extensibility + operator: str + value: Any + + +class Filters(BaseModel): + items: list[Filter] + logicOperator: Literal["and"] + + +class Pin(BaseModel): + # TODO: Make them optional? But one is required? + left: list[str] + right: list[str] + + +SortDirection = Literal["asc", "desc"] + + +class SortClause(BaseModel): + field: str + sort: SortDirection + + +class SavedViewDefinition(BaseModel): + filter: LegacyFilter | None = Field(default=None) + filters: Filters | None = Field(default=None) + cols: dict[str, bool] | None = Field(default=None) + pin: Pin | None = Field(default=None) + sort: list[SortClause] | None = Field(default=None) + page_size: int | None = Field(default=None) + + +class SavedView(base_object_def.BaseObject): + # "traces" or "evaluations" + table: str + + # Avoiding confusion around object_id + name + label: str + + # TODO: We should update our general object so + # that this is a field the backend sets based on + # the authenticated user. + creatorUserId: str + + definition: SavedViewDefinition + # reference: base_object_def.RefStr + + +__all__ = ["SavedView"] diff --git a/weave/wandb_interface/wandb_api.py b/weave/wandb_interface/wandb_api.py index a5cc668cc125..36b6cb7ea29b 100644 --- a/weave/wandb_interface/wandb_api.py +++ b/weave/wandb_interface/wandb_api.py @@ -428,6 +428,7 @@ def artifact_manifest_url_from_id(self, art_id: str) -> Optional[str]: """ query DefaultEntity { viewer { + id username defaultEntity { name @@ -447,6 +448,14 @@ def default_entity_name(self) -> Optional[str]: except AttributeError: return None + def viewer_id(self) -> Optional[str]: + try: + result = self.query(self.VIEWER_DEFAULT_ENTITY_QUERY) + except gql.transport.exceptions.TransportQueryError as e: + return None + + return result.get("viewer", {}).get("id", None) + def username(self) -> Optional[str]: try: result = self.query(self.VIEWER_DEFAULT_ENTITY_QUERY)