Extend lambda inlining to getOrPut and implicit it calls
This commit is contained in:
parent
0c3242cbd8
commit
f4ab2ebab4
@ -4981,24 +4981,23 @@ class BytecodeCompiler(
|
|||||||
|
|
||||||
private fun compileInlineHigherOrderMethodCall(ref: MethodCallRef): CompiledValue? {
|
private fun compileInlineHigherOrderMethodCall(ref: MethodCallRef): CompiledValue? {
|
||||||
val spec = inlineHigherOrderMethodSpec(ref.name) ?: return null
|
val spec = inlineHigherOrderMethodSpec(ref.name) ?: return null
|
||||||
if (ref.args.size != 1 || ref.args.any { it.isSplat || it.name != null }) return null
|
if (ref.args.size != spec.argCount || ref.args.any { it.isSplat || it.name != null }) return null
|
||||||
if (!ref.explicitTypeArgs.isNullOrEmpty()) return null
|
if (!ref.explicitTypeArgs.isNullOrEmpty()) return null
|
||||||
val lambdaRef = extractExactLambdaRef(ref.args.first().value) ?: return null
|
val lambdaRef = extractExactLambdaRef(ref.args[spec.lambdaArgIndex].value) ?: return null
|
||||||
val inlineRef = lambdaRef.inlineBodyRef ?: return null
|
val inlineRef = lambdaRef.inlineBodyRef ?: return null
|
||||||
return when (spec.kind) {
|
return when (spec.kind) {
|
||||||
InlineHigherOrderMethodKind.UNARY_ARGUMENT -> {
|
InlineHigherOrderMethodKind.UNARY_ARGUMENT -> {
|
||||||
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null
|
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null
|
||||||
val paramName = lambdaRef.inlineParamNames()?.singleOrNull() ?: return null
|
|
||||||
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
|
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
|
||||||
val receiverObj = ensureObjSlot(receiver)
|
val receiverObj = ensureObjSlot(receiver)
|
||||||
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
|
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
|
||||||
val receiverSlot = materializeInlineBinding(receiver)
|
val bindings = prepareInlineLambdaBindingsFromValues(lambdaRef, listOf(receiver)) ?: return@compileOptionalInlineMethod null
|
||||||
when (spec.result) {
|
when (spec.result) {
|
||||||
InlineHigherOrderResultMode.BLOCK_RESULT ->
|
InlineHigherOrderResultMode.BLOCK_RESULT ->
|
||||||
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to receiverSlot))
|
compileInlineLambdaBody(lambdaRef, inlineRef, bindings)
|
||||||
InlineHigherOrderResultMode.RETURN_RECEIVER -> {
|
InlineHigherOrderResultMode.RETURN_RECEIVER -> {
|
||||||
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to receiverSlot)) ?: return@compileOptionalInlineMethod null
|
compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return@compileOptionalInlineMethod null
|
||||||
CompiledValue(receiverSlot, receiver.type)
|
CompiledValue(receiverObj.slot, SlotType.OBJ)
|
||||||
}
|
}
|
||||||
else -> null
|
else -> null
|
||||||
}
|
}
|
||||||
@ -5015,11 +5014,22 @@ class BytecodeCompiler(
|
|||||||
}
|
}
|
||||||
InlineHigherOrderMethodKind.ITERABLE -> {
|
InlineHigherOrderMethodKind.ITERABLE -> {
|
||||||
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null
|
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null
|
||||||
val paramName = lambdaRef.inlineParamNames()?.singleOrNull() ?: return null
|
|
||||||
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
|
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
|
||||||
val receiverObj = ensureObjSlot(receiver)
|
val receiverObj = ensureObjSlot(receiver)
|
||||||
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
|
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
|
||||||
compileInlineIterableLambdaLoop(receiverObj, ref, lambdaRef, inlineRef, paramName, spec.result)
|
compileInlineIterableLambdaLoop(receiverObj, ref, lambdaRef, inlineRef, spec.result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
InlineHigherOrderMethodKind.MAP_GET_OR_PUT -> {
|
||||||
|
val receiverClass = resolveReceiverClass(ref.receiver) ?: return null
|
||||||
|
if (receiverClass != ObjMap.type) return null
|
||||||
|
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null
|
||||||
|
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
|
||||||
|
val receiverObj = ensureObjSlot(receiver)
|
||||||
|
val key = compileArgValue(ref.args[0].value) ?: return null
|
||||||
|
val keyObj = ensureObjSlot(key)
|
||||||
|
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
|
||||||
|
compileInlineMapGetOrPut(receiverObj, keyObj, lambdaRef, inlineRef)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -5050,26 +5060,58 @@ class BytecodeCompiler(
|
|||||||
if (ref.tailBlock) return null
|
if (ref.tailBlock) return null
|
||||||
if (!ref.explicitTypeArgs.isNullOrEmpty()) return null
|
if (!ref.explicitTypeArgs.isNullOrEmpty()) return null
|
||||||
val inlineRef = lambdaRef.inlineBodyRef ?: return null
|
val inlineRef = lambdaRef.inlineBodyRef ?: return null
|
||||||
val bindings = prepareInlineLambdaBindings(lambdaRef, ref.args) ?: return null
|
if (lambdaRef.argsDeclaration == null && ref.args.size != 1 &&
|
||||||
|
!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = false)
|
||||||
|
) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
val bindings = prepareInlineLambdaBindingsFromArgs(lambdaRef, ref.args) ?: return null
|
||||||
return compileInlineLambdaBody(lambdaRef, inlineRef, bindings)
|
return compileInlineLambdaBody(lambdaRef, inlineRef, bindings)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun prepareInlineLambdaBindings(
|
private fun prepareInlineLambdaBindingsFromArgs(
|
||||||
lambdaRef: LambdaFnRef,
|
lambdaRef: LambdaFnRef,
|
||||||
args: List<ParsedArgument>
|
args: List<ParsedArgument>
|
||||||
): List<Pair<String, Int>>? {
|
): List<Pair<String, Int>>? {
|
||||||
if (args.any { it.isSplat || it.name != null }) return null
|
if (args.any { it.isSplat || it.name != null }) return null
|
||||||
val paramNames = lambdaRef.inlineParamNames() ?: return null
|
val compiledArgs = ArrayList<CompiledValue>(args.size)
|
||||||
if (args.size != paramNames.size) return null
|
for (arg in args) {
|
||||||
|
compiledArgs += compileArgValue(arg.value) ?: return null
|
||||||
|
}
|
||||||
|
return prepareInlineLambdaBindingsFromValues(lambdaRef, compiledArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun prepareInlineLambdaBindingsFromValues(
|
||||||
|
lambdaRef: LambdaFnRef,
|
||||||
|
args: List<CompiledValue>
|
||||||
|
): List<Pair<String, Int>>? {
|
||||||
|
val declaration = lambdaRef.argsDeclaration
|
||||||
|
if (declaration == null) {
|
||||||
|
val implicitValue = when (args.size) {
|
||||||
|
0 -> CompiledValue(ensureVoidSlot(), SlotType.OBJ)
|
||||||
|
1 -> args[0]
|
||||||
|
else -> buildInlineImplicitItList(args) ?: return null
|
||||||
|
}
|
||||||
|
return listOf("it" to materializeInlineBinding(implicitValue))
|
||||||
|
}
|
||||||
|
if (declaration.params.any { it.isEllipsis || it.defaultValue != null }) return null
|
||||||
|
if (args.size != declaration.params.size) return null
|
||||||
if (args.isEmpty()) return emptyList()
|
if (args.isEmpty()) return emptyList()
|
||||||
val bindings = ArrayList<Pair<String, Int>>(args.size)
|
val bindings = ArrayList<Pair<String, Int>>(args.size)
|
||||||
for ((index, arg) in args.withIndex()) {
|
for ((index, param) in declaration.params.withIndex()) {
|
||||||
val compiled = compileArgValue(arg.value) ?: return null
|
bindings += param.name to materializeInlineBinding(args[index])
|
||||||
bindings += paramNames[index] to materializeInlineBinding(compiled)
|
|
||||||
}
|
}
|
||||||
return bindings
|
return bindings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun buildInlineImplicitItList(args: List<CompiledValue>): CompiledValue? {
|
||||||
|
val list = createEmptyMutableList() ?: return null
|
||||||
|
for (arg in args) {
|
||||||
|
appendToList(list, arg) ?: return null
|
||||||
|
}
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
|
||||||
private fun LambdaFnRef.inlineParamNames(): List<String>? {
|
private fun LambdaFnRef.inlineParamNames(): List<String>? {
|
||||||
val declaration = argsDeclaration
|
val declaration = argsDeclaration
|
||||||
if (declaration == null) {
|
if (declaration == null) {
|
||||||
@ -5092,7 +5134,8 @@ class BytecodeCompiler(
|
|||||||
private enum class InlineHigherOrderMethodKind {
|
private enum class InlineHigherOrderMethodKind {
|
||||||
UNARY_ARGUMENT,
|
UNARY_ARGUMENT,
|
||||||
RECEIVER,
|
RECEIVER,
|
||||||
ITERABLE
|
ITERABLE,
|
||||||
|
MAP_GET_OR_PUT
|
||||||
}
|
}
|
||||||
|
|
||||||
private enum class InlineHigherOrderResultMode {
|
private enum class InlineHigherOrderResultMode {
|
||||||
@ -5107,7 +5150,9 @@ class BytecodeCompiler(
|
|||||||
|
|
||||||
private data class InlineHigherOrderMethodSpec(
|
private data class InlineHigherOrderMethodSpec(
|
||||||
val kind: InlineHigherOrderMethodKind,
|
val kind: InlineHigherOrderMethodKind,
|
||||||
val result: InlineHigherOrderResultMode
|
val result: InlineHigherOrderResultMode,
|
||||||
|
val argCount: Int = 1,
|
||||||
|
val lambdaArgIndex: Int = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
private data class InlineReceiverInfo(
|
private data class InlineReceiverInfo(
|
||||||
@ -5153,6 +5198,12 @@ class BytecodeCompiler(
|
|||||||
InlineHigherOrderMethodKind.ITERABLE,
|
InlineHigherOrderMethodKind.ITERABLE,
|
||||||
InlineHigherOrderResultMode.ASSOCIATE_BY
|
InlineHigherOrderResultMode.ASSOCIATE_BY
|
||||||
)
|
)
|
||||||
|
"getOrPut" -> InlineHigherOrderMethodSpec(
|
||||||
|
InlineHigherOrderMethodKind.MAP_GET_OR_PUT,
|
||||||
|
InlineHigherOrderResultMode.BLOCK_RESULT,
|
||||||
|
argCount = 2,
|
||||||
|
lambdaArgIndex = 1
|
||||||
|
)
|
||||||
else -> null
|
else -> null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -5321,7 +5372,6 @@ class BytecodeCompiler(
|
|||||||
ref: MethodCallRef,
|
ref: MethodCallRef,
|
||||||
lambdaRef: LambdaFnRef,
|
lambdaRef: LambdaFnRef,
|
||||||
inlineRef: ObjRef,
|
inlineRef: ObjRef,
|
||||||
paramName: String,
|
|
||||||
behavior: InlineHigherOrderResultMode
|
behavior: InlineHigherOrderResultMode
|
||||||
): CompiledValue? {
|
): CompiledValue? {
|
||||||
val iterableMethods = ObjIterable.instanceMethodIdMap(includeAbstract = true)
|
val iterableMethods = ObjIterable.instanceMethodIdMap(includeAbstract = true)
|
||||||
@ -5366,16 +5416,17 @@ class BytecodeCompiler(
|
|||||||
val nextSlot = allocSlot()
|
val nextSlot = allocSlot()
|
||||||
builder.emit(Opcode.CALL_MEMBER_SLOT, iterSlot, nextMethodId, 0, 0, nextSlot)
|
builder.emit(Opcode.CALL_MEMBER_SLOT, iterSlot, nextMethodId, 0, 0, nextSlot)
|
||||||
val nextObj = ensureObjSlot(CompiledValue(nextSlot, SlotType.UNKNOWN))
|
val nextObj = ensureObjSlot(CompiledValue(nextSlot, SlotType.UNKNOWN))
|
||||||
|
val bindings = prepareInlineLambdaBindingsFromValues(lambdaRef, listOf(nextObj)) ?: return null
|
||||||
when (behavior) {
|
when (behavior) {
|
||||||
InlineHigherOrderResultMode.FOR_EACH -> {
|
InlineHigherOrderResultMode.FOR_EACH -> {
|
||||||
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
|
compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
|
||||||
}
|
}
|
||||||
InlineHigherOrderResultMode.MAP -> {
|
InlineHigherOrderResultMode.MAP -> {
|
||||||
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
|
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
|
||||||
appendToList(result, mapped) ?: return null
|
appendToList(result, mapped) ?: return null
|
||||||
}
|
}
|
||||||
InlineHigherOrderResultMode.FILTER -> {
|
InlineHigherOrderResultMode.FILTER -> {
|
||||||
val predicate = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
|
val predicate = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
|
||||||
val predicateBool = compileValueAsBool(predicate)
|
val predicateBool = compileValueAsBool(predicate)
|
||||||
val skipLabel = builder.label()
|
val skipLabel = builder.label()
|
||||||
builder.emit(
|
builder.emit(
|
||||||
@ -5386,7 +5437,7 @@ class BytecodeCompiler(
|
|||||||
builder.mark(skipLabel)
|
builder.mark(skipLabel)
|
||||||
}
|
}
|
||||||
InlineHigherOrderResultMode.MAP_NOT_NULL -> {
|
InlineHigherOrderResultMode.MAP_NOT_NULL -> {
|
||||||
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
|
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
|
||||||
val mappedObj = ensureObjSlot(mapped)
|
val mappedObj = ensureObjSlot(mapped)
|
||||||
val nullSlot = allocSlot()
|
val nullSlot = allocSlot()
|
||||||
builder.emit(Opcode.CONST_NULL, nullSlot)
|
builder.emit(Opcode.CONST_NULL, nullSlot)
|
||||||
@ -5401,7 +5452,7 @@ class BytecodeCompiler(
|
|||||||
builder.mark(skipLabel)
|
builder.mark(skipLabel)
|
||||||
}
|
}
|
||||||
InlineHigherOrderResultMode.ASSOCIATE_BY -> {
|
InlineHigherOrderResultMode.ASSOCIATE_BY -> {
|
||||||
val key = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
|
val key = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
|
||||||
appendToMap(result, key, nextObj)
|
appendToMap(result, key, nextObj)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -5433,6 +5484,34 @@ class BytecodeCompiler(
|
|||||||
builder.emit(Opcode.SET_INDEX, mapObj.slot, keyObj.slot, itemObj.slot)
|
builder.emit(Opcode.SET_INDEX, mapObj.slot, keyObj.slot, itemObj.slot)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun compileInlineMapGetOrPut(
|
||||||
|
receiverObj: CompiledValue,
|
||||||
|
keyObj: CompiledValue,
|
||||||
|
lambdaRef: LambdaFnRef,
|
||||||
|
inlineRef: ObjRef
|
||||||
|
): CompiledValue? {
|
||||||
|
val dst = allocSlot()
|
||||||
|
val hasKey = allocSlot()
|
||||||
|
builder.emit(Opcode.CONTAINS_OBJ, receiverObj.slot, keyObj.slot, hasKey)
|
||||||
|
val existingLabel = builder.label()
|
||||||
|
val endLabel = builder.label()
|
||||||
|
builder.emit(
|
||||||
|
Opcode.JMP_IF_TRUE,
|
||||||
|
listOf(CmdBuilder.Operand.IntVal(hasKey), CmdBuilder.Operand.LabelRef(existingLabel))
|
||||||
|
)
|
||||||
|
val bindings = prepareInlineLambdaBindingsFromValues(lambdaRef, emptyList()) ?: return null
|
||||||
|
val computed = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
|
||||||
|
val computedObj = ensureObjSlot(computed)
|
||||||
|
builder.emit(Opcode.SET_INDEX, receiverObj.slot, keyObj.slot, computedObj.slot)
|
||||||
|
builder.emit(Opcode.MOVE_OBJ, computedObj.slot, dst)
|
||||||
|
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
|
||||||
|
builder.mark(existingLabel)
|
||||||
|
builder.emit(Opcode.GET_INDEX, receiverObj.slot, keyObj.slot, dst)
|
||||||
|
builder.mark(endLabel)
|
||||||
|
updateSlotType(dst, SlotType.OBJ)
|
||||||
|
return CompiledValue(dst, SlotType.OBJ)
|
||||||
|
}
|
||||||
|
|
||||||
private fun compileValueAsBool(value: CompiledValue): CompiledValue {
|
private fun compileValueAsBool(value: CompiledValue): CompiledValue {
|
||||||
if (value.type == SlotType.BOOL) return value
|
if (value.type == SlotType.BOOL) return value
|
||||||
val dst = allocSlot()
|
val dst = allocSlot()
|
||||||
@ -5779,7 +5858,7 @@ class BytecodeCompiler(
|
|||||||
return when (entry.ownerKind) {
|
return when (entry.ownerKind) {
|
||||||
CaptureOwnerFrameKind.LOCAL -> {
|
CaptureOwnerFrameKind.LOCAL -> {
|
||||||
localSlotIndexByKey[key]?.let { return scopeSlotCount + it }
|
localSlotIndexByKey[key]?.let { return scopeSlotCount + it }
|
||||||
null
|
scopeSlotMap[key]
|
||||||
}
|
}
|
||||||
CaptureOwnerFrameKind.MODULE -> {
|
CaptureOwnerFrameKind.MODULE -> {
|
||||||
localSlotIndexByKey[key]?.let { return scopeSlotCount + it }
|
localSlotIndexByKey[key]?.let { return scopeSlotCount + it }
|
||||||
|
|||||||
@ -255,6 +255,54 @@ class CompilerVmReviewRegressionTest {
|
|||||||
assertEquals(36, result.list[6].toInt())
|
assertEquals(36, result.list[6].toInt())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun directLambdaInliningMatchesImplicitItInvocationSemantics() = runTest {
|
||||||
|
val script = Compiler.compile(
|
||||||
|
Source(
|
||||||
|
"<direct-inline-it-semantics>",
|
||||||
|
"""
|
||||||
|
val zeroFn = { if (it == void) 1 else 0 }
|
||||||
|
val multiFn = { it }
|
||||||
|
val zero = zeroFn()
|
||||||
|
val multi = multiFn(1, 2, 3)
|
||||||
|
[zero, multi]
|
||||||
|
""".trimIndent()
|
||||||
|
),
|
||||||
|
Script.defaultImportManager
|
||||||
|
)
|
||||||
|
|
||||||
|
val scope = Script.newScope()
|
||||||
|
val result = script.execute(scope) as ObjList
|
||||||
|
|
||||||
|
assertEquals(1, result.list[0].toInt())
|
||||||
|
val multi = result.list[1] as ObjList
|
||||||
|
assertEquals(listOf(1, 2, 3), multi.list.map { it.toInt() })
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun mapGetOrPutUsesInlineDefaultLambda() = runTest {
|
||||||
|
val script = Compiler.compile(
|
||||||
|
Source(
|
||||||
|
"<map-get-or-put-inline>",
|
||||||
|
"""
|
||||||
|
val offset = 10
|
||||||
|
val m = Map()
|
||||||
|
val first = m.getOrPut("k") { offset + 1 }
|
||||||
|
val second = m.getOrPut("k") { offset + 2 }
|
||||||
|
[first, second, m["k"]]
|
||||||
|
""".trimIndent()
|
||||||
|
),
|
||||||
|
Script.defaultImportManager
|
||||||
|
)
|
||||||
|
|
||||||
|
val scope = Script.newScope()
|
||||||
|
val result = script.execute(scope) as ObjList
|
||||||
|
|
||||||
|
assertEquals(11, result.list[0].toInt())
|
||||||
|
assertEquals(11, result.list[1].toInt())
|
||||||
|
assertEquals(11, result.list[2].toInt())
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun subjectlessWhenReportsScriptError() = runTest {
|
fun subjectlessWhenReportsScriptError() = runTest {
|
||||||
val ex = assertFailsWith<ScriptError> {
|
val ex = assertFailsWith<ScriptError> {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user