Skip to content

Commit 5c89e05

Browse files
committed
Move pure ssh client and server logic to proxy package and test it
1 parent c2e3322 commit 5c89e05

File tree

17 files changed

+851
-533
lines changed

17 files changed

+851
-533
lines changed

experimental/ssh/bench.sh

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#!/bin/bash
2+
3+
# SSH Tunnel Performance Test Script
4+
# Usage:
5+
# 1. Setup ssh config: ./cli ssh setup --cluster --name
6+
# 2. Run: ./experimental/ssh/bench.sh <cluster-id> <ssh-config-hostname> [ssh-tunnel-binary-path] [profile]
7+
8+
set -e
9+
10+
CLUSTER_ID="$1"
11+
HOSTNAME="$2"
12+
CLI=${3:-./cli}
13+
PROFILE="${4:-DEFAULT}"
14+
15+
TEST_SIZES=(300 600) # MB
16+
17+
if [ -z "$CLUSTER_ID" ] || [ -z "$HOSTNAME" ]; then
18+
echo "Usage: $0 <cluster-id> <hostname> [ssh-tunnel-binary-path] [profile]"
19+
exit 1
20+
fi
21+
22+
echo "=== SSH Tunnel Performance Test ==="
23+
echo "Cluster ID: $CLUSTER_ID"
24+
echo "Hostname: $HOSTNAME"
25+
echo "Profile: $PROFILE"
26+
echo "Start time: $(date)"
27+
echo "SSH Tunnel: $CLI"
28+
echo
29+
30+
# Test basic connectivity
31+
echo "🔍 Testing basic connectivity..."
32+
33+
start_time=$(date +%s.%N)
34+
if ! $CLI ssh connect --cluster="$CLUSTER_ID" --profile="$PROFILE" --releases-dir=./dist -- "echo 'Connection successful'"; then
35+
echo "❌ Failed to connect to the cluster"
36+
exit 1
37+
fi
38+
end_time=$(date +%s.%N)
39+
duration=$(echo "$end_time - $start_time" | bc)
40+
duration_ms=$(echo "$duration * 1000" | bc)
41+
echo "✅ Basic connectivity OK ($duration_ms ms)"
42+
echo
43+
44+
# Throughput Tests
45+
echo "⚡ Testing Throughput..."
46+
47+
# Create test files
48+
echo "Creating test files..."
49+
for size in "${TEST_SIZES[@]}"; do
50+
if [ ! -f "test_${size}mb.dat" ]; then
51+
dd if=/dev/zero of="test_${size}mb.dat" bs=1M count=$size 2>/dev/null
52+
echo " Created test_${size}mb.dat"
53+
fi
54+
done
55+
echo
56+
57+
# Upload tests
58+
echo "📤 Upload Speed Tests:"
59+
for size in "${TEST_SIZES[@]}"; do
60+
echo -n " ${size}MB file: "
61+
scp "test_${size}mb.dat" "$HOSTNAME:/tmp/test_upload_${size}mb.dat"
62+
done
63+
echo
64+
65+
# Download tests
66+
echo "📥 Download Speed Tests:"
67+
for size in "${TEST_SIZES[@]}"; do
68+
echo -n " ${size}MB file: "
69+
scp "$HOSTNAME:/tmp/test_upload_${size}mb.dat" "./test_download_${size}mb.dat"
70+
done
71+
echo
72+
73+
# Cleanup
74+
echo "🧹 Cleaning up..."
75+
for size in "${TEST_SIZES[@]}"; do
76+
rm -f "test_${size}mb.dat" "test_download_${size}mb.dat"
77+
$CLI ssh connect --cluster="$CLUSTER_ID" --profile="$PROFILE" -- "rm -f /tmp/test_upload_${size}mb.dat" 2>/dev/null || true
78+
done
79+
echo

experimental/ssh/cmd/connect.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55

