読者です 読者をやめる 読者になる 読者になる
無料で使えるシステムトレードフレームワーク「Jiji」 をリリースしました!

・OANDA Trade APIを利用した、オープンソースのシステムトレードフレームワークです。
・自分だけの取引アルゴリズムで、誰でも、いますぐ、かんたんに、自動取引を開始できます。

機械学習手習い: スパムフィルタを作る

「入門 機械学習」手習い、3日目。「3章 分類:スパムフィルタ」です。

www.amazon.co.jp

ナイーブベイズ分類器を作って、メールがスパムかどうかを判定するフィルタを作ります。

分類器の仕組み

  • 1) 以下の単語セットを作成
    • (a) スパムメッセージに出現しやすい単語とその出現確率
    • (b) スパムメッセージに出現しにくい単語とその出現確率
  • 2) で作成した単語セットを元に、メール本文を評価し、以下を算出
    • (a2) メールをスパムと仮定した時の尤もらしさ
    • (b2) メールを非スパムと仮定した時の尤もらしさ
  • 3) a2 > b2 となるメールをスパムと判定する

という感じで判定を行います。

必要なモジュールとデータの読み込み

> setwd("03-Classification/")
> library('tm')
> library('ggplot2')

# テスト用データ
# 分類機の訓練用
> spam.path <- file.path("data", "spam")        # スパムデータ
> easyham.path <- file.path("data", "easy_ham") # 非スパムデータ(易)
> hardham.path <- file.path("data", "hard_ham") # 非スパムデータ(難)
# 分類機のテスト用
> spam2.path <- file.path("data", "spam_2")        # スパムデータ
> easyham2.path <- file.path("data", "easy_ham_2") # 非スパムデータ(易)
> hardham2.path <- file.path("data", "hard_ham_2") # 非スパムデータ(難)

メールから本文を取り出す

ファイルを読み込んで、本文を返す関数を作成します。

> get.msg <- function(path) {
  con <- file(path, open = "rt", encoding = "latin1")
  text <- readLines(con)
  # The message always begins after the first full line break
  msg <- text[seq(which(text == "")[1] + 1, length(text), 1)]
  close(con)
  return(paste(msg, collapse = "\n"))
}

動作テスト。

> get.msg("data/spam/00001.7848dde101aa985090474a91ec93fcf0")
[1] "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0 Transitio...

sapply を使って、data/spam 内のスパムメールから本文を読み込みます。

> spam.docs <- dir(spam.path)
> spam.docs <- spam.docs[which(spam.docs != "cmds")]
> all.spam <- sapply(spam.docs,function(p) get.msg(file.path(spam.path, p)))
> head(all.spam)

単語文書行列を作る

単語を行、文書を列として、文書中の単語の出現数をカウントした、単語文書行列(TDM:Term Document Matrix)を作ります。

↓こんなイメージ

a.txt b.txt c.txt
hello 2 1 0
world 0 2 3


まずは、TDMを作成する関数を定義。

> get.tdm <- function(doc.vec){
  control <- list(stopwords = TRUE,
                  removePunctuation = TRUE,
                  removeNumbers = TRUE,
                  minDocFreq = 2)
  doc.corpus <- Corpus(VectorSource(doc.vec))
  doc.dtm <- TermDocumentMatrix(doc.corpus, control)
  return(doc.dtm)
}

all.spam から、TDMを作成します。

> spam.tdm <- get.tdm(all.spam)

スパムの訓練データを作る

スパムに含まれる単語の出現確率を集計します。

# all.spamのTDMをRの行列に変換
> spam.matrix <- as.matrix(spam.tdm)
# 各単語の、全スパム中での出現頻度をカウント
> spam.counts <- rowSums(spam.matrix)
# データフレームに変換
> spam.df <- data.frame(cbind(names(spam.counts),
                      as.numeric(spam.counts)),
                      stringsAsFactors = FALSE)
> names(spam.df) <- c("term", "frequency")
# 出現確率を算出
> spam.df$frequency <- as.numeric(spam.df$frequency)
> spam.occurrence <- sapply(1:nrow(spam.matrix), function(i) {
  length(which(spam.matrix[i, ] > 0)) / ncol(spam.matrix)
})
> spam.density <- spam.df$frequency / sum(spam.df$frequency)
> spam.df <- transform(spam.df,
                     density = spam.density,
                     occurrence = spam.occurrence)

訓練データの中身を確認。

> head(spam.df[with(spam.df, order(-occurrence)),])
        term frequency     density occurrence
7693   email       813 0.005855914      0.566
18706 please       425 0.003061210      0.508
14623   list       409 0.002945964      0.444
27202   will       828 0.005963957      0.422
2970    body       379 0.002729879      0.408
9369    free       543 0.003911146      0.390

スパムメール中の56%が email を含んでいます。

非スパムの訓練データを作る

スパムの訓練データを作ったのと同じ手順で、非スパムの訓練データも作ります。

