【创新实训 第一周】 CTPN 尝试复现 2019.3.23

本周开始创新实训,主要任务是熟悉 Keras 并尝试复现 CTPN 文本检测模型。完成了模型主体框架搭建,正在进行损失函数及数据集输入方法的设计。

本周工作进展

创新实训正式开始,第一周主要是熟悉 keras 的使用,并尝试复现文本检测模型 ctpn。

CTPN 论文地址

2019.3.24 昨天晚上继续做的时候发现我把模型理解错了,下面的内容如果有人看见的话就当我在胡说八道好了(/ω\*)……


详细工作内容

完成模型大题框架搭建,正在设计损失函数与数据集输入方法。

 

 如图,模型由以下部分组成:

  1. VGG16 网络,取到 conv5 的第三层;
  2. 在Conv5的feature map的每个位置上取3*3*C的窗口的特征,输入双向 LSTM;
  3. 全连接层;
  4. 最后输出三个分支,分别是 anchor 的纵坐标与高度,anchor 是否为文字的分数,以及边缘提纯的偏移量。
def ctpn_model():
    date = np.random.rand(1, 600, 900, 3).astype(np.float32)

    input_layer = vgg16_no_tail()(date)

    # unfold 不会弄,先用卷积代替
    x = keras.layers.Convolution2D(
        512, (3, 3),
        activation='relu',
        padding='same',
        name='cnn2rnn')(input_layer)

    # 下面双向 lstm 出了点问题,先将输出变形
    x = keras.layers.Reshape((x.shape[1], -1))(x)

    # 如果 shape0 是 batch,先假设 shape1 是 h
    x = keras.layers.Bidirectional(keras.layers.LSTM(128))(x)

    x = keras.layers.Dense(512)(x)

    vertical_coordinate = keras.layers.Dense(20)(x)
    score = keras.layers.Dense(20)(x)
    side_refinement = keras.layers.Dense(10)(x)


# vgg16 取 conv5 的第三层
def vgg16_no_tail():
    vgg = keras.applications.VGG16(weights=None)
    vgg_no_tail = keras.Model(
        inputs=vgg.input,
        outputs=vgg.get_layer("block5_conv3").output)
    return vgg_no_tail

下一步计划

  • 模型中 cnn 转向 rnn 的中间步骤需要提取 3*3 的矩阵并拼接,目前暂时以一个卷积层代替,视最终运行效果决定是否修改。
  • 完成损失函数和数据集输入方法的设计。

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值