Skip to content

Commit

Permalink
Merge pull request #12 from NxNiki/dev
Browse files Browse the repository at this point in the history
add analysis
  • Loading branch information
NxNiki authored Nov 13, 2024
2 parents ebf6ab4 + 64b253f commit de7811c
Show file tree
Hide file tree
Showing 31 changed files with 2,201 additions and 1,101 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
results/**/*.html filter=lfs diff=lfs merge=lfs -text
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,10 @@ src/brain_decoding/__pycache__/
data/
._data
config/*.yaml
results/

results/**/*.npy
results/**/*.png
results/**/*.tar
!results/**/*.html

wandb/
25 changes: 25 additions & 0 deletions .run/run_twilight_merge.run.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="run_twilight_merge" type="PythonConfigurationType" factoryName="Python">
<module name="brain_decoding" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="SDK_NAME" value="movie_decoding" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/src/scripts/run_model_twilight_merge.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
</component>
25 changes: 25 additions & 0 deletions .run/run_twilight_vs_24.run.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="run_twilight_vs_24" type="PythonConfigurationType" factoryName="Python">
<module name="brain_decoding" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="SDK_NAME" value="movie_decoding" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/src/scripts/run_model_twilight_vs_24.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
</component>
Git LFS file not shown
Git LFS file not shown
236 changes: 0 additions & 236 deletions scripts/plot_activation.ipynb

This file was deleted.

46 changes: 46 additions & 0 deletions src/brain_decoding/config/ensure_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pydantic import BaseModel


class Config(BaseModel):
class Config:
# arbitrary_types_allowed = True
extra = "allow" # Allow arbitrary attributes

_list_fields = set() # A set to track which fields should be treated as lists

def ensure_list(self, name: str):
value = getattr(self, name, None)
if value is not None and not isinstance(value, list):
setattr(self, name, [value])
# Mark the field to always be treated as a list
self._list_fields.add(name)

def __setattr__(self, name, value):
if name in self._list_fields and not isinstance(value, list):
# Automatically convert to a list if it's in the list fields
value = [value]
super().__setattr__(name, value)


class SupConfig(Config):
pass


# Example usage
config = SupConfig()

# Dynamically adding attributes
config.param1 = "a"

# Ensuring param1 is a list
config.ensure_list("param1")
print(config.param1) # Output: ['a']

# Assigning new value to param1
config.param1 = "ab"
print(config.param1) # Output: ['ab'] gets automatically converted to ['ab']

# Adding another parameter and ensuring it's a list
config.ensure_list("param2")
config.param2 = 123
print(config.param2) # Output: [123]
64 changes: 64 additions & 0 deletions src/brain_decoding/config/export_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Any, Dict, Set

from pydantic import BaseModel, Field


class BaseConfig(BaseModel):
class Config:
extra = "allow" # Allow arbitrary attributes

def __init__(self, **data: Any) -> None:
super().__init__(**data)
self.__dict__["_list_fields"]: Set[str] = set()
self.__dict__["_alias"]: Dict[str, str] = {}

def __getitem__(self, key: str) -> Any:
return getattr(self, key)

def __setitem__(self, key: str, value: Any):
setattr(self, key, value)

def __getattr__(self, name):
"""Handles alias access and custom parameters."""
if name in self._alias:
return getattr(self, self._alias[name])

def __setattr__(self, name, value):
"""Handles alias assignment, field setting, or adding to _param."""
if name in self._alias:
name = self._alias[name]
if name in self._list_fields and not isinstance(value, list):
value = [value]
super().__setattr__(name, value)

def __contains__(self, key: str) -> bool:
return hasattr(self, key)

def __repr__(self):
attrs = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
attr_str = "\n".join(f" {key}: {value!r}" for key, value in attrs.items())
return f"{self.__class__.__name__}(\n{attr_str}\n)"

def set_alias(self, name: str, alias: str) -> None:
self.__dict__["_alias"][alias] = name

def ensure_list(self, name: str):
"""Mark the field to always be treated as a list"""
value = getattr(self, name, None)
if value is not None and not isinstance(value, list):
setattr(self, name, [value])
self._list_fields.add(name)


class Foo(BaseConfig):
a: int = 1

class Config:
extra = "allow"


print(Foo(**{"a": 1, "b": 2}).model_dump()) # == {'a': 1, 'b': 2}

