# 自动标注启动文件
Octopus平台自动标注模块将加载“customer_auto_label.py”中定义的类并实例化,然后依次执行该实例对象的load_model、pre_auto_label、auto_label、post_auto_label方法。用户必须定义一个唯一的自定义类,并导入父类“AutoLabelModel”,同时自定义类中必须定义这四种方法。启动文件中必须导入的部分对象请参考下表。
表 1 自动标注启动文件必选项
# 启动文件编写实例
若要将自定义模型用于自动标注,启动文件编写请参考如下。
# -*- coding: utf-8 -*-
"""
用户自定义启动脚本,用于自动标注
"""
import os
import tensorflow as tf
from tensorflow.python.util import compat
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
import customer_utils
from customer_package import tensor_define
from conf.common.global_logger import logger
# 必选。用户必须导入logger对象以打印日志
from label.octopus_label_app.services.auto_label_model import AutoLabelModel
# 必选。用户必须导入AutoLabelModel类,作为下面自定义类的唯一父类
class OctopusAutolabel(AutoLabelModel): #必选。父类AutoLabelModel
"""
用户必须定义一个唯一的自定类,并集成上述父类。自定义类的名称任意。
自定义类中禁止定义__init__(self)方法,以避免冲突;
自定义类中必须定义load_model(self)、pre_auto_label(self, sample_path)、auto_label(self, sample_data)、post_auto_label(self, label_data)方法;
允许用户在自定义类中定义其他方法;(建议使用)
允许用户使用self对象在自定义类的不同方法间共享对象;(建议使用)
"""
def load_graph(self, root_dir):
op_file = os.path.join(root_dir, "sub_directory/customer_operation.so")
logger.info("loading custom operation: {}".format(op_file))
tf.load_op_library(op_file)
pb_file_path = os.path.join(root_dir, 'sub_directory/customer_model.pb')
logger.info("loading graph from {}".format(pb_file_path))
with gfile.GFile(pb_file_path, "rb") as file_reader: # Tensorflow解析pb文件
data = compat.as_bytes(file_reader.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
with tf.Graph().as_default() as graph: # Tensorflow加载运算图
tf.import_graph_def(sm.meta_graphs[0].graph_def)
return graph
def load_model(self):
logger.info("IN custom_load_model") # 使用logger对象打印日志(建议使用)
root_dir = os.path.dirname(__file__) # 通过os.path.dirname(__file__)获取根目录路径(建议使用)
graph = self.load_saved_model(root_dir)
self.sess = tf.Session(graph=graph) # 创建Tensorflow的Session对象
self.input_tensor = self.sess.graph.get_tensor_by_name("input_tensor:0") # 利用self对象实现在不同方法间共享变量(建议使用)
self.result_tensors = [self.sess.graph.get_tensor_by_name(output) for output in tensor_define.get_output_name(self.model_name)]
logger.info("OUT custom_load_model")
def pre_auto_label(self, sample_path): #必选。形参sample_path
"""
:param sample_path:图片路径
:return sample_data:该变量将传递给auto_label
"""
logger.info("IN custom_pre_auto_label")
image_data = customer_utils.load_data(sample_path)
logger.info("OUT custom_pre_auto_label")
return sample_data
def auto_label(self, sample_data): #必选。形参sample_data
"""
:param sample_data:由pre_auto_label传入
:return label_data:该变量将传递给post_auto_label
"""
logger.info("IN custom_auto_label")
label_data = dict()
res = self.sess.run(self.result_tensors, feed_dict={self.input_tensor: sample_data})
label_data['res'] = res
logger.info("OUT custom_auto_label")
return label_data
def post_auto_label(self, label_data): #必选。形参label_data
"""
:param label_data:由auto_label传入
:return upload_results:该变量将上传至数据库并在前端展示
"""
logger.info("IN custom_post_auto_label")
res = customer_utils.post_process(label_data['res'])
label_meta_dict = json.loads(os.getenv('label_meta_dict')) # 获取meta信息,用于将结果展示在前端
wanted_classes = label_meta_dict.keys() # 该变量为注册该模型时用户填入的类别信息(“训练服务”->“模型管理”->“添加模型”->“标注”)
# upload_results必选。是单张图片的结果(该变量类型为list),upload_results中每个元素对应一个bounding box结果(元素类型为dict)
# 每个元素格式为:
# {
# 'label_meta_id': label_meta_dict.get(label_name).get('label_meta_id'),
# 'score': float,
# 'bndbox': {
# 'xmin': float/int,
# 'ymin': float/int,
# 'xmax': float/int,
# 'ymax': float/int,
# },
# 'shape_type': label_meta_dict.get(label_name).get('shape_type'),
# 'difficult': boolean,
# 'occluded': int/None,
# 'truncated': int/None,
# }
upload_results = list() #必选
for r in res:
label_name = r.get('label')
if label_name in wanted_classes: # 在推理结果中,过滤掉那些没有注册的非法类别对应的bounding boxes结果
upload_results.append({
'label_meta_id': label_meta_dict.get(label_name).get('label_meta_id'),
'score': r.get('score'),
'bndbox': r.get("bndbox"),
'shape_type': label_meta_dict.get(label_name).get('shape_type'),
'difficult': False,
'occluded': None,
'truncated': None
})
return upload_results
说明:
- 为了能够在前端展示推理结果,需要从环境变量获取模型注册的合法标签信息,并且推理结果需要严格遵守post_auto_label中定义的格式;
- 不推荐在自定义启动脚本中执行文件读写操作。