diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt index 047b75a..3c8a35c 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt @@ -4990,7 +4990,6 @@ class BytecodeCompiler( if (ref.args.size != 1 || ref.args.any { it.isSplat || it.name != null }) return null if (!ref.explicitTypeArgs.isNullOrEmpty()) return null val lambdaRef = extractExactLambdaRef(ref.args.first().value) ?: return null - if (hasModuleCapture(lambdaRef)) return null val inlineRef = lambdaRef.inlineBodyRef ?: return null if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null val paramName = lambdaRef.inlineParamNames()?.singleOrNull() ?: return null @@ -5049,7 +5048,6 @@ class BytecodeCompiler( if (ref.args.size != 1 || ref.args.any { it.isSplat || it.name != null }) return null if (!ref.explicitTypeArgs.isNullOrEmpty()) return null val lambdaRef = extractExactLambdaRef(ref.args.first().value) ?: return null - if (hasModuleCapture(lambdaRef)) return null val receiverInfo = receiverInlineInfo(lambdaRef) ?: return null val inlineRef = lambdaRef.inlineBodyRef ?: return null if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = true, allowCaptures = true)) return null @@ -5092,9 +5090,8 @@ class BytecodeCompiler( if (ref.args.size != 1 || ref.args.any { it.isSplat || it.name != null }) return null if (!ref.explicitTypeArgs.isNullOrEmpty()) return null val lambdaRef = extractExactLambdaRef(ref.args.first().value) ?: return null - if (hasAnyCapture(lambdaRef)) return null val inlineRef = lambdaRef.inlineBodyRef ?: return null - if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = false)) return null + if (!isMethodInlineSafe(lambdaRef, inlineRef, allowReceiverRefs = false, allowCaptures = true)) return null val paramNames = lambdaRef.inlineParamNames() ?: return null if (paramNames.size != 1) return null val receiver = compileRefWithFallback(ref.receiver, null, refPosOrCurrent(ref.receiver)) ?: return null @@ -5773,10 +5770,17 @@ class BytecodeCompiler( } private fun resolveInlineCaptureSlot(entry: LambdaCaptureEntry): Int? { - if (entry.ownerKind != CaptureOwnerFrameKind.LOCAL) return null val key = ScopeSlotKey(entry.ownerScopeId, entry.ownerSlotId) - localSlotIndexByKey[key]?.let { return scopeSlotCount + it } - return null + return when (entry.ownerKind) { + CaptureOwnerFrameKind.LOCAL -> { + localSlotIndexByKey[key]?.let { return scopeSlotCount + it } + null + } + CaptureOwnerFrameKind.MODULE -> { + localSlotIndexByKey[key]?.let { return scopeSlotCount + it } + scopeSlotMap[key] + } + } } private fun isImplicitItIdentityRef(ref: ObjRef): Boolean { diff --git a/lynglib/src/commonTest/kotlin/CompilerVmReviewRegressionTest.kt b/lynglib/src/commonTest/kotlin/CompilerVmReviewRegressionTest.kt index e36e04c..6c5b26f 100644 --- a/lynglib/src/commonTest/kotlin/CompilerVmReviewRegressionTest.kt +++ b/lynglib/src/commonTest/kotlin/CompilerVmReviewRegressionTest.kt @@ -26,6 +26,7 @@ import net.sergeych.lyng.Statement import net.sergeych.lyng.asFacade import net.sergeych.lyng.obj.ObjDynamic import net.sergeych.lyng.obj.ObjInt +import net.sergeych.lyng.obj.ObjList import net.sergeych.lyng.obj.ObjString import net.sergeych.lyng.obj.toInt import net.sergeych.lyng.pacman.ImportManager @@ -206,6 +207,44 @@ class CompilerVmReviewRegressionTest { assertEquals("barfoo=7", (dynamic.getAt(scope, ObjString("bar")) as ObjString).value) } + @Test + fun higherOrderMethodInliningSupportsCapturedValues() = runTest { + val script = Compiler.compile( + Source( + "", + """ + val suffix = "!" + val offset = 10 + var sum = 0 + + val letResult = "a".let { it + suffix } + val applyResult = List().apply { add(offset); add(offset + 1) } + val mapped = [1, 2, 3].map { it + offset } + val filtered = [1, 2, 3].filter { it + offset >= 12 } + [1, 2, 3].forEach { sum += it + offset } + + [letResult, applyResult, mapped, filtered, sum] + """.trimIndent() + ), + Script.defaultImportManager + ) + + val scope = Script.newScope() + val result = script.execute(scope) as ObjList + + assertEquals("a!", (result.list[0] as ObjString).value) + val applied = result.list[1] as ObjList + assertEquals(listOf(10, 11), applied.list.map { it.toInt() }) + + val mapped = result.list[2] as ObjList + assertEquals(listOf(11, 12, 13), mapped.list.map { it.toInt() }) + + val filtered = result.list[3] as ObjList + assertEquals(listOf(2, 3), filtered.list.map { it.toInt() }) + + assertEquals(36, result.list[4].toInt()) + } + @Test fun subjectlessWhenReportsScriptError() = runTest { val ex = assertFailsWith {