Extend lambda inlining to getOrPut and implicit it calls

This commit is contained in:
Sergey Chernov 2026-04-21 19:20:25 +03:00
parent 0c3242cbd8
commit f4ab2ebab4
2 changed files with 152 additions and 25 deletions

View File

@ -4981,24 +4981,23 @@ class BytecodeCompiler(
private fun compileInlineHigherOrderMethodCall(ref: MethodCallRef): CompiledValue? {
val spec = inlineHigherOrderMethodSpec(ref.name) ?: return null
if (ref.args.size != 1 || ref.args.any { it.isSplat || it.name != null }) return null
if (ref.args.size != spec.argCount || ref.args.any { it.isSplat || it.name != null }) return null
if (!ref.explicitTypeArgs.isNullOrEmpty()) return null
val lambdaRef = extractExactLambdaRef(ref.args.first().value) ?: return null
val lambdaRef = extractExactLambdaRef(ref.args[spec.lambdaArgIndex].value) ?: return null
val inlineRef = lambdaRef.inlineBodyRef ?: return null
return when (spec.kind) {
InlineHigherOrderMethodKind.UNARY_ARGUMENT -> {
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null
val paramName = lambdaRef.inlineParamNames()?.singleOrNull() ?: return null
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
val receiverObj = ensureObjSlot(receiver)
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
val receiverSlot = materializeInlineBinding(receiver)
val bindings = prepareInlineLambdaBindingsFromValues(lambdaRef, listOf(receiver)) ?: return@compileOptionalInlineMethod null
when (spec.result) {
InlineHigherOrderResultMode.BLOCK_RESULT ->
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to receiverSlot))
compileInlineLambdaBody(lambdaRef, inlineRef, bindings)
InlineHigherOrderResultMode.RETURN_RECEIVER -> {
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to receiverSlot)) ?: return@compileOptionalInlineMethod null
CompiledValue(receiverSlot, receiver.type)
compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return@compileOptionalInlineMethod null
CompiledValue(receiverObj.slot, SlotType.OBJ)
}
else -> null
}
@ -5015,11 +5014,22 @@ class BytecodeCompiler(
}
InlineHigherOrderMethodKind.ITERABLE -> {
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null
val paramName = lambdaRef.inlineParamNames()?.singleOrNull() ?: return null
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
val receiverObj = ensureObjSlot(receiver)
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
compileInlineIterableLambdaLoop(receiverObj, ref, lambdaRef, inlineRef, paramName, spec.result)
compileInlineIterableLambdaLoop(receiverObj, ref, lambdaRef, inlineRef, spec.result)
}
}
InlineHigherOrderMethodKind.MAP_GET_OR_PUT -> {
val receiverClass = resolveReceiverClass(ref.receiver) ?: return null
if (receiverClass != ObjMap.type) return null
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
val receiverObj = ensureObjSlot(receiver)
val key = compileArgValue(ref.args[0].value) ?: return null
val keyObj = ensureObjSlot(key)
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
compileInlineMapGetOrPut(receiverObj, keyObj, lambdaRef, inlineRef)
}
}
}
@ -5050,26 +5060,58 @@ class BytecodeCompiler(
if (ref.tailBlock) return null
if (!ref.explicitTypeArgs.isNullOrEmpty()) return null
val inlineRef = lambdaRef.inlineBodyRef ?: return null
val bindings = prepareInlineLambdaBindings(lambdaRef, ref.args) ?: return null
if (lambdaRef.argsDeclaration == null && ref.args.size != 1 &&
!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = false)
) {
return null
}
val bindings = prepareInlineLambdaBindingsFromArgs(lambdaRef, ref.args) ?: return null
return compileInlineLambdaBody(lambdaRef, inlineRef, bindings)
}
private fun prepareInlineLambdaBindings(
private fun prepareInlineLambdaBindingsFromArgs(
lambdaRef: LambdaFnRef,
args: List<ParsedArgument>
): List<Pair<String, Int>>? {
if (args.any { it.isSplat || it.name != null }) return null
val paramNames = lambdaRef.inlineParamNames() ?: return null
if (args.size != paramNames.size) return null
val compiledArgs = ArrayList<CompiledValue>(args.size)
for (arg in args) {
compiledArgs += compileArgValue(arg.value) ?: return null
}
return prepareInlineLambdaBindingsFromValues(lambdaRef, compiledArgs)
}
private fun prepareInlineLambdaBindingsFromValues(
lambdaRef: LambdaFnRef,
args: List<CompiledValue>
): List<Pair<String, Int>>? {
val declaration = lambdaRef.argsDeclaration
if (declaration == null) {
val implicitValue = when (args.size) {
0 -> CompiledValue(ensureVoidSlot(), SlotType.OBJ)
1 -> args[0]
else -> buildInlineImplicitItList(args) ?: return null
}
return listOf("it" to materializeInlineBinding(implicitValue))
}
if (declaration.params.any { it.isEllipsis || it.defaultValue != null }) return null
if (args.size != declaration.params.size) return null
if (args.isEmpty()) return emptyList()
val bindings = ArrayList<Pair<String, Int>>(args.size)
for ((index, arg) in args.withIndex()) {
val compiled = compileArgValue(arg.value) ?: return null
bindings += paramNames[index] to materializeInlineBinding(compiled)
for ((index, param) in declaration.params.withIndex()) {
bindings += param.name to materializeInlineBinding(args[index])
}
return bindings
}
private fun buildInlineImplicitItList(args: List<CompiledValue>): CompiledValue? {
val list = createEmptyMutableList() ?: return null
for (arg in args) {
appendToList(list, arg) ?: return null
}
return list
}
private fun LambdaFnRef.inlineParamNames(): List<String>? {
val declaration = argsDeclaration
if (declaration == null) {
@ -5092,7 +5134,8 @@ class BytecodeCompiler(
private enum class InlineHigherOrderMethodKind {
UNARY_ARGUMENT,
RECEIVER,
ITERABLE
ITERABLE,
MAP_GET_OR_PUT
}
private enum class InlineHigherOrderResultMode {
@ -5107,7 +5150,9 @@ class BytecodeCompiler(
private data class InlineHigherOrderMethodSpec(
val kind: InlineHigherOrderMethodKind,
val result: InlineHigherOrderResultMode
val result: InlineHigherOrderResultMode,
val argCount: Int = 1,
val lambdaArgIndex: Int = 0
)
private data class InlineReceiverInfo(
@ -5153,6 +5198,12 @@ class BytecodeCompiler(
InlineHigherOrderMethodKind.ITERABLE,
InlineHigherOrderResultMode.ASSOCIATE_BY
)
"getOrPut" -> InlineHigherOrderMethodSpec(
InlineHigherOrderMethodKind.MAP_GET_OR_PUT,
InlineHigherOrderResultMode.BLOCK_RESULT,
argCount = 2,
lambdaArgIndex = 1
)
else -> null
}
}
@ -5321,7 +5372,6 @@ class BytecodeCompiler(
ref: MethodCallRef,
lambdaRef: LambdaFnRef,
inlineRef: ObjRef,
paramName: String,
behavior: InlineHigherOrderResultMode
): CompiledValue? {
val iterableMethods = ObjIterable.instanceMethodIdMap(includeAbstract = true)
@ -5366,16 +5416,17 @@ class BytecodeCompiler(
val nextSlot = allocSlot()
builder.emit(Opcode.CALL_MEMBER_SLOT, iterSlot, nextMethodId, 0, 0, nextSlot)
val nextObj = ensureObjSlot(CompiledValue(nextSlot, SlotType.UNKNOWN))
val bindings = prepareInlineLambdaBindingsFromValues(lambdaRef, listOf(nextObj)) ?: return null
when (behavior) {
InlineHigherOrderResultMode.FOR_EACH -> {
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
}
InlineHigherOrderResultMode.MAP -> {
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
appendToList(result, mapped) ?: return null
}
InlineHigherOrderResultMode.FILTER -> {
val predicate = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
val predicate = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
val predicateBool = compileValueAsBool(predicate)
val skipLabel = builder.label()
builder.emit(
@ -5386,7 +5437,7 @@ class BytecodeCompiler(
builder.mark(skipLabel)
}
InlineHigherOrderResultMode.MAP_NOT_NULL -> {
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
val mappedObj = ensureObjSlot(mapped)
val nullSlot = allocSlot()
builder.emit(Opcode.CONST_NULL, nullSlot)
@ -5401,7 +5452,7 @@ class BytecodeCompiler(
builder.mark(skipLabel)
}
InlineHigherOrderResultMode.ASSOCIATE_BY -> {
val key = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
val key = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
appendToMap(result, key, nextObj)
}
}
@ -5433,6 +5484,34 @@ class BytecodeCompiler(
builder.emit(Opcode.SET_INDEX, mapObj.slot, keyObj.slot, itemObj.slot)
}
private fun compileInlineMapGetOrPut(
receiverObj: CompiledValue,
keyObj: CompiledValue,
lambdaRef: LambdaFnRef,
inlineRef: ObjRef
): CompiledValue? {
val dst = allocSlot()
val hasKey = allocSlot()
builder.emit(Opcode.CONTAINS_OBJ, receiverObj.slot, keyObj.slot, hasKey)
val existingLabel = builder.label()
val endLabel = builder.label()
builder.emit(
Opcode.JMP_IF_TRUE,
listOf(CmdBuilder.Operand.IntVal(hasKey), CmdBuilder.Operand.LabelRef(existingLabel))
)
val bindings = prepareInlineLambdaBindingsFromValues(lambdaRef, emptyList()) ?: return null
val computed = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
val computedObj = ensureObjSlot(computed)
builder.emit(Opcode.SET_INDEX, receiverObj.slot, keyObj.slot, computedObj.slot)
builder.emit(Opcode.MOVE_OBJ, computedObj.slot, dst)
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
builder.mark(existingLabel)
builder.emit(Opcode.GET_INDEX, receiverObj.slot, keyObj.slot, dst)
builder.mark(endLabel)
updateSlotType(dst, SlotType.OBJ)
return CompiledValue(dst, SlotType.OBJ)
}
private fun compileValueAsBool(value: CompiledValue): CompiledValue {
if (value.type == SlotType.BOOL) return value
val dst = allocSlot()
@ -5779,7 +5858,7 @@ class BytecodeCompiler(
return when (entry.ownerKind) {
CaptureOwnerFrameKind.LOCAL -> {
localSlotIndexByKey[key]?.let { return scopeSlotCount + it }
null
scopeSlotMap[key]
}
CaptureOwnerFrameKind.MODULE -> {
localSlotIndexByKey[key]?.let { return scopeSlotCount + it }

View File

@ -255,6 +255,54 @@ class CompilerVmReviewRegressionTest {
assertEquals(36, result.list[6].toInt())
}
@Test
fun directLambdaInliningMatchesImplicitItInvocationSemantics() = runTest {
val script = Compiler.compile(
Source(
"<direct-inline-it-semantics>",
"""
val zeroFn = { if (it == void) 1 else 0 }
val multiFn = { it }
val zero = zeroFn()
val multi = multiFn(1, 2, 3)
[zero, multi]
""".trimIndent()
),
Script.defaultImportManager
)
val scope = Script.newScope()
val result = script.execute(scope) as ObjList
assertEquals(1, result.list[0].toInt())
val multi = result.list[1] as ObjList
assertEquals(listOf(1, 2, 3), multi.list.map { it.toInt() })
}
@Test
fun mapGetOrPutUsesInlineDefaultLambda() = runTest {
val script = Compiler.compile(
Source(
"<map-get-or-put-inline>",
"""
val offset = 10
val m = Map()
val first = m.getOrPut("k") { offset + 1 }
val second = m.getOrPut("k") { offset + 2 }
[first, second, m["k"]]
""".trimIndent()
),
Script.defaultImportManager
)
val scope = Script.newScope()
val result = script.execute(scope) as ObjList
assertEquals(11, result.list[0].toInt())
assertEquals(11, result.list[1].toInt())
assertEquals(11, result.list[2].toInt())
}
@Test
fun subjectlessWhenReportsScriptError() = runTest {
val ex = assertFailsWith<ScriptError> {