Skip to content

Commit

Permalink
Merge pull request #11 from NxNiki/dev
Browse files Browse the repository at this point in the history
check how to lost whole sleep data.
  • Loading branch information
NxNiki authored Oct 24, 2024
2 parents aeaadbc + 1dc9521 commit ebf6ab4
Show file tree
Hide file tree
Showing 10 changed files with 567 additions and 182 deletions.
181 changes: 143 additions & 38 deletions scripts/plot_activation.ipynb

Large diffs are not rendered by default.

104 changes: 54 additions & 50 deletions scripts/save_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,57 +6,61 @@
from brain_decoding.config.config import ExperimentConfig, PipelineConfig
from brain_decoding.config.file_path import CONFIG_FILE_PATH, DATA_PATH, RESULT_PATH

if __name__ == "__main__":
experiment_config = ExperimentConfig(name="sleep", patient=562)
# if __name__ == "__main__":
experiment_config = ExperimentConfig(name="sleep", patient=562)

config = PipelineConfig(experiment=experiment_config)
config.model.architecture = "multi-vit"
config.model.learning_rate = 1e-4
config.model.batch_size = 128
config.model.weight_decay = 1e-4
config.model.epochs = 40
config.model.lr_drop = 50
config.model.validation_step = 10
config.model.early_stop = 75
config.model.num_labels = 8
config.model.merge_label = True
config.model.img_embedding_size = 192
config.model.hidden_size = 256
config.model.num_hidden_layers = 6
config.model.num_attention_heads = 8
config.model.patch_size = (1, 5)
config.model.intermediate_size = 192 * 2
config.model.classifier_proj_size = 192
config = PipelineConfig(experiment=experiment_config)
config.model.architecture = "multi-vit"
config.model.learning_rate = 1e-4
config.model.batch_size = 128
config.model.weight_decay = 1e-4
config.model.epochs = 40
config.model.lr_drop = 50
config.model.validation_step = 10
config.model.early_stop = 75
config.model.num_labels = 8
config.model.merge_label = True
config.model.img_embedding_size = 192
config.model.hidden_size = 256
config.model.num_hidden_layers = 6
config.model.num_attention_heads = 8
config.model.patch_size = (1, 5)
config.model.intermediate_size = 192 * 2
config.model.classifier_proj_size = 192

config.experiment.seed = 42
config.experiment.use_spike = True
config.experiment.use_lfp = False
config.experiment.use_combined = False
config.experiment.use_shuffle = True
config.experiment.use_bipolar = False
config.experiment.use_sleep = (
True # set true to use sleep data as inference dataset, otherwise use free recall, is this right?
)
config.experiment.use_overlap = False
config.experiment.use_long_input = False
config.experiment.use_spontaneous = False
config.experiment.use_augment = False
config.experiment.use_shuffle_diagnostic = True
config.experiment.testing_mode = True # in testing mode, a maximum of 1e4 clusterless data will be loaded.
config.experiment.model_aggregate_type = "sum"
config.experiment.train_phase = ["movie_1"]
config.experiment.test_phase = ["sleep_2"]
config.experiment.seed = 42
config.experiment.use_spike = True
config.experiment.use_lfp = False
config.experiment.use_combined = False
config.experiment.use_shuffle = True
config.experiment.use_bipolar = False
# config.experiment.use_sleep = (
# True # set true to use sleep data as inference dataset, otherwise use free recall, is this right?
# )
config.experiment.use_overlap = False
config.experiment.use_long_input = False
config.experiment.use_spontaneous = False
config.experiment.use_augment = False
config.experiment.use_shuffle_diagnostic = True
config.experiment.testing_mode = False # in testing mode, a maximum of 1e4 clusterless data will be loaded.
config.experiment.model_aggregate_type = "sum"
config.experiment.train_phases = ["movie_1"]
config.experiment.test_phases = ["sleep_2"]
config.experiment.compute_accuracy = False

