Enable fast local refs behind compiler flag

This commit is contained in:
Sergey Chernov 2026-02-09 10:28:32 +03:00
parent f0dc0d2396
commit 541738646f
5 changed files with 127 additions and 7 deletions

View File

@ -46,10 +46,10 @@ Goal: migrate the compiler so all values live in frames/bytecode, keeping JVM te
- [x] Step 13: Qualified `this` value refs in bytecode. - [x] Step 13: Qualified `this` value refs in bytecode.
- [x] Compile `QualifiedThisRef` (`this@Type`) via `LOAD_THIS_VARIANT`. - [x] Compile `QualifiedThisRef` (`this@Type`) via `LOAD_THIS_VARIANT`.
- [x] Add a JVM test that evaluates `this@Type` as a value inside nested classes. - [x] Add a JVM test that evaluates `this@Type` as a value inside nested classes.
- [ ] Step 14: Fast local ref reads in bytecode. - [x] Step 14: Fast local ref reads in bytecode.
- [ ] Support `FastLocalVarRef` reads with the same slot resolution as `LocalVarRef`. - [x] Support `FastLocalVarRef` reads with the same slot resolution as `LocalVarRef`.
- [ ] If `BoundLocalVarRef` is still emitted, map it to a direct slot read instead of failing. - [x] If `BoundLocalVarRef` is still emitted, map it to a direct slot read instead of failing.
- [ ] Add a JVM test that exercises fast-local reads in a bytecode-compiled function. - [x] Add a JVM test that exercises fast-local reads in a bytecode-compiled function.
- [ ] Step 15: Class-scope `?=` in bytecode. - [ ] Step 15: Class-scope `?=` in bytecode.
- [ ] Handle `C.x ?= v` and `C?.x ?= v` for class-scope members without falling back. - [ ] Handle `C.x ?= v` and `C?.x ?= v` for class-scope members without falling back.
- [ ] Add a JVM test for class-scope `?=` on static vars. - [ ] Add a JVM test for class-scope `?=` on static vars.

View File

