카테고리 없음

Tensorflow 2.5 GradientTape()에서 속도가 느려지는데...

미친토끼 2021. 8. 10. 18:39

자연어 처리(NLP) 분야에서 채팅 모델을 개발하고 있는데, 아직은 공부하는 단계다.

 

며칠 전에 기존에 문제 많던 GTX 1060 3GB를 처분하고,  RTX 3060 12GB를 구입했는데, 딥러닝의 텐서 플로 가속에서 만족스러웠다.

 

현재 들여다보고 있는 코드는 '텐서플로 2와 머신러닝으로 시작하는 자연어처리' 책의 seq2seq 방식 채팅 모듈인데, 불만이 많은 책이다. 기본적으로 이 책 저자의 국어 실력부터가....(중략) 게다가,

 

이 책의 채팅 모듈은 훈련 모델과 채팅 실행 모듈이 분리되어 있지 않아 꽤 불편한 코드다. 훈련 코드에서도 훈련 제개 모드를 지원했으면 했는데, subclassing 방식으로 코딩을 해놔서 훈련 모델 저장 및 불러오기가 쉽지가 않다. (그러면서 subclassing 방식의 자유도가 높은 점만 칭송하고 있다. subclassing 방식이 코딩하기엔 편하긴 하지... 하지만 훈련 모델을 저장하고 불러오려고 해보라. 잘 되나...? 저장한 가중치를 다른 코드에서 불러오려고 해도 그것도 만만치 않더라.)

 

어쨌던 이 책의 부족한 점, 즉 채팅 코드 분리 및 훈련 재개 모드 지원을 위해 좀더 공부를 할 필요가 있어서 아래 코드를 뜯어보고 있다.

 

https://github.com/milinzhang/simple-seq2seq-chatbot

 

GitHub - milinzhang/simple-seq2seq-chatbot: a simple seq2seq model based on tensorflow 2, using cornell movie dialog corpus

a simple seq2seq model based on tensorflow 2, using cornell movie dialog corpus - GitHub - milinzhang/simple-seq2seq-chatbot: a simple seq2seq model based on tensorflow 2, using cornell movie dialo...

github.com

 

이 분의 코드를 하나하나 뜯어보면서, tensorflow의 편리한 model.fit() 방식이 아닌, GradientTape() 방식을 공부하면서 느낀 점은, 지난해 미분 공부해놓길 잘했다는 점이다. GradientTape은 pyTorch의 코딩 방식과 비슷해서 많이 어렵지는 않은데, 어쨌던 복잡한 방식이라 하루이틀 공부해보니 꽤 재미있는 녀석이다.

 

문제는 위의 코드를 실행하면 CPU 방식보다 그다지 빨라지지 않는다는 점이다. RTX 3060이면 CPU 사용방식보다 20배 가량 빨라야 되는데 겨우 2배 정도 빠른 것이다.

model.fit() 방식은 매우 빠른데, GradientTape() 방식 코딩의 단점인가 싶었는데, 아뿔사, 원인을 찾았다. 

 

def onehotencoding(matrix, dim):
  onehot = np.zeros((matrix.shape[0], matrix.shape[1], dim)) # (32, 22, 8000) 3차원 원핫 
  for i, seq in enumerate(matrix):
    for j, idx in enumerate(seq):
      if idx > 0:
        onehot[i][j][idx-1] = 1
  return onehot  


for epoch in range(epochs):
	....
    for (batch, (input_q, input_a, target_a):
    	....
        target_a_onehot = onehotencoding(target_a, vocab_size)
        # target_a_onehot = tf.one_hot(target_a, vocab_size)

decoder의 target 데이터를 원핫 인코딩 하고 있는데, tf.one_hot()을 사용하지 않고 직접 코드를 작성하고 있다.

보시다시피 3차원 텐서를 만들어서 원핫 인코딩을 시시콜콜하게 하고 있다. 그것을 for문으로 epochs을 돌릴 때 이 함수를 호출하고 있는 것이다. tf.one_hot() 함수로 바꾸어주니 엄청나게 속도가 빨라졌다. ( 주석처리한 # target_a_onehot = tf.one_hot(...) 코드이다.)

 

텐서 플로 코드가 최적화가 잘 되어 있고 처리 속도 향상을 위해 인덱스 해싱 기법, 데이터 캐싱 기법 등을 동원했을 것이다.

원핫을 텐서 플로 함수를 사용했을 때는 CPU만으로 돌릴 때는 2~3배 속도 향상이 있었고, GPU로 돌렸을 대는 약 50배 속도 향상이 있었다. GradientTape() 의 문제가 아니고, 엄청나게 시간을 잡아먹는 느린 수작업 코드가 떡하니 가운데 틀어박혀 있었던 것이다. 이런 것들을 잘 잡아내어 최적화 잘 하는 사람이 진정 고수이리라.

 

문제를 하나 해결해놓으니 왜 이리 기분이 좋은지...^^