diff --git a/visualize.py b/visualize.py index 0aba3d5..687f19e 100644 --- a/visualize.py +++ b/visualize.py @@ -35,8 +35,13 @@ def rollout_episode( t_counter = 0 reward_seq = [] obs_seq = [] + state_seq = [] # Also collect states + + # Add initial observation and state (after reset) + obs_seq.append(timestep.observation) + state_seq.append(timestep.state) + while True: - obs_seq.append(timestep.observation) rng, rng_act, rng_step = jax.random.split(rng, 3) if model is not None: obs = obs_to_model_input(timestep.observation, prev_actions, rl_config) @@ -51,19 +56,31 @@ def rollout_episode( timestep = env.step( timestep, wrap_action(action, env.batch_cfg.action_type), rng_step ) + + t_counter += 1 + + # COLLECT OBSERVATION AFTER STEP (includes soil mechanics changes) + obs_seq.append(timestep.observation) + state_seq.append(timestep.state) + + if t_counter <= 3: + action_map = timestep.observation['action_map'] + state_action_map = timestep.state.world.action_map.map + + # Compare first environment + obs_dirt = action_map[0][action_map[0] > 0] if action_map.shape[0] > 0 else [] + state_dirt = state_action_map[0][state_action_map[0] > 0] if state_action_map.shape[0] > 0 else [] + + + reward_seq.append(timestep.reward) print(t_counter, timestep.reward, action, timestep.done) print(10 * "=") - t_counter += 1 - # if done or t_counter == max_frames: - # break - # else: + if jnp.all(timestep.done).item() or t_counter == max_frames: break - # env_state = next_env_state - # obs = next_obs print(f"Terra - Steps: {t_counter}, Return: {np.sum(reward_seq)}") - return obs_seq, np.cumsum(reward_seq) + return obs_seq, np.cumsum(reward_seq), state_seq def update_render(seq, env: TerraEnvBatch, frame): @@ -93,28 +110,28 @@ def update_render(seq, env: TerraEnvBatch, frame): "-nx", "--n_envs_x", type=int, - default=1, + default=3, help="Number of environments on x.", ) parser.add_argument( "-ny", "--n_envs_y", type=int, - default=1, + default=3, help="Number of environments on y.", ) parser.add_argument( "-steps", "--n_steps", type=int, - default=10, + default=100, help="Number of steps.", ) parser.add_argument( "-o", "--out_path", type=str, - default=".", + default="./visualize.gif", help="Output path.", ) parser.add_argument( @@ -152,7 +169,7 @@ def update_render(seq, env: TerraEnvBatch, frame): model_params = log["model"] # replicated_params = log['network'] # model_params = jax.tree_map(lambda x: x[0], replicated_params) - obs_seq, cum_rewards = rollout_episode( + obs_seq, cum_rewards, state_seq = rollout_episode( env, model, model_params, @@ -162,7 +179,14 @@ def update_render(seq, env: TerraEnvBatch, frame): seed=args.seed, ) - for o in tqdm(obs_seq, desc="Rendering"): - env.terra_env.render_obs_pygame(o, generate_gif=True) + for i, o in enumerate(tqdm(obs_seq, desc="Rendering")): + # Try using state action_map instead of observation action_map + if i < len(state_seq): + # Create modified observation with raw state action_map + modified_obs = dict(o) + modified_obs['action_map'] = state_seq[i].world.action_map.map + env.terra_env.render_obs_pygame(modified_obs, generate_gif=True) + else: + env.terra_env.render_obs_pygame(o, generate_gif=True) env.terra_env.rendering_engine.create_gif(args.out_path)