diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt index f5e06c4..4f89b2c 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt @@ -131,6 +131,15 @@ class Compiler( private fun moduleSlotPlan(): SlotPlan? = slotPlanStack.firstOrNull() private val slotTypeByScopeId: MutableMap> = mutableMapOf() private val nameObjClass: MutableMap = mutableMapOf() + private val slotTypeDeclByScopeId: MutableMap> = mutableMapOf() + private val nameTypeDecl: MutableMap = mutableMapOf() + private val methodReturnTypeDeclByRef: MutableMap = mutableMapOf() + private val callableReturnTypeByScopeId: MutableMap> = mutableMapOf() + private val callableReturnTypeByName: MutableMap = mutableMapOf() + private val lambdaReturnTypeByRef: MutableMap = mutableMapOf() + private val classFieldTypesByName: MutableMap> = mutableMapOf() + private val encodedPayloadTypeByScopeId: MutableMap> = mutableMapOf() + private val encodedPayloadTypeByName: MutableMap = mutableMapOf() private fun seedSlotPlanFromScope(scope: Scope, includeParents: Boolean = false) { val plan = moduleSlotPlan() ?: return @@ -685,6 +694,8 @@ class Compiler( private val seedScope: Scope? = settings.seedScope private var resolutionScriptDepth = 0 private val resolutionPredeclared = mutableSetOf() + private val importedScopes = mutableListOf() + private val enumEntriesByName = mutableMapOf>() // --- Doc-comment collection state (for immediate preceding declarations) --- private val pendingDocLines = mutableListOf() @@ -727,6 +738,111 @@ class Compiler( return sourceName != "lyng.stdlib" } + private fun looksLikeExtensionReceiver(): Boolean { + val saved = cc.savePos() + try { + if (cc.peekNextNonWhitespace().type != Token.Type.ID) return false + cc.nextNonWhitespace() + // consume qualified name segments + while (cc.peekNextNonWhitespace().type == Token.Type.DOT) { + val dotPos = cc.savePos() + cc.nextNonWhitespace() + if (cc.peekNextNonWhitespace().type != Token.Type.ID) { + cc.restorePos(dotPos) + break + } + cc.nextNonWhitespace() + val afterSegment = cc.peekNextNonWhitespace() + if (afterSegment.type != Token.Type.DOT && + afterSegment.type != Token.Type.LT && + afterSegment.type != Token.Type.QUESTION && + afterSegment.type != Token.Type.IFNULLASSIGN + ) { + cc.restorePos(dotPos) + break + } + } + // optional generic arguments + if (cc.peekNextNonWhitespace().type == Token.Type.LT) { + var depth = 0 + while (true) { + val tok = cc.nextNonWhitespace() + when (tok.type) { + Token.Type.LT -> depth += 1 + Token.Type.GT -> { + depth -= 1 + if (depth <= 0) break + } + Token.Type.SHR -> { + depth -= 2 + if (depth <= 0) break + } + Token.Type.EOF -> return false + else -> {} + } + } + } + // nullable suffix + if (cc.peekNextNonWhitespace().type == Token.Type.QUESTION || + cc.peekNextNonWhitespace().type == Token.Type.IFNULLASSIGN + ) { + cc.nextNonWhitespace() + } + val dotTok = cc.peekNextNonWhitespace() + if (dotTok.type != Token.Type.DOT) return false + val savedDot = cc.savePos() + cc.nextNonWhitespace() + val nameTok = cc.peekNextNonWhitespace() + cc.restorePos(savedDot) + return nameTok.type == Token.Type.ID + } finally { + cc.restorePos(saved) + } + } + + private fun shouldImplicitTypeVar(name: String, explicit: Set): Boolean { + if (explicit.contains(name)) return true + if (name.contains('.')) return false + if (resolveClassByName(name) != null) return false + if (resolveTypeDeclObjClass(TypeDecl.Simple(name, false)) != null) return false + return name.length == 1 || name in setOf("T", "R", "E", "K", "V") + } + + private fun normalizeReceiverTypeDecl( + receiver: TypeDecl?, + explicitTypeParams: Set + ): Pair> { + if (receiver == null) return null to emptySet() + val implicit = mutableSetOf() + fun transform(decl: TypeDecl): TypeDecl = when (decl) { + is TypeDecl.Simple -> { + if (shouldImplicitTypeVar(decl.name, explicitTypeParams)) { + if (!explicitTypeParams.contains(decl.name)) implicit += decl.name + TypeDecl.TypeVar(decl.name, decl.isNullable) + } else decl + } + is TypeDecl.TypeVar -> { + if (!explicitTypeParams.contains(decl.name)) implicit += decl.name + decl + } + is TypeDecl.Generic -> TypeDecl.Generic( + decl.name, + decl.args.map { transform(it) }, + decl.isNullable + ) + is TypeDecl.Function -> TypeDecl.Function( + receiver = decl.receiver?.let { transform(it) }, + params = decl.params.map { transform(it) }, + returnType = transform(decl.returnType), + nullable = decl.isNullable + ) + is TypeDecl.Union -> TypeDecl.Union(decl.options.map { transform(it) }, decl.isNullable) + is TypeDecl.Intersection -> TypeDecl.Intersection(decl.options.map { transform(it) }, decl.isNullable) + else -> decl + } + return transform(receiver) to implicit + } + private var anonCounter = 0 private fun generateAnonName(pos: Pos): String { return "${"$"}${"Anon"}_${pos.line+1}_${pos.column}_${++anonCounter}" @@ -909,6 +1025,7 @@ class Compiler( } } val module = importManager.prepareImport(pos, name, null) + importedScopes.add(module) seedResolutionFromScope(module, pos) seedSlotPlanFromScope(module) statements += object : Statement() { @@ -1711,6 +1828,7 @@ class Compiler( } MethodCallRef(left, next.value, args, tailBlock, isOptional) } else { + enforceReceiverTypeForMember(left, next.value, next.pos) MethodCallRef(left, next.value, args, tailBlock, isOptional) } is QualifiedThisRef -> @@ -1724,7 +1842,10 @@ class Compiler( ).also { resolutionSink?.referenceMember(next.value, next.pos, left.typeName) } - else -> MethodCallRef(left, next.value, args, tailBlock, isOptional) + else -> { + enforceReceiverTypeForMember(left, next.value, next.pos) + MethodCallRef(left, next.value, args, tailBlock, isOptional) + } } } @@ -1736,7 +1857,10 @@ class Compiler( val receiverType = if (next.value == "apply" || next.value == "run") { inferReceiverTypeFromRef(left) } else null - val lambda = parseLambdaExpression(receiverType) + val itType = if (next.value == "let" || next.value == "also") { + inferReceiverTypeFromRef(left) + } else null + val lambda = parseLambdaExpression(receiverType, implicitItType = itType) val argPos = next.pos val args = listOf(ParsedArgument(ExpressionStatement(lambda, argPos), next.pos)) operand = when (left) { @@ -1756,6 +1880,7 @@ class Compiler( } MethodCallRef(left, next.value, args, true, isOptional) } else { + enforceReceiverTypeForMember(left, next.value, next.pos) MethodCallRef(left, next.value, args, true, isOptional) } is QualifiedThisRef -> @@ -1769,7 +1894,10 @@ class Compiler( ).also { resolutionSink?.referenceMember(next.value, next.pos, left.typeName) } - else -> MethodCallRef(left, next.value, args, true, isOptional) + else -> { + enforceReceiverTypeForMember(left, next.value, next.pos) + MethodCallRef(left, next.value, args, true, isOptional) + } } } @@ -1784,6 +1912,7 @@ class Compiler( val ids = resolveMemberIds(next.value, next.pos, implicitType) ThisFieldSlotRef(next.value, ids.fieldId, ids.methodId, isOptional) } else { + enforceReceiverTypeForMember(left, next.value, next.pos) FieldRef(left, next.value, isOptional) } is QualifiedThisRef -> run { @@ -1798,7 +1927,10 @@ class Compiler( }.also { resolutionSink?.referenceMember(next.value, next.pos, left.typeName) } - else -> FieldRef(left, next.value, isOptional) + else -> { + enforceReceiverTypeForMember(left, next.value, next.pos) + FieldRef(left, next.value, isOptional) + } } } } @@ -2006,7 +2138,11 @@ class Compiler( /** * Parse lambda expression, leading '{' is already consumed */ - private suspend fun parseLambdaExpression(expectedReceiverType: String? = null): ObjRef { + private suspend fun parseLambdaExpression( + expectedReceiverType: String? = null, + wrapAsExtensionCallable: Boolean = false, + implicitItType: String? = null + ): ObjRef { // lambda args are different: val startPos = cc.currentPos() val label = lastLabel @@ -2021,6 +2157,15 @@ class Compiler( val hasImplicitIt = argsDeclaration == null val slotParamNames = if (hasImplicitIt) paramNames + "it" else paramNames val paramSlotPlan = buildParamSlotPlan(slotParamNames) + if (implicitItType != null) { + val cls = resolveClassByName(implicitItType) + ?: resolveTypeDeclObjClass(TypeDecl.Simple(implicitItType, false)) + val itSlot = paramSlotPlan.slots["it"]?.index + if (cls != null && itSlot != null) { + val paramTypeMap = slotTypeByScopeId.getOrPut(paramSlotPlan.id) { mutableMapOf() } + paramTypeMap[itSlot] = cls + } + } label?.let { cc.labels.add(it) } slotPlanStack.add(paramSlotPlan) @@ -2052,7 +2197,8 @@ class Compiler( val paramSlotPlanSnapshot = slotPlanIndices(paramSlotPlan) val captureSlots = capturePlan.captures.toList() - return ValueFnRef { closureScope -> + val returnClass = inferReturnClassFromStatement(body) + val ref = ValueFnRef { closureScope -> val stmt = object : Statement() { override val pos: Pos = body.pos override suspend fun execute(scope: Scope): Obj { @@ -2104,8 +2250,17 @@ class Compiler( } } } - stmt.asReadonly + val callable: Obj = if (wrapAsExtensionCallable) { + ObjExtensionMethodCallable("", stmt) + } else { + stmt + } + callable.asReadonly } + if (returnClass != null) { + lambdaReturnTypeByRef[ref] = returnClass + } + return ref } private suspend fun parseArrayLiteral(): List { @@ -2523,20 +2678,24 @@ class Compiler( true } else false - val rangeStart = when (receiverMini) { + val normalizedReceiverDecl = receiverDecl?.let { decl -> + if (decl is TypeDecl.Simple && (decl.name == "Object" || decl.name == "Obj")) null else decl + } + val normalizedReceiverMini = if (normalizedReceiverDecl == null) null else receiverMini + val rangeStart = when (normalizedReceiverMini) { null -> startPos - else -> receiverMini.range.start + else -> normalizedReceiverMini.range.start } val rangeEnd = cc.currentPos() val mini = MiniFunctionType( range = MiniRange(rangeStart, rangeEnd), - receiver = receiverMini, + receiver = normalizedReceiverMini, params = params.map { it.second }, returnType = retMini, nullable = isNullable ) val sem = TypeDecl.Function( - receiver = receiverDecl, + receiver = normalizedReceiverDecl, params = params.map { it.first }, returnType = retDecl, nullable = isNullable @@ -2653,16 +2812,134 @@ class Compiler( return Pair(sem, miniRef) } + private fun parseExtensionReceiverTypeWithMini(): Pair { + val segments = mutableListOf() + var first = true + val typeStart = cc.currentPos() + var lastEnd = typeStart + var lastName: String? = null + var lastPos: Pos? = null + while (true) { + val idTok = + if (first) cc.requireToken(Token.Type.ID, "type name or type expression required") else cc.requireToken( + Token.Type.ID, + "identifier expected after '.' in type" + ) + first = false + segments += MiniTypeName.Segment(idTok.value, MiniRange(idTok.pos, idTok.pos)) + lastEnd = cc.currentPos() + lastName = idTok.value + lastPos = idTok.pos + val dotPos = cc.savePos() + val t = cc.next() + if (t.type == Token.Type.DOT) { + val nextAfterDot = cc.peekNextNonWhitespace() + if (nextAfterDot.type != Token.Type.ID) { + cc.restorePos(dotPos) + break + } + cc.nextNonWhitespace() + val afterSegment = cc.peekNextNonWhitespace() + if (afterSegment.type != Token.Type.DOT && + afterSegment.type != Token.Type.LT && + afterSegment.type != Token.Type.QUESTION && + afterSegment.type != Token.Type.IFNULLASSIGN + ) { + cc.restorePos(dotPos) + break + } + segments += MiniTypeName.Segment(nextAfterDot.value, MiniRange(nextAfterDot.pos, nextAfterDot.pos)) + lastEnd = cc.currentPos() + lastName = nextAfterDot.value + lastPos = nextAfterDot.pos + continue + } else { + cc.restorePos(dotPos) + break + } + } + + val qualified = segments.joinToString(".") { it.name } + val typeParams = currentTypeParams() + if (segments.size == 1 && typeParams.contains(qualified)) { + val isNullable = if (cc.skipTokenOfType(Token.Type.QUESTION, isOptional = true)) { + true + } else if (cc.skipTokenOfType(Token.Type.IFNULLASSIGN, isOptional = true)) { + cc.pushPendingAssign() + true + } else false + val rangeEnd = cc.currentPos() + val miniRef = MiniTypeVar(MiniRange(typeStart, rangeEnd), qualified, isNullable) + return TypeDecl.TypeVar(qualified, isNullable) to miniRef + } + if (segments.size > 1) { + lastPos?.let { pos -> resolutionSink?.reference(qualified, pos) } + } else { + lastName?.let { name -> + lastPos?.let { pos -> resolutionSink?.reference(name, pos) } + } + } + fun buildBaseRef(rangeEnd: Pos, args: List?, nullable: Boolean): MiniTypeRef { + val base = MiniTypeName(MiniRange(typeStart, rangeEnd), segments.toList(), nullable = false) + return if (args == null || args.isEmpty()) base.copy( + range = MiniRange(typeStart, rangeEnd), + nullable = nullable + ) + else MiniGenericType(MiniRange(typeStart, rangeEnd), base, args, nullable) + } + + var miniArgs: MutableList? = null + var semArgs: MutableList? = null + val afterBasePos = cc.savePos() + if (cc.skipTokenOfType(Token.Type.LT, isOptional = true)) { + miniArgs = mutableListOf() + semArgs = mutableListOf() + do { + val (argSem, argMini) = parseTypeExpressionWithMini() + miniArgs += argMini + semArgs += argSem + + val sep = cc.next() + if (sep.type == Token.Type.COMMA) { + // continue + } else if (sep.type == Token.Type.GT) { + break + } else if (sep.type == Token.Type.SHR) { + cc.pushPendingGT() + break + } else { + sep.raiseSyntax("expected ',' or '>' in generic arguments") + } + } while (true) + lastEnd = cc.currentPos() + } else { + cc.restorePos(afterBasePos) + } + + val isNullable = if (cc.skipTokenOfType(Token.Type.QUESTION, isOptional = true)) { + true + } else if (cc.skipTokenOfType(Token.Type.IFNULLASSIGN, isOptional = true)) { + cc.pushPendingAssign() + true + } else false + val endPos = cc.currentPos() + + val miniRef = buildBaseRef(if (miniArgs != null) endPos else lastEnd, miniArgs, isNullable) + val sem = if (semArgs != null) TypeDecl.Generic(qualified, semArgs, isNullable) + else TypeDecl.Simple(qualified, isNullable) + return Pair(sem, miniRef) + } + private fun typeDeclToTypeRef(typeDecl: TypeDecl, pos: Pos): ObjRef { return when (typeDecl) { TypeDecl.TypeAny, TypeDecl.TypeNullableAny -> ConstRef(Obj.rootObjectType.asReadonly) is TypeDecl.TypeVar -> resolveLocalTypeRef(typeDecl.name, pos) ?: ConstRef(Obj.rootObjectType.asReadonly) else -> { - val cls = resolveTypeDeclObjClass(typeDecl) - if (cls != null) return ConstRef(cls.asReadonly) val name = typeDeclName(typeDecl) resolveLocalTypeRef(name, pos)?.let { return it } + val cls = resolveTypeDeclObjClass(typeDecl) + if (cls != null) return ConstRef(cls.asReadonly) throw ScriptError(pos, "unknown type $name") } } @@ -2682,7 +2959,11 @@ class Compiler( private fun inferObjClassFromRef(ref: ObjRef): ObjClass? = when (ref) { is ConstRef -> ref.constValue as? ObjClass ?: (ref.constValue as? Obj)?.objClass is LocalVarRef -> nameObjClass[ref.name] - is LocalSlotRef -> nameObjClass[ref.name] + is LocalSlotRef -> { + val ownerScopeId = ref.captureOwnerScopeId ?: ref.scopeId + val ownerSlot = ref.captureOwnerSlot ?: ref.slot + slotTypeByScopeId[ownerScopeId]?.get(ownerSlot) ?: nameObjClass[ref.name] + } is ListLiteralRef -> ObjList.type is MapLiteralRef -> ObjMap.type is RangeRef -> ObjRange.type @@ -2690,6 +2971,307 @@ class Compiler( else -> null } + private fun resolveReceiverTypeDecl(ref: ObjRef): TypeDecl? { + return when (ref) { + is LocalSlotRef -> { + val ownerScopeId = ref.captureOwnerScopeId ?: ref.scopeId + val ownerSlot = ref.captureOwnerSlot ?: ref.slot + slotTypeDeclByScopeId[ownerScopeId]?.get(ownerSlot) + } + is LocalVarRef -> nameTypeDecl[ref.name] + is MethodCallRef -> methodReturnTypeDeclByRef[ref] + is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) } + else -> null + } + } + + private fun resolveReceiverClassForMember(ref: ObjRef): ObjClass? { + return when (ref) { + is LocalSlotRef -> { + val ownerScopeId = ref.captureOwnerScopeId ?: ref.scopeId + val ownerSlot = ref.captureOwnerSlot ?: ref.slot + slotTypeByScopeId[ownerScopeId]?.get(ownerSlot) + ?: slotTypeDeclByScopeId[ownerScopeId]?.get(ownerSlot)?.let { resolveTypeDeclObjClass(it) } + ?: nameObjClass[ref.name] + ?: resolveClassByName(ref.name) + } + is LocalVarRef -> nameObjClass[ref.name] + ?: nameTypeDecl[ref.name]?.let { resolveTypeDeclObjClass(it) } + ?: resolveClassByName(ref.name) + is ConstRef -> ref.constValue as? ObjClass ?: (ref.constValue as? Obj)?.objClass + is ListLiteralRef -> ObjList.type + is MapLiteralRef -> ObjMap.type + is RangeRef -> ObjRange.type + is CastRef -> resolveTypeRefClass(ref.castTypeRef()) + is QualifiedThisRef -> resolveClassByName(ref.typeName) + is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverClassForMember(it.ref) } + is MethodCallRef -> inferMethodCallReturnClass(ref) + is ImplicitThisMethodCallRef -> inferMethodCallReturnClass(ref.methodName()) + is ThisMethodSlotCallRef -> inferMethodCallReturnClass(ref.methodName()) + is QualifiedThisMethodSlotCallRef -> inferMethodCallReturnClass(ref.methodName()) + is CallRef -> inferCallReturnClass(ref) + is FieldRef -> { + val targetClass = resolveReceiverClassForMember(ref.target) + inferFieldReturnClass(targetClass, ref.name) + } + else -> null + } + } + + private fun inferBinaryOpReturnClass(ref: BinaryOpRef): ObjClass? { + val leftClass = resolveReceiverClassForMember(ref.left) ?: inferObjClassFromRef(ref.left) + val rightClass = resolveReceiverClassForMember(ref.right) ?: inferObjClassFromRef(ref.right) + if (leftClass == null || rightClass == null) return null + return when (ref.op) { + BinOp.PLUS, BinOp.MINUS -> when { + leftClass == ObjInstant.type && rightClass == ObjInstant.type && ref.op == BinOp.MINUS -> ObjDuration.type + leftClass == ObjInstant.type && rightClass == ObjDuration.type -> ObjInstant.type + leftClass == ObjDuration.type && rightClass == ObjInstant.type && ref.op == BinOp.PLUS -> ObjInstant.type + leftClass == ObjDuration.type && rightClass == ObjDuration.type -> ObjDuration.type + (leftClass == ObjBuffer.type || leftClass.allParentsSet.contains(ObjBuffer.type)) && + (rightClass == ObjBuffer.type || rightClass.allParentsSet.contains(ObjBuffer.type)) && + ref.op == BinOp.PLUS -> ObjBuffer.type + else -> null + } + else -> null + } + } + + private fun inferCallReturnClass(ref: CallRef): ObjClass? { + return when (val target = ref.target) { + is LocalSlotRef -> callableReturnTypeByScopeId[target.scopeId]?.get(target.slot) + ?: resolveClassByName(target.name) + is LocalVarRef -> callableReturnTypeByName[target.name] + ?: resolveClassByName(target.name) + is ConstRef -> when (val value = target.constValue) { + is ObjClass -> value + is ObjString -> ObjString.type + else -> null + } + else -> null + } + } + + private fun inferMethodCallReturnClass(ref: MethodCallRef): ObjClass? { + val receiverDecl = resolveReceiverTypeDecl(ref.receiver) + val genericReturnDecl = inferMethodCallReturnTypeDecl(ref.name, receiverDecl) + if (genericReturnDecl != null) { + methodReturnTypeDeclByRef[ref] = genericReturnDecl + resolveTypeDeclObjClass(genericReturnDecl)?.let { return it } + if (genericReturnDecl is TypeDecl.TypeVar) { + return Obj.rootObjectType + } + } + if (ref.name == "decode") { + val payload = inferEncodedPayloadClass(ref.args) + if (payload != null) return payload + } + return inferMethodCallReturnClass(ref.name) + } + + private fun inferMethodCallReturnTypeDecl(name: String, receiver: TypeDecl?): TypeDecl? { + val base = when (receiver) { + is TypeDecl.Generic -> receiver.name.substringAfterLast('.') + is TypeDecl.Simple -> receiver.name.substringAfterLast('.') + else -> null + } + return when { + name == "iterator" && receiver is TypeDecl.Generic && base == "Iterable" -> { + val arg = receiver.args.firstOrNull() ?: TypeDecl.TypeAny + TypeDecl.Generic("Iterator", listOf(arg), false) + } + name == "next" && receiver is TypeDecl.Generic && base == "Iterator" -> { + receiver.args.firstOrNull() + } + else -> null + } + } + + private fun inferEncodedPayloadClass(args: List): ObjClass? { + val stmt = args.firstOrNull()?.value as? ExpressionStatement ?: return null + val ref = stmt.ref + val byEncoded = when (ref) { + is LocalSlotRef -> encodedPayloadTypeByScopeId[ref.scopeId]?.get(ref.slot) + is LocalVarRef -> encodedPayloadTypeByName[ref.name] + else -> null + } + if (byEncoded != null) return byEncoded + return when (ref) { + is LocalSlotRef -> { + val ownerScopeId = ref.captureOwnerScopeId ?: ref.scopeId + val ownerSlot = ref.captureOwnerSlot ?: ref.slot + slotTypeByScopeId[ownerScopeId]?.get(ownerSlot) + } + is LocalVarRef -> nameObjClass[ref.name] + is CastRef -> resolveTypeRefClass(ref.castTypeRef()) + is MethodCallRef -> if (ref.name == "encode") inferEncodedPayloadClass(ref.args) else null + else -> null + } + } + + private fun inferMethodCallReturnClass(name: String): ObjClass? = when (name) { + "map", + "mapNotNull", + "filter", + "filterNotNull", + "drop", + "take", + "flatMap", + "flatten", + "sorted", + "sortedBy", + "sortedWith", + "reversed", + "toList", + "shuffle", + "shuffled" -> ObjList.type + "dropLast" -> ObjFlow.type + "takeLast" -> ObjRingBuffer.type + "iterator" -> ObjIterator + "now", + "truncateToSecond", + "truncateToMinute", + "truncateToMillisecond" -> ObjInstant.type + "toDateTime", + "toTimeZone", + "toUTC", + "parseRFC3339", + "addYears", + "addMonths", + "addDays", + "addHours", + "addMinutes", + "addSeconds" -> ObjDateTime.type + "toInstant" -> ObjInstant.type + "toRFC3339", + "toSortableString", + "toJsonString", + "decodeUtf8", + "toDump", + "toString" -> ObjString.type + "startsWith", + "matches" -> ObjBool.type + "toInt", + "toEpochSeconds" -> ObjInt.type + "toMutable" -> ObjMutableBuffer.type + "seq" -> ObjFlow.type + "encode" -> ObjBitBuffer.type + "assertThrows" -> ObjException.Root + else -> null + } + + private fun inferFieldReturnClass(targetClass: ObjClass?, name: String): ObjClass? { + if (targetClass == null) return null + classFieldTypesByName[targetClass.className]?.get(name)?.let { return it } + enumEntriesByName[targetClass.className]?.let { entries -> + return when { + name == "entries" -> ObjList.type + name == "name" -> ObjString.type + name == "ordinal" -> ObjInt.type + entries.contains(name) -> targetClass + else -> null + } + } + if (targetClass == ObjInstant.type && (name == "distantFuture" || name == "distantPast")) { + return ObjInstant.type + } + if (targetClass == ObjString.type && name == "re") { + return ObjRegex.type + } + if (targetClass == ObjInt.type || targetClass == ObjReal.type) { + return when (name) { + "day", + "days", + "hour", + "hours", + "minute", + "minutes", + "second", + "seconds", + "millisecond", + "milliseconds", + "microsecond", + "microseconds" -> ObjDuration.type + else -> null + } + } + if (targetClass == ObjDuration.type) { + return when (name) { + "days", + "hours", + "minutes", + "seconds", + "milliseconds", + "microseconds" -> ObjReal.type + else -> null + } + } + if (targetClass == ObjInstant.type) { + return when (name) { + "epochSeconds" -> ObjInt.type + "epochWholeSeconds" -> ObjInt.type + "truncateToSecond", + "truncateToMinute", + "truncateToMillisecond" -> ObjInstant.type + else -> null + } + } + if (targetClass == ObjDateTime.type) { + return when (name) { + "year", + "month", + "day", + "hour", + "minute", + "second", + "dayOfWeek", + "nanosecond" -> ObjInt.type + "timeZone" -> ObjString.type + else -> null + } + } + if (targetClass == ObjException.Root || targetClass.allParentsSet.contains(ObjException.Root)) { + return when (name) { + "message" -> ObjString.type + "stackTrace" -> ObjList.type + else -> null + } + } + if (targetClass == ObjRegex.type && name == "pattern") { + return ObjString.type + } + return null + } + + private fun enforceReceiverTypeForMember(left: ObjRef, memberName: String, pos: Pos) { + if (left is LocalVarRef && left.name == "scope") return + if (left is LocalSlotRef && left.name == "scope") return + val receiverClass = resolveReceiverClassForMember(left) + if (receiverClass == null) { + val allowed = memberName == "toString" || memberName == "toInspectString" + if (allowed) return + throw ScriptError(pos, "member access requires compile-time receiver type: $memberName") + } + if (receiverClass == Obj.rootObjectType) { + val allowed = isAllowedObjectMember(memberName) + if (!allowed && !hasExtensionFor(receiverClass.className, memberName)) { + throw ScriptError(pos, "member $memberName is not available on Object without explicit cast") + } + } + } + + private fun isAllowedObjectMember(memberName: String): Boolean { + return when (memberName) { + "toString", + "toInspectString", + "let", + "also", + "apply", + "run" -> true + else -> false + } + } + private fun resolveTypeRefClass(ref: ObjRef): ObjClass? = when (ref) { is ConstRef -> ref.constValue as? ObjClass is LocalSlotRef -> resolveTypeDeclObjClass(TypeDecl.Simple(ref.name, false)) ?: nameObjClass[ref.name] @@ -2929,7 +3511,7 @@ class Compiler( val end = cc.next() if (end.type == Token.Type.LBRACE) { val receiverType = inferReceiverTypeFromArgs(parsedArgs) - val callableAccessor = parseLambdaExpression(receiverType) + val callableAccessor = parseLambdaExpression(receiverType, wrapAsExtensionCallable = true) parsedArgs += ParsedArgument(ExpressionStatement(callableAccessor, end.pos), end.pos) detectedBlockArgument = true } else { @@ -3013,10 +3595,16 @@ class Compiler( private fun inferReceiverTypeFromRef(ref: ObjRef): String? { return when (ref) { - is LocalSlotRef -> slotTypeByScopeId[ref.scopeId]?.get(ref.slot)?.className + is LocalSlotRef -> { + val ownerScopeId = ref.captureOwnerScopeId ?: ref.scopeId + val ownerSlot = ref.captureOwnerSlot ?: ref.slot + slotTypeByScopeId[ownerScopeId]?.get(ownerSlot)?.className + ?: slotTypeDeclByScopeId[ownerScopeId]?.get(ownerSlot)?.let { typeDeclName(it) } + } is LocalVarRef -> nameObjClass[ref.name]?.className + ?: nameTypeDecl[ref.name]?.let { typeDeclName(it) } is QualifiedThisRef -> ref.typeName - else -> null + else -> resolveReceiverClassForMember(ref)?.className } } @@ -3555,6 +4143,13 @@ class Compiler( if (name == "Exception") return ObjException.Root scope.raiseSymbolNotFound("error class does not exist or is not a class: $name") } + fun resolveCatchVarClass(names: List): ObjClass? { + if (names.size == 1) { + val name = names.first() + return resolveClassByName(name) ?: if (name == "Exception") ObjException.Root else null + } + return ObjException.Root + } val body = unwrapBytecodeDeep(parseBlock()) val catches = mutableListOf() @@ -3597,9 +4192,15 @@ class Compiler( val block = try { resolutionSink?.enterScope(ScopeKind.BLOCK, catchVar.pos, null) resolutionSink?.declareSymbol(catchVar.value, SymbolKind.LOCAL, isMutable = false, pos = catchVar.pos) + val catchType = resolveCatchVarClass(exClassNames) stripCatchCaptures( withCatchSlot( - unwrapBytecodeDeep(parseBlockWithPredeclared(listOf(catchVar.value to false))), + unwrapBytecodeDeep( + parseBlockWithPredeclared( + listOf(catchVar.value to false), + predeclaredTypes = catchType?.let { mapOf(catchVar.value to it) } ?: emptyMap() + ) + ), catchVar.value ) ) @@ -3616,9 +4217,16 @@ class Compiler( val block = try { resolutionSink?.enterScope(ScopeKind.BLOCK, itToken.pos, null) resolutionSink?.declareSymbol(itToken.value, SymbolKind.LOCAL, isMutable = false, pos = itToken.pos) + val catchType = resolveCatchVarClass(listOf("Exception")) stripCatchCaptures( withCatchSlot( - unwrapBytecodeDeep(parseBlockWithPredeclared(listOf(itToken.value to false), skipLeadingBrace = true)), + unwrapBytecodeDeep( + parseBlockWithPredeclared( + listOf(itToken.value to false), + skipLeadingBrace = true, + predeclaredTypes = catchType?.let { mapOf(itToken.value to it) } ?: emptyMap() + ) + ), itToken.value ) ) @@ -3741,6 +4349,20 @@ class Compiler( entryPositions = positions ) ) + val fieldIds = LinkedHashMap(names.size + 1) + fieldIds["entries"] = 0 + for ((index, entry) in names.withIndex()) { + fieldIds[entry] = index + 1 + } + val methodIds = mapOf("valueOf" to 0) + compileClassInfos[nameToken.value] = CompileClassInfo( + name = nameToken.value, + fieldIds = fieldIds, + methodIds = methodIds, + nextFieldId = fieldIds.size, + nextMethodId = methodIds.size + ) + enumEntriesByName[nameToken.value] = names.toList() val stmtPos = startPos val enumDeclStatement = object : Statement() { @@ -4522,35 +5144,47 @@ class Compiler( ): Statement { isTransientFlag = false val actualExtern = isExtern || (codeContexts.lastOrNull() as? CodeContext.ClassBody)?.isExtern == true - var t = cc.next() - val start = t.pos + var start = cc.currentPos() var extTypeName: String? = null - var name = if (t.type != Token.Type.ID) - throw ScriptError(t.pos, "Expected identifier after 'fun'") - else t.value - var nameStartPos: Pos = t.pos + var receiverTypeDecl: TypeDecl? = null + var name: String + var nameStartPos: Pos var receiverMini: MiniTypeRef? = null val annotation = lastAnnotation val parentContext = codeContexts.last() // Is extension? - if (cc.peekNextNonWhitespace().type == Token.Type.DOT) { - cc.nextNonWhitespace() // consume DOT - extTypeName = name - resolutionSink?.reference(extTypeName, start) - val receiverEnd = Pos(start.source, start.line, start.column + name.length) - receiverMini = MiniTypeName( - range = MiniRange(start, receiverEnd), - segments = listOf(MiniTypeName.Segment(name, MiniRange(start, receiverEnd))), - nullable = false - ) - t = cc.next() - if (t.type != Token.Type.ID) + if (looksLikeExtensionReceiver()) { + val (recvDecl, recvMini) = parseExtensionReceiverTypeWithMini() + receiverTypeDecl = recvDecl + receiverMini = recvMini + val dot = cc.nextNonWhitespace() + if (dot.type != Token.Type.DOT) { + throw ScriptError(dot.pos, "illegal extension format: expected '.' after receiver type") + } + val t = cc.next() + if (t.type != Token.Type.ID) { throw ScriptError(t.pos, "illegal extension format: expected function name") + } name = t.value nameStartPos = t.pos + extTypeName = when (recvDecl) { + is TypeDecl.Simple -> recvDecl.name.substringAfterLast('.') + is TypeDecl.Generic -> recvDecl.name.substringAfterLast('.') + else -> throw ScriptError( + recvMini.range.start, + "illegal extension receiver type: ${typeDeclName(recvDecl)}" + ) + } registerExtensionName(extTypeName, name) + } else { + val t = cc.next() + if (t.type != Token.Type.ID) + throw ScriptError(t.pos, "Expected identifier after 'fun'") + start = t.pos + name = t.value + nameStartPos = t.pos } val extensionWrapperName = extTypeName?.let { extensionCallableName(it, name) } val classCtx = codeContexts.asReversed().firstOrNull { it is CodeContext.ClassBody } as? CodeContext.ClassBody @@ -4570,10 +5204,21 @@ class Compiler( } val typeParamDecls = parseTypeParamList() - val typeParams = typeParamDecls.map { it.name }.toSet() + val explicitTypeParams = typeParamDecls.map { it.name }.toSet() + val receiverNormalization = normalizeReceiverTypeDecl(receiverTypeDecl, explicitTypeParams) + receiverTypeDecl = receiverNormalization.first + val implicitTypeParams = receiverNormalization.second + val mergedTypeParamDecls = if (implicitTypeParams.isEmpty()) { + typeParamDecls + } else { + typeParamDecls + implicitTypeParams.filter { it !in explicitTypeParams } + .map { TypeDecl.TypeParam(it) } + } + val typeParams = mergedTypeParamDecls.map { it.name }.toSet() pendingTypeParamStack.add(typeParams) val argsDeclaration: ArgsDeclaration val returnTypeMini: MiniTypeRef? + val returnTypeDecl: TypeDecl? try { argsDeclaration = if (cc.peekNextNonWhitespace().type == Token.Type.LPAREN) { @@ -4581,14 +5226,16 @@ class Compiler( parseArgsDeclaration() ?: ArgsDeclaration(emptyList(), Token.Type.RPAREN) } else ArgsDeclaration(emptyList(), Token.Type.RPAREN) - if (typeParamDecls.isNotEmpty() && declKind != SymbolKind.MEMBER) { - currentGenericFunctionDecls()[name] = GenericFunctionDecl(typeParamDecls, argsDeclaration.params, nameStartPos) + if (mergedTypeParamDecls.isNotEmpty() && declKind != SymbolKind.MEMBER) { + currentGenericFunctionDecls()[name] = GenericFunctionDecl(mergedTypeParamDecls, argsDeclaration.params, nameStartPos) } // Optional return type - returnTypeMini = if (cc.peekNextNonWhitespace().type == Token.Type.COLON) { - parseTypeDeclarationWithMini().second + val parsedReturn = if (cc.peekNextNonWhitespace().type == Token.Type.COLON) { + parseTypeDeclarationWithMini() } else null + returnTypeMini = parsedReturn?.second + returnTypeDecl = parsedReturn?.first } finally { pendingTypeParamStack.removeLast() } @@ -4603,7 +5250,7 @@ class Compiler( if (!isDelegated && argsDeclaration.endTokenType != Token.Type.RPAREN) throw ScriptError( - t.pos, + nameStartPos, "Bad function definition: expected valid argument declaration or () after 'fn ${name}'" ) @@ -4652,7 +5299,7 @@ class Compiler( outerLabel?.let { cc.labels.add(it) } val paramNamesList = argsDeclaration.params.map { it.name } - val typeParamNames = typeParamDecls.map { it.name } + val typeParamNames = mergedTypeParamDecls.map { it.name } val paramNames: Set = paramNamesList.toSet() val paramSlotPlan = buildParamSlotPlan(paramNamesList + typeParamNames) val capturePlan = CapturePlan(paramSlotPlan) @@ -4660,6 +5307,17 @@ class Compiler( .filter { isRangeType(it.type) } .map { it.name } .toSet() + val paramTypeMap = slotTypeByScopeId.getOrPut(paramSlotPlan.id) { mutableMapOf() } + val paramTypeDeclMap = slotTypeDeclByScopeId.getOrPut(paramSlotPlan.id) { mutableMapOf() } + for (param in argsDeclaration.params) { + val cls = resolveTypeDeclObjClass(param.type) ?: continue + val slot = paramSlotPlan.slots[param.name]?.index ?: continue + paramTypeMap[slot] = cls + } + for (param in argsDeclaration.params) { + val slot = paramSlotPlan.slots[param.name]?.index ?: continue + paramTypeDeclMap[slot] = param.type + } // Parse function body while tracking declared locals to compute precise capacity hints currentLocalDeclCount @@ -4714,6 +5372,16 @@ class Compiler( val rawFnStatements = parsedFnStatements?.let { if (containsUnsupportedForBytecode(it)) unwrapBytecodeDeep(it) else it } + val inferredReturnClass = returnTypeDecl?.let { resolveTypeDeclObjClass(it) } + ?: inferReturnClassFromStatement(rawFnStatements) + if (declKind != SymbolKind.MEMBER && inferredReturnClass != null) { + callableReturnTypeByName[name] = inferredReturnClass + val slotLoc = lookupSlotLocation(name, includeModule = true) + if (slotLoc != null) { + callableReturnTypeByScopeId.getOrPut(slotLoc.scopeId) { mutableMapOf() }[slotLoc.slot] = + inferredReturnClass + } + } val fnStatements = rawFnStatements?.let { stmt -> if (useBytecodeStatements && !containsUnsupportedForBytecode(stmt)) { wrapFunctionBytecode(stmt, name) @@ -4730,7 +5398,7 @@ class Compiler( val paramSlotPlanSnapshot = slotPlanIndices(paramSlotPlan) val captureSlots = capturePlan.captures.toList() val fnBody = object : Statement(), BytecodeBodyProvider { - override val pos: Pos = t.pos + override val pos: Pos = start override fun bytecodeBody(): BytecodeStatement? = fnStatements as? BytecodeStatement override suspend fun execute(callerContext: Scope): Obj { callerContext.pos = start @@ -4756,7 +5424,7 @@ class Compiler( // load params from caller context argsDeclaration.assignToContext(context, callerContext.args, defaultAccessType = AccessType.Val) - bindTypeParamsAtRuntime(context, argsDeclaration, typeParamDecls) + bindTypeParamsAtRuntime(context, argsDeclaration, mergedTypeParamDecls) if (extTypeName != null) { context.thisObj = callerContext.thisObj } @@ -4989,7 +5657,6 @@ class Compiler( val blockSlotPlan = SlotPlan(mutableMapOf(), 0, nextScopeId++) for ((name, isMutable) in predeclared) { declareSlotNameIn(blockSlotPlan, name, isMutable, isDelegated = false) - resolutionSink?.declareSymbol(name, SymbolKind.LOCAL, isMutable, startPos, isOverride = false) } slotPlanStack.add(blockSlotPlan) val capturePlan = CapturePlan(blockSlotPlan) @@ -5007,25 +5674,80 @@ class Compiler( return stmt } + private fun inferReturnClassFromStatement(stmt: Statement?): ObjClass? { + if (stmt == null) return null + val unwrapped = unwrapBytecodeDeep(stmt) + return when (unwrapped) { + is ExpressionStatement -> resolveInitializerObjClass(unwrapped) + is ReturnStatement -> resolveInitializerObjClass(unwrapped.resultExpr) + is VarDeclStatement -> unwrapped.initializerObjClass ?: resolveInitializerObjClass(unwrapped.initializer) + is BlockStatement -> { + val stmts = unwrapped.statements() + val returnTypes = stmts.mapNotNull { s -> + (s as? ReturnStatement)?.let { resolveInitializerObjClass(it.resultExpr) } + } + if (returnTypes.isNotEmpty()) { + val first = returnTypes.first() + if (returnTypes.all { it == first }) first else Obj.rootObjectType + } else { + val last = stmts.lastOrNull() + inferReturnClassFromStatement(last) + } + } + is InlineBlockStatement -> { + val stmts = unwrapped.statements() + val last = stmts.lastOrNull() + inferReturnClassFromStatement(last) + } + is IfStatement -> { + val ifType = inferReturnClassFromStatement(unwrapped.ifBody) + val elseType = unwrapped.elseBody?.let { inferReturnClassFromStatement(it) } + when { + ifType == null && elseType == null -> null + ifType != null && elseType != null && ifType == elseType -> ifType + else -> Obj.rootObjectType + } + } + else -> null + } + } + + private fun unwrapDirectRef(initializer: Statement?): ObjRef? { + var initStmt = initializer + while (initStmt is BytecodeStatement) { + initStmt = initStmt.original + } + val initRef = (initStmt as? ExpressionStatement)?.ref + return when (initRef) { + is StatementRef -> (initRef.statement as? ExpressionStatement)?.ref + else -> initRef + } + } + private fun resolveInitializerObjClass(initializer: Statement?): ObjClass? { if (initializer is BytecodeStatement) { val fn = initializer.bytecodeFunction() if (fn.cmds.any { it is CmdListLiteral }) return ObjList.type if (fn.cmds.any { it is CmdMakeRange || it is CmdRangeIntBounds }) return ObjRange.type } - var initStmt = initializer - while (initStmt is BytecodeStatement) { - initStmt = initStmt.original - } - val initRef = (initStmt as? ExpressionStatement)?.ref - val directRef = when (initRef) { - is StatementRef -> (initRef.statement as? ExpressionStatement)?.ref - else -> initRef + if (initializer is DoWhileStatement) { + val bodyType = inferReturnClassFromStatement(initializer.body) + val elseType = initializer.elseStatement?.let { inferReturnClassFromStatement(it) } + return when { + bodyType == null && elseType == null -> null + bodyType != null && elseType != null && bodyType == elseType -> bodyType + bodyType != null && elseType == null -> bodyType + bodyType == null -> elseType + else -> Obj.rootObjectType + } } + val directRef = unwrapDirectRef(initializer) return when (directRef) { is ListLiteralRef -> ObjList.type is MapLiteralRef -> ObjMap.type is RangeRef -> ObjRange.type + is CastRef -> resolveTypeRefClass(directRef.castTypeRef()) + is BinaryOpRef -> inferBinaryOpReturnClass(directRef) is ImplicitThisMethodCallRef -> { if (directRef.methodName() == "iterator") ObjIterator else null } @@ -5033,20 +5755,37 @@ class Compiler( if (directRef.methodName() == "iterator") ObjIterator else null } is MethodCallRef -> { - if (directRef.name == "iterator") ObjIterator else null + inferMethodCallReturnClass(directRef) + } + is ImplicitThisMethodCallRef -> { + inferMethodCallReturnClass(directRef.methodName()) + } + is FieldRef -> { + val targetClass = resolveReceiverClassForMember(directRef.target) + inferFieldReturnClass(targetClass, directRef.name) } is CallRef -> { val target = directRef.target when { + target is LocalSlotRef -> { + callableReturnTypeByScopeId[target.scopeId]?.get(target.slot) + ?: if (target.name == "iterator") ObjIterator else resolveClassByName(target.name) + } + target is LocalVarRef -> { + callableReturnTypeByName[target.name] + ?: if (target.name == "iterator") ObjIterator else resolveClassByName(target.name) + } target is LocalVarRef && target.name == "List" -> ObjList.type target is LocalVarRef && target.name == "Map" -> ObjMap.type target is LocalVarRef && target.name == "iterator" -> ObjIterator target is ImplicitThisMemberRef && target.name == "iterator" -> ObjIterator target is ThisFieldSlotRef && target.name == "iterator" -> ObjIterator target is FieldRef && target.name == "iterator" -> ObjIterator - target is LocalSlotRef -> resolveClassByName(target.name) - target is LocalVarRef -> resolveClassByName(target.name) - target is ConstRef -> target.constValue as? ObjClass + target is ConstRef -> when (val value = target.constValue) { + is ObjClass -> value + is ObjString -> ObjString.type + else -> null + } else -> null } } @@ -5054,6 +5793,13 @@ class Compiler( is ObjList -> ObjList.type is ObjMap -> ObjMap.type is ObjRange -> ObjRange.type + is ObjString -> ObjString.type + is ObjInt -> ObjInt.type + is ObjReal -> ObjReal.type + is ObjBool -> ObjBool.type + is ObjChar -> ObjChar.type + is ObjNull -> Obj.rootObjectType + is ObjVoid -> ObjVoid.objClass else -> null } else -> null @@ -5095,6 +5841,12 @@ class Compiler( "RegexMatch" -> ObjRegexMatch.type "MapEntry" -> ObjMapEntry.type "Exception" -> ObjException.Root + "Instant" -> ObjInstant.type + "DateTime" -> ObjDateTime.type + "Duration" -> ObjDuration.type + "Buffer" -> ObjBuffer.type + "MutableBuffer" -> ObjMutableBuffer.type + "RingBuffer" -> ObjRingBuffer.type "Callable" -> Statement.type else -> resolveClassByName(rawName) ?: resolveClassByName(name) } @@ -5103,6 +5855,10 @@ class Compiler( private fun resolveClassByName(name: String): ObjClass? { val rec = seedScope?.get(name) ?: importManager.rootScope.get(name) (rec?.value as? ObjClass)?.let { return it } + for (scope in importedScopes.asReversed()) { + val imported = scope.get(name) + (imported?.value as? ObjClass)?.let { return it } + } val info = compileClassInfos[name] ?: return null return compileClassStubs.getOrPut(info.name) { val stub = ObjInstanceClass(info.name) @@ -5137,7 +5893,8 @@ class Compiler( private suspend fun parseBlockWithPredeclared( predeclared: List>, - skipLeadingBrace: Boolean = false + skipLeadingBrace: Boolean = false, + predeclaredTypes: Map = emptyMap() ): Statement { val startPos = cc.currentPos() if (!skipLeadingBrace) { @@ -5149,6 +5906,14 @@ class Compiler( val blockSlotPlan = SlotPlan(mutableMapOf(), 0, nextScopeId++) for ((name, isMutable) in predeclared) { declareSlotNameIn(blockSlotPlan, name, isMutable, isDelegated = false) + resolutionSink?.declareSymbol(name, SymbolKind.LOCAL, isMutable, startPos, isOverride = false) + } + if (predeclaredTypes.isNotEmpty()) { + val typeMap = slotTypeByScopeId.getOrPut(blockSlotPlan.id) { mutableMapOf() } + for ((name, cls) in predeclaredTypes) { + val slot = blockSlotPlan.slots[name]?.index ?: continue + typeMap[slot] = cls + } } slotPlanStack.add(blockSlotPlan) val capturePlan = CapturePlan(blockSlotPlan) @@ -5245,6 +6010,7 @@ class Compiler( ): Statement { isTransientFlag = false val actualExtern = isExtern || (codeContexts.lastOrNull() as? CodeContext.ClassBody)?.isExtern == true + val markStart = cc.savePos() val nextToken = cc.next() val start = nextToken.pos @@ -5301,32 +6067,47 @@ class Compiler( ) } - if (nextToken.type != Token.Type.ID) - throw ScriptError(nextToken.pos, "Expected identifier or [ here") - var name = nextToken.value + cc.restorePos(markStart) + var name: String var extTypeName: String? = null - var nameStartPos: Pos = nextToken.pos + var nameStartPos: Pos var receiverMini: MiniTypeRef? = null + var receiverTypeDecl: TypeDecl? = null - if (cc.peekNextNonWhitespace().type == Token.Type.DOT) { - cc.skipWsTokens() - cc.next() // consume dot - extTypeName = name - resolutionSink?.reference(extTypeName, nextToken.pos) - val receiverEnd = Pos(nextToken.pos.source, nextToken.pos.line, nextToken.pos.column + name.length) - receiverMini = MiniTypeName( - range = MiniRange(nextToken.pos, receiverEnd), - segments = listOf(MiniTypeName.Segment(name, MiniRange(nextToken.pos, receiverEnd))), - nullable = false - ) + if (looksLikeExtensionReceiver()) { + val (recvDecl, recvMini) = parseExtensionReceiverTypeWithMini() + receiverTypeDecl = recvDecl + receiverMini = recvMini + val dot = cc.nextNonWhitespace() + if (dot.type != Token.Type.DOT) + throw ScriptError(dot.pos, "Expected '.' after extension receiver type") val nameToken = cc.next() if (nameToken.type != Token.Type.ID) throw ScriptError(nameToken.pos, "Expected identifier after dot in extension declaration") name = nameToken.value nameStartPos = nameToken.pos + extTypeName = when (recvDecl) { + is TypeDecl.Simple -> recvDecl.name.substringAfterLast('.') + is TypeDecl.Generic -> recvDecl.name.substringAfterLast('.') + else -> throw ScriptError( + recvMini.range.start, + "illegal extension receiver type: ${typeDeclName(recvDecl)}" + ) + } registerExtensionName(extTypeName, name) + } else { + val nameToken = cc.next() + if (nameToken.type != Token.Type.ID) + throw ScriptError(nameToken.pos, "Expected identifier or [ here") + name = nameToken.value + nameStartPos = nameToken.pos } + val receiverNormalization = normalizeReceiverTypeDecl(receiverTypeDecl, emptySet()) + val implicitTypeParams = receiverNormalization.second + if (implicitTypeParams.isNotEmpty()) pendingTypeParamStack.add(implicitTypeParams) + try { + val classCtx = codeContexts.asReversed().firstOrNull { it is CodeContext.ClassBody } as? CodeContext.ClassBody val memberFieldId = if (extTypeName == null) classCtx?.memberFieldIds?.get(name) else null val memberMethodId = if (extTypeName == null) classCtx?.memberMethodIds?.get(name) else null @@ -5508,7 +6289,35 @@ class Compiler( val slotPlan = slotPlanStack.lastOrNull() val slotIndex = slotPlan?.slots?.get(name)?.index val scopeId = slotPlan?.id - val initObjClass = resolveInitializerObjClass(initialExpression) ?: resolveTypeDeclObjClass(varTypeDecl) + val directRef = unwrapDirectRef(initialExpression) + val declClass = resolveTypeDeclObjClass(varTypeDecl) + val initFromExpr = resolveInitializerObjClass(initialExpression) + val isNullLiteral = (directRef as? ConstRef)?.constValue == ObjNull + val initObjClass = if (declClass != null && isNullLiteral) declClass else initFromExpr ?: declClass + if (varTypeDecl !is TypeDecl.TypeAny && varTypeDecl !is TypeDecl.TypeNullableAny) { + if (slotIndex != null && scopeId != null) { + slotTypeDeclByScopeId.getOrPut(scopeId) { mutableMapOf() }[slotIndex] = varTypeDecl + } + nameTypeDecl[name] = varTypeDecl + } + if (directRef is ValueFnRef) { + val returnClass = lambdaReturnTypeByRef[directRef] + if (returnClass != null) { + if (slotIndex != null && scopeId != null) { + callableReturnTypeByScopeId.getOrPut(scopeId) { mutableMapOf() }[slotIndex] = returnClass + } + callableReturnTypeByName[name] = returnClass + } + } + if (directRef is MethodCallRef && directRef.name == "encode") { + val payloadClass = inferEncodedPayloadClass(directRef.args) + if (payloadClass != null) { + if (slotIndex != null && scopeId != null) { + encodedPayloadTypeByScopeId.getOrPut(scopeId) { mutableMapOf() }[slotIndex] = payloadClass + } + encodedPayloadTypeByName[name] = payloadClass + } + } if (initObjClass != null) { if (slotIndex != null && scopeId != null) { slotTypeByScopeId.getOrPut(scopeId) { mutableMapOf() }[slotIndex] = initObjClass @@ -5529,6 +6338,16 @@ class Compiler( } if (isStatic) { + if (declaringClassNameCaptured != null) { + val directRef = unwrapDirectRef(initialExpression) + val declClass = resolveTypeDeclObjClass(varTypeDecl) + val initFromExpr = resolveInitializerObjClass(initialExpression) + val isNullLiteral = (directRef as? ConstRef)?.constValue == ObjNull + val initClass = if (declClass != null && isNullLiteral) declClass else initFromExpr ?: declClass + if (initClass != null) { + classFieldTypesByName.getOrPut(declaringClassNameCaptured) { mutableMapOf() }[name] = initClass + } + } // find objclass instance: this is tricky: this code executes in object initializer, // when creating instance, but we need to execute it in the class initializer which // is missing as for now. Add it to the compiler context? @@ -6068,7 +6887,10 @@ class Compiler( } } } + } finally { + if (implicitTypeParams.isNotEmpty()) pendingTypeParamStack.removeLast() } +} data class Operator( val tokenType: Token.Type, diff --git a/lynglib/stdlib/lyng/root.lyng b/lynglib/stdlib/lyng/root.lyng index 8dcfbf4..0fd0f03 100644 --- a/lynglib/stdlib/lyng/root.lyng +++ b/lynglib/stdlib/lyng/root.lyng @@ -15,15 +15,15 @@ extern fun pow(x: Object, y: Object): Real extern fun sqrt(x: Object): Real // Last regex match result, updated by =~ / !~. -var $~ = null +var $~: Object? = null /* Wrap a builder into a zero-argument thunk that computes once and caches the result. The first call invokes builder() and stores the value; subsequent calls return the cached value. */ fun cached(builder: ()->T): ()->T { - var calculated = false - var value = null + var calculated: Bool = false + var value: Object? = null { if( !calculated ) { value = builder() @@ -35,7 +35,7 @@ fun cached(builder: ()->T): ()->T { /* Filter elements of this iterable using the provided predicate and provide a flow of results. Coudl be used to map infinte flows, etc. */ -fun Iterable.filterFlow(predicate: (T)->Bool): Flow { +fun Iterable.filterFlow(predicate: (T)->Bool): Flow { val list = this flow { for( item in list ) { @@ -49,7 +49,7 @@ fun Iterable.filterFlow(predicate: (T)->Bool): Flow { /* Filter this iterable and return List of elements */ -fun Iterable.filter(predicate: (T)->Bool): List { +fun Iterable.filter(predicate: (T)->Bool): List { var result: List = List() for( item in this ) if( predicate(item) ) result += item result @@ -58,7 +58,7 @@ fun Iterable.filter(predicate: (T)->Bool): List { /* Count all items in this iterable for which predicate returns true */ -fun Iterable.count(predicate: (T)->Bool): Int { +fun Iterable.count(predicate: (T)->Bool): Int { var hits = 0 this.forEach { if( predicate(it) ) hits++ @@ -69,25 +69,25 @@ fun Iterable.count(predicate: (T)->Bool): Int { filter out all null elements from this collection (Iterable); flow of non-null elements is returned */ -fun Iterable.filterFlowNotNull(): Flow { +fun Iterable.filterFlowNotNull(): Flow { filterFlow { it != null } } /* Filter non-null elements and collect them into a List */ -fun Iterable.filterNotNull(): List { +fun Iterable.filterNotNull(): List { filter { it != null } } /* Skip the first N elements of this iterable. */ -fun Iterable.drop(n: Int): List { +fun Iterable.drop(n: Int): List { var cnt = 0 filter { cnt++ >= n } } /* Return the first element or throw if the iterable is empty. */ -val Iterable.first: Object get() { - val i: Iterator = iterator() +val Iterable.first: T get() { + val i: Iterator = iterator() if( !i.hasNext() ) throw NoSuchElementException() i.next().also { i.cancelIteration() } } @@ -96,7 +96,7 @@ val Iterable.first: Object get() { Return the first element that matches the predicate or throws NuSuchElementException */ -fun Iterable.findFirst(predicate: (T)->Bool): T { +fun Iterable.findFirst(predicate: (T)->Bool): T { for( x in this ) { if( predicate(x) ) break x @@ -107,7 +107,7 @@ fun Iterable.findFirst(predicate: (T)->Bool): T { /* return the first element matching the predicate or null */ -fun Iterable.findFirstOrNull(predicate: (T)->Bool): T? { +fun Iterable.findFirstOrNull(predicate: (T)->Bool): T? { for( x in this ) { if( predicate(x) ) break x @@ -117,19 +117,19 @@ fun Iterable.findFirstOrNull(predicate: (T)->Bool): T? { /* Return the last element or throw if the iterable is empty. */ -val Iterable.last: Object get() { +val Iterable.last: T get() { var found = false - var element = null + var element: Object = Unset for( i in this ) { element = i found = true } if( !found ) throw NoSuchElementException() - element + element as T } /* Emit all but the last N elements of this iterable. */ -fun Iterable.dropLast(n: Int): Flow { +fun Iterable.dropLast(n: Int): Flow { val list = this val buffer = RingBuffer(n) flow { @@ -142,25 +142,25 @@ fun Iterable.dropLast(n: Int): Flow { } /* Return the last N elements of this iterable as a buffer/list. */ -fun Iterable.takeLast(n: Int): RingBuffer { +fun Iterable.takeLast(n: Int): RingBuffer { val buffer: RingBuffer = RingBuffer(n) for( item in this ) buffer += item buffer } /* Join elements into a string with a separator (separator parameter) and optional transformer. */ -fun Iterable.joinToString(separator: String=" ", transformer: (T)->Object = { it }): String { - var result = null +fun Iterable.joinToString(separator: String=" ", transformer: (T)->Object = { it }): String { + var result: String? = null for( part in this ) { val transformed = transformer(part).toString() if( result == null ) result = transformed - else result += separator + transformed + else result = (result as String) + separator + transformed } result ?: "" } /* Return true if any element matches the predicate. */ -fun Iterable.any(predicate: (T)->Bool): Bool { +fun Iterable.any(predicate: (T)->Bool): Bool { for( i in this ) { if( predicate(i) ) break true @@ -168,13 +168,13 @@ fun Iterable.any(predicate: (T)->Bool): Bool { } /* Return true if all elements match the predicate. */ -fun Iterable.all(predicate: (T)->Bool): Bool { +fun Iterable.all(predicate: (T)->Bool): Bool { !any { !predicate(it) } } /* Sum all elements; returns null for empty collections. */ -fun Iterable.sum(): T? { - val i: Iterator = iterator() +fun Iterable.sum(): T? { + val i: Iterator = iterator() if( i.hasNext() ) { var result = i.next() while( i.hasNext() ) result += i.next() @@ -184,8 +184,8 @@ fun Iterable.sum(): T? { } /* Sum mapped values of elements; returns null for empty collections. */ -fun Iterable.sumOf(f: (T)->R): R? { - val i: Iterator = iterator() +fun Iterable.sumOf(f: (T)->R): R? { + val i: Iterator = iterator() if( i.hasNext() ) { var result = f(i.next()) while( i.hasNext() ) result += f(i.next()) @@ -195,8 +195,8 @@ fun Iterable.sumOf(f: (T)->R): R? { } /* Minimum value of the given function applied to elements of the collection. */ -fun Iterable.minOf(lambda: (T)->R): R { - val i: Iterator = iterator() +fun Iterable.minOf(lambda: (T)->R): R { + val i: Iterator = iterator() var minimum = lambda( i.next() ) while( i.hasNext() ) { val x = lambda(i.next()) @@ -206,8 +206,8 @@ fun Iterable.minOf(lambda: (T)->R): R { } /* Maximum value of the given function applied to elements of the collection. */ -fun Iterable.maxOf(lambda: (T)->R): R { - val i: Iterator = iterator() +fun Iterable.maxOf(lambda: (T)->R): R { + val i: Iterator = iterator() var maximum = lambda( i.next() ) while( i.hasNext() ) { val x = lambda(i.next()) @@ -217,17 +217,17 @@ fun Iterable.maxOf(lambda: (T)->R): R { } /* Return elements sorted by natural order. */ -fun Iterable.sorted(): List { +fun Iterable.sorted(): List { sortedWith { a, b -> a <=> b } } /* Return elements sorted by the key selector. */ -fun Iterable.sortedBy(predicate: (T)->R): List { +fun Iterable.sortedBy(predicate: (T)->R): List { sortedWith { a, b -> predicate(a) <=> predicate(b) } } /* Return a shuffled copy of the iterable as a list. */ -fun Iterable.shuffled(): List { +fun Iterable.shuffled(): List { val list: List = toList() list.shuffle() list @@ -237,7 +237,7 @@ fun Iterable.shuffled(): List { Returns a single list of all elements from all collections in the given collection. @return List */ -fun Iterable.flatten(): List { +fun Iterable>.flatten(): List { var result: List = List() forEach { i -> i.forEach { result += it } @@ -249,13 +249,13 @@ fun Iterable.flatten(): List { Returns a single list of all elements yielded from results of transform function being invoked on each element of original collection. */ -fun Iterable.flatMap(transform: (T)->Iterable): List { +fun Iterable.flatMap(transform: (T)->Iterable): List { val mapped: List> = map(transform) mapped.flatten() } /* Return string representation like [a,b,c]. */ -override fun List.toString() { +override fun List.toString() { var first = true var result = "[" for (item in this) { @@ -267,12 +267,12 @@ override fun List.toString() { } /* Sort list in-place by key selector. */ -fun List.sortBy(predicate: (T)->R): Void { +fun List.sortBy(predicate: (T)->R): Void { sortWith { a, b -> predicate(a) <=> predicate(b) } } /* Sort list in-place by natural order. */ -fun List.sort(): Void { +fun List.sort(): Void { sortWith { a, b -> a <=> b } } @@ -327,9 +327,7 @@ interface Delegate { returns what the block returns. */ fun with(self: T, block: T.()->R): R { - var result = Unset - self.apply { result = block() } - result as R + block(self) } /* @@ -337,16 +335,16 @@ fun with(self: T, block: T.()->R): R { The provided creator lambda is called once on the first access to compute the value. Can only be used with 'val' properties. */ -class lazy(creatorParam: Object.()->T) : Delegate { - private val creator: Object.()->T = creatorParam +class lazy(creatorParam: ThisRefType.()->T) : Delegate { + private val creator: ThisRefType.()->T = creatorParam private var value = Unset - override fun bind(name: String, access: DelegateAccess, thisRef: Object): Object { + override fun bind(name: String, access: DelegateAccess, thisRef: ThisRefType): Object { if (access.toString() != "DelegateAccess.Val") throw "lazy delegate can only be used with 'val'" this } - override fun getValue(thisRef: Object, name: String): T { + override fun getValue(thisRef: ThisRefType, name: String): T { if (value == Unset) value = with(thisRef,creator) value as T