tfrecords 格式数据训练mnist

TFRecords是tensorflow存储数据的一种二进制文件,能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件,类似于caffe中的LMDB和LvevelDB,极大的提高了IO吞吐。

TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据,将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。

优势:

第一,tensorflow里的graph能够记住状态(state),这使得TFRecordReader能够记住tfrecord的位置,并且始终能返回下一个。而这就要求我们在使用之前,必须初始化整个graph,这里我们使用了函数tf.initialize_all_variables()来进行初始化。

第二,tensorflow中的队列和普通的队列差不多,不过它里面的operation和tensor都是符号型的(symbolic),在调用sess.run()时才执行。

第三, TFRecordReader会一直弹出队列中文件的名字,直到队列为空。

  1. import os
  2. import tensorflow as tf
  3. from PIL import Image
  4. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  5.  
  6. def data_to_tfrecord(images, labels, filename):
  7. # Save data into TFRecord
  8. if os.path.isfile(filename):
  9. print("%s exists" % filename)
  10. return
  11. print("Converting data into %s ..." % filename)
  12. cwd = os.getcwd()
  13. writer = tf.python_io.TFRecordWriter(filename)
  14. for index, img_name in enumerate(images):
  15. print(index)
  16. img = Image.open(img_name)
  17. img = img.resize((28, 28))
  18. img_raw = img.tobytes()
  19.  
  20. example = tf.train.Example(features=tf.train.Features(feature={
  21. "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(labels[index])])),
  22. 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
  23. }))
  24. writer.write(example.SerializeToString()) # Serialize To String
  25. writer.close()
  26.  
  27. def tfrecord_to_data(filename):
  28. # generate a queue with a given file name
  29. print("reading tfrecords from {}".format(filename))
  30. filename_queue = tf.train.string_input_producer([filename])
  31. reader = tf.TFRecordReader()
  32. _, serialized_example = reader.read(filename_queue)
  33. features = tf.parse_single_example(serialized_example,features={
  34. 'label': tf.FixedLenFeature([], tf.int64),
  35. 'img_raw': tf.FixedLenFeature([], tf.string),
  36. })
  37. img = tf.decode_raw(features['img_raw'], tf.uint8)
  38. img = tf.reshape(img, [28, 28, 3])
  39. img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  40. label = tf.cast(features['label'], tf.int64)
  41. return img, label
  42.  
  43.  
  44.  
  45. def decode_tfrecords(filename):
  46. #show the tfrecords
  47. for serialized_example in tf.python_io.tf_record_iterator(filename):
  48. example = tf.train.Example()
  49. example.ParseFromString(serialized_example)
  50.  
  51. image = example.features.feature['img_raw'].bytes_list.value
  52. label = example.features.feature['label'].int64_list.value
  53. print(image, label)
  54.  
  55.  
  56. def read_data_from_paths(file_path, name):
  57. labels = []
  58. file_names = []
  59.  
  60. file_name = os.path.join(file_path, name)
  61. train_txt = open(file_name,'r')
  62.  
  63. for idx in train_txt:
  64. idx=idx.rstrip('\n')
  65. spt = idx.split(' ')
  66. file_names.append(os.path.join(file_path, spt[0]))
  67. labels.append(spt[1])
  68. return file_names, labels
  69.  
  70.  
  71. def train():
  72. #network
  73. batch_size = 64
  74. inputs = tf.placeholder(tf.float32, [batch_size, 28, 28, 3], name='inputs')
  75. conv1 = tf.layers.conv2d(inputs=inputs, filters=64, kernel_size=(3, 3), padding="same", activation=None)
  76. pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
  77. conv2 = tf.layers.conv2d(inputs=pool1, filters=128, kernel_size=(3, 3), padding="same", activation=None)
  78. pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
  79.  
  80. pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 128])
  81. fc1 = tf.layers.dense(pool2_flat, 500, activation=tf.nn.relu)
  82. fc2 = tf.layers.dense(fc1, 10, activation=tf.nn.relu)
  83. y_out = tf.nn.softmax(fc2)
  84.  
  85. y_ = tf.placeholder(tf.float32, [batch_size, 10])
  86. cross_entropy = -tf.reduce_mean(y_ * tf.log(y_out)) # 计算交叉熵
  87.  
  88. learning_rate=1e-3
  89. train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
  90. correct_prediction = tf.equal(tf.argmax(y_out, 1), tf.argmax(y_, 1))
  91. accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
  92.  
  93. img, label = tfrecord_to_data("./mnist_train.tfrecords")
  94. img_batch, label_batch = tf.train.shuffle_batch([img, label],batch_size=batch_size, capacity=2000,min_after_dequeue=1000)
  95.  
  96. init = tf.global_variables_initializer()
  97.  
  98. with tf.Session() as session:
  99. session.run(init)
  100. threads = tf.train.start_queue_runners()
  101. saver = tf.train.Saver(tf.global_variables(), max_to_keep=20)
  102. for i in range(400):
  103. img_batch_i, label_batch_i = session.run([img_batch, tf.one_hot(label_batch, depth=10)])
  104.  
  105. feed = {inputs: img_batch_i, y_: label_batch_i}
  106. loss,_,acc=session.run([cross_entropy,train_step,accuracy], feed_dict=feed)
  107.  
  108. print("step%d loss:%f accuracy:%F"%(i,loss,acc))
  109. if i>100:
  110. learning_rate=learning_rate*0.1
  111. saver.save(session, "./save/mnist.ckpt")
  112.  
  113.  
  114. if __name__=="__main__":
  115. #function1,image to tfrecords,not use comment
  116. file_path=""
  117. name="test.txt"
  118. file_names, labels=read_data_from_paths(file_path, name)
  119. tfrecord_name="mnist_test.tfrecords"
  120. data_to_tfrecord(file_names, labels, tfrecord_name)
  121.  
  122. #function2,decode tfrecords,not use comment
  123. decode_tfrecords("./mnist_test.tfrecords")
  124. decode_tfrecords("E:/ocr_pdf_rec/crnn_v2.2/train/data/test_2k.tfrecords")
  125.  
  126. #function3,tfrecords to images,not use comment
  127. train()

 

发表评论

匿名网友

拖动滑块以完成验证
加载中...