深入理解并发工具类CountDownLatch

2023年 7月 17日 30.4k 0

CountDownLatch 概述及使用方式

本篇文章想要讲解 JUC 工具类 CountDownLatch,因为 CountDownLatch 提供了简单有效的线程协调和控制机制,所以实际开发中是比较常用的,所以有必要了解一下 CountDownLatch。

初识 CountDownLatch

CountDownLatch 作为 Java 中的一个同步工具类,用于在多线程间实现协调和控制,允许一个或多个线程等待其他线程完成操作后再继续执行。

CountDownLatch 内部维护了一个计数器,可以通过构造函数指定初始计数值。当一个线程完成了自己的任务后,可以调用 countDown() 方法将计数值减一。而其他线程可以通过调用 await() 方法等待计数值减为零,然后再继续执行。

一般情况下,主线程会创建 CountDownLatch 对象,然后传递给其他线程。其他线程执行完自己的任务后,调用 countDown() 方法进行计数,主线程调用 await() 方法等待计数值为零。

CountDownLatch 的核心方法

CountDownLatch 提供了四个核心方法来实现线程的协调和控制,核心方法如下:

  • public CountDownLatch(int count)
    • CountDownLatch 的构造方法,用于创建一个 CountDownLatch 对象,并指定初始计数值(计数值表示需要等待的线程数量)。
  • public void countDown()
    • 当一个线程完成任务后,可以调用该方法将计数器的值减一(如果计数器的值已经为零,那么调用该方法没有任何影响,即计数器的值不会再减,而是一直为零)。
  • public void await()
    • 当一个线程需要等待其他线程完成任务后再继续执行时,可以调用该方法进行等待(如果计数器的值已经为零,那么调用该方法会立即返回)。
    • 如果在等待过程中,当前线程被中断,则会抛出 InterruptedException 异常。
    • 需要注意的是调用该方法时,计数器的值应当在所有线程都能够完成任务后变为零,否则可能导致线程一直等待或提前继续执行的问题。
  • public boolean await(long timeout, TimeUnit unit)
    • await() 方法作用一样都能使当前线程等待,不同点在于允许设置超时时间(即如果计数器的值在超时时间内变为零,那么方法会返回 true,否则返回 false)。
    • 参数中的 timeout 表示超时时间的数值,unit 表示超时时间的单位。
    • 如果在等待过程中,当前线程被中断,则会抛出 InterruptedException 异常。
  • CountDownLatch 的应用场景

    通过上面的介绍,应该能了解到 CountDownLatch 是什么以及如何使用,接下来通过具体的应用场景来看看 CountDownLatch 都可以在实际开发中起到怎样的作用。

    应用场景一:等待多个线程任务执行完成

    场景:如果需要等待多个线程执行完成后,才能进行下一步操作,就可以使用 CountDownLatch 来实现。通过创建一个 CountDownLatch 对象,并将计数器的值初始化为线程数(任务数),每个线程执行完成后,调用 countDown() 方法将计数器减一,主线程通过调用 await() 方法等待所有线程执行完成后执行下一步操作。

    示例:有一个主线程需要等待五个子任务(线程)都完成后再进行后续操作(汇总子任务的结果)。

    示例代码:

    /**
     * CountDownLatch 示例
     * @author 单程车票
     */
    public class CountDownLatchDemo {
        public static void main(String[] args) {
            // 任务数为5
            CountDownLatch countDownLatch = new CountDownLatch(5);
            for (int i = 0; i  {
                    try {
                        System.out.println("执行任务" + task + "业务");
                        try { TimeUnit.SECONDS.sleep(1);  } catch (InterruptedException e) {e.printStackTrace();}
                    } finally {
                        countDownLatch.countDown();
                    }
                }).start();
            }
            // 阻塞直到所有任务执行完成或超出超时时间(30min)
            try {
                countDownLatch.await(30, TimeUnit.MINUTES);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println("子线程任务完成,主线程合并子线程结果");
        }
    }
    

    示例结果:

    image.png

    应用场景二:等待外部资源初始化

    场景:当多个线程在执行前需要初始化某个系统组件或外部资源(如数据库连接池)时,可以使用 CountDownLatch 实现。通过主线程创建 CountDownLatch 对象,设定计数值为 1。初始化线程在完成资源初始化后调用 countDown() 方法,然后其他线程通过 await() 方法等待初始化完成后再开始使用资源。

    示例:有三个线程等待外部资源初始化线程执行完成后再执行各自线程的业务。

    示例代码:

    /**
     * CountDownLatch 示例
     * @author 单程车票
     */
    public class CountDownLatchDemo {
        public static void main(String[] args) {
            // 初始计数值为1
            CountDownLatch countDownLatch = new CountDownLatch(1);
            // 三个线程等待外部资源线程初始化后在执行
            for (int i = 0; i  {
                    // 阻塞直到外部资源初始化完成
                    try {
                        countDownLatch.await(30, TimeUnit.MINUTES);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                    System.out.println("外部资源初始化完成,执行任务" + task + "业务");
                }).start();
            }
            // 创建线程进行外部资源初始化
            new Thread(() -> {
                try {
                    System.out.println("初始化外部资源");
                    try { TimeUnit.SECONDS.sleep(1);  } catch (InterruptedException e) {e.printStackTrace();}
                } finally {
                    countDownLatch.countDown();
                }
            }).start();
        }
    }
    

    示例结果:

    image.png

    应用场景三:控制线程执行顺序

    场景:当需要保证多个线程按照特定的顺序执行时,可以通过 CountDownLatch 实现。主线程可以根据特定执行顺序创建多个 CountDownLatch 对象对应多个线程,每个 CountDownLatch 对象的初始计数值都为 1,保证某一时刻只有指定顺序的线程执行,执行完成后,调用下一个 CountDownLatch 对象的 countDown() 方法唤醒下一个指定顺序线程执行。

    示例:有三个线程,需要按照 3 1 2 的顺序依次执行各自线程的业务。

    示例代码:

    /**
     * CountDownLatch 示例
     * @author 单程车票
     */
    public class CountDownLatchDemo {
        public static void main(String[] args) {
            // 初始计数值为1
            CountDownLatch order1 = new CountDownLatch(1);
            CountDownLatch order2 = new CountDownLatch(1);
            CountDownLatch order3 = new CountDownLatch(1);
            // 三个线程按照 3 1 2 的顺序执行
            order3.countDown();  // 开启多个线程顺序执行
            // 创建线程1
            new Thread(() -> {
                // 阻塞直到线程3完成
                try {
                    order1.await(30, TimeUnit.MINUTES);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                try {
                    System.out.println("执行任务 1 的业务");
                } finally {
                    order2.countDown();
                }
            }).start();
            // 创建线程2
            new Thread(() -> {
                // 阻塞直到线程1完成
                try {
                    order2.await(30, TimeUnit.MINUTES);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println("执行任务 2 的业务");
            }).start();
            // 创建线程3
            new Thread(() -> {
                // 阻塞直到主线程开启顺序执行
                try {
                    order3.await(30, TimeUnit.MINUTES);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                try {
                    System.out.println("执行任务 3 的业务");
                } finally {
                    order1.countDown();
                }
            }).start();
        }
    }
    

    示例结果:

    image.png

    CountDownLatch 的源码分析

    通过前两部分的内容可以了解到 CountDownLatch 的使用方式和应用场景了,可以看到 CountDownLatch 最为核心的两个方法是 countDown()await()。接下来通过源码分析来看看这两个方法是如何实现的。

    通过源码可以看到 CountDownLatch 其实是基于 AQS 实现的(想进一步了解 AQS 的,可以查看深入理解AbstractQueuedSynchronizer - 掘金 (juejin.cn)), CountDownLatch 内部通过一个静态内部类 Sync 继承 AQS 来实现构建同步锁。下面从 countDown()await() 这两个方法开始进行源码分析。

    核心方法一:await()

    await() 源码:

    image.png

    可以看到 await() 方法中调用了 Sync 的 acquireSharedInterruptibly() 方法,但是 Sync 中并没有实现该方法,所以实际上调用的是 AQS 中的 acquireSharedInterruptibly() 方法,进入方法:

    image.png

    方法中先判断线程是否被中断,如果被中断则抛出 InterruptedException 异常,通过调用 tryAcquireShared() 方法尝试抢占共享锁,这个方法是 AQS 的抽象方法由子类实现,这里实际上调用的就是 Sync 的 tryAcquireShared() 方法,进入方法:

    image.png

    该方法调用 getState() 方法获取当前计数器的值,并判断是否为 0,若为 0 则返回 1,不为 0 则返回 -1。回到上面的 tryAcquireShared() 中可以看到当计数器的值为 0 时则不需要进入等待队列,当计数器的值不为 0 时,则调用 doAcquireSharedInterruptibly())。进入方法:

    image.png

    深入方法代码可以分为以下几步:

    • 首先通过 addWaiter() 构建一个共享模式的 Node 并加入等待队列。
    • 然后通过无限循环,判断当前节点的前驱节点是否是头节点(前驱节点为头节点表示意味着具有尝试资源获取的机会)
      • 前驱节点是头节点,则不断地尝试获取资源(即调用 tryAcquireShared() 这个方法前面有提到,用于判断计数器的值是否为 0),计数值为 0,则表示获取资源成功,即线程可以运行,所以执行 setHeadAndPropagate() 将当前节点设置为新的头结点,并设置 p.next=null 等待 GC 回收。
      • 前驱结点不是头节点,则执行 shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt() 根据一定条件判断线程是否应该被阻塞并检查是否发生中断,等待后续唤醒。
    • 最后的 finally 通过标志 failed (表示是否获取资源失败),如果为 true,则执行 cancelAcquire() 方法取消对资源的获取,并移出等待队列。

    所以这个方法核心为通过无限循环不断地尝试获取共享资源,获取成功则将当前节点设置为头结点,获取失败则判断是否需要阻塞并检查是否被中断,如果最后获取失败,则放弃获取资源并移出等待队列。

    到这里就是 await() 方法的整个实现流程了,底层通过调用 AQS 的 doAcquireSharedInterruptibly() 方法以及 CountDownLatch 实现 AQS 的抽象方法 tryAcquireShared() 实现线程阻塞和唤醒。

    核心方法二:countDown()

    countDown() 源码:

    image.png

    可以看到 countDown() 方法中调用了 Sync 的 releaseShared() 方法,但是 Sync 中并没有实现该方法,所以实际上调用的是 AQS 中的 releaseShared() 方法,进入方法:

    image.png

    方法中调用 Sync 实现 AQS 的抽象方法 tryReleaseShared() 来进行判断,进入方法:

    image.png

    方法中判断当前计数器值是否为 0,是则返回 false 不做任何操作,也就是当计数器值为 0 时调用 CountDownLatch() 方法不会做任何操作。不是 0 则进行计数器值减一,并通过 CAS 操作更新计数器值,如果更新后的值为 0,则调用 AQS 内部的 doReleaseShared() 方法释放共享资源,否则除了更新计数器值之外不做任何操作。进入 doReleaseShared() 方法:

    image.png

    doReleaseShared() 方法的目的是在释放共享资源时,确保唤醒等待的线程,并通过循环和 CAS 操作来处理并发情况和头节点的变化。

    到这里就是 countDown() 方法的整个实现过程了,底层通过 CountDownLatch 实现 AQS 的抽象方法 tryReleaseShared() 采用 CAS 来完成计数器减一,并通过 AQS 的内部方法 doReleaseShared() 实现释放资源。

    相关文章

    JavaScript2024新功能:Object.groupBy、正则表达式v标志
    PHP trim 函数对多字节字符的使用和限制
    新函数 json_validate() 、randomizer 类扩展…20 个PHP 8.3 新特性全面解析
    使用HTMX为WordPress增效:如何在不使用复杂框架的情况下增强平台功能
    为React 19做准备:WordPress 6.6用户指南
    如何删除WordPress中的所有评论

    发布评论