無料で使えるシステムトレードフレームワーク「Jiji」 をリリースしました!

・OANDA Trade APIを利用した、オープンソースのシステムトレードフレームワークです。
・自分だけの取引アルゴリズムで、誰でも、いますぐ、かんたんに、自動取引を開始できます。

機械学習手習い: スパムフィルタを作る

「入門 機械学習」手習い、3日目。「3章 分類:スパムフィルタ」です。

www.amazon.co.jp

ナイーブベイズ分類器を作って、メールがスパムかどうかを判定するフィルタを作ります。

分類器の仕組み

  • 1) 以下の単語セットを作成
    • (a) スパムメッセージに出現しやすい単語とその出現確率
    • (b) スパムメッセージに出現しにくい単語とその出現確率
  • 2) で作成した単語セットを元に、メール本文を評価し、以下を算出
    • (a2) メールをスパムと仮定した時の尤もらしさ
    • (b2) メールを非スパムと仮定した時の尤もらしさ
  • 3) a2 > b2 となるメールをスパムと判定する

という感じで判定を行います。

必要なモジュールとデータの読み込み

> setwd("03-Classification/")
> library('tm')
> library('ggplot2')

# テスト用データ
# 分類機の訓練用
> spam.path <- file.path("data", "spam")        # スパムデータ
> easyham.path <- file.path("data", "easy_ham") # 非スパムデータ(易)
> hardham.path <- file.path("data", "hard_ham") # 非スパムデータ(難)
# 分類機のテスト用
> spam2.path <- file.path("data", "spam_2")        # スパムデータ
> easyham2.path <- file.path("data", "easy_ham_2") # 非スパムデータ(易)
> hardham2.path <- file.path("data", "hard_ham_2") # 非スパムデータ(難)

メールから本文を取り出す

ファイルを読み込んで、本文を返す関数を作成します。

> get.msg <- function(path) {
  con <- file(path, open = "rt", encoding = "latin1")
  text <- readLines(con)
  # The message always begins after the first full line break
  msg <- text[seq(which(text == "")[1] + 1, length(text), 1)]
  close(con)
  return(paste(msg, collapse = "\n"))
}

動作テスト。

> get.msg("data/spam/00001.7848dde101aa985090474a91ec93fcf0")
[1] "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0 Transitio...

sapply を使って、data/spam 内のスパムメールから本文を読み込みます。

> spam.docs <- dir(spam.path)
> spam.docs <- spam.docs[which(spam.docs != "cmds")]
> all.spam <- sapply(spam.docs,function(p) get.msg(file.path(spam.path, p)))
> head(all.spam)

単語文書行列を作る

単語を行、文書を列として、文書中の単語の出現数をカウントした、単語文書行列(TDM:Term Document Matrix)を作ります。

↓こんなイメージ

a.txt b.txt c.txt
hello 2 1 0
world 0 2 3


まずは、TDMを作成する関数を定義。

> get.tdm <- function(doc.vec){
  control <- list(stopwords = TRUE,
                  removePunctuation = TRUE,
                  removeNumbers = TRUE,
                  minDocFreq = 2)
  doc.corpus <- Corpus(VectorSource(doc.vec))
  doc.dtm <- TermDocumentMatrix(doc.corpus, control)
  return(doc.dtm)
}

all.spam から、TDMを作成します。

> spam.tdm <- get.tdm(all.spam)

スパムの訓練データを作る

スパムに含まれる単語の出現確率を集計します。

# all.spamのTDMをRの行列に変換
> spam.matrix <- as.matrix(spam.tdm)
# 各単語の、全スパム中での出現頻度をカウント
> spam.counts <- rowSums(spam.matrix)
# データフレームに変換
> spam.df <- data.frame(cbind(names(spam.counts),
                      as.numeric(spam.counts)),
                      stringsAsFactors = FALSE)
> names(spam.df) <- c("term", "frequency")
# 出現確率を算出
> spam.df$frequency <- as.numeric(spam.df$frequency)
> spam.occurrence <- sapply(1:nrow(spam.matrix), function(i) {
  length(which(spam.matrix[i, ] > 0)) / ncol(spam.matrix)
})
> spam.density <- spam.df$frequency / sum(spam.df$frequency)
> spam.df <- transform(spam.df,
                     density = spam.density,
                     occurrence = spam.occurrence)

訓練データの中身を確認。

> head(spam.df[with(spam.df, order(-occurrence)),])
        term frequency     density occurrence
7693   email       813 0.005855914      0.566
18706 please       425 0.003061210      0.508
14623   list       409 0.002945964      0.444
27202   will       828 0.005963957      0.422
2970    body       379 0.002729879      0.408
9369    free       543 0.003911146      0.390

スパムメール中の56%が email を含んでいます。

非スパムの訓練データを作る

スパムの訓練データを作ったのと同じ手順で、非スパムの訓練データも作ります。

> easyham.docs <- dir(easyham.path)
> easyham.docs <- easyham.docs[which(easyham.docs != "cmds")]
# 最初の分類機を作る際に、各メッセージがスパムである確率とそうでない確率を等しいと仮定するため、
# 訓練データもスパムの文書数と同じ数だけに限定する。
> all.easyham <- sapply(easyham.docs[1:length(spam.docs)],
                      function(p) get.msg(file.path(easyham.path, p)))
> easyham.tdm <- get.tdm(all.easyham)
> easyham.matrix <- as.matrix(easyham.tdm)
> easyham.counts <- rowSums(easyham.matrix)
> easyham.df <- data.frame(cbind(names(easyham.counts),
                         as.numeric(easyham.counts)),
                         stringsAsFactors = FALSE)
> names(easyham.df) <- c("term", "frequency")
> easyham.df$frequency <- as.numeric(easyham.df$frequency)
> easyham.occurrence <- sapply(1:nrow(easyham.matrix), function(i) {
 length(which(easyham.matrix[i, ] > 0)) / ncol(easyham.matrix)
})
> easyham.density <- easyham.df$frequency / sum(easyham.df$frequency)
> easyham.df <- transform(easyham.df,
                        density = easyham.density,
                        occurrence = easyham.occurrence)

訓練データの中身を確認。

> head(easyham.df[with(easyham.df, order(-occurrence)),])
       term frequency     density occurrence
5193  group       232 0.003317366      0.388
12877   use       271 0.003875027      0.378
13511 wrote       237 0.003388861      0.378
1629    can       348 0.004976049      0.368
7244   list       248 0.003546150      0.368
8581    one       356 0.005090441      0.336

分類器を作る

メールファイルを受け取り、それがスパム、または、非スパムである確率を計算する分類器を定義します。

> classify.email <- function(path, training.df, prior = 0.5, c = 1e-6) {

  # ファイルから本文を取り出し、単語の出現数をカウント
  msg <- get.msg(path)
  msg.tdm <- get.tdm(msg)
  msg.freq <- rowSums(as.matrix(msg.tdm))

  # メール内の単語のうち、訓練データに含まれる単語を取得
  msg.match <- intersect(names(msg.freq), training.df$term)
  
  # 単語の出現確率を掛け合わせて、条件付き確率を算出する
  # この時、訓練データに含まれない単語は、非常に小さい確率0.0001として扱う。
  if(length(msg.match) < 1) {
    # 訓練データに含まれる単語が1つもない場合、条件付き確率は、
    # 事前確率(prior) * (0.0001の単語数乗)
    # となる。
    return(prior * c ^ (length(msg.freq)))
  } else {
    # 訓練データから、単語ごとの出現確率を取り出す。
    match.probs <- training.df$occurrence[match(msg.match, training.df$term)]
    # 条件付き確率を計算
    # 事前確率(prior) * (訓練データに含まれる単語の出現確率の積) * (0.0001の訓練データに含まれない単語数乗)
    return(prior * prod(match.probs) * c ^ (length(msg.freq) - length(msg.match)))
  }
}