> easyham.docs <- dir(easyham.path)
> easyham.docs <- easyham.docs[which(easyham.docs != "cmds")]
# 最初の分類機を作る際に、各メッセージがスパムである確率とそうでない確率を等しいと仮定するため、
# 訓練データもスパムの文書数と同じ数だけに限定する。
> all.easyham <- sapply(easyham.docs[1:length(spam.docs)],
                      function(p) get.msg(file.path(easyham.path, p)))
> easyham.tdm <- get.tdm(all.easyham)
> easyham.matrix <- as.matrix(easyham.tdm)
> easyham.counts <- rowSums(easyham.matrix)
> easyham.df <- data.frame(cbind(names(easyham.counts),
                         as.numeric(easyham.counts)),
                         stringsAsFactors = FALSE)
> names(easyham.df) <- c("term", "frequency")
> easyham.df$frequency <- as.numeric(easyham.df$frequency)
> easyham.occurrence <- sapply(1:nrow(easyham.matrix), function(i) {
 length(which(easyham.matrix[i, ] > 0)) / ncol(easyham.matrix)
})
> easyham.density <- easyham.df$frequency / sum(easyham.df$frequency)
> easyham.df <- transform(easyham.df,
                        density = easyham.density,
                        occurrence = easyham.occurrence)

訓練データの中身を確認。

> head(easyham.df[with(easyham.df, order(-occurrence)),])
       term frequency     density occurrence
5193  group       232 0.003317366      0.388
12877   use       271 0.003875027      0.378
13511 wrote       237 0.003388861      0.378
1629    can       348 0.004976049      0.368
7244   list       248 0.003546150      0.368
8581    one       356 0.005090441      0.336

分類器を作る

メールファイルを受け取り、それがスパム、または、非スパムである確率を計算する分類器を定義します。

> classify.email <- function(path, training.df, prior = 0.5, c = 1e-6) {

  # ファイルから本文を取り出し、単語の出現数をカウント
  msg <- get.msg(path)
  msg.tdm <- get.tdm(msg)
  msg.freq <- rowSums(as.matrix(msg.tdm))

  # メール内の単語のうち、訓練データに含まれる単語を取得
  msg.match <- intersect(names(msg.freq), training.df$term)
  
  # 単語の出現確率を掛け合わせて、条件付き確率を算出する
  # この時、訓練データに含まれない単語は、非常に小さい確率0.0001として扱う。
  if(length(msg.match) < 1) {
    # 訓練データに含まれる単語が1つもない場合、条件付き確率は、
    # 事前確率(prior) * (0.0001の単語数乗)
    # となる。
    return(prior * c ^ (length(msg.freq)))
  } else {
    # 訓練データから、単語ごとの出現確率を取り出す。
    match.probs <- training.df$occurrence[match(msg.match, training.df$term)]
    # 条件付き確率を計算
    # 事前確率(prior) * (訓練データに含まれる単語の出現確率の積) * (0.0001の訓練データに含まれない単語数乗)
    return(prior * prod(match.probs) * c ^ (length(msg.freq) - length(msg.match)))
  }
}

判定してみる

非スパム(難)のデータを判定してみます。

> hardham.docs <- dir(hardham.path)
> hardham.docs <- hardham.docs[which(hardham.docs != "cmds")]
> hardham.spamtest <- sapply(hardham.docs,
                           function(p) classify.email(file.path(hardham.path, p), training.df = spam.df))
> hardham.hamtest <- sapply(hardham.docs,
                          function(p) classify.email(file.path(hardham.path, p), training.df = easyham.df))
> hardham.res <- ifelse(hardham.spamtest > hardham.hamtest,
                      TRUE,
                      FALSE)
> summary(hardham.res)
   Mode   FALSE    TRUE    NA's 
logical     243       6       0 

データはすべて非スパムのものなので、誤判定は6件、偽陽性率は2.4%。

テスト用データすべてを判定してみる

スパム判定を行う関数を作成。

spam.classifier <- function(path) {
  pr.spam <- classify.email(path, spam.df)
  pr.ham <- classify.email(path, easyham.df)
  return(c(pr.spam, pr.ham, ifelse(pr.spam > pr.ham, 1, 0)))
}

テスト用のメールデータをすべて判定してみます。

> easyham2.docs <- dir(easyham2.path)
> easyham2.docs <- easyham2.docs[which(easyham2.docs != "cmds")]

> hardham2.docs <- dir(hardham2.path)
> hardham2.docs <- hardham2.docs[which(hardham2.docs != "cmds")]

> spam2.docs <- dir(spam2.path)
> spam2.docs <- spam2.docs[which(spam2.docs != "cmds")]

> easyham2.class <- suppressWarnings(lapply(easyham2.docs, function(p) {
  spam.classifier(file.path(easyham2.path, p))
}))
> hardham2.class <- suppressWarnings(lapply(hardham2.docs, function(p) {
  spam.classifier(file.path(hardham2.path, p))
}))
> spam2.class <- suppressWarnings(lapply(spam2.docs, function(p) {
  spam.classifier(file.path(spam2.path, p))
}))

> easyham2.matrix <- do.call(rbind, easyham2.class)
> easyham2.final <- cbind(easyham2.matrix, "EASYHAM")

