読者です 読者をやめる 読者になる 読者になる

倭マン's BLOG

くだらない日々の日記書いてます。 たまにプログラミング関連の記事書いてます。 書いてます。

Scala の Seq に定義されているメソッドを試す (7) ~畳み込み~

Scala の Seq に定義されているメソッドを試すシリーズ(目次)。 今回扱うのは、Seq の要素を走査して何かしらの結果を計算する、畳み込みに関するメソッドです。

今回扱うメソッド

畳み込みに関するメソッドには fold, reduce, scan があり、それぞれに左右どちらかから計算していくかを指定する Left/Right がついたメソッドがあります。 Left/Right が付いていないメソッドは左畳み込みで、かつ指定できる二項演算が要素の型(もしくはそのスーパー型)に対するものだけに制限されます。

// fold 系
def fold[A1 >: A](z: A1)(op: (A1, A1) ⇒ A1): A1
def foldLeft[B](z: B)(op: (B, A) ⇒ B): B
def foldRight[B](z: B)(op: (A, B) ⇒ B): B
def /:[B](z: B)(op: (B, A) ⇒ B): B
def :\[B](z: B)(op: (A, B) ⇒ B): B

def aggregate[B](z: ⇒ B)(seqop: (B, A) ⇒ B, combop: (B, B) ⇒ B): B

// reduce 系
def reduce[A1 >: A](op: (A1, A1) ⇒ A1): A1
def reduceOption[A1 >: A](op: (A1, A1) ⇒ A1): Option[A1]
def reduceLeft[B >: A](op: (B, A) ⇒ B): B
def reduceLeftOption[B >: A](op: (B, A) ⇒ B): Option[B]
def reduceRight[B >: A](op: (A, B) ⇒ B): B
def reduceRightOption[B >: A](op: (A, B) ⇒ B): Option[B]

// scan 系
def scan[B >: A, That](z: B)(op: (B, B) ⇒ B)(implicit cbf: CanBuildFrom[Seq[A], B, That]): That
def scanLeft[B, That](z: B)(op: (B, A) ⇒ B)(implicit bf: CanBuildFrom[Seq[A], B, That]): That
def scanRight[B, That](z: B)(op: (A, B) ⇒ B)(implicit bf: CanBuildFrom[Seq[A], B, That]): That

要素が数値の Seq オブジェクトに対して和や積を計算する sum, product メソッドも畳み込み関連のメソッドですが、これらは別の機会に(まぁ、説明しなくても使い方は分かると思いますが)。

サンプルコード

fold メソッド
fold メソッドは、初期値と要素に対する二項演算を与えて畳み込みを計算します。 たとえば、数値を要素に持つ Seq オブジェクトに対して、初期値を0、二項演算を加算とすると、要素全ての和となります:

  val intSeq = Seq(1, 3, 5, 7, 9, 11, 13, 15, 17, 19)  // == 1 to 19 by 2

  // fold メソッド
  // Int の Seq オブジェクトに対して要素全ての和を計算する
  val result0 = intSeq.fold(0)((m, n) => m + n)
  assert( result0 == 100 )

二項演算を指定するラムダ式は、以下のようにもう少し簡単に書けますね:

  val result1 = intSeq.fold(0)(_+_)
  assert( result1 == 100 )

ちなみに、要素の和はそのうちやる sum メソッドを使えば fold メソッドを使う必要がありませんが。

もちろん数値以外にも fold メソッドは使えます。 例えば以下のようにして文字列を連結できます:

  val strSeq = Seq("a", "b", "c", "d", "e")

  val result2 = strSeq.fold("")(_+_)
  assert( result2 == "abcde" )

fold メソッドは次にやる foldLeft メソッドのように左から畳み込みます:

  val strSeq = Seq("a", "b", "c", "d", "e")

  val result3 = intSeq.fold(0)((sum2, n) => sum2 + n*n)  // 要素の2乗の和を計算
  assert( result3 == 1330 )

ただし、二項演算の引数を同等に扱わない場合は fold よりも foldLeft にしておいた方がよいかと思います。

foldLeft, /: メソッド
foldLeft メソッドは Seq の要素を左(先頭)から順に畳み込んでいきます。 例えば Seq オブジェクトの要素が x0, x1, x2 の3つで、初期値が z、二項演算が f(x, y) のとき、foldLeft(z)(f) は

  f(f(f(f(z, x0), x1), x2), x3)

