from __future__ import annotations
from collections import ChainMap, OrderedDict
from collections.abc import Iterable
from .base import ParticleDataImp, VersionDataImp
from .kernel import ParticleData, ParticleUpdater, VersionData, VersionUpdater
from .mtype import ConfigDict, MetaDict, Module, ModuleDict, ModuleID, ModuleMeta, ParamConfig, ParamDict, ParamID, ParamType
from .particle_updater import AddModuleParticleUpdater, AddParamParticleUpdater, DelBufferParticleUpdater, DelModuleParticleUpdater, DelParamParticleUpdater, RegisterBufferParticleUpdater, SetModuleNoneParticleUpdater, SetParamNoneParticleUpdater
from typing import Generic, Optional, TypeVar
from typing_extensions import Self
import torch
[docs]class AddModuleVersionUpdater(VersionUpdater[VersionDataImp, ParticleDataImp]):
r"""
Updater for VersionKernel instances that adds a module to a particle.
This class facilitates the addition of a module to a particle by updating both version and particle data.
Args:
module_id (ModuleID):
ID of the module where the new module is added.
attr_name (str):
Name of the attribute where the new module is added.
add_module (Module):
The module to be added.
Attributes:
module_id (ModuleID):
ID of the module in the particle where the module is added.
attr_name (str):
Name of the attribute where the module is added.
add_module (Module):
The module to be added to the particle.
owner_modules (Optional[ModuleDict]):
Modules of the particle initiating the event.
owner_params (Optional[ParamDict]):
Parameters of the particle initiating the event.
Example::
class MyNet(NytoModule):
...
net0 = MyNet()
net1 = MyNet()
particle_kernel: ParticleKernel = net0._particle_kernel
particle_kernel.version_update(AddModuleVersionUpdater(net0._module_id,
"attar_net1",
net1))
>>> net0.attar_net1 is net1
True
"""
module_id: ModuleID
attr_name: str
add_module: Module
owner_modules: Optional[ModuleDict]
owner_params: Optional[ParamDict]
def __init__(self, module_id: ModuleID, attr_name: str, add_module: Module) -> None:
r"""
Initialize the AddModuleVersionUpdater.
Args:
module_id (ModuleID):
ID of the module where the new module is added.
attr_name (str):
Name of the attribute where the new module is added.
add_module (Module):
The module to be added.
"""
self.module_id = module_id
self.attr_name = attr_name
self.add_module = add_module
self.owner_modules = None
self.owner_params = None
[docs] def set_version_data(self, vdata: VersionDataImp) -> Self:
return self
[docs] def set_particle_data(self, pdata: ParticleDataImp) -> Self:
self.owner_modules, self.owner_params = pdata.modules, pdata.params
return self
[docs] def run(self, vdata: VersionDataImp) -> tuple[VersionDataImp, ParticleUpdater]:
K = TypeVar("K", ModuleID, ParamID)
V = TypeVar("V", Module, ParamType)
class UniqueDict(Generic[K, V]):
def __init__(self, data: OrderedDict[K, V]) -> None:
self.data: OrderedDict[K, V] = data
self.add_data: OrderedDict[K, V] = OrderedDict()
def _next_key(self) -> K:
if len(ChainMap(self.data, self.add_data)) == 0:
return 0
return max(ChainMap(self.data, self.add_data).keys())+1
def add_value(self, value: Optional[V]) -> None:
if value is None: return
for k, v in self.data.items():
if value is v: return
self.add_data[self._next_key()] = value
def add_values(self, values: Iterable[Optional[V]]) -> None:
for v in values: self.add_value(v)
def value_to_key(self, value: V) -> K:
for k, v in ChainMap(self.data, self.add_data).items():
if value is v: return k
raise ValueError(f"can not found {value}")
assert self.owner_modules is not None
assert self.owner_params is not None
unq_modules = UniqueDict[ModuleID, Module](self.owner_modules)
unq_params = UniqueDict[ParamID, ParamType](self.owner_params)
for mod in self.add_module.modules():
unq_modules.add_value(mod)
unq_params.add_values(mod._parameters.values())
add_modules_meta: MetaDict = OrderedDict(
(unq_modules.value_to_key(mod),
ModuleMeta(sub_modules=OrderedDict((sub_name, sub_mod)
if sub_mod is None else
(sub_name, unq_modules.value_to_key(sub_mod))
for sub_name, sub_mod in mod._modules.items()),
sub_params=OrderedDict((sub_name, sub_param)
if sub_param is None else
(sub_name, unq_params.value_to_key(sub_param))
for sub_name, sub_param in mod._parameters.items())))
for mid, mod in unq_modules.add_data.items())
add_params_config: ConfigDict = OrderedDict((pid, ParamConfig()) for pid, param in unq_params.add_data.items())
add_modules_id: ModuleID = unq_modules.value_to_key(self.add_module)
new_vdata: VersionDataImp = vdata.copy()
if self.attr_name in new_vdata.meta[self.module_id].sub_params:
del new_vdata.meta[self.module_id].sub_params[self.attr_name]
new_vdata.meta[self.module_id].sub_modules[self.attr_name] = add_modules_id
new_vdata.meta.update(add_modules_meta)
new_vdata.config.update(add_params_config)
remove_modules: set[ModuleID] = new_vdata.get_garbage_modules()
new_vdata.remove_modules(remove_modules)
remove_params: set[ParamID] = new_vdata.get_garbage_params()
new_vdata.remove_params(remove_params)
return (new_vdata,
AddModuleParticleUpdater(self.module_id,
self.attr_name,
add_modules_id,
unq_modules.add_data,
unq_params.add_data,
remove_modules,
remove_params))
[docs]class AddParamVersionUpdater(VersionUpdater[VersionDataImp, ParticleDataImp]):
r"""
Updater for VersionKernel instances that adds a parameter to a particle.
Args:
module_id (ModuleID):
ID of the module where the new parameter is added.
attr_name (str):
Name of the attribute where the new parameter is added.
add_param (ParamType):
The parameter to be added.
Attributes:
module_id (ModuleID):
ID of the module in the particle where the parameter is added.
attr_name (str):
Name of the attribute where the parameter is added.
add_param (ParamType):
The parameter to be added to the particle.
owner_params (Optional[ParamDict]):
Parameters of the particle initiating the event.
Example::
class MyNet(NytoModule):
...
net = MyNet()
add_param = torch.nn.Parameter(torch.randn(3, 3))
particle_kernel: ParticleKernel = net._particle_kernel
particle_kernel.version_update(AddParamVersionUpdater(net._module_id,
"attar_add_param",
add_param))
>>> net.attar_add_param is add_param
True
"""
module_id: ModuleID
attr_name: str
add_param: ParamType
owner_params: Optional[ParamDict]
def __init__(self, module_id: ModuleID, attr_name: str, add_param: ParamType) -> None:
"""
Initialize the AddParamVersionUpdater.
Args:
module_id (ModuleID):
ID of the module where the new parameter is added.
attr_name (str):
Name of the attribute where the new parameter is added.
add_param (ParamType):
The parameter to be added.
"""
self.module_id = module_id
self.attr_name = attr_name
self.add_param = add_param
self.owner_params = None
[docs] def set_version_data(self, vdata: VersionDataImp) -> Self:
return self
[docs] def set_particle_data(self, pdata: ParticleDataImp) -> Self:
self.owner_params = pdata.params
return self
[docs] def run(self, vdata: VersionDataImp) -> tuple[VersionDataImp, ParticleUpdater]:
def is_in_dict(data, data_dict) -> bool:
for k, v in data_dict.items():
if v is data: return True
return False
if is_in_dict(self.add_param, self.owner_params):
def value_to_key(data, data_dict):
for k, v in data_dict.items():
if v is data: return k
raise ValueError(f"can not found {value}")
pid: ParamID = value_to_key(self.add_param, self.owner_params)
new_vdata: VersionDataImp = vdata.copy()
if self.attr_name in new_vdata.meta[self.module_id].sub_modules:
del new_vdata.meta[self.module_id].sub_modules[self.attr_name]
new_vdata.meta[self.module_id].sub_params[self.attr_name] = pid
remove_modules: set[ModuleID] = new_vdata.get_garbage_modules()
new_vdata.remove_modules(remove_modules)
remove_params: set[ParamID] = new_vdata.get_garbage_params()
new_vdata.remove_params(remove_params)
return (new_vdata,
AddParamParticleUpdater(self.module_id,
self.attr_name,
pid,
remove_modules,
remove_params))
def get_next_key(param_dict: ConfigDict) -> ParamID:
if len(param_dict) == 0:
return 0
return max(param_dict.keys()) + 1
next_key: ParamID = get_next_key(vdata.config)
new_vdata = vdata.copy()
if self.attr_name in new_vdata.meta[self.module_id].sub_modules:
del new_vdata.meta[self.module_id].sub_modules[self.attr_name]
new_vdata.config[next_key] = ParamConfig()
new_vdata.meta[self.module_id].sub_params[self.attr_name] = next_key
remove_modules = new_vdata.get_garbage_modules()
new_vdata.remove_modules(remove_modules)
remove_params = new_vdata.get_garbage_params()
new_vdata.remove_params(remove_params)
return (new_vdata,
AddParamParticleUpdater(self.module_id,
self.attr_name,
next_key,
remove_modules,
remove_params,
self.add_param))
[docs]class RegisterBufferVersionUpdater(VersionUpdater[VersionDataImp, ParticleDataImp]):
r"""
Updater for VersionKernel instances that registers a buffer to a particle.
Args:
module_id (ModuleID):
ID of the module where the buffer is added.
attr_name (str):
Name of the attribute where the buffer is added.
value (Optional[torch.Tensor]):
The buffer to be added to the particle.
persistent (bool):
If True, the buffer becomes part of the module and is saved or loaded with it.
Attributes:
module_id (ModuleID):
ID of the module in the particle where the buffer is added.
attr_name (str):
Name of the attribute where the buffer is added.
value (Optional[torch.Tensor]):
The buffer to be added to the particle.
persistent (bool):
If True, the buffer becomes part of the module and is saved or loaded with it.
Example::
class MyNet(NytoModule):
...
net = MyNet()
add_tensor = torch.randn(3, 3)
particle_kernel: ParticleKernel = net._particle_kernel
particle_kernel.version_update(RegisterBufferVersionUpdater(net._module_id,
"attar_add_tensor",
add_tensor))
>>> net.attar_add_tensor is add_tensor
True
"""
module_id: ModuleID
attr_name: str
value: Optional[torch.Tensor]
persistent: bool
def __init__(self,
module_id: ModuleID,
attr_name: str,
value: Optional[torch.Tensor],
persistent: bool) -> None:
"""
Initialize the RegisterBufferVersionUpdater.
Args:
module_id (ModuleID):
ID of the module where the buffer is added.
attr_name (str):
Name of the attribute where the buffer is added.
value (Optional[torch.Tensor]):
The buffer to be added to the particle.
persistent (bool):
If True, the buffer becomes part of the module and is saved or loaded with it.
"""
self.module_id = module_id
self.attr_name = attr_name
self.value = value
self.persistent = persistent
[docs] def set_version_data(self, vdata: VersionDataImp) -> Self:
return self
[docs] def set_particle_data(self, pdata: ParticleDataImp) -> Self:
return self
[docs] def run(self, vdata: VersionDataImp) -> tuple[VersionDataImp, ParticleUpdater]:
new_vdata: VersionDataImp = vdata.copy()
if self.attr_name in new_vdata.meta[self.module_id].sub_params:
del new_vdata.meta[self.module_id].sub_params[self.attr_name]
if self.attr_name in new_vdata.meta[self.module_id].sub_modules:
del new_vdata.meta[self.module_id].sub_modules[self.attr_name]
remove_modules: set[ModuleID] = new_vdata.get_garbage_modules()
new_vdata.remove_modules(remove_modules)
remove_params: set[ParamID] = new_vdata.get_garbage_params()
new_vdata.remove_params(remove_params)
return (new_vdata,
RegisterBufferParticleUpdater(self.module_id,
self.attr_name,
self.value,
self.persistent,
remove_modules,
remove_params))
[docs]class SetModuleNoneVersionUpdater(VersionUpdater[VersionDataImp, ParticleDataImp]):
r"""
Updater for VersionKernel instances that sets a module property in a particle to None.
Args:
module_id (ModuleID):
ID of the module performing this operation within the particle.
attr_name (str):
Name of the attribute being operated on.
Attributes:
module_id (ModuleID):
ID of the module performing this operation within the particle.
attr_name (str):
Name of the attribute being operated on.
Example::
class MyNet(NytoModule):
...
net = MyNet()
net.attar_module = MyNet()
particle_kernel: ParticleKernel = net._particle_kernel
particle_kernel.version_update(SetModuleNoneVersionUpdater(net._module_id,
"attar_module"))
>>> net.attar_module is None
True
"""
module_id: ModuleID
attr_name: str
def __init__(self, module_id: ModuleID, attr_name: str) -> None:
r"""
Initialize the SetModuleNoneVersionUpdater.
Args:
module_id (ModuleID):
ID of the module performing this operation within the particle.
attr_name (str):
Name of the attribute being operated on.
"""
self.module_id = module_id
self.attr_name = attr_name
[docs] def set_version_data(self, vdata: VersionDataImp) -> Self:
return self
[docs] def set_particle_data(self, pdata: ParticleDataImp) -> Self:
return self
[docs] def run(self, vdata: VersionDataImp) -> tuple[VersionDataImp, ParticleUpdater]:
new_vdata: VersionDataImp = vdata.copy()
if self.attr_name in new_vdata.meta[self.module_id].sub_params:
del new_vdata.meta[self.module_id].sub_params[self.attr_name]
new_vdata.meta[self.module_id].sub_modules[self.attr_name] = None
remove_modules: set[ModuleID] = new_vdata.get_garbage_modules()
new_vdata.remove_modules(remove_modules)
remove_params: set[ParamID] = new_vdata.get_garbage_params()
new_vdata.remove_params(remove_params)
return (new_vdata,
SetModuleNoneParticleUpdater(self.module_id,
self.attr_name,
remove_modules,
remove_params))
[docs]class SetParamNoneVersionUpdater(VersionUpdater[VersionDataImp, ParticleDataImp]):
r"""
Updater for VersionKernel instances that sets a param property in a particle to None.
Args:
module_id (ModuleID):
ID of the module performing this operation within the particle.
attr_name (str):
Name of the attribute being operated on.
Attributes:
module_id (ModuleID):
ID of the module performing this operation within the particle.
attr_name (str):
Name of the attribute being operated on.
Example::
class MyNet(NytoModule):
...
net = MyNet()
net.attar_param = torch.nn.Parameter(torch.randn(3, 3))
particle_kernel: ParticleKernel = net._particle_kernel
particle_kernel.version_update(SetParamNoneVersionUpdater(net._module_id,
"attar_param"))
>>> net.attar_param is None
True
"""
module_id: ModuleID
attr_name: str
def __init__(self, module_id: ModuleID, attr_name: str) -> None:
r"""
Initialize the SetParamNoneVersionUpdater.
Args:
module_id (ModuleID):
ID of the module performing this operation within the particle.
attr_name (str):
Name of the attribute being operated on.
"""
self.module_id = module_id
self.attr_name = attr_name
[docs] def set_version_data(self, vdata: VersionDataImp) -> Self:
return self
[docs] def set_particle_data(self, pdata: ParticleDataImp) -> Self:
return self
[docs] def run(self, vdata: VersionDataImp) -> tuple[VersionDataImp, ParticleUpdater]:
new_vdata: VersionDataImp = vdata.copy()
if self.attr_name in new_vdata.meta[self.module_id].sub_modules:
del new_vdata.meta[self.module_id].sub_modules[self.attr_name]
new_vdata.meta[self.module_id].sub_params[self.attr_name] = None
remove_modules: set[ModuleID] = new_vdata.get_garbage_modules()
new_vdata.remove_modules(remove_modules)
remove_params: set[ParamID] = new_vdata.get_garbage_params()
new_vdata.remove_params(remove_params)
return (new_vdata,
SetParamNoneParticleUpdater(self.module_id,
self.attr_name,
remove_modules,
remove_params))
[docs]class DelModuleVersionUpdater(VersionUpdater[VersionDataImp, ParticleDataImp]):
r"""
Updater for VersionKernel instances that deletes a module attribute from a particle.
Args:
module_id (ModuleID):
ID of the module performing this operation within the particle.
attr_name (str):
Name of the attribute being deleted.
Attributes:
module_id (ModuleID):
ID of the module in particle where this operation is executed.
attr_name (str):
Name of the attribute to be deleted.
Example::
class MyNet(NytoModule):
...
net = MyNet()
net.attar_moduel = MyNet()
particle_kernel: ParticleKernel = net._particle_kernel
particle_kernel.version_update(DelModuleVersionUpdater(net._module_id,
"attar_moduel"))
>>> hasattr(net, "attar_moduel")
False
"""
module_id: ModuleID
attr_name: str
def __init__(self, module_id: ModuleID, attr_name: str) -> None:
"""
Initialize the DelModuleVersionUpdater.
Args:
module_id (ModuleID):
ID of the module performing this operation within the particle.
attr_name (str):
Name of the attribute being deleted.
"""
self.module_id = module_id
self.attr_name = attr_name
[docs] def set_version_data(self, vdata: VersionDataImp) -> Self:
return self
[docs] def set_particle_data(self, pdata: ParticleDataImp) -> Self:
return self
[docs] def run(self, vdata: VersionDataImp) -> tuple[VersionDataImp, ParticleUpdater]:
new_vdata: VersionDataImp = vdata.copy()
del new_vdata.meta[self.module_id].sub_modules[self.attr_name]
remove_modules: set[ModuleID] = new_vdata.get_garbage_modules()
new_vdata.remove_modules(remove_modules)
remove_params: set[ParamID] = new_vdata.get_garbage_params()
new_vdata.remove_params(remove_params)
return (new_vdata,
DelModuleParticleUpdater(self.module_id,
self.attr_name,
remove_modules,
remove_params))
[docs]class DelParamVersionUpdater(VersionUpdater[VersionDataImp, ParticleDataImp]):
r"""
Updater for VersionKernel instances that deletes a parameter attribute from a particle.
Args:
module_id (ModuleID):
ID of the module performing this operation within the particle.
attr_name (str):
Name of the attribute being deleted.
Attributes:
module_id (ModuleID):
ID of the module in particle where this operation is executed.
attr_name (str):
Name of the attribute to be deleted.
Example::
class MyNet(NytoModule):
...
net = MyNet()
net.attar_param = torch.nn.Parameter(torch.randn(3, 3))
particle_kernel: ParticleKernel = net._particle_kernel
particle_kernel.version_update(DelParamVersionUpdater(net._module_id,
"attar_moduel"))
>>> hasattr(net, "attar_param")
False
"""
module_id: ModuleID
attr_name: str
def __init__(self, module_id: ModuleID, attr_name: str) -> None:
"""
Initialize the DelParamVersionUpdater.
Args:
module_id (ModuleID):
ID of the module performing this operation within the particle.
attr_name (str):
Name of the attribute being deleted.
"""
self.module_id = module_id
self.attr_name = attr_name
[docs] def set_version_data(self, vdata: VersionDataImp) -> Self:
return self
[docs] def set_particle_data(self, pdata: ParticleDataImp) -> Self:
return self
[docs] def run(self, vdata: VersionDataImp) -> tuple[VersionDataImp, ParticleUpdater]:
new_vdata: VersionDataImp = vdata.copy()
del new_vdata.meta[self.module_id].sub_params[self.attr_name]
remove_params: set[ParamID] = new_vdata.get_garbage_params()
new_vdata.remove_params(remove_params)
return (new_vdata,
DelParamParticleUpdater(self.module_id,
self.attr_name,
remove_params))
[docs]class DelBufferVersionUpdater(VersionUpdater[VersionDataImp, ParticleDataImp]):
r"""
Updater for VersionKernel instances that deletes a buffer attribute from a particle.
Args:
module_id (ModuleID):
ID of the module performing this operation within the particle.
attr_name (str):
Name of the attribute being deleted.
Attributes:
module_id (ModuleID):
ID of the module in particle where this operation is executed.
attr_name (str):
Name of the attribute to be deleted.
Example::
class MyNet(NytoModule):
...
net = MyNet()
net.register_buffer("attar_buffer", torch.randn(3, 3))
particle_kernel: ParticleKernel = net._particle_kernel
particle_kernel.version_update(DelParamVersionUpdater(net._module_id,
"attar_buffer"))
>>> hasattr(net, "attar_buffer")
False
"""
module_id: ModuleID
attr_name: str
def __init__(self, module_id: ModuleID, attr_name: str) -> None:
"""
Initialize the DelBufferVersionUpdater.
Args:
module_id (ModuleID):
ID of the module performing this operation within the particle.
attr_name (str):
Name of the attribute being deleted.
"""
self.module_id = module_id
self.attr_name = attr_name
[docs] def set_version_data(self, vdata: VersionDataImp) -> Self:
return self
[docs] def set_particle_data(self, pdata: ParticleDataImp) -> Self:
return self
[docs] def run(self, vdata: VersionDataImp) -> tuple[VersionDataImp, ParticleUpdater]:
new_vdata: VersionDataImp = vdata.copy()
return (new_vdata,
DelBufferParticleUpdater(self.module_id,
self.attr_name))