Skip to content

Commit d9d5808

Browse files
add __call__() to Layer class
1 parent 142a5f4 commit d9d5808

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed
11.9 MB
Binary file not shown.

src/TensorArray/core/tensor2.so

4 MB
Binary file not shown.

src/tensor_array/util/layer.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
11
from collections import OrderedDict, namedtuple
22
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List
33
from typing import Any
4-
from ...tensor_array.core import tensor2 as t
4+
from tensor_array.core import tensor2 as t
55
from .parameter import Parameter
66

77
class Layer:
8-
_layers = Dict[str, Optional['Layer']]
9-
_parameters = Dict[str, Optional[Parameter]]
10-
_tensors = Dict[str, Optional[t.Tensor]]
8+
is_running: bool
9+
_layers: Dict[str, Optional['Layer']]
10+
_parameters: Dict[str, Optional[Parameter]]
11+
_tensors: Dict[str, Optional[t.Tensor]]
1112

1213
def __init__(self) -> None:
14+
super().__setattr__('is_running', False)
1315
super().__setattr__('_layers', OrderedDict())
1416
super().__setattr__('_parameters', OrderedDict())
1517
super().__setattr__('_tensors', OrderedDict())
1618

1719
def __call__(self, *args: Any, **kwds: Any) -> Any:
20+
if not self.__dict__.get('is_running'):
21+
self.init_value(args, kwds)
22+
self.calculate(args, kwds)
23+
24+
def init_value(self, *args: Any, **kwds: Any) -> Any:
25+
pass
26+
27+
def calculate(self, *args: Any, **kwds: Any) -> Any:
1828
pass
1929

2030
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:

0 commit comments

Comments
 (0)