/* Copyright (c) 2022 Contrast Security, Inc.  See
 * https://www.contrastsecurity.com/enduser-terms-0317a for more details. */

#include "cs__common.h"
#include <dlfcn.h>
#include <ruby.h>

/* Globals */
/* These are defined w/ `extern` in the header */
VALUE contrast, agent, patching, policy, assess;
VALUE core_extensions, core_assess;
VALUE assess_policy, assess_propagator;
VALUE components;
VALUE funchook_path;

VALUE rb_sym_enter_scope;
VALUE rb_sym_exit_scope;
VALUE rb_sym_in_scope;
VALUE rb_sym_skip_contrast_analysis;
VALUE rb_sym_skip_assess_analysis;
VALUE rb_sym_method;
VALUE rb_sym_hash_get, rb_sym_hash_set, rb_sym_hash_tracked;
/* end globals */

void patch_via_funchook(void *original_function, void *hook_function) {
    VALUE funchook_module_wrapper = rb_define_module("Funchook");
    funchook_path = rb_iv_get(funchook_module_wrapper, "@path");

    void *funchook_lib_handle;
    void *funchook_reference, *(*funchook_create)(void);
    int prepareResult, (*funchook_prepare)(void *, void **, void *);
    int installResult, (*funchook_install)(void *, int);

    funchook_lib_handle =
        dlopen(StringValueCStr(funchook_path), RTLD_NOW | RTLD_GLOBAL);

    /* Load the funchook methods we need */
    funchook_create =
        (void *(*)(void))dlsym(funchook_lib_handle, "funchook_create");
    funchook_prepare = (int (*)(void *, void **, void *))dlsym(
        funchook_lib_handle, "funchook_prepare");
    funchook_install =
        (int (*)(void *, int))dlsym(funchook_lib_handle, "funchook_install");

    funchook_reference = (void *)(*funchook_create)();

    prepareResult = (*funchook_prepare)(
        funchook_reference, (void **)original_function, hook_function);
    installResult = (*funchook_install)(funchook_reference, 0);
}

void contrast_alias_method(const VALUE target, const char *to,
                           const char *from) {
    rb_funcall(target, cs__send_method, 3, cs__alias_method_sym,
               ID2SYM(rb_intern(to)), ID2SYM(rb_intern(from)));
}

VALUE contrast_patcher() {
    return patcher;
}

/* register instance alias patch */
VALUE contrast_register_patch(const char *module_name, const char *method_name,
                              VALUE(c_fn)(const int, VALUE *, const VALUE)) {
    return _contrast_register_patch(module_name, method_name, c_fn,
                                    IMPL_ALIAS_INSTANCE);
}

/* register singleton alias patch */
VALUE contrast_register_singleton_patch(const char *module_name,
                                        const char *method_name,
                                        VALUE(c_fn)(const int, VALUE *,
                                                    const VALUE)) {
    return _contrast_register_patch(module_name, method_name, c_fn,
                                    IMPL_ALIAS_SINGLETON);
}

/* register instance prepend patch */
VALUE contrast_register_prepend_patch(const char *module_name,
                                      const char *method_name,
                                      VALUE(c_fn)(const int, VALUE *,
                                                  const VALUE)) {
    return _contrast_register_patch(module_name, method_name, c_fn,
                                    IMPL_PREPEND_INSTANCE);
}

/* register singleton prepend patch */
VALUE contrast_register_singleton_prepend_patch(const char *module_name,
                                                const char *method_name,
                                                VALUE(c_fn)(const int, VALUE *,
                                                            const VALUE)) {
    return _contrast_register_patch(module_name, method_name, c_fn,
                                    IMPL_PREPEND_SINGLETON);
}

/* check if method is prepended and register instance alias or prepend patch */
/* module name c_char "Module"; */
/* method name c_char  "method"; */
/* c_func => pointer */
VALUE contrast_check_and_register_instance_patch(
    const char *module_name, const char *method_name,
    VALUE(c_fn)(const int, VALUE *, const VALUE)) {

    VALUE object, method, is_prepended, patch_type;
    /* check if method is prepended */
    object = rb_const_get(rb_cObject, rb_intern(module_name));
    method = ID2SYM(rb_intern(method_name));
    is_prepended = contrast_check_prepended(object, method, Qtrue);

    if (is_prepended == Qtrue) {
        /* prepend patch */
        return _contrast_register_patch(module_name, method_name, c_fn,
                                        IMPL_PREPEND_INSTANCE);
    } else {
        /* alias patch */
        return _contrast_register_patch(module_name, method_name, c_fn,
                                        IMPL_ALIAS_INSTANCE);
    }
}

static VALUE
_contrast_register_patch(const char *module_name, const char *method_name,
                         VALUE(c_fn)(const int, VALUE *, const VALUE),
                         patch_impl patch) {
    VALUE contrast_bind_module = rb_funcall(rb_cModule, rb_intern("new"), 0);
    VALUE unbound_method = Qnil;
    VALUE rb_str_module_name = rb_str_new_cstr(module_name);
    VALUE rb_str_method_name = rb_str_new_cstr(method_name);

    /* We register the c function as an instance method on an
     * anonymous module, and then unbind it.  This creates
     * an UnboundMethod object that we register.  This
     * UnboundMethod object can then be bound to a class
     * with define_method:
     *
     * my_method = get_patch_from_c_somehow()
     *
     * class MyClass
     *   define_method('name', my_method.bind(self))
     * end
     *
     * MyClass.new.name # calls my_method
     */
    /* This is an anonymous module, upon which we define C functions
     * as Ruby functions.  We immediately unbind and undef them.
     * This module doesn't do anything.
     */
    rb_define_method(contrast_bind_module, method_name, (VALUE(*)())c_fn, -1);
    VALUE rb_sym_instance_method = rb_intern("instance_method");
    unbound_method = rb_funcall(contrast_bind_module, rb_sym_instance_method, 1,
                                rb_str_method_name);
    rb_undef_method(contrast_bind_module, method_name);

    /* map impl enum -> ruby symbol */
    VALUE impl = Qnil;
    switch (patch) {
    case IMPL_ALIAS_INSTANCE:
        impl = ID2SYM(rb_sym_alias_instance);
        break;
    case IMPL_ALIAS_SINGLETON:
        impl = ID2SYM(rb_sym_alias_singleton);
        break;
    case IMPL_PREPEND_INSTANCE:
        impl = ID2SYM(rb_sym_prepend_instance);
        break;
    case IMPL_PREPEND_SINGLETON:
        impl = ID2SYM(rb_sym_prepend_singleton);
        break;
    }

    VALUE underlying_method_name =
        rb_funcall(contrast_patcher(), rb_sym_register_c_patch, 3,
                   rb_str_module_name, unbound_method, impl);
    return SYM2ID(underlying_method_name);
}

