使用 torch.Tensor
pytorch、张量与 torch.Tensor
Comfy 的所有核心数值计算都是由 pytorch 完成的。如果你的自定义节点需要深入 stable diffusion 的底层,你就需要熟悉这个库,这远超本简介的范围。
不过,许多自定义节点都需要操作图像、潜变量和蒙版,这些在内部都表示为 torch.Tensor
,因此你可能需要收藏
torch.Tensor 的官方文档。
什么是张量?
torch.Tensor
表示张量,张量是向量或矩阵在任意维度上的数学泛化。张量的 秩(rank)是它的维度数量(所以向量秩为 1,矩阵秩为 2);它的 形状(shape)描述了每个维度的大小。
因此,一个 RGB 图像(高为 H,宽为 W)可以被看作是三组数组(每个颜色通道一组),每组大小为 H x W,可以表示为形状为 [H,W,3]
的张量。在 Comfy 中,图像几乎总是以批量(batch)形式出现(即使批量中只有一张图)。torch
总是将批量维放在第一位,所以 Comfy 的图像形状为 [B,H,W,3]
,通常写作 [B,H,W,C]
,其中 C 代表通道数(Channels)。
squeeze、unsqueeze 与 reshape
如果张量的某个维度大小为 1(称为折叠维度),那么去掉这个维度后的张量与原张量等价(比如只有一张图片的批量其实就是一张图片)。去除这种折叠维度称为 squeeze,插入一个这样的维度称为 unsqueeze。
将同样的数据以不同的形状表示称为 reshape。通常你需要了解底层数据结构,因此请谨慎操作!
重要符号说明
torch.Tensor
支持大多数 Python 的切片符号、迭代和其他常见的类列表操作。张量还有一个 .shape
属性,返回其大小,类型为 torch.Size
(它是 tuple
的子类,可以当作元组使用)。
还有一些你经常会见到的重要符号(其中几个在标准 Python 里不常见,但在处理张量时很常用):
torch.Tensor
支持在切片符号中使用None
,表示插入一个大小为 1 的新维度。:
在切片张量时常用,表示”保留整个维度”。就像 Python 里的a[start:end]
,但省略了起止点。...
表示”未指定数量的所有维度”。所以a[0, ...]
会提取批量中的第一个元素,无论有多少维度。- 在需要传递形状的函数中,形状通常以
tuple
形式传递,其中某个维度可以用-1
,表示该维度的大小由数据总量自动推算。
元素级操作
许多 torch.Tensor
的二元操作(包括 ’+’, ’-’, ’*’, ’/’ 和 ’==‘)都是元素级的(即对每个元素独立操作)。操作数必须是形状相同的两个张量,或一个张量和一个标量。所以:
张量的布尔值
你可能熟悉 Python 列表的真值:非空列表为 True
,None
或 []
为 False
。而 torch.Tensor
(只要有多个元素)没有定义的真值。你需要用 .all()
或 .any()
来合并元素级的真值:
这也意味着你需要用 if a is not None:
而不是 if a:
来判断一个张量变量是否已被赋值。