# 模型评估启动文件
Octopus平台自动标注模块将加载“customer_inference.py”中定义的类并实例化,然后依次执行该实例对象的load_model、infer_preprocess、infer_execution、infer_postprocess方法。用户必须定义一个唯一的自定义类,并导入父类“AutoLabelModel”,同时自定义类中必须定义这四种方法。启动文件中必须导入的部分对象请参考下表。
表 1 模型评估启动文件必选项
# 启动文件编写实例
若要将自定义模型用于模型评估,启动文件编写请参考如下。
# -*- coding: utf-8 -*-
"""
用户自定义启动脚本,用于模型评估
"""
import os
import tensorflow as tf
from tensorflow.python.util import compat
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
import customer_utils
from customer_package import tensor_define
from conf.common.global_logger import logger
# 必选。用户必须导入该对象(logger)以打印日志
from evaluate.octopus_evaluate_app.services.standard_inference_class import StandardInferenceClass
# 必选。用户必须导入该类,作为下面自定义类的唯一父类
class OctopusAutolabel(StandardInferenceClass): #必选。父类StandardInferenceClass
"""
用户必须定义一个唯一的自定类,并集成上述父类。自定义类的名称任意。
自定义类中禁止定义__init__(self)方法,以避免冲突;
自定义类中必须定义load_model(self, inference_rootdir)、infer_preprocess(self, inference_image)、infer_execution(self, inference_input)、infer_postprocess(self, inference_output)方法;
允许用户在自定义类中定义其他方法;(建议使用)
允许用户使用self对象在自定义类的不同方法间共享对象;(建议使用)
"""
def load_graph(self, root_dir):
op_file = os.path.join(root_dir, "sub_directory/customer_operation.so")
logger.info("loading custom operation: {}".format(op_file))
tf.load_op_library(op_file)
pb_file_path = os.path.join(root_dir, 'sub_directory/customer_model.pb')
logger.info("loading graph from {}".format(pb_file_path))
with gfile.GFile(pb_file_path, "rb") as file_reader: # Tensorflow解析pb文件
data = compat.as_bytes(file_reader.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
with tf.Graph().as_default() as graph: # Tensorflow加载运算图
tf.import_graph_def(sm.meta_graphs[0].graph_def)
return graph
def load_model(self, inference_rootdir): #必选
logger.info("IN custom_load_model")
root_dir = os.path.dirname(__file__) # 既可以通过os.path.dirname(__file__)获取根目录路径(建议使用)
root_dir = inference_rootdir # 也可以直接从inference_rootdir获取根目录路径(建议使用)
graph = self.load_saved_model(root_dir)
self.sess = tf.Session(graph=graph) # 创建Tensorflow的Session对象
self.input_tensor = self.sess.graph.get_tensor_by_name("input_tensor:0") # 利用self对象实现在不同方法间共享变量(建议使用)
self.result_tensors = [self.sess.graph.get_tensor_by_name(output) for output in tensor_define.get_output_name(self.model_name)]
logger.info("OUT custom_load_model")
def infer_preprocess(self, inference_image): #必选。形参inference_image必选
"""
:param inference_image:图片路径
:return inference_input:该变量将传递给infer_execution
"""
logger.info("IN custom_pre_auto_label")
image_data = customer_utils.load_data(inference_image)
logger.info("OUT custom_pre_auto_label")
return inference_input
def infer_execution(self, inference_input): #必选。形参inference_input必选
"""
:param inference_input:由infer_preprocess传入
:return inference_output:该变量将传递给infer_postprocess
"""
logger.info("IN custom_auto_label")
label_data = dict()
res = self.sess.run(self.result_tensors, feed_dict={self.input_tensor: inference_input})
label_data['res'] = res
logger.info("OUT custom_auto_label")
return inference_output
def infer_postprocess(self, inference_output): #必选。形参inference_output必选
"""
:param inference_output:由infer_execution传入
:return upload_results:该变量将上传至数据库并在前端展示
"""
logger.info("IN custom_post_auto_label")
res = customer_utils.post_process(inference_output['res'])
label_meta_dict = json.loads(os.getenv('label_meta_dict')) # 必选。获取meta信息,用于将推理结果与评估数据集中的类别信息匹配
wanted_classes = label_meta_dict.keys() # 必选。该变量为注册该模型时用户填入的类别信息(“训练服务”->“模型管理”->“添加模型”->“模型评估”)
# upload_results必选。是单张图片的结果(该变量类型为list),upload_results中每个元素对应一个bounding box结果(元素类型为dict)
# 每个元素格式为:
# {
# 'label_meta_id': label_meta_dict.get(label_name).get('label_meta_id'),
# 'score': float,
# 'bndbox': {
# 'xmin': float/int,
# 'ymin': float/int,
# 'xmax': float/int,
# 'ymax': float/int,
# },
# 'shape_type': label_meta_dict.get(label_name).get('shape_type'),
# 'difficult': boolean,
# 'occluded': int/None,
# 'truncated': int/None,
# }
upload_results = list() # 必选
for r in res:
label_name = r.get('label')
if label_name in wanted_classes: # 在推理结果中,过滤掉那些没有注册的非法类别对应的bounding boxes结果
upload_results.append({
'label_meta_id': label_meta_dict.get(label_name).get('label_meta_id'),
'score': r.get('score'),
'bndbox': r.get("bndbox"),
'shape_type': label_meta_dict.get(label_name).get('shape_type'),
'difficult': False,
'occluded': None,
'truncated': None
})
return upload_results
说明:
- 为了能够保证模型评估正确读取推理结果,需要从环境变量获取模型注册的合法标签信息,并且推理结果需要严格遵守infer_postprocess中定义的格式;
- 不推荐在自定义启动脚本中执行文件读写操作。