を計算します。 foldLeft メソッドは fold メソッドと同じように使えます:

  val intSeq = Seq(1, 3, 5, 7, 9, 11, 13, 15, 17, 19)

  // foldLeft メソッド
  val result0 = intSeq.foldLeft(0)((sum2, n) => sum2 + n*n )  // 要素の2乗を計算
  assert( result == 1330)

左畳み込みの場合は、二項演算の第0引数は計算中の値(今の場合、和の途中の値)を保持するアキュムレータ (accumulator) となります。

folodLeft メソッドは fold メソッドと異なり、要素の型とは違う型を計算結果として用いることができます(この意味では第2引数リストの関数は「二項演算」ではない気もしますが)。 

  val strSeq = Seq("a", "b", "c", "d", "e")
  def builder = Seq.newBuilder[String]

  // アキュムレータとして Builder オブジェクトを用いてみる
  //(結果値で result を呼び出して Seq オブジェクトへ変換)
  val result1 = strSeq.foldLeft(builder)((b, s) => b += s).result
  assert( result1 == Seq("a", "b", "c", "d", "e"))

  // もしくは二項演算を簡単に書いて
  val result2 = strSeq.foldLeft(builder)(_+=_).result
  assert( result2 == Seq("a", "b", "c", "d", "e"))

このコードでは、返される Seq オブジェクトがレシーバの Seq オブジェクトと実質同じなので、大して意味のないコードになってしまってますが、まぁ使い方を見るためのコードということで。

/: メソッドは foldLeft メソッドと同じです。

  def builder = Seq.newBuilder[String]

  // /: メソッド
  assert( strSeq./:(builder)(_+=_).result  == Seq("a", "b", "c", "d", "e") )

  // もしくは演算子っぽく書くとこんな感じ
  assert( (builder /: strSeq)(_+=_).result  == Seq("a", "b", "c", "d", "e") )

foldRight, :\ メソッド
foldRight メソッドは Seq オブジェクトの要素を右から畳み込みます。 例えば Seq オブジェクトの要素が x0, x1, x2 の3つで、初期値が z、二項演算が f(x, y) のとき、foldRight(z)(f) は

  f(x0, f(x1, f(x2, f(x3, z))))

を計算します。 foldRight メソッドでは、二項演算の関数の第2引数がアキュムレータになります:

  val intSeq = Seq(1, 3, 5, 7, 9, 11, 13, 15, 17, 19)

  // foldRight メソッド
  val result0 = intSeq.foldRight(0)((n, sum2) => n*n + sum2)
  assert( result0 == 1330 )

foldRight メソッドも foldLeft メソッドと同じく、要素型と異なる型を計算結果として用いることができます。 foldLeft メソッドのサンプルコードでは Builder オブジェクトを使って結果の Seq オブジェクトを構築していましたが、これは、foldLeft メソッドの畳み込みの順序では要素をアキュムレータの末尾に追加しないといけないためでした(別に Vector オブジェクトでもよかったのですが)。

List などの LinearSeq 型のコレクションでは、先頭に要素を追加する方が高速なので、以下のように List オブジェクトの構築には foldRight メソッドを用いる方がパフォーマンスが高くなります:

  val strSeq = Seq("a", "b", "c", "d", "e")
  val emp = Seq[String]()  // 空の List オブジェクト

  val result1 = strSeq.foldRight(emp)((s, acc) => s +: acc)
  assert( result1 == Seq("a", "b", "c", "d", "e") )

  // もしくは
  assert( strSeq.foldRight(emp)(_+:_) == Seq("a", "b", "c", "d", "e") )

foldRight を使うのは List (LinearSeq) を構築する場合が大半ではないかと思います*1

:\ メソッドは foldRight メソッドと同じです。

  val strSeq = Seq("a", "b", "c", "d", "e")
  val emp = Seq[String]() 

  // :\ メソッド
  assert( strSeq.:\(emp)(_+:_) == Seq("a", "b", "c", "d", "e") )

  // もしくは演算子っぽく書くと
  assert( (strSeq :\ emp)(_+:_) == Seq("a", "b", "c", "d", "e") )

aggregate メソッド
aggregate メソッドは foldLeft メソッドを一般化したようなメソッドです。 第1引数リストには初期値を、第2引数リストの第1引数にはアキュムレータと要素を結合する演算(foldLeft の二項演算に当たる)を、第2引数には2つのアキュムレータを結合する演算を指定します。 アキュムレータ同士の結合を指定することで、畳み込みを並列実行できるようになります。

