無料で使えるシステムトレードフレームワーク「Jiji」 をリリースしました!

・OANDA Trade APIを利用した、オープンソースのシステムトレードフレームワークです。
・自分だけの取引アルゴリズムで、誰でも、いますぐ、かんたんに、自動取引を開始できます。

機械学習手習い: 回帰分析

「入門 機械学習」手習い、5日目。「5章 回帰:ページビューの予測」です。

www.amazon.co.jp

回帰分析を使って、Webサイトのページビューを予測します。

平均値による単純な予測と、予測精度の計測

回帰分析のイメージをつかむために、平均値を使った単純な予測を試します。 平均値を使って寿命を予測し、平均二乗誤差を使って精度を測定。そこに、追加の情報(喫煙者かどうか)を加えて精度を向上します。

まずは、依存ライブラリの読み込みから。

> setwd("05-Regression/")
> library('ggplot2')

寿命データを読み込む

> ages <- read.csv(file.path('data', 'longevity.csv'))
> head(ages)
  Smokes AgeAtDeath
1      1         75
2      1         72
3      1         66
4      1         74
5      1         69
6      1         65

ヒストグラムにしてみます。

> plot = ggplot(ages, aes(x = AgeAtDeath, fill = factor(Smokes))) +
  geom_density() + facet_grid(Smokes ~ .)
> ggsave(plot = plot,
       filename = file.path("images", "longevity_plot.png"),
       height = 4.8,
       width = 7)

f:id:unageanu:20160113141102p:plain

非喫煙者のピークが喫煙者より右に寄っていて、やはり、非喫煙者の方が全体的に長生きな感じ。

平均を使った予測とその精度を測る

データの概要がつかめたところで、平均値を使った予測をしてみます。 まずは、平均値を算出。

> mean(ages$AgeAtDeath)
[1] 72.723

平均2乗誤差を計算。

> guess <- 72.723
> with(ages, mean((AgeAtDeath - guess) ^ 2))
[1] 32.91427

32.91427 となりました。

喫煙者かどうかを考慮して、精度を向上する

喫煙者/非喫煙者それぞれで平均値を計測し、推測値として使います。

> smokers.guess <- with(subset(ages, Smokes == 1), mean(AgeAtDeath))
> non.smokers.guess <- with(subset(ages, Smokes == 0), mean(AgeAtDeath))
> ages <- transform(ages, NewPrediction = ifelse(Smokes == 0,
                     non.smokers.guess, smokers.guess))
> with(ages, mean((AgeAtDeath - NewPrediction) ^ 2))
[1] 26.50831

誤差が減り、精度が向上しました。

線形回帰入門

線形回帰で予測を行うには、次の2つの仮定を置くことが前提となります。

  • 可分性/加法性

    • 独立する特徴量が同時に起こった場合、効果は足し合わされること。
    • 例) アルコール中毒患者の平均寿命がそうでない人より1年短く、喫煙者の平均寿命がそうでない人より5年短い場合、アルコール中毒の喫煙者は6年短いと推測する
  • 単調性/線形性

    • 特徴量に応じて、予測値が常に上昇、または減少すること。

これを踏まえて、実習を開始。身長/体重データを読み込んで、簡単な予測を試します。

# データを読み込み
> heights.weights <- read.csv(
  file.path('data', '01_heights_weights_genders.csv'), header = TRUE, sep = ',')

散布図を描いてみます。 geom_smooth(method = 'lm') で回帰線も引きます。

# 散布図を描く
> plot <- ggplot(heights.weights, aes(x = Height, y = Weight)) +
  geom_point() + geom_smooth(method = 'lm')
> ggsave(plot = plot,
       filename = file.path("images", "height_weight.png"),
       height = 4.8,
       width = 7)

f:id:unageanu:20160113141059p:plain

青の線が身長から体重を予測する回帰線です。 身長60インチの人は、体重105ポンドくらいと予測できます。大体、妥当な感じ。

Rでは、lm 関数で、線形モデルでの回帰分析を行うことができます。

> fitted.regression <- lm(Weight ~ Height, data = heights.weights)

coef 関数を使うと、傾き( Height )と切片( Intercept )を確認できます。

> coef(fitted.regression)
(Intercept)      Height 
-350.737192    7.717288 

predict で予測値が得られます。

> head(predict(fitted.regression))                                                                                                                                 
       1        2        3        4        5        6 
219.1615 180.0725 221.1918 202.8314 188.5607 168.2737 

residuals を使うと、実際のデータとの誤差を計算できます。

head(residuals(fitted.regression))
         1          2          3          4          5          6 
 22.732083 -17.762074  -8.450953  17.211069  17.789073 -16.061519 

ウェブサイトのアクセス数を予測する

まずはデータを読み込み。

> top.1000.sites <- read.csv(file.path('data', 'top_1000_sites.tsv'), 
  sep = '\t', stringsAsFactors = FALSE)
> head(top.1000.sites)
  Rank          Site                     Category UniqueVisitors Reach
1    1  facebook.com              Social Networks      880000000  47.2
2    2   youtube.com                 Online Video      800000000  42.7
3    3     yahoo.com                  Web Portals      660000000  35.3
4    4      live.com               Search Engines      550000000  29.3
5    5 wikipedia.org Dictionaries & Encyclopedias      490000000  26.2
6    6       msn.com                  Web Portals      450000000  24.0
  PageViews HasAdvertising InEnglish TLD
1   9.1e+11            Yes       Yes com
2   1.0e+11            Yes       Yes com
3   7.7e+10            Yes       Yes com
4   3.6e+10            Yes       Yes com
5   7.0e+09             No       Yes org
6   1.5e+10            Yes       Yes com

このデータから、PageView を推測するのが今回の目的。 まずは、線形回帰が適用できるか、確認。ページビューの分布をみてみます。

> plot = ggplot(top.1000.sites, aes(x = PageViews)) + geom_density()
> ggsave(plot = plot,
       filename = file.path("images", "top_1000_sites_page_view.png"),
       height = 4.8,
       width = 7)

f:id:unageanu:20160113141103p:plain

ばらつきが大きすぎてよくわからない・・・。対数を取ってみます。

> plot = ggplot(top.1000.sites, aes(x = log(PageViews))) + geom_density()
> ggsave(plot = plot,
       filename = file.path("images", "top_1000_sites_page_view2.png"),
       height = 4.8,
       width = 7)

f:id:unageanu:20160113141104p:plain

この尺度なら、傾向が見えそう。UniqueVisitorPageView の散布図を描いてみます。

> ggplot(top.1000.sites, aes(x = log(PageViews), y = log(UniqueVisitors))) + geom_point()
> ggsave(file.path("images", "log_page_views_vs_log_visitors.png"))

f:id:unageanu:20160113141100p:plain

回帰直線も追加。

> ggplot(top.1000.sites, aes(x = log(PageViews), y = log(UniqueVisitors))) +
  geom_point() +
  geom_smooth(method = 'lm', se = FALSE)
> ggsave(file.path("images", "log_page_views_vs_log_visitors_with_lm.png"))

f:id:unageanu:20160113141101p:plain

うまくいっているっぽい。lm で回帰分析を実行。

> lm.fit <- lm(log(PageViews) ~ log(UniqueVisitors),
             data = top.1000.sites)

summary を使うと、分析結果の要約が表示されます。

> summary(lm.fit)

Call:
lm(formula = log(PageViews) ~ log(UniqueVisitors), data = top.1000.sites)

Residuals:
    Min      1Q  Median      3Q     Max 
-2.1825 -0.7986 -0.0741  0.6467  5.1549 

Coefficients:
                    Estimate Std. Error t value Pr(>|t|)    
(Intercept)         -2.83441    0.75201  -3.769 0.000173 ***
log(UniqueVisitors)  1.33628    0.04568  29.251  < 2e-16 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 1.084 on 998 degrees of freedom
Multiple R-squared:  0.4616,    Adjusted R-squared:  0.4611 
F-statistic: 855.6 on 1 and 998 DF,  p-value: < 2.2e-16

このうち、 Residual standard error が平均2乗誤差の平方根(RMSE)をとったもの。

さらに、他の情報も考慮に加えてみます。

> lm.fit <- lm(log(PageViews) ~ HasAdvertising + log(UniqueVisitors) + InEnglish,
             data = top.1000.sites)
> summary(lm.fit)

Call:
lm(formula = log(PageViews) ~ HasAdvertising + log(UniqueVisitors) + 
    InEnglish, data = top.1000.sites)

Residuals:
    Min      1Q  Median      3Q     Max 
-2.4283 -0.7685 -0.0632  0.6298  5.4133 

Coefficients:
                    Estimate Std. Error t value Pr(>|t|)    
