如何手写动态代理实现数据库事务

2023年 9月 2日 44.9k 0

动态代理类似于ioc,但具体的说动态代理编程方式符合AOP面向切面编程,动态代理就是,在程序运行期,创建目标对象的代理对象,并对目标对象中的方法进行功能性增强的一种技术。在生成代理对象的过程中,目标对象不变,代理对象中的方法是目标对象方法的增强方法。可以理解为运行期间,对象中方法的动态拦截,在拦截方法的前后执行功能操作。代理类在程序运行期间,创建的代理对象称之为动态代理对象。这种情况下,创建的代理对象,并不是事先在Java代码中定义好的。而是在运行期间,根据我们在动态代理对象中的“指示”,动态生成的。也就是说,你想获取哪个对象的代理,动态代理就会为你动态的生成这个对象的代理对象。动态代理可以对被代理对象的方法进行功能增强。有了动态代理的技术,那么就可以在不修改方法源码的情况下,增强被代理对象的方法的功能,在方法执行前后做任何你想做的事情。实现动态代理的方式不同,这里主要采用JDK自带的 Proxy.newProxyInstance(三个参数);

首先是代码流程的解析:老样子还是通过注解标记

//标记查询方法
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Select {
    String value();
}

//标记插入方法
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Insert {
    String value();
}

//标记修改方法
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Update {
    String value();
}

//标记删除方法
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Delete {
    String value();
}

//标记参数
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.PARAMETER)
public @interface Param {
    String value();
}

接下来是通用的JDBCUtils,与往前一样,基本可以直接用,这里通过代理类去调用代替原先的DAO层调用


public final class JDBCUtils {

    private static boolean autoCommit;

    /** 声明一个 Connection类型的静态属性,用来缓存一个已经存在的连接对象 */
    private static DbConfig dbConfig = new DbConfig();
    private static ConnectionPool connectionPool = new ConnectionPool(dbConfig);


    static {
        config();
    }

    /**
     * 开头配置数据库信息
     */
    private static void config() {
        autoCommit = false;
    }



    /**
     * 建立数据库连接
     */
    //连接数据库
    public static Connection getConnection() {
        Connection conn = connectionPool.getConnection();
        if (conn!=null){
            LoggerUtil.info("get Connection successfully");
            return conn;
        }else return null;
    }

    /**
     * 设置是否自动提交事务
     **/
    public static void transaction(Connection conn) {

        try {
            conn.setAutoCommit(autoCommit);
        } catch (SQLException e) {
            System.out.println("设置事务的提交方式为 : " + (autoCommit ? "自动提交" : "手动提交") + " 时失败: " + e.getMessage());
        }

    }

    /**
     * 创建 Statement 对象
     */
    public static Statement statement(Connection connection) {
        Statement st = null;
        /* 如果连接是无效的就重新连接 */
        transaction(connection);
        /* 设置事务的提交方式 */
        try {
            st = connection.createStatement(ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY);
        } catch (SQLException e) {
            System.out.println("创建 Statement 对象失败: " + e.getMessage());
        }

        return st;
    }

    /**
     * 根据给定的带参数占位符的SQL语句,创建 PreparedStatement 对象
     *
     * @param SQL
     *            带参数占位符的SQL语句
     * @return 返回相应的 PreparedStatement 对象
     */
    private static PreparedStatement prepare(Connection connection,String SQL, boolean autoGeneratedKeys) {

        PreparedStatement ps = null;
        /* 如果连接是无效的就重新连接 */
        transaction(connection);
        /* 设置事务的提交方式 */
        try {
            if (autoGeneratedKeys) {
                ps = connection.prepareStatement(SQL, Statement.RETURN_GENERATED_KEYS);
            } else {
                ps = connection.prepareStatement(SQL);
            }
        } catch (SQLException e) {
            System.out.println("创建 PreparedStatement 对象失败: " + e.getMessage());
        }

        return ps;

    }

