首页 >> 大全

CAM(class attention map)

2023-08-04 大全 23 作者:考证青年

1 主要思想

上图中,网络的最后一个卷积层,经过全局平均池化后,后跟上全连接,经过输出类别。其公式描述为:

Sc是c类对应的全连接的输入,也就是加权求和,k代表第k层,c代表第c类。f为第k层(x,y)处的激活值,求和符号是GAP,其前面经过w加权求和,输出为Sc。其等价于:

在(1)中第二行是先利用第c类对应的权重值w1~wk与特征图的每一层对应相乘,也就是给每层都分配一个权重,之后所有特征图沿通道合并,也就是将k通道变为1通道,最后再全局平均池化。

也就是说,先GAP后加权求和,等价于,先加权求和后GAP。

此处的对于类c,先对特征图进行加权求和,得到的就是所说的cam图。图上的每一点反映了其对分类为类c的贡献程度,其经过GAP就是全连接的输入了。

2 适用场景

只适用网络中全连接前是GAP的网络。至于其他的场景,也有使用梯度,或者非梯度作为特征图层的权重的方法。

3 代码

思路:

获取图片以及类别,对图片进行预处理;

创建网络,获取网络特征图以及权重参数;

生成CAM;

展示。

_map赋值给另一个map_map转class对象

1 )导入包

from torchvision import models, transforms
import torchsnooper
import numpy as np
import cv2
import requests
from PIL import Image
import io
from torch.autograd import Variable
import torch
import torch.nn.functional as F

2)获取图片以及对应的类别,并进行预处理

LABELS_URL = 'https://s3.amazonaws.com/outcome-blog/imagenet/labels.json'
classes = {int(key):value for (key, value) in requests.get(LABELS_URL).json().items()} #获取类别信息,键值对IMG_URL = 'http://media.mlive.com/news_impact/photo/9933031-large.jpg'
reponse = requests.get(IMG_URL) 
img_pil = Image.open(io.BytesIO(reponse.content))
img_pil.save('test.jpg')#获取图片normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),#图片预处理std=(0.229, 0.224, 0.225))
preprocess = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),normalize
])
img_tensor = preprocess(img_pil)
img_variable = Variable(torch.unsqueeze(img_tensor, dim=0)) #图片转为变量类型,注意图片在pytorch进行处理时格式:(批量,通道,高度,宽度)(B,C,H,W),此外注意输出工具与torch中数据的格式不同。

3)创建网络,获取权重参数,创建钩子

中net.eval() 和net.train()的使用

net = models.resnet18(pretrained=True) #预先训练的
net.eval() params = list(net.parameters()) #此时net.parameters是生成器,只能使用next来获取,或者使用list将其转为列表
weight_softmax = params[-2].data.numpy() #最后一层的weight,因为还有bias,所以是-2. 另外此处不需要squeeze(),沿着axis=0获取的数据维度不再包含这个维度,如params的维度为<100,20,4>,params[-2]的维度为<20,4>features_blob = []
def hook_feature(module, input, output): #创建钩子,钩子的输入为模型,层的输入,层的输出,此处的层是广义的层,如果网络中有以Sequential嵌套在整个网络中,那么这个Sequential相当于为一个层。钩子一般用于获取模型中层的数据参数等。不对层进行更改。features_blob.append(output.data.cpu().numpy())net.layer4.register_forward_hook(hook=hook_feature) #注册钩子,也就是对特定的层进行注册。logit = net(img_variable) #网络的输出为<1,k>格式
logit_softmax = F.softmax(logit, dim=1).data.squeeze() #去除0维,变成向量
logit_sort, index = logit_softmax.sort(descending=True) #sort函数返回两个值,一个是排序后的向量,一个是排序后对应的索引,此时仍然是tensor
logit = logit_sort.numpy()
index = index.numpy()

4)生成CAM

def returnCAM(feature_conv, weight_softmax, classid):size_unsample = (256, 256)output_cam = []for id in classid:cam = weight_softmax[id] #获取对应类的k个权重,分别于特征图的k个通道相对应#***cam = np.expand_dims(np.expand_dims(cam, axis=1), axis=2) #weight, feature,利用广播机制对应相乘cam = cam * feature_conv #***cam = np.sum(cam, axis=0) #所有特征图叠加cam = cam - np.min(cam)cam_image = 255* cam / np.max(cam) #归一化,并转为0~255,CV2处理图片的格式是uint8也就是无符号数,取值为0~255cam_image = cam_image.astype(np.uint8)output_cam.append(cv2.resize(cam_image, size_unsample))return output_cam

上面的#***的代码也可以改用下面的,意思是将特征图打成一个向量,权重@同一位置(x,y)处的激活值向量。

cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h * w)))#或者矩阵相乘 weight<1, k> @ feature => <1, 7*7>,
cam = cam.reshape(h, w)

5)展示

cams = returnCAM(np.squeeze(features_blob[0]), weight_softmax, [index[0]])img = cv2.imread('test.jpg')
height, width,_ =img.shape
heatmap = cv2.applyColorMap(cv2.resize(cams[0], (width, height)), cv2.COLORMAP_JET) #此处的cams[0],因为返回的时候就是一个列表
result = heatmap * 0.3 + img * 0.5
cv2.imwrite('CAM0.jpg', result)

注意事项总结:

思路:

获取图片以及类别,对图片进行预处理;

_map赋值给另一个map_map转class对象

创建网络,获取网络特征图以及权重参数;

生成CAM;

展示。

1、获取图片以及类别,进行预处理

1)采用或者cv2对图片进行处理较好,因为其操作的均为array,而PIL得到的是Image对象。

2)其次,注意读取的图片的格式(b,h,w,c),而torch中是(b,c,h,w)。

3)进行测试或者训练时,特别是单张图时,要转化为四维的数据,添加一个维度。

2、创建网络

1)使用print(net)可以输出网络的结构,可以通过net.layer,也就是层的名字来获取此层。

2)net.()获取的是一个生成器,使用list()函数转化为列表。记得参数中不仅有权重,还有偏置bias。

3)钩子的使用:钩子用于从某个层中获取参数,钩子函数输入为模型、层的输入,层的输出。返回None。注册钩子时,采用net.layer.k(hook=hook)。

4)获取torch中的参数时,若为变量,需要使用.data得到,再使用.numpy()得到array。

3、广播机制

两个向量点乘时,使用广播机制。从最后一位开始,两个向量的形状shape或者有一个为1,或者相同,或者有一个不存在。如与其可以进行广播。常用于特征层的加权求和。

4、使用cv2的函数

heatmap = cv2.applyColorMap(cv2.resize(cams[0], (width, height)), cv2.COLORMAP_JET)

中伪彩色函数

关于我们

最火推荐

小编推荐

联系我们


版权声明:本站内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 88@qq.com 举报,一经查实,本站将立刻删除。备案号:桂ICP备2021009421号
Powered By Z-BlogPHP.
复制成功
微信号:
我知道了