SpringBoot通过AOP与注解实现入参校验 GET \ POST

问题

request.getParameterMap() 方法无法获取POST方法中的参数。

解决方法

POST无法通过getParameter获取到参数,请求体只能通过getInputStream或者是getReader来获取到。通过流的方式获取到后,通过FastJson里面的方法将其转成Map返回。

在使用@RequestBody的时候,它会通过流的方式将数据读出来(getReader或getInputStream),而这种方式读取数据只能读取一次,不能读取第二次。

这里我解决这一问题的方法是先将RequestBody保存为一个byte数组,然后继承HttpServletRequestWrapper类覆盖getReader()和getInputStream()方法,使流从保存的byte数组读取。

继承HttpServletRequestWrapper类重写getInputStream和getReader方法,每次读的时候读取保存在requestBody中的数据。

完整代码

SysLog

@Target(ElementType.METHOD)  //该注解应用于方法
@Retention(RetentionPolicy.RUNTIME) //在vm运行期间保留注解
@Documented           //将注解包含在javadoc中
public @interface SysLog {

    String value() default ""; // 接口名称

}

SysLogAspect


/**
 * 系统日志,切面处理类
 * 处理@SysLog注解
 */
@Aspect
@Component
public class SysLogAspect {
    private static Logger logger = LoggerFactory.getLogger(SysLogAspect.class);

    @Resource
    private SysLogMapper sysLogMapper;

    /**
     * SpringBoot获取当前环境代码,Spring获取当前环境代码
     */
    @Value("${spring.profiles.active}")
    private String profiles;

    @Pointcut("@annotation(com.sxmingyun.liandong.common.annotation.SysLog)")
    public void logPointCut() {
    }

    @Around("logPointCut()")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
        long beginTime = System.currentTimeMillis();

        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        SysLogEntity sysLog = new SysLogEntity();
        SysLog syslog = method.getAnnotation(SysLog.class);
        if (syslog != null) {
            //注解上的描述
            String value = syslog.value();
            logger.info("=====> {} {}", value, request.getRequestURI());
            sysLog.setOperation(value);
        }

        //请求的方法名
        String className = joinPoint.getTarget().getClass().getName();
        String methodName = signature.getName();
        sysLog.setMethod(className + "." + methodName + "()");

        sysLog.setIp(IPUtils.getIp(request));
        sysLog.setCreateDate(new Date());
        sysLog.setProfiles(profiles);

        //请求的参数
        Object[] args = joinPoint.getArgs();
        JSONObject paramObj = JSONObject.parseObject(JSON.toJSONString(getRequestParams(request)));
        String params = paramObj.isEmpty() ? "" : JSONObject.toJSONString(paramObj);
        sysLog.setParams(params);
        sysLog.setRpid(StringUtils.isNotBlank(paramObj.getString("rpid")) ? paramObj.getString("rpid") : "");
        logger.info("==> 参数:{}", params);

        // 返回内容
        Object result = joinPoint.proceed(args);
        String results = JSONObject.toJSONString(result);
        JSONObject res = JSONObject.parseObject(results);
        sysLog.setResult(results.replaceAll("\\s*|\t|\r|\n", ""));
        sysLog.setCode(res.getString("retcode"));

        //执行时长(毫秒)
        long time = System.currentTimeMillis() - beginTime;
        sysLog.setTime(time);

        //保存系统日志
        StopWatch watch = StopWatch.create("log");
        watch.start();
        try {
            sysLogMapper.insertSysLog(sysLog);
        } catch (Exception e) {
            logger.error("日志记录失败:{}", e.getMessage());
        }
        watch.stop();
        logger.info("insert time: {}", watch.getTotalTimeMillis());
        logger.info("==> 返回:code={}, logId={}", res.getLongValue("retcode"), sysLog.getId());
        return result;
    }

    /**
     * 获取请求参数
     *
     * @param request
     * @return
     * @throws IOException
     */
    public Map<String, String> getRequestParams(HttpServletRequest request) throws IOException {
        Map<String, String> resultParam = null;
        if (request.getMethod().equalsIgnoreCase("POST")) {
            StringBuffer data = new StringBuffer();
            String line = null;
            BufferedReader reader = request.getReader();
            while (null != (line = reader.readLine())) {
                data.append(line);
            }
            if (data.length() != 0) {
                resultParam = JSONObject.parseObject(data.toString(), new TypeReference<Map<String, String>>() {
                });
            }
        } else if (request.getMethod().equalsIgnoreCase("GET")) {
            resultParam = request.getParameterMap().entrySet().stream()
                    .collect(Collectors.toMap(i -> i.getKey(), e -> Arrays.stream(e.getValue())
                            .collect(Collectors.joining(","))));
        }
        return resultParam != null ? resultParam : new HashMap();
    }

}

CustomRequestWrapper


/**
 * 继承HttpServletRequestWrapper类重写getInputStream和getReader方法,每次读的时候读取保存在requestBody中的数据
 *
 */
public class CustomRequestWrapper extends HttpServletRequestWrapper {

    private byte[] requestBody;
    private HttpServletRequest request;

    public CustomRequestWrapper(HttpServletRequest request) {
        super(request);
        this.request = request;
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        if (this.requestBody == null) {
            ByteArrayOutputStream bos = new ByteArrayOutputStream();
            IOUtils.copy(request.getInputStream(), bos);
            this.requestBody = bos.toByteArray();
        }
        ByteArrayInputStream bis = new ByteArrayInputStream(requestBody);
        return new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            @Override
            public int read() throws IOException {
                return bis.read();
            }
        };
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(this.getInputStream()));
    }
}

CustomFilter

@Component
@WebFilter(filterName = "channelFilter", urlPatterns = {"/*"})
public class CustomFilter implements Filter {
    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        ServletRequest requestWrapper = null;
        if (request instanceof HttpServletRequest) {
            requestWrapper = new CustomRequestWrapper((HttpServletRequest) request);
        }
        if (requestWrapper == null) {
            filterChain.doFilter(request, servletResponse);
        } else {
            filterChain.doFilter(requestWrapper, servletResponse);
        }
    }
}

参考文章:

https://www.jb51.net/article/248246.htm

本文链接: https://jianz.xyz/index.php/archives/406/

1 + 9 =
快来做第一个评论的人吧~