Add context receiver extensions for DSLs

This commit is contained in:
Sergey Chernov 2026-04-29 20:49:50 +03:00
parent e107296bca
commit b2200e71ff
22 changed files with 860 additions and 71 deletions

View File

@ -479,6 +479,18 @@ val block: context(Html, Head) Body.()->String = {
} }
``` ```
Context receivers can also constrain extension functions. The extension is visible only when the required receiver is
already in the implicit receiver stack:
```lyng
class Tag { fun addText(text: String) { /* ... */ } }
context(Tag)
fun String.unaryPlus() {
this@Tag.addText(this)
}
```
- Field inheritance (`val`/`var`) and collisions - Field inheritance (`val`/`var`) and collisions
- Instance storage is kept per declaring class, internally disambiguated; unqualified read/write resolves to the first match in the resolution order (leftmost base). - Instance storage is kept per declaring class, internally disambiguated; unqualified read/write resolves to the first match in the resolution order (leftmost base).
- Qualified read/write (via `this@Type` or casts) targets the chosen ancestor’s storage. - Qualified read/write (via `this@Type` or casts) targets the chosen ancestor’s storage.
@ -650,10 +662,14 @@ Unary operators are overloaded by defining methods with no arguments:
| Operator | Method Name | | Operator | Method Name |
| :--- | :--- | | :--- | :--- |
| `+a` | `unaryPlus()` |
| `-a` | `negate()` | | `-a` | `negate()` |
| `!a` | `logicalNot()` | | `!a` | `logicalNot()` |
| `~a` | `bitNot()` | | `~a` | `bitNot()` |
`unaryPlus()` is useful in DSL-style builders where `+"text"` should append text to
the current receiver. See [samples/html_builder_dsl.lyng](samples/html_builder_dsl.lyng).
### Assignment Operators ### Assignment Operators
Assignment operators like `+=` first attempt to call a specific assignment method. If that method is not defined, they fall back to a combination of the binary operator and a regular assignment (e.g., `a = a + b`). Assignment operators like `+=` first attempt to call a specific assignment method. If that method is not defined, they fall back to a combination of the binary operator and a regular assignment (e.g., `a = a + b`).

View File

