Dall-E model selection
This commit is contained in:
parent
9abea6e3f8
commit
ef3118cbe3
5 changed files with 25 additions and 12 deletions
|
@ -67,6 +67,10 @@ LogLevel = info
|
||||||
#
|
#
|
||||||
# Model = gpt-3.5-turbo
|
# Model = gpt-3.5-turbo
|
||||||
|
|
||||||
|
# The Image Generation model you want to use.
|
||||||
|
#
|
||||||
|
# ImageModel = dall-e-2
|
||||||
|
|
||||||
# Your OpenAI API key
|
# Your OpenAI API key
|
||||||
#
|
#
|
||||||
# Find this in your OpenAI account:
|
# Find this in your OpenAI account:
|
||||||
|
|
|
@ -18,10 +18,6 @@ readme = "README.md"
|
||||||
license = { file="LICENSE" }
|
license = { file="LICENSE" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|
||||||
packages = [
|
|
||||||
"src/gptbot"
|
|
||||||
]
|
|
||||||
|
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
|
@ -29,13 +25,17 @@ classifiers = [
|
||||||
]
|
]
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"matrix-nio[e2e]",
|
"mautrix[all]",
|
||||||
"markdown2[all]",
|
"markdown2[all]",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
"python-magic",
|
"python-magic",
|
||||||
"pillow",
|
"pillow",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
packages = [
|
||||||
|
{ include = "gptbot", where = "src" },
|
||||||
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
openai = [
|
openai = [
|
||||||
"openai",
|
"openai",
|
||||||
|
@ -62,7 +62,10 @@ dev = [
|
||||||
"Bug Tracker" = "https://kumig.it/kumitterer/matrix-gptbot/issues"
|
"Bug Tracker" = "https://kumig.it/kumitterer/matrix-gptbot/issues"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
gptbot = "gptbot:main"
|
gptbot = "gptbot.__main___:main"
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["src/gptbot"]
|
only-include = ["src/gptbot"]
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel.sources]
|
||||||
|
"src" = ""
|
|
@ -10,8 +10,7 @@ import asyncio
|
||||||
def sigterm_handler(_signo, _stack_frame):
|
def sigterm_handler(_signo, _stack_frame):
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
|
def main():
|
||||||
if __name__ == "__main__":
|
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -46,3 +45,7 @@ if __name__ == "__main__":
|
||||||
print("Received KeyboardInterrupt - exiting...")
|
print("Received KeyboardInterrupt - exiting...")
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
print("Received SIGTERM - exiting...")
|
print("Received SIGTERM - exiting...")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -129,7 +129,8 @@ class GPTBot:
|
||||||
bot.allowed_users = json.loads(config["GPTBot"]["AllowedUsers"])
|
bot.allowed_users = json.loads(config["GPTBot"]["AllowedUsers"])
|
||||||
|
|
||||||
bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
|
bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
|
||||||
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger
|
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"),
|
||||||
|
config["OpenAI"].get("ImageModel"), bot.logger
|
||||||
)
|
)
|
||||||
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
|
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
|
||||||
bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
|
bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
|
||||||
|
|
|
@ -21,13 +21,14 @@ class OpenAI:
|
||||||
return self.chat_model
|
return self.chat_model
|
||||||
|
|
||||||
classification_api = chat_api
|
classification_api = chat_api
|
||||||
image_api: str = "dalle"
|
image_model: str = "dall-e-2"
|
||||||
|
|
||||||
operator: str = "OpenAI ([https://openai.com](https://openai.com))"
|
operator: str = "OpenAI ([https://openai.com](https://openai.com))"
|
||||||
|
|
||||||
def __init__(self, api_key, chat_model=None, logger=None):
|
def __init__(self, api_key, chat_model=None, image_model=None, logger=None):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.chat_model = chat_model or self.chat_model
|
self.chat_model = chat_model or self.chat_model
|
||||||
|
self.image_model = image_model or self.image_model
|
||||||
self.logger = logger or Logger()
|
self.logger = logger or Logger()
|
||||||
self.base_url = openai.api_base
|
self.base_url = openai.api_base
|
||||||
|
|
||||||
|
@ -146,6 +147,7 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
||||||
|
|
||||||
image_partial = partial(
|
image_partial = partial(
|
||||||
openai.Image.acreate,
|
openai.Image.acreate,
|
||||||
|
model=self.image_model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
n=1,
|
n=1,
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
|
|
Loading…
Reference in a new issue