Skip to content

Commit

Permalink
Merge pull request #415 from MannLabs/refactor_config2
Browse files Browse the repository at this point in the history
Refactor config2
  • Loading branch information
mschwoer authored Jan 9, 2025
2 parents 8f32e51 + 7b1ad96 commit 15ee261
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 113 deletions.
28 changes: 7 additions & 21 deletions alphadia/search_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(
config: dict | Config | None = None,
cli_config: dict | None = None,
extra_config: dict | None = None,
config_base_path: str | None = None,
) -> None:
"""Highest level class to plan a DIA search step.
Expand All @@ -46,24 +45,22 @@ def __init__(
output folder to save the results
config : dict, optional
values to update the default config. Overrides values in `default.yaml` and `config_base_path`.
values to update the default config. Overrides values in `default.yaml`.
cli_config : dict, optional
additional config values (parameters from the command line). Overrides values in `config`.
extra_config : dict, optional
additional config values (parameters to orchestrate multistep searches). Overrides values in `config` and `cli_config`.
config_base_path : str, optional
absolute path to yaml file containing additional config values. Overrides values in `default.yaml`.
"""

self.output_folder = output_folder
os.makedirs(output_folder, exist_ok=True)
reporting.init_logging(self.output_folder)

self._config = self._init_config(
config, cli_config, extra_config, output_folder, config_base_path
config, cli_config, extra_config, output_folder
)
logger.setLevel(logging.getLevelName(self._config["general"]["log_level"]))

Expand All @@ -86,7 +83,6 @@ def _init_config(
cli_config: dict | None,
extra_config: dict | None,
output_folder: str,
config_base_path: str | None,
) -> Config:
"""Initialize the config with default values and update with user defined values."""

Expand All @@ -98,18 +94,12 @@ def _init_config(
config.from_yaml(default_config_path)

config_updates = []
if config_base_path is not None:
logger.info(f"loading additional config from {config_base_path}")
user_config_from_file = Config(USER_DEFINED)
user_config_from_file.from_yaml(default_config_path)
config_updates.append(user_config_from_file)

if user_config is not None:
logger.info("loading additional config provided via CLI")
# load update config from dict
if isinstance(user_config, dict):
user_config_update = Config(USER_DEFINED)
user_config_update.from_dict(user_config)
user_config_update = Config(user_config, name=USER_DEFINED)
config_updates.append(user_config_update)
elif isinstance(user_config, Config):
config_updates.append(user_config)
Expand All @@ -120,14 +110,12 @@ def _init_config(

if cli_config is not None:
logger.info("loading additional config provided via CLI parameters")
cli_config_update = Config(USER_DEFINED_CLI_PARAM)
cli_config_update.from_dict(cli_config)
cli_config_update = Config(cli_config, name=USER_DEFINED_CLI_PARAM)
config_updates.append(cli_config_update)

# this needs to be last
if extra_config is not None:
extra_config_update = Config(MULTISTEP_SEARCH)
extra_config_update.from_dict(extra_config)
extra_config_update = Config(extra_config, name=MULTISTEP_SEARCH)
# need to overwrite user-defined output folder here to have correct value in config dump
extra_config[ConfigKeys.OUTPUT_DIRECTORY] = output_folder
config_updates.append(extra_config_update)
Expand Down Expand Up @@ -304,12 +292,10 @@ def run(
workflow = self._process_raw_file(dia_path, raw_name, speclib)
workflow_folder_list.append(workflow.path)

except CustomError as e:
_log_exception_event(e, raw_name, workflow)
continue

except Exception as e:
_log_exception_event(e, raw_name, workflow)
if isinstance(e, CustomError):
continue
raise e

finally:
Expand Down
62 changes: 21 additions & 41 deletions alphadia/workflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@

import json
import logging
from collections import defaultdict
from collections import UserDict, defaultdict
from copy import deepcopy
from typing import Any

import yaml

Expand All @@ -25,54 +24,37 @@
MULTISTEP_SEARCH = "multistep search"


class Config:
"""Class holding a configuration.
class Config(UserDict):
"""Dict-like config class that can read from and write to yaml and json files and allows updating with other config objects."""

Can read from and write to yaml and json files.
Can be used to update the config with other config objects, and print the config in a tree structure.
"""

def __init__(self, name: str = "default") -> None:
def __init__(self, data: dict = None, name: str = DEFAULT) -> None:
# super class deliberately not called as this calls "update" (which we overwrite)
self.data = (
{**data} if data is not None else {}
) # this needs to be called 'data' as we inherit from UserDict
self.name = name
self.config = {}
self.translated_config = {}

def from_yaml(self, path: str) -> None:
with open(path) as f:
self.config = yaml.safe_load(f)
self.data = yaml.safe_load(f)

def from_json(self, path: str) -> None:
with open(path) as f:
self.config = json.load(f)
self.data = json.load(f)

def to_yaml(self, path: str) -> None:
with open(path, "w") as f:
yaml.dump(self.config, f, sort_keys=False)
yaml.dump(self.data, f, sort_keys=False)

def to_json(self, path: str) -> None:
with open(path, "w") as f:
json.dump(self.config, f)

def from_dict(self, config: dict[str, Any]) -> None:
self.config = config

def to_dict(self) -> dict[str, Any]:
return self.config

def get(self, key: str, default: Any = None) -> Any:
return self.config.get(key, default)
json.dump(self.data, f)

def __getitem__(self, key: str) -> Any:
return self.config[key]
def __setitem__(self, key, item):
raise NotImplementedError("Use update() to update the config.")

def __setitem__(self, key: str, value: Any) -> None:
self.config[key] = value

def __contains__(self, key: str) -> bool:
return key in self.config

def __repr__(self) -> str:
return str(self.config)
def __delitem__(self, key):
raise NotImplementedError("Use update() to update the config.")

def update(self, configs: list["Config"], do_print: bool = False):
"""
Expand All @@ -92,29 +74,27 @@ def update(self, configs: list["Config"], do_print: bool = False):
do_print : bool, optional
Whether to print the modified config. Default is False.
"""

# we assume that self.config holds the default config
default_config = deepcopy(self.config)
# we assume that self.data holds the default config
default_config = deepcopy(self.data)

def _recursive_defaultdict():
"""Allow initialization of an infinitely nested dictionary to be able to map arbitrary structures."""
return defaultdict(_recursive_defaultdict)

tracking_dict = defaultdict(_recursive_defaultdict)

current_config = deepcopy(self.config)

current_config = deepcopy(self.data)
for config in configs:
logger.info(f"Updating config with '{config.name}'")

_update(
current_config,
config.to_dict(),
config.data,
tracking_dict,
config.name,
)

self.config = current_config
self.data = current_config

if do_print:
try:
Expand Down
6 changes: 3 additions & 3 deletions tests/e2e_tests/prepare_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class YamlKeys:
TEST_CASES = "test_cases"
NAME = "name"
CONFIG = "config"
LIBRARY = "library"
FASTA = "fasta"
RAW_DATA = "raw_data"
LIBRARY = "library_path"
FASTA = "fasta_paths"
RAW_DATA = "raw_paths"
SOURCE_URL = "source_url"


Expand Down
Loading

0 comments on commit 15ee261

Please sign in to comment.