深入理解 ForkJoinPool:入门、使用、原理

开发 前端
一般情况下,线程获取自己队列中的任务是 LIFO(Last Input First Output 后进先出)的方式,即类似于栈的操作方式。如下图所示,首先放入队列的时候先将任务 Push 进队列的头部(top),之后消费的时候在 pop 出队列头部(top)。

​大家好,我是树哥。

本文将从一个简单的例子出发,与大家解释为啥要有 ForkJoinPool 的存在。接着向大家介绍 ForkJoinPool 的基本信息及使用,最后讲解 ForkJoinPool 的基本原理。

诞生原因

对于线程池来说,我们经常使用的是 ThreadPoolExecutor,可以用来提升任务处理效率。一般情况下,我们使用 ThreadPoolExecutor 的时候,各个任务之间都是没有联系的。但在某些特殊情况下,我们处理的任务之间是有联系的,例如经典的 Fibonacci 算法就是其中一种情况。

对于 Fibonacci 数列来说,我们知道 F (N) = F (N-1) + F (N-2)。当前数值的结果,都依赖后面几个数值的结果。这时候如果用 ThreadPoolExecutor 貌似就无法解决问题了。虽然我们可以单线程的递归算法,则其计算速度较慢,并且无法进行并行计算,无法发挥 CPU 多核的优势。

ForkJoinPool 就是设计用来解决父子任务有依赖的并行计算问题的。 类似于快速排序、二分查找、集合运算等有父子依赖的并行计算问题,都可以用 ForkJoinPool 来解决。对于 Fibonacci 数列问题,如果用 ForkJoinPool 来实现,其实现代码为:

@Slf4j
public class ForkJoinDemo {
    // 1. 运行入口
    public static void main(String[] args) {
        int n = 20;

        // 为了追踪子线程名称,需要重写 ForkJoinWorkerThreadFactory 的方法
        final ForkJoinPool.ForkJoinWorkerThreadFactory factory = pool -> {
            final ForkJoinWorkerThread worker = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool);
            worker.setName("my-thread" + worker.getPoolIndex());
            return worker;
        };

        //创建分治任务线程池,可以追踪到线程名称
        ForkJoinPool forkJoinPool = new ForkJoinPool(4, factory, null, false);

        // 快速创建 ForkJoinPool 方法
        // ForkJoinPool forkJoinPool = new ForkJoinPool(4);

        //创建分治任务
        Fibonacci fibonacci = new Fibonacci(n);

        //调用 invoke 方法启动分治任务
        Integer result = forkJoinPool.invoke(fibonacci);
        log.info("Fibonacci {} 的结果是 {}", n, result);
    }
}

// 2. 定义拆分任务,写好拆分逻辑
@Slf4j
class Fibonacci extends RecursiveTask<Integer> {
    final int n;
    Fibonacci(int n) {
        this.n = n;
    }

    @Override
    public Integer compute() {
        //和递归类似,定义可计算的最小单元
        if (n <= 1) {
            return n;
        }
        // 想查看子线程名称输出的可以打开下面注释
        //log.info(Thread.currentThread().getName());

        Fibonacci f1 = new Fibonacci(n - 1);
        // 拆分成子任务
        f1.fork();
        Fibonacci f2 = new Fibonacci(n - 2);
        // f1.join 等待子任务执行结果
        return f2.compute() + f1.join();
    }
}
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.

如上面代码所示,我们定义了一个 Fibonacci 类,继承了 RecursiveTask 抽象类。在 Fibonacci 类中,我们定义了拆分逻辑,并调用了 join () 等待子线程执行结果。运行程序,我们会得到如下的结果:

17:29:10.336 [main] INFO tech.shuyi.javacodechip.forkjoinpool.ForkJoinDemo - Fibonacci 20 的结果是 6765
  • 1.

上面代码中提到的 fork () 和 join () 是 ForkJoinPool 提供的 API 接口,主要用于执行任务以及等待子线程结果。关于其详细用法,我们稍后会讲到。

除了用于处理父子任务有依赖的情形,其实 ForkJoinPool 也可以用于处理需要获取子任务执行结果的场景。 例如:我们要计算 1 到 1 亿的和,为了加快计算的速度,我们自然想到算法中的分治原理,将 1 亿个数字分成 1 万个任务,每个任务计算 1 万个数值的综合,利用 CPU 的并发计算性能缩短计算时间。

