最近在尝试从mybatis sql模板中获取参数信息,期间学习了mybatis内部的一些结构,接下来笔者就向大家分享mybatis相关知识和具体代码实现。

1. mybatis加载mapper文件

mybatis入门中,官方向大家介绍了如何快速初始化mybatis demo。样例代码如下:

String resource = "org/mybatis/example/mybatis-config.xml";
InputStream inputStream = Resources.getResourceAsStream(resource);
SqlSessionFactory sqlSessionFactory = new SqlSessionFactoryBuilder().build(inputStream);

在上边的代码里,mybatis通过SqlSessionFactoryBuilder来帮助我们创建sqlSessionFactory。打开SqlSessionFactoryBuilder类,我们能发现实际上执行的方法如下:

public SqlSessionFactory build(Reader reader, String environment, Properties properties) {
    try {
        XMLConfigBuilder parser = new XMLConfigBuilder(reader, environment, properties);
        return build(parser.parse());
    } catch (Exception e) {
        throw ExceptionFactory.wrapException("Error building SqlSession.", e);
    } finally {
        ErrorContext.instance().reset();
        try {
            reader.close();
        } catch (IOException e) {
            // Intentionally ignore. Prefer previous error.
        }
    }
}

XMLConfigBuilder类中,我们顺着parse方法,就会发现它实际上调用内部的parseConfiguration方法完成配置解析。

private void parseConfiguration(XNode root) {
    try {
        // issue #117 read properties first
        propertiesElement(root.evalNode("properties"));
        Properties settings = settingsAsProperties(root.evalNode("settings"));
        loadCustomVfs(settings);
        loadCustomLogImpl(settings);
        typeAliasesElement(root.evalNode("typeAliases"));
        pluginElement(root.evalNode("plugins"));
        objectFactoryElement(root.evalNode("objectFactory"));
        objectWrapperFactoryElement(root.evalNode("objectWrapperFactory"));
        reflectorFactoryElement(root.evalNode("reflectorFactory"));
        settingsElement(settings);
        // read it after objectFactory and objectWrapperFactory issue #631
        environmentsElement(root.evalNode("environments"));
        databaseIdProviderElement(root.evalNode("databaseIdProvider"));
        typeHandlerElement(root.evalNode("typeHandlers"));
        mapperElement(root.evalNode("mappers"));
    } catch (Exception e) {
        throw new BuilderException("Error parsing SQL Mapper Configuration. Cause: " + e, e);
    }
}

这段逻辑不用细读就能看出,它实际上是将xml配置中的各个部分分别交给了对应的方法,由各个方法实现解析和处理。而我们最关心的mapper加载就在mapperElement方法中。

2. mapper文件具体解析

在第一节中,我们已经摸到了mapperElement方法,这个方法虽然各类判断较多,如果你是按照官方文档配置的,实际上它只会调用下面这些代码:

ErrorContext.instance().resource(resource);
InputStream inputStream = Resources.getResourceAsStream(resource);
XMLMapperBuilder mapperParser = new XMLMapperBuilder(inputStream, configuration, resource, configuration.getSqlFragments());
mapperParser.parse();

你可以debug或翻查mapperParser的parse方法,看看最后解析完的sql模板最终放到了哪里。经过笔者debug和代码翻查,最终确定流转路径如下:

  1. mapperParser.parse方法内调用mapperParser.parsePendingStatements方法。
  2. mapperParser.parsePendingStatements内部调用XMLStatementBuilder.parseStatementNode方法。
  3. XMLStatementBuilder.parseStatementNode方法比较长,都是在处理具体逻辑。
  4. XMLStatementBuilder.parseStatementNode方法最后调用了builderAssistant.addMappedStatement方法。
  5. builderAssistant.addMappedStatement方法最终将解析的单个mappedStatement放到了configuration的mappedStatements属性中。

3. SqlSource和SqlNode介绍

XMLStatementBuilder类最终将单个查询语句解析成了mappedStatement,而mappedStatement中存放sql模板的属性是SqlSource,而SqlSource的实现中,使用SqlNode存放解析过的sql模板。

