Skip to content

Commit d44ce17

Browse files
committed
Resolve merge conflict
2 parents 074f7e7 + e8e84ba commit d44ce17

File tree

5 files changed

+93
-17
lines changed

5 files changed

+93
-17
lines changed

src/main/scala/scala/async/Async.scala

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package scala.async
66

77
import scala.language.experimental.macros
88
import scala.reflect.macros.Context
9+
import scala.util.continuations.{cpsParam, reset}
910

1011
object Async extends AsyncBase {
1112

@@ -60,16 +61,18 @@ abstract class AsyncBase {
6061
@deprecated("`await` must be enclosed in an `async` block", "0.1")
6162
def await[T](awaitable: futureSystem.Fut[T]): T = ???
6263

64+
def awaitFallback[T, U](awaitable: futureSystem.Fut[T], p: futureSystem.Prom[U]): T @cpsParam[U, Unit] = ???
65+
66+
def fallbackEnabled = false
67+
6368
def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
6469
import c.universe._
6570

66-
val analyzer = AsyncAnalysis[c.type](c)
71+
val analyzer = AsyncAnalysis[c.type](c, this)
6772
val utils = TransformUtils[c.type](c)
6873
import utils.{name, defn}
69-
import builder.futureSystemOps
70-
71-
analyzer.reportUnsupportedAwaits(body.tree)
7274

75+
if (!analyzer.reportUnsupportedAwaits(body.tree) || !fallbackEnabled) {
7376
// Transform to A-normal form:
7477
// - no await calls in qualifiers or arguments,
7578
// - if/match only used in statement position.
@@ -92,6 +95,7 @@ abstract class AsyncBase {
9295
}
9396

9497
val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree)
98+
import builder.futureSystemOps
9599
val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap)
96100
import asyncBlock.asyncStates
97101
logDiagnostics(c)(anfTree, asyncStates.map(_.toString))
@@ -140,19 +144,16 @@ abstract class AsyncBase {
140144

141145
def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection)
142146

143-
def spawn(tree: Tree): Tree =
144-
futureSystemOps.future(c.Expr[Unit](tree))(futureSystemOps.execContext).tree
145-
146147
val code: c.Expr[futureSystem.Fut[T]] = {
147148
val isSimple = asyncStates.size == 1
148149
val tree =
149150
if (isSimple)
150-
Block(Nil, spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }`
151+
Block(Nil, futureSystemOps.spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }`
151152
else {
152153
Block(List[Tree](
153154
stateMachine,
154155
ValDef(NoMods, name.stateMachine, stateMachineType, New(Ident(name.stateMachineT), Nil)),
155-
spawn(Apply(selectStateMachine(name.apply), Nil))
156+
futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil))
156157
),
157158
futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree)
158159
}
@@ -161,6 +162,35 @@ abstract class AsyncBase {
161162

162163
AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}")
163164
code
165+
} else {
166+
// replace `await` invocations with `awaitFallback` invocations
167+
val awaitReplacer = new Transformer {
168+
override def transform(tree: Tree): Tree = tree match {
169+
case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == defn.Async_await =>
170+
val typeApp = treeCopy.TypeApply(fun, Ident(defn.Async_awaitFallback), List(TypeTree(futArgTpt.tpe), TypeTree(body.tree.tpe)))
171+
treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident(name.result))
172+
case _ =>
173+
super.transform(tree)
174+
}
175+
}
176+
val newBody = awaitReplacer.transform(body.tree)
177+
178+
val resetBody = reify {
179+
reset { c.Expr(c.resetAllAttrs(newBody.duplicate)).splice }
180+
}
181+
182+
val futureSystemOps = futureSystem.mkOps(c)
183+
val code = {
184+
val tree = Block(List(
185+
ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree),
186+
futureSystemOps.spawn(resetBody.tree)
187+
), futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](Ident(name.result))).tree)
188+
c.Expr[futureSystem.Fut[T]](tree)
189+
}
190+
191+
AsyncUtils.vprintln(s"async CPS fallback transform expands to:\n ${code.tree}")
192+
code
193+
}
164194
}
165195