因为 ThreadPoolExecutor 也可以通过 Future 获取执行结果,因此 ThreadPoolExecutor 也是可行的。这时候我们有两种实现方式,一种是用 ThreadPoolExecutor 实现,一种是用 ForkJoinPool 实现。下面我们将这两种方式都实现一下,看看这两种实现方式有什么不同。

无论哪种实现方式,其大致思路都是:

  • 按照线程池里线程个数 N,将 1 亿个数划分成 N 等份,随后丢入线程池进行计算。
  • 每个计算任务使用 Future 接口获取计算结果,最后积加即可。

我们先使用 ThreadPoolExecutor 实现。

首先,定义一个 Calculator 接口,表示计算数字总和这个动作,如下所示。

public interface Calculator {
    /**
     * 把传进来的所有numbers 做求和处理
     *
     * @param numbers
     * @return 总和
     */
    long sumUp(long[] numbers);
}
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.

接着,我们定义一个使用 ThreadPoolExecutor 线程池实现的类,如下所示。

package tech.shuyi.javacodechip.forkjoinpool;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public class ExecutorServiceCalculator implements Calculator {

    private int parallism;
    private ExecutorService pool;

    public ExecutorServiceCalculator() {
        // CPU的核心数 默认就用cpu核心数了
        parallism = Runtime.getRuntime().availableProcessors(); 
        pool = Executors.newFixedThreadPool(parallism);
    }

    // 1. 处理计算任务的线程
    private static class SumTask implements Callable<Long> {
        private long[] numbers;
        private int from;
        private int to;

        public SumTask(long[] numbers, int from, int to) {
            this.numbers = numbers;
            this.from = from;
            this.to = to;
        }

        @Override
        public Long call() {
            long total = 0;
            for (int i = from; i <= to; i++) {
                total += numbers[i];
            }
            return total;
        }
    }

    // 2. 核心业务逻辑实现
    @Override
    public long sumUp(long[] numbers) {
        List<Future<Long>> results = new ArrayList<>();

        // 2.1 数字拆分
        // 把任务分解为 n 份,交给 n 个线程处理   4核心 就等分成4份呗
        // 然后把每一份都扔个一个SumTask线程 进行处理
        int part = numbers.length / parallism;
        for (int i = 0; i < parallism; i++) {
            int from = i * part; //开始位置
            int to = (i == parallism - 1) ? numbers.length - 1 : (i + 1) * part - 1; //结束位置

            //扔给线程池计算
            results.add(pool.submit(new SumTask(numbers, from, to)));
        }

        // 2.2 阻塞等待结果
        // 把每个线程的结果相加,得到最终结果 get()方法 是阻塞的
        // 优化方案:可以采用CompletableFuture来优化  JDK1.8的新特性
        long total = 0L;
        for (Future<Long> f : results) {
            try {
                total += f.get();
            } catch (Exception ignore) {
            }
        }

        return total;
    }
}
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.

如上面代码所示,我们实现了一个计算单个任务的类 SumTask,在该类中对数值进行累加。其次,我们在 sumUp () 方法中,对 1 亿的数字进行拆分,接着扔给线程池计算,最后阻塞等待计算结果,最终累加起来。

我们运行上面的代码,可以得到顺利得到最终结果,如下所示。

耗时:10ms
结果为:50000005000000
  • 1.
  • 2.

接着我们使用 ForkJoinPool 来实现。

我们首先实现 SumTask 继承 RecursiveTask 抽象类,并在 compute () 方法中定义拆分逻辑及计算。最后在爱 sumUp () 方法中调用 pool 方法进行计算,代码如下所示。

public class ForkJoinCalculator implements Calculator {

    private ForkJoinPool pool;

    // 1. 定义计算逻辑
    private static class SumTask extends RecursiveTask<Long> {
        private long[] numbers;
        private int from;
        private int to;

        public SumTask(long[] numbers, int from, int to) {
            this.numbers = numbers;
            this.from = from;
            this.to = to;
        }

