{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction to PyCall\n", "\n", "まず、PyCall について簡単に説明します。\n", "\n", "PyCall の実態は「**Ruby から libpython.so を使うための拡張ライブラリ**」です。\n", "PyCall は libpython.so の機能を利用して、Ruby から Python のオブジェクトを触れるようにするブリッジ機能を提供します。\n", "PyCall を使うと、例えば以下のように Python 側の `sin` 関数を Ruby 側に持ってきて呼び出すことが可能です。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "require 'pycall'\n", "\n", "# PyCall.import_module function loads a module in Python, and brings the loaded module object in Ruby\n", "pymath = PyCall.import_module('math')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Accessing `sin` attribute of `math` module\n", "pymath.sin" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "1.2246467991473532e-16" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Calling function object by the syntax sugar of `.call` method call\n", "pymath.sin.(Math::PI)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ruby 側に持ってきた Python オブジェクトは、基本的なクラスを除いてすべて PyObject クラスのインスタンスによってラップされます。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PyCall::PyObject" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# pymath is a module object in Python, but it is wrapped by an instance of PyObject in Ruby\n", "pymath.class" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PyCall::PyObject" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# pymath.sin is a builtin-function object in Python, but it is wrapped by an instance of PyObject in Ruby\n", "pymath.sin.class" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "Float" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The result of pymath.sin is a float object in Python, but it is automatically converted to Float object in Ruby\n", "pymath.sin.(Math::PI).class" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "\"sin\"" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The name of a function object\n", "pymath.sin.__name__" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "String" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# It is converted to a String object in Ruby\n", "pymath.sin.__name__.class" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`pycall/import` が提供する機能を利用すると、Python での `import math` と同じような記法でモジュールをインポートできます。\n", "\n", "やってみましょう" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":math" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "require 'pycall/import'\n", "include PyCall::Import\n", "\n", "pyimport :math" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "math" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "math.sin" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.2246467991473532e-16" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "math.sin.(Math::PI)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "PyCall は PyObjectWrapper というモジュールを提供しています。このモジュールを使うと、Python のクラスに対応するラッパークラスを定義できます。ラッパークラスを定義すると、インスタンスメソッドやクラスメソッドの呼び出しを自然に記述できるようになります。\n", "\n", "numpy を例に違いを見てみましょう。\n", "\n", "まず、ラッパークラスを定義せずに numpy を使ってみます。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":np" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pyimport :numpy, as: :np" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# `np.array` retrives a function object\n", "np.array" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0.78720422, 0.04989502, 0.61287186, 0.58935543, 0.10911719,\n", " 0.90534304, 0.17907297, 0.56424541, 0.7532547 , 0.32246486,\n", " 0.65825915, 0.70162822, 0.90634646, 0.19754435, 0.29295796,\n", " 0.61622797, 0.14356238, 0.64854779, 0.18438799, 0.93683436])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Use `.call` method to call `np.array`\n", "ary = np.array.([*1..20].map { rand })" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PyCall::PyObject" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# This is a PyObject\n", "ary.class" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# `ary.mean` retrieves a function object\n", "ary.mean" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.5079560660124105" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Use `.call` method to call `ary.mean`\n", "ary.mean.()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "次に、ラッパークラスを定義します。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[#, rbtype=Numpy::NDArray>]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "module Numpy\n", " class NDArray\n", " include PyCall::PyObjectWrapper\n", " wrap_class PyCall.import_module('numpy').ndarray\n", " end\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "これで Numpy::NDArray クラスが np.ndarray のラッパーになりました。\n", "\n", "もう一度 ndarray オブジェクトを生成してみましょう。" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 10.43929293, 10.774188 , 10.69860716, 10.349782 ,\n", " 10.40455471, 10.31237692, 10.42337371, 10.85054457,\n", " 10.797475 , 10.29543447, 10.97507191, 10.77216906,\n", " 10.80617675, 10.90382459, 10.80300021, 10.61905881,\n", " 10.67108304, 10.77988371, 10.91601262, 10.6303129 ])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ary2 = np.array.([*1..20].map { rand + 10 })" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Numpy::NDArray" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# ary2 is a Numpy::NDArray!!\n", "ary2.class" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "10.661111153255195" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# ary2.mean calls mean method!!\n", "ary2.mean" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "このように、PyObjectWrapper を利用して Python クラスのラッパーを Ruby 側に定義できました。\n", "matplotlib のラッパーライブラリでは、この機能を使って Figure や Axes などのクラスのラッパーを定義しています。\n", "\n", "残念ながら pandas の DataFrame ライブラリに対して wrap_class を適用するとエラーが出てしまう[問題があります](https://github.com/mrkn/pycall/issues/16)。\n", "そのため、このチュートリアルでは pandas のラッパーを定義せずに使っていきます。\n", "\n", "モジュールに対するラッパーを定義する機能はまだ作っていませんが、近日中に提供できる予定になっています。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data Analysis with Ruby using PyCall\n", "\n", "それでは、PyCall を利用して Ruby でデータ分析をやってみましょう。\n", "\n", "## 準備編\n", "\n", "分析に入る前に、いくつか準備をします。\n", "\n", "データの可視化のために seaborn ライブラリを利用します。このライブラリは matplotlib を利用しているため、IRuby と matplotlib の間の連携を有効にします。" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[:inline, \"module://ruby.matplotlib.backend_inline\"]" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "require 'matplotlib/iruby'\n", "Matplotlib::IRuby.activate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "利用するライブラリをインポートしておきましょう。" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":sns" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pyimport :pandas, as: :pd\n", "pyimport :seaborn, as: :sns" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "pandas のデータフレームを IRuby ノートブック上で見やすく表示するための準備をします。\n", "これは、将来的には require 'pandas/iruby' などで自動的に実施されるようにする予定です。" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "collapsed": true }, "outputs": [], "source": [ "module Pandas\n", " class DataFrame < PyCall::PyObject\n", " end\n", "end\n", "\n", "PyCall::Conversions.python_type_mapping(pd.DataFrame, Pandas::DataFrame)\n", "\n", "dataframe_max_rows = 20\n", "\n", "IRuby::Display::Registry.module_eval do\n", " type { Pandas::DataFrame }\n", " format \"text/html\" do |pyobj|\n", " pyobj.to_html.(max_rows: dataframe_max_rows, show_dimensions: true, notebook: true)\n", " end\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## データ分析の実演\n", "\n", "### データの準備と前処理\n", "\n", "タイタニック号の乗客のデータを用いて、乗客の生存予測をするためのモデルを作ってみます。\n", "\n", "seaborn ライブラリの `load_dataset` 関数を使ってデータのダウンロードと読み込みをします。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
survivedpclasssexagesibspparchfareembarkedclasswhoadult_maledeckembark_townalivealone
003male22.0107.2500SThirdmanTrueNaNSouthamptonnoFalse
111female38.01071.2833CFirstwomanFalseCCherbourgyesFalse
213female26.0007.9250SThirdwomanFalseNaNSouthamptonyesTrue
311female35.01053.1000SFirstwomanFalseCSouthamptonyesFalse
403male35.0008.0500SThirdmanTrueNaNSouthamptonnoTrue
503maleNaN008.4583QThirdmanTrueNaNQueenstownnoTrue
601male54.00051.8625SFirstmanTrueESouthamptonnoTrue
703male2.03121.0750SThirdchildFalseNaNSouthamptonnoFalse
813female27.00211.1333SThirdwomanFalseNaNSouthamptonyesFalse
912female14.01030.0708CSecondchildFalseNaNCherbourgyesFalse
................................................
88103male33.0007.8958SThirdmanTrueNaNSouthamptonnoTrue
88203female22.00010.5167SThirdwomanFalseNaNSouthamptonnoTrue
88302male28.00010.5000SSecondmanTrueNaNSouthamptonnoTrue
88403male25.0007.0500SThirdmanTrueNaNSouthamptonnoTrue
88503female39.00529.1250QThirdwomanFalseNaNQueenstownnoFalse
88602male27.00013.0000SSecondmanTrueNaNSouthamptonnoTrue
88711female19.00030.0000SFirstwomanFalseBSouthamptonyesTrue
88803femaleNaN1223.4500SThirdwomanFalseNaNSouthamptonnoFalse
88911male26.00030.0000CFirstmanTrueCCherbourgyesTrue
89003male32.0007.7500QThirdmanTrueNaNQueenstownnoTrue
\n", "

891 rows × 15 columns

