Skip to content

Commit

Permalink
Convert some global typing passes to mini passes (#11717)
Browse files Browse the repository at this point in the history
* Convert ShadowedPatternFields to a mini pass

* Case.Branch expression is not encountered during transformExpression

* Convert ShadowedPatternFields to Java

* Convert UnreachableMatchBranches to mini pass

* Refactor UnreachableMatchBranches to Java

* Fix NPE

* Add test with IGV graph

* CompilerTest.assertIR uses IGV dumper

* IGVDumper does not care about already existing node

* MiniPassTest can use IGV dumper

* [WIP] Skeleton of NestedPatternMatchMini

* Revert "[WIP] Skeleton of NestedPatternMatchMini"

This reverts commit babc82e.
  • Loading branch information
Akirathan authored Feb 8, 2025
1 parent 535ae00 commit efe8786
Show file tree
Hide file tree
Showing 15 changed files with 513 additions and 414 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,11 @@ private ASTNode newNode(IR ir, Map<String, Object> props) {
}
bldr.id(nodeId);
props.forEach(bldr::property);
var existingNode = nodes.get(nodeId);
if (existingNode != null) {
return existingNode;
}
var node = bldr.build();
assert !nodes.containsKey(node.getId());
nodes.put(node.getId(), node);
if (currentBlockBldr() != null) {
currentBlockBldr().addNode(node);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package org.enso.compiler.pass.lint;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.enso.compiler.context.InlineContext;
import org.enso.compiler.context.ModuleContext;
import org.enso.compiler.core.CompilerError;
import org.enso.compiler.core.IR;
import org.enso.compiler.core.ir.Expression;
import org.enso.compiler.core.ir.MetadataStorage;
import org.enso.compiler.core.ir.Name;
import org.enso.compiler.core.ir.Pattern;
import org.enso.compiler.core.ir.expression.Case;
import org.enso.compiler.core.ir.expression.Case.Branch;
import org.enso.compiler.core.ir.expression.warnings.Shadowed.PatternBinding;
import org.enso.compiler.pass.IRProcessingPass;
import org.enso.compiler.pass.MiniIRPass;
import org.enso.compiler.pass.MiniPassFactory;
import org.enso.compiler.pass.analyse.AliasAnalysis$;
import org.enso.compiler.pass.analyse.DataflowAnalysis$;
import org.enso.compiler.pass.analyse.DemandAnalysis$;
import org.enso.compiler.pass.analyse.TailCall;
import org.enso.compiler.pass.desugar.GenerateMethodBodies$;
import org.enso.compiler.pass.desugar.NestedPatternMatch$;
import org.enso.compiler.pass.resolve.IgnoredBindings$;
import scala.collection.immutable.List;
import scala.collection.immutable.Seq;
import scala.jdk.javaapi.CollectionConverters;

/**
* This pass detects and renames shadowed pattern fields.
*
* <p>This is necessary both in order to create a warning, but also to ensure that alias analysis
* doesn't get confused.
*
* <p>This pass requires no configuration.
*
* <p>This pass requires the context to provide:
*
* <p>- Nothing
*/
public final class ShadowedPatternFields implements MiniPassFactory {
public static final ShadowedPatternFields INSTANCE = new ShadowedPatternFields();

private ShadowedPatternFields() {}

@Override
public List<IRProcessingPass> precursorPasses() {
java.util.List<IRProcessingPass> list = java.util.List.of(GenerateMethodBodies$.MODULE$);
return CollectionConverters.asScala(list).toList();
}

@Override
public List<IRProcessingPass> invalidatedPasses() {
java.util.List<IRProcessingPass> list =
java.util.List.of(
AliasAnalysis$.MODULE$,
DataflowAnalysis$.MODULE$,
DemandAnalysis$.MODULE$,
IgnoredBindings$.MODULE$,
NestedPatternMatch$.MODULE$,
TailCall.INSTANCE);
return CollectionConverters.asScala(list).toList();
}

@Override
public MiniIRPass createForModuleCompilation(ModuleContext moduleContext) {
return new Mini();
}

@Override
public MiniIRPass createForInlineCompilation(InlineContext inlineContext) {
return new Mini();
}

private static final class Mini extends MiniIRPass {
@Override
@SuppressWarnings("unchecked")
public Expression transformExpression(Expression expr) {
return switch (expr) {
case Case.Branch branch -> lintCaseBranch(branch);
case Case.Expr caseExpr -> {
Seq<Branch> newBranches = caseExpr.branches().map(this::lintCaseBranch).toSeq();
yield caseExpr.copy(
caseExpr.scrutinee(),
newBranches,
caseExpr.isNested(),
caseExpr.location(),
caseExpr.passData(),
caseExpr.diagnostics(),
caseExpr.id());
}
default -> expr;
};
}

/**
* Lints for shadowed pattern variables in a case branch.
*
* @param branch the case branch to lint
* @return `branch`, with warnings for any shadowed pattern variables
*/
private Case.Branch lintCaseBranch(Case.Branch branch) {
var newPattern = lintPattern(branch.pattern());
return branch.copy(
newPattern,
branch.expression(),
branch.terminalBranch(),
branch.location(),
branch.passData(),
branch.diagnostics(),
branch.id());
}

/**
* Lints a pattern for shadowed pattern variables.
*
* <p>A later pattern variable shadows an earlier pattern variable with the same name.
*
* @param pattern the pattern to lint
* @return `pattern`, with a warning applied to any shadowed pattern variables
*/
private Pattern lintPattern(Pattern pattern) {
var seenNames = new HashSet<String>();
var lastSeen = new HashMap<String, IR>();

return go(pattern, seenNames, lastSeen);
}

private Pattern go(Pattern pattern, Set<String> seenNames, Map<String, IR> lastSeen) {
return switch (pattern) {
case Pattern.Name named -> {
var name = named.name().name();
if (seenNames.contains(name)) {
var warning = new PatternBinding(name, lastSeen.get(name), named.identifiedLocation());
lastSeen.put(name, named);
var blank = new Name.Blank(named.identifiedLocation(), new MetadataStorage());
var patternCopy = named.copyWithName(blank);
patternCopy.getDiagnostics().add(warning);
yield patternCopy;
} else if (!(named.name() instanceof Name.Blank)) {
lastSeen.put(name, named);
seenNames.add(name);
yield named;
} else {
yield named;
}
}
case Pattern.Constructor cons -> {
var newFields =
cons.fields().reverse().map(field -> go(field, seenNames, lastSeen)).reverse();
yield cons.copyWithFields(newFields);
}
case Pattern.Literal literal -> literal;
case Pattern.Type typed -> {
var name = typed.name().name();
if (seenNames.contains(name)) {
var warning = new PatternBinding(name, lastSeen.get(name), typed.identifiedLocation());
lastSeen.put(name, typed);
var blank = new Name.Blank(typed.identifiedLocation(), new MetadataStorage());
var typedCopy =
typed.copy(
blank,
typed.tpe(),
typed.location(),
typed.passData(),
typed.diagnostics(),
typed.id());
typedCopy.getDiagnostics().add(warning);
yield typedCopy;
} else if (!(typed.name() instanceof Name.Blank)) {
lastSeen.put(name, typed);
seenNames.add(name);
yield typed;
} else {
yield typed;
}
}
case Pattern.Documentation doc -> throw new CompilerError(
"Branch documentation should be desugared at an earlier stage.");
default -> pattern;
};
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package org.enso.compiler.pass.optimise;

import java.util.ArrayList;
import java.util.stream.Stream;
import org.enso.compiler.context.InlineContext;
import org.enso.compiler.context.ModuleContext;
import org.enso.compiler.core.CompilerError;
import org.enso.compiler.core.ir.Expression;
import org.enso.compiler.core.ir.IdentifiedLocation;
import org.enso.compiler.core.ir.Pattern;
import org.enso.compiler.core.ir.expression.Case;
import org.enso.compiler.core.ir.expression.warnings.Unreachable;
import org.enso.compiler.pass.IRProcessingPass;
import org.enso.compiler.pass.MiniIRPass;
import org.enso.compiler.pass.MiniPassFactory;
import org.enso.compiler.pass.analyse.AliasAnalysis$;
import org.enso.compiler.pass.analyse.DataflowAnalysis$;
import org.enso.compiler.pass.analyse.DemandAnalysis$;
import org.enso.compiler.pass.analyse.TailCall;
import org.enso.compiler.pass.desugar.ComplexType$;
import org.enso.compiler.pass.desugar.FunctionBinding$;
import org.enso.compiler.pass.desugar.GenerateMethodBodies$;
import org.enso.compiler.pass.desugar.LambdaShorthandToLambda$;
import org.enso.compiler.pass.desugar.NestedPatternMatch$;
import org.enso.compiler.pass.resolve.DocumentationComments$;
import org.enso.compiler.pass.resolve.IgnoredBindings$;
import org.enso.scala.wrapper.ScalaConversions;
import scala.collection.immutable.List;
import scala.jdk.javaapi.CollectionConverters;

/**
* This pass discovers and optimizes away unreachable case branches.
*
* <p>It removes these unreachable expressions from the IR, and attaches a {@link
* org.enso.compiler.core.ir.Warning} diagnostic to the case expression itself.
*
* <p>Currently, a branch is considered 'unreachable' by this pass if:
*
* <ul>
* <li>It occurs after a catch-all branch.
* </ul>
*
* <p>In the future, this pass should be expanded to consider patterns that are entirely subsumed by
* previous patterns in its definition of unreachable, but this requires doing sophisticated
* coverage analysis, and hence should happen as part of the broader refactor of nested patterns
* desugaring.
*
* <p>This pass requires no configuration.
*
* <p>This pass requires the context to provide:
*
* <ul>
* <li>Nothing
* </ul>
*/
public final class UnreachableMatchBranches implements MiniPassFactory {
private UnreachableMatchBranches() {}

public static final UnreachableMatchBranches INSTANCE = new UnreachableMatchBranches();

@Override
public List<IRProcessingPass> precursorPasses() {
java.util.List<IRProcessingPass> passes = new ArrayList<>();
passes.add(ComplexType$.MODULE$);
passes.add(DocumentationComments$.MODULE$);
passes.add(FunctionBinding$.MODULE$);
passes.add(GenerateMethodBodies$.MODULE$);
passes.add(LambdaShorthandToLambda$.MODULE$);
return CollectionConverters.asScala(passes).toList();
}

@Override
public List<IRProcessingPass> invalidatedPasses() {
java.util.List<IRProcessingPass> passes = new ArrayList<>();
passes.add(AliasAnalysis$.MODULE$);
passes.add(DataflowAnalysis$.MODULE$);
passes.add(DemandAnalysis$.MODULE$);
passes.add(IgnoredBindings$.MODULE$);
passes.add(NestedPatternMatch$.MODULE$);
passes.add(TailCall.INSTANCE);
return CollectionConverters.asScala(passes).toList();
}

@Override
public MiniIRPass createForInlineCompilation(InlineContext inlineContext) {
return new Mini();
}

@Override
public MiniIRPass createForModuleCompilation(ModuleContext moduleContext) {
return new Mini();
}

private static class Mini extends MiniIRPass {
@Override
public Expression transformExpression(Expression expr) {
return switch (expr) {
case Case cse -> optimizeCase(cse);
default -> expr;
};
}

/**
* Optimizes a case expression by removing unreachable branches.
*
* <p>Additionally, it will attach a warning about unreachable branches to the case expression.
*
* @param cse the case expression to optimize
* @return `cse` with unreachable branches removed
*/
private Case optimizeCase(Case cse) {
if (cse instanceof Case.Expr expr) {
var branches = CollectionConverters.asJava(expr.branches());
var reachableNonCatchAllBranches =
branches.stream().takeWhile(branch -> !isCatchAll(branch));
var firstCatchAll = branches.stream().filter(this::isCatchAll).findFirst();
var unreachableBranches =
branches.stream().dropWhile(branch -> !isCatchAll(branch)).skip(1).toList();
List<Case.Branch> reachableBranches;
if (firstCatchAll.isPresent()) {
reachableBranches = appended(reachableNonCatchAllBranches, firstCatchAll.get());
} else {
reachableBranches = ScalaConversions.nil();
}

if (unreachableBranches.isEmpty()) {
return expr;
} else {
var firstUnreachableWithLoc =
unreachableBranches.stream()
.filter(branch -> branch.identifiedLocation() != null)
.findFirst();
var lastUnreachableWithLoc =
unreachableBranches.stream()
.filter(branch -> branch.identifiedLocation() != null)
.reduce((first, second) -> second);
IdentifiedLocation unreachableLocation = null;
if (firstUnreachableWithLoc.isPresent() && lastUnreachableWithLoc.isPresent()) {
unreachableLocation =
new IdentifiedLocation(
firstUnreachableWithLoc.get().location().get().start(),
lastUnreachableWithLoc.get().location().get().end(),
firstUnreachableWithLoc.get().id());
}

var diagnostic = new Unreachable.Branches(unreachableLocation);
var copiedExpr = expr.copyWithBranches(reachableBranches);
copiedExpr.getDiagnostics().add(diagnostic);
return copiedExpr;
}
} else {
throw new CompilerError("Unexpected case branch.");
}
}

/**
* Determines if a branch is a catch all branch.
*
* @param branch the branch to check
* @return `true` if `branch` is catch-all, otherwise `false`
*/
private boolean isCatchAll(Case.Branch branch) {
return switch (branch.pattern()) {
case Pattern.Name ignored -> true;
default -> false;
};
}

private static List<Case.Branch> appended(Stream<Case.Branch> branches, Case.Branch branch) {
var ret = new ArrayList<>(branches.toList());
ret.add(branch);
return CollectionConverters.asScala(ret).toList();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class Passes(config: CompilerConfig) {
)
} else List())
++ List(
ShadowedPatternFields,
UnreachableMatchBranches,
ShadowedPatternFields.INSTANCE,
UnreachableMatchBranches.INSTANCE,
NestedPatternMatch,
IgnoredBindings,
TypeFunctions,
Expand Down
Loading

0 comments on commit efe8786

Please sign in to comment.