Add fallback expression handling for bytecode if

This commit is contained in:
Sergey Chernov 2026-01-25 19:24:27 +03:00
parent 6560457e3d
commit c8a8b12dfc
6 changed files with 97 additions and 4 deletions

View File

@ -30,6 +30,7 @@ class BytecodeBuilder {
private val constPool = mutableListOf<BytecodeConst>() private val constPool = mutableListOf<BytecodeConst>()
private val labelPositions = mutableMapOf<Label, Int>() private val labelPositions = mutableMapOf<Label, Int>()
private var nextLabelId = 0 private var nextLabelId = 0
private val fallbackStatements = mutableListOf<net.sergeych.lyng.Statement>()
fun addConst(c: BytecodeConst): Int { fun addConst(c: BytecodeConst): Int {
constPool += c constPool += c
@ -50,6 +51,11 @@ class BytecodeBuilder {
labelPositions[label] = instructions.size labelPositions[label] = instructions.size
} }
fun addFallback(stmt: net.sergeych.lyng.Statement): Int {
fallbackStatements += stmt
return fallbackStatements.lastIndex
}
fun build(name: String, localCount: Int): BytecodeFunction { fun build(name: String, localCount: Int): BytecodeFunction {
val slotWidth = when { val slotWidth = when {
localCount < 256 -> 1 localCount < 256 -> 1
@ -100,6 +106,7 @@ class BytecodeBuilder {
ipWidth = ipWidth, ipWidth = ipWidth,
constIdWidth = constIdWidth, constIdWidth = constIdWidth,
constants = constPool.toList(), constants = constPool.toList(),
fallbackStatements = fallbackStatements.toList(),
code = code.toByteArray() code = code.toByteArray()
) )
} }

View File

@ -18,6 +18,9 @@ package net.sergeych.lyng.bytecode
import net.sergeych.lyng.ExpressionStatement import net.sergeych.lyng.ExpressionStatement
import net.sergeych.lyng.IfStatement import net.sergeych.lyng.IfStatement
import net.sergeych.lyng.Pos
import net.sergeych.lyng.Statement
import net.sergeych.lyng.ToBoolStatement
import net.sergeych.lyng.obj.* import net.sergeych.lyng.obj.*
class BytecodeCompiler { class BytecodeCompiler {
@ -33,7 +36,7 @@ class BytecodeCompiler {
} }
fun compileExpression(name: String, stmt: ExpressionStatement): BytecodeFunction? { fun compileExpression(name: String, stmt: ExpressionStatement): BytecodeFunction? {
val value = compileRef(stmt.ref) ?: return null val value = compileRefWithFallback(stmt.ref, null, stmt.pos) ?: return null
builder.emit(Opcode.RET, value.slot) builder.emit(Opcode.RET, value.slot)
val localCount = maxOf(nextSlot, value.slot + 1) val localCount = maxOf(nextSlot, value.slot + 1)
return builder.build(name, localCount) return builder.build(name, localCount)
@ -252,7 +255,7 @@ class BytecodeCompiler {
private fun compileIf(name: String, stmt: IfStatement): BytecodeFunction? { private fun compileIf(name: String, stmt: IfStatement): BytecodeFunction? {
val conditionStmt = stmt.condition as? ExpressionStatement ?: return null val conditionStmt = stmt.condition as? ExpressionStatement ?: return null
val condValue = compileRef(conditionStmt.ref) ?: return null val condValue = compileRefWithFallback(conditionStmt.ref, SlotType.BOOL, stmt.pos) ?: return null
if (condValue.type != SlotType.BOOL) return null if (condValue.type != SlotType.BOOL) return null
val resultSlot = allocSlot() val resultSlot = allocSlot()
@ -282,9 +285,9 @@ class BytecodeCompiler {
return builder.build(name, localCount) return builder.build(name, localCount)
} }
private fun compileStatementValue(stmt: net.sergeych.lyng.Statement): CompiledValue? { private fun compileStatementValue(stmt: Statement): CompiledValue? {
return when (stmt) { return when (stmt) {
is ExpressionStatement -> compileRef(stmt.ref) is ExpressionStatement -> compileRefWithFallback(stmt.ref, null, stmt.pos)
else -> null else -> null
} }
} }
@ -298,6 +301,24 @@ class BytecodeCompiler {
} }
} }
private fun compileRefWithFallback(ref: ObjRef, forceType: SlotType?, pos: Pos): CompiledValue? {
val compiled = compileRef(ref)
if (compiled != null && (forceType == null || compiled.type == forceType || compiled.type == SlotType.UNKNOWN)) {
return if (forceType != null && compiled.type == SlotType.UNKNOWN) {
CompiledValue(compiled.slot, forceType)
} else compiled
}
val slot = allocSlot()
val stmt = if (forceType == SlotType.BOOL) {
ToBoolStatement(ExpressionStatement(ref, pos), pos)
} else {
ExpressionStatement(ref, pos)
}
val id = builder.addFallback(stmt)
builder.emit(Opcode.EVAL_FALLBACK, id, slot)
return CompiledValue(slot, forceType ?: SlotType.OBJ)
}
private fun refSlot(ref: LocalSlotRef): Int = ref.slot private fun refSlot(ref: LocalSlotRef): Int = ref.slot
private fun refDepth(ref: LocalSlotRef): Int = ref.depth private fun refDepth(ref: LocalSlotRef): Int = ref.depth
private fun binaryLeft(ref: BinaryOpRef): ObjRef = ref.left private fun binaryLeft(ref: BinaryOpRef): ObjRef = ref.left

View File

@ -23,6 +23,7 @@ data class BytecodeFunction(
val ipWidth: Int, val ipWidth: Int,
val constIdWidth: Int, val constIdWidth: Int,
val constants: List<BytecodeConst>, val constants: List<BytecodeConst>,
val fallbackStatements: List<net.sergeych.lyng.Statement>,
val code: ByteArray, val code: ByteArray,
) { ) {
init { init {

View File

@ -67,6 +67,21 @@ class BytecodeVm {
?: error("CONST_BOOL expects Bool at $constId") ?: error("CONST_BOOL expects Bool at $constId")
frame.setBool(dst, c.value) frame.setBool(dst, c.value)
} }
Opcode.CONST_OBJ -> {
val constId = decoder.readConstId(code, ip, fn.constIdWidth)
ip += fn.constIdWidth
val dst = decoder.readSlot(code, ip)
ip += fn.slotWidth
val c = fn.constants[constId] as? BytecodeConst.ObjRef
?: error("CONST_OBJ expects ObjRef at $constId")
val obj = c.value
when (obj) {
is ObjInt -> frame.setInt(dst, obj.value)
is ObjReal -> frame.setReal(dst, obj.value)
is ObjBool -> frame.setBool(dst, obj.value)
else -> frame.setObj(dst, obj)
}
}
Opcode.CONST_NULL -> { Opcode.CONST_NULL -> {
val dst = decoder.readSlot(code, ip) val dst = decoder.readSlot(code, ip)
ip += fn.slotWidth ip += fn.slotWidth
@ -140,6 +155,21 @@ class BytecodeVm {
ip = target ip = target
} }
} }
Opcode.EVAL_FALLBACK -> {
val id = decoder.readConstId(code, ip, 2)
ip += 2
val dst = decoder.readSlot(code, ip)
ip += fn.slotWidth
val stmt = fn.fallbackStatements.getOrNull(id)
?: error("Fallback statement not found: $id")
val result = stmt.execute(scope)
when (result) {
is ObjInt -> frame.setInt(dst, result.value)
is ObjReal -> frame.setReal(dst, result.value)
is ObjBool -> frame.setBool(dst, result.value)
else -> frame.setObj(dst, result)
}
}
Opcode.RET -> { Opcode.RET -> {
val slot = decoder.readSlot(code, ip) val slot = decoder.readSlot(code, ip)
return slotToObj(frame, slot) return slotToObj(frame, slot)

View File

@ -79,6 +79,15 @@ class IfStatement(
} }
} }
class ToBoolStatement(
val expr: Statement,
override val pos: Pos,
) : Statement() {
override suspend fun execute(scope: Scope): Obj {
return if (expr.execute(scope).toBool()) net.sergeych.lyng.obj.ObjTrue else net.sergeych.lyng.obj.ObjFalse
}
}
class ExpressionStatement( class ExpressionStatement(
val ref: net.sergeych.lyng.obj.ObjRef, val ref: net.sergeych.lyng.obj.ObjRef,
override val pos: Pos override val pos: Pos

View File

@ -26,6 +26,7 @@ import net.sergeych.lyng.obj.BinaryOpRef
import net.sergeych.lyng.obj.BinOp import net.sergeych.lyng.obj.BinOp
import net.sergeych.lyng.obj.ConstRef import net.sergeych.lyng.obj.ConstRef
import net.sergeych.lyng.obj.ObjInt import net.sergeych.lyng.obj.ObjInt
import net.sergeych.lyng.obj.ObjVoid
import net.sergeych.lyng.obj.toInt import net.sergeych.lyng.obj.toInt
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -68,4 +69,28 @@ class BytecodeVmTest {
val result = BytecodeVm().execute(fn, Scope(), emptyList()) val result = BytecodeVm().execute(fn, Scope(), emptyList())
assertEquals(10, result.toInt()) assertEquals(10, result.toInt())
} }
@Test
fun ifWithoutElseReturnsVoid() = kotlinx.coroutines.test.runTest {
val cond = ExpressionStatement(
BinaryOpRef(
BinOp.LT,
ConstRef(ObjInt.of(2).asReadonly),
ConstRef(ObjInt.of(1).asReadonly),
),
net.sergeych.lyng.Pos.builtIn
)
val thenStmt = ExpressionStatement(
ConstRef(ObjInt.of(10).asReadonly),
net.sergeych.lyng.Pos.builtIn
)
val ifStmt = IfStatement(cond, thenStmt, null, net.sergeych.lyng.Pos.builtIn)
val fn = BytecodeCompiler().compileStatement("ifNoElse", ifStmt).also {
if (it == null) {
error("bytecode compile failed for ifNoElse")
}
}!!
val result = BytecodeVm().execute(fn, Scope(), emptyList())
assertEquals(ObjVoid, result)
}
} }