Pytorch:ToTensor(object)类 您所在的位置:网站首页 object类python Pytorch:ToTensor(object)类

Pytorch:ToTensor(object)类

#Pytorch:ToTensor(object)类| 来源: 网络整理| 查看: 265

PyTorch在做一般的深度学习图像处理任务时,先使用dataset类和dataloader类读入图片,在读入的时候需要做transform变换,其中transform一般都需要ToTensor()操作,将dataset类中__getitem__()方法内读入的PIL或CV的图像数据转换为torch.FloatTensor。详细过程如下: 

class ToTensor(object): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8 In the other cases, tensors are returned without scaling. """ def __call__(self, pic): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. Returns: Tensor: Converted image. """ return F.to_tensor(pic) def __repr__(self): return self.__class__.__name__ + '()' class ToPILImage(object): """Convert a tensor or an ndarray to PIL Image. Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape H x W x C to a PIL Image while preserving the value range. Args: mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). If ``mode`` is ``None`` (default) there are some assumptions made about the input data: - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. - If the input has 2 channels, the ``mode`` is assumed to be ``LA``. - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, ``short``). .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes """ def __init__(self, mode=None): self.mode = mode def __call__(self, pic): """ Args: pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. Returns: PIL Image: Image converted to PIL Image. """ return F.to_pil_image(pic, self.mode) def __repr__(self): format_string = self.__class__.__name__ + '(' if self.mode is not None: format_string += 'mode={0}'.format(self.mode) format_string += ')' return format_string 可以从to_tensor()函数看到,函数接受PIL Image或numpy.ndarray,将其先由HWC转置为CHW格式,再转为float后每个像素除以255.

transforms.ToTensor() (1) transforms.ToTensor() 将numpy的ndarray或PIL.Image读的图片转换成形状为(C,H, W)的Tensor格式,且/255归一化到[0, 1.0]之间 (2)通道的具体顺序与cv2读的还是PIL.Image读的图片有关系 cv2:(B,G,R) PIL.Image:(R, G, B)

代码例子:

import torch import cv2 from PIL import Image from torchvision import transforms image = cv2.imread('myimage.jpg') # numpy数组格式(H,W,C=3),通道顺序(B,G,R) image2 = Image.open('myimage.jpg') # PIL的JpegImageFile格式(size=(W,H)) print(image.shape) # (H,W,3) print(image2.size) # (W,H) tran = transforms.ToTensor() # 将numpy数组或PIL.Image读的图片转换成(C,H, W)的Tensor格式且/255归一化到[0,1.0]之间 img_tensor = tran(image) img2_tensor = tran(image2) print(img_tensor.size()) # (C,H, W), 通道顺序(B,G,R) print(img2_tensor.size()) # (C,H, W), 通道顺序(R,G,B) orchvision.transforms.ToTensor

对于一个图片img,调用ToTensor转化成张量的形式,发生的不是将图片的RGB三维信道矩阵变成tensor图片在内存中以bytes的形式存储,转化过程的步骤是:

img.tobytes()  将图片转化成内存中的存储格式torch.BytesStorage.frombuffer(img.tobytes() )  将字节以流的形式输入,转化成一维的张量对张量进行reshape对张量进行permute(2,0,1)将当前张量的每个元素除以255输出张量

 

torchvision.transforms.ToPILImage

对于一个Tensor的转化过程是:

将张量的每个元素乘上255将张量的数据类型有FloatTensor转化成Uint8将张量转化成numpy的ndarray类型对ndarray对象做permute (1, 2, 0)的操作利用Image下的fromarray函数,将ndarray对象转化成PILImage形式输出PILImage


【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有