Skip to content

Commit def35ed

Browse files
author
linyiqun
committed
svm支持向量机代码的完善和注释的添加
svm支持向量机代码的完善和注释的添加
1 parent d5d6855 commit def35ed

File tree

8 files changed

+203
-4
lines changed

8 files changed

+203
-4
lines changed

DataMining_SVM/Client.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ public static void main(String[] args){
1111
String trainDataPath = "C:\\Users\\lyq\\Desktop\\icon\\trainInput.txt";
1212
//测试数据文件路径
1313
String testDataPath = "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt";
14+
15+
SVMTool tool = new SVMTool(trainDataPath);
16+
//对测试数据进行svm支持向量机分类
17+
tool.svmPredictData(testDataPath);
1418
}
1519

1620
}

DataMining_SVM/SVM.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ public static void main(String[] args) {
4040
// 定义svm_parameter对象
4141
svm_parameter param = new svm_parameter();
4242
param.svm_type = svm_parameter.EPSILON_SVR;
43+
//设置svm的核函数类型为线型
4344
param.kernel_type = svm_parameter.LINEAR;
45+
//后面的参数配置只针对训练集的数据
4446
param.cache_size = 100;
4547
param.eps = 0.00001;
4648
param.C = 1.9;

DataMining_SVM/SVMTool.java

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,173 @@
11
package DataMining_SVM;
22

3+
import java.io.BufferedReader;
4+
import java.io.File;
5+
import java.io.FileReader;
6+
import java.util.ArrayList;
7+
import java.util.List;
8+
9+
import DataMining_SVM.libsvm.svm;
10+
import DataMining_SVM.libsvm.svm_model;
11+
import DataMining_SVM.libsvm.svm_node;
12+
import DataMining_SVM.libsvm.svm_parameter;
13+
import DataMining_SVM.libsvm.svm_problem;
14+
315
/**
416
* SVM支持向量机工具类
17+
*
518
* @author lyq
6-
*
19+
*
720
*/
821
public class SVMTool {
22+
// 训练集数据文件路径
23+
private String trainDataPath;
24+
// svm_problem对象,用于构造svm model模型
25+
private svm_problem sProblem;
26+
// svm参数,里面有svm支持向量机的类型和不同 的svm的核函数类型
27+
private svm_parameter sParam;
28+
29+
public SVMTool(String trainDataPath) {
30+
this.trainDataPath = trainDataPath;
31+
32+
// 初始化svm相关变量
33+
sProblem = initSvmProblem();
34+
sParam = initSvmParam();
35+
}
36+
37+
/**
38+
* 初始化操作,根据训练集数据构造分类模型
39+
*/
40+
private void initOperation(){
41+
42+
}
43+
44+
/**
45+
* svm_problem对象,训练集数据的相关信息配置
46+
*
47+
* @return
48+
*/
49+
private svm_problem initSvmProblem() {
50+
List<Double> label = new ArrayList<Double>();
51+
List<svm_node[]> nodeSet = new ArrayList<svm_node[]>();
52+
getData(nodeSet, label, trainDataPath);
53+
54+
int dataRange = nodeSet.get(0).length;
55+
svm_node[][] datas = new svm_node[nodeSet.size()][dataRange]; // 训练集的向量表
56+
for (int i = 0; i < datas.length; i++) {
57+
for (int j = 0; j < dataRange; j++) {
58+
datas[i][j] = nodeSet.get(i)[j];
59+
}
60+
}
61+
double[] lables = new double[label.size()]; // a,b 对应的lable
62+
for (int i = 0; i < lables.length; i++) {
63+
lables[i] = label.get(i);
64+
}
65+
66+
// 定义svm_problem对象
67+
svm_problem problem = new svm_problem();
68+
problem.l = nodeSet.size(); // 向量个数
69+
problem.x = datas; // 训练集向量表
70+
problem.y = lables; // 对应的lable数组
71+
72+
return problem;
73+
}
74+
75+
/**
76+
* 初始化svm支持向量机的参数,包括svm的类型和核函数的类型
77+
*
78+
* @return
79+
*/
80+
private svm_parameter initSvmParam() {
81+
// 定义svm_parameter对象
82+
svm_parameter param = new svm_parameter();
83+
param.svm_type = svm_parameter.EPSILON_SVR;
84+
// 设置svm的核函数类型为线型
85+
param.kernel_type = svm_parameter.LINEAR;
86+
// 后面的参数配置只针对训练集的数据
87+
param.cache_size = 100;
88+
param.eps = 0.00001;
89+
param.C = 1.9;
90+
91+
return param;
92+
}
93+
94+
/**
95+
* 通过svm方式预测数据的类型
96+
*
97+
* @param testDataPath
98+
*/
99+
public void svmPredictData(String testDataPath) {
100+
// 获取测试数据
101+
List<Double> testlabel = new ArrayList<Double>();
102+
List<svm_node[]> testnodeSet = new ArrayList<svm_node[]>();
103+
getData(testnodeSet, testlabel, testDataPath);
104+
int dataRange = testnodeSet.get(0).length;
105+
106+
svm_node[][] testdatas = new svm_node[testnodeSet.size()][dataRange]; // 训练集的向量表
107+
for (int i = 0; i < testdatas.length; i++) {
108+
for (int j = 0; j < dataRange; j++) {
109+
testdatas[i][j] = testnodeSet.get(i)[j];
110+
}
111+
}
112+
// 测试数据的真实值,在后面将会与svm的预测值做比较
113+
double[] testlables = new double[testlabel.size()]; // a,b 对应的lable
114+
for (int i = 0; i < testlables.length; i++) {
115+
testlables[i] = testlabel.get(i);
116+
}
117+
118+
// 如果参数没有问题,则svm.svm_check_parameter()函数返回null,否则返回error描述。
119+
// 对svm的配置参数叫验证,因为有些参数只针对部分的支持向量机的类型
120+
System.out.println(svm.svm_check_parameter(sProblem, sParam));
121+
System.out.println("------------检验参数-----------");
122+
// 训练SVM分类模型
123+
svm_model model = svm.svm_train(sProblem, sParam);
124+
125+
// 预测测试数据的lable
126+
double err = 0.0;
127+
for (int i = 0; i < testdatas.length; i++) {
128+
double truevalue = testlables[i];
129+
// 测试数据真实值
130+
System.out.print(truevalue + " ");
131+
double predictValue = svm.svm_predict(model, testdatas[i]);
132+
// 测试数据预测值
133+
System.out.println(predictValue);
134+
}
135+
}
136+
137+
/**
138+
* 从文件中获取数据
139+
*
140+
* @param nodeSet
141+
* 向量节点
142+
* @param label
143+
* 节点值类型值
144+
* @param filename
145+
* 数据文件地址
146+
*/
147+
private void getData(List<svm_node[]> nodeSet, List<Double> label,
148+
String filename) {
149+
try {
150+
151+
FileReader fr = new FileReader(new File(filename));
152+
BufferedReader br = new BufferedReader(fr);
153+
String line = null;
154+
while ((line = br.readLine()) != null) {
155+
String[] datas = line.split(",");
156+
svm_node[] vector = new svm_node[datas.length - 1];
157+
for (int i = 0; i < datas.length - 1; i++) {
158+
svm_node node = new svm_node();
159+
node.index = i + 1;
160+
node.value = Double.parseDouble(datas[i]);
161+
vector[i] = node;
162+
}
163+
nodeSet.add(vector);
164+
double lablevalue = Double.parseDouble(datas[datas.length - 1]);
165+
label.add(lablevalue);
166+
}
167+
} catch (Exception e) {
168+
e.printStackTrace();
169+
}
170+
171+
}
9172

10173
}

DataMining_SVM/libsvm/svm.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,6 +2637,12 @@ else if(cmd.startsWith("SV"))
26372637
return model;
26382638
}
26392639

