import hashlib
import os
import pickle
from typing import Self
import numpy as np
import torch
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
class ImageTensor(torch.Tensor):
@classmethod
def from_pixels(cls, pixels):
instance = (
cls._square_prune(pixels).to(device).as_subclass(cls)
if isinstance(pixels, torch.Tensor)
else cls._square_prune(
torch.tensor(pixels, device=device, dtype=torch.float)
).as_subclass(cls)
)
instance.__init__()
return instance
def __init__(self) -> None:
super().__init__()
assert self.dim() in {2, 3}
self.img_shape = torch.tensor(self.shape[:2])
self.radius = torch.min(self.img_shape).divide(2)
self.center_y, self.center_x = self.img_shape / 2.0
self._circle_mask()
def show(self) -> Self:
numpy_array = torch.clamp(self, min=0, max=255).cpu().numpy().astype("uint8")
Image.fromarray(
numpy_array, mode="CMYK"
).show() if self.dim() == 3 else Image.fromarray(numpy_array).show()
return self
def _circle_mask(self) -> None:
file = hashlib.sha256(f"circle_mask{self.shape}".encode()).hexdigest()
if os.path.isfile(f"tmp/{file}.mask"):
with open(f"tmp/{file}.mask", "rb") as f:
mask = pickle.load(f)
else:
y_indices, x_indices = meshgrid(*self.img_shape)
distances = torch.hypot(
(y_indices - self.center_y).float(),
(x_indices - self.center_x).float(),
)
mask = distances > self.radius
with open(f"tmp/{file}.mask", "wb") as f:
pickle.dump(mask, f)
self[mask] = 255
@staticmethod
def _square_prune(tensor: torch.Tensor) -> torch.Tensor:
assert tensor.dim() in {2, 3}
center_y, center_x = torch.tensor(tensor.shape[:2]) / torch.tensor(2.0)
radius = min(tensor.shape[:2]) / torch.tensor(2.0)
if center_y == center_x:
return tensor
y_min = int(center_y - radius)
y_max = int(center_y + radius)
x_min = int(center_x - radius)
x_max = int(center_x + radius)
return (
tensor[y_min:y_max, x_min:x_max]
if tensor.dim() == 2
else tensor[y_min:y_max, x_min:x_max, :]
)
def __add__(self, other) -> Self:
return ImageTensor.from_pixels(torch.add(self, other))
def __sub__(self, other) -> Self:
return ImageTensor.from_pixels(torch.sub(self, other))
def add_to_tensor(self, other: Self) -> torch.Tensor:
return torch.add(self, other)
def channel(self, idx):
assert self.dim() == 3
return self.from_pixels(self[:, :, idx])
def meshgrid(
y_size: int | torch.Tensor, x_size: int | torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
y_indices, x_indices = torch.meshgrid(
torch.arange(int(y_size), device=device),
torch.arange(int(x_size), device=device),
indexing="ij",
)
return y_indices, x_indices
def calculate_lines(
parent_img: ImageTensor,
current_point: torch.Tensor,
points: torch.Tensor,
decay_factor: float,
brightness_factor: float,
) -> torch.Tensor:
assert parent_img.dim() == 2
assert current_point.dim() == 1
assert points.dim() == 2
y_indices, x_indices = meshgrid(*parent_img.img_shape)
y_indices = y_indices.unsqueeze(0).tile((points.shape[0], 1, 1)).to(device)
x_indices = x_indices.unsqueeze(0).tile((points.shape[0], 1, 1)).to(device)
y0 = points[:, 0].unsqueeze(1).unsqueeze(2).expand_as(y_indices).to(device)
x0 = points[:, 1].unsqueeze(1).unsqueeze(2).expand_as(x_indices).to(device)
y1 = current_point[0]
x1 = current_point[1]
delta_y = y1 - y0
delta_x = x1 - x0
numerator = torch.abs(
(delta_y * x_indices) - (delta_x * y_indices) + (x1 * y0) - (y1 * x0)
)
denominator = torch.hypot(delta_y.float(), delta_x.float())
distance = numerator / denominator
brightness = brightness_factor * brightness_decay(distance, decay_factor)
brightness[(points == current_point).all(dim=1).nonzero(as_tuple=True)[0], :, :] = (
torch.zeros_like(parent_img)
)
return brightness.permute((1, 2, 0))
def brightness_decay(distance: torch.Tensor, decay_factor: float):
return torch.exp((-((distance / decay_factor) ** 2)))
def pickle_object(obj, filename):
with open(filename, "wb") as f:
pickle.dump(obj, f)
return obj
def unpickle_object(filename):
with open(filename, "rb") as f:
return pickle.load(f)
def create_string_art(
image_file,
nail_count: int,
brightness_factor: float,
decay_factor: float,
print_status: bool,
) -> ImageTensor:
img, cyan_channel, magenta_channel, yellow_channel, key_channel = cmyk_split(
image_file
)
cyan_string_art = ImageTensor.from_pixels(torch.zeros_like(cyan_channel))
magenta_string_art = ImageTensor.from_pixels(torch.zeros_like(magenta_channel))
yellow_string_art = ImageTensor.from_pixels(torch.zeros_like(yellow_channel))
nail_angles = torch.linspace(0, 2 * torch.pi, nail_count)
nail_locations = torch.stack(
[
torch.tensor(
[
img.center_y + img.radius * torch.sin(angle),
img.center_x + img.radius * torch.cos(angle),
]
)
for angle in nail_angles
]
)
for channel, string_art in [
(cyan_channel, cyan_string_art),
(magenta_channel, magenta_string_art),
(yellow_channel, yellow_string_art),
]:
is_done = False
nail_order = [nail_count // 2]
channel_mse = np.inf
while not is_done:
current_nail = nail_order[-1]
all_lines = calculate_lines(
channel,
nail_locations[current_nail],
nail_locations,
decay_factor,
brightness_factor,
)
line_error = bulk_square_error(
torch.add(
all_lines, torch.unsqueeze(string_art, 2).expand_as(all_lines)
),
channel,
)
best_mse = torch.min(line_error)
line_idx = int(line_error.argmin())
if torch.min(line_error) < channel_mse:
channel_mse = best_mse
nail_order.append(line_idx)
string_art += ImageTensor.from_pixels(all_lines[:, :, line_idx])
print(f"Iteration: {len(nail_order) - 1}") if print_status else ...
else:
is_done = True
return pickle_object(
ImageTensor.from_pixels(
torch.stack(
[cyan_string_art, magenta_string_art, yellow_string_art, key_channel],
dim=-1,
)
),
f"{image_file}.stringart",
)
def bulk_square_error(inputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
target = target.unsqueeze(2).expand_as(inputs)
return torch.sum((inputs - target) ** 2, dim=(0, 1))
def cmyk_split(image_file: str):
img = ImageTensor.from_pixels(np.array(Image.open(image_file).convert("CMYK")))
cyan_channel = img.channel(0)
magenta_channel = img.channel(1)
yellow_channel = img.channel(2)
key_channel = img.channel(3)
return img, cyan_channel, magenta_channel, yellow_channel, key_channel
def main() -> None:
create_string_art("img/unnamed.jpg", 600, 30, 0.9, True).show()
if __name__ == "__main__":
main()
I am trying to make a program that approximates an image by wrapping "string" around "nails" that are around the edge of the circle.
The basic flow of my code is as follows:
- Load the image and split it into CMYK
- Generate a 3d tensor where each 2d slice is a possible line that can be drawn between the current "nail" and all other nails (including itself)
- Find the slice with the lowese summed square error when added to the current string art and compared with the real image
- Repeat until all possible lines you can draw have worse/the same summed square error as the current string art
- Repeat for each color channel
This works really well for very small images (100x100 to 250x250) and I can use up to like almost 2000 "nails" and have it run really quick. It does start to get really slow really fast, and I get a CUDA timeout error on around 800x800 and anything around 1000x1000 I just get a GPU memory error. The timeout error had some advice on solving it, but I don't know what that actually does/means and I would rather fix any issues with allocating more memory than necessary in my code first. (I am using a GTX 980 which is almost 10 years old and only has 4GB of VRAM, so I would assume on a more modern GPU with more VRAM, it would be able to handle larger images before running into the CUDA launch error and GPU memory error)
RuntimeError: CUDA error: the launch timed out and was terminated
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
I am fairly certain the calculate_lines()
and bulk_square_error()
methods can be improved a lot.
calculate_lines()
takes in a parent image of type ImageTensor
which is just my subclass of torch.Tensor
, the current points of type torch.Tensor
and of the form [y, x]
, and then the list of all points to draw to of type torch.Tensor
of the form [[y_0, x_0], [y_1, x_1], ...]
and returns a torch.Tensor
of shape [y, x, nail_index]
. I am pretty sure that the biggest bottleneck of this function is the creation of y0, x0, y1, x1, delta_y, and delta_x
. I am sure there is a better way to be able to handle them than what I am doing, which is just expanding them to be the desired output shape. but I am sure that allocates a lot more memory than is necessary and that there is a better way to handle that, but I am very new to using pytorch and am not very familiar with all of its methods. I am using this formula to draw the lines.
bulk_square_error()
takes in a 3d tensor of each of the lines added to the current string art and the 2d real target image and returns a 1d tensor of the summed square error of each 2d slice. This function has the same problem as calculate_lines()
where I am expanding the target image to be the same shape as the 3d tensor. I am sure there is a way to compare each slice to the same 2d tensor without having to expand it to take up more memory.
I am looking for any any ways to improve the speed/GPU memory usage of my program, and any criticisms/feedback regarding code style or my approach to the problem.
Any help is appreciated. Thanks!
There is no good reason as to why I am using pytorch rather than numpy. Originally, I was calculating the mean squared error of using torch.nn.functional.mse_loss()
method to compare each line+string art to the real image individually, but it was extremely slow to iterate over each possible line with a for loop. Now that I am not using that function any more, I could switch to numpy, but pytorch is fine.
Additionally, I feel like I should share my inspiration for this project. This youtube video was my inspiration. The creator recently made a follow-up to a video using the radon transform. Maybe I will attempt that implementation eventually, but for now I am content with this implementation.
If you are going to try the script, you need to have a folder named tmp/
in the same directory as the script, since it stores the circle masks for future use in the folder to avoid constantly generating them.
1 Answer 1
Strategies to speed up the code:
Instead of computing all the lines in one large array, then picking one, you should compute one line at the time, and keep only the best one iteratively. Loops in Python are slow, and vectorizing things is good, but the larger the image, and the more points around the circle you evaluate, the larger this intermediate array becomes. And large intermediate arrays are also slow. Instead, use Numba to speed up the loops.
This line is way more expensive than it needs to be:
brightness = brightness_factor * brightness_decay(distance, decay_factor)
brightness_decay
draws the lines with a Gaussian profile. This is really good. But you're computing the Gaussian also for points far away from the line where the output will be (near) zero. You should evaluate thebrightness_decay
function only for pixels that are close to the line.Furthermore, the very next line erases a bunch of lines. Can't you determine first which lines you want to keep, and only draw those?
Other comments:
The function
cmyk_split
does two things: it reads an image from file, then it converts the color space and returns individual channels. The function name reveals only the latter action. It is best when functions do only one thing.You're computing the string art for the cyan, magenta and yellow channels, but not for black channel (the 4th channel in CMYK).
I don't see the reason you needed to make
ImageTensor
derived fromtorch.Tensor
. You could just use atorch.Tensor
in your program._circle_mask
could be a regular function. Other than that, all you do is add a few values (img_shape
,radius
,center_y
,center_x
) to the object, this you could equally well do with atorch.Tensor
object. Or you could just compute them when you need them. I think thisImageTensor
object makes you code more difficult to read than it needs to be, it adds complication rather than simplifying things.I'll admit I'm biased against OOP. But I think classes should only be introduced when they simplify code or the API.
-
\$\begingroup\$ "Furthermore, the very next line erases a bunch of lines. Can't you determine first which lines you want to keep, and only draw those?" It is only one line that gets erased, the one drawn with the same start and end point because of dividing by zero. Does that really have a noticeable effect on performance? Either way if I switch to using a Numba jit loop it is as simple as an if check before drawing the line, I just didn't know how to skip drawing the in the bulk tensor. \$\endgroup\$flakpm– flakpm2024年06月27日 19:57:48 +00:00Commented Jun 27, 2024 at 19:57
1 - max(R,G,B)
. so if it's zero, then your original RGB color has a max value in one of the channels". If that is accurate, that means that K is just a constant value and does not need to have string art done on it. \$\endgroup\$