From 049308f7c134d78d72cbaa61a3f1fc0ce991564d Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Mon, 10 Feb 2025 13:52:15 +0000 Subject: [PATCH] Fix and extend refining annotation comparison --- compiler/src/dotty/tools/dotc/ast/Trees.scala | 25 ------- compiler/src/dotty/tools/dotc/ast/tpd.scala | 54 +++++++++++++++ .../dotty/tools/dotc/core/Annotations.scala | 6 +- tests/neg/annot-refining-infer.scala | 40 +++++++++++ tests/neg/annot-refining-sub.scala | 69 +++++++++---------- 5 files changed, 133 insertions(+), 61 deletions(-) create mode 100644 tests/neg/annot-refining-infer.scala diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index fdefc14aadd6..95ef5f5e1be9 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -189,31 +189,6 @@ object Trees { override def toText(printer: Printer): Text = printer.toText(this) - def sameTree(that: Tree[?]): Boolean = { - def isSame(x: Any, y: Any): Boolean = - x.asInstanceOf[AnyRef].eq(y.asInstanceOf[AnyRef]) || { - x match { - case x: Tree[?] => - y match { - case y: Tree[?] => x.sameTree(y) - case _ => false - } - case x: List[?] => - y match { - case y: List[?] => x.corresponds(y)(isSame) - case _ => false - } - case _ => - false - } - } - this.getClass == that.getClass && { - val it1 = this.productIterator - val it2 = that.productIterator - it1.corresponds(it2)(isSame) - } - } - override def hashCode(): Int = System.identityHashCode(this) override def equals(that: Any): Boolean = this eq that.asInstanceOf[AnyRef] } diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 77e3387c5ce0..68c01d3477ff 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1020,6 +1020,60 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { else applyOverloaded(tree, nme.EQ, that :: Nil, Nil, defn.BooleanType) + def sameTree(that: Tree, thisParamSyms: List[Symbol] = Nil, thatParamRefs: List[TermRef] = Nil)(using Context): Boolean = + def recur(tree1: Tree, tree2: Tree) = + tree1.sameTree(tree2, thisParamSyms, thatParamRefs) + + def sameTrees(trees1: List[Tree], trees2: List[Tree]) = + trees1.corresponds(trees2)(recur) + + def sameType(tp1: Type, tp2: Type) = + (tp1 frozen_=:= tp2) || (tp1.subst(thisParamSyms, thatParamRefs) frozen_=:= tp2) + + val res = tree match + case Literal(_) | Ident(_) => + sameType(tree.tpe, that.tpe) + case Select(qual1, name1) => + that match + case Select(qual2, name2) => name1 == name2 && recur(qual1, qual2) + case _ => false + case Apply(fun1, args1) => + that match + case Apply(fun2, args2) => recur(fun1, fun2) && sameTrees(args1, args2) + case _ => false + case TypeApply(fun1, args1) => + that match + case TypeApply(fun2, args2) => + recur(fun1, fun2) && args1.corresponds(args2)((arg1, arg2) => sameType(arg1.tpe, arg2.tpe)) + case _ => false + case tpt1: TypeTree => + that match + case tpt2: TypeTree => sameType(tpt1.tpe, tpt2.tpe) + case _ => false + case Typed(expr1, tpt1) => + that match + case Typed(expr2, tpt2) => recur(expr1, expr2) && sameType(tpt1.tpe, tpt2.tpe) + case _ => false + case New(tpt1) => + that match + case New(tpt2) => sameType(tpt1.tpe, tpt2.tpe) + case _ => false + case closureDef(def1) => + that match + case closureDef(def2) => + val newThisParamSyms = def1.symbol.paramSymss.flatten ++ thisParamSyms + val newThatParamRefs = def2.symbol.paramSymss.flatten.map(_.termRef) ++ thatParamRefs + def1.rhs.sameTree(def2.rhs, newThisParamSyms, newThatParamRefs) + case _ => false + case Block(stats1, expr1) => + that match + case Block(stats2, expr2) => sameTrees(stats1, stats2) && recur(expr1, expr2) + case _ => false + case _ => false + + res + + /** `tree.isInstanceOf[tp]`, with special treatment of singleton types */ def isInstance(tp: Type)(using Context): Tree = tp.dealias match { case ConstantType(c) if c.tag == StringTag => diff --git a/compiler/src/dotty/tools/dotc/core/Annotations.scala b/compiler/src/dotty/tools/dotc/core/Annotations.scala index 1a5cf2b03e06..2e787cfa3c7b 100644 --- a/compiler/src/dotty/tools/dotc/core/Annotations.scala +++ b/compiler/src/dotty/tools/dotc/core/Annotations.scala @@ -43,6 +43,9 @@ object Annotations { def argumentConstantString(i: Int)(using Context): Option[String] = for (case Constant(s: String) <- argumentConstant(i)) yield s + /** All type and term argument trees of this annotation in a single flat list */ + private def allArguments(using Context): List[Tree] = tpd.allArguments(tree) + /** The tree evaluation is in progress. */ def isEvaluating: Boolean = false @@ -88,7 +91,8 @@ object Annotations { def ensureCompleted(using Context): Unit = tree def sameAnnotation(that: Annotation)(using Context): Boolean = - symbol == that.symbol && tree.sameTree(that.tree) + def sameArg(arg1: Tree, arg2: Tree): Boolean = tpd.stripNamedArg(arg1).sameTree(tpd.stripNamedArg(arg2)) + symbol == that.symbol && allArguments.corresponds(that.allArguments)(sameArg) def hasOneOfMetaAnnotation(metaSyms: Set[Symbol], orNoneOf: Set[Symbol] = Set.empty)(using Context): Boolean = atPhaseNoLater(erasurePhase) { def go(metaSyms: Set[Symbol]) = diff --git a/tests/neg/annot-refining-infer.scala b/tests/neg/annot-refining-infer.scala new file mode 100644 index 000000000000..43dac83a4cd9 --- /dev/null +++ b/tests/neg/annot-refining-infer.scala @@ -0,0 +1,40 @@ +class MyAnnotation(x: Any) extends scala.annotation.RefiningAnnotation + +def id[T](x: T): T = x +def id2[T](x: T, y: T): T = x + +def foo1[T](x: T, g: T => Unit): T = x +def foo2[T](x: T, y: T, g: T => Unit): T = x +def foo3[T](g: T => Unit, x: T, y: T): T = x +def foo4[T](x: T, g: T => Unit, h: T => Unit): T = x + +def take42[T](x: T @MyAnnotation(42)): Unit = () +def take43[T](x: T @MyAnnotation(43)): Unit = () +def take42or43[S](x: S @MyAnnotation(42) | S @MyAnnotation(43)): Unit = () +def take42or43Int(x: Int @MyAnnotation(42) | Int @MyAnnotation(43)): Unit = () + +def main = + val c42: Int @MyAnnotation(42) = ??? + val c43: Int @MyAnnotation(43) = ??? + + val v01 = id2[Int @MyAnnotation(42) | Int @MyAnnotation(43)](c42, c43) + val v02: Int @MyAnnotation(42) | Int @MyAnnotation(43) = c42 + val v03: Int @MyAnnotation(42) | Int @MyAnnotation(43) = id2(c42, c43) + + val v04 = foo1(c42, take42) + val v05: Int @MyAnnotation(42) = v13 + val v06 = foo1(c42, take43) // error + val v07 = foo1(c42, take42or43) + + val v08 = foo2(c42, c42, take42) + val v09: Int @MyAnnotation(42) = v15 + val v10 = foo2(c42, c43, take42) // error + val v11 = foo2(c42, c43, take42or43) // error + val v12 = foo2(c42, c43, take42or43Int) + + val v13 = foo3(take42or43, c42, c43) // error + val v14 = foo3(take42or43Int, c42, c43) + + val v15 = foo4(c42, take42, take42) + val v16: Int @MyAnnotation(42) = v15 + val v17 = foo4(c42, take42, take43) // error diff --git a/tests/neg/annot-refining-sub.scala b/tests/neg/annot-refining-sub.scala index c42602f4ffed..538595113229 100644 --- a/tests/neg/annot-refining-sub.scala +++ b/tests/neg/annot-refining-sub.scala @@ -22,18 +22,18 @@ def main = val c: Int = 42 val o: O.type = O - val v1: Int @annot1(1) = ??? : Int @annot1(1) // error: fixme (constants are equal) + val v1: Int @annot1(1) = ??? : Int @annot1(1) val v2: Int @annot1(c) = ??? : Int @annot1(c) val v3: Int @annot1(O.d) = ??? : Int @annot1(O.d) - val v4: Int @annot1(O.d) = ??? : Int @annot1(o.d) // error: fixme? - val v5: Int @annot1((1, 2)) = ??? : Int @annot1((1, 2)) // error: fixme - val v6: Int @annot1(1 + 2) = ??? : Int @annot1(1 + 2) // error: fixme - val v7: Int @annot1(1 + 2) = ??? : Int @annot1(2 + 1) // error: fixme? should constant fold? - val v8: Int @annot1(1 + c) = ??? : Int @annot1(1 + c) // error: fixme - val v9: Int @annot1(1 + c) = ??? : Int @annot1(c + 1) // error (no algebraic normalization) - val v10: Int @annot1(Box(1)) = ??? : Int @annot1(Box(1)) // error: fixme + val v4: Int @annot1(O.d) = ??? : Int @annot1(o.d) + val v5: Int @annot1((1, 2)) = ??? : Int @annot1((1, 2)) + val v6: Int @annot1(1 + 2) = ??? : Int @annot1(1 + 2) + val v7: Int @annot1(1 + 2) = ??? : Int @annot1(2 + 1) // error: no constant folding + val v8: Int @annot1(1 + c) = ??? : Int @annot1(1 + c) + val v9: Int @annot1(1 + c) = ??? : Int @annot1(c + 1) // error: no algebraic simplification + val v10: Int @annot1(Box(1)) = ??? : Int @annot1(Box(1)) val v11: Int @annot1(Box(c)) = ??? : Int @annot1(Box(c)) - val v12: Int @annot1(Box2(1)) = ??? : Int @annot1(Box2(1)) // error: fixme + val v12: Int @annot1(Box2(1)) = ??? : Int @annot1(Box2(1)) val v13: Int @annot1(Box2(c)) = ??? : Int @annot1(Box2(c)) val v14: Int @annot1(c: Int) = ??? : Int @annot1(c: Int) val v15: Int @annot1(c) = ??? : Int @annot1(c: Int) // error @@ -50,41 +50,40 @@ def main = val v26: Int @annot1(??? : Box3 {type T = Int}) = ??? : Int @annot1(??? : Box3 {type T = String}) // error val v27: Int @annot1(??? : Box3 {type T = Int}) = ??? : Int @annot1(??? : Box3) // error val v28: Int @annot1(a=c) = ??? : Int @annot1(a=c) - val v29: Int @annot1(a=c) = ??? : Int @annot1(c) // error: fixme (same arguments, named vs positional) - val v30: Int @annot1(c) = ??? : Int @annot1(a=c) // error: fixme + val v29: Int @annot1(a=c) = ??? : Int @annot1(c) + val v30: Int @annot1(c) = ??? : Int @annot1(a=c) val v31: Int @annot1((d: Int) => d) = ??? : Int @annot1((d: Int) => d) - val v32: Int @annot1((d: Int) => d) = ??? : Int @annot1((e: Int) => e) // error: fixme (alpha equivalence) - val v33: Int @annot1((e: Int) => e) = ??? : Int @annot1((d: Int) => d) // error: fixme - val v34: Int @annot1((d: Int) => d + 1) = ??? : Int @annot1((e: Int) => e + 1) // error: fixme - val v35: Int @annot1((d: Int) => d + 1) = ??? : Int @annot1((e: Int) => e + 1) // error: fixme - val v36: Int @annot1((d: Int) => id[d.type]) = ??? : Int @annot1((e: Int) => id[e.type]) // error: fixme - val v37: Int @annot1((d: Box3) => id[d.T]) = ??? : Int @annot1((e: Box3) => id[e.T]) // error: fixme - val v38: Int @annot1((d: Int) => (d: Int) => d) = ??? : Int @annot1((e: Int) => (e: Int) => e) // error: fixme - val v39: Int @annot1((d: Int) => ((e: Int) => d)(2)) = ??? : Int @annot1((e: Int) => ((e: Int) => e)(2)) // error: fixme - - val v40: Int @annot2(1, 2) = ??? : Int @annot2(1, 2) // error: fixme + val v32: Int @annot1((d: Int) => d) = ??? : Int @annot1((e: Int) => e) + val v33: Int @annot1((e: Int) => e) = ??? : Int @annot1((d: Int) => d) + val v34: Int @annot1((d: Int) => d + 1) = ??? : Int @annot1((e: Int) => e + 1) + val v35: Int @annot1((d: Int) => id(d)) = ??? : Int @annot1((e: Int) => id(e)) + val v36: Int @annot1((d: Int) => id[d.type]) = ??? : Int @annot1((e: Int) => id[e.type]) + val v37: Int @annot1((d: Box3) => id[d.T]) = ??? : Int @annot1((e: Box3) => id[e.T]) + val v38: Int @annot1((d: Int) => (d: Int) => d) = ??? : Int @annot1((e: Int) => (e: Int) => e) + val v39: Int @annot1((d: Int) => ((e: Int) => e)(2)) = ??? : Int @annot1((e: Int) => ((e: Int) => e)(2)) + val v40: Int @annot2(1, 2) = ??? : Int @annot2(1, 2) val v41: Int @annot2(c, c) = ??? : Int @annot2(c, c) - val v42: Int @annot2(c, c) = ??? : Int @annot2(a=c, b=c) // error: fixme - val v43: Int @annot2(a=c, c) = ??? : Int @annot2(c, b=c) // error: fixme - val v44: Int @annot2(a=c, b=c) = ??? : Int @annot2(c, c) // error: fixme + val v42: Int @annot2(c, c) = ??? : Int @annot2(a=c, b=c) + val v43: Int @annot2(a=c, c) = ??? : Int @annot2(c, b=c) + val v44: Int @annot2(a=c, b=c) = ??? : Int @annot2(c, c) - val v45: Int @annot3(1) = ??? : Int @annot3(1) // error: fixme + val v45: Int @annot3(1) = ??? : Int @annot3(1) val v46: Int @annot3(c) = ??? : Int @annot3(c) - val v47: Int @annot3(1) = ??? : Int @annot3(1, 3) // error: fixme - val v48: Int @annot3(1, 3) = ??? : Int @annot3(1) // error: fixme - val v49: Int @annot3(c) = ??? : Int @annot3(c, 3) // error: fixme - val v50: Int @annot3(c, 3) = ??? : Int @annot3(c) // error: fixme + val v47: Int @annot3(1) = ??? : Int @annot3(1, 3) // error: default arg tree is different, fix in the future? + val v48: Int @annot3(1, 3) = ??? : Int @annot3(1) // error: same as above + val v49: Int @annot3(c) = ??? : Int @annot3(c, 3) // error: same as above + val v50: Int @annot3(c, 3) = ??? : Int @annot3(c) // error: same as above - val v51: Int @annot4[1] = ??? : Int @annot4[1] // error: fixme + val v51: Int @annot4[1] = ??? : Int @annot4[1] val v52: Int @annot4[c.type] = ??? : Int @annot4[c.type] val v53: Int @annot4[O.d.type] = ??? : Int @annot4[O.d.type] - val v54: Int @annot4[O.d.type] = ??? : Int @annot4[o.d.type]// error: fixme? + val v54: Int @annot4[O.d.type] = ??? : Int @annot4[o.d.type] val v55: Int @annot4[Int] = ??? : Int @annot4[Int] val v56: Int @annot4[Int] = ??? : Int @annot4[1] // error - val v57: Int @annot4[(1, 2)] = ??? : Int @annot4[(1, 2)] // error: fixme - val v58: Int @annot4[1 + 2] = ??? : Int @annot4[1 + 2] // error: fixme - val v59: Int @annot4[1 + 2] = ??? : Int @annot4[2 + 1] // error: fixme - val v60: Int @annot4[1 + c.type] = ??? : Int @annot4[1 + c.type] // error: fixme + val v57: Int @annot4[(1, 2)] = ??? : Int @annot4[(1, 2)] + val v58: Int @annot4[1 + 2] = ??? : Int @annot4[1 + 2] + val v59: Int @annot4[1 + 2] = ??? : Int @annot4[2 + 1] + val v60: Int @annot4[1 + c.type] = ??? : Int @annot4[1 + c.type] val v61: Int @annot4[1 + c.type] = ??? : Int @annot4[c.type + 1] // error val v62: Int @annot4[Box[Int]] = ??? : Int @annot4[Box[Int]] val v63: Int @annot4[Box[String]] = ??? : Int @annot4[Box[Int]] // error