JDK动态代理的简单使用及源码分析

JDK动态代理的简单使用及源码分析

微信搜索 zze_coding 或扫描 👉 二维码关注我的微信公众号获取更多资源推送:

使用示例

1、定义一个接口:

public interface IWaiter {
    void service();
}

2、定义它的实现类:

// com.zze.service.impl.Waiter
import com.zze.service.IWaiter;

public class Waiter implements IWaiter {
    public void service(){
        System.out.println("正在服务");
    }
}

3、使用 JDK 提供的动态代理:

@Test
public void test() {
    IWaiter waiter = new Waiter();
    Class<?>[] interfaces = Waiter.class.getInterfaces();
    IWaiter waiterProxy = (IWaiter) Proxy.newProxyInstance(Waiter.class.getClassLoader(), interfaces, new InvocationHandler() {
        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
            Object obj = null;
            if (method.getName().equalsIgnoreCase("service")) {
                System.out.println("服务之前");
                obj = method.invoke(waiter, args);
                System.out.println("服务之后");
            }
            return obj;
        }
    });
    waiterProxy.service();
    /*
    服务之前
    正在服务
    服务之后
    */
}

注:JDK 动态代理只能为实现了接口的类产生代理对象。

源码分析

newProxyInstance 方法看起:

// java.lang.reflect.Proxy.newProxyInstance
@CallerSensitive
public static Object newProxyInstance(ClassLoader loader,
                                      Class<?>[] interfaces,
                                      InvocationHandler h)
        throws IllegalArgumentException
{
    Objects.requireNonNull(h);

    final Class<?>[] intfs = interfaces.clone();
    final SecurityManager sm = System.getSecurityManager();
    if (sm != null) {
        checkProxyAccess(Reflection.getCallerClass(), loader, intfs);
    }

    /*
     * 生成代理类的 class 文件
     */
    // <2>
    Class<?> cl = getProxyClass0(loader, intfs);


    try {
        if (sm != null) {
            checkNewProxyPermission(Reflection.getCallerClass(), cl);
        }
        /*
         * 获取代理类构造函数
         */
        final Constructor<?> cons = cl.getConstructor(constructorParams);
        final InvocationHandler ih = h;
        if (!Modifier.isPublic(cl.getModifiers())) {
            AccessController.doPrivileged(new PrivilegedAction<Void>() {
                public Void run() {
                    cons.setAccessible(true);
                    return null;
                }
            });
        }
        /*
        使用代理类的构造器,传入参数 h(即我们实现的InvocationHandler类实例)创建代理类的实例并返回
         */
        // <1>
        return cons.newInstance(new Object[]{h});
    } catch (IllegalAccessException|InstantiationException e) {
        throw new InternalError(e.toString(), e);
    } catch (InvocationTargetException e) {
        Throwable t = e.getCause();
        if (t instanceof RuntimeException) {
            throw (RuntimeException) t;
        } else {
            throw new InternalError(t.toString(), t);
        }
    } catch (NoSuchMethodException e) {
        throw new InternalError(e.toString(), e);
    }
}

该方法的返回值是 <1> 处返回的代理类实例,而这个代理类字节码文件创建工作都是 <2> 处的 getProxyClass0 方法完成:

// java.lang.reflect.Proxy.getProxyClass0
private static Class<?> getProxyClass0(ClassLoader loader,
    
    // 如果被代理类实现的接口超出 65535 个则抛出异常
    if (interfaces.length > 65535) {
        throw new IllegalArgumentException("interface limit exceeded");
    }
    // <3>
    return proxyClassCache.get(loader, interfaces);
}

接着查看 <3>proxyClassCache.get 方法:

