-Python- ロジスティック回帰

二値分類に使われるアルゴリズムであるロジスティック回帰のプログラム例を以下に示します.

import numpy as np
from scipy import linalg
 
THRESHMIN = 1e-10
 
 
def sigmoid(x):
    return 1 / (1 + np.exp(-x))
 
 
class LogisticRegression:
    def __init__(self, tol=0.001, max_iter=3, random_seed=0):
        self.tol = tol
        self.max_iter = max_iter
        self.random_state = np.random.RandomState(random_seed)
        self.w_ = None
 
    def fit(self, X, y):
        self.w_ = self.random_state.randn(X.shape[1] + 1)
        Xtil = np.c_[np.ones(X.shape[0]), X]
        diff = np.inf
        w_prev = self.w_
        iter = 0
        while diff > self.tol and iter < self.max_iter:
            yhat = sigmoid(np.dot(Xtil, self.w_))
            r = np.clip(yhat * (1 - yhat),
                        THRESHMIN, np.inf)
            XR = Xtil.T * r
            XRX = np.dot(Xtil.T * r, Xtil)
            w_prev = self.w_
            b = np.dot(XR, np.dot(Xtil, self.w_) -
                       1 / r * (yhat - y))
            self.w_ = linalg.solve(XRX, b)
            diff = abs(w_prev - self.w_).mean()
            iter += 1
 
    def predict(self, X):
        Xtil = np.c_[np.ones(X.shape[0]), X]
        yhat = sigmoid(np.dot(Xtil, self.w_))

 

        return np.where(yhat > .5, 1, 0)
 
これをUCIリポジトリのBreast Cancer Wisconsin (Diagnostic) Data Setを利用します.こちらのURLから wdbc.data をダウンロードしてプログラムと同じフォルダに置いておきます.
 
import logisticreg
import csv
import numpy as np
 
n_test = 100
X =
y =
with open("wdbc.data") as fp:
    for row in csv.reader(fp):
        if row[1] == "B":
            y.append(0)
        else:
            y.append(1)
        X.append(row[2:])
 
y = np.array(y, dtype=np.float64)
X = np.array(X, dtype=np.float64)
y_train = y[:-n_test]
X_train = X[:-n_test]
y_test = y[-n_test:]
X_test = X[-n_test:]
model = logisticreg.LogisticRegression(tol=0.01)
model.fit(X_train, y_train)
 
y_predict = model.predict(X_test)
n_hits = (y_test == y_predict).sum()
print("Accuracy: {}/{} = {}".format(n_hits, n_test, n_hits / n_test))
 

実行結果は以下のようになります.

Accuracy: 97/100 = 0.97