관리 메뉴

Leo's Garage

[Numpy] Cost Function 연산 시, linalg.norm을 사용하는 이유 본문

Study/파이썬

[Numpy] Cost Function 연산 시, linalg.norm을 사용하는 이유

LeoBehindK 2024. 9. 18. 23:05
728x90
반응형

예시 코드는 아래와 같다. 

# PACKAGE
# First load the worksheet dependencies.
# Here is the activation function and its derivative.
sigma = lambda z : 1 / (1 + np.exp(-z))
d_sigma = lambda z : np.cosh(z/2)**(-2) / 4

# This function initialises the network with it's structure, it also resets any training already done.
def reset_network (n1 = 6, n2 = 7, random=np.random) :
    global W1, W2, W3, b1, b2, b3
    W1 = random.randn(n1, 1) / 2
    W2 = random.randn(n2, n1) / 2
    W3 = random.randn(2, n2) / 2
    b1 = random.randn(n1, 1) / 2
    b2 = random.randn(n2, 1) / 2
    b3 = random.randn(2, 1) / 2

# This function feeds forward each activation to the next layer. It returns all weighted sums and activations.
def network_function(a0) :
    z1 = W1 @ a0 + b1
    a1 = sigma(z1)
    z2 = W2 @ a1 + b2
    a2 = sigma(z2)
    z3 = W3 @ a2 + b3
    a3 = sigma(z3)
    return a0, z1, a1, z2, a2, z3, a3

# This is the cost function of a neural network with respect to a training set.
def cost(x, y) :
    return np.linalg.norm(network_function(x)[-1] - y)**2 / x.size

여기서 마지막에 np.linalg.norm 함수를 호출하는 부분이 있다. 

Cost function을 계산하는 부분에서 왜 linalg.norm 함수가 호출되었을까? 

 
Cost Function
 

np.linalg.norm 함수는 벡터나 행렬의 크기를 계산하는 함수로, 이 코드에서 사용된 이유는 신경망의 예측값과 실제 값(레이블)의 차이를 계산하는 데 있어서 그 차이를 유클리드 거리(Euclidean distance)로 표현하기 위해서이다.

구체적으로, 신경망 함수 network_function(x)는 입력 x에 대한 예측값을 출력하고, y는 실제 값(레이블)을 나타낸다. network_function(x)[-1]는 신경망의 최종 출력값을 의미하며, network_function(x)[-1] - y는 예측값과 실제 값 간의 차이(오차)를 나타낸다.

np.linalg.norm을 사용하여 이 차이를 유클리드 거리로 계산하면, 다차원 벡터일 경우 각 차원의 차이를 제곱하여 더한 후 제곱근을 취한 값이 나온다. 즉, 오차 벡터의 크기(또는 거리)를 구하는 것이고, 이는 신경망이 예측값을 실제 값과 얼마나 잘 맞추는지를 나타낸다. 이를 제곱(**2)하여 전체 오차의 제곱합(Sum of Squared Errors, SSE)을 구한 뒤, 학습 데이터의 크기(x.size)로 나누어 평균 오차를 계산하는 방식이다.

따라서 np.linalg.norm은 예측값과 실제 값 간의 오차 크기를 계산하는 중요한 부분이다.

 

728x90
반응형

'Study > 파이썬' 카테고리의 다른 글

[Numpy] 행렬 랭크 구하기  (0) 2024.09.19
[Numpy] Data Split  (1) 2024.09.15
[python] pathlib 사용 정리  (0) 2023.09.01
Comments