Generalize higher-order lambda inlining

This commit is contained in:
Sergey Chernov 2026-04-21 19:12:16 +03:00
parent 1d5caaa836
commit 0c3242cbd8
2 changed files with 179 additions and 164 deletions

View File

@ -4800,9 +4800,7 @@ class BytecodeCompiler(
private fun compileMethodCall(ref: MethodCallRef): CompiledValue? {
compileListFillIntCall(ref)?.let { return it }
compileInlineUnaryLambdaMethodCall(ref)?.let { return it }
compileInlineReceiverLambdaMethodCall(ref)?.let { return it }
compileInlineIterableLambdaMethodCall(ref)?.let { return it }
compileInlineHigherOrderMethodCall(ref)?.let { return it }
val callPos = callSitePos()
val receiverClass = resolveReceiverClass(ref.receiver) ?: ObjDynamic.type
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
@ -4981,145 +4979,49 @@ class BytecodeCompiler(
return CompiledValue(dst, SlotType.OBJ)
}
private fun compileInlineUnaryLambdaMethodCall(ref: MethodCallRef): CompiledValue? {
val behavior = when (ref.name) {
"let" -> InlineUnaryLambdaMethodBehavior.RETURN_BLOCK_RESULT
"also" -> InlineUnaryLambdaMethodBehavior.RETURN_RECEIVER
else -> return null
}
private fun compileInlineHigherOrderMethodCall(ref: MethodCallRef): CompiledValue? {
val spec = inlineHigherOrderMethodSpec(ref.name) ?: return null
if (ref.args.size != 1 || ref.args.any { it.isSplat || it.name != null }) return null
if (!ref.explicitTypeArgs.isNullOrEmpty()) return null
val lambdaRef = extractExactLambdaRef(ref.args.first().value) ?: return null
val inlineRef = lambdaRef.inlineBodyRef ?: return null
return when (spec.kind) {
InlineHigherOrderMethodKind.UNARY_ARGUMENT -> {
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null
val paramName = lambdaRef.inlineParamNames()?.singleOrNull() ?: return null
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
return if (!ref.isOptional) {
val receiverObj = ensureObjSlot(receiver)
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
val receiverSlot = materializeInlineBinding(receiver)
when (behavior) {
InlineUnaryLambdaMethodBehavior.RETURN_BLOCK_RESULT ->
when (spec.result) {
InlineHigherOrderResultMode.BLOCK_RESULT ->
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to receiverSlot))
InlineUnaryLambdaMethodBehavior.RETURN_RECEIVER -> {
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to receiverSlot)) ?: return null
InlineHigherOrderResultMode.RETURN_RECEIVER -> {
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to receiverSlot)) ?: return@compileOptionalInlineMethod null
CompiledValue(receiverSlot, receiver.type)
}
}
} else {
val receiverObj = ensureObjSlot(receiver)
val dst = allocSlot()
val nullSlot = allocSlot()
builder.emit(Opcode.CONST_NULL, nullSlot)
val cmpSlot = allocSlot()
builder.emit(Opcode.CMP_REF_EQ_OBJ, receiverObj.slot, nullSlot, cmpSlot)
val nullLabel = builder.label()
val endLabel = builder.label()
builder.emit(
Opcode.JMP_IF_TRUE,
listOf(CmdBuilder.Operand.IntVal(cmpSlot), CmdBuilder.Operand.LabelRef(nullLabel))
)
val receiverSlot = materializeInlineBinding(receiver)
when (behavior) {
InlineUnaryLambdaMethodBehavior.RETURN_BLOCK_RESULT -> {
val inlineResult =
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to receiverSlot)) ?: return null
val inlineObj = ensureObjSlot(inlineResult)
builder.emit(Opcode.MOVE_OBJ, inlineObj.slot, dst)
}
InlineUnaryLambdaMethodBehavior.RETURN_RECEIVER -> {
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to receiverSlot)) ?: return null
builder.emit(Opcode.MOVE_OBJ, receiverObj.slot, dst)
else -> null
}
}
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
builder.mark(nullLabel)
builder.emit(Opcode.CONST_NULL, dst)
builder.mark(endLabel)
updateSlotType(dst, SlotType.OBJ)
CompiledValue(dst, SlotType.OBJ)
}
}
private fun compileInlineReceiverLambdaMethodCall(ref: MethodCallRef): CompiledValue? {
val behavior = when (ref.name) {
"apply" -> InlineReceiverLambdaMethodBehavior.RETURN_RECEIVER
"run" -> InlineReceiverLambdaMethodBehavior.RETURN_BLOCK_RESULT
else -> return null
}
if (ref.args.size != 1 || ref.args.any { it.isSplat || it.name != null }) return null
if (!ref.explicitTypeArgs.isNullOrEmpty()) return null
val lambdaRef = extractExactLambdaRef(ref.args.first().value) ?: return null
InlineHigherOrderMethodKind.RECEIVER -> {
val receiverInfo = receiverInlineInfo(lambdaRef) ?: return null
val inlineRef = lambdaRef.inlineBodyRef ?: return null
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = true, allowCaptures = true)) return null
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
val receiverObj = ensureObjSlot(receiver)
return if (!ref.isOptional) {
compileInlineReceiverLambdaInvocation(receiverObj, lambdaRef, behavior, receiverInfo)
} else {
val dst = allocSlot()
val nullSlot = allocSlot()
builder.emit(Opcode.CONST_NULL, nullSlot)
val cmpSlot = allocSlot()
builder.emit(Opcode.CMP_REF_EQ_OBJ, receiverObj.slot, nullSlot, cmpSlot)
val nullLabel = builder.label()
val endLabel = builder.label()
builder.emit(
Opcode.JMP_IF_TRUE,
listOf(CmdBuilder.Operand.IntVal(cmpSlot), CmdBuilder.Operand.LabelRef(nullLabel))
)
val nonNullResult =
compileInlineReceiverLambdaInvocation(receiverObj, lambdaRef, behavior, receiverInfo) ?: return null
val nonNullObj = ensureObjSlot(nonNullResult)
builder.emit(Opcode.MOVE_OBJ, nonNullObj.slot, dst)
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
builder.mark(nullLabel)
builder.emit(Opcode.CONST_NULL, dst)
builder.mark(endLabel)
updateSlotType(dst, SlotType.OBJ)
CompiledValue(dst, SlotType.OBJ)
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
compileInlineReceiverLambdaInvocation(receiverObj, lambdaRef, spec.result, receiverInfo)
}
}
private fun compileInlineIterableLambdaMethodCall(ref: MethodCallRef): CompiledValue? {
val behavior = when (ref.name) {
"forEach" -> InlineIterableLambdaMethodBehavior.FOR_EACH
"map" -> InlineIterableLambdaMethodBehavior.MAP
"filter" -> InlineIterableLambdaMethodBehavior.FILTER
else -> return null
}
if (ref.args.size != 1 || ref.args.any { it.isSplat || it.name != null }) return null
if (!ref.explicitTypeArgs.isNullOrEmpty()) return null
val lambdaRef = extractExactLambdaRef(ref.args.first().value) ?: return null
val inlineRef = lambdaRef.inlineBodyRef ?: return null
InlineHigherOrderMethodKind.ITERABLE -> {
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null
val paramNames = lambdaRef.inlineParamNames() ?: return null
if (paramNames.size != 1) return null
val paramName = lambdaRef.inlineParamNames()?.singleOrNull() ?: return null
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
val receiverObj = ensureObjSlot(receiver)
return if (!ref.isOptional) {
compileInlineIterableLambdaLoop(receiverObj, ref, lambdaRef, inlineRef, paramNames[0], behavior)
} else {
val dst = allocSlot()
val nullSlot = allocSlot()
builder.emit(Opcode.CONST_NULL, nullSlot)
val cmpSlot = allocSlot()
builder.emit(Opcode.CMP_REF_EQ_OBJ, receiverObj.slot, nullSlot, cmpSlot)
val nullLabel = builder.label()
val endLabel = builder.label()
builder.emit(
Opcode.JMP_IF_TRUE,
listOf(CmdBuilder.Operand.IntVal(cmpSlot), CmdBuilder.Operand.LabelRef(nullLabel))
)
val nonNullResult =
compileInlineIterableLambdaLoop(receiverObj, ref, lambdaRef, inlineRef, paramNames[0], behavior) ?: return null
val nonNullObj = ensureObjSlot(nonNullResult)
builder.emit(Opcode.MOVE_OBJ, nonNullObj.slot, dst)
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
builder.mark(nullLabel)
builder.emit(Opcode.CONST_NULL, dst)
builder.mark(endLabel)
updateSlotType(dst, SlotType.OBJ)
CompiledValue(dst, SlotType.OBJ)
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
compileInlineIterableLambdaLoop(receiverObj, ref, lambdaRef, inlineRef, paramName, spec.result)
}
}
}
}
@ -5187,35 +5089,100 @@ class BytecodeCompiler(
return slot
}
private enum class InlineUnaryLambdaMethodBehavior {
RETURN_BLOCK_RESULT,
RETURN_RECEIVER
private enum class InlineHigherOrderMethodKind {
UNARY_ARGUMENT,
RECEIVER,
ITERABLE
}
private enum class InlineReceiverLambdaMethodBehavior {
RETURN_BLOCK_RESULT,
RETURN_RECEIVER
}
private enum class InlineIterableLambdaMethodBehavior {
private enum class InlineHigherOrderResultMode {
BLOCK_RESULT,
RETURN_RECEIVER,
FOR_EACH,
MAP,
FILTER
FILTER,
MAP_NOT_NULL,
ASSOCIATE_BY
}
private data class InlineHigherOrderMethodSpec(
val kind: InlineHigherOrderMethodKind,
val result: InlineHigherOrderResultMode
)
private data class InlineReceiverInfo(
val explicitBindings: List<Pair<String, Int>>,
val thisTypeName: String?
)
private fun hasModuleCapture(lambdaRef: LambdaFnRef): Boolean {
val captures = (lambdaCaptureEntriesByRef[lambdaRef] ?: lambdaRef.captureEntries).orEmpty()
return captures.any { it.ownerKind == CaptureOwnerFrameKind.MODULE }
private fun inlineHigherOrderMethodSpec(name: String): InlineHigherOrderMethodSpec? {
return when (name) {
"let" -> InlineHigherOrderMethodSpec(
InlineHigherOrderMethodKind.UNARY_ARGUMENT,
InlineHigherOrderResultMode.BLOCK_RESULT
)
"also" -> InlineHigherOrderMethodSpec(
InlineHigherOrderMethodKind.UNARY_ARGUMENT,
InlineHigherOrderResultMode.RETURN_RECEIVER
)
"apply" -> InlineHigherOrderMethodSpec(
InlineHigherOrderMethodKind.RECEIVER,
InlineHigherOrderResultMode.RETURN_RECEIVER
)
"run" -> InlineHigherOrderMethodSpec(
InlineHigherOrderMethodKind.RECEIVER,
InlineHigherOrderResultMode.BLOCK_RESULT
)
"forEach" -> InlineHigherOrderMethodSpec(
InlineHigherOrderMethodKind.ITERABLE,
InlineHigherOrderResultMode.FOR_EACH
)
"map" -> InlineHigherOrderMethodSpec(
InlineHigherOrderMethodKind.ITERABLE,
InlineHigherOrderResultMode.MAP
)
"filter" -> InlineHigherOrderMethodSpec(
InlineHigherOrderMethodKind.ITERABLE,
InlineHigherOrderResultMode.FILTER
)
"mapNotNull" -> InlineHigherOrderMethodSpec(
InlineHigherOrderMethodKind.ITERABLE,
InlineHigherOrderResultMode.MAP_NOT_NULL
)
"associateBy" -> InlineHigherOrderMethodSpec(
InlineHigherOrderMethodKind.ITERABLE,
InlineHigherOrderResultMode.ASSOCIATE_BY
)
else -> null
}
}
private fun hasAnyCapture(lambdaRef: LambdaFnRef): Boolean {
val captures = (lambdaCaptureEntriesByRef[lambdaRef] ?: lambdaRef.captureEntries).orEmpty()
return captures.isNotEmpty()
private fun compileOptionalInlineMethod(
isOptional: Boolean,
receiverObj: CompiledValue,
compileNonNull: () -> CompiledValue?
): CompiledValue? {
if (!isOptional) return compileNonNull()
val dst = allocSlot()
val nullSlot = allocSlot()
builder.emit(Opcode.CONST_NULL, nullSlot)
val cmpSlot = allocSlot()
builder.emit(Opcode.CMP_REF_EQ_OBJ, receiverObj.slot, nullSlot, cmpSlot)
val nullLabel = builder.label()
val endLabel = builder.label()
builder.emit(
Opcode.JMP_IF_TRUE,
listOf(CmdBuilder.Operand.IntVal(cmpSlot), CmdBuilder.Operand.LabelRef(nullLabel))
)
val nonNullResult = compileNonNull() ?: return null
val nonNullObj = ensureObjSlot(nonNullResult)
builder.emit(Opcode.MOVE_OBJ, nonNullObj.slot, dst)
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
builder.mark(nullLabel)
builder.emit(Opcode.CONST_NULL, dst)
builder.mark(endLabel)
updateSlotType(dst, SlotType.OBJ)
return CompiledValue(dst, SlotType.OBJ)
}
private fun isMethodInlineSafe(
@ -5311,7 +5278,7 @@ class BytecodeCompiler(
private fun compileInlineReceiverLambdaInvocation(
receiverObj: CompiledValue,
lambdaRef: LambdaFnRef,
behavior: InlineReceiverLambdaMethodBehavior,
behavior: InlineHigherOrderResultMode,
receiverInfo: InlineReceiverInfo
): CompiledValue? {
val inlineRef = lambdaRef.inlineBodyRef ?: return null
@ -5320,12 +5287,13 @@ class BytecodeCompiler(
inlineThisBindings.addLast(previousBinding)
return try {
when (behavior) {
InlineReceiverLambdaMethodBehavior.RETURN_BLOCK_RESULT ->
InlineHigherOrderResultMode.BLOCK_RESULT ->
compileInlineLambdaBody(lambdaRef, inlineRef, receiverInfo.explicitBindings)
InlineReceiverLambdaMethodBehavior.RETURN_RECEIVER -> {
InlineHigherOrderResultMode.RETURN_RECEIVER -> {
compileInlineLambdaBody(lambdaRef, inlineRef, receiverInfo.explicitBindings) ?: return null
CompiledValue(receiverSlot, SlotType.OBJ)
}
else -> null
}
} finally {
inlineThisBindings.removeLast()
@ -5340,13 +5308,21 @@ class BytecodeCompiler(
return CompiledValue(dst, SlotType.OBJ)
}
private fun createEmptyMutableMap(): CompiledValue? {
val dst = allocSlot()
emitCallDirect(ObjMap.type, 0, 0, dst)
updateSlotType(dst, SlotType.OBJ)
slotObjClass[dst] = ObjMap.type
return CompiledValue(dst, SlotType.OBJ)
}
private fun compileInlineIterableLambdaLoop(
receiverObj: CompiledValue,
ref: MethodCallRef,
lambdaRef: LambdaFnRef,
inlineRef: ObjRef,
paramName: String,
behavior: InlineIterableLambdaMethodBehavior
behavior: InlineHigherOrderResultMode
): CompiledValue? {
val iterableMethods = ObjIterable.instanceMethodIdMap(includeAbstract = true)
val iteratorMethodId = iterableMethods["iterator"]
@ -5362,14 +5338,17 @@ class BytecodeCompiler(
builder.emit(Opcode.ITER_PUSH, iterSlot)
val result = when (behavior) {
InlineIterableLambdaMethodBehavior.FOR_EACH -> CompiledValue(ensureVoidSlot(), SlotType.OBJ)
InlineIterableLambdaMethodBehavior.MAP,
InlineIterableLambdaMethodBehavior.FILTER -> createEmptyMutableList() ?: return null
InlineHigherOrderResultMode.FOR_EACH -> CompiledValue(ensureVoidSlot(), SlotType.OBJ)
InlineHigherOrderResultMode.MAP,
InlineHigherOrderResultMode.FILTER,
InlineHigherOrderResultMode.MAP_NOT_NULL -> createEmptyMutableList() ?: return null
InlineHigherOrderResultMode.ASSOCIATE_BY -> createEmptyMutableMap() ?: return null
else -> return null
}
if (behavior == InlineIterableLambdaMethodBehavior.FILTER) {
if (behavior == InlineHigherOrderResultMode.FILTER) {
listElementClassFromReceiverRef(ref.receiver)?.let { listElementClassBySlot[result.slot] = it }
}
if (behavior == InlineIterableLambdaMethodBehavior.MAP) {
if (behavior == InlineHigherOrderResultMode.MAP) {
lambdaRef.inferredReturnClass?.let { listElementClassBySlot[result.slot] = it }
}
@ -5388,14 +5367,14 @@ class BytecodeCompiler(
builder.emit(Opcode.CALL_MEMBER_SLOT, iterSlot, nextMethodId, 0, 0, nextSlot)
val nextObj = ensureObjSlot(CompiledValue(nextSlot, SlotType.UNKNOWN))
when (behavior) {
InlineIterableLambdaMethodBehavior.FOR_EACH -> {
InlineHigherOrderResultMode.FOR_EACH -> {
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
}
InlineIterableLambdaMethodBehavior.MAP -> {
InlineHigherOrderResultMode.MAP -> {
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
appendToList(result, mapped) ?: return null
}
InlineIterableLambdaMethodBehavior.FILTER -> {
InlineHigherOrderResultMode.FILTER -> {
val predicate = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
val predicateBool = compileValueAsBool(predicate)
val skipLabel = builder.label()
@ -5406,6 +5385,25 @@ class BytecodeCompiler(
appendToList(result, nextObj) ?: return null
builder.mark(skipLabel)
}
InlineHigherOrderResultMode.MAP_NOT_NULL -> {
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
val mappedObj = ensureObjSlot(mapped)
val nullSlot = allocSlot()
builder.emit(Opcode.CONST_NULL, nullSlot)
val cmpSlot = allocSlot()
builder.emit(Opcode.CMP_REF_EQ_OBJ, mappedObj.slot, nullSlot, cmpSlot)
val skipLabel = builder.label()
builder.emit(
Opcode.JMP_IF_TRUE,
listOf(CmdBuilder.Operand.IntVal(cmpSlot), CmdBuilder.Operand.LabelRef(skipLabel))
)
appendToList(result, mappedObj) ?: return null
builder.mark(skipLabel)
}
InlineHigherOrderResultMode.ASSOCIATE_BY -> {
val key = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
appendToMap(result, key, nextObj)
}
}
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(loopLabel)))
builder.mark(endLabel)
@ -5428,6 +5426,13 @@ class BytecodeCompiler(
return CompiledValue(dst, SlotType.OBJ)
}
private fun appendToMap(mapValue: CompiledValue, keyValue: CompiledValue, itemValue: CompiledValue) {
val mapObj = ensureObjSlot(mapValue)
val keyObj = ensureObjSlot(keyValue)
val itemObj = ensureObjSlot(itemValue)
builder.emit(Opcode.SET_INDEX, mapObj.slot, keyObj.slot, itemObj.slot)
}
private fun compileValueAsBool(value: CompiledValue): CompiledValue {
if (value.type == SlotType.BOOL) return value
val dst = allocSlot()

View File

@ -221,9 +221,11 @@ class CompilerVmReviewRegressionTest {
val applyResult = List<Int>().apply { add(offset); add(offset + 1) }
val mapped = [1, 2, 3].map { it + offset }
val filtered = [1, 2, 3].filter { it + offset >= 12 }
val notNull = [1, 2, 3].mapNotNull { if (it + offset >= 12) it + offset else null }
val associated = [1, 2, 3].associateBy { "k" + (it + offset) }
[1, 2, 3].forEach { sum += it + offset }
[letResult, applyResult, mapped, filtered, sum]
[letResult, applyResult, mapped, filtered, notNull, associated, sum]
""".trimIndent()
),
Script.defaultImportManager
@ -242,7 +244,15 @@ class CompilerVmReviewRegressionTest {
val filtered = result.list[3] as ObjList
assertEquals(listOf(2, 3), filtered.list.map { it.toInt() })
assertEquals(36, result.list[4].toInt())
val notNull = result.list[4] as ObjList
assertEquals(listOf(12, 13), notNull.list.map { it.toInt() })
val associated = result.list[5].toString(scope).value
assertContains(associated, "\"k11\" => 1")
assertContains(associated, "\"k12\" => 2")
assertContains(associated, "\"k13\" => 3")
assertEquals(36, result.list[6].toInt())
}
@Test