@@ -33,6 +33,7 @@ import (
33
33
"github.com/aws/aws-sdk-go-v2/service/ec2"
34
34
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
35
35
"github.com/aws/aws-sdk-go-v2/service/sagemaker"
36
+ "github.com/aws/aws-sdk-go-v2/service/sts"
36
37
"github.com/aws/smithy-go"
37
38
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/batcher"
38
39
dm "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud/devicemanager"
88
89
)
89
90
90
91
const (
91
- cacheForgetDelay = 1 * time .Hour
92
- volInitCacheForgetDelay = 6 * time .Hour
92
+ cacheForgetDelay = 1 * time .Hour
93
+ volInitCacheForgetDelay = 6 * time .Hour
94
+ getCallerIdentityRetryDelay = 30 * time .Second
93
95
)
94
96
95
97
// VolumeStatusInitializingState is const reported by EC2 DescribeVolumeStatus which AWS SDK does not have type for.
@@ -320,6 +322,7 @@ type batcherManager struct {
320
322
}
321
323
322
324
type cloud struct {
325
+ awsConfig aws.Config
323
326
region string
324
327
ec2 EC2API
325
328
sm SageMakerAPI
@@ -331,18 +334,14 @@ type cloud struct {
331
334
latestClientTokens expiringcache.ExpiringCache [string , int ]
332
335
volumeInitializations expiringcache.ExpiringCache [string , volumeInitialization ]
333
336
accountID string
337
+ accountIDOnce sync.Once
334
338
}
335
339
336
340
var _ Cloud = & cloud {}
337
341
338
342
// NewCloud returns a new instance of AWS cloud
339
343
// It panics if session is invalid.
340
- func NewCloud (region string , accountID string , awsSdkDebugLog bool , userAgentExtra string , batching bool , deprecatedMetrics bool ) (Cloud , error ) {
341
- c := newEC2Cloud (region , accountID , awsSdkDebugLog , userAgentExtra , batching , deprecatedMetrics )
342
- return c , nil
343
- }
344
-
345
- func newEC2Cloud (region string , accountID string , awsSdkDebugLog bool , userAgentExtra string , batchingEnabled bool , deprecatedMetrics bool ) Cloud {
344
+ func NewCloud (region string , awsSdkDebugLog bool , userAgentExtra string , batchingEnabled bool , deprecatedMetrics bool ) Cloud {
346
345
cfg , err := config .LoadDefaultConfig (context .Background (), config .WithRegion (region ))
347
346
if err != nil {
348
347
panic (err )
@@ -386,11 +385,12 @@ func newEC2Cloud(region string, accountID string, awsSdkDebugLog bool, userAgent
386
385
387
386
var bm * batcherManager
388
387
if batchingEnabled {
389
- klog .V (4 ).InfoS ("newEC2Cloud : batching enabled" )
388
+ klog .V (4 ).InfoS ("NewCloud : batching enabled" )
390
389
bm = newBatcherManager (svc )
391
390
}
392
391
393
392
return & cloud {
393
+ awsConfig : cfg ,
394
394
region : region ,
395
395
dm : dm .NewDeviceManager (),
396
396
ec2 : svc ,
@@ -400,7 +400,6 @@ func newEC2Cloud(region string, accountID string, awsSdkDebugLog bool, userAgent
400
400
vwp : vwp ,
401
401
likelyBadDeviceNames : expiringcache.New [string , sync.Map ](cacheForgetDelay ),
402
402
latestClientTokens : expiringcache.New [string , int ](cacheForgetDelay ),
403
- accountID : accountID ,
404
403
volumeInitializations : expiringcache.New [string , volumeInitialization ](volInitCacheForgetDelay ),
405
404
}
406
405
}
@@ -997,7 +996,11 @@ func (c *cloud) attachDiskHyperPod(ctx context.Context, volumeID, nodeID string)
997
996
klog .V (2 ).InfoS ("AttachDisk: HyperPod node detected" , "volumeID" , volumeID , "nodeID" , nodeID )
998
997
999
998
instanceID := getInstanceIDFromHyperPodNode (nodeID )
1000
- clusterArn := c .buildHyperPodClusterArn (nodeID )
999
+ accountID , err := c .getAccountID (ctx )
1000
+ if err != nil {
1001
+ return "" , fmt .Errorf ("failed to get account ID: %w" , err )
1002
+ }
1003
+ clusterArn := buildHyperPodClusterArn (nodeID , c .region , accountID )
1001
1004
1002
1005
klog .V (5 ).InfoS ("HyperPod attachment details" ,
1003
1006
"volumeID" , volumeID ,
@@ -1025,7 +1028,7 @@ func (c *cloud) attachDiskHyperPod(ctx context.Context, volumeID, nodeID string)
1025
1028
1026
1029
// Wait for attachment completion
1027
1030
deviceName := aws .ToString (resp .DeviceName )
1028
- _ , err : = c .WaitForAttachmentState (
1031
+ _ , err = c .WaitForAttachmentState (
1029
1032
ctx ,
1030
1033
types .VolumeAttachmentStateAttached ,
1031
1034
volumeID ,
@@ -1099,7 +1102,11 @@ func (c *cloud) detachDiskHyperPod(ctx context.Context, volumeID, nodeID string)
1099
1102
klog .V (2 ).InfoS ("DetachDisk: HyperPod node detected" , "volumeID" , volumeID , "nodeID" , nodeID )
1100
1103
1101
1104
instanceID := getInstanceIDFromHyperPodNode (nodeID )
1102
- clusterArn := c .buildHyperPodClusterArn (nodeID )
1105
+ accountID , err := c .getAccountID (ctx )
1106
+ if err != nil {
1107
+ return fmt .Errorf ("failed to get account ID: %w" , err )
1108
+ }
1109
+ clusterArn := buildHyperPodClusterArn (nodeID , c .region , accountID )
1103
1110
1104
1111
klog .V (4 ).InfoS ("HyperPod detachment details" ,
1105
1112
"volumeID" , volumeID ,
@@ -1114,7 +1121,7 @@ func (c *cloud) detachDiskHyperPod(ctx context.Context, volumeID, nodeID string)
1114
1121
}
1115
1122
klog .V (4 ).InfoS ("Calling DetachClusterNodeVolumeInput" , "input" , input )
1116
1123
1117
- _ , err : = c .sm .DetachClusterNodeVolume (ctx , input )
1124
+ _ , err = c .sm .DetachClusterNodeVolume (ctx , input )
1118
1125
if err != nil {
1119
1126
if isAWSHyperPodErrorIncorrectState (err ) ||
1120
1127
isAWSHyperPodErrorInvalidAttachmentNotFound (err ) ||
@@ -1450,9 +1457,9 @@ func getInstanceIDFromHyperPodNode(nodeID string) string {
1450
1457
}
1451
1458
1452
1459
// Only for hyperpod node, buildHyperPodClusterArn: arn:aws:sagemaker:region:account:cluster/clusterID.
1453
- func ( c * cloud ) buildHyperPodClusterArn (nodeID string ) string {
1460
+ func buildHyperPodClusterArn (nodeID string , region string , accountID string ) string {
1454
1461
parts := strings .Split (nodeID , "-" )
1455
- return fmt .Sprintf ("arn:aws:sagemaker:%s:%s:cluster/%s" , c . region , c . accountID , parts [1 ])
1462
+ return fmt .Sprintf ("arn:aws:sagemaker:%s:%s:cluster/%s" , region , accountID , parts [1 ])
1456
1463
}
1457
1464
1458
1465
// For hyperpod node, AssociatedResource is in arn:aws:sagemaker:region:account:cluster/clusterID-instanceId format.
@@ -1916,6 +1923,54 @@ func (c *cloud) waitForVolume(ctx context.Context, volumeID string) error {
1916
1923
return err
1917
1924
}
1918
1925
1926
+ // getAccountID returns the account ID of the AWS Account for the IAM credentials in use.
1927
+ //
1928
+ // In the first call (or any calls made before the first call succeeds), getAccountID
1929
+ // will attempt to determine the Account ID via sts:GetCallerIdentity.
1930
+ // This attempt will retry indefinitely, however getAccountID will return when ctx is cancelled,
1931
+ // leaving the account ID thread to run in the background.
1932
+ //
1933
+ // In subsequent calls (after the first success), getAccountID will use a cached value.
1934
+ func (c * cloud ) getAccountID (ctx context.Context ) (string , error ) {
1935
+ accountIDRetrieved := make (chan struct {}, 1 )
1936
+
1937
+ // Start background thread if it isn't already.
1938
+ // Intentionally runs in the background until account ID is retrieved, so we don't pass the context.
1939
+ //nolint:contextcheck
1940
+ go func () {
1941
+ c .accountIDOnce .Do (func () {
1942
+ for c .accountID == "" {
1943
+ cfg , err := config .LoadDefaultConfig (context .Background (), config .WithRegion (c .region ))
1944
+ if err != nil {
1945
+ klog .ErrorS (err , "Failed to create AWS config for account ID retrieval" )
1946
+ }
1947
+
1948
+ stsClient := sts .NewFromConfig (cfg )
1949
+ resp , err := stsClient .GetCallerIdentity (context .Background (), & sts.GetCallerIdentityInput {})
1950
+ if err != nil {
1951
+ klog .ErrorS (err , "Failed to get AWS account ID, required for HyperPod operations, will retry" )
1952
+ time .Sleep (getCallerIdentityRetryDelay )
1953
+ } else {
1954
+ c .accountID = * resp .Account
1955
+ klog .V (5 ).InfoS ("Retrieved AWS account ID for HyperPod operations" , "accountID" , c .accountID )
1956
+ }
1957
+ }
1958
+ })
1959
+
1960
+ // Once.Do blocks until the function exits, even if we aren't the first caller.
1961
+ // So the account ID must be available now.
1962
+ accountIDRetrieved <- struct {}{}
1963
+ }()
1964
+
1965
+ select {
1966
+ case <- ctx .Done ():
1967
+ return "" , ctx .Err ()
1968
+
1969
+ case <- accountIDRetrieved :
1970
+ return c .accountID , nil
1971
+ }
1972
+ }
1973
+
1919
1974
// isAWSError returns a boolean indicating whether the error is AWS-related
1920
1975
// and has the given code. More information on AWS error codes at:
1921
1976
// https://docs.aws.amazon.com/AWSEC2/latest/APIReference/errors-overview.html
0 commit comments