@ -83,6 +83,7 @@ Primary sources used: `lynglib/src/commonMain/kotlin/net/sergeych/lyng/{Parser,T
## 4. Operators (implemented) ## 4. Operators (implemented)
- Assignment: `=`, `+=`, `-=`, `*=`, `/=`, `%=`, `?=`. - Assignment: `=`, `+=`, `-=`, `*=`, `/=`, `%=`, `?=`.
- Logical: `||`, `&&`, unary `!`. - Logical: `||`, `&&`, unary `!`.
- Unary arithmetic/bitwise: unary `+`, unary `-`, `~`.
- Bitwise: `|`, `^`, `&`, `~`, shifts `<<`, `>>`. - Bitwise: `|`, `^`, `&`, `~`, shifts `<<`, `>>`.
- Equality/comparison: `==`, `!=`, `===`, `!==`, `<`, `<=`, `>`, `>=`, `<=>`, `=~`, `!~`. - Equality/comparison: `==`, `!=`, `===`, `!==`, `<`, `<=`, `>`, `>=`, `<=>`, `=~`, `!~`.
- Type/containment: `is`, `!is`, `in`, `!in`, `as`, `as?`. - Type/containment: `is`, `!is`, `in`, `!in`, `as`, `as?`.
@ -119,6 +120,7 @@ Primary sources used: `lynglib/src/commonMain/kotlin/net/sergeych/lyng/{Parser,T
- shorthand: `fun f(x) = expr`. - shorthand: `fun f(x) = expr`.
- generics: `fun f<T>(x: T): T`. - generics: `fun f<T>(x: T): T`.
- extension functions: `fun Type.name(...) { ... }`. - extension functions: `fun Type.name(...) { ... }`.
- context-aware extension functions: `context(Tag) fun String.unaryPlus() { this@Tag.addText(this) }`.
- named singleton `object` declarations can be extension receivers too: `fun Config.describe(...) { ... }`, `val Config.tag get() = ...`. - named singleton `object` declarations can be extension receivers too: `fun Config.describe(...) { ... }`, `val Config.tag get() = ...`.
- static extension functions are callable on the type object: `static fun List<T>.fill(...)` -> `List.fill(...)`. - static extension functions are callable on the type object: `static fun List<T>.fill(...)` -> `List.fill(...)`.
- delegated callable: `fun f(...) by delegate`. - delegated callable: `fun f(...) by delegate`.

View File

@ -0,0 +1,50 @@
class Tag(name: String) {
val name = name
var inner = ""
fun child(tagName: String, block: Tag.()->void) {
val child = Tag(tagName)
with(child) { block(this) }
inner += child.render()
}
fun head(block: Tag.()->void) { child("head", block) }
fun body(block: Tag.()->void) { child("body", block) }
fun title(block: Tag.()->void) { child("title", block) }
fun h1(block: Tag.()->void) { child("h1", block) }
fun addText(text: String) {
inner += text
}
fun render() {
"<" + name + ">" + inner + "</" + name + ">"
}
}
context(Tag)
fun String.unaryPlus() {
this@Tag.addText(this)
}
fun html(block: Tag.()->void) {
val root = Tag("html")
with(root) { block(this) }
root.render()
}
val page = html {
head {
title {
+"Demo"
}
}
body {
h1 {
+"Heading 1"
}
}
}
println(page)
assertEquals("<html><head><title>Demo</title></head><body><h1>Heading 1</h1></body></html>", page)

View File

@ -1,6 +1,9 @@
// Sample: Operator Overloading in Lyng // Sample: Operator Overloading in Lyng
class Vector<T>(val x: T, val y: T) { class Vector<T>(val x: T, val y: T) {
// Overload unary +
fun unaryPlus() = this
// Overload + // Overload +
fun plus(other: Vector<U>) = Vector(x + other.x, y + other.y) fun plus(other: Vector<U>) = Vector(x + other.x, y + other.y)
@ -28,6 +31,11 @@ val v2 = Vector(5, 5)
println("v1: " + v1) println("v1: " + v1)
println("v2: " + v2) println("v2: " + v2)
// Test unary +
val v0 = +v1
println("+v1 = " + v0)
assertEquals(Vector(10, 20), v0)
// Test binary + // Test binary +
val v3 = v1 + v2 val v3 = v1 + v2
println("v1 + v2 = " + v3) println("v1 + v2 = " + v3)

View File

@ -45,6 +45,7 @@ import net.sergeych.lyng.io.db.createDbModule
import net.sergeych.lyng.io.db.jdbc.createJdbcModule import net.sergeych.lyng.io.db.jdbc.createJdbcModule
import net.sergeych.lyng.io.db.sqlite.createSqliteModule import net.sergeych.lyng.io.db.sqlite.createSqliteModule
import net.sergeych.lyng.io.fs.createFs import net.sergeych.lyng.io.fs.createFs
import net.sergeych.lyng.io.html.createHtmlModule
import net.sergeych.lyng.io.http.createHttpModule import net.sergeych.lyng.io.http.createHttpModule
import net.sergeych.lyng.io.http.server.createHttpServerModule import net.sergeych.lyng.io.http.server.createHttpServerModule
import net.sergeych.lyng.io.net.createNetModule import net.sergeych.lyng.io.net.createNetModule
@ -146,6 +147,7 @@ private fun ImportManager.invalidateCliModuleCaches() {
invalidatePackageCache("lyng.io.console") invalidatePackageCache("lyng.io.console")
invalidatePackageCache("lyng.io.db.jdbc") invalidatePackageCache("lyng.io.db.jdbc")
invalidatePackageCache("lyng.io.db.sqlite") invalidatePackageCache("lyng.io.db.sqlite")
invalidatePackageCache("lyng.io.html")
invalidatePackageCache("lyng.io.http") invalidatePackageCache("lyng.io.http")
invalidatePackageCache("lyng.io.http.server") invalidatePackageCache("lyng.io.http.server")
invalidatePackageCache("lyng.io.ws") invalidatePackageCache("lyng.io.ws")
@ -237,6 +239,7 @@ private fun installCliModules(manager: ImportManager) {
createDbModule(manager) createDbModule(manager)
createJdbcModule(manager) createJdbcModule(manager)
createSqliteModule(manager) createSqliteModule(manager)
createHtmlModule(manager)
createHttpModule(PermitAllHttpAccessPolicy, manager) createHttpModule(PermitAllHttpAccessPolicy, manager)
createHttpServerModule(PermitAllNetAccessPolicy, manager) createHttpServerModule(PermitAllNetAccessPolicy, manager)
createWsModule(PermitAllWsAccessPolicy, manager) createWsModule(PermitAllWsAccessPolicy, manager)

View File

@ -0,0 +1,44 @@
/*
* Copyright 2026 Sergey S. Chernov real.sergeych@gmail.com
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package net.sergeych.lyng.io.html
import net.sergeych.lyng.ModuleScope
import net.sergeych.lyng.Scope
import net.sergeych.lyng.Source
import net.sergeych.lyng.pacman.ImportManager
import net.sergeych.lyngio.stdlib_included.htmlLyng
private const val HTML_MODULE_NAME = "lyng.io.html"
fun createHtmlModule(scope: Scope): Boolean = createHtmlModule(scope.importManager)
fun createHtml(scope: Scope): Boolean = createHtmlModule(scope)
fun createHtmlModule(manager: ImportManager): Boolean {
if (manager.packageNames.contains(HTML_MODULE_NAME)) return false
manager.addPackage(HTML_MODULE_NAME) { module ->
buildHtmlModule(module)
}
return true
}
fun createHtml(manager: ImportManager): Boolean = createHtmlModule(manager)
private suspend fun buildHtmlModule(module: ModuleScope) {
module.eval(Source(HTML_MODULE_NAME, htmlLyng))
}

View File

@ -0,0 +1,57 @@
/*
* Copyright 2026 Sergey S. Chernov real.sergeych@gmail.com
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package net.sergeych.lyng.io.html
import kotlinx.coroutines.test.runTest
import net.sergeych.lyng.Compiler
import net.sergeych.lyng.Script
import net.sergeych.lyng.Source
import net.sergeych.lyng.pacman.ImportManager
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
class LyngHtmlModuleTest {
@Test
fun testModuleRegistrationIsIdempotent() = runTest {
val importManager = ImportManager()
assertTrue(createHtmlModule(importManager))
assertFalse(createHtmlModule(importManager))
}
@Test
fun testModuleCanBeImported() = runTest {
val scope = Script.newScope()
createHtmlModule(scope.importManager)
val result = Compiler.compile(
Source(
"<html-test>",
"""
import lyng.io.html
42
""".trimIndent()
),
scope.importManager
).execute(scope)
assertEquals("42", result.inspect(scope))
}
}

View File

@ -0,0 +1,22 @@
package lyng.io.html
/*
HTML helpers package.
API surface is intentionally empty for now; this package exists so Lyng code
can import `lyng.io.html` and grow declarations here incrementally.
*/
class HtmlBuilder {
private val head: List<String> = []
private val body: List<String> = []
fun build(): String =
"<HTML>" +
(head.isEmpty() ? "" : head.joinToString("\n")) +
(body.isEmpty() ? "" : body.joinToString("\n")) +
"</HTML>"
}
fun buildHtml(f: HtmlBuilder.()->void): String {
HtmlBuilder().apply(f).build()

View File

@ -885,6 +885,19 @@ class Compiler(
) )
} }
private fun promotePreferredReceiverArg(context: Scope, preferredTypeName: String?) {
if (preferredTypeName == null) return
val receiverArg = context.args.list.firstOrNull { arg ->
arg.isInstanceOf(preferredTypeName) ||
((context[preferredTypeName]?.value as? ObjClass)?.let { typeClass ->
arg.isInstanceOf(typeClass)
} == true)
} ?: return
if (context.thisVariants.firstOrNull() !== receiverArg) {
context.setThisVariants(receiverArg, context.thisVariants)
}
}
private fun currentImplicitReceiverTypeNames(): List<String> { private fun currentImplicitReceiverTypeNames(): List<String> {
val result = mutableListOf<String>() val result = mutableListOf<String>()
for (ctx in codeContexts.asReversed()) { for (ctx in codeContexts.asReversed()) {
@ -2010,6 +2023,7 @@ class Compiler(
callableReturnTypeByScopeId = callableReturnTypeByScopeId, callableReturnTypeByScopeId = callableReturnTypeByScopeId,
callableReturnTypeByName = callableReturnTypeByName, callableReturnTypeByName = callableReturnTypeByName,
callSignatureByName = callSignatureByName, callSignatureByName = callSignatureByName,
extensionContextReceiversByWrapperName = extensionContextReceiversByWrapperName,
externBindingNames = externBindingNames, externBindingNames = externBindingNames,
preparedModuleBindingNames = importBindings.keys, preparedModuleBindingNames = importBindings.keys,
scopeRefPosByName = moduleReferencePosByName, scopeRefPosByName = moduleReferencePosByName,
@ -2078,19 +2092,60 @@ class Compiler(
private val rangeParamNamesStack = mutableListOf<Set<String>>() private val rangeParamNamesStack = mutableListOf<Set<String>>()
private val extensionNames = mutableSetOf<String>() private val extensionNames = mutableSetOf<String>()
private val extensionNamesByType = mutableMapOf<String, MutableSet<String>>() private val extensionNamesByType = mutableMapOf<String, MutableSet<String>>()
private val extensionContextReceiversByWrapperName = mutableMapOf<String, List<String>>()
private val useScopeSlots: Boolean = seedScope == null private val useScopeSlots: Boolean = seedScope == null
private fun registerExtensionName(typeName: String, memberName: String) { private fun registerExtensionName(typeName: String, memberName: String) {
extensionNamesByType.getOrPut(typeName) { mutableSetOf() }.add(memberName) extensionNamesByType.getOrPut(typeName) { mutableSetOf() }.add(memberName)
} }
private fun contextReceiverTypeName(typeDecl: TypeDecl): String? = when (typeDecl) {
is TypeDecl.Simple -> typeDecl.name.substringAfterLast('.')
is TypeDecl.Generic -> typeDecl.name.substringAfterLast('.')
else -> null
}
private fun contextReceiversSatisfied(required: List<String>, visibleReceivers: List<String> = currentImplicitReceiverTypeNames()): Boolean {
if (required.isEmpty()) return true
return required.all { req ->
visibleReceivers.any { visible ->
visible == req || resolveClassByName(visible)?.let { cls ->
cls.className == req || cls.mro.any { it.className == req }
} == true
}
}
}
private fun rememberExtensionContextReceivers(wrapperName: String, record: ObjRecord) {
val fnType = record.typeDecl as? TypeDecl.Function ?: return
if (fnType.contextReceivers.isEmpty()) return
val names = fnType.contextReceivers.mapNotNull(::contextReceiverTypeName)
if (names.size == fnType.contextReceivers.size) {
extensionContextReceiversByWrapperName[wrapperName] = names
}
}
private fun hasExtensionFor(typeName: String, memberName: String): Boolean { private fun hasExtensionFor(typeName: String, memberName: String): Boolean {
if (extensionNamesByType[typeName]?.contains(memberName) == true) return true if (extensionNamesByType[typeName]?.contains(memberName) == true) {
val wrapperName = extensionCallableName(typeName, memberName)
if (contextReceiversSatisfied(extensionContextReceiversByWrapperName[wrapperName].orEmpty())) return true
val getterName = extensionPropertyGetterName(typeName, memberName)
if (contextReceiversSatisfied(extensionContextReceiversByWrapperName[getterName].orEmpty())) return true
val setterName = extensionPropertySetterName(typeName, memberName)
if (contextReceiversSatisfied(extensionContextReceiversByWrapperName[setterName].orEmpty())) return true
}
val scopeRec = seedScope?.get(typeName) ?: importManager.rootScope.get(typeName) val scopeRec = seedScope?.get(typeName) ?: importManager.rootScope.get(typeName)
val cls = (scopeRec?.value as? ObjClass) ?: resolveTypeDeclObjClass(TypeDecl.Simple(typeName, false)) val cls = (scopeRec?.value as? ObjClass) ?: resolveTypeDeclObjClass(TypeDecl.Simple(typeName, false))
if (cls != null) { if (cls != null) {
for (base in cls.mro) { for (base in cls.mro) {
if (extensionNamesByType[base.className]?.contains(memberName) == true) return true if (extensionNamesByType[base.className]?.contains(memberName) == true) {
val wrapperName = extensionCallableName(base.className, memberName)
if (contextReceiversSatisfied(extensionContextReceiversByWrapperName[wrapperName].orEmpty())) return true
val getterName = extensionPropertyGetterName(base.className, memberName)
if (contextReceiversSatisfied(extensionContextReceiversByWrapperName[getterName].orEmpty())) return true
val setterName = extensionPropertySetterName(base.className, memberName)
if (contextReceiversSatisfied(extensionContextReceiversByWrapperName[setterName].orEmpty())) return true
}
} }
} }
val candidates = mutableListOf(typeName) val candidates = mutableListOf(typeName)
@ -2102,7 +2157,10 @@ class Compiler(
extensionPropertySetterName(baseName, memberName) extensionPropertySetterName(baseName, memberName)
) )
for (wrapperName in wrapperNames) { for (wrapperName in wrapperNames) {
if (!contextReceiversSatisfied(extensionContextReceiversByWrapperName[wrapperName].orEmpty())) continue
val resolved = resolveImportBinding(wrapperName, Pos.builtIn) ?: continue val resolved = resolveImportBinding(wrapperName, Pos.builtIn) ?: continue
rememberExtensionContextReceivers(wrapperName, resolved.record)
if (!contextReceiversSatisfied(extensionContextReceiversByWrapperName[wrapperName].orEmpty())) continue
val plan = moduleSlotPlan() val plan = moduleSlotPlan()
if (plan != null && !plan.slots.containsKey(wrapperName)) { if (plan != null && !plan.slots.containsKey(wrapperName)) {
declareSlotNameIn( declareSlotNameIn(
@ -2385,6 +2443,7 @@ class Compiler(
callableReturnTypeByScopeId = callableReturnTypeByScopeId, callableReturnTypeByScopeId = callableReturnTypeByScopeId,
callableReturnTypeByName = callableReturnTypeByName, callableReturnTypeByName = callableReturnTypeByName,
callSignatureByName = callSignatureByName, callSignatureByName = callSignatureByName,
extensionContextReceiversByWrapperName = extensionContextReceiversByWrapperName,
externCallableNames = externCallableNames, externCallableNames = externCallableNames,
externBindingNames = externBindingNames, externBindingNames = externBindingNames,
preparedModuleBindingNames = importBindings.keys, preparedModuleBindingNames = importBindings.keys,
@ -2420,6 +2479,7 @@ class Compiler(
callableReturnTypeByScopeId = callableReturnTypeByScopeId, callableReturnTypeByScopeId = callableReturnTypeByScopeId,
callableReturnTypeByName = callableReturnTypeByName, callableReturnTypeByName = callableReturnTypeByName,
callSignatureByName = callSignatureByName, callSignatureByName = callSignatureByName,
extensionContextReceiversByWrapperName = extensionContextReceiversByWrapperName,
externCallableNames = externCallableNames, externCallableNames = externCallableNames,
externBindingNames = externBindingNames, externBindingNames = externBindingNames,
preparedModuleBindingNames = importBindings.keys, preparedModuleBindingNames = importBindings.keys,
@ -2480,6 +2540,7 @@ class Compiler(
callableReturnTypeByScopeId = callableReturnTypeByScopeId, callableReturnTypeByScopeId = callableReturnTypeByScopeId,
callableReturnTypeByName = callableReturnTypeByName, callableReturnTypeByName = callableReturnTypeByName,
callSignatureByName = callSignatureByName, callSignatureByName = callSignatureByName,
extensionContextReceiversByWrapperName = extensionContextReceiversByWrapperName,
externCallableNames = externCallableNames, externCallableNames = externCallableNames,
externBindingNames = externBindingNames, externBindingNames = externBindingNames,
preparedModuleBindingNames = importBindings.keys, preparedModuleBindingNames = importBindings.keys,
@ -3721,7 +3782,7 @@ class Compiler(
val inlineBodyRef = argsDeclaration?.let { null } ?: extractInlineLambdaBodyRef(body) val inlineBodyRef = argsDeclaration?.let { null } ?: extractInlineLambdaBodyRef(body)
val supportsDirectInvokeFastPath = bytecodeFn != null && val supportsDirectInvokeFastPath = bytecodeFn != null &&
bytecodeFn.scopeSlotCount == 0 && bytecodeFn.scopeSlotCount == 0 &&
expectedReceiverType == null && effectiveExpectedReceiverType == null &&
!wrapAsExtensionCallable && !wrapAsExtensionCallable &&
!containsDelegatedRefs(body) !containsDelegatedRefs(body)
val ref = LambdaFnRef( val ref = LambdaFnRef(
@ -3733,9 +3794,10 @@ class Compiler(
override fun bytecodeBody(): BytecodeStatement? = fnStatements as? BytecodeStatement override fun bytecodeBody(): BytecodeStatement? = fnStatements as? BytecodeStatement
override fun callOnFast(scope: Scope): Obj? { override fun callOnFast(scope: Scope): Obj? {
val context = scope.applyClosureForBytecode(closureScope, preferredThisType = expectedReceiverType).also { val context = scope.applyClosureForBytecode(closureScope, preferredThisType = effectiveExpectedReceiverType).also {
it.args = scope.args it.args = scope.args
} }
promotePreferredReceiverArg(context, effectiveExpectedReceiverType)
if (captureSlots.isNotEmpty()) { if (captureSlots.isNotEmpty()) {
if (captureRecords != null) { if (captureRecords != null) {
context.captureRecords = captureRecords context.captureRecords = captureRecords
@ -3804,9 +3866,10 @@ class Compiler(
} }
override suspend fun execute(scope: Scope): Obj { override suspend fun execute(scope: Scope): Obj {
val context = scope.applyClosureForBytecode(closureScope, preferredThisType = expectedReceiverType).also { val context = scope.applyClosureForBytecode(closureScope, preferredThisType = effectiveExpectedReceiverType).also {
it.args = scope.args it.args = scope.args
} }
promotePreferredReceiverArg(context, effectiveExpectedReceiverType)
if (captureSlots.isNotEmpty()) { if (captureSlots.isNotEmpty()) {
if (captureRecords != null) { if (captureRecords != null) {
context.captureRecords = captureRecords context.captureRecords = captureRecords
@ -5321,6 +5384,7 @@ class Compiler(
is RangeRef -> ObjRange.type is RangeRef -> ObjRange.type
is ClassOperatorRef -> ObjClassType is ClassOperatorRef -> ObjClassType
is CastRef -> resolveTypeRefClass(ref.castTypeRef()) is CastRef -> resolveTypeRefClass(ref.castTypeRef())
is UnaryOpRef -> inferUnaryOpReturnClass(ref)
is IndexRef -> { is IndexRef -> {
val targetClass = resolveReceiverClassForMember(ref.targetRef) val targetClass = resolveReceiverClassForMember(ref.targetRef)
classMethodReturnClass(targetClass, "getAt") classMethodReturnClass(targetClass, "getAt")
@ -5439,6 +5503,7 @@ class Compiler(
?: resolveClassByName(ref.receiverTypeName())?.let { classMethodReturnTypeDecl(it, ref.methodName()) } ?: resolveClassByName(ref.receiverTypeName())?.let { classMethodReturnTypeDecl(it, ref.methodName()) }
} }
is CallRef -> callReturnTypeDeclByRef[ref] ?: inferCallReturnTypeDecl(ref) is CallRef -> callReturnTypeDeclByRef[ref] ?: inferCallReturnTypeDecl(ref)
is UnaryOpRef -> inferUnaryOpReturnTypeDecl(ref)
is BinaryOpRef -> inferBinaryOpReturnTypeDecl(ref) is BinaryOpRef -> inferBinaryOpReturnTypeDecl(ref)
is ElvisRef -> inferElvisTypeDecl(ref) is ElvisRef -> inferElvisTypeDecl(ref)
is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) } is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) }
@ -5513,6 +5578,7 @@ class Compiler(
is QualifiedThisMethodSlotCallRef -> is QualifiedThisMethodSlotCallRef ->
inferMethodCallReturnClass(resolveClassByName(ref.receiverTypeName()), ref.methodName()) inferMethodCallReturnClass(resolveClassByName(ref.receiverTypeName()), ref.methodName())
is CallRef -> inferCallReturnTypeDecl(ref)?.let { resolveTypeDeclObjClass(it) } ?: inferCallReturnClass(ref) is CallRef -> inferCallReturnTypeDecl(ref)?.let { resolveTypeDeclObjClass(it) } ?: inferCallReturnClass(ref)
is UnaryOpRef -> inferUnaryOpReturnClass(ref)
is BinaryOpRef -> inferBinaryOpReturnClass(ref) is BinaryOpRef -> inferBinaryOpReturnClass(ref)
is FieldRef -> { is FieldRef -> {
val targetClass = resolveReceiverClassForMember(ref.target) val targetClass = resolveReceiverClassForMember(ref.target)
@ -5539,6 +5605,13 @@ class Compiler(
else -> null else -> null
} }
private fun unaryOpMethodName(op: UnaryOp): String? = when (op) {
UnaryOp.POSITIVE -> "unaryPlus"
UnaryOp.NEGATE -> "negate"
UnaryOp.BITNOT -> "bitNot"
UnaryOp.NOT -> null
}
private fun interopOperatorFor(op: BinOp): InteropOperator? = when (op) { private fun interopOperatorFor(op: BinOp): InteropOperator? = when (op) {
BinOp.PLUS -> InteropOperator.Plus BinOp.PLUS -> InteropOperator.Plus
BinOp.MINUS -> InteropOperator.Minus BinOp.MINUS -> InteropOperator.Minus
@ -5614,6 +5687,67 @@ class Compiler(
} }
} }
private fun inferExtensionMethodReturnTypeDecl(
receiverDecl: TypeDecl?,
receiverClass: ObjClass?,
memberName: String
): TypeDecl? {
if (receiverClass == null) return null
for (cls in receiverClass.mro) {
val wrapperName = extensionCallableName(cls.className, memberName)
val resolved = resolveImportBinding(wrapperName, Pos.builtIn) ?: continue
registerImportBinding(wrapperName, resolved.binding, Pos.builtIn)
val wrapperType = resolved.record.typeDecl as? TypeDecl.Function ?: continue
val bindings = mutableMapOf<String, TypeDecl>()
val receiverParam = wrapperType.params.firstOrNull() ?: wrapperType.receiver
if (receiverParam != null && receiverDecl != null) {
collectTypeVarBindings(receiverParam, receiverDecl, bindings)
}
return if (bindings.isEmpty()) wrapperType.returnType
else substituteTypeAliasTypeVars(wrapperType.returnType, bindings)
}
return null
}
private fun inferUnaryOpReturnTypeDecl(ref: UnaryOpRef): TypeDecl? {
val operandDecl = resolveReceiverTypeDecl(ref.a)
val operandClass = resolveReceiverClassForMember(ref.a) ?: inferObjClassFromRef(ref.a)
return when (ref.op) {
UnaryOp.NOT -> typeDeclOfClass(ObjBool.type)
UnaryOp.POSITIVE -> {
unaryOpMethodName(ref.op)?.let { methodName ->
classMethodReturnTypeDecl(operandClass, methodName)?.let { return it }
inferExtensionMethodReturnTypeDecl(operandDecl, operandClass, methodName)?.let { return it }
}
operandDecl ?: operandClass?.let(::typeDeclOfClass)
}
UnaryOp.NEGATE -> when (operandClass) {
ObjInt.type -> typeDeclOfClass(ObjInt.type)
ObjReal.type -> typeDeclOfClass(ObjReal.type)
else -> unaryOpMethodName(ref.op)?.let { methodName ->
classMethodReturnTypeDecl(operandClass, methodName)
?: inferExtensionMethodReturnTypeDecl(operandDecl, operandClass, methodName)
}
}
UnaryOp.BITNOT -> when (operandClass) {
ObjInt.type -> typeDeclOfClass(ObjInt.type)
else -> unaryOpMethodName(ref.op)?.let { methodName ->
classMethodReturnTypeDecl(operandClass, methodName)
?: inferExtensionMethodReturnTypeDecl(operandDecl, operandClass, methodName)
}
}
}
}
private fun inferUnaryOpReturnClass(ref: UnaryOpRef): ObjClass? {
inferUnaryOpReturnTypeDecl(ref)?.let { declared ->
resolveTypeDeclObjClass(declared)?.let { return it }
if (declared is TypeDecl.TypeVar) return Obj.rootObjectType
}
val operandClass = resolveReceiverClassForMember(ref.a) ?: inferObjClassFromRef(ref.a)
return if (ref.op == UnaryOp.POSITIVE) operandClass else null
}
private fun inferBinaryOpReturnClass(ref: BinaryOpRef): ObjClass? { private fun inferBinaryOpReturnClass(ref: BinaryOpRef): ObjClass? {
inferBinaryOpReturnTypeDecl(ref)?.let { declared -> inferBinaryOpReturnTypeDecl(ref)?.let { declared ->
resolveTypeDeclObjClass(declared)?.let { return it } resolveTypeDeclObjClass(declared)?.let { return it }
@ -7439,6 +7573,7 @@ class Compiler(
is FastLocalVarRef -> nameObjClass[ref.name]?.className is FastLocalVarRef -> nameObjClass[ref.name]?.className
?: nameTypeDecl[ref.name]?.let { typeDeclName(it) } ?: nameTypeDecl[ref.name]?.let { typeDeclName(it) }
is QualifiedThisRef -> ref.typeName is QualifiedThisRef -> ref.typeName
is UnaryOpRef -> inferUnaryOpReturnClass(ref)?.className
else -> resolveReceiverClassForMember(ref)?.className else -> resolveReceiverClassForMember(ref)?.className
} }
} }
@ -7458,8 +7593,12 @@ class Compiler(
Token.Type.CHAR -> ConstRef(ObjChar(t.value[0]).asReadonly) Token.Type.CHAR -> ConstRef(ObjChar(t.value[0]).asReadonly)
Token.Type.PLUS -> { Token.Type.PLUS -> {
val n = parseNumber(true) parseNumberOrNull(true)?.let { n ->
ConstRef(n.asReadonly) ConstRef(n.asReadonly)
} ?: run {
val n = parseTerm() ?: throw ScriptError(t.pos, "Expecting expression after unary plus")
UnaryOpRef(UnaryOp.POSITIVE, n)
}
} }
Token.Type.MINUS -> { Token.Type.MINUS -> {
@ -7655,6 +7794,43 @@ class Compiler(
} }
} }
private fun parseContextReceiverDeclarationList(start: Pos): List<TypeDecl> {
if (!cc.skipTokenOfType(Token.Type.LPAREN, isOptional = true)) {
throw ScriptError(start, "expected '(' after context")
}
val receivers = mutableListOf<TypeDecl>()
cc.skipWsTokens()
if (cc.peekNextNonWhitespace().type == Token.Type.RPAREN) {
cc.nextNonWhitespace()
return receivers
}
while (true) {
val (decl, _) = parseTypeExpressionWithMini()
receivers += decl
val sep = cc.nextNonWhitespace()
when (sep.type) {
Token.Type.COMMA -> continue
Token.Type.RPAREN -> return receivers
else -> sep.raiseSyntax("expected ',' or ')' in context receiver list")
}
}
}
private suspend fun parseContextFunctionDeclaration(contextToken: Token): Statement {
val contextReceivers = parseContextReceiverDeclarationList(contextToken.pos)
val fn = cc.nextNonWhitespace()
if (fn.type != Token.Type.ID || (fn.value != "fun" && fn.value != "fn")) {
throw ScriptError(fn.pos, "context receivers are currently supported only on function declarations")
}
pendingDeclStart = contextToken.pos
pendingDeclDoc = consumePendingDoc()
return parseFunctionDeclaration(
isExtern = false,
isStatic = false,
contextReceiverTypeDecls = contextReceivers
)
}
/** /**
* Parse keyword-starting statement. * Parse keyword-starting statement.
* @return parsed statement or null if, for example. [id] is not among keywords * @return parsed statement or null if, for example. [id] is not among keywords
@ -7693,6 +7869,7 @@ class Compiler(
pendingDeclDoc = consumePendingDoc() pendingDeclDoc = consumePendingDoc()
parseFunctionDeclaration(isExtern = false, isStatic = false) parseFunctionDeclaration(isExtern = false, isStatic = false)
} }
"context" -> parseContextFunctionDeclaration(id)
// Visibility modifiers for declarations: private/protected val/var/fun/fn // Visibility modifiers for declarations: private/protected val/var/fun/fn
"while" -> parseWhileStatement() "while" -> parseWhileStatement()
"do" -> parseDoWhileStatement() "do" -> parseDoWhileStatement()
@ -9648,7 +9825,8 @@ class Compiler(
isOverride: Boolean = false, isOverride: Boolean = false,
isExtern: Boolean = false, isExtern: Boolean = false,
isStatic: Boolean = false, isStatic: Boolean = false,
isTransient: Boolean = isTransientFlag isTransient: Boolean = isTransientFlag,
contextReceiverTypeDecls: List<TypeDecl> = emptyList()
): Statement { ): Statement {
isTransientFlag = false isTransientFlag = false
val declarationAnnotationSpecs = pendingDeclAnnotations.toList() val declarationAnnotationSpecs = pendingDeclAnnotations.toList()
@ -9688,7 +9866,17 @@ class Compiler(
) )
} }
registerExtensionName(extTypeName, name) registerExtensionName(extTypeName, name)
if (contextReceiverTypeDecls.isNotEmpty()) {
val contextNames = contextReceiverTypeDecls.mapNotNull(::contextReceiverTypeName)
if (contextNames.size != contextReceiverTypeDecls.size) {
throw ScriptError(start, "context receiver types for extension functions must be class-like")
}
extensionContextReceiversByWrapperName[extensionCallableName(extTypeName, name)] = contextNames
}
} else { } else {
if (contextReceiverTypeDecls.isNotEmpty()) {
throw ScriptError(start, "context receivers are currently supported only on extension functions")
}
val t = cc.next() val t = cc.next()
if (t.type != Token.Type.ID) if (t.type != Token.Type.ID)
throw ScriptError(t.pos, "Expected identifier after 'fun'") throw ScriptError(t.pos, "Expected identifier after 'fun'")
@ -9774,6 +9962,7 @@ class Compiler(
if (parentContext is CodeContext.ClassBody && !isStatic && extTypeName == null) { if (parentContext is CodeContext.ClassBody && !isStatic && extTypeName == null) {
classMemberTypeDeclByName.getOrPut(parentContext.name) { mutableMapOf() }[name] = TypeDecl.Function( classMemberTypeDeclByName.getOrPut(parentContext.name) { mutableMapOf() }[name] = TypeDecl.Function(
receiver = receiverTypeDecl, receiver = receiverTypeDecl,
contextReceivers = contextReceiverTypeDecls,
params = argsDeclaration.params.map { it.type }, params = argsDeclaration.params.map { it.type },
returnType = returnTypeDecl ?: TypeDecl.TypeAny, returnType = returnTypeDecl ?: TypeDecl.TypeAny,
nullable = false nullable = false
@ -9851,7 +10040,7 @@ class Compiler(
CodeContext.Function( CodeContext.Function(
name, name,
implicitThisMembers = implicitThisMembers, implicitThisMembers = implicitThisMembers,
implicitReceiverTypeNames = listOfNotNull(implicitThisTypeName), implicitReceiverTypeNames = listOfNotNull(implicitThisTypeName) + contextReceiverTypeDecls.mapNotNull(::contextReceiverTypeName),
typeParams = typeParams, typeParams = typeParams,
typeParamDecls = typeParamDecls, typeParamDecls = typeParamDecls,
noImplicitThis = noImplicitThis noImplicitThis = noImplicitThis
@ -9949,6 +10138,7 @@ class Compiler(
run { run {
val memberTypeDecl = TypeDecl.Function( val memberTypeDecl = TypeDecl.Function(
receiver = receiverTypeDecl, receiver = receiverTypeDecl,
contextReceivers = contextReceiverTypeDecls,
params = argsDeclaration.params.map { it.type }, params = argsDeclaration.params.map { it.type },
returnType = inferredReturnDecl ?: TypeDecl.TypeAny, returnType = inferredReturnDecl ?: TypeDecl.TypeAny,
nullable = false nullable = false
@ -10062,7 +10252,7 @@ class Compiler(
} }
} }
if (extTypeName != null) { if (extTypeName != null) {
context.thisObj = scope.thisObj context.setThisVariants(scope.thisObj, context.thisVariants)
} }
val localNames = frame.fn.localSlotNames val localNames = frame.fn.localSlotNames
for (i in localNames.indices) { for (i in localNames.indices) {
@ -10145,6 +10335,7 @@ class Compiler(
annotation = annotation, annotation = annotation,
typeDecl = if (isDelegated) null else TypeDecl.Function( typeDecl = if (isDelegated) null else TypeDecl.Function(
receiver = receiverTypeDecl, receiver = receiverTypeDecl,
contextReceivers = contextReceiverTypeDecls,
params = argsDeclaration.params.map { it.type }, params = argsDeclaration.params.map { it.type },
returnType = inferredReturnDecl ?: TypeDecl.TypeAny, returnType = inferredReturnDecl ?: TypeDecl.TypeAny,
nullable = false nullable = false
@ -11644,6 +11835,7 @@ class Compiler(
val a = constOf(aRef) ?: return null val a = constOf(aRef) ?: return null
return when (op) { return when (op) {
UnaryOp.NOT -> if (a is ObjBool) if (!a.value) ObjTrue else ObjFalse else null UnaryOp.NOT -> if (a is ObjBool) if (!a.value) ObjTrue else ObjFalse else null
UnaryOp.POSITIVE -> a
UnaryOp.NEGATE -> when (a) { UnaryOp.NEGATE -> when (a) {
is ObjInt -> ObjInt.of(-a.value) is ObjInt -> ObjInt.of(-a.value)
is ObjReal -> ObjReal.of(-a.value) is ObjReal -> ObjReal.of(-a.value)

View File

@ -99,6 +99,20 @@ open class Scope(
extensions.getOrPut(cls) { mutableMapOf() }[name] = record extensions.getOrPut(cls) { mutableMapOf() }[name] = record
} }
private fun extensionContextReceiversSatisfied(record: ObjRecord): Boolean {
val fnType = record.typeDecl as? TypeDecl.Function ?: return true
if (fnType.contextReceivers.isEmpty()) return true
return fnType.contextReceivers.all { required ->
thisVariants.any { variant ->
when (required) {
is TypeDecl.Simple -> variant.isInstanceOf(required.name.substringAfterLast('.'))
is TypeDecl.Generic -> variant.isInstanceOf(required.name.substringAfterLast('.'))
else -> false
}
}
}
}
internal fun findExtension(receiverClass: ObjClass, name: String): ObjRecord? { internal fun findExtension(receiverClass: ObjClass, name: String): ObjRecord? {
var s: Scope? = this var s: Scope? = this
var hops = 0 var hops = 0
@ -106,7 +120,9 @@ open class Scope(
// Proximity rule: check all extensions in the current scope before going to parent. // Proximity rule: check all extensions in the current scope before going to parent.
// Priority within scope: more specific class in MRO wins. // Priority within scope: more specific class in MRO wins.
for (cls in receiverClass.mro) { for (cls in receiverClass.mro) {
s.extensions[cls]?.get(name)?.let { return it } s.extensions[cls]?.get(name)?.let {
if (extensionContextReceiversSatisfied(it)) return it
}
} }
if (s is BytecodeClosureScope) { if (s is BytecodeClosureScope) {
s.closureScope.findExtension(receiverClass, name)?.let { return it } s.closureScope.findExtension(receiverClass, name)?.let { return it }

View File

@ -43,6 +43,7 @@ class BytecodeCompiler(
private val callableReturnTypeByScopeId: Map<Int, Map<Int, ObjClass>> = emptyMap(), private val callableReturnTypeByScopeId: Map<Int, Map<Int, ObjClass>> = emptyMap(),
private val callableReturnTypeByName: Map<String, ObjClass> = emptyMap(), private val callableReturnTypeByName: Map<String, ObjClass> = emptyMap(),
private val callSignatureByName: Map<String, CallSignature> = emptyMap(), private val callSignatureByName: Map<String, CallSignature> = emptyMap(),
private val extensionContextReceiversByWrapperName: Map<String, List<String>> = emptyMap(),
private val externCallableNames: Set<String> = emptySet(), private val externCallableNames: Set<String> = emptySet(),
private val externBindingNames: Set<String> = emptySet(), private val externBindingNames: Set<String> = emptySet(),
private val preparedModuleBindingNames: Set<String> = emptySet(), private val preparedModuleBindingNames: Set<String> = emptySet(),
@ -1146,56 +1147,95 @@ class BytecodeCompiler(
} }
private fun compileUnary(ref: UnaryOpRef): CompiledValue? { private fun compileUnary(ref: UnaryOpRef): CompiledValue? {
val a = compileRef(unaryOperand(ref)) ?: return null
val out = allocSlot()
return when (unaryOp(ref)) { return when (unaryOp(ref)) {
UnaryOp.NEGATE -> when (a.type) { UnaryOp.POSITIVE -> {
SlotType.INT -> { val operandRef = unaryOperand(ref)
builder.emit(Opcode.NEG_INT, a.slot, out) if (hasUnaryCallable(operandRef, "unaryPlus")) {
CompiledValue(out, SlotType.INT) return compileMethodCall(MethodCallRef(operandRef, "unaryPlus", emptyList(), false, false))
} }
SlotType.REAL -> { val a = compileRef(operandRef) ?: return null
builder.emit(Opcode.NEG_REAL, a.slot, out) return when (a.type) {
CompiledValue(out, SlotType.REAL) SlotType.INT, SlotType.REAL -> a
} else -> {
else -> compileObjUnaryOp(unaryOperand(ref), a, "negate", Pos.builtIn) val obj = ensureObjSlot(a)
} val out = allocSlot()
UnaryOp.NOT -> { builder.emit(Opcode.POS_OBJ, obj.slot, out)
when (a.type) { updateSlotType(out, SlotType.OBJ)
SlotType.BOOL -> builder.emit(Opcode.NOT_BOOL, a.slot, out) slotObjClass[obj.slot]?.let { slotObjClass[out] = it }
SlotType.INT -> { CompiledValue(out, SlotType.OBJ)
val tmp = allocSlot()
builder.emit(Opcode.INT_TO_BOOL, a.slot, tmp)
builder.emit(Opcode.NOT_BOOL, tmp, out)
} }
SlotType.OBJ, SlotType.UNKNOWN -> {
val objSlot = ensureObjSlot(a)
val tmp = allocSlot()
builder.emit(Opcode.OBJ_TO_BOOL, objSlot.slot, tmp)
builder.emit(Opcode.NOT_BOOL, tmp, out)
updateSlotType(tmp, SlotType.BOOL)
}
else -> return null
} }
CompiledValue(out, SlotType.BOOL)
} }
UnaryOp.BITNOT -> { else -> {
if (a.type == SlotType.INT) { val a = compileRef(unaryOperand(ref)) ?: return null
builder.emit(Opcode.INV_INT, a.slot, out) val out = allocSlot()
return CompiledValue(out, SlotType.INT) when (unaryOp(ref)) {
UnaryOp.NEGATE -> when (a.type) {
SlotType.INT -> {
builder.emit(Opcode.NEG_INT, a.slot, out)
CompiledValue(out, SlotType.INT)
}
SlotType.REAL -> {
builder.emit(Opcode.NEG_REAL, a.slot, out)
CompiledValue(out, SlotType.REAL)
}
else -> compileObjUnaryOp(unaryOperand(ref), a, "negate", Pos.builtIn)
}
UnaryOp.NOT -> {
when (a.type) {
SlotType.BOOL -> builder.emit(Opcode.NOT_BOOL, a.slot, out)
SlotType.INT -> {
val tmp = allocSlot()
builder.emit(Opcode.INT_TO_BOOL, a.slot, tmp)
builder.emit(Opcode.NOT_BOOL, tmp, out)
}
SlotType.OBJ, SlotType.UNKNOWN -> {
val objSlot = ensureObjSlot(a)
val tmp = allocSlot()
builder.emit(Opcode.OBJ_TO_BOOL, objSlot.slot, tmp)
builder.emit(Opcode.NOT_BOOL, tmp, out)
updateSlotType(tmp, SlotType.BOOL)
}
else -> return null
}
CompiledValue(out, SlotType.BOOL)
}
UnaryOp.BITNOT -> {
if (a.type == SlotType.INT) {
builder.emit(Opcode.INV_INT, a.slot, out)
return CompiledValue(out, SlotType.INT)
}
return compileObjUnaryOp(unaryOperand(ref), a, "bitNot", Pos.builtIn)
}
UnaryOp.POSITIVE -> error("unreachable")
} }
return compileObjUnaryOp(unaryOperand(ref), a, "bitNot", Pos.builtIn)
} }
} }
} }
private fun hasUnaryCallable(ref: ObjRef, memberName: String): Boolean {
val receiverClass = resolveReceiverClass(ref) ?: return false
if (receiverClass == ObjDynamic.type) return false
if (receiverClass is ObjInstanceClass && !isThisReceiver(ref)) return true
val resolvedMember = receiverClass.resolveInstanceMember(memberName)
if (resolvedMember?.declaringClass?.className == "Obj") return false
val abstractRecord = receiverClass.members[memberName] ?: receiverClass.classScope?.objects?.get(memberName)
if (abstractRecord?.isAbstract == true) return false
val methodId = receiverClass.instanceMethodIdMap(includeAbstract = true)[memberName]
if (methodId != null && resolvedMember?.declaringClass?.className != "Obj") return true
val fieldId = if (resolvedMember != null) receiverClass.instanceFieldIdMap()[memberName] else null
if (fieldId != null) return true
return resolveExtensionCallableSlot(receiverClass, memberName) != null
}
private fun compileObjUnaryOp( private fun compileObjUnaryOp(
ref: ObjRef, ref: ObjRef,
value: CompiledValue, value: CompiledValue,
memberName: String, memberName: String,
pos: Pos pos: Pos,
defaultIdentity: Boolean = false
): CompiledValue? { ): CompiledValue? {
val receiverClass = resolveReceiverClass(ref) val receiverClass = resolveReceiverClass(ref) ?: slotObjClass[value.slot]
val methodId = receiverClass?.instanceMethodIdMap(includeAbstract = true)?.get(memberName) val methodId = receiverClass?.instanceMethodIdMap(includeAbstract = true)?.get(memberName)
if (methodId != null) { if (methodId != null) {
val receiverObj = ensureObjSlot(value) val receiverObj = ensureObjSlot(value)
@ -1204,6 +1244,19 @@ class BytecodeCompiler(
updateSlotType(dst, SlotType.OBJ) updateSlotType(dst, SlotType.OBJ)
return CompiledValue(dst, SlotType.OBJ) return CompiledValue(dst, SlotType.OBJ)
} }
val extSlot = when {
receiverClass != null -> resolveExtensionCallableSlot(receiverClass, memberName)
else -> resolveUniqueExtensionWrapperSlot(memberName, "__ext__")
}
if (extSlot != null) {
val callee = ensureObjSlot(extSlot)
val args = compileCallArgsWithReceiver(value, emptyList(), false) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null
val dst = allocSlot()
setPos(pos)
emitCallCompiled(callee, args.base, encodedCount, dst)
return CompiledValue(dst, SlotType.OBJ)
}
if (memberName == "negate" && if (memberName == "negate" &&
(receiverClass == null || isDelegateClass(receiverClass) || receiverClass in setOf(ObjInt.type, ObjReal.type)) (receiverClass == null || isDelegateClass(receiverClass) || receiverClass in setOf(ObjInt.type, ObjReal.type))
) { ) {
@ -1217,6 +1270,9 @@ class BytecodeCompiler(
updateSlotType(dst, SlotType.OBJ) updateSlotType(dst, SlotType.OBJ)
return CompiledValue(dst, SlotType.OBJ) return CompiledValue(dst, SlotType.OBJ)
} }
if (defaultIdentity) {
return value
}
throw BytecodeCompileException( throw BytecodeCompileException(
"Unknown member $memberName on ${receiverClass?.className ?: "unknown"}", "Unknown member $memberName on ${receiverClass?.className ?: "unknown"}",
pos pos
@ -5972,6 +6028,7 @@ class BytecodeCompiler(
): String? { ): String? {
for (receiverName in extensionReceiverTypeNames(receiverClass)) { for (receiverName in extensionReceiverTypeNames(receiverClass)) {
val candidate = wrapperName(receiverName, memberName) val candidate = wrapperName(receiverName, memberName)
if (!extensionContextReceiversSatisfied(candidate)) continue
if (allowedScopeNames != null && if (allowedScopeNames != null &&
!allowedScopeNames.contains(candidate) && !allowedScopeNames.contains(candidate) &&
!localSlotIndexByName.containsKey(candidate) !localSlotIndexByName.containsKey(candidate)
@ -5983,6 +6040,31 @@ class BytecodeCompiler(
return null return null
} }
private fun currentImplicitReceiverTypeNames(): List<String> {
val result = mutableListOf<String>()
inlineThisBindings.asReversed().forEach { binding ->
val typeName = binding.typeName ?: return@forEach
if (!result.contains(typeName)) result += typeName
}
implicitThisTypeName?.let {
if (!result.contains(it)) result += it
}
return result
}
private fun extensionContextReceiversSatisfied(wrapperName: String): Boolean {
val required = extensionContextReceiversByWrapperName[wrapperName].orEmpty()
if (required.isEmpty()) return true
val visible = currentImplicitReceiverTypeNames()
return required.all { req ->
visible.any { visibleName ->
visibleName == req || resolveTypeNameClass(visibleName)?.let { cls ->
cls.className == req || cls.mro.any { it.className == req }
} == true
}
}
}
private fun resolveUniqueExtensionWrapperName( private fun resolveUniqueExtensionWrapperName(
memberName: String, memberName: String,
wrapperPrefix: String wrapperPrefix: String
@ -5991,12 +6073,12 @@ class BytecodeCompiler(
val candidates = LinkedHashSet<String>() val candidates = LinkedHashSet<String>()
for (name in localSlotIndexByName.keys) { for (name in localSlotIndexByName.keys) {
if (name.startsWith(wrapperPrefix) && name.endsWith(suffix)) { if (name.startsWith(wrapperPrefix) && name.endsWith(suffix)) {
candidates.add(name) if (extensionContextReceiversSatisfied(name)) candidates.add(name)
} }
} }
for (name in scopeSlotIndexByName.keys) { for (name in scopeSlotIndexByName.keys) {
if (name.startsWith(wrapperPrefix) && name.endsWith(suffix)) { if (name.startsWith(wrapperPrefix) && name.endsWith(suffix)) {
candidates.add(name) if (extensionContextReceiversSatisfied(name)) candidates.add(name)
} }
} }
return candidates.singleOrNull() return candidates.singleOrNull()
@ -8313,6 +8395,19 @@ class BytecodeCompiler(
is ObjChar -> ObjChar.type is ObjChar -> ObjChar.type
else -> null else -> null
} }
is UnaryOpRef -> when (ref.op) {
UnaryOp.NOT -> ObjBool.type
UnaryOp.POSITIVE -> resolveReceiverClass(ref.a)
UnaryOp.NEGATE -> when (val operandClass = resolveReceiverClass(ref.a)) {
ObjInt.type -> ObjInt.type
ObjReal.type -> ObjReal.type
else -> inferMethodCallReturnClass(operandClass, "negate")
}
UnaryOp.BITNOT -> when (val operandClass = resolveReceiverClass(ref.a)) {
ObjInt.type -> ObjInt.type
else -> inferMethodCallReturnClass(operandClass, "bitNot")
}
}
is CastRef -> resolveTypeRefClass(ref.castTypeRef()) is CastRef -> resolveTypeRefClass(ref.castTypeRef())
?: resolveReceiverClass(ref.castValueRef()) ?: resolveReceiverClass(ref.castValueRef())
is FieldRef -> { is FieldRef -> {

View File

@ -107,6 +107,7 @@ class BytecodeStatement private constructor(
callableReturnTypeByScopeId: Map<Int, Map<Int, ObjClass>> = emptyMap(), callableReturnTypeByScopeId: Map<Int, Map<Int, ObjClass>> = emptyMap(),
callableReturnTypeByName: Map<String, ObjClass> = emptyMap(), callableReturnTypeByName: Map<String, ObjClass> = emptyMap(),
callSignatureByName: Map<String, CallSignature> = emptyMap(), callSignatureByName: Map<String, CallSignature> = emptyMap(),
extensionContextReceiversByWrapperName: Map<String, List<String>> = emptyMap(),
externCallableNames: Set<String> = emptySet(), externCallableNames: Set<String> = emptySet(),
externBindingNames: Set<String> = emptySet(), externBindingNames: Set<String> = emptySet(),
preparedModuleBindingNames: Set<String> = emptySet(), preparedModuleBindingNames: Set<String> = emptySet(),
@ -148,6 +149,7 @@ class BytecodeStatement private constructor(
callableReturnTypeByScopeId = callableReturnTypeByScopeId, callableReturnTypeByScopeId = callableReturnTypeByScopeId,
callableReturnTypeByName = callableReturnTypeByName, callableReturnTypeByName = callableReturnTypeByName,
callSignatureByName = callSignatureByName, callSignatureByName = callSignatureByName,
extensionContextReceiversByWrapperName = extensionContextReceiversByWrapperName,
externCallableNames = externCallableNames, externCallableNames = externCallableNames,
externBindingNames = externBindingNames, externBindingNames = externBindingNames,
preparedModuleBindingNames = preparedModuleBindingNames, preparedModuleBindingNames = preparedModuleBindingNames,

View File

@ -143,7 +143,7 @@ class CmdBuilder {
Opcode.UNBOX_INT_OBJ, Opcode.UNBOX_REAL_OBJ, Opcode.UNBOX_INT_OBJ, Opcode.UNBOX_REAL_OBJ,
Opcode.INT_TO_REAL, Opcode.REAL_TO_INT, Opcode.BOOL_TO_INT, Opcode.INT_TO_BOOL, Opcode.INT_TO_REAL, Opcode.REAL_TO_INT, Opcode.BOOL_TO_INT, Opcode.INT_TO_BOOL,
Opcode.OBJ_TO_BOOL, Opcode.GET_OBJ_CLASS, Opcode.OBJ_TO_BOOL, Opcode.GET_OBJ_CLASS,
Opcode.NEG_INT, Opcode.NEG_REAL, Opcode.NOT_BOOL, Opcode.INV_INT, Opcode.NEG_INT, Opcode.NEG_REAL, Opcode.NOT_BOOL, Opcode.INV_INT, Opcode.POS_OBJ,
Opcode.ASSERT_IS -> Opcode.ASSERT_IS ->
listOf(OperandKind.SLOT, OperandKind.SLOT) listOf(OperandKind.SLOT, OperandKind.SLOT)
Opcode.CHECK_IS, Opcode.MAKE_QUALIFIED_VIEW -> Opcode.CHECK_IS, Opcode.MAKE_QUALIFIED_VIEW ->
@ -698,6 +698,7 @@ class CmdBuilder {
} else { } else {
CmdNotBool(operands[0], operands[1]) CmdNotBool(operands[0], operands[1])
} }
Opcode.POS_OBJ -> CmdPosObj(operands[0], operands[1])
Opcode.AND_BOOL -> if (isFastLocal(operands[0]) && isFastLocal(operands[1]) && isFastLocal(operands[2])) { Opcode.AND_BOOL -> if (isFastLocal(operands[0]) && isFastLocal(operands[1]) && isFastLocal(operands[2])) {
CmdAndBoolLocal(operands[0] - scopeSlotCount, operands[1] - scopeSlotCount, operands[2] - scopeSlotCount) CmdAndBoolLocal(operands[0] - scopeSlotCount, operands[1] - scopeSlotCount, operands[2] - scopeSlotCount)
} else { } else {

View File

@ -450,6 +450,7 @@ object CmdDisassembler {
is CmdMulObj -> Opcode.MUL_OBJ to intArrayOf(cmd.a, cmd.b, cmd.dst) is CmdMulObj -> Opcode.MUL_OBJ to intArrayOf(cmd.a, cmd.b, cmd.dst)
is CmdDivObj -> Opcode.DIV_OBJ to intArrayOf(cmd.a, cmd.b, cmd.dst) is CmdDivObj -> Opcode.DIV_OBJ to intArrayOf(cmd.a, cmd.b, cmd.dst)
is CmdModObj -> Opcode.MOD_OBJ to intArrayOf(cmd.a, cmd.b, cmd.dst) is CmdModObj -> Opcode.MOD_OBJ to intArrayOf(cmd.a, cmd.b, cmd.dst)
is CmdPosObj -> Opcode.POS_OBJ to intArrayOf(cmd.a, cmd.dst)
is CmdContainsObj -> Opcode.CONTAINS_OBJ to intArrayOf(cmd.target, cmd.value, cmd.dst) is CmdContainsObj -> Opcode.CONTAINS_OBJ to intArrayOf(cmd.target, cmd.value, cmd.dst)
is CmdAssignOpObj -> Opcode.ASSIGN_OP_OBJ to intArrayOf(cmd.opId, cmd.targetSlot, cmd.valueSlot, cmd.dst, cmd.nameId) is CmdAssignOpObj -> Opcode.ASSIGN_OP_OBJ to intArrayOf(cmd.opId, cmd.targetSlot, cmd.valueSlot, cmd.dst, cmd.nameId)
is CmdJmp -> Opcode.JMP to intArrayOf(cmd.target) is CmdJmp -> Opcode.JMP to intArrayOf(cmd.target)
@ -593,6 +594,8 @@ object CmdDisassembler {
Opcode.ADD_OBJ, Opcode.SUB_OBJ, Opcode.MUL_OBJ, Opcode.DIV_OBJ, Opcode.MOD_OBJ, Opcode.CONTAINS_OBJ, Opcode.ADD_OBJ, Opcode.SUB_OBJ, Opcode.MUL_OBJ, Opcode.DIV_OBJ, Opcode.MOD_OBJ, Opcode.CONTAINS_OBJ,
Opcode.AND_BOOL, Opcode.OR_BOOL -> Opcode.AND_BOOL, Opcode.OR_BOOL ->
listOf(OperandKind.SLOT, OperandKind.SLOT, OperandKind.SLOT) listOf(OperandKind.SLOT, OperandKind.SLOT, OperandKind.SLOT)
Opcode.POS_OBJ ->
listOf(OperandKind.SLOT, OperandKind.SLOT)
Opcode.ASSIGN_OP_OBJ -> Opcode.ASSIGN_OP_OBJ ->
listOf(OperandKind.ID, OperandKind.SLOT, OperandKind.SLOT, OperandKind.SLOT, OperandKind.CONST) listOf(OperandKind.ID, OperandKind.SLOT, OperandKind.SLOT, OperandKind.SLOT, OperandKind.CONST)
Opcode.INC_INT, Opcode.DEC_INT, Opcode.RET, Opcode.ITER_PUSH, Opcode.LOAD_THIS -> Opcode.INC_INT, Opcode.DEC_INT, Opcode.RET, Opcode.ITER_PUSH, Opcode.LOAD_THIS ->

View File

@ -1942,6 +1942,14 @@ class CmdCmpGteObj(internal val a: Int, internal val b: Int, internal val dst: I
} }
} }
class CmdPosObj(internal val a: Int, internal val dst: Int) : Cmd() {
override suspend fun perform(frame: CmdFrame) {
val result = frame.slotToObj(a).unaryPlus(frame.ensureScope())
frame.storeObjResult(dst, result)
return
}
}
class CmdAddObj(internal val a: Int, internal val b: Int, internal val dst: Int) : Cmd() { class CmdAddObj(internal val a: Int, internal val b: Int, internal val dst: Int) : Cmd() {
override suspend fun perform(frame: CmdFrame) { override suspend fun perform(frame: CmdFrame) {
val result = frame.slotToObj(a).plus(frame.ensureScope(), frame.slotToObj(b)) val result = frame.slotToObj(a).plus(frame.ensureScope(), frame.slotToObj(b))
@ -4176,6 +4184,15 @@ class BytecodeLambdaCallable(
val context = callScope.applyClosureForBytecode(closureScope, preferredThisType).also { val context = callScope.applyClosureForBytecode(closureScope, preferredThisType).also {
it.args = args it.args = args
} }
preferredThisType?.let { typeName ->
val receiverArg = args.list.firstOrNull { arg ->
arg.isInstanceOf(typeName) ||
((context[typeName]?.value as? ObjClass)?.let { typeClass -> arg.isInstanceOf(typeClass) } == true)
}
if (receiverArg != null && context.thisVariants.firstOrNull() !== receiverArg) {
context.setThisVariants(receiverArg, context.thisVariants)
}
}
if (captureRecords != null) { if (captureRecords != null) {
context.captureRecords = captureRecords context.captureRecords = captureRecords
context.captureNames = captureNames context.captureNames = captureNames

View File

@ -144,6 +144,7 @@ enum class Opcode(val code: Int) {
MOD_OBJ(0x7B), MOD_OBJ(0x7B),
CONTAINS_OBJ(0x7C), CONTAINS_OBJ(0x7C),
ASSIGN_OP_OBJ(0x7D), ASSIGN_OP_OBJ(0x7D),
POS_OBJ(0x7E),
JMP(0x80), JMP(0x80),
JMP_IF_TRUE(0x81), JMP_IF_TRUE(0x81),

View File

@ -317,6 +317,12 @@ open class Obj {
} }
} }
open suspend fun unaryPlus(scope: Scope): Obj {
return invokeInstanceMethod(scope, "unaryPlus", Arguments.EMPTY) {
this
}
}
open suspend fun mul(scope: Scope, other: Obj): Obj { open suspend fun mul(scope: Scope, other: Obj): Obj {
val otherValue = when (other) { val otherValue = when (other) {
is FrameSlotRef -> other.read() is FrameSlotRef -> other.read()

View File

@ -73,7 +73,7 @@ class ClassOperatorRef(val target: ObjRef, val pos: Pos) : ObjRef {
} }
/** Unary operations supported by ObjRef. */ /** Unary operations supported by ObjRef. */
enum class UnaryOp { NOT, NEGATE, BITNOT } enum class UnaryOp { NOT, POSITIVE, NEGATE, BITNOT }
/** Binary operations supported by ObjRef. */ /** Binary operations supported by ObjRef. */
enum class BinOp { enum class BinOp {

View File

@ -15,17 +15,21 @@
* *
*/ */
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeout import kotlinx.coroutines.withTimeout
import kotlin.test.Test import kotlin.test.Test
import net.sergeych.lyng.eval as lyngEval import net.sergeych.lyng.eval as lyngEval
class LaunchPoolTest { class LaunchPoolTest {
private suspend fun eval(code: String) = withTimeout(2_000L) { lyngEval(code) } private suspend fun eval(code: String) = withContext(Dispatchers.Default) {
withTimeout(2_000L) { lyngEval(code) }
}
@Test @Test
fun testBasicExecution() = runBlocking<Unit> { fun testBasicExecution() = runTest {
eval(""" eval("""
val pool = LaunchPool(2) val pool = LaunchPool(2)
val d1 = pool.launch { 1 + 1 } val d1 = pool.launch { 1 + 1 }
@ -37,7 +41,7 @@ class LaunchPoolTest {
} }
@Test @Test
fun testResultsCollected() = runBlocking<Unit> { fun testResultsCollected() = runTest {
eval(""" eval("""
val pool = LaunchPool(4) val pool = LaunchPool(4)
val jobs = (1..10).map { n -> pool.launch { n * n } } val jobs = (1..10).map { n -> pool.launch { n * n } }
@ -48,7 +52,7 @@ class LaunchPoolTest {
} }
@Test @Test
fun testConcurrencyLimit() = runBlocking<Unit> { fun testConcurrencyLimit() = runTest {
eval(""" eval("""
// With maxWorkers=2, at most 2 tasks run at the same time. // With maxWorkers=2, at most 2 tasks run at the same time.
val mu = Mutex() val mu = Mutex()
@ -70,7 +74,7 @@ class LaunchPoolTest {
} }
@Test @Test
fun testExceptionCapturedInDeferred() = runBlocking<Unit> { fun testExceptionCapturedInDeferred() = runTest {
eval(""" eval("""
val pool = LaunchPool(2) val pool = LaunchPool(2)
val good = pool.launch { 42 } val good = pool.launch { 42 }
@ -83,7 +87,7 @@ class LaunchPoolTest {
} }
@Test @Test
fun testPoolContinuesAfterLambdaException() = runBlocking<Unit> { fun testPoolContinuesAfterLambdaException() = runTest {
eval(""" eval("""
val pool = LaunchPool(1) val pool = LaunchPool(1)
val bad = pool.launch { throw IllegalArgumentException("fail") } val bad = pool.launch { throw IllegalArgumentException("fail") }
@ -96,7 +100,7 @@ class LaunchPoolTest {
} }
@Test @Test
fun testLaunchAfterCloseAndJoinThrows() = runBlocking<Unit> { fun testLaunchAfterCloseAndJoinThrows() = runTest {
eval(""" eval("""
val pool = LaunchPool(2) val pool = LaunchPool(2)
pool.launch { 1 } pool.launch { 1 }
@ -107,7 +111,7 @@ class LaunchPoolTest {
} }
@Test @Test
fun testLaunchAfterCancelThrows() = runBlocking<Unit> { fun testLaunchAfterCancelThrows() = runTest {
eval(""" eval("""
val pool = LaunchPool(2) val pool = LaunchPool(2)
pool.cancel() pool.cancel()
@ -117,7 +121,7 @@ class LaunchPoolTest {
} }
@Test @Test
fun testCancelAndJoinWaitsForWorkers() = runBlocking<Unit> { fun testCancelAndJoinWaitsForWorkers() = runTest {
eval(""" eval("""
val pool = LaunchPool(2) val pool = LaunchPool(2)
pool.launch { delay(5) } pool.launch { delay(5) }
@ -128,7 +132,7 @@ class LaunchPoolTest {
} }
@Test @Test
fun testCloseAndJoinDrainsQueue() = runBlocking<Unit> { fun testCloseAndJoinDrainsQueue() = runTest {
eval(""" eval("""
val mu = Mutex() val mu = Mutex()
val results = [] val results = []
@ -147,7 +151,7 @@ class LaunchPoolTest {
} }
@Test @Test
fun testBoundedQueueSuspendsProducer() = runBlocking<Unit> { fun testBoundedQueueSuspendsProducer() = runTest {
eval(""" eval("""
// queue of 2 + 1 worker; producer can only be 1 ahead of what's running // queue of 2 + 1 worker; producer can only be 1 ahead of what's running
val pool = LaunchPool(1, 2) val pool = LaunchPool(1, 2)
@ -165,7 +169,7 @@ class LaunchPoolTest {
} }
@Test @Test
fun testUnlimitedQueueDefault() = runBlocking<Unit> { fun testUnlimitedQueueDefault() = runTest {
eval(""" eval("""
val pool = LaunchPool(4) val pool = LaunchPool(4)
val jobs = (1..50).map { n -> pool.launch { n } } val jobs = (1..50).map { n -> pool.launch { n } }
@ -177,7 +181,7 @@ class LaunchPoolTest {
} }
@Test @Test
fun testIdempotentClose() = runBlocking<Unit> { fun testIdempotentClose() = runTest {
eval(""" eval("""
val pool = LaunchPool(2) val pool = LaunchPool(2)
pool.closeAndJoin() pool.closeAndJoin()

View File

@ -190,4 +190,63 @@ class ScriptImportPreparationTest {
session.cancelAndJoin() session.cancelAndJoin()
} }
} }
@Test
fun importedContextReceiverExtensionIsAvailableInReceiverDsl() = runTest {
val manager = Script.defaultImportManager.copy().apply {
addTextPackages(
"""
package imported.ctxdsl
class Tag(name: String) {
val name = name
var inner = ""
fun child(tagName: String, block: Tag.()->void) {
val child = Tag(tagName)
child.apply { block(this) }
inner += child.render()
}
fun h3(block: Tag.()->void) { child("h3", block) }
fun addText(text: String) { inner += text }
fun render() = "<" + name + ">" + inner + "</" + name + ">"
}
context(Tag)
fun String.unaryPlus() {
this@Tag.addText(this)
}
""".trimIndent()
)
}
val script = Compiler.compile(
Source(
"<ctx-dsl-import>",
"""
import imported.ctxdsl
fun html(block: Tag.()->void) {
val root = Tag("html")
root.apply { block(this) }
root.render()
}
val page = html {
h3 {
+"Imported"
}
}
assertEquals("<html><h3>Imported</h3></html>", page)
assertEquals("plain", +"plain")
page
""".trimIndent()
),
manager
)
val result = script.execute(manager.newStdScope()) as ObjString
assertEquals("<html><h3>Imported</h3></html>", result.value)
}
} }

View File

@ -15,7 +15,7 @@
* *
*/ */
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest
import net.sergeych.lyng.eval import net.sergeych.lyng.eval
import kotlin.test.Test import kotlin.test.Test
@ -30,7 +30,7 @@ class TypeInferenceTest {
/** Channel field type inferred from constructor — accessed in a launch closure */ /** Channel field type inferred from constructor — accessed in a launch closure */
@Test @Test
fun testChannelFieldInLaunchClosure() = runBlocking<Unit> { fun testChannelFieldInLaunchClosure() = runTest {
eval(""" eval("""
class Foo { class Foo {
private val ch = Channel(Channel.UNLIMITED) private val ch = Channel(Channel.UNLIMITED)
@ -52,7 +52,7 @@ class TypeInferenceTest {
/** Mutex field type inferred from constructor — used directly in a method body */ /** Mutex field type inferred from constructor — used directly in a method body */
@Test @Test
fun testMutexFieldDirectUse() = runBlocking<Unit> { fun testMutexFieldDirectUse() = runTest {
eval(""" eval("""
class Bar { class Bar {
private val mu = Mutex() private val mu = Mutex()
@ -69,7 +69,7 @@ class TypeInferenceTest {
/** CompletableDeferred field type inferred — complete/await used directly */ /** CompletableDeferred field type inferred — complete/await used directly */
@Test @Test
fun testCompletableDeferredFieldDirectUse() = runBlocking<Unit> { fun testCompletableDeferredFieldDirectUse() = runTest {
eval(""" eval("""
class Baz { class Baz {
private val d = CompletableDeferred() private val d = CompletableDeferred()
@ -84,7 +84,7 @@ class TypeInferenceTest {
/** Channel field accessed inside a map closure within class initializer */ /** Channel field accessed inside a map closure within class initializer */
@Test @Test
fun testChannelFieldInMapAndLaunchClosure() = runBlocking<Unit> { fun testChannelFieldInMapAndLaunchClosure() = runTest {
eval(""" eval("""
class Pool(n) { class Pool(n) {
private val ch = Channel(Channel.UNLIMITED) private val ch = Channel(Channel.UNLIMITED)
@ -106,7 +106,7 @@ class TypeInferenceTest {
} }
@Test @Test
fun testIterableFirstPreservesElementTypeForBlockReturnInference() = runBlocking<Unit> { fun testIterableFirstPreservesElementTypeForBlockReturnInference() = runTest {
eval(""" eval("""
class Item(title: String) class Item(title: String)
@ -121,7 +121,7 @@ class TypeInferenceTest {
} }
@Test @Test
fun testCallableLocalInitializedFromFunctionCallPreservesReturnType() = runBlocking<Unit> { fun testCallableLocalInitializedFromFunctionCallPreservesReturnType() = runTest {
eval(""" eval("""
fun makeAdder(base) { fun makeAdder(base) {
return { x -> x + base + 0.5 } return { x -> x + base + 0.5 }

View File

@ -53,6 +53,197 @@ class OperatorOverloadingTest {
""".trimIndent()) """.trimIndent())
} }
@Test
fun testUnaryPlusDefaultIdentity() = runTest {
eval("""
assertEquals(42, +42)
assertEquals(3.5, +3.5)
assertEquals("abc", +"abc")
class Box(val text: String) {
fun upper() = text.upper()
}
assertEquals("ABC", (+Box("abc")).upper())
""".trimIndent())
}
@Test
fun testUnaryPlusOverloading() = runTest {
eval("""
class Counter(val n: Int) {
fun unaryPlus() = Counter(this.n + 1)
fun equals(other: Counter) = this.n == other.n
}
assertEquals(Counter(6), Counter(5).unaryPlus())
assertEquals(Counter(6), +Counter(5))
""".trimIndent())
}
@Test
fun testUnaryPlusExtensionOverloading() = runTest {
eval("""
var out = ""
fun String.unaryPlus() {
out = out + this
}
"Hello".unaryPlus()
" ".unaryPlus()
"Lyng".unaryPlus()
assertEquals("Hello Lyng", out)
out = ""
+"Hello"
+" "
+"Lyng"
assertEquals("Hello Lyng", out)
""".trimIndent())
}
@Test
fun testUnaryPlusDslBuilderStyle() = runTest {
eval("""
class Tag(name: String) {
val name = name
var inner = ""
fun child(tagName: String, block: Tag.()->void) {
val child = Tag(tagName)
with(child) { block(this) }
inner += child.render()
}
fun head(block: Tag.()->void) { child("head", block) }
fun body(block: Tag.()->void) { child("body", block) }
fun title(block: Tag.()->void) { child("title", block) }
fun h1(block: Tag.()->void) { child("h1", block) }
fun addText(text: String) {
inner += text
}
fun render() {
"<" + name + ">" + inner + "</" + name + ">"
}
}
context(Tag)
fun String.unaryPlus() {
this@Tag.addText(this)
}
fun html(block: Tag.()->void) {
val root = Tag("html")
with(root) { block(this) }
root.render()
}
val page = html {
head {
title {
+"Demo"
}
}
body {
h1 {
+"Heading 1"
}
}
}
assertEquals("<html><head><title>Demo</title></head><body><h1>Heading 1</h1></body></html>", page)
""".trimIndent())
}
@Test
fun testContextReceiverUnaryPlusDslBuilderStyle() = runTest {
eval("""
class Tag(name: String) {
val name = name
var inner = ""
fun child(tagName: String, block: Tag.()->void) {
val child = Tag(tagName)
with(child) { block(this) }
inner += child.render()
}
fun h3(block: Tag.()->void) { child("h3", block) }
fun addText(text: String) {
inner += text
}
fun render() {
"<" + name + ">" + inner + "</" + name + ">"
}
}
context(Tag)
fun String.unaryPlus() {
this@Tag.addText(this)
}
fun html(block: Tag.()->void) {
val root = Tag("html")
with(root) { block(this) }
root.render()
}
val page = html {
h3 {
+"Heading 3"
}
}
assertEquals("<html><h3>Heading 3</h3></html>", page)
assertEquals("plain", +"plain")
""".trimIndent())
}
@Test
fun testContextReceiverExtensionIsHiddenOutsideContext() = runTest {
val ex = assertFailsWith<Throwable> {
eval("""
class Tag {
fun wrap(text: String) = "[" + text + "]"
}
context(Tag)
fun String.mark() = this@Tag.wrap(this)
"x".mark()
""".trimIndent())
}
assertContains(ex.message ?: "", "no such member: mark on String")
}
@Test
fun testContextReceiverExtensionIsHiddenInWrongContext() = runTest {
val ex = assertFailsWith<Throwable> {
eval("""
class Tag {
fun wrap(text: String) = "[" + text + "]"
}
class Other
context(Tag)
fun String.mark() = this@Tag.wrap(this)
fun other(block: Other.()->void) {
with(Other()) { block(this) }
}
other {
"x".mark()
}
""".trimIndent())
}
assertContains(ex.message ?: "", "no such member: mark on String")
}
@Test @Test
fun testPlusAssignOverloading() = runTest { fun testPlusAssignOverloading() = runTest {
eval(""" eval("""