File tree Expand file tree Collapse file tree 1 file changed +45
-0
lines changed
dl4scala-examples/src/main/scala/org/dl4scala/examples/nlp/paragraphvectors/tools Expand file tree Collapse file tree 1 file changed +45
-0
lines changed Original file line number Diff line number Diff line change
1
+ package org .dl4scala .examples .nlp .paragraphvectors .tools
2
+
3
+ import java .util .concurrent .atomic .AtomicInteger
4
+
5
+ import lombok .NonNull
6
+ import org .deeplearning4j .models .embeddings .inmemory .InMemoryLookupTable
7
+ import org .deeplearning4j .models .word2vec .VocabWord
8
+ import org .deeplearning4j .text .documentiterator .LabelledDocument
9
+ import org .deeplearning4j .text .tokenization .tokenizerfactory .TokenizerFactory
10
+ import org .nd4j .linalg .api .ndarray .INDArray
11
+ import org .nd4j .linalg .factory .Nd4j
12
+ import scala .collection .JavaConverters ._
13
+
14
+ /**
15
+ * Created by endy on 2017/6/25.
16
+ */
17
+ class MeansBuilder (lookupTable : InMemoryLookupTable [VocabWord ], tokenizerFactory : TokenizerFactory ) {
18
+
19
+ private val vocabCache = lookupTable.getVocab
20
+
21
+ /**
22
+ * This method returns centroid (mean vector) for document.
23
+ *
24
+ * @param document
25
+ */
26
+ def documentAsVector (@ NonNull document : LabelledDocument ): INDArray = {
27
+ val documentAsTokens = tokenizerFactory.create(document.getContent).getTokens.asScala
28
+ val cnt = new AtomicInteger (0 )
29
+
30
+ for (word <- documentAsTokens) {
31
+ if (vocabCache.containsWord(word)) cnt.incrementAndGet
32
+ }
33
+
34
+ val allWords = Nd4j .create(cnt.get, lookupTable.layerSize)
35
+ cnt.set(0 )
36
+
37
+ for (word <- documentAsTokens) {
38
+ if (vocabCache.containsWord(word)) allWords.putRow(cnt.getAndIncrement, lookupTable.vector(word))
39
+ }
40
+
41
+ val mean = allWords.mean(0 )
42
+
43
+ mean
44
+ }
45
+ }
You can’t perform that action at this time.
0 commit comments