计算的python代码如下
def rigid_transform_3D(A, B, scale=False):
assert A.shape == B.shape
num_rows, num_cols = A.shape
if num_rows != 3:
A = A.transpose()
num_rows, num_cols = A.shape
if num_rows != 3:
raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
num_rows, num_cols = B.shape
if num_rows != 3:
B = B.transpose()
num_rows, num_cols = B.shape
if num_rows != 3:
raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
centroid_A = np.mean(A, axis=1)
centroid_B = np.mean(B, axis=1)
centroid_A = centroid_A.reshape(-1, 1)
centroid_B = centroid_B.reshape(-1, 1)
# subtract mean
Am = A.copy() - centroid_A
Bm = B.copy() - centroid_B
if scale:
scale = np.std(Bm) / np.std(Am)
A = A * scale
else:
scale = -1
# find mean column wise
centroid_A = np.mean(A, axis=1)
centroid_B = np.mean(B, axis=1)
# import pdb
# pdb.set_trace()
# ensure centroids are 3x1
centroid_A = centroid_A.reshape(-1, 1)
centroid_B = centroid_B.reshape(-1, 1)
# subtract mean
Am = A - centroid_A
Bm = B - centroid_B
H = Am @ np.transpose(Bm)
# sanity check
#if linalg.matrix_rank(H) < 3:
# raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H)))
# find rotation
U, S, Vt = np.linalg.svd(H)
R = Vt.T @ U.T
# special reflection case
if np.linalg.det(R) < 0:
print("det(R) < R, reflection detected!, correcting for it ...")
Vt[2,:] *= -1
R = Vt.T @ U.T
t = -R @ centroid_A + centroid_B
if scale > 0:
return R, t, scale
else:
return R, t