Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit b6b117a

Browse files
reimplement completions api
1 parent de6be93 commit b6b117a

File tree

3 files changed

+92
-55
lines changed

3 files changed

+92
-55
lines changed

‎examples/src/main/kotlin/completion/Completion.kt‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import com.cjcrafter.openai.openAI
55
import io.github.cdimascio.dotenv.dotenv
66

77
/**
8-
* In this Kotlin example, we will be using the Chat API to create a simple chatbot.
8+
* In this Kotlin example, we will be using the Completions API to generate a response.
99
*/
1010
fun main() {
1111

@@ -17,7 +17,7 @@ fun main() {
1717
// Here you can change the model's settings, add tools, and more.
1818
val request = completionRequest {
1919
model("davinci")
20-
prompt("What is 9+10")
20+
prompt("The wheels on the bus go")
2121
}
2222

2323
val completion = openai.createCompletion(request)[0]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package completion
2+
3+
import com.cjcrafter.openai.completions.completionRequest
4+
import com.cjcrafter.openai.openAI
5+
import io.github.cdimascio.dotenv.dotenv
6+
7+
/**
8+
* In this Kotlin example, we will be using the Completions API to generate a
9+
* response. We will stream the tokens 1 at a time for a faster response time.
10+
*/
11+
fun main() {
12+
13+
// To use dotenv, you need to add the "io.github.cdimascio:dotenv-kotlin:version"
14+
// dependency. Then you can add a .env file in your project directory.
15+
val key = dotenv()["OPENAI_TOKEN"]
16+
val openai = openAI { apiKey(key) }
17+
18+
// Here you can change the model's settings, add tools, and more.
19+
val request = completionRequest {
20+
model("davinci")
21+
prompt("The wheels on the bus go")
22+
maxTokens(500)
23+
}
24+
25+
for (chunk in openai.streamCompletion(request)) {
26+
print(chunk.choices[0].text)
27+
}
28+
}

‎src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt‎

Lines changed: 62 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import com.cjcrafter.openai.chat.*
44
import com.cjcrafter.openai.completions.CompletionRequest
55
import com.cjcrafter.openai.completions.CompletionResponse
66
import com.cjcrafter.openai.completions.CompletionResponseChunk
7-
import com.cjcrafter.openai.completions.CompletionUsage
7+
import com.fasterxml.jackson.databind.JavaType
88
import com.fasterxml.jackson.databind.node.ObjectNode
99
import okhttp3.*
1010
import okhttp3.MediaType.Companion.toMediaType
@@ -32,62 +32,40 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
3232
.post(body).build()
3333
}
3434

35-
override fun createCompletion(request: CompletionRequest): CompletionResponse {
36-
@Suppress("DEPRECATION")
37-
request.stream = false // use streamCompletion for stream=true
38-
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)
39-
40-
val httpResponse = client.newCall(httpRequest).execute()
41-
println(httpResponse)
42-
43-
return CompletionResponse("1", 1, "1", listOf(), CompletionUsage(1, 1, 1))
44-
}
45-
46-
override fun streamCompletion(request: CompletionRequest): Iterable<CompletionResponseChunk> {
47-
@Suppress("DEPRECATION")
48-
request.stream = true // use createCompletion for stream=false
49-
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)
50-
51-
return listOf()
52-
}
53-
54-
override fun createChatCompletion(request: ChatRequest): ChatResponse {
55-
@Suppress("DEPRECATION")
56-
request.stream = false // use streamChatCompletion for stream=true
57-
val httpRequest = buildRequest(request, CHAT_ENDPOINT)
58-
35+
protected open fun <T> executeRequest(httpRequest: Request, responseType: Class<T>): T {
5936
val httpResponse = client.newCall(httpRequest).execute()
6037
if (!httpResponse.isSuccessful) {
6138
val json = httpResponse.body?.byteStream()?.bufferedReader()?.readText()
6239
httpResponse.close()
63-
throw IOException("Unexpected code $httpResponse, recieved: $json")
40+
throw IOException("Unexpected code $httpResponse, received: $json")
6441
}
6542

66-
val json = httpResponse.body?.byteStream()?.bufferedReader() ?: throw IOException("Response body is null")
67-
val str = json.readText()
68-
return objectMapper.readValue(str, ChatResponse::class.java)
43+
val jsonReader = httpResponse.body?.byteStream()?.bufferedReader()
44+
?: throw IOException("Response body is null")
45+
val responseStr = jsonReader.readText()
46+
return objectMapper.readValue(responseStr, responseType)
6947
}
7048

71-
override fun streamChatCompletion(request: ChatRequest): Iterable<ChatResponseChunk> {
72-
request.stream = true // Set streaming to true
73-
val httpRequest = buildRequest(request, CHAT_ENDPOINT)
74-
75-
return object : Iterable<ChatResponseChunk> {
76-
override fun iterator(): Iterator<ChatResponseChunk> {
77-
val httpResponse = client.newCall(httpRequest).execute()
49+
private fun <T> streamResponses(
50+
request: Request,
51+
responseType: JavaType,
52+
updateResponse: (T, String) -> T
53+
): Iterable<T> {
54+
return object : Iterable<T> {
55+
override fun iterator(): Iterator<T> {
56+
val httpResponse = client.newCall(request).execute()
7857

7958
if (!httpResponse.isSuccessful) {
8059
httpResponse.close()
8160
throw IOException("Unexpected code $httpResponse")
8261
}
8362

84-
val reader = httpResponse.body?.byteStream()?.bufferedReader() ?: throw IOException("Response body is null")
63+
val reader = httpResponse.body?.byteStream()?.bufferedReader()
64+
?: throw IOException("Response body is null")
8565

86-
// Only instantiate 1 ChatResponseChunk, otherwise simply update
87-
// the existing one. This lets us accumulate the message.
88-
var chunk: ChatResponseChunk? = null
66+
var currentResponse: T? = null
8967

90-
return object : Iterator<ChatResponseChunk> {
68+
return object : Iterator<T> {
9169
private var nextLine: String? = readNextLine(reader)
9270

9371
private fun readNextLine(reader: BufferedReader): String? {
@@ -98,8 +76,6 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
9876
reader.close()
9977
return null
10078
}
101-
102-
// Check if the line starts with 'data:' and skip empty lines
10379
} while (line != null && (line.isEmpty() || !line.startsWith("data: ")))
10480
return line?.removePrefix("data: ")
10581
}
@@ -108,24 +84,57 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
10884
return nextLine != null
10985
}
11086

111-
override fun next(): ChatResponseChunk {
112-
val currentLine = nextLine ?: throw NoSuchElementException("No more lines")
113-
//println(" $currentLine")
114-
chunk = chunk?.apply { update(objectMapper.readTree(currentLine) as ObjectNode) } ?: objectMapper.readValue(currentLine, ChatResponseChunk::class.java)
115-
nextLine = readNextLine(reader) // Prepare the next line
116-
return chunk!!
117-
//return ChatResponseChunk("1", 1, listOf())
87+
override fun next(): T {
88+
val line = nextLine ?: throw NoSuchElementException("No more lines")
89+
currentResponse = if (currentResponse == null) {
90+
objectMapper.readValue(line, responseType)
91+
} else {
92+
updateResponse(currentResponse!!, line)
93+
}
94+
nextLine = readNextLine(reader)
95+
return currentResponse!!
11896
}
11997
}
12098
}
12199
}
122100
}
123101

102+
override fun createCompletion(request: CompletionRequest): CompletionResponse {
103+
@Suppress("DEPRECATION")
104+
request.stream = false // use streamCompletion for stream=true
105+
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)
106+
return executeRequest(httpRequest, CompletionResponse::class.java)
107+
}
108+
109+
override fun streamCompletion(request: CompletionRequest): Iterable<CompletionResponseChunk> {
110+
@Suppress("DEPRECATION")
111+
request.stream = true
112+
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)
113+
return streamResponses(httpRequest, objectMapper.typeFactory.constructType(CompletionResponseChunk::class.java)) { response, newLine ->
114+
// We don't have any update logic, so we should ignore the old response and just return a new one
115+
objectMapper.readValue(newLine, CompletionResponseChunk::class.java)
116+
}
117+
}
118+
119+
override fun createChatCompletion(request: ChatRequest): ChatResponse {
120+
@Suppress("DEPRECATION")
121+
request.stream = false // use streamChatCompletion for stream=true
122+
val httpRequest = buildRequest(request, CHAT_ENDPOINT)
123+
return executeRequest(httpRequest, ChatResponse::class.java)
124+
}
125+
126+
override fun streamChatCompletion(request: ChatRequest): Iterable<ChatResponseChunk> {
127+
@Suppress("DEPRECATION")
128+
request.stream = true
129+
val httpRequest = buildRequest(request, CHAT_ENDPOINT)
130+
return streamResponses(httpRequest, objectMapper.typeFactory.constructType(ChatResponseChunk::class.java)) { response, newLine ->
131+
response.update(objectMapper.readTree(newLine) as ObjectNode)
132+
response
133+
}
134+
}
135+
124136
companion object {
125137
const val COMPLETIONS_ENDPOINT = "v1/completions"
126138
const val CHAT_ENDPOINT = "v1/chat/completions"
127-
const val IMAGE_CREATE_ENDPOINT = "v1/images/generations"
128-
const val IMAGE_EDIT_ENDPOINT = "v1/images/edits"
129-
const val IMAGE_VARIATION_ENDPOINT = "v1/images/variations"
130139
}
131140
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /