無料で使えるシステムトレードフレームワーク「Jiji」 をリリースしました!

・OANDA Trade APIを利用した、オープンソースのシステムトレードフレームワークです。
・自分だけの取引アルゴリズムで、誰でも、いますぐ、かんたんに、自動取引を開始できます。

機械学習手習い: 回帰分析

「入門 機械学習」手習い、5日目。「5章 回帰:ページビューの予測」です。

www.amazon.co.jp

回帰分析を使って、Webサイトのページビューを予測します。

平均値による単純な予測と、予測精度の計測

回帰分析のイメージをつかむために、平均値を使った単純な予測を試します。 平均値を使って寿命を予測し、平均二乗誤差を使って精度を測定。そこに、追加の情報(喫煙者かどうか)を加えて精度を向上します。

まずは、依存ライブラリの読み込みから。

> setwd("05-Regression/")
> library('ggplot2')

寿命データを読み込む

> ages <- read.csv(file.path('data', 'longevity.csv'))
> head(ages)
  Smokes AgeAtDeath
1      1         75
2      1         72
3      1         66
4      1         74
5      1         69
6      1         65

ヒストグラムにしてみます。

> plot = ggplot(ages, aes(x = AgeAtDeath, fill = factor(Smokes))) +
  geom_density() + facet_grid(Smokes ~ .)
> ggsave(plot = plot,
       filename = file.path("images", "longevity_plot.png"),
       height = 4.8,
       width = 7)

f:id:unageanu:20160113141102p:plain

非喫煙者のピークが喫煙者より右に寄っていて、やはり、非喫煙者の方が全体的に長生きな感じ。

平均を使った予測とその精度を測る

データの概要がつかめたところで、平均値を使った予測をしてみます。 まずは、平均値を算出。

> mean(ages$AgeAtDeath)
[1] 72.723

平均2乗誤差を計算。

> guess <- 72.723
> with(ages, mean((AgeAtDeath - guess) ^ 2))
[1] 32.91427

32.91427 となりました。

喫煙者かどうかを考慮して、精度を向上する

喫煙者/非喫煙者それぞれで平均値を計測し、推測値として使います。

> smokers.guess <- with(subset(ages, Smokes == 1), mean(AgeAtDeath))
> non.smokers.guess <- with(subset(ages, Smokes == 0), mean(AgeAtDeath))
> ages <- transform(ages, NewPrediction = ifelse(Smokes == 0,
                     non.smokers.guess, smokers.guess))
> with(ages, mean((AgeAtDeath - NewPrediction) ^ 2))
[1] 26.50831

誤差が減り、精度が向上しました。

線形回帰入門

線形回帰で予測を行うには、次の2つの仮定を置くことが前提となります。

  • 可分性/加法性

    • 独立する特徴量が同時に起こった場合、効果は足し合わされること。
    • 例) アルコール中毒患者の平均寿命がそうでない人より1年短く、喫煙者の平均寿命がそうでない人より5年短い場合、アルコール中毒の喫煙者は6年短いと推測する
  • 単調性/線形性

    • 特徴量に応じて、予測値が常に上昇、または減少すること。

これを踏まえて、実習を開始。身長/体重データを読み込んで、簡単な予測を試します。

# データを読み込み
> heights.weights <- read.csv(
  file.path('data', '01_heights_weights_genders.csv'), header = TRUE, sep = ',')

散布図を描いてみます。 geom_smooth(method = 'lm') で回帰線も引きます。

# 散布図を描く
> plot <- ggplot(heights.weights, aes(x = Height, y = Weight)) +
  geom_point() + geom_smooth(method = 'lm')
> ggsave(plot = plot,
       filename = file.path("images", "height_weight.png"),
       height = 4.8,
       width = 7)

f:id:unageanu:20160113141059p:plain

青の線が身長から体重を予測する回帰線です。 身長60インチの人は、体重105ポンドくらいと予測できます。大体、妥当な感じ。

Rでは、lm 関数で、線形モデルでの回帰分析を行うことができます。

> fitted.regression <- lm(Weight ~ Height, data = heights.weights)

coef 関数を使うと、傾き( Height )と切片( Intercept )を確認できます。

> coef(fitted.regression)
(Intercept)      Height 
-350.737192    7.717288 

predict で予測値が得られます。

> head(predict(fitted.regression))                                                                                                                                 
       1        2        3        4        5        6 
219.1615 180.0725 221.1918 202.8314 188.5607 168.2737 

residuals を使うと、実際のデータとの誤差を計算できます。

head(residuals(fitted.regression))
         1          2          3          4          5          6 
 22.732083 -17.762074  -8.450953  17.211069  17.789073 -16.061519 

ウェブサイトのアクセス数を予測する

まずはデータを読み込み。

