/* Copyright (c) 2022 Contrast Security, Inc.  See
 * https://www.contrastsecurity.com/enduser-terms-0317a for more details. */

#include "cs__contrast_patch.h"
#include "../cs__common/cs__common.h"
#include "../cs__scope/cs__scope.h"
#include <ruby.h>

VALUE build_preshift(const VALUE method_policy, const VALUE object,
                     const int argc, const VALUE *params) {
    if (method_policy == Qnil) {
        return Qnil;
    }

    VALUE propagation_node, send, ret;
    propagation_node = rb_funcall(method_policy, rb_sym_propagation_node, 0);
    /* if there isn't a propagation node that applies to this method
     *  we don't need to build a preshift */
    if (propagation_node == Qnil) {
        return Qnil;
    }
    send = rb_ary_new_from_values(argc, params);
    ret = rb_funcall(preshift_class, rb_sym_build_preshift, 3, propagation_node,
                     object, send);
    return ret;
}

VALUE contrast_patch_call_original(const VALUE *args) {
    int argc;
    VALUE method, method_id, object;
    VALUE *params;
    argc = NUM2INT(args[0]);
    params = (VALUE *)args[1];
    object = args[2];
    method = args[3];
    method_id = SYM2ID(method);

/* It looks like we can find the last Ruby block given so long as we don't
 * change Ruby method scope (always call this function from C, not Ruby),
 * which is the point of this C call.
 */
/* Ruby >= 2.7 */
#ifdef RB_PASS_CALLED_KEYWORDS
    if (rb_block_given_p()) {
        return rb_funcall_with_block_kw(object, method_id, argc, params,
                                        rb_block_proc(),
                                        RB_PASS_CALLED_KEYWORDS);
    } else {
        return rb_funcallv_kw(object, method_id, argc, params,
                              RB_PASS_CALLED_KEYWORDS);
    }
/* Ruby < 2.7 */
#else
    if (rb_block_given_p()) {
        return rb_funcall_with_block(object, method_id, argc, params,
                                     rb_block_proc());
    } else {
        return rb_funcall2(object, method_id, argc, params);
    }
#endif
}

VALUE contrast_call_pre_patch(const VALUE method_policy, const VALUE method,
                              const VALUE object, const int count,
                              const VALUE *params, const VALUE exception) {
    /* For simplicity's sake, this Ruby method is responsible for
     * exiting scope if it throws a SecurityException.
     * When we start to macro rescues, we can change this.
     */
    VALUE send;
    send = rb_ary_new_from_values(count, params);
    return rb_funcall(contrast_patcher(), rb_sym_contrast_apply_pre_patch, 5,
                      method_policy, method, exception, object, send);
}

VALUE contrast_call_post_patch(const VALUE method_policy, const VALUE preshift,
                               const VALUE object, const VALUE ret,
                               const int count, const VALUE *params) {
    VALUE send;
    send = rb_ary_new_from_values(count, params);

    VALUE block;
    if (rb_block_given_p()) {
        block = rb_block_proc();
    } else {
        block = Qnil;
    }
    return rb_funcall(contrast_patcher(), rb_sym_contrast_apply_post_patch, 6,
                      method_policy, preshift, object, ret, send, block);
}

/* wrap rb_ensure so we can rescue an exception */
VALUE rescue_func(VALUE arg1) {
    VALUE exception;
    /* rb_errinfo() gives the value of $!, the exception that
     * triggered a rescue block.
     */
    exception = rb_errinfo();
    rb_exc_raise(exception);

    return Qnil;
}

VALUE contrast_patch_call_ensure(const VALUE *args) {
    int argc;
    VALUE object, preshift, method_policy, method;
    VALUE *argv;

    object = args[0];
    method = args[1];
    argc = NUM2INT(args[2]);
    argv = (VALUE *)args[3];
    method_policy = args[4];
    preshift = args[5];

    contrast_call_post_patch(method_policy, preshift, object, Qnil, argc, argv);

    return Qnil;
}