判定してみる

非スパム(難)のデータを判定してみます。

> hardham.docs <- dir(hardham.path)
> hardham.docs <- hardham.docs[which(hardham.docs != "cmds")]
> hardham.spamtest <- sapply(hardham.docs,
                           function(p) classify.email(file.path(hardham.path, p), training.df = spam.df))
> hardham.hamtest <- sapply(hardham.docs,
                          function(p) classify.email(file.path(hardham.path, p), training.df = easyham.df))
> hardham.res <- ifelse(hardham.spamtest > hardham.hamtest,
                      TRUE,
                      FALSE)
> summary(hardham.res)
   Mode   FALSE    TRUE    NA's 
logical     243       6       0 

データはすべて非スパムのものなので、誤判定は6件、偽陽性率は2.4%。

テスト用データすべてを判定してみる

スパム判定を行う関数を作成。

spam.classifier <- function(path) {
  pr.spam <- classify.email(path, spam.df)
  pr.ham <- classify.email(path, easyham.df)
  return(c(pr.spam, pr.ham, ifelse(pr.spam > pr.ham, 1, 0)))
}

テスト用のメールデータをすべて判定してみます。

> easyham2.docs <- dir(easyham2.path)
> easyham2.docs <- easyham2.docs[which(easyham2.docs != "cmds")]

> hardham2.docs <- dir(hardham2.path)
> hardham2.docs <- hardham2.docs[which(hardham2.docs != "cmds")]

> spam2.docs <- dir(spam2.path)
> spam2.docs <- spam2.docs[which(spam2.docs != "cmds")]

> easyham2.class <- suppressWarnings(lapply(easyham2.docs, function(p) {
  spam.classifier(file.path(easyham2.path, p))
}))
> hardham2.class <- suppressWarnings(lapply(hardham2.docs, function(p) {
  spam.classifier(file.path(hardham2.path, p))
}))
> spam2.class <- suppressWarnings(lapply(spam2.docs, function(p) {
  spam.classifier(file.path(spam2.path, p))
}))

> easyham2.matrix <- do.call(rbind, easyham2.class)
> easyham2.final <- cbind(easyham2.matrix, "EASYHAM")

> hardham2.matrix <- do.call(rbind, hardham2.class)
> hardham2.final <- cbind(hardham2.matrix, "HARDHAM")

> spam2.matrix <- do.call(rbind, spam2.class)
> spam2.final <- cbind(spam2.matrix, "SPAM")

> class.matrix <- rbind(easyham2.final, hardham2.final, spam2.final)
> class.df <- data.frame(class.matrix, stringsAsFactors = FALSE)
> names(class.df) <- c("Pr.SPAM" ,"Pr.HAM", "Class", "Type")
> class.df$Pr.SPAM <- as.numeric(class.df$Pr.SPAM)
> class.df$Pr.HAM <- as.numeric(class.df$Pr.HAM)
> class.df$Class <- as.logical(as.numeric(class.df$Class))
> class.df$Type <- as.factor(class.df$Type)

文書ごとのスパム/非スパムの尤もらしさ、分類結果、メールの種別を含むデータフレームができました。

> head(class.df)
        Pr.SPAM        Pr.HAM Class    Type
1  0.000000e+00  0.000000e+00 FALSE EASYHAM
2 5.352364e-248 1.159512e-155 FALSE EASYHAM
3  0.000000e+00 5.103377e-216 FALSE EASYHAM
4  0.000000e+00  0.000000e+00 FALSE EASYHAM
5 2.083521e-169 1.221918e-108 FALSE EASYHAM
6  0.000000e+00  0.000000e+00 FALSE EASYHAM

文書の種類ごとに結果を集計してみます。

> get.results <- function(bool.vector) {
  results <- c(length(bool.vector[which(bool.vector == FALSE)]) / length(bool.vector),
               length(bool.vector[which(bool.vector == TRUE)]) / length(bool.vector))
  return(results)
}
> easyham2.col <- get.results(subset(class.df, Type == "EASYHAM")$Class)
> hardham2.col <- get.results(subset(class.df, Type == "HARDHAM")$Class)
> spam2.col <- get.results(subset(class.df, Type == "SPAM")$Class)

> class.res <- rbind(easyham2.col, hardham2.col, spam2.col)
> colnames(class.res) <- c("NOT SPAM", "SPAM")
> print(class.res)
              NOT SPAM       SPAM
easyham2.col 0.9871429 0.01285714
hardham2.col 0.9677419 0.03225806
spam2.col    0.4631353 0.53686471

非スパム(easyham2, hardham2)を間違ってスパムと判定する確率(偽陽性率)はそれぞれ1%,3%と低く、うまく分類できているもよう。 ただ、スパム(spam2)を間違ってスパムでないと判定する確率(偽陰性率)は46%。あれ、ちょっと高いかも?

結果を分散図にしてみる

> class.plot <- ggplot(class.df, aes(x = log(Pr.HAM), log(Pr.SPAM))) +
    geom_point(aes(shape = Type, alpha = 0.5)) +
    geom_abline(intercept = 0, slope = 1) +
    scale_shape_manual(values = c("EASYHAM" = 1,
                                  "HARDHAM" = 2,
                                  "SPAM" = 3),
                       name = "Email Type") +
    scale_alpha(guide = "none") +
    xlab("log[Pr(HAM)]") +
    ylab("log[Pr(SPAM)]") +
    theme_bw() +
    theme(axis.text.x = element_blank(), axis.text.y = element_blank())
> ggsave(plot = class.plot,
       filename = file.path("images", "03_final_classification.png"),
       height = 10,
       width = 10)

f:id:unageanu:20160111152705p:plain

横軸が「メールを非スパムと仮定した時の尤もらしさ」、縦軸が「メールをスパムと仮定した時の尤もらしさ」を示します。

「メールをスパムと仮定した時の尤もらしさ」 > 「メールを非スパムと仮定した時の尤もらしさ」となったメールをスパムと判定するので、真ん中の線より上の者はスパム、下は非スパムと判定されています。線より上に〇や△(非スパムのメール)がいくつかあったりはしますが、おおむね正しく判定できている感じですね。

事前分布を変えて、結果を改善する

↑では、とあるメールがあった時にそれがスパムである確率は50%(=世の中のメールの半分はスパムで半分はそうではない)と仮定していました。 現実には、80%は非スパム、残り20%がスパムなので、これを考慮して再計算することで結果を改善してみます。

spam.classifier を修正して、事前確率を変更。

spam.classifier <- function(path) {
  pr.spam <- classify.email(path, spam.df, prior=0.2)
  pr.ham <- classify.email(path, easyham.df, prior=0.8)
  return(c(pr.spam, pr.ham, ifelse(pr.spam > pr.ham, 1, 0)))
}

再計算した結果は以下。

> print(class.res)
              NOT SPAM       SPAM
easyham2.col 0.9892857 0.01071429
hardham2.col 0.9717742 0.02822581
spam2.col    0.4652827 0.53471725