\n", "
" ], "text/plain": [ " survived pclass sex age sibsp parch fare embarked class \\\n", "0 0 3 male 22.0 1 0 7.2500 S Third \n", "1 1 1 female 38.0 1 0 71.2833 C First \n", "2 1 3 female 26.0 0 0 7.9250 S Third \n", "3 1 1 female 35.0 1 0 53.1000 S First \n", "4 0 3 male 35.0 0 0 8.0500 S Third \n", "5 0 3 male NaN 0 0 8.4583 Q Third \n", "6 0 1 male 54.0 0 0 51.8625 S First \n", "7 0 3 male 2.0 3 1 21.0750 S Third \n", "8 1 3 female 27.0 0 2 11.1333 S Third \n", "9 1 2 female 14.0 1 0 30.0708 C Second \n", "10 1 3 female 4.0 1 1 16.7000 S Third \n", "11 1 1 female 58.0 0 0 26.5500 S First \n", "12 0 3 male 20.0 0 0 8.0500 S Third \n", "13 0 3 male 39.0 1 5 31.2750 S Third \n", "14 0 3 female 14.0 0 0 7.8542 S Third \n", "15 1 2 female 55.0 0 0 16.0000 S Second \n", "16 0 3 male 2.0 4 1 29.1250 Q Third \n", "17 1 2 male NaN 0 0 13.0000 S Second \n", "18 0 3 female 31.0 1 0 18.0000 S Third \n", "19 1 3 female NaN 0 0 7.2250 C Third \n", "20 0 2 male 35.0 0 0 26.0000 S Second \n", "21 1 2 male 34.0 0 0 13.0000 S Second \n", "22 1 3 female 15.0 0 0 8.0292 Q Third \n", "23 1 1 male 28.0 0 0 35.5000 S First \n", "24 0 3 female 8.0 3 1 21.0750 S Third \n", "25 1 3 female 38.0 1 5 31.3875 S Third \n", "26 0 3 male NaN 0 0 7.2250 C Third \n", "27 0 1 male 19.0 3 2 263.0000 S First \n", "28 1 3 female NaN 0 0 7.8792 Q Third \n", "29 0 3 male NaN 0 0 7.8958 S Third \n", ".. ... ... ... ... ... ... ... ... ... \n", "861 0 2 male 21.0 1 0 11.5000 S Second \n", "862 1 1 female 48.0 0 0 25.9292 S First \n", "863 0 3 female NaN 8 2 69.5500 S Third \n", "864 0 2 male 24.0 0 0 13.0000 S Second \n", "865 1 2 female 42.0 0 0 13.0000 S Second \n", "866 1 2 female 27.0 1 0 13.8583 C Second \n", "867 0 1 male 31.0 0 0 50.4958 S First \n", "868 0 3 male NaN 0 0 9.5000 S Third \n", "869 1 3 male 4.0 1 1 11.1333 S Third \n", "870 0 3 male 26.0 0 0 7.8958 S Third \n", "871 1 1 female 47.0 1 1 52.5542 S First \n", "872 0 1 male 33.0 0 0 5.0000 S First \n", "873 0 3 male 47.0 0 0 9.0000 S Third \n", "874 1 2 female 28.0 1 0 24.0000 C Second \n", "875 1 3 female 15.0 0 0 7.2250 C Third \n", "876 0 3 male 20.0 0 0 9.8458 S Third \n", "877 0 3 male 19.0 0 0 7.8958 S Third \n", "878 0 3 male NaN 0 0 7.8958 S Third \n", "879 1 1 female 56.0 0 1 83.1583 C First \n", "880 1 2 female 25.0 0 1 26.0000 S Second \n", "881 0 3 male 33.0 0 0 7.8958 S Third \n", "882 0 3 female 22.0 0 0 10.5167 S Third \n", "883 0 2 male 28.0 0 0 10.5000 S Second \n", "884 0 3 male 25.0 0 0 7.0500 S Third \n", "885 0 3 female 39.0 0 5 29.1250 Q Third \n", "886 0 2 male 27.0 0 0 13.0000 S Second \n", "887 1 1 female 19.0 0 0 30.0000 S First \n", "888 0 3 female NaN 1 2 23.4500 S Third \n", "889 1 1 male 26.0 0 0 30.0000 C First \n", "890 0 3 male 32.0 0 0 7.7500 Q Third \n", "\n", " who adult_male deck embark_town alive alone \n", "0 man True NaN Southampton no False \n", "1 woman False C Cherbourg yes False \n", "2 woman False NaN Southampton yes True \n", "3 woman False C Southampton yes False \n", "4 man True NaN Southampton no True \n", "5 man True NaN Queenstown no True \n", "6 man True E Southampton no True \n", "7 child False NaN Southampton no False \n", "8 woman False NaN Southampton yes False \n", "9 child False NaN Cherbourg yes False \n", "10 child False G Southampton yes False \n", "11 woman False C Southampton yes True \n", "12 man True NaN Southampton no True \n", "13 man True NaN Southampton no False \n", "14 child False NaN Southampton no True \n", "15 woman False NaN Southampton yes True \n", "16 child False NaN Queenstown no False \n", "17 man True NaN Southampton yes True \n", "18 woman False NaN Southampton no False \n", "19 woman False NaN Cherbourg yes True \n", "20 man True NaN Southampton no True \n", "21 man True D Southampton yes True \n", "22 child False NaN Queenstown yes True \n", "23 man True A Southampton yes True \n", "24 child False NaN Southampton no False \n", "25 woman False NaN Southampton yes False \n", "26 man True NaN Cherbourg no True \n", "27 man True C Southampton no False \n", "28 woman False NaN Queenstown yes True \n", "29 man True NaN Southampton no True \n", ".. ... ... ... ... ... ... \n", "861 man True NaN Southampton no False \n", "862 woman False D Southampton yes True \n", "863 woman False NaN Southampton no False \n", "864 man True NaN Southampton no True \n", "865 woman False NaN Southampton yes True \n", "866 woman False NaN Cherbourg yes False \n", "867 man True A Southampton no True \n", "868 man True NaN Southampton no True \n", "869 child False NaN Southampton yes False \n", "870 man True NaN Southampton no True \n", "871 woman False D Southampton yes False \n", "872 man True B Southampton no True \n", "873 man True NaN Southampton no True \n", "874 woman False NaN Cherbourg yes False \n", "875 child False NaN Cherbourg yes True \n", "876 man True NaN Southampton no True \n", "877 man True NaN Southampton no True \n", "878 man True NaN Southampton no True \n", "879 woman False C Cherbourg yes False \n", "880 woman False NaN Southampton yes False \n", "881 man True NaN Southampton no True \n", "882 woman False NaN Southampton no True \n", "883 man True NaN Southampton no True \n", "884 man True NaN Southampton no True \n", "885 woman False NaN Queenstown no False \n", "886 man True NaN Southampton no True \n", "887 woman False B Southampton yes True \n", "888 woman False NaN Southampton no False \n", "889 man True C Cherbourg yes True \n", "890 man True NaN Queenstown no True \n", "\n", "[891 rows x 15 columns]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = sns.load_dataset.('titanic')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "変数 `df` に代入されたオブジェクトは pandas のデータフレームです。" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "pytype(DataFrame)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.type" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "データ解析の最初のステップは、データの内容を観察することから始まります。\n", "\n", "上の表を見るとわかるように、このデータには、15個のカラムで構成されるレコードが890行あります。\n", "これらのカラムのうち、以下のように内容が重複しているものがあります。\n", "\n", "- `survived` は `alive` を `no` -> 0, `yes` -> 1 として変換して生成したもの\n", "- `embarked` は `embark_town` の頭文字\n", "- `pclass` は `class` を数値にしたもの\n", "- `sex` と `who` は、`male` => `man`, `female` => `woman` という対応関係にある\n", "\n", "内容が重複しているカラムが複数存在すると、情報量は変わらないのに処理量が増えてしまうため、これらを削除します。" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['survived', 'pclass', 'sex', 'age', 'sibsp', 'parch', 'fare',\n", " 'embarked', 'adult_male', 'deck', 'alone'], dtype=object)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = df.drop.([:alive, :embark_town, :class, :who], axis: 1)\n", "df.columns.values" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "こうして残ったカラムは次のような意味を持っています。\n", "\n", "| カラム名 | 意味 |\n", "|:--- |:--- |\n", "| `survived` | 1: 生存, 0: 死亡 |\n", "| `pclass` | 乗客クラス (1: Upper, 2: Middle, 3: Lower) |\n", "| `sex` | 性別 (`male`: 男性, `female`: 女性) |\n", "| `age` | 年齢 (1歳未満は小数) |\n", "| `sibsp` | 同乗している兄弟・配偶者の人数 |\n", "| `parch` | 同乗している親・子供の人数 |\n", "| `fare` | チケット料金 |\n", "| `embarked` | 乗船した都市名の頭文字 |\n", "| `adult_male` | 大人の男性の場合 true |\n", "| `deck` | 客室種別 |\n", "| `alone` | 一人で乗船の場合 true |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "生のデータにはほぼ確実に欠損値が含まれています。このデータの場合はどうでしょうか?調べてみましょう。\n", "\n", "データフレームの `isnull` メソッドを用いると、各行各列について欠損値の場合に `true`、そうで無い場合に `false` を対応させた同じ形のデータフレームが作られます。そのような欠損値フラグを集めたデータフレムに対して `sum` メソッドを適用することで、カラム別に欠損値の個数をカウントできます (`true` を 1, `false` を 0 として総和をとる)。" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "survived 0\n", "pclass 0\n", "sex 0\n", "age 177\n", "sibsp 0\n", "parch 0\n", "fare 0\n", "embarked 2\n", "adult_male 0\n", "deck 688\n", "alone 0\n", "dtype: int64" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.isnull.().sum.()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "これより、`age` カラムには177個の欠損値、`deck` カラムには688個の欠損値が存在し、その他のカラムには欠損値が無いことがわかりました。\n", "\n", "全体で890行あるうち688個も値が欠損しているということは、`deck` カラムの値は分析には使えなさそうです。\n", "今回は `deck` カラムは捨てることにします。" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "collapsed": true }, "outputs": [], "source": [ "df = df.drop.(:deck, axis: 1)\n", "nil" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`age` カラムの分布を見てみましょう。" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sampled_age = df[:age].dropna.().sample.(100) # 全てのデータを使うと少し時間がかかるのでランダムサンプリングする\n", "sns.kdeplot.(sampled_age, shade: true, cut: 0)\n", "sns.rugplot.(sampled_age)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "あと、平均値も見てみます。せっかくなので全カラムの要約統計量を `describe` メソッドで求めましょう。" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
survivedpclassagesibspparchfare
count891.000000891.000000714.000000891.000000891.000000891.000000
mean0.3838382.30864229.6991180.5230080.38159432.204208
std0.4865920.83607114.5264971.1027430.80605749.693429
min0.0000001.0000000.4200000.0000000.0000000.000000
25%0.0000002.00000020.1250000.0000000.0000007.910400
50%0.0000003.00000028.0000000.0000000.00000014.454200
75%1.0000003.00000038.0000001.0000000.00000031.000000
max1.0000003.00000080.0000008.0000006.000000512.329200
\n", "

8 rows × 6 columns

\n", "
" ], "text/plain": [ " survived pclass age sibsp parch fare\n", "count 891.000000 891.000000 714.000000 891.000000 891.000000 891.000000\n", "mean 0.383838 2.308642 29.699118 0.523008 0.381594 32.204208\n", "std 0.486592 0.836071 14.526497 1.102743 0.806057 49.693429\n", "min 0.000000 1.000000 0.420000 0.000000 0.000000 0.000000\n", "25% 0.000000 2.000000 20.125000 0.000000 0.000000 7.910400\n", "50% 0.000000 3.000000 28.000000 0.000000 0.000000 14.454200\n", "75% 1.000000 3.000000 38.000000 1.000000 0.000000 31.000000\n", "max 1.000000 3.000000 80.000000 8.000000 6.000000 512.329200" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.describe.()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`age` の平均値は 29.699118、中央値は 28 であることが分かりました。\n", "\n", "`age` の欠損値の位置を記録しておいて、ひとまず中央値を使って欠損値を埋めることにします。" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "collapsed": true }, "outputs": [], "source": [ "age_isnull = df[:age].isnull.() # 欠損値の位置を記憶 (あとで使うかもしれないので)\n", "nil" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "collapsed": true }, "outputs": [], "source": [ "df[:age].fillna.(df[:age].median.(), inplace: true) # 欠損値を中央値で埋める\n", "nil" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "もう一度欠損値の個数を求めてみましょう。" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "survived 0\n", "pclass 0\n", "sex 0\n", "age 0\n", "sibsp 0\n", "parch 0\n", "fare 0\n", "embarked 2\n", "adult_male 0\n", "alone 0\n", "dtype: int64" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.isnull.().sum.()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "残るは `embarked` の2つですが、2件だけなので無視して進みます。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "生存予測をするためのモデルを作るので、予測の対象となるカラムは `survived` です。\n", "まず、各カラムが `survived` とどのくらい相関を持っているか見てみましょう。\n", "そのためには、ラベルが入っている `sex` と `embarked` の2カラムの値を数値に変換する必要があります。\n", "\n", "ラベル変数を数値変数へ変換したものをダミー変数と言い、pandas では `get_dummies` 関数を使って処理します。" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
survivedpclassagesibspparchfareadult_malealonefemalemaleCQ
00322.0107.2500TrueFalse0100
11138.01071.2833FalseFalse1010
21326.0007.9250FalseTrue1000
31135.01053.1000FalseFalse1000
40335.0008.0500TrueTrue0100
50328.0008.4583TrueTrue0101
60154.00051.8625TrueTrue0100
7032.03121.0750FalseFalse0100
81327.00211.1333FalseFalse1000
91214.01030.0708FalseFalse1010
.......................................
8810333.0007.8958TrueTrue0100
8820322.00010.5167FalseTrue1000
8830228.00010.5000TrueTrue0100
8840325.0007.0500TrueTrue0100
8850339.00529.1250FalseFalse1001
8860227.00013.0000TrueTrue0100
8871119.00030.0000FalseTrue1000
8880328.01223.4500FalseFalse1000
8891126.00030.0000TrueTrue0110
8900332.0007.7500TrueTrue0101
\n", "

891 rows × 12 columns

\n", "
" ], "text/plain": [ " survived pclass age sibsp parch fare adult_male alone female \\\n", "0 0 3 22.0 1 0 7.2500 True False 0 \n", "1 1 1 38.0 1 0 71.2833 False False 1 \n", "2 1 3 26.0 0 0 7.9250 False True 1 \n", "3 1 1 35.0 1 0 53.1000 False False 1 \n", "4 0 3 35.0 0 0 8.0500 True True 0 \n", "5 0 3 28.0 0 0 8.4583 True True 0 \n", "6 0 1 54.0 0 0 51.8625 True True 0 \n", "7 0 3 2.0 3 1 21.0750 False False 0 \n", "8 1 3 27.0 0 2 11.1333 False False 1 \n", "9 1 2 14.0 1 0 30.0708 False False 1 \n", "10 1 3 4.0 1 1 16.7000 False False 1 \n", "11 1 1 58.0 0 0 26.5500 False True 1 \n", "12 0 3 20.0 0 0 8.0500 True True 0 \n", "13 0 3 39.0 1 5 31.2750 True False 0 \n", "14 0 3 14.0 0 0 7.8542 False True 1 \n", "15 1 2 55.0 0 0 16.0000 False True 1 \n", "16 0 3 2.0 4 1 29.1250 False False 0 \n", "17 1 2 28.0 0 0 13.0000 True True 0 \n", "18 0 3 31.0 1 0 18.0000 False False 1 \n", "19 1 3 28.0 0 0 7.2250 False True 1 \n", "20 0 2 35.0 0 0 26.0000 True True 0 \n", "21 1 2 34.0 0 0 13.0000 True True 0 \n", "22 1 3 15.0 0 0 8.0292 False True 1 \n", "23 1 1 28.0 0 0 35.5000 True True 0 \n", "24 0 3 8.0 3 1 21.0750 False False 1 \n", "25 1 3 38.0 1 5 31.3875 False False 1 \n", "26 0 3 28.0 0 0 7.2250 True True 0 \n", "27 0 1 19.0 3 2 263.0000 True False 0 \n", "28 1 3 28.0 0 0 7.8792 False True 1 \n", "29 0 3 28.0 0 0 7.8958 True True 0 \n", ".. ... ... ... ... ... ... ... ... ... \n", "861 0 2 21.0 1 0 11.5000 True False 0 \n", "862 1 1 48.0 0 0 25.9292 False True 1 \n", "863 0 3 28.0 8 2 69.5500 False False 1 \n", "864 0 2 24.0 0 0 13.0000 True True 0 \n", "865 1 2 42.0 0 0 13.0000 False True 1 \n", "866 1 2 27.0 1 0 13.8583 False False 1 \n", "867 0 1 31.0 0 0 50.4958 True True 0 \n", "868 0 3 28.0 0 0 9.5000 True True 0 \n", "869 1 3 4.0 1 1 11.1333 False False 0 \n", "870 0 3 26.0 0 0 7.8958 True True 0 \n", "871 1 1 47.0 1 1 52.5542 False False 1 \n", "872 0 1 33.0 0 0 5.0000 True True 0 \n", "873 0 3 47.0 0 0 9.0000 True True 0 \n", "874 1 2 28.0 1 0 24.0000 False False 1 \n", "875 1 3 15.0 0 0 7.2250 False True 1 \n", "876 0 3 20.0 0 0 9.8458 True True 0 \n", "877 0 3 19.0 0 0 7.8958 True True 0 \n", "878 0 3 28.0 0 0 7.8958 True True 0 \n", "879 1 1 56.0 0 1 83.1583 False False 1 \n", "880 1 2 25.0 0 1 26.0000 False False 1 \n", "881 0 3 33.0 0 0 7.8958 True True 0 \n", "882 0 3 22.0 0 0 10.5167 False True 1 \n", "883 0 2 28.0 0 0 10.5000 True True 0 \n", "884 0 3 25.0 0 0 7.0500 True True 0 \n", "885 0 3 39.0 0 5 29.1250 False False 1 \n", "886 0 2 27.0 0 0 13.0000 True True 0 \n", "887 1 1 19.0 0 0 30.0000 False True 1 \n", "888 0 3 28.0 1 2 23.4500 False False 1 \n", "889 1 1 26.0 0 0 30.0000 True True 0 \n", "890 0 3 32.0 0 0 7.7500 True True 0 \n", "\n", " male C Q \n", "0 1 0 0 \n", "1 0 1 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 1 0 0 \n", "5 1 0 1 \n", "6 1 0 0 \n", "7 1 0 0 \n", "8 0 0 0 \n", "9 0 1 0 \n", "10 0 0 0 \n", "11 0 0 0 \n", "12 1 0 0 \n", "13 1 0 0 \n", "14 0 0 0 \n", "15 0 0 0 \n", "16 1 0 1 \n", "17 1 0 0 \n", "18 0 0 0 \n", "19 0 1 0 \n", "20 1 0 0 \n", "21 1 0 0 \n", "22 0 0 1 \n", "23 1 0 0 \n", "24 0 0 0 \n", "25 0 0 0 \n", "26 1 1 0 \n", "27 1 0 0 \n", "28 0 0 1 \n", "29 1 0 0 \n", ".. ... .. .. \n", "861 1 0 0 \n", "862 0 0 0 \n", "863 0 0 0 \n", "864 1 0 0 \n", "865 0 0 0 \n", "866 0 1 0 \n", "867 1 0 0 \n", "868 1 0 0 \n", "869 1 0 0 \n", "870 1 0 0 \n", "871 0 0 0 \n", "872 1 0 0 \n", "873 1 0 0 \n", "874 0 1 0 \n", "875 0 1 0 \n", "876 1 0 0 \n", "877 1 0 0 \n", "878 1 0 0 \n", "879 0 1 0 \n", "880 0 0 0 \n", "881 1 0 0 \n", "882 0 0 0 \n", "883 1 0 0 \n", "884 1 0 0 \n", "885 0 0 1 \n", "886 1 0 0 \n", "887 0 0 0 \n", "888 0 0 0 \n", "889 1 1 0 \n", "890 1 0 1 \n", "\n", "[891 rows x 12 columns]" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sex_dummies = pd.get_dummies.(df[:sex])\n", "embarked_dummies = pd.get_dummies.(df[:embarked])\n", "df = pd.concat.(PyCall.tuple(df, sex_dummies, embarked_dummies), axis: 1)\n", "df = df.drop.([:sex, :embarked, :S], axis: 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`sex` のダミー変数である `female` と `male`, および `embarked` のダミー変数である `C`, `Q` が追加されました。\n", "元の `sex` と `embarked` は削除しました。\n", "\n", "`embarked` のダミー変数にはもう一つ `S` が存在していますが、`C` と `Q` の両方が 0 の場合、(2件ある欠損値を除いて) `S` が 1 になっているはずです。ですから、`S` は情報量を持たないため削除しています。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "これで、全てのカラムが数値データになったので、カラム間の相関係数を `corr` メソッドで求めます。" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
survivedpclassagesibspparchfareadult_malealonefemalemaleCQ
survived1.000000-0.338481-0.064910-0.0353220.0816290.257307-0.557080-0.2033670.543351-0.5433510.1682400.003650
pclass-0.3384811.000000-0.3398980.0830810.018443-0.5495000.0940350.135207-0.1319000.131900-0.2432920.221009
age-0.064910-0.3398981.000000-0.233296-0.1724820.0966880.2477040.171647-0.0811630.0811630.030248-0.031415
sibsp-0.0353220.083081-0.2332961.0000000.4148380.159651-0.253586-0.5844710.114631-0.114631-0.059528-0.026354
parch0.0816290.018443-0.1724820.4148381.0000000.216225-0.349943-0.5833980.245489-0.245489-0.011069-0.081228
fare0.257307-0.5495000.0966880.1596510.2162251.000000-0.182024-0.2718320.182333-0.1823330.269335-0.117216
adult_male-0.5570800.0940350.247704-0.253586-0.349943-0.1820241.0000000.404744-0.9085780.908578-0.065980-0.076789
alone-0.2033670.1352070.171647-0.584471-0.583398-0.2718320.4047441.000000-0.3036460.303646-0.0952980.086464
female0.543351-0.131900-0.0811630.1146310.2454890.182333-0.908578-0.3036461.000000-1.0000000.0828530.074115
male-0.5433510.1319000.081163-0.114631-0.245489-0.1823330.9085780.303646-1.0000001.000000-0.082853-0.074115
C0.168240-0.2432920.030248-0.059528-0.0110690.269335-0.065980-0.0952980.082853-0.0828531.000000-0.148258
Q0.0036500.221009-0.031415-0.026354-0.081228-0.117216-0.0767890.0864640.074115-0.074115-0.1482581.000000
\n", "

12 rows × 12 columns

\n", "
" ], "text/plain": [ " survived pclass age sibsp parch fare \\\n", "survived 1.000000 -0.338481 -0.064910 -0.035322 0.081629 0.257307 \n", "pclass -0.338481 1.000000 -0.339898 0.083081 0.018443 -0.549500 \n", "age -0.064910 -0.339898 1.000000 -0.233296 -0.172482 0.096688 \n", "sibsp -0.035322 0.083081 -0.233296 1.000000 0.414838 0.159651 \n", "parch 0.081629 0.018443 -0.172482 0.414838 1.000000 0.216225 \n", "fare 0.257307 -0.549500 0.096688 0.159651 0.216225 1.000000 \n", "adult_male -0.557080 0.094035 0.247704 -0.253586 -0.349943 -0.182024 \n", "alone -0.203367 0.135207 0.171647 -0.584471 -0.583398 -0.271832 \n", "female 0.543351 -0.131900 -0.081163 0.114631 0.245489 0.182333 \n", "male -0.543351 0.131900 0.081163 -0.114631 -0.245489 -0.182333 \n", "C 0.168240 -0.243292 0.030248 -0.059528 -0.011069 0.269335 \n", "Q 0.003650 0.221009 -0.031415 -0.026354 -0.081228 -0.117216 \n", "\n", " adult_male alone female male C Q \n", "survived -0.557080 -0.203367 0.543351 -0.543351 0.168240 0.003650 \n", "pclass 0.094035 0.135207 -0.131900 0.131900 -0.243292 0.221009 \n", "age 0.247704 0.171647 -0.081163 0.081163 0.030248 -0.031415 \n", "sibsp -0.253586 -0.584471 0.114631 -0.114631 -0.059528 -0.026354 \n", "parch -0.349943 -0.583398 0.245489 -0.245489 -0.011069 -0.081228 \n", "fare -0.182024 -0.271832 0.182333 -0.182333 0.269335 -0.117216 \n", "adult_male 1.000000 0.404744 -0.908578 0.908578 -0.065980 -0.076789 \n", "alone 0.404744 1.000000 -0.303646 0.303646 -0.095298 0.086464 \n", "female -0.908578 -0.303646 1.000000 -1.000000 0.082853 0.074115 \n", "male 0.908578 0.303646 -1.000000 1.000000 -0.082853 -0.074115 \n", "C -0.065980 -0.095298 0.082853 -0.082853 1.000000 -0.148258 \n", "Q -0.076789 0.086464 0.074115 -0.074115 -0.148258 1.000000 " ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.corr.()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "性別系のカラム (`female`, `male`, `adult_male`) が最も相関が高いことがわかります。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "### モデリング\n", "\n", "ここでは、ランダムフォレスト ( `sklearn.ensemble.RandomForestClassifier` )、ロジスティック回帰 ( `sklearn.linear_model.LogisticRegression` )、サポートベクトルマシン ( `sklearn.svm.SVC` ) の3種類のモデルを作り、それぞれの精度を比較します。\n", "モデルのハイパーパラメータをグリッドサーチ ( `sklearn.model_selection.GridSearchCV` ) で最適化します。" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "pytype(GridSearchCV)" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pyfrom 'sklearn.ensemble', import: :RandomForestClassifier\n", "pyfrom 'sklearn.linear_model', import: :LogisticRegression\n", "pyfrom 'sklearn.svm', import: :SVC\n", "pyfrom 'sklearn.model_selection', import: :GridSearchCV" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### ランダムフォレストによる分類モデルの作成" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=5, error_score='raise',\n", " estimator=RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='auto', max_leaf_nodes=None,\n", " min_impurity_split=1e-07, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " n_estimators=10, n_jobs=2, oob_score=False, random_state=None,\n", " verbose=0, warm_start=False),\n", " fit_params={}, iid=True, n_jobs=4,\n", " param_grid={'n_estimators': [10, 20, 50], 'max_depth': [4, 5, 6, 7], 'max_features': ['auto', 'log2', None]},\n", " pre_dispatch='2*n_jobs', refit=True, return_train_score=True,\n", " scoring='roc_auc', verbose=0)" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rfc = GridSearchCV.(\n", " RandomForestClassifier.(n_jobs: 2),\n", " {\n", " n_estimators: [10, 20, 50],\n", " max_depth: [4, 5, 6, 7],\n", " max_features: [:auto, :log2, PyCall.None],\n", " },\n", " scoring: :roc_auc,\n", " n_jobs: 4,\n", " cv: 5\n", ")" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=5, error_score='raise',\n", " estimator=RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='auto', max_leaf_nodes=None,\n", " min_impurity_split=1e-07, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " n_estimators=10, n_jobs=2, oob_score=False, random_state=None,\n", " verbose=0, warm_start=False),\n", " fit_params={}, iid=True, n_jobs=4,\n", " param_grid={'n_estimators': [10, 20, 50], 'max_depth': [4, 5, 6, 7], 'max_features': ['auto', 'log2', None]},\n", " pre_dispatch='2*n_jobs', refit=True, return_train_score=True,\n", " scoring='roc_auc', verbose=0)" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_names = [:pclass, :age, :sibsp, :parch, :fare, :adult_male, :alone, :female, :male, :C, :Q]\n", "x = df[x_names]\n", "y = df[:survived]\n", "rfc.fit.(x, y)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'max_depth': 7, 'max_features': 'auto', 'n_estimators': 20}" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rfc.best_params_" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8764839649977957" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rfc.best_score_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "グリッドサーチおよび交差検定の結果は `cv_results_` 属性に入っています。この属性の値は、そのまま pandas の DataFrame に渡せます。" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_fit_timemean_score_timemean_test_scoremean_train_scoreparam_max_depthparam_max_featuresparam_n_estimatorsrank_test_scoresplit0_test_scoresplit0_train_scoresplit1_test_scoresplit1_train_scoresplit2_test_scoresplit2_train_scoresplit3_test_scoresplit3_train_scoresplit4_test_scoresplit4_train_scorestd_fit_timestd_score_timestd_test_scorestd_train_score
00.1350950.1055500.8630780.8936004auto10340.8575100.8931220.8272730.9079410.8620990.8880210.8675130.8904150.9014440.8885000.0121070.0006270.0236710.007392
10.1252550.1034050.8586640.8957144auto20360.8544140.9002060.8200920.9024510.8700530.8917210.8468580.8951670.9023880.8890260.0061310.0004560.0271300.005031
20.1824100.1043240.8662820.8969974auto50250.8560610.9005690.8297100.9075400.8729280.8948710.8702540.8968580.9029280.8851440.0123500.0011520.0238420.007334
30.1325390.1040100.8646710.8913644log210310.8625160.8896260.8250330.9027180.8665110.8884910.8741980.8960730.8955070.8799150.0113790.0007050.0228970.007661
40.1529710.1042650.8648050.8955174log220300.8639000.8989290.8322790.9018620.8638370.8890230.8653070.9022700.8990830.8855010.0240550.0008320.0211270.006928
50.1525430.1034660.8651510.8967614log250290.8538210.8971980.8366930.9061930.8715240.8929880.8631020.8994360.9010390.8879890.0068870.0008390.0212910.006123
60.1149020.1029680.8684850.9154554None10210.8532940.9232270.8324770.9258600.8775400.9124670.8582890.9117060.9214110.9040150.0037080.0002340.0300080.008032
70.1243460.1040100.8657440.9200274None20270.8407110.9241740.8174570.9314540.8829550.9159130.8677810.9157670.9205340.9128280.0048020.0020740.0353650.006851
80.1551400.1030400.8658410.9202214None50260.8378130.9253670.8133730.9322430.8908420.9172930.8803480.9204270.9075150.9057770.0109090.0002230.0349680.008812
90.1150440.1035780.8734320.9122435auto1060.8713440.9263940.8388670.9171360.8768050.9080150.8641710.9113320.9164190.8983370.0015890.0009830.0250480.009341
.....................................................................
260.1393500.1027830.8677160.9597546None50220.8386690.9618720.8155470.9659820.8969250.9534110.8649730.9604780.9232330.9570260.0018270.0006210.0387360.004277
270.1123840.1028760.8664790.9464157auto10240.8517130.9514260.8139000.9557060.8862300.9444030.8784090.9404500.9027250.9400880.0019990.0006970.0310890.006182
280.1232560.1032520.8764840.9507037auto2010.8596180.9569450.8281950.9608380.8802140.9450730.8886360.9504930.9264030.9401670.0066650.0006270.0324490.007542
290.1434060.1028110.8751910.9554437auto5020.8633730.9589690.8293150.9630990.8951200.9480070.8659090.9558430.9228280.9512980.0036350.0004050.0315950.005360
300.1147500.1035710.8606260.9456367log210350.8517130.9514260.8159420.9513840.8896390.9411440.8436500.9441370.9027250.9400880.0021560.0008100.0315420.004894
310.1243440.1053110.8719640.9524077log22080.8454550.9553510.8313570.9616180.8877010.9473840.8886360.9504930.9072450.9471880.0038140.0041630.0287180.005471
320.1495120.1038680.8693590.9523237log250170.8543480.9562730.8194330.9591900.8951200.9480070.8602940.9508500.9182410.9472920.0101450.0010120.0342010.004666
330.1125300.1042840.8715600.9680327None1090.8456520.9740790.8137680.9751930.8996660.9598330.8847590.9696220.9146650.9614300.0018910.0008380.0369800.006343
340.1207630.1029490.8646540.9717897None20320.8417000.9740590.8118580.9770670.8954550.9671910.8584220.9742530.9165540.9663780.0011160.0003480.0373930.004231
350.1397710.1031920.8713350.9732697None50100.8386030.9739670.8123190.9783770.9102270.9702090.8819520.9744150.9143280.9693760.0025130.0004250.0400850.003238
\n", "