166196
def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) {

src/main/scala/scala/async/AsyncAnalysis.scala

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ package scala.async
77
import scala.reflect.macros.Context
88
import scala.collection.mutable
99

10-
private[async] final case class AsyncAnalysis[C <: Context](c: C) {
11-
10+
private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: AsyncBase) {
1211
import c.universe._
1312

1413
val utils = TransformUtils[c.type](c)
@@ -21,8 +20,10 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
2120
*
2221
* Must be called on the original tree, not on the ANF transformed tree.
2322
*/
24-
def reportUnsupportedAwaits(tree: Tree) {
25-
new UnsupportedAwaitAnalyzer().traverse(tree)
23+
def reportUnsupportedAwaits(tree: Tree): Boolean = {
24+
val analyzer = new UnsupportedAwaitAnalyzer
25+
analyzer.traverse(tree)
26+
analyzer.hasUnsupportedAwaits
2627
}
2728

2829
/**
@@ -40,6 +41,8 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
4041
}
4142

4243
private class UnsupportedAwaitAnalyzer extends AsyncTraverser {
44+
var hasUnsupportedAwaits = false
45+
4346
override def nestedClass(classDef: ClassDef) {
4447
val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class"
4548
if (!reportUnsupportedAwait(classDef, s"nested $kind")) {
@@ -96,10 +99,16 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
9699
}
97100
badAwaits foreach {
98101
tree =>
99-
c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
102+
reportError(tree.pos, s"await must not be used under a $whyUnsupported.")
100103
}
101104
badAwaits.nonEmpty
102105
}
106+
107+
private def reportError(pos: Position, msg: String) {
108+
hasUnsupportedAwaits = true
109+
if (!asyncBase.fallbackEnabled)
110+
c.error(pos, msg)
111+
}
103112
}
104113

105114
private class AsyncDefinitionUseAnalyzer extends AsyncTraverser {
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package scala.async
2+
3+
import scala.language.experimental.macros
4+
5+
import scala.reflect.macros.Context
6+
import scala.util.continuations._
7+
8+
object AsyncWithCPSFallback extends AsyncBase {
9+
10+
import scala.concurrent.{Future, ExecutionContext}
11+
import ExecutionContext.Implicits.global
12+
13+
lazy val futureSystem = ScalaConcurrentFutureSystem
14+
type FS = ScalaConcurrentFutureSystem.type
15+
16+
/* Fall-back for `await` when it is called at an unsupported position.
17+
*/
18+
override def awaitFallback[T, U](awaitable: futureSystem.Fut[T], p: futureSystem.Prom[U]): T @cpsParam[U, Unit] =
19+
shift {
20+
(k: (T => U)) =>
21+
awaitable onComplete {
22+
case tr => p.success(k(tr.get))
23+
}
24+
}
25+
26+
override def fallbackEnabled = true
27+
28+
def async[T](body: T) = macro asyncImpl[T]
29+
30+
override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
31+
}

src/main/scala/scala/async/FutureSystem.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,12 @@ trait FutureSystem {
5151

5252
/** Complete a promise with a value */
5353
def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit]
54+
55+
def spawn(tree: context.Tree): context.Tree =
56+
future(context.Expr[Unit](tree))(execContext).tree
5457
}
5558

56-
def mkOps(c: Context): Ops {val context: c.type}
59+
def mkOps(c: Context): Ops { val context: c.type }
5760
}
5861

5962
object ScalaConcurrentFutureSystem extends FutureSystem {

src/main/scala/scala/async/TransformUtils.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,14 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
168168
val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe))
169169
val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal")
170170

171-
val Async_await = {
171+
private def asyncMember(name: String) = {
172172
val asyncMod = c.mirror.staticClass("scala.async.AsyncBase")
173173
val tpe = asyncMod.asType.toType
174-
tpe.member(c.universe.newTermName("await")).ensuring(_ != NoSymbol)
174+
tpe.member(newTermName(name)).ensuring(_ != NoSymbol)
175175
}
176+
177+
val Async_await = asyncMember("await")
178+
val Async_awaitFallback = asyncMember("awaitFallback")
176179
}
177180

178181
/** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */

0 commit comments

Comments
 (0)