diff --git a/.changelog/3738.txt b/.changelog/3738.txt new file mode 100644 index 0000000000..2c423e83a5 --- /dev/null +++ b/.changelog/3738.txt @@ -0,0 +1,3 @@ +```release-note:bug +provider: Enforces strict hierarchy when selecting the credential source such as AWS Secrets Manager, provider attributes, or environment variables to prevent combining with values from different sources +``` diff --git a/internal/config/client.go b/internal/config/client.go index 4c1b09b0f9..574eb579a0 100644 --- a/internal/config/client.go +++ b/internal/config/client.go @@ -17,13 +17,14 @@ import ( matlasClient "go.mongodb.org/atlas/mongodbatlas" realmAuth "go.mongodb.org/realm/auth" "go.mongodb.org/realm/realm" - "golang.org/x/oauth2" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/logging" "github.com/mongodb-forks/digest" adminpreview "github.com/mongodb/atlas-sdk-go/admin" "github.com/mongodb/terraform-provider-mongodbatlas/version" + + "golang.org/x/oauth2" ) const ( @@ -41,41 +42,12 @@ const ( type AuthMethod int const ( - ServiceAccount AuthMethod = iota - Digest + Unknown AuthMethod = iota AccessToken - Unknown + ServiceAccount + Digest ) -// CredentialProvider interface for types that can provide MongoDB Atlas credentials -type CredentialProvider interface { - GetPublicKey() string - GetPrivateKey() string - GetClientID() string - GetClientSecret() string - GetAccessToken() string -} - -// IsDigestAuth checks if public/private key credentials are present -func IsDigestAuthPresent(cp CredentialProvider) bool { - return cp.GetPublicKey() != "" && cp.GetPrivateKey() != "" -} - -// IsServiceAccountAuth checks if client ID/secret credentials are present -func IsServiceAccountAuthPresent(cp CredentialProvider) bool { - return cp.GetClientID() != "" && cp.GetClientSecret() != "" -} - -// IsAccessTokenAuth checks if access token credentials are present -func IsAccessTokenAuthPresent(cp CredentialProvider) bool { - return cp.GetAccessToken() != "" -} - -// HasValidAuthCredentials checks if any valid authentication method is provided -func HasValidAuthCredentials(cp CredentialProvider) bool { - return IsDigestAuthPresent(cp) || IsServiceAccountAuthPresent(cp) || IsAccessTokenAuthPresent(cp) -} - var baseTransport = &http.Transport{ DialContext: (&net.Dialer{ Timeout: timeout, @@ -101,74 +73,33 @@ func tfLoggingInterceptor(base http.RoundTripper) http.RoundTripper { // MongoDBClient contains the mongodbatlas clients and configurations type MongoDBClient struct { - Atlas *matlasClient.Client - AtlasV2 *admin.APIClient - AtlasPreview *adminpreview.APIClient - AtlasV220240805 *admin20240805.APIClient // used in advanced_cluster to avoid adopting 2024-10-23 release with ISS autoscaling - AtlasV220240530 *admin20240530.APIClient // used in advanced_cluster and cloud_backup_schedule for avoiding breaking changes (supporting deprecated replication_specs.id) - AtlasV220241113 *admin20241113.APIClient // used in teams and atlas_users to avoiding breaking changes - Config *Config -} - -// Config contains the configurations needed to use SDKs -type Config struct { - AssumeRoleARN string - PublicKey string - PrivateKey string - BaseURL string - RealmBaseURL string - TerraformVersion string - ClientID string - ClientSecret string - AccessToken string -} - -// CredentialProvider implementation for Config -func (c *Config) GetPublicKey() string { return c.PublicKey } -func (c *Config) GetPrivateKey() string { return c.PrivateKey } -func (c *Config) GetClientID() string { return c.ClientID } -func (c *Config) GetClientSecret() string { return c.ClientSecret } -func (c *Config) GetAccessToken() string { return c.AccessToken } - -type SecretData struct { - PublicKey string `json:"public_key"` - PrivateKey string `json:"private_key"` + Atlas *matlasClient.Client + AtlasV2 *admin.APIClient + AtlasPreview *adminpreview.APIClient + AtlasV220240805 *admin20240805.APIClient // used in advanced_cluster to avoid adopting 2024-10-23 release with ISS autoscaling + AtlasV220240530 *admin20240530.APIClient // used in advanced_cluster and cloud_backup_schedule for avoiding breaking changes (supporting deprecated replication_specs.id) + AtlasV220241113 *admin20241113.APIClient // used in teams and atlas_users to avoiding breaking changes + Realm *RealmClient + BaseURL string // needed by organization resource + TerraformVersion string // needed by organization resource } -type UAMetadata struct { - Name string - Value string +type RealmClient struct { + publicKey string + privateKey string + realmBaseURL string + terraformVersion string } -func (c *Config) NewClient(ctx context.Context) (*MongoDBClient, error) { - transport := networkLoggingBaseTransport() - switch ResolveAuthMethod(c) { - case AccessToken: - tokenSource := oauth2.StaticTokenSource(&oauth2.Token{ - AccessToken: c.AccessToken, - TokenType: "Bearer", // Use a static bearer token with oauth2 transport. - }) - transport = &oauth2.Transport{ - Source: tokenSource, - Base: networkLoggingBaseTransport(), - } - case ServiceAccount: - tokenSource, err := getTokenSource(c, networkLoggingBaseTransport()) - if err != nil { - return nil, err - } - transport = &oauth2.Transport{ - Source: tokenSource, - Base: networkLoggingBaseTransport(), - } - case Digest: - transport = digest.NewTransportWithHTTPRoundTripper(c.PublicKey, c.PrivateKey, networkLoggingBaseTransport()) - case Unknown: +func NewClient(c *Credentials, terraformVersion string) (*MongoDBClient, error) { + userAgent := userAgent(terraformVersion) + client, err := getHTTPClient(c) + if err != nil { + return nil, err } - client := &http.Client{Transport: tfLoggingInterceptor(transport)} // Initialize the old SDK - optsAtlas := []matlasClient.ClientOpt{matlasClient.SetUserAgent(userAgent(c))} + optsAtlas := []matlasClient.ClientOpt{matlasClient.SetUserAgent(userAgent)} if c.BaseURL != "" { optsAtlas = append(optsAtlas, matlasClient.SetBaseURL(c.BaseURL)) } @@ -178,124 +109,137 @@ func (c *Config) NewClient(ctx context.Context) (*MongoDBClient, error) { } // Initialize the new SDK for different versions - sdkV2Client, err := c.newSDKV2Client(client) + sdkV2Client, err := newSDKV2Client(client, c.BaseURL, userAgent) if err != nil { return nil, err } - sdkPreviewClient, err := c.newSDKPreviewClient(client) + sdkPreviewClient, err := newSDKPreviewClient(client, c.BaseURL, userAgent) if err != nil { return nil, err } - sdkV220240530Client, err := c.newSDKV220240530Client(client) + sdkV220240530Client, err := newSDKV220240530Client(client, c.BaseURL, userAgent) if err != nil { return nil, err } - sdkV220240805Client, err := c.newSDKV220240805Client(client) + sdkV220240805Client, err := newSDKV220240805Client(client, c.BaseURL, userAgent) if err != nil { return nil, err } - sdkV220241113Client, err := c.newSDKV220241113Client(client) + sdkV220241113Client, err := newSDKV220241113Client(client, c.BaseURL, userAgent) if err != nil { return nil, err } + clients := &MongoDBClient{ - Atlas: atlasClient, - AtlasV2: sdkV2Client, - AtlasPreview: sdkPreviewClient, - AtlasV220240530: sdkV220240530Client, - AtlasV220240805: sdkV220240805Client, - AtlasV220241113: sdkV220241113Client, - Config: c, + Atlas: atlasClient, + AtlasV2: sdkV2Client, + AtlasPreview: sdkPreviewClient, + AtlasV220240530: sdkV220240530Client, + AtlasV220240805: sdkV220240805Client, + AtlasV220241113: sdkV220241113Client, + BaseURL: c.BaseURL, + TerraformVersion: terraformVersion, + Realm: &RealmClient{ + publicKey: c.PublicKey, + privateKey: c.PrivateKey, + realmBaseURL: c.RealmBaseURL, + terraformVersion: terraformVersion, + }, } return clients, nil } -func (c *Config) newSDKV2Client(client *http.Client) (*admin.APIClient, error) { - opts := []admin.ClientModifier{ - admin.UseHTTPClient(client), - admin.UseUserAgent(userAgent(c)), - admin.UseBaseURL(c.BaseURL), - admin.UseDebug(false)} - - sdk, err := admin.NewClient(opts...) - if err != nil { - return nil, err +func getHTTPClient(c *Credentials) (*http.Client, error) { + transport := networkLoggingBaseTransport() + switch c.AuthMethod() { + case AccessToken: + tokenSource := oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: c.AccessToken, + TokenType: "Bearer", // Use a static bearer token with oauth2 transport. + }) + transport = &oauth2.Transport{ + Source: tokenSource, + Base: networkLoggingBaseTransport(), + } + case ServiceAccount: + tokenSource, err := getTokenSource(c.ClientID, c.ClientSecret, c.BaseURL, networkLoggingBaseTransport()) + if err != nil { + return nil, err + } + transport = &oauth2.Transport{ + Source: tokenSource, + Base: networkLoggingBaseTransport(), + } + case Digest: + transport = digest.NewTransportWithHTTPRoundTripper(c.PublicKey, c.PrivateKey, networkLoggingBaseTransport()) + case Unknown: } - return sdk, nil + return &http.Client{Transport: tfLoggingInterceptor(transport)}, nil } -func (c *Config) newSDKPreviewClient(client *http.Client) (*adminpreview.APIClient, error) { - opts := []adminpreview.ClientModifier{ - adminpreview.UseHTTPClient(client), - adminpreview.UseUserAgent(userAgent(c)), - adminpreview.UseBaseURL(c.BaseURL), - adminpreview.UseDebug(false)} +func newSDKV2Client(client *http.Client, baseURL, userAgent string) (*admin.APIClient, error) { + return admin.NewClient( + admin.UseHTTPClient(client), + admin.UseUserAgent(userAgent), + admin.UseBaseURL(baseURL), + admin.UseDebug(false), + ) +} - sdk, err := adminpreview.NewClient(opts...) - if err != nil { - return nil, err - } - return sdk, nil +func newSDKPreviewClient(client *http.Client, baseURL, userAgent string) (*adminpreview.APIClient, error) { + return adminpreview.NewClient( + adminpreview.UseHTTPClient(client), + adminpreview.UseUserAgent(userAgent), + adminpreview.UseBaseURL(baseURL), + adminpreview.UseDebug(false), + ) } -func (c *Config) newSDKV220240530Client(client *http.Client) (*admin20240530.APIClient, error) { - opts := []admin20240530.ClientModifier{ +func newSDKV220240530Client(client *http.Client, baseURL, userAgent string) (*admin20240530.APIClient, error) { + return admin20240530.NewClient( admin20240530.UseHTTPClient(client), - admin20240530.UseUserAgent(userAgent(c)), - admin20240530.UseBaseURL(c.BaseURL), - admin20240530.UseDebug(false)} - - sdk, err := admin20240530.NewClient(opts...) - if err != nil { - return nil, err - } - return sdk, nil + admin20240530.UseUserAgent(userAgent), + admin20240530.UseBaseURL(baseURL), + admin20240530.UseDebug(false), + ) } -func (c *Config) newSDKV220240805Client(client *http.Client) (*admin20240805.APIClient, error) { - opts := []admin20240805.ClientModifier{ +func newSDKV220240805Client(client *http.Client, baseURL, userAgent string) (*admin20240805.APIClient, error) { + return admin20240805.NewClient( admin20240805.UseHTTPClient(client), - admin20240805.UseUserAgent(userAgent(c)), - admin20240805.UseBaseURL(c.BaseURL), - admin20240805.UseDebug(false)} - - sdk, err := admin20240805.NewClient(opts...) - if err != nil { - return nil, err - } - return sdk, nil + admin20240805.UseUserAgent(userAgent), + admin20240805.UseBaseURL(baseURL), + admin20240805.UseDebug(false), + ) } -func (c *Config) newSDKV220241113Client(client *http.Client) (*admin20241113.APIClient, error) { - opts := []admin20241113.ClientModifier{ +func newSDKV220241113Client(client *http.Client, baseURL, userAgent string) (*admin20241113.APIClient, error) { + return admin20241113.NewClient( admin20241113.UseHTTPClient(client), - admin20241113.UseUserAgent(userAgent(c)), - admin20241113.UseBaseURL(c.BaseURL), - admin20241113.UseDebug(false)} - - sdk, err := admin20241113.NewClient(opts...) - if err != nil { - return nil, err - } - return sdk, nil + admin20241113.UseUserAgent(userAgent), + admin20241113.UseBaseURL(baseURL), + admin20241113.UseDebug(false), + ) } -func (c *MongoDBClient) GetRealmClient(ctx context.Context) (*realm.Client, error) { - // Realm - if c.Config.PublicKey == "" && c.Config.PrivateKey == "" { +// Get in RealmClient is a method instead of Atlas fields so it's lazy initialized as it needs a roundtrip to authenticate. +func (r *RealmClient) Get(ctx context.Context) (*realm.Client, error) { + if r.publicKey == "" && r.privateKey == "" { return nil, errors.New("please set `public_key` and `private_key` in order to use the realm client") } - optsRealm := []realm.ClientOpt{realm.SetUserAgent(userAgent(c.Config))} + optsRealm := []realm.ClientOpt{ + realm.SetUserAgent(userAgent(r.terraformVersion)), + } authConfig := realmAuth.NewConfig(nil) - if c.Config.BaseURL != "" && c.Config.RealmBaseURL != "" { - adminURL := c.Config.RealmBaseURL + "api/admin/v3.0/" + if r.realmBaseURL != "" { + adminURL := r.realmBaseURL + "api/admin/v3.0/" optsRealm = append(optsRealm, realm.SetBaseURL(adminURL)) authConfig.AuthURL, _ = url.Parse(adminURL + "auth/providers/mongodb-cloud/login") } - token, err := authConfig.NewTokenFromCredentials(ctx, c.Config.PublicKey, c.Config.PrivateKey) + token, err := authConfig.NewTokenFromCredentials(ctx, r.publicKey, r.privateKey) if err != nil { return nil, err } @@ -359,30 +303,21 @@ func (c *MongoDBClient) UntypedAPICall(ctx context.Context, params *APICallParam return apiResp, err } -func userAgent(c *Config) string { - metadata := []UAMetadata{ +func userAgent(terraformVersion string) string { + metadata := []struct { + Name string + Value string + }{ {toolName, version.ProviderVersion}, - {terraformPlatformName, c.TerraformVersion}, + {terraformPlatformName, terraformVersion}, } var parts []string for _, info := range metadata { + if info.Value == "" { + continue + } part := fmt.Sprintf("%s/%s", info.Name, info.Value) parts = append(parts, part) } - return strings.Join(parts, " ") } - -// ResolveAuthMethod determines the authentication method from any credential provider -func ResolveAuthMethod(cg CredentialProvider) AuthMethod { - if IsAccessTokenAuthPresent(cg) { - return AccessToken - } - if IsServiceAccountAuthPresent(cg) { - return ServiceAccount - } - if IsDigestAuthPresent(cg) { - return Digest - } - return Unknown -} diff --git a/internal/config/credentials.go b/internal/config/credentials.go new file mode 100644 index 0000000000..14adc7bdb4 --- /dev/null +++ b/internal/config/credentials.go @@ -0,0 +1,185 @@ +package config + +import ( + "os" + + "github.com/mongodb/terraform-provider-mongodbatlas/internal/common/conversion" +) + +// Credentials has all the authentication fields, it also matches with fields that can be stored in AWS Secrets Manager. +type Credentials struct { + AccessToken string `json:"access_token"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + PublicKey string `json:"public_key"` + PrivateKey string `json:"private_key"` + BaseURL string `json:"base_url"` + RealmBaseURL string `json:"realm_base_url"` +} + +// GetCredentials follows the order of AWS Secrets Manager, provider vars and env vars. +func GetCredentials(providerVars, envVars *Vars, getAWSCredentials func(*AWSVars) (*Credentials, error)) (*Credentials, error) { + if awsVars := CoalesceAWSVars(providerVars.GetAWS(), envVars.GetAWS()); awsVars != nil { + awsCredentials, err := getAWSCredentials(awsVars) + if err != nil { + return nil, err + } + return awsCredentials, nil + } + if c := CoalesceCredentials(providerVars.GetCredentials(), envVars.GetCredentials()); c != nil { + return c, nil + } + return &Credentials{}, nil +} + +// AuthMethod follows the order of token, SA and PAK. +func (c *Credentials) AuthMethod() AuthMethod { + switch { + case c.HasAccessToken(): + return AccessToken + case c.HasServiceAccount(): + return ServiceAccount + case c.HasDigest(): + return Digest + default: + return Unknown + } +} + +func (c *Credentials) HasAccessToken() bool { + return c.AccessToken != "" +} + +func (c *Credentials) HasServiceAccount() bool { + return c.ClientID != "" || c.ClientSecret != "" +} + +func (c *Credentials) HasDigest() bool { + return c.PublicKey != "" || c.PrivateKey != "" +} + +func (c *Credentials) IsPresent() bool { + return c.AuthMethod() != Unknown +} + +func (c *Credentials) Warnings() string { + if !c.IsPresent() { + return "No credentials set" + } + // Prefer specific checks over generic code as there are few combinations and code is clearer. + if c.HasAccessToken() && c.HasServiceAccount() && c.HasDigest() { + return "Access Token will be used although Service Account and API Keys are also set" + } + if c.HasAccessToken() && c.HasServiceAccount() { + return "Access Token will be used although Service Account is also set" + } + if c.HasAccessToken() && c.HasDigest() { + return "Access Token will be used although API Keys is also set" + } + if c.HasServiceAccount() && c.HasDigest() { + return "Service Account will be used although API Keys is also set" + } + return "" +} + +type AWSVars struct { + AssumeRoleARN string + SecretName string + Region string + AccessKeyID string + SecretAccessKey string + SessionToken string + Endpoint string +} + +func (a *AWSVars) IsPresent() bool { + return a.AssumeRoleARN != "" +} + +type Vars struct { + AccessToken string + ClientID string + ClientSecret string + PublicKey string + PrivateKey string + BaseURL string + RealmBaseURL string + AWSAssumeRoleARN string + AWSSecretName string + AWSRegion string + AWSAccessKeyID string + AWSSecretAccessKey string + AWSSessionToken string + AWSEndpoint string +} + +func NewEnvVars() *Vars { + return &Vars{ + AccessToken: getEnv("MONGODB_ATLAS_ACCESS_TOKEN", "TF_VAR_ACCESS_TOKEN"), + ClientID: getEnv("MONGODB_ATLAS_CLIENT_ID", "TF_VAR_CLIENT_ID"), + ClientSecret: getEnv("MONGODB_ATLAS_CLIENT_SECRET", "TF_VAR_CLIENT_SECRET"), + PublicKey: getEnv("MONGODB_ATLAS_PUBLIC_API_KEY", "MONGODB_ATLAS_PUBLIC_KEY", "MCLI_PUBLIC_API_KEY"), + PrivateKey: getEnv("MONGODB_ATLAS_PRIVATE_API_KEY", "MONGODB_ATLAS_PRIVATE_KEY", "MCLI_PRIVATE_API_KEY"), + BaseURL: getEnv("MONGODB_ATLAS_BASE_URL", "MCLI_OPS_MANAGER_URL"), + RealmBaseURL: getEnv("MONGODB_REALM_BASE_URL"), + AWSAssumeRoleARN: getEnv("ASSUME_ROLE_ARN", "TF_VAR_ASSUME_ROLE_ARN"), + AWSSecretName: getEnv("SECRET_NAME", "TF_VAR_SECRET_NAME"), + AWSRegion: getEnv("AWS_REGION", "TF_VAR_AWS_REGION"), + AWSAccessKeyID: getEnv("AWS_ACCESS_KEY_ID", "TF_VAR_AWS_ACCESS_KEY_ID"), + AWSSecretAccessKey: getEnv("AWS_SECRET_ACCESS_KEY", "TF_VAR_AWS_SECRET_ACCESS_KEY"), + AWSSessionToken: getEnv("AWS_SESSION_TOKEN", "TF_VAR_AWS_SESSION_TOKEN"), + AWSEndpoint: getEnv("STS_ENDPOINT", "TF_VAR_STS_ENDPOINT"), + } +} + +func (e *Vars) GetCredentials() *Credentials { + return &Credentials{ + AccessToken: e.AccessToken, + ClientID: e.ClientID, + ClientSecret: e.ClientSecret, + PublicKey: e.PublicKey, + PrivateKey: e.PrivateKey, + BaseURL: e.BaseURL, + RealmBaseURL: e.RealmBaseURL, + } +} + +// GetAWS returns variables in the format AWS expects, e.g. region in lowercase. +func (e *Vars) GetAWS() *AWSVars { + return &AWSVars{ + AssumeRoleARN: e.AWSAssumeRoleARN, + SecretName: e.AWSSecretName, + Region: conversion.MongoDBRegionToAWSRegion(e.AWSRegion), + AccessKeyID: e.AWSAccessKeyID, + SecretAccessKey: e.AWSSecretAccessKey, + SessionToken: e.AWSSessionToken, + Endpoint: e.AWSEndpoint, + } +} + +func getEnv(key ...string) string { + for _, k := range key { + if v := os.Getenv(k); v != "" { + return v + } + } + return "" +} + +func CoalesceAWSVars(awsVars ...*AWSVars) *AWSVars { + for _, awsVar := range awsVars { + if awsVar.IsPresent() { + return awsVar + } + } + return nil +} + +func CoalesceCredentials(credentials ...*Credentials) *Credentials { + for _, credential := range credentials { + if credential.IsPresent() { + return credential + } + } + return nil +} diff --git a/internal/config/credentials_test.go b/internal/config/credentials_test.go new file mode 100644 index 0000000000..86641189eb --- /dev/null +++ b/internal/config/credentials_test.go @@ -0,0 +1,504 @@ +package config_test + +import ( + "errors" + "testing" + + "github.com/mongodb/terraform-provider-mongodbatlas/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCredentials_AuthMethod(t *testing.T) { + testCases := map[string]struct { + credentials config.Credentials + want config.AuthMethod + }{ + "Empty credentials returns Unknown": { + credentials: config.Credentials{}, + want: config.Unknown, + }, + "Access token takes priority": { + credentials: config.Credentials{ + AccessToken: "token", + ClientID: "id", + ClientSecret: "secret", + PublicKey: "public", + PrivateKey: "private", + }, + want: config.AccessToken, + }, + "Service account when no access token": { + credentials: config.Credentials{ + ClientID: "id", + ClientSecret: "secret", + PublicKey: "public", + PrivateKey: "private", + }, + want: config.ServiceAccount, + }, + "Service account with only ClientID": { + credentials: config.Credentials{ + ClientID: "id", + }, + want: config.ServiceAccount, + }, + "Service account with only ClientSecret": { + credentials: config.Credentials{ + ClientSecret: "secret", + }, + want: config.ServiceAccount, + }, + "Digest when only digest credentials": { + credentials: config.Credentials{ + PublicKey: "public", + PrivateKey: "private", + }, + want: config.Digest, + }, + "Digest with only PublicKey": { + credentials: config.Credentials{ + PublicKey: "public", + }, + want: config.Digest, + }, + "Digest with only PrivateKey": { + credentials: config.Credentials{ + PrivateKey: "private", + }, + want: config.Digest, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got := tc.credentials.AuthMethod() + assert.Equal(t, tc.want, got) + }) + } +} + +func TestCredentials_HasAccessToken(t *testing.T) { + testCases := map[string]struct { + credentials config.Credentials + want bool + }{ + "Empty credentials": { + credentials: config.Credentials{}, + want: false, + }, + "With access token": { + credentials: config.Credentials{ + AccessToken: "token", + }, + want: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got := tc.credentials.HasAccessToken() + assert.Equal(t, tc.want, got) + }) + } +} + +func TestCredentials_HasServiceAccount(t *testing.T) { + testCases := map[string]struct { + credentials config.Credentials + want bool + }{ + "Empty credentials": { + credentials: config.Credentials{}, + want: false, + }, + "With ClientID only": { + credentials: config.Credentials{ + ClientID: "id", + }, + want: true, + }, + "With ClientSecret only": { + credentials: config.Credentials{ + ClientSecret: "secret", + }, + want: true, + }, + "With both ClientID and ClientSecret": { + credentials: config.Credentials{ + ClientID: "id", + ClientSecret: "secret", + }, + want: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got := tc.credentials.HasServiceAccount() + assert.Equal(t, tc.want, got) + }) + } +} + +func TestCredentials_HasDigest(t *testing.T) { + testCases := map[string]struct { + credentials config.Credentials + want bool + }{ + "Empty credentials": { + credentials: config.Credentials{}, + want: false, + }, + "With PublicKey only": { + credentials: config.Credentials{ + PublicKey: "public", + }, + want: true, + }, + "With PrivateKey only": { + credentials: config.Credentials{ + PrivateKey: "private", + }, + want: true, + }, + "With both PublicKey and PrivateKey": { + credentials: config.Credentials{ + PublicKey: "public", + PrivateKey: "private", + }, + want: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got := tc.credentials.HasDigest() + assert.Equal(t, tc.want, got) + }) + } +} + +func TestCredentials_IsPresent(t *testing.T) { + testCases := map[string]struct { + credentials config.Credentials + want bool + }{ + "Empty credentials": { + credentials: config.Credentials{}, + want: false, + }, + "With access token": { + credentials: config.Credentials{ + AccessToken: "token", + }, + want: true, + }, + "With service account": { + credentials: config.Credentials{ + ClientID: "id", + }, + want: true, + }, + "With digest": { + credentials: config.Credentials{ + PublicKey: "public", + }, + want: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got := tc.credentials.IsPresent() + assert.Equal(t, tc.want, got) + }) + } +} + +func TestCredentials_Warnings(t *testing.T) { + testCases := map[string]struct { + credentials config.Credentials + want string + }{ + "No credentials": { + credentials: config.Credentials{}, + want: "No credentials set", + }, + "Only access token - no warning": { + credentials: config.Credentials{ + AccessToken: "token", + }, + want: "", + }, + "Only service account - no warning": { + credentials: config.Credentials{ + ClientID: "id", + }, + want: "", + }, + "Only digest - no warning": { + credentials: config.Credentials{ + PublicKey: "public", + }, + want: "", + }, + "Access token and service account": { + credentials: config.Credentials{ + AccessToken: "token", + ClientID: "id", + ClientSecret: "secret", + }, + want: "Access Token will be used although Service Account is also set", + }, + "Access token and digest": { + credentials: config.Credentials{ + AccessToken: "token", + PublicKey: "public", + PrivateKey: "private", + }, + want: "Access Token will be used although API Keys is also set", + }, + "Service account and digest": { + credentials: config.Credentials{ + ClientID: "id", + PublicKey: "public", + PrivateKey: "private", + }, + want: "Service Account will be used although API Keys is also set", + }, + "All three methods": { + credentials: config.Credentials{ + AccessToken: "token", + ClientID: "id", + ClientSecret: "secret", + PublicKey: "public", + PrivateKey: "private", + }, + want: "Access Token will be used although Service Account and API Keys are also set", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got := tc.credentials.Warnings() + assert.Equal(t, tc.want, got) + }) + } +} + +func TestGetCredentials(t *testing.T) { + mockGetAWSCredentials := func(awsVars *config.AWSVars) (*config.Credentials, error) { + if awsVars.AssumeRoleARN == "error" { + return nil, errors.New("AWS error") + } + return &config.Credentials{ + AccessToken: "aws-token", + }, nil + } + + testCases := map[string]struct { + providerVars *config.Vars + envVars *config.Vars + want *config.Credentials + wantErr bool + }{ + "AWS credentials take priority": { + providerVars: &config.Vars{ + AWSAssumeRoleARN: "arn", + PublicKey: "provider-public", + }, + envVars: &config.Vars{ + PublicKey: "env-public", + }, + want: &config.Credentials{ + AccessToken: "aws-token", + }, + wantErr: false, + }, + "AWS credentials error": { + providerVars: &config.Vars{ + AWSAssumeRoleARN: "error", + }, + envVars: &config.Vars{}, + want: nil, + wantErr: true, + }, + "Provider vars take priority over env vars": { + providerVars: &config.Vars{ + PublicKey: "provider-public", + }, + envVars: &config.Vars{ + PublicKey: "env-public", + }, + want: &config.Credentials{ + PublicKey: "provider-public", + }, + wantErr: false, + }, + "Env vars when no provider vars": { + providerVars: &config.Vars{}, + envVars: &config.Vars{ + PublicKey: "env-public", + }, + want: &config.Credentials{ + PublicKey: "env-public", + }, + wantErr: false, + }, + "Empty credentials when nothing provided": { + providerVars: &config.Vars{}, + envVars: &config.Vars{}, + want: &config.Credentials{}, + wantErr: false, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got, err := config.GetCredentials(tc.providerVars, tc.envVars, mockGetAWSCredentials) + if tc.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.want, got) + } + }) + } +} + +func TestAWSVars_IsPresent(t *testing.T) { + testCases := map[string]struct { + awsVars *config.AWSVars + want bool + }{ + "Empty AWS vars": { + awsVars: &config.AWSVars{}, + want: false, + }, + "With AssumeRoleARN": { + awsVars: &config.AWSVars{ + AssumeRoleARN: "arn", + }, + want: true, + }, + "With other fields but no AssumeRoleARN": { + awsVars: &config.AWSVars{ + SecretName: "secret", + Region: "us-east-1", + }, + want: false, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got := tc.awsVars.IsPresent() + assert.Equal(t, tc.want, got) + }) + } +} + +func TestNewEnvVars(t *testing.T) { + // Test the first env var for each attribute. + t.Setenv("MONGODB_ATLAS_ACCESS_TOKEN", "env-token") + t.Setenv("MONGODB_ATLAS_CLIENT_ID", "env-client-id") + t.Setenv("MONGODB_ATLAS_CLIENT_SECRET", "env-client-secret") + t.Setenv("MONGODB_ATLAS_PUBLIC_API_KEY", "env-public") + t.Setenv("MONGODB_ATLAS_PRIVATE_API_KEY", "env-private") + t.Setenv("MONGODB_ATLAS_BASE_URL", "url1") + t.Setenv("MONGODB_REALM_BASE_URL", "url2") + t.Setenv("ASSUME_ROLE_ARN", "arn") + t.Setenv("SECRET_NAME", "env-secret") + t.Setenv("AWS_REGION", "us-west-2") + t.Setenv("AWS_ACCESS_KEY_ID", "env-access") + t.Setenv("AWS_SECRET_ACCESS_KEY", "env-secret-key") + t.Setenv("AWS_SESSION_TOKEN", "env-token") + t.Setenv("STS_ENDPOINT", "https://sts.amazonaws.com") + + vars := config.NewEnvVars() + assert.Equal(t, "env-token", vars.AccessToken) + assert.Equal(t, "env-client-id", vars.ClientID) + assert.Equal(t, "env-client-secret", vars.ClientSecret) + assert.Equal(t, "env-public", vars.PublicKey) + assert.Equal(t, "env-private", vars.PrivateKey) + assert.Equal(t, "url1", vars.BaseURL) + assert.Equal(t, "url2", vars.RealmBaseURL) + assert.Equal(t, "arn", vars.AWSAssumeRoleARN) + assert.Equal(t, "env-secret", vars.AWSSecretName) + assert.Equal(t, "us-west-2", vars.AWSRegion) + assert.Equal(t, "env-access", vars.AWSAccessKeyID) + assert.Equal(t, "env-secret-key", vars.AWSSecretAccessKey) + assert.Equal(t, "env-token", vars.AWSSessionToken) + assert.Equal(t, "https://sts.amazonaws.com", vars.AWSEndpoint) +} + +func TestCoalesceAWSVars(t *testing.T) { + awsVars1 := &config.AWSVars{AssumeRoleARN: "arn1"} + awsVars2 := &config.AWSVars{AssumeRoleARN: "arn2"} + awsVarsEmpty := &config.AWSVars{} + + testCases := map[string]struct { + want *config.AWSVars + awsVars []*config.AWSVars + }{ + "First present AWS vars": { + awsVars: []*config.AWSVars{awsVars1, awsVars2}, + want: awsVars1, + }, + "Skip empty, return first present": { + awsVars: []*config.AWSVars{awsVarsEmpty, awsVars2}, + want: awsVars2, + }, + "All empty returns nil": { + awsVars: []*config.AWSVars{awsVarsEmpty, awsVarsEmpty}, + want: nil, + }, + "No vars returns nil": { + awsVars: []*config.AWSVars{}, + want: nil, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got := config.CoalesceAWSVars(tc.awsVars...) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestCoalesceCredentials(t *testing.T) { + creds1 := &config.Credentials{PublicKey: "key1"} + creds2 := &config.Credentials{PublicKey: "key2"} + credsEmpty := &config.Credentials{} + + testCases := map[string]struct { + want *config.Credentials + credentials []*config.Credentials + }{ + "First present credentials": { + credentials: []*config.Credentials{creds1, creds2}, + want: creds1, + }, + "Skip empty, return first present": { + credentials: []*config.Credentials{credsEmpty, creds2}, + want: creds2, + }, + "All empty returns nil": { + credentials: []*config.Credentials{credsEmpty, credsEmpty}, + want: nil, + }, + "No credentials returns nil": { + credentials: []*config.Credentials{}, + want: nil, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + got := config.CoalesceCredentials(tc.credentials...) + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/internal/config/service_account.go b/internal/config/service_account.go index 6a214e9b17..84b41cc64e 100644 --- a/internal/config/service_account.go +++ b/internal/config/service_account.go @@ -24,20 +24,20 @@ var saInfo = struct { mu sync.Mutex }{} -func getTokenSource(c *Config, tokenRenewalBase http.RoundTripper) (auth.TokenSource, error) { +func getTokenSource(clientID, clientSecret, baseURL string, tokenRenewalBase http.RoundTripper) (auth.TokenSource, error) { saInfo.mu.Lock() defer saInfo.mu.Unlock() + baseURL = strings.TrimRight(baseURL, "/") if saInfo.tokenSource != nil { // Token source in cache. - if saInfo.clientID != c.ClientID || saInfo.clientSecret != c.ClientSecret || saInfo.baseURL != c.BaseURL { + if saInfo.clientID != clientID || saInfo.clientSecret != clientSecret || saInfo.baseURL != baseURL { return nil, fmt.Errorf("service account credentials changed") } return saInfo.tokenSource, nil } - conf := clientcredentials.NewConfig(c.ClientID, c.ClientSecret) - if c.BaseURL != "" { - baseURL := strings.TrimRight(c.BaseURL, "/") + conf := clientcredentials.NewConfig(clientID, clientSecret) + if baseURL != "" { conf.TokenURL = baseURL + clientcredentials.TokenAPIPath conf.RevokeURL = baseURL + clientcredentials.RevokeAPIPath } @@ -47,9 +47,9 @@ func getTokenSource(c *Config, tokenRenewalBase http.RoundTripper) (auth.TokenSo if _, err := tokenSource.Token(); err != nil { // Retrieve token to fail-fast if credentials are invalid. return nil, err } - saInfo.clientID = c.ClientID - saInfo.clientSecret = c.ClientSecret - saInfo.baseURL = c.BaseURL + saInfo.clientID = clientID + saInfo.clientSecret = clientSecret + saInfo.baseURL = baseURL saInfo.tokenSource = tokenSource return saInfo.tokenSource, nil } diff --git a/internal/config/transport_test.go b/internal/config/transport_test.go index b718ac1b72..2c54b5d77f 100644 --- a/internal/config/transport_test.go +++ b/internal/config/transport_test.go @@ -159,17 +159,17 @@ func TestAccNetworkLogging(t *testing.T) { var logOutput bytes.Buffer log.SetOutput(&logOutput) defer log.SetOutput(os.Stderr) - cfg := &config.Config{ + c := &config.Credentials{ PublicKey: os.Getenv("MONGODB_ATLAS_PUBLIC_KEY"), PrivateKey: os.Getenv("MONGODB_ATLAS_PRIVATE_KEY"), ClientID: os.Getenv("MONGODB_ATLAS_CLIENT_ID"), ClientSecret: os.Getenv("MONGODB_ATLAS_CLIENT_SECRET"), BaseURL: os.Getenv("MONGODB_ATLAS_BASE_URL"), } - client, err := cfg.NewClient(t.Context()) + client, err := config.NewClient(c, "") require.NoError(t, err) - // Make a simple API call that should trigger our enhanced logging + // Make a simple API call that should trigger our enhanced logging. _, _, err = client.AtlasV2.OrganizationsApi.ListOrgs(t.Context()).Execute() require.NoError(t, err) logStr := logOutput.String() diff --git a/internal/provider/aws_credentials.go b/internal/provider/aws_credentials.go index 335589d681..34ceaf93fb 100644 --- a/internal/provider/aws_credentials.go +++ b/internal/provider/aws_credentials.go @@ -25,63 +25,34 @@ const ( minSegmentsForSTSRegionalHost = 4 ) -func configureCredentialsSTS(cfg *config.Config, secret, region, awsAccessKeyID, awsSecretAccessKey, awsSessionToken, endpoint string) (config.Config, error) { +func getAWSCredentials(c *config.AWSVars) (*config.Credentials, error) { defaultResolver := endpoints.DefaultResolver() stsCustResolverFn := func(service, _ string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { if service == sts.EndpointsID { - resolved, err := ResolveSTSEndpoint(endpoint, region) + resolved, err := ResolveSTSEndpoint(c.Endpoint, c.Region) if err != nil { return endpoints.ResolvedEndpoint{}, err } return resolved, nil } - return defaultResolver.EndpointFor(service, region, optFns...) + return defaultResolver.EndpointFor(service, c.Region, optFns...) } - sess := session.Must(session.NewSession(&aws.Config{ - Region: aws.String(region), - Credentials: credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, awsSessionToken), + Region: aws.String(c.Region), + Credentials: credentials.NewStaticCredentials(c.AccessKeyID, c.SecretAccessKey, c.SessionToken), EndpointResolver: endpoints.ResolverFunc(stsCustResolverFn), })) - - creds := stscreds.NewCredentials(sess, cfg.AssumeRoleARN) - - _, err := sess.Config.Credentials.Get() + creds := stscreds.NewCredentials(sess, c.AssumeRoleARN) + secretString, err := secretsManagerGetSecretValue(sess, &aws.Config{Credentials: creds, Region: aws.String(c.Region)}, c.SecretName) if err != nil { - log.Printf("Session get credentials error: %s", err) - return *cfg, err + return nil, err } - _, err = creds.Get() + var secret config.Credentials + err = json.Unmarshal([]byte(secretString), &secret) if err != nil { - log.Printf("STS get credentials error: %s", err) - return *cfg, err + return nil, err } - secretString, err := secretsManagerGetSecretValue(sess, &aws.Config{Credentials: creds, Region: aws.String(region)}, secret) - if err != nil { - log.Printf("Get Secrets error: %s", err) - return *cfg, err - } - - var secretData SecretData - err = json.Unmarshal([]byte(secretString), &secretData) - if err != nil { - return *cfg, err - } - - switch config.ResolveAuthMethod(&secretData) { - case config.AccessToken: - cfg.AccessToken = secretData.AccessToken - case config.Digest: - cfg.PublicKey = secretData.PublicKey - cfg.PrivateKey = secretData.PrivateKey - case config.ServiceAccount: - cfg.ClientID = secretData.ClientID - cfg.ClientSecret = secretData.ClientSecret - case config.Unknown: - return *cfg, fmt.Errorf("secret missing value for supported credentials: PrivateKey/PublicKey, ClientID/ClientSecret or AccessToken") - } - - return *cfg, nil + return &secret, nil } func DeriveSTSRegionFromEndpoint(ep string) string { diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 54da063b2b..3801398d08 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -3,12 +3,10 @@ package provider import ( "context" "log" - "os" + "slices" "github.com/hashicorp/terraform-plugin-framework-validators/listvalidator" - "github.com/hashicorp/terraform-plugin-framework/attr" "github.com/hashicorp/terraform-plugin-framework/datasource" - "github.com/hashicorp/terraform-plugin-framework/diag" "github.com/hashicorp/terraform-plugin-framework/provider" "github.com/hashicorp/terraform-plugin-framework/provider/metaschema" "github.com/hashicorp/terraform-plugin-framework/provider/schema" @@ -20,7 +18,6 @@ import ( "github.com/hashicorp/terraform-plugin-mux/tf5to6server" "github.com/hashicorp/terraform-plugin-mux/tf6muxserver" - "github.com/mongodb/terraform-provider-mongodbatlas/internal/common/conversion" "github.com/mongodb/terraform-provider-mongodbatlas/internal/config" "github.com/mongodb/terraform-provider-mongodbatlas/internal/service/advancedcluster" "github.com/mongodb/terraform-provider-mongodbatlas/internal/service/alertconfiguration" @@ -53,8 +50,7 @@ import ( ) const ( - MongodbGovCloudURL = "https://cloud.mongodbgov.com" - MongodbGovCloudQAURL = "https://cloud-qa.mongodbgov.com" + govURL = "https://cloud.mongodbgov.com" MongodbGovCloudDevURL = "https://cloud-dev.mongodbgov.com" ProviderConfigError = "error in configuring the provider." MissingAuthAttrError = "either AWS Secrets Manager, Service Accounts or Atlas Programmatic API Keys attributes must be set" @@ -66,35 +62,38 @@ const ( ProviderMetaModuleVersionDesc = "The version of the module using the provider" ) +var ( + govAdditionalURLs = []string{ + "https://cloud-dev.mongodbgov.com", + "https://cloud-qa.mongodbgov.com", + } +) + type MongodbtlasProvider struct { } -type tfMongodbAtlasProviderModel struct { - AssumeRole types.List `tfsdk:"assume_role"` - Region types.String `tfsdk:"region"` - PrivateKey types.String `tfsdk:"private_key"` - BaseURL types.String `tfsdk:"base_url"` - RealmBaseURL types.String `tfsdk:"realm_base_url"` - SecretName types.String `tfsdk:"secret_name"` - PublicKey types.String `tfsdk:"public_key"` - StsEndpoint types.String `tfsdk:"sts_endpoint"` - AwsAccessKeyID types.String `tfsdk:"aws_access_key_id"` - AwsSecretAccessKeyID types.String `tfsdk:"aws_secret_access_key"` - AwsSessionToken types.String `tfsdk:"aws_session_token"` - ClientID types.String `tfsdk:"client_id"` - ClientSecret types.String `tfsdk:"client_secret"` - AccessToken types.String `tfsdk:"access_token"` - IsMongodbGovCloud types.Bool `tfsdk:"is_mongodbgov_cloud"` +type tfModel struct { + Region types.String `tfsdk:"region"` + PrivateKey types.String `tfsdk:"private_key"` + BaseURL types.String `tfsdk:"base_url"` + RealmBaseURL types.String `tfsdk:"realm_base_url"` + SecretName types.String `tfsdk:"secret_name"` + PublicKey types.String `tfsdk:"public_key"` + StsEndpoint types.String `tfsdk:"sts_endpoint"` + AwsAccessKeyID types.String `tfsdk:"aws_access_key_id"` + AwsSecretAccessKeyID types.String `tfsdk:"aws_secret_access_key"` + AwsSessionToken types.String `tfsdk:"aws_session_token"` + ClientID types.String `tfsdk:"client_id"` + ClientSecret types.String `tfsdk:"client_secret"` + AccessToken types.String `tfsdk:"access_token"` + AssumeRole []tfAssumeRoleModel `tfsdk:"assume_role"` + IsMongodbGovCloud types.Bool `tfsdk:"is_mongodbgov_cloud"` } type tfAssumeRoleModel struct { RoleARN types.String `tfsdk:"role_arn"` } -var AssumeRoleType = types.ObjectType{AttrTypes: map[string]attr.Type{ - "role_arn": types.StringType, -}} - func (p *MongodbtlasProvider) Metadata(ctx context.Context, req provider.MetadataRequest, resp *provider.MetadataResponse) { resp.TypeName = "mongodbatlas" resp.Version = version.ProviderVersion @@ -200,194 +199,57 @@ var fwAssumeRoleSchema = schema.ListNestedBlock{ } func (p *MongodbtlasProvider) Configure(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) { - var data tfMongodbAtlasProviderModel - - resp.Diagnostics.Append(req.Config.Get(ctx, &data)...) + providerVars := getProviderVars(ctx, req, resp) if resp.Diagnostics.HasError() { return } - - data = setDefaultValuesWithValidations(ctx, &data, resp) - if resp.Diagnostics.HasError() { + c, err := config.GetCredentials(providerVars, config.NewEnvVars(), getAWSCredentials) + if err != nil { + resp.Diagnostics.AddError("Error getting credentials for provider", err.Error()) return } - - cfg := config.Config{ - PublicKey: data.PublicKey.ValueString(), - PrivateKey: data.PrivateKey.ValueString(), - BaseURL: data.BaseURL.ValueString(), - RealmBaseURL: data.RealmBaseURL.ValueString(), - TerraformVersion: req.TerraformVersion, - ClientID: data.ClientID.ValueString(), - ClientSecret: data.ClientSecret.ValueString(), - AccessToken: data.AccessToken.ValueString(), - } - - var assumeRoles []tfAssumeRoleModel - data.AssumeRole.ElementsAs(ctx, &assumeRoles, true) - awsRoleDefined := len(assumeRoles) > 0 - if awsRoleDefined { - cfg.AssumeRoleARN = assumeRoles[0].RoleARN.ValueString() - secret := data.SecretName.ValueString() - region := conversion.MongoDBRegionToAWSRegion(data.Region.ValueString()) - awsAccessKeyID := data.AwsAccessKeyID.ValueString() - awsSecretAccessKey := data.AwsSecretAccessKeyID.ValueString() - awsSessionToken := data.AwsSessionToken.ValueString() - endpoint := data.StsEndpoint.ValueString() - var err error - cfg, err = configureCredentialsSTS(&cfg, secret, region, awsAccessKeyID, awsSecretAccessKey, awsSessionToken, endpoint) - if err != nil { - resp.Diagnostics.AddError("failed to configure credentials STS", err.Error()) - return - } + if c.Warnings() != "" { + resp.Diagnostics.AddWarning("Warning getting credentials for provider", c.Warnings()) } - - client, err := cfg.NewClient(ctx) - + client, err := config.NewClient(c, req.TerraformVersion) if err != nil { - resp.Diagnostics.AddError( - "failed to initialize a new client", - err.Error(), - ) + resp.Diagnostics.AddError("Error initializing provider", err.Error()) return } - resp.DataSourceData = client resp.ResourceData = client } -func setDefaultValuesWithValidations(ctx context.Context, data *tfMongodbAtlasProviderModel, resp *provider.ConfigureResponse) tfMongodbAtlasProviderModel { - if mongodbgovCloud := data.IsMongodbGovCloud.ValueBool(); mongodbgovCloud { - if !isGovBaseURLConfiguredForProvider(data) { - data.BaseURL = types.StringValue(MongodbGovCloudURL) - } - } - if data.BaseURL.ValueString() == "" { - data.BaseURL = types.StringValue(MultiEnvDefaultFunc([]string{ - "MONGODB_ATLAS_BASE_URL", - "MCLI_OPS_MANAGER_URL", - }, "").(string)) - } - - awsRoleDefined := false - if len(data.AssumeRole.Elements()) == 0 { - assumeRoleArn := MultiEnvDefaultFunc([]string{ - "ASSUME_ROLE_ARN", - "TF_VAR_ASSUME_ROLE_ARN", - }, "").(string) - if assumeRoleArn != "" { - awsRoleDefined = true - var diags diag.Diagnostics - data.AssumeRole, diags = types.ListValueFrom(ctx, AssumeRoleType, []tfAssumeRoleModel{ - { - RoleARN: types.StringValue(assumeRoleArn), - }, - }) - if diags.HasError() { - resp.Diagnostics.Append(diags...) - } - } - } else { - awsRoleDefined = true - } - - if data.PublicKey.ValueString() == "" { - data.PublicKey = types.StringValue(MultiEnvDefaultFunc([]string{ - "MONGODB_ATLAS_PUBLIC_API_KEY", - "MONGODB_ATLAS_PUBLIC_KEY", - "MCLI_PUBLIC_API_KEY", - }, "").(string)) - } - - if data.PrivateKey.ValueString() == "" { - data.PrivateKey = types.StringValue(MultiEnvDefaultFunc([]string{ - "MONGODB_ATLAS_PRIVATE_API_KEY", - "MONGODB_ATLAS_PRIVATE_KEY", - "MCLI_PRIVATE_API_KEY", - }, "").(string)) - } - - if data.RealmBaseURL.ValueString() == "" { - data.RealmBaseURL = types.StringValue(MultiEnvDefaultFunc([]string{ - "MONGODB_REALM_BASE_URL", - }, "").(string)) - } - - if data.Region.ValueString() == "" { - data.Region = types.StringValue(MultiEnvDefaultFunc([]string{ - "AWS_REGION", - "TF_VAR_AWS_REGION", - }, "").(string)) - } - - if data.StsEndpoint.ValueString() == "" { - data.StsEndpoint = types.StringValue(MultiEnvDefaultFunc([]string{ - "STS_ENDPOINT", - "TF_VAR_STS_ENDPOINT", - }, "").(string)) - } - - if data.AwsAccessKeyID.ValueString() == "" { - data.AwsAccessKeyID = types.StringValue(MultiEnvDefaultFunc([]string{ - "AWS_ACCESS_KEY_ID", - "TF_VAR_AWS_ACCESS_KEY_ID", - }, "").(string)) - } - - if data.AwsSecretAccessKeyID.ValueString() == "" { - data.AwsSecretAccessKeyID = types.StringValue(MultiEnvDefaultFunc([]string{ - "AWS_SECRET_ACCESS_KEY", - "TF_VAR_AWS_SECRET_ACCESS_KEY", - }, "").(string)) - } - - if data.AwsSessionToken.ValueString() == "" { - data.AwsSessionToken = types.StringValue(MultiEnvDefaultFunc([]string{ - "AWS_SESSION_TOKEN", - "TF_VAR_AWS_SESSION_TOKEN", - }, "").(string)) - } - - if data.SecretName.ValueString() == "" { - data.SecretName = types.StringValue(MultiEnvDefaultFunc([]string{ - "SECRET_NAME", - "TF_VAR_SECRET_NAME", - }, "").(string)) - } - - if data.ClientID.ValueString() == "" { - data.ClientID = types.StringValue(MultiEnvDefaultFunc([]string{ - "MONGODB_ATLAS_CLIENT_ID", - "TF_VAR_CLIENT_ID", - }, "").(string)) +func getProviderVars(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) *config.Vars { + var data tfModel + resp.Diagnostics.Append(req.Config.Get(ctx, &data)...) + if resp.Diagnostics.HasError() { + return nil } - - if data.ClientSecret.ValueString() == "" { - data.ClientSecret = types.StringValue(MultiEnvDefaultFunc([]string{ - "MONGODB_ATLAS_CLIENT_SECRET", - "TF_VAR_CLIENT_SECRET", - }, "").(string)) + assumeRoleARN := "" + if len(data.AssumeRole) > 0 { + assumeRoleARN = data.AssumeRole[0].RoleARN.ValueString() } - - if data.AccessToken.ValueString() == "" { - data.AccessToken = types.StringValue(MultiEnvDefaultFunc([]string{ - "MONGODB_ATLAS_OAUTH_TOKEN", - "TF_VAR_OAUTH_TOKEN", - }, "").(string)) + baseURL := data.BaseURL.ValueString() + if data.IsMongodbGovCloud.ValueBool() && !slices.Contains(govAdditionalURLs, baseURL) { + baseURL = govURL } - - // Check if any valid authentication method is provided - if !config.HasValidAuthCredentials(&config.Config{ - PublicKey: data.PublicKey.ValueString(), - PrivateKey: data.PrivateKey.ValueString(), - ClientID: data.ClientID.ValueString(), - ClientSecret: data.ClientSecret.ValueString(), - AccessToken: data.AccessToken.ValueString(), - }) && !awsRoleDefined { - resp.Diagnostics.AddError(ProviderConfigError, MissingAuthAttrError) + return &config.Vars{ + AccessToken: data.AccessToken.ValueString(), + ClientID: data.ClientID.ValueString(), + ClientSecret: data.ClientSecret.ValueString(), + PublicKey: data.PublicKey.ValueString(), + PrivateKey: data.PrivateKey.ValueString(), + BaseURL: baseURL, + RealmBaseURL: data.RealmBaseURL.ValueString(), + AWSAssumeRoleARN: assumeRoleARN, + AWSSecretName: data.SecretName.ValueString(), + AWSRegion: data.Region.ValueString(), + AWSAccessKeyID: data.AwsAccessKeyID.ValueString(), + AWSSecretAccessKey: data.AwsSecretAccessKeyID.ValueString(), + AWSSessionToken: data.AwsSessionToken.ValueString(), + AWSEndpoint: data.StsEndpoint.ValueString(), } - - return *data } func (p *MongodbtlasProvider) DataSources(context.Context) []func() datasource.DataSource { @@ -486,26 +348,3 @@ func MuxProviderFactory() func() tfprotov6.ProviderServer { } return muxServer.ProviderServer } - -func MultiEnvDefaultFunc(ks []string, def any) any { - for _, k := range ks { - if v := os.Getenv(k); v != "" { - return v - } - } - return def -} - -func isGovBaseURLConfigured(baseURL string) bool { - if baseURL == "" { - baseURL = MultiEnvDefaultFunc([]string{ - "MONGODB_ATLAS_BASE_URL", - "MCLI_OPS_MANAGER_URL", - }, "").(string) - } - return baseURL == MongodbGovCloudDevURL || baseURL == MongodbGovCloudQAURL -} - -func isGovBaseURLConfiguredForProvider(data *tfMongodbAtlasProviderModel) bool { - return isGovBaseURLConfigured(data.BaseURL.ValueString()) -} diff --git a/internal/provider/provider_sdk2.go b/internal/provider/provider_sdk2.go index 283b2c447d..0259167d13 100644 --- a/internal/provider/provider_sdk2.go +++ b/internal/provider/provider_sdk2.go @@ -2,11 +2,12 @@ package provider import ( "context" + "fmt" + "slices" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" - "github.com/mongodb/terraform-provider-mongodbatlas/internal/common/conversion" "github.com/mongodb/terraform-provider-mongodbatlas/internal/config" "github.com/mongodb/terraform-provider-mongodbatlas/internal/service/accesslistapikey" "github.com/mongodb/terraform-provider-mongodbatlas/internal/service/apikey" @@ -52,21 +53,6 @@ import ( "github.com/mongodb/terraform-provider-mongodbatlas/internal/service/x509authenticationdatabaseuser" ) -type SecretData struct { - PublicKey string `json:"public_key"` - PrivateKey string `json:"private_key"` - ClientID string `json:"client_id"` - ClientSecret string `json:"client_secret"` - AccessToken string `json:"access_token"` -} - -// CredentialProvider implementation for SecretData -func (s *SecretData) GetPublicKey() string { return s.PublicKey } -func (s *SecretData) GetPrivateKey() string { return s.PrivateKey } -func (s *SecretData) GetClientID() string { return s.ClientID } -func (s *SecretData) GetClientSecret() string { return s.ClientSecret } -func (s *SecretData) GetAccessToken() string { return s.AccessToken } - // NewSdkV2Provider returns the provider to be use by the code. func NewSdkV2Provider() *schema.Provider { provider := &schema.Provider{ @@ -169,6 +155,23 @@ func NewSdkV2Provider() *schema.Provider { return provider } +func assumeRoleSchema() *schema.Schema { + return &schema.Schema{ + Type: schema.TypeList, + Optional: true, + MaxItems: 1, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "role_arn": { + Type: schema.TypeString, + Optional: true, + Description: "Amazon Resource Name (ARN) of an IAM Role to assume prior to making API calls.", + }, + }, + }, + } +} + func getDataSourcesMap() map[string]*schema.Resource { dataSourcesMap := map[string]*schema.Resource{ "mongodbatlas_custom_db_role": customdbrole.DataSource(), @@ -293,218 +296,47 @@ func getResourcesMap() map[string]*schema.Resource { func providerConfigure(provider *schema.Provider) func(ctx context.Context, d *schema.ResourceData) (any, diag.Diagnostics) { return func(ctx context.Context, d *schema.ResourceData) (any, diag.Diagnostics) { - diagnostics := setDefaultsAndValidations(d) - if diagnostics.HasError() { - return nil, diagnostics - } - - cfg := config.Config{ - PublicKey: d.Get("public_key").(string), - PrivateKey: d.Get("private_key").(string), - BaseURL: d.Get("base_url").(string), - RealmBaseURL: d.Get("realm_base_url").(string), - TerraformVersion: provider.TerraformVersion, - ClientID: d.Get("client_id").(string), - ClientSecret: d.Get("client_secret").(string), - AccessToken: d.Get("access_token").(string), - } - - assumeRoleValue, ok := d.GetOk("assume_role") - awsRoleDefined := ok && len(assumeRoleValue.([]any)) > 0 && assumeRoleValue.([]any)[0] != nil - if awsRoleDefined { - cfg.AssumeRoleARN = getAssumeRoleARN(assumeRoleValue.([]any)[0].(map[string]any)) - secret := d.Get("secret_name").(string) - region := conversion.MongoDBRegionToAWSRegion(d.Get("region").(string)) - awsAccessKeyID := d.Get("aws_access_key_id").(string) - awsSecretAccessKey := d.Get("aws_secret_access_key").(string) - awsSessionToken := d.Get("aws_session_token").(string) - endpoint := d.Get("sts_endpoint").(string) - var err error - cfg, err = configureCredentialsSTS(&cfg, secret, region, awsAccessKeyID, awsSecretAccessKey, awsSessionToken, endpoint) - if err != nil { - return nil, append(diagnostics, diag.FromErr(err)...) - } + var diags diag.Diagnostics + providerVars := getSDKv2ProviderVars(d) + c, err := config.GetCredentials(providerVars, config.NewEnvVars(), getAWSCredentials) + if err != nil { + return nil, append(diags, diag.FromErr(fmt.Errorf("error getting credentials for provider: %w", err))...) } - - client, err := cfg.NewClient(ctx) + // Don't log possible warnings as they will be logged by the TPF provider. + client, err := config.NewClient(c, provider.TerraformVersion) if err != nil { - return nil, append(diagnostics, diag.FromErr(err)...) + return nil, append(diags, diag.FromErr(fmt.Errorf("error initializing provider: %w", err))...) } - return client, diagnostics + return client, nil } } -func setDefaultsAndValidations(d *schema.ResourceData) diag.Diagnostics { - diagnostics := []diag.Diagnostic{} - - mongodbgovCloud := conversion.Pointer(d.Get("is_mongodbgov_cloud").(bool)) - if *mongodbgovCloud { - if !isGovBaseURLConfiguredForSDK2Provider(d) { - if err := d.Set("base_url", MongodbGovCloudURL); err != nil { - return append(diagnostics, diag.FromErr(err)...) - } - } - } - - if err := setValueFromConfigOrEnv(d, "base_url", []string{ - "MONGODB_ATLAS_BASE_URL", - "MCLI_OPS_MANAGER_URL", - }); err != nil { - return append(diagnostics, diag.FromErr(err)...) - } - - awsRoleDefined := false +func getSDKv2ProviderVars(d *schema.ResourceData) *config.Vars { + assumeRoleARN := "" assumeRoles := d.Get("assume_role").([]any) - if len(assumeRoles) == 0 { - roleArn := MultiEnvDefaultFunc([]string{ - "ASSUME_ROLE_ARN", - "TF_VAR_ASSUME_ROLE_ARN", - }, "").(string) - if roleArn != "" { - awsRoleDefined = true - if err := d.Set("assume_role", []map[string]any{{"role_arn": roleArn}}); err != nil { - return append(diagnostics, diag.FromErr(err)...) - } + if len(assumeRoles) > 0 { + if assumeRole, ok := assumeRoles[0].(map[string]any); ok { + assumeRoleARN = assumeRole["role_arn"].(string) } - } else { - awsRoleDefined = true } - - if err := setValueFromConfigOrEnv(d, "public_key", []string{ - "MONGODB_ATLAS_PUBLIC_API_KEY", - "MONGODB_ATLAS_PUBLIC_KEY", - "MCLI_PUBLIC_API_KEY", - }); err != nil { - return append(diagnostics, diag.FromErr(err)...) - } - - if err := setValueFromConfigOrEnv(d, "private_key", []string{ - "MONGODB_ATLAS_PRIVATE_API_KEY", - "MONGODB_ATLAS_PRIVATE_KEY", - "MCLI_PRIVATE_API_KEY", - }); err != nil { - return append(diagnostics, diag.FromErr(err)...) - } - - if err := setValueFromConfigOrEnv(d, "realm_base_url", []string{ - "MONGODB_REALM_BASE_URL", - }); err != nil { - return append(diagnostics, diag.FromErr(err)...) + baseURL := d.Get("base_url").(string) + if d.Get("is_mongodbgov_cloud").(bool) && !slices.Contains(govAdditionalURLs, baseURL) { + baseURL = govURL } - - if err := setValueFromConfigOrEnv(d, "region", []string{ - "AWS_REGION", - "TF_VAR_AWS_REGION", - }); err != nil { - return append(diagnostics, diag.FromErr(err)...) - } - - if err := setValueFromConfigOrEnv(d, "sts_endpoint", []string{ - "STS_ENDPOINT", - "TF_VAR_STS_ENDPOINT", - }); err != nil { - return append(diagnostics, diag.FromErr(err)...) - } - - if err := setValueFromConfigOrEnv(d, "aws_access_key_id", []string{ - "AWS_ACCESS_KEY_ID", - "TF_VAR_AWS_ACCESS_KEY_ID", - }); err != nil { - return append(diagnostics, diag.FromErr(err)...) + return &config.Vars{ + AccessToken: d.Get("access_token").(string), + ClientID: d.Get("client_id").(string), + ClientSecret: d.Get("client_secret").(string), + PublicKey: d.Get("public_key").(string), + PrivateKey: d.Get("private_key").(string), + BaseURL: baseURL, + RealmBaseURL: d.Get("realm_base_url").(string), + AWSAssumeRoleARN: assumeRoleARN, + AWSSecretName: d.Get("secret_name").(string), + AWSRegion: d.Get("region").(string), + AWSAccessKeyID: d.Get("aws_access_key_id").(string), + AWSSecretAccessKey: d.Get("aws_secret_access_key").(string), + AWSSessionToken: d.Get("aws_session_token").(string), + AWSEndpoint: d.Get("sts_endpoint").(string), } - - if err := setValueFromConfigOrEnv(d, "aws_secret_access_key", []string{ - "AWS_SECRET_ACCESS_KEY", - "TF_VAR_AWS_SECRET_ACCESS_KEY", - }); err != nil { - return append(diagnostics, diag.FromErr(err)...) - } - - if err := setValueFromConfigOrEnv(d, "secret_name", []string{ - "SECRET_NAME", - "TF_VAR_SECRET_NAME", - }); err != nil { - return append(diagnostics, diag.FromErr(err)...) - } - - if err := setValueFromConfigOrEnv(d, "aws_session_token", []string{ - "AWS_SESSION_TOKEN", - "TF_VAR_AWS_SESSION_TOKEN", - }); err != nil { - return append(diagnostics, diag.FromErr(err)...) - } - - if err := setValueFromConfigOrEnv(d, "client_id", []string{ - "MONGODB_ATLAS_CLIENT_ID", - "TF_VAR_CLIENT_ID", - }); err != nil { - return append(diagnostics, diag.FromErr(err)...) - } - - if err := setValueFromConfigOrEnv(d, "client_secret", []string{ - "MONGODB_ATLAS_CLIENT_SECRET", - "TF_VAR_CLIENT_SECRET", - }); err != nil { - return append(diagnostics, diag.FromErr(err)...) - } - - if err := setValueFromConfigOrEnv(d, "access_token", []string{ - "MONGODB_ATLAS_OAUTH_TOKEN", - "TF_VAR_OAUTH_TOKEN", - }); err != nil { - return append(diagnostics, diag.FromErr(err)...) - } - - // Check if any valid authentication method is provided - if !config.HasValidAuthCredentials(&config.Config{ - PublicKey: d.Get("public_key").(string), - PrivateKey: d.Get("private_key").(string), - ClientID: d.Get("client_id").(string), - ClientSecret: d.Get("client_secret").(string), - AccessToken: d.Get("access_token").(string), - }) && !awsRoleDefined { - diagnostics = append(diagnostics, diag.Diagnostic{Severity: diag.Error, Summary: MissingAuthAttrError}) - } - - return diagnostics -} - -func setValueFromConfigOrEnv(d *schema.ResourceData, attrName string, envVars []string) error { - var val = d.Get(attrName).(string) - if val == "" { - val = MultiEnvDefaultFunc(envVars, "").(string) - } - return d.Set(attrName, val) -} - -// assumeRoleSchema From aws provider.go -func assumeRoleSchema() *schema.Schema { - return &schema.Schema{ - Type: schema.TypeList, - Optional: true, - MaxItems: 1, - Elem: &schema.Resource{ - Schema: map[string]*schema.Schema{ - "role_arn": { - Type: schema.TypeString, - Optional: true, - Description: "Amazon Resource Name (ARN) of an IAM Role to assume prior to making API calls.", - }, - }, - }, - } -} - -func getAssumeRoleARN(tfMap map[string]any) string { - if tfMap == nil { - return "" - } - if v, ok := tfMap["role_arn"].(string); ok && v != "" { - return v - } - return "" -} - -func isGovBaseURLConfiguredForSDK2Provider(d *schema.ResourceData) bool { - return isGovBaseURLConfigured(d.Get("base_url").(string)) } diff --git a/internal/service/eventtrigger/data_source_event_trigger.go b/internal/service/eventtrigger/data_source_event_trigger.go index 5cf8c318ff..bac40a4d0a 100644 --- a/internal/service/eventtrigger/data_source_event_trigger.go +++ b/internal/service/eventtrigger/data_source_event_trigger.go @@ -133,7 +133,7 @@ func DataSource() *schema.Resource { } func dataSourceMongoDBAtlasEventTriggerRead(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { - conn, err := meta.(*config.MongoDBClient).GetRealmClient(ctx) + conn, err := meta.(*config.MongoDBClient).Realm.Get(ctx) if err != nil { return diag.FromErr(err) } diff --git a/internal/service/eventtrigger/data_source_event_triggers.go b/internal/service/eventtrigger/data_source_event_triggers.go index 120b42bdca..02295a331d 100644 --- a/internal/service/eventtrigger/data_source_event_triggers.go +++ b/internal/service/eventtrigger/data_source_event_triggers.go @@ -144,9 +144,8 @@ func PluralDataSource() *schema.Resource { } func dataSourceMongoDBAtlasEventTriggersRead(d *schema.ResourceData, meta any) error { - // Get client connection. ctx := context.Background() - conn, err := meta.(*config.MongoDBClient).GetRealmClient(ctx) + conn, err := meta.(*config.MongoDBClient).Realm.Get(ctx) if err != nil { return err } diff --git a/internal/service/eventtrigger/resource_event_trigger.go b/internal/service/eventtrigger/resource_event_trigger.go index b1e99aee41..1a90c8f3ab 100644 --- a/internal/service/eventtrigger/resource_event_trigger.go +++ b/internal/service/eventtrigger/resource_event_trigger.go @@ -210,7 +210,7 @@ func Resource() *schema.Resource { } func resourceMongoDBAtlasEventTriggersCreate(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { - conn, err := meta.(*config.MongoDBClient).GetRealmClient(ctx) + conn, err := meta.(*config.MongoDBClient).Realm.Get(ctx) if err != nil { return diag.FromErr(err) } @@ -312,7 +312,7 @@ func resourceMongoDBAtlasEventTriggersCreate(ctx context.Context, d *schema.Reso } func resourceMongoDBAtlasEventTriggersRead(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { - conn, err := meta.(*config.MongoDBClient).GetRealmClient(ctx) + conn, err := meta.(*config.MongoDBClient).Realm.Get(ctx) if err != nil { return diag.FromErr(err) } @@ -402,7 +402,7 @@ func resourceMongoDBAtlasEventTriggersRead(ctx context.Context, d *schema.Resour } func resourceMongoDBAtlasEventTriggersUpdate(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { - conn, err := meta.(*config.MongoDBClient).GetRealmClient(ctx) + conn, err := meta.(*config.MongoDBClient).Realm.Get(ctx) if err != nil { return diag.FromErr(err) } @@ -453,8 +453,7 @@ func resourceMongoDBAtlasEventTriggersUpdate(ctx context.Context, d *schema.Reso } func resourceMongoDBAtlasEventTriggersDelete(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { - // Get the client connection. - conn, err := meta.(*config.MongoDBClient).GetRealmClient(ctx) + conn, err := meta.(*config.MongoDBClient).Realm.Get(ctx) if err != nil { return diag.FromErr(err) } @@ -515,7 +514,7 @@ func flattenTriggerEventProcessorAWSEventBridge(eventProcessor map[string]any) [ } func resourceMongoDBAtlasEventTriggerImportState(ctx context.Context, d *schema.ResourceData, meta any) ([]*schema.ResourceData, error) { - conn, err := meta.(*config.MongoDBClient).GetRealmClient(ctx) + conn, err := meta.(*config.MongoDBClient).Realm.Get(ctx) if err != nil { return nil, err } diff --git a/internal/service/eventtrigger/resource_event_trigger_test.go b/internal/service/eventtrigger/resource_event_trigger_test.go index 30f2c2734d..865690c6a7 100644 --- a/internal/service/eventtrigger/resource_event_trigger_test.go +++ b/internal/service/eventtrigger/resource_event_trigger_test.go @@ -484,7 +484,7 @@ func TestAccEventTrigger_functionBasic(t *testing.T) { func checkExists(resourceName string) resource.TestCheckFunc { return func(s *terraform.State) error { ctx := context.Background() - conn, err := acc.MongoDBClient.GetRealmClient(ctx) + conn, err := acc.MongoDBClient.Realm.Get(ctx) if err != nil { return err } @@ -513,7 +513,7 @@ func checkExists(resourceName string) resource.TestCheckFunc { func checkDestroy(s *terraform.State) error { ctx := context.Background() - conn, err := acc.MongoDBClient.GetRealmClient(ctx) + conn, err := acc.MongoDBClient.Realm.Get(ctx) if err != nil { return err } diff --git a/internal/service/organization/resource_organization.go b/internal/service/organization/resource_organization.go index d8cf0c5080..3cd342f7f0 100644 --- a/internal/service/organization/resource_organization.go +++ b/internal/service/organization/resource_organization.go @@ -113,7 +113,7 @@ func resourceCreate(ctx context.Context, d *schema.ResourceData, meta any) diag. if err := ValidateAPIKeyIsOrgOwner(conversion.ExpandStringList(d.Get("role_names").(*schema.Set).List())); err != nil { return diag.FromErr(err) } - conn := getAtlasV2Connection(ctx, d, meta) // Using provider credentials. + conn := getAtlasV2Connection(d, meta) // Using provider credentials. organization, resp, err := conn.OrganizationsApi.CreateOrg(ctx, newCreateOrganizationRequest(d)).Execute() if err != nil { if validate.StatusNotFound(resp) && !strings.Contains(err.Error(), "USER_NOT_FOUND") { @@ -128,7 +128,7 @@ func resourceCreate(ctx context.Context, d *schema.ResourceData, meta any) diag. if err := d.Set("public_key", organization.ApiKey.GetPublicKey()); err != nil { return diag.FromErr(fmt.Errorf("error setting `public_key`: %s", err)) } - conn = getAtlasV2Connection(ctx, d, meta) // Using new credentials from the created organization. + conn = getAtlasV2Connection(d, meta) // Using new credentials from the created organization. orgID := organization.Organization.GetId() _, _, errUpdate := conn.OrganizationsApi.UpdateOrgSettings(ctx, orgID, newOrganizationSettings(d)).Execute() if errUpdate != nil { @@ -146,7 +146,7 @@ func resourceCreate(ctx context.Context, d *schema.ResourceData, meta any) diag. } func resourceRead(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { - conn := getAtlasV2Connection(ctx, d, meta) + conn := getAtlasV2Connection(d, meta) ids := conversion.DecodeStateID(d.Id()) orgID := ids["org_id"] @@ -194,7 +194,7 @@ func resourceRead(ctx context.Context, d *schema.ResourceData, meta any) diag.Di } func resourceUpdate(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { - conn := getAtlasV2Connection(ctx, d, meta) + conn := getAtlasV2Connection(d, meta) ids := conversion.DecodeStateID(d.Id()) orgID := ids["org_id"] for _, attr := range attrsCreateOnly { @@ -227,7 +227,7 @@ func resourceUpdate(ctx context.Context, d *schema.ResourceData, meta any) diag. } func resourceDelete(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { - conn := getAtlasV2Connection(ctx, d, meta) + conn := getAtlasV2Connection(d, meta) ids := conversion.DecodeStateID(d.Id()) orgID := ids["org_id"] @@ -293,18 +293,21 @@ func ValidateAPIKeyIsOrgOwner(roles []string) error { // getAtlasV2Connection uses the created credentials for the organization if they exist. // Otherwise, it uses the provider credentials, e.g. if the resource was imported. -func getAtlasV2Connection(ctx context.Context, d *schema.ResourceData, meta any) *admin.APIClient { +func getAtlasV2Connection(d *schema.ResourceData, meta any) *admin.APIClient { + currentClient := meta.(*config.MongoDBClient) publicKey := d.Get("public_key").(string) privateKey := d.Get("private_key").(string) if publicKey == "" || privateKey == "" { - return meta.(*config.MongoDBClient).AtlasV2 + return currentClient.AtlasV2 } - cfg := config.Config{ - PublicKey: publicKey, - PrivateKey: privateKey, - BaseURL: meta.(*config.MongoDBClient).Config.BaseURL, - TerraformVersion: meta.(*config.MongoDBClient).Config.TerraformVersion, + c := &config.Credentials{ + PublicKey: publicKey, + PrivateKey: privateKey, + BaseURL: currentClient.BaseURL, } - clients, _ := cfg.NewClient(ctx) - return clients.AtlasV2 + newClient, err := config.NewClient(c, currentClient.TerraformVersion) + if err != nil { + return currentClient.AtlasV2 + } + return newClient.AtlasV2 } diff --git a/internal/service/organization/resource_organization_test.go b/internal/service/organization/resource_organization_test.go index c862b6e672..7402c219c0 100644 --- a/internal/service/organization/resource_organization_test.go +++ b/internal/service/organization/resource_organization_test.go @@ -427,18 +427,16 @@ func getTestClientWithNewOrgCreds(rs *terraform.ResourceState) (*admin.APIClient if rs.Primary.Attributes["public_key"] == "" { return nil, fmt.Errorf("no public_key is set") } - if rs.Primary.Attributes["private_key"] == "" { return nil, fmt.Errorf("no private_key is set") } - - cfg := config.Config{ + c := &config.Credentials{ PublicKey: rs.Primary.Attributes["public_key"], PrivateKey: rs.Primary.Attributes["private_key"], - BaseURL: acc.MongoDBClient.Config.BaseURL, + BaseURL: acc.MongoDBClient.BaseURL, } - clients, _ := cfg.NewClient(context.Background()) - return clients.AtlasV2, nil + client, _ := config.NewClient(c, acc.MongoDBClient.TerraformVersion) + return client.AtlasV2, nil } func TestValidateAPIKeyIsOrgOwner(t *testing.T) { diff --git a/internal/testutil/acc/factory.go b/internal/testutil/acc/factory.go index 29f572540c..e654ec914c 100644 --- a/internal/testutil/acc/factory.go +++ b/internal/testutil/acc/factory.go @@ -1,7 +1,6 @@ package acc import ( - "context" "os" matlas "go.mongodb.org/atlas/mongodbatlas" @@ -42,12 +41,12 @@ func ConnV220241113() *admin20241113.APIClient { } func ConnV2UsingGov() *admin.APIClient { - cfg := config.Config{ + c := &config.Credentials{ PublicKey: os.Getenv("MONGODB_ATLAS_GOV_PUBLIC_KEY"), PrivateKey: os.Getenv("MONGODB_ATLAS_GOV_PRIVATE_KEY"), BaseURL: os.Getenv("MONGODB_ATLAS_GOV_BASE_URL"), } - client, _ := cfg.NewClient(context.Background()) + client, _ := config.NewClient(c, "") return client.AtlasV2 } @@ -57,7 +56,7 @@ func init() { return provider.MuxProviderFactory()(), nil }, } - cfg := config.Config{ + c := &config.Credentials{ PublicKey: os.Getenv("MONGODB_ATLAS_PUBLIC_KEY"), PrivateKey: os.Getenv("MONGODB_ATLAS_PRIVATE_KEY"), ClientID: os.Getenv("MONGODB_ATLAS_CLIENT_ID"), @@ -65,5 +64,5 @@ func init() { BaseURL: os.Getenv("MONGODB_ATLAS_BASE_URL"), RealmBaseURL: os.Getenv("MONGODB_REALM_BASE_URL"), } - MongoDBClient, _ = cfg.NewClient(context.Background()) + MongoDBClient, _ = config.NewClient(c, "") } diff --git a/internal/testutil/acc/pre_check.go b/internal/testutil/acc/pre_check.go index 8d98c621b5..0bc99d38b7 100644 --- a/internal/testutil/acc/pre_check.go +++ b/internal/testutil/acc/pre_check.go @@ -347,7 +347,7 @@ func PreCheckAwsMsk(tb testing.TB) { func PreCheckAccessToken(tb testing.TB) { tb.Helper() - if os.Getenv("MONGODB_ATLAS_OAUTH_TOKEN") == "" { - tb.Fatal("`MONGODB_ATLAS_OAUTH_TOKEN` must be set for Atlas Access Token acceptance testing") + if os.Getenv("MONGODB_ATLAS_ACCESS_TOKEN") == "" { + tb.Fatal("`MONGODB_ATLAS_ACCESS_TOKEN` must be set for Atlas Access Token acceptance testing") } }