// Copyright (C) 2016  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.

#include "subprocess_stream.h"

#include <sstream>
#include <utility>
#include <iostream>
#include <cstdio>
#include <fcntl.h>
#include <signal.h>
#include <sys/wait.h>
#include <sys/select.h>
#include "call_matlab.h"

using namespace std;

// ----------------------------------------------------------------------------------------

namespace dlib
{

// ----------------------------------------------------------------------------------------

    void make_fd_non_blocking(int fd)
    {
        int flags = fcntl(fd, F_GETFL, 0);
        fcntl(fd, F_SETFL, flags | O_NONBLOCK);
    }

// ----------------------------------------------------------------------------------------

    // Block until fd is ready to read,  while also echoing whatever is in fd_printf to
    // cout.
    int read_echoing_select(int fd, int fd_printf)
    {
        // run until fd has data ready
        while(fd_printf >= 0)
        {
            fd_set rfds;
            int retval;

            while(true)
            {
                FD_ZERO(&rfds);
                FD_SET(fd, &rfds);
                FD_SET(fd_printf, &rfds);

                // select times out every second just so we can check for matlab ctrl+c.
                struct timeval tv;
                tv.tv_sec = 1;
                tv.tv_usec = 0;

                try{check_for_matlab_ctrl_c();} catch(...) { return 1; }
                retval = select(std::max(fd,fd_printf)+1, &rfds, NULL, NULL, &tv);
                try{check_for_matlab_ctrl_c();} catch(...) { return 1; }
                if (retval == 0) // keep going if it was just a timeout.
                    continue;
                else if (retval == -1 && errno == EINTR)
                    continue;

                break;
            }

            if (retval == -1)
            {
                return 1;
            }
            else 
            {
                if (FD_ISSET(fd,&rfds))
                {
                    return 0;
                }
                else
                {
                    char buf[1024];
                    int num = read(fd_printf,buf, sizeof(buf)-1);
                    if (num == -1)
                        return 1;
                    if (num > 0)
                    {
                        buf[num] = 0;
                        cout << buf << flush;
                    }
                }
            }
        }
        return 0;
    }

    int write_echoing_select(int fd, int fd_printf)
    {
        // run until fd has data ready
        while(fd_printf >= 0)
        {
            fd_set rfds, wfds;
            int retval;
            while(true)
            {
                FD_ZERO(&rfds);
                FD_ZERO(&wfds);
                FD_SET(fd, &wfds);
                FD_SET(fd_printf, &rfds);

                // select times out every second just so we can check for matlab ctrl+c.
                struct timeval tv;
                tv.tv_sec = 1;
                tv.tv_usec = 0;

                try{check_for_matlab_ctrl_c();} catch(...) { return 1; }
                retval = select(std::max(fd,fd_printf)+1, &rfds, &wfds, NULL, &tv);
                try{check_for_matlab_ctrl_c();} catch(...) { return 1; }
                if (retval == 0) // keep going if it was just a timeout.
                    continue;
                else if (retval == -1 && errno == EINTR)
                    continue;

                break;
            }

            if (retval == -1)
            {
                return 1;
            }
            else 
            {
                if (FD_ISSET(fd,&wfds))
                {
                    return 0;
                }
                else
                {
                    char buf[1024];
                    int num = read(fd_printf,buf, sizeof(buf)-1);
                    if (num == -1)
                        return 1;
                    if (num > 0)
                    {
                        buf[num] = 0;
                        cout << buf << flush;
                    }
                }
            }
        }
        return 0;
    }

// ----------------------------------------------------------------------------------------

    class filestreambuf : public std::streambuf
    {
        /*!
                    INITIAL VALUE
                        - fd == the file descriptor we read from.
                        - in_buffer == an array of in_buffer_size bytes
                        - out_buffer == an array of out_buffer_size bytes

                    CONVENTION
                        - in_buffer == the input buffer used by this streambuf
                        - out_buffer == the output buffer used by this streambuf
                        - max_putback == the maximum number of chars to have in the put back buffer.
        !*/

    public:

        filestreambuf (
            int fd_,
            int fd_printf_
        ) :
            fd(fd_),
            fd_printf(fd_printf_),
            out_buffer(0),
            in_buffer(0)
        {
            init();
        }

        virtual ~filestreambuf (
        )
        {
            sync();
            delete [] out_buffer;
            delete [] in_buffer;
        }

