Skip to content

Commit

Permalink
Add config support for enum values
Browse files Browse the repository at this point in the history
  • Loading branch information
grahamtt committed Dec 13, 2023
1 parent 561285b commit ef27006
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
25 changes: 24 additions & 1 deletion prosper_shared/omni_config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import argparse
from copy import deepcopy
from decimal import Decimal
from enum import Enum, EnumMeta
from importlib.util import find_spec
from numbers import Number
from os import getcwd
from os.path import join
from typing import List, Union
from typing import List, Optional, Type, Union

import dpath
from platformdirs import user_config_dir
Expand Down Expand Up @@ -120,6 +121,28 @@ def get_as_bool(self, key: str, default: bool = False):

return False

def get_as_enum(
self, key: str, enum_type: EnumMeta, default: Optional[Enum] = None
) -> Optional[Enum]:
"""Gets a config value by enum name or value.
Args:
key (str): The named config to get.
enum_type (EnumMeta): Interpret the resulting value as an enum of this type.
default (Optional[Enum]): The value to return if the config key doesn't exist.
Returns:
Optional[Enum]: The config value interpreted as the given enum type or the default value.
"""
value = self.get(key)
if value is None:
return default

if value in enum_type.__members__.keys():
return enum_type[value]

return enum_type(value)

@classmethod
def autoconfig(
cls,
Expand Down
25 changes: 25 additions & 0 deletions tests/omni_config/test_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from decimal import Decimal
from enum import Enum
from os import getcwd
from os.path import join

Expand Down Expand Up @@ -41,6 +42,10 @@
}


class TestEnum(Enum):
KEY = "value"


class TestConfig:
def test_get(self):
config = Config(config_dict=TEST_CONFIG, schema=TEST_SCHEMA)
Expand Down Expand Up @@ -84,6 +89,26 @@ def test_get_as_decimal(self):
"testSection.testDecimalNotFound", Decimal("0")
) == Decimal("0")

@pytest.mark.parametrize(
["config_value", "given_default", "expected_value"],
[
("KEY", None, TestEnum.KEY),
("value", None, TestEnum.KEY),
(None, TestEnum.KEY, TestEnum.KEY),
],
)
def test_get_as_enum_happy(self, config_value, given_default, expected_value):
config = Config(config_dict={"enum_config": config_value})
actual_value = config.get_as_enum(
"enum_config", TestEnum, default=given_default
)
assert actual_value == expected_value

def test_get_as_enum_unhappy(self):
config = Config(config_dict={"enum_config": "BAD_KEY"})
with pytest.raises(ValueError):
config.get_as_enum("enum_config", TestEnum)

def test_get_invalid_key(self):
config = Config(config_dict=TEST_CONFIG)

Expand Down

0 comments on commit ef27006

Please sign in to comment.