Demo entry 6641495

my_model_detection.py

   

Submitted by anonymous on Sep 18, 2017 at 17:18
Language: Python 3. Code size: 6.1 kB.

import numpy as np
import os
import sys
import tensorflow as tf
from matplotlib import pyplot as plt
from PIL import Image
from utils import label_map_util
from utils import visualization_utils as vis_util

def load_image_into_numpy_array(image):
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

class CONFIG_bisai(object):
    MODEL_NAME = 'my_detection'
    MODEL_FILE = MODEL_NAME + '.tar.gz'
    PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
    PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
    NUM_CLASSES = 90
    SHOW_CLASSES = [1, 2, 3]#人和汽车

class CONFIG_resnet(object):
    MODEL_NAME = 'faster_rcnn_resnet'
    MODEL_FILE = MODEL_NAME + '.tar.gz'
    PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
    PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
    NUM_CLASSES = 90
    SHOW_CLASSES=[1, 2, 3, 4, 5, 6, 7, 8, ]

class TheModel():
    def __init__(self,config):
        self.config=config
        sys.path.append("..")
        self.graph=tf.Graph()
        self.session=tf.Session(graph=self.graph)
        self.label_map = label_map_util.load_labelmap(self.config.PATH_TO_LABELS)
        self.categories = label_map_util.convert_label_map_to_categories(self.label_map, max_num_classes=self.config.NUM_CLASSES,
                                                                    use_display_name=True)
        self.category_index = label_map_util.create_category_index(self.categories)

    def inference(self):
        with self.graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(self.config.PATH_TO_CKPT, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')

    def predict(self,image_np):

        with self.graph.as_default():
            image_tensor = self.graph.get_tensor_by_name('image_tensor:0')
            detection_boxes = self.graph.get_tensor_by_name('detection_boxes:0')
            detection_scores = self.graph.get_tensor_by_name('detection_scores:0')
            detection_classes = self.graph.get_tensor_by_name('detection_classes:0')
            num_detections = self.graph.get_tensor_by_name('num_detections:0')
            image_np_expanded = np.expand_dims(image_np, axis=0)
            (boxes, scores, classes, num) = self.session.run(
                [detection_boxes, detection_scores, detection_classes, num_detections],
                feed_dict={image_tensor: image_np_expanded})
            vis_util.visualize_boxes_and_labels_on_image_array(
                image_np,
                np.squeeze(boxes),
                np.squeeze(classes).astype(np.int32),
                np.squeeze(scores),
                self.category_index,
                use_normalized_coordinates=True,
                min_score_thresh=0.5,
                line_thickness=1,
                SHOW_CLASSES=self.config.SHOW_CLASSES
            )
        return image_np

    def detect(self,image_np,CLASSES,thresh_hold):
        with self.graph.as_default():
            image_tensor = self.graph.get_tensor_by_name('image_tensor:0')
            detection_boxes = self.graph.get_tensor_by_name('detection_boxes:0')
            detection_scores = self.graph.get_tensor_by_name('detection_scores:0')
            detection_classes = self.graph.get_tensor_by_name('detection_classes:0')
            num_detections = self.graph.get_tensor_by_name('num_detections:0')
            image_np_expanded = np.expand_dims(image_np, axis=0)
            (boxes, scores, classes, num) = self.session.run(
                [detection_boxes, detection_scores, detection_classes, num_detections],
                feed_dict={image_tensor: image_np_expanded})
            classes,scores=np.squeeze(classes).astype(np.int32),np.squeeze(scores)
            for i in range(classes.shape[0]):
                if scores[i]>thresh_hold:
                    if classes[i] in CLASSES:
                        return True
            return False
    def detect_and_predict(self,image_np,CLASSES,thresh_hold):
        flag = False
        with self.graph.as_default():
            image_tensor = self.graph.get_tensor_by_name('image_tensor:0')
            detection_boxes = self.graph.get_tensor_by_name('detection_boxes:0')
            detection_scores = self.graph.get_tensor_by_name('detection_scores:0')
            detection_classes = self.graph.get_tensor_by_name('detection_classes:0')
            num_detections = self.graph.get_tensor_by_name('num_detections:0')
            image_np_expanded = np.expand_dims(image_np, axis=0)
            (boxes, scores, classes, num) = self.session.run(
                [detection_boxes, detection_scores, detection_classes, num_detections],
                feed_dict={image_tensor: image_np_expanded})
            classes,scores=np.squeeze(classes).astype(np.int32),np.squeeze(scores)
            for i in range(classes.shape[0]):
                if scores[i]>thresh_hold:
                    if classes[i] in CLASSES:
                        flag= True
            vis_util.visualize_boxes_and_labels_on_image_array(
                image_np,
                np.squeeze(boxes),
                np.squeeze(classes).astype(np.int32),
                np.squeeze(scores),
                self.category_index,
                use_normalized_coordinates=True,
                min_score_thresh=0.5,
                line_thickness=1,
                SHOW_CLASSES=self.config.SHOW_CLASSES
            )
            return image_np,flag

if __name__=='__main__':
    m=TheModel(CONFIG_bisai)
    m=TheModel(CONFIG_resnet)
    m.inference()
    im=Image.open('video_data/test1.jpg')
    X=load_image_into_numpy_array(im)
    r=m.predict(X)
    flag=m.detect(X,[1,2],0.6)
    plt.imshow(r)

This snippet took 0.01 seconds to highlight.

Back to the Entry List or Home.

Delete this entry (admin only).