Skip to content

Commit 2cc299f

Browse files
Alexey AndreevAlexey Andreev
Alexey Andreev
authored and
Alexey Andreev
committed
JS: refactor coroutines to support inlining of suspend functions
1 parent e56d735 commit 2cc299f

File tree

10 files changed

+193
-79
lines changed

10 files changed

+193
-79
lines changed

compiler/testData/codegen/box/coroutines/inlineSuspendFunction.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// WITH_RUNTIME
22
// WITH_REFLECT
3-
// TARGET_BACKEND: JVM
43
class Controller {
54
fun withValue(v: String, x: Continuation<String>) {
65
x.resume(v)

js/js.dart-ast/src/com/google/dart/compiler/backend/js/ast/metadata/metadataProperties.kt

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2010-2014 JetBrains s.r.o.
2+
* Copyright 2010-2016 JetBrains s.r.o.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -62,10 +62,39 @@ var JsFunction.coroutineType: ClassDescriptor? by MetadataProperty(default = nul
6262

6363
var JsFunction.controllerType: ClassDescriptor? by MetadataProperty(default = null)
6464

65+
/**
66+
* Denotes a suspension call-site that is to be processed by coroutine transformer.
67+
* More clearly, denotes invocation that should immediately return from coroutine state machine
68+
*/
6569
var JsInvocation.isSuspend: Boolean by MetadataProperty(default = false)
6670

71+
/**
72+
* Denotes a pre-suspend call-site that is to be processed by coroutine transformer.
73+
* For normal suspend call-sites both [isSuspend] and [isPreSuspend] present.
74+
* For inlined suspend calls fake calls are generated before and after inlined function body.
75+
*/
76+
var JsInvocation.isPreSuspend: Boolean by MetadataProperty(default = false)
77+
78+
/**
79+
* Denotes a fake suspend call for inlining purposes.
80+
*/
81+
var JsInvocation.isFakeSuspend: Boolean by MetadataProperty(default = false)
82+
83+
/**
84+
* Denotes a call to coroutine's controller `handleResult` function.
85+
* See coroutine spec for explanation.
86+
*/
6787
var JsInvocation.isHandleResult: Boolean by MetadataProperty(default = false)
6888

89+
/**
90+
* Denotes a reference to coroutine's `result` field that contains result of
91+
* last suspended invocation.
92+
*/
93+
var JsNameRef.coroutineResult by MetadataProperty(default = false)
94+
95+
/**
96+
* Denotes a reference to coroutine's `controller` field that contains coroutines's controller
97+
*/
6998
var JsNameRef.coroutineController by MetadataProperty(default = false)
7099

71100
enum class TypeCheck {

js/js.inliner/src/org/jetbrains/kotlin/js/coroutine/CoroutineBodyTransformer.kt

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
package org.jetbrains.kotlin.js.coroutine
1818

1919
import com.google.dart.compiler.backend.js.ast.*
20-
import com.google.dart.compiler.backend.js.ast.metadata.MetadataProperty
21-
import com.google.dart.compiler.backend.js.ast.metadata.isSuspend
20+
import com.google.dart.compiler.backend.js.ast.metadata.*
2221
import org.jetbrains.kotlin.js.translate.utils.JsAstUtils
2322
import org.jetbrains.kotlin.utils.DFS
2423
import org.jetbrains.kotlin.utils.singletonOrEmptyList
@@ -45,6 +44,7 @@ class CoroutineBodyTransformer(val program: JsProgram, val scope: JsScope, val t
4544
private lateinit var nodesToSplit: Set<JsNode>
4645
private var currentCatchBlock = globalCatchBlock
4746
private val tryStack = mutableListOf(TryBlock(globalCatchBlock, null))
47+
private var suspendTarget: CoroutineBlock? = null
4848

4949
var hasFinallyBlocks = false
5050
get
@@ -188,8 +188,6 @@ class CoroutineBodyTransformer(val program: JsProgram, val scope: JsScope, val t
188188
}
189189

190190
override fun visitIf(x: JsIf) = splitIfNecessary(x) {
191-
x.ifExpression = handleExpression(x.ifExpression)
192-
193191
val ifBlock = currentBlock
194192

195193
val thenEntryBlock = CoroutineBlock()
@@ -433,39 +431,26 @@ class CoroutineBodyTransformer(val program: JsProgram, val scope: JsScope, val t
433431

434432
override fun visitExpressionStatement(x: JsExpressionStatement) {
435433
val expression = x.expression
436-
if (expression is JsInvocation && expression.isSuspend) {
437-
handleSuspend(expression)
434+
val splitExpression = handleExpression(expression)
435+
if (splitExpression == expression) {
436+
currentStatements += x
438437
}
439-
else {
440-
val splitExpression = handleExpression(x.expression)
441-
currentStatements += if (splitExpression == expression) x else JsExpressionStatement(expression)
438+
else if (splitExpression != null) {
439+
currentStatements += JsExpressionStatement(splitExpression).apply { synthetic = true }
442440
}
443441
}
444442

445443
override fun visitVars(x: JsVars) {
446-
super.visitVars(x)
447444
currentStatements += x
448445
}
449446

450-
override fun visit(x: JsVars.JsVar) {
451-
val initExpression = x.initExpression
452-
if (initExpression != null) {
453-
x.initExpression = handleExpression(initExpression)
454-
}
455-
}
456-
457447
override fun visitReturn(x: JsReturn) {
458448
val returnBlock = CoroutineBlock()
459449
val isInFinally = hasEnclosingFinallyBlock()
460450
if (isInFinally) {
461451
jumpWithFinally(0, returnBlock)
462452
}
463453

464-
val returnExpression = x.expression
465-
if (returnExpression != null) {
466-
x.expression = handleExpression(returnExpression)
467-
}
468-
469454
if (isInFinally) {
470455
currentStatements += x.expression?.makeStmt().singletonOrEmptyList()
471456
currentStatements += jump()
@@ -479,9 +464,8 @@ class CoroutineBodyTransformer(val program: JsProgram, val scope: JsScope, val t
479464

480465
override fun visitThrow(x: JsThrow) {
481466
if (throwFunctionName != null) {
482-
val exception = handleExpression(x.expression)
483467
val methodRef = JsNameRef(throwFunctionName, JsNameRef(controllerFieldName, JsLiteral.THIS))
484-
val invocation = JsInvocation(methodRef, exception).apply {
468+
val invocation = JsInvocation(methodRef, x.expression).apply {
485469
source = x.source
486470
}
487471
currentStatements += JsReturn(invocation)
@@ -491,28 +475,35 @@ class CoroutineBodyTransformer(val program: JsProgram, val scope: JsScope, val t
491475
}
492476
}
493477

494-
private fun handleExpression(expression: JsExpression): JsExpression {
495-
if (expression !in nodesToSplit) return expression
496-
497-
val visitor = object : JsVisitorWithContextImpl() {
498-
override fun endVisit(x: JsInvocation, ctx: JsContext<in JsExpression>) {
499-
if (x.isSuspend) {
500-
ctx.replaceMe(handleSuspend(x))
501-
}
502-
super.endVisit(x, ctx)
478+
private fun handleExpression(expression: JsExpression): JsExpression? {
479+
return if (expression is JsInvocation) {
480+
var result: JsExpression? = expression
481+
if (expression.isPreSuspend) {
482+
result = handlePreSuspend(expression)
503483
}
484+
if (expression.isSuspend) {
485+
handleSuspend(expression)
486+
result = null
487+
}
488+
result
489+
}
490+
else {
491+
expression
504492
}
505-
506-
return visitor.accept(expression)
507493
}
508494

509-
private fun handleSuspend(invocation: JsInvocation): JsExpression {
495+
private fun handlePreSuspend(invocation: JsInvocation): JsExpression? {
510496
val nextBlock = CoroutineBlock()
511497
currentStatements += state(nextBlock)
512-
currentStatements += JsReturn(invocation)
513-
currentBlock = nextBlock
498+
suspendTarget = nextBlock
499+
500+
return if (invocation.isFakeSuspend) null else invocation
501+
}
514502

515-
return JsNameRef(resultFieldName, JsLiteral.THIS)
503+
private fun handleSuspend(invocation: JsInvocation) {
504+
val invokeExpression = if (invocation.isFakeSuspend) invocation.arguments.getOrNull(0) else invocation
505+
currentStatements += JsReturn(invokeExpression)
506+
currentBlock = suspendTarget!!
516507
}
517508

518509
private fun state(target: CoroutineBlock): List<JsStatement> {

js/js.inliner/src/org/jetbrains/kotlin/js/coroutine/CoroutineFunctionTransformer.kt

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
package org.jetbrains.kotlin.js.coroutine
1818

1919
import com.google.dart.compiler.backend.js.ast.*
20+
import com.google.dart.compiler.backend.js.ast.metadata.SideEffectKind
2021
import com.google.dart.compiler.backend.js.ast.metadata.coroutineController
21-
import com.google.dart.compiler.backend.js.ast.metadata.isSuspend
22+
import com.google.dart.compiler.backend.js.ast.metadata.coroutineResult
23+
import com.google.dart.compiler.backend.js.ast.metadata.sideEffects
2224
import org.jetbrains.kotlin.descriptors.ClassDescriptor
2325
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
24-
import org.jetbrains.kotlin.js.inline.ExpressionDecomposer
2526
import org.jetbrains.kotlin.js.inline.clean.FunctionPostProcessor
2627
import org.jetbrains.kotlin.js.inline.util.collectLocalVariables
2728
import org.jetbrains.kotlin.js.inline.util.getInnerFunction
@@ -44,19 +45,6 @@ class CoroutineFunctionTransformer(
4445
private val className = function.scope.parent.declareFreshName("Coroutine\$${function.name}")
4546

4647
fun transform(): List<JsStatement> {
47-
val visitor = object : JsVisitorWithContextImpl() {
48-
override fun <T : JsNode?> doTraverse(node: T, ctx: JsContext<in JsStatement>) {
49-
super.doTraverse(node, ctx)
50-
if (node is JsStatement) {
51-
val statements = ExpressionDecomposer.preserveEvaluationOrder(function.scope, node) {
52-
it is JsInvocation && it.isSuspend
53-
}
54-
ctx.addPrevious(statements)
55-
}
56-
}
57-
}
58-
visitor.accept(body)
59-
6048
val throwFunction = controllerType.findFunction("handleException")
6149
val throwName = throwFunction?.let {
6250
val throwId = nameSuggestion.suggest(it)!!.names.last()
@@ -238,12 +226,23 @@ class CoroutineFunctionTransformer(
238226

239227
val visitor = object : JsVisitorWithContextImpl() {
240228
override fun endVisit(x: JsNameRef, ctx: JsContext<in JsNode>) {
241-
if (x.coroutineController) {
242-
ctx.replaceMe(JsNameRef(transformer.controllerFieldName, x.qualifier))
243-
}
244-
if (x.qualifier == null && x.name in localVariables) {
245-
val fieldName = scope.getFieldName(x.name!!)
246-
ctx.replaceMe(JsNameRef(fieldName, JsLiteral.THIS))
229+
when {
230+
x.coroutineController -> {
231+
ctx.replaceMe(JsNameRef(transformer.controllerFieldName, x.qualifier).apply {
232+
sideEffects = SideEffectKind.PURE
233+
})
234+
}
235+
236+
x.coroutineResult -> {
237+
ctx.replaceMe(JsNameRef(transformer.resultFieldName, x.qualifier).apply {
238+
sideEffects = SideEffectKind.DEPENDS_ON_STATE
239+
})
240+
}
241+
242+
x.qualifier == null && x.name in localVariables -> {
243+
val fieldName = scope.getFieldName(x.name!!)
244+
ctx.replaceMe(JsNameRef(fieldName, JsLiteral.THIS))
245+
}
247246
}
248247
}
249248

js/js.inliner/src/org/jetbrains/kotlin/js/inline/ExpressionDecomposer.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ internal open class JsExpressionVisitor() : JsVisitorWithContextImpl() {
455455
/**
456456
* Returns descendants of receiver, matched by [predicate].
457457
*/
458-
fun JsNode.match(predicate: (JsNode) -> Boolean): Set<JsNode> {
458+
private fun JsNode.match(predicate: (JsNode) -> Boolean): Set<JsNode> {
459459
val visitor = object : JsExpressionVisitor() {
460460
val matched = IdentitySet<JsNode>()
461461

@@ -475,7 +475,7 @@ fun JsNode.match(predicate: (JsNode) -> Boolean): Set<JsNode> {
475475
/**
476476
* Returns set of nodes, that satisfy transitive closure of `is parent` relation, starting from [nodes].
477477
*/
478-
fun JsNode.withParentsOfNodes(nodes: Set<JsNode>): Set<JsNode> {
478+
private fun JsNode.withParentsOfNodes(nodes: Set<JsNode>): Set<JsNode> {
479479
val visitor = object : JsExpressionVisitor() {
480480
private val stack = SmartList<JsNode>()
481481
val matched = IdentitySet<JsNode>()

js/js.inliner/src/org/jetbrains/kotlin/js/inline/FunctionInlineMutator.kt

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
package org.jetbrains.kotlin.js.inline
1818

1919
import com.google.dart.compiler.backend.js.ast.*
20-
import com.google.dart.compiler.backend.js.ast.metadata.staticRef
21-
import com.google.dart.compiler.backend.js.ast.metadata.synthetic
20+
import com.google.dart.compiler.backend.js.ast.metadata.*
2221
import org.jetbrains.kotlin.js.inline.clean.removeDefaultInitializers
22+
import org.jetbrains.kotlin.js.inline.clean.removeFakeSuspend
2323
import org.jetbrains.kotlin.js.inline.context.InliningContext
2424
import org.jetbrains.kotlin.js.inline.context.NamingContext
2525
import org.jetbrains.kotlin.js.inline.util.*
2626
import org.jetbrains.kotlin.js.inline.util.rewriters.ReturnReplacingVisitor
27+
import org.jetbrains.kotlin.js.translate.utils.JsAstUtils
2728

2829
class FunctionInlineMutator
2930
private constructor(
@@ -34,17 +35,28 @@ private constructor(
3435
private val namingContext: NamingContext
3536
private val body: JsBlock
3637
private var resultExpr: JsExpression? = null
38+
private var resultName: JsName? = null
3739
private var breakLabel: JsLabel? = null
3840
private val currentStatement = inliningContext.statementContext.currentNode
3941

4042
init {
4143
namingContext = inliningContext.newNamingContext()
4244
val functionContext = inliningContext.functionContext
4345
invokedFunction = uncoverClosure(functionContext.getFunctionDefinition(call).deepCopy())
44-
body = invokedFunction.body
46+
47+
// Removing fakeSuspend is not just an optimization.
48+
// Reentrant suspends are not supported by coroutine transformers.
49+
body = if (call.isSuspend) invokedFunction.body.removeFakeSuspend() else invokedFunction.body
4550
}
4651

4752
private fun process() {
53+
if (call.isSuspend) {
54+
val fakeSuspendCall = JsInvocation(JsAstUtils.pureFqn("fakeSuspend", JsAstUtils.pureFqn("Kotlin", null)))
55+
fakeSuspendCall.isPreSuspend = true
56+
fakeSuspendCall.isFakeSuspend = true
57+
body.statements.add(0, JsAstUtils.asSyntheticStatement(fakeSuspendCall))
58+
}
59+
4860
val arguments = getArguments()
4961
val parameters = getParameters()
5062

@@ -111,14 +123,19 @@ private constructor(
111123
val breakName = namingContext.getFreshName(getBreakLabel())
112124
this.breakLabel = JsLabel(breakName).apply { synthetic = true }
113125

114-
val visitor = ReturnReplacingVisitor(resultExpr as? JsNameRef, breakName.makeRef(), invokedFunction)
126+
val visitor = ReturnReplacingVisitor(resultExpr as? JsNameRef, breakName.makeRef(), invokedFunction, call.isSuspend)
115127
visitor.accept(body)
128+
129+
visitor.makeFakeSuspendCall(null)?.let { fakeSuspend ->
130+
body.statements += JsAstUtils.asSyntheticStatement(fakeSuspend)
131+
}
116132
}
117133

118134
private fun getResultReference(): JsNameRef? {
119135
if (!isResultNeeded(call)) return null
120136

121137
val resultName = namingContext.getFreshName(getResultLabel())
138+
this.resultName = resultName
122139
namingContext.newVar(resultName, null)
123140
return resultName.makeRef()
124141
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright 2010-2016 JetBrains s.r.o.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.jetbrains.kotlin.js.inline.clean
18+
19+
import com.google.dart.compiler.backend.js.ast.*
20+
import com.google.dart.compiler.backend.js.ast.metadata.isFakeSuspend
21+
import com.google.dart.compiler.backend.js.ast.metadata.isPreSuspend
22+
import com.google.dart.compiler.backend.js.ast.metadata.isSuspend
23+
import com.google.dart.compiler.backend.js.ast.metadata.synthetic
24+
import org.jetbrains.kotlin.js.translate.context.Namer
25+
26+
fun <T : JsNode> T.removeFakeSuspend(): T {
27+
val visitor = object : JsVisitorWithContextImpl() {
28+
override fun endVisit(x: JsInvocation, ctx: JsContext<in JsNode>) {
29+
if (x.isFakeSuspend) {
30+
ctx.replaceMe(x.arguments.getOrElse(0) { Namer.getUndefinedExpression() })
31+
}
32+
else {
33+
x.isSuspend = false
34+
x.isPreSuspend = false
35+
}
36+
super.endVisit(x, ctx)
37+
}
38+
39+
override fun visit(x: JsExpressionStatement, ctx: JsContext<*>): Boolean {
40+
val expression = x.expression
41+
if (expression is JsInvocation && expression.isFakeSuspend && expression.arguments.isEmpty()) {
42+
x.synthetic = true
43+
}
44+
return super.visit(x, ctx)
45+
}
46+
}
47+
return visitor.accept(this)
48+
}

0 commit comments

Comments
 (0)