Skip to content

Commit 8a00db3

Browse files
committed
make freeze optional
1 parent 89246f2 commit 8a00db3

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

exts/nav_tasks/nav_tasks/mdp/actions/navigation_actions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def __init__(self, cfg: NavigationSE2ActionCfg, env: ManagerBasedRLEnv):
3030
# load policies
3131
file_bytes = read_file(self.cfg.low_level_policy_file)
3232
self.low_level_policy = torch.jit.load(file_bytes, map_location=self.device)
33-
self.low_level_policy = torch.jit.freeze(self.low_level_policy.eval())
33+
self.low_level_policy.eval()
34+
if self.cfg.freeze_low_level_policy:
35+
self.low_level_policy = torch.jit.freeze(self.low_level_policy)
3436

3537
# prepare joint position actions
3638
if not isinstance(self.cfg.low_level_action, list):

exts/nav_tasks/nav_tasks/mdp/actions/navigation_actions_cfg.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ class NavigationSE2ActionCfg(ActionTermCfg):
2828
low_level_policy_file: str = MISSING
2929
"""Path to the low level policy file."""
3030

31+
freeze_low_level_policy: bool = True
32+
"""Whether to freeze the low level policy.
33+
34+
Can improve performance but will also eliminate possible functions such as `reset`."""
35+
3136
low_level_obs_group: str = "low_level_policy"
3237
"""Observation group of the low level policy."""
3338

0 commit comments

Comments
 (0)