From 90e586b5d2259c457d80b571c26afe58aea2f17f Mon Sep 17 00:00:00 2001 From: sergeych Date: Mon, 9 Mar 2026 10:29:04 +0300 Subject: [PATCH] Enable smart-cast handling for expressions and add relevant tests. --- .../kotlin/net/sergeych/lyng/Compiler.kt | 35 +++++++++++++++--- lynglib/src/commonTest/kotlin/TypesTest.kt | 36 +++++++++++++++++++ 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt index 0065794..045822b 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt @@ -2403,13 +2403,13 @@ class Compiler( parseKeywordStatement(t)?.let { wrapBytecode(it) } ?: run { cc.previous() - parseExpression()?.let { wrapBytecode(it) } + parseExpressionWithContracts() } } Token.Type.PLUS2, Token.Type.MINUS2 -> { cc.previous() - parseExpression()?.let { wrapBytecode(it) } + parseExpressionWithContracts() } Token.Type.ATLABEL -> { @@ -2441,7 +2441,7 @@ class Compiler( Token.Type.LBRACE -> { cc.previous() if (braceMeansLambda) - parseExpression()?.let { wrapBytecode(it) } + parseExpressionWithContracts() else wrapBytecode(parseBlock()) } @@ -2456,12 +2456,18 @@ class Compiler( else -> { // could be expression cc.previous() - parseExpression()?.let { wrapBytecode(it) } + parseExpressionWithContracts() } } } } + private suspend fun parseExpressionWithContracts(): Statement? { + val expression = parseExpression() ?: return null + applyAssertSmartCasts(expression) + return wrapBytecode(expression) + } + private suspend fun parseExpression(): Statement? { val pos = cc.currentPos() return parseExpressionLevel()?.let { ref -> @@ -7205,6 +7211,27 @@ class Compiler( } } + private fun applyAssertSmartCasts(statement: Statement) { + val ref = unwrapDirectRef(statement) as? CallRef ?: return + val targetName = when (val target = ref.target) { + is LocalVarRef -> target.name + is FastLocalVarRef -> target.name + is LocalSlotRef -> target.name + else -> null + } ?: return + if (targetName != "assert") return + val conditionArg = ref.args.firstOrNull { !it.isSplat && it.name == null } + ?: ref.args.firstOrNull { !it.isSplat && it.name == "condition" } + ?: return + val conditionStatement = when (val argValue = conditionArg.value) { + is Statement -> argValue + is ObjRef -> ExpressionStatement(argValue, conditionArg.pos) + else -> return + } + val (trueCasts, _) = extractSmartCasts(conditionStatement) + applySmartCasts(trueCasts) + } + private suspend fun parseIfStatement(): Statement { val start = ensureLparen() diff --git a/lynglib/src/commonTest/kotlin/TypesTest.kt b/lynglib/src/commonTest/kotlin/TypesTest.kt index 7924e8e..906bf86 100644 --- a/lynglib/src/commonTest/kotlin/TypesTest.kt +++ b/lynglib/src/commonTest/kotlin/TypesTest.kt @@ -21,6 +21,7 @@ import net.sergeych.lyng.ScriptError import net.sergeych.lyng.eval import kotlin.test.Test import kotlin.test.assertFailsWith +import kotlin.test.assertTrue class TypesTest { @@ -181,6 +182,41 @@ class TypesTest { """.trimIndent()) } + @Test + fun testAssertIsSmartCastEnablesMemberCall() = runTest { + eval( + """ + class Ctx { + fun println(msg: String) = msg + } + fun use(callContext) { + assert(callContext is Ctx) + callContext.println("hello") + } + assertEquals("hello", use(Ctx())) + """.trimIndent() + ) + } + + @Test + fun testBareIsExpressionDoesNotSmartCastForMemberCall() = runTest { + val ex = assertFailsWith { + eval( + """ + class Ctx { + fun println(msg: String) = msg + } + fun use(callContext) { + callContext is Ctx + callContext.println("hello") + } + use(Ctx()) + """.trimIndent() + ) + } + assertTrue(ex.message?.contains("member access requires compile-time receiver type: println") == true) + } + @Test fun testListLiteralInferenceForBounds() = runTest { eval("""