fix #59 implement iterator cancellation on premature termination (break, exception) and ensure no cancellation on natural completion; add tests
This commit is contained in:
parent
8fae4709ed
commit
cbca8cacb5
@ -34,6 +34,7 @@ class Compiler(
|
|||||||
) {
|
) {
|
||||||
|
|
||||||
// Stack of parameter-to-slot plans for current function being parsed (by declaration index)
|
// Stack of parameter-to-slot plans for current function being parsed (by declaration index)
|
||||||
|
@Suppress("unused")
|
||||||
private val paramSlotPlanStack = mutableListOf<Map<String, Int>>()
|
private val paramSlotPlanStack = mutableListOf<Map<String, Int>>()
|
||||||
// private val currentParamSlotPlan: Map<String, Int>?
|
// private val currentParamSlotPlan: Map<String, Int>?
|
||||||
// get() = paramSlotPlanStack.lastOrNull()
|
// get() = paramSlotPlanStack.lastOrNull()
|
||||||
@ -727,7 +728,7 @@ class Compiler(
|
|||||||
|
|
||||||
// Commit to map literal parsing
|
// Commit to map literal parsing
|
||||||
cc.skipWsTokens()
|
cc.skipWsTokens()
|
||||||
val entries = mutableListOf<net.sergeych.lyng.obj.MapLiteralEntry>()
|
val entries = mutableListOf<MapLiteralEntry>()
|
||||||
val usedKeys = mutableSetOf<String>()
|
val usedKeys = mutableSetOf<String>()
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
@ -736,7 +737,7 @@ class Compiler(
|
|||||||
when (t0.type) {
|
when (t0.type) {
|
||||||
Token.Type.RBRACE -> {
|
Token.Type.RBRACE -> {
|
||||||
// end of map literal
|
// end of map literal
|
||||||
return net.sergeych.lyng.obj.MapLiteralRef(entries)
|
return MapLiteralRef(entries)
|
||||||
}
|
}
|
||||||
Token.Type.COMMA -> {
|
Token.Type.COMMA -> {
|
||||||
// allow stray commas; continue
|
// allow stray commas; continue
|
||||||
@ -745,7 +746,7 @@ class Compiler(
|
|||||||
Token.Type.ELLIPSIS -> {
|
Token.Type.ELLIPSIS -> {
|
||||||
// spread element: ... expression
|
// spread element: ... expression
|
||||||
val expr = parseExpressionLevel() ?: throw ScriptError(t0.pos, "invalid map spread: expecting expression")
|
val expr = parseExpressionLevel() ?: throw ScriptError(t0.pos, "invalid map spread: expecting expression")
|
||||||
entries += net.sergeych.lyng.obj.MapLiteralEntry.Spread(expr)
|
entries += MapLiteralEntry.Spread(expr)
|
||||||
// Expect comma or '}' next; loop will handle
|
// Expect comma or '}' next; loop will handle
|
||||||
}
|
}
|
||||||
Token.Type.STRING, Token.Type.ID -> {
|
Token.Type.STRING, Token.Type.ID -> {
|
||||||
@ -769,14 +770,14 @@ class Compiler(
|
|||||||
if (next.type == Token.Type.RBRACE) cc.previous()
|
if (next.type == Token.Type.RBRACE) cc.previous()
|
||||||
// Duplicate detection for literals only
|
// Duplicate detection for literals only
|
||||||
if (!usedKeys.add(keyName)) throw ScriptError(t0.pos, "duplicate key '$keyName'")
|
if (!usedKeys.add(keyName)) throw ScriptError(t0.pos, "duplicate key '$keyName'")
|
||||||
entries += net.sergeych.lyng.obj.MapLiteralEntry.Named(keyName, net.sergeych.lyng.obj.LocalVarRef(keyName, t0.pos))
|
entries += MapLiteralEntry.Named(keyName, LocalVarRef(keyName, t0.pos))
|
||||||
// If the token was COMMA, the loop continues; if it's RBRACE, next iteration will end
|
// If the token was COMMA, the loop continues; if it's RBRACE, next iteration will end
|
||||||
} else {
|
} else {
|
||||||
// There is a value expression: push back token and parse expression
|
// There is a value expression: push back token and parse expression
|
||||||
cc.previous()
|
cc.previous()
|
||||||
val valueRef = parseExpressionLevel() ?: throw ScriptError(colon.pos, "expecting map entry value")
|
val valueRef = parseExpressionLevel() ?: throw ScriptError(colon.pos, "expecting map entry value")
|
||||||
if (!usedKeys.add(keyName)) throw ScriptError(t0.pos, "duplicate key '$keyName'")
|
if (!usedKeys.add(keyName)) throw ScriptError(t0.pos, "duplicate key '$keyName'")
|
||||||
entries += net.sergeych.lyng.obj.MapLiteralEntry.Named(keyName, valueRef)
|
entries += MapLiteralEntry.Named(keyName, valueRef)
|
||||||
// After value, allow optional comma; do not require it
|
// After value, allow optional comma; do not require it
|
||||||
cc.skipTokenOfType(Token.Type.COMMA, isOptional = true)
|
cc.skipTokenOfType(Token.Type.COMMA, isOptional = true)
|
||||||
// The loop will continue and eventually see '}'
|
// The loop will continue and eventually see '}'
|
||||||
@ -893,6 +894,7 @@ class Compiler(
|
|||||||
return ArgsDeclaration(result, endTokenType)
|
return ArgsDeclaration(result, endTokenType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Suppress("unused")
|
||||||
private fun parseTypeDeclaration(): TypeDecl {
|
private fun parseTypeDeclaration(): TypeDecl {
|
||||||
return parseTypeDeclarationWithMini().first
|
return parseTypeDeclarationWithMini().first
|
||||||
}
|
}
|
||||||
@ -2052,6 +2054,8 @@ class Compiler(
|
|||||||
): Obj {
|
): Obj {
|
||||||
val iterObj = sourceObj.invokeInstanceMethod(forScope, "iterator")
|
val iterObj = sourceObj.invokeInstanceMethod(forScope, "iterator")
|
||||||
var result: Obj = ObjVoid
|
var result: Obj = ObjVoid
|
||||||
|
var completedNaturally = false
|
||||||
|
try {
|
||||||
while (iterObj.invokeInstanceMethod(forScope, "hasNext").toBool()) {
|
while (iterObj.invokeInstanceMethod(forScope, "hasNext").toBool()) {
|
||||||
if (catchBreak)
|
if (catchBreak)
|
||||||
try {
|
try {
|
||||||
@ -2060,6 +2064,7 @@ class Compiler(
|
|||||||
} catch (lbe: LoopBreakContinueException) {
|
} catch (lbe: LoopBreakContinueException) {
|
||||||
if (lbe.label == label || lbe.label == null) {
|
if (lbe.label == label || lbe.label == null) {
|
||||||
if (lbe.doContinue) continue
|
if (lbe.doContinue) continue
|
||||||
|
// premature finish, will trigger cancel in finally
|
||||||
return lbe.result
|
return lbe.result
|
||||||
}
|
}
|
||||||
throw lbe
|
throw lbe
|
||||||
@ -2069,7 +2074,16 @@ class Compiler(
|
|||||||
result = body.execute(forScope)
|
result = body.execute(forScope)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
completedNaturally = true
|
||||||
return elseStatement?.execute(forScope) ?: result
|
return elseStatement?.execute(forScope) ?: result
|
||||||
|
} finally {
|
||||||
|
if (!completedNaturally) {
|
||||||
|
// Best-effort cancellation on premature termination
|
||||||
|
runCatching {
|
||||||
|
iterObj.invokeInstanceMethod(forScope, "cancelIteration") { ObjVoid }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suppress("UNUSED_VARIABLE")
|
@Suppress("UNUSED_VARIABLE")
|
||||||
@ -2383,7 +2397,7 @@ class Compiler(
|
|||||||
}
|
}
|
||||||
fnStatements.execute(context)
|
fnStatements.execute(context)
|
||||||
}
|
}
|
||||||
parentContext
|
// parentContext
|
||||||
val fnCreateStatement = statement(start) { context ->
|
val fnCreateStatement = statement(start) { context ->
|
||||||
// we added fn in the context. now we must save closure
|
// we added fn in the context. now we must save closure
|
||||||
// for the function, unless we're in the class scope:
|
// for the function, unless we're in the class scope:
|
||||||
|
|||||||
@ -87,13 +87,25 @@ suspend fun Obj.enumerate(scope: Scope, callback: suspend (Obj) -> Boolean) {
|
|||||||
val hasNext = iterator.getInstanceMethod(scope, "hasNext")
|
val hasNext = iterator.getInstanceMethod(scope, "hasNext")
|
||||||
val next = iterator.getInstanceMethod(scope, "next")
|
val next = iterator.getInstanceMethod(scope, "next")
|
||||||
var closeIt = false
|
var closeIt = false
|
||||||
|
try {
|
||||||
while (hasNext.invoke(scope, iterator).toBool()) {
|
while (hasNext.invoke(scope, iterator).toBool()) {
|
||||||
val nextValue = next.invoke(scope, iterator)
|
val nextValue = next.invoke(scope, iterator)
|
||||||
if (!callback(nextValue)) {
|
val shouldContinue = try {
|
||||||
|
callback(nextValue)
|
||||||
|
} catch (e: Exception) {
|
||||||
|
// iteration aborted due to exception in callback
|
||||||
|
closeIt = true
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
if (!shouldContinue) {
|
||||||
closeIt = true
|
closeIt = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (closeIt)
|
} finally {
|
||||||
|
if (closeIt) {
|
||||||
|
// Best-effort cancel on premature termination
|
||||||
iterator.invokeInstanceMethod(scope, "cancelIteration") { ObjVoid }
|
iterator.invokeInstanceMethod(scope, "cancelIteration") { ObjVoid }
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -37,6 +37,97 @@ class ScriptTest {
|
|||||||
println("version = ${LyngVersion}")
|
println("version = ${LyngVersion}")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Helpers to test iterator cancellation semantics ---
|
||||||
|
class ObjTestIterable : Obj() {
|
||||||
|
|
||||||
|
var cancelCount: Int = 0
|
||||||
|
|
||||||
|
override val objClass: ObjClass = type
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
val type = ObjClass("TestIterable", ObjIterable).apply {
|
||||||
|
addFn("iterator") {
|
||||||
|
ObjTestIterator(thisAs<ObjTestIterable>())
|
||||||
|
}
|
||||||
|
addFn("cancelCount") { thisAs<ObjTestIterable>().cancelCount.toObj() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class ObjTestIterator(private val owner: ObjTestIterable) : Obj() {
|
||||||
|
override val objClass: ObjClass = type
|
||||||
|
private var i = 0
|
||||||
|
|
||||||
|
private fun hasNext(): Boolean = i < 5
|
||||||
|
private fun next(): Obj = ObjInt((++i).toLong())
|
||||||
|
private fun cancelIteration() {
|
||||||
|
owner.cancelCount += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
val type = ObjClass("TestIterator", ObjIterator).apply {
|
||||||
|
addFn("hasNext") { thisAs<ObjTestIterator>().hasNext().toObj() }
|
||||||
|
addFn("next") { thisAs<ObjTestIterator>().next() }
|
||||||
|
addFn("cancelIteration") {
|
||||||
|
thisAs<ObjTestIterator>().cancelIteration()
|
||||||
|
ObjVoid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testForLoopDoesNotCancelOnNaturalCompletion() = runTest {
|
||||||
|
val scope = Script.newScope()
|
||||||
|
val ti = ObjTestIterable()
|
||||||
|
scope.addConst("ti", ti)
|
||||||
|
scope.eval(
|
||||||
|
"""
|
||||||
|
var s = 0
|
||||||
|
for( i in ti ) {
|
||||||
|
s += i
|
||||||
|
}
|
||||||
|
s
|
||||||
|
""".trimIndent()
|
||||||
|
)
|
||||||
|
assertEquals(0, ti.cancelCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testForLoopCancelsOnBreak() = runTest {
|
||||||
|
val scope = Script.newScope()
|
||||||
|
val ti = ObjTestIterable()
|
||||||
|
scope.addConst("ti", ti)
|
||||||
|
scope.eval(
|
||||||
|
"""
|
||||||
|
for( i in ti ) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
""".trimIndent()
|
||||||
|
)
|
||||||
|
assertEquals(1, ti.cancelCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testForLoopCancelsOnException() = runTest {
|
||||||
|
val scope = Script.newScope()
|
||||||
|
val ti = ObjTestIterable()
|
||||||
|
scope.addConst("ti", ti)
|
||||||
|
try {
|
||||||
|
scope.eval(
|
||||||
|
"""
|
||||||
|
for( i in ti ) {
|
||||||
|
throw "boom"
|
||||||
|
}
|
||||||
|
""".trimIndent()
|
||||||
|
)
|
||||||
|
fail("Exception expected")
|
||||||
|
} catch (_: Exception) {
|
||||||
|
// ignore
|
||||||
|
}
|
||||||
|
assertEquals(1, ti.cancelCount)
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun parseNewlines() {
|
fun parseNewlines() {
|
||||||
fun check(expected: String, type: Token.Type, row: Int, col: Int, src: String, offset: Int = 0) {
|
fun check(expected: String, type: Token.Type, row: Int, col: Int, src: String, offset: Int = 0) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user