.. _sec_rnn-concise:
Concise Implementation of Recurrent Neural Networks
===================================================
Like most of our from-scratch implementations,
:numref:`sec_rnn-scratch` was designed to provide insight into how
each component works. But when you are using RNNs every day or writing
production code, you will want to rely more on libraries that cut down
on both implementation time (by supplying library code for common models
and functions) and computation time (by optimizing the heck out of these
library implementations). This section will show you how to implement
the same language model more efficiently using the high-level API
provided by your deep learning framework. We begin, as before, by
loading *The Time Machine* dataset.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import np, npx
from mxnet.gluon import nn, rnn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
from d2l import tensorflow as d2l
.. raw:: html
.. raw:: html
Defining the Model
------------------
We define the following class using the RNN implemented by high-level
APIs.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNN(d2l.Module): #@save
"""The RNN model implemented with high-level APIs."""
def __init__(self, num_inputs, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = nn.RNN(num_inputs, num_hiddens)
def forward(self, inputs, H=None):
return self.rnn(inputs, H)
.. raw:: html
.. raw:: html
Specifically, to initialize the hidden state, we invoke the member
method ``begin_state``. This returns a list that contains an initial
hidden state for each example in the minibatch, whose shape is (number
of hidden layers, batch size, number of hidden units). For some models
to be introduced later (e.g., long short-term memory), this list will
also contain other information.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNN(d2l.Module): #@save
"""The RNN model implemented with high-level APIs."""
def __init__(self, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = rnn.RNN(num_hiddens)
def forward(self, inputs, H=None):
if H is None:
H, = self.rnn.begin_state(inputs.shape[1], ctx=inputs.ctx)
outputs, (H, ) = self.rnn(inputs, (H, ))
return outputs, H
.. raw:: html
.. raw:: html
Flax does not provide an RNNCell for concise implementation of Vanilla
RNNs as of today. There are more advanced variants of RNNs like LSTMs
and GRUs which are available in the Flax ``linen`` API.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNN(nn.Module): #@save
"""The RNN model implemented with high-level APIs."""
num_hiddens: int
@nn.compact
def __call__(self, inputs, H=None):
raise NotImplementedError
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNN(d2l.Module): #@save
"""The RNN model implemented with high-level APIs."""
def __init__(self, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = tf.keras.layers.SimpleRNN(
num_hiddens, return_sequences=True, return_state=True,
time_major=True)
def forward(self, inputs, H=None):
outputs, H = self.rnn(inputs, H)
return outputs, H
.. raw:: html
.. raw:: html
Inheriting from the ``RNNLMScratch`` class in
:numref:`sec_rnn-scratch`, the following ``RNNLM`` class defines a
complete RNN-based language model. Note that we need to create a
separate fully connected output layer.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNLM(d2l.RNNLMScratch): #@save
"""The RNN-based language model implemented with high-level APIs."""
def init_params(self):
self.linear = nn.LazyLinear(self.vocab_size)
def output_layer(self, hiddens):
return self.linear(hiddens).swapaxes(0, 1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNLM(d2l.RNNLMScratch): #@save
"""The RNN-based language model implemented with high-level APIs."""
def init_params(self):
self.linear = nn.Dense(self.vocab_size, flatten=False)
self.initialize()
def output_layer(self, hiddens):
return self.linear(hiddens).swapaxes(0, 1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNLM(d2l.RNNLMScratch): #@save
"""The RNN-based language model implemented with high-level APIs."""
training: bool = True
def setup(self):
self.linear = nn.Dense(self.vocab_size)
def output_layer(self, hiddens):
return self.linear(hiddens).swapaxes(0, 1)
def forward(self, X, state=None):
embs = self.one_hot(X)
rnn_outputs, _ = self.rnn(embs, state, self.training)
return self.output_layer(rnn_outputs)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNLM(d2l.RNNLMScratch): #@save
"""The RNN-based language model implemented with high-level APIs."""
def init_params(self):
self.linear = tf.keras.layers.Dense(self.vocab_size)
def output_layer(self, hiddens):
return tf.transpose(self.linear(hiddens), (1, 0, 2))
.. raw:: html
.. raw:: html
Training and Predicting
-----------------------
Before training the model, let’s make a prediction with a model
initialized with random weights. Given that we have not trained the
network, it will generate nonsensical predictions.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNN(num_inputs=len(data.vocab), num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it hasoadd dd dd dd dd dd '
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNN(num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[22:52:51] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it hasxlxlxlxlxlxlxlxlxlxl'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNN(num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it hasretsnrnrxnrnrgczntgq'
.. raw:: html
.. raw:: html
Next, we train our model, leveraging the high-level API.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
.. figure:: output_rnn-concise_eff2f4_62_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
.. figure:: output_rnn-concise_eff2f4_65_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
with d2l.try_gpu():
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1)
trainer.fit(model, data)
.. figure:: output_rnn-concise_eff2f4_68_0.svg
.. raw:: html
.. raw:: html
Compared with :numref:`sec_rnn-scratch`, this model achieves
comparable perplexity, but runs faster due to the optimized
implementations. As before, we can generate predicted tokens following
the specified prefix string.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab, d2l.try_gpu())
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has and the trave the t'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab, d2l.try_gpu())
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has and the time the ti'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has and the pas an and '
.. raw:: html
.. raw:: html
Summary
-------
High-level APIs in deep learning frameworks provide implementations of
standard RNNs. These libraries help you to avoid wasting time
reimplementing standard models. Moreover, framework implementations are
often highly optimized, leading to significant (computational)
performance gains when compared with implementations from scratch.
Exercises
---------
1. Can you make the RNN model overfit using the high-level APIs?
2. Implement the autoregressive model of :numref:`sec_sequence` using
an RNN.
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html