본문 바로가기
Java

재귀 호출 최적화(Tail-Call)

by ybs 2021. 5. 11.
반응형

일반적인 재귀 호출 방식은 입력 데이터가 매우 많은 경우에 StackOverflowError 가 발생할 위험이 있다.

TCO(Tail-Call Optimization) 기술을 사용하면 이 문제를 해결할 수 있다.

 

먼저 일반적인 재귀를 사용해서 팩토리얼을 계산해보자.

public class Factorial {

	public static void main(String[] args) {
		System.out.println(factorialRec(5)); // 120
	}

	public static int factorialRec(final int number) {
		if (number == 1) {
			return number;
		} else {
			return number * factorialRec(number-1); 
		}
	}
}

 

문제가 없어보이지만 숫자가 커지면 java.lang.StackOverflowError 가 발생한다.

이것은 재귀 자체의 문제가 아니라 재귀가 완료할 때까지 연산의 부분 결과를 계속 스택에 보관하고 있어야 한다는 것이 문제다.

 

return number * factorialRec(number-1); 

위 코드에서 실행하는 마지막 오퍼레이션은 곱셈(*)이다. number 에 값을 갖고 있는 상태에서 다음 factorialRec() 호출 결과를 기다린다.

결국 계속 호출할 때마다 콜 스택(call stack)에 저장하게 되고 팩토리얼을 구하려는 숫자가 커지면 스택은 폭발한다. 이 문제를 해결하려면 스택에 저장하지 않고 재귀를 사용할 수 있는 방법이 필요하다.

 

TCO 방식으로 바꾼 Factorial 클래스는 아래와 같다. TailCall 과 TailCalls 가 사용되고 이게 핵심이기 때문에 각각의 역할을 잘 이해해야 전체 Factorial 코드가 이해된다.

import static com.explore.recur.TailCalls.call;
import static com.explore.recur.TailCalls.done;

public class Factorial {

  public static int factorial(final int number) {
    return factorialTailRec(1, number).invoke();
  }
  
  private static TailCall<Integer> factorialTailRec(final int factorial, final int number) {
    if (number == 1)
      return done(factorial);
    else
      return call(() -> factorialTailRec(factorial * number, number - 1));
  }

  public static void main(final String[] args) {
    System.out.println(factorial(5));
  }
}

본격적으로 TailCall, TailCalls 를 알아보기 전에 대략적으로 코드를 살펴보자. main 문에서 factorial 을 호출하고, 실제적으로 factorialTailRec invoke 를 통해 실행이 된다. factorialTailRec 메서드 내부를 보면 number 가 1일 때 done 으로 끝내는 엔드조건이 있고 그 외에는 call 로 람다 함수가 호출된다(number 1씩 줄음).

 

factorialTailRec 메서드의 리턴 타입은 TailCall<Integer> 이다. 여기서 중요한 아이디어는 done() 메서드를 호출하면 재귀 종료 시그널을 보내고 call() 메서드를 계속 실행한다면 계속 재귀 호출을 요청한다.

 

먼저 TailCall 인터페이스를 살펴보자.

@FunctionalInterface
public interface TailCall<T> {

  TailCall<T> apply();

  default boolean isComplete() {
    return false;
  }

  default T result() {
    throw new Error("not implemented");
  }

  default T invoke() {
    return Stream.iterate(this, TailCall::apply)
            .filter(TailCall::isComplete)
            .findFirst()
            .get()
            .result();
  }
}

하나의 추상메서드와 3개의 디폴트 메서드가 있다. isComplete() 메서드는 항상 false 값을 리턴한다. result() 메서드는 재귀가 진행되는 동안 절대 호출되지 않기 때문에 예외를 발생시킨다. 이게 무슨말이냐면 재귀가 돌면서 여러 TailCall 인터페이스 구현체들이 존재하게 되는데 마지막 TailCall 인터페이스 구현체(종단) 에서만 result 가 사용된다.

 

중요한 작업은 invoke() 메서드에 있는 짧은 코드에서 이루어진다. 이 메서드는 apply() 메서드와 함께 동작하며 이 apply() 메서드는 다음 실행을 기다리고 있는 TailCall 인스턴스를 리턴한다. 여기서 말하는 apply 는 아래 람다 구현체를 말한다.

