Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@ def rollout_episode(
t_counter = 0
reward_seq = []
obs_seq = []
state_seq = [] # Also collect states

# Add initial observation and state (after reset)
obs_seq.append(timestep.observation)
state_seq.append(timestep.state)

while True:
obs_seq.append(timestep.observation)
rng, rng_act, rng_step = jax.random.split(rng, 3)
if model is not None:
obs = obs_to_model_input(timestep.observation, prev_actions, rl_config)
Expand All @@ -51,19 +56,31 @@ def rollout_episode(
timestep = env.step(
timestep, wrap_action(action, env.batch_cfg.action_type), rng_step
)

t_counter += 1

# COLLECT OBSERVATION AFTER STEP (includes soil mechanics changes)
obs_seq.append(timestep.observation)
state_seq.append(timestep.state)

if t_counter <= 3:
action_map = timestep.observation['action_map']
state_action_map = timestep.state.world.action_map.map

# Compare first environment
obs_dirt = action_map[0][action_map[0] > 0] if action_map.shape[0] > 0 else []
state_dirt = state_action_map[0][state_action_map[0] > 0] if state_action_map.shape[0] > 0 else []



reward_seq.append(timestep.reward)
print(t_counter, timestep.reward, action, timestep.done)
print(10 * "=")
t_counter += 1
# if done or t_counter == max_frames:
# break
# else:

if jnp.all(timestep.done).item() or t_counter == max_frames:
break
# env_state = next_env_state
# obs = next_obs
print(f"Terra - Steps: {t_counter}, Return: {np.sum(reward_seq)}")
return obs_seq, np.cumsum(reward_seq)
return obs_seq, np.cumsum(reward_seq), state_seq


def update_render(seq, env: TerraEnvBatch, frame):
Expand Down Expand Up @@ -93,28 +110,28 @@ def update_render(seq, env: TerraEnvBatch, frame):
"-nx",
"--n_envs_x",
type=int,
default=1,
default=3,
help="Number of environments on x.",
)
parser.add_argument(
"-ny",
"--n_envs_y",
type=int,
default=1,
default=3,
help="Number of environments on y.",
)
parser.add_argument(
"-steps",
"--n_steps",
type=int,
default=10,
default=100,
help="Number of steps.",
)
parser.add_argument(
"-o",
"--out_path",
type=str,
default=".",
default="./visualize.gif",
help="Output path.",
)
parser.add_argument(
Expand Down Expand Up @@ -152,7 +169,7 @@ def update_render(seq, env: TerraEnvBatch, frame):
model_params = log["model"]
# replicated_params = log['network']
# model_params = jax.tree_map(lambda x: x[0], replicated_params)
obs_seq, cum_rewards = rollout_episode(
obs_seq, cum_rewards, state_seq = rollout_episode(
env,
model,
model_params,
Expand All @@ -162,7 +179,14 @@ def update_render(seq, env: TerraEnvBatch, frame):
seed=args.seed,
)

for o in tqdm(obs_seq, desc="Rendering"):
env.terra_env.render_obs_pygame(o, generate_gif=True)
for i, o in enumerate(tqdm(obs_seq, desc="Rendering")):
# Try using state action_map instead of observation action_map
if i < len(state_seq):
# Create modified observation with raw state action_map
modified_obs = dict(o)
modified_obs['action_map'] = state_seq[i].world.action_map.map
env.terra_env.render_obs_pygame(modified_obs, generate_gif=True)
else:
env.terra_env.render_obs_pygame(o, generate_gif=True)

env.terra_env.rendering_engine.create_gif(args.out_path)