본문 바로가기
AI/Machine Learning

cross_val_score란?

by comflex 2025. 3. 29.
728x90
반응형

cross-validation

cross_val_score란?

cross_val_score는 사이킷런(scikit-learn)에서 제공하는 함수로, 교차 검증(Cross Validation)을 수행하여 모델의 성능을 평가하는 데 사용됩니다. 데이터를 여러 개의 폴드(fold)로 나누고, 각 폴드를 한 번씩 검증 데이터로 사용하면서 모델을 훈련하고 평가하는 방식입니다.


1. cross_val_score의 동작 방식

교차 검증은 데이터셋을 K개의 폴드로 나누고, 각 폴드를 한 번씩 검증 데이터로 사용하며 K번 훈련 및 평가를 반복하는 방식입니다. 이를 통해 모델이 특정 데이터에 과적합(overfitting)되는 것을 방지하고 일반화 성능을 향상시킬 수 있습니다.

주요 과정:

  1. 데이터를 K개의 폴드로 나눕니다.
  2. K-1개의 폴드로 모델을 학습합니다.
  3. 남은 1개의 폴드로 모델을 평가합니다.
  4. 위 과정을 K번 반복하며, 각 평가 점수를 기록합니다.
  5. 최종적으로 평균 점수를 계산하여 모델의 성능을 측정합니다.

2. cross_val_score 사용법

2.1. 기본적인 사용법

from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression

# 샘플 데이터 생성
X, y = make_regression(n_samples=100, n_features=1, noise=0.1, random_state=42)

# 모델 생성
model = LinearRegression()

# 교차 검증 수행 (5-폴드 교차 검증)
scores = cross_val_score(model, X, y, cv=5, scoring='r2')

print("각 폴드의 R^2 점수:", scores)
print("평균 R^2 점수:", scores.mean())

2.2. 다양한 scoring 옵션 사용하기

cross_val_score에서는 scoring 매개변수를 통해 다양한 성능 지표를 설정할 수 있습니다.

from sklearn.metrics import make_scorer, mean_squared_error

# MSE(평균 제곱 오차)로 평가
mse_scores = cross_val_score(model, X, y, cv=5, scoring=make_scorer(mean_squared_error))
print("각 폴드의 MSE:", mse_scores)
print("평균 MSE:", mse_scores.mean())
반응형
728x90

2.3. 분류 문제에서의 사용

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier

# 아이리스 데이터셋 로드
iris = load_iris()
X, y = iris.data, iris.target

# 랜덤 포레스트 모델 생성
clf = RandomForestClassifier(n_estimators=100, random_state=42)

# 교차 검증 수행 (정확도 평가)
scores = cross_val_score(clf, X, y, cv=5, scoring='accuracy')

print("각 폴드의 정확도:", scores)
print("평균 정확도:", scores.mean())

3. cross_val_score 사용 시 주의할 점

  1. 데이터 분할 방식 선택:
    • cv=5로 설정하면 5-폴드 교차 검증을 수행합니다.
    • 데이터 크기가 작다면 cv 값을 너무 크게 설정하지 않는 것이 좋습니다.
  2. 불균형 데이터 처리:
    • 분류 문제에서 클래스 비율이 불균형할 경우, StratifiedKFold를 사용하여 클래스 분포를 유지하는 것이 좋습니다.
  3. 평균 성능과 편차 분석:
    • scores.mean() 뿐만 아니라 scores.std()도 함께 확인하여 모델 성능의 변동성을 분석할 수 있습니다.
print("평균 성능:", scores.mean())
print("표준 편차:", scores.std())
728x90
반응형