Wie verwende ich die Methode „torch.argmax()“ in PyTorch?

Wie Verwende Ich Die Methode Torch Argmax In Pytorch



In PyTorch ist das „ Torch.argmax() Die Methode ist eine integrierte Funktion, die Indizes der Maximalwerte eines bestimmten Tensors über eine bestimmte Dimension zurückgibt. Benutzer verwenden diese Funktion, wenn sie mit Tensoren arbeiten und den Index des Maximalwerts entlang der gegebenen Dimension eines Tensors ermitteln möchten. Darüber hinaus kann diese Methode auch für die Klassifizierung nützlich sein, wenn Benutzer wissen möchten, welche Klasse die höchste Wahrscheinlichkeit hat.

In diesem Blog wird die Methode zur Verwendung der Methode „torch.argmax()“ in PyTorch veranschaulicht.

Wie verwende ich die Methode „torch.argmax()“ in PyTorch?

Die Methode „torch.argmax()“ nimmt einen beliebigen 1D- oder 2D-Tensor als Eingabe und gibt einen Tensor zurück, der die Indizes/Indizes der Maximalwerte entlang der angegebenen Dimension enthält.







Die Syntax der Methode „torch.argmax()“ ist unten angegeben:



Fackel. argmax ( < input_tensor > )

Um diese Methode in PyTorch zu verwenden, gehen Sie zum besseren Verständnis die folgenden Beispiele durch:



Beispiel 1: Verwenden Sie die Methode „torch.argmax()“ mit 1D-Tensor

Im ersten Beispiel erstellen wir einen 1D-Tensor und verwenden damit die Methode „torch.argmax()“. Folgen wir der folgenden Schritt-für-Schritt-Anleitung:





Schritt 1: PyTorch-Bibliothek importieren

Importieren Sie zunächst „ Fackel ”-Bibliothek zur Verwendung der Methode „torch.argmax()“:

importieren Fackel

Schritt 2: Erstellen Sie einen 1D-Tensor

Erstellen Sie dann einen 1D-Tensor und drucken Sie seine Elemente. Hier erstellen wir Folgendes: „ Tens1 ” Tensor aus einer Liste mit dem „ Torch.tensor() ” Funktion:



Tens1 = Fackel. Tensor ( [ 5 , 0 , - 8 , 1 , 9 , 7 ] )

drucken ( Tens1 )

Dadurch wurde ein 1D-Tensor erstellt, wie unten dargestellt:

Schritt 3: Finden Sie Indizes mit maximalem Wert

Nutzen Sie nun die „ Torch.argmax() ”-Funktion, um den Index/die Indizes des Maximalwerts in der „ Tens1 ” Tensor:

T1_ind = Fackel. argmax ( Tens1 )

Schritt 4: Index des Maximalwerts drucken

Zeigen Sie abschließend den Index des Maximalwerts im Eingabetensor an:

drucken ( „Indizes:“ , T1_ind )

Die folgende Ausgabe zeigt den Index des Maximalwerts im „ Tens1 „Tensor, d. h. 4. Dies bedeutet, dass der höchste Wert des Tensors beim 4. Index liegt, der „ 9 ”:

Beispiel 2: Verwenden Sie die Methode „torch.argmax()“ mit 2D-Tensor

Im zweiten Beispiel erstellen wir einen 2D-Tensor und verwenden damit die Methode „torch.argmax()“. Folgen wir den bereitgestellten Schritten:

Schritt 1: PyTorch-Bibliothek importieren

Importieren Sie zunächst „ Fackel ”-Bibliothek zur Verwendung der Methode „torch.argmax()“:

importieren Fackel

Schritt 2: Erstellen Sie einen 2D-Tensor

Verwenden Sie dann die „ Torch.tensor() ”-Funktion zum Erstellen eines 2D-Tensors und Drucken seiner Elemente. Hier erstellen wir Folgendes: „ Zehner2 „2D-Tensor:

Zehner2 = Fackel. Tensor ( [ [ 4 , 1 , - 7 ] , [ fünfzehn , 6 , 0 ] , [ - 7 , 9 , 2 ] ] )

drucken ( Zehner2 )

Dadurch wurde ein 2D-Tensor erstellt, wie unten dargestellt:

Schritt 3: Finden Sie Indizes mit maximalem Wert

Suchen Sie nun den Index des Maximalwerts im „ Zehner2 ”-Tensor unter Verwendung des „ Torch.argmax() ” Funktion:

T2_ind = Fackel. argmax ( Zehner2 )

Schritt 4: Index des Maximalwerts drucken

Zeigen Sie abschließend den Index des Maximalwerts im Eingabetensor an:

drucken ( „Indizes:“ , T2_ind )

Gemäß der folgenden Ausgabe ist der Index des Maximalwerts im „ Zehner2 Der Tensor ist „3“. Das bedeutet, dass der höchste Wert des Tensors beim 3. Index liegt, also „ fünfzehn ”:

Schritt 5: Finden Sie Indizes mit maximalem Wert entlang der Spalten

Darüber hinaus können Benutzer auch die Indizes/Indizes der Maximalwerte entlang jeder Spalte eines Tensors finden. Wir können zum Beispiel „ dim=0 ”-Argument mit der Funktion „torch.argmax()“. Es findet die Indizes der Maximalwerte entlang der Spalten im „ Zehner2 ” Tensor und gibt dann diese Indizes aus:

col_index = Fackel. argmax ( Zehner2 , schwach = 0 )

drucken ( „Indizes in Spalten:“ , col_index )

Die folgende Ausgabe zeigt die Indizes der Maximalwerte entlang jeder Spalte des Tensors:

Schritt 6: Finden Sie Indizes mit maximalem Wert entlang der Zeilen

Ebenso können Benutzer auch die Indizes/Indizes der Maximalwerte entlang jeder Zeile eines Tensors finden. Verwenden Sie zum Beispiel „ dim=1 ”-Argument mit der Funktion „torch.argmax()“, um die Indizes der Maximalwerte entlang der Zeilen im Tensor „Tens2“ zu finden und diese Indizes dann auszugeben:

row_index = Fackel. argmax ( Zehner2 , schwach = 1 )

drucken ( „Indizes in Zeilen:“ , row_index )

Die Indizes des Maximalwerts entlang jeder Zeile eines „Tens2“-Tensors sind unten zu sehen:

Wir haben die Methode zur Verwendung der Methode „torch.argmax()“ in PyTorch ausführlich erklärt.

Notiz : Hier können Sie auf unser Google Colab Notebook zugreifen Verknüpfung .

Abschluss

Um die Methode „torch.argmax()“ in PyTorch zu verwenden, importieren Sie zunächst die „ Fackel ' Bibliothek. Erstellen Sie dann den gewünschten 1D- oder 2D-Tensor und sehen Sie sich seine Elemente an. Als nächstes verwenden Sie die „ Torch.argmax() ”-Methode zum Finden/Berechnen der Indizes/Indizes der Maximalwerte im Tensor. Darüber hinaus können Benutzer auch die Indizes des Maximalwerts entlang jeder Zeile oder Spalte im Tensor finden, indem sie „ schwach ' Streit. Zeigen Sie abschließend den Index des Maximalwerts im Eingabetensor an. In diesem Blog wurde die Methode zur Verwendung der Methode „torch.argmax()“ in PyTorch veranschaulicht.