假设有大量的CPU密集型计算任务,比如计算1-100的总和,普通的写法是单线程循环累加1-100,这样固然可以。不过 Doug lea 觉得太慢了,于是设计了 ForkJoinPool 。
ForkJoinPool 设计理念是分治思想,将大任务拆分成多个小任务,然后再多线程去执行小任务,最后再将小任务的结果合在一起得到最终结果。
啊?如果是多线程执行任务的话,那我在拆分任务的时候是new一个有返回值的线程去执行不一样的吗,为什么还要大费周章写个 ForkJoinPool?
因为ForkJoinPool本质上是一个线程池,线程池最大的优势就是可以线程复用,无需创建大量线程浪费系统资源。用线程池分的时候只需要向线程池提交一个任务就可以了。
啊?如果是线程复用的话,那我用普通的线程池不可以吗,为什么要用 ForkJoinPool?
因为如果用普通线程池的话,大任务的拆分以及小任务结果的归并这些操作的具体细节都需要自己去控制,搞不好还会出现死锁。
举个例子,线程池最大线程数为2,1-100,还是用分治去拆,最小粒度是10。
首先1-100被拆成1-50,51-100分别被两个线程执行,然后1-50,51-100再拆成1-25,26-50,51-75,76-100四个子任务,但此时已经没有多余的线程可以去执行了,所以四个子任务入队阻塞队列等待空闲线程去执行,但此时线程池中的线程又在阻塞等待拆出来子任务的结果。因而产生了死锁,两边互相等待,永远等不到。如下图:
当然啦,就1-100,每10个数字一组任务这个场景来说,完全可以不用分治思想去实现,直接在main函数提交10个任务,然后依次累加。但这里的1-100只是一个例子,重点在于分治思想的设计和实现,如果抛开分治去讨论就没有意义了。
所以 ForkJoinPool 是工作原理是什么?同样是线程池为什么不会死锁?
工作原理
首先介绍一下基本概念:
ForkJoinTask:任务类 。其fork、join方法对应拆、合任务。常用实现类有RecursiveTask(有返回值),RecursiveAction(无返回值)。
WorkQueue:工作队列。用于存放任务。每个工作线程都有自己的工作队列,所以 ForkJoinPool 存在成员变量 WorkQueue[] workQueues。WorkQueue[] 的容量为2次幂,索引为偶数的存放外部线程提交的任务,索引为奇数的存放内部fork出来的任务。
ForkJoinWorkerThread:工作线程。每个工作线程在处理自身工作队列的同时,会窃取其他工作队列的任务,窃取的位置是底部。
原理大概就是:ForkJoinPool 内部有多个工作队列(对应多个工作线程),提交任务时,会根据一定的规则提交到其中一个队列。在任务执行的过程中如果 fork 了子任务,子任务入队自己工作队列的top,后续当子任务 join 时,如果子任务未被窃取就当前线程直接执行子任务,反之就先处理其他任务等待其完成。
因为处理自身工作队列时是LIFO,而处理其他工作队列时是FIFO,所以工作队列属于双端队列。
为什么即可以多线程执行,又不会像普通线程池一样死锁?
多线程由来:sumbit和fork都会尝试运行一个工作线程。
不会死锁:因为子任务肯定可以得到执行,要么就在自己的工作队列由自身线程执行,要么就被其他线程窃取执行。所以 ForkJoinPool 可以理解为用任务窃取来弥补单线程执行的性能问题。
先将任务随机入队偶数位置的共享队列,然后创建工作线程和工作队列并运行,而工作线程运行时又会窃取其他队列的任务。所以总能窃取到首次提交到共享队列的任务。
public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
return externalSubmit(task);
}
private <T> ForkJoinTask<T> externalSubmit(ForkJoinTask<T> task) {
Thread t; ForkJoinWorkerThread w; WorkQueue q;
if (task == null)
throw new NullPointerException();
//如果是内部任务(fork)就尝试直接入队
if (((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) &&
(w = (ForkJoinWorkerThread)t).pool == this &&
(q = w.workQueue) != null)
q.push(task);
//反之是外部提交的任务执行 externalPush
else
externalPush(task);
return task;
}
final void externalPush(ForkJoinTask<?> task) {
int r;
//随机数
if ((r = ThreadLocalRandom.getProbe()) == 0) {
ThreadLocalRandom.localInit();
r = ThreadLocalRandom.getProbe();
}
//自旋
for (;;) {
WorkQueue q;
int md = mode, n;
WorkQueue[] ws = workQueues;
//异常情况
if ((md & SHUTDOWN) != 0 || ws == null || (n = ws.length) <= 0)
throw new RejectedExecutionException();
//队列不存在,new Queue
//这里new的工作队列是偶数位的共享队列,只是用于存放外部提交的任务,所以owner为null。
else if ((q = ws[(n - 1) & r & SQMASK]) == null) {
int qid = (r | QUIET) & ~(FIFO | OWNED);
Object lock = workerNamePrefix;
ForkJoinTask<?>[] qa =
new ForkJoinTask<?>[INITIAL_QUEUE_CAPACITY];
q = new WorkQueue(this, null);
q.array = qa;
q.id = qid;
q.source = QUIET;
if (lock != null) {
synchronized (lock) {
WorkQueue[] vs; int i, vn;
if ((vs = workQueues) != null && (vn = vs.length) > 0 &&
vs[i = qid & (vn - 1) & SQMASK] == null)
vs[i] = q; // else another thread already installed
}
}
}
//队列忙,尝试移动到下一个队列
else if (!q.tryLockPhase())
r = ThreadLocalRandom.advanceProbe(r);
//将任务压入队列,并通知有可用任务
else {
//压入队列
if (q.lockedPush(task))
//通知有可用任务。
//主要是确保有足够的工作线程去执行任务。
//要么创建新的工作线程。要么唤醒阻塞的工作线程
signalWork();
return;
}
}
}
final void signalWork() {
//自旋
for (;;) {
long c; int sp; WorkQueue[] ws; int i; WorkQueue v;
//当前有足够多的活跃工作线程
if ((c = ctl) >= 0L)
break;
//工作线程还没满,新建一个工作线程和工作队列并绑定,之后启动工作线程
else if ((sp = (int)c) == 0) {
if ((c & ADD_WORKER) != 0L)
tryAddWorker(c);
break;
}
//如果工作队列未启动或已终止,返回
else if ((ws = workQueues) == null)
break;
else if (ws.length <= (i = sp & SMASK))
break
else if ((v = ws[i]) == null)
break;
//唤醒阻塞线程
else {
int np = sp & ~UNSIGNALLED;
int vp = v.phase;
long nc = (v.stackPred & SP_MASK) | (UC_MASK & (c + RC_UNIT));
Thread vt = v.owner;
if (sp == vp && CTL.compareAndSet(this, c, nc)) {
v.phase = np;
if (vt != null && v.source < 0)
//unpark
LockSupport.unpark(vt);
break;
}
}
}
}
*关于ctl,这是一个64位的long类型变量,由4个16位组成,分别是
* RC: 活跃(没有阻塞,即正在扫描或运行任务的)工作线程数 - 目标并行度
* TC: 总工作线程数 - 目标并行度
* SS: Treiber栈顶部阻塞线程的版本计数和状态(Treiber栈:未扫描到任务入栈阻塞等待)
* ID: Treiber栈顶部阻塞线程的poolIndex
ForkJoinPool用到了大量的位运算,比如这个ctl就是,具体我也没去深究,位运算看麻了,这里简单记录了解一下吧…
总的来说 signalWork 可以保证有足够的工作线程在运行(要么新建线程,要么唤醒阻塞线程)。
//尝试创建工作线程
private void tryAddWorker(long c) {
do {
long nc = ((RC_MASK & (c + RC_UNIT)) |
(TC_MASK & (c + TC_UNIT)));
if (ctl == c && CTL.compareAndSet(this, c, nc)) {
createWorker();
break;
}
} while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
}
//创建工作线程
private boolean createWorker() {
ForkJoinWorkerThreadFactory fac = factory;
Throwable ex = null;
ForkJoinWorkerThread wt = null;
try {
if (fac != null && (wt = fac.newThread(this)) != null) {
//启动创建的工作线程
wt.start();
return true;
}
} catch (Throwable rex) {
ex = rex;
}
deregisterWorker(wt, ex);
return false;
}
上面看到提交一个任务会启动一个工作线程,那么工作线程是如何运行的,又是如何窃取任务的?
public void run() {
if (workQueue.array == null) { // 只会运行一次,毕竟这个是线程的run方法
Throwable exception = null;
try {
onStart();
//主要逻辑在ForkJpinPool.runWorker方法
pool.runWorker(workQueue);
} catch (Throwable ex) {
exception = ex;
} finally {
try {
onTermination(exception);
} catch (Throwable ex) {
if (exception == null)
exception = ex;
} finally {
pool.deregisterWorker(this, exception);
}
}
}
}
final void runWorker(WorkQueue w) {
int r = (w.id ^ ThreadLocalRandom.nextSecondarySeed()) | FIFO;
w.array = new ForkJoinTask<?>[INITIAL_QUEUE_CAPACITY];
//自旋
for (;;) {
int phase;
//scan扫描所有工作队列的任务,任务窃取就是在scan窃取的
if (scan(w, r)) {
r ^= r << 13; r ^= r >>> 17; r ^= r << 5; //随机数
}
//没扫描到任务就阻塞线程,在阻塞之前更新栈前驱(stackPred )和ctl字段的值
else if ((phase = w.phase) >= 0) {
long np = (w.phase = (phase + SS_SEQ) | UNSIGNALLED) & SP_MASK;
long c, nc;
do {
w.stackPred = (int)(c = ctl);
nc = ((c - RC_UNIT) & UC_MASK) | np;
} while (!CTL.weakCompareAndSet(this, c, nc));
}
//阻塞线程
else {
int pred = w.stackPred;
Thread.interrupted();
w.source = DORMANT;
long c = ctl;
int md = mode, rc = (md & SMASK) + (int)(c >> RC_SHIFT);
if (md < 0)
break;
else if (rc <= 0 && (md & SHUTDOWN) != 0 &&
tryTerminate(false, false))
break;
else if (rc <= 0 && pred != 0 && phase == (int)c) {
long nc = (UC_MASK & (c - TC_UNIT)) | (SP_MASK & pred);
long d = keepAlive + System.currentTimeMillis();
LockSupport.parkUntil(this, d);
if (ctl == c &&
d - System.currentTimeMillis() <= TIMEOUT_SLOP &&
CTL.compareAndSet(this, c, nc)) {
w.phase = QUIET;
break;
}
}
else if (w.phase < 0)
LockSupport.park(this);
w.source = 0;
}
}
}
stackPred表示在线程池栈当前工作线程的前驱线程的索引,在唤醒线程时用到此属性。
任务窃取就在这个scan方法。
private boolean scan(WorkQueue w, int r) {
WorkQueue[] ws; int n;
if ((ws = workQueues) != null && (n = ws.length) > 0 && w != null) {
//扫描所有工作队列
for (int m = n - 1, j = r & m;;) {
WorkQueue q; int b;
//如果工作队列不为空且有任务
if ((q = ws[j]) != null && q.top != (b = q.base)) {
int qid = q.id;
ForkJoinTask<?>[] a; int cap, k; ForkJoinTask<?> t;
if ((a = q.array) != null && (cap = a.length) > 0) {
//从base窃取任务
t = (ForkJoinTask<?>)QA.getAcquire(a, k = (cap - 1) & b);
if (q.base == b++ && t != null &&
QA.compareAndSet(a, k, t, null)) {
q.base = b;
w.source = qid;
if (q.top - b > 0)
//如果窃取后还有任务,就调用 signalWork 看能否帮忙执行
signalWork();
//执行窃取的任务
w.topLevelExec(t, q,
r & ((n << TOP_BOUND_SHIFT) - 1));
}
}
return true;
}
else if (--n > 0)
//随机定一个位置后,线性扫描
j = (j + 1) & m;
else
break;
}
}
return false;
}
final void topLevelExec(ForkJoinTask<?> t, WorkQueue q, int n) {
if (t != null && q != null) {
int nstolen = 1;
//自旋
for (;;) {
//执行任务,进而调用到我们重写的 ForkJoinTask.exec 方法
t.doExec();
if (n-- < 0)
break;
//下一个任务优先处理自己工作队列中的任务
else if ((t = nextLocalTask()) == null) {
//如果自己工作队列中没有任务,就从这个队列中再窃取一个任务
if ((t = q.poll()) == null)
//如果都没有任务了,就结束这个方法重新scan扫描
break;
else
++nstolen;
}
}
ForkJoinWorkerThread thread = owner;
nsteals += nstolen;
source = 0;
if (thread != null)
thread.afterTopLevelExec();
}
}
fork的流程就相对简单了,其实就是入队后(如果需要)调用 signalWork 然后等待运行而已。
public final ForkJoinTask<V> fork() {
Thread t;
//如果是ForkJoinWorkerThread,入队工作队列,反之入队common全局队列
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
((ForkJoinWorkerThread)t).workQueue.push(this);
else
ForkJoinPool.common.externalPush(this);
return this;
}
final void push(ForkJoinTask<?> task) {
ForkJoinTask<?>[] a;
int s = top, d, cap, m;
ForkJoinPool p = pool;
if ((a = array) != null && (cap = a.length) > 0) {
QA.setRelease(a, (m = cap - 1) & s, task);
top = s + 1;
//size = 0 或 1
if (((d = s - (int)BASE.getAcquire(this)) & ~1) == 0 &&
p != null) {
VarHandle.fullFence();
//fork可能也会调用到signalWork
p.signalWork();
}
else if (d == m)
//扩容任务列表 ForkJoinTask>[]
growArray(false);
}
}
java8并行流用到的就是这里的common全局队列,所以java8并行流有个坑就是不同业务(线程)用到的队列是同一个,在某些情况下会相互影响。
join则是获取子任务的结果。如果join的时候子任务已有结果直接返回,反之看join的子任务是否还在自己的工作队列上,如果是的话自己运行,如果不是的话就等待结果。
值得注意的是,为提高性能,等待子任务结果时并不是直接阻塞等待,而是边执行其他任务边等待。
并且就算真正进入wait等待也会补偿一个活跃线程,以免无线程可用。补偿逻辑:如果有可以唤醒的线程就唤醒线程,如果线程数未满或者虽然已满但全都在awaitJoin子任务结果就新建线程。
因为如果所有线程都在awaitJoin子任务结果,这个时候就算线程数满了,补偿的时候仍会新建线程。所以在极端情况下ForkJoinPool的总线程数是可能大于参数值的。不过最大不会超过 0x7fff, 超过了就抛异常。
public final V join() {
int s;
if (((s = doJoin()) & ABNORMAL) != 0)
reportException(s);
return getRawResult();
}
private int doJoin() {
int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
return (s = status) < 0 ? s :
((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
(w = (wt = (ForkJoinWorkerThread)t).workQueue).
//如果已有结果直接返回
tryUnpush(this) && (s = doExec()) < 0 ? s :
//如果是ForkJoinWorkerThread线程内部等待
wt.pool.awaitJoin(w, this, 0L) :
//反之外部等待
externalAwaitDone();
}
final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) {
int s = 0;
int seed = ThreadLocalRandom.nextSecondarySeed();
if (w != null && task != null &&
(!(task instanceof CountedCompleter) ||
(s = w.helpCC((CountedCompleter<?>)task, 0, false)) >= 0)) {
//如果任务还在自身工作队列的话,自己执行任务
w.tryRemoveAndExec(task);
int src = w.source, id = w.id;
int r = (seed >>> 16) | 1, step = (seed & ~1) | 2;
s = task.status;
//s>=0说明还没有结果,继续等待
while (s >= 0) {
WorkQueue[] ws;
int n = (ws = workQueues) == null ? 0 : ws.length, m = n - 1;
//等待的过程中为提高性能可以扫描其他任务执行
while (n > 0) {
WorkQueue q; int b;
if ((q = ws[r & m]) != null && q.source == id &&
q.top != (b = q.base)) {
ForkJoinTask<?>[] a; int cap, k;
int qid = q.id;
if ((a = q.array) != null && (cap = a.length) > 0) {
ForkJoinTask<?> t = (ForkJoinTask<?>)
QA.getAcquire(a, k = (cap - 1) & b);
if (q.source == id && q.base == b++ &&
t != null && QA.compareAndSet(a, k, t, null)) {
q.base = b;
w.source = qid;
t.doExec();
w.source = src;
}
}
break;
}
else {
r += step;
--n;
}
}
//已有结果break
if ((s = task.status) < 0)
break;
//进入阻塞逻辑
else if (n == 0) {
long ms, ns; int block;
if (deadline == 0L)
ms = 0L;
else if ((ns = deadline - System.nanoTime()) <= 0L)
break;
else if ((ms = TimeUnit.NANOSECONDS.toMillis(ns)) <= 0L)
ms = 1L;
//实际阻塞前,tryCompensate补偿一个线程
//补偿逻辑:如果有可以唤醒的线程就唤醒线程
// 如果线程数未满或者虽然已满但全都在awaitJoin子结果就新建线程
if ((block = tryCompensate(w)) != 0) {
//内部调用Object.wait阻塞
task.internalWait(ms);
CTL.getAndAdd(this, (block > 0) ? RC_UNIT : 0L);
}
s = task.status;
}
}
}
return s;
}
final void internalWait(long timeout) {
if ((int)STATUS.getAndBitwiseOr(this, SIGNAL) >= 0) {
synchronized (this) {
if (status >= 0)
try { wait(timeout); } catch (InterruptedException ie) { }
else
notifyAll();
}
}
}