From e98155ab2393cb679dcde76ece50a845f48a35dc Mon Sep 17 00:00:00 2001 From: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Date: Wed, 22 Jan 2025 16:10:40 +0000 Subject: [PATCH] .Net: Sanitize function names (#10257) ### Motivation and Context Currently, if an AI model hallucinates a function name (e.g., `bar.foo` instead of the advertised `bar-foo`) that contains disallowed characters (not in the range of `a-zA-Z0-9_-`), SK identifies it as an error case and sends the error back to the model along with the original function call. Given that the original function call contains the disallowed function name, the request to the AI model fails with the error: _Invalid 'messages[6].tool_calls[0].function.name': string does not match pattern. Expected a string that matches the pattern `^[a-zA-Z0-9_-]+$`._ ### Description This PR replaces disallowed characters in function names with an underscore for all function calls before sending them to the AI model. This fix prevents the request to the AI model from failing and allows the model to auto-recover. More context: [Function Calling Reliability ADR](https://github.com/microsoft/semantic-kernel/blob/main/docs/decisions/0063-function-calling-reliability.md#function-calling-reliability) Closes: https://github.com/microsoft/semantic-kernel/issues/9850 --- .../Core/ClientCoreTests.cs | 96 +++++++++++++++++++ .../Core/ClientCore.ChatCompletion.cs | 33 ++++++- 2 files changed, 128 insertions(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/ClientCoreTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/ClientCoreTests.cs index 8597fb4b9dd9..4aafd3601bb3 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/ClientCoreTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/ClientCoreTests.cs @@ -3,20 +3,30 @@ using System; using System.ClientModel; using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; using System.Linq; +using System.Net; using System.Net.Http; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.OpenAI; using Microsoft.SemanticKernel.Http; using Microsoft.SemanticKernel.Services; using Moq; using OpenAI; +using OpenAI.Chat; using Xunit; +using BinaryContent = System.ClientModel.BinaryContent; +using ChatMessageContent = Microsoft.SemanticKernel.ChatMessageContent; namespace SemanticKernel.Connectors.OpenAI.UnitTests.Core; + public partial class ClientCoreTests { [Fact] @@ -240,4 +250,90 @@ public void ItDoesNotThrowWhenUsingCustomEndpointAndApiKeyIsNotProvided() clientCore = new ClientCore("modelId", "", endpoint: new Uri("http://localhost")); clientCore = new ClientCore("modelId", apiKey: null!, endpoint: new Uri("http://localhost")); } + + [Theory] + [ClassData(typeof(ChatMessageContentWithFunctionCalls))] + public async Task ItShouldReplaceDisallowedCharactersInFunctionName(ChatMessageContent chatMessageContent, bool nameContainsDisallowedCharacter) + { + // Arrange + using var responseMessage = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(File.ReadAllText("TestData/chat_completion_test_response.json")) + }; + + using HttpMessageHandlerStub handler = new(); + handler.ResponseToReturn = responseMessage; + using HttpClient client = new(handler); + + var clientCore = new ClientCore("modelId", "apikey", httpClient: client); + + ChatHistory chatHistory = [chatMessageContent]; + + // Act + await clientCore.GetChatMessageContentsAsync("gpt-4", chatHistory, new OpenAIPromptExecutionSettings(), new Kernel()); + + // Assert + JsonElement jsonString = JsonSerializer.Deserialize(handler.RequestContent); + + var function = jsonString.GetProperty("messages")[0].GetProperty("tool_calls")[0].GetProperty("function"); + + if (nameContainsDisallowedCharacter) + { + // The original name specified in function calls is "bar.foo", which contains a disallowed character '.'. + Assert.Equal("bar_foo", function.GetProperty("name").GetString()); + } + else + { + // The original name specified in function calls is "bar-foo" and contains no disallowed characters. + Assert.Equal("bar-foo", function.GetProperty("name").GetString()); + } + } + + internal sealed class ChatMessageContentWithFunctionCalls : TheoryData + { + private static readonly ChatToolCall s_functionCallWithInvalidFunctionName = ChatToolCall.CreateFunctionToolCall(id: "call123", functionName: "bar.foo", functionArguments: BinaryData.FromString("{}")); + + private static readonly ChatToolCall s_functionCallWithValidFunctionName = ChatToolCall.CreateFunctionToolCall(id: "call123", functionName: "bar-foo", functionArguments: BinaryData.FromString("{}")); + + public ChatMessageContentWithFunctionCalls() + { + this.AddMessagesWithFunctionCallsWithInvalidFunctionName(); + } + + private void AddMessagesWithFunctionCallsWithInvalidFunctionName() + { + // Case when function calls are available via the `Tools` property. + this.Add(new OpenAIChatMessageContent(AuthorRole.Assistant, "", "", [s_functionCallWithInvalidFunctionName]), true); + + // Case when function calls are available via the `ChatResponseMessage.FunctionToolCalls` metadata as an array of ChatToolCall type. + this.Add(new ChatMessageContent(AuthorRole.Assistant, "", metadata: new Dictionary() + { + [OpenAIChatMessageContent.FunctionToolCallsProperty] = new ChatToolCall[] { s_functionCallWithInvalidFunctionName } + }), true); + + // Case when function calls are available via the `ChatResponseMessage.FunctionToolCalls` metadata as an array of JsonElement type. + this.Add(new ChatMessageContent(AuthorRole.Assistant, "", metadata: new Dictionary() + { + [OpenAIChatMessageContent.FunctionToolCallsProperty] = JsonSerializer.Deserialize($$"""[{"Id": "{{s_functionCallWithInvalidFunctionName.Id}}", "Name": "{{s_functionCallWithInvalidFunctionName.FunctionName}}", "Arguments": "{{s_functionCallWithInvalidFunctionName.FunctionArguments}}"}]""") + }), true); + } + + private void AddMessagesWithFunctionCallsWithValidFunctionName() + { + // Case when function calls are available via the `Tools` property. + this.Add(new OpenAIChatMessageContent(AuthorRole.Assistant, "", "", [s_functionCallWithValidFunctionName]), false); + + // Case when function calls are available via the `ChatResponseMessage.FunctionToolCalls` metadata as an array of ChatToolCall type. + this.Add(new ChatMessageContent(AuthorRole.Assistant, "", metadata: new Dictionary() + { + [OpenAIChatMessageContent.FunctionToolCallsProperty] = new ChatToolCall[] { s_functionCallWithValidFunctionName } + }), false); + + // Case when function calls are available via the `ChatResponseMessage.FunctionToolCalls` metadata as an array of JsonElement type. + this.Add(new ChatMessageContent(AuthorRole.Assistant, "", metadata: new Dictionary() + { + [OpenAIChatMessageContent.FunctionToolCallsProperty] = JsonSerializer.Deserialize($$"""[{"Id": "{{s_functionCallWithValidFunctionName.Id}}", "Name": "{{s_functionCallWithValidFunctionName.FunctionName}}", "Arguments": "{{s_functionCallWithValidFunctionName.FunctionArguments}}"}]""") + }), false); + } + } } diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs index 6b6a039a0acd..129e7913b788 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs @@ -9,6 +9,7 @@ using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; +using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; @@ -26,6 +27,13 @@ namespace Microsoft.SemanticKernel.Connectors.OpenAI; /// internal partial class ClientCore { +#if NET + [GeneratedRegex("[^a-zA-Z0-9_-]")] + private static partial Regex DisallowedFunctionNameCharactersRegex(); +#else + private static Regex DisallowedFunctionNameCharactersRegex() => new("[^a-zA-Z0-9_-]", RegexOptions.Compiled); +#endif + protected const string ModelProvider = "openai"; protected record ToolCallingConfig(IList? Tools, ChatToolChoice? Choice, bool AutoInvoke, bool AllowAnyRequestedKernelFunction, FunctionChoiceBehaviorOptions? Options); @@ -752,7 +760,7 @@ private static List CreateRequestMessages(ChatMessageContent messag return [new AssistantChatMessage(message.Content) { ParticipantName = message.AuthorName }]; } - var assistantMessage = new AssistantChatMessage(toolCalls) { ParticipantName = message.AuthorName }; + var assistantMessage = new AssistantChatMessage(SanitizeFunctionNames(toolCalls)) { ParticipantName = message.AuthorName }; // If message content is null, adding it as empty string, // because chat message content must be string. @@ -1054,4 +1062,27 @@ private void ProcessNonFunctionToolCalls(IEnumerable toolCalls, Ch chatHistory.Add(message); } } + + /// + /// Sanitizes function names by replacing disallowed characters. + /// + /// The function calls containing the function names which need to be sanitized. + /// The function calls with sanitized function names. + private static List SanitizeFunctionNames(List toolCalls) + { + for (int i = 0; i < toolCalls.Count; i++) + { + ChatToolCall tool = toolCalls[i]; + + // Check if function name contains disallowed characters and replace them with '_'. + if (DisallowedFunctionNameCharactersRegex().IsMatch(tool.FunctionName)) + { + var sanitizedName = DisallowedFunctionNameCharactersRegex().Replace(tool.FunctionName, "_"); + + toolCalls[i] = ChatToolCall.CreateFunctionToolCall(tool.Id, sanitizedName, tool.FunctionArguments); + } + } + + return toolCalls; + } }