非スパム(easyham2, hardham2)を間違ってスパムと判定する確率は少し改善しました。 一方、スパム(spam2)を間違ってスパムでないと判定する確率は悪化。。。

機械学習手習い: 数値によるデータの要約と可視化手法

「入門 機械学習」手習い、今日は「2章 データの調査」です。

www.amazon.co.jp

数値によるデータの要約と、可視化手法を学びます。

テスト用データの読み込み

> setwd("02-Exploration/")
> data.file <- file.path('data', '01_heights_weights_genders.csv')
> heights.weights <- read.csv(data.file, header = TRUE, sep = ',') 
> head(heights.weights)
  Gender   Height   Weight
1   Male 73.84702 241.8936
2   Male 68.78190 162.3105
3   Male 74.11011 212.7409
4   Male 71.73098 220.0425
5   Male 69.88180 206.3498
6   Male 67.25302 152.2122

データの数値による要約

summaryでベクトルの数値を要約します

> summary(heights.weights$Height)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
  54.26   63.51   66.32   66.37   69.17   79.00 

左から、

  • Min .. 最小値
  • 1st Qu .. 第一四分位(データ全体の下から25%の位置にあたる値)
  • Median .. 中央値(データ全体の50%の位置にあたる値)
  • Mean .. 平均値
  • 3rd Qu. .. (データ全体の下から75%の位置にあたる値)
  • Max .. 最大値

が表示されます。

最小値、最大値を求める

min/maxを使って、最小値/最大値を算出できます

# Heightだけを含むベクトルを作成
> heights <- with(heights.weights, Height)
> head(heights)
[1] 73.84702 68.78190 74.11011 71.73098 69.88180 67.25302
> min(heights)
[1] 54.26313
> max(heights)
[1] 78.99874

rangeで、両方をまとめて計算することもできます。

> range(heights)
[1] 54.26313 78.99874

分位数を求める

quantile で、データ中の各位置のデータを出力できます。

> quantile(heights)
      0%      25%      50%      75%     100% 
54.26313 63.50562 66.31807 69.17426 78.99874 

分割幅を指定することもできます。

> quantile(heights, probs = seq(0, 1, by = 0.20))
      0%      20%      40%      60%      80%     100% 
54.26313 62.85901 65.19422 67.43537 69.81162 78.99874 

分散と標準偏差を求める

var,sd を使います。

# 標準偏差
> var(heights)
[1] 14.80347
# 分散
> sd(heights)
[1] 3.847528

データの可視化

必要なライブラリを読み込み。

> library('ggplot2')

ヒストグラム

> plot = ggplot(heights.weights, aes(x = Height)) + geom_histogram(binwidth = 1)
> ggsave(plot = plot, filename = "histgram.png", width = 6, height = 8)

f:id:unageanu:20160110141407p:plain

密度プロットにしてみます。少ないデータ量でも、データセットの形状が分かりやすいのがメリット。

> plot = ggplot(heights.weights, aes(x = Height)) + geom_density()
> ggsave(plot = plot, filename = "kde_histgram.png", width = 6, height = 8)

f:id:unageanu:20160110141408p:plain

性別ごとの特徴をみるため、性別ごとのヒストグラムを表示してみます。

> plot = ggplot(heights.weights, aes(x = Height, fill = Gender)) + geom_density() + facet_grid(Gender ~ .)
> ggsave(plot = plot, filename = "gender_histgram.png", width = 6, height = 8)

f:id:unageanu:20160110141405p:plain

ヒストグラムの分類を整理。詳しくはWikipediaで。

  • 正規分布
    • ピーク(=最頻値)が1つしかない、単峰分布
    • 左右が対称
    • 裾が薄い(データのばらつきが小さい)
  • コーシー分布
    • ピーク(=最頻値)が1つしかない、単峰分布
    • 左右が対称
    • 裾が厚い(データのばらつきが大きい)
  • ガンマ分布
    • 左右が非対称で、平均値と中央値が大きく異なる
  • 指数分布
    • 左右が非対称で、最頻値がゼロ。

正規分布の例。

> set.seed(1)

> normal.values <- rnorm(250, 0, 1)
> plot = ggplot(data.frame(X = normal.values), aes(x = X)) + geom_density()
> ggsave(plot = plot, filename = "normal_histgram.png", width = 6, height = 8)

f:id:unageanu:20160110141409p:plain

コーシー分布。

> cauchy.values <- rcauchy(250, 0, 1)
> plot = ggplot(data.frame(X = cauchy.values), aes(x = X)) + geom_density()
> ggsave(plot = plot, filename = "cauchy_histgram.png", width = 6, height = 8)

f:id:unageanu:20160110141403p:plain

ガンマ分布。

> gamma.values <- rgamma(100000, 1, 0.001)
> plot = ggplot(data.frame(X = gamma.values), aes(x = X)) + geom_density()
> ggsave(plot = plot, filename = "gamma_histgram.png", width = 6, height = 8)

f:id:unageanu:20160110141404p:plain

指数分布、はない。

散布図

身長と体重の散布図を描きます。

> plot = ggplot(heights.weights, aes(x = Height, y = Weight)) + geom_point()
> ggsave(plot = plot, filename = "scatterplots.png", width = 6, height = 8)

f:id:unageanu:20160110141410p:plain

身長、体重には相関関係がありそう。geom_smooth()を使って、妥当な予測領域を表示してみます。

> plot = ggplot(heights.weights, aes(x = Height, y = Weight)) + geom_point() + geom_smooth()
> ggsave(plot = plot, filename = "scatterplots2.png", width = 6, height = 8)

f:id:unageanu:20160110141411p:plain

最後に、男女別の散布図を描いて終わり。

> plot = ggplot(heights.weights, aes(x = Height, y = Weight, color = Gender)) + geom_point()
> ggsave(plot = plot, filename = "gender_scatterplots.png", width = 6, height = 8)

f:id:unageanu:20160110141406p:plain

機械学習手習い : Rをインストールして、基本的な使い方を学ぶ

オライリーの「入門 機械学習」を手に入れたので、手を動かしながら学びます。

www.amazon.co.jp

まずは、1章。Rのインストールと基本的な使い方の学習まで。

Rのインストール

手元にあったCentOS7にインストールしました。

$ cat /etc/redhat-release
CentOS Linux release 7.1.1503 (Core) 
$ sudo yum install epel-release
$ sudo yum install R
$ R

R version 3.2.3 (2015-12-10) -- "Wooden Christmas-Tree"
Copyright (C) 2015 The R Foundation for Statistical Computing
Platform: x86_64-redhat-linux-gnu (64-bit)
...

サンプルコードのダウンロード

次に、GitHubで公開されている「入門 機械学習」のサンプルコードを取得します。

$ cd ~
$ git clone https://github.com/johnmyleswhite/ML_for_Hackers.git ml_for_hackers
$ cd  ml_for_hackers

必要モジュールのインストール

サンプルコードを動かす時に必要なパッケージをインストール。

$ R
> source("package_installer.R")

ユーザー権限で実行すると、デフォルトのインストール先に書き込み権がないとのこと。y を押して、ホームディレクトリにインストール。 そこそこ時間がかかるので待つ・・・。

・・・いくつか、エラーになりました。

1:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘rgl’ had non-zero exit status
2:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘XML’ had non-zero exit status
3:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘RCurl’ had non-zero exit status
4:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘Rpoppler’ had non-zero exit status
5:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘XML’ had non-zero exit status
6:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘RCurl’ had non-zero exit status
7:  install.packages(p, dependencies = TRUE, type = "source"): 
  installation of package ‘XML’ had non-zero exit status

