diff --git a/camel/model_backend.py b/camel/model_backend.py index 2c664ba33..de974d290 100644 --- a/camel/model_backend.py +++ b/camel/model_backend.py @@ -92,11 +92,11 @@ def run(self, *args, **kwargs): "gpt-4-32k": 32768, "gpt-4-turbo": 100000, } - num_max_token = num_max_token_map[self.model_type.value] + num_max_token = num_max_token_map.get(self.model_type.value, os.environ['CHATDEV_NUM_MAX_TOKEN']) num_max_completion_tokens = num_max_token - num_prompt_tokens self.model_config_dict['max_tokens'] = num_max_completion_tokens - response = client.chat.completions.create(*args, **kwargs, model=self.model_type.value, + response = client.chat.completions.create(*args, **kwargs, model=os.environ.get('CHATDEV_CUSTOM_MODEL', self.model_type.value), **self.model_config_dict) cost = prompt_cost( diff --git a/chatdev/chat_env.py b/chatdev/chat_env.py index b9504969d..c85271bac 100644 --- a/chatdev/chat_env.py +++ b/chatdev/chat_env.py @@ -1,5 +1,6 @@ import os import re +import base64 import shutil import signal import subprocess @@ -12,7 +13,10 @@ from chatdev.codes import Codes from chatdev.documents import Documents from chatdev.roster import Roster -from chatdev.utils import log_visualize +from chatdev.utils import ( + log_visualize, + is_url +) from ecl.memory import Memory try: @@ -214,12 +218,15 @@ def write_meta(self) -> None: def generate_images_from_codes(self): def download(img_url, file_name): - r = requests.get(img_url) + if is_url(img_url): + content = requests.get(img_url).content + else: + content = base64.b64decode(img_url) filepath = os.path.join(self.env_dict['directory'], file_name) if os.path.exists(filepath): os.remove(filepath) with open(filepath, "wb") as f: - f.write(r.content) + f.write(content) print("{} Downloaded".format(filepath)) regex = r"(\w+.png)" @@ -241,28 +248,39 @@ def download(img_url, file_name): print("{}: {}".format(filename, desc)) if openai_new_api: response = openai.images.generate( + model=os.environ.get('CHATDEV_CUSTOM_IMAGE_MODEL', None), prompt=desc, n=1, size="256x256" ) - image_url = response.data[0].url + try: + image_url = response.data[0].url + except KeyError: + image_url = response.data['url'] else: response = openai.Image.create( + model=os.environ.get('CHATDEV_CUSTOM_IMAGE_MODEL', None), prompt=desc, n=1, size="256x256" ) - image_url = response['data'][0]['url'] + try: + image_url = response['data'][0]['url'] + except KeyError: + image_url = response['data']['url'] download(image_url, filename) def get_proposed_images_from_message(self, messages): def download(img_url, file_name): - r = requests.get(img_url) + if is_url(img_url): + content = requests.get(img_url).content + else: + content = base64.b64decode(img_url) filepath = os.path.join(self.env_dict['directory'], file_name) if os.path.exists(filepath): os.remove(filepath) with open(filepath, "wb") as f: - f.write(r.content) + f.write(content) print("{} Downloaded".format(filepath)) regex = r"(\w+.png):(.*?)\n" @@ -292,6 +310,7 @@ def download(img_url, file_name): if openai_new_api: response = openai.images.generate( + model=os.environ.get('CHATDEV_CUSTOM_IMAGE_MODEL', None), prompt=desc, n=1, size="256x256" @@ -299,6 +318,7 @@ def download(img_url, file_name): image_url = response.data[0].url else: response = openai.Image.create( + model=os.environ.get('CHATDEV_CUSTOM_IMAGE_MODEL', None), prompt=desc, n=1, size="256x256" diff --git a/chatdev/utils.py b/chatdev/utils.py index 4f17947ca..47d40f487 100644 --- a/chatdev/utils.py +++ b/chatdev/utils.py @@ -3,8 +3,11 @@ import re import time +from urllib.parse import urlparse + import markdown import inspect + from camel.messages.system_messages import SystemMessage from visualizer.app import send_msg @@ -87,3 +90,13 @@ def escape_string(value): value = re.sub(r'<[^>]*>', '', value) value = value.replace("\n", " ") return value + +def is_url(url): + """ + Adapted from https://stackoverflow.com/questions/7160737/how-to-validate-a-url-in-python-malformed-or-not + """ + try: + result = urlparse(url) + return all([result.scheme, result.netloc]) + except ValueError: + return False