Skip to content

Commit 128d5d1

Browse files
committed
tests: fix missing stage sets on state
Note that this change will also break API users (reminder that the API is NOT semantically versioned). The stage must be set correctly for op calls to work, primarily for CLI safety, possibly something to disable in API mode...
1 parent 0c48515 commit 128d5d1

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

tests/test_api/test_api_deploys.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pyinfra.api.connect import connect_all, disconnect_all
33
from pyinfra.api.deploy import add_deploy, deploy
44
from pyinfra.api.operations import run_ops
5+
from pyinfra.api.state import StateStage
56
from pyinfra.operations import server
67

78
from ..paramiko_util import PatchSSHTestCase
@@ -15,6 +16,7 @@ def test_deploy(self):
1516
anotherhost = inventory.get_host("anotherhost")
1617

1718
state = State(inventory, Config())
19+
state.current_stage = StateStage.Prepare
1820

1921
# Enable printing on this test to catch any exceptions in the formatting
2022
state.print_output = True
@@ -74,6 +76,7 @@ def test_nested_deploy(self):
7476
somehost = inventory.get_host("somehost")
7577

7678
state = State(inventory, Config())
79+
state.current_stage = StateStage.Prepare
7780

7881
# Enable printing on this test to catch any exceptions in the formatting
7982
state.print_output = True

tests/test_api/test_api_operations.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pyinfra.api.exceptions import PyinfraError
1919
from pyinfra.api.operation import OperationMeta, add_op
2020
from pyinfra.api.operations import run_ops
21-
from pyinfra.api.state import StateOperationMeta
21+
from pyinfra.api.state import StateOperationMeta, StateStage
2222
from pyinfra.connectors.util import CommandOutput, OutputLine
2323
from pyinfra.context import ctx_host, ctx_state
2424
from pyinfra.operations import files, python, server
@@ -44,6 +44,7 @@ def test_op(self):
4444
anotherhost = inventory.get_host("anotherhost")
4545

4646
state = State(inventory, Config())
47+
state.current_stage = StateStage.Prepare
4748
state.add_callback_handler(BaseStateCallback())
4849

4950
# Enable printing on this test to catch any exceptions in the formatting
@@ -122,6 +123,7 @@ def test_file_upload_op(self):
122123
inventory = make_inventory()
123124

124125
state = State(inventory, Config())
126+
state.current_stage = StateStage.Prepare
125127
connect_all(state)
126128

127129
# Test normal
@@ -197,6 +199,7 @@ def test_file_download_op(self):
197199
inventory = make_inventory()
198200

199201
state = State(inventory, Config())
202+
state.current_stage = StateStage.Prepare
200203
connect_all(state)
201204

202205
with patch("pyinfra.operations.files.os.path.isfile", lambda *args, **kwargs: True):
@@ -236,6 +239,7 @@ def test_file_download_op(self):
236239
def test_function_call_op(self):
237240
inventory = make_inventory()
238241
state = State(inventory, Config())
242+
state.current_stage = StateStage.Prepare
239243
connect_all(state)
240244

241245
is_called = []
@@ -257,6 +261,7 @@ def mocked_function(*args, **kwargs):
257261
def test_run_once_serial_op(self):
258262
inventory = make_inventory()
259263
state = State(inventory, Config())
264+
state.current_stage = StateStage.Prepare
260265
connect_all(state)
261266

262267
# Add a run once op
@@ -280,6 +285,7 @@ def test_run_once_serial_op(self):
280285
def test_rsync_op(self):
281286
inventory = make_inventory(hosts=("somehost",))
282287
state = State(inventory, Config())
288+
state.current_stage = StateStage.Prepare
283289
connect_all(state)
284290

285291
add_op(state, files.rsync, "src", "dest", _sudo=True, _sudo_user="root")
@@ -304,6 +310,7 @@ def test_rsync_op(self):
304310
def test_rsync_op_with_strict_host_key_checking_disabled(self):
305311
inventory = make_inventory(hosts=(("somehost", {"ssh_strict_host_key_checking": "no"}),))
306312
state = State(inventory, Config())
313+
state.current_stage = StateStage.Prepare
307314
connect_all(state)
308315

309316
add_op(state, files.rsync, "src", "dest", _sudo=True, _sudo_user="root")
@@ -338,6 +345,7 @@ def test_rsync_op_with_strict_host_key_checking_disabled_and_custom_config_file(
338345
)
339346
)
340347
state = State(inventory, Config())
348+
state.current_stage = StateStage.Prepare
341349
connect_all(state)
342350

343351
add_op(state, files.rsync, "src", "dest", _sudo=True, _sudo_user="root")
@@ -365,6 +373,7 @@ def test_rsync_op_with_sanitized_custom_config_file(self):
365373
hosts=(("somehost", {"ssh_config_file": "/home/me/ssh_test_config && echo hi"}),)
366374
)
367375
state = State(inventory, Config())
376+
state.current_stage = StateStage.Prepare
368377
connect_all(state)
369378