個別にインストールしてみると、依存モジュールがないのが原因のよう。

install.packages("XML", dependencies=TRUE)

Cannot find xml2-config
ERROR: configuration failed for package ‘XML’
* removing ‘/home/yamautim/R/x86_64-redhat-linux-gnu-library/3.2/XML’

必要なモジュールをインストールして、

$ sudo yum -y install libxml2-devel curl-devel poppler-glib-devel freeglut-devel

再試行。

> install.packages("rgl", dependencies=TRUE)
> install.packages("Rpoppler", dependencies=TRUE)
> install.packages("XML", dependencies=TRUE)
> install.packages("RCurl", dependencies=TRUE)

基礎練習で使うライブラリとデータの読み込み

基礎練習で使うライブラリとデータ(UFOの目撃情報データ)を読み込みます。

> setwd("01-Introduction/")
> library("ggplot2")
> library(plyr)
> library(scales)
> ufo <- read.delim("data/ufo/ufo_awesome.tsv", sep="\t", stringsAsFactors=FALSE, header=FALSE, na.strings="")

先頭、末尾の6行を表示。

> head(ufo)                                    
        V1       V2                    V3   V4      V5
1 19951009 19951009         Iowa City, IA <NA>    <NA>
2 19951010 19951011         Milwaukee, WI <NA>  2 min.
3 19950101 19950103           Shelton, WA <NA>    <NA>
...
> tail(ufo)
            V1       V2                   V3        V4                V5
61865 20100828 20100828      Los Angeles, CA      disk        40 seconds
61866 20090424 20100820         Hartwell, GA      oval            10 min
61867 20100821 20100826  Franklin Square, NY  fireball        20 minutes
61868 20100827 20100827         Brighton, CO    circle    at lest 45 min
61869 20100818 20100821  Dryden (Canada), ON     other 5 Min. maybe more
61870 20050502 20100824        Fort Knox, KY  triangle        15 seconds

データをクリーニングする

読み込んだデータを、処理しやすい形に変換します。

データ列に名前をつける

> names(ufo) <- c("DateOccurred", "DateReported", "Location", "ShortDescription",
"Duration", "LongDescription")
> head(ufo)
  DateOccurred DateReported              Location ShortDescription Duration
1     19951009     19951009         Iowa City, IA             <NA>     <NA>
2     19951010     19951011         Milwaukee, WI             <NA>   2 min.
3     19950101     19950103           Shelton, WA             <NA>     <NA>
4     19950510     19950510          Columbia, MO             <NA>   2 min.
5     19950611     19950614           Seattle, WA             <NA>     <NA>
6     19951025     19951024  Brunswick County, ND             <NA>  30 min.

V1 などとなっていたところに、DateOccurredのようなわかりやすい名前が付きました。

DateOccurred,DateReported を日付に変換する

DateOccurred 列は、日付を示す文字列なので、Dateオブジェクトに変換します。

> ufo$DateOccurred <- as.Date(ufo$DateOccurred, format = "%Y%m%d")
 strptime(x, format, tz = "GMT") でエラー:  入力文字列が長すぎます 

エラーになりました。長い文字列を含む列があるよう。探します。

> head(ufo[which(nchar(ufo$DateOccurred)!=8|nchar(ufo$DateReported)!=8),1])
[1] "ler@gnv.ifas.ufl.edu" 
[2] "0000"
[3] "Callers report sighting a number of soft white  balls of lights headingin an easterly directing then changing direction to the west beforespeeding off to the north west."
[4] "0000"
[5] "0000"
[6] "0000"

変なデータがありますな。ということで、8文字でない列を削除します。

# 行のDateOccurred または、DateReported が8文字かどうかを格納するデータ(good.rows)を作成
> good.rows <- ifelse(nchar(ufo$DateOccurred) != 8 |
 nchar(ufo$DateReported) != 8,
 FALSE, TRUE)
# FALSE(=DateOccurred または、DateReported が8文字でない行)の数を確認。
> length(which(!good.rows))
[1] 731
# good.rowsがTRUEの行のみ抽出
> ufo <- ufo[good.rows, ]    

変換を再試行。今度は、日付型に変換できました。

> ufo$DateOccurred <- as.Date(ufo$DateOccurred, format = "%Y%m%d")
> ufo$DateReported <- as.Date(ufo$DateReported, format = "%Y%m%d")
> head(ufo)
  DateOccurred DateReported              Location ShortDescription Duration
1   1995-10-09   1995-10-09         Iowa City, IA             <NA>     <NA>
2   1995-10-10   1995-10-11         Milwaukee, WI             <NA>   2 min.
3   1995-01-01   1995-01-03           Shelton, WA             <NA>     <NA>

関数を作ってLocation列のデータを都市名と州名に分割する

Location列のデータは、Iowa City, IA のような「都市名, 州名」形式の文字列になっています。 関数を使ってこれを都市名, 州名に分解します。

まずは、分解を行う関数を作成。

> get.location <- function(l) {
  split.location <- tryCatch(strsplit(l, ",")[[1]], error = function(e) return(c(NA, NA)))
  clean.location <- gsub("^ ","",split.location)
  if (length(clean.location) > 2) {
    return(c(NA,NA))
  } else {
    return(clean.location)
  }
}

lapplyを使って、関数を適用したデータのリストを作成。

> city.state <- lapply(ufo$Location, get.location)
> head(city.state)
[[1]]
[1] "Iowa City" "IA"       

[[2]]
[1] "Milwaukee" "WI"   

これを、ufoに追加します。まずは、do.callでリストを行列に変換。

> location.matrix <- do.call(rbind, city.state)
> head(location.matrix)
     [,1]               [,2]
[1,] "Iowa City"        "IA"
[2,] "Milwaukee"        "WI"
[3,] "Shelton"          "WA"
[4,] "Columbia"         "MO"
[5,] "Seattle"          "WA"
[6,] "Brunswick County" "ND"

transformufoに追加します。

> ufo <- transform(ufo, USCity = location.matrix[, 1], USState = location.matrix[, 2], stringsAsFactors = FALSE)

また、データには、カナダのものが含まれているので、これも除去します。 USStateにアメリカの州名以外が入っているデータを削除。

> ufo$USState <- state.abb[match(ufo$USState, state.abb)]
> ufo.us <- subset(ufo, !is.na(USState))

これで、データのクリーニングは完了。

> summary(ufo.us)                                                                                                                                                        
  DateOccurred         DateReported          Location        
 Min.   :1400-06-30   Min.   :1905-06-23   Length:51636      
 1st Qu.:1999-09-07   1st Qu.:2002-04-14   Class :character  
 Median :2004-01-10   Median :2005-03-27   Mode  :character  
 Mean   :2001-02-13   Mean   :2004-11-30                     
 3rd Qu.:2007-07-27   3rd Qu.:2008-01-20                     
 Max.   :2010-08-30   Max.   :2010-08-30                     
 ShortDescription     Duration         LongDescription       USCity         
 Length:51636       Length:51636       Length:51636       Length:51636      
 Class :character   Class :character   Class :character   Class :character  
 Mode  :character   Mode  :character   Mode  :character   Mode  :character  
                                                                            
                                                                            
                                                                            
   USState         
 Length:51636      
 Class :character  
 Mode  :character  
                   
                   
                   
> head(ufo.us)
  DateOccurred DateReported              Location ShortDescription Duration