config.data.result_path = str(RESULT_PATH)
config.data.spike_path = str(DATA_PATH)
config.data.lfp_path = "undefined"
config.data.lfp_data_mode = "sf2000-bipolar-region-clean"
config.data.spike_data_mode = "notch CAR-quant-neg"
config.data.spike_data_mode_inference = "notch CAR-quant-neg"
config.data.spike_data_sd = [3.5]
config.data.spike_data_sd_inference = 3.5
config.data.model_aggregate_type = "sum"
config.data.movie_label_path = str(DATA_PATH / "8concepts_merged.npy")
config.data.movie_sampling_rate = 30
config.experiment.ensure_list("train_phases")
config.experiment.ensure_list("test_phases")

config.export_config(CONFIG_FILE_PATH)
config.data.result_path = str(RESULT_PATH)
config.data.spike_path = str(DATA_PATH)
config.data.lfp_path = "undefined"
config.data.lfp_data_mode = "sf2000-bipolar-region-clean"
config.data.spike_data_mode = "notch CAR-quant-neg"
config.data.spike_data_mode_inference = "notch CAR-quant-neg"
config.data.spike_data_sd = [3.5]
config.data.spike_data_sd_inference = 3.5
config.data.model_aggregate_type = "sum"
config.data.movie_label_path = str(DATA_PATH / "8concepts_merged.npy")
config.data.movie_sampling_rate = 30

# config.export_config(CONFIG_FILE_PATH)
10 changes: 8 additions & 2 deletions src/brain_decoding/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,19 @@ def __setattr__(self, name, value):
def __contains__(self, key: str) -> bool:
return hasattr(self, key)

def __repr__(self):
attrs = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
attr_str = "\n".join(f" {key}: {value!r}" for key, value in attrs.items())
return f"{self.__class__.__name__}(\n{attr_str}\n)"

def set_alias(self, name: str, alias: str) -> None:
self.__dict__["_alias"][alias] = name

def ensure_list(self, name: str):
"""Mark the field to always be treated as a list"""
value = getattr(self, name, None)
if value is not None and not isinstance(value, list):
setattr(self, name, [value])
# Mark the field to always be treated as a list
self._list_fields.add(name)


Expand Down Expand Up @@ -112,8 +117,9 @@ def export_config(self, output_file: Union[str, Path] = "config.yaml") -> None:
dir_path = output_file.parent
dir_path.mkdir(parents=True, exist_ok=True)

config_data = self.model_dump()
with open(output_file, "w") as file:
yaml.safe_dump(self.model_dump(), file)
yaml.safe_dump(config_data, file)

@property
def _file_tag(self) -> str:
Expand Down
9 changes: 6 additions & 3 deletions src/brain_decoding/dataloader/save_clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

from brain_decoding.config.file_path import DATA_PATH

SECONDS_PER_HOUR = 3600