Seq オブジェクトの要素が x0, x1, x2, x3、初期値が z、第2引数リストの2つの関数を順に f, g とすると、aggregate メソッドは以下のような計算をします(必ずしもこの通りかどうかは知りませんが):

g(f(f(z, x0), x1), f(f(z, x2), x3))

数値を要素とする Seq オブジェクトに対して、要素の2乗の和を計算するには以下のようにします:

  val intSeq = Seq(1, 3, 5, 7, 9, 11, 13, 15, 17, 19)

  // aggregate メソッド
  val result0 = intSeq.aggregate(0)((s, n) => s + n*n, _+_)
  assert( result0 == 1330)

アキュムレータ同士の結合は単なる加算なので特に難しくはありませんね。

結果型がコレクションになる場合もやっておきましょう。 第2引数リストの第1引数はアキュムレータと要素の結合なので :+ や +: を使い、第2引数はアキュムレータ同士の結合なので ++ や ++: を使います:

  val intSeq = Seq(1, 3, 5, 7, 9, 11, 13, 15, 17, 19)
  val emp = Vector[Int]()

  // 要素の2乗を要素とする Seq を返す
  val result1 = intSeq.aggregate(emp)((s, n) => s :+ n*n, _++: _)
  assert( result1 == Seq(1, 9, 25, 49, 81, 121, 169, 225, 289, 361) )

aggregate メソッドは左畳み込みしかないようなので、コレクションを構築する際にアキュムレータの末尾に要素を効率よく追加する Vector オブジェクトを使っています。

reduce 系のメソッド
reduce メソッドは、fold 系のメソッドと異なり、初期値を与えずに要素を畳み込みます。 もしくは、Seq の先頭 (head) を初期値として、その後続 (tail) を fold で畳み込むような感じです:

  seq.tail.fold(seq.head)(f)

reduce 系のメソッドでは、foldLeft や foldRight などと異なり、返り値の型は要素型(もしくはそのスーパー型)に制限されます。 和や積のような数学的な演算に寄った畳み込みという感じですかね(和と積はそれぞれ sum, product メソッドで計算できますが)。

  val intSeq = Seq(1, 3, 5, 7, 9, 11, 13, 15, 17, 19)

  // reduce メソッド
  assert( intSeq.reduce(_+_) == 100 )

reduce 系のメソッドは、初期値を与えないので fold 系のものより少し書くのが楽ですが、元の Seq オブジェクトが空だと例外を投げます。

  try{
    Seq[String]().reduce(_+_)
      // 空の Seq に対して呼び出すと例外を投げる
    assert(false)
  }catch{
    case ex: UnsupportedOperationException => assert(true)
  }

空の Seq オブジェクトに対しても例外が投げられないように、結果を Option で包んで返す reduceOption メソッドが定義されています:

  val intSeq = Seq(1, 3, 5, 7, 9, 11, 13, 15, 17, 19)

  // reduceOption メソッド
  intSeq.reduceOption(_+_) match {
    case Some(sum) => println(s"[reduceOption] $sum")
    case None      => println("[reduceOption] error!")
  }
    // 「[reduceOption] 100」と表示される

reduceLeftOption, reduceRightOption メソッドも同様です。

reduce 系のメソッドと fold 系のメソッドの動作の違いは、以下のようなコードを見ると分かりやすいかと思います:

  val strSeq = Seq("a", "b", "c", "d", "e")
  val f: (String, String) => String = (x, y) => x + ", " + y  // ", " を挟んで文字列を結合

  // fold 系メソッド
  assert( strSeq.fold("")(f) == ", a, b, c, d, e")
  assert( strSeq.foldLeft("")(f) == ", a, b, c, d, e")  // fold と同じ
  assert( strSeq.foldRight("")(f) == "a, b, c, d, e, " )

  // reduce 系メソッド
  assert( strSeq.reduce(f) == "a, b, c, d, e" )
  assert( strSeq.reduceLeft(f) == "a, b, c, d, e" )  // reduce と同じ
  assert( strSeq.reduceRight(f) == "a, b, c, d, e" )  // reduce と同じ結果

fold 系のメソッドでは、(おそらく)望んでいない位置に ", " が挿入されますが、reduce ではどの場合でも要素間にしか挿入されません。