        //此方法为ForkJoin的核心方法:对任务进行拆分  拆分的好坏决定了效率的高低
        @Override
        protected Long compute() {

            // 当需要计算的数字个数小于6时,直接采用for loop方式计算结果
            if (to - from < 6) {
                long total = 0;
                for (int i = from; i <= to; i++) {
                    total += numbers[i];
                }
                return total;
            } else { 
                // 否则,把任务一分为二,递归拆分(注意此处有递归)到底拆分成多少分 需要根据具体情况而定
                int middle = (from + to) / 2;
                SumTask taskLeft = new SumTask(numbers, from, middle);
                SumTask taskRight = new SumTask(numbers, middle + 1, to);
                taskLeft.fork();
                taskRight.fork();
                return taskLeft.join() + taskRight.join();
            }
        }
    }

    public ForkJoinCalculator() {
        // 也可以使用公用的线程池 ForkJoinPool.commonPool():
        // pool = ForkJoinPool.commonPool()
        pool = new ForkJoinPool();
    }

    @Override
    public long sumUp(long[] numbers) {
        Long result = pool.invoke(new SumTask(numbers, 0, numbers.length - 1));
        pool.shutdown();
        return result;
    }
}
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.

运行上面的代码,结果为:

耗时:860ms
结果为:50000005000000
  • 1.
  • 2.

对比 ThreadPoolExecutor 和 ForkJoinPool 这两者的实现,可以发现它们都有任务拆分的逻辑,以及最终合并数值的逻辑。但 ForkJoinPool 相比 ThreadPoolExecutor 来说,做了一些实现上的封装,例如:

  • 不用手动去获取子任务的结果,而是使用 join () 方法直接获取结果。
  • 将任务拆分的逻辑,封装到 RecursiveTask 实现类中,而不是裸露在外。

因此对于没有父子任务依赖,但是希望获取到子任务执行结果的并行计算任务,也可以使用 ForkJoinPool 来实现。在这种情况下,使用 ForkJoinPool 实现更多是代码实现方便,封装做得更加好。

使用指南

使用 ForkJoinPool 来进行并行计算,主要分为两步:

  • 定义 RecursiveTask 或 RecursiveAction 的任务子类。
  • 初始化线程池及计算任务,丢入线程池处理,取得处理结果。

首先,我们需要定义一个 RecursiveTask 或 RecursiveAction 的子类,然后再该类的 compute () 方法中定义拆分逻辑和计算逻辑。 这两个抽象类的区别在于:前者有返回值,后者没有返回值。例如前面讲到的 1 到 1 亿的叠加问题,其定义的 RecursiveTask 实现类 SumTask 的代码如下:

private static class SumTask extends RecursiveTask<Long> {
    private long[] numbers;
    private int from;
    private int to;

    public SumTask(long[] numbers, int from, int to) {
        this.numbers = numbers;
        this.from = from;
        this.to = to;
    }

    @Override
    protected Long compute() {
        // 1. 定义拆分退出逻辑
        if (to - from < 6) {
            long total = 0;
            for (int i = from; i <= to; i++) {
                total += numbers[i];
            }
            return total;
        } else {
        // 2. 定义计算逻辑
            int middle = (from + to) / 2;
            SumTask taskLeft = new SumTask(numbers, from, middle);
            SumTask taskRight = new SumTask(numbers, middle + 1, to);
            taskLeft.fork();
            taskRight.fork();
            return taskLeft.join() + taskRight.join();
        }
    }
}
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.

对于 compute () 方法的实现,核心是想清楚:怎么拆分成子任务?什么时候结束拆分?

接着,初始化 ForkJoinPool 线程池,初始化计算任务,最后将任务丢入线程池中。

// 初始化线程池
public ForkJoinCalculator() {
    pool = new ForkJoinPool();
}
// 初始化计算任务,将任务丢入线程池
public long sumUp(long[] numbers) {
    Long result = pool.invoke(new SumTask(numbers, 0, numbers.length - 1));
    pool.shutdown();
    return result;
}
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.

通过上面两步操作,我们就完成了一个 ForkJoinPool 任务代码的编写。

原理解析

