web123456

Convert tensor to image_Communication of Tensor and various image formats in Pytorch

Preface

existpytorchIn this article, you often encounter image format conversion, such as converting images read from PIL library into Tensor, or converting Tensor intonumpyFormat picture. Moreover, the image formats read using different image processing libraries are also different. Therefore, how to correctly convert various image formats (PIL, numpy,Tensor) is a relatively important issue in debugging.

This article mainly explains how to correctly convert image formats into various image library reading formats and tensor vectors in pytorch. The following code can be used directly in Pytorch-0.4.0 or 0.3.0 version after testing.

Format conversion

The images we usually process in pytorch or python are nothing more than these formats:

PIL: Image format read using python's own image processing library

numpy: Use python-opencvThe image format read from the library

tensor: the vector format taken during training in pytorch (of course, it can also be said to be pictures)

Note that the following explanations of the picture formats are all RGB three-channel, 24-bit true color, which is the picture format we usually use.

PIL and Tensor

The conversion between PIL and Tensor is relatively easy, because pytorch has provided relevant code, we only need to use it in conjunction:

All codes have been referenced (the subsequent code omits the referenced part):

import torch

from PIL import Image

import as plt

# loader uses transforms that come with torchvisionfunction

loader = transforms.Compose([

()])

unloader = ()

1 PIL reads the picture and converts it into Tensor

# Enter the image address

# Return tensor variable

def image_loader(image_name):

image = (image_name).convert('RGB')

image = loader(image).unsqueeze(0)

return (device, )

2 Convert PIL pictures to Tensor

# Enter PIL format picture

# Return tensor variable

def PIL_to_tensor(image):

image = loader(image).unsqueeze(0)

return (device, )

3 Tensor converted to PIL pictures

# Enter the tensor variable

# Output PIL format picture

def tensor_to_PIL(tensor):

image = ().clone()

image = (0)

image = unloader(image)

return image

4 Directly display tensor format pictures

def imshow(tensor, title=None):

image = ().clone() # we clone the tensor to not do changes on it

image = (0) # remove the fake batch dimension

image = unloader(image)

(image)

if title is not None:

(title)

(0.001) # pause a bit so that plots are updated

5 Save tensor format pictures directly

def save_image(tensor, **para):

dir = 'results'

image = ().clone() # we clone the tensor to not do changes on it

image = (0) # remove the fake batch dimension

image = unloader(image)

if not (dir):

(dir)

('results_{}/s{}-c{}-l{}-e{}-sl{:4f}-cl{:4f}.jpg'

.format(num, para['style_weight'], para['content_weight'], para['lr'], para['epoch'],

para['style_loss'], para['content_loss']))

numpy and Tensor

numpy format is the image format read using cv2, that is, the python-opencv library. It should be noted that the pictures read using python-opencv are slightly different from the picture data read using PIL. The results of the pictures read using python-opencv during testing are slightly worse than those read using PIL (released after the detailed process).

All code references afterwards:

import cv2

import torch

import as plt

numpy is converted to tensor

def toTensor(img):

assert type(img) == ,'the img type is {}, but ndarry expected'.format(type(img))

img = (img, cv2.COLOR_BGR2RGB)

img = torch.from_numpy(((2, 0, 1)))

return ().div(255).unsqueeze(0) # 255 can also be changed to 256

Tensor to numpy

def tensor_to_np(tensor):

img = (255).byte()

img = ().numpy().squeeze(0).transpose((1, 2, 0))

return img

Show numpy format pictures

def show_from_cv(img, title=None):

img = (img, cv2.COLOR_BGR2RGB)

()

(img)

if title is not None:

(title)

(0.001)

Show tensor format pictures

def show_from_tensor(tensor, title=None):

img = ()

img = tensor_to_np(img)

()

(img)

if title is not None:

(title)

(0.001)

Notice

The above introduction is all about the conversion of one picture. If n pictures are together, you only need to modify the corresponding code.

For example, just slightly modify the changes mentioned before:

# Convert numpy format images of N x H x W X C into the corresponding tensor format

def toTensor(img):

img = torch.from_numpy(((0, 3, 1, 2)))

return ().div(255).unsqueeze(0)