diff --git a/README.md b/README.md index a7c2028..14db04a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ -### 兄 Dei,有用能不能给个Star呀 -### 兄 Dei,有用能不能给个Star呀 -### 兄 Dei,有用能不能给个Star呀 +- 想看哪个模块就打开那个模块就行,因为没有使用pom管理 + ### 项目目录介绍 - [hello word](https://rstyro.github.io/blog/2017/07/25/Spring%20Boot%20%EF%BC%88%E4%B8%80%EF%BC%89%EF%BC%9A%E5%88%9D%E8%AF%86%E4%B9%8B%E5%85%A5%E9%97%A8%E7%AF%87/) *最简单的版本* - [Springboot-web](https://rstyro.github.io/blog/2017/07/27/Spring%20Boot%20(%E4%BA%8C)%EF%BC%9AWeb%20%E5%BC%80%E5%8F%91%E7%AF%87/) *web 版本的* @@ -23,6 +22,21 @@ - [SpringBoot2-Redisson](https://rstyro.github.io/blog/2019/06/25/SpringBoot2%E4%B8%8ERedisson%E7%A2%B0%E6%92%9E/) *SpringBoot 与Redisson 整合之分布式锁与发布订阅* - [SpringBoot2-RedisCacheManager](https://rstyro.github.io/blog/2019/04/16/SpringBoot%E4%B8%8ERedisCacheManager%E6%95%B4%E5%90%88/) *SpringBoot 与RedisCacheManager整合* - [Springboot2-api-encrypt](https://rstyro.github.io/blog/2020/10/22/Springboot2接口加解密全过程详解(含前端代码)/) *SpringBoot接口RSA+AES加解密(含前端代码)* -- [springboot-elk](https://rstyro.gitee.io/blog/2021/04/28/Centos7搭建ELK与Springboot整合/) *SpringBoot与ELK整合demo)* - +- [Springboot-elk](https://rstyro.gitee.io/blog/2021/04/28/Centos7搭建ELK与Springboot整合/) *SpringBoot与ELK整合demo)* +- [Springboot-sqlite](https://github.com/rstyro/spring-boot/tree/master/springboot-sqlite/) *SpringBoot与SQLite整合demo)* +- [Springboot-es](https://github.com/rstyro/spring-boot/tree/master/springboot-es/) *SpringBoot与ES 7版本以上整合demo)* +- [Springboot-neo4j-multiple-sdn](https://github.com/rstyro/spring-boot/tree/master/springboot-neo4j-multiple-sdn/) *springboot与neo4j多数据源Demo* +- [Springboot-mqtt](https://github.com/rstyro/spring-boot/tree/master/springboot-mqtt/) *Springboot集成mqtt支持多客户端* +- [Springboot-camunda](https://github.com/rstyro/spring-boot/tree/master/springboot-camunda/) *Springboot集成camunda工作流* +- [Springboot-2FA](https://github.com/rstyro/Springboot/tree/master/springboot-2FA) *Springboot集成2FA二步验证* +- [Springboot-shedlock](https://github.com/rstyro/Springboot/tree/master/springboot-shedlock) *Springboot集群部署之定时任务分布式锁* +- [Springboot-Jasypt](https://github.com/rstyro/Springboot/tree/master/springboot-jasypt) *Springboot集成Jasypt,配置加密* - ...持续更新 + + + + + +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=rstyro/Springboot&type=Date)](https://star-history.com/#rstyro/Springboot&Date) diff --git a/SpringBoot-limit/pom.xml b/SpringBoot-limit/pom.xml index 6bab3b1..3ce52f5 100644 --- a/SpringBoot-limit/pom.xml +++ b/SpringBoot-limit/pom.xml @@ -16,6 +16,7 @@ 1.8 + 3.22.0 @@ -28,7 +29,11 @@ org.projectlombok lombok - 1.18.6 + + + + org.springframework.boot + spring-boot-starter-aop @@ -41,18 +46,18 @@ commons-pool2 - - - com.alibaba - fastjson - 1.2.56 - - org.springframework.boot spring-boot-starter-test test + + + org.redisson + redisson + ${redisson.version} + + diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/LeakyBucketLimit.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/LeakyBucketLimit.java new file mode 100644 index 0000000..58f009c --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/LeakyBucketLimit.java @@ -0,0 +1,28 @@ +package top.lrshuai.limit.annotation; + +import java.lang.annotation.*; + +/** + * 漏桶限流注解 + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface LeakyBucketLimit { + + /** + * 限流key,支持SpEL表达式 + */ + String key() default ""; + + /** + * 桶的容量(最大请求数) + */ + int capacity() default 100; + + /** + * 流出速率(每秒处理多少个请求) + */ + int rate() default 10; + +} \ No newline at end of file diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/RedissonRateLimit.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/RedissonRateLimit.java new file mode 100644 index 0000000..493581b --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/RedissonRateLimit.java @@ -0,0 +1,32 @@ +package top.lrshuai.limit.annotation; + +import java.lang.annotation.*; + +/** + * redisson限流注解 + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface RedissonRateLimit { + + /** + * 限流key,支持SpEL表达式 + */ + String key() default ""; + + /** + * 令牌生成速率 (每秒生成的令牌数) + */ + long rate() default 10; + + /** + * 每次请求消耗的令牌数 + */ + int tokens() default 1; + + /** + * 限流时的提示信息 + */ + String message() default "请求过于频繁,请稍后再试"; +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/RequestLimit.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/RequestLimit.java index 99a52d5..b0daf10 100644 --- a/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/RequestLimit.java +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/RequestLimit.java @@ -3,7 +3,7 @@ import java.lang.annotation.*; /** - * 请求限制的自定义注解 + * 请求限制的自定义注解: 固定计数器限流 * * @Target 注解可修饰的对象范围,ElementType.METHOD 作用于方法,ElementType.TYPE 作用于类 * (ElementType)取值有: @@ -32,7 +32,17 @@ @Target({ElementType.METHOD,ElementType.TYPE}) @Retention(RetentionPolicy.RUNTIME) public @interface RequestLimit { - // 在 second 秒内,最大只能请求 maxCount 次 + /** + * 资源key,用于区分不同的接口,默认为方法名 + */ + String key() default ""; + + /** + * 在 second 秒内,最大只能请求 maxCount 次 + */ int second() default 1; + /** + * 在时间窗口内允许访问的次数 + */ int maxCount() default 1; } diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/SlidingWindowLimit.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/SlidingWindowLimit.java new file mode 100644 index 0000000..e7ca35b --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/SlidingWindowLimit.java @@ -0,0 +1,28 @@ +package top.lrshuai.limit.annotation; + +import java.lang.annotation.*; + +/** + * 滑动时间窗口计数器限流注解 + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface SlidingWindowLimit { + + /** + * 限流key,支持SpEL表达式 + */ + String key() default ""; + + /** + * 时间窗口大小(秒) + */ + int window() default 60; + + /** + * 时间窗口内允许的最大请求数 + */ + int maxCount() default 100; + +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/TokenBucketRateLimit.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/TokenBucketRateLimit.java new file mode 100644 index 0000000..3db63b8 --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/annotation/TokenBucketRateLimit.java @@ -0,0 +1,37 @@ +package top.lrshuai.limit.annotation; + +import java.lang.annotation.*; + +/** + * 令牌桶限流注解 + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface TokenBucketRateLimit { + + /** + * 限流key,支持SpEL表达式 + */ + String key() default ""; + + /** + * 令牌生成速率 (每秒生成的令牌数) + */ + double rate() default 10.0; + + /** + * 桶容量 + */ + int capacity() default 20; + + /** + * 每次请求消耗的令牌数 + */ + int tokens() default 1; + + /** + * 限流时的提示信息 + */ + String message() default "请求过于频繁,请稍后再试"; +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/LeakyBucketLimitAspect.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/LeakyBucketLimitAspect.java new file mode 100644 index 0000000..956d424 --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/LeakyBucketLimitAspect.java @@ -0,0 +1,72 @@ +package top.lrshuai.limit.aspect; + +import lombok.extern.slf4j.Slf4j; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.reflect.MethodSignature; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; +import top.lrshuai.limit.annotation.LeakyBucketLimit; +import top.lrshuai.limit.common.ApiException; +import top.lrshuai.limit.common.ApiResultEnum; +import top.lrshuai.limit.service.LeakyBucketRateLimiter; +import top.lrshuai.limit.util.AopUtil; + +import java.lang.reflect.Method; + + +@Slf4j +@Aspect +@Component +public class LeakyBucketLimitAspect { + + @Autowired + private LeakyBucketRateLimiter rateLimiter; + + @Around("@annotation(leakyBucketLimit)") + public Object around(ProceedingJoinPoint joinPoint, LeakyBucketLimit leakyBucketLimit) throws Throwable { + String key = buildRateLimitKey(joinPoint, leakyBucketLimit); + int capacity = leakyBucketLimit.capacity(); + int rate = leakyBucketLimit.rate(); + + LeakyBucketRateLimiter.BucketStatus bucketStatus = rateLimiter.getBucketStatus(key); + log.debug("bucket status: key={}, water={},lastLeakTime={},ttl={}",key, + bucketStatus.getCurrentWater(), bucketStatus.getLastLeakTime(), bucketStatus.getTtl()); + if (!rateLimiter.tryAcquire(key, capacity, rate, 1)) { + throw new ApiException(ApiResultEnum.REQUEST_LIMIT); + } + + return joinPoint.proceed(); + } + + /** + * 构建限流key + */ + private String buildRateLimitKey(ProceedingJoinPoint joinPoint, LeakyBucketLimit rateLimit) { + String key = rateLimit.key(); + + // 如果key为空,使用默认格式 + if (key.isEmpty()) { + MethodSignature signature = (MethodSignature) joinPoint.getSignature(); + Method method = signature.getMethod(); + String className = method.getDeclaringClass().getSimpleName(); + String methodName = method.getName(); + + // 尝试获取用户信息 + String userKey = getCurrentUserId(); + return String.format("leaky_bucket:%s:%s:%s", className, methodName, userKey); + } + + // 如果key包含SpEL表达式,进行解析 + if (key.contains("#")) { + return AopUtil.parseSpel(key, joinPoint); + } + return key; + } + + private String getCurrentUserId() { + // 实际项目中从安全上下文获取 + return "user123"; + } +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/RateLimitAspect.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/RateLimitAspect.java new file mode 100644 index 0000000..b5b2f50 --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/RateLimitAspect.java @@ -0,0 +1,97 @@ +package top.lrshuai.limit.aspect; + +import lombok.extern.slf4j.Slf4j; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.reflect.MethodSignature; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.ValueOperations; +import org.springframework.stereotype.Component; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import top.lrshuai.limit.annotation.RequestLimit; +import top.lrshuai.limit.common.ApiResultEnum; +import top.lrshuai.limit.common.R; +import top.lrshuai.limit.util.IpUtil; + +import javax.servlet.http.HttpServletRequest; +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.util.concurrent.TimeUnit; + +@Aspect +@Component +@Slf4j +public class RateLimitAspect { + + @Autowired + private RedisTemplate redisTemplate; + + + /** + * 环绕通知,切入所有被@RateLimit注解标记的方法 + * "@annotation(requestLimit)" 只匹配方法上的 + * "@within(requestLimit)" 匹配类上的 + */ + @Around("(@annotation(requestLimit) || @within(requestLimit))") + public Object around(ProceedingJoinPoint joinPoint, RequestLimit requestLimit) throws Throwable { + + // 获取HttpServletRequest对象,从而拿到客户端IP + ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); + if (attributes == null) { + // 非Web请求,直接放行 + return joinPoint.proceed(); + } + HttpServletRequest request = attributes.getRequest(); + // 获取客户端真实IP的方法 + String ip = IpUtil.getClientIpAddress(request); + + // 优先从方法上获取注解 + MethodSignature signature = (MethodSignature) joinPoint.getSignature(); + Method method = signature.getMethod(); + requestLimit = getTagAnnotation(method, RequestLimit.class); + + + // 构建Redis的key,格式为:rate_limit:接口key:IP + String methodName = method.getName(); + String key = requestLimit.key().isEmpty() ? methodName : requestLimit.key(); + String redisKey = "rate_limit:" + key + ":" + ip; + + // 操作Redis,进行计数和判断 + ValueOperations valueOps = redisTemplate.opsForValue(); + Integer currentCount = (Integer) valueOps.get(redisKey); + + if (currentCount == null) { + // 第一次访问,设置key,初始值为1,并设置过期时间 + valueOps.set(redisKey, 1, requestLimit.second(), TimeUnit.SECONDS); + } else if (currentCount < requestLimit.maxCount()) { + // 计数未达到阈值,计数器+1 (注意:这里Redis的过期时间保持不变) + valueOps.increment(redisKey); + } else { + // 计数已达到或超过阈值,抛出异常或返回错误信息 + log.warn("IP【{}】访问接口【{}】过于频繁,已被限流", ip, methodName); + return R.fail(ApiResultEnum.REQUEST_LIMIT); + } + + // 执行目标方法(即正常的业务逻辑) + return joinPoint.proceed(); + } + + /** + * 获取目标注解 + * 如果方法上有注解就返回方法上的注解配置,否则类上的 + * @param method + * @param annotationClass + * @param + * @return + */ + public A getTagAnnotation(Method method, Class annotationClass) { + // 获取方法中是否包含注解 + Annotation methodAnnotate = method.getAnnotation(annotationClass); + //获取 类中是否包含注解,也就是controller 是否有注解 + Annotation classAnnotate = method.getDeclaringClass().getAnnotation(annotationClass); + return (A) (methodAnnotate!= null?methodAnnotate:classAnnotate); + } +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/RedissonRateLimitAspect.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/RedissonRateLimitAspect.java new file mode 100644 index 0000000..e037045 --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/RedissonRateLimitAspect.java @@ -0,0 +1,63 @@ +package top.lrshuai.limit.aspect; + +import lombok.extern.slf4j.Slf4j; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.reflect.MethodSignature; +import org.redisson.api.RRateLimiter; +import org.redisson.api.RateIntervalUnit; +import org.redisson.api.RateType; +import org.redisson.api.RedissonClient; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; +import top.lrshuai.limit.annotation.RedissonRateLimit; +import top.lrshuai.limit.common.R; +import top.lrshuai.limit.util.AopUtil; + +import java.lang.reflect.Method; + +@Aspect +@Component +@Slf4j +public class RedissonRateLimitAspect { + + @Autowired + private RedissonClient redissonClient; + + /** + * 切片-方法级别 + */ + @Around("@annotation(rateLimit)") + public Object around(ProceedingJoinPoint joinPoint, RedissonRateLimit rateLimit) throws Throwable { + String key = buildRateLimitKey(joinPoint, rateLimit); + RRateLimiter rRateLimiter = redissonClient.getRateLimiter(key); + // 初始化限流器 + rRateLimiter.trySetRate(RateType.OVERALL, rateLimit.rate(), 1, RateIntervalUnit.SECONDS); + if (!rRateLimiter.tryAcquire(rateLimit.tokens())) { + log.warn("接口限流触发 - key: {}, 方法: {}", key, joinPoint.getSignature().getName()); + return R.fail(rateLimit.message()); + } + return joinPoint.proceed(); + } + + /** + * 构建限流key + */ + private String buildRateLimitKey(ProceedingJoinPoint joinPoint, RedissonRateLimit rateLimit) { + String key = rateLimit.key(); + // 如果key为空,使用默认格式 + if (key.isEmpty()) { + MethodSignature signature = (MethodSignature) joinPoint.getSignature(); + Method method = signature.getMethod(); + String className = method.getDeclaringClass().getSimpleName(); + String methodName = method.getName(); + return String.format("rate_limit:%s:%s", className, methodName); + } + // 如果key包含SpEL表达式,进行解析 + if (key.contains("#")) { + return AopUtil.parseSpel(key, joinPoint); + } + return key; + } +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/SlidingWindowLimitAspect.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/SlidingWindowLimitAspect.java new file mode 100644 index 0000000..5d77a35 --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/SlidingWindowLimitAspect.java @@ -0,0 +1,79 @@ +package top.lrshuai.limit.aspect; + +import lombok.extern.slf4j.Slf4j; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.reflect.MethodSignature; +import org.springframework.expression.Expression; +import org.springframework.expression.ExpressionParser; +import org.springframework.expression.spel.standard.SpelExpressionParser; +import org.springframework.expression.spel.support.StandardEvaluationContext; +import org.springframework.stereotype.Component; +import org.springframework.util.StringUtils; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import top.lrshuai.limit.annotation.LeakyBucketLimit; +import top.lrshuai.limit.annotation.SlidingWindowLimit; +import top.lrshuai.limit.common.ApiException; +import top.lrshuai.limit.common.ApiResultEnum; +import top.lrshuai.limit.service.SlidingWindowRateLimiter; +import top.lrshuai.limit.util.AopUtil; + +import javax.servlet.http.HttpServletRequest; +import java.lang.reflect.Method; + +@Slf4j +@Aspect +@Component +public class SlidingWindowLimitAspect { + + private final SlidingWindowRateLimiter rateLimiter; + + public SlidingWindowLimitAspect(SlidingWindowRateLimiter rateLimiter) { + this.rateLimiter = rateLimiter; + } + + @Around("@annotation(slidingWindowLimit)") + public Object around(ProceedingJoinPoint joinPoint, SlidingWindowLimit slidingWindowLimit) throws Throwable { + String key = buildRateLimitKey(joinPoint, slidingWindowLimit); + int window = slidingWindowLimit.window(); + int maxCount = slidingWindowLimit.maxCount(); + + if (!rateLimiter.tryAcquire(key, window, maxCount, 1)) { + throw new ApiException(ApiResultEnum.REQUEST_LIMIT); + } + + return joinPoint.proceed(); + } + + /** + * 构建限流key + */ + private String buildRateLimitKey(ProceedingJoinPoint joinPoint, SlidingWindowLimit rateLimit) { + String key = rateLimit.key(); + + // 如果key为空,使用默认格式 + if (key.isEmpty()) { + MethodSignature signature = (MethodSignature) joinPoint.getSignature(); + Method method = signature.getMethod(); + String className = method.getDeclaringClass().getSimpleName(); + String methodName = method.getName(); + + // 尝试获取用户信息 + String userKey = getCurrentUserId(); + return String.format("sliding_window:%s:%s:%s", className, methodName, userKey); + } + + // 如果key包含SpEL表达式,进行解析 + if (key.contains("#")) { + return AopUtil.parseSpel(key, joinPoint); + } + return key; + } + + private String getCurrentUserId() { + // 实际项目中从安全上下文获取 + return "user123"; + } +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/TokenBucketRateLimitAspect.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/TokenBucketRateLimitAspect.java new file mode 100644 index 0000000..9ab9e28 --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/aspect/TokenBucketRateLimitAspect.java @@ -0,0 +1,98 @@ +package top.lrshuai.limit.aspect; + +import lombok.extern.slf4j.Slf4j; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.reflect.MethodSignature; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.expression.Expression; +import org.springframework.expression.ExpressionParser; +import org.springframework.expression.spel.standard.SpelExpressionParser; +import org.springframework.expression.spel.support.StandardEvaluationContext; +import org.springframework.stereotype.Component; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import top.lrshuai.limit.annotation.TokenBucketRateLimit; +import top.lrshuai.limit.common.R; +import top.lrshuai.limit.service.TokenBucketRateLimiter; +import top.lrshuai.limit.util.AopUtil; +import top.lrshuai.limit.util.IpUtil; + +import javax.servlet.http.HttpServletRequest; +import java.lang.reflect.Method; + +@Aspect +@Component +@Slf4j +public class TokenBucketRateLimitAspect { + + @Autowired + private TokenBucketRateLimiter rateLimiter; + + /** + * 切片-方法级别 + */ + @Around("@annotation(rateLimit)") + public Object around(ProceedingJoinPoint joinPoint, TokenBucketRateLimit rateLimit) throws Throwable { + String key = buildRateLimitKey(joinPoint, rateLimit); + + boolean allowed = rateLimiter.tryAcquire(key,rateLimit.rate(),rateLimit.capacity(),rateLimit.tokens()); + if (!allowed) { + log.warn("接口限流触发 - key: {}, 方法: {}", key, joinPoint.getSignature().getName()); + // 这里可以返回统一的错误结果 + return R.fail(rateLimit.message()); + } + return joinPoint.proceed(); + } + + /** + * 构建限流key + */ + private String buildRateLimitKey(ProceedingJoinPoint joinPoint, TokenBucketRateLimit rateLimit) { + String key = rateLimit.key(); + + // 如果key为空,使用默认格式 + if (key.isEmpty()) { + MethodSignature signature = (MethodSignature) joinPoint.getSignature(); + Method method = signature.getMethod(); + String className = method.getDeclaringClass().getSimpleName(); + String methodName = method.getName(); + + // 尝试获取用户信息,实现更细粒度的限流 + String userKey = getUserKey(); + return String.format("rate_limit:%s:%s:%s", className, methodName, userKey); + } + + // 如果key包含SpEL表达式,进行解析 + if (key.contains("#")) { + return AopUtil.parseSpel(key, joinPoint); + } + + return key; + } + + /** + * 获取用户标识(用户ID或IP) + */ + private String getUserKey() { + try { + ServletRequestAttributes attributes = (ServletRequestAttributes) + RequestContextHolder.getRequestAttributes(); + if (attributes != null) { + HttpServletRequest request = attributes.getRequest(); + // 优先使用登录用户ID + String userId = (String) request.getAttribute("userId"); + if (userId != null) { + return "user:" + userId; + } + // 降级为使用IP + return "ip:" + IpUtil.getClientIpAddress(request); + } + } catch (Exception e) { + log.debug("获取用户标识失败", e); + } + return "anonymous"; + } + +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/ApiException.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/ApiException.java index b9d67ef..0aad269 100644 --- a/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/ApiException.java +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/ApiException.java @@ -11,14 +11,14 @@ @Data public class ApiException extends RuntimeException{ private static final long serialVersionUID = 1L; - private String status; + private int status; private String message; private Object data; private Exception exception; public ApiException() { super(); } - public ApiException(String status, String message, Object data, Exception exception) { + public ApiException(int status, String message, Object data, Exception exception) { this.status = status; this.message = message; this.data = data; diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/ApiResultEnum.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/ApiResultEnum.java index be4247b..0b2dc60 100644 --- a/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/ApiResultEnum.java +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/ApiResultEnum.java @@ -1,30 +1,30 @@ package top.lrshuai.limit.common; public enum ApiResultEnum { - SUCCESS("200","ok"), - FAILED("400","请求失败"), - ERROR("500","不知名错误"), - ERROR_NULL("501","空指针异常"), - ERROR_CLASS_CAST("502","类型转换异常"), - ERROR_RUNTION("503","运行时异常"), - ERROR_IO("504","上传文件异常"), - ERROR_MOTHODNOTSUPPORT("505","请求方法错误"), + SUCCESS(200,"ok"), + FAILED(400,"请求失败"), + ERROR(500,"不知名错误"), + ERROR_NULL(501,"空指针异常"), + ERROR_CLASS_CAST(502,"类型转换异常"), + ERROR_RUNTIME(503,"运行时异常"), + ERROR_IO(504,"上传文件异常"), + ERROR_MONTH_NOT_SUPPORT(505,"请求方法错误"), - REQUST_LIMIT("10001","请求次数受限"), + REQUEST_LIMIT(10001,"请求次数受限"), ; private String message; - private String status; + private int status; public String getMessage() { return message; } - public String getStatus() { + public int getStatus() { return status; } - private ApiResultEnum(String status, String message) { + private ApiResultEnum(int status, String message) { this.message = message; this.status = status; } diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/R.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/R.java new file mode 100644 index 0000000..1c0f7b5 --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/R.java @@ -0,0 +1,137 @@ +package top.lrshuai.limit.common; + +import lombok.Data; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +/** + * 响应信息主体 + */ +@Data +public class R implements Serializable { + + /** + * 成功 + */ + public static final int SUCCESS = 200; + public static final String SUCCESS_MSG = "success"; + + /** + * 失败 + */ + public static final int FAIL = 500; + public static final String FAIL_MSG = "fail"; + + private int code; + + private String msg; + + private String trackerId; + + private T data; + + private Map extendMap; + + /** + * 空构造,避免反序列化问题 + */ + public R() { + this.code = SUCCESS; + this.msg = SUCCESS_MSG; + } + + public R(T data, int code, String msg) { + this.code = code; + this.msg = msg; + this.data = data; + } + + public static R ok() { + return restResult(null, SUCCESS, SUCCESS_MSG); + } + + public static R ok(T data) { + return restResult(data, SUCCESS, SUCCESS_MSG); + } + + public static R ok(T data, String msg) { + return restResult(data, SUCCESS, msg); + } + + public static R fail() { + return restResult(null, FAIL, FAIL_MSG); + } + + public static R fail(String msg) { + return restResult(null, FAIL, msg); + } + + public static R fail(ApiResultEnum resultEnum) { + return restResult(null, resultEnum.getStatus(), resultEnum.getMessage()); + } + + public static R fail(T data) { + return restResult(data, FAIL, FAIL_MSG); + } + + public static R fail(T data, String msg) { + return restResult(data, FAIL, msg); + } + + public static R fail(int code, String msg) { + return restResult(null, code, msg); + } + + + private static R restResult(T data, int code, String msg) { + return new R(data,code,msg); + } + + public static Boolean isError(R ret) { + return !isSuccess(ret); + } + + public static Boolean isSuccess(R ret) { + return R.SUCCESS == ret.getCode(); + } + + public boolean isSuccess(){ + return R.SUCCESS == code; + } + + /** + * 链式调用 + */ + public R code(int code) { + this.code = code; + return this; + } + + public R msg(String msg) { + this.msg = msg; + return this; + } + + public R data(T data) { + this.data = data; + return this; + } + + /** + * 添加扩展参数 + * @param key key + * @param data value + * @return this + */ + public R addExtend(String key,Object data){ + if(extendMap==null){ + extendMap=new HashMap(); + } + extendMap.put(key,data); + return this; + } + + +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/Result.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/Result.java deleted file mode 100644 index 9c54843..0000000 --- a/SpringBoot-limit/src/main/java/top/lrshuai/limit/common/Result.java +++ /dev/null @@ -1,58 +0,0 @@ -package top.lrshuai.limit.common; - -import java.util.HashMap; -import java.util.Map; - - -public class Result extends HashMap { - - private static final long serialVersionUID = 1L; - - public Result() { - put("status", 200); - put("message", "ok"); - } - - public static Result error() { - return error("500", "系统错误,请联系管理员"); - } - - public static Result error(String msg) { - return error("500", msg); - } - - public static Result error(String status, String msg) { - Result r = new Result(); - r.put("status", status); - r.put("message", msg); - return r; - } - - public static Result error(ApiResultEnum resultEnum) { - Result r = new Result(); - r.put("status", resultEnum.getStatus()); - r.put("message", resultEnum.getMessage()); - return r; - } - - public static Result ok(Map map) { - Result r = new Result(); - r.putAll(map); - return r; - } - public static Result ok(Object data) { - Result r = new Result(); - r.put("data",data); - return r; - } - - public static Result ok() { - return new Result(); - } - - @Override - public Result put(String key, Object value) { - super.put(key, value); - return this; - } -} \ No newline at end of file diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/GlobalExceptionHandler.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/GlobalExceptionHandler.java index c3769f3..8f2eb45 100644 --- a/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/GlobalExceptionHandler.java +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/GlobalExceptionHandler.java @@ -7,7 +7,7 @@ import org.springframework.web.bind.annotation.RestControllerAdvice; import top.lrshuai.limit.common.ApiException; import top.lrshuai.limit.common.ApiResultEnum; -import top.lrshuai.limit.common.Result; +import top.lrshuai.limit.common.R; import java.io.IOException; @@ -22,45 +22,45 @@ public class GlobalExceptionHandler { private Logger logger = LoggerFactory.getLogger(GlobalExceptionHandler.class); @ExceptionHandler(NullPointerException.class) - public Result NullPointer(NullPointerException ex){ + public R NullPointer(NullPointerException ex){ logger.error(ex.getMessage(),ex); - return Result.error(ApiResultEnum.ERROR_NULL); + return R.fail(ApiResultEnum.ERROR_NULL); } @ExceptionHandler(ClassCastException.class) - public Result ClassCastException(ClassCastException ex){ + public R ClassCastException(ClassCastException ex){ logger.error(ex.getMessage(),ex); - return Result.error(ApiResultEnum.ERROR_CLASS_CAST); + return R.fail(ApiResultEnum.ERROR_CLASS_CAST); } @ExceptionHandler(IOException.class) - public Result IOException(IOException ex){ + public R IOException(IOException ex){ logger.error(ex.getMessage(),ex); - return Result.error(ApiResultEnum.ERROR_IO); + return R.fail(ApiResultEnum.ERROR_IO); } @ExceptionHandler(HttpRequestMethodNotSupportedException.class) - public Result HttpRequestMethodNotSupportedException(HttpRequestMethodNotSupportedException ex){ + public R HttpRequestMethodNotSupportedException(HttpRequestMethodNotSupportedException ex){ logger.error(ex.getMessage(),ex); - return Result.error(ApiResultEnum.ERROR_MOTHODNOTSUPPORT); + return R.fail(ApiResultEnum.ERROR_MONTH_NOT_SUPPORT); } @ExceptionHandler(ApiException.class) - public Result ApiException(ApiException ex) { + public R ApiException(ApiException ex) { logger.error(ex.getMessage(),ex); - return Result.error(ex.getStatus(),ex.getMessage()); + return R.fail(ex.getStatus(),ex.getMessage()); } @ExceptionHandler(RuntimeException.class) - public Result RuntimeException(RuntimeException ex){ + public R RuntimeException(RuntimeException ex){ logger.error(ex.getMessage(),ex); - return Result.error(ApiResultEnum.ERROR_RUNTION); + return R.fail(ApiResultEnum.ERROR_RUNTIME); } @ExceptionHandler(Exception.class) - public Result exception(Exception ex){ + public R exception(Exception ex){ logger.error(ex.getMessage(),ex); - return Result.error(ApiResultEnum.ERROR); + return R.fail(ApiResultEnum.ERROR); } } diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/RedisConfig.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/RedisConfig.java index 360cb9b..365f7a5 100644 --- a/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/RedisConfig.java +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/RedisConfig.java @@ -3,22 +3,19 @@ import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.PropertyAccessor; import com.fasterxml.jackson.databind.ObjectMapper; -import org.springframework.cache.CacheManager; -import org.springframework.cache.annotation.CachingConfigurerSupport; -import org.springframework.cache.interceptor.KeyGenerator; +import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.data.redis.cache.RedisCacheManager; -import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory; +import org.springframework.core.io.Resource; +import org.springframework.data.redis.connection.RedisConnectionFactory; import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.script.DefaultRedisScript; +import org.springframework.data.redis.core.script.RedisScript; import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer; import org.springframework.data.redis.serializer.RedisSerializer; import org.springframework.data.redis.serializer.StringRedisSerializer; -import javax.annotation.Resource; -import java.lang.reflect.Method; -import java.util.HashSet; -import java.util.Set; +import java.util.List; /** * @@ -27,49 +24,64 @@ * */ @Configuration -//@EnableCaching // 开启缓存支持 -public class RedisConfig extends CachingConfigurerSupport { - @Resource - private LettuceConnectionFactory lettuceConnectionFactory; +public class RedisConfig{ + /** + * 令牌桶-lua脚本 + */ + @Value("classpath:lua/tokenRate.lua") + private Resource tokenLuaFile; + + /** + * 漏牌-lua脚本 + */ + @Value("classpath:lua/leakyBucket.lua") + private Resource leakyLuaFile; + + /** + * 漏牌-lua脚本 + */ + @Value("classpath:lua/slidingWindow.lua") + private Resource slidingWindowLuaFile; + + /** + * 令牌桶限流 Lua 脚本 + */ @Bean - public KeyGenerator keyGenerator() { - return new KeyGenerator() { - @Override - public Object generate(Object target, Method method, Object... params) { - StringBuffer sb = new StringBuffer(); - sb.append(target.getClass().getName()); - sb.append(method.getName()); - for (Object obj : params) { - sb.append(obj.toString()); - } - return sb.toString(); - } - }; + public RedisScript tokenBucketScript() { + DefaultRedisScript redisScript = new DefaultRedisScript(); + redisScript.setLocation(tokenLuaFile); + redisScript.setResultType(List.class); + return redisScript; } - - // 缓存管理器 + /** + * 漏桶限流 Lua 脚本 + */ @Bean - public CacheManager cacheManager() { - RedisCacheManager.RedisCacheManagerBuilder builder = RedisCacheManager.RedisCacheManagerBuilder - .fromConnectionFactory(lettuceConnectionFactory); - @SuppressWarnings("serial") - Set cacheNames = new HashSet() { - { - add("codeNameCache"); - } - }; - builder.initialCacheNames(cacheNames); - return builder.build(); + public RedisScript leakyBucketScript() { + DefaultRedisScript redisScript = new DefaultRedisScript(); + redisScript.setLocation(leakyLuaFile); + redisScript.setResultType(Long.class); + return redisScript; } + /** + * 滑动时间窗口计数器限流 Lua 脚本 + */ + @Bean + public RedisScript slidingWindowScript() { + DefaultRedisScript redisScript = new DefaultRedisScript(); + redisScript.setLocation(slidingWindowLuaFile); + redisScript.setResultType(Long.class); + return redisScript; + } /** * RedisTemplate配置 */ @Bean - public RedisTemplate redisTemplate(LettuceConnectionFactory lettuceConnectionFactory) { + public RedisTemplate redisTemplate(RedisConnectionFactory redisConnectionFactory) { // 设置序列化 Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer( Object.class); @@ -79,7 +91,7 @@ public RedisTemplate redisTemplate(LettuceConnectionFactory lett jackson2JsonRedisSerializer.setObjectMapper(om); // 配置redisTemplate RedisTemplate redisTemplate = new RedisTemplate(); - redisTemplate.setConnectionFactory(lettuceConnectionFactory); + redisTemplate.setConnectionFactory(redisConnectionFactory); RedisSerializer stringSerializer = new StringRedisSerializer(); redisTemplate.setKeySerializer(stringSerializer);// key序列化 redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);// value序列化 diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/RedissonConfig.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/RedissonConfig.java new file mode 100644 index 0000000..1c8c1ea --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/RedissonConfig.java @@ -0,0 +1,46 @@ +package top.lrshuai.limit.config; + + +import org.redisson.Redisson; +import org.redisson.api.RedissonClient; +import org.redisson.config.Config; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.util.StringUtils; + +/** + * redisson 配置,下面是单节点配置: + * 官方wiki地址:https://github.com/redisson/redisson/wiki/2.-%E9%85%8D%E7%BD%AE%E6%96%B9%E6%B3%95#26-%E5%8D%95redis%E8%8A%82%E7%82%B9%E6%A8%A1%E5%BC%8F + * + */ +@Configuration +public class RedissonConfig { + + @Value("${spring.redis.host}") + private String host; + + @Value("${spring.redis.port}") + private String port; + + @Value("${spring.redis.password}") + private String password; + + @Bean + public RedissonClient redissonClient(){ + Config config = new Config(); + //单节点 + config.useSingleServer().setAddress("redis://" + host + ":" + port); + if(StringUtils.isEmpty(password)){ + config.useSingleServer().setPassword(null); + }else{ + config.useSingleServer().setPassword(password); + } + //添加主从配置 +// config.useMasterSlaveServers().setMasterAddress("").setPassword("").addSlaveAddress(new String[]{"",""}); + + // 集群模式配置 setScanInterval()扫描间隔时间,单位是毫秒, //可以用"rediss://"来启用SSL连接 +// config.useClusterServers().setScanInterval(2000).addNodeAddress("redis://127.0.0.1:7000", "redis://127.0.0.1:7001").addNodeAddress("redis://127.0.0.1:7002"); + return Redisson.create(config); + } +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/WebMvcConfig.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/WebMvcConfig.java deleted file mode 100644 index 35a29ee..0000000 --- a/SpringBoot-limit/src/main/java/top/lrshuai/limit/config/WebMvcConfig.java +++ /dev/null @@ -1,22 +0,0 @@ -package top.lrshuai.limit.config; - -import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Component; -import org.springframework.web.servlet.config.annotation.InterceptorRegistry; -import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; -import top.lrshuai.limit.interceptor.RequestLimitIntercept; - -@Slf4j -@Component -public class WebMvcConfig implements WebMvcConfigurer { - - @Autowired - private RequestLimitIntercept requestLimitIntercept; - - @Override - public void addInterceptors(InterceptorRegistry registry) { - log.info("添加拦截"); - registry.addInterceptor(requestLimitIntercept); - } -} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/IndexController.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/IndexController.java index 1ea7b5a..3428903 100644 --- a/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/IndexController.java +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/IndexController.java @@ -4,11 +4,11 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import top.lrshuai.limit.annotation.RequestLimit; -import top.lrshuai.limit.common.Result; +import top.lrshuai.limit.common.R; @RestController @RequestMapping("/index") -@RequestLimit(maxCount = 5,second = 1) +@RequestLimit(maxCount = 5,second = 10) public class IndexController { /** @@ -17,9 +17,9 @@ public class IndexController { */ @GetMapping("/test1") @RequestLimit - public Result test(){ + public R test(){ //TODO ... - return Result.ok(); + return R.ok(); } /** @@ -27,8 +27,8 @@ public Result test(){ * @return */ @GetMapping("/test2") - public Result test2(){ + public R test2(){ //TODO ... - return Result.ok(); + return R.ok(); } } diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/LeakyRateController.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/LeakyRateController.java new file mode 100644 index 0000000..2d03dac --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/LeakyRateController.java @@ -0,0 +1,51 @@ +package top.lrshuai.limit.controller; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.*; +import top.lrshuai.limit.annotation.LeakyBucketLimit; +import top.lrshuai.limit.common.R; +import top.lrshuai.limit.service.LeakyBucketRateLimiter; + +@RestController +@RequestMapping("/leakyRate") +public class LeakyRateController { + + @Autowired + private LeakyBucketRateLimiter leakyBucketRateLimiter; + + + /** + * 测试 + */ + @GetMapping("/test1") + @LeakyBucketLimit(rate = 1, capacity = 3) + public R test1() { + //TODO ... + return R.ok(); + } + + @GetMapping("/test2") + @LeakyBucketLimit(key = "leaky_rate:test2",rate = 2, capacity = 1) + public R test2() { + //TODO ... + return R.ok(); + } + + @LeakyBucketLimit(key = "'user :' + #username", rate = 1, capacity = 5) + @GetMapping("/search") + public R search(@RequestParam String username) { + // 搜索逻辑 - 这里key会根据username动态变化 + return R.ok("username:" + username); + } + + @GetMapping("/status/{key}") + public R getStatus(@PathVariable String key) { + return R.ok(leakyBucketRateLimiter.getBucketStatus( key)); + } + + @PostMapping("/reset/{key}") + public R reset(@PathVariable String key) { + leakyBucketRateLimiter.reset(key); + return R.ok(); + } +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/RedissonRateController.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/RedissonRateController.java new file mode 100644 index 0000000..e96a54a --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/RedissonRateController.java @@ -0,0 +1,19 @@ +package top.lrshuai.limit.controller; + +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; +import top.lrshuai.limit.annotation.RedissonRateLimit; +import top.lrshuai.limit.common.R; + +@RestController +@RequestMapping("/redissonRate") +public class RedissonRateController { + + @GetMapping("/queryQuotaInfo") + @RedissonRateLimit(key = "'queryQuotaInfo:' + #storageType",rate = 1) + public R queryQuotaInfo(@RequestParam(value = "storageType") String storageType) { + return R.ok("storageType:"+storageType); + } +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/SlidingWindowRateController.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/SlidingWindowRateController.java new file mode 100644 index 0000000..1004ee8 --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/SlidingWindowRateController.java @@ -0,0 +1,50 @@ +package top.lrshuai.limit.controller; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.*; +import top.lrshuai.limit.annotation.SlidingWindowLimit; +import top.lrshuai.limit.common.R; +import top.lrshuai.limit.service.SlidingWindowRateLimiter; + +@RestController +@RequestMapping("/SlidingWindowRate") +public class SlidingWindowRateController { + + @Autowired + private SlidingWindowRateLimiter slidingWindowRateLimiter; + + /** + * 测试 + */ + @GetMapping("/test1") + @SlidingWindowLimit(window = 3, maxCount = 1) + public R test1() { + //TODO ... + return R.ok(); + } + + @GetMapping("/test2") + @SlidingWindowLimit(key = "sliding_window:test2",window = 60, maxCount = 5) + public R test2() { + //TODO ... + return R.ok(); + } + + @SlidingWindowLimit(key = "'user :' + #username") + @GetMapping("/search") + public R search(@RequestParam String username) { + // 搜索逻辑 - 这里key会根据username动态变化 + return R.ok("username:" + username); + } + + @GetMapping("/status/{key}") + public R getStatus(@PathVariable String key, int window) { + return R.ok(slidingWindowRateLimiter.getWindowStatus(key, window)); + } + + @PostMapping("/reset/{key}") + public R reset(@PathVariable String key) { + slidingWindowRateLimiter.reset(key); + return R.ok(); + } +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/TokenRateController.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/TokenRateController.java new file mode 100644 index 0000000..ecb895f --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/controller/TokenRateController.java @@ -0,0 +1,33 @@ +package top.lrshuai.limit.controller; + +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; +import top.lrshuai.limit.annotation.TokenBucketRateLimit; +import top.lrshuai.limit.common.R; + +@RestController +@RequestMapping("/tokenRate") +public class TokenRateController { + + /** + * 测试发送短信 + */ + @GetMapping("/sendSms") + @TokenBucketRateLimit(rate = 0.1, capacity = 2, message = "短信发送过于频繁") + public R sendSms(){ + //TODO ... + return R.ok(); + } + + /** + * "@TokenBucketRateLimit(rate = 5.0, capacity = 20)" 每秒5次,突发20次 + */ + @TokenBucketRateLimit(key = "'search:' + #keyword", rate = 5.0, capacity = 10) + @GetMapping("/search") + public R search(@RequestParam String keyword) { + // 搜索逻辑 - 这里key会根据keyword动态变化 + return R.ok("搜索结果:"+keyword); + } +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/interceptor/RequestLimitIntercept.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/interceptor/RequestLimitIntercept.java deleted file mode 100644 index 525b6a4..0000000 --- a/SpringBoot-limit/src/main/java/top/lrshuai/limit/interceptor/RequestLimitIntercept.java +++ /dev/null @@ -1,102 +0,0 @@ -package top.lrshuai.limit.interceptor; - -import com.alibaba.fastjson.JSONObject; -import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.data.redis.core.RedisTemplate; -import org.springframework.stereotype.Component; -import org.springframework.web.method.HandlerMethod; -import org.springframework.web.servlet.handler.HandlerInterceptorAdapter; -import top.lrshuai.limit.annotation.RequestLimit; -import top.lrshuai.limit.common.ApiResultEnum; -import top.lrshuai.limit.common.Result; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.io.PrintWriter; -import java.lang.annotation.Annotation; -import java.lang.reflect.Method; -import java.util.concurrent.TimeUnit; - -/** - * 请求拦截 - */ -@Slf4j -@Component -public class RequestLimitIntercept extends HandlerInterceptorAdapter { - - @Autowired - private RedisTemplate redisTemplate; - - @Override - public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { - /** - * isAssignableFrom() 判定此 Class 对象所表示的类或接口与指定的 Class 参数所表示的类或接口是否相同,或是否是其超类或超接口 - * isAssignableFrom()方法是判断是否为某个类的父类 - * instanceof关键字是判断是否某个类的子类 - */ - if(handler.getClass().isAssignableFrom(HandlerMethod.class)){ - //HandlerMethod 封装方法定义相关的信息,如类,方法,参数等 - HandlerMethod handlerMethod = (HandlerMethod) handler; - Method method = handlerMethod.getMethod(); - // 如果 方法上有注解就优先选择方法上的参数,否则类上的参数 - RequestLimit requestLimit = getTagAnnotation(method, RequestLimit.class); - if(requestLimit != null){ - if(isLimit(request,requestLimit)){ - resonseOut(response,Result.error(ApiResultEnum.REQUST_LIMIT)); - return false; - } - } - } - return super.preHandle(request, response, handler); - } - //判断请求是否受限 - public boolean isLimit(HttpServletRequest request,RequestLimit requestLimit){ - // 受限的redis 缓存key ,因为这里用浏览器做测试,我就用sessionid 来做唯一key,如果是app ,可以使用 用户ID 之类的唯一标识。 - String limitKey = request.getServletPath()+request.getSession().getId(); - // 从缓存中获取,当前这个请求访问了几次 - Integer redisCount = (Integer) redisTemplate.opsForValue().get(limitKey); - if(redisCount == null){ - //初始 次数 - redisTemplate.opsForValue().set(limitKey,1,requestLimit.second(), TimeUnit.SECONDS); - }else{ - if(redisCount.intValue()>= requestLimit.maxCount()){ - return true; - } - // 次数自增 - redisTemplate.opsForValue().increment(limitKey); - } - return false; - } - - /** - * 获取目标注解 - * 如果方法上有注解就返回方法上的注解配置,否则类上的 - * @param method - * @param annotationClass - * @param - * @return - */ - public A getTagAnnotation(Method method, Class annotationClass) { - // 获取方法中是否包含注解 - Annotation methodAnnotate = method.getAnnotation(annotationClass); - //获取 类中是否包含注解,也就是controller 是否有注解 - Annotation classAnnotate = method.getDeclaringClass().getAnnotation(annotationClass); - return (A) (methodAnnotate!= null?methodAnnotate:classAnnotate); - } - - /** - * 回写给客户端 - * @param response - * @param result - * @throws IOException - */ - private void resonseOut(HttpServletResponse response, Result result) throws IOException { - response.setCharacterEncoding("UTF-8"); - response.setContentType("application/json; charset=utf-8"); - PrintWriter out = null ; - String json = JSONObject.toJSON(result).toString(); - out = response.getWriter(); - out.append(json); - } -} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/service/LeakyBucketRateLimiter.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/service/LeakyBucketRateLimiter.java new file mode 100644 index 0000000..692fcf4 --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/service/LeakyBucketRateLimiter.java @@ -0,0 +1,108 @@ +package top.lrshuai.limit.service; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.script.RedisScript; +import org.springframework.stereotype.Service; + +import java.util.Collections; + +@Slf4j +@Service +public class LeakyBucketRateLimiter { + + + @Autowired + private RedisTemplate redisTemplate; + + @Autowired + private RedisScript leakyBucketScript; + + /** + * 尝试获取通行证 + * @param key 限流key + * @param capacity 桶容量 + * @param rate 流出速率(每秒请求数) + * @param requestCount 请求数量 + * @return true-允许访问,false-被限流 + */ + public boolean tryAcquire(String key, int capacity, int rate, int requestCount) { + long now = System.currentTimeMillis() / 1000; // 使用秒级时间戳 + + Long result = redisTemplate.execute(leakyBucketScript, + Collections.singletonList(key), + capacity,rate,now,requestCount); + + return result != null && result == 1; + } + + /** + * 获取桶的当前状态(用于监控和调试) + */ + public BucketStatus getBucketStatus(String key) { + try { + // 使用 RedisTemplate 的哈希操作获取值 + Object waterObj = redisTemplate.opsForHash().get(key, "water"); + Object lastLeakTimeObj = redisTemplate.opsForHash().get(key, "lastLeakTime"); + Long ttl = redisTemplate.getExpire(key); + + long water = 0; + long lastLeakTime = 0; + + // 转换值 + if (waterObj != null) { + water = Long.parseLong(waterObj.toString()); + } + if (lastLeakTimeObj != null) { + lastLeakTime = Long.parseLong(lastLeakTimeObj.toString()); + } + + return new BucketStatus( + water, + lastLeakTime, + ttl != null ? ttl : -2 + ); + } catch (Exception e) { + log.error("Failed to get bucket status for key: {}", key, e); + return new BucketStatus(0, 0, -2); + } + } + + + /** + * 清理限流数据 + */ + public void reset(String key) { + redisTemplate.delete(key); + } + + /** + * 桶状态信息 + */ + @Data + @AllArgsConstructor + public static class BucketStatus { + /** + * 当前桶中积压的请求数量 + * 这个值表示漏桶中当前有多少个"水单位",每个水单位代表一个待处理的请求 + * 当有请求进入系统时,currentWater 会增加 + * 随着时间推移,水会以恒定速率从桶底漏出,currentWater 会相应减少 + * 如果 currentWater>= capacity(桶容量),新的请求会被拒绝 + */ + private long currentWater; + /** + * 最后一次计算漏水的时间戳 + * 用于计算从上次漏水到现在应该漏掉多少水 + * 计算公式:漏水量 = (当前时间 - lastLeakTime) ×ばつ 流出速率 + */ + private long lastLeakTime; + /** + * Redis 中该限流 key 的剩余生存时间(单位:秒) + * 表示这个限流桶还有多少秒会被 Redis 自动删除 + */ + private long ttl; + } +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/service/SlidingWindowRateLimiter.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/service/SlidingWindowRateLimiter.java new file mode 100644 index 0000000..e8fb51d --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/service/SlidingWindowRateLimiter.java @@ -0,0 +1,136 @@ +package top.lrshuai.limit.service; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.script.RedisScript; +import org.springframework.stereotype.Service; + +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +@Slf4j +@Service +public class SlidingWindowRateLimiter { + + @Autowired + private RedisTemplate redisTemplate; + + @Autowired + private RedisScript slidingWindowScript; + + /** + * 尝试获取通行证 + * @param key 限流key + * @param window 时间窗口大小(秒) + * @param maxCount 时间窗口内允许的最大请求数 + * @param requestCount 请求数量 + * @return true-允许访问,false-被限流 + */ + public boolean tryAcquire(String key, int window, int maxCount, int requestCount) { + long now = System.currentTimeMillis() / 1000; // 使用秒级时间戳 + Long result = redisTemplate.execute(slidingWindowScript, + Collections.singletonList(key), window, maxCount, now, requestCount); + + return result != null && result == 1; + } + + /** + * 获取时间窗口的当前状态 + */ + public WindowStatus getWindowStatus(String key, int window) { + try { + long now = System.currentTimeMillis() / 1000; + long windowStart = now - window; + + // 获取时间窗口内的请求总数 + Long count = redisTemplate.opsForZSet().count(key, windowStart, Double.MAX_VALUE); + + // 获取最早和最晚的请求时间 + Set members = redisTemplate.opsForZSet().range(key, 0, -1); + long earliestTime = 0; + long latestTime = 0; + + if (members != null && !members.isEmpty()) { + List times = members.stream() + .map(member -> Long.parseLong(member.toString()) / 1000) // 转回秒级 + .sorted() + .collect(Collectors.toList()); + + earliestTime = times.get(0); + latestTime = times.get(times.size() - 1); + } + Long ttl = redisTemplate.getExpire(key); + return new WindowStatus( + count != null ? count : 0, + windowStart, + now, + earliestTime, + latestTime, + ttl != null ? ttl : -2 + ); + } catch (Exception e) { + log.error("Failed to get window status for key: {}", key, e); + return new WindowStatus(0, 0, 0, 0, 0, -2); + } + } + + /** + * 清理限流数据 + */ + public void reset(String key) { + redisTemplate.delete(key); + } + + /** + * 时间窗口状态信息 + */ + @Data + @AllArgsConstructor + public static class WindowStatus { + /** + * 当前时间窗口内的请求数量 + */ + private long currentCount; + + /** + * 时间窗口起始时间戳(秒) + */ + private long windowStart; + + /** + * 当前时间戳(秒) + */ + private long currentTime; + + /** + * 时间窗口内最早的请求时间戳(秒) + */ + private long earliestRequestTime; + + /** + * 时间窗口内最晚的请求时间戳(秒) + */ + private long latestRequestTime; + + /** + * key的剩余生存时间(秒) + */ + private long ttl; + + @Override + public String toString() { + return String.format( + "WindowStatus{currentCount=%d, windowStart=%d, currentTime=%d, " + + "earliestRequest=%d, latestRequest=%d, ttl=%d}", + currentCount, windowStart, currentTime, + earliestRequestTime, latestRequestTime, ttl + ); + } + } + +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/service/TokenBucketRateLimiter.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/service/TokenBucketRateLimiter.java new file mode 100644 index 0000000..77d8cd5 --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/service/TokenBucketRateLimiter.java @@ -0,0 +1,64 @@ +package top.lrshuai.limit.service; + +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.script.RedisScript; +import org.springframework.stereotype.Service; +import org.springframework.util.ObjectUtils; + +import java.util.Arrays; +import java.util.List; + +@Service +@Slf4j +public class TokenBucketRateLimiter { + + @Autowired + private RedisTemplate redisTemplate; + + @Autowired + private RedisScript tokenBucketScript; + + /** + * 尝试获取令牌 + * + * @param key 限流key + * @param rate 令牌生成速率 (个/秒) + * @param capacity 桶容量 + * @param tokenRequest 请求的令牌数 + * @return 是否允许访问 + */ + public boolean tryAcquire(String key, double rate, int capacity, int tokenRequest) { + // 转换为秒 + long now = System.currentTimeMillis() / 1000; + + List keys = Arrays.asList(key); + + @SuppressWarnings("unchecked") + List result = (List) redisTemplate.execute(tokenBucketScript, keys, rate, capacity, now, tokenRequest); + + if (ObjectUtils.isEmpty(result)) { + log.warn("令牌桶限流脚本执行异常, key: {}", key); + return false; + } + + boolean allowed = result.get(0) == 1; + long remainingTokens = result.get(1); + long bucketCapacity = result.get(2); + + if (log.isDebugEnabled()) { + log.debug("限流检查 - key: {}, 允许: {}, 剩余令牌: {}/{}, 请求令牌数: {}", + key, allowed, remainingTokens, bucketCapacity, tokenRequest); + } + + return allowed; + } + + /** + * 简化方法 - 默认请求1个令牌 + */ + public boolean tryAcquire(String key, double rate, int capacity) { + return tryAcquire(key, rate, capacity, 1); + } +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/util/AopUtil.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/util/AopUtil.java new file mode 100644 index 0000000..2266406 --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/util/AopUtil.java @@ -0,0 +1,38 @@ +package top.lrshuai.limit.util; + +import lombok.extern.slf4j.Slf4j; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.reflect.MethodSignature; +import org.springframework.expression.Expression; +import org.springframework.expression.ExpressionParser; +import org.springframework.expression.spel.standard.SpelExpressionParser; +import org.springframework.expression.spel.support.StandardEvaluationContext; + +@Slf4j +public class AopUtil { + + private static final ExpressionParser parser = new SpelExpressionParser(); + + /** + * 解析SpEL表达式 + */ + public static String parseSpel(String expression, ProceedingJoinPoint joinPoint) { + try { + MethodSignature signature = (MethodSignature) joinPoint.getSignature(); + StandardEvaluationContext context = new StandardEvaluationContext(); + + // 设置方法参数 + String[] parameterNames = signature.getParameterNames(); + Object[] args = joinPoint.getArgs(); + for (int i = 0; i < parameterNames.length; i++) { + context.setVariable(parameterNames[i], args[i]); + } + + Expression expr = parser.parseExpression(expression); + return expr.getValue(context, String.class); + } catch (Exception e) { + log.warn("解析SpEL表达式失败: {}", expression, e); + return expression; + } + } +} diff --git a/SpringBoot-limit/src/main/java/top/lrshuai/limit/util/IpUtil.java b/SpringBoot-limit/src/main/java/top/lrshuai/limit/util/IpUtil.java new file mode 100644 index 0000000..bb75c46 --- /dev/null +++ b/SpringBoot-limit/src/main/java/top/lrshuai/limit/util/IpUtil.java @@ -0,0 +1,64 @@ +package top.lrshuai.limit.util; + +import javax.servlet.http.HttpServletRequest; + +public class IpUtil { + + /** + * 获取客户端真实IP地址 + * 优先级: X-Forwarded-For -> X-Real-IP -> Proxy-Client-IP -> WL-Proxy-Client-IP -> RemoteAddr + * @param request HttpServletRequest对象 + * @return 客户端的真实IP地址 + */ + public static String getClientIpAddress(HttpServletRequest request) { + String ip = null; + + // 1. 优先检查X-Forwarded-For头部 + ip = getIpFromHeader(request, "X-Forwarded-For"); + if (isValidIp(ip)) { + // 取第一个非unknown的有效IP + String[] ips = ip.split(","); + for (String i : ips) { + i = i.trim(); + if (isValidIp(i) && !"unknown".equalsIgnoreCase(i)) { + return i; // 返回第一个有效的客户端IP + } + } + } + + // 2. 检查其他头部,按优先级排序 + String[] headers = {"X-Real-IP", "Proxy-Client-IP", "WL-Proxy-Client-IP"}; + for (String header : headers) { + ip = getIpFromHeader(request, header); + if (isValidIp(ip)) { + return ip; + } + } + + // 3. 最后使用远程地址 + ip = request.getRemoteAddr(); + return "0:0:0:0:0:0:0:1".equals(ip) ? "127.0.0.1" : ip; // 处理本地IPv6回环地址 + } + + /** + * 从请求头中获取IP值 + */ + public static String getIpFromHeader(HttpServletRequest request, String headerName) { + String ip = request.getHeader(headerName); + if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) { + return null; + } + return ip.trim(); + } + + /** + * 基础IP地址有效性验证 + */ + private static boolean isValidIp(String ip) { + if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) { + return false; + } + // 基础格式验证,可根据需要增强(例如使用正则表达式或InetAddress验证) + return ip.chars().allMatch(ch -> ch == '.' || Character.isDigit(ch) || (ch>= 'a' && ch <= 'f') || (ch>= 'A' && ch <= 'F') || ch == ':'); + } +} diff --git a/SpringBoot-limit/src/main/resources/application-dev.yml b/SpringBoot-limit/src/main/resources/application-dev.yml index e1b8564..2ded96a 100644 --- a/SpringBoot-limit/src/main/resources/application-dev.yml +++ b/SpringBoot-limit/src/main/resources/application-dev.yml @@ -2,7 +2,7 @@ server: port: 8000 spring: redis: - password: + password: abcsee2see database: 0 port: 6379 host: 127.0.0.1 @@ -13,3 +13,7 @@ spring: max-active: 10 max-wait: -1ms timeout: 10000ms +logging: + level: + root: info + top.lrshuai.limit: debug diff --git a/SpringBoot-limit/src/main/resources/lua/leakyBucket.lua b/SpringBoot-limit/src/main/resources/lua/leakyBucket.lua new file mode 100644 index 0000000..d76b88f --- /dev/null +++ b/SpringBoot-limit/src/main/resources/lua/leakyBucket.lua @@ -0,0 +1,47 @@ +-- 漏桶限流算法 Lua 脚本 + +-- 参数说明: +-- KEYS[1]: 限流的key +-- ARGV[1]: 桶的容量 +-- ARGV[2]: 流出速率(每秒处理数) +-- ARGV[3]: 当前时间戳(秒) +-- ARGV[4]: 本次请求数量 + +local key = KEYS[1] +local capacity = tonumber(ARGV[1]) +local rate = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) +local requestCount = tonumber(ARGV[4]) + +-- 获取桶的当前状态 +local bucketInfo = redis.call('hmget', key, 'water', 'lastLeakTime') +local currentWater = 0 +local lastLeakTime = now + +-- 如果桶存在,获取当前水量和上次漏水时间 +if bucketInfo[1] then + currentWater = tonumber(bucketInfo[1]) +end + +if bucketInfo[2] then + lastLeakTime = tonumber(bucketInfo[2]) +end + +-- 计算从上次漏水到现在的漏出量 +local leakAmount = (now - lastLeakTime) * rate +if leakAmount> 0 then + currentWater = math.max(0, currentWater - leakAmount) + lastLeakTime = now +end + +-- 检查桶是否有足够空间容纳新请求 +if currentWater + requestCount <= capacity then + -- 允许请求,更新桶状态 + currentWater = currentWater + requestCount + redis.call('hmset', key, 'water', currentWater, 'lastLeakTime', lastLeakTime) + redis.call('expire', key, 3600) -- 设置过期时间,防止内存泄漏 + return 1 -- 允许访问 +else + -- 桶已满,拒绝请求 + return 0 -- 被限流 +end \ No newline at end of file diff --git a/SpringBoot-limit/src/main/resources/lua/slidingWindow.lua b/SpringBoot-limit/src/main/resources/lua/slidingWindow.lua new file mode 100644 index 0000000..20355c5 --- /dev/null +++ b/SpringBoot-limit/src/main/resources/lua/slidingWindow.lua @@ -0,0 +1,41 @@ +-- 滑动时间窗口计数器限流算法 + +-- 参数说明: +-- KEYS[1]: 限流的key +-- ARGV[1]: 时间窗口大小(秒) +-- ARGV[2]: 时间窗口内允许的最大请求数 +-- ARGV[3]: 当前时间戳(秒) +-- ARGV[4]: 本次请求数量(默认为1) + +local key = KEYS[1] +local window = tonumber(ARGV[1]) +local maxCount = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) +local requestCount = tonumber(ARGV[4]) or 1 + +-- 计算时间窗口的起始时间戳 +local windowStart = now - window + +-- 移除时间窗口之前的数据 +redis.call('zremrangebyscore', key, 0, windowStart) + +-- 获取当前时间窗口内的请求总数 +local currentCount = redis.call('zcard', key) + +-- 检查是否超过限制 +if currentCount + requestCount <= maxCount then + -- 没有超过限制,添加当前请求 + for i = 1, requestCount do + -- 使用毫秒级时间戳+随机数确保成员唯一性 + local member = now * 1000 + math.random(0, 999) + redis.call('zadd', key, member, member) + end + + -- 设置key的过期时间为窗口大小+1秒,确保数据自动清理 + redis.call('expire', key, window + 1) + + return 1 -- 允许访问 +else + -- 超过限制,拒绝请求 + return 0 -- 被限流 +end \ No newline at end of file diff --git a/SpringBoot-limit/src/main/resources/lua/tokenRate.lua b/SpringBoot-limit/src/main/resources/lua/tokenRate.lua new file mode 100644 index 0000000..df9db86 --- /dev/null +++ b/SpringBoot-limit/src/main/resources/lua/tokenRate.lua @@ -0,0 +1,51 @@ +-- 令牌桶限流 Lua 脚本 +-- KEYS[1]: 限流的key +-- ARGV[1]: 令牌生成速率 (每秒生成的令牌数) +-- ARGV[2]: 桶的容量 (最大令牌数) +-- ARGV[3]: 当前时间戳 (秒) +-- ARGV[4]: 本次请求的令牌数 (默认为1) + +local key = KEYS[1] +local rate = tonumber(ARGV[1]) +local capacity = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) +local requested = tonumber(ARGV[4]) + +-- 计算填满桶需要的时间,用于设置key的过期时间 +local fill_time = capacity / rate +local ttl = math.floor(fill_time * 2) -- 过期时间为填满时间的2倍 + +-- 从Redis获取上次的令牌数和刷新时间 +local last_tokens = tonumber(redis.call("get", key)) +if last_tokens == nil then + last_tokens = capacity -- 第一次访问,令牌数为桶容量 +end + +local last_refreshed = tonumber(redis.call("get", key .. ":ts")) +if last_refreshed == nil then + last_refreshed = now -- 第一次访问,刷新时间为当前时间 +end + +-- 计算时间差和应该补充的令牌数 +local delta = math.max(0, now - last_refreshed) +local filled_tokens = math.min(capacity, last_tokens + (delta * rate)) + +-- 判断是否允许本次请求 +local allowed = filled_tokens>= requested +local new_tokens = filled_tokens +local allowed_num = 0 + +if allowed then + new_tokens = filled_tokens - requested + allowed_num = 1 + -- 更新令牌数和时间戳 + redis.call("setex", key, ttl, new_tokens) + redis.call("setex", key .. ":ts", ttl, now) +else + -- 即使不允许,也更新状态(为了计算下一次的令牌数) + redis.call("setex", key, ttl, new_tokens) + redis.call("setex", key .. ":ts", ttl, last_refreshed) +end + +-- 返回结果:是否允许(1/0),剩余令牌数,桶容量 +return {allowed_num, new_tokens, capacity} \ No newline at end of file diff --git a/springboot-2FA/.gitignore b/springboot-2FA/.gitignore new file mode 100644 index 0000000..667aaef --- /dev/null +++ b/springboot-2FA/.gitignore @@ -0,0 +1,33 @@ +HELP.md +target/ +.mvn/wrapper/maven-wrapper.jar +!**/src/main/**/target/ +!**/src/test/**/target/ + +### STS ### +.apt_generated +.classpath +.factorypath +.project +.settings +.springBeans +.sts4-cache + +### IntelliJ IDEA ### +.idea +*.iws +*.iml +*.ipr + +### NetBeans ### +/nbproject/private/ +/nbbuild/ +/dist/ +/nbdist/ +/.nb-gradle/ +build/ +!**/src/main/**/build/ +!**/src/test/**/build/ + +### VS Code ### +.vscode/ diff --git a/springboot-2FA/pom.xml b/springboot-2FA/pom.xml new file mode 100644 index 0000000..82967aa --- /dev/null +++ b/springboot-2FA/pom.xml @@ -0,0 +1,51 @@ + + + 4.0.0 + + org.springframework.boot + spring-boot-starter-parent + 3.5.7 + + + top.lrshuai.ai + springboot-2FA + 0.0.1-SNAPSHOT + springboot-2FA + 身份验证器 demo + + + 17 + 1.19.0 + + + + + org.springframework.boot + spring-boot-starter-web + + + + org.springframework.boot + spring-boot-starter-test + test + + + + commons-codec + commons-codec + ${commons-codec.version} + + + + + + + + org.springframework.boot + spring-boot-maven-plugin + + + + + diff --git a/springboot-2FA/src/main/java/top/lrshuai/totp/Springboot2FaApplication.java b/springboot-2FA/src/main/java/top/lrshuai/totp/Springboot2FaApplication.java new file mode 100644 index 0000000..b39f761 --- /dev/null +++ b/springboot-2FA/src/main/java/top/lrshuai/totp/Springboot2FaApplication.java @@ -0,0 +1,13 @@ +package top.lrshuai.totp; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; + +@SpringBootApplication +public class Springboot2FaApplication { + + public static void main(String[] args) { + SpringApplication.run(Springboot2FaApplication.class, args); + } + +} diff --git a/springboot-2FA/src/main/java/top/lrshuai/totp/auth/GoogleAuthenticator.java b/springboot-2FA/src/main/java/top/lrshuai/totp/auth/GoogleAuthenticator.java new file mode 100644 index 0000000..83699c1 --- /dev/null +++ b/springboot-2FA/src/main/java/top/lrshuai/totp/auth/GoogleAuthenticator.java @@ -0,0 +1,612 @@ +package top.lrshuai.totp.auth; + + +import org.apache.commons.codec.binary.Base32; +import org.apache.commons.codec.binary.Hex; + +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.HashMap; +import java.util.Map; + +/** + * Google Authenticator 工具类 + * 基于 TOTP (Time-based One-Time Password) 算法实现双因素认证 + * 支持 HMAC-SHA1、HMAC-SHA256、HMAC-SHA512 算法 + * 参考 RFC 6238 标准,兼容 Google Authenticator 移动应用 + * + * 主要功能: + * - 生成随机密钥(支持不同算法推荐长度) + * - 生成TOTP动态验证码 + * - 生成Google Authenticator可识别的二维码数据 + * - 验证用户输入的验证码 + * - 支持多种HMAC算法 + * + * @author rstyro + */ +public final class GoogleAuthenticator { + + /** 默认密钥长度(字节)- SHA1 */ + public static final int DEFAULT_SECRET_KEY_LENGTH_SHA1 = 20; + /** SHA256算法推荐密钥长度(字节) */ + public static final int DEFAULT_SECRET_KEY_LENGTH_SHA256 = 32; + /** SHA512算法推荐密钥长度(字节) */ + public static final int DEFAULT_SECRET_KEY_LENGTH_SHA512 = 64; + + /** 默认密钥长度 */ + private static final int DEFAULT_SECRET_KEY_LENGTH = DEFAULT_SECRET_KEY_LENGTH_SHA1; + + /** 默认时间窗口大小(30秒单位) */ + private static final int DEFAULT_WINDOW_SIZE = 2; + /** 最大允许的时间窗口大小 */ + private static final int MAX_WINDOW_SIZE = 17; + /** 时间步长(秒) */ + private static final long TIME_STEP = 30L; + /** 验证码位数 */ + private static final int CODE_DIGITS = 6; + + /** 算法名称常量 */ + public static final String HMAC_SHA1 = "HmacSHA1"; + public static final String HMAC_SHA256 = "HmacSHA256"; + public static final String HMAC_SHA512 = "HmacSHA512"; + + /** 默认算法 */ + private static final String DEFAULT_ALGORITHM = HMAC_SHA1; + + /** 算法对应的推荐密钥长度映射 */ + private static final Map ALGORITHM_KEY_LENGTH_MAP = new HashMap(); + + static { + ALGORITHM_KEY_LENGTH_MAP.put(HMAC_SHA1, DEFAULT_SECRET_KEY_LENGTH_SHA1); + ALGORITHM_KEY_LENGTH_MAP.put(HMAC_SHA256, DEFAULT_SECRET_KEY_LENGTH_SHA256); + ALGORITHM_KEY_LENGTH_MAP.put(HMAC_SHA512, DEFAULT_SECRET_KEY_LENGTH_SHA512); + } + + /** 当前时间窗口大小 */ + private static int windowSize = DEFAULT_WINDOW_SIZE; + /** 当前使用的算法 */ + private static String currentAlgorithm = DEFAULT_ALGORITHM; + + /** + * 私有构造方法,防止实例化 + */ + private GoogleAuthenticator() { + throw new AssertionError("GoogleAuthenticator是工具类,不能实例化"); + } + + // ==================== 密钥生成相关方法 ==================== + + /** + * 生成随机的Base32编码密钥(使用默认算法和长度) + * 密钥用于在客户端和服务器端之间共享,用于生成验证码 + * + * @return Base32编码的随机密钥(大写,无分隔符) + * @throws SecurityException 如果随机数生成失败 + */ + public static String generateRandomSecretKey() { + return generateRandomSecretKey(DEFAULT_SECRET_KEY_LENGTH); + } + + /** + * 生成指定长度的随机Base32编码密钥 + * + * @param length 密钥长度(字节) + * @return Base32编码的随机密钥 + */ + public static String generateRandomSecretKey(int length) { + try { + SecureRandom random = SecureRandom.getInstanceStrong(); + byte[] bytes = new byte[length]; + random.nextBytes(bytes); + + Base32 base32 = new Base32(); + return base32.encodeToString(bytes).toUpperCase(); + } catch (NoSuchAlgorithmException e) { + throw new SecurityException("安全随机数生成器不可用", e); + } + } + + /** + * 为指定算法生成推荐长度的随机密钥 + * + * @param algorithm 算法(HMAC_SHA1, HMAC_SHA256, HMAC_SHA512) + * @return Base32编码的随机密钥 + */ + public static String generateRandomSecretKey(String algorithm) { + Integer length = ALGORITHM_KEY_LENGTH_MAP.get(algorithm); + if (length == null) { + throw new IllegalArgumentException("不支持的算法: " + algorithm + + ",支持的算法: " + ALGORITHM_KEY_LENGTH_MAP.keySet()); + } + return generateRandomSecretKey(length); + } + + /** + * 生成指定算法和长度的随机密钥 + * + * @param algorithm 算法 + * @param length 密钥长度 + * @return Base32编码的随机密钥 + */ + public static String generateRandomSecretKey(String algorithm, int length) { + Integer recommendedLength = ALGORITHM_KEY_LENGTH_MAP.get(algorithm); + if (recommendedLength != null && length < recommendedLength) { + System.err.println("警告: 密钥长度" + length + "字节小于" + algorithm + + "推荐长度" + recommendedLength + "字节,可能存在安全风险"); + } + return generateRandomSecretKey(length); + } + + // ==================== TOTP生成方法 ==================== + + /** + * 生成当前时间的TOTP验证码(默认SHA1算法) + * + * @param secretKey Base32编码的共享密钥 + * @return 6位数字的TOTP验证码 + * @throws IllegalArgumentException 如果密钥为空或格式错误 + * @throws SecurityException 如果加密操作失败 + */ + public static String generateTOTPCode(String secretKey) { + return generateTOTPCode(secretKey, DEFAULT_ALGORITHM); + } + + /** + * 生成当前时间的TOTP验证码(指定算法) + * + * @param secretKey Base32编码的共享密钥 + * @param algorithm 算法(HMAC_SHA1, HMAC_SHA256, HMAC_SHA512) + * @return 6位数字的TOTP验证码 + */ + public static String generateTOTPCode(String secretKey, String algorithm) { + validateSecretKey(secretKey); + validateAlgorithm(algorithm); + + try { + // 标准化密钥:移除空格并转为大写 + String normalizedKey = secretKey.replace(" ", "").toUpperCase(); + Base32 base32 = new Base32(); + byte[] decodedBytes = base32.decode(normalizedKey); + String hexKey = Hex.encodeHexString(decodedBytes); + + // 计算当前时间窗口 + long timeWindow = (System.currentTimeMillis() / 1000L) / TIME_STEP; + String hexTime = Long.toHexString(timeWindow); + + // 调用TOTP类的对应方法 + switch (algorithm) { + case HMAC_SHA256: + return TOTP.generateTOTP256(hexKey, hexTime, CODE_DIGITS); + case HMAC_SHA512: + return TOTP.generateTOTP512(hexKey, hexTime, CODE_DIGITS); + case HMAC_SHA1: + default: + return TOTP.generateTOTP(hexKey, hexTime, CODE_DIGITS, algorithm); + } + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("无效的密钥格式: " + e.getMessage(), e); + } catch (Exception e) { + throw new SecurityException("生成TOTP验证码失败: " + e.getMessage(), e); + } + } + + /** + * 生成当前时间的TOTP验证码(SHA256算法) + * + * @param secretKey Base32编码的共享密钥 + * @return 6位数字的TOTP验证码 + */ + public static String generateTOTPCode256(String secretKey) { + return generateTOTPCode(secretKey, HMAC_SHA256); + } + + /** + * 生成当前时间的TOTP验证码(SHA512算法) + * + * @param secretKey Base32编码的共享密钥 + * @return 6位数字的TOTP验证码 + */ + public static String generateTOTPCode512(String secretKey) { + return generateTOTPCode(secretKey, HMAC_SHA512); + } + + /** + * 生成指定时间戳的TOTP验证码 + * + * @param secretKey Base32编码的共享密钥 + * @param timestamp 时间戳(毫秒) + * @param algorithm 算法 + * @return 6位数字的TOTP验证码 + */ + public static String generateTOTPCode(String secretKey, long timestamp, String algorithm) { + validateSecretKey(secretKey); + validateAlgorithm(algorithm); + + try { + String normalizedKey = secretKey.replace(" ", "").toUpperCase(); + Base32 base32 = new Base32(); + byte[] decodedBytes = base32.decode(normalizedKey); + String hexKey = Hex.encodeHexString(decodedBytes); + + long timeWindow = (timestamp / 1000L) / TIME_STEP; + String hexTime = Long.toHexString(timeWindow); + + switch (algorithm) { + case HMAC_SHA256: + return TOTP.generateTOTP256(hexKey, hexTime, CODE_DIGITS); + case HMAC_SHA512: + return TOTP.generateTOTP512(hexKey, hexTime, CODE_DIGITS); + case HMAC_SHA1: + default: + return TOTP.generateTOTP(hexKey, hexTime, CODE_DIGITS, algorithm); + } + } catch (Exception e) { + throw new SecurityException("生成TOTP验证码失败: " + e.getMessage(), e); + } + } + + // ==================== 二维码生成方法 ==================== + + /** + * 生成Google Authenticator二维码内容URL(默认SHA1算法) + * + * @param secretKey 共享密钥 + * @param account 用户账号(如邮箱或用户名) + * @param issuer 发行者名称(应用或网站名称) + * @return 二维码内容URL + * @throws IllegalArgumentException 如果参数为空或格式错误 + */ + public static String generateQRCodeUrl(String secretKey, String account, String issuer) { + return generateQRCodeUrl(secretKey, account, issuer, DEFAULT_ALGORITHM); + } + + /** + * 生成Google Authenticator二维码内容URL(指定算法) + * 注意:Google Authenticator应用可能不支持SHA256/SHA512 + * + * @param secretKey 共享密钥 + * @param account 用户账号 + * @param issuer 发行者名称 + * @param algorithm 算法 + * @return 二维码内容URL + */ + public static String generateQRCodeUrl(String secretKey, String account, String issuer, String algorithm) { + validateParameters(secretKey, account, issuer); + validateAlgorithm(algorithm); + + String normalizedKey = secretKey.replace(" ", "").toUpperCase(); + + // 构建OTP Auth URL,符合Google Authenticator标准格式 + StringBuilder url = new StringBuilder("otpauth://totp/") + .append(URLEncoder.encode(issuer + ":" + account, StandardCharsets.UTF_8).replace("+", "%20")) + .append("?secret=").append(URLEncoder.encode(normalizedKey, StandardCharsets.UTF_8).replace("+", "%20")) + .append("&issuer=").append(URLEncoder.encode(issuer, StandardCharsets.UTF_8).replace("+", "%20")); + + // 添加算法参数(SHA1是默认值,可以省略) + if (!HMAC_SHA1.equals(algorithm)) { + url.append("&algorithm=").append(algorithm.toUpperCase()); + } + + // 添加位数参数 + url.append("&digits=").append(CODE_DIGITS); + + // 添加时间步长参数 + url.append("&period=").append(TIME_STEP); + + return url.toString(); + } + + // ==================== 验证方法 ==================== + + /** + * 验证TOTP验证码(默认SHA1算法) + * 考虑时间窗口偏移,以处理客户端和服务端之间的时间差异 + * + * @param secretKey 共享密钥 + * @param code 待验证的验证码 + * @param timestamp 时间戳(毫秒) + * @return 验证是否成功 + * @throws IllegalArgumentException 如果参数无效 + */ + public static boolean verifyCode(String secretKey, long code, long timestamp) { + return verifyCode(secretKey, code, timestamp, DEFAULT_ALGORITHM); + } + + /** + * 验证TOTP验证码(指定算法) + * + * @param secretKey 共享密钥 + * @param code 待验证的验证码 + * @param timestamp 时间戳(毫秒) + * @param algorithm 算法 + * @return 验证是否成功 + */ + public static boolean verifyCode(String secretKey, long code, long timestamp, String algorithm) { + validateSecretKey(secretKey); + validateAlgorithm(algorithm); + + if (code < 0 || code> 999999) { + throw new IllegalArgumentException("验证码必须是6位数字"); + } + + // 计算基准时间窗口 + long timeWindow = (timestamp / 1000L) / TIME_STEP; + + // 检查当前及前后时间窗口内的验证码 + for (int i = -windowSize; i <= windowSize; i++) { + try { + String generatedCode = generateTOTPCode(secretKey, timestamp + (i * TIME_STEP * 1000L), algorithm); + if (Long.parseLong(generatedCode) == code) { + return true; + } + } catch (Exception e) { + // 记录日志但继续检查其他时间窗口 + System.err.println("验证码验证过程中出现异常: " + e.getMessage()); + } + } + + return false; + } + + /** + * 验证当前时间的TOTP验证码(默认SHA1算法) + * + * @param secretKey 共享密钥 + * @param code 待验证的验证码 + * @return 验证是否成功 + */ + public static boolean verifyCurrentCode(String secretKey, long code) { + return verifyCode(secretKey, code, System.currentTimeMillis(), DEFAULT_ALGORITHM); + } + + /** + * 验证当前时间的TOTP验证码(指定算法) + * + * @param secretKey 共享密钥 + * @param code 待验证的验证码 + * @param algorithm 算法 + * @return 验证是否成功 + */ + public static boolean verifyCurrentCode(String secretKey, long code, String algorithm) { + return verifyCode(secretKey, code, System.currentTimeMillis(), algorithm); + } + + /** + * 验证当前时间的TOTP验证码字符串(更易用的方法) + * + * @param secretKey 共享密钥 + * @param code 待验证的验证码字符串 + * @return 验证是否成功 + */ + public static boolean verifyCurrentCode(String secretKey, String code) { + return verifyCurrentCode(secretKey, code, DEFAULT_ALGORITHM); + } + + /** + * 验证当前时间的TOTP验证码字符串(指定算法) + * + * @param secretKey 共享密钥 + * @param code 待验证的验证码字符串 + * @param algorithm 算法 + * @return 验证是否成功 + */ + public static boolean verifyCurrentCode(String secretKey, String code, String algorithm) { + try { + long codeValue = Long.parseLong(code); + return verifyCurrentCode(secretKey, codeValue, algorithm); + } catch (NumberFormatException e) { + return false; + } + } + + /** + * 验证当前时间的TOTP验证码字符串(SHA256算法) + * + * @param secretKey 共享密钥 + * @param code 待验证的验证码字符串 + * @return 验证是否成功 + */ + public static boolean verifyCurrentCode256(String secretKey, String code) { + return verifyCurrentCode(secretKey, code, HMAC_SHA256); + } + + /** + * 验证当前时间的TOTP验证码字符串(SHA512算法) + * + * @param secretKey 共享密钥 + * @param code 待验证的验证码字符串 + * @return 验证是否成功 + */ + public static boolean verifyCurrentCode512(String secretKey, String code) { + return verifyCurrentCode(secretKey, code, HMAC_SHA512); + } + + // ==================== 配置方法 ==================== + + /** + * 设置验证时间窗口大小 + * 时间窗口大小决定了允许的时间偏移范围(每个窗口30秒) + * @param size 窗口大小(1-17) + * @throws IllegalArgumentException 如果窗口大小超出范围 + */ + public static void setWindowSize(int size) { + if (size < 1 || size> MAX_WINDOW_SIZE) { + throw new IllegalArgumentException("窗口大小必须在1到" + MAX_WINDOW_SIZE + "之间"); + } + windowSize = size; + } + + /** + * 设置默认算法 + * + * @param algorithm 算法(HMAC_SHA1, HMAC_SHA256, HMAC_SHA512) + */ + public static void setDefaultAlgorithm(String algorithm) { + validateAlgorithm(algorithm); + currentAlgorithm = algorithm; + } + + /** + * 获取当前时间窗口大小 + * + * @return 当前时间窗口大小 + */ + public static int getWindowSize() { + return windowSize; + } + + /** + * 获取当前默认算法 + * + * @return 当前算法 + */ + public static String getDefaultAlgorithm() { + return currentAlgorithm; + } + + /** + * 获取算法对应的推荐密钥长度 + * + * @param algorithm 算法 + * @return 推荐密钥长度(字节) + */ + public static int getRecommendedKeyLength(String algorithm) { + Integer length = ALGORITHM_KEY_LENGTH_MAP.get(algorithm); + if (length == null) { + throw new IllegalArgumentException("不支持的算法: " + algorithm); + } + return length; + } + + /** + * 获取支持的算法列表 + * + * @return 支持的算法名称数组 + */ + public static String[] getSupportedAlgorithms() { + return new String[]{HMAC_SHA1, HMAC_SHA256, HMAC_SHA512}; + } + + // ==================== 辅助方法 ==================== + + /** + * 验证密钥格式 + */ + private static void validateSecretKey(String secretKey) { + if (secretKey == null || secretKey.trim().isEmpty()) { + throw new IllegalArgumentException("密钥不能为空"); + } + if (!secretKey.matches("^[A-Z2-7=\\s]+$")) { + throw new IllegalArgumentException("密钥必须包含有效的Base32字符(A-Z, 2-7)"); + } + } + + /** + * 验证算法 + */ + private static void validateAlgorithm(String algorithm) { + if (!ALGORITHM_KEY_LENGTH_MAP.containsKey(algorithm)) { + throw new IllegalArgumentException("不支持的算法: " + algorithm + + ",支持的算法: " + String.join(", ", ALGORITHM_KEY_LENGTH_MAP.keySet())); + } + } + + /** + * 验证二维码生成参数 + */ + private static void validateParameters(String secretKey, String account, String issuer) { + validateSecretKey(secretKey); + + if (account == null || account.trim().isEmpty()) { + throw new IllegalArgumentException("账号不能为空"); + } + if (issuer == null || issuer.trim().isEmpty()) { + throw new IllegalArgumentException("发行者名称不能为空"); + } + } + + /** + * 获取算法的显示名称 + * + * @param algorithm 算法标识 + * @return 显示名称 + */ + public static String getAlgorithmDisplayName(String algorithm) { + switch (algorithm) { + case HMAC_SHA1: return "HMAC-SHA1"; + case HMAC_SHA256: return "HMAC-SHA256"; + case HMAC_SHA512: return "HMAC-SHA512"; + default: return algorithm; + } + } + + // ==================== 测试方法 ==================== + + /** + * 完整测试示例 + */ + public static void testAllAlgorithms() { + System.out.println("=== Google Authenticator 多算法测试 ===\n"); + + String[] algorithms = {HMAC_SHA1, HMAC_SHA256, HMAC_SHA512}; + + for (String algorithm : algorithms) { + System.out.println("\n--- 测试 " + getAlgorithmDisplayName(algorithm) + " 算法 ---"); + + // 生成密钥 + String secretKey = generateRandomSecretKey(algorithm); + int recommendedLength = getRecommendedKeyLength(algorithm); + System.out.println("1. 生成密钥 (" + recommendedLength + "字节): " + secretKey); + + // 生成当前验证码 + String totpCode = generateTOTPCode(secretKey, algorithm); + System.out.println("2. 当前TOTP验证码: " + totpCode); + + // 生成二维码URL + String qrCodeUrl = generateQRCodeUrl(secretKey, "test@example.com", "TOTP-Test", algorithm); + System.out.println("3. 二维码URL: " + (qrCodeUrl.length()> 100 ? qrCodeUrl.substring(0, 100) + "..." : qrCodeUrl)); + + // 验证验证码 + boolean isValid = verifyCurrentCode(secretKey, totpCode, algorithm); + System.out.println("4. 验证码验证结果: " + (isValid ? "✓ 通过" : "✗ 失败")); + + // 错误验证码测试 + boolean isInvalid = verifyCurrentCode(secretKey, "123456", algorithm); + System.out.println("5. 错误验证码测试: " + (!isInvalid ? "✓ 测试通过" : "✗ 测试不通过-错误验证码也通过")); + } + + System.out.println("\n=== 测试完成 ==="); + } + + /** + * 主方法:测试多算法支持 + */ + public static void main(String[] args) { + // 测试所有算法 +// testAllAlgorithms(); + + // 或者单独测试特定算法 + String secretKey = "NIHMRAK5ZS73PC3HOAGDTK65QDNCZ6QY"; + String totp1 = generateTOTPCode(secretKey); + System.out.println("URL: " +generateQRCodeUrl(secretKey, "test-sha1@example.com", "TOTP-Test", HMAC_SHA1)); + System.out.println("当前SHA1验证码: " + totp1); + System.out.println("当前SHA1验证码: " + verifyCurrentCode(secretKey, "050761")); + System.out.println(); + + String secretKey256 = "VBS6IG6VLRSRVPZUQBFM6G6WE6YGXRF7SCFTUVBJPTWUMPRBAWVQ===="; + String totp256 = generateTOTPCode256(secretKey256); + System.out.println("URL: " +generateQRCodeUrl(secretKey256, "testsha256@example.com", "TOTP-Test", HMAC_SHA256)); + System.out.println("当前SHA256验证码: " + totp256); + System.out.println("当前SHA256验证码: " + verifyCurrentCode256(secretKey256, "794120")); + System.out.println(); + + String secretKey512 = "Z535MJVUZWDXKRXHB7LMDS7YMTZOEZE37ZUXAXF6TKMU4MLOZGCHFFPAPY43EMW7MUZJZ7W74T2PFCEUVWRN4Z36XXGPZIX6W7XVIKI="; + String totp512 = generateTOTPCode256(secretKey512); + System.out.println("URL: " +generateQRCodeUrl(secretKey512, "testsha512@example.com", "TOTP-Test", HMAC_SHA512)); + System.out.println("当前SHA512验证码: " + totp512); + System.out.println("当前SHA512验证码: " + verifyCurrentCode512(secretKey512, "149488")); + + } +} \ No newline at end of file diff --git a/springboot-2FA/src/main/java/top/lrshuai/totp/auth/TOTP.java b/springboot-2FA/src/main/java/top/lrshuai/totp/auth/TOTP.java new file mode 100644 index 0000000..7cb7085 --- /dev/null +++ b/springboot-2FA/src/main/java/top/lrshuai/totp/auth/TOTP.java @@ -0,0 +1,307 @@ +package top.lrshuai.totp.auth; + +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; +import java.math.BigInteger; +import java.security.GeneralSecurityException; +import java.time.Instant; +import java.time.ZoneId; +import java.time.format.DateTimeFormatter; + +/** + * TOTP (Time-based One-Time Password) 算法实现 + * 基于 RFC 6238 标准,用于生成基于时间的一次性密码。 + * 该类是工具类,所有方法均为静态方法,不可实例化。 + * 功能特点: + * - 支持 HMAC-SHA1、HMAC-SHA256、HMAC-SHA512 算法 + * - 可自定义密码位数(1-8位)和时间步长 + * - 提供密码验证功能,支持时间偏移容错 + * 使用示例: + * String key = "3132333435363738393031323334353637383930"; + * String totp = TOTP.generateCurrentTOTP(key); + * boolean isValid = TOTP.verifyTOTP(key, "123456"); + * + * @author rstyro + */ +public final class TOTP { + + /** + * 数字幂数组,用于计算10的n次方,索引对应位数(1-8位) + * 例如:DIGITS_POWER[6] = 1000000 + */ + private static final int[] DIGITS_POWER = {1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000}; + + /** HMAC-SHA1 算法标识 */ + public static final String HMAC_SHA1 = "HmacSHA1"; + /** HMAC-SHA256 算法标识 */ + public static final String HMAC_SHA256 = "HmacSHA256"; + /** HMAC-SHA512 算法标识 */ + public static final String HMAC_SHA512 = "HmacSHA512"; + + /** 默认动态密码位数(6位) */ + private static final int DEFAULT_DIGITS = 6; + /** 默认时间步长(秒) */ + private static final long DEFAULT_TIME_STEP = 30L; + /** 默认起始时间(Unix纪元) */ + private static final long DEFAULT_START_TIME = 0L; + /** 默认验证时间窗口大小(允许前后偏移的步数) */ + private static final int DEFAULT_TIME_WINDOW = 1; + + /** + * 私有构造方法,防止类实例化 + * 工具类应避免实例化,所有方法均为静态方法 + */ + private TOTP() { + throw new AssertionError("TOTP 是工具类,不能实例化"); + } + + /** + * 使用HMAC算法计算哈希值 + * + * @param crypto 加密算法 (HmacSHA1, HmacSHA256, HmacSHA512) + * @param keyBytes 密钥字节数组 + * @param text 要认证的消息文本 + * @return HMAC哈希值 + * @throws GeneralSecurityException 安全算法异常 + */ + private static byte[] hmacSha(String crypto, byte[] keyBytes, byte[] text) + throws GeneralSecurityException { + Mac hmac = Mac.getInstance(crypto); + SecretKeySpec macKey = new SecretKeySpec(keyBytes, "RAW"); + hmac.init(macKey); + return hmac.doFinal(text); + } + + /** + * 将十六进制字符串转换为字节数组 + * + * @param hex 十六进制字符串 + * @return 字节数组 + * @throws IllegalArgumentException 当十六进制字符串格式错误时 + */ + private static byte[] hexStr2Bytes(String hex) { + // 使用BigInteger处理十六进制字符串,确保正确转换 + byte[] bArray = new BigInteger("10" + hex, 16).toByteArray(); + byte[] ret = new byte[bArray.length - 1]; + System.arraycopy(bArray, 1, ret, 0, ret.length); + return ret; + } + + /** + * 生成TOTP值 + * + * @param key 共享密钥,十六进制编码字符串 + * @param time 时间计数器值,十六进制编码字符串 + * @param returnDigits 返回的TOTP位数,必须在1到8之间 + * @param crypto 加密算法,如 "HmacSHA1" + * @return TOTP数值字符串,指定位数 + * @throws IllegalArgumentException 如果位数无效或参数错误 + * @throws RuntimeException 如果安全算法出错 + */ + public static String generateTOTP(String key, String time, int returnDigits, String crypto) { + // 参数校验 + if (returnDigits < 1 || returnDigits> 8) { + throw new IllegalArgumentException("TOTP位数必须在1到8之间"); + } + if (key == null || key.isEmpty() || time == null || time.isEmpty()) { + throw new IllegalArgumentException("密钥和时间参数不能为空"); + } + + // 时间字符串填充至16字符(64位十六进制表示) + String paddedTime = time; + while (paddedTime.length() < 16) { + paddedTime = "0" + paddedTime; + } + + try { + byte[] msg = hexStr2Bytes(paddedTime); + byte[] k = hexStr2Bytes(key); + byte[] hash = hmacSha(crypto, k, msg); + + // 动态截取:取最后一字节的低4位作为偏移量 + int offset = hash[hash.length - 1] & 0x0f; + + // 从偏移位置取4字节,按大端序组合为整数 + int binary = ((hash[offset] & 0x7f) << 24) + | ((hash[offset + 1] & 0xff) << 16) + | ((hash[offset + 2] & 0xff) << 8) + | (hash[offset + 3] & 0xff); + + // 取模得到指定位数的TOTP值 + int otp = binary % DIGITS_POWER[returnDigits]; + + // 格式化为指定位数字符串,不足位补零 + return String.format("%0" + returnDigits + "d", otp); + + } catch (GeneralSecurityException e) { + throw new RuntimeException("TOTP生成安全错误: " + e.getMessage(), e); + } catch (Exception e) { + throw new RuntimeException("TOTP生成失败: " + e.getMessage(), e); + } + } + + + /** + * 生成TOTP(默认6位数,HMAC-SHA1算法) + */ + public static String generateTOTP(String key, String time) { + return generateTOTP(key, time, DEFAULT_DIGITS, HMAC_SHA1); + } + + /** + * 生成TOTP(指定位数,HMAC-SHA1算法) + */ + public static String generateTOTP(String key, String time, int returnDigits) { + return generateTOTP(key, time, returnDigits, HMAC_SHA1); + } + + /** + * 生成TOTP(指定位数,HMAC-SHA256算法) + */ + public static String generateTOTP256(String key, String time, int returnDigits) { + return generateTOTP(key, time, returnDigits, HMAC_SHA256); + } + + /** + * 生成TOTP(指定位数,HMAC-SHA512算法) + */ + public static String generateTOTP512(String key, String time, int returnDigits) { + return generateTOTP(key, time, returnDigits, HMAC_SHA512); + } + + /** + * 基于当前时间生成TOTP + * @param key 共享密钥(十六进制字符串) + * @return TOTP值(6位数) + */ + public static String generateCurrentTOTP(String key) { + long currentTime = System.currentTimeMillis() / 1000; + long timeStep = (currentTime - DEFAULT_START_TIME) / DEFAULT_TIME_STEP; + return generateTOTP(key, Long.toHexString(timeStep).toUpperCase()); + } + + /** + * 验证TOTP代码,考虑时间偏移容错 + * + * @param key 共享密钥 + * @param code 要验证的代码 + * @param timeWindow 时间窗口大小(允许前后偏移的步数) + * @return 验证是否成功 + */ + public static boolean verifyTOTP(String key, String code, int timeWindow) { + if (key == null || key.isEmpty() || code == null || code.isEmpty()) { + return false; + } + + long currentTime = System.currentTimeMillis() / 1000; + long currentTimeStep = (currentTime - DEFAULT_START_TIME) / DEFAULT_TIME_STEP; + + // 检查当前时间步及其前后时间窗口内的步数 + for (long i = -timeWindow; i <= timeWindow; i++) { + long timeStep = currentTimeStep + i; + String steps = Long.toHexString(timeStep).toUpperCase(); + try { + String totp = generateTOTP(key, steps); + if (totp.equals(code)) { + return true; + } + } catch (Exception e) { + // 忽略单个时间步的错误,继续验证其他步数 + continue; + } + } + return false; + } + + /** + * 验证TOTP代码(使用默认时间窗口) + */ + public static boolean verifyTOTP(String key, String code) { + return verifyTOTP(key, code, DEFAULT_TIME_WINDOW); + } + + /** + * 主方法:测试TOTP算法实现 + * 使用RFC 6238中的测试向量验证算法正确性,并演示当前TOTP生成 + */ + public static void main(String[] args) { + System.out.println("TOTP算法测试程序"); + System.out.println("================\n"); + + // RFC 6238 测试向量 + String seed20 = "3132333435363738393031323334353637383930"; // 20字节密钥(SHA1) + String seed32 = "3132333435363738393031323334353637383930313233343536373839303132"; // 32字节密钥(SHA256) + String seed64 = "3132333435363738393031323334353637383930" + + "3132333435363738393031323334353637383930" + + "3132333435363738393031323334353637383930" + + "31323334"; // 64字节密钥(SHA512) + + // 测试时间点(Unix时间戳) + long[] testTime = {59L, 1111111109L, 1111111111L, 1234567890L, 2000000000L, 20000000000L}; + + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss").withZone(ZoneId.of("UTC")); + + // 打印测试结果表格头 + System.out.println("RFC 6238 测试向量验证结果:"); + System.out.println("+---------------+-----------------------+------------------+----------+----------+"); + System.out.println("| 时间(秒) | UTC时间 | T值(十六进制) | TOTP值 | 算法 |"); + System.out.println("+---------------+-----------------------+------------------+----------+----------+"); + + // 测试每个时间点 + for (long timeValue : testTime) { + long T = (timeValue - DEFAULT_START_TIME) / DEFAULT_TIME_STEP; + String steps = Long.toHexString(T).toUpperCase(); + // 填充至16字符 + while (steps.length() < 16) { + steps = "0" + steps; + } + + String fmtTime = String.format("%1$-11s", timeValue); + String utcTime = formatter.format(Instant.ofEpochSecond(timeValue)); + + // 测试SHA1算法 + printResult(fmtTime, utcTime, steps, generateTOTP(seed20, steps, 8, HMAC_SHA1), "SHA1"); + + // 测试SHA256算法 + printResult(fmtTime, utcTime, steps, generateTOTP(seed32, steps, 8, HMAC_SHA256), "SHA256"); + + // 测试SHA512算法 + printResult(fmtTime, utcTime, steps, generateTOTP(seed64, steps, 8, HMAC_SHA512), "SHA512"); + + System.out.println("+---------------+-----------------------+------------------+----------+----------+"); + } + + // 演示当前时间TOTP生成 + System.out.println("\n当前时间TOTP演示:"); + System.out.println("----------------"); + + String currentTOTP = generateCurrentTOTP(seed20); + System.out.println("共享密钥: " + seed20); + System.out.println("当前TOTP: " + currentTOTP); + + // 验证演示 + boolean isValid = verifyTOTP(seed20, currentTOTP); + System.out.println("TOTP验证: " + (isValid ? "通过" : "失败")); + + // 错误代码验证测试 + boolean isInvalid = verifyTOTP(seed20, "000000"); + System.out.println("错误代码验证: " + (isInvalid ? "通过" : "错误")); + + System.out.println("\n测试完成"); + } + + /** + * 打印单行测试结果 + * + * @param time 时间戳字符串 + * @param utcTime UTC时间字符串 + * @param steps 时间步十六进制值 + * @param totp TOTP值 + * @param mode 算法模式 + */ + private static void printResult(String time, String utcTime, String steps, + String totp, String mode) { + System.out.printf("| %s | %s | %s | %s | %-8s |%n", + time, utcTime, steps, totp, mode); + } +} \ No newline at end of file diff --git a/springboot-2FA/src/main/resources/application.yml b/springboot-2FA/src/main/resources/application.yml new file mode 100644 index 0000000..7ed7e2f --- /dev/null +++ b/springboot-2FA/src/main/resources/application.yml @@ -0,0 +1,3 @@ +spring: + application: + name: springboot-2FA diff --git a/springboot-2FA/src/test/java/top/lrshuai/totp/Springboot2FaApplicationTests.java b/springboot-2FA/src/test/java/top/lrshuai/totp/Springboot2FaApplicationTests.java new file mode 100644 index 0000000..4af8abf --- /dev/null +++ b/springboot-2FA/src/test/java/top/lrshuai/totp/Springboot2FaApplicationTests.java @@ -0,0 +1,13 @@ +package top.lrshuai.totp; + +import org.junit.jupiter.api.Test; +import org.springframework.boot.test.context.SpringBootTest; + +@SpringBootTest +class Springboot2FaApplicationTests { + + @Test + void contextLoads() { + } + +} diff --git a/springboot-camunda/.gitattributes b/springboot-camunda/.gitattributes new file mode 100644 index 0000000..3b41682 --- /dev/null +++ b/springboot-camunda/.gitattributes @@ -0,0 +1,2 @@ +/mvnw text eol=lf +*.cmd text eol=crlf diff --git a/springboot-camunda/.gitignore b/springboot-camunda/.gitignore new file mode 100644 index 0000000..667aaef --- /dev/null +++ b/springboot-camunda/.gitignore @@ -0,0 +1,33 @@ +HELP.md +target/ +.mvn/wrapper/maven-wrapper.jar +!**/src/main/**/target/ +!**/src/test/**/target/ + +### STS ### +.apt_generated +.classpath +.factorypath +.project +.settings +.springBeans +.sts4-cache + +### IntelliJ IDEA ### +.idea +*.iws +*.iml +*.ipr + +### NetBeans ### +/nbproject/private/ +/nbbuild/ +/dist/ +/nbdist/ +/.nb-gradle/ +build/ +!**/src/main/**/build/ +!**/src/test/**/build/ + +### VS Code ### +.vscode/ diff --git a/springboot-camunda/README.md b/springboot-camunda/README.md new file mode 100644 index 0000000..31e76c6 --- /dev/null +++ b/springboot-camunda/README.md @@ -0,0 +1,680 @@ +## 引言 + +### 为什么需要工作流引擎? + +在当今快速变化的商业环境中,企业需要处理越来越复杂的业务流程。想象一下:一个员工请假申请需要经过部门经理审批、HR备案、财务记录等多个环节;一个电商订单需要经历库存检查、支付确认、发货通知、物流跟踪等步骤。这些业务流程如果硬编码在系统中,不仅难以维护,更无法快速适应业务变化。 + + + +## 一、什么是Camunda? + +- Camunda 是一个开源的工作流和业务流程管理平台,基于BPMN 2.0(业务流程模型与标记)标准构建。它提供了一个强大的流程引擎,允许开发人员将复杂的业务流程建模、执行、监控和优化。 + + + +### 1、传统开发 vs Camunda开发的对比: + + + +```java +// 传统硬编码方式 - 紧密耦合,难以维护 +public class LeaveApplicationService { + public void applyLeave(LeaveRequest request) { + // 1. 保存申请 + leaveRepository.save(request); + + // 2. 通知部门经理 + emailService.notifyManager(request); + + // 3. 如果经理批准,通知HR + // 4. 如果HR通过,更新考勤系统 + // ... 更多嵌套的条件判断 + } +} + +// 使用Camunda - 关注点分离,易于维护 +@Service +public class LeaveApplicationService { + + @Autowired + private RuntimeService runtimeService; + + public void applyLeave(LeaveRequest request) { + // 启动流程,具体步骤在BPMN图中定义 + runtimeService.startProcessInstanceByKey( + "LeaveProcess", + Variables.putValue("leaveRequest", request) + ); + } +} +``` + + + +### 2、核心组件 + +| 组件 | 功能描述 | 适用场景 | +| :------------------- | :----------------------------- | :--------------- | +| **Camunda Engine** | 核心流程引擎,负责执行BPMN流程 | 嵌入到Java应用中 | +| **Camunda Modeler** | 图形化流程设计工具 | 业务流程建模 | +| **Camunda Tasklist** | 用户任务管理界面 | 人工任务处理 | +| **Camunda Cockpit** | 流程监控和管理控制台 | 流程运维和监控 | +| **Camunda Optimize** | 流程分析和优化工具 | 性能分析和改进 | + + + + + +## 二、Springboot快速开始 + +### 1、引入依赖 + + +```text + + + org.camunda.bpm.springboot + camunda-bpm-spring-boot-starter + ${camunda.version} + + + + org.camunda.bpm.springboot + camunda-bpm-spring-boot-starter-webapp + ${camunda.version} + + + + + org.camunda.bpm.springboot + camunda-bpm-spring-boot-starter-rest + ${camunda.version} + + + + + com.mysql + mysql-connector-j + ${mysql.version} + + +``` + +### 2、配置yml + +```yml +server: + port: 8081 + +camunda.bpm: + database: + type: mysql + schema-update: true # 首次启动设置为true,自动创建表 + admin-user: + id: admin #用户名 + password: admin #密码 + firstName: rstyro- + filter: + create: All tasks + # 自动部署resources/下的BPMN文件 + auto-deployment-enabled: true + # 历史级别: none, activity, audit, full + history-level: full + generic-properties: + properties: + historyTimeToLive: P30D # 设置全局默认历史记录生存时间为30天 + enforceHistoryTimeToLive: false # 可选:禁用强制TTL检查 + # 作业执行配置 + job-execution: + enabled: true + core-pool-size: 3 + max-pool-size: 10 + +# mysql连接信息 +spring: + datasource: + driver-class-name: com.mysql.cj.jdbc.Driver + type: com.mysql.cj.jdbc.MysqlDataSource + url: jdbc:mysql://localhost:3306/camunda + username: root + password: root + jackson: + date-format: yyyy-MM-dd HH:mm:ss + time-zone: GMT+8 + +# 日志配置 +logging: + level: + org.camunda: INFO + org.springframework.web: INFO +``` + + +### 3、camunda的表解释 + + +| 表类别与前缀 | 核心职责 | 数据生命周期特点 | 代表性数据表 | +| :---------------------- | :--------------------------------------------- | :----------------------------------------------------- | :---------------------------------------------------- | +| **ACT_GE_*** (通用数据) | 存储引擎的二进制资源、属性配置和版本日志。 | 静态或长期存在,与流程定义同生命周期。 | `ACT_GE_BYTEARRAY`, `ACT_GE_PROPERTY` | +| **ACT_RE_*** (资源存储) | 存储流程定义、决策规则等"静态"部署资源。 | 静态数据,部署后一般不变化,是流程的蓝图。 | `ACT_RE_PROCDEF`, `ACT_RE_DEPLOYMENT` | +| **ACT_RU_*** (运行时) | 存储正在运行的流程实例、任务、变量等实时数据。 | **临时数据**,流程实例结束后立即被删除,保持表小而快。 | `ACT_RU_TASK`, `ACT_RU_EXECUTION`, `ACT_RU_VARIABLE` | +| **ACT_HI_*** (历史记录) | 记录所有流程实例的完整历史、活动和变量变更。 | **历史数据**,长期保存,用于查询、报告与审计。 | `ACT_HI_PROCINST`, `ACT_HI_ACTINST`, `ACT_HI_VARINST` | +| **ACT_ID_*** (身份认证) | 管理用户、用户组以及他们之间的关联关系。 | 基础主数据,独立于流程生命周期。 | `ACT_ID_USER`, `ACT_ID_GROUP`, `ACT_ID_MEMBERSHIP` | + +### 4、业务流程建模 + + +#### 安装Camunda Modeler + +我们一般会在`Camunda Modeler` 画出整个工作流的流程,然后导出 `.bpmn` 文件,然后在代码里面加载文件,进行编码的。 + +- Camunda Modeler下载地址:[https://camunda.com/download/modeler/](https://camunda.com/download/modeler/) +- 下载安装完成之后,我们可以新建一个请假流程。 + + + +![请假工作流流程](leave.png) + + + + +我们的BPMN文件内容放在`src/main/resources/process/leave.bpmn`中: + +```text + + + + + Flow_StartToApply + + + + + + 年假 + 病假 + 事假 + + + + + + + + + + + + + + + + Flow_StartToApply + Flow_ApplyToGateway + + + Flow_ApplyToGateway + Flow_GatewayToManager + Flow_GatewayToDirector + + + + + + 同意 + 拒绝 + + + + + Flow_GatewayToManager + Flow_ManagerToEnd + Flow_ManagerReject + + + + + + 同意 + 拒绝 + + + + + Flow_GatewayToDirector + Flow_DirectorToEnd + Flow_DirectorReject + + + Flow_ManagerToEnd + Flow_DirectorToEnd + Flow_HRToEnd + + + Flow_ManagerReject + Flow_DirectorReject + Flow_NotifyToEnd + + + Flow_HRToEnd + + + Flow_NotifyToEnd + + + + + ${leaveDays <= 3} + + + ${leaveDays > 3} + + + ${managerApproved == true} + + + ${managerApproved == false} + + + ${directorApproved == true} + + + ${directorApproved == false} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +``` + + + +### 5、Java服务实现 + +创建相应的Java服务类来处理业务流程: + + + +```java +@RestController +@RequestMapping("/api/leave") +public class LeaveProcessController { + + @Resource + private RuntimeService runtimeService; + + @Resource + private TaskService taskService; + + @Resource + private HistoryService historyService; + + @Resource + private IdentityService identityService; + + /** + * 启动请假流程 + */ + @PostMapping("/start") + public ResponseEntity