@ -46,8 +46,11 @@ class Compiler(
// Track identifiers known to be locals/parameters in the current function for fast local emission // Track identifiers known to be locals/parameters in the current function for fast local emission
private val localNamesStack = mutableListOf<MutableSet<String>>() private val localNamesStack = mutableListOf<MutableSet<String>>()
private val localShadowedNamesStack = mutableListOf<MutableSet<String>>()
private val currentLocalNames: MutableSet<String>? private val currentLocalNames: MutableSet<String>?
get() = localNamesStack.lastOrNull() get() = localNamesStack.lastOrNull()
private val currentShadowedLocalNames: MutableSet<String>?
get() = localShadowedNamesStack.lastOrNull()
private data class SlotEntry(val index: Int, val isMutable: Boolean, val isDelegated: Boolean) private data class SlotEntry(val index: Int, val isMutable: Boolean, val isDelegated: Boolean)
private data class SlotPlan(val slots: MutableMap<String, SlotEntry>, var nextIndex: Int, val id: Int) private data class SlotPlan(val slots: MutableMap<String, SlotEntry>, var nextIndex: Int, val id: Int)
@ -94,9 +97,11 @@ class Compiler(
private inline fun <T> withLocalNames(names: Set<String>, block: () -> T): T { private inline fun <T> withLocalNames(names: Set<String>, block: () -> T): T {
localNamesStack.add(names.toMutableSet()) localNamesStack.add(names.toMutableSet())
localShadowedNamesStack.add(mutableSetOf())
return try { return try {
block() block()
} finally { } finally {
localShadowedNamesStack.removeLast()
localNamesStack.removeLast() localNamesStack.removeLast()
} }
} }
@ -104,6 +109,9 @@ class Compiler(
private fun declareLocalName(name: String, isMutable: Boolean, isDelegated: Boolean = false) { private fun declareLocalName(name: String, isMutable: Boolean, isDelegated: Boolean = false) {
// Add to current function's local set; only count if it was newly added (avoid duplicates) // Add to current function's local set; only count if it was newly added (avoid duplicates)
val added = currentLocalNames?.add(name) == true val added = currentLocalNames?.add(name) == true
if (!added) {
currentShadowedLocalNames?.add(name)
}
if (added && localDeclCountStack.isNotEmpty()) { if (added && localDeclCountStack.isNotEmpty()) {
localDeclCountStack[localDeclCountStack.lastIndex] = currentLocalDeclCount + 1 localDeclCountStack[localDeclCountStack.lastIndex] = currentLocalDeclCount + 1
} }
@ -842,6 +850,16 @@ class Compiler(
return ref return ref
} }
val captureOwner = capturePlanStack.lastOrNull()?.captureOwners?.get(name) val captureOwner = capturePlanStack.lastOrNull()?.captureOwners?.get(name)
if (useFastLocalRefs &&
slotLoc.depth == 0 &&
captureOwner == null &&
currentLocalNames?.contains(name) == true &&
currentShadowedLocalNames?.contains(name) != true &&
!slotLoc.isDelegated
) {
resolutionSink?.reference(name, pos)
return FastLocalVarRef(name, pos)
}
if (slotLoc.depth == 0 && captureOwner != null) { if (slotLoc.depth == 0 && captureOwner != null) {
val ref = LocalSlotRef( val ref = LocalSlotRef(
name, name,
@ -1024,12 +1042,14 @@ class Compiler(
val strictSlotRefs: Boolean = true, val strictSlotRefs: Boolean = true,
val allowUnresolvedRefs: Boolean = false, val allowUnresolvedRefs: Boolean = false,
val seedScope: Scope? = null, val seedScope: Scope? = null,
val useFastLocalRefs: Boolean = false,
) )
// Optional sink for mini-AST streaming (null by default, zero overhead when not used) // Optional sink for mini-AST streaming (null by default, zero overhead when not used)
private val miniSink: MiniAstSink? = settings.miniAstSink private val miniSink: MiniAstSink? = settings.miniAstSink
private val resolutionSink: ResolutionSink? = settings.resolutionSink private val resolutionSink: ResolutionSink? = settings.resolutionSink
private val seedScope: Scope? = settings.seedScope private val seedScope: Scope? = settings.seedScope
private val useFastLocalRefs: Boolean = settings.useFastLocalRefs
private var resolutionScriptDepth = 0 private var resolutionScriptDepth = 0
private val resolutionPredeclared = mutableSetOf<String>() private val resolutionPredeclared = mutableSetOf<String>()
private data class ImportedModule(val scope: ModuleScope, val pos: Pos) private data class ImportedModule(val scope: ModuleScope, val pos: Pos)
@ -3964,6 +3984,7 @@ class Compiler(
private fun inferObjClassFromRef(ref: ObjRef): ObjClass? = when (ref) { private fun inferObjClassFromRef(ref: ObjRef): ObjClass? = when (ref) {
is ConstRef -> ref.constValue as? ObjClass ?: (ref.constValue as? Obj)?.objClass is ConstRef -> ref.constValue as? ObjClass ?: (ref.constValue as? Obj)?.objClass
is LocalVarRef -> nameObjClass[ref.name] ?: resolveClassByName(ref.name) is LocalVarRef -> nameObjClass[ref.name] ?: resolveClassByName(ref.name)
is FastLocalVarRef -> nameObjClass[ref.name] ?: resolveClassByName(ref.name)
is LocalSlotRef -> { is LocalSlotRef -> {
val ownerScopeId = ref.captureOwnerScopeId ?: ref.scopeId val ownerScopeId = ref.captureOwnerScopeId ?: ref.scopeId
val ownerSlot = ref.captureOwnerSlot ?: ref.slot val ownerSlot = ref.captureOwnerSlot ?: ref.slot
@ -3990,6 +4011,7 @@ class Compiler(
slotTypeDeclByScopeId[ownerScopeId]?.get(ownerSlot) slotTypeDeclByScopeId[ownerScopeId]?.get(ownerSlot)
} }
is LocalVarRef -> nameTypeDecl[ref.name] is LocalVarRef -> nameTypeDecl[ref.name]
is FastLocalVarRef -> nameTypeDecl[ref.name]
is MethodCallRef -> methodReturnTypeDeclByRef[ref] is MethodCallRef -> methodReturnTypeDeclByRef[ref]
is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) } is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) }
else -> null else -> null
@ -4009,6 +4031,9 @@ class Compiler(
is LocalVarRef -> nameObjClass[ref.name] is LocalVarRef -> nameObjClass[ref.name]
?: nameTypeDecl[ref.name]?.let { resolveTypeDeclObjClass(it) } ?: nameTypeDecl[ref.name]?.let { resolveTypeDeclObjClass(it) }
?: resolveClassByName(ref.name) ?: resolveClassByName(ref.name)
is FastLocalVarRef -> nameObjClass[ref.name]
?: nameTypeDecl[ref.name]?.let { resolveTypeDeclObjClass(it) }
?: resolveClassByName(ref.name)
is ClassScopeMemberRef -> { is ClassScopeMemberRef -> {
val targetClass = resolveClassByName(ref.ownerClassName()) val targetClass = resolveClassByName(ref.ownerClassName())
inferFieldReturnClass(targetClass, ref.name) inferFieldReturnClass(targetClass, ref.name)
@ -4072,6 +4097,8 @@ class Compiler(
?: resolveClassByName(target.name) ?: resolveClassByName(target.name)
is LocalVarRef -> callableReturnTypeByName[target.name] is LocalVarRef -> callableReturnTypeByName[target.name]
?: resolveClassByName(target.name) ?: resolveClassByName(target.name)
is FastLocalVarRef -> callableReturnTypeByName[target.name]
?: resolveClassByName(target.name)
is ConstRef -> when (val value = target.constValue) { is ConstRef -> when (val value = target.constValue) {
is ObjClass -> value is ObjClass -> value
is ObjString -> ObjString.type is ObjString -> ObjString.type
@ -4348,6 +4375,7 @@ class Compiler(
is ConstRef -> ref.constValue as? ObjClass is ConstRef -> ref.constValue as? ObjClass
is LocalSlotRef -> resolveTypeDeclObjClass(TypeDecl.Simple(ref.name, false)) ?: nameObjClass[ref.name] is LocalSlotRef -> resolveTypeDeclObjClass(TypeDecl.Simple(ref.name, false)) ?: nameObjClass[ref.name]
is LocalVarRef -> resolveTypeDeclObjClass(TypeDecl.Simple(ref.name, false)) ?: nameObjClass[ref.name] is LocalVarRef -> resolveTypeDeclObjClass(TypeDecl.Simple(ref.name, false)) ?: nameObjClass[ref.name]
is FastLocalVarRef -> resolveTypeDeclObjClass(TypeDecl.Simple(ref.name, false)) ?: nameObjClass[ref.name]
else -> null else -> null
} }
@ -4878,6 +4906,8 @@ class Compiler(
} }
is LocalVarRef -> nameObjClass[ref.name]?.className is LocalVarRef -> nameObjClass[ref.name]?.className
?: nameTypeDecl[ref.name]?.let { typeDeclName(it) } ?: nameTypeDecl[ref.name]?.let { typeDeclName(it) }
is FastLocalVarRef -> nameObjClass[ref.name]?.className
?: nameTypeDecl[ref.name]?.let { typeDeclName(it) }
is QualifiedThisRef -> ref.typeName is QualifiedThisRef -> ref.typeName
else -> resolveReceiverClassForMember(ref)?.className else -> resolveReceiverClassForMember(ref)?.className
} }
@ -8529,7 +8559,8 @@ class Compiler(
useBytecodeStatements: Boolean = true, useBytecodeStatements: Boolean = true,
strictSlotRefs: Boolean = true, strictSlotRefs: Boolean = true,
allowUnresolvedRefs: Boolean = false, allowUnresolvedRefs: Boolean = false,
seedScope: Scope? = null seedScope: Scope? = null,
useFastLocalRefs: Boolean = false
): Script { ): Script {
return Compiler( return Compiler(
CompilerContext(parseLyng(source)), CompilerContext(parseLyng(source)),
@ -8540,7 +8571,8 @@ class Compiler(
useBytecodeStatements = useBytecodeStatements, useBytecodeStatements = useBytecodeStatements,
strictSlotRefs = strictSlotRefs, strictSlotRefs = strictSlotRefs,
allowUnresolvedRefs = allowUnresolvedRefs, allowUnresolvedRefs = allowUnresolvedRefs,
seedScope = seedScope seedScope = seedScope,
useFastLocalRefs = useFastLocalRefs
) )
).parseScript() ).parseScript()
} }

View File

@ -340,6 +340,49 @@ class BytecodeCompiler(
} }
null null
} }
is FastLocalVarRef -> {
if (ref.name == "this") {
return compileThisRef()
}
loopSlotOverrides[ref.name]?.let { slot ->
val resolved = slotTypes[slot] ?: SlotType.UNKNOWN
return CompiledValue(slot, resolved)
}
if (allowLocalSlots) {
if (!forceScopeSlots) {
scopeSlotIndexByName[ref.name]?.let { slot ->
val resolved = slotTypes[slot] ?: SlotType.UNKNOWN
return CompiledValue(slot, resolved)
}
val localIndex = localSlotIndexByName[ref.name]
if (localIndex != null) {
val slot = scopeSlotCount + localIndex
val resolved = slotTypes[slot] ?: SlotType.UNKNOWN
return CompiledValue(slot, resolved)
}
}
if (forceScopeSlots) {
scopeSlotIndexByName[ref.name]?.let { slot ->
val resolved = slotTypes[slot] ?: SlotType.UNKNOWN
return CompiledValue(slot, resolved)
}
}
}
null
}
is BoundLocalVarRef -> {
if (!allowLocalSlots) return null
val slot = ref.slotIndex()
val resolved = slotTypes[slot] ?: SlotType.UNKNOWN
if (slot < scopeSlotCount && resolved != SlotType.UNKNOWN) {
val addrSlot = ensureScopeAddr(slot)
val local = allocSlot()
emitLoadFromAddr(addrSlot, local, resolved)
updateSlotType(local, resolved)
return CompiledValue(local, resolved)
}
CompiledValue(slot, resolved)
}
is ValueFnRef -> compileValueFnRef(ref) is ValueFnRef -> compileValueFnRef(ref)
is ListLiteralRef -> compileListLiteral(ref) is ListLiteralRef -> compileListLiteral(ref)
is MapLiteralRef -> compileMapLiteral(ref) is MapLiteralRef -> compileMapLiteral(ref)
@ -5360,9 +5403,10 @@ class BytecodeCompiler(
compiled = null compiled = null
} }
} }
if (ref is LocalVarRef || ref is LocalSlotRef) { if (ref is LocalVarRef || ref is LocalSlotRef || ref is FastLocalVarRef) {
val name = when (ref) { val name = when (ref) {
is LocalVarRef -> ref.name is LocalVarRef -> ref.name
is FastLocalVarRef -> ref.name
is LocalSlotRef -> ref.name is LocalSlotRef -> ref.name
else -> "unknown" else -> "unknown"
} }
@ -5423,6 +5467,20 @@ class BytecodeCompiler(
} ?: nameObjClass[ref.name] } ?: nameObjClass[ref.name]
?: resolveTypeNameClass(ref.name) ?: resolveTypeNameClass(ref.name)
} }
is FastLocalVarRef -> {
if (knownObjectNames.contains(ref.name)) {
return nameObjClass[ref.name] ?: ObjDynamic.type
}
val fromSlot = resolveDirectNameSlot(ref.name)?.let { slotObjClass[it.slot] }
if (fromSlot != null) return fromSlot
val key = localSlotInfoMap.entries.firstOrNull { it.value.name == ref.name }?.key
key?.let {
slotTypeByScopeId[it.scopeId]?.get(it.slot)
?: slotInitClassByKey[it]
} ?: nameObjClass[ref.name]
?: resolveTypeNameClass(ref.name)
}
is BoundLocalVarRef -> slotObjClass[ref.slotIndex()]
is QualifiedThisRef -> resolveTypeNameClass(ref.typeName) is QualifiedThisRef -> resolveTypeNameClass(ref.typeName)
is ListLiteralRef -> ObjList.type is ListLiteralRef -> ObjList.type
is MapLiteralRef -> ObjMap.type is MapLiteralRef -> ObjMap.type
@ -5463,6 +5521,7 @@ class BytecodeCompiler(
return when (ref) { return when (ref) {
is LocalVarRef -> knownClassNames.contains(ref.name) && !knownObjectNames.contains(ref.name) is LocalVarRef -> knownClassNames.contains(ref.name) && !knownObjectNames.contains(ref.name)
is LocalSlotRef -> knownClassNames.contains(ref.name) && !knownObjectNames.contains(ref.name) is LocalSlotRef -> knownClassNames.contains(ref.name) && !knownObjectNames.contains(ref.name)
is FastLocalVarRef -> knownClassNames.contains(ref.name) && !knownObjectNames.contains(ref.name)
else -> false else -> false
} }
} }
@ -5489,6 +5548,8 @@ class BytecodeCompiler(
return when (ref) { return when (ref) {
is LocalSlotRef -> nameObjClass[ref.name] ?: resolveTypeNameClass(ref.name) is LocalSlotRef -> nameObjClass[ref.name] ?: resolveTypeNameClass(ref.name)
is LocalVarRef -> nameObjClass[ref.name] ?: resolveTypeNameClass(ref.name) is LocalVarRef -> nameObjClass[ref.name] ?: resolveTypeNameClass(ref.name)
is FastLocalVarRef -> nameObjClass[ref.name] ?: resolveTypeNameClass(ref.name)
is BoundLocalVarRef -> slotObjClass[ref.slotIndex()]
is QualifiedThisRef -> resolveTypeNameClass(ref.typeName) is QualifiedThisRef -> resolveTypeNameClass(ref.typeName)
is ListLiteralRef -> ObjList.type is ListLiteralRef -> ObjList.type
is MapLiteralRef -> ObjMap.type is MapLiteralRef -> ObjMap.type
@ -5530,6 +5591,8 @@ class BytecodeCompiler(
is ConstRef -> ref.constValue as? ObjClass is ConstRef -> ref.constValue as? ObjClass
is LocalSlotRef -> resolveTypeNameClass(ref.name) ?: nameObjClass[ref.name] is LocalSlotRef -> resolveTypeNameClass(ref.name) ?: nameObjClass[ref.name]
is LocalVarRef -> resolveTypeNameClass(ref.name) ?: nameObjClass[ref.name] is LocalVarRef -> resolveTypeNameClass(ref.name) ?: nameObjClass[ref.name]
is FastLocalVarRef -> resolveTypeNameClass(ref.name) ?: nameObjClass[ref.name]
is QualifiedThisRef -> resolveTypeNameClass(ref.typeName)
else -> null else -> null
} }
} }

