Improve Set handling with type-aware operations and enhance generic function parsing logic

This commit is contained in:
Sergey Chernov 2026-03-12 19:08:29 +03:00
parent 32a7adf56e
commit 3bfb80a7c1
4 changed files with 210 additions and 9 deletions

View File

@ -169,6 +169,7 @@ class Compiler(
) )
private val typeAliases: MutableMap<String, TypeAliasDecl> = mutableMapOf() private val typeAliases: MutableMap<String, TypeAliasDecl> = mutableMapOf()
private val methodReturnTypeDeclByRef: MutableMap<ObjRef, TypeDecl> = mutableMapOf() private val methodReturnTypeDeclByRef: MutableMap<ObjRef, TypeDecl> = mutableMapOf()
private val callReturnTypeDeclByRef: MutableMap<CallRef, TypeDecl> = mutableMapOf()
private val callableReturnTypeByScopeId: MutableMap<Int, MutableMap<Int, ObjClass>> = mutableMapOf() private val callableReturnTypeByScopeId: MutableMap<Int, MutableMap<Int, ObjClass>> = mutableMapOf()
private val callableReturnTypeByName: MutableMap<String, ObjClass> = mutableMapOf() private val callableReturnTypeByName: MutableMap<String, ObjClass> = mutableMapOf()
private val lambdaReturnTypeByRef: MutableMap<ObjRef, ObjClass> = mutableMapOf() private val lambdaReturnTypeByRef: MutableMap<ObjRef, ObjClass> = mutableMapOf()
@ -2544,6 +2545,7 @@ class Compiler(
private suspend fun parseTerm(): ObjRef? { private suspend fun parseTerm(): ObjRef? {
var operand: ObjRef? = null var operand: ObjRef? = null
var pendingCallTypeArgs: List<TypeDecl>? = null
// newlines _before_ // newlines _before_
cc.skipWsTokens() cc.skipWsTokens()
@ -2791,20 +2793,35 @@ class Compiler(
operand = parseScopeOperator(operand) operand = parseScopeOperator(operand)
} }
Token.Type.LT -> {
val parsedTypeArgs = operand
?.takeIf { isGenericCallCalleeCandidate(it) }
?.let { tryParseCallTypeArgsAfterLt() }
if (parsedTypeArgs != null) {
pendingCallTypeArgs = parsedTypeArgs
continue
}
cc.previous()
return operand
}
Token.Type.LPAREN, Token.Type.NULL_COALESCE_INVOKE -> { Token.Type.LPAREN, Token.Type.NULL_COALESCE_INVOKE -> {
operand?.let { left -> operand?.let { left ->
// this is function call from <left> // this is function call from <left>
operand = parseFunctionCall( operand = parseFunctionCall(
left, left,
false, false,
t.type == Token.Type.NULL_COALESCE_INVOKE t.type == Token.Type.NULL_COALESCE_INVOKE,
pendingCallTypeArgs
) )
pendingCallTypeArgs = null
} ?: run { } ?: run {
// Expression in parentheses // Expression in parentheses
val statement = parseStatement() ?: throw ScriptError(t.pos, "Expecting expression") val statement = parseStatement() ?: throw ScriptError(t.pos, "Expecting expression")
operand = StatementRef(statement) operand = StatementRef(statement)
cc.skipTokenOfType(Token.Type.NEWLINE, isOptional = true) cc.skipTokenOfType(Token.Type.NEWLINE, isOptional = true)
cc.skipTokenOfType(Token.Type.RPAREN, "missing ')'") cc.skipTokenOfType(Token.Type.RPAREN, "missing ')'")
pendingCallTypeArgs = null
} }
} }
@ -2984,7 +3001,8 @@ class Compiler(
parseFunctionCall( parseFunctionCall(
left, left,
blockArgument = true, blockArgument = true,
isOptional = t.type == Token.Type.NULL_COALESCE_BLOCKINVOKE isOptional = t.type == Token.Type.NULL_COALESCE_BLOCKINVOKE,
explicitTypeArgs = pendingCallTypeArgs
) )
} ?: run { } ?: run {
// Disambiguate between lambda and map literal. // Disambiguate between lambda and map literal.
@ -3011,6 +3029,54 @@ class Compiler(
} }
} }
private suspend fun tryParseCallTypeArgsAfterLt(): List<TypeDecl>? {
val savedAfterLt = cc.savePos()
return try {
val args = mutableListOf<TypeDecl>()
do {
val (argSem, _) = parseTypeExpressionWithMini()
args += argSem
val sep = cc.next()
when (sep.type) {
Token.Type.COMMA -> continue
Token.Type.GT -> break
Token.Type.SHR -> {
cc.pushPendingGT()
break
}
else -> {
cc.restorePos(savedAfterLt)
return null
}
}
} while (true)
val nextType = cc.peekNextNonWhitespace().type
if (nextType != Token.Type.LPAREN && nextType != Token.Type.NULL_COALESCE_INVOKE) {
cc.restorePos(savedAfterLt)
return null
}
args
} catch (_: ScriptError) {
cc.restorePos(savedAfterLt)
null
}
}
private fun isGenericCallCalleeCandidate(ref: ObjRef): Boolean {
val name = when (ref) {
is LocalVarRef -> ref.name
is FastLocalVarRef -> ref.name
is LocalSlotRef -> ref.name
else -> null
}
if (name != null) {
if (lookupGenericFunctionDecl(name) != null) return true
if (name.firstOrNull()?.isUpperCase() == true) return true
return false
}
return ref is ConstRef && ref.constValue is ObjClass
}
/** /**
* Parse lambda expression, leading '{' is already consumed * Parse lambda expression, leading '{' is already consumed
*/ */
@ -4369,6 +4435,7 @@ class Compiler(
} }
} }
is MethodCallRef -> methodReturnTypeDeclByRef[ref] is MethodCallRef -> methodReturnTypeDeclByRef[ref]
is CallRef -> callReturnTypeDeclByRef[ref]
is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) } is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) }
else -> null else -> null
} }
@ -5407,7 +5474,8 @@ class Compiler(
private suspend fun parseFunctionCall( private suspend fun parseFunctionCall(
left: ObjRef, left: ObjRef,
blockArgument: Boolean, blockArgument: Boolean,
isOptional: Boolean isOptional: Boolean,
explicitTypeArgs: List<TypeDecl>? = null
): ObjRef { ): ObjRef {
var detectedBlockArgument = blockArgument var detectedBlockArgument = blockArgument
val expectedReceiver = tailBlockReceiverType(left) val expectedReceiver = tailBlockReceiverType(left)
@ -5448,7 +5516,9 @@ class Compiler(
val result = when (left) { val result = when (left) {
is ImplicitThisMemberRef -> is ImplicitThisMemberRef ->
if (left.methodId == null && left.fieldId != null) { if (left.methodId == null && left.fieldId != null) {
CallRef(left, args, detectedBlockArgument, isOptional) CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
}
} else { } else {
ImplicitThisMethodCallRef( ImplicitThisMethodCallRef(
left.name, left.name,
@ -5481,7 +5551,9 @@ class Compiler(
checkFunctionTypeCallArity(left, args, left.pos()) checkFunctionTypeCallArity(left, args, left.pos())
checkFunctionTypeCallTypes(left, args, left.pos()) checkFunctionTypeCallTypes(left, args, left.pos())
checkGenericBoundsAtCall(left.name, args, left.pos()) checkGenericBoundsAtCall(left.name, args, left.pos())
CallRef(left, args, detectedBlockArgument, isOptional) CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
}
} }
} }
is LocalSlotRef -> { is LocalSlotRef -> {
@ -5505,14 +5577,30 @@ class Compiler(
checkFunctionTypeCallArity(left, args, left.pos()) checkFunctionTypeCallArity(left, args, left.pos())
checkFunctionTypeCallTypes(left, args, left.pos()) checkFunctionTypeCallTypes(left, args, left.pos())
checkGenericBoundsAtCall(left.name, args, left.pos()) checkGenericBoundsAtCall(left.name, args, left.pos())
CallRef(left, args, detectedBlockArgument, isOptional) CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
} }
} }
else -> CallRef(left, args, detectedBlockArgument, isOptional) }
else -> CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
}
} }
return result return result
} }
private fun applyExplicitCallTypeArgs(callRef: CallRef, explicitTypeArgs: List<TypeDecl>?) {
if (explicitTypeArgs.isNullOrEmpty()) return
val baseName = when (val target = callRef.target) {
is LocalVarRef -> target.name
is FastLocalVarRef -> target.name
is LocalSlotRef -> target.name
is ConstRef -> (target.constValue as? ObjClass)?.className
else -> null
} ?: return
callReturnTypeDeclByRef[callRef] = TypeDecl.Generic(baseName, explicitTypeArgs, isNullable = false)
}
private fun inferReceiverTypeFromArgs(args: List<ParsedArgument>): String? { private fun inferReceiverTypeFromArgs(args: List<ParsedArgument>): String? {
val stmt = args.firstOrNull()?.value as? ExpressionStatement ?: return null val stmt = args.firstOrNull()?.value as? ExpressionStatement ?: return null
val ref = stmt.ref val ref = stmt.ref

View File

@ -471,6 +471,52 @@ open class Scope(
} }
} }
private fun resolvedRecordValueOrNull(record: ObjRecord): Obj? {
return when (val raw = record.value) {
is FrameSlotRef -> raw.read()
is RecordSlotRef -> raw.read()
else -> raw
}
}
private fun declaredTypeForValueInThisScope(value: Obj): TypeDecl? {
// Prefer direct bindings first.
for (record in objects.values) {
val decl = record.typeDecl ?: continue
if (resolvedRecordValueOrNull(record) === value) return decl
}
for ((_, record) in localBindings) {
val decl = record.typeDecl ?: continue
if (resolvedRecordValueOrNull(record) === value) return decl
}
// Then slots (for frame-first locals).
var i = 0
while (i < slots.size) {
val record = slots[i]
val decl = record.typeDecl
if (decl != null && resolvedRecordValueOrNull(record) === value) return decl
i++
}
return null
}
/**
* Best-effort lookup of the declared Set element type for a runtime set instance.
* Returns null when type info is unavailable.
*/
fun declaredSetElementTypeForValue(value: Obj): TypeDecl? {
var s: Scope? = this
var hops = 0
while (s != null && hops++ < 1024) {
val decl = s.declaredTypeForValueInThisScope(value)
if (decl is TypeDecl.Generic && decl.name.substringAfterLast('.') == "Set") {
return decl.args.firstOrNull()
}
s = s.parent
}
return null
}
internal fun applySlotPlanReset(plan: Map<String, Int>, records: Map<String, ObjRecord>) { internal fun applySlotPlanReset(plan: Map<String, Int>, records: Map<String, ObjRecord>) {
if (plan.isEmpty()) return if (plan.isEmpty()) return
slots.clear() slots.clear()

View File

@ -27,6 +27,16 @@ import net.sergeych.lynon.LynonEncoder
import net.sergeych.lynon.LynonType import net.sergeych.lynon.LynonType
class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() { class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
private fun shouldTreatAsSingleElement(scope: Scope, other: Obj): Boolean {
if (!other.isInstanceOf(ObjIterable)) return true
val declaredElementType = scope.declaredSetElementTypeForValue(this)
if (declaredElementType != null && matchesTypeDecl(scope, other, declaredElementType)) {
return true
}
// Strings and buffers are iterable but usually expected to be atomic values for set +/- operators.
if (other is ObjString || other is ObjBuffer) return true
return false
}
override suspend fun equals(scope: Scope, other: Obj): Boolean { override suspend fun equals(scope: Scope, other: Obj): Boolean {
if (this === other) return true if (this === other) return true
@ -53,6 +63,9 @@ class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
} }
override suspend fun plus(scope: Scope, other: Obj): Obj { override suspend fun plus(scope: Scope, other: Obj): Obj {
if (shouldTreatAsSingleElement(scope, other)) {
return ObjSet((set + other).toMutableSet())
}
return ObjSet( return ObjSet(
if (other is ObjSet) if (other is ObjSet)
(set + other.set).toMutableSet() (set + other.set).toMutableSet()
@ -73,6 +86,10 @@ class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
} }
override suspend fun plusAssign(scope: Scope, other: Obj): Obj { override suspend fun plusAssign(scope: Scope, other: Obj): Obj {
if (shouldTreatAsSingleElement(scope, other)) {
set += other
return this
}
when (other) { when (other) {
is ObjSet -> { is ObjSet -> {
set += other.set set += other.set
@ -105,6 +122,9 @@ class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
} }
override suspend fun minus(scope: Scope, other: Obj): Obj { override suspend fun minus(scope: Scope, other: Obj): Obj {
if (shouldTreatAsSingleElement(scope, other)) {
return ObjSet((set - other).toMutableSet())
}
return when { return when {
other is ObjSet -> ObjSet(set.minus(other.set).toMutableSet()) other is ObjSet -> ObjSet(set.minus(other.set).toMutableSet())
other.isInstanceOf(ObjIterable) -> { other.isInstanceOf(ObjIterable) -> {
@ -115,8 +135,7 @@ class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
} }
ObjSet((set - otherSet).toMutableSet()) ObjSet((set - otherSet).toMutableSet())
} }
else -> else -> ObjSet((set - other).toMutableSet())
scope.raiseIllegalArgument("set operator - requires another set or Iterable")
} }
} }

View File

@ -399,4 +399,52 @@ class TypesTest {
fun testOk5() { l4(1, "a", "b", "x") } fun testOk5() { l4(1, "a", "b", "x") }
""".trimIndent()) """.trimIndent())
} }
@Test
fun testSetTyped() = runTest {
eval("""
var s = Set<String>()
val typed: Set<String> = s
assertEquals(Set(), typed)
s += "foo"
assertEquals(Set("foo"), s)
s -= "foo"
assertEquals(Set(), s)
s += ["foo", "bar"]
assertEquals(Set("foo", "bar"), s)
""".trimIndent())
}
// @Test
// fun testAliasesInGenerics1() = runTest {
// val scope = Script.newScope()
// scope.eval("""
// type IntList<T: Int> = List<T>
// type IntMap<K,V> = Map<K,V>
// type IntSet<T: Int> = Set<T>
// type IntPair<T: Int> = Pair<T,T>
// type IntTriple<T: Int> = Triple<T,T,T>
// type IntQuad<T: Int> = Quad<T,T,T,T>
//
// import lyng.buffer
// type Tag = String | Buffer
//
// class X {
// var tags: Set<Tag> = Set()
// }
// val x = X()
// x.tags += "tag1"
// assertEquals(Set("tag1"), x.tags)
// x.tags += "tag2"
// assertEquals(Set("tag1", "tag2"), x.tags)
// x.tags += Buffer("tag3")
// assertEquals(Set("tag1", "tag2", Buffer("tag3")), x.tags)
// x.tags += Buffer("tag4")
// assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags)
// x.tags += "tag3"
// x.tags += "tag4"
// assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags)
// """)
// }
} }