diff --git a/tests/ast/data/Dataframe.join.asof.test b/tests/ast/data/Dataframe.join.asof.test index b05323ec7d0..2109643329a 100644 --- a/tests/ast/data/Dataframe.join.asof.test +++ b/tests/ast/data/Dataframe.join.asof.test @@ -19,7 +19,7 @@ df1 = session.create_dataframe([["A", 1, 15, 3.21], ["A", 2, 16, 3.22], ["B", 1, df2 = session.create_dataframe([["A", 1, 14, 3.19], ["B", 2, 16, 3.04]], schema=["c1", "c2", "c3", "c4"]) -df1.join(df2, on=(df1["c1"] == df2["c1"]) & (df1["c2"] == df2["c2"]), how="asof", lsuffix="_L", rsuffix="_R", match_condition=df1["c3"] >= df2["c3"]).sort("C1_L", "C2_L").collect() +df1.join(df2, on=(df1["c1"] == df2["c1"]) & (df1["c2"] == df2["c2"]), how="asof", lsuffix="_L", rsuffix="_R", match_condition=df1["c3"] >= df2["c3"]).sort("C1_L", "C2_L", ascending=None).collect() ## EXPECTED ENCODED AST @@ -522,6 +522,14 @@ body { assign { expr { sp_dataframe_sort { + ascending { + null_val { + src { + file: "SRC_POSITION_TEST_MODE" + start_line: 36 + } + } + } cols { string_val { src { diff --git a/tests/ast/data/Dataframe.join.prefix.test b/tests/ast/data/Dataframe.join.prefix.test index 06f5aac19d5..fd43d9c6424 100644 --- a/tests/ast/data/Dataframe.join.prefix.test +++ b/tests/ast/data/Dataframe.join.prefix.test @@ -22,7 +22,7 @@ df3 = df2.select((col("\"A\"") + 1).as_("\"A\""), col("\"B\""), col("\"C\""), co df3 = df3_res1.join(df3, how="inner") -df4 = df3.sort(["\"l_0004_A\"", "\"l_0004_B\"", "\"l_0004_C\"", "\"r_0000_A\"", "\"l_0000_A\"", "\"l_0002_A\"", "\"r_0006_A\"", "\"r_0006_B\"", "\"r_0006_C\"", "\"l_0001_C\"", "\"l_0003_B\""]) +df4 = df3.sort(["\"l_0004_A\"", "\"l_0004_B\"", "\"l_0004_C\"", "\"r_0000_A\"", "\"l_0000_A\"", "\"l_0002_A\"", "\"r_0006_A\"", "\"r_0006_B\"", "\"r_0006_C\"", "\"l_0001_C\"", "\"l_0003_B\""], ascending=None) df4.collect() @@ -515,6 +515,14 @@ body { assign { expr { sp_dataframe_sort { + ascending { + null_val { + src { + file: "SRC_POSITION_TEST_MODE" + start_line: 31 + } + } + } cols { string_val { src { diff --git a/tests/ast/decoder.py b/tests/ast/decoder.py index 58a5eb9e19f..c98a68b19d6 100644 --- a/tests/ast/decoder.py +++ b/tests/ast/decoder.py @@ -473,6 +473,42 @@ def decode_data_type_expr( "Unknown data type: %s" % data_type_expr.WhichOneof("variant") ) + def decode_join_type(self, join_type: proto.SpJoinType) -> str: + """ + Decode a join type expression to get the join type. + + Parameters + ---------- + join_type : proto.SpJoinType + The expression to decode. + + Returns + ------- + str + The decoded join type. + """ + match join_type.WhichOneof("variant"): + case "sp_join_type__asof": + return "asof" + case "sp_join_type__cross": + return "cross" + case "sp_join_type__full_outer": + return "full" + case "sp_join_type__inner": + return "inner" + case "sp_join_type__left_anti": + return "anti" + case "sp_join_type__left_outer": + return "left" + case "sp_join_type__left_semi": + return "semi" + case "sp_join_type__right_outer": + return "right" + case _: + raise ValueError( + "Unknown join type: %s" % join_type.WhichOneof("variant") + ) + def decode_timezone_expr(self, tz_expr: proto.PythonTimeZone) -> Any: """ Decode a Python timezone expression to get the timezone. @@ -1103,6 +1139,47 @@ def decode_expr(self, expr: proto.Expr) -> Any: other = self.decode_expr(expr.sp_dataframe_intersect.other) return df.intersect(other) + case "sp_dataframe_join": + d = MessageToDict(expr.sp_dataframe_join) + join_expr = d.get("joinExpr", None) + join_expr = ( + self.decode_expr(expr.sp_dataframe_join.join_expr) + if join_expr + else None + ) + join_type = d.get("joinType", None) + join_type = ( + self.decode_join_type(expr.sp_dataframe_join.join_type) + if join_type + else None + ) + lhs = self.decode_expr(expr.sp_dataframe_join.lhs) + rhs = self.decode_expr(expr.sp_dataframe_join.rhs) + lsuffix = d.get("lsuffix", "") + rsuffix = d.get("rsuffix", "") + match_condition = d.get("matchCondition", None) + match_condition = ( + self.decode_expr(expr.sp_dataframe_join.match_condition) + if match_condition + else None + ) + return lhs.join( + right=rhs, + on=join_expr, + how=join_type, + lsuffix=lsuffix, + rsuffix=rsuffix, + match_condition=match_condition, + ) + + case "sp_dataframe_natural_join": + lhs = self.decode_expr(expr.sp_dataframe_natural_join.lhs) + rhs = self.decode_expr(expr.sp_dataframe_natural_join.rhs) + join_type = self.decode_join_type( + expr.sp_dataframe_natural_join.join_type + ) + return lhs.natural_join(right=rhs, how=join_type) + case "sp_dataframe_na_drop__python": df = self.decode_expr(expr.sp_dataframe_na_drop__python.df) how = expr.sp_dataframe_na_drop__python.how @@ -1209,6 +1286,7 @@ def decode_expr(self, expr: proto.Expr) -> Any: self.decode_expr(col) for col in expr.sp_dataframe_sort.cols ) ascending = self.decode_expr(expr.sp_dataframe_sort.ascending) + if MessageToDict(expr.sp_dataframe_sort).get("colsVariadic", False): return df.sort(*cols, ascending=ascending) else: diff --git a/tests/ast/test_ast_driver.py b/tests/ast/test_ast_driver.py index 4b7dcaaf846..71041878e1d 100644 --- a/tests/ast/test_ast_driver.py +++ b/tests/ast/test_ast_driver.py @@ -291,6 +291,7 @@ def test_ast(session, tables, test_case): decoder = Decoder(session) session._ast_batch.reset_id_gen() # Reset the entity ID generator. session._ast_batch.flush() # Clear the AST. + global_counter.reset() # Turn base64 input into protobuf objects. ParseFromString can retrieve multiple statements. protobuf_request = base64_lines_to_request(stripped_base64_str)