Extend lambda inlining to getOrPut and implicit it calls
This commit is contained in:
parent
0c3242cbd8
commit
f4ab2ebab4
@ -4981,24 +4981,23 @@ class BytecodeCompiler(
|
||||
|
||||
private fun compileInlineHigherOrderMethodCall(ref: MethodCallRef): CompiledValue? {
|
||||
val spec = inlineHigherOrderMethodSpec(ref.name) ?: return null
|
||||
if (ref.args.size != 1 || ref.args.any { it.isSplat || it.name != null }) return null
|
||||
if (ref.args.size != spec.argCount || ref.args.any { it.isSplat || it.name != null }) return null
|
||||
if (!ref.explicitTypeArgs.isNullOrEmpty()) return null
|
||||
val lambdaRef = extractExactLambdaRef(ref.args.first().value) ?: return null
|
||||
val lambdaRef = extractExactLambdaRef(ref.args[spec.lambdaArgIndex].value) ?: return null
|
||||
val inlineRef = lambdaRef.inlineBodyRef ?: return null
|
||||
return when (spec.kind) {
|
||||
InlineHigherOrderMethodKind.UNARY_ARGUMENT -> {
|
||||
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null
|
||||
val paramName = lambdaRef.inlineParamNames()?.singleOrNull() ?: return null
|
||||
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
|
||||
val receiverObj = ensureObjSlot(receiver)
|
||||
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
|
||||
val receiverSlot = materializeInlineBinding(receiver)
|
||||
val bindings = prepareInlineLambdaBindingsFromValues(lambdaRef, listOf(receiver)) ?: return@compileOptionalInlineMethod null
|
||||
when (spec.result) {
|
||||
InlineHigherOrderResultMode.BLOCK_RESULT ->
|
||||
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to receiverSlot))
|
||||
compileInlineLambdaBody(lambdaRef, inlineRef, bindings)
|
||||
InlineHigherOrderResultMode.RETURN_RECEIVER -> {
|
||||
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to receiverSlot)) ?: return@compileOptionalInlineMethod null
|
||||
CompiledValue(receiverSlot, receiver.type)
|
||||
compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return@compileOptionalInlineMethod null
|
||||
CompiledValue(receiverObj.slot, SlotType.OBJ)
|
||||
}
|
||||
else -> null
|
||||
}
|
||||
@ -5015,11 +5014,22 @@ class BytecodeCompiler(
|
||||
}
|
||||
InlineHigherOrderMethodKind.ITERABLE -> {
|
||||
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null
|
||||
val paramName = lambdaRef.inlineParamNames()?.singleOrNull() ?: return null
|
||||
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
|
||||
val receiverObj = ensureObjSlot(receiver)
|
||||
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
|
||||
compileInlineIterableLambdaLoop(receiverObj, ref, lambdaRef, inlineRef, paramName, spec.result)
|
||||
compileInlineIterableLambdaLoop(receiverObj, ref, lambdaRef, inlineRef, spec.result)
|
||||
}
|
||||
}
|
||||
InlineHigherOrderMethodKind.MAP_GET_OR_PUT -> {
|
||||
val receiverClass = resolveReceiverClass(ref.receiver) ?: return null
|
||||
if (receiverClass != ObjMap.type) return null
|
||||
if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null
|
||||
val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null
|
||||
val receiverObj = ensureObjSlot(receiver)
|
||||
val key = compileArgValue(ref.args[0].value) ?: return null
|
||||
val keyObj = ensureObjSlot(key)
|
||||
compileOptionalInlineMethod(ref.isOptional, receiverObj) {
|
||||
compileInlineMapGetOrPut(receiverObj, keyObj, lambdaRef, inlineRef)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -5050,26 +5060,58 @@ class BytecodeCompiler(
|
||||
if (ref.tailBlock) return null
|
||||
if (!ref.explicitTypeArgs.isNullOrEmpty()) return null
|
||||
val inlineRef = lambdaRef.inlineBodyRef ?: return null
|
||||
val bindings = prepareInlineLambdaBindings(lambdaRef, ref.args) ?: return null
|
||||
if (lambdaRef.argsDeclaration == null && ref.args.size != 1 &&
|
||||
!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = false)
|
||||
) {
|
||||
return null
|
||||
}
|
||||
val bindings = prepareInlineLambdaBindingsFromArgs(lambdaRef, ref.args) ?: return null
|
||||
return compileInlineLambdaBody(lambdaRef, inlineRef, bindings)
|
||||
}
|
||||
|
||||
private fun prepareInlineLambdaBindings(
|
||||
private fun prepareInlineLambdaBindingsFromArgs(
|
||||
lambdaRef: LambdaFnRef,
|
||||
args: List<ParsedArgument>
|
||||
): List<Pair<String, Int>>? {
|
||||
if (args.any { it.isSplat || it.name != null }) return null
|
||||
val paramNames = lambdaRef.inlineParamNames() ?: return null
|
||||
if (args.size != paramNames.size) return null
|
||||
val compiledArgs = ArrayList<CompiledValue>(args.size)
|
||||
for (arg in args) {
|
||||
compiledArgs += compileArgValue(arg.value) ?: return null
|
||||
}
|
||||
return prepareInlineLambdaBindingsFromValues(lambdaRef, compiledArgs)
|
||||
}
|
||||
|
||||
private fun prepareInlineLambdaBindingsFromValues(
|
||||
lambdaRef: LambdaFnRef,
|
||||
args: List<CompiledValue>
|
||||
): List<Pair<String, Int>>? {
|
||||
val declaration = lambdaRef.argsDeclaration
|
||||
if (declaration == null) {
|
||||
val implicitValue = when (args.size) {
|
||||
0 -> CompiledValue(ensureVoidSlot(), SlotType.OBJ)
|
||||
1 -> args[0]
|
||||
else -> buildInlineImplicitItList(args) ?: return null
|
||||
}
|
||||
return listOf("it" to materializeInlineBinding(implicitValue))
|
||||
}
|
||||
if (declaration.params.any { it.isEllipsis || it.defaultValue != null }) return null
|
||||
if (args.size != declaration.params.size) return null
|
||||
if (args.isEmpty()) return emptyList()
|
||||
val bindings = ArrayList<Pair<String, Int>>(args.size)
|
||||
for ((index, arg) in args.withIndex()) {
|
||||
val compiled = compileArgValue(arg.value) ?: return null
|
||||
bindings += paramNames[index] to materializeInlineBinding(compiled)
|
||||
for ((index, param) in declaration.params.withIndex()) {
|
||||
bindings += param.name to materializeInlineBinding(args[index])
|
||||
}
|
||||
return bindings
|
||||
}
|
||||
|
||||
private fun buildInlineImplicitItList(args: List<CompiledValue>): CompiledValue? {
|
||||
val list = createEmptyMutableList() ?: return null
|
||||
for (arg in args) {
|
||||
appendToList(list, arg) ?: return null
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
private fun LambdaFnRef.inlineParamNames(): List<String>? {
|
||||
val declaration = argsDeclaration
|
||||
if (declaration == null) {
|
||||
@ -5092,7 +5134,8 @@ class BytecodeCompiler(
|
||||
private enum class InlineHigherOrderMethodKind {
|
||||
UNARY_ARGUMENT,
|
||||
RECEIVER,
|
||||
ITERABLE
|
||||
ITERABLE,
|
||||
MAP_GET_OR_PUT
|
||||
}
|
||||
|
||||
private enum class InlineHigherOrderResultMode {
|
||||
@ -5107,7 +5150,9 @@ class BytecodeCompiler(
|
||||
|
||||
private data class InlineHigherOrderMethodSpec(
|
||||
val kind: InlineHigherOrderMethodKind,
|
||||
val result: InlineHigherOrderResultMode
|
||||
val result: InlineHigherOrderResultMode,
|
||||
val argCount: Int = 1,
|
||||
val lambdaArgIndex: Int = 0
|
||||
)
|
||||
|
||||
private data class InlineReceiverInfo(
|
||||
@ -5153,6 +5198,12 @@ class BytecodeCompiler(
|
||||
InlineHigherOrderMethodKind.ITERABLE,
|
||||
InlineHigherOrderResultMode.ASSOCIATE_BY
|
||||
)
|
||||
"getOrPut" -> InlineHigherOrderMethodSpec(
|
||||
InlineHigherOrderMethodKind.MAP_GET_OR_PUT,
|
||||
InlineHigherOrderResultMode.BLOCK_RESULT,
|
||||
argCount = 2,
|
||||
lambdaArgIndex = 1
|
||||
)
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
@ -5321,7 +5372,6 @@ class BytecodeCompiler(
|
||||
ref: MethodCallRef,
|
||||
lambdaRef: LambdaFnRef,
|
||||
inlineRef: ObjRef,
|
||||
paramName: String,
|
||||
behavior: InlineHigherOrderResultMode
|
||||
): CompiledValue? {
|
||||
val iterableMethods = ObjIterable.instanceMethodIdMap(includeAbstract = true)
|
||||
@ -5366,16 +5416,17 @@ class BytecodeCompiler(
|
||||
val nextSlot = allocSlot()
|
||||
builder.emit(Opcode.CALL_MEMBER_SLOT, iterSlot, nextMethodId, 0, 0, nextSlot)
|
||||
val nextObj = ensureObjSlot(CompiledValue(nextSlot, SlotType.UNKNOWN))
|
||||
val bindings = prepareInlineLambdaBindingsFromValues(lambdaRef, listOf(nextObj)) ?: return null
|
||||
when (behavior) {
|
||||
InlineHigherOrderResultMode.FOR_EACH -> {
|
||||
compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
|
||||
compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
|
||||
}
|
||||
InlineHigherOrderResultMode.MAP -> {
|
||||
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
|
||||
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
|
||||
appendToList(result, mapped) ?: return null
|
||||
}
|
||||
InlineHigherOrderResultMode.FILTER -> {
|
||||
val predicate = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
|
||||
val predicate = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
|
||||
val predicateBool = compileValueAsBool(predicate)
|
||||
val skipLabel = builder.label()
|
||||
builder.emit(
|
||||
@ -5386,7 +5437,7 @@ class BytecodeCompiler(
|
||||
builder.mark(skipLabel)
|
||||
}
|
||||
InlineHigherOrderResultMode.MAP_NOT_NULL -> {
|
||||
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
|
||||
val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
|
||||
val mappedObj = ensureObjSlot(mapped)
|
||||
val nullSlot = allocSlot()
|
||||
builder.emit(Opcode.CONST_NULL, nullSlot)
|
||||
@ -5401,7 +5452,7 @@ class BytecodeCompiler(
|
||||
builder.mark(skipLabel)
|
||||
}
|
||||
InlineHigherOrderResultMode.ASSOCIATE_BY -> {
|
||||
val key = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null
|
||||
val key = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
|
||||
appendToMap(result, key, nextObj)
|
||||
}
|
||||
}
|
||||
@ -5433,6 +5484,34 @@ class BytecodeCompiler(
|
||||
builder.emit(Opcode.SET_INDEX, mapObj.slot, keyObj.slot, itemObj.slot)
|
||||
}
|
||||
|
||||
private fun compileInlineMapGetOrPut(
|
||||
receiverObj: CompiledValue,
|
||||
keyObj: CompiledValue,
|
||||
lambdaRef: LambdaFnRef,
|
||||
inlineRef: ObjRef
|
||||
): CompiledValue? {
|
||||
val dst = allocSlot()
|
||||
val hasKey = allocSlot()
|
||||
builder.emit(Opcode.CONTAINS_OBJ, receiverObj.slot, keyObj.slot, hasKey)
|
||||
val existingLabel = builder.label()
|
||||
val endLabel = builder.label()
|
||||
builder.emit(
|
||||
Opcode.JMP_IF_TRUE,
|
||||
listOf(CmdBuilder.Operand.IntVal(hasKey), CmdBuilder.Operand.LabelRef(existingLabel))
|
||||
)
|
||||
val bindings = prepareInlineLambdaBindingsFromValues(lambdaRef, emptyList()) ?: return null
|
||||
val computed = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null
|
||||
val computedObj = ensureObjSlot(computed)
|
||||
builder.emit(Opcode.SET_INDEX, receiverObj.slot, keyObj.slot, computedObj.slot)
|
||||
builder.emit(Opcode.MOVE_OBJ, computedObj.slot, dst)
|
||||
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
|
||||
builder.mark(existingLabel)
|
||||
builder.emit(Opcode.GET_INDEX, receiverObj.slot, keyObj.slot, dst)
|
||||
builder.mark(endLabel)
|
||||
updateSlotType(dst, SlotType.OBJ)
|
||||
return CompiledValue(dst, SlotType.OBJ)
|
||||
}
|
||||
|
||||
private fun compileValueAsBool(value: CompiledValue): CompiledValue {
|
||||
if (value.type == SlotType.BOOL) return value
|
||||
val dst = allocSlot()
|
||||
@ -5779,7 +5858,7 @@ class BytecodeCompiler(
|
||||
return when (entry.ownerKind) {
|
||||
CaptureOwnerFrameKind.LOCAL -> {
|
||||
localSlotIndexByKey[key]?.let { return scopeSlotCount + it }
|
||||
null
|
||||
scopeSlotMap[key]
|
||||
}
|
||||
CaptureOwnerFrameKind.MODULE -> {
|
||||
localSlotIndexByKey[key]?.let { return scopeSlotCount + it }
|
||||
|
||||
@ -255,6 +255,54 @@ class CompilerVmReviewRegressionTest {
|
||||
assertEquals(36, result.list[6].toInt())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun directLambdaInliningMatchesImplicitItInvocationSemantics() = runTest {
|
||||
val script = Compiler.compile(
|
||||
Source(
|
||||
"<direct-inline-it-semantics>",
|
||||
"""
|
||||
val zeroFn = { if (it == void) 1 else 0 }
|
||||
val multiFn = { it }
|
||||
val zero = zeroFn()
|
||||
val multi = multiFn(1, 2, 3)
|
||||
[zero, multi]
|
||||
""".trimIndent()
|
||||
),
|
||||
Script.defaultImportManager
|
||||
)
|
||||
|
||||
val scope = Script.newScope()
|
||||
val result = script.execute(scope) as ObjList
|
||||
|
||||
assertEquals(1, result.list[0].toInt())
|
||||
val multi = result.list[1] as ObjList
|
||||
assertEquals(listOf(1, 2, 3), multi.list.map { it.toInt() })
|
||||
}
|
||||
|
||||
@Test
|
||||
fun mapGetOrPutUsesInlineDefaultLambda() = runTest {
|
||||
val script = Compiler.compile(
|
||||
Source(
|
||||
"<map-get-or-put-inline>",
|
||||
"""
|
||||
val offset = 10
|
||||
val m = Map()
|
||||
val first = m.getOrPut("k") { offset + 1 }
|
||||
val second = m.getOrPut("k") { offset + 2 }
|
||||
[first, second, m["k"]]
|
||||
""".trimIndent()
|
||||
),
|
||||
Script.defaultImportManager
|
||||
)
|
||||
|
||||
val scope = Script.newScope()
|
||||
val result = script.execute(scope) as ObjList
|
||||
|
||||
assertEquals(11, result.list[0].toInt())
|
||||
assertEquals(11, result.list[1].toInt())
|
||||
assertEquals(11, result.list[2].toInt())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun subjectlessWhenReportsScriptError() = runTest {
|
||||
val ex = assertFailsWith<ScriptError> {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user