倭マン's BLOG

くだらない日々の日記書いてます。 たまにプログラミング関連の記事書いてます。 書いてます。

Fork/Join で並列和

前回 Java 7 を使う環境設定を行ったので、ついでにちょっと Java 7 をいじってみます*1。 今回見ていくのは java.util.concurrent パッケージに追加された Fork/Join フレームワークです。 Fork/Join は大量の計算を小さい部分に分けてマルチスレッドで計算する手法です。

で、この記事では整数のリスト (List<Integer>) を受け取ってその要素全ての和を計算するコードを見ていきます。

参考 URL

java.util.concurrent パッケージ : Java 7


Java 7 で Fork/Join を行うには、まず ForkJoinTask クラスのサブクラスを作成する必要があります。 このクラスには実際に行う計算アルゴリズムを実装します。 ただし、通常は ForkJoinTask クラスを直接継承したクラスを作成するのではなく、RecursiveAction クラスもしくは RecursiveTask クラスのサブクラスを作成します。 これらの使い分けは

  • RecursiveAction クラス・・・返り値が必要ない
  • RecursiveTask クラス・・・返り値が必要

です。 ここでは返り値が必要なので RecursiveTask クラスを使います。 クラス名は ParallelSum (並列和)とします:

import java.util.*;
import java.util.concurrent.RecursiveTask;

class ParallelSum extends RecursiveTask<Integer> {
    
    private static final int THRESHOLD = 16;
    
    private final List<Integer> list;
    
    ParallelSum(List<Integer> list){
        this.list = list;
    }

    @Override
    protected Integer compute() {
        if(this.list.size() < THRESHOLD){
            // リストのサイズが THRESHOLD 未満なら普通に足し算する
            int sum = 0;
            for(Integer i : this.list) sum += i;
            return sum;
            
        }else{
            // リストのサイズが THRESHOLD 以上なら
            // 要素数が半分のサブリストに対して再帰(recursive)処理
            int m = this.list.size() / 2;
            ParallelSum ps1 = new ParallelSum(this.list.subList(0, m));
            ps1.fork();    // ps1 の処理実行
            ParallelSum ps2 = new ParallelSum(this.list.subList(m, this.list.size()));
            ps2.fork();    // ps2 の処理実行

            return ps1.join() + ps2.join();    // ps1, ps2 の結果を使ってこのタスクの結果を計算
        }
    }
}
  • ForkJoinTask クラスは返り値の型を型パラメータとして指定します。 実装しなければならない compute() メソッドの返り値がこの型になります。
  • この実装では、渡されたリストのサイズが THRESHOLD 未満なら、普通に足し算をして結果を返します。 Fork/Join では、あまり細かいタスクに分割しすぎるとパフォーマンスが劣化するようなので注意。
  • 渡されたリストのサイズが THRESHOLD 以上なら、サイズが半分のサブリスト2つに対して新たに ParallelSum オブジェクトを作成して fork() メソッドによって処理を実行します。 その後、join() メソッドによってサブリストに対する計算結果を取得して、それらの和を返します。

まぁ、名前の通りなんですが fork(), join() メソッドが大切だ!ってことですね。 ForkJoinTask のサブクラスが作成できれば後は簡単です:

  1. ForkJoinPool オブジェクトを生成する
  2. ForkJoinTask オブジェクトを生成する
  3. ForkJoin#invoke(ForkJoinTask) メソッドで ForkJoinTask を実行する
import java.util.*;
import java.util.concurrent.ForkJoinPool;

public class ParallelSumMain {
    
    public static void main(String... args){
        List<Integer> intList = newIntegerList();
        
        // Fork/Join の実行は実質的にこの2行だけ
        ForkJoinPool forkJoin = new ForkJoinPool();
        int n = forkJoin.invoke(new ParallelSum(intList));

        System.out.println(n);
    }
    
    /** [0, 1, 2, 3, ... , 2^p-1] といったリストを返す */
    private static List<Integer> newIntegerList(){
        int p = 10, n = 1;
        for(int i = 0; i < p; i++) n *= 2;
        
        List<Integer> list = new ArrayList<>(n);
        
        for(int i = 0; i < n; i++) list.add(i);

        return list;
    }
}

Fork/Join の実行は invoke() メソッド以外でも可能です:

  • execute(ForkJoinTask<T>) : void ・・・ 返り値が必要ない場合
  • invoke(ForkJoinTask<T>) : T ・・・ 返り値が必要な場合
  • submit(ForkJoinTask<T>) : ForkJoinTask<T> ・・・ Future<T> 型の返り値が必要な場合

submit() メソッドが ForkJoinTask オブジェクトを返すのは、ForkJoinTask が Future インターフェースを実装しているためです。 Future は get() メソッドによって計算結果を取得できますが、計算が終了していない場合は呼び出しスレッドを待たせます。

まぁ、こんな感じでどうでしょう?

GPars : Groovy


前回は Groovy の開発環境も設定したので、Groovy のマルチスレッド・プログラミング・ライブラリである GPars でも同じサンプルを書いてみます。

import static groovyx.gpars.GParsPool.runForkJoin
import static groovyx.gpars.GParsPool.withPool

def N = 2**10
def THRESHOLD = 16

def argList = 0..<N

withPool(){
    def sum = runForkJoin(argList){ List list ->
        if(list.size() < THRESHOLD){
            return list.sum()
            
        }else{
            int m = list.size() / 2
            forkOffChild(list.subList(0, m))
            forkOffChild(list.subList(m, list.size()))
            return childrenResults.sum(0)
        }
    }
    println sum
}

ふも、かなりコンパクトに書けてしまいましたがいいんでしょうか? スレッドプログラミングってバグ見つけにくいのでコワいのよねぇ。 ちなみに Java 7 のコードとの対応は(必ずしも1対1対応しているわけではありませんが)

  • GParsPool#withPool() ⇔ ForkJoinPool
  • GParsPool#runForkJoin() ⇔ ForkJoinPool#invoke()
  • forkOffChild ⇔ ForkJoinTask#fork()
  • childrenResults ⇔ ForkJoinTask#join()

といった感じでしょうか。

・・・よく考えると、GPars には sumParallel() ってメソッドがあって、あえて Fork/Join 使って実装する必要ないよね? そうだよね? 引数のリストを準備するコード含めても10行以内で書けるよね?

並行コンピューティング技法 ―実践マルチコア/マルチスレッドプログラミング

並行コンピューティング技法 ―実践マルチコア/マルチスレッドプログラミング


Java並行処理プログラミング ―その「基盤」と「最新API」を究める―

Java並行処理プログラミング ―その「基盤」と「最新API」を究める―

*1:そういえば前に nio2 の記事を書いてた気もするけど。