        int sync (
        )
        {
            if (flush_out_buffer() == EOF)
            {
                // an error occurred
                return -1;
            }
            return 0;
        }
    protected:

        void init (
        )
        {
            try
            {
                out_buffer = new char[out_buffer_size];
                in_buffer = new char[in_buffer_size];
            }
            catch (...)
            {
                if (out_buffer) delete [] out_buffer;
                throw;
            }
            setp(out_buffer, out_buffer + (out_buffer_size-1));
            setg(in_buffer+max_putback, 
                in_buffer+max_putback, 
                in_buffer+max_putback);
        }

        int flush_out_buffer (
        )
        {
            int num = static_cast<int>(pptr()-pbase());
            const int num_written = num;
            char* buf = out_buffer;
            while(num != 0)
            {
                if(write_echoing_select(fd, fd_printf))
                    return EOF;
                int status = write(fd,buf,num);
                if (status < 0)
                {
                    // the write was not successful so return EOF 
                    return EOF;
                } 
                num -= status;
                buf += status;
            }
            pbump(-num_written);
            return num_written;
        }

        // output functions
        int_type overflow (
            int_type c
        )
        {
            if (c != EOF)
            {
                *pptr() = c;
                pbump(1);
            }
            if (flush_out_buffer() == EOF)
            {
                // an error occurred
                return EOF;
            }
            return c;
        }


        std::streamsize xsputn (
            const char* s,
            std::streamsize num
        )
        {
            // Add a sanity check here 
            DLIB_ASSERT(num >= 0,
                "\tstd::streamsize filestreambuf::xsputn"
                << "\n\tThe number of bytes to write can't be negative"
                << "\n\tnum:  " << num 
                << "\n\tthis: " << this
            );

            std::streamsize space_left = static_cast<std::streamsize>(epptr()-pptr());
            if (num <= space_left)
            {
                std::memcpy(pptr(),s,static_cast<size_t>(num));
                pbump(static_cast<int>(num));
                return num;
            }
            else
            {
                std::memcpy(pptr(),s,static_cast<size_t>(space_left));
                s += space_left;
                pbump(space_left);
                std::streamsize num_left = num - space_left;

                if (flush_out_buffer() == EOF)
                {
                    // the write was not successful so return that 0 bytes were written
                    return 0;
                }

                if (num_left < out_buffer_size)
                {
                    std::memcpy(pptr(),s,static_cast<size_t>(num_left));
                    pbump(num_left);
                    return num;
                }
                else
                {
                    while(num_left != 0)
                    {
                        if(write_echoing_select(fd, fd_printf))
                            return EOF;
                        int status = write(fd,s,num_left);
                        if (status < 0)
                        {
                            // the write was not successful so return that 0 bytes were written
                            return 0;
                        } 
                        num_left -= status;
                        s += status;
                    }
                    return num;
                }
            }
        }

        // input functions
        int_type underflow( 
        )
        {
            if (gptr() < egptr())
            {
                return static_cast<unsigned char>(*gptr());
            }

            int num_put_back = static_cast<int>(gptr() - eback());
            if (num_put_back > max_putback)
            {
                num_put_back = max_putback;
            }

            // copy the putback characters into the putback end of the in_buffer
            std::memmove(in_buffer+(max_putback-num_put_back), gptr()-num_put_back, num_put_back);


            if (read_echoing_select(fd, fd_printf))
                return EOF;
            int num = read(fd,in_buffer+max_putback, in_buffer_size-max_putback);
            if (num <= 0)
            {
                // an error occurred or the connection is over which is EOF
                return EOF;
            }

            // reset in_buffer pointers
            setg (in_buffer+(max_putback-num_put_back),
                in_buffer+max_putback,
                in_buffer+max_putback+num);

            return static_cast<unsigned char>(*gptr());
        }

        std::streamsize xsgetn (
            char_type* s, 
            std::streamsize n
        )
        { 
            std::streamsize temp = n;
            while (n > 0)
            {
                int num = static_cast<int>(egptr() - gptr());
                if (num >= n)
                {
                    // copy data from our buffer 
                    std::memcpy(s, gptr(), static_cast<size_t>(n));
                    gbump(static_cast<int>(n));
                    return temp;
                }

                // read more data into our buffer  
                if (num == 0)
                {
                    if (underflow() == EOF)
                        break;
                    continue;
                }

                // copy all the data from our buffer 
                std::memcpy(s, gptr(), num);
                n -= num;
                gbump(num);
                s += num;
            }
            return temp-n;       
        }

    private:

