// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_SQLiTE_H_ #define DLIB_SQLiTE_H_ #include "sqlite_abstract.h" #include <iostream> #include <vector> #include "../algs.h" #include <sqlite3.h> #include "../smart_pointers.h" #include "../serialize.h" #include <limits> // -------------------------------------------------------------------------------------------- namespace dlib { // -------------------------------------------------------------------------------------------- struct sqlite_error : public error { sqlite_error(const std::string& message): error(message) {} }; // -------------------------------------------------------------------------------------------- namespace impl { struct db_deleter { void operator()( sqlite3* db )const { sqlite3_close(db); } }; } // -------------------------------------------------------------------------------------------- class database : noncopyable { public: database( ) { } database ( const std::string& file ) { open(file); } bool is_open ( ) const { return db.get() != 0; } void open ( const std::string& file ) { filename = file; sqlite3* ptr = 0; int status = sqlite3_open(file.c_str(), &ptr); db.reset(ptr, impl::db_deleter()); if (status != SQLITE_OK) { throw sqlite_error(sqlite3_errmsg(db.get())); } } const std::string& get_database_filename ( ) const { // make sure requires clause is not broken DLIB_ASSERT(is_open() == true, "\t std::string database::get_database_filename()" << "\n\t The database must be opened before calling this routine." << "\n\t this: " << this ); return filename; } inline void exec ( const std::string& sql_statement ); int64 last_insert_rowid ( ) const { return sqlite3_last_insert_rowid(db.get()); } private: friend class statement; std::string filename; shared_ptr<sqlite3> db; }; // -------------------------------------------------------------------------------------------- class statement : noncopyable { public: statement ( database& db_, const std::string sql_statement ) : needs_reset(false), step_status(SQLITE_DONE), at_first_step(true), db(db_.db), stmt(0), sql_string(sql_statement) { // make sure requires clause is not broken DLIB_ASSERT(db_.is_open() == true, "\t statement::statement()" << "\n\t The database must be opened before calling this routine." << "\n\t this: " << this ); int status = sqlite3_prepare_v2(db.get(), sql_string.c_str(), sql_string.size()+1, &stmt, NULL); if (status != SQLITE_OK) { sqlite3_finalize(stmt); throw sqlite_error(sqlite3_errmsg(db.get())); } if (stmt == 0) { throw sqlite_error("Invalid SQL statement"); } } ~statement( ) { sqlite3_finalize(stmt); } void exec( ) { reset(); step_status = sqlite3_step(stmt); needs_reset = true; if (step_status != SQLITE_DONE && step_status != SQLITE_ROW) { if (step_status == SQLITE_ERROR) throw sqlite_error(sqlite3_errmsg(db.get())); else if (step_status == SQLITE_BUSY) throw sqlite_error("statement::exec() failed. SQLITE_BUSY returned"); else throw sqlite_error("statement::exec() failed."); } } bool move_next ( ) { if (step_status == SQLITE_ROW) { if (at_first_step) { at_first_step = false; return true; } else { step_status = sqlite3_step(stmt); if (step_status == SQLITE_DONE) { return false; } else if (step_status == SQLITE_ROW) { return true; } else { throw sqlite_error(sqlite3_errmsg(db.get())); } } } else { return false; } } unsigned long get_num_columns( ) const { if( (at_first_step==false) && (step_status==SQLITE_ROW)) { return sqlite3_column_count(stmt); } else { return 0; } } const std::string& get_sql_string ( ) const { return sql_string; } template <typename T> typename enable_if_c<std::numeric_limits<T>::is_integer>::type get_column ( unsigned long idx, T& item ) const { // unsigned ints won't fit into int all the time so put those into 64bit ints. if (sizeof(T) < sizeof(int) || (sizeof(T)==sizeof(int) && is_signed_type<T>::value)) item = get_column_as_int(idx); else item = get_column_as_int64(idx); } void get_column(unsigned long idx, std::string& item) const { item = get_column_as_text(idx); } void get_column(unsigned long idx, float& item ) const { item = get_column_as_double(idx); } void get_column(unsigned long idx, double& item ) const { item = get_column_as_double(idx); } void get_column(unsigned long idx, long double& item) const { item = get_column_as_double(idx); } template <typename T> typename disable_if_c<std::numeric_limits<T>::is_integer>::type get_column ( unsigned long idx, T& item ) const { get_column_as_object(idx, item); } const std::vector<char> get_column_as_blob ( unsigned long idx ) const { // make sure requires clause is not broken DLIB_ASSERT(idx < get_num_columns(), "\t std::vector<char> statement::get_column_as_blob()" << "\n\t Invalid column index." << "\n\t idx: " << idx << "\n\t this: " << this ); const char* data = static_cast<const char*>(sqlite3_column_blob(stmt, idx)); const int size = sqlite3_column_bytes(stmt, idx); return std::vector<char>(data, data+size); } template <typename T> void get_column_as_object ( unsigned long idx, T& item ) const { // make sure requires clause is not broken DLIB_ASSERT(idx < get_num_columns(), "\t void statement::get_column_as_object()" << "\n\t Invalid column index." << "\n\t idx: " << idx << "\n\t this: " << this ); const char* data = static_cast<const char*>(sqlite3_column_blob(stmt, idx)); const int size = sqlite3_column_bytes(stmt, idx); std::istringstream sin(std::string(data,size)); deserialize(item, sin); } const std::string get_column_as_text ( unsigned long idx ) const { // make sure requires clause is not broken DLIB_ASSERT(idx < get_num_columns(), "\t std::string statement::get_column_as_text()" << "\n\t Invalid column index." << "\n\t idx: " << idx << "\n\t this: " << this ); const char* data = reinterpret_cast<const char*>(sqlite3_column_text(stmt, idx)); if (data != 0) return std::string(data); else return std::string(); } double get_column_as_double ( unsigned long idx ) const { // make sure requires clause is not broken DLIB_ASSERT(idx < get_num_columns(), "\t double statement::get_column_as_double()" << "\n\t Invalid column index." << "\n\t idx: " << idx << "\n\t this: " << this ); return sqlite3_column_double(stmt, idx); } int get_column_as_int ( unsigned long idx ) const { // make sure requires clause is not broken DLIB_ASSERT(idx < get_num_columns(), "\t int statement::get_column_as_int()" << "\n\t Invalid column index." << "\n\t idx: " << idx << "\n\t this: " << this ); return sqlite3_column_int(stmt, idx); } int64 get_column_as_int64 ( unsigned long idx ) const { // make sure requires clause is not broken DLIB_ASSERT(idx < get_num_columns(), "\t int64 statement::get_column_as_int64()" << "\n\t Invalid column index." << "\n\t idx: " << idx << "\n\t this: " << this ); return sqlite3_column_int64(stmt, idx); } const std::string get_column_name ( unsigned long idx ) const { // make sure requires clause is not broken DLIB_ASSERT(idx < get_num_columns(), "\t std::string statement::get_column_name()" << "\n\t Invalid column index." << "\n\t idx: " << idx << "\n\t this: " << this ); return std::string(sqlite3_column_name(stmt,idx)); } unsigned long get_max_parameter_id ( ) const { return sqlite3_limit(db.get(), SQLITE_LIMIT_VARIABLE_NUMBER, -1); } unsigned long get_parameter_id ( const std::string& name ) const { return sqlite3_bind_parameter_index(stmt, name.c_str()); } template <typename T> typename enable_if_c<std::numeric_limits<T>::is_integer>::type bind ( unsigned long idx, const T& item ) { // unsigned ints won't fit into int all the time so put those into 64bit ints. if (sizeof(T) < sizeof(int) || (sizeof(T)==sizeof(int) && is_signed_type<T>::value)) bind_int(idx, item); else bind_int64(idx, item); } void bind(unsigned long idx, const std::string& item) { bind_text(idx, item); } void bind(unsigned long idx, const float& item ) { bind_double(idx, item); } void bind(unsigned long idx, const double& item ) { bind_double(idx, item); } void bind(unsigned long idx, const long double& item) { bind_double(idx, item); } template <typename T> typename disable_if_c<std::numeric_limits<T>::is_integer>::type bind ( unsigned long idx, const T& item ) { bind_object(idx, item); } void bind_blob ( unsigned long parameter_id, const std::vector<char>& item ) { // make sure requires clause is not broken DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), "\t void statement::bind_blob()" << "\n\t Invalid parameter id." << "\n\t parameter_id: " << parameter_id << "\n\t get_max_parameter_id(): " << get_max_parameter_id() << "\n\t this: " << this ); reset(); int status = sqlite3_bind_blob(stmt, parameter_id, &item[0], item.size(), SQLITE_TRANSIENT); if (status != SQLITE_OK) { throw sqlite_error(sqlite3_errmsg(db.get())); } } template <typename T> void bind_object ( unsigned long parameter_id, const T& item ) { // make sure requires clause is not broken DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), "\t void statement::bind_object()" << "\n\t Invalid parameter id." << "\n\t parameter_id: " << parameter_id << "\n\t get_max_parameter_id(): " << get_max_parameter_id() << "\n\t this: " << this ); reset(); std::ostringstream sout; serialize(item, sout); const std::string& str = sout.str(); int status = sqlite3_bind_blob(stmt, parameter_id, str.data(), str.size(), SQLITE_TRANSIENT); if (status != SQLITE_OK) { throw sqlite_error(sqlite3_errmsg(db.get())); } } void bind_double ( unsigned long parameter_id, const double& item ) { // make sure requires clause is not broken DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), "\t void statement::bind_double()" << "\n\t Invalid parameter id." << "\n\t parameter_id: " << parameter_id << "\n\t get_max_parameter_id(): " << get_max_parameter_id() << "\n\t this: " << this ); reset(); int status = sqlite3_bind_double(stmt, parameter_id, item); if (status != SQLITE_OK) { throw sqlite_error(sqlite3_errmsg(db.get())); } } void bind_int ( unsigned long parameter_id, const int& item ) { // make sure requires clause is not broken DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), "\t void statement::bind_int()" << "\n\t Invalid parameter id." << "\n\t parameter_id: " << parameter_id << "\n\t get_max_parameter_id(): " << get_max_parameter_id() << "\n\t this: " << this ); reset(); int status = sqlite3_bind_int(stmt, parameter_id, item); if (status != SQLITE_OK) { throw sqlite_error(sqlite3_errmsg(db.get())); } } void bind_int64 ( unsigned long parameter_id, const int64& item ) { // make sure requires clause is not broken DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), "\t void statement::bind_int64()" << "\n\t Invalid parameter id." << "\n\t parameter_id: " << parameter_id << "\n\t get_max_parameter_id(): " << get_max_parameter_id() << "\n\t this: " << this ); reset(); int status = sqlite3_bind_int64(stmt, parameter_id, item); if (status != SQLITE_OK) { throw sqlite_error(sqlite3_errmsg(db.get())); } } void bind_null ( unsigned long parameter_id ) { // make sure requires clause is not broken DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), "\t void statement::bind_null()" << "\n\t Invalid parameter id." << "\n\t parameter_id: " << parameter_id << "\n\t get_max_parameter_id(): " << get_max_parameter_id() << "\n\t this: " << this ); reset(); int status = sqlite3_bind_null(stmt, parameter_id); if (status != SQLITE_OK) { throw sqlite_error(sqlite3_errmsg(db.get())); } } void bind_text ( unsigned long parameter_id, const std::string& item ) { // make sure requires clause is not broken DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(), "\t void statement::bind_text()" << "\n\t Invalid parameter id." << "\n\t parameter_id: " << parameter_id << "\n\t get_max_parameter_id(): " << get_max_parameter_id() << "\n\t this: " << this ); reset(); int status = sqlite3_bind_text(stmt, parameter_id, item.c_str(), -1, SQLITE_TRANSIENT); if (status != SQLITE_OK) { throw sqlite_error(sqlite3_errmsg(db.get())); } } private: void reset() { if (needs_reset) { if (sqlite3_reset(stmt) != SQLITE_OK) { step_status = SQLITE_DONE; throw sqlite_error(sqlite3_errmsg(db.get())); } needs_reset = false; step_status = SQLITE_DONE; at_first_step = true; } } bool needs_reset; // true if sqlite3_step() has been called more recently than sqlite3_reset() int step_status; bool at_first_step; shared_ptr<sqlite3> db; sqlite3_stmt* stmt; std::string sql_string; }; // -------------------------------------------------------------------------------------------- void database:: exec ( const std::string& sql_statement ) { // make sure requires clause is not broken DLIB_ASSERT(is_open() == true, "\t void database::exec()" << "\n\t The database must be opened before calling this routine." << "\n\t this: " << this ); statement(*this, sql_statement).exec(); } // -------------------------------------------------------------------------------------------- } #endif // DLIB_SQLiTE_H_