Skip to content

Commit 20e08b8

Browse files
authored
feat: video
* replace multiprocessing with pyav * add config.RECORD_WINDOW_DATA * video_write_q * fix get_timestamp; extract_frames_to_pil_images with pyav * add video.py; ActionEvent.original_timestamp * use global SCT in get_monitor_dims * fix tests * fix window._windows.get_active_window_state (missing type) * add tests/openadapt/test_video.py * flake8 * black * poetry lock
1 parent 7759996 commit 20e08b8

File tree

17 files changed

+2256
-1642
lines changed

17 files changed

+2256
-1642
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@ src
3030

3131
# MacOS file
3232
.DS_Store
33+
34+
*.pyc
35+
*.pt

openadapt/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"OPENAI_API_KEY": "<set your api key in .env>",
3434
# "OPENAI_MODEL_NAME": "gpt-4",
3535
"OPENAI_MODEL_NAME": "gpt-3.5-turbo",
36+
"RECORD_WINDOW_DATA": False,
3637
# may incur significant performance penalty
3738
"RECORD_READ_ACTIVE_ELEMENT_STATE": False,
3839
# TODO: remove?
@@ -99,7 +100,6 @@
99100
"key_vk",
100101
"children",
101102
],
102-
"PLOT_PERFORMANCE": True,
103103
# VISUALIZATION CONFIGURATIONS
104104
"VISUALIZE_DARK_MODE": False,
105105
"VISUALIZE_RUN_NATIVELY": True,
@@ -111,6 +111,9 @@
111111
"SAVE_SCREENSHOT_DIFF": False,
112112
"SPACY_MODEL_NAME": "en_core_web_trf",
113113
"PRIVATE_AI_API_KEY": "<set your api key in .env>",
114+
"RECORD_VIDEO": False,
115+
"RECORD_IMAGES": True,
116+
"VIDEO_PIXEL_FORMAT": "rgb24",
114117
}
115118

116119
# each string in STOP_STRS should only contain strings

