Skip to content

Commit

Permalink
Merge pull request #52 from dsdude123/expert-changes
Browse files Browse the repository at this point in the history
Add new Expert command, support for gravity
  • Loading branch information
dsdude123 authored Apr 28, 2024
2 parents 13f5dd8 + fcef6b3 commit 24e627f
Show file tree
Hide file tree
Showing 20 changed files with 124 additions and 93 deletions.
1 change: 1 addition & 0 deletions MariBot-Common/Model/GpuWorker/Command.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public enum Command
DkOldies,
DsKoopa,
Edges2Hentai,
Expert,
Herschel,
Kevin,
Kurisu,
Expand Down
1 change: 1 addition & 0 deletions MariBot-Common/Model/GpuWorker/CommandCapabilityMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class CommandCapabilityMapping
{ Command.DkOldies, WorkerCapability.CPU},
{ Command.DsKoopa, WorkerCapability.CPU},
{ Command.Edges2Hentai, WorkerCapability.ConsumerGPU},
{ Command.Expert, WorkerCapability.CPU },
{ Command.Herschel, WorkerCapability.CPU},
{ Command.Kevin, WorkerCapability.CPU},
{ Command.Kurisu, WorkerCapability.CPU},
Expand Down
2 changes: 1 addition & 1 deletion MariBot-Core/MariBot.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Betalgo.OpenAI.GPT3" Version="6.8.1" />
<PackageReference Include="Betalgo.OpenAI" Version="8.1.1" />
<PackageReference Include="BooruSharp" Version="3.5.5" />
<PackageReference Include="ByteSize" Version="2.0.0" />
<PackageReference Include="ConsoleTables" Version="2.4.2" />
Expand Down
19 changes: 0 additions & 19 deletions MariBot-Core/Modules/Text/ImageModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,25 +151,6 @@ public async Task Daryl([Remainder] string text = null)
HandleCommonImageScenario(Command.Daryl);
}

[Command("dalle", RunMode = RunMode.Async)]
[RequireOwner]
public async Task dalle([Remainder] string prompt)
{
try
{
string imageUrl = openAiService.ExecuteDalleQuery(prompt).Result;
await Context.Channel.SendMessageAsync($"{imageUrl}", messageReference: new MessageReference(Context.Message.Id));
}
catch (ArgumentException)
{
await Context.Channel.SendMessageAsync("Your input prompt failed safety checks.", messageReference: new MessageReference(Context.Message.Id));
}
catch (ApplicationException ex)
{
await Context.Channel.SendMessageAsync($"{ex.Message}", messageReference: new MessageReference(Context.Message.Id));
}
}