66
"github.com/databricks/cli/cmd/root"
77
"github.com/databricks/cli/experimental/ssh/internal/client"
8-
"github.com/databricks/cli/experimental/ssh/internal/proxy"
98
"github.com/databricks/cli/libs/cmdctx"
109
"github.com/spf13/cobra"
1110
)
@@ -61,11 +60,7 @@ the SSH server and handling the connection proxy.
6160
ClientPublicKeyName: defaultClientPublicKeyName,
6261
ServerTimeout: serverTimeout,
6362
}
64-
err := client.RunClient(ctx, wsClient, opts)
65-
if err != nil && proxy.IsNormalClosure(err) {
66-
return nil
67-
}
68-
return err
63+
return client.Run(ctx, wsClient, opts)
6964
}
7065

7166
return cmd

experimental/ssh/cmd/server.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ and proxies them to local SSH daemon processes.
2727
var shutdownDelay time.Duration
2828
var clusterID string
2929
var version string
30-
var secretsScope string
31-
var publicKeySecretName string
30+
var keysSecretScopeName string
31+
var authorizedKeyName string
3232

3333
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
3434
cmd.MarkFlagRequired("cluster")
35-
cmd.Flags().StringVar(&secretsScope, "secrets-scope-name", "", "Databricks secrets scope name")
36-
cmd.MarkFlagRequired("secrets-scope-name")
37-
cmd.Flags().StringVar(&publicKeySecretName, "client-key-name", "", "Databricks client key name")
38-
cmd.MarkFlagRequired("client-key-name")
35+
cmd.Flags().StringVar(&keysSecretScopeName, "keys-secret-scope-name", "", "Databricks secret scope name to store SSH keys")
36+
cmd.MarkFlagRequired("keys-secret-scope-name")
37+
cmd.Flags().StringVar(&authorizedKeyName, "authorized-key-secret-name", "", "Authorized key secret name in the secret scope")
38+
cmd.MarkFlagRequired("authorized-key-secret-name")
3939

4040
cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients")
4141
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down after no pings from clients")
@@ -44,21 +44,21 @@ and proxies them to local SSH daemon processes.
4444
cmd.PreRunE = root.MustWorkspaceClient
4545
cmd.RunE = func(cmd *cobra.Command, args []string) error {
4646
ctx := cmd.Context()
47-
client := cmdctx.WorkspaceClient(ctx)
47+
wsc := cmdctx.WorkspaceClient(ctx)
4848
opts := server.ServerOptions{
4949
ClusterID: clusterID,
5050
MaxClients: maxClients,
5151
ShutdownDelay: shutdownDelay,
5252
Version: version,
5353
ConfigDir: serverConfigDir,
54-
SecretsScope: secretsScope,
55-
ClientPublicKeyName: publicKeySecretName,
54+
KeysSecretScopeName: keysSecretScopeName,
55+
AuthorizedKeyName: authorizedKeyName,
5656
ServerPrivateKeyName: serverPrivateKeyName,
5757
ServerPublicKeyName: serverPublicKeyName,
5858
DefaultPort: defaultServerPort,
5959
PortRange: serverPortRange,
6060
}
61-
return server.RunServer(ctx, client, opts)
61+
return server.Run(ctx, wsc, opts)
6262
}
6363

6464
return cmd

experimental/ssh/internal/client/client.go

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import (
2626
"github.com/databricks/databricks-sdk-go/service/jobs"
2727
"github.com/databricks/databricks-sdk-go/service/workspace"
2828
"github.com/gorilla/websocket"
29-
"golang.org/x/sync/errgroup"
3029
)
3130

3231
//go:embed ssh-server-bootstrap.py
@@ -63,7 +62,7 @@ type ClientOptions struct {
6362
AdditionalArgs []string
6463
}
6564

66-
func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions) error {
65+
func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions) error {
6766
ctx, cancel := context.WithCancel(ctx)
6867
defer cancel()
6968

@@ -85,11 +84,11 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
8584
}
8685
cmdio.LogString(ctx, "Using SSH key: "+privateKeyPath)
8786