(Intercept)         -1.94502    1.14777  -1.695  0.09046 .  
HasAdvertisingYes    0.30595    0.09170   3.336  0.00088 ***
log(UniqueVisitors)  1.26507    0.07053  17.936  < 2e-16 ***
InEnglishNo          0.83468    0.20860   4.001 6.77e-05 ***
InEnglishYes        -0.16913    0.20424  -0.828  0.40780    
---
Signif. codes:  0***0.001**0.01*0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 1.067 on 995 degrees of freedom
Multiple R-squared:  0.4798,    Adjusted R-squared:  0.4777 
F-statistic: 229.4 on 4 and 995 DF,  p-value: < 2.2e-16

精度が少し上がりました。

機械学習手習い: 重要度による電子メールの並び替え

「入門 機械学習」手習い、4日目。「4章 順位づけ:優先トレイ」です。

www.amazon.co.jp

電子メールを重要度で順位づけするシステムを作ります

並び替えのアプローチ

以下の素性を使って、メールに優先度をつけます。

  • 1) 送信者のメッセージ数
    • やり取りが多い送信者からのメールは重要とみなす
  • 2) スレッドの活性
    • 活発にやり取りされているスレッドのメールは優先度を高くする
    • 活性度は、1秒あたりのスレッドのメール数で算出。多いほど、活性が高い
  • 3) サブジェクトと本文に含まれる単語
    • 活動的なスレッドのサブジェクトで頻出する単語をサブジェクトに含むメールは、重要度が高いとみなす
    • 本文に頻出単語が含まれているメールは重要度が高いとみなす

必要なモジュールとデータの読み込み

> setwd("04-Ranking/")  
> library('tm')
> library('ggplot2')
> library('plyr')
> library('reshape')

# メールデータは、3章で使った非スパムメール(易)を使う
> data.path <- file.path("..", "03-Classification", "data")
> easyham.path <- file.path(data.path, "easy_ham")

メールから素性を取り出す

ファイルを読み込んで、本文を返す関数を作成します。

# メールファイルから、全データを読み込んで返す
> msg.full <- function(path) {
  con <- file(path, open = "rt", encoding = "latin1")
  msg <- readLines(con)
  close(con)
  return(msg)
}

# メールデータからFromアドレスを取り出す
> get.from <- function(msg.vec) {
  from <- msg.vec[grepl("From: ", msg.vec)]
  from <- strsplit(from, '[":<> ]')[1]
  from <- from[which(from  != "" & from != " ")]
  return(from[grepl("@", from)][1])
}

# メールデータから、サブジェクトを取り出す
> get.subject <- function(msg.vec) {
  subj <- msg.vec[grepl("Subject: ", msg.vec)]
  if(length(subj) > 0) {
    return(strsplit(subj, "Subject: ")[1][2])
  } else {
    return("")
  }
}

# メールデータから、本文を取り出す
> get.msg <- function(msg.vec) {
  msg <- msg.vec[seq(which(msg.vec == "")[1] + 1, length(msg.vec), 1)]
  return(paste(msg, collapse = "\n"))
}

# メールデータから、受信日時を取り出す
> get.date <- function(msg.vec) {
  date.grep <- grepl("^Date: ", msg.vec)
  date.grep <- which(date.grep == TRUE)
  date <- msg.vec[date.grep[1]]
  date <- strsplit(date, "\\+|\\-|: ")[1][2]
  date <- gsub("^\\s+|\\s+$", "", date)
  return(strtrim(date, 25))
}

# メールを読み込んで、必要な素性を返す。
> parse.email <- function(path) {
  full.msg <- msg.full(path)
  date <- get.date(full.msg)
  from <- get.from(full.msg)
  subj <- get.subject(full.msg)
  msg <- get.msg(full.msg)
  return(c(date, from, subj, msg, path))
}

動作テスト。

> parse.email("../03-Classification/data/easy_ham/00111.a478af0547f2fd548f7b412df2e71a92")
[1] "Mon, 7 Oct 2002 10:37:26"
[2] "niall@linux.ie"  
...

メールデータを読み込んで素性をデータフレームにまとめる

# 全メールを解析
> easyham.docs <- dir(easyham.path)
> easyham.docs <- easyham.docs[which(easyham.docs != "cmds")]
> easyham.parse <- lapply(easyham.docs, function(p) parse.email(file.path(easyham.path, p)))
# データフレームに変換
> ehparse.matrix <- do.call(rbind, easyham.parse)
> allparse.df <- data.frame(ehparse.matrix, stringsAsFactors = FALSE)
> names(allparse.df) <- c("Date", "From.EMail", "Subject", "Message", "Path")

できた。

> head(allparse.df)
                       Date                   From.EMail
1 Thu, 22 Aug 2002 18:26:25            kre@munnari.OZ.AU
2 Thu, 22 Aug 2002 12:46:18 steve.burt@cursor-system.com
3 Thu, 22 Aug 2002 13:52:38                timc@2ubh.com
4 Thu, 22 Aug 2002 09:15:25             monty@roscom.com

データの調整

送信日時が文字列になっているので、POSIXオブジェクトに変換します。

# 日本語環境だと、%b が Aug などの月名にマッチしないため、変更しておく。
> Sys.setlocale(locale="C")
> date.converter <- function(dates, pattern1, pattern2) {
  pattern1.convert <- strptime(dates, pattern1)
  pattern2.convert <- strptime(dates, pattern2)
  pattern1.convert[is.na(pattern1.convert)] <- pattern2.convert[is.na(pattern1.convert)]
  return(pattern1.convert)
}
> pattern1 <- "%a, %d %b %Y %H:%M:%S"
> pattern2 <- "%d %b %Y %H:%M:%S"
> allparse.df$Date <- date.converter(allparse.df$Date, pattern1, pattern2)
> head(allparse.df)                                                                                                                                                
                 Date                   From.EMail
1 2002-08-22 18:26:25            kre@munnari.OZ.AU
2 2002-08-22 12:46:18 steve.burt@cursor-system.com
3 2002-08-22 13:52:38                timc@2ubh.com
4 2002-08-22 09:15:25             monty@roscom.com
5 2002-08-22 14:38:22    Stewart.Smith@ee.ed.ac.uk

# ロケールを戻しておく。
> Sys.setlocale(local="ja_JP.UTF-8")

また、サブジェクトと送信者アドレスを小文字に変更します。

> allparse.df$Subject <- tolower(allparse.df$Subject)
> allparse.df$From.EMail <- tolower(allparse.df$From.EMail)

最後に、送信日時でソート。

> priority.df <- allparse.df[with(allparse.df, order(Date)), ]

データの最初の半分を訓練データに使うので、別の変数に格納しておきます。

> priority.train <- priority.df[1:(round(nrow(priority.df) / 2)), ]

送信者別メール件数での重みづけ

送信者ごとのメール件数で重みづけを行うため、まずは、件数がどんな感じになっているか確認します。

送信者ごとのメール件数を集計。

> from.weight <- melt(with(priority.train, table(From.EMail)))                                                                                        
> from.weight <- from.weight[with(from.weight, order(value)), ]
> head(from.weight)
                    From.EMail value
1            adam@homeport.org     1
2     admin@networksonline.com     1
4 albert.white@ireland.sun.com     1
5                andr@sandy.ru     1
6             andris@aernet.ru     1
9              antoin@eire.com     1

> summary(from.weight$value)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
   1.00    1.00    2.00    4.63    4.00   55.00 

平均は4.63通。最大値は55通でばらつきが大きいかな? 7通以上送信しているアドレスをグラフに表示してみます。

> from.ex <- subset(from.weight, value >= 7)
> from.scales <- ggplot(from.ex) +
  geom_rect(aes(xmin = 1:nrow(from.ex) - 0.5,
                xmax = 1:nrow(from.ex) + 0.5,
                ymin = 0,
                ymax = value,
                fill = "lightgrey",
                color = "darkblue")) +
  scale_x_continuous(breaks = 1:nrow(from.ex), labels = from.ex$From.EMail) +
  coord_flip() +
  scale_fill_manual(values = c("lightgrey" = "lightgrey"), guide = "none") +
  scale_color_manual(values = c("darkblue" = "darkblue"), guide = "none") +
  ylab("Number of Emails Received (truncated at 6)") +
  xlab("Sender Address") +
  theme_bw() +
  theme(axis.text.y = element_text(size = 5, hjust = 1))
> ggsave(plot = from.scales,
       filename = file.path("images", "0011_from_scales.png"),
       height = 4.8,
       width = 7)

f:id:unageanu:20160112182344p:plain

