Dall-E model selection

This commit is contained in:
Kumi 2023-11-07 13:58:25 +01:00
parent 9abea6e3f8
commit ef3118cbe3
Signed by: kumi
GPG key ID: ECBCC9082395383F
5 changed files with 25 additions and 12 deletions

View file

@ -67,6 +67,10 @@ LogLevel = info
# #
# Model = gpt-3.5-turbo # Model = gpt-3.5-turbo
# The Image Generation model you want to use.
#
# ImageModel = dall-e-2
# Your OpenAI API key # Your OpenAI API key
# #
# Find this in your OpenAI account: # Find this in your OpenAI account:

View file

@ -18,10 +18,6 @@ readme = "README.md"
license = { file="LICENSE" } license = { file="LICENSE" }
requires-python = ">=3.10" requires-python = ">=3.10"
packages = [
"src/gptbot"
]
classifiers = [ classifiers = [
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
@ -29,13 +25,17 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"matrix-nio[e2e]", "mautrix[all]",
"markdown2[all]", "markdown2[all]",
"tiktoken", "tiktoken",
"python-magic", "python-magic",
"pillow", "pillow",
] ]
packages = [
{ include = "gptbot", where = "src" },
]
[project.optional-dependencies] [project.optional-dependencies]
openai = [ openai = [
"openai", "openai",
@ -62,7 +62,10 @@ dev = [
"Bug Tracker" = "https://kumig.it/kumitterer/matrix-gptbot/issues" "Bug Tracker" = "https://kumig.it/kumitterer/matrix-gptbot/issues"
[project.scripts] [project.scripts]
gptbot = "gptbot:main" gptbot = "gptbot.__main___:main"
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["src/gptbot"] only-include = ["src/gptbot"]
[tool.hatch.build.targets.wheel.sources]
"src" = ""

View file

@ -10,8 +10,7 @@ import asyncio
def sigterm_handler(_signo, _stack_frame): def sigterm_handler(_signo, _stack_frame):
exit() exit()
def main():
if __name__ == "__main__":
# Parse command line arguments # Parse command line arguments
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument( parser.add_argument(
@ -46,3 +45,7 @@ if __name__ == "__main__":
print("Received KeyboardInterrupt - exiting...") print("Received KeyboardInterrupt - exiting...")
except SystemExit: except SystemExit:
print("Received SIGTERM - exiting...") print("Received SIGTERM - exiting...")
if __name__ == "__main__":
main()

View file

@ -129,7 +129,8 @@ class GPTBot:
bot.allowed_users = json.loads(config["GPTBot"]["AllowedUsers"]) bot.allowed_users = json.loads(config["GPTBot"]["AllowedUsers"])
bot.chat_api = bot.image_api = bot.classification_api = OpenAI( bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"),
config["OpenAI"].get("ImageModel"), bot.logger
) )
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens) bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages) bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)

View file

@ -21,13 +21,14 @@ class OpenAI:
return self.chat_model return self.chat_model
classification_api = chat_api classification_api = chat_api
image_api: str = "dalle" image_model: str = "dall-e-2"
operator: str = "OpenAI ([https://openai.com](https://openai.com))" operator: str = "OpenAI ([https://openai.com](https://openai.com))"
def __init__(self, api_key, chat_model=None, logger=None): def __init__(self, api_key, chat_model=None, image_model=None, logger=None):
self.api_key = api_key self.api_key = api_key
self.chat_model = chat_model or self.chat_model self.chat_model = chat_model or self.chat_model
self.image_model = image_model or self.image_model
self.logger = logger or Logger() self.logger = logger or Logger()
self.base_url = openai.api_base self.base_url = openai.api_base
@ -146,6 +147,7 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
image_partial = partial( image_partial = partial(
openai.Image.acreate, openai.Image.acreate,
model=self.image_model,
prompt=prompt, prompt=prompt,
n=1, n=1,
api_key=self.api_key, api_key=self.api_key,