機械学習手習い: スパムフィルタを作る
「入門 機械学習」手習い、3日目。「3章 分類:スパムフィルタ」です。
ナイーブベイズ分類器を作って、メールがスパムかどうかを判定するフィルタを作ります。
分類器の仕組み
- 1) 以下の単語セットを作成
- (a) スパムメッセージに出現しやすい単語とその出現確率
- (b) スパムメッセージに出現しにくい単語とその出現確率
- 2) で作成した単語セットを元に、メール本文を評価し、以下を算出
- (a2) メールをスパムと仮定した時の尤もらしさ
- (b2) メールを非スパムと仮定した時の尤もらしさ
- 3) a2 > b2 となるメールをスパムと判定する
という感じで判定を行います。
必要なモジュールとデータの読み込み
> setwd("03-Classification/") > library('tm') > library('ggplot2') # テスト用データ # 分類機の訓練用 > spam.path <- file.path("data", "spam") # スパムデータ > easyham.path <- file.path("data", "easy_ham") # 非スパムデータ(易) > hardham.path <- file.path("data", "hard_ham") # 非スパムデータ(難) # 分類機のテスト用 > spam2.path <- file.path("data", "spam_2") # スパムデータ > easyham2.path <- file.path("data", "easy_ham_2") # 非スパムデータ(易) > hardham2.path <- file.path("data", "hard_ham_2") # 非スパムデータ(難)
メールから本文を取り出す
ファイルを読み込んで、本文を返す関数を作成します。
> get.msg <- function(path) { con <- file(path, open = "rt", encoding = "latin1") text <- readLines(con) # The message always begins after the first full line break msg <- text[seq(which(text == "")[1] + 1, length(text), 1)] close(con) return(paste(msg, collapse = "\n")) }
動作テスト。
> get.msg("data/spam/00001.7848dde101aa985090474a91ec93fcf0") [1] "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0 Transitio...
sapply
を使って、data/spam
内のスパムメールから本文を読み込みます。
> spam.docs <- dir(spam.path) > spam.docs <- spam.docs[which(spam.docs != "cmds")] > all.spam <- sapply(spam.docs,function(p) get.msg(file.path(spam.path, p))) > head(all.spam)
単語文書行列を作る
単語を行、文書を列として、文書中の単語の出現数をカウントした、単語文書行列(TDM:Term Document Matrix)を作ります。
↓こんなイメージ
a.txt | b.txt | c.txt | |
---|---|---|---|
hello | 2 | 1 | 0 |
world | 0 | 2 | 3 |
まずは、TDMを作成する関数を定義。
> get.tdm <- function(doc.vec){ control <- list(stopwords = TRUE, removePunctuation = TRUE, removeNumbers = TRUE, minDocFreq = 2) doc.corpus <- Corpus(VectorSource(doc.vec)) doc.dtm <- TermDocumentMatrix(doc.corpus, control) return(doc.dtm) }
all.spam
から、TDMを作成します。
> spam.tdm <- get.tdm(all.spam)
スパムの訓練データを作る
スパムに含まれる単語の出現確率を集計します。
# all.spamのTDMをRの行列に変換 > spam.matrix <- as.matrix(spam.tdm) # 各単語の、全スパム中での出現頻度をカウント > spam.counts <- rowSums(spam.matrix) # データフレームに変換 > spam.df <- data.frame(cbind(names(spam.counts), as.numeric(spam.counts)), stringsAsFactors = FALSE) > names(spam.df) <- c("term", "frequency") # 出現確率を算出 > spam.df$frequency <- as.numeric(spam.df$frequency) > spam.occurrence <- sapply(1:nrow(spam.matrix), function(i) { length(which(spam.matrix[i, ] > 0)) / ncol(spam.matrix) }) > spam.density <- spam.df$frequency / sum(spam.df$frequency) > spam.df <- transform(spam.df, density = spam.density, occurrence = spam.occurrence)
訓練データの中身を確認。
> head(spam.df[with(spam.df, order(-occurrence)),]) term frequency density occurrence 7693 email 813 0.005855914 0.566 18706 please 425 0.003061210 0.508 14623 list 409 0.002945964 0.444 27202 will 828 0.005963957 0.422 2970 body 379 0.002729879 0.408 9369 free 543 0.003911146 0.390
スパムメール中の56%が email
を含んでいます。
非スパムの訓練データを作る
スパムの訓練データを作ったのと同じ手順で、非スパムの訓練データも作ります。
> easyham.docs <- dir(easyham.path) > easyham.docs <- easyham.docs[which(easyham.docs != "cmds")] # 最初の分類機を作る際に、各メッセージがスパムである確率とそうでない確率を等しいと仮定するため、 # 訓練データもスパムの文書数と同じ数だけに限定する。 > all.easyham <- sapply(easyham.docs[1:length(spam.docs)], function(p) get.msg(file.path(easyham.path, p))) > easyham.tdm <- get.tdm(all.easyham) > easyham.matrix <- as.matrix(easyham.tdm) > easyham.counts <- rowSums(easyham.matrix) > easyham.df <- data.frame(cbind(names(easyham.counts), as.numeric(easyham.counts)), stringsAsFactors = FALSE) > names(easyham.df) <- c("term", "frequency") > easyham.df$frequency <- as.numeric(easyham.df$frequency) > easyham.occurrence <- sapply(1:nrow(easyham.matrix), function(i) { length(which(easyham.matrix[i, ] > 0)) / ncol(easyham.matrix) }) > easyham.density <- easyham.df$frequency / sum(easyham.df$frequency) > easyham.df <- transform(easyham.df, density = easyham.density, occurrence = easyham.occurrence)
訓練データの中身を確認。
> head(easyham.df[with(easyham.df, order(-occurrence)),]) term frequency density occurrence 5193 group 232 0.003317366 0.388 12877 use 271 0.003875027 0.378 13511 wrote 237 0.003388861 0.378 1629 can 348 0.004976049 0.368 7244 list 248 0.003546150 0.368 8581 one 356 0.005090441 0.336
分類器を作る
メールファイルを受け取り、それがスパム、または、非スパムである確率を計算する分類器を定義します。
> classify.email <- function(path, training.df, prior = 0.5, c = 1e-6) { # ファイルから本文を取り出し、単語の出現数をカウント msg <- get.msg(path) msg.tdm <- get.tdm(msg) msg.freq <- rowSums(as.matrix(msg.tdm)) # メール内の単語のうち、訓練データに含まれる単語を取得 msg.match <- intersect(names(msg.freq), training.df$term) # 単語の出現確率を掛け合わせて、条件付き確率を算出する # この時、訓練データに含まれない単語は、非常に小さい確率0.0001として扱う。 if(length(msg.match) < 1) { # 訓練データに含まれる単語が1つもない場合、条件付き確率は、 # 事前確率(prior) * (0.0001の単語数乗) # となる。 return(prior * c ^ (length(msg.freq))) } else { # 訓練データから、単語ごとの出現確率を取り出す。 match.probs <- training.df$occurrence[match(msg.match, training.df$term)] # 条件付き確率を計算 # 事前確率(prior) * (訓練データに含まれる単語の出現確率の積) * (0.0001の訓練データに含まれない単語数乗) return(prior * prod(match.probs) * c ^ (length(msg.freq) - length(msg.match))) } }
判定してみる
非スパム(難)のデータを判定してみます。
> hardham.docs <- dir(hardham.path) > hardham.docs <- hardham.docs[which(hardham.docs != "cmds")] > hardham.spamtest <- sapply(hardham.docs, function(p) classify.email(file.path(hardham.path, p), training.df = spam.df)) > hardham.hamtest <- sapply(hardham.docs, function(p) classify.email(file.path(hardham.path, p), training.df = easyham.df)) > hardham.res <- ifelse(hardham.spamtest > hardham.hamtest, TRUE, FALSE) > summary(hardham.res) Mode FALSE TRUE NA's logical 243 6 0
データはすべて非スパムのものなので、誤判定は6件、偽陽性率は2.4%。
テスト用データすべてを判定してみる
スパム判定を行う関数を作成。
spam.classifier <- function(path) { pr.spam <- classify.email(path, spam.df) pr.ham <- classify.email(path, easyham.df) return(c(pr.spam, pr.ham, ifelse(pr.spam > pr.ham, 1, 0))) }
テスト用のメールデータをすべて判定してみます。
> easyham2.docs <- dir(easyham2.path) > easyham2.docs <- easyham2.docs[which(easyham2.docs != "cmds")] > hardham2.docs <- dir(hardham2.path) > hardham2.docs <- hardham2.docs[which(hardham2.docs != "cmds")] > spam2.docs <- dir(spam2.path) > spam2.docs <- spam2.docs[which(spam2.docs != "cmds")] > easyham2.class <- suppressWarnings(lapply(easyham2.docs, function(p) { spam.classifier(file.path(easyham2.path, p)) })) > hardham2.class <- suppressWarnings(lapply(hardham2.docs, function(p) { spam.classifier(file.path(hardham2.path, p)) })) > spam2.class <- suppressWarnings(lapply(spam2.docs, function(p) { spam.classifier(file.path(spam2.path, p)) })) > easyham2.matrix <- do.call(rbind, easyham2.class) > easyham2.final <- cbind(easyham2.matrix, "EASYHAM") > hardham2.matrix <- do.call(rbind, hardham2.class) > hardham2.final <- cbind(hardham2.matrix, "HARDHAM") > spam2.matrix <- do.call(rbind, spam2.class) > spam2.final <- cbind(spam2.matrix, "SPAM") > class.matrix <- rbind(easyham2.final, hardham2.final, spam2.final) > class.df <- data.frame(class.matrix, stringsAsFactors = FALSE) > names(class.df) <- c("Pr.SPAM" ,"Pr.HAM", "Class", "Type") > class.df$Pr.SPAM <- as.numeric(class.df$Pr.SPAM) > class.df$Pr.HAM <- as.numeric(class.df$Pr.HAM) > class.df$Class <- as.logical(as.numeric(class.df$Class)) > class.df$Type <- as.factor(class.df$Type)
文書ごとのスパム/非スパムの尤もらしさ、分類結果、メールの種別を含むデータフレームができました。
> head(class.df) Pr.SPAM Pr.HAM Class Type 1 0.000000e+00 0.000000e+00 FALSE EASYHAM 2 5.352364e-248 1.159512e-155 FALSE EASYHAM 3 0.000000e+00 5.103377e-216 FALSE EASYHAM 4 0.000000e+00 0.000000e+00 FALSE EASYHAM 5 2.083521e-169 1.221918e-108 FALSE EASYHAM 6 0.000000e+00 0.000000e+00 FALSE EASYHAM
文書の種類ごとに結果を集計してみます。
> get.results <- function(bool.vector) { results <- c(length(bool.vector[which(bool.vector == FALSE)]) / length(bool.vector), length(bool.vector[which(bool.vector == TRUE)]) / length(bool.vector)) return(results) } > easyham2.col <- get.results(subset(class.df, Type == "EASYHAM")$Class) > hardham2.col <- get.results(subset(class.df, Type == "HARDHAM")$Class) > spam2.col <- get.results(subset(class.df, Type == "SPAM")$Class) > class.res <- rbind(easyham2.col, hardham2.col, spam2.col) > colnames(class.res) <- c("NOT SPAM", "SPAM") > print(class.res) NOT SPAM SPAM easyham2.col 0.9871429 0.01285714 hardham2.col 0.9677419 0.03225806 spam2.col 0.4631353 0.53686471
非スパム(easyham2, hardham2)を間違ってスパムと判定する確率(偽陽性率)はそれぞれ1%,3%と低く、うまく分類できているもよう。 ただ、スパム(spam2)を間違ってスパムでないと判定する確率(偽陰性率)は46%。あれ、ちょっと高いかも?
結果を分散図にしてみる
> class.plot <- ggplot(class.df, aes(x = log(Pr.HAM), log(Pr.SPAM))) + geom_point(aes(shape = Type, alpha = 0.5)) + geom_abline(intercept = 0, slope = 1) + scale_shape_manual(values = c("EASYHAM" = 1, "HARDHAM" = 2, "SPAM" = 3), name = "Email Type") + scale_alpha(guide = "none") + xlab("log[Pr(HAM)]") + ylab("log[Pr(SPAM)]") + theme_bw() + theme(axis.text.x = element_blank(), axis.text.y = element_blank()) > ggsave(plot = class.plot, filename = file.path("images", "03_final_classification.png"), height = 10, width = 10)
横軸が「メールを非スパムと仮定した時の尤もらしさ」、縦軸が「メールをスパムと仮定した時の尤もらしさ」を示します。
「メールをスパムと仮定した時の尤もらしさ」 > 「メールを非スパムと仮定した時の尤もらしさ」となったメールをスパムと判定するので、真ん中の線より上の者はスパム、下は非スパムと判定されています。線より上に〇や△(非スパムのメール)がいくつかあったりはしますが、おおむね正しく判定できている感じですね。
事前分布を変えて、結果を改善する
↑では、とあるメールがあった時にそれがスパムである確率は50%(=世の中のメールの半分はスパムで半分はそうではない)と仮定していました。 現実には、80%は非スパム、残り20%がスパムなので、これを考慮して再計算することで結果を改善してみます。
spam.classifier
を修正して、事前確率を変更。
spam.classifier <- function(path) { pr.spam <- classify.email(path, spam.df, prior=0.2) pr.ham <- classify.email(path, easyham.df, prior=0.8) return(c(pr.spam, pr.ham, ifelse(pr.spam > pr.ham, 1, 0))) }
再計算した結果は以下。
> print(class.res) NOT SPAM SPAM easyham2.col 0.9892857 0.01071429 hardham2.col 0.9717742 0.02822581 spam2.col 0.4652827 0.53471725
非スパム(easyham2, hardham2)を間違ってスパムと判定する確率は少し改善しました。 一方、スパム(spam2)を間違ってスパムでないと判定する確率は悪化。。。
機械学習手習い: 数値によるデータの要約と可視化手法
「入門 機械学習」手習い、今日は「2章 データの調査」です。
数値によるデータの要約と、可視化手法を学びます。
テスト用データの読み込み
> setwd("02-Exploration/") > data.file <- file.path('data', '01_heights_weights_genders.csv') > heights.weights <- read.csv(data.file, header = TRUE, sep = ',') > head(heights.weights) Gender Height Weight 1 Male 73.84702 241.8936 2 Male 68.78190 162.3105 3 Male 74.11011 212.7409 4 Male 71.73098 220.0425 5 Male 69.88180 206.3498 6 Male 67.25302 152.2122
データの数値による要約
summary
でベクトルの数値を要約します
> summary(heights.weights$Height) Min. 1st Qu. Median Mean 3rd Qu. Max. 54.26 63.51 66.32 66.37 69.17 79.00
左から、
Min
.. 最小値1st Qu
.. 第一四分位(データ全体の下から25%の位置にあたる値)Median
.. 中央値(データ全体の50%の位置にあたる値)Mean
.. 平均値3rd Qu.
.. (データ全体の下から75%の位置にあたる値)Max
.. 最大値
が表示されます。
最小値、最大値を求める
min/max
を使って、最小値/最大値を算出できます
# Heightだけを含むベクトルを作成 > heights <- with(heights.weights, Height) > head(heights) [1] 73.84702 68.78190 74.11011 71.73098 69.88180 67.25302 > min(heights) [1] 54.26313 > max(heights) [1] 78.99874
range
で、両方をまとめて計算することもできます。
> range(heights) [1] 54.26313 78.99874
分位数を求める
quantile
で、データ中の各位置のデータを出力できます。
> quantile(heights) 0% 25% 50% 75% 100% 54.26313 63.50562 66.31807 69.17426 78.99874
分割幅を指定することもできます。
> quantile(heights, probs = seq(0, 1, by = 0.20)) 0% 20% 40% 60% 80% 100% 54.26313 62.85901 65.19422 67.43537 69.81162 78.99874
分散と標準偏差を求める
var
,sd
を使います。
# 標準偏差 > var(heights) [1] 14.80347 # 分散 > sd(heights) [1] 3.847528
データの可視化
必要なライブラリを読み込み。
> library('ggplot2')
ヒストグラム
> plot = ggplot(heights.weights, aes(x = Height)) + geom_histogram(binwidth = 1) > ggsave(plot = plot, filename = "histgram.png", width = 6, height = 8)
密度プロットにしてみます。少ないデータ量でも、データセットの形状が分かりやすいのがメリット。
> plot = ggplot(heights.weights, aes(x = Height)) + geom_density() > ggsave(plot = plot, filename = "kde_histgram.png", width = 6, height = 8)
性別ごとの特徴をみるため、性別ごとのヒストグラムを表示してみます。
> plot = ggplot(heights.weights, aes(x = Height, fill = Gender)) + geom_density() + facet_grid(Gender ~ .) > ggsave(plot = plot, filename = "gender_histgram.png", width = 6, height = 8)
- 正規分布
- ピーク(=最頻値)が1つしかない、単峰分布
- 左右が対称
- 裾が薄い(データのばらつきが小さい)
- コーシー分布
- ピーク(=最頻値)が1つしかない、単峰分布
- 左右が対称
- 裾が厚い(データのばらつきが大きい)
- ガンマ分布
- 左右が非対称で、平均値と中央値が大きく異なる
- 指数分布
- 左右が非対称で、最頻値がゼロ。
正規分布の例。
> set.seed(1) > normal.values <- rnorm(250, 0, 1) > plot = ggplot(data.frame(X = normal.values), aes(x = X)) + geom_density() > ggsave(plot = plot, filename = "normal_histgram.png", width = 6, height = 8)
コーシー分布。
> cauchy.values <- rcauchy(250, 0, 1) > plot = ggplot(data.frame(X = cauchy.values), aes(x = X)) + geom_density() > ggsave(plot = plot, filename = "cauchy_histgram.png", width = 6, height = 8)
ガンマ分布。
> gamma.values <- rgamma(100000, 1, 0.001) > plot = ggplot(data.frame(X = gamma.values), aes(x = X)) + geom_density() > ggsave(plot = plot, filename = "gamma_histgram.png", width = 6, height = 8)
指数分布、はない。
散布図
身長と体重の散布図を描きます。
> plot = ggplot(heights.weights, aes(x = Height, y = Weight)) + geom_point() > ggsave(plot = plot, filename = "scatterplots.png", width = 6, height = 8)
身長、体重には相関関係がありそう。geom_smooth()
を使って、妥当な予測領域を表示してみます。
> plot = ggplot(heights.weights, aes(x = Height, y = Weight)) + geom_point() + geom_smooth() > ggsave(plot = plot, filename = "scatterplots2.png", width = 6, height = 8)
最後に、男女別の散布図を描いて終わり。
> plot = ggplot(heights.weights, aes(x = Height, y = Weight, color = Gender)) + geom_point() > ggsave(plot = plot, filename = "gender_scatterplots.png", width = 6, height = 8)
機械学習手習い : Rをインストールして、基本的な使い方を学ぶ
オライリーの「入門 機械学習」を手に入れたので、手を動かしながら学びます。
まずは、1章。Rのインストールと基本的な使い方の学習まで。
Rのインストール
手元にあったCentOS7にインストールしました。
$ cat /etc/redhat-release CentOS Linux release 7.1.1503 (Core) $ sudo yum install epel-release $ sudo yum install R $ R R version 3.2.3 (2015-12-10) -- "Wooden Christmas-Tree" Copyright (C) 2015 The R Foundation for Statistical Computing Platform: x86_64-redhat-linux-gnu (64-bit) ...
サンプルコードのダウンロード
次に、GitHubで公開されている「入門 機械学習」のサンプルコードを取得します。
$ cd ~ $ git clone https://github.com/johnmyleswhite/ML_for_Hackers.git ml_for_hackers $ cd ml_for_hackers
必要モジュールのインストール
サンプルコードを動かす時に必要なパッケージをインストール。
$ R > source("package_installer.R")
ユーザー権限で実行すると、デフォルトのインストール先に書き込み権がないとのこと。y を押して、ホームディレクトリにインストール。 そこそこ時間がかかるので待つ・・・。
・・・いくつか、エラーになりました。
1: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘rgl’ had non-zero exit status 2: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘XML’ had non-zero exit status 3: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘RCurl’ had non-zero exit status 4: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘Rpoppler’ had non-zero exit status 5: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘XML’ had non-zero exit status 6: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘RCurl’ had non-zero exit status 7: install.packages(p, dependencies = TRUE, type = "source") で: installation of package ‘XML’ had non-zero exit status
個別にインストールしてみると、依存モジュールがないのが原因のよう。
install.packages("XML", dependencies=TRUE) Cannot find xml2-config ERROR: configuration failed for package ‘XML’ * removing ‘/home/yamautim/R/x86_64-redhat-linux-gnu-library/3.2/XML’
必要なモジュールをインストールして、
$ sudo yum -y install libxml2-devel curl-devel poppler-glib-devel freeglut-devel
再試行。
> install.packages("rgl", dependencies=TRUE) > install.packages("Rpoppler", dependencies=TRUE) > install.packages("XML", dependencies=TRUE) > install.packages("RCurl", dependencies=TRUE)
基礎練習で使うライブラリとデータの読み込み
基礎練習で使うライブラリとデータ(UFOの目撃情報データ)を読み込みます。
> setwd("01-Introduction/") > library("ggplot2") > library(plyr) > library(scales) > ufo <- read.delim("data/ufo/ufo_awesome.tsv", sep="\t", stringsAsFactors=FALSE, header=FALSE, na.strings="")
先頭、末尾の6行を表示。
> head(ufo) V1 V2 V3 V4 V5 1 19951009 19951009 Iowa City, IA <NA> <NA> 2 19951010 19951011 Milwaukee, WI <NA> 2 min. 3 19950101 19950103 Shelton, WA <NA> <NA> ... > tail(ufo) V1 V2 V3 V4 V5 61865 20100828 20100828 Los Angeles, CA disk 40 seconds 61866 20090424 20100820 Hartwell, GA oval 10 min 61867 20100821 20100826 Franklin Square, NY fireball 20 minutes 61868 20100827 20100827 Brighton, CO circle at lest 45 min 61869 20100818 20100821 Dryden (Canada), ON other 5 Min. maybe more 61870 20050502 20100824 Fort Knox, KY triangle 15 seconds
データをクリーニングする
読み込んだデータを、処理しやすい形に変換します。
データ列に名前をつける
> names(ufo) <- c("DateOccurred", "DateReported", "Location", "ShortDescription", "Duration", "LongDescription") > head(ufo) DateOccurred DateReported Location ShortDescription Duration 1 19951009 19951009 Iowa City, IA <NA> <NA> 2 19951010 19951011 Milwaukee, WI <NA> 2 min. 3 19950101 19950103 Shelton, WA <NA> <NA> 4 19950510 19950510 Columbia, MO <NA> 2 min. 5 19950611 19950614 Seattle, WA <NA> <NA> 6 19951025 19951024 Brunswick County, ND <NA> 30 min.
V1
などとなっていたところに、DateOccurred
のようなわかりやすい名前が付きました。
DateOccurred,DateReported を日付に変換する
DateOccurred
列は、日付を示す文字列なので、Dateオブジェクトに変換します。
> ufo$DateOccurred <- as.Date(ufo$DateOccurred, format = "%Y%m%d") strptime(x, format, tz = "GMT") でエラー: 入力文字列が長すぎます
エラーになりました。長い文字列を含む列があるよう。探します。
> head(ufo[which(nchar(ufo$DateOccurred)!=8|nchar(ufo$DateReported)!=8),1]) [1] "ler@gnv.ifas.ufl.edu" [2] "0000" [3] "Callers report sighting a number of soft white balls of lights headingin an easterly directing then changing direction to the west beforespeeding off to the north west." [4] "0000" [5] "0000" [6] "0000"
変なデータがありますな。ということで、8文字でない列を削除します。
# 行のDateOccurred または、DateReported が8文字かどうかを格納するデータ(good.rows)を作成 > good.rows <- ifelse(nchar(ufo$DateOccurred) != 8 | nchar(ufo$DateReported) != 8, FALSE, TRUE) # FALSE(=DateOccurred または、DateReported が8文字でない行)の数を確認。 > length(which(!good.rows)) [1] 731 # good.rowsがTRUEの行のみ抽出 > ufo <- ufo[good.rows, ]
変換を再試行。今度は、日付型に変換できました。
> ufo$DateOccurred <- as.Date(ufo$DateOccurred, format = "%Y%m%d") > ufo$DateReported <- as.Date(ufo$DateReported, format = "%Y%m%d") > head(ufo) DateOccurred DateReported Location ShortDescription Duration 1 1995-10-09 1995-10-09 Iowa City, IA <NA> <NA> 2 1995-10-10 1995-10-11 Milwaukee, WI <NA> 2 min. 3 1995-01-01 1995-01-03 Shelton, WA <NA> <NA>
関数を作ってLocation列のデータを都市名と州名に分割する
Location列のデータは、Iowa City, IA
のような「都市名, 州名」形式の文字列になっています。
関数を使ってこれを都市名, 州名に分解します。
まずは、分解を行う関数を作成。
> get.location <- function(l) { split.location <- tryCatch(strsplit(l, ",")[[1]], error = function(e) return(c(NA, NA))) clean.location <- gsub("^ ","",split.location) if (length(clean.location) > 2) { return(c(NA,NA)) } else { return(clean.location) } }
lapplyを使って、関数を適用したデータのリストを作成。
> city.state <- lapply(ufo$Location, get.location) > head(city.state) [[1]] [1] "Iowa City" "IA" [[2]] [1] "Milwaukee" "WI"
これを、ufoに追加します。まずは、do.call
でリストを行列に変換。
> location.matrix <- do.call(rbind, city.state) > head(location.matrix) [,1] [,2] [1,] "Iowa City" "IA" [2,] "Milwaukee" "WI" [3,] "Shelton" "WA" [4,] "Columbia" "MO" [5,] "Seattle" "WA" [6,] "Brunswick County" "ND"
transform
で ufo
に追加します。
> ufo <- transform(ufo, USCity = location.matrix[, 1], USState = location.matrix[, 2], stringsAsFactors = FALSE)
また、データには、カナダのものが含まれているので、これも除去します。
USState
にアメリカの州名以外が入っているデータを削除。
> ufo$USState <- state.abb[match(ufo$USState, state.abb)] > ufo.us <- subset(ufo, !is.na(USState))
これで、データのクリーニングは完了。
> summary(ufo.us) DateOccurred DateReported Location Min. :1400-06-30 Min. :1905-06-23 Length:51636 1st Qu.:1999-09-07 1st Qu.:2002-04-14 Class :character Median :2004-01-10 Median :2005-03-27 Mode :character Mean :2001-02-13 Mean :2004-11-30 3rd Qu.:2007-07-27 3rd Qu.:2008-01-20 Max. :2010-08-30 Max. :2010-08-30 ShortDescription Duration LongDescription USCity Length:51636 Length:51636 Length:51636 Length:51636 Class :character Class :character Class :character Class :character Mode :character Mode :character Mode :character Mode :character USState Length:51636 Class :character Mode :character > head(ufo.us) DateOccurred DateReported Location ShortDescription Duration 1 1995-10-09 1995-10-09 Iowa City, IA <NA> <NA> 2 1995-10-10 1995-10-11 Milwaukee, WI <NA> 2 min. 3 1995-01-01 1995-01-03 Shelton, WA <NA> <NA> 4 1995-05-10 1995-05-10 Columbia, MO <NA> 2 min. 5 1995-06-11 1995-06-14 Seattle, WA <NA> <NA> 6 1995-10-25 1995-10-24 Brunswick County, ND <NA> 30 min.
データを分析する
クリーニングしたデータを分析して、州/月ごとの目撃情報の傾向を分析します。
DateOccurredのばらつきを調べる
> summary(ufo.us$DateOccurred) Min. 1st Qu. Median Mean 3rd Qu. Max. NA's "1400-06-30" "1999-09-07" "2004-01-10" "2001-02-13" "2007-07-27" "2010-08-30" "1"
1400年のデータが含まれている・・・。ヒストグラムで、分布をみてみます。
> quick.hist <- ggplot(ufo.us, aes(x = DateOccurred)) + geom_histogram() + scale_x_date(date_breaks = "50 years", date_labels = "%Y") > ggsave(plot = quick.hist, filename = file.path("images", "quick_hist.png"), height = 6, width = 8)
大部分は最近の20年に集中している模様。この範囲に絞って分析するため、古いデータを取り除きます。
> ufo.us <- subset(ufo.us, DateOccurred >= as.Date("1990-01-01"))
データを州/月ごとに集計する
まずは、DateOccurred
列を「年-月」に変換した列を作成。
> ufo.us$YearMonth <- strftime(ufo.us$DateOccurred, format = "%Y-%m")
次に、月/州ごとにデータをグループ化して、データ数を集計します。
> sightings.counts <- ddply(ufo.us, .(USState,YearMonth), nrow) > head(sightings.counts) USState YearMonth V1 1 AK 1990-01 1 2 AK 1990-03 1 3 AK 1990-05 1 4 AK 1993-11 1 5 AK 1994-11 1 6 AK 1995-01 1
これだけだと、データが一つもない月が集計結果に含まれないので、それを補完します。
seq.Date
を使ってシーケンスを作成。
> date.range <- seq.Date(from = as.Date(min(ufo.us$DateOccurred)), to = as.Date(max(ufo.us$DateOccurred)), by = "month") > date.strings <- strftime(date.range, "%Y-%m")
作成したシーケンスに、州の一覧を掛け合わせて、州/月の行列を作成します。
> states.dates <- lapply(state.abb, function(s) cbind(s, date.strings)) > states.dates <- data.frame(do.call(rbind, states.dates), stringsAsFactors = FALSE) > head(states.dates) s date.strings 1 AL 1990-01 2 AL 1990-02 3 AL 1990-03 4 AL 1990-04 5 AL 1990-05 6 AL 1990-06
さらに、sightings.counts
をマージして、月の欠落がない集計データを作成。
> all.sightings <- merge(states.dates, sightings.counts, by.x = c("s", "date.strings"), by.y = c("USState", "YearMonth"), all = TRUE) > head(all.sightings) s date.strings V1 1 AK 1990-01 1 2 AK 1990-02 NA 3 AK 1990-03 1 4 AK 1990-04 NA 5 AK 1990-05 1 6 AK 1990-06 NA
わかりやすいよう、列に名前を付けます。また、集計しやすいようにNA
を0
に変換するなどの操作を行っておきます。
> names(all.sightings) <- c("State", "YearMonth", "Sightings") > all.sightings$Sightings[is.na(all.sightings$Sightings)] <- 0 > all.sightings$YearMonth <- as.Date(rep(date.range, length(state.abb))) > all.sightings$State <- as.factor(all.sightings$State) > head(all.sightings) State YearMonth Sightings 1 AK 1990-01-01 1 2 AK 1990-02-01 0 3 AK 1990-03-01 1 4 AK 1990-04-01 0 5 AK 1990-05-01 1 6 AK 1990-06-01 0
分析用データはこれで完成。
月/州ごとの目撃情報数をグラフにして分析する
> state.plot <- ggplot(all.sightings, aes(x = YearMonth,y = Sightings)) + geom_line(aes(color = "darkblue")) + facet_wrap(~State, nrow = 10, ncol = 5) + theme_bw() + scale_color_manual(values = c("darkblue" = "darkblue"), guide = "none") + scale_x_date(date_breaks = "5 years", date_labels = '%Y') + xlab("Years") + ylab("Number of Sightings") + ggtitle("Number of UFO sightings by Month-Year and U.S. State (1990-2010)") > ggsave(plot = state.plot, filename = file.path("images", "ufo_sightings.png"), width = 14, height = 8.5)
こんなグラフになります。
分析は省略。
トラップリピートイフダンのような注文を発行するエージェントのサンプル
FXシステムトレードフレームワーク「Jiji」のサンプル その2、として、 トラップリピートイフダンのような注文を発行するエージェントを作ってみました。
※トラップリピートイフダン(トラリピ)は、マネースクウェアジャパン(M2J)の登録商標です。
トラップリピートイフダンとは
指値/逆指値の注文と決済を複数組み合わせて行い、その中でレートが上下することで利益を出すことを狙う、発注ロジックです。 具体的にどういった動きをするのかは、マネースクウェアジャパン のサイトがとてもわかりやすいので、そちらをご覧ください。
特徴
FX研究日記さんの評価記事が参考になります。
- レンジ相場では、利益を出しやすい
- ×レートが逆行すると損失を貯めこんでしまう
仕組みからして、いわゆるコツコツドカンなシステムという印象です。 レンジ相場なら利益を積み上げやすいので、トレンドを判定するロジックと組み合わせて、レートが一定のレンジで動作しそうになったら稼働させる、などすれば使えるかも。
エージェントのコード
- 実装は、こちらのサイトで配布されているEAを参考にさせていただきました。
TrapRepeatIfDoneAgent
が、エージェントの本体です。これをバックテストやリアルトレードで動作させればOK。- エージェントファイルの追加の方法など、Jijiの基本的な使い方はこちらをご覧ください。
- 機能の再利用ができるように、発注処理は
TrapRepeatIfDone
に実装しています。
# === トラップリピートイフダンのような注文を発行するエージェント class TrapRepeatIfDoneAgent include Jiji::Model::Agents::Agent def self.description <<-STR トラップリピートイフダンのような注文を発行するエージェント STR end # UIから設定可能なプロパティの一覧 def self.property_infos [ Property.new('trap_interval_pips', 'トラップを仕掛ける間隔(pips)', 50), Property.new('trade_units', '1注文あたりの取引数量', 1), Property.new('profit_pips', '利益を確定するpips', 100), Property.new('slippage', '許容スリッページ(pips)', 3) ] end def post_create @trap_repeat_if_done = TrapRepeatIfDone.new( broker.pairs.find {|p| p.name == :USDJPY }, :buy, @trap_interval_pips.to_i, @trade_units.to_i, @profit_pips.to_i, @slippage.to_i, logger) end def next_tick(tick) @trap_repeat_if_done.register_orders(broker) end def state @trap_repeat_if_done.state end def restore_state(state) @trap_repeat_if_done.restore_state(state) end end # トラップリピートイフダンのような注文を発行するクラス class TrapRepeatIfDone # コンストラクタ # # target_pair:: 現在の価格を格納するTick::Valueオブジェクト # sell_or_buy:: 取引モード。 :buy の場合、買い注文を発行する。 :sellの場合、売 # trap_interval_pips:: トラップを仕掛ける間隔(pips) # trade_units:: 1注文あたりの取引数量 # profit_pips:: 利益を確定するpips # slippage:: 許容スリッページ。nilの場合、指定しない def initialize(target_pair, sell_or_buy=:buy, trap_interval_pips=50, trade_units=1, profit_pips=100, slippage=3, logger=nil) @target_pair = target_pair @trap_interval_pips = trap_interval_pips @slippage = slippage @mode = if sell_or_buy == :sell Sell.new(target_pair, trade_units, profit_pips, slippage, logger) else Buy.new(target_pair, trade_units, profit_pips, slippage, logger) end @logger = logger @registerd_orders = {} end # 注文を登録する # # broker:: broker def register_orders(broker) broker.instance_variable_get(:@broker).refresh_positions # 常に最新の建玉を取得して利用するようにする # TODO 公開APIにする each_traps(broker.tick) do |trap_open_price| next if order_or_position_exists?(trap_open_price, broker) register_order(trap_open_price, broker) end end def state @registerd_orders end def restore_state(state) @registerd_orders = state unless state.nil? end private def each_traps(tick) current_price = @mode.resolve_current_price(tick[@target_pair.name]) base = resolve_base_price(current_price) 6.times do |n| # baseを基準に、上下3つのトラップを仕掛ける trap_open_price = BigDecimal.new(base, 10) \ + BigDecimal.new(@trap_interval_pips, 10) * (n-3) * @target_pair.pip yield trap_open_price end end # 現在価格をtrap_interval_pipsで丸めた価格を返す。 # # 例) trap_interval_pipsが50の場合、 # resolve_base_price(120.10) # -> 120.00 # resolve_base_price(120.49) # -> 120.00 # resolve_base_price(120.51) # -> 120.50 # def resolve_base_price(current_price) current_price = BigDecimal.new(current_price, 10) pip_precision = 1 / @target_pair.pip (current_price * pip_precision / @trap_interval_pips ).ceil \ * @trap_interval_pips / pip_precision end # trap_open_priceに対応するオーダーを登録する def register_order(trap_open_price, broker) result = @mode.register_order(trap_open_price, broker) unless result.order_opened.nil? @registerd_orders[key_for(trap_open_price)] \ = result.order_opened.internal_id end end # trap_open_priceに対応するオーダーを登録済みか評価する def order_or_position_exists?(trap_open_price, broker) order_exists?(trap_open_price, broker) \ || position_exists?(trap_open_price, broker) end def order_exists?(trap_open_price, broker) key = key_for(trap_open_price) return false unless @registerd_orders.include? key id = @registerd_orders[key] order = broker.orders.find {|o| o.internal_id == id } return !order.nil? end def position_exists?(trap_open_price, broker) # trapのリミット付近でレートが上下して注文が大量に発注されないよう、 # trapのリミット付近を開始値とする建玉が存在する間は、trapの注文を発行しない slipage_price = (@slippage.nil? ? 10 : @slippage) * @target_pair.pip position = broker.positions.find do |p| # 注文時に指定したpriceちょうどで約定しない場合を考慮して、 # 指定したslippage(指定なしの場合は10pips)の誤差を考慮して存在判定をする p.entry_price < trap_open_price + slipage_price \ && p.entry_price > trap_open_price - slipage_price end return !position.nil? end def key_for(trap_open_price) (trap_open_price * (1 / @target_pair.pip)).to_i.to_s end # 取引モード(売 or 買) # 買(Buy)の場合、買でオーダーを行う。売(Sell)の場合、売でオーダーを行う。 class Mode def initialize(target_pair, trade_units, profit_pips, slippage, logger) @target_pair = target_pair @trade_units = trade_units @profit_pips = profit_pips @slippage = slippage @logger = logger end # 現在価格を取得する(買の場合Askレート、売の場合Bidレートを使う) # # tick_value:: 現在の価格を格納するTick::Valueオブジェクト # 戻り値:: 現在価格 def resolve_current_price(tick_value) end # 注文を登録する def register_order(trap_open_price, broker) end def calculate_price(price, pips) price = BigDecimal.new(price, 10) pips = BigDecimal.new(pips, 10) * @target_pair.pip (price + pips).to_f end def pring_order_log(mode, options, timestamp) return unless @logger message = [ mode, timestamp, options[:price], options[:take_profit], options[:lower_bound], options[:upper_bound] ].map {|item| item.to_s }.join(" ") @logger.info message end end class Sell < Mode def resolve_current_price(tick_value) tick_value.bid end def register_order(trap_open_price, broker) timestamp = broker.tick.timestamp options = create_option(trap_open_price, timestamp) pring_order_log("sell", options, timestamp) broker.sell(@target_pair.name, @trade_units, :marketIfTouched, options) end def create_option(trap_open_price, timestamp) options = { price: trap_open_price.to_f, take_profit: calculate_price(trap_open_price, @profit_pips*-1), expiry: timestamp + 60*60*24*7 } unless @slippage.nil? options[:lower_bound] = calculate_price(trap_open_price, @slippage*-1) options[:upper_bound] = calculate_price(trap_open_price, @slippage) end options end end class Buy < Mode def resolve_current_price(tick_value) tick_value.ask end def register_order(trap_open_price, broker) timestamp = broker.tick.timestamp options = create_option(trap_open_price, timestamp) pring_order_log("buy", options, timestamp) broker.buy(@target_pair.name, @trade_units, :marketIfTouched, options) end def create_option(trap_open_price, timestamp) options = { price: trap_open_price.to_f, take_profit: calculate_price(trap_open_price, @profit_pips), expiry: timestamp + 60*60*24*7 } unless @slippage.nil? options[:lower_bound] = calculate_price(trap_open_price, @slippage*-1) options[:upper_bound] = calculate_price(trap_open_price, @slippage) end options end end end
インタラクティブにトレーリングストップ決済を行うBotを作ってみた
FXシステムトレードフレームワーク「Jiji」の使い方サンプル その1、ということで、 Jijiを使って、インタラクティブにトレーリングストップ決済を行うBotを作ってみました。
トレーリングストップとは
建玉(ポジション)の決済方法の一つで、「最高値を更新するごとに、逆指値の決済価格を切り上げていく」決済ロジックです。
例) USDJPY/120.10で買建玉を作成。これを、10 pips でトレーリングストップする場合、
- 建玉作成直後は、120.00 で逆指値決済される状態になる
- レートが 120.30 になった場合、逆指値の決済価格が高値に合わせて上昇し、120.20に切り上がる
- その後、レートが120.20 になると、逆指値で決済される
トレンドに乗っている間はそのまま利益を増やし、トレンドが変わって下げ始めたら決済する、という動きをする決済ロジックですね。
インタラクティブにしてみる
単純なトレーリングストップだけなら証券会社が提供している機能で実現できるので、少し手を加えてインタラクティブにしてみました。
トレーリングストップでは、以下のようなパターンがありがち。
- すこし大きなドローダウンがきて、トレンド変わってないのに決済されてしまい、利益を逃した・・
- レートが急落した時に、決済が遅れて損失が広がった・・・
これを回避できるように、Botでの強制決済に加えて、人が状況をみて決済するかどうか判断できる仕組みをいれてみます。
仕様
以下のような動作をします。
トレーリングストップの閾値を2段階で指定できるようにして、1つ目の閾値を超えたタイミングでは警告通知を送信。
- 通知を確認して、即時決済するか、保留するか判断できる。
- 決済をスムーズに行えるよう、通知から1タップで決済を実行できるようにする。
-
- 夜間など通知を受けとっても対処できない場合を考慮して、2つ目の閾値を超えたら、強制決済するようにしておきます。
- なお、決済時にはOANDA JAPANから通知が送信されるので、Jijiからの通知は省略しました。
Bot(エージェント)のコード
TrailingStopAgent
が、Botの本体。これをバックテストやリアルトレードで動作させればOKです。- エージェントファイルの追加の方法など、Jijiの基本的な使い方はこちらをご覧ください。
TrailingStopAgent
自体は、新規に建玉を作ることはしません。- 機能の再利用ができるように、処理は
TrailingStopManager
に実装しています。
# トレーリングストップで建玉を決済するエージェント class TrailingStopAgent include Jiji::Model::Agents::Agent def self.description <<-STR トレーリングストップで建玉を決済するエージェント。 - 損益が警告を送る閾値を下回ったら、1度だけ警告をPush通知で送信。 - さらに決済する閾値も下回ったら、建玉を決済します。 STR end # UIから設定可能なプロパティの一覧 def self.property_infos [ Property.new('warning_limit', '警告を送る閾値', 20), Property.new('closing_limit', '決済する閾値', 40) ] end def post_create @manager = TrailingStopManager.new( @warning_limit.to_i, @closing_limit.to_i, notifier) end def next_tick(tick) @manager.check(broker.positions, broker.pairs) end def execute_action(action) @manager.process_action(action, broker.positions) || '???' end def state { trailing_stop_manager: @manager.state } end def restore_state(state) if state[:trailing_stop_manager] @manager.restore_state(state[:trailing_stop_manager]) end end end # 建玉を監視し、最新のレートに基づいてトレールストップを行う class TrailingStopManager # コンストラクタ # # warning_limit:: 警告を送信する閾値(pip) # closing_limit:: 決済を行う閾値(pip) # notifier:: notifier def initialize(warning_limit, closing_limit, notifier) @warning_limit = warning_limit @closing_limit = closing_limit @notifier = notifier @states = {} end # 建玉がトレールストップの閾値に達していないかチェックする。 # warning_limit を超えている場合、警告通知を送信、 # closing_limit を超えた場合、強制的に決済する。 # # positions:: 建て玉一覧(broker#positions) # pairs:: 通貨ペア一覧(broker#pairs) def check(positions, pairs) @states = positions.each_with_object({}) do |position, r| r[position.id.to_s] = check_position(position, pairs) end end # アクションを処理する # # action:: アクション # positions:: 建て玉一覧(broker#positions) # 戻り値:: アクションを処理できた場合、レスポンスメッセージ。 # TrailingStopManagerが管轄するアクションでない場合、nil def process_action(action, positions) return nil unless action =~ /trailing\_stop\_\_([a-z]+)_(.*)$/ case $1 when "close" then position = positions.find {|p| p.id.to_s == $2 } return nil unless position position.close return "建玉を決済しました。" end end # 永続化する状態。 def state @states.each_with_object({}) {|s, r| r[s[0]] = s[1].state } end # 永続化された状態から、インスタンスを復元する def restore_state(state) @states = state.each_with_object({}) do |s, r| state = PositionState.new( nil, @warning_limit, @closing_limit ) state.restore_state(s[1]) r[s[0]] = state end end private # 建玉の状態を更新し、閾値を超えていたら対応するアクションを実行する。 def check_position(position, pairs) state = get_and_update_state(position, pairs) if state.under_closing_limit? position.close elsif state.under_warning_limit? unless state.sent_warning # 通知は1度だけ送信する send_notification(position, state) state.sent_warning = true end end return state end def get_and_update_state(position, pairs) state = create_or_get_state(position, pairs) state.update(position) state end def create_or_get_state(position, pairs) key = position.id.to_s return @states[key] if @states.include? key PositionState.new( retrieve_pip_for(position.pair_name, pairs), @warning_limit, @closing_limit ) end def retrieve_pip_for(pair_name, pairs) pairs.find {|p| p.name == pair_name }.pip end def send_notification(position, state) message = "#{create_position_description(position)}" \ + " がトレールストップの閾値を下回りました。決済しますか?" @notifier.push_notification(message, [{ 'label' => '決済する', 'action' => 'trailing_stop__close_' + position.id.to_s }]) end def create_position_description(position) sell_or_buy = position.sell_or_buy == :sell ? "売" : "買" "#{position.pair_name}/#{position.entry_price}/#{sell_or_buy}" end end class PositionState attr_reader :max_profit, :profit_or_loss, :max_profit_time, :last_update_time attr_accessor :sent_warning def initialize(pip, warning_limit, closing_limit) @pip = pip @warning_limit = warning_limit @closing_limit = closing_limit @sent_warning = false end def update(position) @units = position.units @profit_or_loss = position.profit_or_loss @last_update_time = position.updated_at if @max_profit.nil? || position.profit_or_loss > @max_profit @max_profit = position.profit_or_loss @max_profit_time = position.updated_at @sent_warning = false # 高値を更新したあと、 warning_limit を超えたら再度警告を送るようにする end end def under_warning_limit? return false if @max_profit.nil? difference >= @warning_limit * @units * @pip end def under_closing_limit? return false if @max_profit.nil? difference >= @closing_limit * @units * @pip end def state { "max_profit" => @max_profit, "max_profit_time" => @max_profit_time, "pip" => @pip, "sent_warning" => @sent_warning } end def restore_state(state) @max_profit = state["max_profit"] @max_profit_time = state["max_profit_time"] @pip = state["pip"] @sent_warning = state["sent_warning"] end private def difference @max_profit - @profit_or_loss end end
それでは、みなさま、良いお年を。
MongoDBのinsert/updateをまとめて、bulk insert/update に流すユーティリティを書いた
バッチ処理などでMongoDBに大量のinsert/updateを行うとき、Mongoidを使って1つずつ #save
してると遅い。
ということで、複数の #save
をまとめて bulk insert/update に流すユーティリティを書いてみました。
使い方
- モデルクラスで、
Mongoid::Document
とUtils::BulkWriteOperationSupport
をinclude
する。 Utils::BulkWriteOperationSupport.begin_transaction
を呼び出してから、モデルの#save
を呼び出す。- この時点ではMongoDBへのinsert/updateは行われず、バッファに蓄積されます。
Utils::BulkWriteOperationSupport.end_transaction
を呼び出すと、バッファのデータが#bulk_write
でまとめて永続化される。
class TestModel # Mongoid::Document と Utils::BulkWriteOperationSupport をincludeする include Mongoid::Document include Utils::BulkWriteOperationSupport store_in collection: 'test_model' field :name, type: String end #略 puts TestModel.count # => 0 Utils::BulkWriteOperationSupport.begin_transaction # #begin_transaction を呼び出したあと、モデルを作成/変更して、 #save を呼び出す。 a = TestModel.new a.name = 'a' b = TestModel.new b.name = 'b' a.save b.save # #end_transaction を実行するまで、永続化されない puts TestModel.count # => 0 # #end_transaction を呼び出すと、バッファのデータが #bulk_write でまとめて永続化される。 Utils::BulkWriteOperationSupport.end_transaction puts TestModel.count # => 2
ユーティリティのコード
Document#save
を書き換えて、#begin_transaction
~#end_transaction
の間であれば、スレッドローカルに永続化対象としてマーク。#end_transaction
が呼び出されたタイミンクで、まとめて#bulk_write
で永続化します。- 参照整合性のチェックとか、いろいろ手抜きなので必要に応じて改造してください。
module BulkWriteOperationSupport KEY = BulkWriteOperationSupport.name def save if BulkWriteOperationSupport.in_transaction? BulkWriteOperationSupport.transaction << self else super end end def self.in_transaction? !transaction.nil? end def self.begin_transaction Thread.current[KEY] = Transaction.new end def self.end_transaction return unless in_transaction? transaction.execute Thread.current[KEY] = nil end def self.transaction Thread.current[KEY] end def create_insert_operation { :insert_one => as_document } end def create_update_operation { :update_one => { :filter => { :_id => id }, :update => {'$set' => collect_changed_values } } } end private def collect_changed_values changes.each_with_object({}) do |change, r| r[change[0].to_sym] = change[1][1] end end class Transaction def initialize @targets = {} end def <<(model) targets_of( model.class )[model.object_id] = model end def execute until @targets.empty? model_class = @targets.keys.first execute_bulk_write_operations(model_class) end end def size @targets.values.reduce(0) {|a, e| a + e.length } end private def targets_of( model_class ) @targets[model_class] ||= {} end def execute_bulk_write_operations(model_class) return unless @targets.include?(model_class) execute_parent_object_bulk_write_operations_if_exists(model_class) client = model_class.mongo_client[model_class.collection_name] operations = create_operations(@targets[model_class].values) client.bulk_write(operations) unless operations.empty? @targets.delete model_class end def execute_parent_object_bulk_write_operations_if_exists(model_class) parents = model_class.reflect_on_all_associations(:belongs_to) parents.each do |m| klass = m.klass execute_bulk_write_operations(klass) end end def create_operations(targets) targets.each_with_object([]) do |model, array| if model.new_record? model.new_record = false array << model.create_insert_operation else array << model.create_update_operation if model.changed? end end end end end
nukeproof/oanda_api のコネクションリーク問題とその対策
OANDA fx Trade APIのRubyクライアント「nukeproof/oanda_api」には、TCPコネクションリークの問題があり、長時間連続で利用しているとファイルディスクリプタが枯渇します。
内部で利用している persistent_http の古いバージョンにある不具合が原因(最新の2.0.1では改修済み)のため、Gemfileなどで最新バージョンを使うようにすると回避できます。
gem 'persistent_http', '2.0.1'
問題の詳細
Jijiを10日程度連続稼働させていて発覚。突然、以下のエラーが発生するようになりました。
E, [2015-12-09T01:23:06.337582 #7932] ERROR -- : Too many open files - getaddrinfo (Errno::EMFILE) /home/yamautim/.rbenv/versions/2.2.3/lib/ruby/2.2.0/net/http.rb:879:in `initialize' /home/yamautim/.rbenv/versions/2.2.3/lib/ruby/2.2.0/net/http.rb:879:in `open' /home/yamautim/.rbenv/versions/2.2.3/lib/ruby/2.2.0/net/http.rb:879:in `block in connect'
lsof コマンドの出力行数も少しずつ増えていきます。
$ lsof -p <Jijiのpid> | wc -l 77 $ lsof -p <Jijiのpid> | wc -l 79
起動直後と、しばらく起動した後で、lsofの出力結果のDiffをとると、CLOSE_WAITのhttpsコネクションが増加していました。
ruby 4389 xxxx 23u IPv4 9265756 0t0 TCP localhost.localdomain:35773->unknown.xxxx.net:https (CLOSE_WAIT)
外向きのHTTPS通信なので、OANDA へのアクセスっぽい。
原因
最初にも書きましたが、persistent_http の古いバージョンにある不具合が原因です。最新の 2.0.1 を使えば改修されます。
- 1.0.6では、内部で利用しているGenePoolに渡すオプションがnilになっており、コネクションの破棄が正しく行われない状態になっています。
- このコミットで改修されていて、最新の 2.0.1 を使えばOK。
- ただし、oanda_api が直接依存している persistent_httparty で バージョン2以下を使うよう明示されているため、普通に使うと 1.0.6 が使われてしまいます。
- このため、Gemfileなどで最新バージョンを使うように明示する等の対応が必要です。
- ちなみに、persistent_httparty に、依存するpersistent_httpのバージョンを上げるPull Requestはあるのですが、マージされていないようです・・・。