Fix closure locals for tail blocks; unignore stdlib tests

This commit is contained in:
Sergey Chernov 2026-01-29 10:31:27 +03:00
parent d8e18e4a0c
commit 20b8464591
2 changed files with 146 additions and 9 deletions

View File

@ -65,6 +65,7 @@ class BytecodeCompiler(
private val loopStack = ArrayDeque<LoopContext>() private val loopStack = ArrayDeque<LoopContext>()
private val effectiveScopeDepthByRef = IdentityHashMap<LocalSlotRef, Int>() private val effectiveScopeDepthByRef = IdentityHashMap<LocalSlotRef, Int>()
private val effectiveLocalDepthByKey = LinkedHashMap<ScopeSlotKey, Int>() private val effectiveLocalDepthByKey = LinkedHashMap<ScopeSlotKey, Int>()
private var forceScopeSlots = false
private data class LoopContext( private data class LoopContext(
val label: String?, val label: String?,
@ -205,6 +206,7 @@ class BytecodeCompiler(
} }
is LocalVarRef -> { is LocalVarRef -> {
if (allowLocalSlots) { if (allowLocalSlots) {
if (!forceScopeSlots) {
loopSlotOverrides[ref.name]?.let { slot -> loopSlotOverrides[ref.name]?.let { slot ->
val resolved = slotTypes[slot] ?: SlotType.UNKNOWN val resolved = slotTypes[slot] ?: SlotType.UNKNOWN
return CompiledValue(slot, resolved) return CompiledValue(slot, resolved)
@ -216,6 +218,7 @@ class BytecodeCompiler(
return CompiledValue(slot, resolved) return CompiledValue(slot, resolved)
} }
} }
}
compileNameLookup(ref.name) compileNameLookup(ref.name)
} }
is ValueFnRef -> { is ValueFnRef -> {
@ -2865,6 +2868,10 @@ class BytecodeCompiler(
private fun refPos(ref: BinaryOpRef): Pos = Pos.builtIn private fun refPos(ref: BinaryOpRef): Pos = Pos.builtIn
private fun resolveSlot(ref: LocalSlotRef): Int? { private fun resolveSlot(ref: LocalSlotRef): Int? {
if (forceScopeSlots) {
val scopeKey = ScopeSlotKey(effectiveScopeDepth(ref), refSlot(ref))
return scopeSlotMap[scopeKey]
}
loopSlotOverrides[ref.name]?.let { return it } loopSlotOverrides[ref.name]?.let { return it }
val localKey = ScopeSlotKey(refScopeDepth(ref), refSlot(ref)) val localKey = ScopeSlotKey(refScopeDepth(ref), refSlot(ref))
val localIndex = localSlotIndexByKey[localKey] val localIndex = localSlotIndexByKey[localKey]
@ -2906,6 +2913,7 @@ class BytecodeCompiler(
loopStack.clear() loopStack.clear()
effectiveScopeDepthByRef.clear() effectiveScopeDepthByRef.clear()
effectiveLocalDepthByKey.clear() effectiveLocalDepthByKey.clear()
forceScopeSlots = allowLocalSlots && containsValueFnRef(stmt)
if (allowLocalSlots) { if (allowLocalSlots) {
collectLoopVarNames(stmt) collectLoopVarNames(stmt)
} }
@ -2981,7 +2989,7 @@ class BytecodeCompiler(
is VarDeclStatement -> { is VarDeclStatement -> {
val slotIndex = stmt.slotIndex val slotIndex = stmt.slotIndex
val slotDepth = stmt.slotDepth val slotDepth = stmt.slotDepth
if (allowLocalSlots && slotIndex != null && slotDepth != null) { if (allowLocalSlots && !forceScopeSlots && slotIndex != null && slotDepth != null) {
val key = ScopeSlotKey(slotDepth, slotIndex) val key = ScopeSlotKey(slotDepth, slotIndex)
declaredLocalKeys.add(key) declaredLocalKeys.add(key)
if (!localSlotInfoMap.containsKey(key)) { if (!localSlotInfoMap.containsKey(key)) {
@ -2992,6 +3000,14 @@ class BytecodeCompiler(
localRangeRefs[key] = range localRangeRefs[key] = range
} }
} }
} else if (slotIndex != null && slotDepth != null) {
val key = ScopeSlotKey(slotDepth, slotIndex)
if (!scopeSlotMap.containsKey(key)) {
scopeSlotMap[key] = scopeSlotMap.size
}
if (!scopeSlotNameMap.containsKey(key)) {
scopeSlotNameMap[key] = stmt.name
}
} }
stmt.initializer?.let { collectScopeSlots(it) } stmt.initializer?.let { collectScopeSlots(it) }
} }
@ -3033,6 +3049,51 @@ class BytecodeCompiler(
collectLoopSlotPlans(stmt.original, scopeDepth) collectLoopSlotPlans(stmt.original, scopeDepth)
return return
} }
if (forceScopeSlots) {
when (stmt) {
is net.sergeych.lyng.ForInStatement -> {
collectLoopSlotPlans(stmt.source, scopeDepth)
val loopDepth = scopeDepth + 1
collectLoopSlotPlans(stmt.body, loopDepth)
stmt.elseStatement?.let { collectLoopSlotPlans(it, loopDepth) }
}
is net.sergeych.lyng.WhileStatement -> {
collectLoopSlotPlans(stmt.condition, scopeDepth)
val loopDepth = scopeDepth + 1
collectLoopSlotPlans(stmt.body, loopDepth)
stmt.elseStatement?.let { collectLoopSlotPlans(it, loopDepth) }
}
is net.sergeych.lyng.DoWhileStatement -> {
val loopDepth = scopeDepth + 1
collectLoopSlotPlans(stmt.body, loopDepth)
collectLoopSlotPlans(stmt.condition, loopDepth)
stmt.elseStatement?.let { collectLoopSlotPlans(it, loopDepth) }
}
is BlockStatement -> {
val nextDepth = scopeDepth + 1
for (child in stmt.statements()) {
collectLoopSlotPlans(child, nextDepth)
}
}
is IfStatement -> {
collectLoopSlotPlans(stmt.condition, scopeDepth)
collectLoopSlotPlans(stmt.ifBody, scopeDepth)
stmt.elseBody?.let { collectLoopSlotPlans(it, scopeDepth) }
}
is VarDeclStatement -> {
stmt.initializer?.let { collectLoopSlotPlans(it, scopeDepth) }
}
is ExpressionStatement -> {}
is net.sergeych.lyng.ReturnStatement -> {
stmt.resultExpr?.let { collectLoopSlotPlans(it, scopeDepth) }
}
is net.sergeych.lyng.ThrowStatement -> {
collectLoopSlotPlans(stmt.throwExpr, scopeDepth)
}
else -> {}
}
return
}
when (stmt) { when (stmt) {
is net.sergeych.lyng.ForInStatement -> { is net.sergeych.lyng.ForInStatement -> {
collectLoopSlotPlans(stmt.source, scopeDepth) collectLoopSlotPlans(stmt.source, scopeDepth)
@ -3186,8 +3247,8 @@ class BytecodeCompiler(
when (ref) { when (ref) {
is LocalSlotRef -> { is LocalSlotRef -> {
val localKey = ScopeSlotKey(refScopeDepth(ref), refSlot(ref)) val localKey = ScopeSlotKey(refScopeDepth(ref), refSlot(ref))
val shouldLocalize = (refDepth(ref) == 0) || val shouldLocalize = !forceScopeSlots && ((refDepth(ref) == 0) ||
intLoopVarNames.contains(ref.name) intLoopVarNames.contains(ref.name))
if (allowLocalSlots && !ref.isDelegated && shouldLocalize) { if (allowLocalSlots && !ref.isDelegated && shouldLocalize) {
if (!localSlotInfoMap.containsKey(localKey)) { if (!localSlotInfoMap.containsKey(localKey)) {
localSlotInfoMap[localKey] = LocalSlotInfo(ref.name, ref.isMutable, localKey.depth) localSlotInfoMap[localKey] = LocalSlotInfo(ref.name, ref.isMutable, localKey.depth)
@ -3212,8 +3273,8 @@ class BytecodeCompiler(
val target = assignTarget(ref) val target = assignTarget(ref)
if (target != null) { if (target != null) {
val localKey = ScopeSlotKey(refScopeDepth(target), refSlot(target)) val localKey = ScopeSlotKey(refScopeDepth(target), refSlot(target))
val shouldLocalize = (refDepth(target) == 0) || val shouldLocalize = !forceScopeSlots && ((refDepth(target) == 0) ||
intLoopVarNames.contains(target.name) intLoopVarNames.contains(target.name))
if (allowLocalSlots && !target.isDelegated && shouldLocalize) { if (allowLocalSlots && !target.isDelegated && shouldLocalize) {
if (!localSlotInfoMap.containsKey(localKey)) { if (!localSlotInfoMap.containsKey(localKey)) {
localSlotInfoMap[localKey] = LocalSlotInfo(target.name, target.isMutable, localKey.depth) localSlotInfoMap[localKey] = LocalSlotInfo(target.name, target.isMutable, localKey.depth)
@ -3274,6 +3335,86 @@ class BytecodeCompiler(
} }
} }
private fun containsValueFnRef(stmt: Statement): Boolean {
if (stmt is BytecodeStatement) return containsValueFnRef(stmt.original)
return when (stmt) {
is ExpressionStatement -> containsValueFnRef(stmt.ref)
is BlockStatement -> stmt.statements().any { containsValueFnRef(it) }
is VarDeclStatement -> stmt.initializer?.let { containsValueFnRef(it) } ?: false
is DestructuringVarDeclStatement -> {
containsValueFnRef(stmt.initializer) || containsValueFnRef(stmt.pattern)
}
is net.sergeych.lyng.ForInStatement -> {
containsValueFnRef(stmt.source) ||
containsValueFnRef(stmt.body) ||
(stmt.elseStatement?.let { containsValueFnRef(it) } ?: false)
}
is net.sergeych.lyng.WhileStatement -> {
containsValueFnRef(stmt.condition) ||
containsValueFnRef(stmt.body) ||
(stmt.elseStatement?.let { containsValueFnRef(it) } ?: false)
}
is net.sergeych.lyng.DoWhileStatement -> {
containsValueFnRef(stmt.body) ||
containsValueFnRef(stmt.condition) ||
(stmt.elseStatement?.let { containsValueFnRef(it) } ?: false)
}
is IfStatement -> {
containsValueFnRef(stmt.condition) ||
containsValueFnRef(stmt.ifBody) ||
(stmt.elseBody?.let { containsValueFnRef(it) } ?: false)
}
is net.sergeych.lyng.ReturnStatement -> {
stmt.resultExpr?.let { containsValueFnRef(it) } ?: false
}
is net.sergeych.lyng.ThrowStatement -> containsValueFnRef(stmt.throwExpr)
else -> false
}
}
private fun containsValueFnRef(ref: ObjRef): Boolean {
return when (ref) {
is ValueFnRef -> true
is BinaryOpRef -> containsValueFnRef(binaryLeft(ref)) || containsValueFnRef(binaryRight(ref))
is UnaryOpRef -> containsValueFnRef(unaryOperand(ref))
is AssignRef -> {
val target = assignTarget(ref)
(target != null && containsValueFnRef(target)) || containsValueFnRef(assignValue(ref))
}
is AssignOpRef -> containsValueFnRef(ref.target) || containsValueFnRef(ref.value)
is AssignIfNullRef -> containsValueFnRef(ref.target) || containsValueFnRef(ref.value)
is IncDecRef -> containsValueFnRef(ref.target)
is ConditionalRef -> {
containsValueFnRef(ref.condition) ||
containsValueFnRef(ref.ifTrue) ||
containsValueFnRef(ref.ifFalse)
}
is ElvisRef -> containsValueFnRef(ref.left) || containsValueFnRef(ref.right)
is FieldRef -> containsValueFnRef(ref.target)
is IndexRef -> containsValueFnRef(ref.targetRef) || containsValueFnRef(ref.indexRef)
is CallRef -> ref.tailBlock || containsValueFnRef(ref.target) || ref.args.any { arg ->
val stmt = arg.value
stmt is ExpressionStatement && containsValueFnRef(stmt.ref)
}
is MethodCallRef -> ref.tailBlock || containsValueFnRef(ref.receiver) || ref.args.any { arg ->
val stmt = arg.value
stmt is ExpressionStatement && containsValueFnRef(stmt.ref)
}
is ThisMethodSlotCallRef -> ref.hasTailBlock() || ref.arguments().any { arg ->
val stmt = arg.value
stmt is ExpressionStatement && containsValueFnRef(stmt.ref)
}
is ListLiteralRef -> ref.entries().any { entry ->
when (entry) {
is net.sergeych.lyng.ListEntry.Element -> containsValueFnRef(entry.ref)
is net.sergeych.lyng.ListEntry.Spread -> containsValueFnRef(entry.ref)
}
}
is StatementRef -> containsValueFnRef(ref.statement)
else -> false
}
}
private fun collectEffectiveDepths( private fun collectEffectiveDepths(
stmt: Statement, stmt: Statement,
scopeDepth: Int, scopeDepth: Int,

View File

@ -22,7 +22,6 @@ import kotlin.test.Test
class StdlibTest { class StdlibTest {
@Test @Test
@Ignore("TODO(bytecode-only): iterable filter mismatch")
fun testIterableFilter() = runTest { fun testIterableFilter() = runTest {
eval(""" eval("""
assertEquals([2,4,6,8], (1..8).filter{ println("call2"); it % 2 == 0 }.toList() ) assertEquals([2,4,6,8], (1..8).filter{ println("call2"); it % 2 == 0 }.toList() )
@ -95,7 +94,6 @@ class StdlibTest {
} }
@Test @Test
@Ignore("TODO(bytecode-only): flatten/filter mismatch")
fun testFlattenAndFilter() = runTest { fun testFlattenAndFilter() = runTest {
eval(""" eval("""
assertEquals([1,2,3,4,5,6], [1,3,5].map { [it, it+1] }.flatten() ) assertEquals([1,2,3,4,5,6], [1,3,5].map { [it, it+1] }.flatten() )
@ -111,7 +109,6 @@ class StdlibTest {
} }
@Test @Test
@Ignore("TODO(bytecode-only): count mismatch")
fun testCount() = runTest { fun testCount() = runTest {
eval(""" eval("""
assertEquals(5, (1..10).toList().count { it % 2 == 1 } ) assertEquals(5, (1..10).toList().count { it % 2 == 1 } )
@ -119,7 +116,6 @@ class StdlibTest {
} }
@Test @Test
@Ignore("TODO(bytecode-only): with mismatch")
fun testWith() = runTest { fun testWith() = runTest {
eval(""" eval("""
class Person(val name, var age) class Person(val name, var age)