Skip to content

Commit 8be1123

Browse files
committed
Add ssh server command
1 parent 2030553 commit 8be1123

File tree

10 files changed

+865
-2
lines changed

10 files changed

+865
-2
lines changed

cmd/cmd.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"strings"
66

77
"github.com/databricks/cli/cmd/psql"
8+
"github.com/databricks/cli/cmd/ssh"
89

910
"github.com/databricks/cli/cmd/account"
1011
"github.com/databricks/cli/cmd/api"
@@ -77,6 +78,7 @@ func New(ctx context.Context) *cobra.Command {
7778
cli.AddCommand(version.New())
7879
cli.AddCommand(selftest.New())
7980
cli.AddCommand(pipelines.InstallPipelinesCLI())
81+
cli.AddCommand(ssh.New())
8082

8183
// Add workspace command groups, filtering out empty groups or groups with only hidden commands.
8284
allGroups := workspace.Groups()

cmd/ssh/server.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package ssh
2+
3+
import (
4+
"time"
5+
6+
"github.com/databricks/cli/cmd/root"
7+
"github.com/databricks/cli/libs/cmdctx"
8+
"github.com/databricks/cli/libs/ssh"
9+
"github.com/spf13/cobra"
10+
)
11+
12+
func newServerCommand() *cobra.Command {
13+
cmd := &cobra.Command{
14+
Use: "server",
15+
Short: "Run SSH tunnel server",
16+
Long: `Run SSH tunnel server.
17+
18+
This command starts an SSH tunnel server that accepts WebSocket connections
19+
and proxies them to local SSH daemon processes.
20+
21+
` + disclaimer,
22+
// This is an internal command spawned by the SSH client running the "ssh-server-bootstrap.py" job
23+
Hidden: true,
24+
}
25+
26+
var maxClients int
27+
var shutdownDelay time.Duration
28+
var clusterID string
29+
var version string
30+
31+
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
32+
cmd.MarkFlagRequired("cluster")
33+
cmd.Flags().IntVar(&maxClients, "max-clients", 10, "Maximum number of SSH clients")
34+
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", 10*time.Minute, "Delay before shutting down after no pings from clients")
35+
cmd.Flags().StringVar(&version, "version", "", "Client version of the Databricks CLI")
36+
37+
cmd.PreRunE = root.MustWorkspaceClient
38+
cmd.RunE = func(cmd *cobra.Command, args []string) error {
39+
ctx := cmd.Context()
40+
client := cmdctx.WorkspaceClient(ctx)
41+
opts := ssh.ServerOptions{
42+
ClusterID: clusterID,
43+
MaxClients: maxClients,
44+
ShutdownDelay: shutdownDelay,
45+
Version: version,
46+
ConfigDir: ".ssh-tunnel",
47+
ServerPrivateKeyName: "server-private-key",
48+
ServerPublicKeyName: "server-public-key",
49+
DefaultPort: 7772,
50+
PortRange: 100,
51+
}
52+
err := ssh.RunServer(ctx, client, opts)
53+
if err != nil && ssh.IsNormalClosure(err) {
54+
return nil
55+
}
56+
return err
57+
}
58+
59+
return cmd
60+
}

cmd/ssh/ssh.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Common workflows:
2929

3030
cmd.AddCommand(newSetupCommand())
3131
cmd.AddCommand(newConnectCommand())
32+
cmd.AddCommand(newServerCommand())
3233

3334
return cmd
3435
}

