Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(athena): implement proper support for inserting data #10770

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions ibis/backends/athena/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import contextlib
import getpass
import os
import re
import sys
import tempfile
from pathlib import Path
Expand Down Expand Up @@ -159,11 +160,7 @@
if location is None:
location = f"{self._s3_staging_dir}/{name}"

property_list = [
sge.ExternalProperty(),
sge.FileFormatProperty(this=compiler.v[stored_as]),
sge.LocationProperty(this=sge.convert(location)),
]
property_list = []

for k, v in (properties or {}).items():
name = sg.to_identifier(k)
Expand Down Expand Up @@ -196,6 +193,9 @@
).from_(compiler.to_sqlglot(table).subquery())
else:
select = None
property_list.append(sge.ExternalProperty())
property_list.append(sge.FileFormatProperty(this=compiler.v[stored_as]))
property_list.append(sge.LocationProperty(this=sge.convert(location)))

create_stmt = sge.Create(
kind="TABLE",
Expand Down Expand Up @@ -287,8 +287,20 @@
def _safe_raw_sql(self, query, *args, unload: bool = True, **kwargs):
with contextlib.suppress(AttributeError):
query = query.sql(self.dialect)
with self.con.cursor(unload=unload) as cur:
yield cur.execute(query, *args, **kwargs)
try:
with self.con.cursor(unload=unload) as cur:
yield cur.execute(query, *args, **kwargs)
except pyathena.error.OperationalError as e:
# apparently unload=True and can just nope out and not tell you
# why, but unload=False is "fine"
#
# if the error isn't this opaque "internal" error, then we raise the original
# exception, otherwise try to execute the query again with unload=False
if unload and re.search("ErrorCode: INTERNAL_ERROR_QUERY_ENGINE", str(e)):
with self.con.cursor(unload=False) as cur:
yield cur.execute(query, *args, **kwargs)

Check warning on line 301 in ibis/backends/athena/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/athena/__init__.py#L300-L301

Added lines #L300 - L301 were not covered by tests
else:
raise

def list_catalogs(self, like: str | None = None) -> list[str]:
response = self.con.client.list_data_catalogs()
Expand Down
8 changes: 0 additions & 8 deletions ibis/backends/tests/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,3 @@
from pyathena.error import OperationalError as PyAthenaOperationalError
except ImportError:
PyAthenaDatabaseError = PyAthenaOperationalError = None


try:
from botocore.errorfactory import (
InvalidRequestException as BotoInvalidRequestException,
)
except ImportError:
BotoInvalidRequestException = None
12 changes: 2 additions & 10 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import ibis.expr.operations as ops
from ibis.backends.conftest import ALL_BACKENDS
from ibis.backends.tests.errors import (
BotoInvalidRequestException,
DatabricksServerOperationError,
ExaQueryError,
ImpalaHiveServer2Error,
Expand Down Expand Up @@ -97,11 +96,6 @@ def _create_temp_table_with_schema(backend, con, temp_table_name, schema, data=N
ids=["no_schema", "dict_schema", "tuples", "schema"],
)
@pytest.mark.notimpl(["druid"])
@pytest.mark.notimpl(
["athena"],
raises=BotoInvalidRequestException,
reason="create table requires a location",
)
@pytest.mark.notimpl(
["flink"],
reason="Flink backend supports creating only TEMPORARY VIEW for in-memory data.",
Expand Down Expand Up @@ -952,6 +946,7 @@ def test_self_join_memory_table(backend, con, monkeypatch):
"sqlite",
"trino",
"databricks",
"athena",
]
)
],
Expand Down Expand Up @@ -982,6 +977,7 @@ def test_self_join_memory_table(backend, con, monkeypatch):
"sqlite",
"trino",
"databricks",
"athena",
],
raises=com.UnsupportedOperationError,
reason="we don't materialize datasets to avoid perf footguns",
Expand Down Expand Up @@ -1033,7 +1029,6 @@ def test_self_join_memory_table(backend, con, monkeypatch):
],
)
@pytest.mark.notimpl(["druid"])
@pytest.mark.notimpl(["athena"], raises=BotoInvalidRequestException)
@pytest.mark.notimpl(
["flink"],
reason="Flink backend supports creating only TEMPORARY VIEW for in-memory data.",
Expand Down Expand Up @@ -1417,7 +1412,6 @@ def create_and_destroy_db(con):
reason="unclear whether Flink supports cross catalog/database inserts",
raises=Py4JJavaError,
)
@pytest.mark.notimpl(["athena"])
def test_insert_with_database_specified(con_create_database):
con = con_create_database

Expand Down Expand Up @@ -1604,7 +1598,6 @@ def test_schema_with_caching(alltypes):
["druid"], raises=NotImplementedError, reason="doesn't support create_table"
)
@pytest.mark.notyet(["polars"], reason="Doesn't support insert")
@pytest.mark.notyet(["athena"])
@pytest.mark.notyet(
["datafusion"], reason="Doesn't support table creation from records"
)
Expand Down Expand Up @@ -1698,7 +1691,6 @@ def test_no_accidental_cross_database_table_load(con_create_database):


@pytest.mark.notyet(["druid"], reason="can't create tables")
@pytest.mark.notimpl(["athena"], reason="can't create tables correctly in some cases")
@pytest.mark.notyet(
["flink"], reason="can't create non-temporary tables from in-memory data"
)
Expand Down
Loading