Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: port the internals to use koerce #10078

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,6 @@ quartodoc:
package: ibis.expr.types.generic
- name: Column
package: ibis.expr.types.generic
- name: Deferred
package: ibis.common.deferred
- name: Scalar
package: ibis.expr.types.generic

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/bigquery/tests/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import ibis.expr.operations as ops
from ibis import _
from ibis.backends.sql.compilers import BigQueryCompiler
from ibis.common.annotations import ValidationError
from ibis.common.grounds import ValidationError

to_sql = ibis.bigquery.compile

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/tests/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

import ibis
from ibis.common.annotations import ValidationError
from ibis.common.grounds import ValidationError

pytest.importorskip("clickhouse_connect")

Expand Down
5 changes: 4 additions & 1 deletion ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,7 +1640,10 @@ def _register_udf(self, udf_node: ops.ScalarUDF):
name = type(udf_node).__name__
type_mapper = self.compiler.type_mapper
input_types = [
type_mapper.to_string(param.annotation.pattern.dtype)
# TODO(kszucs): the data type of the input parameters should be
# retrieved differently rather than relying on the validator
# in the signature
type_mapper.to_string(param.pattern.func.dtype)
for param in udf_node.__signature__.parameters.values()
]
output_type = type_mapper.to_string(udf_node.dtype)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/impala/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import ibis.expr.types as ir
from ibis import util
from ibis.backends.impala import ddl
from ibis.common.annotations import ValidationError
from ibis.common.grounds import ValidationError
from ibis.expr import rules

pytest.importorskip("impala")
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/impala/tests/test_unary_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ibis
import ibis.expr.types as ir
from ibis.backends.impala.tests.conftest import translate
from ibis.common.annotations import ValidationError
from ibis.common.grounds import ValidationError


@pytest.fixture(scope="module")
Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/polars/rewrites.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations

from koerce import attribute, replace
from public import public

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.annotations import attribute
from ibis.common.collections import FrozenDict
from ibis.common.patterns import replace
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.schema import Schema

Expand Down
10 changes: 5 additions & 5 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

import sqlglot as sg
import sqlglot.expressions as sge
from koerce import Replace # noqa: TCH002
from public import public

