Version Manager

In Nytorch, two main functionalities are implemented:

  1. Particle operation

  2. Version manager

We have already discussed particle operation in previous sections. Now, let’s delve into the version manager. So, what does the version manager do?

Consider the following example:

import nytorch as nyto
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMTagger(nyto.NytoModule):
    def __init__(self, word_embeddings, embedding_dim, hidden_dim, tagset_size):
        super().__init__()
        self.word_embeddings = word_embeddings
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

        self.set_param_config(operational=False,
                              clone=False,
                              name='word_embeddings')

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        tag_scores = F.softmax(tag_space, dim=1)
        return tag_scores

In this example, we create a model based on LSTM. Given a sentence, it first obtains embeddings, then passes through an LSTM and a classifier to return the class of each word. Here, word_embeddings is pretrained elsewhere and needs to be provided during model creation:

word_embeddings = nn.Embedding(32, 6).requires_grad_(False)
net1 = LSTMTagger(word_embeddings, 6, 6, 3)
net2 = net1.randn()
net3 = net1.randn()

At this point, net1 and net2 should have the same shape of parameters with the same parameter IDs and attributes.

Now, suppose we need to adjust the model structure, such as replacing the existing word_embeddings with a new one that supports 64 words instead of 32:

new_word_embeddings = nn.Embedding(64, 6).requires_grad_(False)
net1.word_embeddings = new_word_embeddings

Considering that such adjustments should be applied to all particles of the species, we would have to manually update each particle:

net2.word_embeddings = new_word_embeddings
net3.word_embeddings = new_word_embeddings

However, this approach is inefficient. Instead, Nytorch provides a series of automatic mechanisms to handle such situations, which is the topic of this chapter.

Version

Before delving into the details, let’s define what a version is. We consider a version as the metadata state of a particle’s module, buffer, and parameter at a certain point in time. When the metadata is modified, the state before modification is regarded as the old version, and the state after modification is regarded as the latest version.

What constitutes a modification to the metadata? From the perspective of a particle, we consider the following six events as modifications:

  1. Adding a module

  2. Deleting a module

  3. Adding a buffer

  4. Deleting a buffer

  5. Adding a parameter

  6. Deleting a parameter

Let’s look at a specific example. We can access the version instance pointed to by a particle using _version_kernel:

from torch import nn
import nytorch as nyto
import torch

class Linear(nyto.NytoModule):
    def __init__(self, w: float, b: float) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.Tensor([w]))
        self.bias = nn.Parameter(torch.Tensor([b]))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.weight * x + self.bias

net1 = Linear(1, 2)
net2 = net1.randn()

version_before_del = net1._version_kernel

When the metadata is modified, the system records relevant information and saves it to a version instance. Simultaneously, a new version instance is created as the latest version:

del net1.weight

version_after_del = net1._version_kernel
>>> version_before_del is not version_after_del
True

At this point, we observe that the metadata of net1 has changed, and the version instance it points to has also changed. However, if we inspect net2, we find that both its metadata and the version instance it points to remain unchanged:

>>> hasattr(net2, 'weight')
True

>>> version_before_del is net2._version_kernel
True

To upgrade net2 to the latest version, we can call the touch() method. Nytorch automatically updates the particle to the latest version based on the previously recorded modification information:

net2.touch()
>>> hasattr(net2, 'weight')
False

>>> version_after_del is net2._version_kernel
True

In practical use, frequent use of touch() is unnecessary because whenever a particle operation or metadata modification occurs, touch() is automatically invoked to ensure the particle is at the latest version.

Update Behavior

Next, we’ll discuss the differences in update behavior for adding and deleting modules, buffers, and parameters between the particle initiating the event and other particles of the same species.

Here, we define two new terms: event initiator and event recipient.

The event initiator is the particle where the modification event occurs, while the event recipient is another particle of the same species as the event initiator. Consider the following example, where net1 is the event initiator and net2 is the event recipient:

net1 = Linear(10, 5)
net2 = net1.randn()

del net1.weight

net2.touch()

For certain modification events, such as adding a parameter, the update behavior differs between the event initiator and the event recipient. For example, when adding a parameter, the event initiator adds the parameter itself, while the event recipient adds a clone of the paramete:

net1 = Linear(1, 2)
net2 = net1.randn()

add_parameter = nn.Parameter(torch.randn(1))
net1.add_parameter = add_parameter
net2.touch()
>>> net1.add_parameter is add_parameter
True

