-
[tensorflow] What argmax means?(axis) ?AI 2020. 3. 25. 10:31
텐서플로우 예제를 실습 하던 도중
아래와 같은 코드를 접했다.
epoch_accuracy(tf.argmax(model(x), axis=1, output_type=tf.int32), y)
epoch_accuracy는 매트릭스를 이용해 정확도를 측정하는 함수이다. 참조
이 함수는 1차원 배열 형태인 2개의 파라미터를 받으며, 예측된 값과 y값을 넘겨주면 된다.
만약 label이 3개인 경우 [1, 0, 0, 1, 1, 2] 이런 형태가 될 것이다.
y값은 그냥 넘겨주면 된다.
예측값은 다차원의 배열이기 때문에 argmax를 사용해 1차원의 배열로 만드는 동시에 최대값만 남긴다.
argmax에서 2번째 인자로 axis(중심축)를 받는데 이 매개변수가 주요 포인트이다.(헷갈릴 수 있다.)
2차원 배열의 경우 rank가 2이기 때문에 행/열 그러니까 axis값으로 0과 1을 줄 수 있다.
빠른 이해를 위해 예를 바로 들겠다.
import tensorflow as tf sess = tf.Session() sess.run(tf.global_variables_initializer()) x = [ [2, 4, 7], [8, 3, 1] ] # case axis=0 argmax(x, axis=0).eval(session=sess) #result: array([1, 0, 0]) # case axis=1 argmax(x, axis=1).eval(session=sess) #result: array([2, 0])
위의 결과를 보면 axis 값에 따라 결과가 달라지는 것을 확인 할 수 있다.
- axis 값이 0이면,
세로방향(열)으로 최대값의 인덱스를 리턴한다.
[ [2, 4, 7] , [8, 3, 1] ] -> [ maxIdx(2, 8), maxIdx(4, 3), maxIdx(7, 1) ] -> [ 1, 0, 0 ]
- axis 값이 1이면,
가로방향(행)으로 최대값의 인덱스를 리턴한다.
[ [2, 4, 7] , [8, 3, 1] ] -> [ maxIdx(2, 4, 7), maxIdx(8, 3, 1) ] -> [ 2, 0 ]
알고나면 간단한 것을..!
'AI' 카테고리의 다른 글