fork/join作为一个并发框架在jdk7的时候就加入到了我们的java并发包java.util.concurrent中,并且在java 8 的lambda并行流中充当着底层框架的角色。
这样一个优秀的框架设计,我自己想了解一下它的底层代码是如何实现的,所以我尝试的去阅读了JDK相关的源码。下面我打算分享一下阅读完之后的心得~。
1、fork/join的设计思路
了解一个框架的第一件事,就是先了解别人的设计思路!
fork/join大体的执行过程就如上图所示,先把一个大任务分解(fork)成许多个独立的小任务,然后起多线程并行去处理这些小任务。处理完得到结果后再进行合并(join)就得到我们的最终结果。
显而易见的这个框架是借助了现代计算机多核的优势并行去处理数据。这看起来好像没有什么特别之处,这个套路很多人都会,并且工作中也会经常运用~。其实fork/join的最特别之处在于它还运用了一种叫work-stealing(工作窃取)的算法,这种算法的设计思路在于把分解出来的小任务放在多个双端队列中,而线程在队列的头和尾部都可获取任务。
当有线程把当前负责队列的任务处理完之后,它还可以从那些还没有处理完的队列的尾部窃取任务来处理,这连线程的空余时间也充分利用了!。
work-stealing原理图如下:
2、实现fork/join 定义了哪些角色?
了解设计原理,这仅仅是第一步!要了解别人整个的实现思路, 还需要了解别人为了实现这个框架定义了哪些角色,并了解这些角色的职责范围是什么的。因为知道谁负责了什么,谁做什么,这样整个逻辑才能串起来!在JAVA里面角色是以类的形式定义的,而了解类的行为最直接的方式就是看定义的公共方法~。
这里介绍JDK里面与fork/join相关的主要几个类:
- ForkJoinPool:充当fork/join框架里面的管理者,最原始的任务都要交给它才能处理。它负责控制整个fork/join有多少个workerThread,workerThread的创建,激活都是由它来掌控。它还负责workQueue队列的创建和分配,每当创建一个workerThread,它负责分配相应的workQueue。然后它把接到的活都交给workerThread去处理,它可以说是整个frok/join的容器。
- ForkJoinWorkerThread:fork/join里面真正干活的"工人",本质是一个线程。里面有一个ForkJoinPool.WorkQueue的队列存放着它要干的活,接活之前它要向ForkJoinPool注册(registerWorker),拿到相应的workQueue。然后就从workQueue里面拿任务出来处理。它是依附于ForkJoinPool而存活,如果ForkJoinPool的销毁了,它也会跟着结束。
- ForkJoinPool.WorkQueue: 双端队列就是它,它负责存储接收的任务。
- ForkJoinTask:代表fork/join里面任务类型,我们一般用它的两个子类RecursiveTask、RecursiveAction。这两个区别在于RecursiveTask任务是有返回值,RecursiveAction没有返回值。任务的处理逻辑包括任务的切分都集中在compute()方法里面。
3、fork/join初始化时做了什么
大到一个系统,小到一个框架,初始化工作往往是体现逻辑的一个重要地方!因为这是开始的地方,后面的逻辑会有依赖!所以把初始化看明白了,后面很多逻辑就容易理解多了。
下面上一段代码,(ps:这段代码是在网上找到的,并做了一小部分的修改)
- public class CountTask extends RecursiveTask<Integer> {
- private static final int THRESHOLD = 2; //阀值
- private int start;
- private int end;
- public CountTask(int start,int end){
- this.start = start;
- this.end = end;
- }
- @Override
- protected Integer compute() {
- int sum = 0;
- boolean canCompute = (end - start) <= THRESHOLD;
- if(canCompute){
- for(int i = start; i <= end; i++){
- sum += i;
- }
- }else{
- int middle = (start + end) / 2;
- CountTask leftTask = new CountTask(start,middle);
- CountTask rightTask = new CountTask(middle + 1,end);
- //执行子任务
- leftTask.fork();
- rightTask.fork();
- //等待子任务执行完,并得到其结果
- Integer rightResult = rightTask.join();
- Integer leftResult = leftTask.join();
- //合并子任务
- sum = leftResult + rightResult;
- }
- return sum;
- }
- public static void main(String[] args) throws ExecutionException, InterruptedException {
- ForkJoinPool forkJoinPool = new ForkJoinPool();
- CountTask countTask = new CountTask(1,200);
- ForkJoinTask<Integer> forkJoinTask = forkJoinPool.submit(countTask);
- System.out.println(forkJoinTask.get());
- }
- }
代码的执行过程解释起来也是很简单就是把[1,200],分成[1,100],[101,200],然后再对每个部分进行一个递归分解最终分解成[1,2],[3,4],[5,6]…..[199,200]独立的小任务,然后两两求和合并。
其实显然易见负责整个fork/join初始化工作的就是ForkJoinPool!初始化代码就是那一行 ForkJoinPool forkJoinPool = new ForkJoinPool(),点进去查看源码。
- ForkJoinPool forkJoinPool = new ForkJoinPool();
- //最终调用到这段代码
- public ForkJoinPool(int parallelism,
- ForkJoinWorkerThreadFactory factory,
- UncaughtExceptionHandler handler,
- boolean asyncMode) {
- this(checkParallelism(parallelism), //并行度,当前机器的cpu核数
- checkFactory(factory), //工作线程创建工厂
- handler, //异常处理handler
- asyncMode ? FIFO_QUEUE : LIFO_QUEUE, //任务队列出队模式 异步:先进先出,同步:后进先出
- "ForkJoinPool-" + nextPoolId() + "-worker-");
- checkPermission();
- }
看完初始化的代码我们可以知道原来创建ForkJoinPool创建workerThread的工作都是统一由一个叫ForkJoinWorkerThreadFactory的工厂去创建,创建出来的线程都有一个统一的前辍名称"ForkJoinPool-" + nextPoolId() + "-worker-".队列出队模式是LIFO(后进先出),那这样后面的入队的任务是会被先处理的。
所以上面提到对代码做了一些修改就是先处理rightTask,再处理leftTask。这其实是对代码的一种优化!
- //执行子任务
- leftTask.fork();
- rightTask.fork();
- Integer rightResult = rightTask.join();
- Integer leftResult = leftTask.join();
4、任务的提交逻辑?
fork/join其实大部分逻辑处理操作都集中在提交任务和处理任务这两块,了解任务的提交基本上后面就很容易理解了。
fork/join提交任务主要分为两种:
第一种:第一次提交到forkJoinPool
- ForkJoinTask<Integer> forkJoinTask = forkJoinPool.submit(countTask);
第二种:任务切分之后的提交
- leftTask.fork();
- rightTask.fork();
提交到forkJoinPool :
代码调用路径 submit(ForkJoinTask<T> task) -> externalPush(ForkJoinTask<?> task) -> externalSubmit(ForkJoinTask<?> task)
下面贴上externalSubmit的详细代码,着重留意注释的部分。
- private void externalSubmit(ForkJoinTask<?> task) {
- int r; // initialize caller's probe
- if ((r = ThreadLocalRandom.getProbe()) == 0) {
- ThreadLocalRandom.localInit();
- r = ThreadLocalRandom.getProbe();
- }
- for (;;) { //采用循环入队的方式
- WorkQueue[] ws; WorkQueue q; int rs, m, k;
- boolean move = false;
- if ((rs = runState) < 0) {
- tryTerminate(false, false); // help terminate
- throw new RejectedExecutionException();
- }
- else if ((rs & STARTED) == 0 || // initialize 初始化操作
- ((ws = workQueues) == null || (m = ws.length - 1) < 0)) {
- int ns = 0;
- rs = lockRunState();
- try {
- if ((rs & STARTED) == 0) {
- U.compareAndSwapObject(this, STEALCOUNTER, null,
- new AtomicLong());
- // create workQueues array with size a power of two
- int p = config & SMASK; // ensure at least 2 slots //config就是cpu的核数
- int n = (p > 1) ? p - 1 : 1;
- n |= n >>> 1; n |= n >>> 2; n |= n >>> 4;
- n |= n >>> 8; n |= n >>> 16; n = (n + 1) << 1; //算出workQueues的大小n,n一定是2的次方数
- workQueues = new WorkQueue[n]; //初始化队列,然后跳到最外面的循环继续把任务入队~
- ns = STARTED;
- }
- } finally {
- unlockRunState(rs, (rs & ~RSLOCK) | ns);
- }
- }
- else if ((q = ws[k = r & m & SQMASK]) != null) { //选中了一个一个非空队列
- if (q.qlock == 0 && U.compareAndSwapInt(q, QLOCK, 0, 1)) { //利用cas操作加锁成功!
- ForkJoinTask<?>[] a = q.array;
- int s = q.top;
- boolean submitted = false; // initial submission or resizing
- try { // locked version of push
- if ((a != null && a.length > s + 1 - q.base) ||
- (a = q.growArray()) != null) {
- int j = (((a.length - 1) & s) << ASHIFT) + ABASE; //计算出任务在队列中的位置
- U.putOrderedObject(a, j, task); //把任务放在队列中
- U.putOrderedInt(q, QTOP, s + 1); //更新一次存放的位置
- submitted = true;
- }
- } finally {
- U.compareAndSwapInt(q, QLOCK, 1, 0); //利用cas操作释放锁!
- }
- if (submitted) {
- signalWork(ws, q);
- return; //任务入队成功了!跳出循环!
- }
- }
- move = true; // move on failure
- }
- else if (((rs = runState) & RSLOCK) == 0) { // create new queue 选中的队列是空,初始化完队列,然后继续入队!
- q = new WorkQueue(this, null);
- q.hint = r;
- q.config = k | SHARED_QUEUE;
- q.scanState = INACTIVE;
- rs = lockRunState(); // publish index
- if (rs > 0 && (ws = workQueues) != null &&
- k < ws.length && ws[k] == null)
- ws[k] = q; // else terminated
- unlockRunState(rs, rs & ~RSLOCK);
- }
- else
- move = true; // move if busy
- if (move)
- r = ThreadLocalRandom.advanceProbe(r);
- }
- }
通过对externalSubmit方法的代码进行分析,我们知道了第一次提交任务给forkJoinPool时是在无限循环for (;;)中入队。第一步先检查workQueues是不是还没有创建,如果没有,则进行创建。之后跳到外层for循环并随机选取workQueues里面一个队列,并判断队列是否已创建。没有创建,则进行创建!后又跳到外层for循环直到选到一个非空队列并且加锁成功!这样最后才把任务入队~。
所以我们知道fork/join的任务队列workQueues并不是初始化的时候就创建好了,而是在有任务提交的时候才创建!并且每次入队时都需要利用cas操作来进行加锁和释放锁!
任务切分之后的提交:
- public final ForkJoinTask<V> fork() {
- Thread t;
- if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
- ((ForkJoinWorkerThread)t).workQueue.push(this); //workerThread直接入自己的workQueue
- else
- ForkJoinPool.common.externalPush(this);
- return this;
- }
- final void externalPush(ForkJoinTask<?> task) {
- WorkQueue[] ws; WorkQueue q; int m;
- int r = ThreadLocalRandom.getProbe();
- int rs = runState;
- if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
- (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
- U.compareAndSwapInt(q, QLOCK, 0, 1)) { //随机选取了一个非空队列,并且加锁成功!下面是普通的入队过程~
- ForkJoinTask<?>[] a; int am, n, s;
- if ((a = q.array) != null &&
- (aam = a.length - 1) > (n = (s = q.top) - q.base)) {
- int j = ((am & s) << ASHIFT) + ABASE;
- U.putOrderedObject(a, j, task);
- U.putOrderedInt(q, QTOP, s + 1);
- U.putIntVolatile(q, QLOCK, 0);
- if (n <= 1)
- signalWork(ws, q);
- return; //结束方法
- }
- U.compareAndSwapInt(q, QLOCK, 1, 0); //一定要释放锁!
- }<br> //这个就是上面的externalSummit方法,逻辑是一样的~
- externalSubmit(task);
- }
从代码中我们知道了提交一个fork任务的过程和第一次提交到forkJoinPool的过程是大同小异的。主要区分了提交任务的线程是不是workerThread,如果是,任务直接入workerThread当前的workQueue,不是则尝试选中一个workQueue q。如果q非空并且加锁成功则进行入队,否则执行与第一次任务提交到forkJoinPool差不多的逻辑~。
5、任务的消费
提交到任务的最终目的,是为了消费任务并最终获取到我们想要的结果。介绍任务消费之前我们先了解一个我们的任务ForkJoinTask有哪些关键属性和方法。
- /** The run status of this task */
- volatile int status; // accessed directly by pool and workers
- static final int DONE_MASK = 0xf0000000; // mask out non-completion bits
- static final int NORMAL = 0xf0000000; // must be negative
- static final int CANCELLED = 0xc0000000; // must be < NORMAL
- static final int EXCEPTIONAL = 0x80000000; // must be < CANCELLED
- static final int SIGNAL = 0x00010000; // must be >= 1 << 16
- static final int SMASK = 0x0000ffff; // short bits for tags
- final int doExec() { //任务的执行入口
- int s; boolean completed;
- if ((s = status) >= 0) {
- try {
- completed = exec();
- } catch (Throwable rex) {
- return setExceptionalCompletion(rex);
- }
- if (completed)
- s = setCompletion(NORMAL);
- }
- return s;
- }
再看一下RecursiveTask的定义
- public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
- private static final long serialVersionUID = 5232453952276485270L;
- /**
- * The result of the computation.
- */
- V result;
- /**
- * The main computation performed by this task.
- * @return the result of the computation
- */
- protected abstract V compute(); //我们实现的处理逻辑
- public final V getRawResult() { //获取返回计算结果
- return result;
- }
- protected final void setRawResult(V value) {
- result = value;
- }
- /**
- * Implements execution conventions for RecursiveTask.
- */
- protected final boolean exec() {
- result = compute(); //存储计算结果
- return true;
- }
- }
在代码中我们看到任务的真正执行链路是 doExec -> exec -> compute -> 最后设置status 和 result。既然定义状态status并且还是volatile类型我们可以推断出workerThread在获取到执行任务之后都会先判断status是不是已完成或者异常状态,才决定要不要处理该任务。
下面看一下任务真正的处理逻辑代码!
- Integer rightResult = rightTask.join()
- public final V join() {
- int s;
- if ((s = doJoin() & DONE_MASK) != NORMAL)
- reportException(s);
- return getRawResult();
- }
- //执行处理前先判断staus是不是已完成,如果完成了就直接返回
- //因为这个任务可能被其它线程窃取过去处理完了
- private int doJoin() {
- int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
- return (s = status) < 0 ? s :
- ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
- (w = (wt = (ForkJoinWorkerThread)t).workQueue).
- tryUnpush(this) && (s = doExec()) < 0 ? s :
- wt.pool.awaitJoin(w, this, 0L) :
- externalAwaitDone();
- }
代码的调用链是从上到下。整体处理逻辑如下:
线程是workerThread:
先判断任务是否已经处理完成,任务完成直接返回,没有则直接尝试出队tryUnpush(this) 然后执行任务处理doExec()。如果没有出队成功或者处理成功,则执行wt.pool.awaitJoin(w, this, 0L)。wt.pool.awaitJoin(w, this, 0L)的处理逻辑简单来说也是在一个for(;;)中不断的轮询任务的状态是不是已完成,完成就直接退出方法。否就继续尝试出队处理。直到任务完成或者超时为止。
线程不是workerThread:
直接进行入externalAwaitDone()
- private int externalAwaitDone() {
- int s = ((this instanceof CountedCompleter) ? // try helping
- ForkJoinPool.common.externalHelpComplete(
- (CountedCompleter<?>)this, 0) :
- ForkJoinPool.common.tryExternalUnpush(this) ? doExec() : 0);
- if (s >= 0 && (s = status) >= 0) {
- boolean interrupted = false;
- do {
- if (U.compareAndSwapInt(this, STATUS, s, s | SIGNAL)) {
- synchronized (this) {
- if (status >= 0) {
- try {
- wait(0L);
- } catch (InterruptedException ie) {
- interrupted = true;
- }
- }
- else
- notifyAll();
- }
- }
- } while ((s = status) >= 0);
- if (interrupted)
- Thread.currentThread().interrupt();
- }
- return s;
externalAwaitDone的处理逻辑其实也比较简单,当前线程自己先尝试把任务出队ForkJoinPool.common.tryExternalUnpush(this) ? doExec()然后处理掉,如果不成功就交给workerThread去处理,然后利用object/wait的经典方法去监听任务status的状态变更。
6、任务的窃取
一直说fork/join的任务是work-stealing(工作窃取),那任务究竟是怎么被窃取的呢。我们分析一下任务是由workThread来窃取的,workThread是一个线程。线程的所有逻辑都是由run()方法执行,所以任务的窃取逻辑一定在run()方法中可以找到!
- public void run() { //线程run方法
- if (workQueue.array == null) { // only run once
- Throwable exception = null;
- try {
- onStart();
- pool.runWorker(workQueue); //在这里处理任务队列!
- } catch (Throwable ex) {
- exexception = ex;
- } finally {
- try {
- onTermination(exception);
- } catch (Throwable ex) {
- if (exception == null)
- exexception = ex;
- } finally {
- pool.deregisterWorker(this, exception);
- }
- }
- }
- }
- /**
- * Top-level runloop for workers, called by ForkJoinWorkerThread.run.
- */
- final void runWorker(WorkQueue w) {
- w.growArray(); // allocate queue 进行队列的初始化
- int seed = w.hint; // initially holds randomization hint
- int r = (seed == 0) ? 1 : seed; // avoid 0 for xorShift
- for (ForkJoinTask<?> t;;) { //又是无限循环处理任务!
- if ((t = scan(w, r)) != null) //在这里获取任务!
- w.runTask(t);
- else if (!awaitWork(w, r))
- break;
- r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
- }
- }
其实只要看下面的英文注释就知道了大概scan(WorkQueue w, int r)就是用来窃取任务的!
- /**
- * Scans for and tries to steal a top-level task. Scans start at a
- * random location, randomly moving on apparent contention,
- * otherwise continuing linearly until reaching two consecutive
- * empty passes over all queues with the same checksum (summing
- * each base index of each queue, that moves on each steal), at
- * which point the worker tries to inactivate and then re-scans,
- * attempting to re-activate (itself or some other worker) if
- * finding a task; otherwise returning null to await work. Scans
- * otherwise touch as little memory as possible, to reduce
- * disruption on other scanning threads.
- *
- * @param w the worker (via its WorkQueue)
- * @param r a random seed
- * @return a task, or null if none found
- */
- private ForkJoinTask<?> scan(WorkQueue w, int r) {
- WorkQueue[] ws; int m;
- if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
- int ss = w.scanState; // initially non-negative
- for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
- WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
- int b, n; long c;
- if ((q = ws[k]) != null) { //随机选中了非空队列 q
- if ((n = (b = q.base) - q.top) < 0 &&
- (a = q.array) != null) { // non-empty
- long i = (((a.length - 1) & b) << ASHIFT) + ABASE; //从尾部出队,b是尾部下标
- if ((t = ((ForkJoinTask<?>)
- U.getObjectVolatile(a, i))) != null &&
- q.base == b) {
- if (ss >= 0) {
- if (U.compareAndSwapObject(a, i, t, null)) { //利用cas出队
- q.base = b + 1;
- if (n < -1) // signal others
- signalWork(ws, q);
- return t; //出队成功,成功窃取一个任务!
- }
- }
- else if (oldSum == 0 && // try to activate 队列没有激活,尝试激活
- w.scanState < 0)
- tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
- }
- if (ss < 0) // refresh
- ss = w.scanState;
- r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
- origin = k = r & m; // move and rescan
- oldSum = checkSum = 0;
- continue;
- }
- checkSum += b;
- }<br> //kk = k + 1表示取下一个队列 如果(k + 1) & m == origin表示 已经遍历完所有队列了
- if ((k = (k + 1) & m) == origin) { // continue until stable
- if ((ss >= 0 || (ss == (ss = w.scanState))) &&
- oldSum == (oldSum = checkSum)) {
- if (ss < 0 || w.qlock < 0) // already inactive
- break;
- int ns = ss | INACTIVE; // try to inactivate
- long nc = ((SP_MASK & ns) |
- (UC_MASK & ((c = ctl) - AC_UNIT)));
- w.stackPred = (int)c; // hold prev stack top
- U.putInt(w, QSCANSTATE, ns);
- if (U.compareAndSwapLong(this, CTL, c, nc))
- ss = ns;
- else
- w.scanState = ss; // back out
- }
- checkSum = 0;
- }
- }
- }
- return null;
- }
所以我们知道任务的窃取从workerThread运行的那一刻就已经开始了!先随机选中一条队列看能不能窃取到任务,取不到则窃取下一条队列,直接遍历完一遍所有的队列,如果都窃取不到就返回null。
以上就是我阅读fork/join源码之后总结出来一些心得,写了那么多我觉得也只是描述了个大概而已,真正详细有用的东西还需要仔细去阅读里面的代码才行。如果大家有兴趣的话,不妨也去尝试一下吧-。-~