I made this script that takes 3 images taken with a polarising filter 45° apart as inputs and outputs an RGB preview and an image which encodes the polarization parameters as HSV.
However it's way too slow, taking 155.8125 seconds to process it. What can I do to improve it?
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import Image
import glob
import math
from pystackreg import StackReg
def isolateblue(image):
r,g,b=cv2.split(image)
return b
def rgb_preview(imagelist):
return cv2.merge((imagelist[2],imagelist[1],imagelist[0])).astype('uint8')
def hsv_processing(imagelist):
i0 =imagelist[0]/1
i45 =imagelist[1]/1
i90 =imagelist[1]/1
stokesI = i0 + i90
stokesQ = i0 - i90
stokesU = (np.ones(stokesI.shape)*(2.0 * i45))- stokesI
polint = np.sqrt(stokesQ*stokesQ+stokesU*stokesU)
poldolp = polint/(stokesI+((np.ones(stokesI.shape)+0.001)))
polaop = 0.5 * np.arctan(stokesU, stokesQ)
h=(polaop+(np.ones(polaop.shape)*(np.pi/2.0)))/np.pi
s=poldolp*200
s[s<0]=0
s[s>255]=255
v=polint
hsvpolar=cv2.merge((h,s,v))
rgbimg = cv2.cvtColor(hsvpolar.astype('uint8'),cv2.COLOR_HSV2RGB)*2
rgbimg[rgbimg<0]=0
rgbimg[rgbimg>255]=255
return rgbimg
if __name__ == '__main__':
imagefiles=glob.glob(r"#whatever your filepath is")
imagefiles.sort()
images=[]
for filename in imagefiles:
img=cv2.imread(filename)
img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img=np.int16(img)
images.append(img)
polchannels=[]
for image in images:
polchannels.append(isolateblue(image))
num_images=len(polchannels)
sr=StackReg(StackReg.AFFINE)
polchannels=sr.register_transform_stack(np.stack((polchannels[0],polchannels[1],polchannels[2])), reference='first')
cv2.imwrite("rgb preview.jpg",rgb_preview(polchannels))
cv2.imwrite("polarimetric image.jpg",hsv_processing(polchannels))
print("done")
Associated files (Google Drive):
1 Answer 1
You need to care about correctness before performance, and (though it's difficult to say for sure because your numerical methods are undocumented), it's highly unlikely that the output is correct. But, from the top:
You need blank lines in your source. No, seriously. Listen to a PEP8 linter.
isolateblue
does not need cv2
and can use a Numpy slice directly. Further to that, though: images
should not be a list, and instead should be a single Numpy pre-allocated array that receives only the blue channels of each image loaded.
int16
is not necessary here.
StackReg
is slow. But (unlike your previous test images), these images are well aligned enough that you might get away with not aligning them at all. If you need to preserve this step, the approach I already showed you in the previous question of using OpenCV's own homography algorithm is about four times as fast as StackReg
, and drops one external dependency.
Don't cast arrays to float
by using /1
. Use .astype()
.
i90 =imagelist[1]
should certainly pull from the third channel [2]
and not the second [1]
.
Stop calling np.ones
when you should just broadcast. It was a habit in your previous question, it's a habit here, and you need to break it. You've managed to introduce a numerical error because of it: when you write np.ones(stokesI.shape)+0.001
, that add should have been a multiply; but really the entire ones()
call should go away.
Don't s[s>255]=255
; use np.clip()
.
cv2.merge
is really just a np.stack()
.
Don't post-multiply cvtColor
by 2. If you need brighter colours, multiply the value channel.
Your hue calculation is probably incorrect. After you divide out pi, you need to multiply by 180, since OpenCV's HSV colour space has H ranging from 0 through 180.
Suggested
from typing import Iterable
import cv2
import glob
import numpy as np
BLUE_CHANNEL = 0
def warp_align(images: np.ndarray) -> None:
print('SIFT detect and compute...')
sift = cv2.SIFT_create()
keys: list[tuple] = []
descriptors: list[np.ndarray] = []
for image in images:
key, descriptor = sift.detectAndCompute(image, mask=None)
keys.append(key)
descriptors.append(descriptor)
FLANN_INDEX_KDTREE = 1
flann = cv2.FlannBasedMatcher(
indexParams={'algorithm': FLANN_INDEX_KDTREE, 'trees': 5},
searchParams={'checks': 50},
)
print('knn match...')
LOWES_RATIO = 0.7
train_desc, *query_descs = descriptors
matches = [
[
m
for m, n in flann.knnMatch(query_desc, train_desc, k=2)
if m.distance < LOWES_RATIO*n.distance
]
for query_desc in query_descs
]
def keys_to_points(matched_keys: Iterable[tuple[float, float]]) -> np.ndarray:
return np.array(tuple(matched_keys), dtype=np.float32)
print('Dewarping with homographies...')
train_key, *query_keys = keys
for query_key, target_matches, image in zip(query_keys, matches, images[1:]):
query_points = keys_to_points(query_key[m.queryIdx].pt for m in target_matches)
train_points = keys_to_points(train_key[m.trainIdx].pt for m in target_matches)
M, mask = cv2.findHomography(query_points, train_points, method=cv2.RANSAC, ransacReprojThreshold=5)
print(M)
cv2.warpPerspective(src=image, dst=image, M=M, dsize=image.shape[::-1])
def rgb_preview(image: np.ndarray) -> np.ndarray:
"""Convert from (rgb), x, y to x, y, (bgr)"""
return np.moveaxis(image, 0, -1)[..., ::-1]
def hsv_processing(image: np.ndarray) -> np.ndarray:
i00, i45, i90 = image
i00 = i00.astype(float)
stokesI = i00 + i90
stokesQ = i00 - i90
stokesU = 2*i45 - stokesI
polint = np.sqrt(stokesQ*stokesQ + stokesU*stokesU)
# In [0, inf]
poldolp = polint/(stokesI + 1e-6)
# In [-pi/2, pi/2]
polaop = np.arctan(stokesU, stokesQ)
h = (polaop/np.pi + 0.5)*180
s = np.clip(100*poldolp, a_min=0, a_max=255)
v = np.clip(2*polint, a_min=0, a_max=255)
hsvpolar = np.stack((h, s, v), axis=-1).astype('uint8')
return cv2.cvtColor(hsvpolar, cv2.COLOR_HSV2RGB)
def main() -> None:
print('Loading images...')
image_filenames = glob.glob('Test*degrees.jpg')
image_filenames.sort()
pol_channels = None
for i, filename in enumerate(image_filenames):
img = cv2.imread(filename) # BGR
if pol_channels is None:
pol_channels = np.empty((3, *img.shape[:2]), dtype=np.uint8)
pol_channels[i, ...] = img[..., BLUE_CHANNEL]
warp_align(pol_channels)
print('Generating preview...')
cv2.imwrite("rgb preview.jpg", rgb_preview(pol_channels))
print('Generating polarimetry...')
cv2.imwrite("polarimetric image.jpg", hsv_processing(pol_channels))
print('Done.')
if __name__ == '__main__':
main()
Output
-
\$\begingroup\$ That add is there to avoid division by zero, I'm sure it's not a multiplication. \$\endgroup\$Omar Morales Rivera– Omar Morales Rivera2022年10月21日 20:39:14 +00:00Commented Oct 21, 2022 at 20:39
-
\$\begingroup\$ I do use a linter. Bandit is my go-to because it identifies security issues: bandit.readthedocs.io/en/latest \$\endgroup\$Omar Morales Rivera– Omar Morales Rivera2022年10月21日 20:42:31 +00:00Commented Oct 21, 2022 at 20:42
-
\$\begingroup\$ Re. multiplication - look closer. You should be adding a product but instead you're adding a sum. \$\endgroup\$Reinderien– Reinderien2022年10月22日 03:13:36 +00:00Commented Oct 22, 2022 at 3:13
-
\$\begingroup\$ Re. linter: whatever you're using, it isn't enough. You need a whitespace linter. \$\endgroup\$Reinderien– Reinderien2022年10月22日 03:13:53 +00:00Commented Oct 22, 2022 at 3:13
i90 = imagelist[1]
should bei90 = imagelist[2]
; you're throwing away a channel otherwise. \$\endgroup\$