Skip to content

Commit

Permalink
.Net: Sanitize function names (#10257)
Browse files Browse the repository at this point in the history
### Motivation and Context
Currently, if an AI model hallucinates a function name (e.g., `bar.foo`
instead of the advertised `bar-foo`) that contains disallowed characters
(not in the range of `a-zA-Z0-9_-`), SK identifies it as an error case
and sends the error back to the model along with the original function
call. Given that the original function call contains the disallowed
function name, the request to the AI model fails with the error:
_Invalid 'messages[6].tool_calls[0].function.name': string does not
match pattern. Expected a string that matches the pattern
`^[a-zA-Z0-9_-]+$`._

### Description
This PR replaces disallowed characters in function names with an
underscore for all function calls before sending them to the AI model.
This fix prevents the request to the AI model from failing and allows
the model to auto-recover.

More context: [Function Calling Reliability
ADR](https://github.com/microsoft/semantic-kernel/blob/main/docs/decisions/0063-function-calling-reliability.md#function-calling-reliability)
Closes: #9850
  • Loading branch information
SergeyMenshykh authored Jan 22, 2025
1 parent ef28a1e commit e98155a
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,30 @@
using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.Http;
using Microsoft.SemanticKernel.Services;
using Moq;
using OpenAI;
using OpenAI.Chat;
using Xunit;
using BinaryContent = System.ClientModel.BinaryContent;
using ChatMessageContent = Microsoft.SemanticKernel.ChatMessageContent;

namespace SemanticKernel.Connectors.OpenAI.UnitTests.Core;

public partial class ClientCoreTests
{
[Fact]
Expand Down Expand Up @@ -240,4 +250,90 @@ public void ItDoesNotThrowWhenUsingCustomEndpointAndApiKeyIsNotProvided()
clientCore = new ClientCore("modelId", "", endpoint: new Uri("http://localhost"));
clientCore = new ClientCore("modelId", apiKey: null!, endpoint: new Uri("http://localhost"));
}

[Theory]
[ClassData(typeof(ChatMessageContentWithFunctionCalls))]
public async Task ItShouldReplaceDisallowedCharactersInFunctionName(ChatMessageContent chatMessageContent, bool nameContainsDisallowedCharacter)
{
// Arrange
using var responseMessage = new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent(File.ReadAllText("TestData/chat_completion_test_response.json"))
};

using HttpMessageHandlerStub handler = new();
handler.ResponseToReturn = responseMessage;
using HttpClient client = new(handler);

var clientCore = new ClientCore("modelId", "apikey", httpClient: client);

ChatHistory chatHistory = [chatMessageContent];

// Act
await clientCore.GetChatMessageContentsAsync("gpt-4", chatHistory, new OpenAIPromptExecutionSettings(), new Kernel());

// Assert
JsonElement jsonString = JsonSerializer.Deserialize<JsonElement>(handler.RequestContent);

var function = jsonString.GetProperty("messages")[0].GetProperty("tool_calls")[0].GetProperty("function");

if (nameContainsDisallowedCharacter)
{
// The original name specified in function calls is "bar.foo", which contains a disallowed character '.'.
Assert.Equal("bar_foo", function.GetProperty("name").GetString());
}
else
{
// The original name specified in function calls is "bar-foo" and contains no disallowed characters.
Assert.Equal("bar-foo", function.GetProperty("name").GetString());
}
}

internal sealed class ChatMessageContentWithFunctionCalls : TheoryData<ChatMessageContent, bool>
{
private static readonly ChatToolCall s_functionCallWithInvalidFunctionName = ChatToolCall.CreateFunctionToolCall(id: "call123", functionName: "bar.foo", functionArguments: BinaryData.FromString("{}"));

private static readonly ChatToolCall s_functionCallWithValidFunctionName = ChatToolCall.CreateFunctionToolCall(id: "call123", functionName: "bar-foo", functionArguments: BinaryData.FromString("{}"));

public ChatMessageContentWithFunctionCalls()
{
this.AddMessagesWithFunctionCallsWithInvalidFunctionName();
}

private void AddMessagesWithFunctionCallsWithInvalidFunctionName()
{
// Case when function calls are available via the `Tools` property.
this.Add(new OpenAIChatMessageContent(AuthorRole.Assistant, "", "", [s_functionCallWithInvalidFunctionName]), true);

// Case when function calls are available via the `ChatResponseMessage.FunctionToolCalls` metadata as an array of ChatToolCall type.
this.Add(new ChatMessageContent(AuthorRole.Assistant, "", metadata: new Dictionary<string, object?>()
{
[OpenAIChatMessageContent.FunctionToolCallsProperty] = new ChatToolCall[] { s_functionCallWithInvalidFunctionName }
}), true);

// Case when function calls are available via the `ChatResponseMessage.FunctionToolCalls` metadata as an array of JsonElement type.
this.Add(new ChatMessageContent(AuthorRole.Assistant, "", metadata: new Dictionary<string, object?>()
{
[OpenAIChatMessageContent.FunctionToolCallsProperty] = JsonSerializer.Deserialize<JsonElement>($$"""[{"Id": "{{s_functionCallWithInvalidFunctionName.Id}}", "Name": "{{s_functionCallWithInvalidFunctionName.FunctionName}}", "Arguments": "{{s_functionCallWithInvalidFunctionName.FunctionArguments}}"}]""")
}), true);
}

private void AddMessagesWithFunctionCallsWithValidFunctionName()
{
// Case when function calls are available via the `Tools` property.
this.Add(new OpenAIChatMessageContent(AuthorRole.Assistant, "", "", [s_functionCallWithValidFunctionName]), false);

// Case when function calls are available via the `ChatResponseMessage.FunctionToolCalls` metadata as an array of ChatToolCall type.
this.Add(new ChatMessageContent(AuthorRole.Assistant, "", metadata: new Dictionary<string, object?>()
{
[OpenAIChatMessageContent.FunctionToolCallsProperty] = new ChatToolCall[] { s_functionCallWithValidFunctionName }
}), false);

// Case when function calls are available via the `ChatResponseMessage.FunctionToolCalls` metadata as an array of JsonElement type.
this.Add(new ChatMessageContent(AuthorRole.Assistant, "", metadata: new Dictionary<string, object?>()
{
[OpenAIChatMessageContent.FunctionToolCallsProperty] = JsonSerializer.Deserialize<JsonElement>($$"""[{"Id": "{{s_functionCallWithValidFunctionName.Id}}", "Name": "{{s_functionCallWithValidFunctionName.FunctionName}}", "Arguments": "{{s_functionCallWithValidFunctionName.FunctionArguments}}"}]""")
}), false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
Expand All @@ -26,6 +27,13 @@ namespace Microsoft.SemanticKernel.Connectors.OpenAI;
/// </summary>
internal partial class ClientCore
{
#if NET
[GeneratedRegex("[^a-zA-Z0-9_-]")]
private static partial Regex DisallowedFunctionNameCharactersRegex();
#else
private static Regex DisallowedFunctionNameCharactersRegex() => new("[^a-zA-Z0-9_-]", RegexOptions.Compiled);
#endif

protected const string ModelProvider = "openai";
protected record ToolCallingConfig(IList<ChatTool>? Tools, ChatToolChoice? Choice, bool AutoInvoke, bool AllowAnyRequestedKernelFunction, FunctionChoiceBehaviorOptions? Options);

Expand Down Expand Up @@ -752,7 +760,7 @@ private static List<ChatMessage> CreateRequestMessages(ChatMessageContent messag
return [new AssistantChatMessage(message.Content) { ParticipantName = message.AuthorName }];
}

var assistantMessage = new AssistantChatMessage(toolCalls) { ParticipantName = message.AuthorName };
var assistantMessage = new AssistantChatMessage(SanitizeFunctionNames(toolCalls)) { ParticipantName = message.AuthorName };

// If message content is null, adding it as empty string,
// because chat message content must be string.
Expand Down Expand Up @@ -1054,4 +1062,27 @@ private void ProcessNonFunctionToolCalls(IEnumerable<ChatToolCall> toolCalls, Ch
chatHistory.Add(message);
}
}

/// <summary>
/// Sanitizes function names by replacing disallowed characters.
/// </summary>
/// <param name="toolCalls">The function calls containing the function names which need to be sanitized.</param>
/// <returns>The function calls with sanitized function names.</returns>
private static List<ChatToolCall> SanitizeFunctionNames(List<ChatToolCall> toolCalls)
{
for (int i = 0; i < toolCalls.Count; i++)
{
ChatToolCall tool = toolCalls[i];

// Check if function name contains disallowed characters and replace them with '_'.
if (DisallowedFunctionNameCharactersRegex().IsMatch(tool.FunctionName))
{
var sanitizedName = DisallowedFunctionNameCharactersRegex().Replace(tool.FunctionName, "_");

toolCalls[i] = ChatToolCall.CreateFunctionToolCall(tool.Id, sanitizedName, tool.FunctionArguments);
}
}

return toolCalls;
}
}

0 comments on commit e98155a

Please sign in to comment.