mybatis将sql模板的内容划分为以下这几类:

  1. ChooseSqlNode 对应标签
  2. ForEachSqlNode 对应标签
  3. IfSqlNode 对应标签
  4. MixedSqlNode 是node集合的封装体
  5. StaticTextSqlNode 不包含任何标签,也不包含$符号的内容
  6. TextSqlNode 不包含任何标签,但一段连续的内容中如果包含$符号,则会用该类型包装,最终在解析时,会将$内容进行替换
  7. TrimSqlNode 对应标签,子类WhereSqlNode对应标签,子类SetSqlNode对应标签。
  8. VarDeclSqlNode 对应标签

以上8类节点,这里就不详细展开了,大家可以翻查源代码,其中的属性就是标签中的属性和其他相关信息。

接第二节,在XMLStatementBuilder的方法中,调用了XMLLanguageDriver的createSqlSource方法生成SqlSource,然后作为mappedStatement的属性存储起来。

而XMLLanguageDriver的createSqlSource方法则调用了XMLScriptBuilder的parseScriptNode方法创建SqlSource。

public SqlSource parseScriptNode() {
    MixedSqlNode rootSqlNode = parseDynamicTags(context);
    SqlSource sqlSource;
    if (isDynamic) {
        sqlSource = new DynamicSqlSource(configuration, rootSqlNode);
    } else {
        sqlSource = new RawSqlSource(configuration, rootSqlNode, parameterType);
    }
    return sqlSource;
}

笔者细究isDynamic,最终确定如果在sql模板中包含$内容或xml标签,isDynamic就会为true。

3.1 RawSqlSource

DynamicSqlSource和RawSqlSource有什么区别呢?我们先来看看RawSqlSource的相关方法:

RawSqlSource

public RawSqlSource(Configuration configuration, SqlNode rootSqlNode, Class<?> parameterType) {
    this(configuration, getSql(configuration, rootSqlNode), parameterType);
}

public RawSqlSource(Configuration configuration, String sql, Class<?> parameterType) {
    SqlSourceBuilder sqlSourceParser = new SqlSourceBuilder(configuration);
    Class<?> clazz = parameterType == null ? Object.class : parameterType;
    sqlSource = sqlSourceParser.parse(sql, clazz, new HashMap<>());
}

private static String getSql(Configuration configuration, SqlNode rootSqlNode) {
    DynamicContext context = new DynamicContext(configuration, null);
    rootSqlNode.apply(context);
    return context.getSql();
}

SqlSourceBuilder

public SqlSource parse(String originalSql, Class<?> parameterType, Map<String, Object> additionalParameters) {
    SqlSourceBuilder.ParameterMappingTokenHandler handler = new SqlSourceBuilder.ParameterMappingTokenHandler(configuration, parameterType, additionalParameters);
    GenericTokenParser parser = new GenericTokenParser("#{", "}", handler);
    String sql;
    if (configuration.isShrinkWhitespacesInSql()) {
        sql = parser.parse(removeExtraWhitespaces(originalSql));
    } else {
        sql = parser.parse(originalSql);
    }
    return new StaticSqlSource(configuration, sql, handler.getParameterMappings());
}

经过debug,笔者发现由于sql模板中只有#参数,mybatis在初始化解析的时候,直接将#参数变为?,然后在对应的ParameterMappings列表中上添加一个ParameterMapping。这样做避免了每次查询都要提取#参数,生成paramedSql
。它的好处是能够加快查询速度,减少内存消耗。

大家需要注意,最终存放在RawSqlSource中的StaticSqlSource,而StaticSqlSource中包含已经解析出参数的ParameterMappings列表。

3.2 DynamicSqlSource

而DynamicSqlSource恰好相反,由于sql的最终形态和入参息息相关,所以mybatis无法对这类sql模板预处理,只能在运行时动态渲染生成paramedSql。

这块内容由于不是本文的重点,就暂不细讲了。但大家需要知道的是DynamicSqlSource中的sqlNode类型为MixedSqlNode即可。

4 sql模板参数获取

经过前三节的分析,我们已经得知sql模板最终存放在Configuration->MappedStatement->SqlSource中。接下来我们就可以模拟mybatis初始化,然后从SqlSource中获取参数信息。

笔者在这里定义了一个枚举类ParamType,用来区分参数类型。

package com.gavinzh.learn.mybatis;

public enum ParamType {
    // # 预编译
    PRE_COMPILE,
    // $ 被替换
    REPLACE,
    // foreach产出的内容变量
    INTERNAL,
    // 使用bind标签产生的变量
    BIND;
}

