手撸RPC框架 SPI机制基础功能实现

大家好,我是小趴菜,接下来我会从0到1手写一个RPC框架,该专题包括以下专题,有兴趣的小伙伴就跟着我一起学习吧

本章源码地址:gitee.com/baojh123/se…

自定义注解 -> opt-01
服务提供者收发消息基础实现 -> opt-01
自定义网络传输协议的实现 -> opt-02
自定义编解码实现 -> opt-03
服务提供者调用真实方法实现 -> opt-04
完善服务消费者发送消息基础功能 -> opt-05
注册中心基础功能实现 -> opt-06
服务提供者整合注册中心 -> opt-07
服务消费者整合注册中心 -> opt-08
完善服务消费者接收响应结果 -> opt-09
服务消费者,服务提供者整合SpringBoot -> opt-10
动态代理屏蔽RPC服务调用底层细节 -> opt-10
SPI机制基础功能实现 -> opt-11
SPI机制扩展随机负载均衡策略 -> opt-12
SPI机制扩展轮询负载均衡策略 -> opt-13
SPI机制扩展JDK序列化 -> opt-14
SPI机制扩展JSON序列化 -> opt-15
SPI机制扩展protustuff序列化 -> opt-16

目标

我们之前已经完成了服务提供者与消费者,并且将它们与SPringBoot整合到一起了,但是我们发现其实在很多地方我们的扩展性并不够,甚至都是直接写死的,比如下面几个地方

这里是给标记了@DubboReference接口进行代理,但是我们这里是直接写死用的是 JDK动态代理,如果我们要使用CGLIB或者其他代理方式的话,就只能修改源代码,这样扩展性和灵活性都不够

好在Java为我们提供了SPI机制,能够动态扩展对应的功能,不过我们会对Java的SPI功能进行扩展,对标Dubbo的SPI机制

