[ML] 機器學習基石:第二講 Learning to Answer Yes/No

ML:基礎學習
課程:機器學習基石
簡介:第二講 Learning to Answer Yes/No

PLA (Perceptron Learning Algorithm)

  1. 選擇 w0,可設為 0
  2. For t=0,1,
    1. (xn(t),yn(t)) 找到 wt 的下一個錯誤 (可用依序 or 亂數 or 其他選擇點的方法)
    2. sign[wtTxn(t)]yn(t)
    3. 修正錯誤
    4. wt+1wt+yn(t)xn(t)
    5. 直到所有點皆無錯誤


基本定義

x=(x1,x2,,xd)y={+1,1}hH
忽略 y=0
h(x)=sign[(i=1dwixi)threshold]=sign[(i=1dwixi)+(threshold)w0(+1)x0]=sign[(i=1dwixi)+w0x0]=sign[i=0dwixi]=sign[wTx]

更新原理

wt+1wt+yn(t)xn(t)
  • y=+1 時,但 wx 卻為負,則將 w 靠近 x
  • y=1 時,但 wx 卻為正,則將 w 遠離 x


必要條件

資料需是線性可分,也就是可被分為各一半
線性可分的 D 存在一完美 wf 使得 yn=sign(wfTxn) 證明
wf yn(t)wfTxn(t)minn ynwfTxn>0 wt 是否越來越靠近 wf
也就是 wfTwt 隨著 (xn(t),yn(t)) 更新,是否越來越大?
wfTwt+1=wfT(wt+yn(t)xn(t))wfTwt+minn ynwfTxn>wfTwt+0 以上證明,只得出一半,內積的變大也有可能是 wt 長度越來越長導致的
wt 只在犯錯才更新
sign(wtTxn(t))yn(t)yn(t)wtTxn(t)0 wt+12=wt+yn(t)xn(t)2=wt2+2yn(t)wtTxn(t)+yn(t)xn(t)2wt2+0+yn(t)xn(t)2wt2+maxn ynxn2 wt2 會慢慢成長,即使更新的是最長的 xn2,因 yn 不影響長度
或是從以下公式也可看出,會持續更新直到上限
初始於 w0=0,經過 T 次的更新後
wfTwTwfwTTconstant 證明如下
wf 完美 yn(t)wfTxn(t)minn ynwfTxn>0(1) wt 只在犯錯才更新
sign(wtTxn(t))yn(t)yn(t)wtTxn(t)0(2) 更新的方法
wt=wt1+yn(t)xn(t)(3) 針對 wt wfTwt=wfT(wt1+yn(t1)xn(t1))by (3)wfTwt1+minn ynwfTxnby (1)w0+Tminn ynwfTxnby T timesTminn ynwfTxn 針對 wt wt2=wt1+yn(t1)xn(t1)2by (3)=wt12+2yn(t1)wt1Txn(t1)+yn(t1)xn(t1)2wt12+0+yn(t1)xn(t1)2by (2)wt12+maxn xn2w02+Tmaxn xn2=Tmaxn xn2 代入得證: wfTwtwfwtTminn ynwfTxnwfwtTminn ynwfTxnwfTmaxn xnTminn ynwfTxnwfmaxn xn=Tconstant 因為 wfTwtwfwt1 ,內積最大也只是長度的相乘 Tconstant1Tminn ynwfTxnwfmaxn xn1Twf2maxn xn2minn2 ynwfTxn=R2ρ2 R=maxn xn ()ρ=minn ynwfTwfxn (xn)


PLA 程式碼

