From 3bfb80a7c13cb90892536f7d76f27df22b366658 Mon Sep 17 00:00:00 2001 From: sergeych Date: Thu, 12 Mar 2026 19:08:29 +0300 Subject: [PATCH] Improve Set handling with type-aware operations and enhance generic function parsing logic --- .../kotlin/net/sergeych/lyng/Compiler.kt | 102 ++++++++++++++++-- .../kotlin/net/sergeych/lyng/Scope.kt | 46 ++++++++ .../kotlin/net/sergeych/lyng/obj/ObjSet.kt | 23 +++- lynglib/src/commonTest/kotlin/TypesTest.kt | 48 +++++++++ 4 files changed, 210 insertions(+), 9 deletions(-) diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt index 1dfe218..d6cf578 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt @@ -169,6 +169,7 @@ class Compiler( ) private val typeAliases: MutableMap = mutableMapOf() private val methodReturnTypeDeclByRef: MutableMap = mutableMapOf() + private val callReturnTypeDeclByRef: MutableMap = mutableMapOf() private val callableReturnTypeByScopeId: MutableMap> = mutableMapOf() private val callableReturnTypeByName: MutableMap = mutableMapOf() private val lambdaReturnTypeByRef: MutableMap = mutableMapOf() @@ -2544,6 +2545,7 @@ class Compiler( private suspend fun parseTerm(): ObjRef? { var operand: ObjRef? = null + var pendingCallTypeArgs: List? = null // newlines _before_ cc.skipWsTokens() @@ -2791,20 +2793,35 @@ class Compiler( operand = parseScopeOperator(operand) } + Token.Type.LT -> { + val parsedTypeArgs = operand + ?.takeIf { isGenericCallCalleeCandidate(it) } + ?.let { tryParseCallTypeArgsAfterLt() } + if (parsedTypeArgs != null) { + pendingCallTypeArgs = parsedTypeArgs + continue + } + cc.previous() + return operand + } + Token.Type.LPAREN, Token.Type.NULL_COALESCE_INVOKE -> { operand?.let { left -> // this is function call from operand = parseFunctionCall( left, false, - t.type == Token.Type.NULL_COALESCE_INVOKE + t.type == Token.Type.NULL_COALESCE_INVOKE, + pendingCallTypeArgs ) + pendingCallTypeArgs = null } ?: run { // Expression in parentheses val statement = parseStatement() ?: throw ScriptError(t.pos, "Expecting expression") operand = StatementRef(statement) cc.skipTokenOfType(Token.Type.NEWLINE, isOptional = true) cc.skipTokenOfType(Token.Type.RPAREN, "missing ')'") + pendingCallTypeArgs = null } } @@ -2984,7 +3001,8 @@ class Compiler( parseFunctionCall( left, blockArgument = true, - isOptional = t.type == Token.Type.NULL_COALESCE_BLOCKINVOKE + isOptional = t.type == Token.Type.NULL_COALESCE_BLOCKINVOKE, + explicitTypeArgs = pendingCallTypeArgs ) } ?: run { // Disambiguate between lambda and map literal. @@ -3011,6 +3029,54 @@ class Compiler( } } + private suspend fun tryParseCallTypeArgsAfterLt(): List? { + val savedAfterLt = cc.savePos() + return try { + val args = mutableListOf() + do { + val (argSem, _) = parseTypeExpressionWithMini() + args += argSem + val sep = cc.next() + when (sep.type) { + Token.Type.COMMA -> continue + Token.Type.GT -> break + Token.Type.SHR -> { + cc.pushPendingGT() + break + } + else -> { + cc.restorePos(savedAfterLt) + return null + } + } + } while (true) + val nextType = cc.peekNextNonWhitespace().type + if (nextType != Token.Type.LPAREN && nextType != Token.Type.NULL_COALESCE_INVOKE) { + cc.restorePos(savedAfterLt) + return null + } + args + } catch (_: ScriptError) { + cc.restorePos(savedAfterLt) + null + } + } + + private fun isGenericCallCalleeCandidate(ref: ObjRef): Boolean { + val name = when (ref) { + is LocalVarRef -> ref.name + is FastLocalVarRef -> ref.name + is LocalSlotRef -> ref.name + else -> null + } + if (name != null) { + if (lookupGenericFunctionDecl(name) != null) return true + if (name.firstOrNull()?.isUpperCase() == true) return true + return false + } + return ref is ConstRef && ref.constValue is ObjClass + } + /** * Parse lambda expression, leading '{' is already consumed */ @@ -4369,6 +4435,7 @@ class Compiler( } } is MethodCallRef -> methodReturnTypeDeclByRef[ref] + is CallRef -> callReturnTypeDeclByRef[ref] is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) } else -> null } @@ -5407,7 +5474,8 @@ class Compiler( private suspend fun parseFunctionCall( left: ObjRef, blockArgument: Boolean, - isOptional: Boolean + isOptional: Boolean, + explicitTypeArgs: List? = null ): ObjRef { var detectedBlockArgument = blockArgument val expectedReceiver = tailBlockReceiverType(left) @@ -5448,7 +5516,9 @@ class Compiler( val result = when (left) { is ImplicitThisMemberRef -> if (left.methodId == null && left.fieldId != null) { - CallRef(left, args, detectedBlockArgument, isOptional) + CallRef(left, args, detectedBlockArgument, isOptional).also { callRef -> + applyExplicitCallTypeArgs(callRef, explicitTypeArgs) + } } else { ImplicitThisMethodCallRef( left.name, @@ -5481,7 +5551,9 @@ class Compiler( checkFunctionTypeCallArity(left, args, left.pos()) checkFunctionTypeCallTypes(left, args, left.pos()) checkGenericBoundsAtCall(left.name, args, left.pos()) - CallRef(left, args, detectedBlockArgument, isOptional) + CallRef(left, args, detectedBlockArgument, isOptional).also { callRef -> + applyExplicitCallTypeArgs(callRef, explicitTypeArgs) + } } } is LocalSlotRef -> { @@ -5505,14 +5577,30 @@ class Compiler( checkFunctionTypeCallArity(left, args, left.pos()) checkFunctionTypeCallTypes(left, args, left.pos()) checkGenericBoundsAtCall(left.name, args, left.pos()) - CallRef(left, args, detectedBlockArgument, isOptional) + CallRef(left, args, detectedBlockArgument, isOptional).also { callRef -> + applyExplicitCallTypeArgs(callRef, explicitTypeArgs) + } } } - else -> CallRef(left, args, detectedBlockArgument, isOptional) + else -> CallRef(left, args, detectedBlockArgument, isOptional).also { callRef -> + applyExplicitCallTypeArgs(callRef, explicitTypeArgs) + } } return result } + private fun applyExplicitCallTypeArgs(callRef: CallRef, explicitTypeArgs: List?) { + if (explicitTypeArgs.isNullOrEmpty()) return + val baseName = when (val target = callRef.target) { + is LocalVarRef -> target.name + is FastLocalVarRef -> target.name + is LocalSlotRef -> target.name + is ConstRef -> (target.constValue as? ObjClass)?.className + else -> null + } ?: return + callReturnTypeDeclByRef[callRef] = TypeDecl.Generic(baseName, explicitTypeArgs, isNullable = false) + } + private fun inferReceiverTypeFromArgs(args: List): String? { val stmt = args.firstOrNull()?.value as? ExpressionStatement ?: return null val ref = stmt.ref diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Scope.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Scope.kt index 688019a..a4e62d0 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Scope.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Scope.kt @@ -471,6 +471,52 @@ open class Scope( } } + private fun resolvedRecordValueOrNull(record: ObjRecord): Obj? { + return when (val raw = record.value) { + is FrameSlotRef -> raw.read() + is RecordSlotRef -> raw.read() + else -> raw + } + } + + private fun declaredTypeForValueInThisScope(value: Obj): TypeDecl? { + // Prefer direct bindings first. + for (record in objects.values) { + val decl = record.typeDecl ?: continue + if (resolvedRecordValueOrNull(record) === value) return decl + } + for ((_, record) in localBindings) { + val decl = record.typeDecl ?: continue + if (resolvedRecordValueOrNull(record) === value) return decl + } + // Then slots (for frame-first locals). + var i = 0 + while (i < slots.size) { + val record = slots[i] + val decl = record.typeDecl + if (decl != null && resolvedRecordValueOrNull(record) === value) return decl + i++ + } + return null + } + + /** + * Best-effort lookup of the declared Set element type for a runtime set instance. + * Returns null when type info is unavailable. + */ + fun declaredSetElementTypeForValue(value: Obj): TypeDecl? { + var s: Scope? = this + var hops = 0 + while (s != null && hops++ < 1024) { + val decl = s.declaredTypeForValueInThisScope(value) + if (decl is TypeDecl.Generic && decl.name.substringAfterLast('.') == "Set") { + return decl.args.firstOrNull() + } + s = s.parent + } + return null + } + internal fun applySlotPlanReset(plan: Map, records: Map) { if (plan.isEmpty()) return slots.clear() diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjSet.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjSet.kt index f487939..8cd68f0 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjSet.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjSet.kt @@ -27,6 +27,16 @@ import net.sergeych.lynon.LynonEncoder import net.sergeych.lynon.LynonType class ObjSet(val set: MutableSet = mutableSetOf()) : Obj() { + private fun shouldTreatAsSingleElement(scope: Scope, other: Obj): Boolean { + if (!other.isInstanceOf(ObjIterable)) return true + val declaredElementType = scope.declaredSetElementTypeForValue(this) + if (declaredElementType != null && matchesTypeDecl(scope, other, declaredElementType)) { + return true + } + // Strings and buffers are iterable but usually expected to be atomic values for set +/- operators. + if (other is ObjString || other is ObjBuffer) return true + return false + } override suspend fun equals(scope: Scope, other: Obj): Boolean { if (this === other) return true @@ -53,6 +63,9 @@ class ObjSet(val set: MutableSet = mutableSetOf()) : Obj() { } override suspend fun plus(scope: Scope, other: Obj): Obj { + if (shouldTreatAsSingleElement(scope, other)) { + return ObjSet((set + other).toMutableSet()) + } return ObjSet( if (other is ObjSet) (set + other.set).toMutableSet() @@ -73,6 +86,10 @@ class ObjSet(val set: MutableSet = mutableSetOf()) : Obj() { } override suspend fun plusAssign(scope: Scope, other: Obj): Obj { + if (shouldTreatAsSingleElement(scope, other)) { + set += other + return this + } when (other) { is ObjSet -> { set += other.set @@ -105,6 +122,9 @@ class ObjSet(val set: MutableSet = mutableSetOf()) : Obj() { } override suspend fun minus(scope: Scope, other: Obj): Obj { + if (shouldTreatAsSingleElement(scope, other)) { + return ObjSet((set - other).toMutableSet()) + } return when { other is ObjSet -> ObjSet(set.minus(other.set).toMutableSet()) other.isInstanceOf(ObjIterable) -> { @@ -115,8 +135,7 @@ class ObjSet(val set: MutableSet = mutableSetOf()) : Obj() { } ObjSet((set - otherSet).toMutableSet()) } - else -> - scope.raiseIllegalArgument("set operator - requires another set or Iterable") + else -> ObjSet((set - other).toMutableSet()) } } diff --git a/lynglib/src/commonTest/kotlin/TypesTest.kt b/lynglib/src/commonTest/kotlin/TypesTest.kt index 7100295..860c41c 100644 --- a/lynglib/src/commonTest/kotlin/TypesTest.kt +++ b/lynglib/src/commonTest/kotlin/TypesTest.kt @@ -399,4 +399,52 @@ class TypesTest { fun testOk5() { l4(1, "a", "b", "x") } """.trimIndent()) } + + @Test + fun testSetTyped() = runTest { + eval(""" + var s = Set() + val typed: Set = s + assertEquals(Set(), typed) + + s += "foo" + assertEquals(Set("foo"), s) + s -= "foo" + assertEquals(Set(), s) + s += ["foo", "bar"] + assertEquals(Set("foo", "bar"), s) + """.trimIndent()) + } + +// @Test +// fun testAliasesInGenerics1() = runTest { +// val scope = Script.newScope() +// scope.eval(""" +// type IntList = List +// type IntMap = Map +// type IntSet = Set +// type IntPair = Pair +// type IntTriple = Triple +// type IntQuad = Quad +// +// import lyng.buffer +// type Tag = String | Buffer +// +// class X { +// var tags: Set = Set() +// } +// val x = X() +// x.tags += "tag1" +// assertEquals(Set("tag1"), x.tags) +// x.tags += "tag2" +// assertEquals(Set("tag1", "tag2"), x.tags) +// x.tags += Buffer("tag3") +// assertEquals(Set("tag1", "tag2", Buffer("tag3")), x.tags) +// x.tags += Buffer("tag4") +// assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags) +// x.tags += "tag3" +// x.tags += "tag4" +// assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags) +// """) +// } }