diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 3cdbfd3d1..c62160e12 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -333,17 +333,28 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) { @Override public RelNode visit(Sort sort) throws RuntimeException { RelNode child = sort.getInput().accept(this); - List relFieldCollations = + List sortExpressions = sort.getSortFields().stream() - .map(sortField -> toRelFieldCollation(sortField)) + .map(this::directedRexNode) .collect(java.util.stream.Collectors.toList()); - if (relFieldCollations.isEmpty()) { - return relBuilder.push(child).sort(Collections.EMPTY_LIST).build(); - } - RelNode node = relBuilder.push(child).sort(RelCollations.of(relFieldCollations)).build(); + RelNode node = relBuilder.push(child).sort(sortExpressions).build(); return applyRemap(node, sort.getRemap()); } + private RexNode directedRexNode(Expression.SortField sortField) { + var expression = sortField.expr(); + var rexNode = expression.accept(expressionRexConverter); + var sortDirection = sortField.direction(); + return switch (sortDirection) { + case ASC_NULLS_FIRST -> relBuilder.nullsFirst(rexNode); + case ASC_NULLS_LAST -> relBuilder.nullsLast(rexNode); + case DESC_NULLS_FIRST -> relBuilder.nullsFirst(relBuilder.desc(rexNode)); + case DESC_NULLS_LAST -> relBuilder.nullsLast(relBuilder.desc(rexNode)); + case CLUSTERED -> throw new RuntimeException( + String.format("Unexpected Expression.SortDirection: Clustered!")); + }; + } + @Override public RelNode visit(Fetch fetch) throws RuntimeException { RelNode child = fetch.getInput().accept(this); diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java new file mode 100644 index 000000000..4972fce50 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java @@ -0,0 +1,210 @@ +package io.substrait.isthmus; + +import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression; +import io.substrait.relation.Rel; +import io.substrait.type.TypeCreator; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.List; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.externalize.RelWriterImpl; +import org.apache.calcite.sql.SqlExplainLevel; +import org.apache.calcite.util.Pair; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.jupiter.api.Test; + +public class ComplexSortTest extends PlanTestBase { + + final TypeCreator R = TypeCreator.of(false); + SubstraitBuilder b = new SubstraitBuilder(extensions); + + final SubstraitToCalcite substraitToCalcite = + new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory); + + /** + * A {@link RelWriterImpl} that annotates each {@link RelNode} with its {@link RelCollation} trait + * information. A {@link RelNode} is only annotated if its {@link RelCollation} is not empty. + */ + public static class CollationRelWriter extends RelWriterImpl { + public CollationRelWriter(StringWriter sw) { + super(new PrintWriter(sw), SqlExplainLevel.EXPPLAN_ATTRIBUTES, false); + } + + @Override + protected void explain_(RelNode rel, List> values) { + var collation = rel.getTraitSet().getCollation(); + if (!collation.isDefault()) { + StringBuilder s = new StringBuilder(); + spacer.spaces(s); + s.append("Collation: ").append(collation.toString()); + pw.println(s); + } + super.explain_(rel, values); + } + } + + @Test + void handleInputReferenceSort() { + // CREATE TABLE example (a VARCHAR) + // SELECT a FROM example ORDER BY a + + Rel rel = + b.project( + input -> b.fieldReferences(input, 0), + b.remap(1), + b.sort( + input -> + List.of( + b.sortField( + b.fieldReference(input, 0), Expression.SortDirection.ASC_NULLS_LAST)), + b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); + + String expected = + """ + Collation: [0] + LogicalSort(sort0=[$0], dir0=[ASC]) + LogicalTableScan(table=[[example]]) + """; + + RelNode relReturned = substraitToCalcite.convert(rel); + var sw = new StringWriter(); + relReturned.explain(new CollationRelWriter(sw)); + assertEquals(expected, sw.toString()); + } + + @Test + void handleCastExpressionSort() { + // CREATE TABLE example (a VARCHAR) + // SELECT a FROM example ORDER BY a::INT + + Rel rel = + b.project( + input -> b.fieldReferences(input, 0), + b.remap(1), + b.sort( + input -> + List.of( + b.sortField( + b.cast(b.fieldReference(input, 0), R.I32), + Expression.SortDirection.ASC_NULLS_LAST)), + b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); + + String expected = + """ + LogicalProject(a0=[$0]) + Collation: [1] + LogicalSort(sort0=[$1], dir0=[ASC]) + LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL]) + LogicalTableScan(table=[[example]]) + """; + + RelNode relReturned = substraitToCalcite.convert(rel); + var sw = new StringWriter(); + relReturned.explain(new CollationRelWriter(sw)); + assertEquals(expected, sw.toString()); + } + + @Test + void handleCastProjectAndSortWithSortDirection() { + // CREATE TABLE example (a VARCHAR) + // SELECT a::INT FROM example ORDER BY a::INT DESC NULLS LAST + + Rel rel = + b.project( + input -> List.of(b.cast(b.fieldReference(input, 0), R.I32)), + b.remap(1), + b.sort( + input -> + List.of( + b.sortField( + b.cast(b.fieldReference(input, 0), R.I32), + Expression.SortDirection.DESC_NULLS_LAST)), + b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); + + String expected = + """ + LogicalProject(a0=[CAST($0):INTEGER NOT NULL]) + Collation: [1 DESC-nulls-last] + LogicalSort(sort0=[$1], dir0=[DESC-nulls-last]) + LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL]) + LogicalTableScan(table=[[example]]) + """; + + RelNode relReturned = substraitToCalcite.convert(rel); + var sw = new StringWriter(); + relReturned.explain(new CollationRelWriter(sw)); + assertEquals(expected, sw.toString()); + } + + @Test + void handleCastSortToOriginalType() { + // CREATE TABLE example (a VARCHAR) + // SELECT a FROM example ORDER BY a::VARCHAR + + Rel rel = + b.project( + input -> List.of(b.fieldReference(input, 0)), + b.remap(1), + b.sort( + input -> + List.of( + b.sortField( + b.cast(b.fieldReference(input, 0), R.STRING), + Expression.SortDirection.DESC_NULLS_LAST)), + b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); + + String expected = + """ + LogicalProject(a0=[$0]) + Collation: [1 DESC-nulls-last] + LogicalSort(sort0=[$1], dir0=[DESC-nulls-last]) + LogicalProject(a=[$0], a0=[$0]) + LogicalTableScan(table=[[example]]) + """; + + RelNode relReturned = substraitToCalcite.convert(rel); + var sw = new StringWriter(); + relReturned.explain(new CollationRelWriter(sw)); + assertEquals(expected, sw.toString()); + } + + @Test + void handleComplex2ExpressionSort() { + // CREATE TABLE example (a VARCHAR, b INT) + // SELECT b, a FROM example ORDER BY a::INT DESC, -b + 42 ASC NULLS LAST + + Rel rel = + b.project( + input -> List.of(b.fieldReference(input, 0), b.fieldReference(input, 1)), + b.remap(2, 3), + b.sort( + input -> + List.of( + b.sortField( + b.cast(b.fieldReference(input, 0), R.I32), + Expression.SortDirection.DESC_NULLS_FIRST), + b.sortField( + b.add(b.negate(b.fieldReference(input, 1)), b.i32(42)), + Expression.SortDirection.ASC_NULLS_LAST)), + b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.STRING, R.I32)))); + + String expected = + """ + LogicalProject(a0=[$0], b0=[$1]) + Collation: [2 DESC, 3] + LogicalSort(sort0=[$2], sort1=[$3], dir0=[DESC], dir1=[ASC]) + LogicalProject(a=[$0], b=[$1], a0=[CAST($0):INTEGER NOT NULL], $f3=[+(-($1), 42)]) + LogicalTableScan(table=[[example]]) + """; + + RelNode relReturned = substraitToCalcite.convert(rel); + var sw = new StringWriter(); + relReturned.explain(new CollationRelWriter(sw)); + assertEquals(expected, sw.toString()); + } +}