1   1995-10-09   1995-10-09         Iowa City, IA             <NA>     <NA>
2   1995-10-10   1995-10-11         Milwaukee, WI             <NA>   2 min.
3   1995-01-01   1995-01-03           Shelton, WA             <NA>     <NA>
4   1995-05-10   1995-05-10          Columbia, MO             <NA>   2 min.
5   1995-06-11   1995-06-14           Seattle, WA             <NA>     <NA>
6   1995-10-25   1995-10-24  Brunswick County, ND             <NA>  30 min.

データを分析する

クリーニングしたデータを分析して、州/月ごとの目撃情報の傾向を分析します。

DateOccurredのばらつきを調べる

> summary(ufo.us$DateOccurred) 
        Min.      1st Qu.       Median         Mean      3rd Qu.         Max.         NA's 
"1400-06-30" "1999-09-07" "2004-01-10" "2001-02-13" "2007-07-27" "2010-08-30"          "1" 

1400年のデータが含まれている・・・。ヒストグラムで、分布をみてみます。

> quick.hist <- ggplot(ufo.us, aes(x = DateOccurred)) +
  geom_histogram() + scale_x_date(date_breaks = "50 years", date_labels = "%Y")
> ggsave(plot = quick.hist,
 filename = file.path("images", "quick_hist.png"),
 height = 6,
 width = 8)

f:id:unageanu:20160109175440p:plain

大部分は最近の20年に集中している模様。この範囲に絞って分析するため、古いデータを取り除きます。

> ufo.us <- subset(ufo.us, DateOccurred >= as.Date("1990-01-01"))

データを州/月ごとに集計する

まずは、DateOccurred 列を「年-月」に変換した列を作成。

> ufo.us$YearMonth <- strftime(ufo.us$DateOccurred, format = "%Y-%m")

次に、月/州ごとにデータをグループ化して、データ数を集計します。

> sightings.counts <- ddply(ufo.us, .(USState,YearMonth), nrow)
> head(sightings.counts)
  USState YearMonth V1
1      AK   1990-01  1
2      AK   1990-03  1
3      AK   1990-05  1
4      AK   1993-11  1
5      AK   1994-11  1
6      AK   1995-01  1

これだけだと、データが一つもない月が集計結果に含まれないので、それを補完します。 seq.Dateを使ってシーケンスを作成。

> date.range <- seq.Date(from = as.Date(min(ufo.us$DateOccurred)),
                         to = as.Date(max(ufo.us$DateOccurred)),
                         by = "month")
> date.strings <- strftime(date.range, "%Y-%m")

作成したシーケンスに、州の一覧を掛け合わせて、州/月の行列を作成します。

> states.dates <- lapply(state.abb, function(s) cbind(s, date.strings))
> states.dates <- data.frame(do.call(rbind, states.dates),
 stringsAsFactors = FALSE)
> head(states.dates)
   s date.strings
1 AL      1990-01
2 AL      1990-02
3 AL      1990-03
4 AL      1990-04
5 AL      1990-05
6 AL      1990-06

さらに、sightings.counts をマージして、月の欠落がない集計データを作成。

> all.sightings <- merge(states.dates,
 sightings.counts,
 by.x = c("s", "date.strings"),
 by.y = c("USState", "YearMonth"),
 all = TRUE)
> head(all.sightings)                                                                                                                                 
   s date.strings V1
1 AK      1990-01  1
2 AK      1990-02 NA
3 AK      1990-03  1
4 AK      1990-04 NA
5 AK      1990-05  1
6 AK      1990-06 NA

わかりやすいよう、列に名前を付けます。また、集計しやすいようにNA0に変換するなどの操作を行っておきます。

> names(all.sightings) <- c("State", "YearMonth", "Sightings")
> all.sightings$Sightings[is.na(all.sightings$Sightings)] <- 0
> all.sightings$YearMonth <- as.Date(rep(date.range, length(state.abb)))
> all.sightings$State <- as.factor(all.sightings$State)
> head(all.sightings)
  State  YearMonth Sightings
1    AK 1990-01-01         1
2    AK 1990-02-01         0
3    AK 1990-03-01         1
4    AK 1990-04-01         0
5    AK 1990-05-01         1
6    AK 1990-06-01         0

分析用データはこれで完成。

月/州ごとの目撃情報数をグラフにして分析する

> state.plot <- ggplot(all.sightings, aes(x = YearMonth,y = Sightings)) +
  geom_line(aes(color = "darkblue")) +
  facet_wrap(~State, nrow = 10, ncol = 5) + 
  theme_bw() + 
  scale_color_manual(values = c("darkblue" = "darkblue"), guide = "none") +
  scale_x_date(date_breaks = "5 years", date_labels = '%Y') +
  xlab("Years") +
  ylab("Number of Sightings") +
  ggtitle("Number of UFO sightings by Month-Year and U.S. State (1990-2010)")
> ggsave(plot = state.plot,
       filename = file.path("images", "ufo_sightings.png"),
       width = 14,
       height = 8.5)

こんなグラフになります。

f:id:unageanu:20160109175441p:plain

分析は省略。

トラップリピートイフダンのような注文を発行するエージェントのサンプル

FXシステムトレードフレームワーク「Jiji」のサンプル その2、として、 トラップリピートイフダンのような注文を発行するエージェントを作ってみました。

※トラップリピートイフダン(トラリピ)は、マネースクウェアジャパン(M2J)の登録商標です。

トラップリピートイフダンとは

指値/逆指値の注文と決済を複数組み合わせて行い、その中でレートが上下することで利益を出すことを狙う、発注ロジックです。 具体的にどういった動きをするのかは、マネースクウェアジャパン のサイトがとてもわかりやすいので、そちらをご覧ください。

www.toraripifx.com

特徴

FX研究日記さんの評価記事が参考になります。

tasfx.net

  • レンジ相場では、利益を出しやすい
  • ×レートが逆行すると損失を貯めこんでしまう

仕組みからして、いわゆるコツコツドカンなシステムという印象です。 レンジ相場なら利益を積み上げやすいので、トレンドを判定するロジックと組み合わせて、レートが一定のレンジで動作しそうになったら稼働させる、などすれば使えるかも。

エージェントのコード

  • 実装は、こちらのサイトで配布されているEAを参考にさせていただきました。
  • TrapRepeatIfDoneAgentが、エージェントの本体です。これをバックテストやリアルトレードで動作させればOK。
  • 機能の再利用ができるように、発注処理はTrapRepeatIfDoneに実装しています。
# === トラップリピートイフダンのような注文を発行するエージェント
class TrapRepeatIfDoneAgent

  include Jiji::Model::Agents::Agent

  def self.description
    <<-STR
トラップリピートイフダンのような注文を発行するエージェント
      STR
  end

  # UIから設定可能なプロパティの一覧
  def self.property_infos
    [
      Property.new('trap_interval_pips', 'トラップを仕掛ける間隔(pips)', 50),
      Property.new('trade_units',        '1注文あたりの取引数量',         1),
      Property.new('profit_pips',        '利益を確定するpips',         100),
      Property.new('slippage',           '許容スリッページ(pips)',       3)
    ]
  end

  def post_create
    @trap_repeat_if_done = TrapRepeatIfDone.new(
      broker.pairs.find {|p| p.name == :USDJPY }, :buy,
      @trap_interval_pips.to_i,
      @trade_units.to_i, @profit_pips.to_i, @slippage.to_i, logger)
  end

  def next_tick(tick)
    @trap_repeat_if_done.register_orders(broker)
  end

  def state
    @trap_repeat_if_done.state
  end

  def restore_state(state)
    @trap_repeat_if_done.restore_state(state)
  end

