Improve Set handling with type-aware operations and enhance generic function parsing logic
This commit is contained in:
parent
32a7adf56e
commit
3bfb80a7c1
@ -169,6 +169,7 @@ class Compiler(
|
|||||||
)
|
)
|
||||||
private val typeAliases: MutableMap<String, TypeAliasDecl> = mutableMapOf()
|
private val typeAliases: MutableMap<String, TypeAliasDecl> = mutableMapOf()
|
||||||
private val methodReturnTypeDeclByRef: MutableMap<ObjRef, TypeDecl> = mutableMapOf()
|
private val methodReturnTypeDeclByRef: MutableMap<ObjRef, TypeDecl> = mutableMapOf()
|
||||||
|
private val callReturnTypeDeclByRef: MutableMap<CallRef, TypeDecl> = mutableMapOf()
|
||||||
private val callableReturnTypeByScopeId: MutableMap<Int, MutableMap<Int, ObjClass>> = mutableMapOf()
|
private val callableReturnTypeByScopeId: MutableMap<Int, MutableMap<Int, ObjClass>> = mutableMapOf()
|
||||||
private val callableReturnTypeByName: MutableMap<String, ObjClass> = mutableMapOf()
|
private val callableReturnTypeByName: MutableMap<String, ObjClass> = mutableMapOf()
|
||||||
private val lambdaReturnTypeByRef: MutableMap<ObjRef, ObjClass> = mutableMapOf()
|
private val lambdaReturnTypeByRef: MutableMap<ObjRef, ObjClass> = mutableMapOf()
|
||||||
@ -2544,6 +2545,7 @@ class Compiler(
|
|||||||
|
|
||||||
private suspend fun parseTerm(): ObjRef? {
|
private suspend fun parseTerm(): ObjRef? {
|
||||||
var operand: ObjRef? = null
|
var operand: ObjRef? = null
|
||||||
|
var pendingCallTypeArgs: List<TypeDecl>? = null
|
||||||
|
|
||||||
// newlines _before_
|
// newlines _before_
|
||||||
cc.skipWsTokens()
|
cc.skipWsTokens()
|
||||||
@ -2791,20 +2793,35 @@ class Compiler(
|
|||||||
operand = parseScopeOperator(operand)
|
operand = parseScopeOperator(operand)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Token.Type.LT -> {
|
||||||
|
val parsedTypeArgs = operand
|
||||||
|
?.takeIf { isGenericCallCalleeCandidate(it) }
|
||||||
|
?.let { tryParseCallTypeArgsAfterLt() }
|
||||||
|
if (parsedTypeArgs != null) {
|
||||||
|
pendingCallTypeArgs = parsedTypeArgs
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cc.previous()
|
||||||
|
return operand
|
||||||
|
}
|
||||||
|
|
||||||
Token.Type.LPAREN, Token.Type.NULL_COALESCE_INVOKE -> {
|
Token.Type.LPAREN, Token.Type.NULL_COALESCE_INVOKE -> {
|
||||||
operand?.let { left ->
|
operand?.let { left ->
|
||||||
// this is function call from <left>
|
// this is function call from <left>
|
||||||
operand = parseFunctionCall(
|
operand = parseFunctionCall(
|
||||||
left,
|
left,
|
||||||
false,
|
false,
|
||||||
t.type == Token.Type.NULL_COALESCE_INVOKE
|
t.type == Token.Type.NULL_COALESCE_INVOKE,
|
||||||
|
pendingCallTypeArgs
|
||||||
)
|
)
|
||||||
|
pendingCallTypeArgs = null
|
||||||
} ?: run {
|
} ?: run {
|
||||||
// Expression in parentheses
|
// Expression in parentheses
|
||||||
val statement = parseStatement() ?: throw ScriptError(t.pos, "Expecting expression")
|
val statement = parseStatement() ?: throw ScriptError(t.pos, "Expecting expression")
|
||||||
operand = StatementRef(statement)
|
operand = StatementRef(statement)
|
||||||
cc.skipTokenOfType(Token.Type.NEWLINE, isOptional = true)
|
cc.skipTokenOfType(Token.Type.NEWLINE, isOptional = true)
|
||||||
cc.skipTokenOfType(Token.Type.RPAREN, "missing ')'")
|
cc.skipTokenOfType(Token.Type.RPAREN, "missing ')'")
|
||||||
|
pendingCallTypeArgs = null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2984,7 +3001,8 @@ class Compiler(
|
|||||||
parseFunctionCall(
|
parseFunctionCall(
|
||||||
left,
|
left,
|
||||||
blockArgument = true,
|
blockArgument = true,
|
||||||
isOptional = t.type == Token.Type.NULL_COALESCE_BLOCKINVOKE
|
isOptional = t.type == Token.Type.NULL_COALESCE_BLOCKINVOKE,
|
||||||
|
explicitTypeArgs = pendingCallTypeArgs
|
||||||
)
|
)
|
||||||
} ?: run {
|
} ?: run {
|
||||||
// Disambiguate between lambda and map literal.
|
// Disambiguate between lambda and map literal.
|
||||||
@ -3011,6 +3029,54 @@ class Compiler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private suspend fun tryParseCallTypeArgsAfterLt(): List<TypeDecl>? {
|
||||||
|
val savedAfterLt = cc.savePos()
|
||||||
|
return try {
|
||||||
|
val args = mutableListOf<TypeDecl>()
|
||||||
|
do {
|
||||||
|
val (argSem, _) = parseTypeExpressionWithMini()
|
||||||
|
args += argSem
|
||||||
|
val sep = cc.next()
|
||||||
|
when (sep.type) {
|
||||||
|
Token.Type.COMMA -> continue
|
||||||
|
Token.Type.GT -> break
|
||||||
|
Token.Type.SHR -> {
|
||||||
|
cc.pushPendingGT()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
else -> {
|
||||||
|
cc.restorePos(savedAfterLt)
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} while (true)
|
||||||
|
val nextType = cc.peekNextNonWhitespace().type
|
||||||
|
if (nextType != Token.Type.LPAREN && nextType != Token.Type.NULL_COALESCE_INVOKE) {
|
||||||
|
cc.restorePos(savedAfterLt)
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
args
|
||||||
|
} catch (_: ScriptError) {
|
||||||
|
cc.restorePos(savedAfterLt)
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun isGenericCallCalleeCandidate(ref: ObjRef): Boolean {
|
||||||
|
val name = when (ref) {
|
||||||
|
is LocalVarRef -> ref.name
|
||||||
|
is FastLocalVarRef -> ref.name
|
||||||
|
is LocalSlotRef -> ref.name
|
||||||
|
else -> null
|
||||||
|
}
|
||||||
|
if (name != null) {
|
||||||
|
if (lookupGenericFunctionDecl(name) != null) return true
|
||||||
|
if (name.firstOrNull()?.isUpperCase() == true) return true
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return ref is ConstRef && ref.constValue is ObjClass
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parse lambda expression, leading '{' is already consumed
|
* Parse lambda expression, leading '{' is already consumed
|
||||||
*/
|
*/
|
||||||
@ -4369,6 +4435,7 @@ class Compiler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
is MethodCallRef -> methodReturnTypeDeclByRef[ref]
|
is MethodCallRef -> methodReturnTypeDeclByRef[ref]
|
||||||
|
is CallRef -> callReturnTypeDeclByRef[ref]
|
||||||
is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) }
|
is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) }
|
||||||
else -> null
|
else -> null
|
||||||
}
|
}
|
||||||
@ -5407,7 +5474,8 @@ class Compiler(
|
|||||||
private suspend fun parseFunctionCall(
|
private suspend fun parseFunctionCall(
|
||||||
left: ObjRef,
|
left: ObjRef,
|
||||||
blockArgument: Boolean,
|
blockArgument: Boolean,
|
||||||
isOptional: Boolean
|
isOptional: Boolean,
|
||||||
|
explicitTypeArgs: List<TypeDecl>? = null
|
||||||
): ObjRef {
|
): ObjRef {
|
||||||
var detectedBlockArgument = blockArgument
|
var detectedBlockArgument = blockArgument
|
||||||
val expectedReceiver = tailBlockReceiverType(left)
|
val expectedReceiver = tailBlockReceiverType(left)
|
||||||
@ -5448,7 +5516,9 @@ class Compiler(
|
|||||||
val result = when (left) {
|
val result = when (left) {
|
||||||
is ImplicitThisMemberRef ->
|
is ImplicitThisMemberRef ->
|
||||||
if (left.methodId == null && left.fieldId != null) {
|
if (left.methodId == null && left.fieldId != null) {
|
||||||
CallRef(left, args, detectedBlockArgument, isOptional)
|
CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
|
||||||
|
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
ImplicitThisMethodCallRef(
|
ImplicitThisMethodCallRef(
|
||||||
left.name,
|
left.name,
|
||||||
@ -5481,7 +5551,9 @@ class Compiler(
|
|||||||
checkFunctionTypeCallArity(left, args, left.pos())
|
checkFunctionTypeCallArity(left, args, left.pos())
|
||||||
checkFunctionTypeCallTypes(left, args, left.pos())
|
checkFunctionTypeCallTypes(left, args, left.pos())
|
||||||
checkGenericBoundsAtCall(left.name, args, left.pos())
|
checkGenericBoundsAtCall(left.name, args, left.pos())
|
||||||
CallRef(left, args, detectedBlockArgument, isOptional)
|
CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
|
||||||
|
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
is LocalSlotRef -> {
|
is LocalSlotRef -> {
|
||||||
@ -5505,14 +5577,30 @@ class Compiler(
|
|||||||
checkFunctionTypeCallArity(left, args, left.pos())
|
checkFunctionTypeCallArity(left, args, left.pos())
|
||||||
checkFunctionTypeCallTypes(left, args, left.pos())
|
checkFunctionTypeCallTypes(left, args, left.pos())
|
||||||
checkGenericBoundsAtCall(left.name, args, left.pos())
|
checkGenericBoundsAtCall(left.name, args, left.pos())
|
||||||
CallRef(left, args, detectedBlockArgument, isOptional)
|
CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
|
||||||
|
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else -> CallRef(left, args, detectedBlockArgument, isOptional)
|
}
|
||||||
|
else -> CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
|
||||||
|
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun applyExplicitCallTypeArgs(callRef: CallRef, explicitTypeArgs: List<TypeDecl>?) {
|
||||||
|
if (explicitTypeArgs.isNullOrEmpty()) return
|
||||||
|
val baseName = when (val target = callRef.target) {
|
||||||
|
is LocalVarRef -> target.name
|
||||||
|
is FastLocalVarRef -> target.name
|
||||||
|
is LocalSlotRef -> target.name
|
||||||
|
is ConstRef -> (target.constValue as? ObjClass)?.className
|
||||||
|
else -> null
|
||||||
|
} ?: return
|
||||||
|
callReturnTypeDeclByRef[callRef] = TypeDecl.Generic(baseName, explicitTypeArgs, isNullable = false)
|
||||||
|
}
|
||||||
|
|
||||||
private fun inferReceiverTypeFromArgs(args: List<ParsedArgument>): String? {
|
private fun inferReceiverTypeFromArgs(args: List<ParsedArgument>): String? {
|
||||||
val stmt = args.firstOrNull()?.value as? ExpressionStatement ?: return null
|
val stmt = args.firstOrNull()?.value as? ExpressionStatement ?: return null
|
||||||
val ref = stmt.ref
|
val ref = stmt.ref
|
||||||
|
|||||||
@ -471,6 +471,52 @@ open class Scope(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun resolvedRecordValueOrNull(record: ObjRecord): Obj? {
|
||||||
|
return when (val raw = record.value) {
|
||||||
|
is FrameSlotRef -> raw.read()
|
||||||
|
is RecordSlotRef -> raw.read()
|
||||||
|
else -> raw
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun declaredTypeForValueInThisScope(value: Obj): TypeDecl? {
|
||||||
|
// Prefer direct bindings first.
|
||||||
|
for (record in objects.values) {
|
||||||
|
val decl = record.typeDecl ?: continue
|
||||||
|
if (resolvedRecordValueOrNull(record) === value) return decl
|
||||||
|
}
|
||||||
|
for ((_, record) in localBindings) {
|
||||||
|
val decl = record.typeDecl ?: continue
|
||||||
|
if (resolvedRecordValueOrNull(record) === value) return decl
|
||||||
|
}
|
||||||
|
// Then slots (for frame-first locals).
|
||||||
|
var i = 0
|
||||||
|
while (i < slots.size) {
|
||||||
|
val record = slots[i]
|
||||||
|
val decl = record.typeDecl
|
||||||
|
if (decl != null && resolvedRecordValueOrNull(record) === value) return decl
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Best-effort lookup of the declared Set element type for a runtime set instance.
|
||||||
|
* Returns null when type info is unavailable.
|
||||||
|
*/
|
||||||
|
fun declaredSetElementTypeForValue(value: Obj): TypeDecl? {
|
||||||
|
var s: Scope? = this
|
||||||
|
var hops = 0
|
||||||
|
while (s != null && hops++ < 1024) {
|
||||||
|
val decl = s.declaredTypeForValueInThisScope(value)
|
||||||
|
if (decl is TypeDecl.Generic && decl.name.substringAfterLast('.') == "Set") {
|
||||||
|
return decl.args.firstOrNull()
|
||||||
|
}
|
||||||
|
s = s.parent
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
internal fun applySlotPlanReset(plan: Map<String, Int>, records: Map<String, ObjRecord>) {
|
internal fun applySlotPlanReset(plan: Map<String, Int>, records: Map<String, ObjRecord>) {
|
||||||
if (plan.isEmpty()) return
|
if (plan.isEmpty()) return
|
||||||
slots.clear()
|
slots.clear()
|
||||||
|
|||||||
@ -27,6 +27,16 @@ import net.sergeych.lynon.LynonEncoder
|
|||||||
import net.sergeych.lynon.LynonType
|
import net.sergeych.lynon.LynonType
|
||||||
|
|
||||||
class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
|
class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
|
||||||
|
private fun shouldTreatAsSingleElement(scope: Scope, other: Obj): Boolean {
|
||||||
|
if (!other.isInstanceOf(ObjIterable)) return true
|
||||||
|
val declaredElementType = scope.declaredSetElementTypeForValue(this)
|
||||||
|
if (declaredElementType != null && matchesTypeDecl(scope, other, declaredElementType)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Strings and buffers are iterable but usually expected to be atomic values for set +/- operators.
|
||||||
|
if (other is ObjString || other is ObjBuffer) return true
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
override suspend fun equals(scope: Scope, other: Obj): Boolean {
|
override suspend fun equals(scope: Scope, other: Obj): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
@ -53,6 +63,9 @@ class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun plus(scope: Scope, other: Obj): Obj {
|
override suspend fun plus(scope: Scope, other: Obj): Obj {
|
||||||
|
if (shouldTreatAsSingleElement(scope, other)) {
|
||||||
|
return ObjSet((set + other).toMutableSet())
|
||||||
|
}
|
||||||
return ObjSet(
|
return ObjSet(
|
||||||
if (other is ObjSet)
|
if (other is ObjSet)
|
||||||
(set + other.set).toMutableSet()
|
(set + other.set).toMutableSet()
|
||||||
@ -73,6 +86,10 @@ class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun plusAssign(scope: Scope, other: Obj): Obj {
|
override suspend fun plusAssign(scope: Scope, other: Obj): Obj {
|
||||||
|
if (shouldTreatAsSingleElement(scope, other)) {
|
||||||
|
set += other
|
||||||
|
return this
|
||||||
|
}
|
||||||
when (other) {
|
when (other) {
|
||||||
is ObjSet -> {
|
is ObjSet -> {
|
||||||
set += other.set
|
set += other.set
|
||||||
@ -105,6 +122,9 @@ class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun minus(scope: Scope, other: Obj): Obj {
|
override suspend fun minus(scope: Scope, other: Obj): Obj {
|
||||||
|
if (shouldTreatAsSingleElement(scope, other)) {
|
||||||
|
return ObjSet((set - other).toMutableSet())
|
||||||
|
}
|
||||||
return when {
|
return when {
|
||||||
other is ObjSet -> ObjSet(set.minus(other.set).toMutableSet())
|
other is ObjSet -> ObjSet(set.minus(other.set).toMutableSet())
|
||||||
other.isInstanceOf(ObjIterable) -> {
|
other.isInstanceOf(ObjIterable) -> {
|
||||||
@ -115,8 +135,7 @@ class ObjSet(val set: MutableSet<Obj> = mutableSetOf()) : Obj() {
|
|||||||
}
|
}
|
||||||
ObjSet((set - otherSet).toMutableSet())
|
ObjSet((set - otherSet).toMutableSet())
|
||||||
}
|
}
|
||||||
else ->
|
else -> ObjSet((set - other).toMutableSet())
|
||||||
scope.raiseIllegalArgument("set operator - requires another set or Iterable")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -399,4 +399,52 @@ class TypesTest {
|
|||||||
fun testOk5() { l4(1, "a", "b", "x") }
|
fun testOk5() { l4(1, "a", "b", "x") }
|
||||||
""".trimIndent())
|
""".trimIndent())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSetTyped() = runTest {
|
||||||
|
eval("""
|
||||||
|
var s = Set<String>()
|
||||||
|
val typed: Set<String> = s
|
||||||
|
assertEquals(Set(), typed)
|
||||||
|
|
||||||
|
s += "foo"
|
||||||
|
assertEquals(Set("foo"), s)
|
||||||
|
s -= "foo"
|
||||||
|
assertEquals(Set(), s)
|
||||||
|
s += ["foo", "bar"]
|
||||||
|
assertEquals(Set("foo", "bar"), s)
|
||||||
|
""".trimIndent())
|
||||||
|
}
|
||||||
|
|
||||||
|
// @Test
|
||||||
|
// fun testAliasesInGenerics1() = runTest {
|
||||||
|
// val scope = Script.newScope()
|
||||||
|
// scope.eval("""
|
||||||
|
// type IntList<T: Int> = List<T>
|
||||||
|
// type IntMap<K,V> = Map<K,V>
|
||||||
|
// type IntSet<T: Int> = Set<T>
|
||||||
|
// type IntPair<T: Int> = Pair<T,T>
|
||||||
|
// type IntTriple<T: Int> = Triple<T,T,T>
|
||||||
|
// type IntQuad<T: Int> = Quad<T,T,T,T>
|
||||||
|
//
|
||||||
|
// import lyng.buffer
|
||||||
|
// type Tag = String | Buffer
|
||||||
|
//
|
||||||
|
// class X {
|
||||||
|
// var tags: Set<Tag> = Set()
|
||||||
|
// }
|
||||||
|
// val x = X()
|
||||||
|
// x.tags += "tag1"
|
||||||
|
// assertEquals(Set("tag1"), x.tags)
|
||||||
|
// x.tags += "tag2"
|
||||||
|
// assertEquals(Set("tag1", "tag2"), x.tags)
|
||||||
|
// x.tags += Buffer("tag3")
|
||||||
|
// assertEquals(Set("tag1", "tag2", Buffer("tag3")), x.tags)
|
||||||
|
// x.tags += Buffer("tag4")
|
||||||
|
// assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags)
|
||||||
|
// x.tags += "tag3"
|
||||||
|
// x.tags += "tag4"
|
||||||
|
// assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags)
|
||||||
|
// """)
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user