Skip to content

Commit 6cb8f6a

Browse files
committed
Move pure ssh client and server logic to proxy package and test it
1 parent 8f69376 commit 6cb8f6a

File tree

12 files changed

+574
-290
lines changed

12 files changed

+574
-290
lines changed

experimental/ssh/bench.sh

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#!/bin/bash
2+
3+
# SSH Tunnel Performance Test Script
4+
# Usage:
5+
# 1. Setup ssh config: ./cli ssh setup --cluster --name
6+
# 2. Run: ./experimental/ssh/bench.sh <cluster-id> <ssh-config-hostname> [ssh-tunnel-binary-path] [profile]
7+
8+
set -e
9+
10+
CLUSTER_ID="$1"
11+
HOSTNAME="$2"
12+
CLI=${3:-./cli}
13+
PROFILE="${4:-DEFAULT}"
14+
15+
TEST_SIZES=(300 600) # MB
16+
17+
if [ -z "$CLUSTER_ID" ] || [ -z "$HOSTNAME" ]; then
18+
echo "Usage: $0 <cluster-id> <hostname> [ssh-tunnel-binary-path] [profile]"
19+
exit 1
20+
fi
21+
22+
echo "=== SSH Tunnel Performance Test ==="
23+
echo "Cluster ID: $CLUSTER_ID"
24+
echo "Hostname: $HOSTNAME"
25+
echo "Profile: $PROFILE"
26+
echo "Start time: $(date)"
27+
echo "SSH Tunnel: $CLI"
28+
echo
29+
30+
# Test basic connectivity
31+
echo "🔍 Testing basic connectivity..."
32+
33+
start_time=$(date +%s.%N)
34+
if ! $CLI ssh connect --cluster="$CLUSTER_ID" --profile="$PROFILE" --releases-dir=./dist -- "echo 'Connection successful'"; then
35+
echo "❌ Failed to connect to the cluster"
36+
exit 1
37+
fi
38+
end_time=$(date +%s.%N)
39+
duration=$(echo "$end_time - $start_time" | bc)
40+
duration_ms=$(echo "$duration * 1000" | bc)
41+
echo "✅ Basic connectivity OK ($duration_ms ms)"
42+
echo
43+
44+
# Throughput Tests
45+
echo "⚡ Testing Throughput..."
46+
47+
# Create test files
48+
echo "Creating test files..."
49+
for size in "${TEST_SIZES[@]}"; do
50+
if [ ! -f "test_${size}mb.dat" ]; then
51+
dd if=/dev/zero of="test_${size}mb.dat" bs=1M count=$size 2>/dev/null
52+
echo " Created test_${size}mb.dat"
53+
fi
54+
done
55+
echo
56+
57+
# Upload tests
58+
echo "📤 Upload Speed Tests:"
59+
for size in "${TEST_SIZES[@]}"; do
60+
echo -n " ${size}MB file: "
61+
scp "test_${size}mb.dat" "$HOSTNAME:/tmp/test_upload_${size}mb.dat"
62+
done
63+
echo
64+
65+
# Download tests
66+
echo "📥 Download Speed Tests:"
67+
for size in "${TEST_SIZES[@]}"; do
68+
echo -n " ${size}MB file: "
69+
scp "$HOSTNAME:/tmp/test_upload_${size}mb.dat" "./test_download_${size}mb.dat"
70+
done
71+
echo
72+
73+
# Cleanup
74+
echo "🧹 Cleaning up..."
75+
for size in "${TEST_SIZES[@]}"; do
76+
rm -f "test_${size}mb.dat" "test_download_${size}mb.dat"
77+
$CLI ssh connect --cluster="$CLUSTER_ID" --profile="$PROFILE" -- "rm -f /tmp/test_upload_${size}mb.dat" 2>/dev/null || true
78+
done
79+
echo

experimental/ssh/cmd/connect.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55

66
"github.com/databricks/cli/cmd/root"
77
"github.com/databricks/cli/experimental/ssh/internal/client"
8-
"github.com/databricks/cli/experimental/ssh/internal/proxy"
98
"github.com/databricks/cli/libs/cmdctx"
109
"github.com/spf13/cobra"
1110
)
@@ -61,11 +60,7 @@ the SSH server and handling the connection proxy.
6160
ClientPublicKeyName: defaultClientPublicKeyName,
6261
ServerTimeout: serverTimeout,
6362
}
64-
err := client.RunClient(ctx, wsClient, opts)
65-
if err != nil && proxy.IsNormalClosure(err) {
66-
return nil
67-
}
68-
return err
63+
return client.Run(ctx, wsClient, opts)
6964
}
7065

7166
return cmd

experimental/ssh/cmd/server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ and proxies them to local SSH daemon processes.
4444
cmd.PreRunE = root.MustWorkspaceClient
4545
cmd.RunE = func(cmd *cobra.Command, args []string) error {
4646
ctx := cmd.Context()
47-
client := cmdctx.WorkspaceClient(ctx)
47+
wsc := cmdctx.WorkspaceClient(ctx)
4848
opts := server.ServerOptions{
4949
ClusterID: clusterID,
5050
MaxClients: maxClients,
@@ -58,7 +58,7 @@ and proxies them to local SSH daemon processes.
5858
DefaultPort: defaultServerPort,
5959
PortRange: serverPortRange,
6060
}
61-
return server.RunServer(ctx, client, opts)
61+
return server.Run(ctx, wsc, opts)
6262
}
6363