end


# トラップリピートイフダンのような注文を発行するクラス
class TrapRepeatIfDone

  # コンストラクタ
  #
  # target_pair:: 現在の価格を格納するTick::Valueオブジェクト
  # sell_or_buy:: 取引モード。 :buy の場合、買い注文を発行する。 :sellの場合、売
  # trap_interval_pips:: トラップを仕掛ける間隔(pips)
  # trade_units:: 1注文あたりの取引数量
  # profit_pips:: 利益を確定するpips
  # slippage:: 許容スリッページ。nilの場合、指定しない
  def initialize(target_pair, sell_or_buy=:buy, trap_interval_pips=50,
    trade_units=1, profit_pips=100, slippage=3, logger=nil)

    @target_pair        = target_pair
    @trap_interval_pips = trap_interval_pips
    @slippage           = slippage

    @mode = if sell_or_buy == :sell
      Sell.new(target_pair, trade_units, profit_pips, slippage, logger)
    else
      Buy.new(target_pair, trade_units, profit_pips, slippage, logger)
    end

    @logger = logger

    @registerd_orders   = {}
  end

  # 注文を登録する
  #
  # broker:: broker
  def register_orders(broker)
    broker.instance_variable_get(:@broker).refresh_positions
    # 常に最新の建玉を取得して利用するようにする
    # TODO 公開APIにする

    each_traps(broker.tick) do |trap_open_price|
      next if order_or_position_exists?(trap_open_price, broker)
      register_order(trap_open_price, broker)
    end
  end

  def state
    @registerd_orders
  end

  def restore_state(state)
    @registerd_orders = state unless state.nil?
  end

  private

  def each_traps(tick)
    current_price = @mode.resolve_current_price(tick[@target_pair.name])
    base = resolve_base_price(current_price)
    6.times do |n| # baseを基準に、上下3つのトラップを仕掛ける
      trap_open_price = BigDecimal.new(base, 10) \
        + BigDecimal.new(@trap_interval_pips, 10) * (n-3) * @target_pair.pip
      yield trap_open_price
    end
  end

  # 現在価格をtrap_interval_pipsで丸めた価格を返す。
  #
  #  例) trap_interval_pipsが50の場合、
  #  resolve_base_price(120.10) # -> 120.00
  #  resolve_base_price(120.49) # -> 120.00
  #  resolve_base_price(120.51) # -> 120.50
  #
  def resolve_base_price(current_price)
    current_price = BigDecimal.new(current_price, 10)
    pip_precision = 1 / @target_pair.pip
    (current_price * pip_precision / @trap_interval_pips ).ceil \
      * @trap_interval_pips / pip_precision
  end

  # trap_open_priceに対応するオーダーを登録する
  def register_order(trap_open_price, broker)
    result = @mode.register_order(trap_open_price, broker)
    unless result.order_opened.nil?
      @registerd_orders[key_for(trap_open_price)] \
        = result.order_opened.internal_id
    end
  end

  # trap_open_priceに対応するオーダーを登録済みか評価する
  def order_or_position_exists?(trap_open_price, broker)
    order_exists?(trap_open_price, broker) \
    || position_exists?(trap_open_price, broker)
  end
  def order_exists?(trap_open_price, broker)
    key = key_for(trap_open_price)
    return false unless @registerd_orders.include? key
    id = @registerd_orders[key]
    order = broker.orders.find {|o| o.internal_id == id }
    return !order.nil?
  end
  def position_exists?(trap_open_price, broker)

    # trapのリミット付近でレートが上下して注文が大量に発注されないよう、
    # trapのリミット付近を開始値とする建玉が存在する間は、trapの注文を発行しない
    slipage_price = (@slippage.nil? ? 10 : @slippage) * @target_pair.pip
    position = broker.positions.find do |p|
      # 注文時に指定したpriceちょうどで約定しない場合を考慮して、
      # 指定したslippage(指定なしの場合は10pips)の誤差を考慮して存在判定をする
      p.entry_price < trap_open_price + slipage_price \
      && p.entry_price > trap_open_price - slipage_price
    end
    return !position.nil?
  end

  def key_for(trap_open_price)
    (trap_open_price * (1 / @target_pair.pip)).to_i.to_s
  end

  # 取引モード(売 or 買)
  # 買(Buy)の場合、買でオーダーを行う。売(Sell)の場合、売でオーダーを行う。
  class Mode

    def initialize(target_pair, trade_units, profit_pips, slippage, logger)
      @target_pair  = target_pair
      @trade_units  = trade_units
      @profit_pips  = profit_pips
      @slippage     = slippage
      @logger       = logger
    end

    # 現在価格を取得する(買の場合Askレート、売の場合Bidレートを使う)
    #
    # tick_value:: 現在の価格を格納するTick::Valueオブジェクト
    # 戻り値:: 現在価格
    def resolve_current_price(tick_value)
    end

    # 注文を登録する
    def register_order(trap_open_price, broker)
    end

    def calculate_price(price, pips)
      price = BigDecimal.new(price, 10)
      pips  = BigDecimal.new(pips,  10) * @target_pair.pip
      (price + pips).to_f
    end
    def pring_order_log(mode, options, timestamp)
      return unless @logger
      message = [
        mode, timestamp, options[:price], options[:take_profit],
        options[:lower_bound], options[:upper_bound]
      ].map {|item| item.to_s }.join(" ")
      @logger.info message
    end
  end

  class Sell < Mode
    def resolve_current_price(tick_value)
      tick_value.bid
    end
    def register_order(trap_open_price, broker)
      timestamp = broker.tick.timestamp
      options = create_option(trap_open_price, timestamp)
      pring_order_log("sell", options, timestamp)
      broker.sell(@target_pair.name, @trade_units, :marketIfTouched, options)
    end
    def create_option(trap_open_price, timestamp)
      options = {
        price:       trap_open_price.to_f,
        take_profit: calculate_price(trap_open_price, @profit_pips*-1),
        expiry:      timestamp + 60*60*24*7
      }
      unless @slippage.nil?
        options[:lower_bound] = calculate_price(trap_open_price, @slippage*-1)
        options[:upper_bound] = calculate_price(trap_open_price, @slippage)
      end
      options
    end
  end

  class Buy < Mode
    def resolve_current_price(tick_value)
      tick_value.ask
    end
    def register_order(trap_open_price, broker)
      timestamp = broker.tick.timestamp
      options = create_option(trap_open_price, timestamp)
      pring_order_log("buy", options, timestamp)
      broker.buy(@target_pair.name, @trade_units, :marketIfTouched, options)
    end
    def create_option(trap_open_price, timestamp)
      options = {
        price:       trap_open_price.to_f,
        take_profit: calculate_price(trap_open_price, @profit_pips),
        expiry:      timestamp + 60*60*24*7
      }
      unless @slippage.nil?
        options[:lower_bound] = calculate_price(trap_open_price, @slippage*-1)
        options[:upper_bound] = calculate_price(trap_open_price, @slippage)
      end
      options
    end
  end

end

インタラクティブにトレーリングストップ決済を行うBotを作ってみた

FXシステムトレードフレームワーク「Jiji」の使い方サンプル その1、ということで、 Jijiを使って、インタラクティブにトレーリングストップ決済を行うBotを作ってみました。

