1
+ package org .dl4scala .examples .misc .customlayers .layer
2
+
3
+ import java .util
4
+
5
+ import org .deeplearning4j .nn .conf .layers .{FeedForwardLayer , Layer }
6
+ import CustomLayer .Builder
7
+ import org .deeplearning4j .nn .api
8
+ import org .deeplearning4j .nn .api .ParamInitializer
9
+ import org .deeplearning4j .nn .conf .NeuralNetConfiguration
10
+ import org .deeplearning4j .nn .params .DefaultParamInitializer
11
+ import org .deeplearning4j .optimize .api .IterationListener
12
+ import org .nd4j .linalg .activations .{Activation , IActivation }
13
+ import org .nd4j .linalg .api .ndarray .INDArray
14
+ /**
15
+ * Created by endy on 2017/7/1.
16
+ */
17
+ class CustomLayer (builder : Builder ) extends FeedForwardLayer (builder){
18
+ private var secondActivationFunction = builder.secondActivationFunction
19
+
20
+ def getSecondActivationFunction : IActivation = secondActivationFunction
21
+
22
+ def setSecondActivationFunction (secondActivationFunction : IActivation ): Unit = { // We also need setter/getter methods for our layer configuration fields (if any) for JSON serialization
23
+ this .secondActivationFunction = secondActivationFunction
24
+ }
25
+
26
+ override def instantiate (conf : NeuralNetConfiguration , iterationListeners : util.Collection [IterationListener ],
27
+ layerIndex : Int , layerParamsView : INDArray , initializeParams : Boolean ): api.Layer = {
28
+ // The instantiate method is how we go from the configuration class (i.e., this class) to the implementation class
29
+
30
+ // (i.e., a CustomLayerImpl instance)
31
+ // For the most part, it's the same for each type of layer
32
+
33
+ val myCustomLayer = new CustomLayerImpl (conf)
34
+ myCustomLayer.setListeners(iterationListeners) // Set the iteration listeners, if any
35
+ myCustomLayer.setIndex(layerIndex) // Integer index of the layer
36
+
37
+ // Parameter view array: the network parameters for the entire network (all layers) are
38
+ // allocated in one big array. The relevant section of this parameter vector is extracted out for each layer,
39
+ // (i.e., it's a "view" array in that it's a subset of a larger array)
40
+ // This is a row vector, with length equal to the number of parameters in the layer
41
+ myCustomLayer.setParamsViewArray(layerParamsView)
42
+
43
+ // Initialize the layer parameters. For example,
44
+ // Note that the entries in paramTable (2 entries here: a weight array of shape [nIn,nOut] and biases of shape [1,nOut]
45
+ // are in turn a view of the 'layerParamsView' array.
46
+ val paramTable = initializer().init(conf, layerParamsView, initializeParams)
47
+ myCustomLayer.setParamTable(paramTable)
48
+ myCustomLayer.setConf(conf)
49
+ myCustomLayer
50
+ }
51
+
52
+ // This method returns the parameter initializer for this type of layer
53
+ // In this case, we can use the DefaultParamInitializer, which is the same one used for DenseLayer
54
+ // For more complex layers, you may need to implement a custom parameter initializer
55
+ override def initializer (): ParamInitializer = DefaultParamInitializer .getInstance
56
+ }
57
+
58
+ object CustomLayer {
59
+ class Builder extends FeedForwardLayer .Builder [Builder ] {
60
+ var secondActivationFunction : IActivation = _
61
+
62
+ /**
63
+ * A custom property used in this custom layer example. See the CustomLayerExampleReadme.md for details
64
+ *
65
+ * @param secondActivationFunction Second activation function for the layer
66
+ */
67
+ def secondActivationFunction (secondActivationFunction : String ): Builder =
68
+ secondActivationFunction1(Activation .fromString(secondActivationFunction))
69
+
70
+ /**
71
+ * A custom property used in this custom layer example. See the CustomLayerExampleReadme.md for details
72
+ *
73
+ * @param secondActivationFunction Second activation function for the layer
74
+ */
75
+ def secondActivationFunction1 (secondActivationFunction : Activation ): Builder = {
76
+ this .secondActivationFunction = secondActivationFunction.getActivationFunction
77
+ this
78
+ }
79
+
80
+ override def build [E <: Layer ](): CustomLayer = {
81
+ new CustomLayer (this )
82
+ }
83
+ }
84
+ }
0 commit comments