int rb_ver_below_three() {
    int ruby_version =
        FIX2INT(rb_funcall(rb_const_get(rb_cObject, rb_intern("RUBY_VERSION")),
                           rb_intern("to_i"), 0));
    return ruby_version < 3;
}

/* used for direct check on object: String.cs__prepended? *args  */
extern VALUE contrast_check_prepended(VALUE self, VALUE method_name,
                                      VALUE is_instance) {
    return _contrast_check_prepended(self, method_name, is_instance);
}

/* used for passing object to look if not called on itself.
Contrast::Agent::Assess.cs__object_method_prepended? object, :method_name,
true/false */
extern VALUE contrast_lookout_prepended(VALUE self, VALUE object_name,
                                        VALUE method_name, VALUE is_instance) {
    /* object_name must be the object, the self value is needed to prevent
     lookout for self, since is always passed first we skip it */
    VALUE result =
        _contrast_check_prepended(object_name, method_name, is_instance);
    return result;
}

static VALUE _contrast_check_prepended(VALUE object, VALUE method_name,
                                       VALUE is_instance) {
    VALUE entry, ancestors, object_idx, entry_methods;
    VALUE result = Qfalse;
    int i;
    int y;

    /* get self ancestors */
    ancestors = rb_mod_ancestors(object);
    /* get the size of the array */
    int length = RARRAY_LEN(ancestors);
    /* Locate self in ancestors: */
    for (i = 0; i < length; ++i) {
        entry = rb_ary_entry(ancestors, i);
        if (entry == object) {
            object_idx = i;
            break;
        }
    }

    /* find all the prepended modules */
    /* we have the object place in ancestors: */
    /* [suspect, suspect, object, ...] */
    for (i = 0; i < object_idx; ++i) {
        entry = rb_ary_entry(ancestors, i);
        if (is_instance == Qtrue) {
            entry_methods = rb_class_instance_methods(1, entry, entry);
        } else {
            entry_methods = rb_obj_singleton_methods(1, entry, entry);
        }

        /* Loop through the instance/singleton methods of the prepended modules
         */
        int entry_methods_length = RARRAY_LEN(entry_methods);
        for (y = 0; y <= entry_methods_length; ++y) {
            if (rb_ary_entry(entry_methods, y) == method_name) {
                result = Qtrue;
                break;
            }
        }
        if (result == Qtrue) {
            break;
        }
    }
    return result;
}

void Init_cs__common(void) {
    cs__send_method = rb_intern("send");
    cs__alias_method_sym = ID2SYM(rb_intern("alias_method"));

    /* Define symbols */
    rb_sym_enter_scope = rb_intern("enter_contrast_scope!");
    rb_sym_exit_scope = rb_intern("exit_contrast_scope!");
    rb_sym_in_scope = rb_intern("in_contrast_scope?");
    rb_sym_skip_contrast_analysis = rb_intern("skip_contrast_analysis?");
    rb_sym_skip_assess_analysis = rb_intern("skip_assess_analysis?");
    rb_sym_method = rb_intern("__method__");
    rb_sym_hash_get = rb_intern("[]");
    rb_sym_hash_set = rb_intern("[]=");
    rb_sym_hash_tracked = rb_intern("tracked?");

    /* Used for returning unbound C functions */
    rb_sym_register_c_patch = rb_intern("register_c_patch");
    rb_sym_alias_instance = rb_intern("alias_instance");
    rb_sym_alias_singleton = rb_intern("alias_singleton");
    rb_sym_prepend_instance = rb_intern("prepend_instance");
    rb_sym_prepend_singleton = rb_intern("prepend_singleton");

    /* Ensure definition of core Contrast instrumentation modules */
    contrast = rb_define_module("Contrast");
    agent = rb_define_module_under(contrast, "Agent");

    /* components => Contrast::Components */
    components = rb_define_module_under(contrast, "Components");

    assess = rb_define_module_under(agent, "Assess");

    patching = rb_define_module_under(agent, "Patching");
    policy = rb_define_module_under(patching, "Policy");
    patcher = rb_define_module_under(policy, "Patch");

    assess_policy = rb_define_module_under(assess, "Policy");
    assess_propagator = rb_define_module_under(assess_policy, "Propagator");

    core_extensions = rb_define_module_under(contrast, "Extension");
    core_assess = rb_define_module_under(core_extensions, "Assess");
    /* defined for direct object check */
    rb_define_singleton_method(rb_cObject, "cs__prepended?",
                               contrast_check_prepended, 2);
    /* defined for object lookout */
    rb_define_singleton_method(assess, "cs__object_method_prepended?",
                               contrast_lookout_prepended, 4);
}