Skip to content

Commit 585e63c

Browse files
committed
add CustomLayer
1 parent 59d7add commit 585e63c

File tree

1 file changed

+84
-0
lines changed
  • dl4scala-examples/src/main/scala/org/dl4scala/examples/misc/customlayers/layer

1 file changed

+84
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package org.dl4scala.examples.misc.customlayers.layer
2+
3+
import java.util
4+
5+
import org.deeplearning4j.nn.conf.layers.{FeedForwardLayer, Layer}
6+
import CustomLayer.Builder
7+
import org.deeplearning4j.nn.api
8+
import org.deeplearning4j.nn.api.ParamInitializer
9+
import org.deeplearning4j.nn.conf.NeuralNetConfiguration
10+
import org.deeplearning4j.nn.params.DefaultParamInitializer
11+
import org.deeplearning4j.optimize.api.IterationListener
12+
import org.nd4j.linalg.activations.{Activation, IActivation}
13+
import org.nd4j.linalg.api.ndarray.INDArray
14+
/**
15+
* Created by endy on 2017/7/1.
16+
*/
17+
class CustomLayer(builder: Builder) extends FeedForwardLayer(builder){
18+
private var secondActivationFunction = builder.secondActivationFunction
19+
20+
def getSecondActivationFunction: IActivation = secondActivationFunction
21+
22+
def setSecondActivationFunction(secondActivationFunction: IActivation): Unit = { //We also need setter/getter methods for our layer configuration fields (if any) for JSON serialization
23+
this.secondActivationFunction = secondActivationFunction
24+
}
25+
26+
override def instantiate(conf: NeuralNetConfiguration, iterationListeners: util.Collection[IterationListener],
27+
layerIndex: Int, layerParamsView: INDArray, initializeParams: Boolean): api.Layer = {
28+
//The instantiate method is how we go from the configuration class (i.e., this class) to the implementation class
29+
30+
// (i.e., a CustomLayerImpl instance)
31+
//For the most part, it's the same for each type of layer
32+
33+
val myCustomLayer = new CustomLayerImpl(conf)
34+
myCustomLayer.setListeners(iterationListeners) //Set the iteration listeners, if any
35+
myCustomLayer.setIndex(layerIndex) //Integer index of the layer
36+
37+
// Parameter view array: the network parameters for the entire network (all layers) are
38+
// allocated in one big array. The relevant section of this parameter vector is extracted out for each layer,
39+
// (i.e., it's a "view" array in that it's a subset of a larger array)
40+
// This is a row vector, with length equal to the number of parameters in the layer
41+
myCustomLayer.setParamsViewArray(layerParamsView)
42+
43+
// Initialize the layer parameters. For example,
44+
// Note that the entries in paramTable (2 entries here: a weight array of shape [nIn,nOut] and biases of shape [1,nOut]
45+
// are in turn a view of the 'layerParamsView' array.
46+
val paramTable = initializer().init(conf, layerParamsView, initializeParams)
47+
myCustomLayer.setParamTable(paramTable)
48+
myCustomLayer.setConf(conf)
49+
myCustomLayer
50+
}
51+
52+
// This method returns the parameter initializer for this type of layer
53+
// In this case, we can use the DefaultParamInitializer, which is the same one used for DenseLayer
54+
// For more complex layers, you may need to implement a custom parameter initializer
55+
override def initializer(): ParamInitializer = DefaultParamInitializer.getInstance
56+
}
57+
58+
object CustomLayer {
59+
class Builder extends FeedForwardLayer.Builder[Builder] {
60+
var secondActivationFunction: IActivation = _
61+
62+
/**
63+
* A custom property used in this custom layer example. See the CustomLayerExampleReadme.md for details
64+
*
65+
* @param secondActivationFunction Second activation function for the layer
66+
*/
67+
def secondActivationFunction(secondActivationFunction: String): Builder =
68+
secondActivationFunction1(Activation.fromString(secondActivationFunction))
69+
70+
/**
71+
* A custom property used in this custom layer example. See the CustomLayerExampleReadme.md for details
72+
*
73+
* @param secondActivationFunction Second activation function for the layer
74+
*/
75+
def secondActivationFunction1(secondActivationFunction: Activation): Builder = {
76+
this.secondActivationFunction = secondActivationFunction.getActivationFunction
77+
this
78+
}
79+
80+
override def build[E <: Layer](): CustomLayer = {
81+
new CustomLayer(this)
82+
}
83+
}
84+
}

0 commit comments

Comments
 (0)