Skip to content

Commit fa9e94e

Browse files
committed
move trial_fn from trial table to experiment table
1 parent 8f34e56 commit fa9e94e

File tree

9 files changed

+89
-122
lines changed

9 files changed

+89
-122
lines changed

python/powerlift/powerlift/bench/benchmark.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import os
1717
import numpy as np
18+
import inspect
1819

1920

2021
class Benchmark:
@@ -81,6 +82,8 @@ def run(
8182
wheel = db.Wheel(name=name, embedded=content)
8283
wheels.append(wheel)
8384

85+
trial_fn = inspect.getsource(trial_run_fn)
86+
8487
self._store.reset()
8588
while self._store.do:
8689
with self._store:
@@ -92,6 +95,7 @@ def run(
9295
shell_install,
9396
pip_install,
9497
script_contents,
98+
trial_fn,
9599
wheels,
96100
)
97101

@@ -185,9 +189,7 @@ def run(
185189
if executor is None:
186190
executor = LocalMachine(self._store)
187191
self._executors.add(executor)
188-
executor.submit(
189-
self._experiment_id, trial_run_fn, pending_trials, timeout=timeout
190-
)
192+
executor.submit(self._experiment_id, pending_trials, timeout=timeout)
191193
return executor
192194

193195
def wait_until_complete(self):

python/powerlift/powerlift/bench/experiment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class Experiment:
146146
shell_install: str
147147
pip_install: str
148148
script: str
149+
trial_fn: str
149150
wheels: List[Wheel]
150151
trials: List
151152

python/powerlift/powerlift/bench/store.py

Lines changed: 12 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,9 @@
4444
import traceback as tb
4545

4646

47-
def _parse_function(src):
48-
src_ast = ast.parse(src)
49-
if isinstance(src_ast, ast.Module) and isinstance(src_ast.body[0], ast.FunctionDef):
50-
return src_ast
51-
return None
52-
53-
54-
def _compile_function(src_ast):
55-
func_name = r"wired_function"
56-
src_ast.body[0].name = func_name
57-
compiled = compile(src_ast, "<string>", "exec")
58-
scope = locals()
59-
exec(compiled, scope, scope)
60-
return locals()[func_name]
61-
62-
6347
MIMETYPE_DF = "application/vnd.interpretml/parquet-df"
6448
MIMETYPE_SERIES = "application/vnd.interpretml/parquet-series"
6549
MIMETYPE_JSON = "application/json"
66-
MIMETYPE_FUNC = "application/vnd.interpretml/function-str"
6750

6851

6952
class BytesParser:
@@ -84,13 +67,6 @@ def deserialize(cls, mimetype, bytes):
8467
return pd.read_parquet(bstream)
8568
elif mimetype == MIMETYPE_SERIES:
8669
return pd.read_parquet(bstream)["Target"]
87-
elif mimetype == MIMETYPE_FUNC:
88-
src = bstream.getvalue().decode("utf-8")
89-
src_ast = _parse_function(src)
90-
if src_ast is None:
91-
raise RuntimeError("Serialized code not valid.")
92-
compiled_func = _compile_function(src_ast)
93-
return compiled_func
9470
else:
9571
return None
9672

@@ -125,13 +101,6 @@ def serialize(cls, obj):
125101
elif isinstance(obj, dict):
126102
bstream.write(json.dumps(obj).encode())
127103
mimetype = MIMETYPE_JSON
128-
elif isinstance(obj, FunctionType):
129-
src = inspect.getsource(obj)
130-
src_ast = _parse_function(src)
131-
if src_ast is None:
132-
raise RuntimeError("Serialized code not valid.")
133-
bstream.write(src.encode("utf-8"))
134-
mimetype = MIMETYPE_FUNC
135104
else:
136105
return None, None
137106

@@ -449,33 +418,6 @@ def end_trial(self, trial_id, errmsg=None):
449418
result = self._session.execute(query, params)
450419
rowcount = result.rowcount
451420

452-
def add_trial_run_fn(self, trial_ids, trial_run_fn):
453-
import sys
454-
455-
mimetype, bstream = BytesParser.serialize(trial_run_fn)
456-
trial_run_fn_asset_orm = db.Asset(
457-
name="trial_run_fn",
458-
description="Serialized trial run function.",
459-
version=sys.version,
460-
is_embedded=True,
461-
embedded=bstream.getvalue(),
462-
mimetype=mimetype,
463-
)
464-
465-
self.reset()
466-
while self.do:
467-
with self:
468-
trial_orms = self._session.query(db.Trial).filter(
469-
db.Trial.id.in_(trial_ids)
470-
)
471-
for trial_orm in trial_orms:
472-
trial_orm.input_assets.append(trial_run_fn_asset_orm)
473-
474-
if trial_orms.first() is not None:
475-
orms = [trial_run_fn_asset_orm]
476-
self._session.bulk_save_objects(orms, return_defaults=True)
477-
return None
478-
479421
def measure_from_db_task(self, task_orm):
480422
self.check_allowed()
481423
from powerlift.bench.experiment import Measure
@@ -581,6 +523,7 @@ def from_db_experiment(self, experiment_orm):
581523
experiment_orm.shell_install,
582524
experiment_orm.pip_install,
583525
experiment_orm.script,
526+
experiment_orm.trial_fn,
584527
wheels,
585528
trials,
586529
)
@@ -631,6 +574,15 @@ def find_task_by_id(self, _id: int):
631574
return None
632575
return self.from_db_task(task_orm)
633576

577+
def get_trial_fn(self, experiment_id) -> str:
578+
self.reset()
579+
while self.do:
580+
with self:
581+
trial_fn = self._session.execute(
582+
text(f"SELECT trial_fn FROM experiment WHERE id={experiment_id}")
583+
).scalar()
584+
return trial_fn
585+
634586
def pick_trial(self, experiment_id, runner_id):
635587
self.reset()
636588
while self.do:
@@ -680,6 +632,7 @@ def create_experiment(
680632
shell_install: str = None,
681633
pip_install: str = None,
682634
script: str = None,
635+
trial_fn: str = None,
683636
wheels=None,
684637
) -> Tuple[int, bool]:
685638
"""Create experiment keyed by name."""
@@ -692,6 +645,7 @@ def create_experiment(
692645
shell_install=shell_install,
693646
pip_install=pip_install,
694647
script=script,
648+
trial_fn=trial_fn,
695649
)
696650

697651
if wheels is not None:

python/powerlift/powerlift/db/schema.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,13 @@ class Experiment(Base):
9191
"""The overall experiment, includes access to trials."""
9292

9393
__tablename__ = "experiment"
94-
id = Column(Integer, primary_key=True)
95-
name = Column(String(NAME_LEN), unique=True)
94+
id = Column(Integer, primary_key=True, nullable=False)
95+
name = Column(String(NAME_LEN), unique=True, nullable=False)
9696
description = Column(String(DESCRIPTION_LEN))
9797
shell_install = Column(Text)
9898
pip_install = Column(Text)
99-
script = Column(Text)
99+
script = Column(Text, nullable=False)
100+
trial_fn = Column(Text, nullable=False)
100101

101102
# TODO: consider removing the wheel relationship since it means we
102103
# spend time downloading the wheels each time we query the experiment

python/powerlift/powerlift/executors/azure_ci.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,12 @@ def delete_credentials(self):
9191
"""Deletes credentials in object for accessing Azure Resources."""
9292
del self._azure_json
9393

94-
def submit(self, experiment_id, trial_run_fn, trials: List, timeout=None):
94+
def submit(self, experiment_id, trials: List, timeout=None):
9595
from powerlift.run_azure import __main__ as remote_process
9696

9797
uri = (
9898
self._docker_db_uri if self._docker_db_uri is not None else self._store.uri
9999
)
100-
self._store.add_trial_run_fn([x.id for x in trials], trial_run_fn)
101100

102101
n_runners = min(len(trials), self._n_running_containers)
103102
params = (

python/powerlift/powerlift/executors/docker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,10 @@ def __init__(
7171
wheel_filepaths=wheel_filepaths,
7272
)
7373

74-
def submit(self, experiment_id, trial_run_fn, trials: List, timeout=None):
74+
def submit(self, experiment_id, trials: List, timeout=None):
7575
uri = (
7676
self._docker_db_uri if self._docker_db_uri is not None else self._store.uri
7777
)
78-
self._store.add_trial_run_fn([x.id for x in trials], trial_run_fn)
7978

8079
n_runners = min(
8180
len(trials),

python/powerlift/powerlift/executors/localmachine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,9 @@ def __del__(self):
4545
if self._pool is not None:
4646
self._pool.close()
4747

48-
def submit(self, experiment_id, trial_run_fn, trials: List, timeout=None):
48+
def submit(self, experiment_id, trials: List, timeout=None):
4949
from powerlift.run import __main__ as runner
5050

51-
self._store.add_trial_run_fn([x.id for x in trials], trial_run_fn)
5251
n_runners = min(
5352
len(trials),
5453
multiprocessing.cpu_count() if self._n_cpus is None else self._n_cpus,

python/powerlift/powerlift/run/__main__.py

Lines changed: 63 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@ def run_trials(
1414
from powerlift.bench.store import Store
1515
import traceback
1616
from powerlift.executors.base import timed_run
17-
from powerlift.bench.store import MIMETYPE_FUNC, BytesParser
18-
from powerlift.bench.experiment import Store
19-
import subprocess
20-
import tempfile
21-
from pathlib import Path
22-
import sys
17+
import ast
2318

2419
if is_remote:
2520
print_exceptions = True
@@ -29,6 +24,24 @@ def run_trials(
2924
max_attempts = 5
3025

3126
store = Store(db_url, print_exceptions=print_exceptions, max_attempts=max_attempts)
27+
28+
if debug_fn is not None:
29+
trial_run_fn = debug_fn
30+
else:
31+
trial_run_fn = store.get_trial_fn(experiment_id)
32+
trial_run_fn = ast.parse(trial_run_fn)
33+
if not isinstance(trial_run_fn, ast.Module) or not isinstance(
34+
trial_run_fn.body[0], ast.FunctionDef
35+
):
36+
raise RuntimeError("Serialized code not valid.")
37+
38+
func_name = r"wired_function"
39+
trial_run_fn.body[0].name = func_name
40+
compiled = compile(trial_run_fn, "<string>", "exec")
41+
scope = locals()
42+
exec(compiled, scope, scope)
43+
trial_run_fn = locals()[func_name]
44+
3245
while True:
3346
trial_id = store.pick_trial(experiment_id, runner_id)
3447
if trial_id is None:
@@ -40,21 +53,6 @@ def run_trials(
4053
if trial is None:
4154
raise RuntimeError(f"No trial found for id {trial_id}")
4255

43-
# Handle input assets
44-
trial_run_fn = None
45-
for input_asset in trial.input_assets:
46-
if input_asset.mimetype == MIMETYPE_FUNC:
47-
trial_run_fn = BytesParser.deserialize(
48-
MIMETYPE_FUNC, input_asset.embedded
49-
)
50-
else:
51-
continue
52-
if debug_fn is not None:
53-
trial_run_fn = debug_fn
54-
55-
if trial_run_fn is None:
56-
raise RuntimeError("No trial run function found.")
57-
5856
# Run trial
5957
errmsg = None
6058
try:
@@ -72,36 +70,49 @@ def run_trials(
7270

7371

7472
if __name__ == "__main__":
75-
import os
76-
import time
77-
78-
experiment_id = os.getenv("EXPERIMENT_ID")
79-
runner_id = os.getenv("RUNNER_ID")
80-
db_url = os.getenv("DB_URL")
81-
timeout = float(os.getenv("TIMEOUT", 0.0))
82-
raise_exception = True if os.getenv("RAISE_EXCEPTION", False) == "True" else False
83-
run_trials(
84-
experiment_id, runner_id, db_url, timeout, raise_exception, is_remote=True
85-
)
73+
print("STARTING RUNNER")
8674

87-
# below here is Azure specific. Make optional in the future
88-
89-
from azure.identity import ManagedIdentityCredential
90-
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
91-
92-
subscription_id = os.getenv("SUBSCRIPTION_ID")
93-
resource_group_name = os.getenv("RESOURCE_GROUP_NAME")
94-
container_group_name = os.getenv("CONTAINER_GROUP_NAME")
95-
96-
credential = ManagedIdentityCredential()
97-
aci_client = ContainerInstanceManagementClient(credential, subscription_id)
98-
99-
# self-delete the container that we're running on
100-
delete_poller = aci_client.container_groups.begin_delete(
101-
resource_group_name, container_group_name
102-
)
103-
while not delete_poller.done():
104-
print("Waiting to be deleted..")
105-
time.sleep(60)
75+
import time
76+
import traceback
10677

107-
print("THIS LINE SHOULD NEVER EXECUTE SINCE THIS CONTAINER SHOULD BE DELETED.")
78+
try:
79+
import os
80+
81+
experiment_id = os.getenv("EXPERIMENT_ID")
82+
runner_id = os.getenv("RUNNER_ID")
83+
db_url = os.getenv("DB_URL")
84+
timeout = float(os.getenv("TIMEOUT", 0.0))
85+
raise_exception = (
86+
True if os.getenv("RAISE_EXCEPTION", False) == "True" else False
87+
)
88+
run_trials(
89+
experiment_id, runner_id, db_url, timeout, raise_exception, is_remote=True
90+
)
91+
92+
# below here is Azure specific. Make optional in the future
93+
94+
from azure.identity import ManagedIdentityCredential
95+
from azure.mgmt.containerinstance import ContainerInstanceManagementClient
96+
97+
subscription_id = os.getenv("SUBSCRIPTION_ID")
98+
resource_group_name = os.getenv("RESOURCE_GROUP_NAME")
99+
container_group_name = os.getenv("CONTAINER_GROUP_NAME")
100+
101+
credential = ManagedIdentityCredential()
102+
aci_client = ContainerInstanceManagementClient(credential, subscription_id)
103+
104+
# self-delete the container that we're running on
105+
delete_poller = aci_client.container_groups.begin_delete(
106+
resource_group_name, container_group_name
107+
)
108+
while not delete_poller.done():
109+
print("Waiting to be deleted..")
110+
time.sleep(60)
111+
112+
print("THIS LINE SHOULD NEVER EXECUTE SINCE THIS CONTAINER SHOULD BE DELETED.")
113+
except Exception as e:
114+
print("EXCEPTION:")
115+
print("".join(traceback.format_exception(type(e), e, e.__traceback__)))
116+
for _ in range(60 * 60 * 24): # wait 24 hours
117+
time.sleep(1)
118+
print("Unandled exception.")

python/powerlift/powerlift/run_azure/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def run_azure_process(
4747
fi
4848
result=$(psql "$DB_URL" -c "SELECT script FROM Experiment WHERE id='$EXPERIMENT_ID' LIMIT 1;" -t -A)
4949
printf "%s" "$result" > "startup.py"
50+
echo "Running startup.py"
5051
python startup.py
5152
"""
5253

0 commit comments

Comments
 (0)