Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/core/src/host/module_host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap};
use spacetimedb_datastore::execution_context::{ExecutionContext, ReducerContext, Workload, WorkloadType};
use spacetimedb_datastore::locking_tx_datastore::MutTxId;
use spacetimedb_datastore::traits::{IsolationLevel, Program, TxData};
use spacetimedb_durability::DurableOffset;
use spacetimedb_execution::pipelined::PipelinedProject;
use spacetimedb_lib::db::raw_def::v9::Lifecycle;
use spacetimedb_lib::identity::{AuthCtx, RequestId};
Expand Down Expand Up @@ -1233,6 +1234,10 @@ impl ModuleHost {
&self.replica_ctx().database
}

pub fn durable_tx_offset(&self) -> Option<DurableOffset> {
self.replica_ctx().relational_db.durable_tx_offset()
}

pub(crate) fn replica_ctx(&self) -> &ReplicaContext {
self.module.replica_ctx()
}
Expand Down
5 changes: 3 additions & 2 deletions smoketests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def log_records(self, n):
logs = self.spacetime("logs", "--format=json", "-n", str(n), "--", self.database_identity)
return list(map(json.loads, logs.splitlines()))

def publish_module(self, domain=None, *, clear=True, capture_stderr=True):
def publish_module(self, domain=None, *, clear=True, capture_stderr=True, num_replicas=None):
print("publishing module", self.publish_module)
publish_output = self.spacetime(
"publish",
Expand All @@ -227,10 +227,11 @@ def publish_module(self, domain=None, *, clear=True, capture_stderr=True):
# because the server address is `node` which doesn't look like `localhost` or `127.0.0.1`
# and so the publish step prompts for confirmation.
"--yes",
*["--num-replicas", f"{num_replicas}"] if num_replicas is not None else [],
capture_stderr=capture_stderr,
)
self.resolved_identity = re.search(r"identity: ([0-9a-fA-F]+)", publish_output)[1]
self.database_identity = domain if domain is not None else self.resolved_identity
self.database_identity = self.resolved_identity

@classmethod
def reset_config(cls):
Expand Down
100 changes: 96 additions & 4 deletions smoketests/tests/replication.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .. import COMPOSE_FILE, Smoketest, requires_docker, spacetime, parse_sql_result
from ..docker import DockerManager

import time
from typing import Callable
import unittest
from typing import Callable
import json

from .. import COMPOSE_FILE, Smoketest, random_string, requires_docker, spacetime, parse_sql_result
from ..docker import DockerManager

def retry(func: Callable, max_retries: int = 3, retry_delay: int = 2):
"""Retry a function on failure with delay."""
Expand Down Expand Up @@ -113,6 +114,18 @@ def ensure_leader_health(self, id):
# TODO: Replace with confirmed read.
time.sleep(0.6)

def wait_counter_value(self, id, value, max_attempts=10, delay=1):
"""Wait for the value for `id` in the counter table to reach `value`"""

for _ in range(max_attempts):
rows = self.sql(f"select * from counter where id={id}")
if len(rows) >= 1 and int(rows[0]['value']) >= value:
return
else:
time.sleep(delay)

raise ValueError(f"Counter {id} below {value}")


def fail_leader(self, action='kill'):
"""Force leader failure through either killing or network disconnect."""
Expand Down Expand Up @@ -240,6 +253,9 @@ def start(self, id: int, count: int):
def collect_counter_rows(self):
return int_vals(self.cluster.sql("select * from counter"))

def call_control(self, reducer, *args):
self.spacetime("call", "spacetime-control", reducer, *map(json.dumps, args))


class LeaderElection(ReplicationTest):
def test_leader_election_in_loop(self):
Expand Down Expand Up @@ -393,3 +409,79 @@ def test_quorum_loss(self):
with self.assertRaises(Exception):
for i in range(1001):
self.call("send_message", "terminal")


class EnableReplication(ReplicationTest):
AUTOPUBLISH = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.expected_counter_rows = []

def run_counter(self, id, n = 100):
self.start(id, n)
self.cluster.wait_counter_value(id, n)
self.expected_counter_rows.append({"id": id, "value": n})
self.assertEqual(self.collect_counter_rows(), self.expected_counter_rows)

def test_enable_replication(self):
"""Tests enabling and disabling replication"""

self.add_me_as_admin()
name = random_string()
n = 100

self.publish_module(name, num_replicas = 1)
self.cluster.wait_for_leader_change(None)

# start un-replicated
self.run_counter(1, n)
# enable replication
self.call_control("enable_replication", {"Name": name}, 3)
self.run_counter(2, n)
# disable replication
self.call_control("disable_replication", {"Name": name })
self.run_counter(3, n)
# enable it one more time
self.call_control("enable_replication", {"Name": name}, 3)
self.run_counter(4, n)


class EnableReplicationSuspended(ReplicationTest):
AUTOPUBLISH = False

def test_enable_replication_on_suspended_database(self):
"""Tests that we can enable replication on a suspended database"""

self.add_me_as_admin()
name = random_string()

self.publish_module(name, num_replicas = 1)
self.cluster.wait_for_leader_change(None)
self.cluster.ensure_leader_health(1)

id = self.cluster.get_db_id()

self.call_control("suspend_database", {"Name": name})
# Database is now unreachable.
with self.assertRaises(Exception):
self.call("send_message", "hi")

self.call_control("enable_replication", {"Name": name}, 3)
# Still unreachable until we call unsuspend.
with self.assertRaises(Exception):
self.call("send_message", "hi")

self.call_control("unsuspend_database", {"Name": name})
self.cluster.wait_for_leader_change(None)
self.cluster.ensure_leader_health(2)

# We can't direcly observe that there are indeed three replicas running,
# so as a sanity check inspect the event log.
rows = self.cluster.read_controldb(
f"select message from staged_enable_replication_event where database_id={id}")
self.assertEqual(rows, [
{'message': '"bootstrap requested"'},
{'message': '"bootstrap complete"'},
])
Loading