Skip to content

Commit c47b6b8

Browse files
committed
Support Mirrors for local and inner classes.
- update child accessibility check for anonymous mirrors in whyNotGenericSum. Now check the prefix at the callsite can access the child. - for sum mirrors, compute a new prefix for each child from the callsite prefix of the parent, see TypeOps.childPrefix. For each child, subsititute its prefix at definition with the childPrefix using asSeenFrom. For polymorphic classes, perform the subsitution on the constructor before inferring constraints. - add tests for issues 13332, 13935, 11174, 12328 - add tests for local/inner classes taken from Shapeless for its Generic type, backed by mirrors
1 parent e560c2d commit c47b6b8

30 files changed

+1757
-84
lines changed

compiler/src/dotty/tools/dotc/core/TypeOps.scala

+69
Original file line numberDiff line numberDiff line change
@@ -868,4 +868,73 @@ object TypeOps:
868868
def stripTypeVars(tp: Type)(using Context): Type =
869869
new StripTypeVarsMap().apply(tp)
870870

871+
/** computes a prefix for `child`, derived from its common prefix with `pre`
872+
* - `pre` is assumed to be the prefix of `parent` at a given callsite.
873+
* - `child` is assumed to be the sealed child of `parent`, and reachable according to `whyNotGenericSum`.
874+
*/
875+
def childPrefix(pre: Type, parent: Symbol, child: Symbol)(using Context): Type =
876+
// Example, given this class hierarchy, we can see how this should work
877+
// when summoning a mirror for `wrapper.Color`:
878+
//
879+
// package example
880+
// object Outer3:
881+
// class Wrapper:
882+
// sealed trait Color
883+
// val wrapper = new Wrapper
884+
// object Inner:
885+
// case object Red extends wrapper.Color
886+
// case object Green extends wrapper.Color
887+
// case object Blue extends wrapper.Color
888+
//
889+
// summon[Mirror.SumOf[wrapper.Color]]
890+
// ^^^^^^^^^^^^^
891+
// > pre = example.Outer3.wrapper.type
892+
// > parent = sealed trait example.Outer3.Wrapper.Color
893+
// > child = module val example.Outer3.Innner.Red
894+
// > parentOwners = [example, Outer3, Wrapper] // computed from definition
895+
// > childOwners = [example, Outer3, Inner] // computed from definition
896+
// > parentRest = [Wrapper] // strip common owners from `childOwners`
897+
// > childRest = [Inner] // strip common owners from `parentOwners`
898+
// > commonPrefix = example.Outer3.type // i.e. parentRest has only 1 element, use 1st subprefix of `pre`.
899+
// > childPrefix = example.Outer3.Inner.type // select all symbols in `childRest` from `commonPrefix`
900+
901+
/** unwind the prefix into a sequence of sub-prefixes, selecting the one at `limit`
902+
* @return `NoType` if there is an unrecognised prefix type.
903+
*/
904+
def subPrefixAt(pre: Type, limit: Int): Type =
905+
def go(pre: Type, limit: Int): Type =
906+
if limit == 0 then pre // EXIT: No More prefix
907+
else pre match
908+
case pre: ThisType => go(pre.tref.prefix, limit - 1)
909+
case pre: TermRef => go(pre.prefix, limit - 1)
910+
case _:SuperType | NoPrefix => pre.ensuring(limit == 1) // EXIT: can't rewind further than this
911+
case _ => NoType // EXIT: unrecognized prefix
912+
go(pre, limit)
913+
end subPrefixAt
914+
915+
/** Successively select each symbol in the `suffix` from `pre`, such that they are reachable. */
916+
def selectAll(pre: Type, suffix: Seq[Symbol]): Type =
917+
suffix.foldLeft(pre)((pre, sym) =>
918+
pre.select(
919+
if sym.isType && sym.is(Module) then sym.sourceModule
920+
else sym
921+
)
922+
)
923+
924+
def stripCommonPrefix(xs: List[Symbol], ys: List[Symbol]): (List[Symbol], List[Symbol]) = (xs, ys) match
925+
case (x :: xs1, y :: ys1) if x eq y => stripCommonPrefix(xs1, ys1)
926+
case _ => (xs, ys)
927+
928+
val (parentRest, childRest) = stripCommonPrefix(
929+
parent.owner.ownersIterator.toList.reverse,
930+
child.owner.ownersIterator.toList.reverse
931+
)
932+
933+
val commonPrefix = subPrefixAt(pre, parentRest.size) // unwind parent owners up to common prefix
934+
935+
if commonPrefix.exists then selectAll(commonPrefix, childRest)
936+
else NoType
937+
938+
end childPrefix
939+
871940
end TypeOps

