diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/MissingAuthHeaderSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/MissingAuthHeaderSpec.scala new file mode 100644 index 000000000..bffed40b2 --- /dev/null +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/MissingAuthHeaderSpec.scala @@ -0,0 +1,69 @@ +/* + * Copyright 2021 - 2023 Sporta Technologies PVT LTD & the ZIO HTTP contributors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio.http.endpoint + +import zio._ +import zio.test._ + +import zio.http._ +import zio.http.codec._ +import zio.http.endpoint._ + +object MissingAuthHeaderSpec extends ZIOHttpSpec { + + def spec = suite("MissingAuthHeaderSpec")( + test("missing Authorization header should return 401 not 400") { + val endpoint = Endpoint(Method.GET / "test") + .header(HeaderCodec.authorization) + .out[String] + + val routes = endpoint.implementHandler( + Handler.succeed("success"), + ) + + for { + response <- routes.toRoutes.runZIO( + Request.get(url"/test").addHeader(Header.Accept(MediaType.application.`json`)), + ) + status = response.status + } yield assertTrue( + status.code == 401, + status == Status.Unauthorized, + ) + }, + test("missing non-auth header should still return 400") { + val endpoint = Endpoint(Method.GET / "test") + .header(HeaderCodec.headerAs[String]("X-Custom-Header")) + .out[String] + + val routes = endpoint.implementHandler( + Handler.succeed("success"), + ) + + for { + response <- routes.toRoutes.runZIO( + Request.get(url"/test").addHeader(Header.Accept(MediaType.application.`json`)), + ) + status = response.status + } yield assertTrue( + status.code == 400, + status == Status.BadRequest, + ) + }, + ) + +} diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/Endpoint.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/Endpoint.scala index e830e496d..76f82d155 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/Endpoint.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/Endpoint.scala @@ -300,6 +300,11 @@ final case class Endpoint[PathInput, Input, Err, Output, Auth <: AuthType]( def implementHandler[Env](original: Handler[Env, Err, Input, Output])(implicit trace: Trace): Route[Env, Nothing] = { import HttpCodecError.asHttpCodecError + def isAuthHeader(headerName: String): Boolean = { + val lowerName = headerName.toLowerCase + lowerName == "authorization" || lowerName == "www-authenticate" || lowerName.startsWith("x-auth") + } + def authCodec(authType: AuthType): HttpCodec[HttpCodecType.RequestType, Unit] = authType match { case AuthType.None => HttpCodec.empty case AuthType.Basic => @@ -419,7 +424,11 @@ final case class Endpoint[PathInput, Input, Err, Output, Auth <: AuthType]( case Some(HttpCodecError.CustomError("SchemaTransformationFailure", message)) if maybeUnauthedResponse.isDefined && message.endsWith(" auth required") => maybeUnauthedResponse.get - case Some(_) => + case Some(HttpCodecError.MissingHeader(headerName)) if isAuthHeader(headerName) => + Handler.succeed(Response.unauthorized) + case Some(HttpCodecError.MissingHeaders(headerNames)) if headerNames.exists(isAuthHeader) => + Handler.succeed(Response.unauthorized) + case Some(_) => Handler.fromFunctionZIO { (request: zio.http.Request) => val error = cause.defects.head.asInstanceOf[HttpCodecError] val response = { @@ -434,7 +443,7 @@ final case class Endpoint[PathInput, Input, Err, Output, Auth <: AuthType]( } ZIO.succeed(response) } - case None => + case None => Handler.failCause(cause) } }