NumPy で対角要素を取り除く
NumPy クイズです.
問題:(n, n)
行列 X
について,X
の i
行と i
列を同時に取り除いて得られる (n-1, n-1)
行列考えます.こういった行列は n
個ありますが,これらをすべて並べた (n, n-1, n-1)
配列を作ってください.
たとえば,X
が
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
ならば,出力は
array([[[ 5, 6, 7],
[ 9, 10, 11],
[13, 14, 15]],
[[ 0, 2, 3],
[ 8, 10, 11],
[12, 14, 15]],
[[ 0, 1, 3],
[ 4, 5, 7],
[12, 13, 15]],
[[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10]]])
となってほしいです.
さて,この問題を考える前に,まず次の問題を考えます.
問題:(n, n)
行列 X
に対して,X
の対角要素を取り除いて (n, n-1)
行列を作ってください.
たとえば,最初に挙げた X
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
に対して,出力は
array([[ 1, 2, 3],
[ 4, 6, 7],
[ 8, 9, 11],
[12, 13, 14]])
となってほしい,ということです(各行で 0
, 5
, 10
, 15
がスキップされています).
一番簡単なのは,マスクを作って取り出す方法.以下のような感じです.
X[numpy.eye(n) == 0].reshape(n, n - 1)
この方法はわかりやすくてよいのですが,CuPy や PyTorch などで GPU 上で計算する場合には,ホストとデバイスの同期が挟まってしまう問題があります.同期は,bool 配列によるインデクシングで発生します.bool 配列によるインデクシングでは,出力サイズが True
の要素数になるので,配列の中身を見ないとサイズが決まりません.そこで,値が True
の要素数をカウントして,ホストに送り,メモリを確保してから,デバイスでコピーをする,というような手順が必要になります.というわけで,同期が気になるような場面では,できれば使いたくありません.
この非対角要素を取り出す操作を,サイズが静的に決まる配列だけで実行するコードとして以下があります(PyTorch のフォーラムで見つけた手法を NumPy 語に直したものです).
X.reshape(n * n)[:-1].reshape(n - 1, n + 1)[:, 1:].reshape(n, n - 1)
完全に呪文って感じですが,このコードはスライスと reshape
だけで書かれているので,ホスト・デバイス間の同期は発生しません.PyTorch でも同じコードで実行できます.
例に挙げた行列X
で,各ステップを追ってみましょう.
>>> X.reshape(n * n)
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
まず 1 本の長いベクトルにします.X
が contiguous なら,この操作はゼロコピーです.
>>> X.reshape(n * n)[:-1]
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
最後の要素を取り除いています.最後の要素は対角要素なので,ここで一つ対角要素が削除されることになります.最後の対角要素だけまず取り除くのは,次の reshape への布石です.
>>> X.reshape(n * n)[:-1].reshape(n - 1, n + 1)
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
ここが一番の肝です.もともと n
要素ずつで折り返していた行列を,n + 1
要素ずつで折り返すようにしています.これは,隣り合う対角要素が,flatten した列の上では n + 1
要素分だけ離れていることに対応しています.こうすることで,もとの行列の対角要素が,すべて一番左の列に集まっていることがわかります.先ほど最後の対角要素だけ先に取り除いていたのは,そうしないと要素が 1 つ余ってしまい,この reshape ができないからです.ちなみに,この reshape もゼロコピーです.
>>> X.reshape(n * n)[:-1].reshape(n - 1, n + 1)[:, 1:]
array([[ 1, 2, 3, 4],
[ 6, 7, 8, 9],
[11, 12, 13, 14]])
上のがほぼ種明かしで,あとはわかりやすいと思います.まず,左に集まった対角要素をスライスでスキップしています.
>>> X.reshape(n * n)[:-1].reshape(n - 1, n + 1)[:, 1:].reshape(n, n - 1)
array([[ 1, 2, 3],
[ 4, 6, 7],
[ 8, 9, 11],
[12, 13, 14]])
あとは shape を直して完了です.この最後の reshape だけは,直前のスライスで配列が non-contiguous になっているため,コピーが発生します.
これを関数化しておきましょう.行列だけじゃなくて,行列をたくさん並べた高階の配列でも動くようにしておきます.
def offdiagonal(X, axis1, axis2):
X = numpy.moveaxis(X, (axis1, axis2), (-2, -1))
*s, n, _ = X.shape
X = X.reshape(*s, n*n)[..., :-1].reshape(*s, n-1, n+1)[..., 1:].reshape(*s, n, n-1)
return numpy.moveaxis(X, (-2, -1), (axis1, axis2))
対角要素を取り除く軸を指定できるようにしました.moveaxis
でこれらの軸を後ろに回してから,この後ろに 2 軸について上で解説した操作を実行して,最後に軸の位置を戻しています.
さて,最初の問題に戻りましょう.
まず,X
を n
個ならべて (n, n, n)
配列を作ります(最初の軸が新しく追加された軸です).すると,主小行列を取り出す操作は,i
番目の行列から i
番目の行と列を取り除く,という操作になります.ここで,i
番目の行列から i
番目の行を取り除く,という操作を考えると,これは最初の 2 つの軸について対角要素を取り除くということ,と捉えることができます.つまり,さきほど作った非対角要素を取り出すコードがそのまま使えます.列についても同様です.
offdiagonal(offdiagonal(numpy.stack([X] * n), 0, 1), 0, 2)
stack
で行列を n
個コピーしたあと,offdiagonal
を2回実行して行,列それぞれについて対角要素を取り除いています.これで,Python でループを書くことなく,主小行列を列挙することができました.