深入理解numpy中argmax的具體使用

一、基本介紹

numpy中的argmax簡而言之就是返回最大值的索引,當使用np.argmax(axis),這裡方向axis的指定往往讓人不理解。
簡而言之:這裡axis可以讓我們從、或者是深度方向來看一個高維數組。

二、代碼實驗

1、一維數組情況

在這裡插入圖片描述

簡單一維情況,np.argmax()直接返回最大值的索引,不指定axis可以認為是將數組拉平之後尋找最大值的索引

1.1、axis=0

當我們指定axis=0時,其實是在中作比較,尋找最大的的索引

在這裡插入圖片描述

當然對於這個一維情況沒有什麼影響。

1.2、axis=1

在這裡插入圖片描述

當我們指定axis=1的時候報錯瞭,這是因為我們的a是一維數組,沒有axis=1這個軸,可見當我們使用np.argmax()時axis的指定不能超過所需要排序的數組

2、二維數組情況

在這裡插入圖片描述

不指定axis就是相當於把二維數組拉平,直接選取最大值的索引

2.1、axis=0

在這裡插入圖片描述

指定axis=0就是比較,返回索引中的最大值

在這裡插入圖片描述

我們改寫一個b中的元素,我們期望的結果是[2,2,1,2]

在這裡插入圖片描述

實際結果和我們期望相符合

2.2、axis=1

在這裡插入圖片描述

指定axis=0就是比較,返回索引中的最大值

3、三維數組情況

一個三維數組可以視作一張圖片,它的三個維度分別為(high, width, channels) 分別表示圖像的高、寬、通道數(深度)。常見的彩色圖像都有三個通道,我們以常見的RGB圖像為例構建一個數組。

在這裡插入圖片描述

直接使用np.argmax(),就是之間將三維數組拉平,尋找最大值的索引

3.1、axis=0

單獨查看c的三個通道的數據,如圖所示

在這裡插入圖片描述

對於三個通道取axis=0意味分別比較列返回行的最大值索引

在這裡插入圖片描述

我們期望的返回值應該是[[1,1,1,],[1,1,1],[1,1,1]],實際的結果和我們的期望一致

在這裡插入圖片描述

3.2、axis=1

在這裡插入圖片描述

對於三個通道取axis=1意味分別比較行返回列的最大值索引

我們期望的結果是[[2, 2, 2],[2, 2, 2],[2, 2, 2],[2, 2, 2]],,實際的結果和我們的期望一致

在這裡插入圖片描述

3.3、axis=2

取axis=2意味著我們從圖像的深度方向(通道方向)來進行比較,可以認為三個數組的疊在一起的,分別對應channel0,channel1,channel2而我們取最大值的索引就是返回對應pixel像素所在的通道索引

在這裡插入圖片描述

c的channel2所有的像素值均大於其他兩個channel所有返回值應該是[[2,2,2,],[2,2,2,],[2,2,2,],[2,2,2,]],實際結果和我的期望一致

在這裡插入圖片描述

3.4、axis=-1

axis=-1即是反過來看軸,對於三維情況axis=-1axis=2一致

在這裡插入圖片描述

其他
對於二維情況axis=-1anxis=1一致
對於一維情況axis=0anxis=-1一致

四、Reference

https://blog.csdn.net/weixin_39190382/article/details/105854567

https://www.cnblogs.com/zhouyang209117/p/6512302.html

PS:補充

1.對一個一維向量

import numpy as np
a = np.array([3, 1, 2, 4, 6, 1])
b=np.argmax(a)#取出a中元素最大值所對應的索引,此時最大值位6,其對應的位置索引值為4,(索引值默認從0開始)
print(b)#4

2.對2維向量(通常意義下的矩陣)a[][]

import numpy as np
a = np.array([[1, 5, 5, 2],
              [9, 6, 2, 8],
              [3, 7, 9, 1]])
b=np.argmax(a, axis=0)#對二維矩陣來講a[0][1]會有兩個索引方向,第一個方向為a[0],默認按列方向搜索最大值
#a的第一列為1,9,3,最大值為9,所在位置為1,
#a的第一列為5,6,7,最大值為7,所在位置為2,
#此此類推,因為a有4列,所以得到的b為1行4列,
print(b)#[1 2 2 1]
 
c=np.argmax(a, axis=1)#現在按照a[0][1]中的a[1]方向,即行方向搜索最大值,
#a的第一行為1,5,5,2,最大值為5(雖然有2個5,但取第一個5所在的位置),索引值為1,
#a的第2行為9,6,2,8,最大值為9,索引值為0,
#因為a有3行,所以得到的c有3個值,即為1行3列
print(c)#[1 0 2]

