Use estimated position of feature within block rather than just block's centre
[stop_motion.git] / correlate.py
1 #!/usr/bin/env python3
2
3 import gamma
4 import math
5 import numpy
6 import perspective
7 import scipy
8 import scipy.ndimage
9 import scipy.signal
10 import sys
11
12 # size of block that will be matched
13 # (correlate a block this size against a block with added slippage all around)
14 XM = 128
15 YM = 128
16
17 # pitch between the block centres
18 XP = 64
19 YP = 64
20
21 # allowable +/- slippage between pairs
22 XS = 64
23 YS = 64
24
25 CORNER_CANDIDATES = 8
26
27 CUTOFF0 = 64
28 CUTOFF1 = 4
29
30 EPSILON = 1e-6
31
32 def calc_bandpass(image):
33   _, _, cs = image.shape
34   return numpy.stack(
35     [
36       scipy.ndimage.gaussian_filter(image[:, :, i], CUTOFF1, mode = 'mirror') -
37         scipy.ndimage.gaussian_filter(image[:, :, i], CUTOFF0, mode = 'mirror')
38       for i in range(cs)
39     ],
40     2
41   )
42
43 def calc_match(bandpass0, bandpass1, xc, yc):
44   x0 = xc - XM // 2
45   y0 = yc - YM // 2
46   x1 = xc - XM // 2 - XS
47   y1 = yc - YM // 2 - YS
48
49   block0 = bandpass0[
50     y0:y0 + YM,
51     x0:x0 + XM,
52     :
53   ]
54   block1 = bandpass1[
55     y1:y1 + YM + YS * 2,
56     x1:x1 + XM + XS * 2,
57     :
58   ]
59
60   # note: swapping block1, block0 flips the output (subtracts x and y
61   # from BLOCK_SIZE) and we need this to find matching part of image1
62   corr = numpy.sum(
63     numpy.stack(
64       [
65         scipy.signal.correlate(
66           block1[:, :, i],
67           block0[:, :, i],
68           mode = 'valid'
69         )
70         for i in range(cs)
71       ],
72       0
73     ),
74     0
75   )
76   #temp = corr - numpy.mean(corr)
77   #temp /= 10. * numpy.sqrt(numpy.mean(numpy.square(temp)))
78   #gamma.write_image(f'corr_{i:d}_{j:d}.jpg', temp + .5)
79
80   # find slippage from correlation
81   yo, xo = numpy.unravel_index(numpy.argmax(corr), corr.shape)
82   max_corr = corr[yo, xo]
83   if (
84     max_corr < EPSILON or
85       xo < CUTOFF1 or
86       xo > XS * 2 - CUTOFF1 or
87       yo < CUTOFF1 or
88       yo > YS * 2 - CUTOFF1
89   ):
90     return None
91
92   # estimate position within block of feature being matched
93   block1 = block1[yo:yo + YM, xo:xo + XM, :]
94   match = numpy.sum(block0 * block1, 2)
95   #print('xxx', numpy.sum(match), max_corr)
96   #assert False
97   xf = (
98     numpy.sum(match, 0) @ numpy.arange(XM, dtype = numpy.double)
99   ) / max_corr
100   yf = (
101     numpy.sum(match, 1) @ numpy.arange(YM, dtype = numpy.double)
102   ) / max_corr
103   #if diag:
104   #  x = int(math.floor(xf))
105   #  y = int(math.floor(yf))
106
107   #  diag0 = block0 + .5
108   #  diag0[
109   #    max(y - 21, 0):max(y + 22, 0),
110   #    max(x - 1, 0):max(x + 2, 0),
111   #    :
112   #  ] = 0.
113   #  diag0[
114   #    max(y - 1, 0):max(y + 2, 0),
115   #    max(x - 21, 0):max(x + 22, 0),
116   #    :
117   #  ] = 0.
118   #  gamma.write_image(f'diag_{xc:d}_{yc:d}_0.jpg', diag0)
119
120   #  diag1 = block1 + .5 
121   #  diag1[
122   #    max(y - 21, 0):max(y + 22, 0),
123   #    max(x - 1, 0):max(x + 2, 0),
124   #    :
125   #  ] = 0.
126   #  diag1[
127   #    max(y - 1, 0):max(y + 2, 0),
128   #    max(x - 21, 0):max(x + 22, 0),
129   #    :
130   #  ] = 0.
131   #  gamma.write_image(f'diag_{xc:d}_{yc:d}_1.jpg', diag1)
132
133   # return offset and feature relative to block centre
134   return xo - XS, yo - YS, xf - XM // 2, yf - YM // 2
135
136
137
138 diag = False
139 if len(sys.argv) >= 2 and sys.argv[1] == '--diag':
140   diag = True
141   del sys.argv[1]
142
143 in_jpg0 = 'tank_battle/down_2364.jpg'
144 in_jpg1 = 'tank_battle/down_2365.jpg'
145 out_jpg0 = 'out0_2364.jpg'
146 out_jpg1 = 'out0_2365.jpg'
147
148 print(f'read {in_jpg0:s}')
149 image0 = gamma.read_image(in_jpg0)
150 shape = image0.shape
151
152 print('bandpass')
153 bandpass0 = calc_bandpass(image0)
154 if diag:
155   gamma.write_image('bandpass0.jpg', bandpass0 + .5)
156
157 print(f'read {in_jpg1:s}')
158 image1 = gamma.read_image(in_jpg1)
159 assert image1.shape == shape
160
161 print('bandpass')
162 bandpass1 = calc_bandpass(image1)
163 if diag:
164   gamma.write_image('bandpass1.jpg', bandpass1 + .5)
165
166 ys, xs, cs = shape
167 xb = (xs // 2 - XM - 2 * XS) // XP
168 yb = (ys // 2 - YM - 2 * YS) // YP
169 print('xb', xb, 'yb', yb)
170
171 print('find corner candidates')
172 p_all = []
173 q_all = []
174 corner_candidates = []
175 for i in range(2):
176   for j in range(2):
177     print('i', i, 'j', j)
178
179     # correlate blocks in (i, j)-corner
180     offsets = []
181     blocks = []
182     for k in range(yb):
183       yc = YS + YM // 2 + k * YP
184       if i:
185         yc = ys - yc
186       for l in range(xb):
187         xc = XS + XM // 2 + l * XP
188         if j:
189           xc = xs - xc
190         match = calc_match(bandpass0, bandpass1, xc, yc)
191         if match is not None:
192           offsets.append(match [:2])
193           xo, yo, xf, yf = match
194           xf0 = xc + xf
195           yf0 = yc + yf
196           xf1 = xf0 + xo
197           yf1 = yf0 + yo
198           p_all.append(numpy.array([xf0, yf0], numpy.double))
199           q_all.append(numpy.array([xf1, yf1], numpy.double))
200           blocks.append((xf0, yf0, xf1, yf1))
201
202     # find the offset trend (median offset per axis) in (i, j)-corner
203     k = len(blocks) // 2
204     xo_median = sorted([xo for xo, _ in offsets])[k]
205     yo_median = sorted([yo for _, yo in offsets])[k]
206     #print('i', i, 'j', j, 'xo_median', xo_median, 'yo_median', yo_median)
207
208     # choose CORNER_CANDIDATES blocks closest to trend in (i, j)-corner
209     u = numpy.array(offsets, numpy.double)
210     v = numpy.array([xo_median, yo_median], numpy.double)
211     dist = numpy.sum(numpy.square(u - v[numpy.newaxis, :]), 1)
212     corner_candidates.append(
213       sorted(
214         [(dist[i],) + blocks[i] for i in range(len(blocks))]
215       )[:CORNER_CANDIDATES]
216     )
217 p_all = numpy.stack(p_all, 1)
218 q_all = numpy.stack(q_all, 1)
219
220 # try all combinations of the corner candidates
221 print('try corner candidates')
222 p = numpy.zeros((2, 4), numpy.double)
223 q = numpy.zeros((2, 4), numpy.double)
224 best_dist = None
225 best_A = None
226 best_p = None # for diag
227 for _, xf00, yf00, xf10, yf10 in corner_candidates[0]:
228   p[0, 0] = xf00
229   p[1, 0] = yf00
230   q[0, 0] = xf10
231   q[1, 0] = yf10
232   for _, xf01, yf01, xf11, yf11 in corner_candidates[1]:
233     p[0, 1] = xf01
234     p[1, 1] = yf01
235     q[0, 1] = xf11
236     q[1, 1] = yf11
237     for _, xf02, yf02, xf12, yf12 in corner_candidates[2]:
238       p[0, 2] = xf02
239       p[1, 2] = yf02
240       q[0, 2] = xf12
241       q[1, 2] = yf12
242       for _, xf03, yf03, xf13, yf13 in corner_candidates[3]:
243         p[0, 3] = xf03
244         p[1, 3] = yf03
245         q[0, 3] = xf13
246         q[1, 3] = yf13
247
248         A = perspective.calc_transform(p, q)
249         dist = numpy.sum(
250           numpy.square(
251             q_all - perspective.apply_transform_multi(A, p_all)
252           )
253         )
254         if best_dist is None or dist < best_dist:
255           best_dist = dist
256           best_A = A
257           best_p = numpy.copy(p) # for diag
258
259 print('remap')
260 out_image1 = perspective.remap_image(best_A, image1)
261 if diag:
262   for i in range(4):
263     xf0, yf0 = numpy.floor(best_p[:, i]).astype(numpy.int32)
264
265     image0[
266       max(yf0 - 21, 0):max(yf0 + 22, 0),
267       max(xf0 - 1, 0):max(xf0 + 2, 0),
268       :
269     ] = 0.
270     image0[
271       max(yf0 - 1, 0):max(yf0 + 2, 0),
272       max(xf0 - 21, 0):max(xf0 + 22, 0),
273       :
274     ] = 0.
275
276     out_image1[
277       max(yf0 - 21, 0):max(yf0 + 22, 0),
278       max(xf0 - 1, 0):max(xf0 + 2, 0),
279       :
280     ] = 0.
281     out_image1[
282       max(yf0 - 1, 0):max(yf0 + 2, 0),
283       max(xf0 - 21, 0):max(xf0 + 22, 0),
284       :
285     ] = 0.
286
287 sys.stderr.write(f'write {out_jpg0:s}\n')
288 gamma.write_image(out_jpg0, image0)
289
290 sys.stderr.write(f'write {out_jpg1:s}\n')
291 gamma.write_image(out_jpg1, out_image1)