接下来定义一个bean InputParam,用于存放解析出来的参数。

package com.gavinzh.learn.mybatis;

public class InputParam {
    private String name;
    private ParamType type;
    private String source;
    private boolean required;

    public InputParam(String name, ParamType type, String source, boolean required) {
        this.name = name;
        this.type = type;
        this.source = source;
        this.required = required;
    }

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public ParamType getType() {
        return type;
    }

    public void setType(ParamType type) {
        this.type = type;
    }

    public String getSource() {
        return source;
    }

    public void setSource(String source) {
        this.source = source;
    }

    public boolean isRequired() {
        return required;
    }

    public void setRequired(boolean required) {
        this.required = required;
    }
}

最重要的工具类ParamUtils,笔者借用了mybatis中的GenericTokenParser查找#和$参数。

package com.gavinzh.learn.mybatis;

import org.apache.ibatis.builder.StaticSqlSource;
import org.apache.ibatis.builder.xml.XMLMapperBuilder;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.parsing.GenericTokenParser;
import org.apache.ibatis.parsing.TokenHandler;
import org.apache.ibatis.parsing.XNode;
import org.apache.ibatis.scripting.defaults.RawSqlSource;
import org.apache.ibatis.scripting.xmltags.*;
import org.apache.ibatis.session.Configuration;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.util.*;

import static com.gavinzh.learn.mybatis.ParamType.*;
import static java.util.stream.Collectors.toList;

/**
 * mybatis 参数处理器
 *
 * @author zhangheng
 * @date 2021-01-09 17:15
 */
public class ParamUtils {

    private static final String verifyTemplate = "<?xml version=\"1.0\" encoding=\"UTF-8\" ?>\n"
        + "<!DOCTYPE mapper\n"
        + "        PUBLIC \"-//mybatis.org//DTD Mapper 3.0//EN\"\n"
        + "        \"http://mybatis.org/dtd/mybatis-3-mapper.dtd\">\n"
        + "<mapper namespace=\"Verify\">\n"
        + "    <select id=\"verify\">\n"
        + "    %s\n"
        + "    </select>\n"
        + "</mapper>";

    public static SqlSource xmlVerify(String xml) {
        xml = String.format(verifyTemplate, xml);
        InputStream inputStream = new ByteArrayInputStream(xml.getBytes());
        Configuration configuration = new Configuration();
        Map<String, XNode> sqlFragments = configuration.getSqlFragments();
        XMLMapperBuilder xmlMapperBuilder =
            new XMLMapperBuilder(inputStream, configuration, "Verify", sqlFragments);

        xmlMapperBuilder.parse();
        MappedStatement mappedStatement = configuration.getMappedStatement("verify", false);
        return mappedStatement.getSqlSource();
    }

    public static List<InputParam> parseInputParam(String xml) {
        SqlSource sqlSource = xmlVerify(xml);
        return parseInputParam(sqlSource);
    }

    public static List<InputParam> parseInputParam(SqlSource sqlSource) {
        if (sqlSource instanceof RawSqlSource) {
            return parseRawParam((RawSqlSource)sqlSource);
        }
        if (sqlSource instanceof DynamicSqlSource) {
            return parseDynamicParam((DynamicSqlSource)sqlSource);
        }
        return Collections.emptyList();
    }

    private static List<InputParam> parseDynamicParam(DynamicSqlSource sqlSource) {
        SqlNode sqlNode = getFieldValue(sqlSource, "rootSqlNode");
        return parseSqlNode(sqlNode, true);
    }

    private static List<InputParam> parseSqlNode(SqlNode sqlNode, boolean required) {
        if (sqlNode instanceof MixedSqlNode) {
            return parseMixedSqlNodeParam((MixedSqlNode)sqlNode, required);
        }
        if (sqlNode instanceof TextSqlNode) {
            return parseTextSqlNodeParam((TextSqlNode)sqlNode, required);
        }
        if (sqlNode instanceof StaticTextSqlNode) {
            return parseStaticTextSqlNodeParam((StaticTextSqlNode)sqlNode, required);
        }
        if (sqlNode instanceof IfSqlNode) {
            return parseIfSqlNodeParam((IfSqlNode)sqlNode);
        }
        if (sqlNode instanceof ForEachSqlNode) {
            return parseForEachSqlNodeParam((ForEachSqlNode)sqlNode, required);
        }
        if (sqlNode instanceof ChooseSqlNode) {
            return parseChooseSqlNodeParam((ChooseSqlNode)sqlNode);
        }
        if (sqlNode instanceof TrimSqlNode) {
            return parseTrimSqlNodeParam((TrimSqlNode)sqlNode, required);
        }
        if (sqlNode instanceof VarDeclSqlNode) {
            return parseVarDeclSqlNodeParam((VarDeclSqlNode)sqlNode, required);
        }
        return Collections.emptyList();
    }

