1
- package ssh
1
+ package client
2
2
3
3
import (
4
4
"context"
5
5
_ "embed"
6
6
"encoding/base64"
7
- "encoding/json"
8
7
"errors"
9
8
"fmt"
10
9
"io"
@@ -18,6 +17,9 @@ import (
18
17
"syscall"
19
18
"time"
20
19
20
+ "github.com/databricks/cli/experimental/ssh/internal/keys"
21
+ "github.com/databricks/cli/experimental/ssh/internal/proxy"
22
+ sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace"
21
23
"github.com/databricks/cli/internal/build"
22
24
"github.com/databricks/cli/libs/cmdio"
23
25
"github.com/databricks/databricks-sdk-go"
@@ -27,17 +29,11 @@ import (
27
29
"golang.org/x/sync/errgroup"
28
30
)
29
31
30
- type PortMetadata struct {
31
- Port int `json:"port"`
32
- }
33
-
34
32
//go:embed ssh-server-bootstrap.py
35
33
var sshServerBootstrapScript string
36
34
37
35
var errServerMetadata = errors .New ("server metadata error" )
38
36
39
- const serverJobTimeoutSeconds = 24 * 60 * 60
40
-
41
37
type ClientOptions struct {
42
38
// Id of the cluster to connect to
43
39
ClusterID string
@@ -54,6 +50,8 @@ type ClientOptions struct {
54
50
ServerMetadata string
55
51
// How often the CLI should reconnect to the server with new auth.
56
52
HandoverTimeout time.Duration
53
+ // Max amount of time the server process is allowed to live
54
+ ServerTimeout time.Duration
57
55
// Directory for local SSH tunnel development releases.
58
56
// If not present, the CLI will use github releases with the current version.
59
57
ReleasesDir string
@@ -77,16 +75,17 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
77
75
cancel ()
78
76
}()
79
77
80
- keyPath , err := getLocalSSHKeyPath (opts .ClusterID , opts .SSHKeysDir )
78
+ keyPath , err := keys . GetLocalSSHKeyPath (opts .ClusterID , opts .SSHKeysDir )
81
79
if err != nil {
82
80
return fmt .Errorf ("failed to get local keys folder: %w" , err )
83
81
}
84
- privateKeyPath , publicKey , err := checkAndGenerateSSHKeyPair (ctx , keyPath )
82
+ privateKeyPath , publicKey , err := keys . CheckAndGenerateSSHKeyPair (ctx , keyPath )
85
83
if err != nil {
86
84
return fmt .Errorf ("failed to check or generate SSH key pair: %w" , err )
87
85
}
86
+ cmdio .LogString (ctx , "Using SSH key: " + privateKeyPath )
88
87
89
- secretsScopeName , err := putSecretInScope (ctx , client , opts .ClusterID , opts .ClientPublicKeyName , publicKey )
88
+ secretsScopeName , err := keys . PutSecretInScope (ctx , client , opts .ClusterID , opts .ClientPublicKeyName , publicKey )
90
89
if err != nil {
91
90
return fmt .Errorf ("failed to store public key in secret scope: %w" , err )
92
91
}
@@ -99,10 +98,10 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
99
98
100
99
if opts .ServerMetadata == "" {
101
100
cmdio .LogString (ctx , "Checking for ssh-tunnel binaries to upload..." )
102
- if err := uploadTunnelBinaries (ctx , client , version , opts .ReleasesDir ); err != nil {
101
+ if err := UploadTunnelReleases (ctx , client , version , opts .ReleasesDir ); err != nil {
103
102
return fmt .Errorf ("failed to upload ssh-tunnel binaries: %w" , err )
104
103
}
105
- userName , serverPort , err = ensureSSHServerIsRunning (ctx , client , opts .ClusterID , secretsScopeName , opts .ClientPublicKeyName , version , opts .ShutdownDelay , opts .MaxClients )
104
+ userName , serverPort , err = ensureSSHServerIsRunning (ctx , client , opts .ClusterID , secretsScopeName , opts .ClientPublicKeyName , version , opts .ShutdownDelay , opts .MaxClients , opts . ServerTimeout )
106
105
if err != nil {
107
106
return fmt .Errorf ("failed to ensure that ssh server is running: %w" , err )
108
107
}
@@ -132,36 +131,8 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
132
131
}
133
132
}
134
133
135
- func getWorkspaceMetadata (ctx context.Context , client * databricks.WorkspaceClient , version , clusterID string ) (int , error ) {
136
- contentDir , err := getWorkspaceContentDir (ctx , client , version , clusterID )
137
- if err != nil {
138
- return 0 , fmt .Errorf ("failed to get workspace content directory: %w" , err )
139
- }
140
-
141
- metadataPath := filepath .ToSlash (filepath .Join (contentDir , "metadata.json" ))
142
-
143
- content , err := client .Workspace .Download (ctx , metadataPath )
144
- if err != nil {
145
- return 0 , fmt .Errorf ("failed to download metadata file: %w" , err )
146
- }
147
- defer content .Close ()
148
-
149
- metadataBytes , err := io .ReadAll (content )
150
- if err != nil {
151
- return 0 , fmt .Errorf ("failed to read metadata content: %w" , err )
152
- }
153
-
154
- var metadata PortMetadata
155
- err = json .Unmarshal (metadataBytes , & metadata )
156
- if err != nil {
157
- return 0 , fmt .Errorf ("failed to parse metadata JSON: %w" , err )
158
- }
159
-
160
- return metadata .Port , nil
161
- }
162
-
163
134
func getServerMetadata (ctx context.Context , client * databricks.WorkspaceClient , clusterID , version string ) (int , string , error ) {
164
- serverPort , err := getWorkspaceMetadata (ctx , client , version , clusterID )
135
+ serverPort , err := sshWorkspace . GetWorkspaceMetadata (ctx , client , version , clusterID )
165
136
if err != nil {
166
137
return 0 , "" , errors .Join (errServerMetadata , err )
167
138
}
@@ -194,8 +165,8 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
194
165
return serverPort , string (bodyBytes ), nil
195
166
}
196
167
197
- func submitSSHTunnelJob (ctx context.Context , client * databricks.WorkspaceClient , clusterID , secretsScope , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int ) (int64 , error ) {
198
- contentDir , err := getWorkspaceContentDir (ctx , client , version , clusterID )
168
+ func submitSSHTunnelJob (ctx context.Context , client * databricks.WorkspaceClient , clusterID , secretsScope , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time. Duration ) (int64 , error ) {
169
+ contentDir , err := sshWorkspace . GetWorkspaceContentDir (ctx , client , version , clusterID )
199
170
if err != nil {
200
171
return 0 , fmt .Errorf ("failed to get workspace content directory: %w" , err )
201
172
}
@@ -223,7 +194,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
223
194
224
195
submitRun := jobs.SubmitRun {
225
196
RunName : sshTunnelJobName ,
226
- TimeoutSeconds : serverJobTimeoutSeconds ,
197
+ TimeoutSeconds : int ( serverTimeout . Seconds ()) ,
227
198
Tasks : []jobs.SubmitTask {
228
199
{
229
200
TaskKey : "start_ssh_server" ,
@@ -237,7 +208,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
237
208
"maxClients" : strconv .Itoa (maxClients ),
238
209
},
239
210
},
240
- TimeoutSeconds : serverJobTimeoutSeconds ,
211
+ TimeoutSeconds : int ( serverTimeout . Seconds ()) ,
241
212
ExistingClusterId : clusterID ,
242
213
},
243
214
},
@@ -286,13 +257,13 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
286
257
g , gCtx := errgroup .WithContext (ctx )
287
258
288
259
cmdio .LogString (ctx , "Establishing SSH proxy connection..." )
289
- proxy := newProxyConnection (func (ctx context.Context , connID string ) (* websocket.Conn , error ) {
260
+ conn := proxy . NewProxyConnection (func (ctx context.Context , connID string ) (* websocket.Conn , error ) {
290
261
return createWebsocketConnection (ctx , client , connID , clusterID , serverPort )
291
262
})
292
- if err := proxy .Connect (gCtx ); err != nil {
263
+ if err := conn .Connect (gCtx ); err != nil {
293
264
return fmt .Errorf ("failed to connect to proxy: %w" , err )
294
265
}
295
- defer proxy .Close ()
266
+ defer conn .Close ()
296
267
cmdio .LogString (ctx , "SSH proxy connection established" )
297
268
298
269
cmdio .LogString (ctx , fmt .Sprintf ("Connection handover timeout: %v" , handoverTimeout ))
@@ -305,7 +276,7 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
305
276
case <- gCtx .Done ():
306
277
return gCtx .Err ()
307
278
case <- handoverTicker .C :
308
- err := proxy .InitiateHandover (gCtx )
279
+ err := conn .InitiateHandover (gCtx )
309
280
if err != nil {
310
281
return err
311
282
}
@@ -314,13 +285,13 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
314
285
})
315
286
316
287
g .Go (func () error {
317
- return proxy .Start (gCtx , os .Stdin , os .Stdout )
288
+ return conn .Start (gCtx , os .Stdin , os .Stdout )
318
289
})
319
290
320
291
return g .Wait ()
321
292
}
322
293
323
- func ensureSSHServerIsRunning (ctx context.Context , client * databricks.WorkspaceClient , clusterID , secretsScope , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int ) (string , int , error ) {
294
+ func ensureSSHServerIsRunning (ctx context.Context , client * databricks.WorkspaceClient , clusterID , secretsScope , publicKeySecretName , version string , shutdownDelay time.Duration , maxClients int , serverTimeout time. Duration ) (string , int , error ) {
324
295
cmdio .LogString (ctx , "Ensuring the cluster is running: " + clusterID )
325
296
err := client .Clusters .EnsureClusterIsRunning (ctx , clusterID )
326
297
if err != nil {
@@ -331,7 +302,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
331
302
if errors .Is (err , errServerMetadata ) {
332
303
cmdio .LogString (ctx , "SSH server is not running, starting it now..." )
333
304
334
- runID , err := submitSSHTunnelJob (ctx , client , clusterID , secretsScope , publicKeySecretName , version , shutdownDelay , maxClients )
305
+ runID , err := submitSSHTunnelJob (ctx , client , clusterID , secretsScope , publicKeySecretName , version , shutdownDelay , maxClients , serverTimeout )
335
306
if err != nil {
336
307
return "" , 0 , fmt .Errorf ("failed to submit ssh server job: %w" , err )
337
308
}
0 commit comments