/* Copyright (c) 2023 Contrast Security, Inc.  See
 * https://www.contrastsecurity.com/enduser-terms-0317a for more details. */

#include "cs__assess_fiber_track.h"
#include "../cs__common/cs__common.h"
#include <ruby.h>

VALUE rb_fiber_new_hook(VALUE (*func)(ANYARGS), VALUE obj) {
    /* This is a truncated copy of the enumerator struct definition
     * from ruby's enumerator.c.
     * The values should generally align, though there is a nonzero chance
     * that the compiler will optimize struct padding in such a way that
     * we're not reading what we think we're reading, and likely segfault.
     * TODO to be rigorous about compiler flags, but it's not an urgent matter
     * as most compiler defaults should serve us OK.
     */

    VALUE fiber = rb_fiber_new_original(func, obj);

    /* This is the most proximate Ruby method that's asking for this fiber.
     * In our case, we're looking for #next.
     * Other invocations are not particularly interesting to us.
     */
    VALUE calling_method = rb_funcall(rb_cObject, rb_intern("__method__"), 0);

    if (RTEST(rb_obj_is_kind_of(obj, rb_cEnumerator)) &&
        SYM2ID(calling_method) == rb_sym_next) {
        struct enumerator {
            VALUE obj;
            ID meth;
        };

        /*  underlying object is first entry in Enumerator struct def.
         *  that's all statically defined w/in enumerator.c, so we can't
         *  reference the data types and be safe about it.  (yippee.)
         *  we cut out the TypedData_Get_Struct middleman & just go for it.
         */
        struct enumerator *enum_ptr = ((struct enumerator *)DATA_PTR(obj));

        /* This is the object the enumerator is operating upon. */
        VALUE underlying = enum_ptr->obj;
        /* This is the method the enumerator uses to operate upon that object.
         */
        VALUE enumerator_method = ID2SYM(enum_ptr->meth);
        /* e.g.: 1..100, #each_value.  Should reflect #inspect on the enum. */

        rb_funcall(fiber_propagator, track_rb_fiber_new, 5, fiber, obj,
                   enumerator_method, underlying, calling_method);
    }

    return fiber;
}

VALUE rb_fiber_yield_hook(int argc, const VALUE *argv) {
    VALUE calling_method = rb_funcall(rb_cObject, rb_intern("__method__"), 0);
    VALUE yielding_fiber = rb_fiber_current();

    /* propagate from yielding_fiber -> result */
    rb_funcall(fiber_propagator, track_rb_fiber_yield, 3, yielding_fiber,
               calling_method, *argv);

    return rb_fiber_yield_original(argc, argv);
}

int install_fiber_hooks(void) {
    rb_fiber_new_original = rb_fiber_new;
    patch_via_funchook(&rb_fiber_new_original, &rb_fiber_new_hook);

    rb_fiber_yield_original = rb_fiber_yield;
    patch_via_funchook(&rb_fiber_yield_original, &rb_fiber_yield_hook);

    return 0;
}

void Init_cs__assess_fiber_track(void) {
    fiber_propagator =
        rb_define_class_under(core_assess, "FiberPropagator", rb_cObject);
    track_rb_fiber_new = rb_intern("track_rb_fiber_new");
    track_rb_fiber_yield = rb_intern("track_rb_fiber_yield");
    rb_sym_next = rb_intern("next");
    install_fiber_hooks();
}