diff --git a/docs/source/tree_view.rst b/docs/source/tree_view.rst index c0c24710..4d9f93db 100644 --- a/docs/source/tree_view.rst +++ b/docs/source/tree_view.rst @@ -20,11 +20,20 @@ Please visit :doc:`key bindings ` page for a complete list of avai Viewing Externally Generated Tracks *********************************** It is also possible to view tracks that were not created from the motile widget using -the synchronized Tree View and napari layers. This is not accessible from the UI, so -you will need to make a python script to create a Tracks object and load it into the -viewer. +the synchronized Tree View and napari layers. To do so, navigate to the ``Results List`` tab and select ``External tracks from CSV`` in the dropdown menu at the bottom of the widgets, and click ``Load``. +A pop up menu will allow you to select a CSV file and map its columns to the required default attributes and optional additional attributes. You may also provide the accompanying segmentation and specify scaling information. -A `SolutionTracks object`_ contains a networkx graph representing the tracking result, and optionally +The following columns have to be selected: + +- time: representing the position of the object in the time dimension. +- x: x centroid coordinate of the object. +- y: y centroid coordinate of the object. +- z (optional): z centroid coordinate of the object, if it is a 3D object. +- id: unique id of the object. +- parent_id: id of the directly connected predecessor (parent) of the object. Should be empty if the object is at the start of a lineage. +- seg_id: label value in the segmentation image data (if provided) that corresponds to the object id. + +From this, a `SolutionTracks object`_ is generated, containing a networkx graph representing the tracking result, and optionally a segmentation. The networkx graph is directed, with nodes representing detections and edges going from a detection in time t to the same object in t+n (edges go forward in time). Nodes must have an attribute representing time, by default named "time" but a different name @@ -32,9 +41,7 @@ can be stored in the ``Tracks.time_attr`` attribute. Nodes must also have one or representing position. The default way of storing positions on nodes is an attribute called "pos" containing a list of position values, but dimensions can also be stored in separate attributes (e.g. "x" and "y", each with one value). The name or list of names of the position attributes -should be specified in ``Tracks.pos_attr``. If you want to view tracks by area of the nodes, -you will also need to store the area of the corresponding segmentation on the nodes of the graph -in an ``area`` attribute. +should be specified in ``Tracks.pos_attr``. If a segmentation is provided but no ``area`` attribute, it will be computed automatically. The segmentation is expected to be a numpy array with time as the first dimension, followed by the position dimensions in the same order as the ``Tracks.pos_attr``. The segmentation @@ -43,7 +50,7 @@ motile_toolbox called ensure_unique_labels that relabels a segmentation to be un across time if needed. If a segmentation is provided, the node ids in the graph should match label id of the corresponding segmentation. -An example script that loads a tracks object from a CSV and segmentation array is provided in `scripts/view_external_tracks.csv`. Once you have a Tracks object in the format described above, +An example script that loads a tracks object from a CSV and segmentation array is provided in `scripts/view_external_tracks.py`. Once you have a Tracks object in the format described above, the following lines will view it in the Tree View and create synchronized napari layers (Points, Labels, and Tracks) to visualize the provided tracks.:: diff --git a/scripts/view_external_tracks.py b/scripts/view_external_tracks.py index c77ff893..15cb83cb 100644 --- a/scripts/view_external_tracks.py +++ b/scripts/view_external_tracks.py @@ -1,8 +1,10 @@ import napari +import pandas as pd + from motile_tracker.application_menus import MainApp from motile_tracker.data_views.views_coordinator.tracks_viewer import TracksViewer from motile_tracker.example_data import Fluo_N2DL_HeLa -from motile_tracker.utils.load_tracks import tracks_from_csv +from motile_tracker.import_export.load_tracks import tracks_from_df if __name__ == "__main__": # load the example data @@ -10,8 +12,27 @@ segmentation_arr = labels_layer_info[0] # the segmentation ids in this file correspond to the segmentation ids in the # example segmentation data, loaded above - csvfile = "hela_example_tracks.csv" - tracks = tracks_from_csv(csvfile, segmentation_arr) + csvfile = "scripts/hela_example_tracks.csv" + selected_columns = { + "time": "t", + "y": "y", + "x": "x", + "id": "id", + "parent_id": "parent_id", + "seg_id": "id", + } + + df = pd.read_csv(csvfile) + + # Create new columns for each feature based on the original column values + for feature, column in selected_columns.items(): + df[feature] = df[column] + + tracks = tracks_from_df( + df=df, + segmentation=segmentation_arr, + scale=[1, 1, 1], + ) viewer = napari.Viewer() raw_data, raw_kwargs, _ = raw_layer_info diff --git a/src/motile_tracker/data_model/solution_tracks.py b/src/motile_tracker/data_model/solution_tracks.py index a5249502..b52fd756 100644 --- a/src/motile_tracker/data_model/solution_tracks.py +++ b/src/motile_tracker/data_model/solution_tracks.py @@ -64,7 +64,9 @@ def get_next_track_id(self) -> int: return self.max_track_id def get_track_id(self, node) -> int: - track_id = self._get_node_attr(node, NodeAttr.TRACK_ID.value, required=True) + track_id = int( + self._get_node_attr(node, NodeAttr.TRACK_ID.value, required=True) + ) return track_id def set_track_id(self, node: Node, value: int): @@ -156,6 +158,25 @@ def export_tracks( ] if self.ndim == 3: header = [header[0]] + header[2:] # remove z + + # Add the extra attributes that are not part of the default ones + additional_attrs = { + k + for n in self.graph.nodes + for k in self.graph.nodes[n] + if k + not in ( + NodeAttr.TIME.value, + NodeAttr.SEG_ID.value, + NodeAttr.TRACK_ID.value, + NodeAttr.POS.value, + "track_id", + "lineage_id", + "color", + ) + } + header = header + list(additional_attrs) + with open(outfile, "w") as f: f.write(",".join(header)) for node_id in self.graph.nodes(): @@ -166,6 +187,9 @@ def export_tracks( position = self.get_position(node_id) lineage_id = self.get_lineage_id(node_id) color = colormap.map(track_id)[:3] * 255 + attrs = [ + self._get_node_attr(node_id, attr) for attr in additional_attrs + ] row = [ time, *position, @@ -174,6 +198,7 @@ def export_tracks( track_id, lineage_id, color, + *attrs, ] f.write("\n") f.write(",".join(map(str, row))) diff --git a/src/motile_tracker/data_model/tracks.py b/src/motile_tracker/data_model/tracks.py index e745675b..28ad6e8c 100644 --- a/src/motile_tracker/data_model/tracks.py +++ b/src/motile_tracker/data_model/tracks.py @@ -589,7 +589,7 @@ def _compute_ndim( ndim = ndims[0] if not all(d == ndim for d in ndims): raise ValueError( - f"Dimensions from segmentation {seg_ndim}, scale {scale_ndim}, and ndim {provided_ndim} must match" + f"Dimensions from segmentation: {seg_ndim}, scale: {scale_ndim}, and ndim: {provided_ndim} must match" ) return ndim diff --git a/src/motile_tracker/data_views/views/layers/track_points.py b/src/motile_tracker/data_views/views/layers/track_points.py index 2c1275e0..b333ffa0 100644 --- a/src/motile_tracker/data_views/views/layers/track_points.py +++ b/src/motile_tracker/data_views/views/layers/track_points.py @@ -107,6 +107,11 @@ def click(layer, event): # listen to updates of the data self.events.data.connect(self._update_data) + # connect to changing the point size in the UI + self.events.current_size.connect( + lambda: self.set_point_size(size=self.current_size) + ) + # listen to updates in the selected data (from the point selection tool) # to update the nodes in self.tracks_viewer.selected_nodes self.selected_data.events.items_changed.connect(self._update_selection) diff --git a/src/motile_tracker/data_views/views/tree_view/tree_widget.py b/src/motile_tracker/data_views/views/tree_view/tree_widget.py index 7f3dd836..175ac599 100644 --- a/src/motile_tracker/data_views/views/tree_view/tree_widget.py +++ b/src/motile_tracker/data_views/views/tree_view/tree_widget.py @@ -53,24 +53,36 @@ def mouseDragEvent(self, ev, axis=None): """Modified mouseDragEvent function to check which mouse mode to use and to submit rectangle coordinates for selecting multiple nodes if necessary""" - super().mouseDragEvent(ev, axis) - - # use RectMode when pressing shift - if ev.modifiers() == QtCore.Qt.ShiftModifier: - self.setMouseMode(self.RectMode) + # check if SHIFT is pressed + shift_down = ev.modifiers() == QtCore.Qt.ShiftModifier + if shift_down: + # if starting a shift-drag, record the scene position if ev.isStart(): self.mouse_start_pos = self.mapSceneToView(ev.scenePos()) - elif ev.isFinish(): + + # Put the ViewBox in RectMode so it draws its usual yellow rectangle + self.setMouseMode(self.RectMode) + super().mouseDragEvent(ev, axis) + + # Once the drag finishes, emit the rectangle + if ev.isFinish(): rect_end_pos = self.mapSceneToView(ev.scenePos()) rect = QtCore.QRectF(self.mouse_start_pos, rect_end_pos).normalized() self.selected_rect.emit(rect) # emit the rectangle ev.accept() - else: - ev.ignore() + + if hasattr(self, "rbScaleBox") and self.rbScaleBox: + self.rbScaleBox.hide() + else: - # Otherwise, set pan mode + # SHIFT not pressed - use PanMode normally self.setMouseMode(self.PanMode) + super().mouseDragEvent(ev, axis) + + # hide the leftover box if any + if hasattr(self, "rbScaleBox") and self.rbScaleBox: + self.rbScaleBox.hide() class TreePlot(pg.PlotWidget): diff --git a/src/motile_tracker/data_views/views_coordinator/tracks_list.py b/src/motile_tracker/data_views/views_coordinator/tracks_list.py index b985fb6b..55fa29de 100644 --- a/src/motile_tracker/data_views/views_coordinator/tracks_list.py +++ b/src/motile_tracker/data_views/views_coordinator/tracks_list.py @@ -7,6 +7,8 @@ from napari._qt.qt_resources import QColoredSVGIcon from qtpy.QtCore import Signal from qtpy.QtWidgets import ( + QComboBox, + QDialog, QFileDialog, QGroupBox, QHBoxLayout, @@ -19,7 +21,10 @@ ) from superqt.fonticon import icon as qticon -from motile_tracker.data_model import Tracks +from motile_tracker.data_model import SolutionTracks, Tracks +from motile_tracker.import_export.menus.import_external_tracks_dialog import ( + ImportTracksDialog, +) from motile_tracker.motile.backend.motile_run import MotileRun @@ -91,14 +96,29 @@ def __init__(self): self.tracks_list.setSelectionMode(1) # single selection self.tracks_list.itemSelectionChanged.connect(self._selection_changed) - load_button = QPushButton("Load tracks") + load_menu = QHBoxLayout() + self.dropdown_menu = QComboBox() + self.dropdown_menu.addItems(["Motile Run", "External tracks from CSV"]) + + load_button = QPushButton("Load") load_button.clicked.connect(self.load_tracks) + load_menu.addWidget(self.dropdown_menu) + load_menu.addWidget(load_button) + layout = QVBoxLayout() layout.addWidget(self.tracks_list) - layout.addWidget(load_button) + layout.addLayout(load_menu) self.setLayout(layout) + def _load_external_tracks(self): + dialog = ImportTracksDialog() + if dialog.exec_() == QDialog.Accepted: + tracks = dialog.tracks + name = dialog.name + if tracks is not None: + self.add_tracks(tracks, name, select=True) + def _selection_changed(self): selected = self.tracks_list.selectedItems() if selected: @@ -151,9 +171,19 @@ def remove_tracks(self, item: QListWidgetItem): self.tracks_list.takeItem(row) def load_tracks(self): + """Call the function to load tracks from disk for a Motile Run or for externally generated tracks (CSV file), + depending on the choice in the dropdown menu.""" + + if self.dropdown_menu.currentText() == "Motile Run": + self.load_motile_run() + elif self.dropdown_menu.currentText() == "External tracks from CSV": + self._load_external_tracks() + + def load_motile_run(self): """Load a set of tracks from disk. The user selects the directory created by calling save_tracks. """ + if self.file_dialog.exec_(): directory = Path(self.file_dialog.selectedFiles()[0]) name = directory.stem @@ -162,7 +192,7 @@ def load_tracks(self): self.add_tracks(tracks, name, select=True) except (ValueError, FileNotFoundError): try: - tracks = Tracks.load(directory) + tracks = SolutionTracks.load(directory) self.add_tracks(tracks, name, select=True) except (ValueError, FileNotFoundError) as e: warn(f"Could not load tracks from {directory}: {e}", stacklevel=2) diff --git a/src/motile_tracker/import_export/__init__.py b/src/motile_tracker/import_export/__init__.py new file mode 100644 index 00000000..c7971370 --- /dev/null +++ b/src/motile_tracker/import_export/__init__.py @@ -0,0 +1,2 @@ +from .menus.import_external_tracks_dialog import ImportTracksDialog # noqa +from .load_tracks import tracks_from_df # noqa diff --git a/src/motile_tracker/import_export/load_tracks.py b/src/motile_tracker/import_export/load_tracks.py new file mode 100644 index 00000000..26ea5f2c --- /dev/null +++ b/src/motile_tracker/import_export/load_tracks.py @@ -0,0 +1,271 @@ +import ast +from warnings import warn + +import networkx as nx +import numpy as np +import pandas as pd +from motile_toolbox.candidate_graph import NodeAttr + +from motile_tracker.data_model import SolutionTracks + + +def ensure_integer_ids(df: pd.DataFrame) -> pd.DataFrame: + """Ensure that the 'id' column in the dataframe contains integer values + + Args: + df (pd.DataFrame): A pandas dataframe with a columns named "id" and "parent_id" + + Returns: + pd.DataFrame: The same dataframe with the ids remapped to be unique integers. + Parent id column is also remapped. + """ + if not pd.api.types.is_integer_dtype(df["id"]): + unique_ids = df["id"].unique() + id_mapping = { + original_id: new_id + for new_id, original_id in enumerate(unique_ids, start=1) + } + df["id"] = df["id"].map(id_mapping) + df["parent_id"] = df["parent_id"].map(id_mapping).astype(pd.Int64Dtype()) + + return df + + +def ensure_correct_labels(df: pd.DataFrame, segmentation: np.ndarray) -> np.ndarray: + """Create a new segmentation where the values from the column df['seg_id'] are + replaced by those in df['id'] + + Args: + df (pd.DataFrame): A pandas dataframe with columns "seg_id" and "id" where + the "id" column contains unique integers + segmentation (np.ndarray): A numpy array where segmentation label values + are recorded in the "seg_id" column of the dataframe + + Returns: + np.ndarray: A numpy array similar to the input segmentation of dtype uint64 + where each segmentation now has a unique label across time that corresponds + to the ID of each node + """ + + # Create a new segmentation image + new_segmentation = np.zeros_like(segmentation).astype(np.uint64) + + # Loop through each time point + for t in df[NodeAttr.TIME.value].unique(): + # Filter the dataframe for the current time point + df_t = df[df[NodeAttr.TIME.value] == t] + + # Create a mapping from seg_id to id for the current time point + seg_id_to_id = dict(zip(df_t["seg_id"], df_t["id"], strict=True)) + + # Apply the mapping to the segmentation image for the current time point + for seg_id, new_id in seg_id_to_id.items(): + new_segmentation[t][segmentation[t] == seg_id] = new_id + + return new_segmentation + + +def _test_valid( + df: pd.DataFrame, segmentation: np.ndarray, scale: list[float] | None +) -> bool: + """Test if the provided segmentation, dataframe, and scale values are valid together. + Tests the following requirements: + - The scale, if provided, has same dimensions as the segmentation + - The location coordinates have the same dimensions as the segmentation + - The segmentation pixel value for the coordinates of first node corresponds + with the provided seg_id as a basic sanity check that the csv file matches with the + segmentation file + + Args: + df (pd.DataFrame): the pandas dataframe to turn into tracks, with standardized + column names + segmentation (np.ndarray): The segmentation, a 3D or 4D array of integer labels + scale (list[float] | None): A list of floats representing the relationship between + the point coordinates and the pixels in the segmentation + + Returns: + bool: True if the combination of segmentation, dataframe, and scale + pass all validity tests and can likely be loaded, and False otherwise + """ + if scale is not None: + if segmentation.ndim != len(scale): + warn( + f"Dimensions of the segmentation image ({segmentation.ndim}) " + f"do not match the number of scale values given ({len(scale)})", + stacklevel=2, + ) + return False + else: + scale = [ + 1, + ] * segmentation.ndim + + row = df.iloc[0] + pos = ( + [row[NodeAttr.TIME.value], row["z"], row["y"], row["x"]] + if "z" in df.columns + else [row[NodeAttr.TIME.value], row["y"], row["x"]] + ) + + if segmentation.ndim != len(pos): + warn( + f"Dimensions of the segmentation ({segmentation.ndim}) do not match the " + f"number of positional dimensions ({len(pos)})", + stacklevel=2, + ) + return False + + seg_id = row[NodeAttr.SEG_ID.value] + coordinates = [ + int(coord / scale_value) for coord, scale_value in zip(pos, scale, strict=True) + ] + + try: + value = segmentation[tuple(coordinates)] + except IndexError: + warn( + f"Could not get the segmentation value at index {coordinates}", stacklevel=2 + ) + return False + + return value == seg_id + + +def tracks_from_df( + df: pd.DataFrame, + segmentation: np.ndarray | None = None, + scale: list[float] | None = None, + features: dict[str, str] | None = None, +) -> SolutionTracks: + """Turns a pandas data frame with columns: + t,[z],y,x,id,parent_id,[seg_id], [optional custom attr 1], ... + into a SolutionTracks object. + + Cells without a parent_id will have an empty string or a -1 for the parent_id. + + Args: + df (pd.DataFrame): + a pandas DataFrame containing columns + t,[z],y,x,id,parent_id,[seg_id], [optional custom attr 1], ... + segmentation (np.ndarray | None, optional): + An optional accompanying segmentation. + If provided, assumes that the seg_id column in the dataframe exists and + corresponds to the label ids in the segmentation array. Defaults to None. + scale (list[float] | None, optional): + The scale of the segmentation (including the time dimension). Defaults to + None. + features (dict[str: str] | None, optional) + Dict mapping measurement attributes (area, volume) to value that specifies a + column from which to import. If value equals to "Recompute", recompute these + values instead of importing them from a column. Defaults to None. + + Returns: + SolutionTracks: a solution tracks object + Raises: + ValueError: if the segmentation IDs in the dataframe do not match the provided + segmentation + """ + if features is None: + features = {} + # check that the required columns are present + required_columns = ["id", NodeAttr.TIME.value, "y", "x", "parent_id"] + ndim = None + if segmentation is not None: + required_columns.append("seg_id") + ndim = segmentation.ndim + if ndim == 4: + required_columns.append("z") + for column in required_columns: + assert ( + column in df.columns + ), f"Required column {column} not found in dataframe columns {df.columns}" + + if segmentation is not None and not _test_valid(df, segmentation, scale): + raise ValueError( + "Segmentation ids in dataframe do not match values in segmentation." + "Is it possible that you loaded the wrong combination of csv file and " + "segmentation, or that the scaling information you provided is incorrect?" + ) + if not df["id"].is_unique: + raise ValueError("The 'id' column must contain unique values") + + df = df.map(lambda x: None if pd.isna(x) else x) # Convert NaN values to None + + # Convert custom attributes stored as strings back to lists + for col in df.columns: + if col not in required_columns: + df[col] = df[col].apply( + lambda x: ast.literal_eval(x) + if isinstance(x, str) and x.startswith("[") and x.endswith("]") + else x + ) + + df = df.sort_values( + NodeAttr.TIME.value + ) # sort the dataframe to ensure that parents get added to the graph before children + df = ensure_integer_ids(df) # Ensure that the 'id' column contains integer values + + graph = nx.DiGraph() + for _, row in df.iterrows(): + row_dict = row.to_dict() + _id = int(row["id"]) + parent_id = row["parent_id"] + if "z" in df.columns: + pos = [row["z"], row["y"], row["x"]] + ndims = 4 + else: + pos = [row["y"], row["x"]] + ndims = 3 + + attrs = { + NodeAttr.TIME.value: int(row["time"]), + NodeAttr.POS.value: pos, + } + + # add all other columns into the attributes + for attr in required_columns: + del row_dict[attr] + attrs.update(row_dict) + + if "track_id" in df.columns: + attrs[NodeAttr.TRACK_ID.value] = row["track_id"] + + # add the node to the graph + graph.add_node(_id, **attrs) + + # add the edge to the graph, if the node has a parent + # note: this loading format does not support edge attributes + if not pd.isna(parent_id) and parent_id != -1: + assert ( + parent_id in graph.nodes + ), f"Parent id {parent_id} of node {_id} not in graph yet" + graph.add_edge(parent_id, _id) + + # in the case a different column than the id column was used for the seg_id, we need + # to update the segmentation to make sure it matches the values in the id column (it + # should be checked by now that these are unique and integers) + if segmentation is not None and row["seg_id"] != row["id"]: + segmentation = ensure_correct_labels(df, segmentation) + + tracks = SolutionTracks( + graph=graph, + segmentation=segmentation, + pos_attr=NodeAttr.POS.value, + time_attr=NodeAttr.TIME.value, + ndim=ndims, + scale=scale, + ) + + # compute the 'area' attribute if needed + if ( + tracks.segmentation is not None + and NodeAttr.AREA.value not in df.columns + and len(features) > 0 + ): + nodes = tracks.graph.nodes + times = tracks.get_times(nodes) + computed_attrs = tracks._compute_node_attrs(nodes, times) + areas = computed_attrs[NodeAttr.AREA.value] + tracks._set_nodes_attr(nodes, NodeAttr.AREA.value, areas) + + return tracks diff --git a/src/motile_tracker/import_export/menus/__init__.py b/src/motile_tracker/import_export/menus/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/motile_tracker/import_export/menus/csv_widget.py b/src/motile_tracker/import_export/menus/csv_widget.py new file mode 100644 index 00000000..854e5e76 --- /dev/null +++ b/src/motile_tracker/import_export/menus/csv_widget.py @@ -0,0 +1,213 @@ +import os + +import pandas as pd +from motile_toolbox.candidate_graph import NodeAttr +from psygnal import Signal +from qtpy.QtCore import Qt +from qtpy.QtWidgets import ( + QComboBox, + QFileDialog, + QFormLayout, + QHBoxLayout, + QLabel, + QLineEdit, + QMessageBox, + QPushButton, + QVBoxLayout, + QWidget, +) + + +class CSVFieldMapWidget(QWidget): + """QWidget accepting a CSV file and displaying the different column names in QComboBoxes""" + + columns_updated = Signal() + + def __init__(self, csv_columns: list[str], seg: bool = False, incl_z: bool = False): + super().__init__() + + self.standard_fields = [ + NodeAttr.TIME.value, + "y", + "x", + "id", + "parent_id", + ] + + self.csv_columns = csv_columns + self.columns_left = [] + + if incl_z: + self.standard_fields.insert(1, "z") + if seg: + self.standard_fields.insert(-2, "seg_id") + + csv_column_layout = QVBoxLayout() + csv_column_layout.addWidget(QLabel("Choose columns from CSV file")) + + # Field Mapping Layout + self.mapping_layout = QFormLayout() + self.mapping_widgets = {} + layout = QVBoxLayout() + + initial_mapping = self._get_initial_mapping(csv_columns) + for attribute, csv_column in initial_mapping.items(): + if attribute in self.standard_fields: + combo = QComboBox() + combo.addItems(csv_columns) + combo.setCurrentText(csv_column) + combo.currentIndexChanged.connect(self._update_columns_left) + label = QLabel(attribute) + label.setToolTip(self._get_tooltip(attribute)) + self.mapping_widgets[attribute] = combo + self.mapping_layout.addRow(label, combo) + + # Assemble layouts + csv_column_layout.addLayout(self.mapping_layout) + layout.addLayout(csv_column_layout) + self.setLayout(layout) + + def _update_columns_left(self) -> None: + """Update the list of columns that have not been mapped yet""" + + self.columns_left = [ + column + for column in self.csv_columns + if column not in self.get_name_map().values() + ] + self.columns_updated.emit() + + def _get_tooltip(self, attribute: str) -> str: + """Return the tooltip for the given attribute""" + + tooltips = { + NodeAttr.TIME.value: "The time point of the track. Must be an integer", + "y": "The world y-coordinate of the track.", + "x": "The world x-coordinate of the track.", + "id": "The unique identifier of the node (string or integer).", + "parent_id": "The unique identifier of the parent node (string or integer).", + "z": "The world z-coordinate of the track.", + "seg_id": "The integer label value in the segmentation file.", + } + + return tooltips.get(attribute, "") + + def _get_initial_mapping(self, csv_columns: list[str]) -> dict[str, str]: + """Make an initial guess for mapping of csv columns to fields""" + + mapping = {} + self.columns_left: list = csv_columns.copy() + + # find exact matches for standard fields + for attribute in self.standard_fields: + if attribute in self.columns_left: + mapping[attribute] = attribute + self.columns_left.remove(attribute) + + # assign first remaining column as best guess for remaining standard fields + for attribute in self.standard_fields: + if attribute in mapping: + continue + if len(self.columns_left) > 0: + mapping[attribute] = self.columns_left.pop(0) + else: + # no good guesses left - just put something + mapping[attribute] = csv_columns[-1] + + return mapping + + def get_name_map(self) -> dict[str, str]: + """Return a mapping from feature name to csv field name""" + + return { + attribute: combo.currentText() + for attribute, combo in self.mapping_widgets.items() + } + + +class CSVWidget(QWidget): + """QWidget for selecting CSV file and optional segmentation image""" + + update_buttons = Signal() + + def __init__(self, add_segmentation: bool = False, incl_z: bool = False): + super().__init__() + + self.add_segmentation = add_segmentation + self.incl_z = incl_z + self.df = None + + self.layout = QVBoxLayout(self) + + # QlineEdit for CSV file path and browse button + self.csv_path_line = QLineEdit(self) + self.csv_path_line.setFocusPolicy(Qt.StrongFocus) + self.csv_path_line.returnPressed.connect(self._on_csv_editing_finished) + self.csv_browse_button = QPushButton("Browse Tracks CSV file", self) + self.csv_browse_button.setAutoDefault(0) + self.csv_browse_button.clicked.connect(self._browse_csv) + + csv_layout = QHBoxLayout() + csv_layout.addWidget(QLabel("CSV File Path:")) + csv_layout.addWidget(self.csv_path_line) + csv_layout.addWidget(self.csv_browse_button) + csv_widget = QWidget() + csv_widget.setLayout(csv_layout) + + self.layout.addWidget(csv_widget) + + # Initialize the CSVFieldMapWidget as None + self.csv_field_widget = None + + def _on_csv_editing_finished(self) -> None: + """Load the CSV file when the user presses Enter in the CSV path line""" + + csv_path = self.csv_path_line.text() + self._load_csv(csv_path) + + def _browse_csv(self) -> None: + """Open File dialog to select CSV file""" + + csv_file, _ = QFileDialog.getOpenFileName( + self, "Select CSV File", "", "CSV Files (*.csv)" + ) + if csv_file: + self._load_csv(csv_file) + else: + QMessageBox.warning(self, "Input Required", "Please select a CSV file.") + + def _load_csv(self, csv_file: str) -> None: + """Load the csv file and display the CSVFieldMapWidget""" + + if csv_file == "": + self.df = None + return + if not os.path.exists(csv_file): + QMessageBox.critical(self, "Error", "The specified file was not found.") + self.df = None + return + + self.csv_path_line.setText(csv_file) + + # Ensure CSV path is valid + try: + self.df = pd.read_csv(csv_file) + if self.csv_field_widget is not None: + self.layout.removeWidget(self.csv_field_widget) + self.csv_field_widget = CSVFieldMapWidget( + list(self.df.columns), seg=self.add_segmentation, incl_z=self.incl_z + ) + self.csv_field_widget.columns_updated.connect(self.update_buttons) + self.layout.addWidget(self.csv_field_widget) + self.update_buttons.emit() + + except pd.errors.EmptyDataError: + QMessageBox.critical(self, "Error", "The file is empty or has no data.") + self.df = None + return + except pd.errors.ParserError: + self.df = None + QMessageBox.critical( + self, "Error", "The file could not be parsed as a valid CSV." + ) + return diff --git a/src/motile_tracker/import_export/menus/import_external_tracks_dialog.py b/src/motile_tracker/import_export/menus/import_external_tracks_dialog.py new file mode 100644 index 00000000..2231d375 --- /dev/null +++ b/src/motile_tracker/import_export/menus/import_external_tracks_dialog.py @@ -0,0 +1,217 @@ +import pandas as pd +from qtpy.QtWidgets import ( + QDialog, + QHBoxLayout, + QMessageBox, + QPushButton, + QStackedWidget, + QVBoxLayout, + QWidget, +) + +from ..load_tracks import tracks_from_df +from .csv_widget import CSVWidget +from .measurement_widget import MeasurementWidget +from .metadata_menu import MetadataMenu +from .segmentation_widget import SegmentationWidget + + +class ImportTracksDialog(QDialog): + """Multipage dialog for importing external tracks from a CSV file""" + + def __init__(self): + super().__init__() + self.setWindowTitle("Import external tracks from CSV") + + self.csv = None + self.segmentation = None + + self.layout = QVBoxLayout(self) + self.stacked_widget = QStackedWidget() + self.layout.addWidget(self.stacked_widget) + + # navigation buttons + self.button_layout = QHBoxLayout() + self.previous_button = QPushButton("Previous") + self.next_button = QPushButton("Next") + self.finish_button = QPushButton("Finish") + self.button_layout.addWidget(self.previous_button) + self.button_layout.addWidget(self.next_button) + self.button_layout.addWidget(self.finish_button) + self.layout.addLayout(self.button_layout) + + # Connect button signals + self.previous_button.clicked.connect(self._go_to_previous_page) + self.next_button.clicked.connect(self._go_to_next_page) + self.finish_button.clicked.connect(self._finish) + + # Page 1 for metadata choices + self.page1 = QWidget() + page1_layout = QVBoxLayout() + self.menu_widget = MetadataMenu() + page1_layout.addWidget(self.menu_widget) + self.page1.setLayout(page1_layout) + self.stacked_widget.addWidget(self.page1) + + # Connect signals for updating pages + self.menu_widget.segmentation_checkbox.stateChanged.connect(self._update_pages) + self.menu_widget.radio_2D.clicked.connect(self._update_pages) + self.menu_widget.radio_3D.clicked.connect(self._update_pages) + + # Page 2 for csv loading + self.data_widget = CSVWidget( + add_segmentation=self.menu_widget.segmentation_checkbox.isChecked() + ) + self.data_widget.update_buttons.connect(self._update_buttons) + self.data_widget.update_buttons.connect(self._update_measurement_widget) + self.stacked_widget.addWidget(self.data_widget) + + # Optional Page 3 with segmentation information + self.segmentation_page = None + + # Optional Page 4 with measurement attributes that should be calculated (only if segmentation is provided) + self.measurement_widget = None + + self._update_buttons() + + def _update_measurement_widget(self) -> None: + """Update the measurement widget based on the data dimensions and on columns that have not been picked in the csv_field_widget""" + + if ( + self.data_widget.df is not None + and self.menu_widget.segmentation_checkbox.isChecked() + ): + if self.measurement_widget is not None: + self.stacked_widget.removeWidget(self.measurement_widget) + self.measurement_widget = MeasurementWidget( + self.data_widget.csv_field_widget.columns_left, + ndim=2 if self.menu_widget.radio_2D.isChecked() else 3, + ) + self.stacked_widget.addWidget(self.measurement_widget) + self._update_buttons() + + def _update_pages(self) -> None: + """Recreate page3 and page4 when the user changes the options in the menu""" + + self.stacked_widget.removeWidget(self.data_widget) + if self.segmentation_page is not None: + self.stacked_widget.removeWidget(self.segmentation_page) + if self.measurement_widget is not None: + self.stacked_widget.removeWidget(self.measurement_widget) + + self.data_widget = CSVWidget( + add_segmentation=self.menu_widget.segmentation_checkbox.isChecked(), + incl_z=self.menu_widget.radio_3D.isChecked(), + ) + self.data_widget.update_buttons.connect(self._update_buttons) + self.data_widget.update_buttons.connect(self._update_measurement_widget) + + self.stacked_widget.addWidget(self.data_widget) + + if self.menu_widget.segmentation_checkbox.isChecked(): + self.segmentation_page = SegmentationWidget( + self.menu_widget.radio_3D.isChecked() + ) + self.stacked_widget.addWidget(self.segmentation_page) + self.segmentation_page.update_buttons.connect(self._update_buttons) + + if ( + self.data_widget.df is not None + and self.menu_widget.segmentation_checkbox.isChecked() + ): + self.measurement_widget = MeasurementWidget( + self.data_widget.csv_field_widget.columns_left, + ndim=2 if self.menu_widget.radio_2D.isChecked() else 3, + ) + self.stacked_widget.addWidget(self.measurement_widget) + + self.stacked_widget.hide() + self.stacked_widget.show() + + def _go_to_previous_page(self) -> None: + """Go to the previous page.""" + + current_index = self.stacked_widget.currentIndex() + if current_index > 0: + self.stacked_widget.setCurrentIndex(current_index - 1) + self._update_buttons() + + def _go_to_next_page(self) -> None: + """Go to the next page.""" + + current_index = self.stacked_widget.currentIndex() + if current_index < self.stacked_widget.count() - 1: + self.stacked_widget.setCurrentIndex(current_index + 1) + self._update_buttons() + + def _update_buttons(self) -> None: + """Enable or disable buttons based on the current page.""" + + # Do not allow to finish if no CSV file is loaded, or if the segmentation checkbox was checked but no seg file path is given. + if self.data_widget.df is None or ( + self.menu_widget.segmentation_checkbox.isChecked() + and self.segmentation_page.image_path_line.text() == "" + ): + self.finish_button.setEnabled(False) + else: + self.finish_button.setEnabled(True) + + current_index = self.stacked_widget.currentIndex() + if current_index + 1 == self.stacked_widget.count(): + self.next_button.hide() + self.finish_button.show() + else: + self.next_button.show() + self.finish_button.hide() + self.previous_button.setEnabled(current_index > 0) + self.next_button.setEnabled(current_index < self.stacked_widget.count() - 1) + + self.finish_button.setAutoDefault(0) + self.next_button.setAutoDefault(0) + self.previous_button.setAutoDefault(0) + + def _finish(self) -> None: + """Tries to read the CSV file and optional segmentation image, + and apply the attribute to column mapping to construct a Tracks object""" + + # Retrieve selected columns for each required field, and optional columns for additional attributes + name_map = self.data_widget.csv_field_widget.get_name_map() + + # Create new columns for each feature based on the original column values + df = pd.DataFrame() + for feature, column in name_map.items(): + df[feature] = self.data_widget.df[column] + + # Read scaling information from the spinboxes + if self.segmentation_page is not None: + scale = self.segmentation_page.get_scale() + else: + scale = [1, 1, 1] if self.data_widget.incl_z is False else [1, 1, 1, 1] + + if self.measurement_widget is not None: + features = self.measurement_widget.get_measurements() + for feature in features: + if features[feature] != "Recompute" and ( + feature == "Area" or feature == "Volume" + ): + df["area"] = self.data_widget.df[features[feature]] + else: + features = [] + + # Try to create a Tracks object with the provided CSV file, the attr:column dictionaries, and the scaling information + self.name = self.menu_widget.name_widget.text() + + if self.menu_widget.segmentation_checkbox.isChecked(): + self.segmentation_page._load_segmentation() + else: + self.segmentation_page.segmentation = None + + try: + self.tracks = tracks_from_df( + df, self.segmentation_page.segmentation, scale, features + ) + + except ValueError as e: + QMessageBox.critical(self, "Error", f"Failed to load tracks: {e}") + return + self.accept() diff --git a/src/motile_tracker/import_export/menus/measurement_widget.py b/src/motile_tracker/import_export/menus/measurement_widget.py new file mode 100644 index 00000000..ecca0a62 --- /dev/null +++ b/src/motile_tracker/import_export/menus/measurement_widget.py @@ -0,0 +1,100 @@ +from psygnal import Signal +from qtpy.QtCore import Qt +from qtpy.QtWidgets import ( + QButtonGroup, + QCheckBox, + QComboBox, + QHBoxLayout, + QLabel, + QRadioButton, + QVBoxLayout, + QWidget, +) + + +class MeasurementWidget(QWidget): + """QWidget to choose which measurements should be calculated""" + + update_features = Signal() + + def __init__(self, columns_left: list[str], ndim: int): + super().__init__() + + self.columns_left = columns_left + self.layout = QVBoxLayout() + + # Checkboxes for measurements + self.measurements = [] + self.layout.addWidget(QLabel("Choose measurements to calculate")) + + if ndim == 2: + self.measurements.append("Area") + elif ndim == 3: + self.measurements.append("Volume") + + self.measurement_checkboxes = {} + self.radio_buttons = {} + self.select_column_radio_buttons = {} + self.column_dropdowns = {} + + for measurement in self.measurements: + row_layout = QHBoxLayout() + + checkbox = QCheckBox(measurement) + checkbox.setChecked(False) + checkbox.stateChanged.connect(self.emit_update_features) + self.measurement_checkboxes[measurement] = checkbox + row_layout.addWidget(checkbox) + + recompute_radio = QRadioButton("Recompute") + recompute_radio.setChecked(True) + select_column_radio = QRadioButton("Select from column") + button_group = QButtonGroup() + button_group.addButton(recompute_radio) + button_group.addButton(select_column_radio) + self.radio_buttons[measurement] = button_group + row_layout.addWidget(recompute_radio) + row_layout.addWidget(select_column_radio) + + column_dropdown = QComboBox() + column_dropdown.addItems(self.columns_left) + column_dropdown.setEnabled(False) + column_dropdown.currentIndexChanged.connect(self.emit_update_features) + self.column_dropdowns[measurement] = column_dropdown + row_layout.addWidget(column_dropdown) + + select_column_radio.toggled.connect( + lambda checked, dropdown=column_dropdown: dropdown.setEnabled(checked) + ) + select_column_radio.toggled.connect(self.emit_update_features) + + self.layout.addLayout(row_layout) + + self.layout.setAlignment(Qt.AlignTop) + self.setLayout(self.layout) + + def emit_update_features(self): + self.update_features.emit() + + def get_measurements(self) -> list[str]: + """Return the selected measurements as a list of strings""" + + selected_measurements = [] + for prop_name, checkbox in self.measurement_checkboxes.items(): + if checkbox.isChecked(): + selected_measurements.append(prop_name) + + measurements = {} + for measurement in selected_measurements: + button_group = self.radio_buttons[measurement] + checked_button = button_group.checkedButton() + if checked_button is not None: + if checked_button.text() == "Recompute": + measurements[measurement] = "Recompute" + elif checked_button.text() == "Select from column": + # retrieve the column name that was chosen + measurements[measurement] = self.column_dropdowns[ + measurement + ].currentText() + + return measurements diff --git a/src/motile_tracker/import_export/menus/metadata_menu.py b/src/motile_tracker/import_export/menus/metadata_menu.py new file mode 100644 index 00000000..ca2da3d4 --- /dev/null +++ b/src/motile_tracker/import_export/menus/metadata_menu.py @@ -0,0 +1,58 @@ +from qtpy.QtWidgets import ( + QButtonGroup, + QCheckBox, + QGroupBox, + QHBoxLayout, + QLineEdit, + QRadioButton, + QVBoxLayout, + QWidget, +) + + +class MetadataMenu(QWidget): + """Menu to choose tracks name, data dimensions, scaling, and optional segmentation""" + + def __init__(self): + super().__init__() + + layout = QVBoxLayout() + + # Name of the tracks + name_layout = QVBoxLayout() + name_box = QGroupBox("Tracks Name") + self.name_widget = QLineEdit("External Tracks from CSV") + name_layout.addWidget(self.name_widget) + name_box.setLayout(name_layout) + name_box.setMaximumHeight(100) + layout.addWidget(name_box) + + # Dimensions of the tracks + dimensions_layout = QVBoxLayout() + dimension_box = QGroupBox("Data Dimensions") + data_button_group = QButtonGroup() + button_layout = QHBoxLayout() + self.radio_2D = QRadioButton("2D + time") + self.radio_2D.setChecked(True) + self.radio_3D = QRadioButton("3D + time") + data_button_group.addButton(self.radio_2D) + data_button_group.addButton(self.radio_3D) + button_layout.addWidget(self.radio_2D) + button_layout.addWidget(self.radio_3D) + dimensions_layout.addLayout(button_layout) + dimension_box.setLayout(dimensions_layout) + dimension_box.setMaximumHeight(80) + layout.addWidget(dimension_box) + + # Whether or not a segmentation file exists + segmentation_layout = QVBoxLayout() + segmentation_box = QGroupBox("Segmentation Image") + self.segmentation_checkbox = QCheckBox("I have a segmentation image") + segmentation_layout.addWidget(self.segmentation_checkbox) + segmentation_box.setLayout(segmentation_layout) + segmentation_box.setMaximumHeight(80) + layout.addWidget(segmentation_box) + + layout.setContentsMargins(0, 3, 0, 0) + self.setLayout(layout) + self.setMinimumHeight(400) diff --git a/src/motile_tracker/import_export/menus/segmentation_widget.py b/src/motile_tracker/import_export/menus/segmentation_widget.py new file mode 100644 index 00000000..2b3135a1 --- /dev/null +++ b/src/motile_tracker/import_export/menus/segmentation_widget.py @@ -0,0 +1,193 @@ +import os + +import tifffile +import zarr +from psygnal import Signal +from qtpy.QtCore import Qt +from qtpy.QtWidgets import ( + QDialog, + QDoubleSpinBox, + QFileDialog, + QFormLayout, + QHBoxLayout, + QLabel, + QLineEdit, + QMessageBox, + QPushButton, + QVBoxLayout, + QWidget, +) + + +class SegmentationWidget(QWidget): + """QWidget for specifying pixel calibration""" + + update_buttons = Signal() + + def __init__(self, incl_z=True): + super().__init__() + + self.incl_z = incl_z + + layout = QVBoxLayout() + self.image_path_line = QLineEdit(self) + self.image_path_line.editingFinished.connect(self.update_buttons.emit) + self.image_browse_button = QPushButton("Browse Segmentation", self) + self.image_browse_button.setAutoDefault(0) + self.image_browse_button.clicked.connect(self._browse_segmentation) + + image_widget = QWidget() + image_layout = QVBoxLayout() + image_sublayout = QHBoxLayout() + image_sublayout.addWidget(QLabel("Segmentation File Path:")) + image_sublayout.addWidget(self.image_path_line) + image_sublayout.addWidget(self.image_browse_button) + + label = QLabel( + "Segmentation files can either be a single tiff stack, or a directory inside a zarr folder." + ) + font = label.font() + font.setItalic(True) + label.setFont(font) + + label.setWordWrap(True) + image_layout.addWidget(label) + + image_layout.addLayout(image_sublayout) + image_widget.setLayout(image_layout) + image_widget.setMaximumHeight(100) + + layout.addWidget(image_widget) + + # Spinboxes for scaling in x, y, and z (optional) + layout.addWidget(QLabel("Specify scaling")) + scale_form_layout = QFormLayout() + self.z_spin_box = self._scale_spin_box() + self.y_spin_box = self._scale_spin_box() + self.x_spin_box = self._scale_spin_box() + + if self.incl_z: + scale_form_layout.addRow(QLabel("z"), self.z_spin_box) + scale_form_layout.addRow(QLabel("y"), self.y_spin_box) + scale_form_layout.addRow(QLabel("x"), self.x_spin_box) + + layout.addLayout(scale_form_layout) + layout.setAlignment(Qt.AlignTop) + + self.setLayout(layout) + + def _scale_spin_box(self) -> QDoubleSpinBox: + """Return a QDoubleSpinBox for scaling values""" + + spin_box = QDoubleSpinBox() + spin_box.setValue(1.0) + spin_box.setSingleStep(0.1) + spin_box.setMinimum(0) + spin_box.setDecimals(3) + return spin_box + + def get_scale(self) -> list[float]: + """Return the scaling values in the spinboxes as a list of floats. + Since we currently require a dummy 1 value for the time dimension, add it here.""" + + if self.incl_z: + scale = [ + 1, + self.z_spin_box.value(), + self.y_spin_box.value(), + self.x_spin_box.value(), + ] + else: + scale = [ + 1, + self.y_spin_box.value(), + self.x_spin_box.value(), + ] + + return scale + + def _browse_segmentation(self) -> None: + """Open custom dialog to select either a file or a folder""" + + dialog = FileFolderDialog(self) + if dialog.exec_(): + selected_path = dialog.get_selected_path() + if selected_path: + self.image_path_line.setText(selected_path) + + def _load_segmentation(self) -> None: + """Load the segmentation image file""" + + # Check if a valid path to a segmentation image file is provided and if so load it + if os.path.exists(self.image_path_line.text()): + if self.image_path_line.text().endswith(".tif"): + segmentation = tifffile.imread( + self.image_path_line.text() + ) # Assuming no segmentation is needed at this step + elif ".zarr" in self.image_path_line.text(): + segmentation = zarr.open(self.image_path_line.text()) + else: + QMessageBox.warning( + self, + "Invalid file type", + "Please provide a tiff or zarr file for the segmentation image stack", + ) + return + else: + segmentation = None + self.segmentation = segmentation + + +class FileFolderDialog(QDialog): + def __init__(self, parent=None): + super().__init__(parent) + self.setWindowTitle("Select Tif File or Zarr Folder") + + self.layout = QVBoxLayout(self) + + self.path_line_edit = QLineEdit(self) + self.layout.addWidget(self.path_line_edit) + + button_layout = QHBoxLayout() + + self.file_button = QPushButton("Select tiff file", self) + self.file_button.clicked.connect(self.select_file) + self.file_button.setAutoDefault(False) + self.file_button.setDefault(False) + + button_layout.addWidget(self.file_button) + + self.folder_button = QPushButton("Select zarr folder", self) + self.folder_button.clicked.connect(self.select_folder) + self.folder_button.setAutoDefault(False) + self.folder_button.setDefault(False) + button_layout.addWidget(self.folder_button) + + self.layout.addLayout(button_layout) + + self.ok_button = QPushButton("OK", self) + self.ok_button.clicked.connect(self.accept) + self.layout.addWidget(self.ok_button) + + def select_file(self): + file, _ = QFileDialog.getOpenFileName( + self, + "Select Segmentation File", + "", + "Segmentation Files (*.tiff *.zarr *.tif)", + ) + if file: + self.path_line_edit.setText(file) + + def select_folder(self): + folder = QFileDialog.getExistingDirectory( + self, + "Select Folder", + "", + QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks, + ) + if folder: + self.path_line_edit.setText(folder) + + def get_selected_path(self): + return self.path_line_edit.text() diff --git a/src/motile_tracker/motile/menus/motile_widget.py b/src/motile_tracker/motile/menus/motile_widget.py index 8eabaab1..1d009618 100644 --- a/src/motile_tracker/motile/menus/motile_widget.py +++ b/src/motile_tracker/motile/menus/motile_widget.py @@ -71,7 +71,6 @@ def view_run(self, tracks: SolutionTracks) -> None: self.edit_run_widget.hide() self.view_run_widget.show() else: - show_warning("Tried to view a Tracks that is not a MotileRun") self.view_run_widget.hide() def edit_run(self, run: MotileRun | None): diff --git a/src/motile_tracker/utils/load_tracks.py b/src/motile_tracker/utils/load_tracks.py deleted file mode 100644 index ce128932..00000000 --- a/src/motile_tracker/utils/load_tracks.py +++ /dev/null @@ -1,51 +0,0 @@ -from csv import DictReader - -import networkx as nx -import numpy as np - -from motile_tracker.data_model import SolutionTracks - - -def tracks_from_csv( - csvfile: str, segmentation: np.ndarray | None = None -) -> SolutionTracks: - """Assumes a csv similar to that created from "export tracks to csv" with columns: - t,[z],y,x,id,parent_id,[seg_id] - Cells without a parent_id will have an empty string or a -1 for the parent_id. - - Args: - csvfile (str): - path to the csv to load - segmentation (np.ndarray | None, optional): - An optional accompanying segmentation. - If provided, assumes that the seg_id column in the csv file exists and - corresponds to the label ids in the segmentation array - - Returns: - Tracks: a tracks object ready to be visualized with - TracksViewer.view_external_tracks - """ - graph = nx.DiGraph() - with open(csvfile) as f: - reader = DictReader(f) - for row in reader: - _id = int(row["id"]) - attrs = { - "pos": [float(row["y"]), float(row["x"])], - "time": int(row["t"]), - } - if "seg_id" in row: - attrs["seg_id"] = int(row["seg_id"]) - graph.add_node(_id, **attrs) - parent_id = row["parent_id"].strip() - if parent_id != "": - parent_id = int(parent_id) - if parent_id != -1: - graph.add_edge(parent_id, _id) - tracks = SolutionTracks( - graph=graph, - segmentation=segmentation, - pos_attr="pos", - time_attr="time", - ) - return tracks diff --git a/tests/data_model/test_solution_tracks.py b/tests/import_export/test_export_solution_to_csv.py similarity index 70% rename from tests/data_model/test_solution_tracks.py rename to tests/import_export/test_export_solution_to_csv.py index 5ec14f38..364851f4 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/import_export/test_export_solution_to_csv.py @@ -10,8 +10,30 @@ def test_export_to_csv(graph_2d, graph_3d, tmp_path, colormap): assert len(lines) == tracks.graph.number_of_nodes() + 1 # add header - header = ["t", "y", "x", "id", "parent_id", "track_id", "lineage_id", "color"] + header = [ + "t", + "y", + "x", + "id", + "parent_id", + "track_id", + "lineage_id", + "color", + "area", + ] assert lines[0].strip().split(",") == header + line1 = [ + "0", + "50", + "50", + "1", + "", + "1", + "1", + "[120. 37. 6.]", + "1245", + ] + assert lines[1].strip().split(",") == line1 tracks = SolutionTracks(graph_3d, ndim=4) temp_file = tmp_path / "test_export_3d.csv" diff --git a/tests/import_export/test_import_external_tracks.py b/tests/import_export/test_import_external_tracks.py new file mode 100644 index 00000000..770c7a3a --- /dev/null +++ b/tests/import_export/test_import_external_tracks.py @@ -0,0 +1,270 @@ +import os + +import numpy as np +import pandas as pd +import pytest +from motile_toolbox.candidate_graph.graph_attributes import NodeAttr + +from motile_tracker.import_export.load_tracks import ( + _test_valid, + ensure_correct_labels, + ensure_integer_ids, + tracks_from_df, +) + + +class TestLoadTracks: + def test_non_unique_ids(self): + """Test that a ValueError is raised if the ids are not unique""" + + data = { + "id": [1, 1, 2], + "parent_id": [0, 0, 1], + NodeAttr.TIME.value: [0, 1, 2], + "y": [10, 20, 30], + "x": [15, 25, 35], + } + df = pd.DataFrame(data) + with pytest.raises(ValueError): + tracks_from_df(df) + + def test_string_ids(self): + """Test that string ids are converted to unique integers""" + + data = { + "id": ["a", "b", "c"], + "parent_id": [None, "b", "c"], + NodeAttr.TIME.value: [0, 1, 2], + "y": [10, 20, 30], + "x": [15, 25, 35], + } + df = pd.DataFrame(data) + df = ensure_integer_ids(df) + assert pd.api.types.is_integer_dtype(df["id"]) + assert ( + pd.to_numeric(df["parent_id"], errors="coerce") + .dropna() + .apply(lambda x: float(x).is_integer()) + .all() + ) + assert df["id"].is_unique + + def test_set_scale(self): + """Test that the scaling is correctly propagated to the tracks""" + + data = { + "id": [1, 2, 3], + "parent_id": [None, 1, 2], + NodeAttr.TIME.value: [0, 1, 2], + "y": [10, 20, 30], + "x": [15, 25, 35], + } + df = pd.DataFrame(data) + scale = [1, 2, 1] + tracks = tracks_from_df(df, scale=scale) + assert tracks.scale == scale + + def test_valid_segmentation(self): + """Test that the segmentation value of the first node matches with its id""" + + data = { + "id": [1, 2, 3], + "parent_id": [-1, 1, 2], + NodeAttr.TIME.value: [0, 1, 2], + "y": [0.25, 2, 1.3333], + "x": [0.75, 1.5, 1.6667], + "seg_id": [1, 2, 3], + } + df = pd.DataFrame(data) + segmentation = np.array( + [ + [[1, 1, 1], [1, 3, 3], [2, 2, 3]], + [[1, 1, 1], [1, 3, 3], [2, 2, 3]], + [[1, 1, 1], [1, 3, 3], [2, 2, 3]], + ] + ) + + assert _test_valid(df, segmentation, scale=[1, 1, 1]) + assert _test_valid(df, segmentation, scale=None) + + data = { + "id": [1, 2, 3], + "parent_id": [-1, 1, 2], + NodeAttr.TIME.value: [0, 1, 2], + "y": [1, 8, 5.3333], + "x": [3, 6, 6.6667], + "seg_id": [1, 2, 3], + } + df = pd.DataFrame(data) + + # test if False when scaling is applied incorrectly + assert not _test_valid(df, segmentation, scale=[1, 1, 1]) + with pytest.raises( + ValueError, + match="Segmentation ids in dataframe do not match values in segmentation." + "Is it possible that you loaded the wrong combination of csv file and " + "segmentation, or that the scaling information you provided is incorrect?", + ): + tracks_from_df(df, segmentation=segmentation, scale=[1, 1, 1]) + # test if True when scaling is applied correctly + assert _test_valid(df, segmentation, scale=[1, 4, 4]) + # ndim of segmentation should match with the length of provided scale + with pytest.warns( + UserWarning, + match=r"Dimensions of the segmentation image \(3\) do not match the number " + r"of scale values given \(4\)", + ): + assert not _test_valid(df, segmentation, scale=[1, 4, 4, 1]) + + data = { + "id": [1, 2, 3], + "parent_id": [-1, 1, 2], + NodeAttr.TIME.value: [0, 1, 2], + "z": [1, 1, 1], + "y": [1, 8, 5.3333], + "x": [3, 6, 6.6667], + "seg_id": [1, 2, 3], + } + df = pd.DataFrame(data) + # ndim of segmentation should match with the dims specified in the dataframe + with pytest.warns( + UserWarning, + match=r"Dimensions of the segmentation \(3\) do not match the number " + r"of positional dimensions \(4\)", + ): + assert not _test_valid(df, segmentation, scale=[1, 4, 4]) + + # test actual 4D data + data = { + "id": [1, 2, 3], + "parent_id": [-1, 1, 2], + NodeAttr.TIME.value: [0, 1, 2], + "z": [0, 0, 0], + "y": [1, 8, 5.3333], + "x": [3, 6, 6.6667], + "seg_id": [1, 2, 3], + } + seg_4d = np.array([segmentation, segmentation, segmentation]) + df = pd.DataFrame(data) + assert _test_valid(df, seg_4d, scale=[1, 1, 4, 4]) + tracks = tracks_from_df(df, seg_4d, scale=[1, 1, 4, 4]) + assert tracks.graph.number_of_nodes() == 3 + assert tracks.graph.number_of_edges() == 2 + + def test_relabel_segmentation(self): + """Test relabeling the segmentation if id != seg_id""" + + data = { + NodeAttr.TIME.value: [0, 0, 0, 1], + "seg_id": [1, 2, 3, 3], + "id": [10, 20, 30, 40], + "x": [0, 0, 1, 1], + "y": [0, 2, 2, 2], + "parent_id": [None, None, None, None], + } + df = pd.DataFrame(data) + segmentation = np.array( + [ + [[1, 1, 2], [2, 3, 3], [1, 2, 3]], + [[0, 0, 0], [0, 3, 3], [0, 0, 3]], + ] + ) + new_segmentation = ensure_correct_labels(df, segmentation) + expected_segmentation = np.array( + [ + [[10, 10, 20], [20, 30, 30], [10, 20, 30]], + [[0, 0, 0], [0, 40, 40], [0, 0, 40]], + ] + ) + + np.testing.assert_array_equal(new_segmentation, expected_segmentation) + tracks = tracks_from_df(df, segmentation) + np.testing.assert_array_equal(tracks.segmentation, expected_segmentation) + assert tracks.graph.number_of_nodes() == 4 + assert tracks.graph.number_of_edges() == 0 + + def test_measurements(self): + """Test if the area is measured correctly, taking scaling into account""" + + data = { + NodeAttr.TIME.value: [0, 0, 0, 1], + "seg_id": [1, 2, 3, 4], + "id": [1, 2, 3, 4], + "parent_id": [None, 1, 2, 3], + "y": [0, 1.6667, 1.333, 1.333], + "x": [1, 0.33333, 1.66667, 1.66667], + } + df = pd.DataFrame(data) + segmentation = np.array( + [[[1, 1, 1], [2, 3, 3], [2, 3, 3]], [[1, 1, 0], [2, 4, 4], [2, 2, 4]]] + ) + + tracks = tracks_from_df( + df, segmentation, scale=(1, 1, 1), features={"Area": "Recompute"} + ) + + assert tracks._get_node_attr(1, NodeAttr.AREA.value) == 3 + assert tracks._get_node_attr(2, NodeAttr.AREA.value) == 2 + assert tracks._get_node_attr(3, NodeAttr.AREA.value) == 4 + assert tracks._get_node_attr(4, NodeAttr.AREA.value) == 3 + + tracks = tracks_from_df( + df, segmentation, scale=(1, 2, 1), features={"Area": "Recompute"} + ) + + assert tracks._get_node_attr(1, NodeAttr.AREA.value) == 6 + assert tracks._get_node_attr(2, NodeAttr.AREA.value) == 4 + assert tracks._get_node_attr(3, NodeAttr.AREA.value) == 8 + assert tracks._get_node_attr(4, NodeAttr.AREA.value) == 6 + + tracks = tracks_from_df( + df, segmentation=None, scale=(1, 2, 1), features={"Area": "Recompute"} + ) # no seg provided, should return None + + assert tracks._get_node_attr(1, NodeAttr.AREA.value) is None + + tracks = tracks_from_df( + df, segmentation, scale=(1, 2, 1), features={} + ) # no area measurement provided, should return None. + + assert tracks._get_node_attr(1, NodeAttr.AREA.value) is None + + data = { + NodeAttr.TIME.value: [0, 0, 0, 1], + "seg_id": [1, 2, 3, 4], + "id": [1, 2, 3, 4], + "parent_id": [None, 1, 2, 3], + "y": [0, 1.6667, 1.333, 1.333], + "x": [1, 0.33333, 1.66667, 1.66667], + "area": [1, 2, 3, 4], + } + df = pd.DataFrame(data) + + tracks = tracks_from_df( + df, segmentation, scale=(1, 1, 1), features={"Area": "area"} + ) # Area column provided by the dataframe (import_external_tracks_dialog is in charge of mapping a custom column to a column named 'area' (to be updated in future version that supports additional measured features) + + assert tracks._get_node_attr(1, NodeAttr.AREA.value) == 1 + assert tracks._get_node_attr(2, NodeAttr.AREA.value) == 2 + assert tracks._get_node_attr(3, NodeAttr.AREA.value) == 3 + assert tracks._get_node_attr(4, NodeAttr.AREA.value) == 4 + + def test_load_sample_data(self): + test_dir = os.path.abspath(__file__) + example_csv = os.path.abspath( + os.path.join(test_dir, "../../../scripts/hela_example_tracks.csv") + ) + + df = pd.read_csv(example_csv) + + # Retrieve selected columns for each required field, and optional columns for additional attributes + name_map = { + "time": "t", + } + # Create new columns for each feature based on the original column values + for feature, column in name_map.items(): + df[feature] = df[column] + + tracks = tracks_from_df(df) + for node in tracks.graph.nodes(): + assert isinstance(node, int)