Skip to content

Commit 7478778

Browse files
committed
Refactor proxy and add unit tests
1 parent d4e26d5 commit 7478778

File tree

4 files changed

+304
-68
lines changed

4 files changed

+304
-68
lines changed

libs/ssh/client.go

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/databricks/databricks-sdk-go"
2424
"github.com/databricks/databricks-sdk-go/service/jobs"
2525
"github.com/databricks/databricks-sdk-go/service/workspace"
26+
"github.com/gorilla/websocket"
2627
"golang.org/x/sync/errgroup"
2728
)
2829

@@ -269,21 +270,15 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
269270
g, gCtx := errgroup.WithContext(ctx)
270271

271272
cmdio.LogString(ctx, "Establishing SSH proxy connection...")
272-
proxy := newProxyConnection()
273-
if err := proxy.Connect(gCtx, client, clusterID, serverPort); err != nil {
273+
proxy := newProxyConnection(func(ctx context.Context, connID string) (*websocket.Conn, error) {
274+
return createWebsocketConnection(ctx, client, connID, clusterID, serverPort)
275+
})
276+
if err := proxy.Connect(gCtx); err != nil {
274277
return fmt.Errorf("failed to connect to proxy: %w", err)
275278
}
279+
defer proxy.Close()
276280
cmdio.LogString(ctx, "SSH proxy connection established")
277281

278-
go func() {
279-
<-ctx.Done()
280-
cmdio.LogString(ctx, "Closing ssh proxy connection")
281-
err := proxy.Close()
282-
if err != nil {
283-
cmdio.LogError(ctx, err)
284-
}
285-
}()
286-
287282
cmdio.LogString(ctx, fmt.Sprintf("Connection handover timeout: %v", handoverTimeout))
288283
handoverTicker := time.NewTicker(handoverTimeout)
289284
defer handoverTicker.Stop()
@@ -294,7 +289,7 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
294289
case <-gCtx.Done():
295290
return gCtx.Err()
296291
case <-handoverTicker.C:
297-
err := proxy.InitiateHandover(gCtx, client, clusterID, serverPort)
292+
err := proxy.InitiateHandover(gCtx)
298293
if err != nil {
299294
return err
300295
}
@@ -303,11 +298,7 @@ func startSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, clus
303298
})
304299

305300
g.Go(func() error {
306-
return proxy.RunReceivingLoop(gCtx, os.Stdout)
307-
})
308-
309-
g.Go(func() error {
310-
return proxy.RunSendingLoop(gCtx, os.Stdin)
301+
return proxy.Start(gCtx, os.Stdin, os.Stdout)
311302
})
312303

313304
return g.Wait()

libs/ssh/proxy.go

Lines changed: 37 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ import (
1010
"sync/atomic"
1111
"time"
1212

13-
"github.com/databricks/databricks-sdk-go"
1413
"github.com/google/uuid"
1514
"github.com/gorilla/websocket"
15+
"golang.org/x/sync/errgroup"
1616
)
1717

