1
1
package com .gzhu .funai .service .impl ;
2
2
3
- import com .baomidou .mybatisplus .core .conditions .query .QueryWrapper ;
4
3
import com .baomidou .mybatisplus .extension .service .impl .ServiceImpl ;
5
4
import com .google .common .collect .ImmutableMap ;
6
5
import com .gzhu .funai .api .openai .ChatGPTApi ;
20
19
import org .springframework .util .CollectionUtils ;
21
20
22
21
import javax .annotation .Resource ;
22
+ import java .util .Comparator ;
23
23
import java .util .List ;
24
24
import java .util .Map ;
25
+ import java .util .concurrent .ConcurrentHashMap ;
26
+ import java .util .concurrent .CopyOnWriteArrayList ;
27
+ import java .util .concurrent .CountDownLatch ;
25
28
import java .util .stream .Collectors ;
26
29
27
30
/**
@@ -107,20 +110,50 @@ public String roundRobinGetByType(ApiType apiTypes) {
107
110
}
108
111
109
112
/**
110
- * 1 初始化数据并按照apikey的类型分组
113
+ * 1 多线程判断apiKey是否能够被使用
111
114
* 2 重置轮询下标
112
- * 定时任务:每隔60分钟执行一次
115
+ * 定时任务:每隔1小时执行一次
113
116
*/
114
117
@ Scheduled (initialDelay = TimeInterval .ZERO , fixedRate = TimeInterval .ONE_HOUR )
115
118
@ Override
116
119
public void load (){
117
- Map <Integer , List <AdminApiKeyEntity >> collect = baseMapper .selectList (
118
- new QueryWrapper <AdminApiKeyEntity >().orderByDesc ("priority" )).stream ()
119
- .filter (item -> filterInValidOpenAiApiKey (item ))
120
- .collect (Collectors .groupingBy (AdminApiKeyEntity ::getType )
121
- );
122
- this .cache = ImmutableMap .copyOf (collect );
120
+ Map <Integer , List <AdminApiKeyEntity >> collect = new ConcurrentHashMap <>();
121
+ List <AdminApiKeyEntity > adminApiKeyEntityList = baseMapper .selectList (null );
122
+
123
+ // 使用减少计数辅助类让主线程等待多线程执行完毕
124
+ CountDownLatch countDownLatch = new CountDownLatch (adminApiKeyEntityList .size ());
125
+ for (AdminApiKeyEntity adminApiKeyEntity : adminApiKeyEntityList ){
126
+ queueThreadPool .execute (()->{
127
+ try {
128
+ if (isValidOpenAiApiKey (adminApiKeyEntity )){
129
+ if (!collect .containsKey (adminApiKeyEntity .getType ())){
130
+ collect .putIfAbsent (adminApiKeyEntity .getType (), new CopyOnWriteArrayList <>());
131
+ }
132
+ collect .get (adminApiKeyEntity .getType ()).add (adminApiKeyEntity );
133
+ }
134
+ }
135
+ finally {
136
+ countDownLatch .countDown ();
137
+ }
138
+ });
139
+ }
140
+ try {
141
+ countDownLatch .await ();
142
+ } catch (InterruptedException e ) {
143
+ log .error ("error{}" , e .getMessage ());
144
+ Thread .currentThread ().interrupt ();
145
+ }
146
+
147
+ // 排序
148
+ Map <Integer , List <AdminApiKeyEntity >> sortedCollect = collect .entrySet ().stream ().collect (
149
+ Collectors .toMap (Map .Entry ::getKey , entry -> entry .getValue ().stream ()
150
+ // 按照优先级字段进行降序排序 再按照id进行升序排序
151
+ .sorted (Comparator .comparing (AdminApiKeyEntity ::getPriority ).reversed ().thenComparing (AdminApiKeyEntity ::getId ))
152
+ .collect (Collectors .toList ())));
153
+
123
154
155
+ // 复制到缓存中
156
+ this .cache = ImmutableMap .copyOf (sortedCollect );
124
157
if (CollectionUtils .isEmpty (this .cache )){
125
158
this .roundRobinIndex = new int [2 ];
126
159
return ;
@@ -150,7 +183,7 @@ public String getBestByType(ApiType apiTypes) {
150
183
* @param adminApiKeyEntity
151
184
* @return
152
185
*/
153
- private boolean filterInValidOpenAiApiKey (AdminApiKeyEntity adminApiKeyEntity ){
186
+ private boolean isValidOpenAiApiKey (AdminApiKeyEntity adminApiKeyEntity ){
154
187
// 非openai类型,放行
155
188
if (!ApiType .OPENAI .typeNo .equals (adminApiKeyEntity .getType ())){
156
189
return true ;
@@ -168,7 +201,7 @@ private boolean filterInValidOpenAiApiKey(AdminApiKeyEntity adminApiKeyEntity){
168
201
// 余额不足
169
202
if (billingUsage .getTotalAmount ().compareTo (billingUsage .getTotalUsage ()) <= 0 ){
170
203
log .error ("{}的额度使用完毕!" , adminApiKeyEntity .getName ());
171
- queueThreadPool . execute (() -> baseMapper .deleteById (adminApiKeyEntity .getId () ));
204
+ baseMapper .deleteById (adminApiKeyEntity .getId ());
172
205
return false ;
173
206
}
174
207
@@ -177,14 +210,14 @@ private boolean filterInValidOpenAiApiKey(AdminApiKeyEntity adminApiKeyEntity){
177
210
adminApiKeyEntity .setTotalUsage (billingUsage .getTotalUsage ());
178
211
adminApiKeyEntity .setExpiredTime (billingUsage .getExpiredTime ());
179
212
180
- queueThreadPool . execute (() -> baseMapper .updateById (adminApiKeyEntity ) );
213
+ baseMapper .updateById (adminApiKeyEntity );
181
214
return true ;
182
215
}
183
216
// 捕获 请求openai错误的异常, 删掉这个apiKey,不加载到缓存
184
217
catch (BaseException e ){
185
218
log .error ("apiKey:{}, error:{}" ,adminApiKeyEntity .getName (), e .getMsg ());
186
219
if (e .getCode () != OpenAiRespError .OPENAI_LIMIT_ERROR .code ){
187
- queueThreadPool . execute (() -> baseMapper .deleteById (adminApiKeyEntity .getId () ));
220
+ baseMapper .deleteById (adminApiKeyEntity .getId ());
188
221
}
189
222
190
223
return false ;
0 commit comments