Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package me.tatarka.inject.test

import assertk.assertThat
import assertk.assertions.containsOnly
import assertk.assertions.isEmpty
import kotlin.test.Test

class MultibindsTest {
Expand All @@ -15,13 +16,22 @@ class MultibindsTest {
assertThat(component.lazyItems.map { it.value }).containsOnly(FooValue("1"), FooValue("2"))
}

@Test
fun generates_a_component_that_provides_an_empty_set_by_default() {
val component = EmptySetComponent::class.create()

assertThat(component.items).isEmpty()
}

@Test
fun generates_a_component_that_provides_multiple_items_into_a_map() {
val component = DynamicKeyComponent::class.create()

assertThat(component.items).containsOnly(
"1" to FooValue("1"),
"2" to FooValue("2")
"2" to FooValue("2"),
"3" to FooValue("3"),
"4" to FooValue("4"),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ abstract class SetComponent {
@Provides @IntoSet get() = FooValue("2")
}

@Component
abstract class EmptySetComponent {
abstract val items: Set<FooValue>

@Provides
@IntoSet(multiple = true)
fun defaultEmptySet() = emptySet<FooValue>()
}

@Component
abstract class DynamicKeyComponent {

Expand All @@ -35,6 +44,10 @@ abstract class DynamicKeyComponent {

val fooValue2
@Provides @IntoMap get() = "2" to FooValue("2")

@Provides
@IntoMap(multiple = true)
fun fooValue3And4() = mapOf("3" to FooValue("3"), "4" to FooValue("4"))
}

@Component
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ val SCOPE = ClassName(ANNOTATION_PACKAGE_NAME, "Scope")
val INJECT = ClassName(ANNOTATION_PACKAGE_NAME, "Inject")
val INTO_MAP = ClassName(ANNOTATION_PACKAGE_NAME, "IntoMap")
val INTO_SET = ClassName(ANNOTATION_PACKAGE_NAME, "IntoSet")
const val ANNOTATION_MULTIPLE_ARG = "multiple"
val ASSISTED = ClassName(ANNOTATION_PACKAGE_NAME, "Assisted")
val ASSISTED_FACTORY = ClassName(ANNOTATION_PACKAGE_NAME, "AssistedFactory")
const val ASSISTED_FACTORY_FUNCTION_ARG = "injectFunction"
Expand Down Expand Up @@ -227,6 +228,19 @@ fun AstAnnotated.assistedFactoryFunctionName() =
annotation(ASSISTED_FACTORY.packageName, ASSISTED_FACTORY.simpleName)
?.argument(ASSISTED_FACTORY_FUNCTION_ARG) as? String

fun AstAnnotated.isIntoMap() = hasAnnotation(INTO_MAP)

fun AstAnnotated.isIntoSet() = hasAnnotation(INTO_SET)

private fun AstAnnotated.multipleArgValue(annotationClass: ClassName) =
annotation(annotationClass.packageName, annotationClass.simpleName)
?.argument(ANNOTATION_MULTIPLE_ARG) as Boolean?
?: false

fun AstAnnotated.isIntoMapMultiple() = multipleArgValue(INTO_MAP)

fun AstAnnotated.isIntoSetMultiple() = multipleArgValue(INTO_SET)

fun AstClass.findInjectConstructors(messenger: Messenger, options: Options): AstConstructor? {
val injectCtors = constructors.filter {
if (options.enableJavaxAnnotations) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,27 +114,57 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
}
}
val scopedComponent = if (scope != null) astClass else null
if (member.hasAnnotation(INTO_MAP.packageName, INTO_MAP.simpleName)) {
// Pair<A, B> -> Map<A, B>
if (member.isIntoMap()) {
val returnType = member.returnTypeFor(astClass)
val key = TypeKey(returnType, qualifier)
val resolvedType = returnType.resolvedType()
if (resolvedType.isPair()) {
val containerKey = ContainerKey.MapKey(
resolvedType.arguments[0],
resolvedType.arguments[1],
key.qualifier
)
addContainerType(provider, key, containerKey, member, accessor, scope, scopedComponent)

val containerKey = if (member.isIntoMapMultiple()) {
// Map<A, B> -> Map<A, B>
if (resolvedType.isMap()) {
ContainerKey.fromContainer(key)
} else {
provider.error("@IntoMap(multiple = true) must have return type of type Map", member)
null
}
} else {
provider.error("@IntoMap must have return type of type Pair", member)
// Pair<A, B> -> Map<A, B>
if (resolvedType.isPair()) {
ContainerKey.MapKey(
resolvedType.arguments[0],
resolvedType.arguments[1],
key.qualifier
)
} else {
provider.error("@IntoMap(multiple = false) must have return type of type Pair", member)
null
}
}

if (containerKey != null) {
addContainerType(provider, key, containerKey, member, accessor, scope, scopedComponent)
}
} else if (member.hasAnnotation(INTO_SET.packageName, INTO_SET.simpleName)) {
// A -> Set<A>
} else if (member.isIntoSet()) {
val returnType = member.returnTypeFor(astClass)
val key = TypeKey(returnType, qualifier)
val containerKey = ContainerKey.SetKey(returnType, key.qualifier)
addContainerType(provider, key, containerKey, member, accessor, scope, scopedComponent)

val containerKey = if (member.isIntoSetMultiple()) {
// Set<A> -> Set<A>
val resolvedType = returnType.resolvedType()
if (resolvedType.isSet()) {
ContainerKey.fromContainer(key)
} else {
provider.error("@IntoSet(multiple = true) must have return type of type Set", member)
null
}
} else {
// A -> Set<A>
ContainerKey.SetKey(returnType, key.qualifier)
}

if (containerKey != null) {
addContainerType(provider, key, containerKey, member, accessor, scope, scopedComponent)
}
} else {
val returnType = member.returnTypeFor(astClass)
val key = TypeKey(returnType, qualifier)
Expand Down Expand Up @@ -504,20 +534,15 @@ class Member(
)

sealed class ContainerKey {
abstract val creator: String
abstract fun containerTypeKey(provider: AstProvider): TypeKey

data class SetKey(val type: AstType, val qualifier: AstAnnotation? = null) : ContainerKey() {
override val creator: String = "setOf"

override fun containerTypeKey(provider: AstProvider): TypeKey {
return TypeKey(provider.declaredTypeOf(Set::class, type), qualifier)
}
}

data class MapKey(val key: AstType, val value: AstType, val qualifier: AstAnnotation? = null) : ContainerKey() {
override val creator: String = "mapOf"

override fun containerTypeKey(provider: AstProvider): TypeKey {
return TypeKey(provider.declaredTypeOf(Map::class, key, value), qualifier)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,14 @@ sealed class TypeResult {
}
}

/**
* A container that holds the type (ex: Set or Map).
*/
class Container(
val creator: String,
class SetContainer(
val args: List<TypeResultRef>,
) : TypeResult() {
override val children
get() = args.iterator()
}

class MapContainer(
val args: List<TypeResultRef>,
) : TypeResult() {
override val children
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.squareup.kotlinpoet.PropertySpec
import com.squareup.kotlinpoet.TypeName
import com.squareup.kotlinpoet.TypeSpec
import com.squareup.kotlinpoet.tags.TypeAliasTag
import com.squareup.kotlinpoet.withIndent

data class TypeResultGenerator(val options: Options, val implicitAccessor: Accessor = Accessor.Empty) {

Expand Down Expand Up @@ -68,7 +69,8 @@ data class TypeResultGenerator(val options: Options, val implicitAccessor: Acces
is TypeResult.Provides -> generate()
is TypeResult.Scoped -> generate()
is TypeResult.Constructor -> generate()
is TypeResult.Container -> generate()
is TypeResult.SetContainer -> generate()
is TypeResult.MapContainer -> generate()
is TypeResult.Function -> generate()
is TypeResult.NamedFunction -> generate()
is TypeResult.Object -> generate()
Expand Down Expand Up @@ -245,18 +247,31 @@ data class TypeResultGenerator(val options: Options, val implicitAccessor: Acces
private val TypeName.isTypeAlias: Boolean
get() = tag(TypeAliasTag::class) != null

private fun TypeResult.Container.generate(): CodeBlock {
private fun TypeResult.SetContainer.generate(): CodeBlock {
return CodeBlock.builder().apply {
add("$creator(")
add("\n⇥")
args.forEachIndexed { index, arg ->
if (index != 0) {
add(",\n")
beginControlFlow("buildSet(%L)", args.size)
withIndent {
for (arg in args) {
if (arg.key.type.isSet()) {
add("addAll(%L)\n", arg.generate())
} else {
add("add(%L)\n", arg.generate())
}
}
add(arg.generate())
}
add("\n⇤")
add(")")
add("}")
}.build()
}

private fun TypeResult.MapContainer.generate(): CodeBlock {
return CodeBlock.builder().apply {
beginControlFlow("buildMap(%L)", args.size)
withIndent {
for (arg in args) {
add("this += %L\n", arg.generate())
}
}
add("}")
}.build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
val containerKey = ContainerKey.SetKey(innerType, key.qualifier)
val args = types.containerArgs(containerKey)
if (args.isNotEmpty()) {
return Container(
creator = containerKey.creator,
return SetContainer(
args = args,
mapArg = { key, arg, types ->
Provides(withTypes(types), arg.accessor, arg.method, arg.scope, key)
Expand All @@ -298,8 +297,15 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
val containerKey = ContainerKey.SetKey(innerType.arguments.last(), key.qualifier)
val args = types.containerArgs(containerKey)
if (args.isEmpty()) return null
return Container(
creator = containerKey.creator,
val intoSetArgs = args.filter { (member, _) -> member.method.isIntoSetMultiple() }
if (intoSetArgs.isNotEmpty()) {
for (arg in intoSetArgs) {
provider.error("Cannot use @IntoSet(multiple = true) with a Set<Function>", arg.first.method)
}
return null
}

return SetContainer(
args = args,
mapArg = { key, arg, types ->
Function(withTypes(types), args = innerType.arguments.dropLast(1)) { context ->
Expand All @@ -313,8 +319,15 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
val containerKey = ContainerKey.SetKey(innerType.arguments[0], key.qualifier)
val args = types.containerArgs(containerKey)
if (args.isEmpty()) return null
return Container(
creator = containerKey.creator,
val intoSetArgs = args.filter { (member, _) -> member.method.isIntoSetMultiple() }
if (intoSetArgs.isNotEmpty()) {
for (arg in intoSetArgs) {
provider.error("Cannot use @IntoSet(multiple = true) with a Set<Lazy>", arg.first.method)
}
return null
}

return SetContainer(
args = args,
mapArg = { key, arg, types ->
Lazy(key) {
Expand All @@ -331,8 +344,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
val containerKey = ContainerKey.MapKey(type.arguments[0], type.arguments[1], key.qualifier)
val args = types.containerArgs(containerKey)
if (args.isEmpty()) return null
return Container(
creator = containerKey.creator,
return MapContainer(
args = args,
mapArg = { key, arg, types ->
Provides(withTypes(types), arg.accessor, arg.method, arg.scope, key)
Expand Down Expand Up @@ -522,20 +534,28 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
)
}

private fun Container(
creator: String,
private fun SetContainer(
args: List<Pair<Member, TypeCollector.Result>>,
mapArg: (TypeKey, Member, TypeCollector.Result) -> TypeResult,
): TypeResult {
return TypeResult.Container(
creator = creator,
args = args.map { (arg, types) ->
val returnType = arg.method.returnType
val qualifier = qualifier(provider, options, arg.method, returnType)
val key = TypeKey(returnType, qualifier)
TypeResultRef(key, mapArg(key, arg, types))
}
)
return TypeResult.SetContainer(args = containerArgs(args, mapArg))
}

private fun MapContainer(
args: List<Pair<Member, TypeCollector.Result>>,
mapArg: (TypeKey, Member, TypeCollector.Result) -> TypeResult,
): TypeResult {
return TypeResult.MapContainer(args = containerArgs(args, mapArg))
}

private fun containerArgs(
args: List<Pair<Member, TypeCollector.Result>>,
mapArg: (TypeKey, Member, TypeCollector.Result) -> TypeResult,
) = args.map { (arg, types) ->
val returnType = arg.method.returnType
val qualifier = qualifier(provider, options, arg.method, returnType)
val key = TypeKey(returnType, qualifier)
TypeResultRef(key, mapArg(key, arg, types))
}

private inline fun Function(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ annotation class Provides
annotation class Scope

@Target(FUNCTION, PROPERTY_GETTER)
annotation class IntoSet
annotation class IntoSet(val multiple: Boolean = false)

@Target(FUNCTION, PROPERTY_GETTER)
annotation class IntoMap
annotation class IntoMap(val multiple: Boolean = false)

@Target(VALUE_PARAMETER)
annotation class Assisted
Expand Down