Skip to content

Commit 48a4d96

Browse files
authored
Merge pull request #15 from arainko/issue-12-unroll-on-finals
[Issue 12] `@unroll` on final or effectively final methods
2 parents 3db69e9 + 2d36ed1 commit 48a4d96

File tree

43 files changed

+211
-473
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+211
-473
lines changed

build.mill

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import com.github.lolgab.mill.mima.worker
1010
import com.github.lolgab.mill.mima.worker.MimaWorkerExternalModule
1111

1212
import scala.util.chaining._
13+
import os.Path
1314

1415
val scala212 = "2.12.18"
1516
val scala213 = "2.13.12"
@@ -56,6 +57,24 @@ trait UnrollModule extends Cross.Module[String]{
5657
}
5758
}
5859

60+
object negative extends ScalaModule {
61+
override def moduleDir: Path =
62+
if (UnrollModule.this.crossValue.startsWith("2.")) (super.moduleDir / "scala2") else (super.moduleDir / "scala3")
63+
64+
def moduleDeps: Seq[JavaModule] = Seq(annotation)
65+
66+
def scalaVersion = UnrollModule.this.crossValue
67+
68+
def scalacPluginClasspath = Task{ Seq(plugin.jar()) }
69+
70+
def scalacOptions = Task {
71+
Seq(
72+
s"-Xplugin:${plugin.jar().path}",
73+
"-Xplugin-require:unroll",
74+
)
75+
}
76+
}
77+
5978
object testutils extends InnerScalaModule
6079