トレーリングストップとは

建玉(ポジション)の決済方法の一つで、「最高値を更新するごとに、逆指値の決済価格を切り上げていく」決済ロジックです。

例) USDJPY/120.10で買建玉を作成。これを、10 pips でトレーリングストップする場合、

f:id:unageanu:20151228130735p:plain

  • 建玉作成直後は、120.00 で逆指値決済される状態になる
  • レートが 120.30 になった場合、逆指値の決済価格が高値に合わせて上昇し、120.20に切り上がる
  • その後、レートが120.20 になると、逆指値で決済される

トレンドに乗っている間はそのまま利益を増やし、トレンドが変わって下げ始めたら決済する、という動きをする決済ロジックですね。

インタラクティブにしてみる

単純なトレーリングストップだけなら証券会社が提供している機能で実現できるので、少し手を加えてインタラクティブにしてみました。

トレーリングストップでは、以下のようなパターンがありがち。

  • すこし大きなドローダウンがきて、トレンド変わってないのに決済されてしまい、利益を逃した・・
  • レートが急落した時に、決済が遅れて損失が広がった・・・

これを回避できるように、Botでの強制決済に加えて、人が状況をみて決済するかどうか判断できる仕組みをいれてみます。

仕様

以下のような動作をします。

f:id:unageanu:20151228130736p:plain

  • トレーリングストップの閾値を2段階で指定できるようにして、1つ目の閾値を超えたタイミングでは警告通知を送信。

    • 通知を確認して、即時決済するか、保留するか判断できる。
    • 決済をスムーズに行えるよう、通知から1タップで決済を実行できるようにする。 f:id:unageanu:20151228105949p:plain
  • 2つ目の閾値を超えた場合、Bot建玉を決済。

    • 夜間など通知を受けとっても対処できない場合を考慮して、2つ目の閾値を超えたら、強制決済するようにしておきます。
    • なお、決済時にはOANDA JAPANから通知が送信されるので、Jijiからの通知は省略しました。

Bot(エージェント)のコード

  • TrailingStopAgentが、Botの本体。これをバックテストやリアルトレードで動作させればOKです。
  • TrailingStopAgent自体は、新規に建玉を作ることはしません。
    • 裁量トレードや他のエージェントが作成した建玉を自動で監視し、トレーリングストップを行います。
    • バックテストで試す場合は、建玉を作成するエージェントと一緒に動作させてください。
  • 機能の再利用ができるように、処理はTrailingStopManagerに実装しています。
# トレーリングストップで建玉を決済するエージェント
class TrailingStopAgent

  include Jiji::Model::Agents::Agent

  def self.description
    <<-STR
トレーリングストップで建玉を決済するエージェント。
 - 損益が警告を送る閾値を下回ったら、1度だけ警告をPush通知で送信。
 - さらに決済する閾値も下回ったら、建玉を決済します。
      STR
  end

  # UIから設定可能なプロパティの一覧
  def self.property_infos
    [
      Property.new('warning_limit', '警告を送る閾値', 20),
      Property.new('closing_limit', '決済する閾値',   40)
    ]
  end

  def post_create
    @manager = TrailingStopManager.new(
      @warning_limit.to_i, @closing_limit.to_i, notifier)
  end

  def next_tick(tick)
    @manager.check(broker.positions, broker.pairs)
  end

  def execute_action(action)
    @manager.process_action(action, broker.positions) || '???'
  end

  def state
    {
      trailing_stop_manager: @manager.state
    }
  end

  def restore_state(state)
    if state[:trailing_stop_manager]
      @manager.restore_state(state[:trailing_stop_manager])
    end
  end

end

# 建玉を監視し、最新のレートに基づいてトレールストップを行う
class TrailingStopManager

  # コンストラクタ
  #
  # warning_limit:: 警告を送信する閾値(pip)
  # closing_limit:: 決済を行う閾値(pip)
  # notifier:: notifier
  def initialize(warning_limit, closing_limit, notifier)
    @warning_limit = warning_limit
    @closing_limit = closing_limit
    @notifier  = notifier

    @states = {}
  end

  # 建玉がトレールストップの閾値に達していないかチェックする。
  # warning_limit を超えている場合、警告通知を送信、
  # closing_limit を超えた場合、強制的に決済する。
  #
  # positions:: 建て玉一覧(broker#positions)
  # pairs:: 通貨ペア一覧(broker#pairs)
  def check(positions, pairs)
    @states = positions.each_with_object({}) do |position, r|
      r[position.id.to_s] = check_position(position, pairs)
    end
  end

  # アクションを処理する
  #
  # action:: アクション
  # positions:: 建て玉一覧(broker#positions)
  # 戻り値:: アクションを処理できた場合、レスポンスメッセージ。
  #         TrailingStopManagerが管轄するアクションでない場合、nil
  def process_action(action, positions)
    return nil unless action =~ /trailing\_stop\_\_([a-z]+)_(.*)$/
    case $1
    when "close" then
        position = positions.find {|p| p.id.to_s == $2 }
        return nil unless position
        position.close
        return "建玉を決済しました。"
    end
  end

  # 永続化する状態。
  def state
    @states.each_with_object({}) {|s, r| r[s[0]] = s[1].state }
  end

  # 永続化された状態から、インスタンスを復元する
  def restore_state(state)
    @states = state.each_with_object({}) do |s, r|
      state = PositionState.new( nil,
        @warning_limit, @closing_limit )
      state.restore_state(s[1])
      r[s[0]] = state
    end
  end

  private

  # 建玉の状態を更新し、閾値を超えていたら対応するアクションを実行する。
  def check_position(position, pairs)
    state = get_and_update_state(position, pairs)
    if state.under_closing_limit?
      position.close
    elsif state.under_warning_limit?
      unless state.sent_warning # 通知は1度だけ送信する
        send_notification(position, state)
        state.sent_warning = true
      end
    end
    return state
  end

  def get_and_update_state(position, pairs)
    state = create_or_get_state(position, pairs)
    state.update(position)
    state
  end

  def create_or_get_state(position, pairs)
    key = position.id.to_s
    return @states[key] if @states.include? key
    PositionState.new(
      retrieve_pip_for(position.pair_name, pairs),
      @warning_limit, @closing_limit )
  end

  def retrieve_pip_for(pair_name, pairs)
    pairs.find {|p| p.name == pair_name }.pip
  end

  def send_notification(position, state)
    message = "#{create_position_description(position)}" \
      + " がトレールストップの閾値を下回りました。決済しますか?"
    @notifier.push_notification(message,  [{
        'label'  => '決済する',
        'action' => 'trailing_stop__close_' + position.id.to_s
    }])
  end

  def create_position_description(position)
    sell_or_buy = position.sell_or_buy == :sell ? "" : ""
    "#{position.pair_name}/#{position.entry_price}/#{sell_or_buy}"
  end

end

