Sha256: 13ae962f21ccf375da0e14449f8ba4ca8951801e006c5416958f69135b902b20

Contents?: true

Size: 1.63 KB

Versions: 3

Compression:

Stored size: 1.63 KB

Contents

require 'logger'

module Baran
  class TextSplitter
    attr_accessor :chunk_size, :chunk_overlap

    def initialize(chunk_size: 1024, chunk_overlap: 64)
      @chunk_size = chunk_size
      @chunk_overlap = chunk_overlap
      raise "Cannot have chunk_overlap >= chunk_size" if @chunk_overlap >= @chunk_size
    end

    def splitted(text)
      raise NotImplementedError, "splitted method should be implemented in a subclass"
    end

    def chunks(text, metadata: nil)
      cursor = 0
      chunks = []

      splitted(text).compact.each do |chunk|
        chunk = { text: chunk, cursor: cursor }
        chunk[:metadata] = metadata if metadata
        chunks << chunk
        cursor += chunk[:text].length
      end

      chunks
    end

    def joined(items, separator)
      text = items.join(separator).strip
      text.empty? ? nil : text
    end

    def merged(splits, separator)
      results = [] # Array of strings
      current_splits = [] # Array of strings
      total = 0

      splits.each do |split|
        if total + split.length >= chunk_size && current_splits.length.positive?
          results << joined(current_splits, separator)

          while total > chunk_overlap || (total + split.length >= chunk_size && total.positive?)
            total -= current_splits.first.length
            current_splits.shift
          end
        end

        current_splits << split
        total += split.length
        Logger.new(STDOUT).warn("Created a chunk of size #{total}, which is longer than the specified #{@chunk_size}") if total > @chunk_size
      end

      results << joined(current_splits, separator)

      results
    end
  end
end

Version data entries

3 entries across 3 versions & 1 rubygems

Version Path
baran-0.2.1 lib/baran/text_splitter.rb
baran-0.2.0 lib/baran/text_splitter.rb
baran-0.1.12 lib/baran/text_splitter.rb