Skip to content

Commit d63e0bd

Browse files
authored
[Edit] PyTorch: .mean() (#7623)
* [Edit] PyTorch: .mean() * slight faq update * added backlink ---------
1 parent 78017a4 commit d63e0bd

File tree

1 file changed

+107
-19
lines changed
  • content/pytorch/concepts/tensors/terms/mean

1 file changed

+107
-19
lines changed
Lines changed: 107 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
---
22
Title: '.mean()'
3-
Description: 'Calculates the mean of all elements in a PyTorch tensor or along a specified dimension.'
3+
Description: 'Calculates the mean of all elements or along a specified dimension in a PyTorch tensor.'
44
Subjects:
55
- 'AI'
66
- 'Data Science'
@@ -14,22 +14,29 @@ CatalogContent:
1414
- 'paths/data-science'
1515
---
1616

17-
The **.mean()** method in PyTorch computes the arithmetic mean (average) of tensor elements. It can calculate the mean for all elements in the tensor or along a specified dimension. This method is widely used in data preprocessing and analysis for summarizing data.
17+
The **`torch.mean()`** method in PyTorch computes the arithmetic mean (average) of a given [tensor](https://www.codecademy.com/resources/docs/pytorch/tensors). It can calculate the mean of all elements or along a specified dimension in the tensor. This method is widely used in data preprocessing and analysis for summarizing data.
1818

19-
## Syntax
19+
## `torch.mean()` Syntax
2020

2121
```pseudo
22-
tensor.mean(dim=None, keepdim=False)
22+
torch.mean(input, dim, keepdim=False, *, dtype=None, out=None)
2323
```
2424

25-
- `dim` (optional): The dimension along which the mean is computed. If not specified, the mean of all elements is calculated.
26-
- `keepdim` (optional): If `True`, retains the reduced dimension with size `1`. Defaults to `False`.
25+
**Parameters:**
2726

28-
The function returns a tensor containing the mean value(s).
27+
- `input`: The input tensor.
28+
- `dim` (Optional): The dimension along which the mean is computed. If not specified, the mean of all elements is calculated.
29+
- `keepdim` (Optional): If `True`, retains the reduced dimension(s) with size `1`. Defaults to `False`.
30+
- `dtype` (Optional): The desired data type for the output tensor.
31+
- `out` (Optional): The output tensor.
2932

30-
## Example
33+
**Return value:**
3134

32-
This example demonstrates calculating the mean of all elements in a tensor and along a specific dimension:
35+
The `torch.mean()` method returns a tensor containing the mean value(s).
36+
37+
## Example 1: Mean of All Elements Using `torch.mean()`
38+
39+
This example calculates the mean of all elements in a tensor using `torch.mean()`:
3340

3441
```py
3542
import torch
@@ -38,23 +45,104 @@ import torch
3845
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
3946

4047
# Calculate the mean of all elements
41-
mean_all = tensor.mean()
42-
43-
# Calculate the mean along dimension 0 (columns)
44-
mean_dim0 = tensor.mean(dim=0)
48+
mean_all = torch.mean(tensor)
4549

4650
print("Mean of all elements:", mean_all)
47-
print("Mean along dimension 0:", mean_dim0)
4851
```
4952

50-
This example results in the following output:
53+
Here is the output:
5154

5255
```shell
5356
Mean of all elements: tensor(2.5000)
54-
Mean along dimension 0: tensor([2.0000, 3.0000])
5557
```
5658

57-
In this example:
59+
## Example 2: Mean Along Columns Using `torch.mean()`
60+
61+
This example calculates the mean along dimension `0` (columns) in a tensor using `torch.mean()`:
62+
63+
```py
64+
import torch
65+
66+
# Create a tensor
67+
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
68+
69+
# Calculate the mean along dimension 0 (columns)
70+
mean_dim0 = torch.mean(tensor, dim=0)
71+
72+
print("Mean along columns:", mean_dim0)
73+
```
74+
75+
Here is the output:
76+
77+
```shell
78+
Mean along columns: tensor([2., 3.])
79+
```
80+
81+
## Example 3: Mean Along Rows Using `torch.mean()`
82+
83+
This example calculates the mean along dimension `1` (rows) in a tensor using `torch.mean()`:
84+
85+
```py
86+
import torch
87+
88+
# Create a tensor
89+
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
90+
91+
# Calculate the mean along dimension 1 (rows)
92+
mean_dim1 = torch.mean(tensor, dim=1)
93+
94+
print("Mean along rows:", mean_dim1)
95+
```
96+
97+
Here is the output:
98+
99+
```shell
100+
Mean along rows: tensor([1.5000, 3.5000])
101+
```
102+
103+
## Frequently Asked Questions
104+
105+
### 1. What is the mean function in PyTorch?
106+
107+
`torch.mean()` computes the arithmetic mean (average) of a given tensor. By default, it calculates the mean of all elements in the tensor:
108+
109+
```py
110+
import torch
111+
112+
# Create a tensor
113+
x = torch.tensor([1., 2., 3., 4.])
114+
115+
# Calculate the mean of all elements
116+
print(torch.mean(x)) # tensor(2.5000)
117+
```
118+
119+
### 2. How do I compute the mean along a specific axis using `torch.mean()`?
120+
121+
To compute the mean along a specific axis, Use the `dim` parameter with `torch.mean()`:
122+
123+
```py
124+
import torch
125+
126+
# Create a tensor
127+
x = torch.tensor([[1., 2.], [3., 4.]])
128+
129+
# Calculate the mean along dimension 0 (columns)
130+
print(torch.mean(x, dim=0)) # tensor([2., 3.])
131+
132+
# Calculate the mean along dimension 1 (rows)
133+
print(torch.mean(x, dim=1)) # tensor([1.5000, 3.5000])
134+
```
135+
136+
### 3. What does `keepdim=True` do in `torch.mean()`?
137+
138+
`keepdim=True` in `torch.mean()` keeps the reduced dimension(s) with size `1`:
139+
140+
```py
141+
import torch
142+
143+
# Create a tensor
144+
x = torch.tensor([[1., 2.], [3., 4.]])
58145

59-
- `mean_all` computes the mean of all elements in the tensor.
60-
- `mean_dim0` computes the mean along each column (dimension 0), reducing the rows. This makes `.mean()` a versatile tool for data analysis.
146+
# Calculate the mean along rows with keepdim=True
147+
print(torch.mean(x, dim=1, keepdim=True)) # tensor([[1.5000], [3.5000]])
148+
```

0 commit comments

Comments
 (0)