Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 68aeb73

Browse files
Merge pull request #18 from ugwun/master
Add Azure OpenAI support
2 parents fc25791 + 620bfc9 commit 68aeb73

File tree

5 files changed

+357
-7
lines changed

5 files changed

+357
-7
lines changed

‎README.md‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ A community-maintained easy-to-use Java/Kotlin OpenAI API for ChatGPT, Text Comp
1313
## Features
1414
* [Completions](https://platform.openai.com/docs/api-reference/completions)
1515
* [Chat Completions](https://platform.openai.com/docs/api-reference/chat)
16+
* [Azure OpenAI](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference) support via `AzureOpenAI` class
1617

1718
## Installation
1819
For Kotlin DSL (`build.gradle.kts`), add this to your dependencies block:
@@ -85,6 +86,7 @@ public class JavaChatTest {
8586
}
8687
}
8788
```
89+
To use the Azure OpenAI API, use the `AzureOpenAI` class instead of `OpenAI`.
8890
> **Note**: OpenAI recommends using environment variables for your API token
8991
([Read more](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety)).
9092

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package com.cjcrafter.openai
2+
3+
import okhttp3.OkHttpClient
4+
import okhttp3.Request
5+
import okhttp3.RequestBody
6+
import okhttp3.RequestBody.Companion.toRequestBody
7+
8+
/**
9+
* The Azure OpenAI API client.
10+
*
11+
* See {@link OpenAI} for more information.
12+
*
13+
* This class constructs url in the form of: https://<azureBaseUrl>/openai/deployments/<modelName>/<endpoint>?api-version=<apiVersion>
14+
*
15+
* @property azureBaseUrl The base URL for the Azure OpenAI API. Usually https://<your_resource_group>.openai.azure.com
16+
* @property apiVersion The API version to use. Defaults to 2023年03月15日-preview.
17+
* @property modelName The model name to use. This is the name of the model deployed to Azure.
18+
*/
19+
class AzureOpenAI @JvmOverloads constructor(
20+
apiKey: String,
21+
organization: String? = null,
22+
client: OkHttpClient = OkHttpClient(),
23+
private val azureBaseUrl: String = "",
24+
private val apiVersion: String = "2023年03月15日-preview",
25+
private val modelName: String = ""
26+
) : OpenAI(apiKey, organization, client) {
27+
28+
override fun buildRequest(request: Any, endpoint: String): Request {
29+
val json = gson.toJson(request)
30+
val body: RequestBody = json.toRequestBody(mediaType)
31+
return Request.Builder()
32+
.url("$azureBaseUrl/openai/deployments/$modelName/$endpoint?api-version=$apiVersion")
33+
.addHeader("Content-Type", "application/json")
34+
.addHeader("api-key", apiKey)
35+
.apply { if (organization != null) addHeader("OpenAI-Organization", organization) }
36+
.post(body).build()
37+
}
38+
}

‎src/main/kotlin/com/cjcrafter/openai/OpenAI.kt‎

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ import java.util.function.Consumer
5252
* @property client Controls proxies, timeouts, etc.
5353
* @constructor Create a ChatBot for responding to requests.
5454
*/
55-
class OpenAI @JvmOverloads constructor(
56-
private val apiKey: String,
57-
private val organization: String? = null,
55+
openclass OpenAI @JvmOverloads constructor(
56+
protected val apiKey: String,
57+
protected val organization: String? = null,
5858
private val client: OkHttpClient = OkHttpClient()
5959
) {
60-
private val mediaType = "application/json; charset=utf-8".toMediaType()
61-
private val gson = createGson()
60+
protected val mediaType = "application/json; charset=utf-8".toMediaType()
61+
protected val gson = createGson()
6262

63-
private fun buildRequest(request: Any, endpoint: String): Request {
63+
protectedopen fun buildRequest(request: Any, endpoint: String): Request {
6464
val json = gson.toJson(request)
6565
val body: RequestBody = json.toRequestBody(mediaType)
6666
return Request.Builder()
@@ -95,7 +95,7 @@ class OpenAI @JvmOverloads constructor(
9595
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)
9696

9797
try {
98-
val httpResponse = client.newCall(httpRequest).execute();
98+
val httpResponse = client.newCall(httpRequest).execute()
9999
lateinit var response: CompletionResponse
100100
OpenAICallback(true, { throw it }) {
101101
response = gson.fromJson(it, CompletionResponse::class.java)

‎src/test/java/JavaTestAzure.java‎

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import com.cjcrafter.openai.AzureOpenAI;
2+
import com.cjcrafter.openai.OpenAI;
3+
import com.cjcrafter.openai.chat.ChatMessage;
4+
import com.cjcrafter.openai.chat.ChatRequest;
5+
import com.cjcrafter.openai.chat.ChatResponse;
6+
import com.cjcrafter.openai.completions.CompletionRequest;
7+
import com.cjcrafter.openai.exception.OpenAIError;
8+
import io.github.cdimascio.dotenv.Dotenv;
9+
10+
import java.util.ArrayList;
11+
import java.util.Collections;
12+
import java.util.List;
13+
import java.util.Scanner;
14+
15+
16+
public class JavaTestAzure {
17+
18+
// Colors for pretty formatting
19+
public static final String RESET = "033円[0m";
20+
public static final String BLACK = "033円[0;30m";
21+
public static final String RED = "033円[0;31m";
22+
public static final String GREEN = "033円[0;32m";
23+
public static final String YELLOW = "033円[0;33m";
24+
public static final String BLUE = "033円[0;34m";
25+
public static final String PURPLE = "033円[0;35m";
26+
public static final String CYAN = "033円[0;36m";
27+
public static final String WHITE = "033円[0;37m";
28+
29+
public static void main(String[] args) throws OpenAIError {
30+
Scanner scanner = new Scanner(System.in);
31+
32+
// Add test cases for AzureOpenAI
33+
System.out.println(GREEN + " 9. Azure Completion (create, sync)");
34+
System.out.println(GREEN + " 10. Azure Completion (stream, sync)");
35+
System.out.println(GREEN + " 11. Azure Completion (create, async)");
36+
System.out.println(GREEN + " 12. Azure Completion (stream, async)");
37+
System.out.println(GREEN + " 13. Azure Chat (create, sync)");
38+
System.out.println(GREEN + " 14. Azure Chat (stream, sync)");
39+
System.out.println(GREEN + " 15. Azure Chat (create, async)");
40+
System.out.println(GREEN + " 16. Azure Chat (stream, async)");
41+
System.out.println();
42+
43+
// Determine which method to call
44+
switch (scanner.nextLine()) {
45+
// ...
46+
case "9":
47+
doCompletionAzure(false, false);
48+
break;
49+
case "10":
50+
doCompletionAzure(true, false);
51+
break;
52+
case "11":
53+
doCompletionAzure(false, true);
54+
break;
55+
case "12":
56+
doCompletionAzure(true, true);
57+
break;
58+
case "13":
59+
doChatAzure(false, false);
60+
break;
61+
case "14":
62+
doChatAzure(true, false);
63+
break;
64+
case "15":
65+
doChatAzure(false, true);
66+
break;
67+
case "16":
68+
doChatAzure(true, true);
69+
break;
70+
default:
71+
System.err.println("Invalid option");
72+
break;
73+
}
74+
}
75+
76+
public static void doCompletionAzure(boolean stream, boolean async) throws OpenAIError {
77+
Scanner scan = new Scanner(System.in);
78+
System.out.println(YELLOW + "Enter completion: ");
79+
String input = scan.nextLine();
80+
81+
// CompletionRequest contains the data we sent to the OpenAI API. We use
82+
// 128 tokens, so we have a bit of a delay before the response (for testing).
83+
CompletionRequest request = CompletionRequest.builder()
84+
.model("davinci")
85+
.prompt(input)
86+
.maxTokens(128).build();
87+
88+
// Loads the API key from the .env file in the root directory.
89+
String key = Dotenv.load().get("OPENAI_TOKEN");
90+
OpenAI openai = new AzureOpenAI(key);
91+
System.out.println(RESET + "Generating Response" + PURPLE);
92+
93+
// Generate a print the message
94+
if (stream) {
95+
if (async)
96+
openai.streamCompletionAsync(request, response -> System.out.print(response.get(0).getText()));
97+
else
98+
openai.streamCompletion(request, response -> System.out.print(response.get(0).getText()));
99+
} else {
100+
if (async)
101+
openai.createCompletionAsync(request, response -> System.out.println(response.get(0).getText()));
102+
else
103+
System.out.println(openai.createCompletion(request).get(0).getText());
104+
}
105+
106+
System.out.println(CYAN + " !!! Code has finished executing. Wait for async code to complete." + RESET);
107+
}
108+
109+
public static void doChatAzure(boolean stream, boolean async) throws OpenAIError {
110+
Scanner scan = new Scanner(System.in);
111+
112+
// This is the prompt that the bot will refer back to for every message.
113+
ChatMessage prompt = ChatMessage.toSystemMessage("You are a helpful chatbot.");
114+
115+
// Use a mutable (modifiable) list! Always! You should be reusing the
116+
// ChatRequest variable, so in order for a conversation to continue
117+
// you need to be able to modify the list.
118+
List<ChatMessage> messages = new ArrayList<>(Collections.singletonList(prompt));
119+
120+
// ChatRequest is the request we send to OpenAI API. You can modify the
121+
// model, temperature, maxTokens, etc. This should be saved, so you can
122+
// reuse it for a conversation.
123+
ChatRequest request = ChatRequest.builder()
124+
.model("gpt-3.5-turbo")
125+
.messages(messages).build();
126+
127+
// Loads the API key from the .env file in the root directory.
128+
String key = Dotenv.load().get("OPENAI_TOKEN");
129+
OpenAI openai = new AzureOpenAI(key);
130+
131+
// The conversation lasts until the user quits the program
132+
while (true) {
133+
134+
// Prompt the user to enter a response
135+
System.out.println(YELLOW + "Enter text below:\n\n");
136+
String input = scan.nextLine();
137+
138+
// Add the newest user message to the conversation
139+
messages.add(ChatMessage.toUserMessage(input));
140+
141+
System.out.println(RESET + "Generating Response" + PURPLE);
142+
if (stream) {
143+
if (async) {
144+
openai.streamChatCompletionAsync(request, response -> {
145+
System.out.print(response.get(0).getDelta());
146+
if (response.get(0).isFinished())
147+
messages.add(response.get(0).getMessage());
148+
});
149+
} else {
150+
openai.streamChatCompletion(request, response -> {
151+
System.out.print(response.get(0).getDelta());
152+
if (response.get(0).isFinished())
153+
messages.add(response.get(0).getMessage());
154+
});
155+
}
156+
} else {
157+
if (async) {
158+
openai.createChatCompletionAsync(request, response -> {
159+
System.out.println(response.get(0).getMessage().getContent());
160+
messages.add(response.get(0).getMessage());
161+
});
162+
} else {
163+
ChatResponse response = openai.createChatCompletion(request);
164+
System.out.println(response.get(0).getMessage().getContent());
165+
messages.add(response.get(0).getMessage());
166+
}
167+
}
168+
169+
System.out.println(CYAN + " !!! Code has finished executing. Wait for async code to complete.");
170+
}
171+
}
172+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /