Skip to content

Commit

Permalink
SNOW-1830524 Add decoder logic for Dataframe.join (#2802)
Browse files Browse the repository at this point in the history
1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   Fixes SNOW-1830524

2. Fill out the following pre-review checklist:

- [ ] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
- [x] I acknowledge that I have ensured my changes to be thread-safe.
Follow the link for more information: [Thread-safe Developer
Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development)

3. Please describe how your code solves the related issue.
Added decoder logic for `Dataframe.join`. All join tests should work
now.
  • Loading branch information
sfc-gh-vbudati authored Jan 14, 2025
1 parent ae5fd81 commit a7f5670
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 2 deletions.
10 changes: 9 additions & 1 deletion tests/ast/data/Dataframe.join.asof.test
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ df1 = session.create_dataframe([["A", 1, 15, 3.21], ["A", 2, 16, 3.22], ["B", 1,

df2 = session.create_dataframe([["A", 1, 14, 3.19], ["B", 2, 16, 3.04]], schema=["c1", "c2", "c3", "c4"])

df1.join(df2, on=(df1["c1"] == df2["c1"]) & (df1["c2"] == df2["c2"]), how="asof", lsuffix="_L", rsuffix="_R", match_condition=df1["c3"] >= df2["c3"]).sort("C1_L", "C2_L").collect()
df1.join(df2, on=(df1["c1"] == df2["c1"]) & (df1["c2"] == df2["c2"]), how="asof", lsuffix="_L", rsuffix="_R", match_condition=df1["c3"] >= df2["c3"]).sort("C1_L", "C2_L", ascending=None).collect()

## EXPECTED ENCODED AST

Expand Down Expand Up @@ -522,6 +522,14 @@ body {
assign {
expr {
sp_dataframe_sort {
ascending {
null_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 36
}
}
}
cols {
string_val {
src {
Expand Down
10 changes: 9 additions & 1 deletion tests/ast/data/Dataframe.join.prefix.test
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ df3 = df2.select((col("\"A\"") + 1).as_("\"A\""), col("\"B\""), col("\"C\""), co

df3 = df3_res1.join(df3, how="inner")

df4 = df3.sort(["\"l_0004_A\"", "\"l_0004_B\"", "\"l_0004_C\"", "\"r_0000_A\"", "\"l_0000_A\"", "\"l_0002_A\"", "\"r_0006_A\"", "\"r_0006_B\"", "\"r_0006_C\"", "\"l_0001_C\"", "\"l_0003_B\""])
df4 = df3.sort(["\"l_0004_A\"", "\"l_0004_B\"", "\"l_0004_C\"", "\"r_0000_A\"", "\"l_0000_A\"", "\"l_0002_A\"", "\"r_0006_A\"", "\"r_0006_B\"", "\"r_0006_C\"", "\"l_0001_C\"", "\"l_0003_B\""], ascending=None)

df4.collect()

Expand Down Expand Up @@ -515,6 +515,14 @@ body {
assign {
expr {
sp_dataframe_sort {
ascending {
null_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 31
}
}
}
cols {
string_val {
src {
Expand Down
78 changes: 78 additions & 0 deletions tests/ast/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,42 @@ def decode_data_type_expr(
"Unknown data type: %s" % data_type_expr.WhichOneof("variant")
)

def decode_join_type(self, join_type: proto.SpJoinType) -> str:
"""
Decode a join type expression to get the join type.
Parameters
----------
join_type : proto.SpJoinType
The expression to decode.
Returns
-------
str
The decoded join type.
"""
match join_type.WhichOneof("variant"):
case "sp_join_type__asof":
return "asof"
case "sp_join_type__cross":
return "cross"
case "sp_join_type__full_outer":
return "full"
case "sp_join_type__inner":
return "inner"
case "sp_join_type__left_anti":
return "anti"
case "sp_join_type__left_outer":
return "left"
case "sp_join_type__left_semi":
return "semi"
case "sp_join_type__right_outer":
return "right"
case _:
raise ValueError(
"Unknown join type: %s" % join_type.WhichOneof("variant")
)

def decode_timezone_expr(self, tz_expr: proto.PythonTimeZone) -> Any:
"""
Decode a Python timezone expression to get the timezone.
Expand Down Expand Up @@ -1103,6 +1139,47 @@ def decode_expr(self, expr: proto.Expr) -> Any:
other = self.decode_expr(expr.sp_dataframe_intersect.other)
return df.intersect(other)

case "sp_dataframe_join":
d = MessageToDict(expr.sp_dataframe_join)
join_expr = d.get("joinExpr", None)
join_expr = (
self.decode_expr(expr.sp_dataframe_join.join_expr)
if join_expr
else None
)
join_type = d.get("joinType", None)
join_type = (
self.decode_join_type(expr.sp_dataframe_join.join_type)
if join_type
else None
)
lhs = self.decode_expr(expr.sp_dataframe_join.lhs)
rhs = self.decode_expr(expr.sp_dataframe_join.rhs)
lsuffix = d.get("lsuffix", "")
rsuffix = d.get("rsuffix", "")
match_condition = d.get("matchCondition", None)
match_condition = (
self.decode_expr(expr.sp_dataframe_join.match_condition)
if match_condition
else None
)
return lhs.join(
right=rhs,
on=join_expr,
how=join_type,
lsuffix=lsuffix,
rsuffix=rsuffix,
match_condition=match_condition,
)

case "sp_dataframe_natural_join":
lhs = self.decode_expr(expr.sp_dataframe_natural_join.lhs)
rhs = self.decode_expr(expr.sp_dataframe_natural_join.rhs)
join_type = self.decode_join_type(
expr.sp_dataframe_natural_join.join_type
)
return lhs.natural_join(right=rhs, how=join_type)

case "sp_dataframe_na_drop__python":
df = self.decode_expr(expr.sp_dataframe_na_drop__python.df)
how = expr.sp_dataframe_na_drop__python.how
Expand Down Expand Up @@ -1209,6 +1286,7 @@ def decode_expr(self, expr: proto.Expr) -> Any:
self.decode_expr(col) for col in expr.sp_dataframe_sort.cols
)
ascending = self.decode_expr(expr.sp_dataframe_sort.ascending)

if MessageToDict(expr.sp_dataframe_sort).get("colsVariadic", False):
return df.sort(*cols, ascending=ascending)
else:
Expand Down
1 change: 1 addition & 0 deletions tests/ast/test_ast_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def test_ast(session, tables, test_case):
decoder = Decoder(session)
session._ast_batch.reset_id_gen() # Reset the entity ID generator.
session._ast_batch.flush() # Clear the AST.
global_counter.reset()

# Turn base64 input into protobuf objects. ParseFromString can retrieve multiple statements.
protobuf_request = base64_lines_to_request(stripped_base64_str)
Expand Down

0 comments on commit a7f5670

Please sign in to comment.