// java.lang.reflect.WeakCache.get
public V get(K key, P parameter) {
        Objects.requireNonNull(parameter);

        expungeStaleEntries();

        Object cacheKey = CacheKey.valueOf(key, refQueue);

        ConcurrentMap<Object, Supplier<V>> valuesMap = map.get(cacheKey);
        if (valuesMap == null) {
            ConcurrentMap<Object, Supplier<V>> oldValuesMap
                = map.putIfAbsent(cacheKey,
                                  valuesMap = new ConcurrentHashMap<>());
            if (oldValuesMap != null) {
                valuesMap = oldValuesMap;
            }
        }

        Object subKey = Objects.requireNonNull(subKeyFactory.apply(key, parameter));
        Supplier<V> supplier = valuesMap.get(subKey);
        Factory factory = null;

        // <4>
        while (true) {
            if (supplier != null) {
                // <7>
                V value = supplier.get();
                if (value != null) {
                    // <5>
                    return value;
                }
            }
            
            if (factory == null) {
                factory = new Factory(key, parameter, subKey, valuesMap);
            }

            if (supplier == null) {
                // <6 - begin>
                supplier = valuesMap.putIfAbsent(subKey, factory);
                if (supplier == null) {
                    supplier = factory;
                }
                // <6 - end>
            } else {
                if (valuesMap.replace(subKey, supplier, factory)) {
                    supplier = factory;
                } else {
                    supplier = valuesMap.get(subKey);
                }
            }
        }
    }

直接从 <4> 行开始看,入眼就是一个死循环,它的出口在 <5> 处,当 supplier.get() 不为空时返回它的值,而 supplier 的赋值操作是在 <6> 处,赋值后就会执行 <7> 处的 supplier.get 方法:

// java.lang.reflect.WeakCache.Factory.get
@Override
public synchronized V get() { 
    Supplier<V> supplier = valuesMap.get(subKey);
    if (supplier != this) {
        return null;
    }

    V value = null;
    try {
        // <8>
        value = Objects.requireNonNull(valueFactory.apply(key, parameter));
    } finally {
        if (value == null) { 
            valuesMap.remove(subKey, this);
        }
    }

    assert value != null;

    CacheValue<V> cacheValue = new CacheValue<>(value);


    if (valuesMap.replace(subKey, this, cacheValue)) {

        reverseMap.put(cacheValue, Boolean.TRUE);
    } else {
        throw new AssertionError("Should not reach here");
    }
    return value;
}

这个方法的返回值在 <8> 处,它的值为 valueFactory.apply(key, parameter) 的返回值,而此时 valueFactoryjava.lang.reflect.Proxy.ProxyClassFactory 的实例,查看该实例的 apply 方法:

// java.lang.reflect.Proxy.ProxyClassFactory.apply
@Override
public Class<?> apply(ClassLoader loader, Class<?>[] interfaces) {

    Map<Class<?>, Boolean> interfaceSet = new IdentityHashMap<>(interfaces.length);
    for (Class<?> intf : interfaces) {
        Class<?> interfaceClass = null;
        try {
            interfaceClass = Class.forName(intf.getName(), false, loader);
        } catch (ClassNotFoundException e) {
        }
        if (interfaceClass != intf) {
            throw new IllegalArgumentException(
                intf + " is not visible from class loader");
        }

        if (!interfaceClass.isInterface()) {
            throw new IllegalArgumentException(
                interfaceClass.getName() + " is not an interface");
        }
        if (interfaceSet.put(interfaceClass, Boolean.TRUE) != null) {
            throw new IllegalArgumentException(
                "repeated interface: " + interfaceClass.getName());
        }
    }
    // <9 - begin>
    String proxyPkg = null;     
    int accessFlags = Modifier.PUBLIC | Modifier.FINAL;

    for (Class<?> intf : interfaces) {
        int flags = intf.getModifiers();
        if (!Modifier.isPublic(flags)) {
            accessFlags = Modifier.FINAL;
            String name = intf.getName();
            int n = name.lastIndexOf('.');
            String pkg = ((n == -1) ? "" : name.substring(0, n + 1));
            if (proxyPkg == null) {
                proxyPkg = pkg;
            } else if (!pkg.equals(proxyPkg)) {
                throw new IllegalArgumentException(
                    "non-public interfaces from different packages");
            }
        }
    }

    if (proxyPkg == null) {
        proxyPkg = ReflectUtil.PROXY_PACKAGE + ".";
    }
    // <9 - end>

    long num = nextUniqueNumber.getAndIncrement();
    // <10>
    String proxyName = proxyPkg + proxyClassNamePrefix + num;

    // <11>
    byte[] proxyClassFile = ProxyGenerator.generateProxyClass(
        proxyName, interfaces, accessFlags);
    try {
        return defineClass0(loader, proxyName,
                            proxyClassFile, 0, proxyClassFile.length);
    } catch (ClassFormatError e) {
        throw new IllegalArgumentException(e.toString());
    }
}

<9> 处实际上就是在拼接代理类包名,保存在变量 proxyPkg 中,在 <10> 处拼接出代理类全路径。

