Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/databricks/cli/cmd/psql"
"github.com/databricks/cli/cmd/ssh"

"github.com/databricks/cli/cmd/account"
"github.com/databricks/cli/cmd/api"
Expand Down Expand Up @@ -77,6 +78,7 @@ func New(ctx context.Context) *cobra.Command {
cli.AddCommand(version.New())
cli.AddCommand(selftest.New())
cli.AddCommand(pipelines.InstallPipelinesCLI())
cli.AddCommand(ssh.New())

// Add workspace command groups, filtering out empty groups or groups with only hidden commands.
allGroups := workspace.Groups()
Expand Down
60 changes: 60 additions & 0 deletions cmd/ssh/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package ssh

import (
"time"

"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/libs/cmdctx"
"github.com/databricks/cli/libs/ssh"
"github.com/spf13/cobra"
)

func newServerCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "server",
Short: "Run SSH tunnel server",
Long: `Run SSH tunnel server.
This command starts an SSH tunnel server that accepts WebSocket connections
and proxies them to local SSH daemon processes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This command is run only on the backend, correct? Users should never interact with this directly.

If so, perhaps we can check for a magic env var to be set. That prevents users who accidentally (or out of curiosity) run this from modifying a bunch of files in their home directory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check sort of exists right now, as we fail is PUBLIC_SSH_KEY is not set. But I just realized that we don't need to set this env var and just pass the secret name to the go server - it can download the key itself (doing it now in the follow up PR, to avoid rebase conflicts). After that change it still won't be easy to run the server by mistake, since you'll need to read the code to understand what secret name to use (and make sure it actually exists).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack. The failure happens somewhere deep in the stack. It would be useful to make this error more explicit. If a user runs it and gets some error about a secret missing, that's confusing. If instead they see "This command is supposed to be run on Databricks compute. Please find usage instructions at <>.", that's better.

` + disclaimer,
// This is an internal command spawned by the SSH client running the "ssh-server-bootstrap.py" job
Hidden: true,
}

var maxClients int
var shutdownDelay time.Duration
var clusterID string
var version string

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
cmd.MarkFlagRequired("cluster")
cmd.Flags().IntVar(&maxClients, "max-clients", 10, "Maximum number of SSH clients")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to be configured/configurable?

Copy link
Contributor Author

@ilia-db ilia-db Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not strictly necessary, just didn't want to hard code some random number without a way to change it

cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", 10*time.Minute, "Delay before shutting down after no pings from clients")
cmd.Flags().StringVar(&version, "version", "", "Client version of the Databricks CLI")

cmd.PreRunE = root.MustWorkspaceClient
cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
client := cmdctx.WorkspaceClient(ctx)
opts := ssh.ServerOptions{
ClusterID: clusterID,
MaxClients: maxClients,
ShutdownDelay: shutdownDelay,
Version: version,
ConfigDir: ".ssh-tunnel",
ServerPrivateKeyName: "server-private-key",
ServerPublicKeyName: "server-public-key",
DefaultPort: 7772,
PortRange: 100,
}
err := ssh.RunServer(ctx, client, opts)
if err != nil && ssh.IsNormalClosure(err) {
return nil
}
return err
}

return cmd
}
1 change: 1 addition & 0 deletions cmd/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Common workflows:

cmd.AddCommand(newSetupCommand())
cmd.AddCommand(newConnectCommand())
cmd.AddCommand(newServerCommand())

return cmd
}
185 changes: 185 additions & 0 deletions libs/ssh/jupyter-init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from typing import List, Optional
from IPython.core.getipython import get_ipython
from IPython.display import display as ip_display
from dbruntime import UserNamespaceInitializer


def _log_exceptions(func):
from functools import wraps

@wraps(func)
def wrapper(*args, **kwargs):
try:
print(f"Executing {func.__name__}")
return func(*args, **kwargs)
except Exception as e:
print(f"Error in {func.__name__}: {e}")

return wrapper


_user_namespace_initializer = UserNamespaceInitializer.getOrCreate()
_entry_point = _user_namespace_initializer.get_spark_entry_point()
_globals = _user_namespace_initializer.get_namespace_globals()
for name, value in _globals.items():
print(f"Registering global: {name} = {value}")
if name not in globals():
globals()[name] = value


# 'display' from the runtime uses custom widgets that don't work in Jupyter.
# We use the IPython display instead (in combination with the html formatter for DataFrames).
globals()["display"] = ip_display