VALUE ensure_wrapper(const VALUE *args) {
    VALUE original_method, original_args, ensure_args;

    original_method = args[0];
    original_args = (VALUE)args[1];
    ensure_args = (VALUE)args[2];

    return rb_ensure(original_method, original_args, contrast_patch_call_ensure,
                     (VALUE)ensure_args);
}

VALUE contrast_call_super(const VALUE *args) {
    int argc;
    VALUE *argv;
    argc = NUM2INT(args[0]);
    argv = (VALUE *)args[1];

    return rb_call_super(argc, argv);
}

VALUE contrast_run_patches(const VALUE *wrapped_args) {
    VALUE impl, method, method_policy, object, original_args, original_ret,
        preshift, transformed_ret;
    int argc;
    VALUE *argv;
    VALUE ensure_args[6];
    VALUE rescue_wrapper_args[3];

    impl = wrapped_args[0];
    original_args = wrapped_args[1];
    method = wrapped_args[2];
    method_policy = wrapped_args[3];
    object = wrapped_args[4];
    argc = NUM2INT(wrapped_args[5]);
    argv = (VALUE *)wrapped_args[6];

    rescue_wrapper_args[0] = contrast_patch_call_original;
    rescue_wrapper_args[1] = original_args;
    rescue_wrapper_args[2] = ensure_args;

    ensure_args[0] = object;
    ensure_args[1] = method;
    ensure_args[2] = INT2NUM(argc);
    ensure_args[3] = (VALUE)argv;
    ensure_args[4] = method_policy;

    /* Tracking, triggering, and propagation here. */
    contrast_call_pre_patch(method_policy, method, object, argc, argv, Qnil);

    /* Capture pre-call state */
    preshift = build_preshift(method_policy, object, argc, argv);
    ensure_args[5] = preshift;

    /* We wrap a call to the original method with a rescue block, and we use
     * rb_rescue2 to capture all Exception-inheriting exceptions (and if your
     * software is well-behaved, all exceptions should inherit from Exception.)
     *
     * The rescue block is responsible for doing Contrast post-call analysis
     * in the event the original method has thrown an exception.
     *
     * EDGE CASES:
     * Given how extensively we patch and instrument, this code is
     * prone to some esoteric edge cases that are not well-documented or
     * easy to research.
     *
     * There is an esoteric edge case in core Ruby, upon Thread#kill, where
     * it raises Fixnum 8 (Qnil==8).  This is an intentional choice on the
     * part of the core Ruby devs, as blindly rescuing Thread#kill would be
     * disastrous.
     * A consequence of this is that Thread#kill will leak scope, if you
     * happen to ever instrument it.
     *
     * If you are within a catch block, and the original function results
     * in a throw, you will leak scope.  We handle this by not instrumenting
     * methods that do that.  (Tracked in RUBY-552.)
     *
     * If you're thinking of cleaning this up by using rb_protect,
     * you will catch ALL exceptions, as well as ANYTHING
     * else that unwinds the stack.  This includes fiber context switches
     * (which are used to implement Enumerator#next) and catch/throw blocks.
     * I spent a week debugging that so you don't have to.  -ajm
     */

    switch (impl) {
    case IMPL_ALIAS_INSTANCE:
    case IMPL_ALIAS_SINGLETON:
        original_ret =
            rb_rescue(ensure_wrapper, rescue_wrapper_args, rescue_func, Qnil);
        break;
    case IMPL_PREPEND_INSTANCE:
    case IMPL_PREPEND_SINGLETON:
        rescue_wrapper_args[0] = contrast_call_super;
        original_ret =
            rb_rescue(ensure_wrapper, rescue_wrapper_args, rescue_func, Qnil);
        break;
    };

    /* If you're here, the original method did not throw an exception
     * (or unwind the stack otherwise).
     * If the original method threw an exception, contrast_patch_call_rescue
     * re-raises the original exception, which unwinds the stack back to the
     * call site.  This means the rest of this function is not executed.
     */

    /* Invoke Contrast post-call patching.
     * Post-call patching may transform the return value,
     * hence the assignment.
     */
    transformed_ret = contrast_call_post_patch(method_policy, preshift, object,
                                               original_ret, argc, argv);

    /* Special case for tracking frozen sources */
    if (transformed_ret != Qnil) {
        return transformed_ret;
    } else {
        return original_ret;
    }
}