    private static List<InputParam> parseVarDeclSqlNodeParam(VarDeclSqlNode sqlNode, boolean required) {
        String name = getFieldValue(sqlNode, "name");
        String expression = getFieldValue(sqlNode, "expression");
        return Collections.singletonList(new InputParam(name, BIND, expression, required));
    }

    private static List<InputParam> parseTrimSqlNodeParam(TrimSqlNode sqlNode, boolean required) {
        SqlNode contents = getFieldValue(sqlNode, "contents");
        return parseSqlNode(contents, required);
    }

    private static List<InputParam> parseChooseSqlNodeParam(ChooseSqlNode sqlNode) {
        List<InputParam> chooseParamList = new ArrayList<InputParam>();

        List<SqlNode> ifSqlNodes = getFieldValue(sqlNode, "ifSqlNodes");
        ifSqlNodes.forEach(content -> chooseParamList.addAll(parseSqlNode(content, false)));

        SqlNode defaultSqlNode = getFieldValue(sqlNode, "defaultSqlNode");
        if (defaultSqlNode != null) {
            chooseParamList.addAll(parseSqlNode(defaultSqlNode, false));
        }
        return chooseParamList;
    }

    private static List<InputParam> parseForEachSqlNodeParam(ForEachSqlNode sqlNode, boolean required) {
        List<InputParam> forEachParamList = new ArrayList<InputParam>();

        // TODO collectionExpression可以表达式,但大多数情况都只会传入变量
        String collectionExpression = getFieldValue(sqlNode, "collectionExpression");
        forEachParamList.add(new InputParam(collectionExpression, PRE_COMPILE, null, required));

        String item = getFieldValue(sqlNode, "item");
        if (item != null) {
            forEachParamList.add(new InputParam(item, INTERNAL, collectionExpression, false));
        }
        String index = getFieldValue(sqlNode, "index");
        if (index != null) {
            forEachParamList.add(new InputParam(index, INTERNAL, collectionExpression, false));
        }
        // TODO foreach标签内部还可以获取参数

        return forEachParamList;
    }

    private static List<InputParam> parseIfSqlNodeParam(IfSqlNode sqlNode) {
        // TODO 还可以从test中获取变量
        SqlNode contents = getFieldValue(sqlNode, "contents");
        return parseSqlNode(contents, false);
    }

    private static List<InputParam> parseStaticTextSqlNodeParam(StaticTextSqlNode sqlNode, boolean required) {
        TextSqlNodeTokenHandler handler = new TextSqlNodeTokenHandler();
        GenericTokenParser parser = new GenericTokenParser("#{", "}", handler);
        parser.parse(getFieldValue(sqlNode, "text"));
        // TODO mybatis允许在大括号内标记类型,所以可以从大括号内尝试获取类型
        return handler.getParamSet().stream()
            .map(param -> new InputParam(param, PRE_COMPILE, null, required))
            .collect(toList());
    }

    private static List<InputParam> parseTextSqlNodeParam(TextSqlNode sqlNode, boolean required) {
        TextSqlNodeTokenHandler handler1 = new TextSqlNodeTokenHandler();
        GenericTokenParser parser1 = new GenericTokenParser("${", "}", handler1);
        parser1.parse(getFieldValue(sqlNode, "text"));

        TextSqlNodeTokenHandler handler2 = new TextSqlNodeTokenHandler();
        GenericTokenParser parser2 = new GenericTokenParser("#{", "}", handler2);
        parser2.parse(getFieldValue(sqlNode, "text"));
        // TODO mybatis允许在大括号内标记类型,所以可以从大括号内尝试获取类型
        List<InputParam> all = new ArrayList<InputParam>();
        all.addAll(handler1.getParamSet().stream()
            .map(param -> new InputParam(param, REPLACE, null, required))
            .collect(toList()));
        all.addAll(handler2.getParamSet().stream()
            .map(param -> new InputParam(param, PRE_COMPILE, null, required))
            .collect(toList()));
        return all;
    }

