Skip to content

Commit 97e7fb3

Browse files
authored
Merge pull request #2621 from ConnorJC3/cherrypick-gca-fix-147
[CHERRYPICK release-1.47] Call sts:GetCallerIdentity Just in Time instead of blocking on startup
2 parents f39b6b7 + 5b9f2bc commit 97e7fb3

File tree

5 files changed

+82
-68
lines changed

5 files changed

+82
-68
lines changed

cmd/main.go

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ import (
2323
"strings"
2424
"time"
2525

26-
"github.com/aws/aws-sdk-go-v2/config"
27-
"github.com/aws/aws-sdk-go-v2/service/sts"
2826
"github.com/kubernetes-sigs/aws-ebs-csi-driver/cmd/hooks"
2927
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud"
3028
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud/metadata"
@@ -181,30 +179,7 @@ func main() {
181179
region = md.GetRegion()
182180
}
183181

184-
var accountID string
185-
if options.Mode == driver.ControllerMode || options.Mode == driver.AllMode {
186-
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region))
187-
if err != nil {
188-
klog.ErrorS(err, "Failed to create AWS config for account ID retrieval")
189-
klog.FlushAndExit(klog.ExitFlushTimeout, 1)
190-
}
191-
192-
stsClient := sts.NewFromConfig(cfg)
193-
resp, err := stsClient.GetCallerIdentity(context.Background(), &sts.GetCallerIdentityInput{})
194-
if err != nil {
195-
klog.ErrorS(err, "Failed to get AWS account ID, HyperPod functionality may not work")
196-
// Continue without account ID - existing functionality should still work
197-
} else {
198-
accountID = *resp.Account
199-
klog.V(5).InfoS("Retrieved AWS account ID for HyperPod operations", "accountID", accountID)
200-
}
201-
}
202-
203-
cloud, err := cloud.NewCloud(region, accountID, options.AwsSdkDebugLog, options.UserAgentExtra, options.Batching, options.DeprecatedMetrics)
204-
if err != nil {
205-
klog.ErrorS(err, "failed to create cloud service")
206-
klog.FlushAndExit(klog.ExitFlushTimeout, 1)
207-
}
182+
cloud := cloud.NewCloud(region, options.AwsSdkDebugLog, options.UserAgentExtra, options.Batching, options.DeprecatedMetrics)
208183

