Skip to content

Commit be67912

Browse files
author
sheng.xu
committed
train
1 parent d6c561b commit be67912

File tree

8 files changed

+40027
-0
lines changed

8 files changed

+40027
-0
lines changed

Chapter_6 BP/bp_train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,17 @@ def err_rate(label, pre):
224224
# 1、导入数据
225225
print "--------- 1.load data ------------"
226226
feature, label, n_class = load_data("data.txt")
227+
print feature, label, n_class
227228
# 2、训练网络模型
228229
print "--------- 2.training ------------"
229230
w0, w1, b0, b1 = bp_train(feature, label, 20, 1000, 0.1, n_class)
231+
print w0, w1,b0,b1
230232
# 3、保存最终的模型
231233
print "--------- 3.save model ------------"
232234
save_model(w0, w1, b0, b1)
233235
# 4、得到最终的预测结果
234236
print "--------- 4.get prediction ------------"
235237
result = get_predict(feature, w0, w1, b0, b1)
238+
print result
236239
print "训练准确性为:", (1 - err_rate(np.argmax(label, axis=1), np.argmax(result, axis=1)))
237240

Chapter_6 BP/bp_train.pyc

9.18 KB
Binary file not shown.

0 commit comments

Comments
 (0)