        // member data
        int  fd;
        int  fd_printf;
        static const std::streamsize max_putback = 4;
        static const std::streamsize out_buffer_size = 10000;
        static const std::streamsize in_buffer_size = 10000;
        char* out_buffer;
        char* in_buffer;

    };

    namespace impl
    {
        int get_data_fd()
        {
            char* env_fd = getenv("DLIB_SUBPROCESS_DATA_FD");
            DLIB_CASSERT(env_fd != 0,"");
            return atoi(env_fd);
        }

        std::iostream& get_data_iostream()
        {
            static filestreambuf dbuff(get_data_fd(), -1);
            static iostream out(&dbuff);
            return out;
        }
    }

// ---------------------------------------------------------------------------------------- 

    subprocess_stream::
    subprocess_stream(const char* program_name) : stderr(NULL), iosub(NULL)
    {
        if (access(program_name, F_OK))
            throw dlib::error("Error: '" + std::string(program_name) + "' file does not exist.");
        if (access(program_name, X_OK))
            throw dlib::error("Error: '" + std::string(program_name) + "' file is not executable.");

        child_pid = fork();
        if (child_pid == -1) 
            throw dlib::error("Failed to start child process"); 

        if (child_pid == 0) 
        {   
            // In child process
            dup2(stdout_pipe.child_fd(), STDOUT_FILENO);
            dup2(stderr_pipe.child_fd(),  STDERR_FILENO);
            stdout_pipe.close();
            stderr_pipe.close();

            char* argv[] = {(char*)program_name, nullptr};
            char* cudadevs = getenv("CUDA_VISIBLE_DEVICES");
            if (cudadevs)
            {
                std::ostringstream sout;
                sout << "DLIB_SUBPROCESS_DATA_FD="<<data_pipe.child_fd();
                std::string extra = sout.str();

                std::string extra2 = std::string("CUDA_VISIBLE_DEVICES=") + cudadevs;
                char* envp[] = {(char*)extra.c_str(), (char*)extra2.c_str(), nullptr};
                execve(argv[0], argv, envp);
            }
            else
            {
                std::ostringstream sout;
                sout << "DLIB_SUBPROCESS_DATA_FD="<<data_pipe.child_fd();
                std::string extra = sout.str();
                char* envp[] = {(char*)extra.c_str(), nullptr};
                execve(argv[0], argv, envp);
            }


            // If launching the child didn't work then bail immediately so the parent
            // process has no chance to get tweaked out (*cough* MATLAB *cough*).
            _Exit(1);
        }
        else 
        {
            // In parent process
            close(data_pipe.child_fd());
            close(stdout_pipe.child_fd());
            close(stderr_pipe.child_fd());
            make_fd_non_blocking(data_pipe.parent_fd());
            make_fd_non_blocking(stdout_pipe.parent_fd());
            make_fd_non_blocking(stderr_pipe.parent_fd());
            inout_buf = std::unique_ptr<filestreambuf>(new filestreambuf(data_pipe.parent_fd(), stdout_pipe.parent_fd()));
            err_buf = std::unique_ptr<filestreambuf>(new filestreambuf(stderr_pipe.parent_fd(), stdout_pipe.parent_fd()));
            iosub.rdbuf(inout_buf.get());
            stderr.rdbuf(err_buf.get());
            iosub.tie(&iosub);
            stderr.tie(&iosub);
        }
    }

// ----------------------------------------------------------------------------------------

    subprocess_stream::
    ~subprocess_stream() 
    {
        try
        {
            wait();
        }
        catch (dlib::error& e)
        {
            std::cerr << e.what() << std::endl;
        }
    }

// ----------------------------------------------------------------------------------------

    void subprocess_stream::
    wait() 
    {
        if (!wait_called)
        {
            wait_called = true;
            send_eof();

            std::ostringstream sout;
            sout << stderr.rdbuf();

            try{check_for_matlab_ctrl_c();} catch(...) 
            { 
                kill(child_pid, SIGTERM);
            }

            int status;
            waitpid(child_pid, &status, 0);
            if (status)
                throw dlib::error("Child process terminated with an error.\n" + sout.str());

            if (sout.str().size() != 0)
                throw dlib::error("Child process terminated with an error.\n" + sout.str());
        }
    }

// ----------------------------------------------------------------------------------------

    void subprocess_stream::
    send_eof() { inout_buf->sync();  ::close(data_pipe.parent_fd()); }

// ----------------------------------------------------------------------------------------

}