/* Copyright (c) 2020 Contrast Security, Inc.  See
 * https://www.contrastsecurity.com/enduser-terms-0317a for more details. */

#include "cs__assess_regexp_track.h"
#include <funchook.h>
#include <ruby.h>

static VALUE rb_reg_match_pre_hook(VALUE match) {
    VALUE result = rb_reg_match_pre_original(match);
    result = rb_funcall(regexp_class, track_rb_pre_match, 2, match, result);
    return result;
}

static VALUE rb_reg_match_post_hook(VALUE match) {
    VALUE result = rb_reg_match_post_original(match);
    result = rb_funcall(regexp_class, track_rb_post_match, 2, match, result);
    return result;
}

static VALUE rb_reg_match_last_hook(VALUE match) {
    VALUE result = rb_reg_match_last_original(match);
    result =
        rb_funcall(regexp_class, track_rb_reg_match_last, 2, match, result);
    return result;
}

static VALUE rb_reg_nth_match_hook(int nth, VALUE match) {
    VALUE result = rb_reg_nth_match_original(nth, match);
    result = rb_funcall(regexp_class, track_rb_n_match, 2, match, result);
    return result;
}

static int install_regexp_hooks() {
    funchook_t *funchook = funchook_create();

    rb_reg_match_pre_original = rb_reg_match_pre;
    funchook_prepare(funchook, (void **)&rb_reg_match_pre_original,
                     rb_reg_match_pre_hook);

    rb_reg_match_post_original = rb_reg_match_post;
    funchook_prepare(funchook, (void **)&rb_reg_match_post_original,
                     rb_reg_match_post_hook);

    rb_reg_match_last_original = rb_reg_match_last;
    funchook_prepare(funchook, (void **)&rb_reg_match_last_original,
                     rb_reg_match_last_hook);

    rb_reg_nth_match_original = rb_reg_nth_match;
    funchook_prepare(funchook, (void **)&rb_reg_nth_match_original,
                     rb_reg_nth_match_hook);

    funchook_install(funchook, 0);
    return 0;
}

void Init_cs__assess_regexp_track(void) {
    regexp_class = rb_define_class("Regexp", rb_cObject);
    track_rb_n_match = rb_intern("track_rb_n_match");
    track_rb_pre_match = rb_intern("track_rb_pre_match");
    track_rb_post_match = rb_intern("track_rb_post_match");
    track_rb_reg_match_last = rb_intern("track_rb_reg_match_last");
    install_regexp_hooks();
}