From ae3b878c6203e0a545cf86064318c65f114550e4 Mon Sep 17 00:00:00 2001 From: Scott Jasso Date: Wed, 19 Feb 2025 00:40:29 -0800 Subject: [PATCH 1/3] Add support for IntoSet/Map(multiple = true) --- .../me/tatarka/inject/test/Multibinds.kt | 13 ++++ .../inject/compiler/InjectGenerator.kt | 14 +++++ .../tatarka/inject/compiler/TypeCollector.kt | 63 +++++++++++++------ .../me/tatarka/inject/compiler/TypeResult.kt | 13 ++-- .../inject/compiler/TypeResultGenerator.kt | 35 ++++++++--- .../inject/compiler/TypeResultResolver.kt | 58 +++++++++++------ .../tatarka/inject/annotations/Annotations.kt | 4 +- 7 files changed, 145 insertions(+), 55 deletions(-) diff --git a/integration-tests/common/src/main/kotlin/me/tatarka/inject/test/Multibinds.kt b/integration-tests/common/src/main/kotlin/me/tatarka/inject/test/Multibinds.kt index 2a7b06d4..dcb01149 100644 --- a/integration-tests/common/src/main/kotlin/me/tatarka/inject/test/Multibinds.kt +++ b/integration-tests/common/src/main/kotlin/me/tatarka/inject/test/Multibinds.kt @@ -24,6 +24,15 @@ abstract class SetComponent { @Provides @IntoSet get() = FooValue("2") } +@Component +abstract class EmptySetComponent { + abstract val items: Set + + @Provides + @IntoSet(multiple = true) + fun defaultEmptySet() = emptySet() +} + @Component abstract class DynamicKeyComponent { @@ -35,6 +44,10 @@ abstract class DynamicKeyComponent { val fooValue2 @Provides @IntoMap get() = "2" to FooValue("2") + + @Provides + @IntoMap(multiple = true) + fun fooValue3And4() = mapOf("3" to FooValue("3"), "4" to FooValue("4")) } @Component diff --git a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/InjectGenerator.kt b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/InjectGenerator.kt index 2cb13d37..29d59717 100644 --- a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/InjectGenerator.kt +++ b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/InjectGenerator.kt @@ -31,6 +31,7 @@ val SCOPE = ClassName(ANNOTATION_PACKAGE_NAME, "Scope") val INJECT = ClassName(ANNOTATION_PACKAGE_NAME, "Inject") val INTO_MAP = ClassName(ANNOTATION_PACKAGE_NAME, "IntoMap") val INTO_SET = ClassName(ANNOTATION_PACKAGE_NAME, "IntoSet") +const val ANNOTATION_MULTIPLE_ARG = "multiple" val ASSISTED = ClassName(ANNOTATION_PACKAGE_NAME, "Assisted") val ASSISTED_FACTORY = ClassName(ANNOTATION_PACKAGE_NAME, "AssistedFactory") const val ASSISTED_FACTORY_FUNCTION_ARG = "injectFunction" @@ -227,6 +228,19 @@ fun AstAnnotated.assistedFactoryFunctionName() = annotation(ASSISTED_FACTORY.packageName, ASSISTED_FACTORY.simpleName) ?.argument(ASSISTED_FACTORY_FUNCTION_ARG) as? String +fun AstAnnotated.isIntoMap() = hasAnnotation(INTO_MAP) + +fun AstAnnotated.isIntoSet() = hasAnnotation(INTO_SET) + +private fun AstAnnotated.multipleArgValue(annotationClass: ClassName) = + annotation(annotationClass.packageName, annotationClass.simpleName) + ?.argument(ANNOTATION_MULTIPLE_ARG) as Boolean? + ?: false + +fun AstAnnotated.isIntoMapMultiple() = multipleArgValue(INTO_MAP) + +fun AstAnnotated.isIntoSetMultiple() = multipleArgValue(INTO_SET) + fun AstClass.findInjectConstructors(messenger: Messenger, options: Options): AstConstructor? { val injectCtors = constructors.filter { if (options.enableJavaxAnnotations) { diff --git a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeCollector.kt b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeCollector.kt index 48a7a895..82e3c41c 100644 --- a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeCollector.kt +++ b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeCollector.kt @@ -114,27 +114,57 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti } } val scopedComponent = if (scope != null) astClass else null - if (member.hasAnnotation(INTO_MAP.packageName, INTO_MAP.simpleName)) { - // Pair -> Map + if (member.isIntoMap()) { val returnType = member.returnTypeFor(astClass) val key = TypeKey(returnType, qualifier) val resolvedType = returnType.resolvedType() - if (resolvedType.isPair()) { - val containerKey = ContainerKey.MapKey( - resolvedType.arguments[0], - resolvedType.arguments[1], - key.qualifier - ) - addContainerType(provider, key, containerKey, member, accessor, scope, scopedComponent) + + val containerKey = if (member.isIntoMapMultiple()) { + // Map -> Map + if (resolvedType.isMap()) { + ContainerKey.fromContainer(key) + } else { + provider.error("@IntoMap(multiple = true) must have return type of type Map", member) + null + } } else { - provider.error("@IntoMap must have return type of type Pair", member) + // Pair -> Map + if (resolvedType.isPair()) { + ContainerKey.MapKey( + resolvedType.arguments[0], + resolvedType.arguments[1], + key.qualifier + ) + } else { + provider.error("@IntoMap(multiple = false) must have return type of type Pair", member) + null + } + } + + if (containerKey != null) { + addContainerType(provider, key, containerKey, member, accessor, scope, scopedComponent) } - } else if (member.hasAnnotation(INTO_SET.packageName, INTO_SET.simpleName)) { - // A -> Set + } else if (member.isIntoSet()) { val returnType = member.returnTypeFor(astClass) val key = TypeKey(returnType, qualifier) - val containerKey = ContainerKey.SetKey(returnType, key.qualifier) - addContainerType(provider, key, containerKey, member, accessor, scope, scopedComponent) + + val containerKey = if (member.isIntoSetMultiple()) { + // Set -> Set + val resolvedType = returnType.resolvedType() + if (resolvedType.isSet()) { + ContainerKey.fromContainer(key) + } else { + provider.error("@IntoSet(multiple = true) must have return type of type Set", member) + null + } + } else { + // A -> Set + ContainerKey.SetKey(returnType, key.qualifier) + } + + if (containerKey != null) { + addContainerType(provider, key, containerKey, member, accessor, scope, scopedComponent) + } } else { val returnType = member.returnTypeFor(astClass) val key = TypeKey(returnType, qualifier) @@ -504,20 +534,15 @@ class Member( ) sealed class ContainerKey { - abstract val creator: String abstract fun containerTypeKey(provider: AstProvider): TypeKey data class SetKey(val type: AstType, val qualifier: AstAnnotation? = null) : ContainerKey() { - override val creator: String = "setOf" - override fun containerTypeKey(provider: AstProvider): TypeKey { return TypeKey(provider.declaredTypeOf(Set::class, type), qualifier) } } data class MapKey(val key: AstType, val value: AstType, val qualifier: AstAnnotation? = null) : ContainerKey() { - override val creator: String = "mapOf" - override fun containerTypeKey(provider: AstProvider): TypeKey { return TypeKey(provider.declaredTypeOf(Map::class, key, value), qualifier) } diff --git a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResult.kt b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResult.kt index 9aab42bf..70f5699c 100644 --- a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResult.kt +++ b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResult.kt @@ -80,11 +80,14 @@ sealed class TypeResult { } } - /** - * A container that holds the type (ex: Set or Map). - */ - class Container( - val creator: String, + class SetContainer( + val args: List, + ) : TypeResult() { + override val children + get() = args.iterator() + } + + class MapContainer( val args: List, ) : TypeResult() { override val children diff --git a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultGenerator.kt b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultGenerator.kt index 4811b3e5..a14367db 100644 --- a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultGenerator.kt +++ b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultGenerator.kt @@ -12,6 +12,7 @@ import com.squareup.kotlinpoet.PropertySpec import com.squareup.kotlinpoet.TypeName import com.squareup.kotlinpoet.TypeSpec import com.squareup.kotlinpoet.tags.TypeAliasTag +import com.squareup.kotlinpoet.withIndent data class TypeResultGenerator(val options: Options, val implicitAccessor: Accessor = Accessor.Empty) { @@ -68,7 +69,8 @@ data class TypeResultGenerator(val options: Options, val implicitAccessor: Acces is TypeResult.Provides -> generate() is TypeResult.Scoped -> generate() is TypeResult.Constructor -> generate() - is TypeResult.Container -> generate() + is TypeResult.SetContainer -> generate() + is TypeResult.MapContainer -> generate() is TypeResult.Function -> generate() is TypeResult.NamedFunction -> generate() is TypeResult.Object -> generate() @@ -245,18 +247,31 @@ data class TypeResultGenerator(val options: Options, val implicitAccessor: Acces private val TypeName.isTypeAlias: Boolean get() = tag(TypeAliasTag::class) != null - private fun TypeResult.Container.generate(): CodeBlock { + private fun TypeResult.SetContainer.generate(): CodeBlock { return CodeBlock.builder().apply { - add("$creator(") - add("\n⇥") - args.forEachIndexed { index, arg -> - if (index != 0) { - add(",\n") + beginControlFlow("buildSet(%L)", args.size) + withIndent { + for (arg in args) { + if (arg.key.type.isSet()) { + add("addAll(%L)\n", arg.generate()) + } else { + add("add(%L)\n", arg.generate()) + } } - add(arg.generate()) } - add("\n⇤") - add(")") + add("}") + }.build() + } + + private fun TypeResult.MapContainer.generate(): CodeBlock { + return CodeBlock.builder().apply { + beginControlFlow("buildMap(%L)", args.size) + withIndent { + for (arg in args) { + add("this += %L\n", arg.generate()) + } + } + add("}") }.build() } diff --git a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultResolver.kt b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultResolver.kt index 618e131f..ce483050 100644 --- a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultResolver.kt +++ b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultResolver.kt @@ -285,8 +285,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options: val containerKey = ContainerKey.SetKey(innerType, key.qualifier) val args = types.containerArgs(containerKey) if (args.isNotEmpty()) { - return Container( - creator = containerKey.creator, + return SetContainer( args = args, mapArg = { key, arg, types -> Provides(withTypes(types), arg.accessor, arg.method, arg.scope, key) @@ -298,8 +297,15 @@ class TypeResultResolver(private val provider: AstProvider, private val options: val containerKey = ContainerKey.SetKey(innerType.arguments.last(), key.qualifier) val args = types.containerArgs(containerKey) if (args.isEmpty()) return null - return Container( - creator = containerKey.creator, + val intoSetArgs = args.filter { (member, _) -> member.method.isIntoSetMultiple() } + if (intoSetArgs.isNotEmpty()) { + for (arg in intoSetArgs) { + provider.error("Cannot use @IntoSet(multiple = true) with a Set", arg.first.method) + } + return null + } + + return SetContainer( args = args, mapArg = { key, arg, types -> Function(withTypes(types), args = innerType.arguments.dropLast(1)) { context -> @@ -313,8 +319,15 @@ class TypeResultResolver(private val provider: AstProvider, private val options: val containerKey = ContainerKey.SetKey(innerType.arguments[0], key.qualifier) val args = types.containerArgs(containerKey) if (args.isEmpty()) return null - return Container( - creator = containerKey.creator, + val intoSetArgs = args.filter { (member, _) -> member.method.isIntoSetMultiple() } + if (intoSetArgs.isNotEmpty()) { + for (arg in intoSetArgs) { + provider.error("Cannot use @IntoSet(multiple = true) with a Set", arg.first.method) + } + return null + } + + return SetContainer( args = args, mapArg = { key, arg, types -> Lazy(key) { @@ -331,8 +344,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options: val containerKey = ContainerKey.MapKey(type.arguments[0], type.arguments[1], key.qualifier) val args = types.containerArgs(containerKey) if (args.isEmpty()) return null - return Container( - creator = containerKey.creator, + return MapContainer( args = args, mapArg = { key, arg, types -> Provides(withTypes(types), arg.accessor, arg.method, arg.scope, key) @@ -522,20 +534,28 @@ class TypeResultResolver(private val provider: AstProvider, private val options: ) } - private fun Container( - creator: String, + private fun SetContainer( args: List>, mapArg: (TypeKey, Member, TypeCollector.Result) -> TypeResult, ): TypeResult { - return TypeResult.Container( - creator = creator, - args = args.map { (arg, types) -> - val returnType = arg.method.returnType - val qualifier = qualifier(provider, options, arg.method, returnType) - val key = TypeKey(returnType, qualifier) - TypeResultRef(key, mapArg(key, arg, types)) - } - ) + return TypeResult.SetContainer(args = containerArgs(args, mapArg)) + } + + private fun MapContainer( + args: List>, + mapArg: (TypeKey, Member, TypeCollector.Result) -> TypeResult, + ): TypeResult { + return TypeResult.MapContainer(args = containerArgs(args, mapArg)) + } + + private fun containerArgs( + args: List>, + mapArg: (TypeKey, Member, TypeCollector.Result) -> TypeResult, + ) = args.map { (arg, types) -> + val returnType = arg.method.returnType + val qualifier = qualifier(provider, options, arg.method, returnType) + val key = TypeKey(returnType, qualifier) + TypeResultRef(key, mapArg(key, arg, types)) } private inline fun Function( diff --git a/kotlin-inject-runtime/src/commonMain/kotlin/me/tatarka/inject/annotations/Annotations.kt b/kotlin-inject-runtime/src/commonMain/kotlin/me/tatarka/inject/annotations/Annotations.kt index 93655447..dde8dc69 100644 --- a/kotlin-inject-runtime/src/commonMain/kotlin/me/tatarka/inject/annotations/Annotations.kt +++ b/kotlin-inject-runtime/src/commonMain/kotlin/me/tatarka/inject/annotations/Annotations.kt @@ -20,10 +20,10 @@ annotation class Provides annotation class Scope @Target(FUNCTION, PROPERTY_GETTER) -annotation class IntoSet +annotation class IntoSet(val multiple: Boolean = false) @Target(FUNCTION, PROPERTY_GETTER) -annotation class IntoMap +annotation class IntoMap(val multiple: Boolean = false) @Target(VALUE_PARAMETER) annotation class Assisted From a606d4a897bb7d50b5c69387a7d99e7a7a4ef203 Mon Sep 17 00:00:00 2001 From: Scott Jasso Date: Wed, 19 Feb 2025 01:19:39 -0800 Subject: [PATCH 2/3] Fix some tests --- .../kotlin/me/tatarka/inject/test/MultibindsTest.kt | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/integration-tests/common-test/src/test/kotlin/me/tatarka/inject/test/MultibindsTest.kt b/integration-tests/common-test/src/test/kotlin/me/tatarka/inject/test/MultibindsTest.kt index df81ccde..8cc80a5a 100644 --- a/integration-tests/common-test/src/test/kotlin/me/tatarka/inject/test/MultibindsTest.kt +++ b/integration-tests/common-test/src/test/kotlin/me/tatarka/inject/test/MultibindsTest.kt @@ -2,6 +2,7 @@ package me.tatarka.inject.test import assertk.assertThat import assertk.assertions.containsOnly +import assertk.assertions.isEmpty import kotlin.test.Test class MultibindsTest { @@ -15,13 +16,22 @@ class MultibindsTest { assertThat(component.lazyItems.map { it.value }).containsOnly(FooValue("1"), FooValue("2")) } + @Test + fun generates_a_component_that_provides_an_empty_set_by_default() { + val component = EmptySetComponent::class.create() + + assertThat(component.items).isEmpty() + } + @Test fun generates_a_component_that_provides_multiple_items_into_a_map() { val component = DynamicKeyComponent::class.create() assertThat(component.items).containsOnly( "1" to FooValue("1"), - "2" to FooValue("2") + "2" to FooValue("2"), + "3" to FooValue("3"), + "4" to FooValue("4"), ) } From 7e3f40c9d9ca1e63b1803c950e11f9e5a959d14d Mon Sep 17 00:00:00 2001 From: Scott Jasso Date: Wed, 19 Feb 2025 22:52:51 -0800 Subject: [PATCH 3/3] Fix support for Set>, fix annotations for properties --- .../kotlin/me/tatarka/kotlin/ast/KSAst.kt | 6 + .../me/tatarka/inject/test/MultibindsTest.kt | 3 +- .../me/tatarka/inject/test/Multibinds.kt | 13 +- .../tatarka/inject/compiler/TypeCollector.kt | 166 ++++++++++-------- .../me/tatarka/inject/compiler/TypeResult.kt | 1 + .../inject/compiler/TypeResultGenerator.kt | 3 +- .../inject/compiler/TypeResultResolver.kt | 83 ++++----- 7 files changed, 147 insertions(+), 128 deletions(-) diff --git a/ast/ksp/src/main/kotlin/me/tatarka/kotlin/ast/KSAst.kt b/ast/ksp/src/main/kotlin/me/tatarka/kotlin/ast/KSAst.kt index 7f1fccdc..9e460f6a 100644 --- a/ast/ksp/src/main/kotlin/me/tatarka/kotlin/ast/KSAst.kt +++ b/ast/ksp/src/main/kotlin/me/tatarka/kotlin/ast/KSAst.kt @@ -426,6 +426,12 @@ private class KSAstProperty(override val resolver: Resolver, override val declar return declaration.getter?.hasAnnotation(packageName, simpleName) == true } + override fun annotation(packageName: String, simpleName: String): AstAnnotation? { + return declaration.getter?.annotations(packageName, simpleName) + ?.firstOrNull() + ?.let { KSAstAnnotation(resolver, it) } + } + override fun annotationsAnnotatedWith(packageName: String, simpleName: String): Sequence { val declarationAnnotations = super.annotationsAnnotatedWith(packageName, simpleName) val getter = declaration.getter diff --git a/integration-tests/common-test/src/test/kotlin/me/tatarka/inject/test/MultibindsTest.kt b/integration-tests/common-test/src/test/kotlin/me/tatarka/inject/test/MultibindsTest.kt index 8cc80a5a..f60028d6 100644 --- a/integration-tests/common-test/src/test/kotlin/me/tatarka/inject/test/MultibindsTest.kt +++ b/integration-tests/common-test/src/test/kotlin/me/tatarka/inject/test/MultibindsTest.kt @@ -17,10 +17,11 @@ class MultibindsTest { } @Test - fun generates_a_component_that_provides_an_empty_set_by_default() { + fun generates_a_component_that_provides_an_empty_set() { val component = EmptySetComponent::class.create() assertThat(component.items).isEmpty() + assertThat(component.nestedItems).containsOnly(emptySet()) } @Test diff --git a/integration-tests/common/src/main/kotlin/me/tatarka/inject/test/Multibinds.kt b/integration-tests/common/src/main/kotlin/me/tatarka/inject/test/Multibinds.kt index dcb01149..701b0d6d 100644 --- a/integration-tests/common/src/main/kotlin/me/tatarka/inject/test/Multibinds.kt +++ b/integration-tests/common/src/main/kotlin/me/tatarka/inject/test/Multibinds.kt @@ -28,9 +28,15 @@ abstract class SetComponent { abstract class EmptySetComponent { abstract val items: Set + abstract val nestedItems: Set> + @Provides @IntoSet(multiple = true) - fun defaultEmptySet() = emptySet() + fun provideDefaultEmptySetUsingMultiple() = emptySet() + + @Provides + @IntoSet + fun provideEmptySetIntoSet() = emptySet() } @Component @@ -48,6 +54,11 @@ abstract class DynamicKeyComponent { @Provides @IntoMap(multiple = true) fun fooValue3And4() = mapOf("3" to FooValue("3"), "4" to FooValue("4")) + + val emptyFooValue + @Provides + @IntoMap(multiple = true) + get() = emptyMap() } @Component diff --git a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeCollector.kt b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeCollector.kt index 82e3c41c..e5555b6d 100644 --- a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeCollector.kt +++ b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeCollector.kt @@ -32,7 +32,7 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti private val types = mutableMapOf() // Map of container types to inject. Used for multibinding. - private val containerTypes = mutableMapOf>() + private val containerTypes = mutableMapOf>() // Map of types obtained from generated provider methods. This can be used for lookup when the underlying method // is not available (ex: because we only see an interface, or it's marked protected). @@ -66,8 +66,8 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti return null } - fun containerArgs(key: ContainerKey): List> { - val results = mutableListOf>() + fun containerArgs(key: ContainerKey): List> { + val results = mutableListOf>() for (result in iterator()) { val types = result.containerTypes[key] if (types != null) { @@ -114,69 +114,12 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti } } val scopedComponent = if (scope != null) astClass else null - if (member.isIntoMap()) { - val returnType = member.returnTypeFor(astClass) - val key = TypeKey(returnType, qualifier) - val resolvedType = returnType.resolvedType() - - val containerKey = if (member.isIntoMapMultiple()) { - // Map -> Map - if (resolvedType.isMap()) { - ContainerKey.fromContainer(key) - } else { - provider.error("@IntoMap(multiple = true) must have return type of type Map", member) - null - } - } else { - // Pair -> Map - if (resolvedType.isPair()) { - ContainerKey.MapKey( - resolvedType.arguments[0], - resolvedType.arguments[1], - key.qualifier - ) - } else { - provider.error("@IntoMap(multiple = false) must have return type of type Pair", member) - null - } - } - - if (containerKey != null) { - addContainerType(provider, key, containerKey, member, accessor, scope, scopedComponent) - } - } else if (member.isIntoSet()) { - val returnType = member.returnTypeFor(astClass) - val key = TypeKey(returnType, qualifier) - - val containerKey = if (member.isIntoSetMultiple()) { - // Set -> Set - val resolvedType = returnType.resolvedType() - if (resolvedType.isSet()) { - ContainerKey.fromContainer(key) - } else { - provider.error("@IntoSet(multiple = true) must have return type of type Set", member) - null - } - } else { - // A -> Set - ContainerKey.SetKey(returnType, key.qualifier) - } - - if (containerKey != null) { - addContainerType(provider, key, containerKey, member, accessor, scope, scopedComponent) - } - } else { - val returnType = member.returnTypeFor(astClass) - val key = TypeKey(returnType, qualifier) - if (accessor.isNotEmpty()) { - // May have already added from a resolvable provider - if (providerTypes.containsKey(key)) continue - // We out outside the current class, so complain if not accessible - if (member.visibility == AstVisibility.PROTECTED) { - provider.error("@Provides method is not accessible", member) - } - } - addMethod(key, member, accessor, scope, scopedComponent) + val returnType = member.returnTypeFor(astClass) + val key = TypeKey(returnType, qualifier) + when { + member.isIntoMap() -> collectIntoMapProvider(key, member, accessor, scope) + member.isIntoSet() -> collectIntoSetProvider(key, member, accessor, scope) + else -> collectProvider(key, member, accessor, scope, scopedComponent) } } @@ -245,6 +188,76 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti } } + private fun collectProvider( + key: TypeKey, + member: AstMember, + accessor: Accessor, + scope: AstAnnotation?, + scopedComponent: AstClass?, + ) { + if (accessor.isNotEmpty()) { + // May have already added from a resolvable provider + if (providerTypes.containsKey(key)) return + // We out outside the current class, so complain if not accessible + if (member.visibility == AstVisibility.PROTECTED) { + provider.error("@Provides method is not accessible", member) + } + } + addMethod(key, member, accessor, scope, scopedComponent) + } + + private fun collectIntoSetProvider( + key: TypeKey, + member: AstMember, + accessor: Accessor, + scope: AstAnnotation?, + ) { + val isMultiple = member.isIntoSetMultiple() + val containerKey = if (isMultiple) { + // Set -> Set + val resolvedType = key.type.resolvedType() + if (!resolvedType.isSet()) { + provider.error("@IntoSet(multiple = true) must have return type of type Set", member) + return + } + + ContainerKey.SetKey(resolvedType.arguments[0], key.qualifier) + } else { + // A -> Set + ContainerKey.SetKey(key.type, key.qualifier) + } + addContainerType(provider, key, containerKey, member, accessor, scope, isMultiple) + } + + private fun collectIntoMapProvider( + key: TypeKey, + member: AstMember, + accessor: Accessor, + scope: AstAnnotation?, + ) { + val resolvedType = key.type.resolvedType() + val isMultiple = member.isIntoMapMultiple() + if (isMultiple) { + // Map -> Map + if (!resolvedType.isMap()) { + provider.error("@IntoMap(multiple = true) must have return type of type Map", member) + return + } + } else { + // Pair -> Map + if (!resolvedType.isPair()) { + provider.error("@IntoMap(multiple = false) must have return type of type Pair", member) + return + } + } + val containerKey = ContainerKey.MapKey( + resolvedType.arguments[0], + resolvedType.arguments[1], + key.qualifier + ) + addContainerType(provider, key, containerKey, member, accessor, scope, isMultiple) + } + private fun checkDuplicateTypesBetweenResults( result1Types: Map, result1ContainerTypes: Map, @@ -272,7 +285,7 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti method: AstMember, accessor: Accessor, scope: AstAnnotation?, - scopedComponent: AstClass?, + isMultiple: Boolean ) { val current = type(containerKey.containerTypeKey(provider)) if (current != null) { @@ -281,7 +294,7 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti } containerTypes.getOrPut(containerKey) { mutableListOf() } - .add(method(method, accessor, scope, scopedComponent)) + .add(ContainerMember(method, accessor, scope, isMultiple)) } private fun addMethod( @@ -306,7 +319,7 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti } } - types[key] = method(method, accessor, scope, scopedComponent) + types[key] = Member(method, accessor, scope, scopedComponent) } private fun addProviderMethod(key: TypeKey, member: AstMember, accessor: Accessor) { @@ -316,14 +329,6 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti } } - private fun method(method: AstMember, accessor: Accessor, scope: AstAnnotation?, scopedComponent: AstClass?) = - Member( - method = method, - accessor = accessor, - scope = scope, - scopedComponent = scopedComponent - ) - private fun duplicate(key: TypeKey, newValue: AstElement, oldValue: AstElement) { provider.error("Cannot provide: $key", newValue) provider.error("as it is already provided", oldValue) @@ -533,6 +538,13 @@ class Member( val scopedComponent: AstClass? = null, ) +class ContainerMember( + val method: AstMember, + val accessor: Accessor, + val scope: AstAnnotation?, + val isMultiple: Boolean, +) + sealed class ContainerKey { abstract fun containerTypeKey(provider: AstProvider): TypeKey diff --git a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResult.kt b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResult.kt index 70f5699c..f535e2eb 100644 --- a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResult.kt +++ b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResult.kt @@ -42,6 +42,7 @@ sealed class TypeResult { val accessor: Accessor = Accessor.Empty, val receiver: TypeResultRef? = null, val isProperty: Boolean = false, + val isMultiple: Boolean = false, val parameters: Map = emptyMap(), ) : TypeResult() { override val children diff --git a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultGenerator.kt b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultGenerator.kt index a14367db..94e64541 100644 --- a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultGenerator.kt +++ b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultGenerator.kt @@ -252,7 +252,8 @@ data class TypeResultGenerator(val options: Options, val implicitAccessor: Acces beginControlFlow("buildSet(%L)", args.size) withIndent { for (arg in args) { - if (arg.key.type.isSet()) { + val result = arg.result + if (result is TypeResult.Provides && result.isMultiple) { add("addAll(%L)\n", arg.generate()) } else { add("add(%L)\n", arg.generate()) diff --git a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultResolver.kt b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultResolver.kt index ce483050..18c171b7 100644 --- a/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultResolver.kt +++ b/kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultResolver.kt @@ -285,71 +285,56 @@ class TypeResultResolver(private val provider: AstProvider, private val options: val containerKey = ContainerKey.SetKey(innerType, key.qualifier) val args = types.containerArgs(containerKey) if (args.isNotEmpty()) { - return SetContainer( - args = args, - mapArg = { key, arg, types -> - Provides(withTypes(types), arg.accessor, arg.method, arg.scope, key) - } - ) + return SetContainer(args) { key, arg, types -> + Provides(withTypes(types), arg.accessor, arg.method, arg.scope, key, isMultiple = arg.isMultiple) + } } if (innerType.isFunction()) { val containerKey = ContainerKey.SetKey(innerType.arguments.last(), key.qualifier) val args = types.containerArgs(containerKey) if (args.isEmpty()) return null - val intoSetArgs = args.filter { (member, _) -> member.method.isIntoSetMultiple() } - if (intoSetArgs.isNotEmpty()) { - for (arg in intoSetArgs) { - provider.error("Cannot use @IntoSet(multiple = true) with a Set", arg.first.method) - } - return null - } + addErrorForMultipleProvider(key, args) - return SetContainer( - args = args, - mapArg = { key, arg, types -> - Function(withTypes(types), args = innerType.arguments.dropLast(1)) { context -> - TypeResultRef(key, Provides(context, arg.accessor, arg.method, arg.scope, key)) - } + return SetContainer(args) { key, arg, types -> + Function(withTypes(types), args = innerType.arguments.dropLast(1)) { context -> + TypeResultRef(key, Provides(context, arg.accessor, arg.method, arg.scope, key)) } - ) + } } if (innerType.isLazy()) { val containerKey = ContainerKey.SetKey(innerType.arguments[0], key.qualifier) val args = types.containerArgs(containerKey) if (args.isEmpty()) return null - val intoSetArgs = args.filter { (member, _) -> member.method.isIntoSetMultiple() } - if (intoSetArgs.isNotEmpty()) { - for (arg in intoSetArgs) { - provider.error("Cannot use @IntoSet(multiple = true) with a Set", arg.first.method) - } - return null - } + addErrorForMultipleProvider(key, args) - return SetContainer( - args = args, - mapArg = { key, arg, types -> - Lazy(key) { - TypeResultRef(key, Provides(withTypes(types), arg.accessor, arg.method, arg.scope, key)) - } + return SetContainer(args) { key, arg, types -> + Lazy(key) { + TypeResultRef(key, Provides(withTypes(types), arg.accessor, arg.method, arg.scope, key)) } - ) + } } + return null } + private fun addErrorForMultipleProvider(key: TypeKey, args: List>) { + args.forEach { (member, _) -> + if (member.method.isIntoSetMultiple()) { + provider.error("Cannot use @IntoSet(multiple = true) with a ${key.type} binding", member.method) + } + } + } + private fun Context.map(key: TypeKey): TypeResult? { val type = key.type.resolvedType() val containerKey = ContainerKey.MapKey(type.arguments[0], type.arguments[1], key.qualifier) val args = types.containerArgs(containerKey) if (args.isEmpty()) return null - return MapContainer( - args = args, - mapArg = { key, arg, types -> - Provides(withTypes(types), arg.accessor, arg.method, arg.scope, key) - } - ) + return MapContainer(args) { key, arg, types -> + Provides(withTypes(types), arg.accessor, arg.method, arg.scope, key, isMultiple = arg.isMultiple) + } } private fun Context.assistedFactory(astClass: AstClass, key: TypeKey): TypeResult? { @@ -482,6 +467,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options: method: AstMember, scope: AstAnnotation?, key: TypeKey, + isMultiple: Boolean = false ): TypeResult { if (scope != null && method is AstFunction && method.isSuspend) { throw FailedToGenerateException( @@ -499,6 +485,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options: resolve(context, method, key) }, isProperty = method is AstProperty, + isMultiple = isMultiple, parameters = (method as? AstFunction)?.let { resolveParams(context, method, scope, it.parameters) } ?: emptyMap(), @@ -535,22 +522,22 @@ class TypeResultResolver(private val provider: AstProvider, private val options: } private fun SetContainer( - args: List>, - mapArg: (TypeKey, Member, TypeCollector.Result) -> TypeResult, + args: List>, + mapArg: (TypeKey, ContainerMember, TypeCollector.Result) -> TypeResult, ): TypeResult { - return TypeResult.SetContainer(args = containerArgs(args, mapArg)) + return TypeResult.SetContainer(containerArgs(args, mapArg)) } private fun MapContainer( - args: List>, - mapArg: (TypeKey, Member, TypeCollector.Result) -> TypeResult, + args: List>, + mapArg: (TypeKey, ContainerMember, TypeCollector.Result) -> TypeResult, ): TypeResult { - return TypeResult.MapContainer(args = containerArgs(args, mapArg)) + return TypeResult.MapContainer(containerArgs(args, mapArg)) } private fun containerArgs( - args: List>, - mapArg: (TypeKey, Member, TypeCollector.Result) -> TypeResult, + args: List>, + mapArg: (TypeKey, ContainerMember, TypeCollector.Result) -> TypeResult, ) = args.map { (arg, types) -> val returnType = arg.method.returnType val qualifier = qualifier(provider, options, arg.method, returnType)