Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,74 @@ assert rows[0][0] == "-2001-08-22"
assert cur.description[0][1] == "date"
```

## Progress Callback

The Trino client supports progress callbacks to track query execution progress in real-time. you can provide a callback function that gets called whenever the query status is updated.

### Basic Usage

```python
from trino.client import TrinoQuery, TrinoRequest, ClientSession, TrinoStatus
from typing import Dict, Any

def progress_callback(status: TrinoStatus, stats: Dict[str, Any]) -> None:
"""Progress callback function that gets called whenever the query status is updated."""
state = stats.get('state', 'UNKNOWN')
processed_bytes = stats.get('processedBytes', 0)
processed_rows = stats.get('processedRows', 0)
completed_splits = stats.get('completedSplits', 0)
total_splits = stats.get('totalSplits', 0)

print(f"Query {status.id}: {state} - {processed_bytes} bytes, {processed_rows} rows")
if total_splits > 0:
progress = (completed_splits / total_splits) * 100.0
print(f"Progress: {progress:.1f}% ({completed_splits}/{total_splits} splits)")

session = ClientSession(user="test_user", catalog="memory", schema="default")

request = TrinoRequest(
host="localhost",
port=8080,
client_session=session,
http_scheme="http"
)

query = TrinoQuery(
request=request,
query="SELECT * FROM large_table",
progress_callback=progress_callback
)

result = query.execute()

while not query.finished:
rows = query.fetch()
```

### Progress Calculation

The callback receives a `stats` dictionary containing various metrics that can be used to calculate progress:

- `state`: Query state (RUNNING, FINISHED, FAILED, etc.)
- `processedBytes`: Total bytes processed
- `processedRows`: Total rows processed
- `completedSplits`: Number of completed splits
- `totalSplits`: Total number of splits

The most accurate progress calculation is based on splits completion:

```python
def calculate_progress(stats: Dict[str, Any]) -> float:
"""Calculate progress percentage based on splits completion."""
completed_splits = stats.get('completedSplits', 0)
total_splits = stats.get('totalSplits', 0)
if total_splits > 0:
return min(100.0, (completed_splits / total_splits) * 100.0)
elif stats.get('state') == 'FINISHED':
return 100.0
return 0.0
```

### Trino to Python type mappings

| Trino type | Python type |
Expand Down
248 changes: 248 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,3 +1447,251 @@ def delete_password(self, servicename, username):
return None

os.remove(file_path)


@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_success(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat is sent periodically and does not stop on success."""
head_call_count = 0

def fake_head(url, timeout=10):
nonlocal head_call_count
head_call_count += 1

class Resp:
status_code = 200
return Resp()
mock_requests.head.side_effect = fake_head
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.1)

def finish_query(*args, **kwargs):
query._finished = True
return []
query.fetch = finish_query
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.3)
query._stop_heartbeat()
assert head_call_count >= 2


@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_failure_stops(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat stops after 3 consecutive failures."""
def fake_head(url, timeout=10):
class Resp:
status_code = 500
return Resp()
mock_requests.head.side_effect = fake_head
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05)
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.3)
assert not query._heartbeat_enabled
query._stop_heartbeat()


@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_404_405_stops(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat stops if server returns 404 or 405."""
for code in (404, 405):
def fake_head(url, timeout=10, code=code):
class Resp:
status_code = code
return Resp()
mock_requests.head.side_effect = fake_head
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05)
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.2)
assert not query._heartbeat_enabled
query._stop_heartbeat()


@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_stops_on_finish(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat stops when the query is finished."""
head_call_count = 0

def fake_head(url, timeout=10):
nonlocal head_call_count
head_call_count += 1

class Resp:
status_code = 200
return Resp()
mock_requests.head.side_effect = fake_head
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05)
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.1)
query._finished = True
time.sleep(0.1)
query._stop_heartbeat()
# Heartbeat should have stopped after query finished
assert head_call_count >= 1


@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_stops_on_cancel(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat stops when the query is cancelled."""
head_call_count = 0

def fake_head(url, timeout=10):
nonlocal head_call_count
head_call_count += 1

class Resp:
status_code = 200
return Resp()
mock_requests.head.side_effect = fake_head
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05)
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.1)
query._cancelled = True
time.sleep(0.1)
query._stop_heartbeat()
# Heartbeat should have stopped after query cancelled
assert head_call_count >= 1


# Progress Callback Tests
def test_progress_callback_initialization():
"""Test that progress callback is properly initialized."""
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)

def callback(status, stats):
pass

query = TrinoQuery(request=req, query="SELECT 1", progress_callback=callback)
assert query._progress_callback == callback

# Test without callback
query_no_callback = TrinoQuery(request=req, query="SELECT 1")
assert query_no_callback._progress_callback is None


def test_calculate_progress_percentage():
"""Test progress percentage calculation."""
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1")

# Test splits-based calculation
assert query.calculate_progress_percentage({'completedSplits': 5, 'totalSplits': 10}) == 50.0
assert query.calculate_progress_percentage({'completedSplits': 10, 'totalSplits': 10}) == 100.0
assert query.calculate_progress_percentage({'completedSplits': 15, 'totalSplits': 10}) == 100.0 # Cap at 100%

# Test state-based calculation
assert query.calculate_progress_percentage({'state': 'FINISHED'}) == 100.0
assert query.calculate_progress_percentage({'state': 'RUNNING'}) == 5.0
assert query.calculate_progress_percentage({'state': 'FAILED'}) == 0.0

# Test empty stats
assert query.calculate_progress_percentage({}) == 0.0


@mock.patch("trino.client.TrinoRequest.post")
@mock.patch("trino.client.TrinoRequest.get")
def test_progress_callback_execution(mock_get, mock_post):
"""Test that progress callback is called during query execution."""
callback_calls = []

def callback(status, stats):
callback_calls.append((status, stats))

# Mock responses
mock_post_response = Mock()
mock_post_response.json.return_value = {
'id': 'test_query_id',
'nextUri': 'http://localhost:8080/v1/statement/test_query_id/1',
'stats': {'state': 'RUNNING', 'completedSplits': 0, 'totalSplits': 10},
'data': [],
'columns': []
}
mock_post.return_value = mock_post_response

mock_get_response = Mock()
mock_get_response.json.return_value = {
'id': 'test_query_id',
'nextUri': None,
'stats': {'state': 'FINISHED', 'completedSplits': 10, 'totalSplits': 10},
'data': [[1]],
'columns': [{'name': 'col1', 'type': 'integer'}]
}
mock_get.return_value = mock_get_response

req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)

query = TrinoQuery(request=req, query="SELECT 1", progress_callback=callback)
result = query.execute()

# Verify callback was called
assert len(callback_calls) > 0
assert isinstance(callback_calls[0][0], TrinoStatus)
assert isinstance(callback_calls[0][1], dict)
Loading
Loading