From 0dfb877a06222352e1af9dd0a34db33ade3a8a29 Mon Sep 17 00:00:00 2001 From: Gaurav Dasson Date: Wed, 25 Dec 2024 09:18:20 -0600 Subject: [PATCH 1/2] Added support for AWS IAM Auth Method. --- CHANGELOG.md | 5 + README.md | 2 + examples/SecretProviderClassExample1.md | 25 ++ examples/SecretProviderClassExample2.md | 29 +++ go.mod | 4 + go.sum | 23 ++ internal/auth/auth.go | 23 ++ internal/auth/awsiam.go | 234 +++++++++++++++++ internal/auth/awsiam_test.go | 324 ++++++++++++++++++++++++ internal/auth/kubernetes_jwt.go | 14 +- internal/auth/kubernetes_jwt_test.go | 115 +++++++++ internal/client/client.go | 12 +- internal/client/client_test.go | 3 +- internal/config/config.go | 58 ++++- internal/config/config_test.go | 5 + internal/provider/provider.go | 4 +- internal/provider/provider_test.go | 3 +- internal/server/server.go | 5 +- main.go | 3 +- 19 files changed, 872 insertions(+), 19 deletions(-) create mode 100644 examples/SecretProviderClassExample1.md create mode 100644 examples/SecretProviderClassExample2.md create mode 100644 internal/auth/auth.go create mode 100644 internal/auth/awsiam.go create mode 100644 internal/auth/awsiam_test.go create mode 100644 internal/auth/kubernetes_jwt_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b54db67..a6edd6e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ ## Unreleased +FEATURES: + +* Added support for AWS IAM Auth type. + + CHANGES: * Test with K8s 1.27-1.31 diff --git a/README.md b/README.md index f80b8a99..72a8ac06 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,8 @@ full details of deploying, configuring and using Vault CSI provider. The integration tests in [test/bats/provider.bats](./test/bats/provider.bats) also provide a good set of fully worked and tested examples to build on. +For details on different SecretProviderClass configurations, see [examples](./examples). + ## Troubleshooting To troubleshoot issues with Vault CSI provider, look at logs from the Vault CSI diff --git a/examples/SecretProviderClassExample1.md b/examples/SecretProviderClassExample1.md new file mode 100644 index 00000000..c0be201a --- /dev/null +++ b/examples/SecretProviderClassExample1.md @@ -0,0 +1,25 @@ +Below is an example for a SecretProviderClass for Vault with AWS IAM auth method. + +```yaml +apiVersion: v1 +items: +- apiVersion: secrets-store.csi.x-k8s.io/v1 + kind: SecretProviderClass + metadata: + name: vault-foo + namespace: default + spec: + parameters: + auth: |- # This block is optional. If this block is not specified, the default auth method is kubernetes + type: kubernetes # Auth method type + mouthPath: kubernetes # Mount path for Kubernetes auth method. Defaults to kubernetes if not specified. + objects: | + - secretPath: "secret/web-app" + objectName: "creds" + secretKey: "api-token" + roleName: secret-store-csi-test # Vault Role Name + vaultAddress: https://vault.address:8200 + provider: vault + resourceVersion: "" + +``` \ No newline at end of file diff --git a/examples/SecretProviderClassExample2.md b/examples/SecretProviderClassExample2.md new file mode 100644 index 00000000..e600087b --- /dev/null +++ b/examples/SecretProviderClassExample2.md @@ -0,0 +1,29 @@ +Below is an example for a SecretProviderClass for Vault with AWS IAM auth method. + +```yaml +apiVersion: v1 +items: +- apiVersion: secrets-store.csi.x-k8s.io/v1 + kind: SecretProviderClass + metadata: + name: vault-foo + namespace: default + spec: + parameters: + auth: |- # This block is optional. If this block is not specified, the default auth method is kubernetes + type: aws # Auth method type + mouthPath: aws # Mount path for AWS auth method. Defaults to aws if not specified. + aws: + region: us-east-1 # AWS Region + awsIAMRole: secrets-store-inline-irsa-role # AWS IAM Role + xVaultAWSIAMServerID: vault.example.com # Vault AWS IAM Server ID. More info: https://www.vaultproject.io/docs/auth/aws#server-id + objects: | + - secretPath: "secret/web-app" + objectName: "creds" + secretKey: "api-token" + roleName: secret-store-csi-test # Vault Role Name + vaultAddress: https://vault.address:8200 + provider: vault + resourceVersion: "" + +``` \ No newline at end of file diff --git a/go.mod b/go.mod index b9fd775c..823cddd7 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,10 @@ module github.com/hashicorp/vault-csi-provider go 1.22.0 require ( + github.com/aws/aws-sdk-go v1.55.5 + github.com/golang-jwt/jwt/v4 v4.5.1 github.com/hashicorp/go-hclog v1.6.3 + github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/vault/api v1.15.0 github.com/stretchr/testify v1.9.0 @@ -42,6 +45,7 @@ require ( github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect github.com/hashicorp/go-sockaddr v1.0.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/mailru/easyjson v0.7.7 // indirect diff --git a/go.sum b/go.sum index 3210ad4e..a7e6ae37 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,7 @@ github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= +github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= +github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= @@ -26,12 +29,15 @@ github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU= github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/go-test/deep v1.0.2 h1:onZX1rnHT3Wv6cqNgYyFOOlgVKJrksuCMCRvJStbMYw= github.com/go-test/deep v1.0.2/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo= +github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I= @@ -51,6 +57,7 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= @@ -60,6 +67,8 @@ github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISH github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc= github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= +github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0 h1:I8bynUKMh9I7JdwtW9voJ0xmHvBpxQtLjrMFDYmhOxY= +github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0/go.mod h1:oKHSQs4ivIfZ3fbXGQOop1XuDfdSb8RIsWTGaAanSfg= github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6 h1:om4Al8Oy7kCm/B86rLCLah4Dt5Aa0Fr5rYBG60OzwHQ= github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6/go.mod h1:QmrqtbKuxxSWTN3ETMPuB+VtEiBJ/A9XhoYGv8E1uD8= github.com/hashicorp/go-secure-stdlib/strutil v0.1.1/go.mod h1:gKOamz3EwoIoJq7mlMIRBpVTAUn8qPCrEclOKKWhD3U= @@ -73,13 +82,20 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/vault/api v1.15.0 h1:O24FYQCWwhwKnF7CuSqP30S51rTV7vz1iACXE/pj5DA= github.com/hashicorp/vault/api v1.15.0/go.mod h1:+5YTO09JGn0u+b6ySD/LLVf8WkJCPLAL2Vkmrn2+CM8= +github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -124,6 +140,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= @@ -135,11 +152,13 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= @@ -155,6 +174,7 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= @@ -202,12 +222,15 @@ google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFN google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/evanphx/json-patch.v4 v4.12.0 h1:n6jtcsulIzXPJaxegRbvFNNrZDjbij7ny3gmSPG+6V4= gopkg.in/evanphx/json-patch.v4 v4.12.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 00000000..91f4b018 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,23 @@ +package auth + +import ( + "context" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault-csi-provider/internal/config" + "k8s.io/client-go/kubernetes" +) + +type Auth interface { + AuthRequest(context.Context) (string, map[string]any, map[string]string, error) +} + +func NewAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params config.Parameters, defaultMountPath string) (Auth, error) { + if params.VaultAuth.Type == "kubernetes" || params.VaultAuth.Type == "jwt" { + return NewKubernetesJWTAuth(logger, k8sClient, params, defaultMountPath) + } + if params.VaultAuth.Type == "aws" { + return NewAWSIAMAuth(logger, k8sClient, params, defaultMountPath) + } + // Default to Kubernetes + return NewKubernetesJWTAuth(logger, k8sClient, params, defaultMountPath) +} diff --git a/internal/auth/awsiam.go b/internal/auth/awsiam.go new file mode 100644 index 00000000..9a22c3ae --- /dev/null +++ b/internal/auth/awsiam.go @@ -0,0 +1,234 @@ +package auth + +import ( + "context" + "errors" + "fmt" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "github.com/aws/aws-sdk-go/aws/defaults" + "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-secure-stdlib/awsutil" + "github.com/hashicorp/vault-csi-provider/internal/config" + authv1 "k8s.io/api/authentication/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + k8scorev1 "k8s.io/client-go/kubernetes/typed/core/v1" + "os" + "regexp" +) + +type AWSIAMAuth struct { + logger hclog.Logger + k8sClient kubernetes.Interface + params config.Parameters + defaultMountPath string + stsClient stsiface.STSAPI +} + +const ( + roleARNAnnotation = "eks.amazonaws.com/role-arn" + audienceAnnotation = "eks.amazonaws.com/audience" + defaultTokenAudience = "sts.amazonaws.com" + defaultAWSRegion = "us-east-1" + STSEndpointEnv = "AWS_STS_ENDPOINT" + defaultAWSMountPath = "aws" +) + +func setupConfig(params config.Parameters, credentials *credentials.Credentials) *aws.Config { + // Get an initial session to use for STS calls. + regionAWS := defaultAWSRegion + if params.VaultAuth.AWSIAMAuth.Region != "" { + regionAWS = params.VaultAuth.AWSIAMAuth.Region + } + handlers := defaults.Handlers() + handlers.Build.PushBack(request.WithAppendUserAgent("vault-csi-provider")) + awsConfig := aws.NewConfig().WithEndpointResolver(ResolveEndpoint()) + if regionAWS != "" { + awsConfig.WithRegion(regionAWS) + } + + if credentials != nil { + awsConfig.WithCredentials(credentials) + } + return awsConfig +} + +func NewAWSIAMAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params config.Parameters, defaultMountPath string) (*AWSIAMAuth, error) { + // Get an initial session to use for STS calls. + awsConfig := setupConfig(params, nil) + sess, err := session.NewSession(awsConfig) + if err != nil { + return nil, err + } + + return &AWSIAMAuth{ + logger: logger, + k8sClient: k8sClient, + params: params, + defaultMountPath: defaultMountPath, + stsClient: sts.New(sess), + }, nil +} + +func ResolveEndpointWithServiceMap(customEndpoints map[string]string) endpoints.ResolverFunc { + defaultResolver := endpoints.DefaultResolver() + return func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + if ep, ok := customEndpoints[service]; ok { + return endpoints.ResolvedEndpoint{ + URL: ep, + }, nil + } + return defaultResolver.EndpointFor(service, region, opts...) + } +} + +// ResolveEndpoint returns a ResolverFunc with +// customizable endpoints. +func ResolveEndpoint() endpoints.ResolverFunc { + customEndpoints := make(map[string]string) + if v := os.Getenv(STSEndpointEnv); v != "" { + customEndpoints["sts"] = v + } + return ResolveEndpointWithServiceMap(customEndpoints) +} + +var regexReqIDs = []*regexp.Regexp{ + regexp.MustCompile(`request id: (\S+)`), + regexp.MustCompile(` Credential=.+`), +} + +func SanitizeErr(err error) error { + msg := err.Error() + for _, regex := range regexReqIDs { + msg = string(regex.ReplaceAll([]byte(msg), nil)) + } + return errors.New(msg) +} + +type authTokenFetcher struct { + Namespace string + // Audience is the token aud claim + // which is verified by the aws oidc provider + // see: https://github.com/external-secrets/external-secrets/issues/1251#issuecomment-1161745849 + Audiences []string + ServiceAccount string + k8sClient k8scorev1.CoreV1Interface +} + +// FetchToken satisfies the stscreds.TokenFetcher interface +// it is used to generate service account tokens which are consumed by the aws sdk. +func (p authTokenFetcher) FetchToken(ctx credentials.Context) ([]byte, error) { + tokRsp, err := p.k8sClient.ServiceAccounts(p.Namespace).CreateToken(ctx, p.ServiceAccount, &authv1.TokenRequest{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: p.Namespace, + Name: p.ServiceAccount, + }, + Spec: authv1.TokenRequestSpec{ + Audiences: p.Audiences, + }, + }, metav1.CreateOptions{}) + if err != nil { + return nil, fmt.Errorf("error creating service account token: %w", err) + } + return []byte(tokRsp.Status.Token), nil +} + +func getTokenFetcher(ctx context.Context, k *AWSIAMAuth) (string, *authTokenFetcher, error) { + sa, err := k.k8sClient.CoreV1().ServiceAccounts(k.params.PodInfo.Namespace).Get(ctx, k.params.PodInfo.ServiceAccountName, metav1.GetOptions{}) + if err != nil { + return "", nil, err + } + // the service account is expected to have a well-known annotation + // this is used as input to assumeRoleWithWebIdentity + roleArn := sa.Annotations[roleARNAnnotation] + if roleArn == "" { + return "", nil, fmt.Errorf("an IAM role must be associated with service account %s (namespace: %s)", k.params.PodInfo.ServiceAccountName, k.params.PodInfo.Namespace) + } + + tokenAud := sa.Annotations[audienceAnnotation] + if tokenAud == "" { + tokenAud = defaultTokenAudience + } + + audiences := []string{tokenAud} + + return roleArn, &authTokenFetcher{ + Namespace: k.params.PodInfo.Namespace, + Audiences: audiences, + ServiceAccount: k.params.PodInfo.ServiceAccountName, + k8sClient: k.k8sClient.CoreV1(), + }, nil +} + +func (k *AWSIAMAuth) AuthRequest(ctx context.Context) (path string, body map[string]any, additionalHeaders map[string]string, err error) { + + roleArn, tokenFetcher, err := getTokenFetcher(ctx, k) + if err != nil { + return "", nil, nil, err + } + + webIdentityProvider := stscreds.NewWebIdentityRoleProviderWithOptions( + k.stsClient, roleArn, "vault-csi-provider", tokenFetcher) + + awsConfig := setupConfig(k.params, credentials.NewCredentials(webIdentityProvider)) + + sess, err := session.NewSession(awsConfig) + if err != nil { + return "", nil, nil, SanitizeErr(err) + } + + awsCredentials, err := sess.Config.Credentials.Get() + if err != nil { + return "", nil, nil, SanitizeErr(err) + } + + credentialsConfig := awsutil.CredentialsConfig{ + AccessKey: awsCredentials.AccessKeyID, + SecretKey: awsCredentials.SecretAccessKey, + SessionToken: awsCredentials.SessionToken, + Logger: k.logger, + } + + credChainCredentials, err := credentialsConfig.GenerateCredentialChain() + if err != nil { + return "", nil, nil, err + } + if credChainCredentials == nil { + return "", nil, nil, fmt.Errorf("could not compile valid credential providers from config") + } + + _, err = credChainCredentials.Get() + if err != nil { + return "", nil, nil, fmt.Errorf("failed to retrieve credentials from credential chain: %w", err) + } + + data, err := awsutil.GenerateLoginData(credChainCredentials, k.params.VaultAuth.AWSIAMAuth.XVaultAWSIAMServerID, *sess.Config.Region, k.logger) + if err != nil { + return "", nil, nil, fmt.Errorf("unable to generate login data for AWS auth endpoint: %w", err) + } + mountPath := k.params.VaultAuth.MouthPath + if mountPath == "" { + mountPath = defaultAWSMountPath + } + + // Add role if we have one. If not, Vault will infer the role name based + // on the IAM friendly name (iam auth type) or EC2 instance's + // AMI ID (ec2 auth type). + if k.params.VaultAuth.AWSIAMAuth.AWSIAMRole != "" { + data["role"] = k.params.VaultAuth.AWSIAMAuth.AWSIAMRole + } + + h := make(map[string]string) + if k.params.VaultAuth.AWSIAMAuth.XVaultAWSIAMServerID != "" { + h = map[string]string{ + "iam_server_id_header_value": k.params.VaultAuth.AWSIAMAuth.XVaultAWSIAMServerID} + } + return fmt.Sprintf("/v1/auth/%s/login", mountPath), data, h, nil +} diff --git a/internal/auth/awsiam_test.go b/internal/auth/awsiam_test.go new file mode 100644 index 00000000..4111441e --- /dev/null +++ b/internal/auth/awsiam_test.go @@ -0,0 +1,324 @@ +package auth + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/golang-jwt/jwt/v4" + "github.com/hashicorp/go-hclog" + authv1 "k8s.io/api/authentication/v1" + "k8s.io/apimachinery/pkg/runtime" + "net/http" + "net/url" + "strings" + "time" + + "errors" + "fmt" + "k8s.io/client-go/kubernetes" + "testing" + + "github.com/hashicorp/vault-csi-provider/internal/config" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" + k8stesting "k8s.io/client-go/testing" +) + +// Mock STS Client +type mockSTSClient struct { + stsiface.STSAPI +} + +func (m *mockSTSClient) AssumeRoleWithWebIdentity(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { + if aws.StringValue(input.RoleArn) == "" { + return nil, errors.New("role ARN is empty") + } + if aws.StringValue(input.WebIdentityToken) == "" { + return nil, errors.New("web identity token is empty") + } + + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &sts.Credentials{ + AccessKeyId: aws.String("mockAccessKey"), + SecretAccessKey: aws.String("mockSecretKey"), + SessionToken: aws.String("mockSessionToken"), + Expiration: aws.Time(time.Now().Add(1 * time.Hour)), + }, + }, nil +} + +func (m *mockSTSClient) AssumeRoleWithWebIdentityRequest(input *sts.AssumeRoleWithWebIdentityInput) (*request.Request, *sts.AssumeRoleWithWebIdentityOutput) { + _ = input + req := &request.Request{ + HTTPRequest: &http.Request{ + Method: "POST", + URL: &url.URL{Scheme: "https", Host: "sts.amazonaws.com", Path: "/"}, + Header: make(http.Header), + }, + Operation: &request.Operation{ + Name: "AssumeRoleWithWebIdentity", + HTTPMethod: "POST", + HTTPPath: "/", + }, + Data: &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &sts.Credentials{ + AccessKeyId: aws.String("mockAccessKey"), + SecretAccessKey: aws.String("mockSecretKey"), + SessionToken: aws.String("mockSessionToken"), + Expiration: aws.Time(time.Now().Add(1 * time.Hour)), + }, + }, + } + return req, req.Data.(*sts.AssumeRoleWithWebIdentityOutput) +} + +// GenerateDummyPrivateKey generates a dummy RSA private key for testing. +func GenerateDummyPrivateKey() (string, error) { + // Generate a new RSA private key. + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", fmt.Errorf("failed to generate private key: %w", err) + } + + // Encode the private key into PEM format. + privKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + + return string(privKeyPEM), nil +} + +const ( + dummyIssuer = "https://oidc.eks.us-east-1.amazonaws.com/id/ABCDEFG7383928EEC764D2049AE19A7F5" + // Mock service account + serviceAccountName = "test-service-account" + namespace = "test-namespace" + roleArn = "arn:aws:iam::123456789012:role/test-role" + tokenAudience = "sts.amazonaws.com" +) + +// GenerateValidToken generates a Kubernetes-like ServiceAccount token. +func GenerateMockValidToken(privateKey []byte, audiences []string, expiration time.Duration) (string, error) { + key, err := jwt.ParseRSAPrivateKeyFromPEM(privateKey) + if err != nil { + return "", fmt.Errorf("unable to parse private key: %w", err) + } + + now := time.Now() + claims := jwt.MapClaims{ + "iss": dummyIssuer, + "sub": "system:serviceaccount:" + namespace + ":" + serviceAccountName, + "aud": audiences, + "exp": now.Add(expiration).Unix(), + "iat": now.Unix(), + "nbf": now.Unix(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + signedToken, err := token.SignedString(key) + if err != nil { + return "", fmt.Errorf("unable to sign token: %w", err) + } + + return signedToken, nil +} + +func MockNewIAMAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params config.Parameters, defaultMountPath string) (*AWSIAMAuth, error) { + return &AWSIAMAuth{ + logger: logger, + k8sClient: k8sClient, + params: params, + defaultMountPath: defaultMountPath, + stsClient: &mockSTSClient{}, + }, nil + +} + +func SetupFakeClientWithTokenReactor() *fake.Clientset { + fakeClient := fake.NewClientset() + + // Add reactor for ServiceAccount token creation + fakeClient.PrependReactor("create", "serviceaccounts/token", func(action k8stesting.Action) (handled bool, ret runtime.Object, err error) { + createAction, ok := action.(k8stesting.CreateAction) + if !ok { + return false, nil, fmt.Errorf("invalid action type") + } + + tokenRequest, ok := createAction.GetObject().(*authv1.TokenRequest) + if !ok { + return false, nil, fmt.Errorf("unexpected object type: %T", createAction.GetObject()) + } + + if !strings.Contains(strings.Join(tokenRequest.Spec.Audiences, ","), "sts.amazonaws.com") { + return true, nil, fmt.Errorf("invalid audience") + } + + privateKey, err := GenerateDummyPrivateKey() + if err != nil { + fmt.Printf("Error generating private key: %v\n", err) + } + + token, err := GenerateMockValidToken([]byte(privateKey), tokenRequest.Spec.Audiences, 1*time.Hour) + if err != nil { + return true, nil, fmt.Errorf("failed to generate token: %w", err) + } + + // Mock TokenResponse + expiration := metav1.NewTime(time.Now().Add(1 * time.Hour)) + tokenResponse := &authv1.TokenRequest{ + Status: authv1.TokenRequestStatus{ + Token: token, + ExpirationTimestamp: expiration, + }, + } + + return true, tokenResponse, nil + }) + + return fakeClient +} + +func TestAuthRequest(t *testing.T) { + // Mock Kubernetes client + k8sClient := SetupFakeClientWithTokenReactor() + + // Create a mock service account with annotations + mockSA := &corev1.ServiceAccount{ + TypeMeta: metav1.TypeMeta{ + Kind: "ServiceAccount", + APIVersion: "v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: serviceAccountName, + Namespace: namespace, + Annotations: map[string]string{ + roleARNAnnotation: roleArn, + audienceAnnotation: tokenAudience, + }, + }, + } + + ctx := context.TODO() + + _, err := k8sClient.CoreV1().ServiceAccounts(namespace).Create(ctx, mockSA, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("failed to create mock service account: %v", err) + } + + // Create a logger + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test-logger", + Level: hclog.Debug, + }) + + // Mock parameters + params := config.Parameters{ + PodInfo: config.PodInfo{ + Namespace: namespace, + ServiceAccountName: serviceAccountName, + }, + VaultAuth: config.Auth{ + MouthPath: "awstest", + AWSIAMAuth: config.AWSIAMAuth{ + Region: "us-east-1", + AWSIAMRole: "test-role", + XVaultAWSIAMServerID: "test-server-id", + }, + }, + } + + // Initialize Mock AWSIAMAuth + // Initialize AWSIAMAuth + auth, err := MockNewIAMAuth(logger, k8sClient, params, "aws") + if err != nil { + t.Fatalf("failed to create AWSIAMAuth: %v", err) + } + + // Call AuthRequest + path, body, headers, err := auth.AuthRequest(context.TODO()) + if err != nil { + t.Fatalf("AuthRequest failed: %v", err) + } + + // Validate outputs + expectedPath := "/v1/auth/awstest/login" + if path != expectedPath { + t.Errorf("expected path %s, got %s", expectedPath, path) + } + + if body["role"] != "test-role" { + t.Errorf("expected role %s, got %s", "test-role", body["role"]) + } + + if len(headers) == 0 { + t.Errorf("expected headers, got none") + } + + if len(headers) > 0 && headers["iam_server_id_header_value"] != "test-server-id" { + t.Errorf("unexpected IAM server ID header value: %s", headers["iam_server_id_header_value"]) + } +} + +func TestAuthRequestMissingAnnotations(t *testing.T) { + // Mock Kubernetes client + k8sClient := fake.NewClientset() + + // Create a service account without annotations + serviceAccountName := "test-service-account" + namespace := "test-namespace" + + mockSA := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: serviceAccountName, + Namespace: namespace, + }, + } + _, err := k8sClient.CoreV1().ServiceAccounts(namespace).Create(context.TODO(), mockSA, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("failed to create mock service account: %v", err) + } + + // Create a logger + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test-logger", + Level: hclog.Debug, + }) + + // Mock parameters + params := config.Parameters{ + PodInfo: config.PodInfo{ + Namespace: namespace, + ServiceAccountName: serviceAccountName, + }, + VaultAuth: config.Auth{ + AWSIAMAuth: config.AWSIAMAuth{ + Region: "us-east-1", + }, + }, + } + + // Initialize AWSIAMAuth + auth, err := NewAWSIAMAuth(logger, k8sClient, params, "aws") + if err != nil { + t.Fatalf("failed to create AWSIAMAuth: %v", err) + } + + // Call AuthRequest and expect an error + _, _, _, err = auth.AuthRequest(context.TODO()) + if err == nil { + t.Fatalf("expected error, got none") + } + + expectedError := fmt.Sprintf("an IAM role must be associated with service account %s (namespace: %s)", serviceAccountName, namespace) + if err.Error() != expectedError { + t.Errorf("expected error %s, got %s", expectedError, err.Error()) + } +} diff --git a/internal/auth/kubernetes_jwt.go b/internal/auth/kubernetes_jwt.go index 1fc0205a..f6eb80e8 100644 --- a/internal/auth/kubernetes_jwt.go +++ b/internal/auth/kubernetes_jwt.go @@ -25,40 +25,40 @@ type KubernetesJWTAuth struct { defaultMountPath string } -func NewKubernetesJWTAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params config.Parameters, defaultMountPath string) *KubernetesJWTAuth { +func NewKubernetesJWTAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params config.Parameters, defaultMountPath string) (*KubernetesJWTAuth, error) { return &KubernetesJWTAuth{ logger: logger, k8sClient: k8sClient, params: params, defaultMountPath: defaultMountPath, - } + }, nil } // AuthRequest returns the request path and body required to authenticate // using the configured auth role in Vault. If no appropriate JWT is provided // in the CSI mount request, it will create a new one. -func (k *KubernetesJWTAuth) AuthRequest(ctx context.Context) (path string, body map[string]string, err error) { +func (k *KubernetesJWTAuth) AuthRequest(ctx context.Context) (path string, body map[string]any, additionalHeaders map[string]string, err error) { jwt := k.params.PodInfo.ServiceAccountToken if jwt == "" { k.logger.Debug("no suitable token found in mount request, using self-generated service account JWT") var err error jwt, err = k.createJWTToken(ctx, k.params.PodInfo, k.params.Audience) if err != nil { - return "", nil, err + return "", nil, nil, err } } else { k.logger.Debug("using token from mount request for login") } - mountPath := k.params.VaultAuthMountPath + mountPath := k.params.VaultAuth.MouthPath if mountPath == "" { mountPath = k.defaultMountPath } - return fmt.Sprintf("/v1/auth/%s/login", mountPath), map[string]string{ + return fmt.Sprintf("/v1/auth/%s/login", mountPath), map[string]any{ "jwt": jwt, "role": k.params.VaultRoleName, - }, nil + }, nil, nil } func (k *KubernetesJWTAuth) createJWTToken(ctx context.Context, podInfo config.PodInfo, audience string) (string, error) { diff --git a/internal/auth/kubernetes_jwt_test.go b/internal/auth/kubernetes_jwt_test.go new file mode 100644 index 00000000..470917de --- /dev/null +++ b/internal/auth/kubernetes_jwt_test.go @@ -0,0 +1,115 @@ +package auth + +import ( + "context" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault-csi-provider/internal/config" + "github.com/stretchr/testify/assert" + authenticationv1 "k8s.io/api/authentication/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/kubernetes/fake" + k8stesting "k8s.io/client-go/testing" + "testing" +) + +func TestAuthRequestWithExistingToken(t *testing.T) { + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test-logger", + Level: hclog.Debug, + }) + + // Mock Kubernetes client + k8sClient := fake.NewClientset() + + params := config.Parameters{ + PodInfo: config.PodInfo{ + ServiceAccountToken: "existing-token", + }, + VaultAuth: config.Auth{ + MouthPath: "kubernetes", + }, + VaultRoleName: "test-role", + } + + auth, err := NewKubernetesJWTAuth(logger, k8sClient, params, "kubernetes") + assert.NoError(t, err) + + path, body, _, err := auth.AuthRequest(context.TODO()) + assert.NoError(t, err) + assert.Equal(t, "/v1/auth/kubernetes/login", path) + assert.Equal(t, "existing-token", body["jwt"]) + assert.Equal(t, "test-role", body["role"]) +} + +func TestAuthRequestWithGeneratedToken(t *testing.T) { + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test-logger", + Level: hclog.Debug, + }) + + // Mock Kubernetes client with service account token response + token := "generated-token" + k8sClient := fake.NewClientset() + k8sClient.Fake.PrependReactor("create", "serviceaccounts", func(action k8stesting.Action) (handled bool, ret runtime.Object, err error) { + return true, &authenticationv1.TokenRequest{ + Status: authenticationv1.TokenRequestStatus{ + Token: token, + }, + }, nil + }) + + params := config.Parameters{ + PodInfo: config.PodInfo{ + Namespace: "default", + ServiceAccountName: "default", + UID: "1234", + }, + VaultAuth: config.Auth{ + MouthPath: "kubernetes", + }, + Audience: "vault", + VaultRoleName: "test-role", + } + + auth, err := NewKubernetesJWTAuth(logger, k8sClient, params, "kubernetes") + assert.NoError(t, err) + + path, body, _, err := auth.AuthRequest(context.TODO()) + assert.NoError(t, err) + assert.Equal(t, "/v1/auth/kubernetes/login", path) + assert.Equal(t, token, body["jwt"]) + assert.Equal(t, "test-role", body["role"]) +} + +func TestCreateJWTToken(t *testing.T) { + logger := hclog.New(&hclog.LoggerOptions{ + Name: "test-logger", + Level: hclog.Debug, + }) + + // Mock Kubernetes client with token generation + token := "generated-token" + k8sClient := fake.NewClientset() + k8sClient.Fake.PrependReactor("create", "serviceaccounts", func(action k8stesting.Action) (handled bool, ret runtime.Object, err error) { + return true, &authenticationv1.TokenRequest{ + Status: authenticationv1.TokenRequestStatus{ + Token: token, + }, + }, nil + }) + + auth := &KubernetesJWTAuth{ + logger: logger, + k8sClient: k8sClient, + defaultMountPath: "kubernetes", + } + + jwt, err := auth.createJWTToken(context.TODO(), config.PodInfo{ + Namespace: "default", + ServiceAccountName: "default", + UID: "1234", + Name: "test-pod", + }, "vault") + assert.NoError(t, err) + assert.Equal(t, token, jwt) +} diff --git a/internal/client/client.go b/internal/client/client.go index 048ee7cf..5db72c95 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -80,7 +80,7 @@ func overlayConfig(cfg *api.Config, vaultAddr string, tlsConfig api.TLSConfig) e // We follow this pattern because we assume Vault Agent is caching and renewing // our auth token, and we have no universal way to check it's still valid and // in the Agent's cache before making a request. -func (c *Client) RequestSecret(ctx context.Context, authMethod *auth.KubernetesJWTAuth, secretConfig config.Secret) (*api.Secret, error) { +func (c *Client) RequestSecret(ctx context.Context, authMethod auth.Auth, secretConfig config.Secret) (*api.Secret, error) { // Ensure we have a token available. authed, err := c.auth(ctx, authMethod, "") if err != nil { @@ -125,7 +125,7 @@ func (c *Client) RequestSecret(ctx context.Context, authMethod *auth.KubernetesJ // authentications so that when a token expires, multiple consumers asking it // to reauthenticate at the same time only trigger one new authentication with // Vault. -func (c *Client) auth(ctx context.Context, authMethod *auth.KubernetesJWTAuth, failedToken string) (authed bool, err error) { +func (c *Client) auth(ctx context.Context, authMethod auth.Auth, failedToken string) (authed bool, err error) { c.mtx.Lock() defer c.mtx.Unlock() @@ -136,11 +136,17 @@ func (c *Client) auth(ctx context.Context, authMethod *auth.KubernetesJWTAuth, f } c.logger.Debug("performing vault login") - path, body, err := authMethod.AuthRequest(ctx) + path, body, additionalHeaders, err := authMethod.AuthRequest(ctx) if err != nil { return false, err } + if len(additionalHeaders) > 0 { + for k, v := range additionalHeaders { + c.inner.AddHeader(k, v) + } + } + req := c.inner.NewRequest(http.MethodPost, path) if err := req.SetJSONBody(body); err != nil { return false, err diff --git a/internal/client/client_test.go b/internal/client/client_test.go index d333cbc1..8953c7b3 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -117,7 +117,8 @@ func TestRequestSecret_OnlyAuthenticatesOnce(t *testing.T) { k8sClient := fake.NewSimpleClientset( &corev1.ServiceAccount{}, ) - authMethod := auth.NewKubernetesJWTAuth(hclog.Default(), k8sClient, config.Parameters{}, "") + authMethod, err := auth.NewKubernetesJWTAuth(hclog.Default(), k8sClient, config.Parameters{}, "") + require.NoError(t, err) client, err := New(hclog.Default(), config.Parameters{}, flagsConfig) require.NoError(t, err) diff --git a/internal/config/config.go b/internal/config/config.go index fb4af20f..86ebeb18 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -39,6 +39,7 @@ type FlagsConfig struct { CacheSize int VaultAddr string + VaultAuthType string VaultMount string VaultNamespace string @@ -61,6 +62,28 @@ func (fc FlagsConfig) TLSConfig() api.TLSConfig { } } +type AWSIAMAuth struct { + XVaultAWSIAMServerID string `yaml:"xVaultAWSIAMServerID,omitempty"` + Region string `yaml:"region,omitempty"` + AWSIAMRole string `yaml:"awsIAMRole,omitempty"` +} + +// JWTAuth : placeholder, for any future values +type JWTAuth struct { +} + +// K8sAuth : placeholder, for any future values +type K8sAuth struct { +} + +type Auth struct { + Type string `yaml:"type"` + MouthPath string `yaml:"mouthPath"` // Preferred way to specify mount path + AWSIAMAuth AWSIAMAuth `yaml:"aws,omitempty"` + JWTAuth JWTAuth `yaml:"jwt,omitempty"` // Placeholder, for any future values + K8sAuth K8sAuth `yaml:"k8s,omitempty"` // Placeholder, for any future values +} + // Parameters stores the parameters specified in a mount request's `Attributes` field. // It consists of the parameters section from the SecretProviderClass being mounted // and pod metadata provided by the driver. @@ -73,9 +96,10 @@ func (fc FlagsConfig) TLSConfig() api.TLSConfig { type Parameters struct { VaultAddress string VaultRoleName string - VaultAuthMountPath string + VaultAuthMountPath string // Still supported for backward compatibility. Preferred way is under auth block. VaultNamespace string VaultTLSConfig api.TLSConfig + VaultAuth Auth Secrets []Secret PodInfo PodInfo Audience string @@ -129,6 +153,25 @@ func parseParameters(parametersStr string) (Parameters, error) { } var parameters Parameters + authBlockDefinedByUser := false + authBlock, ok := params["auth"] + if !ok { + // If auth block is missing, default to kubernetes auth + // This will help with backward compatibility + params["auth"] = "type: kubernetes" + } else { + authBlockDefinedByUser = true + } + err = yaml.Unmarshal([]byte(authBlock), ¶meters.VaultAuth) + if err != nil { + return Parameters{}, err + } + if authBlockDefinedByUser { + if parameters.VaultAuth.Type != "aws" && parameters.VaultAuth.Type != "kubernetes" && parameters.VaultAuth.Type != "jwt" { + return Parameters{}, errors.New("unsupported auth type") + } + } + parameters.VaultRoleName = params["roleName"] parameters.VaultAddress = params["vaultAddress"] parameters.VaultNamespace = params["vaultNamespace"] @@ -137,16 +180,25 @@ func parseParameters(parametersStr string) (Parameters, error) { parameters.VaultTLSConfig.TLSServerName = params["vaultTLSServerName"] parameters.VaultTLSConfig.ClientCert = params["vaultTLSClientCertPath"] parameters.VaultTLSConfig.ClientKey = params["vaultTLSClientKeyPath"] + + // Continuing to support these parameters to support backward compatibility + // But only if auth type is kubernetes/jwt + // If explicitly set inside the auth block, these params will be ignored k8sMountPath, k8sSet := params["vaultKubernetesMountPath"] authMountPath, authSet := params["vaultAuthMountPath"] switch { case k8sSet && authSet: return Parameters{}, fmt.Errorf("cannot set both vaultKubernetesMountPath and vaultAuthMountPath") case k8sSet: - parameters.VaultAuthMountPath = k8sMountPath + if (parameters.VaultAuth.Type == "kubernetes" || parameters.VaultAuth.Type == "jwt") && parameters.VaultAuth.MouthPath == "" { + parameters.VaultAuth.MouthPath = k8sMountPath + } case authSet: - parameters.VaultAuthMountPath = authMountPath + if (parameters.VaultAuth.Type == "kubernetes" || parameters.VaultAuth.Type == "jwt") && parameters.VaultAuth.MouthPath == "" { + parameters.VaultAuth.MouthPath = authMountPath + } } + parameters.VaultAuthMountPath = parameters.VaultAuth.MouthPath parameters.PodInfo.Name = params["csi.storage.k8s.io/pod.name"] parameters.PodInfo.UID = types.UID(params["csi.storage.k8s.io/pod.uid"]) parameters.PodInfo.Namespace = params["csi.storage.k8s.io/pod.namespace"] diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 8f581e06..546b7678 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -145,6 +145,7 @@ func TestParseConfig(t *testing.T) { name: "set all options", targetPath: targetPath, parameters: map[string]string{ + "auth": "type: kubernetes", "roleName": "example-role", "vaultSkipTLSVerify": "true", "vaultAddress": "my-vault-address", @@ -171,6 +172,10 @@ func TestParseConfig(t *testing.T) { VaultAddress: "my-vault-address", VaultNamespace: "my-vault-namespace", VaultAuthMountPath: "my-mount-path", + VaultAuth: Auth{ + Type: "kubernetes", + MouthPath: "my-mount-path", + }, Secrets: []Secret{ {"bar1", "v1/secret/foo1", "", "", nil, 0o600, ""}, }, diff --git a/internal/provider/provider.go b/internal/provider/provider.go index bfc18bca..dfc97f91 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -30,12 +30,12 @@ type provider struct { vaultResponseCache map[vaultResponseCacheKey]*api.Secret // Allows mocking Kubernetes API for tests. - authMethod *auth.KubernetesJWTAuth + authMethod auth.Auth hmacGenerator *hmacgen.HMACGenerator clientCache *clientcache.ClientCache } -func NewProvider(logger hclog.Logger, authMethod *auth.KubernetesJWTAuth, hmacGenerator *hmacgen.HMACGenerator, clientCache *clientcache.ClientCache) *provider { +func NewProvider(logger hclog.Logger, authMethod auth.Auth, hmacGenerator *hmacgen.HMACGenerator, clientCache *clientcache.ClientCache) *provider { p := &provider{ logger: logger, vaultResponseCache: make(map[vaultResponseCacheKey]*api.Secret), diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index 8d8e3b42..2db7d2cf 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -231,7 +231,8 @@ func TestHandleMountRequest(t *testing.T) { k8sClient := fake.NewSimpleClientset( &corev1.ServiceAccount{}, ) - authMethod := auth.NewKubernetesJWTAuth(hclog.Default(), k8sClient, spcConfig.Parameters, "") + authMethod, err := auth.NewKubernetesJWTAuth(hclog.Default(), k8sClient, spcConfig.Parameters, "") + require.NoError(t, err) hmacGenerator := hmac.NewHMACGenerator(k8sClient, &corev1.Secret{}) clientCache, err := clientcache.NewClientCache(hclog.Default(), 10) require.NoError(t, err) diff --git a/internal/server/server.go b/internal/server/server.go index eefb1e64..55018fe4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -53,7 +53,10 @@ func (s *Server) Mount(ctx context.Context, req *pb.MountRequest) (*pb.MountResp return nil, err } - authMethod := auth.NewKubernetesJWTAuth(s.logger.Named("auth"), s.k8sClient, cfg.Parameters, s.flagsConfig.VaultMount) + authMethod, err := auth.NewAuth(s.logger.Named("auth"), s.k8sClient, cfg.Parameters, s.flagsConfig.VaultMount) + if err != nil { + return nil, fmt.Errorf("error creating auth method: %w", err) + } provider := provider.NewProvider(s.logger.Named("provider"), authMethod, s.hmacGenerator, s.clientCache) resp, err := provider.HandleMountRequest(ctx, cfg, s.flagsConfig) if err != nil { diff --git a/main.go b/main.go index 2f536d5a..c21a5ed5 100644 --- a/main.go +++ b/main.go @@ -71,7 +71,8 @@ func realMain(logger hclog.Logger) error { flag.IntVar(&flags.CacheSize, "cache-size", 1000, "Set the maximum number of Vault tokens that will be cached in-memory. One Vault token will be stored for each pod on the same node that mounts secrets.") flag.StringVar(&flags.VaultAddr, "vault-addr", "", "Default address for connecting to Vault. Can also be specified via the VAULT_ADDR environment variable.") - flag.StringVar(&flags.VaultMount, "vault-mount", "kubernetes", "Default Vault mount path for authentication. Can refer to a Kubernetes or JWT auth mount.") + flag.StringVar(&flags.VaultAuthType, "vault-auth-type", "kubernetes", "Default auth type for Vault.") + flag.StringVar(&flags.VaultMount, "vault-mount", "kubernetes", "Default Vault mount path for authentication. Can refer to a Kubernetes or JWT auth mount or AWS auth mount. Can also be specified via the VAULT_MOUNT environment variable.") flag.StringVar(&flags.VaultNamespace, "vault-namespace", "", "Default Vault namespace for Vault requests. Can also be specified via the VAULT_NAMESPACE environment variable.") flag.StringVar(&flags.TLSCACertPath, "vault-tls-ca-cert", "", "Path on disk to a single PEM-encoded CA certificate to trust for Vault. Takes precendence over -vault-tls-ca-directory. Can also be specified via the VAULT_CACERT environment variable.") From 9bf255751dbd54717ea4915da65c95d492d1110d Mon Sep 17 00:00:00 2001 From: Gaurav Dasson Date: Wed, 25 Dec 2024 13:52:53 -0600 Subject: [PATCH 2/2] Added support for AWS IAM Auth Method. --- examples/SecretProviderClassExample1.md | 2 +- internal/auth/auth.go | 6 ++-- internal/auth/awsiam.go | 18 +++++------ internal/auth/awsiam_test.go | 40 ++++++++++++------------- internal/auth/kubernetes_jwt.go | 2 +- internal/auth/kubernetes_jwt_test.go | 4 +-- internal/client/client_test.go | 2 +- internal/provider/provider_test.go | 2 +- 8 files changed, 38 insertions(+), 38 deletions(-) diff --git a/examples/SecretProviderClassExample1.md b/examples/SecretProviderClassExample1.md index c0be201a..29de0777 100644 --- a/examples/SecretProviderClassExample1.md +++ b/examples/SecretProviderClassExample1.md @@ -1,4 +1,4 @@ -Below is an example for a SecretProviderClass for Vault with AWS IAM auth method. +Below is an example for a SecretProviderClass for Vault with Kubernetes auth method. ```yaml apiVersion: v1 diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 91f4b018..c1424742 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -13,11 +13,11 @@ type Auth interface { func NewAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params config.Parameters, defaultMountPath string) (Auth, error) { if params.VaultAuth.Type == "kubernetes" || params.VaultAuth.Type == "jwt" { - return NewKubernetesJWTAuth(logger, k8sClient, params, defaultMountPath) + return newKubernetesJWTAuth(logger, k8sClient, params, defaultMountPath) } if params.VaultAuth.Type == "aws" { - return NewAWSIAMAuth(logger, k8sClient, params, defaultMountPath) + return newAWSIAMAuth(logger, k8sClient, params, defaultMountPath) } // Default to Kubernetes - return NewKubernetesJWTAuth(logger, k8sClient, params, defaultMountPath) + return newKubernetesJWTAuth(logger, k8sClient, params, defaultMountPath) } diff --git a/internal/auth/awsiam.go b/internal/auth/awsiam.go index 9a22c3ae..43887f80 100644 --- a/internal/auth/awsiam.go +++ b/internal/auth/awsiam.go @@ -49,7 +49,7 @@ func setupConfig(params config.Parameters, credentials *credentials.Credentials) } handlers := defaults.Handlers() handlers.Build.PushBack(request.WithAppendUserAgent("vault-csi-provider")) - awsConfig := aws.NewConfig().WithEndpointResolver(ResolveEndpoint()) + awsConfig := aws.NewConfig().WithEndpointResolver(resolveEndpoint()) if regionAWS != "" { awsConfig.WithRegion(regionAWS) } @@ -60,7 +60,7 @@ func setupConfig(params config.Parameters, credentials *credentials.Credentials) return awsConfig } -func NewAWSIAMAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params config.Parameters, defaultMountPath string) (*AWSIAMAuth, error) { +func newAWSIAMAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params config.Parameters, defaultMountPath string) (*AWSIAMAuth, error) { // Get an initial session to use for STS calls. awsConfig := setupConfig(params, nil) sess, err := session.NewSession(awsConfig) @@ -77,7 +77,7 @@ func NewAWSIAMAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params c }, nil } -func ResolveEndpointWithServiceMap(customEndpoints map[string]string) endpoints.ResolverFunc { +func resolveEndpointWithServiceMap(customEndpoints map[string]string) endpoints.ResolverFunc { defaultResolver := endpoints.DefaultResolver() return func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { if ep, ok := customEndpoints[service]; ok { @@ -89,14 +89,14 @@ func ResolveEndpointWithServiceMap(customEndpoints map[string]string) endpoints. } } -// ResolveEndpoint returns a ResolverFunc with +// resolveEndpoint returns a ResolverFunc with // customizable endpoints. -func ResolveEndpoint() endpoints.ResolverFunc { +func resolveEndpoint() endpoints.ResolverFunc { customEndpoints := make(map[string]string) if v := os.Getenv(STSEndpointEnv); v != "" { customEndpoints["sts"] = v } - return ResolveEndpointWithServiceMap(customEndpoints) + return resolveEndpointWithServiceMap(customEndpoints) } var regexReqIDs = []*regexp.Regexp{ @@ -104,7 +104,7 @@ var regexReqIDs = []*regexp.Regexp{ regexp.MustCompile(` Credential=.+`), } -func SanitizeErr(err error) error { +func sanitizeErr(err error) error { msg := err.Error() for _, regex := range regexReqIDs { msg = string(regex.ReplaceAll([]byte(msg), nil)) @@ -181,12 +181,12 @@ func (k *AWSIAMAuth) AuthRequest(ctx context.Context) (path string, body map[str sess, err := session.NewSession(awsConfig) if err != nil { - return "", nil, nil, SanitizeErr(err) + return "", nil, nil, sanitizeErr(err) } awsCredentials, err := sess.Config.Credentials.Get() if err != nil { - return "", nil, nil, SanitizeErr(err) + return "", nil, nil, sanitizeErr(err) } credentialsConfig := awsutil.CredentialsConfig{ diff --git a/internal/auth/awsiam_test.go b/internal/auth/awsiam_test.go index 4111441e..8278be98 100644 --- a/internal/auth/awsiam_test.go +++ b/internal/auth/awsiam_test.go @@ -31,6 +31,15 @@ import ( k8stesting "k8s.io/client-go/testing" ) +const ( + dummyIssuer = "https://oidc.eks.us-east-1.amazonaws.com/id/ABCDEFG7383928EEC764D2049AE19A7F5" + // Mock service account + serviceAccountName = "test-service-account" + namespace = "test-namespace" + roleArn = "arn:aws:iam::123456789012:role/test-role" + tokenAudience = "sts.amazonaws.com" +) + // Mock STS Client type mockSTSClient struct { stsiface.STSAPI @@ -79,8 +88,8 @@ func (m *mockSTSClient) AssumeRoleWithWebIdentityRequest(input *sts.AssumeRoleWi return req, req.Data.(*sts.AssumeRoleWithWebIdentityOutput) } -// GenerateDummyPrivateKey generates a dummy RSA private key for testing. -func GenerateDummyPrivateKey() (string, error) { +// generateDummyPrivateKey generates a dummy RSA private key for testing. +func generateDummyPrivateKey() (string, error) { // Generate a new RSA private key. key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { @@ -96,17 +105,8 @@ func GenerateDummyPrivateKey() (string, error) { return string(privKeyPEM), nil } -const ( - dummyIssuer = "https://oidc.eks.us-east-1.amazonaws.com/id/ABCDEFG7383928EEC764D2049AE19A7F5" - // Mock service account - serviceAccountName = "test-service-account" - namespace = "test-namespace" - roleArn = "arn:aws:iam::123456789012:role/test-role" - tokenAudience = "sts.amazonaws.com" -) - -// GenerateValidToken generates a Kubernetes-like ServiceAccount token. -func GenerateMockValidToken(privateKey []byte, audiences []string, expiration time.Duration) (string, error) { +// generateValidToken generates a Kubernetes-like ServiceAccount token. +func generateMockValidToken(privateKey []byte, audiences []string, expiration time.Duration) (string, error) { key, err := jwt.ParseRSAPrivateKeyFromPEM(privateKey) if err != nil { return "", fmt.Errorf("unable to parse private key: %w", err) @@ -131,7 +131,7 @@ func GenerateMockValidToken(privateKey []byte, audiences []string, expiration ti return signedToken, nil } -func MockNewIAMAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params config.Parameters, defaultMountPath string) (*AWSIAMAuth, error) { +func mockNewIAMAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params config.Parameters, defaultMountPath string) (*AWSIAMAuth, error) { return &AWSIAMAuth{ logger: logger, k8sClient: k8sClient, @@ -142,7 +142,7 @@ func MockNewIAMAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params } -func SetupFakeClientWithTokenReactor() *fake.Clientset { +func setupFakeClientWithTokenReactor() *fake.Clientset { fakeClient := fake.NewClientset() // Add reactor for ServiceAccount token creation @@ -161,12 +161,12 @@ func SetupFakeClientWithTokenReactor() *fake.Clientset { return true, nil, fmt.Errorf("invalid audience") } - privateKey, err := GenerateDummyPrivateKey() + privateKey, err := generateDummyPrivateKey() if err != nil { fmt.Printf("Error generating private key: %v\n", err) } - token, err := GenerateMockValidToken([]byte(privateKey), tokenRequest.Spec.Audiences, 1*time.Hour) + token, err := generateMockValidToken([]byte(privateKey), tokenRequest.Spec.Audiences, 1*time.Hour) if err != nil { return true, nil, fmt.Errorf("failed to generate token: %w", err) } @@ -188,7 +188,7 @@ func SetupFakeClientWithTokenReactor() *fake.Clientset { func TestAuthRequest(t *testing.T) { // Mock Kubernetes client - k8sClient := SetupFakeClientWithTokenReactor() + k8sClient := setupFakeClientWithTokenReactor() // Create a mock service account with annotations mockSA := &corev1.ServiceAccount{ @@ -237,7 +237,7 @@ func TestAuthRequest(t *testing.T) { // Initialize Mock AWSIAMAuth // Initialize AWSIAMAuth - auth, err := MockNewIAMAuth(logger, k8sClient, params, "aws") + auth, err := mockNewIAMAuth(logger, k8sClient, params, "aws") if err != nil { t.Fatalf("failed to create AWSIAMAuth: %v", err) } @@ -306,7 +306,7 @@ func TestAuthRequestMissingAnnotations(t *testing.T) { } // Initialize AWSIAMAuth - auth, err := NewAWSIAMAuth(logger, k8sClient, params, "aws") + auth, err := newAWSIAMAuth(logger, k8sClient, params, "aws") if err != nil { t.Fatalf("failed to create AWSIAMAuth: %v", err) } diff --git a/internal/auth/kubernetes_jwt.go b/internal/auth/kubernetes_jwt.go index f6eb80e8..c8e708af 100644 --- a/internal/auth/kubernetes_jwt.go +++ b/internal/auth/kubernetes_jwt.go @@ -25,7 +25,7 @@ type KubernetesJWTAuth struct { defaultMountPath string } -func NewKubernetesJWTAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params config.Parameters, defaultMountPath string) (*KubernetesJWTAuth, error) { +func newKubernetesJWTAuth(logger hclog.Logger, k8sClient kubernetes.Interface, params config.Parameters, defaultMountPath string) (*KubernetesJWTAuth, error) { return &KubernetesJWTAuth{ logger: logger, k8sClient: k8sClient, diff --git a/internal/auth/kubernetes_jwt_test.go b/internal/auth/kubernetes_jwt_test.go index 470917de..da8ea444 100644 --- a/internal/auth/kubernetes_jwt_test.go +++ b/internal/auth/kubernetes_jwt_test.go @@ -31,7 +31,7 @@ func TestAuthRequestWithExistingToken(t *testing.T) { VaultRoleName: "test-role", } - auth, err := NewKubernetesJWTAuth(logger, k8sClient, params, "kubernetes") + auth, err := newKubernetesJWTAuth(logger, k8sClient, params, "kubernetes") assert.NoError(t, err) path, body, _, err := auth.AuthRequest(context.TODO()) @@ -71,7 +71,7 @@ func TestAuthRequestWithGeneratedToken(t *testing.T) { VaultRoleName: "test-role", } - auth, err := NewKubernetesJWTAuth(logger, k8sClient, params, "kubernetes") + auth, err := newKubernetesJWTAuth(logger, k8sClient, params, "kubernetes") assert.NoError(t, err) path, body, _, err := auth.AuthRequest(context.TODO()) diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 8953c7b3..49596356 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -117,7 +117,7 @@ func TestRequestSecret_OnlyAuthenticatesOnce(t *testing.T) { k8sClient := fake.NewSimpleClientset( &corev1.ServiceAccount{}, ) - authMethod, err := auth.NewKubernetesJWTAuth(hclog.Default(), k8sClient, config.Parameters{}, "") + authMethod, err := auth.NewAuth(hclog.Default(), k8sClient, config.Parameters{}, "") require.NoError(t, err) client, err := New(hclog.Default(), config.Parameters{}, flagsConfig) require.NoError(t, err) diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index 2db7d2cf..9080f32e 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -231,7 +231,7 @@ func TestHandleMountRequest(t *testing.T) { k8sClient := fake.NewSimpleClientset( &corev1.ServiceAccount{}, ) - authMethod, err := auth.NewKubernetesJWTAuth(hclog.Default(), k8sClient, spcConfig.Parameters, "") + authMethod, err := auth.NewAuth(hclog.Default(), k8sClient, spcConfig.Parameters, "") require.NoError(t, err) hmacGenerator := hmac.NewHMACGenerator(k8sClient, &corev1.Secret{}) clientCache, err := clientcache.NewClientCache(hclog.Default(), 10)