Skip to content

Commit 18cf8f1

Browse files
committed
add MeansBuilder
1 parent 0d3e336 commit 18cf8f1

File tree

1 file changed

+45
-0
lines changed
  • dl4scala-examples/src/main/scala/org/dl4scala/examples/nlp/paragraphvectors/tools

1 file changed

+45
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package org.dl4scala.examples.nlp.paragraphvectors.tools
2+
3+
import java.util.concurrent.atomic.AtomicInteger
4+
5+
import lombok.NonNull
6+
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable
7+
import org.deeplearning4j.models.word2vec.VocabWord
8+
import org.deeplearning4j.text.documentiterator.LabelledDocument
9+
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory
10+
import org.nd4j.linalg.api.ndarray.INDArray
11+
import org.nd4j.linalg.factory.Nd4j
12+
import scala.collection.JavaConverters._
13+
14+
/**
15+
* Created by endy on 2017/6/25.
16+
*/
17+
class MeansBuilder(lookupTable: InMemoryLookupTable[VocabWord], tokenizerFactory: TokenizerFactory) {
18+
19+
private val vocabCache = lookupTable.getVocab
20+
21+
/**
22+
* This method returns centroid (mean vector) for document.
23+
*
24+
* @param document
25+
*/
26+
def documentAsVector(@NonNull document: LabelledDocument): INDArray = {
27+
val documentAsTokens = tokenizerFactory.create(document.getContent).getTokens.asScala
28+
val cnt = new AtomicInteger(0)
29+
30+
for (word <- documentAsTokens) {
31+
if (vocabCache.containsWord(word)) cnt.incrementAndGet
32+
}
33+
34+
val allWords = Nd4j.create(cnt.get, lookupTable.layerSize)
35+
cnt.set(0)
36+
37+
for (word <- documentAsTokens) {
38+
if (vocabCache.containsWord(word)) allWords.putRow(cnt.getAndIncrement, lookupTable.vector(word))
39+
}
40+
41+
val mean = allWords.mean(0)
42+
43+
mean
44+
}
45+
}

0 commit comments

Comments
 (0)