diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeBuilder.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeBuilder.kt index bd93c40..d58b153 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeBuilder.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeBuilder.kt @@ -30,6 +30,7 @@ class BytecodeBuilder { private val constPool = mutableListOf() private val labelPositions = mutableMapOf() private var nextLabelId = 0 + private val fallbackStatements = mutableListOf() fun addConst(c: BytecodeConst): Int { constPool += c @@ -50,6 +51,11 @@ class BytecodeBuilder { labelPositions[label] = instructions.size } + fun addFallback(stmt: net.sergeych.lyng.Statement): Int { + fallbackStatements += stmt + return fallbackStatements.lastIndex + } + fun build(name: String, localCount: Int): BytecodeFunction { val slotWidth = when { localCount < 256 -> 1 @@ -100,6 +106,7 @@ class BytecodeBuilder { ipWidth = ipWidth, constIdWidth = constIdWidth, constants = constPool.toList(), + fallbackStatements = fallbackStatements.toList(), code = code.toByteArray() ) } diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt index 6f050da..f61f8a5 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt @@ -18,6 +18,9 @@ package net.sergeych.lyng.bytecode import net.sergeych.lyng.ExpressionStatement import net.sergeych.lyng.IfStatement +import net.sergeych.lyng.Pos +import net.sergeych.lyng.Statement +import net.sergeych.lyng.ToBoolStatement import net.sergeych.lyng.obj.* class BytecodeCompiler { @@ -33,7 +36,7 @@ class BytecodeCompiler { } fun compileExpression(name: String, stmt: ExpressionStatement): BytecodeFunction? { - val value = compileRef(stmt.ref) ?: return null + val value = compileRefWithFallback(stmt.ref, null, stmt.pos) ?: return null builder.emit(Opcode.RET, value.slot) val localCount = maxOf(nextSlot, value.slot + 1) return builder.build(name, localCount) @@ -252,7 +255,7 @@ class BytecodeCompiler { private fun compileIf(name: String, stmt: IfStatement): BytecodeFunction? { val conditionStmt = stmt.condition as? ExpressionStatement ?: return null - val condValue = compileRef(conditionStmt.ref) ?: return null + val condValue = compileRefWithFallback(conditionStmt.ref, SlotType.BOOL, stmt.pos) ?: return null if (condValue.type != SlotType.BOOL) return null val resultSlot = allocSlot() @@ -282,9 +285,9 @@ class BytecodeCompiler { return builder.build(name, localCount) } - private fun compileStatementValue(stmt: net.sergeych.lyng.Statement): CompiledValue? { + private fun compileStatementValue(stmt: Statement): CompiledValue? { return when (stmt) { - is ExpressionStatement -> compileRef(stmt.ref) + is ExpressionStatement -> compileRefWithFallback(stmt.ref, null, stmt.pos) else -> null } } @@ -298,6 +301,24 @@ class BytecodeCompiler { } } + private fun compileRefWithFallback(ref: ObjRef, forceType: SlotType?, pos: Pos): CompiledValue? { + val compiled = compileRef(ref) + if (compiled != null && (forceType == null || compiled.type == forceType || compiled.type == SlotType.UNKNOWN)) { + return if (forceType != null && compiled.type == SlotType.UNKNOWN) { + CompiledValue(compiled.slot, forceType) + } else compiled + } + val slot = allocSlot() + val stmt = if (forceType == SlotType.BOOL) { + ToBoolStatement(ExpressionStatement(ref, pos), pos) + } else { + ExpressionStatement(ref, pos) + } + val id = builder.addFallback(stmt) + builder.emit(Opcode.EVAL_FALLBACK, id, slot) + return CompiledValue(slot, forceType ?: SlotType.OBJ) + } + private fun refSlot(ref: LocalSlotRef): Int = ref.slot private fun refDepth(ref: LocalSlotRef): Int = ref.depth private fun binaryLeft(ref: BinaryOpRef): ObjRef = ref.left diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeFunction.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeFunction.kt index b1e2602..b04d0c6 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeFunction.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeFunction.kt @@ -23,6 +23,7 @@ data class BytecodeFunction( val ipWidth: Int, val constIdWidth: Int, val constants: List, + val fallbackStatements: List, val code: ByteArray, ) { init { diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeVm.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeVm.kt index f085653..0473729 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeVm.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeVm.kt @@ -67,6 +67,21 @@ class BytecodeVm { ?: error("CONST_BOOL expects Bool at $constId") frame.setBool(dst, c.value) } + Opcode.CONST_OBJ -> { + val constId = decoder.readConstId(code, ip, fn.constIdWidth) + ip += fn.constIdWidth + val dst = decoder.readSlot(code, ip) + ip += fn.slotWidth + val c = fn.constants[constId] as? BytecodeConst.ObjRef + ?: error("CONST_OBJ expects ObjRef at $constId") + val obj = c.value + when (obj) { + is ObjInt -> frame.setInt(dst, obj.value) + is ObjReal -> frame.setReal(dst, obj.value) + is ObjBool -> frame.setBool(dst, obj.value) + else -> frame.setObj(dst, obj) + } + } Opcode.CONST_NULL -> { val dst = decoder.readSlot(code, ip) ip += fn.slotWidth @@ -140,6 +155,21 @@ class BytecodeVm { ip = target } } + Opcode.EVAL_FALLBACK -> { + val id = decoder.readConstId(code, ip, 2) + ip += 2 + val dst = decoder.readSlot(code, ip) + ip += fn.slotWidth + val stmt = fn.fallbackStatements.getOrNull(id) + ?: error("Fallback statement not found: $id") + val result = stmt.execute(scope) + when (result) { + is ObjInt -> frame.setInt(dst, result.value) + is ObjReal -> frame.setReal(dst, result.value) + is ObjBool -> frame.setBool(dst, result.value) + else -> frame.setObj(dst, result) + } + } Opcode.RET -> { val slot = decoder.readSlot(code, ip) return slotToObj(frame, slot) diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/statements.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/statements.kt index 0d43e5a..157182a 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/statements.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/statements.kt @@ -79,6 +79,15 @@ class IfStatement( } } +class ToBoolStatement( + val expr: Statement, + override val pos: Pos, +) : Statement() { + override suspend fun execute(scope: Scope): Obj { + return if (expr.execute(scope).toBool()) net.sergeych.lyng.obj.ObjTrue else net.sergeych.lyng.obj.ObjFalse + } +} + class ExpressionStatement( val ref: net.sergeych.lyng.obj.ObjRef, override val pos: Pos diff --git a/lynglib/src/commonTest/kotlin/BytecodeVmTest.kt b/lynglib/src/commonTest/kotlin/BytecodeVmTest.kt index 18bb0bd..dd822b5 100644 --- a/lynglib/src/commonTest/kotlin/BytecodeVmTest.kt +++ b/lynglib/src/commonTest/kotlin/BytecodeVmTest.kt @@ -26,6 +26,7 @@ import net.sergeych.lyng.obj.BinaryOpRef import net.sergeych.lyng.obj.BinOp import net.sergeych.lyng.obj.ConstRef import net.sergeych.lyng.obj.ObjInt +import net.sergeych.lyng.obj.ObjVoid import net.sergeych.lyng.obj.toInt import kotlin.test.Test import kotlin.test.assertEquals @@ -68,4 +69,28 @@ class BytecodeVmTest { val result = BytecodeVm().execute(fn, Scope(), emptyList()) assertEquals(10, result.toInt()) } + + @Test + fun ifWithoutElseReturnsVoid() = kotlinx.coroutines.test.runTest { + val cond = ExpressionStatement( + BinaryOpRef( + BinOp.LT, + ConstRef(ObjInt.of(2).asReadonly), + ConstRef(ObjInt.of(1).asReadonly), + ), + net.sergeych.lyng.Pos.builtIn + ) + val thenStmt = ExpressionStatement( + ConstRef(ObjInt.of(10).asReadonly), + net.sergeych.lyng.Pos.builtIn + ) + val ifStmt = IfStatement(cond, thenStmt, null, net.sergeych.lyng.Pos.builtIn) + val fn = BytecodeCompiler().compileStatement("ifNoElse", ifStmt).also { + if (it == null) { + error("bytecode compile failed for ifNoElse") + } + }!! + val result = BytecodeVm().execute(fn, Scope(), emptyList()) + assertEquals(ObjVoid, result) + } }