Skip to content

Commit 94c2ae2

Browse files
committed
Split and move ssh logic to experimental/ssh
Move from flat structure to cmd and internal client/server/proxy/keys/workspace sub packages
1 parent 7ebfe7e commit 94c2ae2

25 files changed

+814
-505
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
default: checks fmt lint
22

3-
PACKAGES=./acceptance/... ./libs/... ./internal/... ./cmd/... ./bundle/... .
3+
PACKAGES=./acceptance/... ./libs/... ./internal/... ./cmd/... ./bundle/... ./experimental/ssh/... .
44

55
GOTESTSUM_FORMAT ?= pkgname-and-test-fails
66
GOTESTSUM_CMD ?= go tool gotestsum --format ${GOTESTSUM_FORMAT} --no-summary=skipped --jsonfile test-output.json
@@ -136,4 +136,4 @@ generate:
136136
$(GENKIT_BINARY) update-sdk
137137

138138

139-
.PHONY: lint lintfull tidy lintcheck fmt fmtfull test cover showcover build snapshot schema integration integration-short acc-cover acc-showcover docs ws links checks test-update test-update-aws test-update-all generate-validation
139+
.PHONY: lint lintfull tidy lintcheck fmt fmtfull test cover showcover build snapshot snapshot-release schema integration integration-short acc-cover acc-showcover docs ws links checks test-update test-update-aws test-update-all generate-validation

cmd/cmd.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
"strings"
66

77
"github.com/databricks/cli/cmd/psql"
8-
"github.com/databricks/cli/cmd/ssh"
8+
ssh "github.com/databricks/cli/experimental/ssh/cmd"
99

1010
"github.com/databricks/cli/cmd/account"
1111
"github.com/databricks/cli/cmd/api"
File renamed without changes.

cmd/ssh/connect.go renamed to experimental/ssh/cmd/connect.go

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,12 @@ import (
44
"time"
55

66
"github.com/databricks/cli/cmd/root"
7+
"github.com/databricks/cli/experimental/ssh/internal/client"
8+
"github.com/databricks/cli/experimental/ssh/internal/proxy"
79
"github.com/databricks/cli/libs/cmdctx"
8-
"github.com/databricks/cli/libs/ssh"
910
"github.com/spf13/cobra"
1011
)
1112

