|
| 1 | +package chat; |
| 2 | + |
| 3 | +import com.cjcrafter.openai.OpenAI; |
| 4 | +import com.cjcrafter.openai.chat.ChatMessage; |
| 5 | +import com.cjcrafter.openai.chat.ChatRequest; |
| 6 | +import com.cjcrafter.openai.chat.ChatResponseChunk; |
| 7 | +import com.cjcrafter.openai.chat.ChatUser; |
| 8 | +import com.cjcrafter.openai.chat.tool.*; |
| 9 | +import com.cjcrafter.openai.exception.HallucinationException; |
| 10 | +import com.fasterxml.jackson.databind.JsonNode; |
| 11 | +import io.github.cdimascio.dotenv.Dotenv; |
| 12 | +import org.mariuszgromada.math.mxparser.Expression; |
| 13 | +import org.mariuszgromada.math.mxparser.License; |
| 14 | + |
| 15 | +import java.util.ArrayList; |
| 16 | +import java.util.List; |
| 17 | +import java.util.Map; |
| 18 | +import java.util.Scanner; |
| 19 | + |
| 20 | +/** |
| 21 | + * In this Java example, we will be using the Chat API to create a simple chatbot. |
| 22 | + * Instead of waiting for the full response to generate, we will "stream" tokens |
| 23 | + * 1 by 1 as they are generated. We will also add a Math tool so that the chatbot |
| 24 | + * can solve math problems with a math parser. |
| 25 | + */ |
| 26 | +public class StreamChatCompletionFunction { |
| 27 | + |
| 28 | + public static void main(String[] args) { |
| 29 | + |
| 30 | + // Use mXparser |
| 31 | + License.iConfirmNonCommercialUse("CJCrafter"); |
| 32 | + |
| 33 | + // To use dotenv, you need to add the "io.github.cdimascio:dotenv-kotlin:version" |
| 34 | + // dependency. Then you can add a .env file in your project directory. |
| 35 | + String key = Dotenv.load().get("OPENAI_TOKEN"); |
| 36 | + OpenAI openai = OpenAI.builder() |
| 37 | + .apiKey(key) |
| 38 | + .build(); |
| 39 | + |
| 40 | + // Notice that this is a *mutable* list. We will be adding messages later |
| 41 | + // so we can continue the conversation. |
| 42 | + List<ChatMessage> messages = new ArrayList<>(); |
| 43 | + messages.add(ChatMessage.toSystemMessage("Help the user with their problem.")); |
| 44 | + |
| 45 | + // Here you can change the model's settings, add tools, and more. |
| 46 | + ChatRequest request = ChatRequest.builder() |
| 47 | + .model("gpt-3.5-turbo") |
| 48 | + .messages(messages) |
| 49 | + .addTool(FunctionTool.builder() |
| 50 | + .name("solve_math_problem") |
| 51 | + .description("Returns the result of a math problem as a double") |
| 52 | + .addStringParameter("equation", "The math problem for you to solve", true) |
| 53 | + .build() |
| 54 | + ) |
| 55 | + .build(); |
| 56 | + |
| 57 | + Scanner scan = new Scanner(System.in); |
| 58 | + while (true) { |
| 59 | + System.out.println("What are you having trouble with?"); |
| 60 | + String input = scan.nextLine(); |
| 61 | + |
| 62 | + messages.add(ChatMessage.toUserMessage(input)); |
| 63 | + System.out.println("Generating Response..."); |
| 64 | + |
| 65 | + boolean madeToolCall; |
| 66 | + do { |
| 67 | + madeToolCall = false; |
| 68 | + for (ChatResponseChunk chunk : openai.streamChatCompletion(request)) { |
| 69 | + String delta = chunk.get(0).getDeltaContent(); |
| 70 | + if (delta != null) |
| 71 | + System.out.print(delta); |
| 72 | + |
| 73 | + // When the response is finished, we can add it to the messages list. |
| 74 | + if (chunk.get(0).isFinished()) |
| 75 | + messages.add(chunk.get(0).getMessage()); |
| 76 | + } |
| 77 | + |
| 78 | + // If the API returned a tool call to us, we need to handle it. |
| 79 | + List<ToolCall> toolCalls = messages.get(messages.size() - 1).getToolCalls(); |
| 80 | + if (toolCalls != null) { |
| 81 | + madeToolCall = true; |
| 82 | + for (ToolCall call : toolCalls) { |
| 83 | + ChatMessage response = handleToolCall(call, request.getTools()); |
| 84 | + messages.add(response); |
| 85 | + } |
| 86 | + } |
| 87 | + |
| 88 | + // Loop until we get a message without tool calls |
| 89 | + } while (madeToolCall); |
| 90 | + |
| 91 | + // Print a new line to separate the messages |
| 92 | + System.out.println(); |
| 93 | + } |
| 94 | + } |
| 95 | + |
| 96 | + public static ChatMessage handleToolCall(ToolCall call, List<Tool> validTools) { |
| 97 | + // The try-catch here is *crucial*. ChatGPT *isn't very good* |
| 98 | + // at tool calls (And you probably aren't very good at prompt |
| 99 | + // engineering yet!). OpenAI will often "Hallucinate" arguments. |
| 100 | + try { |
| 101 | + if (call.getType() != ToolType.FUNCTION) |
| 102 | + throw new HallucinationException("Unknown tool call type: " + call.getType()); |
| 103 | + |
| 104 | + FunctionCall function = call.getFunction(); |
| 105 | + Map<String, JsonNode> arguments = function.tryParseArguments(validTools); // You can pass null here for less strict parsing |
| 106 | + String equation = arguments.get("equation").asText(); |
| 107 | + double result = solveEquation(equation); |
| 108 | + |
| 109 | + // NaN implies that the equation was invalid |
| 110 | + if (Double.isNaN(result)) |
| 111 | + throw new HallucinationException("Format was invalid: " + equation); |
| 112 | + |
| 113 | + // Add the result to the messages list |
| 114 | + String json = "{\"result\": " + result + "}"; |
| 115 | + return new ChatMessage(ChatUser.TOOL, json, null, call.getId()); |
| 116 | + |
| 117 | + } catch (HallucinationException ex) { |
| 118 | + |
| 119 | + // Lets let ChatGPT know it made a mistake so it can correct itself |
| 120 | + String json = "{\"error\": \"" + ex.getMessage() + "\"}"; |
| 121 | + return new ChatMessage(ChatUser.TOOL, json, null, call.getId()); |
| 122 | + } |
| 123 | + } |
| 124 | + |
| 125 | + public static double solveEquation(String equation) { |
| 126 | + Expression expression = new Expression(equation); |
| 127 | + return expression.calculate(); |
| 128 | + } |
| 129 | +} |
0 commit comments