Skip to content

Commit

Permalink
QB: Re-introduction of contains Filter Operator for SQLite (#6619)
Browse files Browse the repository at this point in the history
This commit re-introduces the `contains` filter operator for the SQLite backend. The primary goal is to replicate the functionality of `contains` from PostgreSQL to achieve feature parity between the two different storage backends.

Due to the particular (and sometimes unexpected) behavior of PostgreSQL's containment operator (#6618), a pure SQL implementation would be highly complex and difficult to maintain. As the behavior of the `contains` operator should be _exactly_ the same, irrespective of which backend is being used, custom Python functions were chosen for the implementation (rather than SQLAlchemy's high-level abstractions).

Comprehensive tests and benchmarks that cover all potential use cases for `contains` on the SQLite backend have been added.
  • Loading branch information
rabbull authored Jan 21, 2025
1 parent c88fc05 commit aa0aa26
Show file tree
Hide file tree
Showing 7 changed files with 584 additions and 224 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,6 @@ pplot_out/

# docker
docker-bake.override.json

# benchmark
.benchmarks/
19 changes: 16 additions & 3 deletions src/aiida/storage/sqlite_zip/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from sqlalchemy import JSON, case, func, select
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql import ColumnElement, null

from aiida.common.lang import type_check
from aiida.storage.psql_dos.orm import authinfos, comments, computers, entities, groups, logs, nodes, users, utils
Expand Down Expand Up @@ -285,8 +285,21 @@ def _cast_json_type(comparator: JSON.Comparator, value: Any) -> Tuple[ColumnElem
return case((type_filter, casted_entity.ilike(value, escape='\\')), else_=False)

if operator == 'contains':
# to-do, see: https://github.com/sqlalchemy/sqlalchemy/discussions/7836
raise NotImplementedError('The operator `contains` is not implemented for SQLite-based storage plugins.')
# If the operator is 'contains', we must mirror the behavior of the PostgreSQL
# backend, which returns NULL if `attr_key` doesn't exist. To achieve this,
# an additional CASE statement is added to directly return NULL in such cases.
#
# Instead of using `database_entity`, which would be interpreted as a 'null'
# string in SQL, this approach ensures a proper NULL value is returned when
# `attr_key` doesn't exist.
#
# Original implementation:
# return func.json_contains(database_entity, json.dumps(value))

return case(
(func.json_extract(column, '$.' + '.'.join(attr_key)).is_(null()), null()),
else_=func.json_contains(database_entity, json.dumps(value)),
)

if operator == 'has_key':
return (
Expand Down
32 changes: 32 additions & 0 deletions src/aiida/storage/sqlite_zip/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,44 @@ def sqlite_case_sensitive_like(dbapi_connection, _):
cursor.close()


def _contains(lhs: Union[dict, list], rhs: Union[dict, list]):
if isinstance(lhs, dict) and isinstance(rhs, dict):
for key in rhs:
if key not in lhs or not _contains(lhs[key], rhs[key]):
return False
return True

elif isinstance(lhs, list) and isinstance(rhs, list):
for item in rhs:
if not any(_contains(e, item) for e in lhs):
return False
return True
else:
return lhs == rhs


def _json_contains(lhs: Union[str, bytes, bytearray, dict, list], rhs: Union[str, bytes, bytearray, dict, list]):
try:
if isinstance(lhs, (str, bytes, bytearray)):
lhs = json.loads(lhs)
if isinstance(rhs, (str, bytes, bytearray)):
rhs = json.loads(rhs)
except json.JSONDecodeError:
return 0
return int(_contains(lhs, rhs))


def register_json_contains(dbapi_connection, _):
dbapi_connection.create_function('json_contains', 2, _json_contains)


def create_sqla_engine(path: Union[str, Path], *, enforce_foreign_keys: bool = True, **kwargs) -> Engine:
"""Create a new engine instance."""
engine = create_engine(f'sqlite:///{path}', json_serializer=json.dumps, json_deserializer=json.loads, **kwargs)
event.listen(engine, 'connect', sqlite_case_sensitive_like)
if enforce_foreign_keys:
event.listen(engine, 'connect', sqlite_enforce_foreign_keys)
event.listen(engine, 'connect', register_json_contains)
return engine


Expand Down
138 changes: 138 additions & 0 deletions tests/benchmark/test_json_contains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import random
import string

import pytest

from aiida import orm
from aiida.orm.querybuilder import QueryBuilder

GROUP_NAME = 'json-contains'


COMPLEX_JSON_DEPTH_RANGE = [2**i for i in range(4)]
COMPLEX_JSON_BREADTH_RANGE = [2**i for i in range(4)]
LARGE_TABLE_SIZE_RANGE = [2**i for i in range(1, 11)]


def gen_json(depth: int, breadth: int):
def gen_str(n: int, with_digits: bool = True):
population = string.ascii_letters
if with_digits:
population += string.digits
return ''.join(random.choices(population, k=n))

if depth == 0: # random primitive value
# real numbers are not included as their equivalence is tricky
return random.choice(
[
random.randint(-114, 514), # integers
gen_str(6), # strings
random.choice([True, False]), # booleans
None, # nulls
]
)

else:
gen_dict = random.choice([True, False])
data = [gen_json(depth - 1, breadth) for _ in range(breadth)]
if gen_dict:
keys = set()
while len(keys) < breadth:
keys.add(gen_str(6, False))
data = dict(zip(list(keys), data))
return data


def extract_component(data, p: float = -1):
if random.random() < p:
return data

if isinstance(data, dict) and data:
key = random.choice(list(data.keys()))
return {key: extract_component(data[key])}
elif isinstance(data, list) and data:
element = random.choice(data)
return [extract_component(element)]
else:
return data


@pytest.mark.benchmark(group=GROUP_NAME)
@pytest.mark.parametrize('depth', [1, 2, 4, 8])
@pytest.mark.parametrize('breadth', [1, 2, 4])
@pytest.mark.usefixtures('aiida_profile_clean')
def test_deep_json(benchmark, depth, breadth):
lhs = gen_json(depth, breadth)
rhs = extract_component(lhs, p=1.0 / depth)
assert 0 == len(QueryBuilder().append(orm.Dict).all())

orm.Dict(
{
'id': f'{depth}-{breadth}',
'data': lhs,
}
).store()
qb = QueryBuilder().append(
orm.Dict,
filters={
'attributes.data': {'contains': rhs},
},
project=['attributes.id'],
)
qb.all()
result = benchmark(qb.all)
assert len(result) == 1


@pytest.mark.benchmark(group=GROUP_NAME)
@pytest.mark.parametrize('depth', [2])
@pytest.mark.parametrize('breadth', [1, 10, 100])
@pytest.mark.usefixtures('aiida_profile_clean')
def test_wide_json(benchmark, depth, breadth):
lhs = gen_json(depth, breadth)
rhs = extract_component(lhs, p=1.0 / depth)
assert 0 == len(QueryBuilder().append(orm.Dict).all())

orm.Dict(
{
'id': f'{depth}-{breadth}',
'data': lhs,
}
).store()
qb = QueryBuilder().append(
orm.Dict,
filters={
'attributes.data': {'contains': rhs},
},
project=['attributes.id'],
)
qb.all()
result = benchmark(qb.all)
assert len(result) == 1


@pytest.mark.benchmark(group=GROUP_NAME)
@pytest.mark.parametrize('num_entries', LARGE_TABLE_SIZE_RANGE)
@pytest.mark.usefixtures('aiida_profile_clean')
def test_large_table(benchmark, num_entries):
data = gen_json(2, 10)
rhs = extract_component(data)
assert 0 == len(QueryBuilder().append(orm.Dict).all())

for i in range(num_entries):
orm.Dict(
{
'id': f'N={num_entries}, i={i}',
'data': data,
}
).store()
qb = QueryBuilder().append(
orm.Dict,
filters={
'attributes.data': {'contains': rhs},
},
project=['attributes.id'],
)
qb.all()
result = benchmark(qb.all)
assert len(result) == num_entries
Loading

0 comments on commit aa0aa26

Please sign in to comment.