본문 바로가기
Kaggle마스터가 되기 위한 몸부림/실용 머신러닝 A to Z

[scikit-learn] transform()과 fit_transform()의 차이는 무엇일까?

by Steve-Lee 2021. 2. 1.

왜 scikit-learn에서 모델을 학습할 때, train dataset에서만 .fit_transform()메서드를 사용하는 건가요?

TL;DR

안녕하세요 steve-lee입니다. 실용 머신러닝 A to Z 첫번 째 시간은 scikit-learn에서 자주 사용하는 transform(), fit_transform() 메서드와 관련된 질문에 대한 답을 찾아가는 시간입니다. 


시작하며

  • 오늘의 주제는 scikit-learn의 transfrom() 메서드fit_transform() 메서드의 차이를 알아보는 것입니다
  • 우리는 Machine Learning 프로젝트 또는 Kaggle 등을 할 때, 머신러닝을 위한 파이썬 프로그래밍 라이브러리로 scikit-learn을 많이 사용합니다
  • scikit-learn을 사용하다보면 transform()과 fit_transform() 을 자주 사용하게 되는데요. 한번쯤은 이런 생각을 하셨을 것 같습니다

왜 fit_transform()은 training data에만 사용하는 걸까? 왜 transfrom은 test data에만 사용하는 걸까?

  • 오늘은 위의 물음에 대한 정답을 찾아가는 시간이 될 것 같습니다
  • 그럼 시작하겠습니다

 

제 블로그의 글이 도움이 되셨다면 좋아요와 댓글 부탁드리겠습니다.글을 지속적으로 작성하는데 많은 동기부여가 됩니다.감사합니다. 오늘도 편안한 하루 보내세요!🙏

 

scikit-learn, fit_transform(), transform()

  • 우선 scikit-learn부터 간단히 살펴보겠습니다
  • scikit-learn이란 Machine Learning(이하 ML)을 위한 가장 강력한 라이브러리다라고 생각하면 좋을 것 같습니다
    • 특히 Python Programming을 하신다면 ML을 위한 최고의 라이브러리가 아닐까 생각하네요
  • 우리가 scikit-learn에서 제공하는 라이브러리등을 통해 ML을 할 때 아마도 가장 많이 마주치는 메서드는 fit_transform()transform()이 아닐까 생각합니다

fit_transform()과 transform()의 설명을 위해 scikit-learn에서 제공하는 sklearn.preprocessing.StandarrScaler() 클래스를 통해 train data와 test data를 스케일링 한다고 가정해 보겠습니다


Note

  1. Data Standardization이란 feature를 리스케일링 하여 feature의 평균이(mean) 0 분산이(variance) 1이 되게 만들어주는 과정입니다(일석이조입니다)
  2. 이러한 Standardization의 궁극적인 목표는 모든 feature들을 공통의 척도로 변경해 주는 것을 의미합니다. 즉 값의 범위의 차이를 왜곡하지 않으면서 모든 feature를 공통의 척도로 스케일을 해주는 것을 의미합니다
  3. sklearn.preprocessing.StandarScaler()는 각각의 feature마다 독립적으로 값을 중앙으로 옮기고 스케일링을 해줍니다
  • 수식은 아래와 같습니다

  • 수식으로보면 조금 더 이해하기 쉽습니다
  • 그림으로 이해하면 아래와 같습니다

값의 범위가 평균 0, 분산 1이 되도록 만들어 줍니다

 

fit_transform()

  • 결론부터 말씀드리자면 fit_transform()은 train dataset에서만 사용됩니다
  • 우리가 만든 모델은 train data에 있는 mean과 variance를 학습하게 됩니다
  • 이렇게 학습된 Scaler()의 parameter는 test data를 scale하는데 사용됩니다
  • 다시말해 train data로 학습된 Scaler()의 parameter를 통해 test data의 feature 값들이 스케일 되는 것입니다

transform()

  • train data로부터 학습된 mean값과 variance값을 test data에 적용하기 위해 transform() 메서드를 사용합니다

 

왜 test data에서는 fit_transform을 사용하지 않는가?🤔

  • 만약에 fit_transform을 test data에도 적용하게 된다면 test data로부터 새로운 mean값과 variance값을 얻게 되는 것입니다
  • 즉, 우리의 모델이 test data도 학습하게 되는 것입니다
  • test data는 'Surprise'한 데이터 셋입니다. 그런데 이 데이터마저 학습하게 된다면 우리의 모델이 처음 보는 데이터에 대해서 얼마나 성능이 좋은지 알 수 없게 되는 것입니다
  • 다시 말해, test data는 모델이 학습된 후에 평가할 때만 사용되어야 하는데 fit_transform을 test data에도 하게 된다면 모델의 성능을 평가할 수 없게 되는 것입니다

마치며...

그동안  scikit-learn의 모델들을 학습시킬 때 무의식중에 fit_transform()메서드와 transform()메서드를 사용했던 것 같습니다. Pseudo Lab에서 스터디원분께서 질문을 해주시지 않았더라면 두 메서드의 차이를 모른채 계속 사용했을 것 같습니다... 질문해주셔서 감사합니다...!!

 

머신러닝의 메카니즘을 이해한다면 fit_transform()메서드 transform()메서드의 차이를 보다 잘 이해할 수 있을 것 같습니다. 

  • 우리는 tran data를 통해 데이터의 패턴을 학습하고 test data를 통해 처음 보는 데이터에 대해서도 일반화된 성능을 얻길 원합니다
  • 모델링을 할 때도 train data로 모델의 파라미터를 학습시키고 test data에 대해서는 train data로 학습된 모델의 성능을 측정하길 원하는 것입니다
  • 따라서 fit_transform()메서드는 학습을 위한 train data에 사용되고 test data에서는 transform()메서드만 사용해줘야 합니다. 만약 test data에 대해서도 fit_transform()메서드를 사용하게 된다면 모델은 test data에 대해서도 학습을 하게 되는 꼴이 됩니다. (그렇다면 일반화된 성능을 기대할 순 없겠네요)

이상으로 포스팅을 마치도록 하겠습니다. 감사합니다.

P.S - 제가 잘 못 전달한 내용이 있거나 궁금하신 점이 있으시다면 얼마든지 댓글로 문의주신다면 감사하겠습니다!🙏

 

Reference

 

What and why behind fit_transform() vs transform() in scikit-learn !

Scikit-learn is the most useful library for machine learning in Python programming language. It has a lot of tools to build a machine…

towardsdatascience.com

 

댓글