VALUE contrast_ensure_function(const VALUE method_policy) {
    /* exit scope */
    VALUE scopes = rb_funcall(method_policy, rb_sym_scopes_to_exit, 0);

    inst_methods_exit_method_scope(contrast_patcher(), scopes);
    inst_methods_exit_cntr_scope(contrast_patcher(), 0);

    return Qnil;
}

VALUE contrast_patch_dispatch(const int argc, const VALUE *argv,
                              const patch_impl impl, const VALUE object) {
    VALUE cs__method, known, method, method_policy;
    VALUE original_args[4];
    int do_contrast, nested_scope;

    /* Do Contrast analysis, unless our subsequent checks tell us no. */
    do_contrast = 1;

    /* Before we enter scope, see if we're already in scope.
     * We only want to Contrast methods called out of all scope.
     * Otherwise, we do things like propagate within propagators,
     * which is unnecessary, or run Contrast analysis on Contrast code,
     * which will never terminate.
     */
    nested_scope = inst_methods_in_cntr_scope(contrast_patcher(), 0);

    /* enter scope */
    inst_methods_enter_cntr_scope(contrast_patcher(), 0);

    /* Get the name of the calling method */
    method = rb_funcall(object, rb_sym_method, 0);

    /* Retrieve from Ruby Patcher class
     * [method_policy object, Symbol representing original method]
     */
    switch (impl) {
    case IMPL_ALIAS_INSTANCE:
    case IMPL_PREPEND_INSTANCE:
        known =
            rb_funcall(patch_status, rb_sym_info_for, 3, object, method, Qtrue);
        break;
    case IMPL_PREPEND_SINGLETON:
        known = rb_funcall(patch_status, rb_sym_info_for, 3, object, method,
                           Qfalse);
        break;
    case IMPL_ALIAS_SINGLETON:
        known = rb_funcall(patch_status, rb_sym_info_for, 3, object, method,
                           Qfalse);
        break;
    }

    /* Index into above array & retrieve contents */
    if (RTEST(known)) {
        method_policy = rb_funcall(known, rb_sym_brackets, 1, INT2NUM(0));
    } else {
        method_policy = Qnil;
    }

    /* Check conditions for not doing Contrast analysis */
    if (nested_scope == Qtrue) {
        /* if we were in scope */
        do_contrast = 0;
    } else if (!RTEST(known)) {
        /* nothing to be done with entirely unknown method*/
        do_contrast = 0;
    } else if (!RTEST(method_policy)) {
        /* nothing to be done without a method policy */
        do_contrast = 0;
    }

    original_args[0] = INT2NUM(argc);
    original_args[1] = (VALUE)argv;
    original_args[2] = object;

    if (impl == IMPL_ALIAS_INSTANCE || impl == IMPL_ALIAS_SINGLETON) {
        /* Alias patching moves the original method to "cs__#{method}" */
        cs__method = rb_funcall(known, rb_sym_brackets, 1, INT2NUM(1));

        /* We may not have built the alias yet */
        if (!RTEST(cs__method)) {
            cs__method =
                rb_funcall(contrast_patcher(), rb_sym_build_method_name, 2,
                           object, method);
        }
        original_args[3] = cs__method;
    }

    /* Enter any scopes specific to method policy */
    VALUE scopes = rb_funcall(method_policy, rb_sym_scopes_to_enter, 0);

    inst_methods_enter_method_scope(contrast_patcher(), scopes);

    /* If we're not doing Contrast analysis, exit scope and treat as normal. */
    if (!do_contrast) {
        goto call_original;
    }

    /* Otherwise, invoke Contrast analysis. */
    VALUE wrapped_args[7];
    wrapped_args[0] = impl;
    wrapped_args[1] = (VALUE)original_args;
    wrapped_args[2] = method;
    wrapped_args[3] = method_policy;
    wrapped_args[4] = object;
    wrapped_args[5] = INT2NUM(argc);
    wrapped_args[6] = (VALUE)argv;

    return rb_ensure(contrast_run_patches, (VALUE)wrapped_args,
                     contrast_ensure_function, method_policy);

call_original:

    /* exit scope */
    contrast_ensure_function(method_policy);

    switch (impl) {
    case IMPL_ALIAS_INSTANCE:
    case IMPL_ALIAS_SINGLETON:
        return contrast_patch_call_original(original_args);
    case IMPL_PREPEND_INSTANCE:
    case IMPL_PREPEND_SINGLETON:
        return contrast_call_super(original_args);
    };
}