import ibis.common.exceptions as com
import ibis.common.patterns as pats
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.rewrites import (
Expand Down Expand Up @@ -239,15 +239,15 @@ class SQLGlotCompiler(abc.ABC):
agg = AggGen()
"""A generator for handling aggregate functions"""

rewrites: tuple[type[pats.Replace], ...] = (
rewrites: tuple[type[Replace], ...] = (
empty_in_values_right_side,
add_order_by_to_empty_ranking_window_functions,
one_to_zero_index,
add_one_to_nth_value_input,
)
"""A sequence of rewrites to apply to the expression tree before SQL-specific transforms."""

post_rewrites: tuple[type[pats.Replace], ...] = ()
post_rewrites: tuple[type[Replace], ...] = ()
"""A sequence of rewrites to apply to the expression tree after SQL-specific transforms."""

no_limit_value: sge.Null | None = None
Expand Down Expand Up @@ -290,7 +290,7 @@ class SQLGlotCompiler(abc.ABC):
UNSUPPORTED_OPS: tuple[type[ops.Node], ...] = ()
"""Tuple of operations the backend doesn't support."""

LOWERED_OPS: dict[type[ops.Node], pats.Replace | None] = {
LOWERED_OPS: dict[type[ops.Node], Replace | None] = {
ops.Bucket: lower_bucket,
ops.Capitalize: lower_capitalize,
ops.Sample: lower_sample,
Expand Down Expand Up @@ -431,7 +431,7 @@ class SQLGlotCompiler(abc.ABC):
# Constructed dynamically in `__init_subclass__` from their respective
# UPPERCASE values to handle inheritance, do not modify directly here.
extra_supported_ops: ClassVar[frozenset[type[ops.Node]]] = frozenset()
lowered_ops: ClassVar[dict[type[ops.Node], pats.Replace]] = {}
lowered_ops: ClassVar[dict[type[ops.Node], Replace]] = {}

def __init__(self) -> None:
self.f = FuncGen(copy=self.__class__.copy_func_args)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> sge.Create:
signature = [
sge.ColumnDef(
this=sg.to_identifier(name, quoted=self.quoted),
kind=type_mapper.from_ibis(param.annotation.pattern.dtype),
kind=type_mapper.from_ibis(dt.dtype(param.typehint)),
)
for name, param in udf_node.__signature__.parameters.items()
]
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import sqlglot as sg
import sqlglot.expressions as sge
from koerce import var

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
Expand All @@ -26,7 +27,6 @@
replace,
split_select_distinct_with_order_by,
)
from ibis.common.deferred import var

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import sqlglot as sg
import sqlglot.expressions as sge
from koerce import replace

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
Expand All @@ -18,7 +19,6 @@
exclude_unsupported_window_frame_from_row_number,
rewrite_empty_order_by_window,
)
from ibis.common.patterns import replace
from ibis.expr.rewrites import p


Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import sqlglot as sg
import sqlglot.expressions as sge
from koerce import replace

import ibis
import ibis.common.exceptions as com
Expand All @@ -21,7 +22,6 @@
p,
split_select_distinct_with_order_by,
)
from ibis.common.patterns import replace
from ibis.config import options
from ibis.expr.operations.udf import InputType
from ibis.util import gen_name
Expand Down
10 changes: 4 additions & 6 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
from typing import TYPE_CHECKING, Any

import toolz
from koerce import Is, Object, Pattern, attribute, replace, var
from public import public

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.annotations import attribute
from ibis.common.collections import FrozenDict # noqa: TCH001
from ibis.common.deferred import var
from ibis.common.graph import Graph
from ibis.common.patterns import InstanceOf, Object, Pattern, replace
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.rewrites import d, p, replace_parameter
from ibis.expr.schema import Schema
Expand Down Expand Up @@ -330,7 +328,7 @@ def extract_ctes(node: ops.Relation) -> set[ops.Relation]:
cte_types = (Select, ops.Aggregate, ops.JoinChain, ops.Set, ops.Limit, ops.Sample)
dont_count = (ops.Field, ops.CountStar, ops.CountDistinctStar)

g = Graph.from_bfs(node, filter=~InstanceOf(dont_count))
g = Graph.from_bfs(node, filter=~Is(dont_count))
result = set()
for op, dependents in g.invert().items():
if isinstance(op, ops.View) or (
Expand Down Expand Up @@ -403,7 +401,7 @@ def sqlize(
if ctes:

def apply_ctes(node, kwargs):
new = node.__recreate__(kwargs) if kwargs else node
new = node.__class__(**kwargs) if kwargs else node
return CTE(new) if node in ctes else new

result = result.replace(apply_ctes)
Expand Down Expand Up @@ -454,7 +452,7 @@ def split_select_distinct_with_order_by(_):
return _


@replace(p.WindowFunction(func=p.NTile(y), order_by=()))
@replace(p.WindowFunction(func=p.NTile(+y), order_by=()))
def add_order_by_to_empty_ranking_window_functions(_, **kwargs):
"""Add an ORDER BY clause to rank window functions that don't have one."""
return _.copy(order_by=(y,))
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
SnowflakeProgrammingError,
TrinoUserError,
)
from ibis.common.annotations import ValidationError
from ibis.common.grounds import ValidationError

np = pytest.importorskip("numpy")
pd = pytest.importorskip("pandas")
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
PsycoPg2InternalError,
PyODBCProgrammingError,
)
from ibis.common.annotations import ValidationError
from ibis.common.grounds import ValidationError
from ibis.util import gen_name

np = pytest.importorskip("numpy")
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
SnowflakeProgrammingError,
TrinoUserError,
)
from ibis.common.annotations import ValidationError
from ibis.common.grounds import ValidationError
from ibis.conftest import IS_SPARK_REMOTE

np = pytest.importorskip("numpy")
Expand Down
Loading
Loading