36 rows × 22 columns

\n", "
" ], "text/plain": [ " mean_fit_time mean_score_time mean_test_score mean_train_score \\\n", "0 0.135095 0.105550 0.863078 0.893600 \n", "1 0.125255 0.103405 0.858664 0.895714 \n", "2 0.182410 0.104324 0.866282 0.896997 \n", "3 0.132539 0.104010 0.864671 0.891364 \n", "4 0.152971 0.104265 0.864805 0.895517 \n", "5 0.152543 0.103466 0.865151 0.896761 \n", "6 0.114902 0.102968 0.868485 0.915455 \n", "7 0.124346 0.104010 0.865744 0.920027 \n", "8 0.155140 0.103040 0.865841 0.920221 \n", "9 0.115044 0.103578 0.873432 0.912243 \n", "10 0.119204 0.105111 0.872546 0.914753 \n", "11 0.154862 0.103477 0.870855 0.917777 \n", "12 0.115686 0.103893 0.868964 0.910325 \n", "13 0.128181 0.103131 0.863800 0.910893 \n", "14 0.160281 0.103907 0.869836 0.917284 \n", "15 0.117599 0.104991 0.873508 0.934520 \n", "16 0.126553 0.105983 0.869542 0.940223 \n", "17 0.153157 0.103327 0.871002 0.942536 \n", "18 0.118098 0.103806 0.865280 0.930026 \n", "19 0.127066 0.103749 0.868677 0.932113 \n", "20 0.141172 0.103467 0.874621 0.936588 \n", "21 0.115320 0.103669 0.874284 0.928126 \n", "22 0.121289 0.103003 0.869396 0.932330 \n", "23 0.151305 0.103404 0.868924 0.938588 \n", "24 0.112815 0.103616 0.869369 0.953285 \n", "25 0.120892 0.102985 0.867146 0.957426 \n", "26 0.139350 0.102783 0.867716 0.959754 \n", "27 0.112384 0.102876 0.866479 0.946415 \n", "28 0.123256 0.103252 0.876484 0.950703 \n", "29 0.143406 0.102811 0.875191 0.955443 \n", "30 0.114750 0.103571 0.860626 0.945636 \n", "31 0.124344 0.105311 0.871964 0.952407 \n", "32 0.149512 0.103868 0.869359 0.952323 \n", "33 0.112530 0.104284 0.871560 0.968032 \n", "34 0.120763 0.102949 0.864654 0.971789 \n", "35 0.139771 0.103192 0.871335 0.973269 \n", "\n", " param_max_depth param_max_features param_n_estimators rank_test_score \\\n", "0 4 auto 10 34 \n", "1 4 auto 20 36 \n", "2 4 auto 50 25 \n", "3 4 log2 10 31 \n", "4 4 log2 20 30 \n", "5 4 log2 50 29 \n", "6 4 None 10 21 \n", "7 4 None 20 27 \n", "8 4 None 50 26 \n", "9 5 auto 10 6 \n", "10 5 auto 20 7 \n", "11 5 auto 50 12 \n", "12 5 log2 10 18 \n", "13 5 log2 20 33 \n", "14 5 log2 50 13 \n", "15 5 None 10 5 \n", "16 5 None 20 14 \n", "17 5 None 50 11 \n", "18 6 auto 10 28 \n", "19 6 auto 20 20 \n", "20 6 auto 50 3 \n", "21 6 log2 10 4 \n", "22 6 log2 20 15 \n", "23 6 log2 50 19 \n", "24 6 None 10 16 \n", "25 6 None 20 23 \n", "26 6 None 50 22 \n", "27 7 auto 10 24 \n", "28 7 auto 20 1 \n", "29 7 auto 50 2 \n", "30 7 log2 10 35 \n", "31 7 log2 20 8 \n", "32 7 log2 50 17 \n", "33 7 None 10 9 \n", "34 7 None 20 32 \n", "35 7 None 50 10 \n", "\n", " split0_test_score split0_train_score ... split2_test_score \\\n", "0 0.857510 0.893122 ... 0.862099 \n", "1 0.854414 0.900206 ... 0.870053 \n", "2 0.856061 0.900569 ... 0.872928 \n", "3 0.862516 0.889626 ... 0.866511 \n", "4 0.863900 0.898929 ... 0.863837 \n", "5 0.853821 0.897198 ... 0.871524 \n", "6 0.853294 0.923227 ... 0.877540 \n", "7 0.840711 0.924174 ... 0.882955 \n", "8 0.837813 0.925367 ... 0.890842 \n", "9 0.871344 0.926394 ... 0.876805 \n", "10 0.864163 0.911846 ... 0.888102 \n", "11 0.861067 0.918363 ... 0.876537 \n", "12 0.854875 0.916831 ... 0.873864 \n", "13 0.861924 0.915342 ... 0.877807 \n", "14 0.858893 0.917345 ... 0.883556 \n", "15 0.851713 0.943653 ... 0.905013 \n", "16 0.846706 0.942940 ... 0.898529 \n", "17 0.843083 0.946765 ... 0.897527 \n", "18 0.852701 0.938371 ... 0.873797 \n", "19 0.873847 0.941171 ... 0.882019 \n", "20 0.864097 0.941012 ... 0.895521 \n", "21 0.867852 0.927466 ... 0.893048 \n", "22 0.873847 0.941171 ... 0.879947 \n", "23 0.855599 0.940595 ... 0.885094 \n", "24 0.826943 0.951058 ... 0.914906 \n", "25 0.846772 0.957091 ... 0.890775 \n", "26 0.838669 0.961872 ... 0.896925 \n", "27 0.851713 0.951426 ... 0.886230 \n", "28 0.859618 0.956945 ... 0.880214 \n", "29 0.863373 0.958969 ... 0.895120 \n", "30 0.851713 0.951426 ... 0.889639 \n", "31 0.845455 0.955351 ... 0.887701 \n", "32 0.854348 0.956273 ... 0.895120 \n", "33 0.845652 0.974079 ... 0.899666 \n", "34 0.841700 0.974059 ... 0.895455 \n", "35 0.838603 0.973967 ... 0.910227 \n", "\n", " split2_train_score split3_test_score split3_train_score \\\n", "0 0.888021 0.867513 0.890415 \n", "1 0.891721 0.846858 0.895167 \n", "2 0.894871 0.870254 0.896858 \n", "3 0.888491 0.874198 0.896073 \n", "4 0.889023 0.865307 0.902270 \n", "5 0.892988 0.863102 0.899436 \n", "6 0.912467 0.858289 0.911706 \n", "7 0.915913 0.867781 0.915767 \n", "8 0.917293 0.880348 0.920427 \n", "9 0.908015 0.864171 0.911332 \n", "10 0.911311 0.882821 0.915551 \n", "11 0.913560 0.883757 0.920066 \n", "12 0.903800 0.875401 0.907853 \n", "13 0.908289 0.853075 0.910027 \n", "14 0.911827 0.871791 0.919288 \n", "15 0.927203 0.873463 0.937802 \n", "16 0.934261 0.863570 0.942944 \n", "17 0.936801 0.874332 0.944324 \n", "18 0.923042 0.868783 0.930520 \n", "19 0.928737 0.857420 0.935046 \n", "20 0.935013 0.871658 0.936489 \n", "21 0.920988 0.870388 0.932694 \n", "22 0.930644 0.869452 0.926633 \n", "23 0.932690 0.871658 0.936489 \n", "24 0.952850 0.873663 0.952114 \n", "25 0.952376 0.872326 0.958831 \n", "26 0.953411 0.864973 0.960478 \n", "27 0.944403 0.878409 0.940450 \n", "28 0.945073 0.888636 0.950493 \n", "29 0.948007 0.865909 0.955843 \n", "30 0.941144 0.843650 0.944137 \n", "31 0.947384 0.888636 0.950493 \n", "32 0.948007 0.860294 0.950850 \n", "33 0.959833 0.884759 0.969622 \n", "34 0.967191 0.858422 0.974253 \n", "35 0.970209 0.881952 0.974415 \n", "\n", " split4_test_score split4_train_score std_fit_time std_score_time \\\n", "0 0.901444 0.888500 0.012107 0.000627 \n", "1 0.902388 0.889026 0.006131 0.000456 \n", "2 0.902928 0.885144 0.012350 0.001152 \n", "3 0.895507 0.879915 0.011379 0.000705 \n", "4 0.899083 0.885501 0.024055 0.000832 \n", "5 0.901039 0.887989 0.006887 0.000839 \n", "6 0.921411 0.904015 0.003708 0.000234 \n", "7 0.920534 0.912828 0.004802 0.002074 \n", "8 0.907515 0.905777 0.010909 0.000223 \n", "9 0.916419 0.898337 0.001589 0.000983 \n", "10 0.911360 0.911961 0.001261 0.000830 \n", "11 0.909404 0.907590 0.008851 0.000683 \n", "12 0.907178 0.900527 0.003163 0.000692 \n", "13 0.908729 0.903243 0.004763 0.000626 \n", "14 0.912979 0.911894 0.023513 0.001555 \n", "15 0.925594 0.921155 0.007179 0.001646 \n", "16 0.922895 0.933975 0.005028 0.003048 \n", "17 0.922895 0.933033 0.008430 0.000568 \n", "18 0.915677 0.925245 0.006188 0.000669 \n", "19 0.910753 0.916004 0.006584 0.000852 \n", "20 0.918173 0.924498 0.001793 0.000669 \n", "21 0.923570 0.918161 0.002209 0.001296 \n", "22 0.910753 0.916004 0.001124 0.000493 \n", "23 0.917701 0.932735 0.007009 0.001459 \n", "24 0.914126 0.948984 0.001158 0.000908 \n", "25 0.916892 0.954931 0.002987 0.000276 \n", "26 0.923233 0.957026 0.001827 0.000621 \n", "27 0.902725 0.940088 0.001999 0.000697 \n", "28 0.926403 0.940167 0.006665 0.000627 \n", "29 0.922828 0.951298 0.003635 0.000405 \n", "30 0.902725 0.940088 0.002156 0.000810 \n", "31 0.907245 0.947188 0.003814 0.004163 \n", "32 0.918241 0.947292 0.010145 0.001012 \n", "33 0.914665 0.961430 0.001891 0.000838 \n", "34 0.916554 0.966378 0.001116 0.000348 \n", "35 0.914328 0.969376 0.002513 0.000425 \n", "\n", " std_test_score std_train_score \n", "0 0.023671 0.007392 \n", "1 0.027130 0.005031 \n", "2 0.023842 0.007334 \n", "3 0.022897 0.007661 \n", "4 0.021127 0.006928 \n", "5 0.021291 0.006123 \n", "6 0.030008 0.008032 \n", "7 0.035365 0.006851 \n", "8 0.034968 0.008812 \n", "9 0.025048 0.009341 \n", "10 0.031713 0.004435 \n", "11 0.028184 0.007212 \n", "12 0.024264 0.008222 \n", "13 0.029759 0.005106 \n", "14 0.029695 0.005292 \n", "15 0.039749 0.008884 \n", "16 0.037490 0.005201 \n", "17 0.037422 0.006772 \n", "18 0.032245 0.005479 \n", "19 0.029955 0.009140 \n", "20 0.031590 0.007148 \n", "21 0.034914 0.008315 \n", "22 0.031521 0.010964 \n", "23 0.033786 0.006600 \n", "24 0.041327 0.004270 \n", "25 0.036813 0.003892 \n", "26 0.038736 0.004277 \n", "27 0.031089 0.006182 \n", "28 0.032449 0.007542 \n", "29 0.031595 0.005360 \n", "30 0.031542 0.004894 \n", "31 0.028718 0.005471 \n", "32 0.034201 0.004666 \n", "33 0.036980 0.006343 \n", "34 0.037393 0.004231 \n", "35 0.040085 0.003238 \n", "\n", "[36 rows x 22 columns]" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame.(data: rfc.cv_results_).drop.(:params, axis: 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "もっとも成績が良かったランダムフォレストモデルにおける特徴量の重要度を見てみましょう。\n", "\n", "もっとも成績が良いモデルは `best_estimator_` で取得できます。\n", "このモデルは RandomForestClassifier のインスタンスなので、`feature_importances_` 属性を持っています。\n", "これと `x_names` を seaborn の barplot を使って可視化します。" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_importance = pd.DataFrame.(data: {\n", " name: x_names,\n", " importance: rfc.best_estimator_.feature_importances_\n", "})\n", "sns.barplot.(x: :name, y: :importance, data: df_importance)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`adult_male` や性別 (`female`, `male`) が大きく寄与していることがわかります。\n", "逆に `alone`、`C`、`Q` はほとんど寄与していません。\n", "\n", "もう一度、カラム間の相関行列を見てみましょう。" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
survivedpclassagesibspparchfareadult_malealonefemalemaleCQ
survived1.000000-0.338481-0.064910-0.0353220.0816290.257307-0.557080-0.2033670.543351-0.5433510.1682400.003650
pclass-0.3384811.000000-0.3398980.0830810.018443-0.5495000.0940350.135207-0.1319000.131900-0.2432920.221009
age-0.064910-0.3398981.000000-0.233296-0.1724820.0966880.2477040.171647-0.0811630.0811630.030248-0.031415
sibsp-0.0353220.083081-0.2332961.0000000.4148380.159651-0.253586-0.5844710.114631-0.114631-0.059528-0.026354
parch0.0816290.018443-0.1724820.4148381.0000000.216225-0.349943-0.5833980.245489-0.245489-0.011069-0.081228
fare0.257307-0.5495000.0966880.1596510.2162251.000000-0.182024-0.2718320.182333-0.1823330.269335-0.117216
adult_male-0.5570800.0940350.247704-0.253586-0.349943-0.1820241.0000000.404744-0.9085780.908578-0.065980-0.076789
alone-0.2033670.1352070.171647-0.584471-0.583398-0.2718320.4047441.000000-0.3036460.303646-0.0952980.086464
female0.543351-0.131900-0.0811630.1146310.2454890.182333-0.908578-0.3036461.000000-1.0000000.0828530.074115
male-0.5433510.1319000.081163-0.114631-0.245489-0.1823330.9085780.303646-1.0000001.000000-0.082853-0.074115
C0.168240-0.2432920.030248-0.059528-0.0110690.269335-0.065980-0.0952980.082853-0.0828531.000000-0.148258
Q0.0036500.221009-0.031415-0.026354-0.081228-0.117216-0.0767890.0864640.074115-0.074115-0.1482581.000000
\n", "

