diff --git a/common/src/main/scala/org/mockito/JavaReflectionUtils.scala b/common/src/main/scala/org/mockito/JavaReflectionUtils.scala new file mode 100644 index 00000000..0c40c822 --- /dev/null +++ b/common/src/main/scala/org/mockito/JavaReflectionUtils.scala @@ -0,0 +1,53 @@ +package org.mockito + +import org.mockito.invocation.InvocationOnMock +import ru.vyarus.java.generics.resolver.GenericsResolver + +import java.lang.reflect.Field +import scala.util.control.NonFatal + +/** + * Utility methods for Java reflection operations, particularly for Mockito mocks. + */ +object JavaReflectionUtils { + + def resolveWithJavaGenerics(invocation: InvocationOnMock): Option[Class[_]] = + try Some(GenericsResolver.resolve(invocation.getMock.getClass).`type`(invocation.method.getDeclaringClass).method(invocation.method).resolveReturnClass()) + catch { + case _: Throwable => None + } + + def setFinalStatic(field: Field, newValue: AnyRef): Unit = + try { + // Try to get Unsafe instance (works with both sun.misc.Unsafe and jdk.internal.misc.Unsafe) + val unsafeClass: Class[_] = + try + Class.forName("sun.misc.Unsafe") + catch { + case _: ClassNotFoundException => Class.forName("jdk.internal.misc.Unsafe") + } + + val unsafeField = unsafeClass.getDeclaredField("theUnsafe") + unsafeField.setAccessible(true) + val unsafe = unsafeField.get(null) + + // Get methods via reflection to handle both Unsafe implementations + val staticFieldBaseMethod = unsafeClass.getMethod("staticFieldBase", classOf[Field]) + val staticFieldOffsetMethod = unsafeClass.getMethod("staticFieldOffset", classOf[Field]) + val putObjectMethod = unsafeClass.getMethod("putObject", classOf[Object], classOf[Long], classOf[Object]) + + // Make the field accessible + field.setAccessible(true) + + // Get base and offset for the field + val base: Object = staticFieldBaseMethod.invoke(unsafe, field) + val offset: Long = staticFieldOffsetMethod.invoke(unsafe, field).asInstanceOf[Long] + + // Set the field value directly + putObjectMethod.invoke(unsafe, base, java.lang.Long.valueOf(offset), newValue) + } catch { + case NonFatal(e) => + throw new IllegalStateException(s"Cannot modify final field ${field.getName}", e) + } + +} diff --git a/common/src/main/scala/org/mockito/MockitoAPI.scala b/common/src/main/scala/org/mockito/MockitoAPI.scala index e42ff084..1e88f210 100644 --- a/common/src/main/scala/org/mockito/MockitoAPI.scala +++ b/common/src/main/scala/org/mockito/MockitoAPI.scala @@ -12,7 +12,6 @@ package org.mockito import org.mockito.Answers.CALLS_REAL_METHODS -import org.mockito.ReflectionUtils.InvocationOnMockOps import org.mockito.internal.configuration.plugins.Plugins.getMockMaker import org.mockito.internal.creation.MockSettingsImpl import org.mockito.internal.exceptions.Reporter.notAMockPassedToVerifyNoMoreInteractions @@ -453,7 +452,6 @@ private[mockito] trait DoSomething { } private[mockito] trait MockitoEnhancer extends MockCreator { - implicit val invocationOps: InvocationOnMock => InvocationOnMockOps = InvocationOps /** * Delegates to Mockito.mock(type: Class[T]) It provides a nicer API as you can, for instance, do mock[MyClass] instead of @@ -630,9 +628,9 @@ private[mockito] trait MockitoEnhancer extends MockCreator { (settings: MockCreationSettings[O], pt: Prettifier) => ThreadAwareMockHandler(settings, realImpl)(pt) ) - ReflectionUtils.setFinalStatic(moduleField, threadAwareMock) + JavaReflectionUtils.setFinalStatic(moduleField, threadAwareMock) try block - finally ReflectionUtils.setFinalStatic(moduleField, realImpl) + finally JavaReflectionUtils.setFinalStatic(moduleField, realImpl) } } } diff --git a/common/src/main/scala/org/mockito/ReflectionUtils.scala b/common/src/main/scala/org/mockito/ReflectionUtils.scala index 2b80ab41..4443dd2c 100644 --- a/common/src/main/scala/org/mockito/ReflectionUtils.scala +++ b/common/src/main/scala/org/mockito/ReflectionUtils.scala @@ -1,16 +1,13 @@ package org.mockito -import java.lang.reflect.{ Field, Method, Modifier } - -import org.mockito.internal.ValueClassWrapper +import org.mockito.JavaReflectionUtils.resolveWithJavaGenerics import org.mockito.invocation.InvocationOnMock import org.scalactic.TripleEquals._ -import ru.vyarus.java.generics.resolver.GenericsResolver +import java.lang.reflect.Method import scala.reflect.ClassTag import scala.reflect.internal.Symbols import scala.util.{ Try => uTry } -import scala.util.control.NonFatal object ReflectionUtils { import scala.reflect.runtime.{ universe => ru } @@ -23,58 +20,37 @@ object ReflectionUtils { def methodToJava(sym: Symbols#MethodSymbol): Method }] - def listToTuple(l: List[Object]): Any = - l match { - case Nil => Nil - case h :: Nil => h - case _ => Class.forName(s"scala.Tuple${l.size}").getDeclaredConstructors.head.newInstance(l: _*) - } - - implicit class InvocationOnMockOps(val invocation: InvocationOnMock) extends AnyVal { - def mock[M]: M = invocation.getMock.asInstanceOf[M] - def method: Method = invocation.getMethod - def arg[A: ValueClassWrapper](index: Int): A = ValueClassWrapper[A].wrapAs[A](invocation.getArgument(index)) - def args: List[Any] = invocation.getArguments.toList - def callRealMethod[R](): R = invocation.callRealMethod.asInstanceOf[R] - def argsAsTuple: Any = listToTuple(args.map(_.asInstanceOf[Object])) - - def returnType: Class[_] = { - val javaReturnType = method.getReturnType + private[mockito] def returnType(invocation: InvocationOnMock): Class[_] = { + val javaReturnType = invocation.method.getReturnType - if (javaReturnType == classOf[Object]) - resolveWithScalaGenerics - .orElse(resolveWithJavaGenerics) - .getOrElse(javaReturnType) - else javaReturnType - } + if (javaReturnType == classOf[Object]) + resolveWithScalaGenerics(invocation) + .orElse(resolveWithJavaGenerics(invocation)) + .getOrElse(javaReturnType) + else javaReturnType + } - def returnsValueClass: Boolean = findTypeSymbol.exists(_.returnType.typeSymbol.isDerivedValueClass) + private[mockito] def returnsValueClass(invocation: InvocationOnMock): Boolean = + findTypeSymbol(invocation).exists(_.returnType.typeSymbol.isDerivedValueClass) - private def resolveWithScalaGenerics: Option[Class[_]] = - uTry { - findTypeSymbol - .filter(_.returnType.typeSymbol.isClass) - .map(_.asMethod.returnType.typeSymbol.asClass) - .map(mirror.runtimeClass) - }.toOption.flatten - - private def findTypeSymbol = - uTry { - mirror - .classSymbol(method.getDeclaringClass) - .info - .decls - .collectFirst { - case symbol if isNonConstructorMethod(symbol) && customMirror.methodToJava(symbol) === method => symbol - } - }.toOption.flatten + private def resolveWithScalaGenerics(invocation: InvocationOnMock): Option[Class[_]] = + uTry { + findTypeSymbol(invocation) + .filter(_.returnType.typeSymbol.isClass) + .map(_.asMethod.returnType.typeSymbol.asClass) + .map(mirror.runtimeClass) + }.toOption.flatten - private def resolveWithJavaGenerics: Option[Class[_]] = - try Some(GenericsResolver.resolve(invocation.getMock.getClass).`type`(method.getDeclaringClass).method(method).resolveReturnClass()) - catch { - case _: Throwable => None - } - } + private def findTypeSymbol(invocation: InvocationOnMock) = + uTry { + mirror + .classSymbol(invocation.method.getDeclaringClass) + .info + .decls + .collectFirst { + case symbol if isNonConstructorMethod(symbol) && customMirror.methodToJava(symbol) === invocation.method => symbol + } + }.toOption.flatten private def isNonConstructorMethod(d: ru.Symbol): Boolean = d.isMethod && !d.isConstructor @@ -113,37 +89,4 @@ object ReflectionUtils { .getOrElse(Seq.empty) } - def setFinalStatic(field: Field, newValue: AnyRef): Unit = - try { - // Try to get Unsafe instance (works with both sun.misc.Unsafe and jdk.internal.misc.Unsafe) - val unsafeClass: Class[_] = - try - Class.forName("sun.misc.Unsafe") - catch { - case _: ClassNotFoundException => Class.forName("jdk.internal.misc.Unsafe") - } - - val unsafeField = unsafeClass.getDeclaredField("theUnsafe") - unsafeField.setAccessible(true) - val unsafe = unsafeField.get(null) - - // Get methods via reflection to handle both Unsafe implementations - val staticFieldBaseMethod = unsafeClass.getMethod("staticFieldBase", classOf[Field]) - val staticFieldOffsetMethod = unsafeClass.getMethod("staticFieldOffset", classOf[Field]) - val putObjectMethod = unsafeClass.getMethod("putObject", classOf[Object], classOf[Long], classOf[Object]) - - // Make the field accessible - field.setAccessible(true) - - // Get base and offset for the field - val base: Object = staticFieldBaseMethod.invoke(unsafe, field) - val offset: Long = staticFieldOffsetMethod.invoke(unsafe, field).asInstanceOf[Long] - - // Set the field value directly - putObjectMethod.invoke(unsafe, base, java.lang.Long.valueOf(offset), newValue) - } catch { - case NonFatal(e) => - throw new IllegalStateException(s"Cannot modify final field ${field.getName}", e) - } - } diff --git a/common/src/main/scala/org/mockito/mockito.scala b/common/src/main/scala/org/mockito/mockito.scala index 5fa0bcbf..1335cdf1 100644 --- a/common/src/main/scala/org/mockito/mockito.scala +++ b/common/src/main/scala/org/mockito/mockito.scala @@ -2,7 +2,6 @@ package org import java.lang.reflect.Method -import org.mockito.ReflectionUtils.InvocationOnMockOps import org.mockito.internal.{ ValueClassExtractor, ValueClassWrapper } import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.ScalaAnswer @@ -21,7 +20,20 @@ package object mockito { def clazz[T](implicit classTag: ClassTag[T]): Class[T] = classTag.runtimeClass.asInstanceOf[Class[T]] - implicit val InvocationOps: InvocationOnMock => InvocationOnMockOps = new InvocationOnMockOps(_) + implicit class InvocationOnMockOps(val invocation: InvocationOnMock) { + def mock[M]: M = invocation.getMock.asInstanceOf[M] + def method: Method = invocation.getMethod + def arg[A: ValueClassWrapper](index: Int): A = ValueClassWrapper[A].wrapAs[A](invocation.getArgument(index)) + def args: List[Any] = invocation.getArguments.toList + def callRealMethod[R](): R = invocation.callRealMethod.asInstanceOf[R] + def argsAsTuple: Any = args.map(_.asInstanceOf[Object]) match { + case Nil => Nil + case h :: Nil => h + case l => Class.forName(s"scala.Tuple${l.size}").getDeclaredConstructors.head.newInstance(l: _*) + } + def returnType: Class[_] = ReflectionUtils.returnType(invocation) + def returnsValueClass: Boolean = ReflectionUtils.returnsValueClass(invocation) + } def invocationToAnswer[T: ValueClassExtractor](f: InvocationOnMock => T): ScalaAnswer[T] = ScalaAnswer.lift(f.andThen(ValueClassExtractor[T].extractAs[T]))