In this tutorial, you will learn how to use the NumPy argmax() function to find the index of the maximum element in arrays.

NumPy is a powerful library for scientific computing in Python; it provides N-dimensional arrays that are more performant than Python lists. One of the common operations you’ll perform when working with NumPy arrays is to find the maximum value in the array. However, you may sometimes want to find the index at which the maximum value occurs.

The argmax() function helps you find the index of the maximum in both one-dimensional and multidimensional arrays. Let’s proceed to learn how it works.

How to Find the Index of Maximum Element in a NumPy Array

To follow along with this tutorial, you need to have Python and NumPy installed. You can code along by starting a Python REPL or launching a Jupyter notebook.

First, let’s import NumPy under the usual alias np.

import numpy as np

You can use the NumPy max() function to get the maximum value in an array (optionally along a specific axis).

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))

# Output
10

In this case, np.max(array_1) returns 10, which is correct.

Suppose you’d like to find the index at which the maximum value occurs in the array. You can take the following two-step approach:

  1. Find the maximum element.
  2. Find the index of the maximum element.

In array_1, the maximum value of 10 occurs at index 4, following zero indexing. The first element is at index 0; the second element is at index 1, and so on.

<img alt="numpy-argmax" data- data-src="https://kirelos.com/wp-content/uploads/2022/09/echo/numpy-argmax-1-1500×844.png" data- height="422" src="data:image/svg xml,” width=”750″>

To find the index at which the maximum occurs, you can use the NumPy where() function. np.where(condition) returns an array of all indices where the condition is True.

You’ll have to tap into the array and access the item at the first index. To find where the maximum value occurs, we set the condition to array_1==10; recall that 10 is the maximum value in array_1.

print(int(np.where(array_1==10)[0]))

# Output
4

We have used np.where() with only the condition, but this is not the recommended method to use this function.

📑 Note: NumPy where() Function:

np.where(condition,x,y) returns:

– Elements from x when the condition is True, and

– Elements from y when the condition is False.

Therefore, chaining the np.max() and np.where() functions, we can find the maximum element, followed by the index at which it occurs.

Instead of the above two-step process, you can use the NumPy argmax() function to get the index of the maximum element in the array.

Syntax of the NumPy argmax() Function

The general syntax to use the NumPy argmax() function is as follows:

np.argmax(array,axis,out)
# we've imported numpy under the alias np

In the above syntax:

  • array is any valid NumPy array.
  • axis is an optional parameter. When working with multidimensional arrays, you can use the axis parameter to find the index of maximum along a specific axis.
  • out is another optional parameter. You can set the out parameter to a NumPy array to store the output of the argmax() function.

Note: From NumPy version 1.22.0, there’s an additional keepdims parameter. When we specify the axis parameter in the argmax() function call, the array is reduced along that axis. But setting the keepdims parameter to True ensures that the returned output is of the same shape as the input array.

Using NumPy argmax() to Find the Index of the Maximum Element

#1. Let us use the NumPy argmax() function to find the index of the maximum element in array_1.

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))

# Output
4

The argmax() function returns 4, which is correct! ✅

#2. If we redefine array_1 such that10 occurs twice, the argmax() function returns only the index of the first occurrence.

array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))

# Output
4

For the rest of the examples, we’ll use the elements of array_1 we defined in example #1.

Using NumPy argmax() to Find the Index of the Maximum Element in a 2D Array

Let’s reshape the NumPy array array_1 into a two-dimensional array with two rows and four columns.

array_2 = array_1.reshape(2,4)
print(array_2)

# Output
[[ 1  5  7  2]
 [10  9  8  4]]

For a two-dimensional array, axis 0 denotes the rows and axis 1 denotes the columns. NumPy arrays follow zero-indexing. So the indices of the rows and columns for the NumPy array array_2 are as follows:

<img alt="numpy-argmax-2darray" data- data-src="https://kirelos.com/wp-content/uploads/2022/09/echo/numpy-argmax-2-1500×844.png" data- height="422" src="data:image/svg xml,” width=”750″>

Now, let’s call the argmax() function on the two-dimensional array, array_2.

print(np.argmax(array_2))

# Output
4

Even though we called argmax() on the two-dimensional array, it still returns 4. This is identical to the output for the one-dimensional array, array_1 from the previous section.

Why does this happen? 🤔

This is because we have not specified any value for the axis parameter. When this axis parameter is not set, by default, the argmax() function returns the index of the maximum element along the flattened array.

