ForkJoinPool

Java 1.7 引入了一种新的并发框架—— Fork/Join Framework

Fork/Join的思路:通过分而治之(把大任务分割成若干个小任务),只不过划分之后的任务更适合分派给不同的计算资源,可以并行的完成任务。

原理

Fork/Join框架主要依靠forkjoin两个操作,一般对这两个操作的解释如下:

  • fork():开启一个新线程(或是重用线程池内的空闲线程),将任务交给该线程处理。
  • join():等待该任务的处理线程处理完毕,获得返回值。

这里有个问题,不断的fork()如果是不断创建线程,岂不是要“线程数量爆炸”?事实上,ForkJoinPool用了一种work stealing的算法,避免产生大量线程。所以如果一开始设置线程池的线程数为N,实际上使用ForkJoinPool的时候也只会有固定的线程数(默认和CPU核数一样)。

Fork/Join的基本用法

1
2
3
4
5
if (当前这个任务工作量足够小)
直接完成这个任务
else
将这个任务分解成两个部分
分别触发(invoke)这两个子任务的执行,并等待结果

ForkJoin框架组成

  • ForkJoinPool:管理worker线程,类似ThreadPoolExecutor,提供接口用于提交或者执行任务;
  • ForkJoinWorkerThreadworker线程,任务保存在一个deque中;
  • ForkJoinTaskForkJoin框架中运行的任务,可以fork子任务,可以join子任务完成。

ForkJoinPool构造函数

1
2
3
4
5
6
7
8
9
10
public class ForkJoinPool extends AbstractExecutorService {
public ForkJoinPool() {
// 默认并行数为CPU核数
this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
defaultForkJoinWorkerThreadFactory, null, false);
}
public ForkJoinPool(int parallelism) {
this(parallelism, defaultForkJoinWorkerThreadFactory, null, false);
}
}

work stealing(窃取)算法

一个任务通过fork()被分割成若干个小任务。比如线程1和线程2都被分割成4个小任务,如果线程1执行完毕,那么他可以去窃取线程2的工作。当要发生线程窃取的时候,两个线程内的任务可以理解成放在一个线程自己的双端队列中。例如下图线程2中的任务被分割成若干个小任务(就WorkQueue里面的一个个小方块)放在其线程的双端队列中。被窃取任务线程从其他双端队列的头部拿任务执行,而窃取任务的线程永远从双端队列的尾部拿任务执行。

img

  • ForkJoinPool 的每个工作线程都维护着一个工作队列WorkQueue),这是一个双端队列(Deque),里面存放的对象是任务ForkJoinTask)。
  • 每个工作线程在运行中产生新的任务(通常是因为调用了 fork())时,会放入工作队列的队尾,并且工作线程在处理自己的工作队列时,使用的是 LIFO 方式,也就是说每次从队尾取出任务来执行。
  • 每个工作线程在处理自己的工作队列同时,会尝试窃取一个任务(或是来自于刚刚提交到 pool 的任务,或是来自于其他工作线程的工作队列),窃取的任务位于其他线程的工作队列的队首,也就是说工作线程在窃取其他工作线程的任务时,使用的是 FIFO 方式。
  • 在遇到 join() 时,如果需要join的任务尚未完成,则会先处理其他任务,并等待其完成。
  • 在既没有自己的任务,也没有可以窃取的任务时,进入休眠

优缺点

work stealing算法的优点:利用了线程进行并行计算,减少了线程间的竞争。

work stealing算法的缺点:

  • 如果双端队列中只有一个任务时,线程间会存在竞争。
  • 额外的开销,例如双端队列

fork()

fork() 做的工作只有一件事,既是把任务推入当前工作线程的工作队列里

1
2
3
4
5
6
7
8
public final ForkJoinTask<V> fork() {
Thread t;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
((ForkJoinWorkerThread)t).workQueue.push(this);
else
ForkJoinPool.common.externalPush(this);
return this;
}

