diff --git a/README.md b/README.md index 68a84706..976dba02 100644 --- a/README.md +++ b/README.md @@ -566,6 +566,74 @@ assert rows[0][0] == "-2001-08-22" assert cur.description[0][1] == "date" ``` +## Progress Callback + +The Trino client supports progress callbacks to track query execution progress in real-time. you can provide a callback function that gets called whenever the query status is updated. + +### Basic Usage + +```python +from trino.client import TrinoQuery, TrinoRequest, ClientSession, TrinoStatus +from typing import Dict, Any + +def progress_callback(status: TrinoStatus, stats: Dict[str, Any]) -> None: + """Progress callback function that gets called whenever the query status is updated.""" + state = stats.get('state', 'UNKNOWN') + processed_bytes = stats.get('processedBytes', 0) + processed_rows = stats.get('processedRows', 0) + completed_splits = stats.get('completedSplits', 0) + total_splits = stats.get('totalSplits', 0) + + print(f"Query {status.id}: {state} - {processed_bytes} bytes, {processed_rows} rows") + if total_splits > 0: + progress = (completed_splits / total_splits) * 100.0 + print(f"Progress: {progress:.1f}% ({completed_splits}/{total_splits} splits)") + +session = ClientSession(user="test_user", catalog="memory", schema="default") + +request = TrinoRequest( + host="localhost", + port=8080, + client_session=session, + http_scheme="http" +) + +query = TrinoQuery( + request=request, + query="SELECT * FROM large_table", + progress_callback=progress_callback +) + +result = query.execute() + +while not query.finished: + rows = query.fetch() +``` + +### Progress Calculation + +The callback receives a `stats` dictionary containing various metrics that can be used to calculate progress: + +- `state`: Query state (RUNNING, FINISHED, FAILED, etc.) +- `processedBytes`: Total bytes processed +- `processedRows`: Total rows processed +- `completedSplits`: Number of completed splits +- `totalSplits`: Total number of splits + +The most accurate progress calculation is based on splits completion: + +```python +def calculate_progress(stats: Dict[str, Any]) -> float: + """Calculate progress percentage based on splits completion.""" + completed_splits = stats.get('completedSplits', 0) + total_splits = stats.get('totalSplits', 0) + if total_splits > 0: + return min(100.0, (completed_splits / total_splits) * 100.0) + elif stats.get('state') == 'FINISHED': + return 100.0 + return 0.0 +``` + ### Trino to Python type mappings | Trino type | Python type | diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 909823ee..9cee2cf1 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1447,3 +1447,251 @@ def delete_password(self, servicename, username): return None os.remove(file_path) + + +@mock.patch("trino.client.TrinoRequest.http") +def test_trinoquery_heartbeat_success(mock_requests, sample_post_response_data, sample_get_response_data): + """Test that heartbeat is sent periodically and does not stop on success.""" + head_call_count = 0 + + def fake_head(url, timeout=10): + nonlocal head_call_count + head_call_count += 1 + + class Resp: + status_code = 200 + return Resp() + mock_requests.head.side_effect = fake_head + mock_requests.Response.return_value.json.return_value = sample_post_response_data + mock_requests.get.return_value.json.return_value = sample_get_response_data + mock_requests.post.return_value.json.return_value = sample_post_response_data + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.1) + + def finish_query(*args, **kwargs): + query._finished = True + return [] + query.fetch = finish_query + query._next_uri = "http://coordinator/v1/statement/next" + query._row_mapper = mock.Mock(map=lambda x: []) + query._start_heartbeat() + time.sleep(0.3) + query._stop_heartbeat() + assert head_call_count >= 2 + + +@mock.patch("trino.client.TrinoRequest.http") +def test_trinoquery_heartbeat_failure_stops(mock_requests, sample_post_response_data, sample_get_response_data): + """Test that heartbeat stops after 3 consecutive failures.""" + def fake_head(url, timeout=10): + class Resp: + status_code = 500 + return Resp() + mock_requests.head.side_effect = fake_head + mock_requests.Response.return_value.json.return_value = sample_post_response_data + mock_requests.get.return_value.json.return_value = sample_get_response_data + mock_requests.post.return_value.json.return_value = sample_post_response_data + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05) + query._next_uri = "http://coordinator/v1/statement/next" + query._row_mapper = mock.Mock(map=lambda x: []) + query._start_heartbeat() + time.sleep(0.3) + assert not query._heartbeat_enabled + query._stop_heartbeat() + + +@mock.patch("trino.client.TrinoRequest.http") +def test_trinoquery_heartbeat_404_405_stops(mock_requests, sample_post_response_data, sample_get_response_data): + """Test that heartbeat stops if server returns 404 or 405.""" + for code in (404, 405): + def fake_head(url, timeout=10, code=code): + class Resp: + status_code = code + return Resp() + mock_requests.head.side_effect = fake_head + mock_requests.Response.return_value.json.return_value = sample_post_response_data + mock_requests.get.return_value.json.return_value = sample_get_response_data + mock_requests.post.return_value.json.return_value = sample_post_response_data + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05) + query._next_uri = "http://coordinator/v1/statement/next" + query._row_mapper = mock.Mock(map=lambda x: []) + query._start_heartbeat() + time.sleep(0.2) + assert not query._heartbeat_enabled + query._stop_heartbeat() + + +@mock.patch("trino.client.TrinoRequest.http") +def test_trinoquery_heartbeat_stops_on_finish(mock_requests, sample_post_response_data, sample_get_response_data): + """Test that heartbeat stops when the query is finished.""" + head_call_count = 0 + + def fake_head(url, timeout=10): + nonlocal head_call_count + head_call_count += 1 + + class Resp: + status_code = 200 + return Resp() + mock_requests.head.side_effect = fake_head + mock_requests.Response.return_value.json.return_value = sample_post_response_data + mock_requests.get.return_value.json.return_value = sample_get_response_data + mock_requests.post.return_value.json.return_value = sample_post_response_data + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05) + query._next_uri = "http://coordinator/v1/statement/next" + query._row_mapper = mock.Mock(map=lambda x: []) + query._start_heartbeat() + time.sleep(0.1) + query._finished = True + time.sleep(0.1) + query._stop_heartbeat() + # Heartbeat should have stopped after query finished + assert head_call_count >= 1 + + +@mock.patch("trino.client.TrinoRequest.http") +def test_trinoquery_heartbeat_stops_on_cancel(mock_requests, sample_post_response_data, sample_get_response_data): + """Test that heartbeat stops when the query is cancelled.""" + head_call_count = 0 + + def fake_head(url, timeout=10): + nonlocal head_call_count + head_call_count += 1 + + class Resp: + status_code = 200 + return Resp() + mock_requests.head.side_effect = fake_head + mock_requests.Response.return_value.json.return_value = sample_post_response_data + mock_requests.get.return_value.json.return_value = sample_get_response_data + mock_requests.post.return_value.json.return_value = sample_post_response_data + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05) + query._next_uri = "http://coordinator/v1/statement/next" + query._row_mapper = mock.Mock(map=lambda x: []) + query._start_heartbeat() + time.sleep(0.1) + query._cancelled = True + time.sleep(0.1) + query._stop_heartbeat() + # Heartbeat should have stopped after query cancelled + assert head_call_count >= 1 + + +# Progress Callback Tests +def test_progress_callback_initialization(): + """Test that progress callback is properly initialized.""" + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + + def callback(status, stats): + pass + + query = TrinoQuery(request=req, query="SELECT 1", progress_callback=callback) + assert query._progress_callback == callback + + # Test without callback + query_no_callback = TrinoQuery(request=req, query="SELECT 1") + assert query_no_callback._progress_callback is None + + +def test_calculate_progress_percentage(): + """Test progress percentage calculation.""" + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + query = TrinoQuery(request=req, query="SELECT 1") + + # Test splits-based calculation + assert query.calculate_progress_percentage({'completedSplits': 5, 'totalSplits': 10}) == 50.0 + assert query.calculate_progress_percentage({'completedSplits': 10, 'totalSplits': 10}) == 100.0 + assert query.calculate_progress_percentage({'completedSplits': 15, 'totalSplits': 10}) == 100.0 # Cap at 100% + + # Test state-based calculation + assert query.calculate_progress_percentage({'state': 'FINISHED'}) == 100.0 + assert query.calculate_progress_percentage({'state': 'RUNNING'}) == 5.0 + assert query.calculate_progress_percentage({'state': 'FAILED'}) == 0.0 + + # Test empty stats + assert query.calculate_progress_percentage({}) == 0.0 + + +@mock.patch("trino.client.TrinoRequest.post") +@mock.patch("trino.client.TrinoRequest.get") +def test_progress_callback_execution(mock_get, mock_post): + """Test that progress callback is called during query execution.""" + callback_calls = [] + + def callback(status, stats): + callback_calls.append((status, stats)) + + # Mock responses + mock_post_response = Mock() + mock_post_response.json.return_value = { + 'id': 'test_query_id', + 'nextUri': 'http://localhost:8080/v1/statement/test_query_id/1', + 'stats': {'state': 'RUNNING', 'completedSplits': 0, 'totalSplits': 10}, + 'data': [], + 'columns': [] + } + mock_post.return_value = mock_post_response + + mock_get_response = Mock() + mock_get_response.json.return_value = { + 'id': 'test_query_id', + 'nextUri': None, + 'stats': {'state': 'FINISHED', 'completedSplits': 10, 'totalSplits': 10}, + 'data': [[1]], + 'columns': [{'name': 'col1', 'type': 'integer'}] + } + mock_get.return_value = mock_get_response + + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + + query = TrinoQuery(request=req, query="SELECT 1", progress_callback=callback) + result = query.execute() + + # Verify callback was called + assert len(callback_calls) > 0 + assert isinstance(callback_calls[0][0], TrinoStatus) + assert isinstance(callback_calls[0][1], dict) diff --git a/trino/client.py b/trino/client.py index 7cc1f0f2..f204cf78 100644 --- a/trino/client.py +++ b/trino/client.py @@ -54,19 +54,21 @@ from email.utils import parsedate_to_datetime from enum import Enum from time import sleep -from typing import Any +from typing import Any, Callable, Optional from typing import cast from typing import Dict from typing import List from typing import Literal -from typing import Optional -from typing import Tuple -from typing import TypedDict from typing import Union +from typing import TypedDict +from typing import Tuple from zoneinfo import ZoneInfo import lz4.block import requests + +# Progress callback type definition +ProgressCallback = Callable[['TrinoStatus', Dict[str, Any]], None] import zstandard from requests import Response from requests import Session @@ -808,7 +810,9 @@ def __init__( request: TrinoRequest, query: str, legacy_primitive_types: bool = False, - fetch_mode: Literal["mapped", "segments"] = "mapped" + fetch_mode: Literal["mapped", "segments"] = "mapped", + heartbeat_interval: float = 60.0, # seconds + progress_callback: Optional[ProgressCallback] = None, ) -> None: self._query_id: Optional[str] = None self._stats: Dict[Any, Any] = {} @@ -826,6 +830,13 @@ def __init__( self._legacy_primitive_types = legacy_primitive_types self._row_mapper: Optional[RowMapper] = None self._fetch_mode = fetch_mode + self._heartbeat_interval = heartbeat_interval + self._heartbeat_thread = None + self._heartbeat_stop_event = threading.Event() + self._heartbeat_failures = 0 + self._heartbeat_enabled = True + self._progress_callback = progress_callback + self._last_progress_stats = {} @property def query_id(self) -> Optional[str]: @@ -868,6 +879,40 @@ def result(self): def info_uri(self): return self._info_uri + def _start_heartbeat(self): + if self._heartbeat_thread is not None: + return + self._heartbeat_stop_event.clear() + self._heartbeat_thread = threading.Thread(target=self._heartbeat_loop, daemon=True) + self._heartbeat_thread.start() + + def _stop_heartbeat(self): + self._heartbeat_stop_event.set() + if self._heartbeat_thread is not None: + self._heartbeat_thread.join(timeout=2) + self._heartbeat_thread = None + + def _heartbeat_loop(self): + while all([not self._heartbeat_stop_event.is_set(), not self.finished, not self.cancelled, + self._heartbeat_enabled]): + if self._next_uri is None: + break + try: + response = self._request.http.head(self._next_uri, timeout=10) + if response.status_code == 404 or response.status_code == 405: + self._heartbeat_enabled = False + break + if response.status_code == 200: + self._heartbeat_failures = 0 + else: + self._heartbeat_failures += 1 + except Exception: + self._heartbeat_failures += 1 + if self._heartbeat_failures >= 3: + self._heartbeat_enabled = False + break + self._heartbeat_stop_event.wait(self._heartbeat_interval) + def execute(self, additional_http_headers=None) -> TrinoResult: """Initiate a Trino query by sending the SQL statement @@ -895,6 +940,9 @@ def execute(self, additional_http_headers=None) -> TrinoResult: rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows self._result = TrinoResult(self, rows) + # Start heartbeat thread + self._start_heartbeat() + # Execute should block until at least one row is received or query is finished or cancelled while not self.finished and not self.cancelled and len(self._result.rows) == 0: self._result.rows += self.fetch() @@ -910,6 +958,10 @@ def _update_state(self, status): legacy_primitive_types=self._legacy_primitive_types) if status.columns: self._columns = status.columns + + # Call progress callback if provided + if self._progress_callback is not None: + self._progress_callback(status, self._stats) def fetch(self) -> List[Union[List[Any]], Any]: """Continue fetching data for the current query_id""" @@ -921,6 +973,7 @@ def fetch(self) -> List[Union[List[Any]], Any]: self._update_state(status) if status.next_uri is None: self._finished = True + self._stop_heartbeat() if not self._row_mapper: return [] @@ -968,6 +1021,7 @@ def cancel(self) -> None: if response.status_code == requests.codes.no_content: self._cancelled = True logger.debug("query cancelled: %s", self.query_id) + self._stop_heartbeat() return self._request.raise_response_error(response) @@ -985,6 +1039,40 @@ def finished(self) -> bool: def cancelled(self) -> bool: return self._cancelled + @property + def is_running(self) -> bool: + """Return True if the query is still running (not finished or cancelled).""" + return not self.finished and not self.cancelled + + def calculate_progress_percentage(self, stats: Dict[str, Any]) -> float: + """ + Calculate progress percentage based on available statistics. + + Args: + stats: The current query statistics from Trino + + Returns: + Progress percentage as a float between 0.0 and 100.0 + """ + # Try to calculate progress based on splits completion + if 'completedSplits' in stats and 'totalSplits' in stats: + completed_splits = stats.get('completedSplits', 0) + total_splits = stats.get('totalSplits', 0) + if total_splits > 0: + return min(100.0, (completed_splits / total_splits) * 100.0) + + # Fallback: check if query is finished + if stats.get('state') == 'FINISHED': + return 100.0 + + # If query is running but we don't have split info, estimate based on time + # This is a rough estimate and may not be accurate + if stats.get('state') == 'RUNNING': + # Return a conservative estimate - could be enhanced with more sophisticated logic + return 5.0 # Assume some progress has been made + + return 0.0 + def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts): def wrapper(func):