背景:
我们需要获取接口Controller中前端传入的Json对象参数值然后修改本次调用接口的查询sql语句。后台接收参数如果是表单数据的话,通过request.getParameterMap就可以全部获取到了,如果是json对象数据时,我们在过滤器或拦截器里通过request.getInputStream() 读取了request的输入流之后,请求走到controller层时就会报错,问题在于request的输入流只能读取一次不能重复读取。
1.示例:定义Controller查询UserList
@PostMapping("/user/list")
public PageDataInfo<UserInfo> getUserList(@RequestBody ChkReq req) {PageUtils.startPage(req);return PageUtils.buildPageDataInfo(userInfoService.getUserList(req));
}
2.定义一个容器,将输入流存储到这个容器里面
@Slf4j
public class RequestWrapper extends HttpServletRequestWrapper {/*** 存储body数据的容器*/private final byte[] body;public RequestWrapper(HttpServletRequest request) {super(request);// 将body数据存储起来String bodyStr = getBodyString(request);body = bodyStr.getBytes(Charset.defaultCharset());}public String getBodyString(final ServletRequest request) {try {return cloneInputStreamString(request.getInputStream());} catch (IOException e) {log.error("", e);throw new RuntimeException(e);}}public String getBodyString() {final InputStream inputStream = new ByteArrayInputStream(body);return cloneInputStreamString(inputStream);}private String cloneInputStreamString(InputStream inputStream) {StringBuilder sb = new StringBuilder();BufferedReader reader = null;try {reader = new BufferedReader(new InputStreamReader(inputStream, Charset.defaultCharset()));String line;while ((line = reader.readLine()) != null) {sb.append(line);}} catch (IOException e) {log.error("", e);throw new RuntimeException(e);} finally {if (reader != null) {try {reader.close();} catch (IOException e) {log.error("", e);}}}return sb.toString();}@Overridepublic BufferedReader getReader() throws IOException {return new BufferedReader(new InputStreamReader(getInputStream()));}@Overridepublic ServletInputStream getInputStream() throws IOException {final ByteArrayInputStream inputStream = new ByteArrayInputStream(body);return new ServletInputStream() {@Overridepublic int read() throws IOException {return inputStream.read();}@Overridepublic boolean isFinished() {return false;}@Overridepublic boolean isReady() {return false;}@Overridepublic void setReadListener(ReadListener readListener) {}};}}
3.我们要在过滤器中将原生的HttpServletRequest换成RequestWrapper对象
public class ReplaceStreamFilter implements Filter {@Overridepublic void init(FilterConfig filterConfig) throws ServletException {Filter.super.init(filterConfig);}@Overridepublic void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {ServletRequest requestWrapper = new RequestWrapper((HttpServletRequest) servletRequest);filterChain.doFilter(requestWrapper, servletResponse);}@Overridepublic void destroy() {Filter.super.destroy();}
}
4.注册过滤器
@Configuration
public class FilterConfig {@Beanpublic FilterRegistrationBean someFilterRegistration() {FilterRegistrationBean registration = new FilterRegistrationBean();registration.setFilter(replaceStreamFilter());registration.addUrlPatterns("/*");registration.setName("streamFilter");return registration;}@Bean(name = "replaceStreamFilter")public Filter replaceStreamFilter() {return new ReplaceStreamFilter();}
}
5.然后我们可以在拦截器中获取json数据
public class MyRequestInterceptor implements HandlerInterceptor {private ObjectMapper objectMapper = new ObjectMapper();@Overridepublic boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {if ("POST".equalsIgnoreCase(request.getMethod()) && request.getContentType() != null && request.getContentType().contains("application/json")) {/*try {byte[] requestBodyBytes = readRequestBody(request);String requestBody = new String(requestBodyBytes, StandardCharsets.UTF_8);*/try (BufferedReader reader = request.getReader()) {StringBuilder requestBody = new StringBuilder();String line;while ((line = reader.readLine()) != null) {requestBody.append(line);}// 将请求体转换为 ChkReq 对象ChkReq chkReq = objectMapper.readValue(requestBody.toString(), ChkReq.class);// 将 ChkReq 对象存储在 ServletRequestAttributes 中ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();if (attributes != null) {attributes.getRequest().setAttribute("ChkReq", chkReq);}// 继续处理请求return true;} catch (IOException e) {// 处理异常,例如返回错误响应response.setStatus(HttpServletResponse.SC_BAD_REQUEST);response.getWriter().write("Invalid JSON data");return false;}}// 如果不是 JSON 请求或者不是 POST 方法,则继续处理请求return true;}
}
6.注册拦截器
@Configuration
public class WebConfig implements WebMvcConfigurer {@Overridepublic void addInterceptors(InterceptorRegistry registry) {registry.addInterceptor(new MyRequestInterceptor()).addPathPatterns("/**"); // 指定需要拦截的路径}
}
7.在Mybatis拦截器中获取request的值修改sql
@Component
public class MyInterceptor implements InnerInterceptor {@SneakyThrows@Overridepublic void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {//InnerInterceptor.super.beforeQuery(executor, ms, parameter, rowBounds, resultHandler, boundSql);String sql = boundSql.getSql();System.out.println("sql更新之前:" + sql);//String condition = " name = '李四' " ;String condition = " 1 = 1 ";String name = null;ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();if (attributes != null) {ChkReq chkReq = (ChkReq) attributes.getRequest().getAttribute("ChkReq");if (chkReq != null) {name = "name = '" + chkReq.getName() + "'";}}PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);Select select = (Select) CCJSqlParserUtil.parse(sql);PlainSelect plainSelect = (PlainSelect) select.getSelectBody();final Expression expression = plainSelect.getWhere();final Expression envCondition = CCJSqlParserUtil.parseCondExpression(condition);final Expression envCondition2 = CCJSqlParserUtil.parseCondExpression(name);if (expression == null) {plainSelect.setWhere(envCondition);plainSelect.setWhere(envCondition2);} else {AndExpression andExpression = new AndExpression(expression, envCondition);AndExpression andExpression2 = new AndExpression(andExpression, envCondition2);plainSelect.setWhere(andExpression2);}mpBs.sql(plainSelect.toString());System.out.println("sql更新之后:" + plainSelect.toString());}}