NumPy クイズです.

問題(n, n) 行列 X について,Xi 行と i 列を同時に取り除いて得られる (n-1, n-1) 行列考えます.こういった行列は n 個ありますが,これらをすべて並べた (n, n-1, n-1) 配列を作ってください.

たとえば,X

array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])

ならば,出力は

array([[[ 5,  6,  7],
        [ 9, 10, 11],
        [13, 14, 15]],

       [[ 0,  2,  3],
        [ 8, 10, 11],
        [12, 14, 15]],

       [[ 0,  1,  3],
        [ 4,  5,  7],
        [12, 13, 15]],

       [[ 0,  1,  2],
        [ 4,  5,  6],
        [ 8,  9, 10]]])

となってほしいです.


さて,この問題を考える前に,まず次の問題を考えます.

問題(n, n) 行列 X に対して,X の対角要素を取り除いて (n, n-1) 行列を作ってください.

たとえば,最初に挙げた X

array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])

に対して,出力は

array([[ 1,  2,  3],
       [ 4,  6,  7],
       [ 8,  9, 11],
       [12, 13, 14]])

となってほしい,ということです(各行で 0, 5, 10, 15 がスキップされています).

一番簡単なのは,マスクを作って取り出す方法.以下のような感じです.

X[numpy.eye(n) == 0].reshape(n, n - 1)

この方法はわかりやすくてよいのですが,CuPy や PyTorch などで GPU 上で計算する場合には,ホストとデバイスの同期が挟まってしまう問題があります.同期は,bool 配列によるインデクシングで発生します.bool 配列によるインデクシングでは,出力サイズが True の要素数になるので,配列の中身を見ないとサイズが決まりません.そこで,値が True の要素数をカウントして,ホストに送り,メモリを確保してから,デバイスでコピーをする,というような手順が必要になります.というわけで,同期が気になるような場面では,できれば使いたくありません.

この非対角要素を取り出す操作を,サイズが静的に決まる配列だけで実行するコードとして以下があります(PyTorch のフォーラムで見つけた手法を NumPy 語に直したものです).

X.reshape(n * n)[:-1].reshape(n - 1, n + 1)[:, 1:].reshape(n, n - 1)

完全に呪文って感じですが,このコードはスライスと reshape だけで書かれているので,ホスト・デバイス間の同期は発生しません.PyTorch でも同じコードで実行できます.

例に挙げた行列X で,各ステップを追ってみましょう.

>>> X.reshape(n * n)
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])

まず 1 本の長いベクトルにします.X が contiguous なら,この操作はゼロコピーです.

>>> X.reshape(n * n)[:-1]
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])

最後の要素を取り除いています.最後の要素は対角要素なので,ここで一つ対角要素が削除されることになります.最後の対角要素だけまず取り除くのは,次の reshape への布石です.

>>> X.reshape(n * n)[:-1].reshape(n - 1, n + 1)
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14]])

ここが一番の肝です.もともと n 要素ずつで折り返していた行列を,n + 1 要素ずつで折り返すようにしています.これは,隣り合う対角要素が,flatten した列の上では n + 1 要素分だけ離れていることに対応しています.こうすることで,もとの行列の対角要素が,すべて一番左の列に集まっていることがわかります.先ほど最後の対角要素だけ先に取り除いていたのは,そうしないと要素が 1 つ余ってしまい,この reshape ができないからです.ちなみに,この reshape もゼロコピーです.

>>> X.reshape(n * n)[:-1].reshape(n - 1, n + 1)[:, 1:]
array([[ 1,  2,  3,  4],
       [ 6,  7,  8,  9],
       [11, 12, 13, 14]])

上のがほぼ種明かしで,あとはわかりやすいと思います.まず,左に集まった対角要素をスライスでスキップしています.

>>> X.reshape(n * n)[:-1].reshape(n - 1, n + 1)[:, 1:].reshape(n, n - 1)
array([[ 1,  2,  3],
       [ 4,  6,  7],
       [ 8,  9, 11],
       [12, 13, 14]])

あとは shape を直して完了です.この最後の reshape だけは,直前のスライスで配列が non-contiguous になっているため,コピーが発生します.

これを関数化しておきましょう.行列だけじゃなくて,行列をたくさん並べた高階の配列でも動くようにしておきます.

def offdiagonal(X, axis1, axis2):
    X = numpy.moveaxis(X, (axis1, axis2), (-2, -1))
    *s, n, _ = X.shape
    X = X.reshape(*s, n*n)[..., :-1].reshape(*s, n-1, n+1)[..., 1:].reshape(*s, n, n-1)
    return numpy.moveaxis(X, (-2, -1), (axis1, axis2))

対角要素を取り除く軸を指定できるようにしました.moveaxis でこれらの軸を後ろに回してから,この後ろに 2 軸について上で解説した操作を実行して,最後に軸の位置を戻しています.


さて,最初の問題に戻りましょう.

まず,Xn 個ならべて (n, n, n) 配列を作ります(最初の軸が新しく追加された軸です).すると,主小行列を取り出す操作は,i 番目の行列から i 番目の行と列を取り除く,という操作になります.ここで,i 番目の行列から i 番目の行を取り除く,という操作を考えると,これは最初の 2 つの軸について対角要素を取り除くということ,と捉えることができます.つまり,さきほど作った非対角要素を取り出すコードがそのまま使えます.列についても同様です.

offdiagonal(offdiagonal(numpy.stack([X] * n), 0, 1), 0, 2)

stack で行列を n 個コピーしたあと,offdiagonal を2回実行して行,列それぞれについて対角要素を取り除いています.これで,Python でループを書くことなく,主小行列を列挙することができました.