@@ -10,9 +10,9 @@ import (
10
10
"sync/atomic"
11
11
"time"
12
12
13
- "github.com/databricks/databricks-sdk-go"
14
13
"github.com/google/uuid"
15
14
"github.com/gorilla/websocket"
15
+ "golang.org/x/sync/errgroup"
16
16
)
17
17
18
18
var (
@@ -23,69 +23,45 @@ var (
23
23
)
24
24
25
25
type proxyConnection struct {
26
- workspaceID int64
27
- connID string
28
- conn atomic. Value // *websocket.Conn
26
+ connID string
27
+ conn atomic. Value // *websocket.Conn
28
+ createWebsocketConnection createWebsocketConnectionFunc
29
29
30
30
handoverMutex sync.Mutex
31
31
isHandover atomic.Bool
32
32
currentConnectionClosed chan error
33
33
}
34
34
35
- func newProxyConnection () * proxyConnection {
35
+ type createWebsocketConnectionFunc func (ctx context.Context , connID string ) (* websocket.Conn , error )
36
+
37
+ func newProxyConnection (createConn createWebsocketConnectionFunc ) * proxyConnection {
36
38
return & proxyConnection {
37
- connID : uuid .NewString (),
38
- currentConnectionClosed : make (chan error ),
39
+ connID : uuid .NewString (),
40
+ currentConnectionClosed : make (chan error ),
41
+ createWebsocketConnection : createConn ,
39
42
}
40
43
}
41
44
42
- func (pc * proxyConnection ) Connect (ctx context.Context , client * databricks.WorkspaceClient , clusterID string , serverPort int ) error {
43
- conn , err := pc .createWebsocketConnection (ctx , client , clusterID , serverPort )
45
+ func (pc * proxyConnection ) Start (ctx context.Context , src io.Reader , dst io.Writer ) error {
46
+ g , gCtx := errgroup .WithContext (ctx )
47
+ g .Go (func () error {
48
+ return pc .runSendingLoop (gCtx , src )
49
+ })
50
+ g .Go (func () error {
51
+ return pc .runReceivingLoop (gCtx , dst )
52
+ })
53
+ return g .Wait ()
54
+ }
55
+
56
+ func (pc * proxyConnection ) Connect (ctx context.Context ) error {
57
+ conn , err := pc .createWebsocketConnection (ctx , pc .connID )
44
58
if err != nil {
45
59
return err
46
60
}
47
61
pc .conn .Store (conn )
48
62
return nil
49
63
}
50
64
51
- func (pc * proxyConnection ) createWebsocketConnection (ctx context.Context , client * databricks.WorkspaceClient , clusterID string , serverPort int ) (* websocket.Conn , error ) {
52
- url , err := pc .getProxyURL (ctx , client , clusterID , serverPort )
53
- if err != nil {
54
- return nil , fmt .Errorf ("failed to get proxy URL: %w" , err )
55
- }
56
-
57
- req , err := http .NewRequestWithContext (ctx , "GET" , url , nil )
58
- if err != nil {
59
- return nil , fmt .Errorf ("failed to create request: %w" , err )
60
- }
61
-
62
- if err := client .Config .Authenticate (req ); err != nil {
63
- return nil , fmt .Errorf ("failed to authenticate: %w" , err )
64
- }
65
-
66
- req .URL .Scheme = "wss"
67
- // websocket connection manages lifecycle of the response object, no need to close the body
68
- conn , _ , err := websocket .DefaultDialer .Dial (req .URL .String (), req .Header ) // nolint:bodyclose
69
- if err != nil {
70
- return nil , fmt .Errorf ("failed to establish websocket connection: %w" , err )
71
- }
72
-
73
- return conn , nil
74
- }
75
-
76
- func (pc * proxyConnection ) getProxyURL (ctx context.Context , client * databricks.WorkspaceClient , clusterID string , serverPort int ) (string , error ) {
77
- if pc .workspaceID == 0 {
78
- workspaceID , err := client .CurrentWorkspaceID (ctx )
79
- if err != nil {
80
- return "" , fmt .Errorf ("failed to get current workspace ID: %w" , err )
81
- }
82
- pc .workspaceID = workspaceID
83
- }
84
- host := client .Config .Host
85
- url := fmt .Sprintf ("%s/driver-proxy-api/o/%d/%s/%d/ssh?id=%s" , host , pc .workspaceID , clusterID , serverPort , pc .connID )
86
- return url , nil
87
- }
88
-
89
65
func (pc * proxyConnection ) Accept (w http.ResponseWriter , r * http.Request ) error {
90
66
conn , err := pc .acceptWebsocketConnection (w , r )
91
67
if err != nil {
@@ -104,8 +80,11 @@ func (pc *proxyConnection) acceptWebsocketConnection(w http.ResponseWriter, r *h
104
80
return conn , nil
105
81
}
106
82
107
- func (pc * proxyConnection ) RunSendingLoop (ctx context.Context , src io.Reader ) error {
83
+ func (pc * proxyConnection ) runSendingLoop (ctx context.Context , src io.Reader ) error {
108
84
for {
85
+ if ctx .Err () != nil {
86
+ return ctx .Err ()
87
+ }
109
88
b := make ([]byte , 32 * 1024 )
110
89
n , readErr := src .Read (b )
111
90
if n > 0 {
@@ -133,8 +112,11 @@ func (pc *proxyConnection) sendMessage(mt int, data []byte) error {
133
112
return conn .WriteMessage (mt , data )
134
113
}
135
114
136
- func (pc * proxyConnection ) RunReceivingLoop (ctx context.Context , dst io.Writer ) error {
115
+ func (pc * proxyConnection ) runReceivingLoop (ctx context.Context , dst io.Writer ) error {
137
116
for {
117
+ if ctx .Err () != nil {
118
+ return ctx .Err ()
119
+ }
138
120
conn := pc .conn .Load ().(* websocket.Conn )
139
121
mt , data , err := conn .ReadMessage ()
140
122
if err != nil {
@@ -205,7 +187,7 @@ func (pc *proxyConnection) Close() error {
205
187
return nil
206
188
}
207
189
208
- func (pc * proxyConnection ) InitiateHandover (ctx context.Context , client * databricks. WorkspaceClient , clusterID string , serverPort int ) error {
190
+ func (pc * proxyConnection ) InitiateHandover (ctx context.Context ) error {
209
191
// Blocks proxying any outgoing messages during the entire handover
210
192
pc .handoverMutex .Lock ()
211
193
defer pc .handoverMutex .Unlock ()
@@ -220,7 +202,7 @@ func (pc *proxyConnection) InitiateHandover(ctx context.Context, client *databri
220
202
221
203
// Create a new websocket connection by sending an /ssh?id=<connID> request to the server.
222
204
// When server realises it's an ID of an existing connection, it will start AcceptHandover process.
223
- newConn , err := pc .createWebsocketConnection (ctx , client , clusterID , serverPort )
205
+ newConn , err := pc .createWebsocketConnection (ctx , pc . connID )
224
206
if err != nil {
225
207
return fmt .Errorf ("failed to create new websocket connection: %w" , err )
226
208
}
@@ -287,3 +269,7 @@ func (pc *proxyConnection) AcceptHandover(ctx context.Context, w http.ResponseWr
287
269
288
270
return nil
289
271
}
272
+
273
+ func IsNormalClosure (err error ) bool {
274
+ return websocket .IsCloseError (err , websocket .CloseNormalClosure ) || errors .Is (err , errProxyEOF )
275
+ }
0 commit comments