tensorflow에서 affine으로 신경망 만드는데 체크포인트를 불러오고 나서는 학습이 진행되지가 않아요.
조회수 913회
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("C:\WD\mnist\mnist_data", one_hot= True)
X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)
global_step = tf.Variable(0, trainable= False, name = 'gloabal_step')
#layer-1
W1 = tf.Variable(tf.random_normal([784, 256], stddev=0.01))
b1 = tf.Variable(tf.zeros([256]))
L1 = tf.add(tf.matmul(X, W1), b1)
L1 = tf.nn.relu(L1)
L1 = tf.nn.dropout(L1, keep_prob)
#Layer-2
W2 = tf.Variable(tf.random_normal([256, 256], stddev=0.01))
b2 = tf.Variable(tf.zeros(256))
L2 = tf.add(tf.matmul(L1, W2), b2)
L2 = tf.nn.relu(L2)
L2 = tf.nn.dropout(L2, keep_prob)
#Layer-3
W3 = tf.Variable(tf.random_normal([256, 100], stddev= 0.01))
b3 = tf.Variable(tf.zeros(100))
L3 = tf.add(tf.matmul(L2, W3), b3)
L3 = tf.nn.relu(L3)
L3 = tf.nn.dropout(L3, keep_prob)
#Layer-4
W4 = tf.Variable(tf.random_normal([100, 10], stddev= 0.01))
b4 = tf.Variable(tf.zeros(10))
model = tf.add(tf.matmul(L3, W4), b4)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = model, labels= Y))
optimizer = tf.train.AdamOptimizer(0.01).minimize(cost)
#session
sess = tf.Session()
saver = tf.train.Saver(tf.global_variables())
#save params
def save_params():
saver.save(sess, 'C:\\Users\yym30\PycharmProjects\\tensorflow\mnist\data\models\model3\\checkpoint0507.ckpt', global_step=global_step)
print('-------params saved!-------')
init = tf.global_variables_initializer()
sess.run(init)
#load checkpoint
ckpt = tf.train.get_checkpoint_state('C:\\Users\yym30\PycharmProjects\\tensorflow\mnist\data\models\model3')
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
saver.restore(sess, ckpt.model_checkpoint_path)
print('\n' + 'params loaded!')
else:
sess.run(tf.global_variables_initializer())
print('params aren\'t loaded!')
#training
batch_size = 100
total_batch = int(mnist.train.num_examples / batch_size)
for epoch in range(30):
total_cost = 0
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
_, cost_val = sess.run([optimizer, cost], feed_dict={X:batch_xs, Y:batch_ys, keep_prob:0.7})
total_cost += cost_val
print('Epoch', '%04d' % (epoch + 1), 'Average cost= ', '{:.3f}'.format(total_cost / total_batch))
save_params()
print("optimizing completed")
#calculating accuracy
is_correct = tf.equal(tf.argmax(model,1), tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
print('accuracy: ', sess.run(accuracy, feed_dict={X: mnist.test.images, Y:mnist.test.labels, keep_prob:1 }), '!' )
일단 체크포인트 로드하는 건 성공하는 듯 싶어요.
params loaded!
Epoch 0001 Average cost= 0.294
-------params saved!-------
Epoch 0002 Average cost= 0.297
-------params saved!-------
Epoch 0003 Average cost= 0.278
-------params saved!-------
Epoch 0004 Average cost= 0.268
-------params saved!-------
Epoch 0005 Average cost= 0.314
-------params saved!-------
Epoch 0006 Average cost= 0.273
-------params saved!-------
Epoch 0007 Average cost= 0.271
-------params saved!-------
Epoch 0008 Average cost= 0.301
-------params saved!-------
Epoch 0009 Average cost= 0.338
-------params saved!-------
Epoch 0010 Average cost= 0.360
-------params saved!-------
Epoch 0011 Average cost= 0.287
-------params saved!-------
Epoch 0012 Average cost= 0.273
-------params saved!-------
Epoch 0013 Average cost= 0.289
-------params saved!-------
Epoch 0014 Average cost= 0.273
-------params saved!-------
Epoch 0015 Average cost= 0.254
-------params saved!-------
Epoch 0016 Average cost= 0.256
-------params saved!-------
Epoch 0017 Average cost= 0.292
-------params saved!-------
Epoch 0018 Average cost= 0.288
-------params saved!-------
Epoch 0019 Average cost= 0.290
-------params saved!-------
Epoch 0020 Average cost= 0.316
-------params saved!-------
Epoch 0021 Average cost= 0.293
-------params saved!-------
Epoch 0022 Average cost= 0.291
-------params saved!-------
Epoch 0023 Average cost= 0.268
-------params saved!-------
Epoch 0024 Average cost= 0.317
-------params saved!-------
Epoch 0025 Average cost= 0.324
-------params saved!-------
Epoch 0026 Average cost= 0.323
-------params saved!-------
Epoch 0027 Average cost= 0.281
-------params saved!-------
Epoch 0028 Average cost= 0.272
-------params saved!-------
Epoch 0029 Average cost= 0.280
-------params saved!-------
Epoch 0030 Average cost= 0.319
-------params saved!-------
optimizing completed
accuracy: 0.9544 !
그런데 이런 식으로 항상 체크포인트를 불러오면 cost가 0.3 근처에서 나아지지가 않아요! 왜 그럴까요 ..?
댓글 입력