libs/ssh/jupyter-init.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from typing import List, Optional
2+
from IPython.core.getipython import get_ipython
3+
from IPython.display import display as ip_display
4+
from dbruntime import UserNamespaceInitializer
5+
6+
7+
def _log_exceptions(func):
8+
from functools import wraps
9+
10+
@wraps(func)
11+
def wrapper(*args, **kwargs):
12+
try:
13+
print(f"Executing {func.__name__}")
14+
return func(*args, **kwargs)
15+
except Exception as e:
16+
print(f"Error in {func.__name__}: {e}")
17+
18+
return wrapper
19+
20+
21+
_user_namespace_initializer = UserNamespaceInitializer.getOrCreate()
22+
_entry_point = _user_namespace_initializer.get_spark_entry_point()
23+
_globals = _user_namespace_initializer.get_namespace_globals()
24+
for name, value in _globals.items():
25+
print(f"Registering global: {name} = {value}")
26+
if name not in globals():
27+
globals()[name] = value
28+
29+
30+
# 'display' from the runtime uses custom widgets that don't work in Jupyter.
31+
# We use the IPython display instead (in combination with the html formatter for DataFrames).
32+
globals()["display"] = ip_display
33+
34+
35+
@_log_exceptions
36+
def _register_runtime_hooks():
37+
from dbruntime.monkey_patches import apply_dataframe_display_patch
38+
from dbruntime.IPythonShellHooks import load_ipython_hooks, IPythonShellHook
39+
from IPython.core.interactiveshell import ExecutionInfo
40+
41+
# Setting executing_raw_cell before cell execution is required to make dbutils.library.restartPython() work
42+
class PreRunHook(IPythonShellHook):
43+
def pre_run_cell(self, info: ExecutionInfo) -> None:
44+
get_ipython().executing_raw_cell = info.raw_cell
45+
46+
load_ipython_hooks(get_ipython(), PreRunHook())
47+
apply_dataframe_display_patch(ip_display)
48+
49+
50+
def _warn_for_dbr_alternative(magic: str):
51+
import warnings
52+
53+
"""Warn users about magics that have Databricks alternatives."""
54+
local_magic_dbr_alternative = {"%%sh": "%sh"}
55+
if magic in local_magic_dbr_alternative:
56+
warnings.warn(
57+
f"\\n{magic} is not supported on Databricks. This notebook might fail when running on a Databricks cluster.\\n"
58+
f"Consider using %{local_magic_dbr_alternative[magic]} instead."
59+
)
60+
61+
62+
def _throw_if_not_supported(magic: str):
63+
"""Throw an error for magics that are not supported locally."""
64+
unsupported_dbr_magics = ["%r", "%scala"]
65+
if magic in unsupported_dbr_magics:
66+
raise NotImplementedError(f"{magic} is not supported for local Databricks Notebooks.")
67+
68+
69+
def _get_cell_magic(lines: List[str]) -> Optional[str]:
70+
"""Extract cell magic from the first line if it exists."""
71+
if len(lines) == 0:
72+
return None
73+
if lines[0].strip().startswith("%%"):
74+
return lines[0].split(" ")[0].strip()
75+
return None
76+
77+
78+
def _get_line_magic(lines: List[str]) -> Optional[str]:
79+
"""Extract line magic from the first line if it exists."""
80+
if len(lines) == 0:
81+
return None
82+
if lines[0].strip().startswith("%"):
83+
return lines[0].split(" ")[0].strip().strip("%")
84+
return None
85+
86+
87+
def _handle_cell_magic(lines: List[str]) -> List[str]:
88+
"""Process cell magic commands."""
89+
cell_magic = _get_cell_magic(lines)
90+
if cell_magic is None:
91+
return lines
92+
93+
_warn_for_dbr_alternative(cell_magic)
94+
_throw_if_not_supported(cell_magic)
95+
return lines
96+
97+
98+
def _handle_line_magic(lines: List[str]) -> List[str]:
99+
"""Process line magic commands and transform them appropriately."""
100+
lmagic = _get_line_magic(lines)
101+
if lmagic is None:
102+
return lines
103+
104+
_warn_for_dbr_alternative(lmagic)
105+
_throw_if_not_supported(lmagic)
106+
107+
if lmagic in ["md", "md-sandbox"]:
108+
lines[0] = "%%markdown" + lines[0].partition("%" + lmagic)[2]
109+
return lines
110+
111+
if lmagic == "sh":
112+
lines[0] = "%%sh" + lines[0].partition("%" + lmagic)[2]
113+
return lines
114+
115+
if lmagic == "sql":
116+
lines = lines[1:]
117+
spark_string = "global _sqldf\n" + "_sqldf = spark.sql('''" + "".join(lines).replace("'", "\\'") + "''')\n" + "display(_sqldf)\n"
118+
return spark_string.splitlines(keepends=True)
119+
120+
if lmagic == "python":
121+
return lines[1:]
122+
123+
return lines
124+
125+
126+
def _strip_hash_magic(lines: List[str]) -> List[str]:
127+
if len(lines) == 0:
128+
return lines
129+
if lines[0].startswith("# MAGIC"):
130+
return [line.partition("# MAGIC ")[2] for line in lines]
131+
return lines
132+
133+
134+
def _parse_line_for_databricks_magics(lines: List[str]) -> List[str]:
135+
"""Main parser function for Databricks magic commands."""
136+
if len(lines) == 0:
137+
return lines
138+
139+
lines_to_ignore = ("# Databricks notebook source", "# COMMAND ----------", "# DBTITLE")
140+
lines = [line for line in lines if not line.strip().startswith(lines_to_ignore)]
141+
lines = "".join(lines).strip().splitlines(keepends=True)
142+
lines = _strip_hash_magic(lines)
143+
144+
if _get_cell_magic(lines):
145+
return _handle_cell_magic(lines)
146+
147+
if _get_line_magic(lines):
148+
return _handle_line_magic(lines)
149+
150+
return lines
151+
152+
153+
@_log_exceptions
154+
def _register_magics():
155+
"""Register the magic command parser with IPython."""
156+
from dbruntime.DatasetInfo import UserNamespaceDict
157+
from dbruntime.PipMagicOverrides import PipMagicOverrides
158+
159+
user_ns = UserNamespaceDict(
160+
_user_namespace_initializer.get_namespace_globals(),
161+
_entry_point.getDriverConf(),
162+
_entry_point,
163+
)
164+
ip = get_ipython()
165+
ip.input_transformers_cleanup.append(_parse_line_for_databricks_magics)
166+
ip.register_magics(PipMagicOverrides(_entry_point, _globals["sc"]._conf, user_ns))
167+
168+
169+
@_log_exceptions
170+
def _register_formatters():
171+
from pyspark.sql import DataFrame
172+
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataframe
173+
174+
def df_html(df: DataFrame) -> str:
175+
return df.toPandas().to_html()
176+
177+
ip = get_ipython()
178+
html_formatter = ip.display_formatter.formatters["text/html"]
179+
html_formatter.for_type(SparkConnectDataframe, df_html)
180+
html_formatter.for_type(DataFrame, df_html)
181+
182+
183+
_register_magics()
184+
_register_formatters()
185+
_register_runtime_hooks()