() -> factorialTailRec(factorial * number, number - 1)

 

invoke() 메서드는 두 가지를 책임져야 한다. 첫번째는 재귀 과정이 끝날 때까지 대기하고 있는 TailCall 재귀 메서드를 통해 반복적으로 이터레이션한다. 두 번째는 재귀 과정이 끝에 도달하면, 최종 결과(종단 TailCall 인스턴스의 result() 메서드에 있는)를 리턴해야 한다.

 

Stream 인터페이스는 정적 메서드 iterate() 를 가지며 이 메서드는 무한 스트림(infinite Stream)을 생성한다. 이 메서드는 두 개의 파라미터를 갖는데 하나는 컬렉션을 시작하기 위한 시드(seed) 값이고 다른 하나는 UnaryOperator 인터페이스의 인스턴스로서 컬렉션에 데이터를 공급하는 역할을 한다. iterate() 메서드가 리턴하는 스트림은 terminating 메서드를 사용하기 전까지 엘리먼트에 대한 생성을 지연한다.

 

이터레이션은 isComplete() 메서드가 완료됐다는 것을 리포트할 때까지 계속 실행된다. 그러나 TailCall 인터페이스에 있는 이 메서드의 default 구현은 항상 false 를 리턴한다. 즉, done() 메서드에서 TailCall 의 특별한 버전을 리턴하며 재귀 과정이 종료되었음을 알려주는 역할을 한다(isComplete 를 true 로 리턴). 또한 TailCall 의 마지막 단계에서는 apply() 메서드가 절대 호출되지 않기 때문에 예외를 발생시킨다.

public class TailCalls {
  public static <T> TailCall<T> call(final TailCall<T> nextCall) {
    return nextCall;
  }

  public static <T> TailCall<T> done(final T value) {
    return new TailCall<T>() {
      @Override
      public boolean isComplete() {
        return true;
      }

      @Override
      public T result() {
        return value;
      }

      @Override
      public TailCall<T> apply() {
        throw new Error("not implemented");
      }
    };
  }
}

 

5 팩토리얼을 계산한다고 했을 때 전체 flow  정리

factorialTailRec(1, 5) 가 실행되고, number가 1이 아니니까 else 로 call(() -> factorialTailRec(5, 4)); 가 수행된다.

TailCalls 클래스의 call 메서드를 통해 () -> factorialTailRec(5, 4) 구현체를 갖는 TailCall 인스턴스가 리턴이 되고

factorialTailRec 메서드가 리턴되면서 그 다음 invoke() 메서드가 실행된다. invoke() 메서드에서 Stream.iterate 로 TailCall::apply 가 실행되면서 factorialTailRec(5, 4) 가 실행이 된다. number 가 1이 아니니까 다시 call(() -> factorialTailRec(20, 3)) 이 리턴되고(call 메서드 안에 있는 구현체를 갖고 있는 TailCall 인스턴스가 리턴되고), TailCall 의 isComplete 가 false 이기 때문에 다시 TailCall::apply 가 실행되면서 factorialTailRec(20, 3) 이 실행이 된다. number가 1이 아니니까 다시 call(() -> factorialTailRec(60, 2)) 가 리턴되고 isComplete 가 false 이기 때문에 다시 TailCall::apply 가 실행되면서 factorialTailRec(60, 2) 이 실행이 된다. number 가 1이 아니니까 다시 call(() -> factorialTailRec(120, 1)) 이 리턴되고 isComplete 가 false 이기 때문에 다시 TailCall::apply 가 실행되면서 factorialTailRec(120, 1) 이 실행이 된다. 이제 number 가 1이니까 done 메서드로 120 이 전달되고

isComplete 가 true 이고 result 메서드 value가 120인 TailCall 인스턴스가 새롭게 만들어져 리턴된다. 이제 isComplete 가 true 이므로 findFirst().get().result() 를 통해 120이 리턴되고 invoke 메서드가 끝난다.

 

이렇게 하면 큰수의 팩토리얼 계산도 StackOverflowError 없이 가능하지만 한가지 더 신경써야 할 부분이 있다. 