一部の送信者が、平均的な送信者の10倍以上、メールを送信しています。送信数をそのまま重みにしてしまうと、これらの特殊な送信者の優先度が高くなりすぎてしまいます。グラフを見ると、指数関数的に増えている感じなので、自然対数を使って重みを調整します。

# 対数を取る。重みがゼロにならないように、値に1を足す。
> from.weight <- transform(from.weight, 
  Weight = log(value + 1), log10Weight = log10(value + 1))

スレッド活性での重みづけ

まずは、スレッド別のメール件数を集計します。

メールがスレッドに属するかどうかは、メールのサブジェクトを見て判定します。

# re:を除いたサブジェクト(=スレッド名)と送信者を取り出す。
> find.threads <- function(email.df) {
  response.threads <- strsplit(email.df$Subject, "re: ")
  is.thread <- sapply(response.threads, function(subj) ifelse(subj[1] == "", TRUE, FALSE))
  threads <- response.threads[is.thread]
  senders <- email.df$From.EMail[is.thread]
  threads <- sapply(threads, function(t) paste(t[2:length(t)], collapse = "re: "))
  return(cbind(senders,threads))
}
> threads.matrix <- find.threads(priority.train)
> head(threads.matrix)
     senders                     threads                                      
[1,] "kre@munnari.oz.au"         "new sequences window"                       
[2,] "stewart.smith@ee.ed.ac.uk" "[zzzzteana] nothing like mama used to make" 
[3,] "martin@srv0.ems.ed.ac.uk"  "[zzzzteana] nothing like mama used to make" 
[4,] "stewart.smith@ee.ed.ac.uk" "[zzzzteana] nothing like mama used to make" 
[5,] "marc@perkel.com"           "[sadev] live rule updates after release ???"
[6,] "cwg-exmh@deepeddy.com"     "new sequences window"    

次に、スレッドごとの活性度を集計します。

# スレッドごとの活性度一覧を返す
> get.threads <- function(threads.matrix, email.df) {
  threads <- unique(threads.matrix[, 2])
  thread.counts <- lapply(threads, function(t) thread.counts(t, email.df))
  thread.matrix <- do.call(rbind, thread.counts)
  return(cbind(threads, thread.matrix))
}

# スレッド名に属するメールの活性度を返す
> thread.counts <- function(thread, email.df) {
  # メールから、スレッドに属するメールの送信日時を取り出す
  thread.times <- email.df$Date[which(email.df$Subject == thread |
                                      email.df$Subject == paste("re:", thread))]
  freq <- length(thread.times)  # スレッドのメールの総数
  min.time <- min(thread.times) # 送信日時の最小値
  max.time <- max(thread.times) # 送信日時の最大値
  time.span <- as.numeric(difftime(max.time, min.time, units = "secs"))
  if(freq < 2) {
    # メールが1通しかない場合(返信がなくスレッドになっていない場合)、NAを返す
    return(c(NA, NA, NA))
  } else {
    trans.weight <- freq / time.span # 1秒当たりのメール送信数
    log.trans.weight <- 10 + log(trans.weight, base = 10) 
      # 対数を取る。負にならないよう、10を足す(アフィん変換)
    return(c(freq, time.span, log.trans.weight))
  }
}
> thread.weights <- data.frame(thread.weights, stringsAsFactors = FALSE)
> names(thread.weights) <- c("Thread", "Freq", "Response", "Weight")
> thread.weights$Freq <- as.numeric(thread.weights$Freq)
> thread.weights$Response <- as.numeric(thread.weights$Response)
> thread.weights$Weight <- as.numeric(thread.weights$Weight)
> thread.weights <- subset(thread.weights, is.na(thread.weights$Freq) == FALSE)
> head(thread.weights)
                                      Thread Freq Response   Weight
1   please help a newbie compile mplayer :-)    4    42309 5.975627
2                 prob. w/ install/uninstall    4    23745 6.226488
3                       http://apt.nixia.no/   10   265303 5.576258
4         problems with 'apt-get -f install'    3    55960 5.729244
5                   problems with apt update    2     6347 6.498461
6 about apt, kernel updates and dist-upgrade    5   240238 5.318328

また、送信者での重みづけの補完として、「送信者が何スレッドに参加しているか」を示す重みも計算しておきます。

> email.thread <- function(threads.matrix) {
  senders <- threads.matrix[, 1]
  senders.freq <- table(senders)
  senders.matrix <- cbind(names(senders.freq),
                          senders.freq,
                          log(senders.freq + 1))
  senders.df <- data.frame(senders.matrix, stringsAsFactors=FALSE)
  row.names(senders.df) <- 1:nrow(senders.df)
  names(senders.df) <- c("From.EMail", "Freq", "Weight")
  senders.df$Freq <- as.numeric(senders.df$Freq)
  senders.df$Weight <- as.numeric(senders.df$Weight)
  return(senders.df)
}
> senders.df <- email.thread(threads.matrix)
> head(senders.df)
                    From.EMail Freq    Weight
1            adam@homeport.org    1 0.6931472
2        aeriksson@fastmail.fm    5 1.7917595
3 albert.white@ireland.sun.com    1 0.6931472
4          alex@netwindows.org    1 0.6931472
5                andr@sandy.ru    1 0.6931472
6             andris@aernet.ru    1 0.6931472

サブジェクトと本文に含まれる単語による重みづけ

まずは、サブジェクト。

  • スレッド名に含まれる単語一覧を抽出して、単語ごとに重みを計算します。
  • 単語を含む全スレッドのweightを取り出して、その平均を重みとして使います。
# 単語と出現頻度の一覧を返す
> term.counts <- function(term.vec, control) {
  vec.corpus <- Corpus(VectorSource(term.vec))
  vec.tdm <- TermDocumentMatrix(vec.corpus, control = control)
  return(rowSums(as.matrix(vec.tdm)))
}

# スレッド名に含まれる単語一覧を抽出
> thread.terms <- term.counts(thread.weights$Thread, control = list(stopwords = TRUE))
> thread.terms <- names(thread.terms) # 出現頻度は使わないので捨てる
> head(thread.terms)
[1] "--with"    ":-)"       "..."       ".doc"      "'apt-get"  "\"holiday"

# 単語ごとに重みを算出
# 単語を含む全スレッドのweightを取り出して、その平均を重みとして使う 
> term.weights <- sapply(thread.terms, 
  function(t) mean(thread.weights$Weight[grepl(t, thread.weights$Thread, fixed = TRUE)]))
> head(term.weights)
  --with      :-)      ...     .doc 'apt-get "holiday 
7.109579 6.103883 6.050786 5.725911 5.729244 7.197911 
# 整形
> term.weights <- data.frame(list(Term = names(term.weights),
  Weight = term.weights), stringsAsFactors = FALSE, row.names = 1:length(term.weights))
> head(term.weights)
      Term   Weight
1   --with 7.109579
2      :-) 6.103883
3      ... 6.050786
4     .doc 5.725911
5 'apt-get 5.729244
6 "holiday 7.197911

次に本文。

# 本文に含まれる単語と頻度を集計
> msg.terms <- term.counts(priority.train$Message,
    control = list(stopwords = TRUE, removePunctuation = TRUE, removeNumbers = TRUE))
# 重みを算出。ここでも対数をとる
> msg.weights <- data.frame(list(Term = names(msg.terms), Weight = log(msg.terms, base = 10)), 
    stringsAsFactors = FALSE, row.names = 1:length(msg.terms))
# 重みがゼロのものは除外
> msg.weights <- subset(msg.weights, Weight > 0)

これで、すべての重みデータフレームがそろいました。

順位づけを行う

重要度を計算する関数を定義します。

# 単語の重みを返す
# 単語、検索する重みデータフレーム、term.weightが検索対象かどうか、を引数で受け取り、重みを返す。
> get.weights <- function(search.term, weight.df, term = TRUE) {
  if(length(search.term) > 0) {
    # weight.dfがterm.weightかどうかで列名が異なるので、ここで調整
    if(term) {
      term.match <- match(names(search.term), weight.df$Term)
    } else {
      term.match <- match(search.term, weight.df$Thread)
    }
    match.weights <- weight.df$Weight[which(!is.na(term.match))]
    if(length(match.weights) < 1) {
      # マッチする件数がゼロの場合、1を使う
      return(1)
    } else {
      # マッチする件数が1以上の場合、平均を使う
      return(mean(match.weights))
    }
  } else {
    return(1)
  }
}