Python 原始碼
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import operator
  4.  
  5. # 網路上找的 dataset 可以線性分割
  6. rawData = [
  7. ((-0.4, 0.3), -1),
  8. ((-0.3, -0.1), -1),
  9. ((-0.2, 0.4), -1),
  10. ((-0.1, 0.1), -1),
  11. ((0.9, -0.5), 1),
  12. ((0.7, -0.9), 1),
  13. ((0.8, 0.2), 1),
  14. ((0.4, -0.6), 1)]
  15.  
  16. # 加入 x0
  17. dataset = [((1,) + x, y) for x, y in rawData]
  18.  
  19.  
  20. # 內積
  21. def dot(*v):
  22. return sum(map(operator.mul, *v))
  23.  
  24.  
  25. # 取 sign (1, 0, -1)
  26. def sign(v):
  27. if v > 0:
  28. return 1
  29. elif v == 0:
  30. return 0
  31. else: # v < 0
  32. return -1
  33.  
  34.  
  35. # 判斷有沒有分類錯誤
  36. def check_error(w, x, y):
  37. if sign(dot(w, x)) != y:
  38. return True
  39. else:
  40. return False
  41.  
  42.  
  43. # 更新 w
  44. def update(w, x, y):
  45. u = map(operator.mul, [y] * len(x), x)
  46. w = map(operator.add, w, u)
  47. return list(w)
  48.  
  49.  
  50. # PLA演算法實作
  51. def pla(dataset):
  52. # 初始化 w
  53. w = [0] * 3
  54.  
  55. t = 0
  56. while True:
  57. print("{}: {}".format(t, tuple(w)))
  58. no_error = True
  59. for x, y in dataset:
  60. if check_error(w, x, y):
  61. w = update(w, x, y)
  62. no_error = False
  63. break
  64. t += 1
  65.  
  66. if no_error:
  67. break
  68.  
  69. return w
  70.  
  71.  
  72. # 主程式
  73. def main():
  74. # 執行
  75. w = pla(dataset)
  76.  
  77. # 畫圖
  78. fig = plt.figure()
  79.  
  80. # numrows=1, numcols=1, fignum=1
  81. ax1 = fig.add_subplot(111)
  82.  
  83. xx = list(filter(lambda d: d[1] == -1, dataset))
  84. ax1.scatter([x[0][1] for x in xx], [x[0][2] for x in xx],
  85. s=100, c='b', marker="x", label='-1')
  86. oo = list(filter(lambda d: d[1] == 1, dataset))
  87. ax1.scatter([x[0][1] for x in oo], [x[0][2] for x in oo],
  88. s=100, c='r', marker="o", label='1')
  89. l = np.linspace(-2, 2)
  90.  
  91. # w0 + w1x + w2y = 0
  92. # y = -w0/w2 - w1/w2 x
  93. if w[2]:
  94. a, b = -w[1] / w[2], -w[0] / w[2]
  95. ax1.plot(l, a * l + b, 'b-')
  96. else:
  97. ax1.plot([-w[0] / w[1]] * len(l), l, 'b-')
  98.  
  99. plt.legend(loc='upper left', scatterpoints=1)
  100. plt.show()
  101.  
  102.  
  103. if __name__ == '__main__':
  104. main()
  105.  


PLA 缺點

  • 無法得知 D 是否線性可分
  • 即使線性可分,也無法得知要跑多久,因為需要 wf


若資料不為線性可分 或 有雜訊

找到錯誤最少的線
wgargminwn=1N[ynsign(wTxn)] 但為 NP-hard 需花費大量時間解,故使用 Pocket Algorithm (PLA 變形)


Pocket Algorithm (PLA 變形)

  1. 初始化 w^
  2. For t=0,1,
    1. (xn(t),yn(t)) 找到 wt 的下一個錯誤 (亂數選擇較佳)
    2. sign[wtTxn(t)]yn(t)
    3. 修正錯誤
    4. wt+1wt+yn(t)xn(t)
    5. 假如 wt+1 的錯誤比 w^ 少,則將 wt+1 設為 w^
    6. 直到足夠的次數 或 小於設定的錯誤門檻


Pocket Algorithm 程式碼

