Skip to content

Commit c209183

Browse files
committed
Refactor the utility class
1 parent 0527d81 commit c209183

File tree

4 files changed

+119
-112
lines changed

4 files changed

+119
-112
lines changed

ml/src/main/java/com/baeldung/logreg/DataUtilities.java

Lines changed: 0 additions & 102 deletions
This file was deleted.

ml/src/main/java/com/baeldung/logreg/MnistClassifier.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,20 @@ public static void main(String[] args) throws Exception {
6767

6868
final String path = basePath + "mnist_png" + File.separator;
6969
if (!new File(path).exists()) {
70-
logger.debug("Downloading data {}", dataUrl);
70+
logger.info("Downloading data {}", dataUrl);
7171
String localFilePath = basePath + "mnist_png.tar.gz";
72-
logger.info("local file: {}", localFilePath);
73-
if (DataUtilities.downloadFile(dataUrl, localFilePath)) {
74-
DataUtilities.extractTarGz(localFilePath, basePath);
72+
File file = new File(localFilePath);
73+
if (!file.exists()) {
74+
file.getParentFile()
75+
.mkdirs();
76+
Utils.downloadAndSave(dataUrl, file);
77+
Utils.extractTarArchive(file, basePath);
7578
}
7679
} else {
77-
logger.info("local file exists {}", path);
78-
80+
logger.info("Using the local data from folder {}", path);
7981
}
8082

81-
logger.info("Vectorizing data...");
83+
logger.info("Vectorizing the data from folder {}", path);
8284
// vectorization of train data
8385
File trainData = new File(path + "training");
8486
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);

ml/src/main/java/com/baeldung/logreg/MnistPrediction.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,22 @@ public static String fileChose() {
3636
}
3737

3838
public static void main(String[] args) throws IOException {
39-
String path = fileChose().toString();
39+
if (!modelPath.exists()) {
40+
logger.info("The model not found. Have you trained it?");
41+
return;
42+
}
4043
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelPath);
44+
String path = fileChose();
4145
File file = new File(path);
4246

4347
INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file);
4448
new ImagePreProcessingScaler(0, 1).transform(image);
45-
49+
4650
// Pass through to neural Net
4751
INDArray output = model.output(image);
4852

4953
logger.info("File: {}", path);
50-
logger.info(output.toString());
54+
logger.info("Probabilities: {}", output);
5155
}
5256

5357
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package com.baeldung.logreg;
2+
3+
import java.io.BufferedInputStream;
4+
import java.io.BufferedOutputStream;
5+
import java.io.File;
6+
import java.io.FileInputStream;
7+
import java.io.FileOutputStream;
8+
import java.io.IOException;
9+
import java.io.InputStream;
10+
11+
import org.apache.commons.compress.archivers.ArchiveEntry;
12+
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
13+
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
14+
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
15+
import org.apache.http.HttpEntity;
16+
import org.apache.http.client.methods.CloseableHttpResponse;
17+
import org.apache.http.client.methods.HttpGet;
18+
import org.apache.http.impl.client.CloseableHttpClient;
19+
import org.apache.http.impl.client.HttpClientBuilder;
20+
import org.slf4j.Logger;
21+
import org.slf4j.LoggerFactory;
22+
23+
/**
24+
* Utility class for digit classifier.
25+
*
26+
*/
27+
public class Utils {
28+
29+
private static final Logger logger = LoggerFactory.getLogger(Utils.class);
30+
31+
private Utils() {
32+
}
33+
34+
/**
35+
* Download the content of the given url and save it into a file.
36+
* @param url
37+
* @param file
38+
*/
39+
public static void downloadAndSave(String url, File file) throws IOException {
40+
CloseableHttpClient client = HttpClientBuilder.create()
41+
.build();
42+
logger.info("Connecting to {}", url);
43+
try (CloseableHttpResponse response = client.execute(new HttpGet(url))) {
44+
HttpEntity entity = response.getEntity();
45+
if (entity != null) {
46+
logger.info("Downloaded {} bytes", entity.getContentLength());
47+
try (FileOutputStream outstream = new FileOutputStream(file)) {
48+
logger.info("Saving to the local file");
49+
entity.writeTo(outstream);
50+
outstream.flush();
51+
logger.info("Local file saved");
52+
}
53+
}
54+
}
55+
}
56+
57+
/**
58+
* Extract a "tar.gz" file into a given folder.
59+
* @param file
60+
* @param folder
61+
*/
62+
public static void extractTarArchive(File file, String folder) throws IOException {
63+
logger.info("Extracting archive {} into folder {}", file.getName(), folder);
64+
// @formatter:off
65+
try (FileInputStream fis = new FileInputStream(file);
66+
BufferedInputStream bis = new BufferedInputStream(fis);
67+
GzipCompressorInputStream gzip = new GzipCompressorInputStream(bis);
68+
TarArchiveInputStream tar = new TarArchiveInputStream(gzip)) {
69+
// @formatter:on
70+
TarArchiveEntry entry;
71+
while ((entry = (TarArchiveEntry) tar.getNextEntry()) != null) {
72+
extractEntry(entry, tar, folder);
73+
}
74+
}
75+
logger.info("Archive extracted");
76+
}
77+
78+
/**
79+
* Extract an entry of the input stream into a given folder
80+
* @param entry
81+
* @param tar
82+
* @param folder
83+
* @throws IOException
84+
*/
85+
public static void extractEntry(ArchiveEntry entry, InputStream tar, String folder) throws IOException {
86+
final int bufferSize = 4096;
87+
final String path = folder + entry.getName();
88+
if (entry.isDirectory()) {
89+
new File(path).mkdirs();
90+
} else {
91+
int count;
92+
byte[] data = new byte[bufferSize];
93+
// @formatter:off
94+
try (FileOutputStream os = new FileOutputStream(path);
95+
BufferedOutputStream dest = new BufferedOutputStream(os, bufferSize)) {
96+
// @formatter:off
97+
while ((count = tar.read(data, 0, bufferSize)) != -1) {
98+
dest.write(data, 0, count);
99+
}
100+
}
101+
}
102+
}
103+
}

0 commit comments

Comments
 (0)