Sha256: bc2f06316859261b82f1740b9df4a8f51a7a97afa270c75dec0876f4601d5a83

Contents?: true

Size: 1.31 KB

Versions: 4

Compression:

Stored size: 1.31 KB

Contents

module EasyML::Data::Dataset::Splitters
  class DateSplitter
    include GlueGun::DSL

    attribute :today, :datetime
    def today=(value)
      super(value.in_time_zone(UTC).to_datetime)
    end
    attribute :date_col, :string
    attribute :months_test, :integer, default: 2
    attribute :months_valid, :integer, default: 2

    def initialize(options)
      options[:today] ||= UTC.now
      super(options)
    end

    def split(df)
      unless df[date_col].dtype.is_a?(Polars::Datetime)
        raise "Date splitter cannot split on non-date col #{date_col}, dtype is #{df[date_col].dtype}"
      end

      validation_date_start, test_date_start = splits

      test_df = df.filter(Polars.col(date_col) >= test_date_start)
      remaining_df = df.filter(Polars.col(date_col) < test_date_start)
      valid_df = remaining_df.filter(Polars.col(date_col) >= validation_date_start)
      train_df = remaining_df.filter(Polars.col(date_col) < validation_date_start)

      [train_df, valid_df, test_df]
    end

    def months(n)
      ActiveSupport::Duration.months(n)
    end

    def splits
      test_date_start = today.advance(months: -months_test).beginning_of_day
      validation_date_start = today.advance(months: -(months_test + months_valid)).beginning_of_day
      [validation_date_start, test_date_start]
    end
  end
end

Version data entries

4 entries across 4 versions & 1 rubygems

Version Path
easy_ml-0.1.4 lib/easy_ml/data/dataset/splitters/date_splitter.rb
easy_ml-0.1.3 lib/easy_ml/data/dataset/splitters/date_splitter.rb
easy_ml-0.1.2 lib/easy_ml/data/dataset/splitters/date_splitter.rb
easy_ml-0.1.1 lib/easy_ml/data/dataset/splitters/date_splitter.rb