@@ -6,6 +6,7 @@ package scala.async
6
6
7
7
import scala .language .experimental .macros
8
8
import scala .reflect .macros .Context
9
+ import scala .util .continuations .{cpsParam , reset }
9
10
10
11
object Async extends AsyncBase {
11
12
@@ -60,16 +61,18 @@ abstract class AsyncBase {
60
61
@ deprecated(" `await` must be enclosed in an `async` block" , " 0.1" )
61
62
def await [T ](awaitable : futureSystem.Fut [T ]): T = ???
62
63
64
+ def awaitFallback [T , U ](awaitable : futureSystem.Fut [T ], p : futureSystem.Prom [U ]): T @ cpsParam[U , Unit ] = ???
65
+
66
+ def fallbackEnabled = false
67
+
63
68
def asyncImpl [T : c.WeakTypeTag ](c : Context )(body : c.Expr [T ]): c.Expr [futureSystem.Fut [T ]] = {
64
69
import c .universe ._
65
70
66
- val analyzer = AsyncAnalysis [c.type ](c)
71
+ val analyzer = AsyncAnalysis [c.type ](c, this )
67
72
val utils = TransformUtils [c.type ](c)
68
73
import utils .{name , defn }
69
- import builder .futureSystemOps
70
-
71
- analyzer.reportUnsupportedAwaits(body.tree)
72
74
75
+ if (! analyzer.reportUnsupportedAwaits(body.tree) || ! fallbackEnabled) {
73
76
// Transform to A-normal form:
74
77
// - no await calls in qualifiers or arguments,
75
78
// - if/match only used in statement position.
@@ -92,6 +95,7 @@ abstract class AsyncBase {
92
95
}
93
96
94
97
val builder = ExprBuilder [c.type , futureSystem.type ](c, self.futureSystem, anfTree)
98
+ import builder .futureSystemOps
95
99
val asyncBlock : builder.AsyncBlock = builder.build(anfTree, renameMap)
96
100
import asyncBlock .asyncStates
97
101
logDiagnostics(c)(anfTree, asyncStates.map(_.toString))
@@ -140,19 +144,16 @@ abstract class AsyncBase {
140
144
141
145
def selectStateMachine (selection : TermName ) = Select (Ident (name.stateMachine), selection)
142
146
143
- def spawn (tree : Tree ): Tree =
144
- futureSystemOps.future(c.Expr [Unit ](tree))(futureSystemOps.execContext).tree
145
-
146
147
val code : c.Expr [futureSystem.Fut [T ]] = {
147
148
val isSimple = asyncStates.size == 1
148
149
val tree =
149
150
if (isSimple)
150
- Block (Nil , spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }`
151
+ Block (Nil , futureSystemOps. spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }`
151
152
else {
152
153
Block (List [Tree ](
153
154
stateMachine,
154
155
ValDef (NoMods , name.stateMachine, stateMachineType, New (Ident (name.stateMachineT), Nil )),
155
- spawn(Apply (selectStateMachine(name.apply), Nil ))
156
+ futureSystemOps. spawn(Apply (selectStateMachine(name.apply), Nil ))
156
157
),
157
158
futureSystemOps.promiseToFuture(c.Expr [futureSystem.Prom [T ]](selectStateMachine(name.result))).tree)
158
159
}
@@ -161,6 +162,35 @@ abstract class AsyncBase {
161
162
162
163
AsyncUtils .vprintln(s " async state machine transform expands to: \n ${code.tree}" )
163
164
code
165
+ } else {
166
+ // replace `await` invocations with `awaitFallback` invocations
167
+ val awaitReplacer = new Transformer {
168
+ override def transform (tree : Tree ): Tree = tree match {
169
+ case Apply (fun @ TypeApply (_, List (futArgTpt)), args) if fun.symbol == defn.Async_await =>
170
+ val typeApp = treeCopy.TypeApply (fun, Ident (defn.Async_awaitFallback ), List (TypeTree (futArgTpt.tpe), TypeTree (body.tree.tpe)))
171
+ treeCopy.Apply (tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident (name.result))
172
+ case _ =>
173
+ super .transform(tree)
174
+ }
175
+ }
176
+ val newBody = awaitReplacer.transform(body.tree)
177
+
178
+ val resetBody = reify {
179
+ reset { c.Expr (c.resetAllAttrs(newBody.duplicate)).splice }
180
+ }
181
+
182
+ val futureSystemOps = futureSystem.mkOps(c)
183
+ val code = {
184
+ val tree = Block (List (
185
+ ValDef (NoMods , name.result, TypeTree (futureSystemOps.promType[T ]), futureSystemOps.createProm[T ].tree),
186
+ futureSystemOps.spawn(resetBody.tree)
187
+ ), futureSystemOps.promiseToFuture(c.Expr [futureSystem.Prom [T ]](Ident (name.result))).tree)
188
+ c.Expr [futureSystem.Fut [T ]](tree)
189
+ }
190
+
191
+ AsyncUtils .vprintln(s " async CPS fallback transform expands to: \n ${code.tree}" )
192
+ code
193
+ }
164
194
}
165
195
166
196
def logDiagnostics (c : Context )(anfTree : c.Tree , states : Seq [String ]) {
0 commit comments