Enable smart-cast handling for expressions and add relevant tests.
This commit is contained in:
parent
f03006ce37
commit
90e586b5d2
@ -2403,13 +2403,13 @@ class Compiler(
|
|||||||
parseKeywordStatement(t)?.let { wrapBytecode(it) }
|
parseKeywordStatement(t)?.let { wrapBytecode(it) }
|
||||||
?: run {
|
?: run {
|
||||||
cc.previous()
|
cc.previous()
|
||||||
parseExpression()?.let { wrapBytecode(it) }
|
parseExpressionWithContracts()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Token.Type.PLUS2, Token.Type.MINUS2 -> {
|
Token.Type.PLUS2, Token.Type.MINUS2 -> {
|
||||||
cc.previous()
|
cc.previous()
|
||||||
parseExpression()?.let { wrapBytecode(it) }
|
parseExpressionWithContracts()
|
||||||
}
|
}
|
||||||
|
|
||||||
Token.Type.ATLABEL -> {
|
Token.Type.ATLABEL -> {
|
||||||
@ -2441,7 +2441,7 @@ class Compiler(
|
|||||||
Token.Type.LBRACE -> {
|
Token.Type.LBRACE -> {
|
||||||
cc.previous()
|
cc.previous()
|
||||||
if (braceMeansLambda)
|
if (braceMeansLambda)
|
||||||
parseExpression()?.let { wrapBytecode(it) }
|
parseExpressionWithContracts()
|
||||||
else
|
else
|
||||||
wrapBytecode(parseBlock())
|
wrapBytecode(parseBlock())
|
||||||
}
|
}
|
||||||
@ -2456,12 +2456,18 @@ class Compiler(
|
|||||||
else -> {
|
else -> {
|
||||||
// could be expression
|
// could be expression
|
||||||
cc.previous()
|
cc.previous()
|
||||||
parseExpression()?.let { wrapBytecode(it) }
|
parseExpressionWithContracts()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private suspend fun parseExpressionWithContracts(): Statement? {
|
||||||
|
val expression = parseExpression() ?: return null
|
||||||
|
applyAssertSmartCasts(expression)
|
||||||
|
return wrapBytecode(expression)
|
||||||
|
}
|
||||||
|
|
||||||
private suspend fun parseExpression(): Statement? {
|
private suspend fun parseExpression(): Statement? {
|
||||||
val pos = cc.currentPos()
|
val pos = cc.currentPos()
|
||||||
return parseExpressionLevel()?.let { ref ->
|
return parseExpressionLevel()?.let { ref ->
|
||||||
@ -7205,6 +7211,27 @@ class Compiler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun applyAssertSmartCasts(statement: Statement) {
|
||||||
|
val ref = unwrapDirectRef(statement) as? CallRef ?: return
|
||||||
|
val targetName = when (val target = ref.target) {
|
||||||
|
is LocalVarRef -> target.name
|
||||||
|
is FastLocalVarRef -> target.name
|
||||||
|
is LocalSlotRef -> target.name
|
||||||
|
else -> null
|
||||||
|
} ?: return
|
||||||
|
if (targetName != "assert") return
|
||||||
|
val conditionArg = ref.args.firstOrNull { !it.isSplat && it.name == null }
|
||||||
|
?: ref.args.firstOrNull { !it.isSplat && it.name == "condition" }
|
||||||
|
?: return
|
||||||
|
val conditionStatement = when (val argValue = conditionArg.value) {
|
||||||
|
is Statement -> argValue
|
||||||
|
is ObjRef -> ExpressionStatement(argValue, conditionArg.pos)
|
||||||
|
else -> return
|
||||||
|
}
|
||||||
|
val (trueCasts, _) = extractSmartCasts(conditionStatement)
|
||||||
|
applySmartCasts(trueCasts)
|
||||||
|
}
|
||||||
|
|
||||||
private suspend fun parseIfStatement(): Statement {
|
private suspend fun parseIfStatement(): Statement {
|
||||||
val start = ensureLparen()
|
val start = ensureLparen()
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import net.sergeych.lyng.ScriptError
|
|||||||
import net.sergeych.lyng.eval
|
import net.sergeych.lyng.eval
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertFailsWith
|
import kotlin.test.assertFailsWith
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
class TypesTest {
|
class TypesTest {
|
||||||
|
|
||||||
@ -181,6 +182,41 @@ class TypesTest {
|
|||||||
""".trimIndent())
|
""".trimIndent())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAssertIsSmartCastEnablesMemberCall() = runTest {
|
||||||
|
eval(
|
||||||
|
"""
|
||||||
|
class Ctx {
|
||||||
|
fun println(msg: String) = msg
|
||||||
|
}
|
||||||
|
fun use(callContext) {
|
||||||
|
assert(callContext is Ctx)
|
||||||
|
callContext.println("hello")
|
||||||
|
}
|
||||||
|
assertEquals("hello", use(Ctx()))
|
||||||
|
""".trimIndent()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBareIsExpressionDoesNotSmartCastForMemberCall() = runTest {
|
||||||
|
val ex = assertFailsWith<ScriptError> {
|
||||||
|
eval(
|
||||||
|
"""
|
||||||
|
class Ctx {
|
||||||
|
fun println(msg: String) = msg
|
||||||
|
}
|
||||||
|
fun use(callContext) {
|
||||||
|
callContext is Ctx
|
||||||
|
callContext.println("hello")
|
||||||
|
}
|
||||||
|
use(Ctx())
|
||||||
|
""".trimIndent()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
assertTrue(ex.message?.contains("member access requires compile-time receiver type: println") == true)
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testListLiteralInferenceForBounds() = runTest {
|
fun testListLiteralInferenceForBounds() = runTest {
|
||||||
eval("""
|
eval("""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user