12-
const (
13-
defaultClientPublicKeyName = "client-public-key"
14-
defaultShutdownDelay = 10 * time.Minute
15-
defaultHandoverTimeout = 30 * time.Minute
16-
defaultMaxClients = 10
17-
)
18-
1913
func newConnectCommand() *cobra.Command {
2014
cmd := &cobra.Command{
2115
Use: "connect",
@@ -54,8 +48,8 @@ the SSH server and handling the connection proxy.
5448
cmd.PreRunE = root.MustWorkspaceClient
5549
cmd.RunE = func(cmd *cobra.Command, args []string) error {
5650
ctx := cmd.Context()
57-
client := cmdctx.WorkspaceClient(ctx)
58-
opts := ssh.ClientOptions{
51+
wsClient := cmdctx.WorkspaceClient(ctx)
52+
opts := client.ClientOptions{
5953
ClusterID: clusterID,
6054
ProxyMode: proxyMode,
6155
ServerMetadata: serverMetadata,
@@ -65,8 +59,13 @@ the SSH server and handling the connection proxy.
6559
ReleasesDir: releasesDir,
6660
AdditionalArgs: args,
6761
ClientPublicKeyName: defaultClientPublicKeyName,
62+
ServerTimeout: serverTimeout,
63+
}
64+
err := client.RunClient(ctx, wsClient, opts)
65+
if err != nil && proxy.IsNormalClosure(err) {
66+
return nil
6867
}
69-
return ssh.RunClient(ctx, client, opts)
68+
return err
7069
}
7170

7271
return cmd

experimental/ssh/cmd/constants.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package ssh
2+
3+
import "time"
4+
5+
const (
6+
defaultServerPort = 7772
7+
defaultMaxClients = 10
8+
defaultShutdownDelay = 10 * time.Minute
9+
defaultHandoverTimeout = 30 * time.Minute
10+
11+
serverTimeout = 24 * time.Hour
12+
serverPortRange = 100
13+
serverConfigDir = ".ssh-tunnel"
14+
serverPrivateKeyName = "server-private-key"
15+
serverPublicKeyName = "server-public-key"
16+
defaultClientPublicKeyName = "client-public-key"
17+
)

cmd/ssh/server.go renamed to experimental/ssh/cmd/server.go

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import (
44
"time"
55

66
"github.com/databricks/cli/cmd/root"
7+
"github.com/databricks/cli/experimental/ssh/internal/server"
78
"github.com/databricks/cli/libs/cmdctx"
8-
"github.com/databricks/cli/libs/ssh"
99
"github.com/spf13/cobra"
1010
)
1111

@@ -27,33 +27,38 @@ and proxies them to local SSH daemon processes.
2727
var shutdownDelay time.Duration
2828
var clusterID string
2929
var version string
30+
var secretsScope string
31+
var publicKeySecretName string
3032

3133
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
3234
cmd.MarkFlagRequired("cluster")
33-
cmd.Flags().IntVar(&maxClients, "max-clients", 10, "Maximum number of SSH clients")
34-
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", 10*time.Minute, "Delay before shutting down after no pings from clients")
35+
cmd.Flags().StringVar(&secretsScope, "secrets-scope-name", "", "Databricks secrets scope name")
36+
cmd.MarkFlagRequired("secrets-scope-name")
37+
cmd.Flags().StringVar(&publicKeySecretName, "client-key-name", "", "Databricks client key name")
38+
cmd.MarkFlagRequired("client-key-name")
39+
40+
cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients")
41+
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down after no pings from clients")
3542
cmd.Flags().StringVar(&version, "version", "", "Client version of the Databricks CLI")
3643

3744
cmd.PreRunE = root.MustWorkspaceClient
3845
cmd.RunE = func(cmd *cobra.Command, args []string) error {
3946
ctx := cmd.Context()
4047
client := cmdctx.WorkspaceClient(ctx)
41-
opts := ssh.ServerOptions{
48+
opts := server.ServerOptions{
4249
ClusterID: clusterID,
4350
MaxClients: maxClients,
4451
ShutdownDelay: shutdownDelay,
4552
Version: version,
46-
ConfigDir: ".ssh-tunnel",
47-
ServerPrivateKeyName: "server-private-key",
48-
ServerPublicKeyName: "server-public-key",
49-
DefaultPort: 7772,
50-
PortRange: 100,
51-
}
52-
err := ssh.RunServer(ctx, client, opts)
53-
if err != nil && ssh.IsNormalClosure(err) {
54-
return nil
53+
ConfigDir: serverConfigDir,
54+
SecretsScope: secretsScope,
55+
ClientPublicKeyName: publicKeySecretName,
56+
ServerPrivateKeyName: serverPrivateKeyName,
57+
ServerPublicKeyName: serverPublicKeyName,
58+
DefaultPort: defaultServerPort,
59+
PortRange: serverPortRange,
5560
}
56-
return err
61+
return server.RunServer(ctx, client, opts)
5762
}
5863

5964
return cmd

cmd/ssh/setup.go renamed to experimental/ssh/cmd/setup.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import (
44
"time"
55

66
"github.com/databricks/cli/cmd/root"
7+
"github.com/databricks/cli/experimental/ssh/internal/setup"
78
"github.com/databricks/cli/libs/cmdctx"
8-
"github.com/databricks/cli/libs/ssh"
99
"github.com/spf13/cobra"
1010
)
1111

@@ -31,20 +31,20 @@ an SSH host configuration to your SSH config file.
3131
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
3232
cmd.MarkFlagRequired("cluster")
3333
cmd.Flags().StringVar(&sshConfigPath, "ssh-config", "", "Path to SSH config file (default ~/.ssh/config)")
34-
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", 10*time.Minute, "SSH server will terminate after this delay if there are no active connections")
34+
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "SSH server will terminate after this delay if there are no active connections")
3535

3636
cmd.PreRunE = root.MustWorkspaceClient
3737
cmd.RunE = func(cmd *cobra.Command, args []string) error {
3838
ctx := cmd.Context()
3939
client := cmdctx.WorkspaceClient(ctx)
40-
opts := ssh.SetupOptions{
40+
opts := setup.SetupOptions{
4141
HostName: hostName,
4242
ClusterID: clusterID,
4343
SSHConfigPath: sshConfigPath,
4444
ShutdownDelay: shutdownDelay,
4545
Profile: client.Config.Profile,
4646
}
47-
return ssh.Setup(ctx, client, opts)
47+
return setup.Setup(ctx, client, opts)
4848
}
4949

5050
return cmd
File renamed without changes.

libs/ssh/client.go renamed to experimental/ssh/internal/client/client.go

Lines changed: 24 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
package ssh
1+
package client
22

33
import (
44
"context"
55
_ "embed"
66
"encoding/base64"
7-
"encoding/json"
87
"errors"
98
"fmt"
109
"io"
@@ -18,6 +17,9 @@ import (
1817
"syscall"
1918
"time"
2019

20+
"github.com/databricks/cli/experimental/ssh/internal/keys"
21+
"github.com/databricks/cli/experimental/ssh/internal/proxy"
22+
sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace"
2123
"github.com/databricks/cli/internal/build"
2224
"github.com/databricks/cli/libs/cmdio"
2325
"github.com/databricks/databricks-sdk-go"
@@ -27,17 +29,11 @@ import (
2729
"golang.org/x/sync/errgroup"
2830
)
2931

30-
type PortMetadata struct {
31-
Port int `json:"port"`
32-
}
33-
3432
//go:embed ssh-server-bootstrap.py
3533
var sshServerBootstrapScript string
3634

3735
var errServerMetadata = errors.New("server metadata error")
3836

39-
const serverJobTimeoutSeconds = 24 * 60 * 60
40-
4137
type ClientOptions struct {
4238
// Id of the cluster to connect to
4339
ClusterID string
@@ -54,6 +50,8 @@ type ClientOptions struct {
5450
ServerMetadata string
5551
// How often the CLI should reconnect to the server with new auth.
5652
HandoverTimeout time.Duration
53+
// Max amount of time the server process is allowed to live
54+
ServerTimeout time.Duration
5755
// Directory for local SSH tunnel development releases.
5856
// If not present, the CLI will use github releases with the current version.
5957
ReleasesDir string
@@ -77,16 +75,17 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
7775
cancel()
7876
}()
7977

80-
keyPath, err := getLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir)
78+
keyPath, err := keys.GetLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir)
8179
if err != nil {
8280
return fmt.Errorf("failed to get local keys folder: %w", err)
8381
}
84-
privateKeyPath, publicKey, err := checkAndGenerateSSHKeyPair(ctx, keyPath)
82+
privateKeyPath, publicKey, err := keys.CheckAndGenerateSSHKeyPair(ctx, keyPath)
8583
if err != nil {
8684
return fmt.Errorf("failed to check or generate SSH key pair: %w", err)
8785
}
86+
cmdio.LogString(ctx, "Using SSH key: "+privateKeyPath)
8887

89-
secretsScopeName, err := putSecretInScope(ctx, client, opts.ClusterID, opts.ClientPublicKeyName, publicKey)
88+
secretsScopeName, err := keys.PutSecretInScope(ctx, client, opts.ClusterID, opts.ClientPublicKeyName, publicKey)
9089
if err != nil {
9190
return fmt.Errorf("failed to store public key in secret scope: %w", err)
9291
}
@@ -99,10 +98,10 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
9998

10099
if opts.ServerMetadata == "" {
101100
cmdio.LogString(ctx, "Checking for ssh-tunnel binaries to upload...")
102-
if err := uploadTunnelBinaries(ctx, client, version, opts.ReleasesDir); err != nil {
101+
if err := UploadTunnelReleases(ctx, client, version, opts.ReleasesDir); err != nil {
103102
return fmt.Errorf("failed to upload ssh-tunnel binaries: %w", err)
104103
}
105-
userName, serverPort, err = ensureSSHServerIsRunning(ctx, client, opts.ClusterID, secretsScopeName, opts.ClientPublicKeyName, version, opts.ShutdownDelay, opts.MaxClients)
104+
userName, serverPort, err = ensureSSHServerIsRunning(ctx, client, opts.ClusterID, secretsScopeName, opts.ClientPublicKeyName, version, opts.ShutdownDelay, opts.MaxClients, opts.ServerTimeout)
106105
if err != nil {
107106
return fmt.Errorf("failed to ensure that ssh server is running: %w", err)
108107
}
@@ -132,36 +131,8 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
132131
}
133132
}
134133

135-
func getWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, clusterID string) (int, error) {
136-
contentDir, err := getWorkspaceContentDir(ctx, client, version, clusterID)
137-
if err != nil {
138-
return 0, fmt.Errorf("failed to get workspace content directory: %w", err)
139-
}
140-
141-
metadataPath := filepath.ToSlash(filepath.Join(contentDir, "metadata.json"))
142-
143-
content, err := client.Workspace.Download(ctx, metadataPath)
144-
if err != nil {
145-
return 0, fmt.Errorf("failed to download metadata file: %w", err)
146-
}
147-
defer content.Close()
148-
149-
metadataBytes, err := io.ReadAll(content)
150-
if err != nil {
151-
return 0, fmt.Errorf("failed to read metadata content: %w", err)
152-
}
153-
154-
var metadata PortMetadata
155-
err = json.Unmarshal(metadataBytes, &metadata)
156-
if err != nil {
157-
return 0, fmt.Errorf("failed to parse metadata JSON: %w", err)
158-
}
159-
160-
return metadata.Port, nil
161-
}
162-
163134
func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, clusterID, version string) (int, string, error) {
164-
serverPort, err := getWorkspaceMetadata(ctx, client, version, clusterID)
135+
serverPort, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, clusterID)
165136
if err != nil {
166137
return 0, "", errors.Join(errServerMetadata, err)
167138
}
@@ -194,8 +165,8 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
194165
return serverPort, string(bodyBytes), nil
195166
}
196167

197-
func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretsScope, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int) (int64, error) {
198-
contentDir, err := getWorkspaceContentDir(ctx, client, version, clusterID)
168+
func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretsScope, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int, serverTimeout time.Duration) (int64, error) {
169+
contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, clusterID)
199170
if err != nil {
200171
return 0, fmt.Errorf("failed to get workspace content directory: %w", err)
201172
}
@@ -223,7 +194,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
223194

224195
submitRun := jobs.SubmitRun{
225196
RunName: sshTunnelJobName,
226-
TimeoutSeconds: serverJobTimeoutSeconds,
197+
TimeoutSeconds: int(serverTimeout.Seconds()),
227198
Tasks: []jobs.SubmitTask{
228199
{
229200
TaskKey: "start_ssh_server",
@@ -237,7 +208,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
237208
"maxClients": strconv.Itoa(maxClients),
238209
},
239210
},
240-
TimeoutSeconds: serverJobTimeoutSeconds,
211+
TimeoutSeconds: int(serverTimeout.Seconds()),
241212
ExistingClusterId: clusterID,
242213
},
243214
},
@@ -286,13 +257,13 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
286257
g, gCtx := errgroup.WithContext(ctx)
287258

