diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt index 2f7b7de..b0b7e8e 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt @@ -4981,24 +4981,23 @@ class BytecodeCompiler( private fun compileInlineHigherOrderMethodCall(ref: MethodCallRef): CompiledValue? { val spec = inlineHigherOrderMethodSpec(ref.name) ?: return null - if (ref.args.size != 1 || ref.args.any { it.isSplat || it.name != null }) return null + if (ref.args.size != spec.argCount || ref.args.any { it.isSplat || it.name != null }) return null if (!ref.explicitTypeArgs.isNullOrEmpty()) return null - val lambdaRef = extractExactLambdaRef(ref.args.first().value) ?: return null + val lambdaRef = extractExactLambdaRef(ref.args[spec.lambdaArgIndex].value) ?: return null val inlineRef = lambdaRef.inlineBodyRef ?: return null return when (spec.kind) { InlineHigherOrderMethodKind.UNARY_ARGUMENT -> { if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null - val paramName = lambdaRef.inlineParamNames()?.singleOrNull() ?: return null val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null val receiverObj = ensureObjSlot(receiver) compileOptionalInlineMethod(ref.isOptional, receiverObj) { - val receiverSlot = materializeInlineBinding(receiver) + val bindings = prepareInlineLambdaBindingsFromValues(lambdaRef, listOf(receiver)) ?: return@compileOptionalInlineMethod null when (spec.result) { InlineHigherOrderResultMode.BLOCK_RESULT -> - compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to receiverSlot)) + compileInlineLambdaBody(lambdaRef, inlineRef, bindings) InlineHigherOrderResultMode.RETURN_RECEIVER -> { - compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to receiverSlot)) ?: return@compileOptionalInlineMethod null - CompiledValue(receiverSlot, receiver.type) + compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return@compileOptionalInlineMethod null + CompiledValue(receiverObj.slot, SlotType.OBJ) } else -> null } @@ -5015,11 +5014,22 @@ class BytecodeCompiler( } InlineHigherOrderMethodKind.ITERABLE -> { if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null - val paramName = lambdaRef.inlineParamNames()?.singleOrNull() ?: return null val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null val receiverObj = ensureObjSlot(receiver) compileOptionalInlineMethod(ref.isOptional, receiverObj) { - compileInlineIterableLambdaLoop(receiverObj, ref, lambdaRef, inlineRef, paramName, spec.result) + compileInlineIterableLambdaLoop(receiverObj, ref, lambdaRef, inlineRef, spec.result) + } + } + InlineHigherOrderMethodKind.MAP_GET_OR_PUT -> { + val receiverClass = resolveReceiverClass(ref.receiver) ?: return null + if (receiverClass != ObjMap.type) return null + if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null + val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null + val receiverObj = ensureObjSlot(receiver) + val key = compileArgValue(ref.args[0].value) ?: return null + val keyObj = ensureObjSlot(key) + compileOptionalInlineMethod(ref.isOptional, receiverObj) { + compileInlineMapGetOrPut(receiverObj, keyObj, lambdaRef, inlineRef) } } } @@ -5050,26 +5060,58 @@ class BytecodeCompiler( if (ref.tailBlock) return null if (!ref.explicitTypeArgs.isNullOrEmpty()) return null val inlineRef = lambdaRef.inlineBodyRef ?: return null - val bindings = prepareInlineLambdaBindings(lambdaRef, ref.args) ?: return null + if (lambdaRef.argsDeclaration == null && ref.args.size != 1 && + !isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = false) + ) { + return null + } + val bindings = prepareInlineLambdaBindingsFromArgs(lambdaRef, ref.args) ?: return null return compileInlineLambdaBody(lambdaRef, inlineRef, bindings) } - private fun prepareInlineLambdaBindings( + private fun prepareInlineLambdaBindingsFromArgs( lambdaRef: LambdaFnRef, args: List ): List>? { if (args.any { it.isSplat || it.name != null }) return null - val paramNames = lambdaRef.inlineParamNames() ?: return null - if (args.size != paramNames.size) return null + val compiledArgs = ArrayList(args.size) + for (arg in args) { + compiledArgs += compileArgValue(arg.value) ?: return null + } + return prepareInlineLambdaBindingsFromValues(lambdaRef, compiledArgs) + } + + private fun prepareInlineLambdaBindingsFromValues( + lambdaRef: LambdaFnRef, + args: List + ): List>? { + val declaration = lambdaRef.argsDeclaration + if (declaration == null) { + val implicitValue = when (args.size) { + 0 -> CompiledValue(ensureVoidSlot(), SlotType.OBJ) + 1 -> args[0] + else -> buildInlineImplicitItList(args) ?: return null + } + return listOf("it" to materializeInlineBinding(implicitValue)) + } + if (declaration.params.any { it.isEllipsis || it.defaultValue != null }) return null + if (args.size != declaration.params.size) return null if (args.isEmpty()) return emptyList() val bindings = ArrayList>(args.size) - for ((index, arg) in args.withIndex()) { - val compiled = compileArgValue(arg.value) ?: return null - bindings += paramNames[index] to materializeInlineBinding(compiled) + for ((index, param) in declaration.params.withIndex()) { + bindings += param.name to materializeInlineBinding(args[index]) } return bindings } + private fun buildInlineImplicitItList(args: List): CompiledValue? { + val list = createEmptyMutableList() ?: return null + for (arg in args) { + appendToList(list, arg) ?: return null + } + return list + } + private fun LambdaFnRef.inlineParamNames(): List? { val declaration = argsDeclaration if (declaration == null) { @@ -5092,7 +5134,8 @@ class BytecodeCompiler( private enum class InlineHigherOrderMethodKind { UNARY_ARGUMENT, RECEIVER, - ITERABLE + ITERABLE, + MAP_GET_OR_PUT } private enum class InlineHigherOrderResultMode { @@ -5107,7 +5150,9 @@ class BytecodeCompiler( private data class InlineHigherOrderMethodSpec( val kind: InlineHigherOrderMethodKind, - val result: InlineHigherOrderResultMode + val result: InlineHigherOrderResultMode, + val argCount: Int = 1, + val lambdaArgIndex: Int = 0 ) private data class InlineReceiverInfo( @@ -5153,6 +5198,12 @@ class BytecodeCompiler( InlineHigherOrderMethodKind.ITERABLE, InlineHigherOrderResultMode.ASSOCIATE_BY ) + "getOrPut" -> InlineHigherOrderMethodSpec( + InlineHigherOrderMethodKind.MAP_GET_OR_PUT, + InlineHigherOrderResultMode.BLOCK_RESULT, + argCount = 2, + lambdaArgIndex = 1 + ) else -> null } } @@ -5321,7 +5372,6 @@ class BytecodeCompiler( ref: MethodCallRef, lambdaRef: LambdaFnRef, inlineRef: ObjRef, - paramName: String, behavior: InlineHigherOrderResultMode ): CompiledValue? { val iterableMethods = ObjIterable.instanceMethodIdMap(includeAbstract = true) @@ -5366,16 +5416,17 @@ class BytecodeCompiler( val nextSlot = allocSlot() builder.emit(Opcode.CALL_MEMBER_SLOT, iterSlot, nextMethodId, 0, 0, nextSlot) val nextObj = ensureObjSlot(CompiledValue(nextSlot, SlotType.UNKNOWN)) + val bindings = prepareInlineLambdaBindingsFromValues(lambdaRef, listOf(nextObj)) ?: return null when (behavior) { InlineHigherOrderResultMode.FOR_EACH -> { - compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null + compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null } InlineHigherOrderResultMode.MAP -> { - val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null + val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null appendToList(result, mapped) ?: return null } InlineHigherOrderResultMode.FILTER -> { - val predicate = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null + val predicate = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null val predicateBool = compileValueAsBool(predicate) val skipLabel = builder.label() builder.emit( @@ -5386,7 +5437,7 @@ class BytecodeCompiler( builder.mark(skipLabel) } InlineHigherOrderResultMode.MAP_NOT_NULL -> { - val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null + val mapped = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null val mappedObj = ensureObjSlot(mapped) val nullSlot = allocSlot() builder.emit(Opcode.CONST_NULL, nullSlot) @@ -5401,7 +5452,7 @@ class BytecodeCompiler( builder.mark(skipLabel) } InlineHigherOrderResultMode.ASSOCIATE_BY -> { - val key = compileInlineLambdaBody(lambdaRef, inlineRef, listOf(paramName to nextObj.slot)) ?: return null + val key = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null appendToMap(result, key, nextObj) } } @@ -5433,6 +5484,34 @@ class BytecodeCompiler( builder.emit(Opcode.SET_INDEX, mapObj.slot, keyObj.slot, itemObj.slot) } + private fun compileInlineMapGetOrPut( + receiverObj: CompiledValue, + keyObj: CompiledValue, + lambdaRef: LambdaFnRef, + inlineRef: ObjRef + ): CompiledValue? { + val dst = allocSlot() + val hasKey = allocSlot() + builder.emit(Opcode.CONTAINS_OBJ, receiverObj.slot, keyObj.slot, hasKey) + val existingLabel = builder.label() + val endLabel = builder.label() + builder.emit( + Opcode.JMP_IF_TRUE, + listOf(CmdBuilder.Operand.IntVal(hasKey), CmdBuilder.Operand.LabelRef(existingLabel)) + ) + val bindings = prepareInlineLambdaBindingsFromValues(lambdaRef, emptyList()) ?: return null + val computed = compileInlineLambdaBody(lambdaRef, inlineRef, bindings) ?: return null + val computedObj = ensureObjSlot(computed) + builder.emit(Opcode.SET_INDEX, receiverObj.slot, keyObj.slot, computedObj.slot) + builder.emit(Opcode.MOVE_OBJ, computedObj.slot, dst) + builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel))) + builder.mark(existingLabel) + builder.emit(Opcode.GET_INDEX, receiverObj.slot, keyObj.slot, dst) + builder.mark(endLabel) + updateSlotType(dst, SlotType.OBJ) + return CompiledValue(dst, SlotType.OBJ) + } + private fun compileValueAsBool(value: CompiledValue): CompiledValue { if (value.type == SlotType.BOOL) return value val dst = allocSlot() @@ -5779,7 +5858,7 @@ class BytecodeCompiler( return when (entry.ownerKind) { CaptureOwnerFrameKind.LOCAL -> { localSlotIndexByKey[key]?.let { return scopeSlotCount + it } - null + scopeSlotMap[key] } CaptureOwnerFrameKind.MODULE -> { localSlotIndexByKey[key]?.let { return scopeSlotCount + it } diff --git a/lynglib/src/commonTest/kotlin/CompilerVmReviewRegressionTest.kt b/lynglib/src/commonTest/kotlin/CompilerVmReviewRegressionTest.kt index 761b1a9..4549a06 100644 --- a/lynglib/src/commonTest/kotlin/CompilerVmReviewRegressionTest.kt +++ b/lynglib/src/commonTest/kotlin/CompilerVmReviewRegressionTest.kt @@ -255,6 +255,54 @@ class CompilerVmReviewRegressionTest { assertEquals(36, result.list[6].toInt()) } + @Test + fun directLambdaInliningMatchesImplicitItInvocationSemantics() = runTest { + val script = Compiler.compile( + Source( + "", + """ + val zeroFn = { if (it == void) 1 else 0 } + val multiFn = { it } + val zero = zeroFn() + val multi = multiFn(1, 2, 3) + [zero, multi] + """.trimIndent() + ), + Script.defaultImportManager + ) + + val scope = Script.newScope() + val result = script.execute(scope) as ObjList + + assertEquals(1, result.list[0].toInt()) + val multi = result.list[1] as ObjList + assertEquals(listOf(1, 2, 3), multi.list.map { it.toInt() }) + } + + @Test + fun mapGetOrPutUsesInlineDefaultLambda() = runTest { + val script = Compiler.compile( + Source( + "", + """ + val offset = 10 + val m = Map() + val first = m.getOrPut("k") { offset + 1 } + val second = m.getOrPut("k") { offset + 2 } + [first, second, m["k"]] + """.trimIndent() + ), + Script.defaultImportManager + ) + + val scope = Script.newScope() + val result = script.execute(scope) as ObjList + + assertEquals(11, result.list[0].toInt()) + assertEquals(11, result.list[1].toInt()) + assertEquals(11, result.list[2].toInt()) + } + @Test fun subjectlessWhenReportsScriptError() = runTest { val ex = assertFailsWith {