VALUE contrast_alias_instance_patch(const int argc, const VALUE *argv,
                                    const VALUE object) {
    return contrast_patch_dispatch(argc, argv, IMPL_ALIAS_INSTANCE, object);
}

VALUE contrast_alias_singleton_patch(const int argc, const VALUE *argv,
                                     const VALUE object) {
    return contrast_patch_dispatch(argc, argv, IMPL_ALIAS_SINGLETON, object);
}

VALUE contrast_prepend_instance_patch(const int argc, const VALUE *argv,
                                      const VALUE object) {
    return contrast_patch_dispatch(argc, argv, IMPL_PREPEND_INSTANCE, object);
}

VALUE contrast_prepend_singleton_patch(const int argc, const VALUE *argv,
                                       const VALUE object) {
    return contrast_patch_dispatch(argc, argv, IMPL_PREPEND_SINGLETON, object);
}

VALUE contrast_patch_define_method(const VALUE self, const VALUE clazz,
                                   const VALUE method_policy,
                                   const VALUE cs_method) {
    const VALUE original_method_name =
        rb_funcall(method_policy, rb_sym_method_name, 0);
    const VALUE is_instance_method =
        rb_funcall(method_policy, rb_sym_instance_method, 0);
    char *cStr;
    VALUE str;
    rb_funcall(patch_status, rb_sym_set_info_for, 5, clazz,
               original_method_name, method_policy, is_instance_method,
               cs_method);

    /* Some methods we patch rely on a specific C level patch,
     * in those cases we should still add the method to the info_for hash
     * on the class, but we should skip setting up the generic Contrast
     * monkeypatch
     */
    VALUE is_custom_patch = rb_funcall(method_policy, rb_sym_custom_patch, 0);
    if (RTEST(is_custom_patch)) {
        return Qnil;
    }
    str = rb_funcall(original_method_name, rb_sym_cs_to_s, 0);
    cStr = StringValueCStr(str);

    const VALUE is_private =
        rb_funcall(method_policy, rb_sym_private_method, 0);
    if (RTEST(is_instance_method)) {
        rb_funcall(clazz, rb_sym_alias_method, 2, cs_method,
                   original_method_name);
        if (RTEST(is_private)) {
            rb_funcall(clazz, rb_sym_public, 1, cs_method);
        }
        rb_define_method(clazz, cStr, contrast_alias_instance_patch, -1);
        if (RTEST(is_private)) {
            rb_funcall(clazz, rb_sym_private, 1, original_method_name);
        }
    } else {
        /*
         * we have to send the method call to either the singleton class or the
         * instance depending on the method type. send_to will hold this for us.
         * https://stackoverflow.com/questions/212407/what-exactly-is-the-singleton-class-in-ruby
         */
        const VALUE singleton_class =
            rb_funcall(clazz, rb_sym_cs_singleton_class, 0);
        rb_funcall(singleton_class, rb_sym_alias_method, 2, cs_method,
                   original_method_name);
        if (RTEST(is_private)) {
            rb_funcall(singleton_class, rb_sym_public, 1, cs_method);
        }
        rb_define_singleton_method(clazz, cStr, contrast_alias_singleton_patch,
                                   -1);
        if (RTEST(is_private)) {
            rb_funcall(singleton_class, rb_sym_private, 1,
                       original_method_name);
        }
    }
    return Qnil;
}

