图像匹配 | NCC 归一化互相关损失 | 代码 + 讲解

  • 文章转载自:微信公众号「机器学习炼丹术」
  • 作者:炼丹兄(已授权)
  • 作者联系方式:微信cyx645016617(欢迎交流共同进步)

本次的内容主要讲解NCCNormalized cross-correlation 归一化互相关。

两张图片是否是同一个内容,现在深度学习的方案自然是用神经网络,比方说:孪生网络的架构做人面识别等等;

在传统的非参数方法中,常见的也有相关系数等。我在上一片文章voxelmorph的模型的学习中发现,在医学图像配准任务(不限于医学),衡量两个图片相似的度量有一种叫做NCC的

而这个NCC就是Normalized Cross-Correlation归一化互相关系数。

1 互相关系数

如果你知道互相关系数,那么你就能很好的理解归一化互相关系数。

相关系数的计算公式如下:

\[r(X,Y) = \frac{Cov(X,Y)}{\sqrt{Var(X)Var(Y)}} \]

公式中的X,Y分别表示两个图片,\(Cov(X,Y)\)表示两个图片的协方差,\(Var(X)\)表示X自身的方差;

2 归一化互相关NCC

如果把一张图片,按照一定的像素,比方说9×9的一个框滑动,那么就可以把图片分成很多的9×9的小图片,那么NCC就是X,Y两张大图片中的对应的小图片的互相关系数的平均值。

这里看一下协方差的计算方式:
\(Cov(X,Y) = E[(X-E(X))(Y-E(Y))]\)

方差的计算为:
\(Var(X) = E[(X-E(X))^2]\)

其实NCC不难理解,但是如何用代码计算呢?当然我们可以一行一行遍历求解,但是这样时间复杂度过高,所以我们做好还是选择矩阵运算。

3 NCC损失函数的代码

class NCC:
    """
    Local (over window) normalized cross correlation loss.
    """

    def __init__(self, win=None):
        self.win = win

    def loss(self, y_true, y_pred):

        I = y_true
        J = y_pred

        # get dimension of volume
        # assumes I, J are sized [batch_size, *vol_shape, nb_feats]
        ndims = len(list(I.size())) - 2
        assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims

        # set window size
        win = [9] * ndims if self.win is None else self.win

        # compute filters
        sum_filt = torch.ones([1, 1, *win]).to("cuda")

        pad_no = math.floor(win[0]/2)

        if ndims == 1:
            stride = (1)
            padding = (pad_no)
        elif ndims == 2:
            stride = (1,1)
            padding = (pad_no, pad_no)
        else:
            stride = (1,1,1)
            padding = (pad_no, pad_no, pad_no)

        # get convolution function
        conv_fn = getattr(F, 'conv%dd' % ndims)

        # compute CC squares
        I2 = I * I
        J2 = J * J
        IJ = I * J

        I_sum = conv_fn(I, sum_filt, stride=stride, padding=padding)
        J_sum = conv_fn(J, sum_filt, stride=stride, padding=padding)
        I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding)
        J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding)
        IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding)

        win_size = np.prod(win)
        u_I = I_sum / win_size
        u_J = J_sum / win_size

        cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size
        I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
        J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size

        cc = cross * cross / (I_var * J_var + 1e-5)

        return -torch.mean(cc)

这段代码其实不是很好看懂,我思考了很久才明白。其中的关键就在于如何理解:

# compute CC squares
        I2 = I * I
        J2 = J * J
        IJ = I * J

        I_sum = conv_fn(I, sum_filt, stride=stride, padding=padding)
        J_sum = conv_fn(J, sum_filt, stride=stride, padding=padding)
        I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding)
        J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding)
        IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding)

        win_size = np.prod(win)
        u_I = I_sum / win_size
        u_J = J_sum / win_size

        cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size
        I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
        J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size

我们可以才到,这个cross应该是协方差部分,I_var和J_var是方差部分。

我们对协方差公式进行推导:\(Cov(X,Y) = E[(X-E(X))(Y-E(Y))]\)
\(=E[XY-XE(Y)-YE(X)+E(X)E(Y)]\)

这样刚好和cross对应上。

  • IJ_sum = E[XY]
  • u_J * I_sum = E[XE(Y)]
  • u_I * u_J * win_size = E[E(X)E(Y)]

对方差公式进行推导:\(Var(X) = E[(X-E(X))^2]=E[X^2-2XE(X)+E(X)^2]\)

  • J2_sum = E(X^2)
  • 2 * u_J * J_sum = E[2XE(X)]
  • u_J * u_J * win_size = E[E(X)^2]

给TA买糖
共{{data.count}}人
人已赞赏
经验教程

Vue3手册译稿 - 深入组件 - 自定义事件

2021-3-15 19:52:00

经验教程

基于autofac的属性注入

2021-3-16 9:37:00

⚠️
免责声明:根据《计算机软件保护条例》第十七条规定“为了学习和研究软件内含的设计思想和原理,通过安装、显示、传输或者存储软件等方式使用软件的,可以不经软件著作权人许可,不向其支付报酬。”您需知晓本站所有内容资源均来源于网络,仅供用户交流学习与研究使用,版权归属原版权方所有,版权争议与本站无关,用户本人下载后不能用作商业或非法用途,需在24个小时之内从您的电脑中彻底删除上述内容,否则后果均由用户承担责任;如果您访问和下载此文件,表示您同意只将此文件用于参考、学习而非其他用途,否则一切后果请您自行承担,如果您喜欢该程序,请支持正版软件,购买注册,得到更好的正版服务。 本站为个人博客非盈利性站点,所有软件信息均来自网络,所有资源仅供学习参考研究目的,并不贩卖软件,不存在任何商业目的及用途,网站会员捐赠是您喜欢本站而产生的赞助支持行为,仅为维持服务器的开支与维护,全凭自愿无任何强求。本站部份代码及教程来源于互联网,仅供网友学习交流,若您喜欢本文可附上原文链接随意转载。
无意侵害您的权益,请发送邮件至 momeis6@qq.com 或点击右侧 私信:momeis 反馈,我们将尽快处理。
0 条回复 A文章作者 M管理员
    暂无讨论,说说你的看法吧
个人中心
今日签到
有新私信 私信列表
搜索