12 rows × 12 columns

\n", "
" ], "text/plain": [ " survived pclass age sibsp parch fare \\\n", "survived 1.000000 -0.338481 -0.064910 -0.035322 0.081629 0.257307 \n", "pclass -0.338481 1.000000 -0.339898 0.083081 0.018443 -0.549500 \n", "age -0.064910 -0.339898 1.000000 -0.233296 -0.172482 0.096688 \n", "sibsp -0.035322 0.083081 -0.233296 1.000000 0.414838 0.159651 \n", "parch 0.081629 0.018443 -0.172482 0.414838 1.000000 0.216225 \n", "fare 0.257307 -0.549500 0.096688 0.159651 0.216225 1.000000 \n", "adult_male -0.557080 0.094035 0.247704 -0.253586 -0.349943 -0.182024 \n", "alone -0.203367 0.135207 0.171647 -0.584471 -0.583398 -0.271832 \n", "female 0.543351 -0.131900 -0.081163 0.114631 0.245489 0.182333 \n", "male -0.543351 0.131900 0.081163 -0.114631 -0.245489 -0.182333 \n", "C 0.168240 -0.243292 0.030248 -0.059528 -0.011069 0.269335 \n", "Q 0.003650 0.221009 -0.031415 -0.026354 -0.081228 -0.117216 \n", "\n", " adult_male alone female male C Q \n", "survived -0.557080 -0.203367 0.543351 -0.543351 0.168240 0.003650 \n", "pclass 0.094035 0.135207 -0.131900 0.131900 -0.243292 0.221009 \n", "age 0.247704 0.171647 -0.081163 0.081163 0.030248 -0.031415 \n", "sibsp -0.253586 -0.584471 0.114631 -0.114631 -0.059528 -0.026354 \n", "parch -0.349943 -0.583398 0.245489 -0.245489 -0.011069 -0.081228 \n", "fare -0.182024 -0.271832 0.182333 -0.182333 0.269335 -0.117216 \n", "adult_male 1.000000 0.404744 -0.908578 0.908578 -0.065980 -0.076789 \n", "alone 0.404744 1.000000 -0.303646 0.303646 -0.095298 0.086464 \n", "female -0.908578 -0.303646 1.000000 -1.000000 0.082853 0.074115 \n", "male 0.908578 0.303646 -1.000000 1.000000 -0.082853 -0.074115 \n", "C -0.065980 -0.095298 0.082853 -0.082853 1.000000 -0.148258 \n", "Q -0.076789 0.086464 0.074115 -0.074115 -0.148258 1.000000 " ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.corr.()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`adult_male`, `female`, `male`, はどれも0.5を超える相関係数を持っていて、かつ、特徴量としての重要度も高くなっていました。\n", "しかし、`fare` と `alone` を見てみると、これらは同程度の相関係数になっていますが、特徴量としての重要度は `fare` は `female` と同じくらい高いのに対し、`alone` はもっとも重要度が低い特徴量でした。\n", "このように、単に相関係数を見るだけでは、特徴量が分類にどの程度重要になるかは分からないのです。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### ロジスティク回帰による分類モデルの作成" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=5, error_score='raise',\n", " estimator=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", " intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=2,\n", " penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", " verbose=0, warm_start=False),\n", " fit_params={}, iid=True, n_jobs=4,\n", " param_grid={'penalty': ['l2', 'l1'], 'C': [10.0, 1.0, 0.1, 0.01]},\n", " pre_dispatch='2*n_jobs', refit=True, return_train_score=True,\n", " scoring='roc_auc', verbose=0)" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lrc = GridSearchCV.(\n", " LogisticRegression.(n_jobs: 2),\n", " {\n", " penalty: [:l2, :l1],\n", " C: [10.0, 1.0, 0.1, 0.01],\n", " },\n", " scoring: :roc_auc,\n", " n_jobs: 4,\n", " cv: 5\n", ")" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=5, error_score='raise',\n", " estimator=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", " intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=2,\n", " penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", " verbose=0, warm_start=False),\n", " fit_params={}, iid=True, n_jobs=4,\n", " param_grid={'penalty': ['l2', 'l1'], 'C': [10.0, 1.0, 0.1, 0.01]},\n", " pre_dispatch='2*n_jobs', refit=True, return_train_score=True,\n", " scoring='roc_auc', verbose=0)" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lrc.fit.(x, y)" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'C': 10.0, 'penalty': 'l1'}" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lrc.best_params_" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_fit_timemean_score_timemean_test_scoremean_train_scoreparam_Cparam_penaltyrank_test_scoresplit0_test_scoresplit0_train_scoresplit1_test_scoresplit1_train_scoresplit2_test_scoresplit2_train_scoresplit3_test_scoresplit3_train_scoresplit4_test_scoresplit4_train_scorestd_fit_timestd_score_timestd_test_scorestd_train_score
00.0195780.0025490.8618750.87010710l220.8516470.8732220.8392620.8759630.8674470.8692570.8544790.8713030.8969240.8607910.0030530.0005020.0196190.005156
10.0577000.0019650.8619020.87007410l110.8527010.8730800.8384720.8760040.8678480.8692160.8535430.8712530.8973290.8608160.0303850.0003290.0199410.005139
20.0070030.0011800.8610990.8698691l240.8525690.8716780.8384720.8773310.8655750.8683430.8544790.8707370.8947650.8612560.0013670.0002820.0188520.005220
30.0197800.0017200.8615480.8701841l130.8541500.8725790.8400530.8769720.8657090.8685260.8519390.8716850.8962490.8611560.0057930.0005640.0191030.005260
40.0086170.0018690.8542990.8630250.1l250.8437420.8625080.8380760.8685030.8491310.8635210.8547460.8664470.8861310.8541470.0010680.0006480.0167910.004920
50.0063090.0012800.8500510.8582380.1l160.8507250.8554410.8357050.8617780.8428480.8595800.8443850.8621240.8768210.8522640.0013060.0003570.0141630.003820
60.0080140.0018780.8411880.8515570.01l270.8135700.8517610.8370220.8550860.8293450.8519150.8421790.8531620.8842420.8458610.0016520.0004280.0235190.003087
70.0056910.0016340.6687660.6783270.01l180.5699600.7140900.7300400.6746690.6431150.6845520.6960560.6613700.7050730.6569550.0016060.0006930.0570760.020374
\n", "

