Skip to content

Commit

Permalink
CC: Infer more self types automatically (#19425)
Browse files Browse the repository at this point in the history
Fixes #19398 

Clean up the logic how we infer self types, and add a new clause:

> If we have an externally extensible class that itself does not have a
declared self type itself and also not in any of its base classes,
assume {cap} as the self type. Previously we would install a capture
set but then check after the fact that that capture set is indeed {cap}.
So it's less verbose to just assume that from the start.
  • Loading branch information
odersky authored Jan 14, 2024
2 parents b5ecaa0 + 8e0cac4 commit 9b5815a
Show file tree
Hide file tree
Showing 26 changed files with 80 additions and 88 deletions.
17 changes: 13 additions & 4 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -236,19 +236,19 @@ extension (tp: Type)
* (2) all covariant occurrences of cap replaced by `x*`, provided there
* are no occurrences in `T` at other variances. (1) is standard, whereas
* (2) is new.
*
*
* For (2), multiple-flipped covariant occurrences of cap won't be replaced.
* In other words,
*
* - For xs: List[File^] ==> List[File^{xs*}], the cap is replaced;
* - while f: [R] -> (op: File^ => R) -> R remains unchanged.
*
*
* Without this restriction, the signature of functions like withFile:
*
*
* (path: String) -> [R] -> (op: File^ => R) -> R
*
* could be refined to
*
*
* (path: String) -> [R] -> (op: File^{withFile*} => R) -> R
*
* which is clearly unsound.
Expand Down Expand Up @@ -315,6 +315,15 @@ extension (cls: ClassSymbol)
// and err on the side of impure.
&& selfType.exists && selfType.captureSet.isAlwaysEmpty

def baseClassHasExplicitSelfType(using Context): Boolean =
cls.baseClasses.exists: bc =>
bc.is(CaptureChecked) && bc.givenSelfType.exists

def matchesExplicitRefsInBaseClass(refs: CaptureSet)(using Context): Boolean =
cls.baseClasses.tail.exists: bc =>
val selfType = bc.givenSelfType
bc.is(CaptureChecked) && selfType.exists && selfType.captureSet.elems == refs.elems

extension (sym: Symbol)

/** A class is pure if:
Expand Down
39 changes: 16 additions & 23 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -672,17 +672,13 @@ class CheckCaptures extends Recheck, SymTransformer:
def checkInferredResult(tp: Type, tree: ValOrDefDef)(using Context): Type =
val sym = tree.symbol

def isLocal =
sym.owner.ownersIterator.exists(_.isTerm)
|| sym.accessBoundary(defn.RootClass).isContainedIn(sym.topLevelClass)

def canUseInferred = // If canUseInferred is false, all capturing types in the type of `sym` need to be given explicitly
sym.is(Private) // private symbols can always have inferred types
|| sym.name.is(DefaultGetterName) // default getters are exempted since otherwise it would be
// too annoying. This is a hole since a defualt getter's result type
// might leak into a type variable.
|| // non-local symbols cannot have inferred types since external capture types are not inferred
isLocal // local symbols still need explicit types if
sym.isLocalToCompilationUnit // local symbols still need explicit types if
&& !sym.owner.is(Trait) // they are defined in a trait, since we do OverridingPairs checking before capture inference

def addenda(expected: Type) = new Addenda:
Expand Down Expand Up @@ -1182,7 +1178,7 @@ class CheckCaptures extends Recheck, SymTransformer:
/** Check that self types of subclasses conform to self types of super classes.
* (See comment below how this is achieved). The check assumes that classes
* without an explicit self type have the universal capture set `{cap}` on the
* self type. If a class without explicit self type is not `effectivelyFinal`
* self type. If a class without explicit self type is not `effectivelySealed`
* it is checked that the inferred self type is universal, in order to assure
* that joint and separate compilation give the same result.
*/
Expand Down Expand Up @@ -1212,23 +1208,20 @@ class CheckCaptures extends Recheck, SymTransformer:
checkSelfAgainstParents(root, root.baseClasses)
val selfType = root.asClass.classInfo.selfType
interpolator(startingVariance = -1).traverse(selfType)
if !root.isEffectivelySealed then
def matchesExplicitRefsInBaseClass(refs: CaptureSet, cls: ClassSymbol): Boolean =
cls.baseClasses.tail.exists { psym =>
val selfType = psym.asClass.givenSelfType
selfType.exists && selfType.captureSet.elems == refs.elems
}
selfType match
case CapturingType(_, refs: CaptureSet.Var)
if !refs.elems.exists(_.isRootCapability) && !matchesExplicitRefsInBaseClass(refs, root) =>
// Forbid inferred self types unless they are already implied by an explicit
// self type in a parent.
report.error(
em"""$root needs an explicitly declared self type since its
|inferred self type $selfType
|is not visible in other compilation units that define subclasses.""",
root.srcPos)
case _ =>
selfType match
case CapturingType(_, refs: CaptureSet.Var)
if !root.isEffectivelySealed
&& !refs.elems.exists(_.isRootCapability)
&& !root.matchesExplicitRefsInBaseClass(refs)
=>
// Forbid inferred self types unless they are already implied by an explicit
// self type in a parent.
report.error(
em"""$root needs an explicitly declared self type since its
|inferred self type $selfType
|is not visible in other compilation units that define subclasses.""",
root.srcPos)
case _ =>
parentTrees -= root
capt.println(i"checked $root with $selfType")
end checkSelfTypes
Expand Down
31 changes: 22 additions & 9 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -517,21 +517,34 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
tree.symbol match
case cls: ClassSymbol =>
val cinfo @ ClassInfo(prefix, _, ps, decls, selfInfo) = cls.classInfo
if ((selfInfo eq NoType) || cls.is(ModuleClass) && !cls.isStatic)
&& !cls.isPureClass
then
// add capture set to self type of nested classes if no self type is given explicitly.
val newSelfType = CapturingType(cinfo.selfType, CaptureSet.Var(cls))
val ps1 = inContext(ctx.withOwner(cls)):
ps.mapConserve(transformExplicitType(_))
val newInfo = ClassInfo(prefix, cls, ps1, decls, newSelfType)
def innerModule = cls.is(ModuleClass) && !cls.isStatic
val selfInfo1 =
if (selfInfo ne NoType) && !innerModule then
// if selfInfo is explicitly given then use that one, except if
// self info applies to non-static modules, these still need to be inferred
selfInfo
else if cls.isPureClass then
// is cls is known to be pure, nothing needs to be added to self type
selfInfo
else if !cls.isEffectivelySealed && !cls.baseClassHasExplicitSelfType then
// assume {cap} for completely unconstrained self types of publicly extensible classes
CapturingType(cinfo.selfType, CaptureSet.universal)
else
// Infer the self type for the rest, which is all classes without explicit
// self types (to which we also add nested module classes), provided they are
// neither pure, nor are publicily extensible with an unconstrained self type.
CapturingType(cinfo.selfType, CaptureSet.Var(cls))
val ps1 = inContext(ctx.withOwner(cls)):
ps.mapConserve(transformExplicitType(_))
if (selfInfo1 ne selfInfo) || (ps1 ne ps) then
val newInfo = ClassInfo(prefix, cls, ps1, decls, selfInfo1)
updateInfo(cls, newInfo)
capt.println(i"update class info of $cls with parents $ps selfinfo $selfInfo to $newInfo")
cls.thisType.asInstanceOf[ThisType].invalidateCaches()
if cls.is(ModuleClass) then
// if it's a module, the capture set of the module reference is the capture set of the self type
val modul = cls.sourceModule
updateInfo(modul, CapturingType(modul.info, newSelfType.captureSet))
updateInfo(modul, CapturingType(modul.info, selfInfo1.asInstanceOf[Type].captureSet))
modul.termRef.invalidateCaches()
case _ =>
case _ =>
Expand Down
9 changes: 8 additions & 1 deletion compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,14 @@ object SymDenotations {
* is defined in Scala 3 and is neither abstract nor open.
*/
final def isEffectivelySealed(using Context): Boolean =
isOneOf(FinalOrSealed) || isClass && !isOneOf(EffectivelyOpenFlags)
isOneOf(FinalOrSealed)
|| isClass && (!isOneOf(EffectivelyOpenFlags)
|| isLocalToCompilationUnit)

final def isLocalToCompilationUnit(using Context): Boolean =
is(Private)
|| owner.ownersIterator.exists(_.isTerm)
|| accessBoundary(defn.RootClass).isContainedIn(symbol.topLevelClass)

final def isTransparentClass(using Context): Boolean =
is(TransparentType)
Expand Down
2 changes: 0 additions & 2 deletions scala2-library-cc/src/scala/collection/IndexedSeqView.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ object IndexedSeqView {

@SerialVersionUID(3L)
private[collection] class IndexedSeqViewIterator[A](self: IndexedSeqView[A]^) extends AbstractIterator[A] with Serializable {
this: IndexedSeqViewIterator[A]^ =>
private[this] var current = 0
private[this] var remainder = self.length
override def knownSize: Int = remainder
Expand Down Expand Up @@ -90,7 +89,6 @@ object IndexedSeqView {
}
@SerialVersionUID(3L)
private[collection] class IndexedSeqViewReverseIterator[A](self: IndexedSeqView[A]^) extends AbstractIterator[A] with Serializable {
this: IndexedSeqViewReverseIterator[A]^ =>
private[this] var remainder = self.length
private[this] var pos = remainder - 1
@inline private[this] def _hasNext: Boolean = remainder > 0
Expand Down
3 changes: 0 additions & 3 deletions scala2-library-cc/src/scala/collection/Iterable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/

package scala
package collection

Expand All @@ -29,7 +28,6 @@ import language.experimental.captureChecking
trait Iterable[+A] extends IterableOnce[A]
with IterableOps[A, Iterable, Iterable[A]]
with IterableFactoryDefaults[A, Iterable] {
this: Iterable[A]^ =>

// The collection itself
@deprecated("toIterable is internal and will be made protected; its name is similar to `toList` or `toSeq`, but it doesn't copy non-immutable collections", "2.13.7")
Expand Down Expand Up @@ -134,7 +132,6 @@ trait Iterable[+A] extends IterableOnce[A]
* and may be nondeterministic.
*/
trait IterableOps[+A, +CC[_], +C] extends Any with IterableOnce[A] with IterableOnceOps[A, CC, C] {
this: IterableOps[A, CC, C]^ =>

/**
* @return This collection as an `Iterable[A]`. No new collection will be built if `this` is already an `Iterable[A]`.
Expand Down
1 change: 0 additions & 1 deletion scala2-library-cc/src/scala/collection/IterableOnce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import language.experimental.captureChecking
* @define coll collection
*/
trait IterableOnce[+A] extends Any {
this: IterableOnce[A]^ =>

/** Iterator can be used only once */
def iterator: Iterator[A]^{this}
Expand Down
3 changes: 1 addition & 2 deletions scala2-library-cc/src/scala/collection/Iterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1302,5 +1302,4 @@ object Iterator extends IterableFactory[Iterator] {
}

/** Explicit instantiation of the `Iterator` trait to reduce class file size in subclasses. */
abstract class AbstractIterator[+A] extends Iterator[A]:
this: Iterator[A]^ =>
abstract class AbstractIterator[+A] extends Iterator[A]
1 change: 0 additions & 1 deletion scala2-library-cc/src/scala/collection/Map.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ trait Map[K, +V]
trait MapOps[K, +V, +CC[_, _] <: IterableOps[_, AnyConstr, _], +C]
extends IterableOps[(K, V), Iterable, C]
with PartialFunction[K, V] {
this: MapOps[K, V, CC, C]^ =>

override def view: MapView[K, V]^{this} = new MapView.Id(this)

Expand Down
4 changes: 1 addition & 3 deletions scala2-library-cc/src/scala/collection/MapView.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import caps.unsafe.unsafeAssumePure
trait MapView[K, +V]
extends MapOps[K, V, ({ type l[X, Y] = View[(X, Y)] })#l, View[(K, V)]]
with View[(K, V)] {
this: MapView[K, V]^ =>

override def view: MapView[K, V]^{this} = this

Expand Down Expand Up @@ -191,6 +190,5 @@ trait MapViewFactory extends collection.MapFactory[({ type l[X, Y] = View[(X, Y)

/** Explicit instantiation of the `MapView` trait to reduce class file size in subclasses. */
@SerialVersionUID(3L)
abstract class AbstractMapView[K, +V] extends AbstractView[(K, V)] with MapView[K, V]:
this: AbstractMapView[K, V]^ =>
abstract class AbstractMapView[K, +V] extends AbstractView[(K, V)] with MapView[K, V]

5 changes: 0 additions & 5 deletions scala2-library-cc/src/scala/collection/Stepper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import scala.collection.Stepper.EfficientSplit
* @tparam A the element type of the Stepper
*/
trait Stepper[@specialized(Double, Int, Long) +A] {
this: Stepper[A]^ =>

/** Check if there's an element available. */
def hasStep: Boolean
Expand Down Expand Up @@ -186,7 +185,6 @@ object Stepper {

/** A Stepper for arbitrary element types. See [[Stepper]]. */
trait AnyStepper[+A] extends Stepper[A] {
this: AnyStepper[A]^ =>

def trySplit(): AnyStepper[A]

Expand Down Expand Up @@ -258,7 +256,6 @@ object AnyStepper {

/** A Stepper for Ints. See [[Stepper]]. */
trait IntStepper extends Stepper[Int] {
this: IntStepper^ =>

def trySplit(): IntStepper

Expand Down Expand Up @@ -298,7 +295,6 @@ object IntStepper {

/** A Stepper for Doubles. See [[Stepper]]. */
trait DoubleStepper extends Stepper[Double] {
this: DoubleStepper^ =>
def trySplit(): DoubleStepper

def spliterator[B >: Double]: Spliterator.OfDouble^{this} = new DoubleStepper.DoubleStepperSpliterator(this)
Expand Down Expand Up @@ -338,7 +334,6 @@ object DoubleStepper {

/** A Stepper for Longs. See [[Stepper]]. */
trait LongStepper extends Stepper[Long] {
this: LongStepper^ =>

def trySplit(): LongStepper^{this}

Expand Down
1 change: 0 additions & 1 deletion scala2-library-cc/src/scala/collection/View.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import language.experimental.captureChecking
* @define Coll `View`
*/
trait View[+A] extends Iterable[A] with IterableOps[A, View, View[A]] with IterableFactoryDefaults[A, View] with Serializable {
this: View[A]^ =>

override def view: View[A]^{this} = this

Expand Down
1 change: 0 additions & 1 deletion scala2-library-cc/src/scala/collection/WithFilter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import language.experimental.captureChecking
*/
@SerialVersionUID(3L)
abstract class WithFilter[+A, +CC[_]] extends Serializable {
this: WithFilter[A, CC]^ =>

/** Builds a new collection by applying a function to all elements of the
* `filtered` outer $coll.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import language.experimental.captureChecking
trait Iterable[+A] extends collection.Iterable[A]
with collection.IterableOps[A, Iterable, Iterable[A]]
with IterableFactoryDefaults[A, Iterable] {
this: Iterable[A]^ =>

override def iterableFactory: IterableFactory[Iterable] = Iterable
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ final class LazyListIterable[+A] private(private[this] var lazyState: () => Lazy
with IterableOps[A, LazyListIterable, LazyListIterable[A]]
with IterableFactoryDefaults[A, LazyListIterable]
with Serializable {
this: LazyListIterable[A]^ =>
import LazyListIterable._

@volatile private[this] var stateEvaluated: Boolean = false
Expand Down Expand Up @@ -964,7 +963,6 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
private[this] val _empty = newLL(State.Empty).force

private sealed trait State[+A] extends Serializable {
this: State[A]^ =>
def head: A
def tail: LazyListIterable[A]^
}
Expand Down Expand Up @@ -1252,7 +1250,6 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {

private class SlidingIterator[A](private[this] var lazyList: LazyListIterable[A]^, size: Int, step: Int)
extends AbstractIterator[LazyListIterable[A]] {
this: SlidingIterator[A]^ =>
private val minLen = size - step max 0
private var first = true

Expand All @@ -1273,7 +1270,6 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {

private final class WithFilter[A] private[LazyListIterable](lazyList: LazyListIterable[A]^, p: A => Boolean)
extends collection.WithFilter[A, LazyListIterable] {
this: WithFilter[A]^ =>
private[this] val filtered = lazyList.filter(p)
def map[B](f: A => B): LazyListIterable[B]^{this, f} = filtered.map(f)
def flatMap[B](f: A => IterableOnce[B]^): LazyListIterable[B]^{this, f} = filtered.flatMap(f)
Expand Down Expand Up @@ -1320,7 +1316,6 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {

private object LazyBuilder {
final class DeferredState[A] {
this: DeferredState[A]^ =>
private[this] var _state: (() => State[A]^) @uncheckedCaptures = _

def eval(): State[A]^ = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package mutable
import language.experimental.captureChecking

private[mutable] trait CheckedIndexedSeqView[+A] extends IndexedSeqView[A] {
this: CheckedIndexedSeqView[A]^ =>

protected val mutationCount: () => Int

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ trait Iterable[A]
extends collection.Iterable[A]
with collection.IterableOps[A, Iterable, Iterable[A]]
with IterableFactoryDefaults[A, Iterable] {
this: Iterable[A]^ =>

override def iterableFactory: IterableFactory[Iterable] = Iterable
}
Expand All @@ -33,5 +32,4 @@ trait Iterable[A]
object Iterable extends IterableFactory.Delegate[Iterable](ArrayBuffer)

/** Explicit instantiation of the `Iterable` trait to reduce class file size in subclasses. */
abstract class AbstractIterable[A] extends scala.collection.AbstractIterable[A] with Iterable[A]:
this: AbstractIterable[A]^ =>
abstract class AbstractIterable[A] extends scala.collection.AbstractIterable[A] with Iterable[A]
13 changes: 7 additions & 6 deletions tests/neg-custom-args/captures/cc-this4.check
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
-- Error: tests/neg-custom-args/captures/cc-this4.scala:1:11 -----------------------------------------------------------
1 |open class C: // error
| ^
| class C needs an explicitly declared self type since its
| inferred self type C^{}
| is not visible in other compilation units that define subclasses.
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/cc-this4.scala:2:13 --------------------------------------
2 | val x: C = this // error
| ^^^^
| Found: (C.this : C^)
| Required: C
|
| longer explanation available when compiling with `-explain`
4 changes: 2 additions & 2 deletions tests/neg-custom-args/captures/cc-this4.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
open class C: // error
val x: C = this
open class C:
val x: C = this // error

open class D:
this: D =>
Expand Down
Loading

0 comments on commit 9b5815a

Please sign in to comment.