What is a flattened array? If there is an N-dimensional array of shape d1 x d2 x … x dN, where d1, d2, up to dN are the sizes of the array along the N dimensions, then the flattened array is a long one-dimensional array of size  d1 * d2 * … * dN.

To check how the flattened array looks like for array_2,  you can call the flatten() method, as shown below:

array_2.flatten()

# Output
array([ 1,  5,  7,  2, 10,  9,  8,  4])

Index of the Maximum Element Along the Rows (axis = 0)

Let’s proceed to find the index of the maximum element along the rows (axis = 0).

np.argmax(array_2,axis=0)

# Output
array([1, 1, 1, 1])

This output can be a bit difficult to comprehend, but we’ll understand how it works.

We’ve set the axis parameter to zero (axis = 0), as we’d like to find the index of the maximum element along the rows. Therefore, the argmax() function returns the row number in which the maximum element occurs—for each of the three columns.

Let’s visualize this for better understanding.

<img alt="numpy-argmax-axis0" data- data-src="https://kirelos.com/wp-content/uploads/2022/09/echo/numpy-argmax-axis-0-1500×844.png" data- height="422" src="data:image/svg xml,” width=”750″>

From the above diagram and the argmax() output, we have the following:

  • For the first column at index 0, the maximum value 10 occurs in the second row, at index = 1.
  • For the second column at index 1, the maximum value 9 occurs in the second row, at index = 1.
  • For the third and fourth columns at index 2 and 3, the maximum values 8 and 4 both occur in the second row, at index = 1.

This is precisely why we have the output array([1, 1, 1, 1]) because the maximum element along the rows occurs in the second row (for all columns).

Index of the Maximum Element Along the Columns (axis = 1)

Next, let’s use the argmax() function to find the index of the maximum element along the columns.

Run the following code snippet and observe the output.

np.argmax(array_2,axis=1)
array([2, 0])

Can you parse the output?

We have set axis = 1 to compute the index of the maximum element along the columns.

The argmax() function returns, for each row, the column number in which the maximum value occurs.

Here’s a visual explanation:

<img alt="numpy-argmax-axis1" data- data-src="https://kirelos.com/wp-content/uploads/2022/09/echo/numpy-argmax-axis-1-1500×844.png" data- height="422" src="data:image/svg xml,” width=”750″>

From the above diagram and the argmax() output, we have the following:

  • For the first row at index 0, the maximum value 7 occurs in the third column, at index = 2.
  • For the second row at index 1, the maximum value 10 occurs in the first column, at index = 0.

I hope you now understand what the output, array([2, 0]) means.

Using the Optional out Parameter in NumPy argmax()

You can use the optional out the parameter in the NumPy argmax() function to store the output in a NumPy array.

Let’s initialize an array of zeros to store the output of the previous argmax() function call – to find the index of the maximum along the columns (axis= 1).

out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]

Now, let’s revisit the example of finding the index of the maximum element along the columns (axis = 1) and set the out to out_arr we’ve defined above.

np.argmax(array_2,axis=1,out=out_arr)

We see that the Python interpreter throws a TypeError, as the out_arr was initialized to an array of floats by default.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     56     try:
---> 57         return bound(*args, **kwds)
     58     except TypeError:

TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

Therefore, when setting the out parameter to the output array, it’s important to ensure that the output array is of the correct shape and data type. As array indices are always integers, we should set the dtype parameter to int when defining the output array.

out_arr = np.zeros((2,),dtype=int)
print(out_arr)

# Output
[0 0]

We can now go ahead and call the argmax() function with both the axis and out parameters, and this time, it runs without error.

np.argmax(array_2,axis=1,out=out_arr)

The output of the argmax() function can now be accessed in the array out_arr.

print(out_arr)
# Output
[2 0]

Conclusion

I hope this tutorial helped you understand how to use the NumPy argmax() function. You can run the code examples in a Jupyter notebook.

Let’s review what we’ve learned.

  • The NumPy argmax() function returns the index of the maximum element in an array. If the maximum element occurs more than once in an array a, then np.argmax(a) returns the index of the first occurrence of the element.
  • When working with multidimensional arrays, you can use the optional axis parameter to get the index of the maximum element along a particular axis. For example, in a two-dimensional array: by setting axis = 0 and axis = 1, you can get the index of the maximum element along the rows and columns, respectively.
  • If you’d like to store the returned value in another array, you can set the optional out parameter to the output array. However, the output array should be of compatible shape and data type.

Next, check out the in-depth guide on Python sets.