ForkJoinPool 的设计思想是分治算法,即将任务不断拆分(fork)成更小的任务,最终再合并(join)各个任务的计算结果。通过这种方式,可以充分利用 CPU 资源,再结合工作窃取算法(worksteal)整体提高执行效率。其简单的流程如下图:

图片

图片来源于思否用户「日拱一兵」

从图中可以看出 ForkJoinPool 要先执行完子任务才能执行上一层任务。因此 ForkJoinPool 最适合有父子任务依赖的场景,其次就是需要获取子任务执行结果的场景。比如:Fibonacci 数列、快速排序、二分查找等。

源码实现

ForkJoinPool 的主要实现类为:ForkJoinPool 和 ForkJoinTask 抽象类。

ForkJoinTask 实现了 Future 接口,可以用于获取处理结果。ForkJoinTask 有两个抽象子类:RecursiveAction 和 RecursiveTask 抽象类,其区别在于前者没有返回值,后者有返回值,其类图如下所示。

图片

图片来源于思否用户「日拱一兵」

ForkJoinPool 则是具体的逻辑实现,由于暂时没有应用场景,就不了解这么深了,这里就不深入解析了。

感兴趣的朋友可以参考这篇文章:ForkJoinPool 大型图文现场(一阅到底 vs 直接收藏) - SegmentFault 思否。

窃取算法

我们知道 ForkJoinPool 的父子任务之间是有依赖关系的,那么 ForkJoinPool 是如何实现的呢?答案是:利用不同任务队列执行。 在 ForkJoinPool 中有一个数组形式的成员变量 workQueue[],其对应一个队列数组,每个队列对应一个消费线程。丢入线程池的任务,根据特定规则进行转发。

图片

图片来源于思否用户「日拱一兵」

这样就有一个问题:有些队列可能任务比较多,有些队列任务比较少,这样就会导致不同线程负载不一样,整体不够高效,怎么办呢?

答案是:利用窃取算法,空闲的线程从尾部去消费其他队列的任务。

一般情况下,线程获取自己队列中的任务是 LIFO(Last Input First Output 后进先出)的方式,即类似于栈的操作方式。如下图所示,首先放入队列的时候先将任务 Push 进队列的头部(top),之后消费的时候在 pop 出队列头部(top)。

图片

图片来源于思否用户「日拱一兵」

而当某个线程对应的队列空闲时,该线程则去队列的底部(base)窃取(poll)任务到自己的队列,然后进行消费。那问题来了:为什么不从头部(top)获取任务,而要从底部(base)获取任务呢? 那是为了避免冲突!如果两个线程同时从顶部获取任务,那就会有多线程的冲突问题,就需要加锁操作,从而降低了执行效率。

参考资料

  • 线程池 ForkJoinPool 简介 - 老 K 的 Java 博客
  • ForkJoinPool 大型图文现场(一阅到底 vs 直接收藏) - SegmentFault 思否
  • (2 条消息) 介绍 ForkJoinPool 的适用场景,实现原理_想跑步丶小胖子的博客 - CSDN 博客_forkjoinpool
  • ForkJoinPool 使用场景・ARLOOR​
责任编辑:武晓燕 来源: 树哥聊编程
相关推荐

2022-08-22 08:04:25

Spring事务Atomicity

2021-03-10 10:55:51

SpringJava代码

2022-11-04 09:43:05

Java线程

2022-09-05 08:39:04

kubernetesk8s

2024-03-12 00:00:00

Sora技术数据

2024-11-01 08:57:07

2020-08-10 18:03:54

Cache存储器CPU

2024-04-15 00:00:00

技术Attention架构

2020-03-26 16:40:07

MySQL索引数据库

2023-09-19 22:47:39

Java内存

2022-01-14 12:28:18

架构OpenFeign远程

2019-07-01 13:34:22

vue系统数据

2020-03-17 08:36:22

数据库存储Mysql

2020-11-04 15:35:13

Golang内存程序员

2022-09-05 22:22:00

Stream操作对象

2023-10-13 13:30:00

MySQL锁机制

2023-01-16 18:32:15

架构APNacos

2009-11-16 17:20:04

PHP多维数组排序

2020-07-03 17:20:07

Redux前端代码

2017-05-04 16:35:45

点赞
收藏

51CTO技术栈公众号