6464
return cmd

experimental/ssh/internal/client/client.go

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import (
2626
"github.com/databricks/databricks-sdk-go/service/jobs"
2727
"github.com/databricks/databricks-sdk-go/service/workspace"
2828
"github.com/gorilla/websocket"
29-
"golang.org/x/sync/errgroup"
3029
)
3130

3231
//go:embed ssh-server-bootstrap.py
@@ -63,7 +62,7 @@ type ClientOptions struct {
6362
AdditionalArgs []string
6463
}
6564

66-
func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions) error {
65+
func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions) error {
6766
ctx, cancel := context.WithCancel(ctx)
6867
defer cancel()
6968

@@ -124,7 +123,7 @@ func RunClient(ctx context.Context, client *databricks.WorkspaceClient, opts Cli
124123
cmdio.LogString(ctx, fmt.Sprintf("Server port: %d", serverPort))
125124

126125
if opts.ProxyMode {
127-
return startSSHProxy(ctx, client, opts.ClusterID, serverPort, opts.HandoverTimeout)
126+
return runSSHProxy(ctx, client, opts.ClusterID, serverPort, opts.HandoverTimeout)
128127
} else {
129128
cmdio.LogString(ctx, fmt.Sprintf("Additional SSH arguments: %v", opts.AdditionalArgs))
130129
return spawnSSHClient(ctx, opts.ClusterID, userName, privateKeyPath, serverPort, opts.HandoverTimeout, opts.AdditionalArgs)
@@ -253,42 +252,11 @@ func spawnSSHClient(ctx context.Context, clusterID, userName, privateKeyPath str
253252
return sshCmd.Run()
254253
}
255254

256-
func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clusterID string, serverPort int, handoverTimeout time.Duration) error {
257-
g, gCtx := errgroup.WithContext(ctx)
258-
259-
cmdio.LogString(ctx, "Establishing SSH proxy connection...")
260-
conn := proxy.NewProxyConnection(func(ctx context.Context, connID string) (*websocket.Conn, error) {
255+
func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clusterID string, serverPort int, handoverTimeout time.Duration) error {
256+
createConn := func(ctx context.Context, connID string) (*websocket.Conn, error) {
261257
return createWebsocketConnection(ctx, client, connID, clusterID, serverPort)
262-
})
263-
if err := conn.Connect(gCtx); err != nil {
264-
return fmt.Errorf("failed to connect to proxy: %w", err)
265258
}
266-
defer conn.Close()
267-
cmdio.LogString(ctx, "SSH proxy connection established")
268-
269-
cmdio.LogString(ctx, fmt.Sprintf("Connection handover timeout: %v", handoverTimeout))
270-
handoverTicker := time.NewTicker(handoverTimeout)
271-
defer handoverTicker.Stop()
272-
273-
g.Go(func() error {
274-
for {
275-
select {
276-
case <-gCtx.Done():
277-
return gCtx.Err()
278-
case <-handoverTicker.C:
279-
err := conn.InitiateHandover(gCtx)
280-
if err != nil {
281-
return err
282-
}
283-
}
284-
}
285-
})
286-
287-
g.Go(func() error {
288-
return conn.Start(gCtx, os.Stdin, os.Stdout)
289-
})
290-
291-
return g.Wait()
259+
return proxy.RunClientProxy(ctx, os.Stdin, os.Stdout, handoverTimeout, createConn)
292260
}
293261

294262
func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretsScope, publicKeySecretName, version string, shutdownDelay time.Duration, maxClients int, serverTimeout time.Duration) (string, int, error) {
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package proxy
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"io"
7+
"time"
8+
9+
"github.com/databricks/cli/libs/cmdio"
10+
"golang.org/x/sync/errgroup"
11+
)
12+
13+
func RunClientProxy(ctx context.Context, src io.Reader, dst io.Writer, handoverTimeout time.Duration, createConn createWebsocketConnectionFunc) error {
14+
proxy := newProxyConnection(createConn)
15+
cmdio.LogString(ctx, "Establishing SSH proxy connection...")
16+
g, gCtx := errgroup.WithContext(ctx)
17+
if err := proxy.connect(gCtx); err != nil {
18+
return fmt.Errorf("failed to connect to proxy: %w", err)
19+
}
20+
defer proxy.close()
21+
cmdio.LogString(ctx, "SSH proxy connection established")
22+
23+
cmdio.LogString(ctx, fmt.Sprintf("Connection handover timeout: %v", handoverTimeout))
24+
handoverTicker := time.NewTicker(handoverTimeout)
25+
defer handoverTicker.Stop()
26+
27+
g.Go(func() error {
28+
for {
29+
select {
30+
case <-gCtx.Done():
31+
return gCtx.Err()
32+
case <-handoverTicker.C:
33+
err := proxy.initiateHandover(gCtx)
34+
if err != nil {
35+
return err
36+
}
37+
}
38+
}
39+
})
40+
41+
g.Go(func() error {
42+
return proxy.start(gCtx, src, dst)
43+
})
44+
45+
return g.Wait()
46+
}

0 commit comments

Comments
 (0)