diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 73869db..2cb536f 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -5,6 +5,7 @@ from variantlib.base import PluginBase from variantlib.config import KeyConfig, ProviderConfig +from variantlib.meta import VariantDescription, VariantMeta from variantlib.plugins import PluginLoader @@ -20,6 +21,12 @@ def get_supported_configs(self) -> Optional[ProviderConfig]: ], ) + def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]: + for meta in variant_desc: + if meta.namespace == self.namespace and meta.key == "key1": + return [meta.value.removeprefix("val")] + return [] + # NB: this plugin deliberately does not inherit from PluginBase # to test that we don't rely on that inheritance @@ -34,9 +41,28 @@ def get_supported_configs(self) -> Optional[ProviderConfig]: ], ) + def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]: + if VariantMeta(self.namespace, "key3", "val3a") in variant_desc: + return ["sec"] + return [] + class MockedPluginC(PluginBase): - namespace = "incompatible_plugin" + namespace = "other_plugin" + + def get_supported_configs(self) -> Optional[ProviderConfig]: + return None + + def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]: + ret = [] + for meta in variant_desc: + if meta.namespace == self.namespace and meta.value == "on": + ret.append(meta.key) + return ret + + +class MockedPluginD: + namespace = "plugin_without_labels" def get_supported_configs(self) -> Optional[ProviderConfig]: return None @@ -48,6 +74,13 @@ class ClashingPlugin(PluginBase): def get_supported_configs(self) -> Optional[ProviderConfig]: return None + def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]: + ret = [] + for meta in variant_desc: + if meta.namespace == self.namespace and meta.value == "on": + ret.append(meta.key) + return ret + @dataclass class MockedDistribution: @@ -87,6 +120,11 @@ def mocked_plugin_loader(session_mocker): value="tests.test_plugins:MockedPluginC", plugin=MockedPluginC, ), + MockedEntryPoint( + name="no_labels", + value="tests.test_plugins:MockedPluginD", + plugin=MockedPluginD, + ), ] yield PluginLoader() @@ -137,3 +175,36 @@ def test_namespace_clash(mocker): assert "same namespace test_plugin" in str(exc) assert "test-plugin" in str(exc) assert "clashing-plugin" in str(exc) + + +@pytest.mark.parametrize("variant_desc,expected", +[ + (VariantDescription([ + VariantMeta("test_plugin", "key1", "val1a"), + VariantMeta("test_plugin", "key2", "val2b"), + VariantMeta("second_plugin", "key3", "val3a"), + VariantMeta("other_plugin", "flag2", "on"), + ]), ["1a", "sec", "flag2"]), + (VariantDescription([ + # note that VariantMetas don't actually have to be supported + # by the system in question -- we could be cross-building + # for another system + VariantMeta("test_plugin", "key1", "val1f"), + VariantMeta("test_plugin", "key2", "val2b"), + VariantMeta("second_plugin", "key3", "val3a"), + ]), ["1f", "sec"]), + (VariantDescription([ + VariantMeta("test_plugin", "key2", "val2b"), + VariantMeta("second_plugin", "key3", "val3a"), + ]), ["sec"]), + (VariantDescription([ + VariantMeta("test_plugin", "key2", "val2b"), + ]), []), + (VariantDescription([ + VariantMeta("test_plugin", "key2", "val2b"), + VariantMeta("other_plugin", "flag1", "on"), + VariantMeta("other_plugin", "flag2", "on"), + ]), ["flag1", "flag2"]), +]) +def test_get_variant_labels(mocked_plugin_loader, variant_desc, expected): + assert mocked_plugin_loader.get_variant_labels(variant_desc) == expected diff --git a/variantlib/base.py b/variantlib/base.py index 80b17c1..480a311 100644 --- a/variantlib/base.py +++ b/variantlib/base.py @@ -2,9 +2,9 @@ from typing import Protocol, runtime_checkable from variantlib.config import ProviderConfig +from variantlib.meta import VariantDescription -@runtime_checkable class PluginType(Protocol): """A protocol for plugin classes""" @@ -17,6 +17,10 @@ def get_supported_configs(self) -> ProviderConfig: """Get supported configs for the current system""" ... + def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]: + """Get list of short labels to describe the variant""" + ... + class PluginBase(ABC): """An abstract base class that can be used to implement plugins""" @@ -26,3 +30,6 @@ def namespace(self) -> str: ... @abstractmethod def get_supported_configs(self) -> ProviderConfig: ... + + def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]: + return [] diff --git a/variantlib/plugins.py b/variantlib/plugins.py index 8fe7068..ca3e2c8 100644 --- a/variantlib/plugins.py +++ b/variantlib/plugins.py @@ -44,7 +44,12 @@ def load_plugins(self) -> None: # Instantiate the plugin plugin_instance = plugin_class() - assert isinstance(plugin_instance, PluginType) + + # Check for obligatory members + for attr in ("namespace", "get_supported_configs"): + assert hasattr( + plugin_instance, attr + ), f"Plugin is missing required member: {attr}" except Exception: logging.exception("An unknown error happened - Ignoring plugin") else: @@ -92,3 +97,13 @@ def get_dist_name_mapping(self) -> dict[str, str]: """Get a mapping from plugin names to distribution names""" return self._dist_names + + def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]: + """Get list of short labels to describe the variant""" + + labels = [] + for plugin in self._plugins.values(): + if hasattr(plugin, "get_variant_labels"): + labels += plugin.get_variant_labels(variant_desc) + + return labels