[ML] 機器學習技法:第九講 Decision Tree

ML:基礎技法學習
Package:scikit-learn
課程:機器學習技法
簡介:第九講 Decision Tree

提要

  • 優點
    • 可輕易被解釋
    • 簡單實現
    • 有效率
  • 缺點
    • 理論支持薄弱
    • 參數無理論保證,難以設計
    • 無特定代表性的演算法,只有流不流行

基本算式

G(x)=c=1C[b(x)=c]Gc(x) G(x):表示整顆樹
b(x):表示分支的條件
Gc(x):表示 c 條件下分支的子樹
G(x)=t=1Tqt(x)gt(x) T:表示 leaf 的個數,如圖中共有六個 leaf
qt(x):表示滿足條件的 path,如紫色的部分
gt(x):表示到達 leaf 上最終執行的演算法
或者可以換個方向來看
一顆樹是由不同的分支子樹所組成的
G(x)=c=1C[b(x)=c]Gc(x)G(x) 表示整顆樹,依不同的分支條件 [b(x)=c],決定其分支的子樹 Gc(x)

基本演算法

  • function DecisionTree(D={(xn,yn)}n=1N
    • if (符合中止條件)
      • 回傳 gt(x)
    • else
      1. 學習  b(x)
      2. D 分割成 C 個部分 Dc={(xn,yn):b(xn)=c}
      3. 遞迴 Gc(x)=DecisionTree(Dc)
      4. 回傳 G(x)=c=1C[b(x)=c]Gc(x)

Classification and Regression Tree (C&RT)

來自於 CARTTM of California Statistical Software
  • function DecisionTree(D={(xn,yn)}n=1N
    • if (無法再切割) => 故為 fully-grown tree
      • 回傳 gt(x)=Einoptimal constant
    • else
      1. 學習 b(x)
        b(x)=argmindecision stumps h(x)c=12|Dc with h|impurity(Dc with h)
        • classification 使用 Gini index
          impurity(D)=1k=1K(n=1N[yn=k]N)2
        • regression
          impurity(D)=1Nn=1N(yny¯)2
          y¯{yn} 的平均
      2. D 分割成 2 個部分 Dc={(xn,yn):b(xn)=c}
      3. 遞迴 Gc(x)=DecisionTree(Dc)
      4. 回傳 G(x)=c=12[b(x)=c]Gc(x)
基於基本的 decisionTree 的演算法
首先決定樹的種類,為二元樹,所以 C=2
決定 leaf 上回傳的 gt(x) 是一個最佳化的常數
  • classification => 回傳 {yn} 最多數的類別
  • regression => 回傳 {yn} 的平均值
C=2 如何切兩半呢?可利用 decision stump,這裡比較不同的是,值為 {1,2},而不是 {0,1}
那麼怎樣才是好的切割?因 gt(x) 的決定方式,希望此時的 Dc 越純越好,這樣 Ein 才會低
b(x)=argmindecision stumps h(x)c=12impurity(Dc with h) 但 impurity function 該用什麼才能表示純度呢?
首先可考慮 Ein 的大小,越小越純
  • regression
    • 常用的 square error,y¯{yn} 的平均
      impurity(D)=1Nn=1N(yny¯)2
  • classification
    • 分類錯誤 impurity(D)=1Nn=1N[yny]y 為多數的 {yn}
    • 或者換個方式寫 impurity(D)=1max1kKn=1N[yn=k]Nk=y
    • 但以上兩者的缺點就是只考慮單一項,未考慮整體的分佈是否有集中在某項上
      故通常使用 Gini index impurity(D)=1k=1K(n=1N[yn=k]N)2
但資料切割後,大小不一樣,理論上越大的希望越純,越小的則較不重要
故加上一加權值 |Dc with h| 為資料的大小,也許可使用百分比
b(x)=argmindecision stumps h(x)c=12|Dc with h|impurity(Dc with h) 那麼何時可以終止呢?切到無法再切為止
所以滿足以下其中之一,即可中止
  • yn 都一樣時,表示 impurity = 0 => gt(x)=yn
  • xn 都一樣時,無法再執行 decision stumps

關於 Decision Tree

  • Regularization by Pruning
    • xn 都不一樣時,fully-grown tree 雖然可以 Ein(G)=0
      但這也表示很容易 overfit,因為越接近 leaf 的資料 Dc 越少,訓練出來的越有問題
    • regularized decision tree,通常也叫做 pruned decision tree
      利用此式子,Ω(G) 為 leaf 的個數
      argminall possible GEin(G)+λΩ(G)
    • 實際上無法考慮所有的 G,故實務上只針對訓練出來的 G 做處理
      • G(0) = fully-grown tree
      • G(i) = argminG Ein(G)
        G(i1) 移除一個 leaf 的所有選擇中,最好 Ein(G)G
      • 移除 leaf 的意義即是跟隔壁資料做合併,重新計算 gt(x)
      • 將全部的 G(i) 代入 regularizer decision tree,使用 Validation 得到合適的 λ
        或者直接用此得到合適的 G
  • Branching on Categorical Features
    • numerical feature,像是血壓
      b(x)=[xiθ]+1θR
    • 若換到 Categorical Features 也可輕易應用,像是症狀
      b(x)=[xiS]+1S1,2,,K
    • decision subset,某幾類走左邊,其餘走右邊,由 S 決定
  • Missing Features by Surrogate Branch
    • 可適用缺少資料的時候
    • 訓練時,同時記錄替代的 feature,令其分類跟原本的最佳解差不多
      替代的  branch b1(x),b2(x),best branch b(x)

C&RT 特點

  • 與 AdaBoost 比較
    • 只針對特定區域做切割,並非是整區切割
    • 較有效率,t=90 已完成,但 AdaBoost 仍在微調
  • 特點
    • 容易解釋
    • multiclass 實現簡單
    • categorical features 應用簡單
    • missing features 處理簡單
    • non-linear training (and testing) 有效率
  • 與之對應
    • C4.5 演算法,在關鍵點的處理不同而已

比較

  • blending
    • 適用已知 gtgt 來自不同的 H,只是單純合併
  • learning 
    • 適用未知 gtgt 來自同樣的 H,邊學邊決定

程式碼

Overfitting 需注意
  1. import numpy as np
  2. from sklearn.tree import DecisionTreeRegressor
  3. import matplotlib.pyplot as plt
  4.  
  5. # 建立訓練資料
  6. rng = np.random.RandomState(1)
  7. X = np.sort(5 * rng.rand(80, 1), axis=0)
  8. y = np.sin(X).ravel()
  9. # 每五個加入 noise
  10. y[::5] += 3 * (0.5 - rng.rand(16))
  11.  
  12. # 訓練模型
  13. estimators = [("max_depth=max", DecisionTreeRegressor(), "slategray"),
  14. ("max_depth=5", DecisionTreeRegressor(max_depth=5), "yellowgreen"),
  15. ("max_depth=2", DecisionTreeRegressor(max_depth=2), "cornflowerblue"),]
  16. size = len(estimators)
  17. # 預測資料
  18. X_test = np.arange(0.0, 5.0, 0.01)[:, None]
  19.  
  20. # 畫圖
  21. plt.figure()
  22.  
  23. for i, (label, clf, c) in enumerate(estimators):
  24. # 訓練
  25. clf.fit(X, y)
  26. # 預測
  27. yp = clf.predict(X_test)
  28. # 畫預測的線
  29. plt.plot(X_test, yp, label=label, linewidth=size-i, c=c)
  30.  
  31. # 畫出原始資料
  32. plt.scatter(X, y, c="darkorange", label="data")
  33. # 標示 x, y
  34. plt.xlabel("data")
  35. plt.ylabel("target")
  36. # 設定 title
  37. plt.title("Decision Tree Regression")
  38. plt.legend()
  39. plt.show()

參考

1.10. Decision Trees
sklearn.tree.DecisionTreeClassifier
sklearn.tree.DecisionTreeRegressor

留言