Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions build.mill
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.github.lolgab.mill.mima.worker
import com.github.lolgab.mill.mima.worker.MimaWorkerExternalModule

import scala.util.chaining._
import os.Path

val scala212 = "2.12.18"
val scala213 = "2.13.12"
Expand Down Expand Up @@ -56,6 +57,24 @@ trait UnrollModule extends Cross.Module[String]{
}
}

object negative extends ScalaModule {
override def moduleDir: Path =
if (UnrollModule.this.crossValue.startsWith("2.")) (super.moduleDir / "scala2") else (super.moduleDir / "scala3")

def moduleDeps: Seq[JavaModule] = Seq(annotation)

def scalaVersion = UnrollModule.this.crossValue

def scalacPluginClasspath = Task{ Seq(plugin.jar()) }

def scalacOptions = Task {
Seq(
s"-Xplugin:${plugin.jar().path}",
"-Xplugin-require:unroll",
)
}
}

object testutils extends InnerScalaModule

val testcases = Seq(
Expand All @@ -69,11 +88,8 @@ trait UnrollModule extends Cross.Module[String]{
"secondaryConstructor",
"caseclass",
"secondParameterList",
"abstractTraitMethod",
"abstractClassMethod"
)


object tests extends Cross[Tests](testcases)

trait Tests extends Cross.Module[String]{
Expand Down Expand Up @@ -280,4 +296,3 @@ trait LocalMimaModule extends ScalaModule{
else mill.api.Result.Success(())
}
}

31 changes: 31 additions & 0 deletions unroll/negative/scala2/src/invalidunroll.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package unroll

import com.lihaoyi.unroll

class Foo {
def foo(int: Int, @unroll str: String = "1") = int.toString() + str

def fooInner(int: Int) = {
def inner(a: Int, @unroll str: String = "1") = {
def anEvenMoreInnerMethUnrolled(@unroll int: Int = 1) = ()
str
}
def innerNonUnrollMethod(str: String) = str
}
}

abstract class FooAbstractClass {
def foo(s: String, @unroll n: Int = 1): String
}

trait FooTrait {
def foo(s: String, @unroll n: Int = 1): String
}

case class FooCaseClass(int: Int, str: String) {
final def copy(int: Int = int, @unroll str: String = str): FooCaseClass = FooCaseClass(int, str)
}

object FooCaseClass {
def apply(int: Int, @unroll str: String): FooCaseClass = FooCaseClass(int, str)
}
31 changes: 31 additions & 0 deletions unroll/negative/scala3/src/invalidunroll.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package unroll

import com.lihaoyi.unroll

class Foo {
def foo(int: Int, @unroll str: String = "1") = int.toString() + str

def fooInner(int: Int) = {
def inner(a: Int, @unroll str: String = "1") = {
def anEvenMoreInnerMethUnrolled(@unroll int: Int = 1) = ()
str
}
def innerNonUnrollMethod(str: String) = str
}
}

abstract class FooAbstractClass {
def foo(s: String, @unroll n: Int = 1): String
}

trait FooTrait(@unroll val param: Int = 1, @unroll val param2: String = "asd") {
def foo(s: String, @unroll n: Int = 1): String
}

case class FooCaseClass(int: Int, str: String) {
final def copy(int: Int = int, @unroll str: String = str): FooCaseClass = FooCaseClass(int, str)
}

object FooCaseClass {
def apply(int: Int, @unroll str: String): FooCaseClass = FooCaseClass(int, str)
}
51 changes: 49 additions & 2 deletions unroll/plugin/src-2/UnrollPhaseScala2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,44 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
new UnrollTransformer(unit)
}

private def isValidUnrolledMethod(method: Symbol, origin: Position) = {
val isCtor = method.isConstructor

def explanation = {
def what = if (isCtor) s"a class constructor" else s"method ${method.name}"
val prefix = s"Cannot unroll parameters of $what"
if (isLocalMethod(method))
s"$prefix because it is a local method"
else if (!method.isEffectivelyFinal && !method.owner.isEffectivelyFinal)
s"$prefix because it can be overridden"
else if (method.owner.companionClass.isCaseClass)
s"$prefix of a case class companion object: please annotate the class constructor instead"
else
s"$prefix of a case class: please annotate the class constructor instead"
}

if (isCtor) true
else if (isLocalMethod(method)
|| !method.isEffectivelyFinal && !method.owner.isEffectivelyFinal
|| method.owner.companionClass.isCaseClass && method.name == nme.apply
|| method.owner.isCaseClass && method.name == nme.copy) {
globalError(origin, explanation)
false
} else true
}