Python 原始碼
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import operator
  4. import random
  5.  
  6. # 網路上找的 dataset 再加上亂數的 data 不可以線性分割
  7. rawData = [
  8. ((-0.4, 0.3), -1),
  9. ((-0.3, -0.1), -1),
  10. ((-0.2, 0.4), -1),
  11. ((-0.1, 0.1), -1),
  12. ((0.9, -0.5), 1),
  13. ((0.7, -0.9), 1),
  14. ((0.8, 0.2), 1),
  15. ((0.4, -0.6), 1),
  16. ((0.2, 0.6), -1),
  17. ((-0.5, -0.5), -1),
  18. ((0.7, 0.3), 1),
  19. ((0.9, -0.6), 1),
  20. ((-0.1, 0.2), -1),
  21. ((0.3, -0.6), 1),
  22. ((0.5, 0.1), -1), ]
  23.  
  24. # 加入 x0
  25. dataset = [((1,) + x, y) for x, y in rawData]
  26.  
  27.  
  28. # 內積
  29. def dot(*v):
  30. return sum(map(operator.mul, *v))
  31.  
  32.  
  33. # 取 sign (1, 0, -1)
  34. def sign(v):
  35. if v > 0:
  36. return 1
  37. elif v == 0:
  38. return 0
  39. else: # v < 0
  40. return -1
  41.  
  42.  
  43. # 判斷有沒有分類錯誤
  44. def check_error(w, x, y):
  45. if sign(dot(w, x)) != y:
  46. return True
  47. else:
  48. return False
  49.  
  50.  
  51. # 更新 w
  52. def update(w, x, y):
  53. u = map(operator.mul, [y] * len(x), x)
  54. w = map(operator.add, w, u)
  55. return list(w)
  56.  
  57.  
  58. # 總錯誤數
  59. def sum_errors(w, dataset):
  60. errors = 0
  61. for x, y in dataset:
  62. if check_error(w, x, y):
  63. errors += 1
  64.  
  65. return errors
  66.  
  67.  
  68. # POCKET 演算法實作
  69. def pocket(dataset):
  70. # 初始化 w
  71. w = [0] * 3
  72. min_e = sum_errors(w, dataset)
  73.  
  74. max_t = 500
  75. for t in range(0, max_t):
  76. wt = None
  77. et = None
  78.  
  79. while True:
  80. x, y = random.choice(dataset)
  81. if check_error(w, x, y):
  82. wt = update(w, x, y)
  83. et = sum_errors(wt, dataset)
  84. break
  85.  
  86. if et < min_e:
  87. w = wt
  88. min_e = et
  89.  
  90. print("{}: {}".format(t, tuple(w)))
  91. print("min erros: {}".format(min_e))
  92.  
  93. t += 1
  94.  
  95. if min_e == 0:
  96. break
  97.  
  98. return (w, min_e)
  99.  
  100.  
  101. # 主程式
  102. def main():
  103. # 執行,並輸入新的 list
  104. w, e = pocket(list(dataset))
  105.  
  106. # 畫圖
  107. fig = plt.figure()
  108.  
  109. # numrows=1, numcols=1, fignum=1
  110. ax1 = fig.add_subplot(111)
  111.  
  112. xx = list(filter(lambda d: d[1] == -1, dataset))
  113. ax1.scatter([x[0][1] for x in xx], [x[0][2] for x in xx],
  114. s=100, c='b', marker="x", label='-1')
  115. oo = list(filter(lambda d: d[1] == 1, dataset))
  116. ax1.scatter([x[0][1] for x in oo], [x[0][2] for x in oo],
  117. s=100, c='r', marker="o", label='1')
  118. l = np.linspace(-2, 2)
  119.  
  120. # w0 + w1x + w2y = 0
  121. # y = -w0/w2 - w1/w2 x
  122. if w[2]:
  123. a, b = -w[1] / w[2], -w[0] / w[2]
  124. ax1.plot(l, a * l + b, 'b-')
  125. else:
  126. ax1.plot([-w[0] / w[1]] * len(l), l, 'b-')
  127.  
  128. plt.legend(loc='upper left', scatterpoints=1)
  129. plt.show()
  130.  
  131.  
  132. if __name__ == '__main__':
  133. main()
  134.  


假設資料為線性可分,卻跑 Pocket Algorithm 的缺點

  • Pocket Algorithm 比 PLA 慢
    • 花力氣存檔,把最佳解存起來
    • 要檢查所有資料的錯誤才能比較 wt+1w^
  • 但跑夠久,仍可得到 wPOCKET=wPLA

留言