reduce 系のメソッドでは結果型が要素型と同じでなければならないので、Left/Right の使い分けが fold 系のメソッドのサンプルコードでは参考にならなさそうですね。 ということで、reduceRight メソッドを使うサンプルコードを書いておきましょう。 やはり List (LinearSeq) の構築がよく使うのではないでしょうか。

  val seqSeq = Seq(Seq(0), Seq(0, 1), Seq(0, 1, 2), Seq(0, 1, 2, 3))
  
  // reduceRight メソッド
  assert( seqSeq.reduceRight(_++:_) == Seq(0, 0, 1, 0, 1, 2, 0, 1, 2, 3) )

まぁ、このコードは flatten メソッド使えばいいんですけどね。

scan 系のメソッド
scan メソッドは畳み込みの途中のアキュムレータの値を要素とする Seq オブジェクトを返します。 数値を要素とする Seq オブジェクトに対しては、その項までの和を要素とする Seq オブジェクトを返します。

  val intSeq = Seq(1, 3, 5, 7, 9, 11, 13, 15, 17, 19)  // intSeq の長さは10

  // scan メソッド
  val resultSeq = intSeq.scan(0)(_+_)
  assert( resultSeq.length == 11)  // 返される Seq オブジェクトの長さは11
  assert( resultSeq == Seq(0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100) )
    // == Seq(0, 0+1, 0+1+3, 0+1+3+5, 0+1+3+5+7, ... )

scan メソッドに与えた初期値が返される Seq オブジェクトの先頭要素となるので、結果の Seq オブジェクトの長さは元の Seq オブジェクトの長さより1だけ大きくなります。

文字列を要素とする Seq を結合するサンプルの方が分かりやすいかもしれません:

  val strSeq = Seq("a", "b", "c", "d", "e")

  assert( strSeq.scan("")(_+_) == Seq("", "a", "ab", "abc", "abcd", "abcde") )

scanLeft メソッド, scanRight メソッドはそれぞれ foldLeft, foldRight メソッドの畳み込み途中のアキュムレータを Seq オブジェクトとして返します。

  val strSeq = Seq("a", "b", "c", "d", "e")

  // scanLeft メソッド
  assert( strSeq.scanLeft("")(_+_) == Seq("", "a", "ab", "abc", "abcd", "abcde") )

  // scanRight メソッド
  assert( strSeq.scanRight("")(_+_) == Seq("abcde", "bcde", "cde", "de", "e", ""))

scanRight メソッドでは計算途中のアキュムレータの値が右(末尾)から並んでいます(まぁ当然かと思いますが)。

scanLeft, scanRight メソッドでは、scan メソッドとは異なり(foldLeft, foldRight と同じように)要素型以外の型を結果型にできます:

  val strSeq = Seq("a", "b", "c", "d", "e")

  // scanLeft メソッド
  val emp0 = Vector[String]()
  assert( strSeq.scanLeft(emp0)(_:+_) ==
    Seq(
      Seq(),
      Seq("a"),
      Seq("a", "b"),
      Seq("a", "b", "c"),
      Seq("a", "b", "c","d"),
      Seq("a", "b", "c", "d", "e"))
  )

  // scanRight メソッド
  val emp1 = Seq[String]()
  assert( strSeq.scanRight(emp1)(_+:_) ==
    Seq(
      Seq("a", "b", "c", "d", "e"),
      Seq("b", "c", "d", "e"),
      Seq("c", "d", "e"),
      Seq("d", "e"),
      Seq("e"),
      Seq()
    )
  )

やはり、List (LinearSeq) を構築する場合には右畳み込みを使いましょう。

思いの外長くなりましたが、fold, reduce などの畳み込み関連のメソッド終了。 畳み込みは関数型プログラミングを行う、理解する上で重要なメソッドだと思うので、理解が怪しげなところがあったら自分でいろいろ試してみましょう。 次回は多重コレクションに関連するメソッドを試していく予定。

Scalaスケーラブルプログラミング第3版

Scalaスケーラブルプログラミング第3版

*1:Haskell では foldr を使って無限長のコレクションに対して動作するコードを書けるみたいですが、Scala では同様のコードを書いても処理が返ってこなくなるようです。 ただし、Scala でも contains や forall, exists などのメソッドのように、無限長のコレクションに対しても要素を走査するが結果を返す(場合がある)ものもあります。