2640+
/**
2641+
* 对svm的配置参数叫验证,因为有些参数只针对部分的支持向量机的类型
2642+
* @param prob
2643+
* @param param
2644+
* @return
2645+
*/
26402646
public static String svm_check_parameter(svm_problem prob, svm_parameter param)
26412647
{
26422648
// svm_type

DataMining_SVM/libsvm/svm_model.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
package DataMining_SVM.libsvm;
55
public class svm_model implements java.io.Serializable
66
{
7+
//svm支持向量机的参数
78
svm_parameter param; // parameter
9+
//分类的类型数
810
int nr_class; // number of classes, = 2 in regression/one class svm
911
int l; // total #SV
1012
svm_node[][] SV; // SVs (SV[l])
@@ -15,6 +17,7 @@ public class svm_model implements java.io.Serializable
1517

1618
// for classification only
1719

20+
//每个类型的类型值
1821
int[] label; // label of each class (label[k])
1922
int[] nSV; // number of SVs for each class (nSV[k])
2023
// nSV[0] + nSV[1] + ... + nSV[k-1] = l

DataMining_SVM/libsvm/svm_node.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
package DataMining_SVM.libsvm;
2+
/**
3+
*
4+
* svm向量节点
5+
* @author lyq
6+
*
7+
*/
28
public class svm_node implements java.io.Serializable
39
{
10+
//节点索引
411
public int index;
12+
//节点的值
513
public double value;
614
}

DataMining_SVM/libsvm/svm_parameter.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
package DataMining_SVM.libsvm;
22
public class svm_parameter implements Cloneable,java.io.Serializable
33
{
4-
/* svm_type */
4+
/* svm_type 支持向量机的类型*/
55
public static final int C_SVC = 0;
66
public static final int NU_SVC = 1;
7+
//一类svm
78
public static final int ONE_CLASS = 2;
89
public static final int EPSILON_SVR = 3;
910
public static final int NU_SVR = 4;
1011

11-
/* kernel_type */
12+
/* kernel_type 核函数类型*/
13+
//线型核函数
1214
public static final int LINEAR = 0;
15+
//多项式核函数
1316
public static final int POLY = 1;
17+
//RBF径向基函数
1418
public static final int RBF = 2;
19+
//二层神经网络核函数
1520
public static final int SIGMOID = 3;
1621
public static final int PRECOMPUTED = 4;
1722

@@ -21,7 +26,7 @@ public class svm_parameter implements Cloneable,java.io.Serializable
2126
public double gamma; // for poly/rbf/sigmoid
2227
public double coef0; // for poly/sigmoid
2328

24-
// these are for training only
29+
// these are for training only 后面这些参数只针对训练集的数据
2530
public double cache_size; // in MB
2631
public double eps; // stopping criteria
2732
public double C; // for C_SVC, EPSILON_SVR and NU_SVR
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
package DataMining_SVM.libsvm;
2+
/**
3+
* 包含了训练集数据的基本信息
4+
* @author lyq
5+
*
6+
*/
27
public class svm_problem implements java.io.Serializable
38
{
9+
//定义了向量的总个数
410
public int l;
11+
//分类类型值数组
512
public double[] y;
13+
//训练集向量表
614
public svm_node[][] x;
715
}

0 commit comments

Comments
 (0)