8 rows × 21 columns

\n", "
" ], "text/plain": [ " mean_fit_time mean_score_time mean_test_score mean_train_score param_C \\\n", "0 0.019578 0.002549 0.861875 0.870107 10 \n", "1 0.057700 0.001965 0.861902 0.870074 10 \n", "2 0.007003 0.001180 0.861099 0.869869 1 \n", "3 0.019780 0.001720 0.861548 0.870184 1 \n", "4 0.008617 0.001869 0.854299 0.863025 0.1 \n", "5 0.006309 0.001280 0.850051 0.858238 0.1 \n", "6 0.008014 0.001878 0.841188 0.851557 0.01 \n", "7 0.005691 0.001634 0.668766 0.678327 0.01 \n", "\n", " param_penalty rank_test_score split0_test_score split0_train_score \\\n", "0 l2 2 0.851647 0.873222 \n", "1 l1 1 0.852701 0.873080 \n", "2 l2 4 0.852569 0.871678 \n", "3 l1 3 0.854150 0.872579 \n", "4 l2 5 0.843742 0.862508 \n", "5 l1 6 0.850725 0.855441 \n", "6 l2 7 0.813570 0.851761 \n", "7 l1 8 0.569960 0.714090 \n", "\n", " split1_test_score ... split2_test_score split2_train_score \\\n", "0 0.839262 ... 0.867447 0.869257 \n", "1 0.838472 ... 0.867848 0.869216 \n", "2 0.838472 ... 0.865575 0.868343 \n", "3 0.840053 ... 0.865709 0.868526 \n", "4 0.838076 ... 0.849131 0.863521 \n", "5 0.835705 ... 0.842848 0.859580 \n", "6 0.837022 ... 0.829345 0.851915 \n", "7 0.730040 ... 0.643115 0.684552 \n", "\n", " split3_test_score split3_train_score split4_test_score \\\n", "0 0.854479 0.871303 0.896924 \n", "1 0.853543 0.871253 0.897329 \n", "2 0.854479 0.870737 0.894765 \n", "3 0.851939 0.871685 0.896249 \n", "4 0.854746 0.866447 0.886131 \n", "5 0.844385 0.862124 0.876821 \n", "6 0.842179 0.853162 0.884242 \n", "7 0.696056 0.661370 0.705073 \n", "\n", " split4_train_score std_fit_time std_score_time std_test_score \\\n", "0 0.860791 0.003053 0.000502 0.019619 \n", "1 0.860816 0.030385 0.000329 0.019941 \n", "2 0.861256 0.001367 0.000282 0.018852 \n", "3 0.861156 0.005793 0.000564 0.019103 \n", "4 0.854147 0.001068 0.000648 0.016791 \n", "5 0.852264 0.001306 0.000357 0.014163 \n", "6 0.845861 0.001652 0.000428 0.023519 \n", "7 0.656955 0.001606 0.000693 0.057076 \n", "\n", " std_train_score \n", "0 0.005156 \n", "1 0.005139 \n", "2 0.005220 \n", "3 0.005260 \n", "4 0.004920 \n", "5 0.003820 \n", "6 0.003087 \n", "7 0.020374 \n", "\n", "[8 rows x 21 columns]" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame.(data: lrc.cv_results_).drop.(:params, axis: 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### サポートベクトルマシンによる分類モデルの作成" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=5, error_score='raise',\n", " estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", " decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',\n", " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", " tol=0.001, verbose=False),\n", " fit_params={}, iid=True, n_jobs=4,\n", " param_grid={'C': [10.0, 1.0, 0.1, 0.01], 'gamma': [0.2, 0.1, 0.06666666666666667, 0.05]},\n", " pre_dispatch='2*n_jobs', refit=True, return_train_score=True,\n", " scoring='roc_auc', verbose=0)" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "svc = GridSearchCV.(\n", " SVC.(kernel: :rbf),\n", " {\n", " C: [10.0, 1.0, 0.1, 0.01],\n", " gamma: [5, 10, 15, 20].map {|x| 1.0 / x },\n", " },\n", " scoring: :roc_auc,\n", " n_jobs: 4,\n", " cv: 5\n", ")" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=5, error_score='raise',\n", " estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", " decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',\n", " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", " tol=0.001, verbose=False),\n", " fit_params={}, iid=True, n_jobs=4,\n", " param_grid={'C': [10.0, 1.0, 0.1, 0.01], 'gamma': [0.2, 0.1, 0.06666666666666667, 0.05]},\n", " pre_dispatch='2*n_jobs', refit=True, return_train_score=True,\n", " scoring='roc_auc', verbose=0)" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "svc.fit.(x, y)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'C': 1.0, 'gamma': 0.05}" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "svc.best_params_" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7861745269974777" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "svc.best_score_" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_fit_timemean_score_timemean_test_scoremean_train_scoreparam_Cparam_gammarank_test_scoresplit0_test_scoresplit0_train_scoresplit1_test_scoresplit1_train_scoresplit2_test_scoresplit2_train_scoresplit3_test_scoresplit3_train_scoresplit4_test_scoresplit4_train_scorestd_fit_timestd_score_timestd_test_scorestd_train_score
00.0793610.0085250.7642780.973148100.270.7065880.9774460.7397890.9730870.7921790.9787130.7961900.9672240.7872370.9692680.0183220.0012470.0353950.004463
10.0424050.0059430.7641560.969342100.180.6984190.9754600.7416340.9698660.7868320.9757540.8078210.9643470.7866970.9612810.0076240.0010100.0394290.005809
20.0385380.0040080.7704270.964944100.066666760.7318840.9665570.7409750.9659610.7770720.9671740.8149060.9627100.7877770.9623180.0040630.0001510.0306030.002025
30.0323190.0041810.7713510.957600100.0550.7227930.9579050.7461130.9591690.7691840.9617540.8194520.9566990.7997840.9524720.0035460.0003360.0350060.003063
40.0244430.0052620.7768940.95681310.230.7226610.9572200.7415020.9599780.7841580.9622360.8273400.9524090.8094980.9522230.0008460.0008050.0396560.004001
50.0244680.0042560.7754440.94132310.140.6974970.9453800.7359680.9455180.7864300.9484100.8257350.9348840.8325690.9324240.0024960.0000930.0520700.006402
60.0239040.0044720.7818210.92724910.066666720.7000000.9283590.7420290.9338450.7870990.9352830.8350940.9237030.8459260.9150550.0011100.0002680.0552440.007356
70.0233740.0046060.7861750.91839710.0510.7088270.9241450.7476940.9294390.7957890.9255480.8336230.9076990.8459260.9051510.0012360.0006410.0517380.009960
80.0260610.0078470.7621360.9333530.10.290.6977600.9324720.7158100.9351550.7699870.9428650.8201200.9322410.8078790.9240300.0039820.0044820.0486150.006044
90.0233140.0080070.7568440.9091560.10.1160.6400530.9029680.7142290.9159140.7694520.9156880.8209220.9052050.8409340.9060050.0008440.0048080.0732250.005517
100.0280510.0048080.7587740.8896900.10.0666667120.6237150.8773100.7284580.8955880.7772060.9000510.8207890.8861090.8451160.8893910.0055200.0004270.0785260.007852
110.0223240.0047460.7586990.8753190.10.05130.6209490.8602380.7267460.8799680.7780080.8852440.8285430.8710120.8406640.8801340.0012160.0005390.0800050.008824
120.0231410.0047230.7620120.9347550.010.2100.6971010.9352920.7178520.9384090.7705210.9427320.8191840.9319660.8062600.9253730.0029680.0001460.0479020.005883
130.0217300.0046230.7581000.9086550.010.1150.6392620.9014080.7154150.9160810.7698530.9164620.8254680.9036010.8418780.9057230.0024840.0004160.0743760.006368
140.0259440.0061890.7587790.8907560.010.0666667110.6218710.8771850.7278000.8969060.7801470.9023450.8205210.8875220.8449810.8898220.0034510.0007040.0792930.008575
150.0276010.0073140.7585610.8755790.010.05140.6212120.8621910.7283270.8802100.7826870.8844210.8235960.8726740.8383700.8784010.0012070.0013080.0787410.007685
\n", "