# メールの重要度を返す
> rank.message <- function(path) {
  
  # メールを解析
  msg <- parse.email(path)
  
  # 送信者が送信したメール数に基づく重みを取得
  from <- ifelse(length(which(from.weight$From.EMail == msg[2])) > 0,
                 from.weight$Weight[which(from.weight$From.EMail == msg[2])],
                 1)
  
  # 送信者が参加したスレッド数に基づく重みを取得
  thread.from <- ifelse(length(which(senders.df$From.EMail == msg[2])) > 0,
                        senders.df$Weight[which(senders.df$From.EMail == msg[2])],
                        1)
  
  # メールがスレッドへの投降かどうかを判定し、スレッドへの投稿であれば、スレッドの重みを取得
  subj <- strsplit(tolower(msg[3]), "re: ")
  is.thread <- ifelse(subj[[1]][1] == "", TRUE, FALSE)
  if(is.thread){
    activity <- get.weights(subj[[1]][2], thread.weights, term = FALSE)
  } else {
    # スレッドへの投稿でない場合、重みは1
    activity <- 1
  }
  
  # メールサブジェクトに基づく重みを取得
  thread.terms <- term.counts(msg[3], control = list(stopwords = TRUE))
  thread.terms.weights <- get.weights(thread.terms, term.weights)
  
  # メール本文に基づく重みを取得
  msg.terms <- term.counts(msg[4],
                           control = list(stopwords = TRUE,
                           removePunctuation = TRUE,
                           removeNumbers = TRUE))
  msg.weights <- get.weights(msg.terms, msg.weights)
  
  # 重みをすべて掛け合わせて、重要度を算出する
  rank <- prod(from,
               thread.from,
               activity, 
               thread.terms.weights,
               msg.weights)
  
  return(c(msg[1], msg[2], msg[3], rank))
}

動作テスト。

> rank.message("../03-Classification/data/easy_ham/00111.a478af0547f2fd548f7b412df2e71a92")
[1] "Mon, 7 Oct 2002 10:37:26"                                
[2] "niall@linux.ie"                                          
[3] "Re: [ILUG] Interesting article on free software licences"
[4] "5.27542087468428"

優先メールとみなす閾値が妥当か確認する

今回は、優先度の中央値を閾値として使います。 データの半分を使って、閾値が妥当かチェックします。

train.paths <- priority.df$Path[1:(round(nrow(priority.df) / 2))]
test.paths <- priority.df$Path[((round(nrow(priority.df) / 2)) + 1):nrow(priority.df)]

# train.pathsに含まれるメールの重要度を算出
train.ranks <- suppressWarnings(lapply(train.paths, rank.message))
# データフレームに変換
> train.ranks.matrix <- do.call(rbind, train.ranks)
> train.ranks.matrix <- cbind(train.paths, train.ranks.matrix, "TRAINING")
> train.ranks.df <- data.frame(train.ranks.matrix, stringsAsFactors = FALSE)
> names(train.ranks.df) <- c("Message", "Date", "From", "Subj", "Rank", "Type")
> train.ranks.df$Rank <- as.numeric(train.ranks.df$Rank)
> head(train.ranks.df)
                                                                    Message
1 ../03-Classification/data/easy_ham/01061.6610124afa2a5844d41951439d1c1068
2 ../03-Classification/data/easy_ham/01062.ef7955b391f9b161f3f2106c8cda5edb
3 ../03-Classification/data/easy_ham/01063.ad3449bd2890a29828ac3978ca8c02ab
4 ../03-Classification/data/easy_ham/01064.9f4fc60b4e27bba3561e322c82d5f7ff
5 ../03-Classification/data/easy_ham/01070.6e34c1053a1840779780a315fb083057
6 ../03-Classification/data/easy_ham/01072.81ed44b31e111f9c1e47e53f4dfbefe3
                       Date                   From
1 Thu, 31 Jan 2002 22:44:14  robinderbains@shaw.ca
2      01 Feb 2002 00:53:41 lance_tt@bellsouth.net
3 Fri, 01 Feb 2002 02:01:44  robinderbains@shaw.ca
4  Fri, 1 Feb 2002 10:29:23      matthias@egwn.net
5  Fri, 1 Feb 2002 12:42:02     bfrench@ematic.com
6  Fri, 1 Feb 2002 13:39:31     bfrench@ematic.com
                                          Subj       Rank     Type
1     Please help a newbie compile mplayer :-)   3.614003 TRAINING
2 Re: Please help a newbie compile mplayer :-) 120.742481 TRAINING
3 Re: Please help a newbie compile mplayer :-)  20.348502 TRAINING
4 Re: Please help a newbie compile mplayer :-) 307.809626 TRAINING
5                   Prob. w/ install/uninstall   3.653047 TRAINING
6               RE: Prob. w/ install/uninstall  21.685750 TRAINING

閾値を中央値に設定して、訓練データの重要度と密度を図にします。

# 閾値を中央値に設定
> priority.threshold <- median(train.ranks.df$Rank)

# 訓練データの重要度と密度を図示
> threshold.plot <- ggplot(train.ranks.df, aes(x = Rank)) +
  stat_density(aes(fill="darkred")) +
  geom_vline(xintercept = priority.threshold, linetype = 2) +
  scale_fill_manual(values = c("darkred" = "darkred"), guide = "none") +
  theme_bw()
> ggsave(plot = threshold.plot,
       filename = file.path("images", "01_threshold_plot.png"),
       height = 4.7,
       width = 7)

f:id:unageanu:20160112182342p:plain

図中の点線が中央値。 ここを閾値にすれば、ランクの高い裾部分と、密度の高い部分の電子メールもある程度含まれるので、これらを優先メールと判定したのでよさそう。

残りのデータも加えて、図にしてみます。

# test.ranksに含まれるメールの重要度を算出
> train.ranks.df$Priority <- ifelse(train.ranks.df$Rank >= priority.threshold, 1, 0)
> test.ranks <- suppressWarnings(lapply(test.paths,rank.message))
> test.ranks.matrix <- do.call(rbind, test.ranks)
> test.ranks.matrix <- cbind(test.paths, test.ranks.matrix, "TESTING")
> test.ranks.df <- data.frame(test.ranks.matrix, stringsAsFactors = FALSE)
> names(test.ranks.df) <- c("Message","Date","From","Subj","Rank","Type")
> test.ranks.df$Rank <- as.numeric(test.ranks.df$Rank)
> test.ranks.df$Priority <- ifelse(test.ranks.df$Rank >= priority.threshold, 1, 0)

# 訓練用データとテスト用データをマージ
> final.df <- rbind(train.ranks.df, test.ranks.df)
> final.df$Date <- date.converter(final.df$Date, pattern1, pattern2)
> final.df <- final.df[rev(with(final.df, order(Date))), ]
> head(final.df)
                                                                       Message
2500 ../03-Classification/data/easy_ham/00883.c44a035e7589e83076b7f1fed8fa97d5
2499 ../03-Classification/data/easy_ham/02500.05b3496ce7bca306bed0805425ec8621
2498 ../03-Classification/data/easy_ham/02499.b4af165650f138b10f9941f6cc5bce3c
2497 ../03-Classification/data/easy_ham/02498.09835f512f156da210efb99fcc523e21
2496 ../03-Classification/data/easy_ham/02497.60497db0a06c2132ec2374b2898084d3
2495 ../03-Classification/data/easy_ham/02496.aae0c81581895acfe65323f344340856
     Date                    From
2500 <NA>             sdw@lig.net
2499 <NA> ilug_gmc@fiachra.ucd.ie
2498 <NA>          mwh@python.net
2497 <NA>            nickm@go2.ie
2496 <NA>       phil@techworks.ie
2495 <NA>           timc@2ubh.com
                                                      Subj      Rank    Type
2500                                       Re: ActiveBuddy  6.219744 TESTING
2499                              Re: [ILUG] Linux Install  2.278890 TESTING
2498 [Spambayes] Re: New Application of SpamBayesian tech?  4.265954 TESTING
2497                              Re: [ILUG] Linux Install  4.576643 TESTING
2496                              Re: [ILUG] Linux Install  3.652100 TESTING
2495                          [zzzzteana] Surfing the tube 27.987331 TESTING
     Priority
2500        0
2499        0
2498        0
2497        0
2496        0
2495        1

# 図示
> testing.plot <- ggplot(subset(final.df, Type == "TRAINING"), aes(x = Rank)) +
  stat_density(aes(fill = Type, alpha = 0.65)) +
  stat_density(data = subset(final.df, Type == "TESTING"),
               aes(fill = Type, alpha = 0.65)) +
  geom_vline(xintercept = priority.threshold, linetype = 2) +
  scale_alpha(guide = "none") +
  scale_fill_manual(values = c("TRAINING" = "darkred", "TESTING" = "darkblue")) +
  theme_bw()
> ggsave(plot = testing.plot,
       filename = file.path("images", "02_testing_plot.png"),
       height = 4.7,
       width = 7)

f:id:unageanu:20160112182343p:plain