1818
var (
@@ -23,69 +23,45 @@ var (
2323
)
2424

2525
type proxyConnection struct {
26-
workspaceID int64
27-
connID string
28-
conn atomic.Value // *websocket.Conn
26+
connID string
27+
conn atomic.Value // *websocket.Conn
28+
createWebsocketConnection createWebsocketConnectionFunc
2929

3030
handoverMutex sync.Mutex
3131
isHandover atomic.Bool
3232
currentConnectionClosed chan error
3333
}
3434

35-
func newProxyConnection() *proxyConnection {
35+
type createWebsocketConnectionFunc func(ctx context.Context, connID string) (*websocket.Conn, error)
36+
37+
func newProxyConnection(createConn createWebsocketConnectionFunc) *proxyConnection {
3638
return &proxyConnection{
37-
connID: uuid.NewString(),
38-
currentConnectionClosed: make(chan error),
39+
connID: uuid.NewString(),
40+
currentConnectionClosed: make(chan error),
41+
createWebsocketConnection: createConn,
3942
}
4043
}
4144

42-
func (pc *proxyConnection) Connect(ctx context.Context, client *databricks.WorkspaceClient, clusterID string, serverPort int) error {
43-
conn, err := pc.createWebsocketConnection(ctx, client, clusterID, serverPort)
45+
func (pc *proxyConnection) Start(ctx context.Context, src io.Reader, dst io.Writer) error {
46+
g, gCtx := errgroup.WithContext(ctx)
47+
g.Go(func() error {
48+
return pc.runSendingLoop(gCtx, src)
49+
})
50+
g.Go(func() error {
51+
return pc.runReceivingLoop(gCtx, dst)
52+
})
53+
return g.Wait()
54+
}
55+
56+
func (pc *proxyConnection) Connect(ctx context.Context) error {
57+
conn, err := pc.createWebsocketConnection(ctx, pc.connID)
4458
if err != nil {
4559
return err
4660
}
4761
pc.conn.Store(conn)
4862
return nil
4963
}
5064

51-
func (pc *proxyConnection) createWebsocketConnection(ctx context.Context, client *databricks.WorkspaceClient, clusterID string, serverPort int) (*websocket.Conn, error) {
52-
url, err := pc.getProxyURL(ctx, client, clusterID, serverPort)
53-
if err != nil {
54-
return nil, fmt.Errorf("failed to get proxy URL: %w", err)
55-
}
56-
57-
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
58-
if err != nil {
59-
return nil, fmt.Errorf("failed to create request: %w", err)
60-
}
61-
62-
if err := client.Config.Authenticate(req); err != nil {
63-
return nil, fmt.Errorf("failed to authenticate: %w", err)
64-
}
65-
66-
req.URL.Scheme = "wss"
67-
// websocket connection manages lifecycle of the response object, no need to close the body
68-
conn, _, err := websocket.DefaultDialer.Dial(req.URL.String(), req.Header) // nolint:bodyclose
69-
if err != nil {
70-
return nil, fmt.Errorf("failed to establish websocket connection: %w", err)
71-
}
72-
73-
return conn, nil
74-
}
75-
76-
func (pc *proxyConnection) getProxyURL(ctx context.Context, client *databricks.WorkspaceClient, clusterID string, serverPort int) (string, error) {
77-
if pc.workspaceID == 0 {
78-
workspaceID, err := client.CurrentWorkspaceID(ctx)
79-
if err != nil {
80-
return "", fmt.Errorf("failed to get current workspace ID: %w", err)
81-
}
82-
pc.workspaceID = workspaceID
83-
}
84-
host := client.Config.Host
85-
url := fmt.Sprintf("%s/driver-proxy-api/o/%d/%s/%d/ssh?id=%s", host, pc.workspaceID, clusterID, serverPort, pc.connID)
86-
return url, nil
87-
}
88-
8965
func (pc *proxyConnection) Accept(w http.ResponseWriter, r *http.Request) error {
9066
conn, err := pc.acceptWebsocketConnection(w, r)
9167
if err != nil {
@@ -104,8 +80,11 @@ func (pc *proxyConnection) acceptWebsocketConnection(w http.ResponseWriter, r *h
10480
return conn, nil
10581
}
10682

107-
func (pc *proxyConnection) RunSendingLoop(ctx context.Context, src io.Reader) error {
83+
func (pc *proxyConnection) runSendingLoop(ctx context.Context, src io.Reader) error {
10884
for {
85+
if ctx.Err() != nil {
86+
return ctx.Err()
87+
}
10988
b := make([]byte, 32*1024)
11089
n, readErr := src.Read(b)
11190
if n > 0 {
@@ -133,8 +112,11 @@ func (pc *proxyConnection) sendMessage(mt int, data []byte) error {
133112
return conn.WriteMessage(mt, data)
134113
}
135114

136-
func (pc *proxyConnection) RunReceivingLoop(ctx context.Context, dst io.Writer) error {
115+
func (pc *proxyConnection) runReceivingLoop(ctx context.Context, dst io.Writer) error {
137116
for {
117+
if ctx.Err() != nil {
118+
return ctx.Err()
119+
}
138120
conn := pc.conn.Load().(*websocket.Conn)
139121
mt, data, err := conn.ReadMessage()
140122
if err != nil {
@@ -205,7 +187,7 @@ func (pc *proxyConnection) Close() error {
205187
return nil
206188
}
207189

208-
func (pc *proxyConnection) InitiateHandover(ctx context.Context, client *databricks.WorkspaceClient, clusterID string, serverPort int) error {
190+
func (pc *proxyConnection) InitiateHandover(ctx context.Context) error {
209191
// Blocks proxying any outgoing messages during the entire handover
210192
pc.handoverMutex.Lock()
211193
defer pc.handoverMutex.Unlock()
@@ -220,7 +202,7 @@ func (pc *proxyConnection) InitiateHandover(ctx context.Context, client *databri
220202

221203
// Create a new websocket connection by sending an /ssh?id=<connID> request to the server.
222204
// When server realises it's an ID of an existing connection, it will start AcceptHandover process.
223-
newConn, err := pc.createWebsocketConnection(ctx, client, clusterID, serverPort)
205+
newConn, err := pc.createWebsocketConnection(ctx, pc.connID)
224206
if err != nil {
225207
return fmt.Errorf("failed to create new websocket connection: %w", err)
226208
}
@@ -287,3 +269,7 @@ func (pc *proxyConnection) AcceptHandover(ctx context.Context, w http.ResponseWr
287269

288270
return nil
289271
}
272+
273+
func IsNormalClosure(err error) bool {
274+
return websocket.IsCloseError(err, websocket.CloseNormalClosure) || errors.Is(err, errProxyEOF)
275+
}

0 commit comments

Comments
 (0)