Skip to content

Commit

Permalink
Fix and extend refining annotation comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
mbovel committed Feb 10, 2025
1 parent 10d23a6 commit 049308f
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 61 deletions.
25 changes: 0 additions & 25 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,31 +189,6 @@ object Trees {

override def toText(printer: Printer): Text = printer.toText(this)

def sameTree(that: Tree[?]): Boolean = {
def isSame(x: Any, y: Any): Boolean =
x.asInstanceOf[AnyRef].eq(y.asInstanceOf[AnyRef]) || {
x match {
case x: Tree[?] =>
y match {
case y: Tree[?] => x.sameTree(y)
case _ => false
}
case x: List[?] =>
y match {
case y: List[?] => x.corresponds(y)(isSame)
case _ => false
}
case _ =>
false
}
}
this.getClass == that.getClass && {
val it1 = this.productIterator
val it2 = that.productIterator
it1.corresponds(it2)(isSame)
}
}

override def hashCode(): Int = System.identityHashCode(this)
override def equals(that: Any): Boolean = this eq that.asInstanceOf[AnyRef]
}
Expand Down
54 changes: 54 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,60 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
else
applyOverloaded(tree, nme.EQ, that :: Nil, Nil, defn.BooleanType)

def sameTree(that: Tree, thisParamSyms: List[Symbol] = Nil, thatParamRefs: List[TermRef] = Nil)(using Context): Boolean =
def recur(tree1: Tree, tree2: Tree) =
tree1.sameTree(tree2, thisParamSyms, thatParamRefs)

def sameTrees(trees1: List[Tree], trees2: List[Tree]) =
trees1.corresponds(trees2)(recur)

def sameType(tp1: Type, tp2: Type) =
(tp1 frozen_=:= tp2) || (tp1.subst(thisParamSyms, thatParamRefs) frozen_=:= tp2)

val res = tree match
case Literal(_) | Ident(_) =>
sameType(tree.tpe, that.tpe)
case Select(qual1, name1) =>
that match
case Select(qual2, name2) => name1 == name2 && recur(qual1, qual2)
case _ => false
case Apply(fun1, args1) =>
that match
case Apply(fun2, args2) => recur(fun1, fun2) && sameTrees(args1, args2)
case _ => false
case TypeApply(fun1, args1) =>
that match
case TypeApply(fun2, args2) =>
recur(fun1, fun2) && args1.corresponds(args2)((arg1, arg2) => sameType(arg1.tpe, arg2.tpe))
case _ => false
case tpt1: TypeTree =>
that match
case tpt2: TypeTree => sameType(tpt1.tpe, tpt2.tpe)
case _ => false
case Typed(expr1, tpt1) =>
that match
case Typed(expr2, tpt2) => recur(expr1, expr2) && sameType(tpt1.tpe, tpt2.tpe)
case _ => false
case New(tpt1) =>
that match
case New(tpt2) => sameType(tpt1.tpe, tpt2.tpe)
case _ => false
case closureDef(def1) =>
that match
case closureDef(def2) =>
val newThisParamSyms = def1.symbol.paramSymss.flatten ++ thisParamSyms
val newThatParamRefs = def2.symbol.paramSymss.flatten.map(_.termRef) ++ thatParamRefs
def1.rhs.sameTree(def2.rhs, newThisParamSyms, newThatParamRefs)
case _ => false
case Block(stats1, expr1) =>
that match
case Block(stats2, expr2) => sameTrees(stats1, stats2) && recur(expr1, expr2)
case _ => false
case _ => false

res


