Skip to content

Commit bf18947

Browse files
authored
Merge pull request swiftlang#67502 from slavapestov/rqm-shape-req-fixes
RequirementMachine: Fix some edge cases with shape requirements
2 parents 5abb580 + 7ce6f37 commit bf18947

File tree

10 files changed

+172
-24
lines changed

10 files changed

+172
-24
lines changed

include/swift/AST/Type.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ enum class ForeignRepresentableKind : uint8_t {
214214
/// therefore, the result type is in covariant position relative to the function
215215
/// type.
216216
struct TypePosition final {
217-
enum : uint8_t { Covariant, Contravariant, Invariant };
217+
enum : uint8_t { Covariant, Contravariant, Invariant, Shape };
218218

219219
private:
220220
decltype(Covariant) kind;
@@ -224,6 +224,7 @@ struct TypePosition final {
224224

225225
TypePosition flipped() const {
226226
switch (kind) {
227+
case Shape:
227228
case Invariant:
228229
return *this;
229230
case Covariant:

include/swift/AST/TypeMatcher.h

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ namespace swift {
5151
/// or false to indicate that matching should exit early.
5252
template<typename ImplClass>
5353
class TypeMatcher {
54+
public:
55+
enum class Position : uint8_t {
56+
Type,
57+
Shape
58+
};
59+
60+
private:
5461
class MatchVisitor : public CanTypeVisitor<MatchVisitor, bool, Type, Type> {
5562
TypeMatcher &Matcher;
5663

@@ -192,11 +199,24 @@ class TypeMatcher {
192199

193200
bool visitPackExpansionType(CanPackExpansionType firstPE, Type secondType,
194201
Type sugaredFirstType) {
195-
if (auto secondInOut = secondType->getAs<PackExpansionType>()) {
196-
return this->visit(firstPE.getPatternType(),
197-
secondInOut->getPatternType(),
198-
sugaredFirstType->castTo<PackExpansionType>()
199-
->getPatternType());
202+
if (auto secondExpansion = secondType->getAs<PackExpansionType>()) {
203+
if (!this->visit(firstPE.getPatternType(),
204+
secondExpansion->getPatternType(),
205+
sugaredFirstType->castTo<PackExpansionType>()
206+
->getPatternType())) {
207+
return false;
208+
}
209+
210+
Matcher.asDerived().pushPosition(Position::Shape);
211+
if (!this->visit(firstPE.getCountType(),
212+
secondExpansion->getCountType(),
213+
sugaredFirstType->castTo<PackExpansionType>()
214+
->getCountType())) {
215+
return false;
216+
}
217+
Matcher.asDerived().popPosition(Position::Shape);
218+
219+
return true;
200220
}
201221

202222
return mismatch(firstPE.getPointer(), secondType, sugaredFirstType);
@@ -524,6 +544,9 @@ class TypeMatcher {
524544

525545
bool alwaysMismatchTypeParameters() const { return false; }
526546

547+
void pushPosition(Position pos) {}
548+
void popPosition(Position pos) {}
549+
527550
ImplClass &asDerived() { return static_cast<ImplClass &>(*this); }
528551

529552
const ImplClass &asDerived() const {

lib/AST/RequirementMachine/InterfaceType.cpp

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ Type PropertyMap::getTypeForTerm(const MutableTerm &term,
403403
/// Concrete type terms are written in terms of generic parameter types that
404404
/// have a depth of 0, and an index into an array of substitution terms.
405405
///
406-
/// See RewriteSystemBuilder::getConcreteSubstitutionSchema().
406+
/// See RewriteSystemBuilder::getSubstitutionSchemaFromType().
407407
unsigned RewriteContext::getGenericParamIndex(Type type) {
408408
auto *paramTy = type->castTo<GenericTypeParamType>();
409409
assert(paramTy->getDepth() == 0);
@@ -429,6 +429,7 @@ RewriteContext::getRelativeTermForType(CanType typeWitness,
429429
// Get the substitution S corresponding to τ_0_n.
430430
unsigned index = getGenericParamIndex(typeWitness->getRootGenericParam());
431431
result = MutableTerm(substitutions[index]);
432+
assert(result.back().getKind() != Symbol::Kind::Shape);
432433

433434
// If the substitution is a term consisting of a single protocol symbol
434435
// [P], save P for later.
@@ -471,7 +472,7 @@ RewriteContext::getRelativeTermForType(CanType typeWitness,
471472
}
472473

473474
/// Reverses the transformation performed by
474-
/// RewriteSystemBuilder::getConcreteSubstitutionSchema().
475+
/// RewriteSystemBuilder::getSubstitutionSchemaFromType().
475476
Type PropertyMap::getTypeFromSubstitutionSchema(
476477
Type schema, ArrayRef<Term> substitutions,
477478
ArrayRef<GenericTypeParamType *> genericParams,
@@ -481,11 +482,38 @@ Type PropertyMap::getTypeFromSubstitutionSchema(
481482
if (!schema->hasTypeParameter())
482483
return schema;
483484

484-
return schema.transformRec([&](Type t) -> llvm::Optional<Type> {
485+
return schema.transformWithPosition(
486+
TypePosition::Invariant,
487+
[&](Type t, TypePosition pos) -> llvm::Optional<Type> {
485488
if (t->is<GenericTypeParamType>()) {
486489
auto index = RewriteContext::getGenericParamIndex(t);
487490
auto substitution = substitutions[index];
488491

492+
bool isShapePosition = (pos == TypePosition::Shape);
493+
bool isShapeTerm = (substitution.back() == Symbol::forShape(Context));
494+
if (isShapePosition != isShapeTerm) {
495+
llvm::errs() << "Shape vs. type mixup\n\n";
496+
schema.dump(llvm::errs());
497+
llvm::errs() << "Substitutions:\n";
498+
for (auto otherSubst : substitutions) {
499+
llvm::errs() << "- ";
500+
otherSubst.dump(llvm::errs());
501+
llvm::errs() << "\n";
502+
}
503+
llvm::errs() << "\n";
504+
dump(llvm::errs());
505+
506+
abort();
507+
}
508+
509+
// Undo the thing where the count type of a PackExpansionType
510+
// becomes a shape term.
511+
if (isShapeTerm) {
512+
MutableTerm mutTerm(substitution.begin(),
513+
substitution.end() - 1);
514+
substitution = Term::get(mutTerm, Context);
515+
}
516+
489517
// Prepend the prefix of the lookup key to the substitution.
490518
if (prefix.empty()) {
491519
// Skip creation of a new MutableTerm in the case where the
@@ -535,17 +563,31 @@ RewriteContext::getRelativeSubstitutionSchemaFromType(
535563
if (!concreteType->hasTypeParameter())
536564
return concreteType;
537565

538-
return CanType(concreteType.transformRec([&](Type t) -> llvm::Optional<Type> {
566+
return CanType(concreteType.transformWithPosition(
567+
TypePosition::Invariant,
568+
[&](Type t, TypePosition pos) -> llvm::Optional<Type> {
569+
539570
if (!t->isTypeParameter())
540571
return llvm::None;
541572

542573
auto term = getRelativeTermForType(CanType(t), substitutions);
543574

544-
unsigned newIndex = result.size();
575+
// PackExpansionType(pattern=T, count=U) becomes
576+
// PackExpansionType(pattern=τ_0_0, count=τ_0_1) with
577+
//
578+
// τ_0_0 := T
579+
// τ_0_1 := U.[shape]
580+
if (pos == TypePosition::Shape) {
581+
assert(false);
582+
term.add(Symbol::forShape(*this));
583+
}
584+
585+
unsigned index = result.size();
586+
545587
result.push_back(Term::get(term, *this));
546588

547589
return CanGenericTypeParamType::get(/*isParameterPack=*/ false,
548-
/*depth=*/ 0, newIndex,
590+
/*depth=*/ 0, index,
549591
Context);
550592
}));
551593
}
@@ -566,12 +608,26 @@ RewriteContext::getSubstitutionSchemaFromType(CanType concreteType,
566608
if (!concreteType->hasTypeParameter())
567609
return concreteType;
568610

569-
return CanType(concreteType.transformRec([&](Type t) -> llvm::Optional<Type> {
611+
return CanType(concreteType.transformWithPosition(
612+
TypePosition::Invariant,
613+
[&](Type t, TypePosition pos)
614+
-> llvm::Optional<Type> {
615+
570616
if (!t->isTypeParameter())
571617
return llvm::None;
572618

619+
// PackExpansionType(pattern=T, count=U) becomes
620+
// PackExpansionType(pattern=τ_0_0, count=τ_0_1) with
621+
//
622+
// τ_0_0 := T
623+
// τ_0_1 := U.[shape]
624+
MutableTerm term = getMutableTermForType(CanType(t), proto);
625+
if (pos == TypePosition::Shape)
626+
term.add(Symbol::forShape(*this));
627+
573628
unsigned index = result.size();
574-
result.push_back(getTermForType(CanType(t), proto));
629+
630+
result.push_back(Term::get(term, *this));
575631

576632
return CanGenericTypeParamType::get(/*isParameterPack=*/ false,
577633
/*depth=*/0, index,

lib/AST/RequirementMachine/PropertyUnification.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
#include "swift/AST/Decl.h"
2626
#include "swift/AST/LayoutConstraint.h"
27-
#include "swift/AST/TypeMatcher.h"
2827
#include "swift/AST/Types.h"
2928
#include <algorithm>
3029
#include <vector>

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ static void desugarSameTypeRequirement(Requirement req, SourceLoc loc,
180180
SourceLoc loc;
181181
SmallVectorImpl<Requirement> &result;
182182
SmallVectorImpl<RequirementError> &errors;
183+
SmallVector<Position, 2> stack;
183184

184185
public:
185186
bool recordedErrors = false;
@@ -191,32 +192,53 @@ static void desugarSameTypeRequirement(Requirement req, SourceLoc loc,
191192

192193
bool alwaysMismatchTypeParameters() const { return true; }
193194

195+
void pushPosition(Position pos) {
196+
stack.push_back(pos);
197+
}
198+
199+
void popPosition(Position pos) {
200+
assert(stack.back() == pos);
201+
stack.pop_back();
202+
}
203+
204+
Position getPosition() const {
205+
if (stack.empty()) return Position::Type;
206+
return stack.back();
207+
}
208+
194209
bool mismatch(TypeBase *firstType, TypeBase *secondType,
195210
Type sugaredFirstType) {
211+
RequirementKind kind;
212+
switch (getPosition()) {
213+
case Position::Type:
214+
kind = RequirementKind::SameType;
215+
break;
216+
case Position::Shape:
217+
kind = RequirementKind::SameShape;
218+
break;
219+
}
220+
196221
// If one side is a parameter pack, this is a same-element requirement, which
197222
// is not yet supported.
198223
if (firstType->isParameterPack() != secondType->isParameterPack()) {
199224
errors.push_back(RequirementError::forSameElement(
200-
{RequirementKind::SameType, sugaredFirstType, secondType}, loc));
225+
{kind, sugaredFirstType, secondType}, loc));
201226
recordedErrors = true;
202227
return true;
203228
}
204229

205230
if (firstType->isTypeParameter() && secondType->isTypeParameter()) {
206-
result.emplace_back(RequirementKind::SameType,
207-
sugaredFirstType, secondType);
231+
result.emplace_back(kind, sugaredFirstType, secondType);
208232
return true;
209233
}
210234

211235
if (firstType->isTypeParameter()) {
212-
result.emplace_back(RequirementKind::SameType,
213-
sugaredFirstType, secondType);
236+
result.emplace_back(kind, sugaredFirstType, secondType);
214237
return true;
215238
}
216239

217240
if (secondType->isTypeParameter()) {
218-
result.emplace_back(RequirementKind::SameType,
219-
secondType, sugaredFirstType);
241+
result.emplace_back(kind, secondType, sugaredFirstType);
220242
return true;
221243
}
222244

lib/AST/RequirementMachine/Symbol.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class Term;
5757
/// This transformation allows DependentMemberTypes to be manipulated as
5858
/// terms, with the actual concrete type structure remaining opaque to
5959
/// the requirement machine. This transformation is implemented in
60-
/// RewriteContext::getConcreteSubstitutionSchema().
60+
/// RewriteContext::getSubstitutionSchemaFromType().
6161
///
6262
/// For example, the superclass requirement
6363
/// "T : MyClass<U.X, (Int) -> V.A.B>" is denoted with a symbol

lib/AST/RequirementMachine/TypeDifference.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ namespace {
159159
bool rhsAbstract = rhsType->isTypeParameter();
160160

161161
if (lhsAbstract && rhsAbstract) {
162+
// FIXME: same-element requirements
163+
assert(lhsType->isParameterPack() == rhsType->isParameterPack());
164+
162165
unsigned lhsIndex = RewriteContext::getGenericParamIndex(lhsType);
163166
unsigned rhsIndex = RewriteContext::getGenericParamIndex(rhsType);
164167

lib/AST/Type.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4677,7 +4677,7 @@ case TypeKind::Id:
46774677
return Type();
46784678

46794679
Type transformedCount =
4680-
expand->getCountType().transformWithPosition(pos, fn);
4680+
expand->getCountType().transformWithPosition(TypePosition::Shape, fn);
46814681
if (!transformedCount)
46824682
return Type();
46834683

lib/Sema/ConstraintSystem.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2199,6 +2199,7 @@ static Type typeEraseExistentialSelfReferences(Type refTy, Type baseTy,
21992199

22002200
case TypePosition::Contravariant:
22012201
case TypePosition::Invariant:
2202+
case TypePosition::Shape:
22022203
return Type(t);
22032204
}
22042205

test/Generics/pack-shape-requirements.swift

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,46 @@ protocol Q: P where A: Q {}
8383
// CHECK-LABEL: sameType2
8484
// CHECK-NEXT: Generic signature: <each T, each U where repeat each T : Q, repeat each U : Q, repeat (each T).[P]A.[P]A == (each U).[P]A.[P]A>
8585
func sameType2<each T, each U>(_: repeat (each T, each U)) where repeat each T: Q, repeat each U: Q, repeat (each T).A.A == (each U).A.A {}
86+
87+
88+
//////
89+
///
90+
/// A same-type requirement between two pack expansion types
91+
/// should desugar to a same-shape requirement between their
92+
/// count types and a same-type requirement between their
93+
/// element types.
94+
///
95+
//////
96+
97+
typealias First<T, U> = T
98+
typealias Shape<each T> = (repeat First<(), each T>)
99+
100+
// CHECK-LABEL: sameTypeDesugar1
101+
// CHECK-NEXT: Generic signature: <each T, each U where (repeat (each T, each U)) : Any>
102+
func sameTypeDesugar1<each T, each U>(t: repeat each T, u: repeat each U)
103+
where Shape<repeat each T> == Shape<repeat each U> {}
104+
105+
// CHECK-LABEL: sameTypeDesugar2
106+
// CHECK-NEXT: Generic signature: <each T, each U where repeat each T : P, (repeat (each T, each U)) : Any, repeat each U : P>
107+
func sameTypeDesugar2<each T: P, each U: P>(t: repeat each T, u: repeat each U)
108+
where Shape<repeat (each T).A> == Shape<repeat (each U).A> {}
109+
110+
/// More complex example involving concrete type matching in
111+
/// property map construction
112+
113+
protocol PP {
114+
associatedtype A
115+
}
116+
117+
struct G<each T> {}
118+
119+
// CHECK-LABEL: sameTypeMatch1
120+
// CHECK-NEXT: <T, each U, each V where T : PP, repeat each U : PP, repeat each V : PP, T.[PP]A == G<repeat (each U).[PP]A>, repeat (each U).[PP]A == (each V).[PP]A>
121+
func sameTypeMatch1<T: PP, each U: PP, each V: PP>(t: T, u: repeat each U, v: repeat each V)
122+
where T.A == G<repeat (each U).A>, T.A == G<repeat (each V).A>,
123+
(repeat (each U, each V)) : Any {}
124+
125+
// CHECK-LABEL: sameTypeMatch2
126+
// CHECK-NEXT: <T, each U, each V where T : PP, repeat each U : PP, (repeat (each U, each V)) : Any, repeat each V : PP, T.[PP]A == (/* shape: each U */ repeat ())>
127+
func sameTypeMatch2<T: PP, each U: PP, each V: PP>(t: T, u: repeat each U, v: repeat each V)
128+
where T.A == Shape<repeat each U>, T.A == Shape<repeat each V> {}

0 commit comments

Comments
 (0)