-
Notifications
You must be signed in to change notification settings - Fork 108
Add "ssh server" command #3475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add "ssh server" command #3475
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
package ssh | ||
|
||
import ( | ||
"time" | ||
|
||
"github.com/databricks/cli/cmd/root" | ||
"github.com/databricks/cli/libs/cmdctx" | ||
"github.com/databricks/cli/libs/ssh" | ||
"github.com/spf13/cobra" | ||
) | ||
|
||
func newServerCommand() *cobra.Command { | ||
cmd := &cobra.Command{ | ||
Use: "server", | ||
Short: "Run SSH tunnel server", | ||
Long: `Run SSH tunnel server. | ||
This command starts an SSH tunnel server that accepts WebSocket connections | ||
and proxies them to local SSH daemon processes. | ||
` + disclaimer, | ||
// This is an internal command spawned by the SSH client running the "ssh-server-bootstrap.py" job | ||
Hidden: true, | ||
} | ||
|
||
var maxClients int | ||
var shutdownDelay time.Duration | ||
var clusterID string | ||
var version string | ||
|
||
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID") | ||
cmd.MarkFlagRequired("cluster") | ||
cmd.Flags().IntVar(&maxClients, "max-clients", 10, "Maximum number of SSH clients") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this need to be configured/configurable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not strictly necessary, just didn't want to hard code some random number without a way to change it |
||
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", 10*time.Minute, "Delay before shutting down after no pings from clients") | ||
cmd.Flags().StringVar(&version, "version", "", "Client version of the Databricks CLI") | ||
|
||
cmd.PreRunE = root.MustWorkspaceClient | ||
cmd.RunE = func(cmd *cobra.Command, args []string) error { | ||
ctx := cmd.Context() | ||
client := cmdctx.WorkspaceClient(ctx) | ||
opts := ssh.ServerOptions{ | ||
ClusterID: clusterID, | ||
MaxClients: maxClients, | ||
ShutdownDelay: shutdownDelay, | ||
Version: version, | ||
ConfigDir: ".ssh-tunnel", | ||
ServerPrivateKeyName: "server-private-key", | ||
ServerPublicKeyName: "server-public-key", | ||
DefaultPort: 7772, | ||
PortRange: 100, | ||
ilia-db marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
err := ssh.RunServer(ctx, client, opts) | ||
if err != nil && ssh.IsNormalClosure(err) { | ||
return nil | ||
} | ||
return err | ||
} | ||
|
||
return cmd | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
from typing import List, Optional | ||
from IPython.core.getipython import get_ipython | ||
from IPython.display import display as ip_display | ||
from dbruntime import UserNamespaceInitializer | ||
|
||
|
||
def _log_exceptions(func): | ||
from functools import wraps | ||
|
||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
try: | ||
print(f"Executing {func.__name__}") | ||
return func(*args, **kwargs) | ||
except Exception as e: | ||
print(f"Error in {func.__name__}: {e}") | ||
|
||
return wrapper | ||
|
||
|
||
_user_namespace_initializer = UserNamespaceInitializer.getOrCreate() | ||
_entry_point = _user_namespace_initializer.get_spark_entry_point() | ||
_globals = _user_namespace_initializer.get_namespace_globals() | ||
for name, value in _globals.items(): | ||
print(f"Registering global: {name} = {value}") | ||
if name not in globals(): | ||
globals()[name] = value | ||
|
||
|
||
# 'display' from the runtime uses custom widgets that don't work in Jupyter. | ||
# We use the IPython display instead (in combination with the html formatter for DataFrames). | ||
globals()["display"] = ip_display | ||
|
||
|
||
@_log_exceptions | ||
def _register_runtime_hooks(): | ||
from dbruntime.monkey_patches import apply_dataframe_display_patch | ||
from dbruntime.IPythonShellHooks import load_ipython_hooks, IPythonShellHook | ||
from IPython.core.interactiveshell import ExecutionInfo | ||
|
||
# Setting executing_raw_cell before cell execution is required to make dbutils.library.restartPython() work | ||
class PreRunHook(IPythonShellHook): | ||
def pre_run_cell(self, info: ExecutionInfo) -> None: | ||
get_ipython().executing_raw_cell = info.raw_cell | ||
|
||
load_ipython_hooks(get_ipython(), PreRunHook()) | ||
apply_dataframe_display_patch(ip_display) | ||
|
||
|
||
def _warn_for_dbr_alternative(magic: str): | ||
import warnings | ||
|
||
"""Warn users about magics that have Databricks alternatives.""" | ||
local_magic_dbr_alternative = {"%%sh": "%sh"} | ||
if magic in local_magic_dbr_alternative: | ||
warnings.warn( | ||
f"\\n{magic} is not supported on Databricks. This notebook might fail when running on a Databricks cluster.\\n" | ||
f"Consider using %{local_magic_dbr_alternative[magic]} instead." | ||
) | ||
|
||
|
||
def _throw_if_not_supported(magic: str): | ||
"""Throw an error for magics that are not supported locally.""" | ||
unsupported_dbr_magics = ["%r", "%scala"] | ||
if magic in unsupported_dbr_magics: | ||
raise NotImplementedError(f"{magic} is not supported for local Databricks Notebooks.") | ||
|
||
|
||
def _get_cell_magic(lines: List[str]) -> Optional[str]: | ||
"""Extract cell magic from the first line if it exists.""" | ||
if len(lines) == 0: | ||
return None | ||
if lines[0].strip().startswith("%%"): | ||
return lines[0].split(" ")[0].strip() | ||
return None | ||
|
||
|
||
def _get_line_magic(lines: List[str]) -> Optional[str]: | ||
"""Extract line magic from the first line if it exists.""" | ||
if len(lines) == 0: | ||
return None | ||
if lines[0].strip().startswith("%"): | ||
return lines[0].split(" ")[0].strip().strip("%") | ||
return None | ||
|
||
|
||
def _handle_cell_magic(lines: List[str]) -> List[str]: | ||
"""Process cell magic commands.""" | ||
cell_magic = _get_cell_magic(lines) | ||
if cell_magic is None: | ||
return lines | ||
|
||
_warn_for_dbr_alternative(cell_magic) | ||
_throw_if_not_supported(cell_magic) | ||
return lines | ||
|
||
|
||
def _handle_line_magic(lines: List[str]) -> List[str]: | ||
"""Process line magic commands and transform them appropriately.""" | ||
lmagic = _get_line_magic(lines) | ||
if lmagic is None: | ||
return lines | ||
|
||
_warn_for_dbr_alternative(lmagic) | ||
_throw_if_not_supported(lmagic) | ||
|
||
if lmagic in ["md", "md-sandbox"]: | ||
lines[0] = "%%markdown" + lines[0].partition("%" + lmagic)[2] | ||
return lines | ||
|
||
if lmagic == "sh": | ||
lines[0] = "%%sh" + lines[0].partition("%" + lmagic)[2] | ||
return lines | ||
|
||
if lmagic == "sql": | ||
lines = lines[1:] | ||
spark_string = "global _sqldf\n" + "_sqldf = spark.sql('''" + "".join(lines).replace("'", "\\'") + "''')\n" + "display(_sqldf)\n" | ||
return spark_string.splitlines(keepends=True) | ||
|
||
if lmagic == "python": | ||
return lines[1:] | ||
|
||
return lines | ||
|
||
|
||
def _strip_hash_magic(lines: List[str]) -> List[str]: | ||
if len(lines) == 0: | ||
return lines | ||
if lines[0].startswith("# MAGIC"): | ||
return [line.partition("# MAGIC ")[2] for line in lines] | ||
return lines | ||
|
||
|
||
def _parse_line_for_databricks_magics(lines: List[str]) -> List[str]: | ||
"""Main parser function for Databricks magic commands.""" | ||
if len(lines) == 0: | ||
return lines | ||
|
||
lines_to_ignore = ("# Databricks notebook source", "# COMMAND ----------", "# DBTITLE") | ||
lines = [line for line in lines if not line.strip().startswith(lines_to_ignore)] | ||
lines = "".join(lines).strip().splitlines(keepends=True) | ||
lines = _strip_hash_magic(lines) | ||
|
||
if _get_cell_magic(lines): | ||
return _handle_cell_magic(lines) | ||
|
||
if _get_line_magic(lines): | ||
return _handle_line_magic(lines) | ||
|
||
return lines | ||
|
||
|
||
@_log_exceptions | ||
def _register_magics(): | ||
"""Register the magic command parser with IPython.""" | ||
from dbruntime.DatasetInfo import UserNamespaceDict | ||
from dbruntime.PipMagicOverrides import PipMagicOverrides | ||
|
||
user_ns = UserNamespaceDict( | ||
_user_namespace_initializer.get_namespace_globals(), | ||
_entry_point.getDriverConf(), | ||
_entry_point, | ||
) | ||
ip = get_ipython() | ||
ip.input_transformers_cleanup.append(_parse_line_for_databricks_magics) | ||
ip.register_magics(PipMagicOverrides(_entry_point, _globals["sc"]._conf, user_ns)) | ||
|
||
|
||
@_log_exceptions | ||
def _register_formatters(): | ||
from pyspark.sql import DataFrame | ||
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataframe | ||
|
||
def df_html(df: DataFrame) -> str: | ||
return df.toPandas().to_html() | ||
|
||
ip = get_ipython() | ||
html_formatter = ip.display_formatter.formatters["text/html"] | ||
html_formatter.for_type(SparkConnectDataframe, df_html) | ||
html_formatter.for_type(DataFrame, df_html) | ||
|
||
|
||
_register_magics() | ||
_register_formatters() | ||
_register_runtime_hooks() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's move the Python code embedded in the binary to a subpackage. A README there can describe the purpose of the files and a Go file can embed them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the follow up PR I've moved the python logic into two separate packages (client and server), but they still live together with go files. I'll improve the situation in another follow up (as right now IDE integration is freaking out while editing these files) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ import ( | |
"strings" | ||
|
||
"github.com/databricks/cli/libs/cmdio" | ||
"github.com/databricks/databricks-sdk-go" | ||
"golang.org/x/crypto/ssh" | ||
) | ||
|
||
|
@@ -128,3 +129,41 @@ func checkAndGenerateSSHKeyPair(ctx context.Context, keyPath string) (string, st | |
|
||
return keyPath, strings.TrimSpace(string(publicKeyBytes)), nil | ||
} | ||
|
||
func checkAndGenerateSSHKeyPairFromSecrets(ctx context.Context, client *databricks.WorkspaceClient, clusterID, privateKeyName, publicKeyName string) ([]byte, []byte, error) { | ||
secretsScopeName, err := createSecretsScope(ctx, client, clusterID) | ||
if err != nil { | ||
return nil, nil, fmt.Errorf("failed to create secrets scope: %w", err) | ||
} | ||
|
||
privateKeyBytes, err := getSecret(ctx, client, secretsScopeName, privateKeyName) | ||
if err != nil { | ||
cmdio.LogString(ctx, "SSH key pair not found in secrets scope, generating a new one...") | ||
|
||
privateKeyBytes, publicKeyBytes, err := generateSSHKeyPair() | ||
if err != nil { | ||
return nil, nil, fmt.Errorf("failed to generate SSH key pair: %w", err) | ||
} | ||
|
||
err = putSecret(ctx, client, secretsScopeName, privateKeyName, string(privateKeyBytes)) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
err = putSecret(ctx, client, secretsScopeName, publicKeyName, string(publicKeyBytes)) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
return privateKeyBytes, publicKeyBytes, nil | ||
} else { | ||
cmdio.LogString(ctx, "Using SSH key pair from secrets scope") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this relevant to print when it is nominal operations? You can consider using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, moved everything in the server package to libs/log. I'm using stdout as log-file for now, as I don't want to add log rotation right now. The output is accessible through jobs api (for 60 days), so server troubleshooting is still possible |
||
|
||
publicKeyBytes, err := getSecret(ctx, client, secretsScopeName, publicKeyName) | ||
if err != nil { | ||
return nil, nil, fmt.Errorf("failed to get public key from secrets scope: %w", err) | ||
} | ||
|
||
return privateKeyBytes, publicKeyBytes, nil | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This command is run only on the backend, correct? Users should never interact with this directly.
If so, perhaps we can check for a magic env var to be set. That prevents users who accidentally (or out of curiosity) run this from modifying a bunch of files in their home directory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check sort of exists right now, as we fail is PUBLIC_SSH_KEY is not set. But I just realized that we don't need to set this env var and just pass the secret name to the go server - it can download the key itself (doing it now in the follow up PR, to avoid rebase conflicts). After that change it still won't be easy to run the server by mistake, since you'll need to read the code to understand what secret name to use (and make sure it actually exists).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack. The failure happens somewhere deep in the stack. It would be useful to make this error more explicit. If a user runs it and gets some error about a secret missing, that's confusing. If instead they see "This command is supposed to be run on Databricks compute. Please find usage instructions at <>.", that's better.