diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt index a3fc4c2..cf4f918 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt @@ -5470,19 +5470,7 @@ class BytecodeCompiler( } private fun resolveInlineCallableLambda(target: ObjRef): LambdaFnRef? { - val lambdaRef = when (target) { - is LambdaFnRef -> target - is LocalSlotRef -> { - val ownerScopeId = target.captureOwnerScopeId ?: target.scopeId - val ownerSlot = target.captureOwnerSlot ?: target.slot - exactLambdaRefByScopeId[ownerScopeId]?.get(ownerSlot) - ?: resolveLocalSlotByRefOrName(target)?.let { exactLambdaRefBySlot[it] } - } - is LocalVarRef -> resolveDirectNameSlot(target.name)?.slot?.let { exactLambdaRefBySlot[it] } - is FastLocalVarRef -> resolveDirectNameSlot(target.name)?.slot?.let { exactLambdaRefBySlot[it] } - is BoundLocalVarRef -> exactLambdaRefBySlot[target.slotIndex()] - else -> null - } + val lambdaRef = resolveExactLambdaRef(target) return lambdaRef?.takeUnless { activeInlineLambdas.contains(it) } } @@ -5574,13 +5562,28 @@ class BytecodeCompiler( } private fun extractExactLambdaRef(value: Obj?): LambdaFnRef? { - val expr = value as? ExpressionStatement ?: return null - return when (val ref = expr.ref) { - is LambdaFnRef -> ref - is LocalSlotRef -> resolveLocalSlotByRefOrName(ref)?.let { exactLambdaRefBySlot[it] } - is LocalVarRef -> resolveDirectNameSlot(ref.name)?.slot?.let { exactLambdaRefBySlot[it] } - is FastLocalVarRef -> resolveDirectNameSlot(ref.name)?.slot?.let { exactLambdaRefBySlot[it] } - is BoundLocalVarRef -> exactLambdaRefBySlot[ref.slotIndex()] + return when (value) { + is ExpressionStatement -> resolveExactLambdaRef(value.ref) + is IfStatement -> { + val thenRef = extractExactLambdaRef(value.ifBody) + val elseRef = value.elseBody?.let { extractExactLambdaRef(it) } + if (thenRef != null && thenRef === elseRef) thenRef else null + } + is WhenStatement -> { + var candidate: LambdaFnRef? = null + for (case in value.cases) { + val current = extractExactLambdaRef(case.block) ?: return null + if (candidate == null) { + candidate = current + } else if (candidate !== current) { + return null + } + } + val elseRef = value.elseCase?.let { extractExactLambdaRef(it) } + if (candidate == null) return elseRef + if (elseRef == null || candidate !== elseRef) return null + candidate + } else -> null } } @@ -5648,6 +5651,7 @@ class BytecodeCompiler( private fun resolveExactCallableObj(target: ObjRef): Obj? { return when (target) { is ConstRef -> target.constValue.takeUnless { it === ObjNull || it === ObjUnset || it is ObjExternCallable } + is CastRef -> resolveExactCallableObj(target.castValueRef()) is ElvisRef -> { val left = resolveExactCallableObj(target.left) if (left != null) return left @@ -5666,6 +5670,7 @@ class BytecodeCompiler( val elseObj = statement.elseBody?.let { extractExactCallableObj(it) } if (thenObj != null && thenObj === elseObj) thenObj else null } + is WhenStatement -> extractExactCallableObj(statement) else -> null } } @@ -5686,6 +5691,41 @@ class BytecodeCompiler( } } + private fun resolveExactLambdaRef(target: ObjRef): LambdaFnRef? { + return when (target) { + is LambdaFnRef -> target + is CastRef -> resolveExactLambdaRef(target.castValueRef()) + is ElvisRef -> { + val left = resolveExactLambdaRef(target.left) + if (left != null) return left + if (isDefinitelyNullRef(target.left)) resolveExactLambdaRef(target.right) else null + } + is ConditionalRef -> { + val thenRef = resolveExactLambdaRef(target.ifTrue) + val elseRef = resolveExactLambdaRef(target.ifFalse) + if (thenRef != null && thenRef === elseRef) thenRef else null + } + is StatementRef -> { + when (val statement = target.statement) { + is ExpressionStatement -> resolveExactLambdaRef(statement.ref) + is IfStatement, + is WhenStatement -> extractExactLambdaRef(statement) + else -> null + } + } + is LocalSlotRef -> { + val ownerScopeId = target.captureOwnerScopeId ?: target.scopeId + val ownerSlot = target.captureOwnerSlot ?: target.slot + exactLambdaRefByScopeId[ownerScopeId]?.get(ownerSlot) + ?: resolveLocalSlotByRefOrName(target)?.let { exactLambdaRefBySlot[it] } + } + is LocalVarRef -> resolveDirectNameSlot(target.name)?.slot?.let { exactLambdaRefBySlot[it] } + is FastLocalVarRef -> resolveDirectNameSlot(target.name)?.slot?.let { exactLambdaRefBySlot[it] } + is BoundLocalVarRef -> exactLambdaRefBySlot[target.slotIndex()] + else -> null + } + } + private fun compileInlineListFillInt(size: CompiledValue, lambdaRef: LambdaFnRef, inlineRef: ObjRef): CompiledValue { if (isImplicitItIdentityRef(inlineRef)) { val dst = allocSlot() diff --git a/lynglib/src/commonTest/kotlin/BytecodeRecentOpsTest.kt b/lynglib/src/commonTest/kotlin/BytecodeRecentOpsTest.kt index a29d27e..aeb04bf 100644 --- a/lynglib/src/commonTest/kotlin/BytecodeRecentOpsTest.kt +++ b/lynglib/src/commonTest/kotlin/BytecodeRecentOpsTest.kt @@ -378,6 +378,55 @@ class BytecodeRecentOpsTest { assertEquals(11, scope.eval("calc()").toInt()) } + @Test + fun conditionalExactLambdaCallUsesInlineBytecode() = runTest { + val scope = Script.newScope() + scope.eval( + """ + val base = { x -> x + 1 } + fun calc(flag: Bool) { + (if(flag) base else base)(10) + } + """.trimIndent() + ) + val disasm = scope.disassembleSymbol("calc") + assertFalse(disasm.contains("CALL_SLOT"), disasm) + assertEquals(11, scope.eval("calc(true)").toInt()) + } + + @Test + fun elvisExactLambdaCallUsesInlineBytecode() = runTest { + val scope = Script.newScope() + scope.eval( + """ + val base = { x -> x + 1 } + fun calc() { + (null ?: base)(10) + } + """.trimIndent() + ) + val disasm = scope.disassembleSymbol("calc") + assertFalse(disasm.contains("CALL_SLOT"), disasm) + assertEquals(11, scope.eval("calc()").toInt()) + } + + @Test + fun castExactLambdaCallUsesInlineBytecode() = runTest { + val scope = Script.newScope() + scope.eval( + """ + type IntFn = (Int)->Int + val base: IntFn = { x -> x + 1 } + fun calc() { + (base as IntFn)(10) + } + """.trimIndent() + ) + val disasm = scope.disassembleSymbol("calc") + assertFalse(disasm.contains("CALL_SLOT"), disasm) + assertEquals(11, scope.eval("calc()").toInt()) + } + @Test fun letLiteralUsesInlineBytecode() = runTest { val scope = Script.newScope()