@_log_exceptions
def _register_runtime_hooks():
from dbruntime.monkey_patches import apply_dataframe_display_patch
from dbruntime.IPythonShellHooks import load_ipython_hooks, IPythonShellHook
from IPython.core.interactiveshell import ExecutionInfo

# Setting executing_raw_cell before cell execution is required to make dbutils.library.restartPython() work
class PreRunHook(IPythonShellHook):
def pre_run_cell(self, info: ExecutionInfo) -> None:
get_ipython().executing_raw_cell = info.raw_cell

load_ipython_hooks(get_ipython(), PreRunHook())
apply_dataframe_display_patch(ip_display)


def _warn_for_dbr_alternative(magic: str):
import warnings

"""Warn users about magics that have Databricks alternatives."""
local_magic_dbr_alternative = {"%%sh": "%sh"}
if magic in local_magic_dbr_alternative:
warnings.warn(
f"\\n{magic} is not supported on Databricks. This notebook might fail when running on a Databricks cluster.\\n"
f"Consider using %{local_magic_dbr_alternative[magic]} instead."
)


def _throw_if_not_supported(magic: str):
"""Throw an error for magics that are not supported locally."""
unsupported_dbr_magics = ["%r", "%scala"]
if magic in unsupported_dbr_magics:
raise NotImplementedError(f"{magic} is not supported for local Databricks Notebooks.")


def _get_cell_magic(lines: List[str]) -> Optional[str]:
"""Extract cell magic from the first line if it exists."""
if len(lines) == 0:
return None
if lines[0].strip().startswith("%%"):
return lines[0].split(" ")[0].strip()
return None


def _get_line_magic(lines: List[str]) -> Optional[str]:
"""Extract line magic from the first line if it exists."""
if len(lines) == 0:
return None
if lines[0].strip().startswith("%"):
return lines[0].split(" ")[0].strip().strip("%")
return None


def _handle_cell_magic(lines: List[str]) -> List[str]:
"""Process cell magic commands."""
cell_magic = _get_cell_magic(lines)
if cell_magic is None:
return lines

_warn_for_dbr_alternative(cell_magic)
_throw_if_not_supported(cell_magic)
return lines


def _handle_line_magic(lines: List[str]) -> List[str]:
"""Process line magic commands and transform them appropriately."""
lmagic = _get_line_magic(lines)
if lmagic is None:
return lines

_warn_for_dbr_alternative(lmagic)
_throw_if_not_supported(lmagic)

if lmagic in ["md", "md-sandbox"]:
lines[0] = "%%markdown" + lines[0].partition("%" + lmagic)[2]
return lines

if lmagic == "sh":
lines[0] = "%%sh" + lines[0].partition("%" + lmagic)[2]
return lines

if lmagic == "sql":
lines = lines[1:]
spark_string = "global _sqldf\n" + "_sqldf = spark.sql('''" + "".join(lines).replace("'", "\\'") + "''')\n" + "display(_sqldf)\n"
return spark_string.splitlines(keepends=True)

if lmagic == "python":
return lines[1:]

return lines


def _strip_hash_magic(lines: List[str]) -> List[str]:
if len(lines) == 0:
return lines
if lines[0].startswith("# MAGIC"):
return [line.partition("# MAGIC ")[2] for line in lines]
return lines


def _parse_line_for_databricks_magics(lines: List[str]) -> List[str]:
"""Main parser function for Databricks magic commands."""
if len(lines) == 0:
return lines

lines_to_ignore = ("# Databricks notebook source", "# COMMAND ----------", "# DBTITLE")
lines = [line for line in lines if not line.strip().startswith(lines_to_ignore)]
lines = "".join(lines).strip().splitlines(keepends=True)
lines = _strip_hash_magic(lines)

if _get_cell_magic(lines):
return _handle_cell_magic(lines)

if _get_line_magic(lines):
return _handle_line_magic(lines)

return lines


@_log_exceptions
def _register_magics():
"""Register the magic command parser with IPython."""
from dbruntime.DatasetInfo import UserNamespaceDict
from dbruntime.PipMagicOverrides import PipMagicOverrides

user_ns = UserNamespaceDict(
_user_namespace_initializer.get_namespace_globals(),
_entry_point.getDriverConf(),
_entry_point,
)
ip = get_ipython()
ip.input_transformers_cleanup.append(_parse_line_for_databricks_magics)
ip.register_magics(PipMagicOverrides(_entry_point, _globals["sc"]._conf, user_ns))