3.對於三維矩陣a[0][1][2],情況最為復制,但在lstm中應用最廣

import numpy as np
a = np.array([
              [
                  [1, 5, 5, 2],
                  [9, -6, 2, 8],
                  [-3, 7, -9, 1]
              ],
 
              [
                  [-1, 7, -5, 2],
                  [9, 6, 2, 8],
                  [3, 7, 9, 1]
              ],
            [
                  [21, 6, -5, 2],
                  [9, 36, 2, 8],
                  [3, 7, 79, 1]
              ]
            ])
b=np.argmax(a, axis=0)#對於三維度矩陣,a有三個方向a[0][1][2]
#當axis=0時,是在a[0]方向上找最大值,即兩個矩陣做比較,具體
#(1)比較3個矩陣的第一行,即拿[1, 5, 5, 2],
#                         [-1, 7, -5, 2],
#                         [21, 6, -5, 2],
#再比較每一列的最大值在那個矩陣中,可以看出第一列1,-2,21最大值為21,在第三個矩陣中,索引值為2
#第2列5,7,6最大值為7,在第2個矩陣中,索引值為1.....,最終得出比較結果[2 1 0 0]
#再拿出三個矩陣的第二行,按照上述方法,得出比較結果 [0 2 0 0]
#一共有三個,所以最終得到的結果b就為3行4列矩陣
print(b)
#[[0 0 0 0]
 #[0 1 0 0]
 #[1 0 1 0]]
 
c=np.argmax(a, axis=1)#對於三維度矩陣,a有三個方向a[0][1][2]
#當axis=1時,是在a[1]方向上找最大值,即在列方向比較,此時就是指在每個矩陣內部的列方向上進行比較
#(1)看第一個矩陣
                  # [1, 5, 5, 2],
                  # [9, -6, 2, 8],
                  # [-3, 7, -9, 1]
#比較每一列的最大值,可以看出第一列1,9,-3最大值為9,,索引值為1
#第2列5,-6,7最大值為7,,索引值為2
# 因此對第一個矩陣,找出索引結果為[1,2,0,1]
#再拿出2個,按照上述方法,得出比較結果 [1 0 2 1]
#一共有三個,所以最終得到的結果b就為3行4列矩陣
print(c)
#[[1 2 0 1]
 # [1 0 2 1]
 # [0 1 2 1]]
 
d=np.argmax(a, axis=2)#對於三維度矩陣,a有三個方向a[0][1][2]
#當axis=2時,是在a[2]方向上找最大值,即在行方向比較,此時就是指在每個矩陣內部的行方向上進行比較
#(1)看第一個矩陣
                  # [1, 5, 5, 2],
                  # [9, -6, 2, 8],
                  # [-3, 7, -9, 1]
#尋找第一行的最大值,可以看出第一行[1, 5, 5, 2]最大值為5,,索引值為1
#第2行[9, -6, 2, 8],最大值為9,,索引值為0
# 因此對第一個矩陣,找出行最大索引結果為[1,0,1]
#再拿出2個矩陣,按照上述方法,得出比較結果 [1 0 2 1]
#一共有三個,所以最終得到的結果d就為3行3列矩陣
print(d)
# [[1 0 1]
#  [1 0 2]
#  [0 1 2]]
###################################################################
#最後一種情況,指定矩陣a[0, -1, :],第一個數字0代表取出第一個矩陣(從前面可以看出a有3個矩陣)為
# [1, 5, 5, 2],
# [9, -6, 2, 8],
# [-3, 7, -9, 1]
#第二個數字“-1”代表拿出倒數第一行,為
# [-3, 7, -9, 1]
#這一行的最大索引值為1
 
# ,-1,代表最後一行
m=np.argmax(a[0, -1, :])
print(m)#1
 
#h,取a的第2個矩陣
# [-1, 7, -5, 2],
# [9, 6, 2, 8],
# [3, 7, 9, 1]
#的第3行
# [3, 7, 9, 1]
#的最大值為9,索引為2
h=np.argmax(a[1, 2, :])
print(h)#2
 
g=np.argmax(a[1,:, 2])#g,取出矩陣a,第2個矩陣的第3列為-5,2,9,最大值為9,索引為2
print(g)#2

到此這篇關於深入理解numpy中argmax的具體使用的文章就介紹到這瞭,更多相關numpy argmax內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!

推薦閱讀: