Rate this Page

Class ModuleDictImpl#

Inheritance Relationships#

Base Type#

Class Documentation#

class ModuleDictImpl : public torch::nn::Cloneable<ModuleDictImpl>#

An OrderedDict of Modules that registers its elements by their keys.

torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
  {"linear", Linear(10, 3).ptr()},
  {"conv", Conv2d(1, 2, 3).ptr()},
  {"dropout", Dropout(0.5).ptr()},
};
torch::nn::ModuleDict dict1(ordereddict);

for (const auto &module : *dict1) {
  module->pretty_print(std::cout);
}

std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = {
  {"linear", Linear(10, 3).ptr()},
  {"conv", Conv2d(1, 2, 3).ptr()},
  {"dropout", Dropout(0.5).ptr()},
};
torch::nn::ModuleDict dict2(list);

for (const auto &module : *dict2) {
  module->pretty_print(std::cout);
}

Why should you use ModuleDict instead of a simple map or OrderedDict? The value a ModuleDict provides over manually calling an ordered map of modules is that it allows treating the whole container as a single module, such that performing a transformation on the ModuleDict applies to each of the modules it stores (which are each a registered submodule of the ModuleDict). For example, calling .to(torch::kCUDA) on a ModuleDict will move each module in the map to CUDA memory. For example:

torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
  {"linear", Linear(10, 3).ptr()},
  {"conv", Conv2d(1, 2, 3).ptr()},
  {"dropout", Dropout(0.5).ptr()},
};
torch::nn::ModuleDict dict(ordereddict);

// Convert all modules to CUDA.
dict->to(torch::kCUDA);

Finally, ModuleDict provides a lightweight container API, such as allowing iteration over submodules, positional access, adding new modules from a vector of key-module pairs or an OrderedDict or another ModuleDict after construction via update.

Public Types

using Iterator = torch::OrderedDict<std::string, std::shared_ptr<Module>>::Iterator#
using ConstIterator = torch::OrderedDict<std::string, std::shared_ptr<Module>>::ConstIterator#

Public Functions

ModuleDictImpl() = default#
inline explicit ModuleDictImpl(const std::vector<std::pair<std::string, std::shared_ptr<Module>>> &modules)#

Constructs the ModuleDict from a list of string-Module pairs.

inline explicit ModuleDictImpl(const torch::OrderedDict<std::string, std::shared_ptr<Module>> &modules)#

Constructs the ModuleDict from an OrderedDict.

inline std::vector<std::pair<std::string, std::shared_ptr<Module>>> items() const#

Return the items in the ModuleDict.

inline std::vector<std::string> keys() const#

Return the keys in the ModuleDict.

inline std::vector<std::shared_ptr<Module>> values() const#

Return the values in the ModuleDict.

inline Iterator begin()#

Return an iterator to the start of ModuleDict.

inline ConstIterator begin() const#

Return a const iterator to the start of ModuleDict.

inline Iterator end()#

Return an iterator to the end of ModuleDict.

inline ConstIterator end() const#

Return a const iterator to the end of ModuleDict.

inline size_t size() const noexcept#

Return the number of items currently stored in the ModuleDict.

inline bool empty() const noexcept#

Return true if the ModuleDict is empty, otherwise return false.

inline bool contains(const std::string &key) const noexcept#

Check if the centain parameter with the key in the ModuleDict.

inline void clear()#

Remove all items from the ModuleDict.

inline virtual std::shared_ptr<Module> clone(const std::optional<Device> &device = std::nullopt) const override#

Special cloning function for ModuleDict because it does not use reset().

inline virtual void reset() override#

reset() is empty for ModuleDict, since it does not have parameters of its own.

inline virtual void pretty_print(std::ostream &stream) const override#

Pretty prints the ModuleDict into the given stream.

inline std::shared_ptr<Module> operator[](const std::string &key) const#

Attempts to returns the Module associated with the given key.

Throws an exception if no such key is stored in the ModuleDict. Check contains(key) before for a non-throwing way of access.

template<typename T>
inline T &at(const std::string &key)#

Attempts to return the module at the given key as the requested type.

Throws an exception if no such key is stored in the ModuleDict. Check contains(key) before for a non-throwing way of access.

template<typename T>
inline const T &at(const std::string &key) const#

Attempts to return the module at the given key as the requested type.

Throws an exception if no such key is stored in the ModuleDict. Check contains(key) before for a non-throwing way of access.

inline std::shared_ptr<Module> pop(const std::string &key)#

Removes and returns the Module associated with the given key.

Throws an exception if no such key is stored in the ModuleDict. Check contains(key) before for a non-throwing way of access.

inline void update(const std::vector<std::pair<std::string, std::shared_ptr<Module>>> &modules)#

Updated the ModuleDict with a vector of key-module pairs.

template<typename Container>
inline void update(const Container &container)#

Updated the ModuleDict with key-value pairs from OrderedDict or ModuleDict.