pytorch、张量与 torch.Tensor

Comfy 的所有核心数值计算都是由 pytorch 完成的。如果你的自定义节点需要深入 stable diffusion 的底层,你就需要熟悉这个库,这远超本简介的范围。

不过,许多自定义节点都需要操作图像、潜变量和蒙版,这些在内部都表示为 torch.Tensor,因此你可能需要收藏 torch.Tensor 的官方文档

什么是张量?

torch.Tensor 表示张量,张量是向量或矩阵在任意维度上的数学泛化。张量的 (rank)是它的维度数量(所以向量秩为 1,矩阵秩为 2);它的 形状(shape)描述了每个维度的大小。

因此,一个 RGB 图像(高为 H,宽为 W)可以被看作是三组数组(每个颜色通道一组),每组大小为 H x W,可以表示为形状为 [H,W,3] 的张量。在 Comfy 中,图像几乎总是以批量(batch)形式出现(即使批量中只有一张图)。torch 总是将批量维放在第一位,所以 Comfy 的图像形状为 [B,H,W,3],通常写作 [B,H,W,C],其中 C 代表通道数(Channels)。

squeeze、unsqueeze 与 reshape

如果张量的某个维度大小为 1(称为折叠维度),那么去掉这个维度后的张量与原张量等价(比如只有一张图片的批量其实就是一张图片)。去除这种折叠维度称为 squeeze,插入一个这样的维度称为 unsqueeze。

有些 torch 代码和自定义节点作者会在某个维度折叠时返回 squeeze 过的张量——比如批量只有一个成员时。这是常见的 bug 来源!

将同样的数据以不同的形状表示称为 reshape。通常你需要了解底层数据结构,因此请谨慎操作!

重要符号说明

torch.Tensor 支持大多数 Python 的切片符号、迭代和其他常见的类列表操作。张量还有一个 .shape 属性,返回其大小,类型为 torch.Size(它是 tuple 的子类,可以当作元组使用)。

还有一些你经常会见到的重要符号(其中几个在标准 Python 里不常见,但在处理张量时很常用):

  • torch.Tensor 支持在切片符号中使用 None,表示插入一个大小为 1 的新维度。
  • : 在切片张量时常用,表示”保留整个维度”。就像 Python 里的 a[start:end],但省略了起止点。
  • ... 表示”未指定数量的所有维度”。所以 a[0, ...] 会提取批量中的第一个元素,无论有多少维度。
  • 在需要传递形状的函数中,形状通常以 tuple 形式传递,其中某个维度可以用 -1,表示该维度的大小由数据总量自动推算。
>>> a = torch.Tensor((1,2))
>>> a.shape
torch.Size([2])
>>> a[:,None].shape 
torch.Size([2, 1])
>>> a.reshape((1,-1)).shape
torch.Size([1, 2])

元素级操作

许多 torch.Tensor 的二元操作(包括 ’+’, ’-’, ’*’, ’/’ 和 ’==‘)都是元素级的(即对每个元素独立操作)。操作数必须是形状相同的两个张量,或一个张量和一个标量。所以:

>>> import torch
>>> a = torch.Tensor((1,2))
>>> b = torch.Tensor((3,2))
>>> a*b
tensor([3., 4.])
>>> a/b
tensor([0.3333, 1.0000])
>>> a==b
tensor([False,  True])
>>> a==1
tensor([ True, False])
>>> c = torch.Tensor((3,2,1)) 
>>> a==c
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0

张量的布尔值

张量的”真值”与 Python 列表的真值不同。

你可能熟悉 Python 列表的真值:非空列表为 TrueNone[]False。而 torch.Tensor(只要有多个元素)没有定义的真值。你需要用 .all().any() 来合并元素级的真值:

>>> a = torch.Tensor((1,2))
>>> print("yes" if a else "no")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
>>> a.all()
tensor(False)
>>> a.any()
tensor(True)

这也意味着你需要用 if a is not None: 而不是 if a: 来判断一个张量变量是否已被赋值。