lib/blingfire.rb in blingfire-0.1.0 vs lib/blingfire.rb in blingfire-0.1.1

- old
+ new

@@ -33,56 +33,57 @@ def load_model(path) Model.new(path) end def text_to_words(text) - text = encode_utf8(text.dup) unless text.encoding == Encoding::UTF_8 - out = Fiddle::Pointer.malloc(text.bytesize * 3) - out_size = FFI.TextToWords(text, text.bytesize, out, out.size) - check_status out_size - encode_utf8(out[0, out_size - 1]).split(" ") + text_to(text, " ") do |t, out| + FFI.TextToWords(t, t.bytesize, out, out.size) + end end def text_to_words_with_model(model, text) - text = encode_utf8(text.dup) unless text.encoding == Encoding::UTF_8 - out = Fiddle::Pointer.malloc(text.bytesize * 3) - out_size = FFI.TextToWordsWithModel(text, text.bytesize, out, out.size, model) - check_status out_size - encode_utf8(out[0, out_size - 1]).split(" ") + text_to(text, " ") do |t, out| + FFI.TextToWordsWithModel(t, t.bytesize, out, out.size, model) + end end def text_to_sentences(text) - text = encode_utf8(text.dup) unless text.encoding == Encoding::UTF_8 - out = Fiddle::Pointer.malloc(text.bytesize * 3) - out_size = FFI.TextToSentences(text, text.bytesize, out, out.size) - check_status out_size - encode_utf8(out[0, out_size - 1]).split("\n") + text_to(text, "\n") do |t, out| + FFI.TextToSentences(t, t.bytesize, out, out.size) + end end def text_to_sentences_with_model(model, text) - text = encode_utf8(text.dup) unless text.encoding == Encoding::UTF_8 - out = Fiddle::Pointer.malloc(text.bytesize * 3) - out_size = FFI.TextToSentencesWithModel(text, text.bytesize, out, out.size, model) - check_status out_size - encode_utf8(out[0, out_size - 1]).split("\n") + text_to(text, "\n") do |t, out| + FFI.TextToSentencesWithModel(t, t.bytesize, out, out.size, model) + end end def text_to_ids(model, text, max_len = nil, unk_id = 0) text = encode_utf8(text.dup) unless text.encoding == Encoding::UTF_8 ids = Fiddle::Pointer.malloc((max_len || text.size) * Fiddle::SIZEOF_INT) out_size = FFI.TextToIds(model, text, text.bytesize, ids, ids.size, unk_id) - check_status out_size + check_status out_size, ids ids[0, (max_len || out_size) * Fiddle::SIZEOF_INT].unpack("i!*") end def free_model(model) FFI.FreeModel(model) end private - def check_status(ret) - raise Error, "Bad status" if ret == -1 + def check_status(ret, ptr) + raise Error, "Not enough memory allocated" if ret == -1 || ret > ptr.size + end + + def text_to(text, sep) + text = encode_utf8(text.dup) unless text.encoding == Encoding::UTF_8 + # TODO allocate less, and try again if needed + out = Fiddle::Pointer.malloc([text.bytesize * 1.5, 20].max) + out_size = yield(text, out) + check_status out_size, out + encode_utf8(out.to_str(out_size - 1)).split(sep) end def encode_utf8(text) text.force_encoding(Encoding::UTF_8) end