テストデータは、訓練データより優先度低のメールが多く含まれる結果になっています。 これは、テストデータの素性に、訓練データに含まれないデータが多く含まれ、これらが順序付け時に無視されているためであり、妥当らしい。ふむ。

最後に優先度一覧をcsvに出力しておしまい。

write.csv(final.df, file.path("data", "final_df.csv"), row.names = FALSE)

機械学習手習い: スパムフィルタを作る

「入門 機械学習」手習い、3日目。「3章 分類:スパムフィルタ」です。

www.amazon.co.jp

ナイーブベイズ分類器を作って、メールがスパムかどうかを判定するフィルタを作ります。

分類器の仕組み

  • 1) 以下の単語セットを作成
    • (a) スパムメッセージに出現しやすい単語とその出現確率
    • (b) スパムメッセージに出現しにくい単語とその出現確率
  • 2) で作成した単語セットを元に、メール本文を評価し、以下を算出
    • (a2) メールをスパムと仮定した時の尤もらしさ
    • (b2) メールを非スパムと仮定した時の尤もらしさ
  • 3) a2 > b2 となるメールをスパムと判定する

という感じで判定を行います。

必要なモジュールとデータの読み込み

> setwd("03-Classification/")
> library('tm')
> library('ggplot2')

# テスト用データ
# 分類機の訓練用
> spam.path <- file.path("data", "spam")        # スパムデータ
> easyham.path <- file.path("data", "easy_ham") # 非スパムデータ(易)
> hardham.path <- file.path("data", "hard_ham") # 非スパムデータ(難)
# 分類機のテスト用
> spam2.path <- file.path("data", "spam_2")        # スパムデータ
> easyham2.path <- file.path("data", "easy_ham_2") # 非スパムデータ(易)
> hardham2.path <- file.path("data", "hard_ham_2") # 非スパムデータ(難)

メールから本文を取り出す

ファイルを読み込んで、本文を返す関数を作成します。

> get.msg <- function(path) {
  con <- file(path, open = "rt", encoding = "latin1")
  text <- readLines(con)
  # The message always begins after the first full line break
  msg <- text[seq(which(text == "")[1] + 1, length(text), 1)]
  close(con)
  return(paste(msg, collapse = "\n"))
}

動作テスト。

> get.msg("data/spam/00001.7848dde101aa985090474a91ec93fcf0")
[1] "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0 Transitio...

sapply を使って、data/spam 内のスパムメールから本文を読み込みます。

> spam.docs <- dir(spam.path)
> spam.docs <- spam.docs[which(spam.docs != "cmds")]
> all.spam <- sapply(spam.docs,function(p) get.msg(file.path(spam.path, p)))
> head(all.spam)

単語文書行列を作る

単語を行、文書を列として、文書中の単語の出現数をカウントした、単語文書行列(TDM:Term Document Matrix)を作ります。

↓こんなイメージ

a.txt b.txt c.txt
hello 2 1 0
world 0 2 3


まずは、TDMを作成する関数を定義。

> get.tdm <- function(doc.vec){
  control <- list(stopwords = TRUE,
                  removePunctuation = TRUE,
                  removeNumbers = TRUE,
                  minDocFreq = 2)
  doc.corpus <- Corpus(VectorSource(doc.vec))
  doc.dtm <- TermDocumentMatrix(doc.corpus, control)
  return(doc.dtm)
}

all.spam から、TDMを作成します。

> spam.tdm <- get.tdm(all.spam)

スパムの訓練データを作る

スパムに含まれる単語の出現確率を集計します。

# all.spamのTDMをRの行列に変換
> spam.matrix <- as.matrix(spam.tdm)
# 各単語の、全スパム中での出現頻度をカウント
> spam.counts <- rowSums(spam.matrix)
# データフレームに変換
> spam.df <- data.frame(cbind(names(spam.counts),
                      as.numeric(spam.counts)),
                      stringsAsFactors = FALSE)
> names(spam.df) <- c("term", "frequency")
# 出現確率を算出
> spam.df$frequency <- as.numeric(spam.df$frequency)
> spam.occurrence <- sapply(1:nrow(spam.matrix), function(i) {
  length(which(spam.matrix[i, ] > 0)) / ncol(spam.matrix)
})
> spam.density <- spam.df$frequency / sum(spam.df$frequency)
> spam.df <- transform(spam.df,
                     density = spam.density,
                     occurrence = spam.occurrence)

訓練データの中身を確認。

> head(spam.df[with(spam.df, order(-occurrence)),])
        term frequency     density occurrence
7693   email       813 0.005855914      0.566
18706 please       425 0.003061210      0.508
14623   list       409 0.002945964      0.444
27202   will       828 0.005963957      0.422
2970    body       379 0.002729879      0.408
9369    free       543 0.003911146      0.390

スパムメール中の56%が email を含んでいます。

非スパムの訓練データを作る

スパムの訓練データを作ったのと同じ手順で、非スパムの訓練データも作ります。

> easyham.docs <- dir(easyham.path)
> easyham.docs <- easyham.docs[which(easyham.docs != "cmds")]
# 最初の分類機を作る際に、各メッセージがスパムである確率とそうでない確率を等しいと仮定するため、
# 訓練データもスパムの文書数と同じ数だけに限定する。
> all.easyham <- sapply(easyham.docs[1:length(spam.docs)],
                      function(p) get.msg(file.path(easyham.path, p)))
> easyham.tdm <- get.tdm(all.easyham)
> easyham.matrix <- as.matrix(easyham.tdm)
> easyham.counts <- rowSums(easyham.matrix)
> easyham.df <- data.frame(cbind(names(easyham.counts),
                         as.numeric(easyham.counts)),
                         stringsAsFactors = FALSE)
> names(easyham.df) <- c("term", "frequency")
> easyham.df$frequency <- as.numeric(easyham.df$frequency)
> easyham.occurrence <- sapply(1:nrow(easyham.matrix), function(i) {
 length(which(easyham.matrix[i, ] > 0)) / ncol(easyham.matrix)
})
> easyham.density <- easyham.df$frequency / sum(easyham.df$frequency)
> easyham.df <- transform(easyham.df,
                        density = easyham.density,
                        occurrence = easyham.occurrence)

訓練データの中身を確認。

> head(easyham.df[with(easyham.df, order(-occurrence)),])
       term frequency     density occurrence
5193  group       232 0.003317366      0.388
12877   use       271 0.003875027      0.378
13511 wrote       237 0.003388861      0.378
1629    can       348 0.004976049      0.368
7244   list       248 0.003546150      0.368
8581    one       356 0.005090441      0.336

分類器を作る

メールファイルを受け取り、それがスパム、または、非スパムである確率を計算する分類器を定義します。

> classify.email <- function(path, training.df, prior = 0.5, c = 1e-6) {

  # ファイルから本文を取り出し、単語の出現数をカウント
  msg <- get.msg(path)
  msg.tdm <- get.tdm(msg)
  msg.freq <- rowSums(as.matrix(msg.tdm))

  # メール内の単語のうち、訓練データに含まれる単語を取得
  msg.match <- intersect(names(msg.freq), training.df$term)
  
  # 単語の出現確率を掛け合わせて、条件付き確率を算出する
  # この時、訓練データに含まれない単語は、非常に小さい確率0.0001として扱う。
  if(length(msg.match) < 1) {
    # 訓練データに含まれる単語が1つもない場合、条件付き確率は、
    # 事前確率(prior) * (0.0001の単語数乗)
    # となる。
    return(prior * c ^ (length(msg.freq)))
  } else {
    # 訓練データから、単語ごとの出現確率を取り出す。
    match.probs <- training.df$occurrence[match(msg.match, training.df$term)]
    # 条件付き確率を計算
    # 事前確率(prior) * (訓練データに含まれる単語の出現確率の積) * (0.0001の訓練データに含まれない単語数乗)
    return(prior * prod(match.probs) * c ^ (length(msg.freq) - length(msg.match)))
  }
}

判定してみる

非スパム(難)のデータを判定してみます。

> hardham.docs <- dir(hardham.path)
> hardham.docs <- hardham.docs[which(hardham.docs != "cmds")]
> hardham.spamtest <- sapply(hardham.docs,
                           function(p) classify.email(file.path(hardham.path, p), training.df = spam.df))
> hardham.hamtest <- sapply(hardham.docs,
                          function(p) classify.email(file.path(hardham.path, p), training.df = easyham.df))
> hardham.res <- ifelse(hardham.spamtest > hardham.hamtest,
                      TRUE,
                      FALSE)
> summary(hardham.res)
   Mode   FALSE    TRUE    NA's 
logical     243       6       0 

データはすべて非スパムのものなので、誤判定は6件、偽陽性率は2.4%。