openadapt/db/crud.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ def _insert(
7070

7171

7272
def insert_action_event(
73-
recording_timestamp: int, event_timestamp: int, event_data: dict[str, Any]
73+
recording_timestamp: float, event_timestamp: int, event_data: dict[str, Any]
7474
) -> None:
7575
"""Insert an action event into the database.
7676
7777
Args:
78-
recording_timestamp (int): The timestamp of the recording.
78+
recording_timestamp (float): The timestamp of the recording.
7979
event_timestamp (int): The timestamp of the event.
8080
event_data (dict): The data of the event.
8181
"""
@@ -88,12 +88,12 @@ def insert_action_event(
8888

8989

9090
def insert_screenshot(
91-
recording_timestamp: int, event_timestamp: int, event_data: dict[str, Any]
91+
recording_timestamp: float, event_timestamp: int, event_data: dict[str, Any]
9292
) -> None:
9393
"""Insert a screenshot into the database.
9494
9595
Args:
96-
recording_timestamp (int): The timestamp of the recording.
96+
recording_timestamp (float): The timestamp of the recording.
9797
event_timestamp (int): The timestamp of the event.
9898
event_data (dict): The data of the event.
9999
"""
@@ -106,14 +106,14 @@ def insert_screenshot(
106106

107107

108108
def insert_window_event(
109-
recording_timestamp: int,
109+
recording_timestamp: float,
110110
event_timestamp: int,
111111
event_data: dict[str, Any],
112112
) -> None:
113113
"""Insert a window event into the database.
114114
115115
Args:
116-
recording_timestamp (int): The timestamp of the recording.
116+
recording_timestamp (float): The timestamp of the recording.
117117
event_timestamp (int): The timestamp of the event.
118118
event_data (dict): The data of the event.
119119
"""
@@ -126,15 +126,15 @@ def insert_window_event(
126126

127127

128128
def insert_perf_stat(
129-
recording_timestamp: int,
129+
recording_timestamp: float,
130130
event_type: str,
131131
start_time: float,
132132
end_time: float,
133133
) -> None:
134134
"""Insert an event performance stat into the database.
135135
136136
Args:
137-
recording_timestamp (int): The timestamp of the recording.
137+
recording_timestamp (float): The timestamp of the recording.
138138
event_type (str): The type of the event.
139139
start_time (float): The start time of the event.
140140
end_time (float): The end time of the event.
@@ -148,11 +148,11 @@ def insert_perf_stat(
148148
_insert(event_perf_stat, PerformanceStat, performance_stats)
149149

150150

151-
def get_perf_stats(recording_timestamp: int) -> list[PerformanceStat]:
151+
def get_perf_stats(recording_timestamp: float) -> list[PerformanceStat]:
152152
"""Get performance stats for a given recording.
153153
154154
Args:
155-
recording_timestamp (int): The timestamp of the recording.
155+
recording_timestamp (float): The timestamp of the recording.
156156
157157
Returns:
158158
list[PerformanceStat]: A list of performance stats for the recording.
@@ -166,7 +166,7 @@ def get_perf_stats(recording_timestamp: int) -> list[PerformanceStat]:
166166

167167

168168
def insert_memory_stat(
169-
recording_timestamp: int, memory_usage_bytes: int, timestamp: int
169+
recording_timestamp: float, memory_usage_bytes: int, timestamp: int
170170
) -> None:
171171
"""Insert memory stat into db."""
172172
memory_stat = {
@@ -177,7 +177,7 @@ def insert_memory_stat(
177177
_insert(memory_stat, MemoryStat, memory_stats)
178178

179179

180-
def get_memory_stats(recording_timestamp: int) -> None:
180+
def get_memory_stats(recording_timestamp: float) -> None:
181181
"""Return memory stats for a given recording."""
182182
return (
183183
db.query(MemoryStat)
@@ -196,7 +196,7 @@ def insert_recording(recording_data: Recording) -> Recording:
196196
return db_obj
197197

198198

199-
def delete_recording(recording_timestamp: int) -> None:
199+
def delete_recording(recording_timestamp: float) -> None:
200200
"""Remove the recording from the db."""
201201
db.query(Recording).filter(Recording.timestamp == recording_timestamp).delete()
202202
db.commit()
@@ -241,12 +241,12 @@ def get_recording(timestamp: int) -> Recording:
241241
return db.query(Recording).filter(Recording.timestamp == timestamp).first()
242242

243243

244-
def _get(table: BaseModel, recording_timestamp: int) -> list[BaseModel]:
244+
def _get(table: BaseModel, recording_timestamp: float) -> list[BaseModel]:
245245
"""Retrieve records from the database table based on the recording timestamp.
246246
247247
Args:
248248
table (BaseModel): The database table to query.
249-
recording_timestamp (int): The recording timestamp to filter the records.
249+
recording_timestamp (float): The recording timestamp to filter the records.
250250
251251
Returns:
252252
list[BaseModel]: A list of records retrieved from the database table,
@@ -420,3 +420,33 @@ def new_session() -> None:
420420
if db:
421421
db.close()
422422
db = Session()
423+
424+
425+
def update_video_start_time(
426+
recording_timestamp: float, video_start_time: float
427+
) -> None:
428+
"""Update the video start time of a specific recording.
429+
430+
Args:
431+
recording_timestamp (float): The timestamp of the recording to update.
432+
video_start_time (float): The new video start time to set.
433+
"""
434+
# Find the recording by its timestamp
435+
recording = (
436+
db.query(Recording).filter(Recording.timestamp == recording_timestamp).first()
437+
)
438+
439+
if not recording:
440+
logger.error(f"No recording found with timestamp {recording_timestamp}.")
441+
return
442+
443+
# Update the video start time
444+
recording.video_start_time = video_start_time
445+
446+
# Commit the changes to the database
447+
db.commit()
448+
449+
logger.info(
450+
f"Updated video start time for recording {recording_timestamp} to"
451+
f" {video_start_time}."
452+
)

openadapt/deprecated/visualize.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
from bokeh.models.widgets import Div
1212
from loguru import logger
1313
from tqdm import tqdm
14+
import fire
1415

15-
from openadapt import config
16+
from openadapt import config, video
1617
from openadapt.db.crud import get_latest_recording
1718
from openadapt.events import get_events
1819
from openadapt.models import Recording
1920
from openadapt.privacy.providers.presidio import PresidioScrubbingProvider
2021
from openadapt.utils import (
2122
EMPTY,
23+
compute_diff,
2224
configure_logging,
2325
display_event,
2426
evenly_spaced,
@@ -184,11 +186,17 @@ def dict2html(
184186

185187

186188
@logger.catch
187-
def main(recording: Recording = None) -> bool:
189+
def main(
190+
recording: Recording = None,
191+
diff_video: bool = False,
192+
cleanup: bool = True,
193+
) -> bool:
188194
"""Visualize a recording.
189195
190196
Args:
191197
recording (Recording, optional): The recording to visualize.
198+
diff_video (bool): Whether to diff Screenshots against video frames.
199+
cleanup (bool): Whether to remove the HTML file after it is displayed.
192200
193201
Returns:
194202
bool: True if visualization was successful, None otherwise.
@@ -199,7 +207,8 @@ def main(recording: Recording = None) -> bool:
199207
recording = get_latest_recording()
200208
if SCRUB:
201209
scrub.scrub_text(recording.task_description)
202-
logger.debug(f"{recording=}")
210+
logger.info(f"{recording=}")
211+
logger.info(f"{diff_video=}")
203212

204213
meta = {}
205214
action_events = get_events(recording, process=PROCESS_EVENTS, meta=meta)
@@ -233,6 +242,14 @@ def main(recording: Recording = None) -> bool:
233242
]
234243
logger.info(f"{len(action_events)=}")
235244

245+
if diff_video:
246+
video_file_name = video.get_video_file_name(recording.timestamp)
247+
timestamps = [
248+
action_event.screenshot.timestamp - recording.video_start_time
249+
for action_event in action_events
250+
]
251+
frames = video.extract_frames(video_file_name, timestamps)
252+
236253
num_events = (
237254
min(MAX_EVENTS, len(action_events))
238255
if MAX_EVENTS is not None
@@ -248,9 +265,24 @@ def main(recording: Recording = None) -> bool:
248265
for idx, action_event in enumerate(action_events):
249266
if idx == MAX_EVENTS:
250267
break
251-
image = display_event(action_event)
252-
diff = display_event(action_event, diff=True)
253-
mask = action_event.screenshot.diff_mask
268+
269+
try:
270+
image = display_event(action_event)
271+
except TypeError as exc:
272+
# https://github.com/moses-palmer/pynput/issues/481
273+
logger.warning(exc)
274+
continue
275+
276+
if diff_video:
277+
frame_image = frames[idx]
278+
diff_image = compute_diff(frame_image, action_event.screenshot.image)
279+
280+
# TODO: rename
281+
diff = frame_image
282+
mask = diff_image
283+
else:
284+
diff = display_event(action_event, diff=True)
285+
mask = action_event.screenshot.diff_mask
254286

255287
if SCRUB:
256288
image = scrub.scrub_image(image)
@@ -323,14 +355,15 @@ def main(recording: Recording = None) -> bool:
323355
)
324356
)
325357

326-
def cleanup() -> None:
358+
def _cleanup() -> None:
327359
os.remove(fname_out)
328360
removed = not os.path.exists(fname_out)
329361
logger.info(f"{removed=}")
330362

331-
Timer(1, cleanup).start()
363+
if cleanup:
364+
Timer(1, _cleanup).start()
332365
return True
333366

334367

335368
if __name__ == "__main__":
336-
main()
369+
fire.Fire(main)

openadapt/events.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,18 @@ def get_merged_events(
450450
)
451451

452452

453+
def remove_invalid_keyboard_events(
454+
events: list[models.ActionEvent],
455+
) -> list[models.ActionEvent]:
456+
"""Remove invalid keyboard events."""
457+
return [
458+
event
459+
for event in events
460+
# https://github.com/moses-palmer/pynput/issues/481
461+
if not str(event.key) == "<0>"
462+
]
463+
464+
453465
def merge_consecutive_keyboard_events(
454466
events: list[models.ActionEvent],
455467
group_named_keys: bool = KEYBOARD_EVENTS_MERGE_GROUP_NAMED_KEYS,
@@ -717,6 +729,7 @@ def process_events(
717729
f"{num_total=}"
718730
)
719731
process_fns = [
732+
remove_invalid_keyboard_events,
720733
merge_consecutive_keyboard_events,
721734
merge_consecutive_mouse_move_events,
722735
merge_consecutive_mouse_scroll_events,

openadapt/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class Recording(db.Base):
4444
double_click_distance_pixels = sa.Column(sa.Numeric(asdecimal=False))
4545
platform = sa.Column(sa.String)
4646
task_description = sa.Column(sa.String)
47+
video_start_time = sa.Column(ForceFloat)
4748

4849
action_events = sa.orm.relationship(
4950
"ActionEvent",

0 commit comments

Comments
 (0)