join()

  1. 检查调用 join() 的线程是否是 ForkJoinThread线程。如果不是,则阻塞当前线程,等待任务完成。如果是,则不阻塞。
  2. 查看任务的完成状态,如果已经完成,直接返回结果。
  3. 如果任务尚未完成,但处于自己的工作队列内,则完成它。
  4. 如果任务已经被其他的工作线程偷走,则窃取这个任务的worker执行(以 FIFO 方式),以期帮助它早日完成join的任务。
  5. 如果偷走任务的worker也已经把自己的任务全部做完,正在等待需要join的任务时,则找到该小偷的小偷,帮助它完成它的任务。
  6. 递归地执行第5步。

img

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
private int doJoin() {
Thread t; ForkJoinWorkerThread w; int s; boolean completed;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) {
// status值的初始化值是0,在任务没有完成以前一直是非负值
if ((s = status) < 0)
return s;
// 从当前工作线程的栈顶中 pop 该任务,准备执行
if ((w = (ForkJoinWorkerThread)t).unpushTask(this)) {
try {
completed = exec();
} catch (Throwable rex) {
return setExceptionalCompletion(rex);
}
if (completed)
return setCompletion(NORMAL);
}
// 当工作线程队列为空或者任务没有正常完成,则会给helpJoinTask stolen->joining 方式执行
return w.joinTask(this);
}
else
// 不是worker线程,直接调用Object.wait等待任务完成(阻塞)。
return externalAwaitDone();
}

案例

统计1~1000整数之和

单线程For循环

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public class ForLoopCalculator {
public static long sum(long[] numbers) {
long sum = 0;
for (long i : numbers) {
sum += i;
}
return sum;
}
public static void main(String[] args) {
long[] numbers = LongStream.rangeClosed(1, 1000).toArray();
System.out.println(sum(numbers));
}
}

ExecutorService线程池

把大任务分割成若干个小任务,并行计算再合并结果

1
2
3
4
5
6
7
8
9
10
11
12
13
//任务Task
class SumTask implements Callable<Long> {
private long[] numbers;
public SumTask(long[] numbers) {
this.numbers = numbers;
}
@Override
public Long call() throws Exception {
return Arrays.stream(numbers).sum();
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
// Task执行线程池
public class ExecutorServiceCalculator {
int parallism;
private ExecutorService pool;
public ExecutorServiceCalculator(int parallism) {
this.parallism = parallism;
pool = Executors.newFixedThreadPool(parallism);
}
public long sum(long[] numbers) {
List<Future<Long>> results = new ArrayList<>();
// 把任务分解为 n 份,交给 n 个线程处理
int part = numbers.length / parallism;
for (int i = 0; i < parallism; i++) {
int from = i * part + 1;
int to = (i + 1) * part;
results.add(pool.submit(new SumTask(LongStream.rangeClosed(from, to).toArray())));
}
// 把每个线程的结果相加,得到最终结果
long sum = results.stream().map(f -> {
try {
return f.get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}).reduce(0L, Long::sum);
return sum;
}
}

ForkJoinPool线程池

ForkJoinPool主要用于实现“分而治之”的算法,特别是分治之后递归调用的函数,例如quick sort

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
public class ForkJoinCalculator {
private ForkJoinPool pool;
public ForkJoinCalculator() {
pool = new ForkJoinPool();
}
public long sum(long[] numbers) {
return pool.invoke(new SumTask(numbers, 0, numbers.length - 1));
}
private static class SumTask extends RecursiveTask<Long> {
private long[] numbers;
private int from;
private int to;
public SumTask(long[] numbers, int from, int to) {
this.numbers = numbers;
this.from = from;
this.to = to;
}
@Override
protected Long compute() {
if (to - from <= 0) {
return numbers[from];
}
if (to - from <= 1) {
return numbers[from] + numbers[to];
// 否则,把任务一分为二,递归计算
} else {
int middle = (from + to) / 2;
SumTask taskLeft = new SumTask(numbers, from, middle);
SumTask taskRight = new SumTask(numbers, middle + 1, to);
taskLeft.fork();
taskRight.fork();
return taskLeft.join() + taskRight.join();
}
}
}
}

在这段代码里没有显式地“把任务分配给线程”,只是分解了任务,而把具体的任务到线程的映射交给了 ForkJoinPool 来完成。

参考

如何使用 ForkJoinPool 以及原理

ForkJoinPool解读

Java Fork&Join框架使用和实现分析

热评文章