メインコンテンツへスキップ

pytorch、張量、および torch.Tensor

Comfy におけるすべての核心的な数値計算は pytorch によって行われています。カスタムノードが stable diffusion の内部構造に深く関わる場合、このライブラリに精通する必要がありますが、それはこの導入文の範囲を大きく超えています。 しかし、多くのカスタムノードは画像、潜変量、マスクを操作する必要があり、これらは内部で torch.Tensor として表現されているため、torch.Tensor のドキュメント をブックマークしておくとよいでしょう。

張量とは?

torch.Tensor は張量を表します。張量は、ベクトルや行列を任意の次元数に一般化した数学的な概念です。張量の (rank)はそれが持つ次元の数であり(つまりベクトルは 1、行列は 2)、形状(shape)は各次元のサイズを記述します。 したがって、RGB 画像(高さ H、幅 W)は、3 つの配列(各色チャンネルごとに 1 つ)であり、それぞれが H x W であるため、形状 [H,W,3] の張量として表現できると考えられます。Comfy では、画像はほぼ常にバッチとして扱われます(バッチに画像が 1 枚しか含まれていない場合でも)。torch は常にバッチ次元を最初に配置するため、Comfy の画像の 形状[B,H,W,3] となり、通常は C をチャンネル(Channels)を表すものとして [B,H,W,C] と記述されます。

squeeze、unsqueeze、および reshape

張量の次元のサイズが 1 である場合(縮退次元と呼ばれます)、その次元を削除した張量と同じものとなります(バッチに画像が 1 枚しかない場合は、単なる画像です)。このような縮退次元を削除することを squeezing(スクイーズ)と呼び、挿入することを unsqueezing(アンスクイーズ)と呼びます。
一部の torch コードやカスタムノードの作者は、次元が縮退している場合に squeezed された張量を返すことがあります——例えばバッチにメンバーが 1 つしかない場合などです。これはバグの一般的な原因となります!
同じデータを異なる形状で表現することを reshape(リシェイプ)と呼びます。これには多くの場合、基礎となるデータ構造を知る必要があるため、注意して扱ってください!

重要な記法

torch.Tensor は、ほとんどの Python のスライス記法、イテレーション、以及其他常见的类列表操作 (その他の一般的なリストのような操作) をサポートしています。張量には .shape 属性もあり、サイズを torch.Size として返します(これは tuple のサブクラスであり、そのように扱えます)。 他にも、よく目にする重要な記法がいくつかあります(これらのいくつかは標準的な Python の記法としてはあまり一般的ではありませんが、張量を扱う際には非常に頻繁に見られます)
  • torch.Tensor はスライス記法において None を使用して、サイズ 1 の次元の挿入を示すことができます。
  • : は張量をスライスする際によく使用されます。これは単に「次元全体を保持する」ことを意味します。Python の a[start:end] を使用するようなものですが、開始点と終点を省略した形です。
  • ... は「指定されていない数の次元全体」を表します。したがって、a[0, ...] は次元の数に関係なく、バッチから最初のアイテムを抽出します。
  • 形状の受け渡しを必要とするメソッドでは、しばしば次元の tuple として渡され、その中で単一の次元にサイズ -1 を指定できます。これは、この次元のサイズがデータの総サイズに基づいて計算されるべきであることを示します。
>>> a = torch.Tensor((1,2))
>>> a.shape
torch.Size([2])
>>> a[:,None].shape 
torch.Size([2, 1])
>>> a.reshape((1,-1)).shape
torch.Size([1, 2])

要素ごとの操作

torch.Tensor における多くの二項演算(’+’, ’-’, ’*’, ’/’、’==’ などを含む)は要素ごとに適用されます(各要素に独立して適用されます)。オペランドは、同じ形状の 2 つの張量 または 張量とスカラーの いずれか でなければなりません。したがって:
>>> import torch
>>> a = torch.Tensor((1,2))
>>> b = torch.Tensor((3,2))
>>> a*b
tensor([3., 4.])
>>> a/b
tensor([0.3333, 1.0000])
>>> a==b
tensor([False,  True])
>>> a==1
tensor([ True, False])
>>> c = torch.Tensor((3,2,1)) 
>>> a==c
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0

張量の真偽値

張量の「真偽値(truthiness)」は Python リストのそれと同じではありません。
Python リストの真偽値については馴染みがあるかもしれません。空でないリストは TrueNone または []False となります。対照的に、torch.Tensor(要素が 2 つ以上ある場合)には定義された真偽値がありません。代わりに、.all() または .any() を使用して要素ごとの真偽値を結合する必要があります:
>>> a = torch.Tensor((1,2))
>>> print("yes" if a else "no")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
>>> a.all()
tensor(False)
>>> a.any()
tensor(True)
これはつまり、張量変数が設定されているかどうかを判定するには、if a: ではなく if a is not None: を使用する必要があることを意味します。