diff --git a/src/viur/core/bones/json.py b/src/viur/core/bones/json.py index 4aeae0458..621f984ab 100644 --- a/src/viur/core/bones/json.py +++ b/src/viur/core/bones/json.py @@ -1,11 +1,10 @@ import ast import json -import typing as t - import jsonschema - +import typing as t from viur.core.bones.base import ReadFromClientError, ReadFromClientErrorSeverity from viur.core.bones.raw import RawBone +from viur.core import utils class JsonBone(RawBone): @@ -23,9 +22,14 @@ class JsonBone(RawBone): type = "raw.json" - def __init__(self, indexed: bool = False, multiple: bool = False, languages: bool = None, schema: t.Mapping = {}, - *args, - **kwargs): + def __init__( + self, + indexed: bool = False, + multiple: bool = False, + languages: bool = None, + schema: t.Mapping = {}, + *args, **kwargs + ): super().__init__(*args, **kwargs) assert not multiple assert not languages @@ -36,7 +40,7 @@ def __init__(self, indexed: bool = False, multiple: bool = False, languages: boo def serialize(self, skel: 'SkeletonInstance', name: str, parentIndexed: bool) -> bool: if name in skel.accessedValues: - skel.dbEntity[name] = json.dumps(skel.accessedValues[name]) + skel.dbEntity[name] = utils.json.dumps(skel.accessedValues[name]) # Ensure this bone is NOT indexed! skel.dbEntity.exclude_from_indexes.add(name) @@ -47,7 +51,7 @@ def serialize(self, skel: 'SkeletonInstance', name: str, parentIndexed: bool) -> def unserialize(self, skel: 'viur.core.skeleton.SkeletonInstance', name: str) -> bool: if data := skel.dbEntity.get(name): - skel.accessedValues[name] = json.loads(data) + skel.accessedValues[name] = utils.json.loads(data) return True return False @@ -59,7 +63,7 @@ def singleValueFromClient(self, value: str | list | dict, skel, bone_name, clien # Try to parse a JSON string try: - value = json.loads(value) + value = utils.json.loads(value) except json.decoder.JSONDecodeError as e: # Try to parse a Python dict as fallback @@ -76,8 +80,11 @@ def singleValueFromClient(self, value: str | list | dict, skel, bone_name, clien jsonschema.validate(value, self.schema) except (jsonschema.exceptions.ValidationError, jsonschema.exceptions.SchemaError) as e: return self.getEmptyValue(), [ - ReadFromClientError(ReadFromClientErrorSeverity.Invalid, - f"Invalid JSON for schema supplied: {e!s}")] + ReadFromClientError( + ReadFromClientErrorSeverity.Invalid, + f"Invalid JSON for schema supplied: {e!s}") + ] + return super().singleValueFromClient(value, skel, bone_name, client_data) def structure(self) -> dict: diff --git a/src/viur/core/tasks.py b/src/viur/core/tasks.py index a3d705978..9e0e518e6 100644 --- a/src/viur/core/tasks.py +++ b/src/viur/core/tasks.py @@ -1,12 +1,10 @@ import abc -import base64 import datetime import functools import grpc import json import logging import os -import pytz import requests import sys import time @@ -43,52 +41,6 @@ def restore(self, obj: CUSTOM_OBJ) -> None: ... -def _preprocess_json_object(obj): - """ - Add support for db.Key, datetime, bytes and db.Entity in deferred tasks, - and converts the provided obj into a special dict with JSON-serializable values. - """ - if isinstance(obj, db.Key): - return {".__key__": db.encodeKey(obj)} - elif isinstance(obj, datetime.datetime): - return {".__datetime__": obj.astimezone(pytz.UTC).isoformat()} - elif isinstance(obj, bytes): - return {".__bytes__": base64.b64encode(obj).decode("ASCII")} - elif isinstance(obj, db.Entity): - # TODO: Support Skeleton instances as well? - return { - ".__entity__": _preprocess_json_object(dict(obj)), - ".__ekey__": db.encodeKey(obj.key) if obj.key else None - } - elif isinstance(obj, dict): - return {_preprocess_json_object(k): _preprocess_json_object(v) for k, v in obj.items()} - elif isinstance(obj, (list, tuple, set)): - return [_preprocess_json_object(x) for x in obj] - - return obj - - -def _decode_object_hook(obj): - """ - Inverse for _preprocess_json_object, which is an object-hook for json.loads. - Check if the object matches a custom ViUR type and recreate it accordingly. - """ - if len(obj) == 1: - if key := obj.get(".__key__"): - return db.Key.from_legacy_urlsafe(key) - elif date := obj.get(".__datetime__"): - return datetime.datetime.fromisoformat(date) - elif buf := obj.get(".__bytes__"): - return base64.b64decode(buf) - - elif len(obj) == 2 and ".__entity__" in obj and ".__ekey__" in obj: - entity = db.Entity(db.Key.from_legacy_urlsafe(obj[".__ekey__"]) if obj[".__ekey__"] else None) - entity.update(obj[".__entity__"]) - return entity - - return obj - - _gaeApp = os.environ.get("GAE_APPLICATION") queueRegion = None @@ -221,7 +173,7 @@ def queryIter(self, *args, **kwargs): """ req = current.request.get().request self._validate_request() - data = json.loads(req.body, object_hook=_decode_object_hook) + data = utils.json.loads(req.body) if data["classID"] not in MetaQueryIter._classCache: logging.error(f"""Could not continue queryIter - {data["classID"]} not known on this instance""") MetaQueryIter._classCache[data["classID"]]._qryStep(data) @@ -242,7 +194,7 @@ def deferred(self, *args, **kwargs): f"""Task {req.headers.get("X-Appengine-Taskname", "")} is retried for the {retryCount}th time.""" ) - cmd, data = json.loads(req.body, object_hook=_decode_object_hook) + cmd, data = utils.json.loads(req.body) funcPath, args, kwargs, env = data logging.debug(f"Call task {funcPath} with {cmd=} {args=} {kwargs=} {env=}") @@ -612,7 +564,7 @@ def task(): # Create task description task = tasks_v2.Task( app_engine_http_request=tasks_v2.AppEngineHttpRequest( - body=json.dumps(_preprocess_json_object((command, (funcPath, args, kwargs, env)))).encode("UTF-8"), + body=utils.json.dumps((command, (funcPath, args, kwargs, env))).encode(), http_method=tasks_v2.HttpMethod.POST, relative_uri=taskargs["url"], app_engine_routing=tasks_v2.AppEngineRouting( @@ -787,7 +739,7 @@ def _requeueStep(cls, qryDict: dict[str, t.Any]) -> None: parent=taskClient.queue_path(conf.instance.project_id, queueRegion, cls.queueName), task=tasks_v2.Task( app_engine_http_request=tasks_v2.AppEngineHttpRequest( - body=json.dumps(_preprocess_json_object(qryDict)).encode("UTF-8"), + body=utils.json.dumps(qryDict).encode(), http_method=tasks_v2.HttpMethod.POST, relative_uri="/_tasks/queryIter", app_engine_routing=tasks_v2.AppEngineRouting( diff --git a/src/viur/core/utils/__init__.py b/src/viur/core/utils/__init__.py index 0fff028b0..834ef12c6 100644 --- a/src/viur/core/utils/__init__.py +++ b/src/viur/core/utils/__init__.py @@ -9,7 +9,7 @@ from urllib.parse import quote from viur.core import current, db from viur.core.config import conf -from . import string, parse +from . import string, parse, json def utcNow() -> datetime: diff --git a/src/viur/core/utils/json.py b/src/viur/core/utils/json.py new file mode 100644 index 000000000..aec0aafe8 --- /dev/null +++ b/src/viur/core/utils/json.py @@ -0,0 +1,75 @@ +import base64 +import datetime +import json +import pytz +import typing as t +from viur.core import db + + +class ViURJsonEncoder(json.JSONEncoder): + """ + Adds support for db.Key, db.Entity, datetime, bytes and and converts the provided obj + into a special dict with JSON-serializable values. + """ + def default(self, obj: t.Any) -> t.Any: + if isinstance(obj, bytes): + return {".__bytes__": base64.b64encode(obj).decode("ASCII")} + elif isinstance(obj, datetime.datetime): + return {".__datetime__": obj.astimezone(pytz.UTC).isoformat()} + elif isinstance(obj, datetime.timedelta): + return {".__timedelta__": obj / datetime.timedelta(microseconds=1)} + elif isinstance(obj, set): + return {".__set__": list(obj)} + elif hasattr(obj, "__iter__"): + return tuple(obj) + # cannot be tested in tests... + elif isinstance(obj, db.Key): + return {".__key__": db.encodeKey(obj)} + elif isinstance(obj, db.Entity): + # TODO: Handle SkeletonInstance as well? + return { + ".__entity__": dict(obj), + ".__key__": db.encodeKey(obj.key) if obj.key else None + } + + return super().default(obj) + + +def dumps(obj: t.Any, *, cls=ViURJsonEncoder, **kwargs) -> str: + """ + Wrapper for json.dumps() which converts additional ViUR datatypes. + """ + return json.dumps(obj, cls=cls, **kwargs) + + +def _decode_object_hook(obj: t.Any): + """ + Inverse for _preprocess_json_object, which is an object-hook for json.loads. + Check if the object matches a custom ViUR type and recreate it accordingly. + """ + if len(obj) == 1: + if buf := obj.get(".__bytes__"): + return base64.b64decode(buf) + elif date := obj.get(".__datetime__"): + return datetime.datetime.fromisoformat(date) + elif microseconds := obj.get(".__timedelta__"): + return datetime.timedelta(microseconds=microseconds) + elif key := obj.get(".__key__"): + return db.Key.from_legacy_urlsafe(key) + elif items := obj.get(".__set__"): + return set(items) + + elif len(obj) == 2 and all(k in obj for k in (".__entity__", ".__key__")): + # TODO: Handle SkeletonInstance as well? + entity = db.Entity(db.Key.from_legacy_urlsafe(obj[".__key__"]) if obj[".__key__"] else None) + entity.update(obj[".__entity__"]) + return entity + + return obj + + +def loads(s: str, *, object_hook=_decode_object_hook, **kwargs) -> t.Any: + """ + Wrapper for json.loads() which recreates additional ViUR datatypes. + """ + return json.loads(s, object_hook=object_hook, **kwargs) diff --git a/tests/main.py b/tests/main.py index 8befef005..f74581ef6 100755 --- a/tests/main.py +++ b/tests/main.py @@ -85,8 +85,10 @@ def __init__(self, *args, **kwargs): ) viur_datastore = mock.Mock() + for attr in db_attr: setattr(viur_datastore, attr, mock.MagicMock()) + viur_datastore.config = {} sys.modules["viur.datastore"] = viur_datastore diff --git a/tests/test_utils.py b/tests/test_utils.py index c17ce60e3..4267f1925 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -25,3 +25,47 @@ def test_string_escape(self): self.assertEqual("abcde", utils.string.escape("abcdefghi", max_length=5)) self.assertEqual("<html> &</html>", utils.string.escape("\n&\0")) self.assertEqual(utils.string.escape(S), E) + + def test_json(self): + from viur.core import utils, db + import datetime + + # key = db.Key("test", "hello world") + now = datetime.datetime.fromisoformat("2024-02-28T14:43:17.125207+00:00") + duration = datetime.timedelta(minutes=13, microseconds=37) + + example = { + "datetime": now, + "false": False, + "float": 42.5, + "generator": (x for x in "Hello"), + "int": 1337, + # "key": key, # cannot use in tests + "list": [1, 2, 3], + "none": None, + "set": {1, 2, 3}, + "str": "World", + "timedelta": duration, + "true": True, + "tuple": (1, 2, 3), + } + + # serialize example into string + s = utils.json.dumps(example) + + # check if string is as expected + self.assertEqual( + s, + """{"datetime": {".__datetime__": "2024-02-28T14:43:17.125207+00:00"}, "false": false, "float": 42.5, "generator": ["H", "e", "l", "l", "o"], "int": 1337, "list": [1, 2, 3], "none": null, "set": {".__set__": [1, 2, 3]}, "str": "World", "timedelta": {".__timedelta__": 780000037.0}, "true": true, "tuple": [1, 2, 3]}""", # noqa + ) + + # deserialize string into object again + o = utils.json.loads(s) + + # patch tuple as a list + example["tuple"] = list(example["tuple"]) + example["generator"] = [x for x in "Hello"] + + # self.assertEqual(example, o) + for k, v in example.items(): + self.assertEqual(o[k], v)