def findUnrollAnnotations(params: Seq[Symbol]): Seq[Int] = {
params.toList.zipWithIndex.collect {
case (v, i) if v.annotations.exists(_.tpe =:= typeOf[com.lihaoyi.unroll]) => i
case (v, i) if hasUnrollAnnotation(v) && isValidUnrolledMethod(v.owner, v.pos) => i
}
}

private def hasUnrollAnnotation(symbol: Symbol) =
symbol.annotations.exists(_.tpe =:= typeOf[com.lihaoyi.unroll])

private def methodHasUnroll(method: Symbol) =
method.paramss.exists(_.exists(hasUnrollAnnotation))

def copyValDef(vd: ValDef) = {
val newMods = vd.mods.copy(flags = vd.mods.flags ^ Flags.DEFAULTPARAM)
newStrictTreeCopier.ValDef(vd, newMods, vd.name, vd.tpt, EmptyTree)
Expand Down Expand Up @@ -157,11 +189,27 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
forwarderDef.substituteSymbols(fromSyms, toSyms).asInstanceOf[DefDef]
}

private object LintInnerMethods extends Traverser {
override def traverse(tree: Tree): Unit = {
tree match {
case method: DefDef =>
if (methodHasUnroll(method.symbol)) isValidUnrolledMethod(method.symbol, method.pos)
traverseChildren(method.rhs)
case _ => ()
}
}

final def traverseChildren(tree: Tree): Unit = super.traverse(tree)
}

private def isLocalMethod(sym: Symbol): Boolean = {
sym.ownerChain.tail.exists(_.isMethod) || sym.isLocalToBlock
}

