Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ast/ksp/src/main/kotlin/me/tatarka/kotlin/ast/KSAst.kt
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,12 @@ private class KSAstProperty(override val resolver: Resolver, override val declar
return declaration.getter?.hasAnnotation(packageName, simpleName) == true
}

override fun annotation(packageName: String, simpleName: String): AstAnnotation? {
return declaration.getter?.annotations(packageName, simpleName)
?.firstOrNull()
?.let { KSAstAnnotation(resolver, it) }
}

override fun annotationsAnnotatedWith(packageName: String, simpleName: String): Sequence<AstAnnotation> {
val declarationAnnotations = super.annotationsAnnotatedWith(packageName, simpleName)
val getter = declaration.getter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package me.tatarka.inject.test

import assertk.assertThat
import assertk.assertions.containsOnly
import assertk.assertions.isEmpty
import kotlin.test.Test

class MultibindsTest {
Expand All @@ -15,13 +16,23 @@ class MultibindsTest {
assertThat(component.lazyItems.map { it.value }).containsOnly(FooValue("1"), FooValue("2"))
}

@Test
fun generates_a_component_that_provides_an_empty_set() {
val component = EmptySetComponent::class.create()

assertThat(component.items).isEmpty()
assertThat(component.nestedItems).containsOnly(emptySet<FooValue>())
}

@Test
fun generates_a_component_that_provides_multiple_items_into_a_map() {
val component = DynamicKeyComponent::class.create()

assertThat(component.items).containsOnly(
"1" to FooValue("1"),
"2" to FooValue("2")
"2" to FooValue("2"),
"3" to FooValue("3"),
"4" to FooValue("4"),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ abstract class SetComponent {
@Provides @IntoSet get() = FooValue("2")
}

@Component
abstract class EmptySetComponent {
abstract val items: Set<FooValue>

abstract val nestedItems: Set<Set<FooValue>>

@Provides
@IntoSet(multiple = true)
fun provideDefaultEmptySetUsingMultiple() = emptySet<FooValue>()

@Provides
@IntoSet
fun provideEmptySetIntoSet() = emptySet<FooValue>()
}

@Component
abstract class DynamicKeyComponent {

Expand All @@ -35,6 +50,15 @@ abstract class DynamicKeyComponent {

val fooValue2
@Provides @IntoMap get() = "2" to FooValue("2")

@Provides
@IntoMap(multiple = true)
fun fooValue3And4() = mapOf("3" to FooValue("3"), "4" to FooValue("4"))

val emptyFooValue
@Provides
@IntoMap(multiple = true)
get() = emptyMap<String, FooValue>()
}

@Component
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ val SCOPE = ClassName(ANNOTATION_PACKAGE_NAME, "Scope")
val INJECT = ClassName(ANNOTATION_PACKAGE_NAME, "Inject")
val INTO_MAP = ClassName(ANNOTATION_PACKAGE_NAME, "IntoMap")
val INTO_SET = ClassName(ANNOTATION_PACKAGE_NAME, "IntoSet")
const val ANNOTATION_MULTIPLE_ARG = "multiple"
val ASSISTED = ClassName(ANNOTATION_PACKAGE_NAME, "Assisted")
val ASSISTED_FACTORY = ClassName(ANNOTATION_PACKAGE_NAME, "AssistedFactory")
const val ASSISTED_FACTORY_FUNCTION_ARG = "injectFunction"
Expand Down Expand Up @@ -227,6 +228,19 @@ fun AstAnnotated.assistedFactoryFunctionName() =
annotation(ASSISTED_FACTORY.packageName, ASSISTED_FACTORY.simpleName)
?.argument(ASSISTED_FACTORY_FUNCTION_ARG) as? String

fun AstAnnotated.isIntoMap() = hasAnnotation(INTO_MAP)

fun AstAnnotated.isIntoSet() = hasAnnotation(INTO_SET)

private fun AstAnnotated.multipleArgValue(annotationClass: ClassName) =
annotation(annotationClass.packageName, annotationClass.simpleName)
?.argument(ANNOTATION_MULTIPLE_ARG) as Boolean?
?: false

fun AstAnnotated.isIntoMapMultiple() = multipleArgValue(INTO_MAP)

fun AstAnnotated.isIntoSetMultiple() = multipleArgValue(INTO_SET)

fun AstClass.findInjectConstructors(messenger: Messenger, options: Options): AstConstructor? {
val injectCtors = constructors.filter {
if (options.enableJavaxAnnotations) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
private val types = mutableMapOf<TypeKey, Member>()

// Map of container types to inject. Used for multibinding.
private val containerTypes = mutableMapOf<ContainerKey, MutableList<Member>>()
private val containerTypes = mutableMapOf<ContainerKey, MutableList<ContainerMember>>()

// Map of types obtained from generated provider methods. This can be used for lookup when the underlying method
// is not available (ex: because we only see an interface, or it's marked protected).
Expand Down Expand Up @@ -66,8 +66,8 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
return null
}

fun containerArgs(key: ContainerKey): List<Pair<Member, Result>> {
val results = mutableListOf<Pair<Member, Result>>()
fun containerArgs(key: ContainerKey): List<Pair<ContainerMember, Result>> {
val results = mutableListOf<Pair<ContainerMember, Result>>()
for (result in iterator()) {
val types = result.containerTypes[key]
if (types != null) {
Expand Down Expand Up @@ -114,39 +114,12 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
}
}
val scopedComponent = if (scope != null) astClass else null
if (member.hasAnnotation(INTO_MAP.packageName, INTO_MAP.simpleName)) {
// Pair<A, B> -> Map<A, B>
val returnType = member.returnTypeFor(astClass)
val key = TypeKey(returnType, qualifier)
val resolvedType = returnType.resolvedType()
if (resolvedType.isPair()) {
val containerKey = ContainerKey.MapKey(
resolvedType.arguments[0],
resolvedType.arguments[1],
key.qualifier
)
addContainerType(provider, key, containerKey, member, accessor, scope, scopedComponent)
} else {
provider.error("@IntoMap must have return type of type Pair", member)
}
} else if (member.hasAnnotation(INTO_SET.packageName, INTO_SET.simpleName)) {
// A -> Set<A>
val returnType = member.returnTypeFor(astClass)
val key = TypeKey(returnType, qualifier)
val containerKey = ContainerKey.SetKey(returnType, key.qualifier)
addContainerType(provider, key, containerKey, member, accessor, scope, scopedComponent)
} else {
val returnType = member.returnTypeFor(astClass)
val key = TypeKey(returnType, qualifier)
if (accessor.isNotEmpty()) {
// May have already added from a resolvable provider
if (providerTypes.containsKey(key)) continue
// We out outside the current class, so complain if not accessible
if (member.visibility == AstVisibility.PROTECTED) {
provider.error("@Provides method is not accessible", member)
}
}
addMethod(key, member, accessor, scope, scopedComponent)
val returnType = member.returnTypeFor(astClass)
val key = TypeKey(returnType, qualifier)
when {
member.isIntoMap() -> collectIntoMapProvider(key, member, accessor, scope)
member.isIntoSet() -> collectIntoSetProvider(key, member, accessor, scope)
else -> collectProvider(key, member, accessor, scope, scopedComponent)
}
}

Expand Down Expand Up @@ -215,6 +188,76 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
}
}

private fun collectProvider(
key: TypeKey,
member: AstMember,
accessor: Accessor,
scope: AstAnnotation?,
scopedComponent: AstClass?,
) {
if (accessor.isNotEmpty()) {
// May have already added from a resolvable provider
if (providerTypes.containsKey(key)) return
// We out outside the current class, so complain if not accessible
if (member.visibility == AstVisibility.PROTECTED) {
provider.error("@Provides method is not accessible", member)
}
}
addMethod(key, member, accessor, scope, scopedComponent)
}

private fun collectIntoSetProvider(
key: TypeKey,
member: AstMember,
accessor: Accessor,
scope: AstAnnotation?,
) {
val isMultiple = member.isIntoSetMultiple()
val containerKey = if (isMultiple) {
// Set<A> -> Set<A>
val resolvedType = key.type.resolvedType()
if (!resolvedType.isSet()) {
provider.error("@IntoSet(multiple = true) must have return type of type Set", member)
return
}

ContainerKey.SetKey(resolvedType.arguments[0], key.qualifier)
} else {
// A -> Set<A>
ContainerKey.SetKey(key.type, key.qualifier)
}
addContainerType(provider, key, containerKey, member, accessor, scope, isMultiple)
}

private fun collectIntoMapProvider(
key: TypeKey,
member: AstMember,
accessor: Accessor,
scope: AstAnnotation?,
) {
val resolvedType = key.type.resolvedType()
val isMultiple = member.isIntoMapMultiple()
if (isMultiple) {
// Map<A, B> -> Map<A, B>
if (!resolvedType.isMap()) {
provider.error("@IntoMap(multiple = true) must have return type of type Map", member)
return
}
} else {
// Pair<A, B> -> Map<A, B>
if (!resolvedType.isPair()) {
provider.error("@IntoMap(multiple = false) must have return type of type Pair", member)
return
}
}
val containerKey = ContainerKey.MapKey(
resolvedType.arguments[0],
resolvedType.arguments[1],
key.qualifier
)
addContainerType(provider, key, containerKey, member, accessor, scope, isMultiple)
}

private fun checkDuplicateTypesBetweenResults(
result1Types: Map<TypeKey, AstMember>,
result1ContainerTypes: Map<TypeKey, AstMember>,
Expand Down Expand Up @@ -242,7 +285,7 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
method: AstMember,
accessor: Accessor,
scope: AstAnnotation?,
scopedComponent: AstClass?,
isMultiple: Boolean
) {
val current = type(containerKey.containerTypeKey(provider))
if (current != null) {
Expand All @@ -251,7 +294,7 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
}

containerTypes.getOrPut(containerKey) { mutableListOf() }
.add(method(method, accessor, scope, scopedComponent))
.add(ContainerMember(method, accessor, scope, isMultiple))
}

private fun addMethod(
Expand All @@ -276,7 +319,7 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
}
}

types[key] = method(method, accessor, scope, scopedComponent)
types[key] = Member(method, accessor, scope, scopedComponent)
}

private fun addProviderMethod(key: TypeKey, member: AstMember, accessor: Accessor) {
Expand All @@ -286,14 +329,6 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
}
}

private fun method(method: AstMember, accessor: Accessor, scope: AstAnnotation?, scopedComponent: AstClass?) =
Member(
method = method,
accessor = accessor,
scope = scope,
scopedComponent = scopedComponent
)

private fun duplicate(key: TypeKey, newValue: AstElement, oldValue: AstElement) {
provider.error("Cannot provide: $key", newValue)
provider.error("as it is already provided", oldValue)
Expand Down Expand Up @@ -503,21 +538,23 @@ class Member(
val scopedComponent: AstClass? = null,
)

class ContainerMember(
val method: AstMember,
val accessor: Accessor,
val scope: AstAnnotation?,
val isMultiple: Boolean,
)

sealed class ContainerKey {
abstract val creator: String
abstract fun containerTypeKey(provider: AstProvider): TypeKey

data class SetKey(val type: AstType, val qualifier: AstAnnotation? = null) : ContainerKey() {
override val creator: String = "setOf"

override fun containerTypeKey(provider: AstProvider): TypeKey {
return TypeKey(provider.declaredTypeOf(Set::class, type), qualifier)
}
}

data class MapKey(val key: AstType, val value: AstType, val qualifier: AstAnnotation? = null) : ContainerKey() {
override val creator: String = "mapOf"

override fun containerTypeKey(provider: AstProvider): TypeKey {
return TypeKey(provider.declaredTypeOf(Map::class, key, value), qualifier)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ sealed class TypeResult {
val accessor: Accessor = Accessor.Empty,
val receiver: TypeResultRef? = null,
val isProperty: Boolean = false,
val isMultiple: Boolean = false,
val parameters: Map<String, TypeResultRef> = emptyMap(),
) : TypeResult() {
override val children
Expand Down Expand Up @@ -80,11 +81,14 @@ sealed class TypeResult {
}
}

/**
* A container that holds the type (ex: Set or Map).
*/
class Container(
val creator: String,
class SetContainer(
val args: List<TypeResultRef>,
) : TypeResult() {
override val children
get() = args.iterator()
}

class MapContainer(
val args: List<TypeResultRef>,
) : TypeResult() {
override val children
Expand Down
Loading