[Python] OpenCV matchTemplate

程式語言:Python
Package:opencv-python
官方網址
OpenCV-Python Tutorials

功能:找出圖片中類似的東西

演算法

  1. 輸入兩張影像,分別為 image、template
  2. 不斷滑動 template,得到 image 上各個位置的比較值,比較值代表相似程度
    然後將 image 左上角位置,作為 result 比較值的存放位置
  3. 完成後可得到 result
    可用 minMaxLoc() 函式,找出結果圖的最大或最小值,定位出搜尋位置

限制

  • 物體有旋轉時,會找不到
  • 物體大小改變時,會找不到

result = cv2.matchTemplate(image, templ, method[, result])

  • 參數
    • image
      • 被尋找的圖片
      • 必須為 8-bit or 32-bit
    • templ
      • 尋找的物品圖片
      • size 不能大於 image,且格式需一致
    • method
      • 比對的方法
    • result
      • 比較的結果,格式為 numpy.ndarray (dtype=float32)
      • 可傳入想儲存結果的 array
      • 因 image 大小為 W×H 且 templ 為 w×h ,所以大小為 (Ww+1)×(Hh+1)

比對方法

I 表示 image,T 表示 template,R 表示 result
  • CV_TM_SQDIFF
    • 平方差,越小越相似
    • R(x,y)=x,y(T(x,y)I(x+x,y+y))2
  • CV_TM_SQDIFF_NORMED
    • 正規化平方差,越小越相似
    • 保證當 pixel 亮度都乘上同一係數時,相似度不變
    • R(x,y)=x,y(T(x,y)I(x+x,y+y))2x,yT(x,y)2x,yI(x+x,y+y)2
  • CV_TM_CCORR
    • 相關係數,越大越相似
    • R(x,y)=x,y(T(x,y)I(x+x,y+y))
  • CV_TM_CCORR_NORMED
    • 正規化相關係數,越大越相似
    • 保證當 pixel 亮度都乘上同一係數時,相似度不變
    • R(x,y)=x,y(T(x,y)I(x+x,y+y))x,yT(x,y)2x,yI(x+x,y+y)2
  • CV_TM_CCOEFF
    • 去掉直流成份的相關係數,越大越相似
    • R(x,y)=x,y(T(x,y)I(x+x,y+y))
      where
      T(x,y)=T(x,y)1/(wh)x,yT(x,y)I(x+x,y+y)=I(x+x,y+y)1/(wh)x,yI(x+x,y+y)
  • CV_TM_CCOEFF_NORMED
    • 正規化 去掉直流成份的相關係數
    • 保證當 pixel 亮度都乘上同一係數時,相似度不變
    • 計算出的相關係數被限制在了 -1 到 1 之間
      • 1 表示完全相同
      • -1 表示亮度正好相反
      • 0 表示没有線性相關
    • R(x,y)=x,y(T(x,y)I(x+x,y+y))x,yT(x,y)2x,yI(x+x,y+y)2
      where
      T(x,y)=T(x,y)1/(wh)x,yT(x,y)I(x+x,y+y)=I(x+x,y+y)1/(wh)x,yI(x+x,y+y)

範例

