.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
board.draw(x, np.sin(x), 'sin', every_n=2)
board.draw(x, np.cos(x), 'cos', every_n=10)
.. figure:: output_oo-design_a0c19f_56_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
board.draw(x, np.sin(x), 'sin', every_n=2)
board.draw(x, np.cos(x), 'cos', every_n=10)
.. figure:: output_oo-design_a0c19f_59_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
board.draw(x, np.sin(x), 'sin', every_n=2)
board.draw(x, np.cos(x), 'cos', every_n=10)
.. figure:: output_oo-design_a0c19f_62_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
board.draw(x, np.sin(x), 'sin', every_n=2)
board.draw(x, np.cos(x), 'cos', every_n=10)
.. figure:: output_oo-design_a0c19f_65_0.svg
.. raw:: html
.. raw:: html
.. _subsec_oo-design-models:
Models
------
The ``Module`` class is the base class of all models we will implement.
At the very least we need three methods. The first, ``__init__``, stores
the learnable parameters, the ``training_step`` method accepts a data
batch to return the loss value, and finally, ``configure_optimizers``
returns the optimization method, or a list of them, that is used to
update the learnable parameters. Optionally we can define
``validation_step`` to report the evaluation measures. Sometimes we put
the code for computing the output into a separate ``forward`` method to
make it more reusable.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Module(nn.Module, d2l.HyperParameters): #@save
"""The base class of models."""
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
def loss(self, y_hat, y):
raise NotImplementedError
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
self.board.draw(x, value.to(d2l.cpu()).detach().numpy(),
('train_' if train else 'val_') + key,
every_n=int(n))
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
def configure_optimizers(self):
raise NotImplementedError
You may notice that ``Module`` is a subclass of ``nn.Module``, the base
class of neural networks in PyTorch. It provides convenient features for
handling neural networks. For example, if we define a ``forward``
method, such as ``forward(self, X)``, then for an instance ``a`` we can
invoke this method by ``a(X)``. This works since it calls the
``forward`` method in the built-in ``__call__`` method. You can find
more details and examples about ``nn.Module`` in
:numref:`sec_model_construction`.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Module(nn.Block, d2l.HyperParameters): #@save
"""The base class of models."""
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
def loss(self, y_hat, y):
raise NotImplementedError
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
self.board.draw(x, value.asnumpy(), (
'train_' if train else 'val_') + key, every_n=int(n))
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
def configure_optimizers(self):
raise NotImplementedError
You may notice that ``Module`` is a subclass of ``nn.Block``, the base
class of neural networks in Gluon. It provides convenient features for
handling neural networks. For example, if we define a ``forward``
method, such as ``forward(self, X)``, then for an instance ``a`` we can
invoke this method by ``a(X)``. This works since it calls the
``forward`` method in the built-in ``__call__`` method. You can find
more details and examples about ``nn.Block`` in
:numref:`sec_model_construction`.
.. raw:: html
.. raw:: html
With the introduction of
`dataclasses `__ in
Python 3.7, classes decorated with ``@dataclass`` automatically add
magic methods such as ``__init__`` and ``__repr__``. The member
variables are defined using type annotations. All Flax modules are
Python 3.7 dataclasses.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Module(nn.Module, d2l.HyperParameters): #@save
"""The base class of models."""
# No need for save_hyperparam when using Python dataclass
plot_train_per_epoch: int = field(default=2, init=False)
plot_valid_per_epoch: int = field(default=1, init=False)
# Use default_factory to make sure new plots are generated on each run
board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
init=False)
def loss(self, y_hat, y):
raise NotImplementedError
# JAX & Flax do not have a forward-method-like syntax. Flax uses setup
# and built-in __call__ magic methods for forward pass. Adding here
# for consistency
def forward(self, X, *args, **kwargs):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X, *args, **kwargs)
def __call__(self, X, *args, **kwargs):
return self.forward(X, *args, **kwargs)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
self.board.draw(x, jax.device_put(value, d2l.cpu()),
('train_' if train else 'val_') + key,
every_n=int(n))
def training_step(self, params, batch, state):
l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
batch[-1], state)
self.plot("loss", l, train=True)
return l, grads
def validation_step(self, params, batch, state):
l = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
def apply_init(self, dummy_input, key):
"""To be defined later in :numref:`sec_lazy_init`"""
raise NotImplementedError
def configure_optimizers(self):
raise NotImplementedError
You may notice that ``Module`` is a subclass of ``linen.Module``, the
base class of neural networks in Flax. It provides convenient features
for handling neural networks. For example, it handles the model
parameters, provides the ``nn.compact`` decorator to simplify code,
invokes the ``__call__`` method among other things. Here we also
redirect ``__call__`` to the ``forward`` method. We do this to make our
code more similar to other framework implementations.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Module(tf.keras.Model, d2l.HyperParameters): #@save
"""The base class of models."""
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
self.training = None
def loss(self, y_hat, y):
raise NotImplementedError
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
def call(self, X, *args, **kwargs):
if kwargs and "training" in kwargs:
self.training = kwargs['training']
return self.forward(X, *args)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
self.board.draw(x, value.numpy(), (
'train_' if train else 'val_') + key, every_n=int(n))
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
def configure_optimizers(self):
raise NotImplementedError
You may notice that ``Module`` is a subclass of ``tf.keras.Model``, the
base class of neural networks in TensorFlow. It provides convenient
features for handling neural networks. For example, it invokes the
``call`` method in the built-in ``__call__`` method. Here we redirect
``call`` to the ``forward`` method, saving its arguments as a class
attribute. We do this to make our code more similar to other framework
implementations.
.. raw:: html
.. raw:: html
.. _oo-design-data:
Data
----
The ``DataModule`` class is the base class for data. Quite frequently
the ``__init__`` method is used to prepare the data. This includes
downloading and preprocessing if needed. The ``train_dataloader``
returns the data loader for the training dataset. A data loader is a
(Python) generator that yields a data batch each time it is used. This
batch is then fed into the ``training_step`` method of ``Module`` to
compute the loss. There is an optional ``val_dataloader`` to return the
validation dataset loader. It behaves in the same manner, except that it
yields data batches for the ``validation_step`` method in ``Module``.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DataModule(d2l.HyperParameters): #@save
"""The base class of data."""
def __init__(self, root='../data', num_workers=4):
self.save_hyperparameters()
def get_dataloader(self, train):
raise NotImplementedError
def train_dataloader(self):
return self.get_dataloader(train=True)
def val_dataloader(self):
return self.get_dataloader(train=False)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DataModule(d2l.HyperParameters): #@save
"""The base class of data."""
def __init__(self, root='../data', num_workers=4):
self.save_hyperparameters()
def get_dataloader(self, train):
raise NotImplementedError
def train_dataloader(self):
return self.get_dataloader(train=True)
def val_dataloader(self):
return self.get_dataloader(train=False)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DataModule(d2l.HyperParameters): #@save
"""The base class of data."""
def __init__(self, root='../data'):
self.save_hyperparameters()
def get_dataloader(self, train):
raise NotImplementedError
def train_dataloader(self):
return self.get_dataloader(train=True)
def val_dataloader(self):
return self.get_dataloader(train=False)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DataModule(d2l.HyperParameters): #@save
"""The base class of data."""
def __init__(self, root='../data'):
self.save_hyperparameters()
def get_dataloader(self, train):
raise NotImplementedError
def train_dataloader(self):
return self.get_dataloader(train=True)
def val_dataloader(self):
return self.get_dataloader(train=False)
.. raw:: html
.. raw:: html
.. _oo-design-training:
Training
--------
.. raw:: html
.. raw:: html
The ``Trainer`` class trains the learnable parameters in the ``Module``
class with data specified in ``DataModule``. The key method is ``fit``,
which accepts two arguments: ``model``, an instance of ``Module``, and
``data``, an instance of ``DataModule``. It then iterates over the
entire dataset ``max_epochs`` times to train the model. As before, we
will defer the implementation of this method to later chapters.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Trainer(d2l.HyperParameters): #@save
"""The base class for training models with data."""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
assert num_gpus == 0, 'No GPU support yet'
def prepare_data(self, data):
self.train_dataloader = data.train_dataloader()
self.val_dataloader = data.val_dataloader()
self.num_train_batches = len(self.train_dataloader)
self.num_val_batches = (len(self.val_dataloader)
if self.val_dataloader is not None else 0)
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
self.model = model
def fit(self, model, data):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
def fit_epoch(self):
raise NotImplementedError
.. raw:: html
.. raw:: html
The ``Trainer`` class trains the learnable parameters in the ``Module``
class with data specified in ``DataModule``. The key method is ``fit``,
which accepts two arguments: ``model``, an instance of ``Module``, and
``data``, an instance of ``DataModule``. It then iterates over the
entire dataset ``max_epochs`` times to train the model. As before, we
will defer the implementation of this method to later chapters.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Trainer(d2l.HyperParameters): #@save
"""The base class for training models with data."""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
assert num_gpus == 0, 'No GPU support yet'
def prepare_data(self, data):
self.train_dataloader = data.train_dataloader()
self.val_dataloader = data.val_dataloader()
self.num_train_batches = len(self.train_dataloader)
self.num_val_batches = (len(self.val_dataloader)
if self.val_dataloader is not None else 0)
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
self.model = model
def fit(self, model, data):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
def fit_epoch(self):
raise NotImplementedError
.. raw:: html
.. raw:: html
The ``Trainer`` class trains the learnable parameters ``params`` with
data specified in ``DataModule``. The key method is ``fit``, which
accepts three arguments: ``model``, an instance of ``Module``, ``data``,
an instance of ``DataModule``, and ``key``, a JAX ``PRNGKeyArray``. We
make the ``key`` argument optional here to simplify the interface, but
it is recommended to always pass and initialize the model parameters
with a root key in JAX and Flax. It then iterates over the entire
dataset ``max_epochs`` times to train the model. As before, we will
defer the implementation of this method to later chapters.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Trainer(d2l.HyperParameters): #@save
"""The base class for training models with data."""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
assert num_gpus == 0, 'No GPU support yet'
def prepare_data(self, data):
self.train_dataloader = data.train_dataloader()
self.val_dataloader = data.val_dataloader()
self.num_train_batches = len(self.train_dataloader)
self.num_val_batches = (len(self.val_dataloader)
if self.val_dataloader is not None else 0)
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
self.model = model
def fit(self, model, data, key=None):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
if key is None:
root_key = d2l.get_key()
else:
root_key = key
params_key, dropout_key = jax.random.split(root_key)
key = {'params': params_key, 'dropout': dropout_key}
dummy_input = next(iter(self.train_dataloader))[:-1]
variables = model.apply_init(dummy_input, key=key)
params = variables['params']
if 'batch_stats' in variables.keys():
# Here batch_stats will be used later (e.g., for batch norm)
batch_stats = variables['batch_stats']
else:
batch_stats = {}
# Flax uses optax under the hood for a single state obj TrainState.
# More will be discussed later in the dropout and batch
# normalization section
class TrainState(train_state.TrainState):
batch_stats: Any
dropout_rng: jax.random.PRNGKeyArray
self.state = TrainState.create(apply_fn=model.apply,
params=params,
batch_stats=batch_stats,
dropout_rng=dropout_key,
tx=model.configure_optimizers())
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
def fit_epoch(self):
raise NotImplementedError
.. raw:: html
.. raw:: html
The ``Trainer`` class trains the learnable parameters in the ``Module``
class with data specified in ``DataModule``. The key method is ``fit``,
which accepts two arguments: ``model``, an instance of ``Module``, and
``data``, an instance of ``DataModule``. It then iterates over the
entire dataset ``max_epochs`` times to train the model. As before, we
will defer the implementation of this method to later chapters.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Trainer(d2l.HyperParameters): #@save
"""The base class for training models with data."""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
assert num_gpus == 0, 'No GPU support yet'
def prepare_data(self, data):
self.train_dataloader = data.train_dataloader()
self.val_dataloader = data.val_dataloader()
self.num_train_batches = len(self.train_dataloader)
self.num_val_batches = (len(self.val_dataloader)
if self.val_dataloader is not None else 0)
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
self.model = model
def fit(self, model, data):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
def fit_epoch(self):
raise NotImplementedError
.. raw:: html
.. raw:: html
Summary
-------
To highlight the object-oriented design for our future deep learning
implementation, the above classes simply show how their objects store
data and interact with each other. We will keep enriching
implementations of these classes, such as via ``@add_to_class``, in the
rest of the book. Moreover, these fully implemented classes are saved in
the `D2L library