/* Copyright (c) 2020 Contrast Security, Inc.  See
 * https://www.contrastsecurity.com/enduser-terms-0317a for more details. */

#include "cs__common.h"
#include <ruby.h>
#include <dlfcn.h>

/* Globals */
/* These are defined w/ `extern` in the header */
VALUE contrast, agent, patching, policy, assess;
VALUE core_extensions, core_assess;
VALUE assess_policy, assess_propagator;
VALUE funchook_path;

VALUE rb_sym_enter_scope;
VALUE rb_sym_exit_scope;
VALUE rb_sym_in_scope;
VALUE rb_sym_skip_contrast_analysis;
VALUE rb_sym_skip_assess_analysis;
VALUE rb_sym_method;
VALUE rb_sym_cs_tracked;
/* end globals */

void patch_via_funchook(void *original_function, void *hook_function) {
    VALUE funchook_module_wrapper = rb_define_module("Funchook");
    funchook_path = rb_iv_get(funchook_module_wrapper, "@path");

    void *funchook_lib_handle;
    void *funchook_reference, *(*funchook_create)(void);
    int prepareResult, (*funchook_prepare)(void*, void**, void*);
    int installResult, (*funchook_install)(void*, int);

    funchook_lib_handle = dlopen(StringValueCStr(funchook_path), RTLD_NOW | RTLD_GLOBAL);

    /* Load the funchook methods we need */
    funchook_create  = (void* (*)(void))dlsym(funchook_lib_handle, "funchook_create");
    funchook_prepare = (int (*)(void*, void**, void*))dlsym(funchook_lib_handle, "funchook_prepare");
    funchook_install = (int (*)(void*, int))dlsym(funchook_lib_handle, "funchook_install");

    funchook_reference = (void*)(*funchook_create)();

    prepareResult = (*funchook_prepare)(funchook_reference, (void**)original_function, hook_function);
    installResult = (*funchook_install)(funchook_reference, 0);
}

void contrast_alias_method(const VALUE target, const char *to,
                           const char *from) {
    rb_funcall(target, cs__send_method, 3, cs__alias_method_sym,
               ID2SYM(rb_intern(to)), ID2SYM(rb_intern(from)));
}

VALUE contrast_patcher() {
    return patcher;
}

VALUE contrast_register_patch(const char *module_name,
                               const char *method_name,
                               VALUE(c_fn)(const int, const VALUE*, const VALUE)
                               ) {
  return _contrast_register_patch(module_name, method_name, c_fn, IMPL_ALIAS_INSTANCE);
}


VALUE contrast_register_singleton_patch(const char *module_name,
                                        const char *method_name,
                                        VALUE(c_fn)(const int, const VALUE*, const VALUE)
                                        ) {
  return _contrast_register_patch(module_name, method_name, c_fn, IMPL_ALIAS_SINGLETON);
}

VALUE contrast_register_singleton_prepend_patch(const char *module_name,
                                                const char *method_name,
                                                VALUE(c_fn)(const int, const VALUE*, const VALUE)
                                                ) {
  return _contrast_register_patch(module_name, method_name, c_fn, IMPL_PREPEND);
}

static VALUE _contrast_register_patch(const char *module_name,
                                      const char *method_name,
                                      VALUE(c_fn)(const int, const VALUE*, const VALUE),
                                      patch_impl patch
                                      ) {
    VALUE contrast_bind_module = rb_funcall(rb_cModule, rb_intern("new"), 0);
    VALUE unbound_method = Qnil;
    VALUE rb_str_module_name = rb_str_new_cstr(module_name);
    VALUE rb_str_method_name = rb_str_new_cstr(method_name);

    /* We register the c function as an instance method on an
     * anonymous module, and then unbind it.  This creates
     * an UnboundMethod object that we register.  This
     * UnboundMethod object can then be bound to a class
     * with define_method:
     *
     * my_method = get_patch_from_c_somehow()
     *
     * class MyClass
     *   define_method('name', my_method.bind(self))
     * end
     *
     * MyClass.new.name # calls my_method
     */
    /* This is an anonymous module, upon which we define C functions
     * as Ruby functions.  We immediately unbind and undef them.
     * This module doesn't do anything.
     */
    rb_define_method(contrast_bind_module, method_name, (VALUE(*)())c_fn, -1);
    VALUE rb_sym_instance_method = rb_intern("instance_method");
    unbound_method = rb_funcall(contrast_bind_module, rb_sym_instance_method, 1, rb_str_method_name);
    rb_undef_method(contrast_bind_module, method_name);

    /* map impl enum -> ruby symbol */
    VALUE impl = Qnil;
    switch(patch) {
      case IMPL_ALIAS_INSTANCE:
        impl = ID2SYM(rb_sym_alias_instance);
        break;
      case IMPL_ALIAS_SINGLETON:
        impl = ID2SYM(rb_sym_alias_singleton);
        break;
      case IMPL_PREPEND:
        impl = ID2SYM(rb_sym_prepend);
        break;
    }

    VALUE underlying_method_name = rb_funcall(contrast_patcher(), rb_sym_register_c_patch, 3, rb_str_module_name, unbound_method, impl);
    return SYM2ID(underlying_method_name);
}

void Init_cs__common(void) {
    cs__send_method = rb_intern("send");
    cs__alias_method_sym = ID2SYM(rb_intern("alias_method"));

    /* Define symbols */
    rb_sym_enter_scope = rb_intern("enter_contrast_scope!");
    rb_sym_exit_scope = rb_intern("exit_contrast_scope!");
    rb_sym_in_scope = rb_intern("in_contrast_scope?");
    rb_sym_skip_contrast_analysis = rb_intern("skip_contrast_analysis?");
    rb_sym_skip_assess_analysis = rb_intern("skip_assess_analysis?");
    rb_sym_method = rb_intern("__method__");
    rb_sym_cs_tracked = rb_intern("cs__tracked?");

    /* Used for returning unbound C functions */
    rb_sym_register_c_patch = rb_intern("register_c_patch");
    rb_sym_alias_instance   = rb_intern("alias_instance");
    rb_sym_alias_singleton  = rb_intern("alias_singleton");
    rb_sym_prepend          = rb_intern("prepend");

    /* Ensure definition of core Contrast instrumentation modules */
    contrast = rb_define_module("Contrast");
    agent = rb_define_module_under(contrast, "Agent");

    assess = rb_define_module_under(agent, "Assess");

    patching = rb_define_module_under(agent, "Patching");
    policy = rb_define_module_under(patching, "Policy");
    patcher = rb_define_module_under(policy, "Patch");

    assess_policy = rb_define_module_under(assess, "Policy");
    assess_propagator = rb_define_module_under(assess_policy, "Propagator");

    core_extensions = rb_define_module_under(contrast, "Extension");
    core_assess = rb_define_module_under(core_extensions, "Assess");
}