末尾再帰
ScalaByExample - Example 4.6 Tail Recursionより。Scalaでは末尾再帰は最適化されます。無駄なスタックを生成しないために、再帰を利用する関数は末尾再帰となるように設計しましょう、とのこと。
末尾再帰とは
再帰呼び出しを含む手続きが、自分自身の呼び出しをその手続きの末尾に行うように記述をすること。
例として「1から引数までの総和を計算する関数」の末尾再帰版とそうでない版を作ってみました。
// 1から引数までの総和を計算する関数 // 末尾再帰でない版 def sum( i:int ):int = { if ( i == 0 ) i else i + sum(i-1) } // 末尾再帰版 def sumTailCalls( i:int ):int = { def _sum( j:int, total:int ):int = { if ( j == 0 ) total else _sum( j-1, total+j ) } _sum( i, 0 ) } var count = 4 println( sum( count )) // 10 println( sumTailCalls( count )) // 10
ここで、各関数の処理を1ステップごとみていくと次のようになります。
// 末尾再帰でない版 sum(4) 4 + sum(4-1) 4 + { 3 + sum(3-1) } 4 + { 3 + { 2 + sum(2-1) } } 4 + { 3 + { 2 + { 1 + sum(1-1) } } } 4 + { 3 + { 2 + { 1 + { 0 } } } } // 末尾再帰版 sumTailCalls(4) _sum( 4, 0 ) _sum( 4-1, 0+4 ) _sum( 3-1, 4+3 ) _sum( 2-1, 7+2 ) _sum( 1-1, 9+1 ) 10
末尾再帰でない版はステップごとに4,3,2と式が長くなっていき、最後にそれらを合計して結果が生成されます。このとき、4や3や2は関数のスタックに存在するiとして保持されます。そのため、iが0になるまでスタックを覚えておく必要があります。
一方、末尾再帰版では式の長さは一定です。必要な値は引数として次の再帰関数に渡されており、スタックの値は次の再帰関数を呼び出した後に使用されることはありません。なので、VMやコンパイラは不要になったスタックを破棄したり、再利用したりして、無駄なスタックの生成を抑制することができます。
末尾再帰しないとどうなるか
再帰する回数が増えた場合に、「StackOverflowError」になる可能性があります。
// 末尾再帰版だと最適化されるため、余計なスタックが生成されない // 引数を増やすと、末尾再帰でない版はjava.lang.StackOverflowErrorになる。 // 末尾再帰版はOK count = 10000 println( sumTailCalls( count )) // OK println( sum( count )) // これはjava.lang.StackOverflowError
実行結果です。
50005000 java.lang.StackOverflowError at TailCallsSample$.sum$1(TailCallsSample.scala:9) at TailCallsSample$.sum$1(TailCallsSample.scala:9) at TailCallsSample$.sum$1(TailCallsSample.scala:9) at TailCallsSample$.sum$1(TailCallsSample.scala:9)
なので、再帰回数が多くなりそうな関数は末尾再帰になるように実装しましょう!というか、推定再帰回数に関係なく、とりあえず末尾再帰になるようにしときましょう!
おまけ:Javaではどうなの?
JavaVMにそんな機能あったっけ?と思って試してみた。
// 1から引数までの総和を計算する関数 // 末尾再帰版 static int sumTailCalls( int i ) { return _sumTailCalls( i, 0 ); } static int _sumTailCalls( int i, int total ) { return i == 0 ? total : _sumTailCalls( i-1, total+i ); } // メイン public static void main( String[] args ) { int i = 4; System.out.println( sumTailCalls(i) ); i = 100000; System.out.println( sumTailCalls(i) ); }
実行結果です。そんな機能はないようだ。
10 Exception in thread "main" java.lang.StackOverflowError at Sum._sumTailCalls(Sum.java:19) at Sum._sumTailCalls(Sum.java:19) at Sum._sumTailCalls(Sum.java:19) at Sum._sumTailCalls(Sum.java:19)