Skip to content

Commit 266b9bf

Browse files
authored
feat(models): add ActionEvent.prompt_for_description (#933)
* add ActionEvent.prompt_for_description * add display_event(darken_outside, display_text) * add experiments/describe_action.py * default RECORD_AUDIO to false * use joinedload in get_latest_recording * set anthropic.py MODEL_NAME to claude-3-5-sonnet-20241022 * support PNG in utils.image2utf8 * python>=3.10,<3.12
1 parent e595dd3 commit 266b9bf

File tree

9 files changed

+288
-51
lines changed

9 files changed

+288
-51
lines changed

experiments/describe_action.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""Generate action descriptions."""
2+
3+
from pprint import pformat
4+
5+
from loguru import logger
6+
import cv2
7+
import numpy as np
8+
9+
from openadapt.db import crud
10+
11+
12+
def embed_description(
13+
image: np.ndarray,
14+
description: str,
15+
x: int = None,
16+
y: int = None,
17+
) -> np.ndarray:
18+
"""Embed a description into an image at the specified location.
19+
20+
Args:
21+
image (np.ndarray): The image to annotate.
22+
description (str): The text to embed.
23+
x (int, optional): The x-coordinate. Defaults to None (centered).
24+
y (int, optional): The y-coordinate. Defaults to None (centered).
25+
26+
Returns:
27+
np.ndarray: The annotated image.
28+
"""
29+
font = cv2.FONT_HERSHEY_SIMPLEX
30+
font_scale = 1
31+
font_color = (255, 255, 255) # White
32+
line_type = 1
33+
34+
# Split description into multiple lines
35+
max_width = 60 # Maximum characters per line
36+
words = description.split()
37+
lines = []
38+
current_line = []
39+
for word in words:
40+
if len(" ".join(current_line + [word])) <= max_width:
41+
current_line.append(word)
42+
else:
43+
lines.append(" ".join(current_line))
44+
current_line = [word]
45+
if current_line:
46+
lines.append(" ".join(current_line))
47+
48+
# Default to center if coordinates are not provided
49+
if x is None or y is None:
50+
x = image.shape[1] // 2
51+
y = image.shape[0] // 2
52+
53+
# Draw semi-transparent background and text
54+
for i, line in enumerate(lines):
55+
text_size, _ = cv2.getTextSize(line, font, font_scale, line_type)
56+
text_x = max(0, min(x - text_size[0] // 2, image.shape[1] - text_size[0]))
57+
text_y = y + i * 20
58+
59+
# Draw background
60+
cv2.rectangle(
61+
image,
62+
(text_x - 15, text_y - 25),
63+
(text_x + text_size[0] + 15, text_y + 15),
64+
(0, 0, 0),
65+
-1,
66+
)
67+
68+
# Draw text
69+
cv2.putText(
70+
image,
71+
line,
72+
(text_x, text_y),
73+
font,
74+
font_scale,
75+
font_color,
76+
line_type,
77+
)
78+
79+
return image
80+
81+
82+
def main() -> None:
83+
"""Main function."""
84+
with crud.get_new_session(read_only=True) as session:
85+
recording = crud.get_latest_recording(session)
86+
action_events = recording.processed_action_events
87+
descriptions = []
88+
for action in action_events:
89+
description, image = action.prompt_for_description(return_image=True)
90+
91+
# Convert image to numpy array for OpenCV compatibility
92+
image = np.array(image)
93+
94+
if action.mouse_x is not None and action.mouse_y is not None:
95+
# Use the mouse coordinates for mouse events
96+
annotated_image = embed_description(
97+
image,
98+
description,
99+
x=int(action.mouse_x) * 2,
100+
y=int(action.mouse_y) * 2,
101+
)
102+
else:
103+
# Center the text for other events
104+
annotated_image = embed_description(image, description)
105+
106+
logger.info(f"{action=}")
107+
logger.info(f"{description=}")
108+
cv2.imshow("Annotated Image", annotated_image)
109+
cv2.waitKey(0)
110+
descriptions.append(description)
111+
112+
logger.info(f"descriptions=\n{pformat(descriptions)}")
113+
114+
115+
if __name__ == "__main__":
116+
main()

openadapt/config.defaults.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"RECORD_READ_ACTIVE_ELEMENT_STATE": false,
1818
"REPLAY_STRIP_ELEMENT_STATE": true,
1919
"RECORD_VIDEO": true,
20-
"RECORD_AUDIO": true,
20+
"RECORD_AUDIO": false,
2121
"RECORD_BROWSER_EVENTS": false,
2222
"RECORD_FULL_VIDEO": false,
2323
"RECORD_IMAGES": false,

openadapt/db/crud.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -337,16 +337,18 @@ def get_all_scrubbed_recordings(
337337

338338

339339
def get_latest_recording(session: SaSession) -> Recording:
340-
"""Get the latest recording.
341-
342-
Args:
343-
session (sa.orm.Session): The database session.
344-
345-
Returns:
346-
Recording: The latest recording object.
347-
"""
340+
"""Get the latest recording with preloaded relationships."""
348341
return (
349-
session.query(Recording).order_by(sa.desc(Recording.timestamp)).limit(1).first()
342+
session.query(Recording)
343+
.options(
344+
sa.orm.joinedload(Recording.screenshots),
345+
sa.orm.joinedload(Recording.action_events)
346+
.joinedload(ActionEvent.screenshot)
347+
.joinedload(Screenshot.recording),
348+
sa.orm.joinedload(Recording.window_events),
349+
)
350+
.order_by(sa.desc(Recording.timestamp))
351+
.first()
350352
)
351353

352354

openadapt/drivers/anthropic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
from PIL import Image
66
import anthropic
77

8-
from openadapt import cache, utils
8+
from openadapt import cache
99
from openadapt.config import config
1010
from openadapt.custom_logger import logger
1111

1212
MAX_TOKENS = 4096
1313
# from https://docs.anthropic.com/claude/docs/vision
1414
MAX_IMAGES = 20
15-
MODEL_NAME = "claude-3-opus-20240229"
15+
MODEL_NAME = "claude-3-5-sonnet-20241022"
1616

1717

1818
@cache.cache()
@@ -24,6 +24,8 @@ def create_payload(
2424
max_tokens: int | None = None,
2525
) -> dict:
2626
"""Creates the payload for the Anthropic API request with image support."""
27+
from openadapt import utils
28+
2729
messages = []
2830

2931
user_message_content = []
@@ -36,7 +38,7 @@ def create_payload(
3638
# Add base64 encoded images to the user message content
3739
if images:
3840
for image in images:
39-
image_base64 = utils.image2utf8(image)
41+
image_base64 = utils.image2utf8(image, "PNG")
4042
# Extract media type and base64 data
4143
# TODO: don't add it to begin with
4244
media_type, image_base64_data = image_base64.split(";base64,", 1)
@@ -90,7 +92,7 @@ def get_completion(
9092
"""Sends a request to the Anthropic API and returns the response."""
9193
client = anthropic.Anthropic(api_key=api_key)
9294
try:
93-
response = client.messages.create(**payload)
95+
response = client.beta.messages.create(**payload)
9496
except Exception as exc:
9597
logger.exception(exc)
9698
if dev_mode:

openadapt/models.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import copy
88
import io
99
import sys
10+
import textwrap
1011

1112
from bs4 import BeautifulSoup
1213
from pynput import keyboard
@@ -16,6 +17,7 @@
1617

1718
from openadapt.config import config
1819
from openadapt.custom_logger import logger
20+
from openadapt.drivers import anthropic
1921
from openadapt.db import db
2022
from openadapt.privacy.base import ScrubbingProvider, TextScrubbingMixin
2123
from openadapt.privacy.providers import ScrubProvider
@@ -110,6 +112,9 @@ def processed_action_events(self) -> list:
110112
if not self._processed_action_events:
111113
session = crud.get_new_session(read_only=True)
112114
self._processed_action_events = events.get_events(session, self)
115+
# Preload screenshots to avoid lazy loading later
116+
for event in self._processed_action_events:
117+
event.screenshot
113118
return self._processed_action_events
114119

115120
def scrub(self, scrubber: ScrubbingProvider) -> None:
@@ -125,6 +130,7 @@ class ActionEvent(db.Base):
125130
"""Class representing an action event in the database."""
126131

127132
__tablename__ = "action_event"
133+
_repr_ignore_attrs = ["reducer_names"]
128134

129135
_segment_description_separator = ";"
130136

@@ -333,6 +339,11 @@ def canonical_text(self, value: str) -> None:
333339
if not value == self.canonical_text:
334340
logger.warning(f"{value=} did not match {self.canonical_text=}")
335341

342+
@property
343+
def raw_text(self) -> str:
344+
"""Return a string containing the raw action text (without separators)."""
345+
return "".join(self.text.split(config.ACTION_TEXT_SEP))
346+
336347
def __str__(self) -> str:
337348
"""Return a string representation of the action event."""
338349
attr_names = [
@@ -544,6 +555,75 @@ def next_event(self) -> Union["ActionEvent", None]:
544555

545556
return None
546557

558+
def prompt_for_description(self, return_image: bool = False) -> str:
559+
"""Use the Anthropic API to describe what is happening in the action event.
560+
561+
Args:
562+
return_image (bool): Whether to return the image sent to the model.
563+
564+
Returns:
565+
str: The description of the action event.
566+
"""
567+
from openadapt.plotting import display_event
568+
569+
image = display_event(
570+
self,
571+
marker_width_pct=0.05,
572+
marker_height_pct=0.05,
573+
darken_outside=0.7,
574+
display_text=False,
575+
marker_fill_transparency=0,
576+
)
577+
578+
if self.text:
579+
description = f"Type '{self.raw_text}'"
580+
else:
581+
prompt = (
582+
"What user interface element is contained in the highlighted circle "
583+
"of the image?"
584+
)
585+
# TODO: disambiguate
586+
system_prompt = textwrap.dedent(
587+
"""
588+
Briefly describe the user interface element in the screenshot at the
589+
highlighted location.
590+
For example:
591+
- "OK button"
592+
- "URL bar"
593+
- "Down arrow"
594+
DO NOT DESCRIBE ANYTHING OUTSIDE THE HIGHLIGHTED AREA.
595+
Do not append anything like "is contained within the highlighted circle
596+
in the calculator interface." Just name the user interface element.
597+
"""
598+
)
599+
600+
logger.info(f"system_prompt=\n{system_prompt}")
601+
logger.info(f"prompt=\n{prompt}")
602+
603+
# Call the Anthropic API
604+
element = anthropic.prompt(
605+
prompt=prompt,
606+
system_prompt=system_prompt,
607+
images=[image],
608+
)
609+
610+
if self.name == "move":
611+
description = f"Move mouse to '{element}'"
612+
elif self.name == "scroll":
613+
# TODO: "scroll to", dx/dy
614+
description = f"Scroll mouse on '{element}'"
615+
elif "click" in self.name:
616+
description = (
617+
f"{self.mouse_button_name.capitalize()} {self.name} '{element}'"
618+
)
619+
else:
620+
raise ValueError(f"Unhandled {self.name=} {self}")
621+
622+
if return_image:
623+
return description, image
624+
else:
625+
return description
626+
547627

548628
class WindowEvent(db.Base):
549629
"""Class representing a window event in the database."""

0 commit comments

Comments
 (0)