[Command("dave", RunMode = RunMode.Async)]
public async Task dave([Remainder] string text = null)
{
Expand Down
53 changes: 44 additions & 9 deletions MariBot-Core/Modules/Text/OpenAIModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ public class OpenAiModule : ModuleBase<SocketCommandContext>
{
private readonly DataService dataService;
private readonly OpenAiService openAiService;
private readonly ImageService imageService;

public OpenAiModule(OpenAiService openAIService, DataService dataService)
public OpenAiModule(OpenAiService openAIService, DataService dataService, ImageService imageService)
{
this.openAiService = openAIService;
this.dataService = dataService;
this.imageService = imageService;
}

/// <summary>
Expand All @@ -30,7 +32,7 @@ public async Task StartChatGptSession([Remainder] string input)
{
try
{
var result = await openAiService.ExecuteChatGptQuery(Context.Guild.Id, Context.Channel.Id, Context.Message.Id, input);
var result = await openAiService.ExecuteChatGptQuery(Context.Guild.Id, Context.Channel.Id, Context.Message.Id, input, Context.User.Id.ToString());
var sentMessage = Context.Channel.SendMessageAsync($"```\n{result.Replace("```","")}\n```", messageReference: new MessageReference(Context.Message.Id)).Result;
if (!dataService.UpdateChatGptMessageHistoryId(Context.Guild.Id, Context.Channel.Id, Context.Message.Id,
sentMessage.Id))
Expand All @@ -54,22 +56,55 @@ public async Task StartChatGptSession([Remainder] string input)
/// </summary>
/// <param name="prompt">Input prompt</param>
/// <returns></returns>
[RequireOwner]
[Command("dalle", RunMode = RunMode.Async)]
public async Task Dalle([Remainder] string prompt)
{
try
{
var imageUrl = openAiService.ExecuteDalleQuery(prompt).Result;
await Context.Channel.SendMessageAsync($"{imageUrl}", messageReference: new MessageReference(Context.Message.Id));
var imageUrl = openAiService.ExecuteDalleQuery(prompt, Context.User.Id.ToString(), "standard").Result;
var stream = await imageService.GetWebResource(imageUrl);
await Context.Channel.SendFileAsync(stream, "dalle.png", messageReference: new MessageReference(Context.Message.Id));
}
catch (ArgumentException)
{
await Context.Channel.SendMessageAsync("Your input prompt failed safety checks.", messageReference: new MessageReference(Context.Message.Id));
}
catch (ApplicationException ex)
catch (AggregateException ex)
{
await Context.Channel.SendMessageAsync($"{ex.Message}", messageReference: new MessageReference(Context.Message.Id));
if (ex.InnerException.GetType() == typeof(ApplicationException))
{
await Context.Channel.SendMessageAsync($"{ex.InnerException.Message}", messageReference: new MessageReference(Context.Message.Id));
}
else
{
throw ex.InnerException;
}
}
}

[Command("dallehd", RunMode = RunMode.Async)]
public async Task DalleHd([Remainder] string prompt)
{
try
{
var imageUrl = openAiService.ExecuteDalleQuery(prompt, Context.User.Id.ToString(), "hd").Result;
var stream = await imageService.GetWebResource(imageUrl);
await Context.Channel.SendFileAsync(stream, "dalle.png", messageReference: new MessageReference(Context.Message.Id));
}
catch (ArgumentException)
{
await Context.Channel.SendMessageAsync("Your input prompt failed safety checks.", messageReference: new MessageReference(Context.Message.Id));
}
catch (AggregateException ex)
{
if (ex.InnerException.GetType() == typeof(ApplicationException))
{
await Context.Channel.SendMessageAsync($"{ex.InnerException.Message}", messageReference: new MessageReference(Context.Message.Id));
}
else
{
throw ex.InnerException;
}
}
}

Expand All @@ -83,7 +118,7 @@ public async Task Gpt3TextCompletion([Remainder] string input)
{
try
{
var result = await openAiService.ExecuteGpt3Query(input);
var result = await openAiService.ExecuteGpt3Query(input, Context.User.Id.ToString());
await Context.Channel.SendMessageAsync($"```\n{result.Replace("```", "")}\n```", messageReference: new MessageReference(Context.Message.Id));
}
catch (ArgumentException)
Expand All @@ -106,7 +141,7 @@ public async Task Gpt4TextCompletion([Remainder] string input)
{
try
{
var result = await openAiService.ExecuteGpt4Query(input);
var result = await openAiService.ExecuteGpt4Query(input, Context.User.Id.ToString());
await Context.Channel.SendMessageAsync($"```\n{result.Replace("```", "")}\n```", messageReference: new MessageReference(Context.Message.Id));
}
catch (ArgumentException)
Expand Down
5 changes: 5 additions & 0 deletions MariBot-Core/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using MariBot.Core;
using MariBot.Core.Services;
using MariBot.Services;
using OpenAI.Extensions;

var builder = WebApplication.CreateBuilder(args);
var clientConfig = new DiscordSocketConfig()
Expand Down Expand Up @@ -46,6 +47,10 @@
builder.Services.AddSingleton<WorkerManagerService>();
builder.Services.AddSingleton<YahooFantasyService>();
builder.Services.AddSingleton<MediawikiSharp_API.Mediawiki>();
builder.Services.AddOpenAIService(settings => {
settings.ApiKey = builder.Configuration["DiscordSettings:OpenAiApiKey"];
settings.Organization = builder.Configuration["DiscordSettings:OpenAiOrganization"];
});

builder.Logging.ClearProviders();
builder.Logging.AddDebug();
Expand Down
26 changes: 14 additions & 12 deletions MariBot-Core/Services/CommandHandlingService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,20 @@ private async Task MessageReceived(SocketMessage rawMessage)
}

var staticResponse = staticTextResponseService.GetResponse(requestedCommand, context.Guild.Id);
if (staticResponse != null)


var result = await commandService.ExecuteAsync(context, argPos, serviceProvider);

if (result.Error.HasValue &&
result.Error.Value != CommandError.UnknownCommand)
{
logger.LogError("Command encountered an error. {}", result.ToString());
await context.Channel.SendMessageAsync(result.ToString());
}


if (staticResponse != null && result.Error.HasValue &&
result.Error.Value != CommandError.UnknownCommand)
{
logger.LogInformation("Found matching static text response for {}", requestedCommand);
if (staticResponse.Attachments != null && staticResponse.Attachments.Count > 0)
Expand All @@ -170,17 +183,6 @@ private async Task MessageReceived(SocketMessage rawMessage)
}

}
else
{
var result = await commandService.ExecuteAsync(context, argPos, serviceProvider);

if (result.Error.HasValue &&
result.Error.Value != CommandError.UnknownCommand)
{
logger.LogError("Command encountered an error. {}", result.ToString());
await context.Channel.SendMessageAsync(result.ToString());
}
}
}
else
{
Expand Down
61 changes: 33 additions & 28 deletions MariBot-Core/Services/OpenAIService.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using Discord;
using Discord.Commands;
using MariBot.Core.Models.ChatGPT;
using OpenAI.GPT3;
using OpenAI.GPT3.ObjectModels.RequestModels;
using OpenAI.Interfaces;
using OpenAI.ObjectModels.RequestModels;
using ApplicationException = System.ApplicationException;
using MessageType = MariBot.Core.Models.ChatGPT.MessageType;

namespace MariBot.Core.Services
Expand All @@ -12,19 +13,15 @@ namespace MariBot.Core.Services
/// </summary>
public class OpenAiService
{
private readonly OpenAI.GPT3.Managers.OpenAIService apiClient;
private readonly IOpenAIService apiClient;
private readonly DataService dataService;
private readonly ILogger<OpenAiService> logger;

public OpenAiService(IConfiguration configuration, ILogger<OpenAiService> logger, DataService dataService)
public OpenAiService(IConfiguration configuration, ILogger<OpenAiService> logger, DataService dataService, IOpenAIService apiClient)
{
this.logger = logger;
this.dataService = dataService;
apiClient = new OpenAI.GPT3.Managers.OpenAIService(
new OpenAiOptions()
{
ApiKey = configuration["DiscordSettings:OpenAiApiKey"]
});
this.apiClient = apiClient;
}

/// <summary>
Expand All @@ -46,10 +43,10 @@ public bool CheckIfChatGpt(ulong guildId, ulong channelId, ulong messageId)
/// <returns>Url to the image</returns>
/// <exception cref="ArgumentException">Input fails safety checks</exception>
/// <exception cref="ApplicationException">API error</exception>
public async Task<string> ExecuteDalleQuery(string input)
public async Task<string> ExecuteDalleQuery(string input, string userId, string quality)
{
// Perform safety checks
var moderationResult = await apiClient.CreateModeration(new CreateModerationRequest()
var moderationResult = await apiClient.Moderation.CreateModeration(new CreateModerationRequest()
{
Input = input,
Model = "text-moderation-latest"
Expand All @@ -61,11 +58,14 @@ public async Task<string> ExecuteDalleQuery(string input)
}

// Call OpenAI
var imageResult = await apiClient.CreateImage(new ImageCreateRequest()
var imageResult = await apiClient.Image.CreateImage(new ImageCreateRequest()
{
Prompt = input,
Size = "1024x1024",
N = 1 // Generate N images
N = 1, // Generate N images,
User = userId,
Model = "dall-e-3",
Quality = quality
});

// Return result
Expand All @@ -86,10 +86,10 @@ public async Task<string> ExecuteDalleQuery(string input)
/// <returns>Response text</returns>
/// <exception cref="ArgumentException">Input fails safety checks</exception>
/// <exception cref="ApplicationException">API error</exception>
public async Task<string> ExecuteGpt3Query(string input)
public async Task<string> ExecuteGpt3Query(string input, string userId)
{
// Perform safety checks
var moderationResult = await apiClient.CreateModeration(new CreateModerationRequest()
var moderationResult = await apiClient.Moderation.CreateModeration(new CreateModerationRequest()
{
Input = input,
Model = "text-moderation-latest"
Expand All @@ -101,16 +101,18 @@ public async Task<string> ExecuteGpt3Query(string input)
}

// Call OpenAI
var textResult = await apiClient.CreateCompletion(new CompletionCreateRequest()
var textResult = await apiClient.ChatCompletion.CreateCompletion(new ChatCompletionCreateRequest()
{
Prompt = input,
MaxTokens = 500
}, OpenAI.GPT3.ObjectModels.Models.TextDavinciV3);
Messages = new[] { new ChatMessage("user", input) },
MaxTokens = 500,
Model = "gpt-3.5-turbo-1106",
User = userId
});

// Return result
if (textResult.Successful)
{
var text = textResult.Choices.FirstOrDefault().Text;
var text = textResult.Choices.FirstOrDefault().Message.Content;

// Trim to meet Discord message length limits
if (text.Length > 1992)
Expand All @@ -133,10 +135,10 @@ public async Task<string> ExecuteGpt3Query(string input)
/// <returns>Response text</returns>
/// <exception cref="ArgumentException">Input fails safety checks</exception>
/// <exception cref="ApplicationException">API error</exception>
public async Task<string> ExecuteGpt4Query(string input)
public async Task<string> ExecuteGpt4Query(string input, string userId)
{
// Perform safety checks
var moderationResult = await apiClient.CreateModeration(new CreateModerationRequest()
var moderationResult = await apiClient.Moderation.CreateModeration(new CreateModerationRequest()
{
Input = input,
Model = "text-moderation-latest"
Expand All @@ -151,9 +153,11 @@ public async Task<string> ExecuteGpt4Query(string input)
var completionResult = await apiClient.ChatCompletion.CreateCompletion(new ChatCompletionCreateRequest()
{
Messages = new[] {new ChatMessage("user", input)},
MaxTokens = 500
MaxTokens = 500,
Model = "gpt-4-1106-preview",
User = userId

}, OpenAI.GPT3.ObjectModels.Models.Gpt4); // TODO: Upgrade to 32k when available
});

// Return result
if (completionResult.Successful)
Expand Down Expand Up @@ -185,11 +189,11 @@ public async Task<string> ExecuteGpt4Query(string input)
/// <exception cref="NotImplementedException">Message type not supported</exception>
/// <exception cref="ArgumentException">Input failed safety checks</exception>
/// <exception cref="ApplicationException">API error</exception>
public async Task<string> ExecuteChatGptQuery(ulong guildId, ulong channelId, ulong messageId, string input)
public async Task<string> ExecuteChatGptQuery(ulong guildId, ulong channelId, ulong messageId, string input, string userId)
{

// Perform safety checks
var moderationResult = await apiClient.CreateModeration(new CreateModerationRequest()
var moderationResult = await apiClient.Moderation.CreateModeration(new CreateModerationRequest()
{
Input = input,
Model = "text-moderation-latest"
Expand Down Expand Up @@ -233,7 +237,8 @@ public async Task<string> ExecuteChatGptQuery(ulong guildId, ulong channelId, ul
{
Messages = messages,
MaxTokens = 500,
Model = OpenAI.GPT3.ObjectModels.Models.ChatGpt3_5Turbo
Model = "gpt-4-1106-preview",
User = userId
});

// Save new message history and return result
Expand Down Expand Up @@ -270,7 +275,7 @@ public async Task HandleReply(SocketCommandContext replyContext)
{
try
{
var result = await ExecuteChatGptQuery(replyContext.Guild.Id, replyContext.Channel.Id, replyContext.Message.ReferencedMessage.Id, replyContext.Message.Content);
var result = await ExecuteChatGptQuery(replyContext.Guild.Id, replyContext.Channel.Id, replyContext.Message.ReferencedMessage.Id, replyContext.Message.Content, replyContext.User.Id.ToString());
var sentMessage = replyContext.Channel.SendMessageAsync($"```\n{result}\n```", messageReference: new MessageReference(replyContext.Message.Id)).Result;
if (!dataService.UpdateChatGptMessageHistoryId(replyContext.Guild.Id, replyContext.Channel.Id, replyContext.Message.ReferencedMessage.Id,
sentMessage.Id))
Expand Down
1 change: 1 addition & 0 deletions MariBot-Core/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"GoogleCloudKey": "",
"GoogleCustomSearchId": "",
"OpenAiApiKey": "",
"OpenAiOrganization": "",
"WolframAlphaAppId": "",
"TwitterApiKey": "",
"TwitterApiKeySecret": "",
Expand Down
4 changes: 2 additions & 2 deletions MariBot-Core/dynamic-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
"Id": 297485054836342786,
"Name": "CSS",
"EnabledFeatures": [ "auto-image-conversion", "latex", "emoji-triggers", "auto-vxtwitter", "auto-ddinstagram", "auto-vxtiktok", "auto-rxddit" ],
"BlockedTextCommands": [ "danbooru", "gelbooru", "konachan", "realbooru", "r34", "safebooru", "sakugabooru", "sankakucomplex", "xbooru", "yandere", "e2h", "gpt3", "chatgpt", "gpt4" ]
"BlockedTextCommands": [ "danbooru", "gelbooru", "konachan", "realbooru", "r34", "safebooru", "sakugabooru", "sankakucomplex", "xbooru", "yandere", "e2h", "gpt3", "chatgpt", "gpt4", "dalle", "dallehd" ]
},
{
"Id": 829910467622338580,
"Name": "MarukiCountry",
"EnabledFeatures": [ "emoji-triggers", "auto-image-conversion", "auto-vxtwitter", "auto-ddinstagram", "auto-vxtiktok", "auto-rxddit" ],
"BlockedTextCommands": [ "danbooru", "gelbooru", "konachan", "realbooru", "r34", "safebooru", "sakugabooru", "sankakucomplex", "xbooru", "yandere" ],
"BlockedTextCommands": [ "danbooru", "gelbooru", "konachan", "realbooru", "r34", "safebooru", "sakugabooru", "sankakucomplex", "xbooru", "yandere"],
"AutoReactions": [
{
"TriggerWords": [ "cowboy", "cowboys", "dak", "zeke", "lamb", "diggs", "parsons", "ring", "dallas", "amari" ],
Expand Down
Loading

0 comments on commit 24e627f

Please sign in to comment.