# 自动标注启动文件

Octopus平台自动标注模块将加载“customer_auto_label.py”中定义的类并实例化,然后依次执行该实例对象的load_model、pre_auto_label、auto_label、post_auto_label方法。用户必须定义一个唯一的自定义类,并导入父类“AutoLabelModel”,同时自定义类中必须定义这四种方法。启动文件中必须导入的部分对象请参考下表。

表 1 自动标注启动文件必选项

参数

说明

logger

用户必须导入logger对象以打印日志

AutoLabelModel

自动标注模型的类名。必须导入AutoLabelModel类,并作为自定义类的唯一父类。

load_model (self)

用于加载模型的方法,必须在自定义类中定义,且仅在推理图片前调用该方法1次。

pre_auto_label (self, sample_path)

用于前处理的方法,必须在自定义类中定义。该方法形参必须唯一且参数名为sample_path,推理每张图片时均会调用1次该方法。

auto_label (self, sample_path)

用于推理图片的方法,必须在自定义类中定义。该方法形参必须唯一且参数名为sample_data,推理每张图片时均会调用1次该方法。

post_auto_label (self, label_data)

用于后处理的方法,必须在自定义类中定义。该方法形参必须唯一且参数名为label_data,推理每张图片时均会调用1次该方法。

label_meta_dict

用于获取meta信息(标注名、形状和颜色等),并将结果展示在前端。可以通过“label_meta_dict.key()”读取注册该模型时用户填入的类别信息(“训练服务”->“模型管理”->“添加模型”->“标注”)。

upload_results

用于获取单张图片的结果(该变量类型为list),upload_results中每个元素对应一个bounding box结果(元素类型为dict)。

# 启动文件编写实例

若要将自定义模型用于自动标注,启动文件编写请参考如下。

# -*- coding: utf-8 -*-
"""
用户自定义启动脚本,用于自动标注
"""
import os
import tensorflow as tf
from tensorflow.python.util import compat
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2

import customer_utils
from customer_package import tensor_define

from conf.common.global_logger import logger
# 必选。用户必须导入logger对象以打印日志
from label.octopus_label_app.services.auto_label_model import AutoLabelModel
# 必选。用户必须导入AutoLabelModel类,作为下面自定义类的唯一父类

class OctopusAutolabel(AutoLabelModel): #必选。父类AutoLabelModel
    """
    用户必须定义一个唯一的自定类,并集成上述父类。自定义类的名称任意。
    自定义类中禁止定义__init__(self)方法,以避免冲突;
    自定义类中必须定义load_model(self)、pre_auto_label(self, sample_path)、auto_label(self, sample_data)、post_auto_label(self, label_data)方法;
    允许用户在自定义类中定义其他方法;(建议使用)
    允许用户使用self对象在自定义类的不同方法间共享对象;(建议使用)
    """
    def load_graph(self, root_dir):
        op_file = os.path.join(root_dir, "sub_directory/customer_operation.so")
        logger.info("loading custom operation: {}".format(op_file))
        tf.load_op_library(op_file)
        pb_file_path = os.path.join(root_dir, 'sub_directory/customer_model.pb')
        logger.info("loading graph from {}".format(pb_file_path))
        with gfile.GFile(pb_file_path, "rb") as file_reader: # Tensorflow解析pb文件
            data = compat.as_bytes(file_reader.read())
            sm = saved_model_pb2.SavedModel()
            sm.ParseFromString(data)
        with tf.Graph().as_default() as graph: # Tensorflow加载运算图
            tf.import_graph_def(sm.meta_graphs[0].graph_def)
        return graph

    def load_model(self):
        logger.info("IN custom_load_model") # 使用logger对象打印日志(建议使用)
        root_dir = os.path.dirname(__file__) # 通过os.path.dirname(__file__)获取根目录路径(建议使用)
        graph = self.load_saved_model(root_dir)
        self.sess = tf.Session(graph=graph) # 创建Tensorflow的Session对象
        self.input_tensor = self.sess.graph.get_tensor_by_name("input_tensor:0") # 利用self对象实现在不同方法间共享变量(建议使用)
        self.result_tensors = [self.sess.graph.get_tensor_by_name(output) for output in tensor_define.get_output_name(self.model_name)]
        logger.info("OUT custom_load_model")

    def pre_auto_label(self, sample_path):   #必选。形参sample_path
        """
        :param sample_path:图片路径
        :return sample_data:该变量将传递给auto_label
        """
        logger.info("IN custom_pre_auto_label")
        image_data = customer_utils.load_data(sample_path)
        logger.info("OUT custom_pre_auto_label")
        return sample_data

    def auto_label(self, sample_data):   #必选。形参sample_data
        """
        :param sample_data:由pre_auto_label传入
        :return label_data:该变量将传递给post_auto_label
        """
        logger.info("IN custom_auto_label")
        label_data = dict()
        res = self.sess.run(self.result_tensors, feed_dict={self.input_tensor: sample_data})
        label_data['res'] = res
        logger.info("OUT custom_auto_label")
        return label_data

    def post_auto_label(self, label_data): #必选。形参label_data
        """
        :param label_data:由auto_label传入
        :return upload_results:该变量将上传至数据库并在前端展示
        """
        logger.info("IN custom_post_auto_label")
        res = customer_utils.post_process(label_data['res'])
        label_meta_dict = json.loads(os.getenv('label_meta_dict')) # 获取meta信息,用于将结果展示在前端 
        wanted_classes = label_meta_dict.keys() # 该变量为注册该模型时用户填入的类别信息(“训练服务”->“模型管理”->“添加模型”->“标注”)
  
        # upload_results必选。是单张图片的结果(该变量类型为list),upload_results中每个元素对应一个bounding box结果(元素类型为dict)
        # 每个元素格式为:
        # {
        #      'label_meta_id': label_meta_dict.get(label_name).get('label_meta_id'),
        #      'score': float,
        #      'bndbox': {
        #                    'xmin': float/int,
        #                    'ymin': float/int,
        #                    'xmax': float/int,
        #                    'ymax': float/int,
        #                },
        #      'shape_type': label_meta_dict.get(label_name).get('shape_type'),
        #      'difficult': boolean,
        #      'occluded': int/None,
        #      'truncated': int/None,
        # }
        upload_results = list()      #必选
        for r in res:
            label_name = r.get('label')
            if label_name in wanted_classes: # 在推理结果中,过滤掉那些没有注册的非法类别对应的bounding boxes结果
                 upload_results.append({
                    'label_meta_id': label_meta_dict.get(label_name).get('label_meta_id'),
                    'score': r.get('score'),
                    'bndbox': r.get("bndbox"),
                    'shape_type': label_meta_dict.get(label_name).get('shape_type'),
                    'difficult': False,
                    'occluded': None,
                    'truncated': None
                })
        return upload_results

说明:

  • 为了能够在前端展示推理结果,需要从环境变量获取模型注册的合法标签信息,并且推理结果需要严格遵守post_auto_label中定义的格式;
  • 不推荐在自定义启动脚本中执行文件读写操作。
上次更新: 5/25/2021, 10:43:12 AM