在 PyTorch 实现图像的 Normalize 和 反 Normalize 的实验中,发现经过这两个转换后存储的图像和原始图像虽然视觉上没什么差异,但在二进制上却不能完全匹配,这里记录下问题的原因分析及最终的解决过程。
下面是抽象出来的问题及解决问题的代码:
#!/usr/bin/env python # -*- coding: utf-8 -*- import torch from PIL import Image import torchvision.transforms as transforms # 1. Read image imgFolder = "/home/test/image/" imgSrc = Image.open(imgFolder + "src.jpg") imgSrc.save(imgFolder + "./00src.png") # 2. Save source image tensorSrc = transforms.ToTensor()(imgSrc) imgRlt = transforms.ToPILImage()(tensorSrc) imgRlt.save(imgFolder + "./00rlt.png") # Normalized transform tensorTrans = tensorSrc.clone() tensorTrans = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(tensorTrans) tensorRevert = tensorTrans * 0.5 + 0.5 # 3. Usual revert transform imgRevert = transforms.ToPILImage()(tensorRevert) imgRevert.save(imgFolder + "01rlt.png") # 4. Rectified revert transform if isinstance(tensorTrans, torch.FloatTensor): imgRevert = tensorRevert.mul(255).round().byte() imgRevert = transforms.ToPILImage()(imgRevert) imgRevert.save(imgFolder + "02rlt.png")
主要对比的是上面代码中存储的几幅图像:
(1)输入图像 src.jpg,读入后为了防止编码差异, 便于和后面的结果图进行二进制对比,直接存储为 00src.png;
(2)将图像直接转换成 Tensor 并立即重新转换成 PILImage,存储为 00rlt.png;
(3)将图像转换成的 Tensor 经过 Normalize 和按照 Normalize 定义推算出的反向计算反转回图像,存储为 01rlt.png;
(4)通过对前面 1,2,3 中结果不一致的原因分析得到的修正方案,结果图像存储为 02rlt.png。
对于存储的图像,通过 Beyond Compare 进行对比,发现 00src.png 和 00rlt.png 是完全一致的(符合预期),但 00src.png 和 01rlt.png 却不完全一致(不符合预期)。出现这个现象后,直观的想法是由于精度不够引起的,但具体哪一步的精度出了问题,还需进一步调查。
在查看了 transforms.ToTensor 和 transforms.ToPILImage 的源代码后,对问题进行进一步抽象,见下面代码:
import torch torch.set_printoptions(precision = 32) a = torch.tensor(4, dtype=torch.uint8) b = a.float().div(255) c = (b - 0.5) / 0.5 d = c * 0.5 + 0.5 e = d.mul(255) f = e.byte() print("a = " + str(a) + "\nb = " + str(b) + "\nc = " + str(c)) print("d = " + str(d) + "\ne = " + str(e) + "\nf = " + str(f))
输出结果为:
a = tensor(4, dtype=torch.uint8) b = tensor(0.01568627543747425079345703125000) c = tensor(-0.96862745285034179687500000000000) d = tensor(0.01568627357482910156250000000000) e = tensor(3.99999976158142089843750000000000) f = tensor(3, dtype=torch.uint8)
为便于查看问题,通过 torch.set_printoptions 设置了输出精度。对于每个变量代表的意义,大致理解如下:a 相当于原始图像,b 相当于图像转换为 Tensor,c 相当于 Normalize,d 相当于反 Normalize,e 为反转回图像的一个中间结果,f 为最终结果。很显然按此步骤,图像像素值 4 在经历 Normalize 和反 Normalize 过程后,最终在新的图像上像素值变成了 3。
接着将上面的计算结果和 wolframalpha 中计算的结果进行对比,可以发现,b 的精度已经出问题了,高精度计算中 4/255 的结果如下:
4/255 = 0.015686274509803921568627450980392156862745098039215686274...
很明显,问题的原因就是算法精度不够而最后的转换又是直接取整,导致出现了不符合预期的结果( 有兴趣的话,也可以对比验证下其余步骤的结果)。 既然搞清楚了原因,修正的方法就很简单了,再取整前加个 round 函数就可以了,见最上面得到 02rlt.png 图像的代码,最终可验证 00src.png 和 02rlt.png 是二进制一致的。
欢迎转载,转载请注明出处:蔓草札记 » PyTorch 中 Tensor 和 PILImage 的相互转换