(19)国家知识产权局
(12)发明 专利申请
(10)申请公布号
(43)申请公布日
(21)申请 号 202210240273.8
(22)申请日 2022.03.10
(71)申请人 南京邮电大 学
地址 210003 江苏省南京市 鼓楼区新模范
马路66号
(72)发明人 范保杰 吴育竹 蒋国平 徐丰羽
(74)专利代理 机构 南京纵横知识产权代理有限
公司 32224
专利代理师 刘艳艳
(51)Int.Cl.
G06T 7/246(2017.01)
G06N 3/04(2006.01)
G06N 3/08(2006.01)
G06N 20/00(2019.01)
G06V 10/25(2022.01)G06V 10/77(2022.01)
G06V 10/774(2022.01)
G06V 10/82(2022.01)
G06V 20/40(2022.01)
G06V 20/52(2022.01)
(54)发明名称
一种基于强化学习算法SAC的目标跟踪方
法、 装置及存 储介质
(57)摘要
本发明公开了一种基于强化学习算法SAC的
目标跟踪方法、 装置及存储介质, 方法包括: 获取
视频数据; 在视频数据的当前帧中确定搜索区域
位置和大小; 判断当前帧是否为第一帧; 响应于
当前帧非第一帧, 将当前帧输入预训练好的
actor网络模型进行特征提取, 得到输出的预测
框, 根据所述预测框对目标进行跟踪; 其中所述
actor网络模型的训练方法, 包括: 通过第一帧对
actor、 target_actor网络进行初始化, 根据经验
池中存储的数据, 通过actor、 critic网络计算动
作, 计算actor、 critic1、 critic2网络损失, 利用
强化学习SAC算法更新网络权值。 将目标跟踪问
题转化为强化学习算法中在 线决策的问题, 并且
本发明只需要少量数据集, 充分利用现有技术,
提升训练速度。
权利要求书3页 说明书8页 附图2页
CN 114897930 A
2022.08.12
CN 114897930 A
1.一种基于强化学习算法SAC的目标跟踪方法, 其特 征在于, 包括:
获取视频 数据;
在视频数据的当前帧中确定 搜索区域 位置和大小;
判断当前帧是否为第一帧;
响应于当前帧非第一帧, 获取当前帧的上一帧的预测结果, 根据当前帧的上一帧的预
测结果在当前帧裁取图像得到当前帧裁取图像s, 将当前帧裁取图像s输入预训练好的
actor网络模型进行 特征提取, 得到 输出的预测框;
根据所述预测框对目标进行跟踪, 并将预测框作为下一帧的groundTruth;
其中所述actor网络模型的训练方法, 包括:
在视频数据的第一帧中确定目标的大小和位置, 设置经验 池参数;
初始化actor、 target_actor,critic1、 target_critic1,critic2、 target_critic2网
络参数;
响应于当前输入帧为第一帧, 对actor、 target_actor网络进行初始化,
响应于当前帧非第一帧, 根据当前帧预测框在当前帧裁取图像, 得到预测框裁取图像
s', 并计算所述当前帧预测框与当前帧真实框 之间的IoU; 根据IoU, 通过奖励函数计算得到
奖励值;
将上一帧的预测结果在当前帧裁取图像s、 动作、 奖励值、 预测框裁取图像s'存入经验
池;
根据经验池中存储的数据, 通过actor、 critic网络计算动作at、 动作at在定义的分布
Normal(mu.std)中对应的概 率的对数l ogπφ(a|s);
根据计算得到的动作at、 动作at在定义的分布Normal(mu.std)中对应的概率的对数log
πφ(a|s), 计算actor、 critic1、 critic2网络损失, 利用强化学习SAC算法更新网络 权值。
2.根据权利要求1所述的基于强化学习算法SAC的目标跟踪方法, 其特征在于, 其中, 所
述设置经验池参数, 包括经验池容量X, 表 示可以存储 X条数据, 每一条数据符号为: (s,a,r,
s'), 其中s表示上一帧的预测结果在当前帧裁取图像、 a表示动作、 r表示奖励值、 s'表示当
前帧预测框在当前帧裁取图像。
3.根据权利要求1所述的基于强化学习算法SAC的目标跟踪方法, 其特征在于, 所述设
置经验池参数, 包括: 根据经验设置经验池容量为X, 表 示可以存储 X条数据符号为: (s,a,r,
s'), 一条数据包含: 在当前帧所裁取1*3*107*107维度图像, 1*3维度的动作, 1*1的奖励值,
当前帧图像采取动作后裁取的1* 3*107*107维度图像。
4.根据权利要求1所述的基于强化学习算法SAC的目标跟踪方法, 其特征在于, 所述初
始化actor、 target_actor,critic1、 target_critic1,critic2、 target_critic2网络参数,
包括: 加载在imageNet预训练好的vgg ‑M网络的前四层网络参数, 并以此作为图片特征提取
模型网络, 并将actor、 critic1,critic2网络参数分别赋值给target_actor、 target_
critic1,target_critic2网络参数。
5.根据权利要求1所述的基于强化学习算法SAC的目标跟踪方法, 其特征在于, 响应于
当前输入帧为第一帧, 对actor、 target _actor网络进行初始化, 包括: 若 此帧图片为该视频
序列第一帧, 最小化actor以及target_actor网络输出和标签 之间的误差, 损失函数表达式
为:权 利 要 求 书 1/3 页
2
CN 114897930 A
2其中μ(sm|φμ)为在groundTruth加入高斯噪声, 产生M个样本, 经由actor网络处理后输
出的预测动作, am为标签, 是M个样本与groundTruth的真实距离, μ是actor网络, m表示第m
个数据;
通过Adam优化器训练actor以及target_actor网络参数。
6.根据权利要求1所述的基于强化学习算法SAC的目标跟踪方法, 其特征在于, 根据
IoU, 通过奖励函数计算得到奖励值, 包括:
其中, r表示奖励值, b ’预测框、 G为真实框 。
7.根据权利要求1所述的基于强化学习算法SAC的目标跟踪方法, 其特征在于, 根据经
验池中存储的数据, 通过actor、 critic网络计算动作at、 动作at在定义的分布Normal
(mu.std)中对应的概率的对数logπφ(a|s), 包括: 重参数方式所得动作at=fφ( εt; st)=pi,
和动作at在定义的分布Normal(mu.std)中对应的概率的对数; logπφ(a|s)其中πφ表示
actor网络, 网络参数为φ;
at=fφ( εt; st)=pi使用重参数方式计算方法包括: 将从经验池中取出的s送入actor网
络计算得到mu、 std, 然后计算以均值为mu, 方差为std的高斯分布: pi_distribution=
Normal(mu,std), 然后通过重参数方式进行采样, 重参数采样方式为: 先从数学期望为0、 标
准方差为1的高斯正太分布(Normal(0,1))中采样ε, ε~N(0,1), 再经过线性变换操作得到
at: at=mu+std* ε;
动作at在定义的分布Normal(mu.std)中对应的概率的对数logπφ(a|s)根据公式logπφ
(a|s)=log(pi) ‑∑log(1‑tanh2(pi))计算而得; 其中pi为上述重参数方式所得动作at, 其
中log(pi)由高斯似然函数获得, 函数公式为
8.根据权利要求7所述的基于强化学习算法SAC的目标跟踪方法, 其特征在于, 计算
actor、 critic1、 critic2网络损失, 利用强化学习SAC算法更新网络 权值, 包括:
actor网络l oss函数定义 为:
其中, Q表示critic网络, θ表示critic网络参数, fφ( εt; st)为重参数方式所得动作, st
~D表示从经验池D中采样出st, εt~N表示从正太分分布中采样; logπφ(fφ( εt; st)|st)即为
上述logπφ(a|s), fφ( εt; st)即为上述a;
经求导后, Actor网络梯度计算 为:
critic网络l oss定义为:
权 利 要 求 书 2/3 页
3
CN 114897930 A
3
专利 一种基于强化学习算法SAC的目标跟踪方法、装置及存储介质
文档预览
中文文档
14 页
50 下载
1000 浏览
0 评论
309 收藏
3.0分
温馨提示:本文档共14页,可预览 3 页,如浏览全部内容或当前文档出现乱码,可开通会员下载原始文档
本文档由 人生无常 于 2024-03-18 12:01:44上传分享