diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/UnauthorizedSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/UnauthorizedSpec.scala new file mode 100644 index 000000000..74512b3df --- /dev/null +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/UnauthorizedSpec.scala @@ -0,0 +1,24 @@ +package zio.http.endpoint + +import zio.ZIO +import zio.test._ + +import zio.http._ +import zio.http.codec._ + +object UnauthorizedSpec extends ZIOSpecDefault { + override def spec = + suite("UnauthorizedSpec")( + test("should respond with 401 Unauthorized when required authorization header is missing") { + val endpoint = Endpoint(Method.GET / "test") + .header(HeaderCodec.authorization) + .out[Unit] + val route = endpoint.implement(_ => ZIO.unit) + val request = + Request(method = Method.GET, url = url"/test") + for { + response <- route.toRoutes.runZIO(request) + } yield assertTrue(Status.Unauthorized == response.status) + }, + ) +} diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/AuthType.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/AuthType.scala index aa5214c7a..9b7aeea56 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/AuthType.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/AuthType.scala @@ -16,6 +16,7 @@ sealed trait AuthType { self => AuthType { type ClientRequirement = ClientReq }, ] + def asWWWAuthenticateHeader: Option[Header.WWWAuthenticate] } object AuthType { @@ -27,28 +28,43 @@ object AuthType { case object None extends AuthType { type ClientRequirement = Unit override val codec: HeaderCodec[Unit] = HttpCodec.empty.asInstanceOf[HeaderCodec[Unit]] + + override def asWWWAuthenticateHeader: Option[Header.WWWAuthenticate] = + Option.empty } case object Basic extends AuthType { type ClientRequirement = Header.Authorization.Basic override val codec: HeaderCodec[Header.Authorization.Basic] = HeaderCodec.basicAuth + + override def asWWWAuthenticateHeader: Option[Header.WWWAuthenticate] = + Some(Header.WWWAuthenticate.Basic()) } case object Bearer extends AuthType { type ClientRequirement = Header.Authorization.Bearer override val codec: HeaderCodec[Header.Authorization.Bearer] = HeaderCodec.bearerAuth + + override def asWWWAuthenticateHeader: Option[Header.WWWAuthenticate] = + Some(Header.WWWAuthenticate.Bearer(???)) } case object Digest extends AuthType { type ClientRequirement = Header.Authorization.Digest override val codec: HeaderCodec[Header.Authorization.Digest] = HeaderCodec.digestAuth + + override def asWWWAuthenticateHeader: Option[Header.WWWAuthenticate] = + Some(Header.WWWAuthenticate.Digest(None)) } final case class Custom[ClientReq](override val codec: HttpCodec[HttpCodecType.RequestType, ClientReq]) extends AuthType { type ClientRequirement = ClientReq + + override def asWWWAuthenticateHeader: Option[Header.WWWAuthenticate] = + Some(Header.WWWAuthenticate.Unknown(???, ???, ???)) } final case class Or[ClientReq1, ClientReq2, ClientReq]( @@ -57,8 +73,10 @@ object AuthType { alternator: Alternator.WithOut[ClientReq1, ClientReq2, ClientReq], ) extends AuthType { type ClientRequirement = ClientReq - override val codec: HttpCodec[HttpCodecType.RequestType, ClientReq] = + override val codec: HttpCodec[HttpCodecType.RequestType, ClientReq] = auth1.codec.orElseEither(auth2.codec)(alternator) + override def asWWWAuthenticateHeader: Option[Header.WWWAuthenticate] = + auth1.asWWWAuthenticateHeader } final case class ScopedAuth[ClientReq]( @@ -68,6 +86,9 @@ object AuthType { type ClientRequirement = ClientReq override val codec: HttpCodec[HttpCodecType.RequestType, ClientReq] = authType.codec + override def asWWWAuthenticateHeader: Option[Header.WWWAuthenticate] = + authType.asWWWAuthenticateHeader + def scopes: List[String] = _scopes def scopes(newScopes: List[String]) = copy(_scopes = newScopes) diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/Endpoint.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/Endpoint.scala index 1207cb8a1..c3c8d19f5 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/Endpoint.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/Endpoint.scala @@ -384,7 +384,14 @@ final case class Endpoint[PathInput, Input, Err, Output, Auth <: AuthType]( case Some(HttpCodecError.CustomError("SchemaTransformationFailure", message)) if maybeUnauthedResponse.isDefined && message.endsWith(" auth required") => maybeUnauthedResponse.get - case Some(_) => + case Some(HttpCodecError.MissingHeaders(headerNames)) + if headerNames.contains(Header.Authorization.name) => + Handler.succeed(Response.unauthorized) + case Some(HttpCodecError.MissingHeader(headerName)) if headerName == Header.Authorization.name => + Handler.succeed(Response.unauthorized) + case Some(HttpCodecError.DecodingErrorHeader(headerName, _)) if headerName == Header.Authorization.name => + Handler.succeed(Response.unauthorized) + case _: Some[_] => Handler.fromFunctionZIO { (request: zio.http.Request) => val error = cause.defects.head.asInstanceOf[HttpCodecError] val response = { @@ -399,7 +406,7 @@ final case class Endpoint[PathInput, Input, Err, Output, Auth <: AuthType]( } ZIO.succeed(response) } - case None => + case None => Handler.failCause(cause) } }