compiler/src/dotty/tools/dotc/transform/PostInlining.scala

+1-3
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ class PostInlining extends MacroTransform, IdentityDenotTransformer:
2626
override def transform(tree: Tree)(using Context): Tree =
2727
super.transform(tree) match
2828
case tree1: Template
29-
if tree1.hasAttachment(ExtendsSingletonMirror)
30-
|| tree1.hasAttachment(ExtendsProductMirror)
31-
|| tree1.hasAttachment(ExtendsSumMirror) =>
29+
if tree1.hasAttachment(ExtendsSingletonMirror) || tree1.hasAttachment(ExtendsSumOrProductMirror) =>
3230
synthMbr.addMirrorSupport(tree1)
3331
case tree1 => tree1
3432

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

+21-7
Original file line numberDiff line numberDiff line change
@@ -163,28 +163,42 @@ object SymUtils:
163163
* and also the location of the generated mirror.
164164
* - all of its children are generic products, singletons, or generic sums themselves.
165165
*/
166-
def whyNotGenericSum(using Context): String =
166+
def whyNotGenericSum(pre: Type)(using Context): String =
167167
if (!self.is(Sealed))
168168
s"it is not a sealed ${self.kindString}"
169169
else if (!self.isOneOf(AbstractOrTrait))
170170
"it is not an abstract class"
171171
else {
172172
val children = self.children
173173
val companionMirror = self.useCompanionAsSumMirror
174+
val ownerScope = if pre.isInstanceOf[SingletonType] then pre.classSymbol else NoSymbol
174175
def problem(child: Symbol) = {
175176

176-
def isAccessible(sym: Symbol): Boolean =
177-
(self.isContainedIn(sym) && (companionMirror || ctx.owner.isContainedIn(sym)))
178-
|| sym.is(Module) && isAccessible(sym.owner)
177+
def accessibleMessage(sym: Symbol): String =
178+
def inherits(sym: Symbol, scope: Symbol): Boolean =
179+
!scope.is(Package) && (scope.derivesFrom(sym) || inherits(sym, scope.owner))
180+
def isVisibleToParent(sym: Symbol): Boolean =
181+
self.isContainedIn(sym) || sym.is(Module) && isVisibleToParent(sym.owner)
182+
def isVisibleToScope(sym: Symbol): Boolean =
183+
def isReachable: Boolean = ctx.owner.isContainedIn(sym)
184+
def isMemberOfPrefix: Boolean =
185+
ownerScope.exists && inherits(sym, ownerScope)
186+
isReachable || isMemberOfPrefix || sym.is(Module) && isVisibleToScope(sym.owner)
187+
if !isVisibleToParent(sym) then i"to its parent $self"
188+
else if !companionMirror && !isVisibleToScope(sym) then i"to call site ${ctx.owner}"
189+
else ""
190+
end accessibleMessage
191+
192+
val childAccessible = accessibleMessage(child.owner)
179193

180194
if (child == self) "it has anonymous or inaccessible subclasses"
181-
else if (!isAccessible(child.owner)) i"its child $child is not accessible"
195+
else if (!childAccessible.isEmpty) i"its child $child is not accessible $childAccessible"
182196
else if (!child.isClass) "" // its a singleton enum value
183197
else {
184198
val s = child.whyNotGenericProduct
185199
if s.isEmpty then s
186200
else if child.is(Sealed) then
187-
val s = child.whyNotGenericSum
201+
val s = child.whyNotGenericSum(pre)
188202
if s.isEmpty then s
189203
else i"its child $child is not a generic sum because $s"
190204
else
@@ -195,7 +209,7 @@ object SymUtils:
195209
else children.map(problem).find(!_.isEmpty).getOrElse("")
196210
}
197211

198-
def isGenericSum(using Context): Boolean = whyNotGenericSum.isEmpty
212+
def isGenericSum(pre: Type)(using Context): Boolean = whyNotGenericSum(pre).isEmpty
199213

200214
/** If this is a constructor, its owner: otherwise this. */
201215
final def skipConstructor(using Context): Symbol =

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

+67-42
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@ import NullOpsDecorator._
1818

1919
object SyntheticMembers {
2020

21+
enum MirrorImpl:
22+
case OfProduct(pre: Type)
23+
case OfSum(childPres: List[Type])
24+
2125
/** Attachment marking an anonymous class as a singleton case that will extend from Mirror.Singleton */
2226
val ExtendsSingletonMirror: Property.StickyKey[Unit] = new Property.StickyKey
2327

2428
/** Attachment recording that an anonymous class should extend Mirror.Product */
25-
val ExtendsProductMirror: Property.StickyKey[Unit] = new Property.StickyKey
26-
27-
/** Attachment recording that an anonymous class should extend Mirror.Sum */
28-
val ExtendsSumMirror: Property.StickyKey[Unit] = new Property.StickyKey
29+
val ExtendsSumOrProductMirror: Property.StickyKey[MirrorImpl] = new Property.StickyKey
2930
}
3031

3132
/** Synthetic method implementations for case classes, case objects,
@@ -484,32 +485,41 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
484485
* type MirroredMonoType = C[?]
485486
* ```
486487
*/
487-
def fromProductBody(caseClass: Symbol, param: Tree)(using Context): Tree = {
488-
val (classRef, methTpe) =
489-
caseClass.primaryConstructor.info match {
488+
def fromProductBody(caseClass: Symbol, param: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
489+
def extractParams(tpe: Type): List[Type] =
490+
tpe.asInstanceOf[MethodType].paramInfos
491+
492+
def computeFromCaseClass: (Type, List[Type]) =
493+
val (baseRef, baseInfo) =
494+
val rawRef = caseClass.typeRef
495+
val rawInfo = caseClass.primaryConstructor.info
496+
optInfo match
497+
case Some(info) =>
498+
(rawRef.asSeenFrom(info.pre, caseClass.owner), rawInfo.asSeenFrom(info.pre, caseClass.owner))
499+
case _ =>
500+
(rawRef, rawInfo)
501+
baseInfo match
490502
case tl: PolyType =>
491503
val (tl1, tpts) = constrained(tl, untpd.EmptyTree, alwaysAddTypeVars = true)
492504
val targs =
493505
for (tpt <- tpts) yield
494506
tpt.tpe match {
495507
case tvar: TypeVar => tvar.instantiate(fromBelow = false)
496508
}
497-
(caseClass.typeRef.appliedTo(targs), tl.instantiate(targs))
509+
(baseRef.appliedTo(targs), extractParams(tl.instantiate(targs)))
498510
case methTpe =>
499-
(caseClass.typeRef, methTpe)
500-
}
501-
methTpe match {
502-
case methTpe: MethodType =>
503-
val elems =
504-
for ((formal, idx) <- methTpe.paramInfos.zipWithIndex) yield {
505-
val elem =
506-
param.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
507-
.ensureConforms(formal.translateFromRepeated(toArray = false))
508-
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
509-
}
510-
New(classRef, elems)
511-
}
512-
}
511+
(baseRef, extractParams(methTpe))
512+
end computeFromCaseClass
513+
514+
val (classRefApplied, paramInfos) = computeFromCaseClass
515+
val elems =
516+
for ((formal, idx) <- paramInfos.zipWithIndex) yield
517+
val elem =
518+
param.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
519+
.ensureConforms(formal.translateFromRepeated(toArray = false))
520+
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
521+
New(classRefApplied, elems)
522+
end fromProductBody
513523

514524
/** For an enum T:
515525
*
@@ -527,24 +537,36 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
527537
* a wildcard for each type parameter. The normalized type of an object
528538
* O is O.type.
529539
*/
530-
def ordinalBody(cls: Symbol, param: Tree)(using Context): Tree =
531-
if (cls.is(Enum)) param.select(nme.ordinal).ensureApplied
532-
else {
540+
def ordinalBody(cls: Symbol, param: Tree, optInfo: Option[MirrorImpl.OfSum])(using Context): Tree =
541+
if cls.is(Enum) then
542+
param.select(nme.ordinal).ensureApplied
543+
else
544+
def computeChildTypes: List[Type] =
545+
def rawRef(child: Symbol): Type =
546+
if (child.isTerm) child.reachableTermRef else child.reachableRawTypeRef
547+
optInfo match
548+
case Some(info) => info
549+
.childPres
550+
.lazyZip(cls.children)
551+
.map((pre, child) => rawRef(child).asSeenFrom(pre, child.owner))
552+
case _ =>
553+
cls.children.map(rawRef)
554+
end computeChildTypes
555+
val childTypes = computeChildTypes
533556
val cases =
534-
for ((child, idx) <- cls.children.zipWithIndex) yield {
535-
val patType = if (child.isTerm) child.reachableTermRef else child.reachableRawTypeRef
557+
for (patType, idx) <- childTypes.zipWithIndex yield
536558
val pat = Typed(untpd.Ident(nme.WILDCARD).withType(patType), TypeTree(patType))
537559
CaseDef(pat, EmptyTree, Literal(Constant(idx)))
538-
}
560+
539561
Match(param.annotated(New(defn.UncheckedAnnot.typeRef, Nil)), cases)
540-
}
562+
end ordinalBody
541563

542564
/** - If `impl` is the companion of a generic sum, add `deriving.Mirror.Sum` parent
543565
* and `MirroredMonoType` and `ordinal` members.
544566
* - If `impl` is the companion of a generic product, add `deriving.Mirror.Product` parent
545567
* and `MirroredMonoType` and `fromProduct` members.
546-
* - If `impl` is marked with one of the attachments ExtendsSingletonMirror, ExtendsProductMirror,
547-
* or ExtendsSumMirror, remove the attachment and generate the corresponding mirror support,
568+
* - If `impl` is marked with one of the attachments ExtendsSingletonMirror or ExtendsSumOfProductMirror,
569+
* remove the attachment and generate the corresponding mirror support,
548570
* On this case the represented class or object is referred to in a pre-existing `MirroredMonoType`
549571
* member of the template.
550572
*/
@@ -581,30 +603,33 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
581603
}
582604
def makeSingletonMirror() =
583605
addParent(defn.Mirror_SingletonClass.typeRef)
584-
def makeProductMirror(cls: Symbol) = {
606+
def makeProductMirror(cls: Symbol, optInfo: Option[MirrorImpl.OfProduct]) = {
585607
addParent(defn.Mirror_ProductClass.typeRef)
586608
addMethod(nme.fromProduct, MethodType(defn.ProductClass.typeRef :: Nil, monoType.typeRef), cls,
587-
fromProductBody(_, _).ensureConforms(monoType.typeRef)) // t4758.scala or i3381.scala are examples where a cast is needed
609+
fromProductBody(_, _, optInfo).ensureConforms(monoType.typeRef)) // t4758.scala or i3381.scala are examples where a cast is needed
588610
}
589-
def makeSumMirror(cls: Symbol) = {
611+
def makeSumMirror(cls: Symbol, optInfo: Option[MirrorImpl.OfSum]) = {
590612
addParent(defn.Mirror_SumClass.typeRef)
591613
addMethod(nme.ordinal, MethodType(monoType.typeRef :: Nil, defn.IntType), cls,
592-
ordinalBody(_, _))
614+
ordinalBody(_, _, optInfo))
593615
}
594616

595617
if (clazz.is(Module)) {
596618
if (clazz.is(Case)) makeSingletonMirror()
597-
else if (linked.isGenericProduct) makeProductMirror(linked)
598-
else if (linked.isGenericSum) makeSumMirror(linked)
619+
else if (linked.isGenericProduct) makeProductMirror(linked, None)
620+
else if (linked.isGenericSum(NoType)) makeSumMirror(linked, None)
599621
else if (linked.is(Sealed))
600-
derive.println(i"$linked is not a sum because ${linked.whyNotGenericSum}")
622+
derive.println(i"$linked is not a sum because ${linked.whyNotGenericSum(NoType)}")
601623
}
602624
else if (impl.removeAttachment(ExtendsSingletonMirror).isDefined)
603625
makeSingletonMirror()
604-
else if (impl.removeAttachment(ExtendsProductMirror).isDefined)
605-
makeProductMirror(monoType.typeRef.dealias.classSymbol)
606-
else if (impl.removeAttachment(ExtendsSumMirror).isDefined)
607-
makeSumMirror(monoType.typeRef.dealias.classSymbol)
626+
else
627+
impl.removeAttachment(ExtendsSumOrProductMirror).match
628+
case Some(prodImpl: MirrorImpl.OfProduct) =>
629+
makeProductMirror(monoType.typeRef.dealias.classSymbol, Some(prodImpl))
630+
case Some(sumImpl: MirrorImpl.OfSum) =>
631+
makeSumMirror(monoType.typeRef.dealias.classSymbol, Some(sumImpl))
632+
case _ =>
608633

609634
cpy.Template(impl)(parents = newParents, body = newBody)
610635
}

0 commit comments

Comments
 (0)