88-
secretsScopeName, err := keys.PutSecretInScope(ctx, client, opts.ClusterID, opts.ClientPublicKeyName, publicKey)
87+
keysSecretScopeName, err := keys.PutSecretInScope(ctx, client, opts.ClusterID, opts.ClientPublicKeyName, publicKey)
8988
if err != nil {
9089
return fmt.Errorf("failed to store public key in secret scope: %w", err)
9190
}
92-
cmdio.LogString(ctx, fmt.Sprintf("Secrets scope: %s, key name: %s", secretsScopeName, opts.ClientPublicKeyName))
91+
cmdio.LogString(ctx, fmt.Sprintf("Secrets scope: %s, key name: %s", keysSecretScopeName, opts.ClientPublicKeyName))
9392

9493
var userName string
9594
var serverPort int
@@ -101,7 +100,7 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
101100
if err := UploadTunnelReleases(ctx, client, version, opts.ReleasesDir); err != nil {
102101
return fmt.Errorf("failed to upload ssh-tunnel binaries: %w", err)
103102
}
104-
userName, serverPort, err = ensureSSHServerIsRunning(ctx, client, opts.ClusterID, secretsScopeName, opts.ClientPublicKeyName, version, opts.ShutdownDelay, opts.MaxClients, opts.ServerTimeout)
103+
userName, serverPort, err = ensureSSHServerIsRunning(ctx, client, opts.ClusterID, keysSecretScopeName, opts.ClientPublicKeyName, version, opts.ShutdownDelay, opts.MaxClients, opts.ServerTimeout)
105104
if err != nil {
106105
return fmt.Errorf("failed to ensure that ssh server is running: %w", err)
107106
}
@@ -124,7 +123,7 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
124123
cmdio.LogString(ctx, fmt.Sprintf("Server port: %d", serverPort))
125124

126125
if opts.ProxyMode {
127-
return startSSHProxy(ctx, client, opts.ClusterID, serverPort, opts.HandoverTimeout)
126+
return runSSHProxy(ctx, client, opts.ClusterID, serverPort, opts.HandoverTimeout)
128127
} else {
129128
cmdio.LogString(ctx, fmt.Sprintf("Additional SSH arguments: %v", opts.AdditionalArgs))
130129
return spawnSSHClient(ctx, opts.ClusterID, userName, privateKeyPath, serverPort, opts.HandoverTimeout, opts.AdditionalArgs)
@@ -165,7 +164,7 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
165164
return serverPort, string(bodyBytes), nil
166165
}
167166