288259
cmdio.LogString(ctx, "Establishing SSH proxy connection...")
289-
proxy := newProxyConnection(func(ctx context.Context, connID string) (*websocket.Conn, error) {
260+
conn := proxy.NewProxyConnection(func(ctx context.Context, connID string) (*websocket.Conn, error) {
290261
return createWebsocketConnection(ctx, client, connID, clusterID, serverPort)
291262
})
292-
if err := proxy.Connect(gCtx); err != nil {
263+
if err := conn.Connect(gCtx); err != nil {
293264
return fmt.Errorf("failed to connect to proxy: %w", err)
294265
}
295-
defer proxy.Close()
266+
defer conn.Close()
296267
cmdio.LogString(ctx, "SSH proxy connection established")
297268

298269
cmdio.LogString(ctx, fmt.Sprintf("Connection handover timeout: %v", handoverTimeout))
@@ -305,7 +276,7 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
305276
case <-gCtx.Done():
306277
return gCtx.Err()
307278
case <-handoverTicker.C:
308-
err := proxy.InitiateHandover(gCtx)
279+
err := conn.InitiateHandover(gCtx)
309280
if err != nil {
310281
return err
311282
}
@@ -314,13 +285,13 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
314285
})
315286

316287
g.Go(func() error {
317-
return proxy.Start(gCtx, os.Stdin, os.Stdout)
288+
return conn.Start(gCtx, os.Stdin, os.Stdout)
318289
})
319290

320291
return g.Wait()
321292
}
322293

323-
func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretsScope, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int) (string, int, error) {
294+
func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretsScope, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int, serverTimeout time.Duration) (string, int, error) {
324295
cmdio.LogString(ctx, "Ensuring the cluster is running: "+clusterID)
325296
err := client.Clusters.EnsureClusterIsRunning(ctx, clusterID)
326297
if err != nil {
@@ -331,7 +302,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
331302
if errors.Is(err, errServerMetadata) {
332303
cmdio.LogString(ctx, "SSH server is not running, starting it now...")
333304

334-
runID, err := submitSSHTunnelJob(ctx, client, clusterID, secretsScope, publicKeySecretName, version, shutdownDelay, maxClients)
305+
runID, err := submitSSHTunnelJob(ctx, client, clusterID, secretsScope, publicKeySecretName, version, shutdownDelay, maxClients, serverTimeout)
335306
if err != nil {
336307
return "", 0, fmt.Errorf("failed to submit ssh server job: %w", err)
337308
}

0 commit comments

Comments
 (0)