diff --git a/Demo/App/APIProvidedView.swift b/Demo/App/APIProvidedView.swift index 25e7ad87..e13c3f8a 100644 --- a/Demo/App/APIProvidedView.swift +++ b/Demo/App/APIProvidedView.swift @@ -42,7 +42,9 @@ struct APIProvidedView: View { miscStore: miscStore ) .onChange(of: apiKey) { newApiKey in - chatStore.openAIClient = OpenAI(apiToken: newApiKey) + let client = OpenAI(apiToken: newApiKey) + chatStore.openAIClient = client + miscStore.openAIClient = client } } } diff --git a/Demo/DemoChat/Sources/ChatStore.swift b/Demo/DemoChat/Sources/ChatStore.swift index a24b095d..51ee6b11 100644 --- a/Demo/DemoChat/Sources/ChatStore.swift +++ b/Demo/DemoChat/Sources/ChatStore.swift @@ -85,22 +85,53 @@ public final class ChatStore: ObservableObject { return } + let weatherFunction = ChatFunctionDeclaration( + name: "getWeatherData", + description: "Get the current weather in a given location", + parameters: .init( + type: .object, + properties: [ + "location": .init(type: .string, description: "The city and state, e.g. San Francisco, CA") + ], + required: ["location"] + ) + ) + + let functions = [weatherFunction] + let chatsStream: AsyncThrowingStream = openAIClient.chatsStream( query: ChatQuery( model: model, messages: conversation.messages.map { message in Chat(role: message.role, content: message.content) - } + }, + functions: functions ) ) + var functionCallName = "" + var functionCallArguments = "" for try await partialChatResult in chatsStream { for choice in partialChatResult.choices { let existingMessages = conversations[conversationIndex].messages + // Function calls are also streamed, so we need to accumulate. + if let functionCallDelta = choice.delta.functionCall { + if let nameDelta = functionCallDelta.name { + functionCallName += nameDelta + } + if let argumentsDelta = functionCallDelta.arguments { + functionCallArguments += argumentsDelta + } + } + var messageText = choice.delta.content ?? "" + if let finishReason = choice.finishReason, + finishReason == "function_call" { + messageText += "Function call: name=\(functionCallName) arguments=\(functionCallArguments)" + } let message = Message( id: partialChatResult.id, role: choice.delta.role ?? .assistant, - content: choice.delta.content ?? "", + content: messageText, createdAt: Date(timeIntervalSince1970: TimeInterval(partialChatResult.created)) ) if let existingMessageIndex = existingMessages.firstIndex(where: { $0.id == partialChatResult.id }) { diff --git a/Demo/DemoChat/Sources/UI/DetailView.swift b/Demo/DemoChat/Sources/UI/DetailView.swift index cd8529ac..55ff60af 100644 --- a/Demo/DemoChat/Sources/UI/DetailView.swift +++ b/Demo/DemoChat/Sources/UI/DetailView.swift @@ -17,9 +17,9 @@ struct DetailView: View { @State var inputText: String = "" @FocusState private var isFocused: Bool @State private var showsModelSelectionSheet = false - @State private var selectedChatModel: Model = .gpt3_5Turbo + @State private var selectedChatModel: Model = .gpt4_0613 - private let availableChatModels: [Model] = [.gpt3_5Turbo, .gpt4] + private let availableChatModels: [Model] = [.gpt3_5Turbo0613, .gpt4_0613] let conversation: Conversation let error: Error? @@ -237,6 +237,14 @@ struct ChatBubble: View { .foregroundColor(userForegroundColor) .background(userBackgroundColor) .clipShape(RoundedRectangle(cornerRadius: 16, style: .continuous)) + case .function: + Text(message.content) + .font(.footnote.monospaced()) + .padding(.horizontal, 16) + .padding(.vertical, 12) + .background(assistantBackgroundColor) + .clipShape(RoundedRectangle(cornerRadius: 16, style: .continuous)) + Spacer(minLength: 24) case .system: EmptyView() } @@ -252,7 +260,14 @@ struct DetailView_Previews: PreviewProvider { messages: [ Message(id: "1", role: .assistant, content: "Hello, how can I help you today?", createdAt: Date(timeIntervalSinceReferenceDate: 0)), Message(id: "2", role: .user, content: "I need help with my subscription.", createdAt: Date(timeIntervalSinceReferenceDate: 100)), - Message(id: "3", role: .assistant, content: "Sure, what seems to be the problem with your subscription?", createdAt: Date(timeIntervalSinceReferenceDate: 200)) + Message(id: "3", role: .assistant, content: "Sure, what seems to be the problem with your subscription?", createdAt: Date(timeIntervalSinceReferenceDate: 200)), + Message(id: "4", role: .function, content: + """ + get_current_weather({ + "location": "Glasgow, Scotland", + "format": "celsius" + }) + """, createdAt: Date(timeIntervalSinceReferenceDate: 200)) ] ), error: nil, diff --git a/README.md b/README.md index d130c5fe..0e30cd22 100644 --- a/README.md +++ b/README.md @@ -208,6 +208,8 @@ Using the OpenAI Chat API, you can build your own applications with `gpt-3.5-tur public let model: Model /// The messages to generate chat completions for public let messages: [Chat] + /// A list of functions the model may generate JSON inputs for. + public let functions: [ChatFunctionDeclaration]? /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and We generally recommend altering this or top_p but not both. public let temperature: Double? /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. @@ -318,6 +320,61 @@ for try await result in openAI.chatsStream(query: query) { } ``` +**Function calls** +```swift +let openAI = OpenAI(apiToken: "...") +// Declare functions which GPT-3 might decide to call. +let functions = [ + ChatFunctionDeclaration( + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: + JSONSchema( + type: .object, + properties: [ + "location": .init(type: .string, description: "The city and state, e.g. San Francisco, CA"), + "unit": .init(type: .string, enumValues: ["celsius", "fahrenheit"]) + ], + required: ["location"] + ) + ) +] +let query = ChatQuery( + model: "gpt-3.5-turbo-0613", // 0613 is the earliest version with function calls support. + messages: [ + Chat(role: .user, content: "What's the weather like in Boston?") + ], + functions: functions +) +let result = try await openAI.chats(query: query) +``` + +Result will be (serialized as JSON here for readability): +```json +{ + "id": "chatcmpl-1234", + "object": "chat.completion", + "created": 1686000000, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "function_call": { + "name": "get_current_weather", + "arguments": "{\n \"location\": \"Boston, MA\"\n}" + } + }, + "finish_reason": "function_call" + } + ], + "usage": { "total_tokens": 100, "completion_tokens": 18, "prompt_tokens": 82 } +} + +``` + + Review [Chat Documentation](https://platform.openai.com/docs/guides/chat) for more info. ### Images diff --git a/Sources/OpenAI/Public/Models/ChatQuery.swift b/Sources/OpenAI/Public/Models/ChatQuery.swift index d7d4d8fc..1ca78826 100644 --- a/Sources/OpenAI/Public/Models/ChatQuery.swift +++ b/Sources/OpenAI/Public/Models/ChatQuery.swift @@ -9,25 +9,188 @@ import Foundation public struct Chat: Codable, Equatable { public let role: Role - public let content: String + /// The contents of the message. `content` is required for all messages except assistant messages with function calls. + public let content: String? + /// The name of the author of this message. `name` is required if role is `function`, and it should be the name of the function whose response is in the `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters. + public let name: String? + public let functionCall: ChatFunctionCall? public enum Role: String, Codable, Equatable { case system case assistant case user + case function } - - public init(role: Role, content: String) { + + enum CodingKeys: String, CodingKey { + case role + case content + case name + case functionCall = "function_call" + } + + public init(role: Role, content: String? = nil, name: String? = nil, functionCall: ChatFunctionCall? = nil) { self.role = role self.content = content + self.name = name + self.functionCall = functionCall + } +} + +public struct ChatFunctionCall: Codable, Equatable { + /// The name of the function to call. + public let name: String? + /// The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + public let arguments: String? +} + +/// See the [guide](/docs/guides/gpt/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format. +public struct JSONSchema: Codable, Equatable { + public let type: JSONType + public let properties: [String: Property]? + public let required: [String]? + public let pattern: String? + public let const: String? + public let enumValues: [String]? + public let multipleOf: Int? + public let minimum: Int? + public let maximum: Int? + + private enum CodingKeys: String, CodingKey { + case type, properties, required, pattern, const + case enumValues = "enum" + case multipleOf, minimum, maximum + } + + public struct Property: Codable, Equatable { + public let type: JSONType + public let description: String? + public let format: String? + public let items: Items? + public let required: [String]? + public let pattern: String? + public let const: String? + public let enumValues: [String]? + public let multipleOf: Int? + public let minimum: Double? + public let maximum: Double? + public let minItems: Int? + public let maxItems: Int? + public let uniqueItems: Bool? + + private enum CodingKeys: String, CodingKey { + case type, description, format, items, required, pattern, const + case enumValues = "enum" + case multipleOf, minimum, maximum + case minItems, maxItems, uniqueItems + } + + public init(type: JSONType, description: String? = nil, format: String? = nil, items: Items? = nil, required: [String]? = nil, pattern: String? = nil, const: String? = nil, enumValues: [String]? = nil, multipleOf: Int? = nil, minimum: Double? = nil, maximum: Double? = nil, minItems: Int? = nil, maxItems: Int? = nil, uniqueItems: Bool? = nil) { + self.type = type + self.description = description + self.format = format + self.items = items + self.required = required + self.pattern = pattern + self.const = const + self.enumValues = enumValues + self.multipleOf = multipleOf + self.minimum = minimum + self.maximum = maximum + self.minItems = minItems + self.maxItems = maxItems + self.uniqueItems = uniqueItems + } } + + public enum JSONType: String, Codable { + case integer = "integer" + case string = "string" + case boolean = "boolean" + case array = "array" + case object = "object" + case number = "number" + case `null` = "null" + } + + public struct Items: Codable, Equatable { + public let type: JSONType + public let properties: [String: Property]? + public let pattern: String? + public let const: String? + public let enumValues: [String]? + public let multipleOf: Int? + public let minimum: Double? + public let maximum: Double? + public let minItems: Int? + public let maxItems: Int? + public let uniqueItems: Bool? + + private enum CodingKeys: String, CodingKey { + case type, properties, pattern, const + case enumValues = "enum" + case multipleOf, minimum, maximum, minItems, maxItems, uniqueItems + } + + public init(type: JSONType, properties: [String : Property]? = nil, pattern: String? = nil, const: String? = nil, enumValues: [String]? = nil, multipleOf: Int? = nil, minimum: Double? = nil, maximum: Double? = nil, minItems: Int? = nil, maxItems: Int? = nil, uniqueItems: Bool? = nil) { + self.type = type + self.properties = properties + self.pattern = pattern + self.const = const + self.enumValues = enumValues + self.multipleOf = multipleOf + self.minimum = minimum + self.maximum = maximum + self.minItems = minItems + self.maxItems = maxItems + self.uniqueItems = uniqueItems + } + } + + public init(type: JSONType, properties: [String : Property]? = nil, required: [String]? = nil, pattern: String? = nil, const: String? = nil, enumValues: [String]? = nil, multipleOf: Int? = nil, minimum: Int? = nil, maximum: Int? = nil) { + self.type = type + self.properties = properties + self.required = required + self.pattern = pattern + self.const = const + self.enumValues = enumValues + self.multipleOf = multipleOf + self.minimum = minimum + self.maximum = maximum + } +} + +public struct ChatFunctionDeclaration: Codable, Equatable { + /// The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + public let name: String + + /// The description of what the function does. + public let description: String + + /// The parameters the functions accepts, described as a JSON Schema object. + public let parameters: JSONSchema + + public init(name: String, description: String, parameters: JSONSchema) { + self.name = name + self.description = description + self.parameters = parameters + } +} + +public struct ChatQueryFunctionCall: Codable, Equatable { + /// The name of the function to call. + public let name: String? + /// The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + public let arguments: String? } -public struct ChatQuery: Codable, Streamable { +public struct ChatQuery: Equatable, Codable, Streamable { /// ID of the model to use. Currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported. public let model: Model /// The messages to generate chat completions for public let messages: [Chat] + /// A list of functions the model may generate JSON inputs for. + public let functions: [ChatFunctionDeclaration]? /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and We generally recommend altering this or top_p but not both. public let temperature: Double? /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. @@ -52,6 +215,7 @@ public struct ChatQuery: Codable, Streamable { enum CodingKeys: String, CodingKey { case model case messages + case functions case temperature case topP = "top_p" case n @@ -64,9 +228,10 @@ public struct ChatQuery: Codable, Streamable { case user } - public init(model: Model, messages: [Chat], temperature: Double? = nil, topP: Double? = nil, n: Int? = nil, stop: [String]? = nil, maxTokens: Int? = nil, presencePenalty: Double? = nil, frequencyPenalty: Double? = nil, logitBias: [String : Int]? = nil, user: String? = nil) { + public init(model: Model, messages: [Chat], functions: [ChatFunctionDeclaration]? = nil, temperature: Double? = nil, topP: Double? = nil, n: Int? = nil, stop: [String]? = nil, maxTokens: Int? = nil, presencePenalty: Double? = nil, frequencyPenalty: Double? = nil, logitBias: [String : Int]? = nil, user: String? = nil, stream: Bool = false) { self.model = model self.messages = messages + self.functions = functions self.temperature = temperature self.topP = topP self.n = n @@ -76,5 +241,6 @@ public struct ChatQuery: Codable, Streamable { self.frequencyPenalty = frequencyPenalty self.logitBias = logitBias self.user = user + self.stream = stream } } diff --git a/Sources/OpenAI/Public/Models/ChatStreamResult.swift b/Sources/OpenAI/Public/Models/ChatStreamResult.swift index 1475cc71..4d69713c 100644 --- a/Sources/OpenAI/Public/Models/ChatStreamResult.swift +++ b/Sources/OpenAI/Public/Models/ChatStreamResult.swift @@ -13,6 +13,16 @@ public struct ChatStreamResult: Codable, Equatable { public struct Delta: Codable, Equatable { public let content: String? public let role: Chat.Role? + /// The name of the author of this message. `name` is required if role is `function`, and it should be the name of the function whose response is in the `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters. + public let name: String? + public let functionCall: ChatFunctionCall? + + enum CodingKeys: String, CodingKey { + case role + case content + case name + case functionCall = "function_call" + } } public let index: Int diff --git a/Tests/OpenAITests/OpenAITestsDecoder.swift b/Tests/OpenAITests/OpenAITestsDecoder.swift index f2111532..bb064261 100644 --- a/Tests/OpenAITests/OpenAITestsDecoder.swift +++ b/Tests/OpenAITests/OpenAITestsDecoder.swift @@ -23,6 +23,10 @@ class OpenAITestsDecoder: XCTestCase { XCTAssertEqual(decoded, expectedValue) } + func jsonDataAsNSDictionary(_ data: Data) throws -> NSDictionary { + return NSDictionary(dictionary: try JSONSerialization.jsonObject(with: data, options: []) as! [String: Any]) + } + func testCompletions() async throws { let data = """ { @@ -102,7 +106,106 @@ class OpenAITestsDecoder: XCTestCase { ], usage: .init(promptTokens: 9, completionTokens: 12, totalTokens: 21)) try decode(data, expectedValue) } - + + func testChatQueryWithFunctionCall() async throws { + let chatQuery = ChatQuery( + model: .gpt3_5Turbo, + messages: [ + Chat(role: .user, content: "What's the weather like in Boston?") + ], + functions: [ + ChatFunctionDeclaration( + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: + JSONSchema( + type: .object, + properties: [ + "location": .init(type: .string, description: "The city and state, e.g. San Francisco, CA"), + "unit": .init(type: .string, enumValues: ["celsius", "fahrenheit"]) + ], + required: ["location"] + ) + ) + ] + ) + let expectedValue = """ + { + "model": "gpt-3.5-turbo", + "messages": [ + { "role": "user", "content": "What's the weather like in Boston?" } + ], + "functions": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { "type": "string", "enum": ["celsius", "fahrenheit"] } + }, + "required": ["location"] + } + } + ], + "stream": false + } + """ + + // To compare serialized JSONs we first convert them both into NSDictionary which are comparable (unline native swift dictionaries) + let chatQueryAsDict = try jsonDataAsNSDictionary(JSONEncoder().encode(chatQuery)) + let expectedValueAsDict = try jsonDataAsNSDictionary(expectedValue.data(using: .utf8)!) + + XCTAssertEqual(chatQueryAsDict, expectedValueAsDict) + } + + func testChatCompletionWithFunctionCall() async throws { + let data = """ + { + "id": "chatcmpl-1234", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "function_call": { + "name": "get_current_weather" + } + }, + "finish_reason": "function_call" + } + ], + "usage": { + "prompt_tokens": 82, + "completion_tokens": 18, + "total_tokens": 100 + } + } + """ + + let expectedValue = ChatResult( + id: "chatcmpl-1234", + object: "chat.completion", + created: 1677652288, + model: .gpt3_5Turbo, + choices: [ + .init(index: 0, message: + Chat(role: .assistant, + functionCall: ChatFunctionCall(name: "get_current_weather", arguments: nil)), + finishReason: "function_call") + ], + usage: .init(promptTokens: 82, completionTokens: 18, totalTokens: 100)) + try decode(data, expectedValue) + } + func testEdits() async throws { let data = """ {