# 模型评估启动文件

Octopus平台自动标注模块将加载“customer_inference.py”中定义的类并实例化,然后依次执行该实例对象的load_model、infer_preprocess、infer_execution、infer_postprocess方法。用户必须定义一个唯一的自定义类,并导入父类“AutoLabelModel”,同时自定义类中必须定义这四种方法。启动文件中必须导入的部分对象请参考下表。

表 1 模型评估启动文件必选项

参数

说明

logger

用户必须导入logger对象以打印日志

StandardInferenceClass

模型评估的类名。必须导入StandardInferenceClass类,并作为自定义类的唯一父类。

load_model (self, inference_rootdir)

用于加载模型的方法,必须在自定义类中定义。该方法形参必须唯一且参数名为inference_rootdir,且该变量为根目录的绝对路径,仅在推理图片前调用该方法1次。

infer_preprocess(self, inference_image)

用于前处理的方法,必须在自定义类中定义。该方法形参必须唯一且参数名为inference_image,推理每张图片时均会调用1次该方法。

infer_execution (self, inference_input)

用于推理图片的方法,必须在自定义类中定义。该方法形参必须唯一且参数名为inference_input,推理每张图片时均会调用1次该方法。

infer_postprocess(self, inference_output)

用于后处理的方法,必须在自定义类中定义。该方法形参必须唯一且参数名为inference_output,推理每张图片时均会调用1次该方法。

label_meta_dict

用于获取meta信息(标注名、形状和颜色等),并将推理结果与评估数据集中的类别信息匹配。可以通过“label_meta_dict.key()”读取注册该模型时用户填入的类别信息(“训练服务”->“模型管理”->“添加模型”->“模型评估”);

upload_results

用于获取单张图片的结果(该变量类型为list),upload_results中每个元素对应一个bounding box结果(元素类型为dict)。

# 启动文件编写实例

若要将自定义模型用于模型评估,启动文件编写请参考如下。

# -*- coding: utf-8 -*-
"""
用户自定义启动脚本,用于模型评估
"""
import os
import tensorflow as tf
from tensorflow.python.util import compat
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2

import customer_utils
from customer_package import tensor_define

from conf.common.global_logger import logger
# 必选。用户必须导入该对象(logger)以打印日志
from evaluate.octopus_evaluate_app.services.standard_inference_class import StandardInferenceClass
# 必选。用户必须导入该类,作为下面自定义类的唯一父类


class OctopusAutolabel(StandardInferenceClass): #必选。父类StandardInferenceClass
    """
    用户必须定义一个唯一的自定类,并集成上述父类。自定义类的名称任意。
    自定义类中禁止定义__init__(self)方法,以避免冲突;
    自定义类中必须定义load_model(self, inference_rootdir)、infer_preprocess(self, inference_image)、infer_execution(self, inference_input)、infer_postprocess(self, inference_output)方法;
    允许用户在自定义类中定义其他方法;(建议使用)
    允许用户使用self对象在自定义类的不同方法间共享对象;(建议使用)
    """
    
    def load_graph(self, root_dir):
        op_file = os.path.join(root_dir, "sub_directory/customer_operation.so")
        logger.info("loading custom operation: {}".format(op_file))
        tf.load_op_library(op_file)
        pb_file_path = os.path.join(root_dir, 'sub_directory/customer_model.pb')
        logger.info("loading graph from {}".format(pb_file_path))
        with gfile.GFile(pb_file_path, "rb") as file_reader:   # Tensorflow解析pb文件
            data = compat.as_bytes(file_reader.read())
            sm = saved_model_pb2.SavedModel()
            sm.ParseFromString(data)
        with tf.Graph().as_default() as graph:   # Tensorflow加载运算图
            tf.import_graph_def(sm.meta_graphs[0].graph_def)
        return graph

    def load_model(self, inference_rootdir):  #必选
        logger.info("IN custom_load_model")
        root_dir = os.path.dirname(__file__) # 既可以通过os.path.dirname(__file__)获取根目录路径(建议使用)
        root_dir = inference_rootdir # 也可以直接从inference_rootdir获取根目录路径(建议使用)
        graph = self.load_saved_model(root_dir)
        self.sess = tf.Session(graph=graph) # 创建Tensorflow的Session对象
        self.input_tensor = self.sess.graph.get_tensor_by_name("input_tensor:0") # 利用self对象实现在不同方法间共享变量(建议使用)
        self.result_tensors = [self.sess.graph.get_tensor_by_name(output) for output in tensor_define.get_output_name(self.model_name)]
        logger.info("OUT custom_load_model")

    def infer_preprocess(self, inference_image): #必选。形参inference_image必选
        """
        :param inference_image:图片路径
        :return inference_input:该变量将传递给infer_execution
        """
        logger.info("IN custom_pre_auto_label")
        image_data = customer_utils.load_data(inference_image)
        logger.info("OUT custom_pre_auto_label")
        return inference_input

    def infer_execution(self, inference_input): #必选。形参inference_input必选
        """
        :param inference_input:由infer_preprocess传入
        :return inference_output:该变量将传递给infer_postprocess
        """
        logger.info("IN custom_auto_label")
        label_data = dict()
        res = self.sess.run(self.result_tensors, feed_dict={self.input_tensor: inference_input})
        label_data['res'] = res
        logger.info("OUT custom_auto_label")
        return inference_output

    def infer_postprocess(self, inference_output): #必选。形参inference_output必选
        """
        :param inference_output:由infer_execution传入
        :return upload_results:该变量将上传至数据库并在前端展示
        """
        logger.info("IN custom_post_auto_label")
        res = customer_utils.post_process(inference_output['res'])
        label_meta_dict = json.loads(os.getenv('label_meta_dict'))  # 必选。获取meta信息,用于将推理结果与评估数据集中的类别信息匹配
        wanted_classes = label_meta_dict.keys() # 必选。该变量为注册该模型时用户填入的类别信息(“训练服务”->“模型管理”->“添加模型”->“模型评估”)

        # upload_results必选。是单张图片的结果(该变量类型为list),upload_results中每个元素对应一个bounding box结果(元素类型为dict)
        # 每个元素格式为:
        # {
        #      'label_meta_id': label_meta_dict.get(label_name).get('label_meta_id'),
        #      'score': float,
        #      'bndbox': {
        #                    'xmin': float/int,
        #                    'ymin': float/int,
        #                    'xmax': float/int,
        #                    'ymax': float/int,
        #                },
        #      'shape_type': label_meta_dict.get(label_name).get('shape_type'),
        #      'difficult': boolean,
        #      'occluded': int/None,
        #      'truncated': int/None,
        # }

        upload_results = list()   # 必选
        for r in res:
            label_name = r.get('label')
            if label_name in wanted_classes:   # 在推理结果中,过滤掉那些没有注册的非法类别对应的bounding boxes结果
                 upload_results.append({
                    'label_meta_id': label_meta_dict.get(label_name).get('label_meta_id'),
                    'score': r.get('score'),
                    'bndbox': r.get("bndbox"),
                    'shape_type': label_meta_dict.get(label_name).get('shape_type'),
                    'difficult': False,
                    'occluded': None,
                    'truncated': None
                })
        return upload_results

说明:

  • 为了能够保证模型评估正确读取推理结果,需要从环境变量获取模型注册的合法标签信息,并且推理结果需要严格遵守infer_postprocess中定义的格式;
  • 不推荐在自定义启动脚本中执行文件读写操作。
上次更新: 5/25/2021, 10:43:12 AM