機械学習手習い: 回帰分析
「入門 機械学習」手習い、5日目。「5章 回帰:ページビューの予測」です。
回帰分析を使って、Webサイトのページビューを予測します。
平均値による単純な予測と、予測精度の計測
回帰分析のイメージをつかむために、平均値を使った単純な予測を試します。 平均値を使って寿命を予測し、平均二乗誤差を使って精度を測定。そこに、追加の情報(喫煙者かどうか)を加えて精度を向上します。
まずは、依存ライブラリの読み込みから。
> setwd("05-Regression/") > library('ggplot2')
寿命データを読み込む
> ages <- read.csv(file.path('data', 'longevity.csv')) > head(ages) Smokes AgeAtDeath 1 1 75 2 1 72 3 1 66 4 1 74 5 1 69 6 1 65
ヒストグラムにしてみます。
> plot = ggplot(ages, aes(x = AgeAtDeath, fill = factor(Smokes))) + geom_density() + facet_grid(Smokes ~ .) > ggsave(plot = plot, filename = file.path("images", "longevity_plot.png"), height = 4.8, width = 7)
非喫煙者のピークが喫煙者より右に寄っていて、やはり、非喫煙者の方が全体的に長生きな感じ。
平均を使った予測とその精度を測る
データの概要がつかめたところで、平均値を使った予測をしてみます。 まずは、平均値を算出。
> mean(ages$AgeAtDeath) [1] 72.723
平均2乗誤差を計算。
> guess <- 72.723 > with(ages, mean((AgeAtDeath - guess) ^ 2)) [1] 32.91427
32.91427
となりました。
喫煙者かどうかを考慮して、精度を向上する
喫煙者/非喫煙者それぞれで平均値を計測し、推測値として使います。
> smokers.guess <- with(subset(ages, Smokes == 1), mean(AgeAtDeath)) > non.smokers.guess <- with(subset(ages, Smokes == 0), mean(AgeAtDeath)) > ages <- transform(ages, NewPrediction = ifelse(Smokes == 0, non.smokers.guess, smokers.guess)) > with(ages, mean((AgeAtDeath - NewPrediction) ^ 2)) [1] 26.50831
誤差が減り、精度が向上しました。
線形回帰入門
線形回帰で予測を行うには、次の2つの仮定を置くことが前提となります。
可分性/加法性
単調性/線形性
- 特徴量に応じて、予測値が常に上昇、または減少すること。
これを踏まえて、実習を開始。身長/体重データを読み込んで、簡単な予測を試します。
# データを読み込み > heights.weights <- read.csv( file.path('data', '01_heights_weights_genders.csv'), header = TRUE, sep = ',')
散布図を描いてみます。 geom_smooth(method = 'lm')
で回帰線も引きます。
# 散布図を描く > plot <- ggplot(heights.weights, aes(x = Height, y = Weight)) + geom_point() + geom_smooth(method = 'lm') > ggsave(plot = plot, filename = file.path("images", "height_weight.png"), height = 4.8, width = 7)
青の線が身長から体重を予測する回帰線です。 身長60インチの人は、体重105ポンドくらいと予測できます。大体、妥当な感じ。
Rでは、lm
関数で、線形モデルでの回帰分析を行うことができます。
> fitted.regression <- lm(Weight ~ Height, data = heights.weights)
coef
関数を使うと、傾き( Height
)と切片( Intercept
)を確認できます。
> coef(fitted.regression) (Intercept) Height -350.737192 7.717288
predict
で予測値が得られます。
> head(predict(fitted.regression)) 1 2 3 4 5 6 219.1615 180.0725 221.1918 202.8314 188.5607 168.2737
residuals
を使うと、実際のデータとの誤差を計算できます。
head(residuals(fitted.regression)) 1 2 3 4 5 6 22.732083 -17.762074 -8.450953 17.211069 17.789073 -16.061519
ウェブサイトのアクセス数を予測する
まずはデータを読み込み。
> top.1000.sites <- read.csv(file.path('data', 'top_1000_sites.tsv'), sep = '\t', stringsAsFactors = FALSE) > head(top.1000.sites) Rank Site Category UniqueVisitors Reach 1 1 facebook.com Social Networks 880000000 47.2 2 2 youtube.com Online Video 800000000 42.7 3 3 yahoo.com Web Portals 660000000 35.3 4 4 live.com Search Engines 550000000 29.3 5 5 wikipedia.org Dictionaries & Encyclopedias 490000000 26.2 6 6 msn.com Web Portals 450000000 24.0 PageViews HasAdvertising InEnglish TLD 1 9.1e+11 Yes Yes com 2 1.0e+11 Yes Yes com 3 7.7e+10 Yes Yes com 4 3.6e+10 Yes Yes com 5 7.0e+09 No Yes org 6 1.5e+10 Yes Yes com
このデータから、PageView
を推測するのが今回の目的。
まずは、線形回帰が適用できるか、確認。ページビューの分布をみてみます。
> plot = ggplot(top.1000.sites, aes(x = PageViews)) + geom_density() > ggsave(plot = plot, filename = file.path("images", "top_1000_sites_page_view.png"), height = 4.8, width = 7)
ばらつきが大きすぎてよくわからない・・・。対数を取ってみます。
> plot = ggplot(top.1000.sites, aes(x = log(PageViews))) + geom_density() > ggsave(plot = plot, filename = file.path("images", "top_1000_sites_page_view2.png"), height = 4.8, width = 7)
この尺度なら、傾向が見えそう。UniqueVisitor
と PageView
の散布図を描いてみます。
> ggplot(top.1000.sites, aes(x = log(PageViews), y = log(UniqueVisitors))) + geom_point() > ggsave(file.path("images", "log_page_views_vs_log_visitors.png"))
回帰直線も追加。
> ggplot(top.1000.sites, aes(x = log(PageViews), y = log(UniqueVisitors))) + geom_point() + geom_smooth(method = 'lm', se = FALSE) > ggsave(file.path("images", "log_page_views_vs_log_visitors_with_lm.png"))
うまくいっているっぽい。lm
で回帰分析を実行。
> lm.fit <- lm(log(PageViews) ~ log(UniqueVisitors), data = top.1000.sites)
summary
を使うと、分析結果の要約が表示されます。
> summary(lm.fit) Call: lm(formula = log(PageViews) ~ log(UniqueVisitors), data = top.1000.sites) Residuals: Min 1Q Median 3Q Max -2.1825 -0.7986 -0.0741 0.6467 5.1549 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) -2.83441 0.75201 -3.769 0.000173 *** log(UniqueVisitors) 1.33628 0.04568 29.251 < 2e-16 *** --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 1.084 on 998 degrees of freedom Multiple R-squared: 0.4616, Adjusted R-squared: 0.4611 F-statistic: 855.6 on 1 and 998 DF, p-value: < 2.2e-16
このうち、 Residual standard error
が平均2乗誤差の平方根(RMSE)をとったもの。
さらに、他の情報も考慮に加えてみます。
> lm.fit <- lm(log(PageViews) ~ HasAdvertising + log(UniqueVisitors) + InEnglish, data = top.1000.sites) > summary(lm.fit) Call: lm(formula = log(PageViews) ~ HasAdvertising + log(UniqueVisitors) + InEnglish, data = top.1000.sites) Residuals: Min 1Q Median 3Q Max -2.4283 -0.7685 -0.0632 0.6298 5.4133 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) -1.94502 1.14777 -1.695 0.09046 . HasAdvertisingYes 0.30595 0.09170 3.336 0.00088 *** log(UniqueVisitors) 1.26507 0.07053 17.936 < 2e-16 *** InEnglishNo 0.83468 0.20860 4.001 6.77e-05 *** InEnglishYes -0.16913 0.20424 -0.828 0.40780 --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 1.067 on 995 degrees of freedom Multiple R-squared: 0.4798, Adjusted R-squared: 0.4777 F-statistic: 229.4 on 4 and 995 DF, p-value: < 2.2e-16
精度が少し上がりました。
機械学習手習い: 重要度による電子メールの並び替え
「入門 機械学習」手習い、4日目。「4章 順位づけ:優先トレイ」です。
電子メールを重要度で順位づけするシステムを作ります
並び替えのアプローチ
以下の素性を使って、メールに優先度をつけます。
- 1) 送信者のメッセージ数
- やり取りが多い送信者からのメールは重要とみなす
- 2) スレッドの活性
- 活発にやり取りされているスレッドのメールは優先度を高くする
- 活性度は、1秒あたりのスレッドのメール数で算出。多いほど、活性が高い
- 3) サブジェクトと本文に含まれる単語
必要なモジュールとデータの読み込み
> setwd("04-Ranking/") > library('tm') > library('ggplot2') > library('plyr') > library('reshape') # メールデータは、3章で使った非スパムメール(易)を使う > data.path <- file.path("..", "03-Classification", "data") > easyham.path <- file.path(data.path, "easy_ham")
メールから素性を取り出す
ファイルを読み込んで、本文を返す関数を作成します。
# メールファイルから、全データを読み込んで返す > msg.full <- function(path) { con <- file(path, open = "rt", encoding = "latin1") msg <- readLines(con) close(con) return(msg) } # メールデータからFromアドレスを取り出す > get.from <- function(msg.vec) { from <- msg.vec[grepl("From: ", msg.vec)] from <- strsplit(from, '[":<> ]')[1] from <- from[which(from != "" & from != " ")] return(from[grepl("@", from)][1]) } # メールデータから、サブジェクトを取り出す > get.subject <- function(msg.vec) { subj <- msg.vec[grepl("Subject: ", msg.vec)] if(length(subj) > 0) { return(strsplit(subj, "Subject: ")[1][2]) } else { return("") } } # メールデータから、本文を取り出す > get.msg <- function(msg.vec) { msg <- msg.vec[seq(which(msg.vec == "")[1] + 1, length(msg.vec), 1)] return(paste(msg, collapse = "\n")) } # メールデータから、受信日時を取り出す > get.date <- function(msg.vec) { date.grep <- grepl("^Date: ", msg.vec) date.grep <- which(date.grep == TRUE) date <- msg.vec[date.grep[1]] date <- strsplit(date, "\\+|\\-|: ")[1][2] date <- gsub("^\\s+|\\s+$", "", date) return(strtrim(date, 25)) } # メールを読み込んで、必要な素性を返す。 > parse.email <- function(path) { full.msg <- msg.full(path) date <- get.date(full.msg) from <- get.from(full.msg) subj <- get.subject(full.msg) msg <- get.msg(full.msg) return(c(date, from, subj, msg, path)) }
動作テスト。
> parse.email("../03-Classification/data/easy_ham/00111.a478af0547f2fd548f7b412df2e71a92") [1] "Mon, 7 Oct 2002 10:37:26" [2] "niall@linux.ie" ...
メールデータを読み込んで素性をデータフレームにまとめる
# 全メールを解析 > easyham.docs <- dir(easyham.path) > easyham.docs <- easyham.docs[which(easyham.docs != "cmds")] > easyham.parse <- lapply(easyham.docs, function(p) parse.email(file.path(easyham.path, p))) # データフレームに変換 > ehparse.matrix <- do.call(rbind, easyham.parse) > allparse.df <- data.frame(ehparse.matrix, stringsAsFactors = FALSE) > names(allparse.df) <- c("Date", "From.EMail", "Subject", "Message", "Path")
できた。
> head(allparse.df) Date From.EMail 1 Thu, 22 Aug 2002 18:26:25 kre@munnari.OZ.AU 2 Thu, 22 Aug 2002 12:46:18 steve.burt@cursor-system.com 3 Thu, 22 Aug 2002 13:52:38 timc@2ubh.com 4 Thu, 22 Aug 2002 09:15:25 monty@roscom.com
データの調整
送信日時が文字列になっているので、POSIXオブジェクトに変換します。
# 日本語環境だと、%b が Aug などの月名にマッチしないため、変更しておく。 > Sys.setlocale(locale="C") > date.converter <- function(dates, pattern1, pattern2) { pattern1.convert <- strptime(dates, pattern1) pattern2.convert <- strptime(dates, pattern2) pattern1.convert[is.na(pattern1.convert)] <- pattern2.convert[is.na(pattern1.convert)] return(pattern1.convert) } > pattern1 <- "%a, %d %b %Y %H:%M:%S" > pattern2 <- "%d %b %Y %H:%M:%S" > allparse.df$Date <- date.converter(allparse.df$Date, pattern1, pattern2) > head(allparse.df) Date From.EMail 1 2002-08-22 18:26:25 kre@munnari.OZ.AU 2 2002-08-22 12:46:18 steve.burt@cursor-system.com 3 2002-08-22 13:52:38 timc@2ubh.com 4 2002-08-22 09:15:25 monty@roscom.com 5 2002-08-22 14:38:22 Stewart.Smith@ee.ed.ac.uk # ロケールを戻しておく。 > Sys.setlocale(local="ja_JP.UTF-8")
また、サブジェクトと送信者アドレスを小文字に変更します。
> allparse.df$Subject <- tolower(allparse.df$Subject) > allparse.df$From.EMail <- tolower(allparse.df$From.EMail)
最後に、送信日時でソート。
> priority.df <- allparse.df[with(allparse.df, order(Date)), ]
データの最初の半分を訓練データに使うので、別の変数に格納しておきます。
> priority.train <- priority.df[1:(round(nrow(priority.df) / 2)), ]
送信者別メール件数での重みづけ
送信者ごとのメール件数で重みづけを行うため、まずは、件数がどんな感じになっているか確認します。
送信者ごとのメール件数を集計。
> from.weight <- melt(with(priority.train, table(From.EMail))) > from.weight <- from.weight[with(from.weight, order(value)), ] > head(from.weight) From.EMail value 1 adam@homeport.org 1 2 admin@networksonline.com 1 4 albert.white@ireland.sun.com 1 5 andr@sandy.ru 1 6 andris@aernet.ru 1 9 antoin@eire.com 1 > summary(from.weight$value) Min. 1st Qu. Median Mean 3rd Qu. Max. 1.00 1.00 2.00 4.63 4.00 55.00
平均は4.63通。最大値は55通でばらつきが大きいかな? 7通以上送信しているアドレスをグラフに表示してみます。
> from.ex <- subset(from.weight, value >= 7) > from.scales <- ggplot(from.ex) + geom_rect(aes(xmin = 1:nrow(from.ex) - 0.5, xmax = 1:nrow(from.ex) + 0.5, ymin = 0, ymax = value, fill = "lightgrey", color = "darkblue")) + scale_x_continuous(breaks = 1:nrow(from.ex), labels = from.ex$From.EMail) + coord_flip() + scale_fill_manual(values = c("lightgrey" = "lightgrey"), guide = "none") + scale_color_manual(values = c("darkblue" = "darkblue"), guide = "none") + ylab("Number of Emails Received (truncated at 6)") + xlab("Sender Address") + theme_bw() + theme(axis.text.y = element_text(size = 5, hjust = 1)) > ggsave(plot = from.scales, filename = file.path("images", "0011_from_scales.png"), height = 4.8, width = 7)
一部の送信者が、平均的な送信者の10倍以上、メールを送信しています。送信数をそのまま重みにしてしまうと、これらの特殊な送信者の優先度が高くなりすぎてしまいます。グラフを見ると、指数関数的に増えている感じなので、自然対数を使って重みを調整します。
# 対数を取る。重みがゼロにならないように、値に1を足す。 > from.weight <- transform(from.weight, Weight = log(value + 1), log10Weight = log10(value + 1))
スレッド活性での重みづけ
まずは、スレッド別のメール件数を集計します。
メールがスレッドに属するかどうかは、メールのサブジェクトを見て判定します。
# re:を除いたサブジェクト(=スレッド名)と送信者を取り出す。 > find.threads <- function(email.df) { response.threads <- strsplit(email.df$Subject, "re: ") is.thread <- sapply(response.threads, function(subj) ifelse(subj[1] == "", TRUE, FALSE)) threads <- response.threads[is.thread] senders <- email.df$From.EMail[is.thread] threads <- sapply(threads, function(t) paste(t[2:length(t)], collapse = "re: ")) return(cbind(senders,threads)) } > threads.matrix <- find.threads(priority.train) > head(threads.matrix) senders threads [1,] "kre@munnari.oz.au" "new sequences window" [2,] "stewart.smith@ee.ed.ac.uk" "[zzzzteana] nothing like mama used to make" [3,] "martin@srv0.ems.ed.ac.uk" "[zzzzteana] nothing like mama used to make" [4,] "stewart.smith@ee.ed.ac.uk" "[zzzzteana] nothing like mama used to make" [5,] "marc@perkel.com" "[sadev] live rule updates after release ???" [6,] "cwg-exmh@deepeddy.com" "new sequences window"
次に、スレッドごとの活性度を集計します。
# スレッドごとの活性度一覧を返す > get.threads <- function(threads.matrix, email.df) { threads <- unique(threads.matrix[, 2]) thread.counts <- lapply(threads, function(t) thread.counts(t, email.df)) thread.matrix <- do.call(rbind, thread.counts) return(cbind(threads, thread.matrix)) } # スレッド名に属するメールの活性度を返す > thread.counts <- function(thread, email.df) { # メールから、スレッドに属するメールの送信日時を取り出す thread.times <- email.df$Date[which(email.df$Subject == thread | email.df$Subject == paste("re:", thread))] freq <- length(thread.times) # スレッドのメールの総数 min.time <- min(thread.times) # 送信日時の最小値 max.time <- max(thread.times) # 送信日時の最大値 time.span <- as.numeric(difftime(max.time, min.time, units = "secs")) if(freq < 2) { # メールが1通しかない場合(返信がなくスレッドになっていない場合)、NAを返す return(c(NA, NA, NA)) } else { trans.weight <- freq / time.span # 1秒当たりのメール送信数 log.trans.weight <- 10 + log(trans.weight, base = 10) # 対数を取る。負にならないよう、10を足す(アフィん変換) return(c(freq, time.span, log.trans.weight)) } } > thread.weights <- data.frame(thread.weights, stringsAsFactors = FALSE) > names(thread.weights) <- c("Thread", "Freq", "Response", "Weight") > thread.weights$Freq <- as.numeric(thread.weights$Freq) > thread.weights$Response <- as.numeric(thread.weights$Response) > thread.weights$Weight <- as.numeric(thread.weights$Weight) > thread.weights <- subset(thread.weights, is.na(thread.weights$Freq) == FALSE) > head(thread.weights) Thread Freq Response Weight 1 please help a newbie compile mplayer :-) 4 42309 5.975627 2 prob. w/ install/uninstall 4 23745 6.226488 3 http://apt.nixia.no/ 10 265303 5.576258 4 problems with 'apt-get -f install' 3 55960 5.729244 5 problems with apt update 2 6347 6.498461 6 about apt, kernel updates and dist-upgrade 5 240238 5.318328
また、送信者での重みづけの補完として、「送信者が何スレッドに参加しているか」を示す重みも計算しておきます。
> email.thread <- function(threads.matrix) { senders <- threads.matrix[, 1] senders.freq <- table(senders) senders.matrix <- cbind(names(senders.freq), senders.freq, log(senders.freq + 1)) senders.df <- data.frame(senders.matrix, stringsAsFactors=FALSE) row.names(senders.df) <- 1:nrow(senders.df) names(senders.df) <- c("From.EMail", "Freq", "Weight") senders.df$Freq <- as.numeric(senders.df$Freq) senders.df$Weight <- as.numeric(senders.df$Weight) return(senders.df) } > senders.df <- email.thread(threads.matrix) > head(senders.df) From.EMail Freq Weight 1 adam@homeport.org 1 0.6931472 2 aeriksson@fastmail.fm 5 1.7917595 3 albert.white@ireland.sun.com 1 0.6931472 4 alex@netwindows.org 1 0.6931472 5 andr@sandy.ru 1 0.6931472 6 andris@aernet.ru 1 0.6931472
サブジェクトと本文に含まれる単語による重みづけ
まずは、サブジェクト。
- スレッド名に含まれる単語一覧を抽出して、単語ごとに重みを計算します。
- 単語を含む全スレッドのweightを取り出して、その平均を重みとして使います。
# 単語と出現頻度の一覧を返す > term.counts <- function(term.vec, control) { vec.corpus <- Corpus(VectorSource(term.vec)) vec.tdm <- TermDocumentMatrix(vec.corpus, control = control) return(rowSums(as.matrix(vec.tdm))) } # スレッド名に含まれる単語一覧を抽出 > thread.terms <- term.counts(thread.weights$Thread, control = list(stopwords = TRUE)) > thread.terms <- names(thread.terms) # 出現頻度は使わないので捨てる > head(thread.terms) [1] "--with" ":-)" "..." ".doc" "'apt-get" "\"holiday" # 単語ごとに重みを算出 # 単語を含む全スレッドのweightを取り出して、その平均を重みとして使う > term.weights <- sapply(thread.terms, function(t) mean(thread.weights$Weight[grepl(t, thread.weights$Thread, fixed = TRUE)])) > head(term.weights) --with :-) ... .doc 'apt-get "holiday 7.109579 6.103883 6.050786 5.725911 5.729244 7.197911 # 整形 > term.weights <- data.frame(list(Term = names(term.weights), Weight = term.weights), stringsAsFactors = FALSE, row.names = 1:length(term.weights)) > head(term.weights) Term Weight 1 --with 7.109579 2 :-) 6.103883 3 ... 6.050786 4 .doc 5.725911 5 'apt-get 5.729244 6 "holiday 7.197911
次に本文。
# 本文に含まれる単語と頻度を集計 > msg.terms <- term.counts(priority.train$Message, control = list(stopwords = TRUE, removePunctuation = TRUE, removeNumbers = TRUE)) # 重みを算出。ここでも対数をとる > msg.weights <- data.frame(list(Term = names(msg.terms), Weight = log(msg.terms, base = 10)), stringsAsFactors = FALSE, row.names = 1:length(msg.terms)) # 重みがゼロのものは除外 > msg.weights <- subset(msg.weights, Weight > 0)
これで、すべての重みデータフレームがそろいました。
順位づけを行う
重要度を計算する関数を定義します。
# 単語の重みを返す # 単語、検索する重みデータフレーム、term.weightが検索対象かどうか、を引数で受け取り、重みを返す。 > get.weights <- function(search.term, weight.df, term = TRUE) { if(length(search.term) > 0) { # weight.dfがterm.weightかどうかで列名が異なるので、ここで調整 if(term) { term.match <- match(names(search.term), weight.df$Term) } else { term.match <- match(search.term, weight.df$Thread) } match.weights <- weight.df$Weight[which(!is.na(term.match))] if(length(match.weights) < 1) { # マッチする件数がゼロの場合、1を使う return(1) } else { # マッチする件数が1以上の場合、平均を使う return(mean(match.weights)) } } else { return(1) } } # メールの重要度を返す > rank.message <- function(path) { # メールを解析 msg <- parse.email(path) # 送信者が送信したメール数に基づく重みを取得 from <- ifelse(length(which(from.weight$From.EMail == msg[2])) > 0, from.weight$Weight[which(from.weight$From.EMail == msg[2])], 1) # 送信者が参加したスレッド数に基づく重みを取得 thread.from <- ifelse(length(which(senders.df$From.EMail == msg[2])) > 0, senders.df$Weight[which(senders.df$From.EMail == msg[2])], 1) # メールがスレッドへの投降かどうかを判定し、スレッドへの投稿であれば、スレッドの重みを取得 subj <- strsplit(tolower(msg[3]), "re: ") is.thread <- ifelse(subj[[1]][1] == "", TRUE, FALSE) if(is.thread){ activity <- get.weights(subj[[1]][2], thread.weights, term = FALSE) } else { # スレッドへの投稿でない場合、重みは1 activity <- 1 } # メールサブジェクトに基づく重みを取得 thread.terms <- term.counts(msg[3], control = list(stopwords = TRUE)) thread.terms.weights <- get.weights(thread.terms, term.weights) # メール本文に基づく重みを取得 msg.terms <- term.counts(msg[4], control = list(stopwords = TRUE, removePunctuation = TRUE, removeNumbers = TRUE)) msg.weights <- get.weights(msg.terms, msg.weights) # 重みをすべて掛け合わせて、重要度を算出する rank <- prod(from, thread.from, activity, thread.terms.weights, msg.weights) return(c(msg[1], msg[2], msg[3], rank)) }
動作テスト。
> rank.message("../03-Classification/data/easy_ham/00111.a478af0547f2fd548f7b412df2e71a92") [1] "Mon, 7 Oct 2002 10:37:26" [2] "niall@linux.ie" [3] "Re: [ILUG] Interesting article on free software licences" [4] "5.27542087468428"
優先メールとみなす閾値が妥当か確認する
今回は、優先度の中央値を閾値として使います。 データの半分を使って、閾値が妥当かチェックします。
train.paths <- priority.df$Path[1:(round(nrow(priority.df) / 2))] test.paths <- priority.df$Path[((round(nrow(priority.df) / 2)) + 1):nrow(priority.df)] # train.pathsに含まれるメールの重要度を算出 train.ranks <- suppressWarnings(lapply(train.paths, rank.message)) # データフレームに変換 > train.ranks.matrix <- do.call(rbind, train.ranks) > train.ranks.matrix <- cbind(train.paths, train.ranks.matrix, "TRAINING") > train.ranks.df <- data.frame(train.ranks.matrix, stringsAsFactors = FALSE) > names(train.ranks.df) <- c("Message", "Date", "From", "Subj", "Rank", "Type") > train.ranks.df$Rank <- as.numeric(train.ranks.df$Rank) > head(train.ranks.df) Message 1 ../03-Classification/data/easy_ham/01061.6610124afa2a5844d41951439d1c1068 2 ../03-Classification/data/easy_ham/01062.ef7955b391f9b161f3f2106c8cda5edb 3 ../03-Classification/data/easy_ham/01063.ad3449bd2890a29828ac3978ca8c02ab 4 ../03-Classification/data/easy_ham/01064.9f4fc60b4e27bba3561e322c82d5f7ff 5 ../03-Classification/data/easy_ham/01070.6e34c1053a1840779780a315fb083057 6 ../03-Classification/data/easy_ham/01072.81ed44b31e111f9c1e47e53f4dfbefe3 Date From 1 Thu, 31 Jan 2002 22:44:14 robinderbains@shaw.ca 2 01 Feb 2002 00:53:41 lance_tt@bellsouth.net 3 Fri, 01 Feb 2002 02:01:44 robinderbains@shaw.ca 4 Fri, 1 Feb 2002 10:29:23 matthias@egwn.net 5 Fri, 1 Feb 2002 12:42:02 bfrench@ematic.com 6 Fri, 1 Feb 2002 13:39:31 bfrench@ematic.com Subj Rank Type 1 Please help a newbie compile mplayer :-) 3.614003 TRAINING 2 Re: Please help a newbie compile mplayer :-) 120.742481 TRAINING 3 Re: Please help a newbie compile mplayer :-) 20.348502 TRAINING 4 Re: Please help a newbie compile mplayer :-) 307.809626 TRAINING 5 Prob. w/ install/uninstall 3.653047 TRAINING 6 RE: Prob. w/ install/uninstall 21.685750 TRAINING
閾値を中央値に設定して、訓練データの重要度と密度を図にします。
# 閾値を中央値に設定 > priority.threshold <- median(train.ranks.df$Rank) # 訓練データの重要度と密度を図示 > threshold.plot <- ggplot(train.ranks.df, aes(x = Rank)) + stat_density(aes(fill="darkred")) + geom_vline(xintercept = priority.threshold, linetype = 2) + scale_fill_manual(values = c("darkred" = "darkred"), guide = "none") + theme_bw() > ggsave(plot = threshold.plot, filename = file.path("images", "01_threshold_plot.png"), height = 4.7, width = 7)
図中の点線が中央値。 ここを閾値にすれば、ランクの高い裾部分と、密度の高い部分の電子メールもある程度含まれるので、これらを優先メールと判定したのでよさそう。
残りのデータも加えて、図にしてみます。
# test.ranksに含まれるメールの重要度を算出 > train.ranks.df$Priority <- ifelse(train.ranks.df$Rank >= priority.threshold, 1, 0) > test.ranks <- suppressWarnings(lapply(test.paths,rank.message)) > test.ranks.matrix <- do.call(rbind, test.ranks) > test.ranks.matrix <- cbind(test.paths, test.ranks.matrix, "TESTING") > test.ranks.df <- data.frame(test.ranks.matrix, stringsAsFactors = FALSE) > names(test.ranks.df) <- c("Message","Date","From","Subj","Rank","Type") > test.ranks.df$Rank <- as.numeric(test.ranks.df$Rank) > test.ranks.df$Priority <- ifelse(test.ranks.df$Rank >= priority.threshold, 1, 0) # 訓練用データとテスト用データをマージ > final.df <- rbind(train.ranks.df, test.ranks.df) > final.df$Date <- date.converter(final.df$Date, pattern1, pattern2) > final.df <- final.df[rev(with(final.df, order(Date))), ] > head(final.df) Message 2500 ../03-Classification/data/easy_ham/00883.c44a035e7589e83076b7f1fed8fa97d5 2499 ../03-Classification/data/easy_ham/02500.05b3496ce7bca306bed0805425ec8621 2498 ../03-Classification/data/easy_ham/02499.b4af165650f138b10f9941f6cc5bce3c 2497 ../03-Classification/data/easy_ham/02498.09835f512f156da210efb99fcc523e21 2496 ../03-Classification/data/easy_ham/02497.60497db0a06c2132ec2374b2898084d3 2495 ../03-Classification/data/easy_ham/02496.aae0c81581895acfe65323f344340856 Date From 2500 <NA> sdw@lig.net 2499 <NA> ilug_gmc@fiachra.ucd.ie 2498 <NA> mwh@python.net 2497 <NA> nickm@go2.ie 2496 <NA> phil@techworks.ie 2495 <NA> timc@2ubh.com Subj Rank Type 2500 Re: ActiveBuddy 6.219744 TESTING 2499 Re: [ILUG] Linux Install 2.278890 TESTING 2498 [Spambayes] Re: New Application of SpamBayesian tech? 4.265954 TESTING 2497 Re: [ILUG] Linux Install 4.576643 TESTING 2496 Re: [ILUG] Linux Install 3.652100 TESTING 2495 [zzzzteana] Surfing the tube 27.987331 TESTING Priority 2500 0 2499 0 2498 0 2497 0 2496 0 2495 1 # 図示 > testing.plot <- ggplot(subset(final.df, Type == "TRAINING"), aes(x = Rank)) + stat_density(aes(fill = Type, alpha = 0.65)) + stat_density(data = subset(final.df, Type == "TESTING"), aes(fill = Type, alpha = 0.65)) + geom_vline(xintercept = priority.threshold, linetype = 2) + scale_alpha(guide = "none") + scale_fill_manual(values = c("TRAINING" = "darkred", "TESTING" = "darkblue")) + theme_bw() > ggsave(plot = testing.plot, filename = file.path("images", "02_testing_plot.png"), height = 4.7, width = 7)
テストデータは、訓練データより優先度低のメールが多く含まれる結果になっています。 これは、テストデータの素性に、訓練データに含まれないデータが多く含まれ、これらが順序付け時に無視されているためであり、妥当らしい。ふむ。
最後に優先度一覧をcsvに出力しておしまい。
write.csv(final.df, file.path("data", "final_df.csv"), row.names = FALSE)
機械学習手習い: スパムフィルタを作る
「入門 機械学習」手習い、3日目。「3章 分類:スパムフィルタ」です。
ナイーブベイズ分類器を作って、メールがスパムかどうかを判定するフィルタを作ります。
分類器の仕組み
- 1) 以下の単語セットを作成
- (a) スパムメッセージに出現しやすい単語とその出現確率
- (b) スパムメッセージに出現しにくい単語とその出現確率
- 2) で作成した単語セットを元に、メール本文を評価し、以下を算出
- (a2) メールをスパムと仮定した時の尤もらしさ
- (b2) メールを非スパムと仮定した時の尤もらしさ
- 3) a2 > b2 となるメールをスパムと判定する
という感じで判定を行います。
必要なモジュールとデータの読み込み
> setwd("03-Classification/") > library('tm') > library('ggplot2') # テスト用データ # 分類機の訓練用 > spam.path <- file.path("data", "spam") # スパムデータ > easyham.path <- file.path("data", "easy_ham") # 非スパムデータ(易) > hardham.path <- file.path("data", "hard_ham") # 非スパムデータ(難) # 分類機のテスト用 > spam2.path <- file.path("data", "spam_2") # スパムデータ > easyham2.path <- file.path("data", "easy_ham_2") # 非スパムデータ(易) > hardham2.path <- file.path("data", "hard_ham_2") # 非スパムデータ(難)
メールから本文を取り出す
ファイルを読み込んで、本文を返す関数を作成します。
> get.msg <- function(path) { con <- file(path, open = "rt", encoding = "latin1") text <- readLines(con) # The message always begins after the first full line break msg <- text[seq(which(text == "")[1] + 1, length(text), 1)] close(con) return(paste(msg, collapse = "\n")) }
動作テスト。
> get.msg("data/spam/00001.7848dde101aa985090474a91ec93fcf0") [1] "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0 Transitio...
sapply
を使って、data/spam
内のスパムメールから本文を読み込みます。
> spam.docs <- dir(spam.path) > spam.docs <- spam.docs[which(spam.docs != "cmds")] > all.spam <- sapply(spam.docs,function(p) get.msg(file.path(spam.path, p))) > head(all.spam)
単語文書行列を作る
単語を行、文書を列として、文書中の単語の出現数をカウントした、単語文書行列(TDM:Term Document Matrix)を作ります。
↓こんなイメージ
a.txt | b.txt | c.txt | |
---|---|---|---|
hello | 2 | 1 | 0 |
world | 0 | 2 | 3 |
まずは、TDMを作成する関数を定義。
> get.tdm <- function(doc.vec){ control <- list(stopwords = TRUE, removePunctuation = TRUE, removeNumbers = TRUE, minDocFreq = 2) doc.corpus <- Corpus(VectorSource(doc.vec)) doc.dtm <- TermDocumentMatrix(doc.corpus, control) return(doc.dtm) }
all.spam
から、TDMを作成します。
> spam.tdm <- get.tdm(all.spam)
スパムの訓練データを作る
スパムに含まれる単語の出現確率を集計します。
# all.spamのTDMをRの行列に変換 > spam.matrix <- as.matrix(spam.tdm) # 各単語の、全スパム中での出現頻度をカウント > spam.counts <- rowSums(spam.matrix) # データフレームに変換 > spam.df <- data.frame(cbind(names(spam.counts), as.numeric(spam.counts)), stringsAsFactors = FALSE) > names(spam.df) <- c("term", "frequency") # 出現確率を算出 > spam.df$frequency <- as.numeric(spam.df$frequency) > spam.occurrence <- sapply(1:nrow(spam.matrix), function(i) { length(which(spam.matrix[i, ] > 0)) / ncol(spam.matrix) }) > spam.density <- spam.df$frequency / sum(spam.df$frequency) > spam.df <- transform(spam.df, density = spam.density, occurrence = spam.occurrence)
訓練データの中身を確認。
> head(spam.df[with(spam.df, order(-occurrence)),]) term frequency density occurrence 7693 email 813 0.005855914 0.566 18706 please 425 0.003061210 0.508 14623 list 409 0.002945964 0.444 27202 will 828 0.005963957 0.422 2970 body 379 0.002729879 0.408 9369 free 543 0.003911146 0.390
スパムメール中の56%が email
を含んでいます。
非スパムの訓練データを作る
スパムの訓練データを作ったのと同じ手順で、非スパムの訓練データも作ります。
> easyham.docs <- dir(easyham.path) > easyham.docs <- easyham.docs[which(easyham.docs != "cmds")] # 最初の分類機を作る際に、各メッセージがスパムである確率とそうでない確率を等しいと仮定するため、 # 訓練データもスパムの文書数と同じ数だけに限定する。 > all.easyham <- sapply(easyham.docs[1:length(spam.docs)], function(p) get.msg(file.path(easyham.path, p))) > easyham.tdm <- get.tdm(all.easyham) > easyham.matrix <- as.matrix(easyham.tdm) > easyham.counts <- rowSums(easyham.matrix) > easyham.df <- data.frame(cbind(names(easyham.counts), as.numeric(easyham.counts)), stringsAsFactors = FALSE) > names(easyham.df) <- c("term", "frequency") > easyham.df$frequency <- as.numeric(easyham.df$frequency) > easyham.occurrence <- sapply(1:nrow(easyham.matrix), function(i) { length(which(easyham.matrix[i, ] > 0)) / ncol(easyham.matrix) }) > easyham.density <- easyham.df$frequency / sum(easyham.df$frequency) > easyham.df <- transform(easyham.df, density = easyham.density, occurrence = easyham.occurrence)
訓練データの中身を確認。
> head(easyham.df[with(easyham.df, order(-occurrence)),]) term frequency density occurrence 5193 group 232 0.003317366 0.388 12877 use 271 0.003875027 0.378 13511 wrote 237 0.003388861 0.378 1629 can 348 0.004976049 0.368 7244 list 248 0.003546150 0.368 8581 one 356 0.005090441 0.336
分類器を作る
メールファイルを受け取り、それがスパム、または、非スパムである確率を計算する分類器を定義します。
> classify.email <- function(path, training.df, prior = 0.5, c = 1e-6) { # ファイルから本文を取り出し、単語の出現数をカウント msg <- get.msg(path) msg.tdm <- get.tdm(msg) msg.freq <- rowSums(as.matrix(msg.tdm)) # メール内の単語のうち、訓練データに含まれる単語を取得 msg.match <- intersect(names(msg.freq), training.df$term) # 単語の出現確率を掛け合わせて、条件付き確率を算出する # この時、訓練データに含まれない単語は、非常に小さい確率0.0001として扱う。 if(length(msg.match) < 1) { # 訓練データに含まれる単語が1つもない場合、条件付き確率は、 # 事前確率(prior) * (0.0001の単語数乗) # となる。 return(prior * c ^ (length(msg.freq))) } else { # 訓練データから、単語ごとの出現確率を取り出す。 match.probs <- training.df$occurrence[match(msg.match, training.df$term)] # 条件付き確率を計算 # 事前確率(prior) * (訓練データに含まれる単語の出現確率の積) * (0.0001の訓練データに含まれない単語数乗) return(prior * prod(match.probs) * c ^ (length(msg.freq) - length(msg.match))) } }
判定してみる
非スパム(難)のデータを判定してみます。
> hardham.docs <- dir(hardham.path) > hardham.docs <- hardham.docs[which(hardham.docs != "cmds")] > hardham.spamtest <- sapply(hardham.docs, function(p) classify.email(file.path(hardham.path, p), training.df = spam.df)) > hardham.hamtest <- sapply(hardham.docs, function(p) classify.email(file.path(hardham.path, p), training.df = easyham.df)) > hardham.res <- ifelse(hardham.spamtest > hardham.hamtest, TRUE, FALSE) > summary(hardham.res) Mode FALSE TRUE NA's logical 243 6 0
データはすべて非スパムのものなので、誤判定は6件、偽陽性率は2.4%。
テスト用データすべてを判定してみる
スパム判定を行う関数を作成。
spam.classifier <- function(path) { pr.spam <- classify.email(path, spam.df) pr.ham <- classify.email(path, easyham.df) return(c(pr.spam, pr.ham, ifelse(pr.spam > pr.ham, 1, 0))) }
テスト用のメールデータをすべて判定してみます。
> easyham2.docs <- dir(easyham2.path) > easyham2.docs <- easyham2.docs[which(easyham2.docs != "cmds")] > hardham2.docs <- dir(hardham2.path) > hardham2.docs <- hardham2.docs[which(hardham2.docs != "cmds")] > spam2.docs <- dir(spam2.path) > spam2.docs <- spam2.docs[which(spam2.docs != "cmds")] > easyham2.class <- suppressWarnings(lapply(easyham2.docs, function(p) { spam.classifier(file.path(easyham2.path, p)) })) > hardham2.class <- suppressWarnings(lapply(hardham2.docs, function(p) { spam.classifier(file.path(hardham2.path, p)) })) > spam2.class <- suppressWarnings(lapply(spam2.docs, function(p) { spam.classifier(file.path(spam2.path, p)) })) > easyham2.matrix <- do.call(rbind, easyham2.class) > easyham2.final <- cbind(easyham2.matrix, "EASYHAM") > hardham2.matrix <- do.call(rbind, hardham2.class) > hardham2.final <- cbind(hardham2.matrix, "HARDHAM") > spam2.matrix <- do.call(rbind, spam2.class) > spam2.final <- cbind(spam2.matrix, "SPAM") > class.matrix <- rbind(easyham2.final, hardham2.final, spam2.final) > class.df <- data.frame(class.matrix, stringsAsFactors = FALSE) > names(class.df) <- c("Pr.SPAM" ,"Pr.HAM", "Class", "Type") > class.df$Pr.SPAM <- as.numeric(class.df$Pr.SPAM) > class.df$Pr.HAM <- as.numeric(class.df$Pr.HAM) > class.df$Class <- as.logical(as.numeric(class.df$Class)) > class.df$Type <- as.factor(class.df$Type)
文書ごとのスパム/非スパムの尤もらしさ、分類結果、メールの種別を含むデータフレームができました。
> head(class.df) Pr.SPAM Pr.HAM Class Type 1 0.000000e+00 0.000000e+00 FALSE EASYHAM 2 5.352364e-248 1.159512e-155 FALSE EASYHAM 3 0.000000e+00 5.103377e-216 FALSE EASYHAM 4 0.000000e+00 0.000000e+00 FALSE EASYHAM 5 2.083521e-169 1.221918e-108 FALSE EASYHAM 6 0.000000e+00 0.000000e+00 FALSE EASYHAM
文書の種類ごとに結果を集計してみます。
> get.results <- function(bool.vector) { results <- c(length(bool.vector[which(bool.vector == FALSE)]) / length(bool.vector), length(bool.vector[which(bool.vector == TRUE)]) / length(bool.vector)) return(results) } > easyham2.col <- get.results(subset(class.df, Type == "EASYHAM")$Class) > hardham2.col <- get.results(subset(class.df, Type == "HARDHAM")$Class) > spam2.col <- get.results(subset(class.df, Type == "SPAM")$Class) > class.res <- rbind(easyham2.col, hardham2.col, spam2.col) > colnames(class.res) <- c("NOT SPAM", "SPAM") > print(class.res) NOT SPAM SPAM easyham2.col 0.9871429 0.01285714 hardham2.col 0.9677419 0.03225806 spam2.col 0.4631353 0.53686471
非スパム(easyham2, hardham2)を間違ってスパムと判定する確率(偽陽性率)はそれぞれ1%,3%と低く、うまく分類できているもよう。 ただ、スパム(spam2)を間違ってスパムでないと判定する確率(偽陰性率)は46%。あれ、ちょっと高いかも?
結果を分散図にしてみる
> class.plot <- ggplot(class.df, aes(x = log(Pr.HAM), log(Pr.SPAM))) + geom_point(aes(shape = Type, alpha = 0.5)) + geom_abline(intercept = 0, slope = 1) + scale_shape_manual(values = c("EASYHAM" = 1, "HARDHAM" = 2, "SPAM" = 3), name = "Email Type") + scale_alpha(guide = "none") + xlab("log[Pr(HAM)]") + ylab("log[Pr(SPAM)]") + theme_bw() + theme(axis.text.x = element_blank(), axis.text.y = element_blank()) > ggsave(plot = class.plot, filename = file.path("images", "03_final_classification.png"), height = 10, width = 10)
横軸が「メールを非スパムと仮定した時の尤もらしさ」、縦軸が「メールをスパムと仮定した時の尤もらしさ」を示します。
「メールをスパムと仮定した時の尤もらしさ」 > 「メールを非スパムと仮定した時の尤もらしさ」となったメールをスパムと判定するので、真ん中の線より上の者はスパム、下は非スパムと判定されています。線より上に〇や△(非スパムのメール)がいくつかあったりはしますが、おおむね正しく判定できている感じですね。
事前分布を変えて、結果を改善する
↑では、とあるメールがあった時にそれがスパムである確率は50%(=世の中のメールの半分はスパムで半分はそうではない)と仮定していました。 現実には、80%は非スパム、残り20%がスパムなので、これを考慮して再計算することで結果を改善してみます。
spam.classifier
を修正して、事前確率を変更。
spam.classifier <- function(path) { pr.spam <- classify.email(path, spam.df, prior=0.2) pr.ham <- classify.email(path, easyham.df, prior=0.8) return(c(pr.spam, pr.ham, ifelse(pr.spam > pr.ham, 1, 0))) }
再計算した結果は以下。
> print(class.res) NOT SPAM SPAM easyham2.col 0.9892857 0.01071429 hardham2.col 0.9717742 0.02822581 spam2.col 0.4652827 0.53471725
非スパム(easyham2, hardham2)を間違ってスパムと判定する確率は少し改善しました。 一方、スパム(spam2)を間違ってスパムでないと判定する確率は悪化。。。
機械学習手習い: 数値によるデータの要約と可視化手法
「入門 機械学習」手習い、今日は「2章 データの調査」です。
数値によるデータの要約と、可視化手法を学びます。
テスト用データの読み込み
> setwd("02-Exploration/") > data.file <- file.path('data', '01_heights_weights_genders.csv') > heights.weights <- read.csv(data.file, header = TRUE, sep = ',') > head(heights.weights) Gender Height Weight 1 Male 73.84702 241.8936 2 Male 68.78190 162.3105 3 Male 74.11011 212.7409 4 Male 71.73098 220.0425 5 Male 69.88180 206.3498 6 Male 67.25302 152.2122
データの数値による要約
summary
でベクトルの数値を要約します
> summary(heights.weights$Height) Min. 1st Qu. Median Mean 3rd Qu. Max. 54.26 63.51 66.32 66.37 69.17 79.00
左から、
Min
.. 最小値1st Qu
.. 第一四分位(データ全体の下から25%の位置にあたる値)Median
.. 中央値(データ全体の50%の位置にあたる値)Mean
.. 平均値3rd Qu.
.. (データ全体の下から75%の位置にあたる値)Max
.. 最大値
が表示されます。
最小値、最大値を求める
min/max
を使って、最小値/最大値を算出できます
# Heightだけを含むベクトルを作成 > heights <- with(heights.weights, Height) > head(heights) [1] 73.84702 68.78190 74.11011 71.73098 69.88180 67.25302 > min(heights) [1] 54.26313 > max(heights) [1] 78.99874
range
で、両方をまとめて計算することもできます。
> range(heights) [1] 54.26313 78.99874
分位数を求める
quantile
で、データ中の各位置のデータを出力できます。
> quantile(heights) 0% 25% 50% 75% 100% 54.26313 63.50562 66.31807 69.17426 78.99874
分割幅を指定することもできます。
> quantile(heights, probs = seq(0, 1, by = 0.20)) 0% 20% 40% 60% 80% 100% 54.26313 62.85901 65.19422 67.43537 69.81162 78.99874
分散と標準偏差を求める
var
,sd
を使います。
# 標準偏差 > var(heights) [1] 14.80347 # 分散 > sd(heights) [1] 3.847528
データの可視化
必要なライブラリを読み込み。
> library('ggplot2')
ヒストグラム
> plot = ggplot(heights.weights, aes(x = Height)) + geom_histogram(binwidth = 1) > ggsave(plot = plot, filename = "histgram.png", width = 6, height = 8)
密度プロットにしてみます。少ないデータ量でも、データセットの形状が分かりやすいのがメリット。
> plot = ggplot(heights.weights, aes(x = Height)) + geom_density() > ggsave(plot = plot, filename = "kde_histgram.png", width = 6, height = 8)
性別ごとの特徴をみるため、性別ごとのヒストグラムを表示してみます。
> plot = ggplot(heights.weights, aes(x = Height, fill = Gender)) + geom_density() + facet_grid(Gender ~ .) > ggsave(plot = plot, filename = "gender_histgram.png", width = 6, height = 8)
- 正規分布
- ピーク(=最頻値)が1つしかない、単峰分布
- 左右が対称
- 裾が薄い(データのばらつきが小さい)
- コーシー分布
- ピーク(=最頻値)が1つしかない、単峰分布
- 左右が対称
- 裾が厚い(データのばらつきが大きい)
- ガンマ分布
- 左右が非対称で、平均値と中央値が大きく異なる
- 指数分布
- 左右が非対称で、最頻値がゼロ。
正規分布の例。
> set.seed(1) > normal.values <- rnorm(250, 0, 1) > plot = ggplot(data.frame(X = normal.values), aes(x = X)) + geom_density() > ggsave(plot = plot, filename = "normal_histgram.png", width = 6, height = 8)
コーシー分布。
> cauchy.values <- rcauchy(250, 0, 1) > plot = ggplot(data.frame(X = cauchy.values), aes(x = X)) + geom_density() > ggsave(plot = plot, filename = "cauchy_histgram.png", width = 6, height = 8)
ガンマ分布。
> gamma.values <- rgamma(100000, 1, 0.001) > plot = ggplot(data.frame(X = gamma.values), aes(x = X)) + geom_density() > ggsave(plot = plot, filename = "gamma_histgram.png", width = 6, height = 8)
指数分布、はない。
散布図
身長と体重の散布図を描きます。
> plot = ggplot(heights.weights, aes(x = Height, y = Weight)) + geom_point() > ggsave(plot = plot, filename = "scatterplots.png", width = 6, height = 8)
身長、体重には相関関係がありそう。geom_smooth()
を使って、妥当な予測領域を表示してみます。
> plot = ggplot(heights.weights, aes(x = Height, y = Weight)) + geom_point() + geom_smooth() > ggsave(plot = plot, filename = "scatterplots2.png", width = 6, height = 8)
最後に、男女別の散布図を描いて終わり。
> plot = ggplot(heights.weights, aes(x = Height, y = Weight, color = Gender)) + geom_point() > ggsave(plot = plot, filename = "gender_scatterplots.png", width = 6, height = 8)
機械学習手習い : Rをインストールして、基本的な使い方を学ぶ
オライリーの「入門 機械学習」を手に入れたので、手を動かしながら学びます。
まずは、1章。Rのインストールと基本的な使い方の学習まで。
Rのインストール
手元にあったCentOS7にインストールしました。
$ cat /etc/redhat-release CentOS Linux release 7.1.1503 (Core) $ sudo yum install epel-release $ sudo yum install R $ R R version 3.2.3 (2015-12-10) -- "Wooden Christmas-Tree" Copyright (C) 2015 The R Foundation for Statistical Computing Platform: x86_64-redhat-linux-gnu (64-bit) ...
サンプルコードのダウンロード
次に、GitHubで公開されている「入門 機械学習」のサンプルコードを取得します。
$ cd ~ $ git clone https://github.com/johnmyleswhite/ML_for_Hackers.git ml_for_hackers $ cd ml_for_hackers
必要モジュールのインストール
サンプルコードを動かす時に必要なパッケージをインストール。
$ R > source("package_installer.R")
ユーザー権限で実行すると、デフォルトのインストール先に書き込み権がないとのこと。y を押して、ホームディレクトリにインストール。 そこそこ時間がかかるので待つ・・・。
・・・いくつか、エラーになりました。
1: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘rgl’ had non-zero exit status 2: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘XML’ had non-zero exit status 3: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘RCurl’ had non-zero exit status 4: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘Rpoppler’ had non-zero exit status 5: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘XML’ had non-zero exit status 6: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘RCurl’ had non-zero exit status 7: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘XML’ had non-zero exit status
個別にインストールしてみると、依存モジュールがないのが原因のよう。
install.packages("XML", dependencies=TRUE) Cannot find xml2-config ERROR: configuration failed for package ‘XML’ * removing ‘/home/yamautim/R/x86_64-redhat-linux-gnu-library/3.2/XML’
必要なモジュールをインストールして、
$ sudo yum -y install libxml2-devel curl-devel poppler-glib-devel freeglut-devel
再試行。
> install.packages("rgl", dependencies=TRUE) > install.packages("Rpoppler", dependencies=TRUE) > install.packages("XML", dependencies=TRUE) > install.packages("RCurl", dependencies=TRUE)
基礎練習で使うライブラリとデータの読み込み
基礎練習で使うライブラリとデータ(UFOの目撃情報データ)を読み込みます。
> setwd("01-Introduction/") > library("ggplot2") > library(plyr) > library(scales) > ufo <- read.delim("data/ufo/ufo_awesome.tsv", sep="\t", stringsAsFactors=FALSE, header=FALSE, na.strings="")
先頭、末尾の6行を表示。
> head(ufo) V1 V2 V3 V4 V5 1 19951009 19951009 Iowa City, IA <NA> <NA> 2 19951010 19951011 Milwaukee, WI <NA> 2 min. 3 19950101 19950103 Shelton, WA <NA> <NA> ... > tail(ufo) V1 V2 V3 V4 V5 61865 20100828 20100828 Los Angeles, CA disk 40 seconds 61866 20090424 20100820 Hartwell, GA oval 10 min 61867 20100821 20100826 Franklin Square, NY fireball 20 minutes 61868 20100827 20100827 Brighton, CO circle at lest 45 min 61869 20100818 20100821 Dryden (Canada), ON other 5 Min. maybe more 61870 20050502 20100824 Fort Knox, KY triangle 15 seconds
データをクリーニングする
読み込んだデータを、処理しやすい形に変換します。
データ列に名前をつける
> names(ufo) <- c("DateOccurred", "DateReported", "Location", "ShortDescription", "Duration", "LongDescription") > head(ufo) DateOccurred DateReported Location ShortDescription Duration 1 19951009 19951009 Iowa City, IA <NA> <NA> 2 19951010 19951011 Milwaukee, WI <NA> 2 min. 3 19950101 19950103 Shelton, WA <NA> <NA> 4 19950510 19950510 Columbia, MO <NA> 2 min. 5 19950611 19950614 Seattle, WA <NA> <NA> 6 19951025 19951024 Brunswick County, ND <NA> 30 min.
V1
などとなっていたところに、DateOccurred
のようなわかりやすい名前が付きました。
DateOccurred,DateReported を日付に変換する
DateOccurred
列は、日付を示す文字列なので、Dateオブジェクトに変換します。
> ufo$DateOccurred <- as.Date(ufo$DateOccurred, format = "%Y%m%d") strptime(x, format, tz = "GMT") でエラー: 入力文字列が長すぎます
エラーになりました。長い文字列を含む列があるよう。探します。
> head(ufo[which(nchar(ufo$DateOccurred)!=8|nchar(ufo$DateReported)!=8),1]) [1] "ler@gnv.ifas.ufl.edu" [2] "0000" [3] "Callers report sighting a number of soft white balls of lights headingin an easterly directing then changing direction to the west beforespeeding off to the north west." [4] "0000" [5] "0000" [6] "0000"
変なデータがありますな。ということで、8文字でない列を削除します。
# 行のDateOccurred または、DateReported が8文字かどうかを格納するデータ(good.rows)を作成 > good.rows <- ifelse(nchar(ufo$DateOccurred) != 8 | nchar(ufo$DateReported) != 8, FALSE, TRUE) # FALSE(=DateOccurred または、DateReported が8文字でない行)の数を確認。 > length(which(!good.rows)) [1] 731 # good.rowsがTRUEの行のみ抽出 > ufo <- ufo[good.rows, ]
変換を再試行。今度は、日付型に変換できました。
> ufo$DateOccurred <- as.Date(ufo$DateOccurred, format = "%Y%m%d") > ufo$DateReported <- as.Date(ufo$DateReported, format = "%Y%m%d") > head(ufo) DateOccurred DateReported Location ShortDescription Duration 1 1995-10-09 1995-10-09 Iowa City, IA <NA> <NA> 2 1995-10-10 1995-10-11 Milwaukee, WI <NA> 2 min. 3 1995-01-01 1995-01-03 Shelton, WA <NA> <NA>
関数を作ってLocation列のデータを都市名と州名に分割する
Location列のデータは、Iowa City, IA
のような「都市名, 州名」形式の文字列になっています。
関数を使ってこれを都市名, 州名に分解します。
まずは、分解を行う関数を作成。
> get.location <- function(l) { split.location <- tryCatch(strsplit(l, ",")[[1]], error = function(e) return(c(NA, NA))) clean.location <- gsub("^ ","",split.location) if (length(clean.location) > 2) { return(c(NA,NA)) } else { return(clean.location) } }
lapplyを使って、関数を適用したデータのリストを作成。
> city.state <- lapply(ufo$Location, get.location) > head(city.state) [[1]] [1] "Iowa City" "IA" [[2]] [1] "Milwaukee" "WI"
これを、ufoに追加します。まずは、do.call
でリストを行列に変換。
> location.matrix <- do.call(rbind, city.state) > head(location.matrix) [,1] [,2] [1,] "Iowa City" "IA" [2,] "Milwaukee" "WI" [3,] "Shelton" "WA" [4,] "Columbia" "MO" [5,] "Seattle" "WA" [6,] "Brunswick County" "ND"
transform
で ufo
に追加します。
> ufo <- transform(ufo, USCity = location.matrix[, 1], USState = location.matrix[, 2], stringsAsFactors = FALSE)
また、データには、カナダのものが含まれているので、これも除去します。
USState
にアメリカの州名以外が入っているデータを削除。
> ufo$USState <- state.abb[match(ufo$USState, state.abb)] > ufo.us <- subset(ufo, !is.na(USState))
これで、データのクリーニングは完了。
> summary(ufo.us) DateOccurred DateReported Location Min. :1400-06-30 Min. :1905-06-23 Length:51636 1st Qu.:1999-09-07 1st Qu.:2002-04-14 Class :character Median :2004-01-10 Median :2005-03-27 Mode :character Mean :2001-02-13 Mean :2004-11-30 3rd Qu.:2007-07-27 3rd Qu.:2008-01-20 Max. :2010-08-30 Max. :2010-08-30 ShortDescription Duration LongDescription USCity Length:51636 Length:51636 Length:51636 Length:51636 Class :character Class :character Class :character Class :character Mode :character Mode :character Mode :character Mode :character USState Length:51636 Class :character Mode :character > head(ufo.us) DateOccurred DateReported Location ShortDescription Duration 1 1995-10-09 1995-10-09 Iowa City, IA <NA> <NA> 2 1995-10-10 1995-10-11 Milwaukee, WI <NA> 2 min. 3 1995-01-01 1995-01-03 Shelton, WA <NA> <NA> 4 1995-05-10 1995-05-10 Columbia, MO <NA> 2 min. 5 1995-06-11 1995-06-14 Seattle, WA <NA> <NA> 6 1995-10-25 1995-10-24 Brunswick County, ND <NA> 30 min.
データを分析する
クリーニングしたデータを分析して、州/月ごとの目撃情報の傾向を分析します。
DateOccurredのばらつきを調べる
> summary(ufo.us$DateOccurred) Min. 1st Qu. Median Mean 3rd Qu. Max. NA's "1400-06-30" "1999-09-07" "2004-01-10" "2001-02-13" "2007-07-27" "2010-08-30" "1"
1400年のデータが含まれている・・・。ヒストグラムで、分布をみてみます。
> quick.hist <- ggplot(ufo.us, aes(x = DateOccurred)) + geom_histogram() + scale_x_date(date_breaks = "50 years", date_labels = "%Y") > ggsave(plot = quick.hist, filename = file.path("images", "quick_hist.png"), height = 6, width = 8)
大部分は最近の20年に集中している模様。この範囲に絞って分析するため、古いデータを取り除きます。
> ufo.us <- subset(ufo.us, DateOccurred >= as.Date("1990-01-01"))
データを州/月ごとに集計する
まずは、DateOccurred
列を「年-月」に変換した列を作成。
> ufo.us$YearMonth <- strftime(ufo.us$DateOccurred, format = "%Y-%m")
次に、月/州ごとにデータをグループ化して、データ数を集計します。
> sightings.counts <- ddply(ufo.us, .(USState,YearMonth), nrow) > head(sightings.counts) USState YearMonth V1 1 AK 1990-01 1 2 AK 1990-03 1 3 AK 1990-05 1 4 AK 1993-11 1 5 AK 1994-11 1 6 AK 1995-01 1
これだけだと、データが一つもない月が集計結果に含まれないので、それを補完します。
seq.Date
を使ってシーケンスを作成。
> date.range <- seq.Date(from = as.Date(min(ufo.us$DateOccurred)), to = as.Date(max(ufo.us$DateOccurred)), by = "month") > date.strings <- strftime(date.range, "%Y-%m")
作成したシーケンスに、州の一覧を掛け合わせて、州/月の行列を作成します。
> states.dates <- lapply(state.abb, function(s) cbind(s, date.strings)) > states.dates <- data.frame(do.call(rbind, states.dates), stringsAsFactors = FALSE) > head(states.dates) s date.strings 1 AL 1990-01 2 AL 1990-02 3 AL 1990-03 4 AL 1990-04 5 AL 1990-05 6 AL 1990-06
さらに、sightings.counts
をマージして、月の欠落がない集計データを作成。
> all.sightings <- merge(states.dates, sightings.counts, by.x = c("s", "date.strings"), by.y = c("USState", "YearMonth"), all = TRUE) > head(all.sightings) s date.strings V1 1 AK 1990-01 1 2 AK 1990-02 NA 3 AK 1990-03 1 4 AK 1990-04 NA 5 AK 1990-05 1 6 AK 1990-06 NA
わかりやすいよう、列に名前を付けます。また、集計しやすいようにNA
を0
に変換するなどの操作を行っておきます。
> names(all.sightings) <- c("State", "YearMonth", "Sightings") > all.sightings$Sightings[is.na(all.sightings$Sightings)] <- 0 > all.sightings$YearMonth <- as.Date(rep(date.range, length(state.abb))) > all.sightings$State <- as.factor(all.sightings$State) > head(all.sightings) State YearMonth Sightings 1 AK 1990-01-01 1 2 AK 1990-02-01 0 3 AK 1990-03-01 1 4 AK 1990-04-01 0 5 AK 1990-05-01 1 6 AK 1990-06-01 0
分析用データはこれで完成。
月/州ごとの目撃情報数をグラフにして分析する
> state.plot <- ggplot(all.sightings, aes(x = YearMonth,y = Sightings)) + geom_line(aes(color = "darkblue")) + facet_wrap(~State, nrow = 10, ncol = 5) + theme_bw() + scale_color_manual(values = c("darkblue" = "darkblue"), guide = "none") + scale_x_date(date_breaks = "5 years", date_labels = '%Y') + xlab("Years") + ylab("Number of Sightings") + ggtitle("Number of UFO sightings by Month-Year and U.S. State (1990-2010)") > ggsave(plot = state.plot, filename = file.path("images", "ufo_sightings.png"), width = 14, height = 8.5)
こんなグラフになります。
分析は省略。