    private static List<InputParam> parseMixedSqlNodeParam(MixedSqlNode sqlNode, boolean required) {
        List<SqlNode> contents = getFieldValue(sqlNode, "contents");
        return contents.stream()
            .map(node -> parseSqlNode(node, required))
            .flatMap(Collection::stream)
            .collect(toList());
    }

    private static List<InputParam> parseRawParam(RawSqlSource sqlSource) {
        StaticSqlSource SqlSource = getFieldValue(sqlSource, "sqlSource");
        List<ParameterMapping> parameterMappings = getFieldValue(SqlSource, "parameterMappings");
        return parameterMappings.stream()
            .map(parameterMapping ->
                new InputParam(parameterMapping.getProperty(), PRE_COMPILE, null, true))
            .collect(toList());
    }

    private static <T> T getFieldValue(Object o, String fieldName) {
        return getFieldValue(o.getClass(), o, fieldName);
    }

    private static <T> T getFieldValue(Class clazz, Object o, String fieldName) {
        try {
            Field field = clazz.getDeclaredField(fieldName);
            field.setAccessible(true);
            return (T)field.get(o);
        } catch (Exception e) {
            if (clazz.getSuperclass() != null) {
                return getFieldValue(clazz.getSuperclass(), o, fieldName);
            } else {
                return null;
            }
        }
    }

    static class TextSqlNodeTokenHandler implements TokenHandler {

        private Set<String> paramSet = new HashSet<String>();

        public Set<String> getParamSet() {
            return paramSet;
        }

        @Override
        public String handleToken(String content) {
            String param = content;
            if (param.contains(",")) {
                param = param.split(",")[0].trim();
            }
            paramSet.add(param);
        }
    }

}

工具类中,有一些TODO项,都是一些可以更进一步做的事情,大家有兴趣的话,可以自行再次开发。

最终,我们拿一个官网的例子演示一下:

package com.gavinzh.learn.mybatis;

/**
 * @author zhangheng
 * @date 2021-01-09 17:32
 */
public class Main {
    private static String mybatisSql = "<bind name=\"likeStr\" value=\"'%' + like + '%'\" />"
        + "SELECT * FROM BLOG\n"
        + "<where>\n"
        + "    <choose>\n"
        + "        <when test=\"title != null\">\n"
        + "            AND title like #{title}\n"
        + "        </when>\n"
        + "        <when test=\"author != null and author_name != null\">\n"
        + "            AND author_name like #{author_name}\n"
        + "        </when>\n"
        + "        <otherwise>\n"
        + "            AND featured = 1\n"
        + "        </otherwise>\n"
        + "    </choose>\n"
        + "    <foreach item=\"item\" index=\"index\" collection=\"list\"\n"
        + "             open=\"(\" separator=\",\" close=\")\">\n"
        + "        #{item}\n"
        + "    </foreach>\n"
        + "    <if test=\"state != null\">\n"
        + "        state = #{state}\n"
        + "    </if>\n"
        + "</where>\n"
        + "limit ${pageSize} offset ${offset} #{test}";

    public static void main(String[] args) {
        ParamUtils.parseInputParam(mybatisSql)
            .forEach(inputParam -> {
                if (inputParam.getSource() != null) {
                    System.out.println(String.format("%s 类型:%s 必填:%s 源:%s", inputParam.getName(),
                        inputParam.getType(), inputParam.isRequired(), inputParam.getSource()));
                } else {
                    System.out.println(String.format("%s 类型:%s 必填:%s", inputParam.getName(),
                        inputParam.getType(), inputParam.isRequired()));
                }

            });
    }
}

结果如下:

2021-01-09T11:32:53.png

5. 小结

笔者基本上是顺藤摸瓜,借mybatis自身的方法实现了相关参数解析。回过头来再看,如果直接用正则,似乎也可以拿到参数,但其中foreach标签的解析可能就会有问题,也不可能得知一个参数是否是必须传入的。

最终总结一下,通过mybatis的sqlNode结构获取参数信息是获得参数的最佳手段。

标签: java, Mybatis, database

添加新评论