Enable smart-cast handling for expressions and add relevant tests.
This commit is contained in:
parent
f03006ce37
commit
90e586b5d2
@ -2403,13 +2403,13 @@ class Compiler(
|
||||
parseKeywordStatement(t)?.let { wrapBytecode(it) }
|
||||
?: run {
|
||||
cc.previous()
|
||||
parseExpression()?.let { wrapBytecode(it) }
|
||||
parseExpressionWithContracts()
|
||||
}
|
||||
}
|
||||
|
||||
Token.Type.PLUS2, Token.Type.MINUS2 -> {
|
||||
cc.previous()
|
||||
parseExpression()?.let { wrapBytecode(it) }
|
||||
parseExpressionWithContracts()
|
||||
}
|
||||
|
||||
Token.Type.ATLABEL -> {
|
||||
@ -2441,7 +2441,7 @@ class Compiler(
|
||||
Token.Type.LBRACE -> {
|
||||
cc.previous()
|
||||
if (braceMeansLambda)
|
||||
parseExpression()?.let { wrapBytecode(it) }
|
||||
parseExpressionWithContracts()
|
||||
else
|
||||
wrapBytecode(parseBlock())
|
||||
}
|
||||
@ -2456,12 +2456,18 @@ class Compiler(
|
||||
else -> {
|
||||
// could be expression
|
||||
cc.previous()
|
||||
parseExpression()?.let { wrapBytecode(it) }
|
||||
parseExpressionWithContracts()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun parseExpressionWithContracts(): Statement? {
|
||||
val expression = parseExpression() ?: return null
|
||||
applyAssertSmartCasts(expression)
|
||||
return wrapBytecode(expression)
|
||||
}
|
||||
|
||||
private suspend fun parseExpression(): Statement? {
|
||||
val pos = cc.currentPos()
|
||||
return parseExpressionLevel()?.let { ref ->
|
||||
@ -7205,6 +7211,27 @@ class Compiler(
|
||||
}
|
||||
}
|
||||
|
||||
private fun applyAssertSmartCasts(statement: Statement) {
|
||||
val ref = unwrapDirectRef(statement) as? CallRef ?: return
|
||||
val targetName = when (val target = ref.target) {
|
||||
is LocalVarRef -> target.name
|
||||
is FastLocalVarRef -> target.name
|
||||
is LocalSlotRef -> target.name
|
||||
else -> null
|
||||
} ?: return
|
||||
if (targetName != "assert") return
|
||||
val conditionArg = ref.args.firstOrNull { !it.isSplat && it.name == null }
|
||||
?: ref.args.firstOrNull { !it.isSplat && it.name == "condition" }
|
||||
?: return
|
||||
val conditionStatement = when (val argValue = conditionArg.value) {
|
||||
is Statement -> argValue
|
||||
is ObjRef -> ExpressionStatement(argValue, conditionArg.pos)
|
||||
else -> return
|
||||
}
|
||||
val (trueCasts, _) = extractSmartCasts(conditionStatement)
|
||||
applySmartCasts(trueCasts)
|
||||
}
|
||||
|
||||
private suspend fun parseIfStatement(): Statement {
|
||||
val start = ensureLparen()
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ import net.sergeych.lyng.ScriptError
|
||||
import net.sergeych.lyng.eval
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertFailsWith
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class TypesTest {
|
||||
|
||||
@ -181,6 +182,41 @@ class TypesTest {
|
||||
""".trimIndent())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAssertIsSmartCastEnablesMemberCall() = runTest {
|
||||
eval(
|
||||
"""
|
||||
class Ctx {
|
||||
fun println(msg: String) = msg
|
||||
}
|
||||
fun use(callContext) {
|
||||
assert(callContext is Ctx)
|
||||
callContext.println("hello")
|
||||
}
|
||||
assertEquals("hello", use(Ctx()))
|
||||
""".trimIndent()
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBareIsExpressionDoesNotSmartCastForMemberCall() = runTest {
|
||||
val ex = assertFailsWith<ScriptError> {
|
||||
eval(
|
||||
"""
|
||||
class Ctx {
|
||||
fun println(msg: String) = msg
|
||||
}
|
||||
fun use(callContext) {
|
||||
callContext is Ctx
|
||||
callContext.println("hello")
|
||||
}
|
||||
use(Ctx())
|
||||
""".trimIndent()
|
||||
)
|
||||
}
|
||||
assertTrue(ex.message?.contains("member access requires compile-time receiver type: println") == true)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testListLiteralInferenceForBounds() = runTest {
|
||||
eval("""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user