Add bytecode support for when statements

This commit is contained in:
Sergey Chernov 2026-01-29 04:11:27 +03:00
parent 91624a30b8
commit e143f31f3d
5 changed files with 218 additions and 56 deletions

View File

@ -2031,8 +2031,6 @@ class Compiler(
} }
} }
data class WhenCase(val condition: Statement, val block: Statement)
private suspend fun parseWhenStatement(): Statement { private suspend fun parseWhenStatement(): Statement {
// has a value, when(value) ? // has a value, when(value) ?
var t = cc.nextNonWhitespace() var t = cc.nextNonWhitespace()
@ -2044,7 +2042,6 @@ class Compiler(
if (t.type != Token.Type.LBRACE) throw ScriptError(t.pos, "when { ... } expected") if (t.type != Token.Type.LBRACE) throw ScriptError(t.pos, "when { ... } expected")
val cases = mutableListOf<WhenCase>() val cases = mutableListOf<WhenCase>()
var elseCase: Statement? = null var elseCase: Statement? = null
lateinit var whenValue: Obj
// there could be 0+ then clauses // there could be 0+ then clauses
// condition could be a value, in and is clauses: // condition could be a value, in and is clauses:
@ -2053,9 +2050,8 @@ class Compiler(
// loop cases // loop cases
outer@ while (true) { outer@ while (true) {
var skipParseBody = false var skipParseBody = false
val currentCondition = mutableListOf<Statement>() val currentConditions = mutableListOf<WhenCondition>()
// loop conditions // loop conditions
while (true) { while (true) {
@ -2064,31 +2060,16 @@ class Compiler(
when (t.type) { when (t.type) {
Token.Type.IN, Token.Type.IN,
Token.Type.NOTIN -> { Token.Type.NOTIN -> {
// we need a copy in the closure: val negated = t.type == Token.Type.NOTIN
val isIn = t.type == Token.Type.IN
val container = parseExpression() ?: throw ScriptError(cc.currentPos(), "type expected") val container = parseExpression() ?: throw ScriptError(cc.currentPos(), "type expected")
val condPos = t.pos currentConditions += WhenInCondition(container, negated, t.pos)
currentCondition += object : Statement() {
override val pos: Pos = condPos
override suspend fun execute(scope: Scope): Obj {
val r = container.execute(scope).contains(scope, whenValue)
return ObjBool(if (isIn) r else !r)
}
}
} }
Token.Type.IS, Token.Type.NOTIS -> { Token.Type.IS,
// we need a copy in the closure: Token.Type.NOTIS -> {
val isIn = t.type == Token.Type.IS val negated = t.type == Token.Type.NOTIS
val caseType = parseExpression() ?: throw ScriptError(cc.currentPos(), "type expected") val caseType = parseExpression() ?: throw ScriptError(cc.currentPos(), "type expected")
val condPos = t.pos currentConditions += WhenIsCondition(caseType, negated, t.pos)
currentCondition += object : Statement() {
override val pos: Pos = condPos
override suspend fun execute(scope: Scope): Obj {
val r = whenValue.isInstanceOf(caseType.execute(scope))
return ObjBool(if (isIn) r else !r)
}
}
} }
Token.Type.COMMA -> Token.Type.COMMA ->
@ -2117,13 +2098,7 @@ class Compiler(
cc.previous() cc.previous()
val x = parseExpression() val x = parseExpression()
?: throw ScriptError(cc.currentPos(), "when case condition expected") ?: throw ScriptError(cc.currentPos(), "when case condition expected")
val condPos = t.pos currentConditions += WhenEqualsCondition(x, t.pos)
currentCondition += object : Statement() {
override val pos: Pos = condPos
override suspend fun execute(scope: Scope): Obj {
return ObjBool(x.execute(scope).compareTo(scope, whenValue) == 0)
}
}
} }
} }
} }
@ -2132,28 +2107,11 @@ class Compiler(
if (!skipParseBody) { if (!skipParseBody) {
val block = parseStatement()?.let { unwrapBytecodeDeep(it) } val block = parseStatement()?.let { unwrapBytecodeDeep(it) }
?: throw ScriptError(cc.currentPos(), "when case block expected") ?: throw ScriptError(cc.currentPos(), "when case block expected")
for (c in currentCondition) cases += WhenCase(c, block) cases += WhenCase(currentConditions, block)
} }
} }
val whenPos = t.pos val whenPos = t.pos
object : Statement() { WhenStatement(value, cases, elseCase, whenPos)
override val pos: Pos = whenPos
override suspend fun execute(scope: Scope): Obj {
var result: Obj = ObjVoid
// in / is and like uses whenValue from closure:
whenValue = value.execute(scope)
var found = false
for (c in cases) {
if (c.condition.execute(scope).toBool()) {
result = c.block.execute(scope)
found = true
break
}
}
if (!found && elseCase != null) result = elseCase.execute(scope)
return result
}
}
} else { } else {
// when { cond -> ... } // when { cond -> ... }
TODO("when without object is not yet implemented") TODO("when without object is not yet implemented")

View File

@ -0,0 +1,76 @@
/*
* Copyright 2026 Sergey S. Chernov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.sergeych.lyng
import net.sergeych.lyng.obj.Obj
import net.sergeych.lyng.obj.ObjVoid
sealed class WhenCondition(open val expr: Statement, open val pos: Pos) {
abstract suspend fun matches(scope: Scope, value: Obj): Boolean
}
class WhenEqualsCondition(
override val expr: Statement,
override val pos: Pos,
) : WhenCondition(expr, pos) {
override suspend fun matches(scope: Scope, value: Obj): Boolean {
return expr.execute(scope).compareTo(scope, value) == 0
}
}
class WhenInCondition(
override val expr: Statement,
val negated: Boolean,
override val pos: Pos,
) : WhenCondition(expr, pos) {
override suspend fun matches(scope: Scope, value: Obj): Boolean {
val result = expr.execute(scope).contains(scope, value)
return if (negated) !result else result
}
}
class WhenIsCondition(
override val expr: Statement,
val negated: Boolean,
override val pos: Pos,
) : WhenCondition(expr, pos) {
override suspend fun matches(scope: Scope, value: Obj): Boolean {
val result = value.isInstanceOf(expr.execute(scope))
return if (negated) !result else result
}
}
data class WhenCase(val conditions: List<WhenCondition>, val block: Statement)
class WhenStatement(
val value: Statement,
val cases: List<WhenCase>,
val elseCase: Statement?,
override val pos: Pos,
) : Statement() {
override suspend fun execute(scope: Scope): Obj {
val whenValue = value.execute(scope)
for (case in cases) {
for (condition in case.conditions) {
if (condition.matches(scope, whenValue)) {
return case.block.execute(scope)
}
}
}
return elseCase?.execute(scope) ?: ObjVoid
}
}

View File

@ -24,6 +24,11 @@ import net.sergeych.lyng.Pos
import net.sergeych.lyng.Statement import net.sergeych.lyng.Statement
import net.sergeych.lyng.ToBoolStatement import net.sergeych.lyng.ToBoolStatement
import net.sergeych.lyng.VarDeclStatement import net.sergeych.lyng.VarDeclStatement
import net.sergeych.lyng.WhenCondition
import net.sergeych.lyng.WhenEqualsCondition
import net.sergeych.lyng.WhenInCondition
import net.sergeych.lyng.WhenIsCondition
import net.sergeych.lyng.WhenStatement
import net.sergeych.lyng.obj.* import net.sergeych.lyng.obj.*
class BytecodeCompiler( class BytecodeCompiler(
@ -1515,6 +1520,96 @@ class BytecodeCompiler(
return CompiledValue(resultSlot, SlotType.OBJ) return CompiledValue(resultSlot, SlotType.OBJ)
} }
private fun compileWhen(stmt: WhenStatement, wantResult: Boolean): CompiledValue? {
val subjectValue = compileStatementValueOrFallback(stmt.value) ?: return null
val subjectObj = ensureObjSlot(subjectValue)
val resultSlot = allocSlot()
if (wantResult) {
val voidId = builder.addConst(BytecodeConst.ObjRef(ObjVoid))
builder.emit(Opcode.CONST_OBJ, voidId, resultSlot)
updateSlotType(resultSlot, SlotType.OBJ)
}
val endLabel = builder.label()
for (case in stmt.cases) {
val caseLabel = builder.label()
val nextCaseLabel = builder.label()
for (cond in case.conditions) {
val condValue = compileWhenCondition(cond, subjectObj) ?: return null
builder.emit(
Opcode.JMP_IF_TRUE,
listOf(CmdBuilder.Operand.IntVal(condValue.slot), CmdBuilder.Operand.LabelRef(caseLabel))
)
}
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(nextCaseLabel)))
builder.mark(caseLabel)
val bodyValue = compileStatementValueOrFallback(case.block, wantResult) ?: return null
if (wantResult) {
val bodyObj = ensureObjSlot(bodyValue)
builder.emit(Opcode.MOVE_OBJ, bodyObj.slot, resultSlot)
}
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
builder.mark(nextCaseLabel)
}
stmt.elseCase?.let {
val elseValue = compileStatementValueOrFallback(it, wantResult) ?: return null
if (wantResult) {
val elseObj = ensureObjSlot(elseValue)
builder.emit(Opcode.MOVE_OBJ, elseObj.slot, resultSlot)
}
}
builder.mark(endLabel)
return if (wantResult) {
updateSlotType(resultSlot, SlotType.OBJ)
CompiledValue(resultSlot, SlotType.OBJ)
} else {
subjectObj
}
}
private fun compileWhenCondition(cond: WhenCondition, subjectObj: CompiledValue): CompiledValue? {
val subject = ensureObjSlot(subjectObj)
return when (cond) {
is WhenEqualsCondition -> {
val expected = compileStatementValueOrFallback(cond.expr) ?: return null
val expectedObj = ensureObjSlot(expected)
val dst = allocSlot()
builder.emit(Opcode.CMP_EQ_OBJ, expectedObj.slot, subject.slot, dst)
updateSlotType(dst, SlotType.BOOL)
CompiledValue(dst, SlotType.BOOL)
}
is WhenInCondition -> {
val container = compileStatementValueOrFallback(cond.expr) ?: return null
val containerObj = ensureObjSlot(container)
val baseDst = allocSlot()
builder.emit(Opcode.CONTAINS_OBJ, containerObj.slot, subject.slot, baseDst)
updateSlotType(baseDst, SlotType.BOOL)
if (!cond.negated) {
CompiledValue(baseDst, SlotType.BOOL)
} else {
val neg = allocSlot()
builder.emit(Opcode.NOT_BOOL, baseDst, neg)
updateSlotType(neg, SlotType.BOOL)
CompiledValue(neg, SlotType.BOOL)
}
}
is WhenIsCondition -> {
val typeValue = compileStatementValueOrFallback(cond.expr) ?: return null
val typeObj = ensureObjSlot(typeValue)
val baseDst = allocSlot()
builder.emit(Opcode.CHECK_IS, subject.slot, typeObj.slot, baseDst)
updateSlotType(baseDst, SlotType.BOOL)
if (!cond.negated) {
CompiledValue(baseDst, SlotType.BOOL)
} else {
val neg = allocSlot()
builder.emit(Opcode.NOT_BOOL, baseDst, neg)
updateSlotType(neg, SlotType.BOOL)
CompiledValue(neg, SlotType.BOOL)
}
}
}
}
private fun ensureObjSlot(value: CompiledValue): CompiledValue { private fun ensureObjSlot(value: CompiledValue): CompiledValue {
if (value.type == SlotType.OBJ) return value if (value.type == SlotType.OBJ) return value
val dst = allocSlot() val dst = allocSlot()
@ -1883,6 +1978,7 @@ class BytecodeCompiler(
is net.sergeych.lyng.FunctionDeclStatement -> emitStatementEval(target) is net.sergeych.lyng.FunctionDeclStatement -> emitStatementEval(target)
is net.sergeych.lyng.EnumDeclStatement -> emitStatementEval(target) is net.sergeych.lyng.EnumDeclStatement -> emitStatementEval(target)
is net.sergeych.lyng.TryStatement -> emitStatementEval(target) is net.sergeych.lyng.TryStatement -> emitStatementEval(target)
is net.sergeych.lyng.WhenStatement -> compileWhen(target, true)
is net.sergeych.lyng.BreakStatement -> compileBreak(target) is net.sergeych.lyng.BreakStatement -> compileBreak(target)
is net.sergeych.lyng.ContinueStatement -> compileContinue(target) is net.sergeych.lyng.ContinueStatement -> compileContinue(target)
is net.sergeych.lyng.ReturnStatement -> compileReturn(target) is net.sergeych.lyng.ReturnStatement -> compileReturn(target)
@ -1927,6 +2023,7 @@ class BytecodeCompiler(
is net.sergeych.lyng.ContinueStatement -> compileContinue(target) is net.sergeych.lyng.ContinueStatement -> compileContinue(target)
is net.sergeych.lyng.ReturnStatement -> compileReturn(target) is net.sergeych.lyng.ReturnStatement -> compileReturn(target)
is net.sergeych.lyng.ThrowStatement -> compileThrow(target) is net.sergeych.lyng.ThrowStatement -> compileThrow(target)
is net.sergeych.lyng.WhenStatement -> compileWhen(target, false)
else -> { else -> {
emitFallbackStatement(target) emitFallbackStatement(target)
} }

View File

@ -19,6 +19,12 @@ package net.sergeych.lyng.bytecode
import net.sergeych.lyng.Pos import net.sergeych.lyng.Pos
import net.sergeych.lyng.Scope import net.sergeych.lyng.Scope
import net.sergeych.lyng.Statement import net.sergeych.lyng.Statement
import net.sergeych.lyng.WhenCase
import net.sergeych.lyng.WhenCondition
import net.sergeych.lyng.WhenEqualsCondition
import net.sergeych.lyng.WhenInCondition
import net.sergeych.lyng.WhenIsCondition
import net.sergeych.lyng.WhenStatement
import net.sergeych.lyng.obj.Obj import net.sergeych.lyng.obj.Obj
import net.sergeych.lyng.obj.RangeRef import net.sergeych.lyng.obj.RangeRef
@ -106,6 +112,14 @@ class BytecodeStatement private constructor(
is net.sergeych.lyng.FunctionDeclStatement -> false is net.sergeych.lyng.FunctionDeclStatement -> false
is net.sergeych.lyng.EnumDeclStatement -> false is net.sergeych.lyng.EnumDeclStatement -> false
is net.sergeych.lyng.TryStatement -> false is net.sergeych.lyng.TryStatement -> false
is net.sergeych.lyng.WhenStatement -> {
containsUnsupportedStatement(target.value) ||
target.cases.any { case ->
case.conditions.any { cond -> containsUnsupportedStatement(cond.expr) } ||
containsUnsupportedStatement(case.block)
} ||
(target.elseCase?.let { containsUnsupportedStatement(it) } ?: false)
}
else -> true else -> true
} }
} }
@ -187,8 +201,29 @@ class BytecodeStatement private constructor(
} }
is net.sergeych.lyng.ThrowStatement -> is net.sergeych.lyng.ThrowStatement ->
net.sergeych.lyng.ThrowStatement(unwrapDeep(stmt.throwExpr), stmt.pos) net.sergeych.lyng.ThrowStatement(unwrapDeep(stmt.throwExpr), stmt.pos)
is net.sergeych.lyng.WhenStatement -> {
net.sergeych.lyng.WhenStatement(
unwrapDeep(stmt.value),
stmt.cases.map { case ->
net.sergeych.lyng.WhenCase(
case.conditions.map { unwrapWhenCondition(it) },
unwrapDeep(case.block)
)
},
stmt.elseCase?.let { unwrapDeep(it) },
stmt.pos
)
}
else -> stmt else -> stmt
} }
} }
private fun unwrapWhenCondition(cond: WhenCondition): WhenCondition {
return when (cond) {
is WhenEqualsCondition -> WhenEqualsCondition(unwrapDeep(cond.expr), cond.pos)
is WhenInCondition -> WhenInCondition(unwrapDeep(cond.expr), cond.negated, cond.pos)
is WhenIsCondition -> WhenIsCondition(unwrapDeep(cond.expr), cond.negated, cond.pos)
}
}
} }
} }

View File

@ -2351,7 +2351,6 @@ class ScriptTest {
} }
} }
@Ignore("Bytecode: unsupported or incorrect behavior")
@Test @Test
fun testSimpleWhen() = runTest { fun testSimpleWhen() = runTest {
eval( eval(
@ -2376,7 +2375,6 @@ class ScriptTest {
) )
} }
@Ignore("Bytecode: unsupported or incorrect behavior")
@Test @Test
fun testWhenIs() = runTest { fun testWhenIs() = runTest {
eval( eval(
@ -2407,7 +2405,6 @@ class ScriptTest {
) )
} }
@Ignore("Bytecode: unsupported or incorrect behavior")
@Test @Test
fun testWhenIn() = runTest { fun testWhenIn() = runTest {
eval( eval(
@ -2502,7 +2499,6 @@ class ScriptTest {
// ) // )
// } // }
@Ignore("Bytecode: unsupported or incorrect behavior")
@Test @Test
fun testWhenSample1() = runTest { fun testWhenSample1() = runTest {
eval( eval(