From ef3118cbe33bccdf4bfa69164b569e34b44348b7 Mon Sep 17 00:00:00 2001 From: Kumi Date: Tue, 7 Nov 2023 13:58:25 +0100 Subject: [PATCH] Dall-E model selection --- config.dist.ini | 4 ++++ pyproject.toml | 17 ++++++++++------- src/gptbot/__main__.py | 7 +++++-- src/gptbot/classes/bot.py | 3 ++- src/gptbot/classes/openai.py | 6 ++++-- 5 files changed, 25 insertions(+), 12 deletions(-) diff --git a/config.dist.ini b/config.dist.ini index 91fe066..a4026a5 100644 --- a/config.dist.ini +++ b/config.dist.ini @@ -67,6 +67,10 @@ LogLevel = info # # Model = gpt-3.5-turbo +# The Image Generation model you want to use. +# +# ImageModel = dall-e-2 + # Your OpenAI API key # # Find this in your OpenAI account: diff --git a/pyproject.toml b/pyproject.toml index 9cb355d..cd1a0c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,10 +18,6 @@ readme = "README.md" license = { file="LICENSE" } requires-python = ">=3.10" -packages = [ - "src/gptbot" -] - classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", @@ -29,13 +25,17 @@ classifiers = [ ] dependencies = [ - "matrix-nio[e2e]", + "mautrix[all]", "markdown2[all]", "tiktoken", "python-magic", "pillow", ] +packages = [ + { include = "gptbot", where = "src" }, +] + [project.optional-dependencies] openai = [ "openai", @@ -62,7 +62,10 @@ dev = [ "Bug Tracker" = "https://kumig.it/kumitterer/matrix-gptbot/issues" [project.scripts] -gptbot = "gptbot:main" +gptbot = "gptbot.__main___:main" [tool.hatch.build.targets.wheel] -packages = ["src/gptbot"] \ No newline at end of file +only-include = ["src/gptbot"] + +[tool.hatch.build.targets.wheel.sources] +"src" = "" \ No newline at end of file diff --git a/src/gptbot/__main__.py b/src/gptbot/__main__.py index f50eb3b..8d864e4 100644 --- a/src/gptbot/__main__.py +++ b/src/gptbot/__main__.py @@ -10,8 +10,7 @@ import asyncio def sigterm_handler(_signo, _stack_frame): exit() - -if __name__ == "__main__": +def main(): # Parse command line arguments parser = ArgumentParser() parser.add_argument( @@ -46,3 +45,7 @@ if __name__ == "__main__": print("Received KeyboardInterrupt - exiting...") except SystemExit: print("Received SIGTERM - exiting...") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/gptbot/classes/bot.py b/src/gptbot/classes/bot.py index 948632d..f1f69d8 100644 --- a/src/gptbot/classes/bot.py +++ b/src/gptbot/classes/bot.py @@ -129,7 +129,8 @@ class GPTBot: bot.allowed_users = json.loads(config["GPTBot"]["AllowedUsers"]) bot.chat_api = bot.image_api = bot.classification_api = OpenAI( - config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger + config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), + config["OpenAI"].get("ImageModel"), bot.logger ) bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens) bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages) diff --git a/src/gptbot/classes/openai.py b/src/gptbot/classes/openai.py index fcff3e3..d540db5 100644 --- a/src/gptbot/classes/openai.py +++ b/src/gptbot/classes/openai.py @@ -21,13 +21,14 @@ class OpenAI: return self.chat_model classification_api = chat_api - image_api: str = "dalle" + image_model: str = "dall-e-2" operator: str = "OpenAI ([https://openai.com](https://openai.com))" - def __init__(self, api_key, chat_model=None, logger=None): + def __init__(self, api_key, chat_model=None, image_model=None, logger=None): self.api_key = api_key self.chat_model = chat_model or self.chat_model + self.image_model = image_model or self.image_model self.logger = logger or Logger() self.base_url = openai.api_base @@ -146,6 +147,7 @@ Only the event_types mentioned above are allowed, you must not respond in any ot image_partial = partial( openai.Image.acreate, + model=self.image_model, prompt=prompt, n=1, api_key=self.api_key,