@@ -26,7 +26,6 @@ import (
26
26
"github.com/databricks/databricks-sdk-go/service/jobs"
27
27
"github.com/databricks/databricks-sdk-go/service/workspace"
28
28
"github.com/gorilla/websocket"
29
- "golang.org/x/sync/errgroup"
30
29
)
31
30
32
31
//go:embed ssh-server-bootstrap.py
@@ -63,7 +62,7 @@ type ClientOptions struct {
63
62
AdditionalArgs []string
64
63
}
65
64
66
- func RunClient (ctx context.Context , client * databricks.WorkspaceClient , opts ClientOptions ) error {
65
+ func Run (ctx context.Context , client * databricks.WorkspaceClient , opts ClientOptions ) error {
67
66
ctx , cancel := context .WithCancel (ctx )
68
67
defer cancel ()
69
68
@@ -124,7 +123,7 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
124
123
cmdio .LogString (ctx , fmt .Sprintf ("Server port: %d" , serverPort ))
125
124
126
125
if opts .ProxyMode {
127
- return startSSHProxy (ctx , client , opts .ClusterID , serverPort , opts .HandoverTimeout )
126
+ return runSSHProxy (ctx , client , opts .ClusterID , serverPort , opts .HandoverTimeout )
128
127
} else {
129
128
cmdio .LogString (ctx , fmt .Sprintf ("Additional SSH arguments: %v" , opts .AdditionalArgs ))
130
129
return spawnSSHClient (ctx , opts .ClusterID , userName , privateKeyPath , serverPort , opts .HandoverTimeout , opts .AdditionalArgs )
@@ -253,42 +252,11 @@ func spawnSSHClient(ctx context.Context, clusterID, userName, privateKeyPath str
253
252
return sshCmd .Run ()
254
253
}
255
254
256
- func startSSHProxy (ctx context.Context , client * databricks.WorkspaceClient , clusterID string , serverPort int , handoverTimeout time.Duration ) error {
257
- g , gCtx := errgroup .WithContext (ctx )
258
-
259
- cmdio .LogString (ctx , "Establishing SSH proxy connection..." )
260
- conn := proxy .NewProxyConnection (func (ctx context.Context , connID string ) (* websocket.Conn , error ) {
255
+ func runSSHProxy (ctx context.Context , client * databricks.WorkspaceClient , clusterID string , serverPort int , handoverTimeout time.Duration ) error {
256
+ createConn := func (ctx context.Context , connID string ) (* websocket.Conn , error ) {
261
257
return createWebsocketConnection (ctx , client , connID , clusterID , serverPort )
262
- })
263
- if err := conn .Connect (gCtx ); err != nil {
264
- return fmt .Errorf ("failed to connect to proxy: %w" , err )
265
258
}
266
- defer conn .Close ()
267
- cmdio .LogString (ctx , "SSH proxy connection established" )
268
-
269
- cmdio .LogString (ctx , fmt .Sprintf ("Connection handover timeout: %v" , handoverTimeout ))
270
- handoverTicker := time .NewTicker (handoverTimeout )
271
- defer handoverTicker .Stop ()
272
-
273
- g .Go (func () error {
274
- for {
275
- select {
276
- case <- gCtx .Done ():
277
- return gCtx .Err ()
278
- case <- handoverTicker .C :
279
- err := conn .InitiateHandover (gCtx )
280
- if err != nil {
281
- return err
282
- }
283
- }
284
- }
285
- })
286
-
287
- g .Go (func () error {
288
- return conn .Start (gCtx , os .Stdin , os .Stdout )
289
- })
290
-
291
- return g .Wait ()
259
+ return proxy .RunClientProxy (ctx , os .Stdin , os .Stdout , handoverTimeout , createConn )
292
260
}
293
261
294
262
func ensureSSHServerIsRunning (ctx context.Context , client * databricks.WorkspaceClient , clusterID , secretsScope , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time.Duration ) (string , int , error ) {
0 commit comments