6180
val testcases = Seq(
@@ -69,11 +88,8 @@ trait UnrollModule extends Cross.Module[String]{
6988
"secondaryConstructor",
7089
"caseclass",
7190
"secondParameterList",
72-
"abstractTraitMethod",
73-
"abstractClassMethod"
7491
)
7592

76-
7793
object tests extends Cross[Tests](testcases)
7894

7995
trait Tests extends Cross.Module[String]{
@@ -280,4 +296,3 @@ trait LocalMimaModule extends ScalaModule{
280296
else mill.api.Result.Success(())
281297
}
282298
}
283-
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package unroll
2+
3+
import com.lihaoyi.unroll
4+
5+
class Foo {
6+
def foo(int: Int, @unroll str: String = "1") = int.toString() + str
7+
8+
def fooInner(int: Int) = {
9+
def inner(a: Int, @unroll str: String = "1") = {
10+
def anEvenMoreInnerMethUnrolled(@unroll int: Int = 1) = ()
11+
str
12+
}
13+
def innerNonUnrollMethod(str: String) = str
14+
}
15+
}
16+
17+
abstract class FooAbstractClass {
18+
def foo(s: String, @unroll n: Int = 1): String
19+
}
20+
21+
trait FooTrait {
22+
def foo(s: String, @unroll n: Int = 1): String
23+
}
24+
25+
case class FooCaseClass(int: Int, str: String) {
26+
final def copy(int: Int = int, @unroll str: String = str): FooCaseClass = FooCaseClass(int, str)
27+
}
28+
29+
object FooCaseClass {
30+
def apply(int: Int, @unroll str: String): FooCaseClass = FooCaseClass(int, str)
31+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package unroll
2+
3+
import com.lihaoyi.unroll
4+
5+
class Foo {
6+
def foo(int: Int, @unroll str: String = "1") = int.toString() + str
7+
8+
def fooInner(int: Int) = {
9+
def inner(a: Int, @unroll str: String = "1") = {
10+
def anEvenMoreInnerMethUnrolled(@unroll int: Int = 1) = ()
11+
str
12+
}
13+
def innerNonUnrollMethod(str: String) = str
14+
}
15+
}
16+
17+
abstract class FooAbstractClass {
18+
def foo(s: String, @unroll n: Int = 1): String
19+
}
20+
21+
trait FooTrait(@unroll val param: Int = 1, @unroll val param2: String = "asd") {
22+
def foo(s: String, @unroll n: Int = 1): String
23+
}
24+
25+
case class FooCaseClass(int: Int, str: String) {
26+
final def copy(int: Int = int, @unroll str: String = str): FooCaseClass = FooCaseClass(int, str)
27+
}
28+
29+
object FooCaseClass {
30+
def apply(int: Int, @unroll str: String): FooCaseClass = FooCaseClass(int, str)
31+
}

unroll/plugin/src-2/UnrollPhaseScala2.scala

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,44 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
1919
new UnrollTransformer(unit)
2020
}
2121

22+
private def isValidUnrolledMethod(method: Symbol, origin: Position) = {
23+
val isCtor = method.isConstructor
24+
25+
def explanation = {
26+
def what = if (isCtor) s"a class constructor" else s"method ${method.name}"
27+
val prefix = s"Cannot unroll parameters of $what"
28+
if (isLocalMethod(method))
29+
s"$prefix because it is a local method"
30+
else if (!method.isEffectivelyFinal && !method.owner.isEffectivelyFinal)
31+
s"$prefix because it can be overridden"
32+
else if (method.owner.companionClass.isCaseClass)
33+
s"$prefix of a case class companion object: please annotate the class constructor instead"
34+
else
35+
s"$prefix of a case class: please annotate the class constructor instead"
36+
}
37+
38+
if (isCtor) true
39+
else if (isLocalMethod(method)
40+
|| !method.isEffectivelyFinal && !method.owner.isEffectivelyFinal
41+
|| method.owner.companionClass.isCaseClass && method.name == nme.apply
42+
|| method.owner.isCaseClass && method.name == nme.copy) {
43+
globalError(origin, explanation)
44+
false
45+
} else true
46+
}
47+
2248
def findUnrollAnnotations(params: Seq[Symbol]): Seq[Int] = {
2349
params.toList.zipWithIndex.collect {
24-
case (v, i) if v.annotations.exists(_.tpe =:= typeOf[com.lihaoyi.unroll]) => i
50+
case (v, i) if hasUnrollAnnotation(v) && isValidUnrolledMethod(v.owner, v.pos) => i
2551
}
2652
}
2753

54+
private def hasUnrollAnnotation(symbol: Symbol) =
55+
symbol.annotations.exists(_.tpe =:= typeOf[com.lihaoyi.unroll])
56+
57+
private def methodHasUnroll(method: Symbol) =
58+
method.paramss.exists(_.exists(hasUnrollAnnotation))
59+
2860
def copyValDef(vd: ValDef) = {
2961
val newMods = vd.mods.copy(flags = vd.mods.flags ^ Flags.DEFAULTPARAM)
3062
newStrictTreeCopier.ValDef(vd, newMods, vd.name, vd.tpt, EmptyTree)
@@ -157,11 +189,27 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
157189
forwarderDef.substituteSymbols(fromSyms, toSyms).asInstanceOf[DefDef]
158190
}
159191

192+
private object LintInnerMethods extends Traverser {
193+
override def traverse(tree: Tree): Unit = {
194+
tree match {
195+
case method: DefDef =>
196+
if (methodHasUnroll(method.symbol)) isValidUnrolledMethod(method.symbol, method.pos)
197+
traverseChildren(method.rhs)
198+
case _ => ()
199+
}
200+
}
201+
202+
final def traverseChildren(tree: Tree): Unit = super.traverse(tree)
203+
}
160204

205+
private def isLocalMethod(sym: Symbol): Boolean = {
206+
sym.ownerChain.tail.exists(_.isMethod) || sym.isLocalToBlock
207+
}
161208

162209
class UnrollTransformer(unit: global.CompilationUnit) extends TypingTransformer(unit) {
163210
def generateDefForwarders(implDef: ImplDef): List[(Option[Symbol], Seq[DefDef])] = {
164211
implDef.impl.body.collect{ case defdef: DefDef =>
212+
LintInnerMethods.traverseChildren(defdef.rhs)
165213

166214
val annotatedOpt =
167215
if (defdef.symbol.isCaseCopy && defdef.symbol.name.toString == "copy") {
@@ -280,4 +328,3 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
280328
}
281329
}
282330
}
283-

unroll/plugin/src-3/UnrollPhaseScala3.scala

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import dotty.tools.dotc.core.Types.{MethodType, NamedType, PolyType, Type}
1818
import dotty.tools.dotc.core.Symbols
1919

2020
import scala.language.implicitConversions
21+
import dotty.tools.dotc.util.SrcPos
2122

2223
class UnrollPhaseScala3() extends PluginPhase {
2324
import tpd._
@@ -40,14 +41,63 @@ class UnrollPhaseScala3() extends PluginPhase {
4041
)
4142
}
4243

43-
def findUnrollAnnotations(params: List[Symbol])(using Context): List[Int] = {
44+
/**
45+
* Adapted from:
46+
* - https://github.com/scala/scala3/pull/21693/files#diff-367e32065f51ed46353fdaaf526ab7e7899404b219d24d5b1054fce7f376dfe5R127
47+
* - https://github.com/scala/scala3/pull/21693/files#diff-e054695755ff26925ae51361df0f7cd4940bc1fd7ceb658023d0dc38c18178c3R3412
48+
* - https://github.com/scala/scala3/pull/22926/files#diff-367e32065f51ed46353fdaaf526ab7e7899404b219d24d5b1054fce7f376dfe5R135
49+
*/
50+
private def isValidUnrolledMethod(method: Symbol, origin: SrcPos)(using Context) = {
51+
val isCtor = method.isConstructor
52+
53+
def explanation =
54+
def what = if isCtor then i"a ${if method.owner.is(Trait) then "trait" else "class"} constructor" else i"method ${method.name}"
55+
val prefix = s"Cannot unroll parameters of $what"
56+
if method.isLocal then
57+
i"$prefix because it is a local method"
58+
else if !method.isEffectivelyFinal then
59+
i"$prefix because it can be overridden"
60+
else if isCtor && method.owner.is(Trait) then
61+
i"implementation restriction: $prefix"
62+
else if method.owner.companionClass.is(CaseClass) then
63+
i"$prefix of a case class companion object: please annotate the class constructor instead"
64+
else
65+
i"$prefix of a case class: please annotate the class constructor instead"
66+
67+
if method.name.is(DefaultGetterName) then false
68+
else if method.isLocal
69+
|| !method.isEffectivelyFinal
70+
|| isCtor && method.owner.is(Trait)
71+
|| method.owner.companionClass.is(CaseClass) && (method.name == nme.apply || method.name == nme.fromProduct)
72+
|| method.owner.is(CaseClass) && method.name == nme.copy then
73+
report.error(explanation, origin)
74+
false
75+
else true
76+
}
77+
78+
79+
// if `annotatedMethod` is a case class apply, case class copy or case class fromProduct then `params` actually comes from the case class' primary constructor
80+
// but we still need to check the actual annotated method hence the two-step check
81+
def findUnrollAnnotations(params: List[Symbol], specialCasedMethod: Option[Symbol])(using Context): List[Int] = {
82+
83+
specialCasedMethod.foreach { annotatedMethod =>
84+
if methodHasUnroll(annotatedMethod) then isValidUnrolledMethod(annotatedMethod, annotatedMethod.sourcePos)
85+
}
86+
4487
params
4588
.zipWithIndex
4689
.collect {
47-
case (v, i) if v.annotations.exists(_.symbol.fullName.toString == "com.lihaoyi.unroll") =>
90+
case (v, i) if hasUnrollAnnotation(v) && isValidUnrolledMethod(v.owner, v.sourcePos) =>
4891
i
4992
}
5093
}
94+
95+
private def methodHasUnroll(method: Symbol)(using Context) =
96+
method.paramSymss.exists(_.exists(hasUnrollAnnotation))
97+
98+
private def hasUnrollAnnotation(sym: Symbol)(using Context) =
99+
sym.annotations.exists(_.symbol.fullName.toString == "com.lihaoyi.unroll")
100+
51101
def isTypeClause(p: ParamClause) = p.headOption.exists(_.isInstanceOf[TypeDef])
52102
def generateSingleForwarder(defdef: DefDef,
53103
prevMethodType: Type,
@@ -186,6 +236,8 @@ class UnrollPhaseScala3() extends PluginPhase {
186236
case defdef: DefDef if defdef.paramss.nonEmpty =>
187237
import dotty.tools.dotc.core.NameOps.isConstructorName
188238

239+
LintInnerMethods.traverseChildren(defdef.rhs)
240+
189241
val isCaseCopy =
190242
defdef.name.toString == "copy" && defdef.symbol.owner.is(CaseClass)
191243

@@ -200,12 +252,13 @@ class UnrollPhaseScala3() extends PluginPhase {
200252
else if (isCaseFromProduct) defdef.symbol.owner.companionClass.primaryConstructor
201253
else defdef.symbol
202254

255+
val specialCasedMethod = Option.when(isCaseCopy || isCaseApply || isCaseFromProduct)(defdef.symbol)
203256

204257
annotated
205258
.paramSymss
206259
.zipWithIndex
207260
.flatMap{case (paramClause, paramClauseIndex) =>
208-
val annotationIndices = findUnrollAnnotations(paramClause)
261+
val annotationIndices = findUnrollAnnotations(paramClause, specialCasedMethod)
209262
if (annotationIndices.isEmpty) None
210263
else Some((paramClauseIndex, annotationIndices))
211264
} match{
@@ -262,6 +315,17 @@ class UnrollPhaseScala3() extends PluginPhase {
262315
case _ => (None, Nil)
263316
}
264317

318+
private object LintInnerMethods extends TreeTraverser {
319+
override def traverse(tree: Tree)(using Context): Unit =
320+
tree match {
321+
case method: DefDef =>
322+
if methodHasUnroll(method.symbol) then isValidUnrolledMethod(method.symbol, method.sourcePos)
323+
traverseChildren(method.rhs)
324+
case _ => ()
325+
}
326+
override def traverseChildren(tree: Tree)(using Context): Unit = super.traverseChildren(tree)
327+
}
328+
265329
override def transformTemplate(tmpl: tpd.Template)(using Context): tpd.Tree = {
266330

267331
val (removed0, generatedDefs) = tmpl.body.map(generateSyntheticDefs).unzip

unroll/tests/abstractClassMethod/v1/downstream/src/Downstream.scala

Lines changed: 0 additions & 13 deletions
This file was deleted.

unroll/tests/abstractClassMethod/v1/src/Unrolled.scala

Lines changed: 0 additions & 6 deletions
This file was deleted.

unroll/tests/abstractClassMethod/v1/test/src/UnrollTestMain.scala

Lines changed: 0 additions & 16 deletions
This file was deleted.

unroll/tests/abstractClassMethod/v2/downstream/src/Downstream.scala

Lines changed: 0 additions & 15 deletions
This file was deleted.

unroll/tests/abstractClassMethod/v2/src/Unrolled.scala

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)