16 rows × 21 columns

\n", "
" ], "text/plain": [ " mean_fit_time mean_score_time mean_test_score mean_train_score param_C \\\n", "0 0.079361 0.008525 0.764278 0.973148 10 \n", "1 0.042405 0.005943 0.764156 0.969342 10 \n", "2 0.038538 0.004008 0.770427 0.964944 10 \n", "3 0.032319 0.004181 0.771351 0.957600 10 \n", "4 0.024443 0.005262 0.776894 0.956813 1 \n", "5 0.024468 0.004256 0.775444 0.941323 1 \n", "6 0.023904 0.004472 0.781821 0.927249 1 \n", "7 0.023374 0.004606 0.786175 0.918397 1 \n", "8 0.026061 0.007847 0.762136 0.933353 0.1 \n", "9 0.023314 0.008007 0.756844 0.909156 0.1 \n", "10 0.028051 0.004808 0.758774 0.889690 0.1 \n", "11 0.022324 0.004746 0.758699 0.875319 0.1 \n", "12 0.023141 0.004723 0.762012 0.934755 0.01 \n", "13 0.021730 0.004623 0.758100 0.908655 0.01 \n", "14 0.025944 0.006189 0.758779 0.890756 0.01 \n", "15 0.027601 0.007314 0.758561 0.875579 0.01 \n", "\n", " param_gamma rank_test_score split0_test_score split0_train_score \\\n", "0 0.2 7 0.706588 0.977446 \n", "1 0.1 8 0.698419 0.975460 \n", "2 0.0666667 6 0.731884 0.966557 \n", "3 0.05 5 0.722793 0.957905 \n", "4 0.2 3 0.722661 0.957220 \n", "5 0.1 4 0.697497 0.945380 \n", "6 0.0666667 2 0.700000 0.928359 \n", "7 0.05 1 0.708827 0.924145 \n", "8 0.2 9 0.697760 0.932472 \n", "9 0.1 16 0.640053 0.902968 \n", "10 0.0666667 12 0.623715 0.877310 \n", "11 0.05 13 0.620949 0.860238 \n", "12 0.2 10 0.697101 0.935292 \n", "13 0.1 15 0.639262 0.901408 \n", "14 0.0666667 11 0.621871 0.877185 \n", "15 0.05 14 0.621212 0.862191 \n", "\n", " split1_test_score ... split2_test_score split2_train_score \\\n", "0 0.739789 ... 0.792179 0.978713 \n", "1 0.741634 ... 0.786832 0.975754 \n", "2 0.740975 ... 0.777072 0.967174 \n", "3 0.746113 ... 0.769184 0.961754 \n", "4 0.741502 ... 0.784158 0.962236 \n", "5 0.735968 ... 0.786430 0.948410 \n", "6 0.742029 ... 0.787099 0.935283 \n", "7 0.747694 ... 0.795789 0.925548 \n", "8 0.715810 ... 0.769987 0.942865 \n", "9 0.714229 ... 0.769452 0.915688 \n", "10 0.728458 ... 0.777206 0.900051 \n", "11 0.726746 ... 0.778008 0.885244 \n", "12 0.717852 ... 0.770521 0.942732 \n", "13 0.715415 ... 0.769853 0.916462 \n", "14 0.727800 ... 0.780147 0.902345 \n", "15 0.728327 ... 0.782687 0.884421 \n", "\n", " split3_test_score split3_train_score split4_test_score \\\n", "0 0.796190 0.967224 0.787237 \n", "1 0.807821 0.964347 0.786697 \n", "2 0.814906 0.962710 0.787777 \n", "3 0.819452 0.956699 0.799784 \n", "4 0.827340 0.952409 0.809498 \n", "5 0.825735 0.934884 0.832569 \n", "6 0.835094 0.923703 0.845926 \n", "7 0.833623 0.907699 0.845926 \n", "8 0.820120 0.932241 0.807879 \n", "9 0.820922 0.905205 0.840934 \n", "10 0.820789 0.886109 0.845116 \n", "11 0.828543 0.871012 0.840664 \n", "12 0.819184 0.931966 0.806260 \n", "13 0.825468 0.903601 0.841878 \n", "14 0.820521 0.887522 0.844981 \n", "15 0.823596 0.872674 0.838370 \n", "\n", " split4_train_score std_fit_time std_score_time std_test_score \\\n", "0 0.969268 0.018322 0.001247 0.035395 \n", "1 0.961281 0.007624 0.001010 0.039429 \n", "2 0.962318 0.004063 0.000151 0.030603 \n", "3 0.952472 0.003546 0.000336 0.035006 \n", "4 0.952223 0.000846 0.000805 0.039656 \n", "5 0.932424 0.002496 0.000093 0.052070 \n", "6 0.915055 0.001110 0.000268 0.055244 \n", "7 0.905151 0.001236 0.000641 0.051738 \n", "8 0.924030 0.003982 0.004482 0.048615 \n", "9 0.906005 0.000844 0.004808 0.073225 \n", "10 0.889391 0.005520 0.000427 0.078526 \n", "11 0.880134 0.001216 0.000539 0.080005 \n", "12 0.925373 0.002968 0.000146 0.047902 \n", "13 0.905723 0.002484 0.000416 0.074376 \n", "14 0.889822 0.003451 0.000704 0.079293 \n", "15 0.878401 0.001207 0.001308 0.078741 \n", "\n", " std_train_score \n", "0 0.004463 \n", "1 0.005809 \n", "2 0.002025 \n", "3 0.003063 \n", "4 0.004001 \n", "5 0.006402 \n", "6 0.007356 \n", "7 0.009960 \n", "8 0.006044 \n", "9 0.005517 \n", "10 0.007852 \n", "11 0.008824 \n", "12 0.005883 \n", "13 0.006368 \n", "14 0.008575 \n", "15 0.007685 \n", "\n", "[16 rows x 21 columns]" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame.(data: svc.cv_results_).drop.(:params, axis: 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 結果" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result = pd.DataFrame.(data: {\n", " model: %w[RFC LRC SVC],\n", " score: [rfc.best_score_, lrc.best_score_, svc.best_score_]\n", "})\n", "sns.barplot.(x: :model, y: :score, data: result)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Ruby 2.4.0", "language": "ruby", "name": "ruby" }, "language_info": { "file_extension": ".rb", "mimetype": "application/x-ruby", "name": "ruby", "version": "2.4.0" } }, "nbformat": 4, "nbformat_minor": 2 }