テスト用データすべてを判定してみる

スパム判定を行う関数を作成。

spam.classifier <- function(path) {
  pr.spam <- classify.email(path, spam.df)
  pr.ham <- classify.email(path, easyham.df)
  return(c(pr.spam, pr.ham, ifelse(pr.spam > pr.ham, 1, 0)))
}

テスト用のメールデータをすべて判定してみます。

> easyham2.docs <- dir(easyham2.path)
> easyham2.docs <- easyham2.docs[which(easyham2.docs != "cmds")]

> hardham2.docs <- dir(hardham2.path)
> hardham2.docs <- hardham2.docs[which(hardham2.docs != "cmds")]

> spam2.docs <- dir(spam2.path)
> spam2.docs <- spam2.docs[which(spam2.docs != "cmds")]

> easyham2.class <- suppressWarnings(lapply(easyham2.docs, function(p) {
  spam.classifier(file.path(easyham2.path, p))
}))
> hardham2.class <- suppressWarnings(lapply(hardham2.docs, function(p) {
  spam.classifier(file.path(hardham2.path, p))
}))
> spam2.class <- suppressWarnings(lapply(spam2.docs, function(p) {
  spam.classifier(file.path(spam2.path, p))
}))

> easyham2.matrix <- do.call(rbind, easyham2.class)
> easyham2.final <- cbind(easyham2.matrix, "EASYHAM")

> hardham2.matrix <- do.call(rbind, hardham2.class)
> hardham2.final <- cbind(hardham2.matrix, "HARDHAM")

> spam2.matrix <- do.call(rbind, spam2.class)
> spam2.final <- cbind(spam2.matrix, "SPAM")

> class.matrix <- rbind(easyham2.final, hardham2.final, spam2.final)
> class.df <- data.frame(class.matrix, stringsAsFactors = FALSE)
> names(class.df) <- c("Pr.SPAM" ,"Pr.HAM", "Class", "Type")
> class.df$Pr.SPAM <- as.numeric(class.df$Pr.SPAM)
> class.df$Pr.HAM <- as.numeric(class.df$Pr.HAM)
> class.df$Class <- as.logical(as.numeric(class.df$Class))
> class.df$Type <- as.factor(class.df$Type)

文書ごとのスパム/非スパムの尤もらしさ、分類結果、メールの種別を含むデータフレームができました。

> head(class.df)
        Pr.SPAM        Pr.HAM Class    Type
1  0.000000e+00  0.000000e+00 FALSE EASYHAM
2 5.352364e-248 1.159512e-155 FALSE EASYHAM
3  0.000000e+00 5.103377e-216 FALSE EASYHAM
4  0.000000e+00  0.000000e+00 FALSE EASYHAM
5 2.083521e-169 1.221918e-108 FALSE EASYHAM
6  0.000000e+00  0.000000e+00 FALSE EASYHAM

文書の種類ごとに結果を集計してみます。

> get.results <- function(bool.vector) {
  results <- c(length(bool.vector[which(bool.vector == FALSE)]) / length(bool.vector),
               length(bool.vector[which(bool.vector == TRUE)]) / length(bool.vector))
  return(results)
}
> easyham2.col <- get.results(subset(class.df, Type == "EASYHAM")$Class)
> hardham2.col <- get.results(subset(class.df, Type == "HARDHAM")$Class)
> spam2.col <- get.results(subset(class.df, Type == "SPAM")$Class)

> class.res <- rbind(easyham2.col, hardham2.col, spam2.col)
> colnames(class.res) <- c("NOT SPAM", "SPAM")
> print(class.res)
              NOT SPAM       SPAM
easyham2.col 0.9871429 0.01285714
hardham2.col 0.9677419 0.03225806
spam2.col    0.4631353 0.53686471

非スパム(easyham2, hardham2)を間違ってスパムと判定する確率(偽陽性率)はそれぞれ1%,3%と低く、うまく分類できているもよう。 ただ、スパム(spam2)を間違ってスパムでないと判定する確率(偽陰性率)は46%。あれ、ちょっと高いかも?

結果を分散図にしてみる

> class.plot <- ggplot(class.df, aes(x = log(Pr.HAM), log(Pr.SPAM))) +
    geom_point(aes(shape = Type, alpha = 0.5)) +
    geom_abline(intercept = 0, slope = 1) +
    scale_shape_manual(values = c("EASYHAM" = 1,
                                  "HARDHAM" = 2,
                                  "SPAM" = 3),
                       name = "Email Type") +
    scale_alpha(guide = "none") +
    xlab("log[Pr(HAM)]") +
    ylab("log[Pr(SPAM)]") +
    theme_bw() +
    theme(axis.text.x = element_blank(), axis.text.y = element_blank())
> ggsave(plot = class.plot,
       filename = file.path("images", "03_final_classification.png"),
       height = 10,
       width = 10)

f:id:unageanu:20160111152705p:plain

横軸が「メールを非スパムと仮定した時の尤もらしさ」、縦軸が「メールをスパムと仮定した時の尤もらしさ」を示します。

「メールをスパムと仮定した時の尤もらしさ」 > 「メールを非スパムと仮定した時の尤もらしさ」となったメールをスパムと判定するので、真ん中の線より上の者はスパム、下は非スパムと判定されています。線より上に〇や△(非スパムのメール)がいくつかあったりはしますが、おおむね正しく判定できている感じですね。

事前分布を変えて、結果を改善する

↑では、とあるメールがあった時にそれがスパムである確率は50%(=世の中のメールの半分はスパムで半分はそうではない)と仮定していました。 現実には、80%は非スパム、残り20%がスパムなので、これを考慮して再計算することで結果を改善してみます。

spam.classifier を修正して、事前確率を変更。

spam.classifier <- function(path) {
  pr.spam <- classify.email(path, spam.df, prior=0.2)
  pr.ham <- classify.email(path, easyham.df, prior=0.8)
  return(c(pr.spam, pr.ham, ifelse(pr.spam > pr.ham, 1, 0)))
}

再計算した結果は以下。

> print(class.res)
              NOT SPAM       SPAM
easyham2.col 0.9892857 0.01071429
hardham2.col 0.9717742 0.02822581
spam2.col    0.4652827 0.53471725

非スパム(easyham2, hardham2)を間違ってスパムと判定する確率は少し改善しました。 一方、スパム(spam2)を間違ってスパムでないと判定する確率は悪化。。。

機械学習手習い: 数値によるデータの要約と可視化手法

「入門 機械学習」手習い、今日は「2章 データの調査」です。

www.amazon.co.jp

数値によるデータの要約と、可視化手法を学びます。

テスト用データの読み込み

> setwd("02-Exploration/")
> data.file <- file.path('data', '01_heights_weights_genders.csv')
> heights.weights <- read.csv(data.file, header = TRUE, sep = ',') 
> head(heights.weights)
  Gender   Height   Weight
1   Male 73.84702 241.8936
2   Male 68.78190 162.3105
3   Male 74.11011 212.7409
4   Male 71.73098 220.0425
5   Male 69.88180 206.3498
6   Male 67.25302 152.2122

データの数値による要約

summaryでベクトルの数値を要約します

> summary(heights.weights$Height)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
  54.26   63.51   66.32   66.37   69.17   79.00 

左から、

  • Min .. 最小値
  • 1st Qu .. 第一四分位(データ全体の下から25%の位置にあたる値)
  • Median .. 中央値(データ全体の50%の位置にあたる値)
  • Mean .. 平均値
  • 3rd Qu. .. (データ全体の下から75%の位置にあたる値)
  • Max .. 最大値

が表示されます。

最小値、最大値を求める

min/maxを使って、最小値/最大値を算出できます

# Heightだけを含むベクトルを作成
> heights <- with(heights.weights, Height)
> head(heights)
[1] 73.84702 68.78190 74.11011 71.73098 69.88180 67.25302
> min(heights)
[1] 54.26313
> max(heights)
[1] 78.99874

rangeで、両方をまとめて計算することもできます。

> range(heights)
[1] 54.26313 78.99874

分位数を求める

quantile で、データ中の各位置のデータを出力できます。

> quantile(heights)
      0%      25%      50%      75%     100% 
54.26313 63.50562 66.31807 69.17426 78.99874 

分割幅を指定することもできます。

> quantile(heights, probs = seq(0, 1, by = 0.20))
      0%      20%      40%      60%      80%     100% 
54.26313 62.85901 65.19422 67.43537 69.81162 78.99874 

分散と標準偏差を求める

var,sd を使います。

# 標準偏差
> var(heights)
[1] 14.80347
# 分散
> sd(heights)
[1] 3.847528

データの可視化

必要なライブラリを読み込み。

> library('ggplot2')

