NumPy で畳込み層
NumPy で畳込み層(torch.nn.functional.conv2d
互換)を書いたメモです.解説を書く元気が残ってないので,興味がある人はがんばって解読してください.適当なパラメータで PyTorch と答えが allclose で一致することを確認してますが,網羅的にはテストしていないです.PyTorch の方がずっと速いので,これ自体は実用的ではないですが,NumPy 使ってスクラッチからニューラルネットや DL フレームワーク書きたい人には取っかかりとして手っ取り早いのではないかと思います.
# Copyright 2021 Seiya Tokui
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import numpy
def conv2d(x, w, bias=None, stride=1, padding=0, dilation=1, groups=1):
"""PyTorch-compatible simple implementation of conv2d in pure NumPy.
NOTE: This code prioritizes simplicity, sacrificing the performance.
It is much faster to use matmul instead of einsum (at least with NumPy 1.19.1).
MIT License.
"""
sY, sX = stride if isinstance(stride, (list, tuple)) else (stride, stride)
pY, pX = padding if isinstance(padding, (list, tuple)) else (padding, padding)
dY, dX = dilation if isinstance(dilation, (list, tuple)) else (dilation, dilation)
N, iC, iH, iW = x.shape
oC, iCg, kH, kW = w.shape
pY_ex = (sY - (iH + pY * 2 - (kH - 1) * dY) % sY) % sY
pX_ex = (sX - (iW + pX * 2 - (kW - 1) * dX) % sX) % sX
oH = (iH + pY * 2 + pY_ex - (kH - 1) * dY) // sY
oW = (iW + pX * 2 + pX_ex - (kW - 1) * dX) // sX
x = numpy.pad(x, ((0, 0), (0, 0), (pY, pY + pY_ex), (pX, pX + pX_ex)))
sN, sC, sH, sW = x.strides
col = numpy.lib.stride_tricks.as_strided(
x, shape=(N, groups, iCg, oH, oW, kH, kW),
strides=(sN, sC * iCg, sC, sH * sY, sW * sX, sH * dY, sW * dX),
)
w = w.reshape(groups, oC // groups, iCg, kH, kW)
y = numpy.einsum('ngihwkl,goikl->ngohw', col, w).reshape(N, oC, oH, oW)
if bias is not None:
y += bias[:, None, None]
return y
やってることは im2col による実装と同じです.Chainer のときは col
相当のものをループ書いて作ってましたが,ループなし・コピーなしで書けるじゃん,というのが今回の気づきです.
as_strided
で移動窓を作っています (cf. sliding_window_view).stride
とdilation
はどちらも strides で対応できます.padding
はおそらくどうしようもない.einsum
はreshape
/transpose
がんばればmatmul
になりますがシンプルさを優先しました.大抵の場合はmatmul
の方が断然速いです(Chainer の grouped conv 実装はmatmul
を使っていて,ちゃんと速い).einsum
爆速になってほしい.- 追記:
optimize=True
のことを忘れてたので試してみましたが,特に速度は変わりませんでした.残念.
- 追記:
- backward も似た方法で書けるはず(
einsum
の微分はeinsum
で,as_strided
は出力をas_strided
で作ってadd.at
). as_strided
をcupy
のものに置き換えれば CuPy で動きます.ちゃんと比べてませんが,こっちは実装の割にそこそこ速いみたいです.適当に P100 のマシンで比べると,autotune してない PyTorch と比べて 1.5〜2 倍くらいの時間で動いてそうです.cuDNN 叩くよりずっと簡単なのでいいですね.