遞歸神經網絡RNN怎樣加速?看PyTorch如何進行動態批處理

遞歸神經網絡RNN怎樣加速?看PyTorch如何進行動態批處理

原文來源:medium

作者:Illia Polosukhin

「機器人圈」編譯:多啦A亮

如果你讀過我的博客,你可能已經瞭解到我是一個TensorFlow的貢獻者,並在那裡建立了很多高級API。

而在2017年2月,我已經離開谷歌並創立了自己的公司——NEAR.ai。我們教機器用自然語言編寫代碼。

作為這項工作的一部分,我們正在構建以樹格式讀取或編寫代碼的深度學習模型。在試圖用TensorFlow管理這種複雜性之後,我已經決定嘗試用一下PyTorch。

PyTorch是由Facebook AI研究人員構建的框架,並且在自然語言和強化學習研究領域越來越受歡迎。它的主要優點是動態圖形構建原理——與Tensorflow相比,其中圖形一旦被構建,然後就會被“執行”多次,PyTorch可以使用簡單的Python邏輯動態重建圖形,就像你正在使用numpy數組進行計算一樣。

遞歸神經網絡RNN怎樣加速?看PyTorch如何進行動態批處理

來源: http://pytorch.org/about

這種靈活性吸引了一些人,他們使用複雜輸入/輸出數據(例如語言、樹、圖形)或需要在計算中運行一些自定義邏輯(深度強化學習)。

在這裡我想談談批處理的事情。即使PyTorch利用GPU加速器快速運行,並且通常推進C模塊的計算,如果你沒有對計算進行批處理——你仍然需要付出代價。

遞歸神經網絡(以樹形LSTM為例)特別難以批處理,因為每個示例都是不同的樹。

單純的實現將如下所示:

class TreeLSTM(nn.Module):

def __init__(self, num_units):

super(TreeLSTM, self).__init__()

self.num_units = num_units

self.left = nn.Linear(num_units, 5 * num_units)

self.right = nn.Linear(num_units, 5 * num_units)

def forward(self, left_in, right_in):

lstm_in = self.left(left_in[0])

lstm_in += self.right(right_in[0])

a, i, f1, f2, o = lstm_in.chunk(5, 1)

c = (a.tanh() * i.sigmoid() + f1.sigmoid() * left_in[1] +

f2.sigmoid() * right_in[1])

h = o.sigmoid() * c.tanh()

return h, c

class SPINN(nn.Module):

def __init__(self, n_classes, size, n_words):

super(SPINN, self).__init__()

self.size = size

self.tree_lstm = TreeLSTM(size)

self.embeddings = nn.Embedding(n_words, size)

self.out = nn.Linear(size, n_classes)

def leaf(self, word_id):

return self.embeddings(word_id), Variable(torch.FloatTensor(word_id.size()[0], self.size))

def children(self, left_h, left_c, right_h, right_c):

return self.tree_lstm((left_h, left_c), (right_h, right_c))

def logits(self, encoding):

return self.out(encoding)

def encode_tree_regular(model, tree):

def encode_node(node):

if node.is_leaf():

return model.leaf(Variable(torch.LongTensor([node.id])))

else:

left_h, left_c = encode_node(node.left)

right_h, right_c = encode_node(node.right)

return model.children(left_h, left_c, right_h, right_c)

encoding, _ = encode_node(tree.root)

return model.logits(encoding)

...

all_logits, all_labels = [], []

for tree in batch:

all_logits.append(encode_tree_regular(model, tree))

all_labels.append(tree.label)

loss = criterion(torch.cat(all_logits, 0), Variable(torch.LongTensor(all_labels)))

有一種手動批處理的方法:在每次處理輸入不同的操作之後,找出如何批處理輸入,然後解除輸出批處理。這是James Bradbury在其文章中的一個例子。

另一種選擇是,根據我們要計算的確切輸入/輸出,找到一個系統決定為我們的批處理對象。靈感來自Moshe等人的論文中描述的方法。 “動態計算圖深度學習”(在TensorFlow Fold 中實現但似乎並不被支持),在這個動畫中有很好的描繪:

遞歸神經網絡RNN怎樣加速?看PyTorch如何進行動態批處理

來源:http://github.com/tensorflow/fold

我已經在一個簡單的TorchFold中實現了這個原理:

class TorchFold(object):

def __init__(self, versatible=False, cuda=False):

...

def add(self, op, *args):

...

def apply(self, nn, return_values):

...

現在,如果我們想用以前的gist對樹形LSTM / 模型進行編碼,那麼我們需要這樣更改代碼:

from pytorch_tools import torchfold

def encode_tree_fold(fold, tree):

def encode_node(node):

if node.is_leaf():

return fold.add('leaf', node.id).split(2)

else:

left_h, left_c = encode_node(node.left)

right_h, right_c = encode_node(node.right)

return fold.add('children', left_h, left_c, right_h, right_c).split(2)

encoding, _ = encode_node(tree.root)

return fold.add('logits', encoding)

...

fold = torchfold.Fold(cuda=args.cuda)

all_logits, all_labels = [], []

for tree in batch:

all_logits.append(encode_tree_folded(fold, tree))

all_labels.append(tree.label)

res = fold.apply(model, [all_logits, all_labels])

loss = criterion(res[0], res[1])

這裡,在每次調用encode_tree_folded時,通過fold.add添加節點來動態構建“摺疊”圖,其中op是要調用的模型中的函數的名稱。它會自動顯示哪些op可以組合在一起,哪些應該遵循。

然後在fold.apply,調用傳遞的模型的操作,傳遞它們的批處理的輸入張量(可能在不同的步驟有不同的批處理大小),並自動輸出到接下來的步驟。

比較未摺疊和摺疊版本之間的速度(在這裡的簡單模型https://github.com/nearai/pytorch-tools/blob/master/examples/snli/spinn-example.py):

常規:0.18秒/步(100 dim),2.19秒/步(500 dim)

摺疊:0.05秒/步(100 dim),0.22秒/步(500 dim)

由於降低了計算非有效效率,提升了3-10倍的速度。

該工具通常對於任何複雜的架構(包括RNN)都是有用的,因為它至少在第一個實驗中不需要考慮批處理。

你可以在這裡找到實現和示例:https://github.com/nearai/pytorch-tools

另外,在撰寫本文時,我發現最近有關於這個主題的文章 - https://arxiv.org/pdf/1705.07860.pdf, DyNet的實現。

還有就是,自從升級到PyTorch 0.2.0後,我發現TorchFold的性能略有下降,所以為了最佳速度,嘗試運行0.1.12直到穩定即可。

相關推薦

推薦中...