Rate this Page

torch.bmm#

torch.bmm(input, mat2, out_dtype=None, *, out=None) Tensor#

Performs a batch matrix-matrix product of matrices stored in input and mat2.

input and mat2 must be 3-D tensors each containing the same number of matrices.

If input is a (b×n×m)(b \times n \times m) tensor, mat2 is a (b×m×p)(b \times m \times p) tensor, out will be a (b×n×p)(b \times n \times p) tensor.

outi=inputi@mat2i\text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i

This operator supports TensorFloat32.

On certain ROCm devices, when using float16 inputs this module will use different precision for backward.

Note

This function does not broadcast. For broadcasting matrix products, see torch.matmul().

Parameters
  • input (Tensor) – the first batch of matrices to be multiplied

  • mat2 (Tensor) – the second batch of matrices to be multiplied

  • out_dtype (dtype, optional) – the dtype of the output tensor, Supported only on CUDA and for torch.float32 given torch.float16/torch.bfloat16 input dtypes

Keyword Arguments

out (Tensor, optional) – the output tensor.

Example:

>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])