<11> 处的 ProxyGenerator.generateProxyClass 方法才是真正生成代理类字节码文件的地方,将其保存在名为 proxyClassFile 的字节数组中。查看 ProxyGenerator.generateProxyClass 方法:

// sun.misc.ProxyGenerator.generateProxyClass
public static byte[] generateProxyClass(final String var0, Class<?>[] var1, int var2) {
        ProxyGenerator var3 = new ProxyGenerator(var0, var1, var2);
        // <12>
        final byte[] var4 = var3.generateClassFile();
        // <14>
        if (saveGeneratedFiles) {
            AccessController.doPrivileged(new PrivilegedAction<Void>() {
                public Void run() {
                    try {
                        int var1 = var0.lastIndexOf(46);
                        Path var2;
                        if (var1 > 0) {
                            Path var3 = Paths.get(var0.substring(0, var1).replace('.', File.separatorChar));
                            Files.createDirectories(var3);
                            var2 = var3.resolve(var0.substring(var1 + 1, var0.length()) + ".class");
                        } else {
                            var2 = Paths.get(var0 + ".class");
                        }

                        Files.write(var2, var4, new OpenOption[0]);
                        return null;
                    } catch (IOException var4x) {
                        throw new InternalError("I/O exception saving generated file: " + var4x);
                    }
                }
            });
        }

        return var4;
    }

<12> 处通过 generateClassFile 方法创建代理类字节码文件,保存在字节数组。查看 generateClassFile 方法:

// sun.misc.ProxyGenerator.generateClassFile
private byte[] generateClassFile() {
        // <13 - begin>
        this.addProxyMethod(hashCodeMethod, Object.class);
        this.addProxyMethod(equalsMethod, Object.class);
        this.addProxyMethod(toStringMethod, Object.class);
        // <13 - end>
        Class[] var1 = this.interfaces;
        int var2 = var1.length;

        int var3;
        Class var4;
        for(var3 = 0; var3 < var2; ++var3) {
            var4 = var1[var3];
            Method[] var5 = var4.getMethods();
            int var6 = var5.length;

            for(int var7 = 0; var7 < var6; ++var7) {
                Method var8 = var5[var7];
                this.addProxyMethod(var8, var4);
            }
        }

        Iterator var11 = this.proxyMethods.values().iterator();

        List var12;
        while(var11.hasNext()) {
            var12 = (List)var11.next();
            checkReturnTypes(var12);
        }

        Iterator var15;
        try {
            this.methods.add(this.generateConstructor());
            var11 = this.proxyMethods.values().iterator();

            while(var11.hasNext()) {
                var12 = (List)var11.next();
                var15 = var12.iterator();

                while(var15.hasNext()) {
                    ProxyGenerator.ProxyMethod var16 = (ProxyGenerator.ProxyMethod)var15.next();
                    this.fields.add(new ProxyGenerator.FieldInfo(var16.methodFieldName, "Ljava/lang/reflect/Method;", 10));
                    this.methods.add(var16.generateMethod());
                }
            }

            this.methods.add(this.generateStaticInitializer());
        } catch (IOException var10) {
            throw new InternalError("unexpected I/O Exception", var10);
        }

        if (this.methods.size() > 65535) {
            throw new IllegalArgumentException("method limit exceeded");
        } else if (this.fields.size() > 65535) {
            throw new IllegalArgumentException("field limit exceeded");
        } else {
            this.cp.getClass(dotToSlash(this.className));
            this.cp.getClass("java/lang/reflect/Proxy");
            var1 = this.interfaces;
            var2 = var1.length;

            for(var3 = 0; var3 < var2; ++var3) {
                var4 = var1[var3];
                this.cp.getClass(dotToSlash(var4.getName()));
            }

            this.cp.setReadOnly();
            ByteArrayOutputStream var13 = new ByteArrayOutputStream();
            DataOutputStream var14 = new DataOutputStream(var13);

            try {
                var14.writeInt(-889275714);
                var14.writeShort(0);
                var14.writeShort(49);
                this.cp.write(var14);
                var14.writeShort(this.accessFlags);
                var14.writeShort(this.cp.getClass(dotToSlash(this.className)));
                var14.writeShort(this.cp.getClass("java/lang/reflect/Proxy"));
                var14.writeShort(this.interfaces.length);
                Class[] var17 = this.interfaces;
                int var18 = var17.length;

                for(int var19 = 0; var19 < var18; ++var19) {
                    Class var22 = var17[var19];
                    var14.writeShort(this.cp.getClass(dotToSlash(var22.getName())));
                }

                var14.writeShort(this.fields.size());
                var15 = this.fields.iterator();

                while(var15.hasNext()) {
                    ProxyGenerator.FieldInfo var20 = (ProxyGenerator.FieldInfo)var15.next();
                    var20.write(var14);
                }

                var14.writeShort(this.methods.size());
                var15 = this.methods.iterator();

                while(var15.hasNext()) {
                    ProxyGenerator.MethodInfo var21 = (ProxyGenerator.MethodInfo)var15.next();
                    var21.write(var14);
                }

                var14.writeShort(0);
                return var13.toByteArray();
            } catch (IOException var9) {
                throw new InternalError("unexpected I/O Exception", var9);
            }
        }
    }

