Improve Set handling with type-aware operations and enhance generic function parsing logic
This commit is contained in:
parent
32a7adf56e
commit
3bfb80a7c1
@ -169,6 +169,7 @@ class Compiler(
|
||||
)
|
||||
private val typeAliases: MutableMap<String, TypeAliasDecl> = mutableMapOf()
|
||||
private val methodReturnTypeDeclByRef: MutableMap<ObjRef, TypeDecl> = mutableMapOf()
|
||||
private val callReturnTypeDeclByRef: MutableMap<CallRef, TypeDecl> = mutableMapOf()
|
||||
private val callableReturnTypeByScopeId: MutableMap<Int, MutableMap<Int, ObjClass>> = mutableMapOf()
|
||||
private val callableReturnTypeByName: MutableMap<String, ObjClass> = mutableMapOf()
|
||||
private val lambdaReturnTypeByRef: MutableMap<ObjRef, ObjClass> = mutableMapOf()
|
||||
@ -2544,6 +2545,7 @@ class Compiler(
|
||||
|
||||
private suspend fun parseTerm(): ObjRef? {
|
||||
var operand: ObjRef? = null
|
||||
var pendingCallTypeArgs: List<TypeDecl>? = null
|
||||
|
||||
// newlines _before_
|
||||
cc.skipWsTokens()
|
||||
@ -2791,20 +2793,35 @@ class Compiler(
|
||||
operand = parseScopeOperator(operand)
|
||||
}
|
||||
|
||||
Token.Type.LT -> {
|
||||
val parsedTypeArgs = operand
|
||||
?.takeIf { isGenericCallCalleeCandidate(it) }
|
||||
?.let { tryParseCallTypeArgsAfterLt() }
|
||||
if (parsedTypeArgs != null) {
|
||||
pendingCallTypeArgs = parsedTypeArgs
|
||||
continue
|
||||
}
|
||||
cc.previous()
|
||||
return operand
|
||||
}
|
||||
|
||||
Token.Type.LPAREN, Token.Type.NULL_COALESCE_INVOKE -> {
|
||||
operand?.let { left ->
|
||||
// this is function call from <left>
|
||||
operand = parseFunctionCall(
|
||||
left,
|
||||
false,
|
||||
t.type == Token.Type.NULL_COALESCE_INVOKE
|
||||
t.type == Token.Type.NULL_COALESCE_INVOKE,
|
||||
pendingCallTypeArgs
|
||||
)
|
||||
pendingCallTypeArgs = null
|
||||
} ?: run {
|
||||
// Expression in parentheses
|
||||
val statement = parseStatement() ?: throw ScriptError(t.pos, "Expecting expression")
|
||||
operand = StatementRef(statement)
|
||||
cc.skipTokenOfType(Token.Type.NEWLINE, isOptional = true)
|
||||
cc.skipTokenOfType(Token.Type.RPAREN, "missing ')'")
|
||||
pendingCallTypeArgs = null
|
||||
}
|
||||
}
|
||||
|
||||
@ -2984,7 +3001,8 @@ class Compiler(
|
||||
parseFunctionCall(
|
||||
left,
|
||||
blockArgument = true,
|
||||
isOptional = t.type == Token.Type.NULL_COALESCE_BLOCKINVOKE
|
||||
isOptional = t.type == Token.Type.NULL_COALESCE_BLOCKINVOKE,
|
||||
explicitTypeArgs = pendingCallTypeArgs
|
||||
)
|
||||
} ?: run {
|
||||
// Disambiguate between lambda and map literal.
|
||||
@ -3011,6 +3029,54 @@ class Compiler(
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun tryParseCallTypeArgsAfterLt(): List<TypeDecl>? {
|
||||
val savedAfterLt = cc.savePos()
|
||||
return try {
|
||||
val args = mutableListOf<TypeDecl>()
|
||||
do {
|
||||
val (argSem, _) = parseTypeExpressionWithMini()
|
||||
args += argSem
|
||||
val sep = cc.next()
|
||||
when (sep.type) {
|
||||
Token.Type.COMMA -> continue
|
||||
Token.Type.GT -> break
|
||||
Token.Type.SHR -> {
|
||||
cc.pushPendingGT()
|
||||
break
|
||||
}
|
||||
else -> {
|
||||
cc.restorePos(savedAfterLt)
|
||||
return null
|
||||
}
|
||||
}
|
||||
} while (true)
|
||||
val nextType = cc.peekNextNonWhitespace().type
|
||||
if (nextType != Token.Type.LPAREN && nextType != Token.Type.NULL_COALESCE_INVOKE) {
|
||||
cc.restorePos(savedAfterLt)
|
||||
return null
|
||||
}
|
||||
args
|
||||
} catch (_: ScriptError) {
|
||||
cc.restorePos(savedAfterLt)
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
private fun isGenericCallCalleeCandidate(ref: ObjRef): Boolean {
|
||||
val name = when (ref) {
|
||||
is LocalVarRef -> ref.name
|
||||
is FastLocalVarRef -> ref.name
|
||||
is LocalSlotRef -> ref.name
|
||||
else -> null
|
||||
}
|
||||
if (name != null) {
|
||||
if (lookupGenericFunctionDecl(name) != null) return true
|
||||
if (name.firstOrNull()?.isUpperCase() == true) return true
|
||||
return false
|
||||
}
|
||||
return ref is ConstRef && ref.constValue is ObjClass
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse lambda expression, leading '{' is already consumed
|
||||
*/
|
||||
@ -4369,6 +4435,7 @@ class Compiler(
|
||||
}
|
||||
}
|
||||
is MethodCallRef -> methodReturnTypeDeclByRef[ref]
|
||||
is CallRef -> callReturnTypeDeclByRef[ref]
|
||||
is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) }
|
||||
else -> null
|
||||
}
|
||||
@ -5407,7 +5474,8 @@ class Compiler(
|
||||
private suspend fun parseFunctionCall(
|
||||
left: ObjRef,
|
||||
blockArgument: Boolean,
|
||||
isOptional: Boolean
|
||||
isOptional: Boolean,
|
||||
explicitTypeArgs: List<TypeDecl>? = null
|
||||
): ObjRef {
|
||||
var detectedBlockArgument = blockArgument
|
||||
val expectedReceiver = tailBlockReceiverType(left)
|
||||
@ -5448,7 +5516,9 @@ class Compiler(
|
||||
val result = when (left) {
|
||||
is ImplicitThisMemberRef ->
|
||||
if (left.methodId == null && left.fieldId != null) {
|
||||
CallRef(left, args, detectedBlockArgument, isOptional)
|
||||
CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
|
||||
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
|
||||
}
|
||||
} else {
|
||||
ImplicitThisMethodCallRef(
|
||||
left.name,
|
||||
@ -5481,7 +5551,9 @@ class Compiler(
|
||||
checkFunctionTypeCallArity(left, args, left.pos())
|
||||
checkFunctionTypeCallTypes(left, args, left.pos())
|
||||
checkGenericBoundsAtCall(left.name, args, left.pos())
|
||||
CallRef(left, args, detectedBlockArgument, isOptional)
|
||||
CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
|
||||
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
|
||||
}
|
||||
}
|
||||
}
|
||||
is LocalSlotRef -> {
|
||||
@ -5505,14 +5577,30 @@ class Compiler(
|
||||
checkFunctionTypeCallArity(left, args, left.pos())
|
||||
checkFunctionTypeCallTypes(left, args, left.pos())
|
||||
checkGenericBoundsAtCall(left.name, args, left.pos())
|
||||
CallRef(left, args, detectedBlockArgument, isOptional)
|
||||
CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
|
||||
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
|
||||
}
|
||||
}
|
||||
}
|
||||
else -> CallRef(left, args, detectedBlockArgument, isOptional)
|
||||
else -> CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
|
||||
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
private fun applyExplicitCallTypeArgs(callRef: CallRef, explicitTypeArgs: List<TypeDecl>?) {
|
||||
if (explicitTypeArgs.isNullOrEmpty()) return
|
||||
val baseName = when (val target = callRef.target) {
|
||||
is LocalVarRef -> target.name
|
||||
is FastLocalVarRef -> target.name
|
||||
is LocalSlotRef -> target.name
|
||||
is ConstRef -> (target.constValue as? ObjClass)?.className
|
||||
else -> null
|
||||
} ?: return
|
||||
callReturnTypeDeclByRef[callRef] = TypeDecl.Generic(baseName, explicitTypeArgs, isNullable = false)
|
||||
}
|
||||
|
||||
private fun inferReceiverTypeFromArgs(args: List<ParsedArgument>): String? {
|
||||
val stmt = args.firstOrNull()?.value as? ExpressionStatement ?: return null
|
||||
val ref = stmt.ref
|
||||
|
||||
@ -471,6 +471,52 @@ open class Scope(
|
||||
}
|
||||
}
|
||||
|
||||
private fun resolvedRecordValueOrNull(record: ObjRecord): Obj? {
|
||||
return when (val raw = record.value) {
|
||||
is FrameSlotRef -> raw.read()
|
||||
is RecordSlotRef -> raw.read()
|
||||
else -> raw
|
||||
}
|
||||
}
|
||||
|
||||
private fun declaredTypeForValueInThisScope(value: Obj): TypeDecl? {
|
||||
// Prefer direct bindings first.
|
||||
for (record in objects.values) {
|
||||
val decl = record.typeDecl ?: continue
|
||||
if (resolvedRecordValueOrNull(record) === value) return decl
|
||||
}
|
||||
for ((_, record) in localBindings) {
|
||||
val decl = record.typeDecl ?: continue
|
||||
if (resolvedRecordValueOrNull(record) === value) return decl
|
||||
}
|
||||
// Then slots (for frame-first locals).
|
||||
var i = 0
|
||||
while (i < slots.size) {
|
||||
val record = slots[i]
|
||||
val decl = record.typeDecl
|
||||
if (decl != null && resolvedRecordValueOrNull(record) === value) return decl
|
||||
i++
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Best-effort lookup of the declared Set element type for a runtime set instance.
|
||||
* Returns null when type info is unavailable.
|
||||
*/
|
||||
fun declaredSetElementTypeForValue(value: Obj): TypeDecl? {
|
||||
var s: Scope? = this
|
||||
var hops = 0
|
||||
while (s != null && hops++ < 1024) {
|
||||
val decl = s.declaredTypeForValueInThisScope(value)
|
||||
if (decl is TypeDecl.Generic && decl.name.substringAfterLast('.') == "Set") {
|
||||
return decl.args.firstOrNull()
|
||||
}
|
||||
s = s.parent
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
internal fun applySlotPlanReset(plan: Map<String, Int>, records: Map<String, ObjRecord>) {
|
||||
if (plan.isEmpty()) return
|
||||
slots.clear()
|
||||
|
||||
@ -27,6 +27,16 @@ import net.sergeych.lynon.LynonEncoder
|
||||
import net.sergeych.lynon.LynonType
|
||||
|
||||
class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
|
||||
private fun shouldTreatAsSingleElement(scope: Scope, other: Obj): Boolean {
|
||||
if (!other.isInstanceOf(ObjIterable)) return true
|
||||
val declaredElementType = scope.declaredSetElementTypeForValue(this)
|
||||
if (declaredElementType != null && matchesTypeDecl(scope, other, declaredElementType)) {
|
||||
return true
|
||||
}
|
||||
// Strings and buffers are iterable but usually expected to be atomic values for set +/- operators.
|
||||
if (other is ObjString || other is ObjBuffer) return true
|
||||
return false
|
||||
}
|
||||
|
||||
override suspend fun equals(scope: Scope, other: Obj): Boolean {
|
||||
if (this === other) return true
|
||||
@ -53,6 +63,9 @@ class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
|
||||
}
|
||||
|
||||
override suspend fun plus(scope: Scope, other: Obj): Obj {
|
||||
if (shouldTreatAsSingleElement(scope, other)) {
|
||||
return ObjSet((set + other).toMutableSet())
|
||||
}
|
||||
return ObjSet(
|
||||
if (other is ObjSet)
|
||||
(set + other.set).toMutableSet()
|
||||
@ -73,6 +86,10 @@ class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
|
||||
}
|
||||
|
||||
override suspend fun plusAssign(scope: Scope, other: Obj): Obj {
|
||||
if (shouldTreatAsSingleElement(scope, other)) {
|
||||
set += other
|
||||
return this
|
||||
}
|
||||
when (other) {
|
||||
is ObjSet -> {
|
||||
set += other.set
|
||||
@ -105,6 +122,9 @@ class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
|
||||
}
|
||||
|
||||
override suspend fun minus(scope: Scope, other: Obj): Obj {
|
||||
if (shouldTreatAsSingleElement(scope, other)) {
|
||||
return ObjSet((set - other).toMutableSet())
|
||||
}
|
||||
return when {
|
||||
other is ObjSet -> ObjSet(set.minus(other.set).toMutableSet())
|
||||
other.isInstanceOf(ObjIterable) -> {
|
||||
@ -115,8 +135,7 @@ class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
|
||||
}
|
||||
ObjSet((set - otherSet).toMutableSet())
|
||||
}
|
||||
else ->
|
||||
scope.raiseIllegalArgument("set operator - requires another set or Iterable")
|
||||
else -> ObjSet((set - other).toMutableSet())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -399,4 +399,52 @@ class TypesTest {
|
||||
fun testOk5() { l4(1, "a", "b", "x") }
|
||||
""".trimIndent())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSetTyped() = runTest {
|
||||
eval("""
|
||||
var s = Set<String>()
|
||||
val typed: Set<String> = s
|
||||
assertEquals(Set(), typed)
|
||||
|
||||
s += "foo"
|
||||
assertEquals(Set("foo"), s)
|
||||
s -= "foo"
|
||||
assertEquals(Set(), s)
|
||||
s += ["foo", "bar"]
|
||||
assertEquals(Set("foo", "bar"), s)
|
||||
""".trimIndent())
|
||||
}
|
||||
|
||||
// @Test
|
||||
// fun testAliasesInGenerics1() = runTest {
|
||||
// val scope = Script.newScope()
|
||||
// scope.eval("""
|
||||
// type IntList<T: Int> = List<T>
|
||||
// type IntMap<K,V> = Map<K,V>
|
||||
// type IntSet<T: Int> = Set<T>
|
||||
// type IntPair<T: Int> = Pair<T,T>
|
||||
// type IntTriple<T: Int> = Triple<T,T,T>
|
||||
// type IntQuad<T: Int> = Quad<T,T,T,T>
|
||||
//
|
||||
// import lyng.buffer
|
||||
// type Tag = String | Buffer
|
||||
//
|
||||
// class X {
|
||||
// var tags: Set<Tag> = Set()
|
||||
// }
|
||||
// val x = X()
|
||||
// x.tags += "tag1"
|
||||
// assertEquals(Set("tag1"), x.tags)
|
||||
// x.tags += "tag2"
|
||||
// assertEquals(Set("tag1", "tag2"), x.tags)
|
||||
// x.tags += Buffer("tag3")
|
||||
// assertEquals(Set("tag1", "tag2", Buffer("tag3")), x.tags)
|
||||
// x.tags += Buffer("tag4")
|
||||
// assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags)
|
||||
// x.tags += "tag3"
|
||||
// x.tags += "tag4"
|
||||
// assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags)
|
||||
// """)
|
||||
// }
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user