diff --git a/mainargs/src/TokensReader.scala b/mainargs/src/TokensReader.scala index 42449d0..997e901 100644 --- a/mainargs/src/TokensReader.scala +++ b/mainargs/src/TokensReader.scala @@ -153,24 +153,35 @@ object TokensReader { def shortName = wrapped.shortName } - implicit def OptionRead[T: TokensReader.Simple]: TokensReader[Option[T]] = new OptionRead[T] + implicit def OptionRead[T: TokensReader.Simple]: TokensReader.Simple[Option[T]] = new OptionRead[T] class OptionRead[T: TokensReader.Simple] extends Simple[Option[T]] { def shortName = implicitly[TokensReader.Simple[T]].shortName def read(strs: Seq[String]) = { - strs.lastOption match { - case None => Right(None) - case Some(s) => implicitly[TokensReader.Simple[T]].read(Seq(s)) match { + if (implicitly[TokensReader.Simple[T]].alwaysRepeatable) { + Option(strs).filter(_.nonEmpty) match{ + case None => Right(None) + case Some(strs) => implicitly[TokensReader.Simple[T]].read(strs) match{ case Left(s) => Left(s) case Right(s) => Right(Some(s)) } + } + } else { + strs.lastOption match{ + case None => Right(None) + case Some(s) => implicitly[TokensReader.Simple[T]].read(Seq(s)) match{ + case Left(s) => Left(s) + case Right(s) => Right(Some(s)) + } + } } } + override def alwaysRepeatable = implicitly[TokensReader.Simple[T]].alwaysRepeatable override def allowEmpty = true } implicit def SeqRead[C[_] <: Iterable[_], T: TokensReader.Simple](implicit factory: Factory[T, C[T]] - ): TokensReader[C[T]] = + ): TokensReader.Simple[C[T]] = new SeqRead[C, T] class SeqRead[C[_] <: Iterable[_], T: TokensReader.Simple](implicit factory: Factory[T, C[T]]) @@ -194,7 +205,7 @@ object TokensReader { override def allowEmpty = true } - implicit def MapRead[K: TokensReader.Simple, V: TokensReader.Simple]: TokensReader[Map[K, V]] = + implicit def MapRead[K: TokensReader.Simple, V: TokensReader.Simple]: TokensReader.Simple[Map[K, V]] = new MapRead[K, V] class MapRead[K: TokensReader.Simple, V: TokensReader.Simple] extends Simple[Map[K, V]] { def shortName = "k=v" diff --git a/mainargs/test/src/OptionSeqTests.scala b/mainargs/test/src/OptionSeqTests.scala index d37b8d7..a21e170 100644 --- a/mainargs/test/src/OptionSeqTests.scala +++ b/mainargs/test/src/OptionSeqTests.scala @@ -17,6 +17,9 @@ object OptionSeqTests extends TestSuite { @main def runInt(int: Int) = int + + @main + def runOptionSeq(os: Option[Seq[Int]]) = os } val tests = Tests { @@ -41,6 +44,24 @@ object OptionSeqTests extends TestSuite { Seq(123) } } + + test("option seq") { + test { + ParserForMethods(Main).runOrThrow(Array("runOptionSeq", "--os", "123")) ==> + Some(Seq(123)) + } + + test { + ParserForMethods(Main).runOrThrow(Array("runOptionSeq", "--os", "123", "--os", "456")) ==> + Some(Seq(123, 456)) + } + + test { + ParserForMethods(Main).runOrThrow(Array("runOptionSeq")) ==> + None + } + } + test("vec") { ParserForMethods(Main).runOrThrow(Array("runVec", "--seq", "123", "--seq", "456")) ==> Vector(123, 456)