在起始 <13> 处可以看到,它还帮我们额外的代理了 hashCodeequalstoString 方法。

接着看 sun.misc.ProxyGenerator.generateProxyClass<14> 处,条件 saveGeneratedFiles 是一个布尔值,用于指定是否执行下面代码块的保存 class 文件到硬盘的功能,默认是 false。而 saveGeneratedFiles 的值实际上是取自

private static final boolean saveGeneratedFiles = (Boolean)AccessController.doPrivileged(new GetBooleanAction("sun.misc.ProxyGenerator.saveGeneratedFiles"));

所以我们只要指定了它为 true,它就会帮我们保存字节码文件。

测试一下,在 src 根目录下运行以下代码:

import com.zze.dao.impl.Waiter;
import com.zze.service.IWaiter;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;

public class Test {

    public static void main(String[] args) {
        System.getProperties().put("sun.misc.ProxyGenerator.saveGeneratedFiles", "true");
        IWaiter waiter = new Waiter();
        Class<?>[] interfaces = Waiter.class.getInterfaces();
        IWaiter waiterProxy = (IWaiter) Proxy.newProxyInstance(Waiter.class.getClassLoader(), interfaces, new InvocationHandler() {
            @Override
            public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                Object obj = null;
                if (method.getName().equalsIgnoreCase("service")) {
                    System.out.println("服务之前");
                    obj = method.invoke(waiter, args);
                    System.out.println("服务之后");
                }
                return obj;
            }
        });
        waiterProxy.service();
    }
}

接着在项目根目录下就会生成如下文件:

image.png

// com.sun.proxy.$Proxy0
import com.zze.service.IWaiter;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.lang.reflect.UndeclaredThrowableException;

public final class $Proxy0 extends Proxy implements IWaiter {
    private static Method m1;
    private static Method m2;
    private static Method m3;
    private static Method m0;

    public $Proxy0(InvocationHandler var1) throws  {
        super(var1);
    }

    public final boolean equals(Object var1) throws  {
        try {
            return (Boolean)super.h.invoke(this, m1, new Object[]{var1});
        } catch (RuntimeException | Error var3) {
            throw var3;
        } catch (Throwable var4) {
            throw new UndeclaredThrowableException(var4);
        }
    }

    public final String toString() throws  {
        try {
            return (String)super.h.invoke(this, m2, (Object[])null);
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }

    public final void service() throws  {
        try {
            super.h.invoke(this, m3, (Object[])null);
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }

    public final int hashCode() throws  {
        try {
            return (Integer)super.h.invoke(this, m0, (Object[])null);
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }

    static {
        try {
            m1 = Class.forName("java.lang.Object").getMethod("equals", Class.forName("java.lang.Object"));
            m2 = Class.forName("java.lang.Object").getMethod("toString");
            m3 = Class.forName("com.zze.service.IWaiter").getMethod("service");
            m0 = Class.forName("java.lang.Object").getMethod("hashCode");
        } catch (NoSuchMethodException var2) {
            throw new NoSuchMethodError(var2.getMessage());
        } catch (ClassNotFoundException var3) {
            throw new NoClassDefFoundError(var3.getMessage());
        }
    }
}

我们最后使用的代理对象就是 com.sun.proxy.$Proxy0 类的实例。

Copyright: 采用 知识共享署名4.0 国际许可协议进行许可

Links: https://www.zze.xyz/archives/jdk-dynamic-proxy.html

Buy me a cup of coffee ☕.