尋找特定物品
程式碼
  1. import cv2
  2. import numpy as np
  3. from matplotlib import pyplot as plt
  4.  
  5. img = cv2.imread('image.png')
  6. img2 = img.copy()
  7. template = cv2.imread('template.jpg')
  8. w = template.shape[1]
  9. h = template.shape[0]
  10.  
  11. # All the 6 methods for comparison in a list
  12. methods = ['cv2.TM_CCOEFF', 'cv2.TM_CCOEFF_NORMED', 'cv2.TM_CCORR',
  13. 'cv2.TM_CCORR_NORMED', 'cv2.TM_SQDIFF', 'cv2.TM_SQDIFF_NORMED']
  14.  
  15. for meth in methods:
  16. img = img2.copy()
  17. method = eval(meth)
  18.  
  19. # Apply template Matching
  20. res = cv2.matchTemplate(img,template,method)
  21. min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
  22.  
  23. # If the method is TM_SQDIFF or TM_SQDIFF_NORMED, take minimum
  24. if method in [cv2.TM_SQDIFF, cv2.TM_SQDIFF_NORMED]:
  25. top_left = min_loc
  26. else:
  27. top_left = max_loc
  28. bottom_right = (top_left[0] + w, top_left[1] + h)
  29.  
  30. cv2.rectangle(img,top_left, bottom_right, 255, 2)
  31.  
  32. plt.subplot(121),plt.imshow(res,cmap = 'gray')
  33. plt.title('Matching Result'), plt.xticks([]), plt.yticks([])
  34. plt.subplot(122),plt.imshow(img)
  35. plt.title('Detected Point'), plt.xticks([]), plt.yticks([])
  36. plt.suptitle(meth)
  37.  
  38. plt.show()
結果
標示所有相同物品
程式碼
  1. import cv2
  2. import numpy as np
  3. from matplotlib import pyplot as plt
  4.  
  5. def getPoints(img, template, threshold, method=cv2.TM_CCOEFF_NORMED):
  6. result = cv2.matchTemplate(img, template, method)
  7.  
  8. if method in [cv2.TM_SQDIFF, cv2.TM_SQDIFF_NORMED]:
  9. loc = np.where(result <= threshold) #回傳為 y, x
  10. else:
  11. loc = np.where(result >= threshold) #回傳為 y, x
  12. pts = zip(*loc[::-1])
  13. return removeSame(pts, min(template.shape[0], template.shape[1]))
  14.  
  15. # 去掉太過接近的座標
  16. def removeSame(pts, threshold):
  17. elements = []
  18. for x,y in pts:
  19. for ele in elements:
  20. if ((x-ele[0])**2 + (y-ele[1])**2) < threshold**2:
  21. break
  22. else:
  23. elements.append((x,y))
  24. return elements
  25. def drawRectangle(img, pts, w, h, color=(0,0,255), lineW=2):
  26. for pt in pts:
  27. cv2.rectangle(img, pt, (pt[0] + w, pt[1] + h), color, lineW) #因後面會反向顏色順序,所以這邊顯示紅色要反向
  28.  
  29. if __name__ == '__main__':
  30. img = cv2.imread('image.png')
  31. img_result = img.copy()
  32. template = cv2.imread('template.png')
  33. w = template.shape[1]
  34. h = template.shape[0]
  35. threshold = 0.4
  36. elements = getPoints(img, template, threshold)
  37. total_n = len(elements)
  38. drawRectangle(img_result, elements, w=template.shape[1], h=template.shape[0])
  39.  
  40. # 逆時針旋轉 90
  41. template_r90 = cv2.transpose(template)
  42. cv2.flip(template_r90, 0)
  43. pts1 = getPoints(img, template_r90, threshold)
  44. # 順時針旋轉 90
  45. template_r90 = cv2.transpose(template)
  46. cv2.flip(template_r90, 1)
  47. pts2 = getPoints(img, template_r90, threshold)
  48. elements = removeSame(pts1+pts2, w)
  49. total_90 = len(elements)
  50. drawRectangle(img_result, elements, w=template_r90.shape[1], h=template_r90.shape[0], color=(0,255,0))
  51. plt.subplot(121)
  52. plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) #因 openCV 記錄的顏色順序是反過來的
  53. plt.title('Original'), plt.xticks([]), plt.yticks([])
  54. plt.subplot(122)
  55. plt.imshow(cv2.cvtColor(img_result, cv2.COLOR_BGR2RGB)) #因 openCV 記錄的顏色順序是反過來的
  56. plt.title('Detected Point'), plt.xticks([]), plt.yticks([])
  57. plt.suptitle('Total:{}\nnomral:{}, 90:{}'.format(total_n+total_90, total_n, total_90))
  58.  
  59. plt.show()
  60. # cv2.imwrite('result.png', img_result)

參考

Template Matching
Template Matching

留言