diff --git a/ruoyi-common/src/main/java/com/ruoyi/common/annotation/RateLimiter.java b/ruoyi-common/src/main/java/com/ruoyi/common/annotation/RateLimiter.java index 2e3ff7a82..e2693b520 100644 --- a/ruoyi-common/src/main/java/com/ruoyi/common/annotation/RateLimiter.java +++ b/ruoyi-common/src/main/java/com/ruoyi/common/annotation/RateLimiter.java @@ -15,7 +15,8 @@ import java.lang.annotation.*; @Documented public @interface RateLimiter { /** - * 限流key + * 限流key,支持使用Spring el表达式来动态获取方法上的参数值 + * 格式类似于 #code.id #{#code} */ String key() default CacheConstants.RATE_LIMIT_KEY; @@ -33,4 +34,9 @@ public @interface RateLimiter { * 限流类型 */ LimitType limitType() default LimitType.DEFAULT; + + /** + * 提示消息 支持国际化 格式为 {code} + */ + String message() default "{rate.limiter.message}"; } diff --git a/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/RateLimiterAspect.java b/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/RateLimiterAspect.java index 542b16bfd..cc953b1fb 100644 --- a/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/RateLimiterAspect.java +++ b/ruoyi-framework/src/main/java/com/ruoyi/framework/aspectj/RateLimiterAspect.java @@ -5,6 +5,7 @@ import com.ruoyi.common.enums.LimitType; import com.ruoyi.common.exception.ServiceException; import com.ruoyi.common.utils.MessageUtils; import com.ruoyi.common.utils.ServletUtils; +import com.ruoyi.common.utils.StringUtils; import com.ruoyi.common.utils.redis.RedisUtils; import lombok.extern.slf4j.Slf4j; import org.aspectj.lang.JoinPoint; @@ -12,6 +13,14 @@ import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.annotation.Before; import org.aspectj.lang.reflect.MethodSignature; import org.redisson.api.RateType; +import org.springframework.core.DefaultParameterNameDiscoverer; +import org.springframework.core.ParameterNameDiscoverer; +import org.springframework.expression.EvaluationContext; +import org.springframework.expression.ExpressionParser; +import org.springframework.expression.ParserContext; +import org.springframework.expression.common.TemplateParserContext; +import org.springframework.expression.spel.standard.SpelExpressionParser; +import org.springframework.expression.spel.support.StandardEvaluationContext; import org.springframework.stereotype.Component; import java.lang.reflect.Method; @@ -26,6 +35,16 @@ import java.lang.reflect.Method; @Component public class RateLimiterAspect { + //定义spel表达式解析器 + private final ExpressionParser parser = new SpelExpressionParser(); + //定义spel解析模版 + private final ParserContext parserContext = new TemplateParserContext(); + //定义spel上下文对象进行解析 + private final EvaluationContext context = new StandardEvaluationContext(); + //方法参数解析器 + private final ParameterNameDiscoverer pnd = new DefaultParameterNameDiscoverer(); + + @Before("@annotation(rateLimiter)") public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable { int time = rateLimiter.time(); @@ -38,18 +57,45 @@ public class RateLimiterAspect { } long number = RedisUtils.rateLimiter(combineKey, rateType, count, time); if (number == -1) { - throw new ServiceException(MessageUtils.message("rate.limiter.message")); + String message = rateLimiter.message(); + if (StringUtils.startsWith(message, "{") && StringUtils.endsWith(message, "}")) { + message = MessageUtils.message(StringUtils.substring(message, 1, message.length() - 1)); + } + throw new ServiceException(message); } log.info("限制令牌 => {}, 剩余令牌 => {}, 缓存key => '{}'", count, number, combineKey); - } catch (ServiceException e) { - throw e; } catch (Exception e) { - throw new RuntimeException("服务器限流异常,请稍候再试"); + if (e instanceof ServiceException) { + throw e; + } else { + throw new RuntimeException("服务器限流异常,请稍候再试"); + } } } public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) { - StringBuilder stringBuffer = new StringBuilder(rateLimiter.key()); + String key = rateLimiter.key(); + //获取方法(通过方法签名来获取) + MethodSignature signature = (MethodSignature) point.getSignature(); + Method method = signature.getMethod(); + Class targetClass = method.getDeclaringClass(); + //判断是否是spel格式 + if (StringUtils.containsAny(key, "#")) { + //获取参数值 + Object[] args = point.getArgs(); + //获取方法上参数的名称 + String[] parameterNames = pnd.getParameterNames(method); + for (int i = 0; i < parameterNames.length; i++) { + context.setVariable(parameterNames[i], args[i]); + } + //解析返回给key + try { + key = parser.parseExpression(key, parserContext).getValue(context, String.class) + ":"; + } catch (Exception e) { + throw new ServiceException("限流key解析异常!请联系管理员!"); + } + } + StringBuilder stringBuffer = new StringBuilder(key); if (rateLimiter.limitType() == LimitType.IP) { // 获取请求ip stringBuffer.append(ServletUtils.getClientIP()).append("-"); @@ -57,9 +103,6 @@ public class RateLimiterAspect { // 获取客户端实例id stringBuffer.append(RedisUtils.getClient().getId()).append("-"); } - MethodSignature signature = (MethodSignature) point.getSignature(); - Method method = signature.getMethod(); - Class targetClass = method.getDeclaringClass(); stringBuffer.append(targetClass.getName()).append("-").append(method.getName()); return stringBuffer.toString(); }