diff --git a/config.dist.ini b/config.dist.ini index 91fe066..a4026a5 100644 --- a/config.dist.ini +++ b/config.dist.ini @@ -67,6 +67,10 @@ LogLevel = info # # Model = gpt-3.5-turbo +# The Image Generation model you want to use. +# +# ImageModel = dall-e-2 + # Your OpenAI API key # # Find this in your OpenAI account: diff --git a/src/gptbot/classes/bot.py b/src/gptbot/classes/bot.py index 5e7c271..9800581 100644 --- a/src/gptbot/classes/bot.py +++ b/src/gptbot/classes/bot.py @@ -138,7 +138,8 @@ class GPTBot: bot.allowed_users = json.loads(config["GPTBot"]["AllowedUsers"]) bot.chat_api = bot.image_api = bot.classification_api = OpenAI( - config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger + config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), + config["OpenAI"].get("ImageModel"), bot.logger ) bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens) bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages) diff --git a/src/gptbot/classes/openai.py b/src/gptbot/classes/openai.py index fcff3e3..8eb9344 100644 --- a/src/gptbot/classes/openai.py +++ b/src/gptbot/classes/openai.py @@ -21,11 +21,11 @@ class OpenAI: return self.chat_model classification_api = chat_api - image_api: str = "dalle" + image_model: str = "dall-e-2" operator: str = "OpenAI ([https://openai.com](https://openai.com))" - def __init__(self, api_key, chat_model=None, logger=None): + def __init__(self, api_key, chat_model=None, image_model=None, logger=None): self.api_key = api_key self.chat_model = chat_model or self.chat_model self.logger = logger or Logger() @@ -146,6 +146,7 @@ Only the event_types mentioned above are allowed, you must not respond in any ot image_partial = partial( openai.Image.acreate, + model=self.image_model, prompt=prompt, n=1, api_key=self.api_key,