diff --git a/README.md b/README.md index 1d80ec8..6583375 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ With support for local or cluster-based parallelization, CIMAT provides visualiz ### Run workflow with Snakemake Run all jobs in the pipeline: ```bash -snakemake --executor slurm --jobs 20 --latency-wait 10 all +snakemake --executor slurm --jobs 20 --latency-wait 10 all --forcerun preprocess --rerun-incomplete ``` Add `-np --printshellcmds` for a dry run with commands printed to the terminal. @@ -24,3 +24,13 @@ Create the report: datavzrd workflow/resources/datavzrd_config.yaml --output workflow/results/datavzrd ``` Then open the report (`index.html`) in a browser. + +specific dataset, rerun: +```bash +snakemake --executor slurm --jobs 20 --latency-wait 10 /ceph/margrie/laura/cimaut/derivatives/sub-1_230802CAA1120182/ses-0/funcimg/derotation/derotated_full.tif --forcerun preprocess --rerun-incomplete +``` + +summary plot: +```bash +snakemake --cores 1 --latency-wait 10 workflow/results/data/stability_metric.png +``` diff --git a/calcium_imaging_automation/core/rules/plot_data.py b/calcium_imaging_automation/core/rules/plot_data.py new file mode 100644 index 0000000..0ab09ee --- /dev/null +++ b/calcium_imaging_automation/core/rules/plot_data.py @@ -0,0 +1,315 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib import pyplot as plt +from snakemake.script import snakemake + +print("Plotting data...") + + +# datasets = snakemake.params.datasets +f_path = Path(snakemake.input[0]) + +print(f"Processing file: {f_path}") + +dataset_path = f_path.parent.parent.parent.parent +dataset = dataset_path.name +f_neu_path = f_path.parent / "Fneu.npy" +derotated_full_csv_path = ( + dataset_path / "ses-0" / "funcimg" / "derotation" / "derotated_full.csv" +) +saving_path = Path(dataset_path) / "ses-0" / "traces" +saving_path.mkdir(exist_ok=True) + +print(f"Dataset path: {dataset_path}") +print(f"Dataset: {dataset}") +print(f"Derotated full csv path: {derotated_full_csv_path}") +print(f"Fneu path: {f_neu_path}") +print(f"Saving path: {saving_path}") + + +f = np.load(f_path) +fneu = np.load(f_neu_path) +rotated_frames = pd.read_csv(derotated_full_csv_path) +f_corrected = f - 0.7 * fneu + + +F_df = pd.DataFrame(f_corrected).T + +print(f"Shape of F_df: {F_df.shape}") +print(F_df.head()) + +full_dataframe = pd.concat([F_df, rotated_frames], axis=1) + +# -------------------------------------------------------- +# Prepare the dataset + +# find where do rotations start +rotation_on = np.diff(full_dataframe["rotation_count"]) + + +def find_zero_chunks(arr): + zero_chunks = [] + start = None + + for i in range(len(arr)): + if arr[i] == 0 and start is None: + start = i + elif arr[i] != 0 and start is not None: + zero_chunks.append((start, i - 1)) + start = None + + # Check if the array ends with a chunk of zeros + if start is not None: + zero_chunks.append((start, len(arr) - 1)) + + return zero_chunks + + +starts_ends = find_zero_chunks(rotation_on) + +frames_before_rotation = 15 +# frames_after_rotation = 10 + +total_len = 100 + +full_dataframe["rotation_frames"] = np.zeros(len(full_dataframe)) +for i, (start, end) in enumerate(starts_ends): + frame_array = np.arange(total_len) + column_index_of_rotation_frames = full_dataframe.columns.get_loc( + "rotation_frames" + ) + full_dataframe.iloc[ + start - frames_before_rotation : total_len + + start + - frames_before_rotation, + column_index_of_rotation_frames, + ] = frame_array + + # extend this value of speed and direction to all this range + this_speed = full_dataframe.loc[start, "speed"] + this_direction = full_dataframe.loc[start, "direction"] + + full_dataframe.iloc[ + start - frames_before_rotation : total_len + + start + - frames_before_rotation, + full_dataframe.columns.get_loc("speed"), + ] = this_speed + full_dataframe.iloc[ + start - frames_before_rotation : total_len + + start + - frames_before_rotation, + full_dataframe.columns.get_loc("direction"), + ] = this_direction + + +# directtion, change -1 to CCW and 1 to CW +full_dataframe["direction"] = np.where( + full_dataframe["direction"] == -1, "CCW", "CW" +) + +# print(f"Full dataframe shape: {full_dataframe.shape}") +# print(full_dataframe.head()) + +# # angle based calculation of ΔF/F +# # first calculate F0, as the 20th quantile for each angle. +# # consider angles every 5 degrees, from 0 to 355 +# full_dataframe["aproximated_rotation_angle"] = ( +# full_dataframe["rotation_angle"] // 5 * 5 +# ) + +# print("Unique angles:") +# print(full_dataframe["aproximated_rotation_angle"].unique()) + +# f0_as_20th_quantile_per_angle = np.zeros((360, f_corrected.shape[0])) +# for angle in range(360): +# for roi in range(f_corrected.shape[0]): +# angle_indices = full_dataframe["aproximated_rotation_angle"] == angle +# print(f"Angle: {angle}, ROI: {roi}") +# print(f"Angle indices: {angle_indices}") +# # check for nans / missing values in angle_indices +# if angle_indices.isnull().values.any(): +# f0_as_20th_quantile_per_angle[angle, roi] = np.nan +# else: +# f0_as_20th_quantile_per_angle[angle, roi] = np.quantile( +# f_corrected[roi][angle_indices], 0.2 +# ) +# print("Shape of f0_as_20th_quantile_per_angle:") +# print(f0_as_20th_quantile_per_angle.shape) +# print(f0_as_20th_quantile_per_angle) + +# # calculate ΔF/F +# for roi in range(f_corrected.T.shape[0]): +# full_dataframe[roi] = ( +# f_corrected.T[roi] - f0_as_20th_quantile_per_angle[ +# full_dataframe["rotation_angle"], roi +# ] +# ) / f0_as_20th_quantile_per_angle[ +# full_dataframe["rotation_angle"], roi +# ] + +# print("Full dataframe with ΔF/F:") +# print(full_dataframe.head()) + +rois_selection = range(F_df.shape[1]) + +# -------------------------------------------------------- +# Plot single traces + +# %% +selected_range = (400, 2000) + +for roi in rois_selection: + roi_selected = full_dataframe.loc[ + :, [roi, "rotation_count", "speed", "direction"] + ] + + fig, ax = plt.subplots(figsize=(27, 5)) + ax.plot(roi_selected.loc[selected_range[0] : selected_range[1], roi]) + ax.set(xlabel="Frames", ylabel="Neuropil corrected (a.u.)") # "ΔF/F") + + rotation_on = ( + np.diff( + roi_selected.loc[ + selected_range[0] : selected_range[1], "rotation_count" + ] + ) + == 0 + ) + + # add label at the beginning of every block of rotations + # if the previous was true, do not write the label + for i, rotation in enumerate(rotation_on): + if rotation and not rotation_on[i - 1]: + ax.text( + i + selected_range[0] + 3, + -1100, + f"{int(roi_selected.loc[i + 5 + selected_range[0], 'speed'])}º/s\n{roi_selected.loc[i + 5 + selected_range[0], 'direction']}", + fontsize=10, + ) + + # add gray squares when the rotation is happening using the starst_ends + for start, end in starts_ends: + if start > selected_range[0] and end < selected_range[1]: + ax.axvspan(start, end, color="gray", alpha=0.2) + + fps = 6.74 + # change xticks to seconds + xticks = ax.get_xticks() + ax.set_xticks(xticks) + ax.set_xticklabels((xticks / fps).astype(int)) + # change x label + ax.set(xlabel="Seconds", ylabel="Neuropil corrected (a.u.)") # "ΔF/F") + + ax.set_xlim(selected_range) + # ax.set_ylim(-10, 10) + + # leave some gap between the axis and the plot + plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1) + + # remove top and right spines + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + plt.savefig(saving_path / f"dff_example_{roi}.pdf") + plt.savefig(saving_path / f"dff_example_{roi}.png") + plt.close() + + +# -------------------------------------------------------- +# Plot averages + +custom_palette = sns.color_palette("dark:#5A9_r", 4) + +for roi in rois_selection: + fig, ax = plt.subplots(1, 2, figsize=(20, 10)) + for i, direction in enumerate(["CW", "CCW"]): + sns.lineplot( + x="rotation_frames", + y=roi, + data=full_dataframe[(full_dataframe["direction"] == direction)], + hue="speed", + palette=custom_palette, + ax=ax[i], + ) + ax[i].set_title(f"Direction: {direction}") + ax[i].legend(title="Speed") + + # remove top and right spines + ax[i].spines["top"].set_visible(False) + ax[i].spines["right"].set_visible(False) + + # add vertical lines to show the start of the rotation + # start is always at 11, end at total len - 10 + ax[i].axvline(x=frames_before_rotation, color="gray", linestyle="--") + + # change x axis to seconds + fps = 6.74 + xticks = ax[i].get_xticks() + ax[i].set_xticks(xticks) + ax[i].set_xticklabels(np.round(xticks / fps, 1)) + # change x label + ax[i].set( + xlabel="Seconds", ylabel="Neuropil corrected (a.u.)" + ) # "ΔF/F") + + plt.savefig(saving_path / f"roi_{roi}_direction_speed.pdf") + plt.savefig(saving_path / f"roi_{roi}_direction_speed.png") + plt.close() + + # make also another plot showing all traces (not averaged - no std) + + fig, ax = plt.subplots(figsize=(20, 10)) + for i, direction in enumerate(["CW", "CCW"]): + # sns.relplot( + # x="rotation_frames", + # y=roi, + # data=full_dataframe[(full_dataframe["direction"] == direction)], + # hue="speed", + # palette=custom_palette, + # kind="line", + # estimator=None, + # style="direction", + # ax=ax, + # ) + # plot single traces using matplotlib + for speed in full_dataframe["speed"].unique(): + ax.plot( + full_dataframe[ + (full_dataframe["direction"] == direction) + & (full_dataframe["speed"] == speed) + ]["rotation_frames"], + full_dataframe[ + (full_dataframe["direction"] == direction) + & (full_dataframe["speed"] == speed) + ][roi], + label=f"{speed}º/s", + # color=custom_palette[speed], + ) + + ax.set_title(f"Direction: {direction}") + ax.legend(title="Speed") + + # remove top and right spines + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + # add vertical lines to show the start of the rotation + # start is always at 11, end at total len - 10 + ax.axvline(x=frames_before_rotation, color="gray", linestyle="--") + + # change x axis to seconds + fps = 6.74 + xticks = ax.get_xticks() + ax.set_xticks(xticks) + ax.set_xticklabels(np.round(xticks / fps, 1)) + # change x label + ax.set(xlabel="Seconds", ylabel="Neuropil corrected (a.u.)") # "ΔF/F") + + plt.savefig(saving_path / f"roi_{roi}_direction_speed_all.pdf") + plt.savefig(saving_path / f"roi_{roi}_direction_speed_all.png") + + plt.close() diff --git a/calcium_imaging_automation/core/rules/postprocess.py b/calcium_imaging_automation/core/rules/postprocess.py new file mode 100644 index 0000000..59b81ef --- /dev/null +++ b/calcium_imaging_automation/core/rules/postprocess.py @@ -0,0 +1,321 @@ +import traceback +from pathlib import Path + +import numpy as np +import pandas as pd +import seaborn as sns +from derotation.analysis.full_derotation_pipeline import FullPipeline +from derotation.analysis.metrics import stability_of_most_detected_blob +from matplotlib import pyplot as plt +from snakemake.script import snakemake + +datasets = snakemake.params.datasets +processed_data_base = snakemake.params.base_path + +csv_path = Path(snakemake.output[0]).with_suffix(".csv") +img_path = Path(snakemake.output[0]) + +if not img_path.exists(): + datasets_paths = [] + for idx, dataset in enumerate(datasets): + datasets_paths.append( + Path( + f"{snakemake.params.base_path}/sub-{idx}_{dataset}/ses-0/funcimg" + ) + ) + + movie_bin_paths = [] + for dataset in datasets_paths: + movie_bin_paths.extend(list(Path(dataset).rglob("*.bin"))) + + is_cell_paths = [] + for dataset in datasets_paths: + is_cell_paths.extend(list(Path(dataset).rglob("iscell.npy"))) + + metric_paths = [] + for path_to_bin_file in movie_bin_paths: + metric_path = ( + path_to_bin_file.parent.parent.parent / "derotation_metrics.csv" + ) + metric_paths.append(metric_path) + + derotated_full_csv_paths = [] + for path_to_bin_file in movie_bin_paths: + derotated_full_csv_path = ( + path_to_bin_file.parent.parent.parent + / "derotation/derotated_full.csv" + ) + derotated_full_csv_paths.append(derotated_full_csv_path) + + all_metrics_df = pd.DataFrame( + columns=["dataset", "analysis_type", "metric", "value"] + ) + analysis_types = [ + "no_adj", + "adj_track", + "adj_largest", + "adj_track_shifted", + ] + + for path_to_bin_file, metric_path, derotated_full_csv_path in zip( + movie_bin_paths, metric_paths, derotated_full_csv_paths + ): + print( + f"Processing dataset: {path_to_bin_file.parent.parent.parent.parent.parent.name}..." + ) + try: + metric = pd.read_csv(metric_path) + + path_to_bin_file = Path(path_to_bin_file) + + rotation_df = pd.read_csv(derotated_full_csv_path) + num_frames = len(rotation_df) + + shape_image = (num_frames, 256, 256) + registered = np.memmap( + path_to_bin_file, shape=shape_image, dtype="int16" + ) + + # plot first frame as an image of registered as a way to test if the loading was correct + plt.imshow(registered[0]) + plt.savefig(path_to_bin_file.parent / "first_frame_registered.png") + plt.close() + + derotator = FullPipeline.__new__(FullPipeline) + + angles = rotation_df["rotation_angle"].values + if len(angles) > len(registered): + angles = angles[: len(registered)] + elif len(angles) < len(registered): + angles = np.pad(angles, (0, len(registered) - len(angles))) + + derotator.rot_deg_frame = angles + mean_images = derotator.calculate_mean_images( + registered, round_decimals=0 + ) + + # show first mean image + plt.imshow(mean_images[0]) + plt.savefig(path_to_bin_file.parent / "first_mean_image.png") + plt.close() + + path_plots = path_to_bin_file.parent + try: + ptd, std = stability_of_most_detected_blob( + (mean_images, path_plots), + # blob_log_kwargs={"min_sigma": 0, "max_sigma": 20, "threshold": 0.5, "overlap": 0}, + # clip=False + ) + print(f"ptd: {ptd}, std: {std}") + except Exception as e: + print(e) + print(traceback.format_exc()) + ptd = np.nan + std = np.nan + + for i, analysis_type in enumerate(analysis_types): + row_ptd = { + "dataset": path_to_bin_file.parent.parent.parent.parent.parent.name, + "analysis_type": analysis_type, + "metric": "ptd", + "value": metric["ptd"][i], + } + row_std = { + "dataset": path_to_bin_file.parent.parent.parent.parent.parent.name, + "analysis_type": analysis_type, + "metric": "std", + "value": metric["std"][i], + } + all_metrics_df = pd.concat( + [all_metrics_df, pd.DataFrame([row_ptd, row_std])], + ignore_index=True, + ) + # add post_suite2p metrics + row_ptd = { + "dataset": path_to_bin_file.parent.parent.parent.parent.parent.name, + "analysis_type": "post_suite2p", + "metric": "ptd", + "value": ptd, + } + row_std = { + "dataset": path_to_bin_file.parent.parent.parent.parent.parent.name, + "analysis_type": "post_suite2p", + "metric": "std", + "value": std, + } + all_metrics_df = pd.concat( + [all_metrics_df, pd.DataFrame([row_ptd, row_std])], + ignore_index=True, + ) + except Exception as e: + print(e) + print("Error in dataset") + continue + + # save the dataframe to a csv file (change the file extension from png to csv) + all_metrics_df.to_csv(csv_path, index=False) + +else: + all_metrics_df = pd.read_csv(csv_path) + +sns.set_theme(style="whitegrid") +sns.set_context("paper") +sns.set_palette("pastel") + +fig, axs = plt.subplots(1, 2, figsize=(10, 5)) + +sns.pointplot( + x="analysis_type", + y="value", + hue="dataset", + data=all_metrics_df[all_metrics_df["metric"] == "ptd"], + ax=axs[0], +) + +sns.pointplot( + x="analysis_type", + y="value", + hue="dataset", + data=all_metrics_df[all_metrics_df["metric"] == "std"], + ax=axs[1], +) + +axs[0].set_title("PTD") +axs[1].set_title("STD") + +plt.tight_layout() +plt.savefig(img_path) +plt.close() + +# make another similar plot with these analysis types: +# 1. no_adj +# 2. the min between "no_adj", "adj_track", "adj_largest", "adj_track_shifted" (to be calculated) +# 3. post_suite2p + +fig, axs = plt.subplots(1, 2, figsize=(10, 5)) + +data = pd.DataFrame(columns=["dataset", "analysis_type", "metric", "value"]) +for dataset in all_metrics_df["dataset"].unique(): + dataset_df = all_metrics_df[all_metrics_df["dataset"] == dataset] + # no_adj + no_adj_value = dataset_df[ + (dataset_df["analysis_type"] == "no_adj") + & (dataset_df["metric"] == "ptd") + ]["value"].values[0] + row = { + "dataset": dataset, + "analysis_type": "no_adj", + "metric": "ptd", + "value": no_adj_value, + } + data = pd.concat([data, pd.DataFrame([row])], ignore_index=True) + # min but not for post_suite2p + min_value = dataset_df[ + (dataset_df["analysis_type"] != "post_suite2p") + & (dataset_df["metric"] == "ptd") + ]["value"].min() + row = { + "dataset": dataset, + "analysis_type": "min", + "metric": "ptd", + "value": min_value, + } + data = pd.concat([data, pd.DataFrame([row])], ignore_index=True) + # post_suite2p + post_suite2p_value = dataset_df[ + (dataset_df["analysis_type"] == "post_suite2p") + & (dataset_df["metric"] == "ptd") + ]["value"].values[0] + row = { + "dataset": dataset, + "analysis_type": "post_suite2p", + "metric": "ptd", + "value": post_suite2p_value, + } + data = pd.concat([data, pd.DataFrame([row])], ignore_index=True) + +# save dataset +data.to_csv(csv_path.with_name("min_analysis_types_min_ptd.csv"), index=False) + +sns.pointplot( + x="analysis_type", + y="value", + hue="dataset", + data=data[data["metric"] == "ptd"], + ax=axs[0], +) + + +axs[0].set_ylabel("Point to point distance (r)") +axs[0].set_xlabel("Derotation adjustment") +axs[0].set_xticklabels(["No", "Yes", "Post Suite2p"]) + +data = pd.DataFrame(columns=["dataset", "analysis_type", "metric", "value"]) + +for dataset in all_metrics_df["dataset"].unique(): + dataset_df = all_metrics_df[all_metrics_df["dataset"] == dataset] + # no_adj + no_adj_value = dataset_df[ + (dataset_df["analysis_type"] == "no_adj") + & (dataset_df["metric"] == "std") + ]["value"].values[0] + row = { + "dataset": dataset, + "analysis_type": "no_adj", + "metric": "std", + "value": no_adj_value, + } + data = pd.concat([data, pd.DataFrame([row])], ignore_index=True) + # min but not for post_suite2p + min_value = dataset_df[ + (dataset_df["analysis_type"] != "post_suite2p") + & (dataset_df["metric"] == "std") + ]["value"].min() + row = { + "dataset": dataset, + "analysis_type": "min", + "metric": "std", + "value": min_value, + } + data = pd.concat([data, pd.DataFrame([row])], ignore_index=True) + # post_suite2p + post_suite2p_value = dataset_df[ + (dataset_df["analysis_type"] == "post_suite2p") + & (dataset_df["metric"] == "std") + ]["value"].values[0] + row = { + "dataset": dataset, + "analysis_type": "post_suite2p", + "metric": "std", + "value": post_suite2p_value, + } + data = pd.concat([data, pd.DataFrame([row])], ignore_index=True) + +# save dataset +data.to_csv(csv_path.with_name("min_analysis_types_min_std.csv"), index=False) + +sns.pointplot( + x="analysis_type", + y="value", + hue="dataset", + data=data[data["metric"] == "std"], + ax=axs[1], +) + +axs[1].set_ylabel("XY standard deviation (s)") +axs[1].set_xlabel("Derotation adjustment") +axs[1].set_xticklabels(["No", "Yes", "Post Suite2p"]) + +# remove legend +axs[0].get_legend().remove() +axs[1].get_legend().remove() + +axs[0].set_title("PTD") +axs[1].set_title("STD") + +# despine +sns.despine() + +plt.tight_layout() +plt.savefig(img_path.with_name("min_analysis_types.png")) +plt.savefig(img_path.with_name("min_analysis_types.pdf")) diff --git a/calcium_imaging_automation/core/rules/preprocess.py b/calcium_imaging_automation/core/rules/preprocess.py index 27434fb..0070a4a 100644 --- a/calcium_imaging_automation/core/rules/preprocess.py +++ b/calcium_imaging_automation/core/rules/preprocess.py @@ -1,7 +1,6 @@ import traceback from pathlib import Path -from derotation.analysis.metrics import stability_of_most_detected_blob from derotation.derotate_batch import derotate from snakemake.script import snakemake @@ -9,17 +8,19 @@ read_dataset_path = Path(snakemake.input[0]) output_tif = Path(snakemake.output[0]) -output_path_dataset = output_tif.parent.parent +output_path_dataset = output_tif.parent try: - data = derotate(read_dataset_path, output_path_dataset) - metric_measured = stability_of_most_detected_blob(data) - with open(output_path_dataset / "metric.txt", "w") as f: - f.write(f"stability_of_most_detected_blob: {metric_measured}") + metrics = derotate(read_dataset_path, output_path_dataset) + # save metrics as csv (matrix is already a pandas dataframe) + metrics.to_csv(output_path_dataset / "derotation_metrics.csv", index=False) + # make empty error file with open(output_path_dataset / "error.txt", "w") as f: f.write("") except Exception: with open(output_path_dataset / "error.txt", "w") as f: f.write(traceback.format_exc()) - with open(output_path_dataset / "metric.txt", "w") as f: - f.write(f"dataset: {read_dataset_path.stem} metric: NaN") + + # make empty metrics file + with open(output_path_dataset / "derotation_metrics.csv", "w") as f: + f.write("") diff --git a/calcium_imaging_automation/core/rules/suite2p_run.py b/calcium_imaging_automation/core/rules/suite2p_run.py new file mode 100644 index 0000000..e27957f --- /dev/null +++ b/calcium_imaging_automation/core/rules/suite2p_run.py @@ -0,0 +1,54 @@ +import datetime +import traceback +from pathlib import Path + +import numpy as np +from snakemake.script import snakemake +from suite2p import run_s2p + +# Retrieve parameters and inputs from Snakemake +input_path = Path(snakemake.input[0]) +ops_file = snakemake.input[1] +dataset_folder = Path(input_path).parent.parent + +# load ops +ops = np.load(ops_file, allow_pickle=True).item() +ops["save_folder"] = str(dataset_folder) +ops["save_path0"] = str(dataset_folder) +ops["fast_disk"] = str(dataset_folder) +ops["data_path"] = [input_path.parent] + +# change ops for non-rigid registration +ops["nonrigid"] = True +ops["block_size"] = [64, 64] +ops["snr_thresh"] = 1.7 +ops["maxregshiftNR"] = 15 + +db = {"data_path": input_path} +try: + assert type(ops) == dict, f"ops is not a dict, it is {type(ops)}" + assert type(db) == dict, f"db is not a dict, it is {type(db)}" + ops_end = run_s2p(ops=ops) + + # get registration metrics from ops + metrics = { + "regDX": ops_end.get("regDX", "NaN"), + "regPC": ops_end.get("regPC", "NaN"), + "tPC": ops_end.get("tPC", "NaN"), + } + + # append in the metrics file the new metrics + with open(dataset_folder / "suite2p_metrics.txt", "w") as f: + f.write("registration metrics: \n") + for key, value in metrics.items(): + f.write(f"{key}: {value}\n") + # make empty error file + with open(dataset_folder / "error.txt", "a") as f: + f.write("") +except Exception: + with open(dataset_folder / "error.txt", "a") as f: + # add timestamp to the error file + f.write(f"Error at {datetime.datetime.now()}\n") + f.write(traceback.format_exc()) + with open(dataset_folder / "suite2p_metrics.txt", "w") as f: + f.write("registration metrics: NaN\n") diff --git a/workflow/Snakefile b/workflow/Snakefile index 677d5cd..f944ee4 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -12,6 +12,9 @@ datasets.sort() # for the output datasets_no_underscore = [ds.replace("_", "") for ds in datasets] +valid_indices = [0, 1, 6, 7, 8, 9, 11, 13, 15, 16, 18] +subsample_datasets = [datasets_no_underscore[i] for i in valid_indices] + # ----------------------------------------------------- # Final state of the pipeline # Are all the outputs files present? @@ -19,15 +22,20 @@ rule all: input: expand( [ - f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/derotation/derotated_full.tif", - f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/derotation/derotated_full.csv", - f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/metric.txt", - f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/error.txt", + # f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/derotation/derotated_full.tif", + # f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/derotation/derotated_full.csv", + # f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/derotation_metrics.csv", + # f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/error.txt", + # f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/plane0/stat.npy", + f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/traces/dff_example_100.png", ], zip, - index=range(len(datasets)), - datasets_no_underscore=datasets_no_underscore, + # index=range(len(datasets)), + index=valid_indices, + # datasets_no_underscore=datasets_no_underscore, + datasets_no_underscore=subsample_datasets, ), + # f"{processed_data_base}/stability_metric.png", # ----------------------------------------------------- # Preprocess @@ -35,21 +43,75 @@ rule preprocess: input: raw=lambda wildcards: f"{raw_data_base}{datasets[int(wildcards.index)]}/", output: - report(f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/metric.txt"), - report(f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/error.txt"), + f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/derotation_metrics.csv", + f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/error.txt", tiff=f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/derotation/derotated_full.tif", csv=f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/derotation/derotated_full.csv", params: index=lambda wildcards: wildcards.index resources: partition="fast", - mem_mb=16000, + mem_mb=32000, cpu_per_task=1, tasks=1, nodes=1, script: "../calcium_imaging_automation/core/rules/preprocess.py" +# ----------------------------------------------------- +# Suite2p +rule suite2p: + input: + f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/derotation/derotated_full.tif", + f"/ceph/margrie/laura/cimaut/3p_non_rigid_ops.npy", + output: + f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/plane0/stat.npy", + f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/suite2p/plane0/data.bin", + params: + index=lambda wildcards: wildcards.index + resources: + partition="fast", + mem_mb=16000, + cpu_per_task=1, + tasks=1, + nodes=1, + script: + "../calcium_imaging_automation/core/rules/suite2p_run.py" + +# ----------------------------------------------------- +# Collect suite2p data and make plots +rule postprocess: + output: + # f"{processed_data_base}/stability_metric.png", + "workflow/results/data/stability_metric.png", + params: + datasets=datasets_no_underscore, + base_path=processed_data_base + resources: + partition="fast", + mem_mb=8000, + cpu_per_task=1, + tasks=1, + nodes=1, + script: + "../calcium_imaging_automation/core/rules/postprocess.py" + +rule plot_traces: + input: + f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/funcimg/plane0/F.npy" + output: + f"{processed_data_base}/sub-{{index}}_{{datasets_no_underscore}}/ses-0/traces/dff_example_100.png", + params: + index=lambda wildcards: wildcards.index + resources: + partition="fast", + mem_mb=16000, + cpu_per_task=1, + tasks=1, + nodes=1, + script: + "../calcium_imaging_automation/core/rules/plot_data.py" + # ----------------------------------------------------- # Summarize data for datavzrd report rule summarize_data: