Enable smart-cast handling for expressions and add relevant tests.

This commit is contained in:
Sergey Chernov 2026-03-09 10:29:04 +03:00
parent f03006ce37
commit 90e586b5d2
2 changed files with 67 additions and 4 deletions

View File

@ -2403,13 +2403,13 @@ class Compiler(
parseKeywordStatement(t)?.let { wrapBytecode(it) }
?: run {
cc.previous()
parseExpression()?.let { wrapBytecode(it) }
parseExpressionWithContracts()
}
}
Token.Type.PLUS2, Token.Type.MINUS2 -> {
cc.previous()
parseExpression()?.let { wrapBytecode(it) }
parseExpressionWithContracts()
}
Token.Type.ATLABEL -> {
@ -2441,7 +2441,7 @@ class Compiler(
Token.Type.LBRACE -> {
cc.previous()
if (braceMeansLambda)
parseExpression()?.let { wrapBytecode(it) }
parseExpressionWithContracts()
else
wrapBytecode(parseBlock())
}
@ -2456,12 +2456,18 @@ class Compiler(
else -> {
// could be expression
cc.previous()
parseExpression()?.let { wrapBytecode(it) }
parseExpressionWithContracts()
}
}
}
}
private suspend fun parseExpressionWithContracts(): Statement? {
val expression = parseExpression() ?: return null
applyAssertSmartCasts(expression)
return wrapBytecode(expression)
}
private suspend fun parseExpression(): Statement? {
val pos = cc.currentPos()
return parseExpressionLevel()?.let { ref ->
@ -7205,6 +7211,27 @@ class Compiler(
}
}
private fun applyAssertSmartCasts(statement: Statement) {
val ref = unwrapDirectRef(statement) as? CallRef ?: return
val targetName = when (val target = ref.target) {
is LocalVarRef -> target.name
is FastLocalVarRef -> target.name
is LocalSlotRef -> target.name
else -> null
} ?: return
if (targetName != "assert") return
val conditionArg = ref.args.firstOrNull { !it.isSplat && it.name == null }
?: ref.args.firstOrNull { !it.isSplat && it.name == "condition" }
?: return
val conditionStatement = when (val argValue = conditionArg.value) {
is Statement -> argValue
is ObjRef -> ExpressionStatement(argValue, conditionArg.pos)
else -> return
}
val (trueCasts, _) = extractSmartCasts(conditionStatement)
applySmartCasts(trueCasts)
}
private suspend fun parseIfStatement(): Statement {
val start = ensureLparen()

View File

@ -21,6 +21,7 @@ import net.sergeych.lyng.ScriptError
import net.sergeych.lyng.eval
import kotlin.test.Test
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
class TypesTest {
@ -181,6 +182,41 @@ class TypesTest {
""".trimIndent())
}
@Test
fun testAssertIsSmartCastEnablesMemberCall() = runTest {
eval(
"""
class Ctx {
fun println(msg: String) = msg
}
fun use(callContext) {
assert(callContext is Ctx)
callContext.println("hello")
}
assertEquals("hello", use(Ctx()))
""".trimIndent()
)
}
@Test
fun testBareIsExpressionDoesNotSmartCastForMemberCall() = runTest {
val ex = assertFailsWith<ScriptError> {
eval(
"""
class Ctx {
fun println(msg: String) = msg
}
fun use(callContext) {
callContext is Ctx
callContext.println("hello")
}
use(Ctx())
""".trimIndent()
)
}
assertTrue(ex.message?.contains("member access requires compile-time receiver type: println") == true)
}
@Test
fun testListLiteralInferenceForBounds() = runTest {
eval("""