Pytorch - Argmax

Pytorch - Argmax
“En este tutorial de Pytorch, veremos cómo devolver las posiciones de índice de los valores máximos de un tensor usando argmax ().

Pytorch es un marco de código abierto disponible con un lenguaje de programación de Python. Podemos procesar los datos en Pytorch en forma de tensor.

Un tensor es una matriz multidimensional que se utiliza para almacenar los datos. Entonces, para usar un tensor, tenemos que importar el módulo de antorcha.

Para crear un tensor, el método utilizado es tensor () "

Sintaxis:

antorcha.Tensor (datos)

Donde los datos son una matriz multidimensional.

argmax ()

ArgMax () en Pytorch se usa para devolver el índice del valor máximo de todos los elementos en el tensor de entrada.

Sintaxis:

antorcha.Argmax (Tensor, Dim, Keepdim)

Dónde

  1. El tensor es el tensor de entrada
  2. Dim es reducir la dimensión. dim = 0 Especifica la comparación de columna, que obtendrá el índice para el valor máximo a lo largo de una columna, y Dim = 1 especifica la comparación de filas, que obtendrá el índice para el valor máximo a lo largo de la fila.
  3. KeepDim verifica si el tensor de salida tiene dimensión (dim) retenida o no

Ejemplo 1

En este ejemplo, crearemos un tensor con 2 dimensiones que tengan 3 filas y 5 columnas y aplicaremos argmax () en filas y columnas.

#módulo de antorcha de Import
antorcha de importación
#cree un tensor con 2 dimensiones (3 * 5)
#Con elementos aleatorios usando la función randn ()
datos = antorcha.Randn (3,5)
#mostrar
Imprimir (datos)
#get el índice máximo a lo largo de las columnas con ArgMax
imprimir (antorcha.ArgMax (datos, dim = 0))
#Obtenga el índice máximo a lo largo de las filas con argmax
imprimir (antorcha.ArgMax (datos, dim = 1))

Producción:

tensor ([[0.6699, 1.3390, -1.0658, -1.8200, 0.6544],
[-0.3117, 0.2488, 0.2677, 0.2568, 0.5337],
[-1.0966, 1.8024, -0.7538, -0.2553, -1.0591]])
tensor ([0, 2, 1, 1, 0])
Tensor ([1, 4, 1])

Podemos ver que los valores máximos presentes en el índice a lo largo de las columnas son:

  1. Valor máximo - 0.6699. Su índice es 0.
  2. Valor máximo - 1.8024. Su índice es 2.
  3. Valor máximo - 0.2677. Su índice es 1.
  4. Valor máximo - 0.2568. Su índice es 1.
  5. Valor máximo - 0.6544. Su índice es 0.

Del mismo modo, los valores máximos presentes en el índice a lo largo de las filas son:

  1. Valor máximo - 1.3390. Su índice es 1.
  2. Valor máximo - 0.5337. Su índice es 4.
  3. Valor máximo - 1.8024. Su índice es 1.

Ejemplo 2

Crear tensor con 5 * 5 matriz y aplicar argMax ()

#módulo de antorcha de Import
antorcha de importación
#cree un tensor con 2 dimensiones (5 * 5)
#Con elementos aleatorios usando la función randn ()
datos = antorcha.Randn (5,5)
#mostrar
Imprimir (datos)
#get el índice máximo a lo largo de las columnas con ArgMax
imprimir (antorcha.ArgMax (datos, dim = 0))
#Obtenga el índice máximo a lo largo de las filas con argmax
imprimir (antorcha.ArgMax (datos, dim = 1))

Producción:

tensor ([[-0.9553, -0.2611, -2.1233, -0.5208, -0.3458],
[-0.5466, -1.6395, 0.2576, -0.3123, 0.6785],
[-0.4574, 1.5301, 0.4812, 0.3434, 0.1388],
[0.8364, 0.3821, 0.1529, 1.4529, 0.3747],
[-1.4991, -1.8821, -0.2861, -0.4067, 1.1323]])
Tensor ([3, 2, 2, 3, 4])
Tensor ([1, 4, 1, 3, 4])

Podemos ver que los valores máximos presentes en el índice a lo largo de las columnas son:

  1. Valor máximo - 0.8364. Su índice es 3.
  2. Valor máximo - 1.5301. Su índice es 2.
  3. Valor máximo - 0.4812. Su índice es 2.
  4. Valor máximo - 1.4529. Su índice es 3.
  5. Valor máximo - 1.1323. Su índice es 4.

