機械学習手習い: 回帰分析
「入門 機械学習」手習い、5日目。「5章 回帰:ページビューの予測」です。
回帰分析を使って、Webサイトのページビューを予測します。
平均値による単純な予測と、予測精度の計測
回帰分析のイメージをつかむために、平均値を使った単純な予測を試します。 平均値を使って寿命を予測し、平均二乗誤差を使って精度を測定。そこに、追加の情報(喫煙者かどうか)を加えて精度を向上します。
まずは、依存ライブラリの読み込みから。
> setwd("05-Regression/") > library('ggplot2')
寿命データを読み込む
> ages <- read.csv(file.path('data', 'longevity.csv')) > head(ages) Smokes AgeAtDeath 1 1 75 2 1 72 3 1 66 4 1 74 5 1 69 6 1 65
ヒストグラムにしてみます。
> plot = ggplot(ages, aes(x = AgeAtDeath, fill = factor(Smokes))) + geom_density() + facet_grid(Smokes ~ .) > ggsave(plot = plot, filename = file.path("images", "longevity_plot.png"), height = 4.8, width = 7)
非喫煙者のピークが喫煙者より右に寄っていて、やはり、非喫煙者の方が全体的に長生きな感じ。
平均を使った予測とその精度を測る
データの概要がつかめたところで、平均値を使った予測をしてみます。 まずは、平均値を算出。
> mean(ages$AgeAtDeath) [1] 72.723
平均2乗誤差を計算。
> guess <- 72.723 > with(ages, mean((AgeAtDeath - guess) ^ 2)) [1] 32.91427
32.91427
となりました。
喫煙者かどうかを考慮して、精度を向上する
喫煙者/非喫煙者それぞれで平均値を計測し、推測値として使います。
> smokers.guess <- with(subset(ages, Smokes == 1), mean(AgeAtDeath)) > non.smokers.guess <- with(subset(ages, Smokes == 0), mean(AgeAtDeath)) > ages <- transform(ages, NewPrediction = ifelse(Smokes == 0, non.smokers.guess, smokers.guess)) > with(ages, mean((AgeAtDeath - NewPrediction) ^ 2)) [1] 26.50831
誤差が減り、精度が向上しました。
線形回帰入門
線形回帰で予測を行うには、次の2つの仮定を置くことが前提となります。
可分性/加法性
単調性/線形性
- 特徴量に応じて、予測値が常に上昇、または減少すること。
これを踏まえて、実習を開始。身長/体重データを読み込んで、簡単な予測を試します。
# データを読み込み > heights.weights <- read.csv( file.path('data', '01_heights_weights_genders.csv'), header = TRUE, sep = ',')
散布図を描いてみます。 geom_smooth(method = 'lm')
で回帰線も引きます。
# 散布図を描く > plot <- ggplot(heights.weights, aes(x = Height, y = Weight)) + geom_point() + geom_smooth(method = 'lm') > ggsave(plot = plot, filename = file.path("images", "height_weight.png"), height = 4.8, width = 7)
青の線が身長から体重を予測する回帰線です。 身長60インチの人は、体重105ポンドくらいと予測できます。大体、妥当な感じ。
Rでは、lm
関数で、線形モデルでの回帰分析を行うことができます。
> fitted.regression <- lm(Weight ~ Height, data = heights.weights)
coef
関数を使うと、傾き( Height
)と切片( Intercept
)を確認できます。
> coef(fitted.regression) (Intercept) Height -350.737192 7.717288
predict
で予測値が得られます。
> head(predict(fitted.regression)) 1 2 3 4 5 6 219.1615 180.0725 221.1918 202.8314 188.5607 168.2737
residuals
を使うと、実際のデータとの誤差を計算できます。
head(residuals(fitted.regression)) 1 2 3 4 5 6 22.732083 -17.762074 -8.450953 17.211069 17.789073 -16.061519
ウェブサイトのアクセス数を予測する
まずはデータを読み込み。
> top.1000.sites <- read.csv(file.path('data', 'top_1000_sites.tsv'), sep = '\t', stringsAsFactors = FALSE) > head(top.1000.sites) Rank Site Category UniqueVisitors Reach 1 1 facebook.com Social Networks 880000000 47.2 2 2 youtube.com Online Video 800000000 42.7 3 3 yahoo.com Web Portals 660000000 35.3 4 4 live.com Search Engines 550000000 29.3 5 5 wikipedia.org Dictionaries & Encyclopedias 490000000 26.2 6 6 msn.com Web Portals 450000000 24.0 PageViews HasAdvertising InEnglish TLD 1 9.1e+11 Yes Yes com 2 1.0e+11 Yes Yes com 3 7.7e+10 Yes Yes com 4 3.6e+10 Yes Yes com 5 7.0e+09 No Yes org 6 1.5e+10 Yes Yes com
このデータから、PageView
を推測するのが今回の目的。
まずは、線形回帰が適用できるか、確認。ページビューの分布をみてみます。
> plot = ggplot(top.1000.sites, aes(x = PageViews)) + geom_density() > ggsave(plot = plot, filename = file.path("images", "top_1000_sites_page_view.png"), height = 4.8, width = 7)
ばらつきが大きすぎてよくわからない・・・。対数を取ってみます。
> plot = ggplot(top.1000.sites, aes(x = log(PageViews))) + geom_density() > ggsave(plot = plot, filename = file.path("images", "top_1000_sites_page_view2.png"), height = 4.8, width = 7)
この尺度なら、傾向が見えそう。UniqueVisitor
と PageView
の散布図を描いてみます。
> ggplot(top.1000.sites, aes(x = log(PageViews), y = log(UniqueVisitors))) + geom_point() > ggsave(file.path("images", "log_page_views_vs_log_visitors.png"))
回帰直線も追加。
> ggplot(top.1000.sites, aes(x = log(PageViews), y = log(UniqueVisitors))) + geom_point() + geom_smooth(method = 'lm', se = FALSE) > ggsave(file.path("images", "log_page_views_vs_log_visitors_with_lm.png"))
うまくいっているっぽい。lm
で回帰分析を実行。
> lm.fit <- lm(log(PageViews) ~ log(UniqueVisitors), data = top.1000.sites)
summary
を使うと、分析結果の要約が表示されます。
> summary(lm.fit) Call: lm(formula = log(PageViews) ~ log(UniqueVisitors), data = top.1000.sites) Residuals: Min 1Q Median 3Q Max -2.1825 -0.7986 -0.0741 0.6467 5.1549 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) -2.83441 0.75201 -3.769 0.000173 *** log(UniqueVisitors) 1.33628 0.04568 29.251 < 2e-16 *** --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 1.084 on 998 degrees of freedom Multiple R-squared: 0.4616, Adjusted R-squared: 0.4611 F-statistic: 855.6 on 1 and 998 DF, p-value: < 2.2e-16
このうち、 Residual standard error
が平均2乗誤差の平方根(RMSE)をとったもの。
さらに、他の情報も考慮に加えてみます。
> lm.fit <- lm(log(PageViews) ~ HasAdvertising + log(UniqueVisitors) + InEnglish, data = top.1000.sites) > summary(lm.fit) Call: lm(formula = log(PageViews) ~ HasAdvertising + log(UniqueVisitors) + InEnglish, data = top.1000.sites) Residuals: Min 1Q Median 3Q Max -2.4283 -0.7685 -0.0632 0.6298 5.4133 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) -1.94502 1.14777 -1.695 0.09046 . HasAdvertisingYes 0.30595 0.09170 3.336 0.00088 *** log(UniqueVisitors) 1.26507 0.07053 17.936 < 2e-16 *** InEnglishNo 0.83468 0.20860 4.001 6.77e-05 *** InEnglishYes -0.16913 0.20424 -0.828 0.40780 --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 1.067 on 995 degrees of freedom Multiple R-squared: 0.4798, Adjusted R-squared: 0.4777 F-statistic: 229.4 on 4 and 995 DF, p-value: < 2.2e-16
精度が少し上がりました。