diff --git a/piccolo/query/operators/json.py b/piccolo/query/operators/json.py index ea6d0509..16255b9d 100644 --- a/piccolo/query/operators/json.py +++ b/piccolo/query/operators/json.py @@ -1,9 +1,10 @@ from __future__ import annotations +import json import typing as t from piccolo.querystring import QueryString -from piccolo.utils.encoding import dump_json +from piccolo.utils.encoding import dump_json, load_json if t.TYPE_CHECKING: from piccolo.columns.column_types import JSON @@ -12,9 +13,34 @@ class JSONQueryString(QueryString): def clean_value(self, value: t.Any): - if not isinstance(value, (str, QueryString)): - value = dump_json(value) - return value + """ + We need to pass a valid JSON string to Postgres. + + There are lots of different use cases to account for:: + + # A JSON string is passed in - in which case, leave it. + RecordingStudio.facilities == '{"mixing_desk": true}' + + # A Python dict is passed in - we need to convert this to JSON. + RecordingStudio.facilities == {"mixing_desk": True} + + # A string is passed in, but it isn't valid JSON, so we need to + # convert it to a JSON string (i.e. '"Alice Jones"'). + RecordingStudio.facilities["technicians"][0]["name"] == "Alice Jones" + + """ # noqa: E501 + if isinstance(value, QueryString): + return value + elif isinstance(value, str): + # The string might already be valid JSON, in which case, leave it. + try: + load_json(value) + except json.JSONDecodeError: + pass + else: + return value + + return dump_json(value) def __eq__(self, value) -> QueryString: # type: ignore[override] value = self.clean_value(value) diff --git a/tests/columns/test_jsonb.py b/tests/columns/test_jsonb.py index f38c0de0..159239a1 100644 --- a/tests/columns/test_jsonb.py +++ b/tests/columns/test_jsonb.py @@ -263,9 +263,9 @@ class TestFromPath(AsyncTableTest): tables = [RecordingStudio, Instrument] - async def test_from_path(self): + async def test_select(self): """ - Make sure ``from_path`` can be used for complex nested data. + Make sure ``from_path`` can be used when selecting complex nested data. """ await RecordingStudio( name="Abbey Road", @@ -284,3 +284,58 @@ async def test_from_path(self): ).output(load_json=True) assert response is not None self.assertListEqual(response, [{"technician_name": "Alice Jones"}]) + + async def test_where(self): + """ + Make sure ``from_path`` can be used in a ``where`` clause. + """ + await RecordingStudio.insert( + RecordingStudio( + name="Abbey Road", + facilities={ + "restaurant": False, + "mixing_desk": False, + "instruments": {"electric_guitars": 4, "drum_kits": 2}, + "technicians": [ + {"name": "Alice Jones"}, + {"name": "Bob Williams"}, + ], + }, + ), + RecordingStudio( + name="Electric Lady", + facilities={ + "restaurant": True, + "mixing_desk": True, + "instruments": {"electric_guitars": 6, "drum_kits": 3}, + "technicians": [ + {"name": "Frank Smith"}, + ], + }, + ), + ) + + # Test array indexing + response = ( + await RecordingStudio.select(RecordingStudio.name) + .where( + RecordingStudio.facilities.from_path( + ["technicians", 0, "name"] + ) + == "Alice Jones" + ) + .output(load_json=True) + ) + assert response is not None + self.assertListEqual(response, [{"name": "Abbey Road"}]) + + # Test boolean + response = ( + await RecordingStudio.select(RecordingStudio.name) + .where( + RecordingStudio.facilities.from_path(["restaurant"]).eq(True) + ) + .output(load_json=True) + ) + assert response is not None + self.assertListEqual(response, [{"name": "Electric Lady"}])