Del mismo modo, los valores máximos presentes en el índice a lo largo de las filas son:

  1. Valor máximo - -0.2611. Su índice es 1.
  2. Valor máximo - 0.6785. Su índice es 4.
  3. Valor máximo - 1.5301. Su índice es 1.
  4. Valor máximo - 1.4529. Su índice es 3.
  5. Valor máximo - 1.1323. Su índice es 4.

Trabajar con CPU

Si desea ejecutar una función argMax () en la CPU, entonces tenemos que crear un tensor con una función CPU (). Esto se ejecutará en una máquina CPU.

Cuando estamos creando un tensor, en este momento, podemos usar la función CPU ().

Sintaxis:

antorcha.Tensor (datos).UPC()

Ejemplo 1

Crear tensor con 5 * 5 matriz con cpu () y aplique argMax ()
#módulo de antorcha de Import
antorcha de importación
#cree un tensor con 2 dimensiones (5 * 5)
#Con elementos aleatorios usando la función randn () con cpu ()
datos = antorcha.Randn (5,5).UPC()
#mostrar
Imprimir (datos)
#get el índice máximo a lo largo de las columnas con ArgMax
imprimir (antorcha.ArgMax (datos, dim = 0))
#Obtenga el índice máximo a lo largo de las filas con argmax
imprimir (antorcha.ArgMax (datos, dim = 1))

Producción:

tensor ([[-0.2213, 1.6140, -0.0774, 0.4135, 0.1379],
[-0.4415, -2.5789, 0.8294, -0.9309, 1.3535],
[-1.3256, -0.7233, -0.9713, 1.0742, 1.9350],
[-0.7126, -1.3336, 0.7371, -0.2253, 0.1675],
[-0.1174, -0.5773, 0.8887, -0.2563, -1.0667]])
Tensor ([4, 0, 4, 2, 2])
Tensor ([1, 4, 4, 2, 2])

Podemos ver que los valores máximos presentes en el índice a lo largo de las columnas son:

  1. Valor máximo - -0.1174. Su índice es 4.
  2. Valor máximo - 1.6140. Su índice es 0.
  3. Valor máximo - 0.8887. Su índice es 4.
  4. Valor máximo - 1.0742. Su índice es 2.
  5. Valor máximo - 1.9350. Su índice es 2.

Del mismo modo, los valores máximos presentes en el índice a lo largo de las filas son:

  1. Valor máximo - 1.6140. Su índice es 1.
  2. Valor máximo - 1.3535. Su índice es 4.
  3. Valor máximo - 1.9350. Su índice es 4.
  4. Valor máximo - 0.7371. Su índice es 2.
  5. Valor máximo - 0.8887. Su índice es 2.

Ejemplo 2

En este ejemplo, crearemos un tensor con 2 dimensiones que tengan 3 filas y 5 columnas usando la función CPU () y aplicaremos argmax () en filas y columnas.

#módulo de antorcha de Import
antorcha de importación
#cree un tensor con 2 dimensiones (3 * 5)
#Con elementos aleatorios usando randn () con la función CPU ()
datos = antorcha.Randn (3,5).UPC()
#mostrar
Imprimir (datos)
#get el índice máximo a lo largo de las columnas con ArgMax
imprimir (antorcha.ArgMax (datos, dim = 0))
#Obtenga el índice máximo a lo largo de las filas con argmax
imprimir (antorcha.ArgMax (datos, dim = 1))

Producción:

tensor ([[0.6699, 1.3390, -1.0658, -1.8200, 0.6544],
[-0.3117, 0.2488, 0.2677, 0.2568, 0.5337],
[-1.0966, 1.8024, -0.7538, -0.2553, -1.0591]])
tensor ([0, 2, 1, 1, 0])
Tensor ([1, 4, 1])

Podemos ver que los valores máximos presentes en el índice a lo largo de las columnas son:

  1. Valor máximo - 0.6699. Su índice es 0.
  2. Valor máximo - 1.8024. Su índice es 2.
  3. Valor máximo - 0.2677. Su índice es 1.
  4. Valor máximo - 0.2568. Su índice es 1.
  5. Valor máximo - 0.6544. Su índice es 0.

Del mismo modo, los valores máximos presentes en el índice a lo largo de las filas son:

  1. Valor máximo - 1.3390. Su índice es 1.
  2. Valor máximo - 0.5337. Su índice es 4.
  3. Valor máximo - 1.8024. Su índice es 1.

Conclusión

En esta lección de Pytorch, vimos lo que argmax () y cómo aplicar argMax () en un tensor para devolver índices de valores máximos en columnas y filas.

También creamos un tensor con la función CPU () y los índices devueltos de valores máximos. Dim es el parámetro utilizado para devolver índices de valores máximos en las columnas cuando se establece en 0 y devuelve los índices de valores máximos en las filas cuando se establece en 1.