> top.1000.sites <- read.csv(file.path('data', 'top_1000_sites.tsv'), 
  sep = '\t', stringsAsFactors = FALSE)
> head(top.1000.sites)
  Rank          Site                     Category UniqueVisitors Reach
1    1  facebook.com              Social Networks      880000000  47.2
2    2   youtube.com                 Online Video      800000000  42.7
3    3     yahoo.com                  Web Portals      660000000  35.3
4    4      live.com               Search Engines      550000000  29.3
5    5 wikipedia.org Dictionaries & Encyclopedias      490000000  26.2
6    6       msn.com                  Web Portals      450000000  24.0
  PageViews HasAdvertising InEnglish TLD
1   9.1e+11            Yes       Yes com
2   1.0e+11            Yes       Yes com
3   7.7e+10            Yes       Yes com
4   3.6e+10            Yes       Yes com
5   7.0e+09             No       Yes org
6   1.5e+10            Yes       Yes com

このデータから、PageView を推測するのが今回の目的。 まずは、線形回帰が適用できるか、確認。ページビューの分布をみてみます。

> plot = ggplot(top.1000.sites, aes(x = PageViews)) + geom_density()
> ggsave(plot = plot,
       filename = file.path("images", "top_1000_sites_page_view.png"),
       height = 4.8,
       width = 7)

f:id:unageanu:20160113141103p:plain

ばらつきが大きすぎてよくわからない・・・。対数を取ってみます。

> plot = ggplot(top.1000.sites, aes(x = log(PageViews))) + geom_density()
> ggsave(plot = plot,
       filename = file.path("images", "top_1000_sites_page_view2.png"),
       height = 4.8,
       width = 7)

f:id:unageanu:20160113141104p:plain

この尺度なら、傾向が見えそう。UniqueVisitorPageView の散布図を描いてみます。

> ggplot(top.1000.sites, aes(x = log(PageViews), y = log(UniqueVisitors))) + geom_point()
> ggsave(file.path("images", "log_page_views_vs_log_visitors.png"))

f:id:unageanu:20160113141100p:plain

回帰直線も追加。

> ggplot(top.1000.sites, aes(x = log(PageViews), y = log(UniqueVisitors))) +
  geom_point() +
  geom_smooth(method = 'lm', se = FALSE)
> ggsave(file.path("images", "log_page_views_vs_log_visitors_with_lm.png"))

f:id:unageanu:20160113141101p:plain

うまくいっているっぽい。lm で回帰分析を実行。

> lm.fit <- lm(log(PageViews) ~ log(UniqueVisitors),
             data = top.1000.sites)

summary を使うと、分析結果の要約が表示されます。

> summary(lm.fit)

Call:
lm(formula = log(PageViews) ~ log(UniqueVisitors), data = top.1000.sites)

Residuals:
    Min      1Q  Median      3Q     Max 
-2.1825 -0.7986 -0.0741  0.6467  5.1549 

Coefficients:
                    Estimate Std. Error t value Pr(>|t|)    
(Intercept)         -2.83441    0.75201  -3.769 0.000173 ***
log(UniqueVisitors)  1.33628    0.04568  29.251  < 2e-16 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 1.084 on 998 degrees of freedom
Multiple R-squared:  0.4616,    Adjusted R-squared:  0.4611 
F-statistic: 855.6 on 1 and 998 DF,  p-value: < 2.2e-16

このうち、 Residual standard error が平均2乗誤差の平方根(RMSE)をとったもの。

さらに、他の情報も考慮に加えてみます。

> lm.fit <- lm(log(PageViews) ~ HasAdvertising + log(UniqueVisitors) + InEnglish,
             data = top.1000.sites)
> summary(lm.fit)

Call:
lm(formula = log(PageViews) ~ HasAdvertising + log(UniqueVisitors) + 
    InEnglish, data = top.1000.sites)

Residuals:
    Min      1Q  Median      3Q     Max 
-2.4283 -0.7685 -0.0632  0.6298  5.4133 

Coefficients:
                    Estimate Std. Error t value Pr(>|t|)    
(Intercept)         -1.94502    1.14777  -1.695  0.09046 .  
HasAdvertisingYes    0.30595    0.09170   3.336  0.00088 ***
log(UniqueVisitors)  1.26507    0.07053  17.936  < 2e-16 ***
InEnglishNo          0.83468    0.20860   4.001 6.77e-05 ***
InEnglishYes        -0.16913    0.20424  -0.828  0.40780    
---
Signif. codes:  0***0.001**0.01*0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 1.067 on 995 degrees of freedom
Multiple R-squared:  0.4798,    Adjusted R-squared:  0.4777 
F-statistic: 229.4 on 4 and 995 DF,  p-value: < 2.2e-16

精度が少し上がりました。