ヒストグラム

> plot = ggplot(heights.weights, aes(x = Height)) + geom_histogram(binwidth = 1)
> ggsave(plot = plot, filename = "histgram.png", width = 6, height = 8)

f:id:unageanu:20160110141407p:plain

密度プロットにしてみます。少ないデータ量でも、データセットの形状が分かりやすいのがメリット。

> plot = ggplot(heights.weights, aes(x = Height)) + geom_density()
> ggsave(plot = plot, filename = "kde_histgram.png", width = 6, height = 8)

f:id:unageanu:20160110141408p:plain

性別ごとの特徴をみるため、性別ごとのヒストグラムを表示してみます。

> plot = ggplot(heights.weights, aes(x = Height, fill = Gender)) + geom_density() + facet_grid(Gender ~ .)
> ggsave(plot = plot, filename = "gender_histgram.png", width = 6, height = 8)

f:id:unageanu:20160110141405p:plain

ヒストグラムの分類を整理。詳しくはWikipediaで。

  • 正規分布
    • ピーク(=最頻値)が1つしかない、単峰分布
    • 左右が対称
    • 裾が薄い(データのばらつきが小さい)
  • コーシー分布
    • ピーク(=最頻値)が1つしかない、単峰分布
    • 左右が対称
    • 裾が厚い(データのばらつきが大きい)
  • ガンマ分布
    • 左右が非対称で、平均値と中央値が大きく異なる
  • 指数分布
    • 左右が非対称で、最頻値がゼロ。

正規分布の例。

> set.seed(1)

> normal.values <- rnorm(250, 0, 1)
> plot = ggplot(data.frame(X = normal.values), aes(x = X)) + geom_density()
> ggsave(plot = plot, filename = "normal_histgram.png", width = 6, height = 8)

f:id:unageanu:20160110141409p:plain

コーシー分布。

> cauchy.values <- rcauchy(250, 0, 1)
> plot = ggplot(data.frame(X = cauchy.values), aes(x = X)) + geom_density()
> ggsave(plot = plot, filename = "cauchy_histgram.png", width = 6, height = 8)

f:id:unageanu:20160110141403p:plain

ガンマ分布。

> gamma.values <- rgamma(100000, 1, 0.001)
> plot = ggplot(data.frame(X = gamma.values), aes(x = X)) + geom_density()
> ggsave(plot = plot, filename = "gamma_histgram.png", width = 6, height = 8)

f:id:unageanu:20160110141404p:plain

指数分布、はない。

散布図

身長と体重の散布図を描きます。

> plot = ggplot(heights.weights, aes(x = Height, y = Weight)) + geom_point()
> ggsave(plot = plot, filename = "scatterplots.png", width = 6, height = 8)

f:id:unageanu:20160110141410p:plain

身長、体重には相関関係がありそう。geom_smooth()を使って、妥当な予測領域を表示してみます。

> plot = ggplot(heights.weights, aes(x = Height, y = Weight)) + geom_point() + geom_smooth()
> ggsave(plot = plot, filename = "scatterplots2.png", width = 6, height = 8)

f:id:unageanu:20160110141411p:plain

最後に、男女別の散布図を描いて終わり。

> plot = ggplot(heights.weights, aes(x = Height, y = Weight, color = Gender)) + geom_point()
> ggsave(plot = plot, filename = "gender_scatterplots.png", width = 6, height = 8)

f:id:unageanu:20160110141406p:plain

機械学習手習い : Rをインストールして、基本的な使い方を学ぶ

オライリーの「入門 機械学習」を手に入れたので、手を動かしながら学びます。

www.amazon.co.jp

まずは、1章。Rのインストールと基本的な使い方の学習まで。

Rのインストール

手元にあったCentOS7にインストールしました。

$ cat /etc/redhat-release
CentOS Linux release 7.1.1503 (Core) 
$ sudo yum install epel-release
$ sudo yum install R
$ R

R version 3.2.3 (2015-12-10) -- "Wooden Christmas-Tree"
Copyright (C) 2015 The R Foundation for Statistical Computing
Platform: x86_64-redhat-linux-gnu (64-bit)
...

サンプルコードのダウンロード

次に、GitHubで公開されている「入門 機械学習」のサンプルコードを取得します。

$ cd ~
$ git clone https://github.com/johnmyleswhite/ML_for_Hackers.git ml_for_hackers
$ cd  ml_for_hackers

必要モジュールのインストール

サンプルコードを動かす時に必要なパッケージをインストール。

$ R
> source("package_installer.R")

ユーザー権限で実行すると、デフォルトのインストール先に書き込み権がないとのこと。y を押して、ホームディレクトリにインストール。 そこそこ時間がかかるので待つ・・・。

・・・いくつか、エラーになりました。

1:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘rgl’ had non-zero exit status
2:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘XML’ had non-zero exit status
3:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘RCurl’ had non-zero exit status
4:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘Rpoppler’ had non-zero exit status
5:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘XML’ had non-zero exit status
6:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘RCurl’ had non-zero exit status
7:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘XML’ had non-zero exit status

個別にインストールしてみると、依存モジュールがないのが原因のよう。

install.packages("XML", dependencies=TRUE)

Cannot find xml2-config
ERROR: configuration failed for package ‘XML’
* removing ‘/home/yamautim/R/x86_64-redhat-linux-gnu-library/3.2/XML’

必要なモジュールをインストールして、

$ sudo yum -y install libxml2-devel curl-devel poppler-glib-devel freeglut-devel

再試行。

> install.packages("rgl", dependencies=TRUE)
> install.packages("Rpoppler", dependencies=TRUE)
> install.packages("XML", dependencies=TRUE)
> install.packages("RCurl", dependencies=TRUE)

基礎練習で使うライブラリとデータの読み込み

基礎練習で使うライブラリとデータ(UFOの目撃情報データ)を読み込みます。

> setwd("01-Introduction/")
> library("ggplot2")
> library(plyr)
> library(scales)
> ufo <- read.delim("data/ufo/ufo_awesome.tsv", sep="\t", stringsAsFactors=FALSE, header=FALSE, na.strings="")

先頭、末尾の6行を表示。

> head(ufo)                                    
        V1       V2                    V3   V4      V5
1 19951009 19951009         Iowa City, IA <NA>    <NA>
2 19951010 19951011         Milwaukee, WI <NA>  2 min.
3 19950101 19950103           Shelton, WA <NA>    <NA>
...
> tail(ufo)
            V1       V2                   V3        V4                V5
61865 20100828 20100828      Los Angeles, CA      disk        40 seconds
61866 20090424 20100820         Hartwell, GA      oval            10 min
61867 20100821 20100826  Franklin Square, NY  fireball        20 minutes
61868 20100827 20100827         Brighton, CO    circle    at lest 45 min
61869 20100818 20100821  Dryden (Canada), ON     other 5 Min. maybe more
61870 20050502 20100824        Fort Knox, KY  triangle        15 seconds

データをクリーニングする

読み込んだデータを、処理しやすい形に変換します。

データ列に名前をつける

> names(ufo) <- c("DateOccurred", "DateReported", "Location", "ShortDescription",
"Duration", "LongDescription")
> head(ufo)
  DateOccurred DateReported              Location ShortDescription Duration
1     19951009     19951009         Iowa City, IA             <NA>     <NA>
2     19951010     19951011         Milwaukee, WI             <NA>   2 min.
3     19950101     19950103           Shelton, WA             <NA>     <NA>
4     19950510     19950510          Columbia, MO             <NA>   2 min.
5     19950611     19950614           Seattle, WA             <NA>     <NA>
6     19951025     19951024  Brunswick County, ND             <NA>  30 min.

V1 などとなっていたところに、DateOccurredのようなわかりやすい名前が付きました。

DateOccurred,DateReported を日付に変換する

DateOccurred 列は、日付を示す文字列なので、Dateオブジェクトに変換します。

> ufo$DateOccurred <- as.Date(ufo$DateOccurred, format = "%Y%m%d")
 strptime(x, format, tz = "GMT") でエラー:  入力文字列が長すぎます 

エラーになりました。長い文字列を含む列があるよう。探します。

> head(ufo[which(nchar(ufo$DateOccurred)!=8|nchar(ufo$DateReported)!=8),1])
[1] "ler@gnv.ifas.ufl.edu" 
[2] "0000"
[3] "Callers report sighting a number of soft white  balls of lights headingin an easterly directing then changing direction to the west beforespeeding off to the north west."
[4] "0000"
[5] "0000"
[6] "0000"

変なデータがありますな。ということで、8文字でない列を削除します。

# 行のDateOccurred または、DateReported が8文字かどうかを格納するデータ(good.rows)を作成
> good.rows <- ifelse(nchar(ufo$DateOccurred) != 8 |
 nchar(ufo$DateReported) != 8,
 FALSE, TRUE)
