|
| 1 | +from typing import List, Optional |
| 2 | +from IPython.core.getipython import get_ipython |
| 3 | +from IPython.display import display as ip_display |
| 4 | +from dbruntime import UserNamespaceInitializer |
| 5 | + |
| 6 | + |
| 7 | +def _log_exceptions(func): |
| 8 | + from functools import wraps |
| 9 | + |
| 10 | + @wraps(func) |
| 11 | + def wrapper(*args, **kwargs): |
| 12 | + try: |
| 13 | + print(f"Executing {func.__name__}") |
| 14 | + return func(*args, **kwargs) |
| 15 | + except Exception as e: |
| 16 | + print(f"Error in {func.__name__}: {e}") |
| 17 | + |
| 18 | + return wrapper |
| 19 | + |
| 20 | + |
| 21 | +_user_namespace_initializer = UserNamespaceInitializer.getOrCreate() |
| 22 | +_entry_point = _user_namespace_initializer.get_spark_entry_point() |
| 23 | +_globals = _user_namespace_initializer.get_namespace_globals() |
| 24 | +for name, value in _globals.items(): |
| 25 | + print(f"Registering global: {name} = {value}") |
| 26 | + if name not in globals(): |
| 27 | + globals()[name] = value |
| 28 | + |
| 29 | + |
| 30 | +# 'display' from the runtime uses custom widgets that don't work in Jupyter. |
| 31 | +# We use the IPython display instead (in combination with the html formatter for DataFrames). |
| 32 | +globals()["display"] = ip_display |
| 33 | + |
| 34 | + |
| 35 | +@_log_exceptions |
| 36 | +def _register_runtime_hooks(): |
| 37 | + from dbruntime.monkey_patches import apply_dataframe_display_patch |
| 38 | + from dbruntime.IPythonShellHooks import load_ipython_hooks, IPythonShellHook |
| 39 | + from IPython.core.interactiveshell import ExecutionInfo |
| 40 | + |
| 41 | + # Setting executing_raw_cell before cell execution is required to make dbutils.library.restartPython() work |
| 42 | + class PreRunHook(IPythonShellHook): |
| 43 | + def pre_run_cell(self, info: ExecutionInfo) -> None: |
| 44 | + get_ipython().executing_raw_cell = info.raw_cell |
| 45 | + |
| 46 | + load_ipython_hooks(get_ipython(), PreRunHook()) |
| 47 | + apply_dataframe_display_patch(ip_display) |
| 48 | + |
| 49 | + |
| 50 | +def _warn_for_dbr_alternative(magic: str): |
| 51 | + import warnings |
| 52 | + |
| 53 | + """Warn users about magics that have Databricks alternatives.""" |
| 54 | + local_magic_dbr_alternative = {"%%sh": "%sh"} |
| 55 | + if magic in local_magic_dbr_alternative: |
| 56 | + warnings.warn( |
| 57 | + f"\\n{magic} is not supported on Databricks. This notebook might fail when running on a Databricks cluster.\\n" |
| 58 | + f"Consider using %{local_magic_dbr_alternative[magic]} instead." |
| 59 | + ) |
| 60 | + |
| 61 | + |
| 62 | +def _throw_if_not_supported(magic: str): |
| 63 | + """Throw an error for magics that are not supported locally.""" |
| 64 | + unsupported_dbr_magics = ["%r", "%scala"] |
| 65 | + if magic in unsupported_dbr_magics: |
| 66 | + raise NotImplementedError(f"{magic} is not supported for local Databricks Notebooks.") |
| 67 | + |
| 68 | + |
| 69 | +def _get_cell_magic(lines: List[str]) -> Optional[str]: |
| 70 | + """Extract cell magic from the first line if it exists.""" |
| 71 | + if len(lines) == 0: |
| 72 | + return None |
| 73 | + if lines[0].strip().startswith("%%"): |
| 74 | + return lines[0].split(" ")[0].strip() |
| 75 | + return None |
| 76 | + |
| 77 | + |
| 78 | +def _get_line_magic(lines: List[str]) -> Optional[str]: |
| 79 | + """Extract line magic from the first line if it exists.""" |
| 80 | + if len(lines) == 0: |
| 81 | + return None |
| 82 | + if lines[0].strip().startswith("%"): |
| 83 | + return lines[0].split(" ")[0].strip().strip("%") |
| 84 | + return None |
| 85 | + |
| 86 | + |
| 87 | +def _handle_cell_magic(lines: List[str]) -> List[str]: |
| 88 | + """Process cell magic commands.""" |
| 89 | + cell_magic = _get_cell_magic(lines) |
| 90 | + if cell_magic is None: |
| 91 | + return lines |
| 92 | + |
| 93 | + _warn_for_dbr_alternative(cell_magic) |
| 94 | + _throw_if_not_supported(cell_magic) |
| 95 | + return lines |
| 96 | + |
| 97 | + |
| 98 | +def _handle_line_magic(lines: List[str]) -> List[str]: |
| 99 | + """Process line magic commands and transform them appropriately.""" |
| 100 | + lmagic = _get_line_magic(lines) |
| 101 | + if lmagic is None: |
| 102 | + return lines |
| 103 | + |
| 104 | + _warn_for_dbr_alternative(lmagic) |
| 105 | + _throw_if_not_supported(lmagic) |
| 106 | + |
| 107 | + if lmagic in ["md", "md-sandbox"]: |
| 108 | + lines[0] = "%%markdown" + lines[0].partition("%" + lmagic)[2] |
| 109 | + return lines |
| 110 | + |
| 111 | + if lmagic == "sh": |
| 112 | + lines[0] = "%%sh" + lines[0].partition("%" + lmagic)[2] |
| 113 | + return lines |
| 114 | + |
| 115 | + if lmagic == "sql": |
| 116 | + lines = lines[1:] |
| 117 | + spark_string = "global _sqldf\n" + "_sqldf = spark.sql('''" + "".join(lines).replace("'", "\\'") + "''')\n" + "display(_sqldf)\n" |
| 118 | + return spark_string.splitlines(keepends=True) |
| 119 | + |
| 120 | + if lmagic == "python": |
| 121 | + return lines[1:] |
| 122 | + |
| 123 | + return lines |
| 124 | + |
| 125 | + |
| 126 | +def _strip_hash_magic(lines: List[str]) -> List[str]: |
| 127 | + if len(lines) == 0: |
| 128 | + return lines |
| 129 | + if lines[0].startswith("# MAGIC"): |
| 130 | + return [line.partition("# MAGIC ")[2] for line in lines] |
| 131 | + return lines |
| 132 | + |
| 133 | + |
| 134 | +def _parse_line_for_databricks_magics(lines: List[str]) -> List[str]: |
| 135 | + """Main parser function for Databricks magic commands.""" |
| 136 | + if len(lines) == 0: |
| 137 | + return lines |
| 138 | + |
| 139 | + lines_to_ignore = ("# Databricks notebook source", "# COMMAND ----------", "# DBTITLE") |
| 140 | + lines = [line for line in lines if not line.strip().startswith(lines_to_ignore)] |
| 141 | + lines = "".join(lines).strip().splitlines(keepends=True) |
| 142 | + lines = _strip_hash_magic(lines) |
| 143 | + |
| 144 | + if _get_cell_magic(lines): |
| 145 | + return _handle_cell_magic(lines) |
| 146 | + |
| 147 | + if _get_line_magic(lines): |
| 148 | + return _handle_line_magic(lines) |
| 149 | + |
| 150 | + return lines |
| 151 | + |
| 152 | + |
| 153 | +@_log_exceptions |
| 154 | +def _register_magics(): |
| 155 | + """Register the magic command parser with IPython.""" |
| 156 | + from dbruntime.DatasetInfo import UserNamespaceDict |
| 157 | + from dbruntime.PipMagicOverrides import PipMagicOverrides |
| 158 | + |
| 159 | + user_ns = UserNamespaceDict( |
| 160 | + _user_namespace_initializer.get_namespace_globals(), |
| 161 | + _entry_point.getDriverConf(), |
| 162 | + _entry_point, |
| 163 | + ) |
| 164 | + ip = get_ipython() |
| 165 | + ip.input_transformers_cleanup.append(_parse_line_for_databricks_magics) |
| 166 | + ip.register_magics(PipMagicOverrides(_entry_point, _globals["sc"]._conf, user_ns)) |
| 167 | + |
| 168 | + |
| 169 | +@_log_exceptions |
| 170 | +def _register_formatters(): |
| 171 | + from pyspark.sql import DataFrame |
| 172 | + from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataframe |
| 173 | + |
| 174 | + def df_html(df: DataFrame) -> str: |
| 175 | + return df.toPandas().to_html() |
| 176 | + |
| 177 | + ip = get_ipython() |
| 178 | + html_formatter = ip.display_formatter.formatters["text/html"] |
| 179 | + html_formatter.for_type(SparkConnectDataframe, df_html) |
| 180 | + html_formatter.for_type(DataFrame, df_html) |
| 181 | + |
| 182 | + |
| 183 | +_register_magics() |
| 184 | +_register_formatters() |
| 185 | +_register_runtime_hooks() |
0 commit comments