多租户数据隔离方案实践

背景:随着业务的发展,我们同一套业务系统需支持提供给多个客户(不同的企业用户)使用,所以需确保在多用户环境下,各用户间数据的隔离 。但目前系统在早期设计的时候没有考虑到多租户的情况,业务数据没有做到充分隔离(有些表做了字段区分,有些没有) 。
目前数据访问层用的是MyBatis框架,sql语句散布在xml里,dao注解里,量非常大 。另外,租户字段(companyId)定义也不是所有的业务实体类都有 。
基于现状,一个个修改sql,这样工作量太大了,所以排除掉一个个修改sql的方案 。只能考虑怎样可以统一修改sql 。而租户字段(companyId)的传递也需要有统一处理的地方 。
一、业务表添加数据隔字段我们先给没有租户字段(companyId)的表加上字段 。然后考虑怎样给字段统一添加值的改造 。因为业务系统目前是使用Mybatis做持久化,Mybatis有拦截器的功能,是否可以通过自定义Mybatis拦截器拦截下所有的 sql 语句,然后对其进行动态修改,自动添加company_id 字段及其字段值,实现数据隔离呢?答案是肯定的 。
二、添加Mybatis拦截器先看下Mybatis的核心对象:
Mybatis核心对象
解释
SqlSession
作为MyBatis工作的主要顶层API,表示和数据库交互的会话,完成必要数据库增删改查功能 。
Executor
MyBatis执行器,是MyBatis 调度的核心,负责SQL语句的生成和查询缓存的维护 。
StatementHandler
封装了JDBC Statement操作,负责对JDBC statement 的操作,如设置参数、将Statement结果集转换成List集合 。
ParameterHandler
负责对用户传递的参数转换成JDBC Statement 所需要的参数 。
ResultSetHandler
负责将JDBC返回的ResultSet结果集对象转换成List类型的集合 。
【多租户数据隔离方案实践】TypeHandler
负责JAVA数据类型和jdbc数据类型之间的映射和转换 。
MAppedStatement
MappedStatement维护了一条mapper.xml文件里面 select 、update、delete、insert节点的封装 。
SqlSource
负责根据用户传递的parameterObject,动态地生成SQL语句,将信息封装到BoundSql对象中 。
BoundSql
表示动态生成的SQL语句以及相应的参数信息 。
Configuration
MyBatis所有的配置信息都维持在Configuration对象 。
Mybatis拦截器可以拦截Executor、ParameterHandler、StatementHandler、ResultSetHandler四个对象里面的方法 。Executor是Mybatis的核心接口 。Mybatis中所有的Mapper语句的执行都是通过Executor进行的 。其中增删改语句是通过Executor接口的update方法,查询语句是通过query方法 。所以我们可以拦截Executor,拦载所有的select 、insert、update、delete语句进行改造,添加company_id字段及字段值 。
创建一个自定义的拦截器:
/** * Mybatis - 通用拦截器 。用于拦截sql并自动补充公共字段 。包括query、insert、update、delete语句 */@Slf4j@Intercepts({@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})})public class AutoFillParamInterceptor implements Interceptor {private static final String LAST_INSERT_ID_SQL = "LAST_INSERT_ID()";private static final String COMPANY_ID = "company_id";/*** 拦截主要的逻辑* @param invocation* @return* @throws Throwable*/@Overridepublic Object intercept(Invocation invocation) throws Throwable {final Object[] args = invocation.getArgs();final MappedStatement ms = (MappedStatement) args[0];final Object paramObj = args[1];//1.通过注解判断是否需要处理此SQLString namespace = ms.getId();String className = namespace.substring(0, namespace.lastIndexOf("."));//selectByExampleString methodName = StringUtils.substringAfterLast(namespace, ".");Class<?> classType = Class.forName(className);if (classType.isAnnotationPresent(IgnoreAutoFill.class)) {//注解在类上String userType = classType.getAnnotation(IgnoreAutoFill.class).userType();if (StringUtils.isNotBlank(userType)) {//ignore特定的用户类型,其他均拦截if (userType.equals(getCurrentUserType())) {return invocation.proceed();}} else {return invocation.proceed();}} else {//注解在方法上for (Method method : classType.getMethods()) {if (!methodName.equals(method.getName())) {continue;} else {if (method.isAnnotationPresent(IgnoreAutoFill.class)) {String userType = method.getAnnotation(IgnoreAutoFill.class).userType();if (StringUtils.isNotBlank(userType)) {//ignore特定的用户类型,其他均拦截if (userType.equals(getCurrentUserType())) {return invocation.proceed();}} else {return invocation.proceed();}}break;}}}//2.获取SQL语句BoundSql boundSql = ms.getBoundSql(paramObj);// 原始sqlString originalSql = boundSql.getSql();log.debug("originalSql:{}", originalSql);//3.根据语句类型改造SQL语句switch (ms.getSqlCommandType()) {case INSERT: {originalSql = convertInsertSQL(originalSql);args[0] = newMappedStatement(ms, boundSql, originalSql, paramObj);break;}case UPDATE:case DELETE: {originalSql = SQLUtils.addCondition(originalSql, COMPANY_ID + "='" + getCompanyId() +"'", null);args[0] = newMappedStatement(ms, boundSql, originalSql, paramObj);break;}case SELECT: {if (!StringUtils.containsIgnoreCase(originalSql, LAST_INSERT_ID_SQL)) {//where 条件拼接 companyIdMySQLStatementParser parser = new MySqlStatementParser(originalSql);SQLStatement statement = parser.parseStatement();SQLSelectStatement selectStatement = (SQLSelectStatement) statement;SQLSelect sqlSelect = selectStatement.getSelect();SQLSelectQuery query = sqlSelect.getQuery();addSelectCondition(query, COMPANY_ID + "='" + getCompanyId() + "'");originalSql = SQLUtils.toSQLString(selectStatement, JdbcConstants.MYSQL);// 将新生成的MappedStatement对象替换到参数列表中args[0] = newMappedStatement(ms, boundSql, originalSql, paramObj);}break;}}log.debug("modifiedSql:{}", originalSql);//4.应用修改后的SQL语句return invocation.proceed();}private void addSelectCondition(SQLSelectQuery query, String condition){if (query instanceof SQLUnionQuery) {SQLUnionQuery sqlUnionQuery = (SQLUnionQuery) query;addSelectCondition(sqlUnionQuery.getLeft(), condition);addSelectCondition(sqlUnionQuery.getRight(), condition);} else if (query instanceof SQLSelectQueryBlock) {SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) query;SQLTableSource tableSource = selectQueryBlock.getFrom();String conditionTmp = condition;String alias = getLeftAlias(tableSource);if (StringUtils.isNotBlank(alias)) {//拼接别名conditionTmp = alias + "." + condition;}SQLExpr conditionExpr = SQLUtils.toMySqlExpr(conditionTmp);selectQueryBlock.addCondition(conditionExpr);}}private String getLeftAlias(SQLTableSource tableSource) {if (tableSource != null) {if (tableSource instanceof SQLExprTableSource) {if (StringUtils.isNotBlank(tableSource.getAlias())) {return tableSource.getAlias();}} else if (tableSource instanceof SQLJoinTableSource) {SQLJoinTableSource join = (SQLJoinTableSource) tableSource;return getLeftAlias(join.getLeft());}}return null;}/*** 用于封装目标对象的,通过该方法我们可以返回目标对象本身,也可以返回一个它的代理* @param target* @return*/@Overridepublic Object plugin(Object target) {//只拦截Executor对象,减少目标被代理的次数if (target instanceof Executor) {return Plugin.wrap(target, this);}return target;}/*** 注册当前拦截器的时候可以设置一些属性*/@Overridepublic void setProperties(Properties properties) {}private String convertInsertSQL(String originalSql) {MySqlStatementParser parser = new MySqlStatementParser(originalSql);SQLStatement statement = parser.parseStatement();MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();statement.accept(visitor);MySqlInsertStatement myStatement = (MySqlInsertStatement) statement;String tableName = myStatement.getTableName().getSimpleName();List<SQLExpr> columns = myStatement.getColumns();List<SQLInsertStatement.ValuesClause> vcl = myStatement.getValuesList();if (columns == null || columns.size() <= 0 || myStatement.getQuery() != null) {return originalSql;}if (!visitor.containsColumn(tableName, COMPANY_ID)) {SQLExpr columnExpr = SQLUtils.toMySqlExpr(COMPANY_ID);columns.add(columnExpr);SQLExpr valuesExpr = SQLUtils.toMySqlExpr("'" + getCompanyId() + "'");vcl.stream().forEach(v -> v.addValue(valuesExpr));}return SQLUtils.toSQLString(myStatement, JdbcConstants.MYSQL);}private MappedStatement newMappedStatement(MappedStatement ms, BoundSql boundSql,String sql, Object parameter){BoundSql newBoundSql = new BoundSql(ms.getConfiguration(),sql, new ArrayList(boundSql.getParameterMappings()), parameter);for (ParameterMapping mapping : boundSql.getParameterMappings()) {String prop = mapping.getProperty();if (boundSql.hasAdditionalParameter(prop)) {newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));}}return copyFromOriMappedStatement(ms, new WarpBoundSqlSqlSource(newBoundSql));}private MappedStatement copyFromOriMappedStatement(MappedStatement ms, SqlSource newSqlSource) {MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(),ms.getId(),newSqlSource,ms.getSqlCommandType());builder.cache(ms.getCache()).databaseId(ms.getDatabaseId()).fetchSize(ms.getFetchSize()).flushCacheRequired(ms.isFlushCacheRequired()).keyColumn(StringUtils.join(ms.getKeyColumns(), ',')).keyGenerator(ms.getKeyGenerator()).keyProperty(StringUtils.join(ms.getKeyProperties(), ',')).lang(ms.getLang()).parameterMap(ms.getParameterMap()).resource(ms.getResource()).resultMaps(ms.getResultMaps()).resultOrdered(ms.isResultOrdered()).resultSets(StringUtils.join(ms.getResultSets(), ',')).resultSetType(ms.getResultSetType()).statementType(ms.getStatementType()).timeout(ms.getTimeout()).useCache(ms.isUseCache());return builder.build();}static class WarpBoundSqlSqlSource implements SqlSource {private final BoundSql boundSql;public WarpBoundSqlSqlSource(BoundSql boundSql) {this.boundSql = boundSql;}@Overridepublic BoundSql getBoundSql(Object parameterObject) {return boundSql;}}public String getCompanyId() {//先从authenticationFacade取String companyId = CompanyContext.getCompanyId();if(StringUtils.isBlank(companyId)){log.error("Can not get the companyId! {}", companyId);throw new RuntimeException("Can not get the companyId! " + companyId);}return companyId;}public String getCurrentUserType() {//authenticationFacade取AuthenticationFacade authenticationFacade = ApplicationContextProvider.getBean(AuthenticationFacade.class);Integer currentUserType = authenticationFacade.getCurrentUserType();if (currentUserType == null) {log.error("Can not get the currentUserType! {}", currentUserType);throw new RuntimeException("Can not get the currentUserType! " + currentUserType);}UserTypeEnum userTypeEnum = UserTypeEnum.getByCode(currentUserType);return userTypeEnum.getUserType();}}


推荐阅读