> hardham2.matrix <- do.call(rbind, hardham2.class)
> hardham2.final <- cbind(hardham2.matrix, "HARDHAM")

> spam2.matrix <- do.call(rbind, spam2.class)
> spam2.final <- cbind(spam2.matrix, "SPAM")

> class.matrix <- rbind(easyham2.final, hardham2.final, spam2.final)
> class.df <- data.frame(class.matrix, stringsAsFactors = FALSE)
> names(class.df) <- c("Pr.SPAM" ,"Pr.HAM", "Class", "Type")
> class.df$Pr.SPAM <- as.numeric(class.df$Pr.SPAM)
> class.df$Pr.HAM <- as.numeric(class.df$Pr.HAM)
> class.df$Class <- as.logical(as.numeric(class.df$Class))
> class.df$Type <- as.factor(class.df$Type)

文書ごとのスパム/非スパムの尤もらしさ、分類結果、メールの種別を含むデータフレームができました。

> head(class.df)
        Pr.SPAM        Pr.HAM Class    Type
1  0.000000e+00  0.000000e+00 FALSE EASYHAM
2 5.352364e-248 1.159512e-155 FALSE EASYHAM
3  0.000000e+00 5.103377e-216 FALSE EASYHAM
4  0.000000e+00  0.000000e+00 FALSE EASYHAM
5 2.083521e-169 1.221918e-108 FALSE EASYHAM
6  0.000000e+00  0.000000e+00 FALSE EASYHAM

文書の種類ごとに結果を集計してみます。

> get.results <- function(bool.vector) {
  results <- c(length(bool.vector[which(bool.vector == FALSE)]) / length(bool.vector),
               length(bool.vector[which(bool.vector == TRUE)]) / length(bool.vector))
  return(results)
}
> easyham2.col <- get.results(subset(class.df, Type == "EASYHAM")$Class)
> hardham2.col <- get.results(subset(class.df, Type == "HARDHAM")$Class)
> spam2.col <- get.results(subset(class.df, Type == "SPAM")$Class)

> class.res <- rbind(easyham2.col, hardham2.col, spam2.col)
> colnames(class.res) <- c("NOT SPAM", "SPAM")
> print(class.res)
              NOT SPAM       SPAM
easyham2.col 0.9871429 0.01285714
hardham2.col 0.9677419 0.03225806
spam2.col    0.4631353 0.53686471

非スパム(easyham2, hardham2)を間違ってスパムと判定する確率(偽陽性率)はそれぞれ1%,3%と低く、うまく分類できているもよう。 ただ、スパム(spam2)を間違ってスパムでないと判定する確率(偽陰性率)は46%。あれ、ちょっと高いかも?

結果を分散図にしてみる

> class.plot <- ggplot(class.df, aes(x = log(Pr.HAM), log(Pr.SPAM))) +
    geom_point(aes(shape = Type, alpha = 0.5)) +
    geom_abline(intercept = 0, slope = 1) +
    scale_shape_manual(values = c("EASYHAM" = 1,
                                  "HARDHAM" = 2,
                                  "SPAM" = 3),
                       name = "Email Type") +
    scale_alpha(guide = "none") +
    xlab("log[Pr(HAM)]") +
    ylab("log[Pr(SPAM)]") +
    theme_bw() +
    theme(axis.text.x = element_blank(), axis.text.y = element_blank())
> ggsave(plot = class.plot,
       filename = file.path("images", "03_final_classification.png"),
       height = 10,
       width = 10)

f:id:unageanu:20160111152705p:plain

横軸が「メールを非スパムと仮定した時の尤もらしさ」、縦軸が「メールをスパムと仮定した時の尤もらしさ」を示します。

「メールをスパムと仮定した時の尤もらしさ」 > 「メールを非スパムと仮定した時の尤もらしさ」となったメールをスパムと判定するので、真ん中の線より上の者はスパム、下は非スパムと判定されています。線より上に〇や△(非スパムのメール)がいくつかあったりはしますが、おおむね正しく判定できている感じですね。

事前分布を変えて、結果を改善する

↑では、とあるメールがあった時にそれがスパムである確率は50%(=世の中のメールの半分はスパムで半分はそうではない)と仮定していました。 現実には、80%は非スパム、残り20%がスパムなので、これを考慮して再計算することで結果を改善してみます。

spam.classifier を修正して、事前確率を変更。

spam.classifier <- function(path) {
  pr.spam <- classify.email(path, spam.df, prior=0.2)
  pr.ham <- classify.email(path, easyham.df, prior=0.8)
  return(c(pr.spam, pr.ham, ifelse(pr.spam > pr.ham, 1, 0)))
}

再計算した結果は以下。

> print(class.res)
              NOT SPAM       SPAM
easyham2.col 0.9892857 0.01071429
hardham2.col 0.9717742 0.02822581
spam2.col    0.4652827 0.53471725

非スパム(easyham2, hardham2)を間違ってスパムと判定する確率は少し改善しました。 一方、スパム(spam2)を間違ってスパムでないと判定する確率は悪化。。。