Enable smart-cast handling for expressions and add relevant tests.

This commit is contained in:
Sergey Chernov 2026-03-09 10:29:04 +03:00
parent f03006ce37
commit 90e586b5d2
2 changed files with 67 additions and 4 deletions

View File

@ -2403,13 +2403,13 @@ class Compiler(
parseKeywordStatement(t)?.let { wrapBytecode(it) } parseKeywordStatement(t)?.let { wrapBytecode(it) }
?: run { ?: run {
cc.previous() cc.previous()
parseExpression()?.let { wrapBytecode(it) } parseExpressionWithContracts()
} }
} }
Token.Type.PLUS2, Token.Type.MINUS2 -> { Token.Type.PLUS2, Token.Type.MINUS2 -> {
cc.previous() cc.previous()
parseExpression()?.let { wrapBytecode(it) } parseExpressionWithContracts()
} }
Token.Type.ATLABEL -> { Token.Type.ATLABEL -> {
@ -2441,7 +2441,7 @@ class Compiler(
Token.Type.LBRACE -> { Token.Type.LBRACE -> {
cc.previous() cc.previous()
if (braceMeansLambda) if (braceMeansLambda)
parseExpression()?.let { wrapBytecode(it) } parseExpressionWithContracts()
else else
wrapBytecode(parseBlock()) wrapBytecode(parseBlock())
} }
@ -2456,12 +2456,18 @@ class Compiler(
else -> { else -> {
// could be expression // could be expression
cc.previous() cc.previous()
parseExpression()?.let { wrapBytecode(it) } parseExpressionWithContracts()
} }
} }
} }
} }
private suspend fun parseExpressionWithContracts(): Statement? {
val expression = parseExpression() ?: return null
applyAssertSmartCasts(expression)
return wrapBytecode(expression)
}
private suspend fun parseExpression(): Statement? { private suspend fun parseExpression(): Statement? {
val pos = cc.currentPos() val pos = cc.currentPos()
return parseExpressionLevel()?.let { ref -> return parseExpressionLevel()?.let { ref ->
@ -7205,6 +7211,27 @@ class Compiler(
} }
} }
private fun applyAssertSmartCasts(statement: Statement) {
val ref = unwrapDirectRef(statement) as? CallRef ?: return
val targetName = when (val target = ref.target) {
is LocalVarRef -> target.name
is FastLocalVarRef -> target.name
is LocalSlotRef -> target.name
else -> null
} ?: return
if (targetName != "assert") return
val conditionArg = ref.args.firstOrNull { !it.isSplat && it.name == null }
?: ref.args.firstOrNull { !it.isSplat && it.name == "condition" }
?: return
val conditionStatement = when (val argValue = conditionArg.value) {
is Statement -> argValue
is ObjRef -> ExpressionStatement(argValue, conditionArg.pos)
else -> return
}
val (trueCasts, _) = extractSmartCasts(conditionStatement)
applySmartCasts(trueCasts)
}
private suspend fun parseIfStatement(): Statement { private suspend fun parseIfStatement(): Statement {
val start = ensureLparen() val start = ensureLparen()

View File

@ -21,6 +21,7 @@ import net.sergeych.lyng.ScriptError
import net.sergeych.lyng.eval import net.sergeych.lyng.eval
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertFailsWith import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
class TypesTest { class TypesTest {
@ -181,6 +182,41 @@ class TypesTest {
""".trimIndent()) """.trimIndent())
} }
@Test
fun testAssertIsSmartCastEnablesMemberCall() = runTest {
eval(
"""
class Ctx {
fun println(msg: String) = msg
}
fun use(callContext) {
assert(callContext is Ctx)
callContext.println("hello")
}
assertEquals("hello", use(Ctx()))
""".trimIndent()
)
}
@Test
fun testBareIsExpressionDoesNotSmartCastForMemberCall() = runTest {
val ex = assertFailsWith<ScriptError> {
eval(
"""
class Ctx {
fun println(msg: String) = msg
}
fun use(callContext) {
callContext is Ctx
callContext.println("hello")
}
use(Ctx())
""".trimIndent()
)
}
assertTrue(ex.message?.contains("member access requires compile-time receiver type: println") == true)
}
@Test @Test
fun testListLiteralInferenceForBounds() = runTest { fun testListLiteralInferenceForBounds() = runTest {
eval(""" eval("""