168-
func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretsScope, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int, serverTimeout time.Duration) (int64, error) {
167+
func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, clusterID, keysSecretScopeName, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int, serverTimeout time.Duration) (int64, error) {
169168
contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, clusterID)
170169
if err != nil {
171170
return 0, fmt.Errorf("failed to get workspace content directory: %w", err)
@@ -201,11 +200,11 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
201200
NotebookTask: &jobs.NotebookTask{
202201
NotebookPath: jobNotebookPath,
203202
BaseParameters: map[string]string{
204-
"version": version,
205-
"secretsScope": secretsScope,
206-
"publicKeySecretName": publicKeySecretName,
207-
"shutdownDelay": shutdownDelay.String(),
208-
"maxClients": strconv.Itoa(maxClients),
203+
"version": version,
204+
"keysSecretScopeName": keysSecretScopeName,
205+
"authorizedKeySecretName": publicKeySecretName,
206+
"shutdownDelay": shutdownDelay.String(),
207+
"maxClients": strconv.Itoa(maxClients),
209208
},
210209
},
211210
TimeoutSeconds: int(serverTimeout.Seconds()),
@@ -253,45 +252,14 @@ func spawnSSHClient(ctx context.Context, clusterID, userName, privateKeyPath str
253252
return sshCmd.Run()
254253
}
255254

256-
func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clusterID string, serverPort int, handoverTimeout time.Duration) error {
257-
g, gCtx := errgroup.WithContext(ctx)
258-
259-
cmdio.LogString(ctx, "Establishing SSH proxy connection...")
260-
conn := proxy.NewProxyConnection(func(ctx context.Context, connID string) (*websocket.Conn, error) {
255+
func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clusterID string, serverPort int, handoverTimeout time.Duration) error {
256+
createConn := func(ctx context.Context, connID string) (*websocket.Conn, error) {
261257
return createWebsocketConnection(ctx, client, connID, clusterID, serverPort)
262-
})
263-
if err := conn.Connect(gCtx); err != nil {
264-
return fmt.Errorf("failed to connect to proxy: %w", err)
265258
}
266-
defer conn.Close()
267-
cmdio.LogString(ctx, "SSH proxy connection established")
268-
269-
cmdio.LogString(ctx, fmt.Sprintf("Connection handover timeout: %v", handoverTimeout))
270-
handoverTicker := time.NewTicker(handoverTimeout)
271-
defer handoverTicker.Stop()
272-
273-
g.Go(func() error {
274-
for {
275-
select {
276-
case <-gCtx.Done():
277-
return gCtx.Err()
278-
case <-handoverTicker.C:
279-
err := conn.InitiateHandover(gCtx)
280-
if err != nil {
281-
return err
282-
}
283-
}
284-
}
285-
})
286-
287-
g.Go(func() error {
288-
return conn.Start(gCtx, os.Stdin, os.Stdout)
289-
})
290-
291-
return g.Wait()
259+
return proxy.RunClientProxy(ctx, os.Stdin, os.Stdout, handoverTimeout, createConn)
292260
}
293261

294-
func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretsScope, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int, serverTimeout time.Duration) (string, int, error) {
262+
func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, clusterID, keysSecretScopeName, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int, serverTimeout time.Duration) (string, int, error) {
295263
cmdio.LogString(ctx, "Ensuring the cluster is running: "+clusterID)
296264
err := client.Clusters.EnsureClusterIsRunning(ctx, clusterID)
297265
if err != nil {
@@ -302,7 +270,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
302270
if errors.Is(err, errServerMetadata) {
303271
cmdio.LogString(ctx, "SSH server is not running, starting it now...")
304272

305-
runID, err := submitSSHTunnelJob(ctx, client, clusterID, secretsScope, publicKeySecretName, version, shutdownDelay, maxClients, serverTimeout)
273+
runID, err := submitSSHTunnelJob(ctx, client, clusterID, keysSecretScopeName, publicKeySecretName, version, shutdownDelay, maxClients, serverTimeout)
306274
if err != nil {
307275
return "", 0, fmt.Errorf("failed to submit ssh server job: %w", err)
308276
}

experimental/ssh/internal/client/ssh-server-bootstrap.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
SSH_TUNNEL_BASENAME = "databricks_cli_linux"
1414

1515
dbutils.widgets.text("version", "")
16-
dbutils.widgets.text("secretsScope", "")
17-
dbutils.widgets.text("publicKeySecretName", "")
16+
dbutils.widgets.text("keysSecretScopeName", "")
17+
dbutils.widgets.text("authorizedKeySecretName", "")
1818
dbutils.widgets.text("maxClients", "10")
1919
dbutils.widgets.text("shutdownDelay", "10m")
2020

@@ -86,13 +86,13 @@ def run_ssh_server():
8686
if os.environ.get("VIRTUAL_ENV") is None:
8787
os.environ["VIRTUAL_ENV"] = sys.executable
8888

89-
secrets_scope = dbutils.widgets.get("secretsScope")
89+
secrets_scope = dbutils.widgets.get("keysSecretScopeName")
9090
if not secrets_scope:
91-
raise RuntimeError("Secrets scope is required. Please provide it using the 'secretsScope' widget.")
91+
raise RuntimeError("Secrets scope is required. Please provide it using the 'keysSecretScopeName' widget.")
9292

93-
public_key_secret_name = dbutils.widgets.get("publicKeySecretName")
93+
public_key_secret_name = dbutils.widgets.get("authorizedKeySecretName")
9494
if not public_key_secret_name:
95-
raise RuntimeError("Public key secret name is required. Please provide it using the 'publicKeySecretName' widget.")
95+
raise RuntimeError("Public key secret name is required. Please provide it using the 'authorizedKeySecretName' widget.")
9696

9797
version = dbutils.widgets.get("version")
9898
if not version:
@@ -117,8 +117,8 @@ def run_ssh_server():
117117
"ssh",
118118
"server",
119119
f"--cluster={ctx.clusterId}",
120-
f"--secrets-scope-name={secrets_scope}",
121-
f"--client-key-name={public_key_secret_name}",
120+
f"--keys-secret-scope-name={secrets_scope}",
121+
f"--authorized-key-secret-name={public_key_secret_name}",
122122
f"--max-clients={max_clients}",
123123
f"--shutdown-delay={shutdown_delay}",
124124
f"--version={version}",

experimental/ssh/internal/keys/keys.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,27 +125,27 @@ func CheckAndGenerateSSHKeyPair(ctx context.Context, keyPath string) (string, st
125125
return keyPath, strings.TrimSpace(string(publicKeyBytes)), nil
126126
}
127127

128-
func CheckAndGenerateSSHKeyPairFromSecrets(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretsScopeName, privateKeyName, publicKeyName string) ([]byte, []byte, error) {
129-
privateKeyBytes, err := GetSecret(ctx, client, secretsScopeName, privateKeyName)
128+
func CheckAndGenerateSSHKeyPairFromSecrets(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretScopeName, privateKeyName, publicKeyName string) ([]byte, []byte, error) {
129+
privateKeyBytes, err := GetSecret(ctx, client, secretScopeName, privateKeyName)
130130
if err != nil {
131131
privateKeyBytes, publicKeyBytes, err := generateSSHKeyPair()
132132
if err != nil {
133133
return nil, nil, fmt.Errorf("failed to generate SSH key pair: %w", err)
134134
}
135135

136-
err = putSecret(ctx, client, secretsScopeName, privateKeyName, string(privateKeyBytes))
136+
err = putSecret(ctx, client, secretScopeName, privateKeyName, string(privateKeyBytes))
137137
if err != nil {
138138
return nil, nil, err
139139
}
140140

141-
err = putSecret(ctx, client, secretsScopeName, publicKeyName, string(publicKeyBytes))
141+
err = putSecret(ctx, client, secretScopeName, publicKeyName, string(publicKeyBytes))
142142
if err != nil {
143143
return nil, nil, err
144144
}
145145

146146
return privateKeyBytes, publicKeyBytes, nil
147147
} else {
148-
publicKeyBytes, err := GetSecret(ctx, client, secretsScopeName, publicKeyName)
148+
publicKeyBytes, err := GetSecret(ctx, client, secretScopeName, publicKeyName)
149149
if err != nil {
150150
return nil, nil, fmt.Errorf("failed to get public key from secrets scope: %w", err)
151151
}

experimental/ssh/internal/keys/secrets.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@ import (
1010
"github.com/databricks/databricks-sdk-go/service/workspace"
1111
)
1212

13-
func createSecretsScope(ctx context.Context, client *databricks.WorkspaceClient, clusterID string) (string, error) {
13+
func createKeysSecretScope(ctx context.Context, client *databricks.WorkspaceClient, clusterID string) (string, error) {
1414
me, err := client.CurrentUser.Me(ctx)
1515
if err != nil {
1616
return "", fmt.Errorf("failed to get current user: %w", err)
1717
}
18-
secretsScope := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, clusterID)
18+
secretScopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, clusterID)
1919
err = client.Secrets.CreateScope(ctx, workspace.CreateScope{
20-
Scope: secretsScope,
20+
Scope: secretScopeName,
2121
})
2222
if err != nil && !errors.Is(err, databricks.ErrResourceAlreadyExists) {
2323
return "", fmt.Errorf("failed to create secrets scope: %w", err)
2424
}
25-
return secretsScope, nil
25+
return secretScopeName, nil
2626
}
2727

2828
func GetSecret(ctx context.Context, client *databricks.WorkspaceClient, scope, key string) ([]byte, error) {
@@ -54,7 +54,7 @@ func putSecret(ctx context.Context, client *databricks.WorkspaceClient, scope, k
5454
}
5555

5656
func PutSecretInScope(ctx context.Context, client *databricks.WorkspaceClient, clusterID, key, value string) (string, error) {
57-
scopeName, err := createSecretsScope(ctx, client, clusterID)
57+
scopeName, err := createKeysSecretScope(ctx, client, clusterID)
5858
if err != nil {
5959
return "", err
6060
}

0 commit comments

Comments
 (0)