lib/blingfire.rb in blingfire-0.1.0 vs lib/blingfire.rb in blingfire-0.1.1
- old
+ new
@@ -33,56 +33,57 @@
def load_model(path)
Model.new(path)
end
def text_to_words(text)
- text = encode_utf8(text.dup) unless text.encoding == Encoding::UTF_8
- out = Fiddle::Pointer.malloc(text.bytesize * 3)
- out_size = FFI.TextToWords(text, text.bytesize, out, out.size)
- check_status out_size
- encode_utf8(out[0, out_size - 1]).split(" ")
+ text_to(text, " ") do |t, out|
+ FFI.TextToWords(t, t.bytesize, out, out.size)
+ end
end
def text_to_words_with_model(model, text)
- text = encode_utf8(text.dup) unless text.encoding == Encoding::UTF_8
- out = Fiddle::Pointer.malloc(text.bytesize * 3)
- out_size = FFI.TextToWordsWithModel(text, text.bytesize, out, out.size, model)
- check_status out_size
- encode_utf8(out[0, out_size - 1]).split(" ")
+ text_to(text, " ") do |t, out|
+ FFI.TextToWordsWithModel(t, t.bytesize, out, out.size, model)
+ end
end
def text_to_sentences(text)
- text = encode_utf8(text.dup) unless text.encoding == Encoding::UTF_8
- out = Fiddle::Pointer.malloc(text.bytesize * 3)
- out_size = FFI.TextToSentences(text, text.bytesize, out, out.size)
- check_status out_size
- encode_utf8(out[0, out_size - 1]).split("\n")
+ text_to(text, "\n") do |t, out|
+ FFI.TextToSentences(t, t.bytesize, out, out.size)
+ end
end
def text_to_sentences_with_model(model, text)
- text = encode_utf8(text.dup) unless text.encoding == Encoding::UTF_8
- out = Fiddle::Pointer.malloc(text.bytesize * 3)
- out_size = FFI.TextToSentencesWithModel(text, text.bytesize, out, out.size, model)
- check_status out_size
- encode_utf8(out[0, out_size - 1]).split("\n")
+ text_to(text, "\n") do |t, out|
+ FFI.TextToSentencesWithModel(t, t.bytesize, out, out.size, model)
+ end
end
def text_to_ids(model, text, max_len = nil, unk_id = 0)
text = encode_utf8(text.dup) unless text.encoding == Encoding::UTF_8
ids = Fiddle::Pointer.malloc((max_len || text.size) * Fiddle::SIZEOF_INT)
out_size = FFI.TextToIds(model, text, text.bytesize, ids, ids.size, unk_id)
- check_status out_size
+ check_status out_size, ids
ids[0, (max_len || out_size) * Fiddle::SIZEOF_INT].unpack("i!*")
end
def free_model(model)
FFI.FreeModel(model)
end
private
- def check_status(ret)
- raise Error, "Bad status" if ret == -1
+ def check_status(ret, ptr)
+ raise Error, "Not enough memory allocated" if ret == -1 || ret > ptr.size
+ end
+
+ def text_to(text, sep)
+ text = encode_utf8(text.dup) unless text.encoding == Encoding::UTF_8
+ # TODO allocate less, and try again if needed
+ out = Fiddle::Pointer.malloc([text.bytesize * 1.5, 20].max)
+ out_size = yield(text, out)
+ check_status out_size, out
+ encode_utf8(out.to_str(out_size - 1)).split(sep)
end
def encode_utf8(text)
text.force_encoding(Encoding::UTF_8)
end