370379
add_op(state, files.rsync, "src", "dest", _sudo=True, _sudo_user="root")
@@ -389,6 +398,7 @@ def test_rsync_op_with_sanitized_custom_config_file(self):
389398
def test_rsync_op_failure(self):
390399
inventory = make_inventory(hosts=("somehost",))
391400
state = State(inventory, Config())
401+
state.current_stage = StateStage.Prepare
392402
connect_all(state)
393403

394404
with patch("pyinfra.connectors.ssh.which", lambda x: None):
@@ -401,6 +411,7 @@ def test_op_cannot_change_execution_kwargs(self):
401411
inventory = make_inventory()
402412

403413
state = State(inventory, Config())
414+
state.current_stage = StateStage.Prepare
404415

405416
class NoSetDefaultDict(defaultdict):
406417
def setdefault(self, key, _):
@@ -422,6 +433,7 @@ class TestNestedOperationsApi(PatchSSHTestCase):
422433
def test_nested_op_api(self):
423434
inventory = make_inventory()
424435
state = State(inventory, Config())
436+
state.current_stage = StateStage.Prepare
425437

426438
connect_all(state)
427439

@@ -458,6 +470,7 @@ class TestOperationFailures(PatchSSHTestCase):
458470
def test_full_op_fail(self):
459471
inventory = make_inventory()
460472
state = State(inventory, Config())
473+
state.current_stage = StateStage.Prepare
461474
connect_all(state)
462475

463476
add_op(state, server.shell, 'echo "hi"')
@@ -484,6 +497,7 @@ def test_full_op_fail(self):
484497
def test_ignore_errors_op_fail(self):
485498
inventory = make_inventory()
486499
state = State(inventory, Config())
500+
state.current_stage = StateStage.Prepare
487501
connect_all(state)
488502

489503
add_op(state, server.shell, 'echo "hi"', _ignore_errors=True)
@@ -514,6 +528,7 @@ class TestOperationOrdering(PatchSSHTestCase):
514528
def test_cli_op_line_numbers(self):
515529
inventory = make_inventory()
516530
state = State(inventory, Config())
531+
state.current_stage = StateStage.Prepare
517532
connect_all(state)
518533

519534
state.current_deploy_filename = __file__
@@ -560,6 +575,7 @@ def test_cli_op_line_numbers(self):
560575
def test_api_op_line_numbers(self):
561576
inventory = make_inventory()
562577
state = State(inventory, Config())
578+
state.current_stage = StateStage.Prepare
563579
connect_all(state)
564580

565581
another_host = inventory.get_host("anotherhost")
@@ -590,6 +606,7 @@ def test_basic_retry_behavior(self, fake_run_command):
590606
# Create inventory with just one host to simplify testing
591607
inventory = make_inventory(hosts=("somehost",))
592608
state = State(inventory, Config())
609+
state.current_stage = StateStage.Prepare
593610
connect_all(state)
594611

595612
# Add operation with retry settings
@@ -649,6 +666,7 @@ def test_retry_max_attempts_failure(self, fake_run_command):
649666
"""
650667
inventory = make_inventory(hosts=("somehost",))
651668
state = State(inventory, Config())
669+
state.current_stage = StateStage.Prepare
652670
connect_all(state)
653671

654672
# Add operation with retry settings
@@ -698,6 +716,7 @@ def test_retry_until_condition(self, fake_sleep, fake_run_command):
698716
# Setup inventory and state using the utility function
699717
inventory = make_inventory(hosts=("somehost",))
700718
state = State(inventory, Config())
719+
state.current_stage = StateStage.Prepare
701720
connect_all(state)
702721

703722
# Create a counter to track retry_until calls
@@ -757,6 +776,7 @@ def test_retry_delay(self, fake_sleep, fake_run_command):
757776
"""
758777
inventory = make_inventory(hosts=("somehost",))
759778
state = State(inventory, Config())
779+
state.current_stage = StateStage.Prepare
760780
connect_all(state)
761781

762782
retry_delay = 5
@@ -799,6 +819,7 @@ def test_retry_until_with_error_handling(self, fake_sleep, fake_run_command):
799819
"""
800820
inventory = make_inventory(hosts=("somehost",))
801821
state = State(inventory, Config())
822+
state.current_stage = StateStage.Prepare
802823
connect_all(state)
803824

804825
# Create a retry_until function that raises an exception
@@ -842,6 +863,7 @@ def test_retry_until_with_complex_output_parsing(self, fake_sleep, fake_run_comm
842863
"""
843864
inventory = make_inventory(hosts=("somehost",))
844865
state = State(inventory, Config())
866+
state.current_stage = StateStage.Prepare
845867
connect_all(state)
846868

847869
# Track what output we've seen

0 commit comments

Comments
 (0)