class PositionState

  attr_reader :max_profit, :profit_or_loss, :max_profit_time, :last_update_time
  attr_accessor :sent_warning

  def initialize(pip, warning_limit, closing_limit)
    @pip           = pip
    @warning_limit = warning_limit
    @closing_limit = closing_limit
    @sent_warning  = false
  end

  def update(position)
    @units            = position.units
    @profit_or_loss   = position.profit_or_loss
    @last_update_time = position.updated_at

    if @max_profit.nil? || position.profit_or_loss > @max_profit
      @max_profit      = position.profit_or_loss
      @max_profit_time = position.updated_at
      @sent_warning    = false
      # 高値を更新したあと、 warning_limit を超えたら再度警告を送るようにする
    end
  end

  def under_warning_limit?
    return false if @max_profit.nil?
    difference >= @warning_limit * @units * @pip
  end

  def under_closing_limit?
    return false if @max_profit.nil?
    difference >= @closing_limit * @units * @pip
  end

  def state
    {
      "max_profit"      => @max_profit,
      "max_profit_time" => @max_profit_time,
      "pip"             => @pip,
      "sent_warning"    => @sent_warning
    }
  end

  def restore_state(state)
    @max_profit      = state["max_profit"]
    @max_profit_time = state["max_profit_time"]
    @pip             = state["pip"]
    @sent_warning    = state["sent_warning"]
  end

  private

  def difference
    @max_profit - @profit_or_loss
  end

end

それでは、みなさま、良いお年を。

MongoDBのinsert/updateをまとめて、bulk insert/update に流すユーティリティを書いた

バッチ処理などでMongoDBに大量のinsert/updateを行うとき、Mongoidを使って1つずつ #save してると遅い。 ということで、複数#save をまとめて bulk insert/update に流すユーティリティを書いてみました。

使い方

  • モデルクラスで、Mongoid::DocumentUtils::BulkWriteOperationSupportinclude する。
  • Utils::BulkWriteOperationSupport.begin_transaction を呼び出してから、モデルの #save を呼び出す。
    • この時点ではMongoDBへのinsert/updateは行われず、バッファに蓄積されます。
  • Utils::BulkWriteOperationSupport.end_transaction を呼び出すと、バッファのデータが #bulk_write でまとめて永続化される。
class TestModel
  
  # Mongoid::Document と Utils::BulkWriteOperationSupport をincludeする
  include Mongoid::Document
  include Utils::BulkWriteOperationSupport

  store_in collection: 'test_model'

  field :name,      type: String

end

#略

puts TestModel.count # => 0

Utils::BulkWriteOperationSupport.begin_transaction

# #begin_transaction を呼び出したあと、モデルを作成/変更して、 #save を呼び出す。
a = TestModel.new
a.name = 'a'
b = TestModel.new
b.name = 'b'
a.save
b.save

# #end_transaction を実行するまで、永続化されない
puts TestModel.count # => 0

# #end_transaction を呼び出すと、バッファのデータが #bulk_write でまとめて永続化される。
Utils::BulkWriteOperationSupport.end_transaction

puts TestModel.count # => 2

ユーティリティのコード

  • Document#save を書き換えて、#begin_transaction#end_transaction の間であれば、スレッドローカルに永続化対象としてマーク。
  • #end_transaction が呼び出されたタイミンクで、まとめて #bulk_write で永続化します。
  • 参照整合性のチェックとか、いろいろ手抜きなので必要に応じて改造してください。
module BulkWriteOperationSupport

  KEY = BulkWriteOperationSupport.name

  def save
    if BulkWriteOperationSupport.in_transaction?
      BulkWriteOperationSupport.transaction << self
    else
      super
    end
  end

  def self.in_transaction?
    !transaction.nil?
  end

  def self.begin_transaction
    Thread.current[KEY] = Transaction.new
  end

  def self.end_transaction
    return unless in_transaction?
    transaction.execute
    Thread.current[KEY] = nil
  end

  def self.transaction
    Thread.current[KEY]
  end

  def create_insert_operation
    { :insert_one => as_document }
  end

  def create_update_operation
    {
      :update_one => {
        :filter => { :_id => id },
        :update => {'$set' => collect_changed_values }
      }
    }
  end

  private

  def collect_changed_values
    changes.each_with_object({}) do |change, r|
      r[change[0].to_sym] = change[1][1]
    end
  end

  class Transaction

    def initialize
      @targets = {}
    end

    def <<(model)
      targets_of( model.class )[model.object_id] = model
    end

    def execute
      until @targets.empty?
        model_class = @targets.keys.first
        execute_bulk_write_operations(model_class)
      end
    end

    def size
      @targets.values.reduce(0) {|a, e| a + e.length }
    end

    private

    def targets_of( model_class )
      @targets[model_class] ||= {}
    end

    def execute_bulk_write_operations(model_class)
      return unless @targets.include?(model_class)
      execute_parent_object_bulk_write_operations_if_exists(model_class)

      client = model_class.mongo_client[model_class.collection_name]
      operations = create_operations(@targets[model_class].values)
      client.bulk_write(operations) unless operations.empty?

      @targets.delete model_class
    end

    def execute_parent_object_bulk_write_operations_if_exists(model_class)
      parents = model_class.reflect_on_all_associations(:belongs_to)
      parents.each do |m|
        klass = m.klass
        execute_bulk_write_operations(klass)
      end
    end

    def create_operations(targets)
      targets.each_with_object([]) do |model, array|
        if model.new_record?
          model.new_record = false
          array << model.create_insert_operation
        else
          array << model.create_update_operation if model.changed?
        end
      end
    end

  end

end

nukeproof/oanda_api のコネクションリーク問題とその対策

OANDA fx Trade APIRubyクライアント「nukeproof/oanda_api」には、TCPコネクションリークの問題があり、長時間連続で利用しているとファイルディスクリプタが枯渇します。

内部で利用している persistent_http の古いバージョンにある不具合が原因(最新の2.0.1では改修済み)のため、Gemfileなどで最新バージョンを使うようにすると回避できます。

gem 'persistent_http', '2.0.1'

問題の詳細

Jijiを10日程度連続稼働させていて発覚。突然、以下のエラーが発生するようになりました。

E, [2015-12-09T01:23:06.337582 #7932] ERROR -- : Too many open files - getaddrinfo (Errno::EMFILE)
/home/yamautim/.rbenv/versions/2.2.3/lib/ruby/2.2.0/net/http.rb:879:in `initialize'
/home/yamautim/.rbenv/versions/2.2.3/lib/ruby/2.2.0/net/http.rb:879:in `open'
/home/yamautim/.rbenv/versions/2.2.3/lib/ruby/2.2.0/net/http.rb:879:in `block in connect'

lsof コマンドの出力行数も少しずつ増えていきます。

$ lsof -p <Jijiのpid> | wc -l 
77
$ lsof -p <Jijiのpid> | wc -l 
79

起動直後と、しばらく起動した後で、lsofの出力結果のDiffをとると、CLOSE_WAITのhttpsコネクションが増加していました。

ruby    4389 xxxx   23u  IPv4 9265756       0t0       TCP localhost.localdomain:35773->unknown.xxxx.net:https (CLOSE_WAIT)

外向きのHTTPS通信なので、OANDA へのアクセスっぽい。

原因

最初にも書きましたが、persistent_http の古いバージョンにある不具合が原因です。最新の 2.0.1 を使えば改修されます。

  • 1.0.6では、内部で利用しているGenePoolに渡すオプションがnilになっており、コネクションの破棄が正しく行われない状態になっています。
  • このコミットで改修されていて、最新の 2.0.1 を使えばOK。
  • ただし、oanda_api が直接依存している persistent_httparty で バージョン2以下を使うよう明示されているため、普通に使うと 1.0.6 が使われてしまいます。
  • このため、Gemfileなどで最新バージョンを使うように明示する等の対応が必要です。
  • ちなみに、persistent_httparty に、依存するpersistent_httpのバージョンを上げるPull Requestはあるのですが、マージされていないようです・・・。