# Source code for torchvision.utils

import torch
import math
irange = range

"""Make a grid of images.

Args:
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
or a list of images all of the same size.
nrow (int, optional): Number of images displayed in each row of the grid.
The Final grid size is (B / nrow, nrow). Default is 8.
normalize (bool, optional): If True, shift the image to the range (0, 1),
by subtracting the minimum and dividing by the maximum pixel value.
range (tuple, optional): tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
scale_each (bool, optional): If True, scale each image in the batch of
images separately rather than the (min, max) over all images.

Example:
See this notebook here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>_

"""
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))

# if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list):
tensor = torch.stack(tensor, dim=0)

if tensor.dim() == 2:  # single image H x W
tensor = tensor.view(1, tensor.size(0), tensor.size(1))
if tensor.dim() == 3:  # single image
if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
tensor = torch.cat((tensor, tensor, tensor), 0)
tensor = tensor.view(1, tensor.size(0), tensor.size(1), tensor.size(2))

if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1)

if normalize is True:
tensor = tensor.clone()  # avoid modifying tensor in-place
if range is not None:
assert isinstance(range, tuple), \
"range has to be a tuple (min, max) if specified. min and max are numbers"

def norm_ip(img, min, max):
img.clamp_(min=min, max=max)

def norm_range(t, range):
if range is not None:
norm_ip(t, range[0], range[1])
else:
norm_ip(t, float(t.min()), float(t.max()))

if scale_each is True:
for t in tensor:  # loop over mini-batch dimension
norm_range(t, range)
else:
norm_range(tensor, range)

if tensor.size(0) == 1:
return tensor.squeeze()

# make the mini-batch of images into a grid
nmaps = tensor.size(0)
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
k = 0
for y in irange(ymaps):
for x in irange(xmaps):
if k >= nmaps:
break
.copy_(tensor[k])
k = k + 1
return grid

"""Save a given Tensor into an image file.

Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
saves the tensor as a grid of images by calling make_grid.
**kwargs: Other arguments are documented in make_grid.
"""
from PIL import Image