209184
m, err := mounter.NewNodeMounter(options.WindowsHostProcess)
210185
if err != nil {

pkg/cloud/cloud.go

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333
"github.com/aws/aws-sdk-go-v2/service/ec2"
3434
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
3535
"github.com/aws/aws-sdk-go-v2/service/sagemaker"
36+
"github.com/aws/aws-sdk-go-v2/service/sts"
3637
"github.com/aws/smithy-go"
3738
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/batcher"
3839
dm "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud/devicemanager"
@@ -88,8 +89,9 @@ var (
8889
)
8990

9091
const (
91-
cacheForgetDelay = 1 * time.Hour
92-
volInitCacheForgetDelay = 6 * time.Hour
92+
cacheForgetDelay = 1 * time.Hour
93+
volInitCacheForgetDelay = 6 * time.Hour
94+
getCallerIdentityRetryDelay = 30 * time.Second
9395
)
9496

9597
// VolumeStatusInitializingState is const reported by EC2 DescribeVolumeStatus which AWS SDK does not have type for.
@@ -320,6 +322,7 @@ type batcherManager struct {
320322
}
321323

322324
type cloud struct {
325+
awsConfig aws.Config
323326
region string
324327
ec2 EC2API
325328
sm SageMakerAPI
@@ -331,18 +334,14 @@ type cloud struct {
331334
latestClientTokens expiringcache.ExpiringCache[string, int]
332335
volumeInitializations expiringcache.ExpiringCache[string, volumeInitialization]
333336
accountID string
337+
accountIDOnce sync.Once
334338
}
335339

336340
var _ Cloud = &cloud{}
337341

338342
// NewCloud returns a new instance of AWS cloud
339343
// It panics if session is invalid.
340-
func NewCloud(region string, accountID string, awsSdkDebugLog bool, userAgentExtra string, batching bool, deprecatedMetrics bool) (Cloud, error) {
341-
c := newEC2Cloud(region, accountID, awsSdkDebugLog, userAgentExtra, batching, deprecatedMetrics)
342-
return c, nil
343-
}
344-
345-
func newEC2Cloud(region string, accountID string, awsSdkDebugLog bool, userAgentExtra string, batchingEnabled bool, deprecatedMetrics bool) Cloud {
344+
func NewCloud(region string, awsSdkDebugLog bool, userAgentExtra string, batchingEnabled bool, deprecatedMetrics bool) Cloud {
346345
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region))
347346
if err != nil {
348347
panic(err)
@@ -386,11 +385,12 @@ func newEC2Cloud(region string, accountID string, awsSdkDebugLog bool, userAgent
386385

387386
var bm *batcherManager
388387
if batchingEnabled {
389-
klog.V(4).InfoS("newEC2Cloud: batching enabled")
388+
klog.V(4).InfoS("NewCloud: batching enabled")
390389
bm = newBatcherManager(svc)
391390
}
392391

393392
return &cloud{
393+
awsConfig: cfg,
394394
region: region,
395395
dm: dm.NewDeviceManager(),
396396
ec2: svc,
@@ -400,7 +400,6 @@ func newEC2Cloud(region string, accountID string, awsSdkDebugLog bool, userAgent
400400
vwp: vwp,
401401
likelyBadDeviceNames: expiringcache.New[string, sync.Map](cacheForgetDelay),
402402
latestClientTokens: expiringcache.New[string, int](cacheForgetDelay),
403-
accountID: accountID,
404403
volumeInitializations: expiringcache.New[string, volumeInitialization](volInitCacheForgetDelay),
405404
}
406405
}
@@ -997,7 +996,11 @@ func (c *cloud) attachDiskHyperPod(ctx context.Context, volumeID, nodeID string)
997996
klog.V(2).InfoS("AttachDisk: HyperPod node detected", "volumeID", volumeID, "nodeID", nodeID)
998997

999998
instanceID := getInstanceIDFromHyperPodNode(nodeID)
1000-
clusterArn := c.buildHyperPodClusterArn(nodeID)
999+
accountID, err := c.getAccountID(ctx)
1000+
if err != nil {
1001+
return "", fmt.Errorf("failed to get account ID: %w", err)
1002+
}
1003+
clusterArn := buildHyperPodClusterArn(nodeID, c.region, accountID)
10011004

10021005
klog.V(5).InfoS("HyperPod attachment details",
10031006
"volumeID", volumeID,
@@ -1025,7 +1028,7 @@ func (c *cloud) attachDiskHyperPod(ctx context.Context, volumeID, nodeID string)
10251028

10261029
// Wait for attachment completion
10271030
deviceName := aws.ToString(resp.DeviceName)
1028-
_, err := c.WaitForAttachmentState(
1031+
_, err = c.WaitForAttachmentState(
10291032
ctx,
10301033
types.VolumeAttachmentStateAttached,
10311034
volumeID,
@@ -1099,7 +1102,11 @@ func (c *cloud) detachDiskHyperPod(ctx context.Context, volumeID, nodeID string)
10991102
klog.V(2).InfoS("DetachDisk: HyperPod node detected", "volumeID", volumeID, "nodeID", nodeID)
11001103

11011104
instanceID := getInstanceIDFromHyperPodNode(nodeID)
1102-
clusterArn := c.buildHyperPodClusterArn(nodeID)
1105+
accountID, err := c.getAccountID(ctx)
1106+
if err != nil {
1107+
return fmt.Errorf("failed to get account ID: %w", err)
1108+
}
1109+
clusterArn := buildHyperPodClusterArn(nodeID, c.region, accountID)
11031110

11041111
klog.V(4).InfoS("HyperPod detachment details",
11051112
"volumeID", volumeID,
@@ -1114,7 +1121,7 @@ func (c *cloud) detachDiskHyperPod(ctx context.Context, volumeID, nodeID string)
11141121
}
11151122
klog.V(4).InfoS("Calling DetachClusterNodeVolumeInput", "input", input)
11161123

1117-
_, err := c.sm.DetachClusterNodeVolume(ctx, input)
1124+
_, err = c.sm.DetachClusterNodeVolume(ctx, input)
11181125
if err != nil {
11191126
if isAWSHyperPodErrorIncorrectState(err) ||
11201127
isAWSHyperPodErrorInvalidAttachmentNotFound(err) ||
@@ -1450,9 +1457,9 @@ func getInstanceIDFromHyperPodNode(nodeID string) string {
14501457
}
14511458

14521459
// Only for hyperpod node, buildHyperPodClusterArn: arn:aws:sagemaker:region:account:cluster/clusterID.
1453-
func (c *cloud) buildHyperPodClusterArn(nodeID string) string {
1460+
func buildHyperPodClusterArn(nodeID string, region string, accountID string) string {
14541461
parts := strings.Split(nodeID, "-")
1455-
return fmt.Sprintf("arn:aws:sagemaker:%s:%s:cluster/%s", c.region, c.accountID, parts[1])
1462+
return fmt.Sprintf("arn:aws:sagemaker:%s:%s:cluster/%s", region, accountID, parts[1])
14561463
}
14571464

14581465
// For hyperpod node, AssociatedResource is in arn:aws:sagemaker:region:account:cluster/clusterID-instanceId format.
@@ -1916,6 +1923,54 @@ func (c *cloud) waitForVolume(ctx context.Context, volumeID string) error {
19161923
return err
19171924
}
19181925

1926+
// getAccountID returns the account ID of the AWS Account for the IAM credentials in use.
1927+
//
1928+
// In the first call (or any calls made before the first call succeeds), getAccountID
1929+
// will attempt to determine the Account ID via sts:GetCallerIdentity.
1930+
// This attempt will retry indefinitely, however getAccountID will return when ctx is cancelled,
1931+
// leaving the account ID thread to run in the background.
1932+
//
1933+
// In subsequent calls (after the first success), getAccountID will use a cached value.
1934+
func (c *cloud) getAccountID(ctx context.Context) (string, error) {
1935+
accountIDRetrieved := make(chan struct{}, 1)
1936+
1937+
// Start background thread if it isn't already.
1938+
// Intentionally runs in the background until account ID is retrieved, so we don't pass the context.
1939+
//nolint:contextcheck
1940+
go func() {
1941+
c.accountIDOnce.Do(func() {
1942+
for c.accountID == "" {
1943+
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(c.region))
1944+
if err != nil {
1945+
klog.ErrorS(err, "Failed to create AWS config for account ID retrieval")
1946+
}
1947+
1948+
stsClient := sts.NewFromConfig(cfg)
1949+
resp, err := stsClient.GetCallerIdentity(context.Background(), &sts.GetCallerIdentityInput{})
1950+
if err != nil {
1951+
klog.ErrorS(err, "Failed to get AWS account ID, required for HyperPod operations, will retry")
1952+
time.Sleep(getCallerIdentityRetryDelay)
1953+
} else {
1954+
c.accountID = *resp.Account
1955+
klog.V(5).InfoS("Retrieved AWS account ID for HyperPod operations", "accountID", c.accountID)
1956+
}
1957+
}
1958+
})
1959+
1960+
// Once.Do blocks until the function exits, even if we aren't the first caller.
1961+
// So the account ID must be available now.
1962+
accountIDRetrieved <- struct{}{}
1963+
}()
1964+
1965+
select {
1966+
case <-ctx.Done():
1967+
return "", ctx.Err()
1968+
1969+
case <-accountIDRetrieved:
1970+
return c.accountID, nil
1971+
}
1972+
}
1973+
19191974
// isAWSError returns a boolean indicating whether the error is AWS-related
19201975
// and has the given code. More information on AWS error codes at:
19211976
// https://docs.aws.amazon.com/AWSEC2/latest/APIReference/errors-overview.html

pkg/cloud/cloud_test.go

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ func TestNewCloud(t *testing.T) {
8989
testCases := []struct {
9090
name string
9191
region string
92-
accountID string
9392
awsSdkDebugLog bool
9493
userAgentExtra string
9594
batchingEnabled bool
@@ -98,15 +97,13 @@ func TestNewCloud(t *testing.T) {
9897
{
9998
name: "success: with awsSdkDebugLog, userAgentExtra, and batchingEnabled",
10099
region: "us-east-1",
101-
accountID: "123456789012",
102100
awsSdkDebugLog: true,
103101
userAgentExtra: "example_user_agent_extra",
104102
batchingEnabled: true,
105103
},
106104
{
107105
name: "success: with only awsSdkDebugLog, and userAgentExtra",
108106
region: "us-east-1",
109-
accountID: "123456789012",
110107
awsSdkDebugLog: true,
111108
userAgentExtra: "example_user_agent_extra",
112109
},
@@ -116,10 +113,7 @@ func TestNewCloud(t *testing.T) {
116113
},
117114
}
118115
for _, tc := range testCases {
119-
ec2Cloud, err := NewCloud(tc.region, tc.accountID, tc.awsSdkDebugLog, tc.userAgentExtra, tc.batchingEnabled, tc.deprecatedMetrics)
120-
if err != nil {
121-
t.Fatalf("error %v", err)
122-
}
116+
ec2Cloud := NewCloud(tc.region, tc.awsSdkDebugLog, tc.userAgentExtra, tc.batchingEnabled, tc.deprecatedMetrics)
123117
ec2CloudAscloud, ok := ec2Cloud.(*cloud)
124118
if !ok {
125119
t.Fatalf("could not assert object ec2Cloud as cloud type, %v", ec2Cloud)
@@ -132,6 +126,7 @@ func TestNewCloud(t *testing.T) {
132126
}
133127
}
134128
}
129+
135130
func TestBatchDescribeVolumes(t *testing.T) {
136131
t.Parallel()
137132
testCases := []struct {
@@ -2442,11 +2437,15 @@ func TestBuildHyperPodClusterArn(t *testing.T) {
24422437
testCases := []struct {
24432438
name string
24442439
nodeID string
2440+
region string
2441+
accountID string
24452442
expectedArn string
24462443
}{
24472444
{
24482445
name: "success: valid HyperPod node",
24492446
nodeID: "hyperpod-abc123-i-1234567890abcdef0",
2447+
region: "test-region",
2448+
accountID: "123456789012",
24502449
expectedArn: "arn:aws:sagemaker:test-region:123456789012:cluster/abc123",
24512450
},
24522451
}
@@ -2456,13 +2455,7 @@ func TestBuildHyperPodClusterArn(t *testing.T) {
24562455
mockCtrl := gomock.NewController(t)
24572456
defer mockCtrl.Finish()
24582457

2459-
mockEC2 := NewMockEC2API(mockCtrl)
2460-
c := newCloud(mockEC2)
2461-
cloudInstance, ok := c.(*cloud)
2462-
if !ok {
2463-
t.Fatalf("could not assert cloudInstance as type cloud, %v", cloudInstance)
2464-
}
2465-
result := cloudInstance.buildHyperPodClusterArn(tc.nodeID)
2458+
result := buildHyperPodClusterArn(tc.nodeID, tc.region, tc.accountID)
24662459
assert.Equal(t, tc.expectedArn, result)
24672460
})
24682461
}

tests/e2e/dynamic_provisioning.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -621,10 +621,7 @@ var _ = Describe("[ebs-csi-e2e] [single-az] Dynamic Provisioning", func() {
621621
availabilityZones := strings.Split(os.Getenv(awsAvailabilityZonesEnv), ",")
622622
availabilityZone := availabilityZones[rand.Intn(len(availabilityZones))]
623623
region := availabilityZone[0 : len(availabilityZone)-1]
624-
cloud, err := awscloud.NewCloud(region, "", false, "", true, false)
625-
if err != nil {
626-
Fail(fmt.Sprintf("could not get NewCloud: %v", err))
627-
}
624+
cloud := awscloud.NewCloud(region, false, "", true, false)
628625

629626
test := testsuites.DynamicallyProvisionedReclaimPolicyTest{
630627
CSIDriver: ebsDriver,

tests/e2e/pre_provsioning.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,7 @@ var _ = Describe("[ebs-csi-e2e] [single-az] Pre-Provisioned", func() {
8787
Tags: map[string]string{awscloud.VolumeNameTagKey: dummyVolumeName, awscloud.AwsEbsDriverTagKey: "true"},
8888
}
8989
var err error
90-
cloud, err = awscloud.NewCloud(region, "", false, "", true, false)
91-
if err != nil {
92-
Fail(fmt.Sprintf("could not get NewCloud: %v", err))
93-
}
90+
cloud = awscloud.NewCloud(region, false, "", true, false)
9491
r1 := rand.New(rand.NewSource(time.Now().UnixNano()))
9592
disk, err := cloud.CreateDisk(context.Background(), fmt.Sprintf("pvc-%d", r1.Uint64()), diskOptions)
9693
if err != nil {
@@ -260,10 +257,7 @@ var _ = Describe("[ebs-csi-e2e] [single-az] Pre-Provisioned with Multi-Attach",
260257
Tags: map[string]string{awscloud.VolumeNameTagKey: dummyVolumeName, awscloud.AwsEbsDriverTagKey: "true"},
261258
}
262259
var err error
263-
cloud, err = awscloud.NewCloud(region, "", false, "", true, false)
264-
if err != nil {
265-
Fail(fmt.Sprintf("could not get NewCloud: %v", err))
266-
}
260+
cloud = awscloud.NewCloud(region, false, "", true, false)
267261
r1 := rand.New(rand.NewSource(time.Now().UnixNano()))
268262
disk, err := cloud.CreateDisk(context.Background(), fmt.Sprintf("pvc-%d", r1.Uint64()), diskOptions)
269263
if err != nil {

0 commit comments

Comments
 (0)