class UnrollTransformer(unit: global.CompilationUnit) extends TypingTransformer(unit) {
def generateDefForwarders(implDef: ImplDef): List[(Option[Symbol], Seq[DefDef])] = {
implDef.impl.body.collect{ case defdef: DefDef =>
LintInnerMethods.traverseChildren(defdef.rhs)

val annotatedOpt =
if (defdef.symbol.isCaseCopy && defdef.symbol.name.toString == "copy") {
Expand Down Expand Up @@ -280,4 +328,3 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
}
}
}

70 changes: 67 additions & 3 deletions unroll/plugin/src-3/UnrollPhaseScala3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import dotty.tools.dotc.core.Types.{MethodType, NamedType, PolyType, Type}
import dotty.tools.dotc.core.Symbols

import scala.language.implicitConversions
import dotty.tools.dotc.util.SrcPos

class UnrollPhaseScala3() extends PluginPhase {
import tpd._
Expand All @@ -40,14 +41,63 @@ class UnrollPhaseScala3() extends PluginPhase {
)
}

def findUnrollAnnotations(params: List[Symbol])(using Context): List[Int] = {
/**
* Adapted from:
* - https://github.com/scala/scala3/pull/21693/files#diff-367e32065f51ed46353fdaaf526ab7e7899404b219d24d5b1054fce7f376dfe5R127
* - https://github.com/scala/scala3/pull/21693/files#diff-e054695755ff26925ae51361df0f7cd4940bc1fd7ceb658023d0dc38c18178c3R3412
* - https://github.com/scala/scala3/pull/22926/files#diff-367e32065f51ed46353fdaaf526ab7e7899404b219d24d5b1054fce7f376dfe5R135
*/
private def isValidUnrolledMethod(method: Symbol, origin: SrcPos)(using Context) = {
val isCtor = method.isConstructor

def explanation =
def what = if isCtor then i"a ${if method.owner.is(Trait) then "trait" else "class"} constructor" else i"method ${method.name}"
val prefix = s"Cannot unroll parameters of $what"
if method.isLocal then
i"$prefix because it is a local method"
else if !method.isEffectivelyFinal then
i"$prefix because it can be overridden"
else if isCtor && method.owner.is(Trait) then
i"implementation restriction: $prefix"
else if method.owner.companionClass.is(CaseClass) then
i"$prefix of a case class companion object: please annotate the class constructor instead"
else
i"$prefix of a case class: please annotate the class constructor instead"

if method.name.is(DefaultGetterName) then false
else if method.isLocal
|| !method.isEffectivelyFinal
|| isCtor && method.owner.is(Trait)
|| method.owner.companionClass.is(CaseClass) && (method.name == nme.apply || method.name == nme.fromProduct)
|| method.owner.is(CaseClass) && method.name == nme.copy then
report.error(explanation, origin)
false
else true
}


// if `annotatedMethod` is a case class apply, case class copy or case class fromProduct then `params` actually comes from the case class' primary constructor
// but we still need to check the actual annotated method hence the two-step check
def findUnrollAnnotations(params: List[Symbol], specialCasedMethod: Option[Symbol])(using Context): List[Int] = {

specialCasedMethod.foreach { annotatedMethod =>
if methodHasUnroll(annotatedMethod) then isValidUnrolledMethod(annotatedMethod, annotatedMethod.sourcePos)
}

params
.zipWithIndex
.collect {
case (v, i) if v.annotations.exists(_.symbol.fullName.toString == "com.lihaoyi.unroll") =>
case (v, i) if hasUnrollAnnotation(v) && isValidUnrolledMethod(v.owner, v.sourcePos) =>
i
}
}

private def methodHasUnroll(method: Symbol)(using Context) =
method.paramSymss.exists(_.exists(hasUnrollAnnotation))

private def hasUnrollAnnotation(sym: Symbol)(using Context) =
sym.annotations.exists(_.symbol.fullName.toString == "com.lihaoyi.unroll")

def isTypeClause(p: ParamClause) = p.headOption.exists(_.isInstanceOf[TypeDef])
def generateSingleForwarder(defdef: DefDef,
prevMethodType: Type,
Expand Down Expand Up @@ -186,6 +236,8 @@ class UnrollPhaseScala3() extends PluginPhase {
case defdef: DefDef if defdef.paramss.nonEmpty =>
import dotty.tools.dotc.core.NameOps.isConstructorName

LintInnerMethods.traverseChildren(defdef.rhs)

val isCaseCopy =
defdef.name.toString == "copy" && defdef.symbol.owner.is(CaseClass)

Expand All @@ -200,12 +252,13 @@ class UnrollPhaseScala3() extends PluginPhase {
else if (isCaseFromProduct) defdef.symbol.owner.companionClass.primaryConstructor
else defdef.symbol

val specialCasedMethod = Option.when(isCaseCopy || isCaseApply || isCaseFromProduct)(defdef.symbol)

annotated
.paramSymss
.zipWithIndex
.flatMap{case (paramClause, paramClauseIndex) =>
val annotationIndices = findUnrollAnnotations(paramClause)
val annotationIndices = findUnrollAnnotations(paramClause, specialCasedMethod)
if (annotationIndices.isEmpty) None
else Some((paramClauseIndex, annotationIndices))
} match{
Expand Down Expand Up @@ -262,6 +315,17 @@ class UnrollPhaseScala3() extends PluginPhase {
case _ => (None, Nil)
}

private object LintInnerMethods extends TreeTraverser {
override def traverse(tree: Tree)(using Context): Unit =
tree match {
case method: DefDef =>
if methodHasUnroll(method.symbol) then isValidUnrolledMethod(method.symbol, method.sourcePos)
traverseChildren(method.rhs)
case _ => ()
}
override def traverseChildren(tree: Tree)(using Context): Unit = super.traverseChildren(tree)
}

override def transformTemplate(tmpl: tpd.Template)(using Context): tpd.Tree = {

val (removed0, generatedDefs) = tmpl.body.map(generateSyntheticDefs).unzip
Expand Down

This file was deleted.

6 changes: 0 additions & 6 deletions unroll/tests/abstractClassMethod/v1/src/Unrolled.scala

This file was deleted.

16 changes: 0 additions & 16 deletions unroll/tests/abstractClassMethod/v1/test/src/UnrollTestMain.scala

This file was deleted.

This file was deleted.

8 changes: 0 additions & 8 deletions unroll/tests/abstractClassMethod/v2/src/Unrolled.scala

This file was deleted.

Loading