0

本文主要研究一下Spring AI Alibaba的RedisChatMemory

RedisChatMemory

community/memories/spring-ai-alibaba-redis-memory/src/main/java/com/alibaba/cloud/ai/memory/redis/RedisChatMemory.java

public class RedisChatMemory implements ChatMemory, AutoCloseable {
 private static final Logger logger = LoggerFactory.getLogger(RedisChatMemory.class);
 private static final String DEFAULT_KEY_PREFIX = "spring_ai_alibaba_chat_memory";
 private static final String DEFAULT_HOST = "127.0.0.1";
 private static final int DEFAULT_PORT = 6379;
 private static final String DEFAULT_PASSWORD = null;
 private final JedisPool jedisPool;
 private final Jedis jedis;
 private final ObjectMapper objectMapper;
 public RedisChatMemory() {
 this(DEFAULT_HOST, DEFAULT_PORT, DEFAULT_PASSWORD);
 }
 public RedisChatMemory(String host, int port, String password) {
 JedisPoolConfig poolConfig = new JedisPoolConfig();
 this.jedisPool = new JedisPool(poolConfig, host, port, 2000, password);
 this.jedis = jedisPool.getResource();
 this.objectMapper = new ObjectMapper();
 SimpleModule module = new SimpleModule();
 module.addDeserializer(Message.class, new MessageDeserializer());
 this.objectMapper.registerModule(module);
 logger.info("Connected to Redis at {}:{}", host, port);
 }
 @Override
 public void add(String conversationId, List<Message> messages) {
 String key = DEFAULT_KEY_PREFIX + conversationId;
 for (Message message : messages) {
 try {
 String messageJson = objectMapper.writeValueAsString(message);
 jedis.rpush(key, messageJson);
 }
 catch (JsonProcessingException e) {
 throw new RuntimeException("Error serializing message", e);
 }
 }
 logger.info("Added messages to conversationId: {}", conversationId);
 }
 @Override
 public List<Message> get(String conversationId, int lastN) {
 String key = DEFAULT_KEY_PREFIX + conversationId;
 List<String> messageStrings = jedis.lrange(key, -lastN, -1);
 List<Message> messages = new ArrayList<>();
 for (String messageString : messageStrings) {
 try {
 Message message = objectMapper.readValue(messageString, Message.class);
 messages.add(message);
 }
 catch (JsonProcessingException e) {
 throw new RuntimeException("Error deserializing message", e);
 }
 }
 logger.info("Retrieved {} messages for conversationId: {}", messages.size(), conversationId);
 return messages;
 }
 @Override
 public void clear(String conversationId) {
 String key = DEFAULT_KEY_PREFIX + conversationId;
 jedis.del(key);
 logger.info("Cleared messages for conversationId: {}", conversationId);
 }
 @Override
 public void close() {
 if (jedis != null) {
 jedis.close();
 logger.info("Redis connection closed.");
 }
 if (jedisPool != null) {
 jedisPool.close();
 logger.info("Jedis pool closed.");
 }
 }
 public void clearOverLimit(String conversationId, int maxLimit, int deleteSize) {
 try {
 String key = DEFAULT_KEY_PREFIX + conversationId;
 List<String> all = jedis.lrange(key, 0, -1);
 if (all.size() >= maxLimit) {
 all = all.stream().skip(Math.max(0, deleteSize)).toList();
 }
 this.clear(conversationId);
 for (String message : all) {
 jedis.rpush(key, message);
 }
 }
 catch (Exception e) {
 logger.error("Error clearing messages from Redis chat memory", e);
 throw new RuntimeException(e);
 }
 }
 public void updateMessageById(String conversationId, String messages) {
 String key = "spring_ai_alibaba_chat_memory:" + conversationId;
 try {
 this.jedis.del(key);
 this.jedis.rpush(key, new String[] { messages });
 }
 catch (Exception var6) {
 logger.error("Error updating messages from Redis chat memory", var6);
 throw new RuntimeException(var6);
 }
 }
}
RedisChatMemory的构造器初始化了JedisPool并给ObjectMapper注册了org.springframework.ai.chat.messages.Message类型的MessageDeserializer;其add方法遍历messages挨个序列化为json然后rpush到spring_ai_alibaba_chat_memory{conversationId}中;其get方法通过lrange取出最近n条记录,再反序列化为message对象;其clear方法直接删除该key;close方法则先关闭jedis再关闭jedisPool

MessageDeserializer

community/memories/spring-ai-alibaba-redis-memory/src/main/java/com/alibaba/cloud/ai/memory/redis/serializer/MessageDeserializer.java

public class MessageDeserializer extends JsonDeserializer<Message> {
 private static final Logger logger = LoggerFactory.getLogger(MessageDeserializer.class);
 public Message deserialize(JsonParser p, DeserializationContext ctxt) {
 ObjectMapper mapper = (ObjectMapper) p.getCodec();
 JsonNode node = null;
 Message message = null;
 try {
 node = mapper.readTree(p);
 String messageType = node.get("messageType").asText();
 switch (messageType) {
 case "USER" -> message = new UserMessage(node.get("text").asText(),
 mapper.convertValue(node.get("media"), new TypeReference<Collection<Media>>() {
 }), mapper.convertValue(node.get("metadata"), new TypeReference<Map<String, Object>>() {
 }));
 case "ASSISTANT" -> message = new AssistantMessage(node.get("text").asText(),
 mapper.convertValue(node.get("metadata"), new TypeReference<Map<String, Object>>() {
 }), (List<AssistantMessage.ToolCall>) mapper.convertValue(node.get("toolCalls"),
 new TypeReference<Collection<AssistantMessage.ToolCall>>() {
 }),
 (List<Media>) mapper.convertValue(node.get("media"), new TypeReference<Collection<Media>>() {
 }));
 default -> throw new IllegalArgumentException("Unknown message type: " + messageType);
 }
 ;
 }
 catch (IOException e) {
 logger.error("Error deserializing message", e);
 }
 return message;
 }
}
MessageDeserializer继承了JsonDeserializer,它读取messageType字段,然后对于USER类型创建UserMessage、对于ASSISTANT类型创建AssistantMessage

小结

spring-ai-alibaba-redis-memory提供了ChatMemory的redis实现,它通过jedis使用rpush添加message,通过lrange取出最近N条,通过del删除指定会话的消息。

doc


codecraft
11.9k 声望2k 粉丝

当一个代码的工匠回首往事时,不因虚度年华而悔恨,也不因碌碌无为而羞愧,这样,当他老的时候,可以很自豪告诉世人,我曾经将代码注入生命去打造互联网的浪潮之巅,那是个很疯狂的时代,我在一波波的浪潮上留下...


引用和评论

0 条评论
评论支持部分 Markdown 语法:**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用。你还可以使用 @ 来通知其他用户。

AltStyle によって変換されたページ (->オリジナル) /