- 取得連結
- X
- 以電子郵件傳送
- 其他應用程式
ML:基礎學習
課程:機器學習基石
簡介:第二講 Learning to Answer Yes/No
忽略
完 美
是否越來越靠近 ?
也就是 隨著 更新,是否越來越大?
以上證明,只得出一半,內積的變大也有可能是 長度越來越長導致的
若 只在犯錯才更新
會慢慢成長,即使更新的是最長的 ,因 不影響長度
或是從以下公式也可看出,會持續更新直到上限
初始於 ,經過 次的更新後
證明如下
若 完美
若 只在犯錯才更新
更新的方法
針對
針對
代入得證:
因為 ,內積最大也只是長度的相乘
最 遠 的 半 徑 最 接 近 分 隔 線 的 與 法 向 量 的 內 積
但為 NP-hard 需花費大量時間解,故使用 Pocket Algorithm (PLA 變形)
課程:機器學習基石
簡介:第二講 Learning to Answer Yes/No
PLA (Perceptron Learning Algorithm)
- 選擇
,可設為 0 - For
-
找到 的下一個錯誤 (可用依序 or 亂數 or 其他選擇點的方法) - 修正錯誤
- 直到所有點皆無錯誤
基本定義
忽略
更新原理
- 當
時,但 卻為負,則將 靠近 - 當
時,但 卻為正,則將 遠離
必要條件
資料需是線性可分,也就是可被分為各一半
線性可分的 存在一完美 使得 證明
也就是
若
或是從以下公式也可看出,會持續更新直到上限
初始於
若
PLA 程式碼
Python 原始碼
- import matplotlib.pyplot as plt
- import numpy as np
- import operator
- # 網路上找的 dataset 可以線性分割
- rawData = [
- ((-0.4, 0.3), -1),
- ((-0.3, -0.1), -1),
- ((-0.2, 0.4), -1),
- ((-0.1, 0.1), -1),
- ((0.9, -0.5), 1),
- ((0.7, -0.9), 1),
- ((0.8, 0.2), 1),
- ((0.4, -0.6), 1)]
- # 加入 x0
- dataset = [((1,) + x, y) for x, y in rawData]
- # 內積
- def dot(*v):
- return sum(map(operator.mul, *v))
- # 取 sign (1, 0, -1)
- def sign(v):
- if v > 0:
- return 1
- elif v == 0:
- return 0
- else: # v < 0
- return -1
- # 判斷有沒有分類錯誤
- def check_error(w, x, y):
- if sign(dot(w, x)) != y:
- return True
- else:
- return False
- # 更新 w
- def update(w, x, y):
- u = map(operator.mul, [y] * len(x), x)
- w = map(operator.add, w, u)
- return list(w)
- # PLA演算法實作
- def pla(dataset):
- # 初始化 w
- w = [0] * 3
- t = 0
- while True:
- print("{}: {}".format(t, tuple(w)))
- no_error = True
- for x, y in dataset:
- if check_error(w, x, y):
- w = update(w, x, y)
- no_error = False
- break
- t += 1
- if no_error:
- break
- return w
- # 主程式
- def main():
- # 執行
- w = pla(dataset)
- # 畫圖
- fig = plt.figure()
- # numrows=1, numcols=1, fignum=1
- ax1 = fig.add_subplot(111)
- xx = list(filter(lambda d: d[1] == -1, dataset))
- ax1.scatter([x[0][1] for x in xx], [x[0][2] for x in xx],
- s=100, c='b', marker="x", label='-1')
- oo = list(filter(lambda d: d[1] == 1, dataset))
- ax1.scatter([x[0][1] for x in oo], [x[0][2] for x in oo],
- s=100, c='r', marker="o", label='1')
- l = np.linspace(-2, 2)
- # w0 + w1x + w2y = 0
- # y = -w0/w2 - w1/w2 x
- if w[2]:
- a, b = -w[1] / w[2], -w[0] / w[2]
- ax1.plot(l, a * l + b, 'b-')
- else:
- ax1.plot([-w[0] / w[1]] * len(l), l, 'b-')
- plt.legend(loc='upper left', scatterpoints=1)
- plt.show()
- if __name__ == '__main__':
- main()
PLA 缺點
- 無法得知
是否線性可分 - 即使線性可分,也無法得知要跑多久,因為需要
若資料不為線性可分 或 有雜訊
找到錯誤最少的線Pocket Algorithm (PLA 變形)
- 初始化
- For
-
找到 的下一個錯誤 (亂數選擇較佳) - 修正錯誤
- 假如
的錯誤比 少,則將 設為 - 直到足夠的次數 或 小於設定的錯誤門檻
Pocket Algorithm 程式碼
Python 原始碼
- import matplotlib.pyplot as plt
- import numpy as np
- import operator
- import random
- # 網路上找的 dataset 再加上亂數的 data 不可以線性分割
- rawData = [
- ((-0.4, 0.3), -1),
- ((-0.3, -0.1), -1),
- ((-0.2, 0.4), -1),
- ((-0.1, 0.1), -1),
- ((0.9, -0.5), 1),
- ((0.7, -0.9), 1),
- ((0.8, 0.2), 1),
- ((0.4, -0.6), 1),
- ((0.2, 0.6), -1),
- ((-0.5, -0.5), -1),
- ((0.7, 0.3), 1),
- ((0.9, -0.6), 1),
- ((-0.1, 0.2), -1),
- ((0.3, -0.6), 1),
- ((0.5, 0.1), -1), ]
- # 加入 x0
- dataset = [((1,) + x, y) for x, y in rawData]
- # 內積
- def dot(*v):
- return sum(map(operator.mul, *v))
- # 取 sign (1, 0, -1)
- def sign(v):
- if v > 0:
- return 1
- elif v == 0:
- return 0
- else: # v < 0
- return -1
- # 判斷有沒有分類錯誤
- def check_error(w, x, y):
- if sign(dot(w, x)) != y:
- return True
- else:
- return False
- # 更新 w
- def update(w, x, y):
- u = map(operator.mul, [y] * len(x), x)
- w = map(operator.add, w, u)
- return list(w)
- # 總錯誤數
- def sum_errors(w, dataset):
- errors = 0
- for x, y in dataset:
- if check_error(w, x, y):
- errors += 1
- return errors
- # POCKET 演算法實作
- def pocket(dataset):
- # 初始化 w
- w = [0] * 3
- min_e = sum_errors(w, dataset)
- max_t = 500
- for t in range(0, max_t):
- wt = None
- et = None
- while True:
- x, y = random.choice(dataset)
- if check_error(w, x, y):
- wt = update(w, x, y)
- et = sum_errors(wt, dataset)
- break
- if et < min_e:
- w = wt
- min_e = et
- print("{}: {}".format(t, tuple(w)))
- print("min erros: {}".format(min_e))
- t += 1
- if min_e == 0:
- break
- return (w, min_e)
- # 主程式
- def main():
- # 執行,並輸入新的 list
- w, e = pocket(list(dataset))
- # 畫圖
- fig = plt.figure()
- # numrows=1, numcols=1, fignum=1
- ax1 = fig.add_subplot(111)
- xx = list(filter(lambda d: d[1] == -1, dataset))
- ax1.scatter([x[0][1] for x in xx], [x[0][2] for x in xx],
- s=100, c='b', marker="x", label='-1')
- oo = list(filter(lambda d: d[1] == 1, dataset))
- ax1.scatter([x[0][1] for x in oo], [x[0][2] for x in oo],
- s=100, c='r', marker="o", label='1')
- l = np.linspace(-2, 2)
- # w0 + w1x + w2y = 0
- # y = -w0/w2 - w1/w2 x
- if w[2]:
- a, b = -w[1] / w[2], -w[0] / w[2]
- ax1.plot(l, a * l + b, 'b-')
- else:
- ax1.plot([-w[0] / w[1]] * len(l), l, 'b-')
- plt.legend(loc='upper left', scatterpoints=1)
- plt.show()
- if __name__ == '__main__':
- main()
假設資料為線性可分,卻跑 Pocket Algorithm 的缺點
- Pocket Algorithm 比 PLA 慢
- 花力氣存檔,把最佳解存起來
- 要檢查所有資料的錯誤才能比較
和 - 但跑夠久,仍可得到
留言
張貼留言