    //查询语句
    public static ResultSet query(String SQL, List params) {
        Connection connection = null;
        PreparedStatement ps  = null;
        ResultSet rs = null;
        connection = getConnection();
        if (SQL == null || SQL.trim().isEmpty() || !SQL.trim().toLowerCase().startsWith("select")) {
            throw new RuntimeException("你的SQL语句为空或不是查询语句");
        }

        if (params.size() > 0) {
            /* 说明 有参数 传入,就需要处理参数 */
            ps = prepare(connection,SQL, false);
            try {
                for (int i = 0; i  0) { // 说明有参数
            Connection c = getConnection();
            PreparedStatement ps = prepare(c,SQL, false);
            try {
                c = ps.getConnection();
            } catch (SQLException e) {
                e.printStackTrace();
            }
            try {
                for (int i = 0; i  0) { // 说明有参数
            Connection c = getConnection();
            PreparedStatement ps = prepare(c,SQL, autoGeneratedKeys);

            try {
                for (int i = 0; i < params.size(); i++) {
                    Object p = params.get(i);
                    p = typeof(p);
                    ps.setObject(i + 1, p);
                }
                int count = ps.executeUpdate();
                if (autoGeneratedKeys) { // 如果希望获得数据库产生的键
                    ResultSet rs = ps.getGeneratedKeys(); // 获得数据库产生的键集
                    if (rs.next()) { // 因为是保存的是单条记录,因此至多返回一个键
                        var = rs.getInt(1); // 获得值并赋值给 var 变量
                    }
                } else {
                    var = count; // 如果不需要获得,则将受SQL影像的记录数赋值给 var 变量
                }
                commit(c);

            } catch (SQLException e) {
                System.out.println("数据保存失败: " + e.getMessage());
                rollback(c);
            }finally {
                release(c);
            }
        } else { // 说明没有参数
            Connection c = getConnection();
            Statement st = statement(c);
            // 执行 DDL 或 DML 语句,并返回执行结果
            try {
                int count = st.executeUpdate(SQL);
                if (autoGeneratedKeys) { // 如果企望获得数据库产生的键
                    ResultSet rs = st.getGeneratedKeys(); // 获得数据库产生的键集
                    if (rs.next()) { // 因为是保存的是单条记录,因此至多返回一个键
                        var = rs.getInt(1); // 获得值并赋值给 var 变量
                    }
                } else {
                    var = count; // 如果不需要获得,则将受SQL影像的记录数赋值给 var 变量
                }
                commit(c); // 提交事务
            } catch (SQLException e) {
                System.out.println("数据保存失败: " + e.getMessage());
                rollback(c); // 回滚事务
            }finally {
                release(c);
            }
        }

        return var;
    }

    /** 提交事务 */
    private static void commit(Connection c) {
        if (c != null && !autoCommit) {
            try {
                c.commit();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }

    /** 回滚事务 */
    private static void rollback(Connection c) {
        if (c != null && !autoCommit) {
            try {
                c.rollback();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 释放资源
     **/
    public static void release(Object cloaseable) {

        if (cloaseable != null) {

            if (cloaseable instanceof ResultSet) {
                ResultSet rs = (ResultSet) cloaseable;
                try {
                    rs.close();
                } catch (SQLException e) {
                    e.printStackTrace();
                }
            }

            if (cloaseable instanceof Statement) {
                Statement st = (Statement) cloaseable;
                try {
                    st.close();
                } catch (SQLException e) {
                    e.printStackTrace();
                }
            }

            if (cloaseable instanceof Connection) {
                Connection c = (Connection) cloaseable;
                connectionPool.releaseConnection(c);
            }

        }

    }

}

代码较长,不做详细解释,注释中都有相应的解释,重点放在代理类

首先让代理类实现InvocationHandler的接口并重写invoke方法完成代理反射执行该方法中判断相应的注解标记方法并调用

@SuppressWarnings("unused")//忽略未使用的变量
public class MyInvocationHandlerMybatis implements InvocationHandler {

    // 代理对象
    private Object target;

    public MyInvocationHandlerMybatis(Object target) {
        this.target = target;
    }

    /**
     * 代理反射执行
     */
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        Object reuslt = null;
        // reuslt = method.invoke(target, args);
        // 判断方法上是否有注解
        Annotation[] annotations = method.getAnnotations();
        if (null != annotations && annotations.length > 0) {
            // Annotation ann = annotations[0];
            // 可能有多个注解
            for (Annotation annotation : annotations) {
                //判断注解类型
                if (annotation instanceof Insert) {
                    // 添加
                    reuslt = intsertSql(method, args);
                } else if (annotation instanceof Delete) {
                    // 删除
                    reuslt = deleteSql(method,args);
                } else if (annotation instanceof Update) {
                    // 修改
                    reuslt = updateSql(method,args);
                } else if (annotation instanceof Select) {
                    // 查询
                    reuslt = selectSql(method, args);
                }
            }
        }
        return reuslt;
    }

之后需要一个getMethodParsMap方法,用于封装传入的参数列表并返回参数Map集合

/**
     * 获取方法里面的参数和param注解中的value值封装成map对象
     *
     * @param args
     * @return
     */
    private Map getMethodParsMap(Method method, Object[] args) {
        Map parsMap = new HashMap();
        //就是对sql方法的处理 public int add(@XwfParam(value = "username") String username, @XwfParam(value = "password") String password);
        Parameter[] parameters = method.getParameters();
        for (int i = 0; i < parameters.length; i++) {
            Parameter parameter = parameters[i];
            // 获取方法参数中是否存在param注解
            Param param = parameter.getAnnotation(Param.class);
            String parName = "";
            // 获取param注解中value的值
            if (null != param) {
                parName = param.value();
            } else {
                parName = parameter.getName();
            }
            // 将参数名称和参数封装成map集合
            //parName为参数名,args为实际参数
            parsMap.put(parName, args[i]);
        }
        return parsMap;
    }

然后需要一个getSqlParNameList方法用于 根据#{string}参数的顺序 排列 pars里面的顺序 封装成list

// 截取sql中的参数并封装成map集合
    private List getSqlParNameList(String sql, Map parsMap) {
        List ListPar = new ArrayList();
        // 获取参数
        List listPars = this.getSqlParsList(sql);
        // 获取设置jdbc参数 按顺序
        for (String string : listPars) {
            ListPar.add(parsMap.get(string));
        }
        return ListPar;
    }

接着需要一个setSqlPars方法将将sql中替换参数为?,就是普通jdbc能处理的sql,主要通过String字符串处理

private String setSqlPars(String sql) {
        List ListPar = new ArrayList();
        if (sql.indexOf("#") > 0) {
            // 获取到}位置
            int indexNext = sql.indexOf("#") + sql.substring(sql.indexOf("#")).indexOf("}");
            // 获取到{位置
            int indexPre = sql.indexOf("#");
            // 截取#{}中的值
            String parName = sql.substring(indexPre, indexNext + 1);
            ListPar.add(parName.trim());
            sql = sql.replace(parName, "?");
            if (sql.indexOf("#") > 0) {
                //递归处理多个#
                sql = setSqlPars(sql);
            }
        }
        return sql;
    }

处理完成之后就可以执行JDBCUtils.query方法 返回一个结果集,该结果集需要处理因此需要一个selectQueryForObject方法将rse封装成返回对象其主要思路是:1 检查方法的返回类型并确定它是一个 List 还是一个 JavaBean。代码首先通过 method.getReturnType() 方法获取方法的返回类型,2 并使用 instanceof 关键字判断该类型是否为 List。3 如果返回类型为 List,则使用 method.getGenericReturnType() 方法获取方法的泛型返回类型,4 并检查该类型是否是 ParameterizedType 的实例。5 如果是,则通过 pType.getActualTypeArguments() 方法获取 List 中的泛型类型。6 接着,使用 Class.forName() 方法将泛型类型转换为 Class 类型,并将其赋值给 returnType 变量,同时将 returnTypeFlag 变量设为 "List"。7 如果返回类型不是 List,则 returnType 变量将保持原值,returnTypeFlag 变量将保持未初始化状态。

private Object selectQueryForObject(Method method, ResultSet rse) {
        List listObject = new ArrayList();
        try {
//            // 判断是否有记录
//            if (!rse.next()) {
//                return null;
//            }
            //注意此处rse结果集不能向上回滚,否则报错,只能判断长度!!!
            ResultSetMetaData rsmd = rse.getMetaData();
            // 通过ResultSetMetaData获取结果集中的列数
            int columnCount = rsmd.getColumnCount();
            if (columnCount>=1){
                // 光标往上移动一位/此方法有报错风险
//                rse.previous();
                String reutrnTypeFlag = "Bean";
                // 将结果封装成方法的返回类型

                Class returnType = method.getReturnType();
                // 判断返回类型为List还是JavaBean
                // 如果为List 再获取List中泛型的值
                Type genericReturnType = method.getGenericReturnType();
                if (genericReturnType instanceof ParameterizedType) {
                    ParameterizedType pType = (ParameterizedType) genericReturnType;
                    Type rType = pType.getRawType();// 主类型
                    Type[] tArgs = pType.getActualTypeArguments();// 泛型类型
                    for (int i = 0; i < tArgs.length; i++) {
                        System.out.println(pType + "第" + i + "泛型类型是:" + tArgs[i]);
                        returnType = Class.forName(tArgs[i].getTypeName());
                        reutrnTypeFlag = "List";
                    }
                }
                // 设置每条记录的属性值
                while (rse.next()) {
                    Object returnInstance = returnType.newInstance();
                    // 获取实例的字段属性
                    Field[] declaredFields = returnInstance.getClass().getDeclaredFields();
                    for (Field field : declaredFields) {
                        // 获取光标值
                        Object value = rse.getObject(field.getName());
                        // 获取set方法对象
                        Method methodSet = returnInstance.getClass().getMethod("set" + toUpperCaseFirstOne(field.getName()),
                                field.getType());
                        field.setAccessible(true);
                        // 为返回结果实例设置值
                        methodSet.invoke(returnInstance, value);
                    }
                    listObject.add(returnInstance);
                }
                if (listObject.size()==0){
                    return null;
                }
                // 如果返还对象为List
                if (reutrnTypeFlag.equals("List")) {
                    return listObject;
                } else {
                    // 如果返还对象为普通bean
                    return listObject.get(0);
                }

            }else return null;

        } catch (Exception e) {
            throw new RuntimeException("返回类型反射生成实例失败", e);
        }
    }

其中有部分方法是对字符串的大小写转换较为简单,在此省略

然后就是开始invoke调用的相应处理CURD方法,调用jdbc的CURD方法,还需要一些字符串的处理

    /**
     * 查询
     *
     * @param method 方法
     * @param args 参数
     * @return obj
     */
    private Object selectSql(Method method, Object[] args) {

        Select select = method.getAnnotation(Select.class);
        if (null != select) {
            // 获取注解值 sql语句
            String sql = select.value();
            // 判断sql中是否有参数
            //就是将"insert into test_table values (#{username},#{password}) " 转为 (?,?)的形式
            if (sql.indexOf("#") > 0) {
                // 获取参数 封装成Map
                Map parsMap = this.getMethodParsMap(method, args);
                // 替换参数封装jdbc能执行的语句
                // 根据#{string}参数的顺序 排列 pars里面的顺序 封装成list
                List ListPar = this.getSqlParNameList(sql, parsMap);
                // 将sql语句 中的 #{string} 替换成 ?
                sql = this.setSqlPars(sql);
                // 使用jdbc执行sql语句获取结果

                ResultSet rse = JDBCUtils.query(sql, ListPar);
                // 将rse封装成返回对象
                return this.selectQueryForObject(method, rse);
            } else {
                // 没有参数
                // 执行sql语句 返回结果 封装成 返回类型
                ResultSet rse = JDBCUtils.query(sql, null);
                // 将rse封装成返回对象
                return this.selectQueryForObject(method, rse);
            }
        }
        // 返回结果
        return null;
    }

    /**
     * 插入
     *
     * @param method 方法
     * @param args 参数
     * @return obj
     */
    public Object intsertSql(Method method, Object[] args) {

        Insert insert = method.getAnnotation(Insert.class);
        if (null != insert) {
            // 获取注解值 sql语句
            String sql = insert.value();
            // 判断sql中是否有参数
            if (sql.indexOf("#") > 0) {
                // 获取参数 封装成Map
                Map parsMap = this.getMethodParsMap(method, args);
                // 替换参数封装jdbc能执行的语句
                // 根据#{string}参数的顺序 排列 pars里面的顺序 封装成list
                List ListPar = this.getSqlParNameList(sql, parsMap);
                // 将sql语句 中的 #{string} 替换成 ?
                sql = this.setSqlPars(sql);
                // 使用jdbc执行sql语句获取结果
                return JDBCUtils.update(sql, false, ListPar);
            } else {
                // 没有参数
                return JDBCUtils.execute(sql);
            }
        }
        // 返回结果
        return false;
    }

    
    
    public Object updateSql(Method method, Object[] args) {

        Update update = method.getAnnotation(Update.class);
        if (null != update) {
            // 获取注解值 sql语句
            String sql = update.value();
            // 判断sql中是否有参数
            if (sql.indexOf("#") > 0) {
                // 获取参数 封装成Map
                Map parsMap = this.getMethodParsMap(method, args);
                // 替换参数封装jdbc能执行的语句
                // 根据#{string}参数的顺序 排列 pars里面的顺序 封装成list
                List ListPar = this.getSqlParNameList(sql, parsMap);
                // 将sql语句 中的 #{string} 替换成 ?
                sql = this.setSqlPars(sql);
                // 使用jdbc执行sql语句获取结果
                return JDBCUtils.update(sql, false, ListPar);
            } else {
                // 没有参数
                return JDBCUtils.execute(sql);
            }
        }
        // 返回结果
        return false;
    }


    public Object deleteSql(Method method, Object[] args) {
        Delete delete = method.getAnnotation(Delete.class);
        if (null != delete) {
            // 获取注解值 sql语句
            String sql = delete.value();
            // 判断sql中是否有参数
            if (sql.indexOf("#") > 0) {
                // 获取参数 封装成Map
                Map parsMap = this.getMethodParsMap(method, args);
                // 替换参数封装jdbc能执行的语句
                // 根据#{string}参数的顺序 排列 pars里面的顺序 封装成list
                List ListPar = this.getSqlParNameList(sql, parsMap);
                // 将sql语句 中的 #{string} 替换成 ?
                sql = this.setSqlPars(sql);
                // 使用jdbc执行sql语句获取结果
                return JDBCUtils.update(sql, false, ListPar);
            } else {
                // 没有参数
                return JDBCUtils.execute(sql);
            }
        }
        // 返回结果
        return false;
    }

最后要有一个静态的获取代理对象的方法,利于在其他类中传入要代理的类class就可以获得代理对象实例,并执行调用

// 获取代理对象
    @SuppressWarnings("unchecked")//抑制编译器的警告信息
    public static  T getObjectProxy(Class objSreviceCla) throws Exception {
        MyInvocationHandlerMybatis invocationHandlerImpl = new MyInvocationHandlerMybatis(objSreviceCla);
        ClassLoader loader = objSreviceCla.getClassLoader();
        // Class[] interfaces = objSreviceCla.getInterfaces();
        T newProxyInstance = (T) Proxy.newProxyInstance(loader, new Class[] { objSreviceCla },
                invocationHandlerImpl);
        return newProxyInstance;
    }

这样一个基本的MyInvocationHandlerMybatis代理类就完成了

相关文章

JavaScript2024新功能:Object.groupBy、正则表达式v标志
PHP trim 函数对多字节字符的使用和限制
新函数 json_validate() 、randomizer 类扩展…20 个PHP 8.3 新特性全面解析
使用HTMX为WordPress增效:如何在不使用复杂框架的情况下增强平台功能
为React 19做准备:WordPress 6.6用户指南
如何删除WordPress中的所有评论

发布评论