python - How to feed Cifar10 trained model with my own image and get label as output? -


i trying use trained model based on cifar10 tutorial , feed external image 32x32 (jpg or png).
goal able the label output. in other words, want feed network single jpeg image of size 32 x 32, 3 channels no label input , have inference process give me tf.argmax(logits, 1).
able use trained cifar10 model on external image , see class spit out.

i have been trying based on cifar10 tutorial , unfortunately have issues. session concept , batch concept.

any doing cifar10 appreciated.

here implemented code far compilation issues :

#!/usr/bin/env python  __future__ import absolute_import __future__ import division __future__ import print_function  datetime import datetime import math import time  import tensorflow.python.platform tensorflow.python.platform import gfile import numpy np import tensorflow tf  import cifar10 import cifar10_input import os import faultnet_flags pil import image  flags = tf.app.flags.flags  def evaluate():    filename_queue = tf.train.string_input_producer(['/home/tensor/.../inputimage.jpg'])    reader = tf.wholefilereader()   key, value = reader.read(filename_queue)    input_img = tf.image.decode_jpeg(value)    init_op = tf.initialize_all_variables()  # problem in here graph / session   tf.session() sess:     sess.run(init_op)      coord = tf.train.coordinator()     threads = tf.train.start_queue_runners(coord=coord)      in range(1):        image = input_img.eval()      print(image.shape)     image.fromarray(np.asarray(image)).show()  # problem in here have 1 image input , have no label , have # compatible cifar10 network     reshaped_image = tf.cast(image, tf.float32)     height = flags.resized_image_size     width = flags.resized_image_size     resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, width, height)     float_image = tf.image.per_image_whitening(resized_image)  # reshaped_image     num_preprocess_threads = 1     images = tf.train.batch(       [float_image],       batch_size=128,       num_threads=num_preprocess_threads,       capacity=128)     coord.request_stop()     coord.join(threads)      logits = faultnet.inference(images)      # calculate predictions.     #top_k_predict_op = tf.argmax(logits, 1)      # print('current image is: ')     # print(top_k_predict_op[0])      # not work since there problem session     # , graph conflicting     my_classification = sess.run(tf.argmax(logits, 1))      print ('predicted ', my_classification[0], " input image.")   def main(argv=none):   evaluate()  if __name__ == '__main__':   tf.app.run() ''' 

some basics first:

  1. first define graph: image queue, image preprocessing, inference of convnet, top-k accuracy
  2. then create tf.session() , work inside it: starting queue runners, , calls sess.run()

here code should like

# 1. graph creation  filename_queue = tf.train.string_input_producer(['/home/tensor/.../inputimage.jpg']) ...  # no creation of tf.session here float_image = ... images = tf.expand_dims(float_image, 0)  # create fake batch of images (batch_size=1) logits = faultnet.inference(images) _, top_k_pred = tf.nn.top_k(logits, k=5)  # 2. tensorflow session tf.session() sess:     sess.run(init_op)      coord = tf.train.coordinator()     threads = tf.train.start_queue_runners(coord=coord)      top_indices = sess.run([top_k_pred])     print ("predicted ", top_indices[0], " input image.") 

edit:

as @mrry suggests, if need work on single image, can remove queue runners:

# 1. graph creation input_img = tf.image.decode_jpeg(tf.read_file("/home/.../your_image.jpg"), channels=3) reshaped_image = tf.image.resize_image_with_crop_or_pad(tf.cast(input_img, width, height), tf.float32) float_image = tf.image.per_image_withening(reshaped_image) images = tf.expand_dims(float_image, 0)  # create fake batch of images (batch_size = 1) logits = faultnet.inference(images) _, top_k_pred = tf.nn.top_k(logits, k=5)  # 2. tensorflow session tf.session() sess:   sess.run(init_op)    top_indices = sess.run([top_k_pred])   print ("predicted ", top_indices[0], " input image.") 

Popular posts from this blog

php - How should I create my API for mobile applications (Needs Authentication) -

python 3.x - PyQt5 - Signal : pyqtSignal no method connect -

5 Reasons to Blog Anonymously (and 5 Reasons Not To)