>>> net2.add_parameter is add_parameter
False

>>> torch.equal(net2.add_parameter, add_parameter)
True

Similarly, for adding modules, the event initiator adds the module itself, while the event recipient adds a clone of the module:

add_linear = nn.Linear(3, 4)
net1.add_linear = add_linear
net2.touch()
>>> net1.add_linear is add_linear
True

>>> net2.add_linear is add_linear
False

>>> torch.equal(net2.add_linear.weight, add_linear.weight)
True

>>> torch.equal(net2.add_linear.bias, add_linear.bias)
True

However, for adding buffers, both the event initiator and event receiver add a reference to the same buffer:

add_tensor = torch.randn(3, 3)
net1.register_buffer("add_tensor", add_tensor)
net2.touch()
>>> net1.add_tensor is add_tensor
True

>>> net2.add_tensor is add_tensor
True

As for deleting modules, buffers, and parameters, there’s no difference in behavior between event initiators and event receivers.

Avoiding Update Behavior

In nytorch, any version update to a NytoModule instance is mandatory. This means that in nytorch, you cannot avoid a version update after making any of the following six events as modifications:

  1. Adding a module

  2. Deleting a module

  3. Adding a buffer

  4. Deleting a buffer

  5. Adding a parameter

  6. Deleting a parameter

However, we may want to avoid triggering a version update in certain situations. Below is a specific example:

import torch
import torch.nn as nn
import nytorch as nyto

class MyModule(nyto.NytoModule):
    def __init__(self, rate=0.5):
        super().__init__()
        self.layer = nn.Linear(4, 1)
        self.register_buffer("mean", torch.zeros(1))
        self.rate = rate

    def forward(self, x):
        y = self.layer(x)
        self.update_mean(y)
        return y

    def update_mean(self, y):
        with torch.no_grad():
            self.mean = (1-self.rate)*self.mean + self.rate*y.mean()

model = MyModule()
model_clone = model.clone()

We need to record the output mean in the buffer each time. Our expectation is that different particles of the same species have their own independent mean records. However, you will find that when one of the particles updates the mean, all particles of the same species are updated:

>>> model.mean
tensor([0.])
>>> model_clone.mean
tensor([0.])
>>> y = model(torch.randn(10, 4))
>>> model.mean
tensor([0.2386])
>>> model_clone.touch().mean
tensor([0.2386])

For this, we have two different solutions.

Solution 1: Move the buffer to a torch.nn.Module

The first solution is to move the buffer that needs to be modified to a torch.nn.Module. This way, when the buffer is modified, it will not trigger a version update:

import torch
import torch.nn as nn
import nytorch as nyto

class MeanRecord(nn.Module):
    def __init__(self, rate=0.5):
        super().__init__()
        self.register_buffer("mean", torch.zeros(1))
        self.rate = rate

    def update_mean(self, output):
        with torch.no_grad():
            self.mean = (1-self.rate)*self.mean + self.rate*output.mean()

class MyModule(nyto.NytoModule):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(4, 1)
        self.mean_record = MeanRecord()

    def forward(self, x):
        y = self.layer(x)
        self.mean_record.update_mean(y)
        return y

In the example above, we define a MeanRecord class to record the mean of the MyModule and store the mean in the buffer. This way, different particles of the same species have their own independent mean records:

>>> model.mean_record.mean
tensor([0.])
>>> model_clone.mean_record.mean
tensor([0.])
>>> y = model(torch.randn(10, 4))
>>> model.mean_record.mean
tensor([0.2386])
>>> model_clone.touch().mean_record.mean
tensor([0.])

Solution 2: Use super().register_buffer to modify the buffer

The second solution is to use super().register_buffer to modify the buffer, which can also avoid triggering a version update:

import torch
import torch.nn as nn
import nytorch as nyto

class MyModule(nyto.NytoModule):
    def __init__(self, rate=0.5):
        super().__init__()
        self.layer = nn.Linear(4, 1)
        self.register_buffer("mean", torch.zeros(1))
        self.rate = rate

    def forward(self, x):
        y = self.layer(x)
        self.update_mean(y)
        return y

    def update_mean(self, y):
        with torch.no_grad():
            new_mean = (1-self.rate)*self.mean + self.rate*y.mean()
            super().register_buffer("mean", new_mean)

Warning

Although users can also avoid version updates by modifying parameters and modules using the above methods, it is not recommended to do so, as it carries the risk of causing incorrect behavior.