foo = Foo()
foo.b = 2
print(foo.model_dump())
4 changes: 4 additions & 0 deletions src/brain_decoding/config/file_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@
SURROGATE_FILE_PATH = ROOT_PATH / "data/surrogate_windows"
CONFIG_FILE_PATH = ROOT_PATH / "config"
RESULT_PATH = ROOT_PATH / "results"
MOVIE24_LABEL_PATH = f"{DATA_PATH}/8concepts_merged.npy"
TWILIGHT_LABEL_PATH = f"{DATA_PATH}/twilight_concepts.npy"
TWILIGHT_MERGE_LABEL_PATH = f"{DATA_PATH}/twilight_concepts_merged.npy"
MOVIE_LABEL_TWILIGHT_VS_24 = f"{DATA_PATH}/twilight_vs_24.npy"
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Custom parameters can be added to any of the three fields of config (experiment, model, data).
"""

from torch import nn

from brain_decoding.config.config import ExperimentConfig, PipelineConfig
from brain_decoding.config.file_path import CONFIG_FILE_PATH, DATA_PATH, RESULT_PATH

Expand All @@ -18,7 +20,7 @@
config.model.lr_drop = 50
config.model.validation_step = 10
config.model.early_stop = 75
config.model.num_labels = 8
config.model.num_labels = 18 # 8 for 24, 18 for twilight
config.model.merge_label = True
config.model.img_embedding_size = 192
config.model.hidden_size = 256
Expand All @@ -27,6 +29,7 @@
config.model.patch_size = (1, 5)
config.model.intermediate_size = 192 * 2
config.model.classifier_proj_size = 192
config.model.train_loss = nn.BCEWithLogitsLoss(reduction="none")

config.experiment.seed = 42
config.experiment.use_spike = True
Expand All @@ -44,8 +47,8 @@
config.experiment.use_shuffle_diagnostic = True
config.experiment.testing_mode = False # in testing mode, a maximum of 1e4 clusterless data will be loaded.
config.experiment.model_aggregate_type = "sum"
config.experiment.train_phases = ["movie_1"]
config.experiment.test_phases = ["sleep_2"]
config.experiment.train_phases = ["twilight_1"]
config.experiment.test_phases = ["sleep_1"]
config.experiment.compute_accuracy = False

config.experiment.ensure_list("train_phases")
Expand All @@ -61,6 +64,8 @@
config.data.spike_data_sd_inference = 3.5
config.data.model_aggregate_type = "sum"
config.data.movie_label_path = str(DATA_PATH / "8concepts_merged.npy")
config.data.movie_label_sr = 1
config.data.movie_sampling_rate = 30
config.data.filter_low_occurrence_samples = True

# config.export_config(CONFIG_FILE_PATH)
12 changes: 7 additions & 5 deletions src/brain_decoding/dataloader/clusterless_clean.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import glob
import os
import re
from typing import List, Union

import numpy as np
import pandas as pd
Expand All @@ -16,6 +17,11 @@ def __init__(self, operator: str, threshold: int):
self.threshold = threshold


def sort_file_name(filenames: str) -> List[Union[int, str]]:
"""Extract the numeric part of the filename and use it as the sort key"""
return [int(x) if x.isdigit() else x for x in re.findall(r"\d+|\D+", filenames)]


def find_true_indices(mask, op_thresh: OpThresh = None):
"""
Returns an nx3 matrix containing start, end, and length of all true samples in a 1D boolean mask.
Expand Down Expand Up @@ -161,14 +167,10 @@ def load_data_from_bundle(clu_bundle_filepaths):


def get_oneshot_clean(patient_number, desired_samplerate, mode, category="recall", phase=None, version="notch"):
def sort_filename(filename):
"""Extract the numeric part of the filename and use it as the sort key"""
return [int(x) if x.isdigit() else x for x in re.findall(r"\d+|\D+", filename)]

# folder contains the clustless data, I saved the folder downloaded from the drive as '562/clustless_raw'
spike_path = f"/mnt/SSD2/yyding/Datasets/neuron/spike_data/{patient_number}/raw_{mode}/"
spike_files = glob.glob(os.path.join(spike_path, "*.csv"))
spike_files = sorted(spike_files, key=sort_filename)
spike_files = sorted(spike_files, key=sort_file_name)

for bundle in range(0, len(spike_files), 8):
df = load_data_from_bundle(spike_files[bundle : bundle + 8])
Expand Down
Loading

0 comments on commit de7811c

Please sign in to comment.