VALUE contrast_patch_prepend(const VALUE self, const VALUE originalModule,
                             const VALUE method_policy) {

    const VALUE instance = Qtrue;
    const VALUE singleton = Qfalse;
    const VALUE original_method_name =
        rb_funcall(method_policy, rb_sym_method_name, 0);
    const VALUE is_private =
        rb_funcall(method_policy, rb_sym_private_method, 0);
    const VALUE is_instance_method =
        rb_funcall(method_policy, rb_sym_instance_method, 0);

    // Set the value for instance or singleton method
    if (RTEST(is_instance_method)) {
        rb_funcall(patch_status, rb_sym_set_info_for, 5, originalModule,
                   original_method_name, method_policy, instance, Qnil);

    } else {
        rb_funcall(patch_status, rb_sym_set_info_for, 5, originalModule,
                   original_method_name, method_policy, singleton, Qnil);
    }

    VALUE module = rb_define_module_under(originalModule, "ContrastPrepend");
    VALUE str = rb_funcall(original_method_name, rb_sym_cs_to_s, 0);
    char *cMethodName = StringValueCStr(str);
    if (RTEST(is_instance_method)) {
        if (RTEST(is_private)) {
            rb_define_private_method(module, cMethodName,
                                     contrast_prepend_instance_patch, -1);
        } else {
            rb_define_method(module, cMethodName,
                             contrast_prepend_instance_patch, -1);
        }
    } else {
        rb_define_singleton_method(module, cMethodName,
                                   contrast_prepend_singleton_patch, -1);
    }
    rb_prepend_module(originalModule, module);

    if (rb_ver_below_three()) {
        VALUE module_at;
        VALUE rb_incl_in_mod_ary =
            rb_funcall(originalModule, rb_intern("included_in"), 0);
        if (RB_TYPE_P(rb_incl_in_mod_ary, T_ARRAY)) {
            int i = 0;
            int size = RARRAY_LEN(rb_incl_in_mod_ary);
            for (i = 0; i < size; ++i) {
                module_at = rb_ary_entry(rb_incl_in_mod_ary, i);
                if (RB_TYPE_P(module_at, T_MODULE)) {
                    rb_include_module(module_at, module);
                }
            }
        }
    }
    return Qtrue;
}

void Init_cs__contrast_patch(void) {
    rb_sym_brackets = rb_intern("[]");
    rb_sym_build_method_name = rb_intern("build_method_name");
    rb_sym_build_preshift = rb_intern("build_preshift");
    rb_sym_contrast_apply_post_patch = rb_intern("apply_post_patch");
    rb_sym_contrast_apply_pre_patch = rb_intern("apply_pre_patch");
    rb_sym_cs_to_s = rb_intern("to_s");
    rb_sym_custom_patch = rb_intern("requires_custom_patch?");
    rb_sym_info_for = rb_intern("info_for");
    rb_sym_propagation_node = rb_intern("propagation_node");
    rb_sym_set_info_for = rb_intern("set_info_for");

    rb_sym_private_method = rb_intern("private_method?");
    rb_sym_method_name = rb_intern("method_name");
    rb_sym_alias_method = rb_intern("alias_method");
    rb_sym_public = rb_intern("public");
    rb_sym_private = rb_intern("private");
    rb_sym_instance_method = rb_intern("instance_method");
    rb_sym_cs_singleton_class = rb_intern("cs__singleton_class");

    rb_sym_enter_method_scope = rb_intern("enter_method_scope!");
    rb_sym_exit_method_scope = rb_intern("exit_method_scope!");
    rb_sym_scopes_to_enter = rb_intern("scopes_to_enter");
    rb_sym_scopes_to_exit = rb_intern("scopes_to_exit");

    rb_define_module_function(contrast_patcher(), "contrast_define_method",
                              contrast_patch_define_method, 3);

    rb_define_module_function(contrast_patcher(), "contrast_prepend_method",
                              contrast_patch_prepend, 2);

    /* patch_status = Contrast::Agent::Patching::Policy::PatchStatus */
    patch_status = rb_define_class_under(policy, "PatchStatus", rb_cObject);

    /* preshift_class = Contrast::Agent::Assess::PreShift */
    preshift_class = rb_define_class_under(assess, "PreShift", rb_cObject);
}