@_log_exceptions
def _register_formatters():
from pyspark.sql import DataFrame
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataframe

def df_html(df: DataFrame) -> str:
return df.toPandas().to_html()

ip = get_ipython()
html_formatter = ip.display_formatter.formatters["text/html"]
html_formatter.for_type(SparkConnectDataframe, df_html)
html_formatter.for_type(DataFrame, df_html)


_register_magics()
_register_formatters()
_register_runtime_hooks()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move the Python code embedded in the binary to a subpackage.

A README there can describe the purpose of the files and a Go file can embed them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the follow up PR I've moved the python logic into two separate packages (client and server), but they still live together with go files. I'll improve the situation in another follow up (as right now IDE integration is freaking out while editing these files)

39 changes: 39 additions & 0 deletions libs/ssh/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go"
"golang.org/x/crypto/ssh"
)

Expand Down Expand Up @@ -128,3 +129,41 @@ func checkAndGenerateSSHKeyPair(ctx context.Context, keyPath string) (string, st

return keyPath, strings.TrimSpace(string(publicKeyBytes)), nil
}

func checkAndGenerateSSHKeyPairFromSecrets(ctx context.Context, client *databricks.WorkspaceClient, clusterID, privateKeyName, publicKeyName string) ([]byte, []byte, error) {
secretsScopeName, err := createSecretsScope(ctx, client, clusterID)
if err != nil {
return nil, nil, fmt.Errorf("failed to create secrets scope: %w", err)
}

privateKeyBytes, err := getSecret(ctx, client, secretsScopeName, privateKeyName)
if err != nil {
cmdio.LogString(ctx, "SSH key pair not found in secrets scope, generating a new one...")

privateKeyBytes, publicKeyBytes, err := generateSSHKeyPair()
if err != nil {
return nil, nil, fmt.Errorf("failed to generate SSH key pair: %w", err)
}

err = putSecret(ctx, client, secretsScopeName, privateKeyName, string(privateKeyBytes))
if err != nil {
return nil, nil, err
}

err = putSecret(ctx, client, secretsScopeName, publicKeyName, string(publicKeyBytes))
if err != nil {
return nil, nil, err
}

return privateKeyBytes, publicKeyBytes, nil
} else {
cmdio.LogString(ctx, "Using SSH key pair from secrets scope")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this relevant to print when it is nominal operations?

You can consider using libs/log to emit a debug log message instead. Those are shown in a friendly format when the CLI is run with --debug or --log-level debug. You can use the trace level for higher verbosity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, moved everything in the server package to libs/log. I'm using stdout as log-file for now, as I don't want to add log rotation right now. The output is accessible through jobs api (for 60 days), so server troubleshooting is still possible


publicKeyBytes, err := getSecret(ctx, client, secretsScopeName, publicKeyName)
if err != nil {
return nil, nil, fmt.Errorf("failed to get public key from secrets scope: %w", err)
}

return privateKeyBytes, publicKeyBytes, nil
}
}
2 changes: 1 addition & 1 deletion libs/ssh/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (pc *proxyConnection) Close() error {
// Keep in mind that pc.sendMessage blocks during handover
err := pc.sendMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
if IsNormalClosure(err) {
return nil
} else {
return fmt.Errorf("failed to send close message: %w", err)
Expand Down
17 changes: 17 additions & 0 deletions libs/ssh/secrets.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ssh

import (
"context"
"encoding/base64"
"errors"
"fmt"

Expand All @@ -24,6 +25,22 @@ func createSecretsScope(ctx context.Context, client *databricks.WorkspaceClient,
return secretsScope, nil
}

func getSecret(ctx context.Context, client *databricks.WorkspaceClient, scope, key string) ([]byte, error) {
resp, err := client.Secrets.GetSecret(ctx, workspace.GetSecretRequest{
Scope: scope,
Key: key,
})
if err != nil {
return nil, fmt.Errorf("failed to get secret %s from scope %s: %w", key, scope, err)
}

value, err := base64.StdEncoding.DecodeString(resp.Value)
if err != nil {
return nil, fmt.Errorf("failed to decode secret key from base64: %w", err)
}
return value, nil
}

func putSecret(ctx context.Context, client *databricks.WorkspaceClient, scope, key, value string) error {
err := client.Secrets.PutSecret(ctx, workspace.PutSecret{
Scope: scope,
Expand Down
Loading
Loading