diff --git a/docs/OOP.md b/docs/OOP.md index f16c2aa..45d56be 100644 --- a/docs/OOP.md +++ b/docs/OOP.md @@ -479,6 +479,18 @@ val block: context(Html, Head) Body.()->String = { } ``` +Context receivers can also constrain extension functions. The extension is visible only when the required receiver is +already in the implicit receiver stack: + +```lyng +class Tag { fun addText(text: String) { /* ... */ } } + +context(Tag) +fun String.unaryPlus() { + this@Tag.addText(this) +} +``` + - Field inheritance (`val`/`var`) and collisions - Instance storage is kept per declaring class, internally disambiguated; unqualified read/write resolves to the first match in the resolution order (leftmost base). - Qualified read/write (via `this@Type` or casts) targets the chosen ancestor’s storage. @@ -650,10 +662,14 @@ Unary operators are overloaded by defining methods with no arguments: | Operator | Method Name | | :--- | :--- | +| `+a` | `unaryPlus()` | | `-a` | `negate()` | | `!a` | `logicalNot()` | | `~a` | `bitNot()` | +`unaryPlus()` is useful in DSL-style builders where `+"text"` should append text to +the current receiver. See [samples/html_builder_dsl.lyng](samples/html_builder_dsl.lyng). + ### Assignment Operators Assignment operators like `+=` first attempt to call a specific assignment method. If that method is not defined, they fall back to a combination of the binary operator and a regular assignment (e.g., `a = a + b`). diff --git a/docs/ai_language_reference.md b/docs/ai_language_reference.md index df15065..c82cd43 100644 --- a/docs/ai_language_reference.md +++ b/docs/ai_language_reference.md @@ -83,6 +83,7 @@ Primary sources used: `lynglib/src/commonMain/kotlin/net/sergeych/lyng/{Parser,T ## 4. Operators (implemented) - Assignment: `=`, `+=`, `-=`, `*=`, `/=`, `%=`, `?=`. - Logical: `||`, `&&`, unary `!`. +- Unary arithmetic/bitwise: unary `+`, unary `-`, `~`. - Bitwise: `|`, `^`, `&`, `~`, shifts `<<`, `>>`. - Equality/comparison: `==`, `!=`, `===`, `!==`, `<`, `<=`, `>`, `>=`, `<=>`, `=~`, `!~`. - Type/containment: `is`, `!is`, `in`, `!in`, `as`, `as?`. @@ -119,6 +120,7 @@ Primary sources used: `lynglib/src/commonMain/kotlin/net/sergeych/lyng/{Parser,T - shorthand: `fun f(x) = expr`. - generics: `fun f(x: T): T`. - extension functions: `fun Type.name(...) { ... }`. + - context-aware extension functions: `context(Tag) fun String.unaryPlus() { this@Tag.addText(this) }`. - named singleton `object` declarations can be extension receivers too: `fun Config.describe(...) { ... }`, `val Config.tag get() = ...`. - static extension functions are callable on the type object: `static fun List.fill(...)` -> `List.fill(...)`. - delegated callable: `fun f(...) by delegate`. diff --git a/docs/samples/html_builder_dsl.lyng b/docs/samples/html_builder_dsl.lyng new file mode 100644 index 0000000..28a370a --- /dev/null +++ b/docs/samples/html_builder_dsl.lyng @@ -0,0 +1,50 @@ +class Tag(name: String) { + val name = name + var inner = "" + + fun child(tagName: String, block: Tag.()->void) { + val child = Tag(tagName) + with(child) { block(this) } + inner += child.render() + } + + fun head(block: Tag.()->void) { child("head", block) } + fun body(block: Tag.()->void) { child("body", block) } + fun title(block: Tag.()->void) { child("title", block) } + fun h1(block: Tag.()->void) { child("h1", block) } + + fun addText(text: String) { + inner += text + } + + fun render() { + "<" + name + ">" + inner + "" + } +} + +context(Tag) +fun String.unaryPlus() { + this@Tag.addText(this) +} + +fun html(block: Tag.()->void) { + val root = Tag("html") + with(root) { block(this) } + root.render() +} + +val page = html { + head { + title { + +"Demo" + } + } + body { + h1 { + +"Heading 1" + } + } +} + +println(page) +assertEquals("Demo

Heading 1

", page) diff --git a/docs/samples/operator_overloading.lyng b/docs/samples/operator_overloading.lyng index 2d42eff..ef26e40 100644 --- a/docs/samples/operator_overloading.lyng +++ b/docs/samples/operator_overloading.lyng @@ -1,6 +1,9 @@ // Sample: Operator Overloading in Lyng class Vector(val x: T, val y: T) { + // Overload unary + + fun unaryPlus() = this + // Overload + fun plus(other: Vector) = Vector(x + other.x, y + other.y) @@ -28,6 +31,11 @@ val v2 = Vector(5, 5) println("v1: " + v1) println("v2: " + v2) +// Test unary + +val v0 = +v1 +println("+v1 = " + v0) +assertEquals(Vector(10, 20), v0) + // Test binary + val v3 = v1 + v2 println("v1 + v2 = " + v3) diff --git a/lyng/src/commonMain/kotlin/Common.kt b/lyng/src/commonMain/kotlin/Common.kt index 120bf1e..7067e4a 100644 --- a/lyng/src/commonMain/kotlin/Common.kt +++ b/lyng/src/commonMain/kotlin/Common.kt @@ -45,6 +45,7 @@ import net.sergeych.lyng.io.db.createDbModule import net.sergeych.lyng.io.db.jdbc.createJdbcModule import net.sergeych.lyng.io.db.sqlite.createSqliteModule import net.sergeych.lyng.io.fs.createFs +import net.sergeych.lyng.io.html.createHtmlModule import net.sergeych.lyng.io.http.createHttpModule import net.sergeych.lyng.io.http.server.createHttpServerModule import net.sergeych.lyng.io.net.createNetModule @@ -146,6 +147,7 @@ private fun ImportManager.invalidateCliModuleCaches() { invalidatePackageCache("lyng.io.console") invalidatePackageCache("lyng.io.db.jdbc") invalidatePackageCache("lyng.io.db.sqlite") + invalidatePackageCache("lyng.io.html") invalidatePackageCache("lyng.io.http") invalidatePackageCache("lyng.io.http.server") invalidatePackageCache("lyng.io.ws") @@ -237,6 +239,7 @@ private fun installCliModules(manager: ImportManager) { createDbModule(manager) createJdbcModule(manager) createSqliteModule(manager) + createHtmlModule(manager) createHttpModule(PermitAllHttpAccessPolicy, manager) createHttpServerModule(PermitAllNetAccessPolicy, manager) createWsModule(PermitAllWsAccessPolicy, manager) diff --git a/lyngio/src/commonMain/kotlin/net/sergeych/lyng/io/html/LyngHtmlModule.kt b/lyngio/src/commonMain/kotlin/net/sergeych/lyng/io/html/LyngHtmlModule.kt new file mode 100644 index 0000000..bd42289 --- /dev/null +++ b/lyngio/src/commonMain/kotlin/net/sergeych/lyng/io/html/LyngHtmlModule.kt @@ -0,0 +1,44 @@ +/* + * Copyright 2026 Sergey S. Chernov real.sergeych@gmail.com + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package net.sergeych.lyng.io.html + +import net.sergeych.lyng.ModuleScope +import net.sergeych.lyng.Scope +import net.sergeych.lyng.Source +import net.sergeych.lyng.pacman.ImportManager +import net.sergeych.lyngio.stdlib_included.htmlLyng + +private const val HTML_MODULE_NAME = "lyng.io.html" + +fun createHtmlModule(scope: Scope): Boolean = createHtmlModule(scope.importManager) + +fun createHtml(scope: Scope): Boolean = createHtmlModule(scope) + +fun createHtmlModule(manager: ImportManager): Boolean { + if (manager.packageNames.contains(HTML_MODULE_NAME)) return false + manager.addPackage(HTML_MODULE_NAME) { module -> + buildHtmlModule(module) + } + return true +} + +fun createHtml(manager: ImportManager): Boolean = createHtmlModule(manager) + +private suspend fun buildHtmlModule(module: ModuleScope) { + module.eval(Source(HTML_MODULE_NAME, htmlLyng)) +} diff --git a/lyngio/src/commonTest/kotlin/net/sergeych/lyng/io/html/LyngHtmlModuleTest.kt b/lyngio/src/commonTest/kotlin/net/sergeych/lyng/io/html/LyngHtmlModuleTest.kt new file mode 100644 index 0000000..30a4790 --- /dev/null +++ b/lyngio/src/commonTest/kotlin/net/sergeych/lyng/io/html/LyngHtmlModuleTest.kt @@ -0,0 +1,57 @@ +/* + * Copyright 2026 Sergey S. Chernov real.sergeych@gmail.com + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package net.sergeych.lyng.io.html + +import kotlinx.coroutines.test.runTest +import net.sergeych.lyng.Compiler +import net.sergeych.lyng.Script +import net.sergeych.lyng.Source +import net.sergeych.lyng.pacman.ImportManager +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class LyngHtmlModuleTest { + + @Test + fun testModuleRegistrationIsIdempotent() = runTest { + val importManager = ImportManager() + assertTrue(createHtmlModule(importManager)) + assertFalse(createHtmlModule(importManager)) + } + + @Test + fun testModuleCanBeImported() = runTest { + val scope = Script.newScope() + createHtmlModule(scope.importManager) + + val result = Compiler.compile( + Source( + "", + """ + import lyng.io.html + 42 + """.trimIndent() + ), + scope.importManager + ).execute(scope) + + assertEquals("42", result.inspect(scope)) + } +} diff --git a/lyngio/stdlib/lyng/io/html.lyng b/lyngio/stdlib/lyng/io/html.lyng new file mode 100644 index 0000000..d44260e --- /dev/null +++ b/lyngio/stdlib/lyng/io/html.lyng @@ -0,0 +1,22 @@ +package lyng.io.html + +/* + HTML helpers package. + API surface is intentionally empty for now; this package exists so Lyng code + can import `lyng.io.html` and grow declarations here incrementally. +*/ + +class HtmlBuilder { + + private val head: List = [] + private val body: List = [] + + fun build(): String = + "" + + (head.isEmpty() ? "" : head.joinToString("\n")) + + (body.isEmpty() ? "" : body.joinToString("\n")) + + "" +} + +fun buildHtml(f: HtmlBuilder.()->void): String { + HtmlBuilder().apply(f).build() \ No newline at end of file diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt index f9240a2..3078137 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt @@ -885,6 +885,19 @@ class Compiler( ) } + private fun promotePreferredReceiverArg(context: Scope, preferredTypeName: String?) { + if (preferredTypeName == null) return + val receiverArg = context.args.list.firstOrNull { arg -> + arg.isInstanceOf(preferredTypeName) || + ((context[preferredTypeName]?.value as? ObjClass)?.let { typeClass -> + arg.isInstanceOf(typeClass) + } == true) + } ?: return + if (context.thisVariants.firstOrNull() !== receiverArg) { + context.setThisVariants(receiverArg, context.thisVariants) + } + } + private fun currentImplicitReceiverTypeNames(): List { val result = mutableListOf() for (ctx in codeContexts.asReversed()) { @@ -2010,6 +2023,7 @@ class Compiler( callableReturnTypeByScopeId = callableReturnTypeByScopeId, callableReturnTypeByName = callableReturnTypeByName, callSignatureByName = callSignatureByName, + extensionContextReceiversByWrapperName = extensionContextReceiversByWrapperName, externBindingNames = externBindingNames, preparedModuleBindingNames = importBindings.keys, scopeRefPosByName = moduleReferencePosByName, @@ -2078,19 +2092,60 @@ class Compiler( private val rangeParamNamesStack = mutableListOf>() private val extensionNames = mutableSetOf() private val extensionNamesByType = mutableMapOf>() + private val extensionContextReceiversByWrapperName = mutableMapOf>() private val useScopeSlots: Boolean = seedScope == null private fun registerExtensionName(typeName: String, memberName: String) { extensionNamesByType.getOrPut(typeName) { mutableSetOf() }.add(memberName) } + private fun contextReceiverTypeName(typeDecl: TypeDecl): String? = when (typeDecl) { + is TypeDecl.Simple -> typeDecl.name.substringAfterLast('.') + is TypeDecl.Generic -> typeDecl.name.substringAfterLast('.') + else -> null + } + + private fun contextReceiversSatisfied(required: List, visibleReceivers: List = currentImplicitReceiverTypeNames()): Boolean { + if (required.isEmpty()) return true + return required.all { req -> + visibleReceivers.any { visible -> + visible == req || resolveClassByName(visible)?.let { cls -> + cls.className == req || cls.mro.any { it.className == req } + } == true + } + } + } + + private fun rememberExtensionContextReceivers(wrapperName: String, record: ObjRecord) { + val fnType = record.typeDecl as? TypeDecl.Function ?: return + if (fnType.contextReceivers.isEmpty()) return + val names = fnType.contextReceivers.mapNotNull(::contextReceiverTypeName) + if (names.size == fnType.contextReceivers.size) { + extensionContextReceiversByWrapperName[wrapperName] = names + } + } + private fun hasExtensionFor(typeName: String, memberName: String): Boolean { - if (extensionNamesByType[typeName]?.contains(memberName) == true) return true + if (extensionNamesByType[typeName]?.contains(memberName) == true) { + val wrapperName = extensionCallableName(typeName, memberName) + if (contextReceiversSatisfied(extensionContextReceiversByWrapperName[wrapperName].orEmpty())) return true + val getterName = extensionPropertyGetterName(typeName, memberName) + if (contextReceiversSatisfied(extensionContextReceiversByWrapperName[getterName].orEmpty())) return true + val setterName = extensionPropertySetterName(typeName, memberName) + if (contextReceiversSatisfied(extensionContextReceiversByWrapperName[setterName].orEmpty())) return true + } val scopeRec = seedScope?.get(typeName) ?: importManager.rootScope.get(typeName) val cls = (scopeRec?.value as? ObjClass) ?: resolveTypeDeclObjClass(TypeDecl.Simple(typeName, false)) if (cls != null) { for (base in cls.mro) { - if (extensionNamesByType[base.className]?.contains(memberName) == true) return true + if (extensionNamesByType[base.className]?.contains(memberName) == true) { + val wrapperName = extensionCallableName(base.className, memberName) + if (contextReceiversSatisfied(extensionContextReceiversByWrapperName[wrapperName].orEmpty())) return true + val getterName = extensionPropertyGetterName(base.className, memberName) + if (contextReceiversSatisfied(extensionContextReceiversByWrapperName[getterName].orEmpty())) return true + val setterName = extensionPropertySetterName(base.className, memberName) + if (contextReceiversSatisfied(extensionContextReceiversByWrapperName[setterName].orEmpty())) return true + } } } val candidates = mutableListOf(typeName) @@ -2102,7 +2157,10 @@ class Compiler( extensionPropertySetterName(baseName, memberName) ) for (wrapperName in wrapperNames) { + if (!contextReceiversSatisfied(extensionContextReceiversByWrapperName[wrapperName].orEmpty())) continue val resolved = resolveImportBinding(wrapperName, Pos.builtIn) ?: continue + rememberExtensionContextReceivers(wrapperName, resolved.record) + if (!contextReceiversSatisfied(extensionContextReceiversByWrapperName[wrapperName].orEmpty())) continue val plan = moduleSlotPlan() if (plan != null && !plan.slots.containsKey(wrapperName)) { declareSlotNameIn( @@ -2385,6 +2443,7 @@ class Compiler( callableReturnTypeByScopeId = callableReturnTypeByScopeId, callableReturnTypeByName = callableReturnTypeByName, callSignatureByName = callSignatureByName, + extensionContextReceiversByWrapperName = extensionContextReceiversByWrapperName, externCallableNames = externCallableNames, externBindingNames = externBindingNames, preparedModuleBindingNames = importBindings.keys, @@ -2420,6 +2479,7 @@ class Compiler( callableReturnTypeByScopeId = callableReturnTypeByScopeId, callableReturnTypeByName = callableReturnTypeByName, callSignatureByName = callSignatureByName, + extensionContextReceiversByWrapperName = extensionContextReceiversByWrapperName, externCallableNames = externCallableNames, externBindingNames = externBindingNames, preparedModuleBindingNames = importBindings.keys, @@ -2480,6 +2540,7 @@ class Compiler( callableReturnTypeByScopeId = callableReturnTypeByScopeId, callableReturnTypeByName = callableReturnTypeByName, callSignatureByName = callSignatureByName, + extensionContextReceiversByWrapperName = extensionContextReceiversByWrapperName, externCallableNames = externCallableNames, externBindingNames = externBindingNames, preparedModuleBindingNames = importBindings.keys, @@ -3721,7 +3782,7 @@ class Compiler( val inlineBodyRef = argsDeclaration?.let { null } ?: extractInlineLambdaBodyRef(body) val supportsDirectInvokeFastPath = bytecodeFn != null && bytecodeFn.scopeSlotCount == 0 && - expectedReceiverType == null && + effectiveExpectedReceiverType == null && !wrapAsExtensionCallable && !containsDelegatedRefs(body) val ref = LambdaFnRef( @@ -3733,9 +3794,10 @@ class Compiler( override fun bytecodeBody(): BytecodeStatement? = fnStatements as? BytecodeStatement override fun callOnFast(scope: Scope): Obj? { - val context = scope.applyClosureForBytecode(closureScope, preferredThisType = expectedReceiverType).also { + val context = scope.applyClosureForBytecode(closureScope, preferredThisType = effectiveExpectedReceiverType).also { it.args = scope.args } + promotePreferredReceiverArg(context, effectiveExpectedReceiverType) if (captureSlots.isNotEmpty()) { if (captureRecords != null) { context.captureRecords = captureRecords @@ -3804,9 +3866,10 @@ class Compiler( } override suspend fun execute(scope: Scope): Obj { - val context = scope.applyClosureForBytecode(closureScope, preferredThisType = expectedReceiverType).also { + val context = scope.applyClosureForBytecode(closureScope, preferredThisType = effectiveExpectedReceiverType).also { it.args = scope.args } + promotePreferredReceiverArg(context, effectiveExpectedReceiverType) if (captureSlots.isNotEmpty()) { if (captureRecords != null) { context.captureRecords = captureRecords @@ -5321,6 +5384,7 @@ class Compiler( is RangeRef -> ObjRange.type is ClassOperatorRef -> ObjClassType is CastRef -> resolveTypeRefClass(ref.castTypeRef()) + is UnaryOpRef -> inferUnaryOpReturnClass(ref) is IndexRef -> { val targetClass = resolveReceiverClassForMember(ref.targetRef) classMethodReturnClass(targetClass, "getAt") @@ -5439,6 +5503,7 @@ class Compiler( ?: resolveClassByName(ref.receiverTypeName())?.let { classMethodReturnTypeDecl(it, ref.methodName()) } } is CallRef -> callReturnTypeDeclByRef[ref] ?: inferCallReturnTypeDecl(ref) + is UnaryOpRef -> inferUnaryOpReturnTypeDecl(ref) is BinaryOpRef -> inferBinaryOpReturnTypeDecl(ref) is ElvisRef -> inferElvisTypeDecl(ref) is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) } @@ -5513,6 +5578,7 @@ class Compiler( is QualifiedThisMethodSlotCallRef -> inferMethodCallReturnClass(resolveClassByName(ref.receiverTypeName()), ref.methodName()) is CallRef -> inferCallReturnTypeDecl(ref)?.let { resolveTypeDeclObjClass(it) } ?: inferCallReturnClass(ref) + is UnaryOpRef -> inferUnaryOpReturnClass(ref) is BinaryOpRef -> inferBinaryOpReturnClass(ref) is FieldRef -> { val targetClass = resolveReceiverClassForMember(ref.target) @@ -5539,6 +5605,13 @@ class Compiler( else -> null } + private fun unaryOpMethodName(op: UnaryOp): String? = when (op) { + UnaryOp.POSITIVE -> "unaryPlus" + UnaryOp.NEGATE -> "negate" + UnaryOp.BITNOT -> "bitNot" + UnaryOp.NOT -> null + } + private fun interopOperatorFor(op: BinOp): InteropOperator? = when (op) { BinOp.PLUS -> InteropOperator.Plus BinOp.MINUS -> InteropOperator.Minus @@ -5614,6 +5687,67 @@ class Compiler( } } + private fun inferExtensionMethodReturnTypeDecl( + receiverDecl: TypeDecl?, + receiverClass: ObjClass?, + memberName: String + ): TypeDecl? { + if (receiverClass == null) return null + for (cls in receiverClass.mro) { + val wrapperName = extensionCallableName(cls.className, memberName) + val resolved = resolveImportBinding(wrapperName, Pos.builtIn) ?: continue + registerImportBinding(wrapperName, resolved.binding, Pos.builtIn) + val wrapperType = resolved.record.typeDecl as? TypeDecl.Function ?: continue + val bindings = mutableMapOf() + val receiverParam = wrapperType.params.firstOrNull() ?: wrapperType.receiver + if (receiverParam != null && receiverDecl != null) { + collectTypeVarBindings(receiverParam, receiverDecl, bindings) + } + return if (bindings.isEmpty()) wrapperType.returnType + else substituteTypeAliasTypeVars(wrapperType.returnType, bindings) + } + return null + } + + private fun inferUnaryOpReturnTypeDecl(ref: UnaryOpRef): TypeDecl? { + val operandDecl = resolveReceiverTypeDecl(ref.a) + val operandClass = resolveReceiverClassForMember(ref.a) ?: inferObjClassFromRef(ref.a) + return when (ref.op) { + UnaryOp.NOT -> typeDeclOfClass(ObjBool.type) + UnaryOp.POSITIVE -> { + unaryOpMethodName(ref.op)?.let { methodName -> + classMethodReturnTypeDecl(operandClass, methodName)?.let { return it } + inferExtensionMethodReturnTypeDecl(operandDecl, operandClass, methodName)?.let { return it } + } + operandDecl ?: operandClass?.let(::typeDeclOfClass) + } + UnaryOp.NEGATE -> when (operandClass) { + ObjInt.type -> typeDeclOfClass(ObjInt.type) + ObjReal.type -> typeDeclOfClass(ObjReal.type) + else -> unaryOpMethodName(ref.op)?.let { methodName -> + classMethodReturnTypeDecl(operandClass, methodName) + ?: inferExtensionMethodReturnTypeDecl(operandDecl, operandClass, methodName) + } + } + UnaryOp.BITNOT -> when (operandClass) { + ObjInt.type -> typeDeclOfClass(ObjInt.type) + else -> unaryOpMethodName(ref.op)?.let { methodName -> + classMethodReturnTypeDecl(operandClass, methodName) + ?: inferExtensionMethodReturnTypeDecl(operandDecl, operandClass, methodName) + } + } + } + } + + private fun inferUnaryOpReturnClass(ref: UnaryOpRef): ObjClass? { + inferUnaryOpReturnTypeDecl(ref)?.let { declared -> + resolveTypeDeclObjClass(declared)?.let { return it } + if (declared is TypeDecl.TypeVar) return Obj.rootObjectType + } + val operandClass = resolveReceiverClassForMember(ref.a) ?: inferObjClassFromRef(ref.a) + return if (ref.op == UnaryOp.POSITIVE) operandClass else null + } + private fun inferBinaryOpReturnClass(ref: BinaryOpRef): ObjClass? { inferBinaryOpReturnTypeDecl(ref)?.let { declared -> resolveTypeDeclObjClass(declared)?.let { return it } @@ -7439,6 +7573,7 @@ class Compiler( is FastLocalVarRef -> nameObjClass[ref.name]?.className ?: nameTypeDecl[ref.name]?.let { typeDeclName(it) } is QualifiedThisRef -> ref.typeName + is UnaryOpRef -> inferUnaryOpReturnClass(ref)?.className else -> resolveReceiverClassForMember(ref)?.className } } @@ -7458,8 +7593,12 @@ class Compiler( Token.Type.CHAR -> ConstRef(ObjChar(t.value[0]).asReadonly) Token.Type.PLUS -> { - val n = parseNumber(true) - ConstRef(n.asReadonly) + parseNumberOrNull(true)?.let { n -> + ConstRef(n.asReadonly) + } ?: run { + val n = parseTerm() ?: throw ScriptError(t.pos, "Expecting expression after unary plus") + UnaryOpRef(UnaryOp.POSITIVE, n) + } } Token.Type.MINUS -> { @@ -7655,6 +7794,43 @@ class Compiler( } } + private fun parseContextReceiverDeclarationList(start: Pos): List { + if (!cc.skipTokenOfType(Token.Type.LPAREN, isOptional = true)) { + throw ScriptError(start, "expected '(' after context") + } + val receivers = mutableListOf() + cc.skipWsTokens() + if (cc.peekNextNonWhitespace().type == Token.Type.RPAREN) { + cc.nextNonWhitespace() + return receivers + } + while (true) { + val (decl, _) = parseTypeExpressionWithMini() + receivers += decl + val sep = cc.nextNonWhitespace() + when (sep.type) { + Token.Type.COMMA -> continue + Token.Type.RPAREN -> return receivers + else -> sep.raiseSyntax("expected ',' or ')' in context receiver list") + } + } + } + + private suspend fun parseContextFunctionDeclaration(contextToken: Token): Statement { + val contextReceivers = parseContextReceiverDeclarationList(contextToken.pos) + val fn = cc.nextNonWhitespace() + if (fn.type != Token.Type.ID || (fn.value != "fun" && fn.value != "fn")) { + throw ScriptError(fn.pos, "context receivers are currently supported only on function declarations") + } + pendingDeclStart = contextToken.pos + pendingDeclDoc = consumePendingDoc() + return parseFunctionDeclaration( + isExtern = false, + isStatic = false, + contextReceiverTypeDecls = contextReceivers + ) + } + /** * Parse keyword-starting statement. * @return parsed statement or null if, for example. [id] is not among keywords @@ -7693,6 +7869,7 @@ class Compiler( pendingDeclDoc = consumePendingDoc() parseFunctionDeclaration(isExtern = false, isStatic = false) } + "context" -> parseContextFunctionDeclaration(id) // Visibility modifiers for declarations: private/protected val/var/fun/fn "while" -> parseWhileStatement() "do" -> parseDoWhileStatement() @@ -9648,7 +9825,8 @@ class Compiler( isOverride: Boolean = false, isExtern: Boolean = false, isStatic: Boolean = false, - isTransient: Boolean = isTransientFlag + isTransient: Boolean = isTransientFlag, + contextReceiverTypeDecls: List = emptyList() ): Statement { isTransientFlag = false val declarationAnnotationSpecs = pendingDeclAnnotations.toList() @@ -9688,7 +9866,17 @@ class Compiler( ) } registerExtensionName(extTypeName, name) + if (contextReceiverTypeDecls.isNotEmpty()) { + val contextNames = contextReceiverTypeDecls.mapNotNull(::contextReceiverTypeName) + if (contextNames.size != contextReceiverTypeDecls.size) { + throw ScriptError(start, "context receiver types for extension functions must be class-like") + } + extensionContextReceiversByWrapperName[extensionCallableName(extTypeName, name)] = contextNames + } } else { + if (contextReceiverTypeDecls.isNotEmpty()) { + throw ScriptError(start, "context receivers are currently supported only on extension functions") + } val t = cc.next() if (t.type != Token.Type.ID) throw ScriptError(t.pos, "Expected identifier after 'fun'") @@ -9774,6 +9962,7 @@ class Compiler( if (parentContext is CodeContext.ClassBody && !isStatic && extTypeName == null) { classMemberTypeDeclByName.getOrPut(parentContext.name) { mutableMapOf() }[name] = TypeDecl.Function( receiver = receiverTypeDecl, + contextReceivers = contextReceiverTypeDecls, params = argsDeclaration.params.map { it.type }, returnType = returnTypeDecl ?: TypeDecl.TypeAny, nullable = false @@ -9851,7 +10040,7 @@ class Compiler( CodeContext.Function( name, implicitThisMembers = implicitThisMembers, - implicitReceiverTypeNames = listOfNotNull(implicitThisTypeName), + implicitReceiverTypeNames = listOfNotNull(implicitThisTypeName) + contextReceiverTypeDecls.mapNotNull(::contextReceiverTypeName), typeParams = typeParams, typeParamDecls = typeParamDecls, noImplicitThis = noImplicitThis @@ -9949,6 +10138,7 @@ class Compiler( run { val memberTypeDecl = TypeDecl.Function( receiver = receiverTypeDecl, + contextReceivers = contextReceiverTypeDecls, params = argsDeclaration.params.map { it.type }, returnType = inferredReturnDecl ?: TypeDecl.TypeAny, nullable = false @@ -10062,7 +10252,7 @@ class Compiler( } } if (extTypeName != null) { - context.thisObj = scope.thisObj + context.setThisVariants(scope.thisObj, context.thisVariants) } val localNames = frame.fn.localSlotNames for (i in localNames.indices) { @@ -10145,6 +10335,7 @@ class Compiler( annotation = annotation, typeDecl = if (isDelegated) null else TypeDecl.Function( receiver = receiverTypeDecl, + contextReceivers = contextReceiverTypeDecls, params = argsDeclaration.params.map { it.type }, returnType = inferredReturnDecl ?: TypeDecl.TypeAny, nullable = false @@ -11644,6 +11835,7 @@ class Compiler( val a = constOf(aRef) ?: return null return when (op) { UnaryOp.NOT -> if (a is ObjBool) if (!a.value) ObjTrue else ObjFalse else null + UnaryOp.POSITIVE -> a UnaryOp.NEGATE -> when (a) { is ObjInt -> ObjInt.of(-a.value) is ObjReal -> ObjReal.of(-a.value) diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Scope.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Scope.kt index 3852a2d..585121e 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Scope.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Scope.kt @@ -99,6 +99,20 @@ open class Scope( extensions.getOrPut(cls) { mutableMapOf() }[name] = record } + private fun extensionContextReceiversSatisfied(record: ObjRecord): Boolean { + val fnType = record.typeDecl as? TypeDecl.Function ?: return true + if (fnType.contextReceivers.isEmpty()) return true + return fnType.contextReceivers.all { required -> + thisVariants.any { variant -> + when (required) { + is TypeDecl.Simple -> variant.isInstanceOf(required.name.substringAfterLast('.')) + is TypeDecl.Generic -> variant.isInstanceOf(required.name.substringAfterLast('.')) + else -> false + } + } + } + } + internal fun findExtension(receiverClass: ObjClass, name: String): ObjRecord? { var s: Scope? = this var hops = 0 @@ -106,7 +120,9 @@ open class Scope( // Proximity rule: check all extensions in the current scope before going to parent. // Priority within scope: more specific class in MRO wins. for (cls in receiverClass.mro) { - s.extensions[cls]?.get(name)?.let { return it } + s.extensions[cls]?.get(name)?.let { + if (extensionContextReceiversSatisfied(it)) return it + } } if (s is BytecodeClosureScope) { s.closureScope.findExtension(receiverClass, name)?.let { return it } diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt index 99e456f..b20da72 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt @@ -43,6 +43,7 @@ class BytecodeCompiler( private val callableReturnTypeByScopeId: Map> = emptyMap(), private val callableReturnTypeByName: Map = emptyMap(), private val callSignatureByName: Map = emptyMap(), + private val extensionContextReceiversByWrapperName: Map> = emptyMap(), private val externCallableNames: Set = emptySet(), private val externBindingNames: Set = emptySet(), private val preparedModuleBindingNames: Set = emptySet(), @@ -1146,56 +1147,95 @@ class BytecodeCompiler( } private fun compileUnary(ref: UnaryOpRef): CompiledValue? { - val a = compileRef(unaryOperand(ref)) ?: return null - val out = allocSlot() return when (unaryOp(ref)) { - UnaryOp.NEGATE -> when (a.type) { - SlotType.INT -> { - builder.emit(Opcode.NEG_INT, a.slot, out) - CompiledValue(out, SlotType.INT) + UnaryOp.POSITIVE -> { + val operandRef = unaryOperand(ref) + if (hasUnaryCallable(operandRef, "unaryPlus")) { + return compileMethodCall(MethodCallRef(operandRef, "unaryPlus", emptyList(), false, false)) } - SlotType.REAL -> { - builder.emit(Opcode.NEG_REAL, a.slot, out) - CompiledValue(out, SlotType.REAL) - } - else -> compileObjUnaryOp(unaryOperand(ref), a, "negate", Pos.builtIn) - } - UnaryOp.NOT -> { - when (a.type) { - SlotType.BOOL -> builder.emit(Opcode.NOT_BOOL, a.slot, out) - SlotType.INT -> { - val tmp = allocSlot() - builder.emit(Opcode.INT_TO_BOOL, a.slot, tmp) - builder.emit(Opcode.NOT_BOOL, tmp, out) + val a = compileRef(operandRef) ?: return null + return when (a.type) { + SlotType.INT, SlotType.REAL -> a + else -> { + val obj = ensureObjSlot(a) + val out = allocSlot() + builder.emit(Opcode.POS_OBJ, obj.slot, out) + updateSlotType(out, SlotType.OBJ) + slotObjClass[obj.slot]?.let { slotObjClass[out] = it } + CompiledValue(out, SlotType.OBJ) } - SlotType.OBJ, SlotType.UNKNOWN -> { - val objSlot = ensureObjSlot(a) - val tmp = allocSlot() - builder.emit(Opcode.OBJ_TO_BOOL, objSlot.slot, tmp) - builder.emit(Opcode.NOT_BOOL, tmp, out) - updateSlotType(tmp, SlotType.BOOL) - } - else -> return null } - CompiledValue(out, SlotType.BOOL) } - UnaryOp.BITNOT -> { - if (a.type == SlotType.INT) { - builder.emit(Opcode.INV_INT, a.slot, out) - return CompiledValue(out, SlotType.INT) + else -> { + val a = compileRef(unaryOperand(ref)) ?: return null + val out = allocSlot() + when (unaryOp(ref)) { + UnaryOp.NEGATE -> when (a.type) { + SlotType.INT -> { + builder.emit(Opcode.NEG_INT, a.slot, out) + CompiledValue(out, SlotType.INT) + } + SlotType.REAL -> { + builder.emit(Opcode.NEG_REAL, a.slot, out) + CompiledValue(out, SlotType.REAL) + } + else -> compileObjUnaryOp(unaryOperand(ref), a, "negate", Pos.builtIn) + } + UnaryOp.NOT -> { + when (a.type) { + SlotType.BOOL -> builder.emit(Opcode.NOT_BOOL, a.slot, out) + SlotType.INT -> { + val tmp = allocSlot() + builder.emit(Opcode.INT_TO_BOOL, a.slot, tmp) + builder.emit(Opcode.NOT_BOOL, tmp, out) + } + SlotType.OBJ, SlotType.UNKNOWN -> { + val objSlot = ensureObjSlot(a) + val tmp = allocSlot() + builder.emit(Opcode.OBJ_TO_BOOL, objSlot.slot, tmp) + builder.emit(Opcode.NOT_BOOL, tmp, out) + updateSlotType(tmp, SlotType.BOOL) + } + else -> return null + } + CompiledValue(out, SlotType.BOOL) + } + UnaryOp.BITNOT -> { + if (a.type == SlotType.INT) { + builder.emit(Opcode.INV_INT, a.slot, out) + return CompiledValue(out, SlotType.INT) + } + return compileObjUnaryOp(unaryOperand(ref), a, "bitNot", Pos.builtIn) + } + UnaryOp.POSITIVE -> error("unreachable") } - return compileObjUnaryOp(unaryOperand(ref), a, "bitNot", Pos.builtIn) } } } + private fun hasUnaryCallable(ref: ObjRef, memberName: String): Boolean { + val receiverClass = resolveReceiverClass(ref) ?: return false + if (receiverClass == ObjDynamic.type) return false + if (receiverClass is ObjInstanceClass && !isThisReceiver(ref)) return true + val resolvedMember = receiverClass.resolveInstanceMember(memberName) + if (resolvedMember?.declaringClass?.className == "Obj") return false + val abstractRecord = receiverClass.members[memberName] ?: receiverClass.classScope?.objects?.get(memberName) + if (abstractRecord?.isAbstract == true) return false + val methodId = receiverClass.instanceMethodIdMap(includeAbstract = true)[memberName] + if (methodId != null && resolvedMember?.declaringClass?.className != "Obj") return true + val fieldId = if (resolvedMember != null) receiverClass.instanceFieldIdMap()[memberName] else null + if (fieldId != null) return true + return resolveExtensionCallableSlot(receiverClass, memberName) != null + } + private fun compileObjUnaryOp( ref: ObjRef, value: CompiledValue, memberName: String, - pos: Pos + pos: Pos, + defaultIdentity: Boolean = false ): CompiledValue? { - val receiverClass = resolveReceiverClass(ref) + val receiverClass = resolveReceiverClass(ref) ?: slotObjClass[value.slot] val methodId = receiverClass?.instanceMethodIdMap(includeAbstract = true)?.get(memberName) if (methodId != null) { val receiverObj = ensureObjSlot(value) @@ -1204,6 +1244,19 @@ class BytecodeCompiler( updateSlotType(dst, SlotType.OBJ) return CompiledValue(dst, SlotType.OBJ) } + val extSlot = when { + receiverClass != null -> resolveExtensionCallableSlot(receiverClass, memberName) + else -> resolveUniqueExtensionWrapperSlot(memberName, "__ext__") + } + if (extSlot != null) { + val callee = ensureObjSlot(extSlot) + val args = compileCallArgsWithReceiver(value, emptyList(), false) ?: return null + val encodedCount = encodeCallArgCount(args) ?: return null + val dst = allocSlot() + setPos(pos) + emitCallCompiled(callee, args.base, encodedCount, dst) + return CompiledValue(dst, SlotType.OBJ) + } if (memberName == "negate" && (receiverClass == null || isDelegateClass(receiverClass) || receiverClass in setOf(ObjInt.type, ObjReal.type)) ) { @@ -1217,6 +1270,9 @@ class BytecodeCompiler( updateSlotType(dst, SlotType.OBJ) return CompiledValue(dst, SlotType.OBJ) } + if (defaultIdentity) { + return value + } throw BytecodeCompileException( "Unknown member $memberName on ${receiverClass?.className ?: "unknown"}", pos @@ -5972,6 +6028,7 @@ class BytecodeCompiler( ): String? { for (receiverName in extensionReceiverTypeNames(receiverClass)) { val candidate = wrapperName(receiverName, memberName) + if (!extensionContextReceiversSatisfied(candidate)) continue if (allowedScopeNames != null && !allowedScopeNames.contains(candidate) && !localSlotIndexByName.containsKey(candidate) @@ -5983,6 +6040,31 @@ class BytecodeCompiler( return null } + private fun currentImplicitReceiverTypeNames(): List { + val result = mutableListOf() + inlineThisBindings.asReversed().forEach { binding -> + val typeName = binding.typeName ?: return@forEach + if (!result.contains(typeName)) result += typeName + } + implicitThisTypeName?.let { + if (!result.contains(it)) result += it + } + return result + } + + private fun extensionContextReceiversSatisfied(wrapperName: String): Boolean { + val required = extensionContextReceiversByWrapperName[wrapperName].orEmpty() + if (required.isEmpty()) return true + val visible = currentImplicitReceiverTypeNames() + return required.all { req -> + visible.any { visibleName -> + visibleName == req || resolveTypeNameClass(visibleName)?.let { cls -> + cls.className == req || cls.mro.any { it.className == req } + } == true + } + } + } + private fun resolveUniqueExtensionWrapperName( memberName: String, wrapperPrefix: String @@ -5991,12 +6073,12 @@ class BytecodeCompiler( val candidates = LinkedHashSet() for (name in localSlotIndexByName.keys) { if (name.startsWith(wrapperPrefix) && name.endsWith(suffix)) { - candidates.add(name) + if (extensionContextReceiversSatisfied(name)) candidates.add(name) } } for (name in scopeSlotIndexByName.keys) { if (name.startsWith(wrapperPrefix) && name.endsWith(suffix)) { - candidates.add(name) + if (extensionContextReceiversSatisfied(name)) candidates.add(name) } } return candidates.singleOrNull() @@ -8313,6 +8395,19 @@ class BytecodeCompiler( is ObjChar -> ObjChar.type else -> null } + is UnaryOpRef -> when (ref.op) { + UnaryOp.NOT -> ObjBool.type + UnaryOp.POSITIVE -> resolveReceiverClass(ref.a) + UnaryOp.NEGATE -> when (val operandClass = resolveReceiverClass(ref.a)) { + ObjInt.type -> ObjInt.type + ObjReal.type -> ObjReal.type + else -> inferMethodCallReturnClass(operandClass, "negate") + } + UnaryOp.BITNOT -> when (val operandClass = resolveReceiverClass(ref.a)) { + ObjInt.type -> ObjInt.type + else -> inferMethodCallReturnClass(operandClass, "bitNot") + } + } is CastRef -> resolveTypeRefClass(ref.castTypeRef()) ?: resolveReceiverClass(ref.castValueRef()) is FieldRef -> { diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeStatement.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeStatement.kt index 17312a9..b8d3a55 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeStatement.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeStatement.kt @@ -107,6 +107,7 @@ class BytecodeStatement private constructor( callableReturnTypeByScopeId: Map> = emptyMap(), callableReturnTypeByName: Map = emptyMap(), callSignatureByName: Map = emptyMap(), + extensionContextReceiversByWrapperName: Map> = emptyMap(), externCallableNames: Set = emptySet(), externBindingNames: Set = emptySet(), preparedModuleBindingNames: Set = emptySet(), @@ -148,6 +149,7 @@ class BytecodeStatement private constructor( callableReturnTypeByScopeId = callableReturnTypeByScopeId, callableReturnTypeByName = callableReturnTypeByName, callSignatureByName = callSignatureByName, + extensionContextReceiversByWrapperName = extensionContextReceiversByWrapperName, externCallableNames = externCallableNames, externBindingNames = externBindingNames, preparedModuleBindingNames = preparedModuleBindingNames, diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdBuilder.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdBuilder.kt index 344769e..d3731bd 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdBuilder.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdBuilder.kt @@ -143,7 +143,7 @@ class CmdBuilder { Opcode.UNBOX_INT_OBJ, Opcode.UNBOX_REAL_OBJ, Opcode.INT_TO_REAL, Opcode.REAL_TO_INT, Opcode.BOOL_TO_INT, Opcode.INT_TO_BOOL, Opcode.OBJ_TO_BOOL, Opcode.GET_OBJ_CLASS, - Opcode.NEG_INT, Opcode.NEG_REAL, Opcode.NOT_BOOL, Opcode.INV_INT, + Opcode.NEG_INT, Opcode.NEG_REAL, Opcode.NOT_BOOL, Opcode.INV_INT, Opcode.POS_OBJ, Opcode.ASSERT_IS -> listOf(OperandKind.SLOT, OperandKind.SLOT) Opcode.CHECK_IS, Opcode.MAKE_QUALIFIED_VIEW -> @@ -698,6 +698,7 @@ class CmdBuilder { } else { CmdNotBool(operands[0], operands[1]) } + Opcode.POS_OBJ -> CmdPosObj(operands[0], operands[1]) Opcode.AND_BOOL -> if (isFastLocal(operands[0]) && isFastLocal(operands[1]) && isFastLocal(operands[2])) { CmdAndBoolLocal(operands[0] - scopeSlotCount, operands[1] - scopeSlotCount, operands[2] - scopeSlotCount) } else { diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdDisassembler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdDisassembler.kt index ceb54be..023eeaf 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdDisassembler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdDisassembler.kt @@ -450,6 +450,7 @@ object CmdDisassembler { is CmdMulObj -> Opcode.MUL_OBJ to intArrayOf(cmd.a, cmd.b, cmd.dst) is CmdDivObj -> Opcode.DIV_OBJ to intArrayOf(cmd.a, cmd.b, cmd.dst) is CmdModObj -> Opcode.MOD_OBJ to intArrayOf(cmd.a, cmd.b, cmd.dst) + is CmdPosObj -> Opcode.POS_OBJ to intArrayOf(cmd.a, cmd.dst) is CmdContainsObj -> Opcode.CONTAINS_OBJ to intArrayOf(cmd.target, cmd.value, cmd.dst) is CmdAssignOpObj -> Opcode.ASSIGN_OP_OBJ to intArrayOf(cmd.opId, cmd.targetSlot, cmd.valueSlot, cmd.dst, cmd.nameId) is CmdJmp -> Opcode.JMP to intArrayOf(cmd.target) @@ -593,6 +594,8 @@ object CmdDisassembler { Opcode.ADD_OBJ, Opcode.SUB_OBJ, Opcode.MUL_OBJ, Opcode.DIV_OBJ, Opcode.MOD_OBJ, Opcode.CONTAINS_OBJ, Opcode.AND_BOOL, Opcode.OR_BOOL -> listOf(OperandKind.SLOT, OperandKind.SLOT, OperandKind.SLOT) + Opcode.POS_OBJ -> + listOf(OperandKind.SLOT, OperandKind.SLOT) Opcode.ASSIGN_OP_OBJ -> listOf(OperandKind.ID, OperandKind.SLOT, OperandKind.SLOT, OperandKind.SLOT, OperandKind.CONST) Opcode.INC_INT, Opcode.DEC_INT, Opcode.RET, Opcode.ITER_PUSH, Opcode.LOAD_THIS -> diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt index 18c0170..1514a64 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt @@ -1942,6 +1942,14 @@ class CmdCmpGteObj(internal val a: Int, internal val b: Int, internal val dst: I } } +class CmdPosObj(internal val a: Int, internal val dst: Int) : Cmd() { + override suspend fun perform(frame: CmdFrame) { + val result = frame.slotToObj(a).unaryPlus(frame.ensureScope()) + frame.storeObjResult(dst, result) + return + } +} + class CmdAddObj(internal val a: Int, internal val b: Int, internal val dst: Int) : Cmd() { override suspend fun perform(frame: CmdFrame) { val result = frame.slotToObj(a).plus(frame.ensureScope(), frame.slotToObj(b)) @@ -4176,6 +4184,15 @@ class BytecodeLambdaCallable( val context = callScope.applyClosureForBytecode(closureScope, preferredThisType).also { it.args = args } + preferredThisType?.let { typeName -> + val receiverArg = args.list.firstOrNull { arg -> + arg.isInstanceOf(typeName) || + ((context[typeName]?.value as? ObjClass)?.let { typeClass -> arg.isInstanceOf(typeClass) } == true) + } + if (receiverArg != null && context.thisVariants.firstOrNull() !== receiverArg) { + context.setThisVariants(receiverArg, context.thisVariants) + } + } if (captureRecords != null) { context.captureRecords = captureRecords context.captureNames = captureNames diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/Opcode.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/Opcode.kt index ffa8b0e..c1cdca1 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/Opcode.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/Opcode.kt @@ -144,6 +144,7 @@ enum class Opcode(val code: Int) { MOD_OBJ(0x7B), CONTAINS_OBJ(0x7C), ASSIGN_OP_OBJ(0x7D), + POS_OBJ(0x7E), JMP(0x80), JMP_IF_TRUE(0x81), diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/Obj.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/Obj.kt index d46a535..9e3f6f3 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/Obj.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/Obj.kt @@ -317,6 +317,12 @@ open class Obj { } } + open suspend fun unaryPlus(scope: Scope): Obj { + return invokeInstanceMethod(scope, "unaryPlus", Arguments.EMPTY) { + this + } + } + open suspend fun mul(scope: Scope, other: Obj): Obj { val otherValue = when (other) { is FrameSlotRef -> other.read() diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjRef.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjRef.kt index c1c59a4..755957b 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjRef.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjRef.kt @@ -73,7 +73,7 @@ class ClassOperatorRef(val target: ObjRef, val pos: Pos) : ObjRef { } /** Unary operations supported by ObjRef. */ -enum class UnaryOp { NOT, NEGATE, BITNOT } +enum class UnaryOp { NOT, POSITIVE, NEGATE, BITNOT } /** Binary operations supported by ObjRef. */ enum class BinOp { diff --git a/lynglib/src/commonTest/kotlin/LaunchPoolTest.kt b/lynglib/src/commonTest/kotlin/LaunchPoolTest.kt index 968c13f..3966792 100644 --- a/lynglib/src/commonTest/kotlin/LaunchPoolTest.kt +++ b/lynglib/src/commonTest/kotlin/LaunchPoolTest.kt @@ -15,17 +15,21 @@ * */ -import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext import kotlinx.coroutines.withTimeout import kotlin.test.Test import net.sergeych.lyng.eval as lyngEval class LaunchPoolTest { - private suspend fun eval(code: String) = withTimeout(2_000L) { lyngEval(code) } + private suspend fun eval(code: String) = withContext(Dispatchers.Default) { + withTimeout(2_000L) { lyngEval(code) } + } @Test - fun testBasicExecution() = runBlocking { + fun testBasicExecution() = runTest { eval(""" val pool = LaunchPool(2) val d1 = pool.launch { 1 + 1 } @@ -37,7 +41,7 @@ class LaunchPoolTest { } @Test - fun testResultsCollected() = runBlocking { + fun testResultsCollected() = runTest { eval(""" val pool = LaunchPool(4) val jobs = (1..10).map { n -> pool.launch { n * n } } @@ -48,7 +52,7 @@ class LaunchPoolTest { } @Test - fun testConcurrencyLimit() = runBlocking { + fun testConcurrencyLimit() = runTest { eval(""" // With maxWorkers=2, at most 2 tasks run at the same time. val mu = Mutex() @@ -70,7 +74,7 @@ class LaunchPoolTest { } @Test - fun testExceptionCapturedInDeferred() = runBlocking { + fun testExceptionCapturedInDeferred() = runTest { eval(""" val pool = LaunchPool(2) val good = pool.launch { 42 } @@ -83,7 +87,7 @@ class LaunchPoolTest { } @Test - fun testPoolContinuesAfterLambdaException() = runBlocking { + fun testPoolContinuesAfterLambdaException() = runTest { eval(""" val pool = LaunchPool(1) val bad = pool.launch { throw IllegalArgumentException("fail") } @@ -96,7 +100,7 @@ class LaunchPoolTest { } @Test - fun testLaunchAfterCloseAndJoinThrows() = runBlocking { + fun testLaunchAfterCloseAndJoinThrows() = runTest { eval(""" val pool = LaunchPool(2) pool.launch { 1 } @@ -107,7 +111,7 @@ class LaunchPoolTest { } @Test - fun testLaunchAfterCancelThrows() = runBlocking { + fun testLaunchAfterCancelThrows() = runTest { eval(""" val pool = LaunchPool(2) pool.cancel() @@ -117,7 +121,7 @@ class LaunchPoolTest { } @Test - fun testCancelAndJoinWaitsForWorkers() = runBlocking { + fun testCancelAndJoinWaitsForWorkers() = runTest { eval(""" val pool = LaunchPool(2) pool.launch { delay(5) } @@ -128,7 +132,7 @@ class LaunchPoolTest { } @Test - fun testCloseAndJoinDrainsQueue() = runBlocking { + fun testCloseAndJoinDrainsQueue() = runTest { eval(""" val mu = Mutex() val results = [] @@ -147,7 +151,7 @@ class LaunchPoolTest { } @Test - fun testBoundedQueueSuspendsProducer() = runBlocking { + fun testBoundedQueueSuspendsProducer() = runTest { eval(""" // queue of 2 + 1 worker; producer can only be 1 ahead of what's running val pool = LaunchPool(1, 2) @@ -165,7 +169,7 @@ class LaunchPoolTest { } @Test - fun testUnlimitedQueueDefault() = runBlocking { + fun testUnlimitedQueueDefault() = runTest { eval(""" val pool = LaunchPool(4) val jobs = (1..50).map { n -> pool.launch { n } } @@ -177,7 +181,7 @@ class LaunchPoolTest { } @Test - fun testIdempotentClose() = runBlocking { + fun testIdempotentClose() = runTest { eval(""" val pool = LaunchPool(2) pool.closeAndJoin() diff --git a/lynglib/src/commonTest/kotlin/ScriptImportPreparationTest.kt b/lynglib/src/commonTest/kotlin/ScriptImportPreparationTest.kt index 2444b1a..bb0229d 100644 --- a/lynglib/src/commonTest/kotlin/ScriptImportPreparationTest.kt +++ b/lynglib/src/commonTest/kotlin/ScriptImportPreparationTest.kt @@ -190,4 +190,63 @@ class ScriptImportPreparationTest { session.cancelAndJoin() } } + + @Test + fun importedContextReceiverExtensionIsAvailableInReceiverDsl() = runTest { + val manager = Script.defaultImportManager.copy().apply { + addTextPackages( + """ + package imported.ctxdsl + + class Tag(name: String) { + val name = name + var inner = "" + + fun child(tagName: String, block: Tag.()->void) { + val child = Tag(tagName) + child.apply { block(this) } + inner += child.render() + } + + fun h3(block: Tag.()->void) { child("h3", block) } + fun addText(text: String) { inner += text } + fun render() = "<" + name + ">" + inner + "" + } + + context(Tag) + fun String.unaryPlus() { + this@Tag.addText(this) + } + """.trimIndent() + ) + } + val script = Compiler.compile( + Source( + "", + """ + import imported.ctxdsl + + fun html(block: Tag.()->void) { + val root = Tag("html") + root.apply { block(this) } + root.render() + } + + val page = html { + h3 { + +"Imported" + } + } + + assertEquals("

Imported

", page) + assertEquals("plain", +"plain") + page + """.trimIndent() + ), + manager + ) + + val result = script.execute(manager.newStdScope()) as ObjString + assertEquals("

Imported

", result.value) + } } diff --git a/lynglib/src/commonTest/kotlin/TypeInferenceTest.kt b/lynglib/src/commonTest/kotlin/TypeInferenceTest.kt index 766a78e..8f608ea 100644 --- a/lynglib/src/commonTest/kotlin/TypeInferenceTest.kt +++ b/lynglib/src/commonTest/kotlin/TypeInferenceTest.kt @@ -15,7 +15,7 @@ * */ -import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest import net.sergeych.lyng.eval import kotlin.test.Test @@ -30,7 +30,7 @@ class TypeInferenceTest { /** Channel field type inferred from constructor — accessed in a launch closure */ @Test - fun testChannelFieldInLaunchClosure() = runBlocking { + fun testChannelFieldInLaunchClosure() = runTest { eval(""" class Foo { private val ch = Channel(Channel.UNLIMITED) @@ -52,7 +52,7 @@ class TypeInferenceTest { /** Mutex field type inferred from constructor — used directly in a method body */ @Test - fun testMutexFieldDirectUse() = runBlocking { + fun testMutexFieldDirectUse() = runTest { eval(""" class Bar { private val mu = Mutex() @@ -69,7 +69,7 @@ class TypeInferenceTest { /** CompletableDeferred field type inferred — complete/await used directly */ @Test - fun testCompletableDeferredFieldDirectUse() = runBlocking { + fun testCompletableDeferredFieldDirectUse() = runTest { eval(""" class Baz { private val d = CompletableDeferred() @@ -84,7 +84,7 @@ class TypeInferenceTest { /** Channel field accessed inside a map closure within class initializer */ @Test - fun testChannelFieldInMapAndLaunchClosure() = runBlocking { + fun testChannelFieldInMapAndLaunchClosure() = runTest { eval(""" class Pool(n) { private val ch = Channel(Channel.UNLIMITED) @@ -106,7 +106,7 @@ class TypeInferenceTest { } @Test - fun testIterableFirstPreservesElementTypeForBlockReturnInference() = runBlocking { + fun testIterableFirstPreservesElementTypeForBlockReturnInference() = runTest { eval(""" class Item(title: String) @@ -121,7 +121,7 @@ class TypeInferenceTest { } @Test - fun testCallableLocalInitializedFromFunctionCallPreservesReturnType() = runBlocking { + fun testCallableLocalInitializedFromFunctionCallPreservesReturnType() = runTest { eval(""" fun makeAdder(base) { return { x -> x + base + 0.5 } diff --git a/lynglib/src/commonTest/kotlin/net/sergeych/lyng/OperatorOverloadingTest.kt b/lynglib/src/commonTest/kotlin/net/sergeych/lyng/OperatorOverloadingTest.kt index 1e47410..f02a2f8 100644 --- a/lynglib/src/commonTest/kotlin/net/sergeych/lyng/OperatorOverloadingTest.kt +++ b/lynglib/src/commonTest/kotlin/net/sergeych/lyng/OperatorOverloadingTest.kt @@ -53,6 +53,197 @@ class OperatorOverloadingTest { """.trimIndent()) } + @Test + fun testUnaryPlusDefaultIdentity() = runTest { + eval(""" + assertEquals(42, +42) + assertEquals(3.5, +3.5) + assertEquals("abc", +"abc") + + class Box(val text: String) { + fun upper() = text.upper() + } + + assertEquals("ABC", (+Box("abc")).upper()) + """.trimIndent()) + } + + @Test + fun testUnaryPlusOverloading() = runTest { + eval(""" + class Counter(val n: Int) { + fun unaryPlus() = Counter(this.n + 1) + fun equals(other: Counter) = this.n == other.n + } + + assertEquals(Counter(6), Counter(5).unaryPlus()) + assertEquals(Counter(6), +Counter(5)) + """.trimIndent()) + } + + @Test + fun testUnaryPlusExtensionOverloading() = runTest { + eval(""" + var out = "" + fun String.unaryPlus() { + out = out + this + } + + "Hello".unaryPlus() + " ".unaryPlus() + "Lyng".unaryPlus() + assertEquals("Hello Lyng", out) + out = "" + + +"Hello" + +" " + +"Lyng" + assertEquals("Hello Lyng", out) + """.trimIndent()) + } + + @Test + fun testUnaryPlusDslBuilderStyle() = runTest { + eval(""" + class Tag(name: String) { + val name = name + var inner = "" + + fun child(tagName: String, block: Tag.()->void) { + val child = Tag(tagName) + with(child) { block(this) } + inner += child.render() + } + + fun head(block: Tag.()->void) { child("head", block) } + fun body(block: Tag.()->void) { child("body", block) } + fun title(block: Tag.()->void) { child("title", block) } + fun h1(block: Tag.()->void) { child("h1", block) } + + fun addText(text: String) { + inner += text + } + + fun render() { + "<" + name + ">" + inner + "" + } + } + + context(Tag) + fun String.unaryPlus() { + this@Tag.addText(this) + } + + fun html(block: Tag.()->void) { + val root = Tag("html") + with(root) { block(this) } + root.render() + } + + val page = html { + head { + title { + +"Demo" + } + } + body { + h1 { + +"Heading 1" + } + } + } + + assertEquals("Demo

Heading 1

", page) + """.trimIndent()) + } + + @Test + fun testContextReceiverUnaryPlusDslBuilderStyle() = runTest { + eval(""" + class Tag(name: String) { + val name = name + var inner = "" + + fun child(tagName: String, block: Tag.()->void) { + val child = Tag(tagName) + with(child) { block(this) } + inner += child.render() + } + + fun h3(block: Tag.()->void) { child("h3", block) } + + fun addText(text: String) { + inner += text + } + + fun render() { + "<" + name + ">" + inner + "" + } + } + + context(Tag) + fun String.unaryPlus() { + this@Tag.addText(this) + } + + fun html(block: Tag.()->void) { + val root = Tag("html") + with(root) { block(this) } + root.render() + } + + val page = html { + h3 { + +"Heading 3" + } + } + + assertEquals("

Heading 3

", page) + assertEquals("plain", +"plain") + """.trimIndent()) + } + + @Test + fun testContextReceiverExtensionIsHiddenOutsideContext() = runTest { + val ex = assertFailsWith { + eval(""" + class Tag { + fun wrap(text: String) = "[" + text + "]" + } + + context(Tag) + fun String.mark() = this@Tag.wrap(this) + + "x".mark() + """.trimIndent()) + } + assertContains(ex.message ?: "", "no such member: mark on String") + } + + @Test + fun testContextReceiverExtensionIsHiddenInWrongContext() = runTest { + val ex = assertFailsWith { + eval(""" + class Tag { + fun wrap(text: String) = "[" + text + "]" + } + class Other + + context(Tag) + fun String.mark() = this@Tag.wrap(this) + + fun other(block: Other.()->void) { + with(Other()) { block(this) } + } + + other { + "x".mark() + } + """.trimIndent()) + } + assertContains(ex.message ?: "", "no such member: mark on String") + } + @Test fun testPlusAssignOverloading() = runTest { eval("""