#include "zstdruby.h" #include "./libzstd/zstd.h" static VALUE zstdVersion(VALUE self) { unsigned version = ZSTD_versionNumber(); return INT2NUM(version); } static VALUE compress(int argc, VALUE *argv, VALUE self) { VALUE input_value; VALUE compression_level_value; rb_scan_args(argc, argv, "11", &input_value, &compression_level_value); Check_Type(input_value, RUBY_T_STRING); const char* input_data = RSTRING_PTR(input_value); size_t input_size = RSTRING_LEN(input_value); int compression_level; if (NIL_P(compression_level_value)) { compression_level = 0; // The default. See ZSTD_CLEVEL_DEFAULT in zstd_compress.c } else { compression_level = NUM2INT(compression_level_value); } // do compress size_t max_compressed_size = ZSTD_compressBound(input_size); VALUE output = rb_str_new(NULL, max_compressed_size); char* output_data = RSTRING_PTR(output); size_t compressed_size = ZSTD_compress((void*)output_data, max_compressed_size, (const void*)input_data, input_size, compression_level); if (ZSTD_isError(compressed_size)) { rb_raise(rb_eRuntimeError, "%s: %s", "compress failed", ZSTD_getErrorName(compressed_size)); } else { rb_str_resize(output, compressed_size); } return output; } static VALUE decompress_buffered(const char* input_data, size_t input_size) { const size_t outputBufferSize = 4096; ZSTD_DStream* const dstream = ZSTD_createDStream(); if (dstream == NULL) { rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDStream failed"); } size_t initResult = ZSTD_initDStream(dstream); if (ZSTD_isError(initResult)) { ZSTD_freeDStream(dstream); rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_initDStream failed", ZSTD_getErrorName(initResult)); } VALUE output_string = rb_str_new(NULL, 0); ZSTD_outBuffer output = { NULL, 0, 0 }; ZSTD_inBuffer input = { input_data, input_size, 0 }; while (input.pos < input.size) { output.size += outputBufferSize; rb_str_resize(output_string, output.size); output.dst = RSTRING_PTR(output_string); size_t readHint = ZSTD_decompressStream(dstream, &output, &input); if (ZSTD_isError(readHint)) { ZSTD_freeDStream(dstream); rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_decompressStream failed", ZSTD_getErrorName(readHint)); } } ZSTD_freeDStream(dstream); rb_str_resize(output_string, output.pos); return output_string; } static VALUE decompress(VALUE self, VALUE input) { Check_Type(input, T_STRING); const char* input_data = RSTRING_PTR(input); size_t input_size = RSTRING_LEN(input); uint64_t uncompressed_size = ZSTD_getDecompressedSize(input_data, input_size); if (uncompressed_size == 0) { return decompress_buffered(input_data, input_size); } VALUE output = rb_str_new(NULL, uncompressed_size); char* output_data = RSTRING_PTR(output); size_t decompress_size = ZSTD_decompress((void*)output_data, uncompressed_size, (const void*)input_data, input_size); if (ZSTD_isError(decompress_size)) { rb_raise(rb_eRuntimeError, "%s: %s", "decompress error", ZSTD_getErrorName(decompress_size)); } return output; } VALUE rb_mZstd; void Init_zstdruby(void) { rb_mZstd = rb_define_module("Zstd"); rb_define_module_function(rb_mZstd, "zstd_version", zstdVersion, 0); rb_define_module_function(rb_mZstd, "compress", compress, -1); rb_define_module_function(rb_mZstd, "decompress", decompress, 1); }