View File

@ -1971,6 +1971,7 @@ class BoundLocalVarRef(
private val slot: Int, private val slot: Int,
private val atPos: Pos, private val atPos: Pos,
) : ObjRef { ) : ObjRef {
internal fun slotIndex(): Int = slot
override suspend fun get(scope: Scope): ObjRecord { override suspend fun get(scope: Scope): ObjRecord {
scope.pos = atPos scope.pos = atPos
val rec = scope.getSlotRecord(slot) val rec = scope.getSlotRecord(slot)

View File

@ -16,8 +16,13 @@
*/ */
import kotlinx.coroutines.test.runTest import kotlinx.coroutines.test.runTest
import net.sergeych.lyng.Compiler
import net.sergeych.lyng.Script
import net.sergeych.lyng.Source
import net.sergeych.lyng.eval import net.sergeych.lyng.eval
import net.sergeych.lyng.obj.toInt
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals
class BytecodeRecentOpsTest { class BytecodeRecentOpsTest {
@ -154,4 +159,23 @@ class BytecodeRecentOpsTest {
""".trimIndent() """.trimIndent()
) )
} }
@Test
fun fastLocalVarRefRead() = runTest {
val code = """
fun addOne(x) {
val y = x + 1
y
}
addOne(1)
""".trimIndent()
val script = Compiler.compileWithResolution(
Source("<fast-local>", code),
Script.defaultImportManager,
useBytecodeStatements = true,
useFastLocalRefs = true
)
val result = script.execute(Script.defaultImportManager.newStdScope())
assertEquals(2, result.toInt())
}
} }