/** `tree.isInstanceOf[tp]`, with special treatment of singleton types */
def isInstance(tp: Type)(using Context): Tree = tp.dealias match {
case ConstantType(c) if c.tag == StringTag =>
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ object Annotations {
def argumentConstantString(i: Int)(using Context): Option[String] =
for (case Constant(s: String) <- argumentConstant(i)) yield s

/** All type and term argument trees of this annotation in a single flat list */
private def allArguments(using Context): List[Tree] = tpd.allArguments(tree)

/** The tree evaluation is in progress. */
def isEvaluating: Boolean = false

Expand Down Expand Up @@ -88,7 +91,8 @@ object Annotations {
def ensureCompleted(using Context): Unit = tree

def sameAnnotation(that: Annotation)(using Context): Boolean =
symbol == that.symbol && tree.sameTree(that.tree)
def sameArg(arg1: Tree, arg2: Tree): Boolean = tpd.stripNamedArg(arg1).sameTree(tpd.stripNamedArg(arg2))
symbol == that.symbol && allArguments.corresponds(that.allArguments)(sameArg)

def hasOneOfMetaAnnotation(metaSyms: Set[Symbol], orNoneOf: Set[Symbol] = Set.empty)(using Context): Boolean = atPhaseNoLater(erasurePhase) {
def go(metaSyms: Set[Symbol]) =
Expand Down
40 changes: 40 additions & 0 deletions tests/neg/annot-refining-infer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
class MyAnnotation(x: Any) extends scala.annotation.RefiningAnnotation

def id[T](x: T): T = x
def id2[T](x: T, y: T): T = x

def foo1[T](x: T, g: T => Unit): T = x
def foo2[T](x: T, y: T, g: T => Unit): T = x
def foo3[T](g: T => Unit, x: T, y: T): T = x
def foo4[T](x: T, g: T => Unit, h: T => Unit): T = x

def take42[T](x: T @MyAnnotation(42)): Unit = ()
def take43[T](x: T @MyAnnotation(43)): Unit = ()
def take42or43[S](x: S @MyAnnotation(42) | S @MyAnnotation(43)): Unit = ()
def take42or43Int(x: Int @MyAnnotation(42) | Int @MyAnnotation(43)): Unit = ()

def main =
val c42: Int @MyAnnotation(42) = ???
val c43: Int @MyAnnotation(43) = ???

val v01 = id2[Int @MyAnnotation(42) | Int @MyAnnotation(43)](c42, c43)
val v02: Int @MyAnnotation(42) | Int @MyAnnotation(43) = c42
val v03: Int @MyAnnotation(42) | Int @MyAnnotation(43) = id2(c42, c43)

val v04 = foo1(c42, take42)
val v05: Int @MyAnnotation(42) = v13
val v06 = foo1(c42, take43) // error
val v07 = foo1(c42, take42or43)

val v08 = foo2(c42, c42, take42)
val v09: Int @MyAnnotation(42) = v15
val v10 = foo2(c42, c43, take42) // error
val v11 = foo2(c42, c43, take42or43) // error
val v12 = foo2(c42, c43, take42or43Int)

val v13 = foo3(take42or43, c42, c43) // error
val v14 = foo3(take42or43Int, c42, c43)

val v15 = foo4(c42, take42, take42)
val v16: Int @MyAnnotation(42) = v15
val v17 = foo4(c42, take42, take43) // error
69 changes: 34 additions & 35 deletions tests/neg/annot-refining-sub.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ def main =
val c: Int = 42
val o: O.type = O

val v1: Int @annot1(1) = ??? : Int @annot1(1) // error: fixme (constants are equal)
val v1: Int @annot1(1) = ??? : Int @annot1(1)
val v2: Int @annot1(c) = ??? : Int @annot1(c)
val v3: Int @annot1(O.d) = ??? : Int @annot1(O.d)
val v4: Int @annot1(O.d) = ??? : Int @annot1(o.d) // error: fixme?
val v5: Int @annot1((1, 2)) = ??? : Int @annot1((1, 2)) // error: fixme
val v6: Int @annot1(1 + 2) = ??? : Int @annot1(1 + 2) // error: fixme
val v7: Int @annot1(1 + 2) = ??? : Int @annot1(2 + 1) // error: fixme? should constant fold?
val v8: Int @annot1(1 + c) = ??? : Int @annot1(1 + c) // error: fixme
val v9: Int @annot1(1 + c) = ??? : Int @annot1(c + 1) // error (no algebraic normalization)
val v10: Int @annot1(Box(1)) = ??? : Int @annot1(Box(1)) // error: fixme
val v4: Int @annot1(O.d) = ??? : Int @annot1(o.d)
val v5: Int @annot1((1, 2)) = ??? : Int @annot1((1, 2))
val v6: Int @annot1(1 + 2) = ??? : Int @annot1(1 + 2)
val v7: Int @annot1(1 + 2) = ??? : Int @annot1(2 + 1) // error: no constant folding
val v8: Int @annot1(1 + c) = ??? : Int @annot1(1 + c)
val v9: Int @annot1(1 + c) = ??? : Int @annot1(c + 1) // error: no algebraic simplification
val v10: Int @annot1(Box(1)) = ??? : Int @annot1(Box(1))
val v11: Int @annot1(Box(c)) = ??? : Int @annot1(Box(c))
val v12: Int @annot1(Box2(1)) = ??? : Int @annot1(Box2(1)) // error: fixme
val v12: Int @annot1(Box2(1)) = ??? : Int @annot1(Box2(1))
val v13: Int @annot1(Box2(c)) = ??? : Int @annot1(Box2(c))
val v14: Int @annot1(c: Int) = ??? : Int @annot1(c: Int)
val v15: Int @annot1(c) = ??? : Int @annot1(c: Int) // error
Expand All @@ -50,41 +50,40 @@ def main =
val v26: Int @annot1(??? : Box3 {type T = Int}) = ??? : Int @annot1(??? : Box3 {type T = String}) // error
val v27: Int @annot1(??? : Box3 {type T = Int}) = ??? : Int @annot1(??? : Box3) // error
val v28: Int @annot1(a=c) = ??? : Int @annot1(a=c)
val v29: Int @annot1(a=c) = ??? : Int @annot1(c) // error: fixme (same arguments, named vs positional)
val v30: Int @annot1(c) = ??? : Int @annot1(a=c) // error: fixme
val v29: Int @annot1(a=c) = ??? : Int @annot1(c)
val v30: Int @annot1(c) = ??? : Int @annot1(a=c)
val v31: Int @annot1((d: Int) => d) = ??? : Int @annot1((d: Int) => d)
val v32: Int @annot1((d: Int) => d) = ??? : Int @annot1((e: Int) => e) // error: fixme (alpha equivalence)
val v33: Int @annot1((e: Int) => e) = ??? : Int @annot1((d: Int) => d) // error: fixme
val v34: Int @annot1((d: Int) => d + 1) = ??? : Int @annot1((e: Int) => e + 1) // error: fixme
val v35: Int @annot1((d: Int) => d + 1) = ??? : Int @annot1((e: Int) => e + 1) // error: fixme
val v36: Int @annot1((d: Int) => id[d.type]) = ??? : Int @annot1((e: Int) => id[e.type]) // error: fixme
val v37: Int @annot1((d: Box3) => id[d.T]) = ??? : Int @annot1((e: Box3) => id[e.T]) // error: fixme
val v38: Int @annot1((d: Int) => (d: Int) => d) = ??? : Int @annot1((e: Int) => (e: Int) => e) // error: fixme
val v39: Int @annot1((d: Int) => ((e: Int) => d)(2)) = ??? : Int @annot1((e: Int) => ((e: Int) => e)(2)) // error: fixme

val v40: Int @annot2(1, 2) = ??? : Int @annot2(1, 2) // error: fixme
val v32: Int @annot1((d: Int) => d) = ??? : Int @annot1((e: Int) => e)
val v33: Int @annot1((e: Int) => e) = ??? : Int @annot1((d: Int) => d)
val v34: Int @annot1((d: Int) => d + 1) = ??? : Int @annot1((e: Int) => e + 1)
val v35: Int @annot1((d: Int) => id(d)) = ??? : Int @annot1((e: Int) => id(e))
val v36: Int @annot1((d: Int) => id[d.type]) = ??? : Int @annot1((e: Int) => id[e.type])
val v37: Int @annot1((d: Box3) => id[d.T]) = ??? : Int @annot1((e: Box3) => id[e.T])
val v38: Int @annot1((d: Int) => (d: Int) => d) = ??? : Int @annot1((e: Int) => (e: Int) => e)
val v39: Int @annot1((d: Int) => ((e: Int) => e)(2)) = ??? : Int @annot1((e: Int) => ((e: Int) => e)(2))
val v40: Int @annot2(1, 2) = ??? : Int @annot2(1, 2)
val v41: Int @annot2(c, c) = ??? : Int @annot2(c, c)
val v42: Int @annot2(c, c) = ??? : Int @annot2(a=c, b=c) // error: fixme
val v43: Int @annot2(a=c, c) = ??? : Int @annot2(c, b=c) // error: fixme
val v44: Int @annot2(a=c, b=c) = ??? : Int @annot2(c, c) // error: fixme
val v42: Int @annot2(c, c) = ??? : Int @annot2(a=c, b=c)
val v43: Int @annot2(a=c, c) = ??? : Int @annot2(c, b=c)
val v44: Int @annot2(a=c, b=c) = ??? : Int @annot2(c, c)

val v45: Int @annot3(1) = ??? : Int @annot3(1) // error: fixme
val v45: Int @annot3(1) = ??? : Int @annot3(1)
val v46: Int @annot3(c) = ??? : Int @annot3(c)
val v47: Int @annot3(1) = ??? : Int @annot3(1, 3) // error: fixme
val v48: Int @annot3(1, 3) = ??? : Int @annot3(1) // error: fixme
val v49: Int @annot3(c) = ??? : Int @annot3(c, 3) // error: fixme
val v50: Int @annot3(c, 3) = ??? : Int @annot3(c) // error: fixme
val v47: Int @annot3(1) = ??? : Int @annot3(1, 3) // error: default arg tree is different, fix in the future?
val v48: Int @annot3(1, 3) = ??? : Int @annot3(1) // error: same as above
val v49: Int @annot3(c) = ??? : Int @annot3(c, 3) // error: same as above
val v50: Int @annot3(c, 3) = ??? : Int @annot3(c) // error: same as above

val v51: Int @annot4[1] = ??? : Int @annot4[1] // error: fixme
val v51: Int @annot4[1] = ??? : Int @annot4[1]
val v52: Int @annot4[c.type] = ??? : Int @annot4[c.type]
val v53: Int @annot4[O.d.type] = ??? : Int @annot4[O.d.type]
val v54: Int @annot4[O.d.type] = ??? : Int @annot4[o.d.type]// error: fixme?
val v54: Int @annot4[O.d.type] = ??? : Int @annot4[o.d.type]
val v55: Int @annot4[Int] = ??? : Int @annot4[Int]
val v56: Int @annot4[Int] = ??? : Int @annot4[1] // error
val v57: Int @annot4[(1, 2)] = ??? : Int @annot4[(1, 2)] // error: fixme
val v58: Int @annot4[1 + 2] = ??? : Int @annot4[1 + 2] // error: fixme
val v59: Int @annot4[1 + 2] = ??? : Int @annot4[2 + 1] // error: fixme
val v60: Int @annot4[1 + c.type] = ??? : Int @annot4[1 + c.type] // error: fixme
val v57: Int @annot4[(1, 2)] = ??? : Int @annot4[(1, 2)]
val v58: Int @annot4[1 + 2] = ??? : Int @annot4[1 + 2]
val v59: Int @annot4[1 + 2] = ??? : Int @annot4[2 + 1]
val v60: Int @annot4[1 + c.type] = ??? : Int @annot4[1 + c.type]
val v61: Int @annot4[1 + c.type] = ??? : Int @annot4[c.type + 1] // error
val v62: Int @annot4[Box[Int]] = ??? : Int @annot4[Box[Int]]
val v63: Int @annot4[Box[String]] = ??? : Int @annot4[Box[Int]] // error
Expand Down

0 comments on commit 049308f

Please sign in to comment.