# FALSE(=DateOccurred または、DateReported が8文字でない行)の数を確認。
> length(which(!good.rows))
[1] 731
# good.rowsがTRUEの行のみ抽出
> ufo <- ufo[good.rows, ]    

変換を再試行。今度は、日付型に変換できました。

> ufo$DateOccurred <- as.Date(ufo$DateOccurred, format = "%Y%m%d")
> ufo$DateReported <- as.Date(ufo$DateReported, format = "%Y%m%d")
> head(ufo)
  DateOccurred DateReported              Location ShortDescription Duration
1   1995-10-09   1995-10-09         Iowa City, IA             <NA>     <NA>
2   1995-10-10   1995-10-11         Milwaukee, WI             <NA>   2 min.
3   1995-01-01   1995-01-03           Shelton, WA             <NA>     <NA>

関数を作ってLocation列のデータを都市名と州名に分割する

Location列のデータは、Iowa City, IA のような「都市名, 州名」形式の文字列になっています。 関数を使ってこれを都市名, 州名に分解します。

まずは、分解を行う関数を作成。

> get.location <- function(l) {
  split.location <- tryCatch(strsplit(l, ",")[[1]], error = function(e) return(c(NA, NA)))
  clean.location <- gsub("^ ","",split.location)
  if (length(clean.location) > 2) {
    return(c(NA,NA))
  } else {
    return(clean.location)
  }
}

lapplyを使って、関数を適用したデータのリストを作成。

> city.state <- lapply(ufo$Location, get.location)
> head(city.state)
[[1]]
[1] "Iowa City" "IA"       

[[2]]
[1] "Milwaukee" "WI"   

これを、ufoに追加します。まずは、do.callでリストを行列に変換。

> location.matrix <- do.call(rbind, city.state)
> head(location.matrix)
     [,1]               [,2]
[1,] "Iowa City"        "IA"
[2,] "Milwaukee"        "WI"
[3,] "Shelton"          "WA"
[4,] "Columbia"         "MO"
[5,] "Seattle"          "WA"
[6,] "Brunswick County" "ND"

transformufoに追加します。

> ufo <- transform(ufo, USCity = location.matrix[, 1], USState = location.matrix[, 2], stringsAsFactors = FALSE)

また、データには、カナダのものが含まれているので、これも除去します。 USStateにアメリカの州名以外が入っているデータを削除。

> ufo$USState <- state.abb[match(ufo$USState, state.abb)]
> ufo.us <- subset(ufo, !is.na(USState))

これで、データのクリーニングは完了。

> summary(ufo.us)                                                                                                                                                        
  DateOccurred         DateReported          Location        
 Min.   :1400-06-30   Min.   :1905-06-23   Length:51636      
 1st Qu.:1999-09-07   1st Qu.:2002-04-14   Class :character  
 Median :2004-01-10   Median :2005-03-27   Mode  :character  
 Mean   :2001-02-13   Mean   :2004-11-30                     
 3rd Qu.:2007-07-27   3rd Qu.:2008-01-20                     
 Max.   :2010-08-30   Max.   :2010-08-30                     
 ShortDescription     Duration         LongDescription       USCity         
 Length:51636       Length:51636       Length:51636       Length:51636      
 Class :character   Class :character   Class :character   Class :character  
 Mode  :character   Mode  :character   Mode  :character   Mode  :character  
                                                                            
                                                                            
                                                                            
   USState         
 Length:51636      
 Class :character  
 Mode  :character  
                   
                   
                   
> head(ufo.us)
  DateOccurred DateReported              Location ShortDescription Duration
1   1995-10-09   1995-10-09         Iowa City, IA             <NA>     <NA>
2   1995-10-10   1995-10-11         Milwaukee, WI             <NA>   2 min.
3   1995-01-01   1995-01-03           Shelton, WA             <NA>     <NA>
4   1995-05-10   1995-05-10          Columbia, MO             <NA>   2 min.
5   1995-06-11   1995-06-14           Seattle, WA             <NA>     <NA>
6   1995-10-25   1995-10-24  Brunswick County, ND             <NA>  30 min.

データを分析する

クリーニングしたデータを分析して、州/月ごとの目撃情報の傾向を分析します。

DateOccurredのばらつきを調べる

> summary(ufo.us$DateOccurred) 
        Min.      1st Qu.       Median         Mean      3rd Qu.         Max.         NA's 
"1400-06-30" "1999-09-07" "2004-01-10" "2001-02-13" "2007-07-27" "2010-08-30"          "1" 

1400年のデータが含まれている・・・。ヒストグラムで、分布をみてみます。

> quick.hist <- ggplot(ufo.us, aes(x = DateOccurred)) +
  geom_histogram() + scale_x_date(date_breaks = "50 years", date_labels = "%Y")
> ggsave(plot = quick.hist,
 filename = file.path("images", "quick_hist.png"),
 height = 6,
 width = 8)

f:id:unageanu:20160109175440p:plain

大部分は最近の20年に集中している模様。この範囲に絞って分析するため、古いデータを取り除きます。

> ufo.us <- subset(ufo.us, DateOccurred >= as.Date("1990-01-01"))

データを州/月ごとに集計する

まずは、DateOccurred 列を「年-月」に変換した列を作成。

> ufo.us$YearMonth <- strftime(ufo.us$DateOccurred, format = "%Y-%m")

次に、月/州ごとにデータをグループ化して、データ数を集計します。

> sightings.counts <- ddply(ufo.us, .(USState,YearMonth), nrow)
> head(sightings.counts)
  USState YearMonth V1
1      AK   1990-01  1
2      AK   1990-03  1
3      AK   1990-05  1
4      AK   1993-11  1
5      AK   1994-11  1
6      AK   1995-01  1

これだけだと、データが一つもない月が集計結果に含まれないので、それを補完します。 seq.Dateを使ってシーケンスを作成。

> date.range <- seq.Date(from = as.Date(min(ufo.us$DateOccurred)),
                         to = as.Date(max(ufo.us$DateOccurred)),
                         by = "month")
> date.strings <- strftime(date.range, "%Y-%m")

作成したシーケンスに、州の一覧を掛け合わせて、州/月の行列を作成します。

> states.dates <- lapply(state.abb, function(s) cbind(s, date.strings))
> states.dates <- data.frame(do.call(rbind, states.dates),
 stringsAsFactors = FALSE)
> head(states.dates)
   s date.strings
1 AL      1990-01
2 AL      1990-02
3 AL      1990-03
4 AL      1990-04
5 AL      1990-05
6 AL      1990-06

さらに、sightings.counts をマージして、月の欠落がない集計データを作成。

> all.sightings <- merge(states.dates,
 sightings.counts,
 by.x = c("s", "date.strings"),
 by.y = c("USState", "YearMonth"),
 all = TRUE)
> head(all.sightings)                                                                                                                                 
   s date.strings V1
1 AK      1990-01  1
2 AK      1990-02 NA
3 AK      1990-03  1
4 AK      1990-04 NA
5 AK      1990-05  1
6 AK      1990-06 NA

わかりやすいよう、列に名前を付けます。また、集計しやすいようにNA0に変換するなどの操作を行っておきます。

> names(all.sightings) <- c("State", "YearMonth", "Sightings")
> all.sightings$Sightings[is.na(all.sightings$Sightings)] <- 0
> all.sightings$YearMonth <- as.Date(rep(date.range, length(state.abb)))
> all.sightings$State <- as.factor(all.sightings$State)
> head(all.sightings)
  State  YearMonth Sightings
1    AK 1990-01-01         1
2    AK 1990-02-01         0
3    AK 1990-03-01         1
4    AK 1990-04-01         0
5    AK 1990-05-01         1
6    AK 1990-06-01         0

分析用データはこれで完成。

月/州ごとの目撃情報数をグラフにして分析する

> state.plot <- ggplot(all.sightings, aes(x = YearMonth,y = Sightings)) +
  geom_line(aes(color = "darkblue")) +
  facet_wrap(~State, nrow = 10, ncol = 5) + 
  theme_bw() + 
  scale_color_manual(values = c("darkblue" = "darkblue"), guide = "none") +
  scale_x_date(date_breaks = "5 years", date_labels = '%Y') +
  xlab("Years") +
  ylab("Number of Sightings") +
  ggtitle("Number of UFO sightings by Month-Year and U.S. State (1990-2010)")
> ggsave(plot = state.plot,
       filename = file.path("images", "ufo_sightings.png"),
       width = 14,
       height = 8.5)

こんなグラフになります。

f:id:unageanu:20160109175441p:plain

分析は省略。