OFFSET = {
"555_1": 4.58,
"562_1": 0,
Expand Down Expand Up @@ -115,8 +117,9 @@

# is there a way to select the whole duration?
SLEEP_TIME = {
"562_1": (0, 2 * 3600), # memory test
"562_2": (0, 5 * 3600), # memory test
"562_1": (0, 2 * SECONDS_PER_HOUR), # memory test
"562_2": (0, 5 * SECONDS_PER_HOUR), # memory test
"562_3": (0, 10 * SECONDS_PER_HOUR), # memory test
}

CONTROL = {
Expand Down Expand Up @@ -820,7 +823,7 @@ def sort_filename(filename):
if __name__ == "__main__":
version = "notch CAR-quant-neg"
SPIKE_ROOT_PATH = "/Users/XinNiuAdmin/Library/CloudStorage/Box-Box/Vwani_Movie/Clusterless/"
get_oneshot_clean("562", 2000, "Experiment6_MovieParadigm_notch", category="sleep", phase=2, version=version)
get_oneshot_clean("562", 2000, "Experiment6_MovieParadigm_notch", category="sleep", phase=3, version=version)
# get_oneshot_clean("562", 2000, "presleep", category="movie", phase=1, version=version)
# get_oneshot_clean("562", 2000, "presleep", category="recall", phase="FR1", version=version)
# get_oneshot_clean("562", 2000, "postsleep", category="recall", phase="FR2", version=version)
49 changes: 7 additions & 42 deletions src/brain_decoding/dataloader/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,56 +27,23 @@ class InferenceDataset(Dataset):
def __init__(self, config):
self.config = config
self.lfp_channel_by_region = {}
phases = config.data.phases

spikes_data = None
if self.config.experiment["use_spike"]:
data_path = "spike_path"
if self.config.experiment["use_sleep"]:
config.experiment["spike_data_mode_inference"] = ""
spikes_data = self.read_recording_data(data_path, "time_sleep", phases[0])
else:
if (
isinstance(self.config.experiment["free_recall_phase"], str)
and "all" in self.config.experiment["free_recall_phase"]
):
for phase in phases:
spikes_data = self.read_recording_data(data_path, "time_recall", phase)
elif (
isinstance(self.config.experiment["free_recall_phase"], str)
and "control" in self.config.experiment["free_recall_phase"]
):
spikes_data = self.read_recording_data(data_path, "time", None)
elif (
isinstance(self.config.experiment["free_recall_phase"], str)
and "movie" in self.config.experiment["free_recall_phase"]
):
spikes_data = self.read_recording_data(data_path, "time", None)
else:
spikes_data = self.read_recording_data(data_path, "time_recall", None)
spikes_data = self.read_recording_data(data_path, "time", self.config.experiment.test_phases[0])

lfp_data = None
if self.config.experiment["use_lfp"]:
data_path = "lfp_path"
if self.config.experiment.use_sleep:
config["spike_data_mode_inference"] = ""
lfp_data = self.read_recording_data(data_path, "spectrogram_sleep", "")
else:
if isinstance(self.config["free_recall_phase"], str) and "all" in self.config["free_recall_phase"]:
for phase in phases:
lfp_data = self.read_recording_data(data_path, "spectrogram_recall", phase)
elif (
isinstance(self.config["free_recall_phase"], str) and "control" in self.config["free_recall_phase"]
):
lfp_data = self.read_recording_data(data_path, "spectrogram", None)
else:
lfp_data = self.read_recording_data(data_path, "spectrogram_recall", None)
lfp_data = self.read_recording_data(data_path, "spectrogram_recall", self.config.experiment.test_phases[0])
# self.lfp_data = {key: np.concatenate(value_list, axis=0) for key, value_list in self.lfp_data.items()}

self.data = {"clusterless": spikes_data, "lfp": lfp_data}
self.data_length = self.get_data_length()
self.preprocess_data()

def read_recording_data(self, root_path: str, file_path_prefix: str, phase: Optional[str]) -> np.ndarray[float]:
def read_recording_data(self, root_path: str, file_path_prefix: str, phase: str) -> np.ndarray[float]:
"""
read spike or lfp data.
Expand All @@ -85,10 +52,7 @@ def read_recording_data(self, root_path: str, file_path_prefix: str, phase: Opti
:param phase:
:return:
"""
if phase == "":
exp_file_path = file_path_prefix
else:
exp_file_path = f"{file_path_prefix}_{phase}"
exp_file_path = f"{file_path_prefix}_{phase}"

recording_file_path = os.path.join(
self.config.data[root_path],
Expand All @@ -100,7 +64,8 @@ def read_recording_data(self, root_path: str, file_path_prefix: str, phase: Opti
recording_files = sorted(recording_files, key=sort_file_name)

if not recording_files:
raise ValueError(f"not files found in: {recording_files}")
error_msg = f"not files found in: {recording_files}"
raise ValueError(error_msg)

if root_path == "spike_path":
data = self.load_clustless(recording_files)
Expand Down
13 changes: 2 additions & 11 deletions src/brain_decoding/dataloader/train_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,13 @@ def __init__(self, config: PipelineConfig):
self.smoothed_label = []
self.lfp_channel_by_region = {}

if self.patient in ["564", "565"]:
categories = ["Movie_1", "Movie_2"]
else:
categories = ["Movie_1"]

if self.use_spontaneous:
categories.append("Control1")
categories.append("Control2")

# create spike data
if self.use_spike:
self.data["clusterless"] = self.load_data(config.data["spike_path"], categories)
self.data["clusterless"] = self.load_data(config.data["spike_path"], config.experiment.train_phases)

# create lfp data
if self.use_lfp:
self.data["lfp"] = self.load_data(config.data["lfp_path"], categories)
self.data["lfp"] = self.load_data(config.data["lfp_path"], config.experiment.train_phases)

# for c, category in enumerate(categories):
# size = sample_size[c]
Expand Down
31 changes: 20 additions & 11 deletions src/brain_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from brain_decoding.config.config import PipelineConfig
from brain_decoding.config.file_path import CONFIG_FILE_PATH
from brain_decoding.param.base_param import device
from scripts.save_config import config

# torch.autograd.set_detect_anomaly(True)
# torch.backends.cuda.matmul.allow_tf32=True
Expand All @@ -28,17 +29,19 @@


def set_config(
config_file: Union[str, Path],
config_file: Union[str, Path, PipelineConfig],
patient_id: int,
phases: Union[List[str], str],
train_phases: Union[List[str], str],
test_phases: Union[List[str], str],
spike_data_sd: Union[List[float], float, None] = None,
spike_data_sd_inference: Optional[float] = None,
) -> PipelineConfig:
"""
set parameters based on config file.
:param config_file:
:param patient_id:
:param phases:
:param train_phases:
:param test_phases:
:param spike_data_sd:
:param spike_data_sd_inference:
:return:
Expand All @@ -47,15 +50,18 @@ def set_config(
if isinstance(spike_data_sd, float):
spike_data_sd = [spike_data_sd]

config = PipelineConfig.read_config(config_file)
if isinstance(config_file, PipelineConfig):
config = config_file
else:
config = PipelineConfig.read_config(config_file)

config.experiment["patient"] = patient_id
config.experiment.name = "8concepts"

if isinstance(phases, str):
config.data.phases = [phases]
else:
config.data.phases = phases
config.experiment.train_phases = [train_phases]

config.experiment.test_phases = test_phases
config.experiment.ensure_list("test_phases")

if spike_data_sd is not None:
config.data.spike_data_sd = spike_data_sd
Expand Down Expand Up @@ -110,13 +116,16 @@ def pipeline(config: PipelineConfig) -> Trainer:

if __name__ == "__main__":
patient = 562
phase = "2"
phase_train = "movie_1"
phase_test = "sleep_3"
CONFIG_FILE = CONFIG_FILE_PATH / "config_sleep-None-None_2024-10-16-19:17:43.yaml"

config = set_config(
CONFIG_FILE,
# CONFIG_FILE,
config,
patient,
phase,
phase_train,
phase_test,
)

print("start: ", patient)
Expand Down
8 changes: 4 additions & 4 deletions src/brain_decoding/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def train(self, epochs, fold):
)
print()
print("WELCOME MEMORY TEST at: ", epoch)
stats_m = self.memory(epoch=epoch + 1, phase=self.config.data.phases[0], alongwith=[])
# self.memory(1, epoch=epoch+1, phase='all')
stats_m = self.memory(epoch=epoch + 1, phase=self.config.experiment.test_phases[0], alongwith=[])

if stats_m is not None:
overall_p = list(stats_m.values())
print("P: ", overall_p)
Expand Down Expand Up @@ -360,7 +360,7 @@ def memory(self, epoch=-1, phase: str = "free_recall1", alongwith=[]):
torch.manual_seed(self.config.experiment["seed"])
np.random.seed(self.config.experiment["seed"])
random.seed(self.config.experiment["seed"])
self.config.experiment["free_recall_phase"] = phase
# self.config.experiment["test_phase"] = phase
dataloaders = initialize_inference_dataloaders(self.config)
model = initialize_model(self.config)
# model = torch.compile(model)
Expand Down Expand Up @@ -448,7 +448,7 @@ def memory(self, epoch=-1, phase: str = "free_recall1", alongwith=[]):
predictions = predictions[:, 0:8]

# Perform Statistic Method
if not self.config.experiment["use_sleep"]:
if self.config.experiment["compute_accuracy"]:
sts = Permutate(
config=self.config,
phase=phase,
Expand Down
Loading

0 comments on commit ebf6ab4

Please sign in to comment.