libs/ssh/keys.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"strings"
1414

1515
"github.com/databricks/cli/libs/cmdio"
16+
"github.com/databricks/databricks-sdk-go"
1617
"golang.org/x/crypto/ssh"
1718
)
1819

@@ -128,3 +129,41 @@ func checkAndGenerateSSHKeyPair(ctx context.Context, keyPath string) (string, st
128129

129130
return keyPath, strings.TrimSpace(string(publicKeyBytes)), nil
130131
}
132+
133+
func checkAndGenerateSSHKeyPairFromSecrets(ctx context.Context, client *databricks.WorkspaceClient, clusterID, privateKeyName, publicKeyName string) ([]byte, []byte, error) {
134+
secretsScopeName, err := createSecretsScope(ctx, client, clusterID)
135+
if err != nil {
136+
return nil, nil, fmt.Errorf("failed to create secrets scope: %w", err)
137+
}
138+
139+
privateKeyBytes, err := getSecret(ctx, client, secretsScopeName, privateKeyName)
140+
if err != nil {
141+
cmdio.LogString(ctx, "SSH key pair not found in secrets scope, generating a new one...")
142+
143+
privateKeyBytes, publicKeyBytes, err := generateSSHKeyPair()
144+
if err != nil {
145+
return nil, nil, fmt.Errorf("failed to generate SSH key pair: %w", err)
146+
}
147+
148+
err = putSecret(ctx, client, secretsScopeName, privateKeyName, string(privateKeyBytes))
149+
if err != nil {
150+
return nil, nil, err
151+
}
152+
153+
err = putSecret(ctx, client, secretsScopeName, publicKeyName, string(publicKeyBytes))
154+
if err != nil {
155+
return nil, nil, err
156+
}
157+
158+
return privateKeyBytes, publicKeyBytes, nil
159+
} else {
160+
cmdio.LogString(ctx, "Using SSH key pair from secrets scope")
161+
162+
publicKeyBytes, err := getSecret(ctx, client, secretsScopeName, publicKeyName)
163+
if err != nil {
164+
return nil, nil, fmt.Errorf("failed to get public key from secrets scope: %w", err)
165+
}
166+
167+
return privateKeyBytes, publicKeyBytes, nil
168+
}
169+
}

libs/ssh/proxy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ func (pc *proxyConnection) Close() error {
180180
// Keep in mind that pc.sendMessage blocks during handover
181181
err := pc.sendMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
182182
if err != nil {
183-
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
183+
if IsNormalClosure(err) {
184184
return nil
185185
} else {
186186
return fmt.Errorf("failed to send close message: %w", err)

libs/ssh/secrets.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package ssh
22

33
import (
44
"context"
5+
"encoding/base64"
56
"errors"
67
"fmt"
78

@@ -24,6 +25,22 @@ func createSecretsScope(ctx context.Context, client *databricks.WorkspaceClient,
2425
return secretsScope, nil
2526
}
2627

28+
func getSecret(ctx context.Context, client *databricks.WorkspaceClient, scope, key string) ([]byte, error) {
29+
resp, err := client.Secrets.GetSecret(ctx, workspace.GetSecretRequest{
30+
Scope: scope,
31+
Key: key,
32+
})
33+
if err != nil {
34+
return nil, fmt.Errorf("failed to get secret %s from scope %s: %w", key, scope, err)
35+
}
36+
37+
value, err := base64.StdEncoding.DecodeString(resp.Value)
38+
if err != nil {
39+
return nil, fmt.Errorf("failed to decode secret key from base64: %w", err)
40+
}
41+
return value, nil
42+
}
43+
2744
func putSecret(ctx context.Context, client *databricks.WorkspaceClient, scope, key, value string) error {
2845
err := client.Secrets.PutSecret(ctx, workspace.PutSecret{
2946
Scope: scope,

0 commit comments

Comments
 (0)