public void doScanDubboReferenceByPackage(String packageName) throws Exception{

    classList.forEach(item -> {
        try {
            Class clazz = Class.forName(item);
            Field[] clazzFields = clazz.getDeclaredFields();
            for(Field field : clazzFields) {
                DubboReference dubboReference = field.getAnnotation(DubboReference.class);
                if(dubboReference != null) {

                    Class targetClazz = field.getType();

                    //直接使用JDK动态代理
                    JdkProxy jdkProxy = new JdkProxy(RpcConsumer.getInstance());
                    Object proxy = jdkProxy.getProxy(targetClazz);
                    setField(field, RpcConsumerAutoConfig.getObject(clazz),proxy,true);
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    });
}

实现

在 xpc-rpc-annoation模块中新增二个注解 @SPI @SPIClass

package com.xpc.rpc.annotation;

import java.lang.annotation.*;

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface SPI {

    String value() default "";
}
package com.xpc.rpc.annotation;

import java.lang.annotation.*;

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface SPIClass {
}

创建一个SPI模块 xpc-rpc-spi

  • SPI机制实现的核心类:com.xpc.rpc.spi.loader.ExtensionLoader
package com.xpc.rpc.spi.loader;

import com.xpc.rpc.annotation.SPI;
import com.xpc.rpc.annotation.SPIClass;
import com.xpc.rpc.spi.factory.ExtensionFactory;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

public class ExtensionLoader {

    private static final Logger LOG = LoggerFactory.getLogger(ExtensionLoader.class);

    private static final String SERVICES_DIRECTORY = "META-INF/services/";
    private static final String BINGHE_DIRECTORY = "META-INF/xpc/";
    private static final String BINGHE_DIRECTORY_EXTERNAL = "META-INF/xpc/external/";
    private static final String BINGHE_DIRECTORY_INTERNAL = "META-INF/xpc/internal/";

    private static final String[] SPI_DIRECTORIES = new String[]{
            SERVICES_DIRECTORY,
            BINGHE_DIRECTORY,
            BINGHE_DIRECTORY_EXTERNAL,
            BINGHE_DIRECTORY_INTERNAL
    };

    private static final Map> LOADERS = new ConcurrentHashMap();

    private final Class clazz;

    private final ClassLoader classLoader;

    private final Holder, Object> spiClassInstances = new ConcurrentHashMap();

    private String cachedDefaultName;

    /**
     * Instantiates a new Extension loader.
     *
     * @param clazz the clazz.
     */
    private ExtensionLoader(final Class clazz, final ClassLoader cl) {
        this.clazz = clazz;
        this.classLoader = cl;
        if (!Objects.equals(clazz, ExtensionFactory.class)) {
            ExtensionLoader.getExtensionLoader(ExtensionFactory.class).getExtensionClasses();
        }
    }

    /**
     * Gets extension loader.
     *
     * @param    the type parameter
     * @param clazz the clazz
     * @param cl    the cl
     * @return the extension loader.
     */
    public static  ExtensionLoader getExtensionLoader(final Class clazz, final ClassLoader cl) {

        Objects.requireNonNull(clazz, "extension clazz is null");

        if (!clazz.isInterface()) {
            throw new IllegalArgumentException("extension clazz (" + clazz + ") is not interface!");
        }
        if (!clazz.isAnnotationPresent(SPI.class)) {
            throw new IllegalArgumentException("extension clazz (" + clazz + ") without @" + SPI.class + " Annotation");
        }
        ExtensionLoader extensionLoader = (ExtensionLoader) LOADERS.get(clazz);
        if (Objects.nonNull(extensionLoader)) {
            return extensionLoader;
        }
        LOADERS.putIfAbsent(clazz, new ExtensionLoader(clazz, cl));
        return (ExtensionLoader) LOADERS.get(clazz);
    }

    /**
     * 直接获取想要的类实例
     * @param clazz 接口的Class实例
     * @param name SPI名称
     * @param  泛型类型
     * @return 泛型实例
     */
    public static  T getExtension(final Class clazz, String name){
        return StringUtils.isEmpty(name) ? getExtensionLoader(clazz).getDefaultSpiClassInstance() : getExtensionLoader(clazz).getSpiClassInstance(name);
    }

    /**
     * Gets extension loader.
     *
     * @param    the type parameter
     * @param clazz the clazz
     * @return the extension loader
     */
    public static  ExtensionLoader getExtensionLoader(final Class clazz) {
        return getExtensionLoader(clazz, ExtensionLoader.class.getClassLoader());
    }

    /**
     * Gets default spi class instance.
     *
     * @return the default spi class instance.
     */
    public T getDefaultSpiClassInstance() {
        getExtensionClasses();
        if (StringUtils.isBlank(cachedDefaultName)) {
            return null;
        }
        return getSpiClassInstance(cachedDefaultName);
    }

    /**
     * Gets spi class.
     *
     * @param name the name
     * @return the spi class instance.
     */
    public T getSpiClassInstance(final String name) {
        if (StringUtils.isBlank(name)) {
            throw new NullPointerException("get spi class name is null");
        }
        Holder objectHolder = cachedInstances.get(name);
        if (Objects.isNull(objectHolder)) {
            cachedInstances.putIfAbsent(name, new Holder());
            objectHolder = cachedInstances.get(name);
        }
        Object value = objectHolder.getValue();
        if (Objects.isNull(value)) {
            synchronized (cachedInstances) {
                value = objectHolder.getValue();
                if (Objects.isNull(value)) {
                    value = createExtension(name);
                    objectHolder.setValue(value);
                }
            }
        }
        return (T) value;
    }

    /**
     * get all spi class spi.
     *
     * @return list. spi instances
     */
    public List getSpiClassInstances() {
        Map aClass = getExtensionClasses().get(name);
        if (Objects.isNull(aClass)) {
            throw new IllegalArgumentException("name is error");
        }
        Object o = spiClassInstances.get(aClass);
        if (Objects.isNull(o)) {
            try {
                spiClassInstances.putIfAbsent(aClass, aClass.newInstance());
                o = spiClassInstances.get(aClass);
            } catch (InstantiationException | IllegalAccessException e) {
                throw new IllegalStateException("Extension instance(name: " + name + ", class: "
                        + aClass + ")  could not be instantiated: " + e.getMessage(), e);

            }
        }
        return (T) o;
    }

    /**
     * Gets extension classes.
     *
     * @return the extension classes
     */
    public Map> classes = cachedClasses.getValue();
        if (Objects.isNull(classes)) {
            synchronized (cachedClasses) {
                classes = cachedClasses.getValue();
                if (Objects.isNull(classes)) {
                    classes = loadExtensionClass();
                    cachedClasses.setValue(classes);
                }
            }
        }
        return classes;
    }

    private Map> classes = new HashMap(16);
        loadDirectory(classes);
        return classes;
    }

    private void loadDirectory(final Map> classes, final URL url) throws IOException {
        try (InputStream inputStream = url.openStream()) {
            Properties properties = new Properties();
            properties.load(inputStream);
            properties.forEach((k, v) -> {
                String name = (String) k;
                String classPath = (String) v;
                if (StringUtils.isNotBlank(name) && StringUtils.isNotBlank(classPath)) {
                    try {
                        loadClass(classes, name, classPath);
                    } catch (ClassNotFoundException e) {
                        throw new IllegalStateException("load extension resources error", e);
                    }
                }
            });
        } catch (IOException e) {
            throw new IllegalStateException("load extension resources error", e);
        }
    }

    private void loadClass(final Map