@@ -26,7 +26,6 @@ import (
26
26
"github.com/databricks/databricks-sdk-go/service/jobs"
27
27
"github.com/databricks/databricks-sdk-go/service/workspace"
28
28
"github.com/gorilla/websocket"
29
- "golang.org/x/sync/errgroup"
30
29
)
31
30
32
31
//go:embed ssh-server-bootstrap.py
@@ -63,7 +62,7 @@ type ClientOptions struct {
63
62
AdditionalArgs []string
64
63
}
65
64
66
- func RunClient (ctx context.Context , client * databricks.WorkspaceClient , opts ClientOptions ) error {
65
+ func Run (ctx context.Context , client * databricks.WorkspaceClient , opts ClientOptions ) error {
67
66
ctx , cancel := context .WithCancel (ctx )
68
67
defer cancel ()
69
68
@@ -85,11 +84,11 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
85
84
}
86
85
cmdio .LogString (ctx , "Using SSH key: " + privateKeyPath )
87
86
88
- secretsScopeName , err := keys .PutSecretInScope (ctx , client , opts .ClusterID , opts .ClientPublicKeyName , publicKey )
87
+ keysSecretScopeName , err := keys .PutSecretInScope (ctx , client , opts .ClusterID , opts .ClientPublicKeyName , publicKey )
89
88
if err != nil {
90
89
return fmt .Errorf ("failed to store public key in secret scope: %w" , err )
91
90
}
92
- cmdio .LogString (ctx , fmt .Sprintf ("Secrets scope: %s, key name: %s" , secretsScopeName , opts .ClientPublicKeyName ))
91
+ cmdio .LogString (ctx , fmt .Sprintf ("Secrets scope: %s, key name: %s" , keysSecretScopeName , opts .ClientPublicKeyName ))
93
92
94
93
var userName string
95
94
var serverPort int
@@ -101,7 +100,7 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
101
100
if err := UploadTunnelReleases (ctx , client , version , opts .ReleasesDir ); err != nil {
102
101
return fmt .Errorf ("failed to upload ssh-tunnel binaries: %w" , err )
103
102
}
104
- userName , serverPort , err = ensureSSHServerIsRunning (ctx , client , opts .ClusterID , secretsScopeName , opts .ClientPublicKeyName , version , opts .ShutdownDelay , opts .MaxClients , opts .ServerTimeout )
103
+ userName , serverPort , err = ensureSSHServerIsRunning (ctx , client , opts .ClusterID , keysSecretScopeName , opts .ClientPublicKeyName , version , opts .ShutdownDelay , opts .MaxClients , opts .ServerTimeout )
105
104
if err != nil {
106
105
return fmt .Errorf ("failed to ensure that ssh server is running: %w" , err )
107
106
}
@@ -124,7 +123,7 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
124
123
cmdio .LogString (ctx , fmt .Sprintf ("Server port: %d" , serverPort ))
125
124
126
125
if opts .ProxyMode {
127
- return startSSHProxy (ctx , client , opts .ClusterID , serverPort , opts .HandoverTimeout )
126
+ return runSSHProxy (ctx , client , opts .ClusterID , serverPort , opts .HandoverTimeout )
128
127
} else {
129
128
cmdio .LogString (ctx , fmt .Sprintf ("Additional SSH arguments: %v" , opts .AdditionalArgs ))
130
129
return spawnSSHClient (ctx , opts .ClusterID , userName , privateKeyPath , serverPort , opts .HandoverTimeout , opts .AdditionalArgs )
@@ -165,7 +164,7 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
165
164
return serverPort , string (bodyBytes ), nil
166
165
}
167
166
168
- func submitSSHTunnelJob (ctx context.Context , client * databricks.WorkspaceClient , clusterID , secretsScope , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time.Duration ) (int64 , error ) {
167
+ func submitSSHTunnelJob (ctx context.Context , client * databricks.WorkspaceClient , clusterID , keysSecretScopeName , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time.Duration ) (int64 , error ) {
169
168
contentDir , err := sshWorkspace .GetWorkspaceContentDir (ctx , client , version , clusterID )
170
169
if err != nil {
171
170
return 0 , fmt .Errorf ("failed to get workspace content directory: %w" , err )
@@ -201,11 +200,11 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
201
200
NotebookTask : & jobs.NotebookTask {
202
201
NotebookPath : jobNotebookPath ,
203
202
BaseParameters : map [string ]string {
204
- "version" : version ,
205
- "secretsScope " : secretsScope ,
206
- "publicKeySecretName " : publicKeySecretName ,
207
- "shutdownDelay" : shutdownDelay .String (),
208
- "maxClients" : strconv .Itoa (maxClients ),
203
+ "version" : version ,
204
+ "keysSecretScopeName " : keysSecretScopeName ,
205
+ "authorizedKeySecretName " : publicKeySecretName ,
206
+ "shutdownDelay" : shutdownDelay .String (),
207
+ "maxClients" : strconv .Itoa (maxClients ),
209
208
},
210
209
},
211
210
TimeoutSeconds : int (serverTimeout .Seconds ()),
@@ -253,45 +252,14 @@ func spawnSSHClient(ctx context.Context, clusterID, userName, privateKeyPath str
253
252
return sshCmd .Run ()
254
253
}
255
254
256
- func startSSHProxy (ctx context.Context , client * databricks.WorkspaceClient , clusterID string , serverPort int , handoverTimeout time.Duration ) error {
257
- g , gCtx := errgroup .WithContext (ctx )
258
-
259
- cmdio .LogString (ctx , "Establishing SSH proxy connection..." )
260
- conn := proxy .NewProxyConnection (func (ctx context.Context , connID string ) (* websocket.Conn , error ) {
255
+ func runSSHProxy (ctx context.Context , client * databricks.WorkspaceClient , clusterID string , serverPort int , handoverTimeout time.Duration ) error {
256
+ createConn := func (ctx context.Context , connID string ) (* websocket.Conn , error ) {
261
257
return createWebsocketConnection (ctx , client , connID , clusterID , serverPort )
262
- })
263
- if err := conn .Connect (gCtx ); err != nil {
264
- return fmt .Errorf ("failed to connect to proxy: %w" , err )
265
258
}
266
- defer conn .Close ()
267
- cmdio .LogString (ctx , "SSH proxy connection established" )
268
-
269
- cmdio .LogString (ctx , fmt .Sprintf ("Connection handover timeout: %v" , handoverTimeout ))
270
- handoverTicker := time .NewTicker (handoverTimeout )
271
- defer handoverTicker .Stop ()
272
-
273
- g .Go (func () error {
274
- for {
275
- select {
276
- case <- gCtx .Done ():
277
- return gCtx .Err ()
278
- case <- handoverTicker .C :
279
- err := conn .InitiateHandover (gCtx )
280
- if err != nil {
281
- return err
282
- }
283
- }
284
- }
285
- })
286
-
287
- g .Go (func () error {
288
- return conn .Start (gCtx , os .Stdin , os .Stdout )
289
- })
290
-
291
- return g .Wait ()
259
+ return proxy .RunClientProxy (ctx , os .Stdin , os .Stdout , handoverTimeout , createConn )
292
260
}
293
261
294
- func ensureSSHServerIsRunning (ctx context.Context , client * databricks.WorkspaceClient , clusterID , secretsScope , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time.Duration ) (string , int , error ) {
262
+ func ensureSSHServerIsRunning (ctx context.Context , client * databricks.WorkspaceClient , clusterID , keysSecretScopeName , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time.Duration ) (string , int , error ) {
295
263
cmdio .LogString (ctx , "Ensuring the cluster is running: " + clusterID )
296
264
err := client .Clusters .EnsureClusterIsRunning (ctx , clusterID )
297
265
if err != nil {
@@ -302,7 +270,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
302
270
if errors .Is (err , errServerMetadata ) {
303
271
cmdio .LogString (ctx , "SSH server is not running, starting it now..." )
304
272
305
- runID , err := submitSSHTunnelJob (ctx , client , clusterID , secretsScope , publicKeySecretName , version , shutdownDelay , maxClients , serverTimeout )
273
+ runID , err := submitSSHTunnelJob (ctx , client , clusterID , keysSecretScopeName , publicKeySecretName , version , shutdownDelay , maxClients , serverTimeout )
306
274
if err != nil {
307
275
return "" , 0 , fmt .Errorf ("failed to submit ssh server job: %w" , err )
308
276
}
0 commit comments