Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow overwrite when schemas refer to the same tool #175

Merged
merged 5 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/validate_pyproject/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,16 @@ def __init__(self, plugins: Sequence["PluginProtocol"] = ()):

# Add tools using Plugins
for plugin in plugins:
allow_overwrite: Optional[str] = None
if plugin.tool in tool_properties:
_logger.warning(f"{plugin.id} overwrites `tool.{plugin.tool}` schema")
allow_overwrite = plugin.schema.get("$id")
else:
_logger.info(f"{plugin.id} defines `tool.{plugin.tool}` schema")
sid = self._ensure_compatibility(plugin.tool, plugin.schema)["$id"]
compatible = self._ensure_compatibility(
plugin.tool, plugin.schema, allow_overwrite
)
sid = compatible["$id"]
sref = f"{sid}#{plugin.fragment}" if plugin.fragment else sid
tool_properties[plugin.tool] = {"$ref": sref}
self._schemas[sid] = (f"tool.{plugin.tool}", plugin.id, plugin.schema)
Expand All @@ -133,11 +138,13 @@ def main(self) -> str:
"""Top level schema for validating a ``pyproject.toml`` file"""
return self._main_id

def _ensure_compatibility(self, reference: str, schema: Schema) -> Schema:
if "$id" not in schema:
def _ensure_compatibility(
self, reference: str, schema: Schema, allow_overwrite: Optional[str] = None
) -> Schema:
if "$id" not in schema or not schema["$id"]:
raise errors.SchemaMissingId(reference)
sid = schema["$id"]
if sid in self._schemas:
if sid in self._schemas and sid != allow_overwrite:
raise errors.SchemaWithDuplicatedId(sid)
version = schema.get("$schema")
# Support schemas with missing trailing # (incorrect, but required before 0.15)
Expand Down
4 changes: 3 additions & 1 deletion src/validate_pyproject/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def iterate_entry_points(group: str = ENTRYPOINT_GROUP) -> Iterable[EntryPoint]:
# TODO: Once Python 3.10 becomes the oldest version supported, this fallback and
# conditional statement can be removed.
entries_ = (plugin for plugin in entries.get(group, []))
deduplicated = {e.name: e for e in sorted(entries_, key=lambda e: e.name)}
deduplicated = {
e.name: e for e in sorted(entries_, key=lambda e: (e.name, e.value))
}
return list(deduplicated.values())


Expand Down
12 changes: 10 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,19 @@ def test_incompatible_versions(self):
with pytest.raises(errors.InvalidSchemaVersion):
api.SchemaRegistry([plg])

def test_duplicated_id(self):
plg = [plugins.PluginWrapper("plg", self.fake_plugin) for _ in range(2)]
def test_duplicated_id_different_tools(self):
schema = self.fake_plugin("plg")
fn = wraps(self.fake_plugin)(lambda _: schema) # Same ID
plg = [plugins.PluginWrapper(f"plg{i}", fn) for i in range(2)]
with pytest.raises(errors.SchemaWithDuplicatedId):
api.SchemaRegistry(plg)

def test_allow_overwrite_same_tool(self):
plg = [plugins.PluginWrapper("plg", self.fake_plugin) for _ in range(2)]
registry = api.SchemaRegistry(plg)
sid = self.fake_plugin("plg")["$id"]
assert sid in registry

def test_missing_id(self):
def _fake_plugin(name):
plg = dict(self.fake_plugin(name))
Expand Down