바로 Integer 범위를 넘어서게 되는 Arithmetic Overflow 다. 이 문제 해결을 위해서는 BigInteger 로 바꿔줘야 한다.

public class BigFactorial {
  public static BigInteger decrement(final BigInteger number) {
    return number.subtract(BigInteger.ONE);
  }

  public static BigInteger multiply(final BigInteger first, final BigInteger second) {
    return first.multiply(second);
  }

  final static BigInteger ONE = BigInteger.ONE;
  final static BigInteger FIVE = new BigInteger("5");
  final static BigInteger TWENTYK = new BigInteger("20000");

  public static BigInteger factorial(final BigInteger number) {
    return factorialTailRec(BigInteger.ONE, number).invoke();
  }
  
  private static TailCall<BigInteger> factorialTailRec(
          final BigInteger factorial, final BigInteger number) {
    if (number.equals(BigInteger.ONE))
      return done(factorial);
    else
      return call(() ->
              factorialTailRec(multiply(factorial, number), decrement(number)));
  }

  public static void main(final String[] args) {
    System.out.println(factorial(FIVE));
    System.out.println(factorial(TWENTYK));
  }
}

 

마지막으로 이전에 썼던 글, reflection 사용해서 api 문서화 yangbongsoo.tistory.com/23에서 메서드 리턴타입을 알기 위해서 재귀 호출이 필요했었는데 그걸 TCO 방식으로 바꿔보자.

import static com.explore.reflection.TailCalls.*;

import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.List;

import org.springframework.http.ResponseEntity;

public class Main2 {
  public static void main(final String[] args) throws Exception {
    Method method = Main2.class.getDeclaredMethod("method1", null);
    System.out.println(getReturnType(method.getGenericReturnType(), new ReturnTypeDto()).invoke());
  }

  private ResponseEntity<ResultData<List<String>>> method1() {
    return null;
  }

  public static TailCall<ReturnTypeDto> getReturnType(Type genericReturnType, ReturnTypeDto returnTypeDto) {
    if (genericReturnType instanceof Class<?>) {
      returnTypeDto.setReturnType((Class<?>)genericReturnType);
      return done(returnTypeDto);
    }

    else if (genericReturnType instanceof ParameterizedType) {

      Type unWrappedType = (((ParameterizedType)genericReturnType).getActualTypeArguments()[0]);

      if (unWrappedType instanceof Class<?>) {
        returnTypeDto.setWrappingReturnType(null);
      } else {
        ParameterizedType unWrappedParameterizedType = (ParameterizedType)unWrappedType;
        // List 한정
        if (unWrappedParameterizedType.getRawType().getTypeName().equalsIgnoreCase("java.util.List")) {
          returnTypeDto.setWrappingReturnType(unWrappedParameterizedType);
          return call(() -> getReturnType(unWrappedParameterizedType.getActualTypeArguments()[0], returnTypeDto));
        }
      }

      return call(() -> getReturnType(unWrappedType, returnTypeDto));

    } else {
      returnTypeDto.setWrappingReturnType(null);
      returnTypeDto.setReturnType(null);
      return done(returnTypeDto);

    }
  }

  private static class ResultData<T> {
    private T data;

    public T getData() {
      return data;
    }

    public void setData(T data) {
      this.data = data;
    }
  }

  private static class ReturnTypeDto {
    private ParameterizedType wrappingReturnType;
    private Class<?> returnType;

    public ParameterizedType getWrappingReturnType() {
      return wrappingReturnType;
    }

    public void setWrappingReturnType(ParameterizedType wrappingReturnType) {
      this.wrappingReturnType = wrappingReturnType;
    }

    public Class<?> getReturnType() {
      return returnType;
    }

    public void setReturnType(Class<?> returnType) {
      this.returnType = returnType;
    }

    @Override
    public String toString() {
      return "ReturnTypeDto{" +
              "wrappingReturnType=" + wrappingReturnType +
              ", returnType=" + returnType +
              '}';
    }
  }
}

 

TailCall, TailCalls 은 그대로 재활용이 가능하다.

 

 

원문 : Functional Programming in Java8(출판사 : 루비페이퍼)

반응형